You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.h 8.3 kB


  1. /**
  2. * \file dnn/src/cuda/convolution/opr_impl.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/oprs/nn.h"
  14. #include "src/common/utils.h"
  15. namespace megdnn {
  16. namespace cuda {
  17. class ConvolutionForwardImpl : public ConvolutionForward {
  18. public:
  19. using ConvolutionForward::ConvolutionForward;
  20. void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  21. _megdnn_tensor_out dst,
  22. const PreprocessedFilter* preprocessed_filter,
  23. _megdnn_workspace workspace) override;
  24. size_t get_workspace_in_bytes(
  25. const TensorLayout& src, const TensorLayout& filter,
  26. const TensorLayout& dst,
  27. const PreprocessedFilter* preprocessed_filter) override;
  28. const char* get_algorithm_set_name() const override;
  29. SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  30. const TensorLayout&, const TensorLayout&,
  31. const TensorLayout&) override {
  32. return {};
  33. }
  34. size_t get_preprocess_workspace_in_bytes(const TensorLayout&,
  35. const TensorLayout&,
  36. const TensorLayout&) override {
  37. return 0;
  38. }
  39. void exec_preprocess(const TensorLayout&, _megdnn_tensor_in,
  40. const TensorLayout&, PreprocessedFilter*,
  41. _megdnn_workspace) override {
  42. megdnn_throw("cuda exec_preprocess has not implemeted yet");
  43. }
  44. Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
  45. class AlgoBase;
  46. class AlgoDefault;
  47. class AlgoPack;
  48. static const AlgoPack& algo_pack() { return sm_algo_pack; }
  49. protected:
  50. std::vector<Algorithm*> get_all_algorithms(
  51. const TensorLayout& src, const TensorLayout& filter,
  52. const TensorLayout& dst) override;
  53. Algorithm* get_algorithm_heuristic(const TensorLayout& src,
  54. const TensorLayout& filter,
  55. const TensorLayout& dst,
  56. size_t workspace_limit_in_bytes,
  57. const AlgoAttribute& attr) override;
  58. private:
  59. static AlgoPack sm_algo_pack;
  60. };
  61. class ConvolutionBackwardDataImpl : public ConvolutionBackwardData {
  62. public:
  63. using ConvolutionBackwardData::ConvolutionBackwardData;
  64. void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  65. _megdnn_tensor_out grad, _megdnn_workspace workspace) override;
  66. AlgorithmInfo get_algorithm_info_heuristic(
  67. const TensorLayout& filter, const CanonizedFilterMeta& filter_meta,
  68. const TensorLayout& diff, const TensorLayout& grad,
  69. size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
  70. return get_algorithm_heuristic(filter, filter_meta, diff, grad,
  71. workspace_limit_in_bytes, attr)
  72. ->info();
  73. }
  74. AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& filter,
  75. const TensorLayout& diff,
  76. const TensorLayout& grad,
  77. size_t workspace_limit_in_bytes,
  78. const AlgoAttribute& attr) {
  79. return get_algorithm_heuristic(filter, diff, grad,
  80. workspace_limit_in_bytes, attr)
  81. ->info();
  82. }
  83. size_t get_workspace_in_bytes(const TensorLayout& filter,
  84. const TensorLayout& diff,
  85. const TensorLayout& grad) override;
  86. const char* get_algorithm_set_name() const override;
  87. class AlgoBase;
  88. class AlgoCUDNN;
  89. class AlgoMatmul;
  90. class AlgoChanwise;
  91. class AlgoChanwiseSmall;
  92. class AlgoGroupConvGeneral;
  93. class AlgoBFloat16;
  94. class AlgoInt8NCHW4DotProdImplicitGemm;
  95. class AlgoInt8NCHWDotProdImplicitGemm;
  96. class AlgoPack;
  97. static const AlgoPack& algo_pack() { return sm_algo_pack; }
  98. Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
  99. protected:
  100. std::vector<Algorithm*> get_all_algorithms(
  101. const TensorLayout& filter, const TensorLayout& diff,
  102. const TensorLayout& grad) override;
  103. Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
  104. const TensorLayout& diff,
  105. const TensorLayout& grad,
  106. size_t workspace_limit_in_bytes,
  107. const AlgoAttribute& attr) override;
  108. private:
  109. Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
  110. const CanonizedFilterMeta& filter_meta,
  111. const TensorLayout& diff,
  112. const TensorLayout& grad,
  113. size_t workspace_limit_in_bytes,
  114. const AlgoAttribute& attr);
  115. static AlgoPack sm_algo_pack;
  116. };
  117. class ConvolutionBackwardFilterImpl : public ConvolutionBackwardFilter {
  118. public:
  119. using ConvolutionBackwardFilter::ConvolutionBackwardFilter;
  120. void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  121. _megdnn_tensor_out grad, _megdnn_workspace workspace) override;
  122. size_t get_workspace_in_bytes(const TensorLayout& src,
  123. const TensorLayout& diff,
  124. const TensorLayout& grad) override;
  125. AlgorithmInfo get_algorithm_info_heuristic(
  126. const TensorLayout& src, const TensorLayout& diff,
  127. const TensorLayout& grad, const CanonizedFilterMeta& grad_meta,
  128. size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
  129. return get_algorithm_heuristic(src, diff, grad, grad_meta,
  130. workspace_limit_in_bytes, attr)
  131. ->info();
  132. }
  133. AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& filter,
  134. const TensorLayout& diff,
  135. const TensorLayout& grad,
  136. size_t workspace_limit_in_bytes,
  137. const AlgoAttribute& attr) {
  138. return get_algorithm_heuristic(filter, diff, grad,
  139. workspace_limit_in_bytes, attr)
  140. ->info();
  141. }
  142. const char* get_algorithm_set_name() const override;
  143. class AlgoBase;
  144. class AlgoCUDNN;
  145. class AlgoMatmul;
  146. class AlgoChanwise;
  147. class AlgoGroupConvGeneral;
  148. class AlgoBFloat16;
  149. class AlgoPack;
  150. static const AlgoPack& algo_pack() { return sm_algo_pack; }
  151. Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
  152. protected:
  153. std::vector<Algorithm*> get_all_algorithms(
  154. const TensorLayout& src, const TensorLayout& diff,
  155. const TensorLayout& grad) override;
  156. Algorithm* get_algorithm_heuristic(const TensorLayout& src,
  157. const TensorLayout& diff,
  158. const TensorLayout& grad,
  159. size_t workspace_limit_in_bytes,
  160. const AlgoAttribute& attr) override;
  161. private:
  162. Algorithm* get_algorithm_heuristic(const TensorLayout& src,
  163. const TensorLayout& diff,
  164. const TensorLayout& grad,
  165. const CanonizedFilterMeta& grad_meta,
  166. size_t workspace_limit_in_bytes,
  167. const AlgoAttribute& attr);
  168. static AlgoPack sm_algo_pack;
  169. };
  170. } // namespace cuda
  171. } // namespace megdnn
  172. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台