You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo.h 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. /**
  2. * \file dnn/src/cuda/convolution/backward_data/algo.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include <unordered_map>
  14. #include "src/common/algo_base.h"
  15. #include "src/common/metahelper.h"
  16. #include "src/cuda/convolution/helper.h"
  17. #include "src/cuda/cudnn_wrapper.h"
  18. namespace megdnn {
  19. namespace cuda {
  20. /*!
  21. * \brief base class for convolution algos
  22. *
  23. * All the algo impls should try to support non-contiguous batch dim, for group
  24. * conv execution.
  25. */
  26. class ConvolutionBackwardDataImpl::AlgoBase : public Algorithm {
  27. protected:
  28. ~AlgoBase() = default;
  29. public:
  30. enum class AlgoType : uint32_t {
  31. CUDA_CUDNN,
  32. CUDA_MATMUL,
  33. CUDA_CHANWISE,
  34. CUDA_CHANWISE_SMALL,
  35. CUDA_DEPTHWISE_LARGE_FILTER,
  36. CUDA_BFLOAT16,
  37. CUDA_GROUP_CONV_GENERAL,
  38. CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8,
  39. CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8,
  40. CUDA_IMPLICIT_GEMM_NHWC_IMMA_INT8,
  41. CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32,
  42. CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16,
  43. };
  44. using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
  45. AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
  46. struct SizeArgs {
  47. HandleImpl* handle;
  48. CanonizedFilterMeta filter_meta;
  49. const TensorLayout *diff_layout, *grad_layout, *filter_layout;
  50. const ConvolutionBackwardDataImpl* opr;
  51. std::string to_string() const;
  52. void init_desc(convolution::CUDNNBwdDataDescs& desc) const {
  53. desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
  54. }
  55. SizeArgs(
  56. const ConvolutionBackwardDataImpl* opr, const TensorLayout& filter,
  57. const TensorLayout& diff, const TensorLayout& grad);
  58. SizeArgs(
  59. const ConvolutionBackwardDataImpl* opr, const TensorLayout& filter,
  60. const CanonizedFilterMeta& filter_meta, const TensorLayout& diff,
  61. const TensorLayout& grad);
  62. convolution::ForwardSizeArgs as_fwd_args() const {
  63. return {handle, grad_layout, filter_layout, filter_meta, diff_layout};
  64. }
  65. };
  66. struct ExecArgs : public SizeArgs {
  67. const TensorND *filter_tensor, *diff_tensor, *grad_tensor;
  68. Workspace workspace;
  69. ExecArgs(
  70. const ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter,
  71. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  72. _megdnn_workspace workspace);
  73. };
  74. virtual bool is_available(const SizeArgs& args) const = 0;
  75. virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
  76. virtual void exec(const ExecArgs& args) const = 0;
  77. bool is_available_wk(const SizeArgs& args, size_t limit) {
  78. return is_available(args) && get_workspace_in_bytes(args) <= limit;
  79. }
  80. bool is_available_attribute(
  81. const SizeArgs& args,
  82. const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
  83. const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
  84. size_t limit = std::numeric_limits<size_t>::max()) {
  85. return contain_attribute_all(positive_attr) &&
  86. !contain_attribute_any(negative_attr) && is_available_wk(args, limit);
  87. }
  88. AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) {
  89. auto req = get_workspace_in_bytes(args);
  90. megdnn_assert(
  91. req <= workspace.size,
  92. "conv bwd data algo %s: "
  93. "required workspace %zu bytes, got %zu",
  94. name(), req, workspace.size);
  95. return *this;
  96. }
  97. virtual bool is_cudnn() const { return false; }
  98. };
  99. class ConvolutionBackwardDataImpl::AlgoCUDNN final : public AlgoBase {
  100. cudnnConvolutionBwdDataAlgo_t m_cudnn_enum;
  101. CudnnAlgoPack::Attr m_attr;
  102. public:
  103. AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) {
  104. megdnn_assert(
  105. CudnnAlgoPack::conv_bwd_data_algos().find(cudnn_enum) !=
  106. CudnnAlgoPack::conv_bwd_data_algos().end());
  107. m_attr = CudnnAlgoPack::conv_bwd_data_algos().at(cudnn_enum);
  108. }
  109. bool is_available(const SizeArgs& args) const override;
  110. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  111. void exec(const ExecArgs& args) const override;
  112. const char* name() const override { return m_attr.name.c_str(); }
  113. AlgoAttribute attribute() const override {
  114. auto ret = static_cast<AlgoAttribute>(0);
  115. if (m_attr.is_reproducible) {
  116. ret |= AlgoAttribute::REPRODUCIBLE;
  117. }
  118. if (m_attr.accuracy_depend_on_batch) {
  119. ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
  120. }
  121. return ret;
  122. }
  123. cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; }
  124. bool is_cudnn() const override { return true; }
  125. MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN)
  126. std::string param() const override {
  127. std::string ret;
  128. serialize_write_pod(m_cudnn_enum, ret);
  129. return ret;
  130. }
  131. };
  132. //! im2col and matmul, with dilation
  133. class ConvolutionBackwardDataImpl::AlgoMatmul final : public AlgoBase {
  134. template <typename T>
  135. static void exec_internal(const ExecArgs& args);
  136. public:
  137. bool is_available(const SizeArgs& args) const override;
  138. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  139. void exec(const ExecArgs& args) const override;
  140. std::vector<SearchItem> get_subopr_list(
  141. const TensorLayoutArray& layouts, const OperatorBase* opr) const override;
  142. const char* name() const override { return "MATMUL"; }
  143. MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
  144. AlgoAttribute attribute() const override {
  145. return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
  146. }
  147. };
  148. class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase {
  149. public:
  150. bool is_available(const SizeArgs& args) const override;
  151. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  152. void exec(const ExecArgs& args) const override;
  153. const char* name() const override { return "CHANNEL_WISE"; }
  154. MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
  155. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  156. };
  157. class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase {
  158. public:
  159. bool is_available(const SizeArgs& args) const override;
  160. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  161. void exec(const ExecArgs& args) const override;
  162. const char* name() const override { return "CHANNEL_WISE_SMALL"; }
  163. MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
  164. AlgoAttribute attribute() const override {
  165. return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
  166. }
  167. };
  168. class ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter final : public AlgoBase {
  169. public:
  170. bool is_available(const SizeArgs& args) const override;
  171. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  172. void exec(const ExecArgs& args) const override;
  173. const char* name() const override { return "DEPTHWISE_LARGE_FILTER"; }
  174. MEGDNN_DECL_ALGO_TYPE(CUDA_DEPTHWISE_LARGE_FILTER)
  175. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  176. private:
  177. mutable std::string m_name;
  178. };
  179. class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase {
  180. public:
  181. bool is_available(const SizeArgs& args) const override;
  182. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  183. void exec(const ExecArgs& args) const override;
  184. std::vector<SearchItem> get_subopr_list(
  185. const TensorLayoutArray& layouts, const OperatorBase* opr) const override;
  186. const char* name() const override { return "CONVOLUTION_BACKWARD_DATD_BFLOAT16"; }
  187. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  188. private:
  189. WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
  190. MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
  191. };
  192. //! implement group conv by another algo
  193. class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final : public AlgoBase {
  194. public:
  195. bool is_available(const SizeArgs& args) const override;
  196. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  197. void exec(const ExecArgs& args) const override;
  198. std::vector<SearchItem> get_subopr_list(
  199. const TensorLayoutArray& layouts, const OperatorBase* opr) const override;
  200. const char* name() const override { return "CUDA:GROUP_CONV_BACKWARD_DATA"; }
  201. MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
  202. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  203. private:
  204. WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
  205. };
  206. class ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm final
  207. : public AlgoBase {
  208. public:
  209. struct AlgoParam {
  210. int threadblock_m;
  211. int threadblock_n;
  212. int threadblock_k;
  213. int warp_m;
  214. int warp_n;
  215. int warp_k;
  216. int stage;
  217. std::string to_string() {
  218. return ssprintf(
  219. "_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n,
  220. threadblock_k, warp_m, warp_n, warp_k, stage);
  221. }
  222. };
  223. AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param)
  224. : m_algo_param{algo_param},
  225. m_name{ssprintf(
  226. "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s",
  227. m_algo_param.to_string().c_str())} {}
  228. bool is_available(const SizeArgs& args) const override;
  229. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  230. void exec(const ExecArgs& args) const override;
  231. const char* name() const override { return m_name.c_str(); }
  232. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  233. MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)
  234. private:
  235. WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
  236. const void* get_available_op(const SizeArgs& args) const;
  237. AlgoParam m_algo_param;
  238. std::string m_name;
  239. };
  240. class ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm final
  241. : public AlgoBase {
  242. public:
  243. bool is_available(const SizeArgs& args) const override;
  244. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  245. void exec(const ExecArgs& args) const override;
  246. const char* name() const override { return "INT8_NCHW_DOTPROD_IMPLICIT_GEMM"; }
  247. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  248. MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8);
  249. private:
  250. WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
  251. const void* get_available_op(const SizeArgs& args) const;
  252. };
  253. class ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm final
  254. : public AlgoBase {
  255. public:
  256. struct AlgoParam {
  257. int threadblock_m;
  258. int threadblock_n;
  259. int threadblock_k;
  260. int warp_m;
  261. int warp_n;
  262. int warp_k;
  263. int stage;
  264. int access_size;
  265. std::string to_string() {
  266. return ssprintf(
  267. "_%dX%dX%d_%dX%dX%d_%dstage_%d", threadblock_m, threadblock_n,
  268. threadblock_k, warp_m, warp_n, warp_k, stage, access_size);
  269. }
  270. };
  271. AlgoInt8NHWCIMMAImplicitGemm(AlgoParam algo_param)
  272. : m_algo_param{algo_param},
  273. m_name{ssprintf(
  274. "INT8_NHWC_IMMA_IMPLICIT_GEMM%s",
  275. m_algo_param.to_string().c_str())} {}
  276. bool is_available(const SizeArgs& args) const override;
  277. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  278. void exec(const ExecArgs& args) const override;
  279. const char* name() const override { return m_name.c_str(); }
  280. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  281. MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NHWC_IMMA_INT8)
  282. private:
  283. WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
  284. const void* get_available_op(const SizeArgs& args) const;
  285. void reorder_filter(
  286. const ExecArgs& args, const int iterleaved, int8_t* reordered_filter) const;
  287. AlgoParam m_algo_param;
  288. std::string m_name;
  289. };
  290. class ConvolutionBackwardDataImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm final
  291. : public AlgoBase {
  292. public:
  293. struct AlgoParam {
  294. int threadblock_m;
  295. int threadblock_n;
  296. int threadblock_k;
  297. int warp_m;
  298. int warp_n;
  299. int warp_k;
  300. int stage;
  301. std::string to_string() {
  302. return ssprintf(
  303. "_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n,
  304. threadblock_k, warp_m, warp_n, warp_k, stage);
  305. }
  306. };
  307. AlgoFloat32NCHWFMAImplicitBatchedGemm(AlgoParam algo_param)
  308. : m_algo_param{algo_param},
  309. m_name{ssprintf(
  310. "FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM%s",
  311. m_algo_param.to_string().c_str())} {}
  312. bool is_available(const SizeArgs& args) const override;
  313. size_t get_workspace_in_bytes(const SizeArgs& args) const override { return 0; }
  314. void exec(const ExecArgs& args) const override;
  315. const char* name() const override { return m_name.c_str(); }
  316. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  317. MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32)
  318. private:
  319. const void* get_available_op(const SizeArgs& args) const;
  320. AlgoParam m_algo_param;
  321. std::string m_name;
  322. };
  323. class ConvolutionBackwardDataImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm final
  324. : public AlgoBase {
  325. public:
  326. /// add instruction shape as member of algo param, because f16 tensor core has 2
  327. /// different matrix shapes (i.e. mma.884 and mma.1688)
  328. struct AlgoParam {
  329. int threadblock_m;
  330. int threadblock_n;
  331. int threadblock_k;
  332. int warp_m;
  333. int warp_n;
  334. int warp_k;
  335. int instruction_m;
  336. int instruction_n;
  337. int instruction_k;
  338. int stage;
  339. std::string to_string() {
  340. return ssprintf(
  341. "_%dX%dX%d_%dX%dX%d_mma%dX%dX%d_%dstage", threadblock_m,
  342. threadblock_n, threadblock_k, warp_m, warp_n, warp_k, instruction_m,
  343. instruction_n, instruction_k, stage);
  344. }
  345. };
  346. AlgoFloat16NCHWHMMAImplicitBatchedGemm(AlgoParam algo_param)
  347. : m_algo_param{algo_param},
  348. m_name{ssprintf(
  349. "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM%s",
  350. m_algo_param.to_string().c_str())} {}
  351. bool is_available(const SizeArgs& args) const override;
  352. size_t get_workspace_in_bytes(const SizeArgs& args) const override { return 0; }
  353. void exec(const ExecArgs& args) const override;
  354. const char* name() const override { return m_name.c_str(); }
  355. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  356. MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16)
  357. private:
  358. const void* get_available_op(const SizeArgs& args) const;
  359. AlgoParam m_algo_param;
  360. std::string m_name;
  361. };
  362. class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
  363. // defined in cudnn.cpp
  364. void fill_cudnn_algos();
  365. // defined in implicit_gemm_int8_nchw4_dp4a.cpp
  366. void fill_int8_dp4a_algos();
  367. // defined in implicit_gemm_int8_nhwc_imma.cpp
  368. void fill_int8_imma_algos();
  369. void fill_dwconv_algos();
  370. AlgoBase::Mapper m_all_algos_map;
  371. public:
  372. AlgoPack();
  373. std::vector<AlgoCUDNN> cudnn;
  374. AlgoMatmul matmul;
  375. AlgoChanwise chanwise;
  376. AlgoChanwiseSmall chanwise_small;
  377. AlgoDepthwiseLargeFilter depthwise_large_filter;
  378. AlgoBFloat16 bfloat16;
  379. AlgoGroupConvGeneral group;
  380. std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod;
  381. AlgoInt8NCHWDotProdImplicitGemm int8_nchw_dotprod;
  382. std::vector<AlgoInt8NHWCIMMAImplicitGemm> int8_nhwc_imma;
  383. std::vector<AlgoFloat32NCHWFMAImplicitBatchedGemm> implbmm_nchw_fma;
  384. std::vector<AlgoFloat16NCHWHMMAImplicitBatchedGemm> implbmm_nchw_hmma;
  385. std::vector<AlgoBase*>
  386. //! all algorithms
  387. all_algos,
  388. //! non-cudnn algos, used for heuristic if cudnn is not supported
  389. non_cudnn_algos, bfloat16_algos, int8_algos;
  390. AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);
  391. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  392. };
  393. } // namespace cuda
  394. } // namespace megdnn
  395. // vim: syntax=cpp.doxygen