You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo.h 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. /**
  2. * \file dnn/src/cuda/conv_bias/algo.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/oprs.h"
  13. #include "src/common/utils.h"
  14. #include "src/cuda/conv_bias/conv_bias_int8.cuh"
  15. #include "src/cuda/conv_bias/helper.h"
  16. #include "src/cuda/conv_bias/opr_impl.h"
  17. #include "src/cuda/convolution_helper/parameter.cuh"
  18. #include "src/cuda/handle.h"
  19. #include <cuda.h>
  20. #include <memory>
  21. #include <unordered_map>
  22. namespace megdnn {
  23. namespace cuda {
  24. /*!
  25. * \brief base class for conv bias algos
  26. *
  27. * All the algo impls should try to support non-contiguous batch dim, for group
  28. * conv execution.
  29. */
  30. class ConvBiasForwardImpl::AlgoBase : public Algorithm {
  31. protected:
  32. ~AlgoBase() = default;
  33. public:
  34. AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
  35. struct SizeArgs : public conv_bias::BiasForwardSizeArgs {
  36. ConvBiasForwardImpl* opr;
  37. std::string to_string() const;
  38. SizeArgs(ConvBiasForwardImpl* opr, const TensorLayout& src,
  39. const TensorLayout& filter, const TensorLayout& bias,
  40. const TensorLayout& z, const TensorLayout& dst);
  41. SizeArgs(ConvBiasForwardImpl* opr, const TensorLayout& src,
  42. const TensorLayout& filter,
  43. const CanonizedFilterMeta& filter_meta,
  44. const TensorLayout& bias, const TensorLayout& z,
  45. const TensorLayout& dst);
  46. void init_conv_bias_desc(conv_bias::CUDNNForwardDescs& desc) const {
  47. desc.set_conv_bias(*src_layout, filter_meta, *dst_layout,
  48. *bias_layout, *z_layout, opr->param());
  49. }
  50. void init_conv_desc(conv_bias::CUDNNForwardDescs& desc) const {
  51. desc.set_conv(*src_layout, filter_meta, *dst_layout, opr->param());
  52. }
  53. };
  54. struct ExecArgs : public SizeArgs {
  55. const TensorND *src_tensor, *filter_tensor, *bias_tensor, *z_tensor,
  56. *dst_tensor;
  57. Workspace workspace;
  58. ExecArgs(ConvBiasForwardImpl* opr, _megdnn_tensor_in src,
  59. _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  60. _megdnn_tensor_in z, _megdnn_tensor_out dst,
  61. _megdnn_workspace workspace);
  62. };
  63. virtual bool is_available(const SizeArgs& args) const = 0;
  64. virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
  65. virtual void exec(const ExecArgs& args) const = 0;
  66. bool is_available_wk(const SizeArgs& args, size_t limit) {
  67. return is_available(args) && get_workspace_in_bytes(args) <= limit;
  68. }
  69. bool is_available_reproducible(
  70. const SizeArgs& args, bool reproducible = true,
  71. size_t limit = std::numeric_limits<size_t>::max()) {
  72. return (!reproducible || is_reproducible()) &&
  73. is_available_wk(args, limit);
  74. }
  75. AlgoBase& check_workspace(const SizeArgs& args,
  76. const Workspace& workspace) {
  77. auto req = get_workspace_in_bytes(args);
  78. megdnn_assert(
  79. req <= workspace.size,
  80. "conv bias fwd algo %s: required workspace %zu bytes, got %zu",
  81. name(), req, workspace.size);
  82. return *this;
  83. }
  84. virtual bool is_cudnn() const { return false; }
  85. };
  86. class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase {
  87. public:
  88. AlgoCUDNNConvBiasActivation(bool is_reproducible, const char* name,
  89. cudnnConvolutionFwdAlgo_t cudnn_enum)
  90. : m_is_reproducible(is_reproducible),
  91. m_name(ConvBiasForward::algo_name<DefaultParam>(name, {})),
  92. m_cudnn_enum(cudnn_enum) {}
  93. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  94. void exec(const ExecArgs& args) const override;
  95. param::Convolution get_param_convolution(const SizeArgs& args) const;
  96. bool is_available(const SizeArgs&) const override;
  97. const char* name() const override { return m_name.c_str(); }
  98. bool is_reproducible() const override { return m_is_reproducible; }
  99. cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; }
  100. bool is_cudnn() const override { return true; }
  101. private:
  102. bool m_is_reproducible;
  103. std::string m_name;
  104. cudnnConvolutionFwdAlgo_t m_cudnn_enum;
  105. };
  106. class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase {
  107. public:
  108. bool is_available(const SizeArgs& args) const override;
  109. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  110. void exec(const ExecArgs& args) const override;
  111. const char* name() const override {
  112. if (m_name.empty()) {
  113. m_name =
  114. ConvBiasForward::algo_name<DirectParam>("CHANNEL_WISE", {});
  115. }
  116. return m_name.c_str();
  117. }
  118. bool is_reproducible() const override { return true; }
  119. private:
  120. mutable std::string m_name;
  121. };
  122. class ConvBiasForwardImpl::AlgoChanwiseSmall final : public AlgoBase {
  123. public:
  124. bool is_available(const SizeArgs& args) const override;
  125. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  126. void exec(const ExecArgs& args) const override;
  127. const char* name() const override {
  128. if (m_name.empty()) {
  129. m_name = ConvBiasForward::algo_name<DirectParam>(
  130. "CHANNEL_WISE_SMALL", {});
  131. }
  132. return m_name.c_str();
  133. }
  134. bool is_reproducible() const override { return true; }
  135. private:
  136. mutable std::string m_name;
  137. };
  138. class ConvBiasForwardImpl::AlgoChanwise8x8x32 final : public AlgoBase {
  139. public:
  140. bool is_available(const SizeArgs& args) const override;
  141. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  142. void exec(const ExecArgs& args) const override;
  143. const char* name() const override {
  144. if (m_name.empty()) {
  145. m_name = ConvBiasForward::algo_name<DirectParam>(
  146. "CHANNEL_WISE_8X8X32", {});
  147. }
  148. return m_name.c_str();
  149. }
  150. bool is_reproducible() const override { return true; }
  151. private:
  152. mutable std::string m_name;
  153. };
  154. class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase {
  155. public:
  156. AlgoCUDNNConv(bool is_reproducible, const char* name,
  157. cudnnConvolutionFwdAlgo_t cudnn_enum)
  158. : m_is_reproducible(is_reproducible),
  159. m_name(ConvBiasForward::algo_name<DefaultParam>(name, {})),
  160. m_cudnn_enum(cudnn_enum) {}
  161. bool is_available(const SizeArgs& args) const override;
  162. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  163. void exec(const ExecArgs& args) const override;
  164. bool is_reproducible() const override { return m_is_reproducible; }
  165. const char* name() const override { return m_name.c_str(); }
  166. cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; }
  167. bool is_cudnn() const override { return true; }
  168. private:
  169. bool m_is_reproducible;
  170. std::string m_name;
  171. cudnnConvolutionFwdAlgo_t m_cudnn_enum;
  172. WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
  173. };
  174. //! compute small matmul in the kernel
  175. class ConvBiasForwardImpl::AlgoInplaceMatmul final : public AlgoBase {
  176. public:
  177. bool is_available(const SizeArgs& args) const override;
  178. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  179. void exec(const ExecArgs& args) const override;
  180. const char* name() const override {
  181. if (m_name.empty()) {
  182. m_name = ConvBiasForward::algo_name<ConvBias::MatmulParam>(
  183. "INPLACE_MATMUL", {});
  184. }
  185. return m_name.c_str();
  186. }
  187. bool is_reproducible() const override { return true; }
  188. private:
  189. mutable std::string m_name;
  190. };
  191. //! im2col and matmul, with dilation
  192. class ConvBiasForwardImpl::AlgoMatmul final : public AlgoBase {
  193. template <typename T>
  194. static void exec_internal(const ExecArgs& args,
  195. const WorkspaceBundle& bundle);
  196. public:
  197. bool is_available(const SizeArgs& args) const override;
  198. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  199. void exec(const ExecArgs& args) const override;
  200. const char* name() const override {
  201. if (m_name.empty()) {
  202. m_name = ConvBiasForward::algo_name<ConvBiasForward::MatmulParam>(
  203. "MATMUL", {});
  204. }
  205. return m_name.c_str();
  206. }
  207. bool is_reproducible() const override { return true; }
  208. private:
  209. WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
  210. mutable std::string m_name;
  211. };
  212. class ConvBiasForwardImpl::AlgoMatmul8x8x32 final : public AlgoBase {
  213. public:
  214. bool is_available(const SizeArgs& args) const override;
  215. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  216. void exec(const ExecArgs& args) const override;
  217. const char* name() const override {
  218. if (m_name.empty()) {
  219. m_name = ConvBiasForward::algo_name<ConvBiasForward::MatmulParam>(
  220. "MATMUL8X8X32", {});
  221. }
  222. return m_name.c_str();
  223. }
  224. bool is_reproducible() const override { return true; }
  225. private:
  226. bool need_src_unroll(const SizeArgs& args) const;
  227. bool need_filter_reshape(const SizeArgs& args) const;
  228. template <Param::Format>
  229. WorkspaceBundle get_bundle(const SizeArgs& args) const;
  230. template <Param::Format>
  231. void exec_internal(const ExecArgs& args) const;
  232. mutable std::string m_name;
  233. };
  234. //! optimized 1x1 conv
  235. class ConvBiasForwardImpl::Algo1x1 final : public AlgoBase {
  236. static void extract_matmul_layouts(const SizeArgs& args, TensorLayout& A,
  237. TensorLayout& B, TensorLayout& C);
  238. public:
  239. bool is_available(const SizeArgs& args) const override;
  240. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  241. void exec(const ExecArgs& args) const override;
  242. const char* name() const override {
  243. if (m_name.empty()) {
  244. m_name = ConvBiasForward::algo_name<ConvBiasForward::MatmulParam>(
  245. "MATMUL1X1", {});
  246. }
  247. return m_name.c_str();
  248. }
  249. bool is_reproducible() const override { return true; }
  250. private:
  251. WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
  252. mutable std::string m_name;
  253. };
  254. class ConvBiasForwardImpl::AlgoBatchedMatmul final : public AlgoBase {
  255. static void extract_matmul_layouts(const SizeArgs& args, TensorLayout& A,
  256. TensorLayout& B, TensorLayout& C);
  257. public:
  258. bool is_available(const SizeArgs& args) const override;
  259. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  260. void exec(const ExecArgs& args) const override;
  261. const char* name() const override {
  262. if (m_name.empty()) {
  263. m_name = ConvBiasForward::algo_name<ConvBiasForward::MatmulParam>(
  264. "BATCHEDMATMUL", {});
  265. }
  266. return m_name.c_str();
  267. }
  268. bool is_reproducible() const override { return true; }
  269. private:
  270. WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
  271. mutable std::string m_name;
  272. };
  273. //! implement group conv by another algo
  274. class ConvBiasForwardImpl::AlgoGroupConvGeneral final : public AlgoBase {
  275. public:
  276. AlgoGroupConvGeneral(AlgoBase* impl);
  277. bool is_available(const SizeArgs& args) const override;
  278. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  279. void exec(const ExecArgs& args) const override;
  280. const char* name() const override { return m_name.c_str(); }
  281. bool is_reproducible() const override { return m_impl->is_reproducible(); }
  282. static void modify_size_args(SizeArgs& args, TensorLayout& src_pg,
  283. TensorLayout& dst_pg, TensorLayout& bias_pg);
  284. private:
  285. WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
  286. AlgoBase* m_impl;
  287. std::string m_name;
  288. };
  289. #if CUDA_VERSION >= 10000
  290. class ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA final : public AlgoBase {
  291. public:
  292. AlgoQUInt4x4x32WMMA() = default;
  293. bool is_available(const SizeArgs& args) const override;
  294. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  295. void exec(const ExecArgs& args) const override;
  296. const char* name() const override { return "QUINT4x4x32_WMMA"; }
  297. bool is_reproducible() const override { return true; }
  298. private:
  299. WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
  300. bool use_kernel_fhxfw(const SizeArgs& args) const;
  301. size_t get_workspace_in_bytes_do_conv(const SizeArgs& args) const;
  302. };
  303. #endif
  304. class ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm final
  305. : public AlgoBase {
  306. public:
  307. AlgoInt8CHWN4DotProdImplicitGemm() = default;
  308. bool is_available(const SizeArgs& args) const override;
  309. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  310. void exec(const ExecArgs& args) const override;
  311. const char* name() const override {
  312. return "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM";
  313. }
  314. bool is_reproducible() const override { return true; }
  315. template <typename BiasVisitor>
  316. static void dispatch_nonlinear_mode(
  317. const int8_t* d_src, const int8_t* d_filter,
  318. BiasVisitor bias_visitor, const int8_t* d_z, int8_t* d_dst,
  319. const convolution::ConvParam& param, float alpha, float beta,
  320. float gamma, float scale, cudaStream_t stream,
  321. param::ConvBias::NonlineMode nonlinear_mode);
  322. };
  323. class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final
  324. : public AlgoBase {
  325. public:
  326. struct AlgoParam {
  327. int threadblock_m;
  328. int threadblock_n;
  329. int threadblock_k;
  330. int warp_m;
  331. int warp_n;
  332. int warp_k;
  333. std::string to_string() {
  334. /// default algorithm
  335. if (threadblock_m == 128 && threadblock_n == 128 &&
  336. threadblock_k == 32 && warp_m == 32 && warp_n == 64 &&
  337. warp_k == 32) {
  338. return "";
  339. }
  340. return ssprintf("_%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n,
  341. threadblock_k, warp_m, warp_n, warp_k);
  342. }
  343. };
  344. AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param)
  345. : m_algo_param{algo_param},
  346. m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s",
  347. m_algo_param.to_string().c_str())} {}
  348. bool is_available(const SizeArgs& args) const override;
  349. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  350. void exec(const ExecArgs& args) const override;
  351. const char* name() const override { return m_name.c_str(); }
  352. bool is_reproducible() const override { return true; }
  353. private:
  354. WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
  355. const SizeArgs& args) const;
  356. AlgoParam m_algo_param;
  357. std::string m_name;
  358. };
  359. #if CUDA_VERSION >= 10000
  360. class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm final
  361. : public AlgoBase {
  362. public:
  363. enum class MMATileSize : uint32_t {
  364. IMMA16x16x16,
  365. IMMA32x8x16,
  366. IMMA8x32x16
  367. };
  368. AlgoInt8CHWN4IMMAImplicitGemm(MMATileSize mma_tile_size)
  369. : m_mma_tile_size{mma_tile_size},
  370. m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_" +
  371. to_string(m_mma_tile_size)} {}
  372. bool is_available(const SizeArgs& args) const override;
  373. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  374. void exec(const ExecArgs& args) const override;
  375. const char* name() const override {
  376. return m_name.c_str();
  377. }
  378. bool is_reproducible() const override { return true; }
  379. template <typename BiasVisitor>
  380. static void dispatch_nonlinear_mode(
  381. const int8_t* d_src, const int8_t* d_filter,
  382. BiasVisitor bias_visitor, int8_t* d_z, int8_t* d_dst,
  383. const convolution::ConvParam& param, float alpha, float beta,
  384. float gamma, float scale, cudaStream_t stream,
  385. param::ConvBias::NonlineMode nonlinear_mode,
  386. MMATileSize mma_tile_size);
  387. static std::string to_string(MMATileSize mma_tile_size);
  388. private:
  389. MMATileSize m_mma_tile_size;
  390. std::string m_name;
  391. };
  392. class ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm final
  393. : public AlgoBase {
  394. public:
  395. using MMATileSize = AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize;
  396. AlgoInt8NCHW4IMMAImplicitGemm(MMATileSize mma_tile_size)
  397. : m_mma_tile_size{mma_tile_size},
  398. m_name{"INT8_NCHW4_IMMA_IMPLICIT_GEMM_" +
  399. AlgoInt8CHWN4IMMAImplicitGemm::to_string(
  400. m_mma_tile_size)} {}
  401. bool is_available(const SizeArgs& args) const override;
  402. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  403. void exec(const ExecArgs& args) const override;
  404. const char* name() const override {
  405. return m_name.c_str();
  406. }
  407. bool is_reproducible() const override { return true; }
  408. private:
  409. WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
  410. const SizeArgs& args) const;
  411. MMATileSize m_mma_tile_size;
  412. std::string m_name;
  413. };
  414. class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter final
  415. : public AlgoBase {
  416. public:
  417. using MMATileSize = AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize;
  418. AlgoInt8CHWN4IMMAImplicitGemmReorderFilter(MMATileSize mma_tile_size)
  419. : m_mma_tile_size{mma_tile_size},
  420. m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_REORDER_FILTER_" +
  421. AlgoInt8CHWN4IMMAImplicitGemm::to_string(
  422. m_mma_tile_size)} {}
  423. bool is_available(const SizeArgs& args) const override;
  424. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  425. void exec(const ExecArgs& args) const override;
  426. const char* name() const override { return m_name.c_str(); }
  427. bool is_reproducible() const override { return true; }
  428. private:
  429. MMATileSize m_mma_tile_size;
  430. std::string m_name;
  431. };
  432. class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth final
  433. : public AlgoBase {
  434. public:
  435. using MMATileSize = AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize;
  436. AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth(MMATileSize mma_tile_size)
  437. : m_mma_tile_size{mma_tile_size},
  438. m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_" +
  439. AlgoInt8CHWN4IMMAImplicitGemm::to_string(
  440. m_mma_tile_size)} {}
  441. bool is_available(const SizeArgs& args) const override;
  442. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  443. void exec(const ExecArgs& args) const override;
  444. const char* name() const override { return m_name.c_str(); }
  445. bool is_reproducible() const override { return true; }
  446. private:
  447. MMATileSize m_mma_tile_size;
  448. std::string m_name;
  449. };
  450. #endif
  451. #if CUDA_VERSION >= 10020
  452. class ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm final
  453. : public AlgoBase {
  454. public:
  455. struct AlgoParam {
  456. int threadblock_m;
  457. int threadblock_n;
  458. int threadblock_k;
  459. int warp_m;
  460. int warp_n;
  461. int warp_k;
  462. };
  463. AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param)
  464. : m_algo_param{algo_param} {
  465. m_name = ConvBias::algo_name<ConvBias::DirectParam>(
  466. ssprintf("INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s",
  467. to_string(m_algo_param).c_str()),
  468. ConvBias::DirectParam{});
  469. }
  470. bool is_available(const SizeArgs& args) const override;
  471. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  472. void exec(const ExecArgs& args) const override;
  473. const char* name() const override { return m_name.c_str(); }
  474. bool is_reproducible() const override { return true; }
  475. static std::string to_string(AlgoParam algo_param);
  476. private:
  477. WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
  478. const SizeArgs& args) const;
  479. AlgoParam m_algo_param;
  480. std::string m_name;
  481. };
  482. #endif
  483. class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase {
  484. public:
  485. AlgoBFloat16(AlgoBase* impl);
  486. bool is_available(const SizeArgs& args) const override;
  487. size_t get_workspace_in_bytes(const SizeArgs& args) const override;
  488. void exec(const ExecArgs& args) const override;
  489. const char* name() const override { return m_name.c_str(); }
  490. bool is_reproducible() const override { return m_impl->is_reproducible(); }
  491. private:
  492. SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr,
  493. TensorLayout& fsrc, TensorLayout& ffilter,
  494. TensorLayout& fbias, TensorLayout& fz,
  495. TensorLayout& fdst) const;
  496. WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
  497. AlgoBase* m_impl;
  498. std::string m_name;
  499. };
  500. class ConvBiasForwardImpl::AlgoPack {
  501. AlgoPack(const AlgoPack&) = delete;
  502. AlgoPack& operator=(const AlgoPack&) = delete;
  503. public:
  504. AlgoPack();
  505. std::vector<AlgoBase*> all_algos,
  506. //! non-cudnn algos, used for heuristic if cudnn is not supported
  507. non_cudnn_algos,
  508. bfloat16_algos;
  509. std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations;
  510. std::vector<AlgoCUDNNConv> cudnn_convs;
  511. AlgoChanwise chanwise;
  512. AlgoChanwiseSmall chanwise_small;
  513. AlgoChanwise8x8x32 chanwise8x8x32;
  514. AlgoInplaceMatmul inplace_matmul;
  515. AlgoMatmul matmul;
  516. AlgoMatmul8x8x32 matmul8x8x32;
  517. AlgoBatchedMatmul batched_matmul;
  518. Algo1x1 a1x1;
  519. std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod;
  520. AlgoInt8CHWN4DotProdImplicitGemm int8_chwn4_dotprod;
  521. #if CUDA_VERSION >= 10000
  522. AlgoQUInt4x4x32WMMA wmma_quint4x4x32;
  523. std::vector<AlgoInt8CHWN4IMMAImplicitGemm> int8_chwn4_imma;
  524. std::vector<AlgoInt8NCHW4IMMAImplicitGemm> int8_nchw4_imma;
  525. std::vector<AlgoInt8CHWN4IMMAImplicitGemmReorderFilter>
  526. int8_chwn4_imma_reorder_filter;
  527. std::vector<AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth>
  528. int8_chwn4_imma_unroll_width;
  529. #endif
  530. #if CUDA_VERSION >= 10020
  531. std::vector<AlgoInt8NCHW32IMMAImplicitGemm> int8_nchw32_imma;
  532. #endif
  533. std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold;
  534. std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;
  535. std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
  536. AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo);
  537. AlgoBase* cudnn_conv_from_enum(cudnnConvolutionFwdAlgo_t algo);
  538. private:
  539. #if CUDA_VERSION >= 10000
  540. void fill_imma_algos();
  541. #endif
  542. void fill_cudnn_algos();
  543. void fill_dp4a_algos();
  544. };
  545. } // namespace cuda
  546. } // namespace megdnn
  547. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台