You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.h 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. /**
  2. * \file dnn/src/cuda/convolution/forward/algos.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/oprs.h"
  14. #include "src/common/algo_base.h"
  15. #include "src/common/metahelper.h"
  16. #include "src/common/utils.h"
  17. #include "src/cuda/convolution/opr_impl.h"
  18. #include <unordered_map>
  19. namespace megdnn {
  20. namespace cuda {
  21. /*!
  22. * \brief base class for convolutionForward algos
  23. *
  24. */
  25. class ConvolutionForwardImpl::AlgoBase : public Algorithm {
  26. protected:
  27. ~AlgoBase() = default;
  28. public:
  29. enum class AlgoType : uint32_t {
  30. CUDA_DEFAULT,
  31. };
  32. using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
  33. AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
  34. struct SizeArgs {
  35. ConvolutionForwardImpl* opr;
  36. const TensorLayout *layout_src, *layout_filter, *layout_dst;
  37. std::string to_string() const;
  38. SizeArgs(ConvolutionForwardImpl* opr, const TensorLayout& src,
  39. const TensorLayout& filter, const TensorLayout& dst);
  40. };
  41. struct ExecArgs : public SizeArgs {
  42. TensorND tensor_src, tensor_filter, tensor_dst;
  43. Workspace workspace;
  44. ExecArgs(ConvolutionForwardImpl* opr, _megdnn_tensor_in src,
  45. _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  46. _megdnn_workspace workspace);
  47. };
  48. virtual bool is_available(const SizeArgs& args) const = 0;
  49. virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
  50. virtual void exec(const ExecArgs&) const = 0;
  51. bool is_available_wk(const SizeArgs& args, size_t limit) const {
  52. return is_available(args) && get_workspace_in_bytes(args) <= limit;
  53. }
  54. bool is_available_reproducible(
  55. const SizeArgs& args, bool reproducible = true,
  56. size_t limit = std::numeric_limits<size_t>::max()) const {
  57. return (!reproducible || is_reproducible()) &&
  58. is_available_wk(args, limit);
  59. }
  60. AlgoBase& check_workspace(const SizeArgs& args,
  61. const Workspace& workspace) {
  62. auto req = get_workspace_in_bytes(args);
  63. megdnn_assert(req <= workspace.size,
  64. "convolution fwd algo %s: required workspace %zu bytes, "
  65. "got %zu",
  66. name(), req, workspace.size);
  67. return *this;
  68. }
  69. };
  70. class ConvolutionForwardImpl::AlgoDefault final : public AlgoBase {
  71. public:
  72. AlgoDefault() = default;
  73. bool is_available(const SizeArgs&) const override;
  74. size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override;
  75. const char* name() const override { return "DEFAULT"; }
  76. void exec(const ExecArgs&) const override;
  77. bool is_reproducible() const override { return true; }
  78. std::vector<SearchItem> get_subopr_list(
  79. const TensorLayoutArray& layouts,
  80. const OperatorBase* opr) const override;
  81. MEGDNN_DECL_ALGO_TYPE(CUDA_DEFAULT)
  82. };
  83. class ConvolutionForwardImpl::AlgoPack : NonCopyableObj {
  84. private:
  85. AlgoBase::Mapper m_all_algos_map;
  86. public:
  87. AlgoPack();
  88. AlgoDefault algo_default;
  89. std::vector<AlgoBase*> all_algos;
  90. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  91. };
  92. } // namespace cuda
  93. } // namespace megdnn
  94. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台