You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo.cpp 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. #include "src/cuda/conv_bias/algo.h"
  2. #include "src/cuda/utils.h"
  3. using namespace megdnn;
  4. using namespace cuda;
  5. ConvBiasForwardImpl::AlgoPack::AlgoPack() {
  6. non_cudnn_algos.push_back(&chanwise);
  7. non_cudnn_algos.push_back(&chanwise_small);
  8. non_cudnn_algos.push_back(&depthwise_large_filter);
  9. non_cudnn_algos.push_back(&inplace_matmul);
  10. non_cudnn_algos.push_back(&matmul);
  11. non_cudnn_algos.push_back(&matmul8x8x32);
  12. non_cudnn_algos.push_back(&batched_matmul);
  13. non_cudnn_algos.push_back(&int1_simple);
  14. #if CUDNN_VERSION >= 8020
  15. all_algos.push_back(&cudnn_conv_v8);
  16. all_algos.push_back(&cudnn_conv_bias_activation_v8);
  17. #endif
  18. fill_cudnn_algos();
  19. for (auto&& algo : cudnn_conv_bias_activations) {
  20. all_algos.push_back(&algo);
  21. }
  22. //! add conv+nonlinear algos
  23. std::vector<AlgoBase*> conv_algos;
  24. conv_algos.push_back(&chanwise);
  25. conv_algos.push_back(&chanwise_small);
  26. conv_algos.push_back(&depthwise_large_filter);
  27. conv_algos.push_back(&chanwise8x8x32);
  28. for (auto&& algo : cudnn_convs) {
  29. conv_algos.push_back(&algo);
  30. }
  31. conv_algos.push_back(&inplace_matmul);
  32. conv_algos.push_back(&matmul);
  33. conv_algos.push_back(&matmul8x8x32);
  34. conv_algos.push_back(&batched_matmul);
  35. conv_algos.push_back(&group);
  36. conv_algos.push_back(&int1_simple);
  37. for (auto&& algo : conv_algos) {
  38. all_algos.push_back(algo);
  39. }
  40. all_algos.push_back(&bfloat16);
  41. bfloat16_algos.push_back(&bfloat16);
  42. size_t all_algo_size = all_algos.size();
  43. #if CUDA_VERSION >= 10000
  44. fill_imma_algos();
  45. all_algos.push_back(&wmma_quint4x4x32);
  46. for (auto&& algo : int8_nchw4_imma) {
  47. all_algos.push_back(&algo);
  48. }
  49. for (auto&& algo : int8_chwn4_imma) {
  50. all_algos.push_back(&algo);
  51. }
  52. for (auto&& algo : int8_chwn4_imma_reorder_filter) {
  53. all_algos.push_back(&algo);
  54. }
  55. for (auto&& algo : int8_chwn4_imma_unroll_width) {
  56. all_algos.push_back(&algo);
  57. }
  58. #if CUDA_VERSION >= 10020
  59. for (auto&& algo : int8_nchw32_imma) {
  60. all_algos.push_back(&algo);
  61. }
  62. for (auto&& algo : int8_nhwc_imma) {
  63. all_algos.push_back(&algo);
  64. }
  65. for (auto&& algo : int4_int4_nchw64_imma) {
  66. all_algos.push_back(&algo);
  67. }
  68. for (auto&& algo : uint4_int4_nchw64_imma) {
  69. all_algos.push_back(&algo);
  70. }
  71. for (auto&& algo : int4_int4_nhwc_imma) {
  72. all_algos.push_back(&algo);
  73. }
  74. for (auto&& algo : uint4_int4_nhwc_imma) {
  75. all_algos.push_back(&algo);
  76. }
  77. #endif
  78. #endif
  79. fill_dp4a_algos();
  80. for (auto&& algo : int8_nchw4_dotprod) {
  81. all_algos.push_back(&algo);
  82. }
  83. fill_dwconv_algos();
  84. all_algos.push_back(&int8_chwn4_dotprod);
  85. all_algos.push_back(&fallback_nchw_qs8);
  86. for (size_t i = all_algo_size; i < all_algos.size(); ++i) {
  87. non_cudnn_algos.push_back(all_algos[i]);
  88. }
  89. for (auto&& algo : all_algos) {
  90. m_all_algos_map.emplace(algo->info().desc, algo);
  91. }
  92. }
  93. ConvBiasForwardImpl::AlgoPack ConvBiasForwardImpl::sm_algo_pack;
  94. MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvBiasForwardImpl)
  95. ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs(
  96. const ConvBiasForwardImpl* o, const TensorLayout& src,
  97. const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z,
  98. const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter)
  99. : SizeArgs(
  100. o, src, filter, o->make_canonized_filter_meta(src.ndim, filter), bias,
  101. z, dst, preprocessed_filter) {}
  102. ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs(
  103. const ConvBiasForwardImpl* o, const TensorLayout& src,
  104. const TensorLayout& filter, const CanonizedFilterMeta& filter_meta,
  105. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  106. const PreprocessedFilter* preprocessed_filter)
  107. : BiasForwardSizeArgs{concrete_handle(o->handle()),
  108. &src,
  109. &filter,
  110. &bias,
  111. &z,
  112. filter_meta,
  113. &dst,
  114. o->param().nonlineMode},
  115. opr{o},
  116. preprocessed_filter{preprocessed_filter} {}
  117. ConvBiasForwardImpl::AlgoBase::ExecArgs::ExecArgs(
  118. ConvBiasForwardImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_in filter,
  119. _megdnn_tensor_in bias, _megdnn_tensor_in z, _megdnn_tensor_out dst,
  120. _megdnn_workspace workspace, const PreprocessedFilter* preprocessed_filter)
  121. : SizeArgs(
  122. opr, src.layout, filter.layout, bias.layout, z.layout, dst.layout,
  123. preprocessed_filter),
  124. src_tensor{&src},
  125. filter_tensor{&filter},
  126. bias_tensor{&bias},
  127. z_tensor{&z},
  128. dst_tensor{&dst},
  129. workspace{workspace} {}
  130. std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const {
  131. auto&& fm = filter_meta;
  132. MEGDNN_MARK_USED_VAR(fm);
  133. std::string nonlinear_mode_str;
  134. switch (nonlinear_mode) {
  135. case param::ConvBias::NonlineMode::RELU:
  136. nonlinear_mode_str = "RELU";
  137. break;
  138. case param::ConvBias::NonlineMode::SIGMOID:
  139. nonlinear_mode_str = "SIGMOID";
  140. break;
  141. case param::ConvBias::NonlineMode::IDENTITY:
  142. nonlinear_mode_str = "IDENTITY";
  143. break;
  144. case param::ConvBias::NonlineMode::H_SWISH:
  145. nonlinear_mode_str = "H_SWISH";
  146. break;
  147. default:
  148. megdnn_throw("invalid conv bias nonlinear mode");
  149. }
  150. return ssprintf(
  151. "src=%s, filter=%s, bias=%s, z=%s, dst=%s, "
  152. "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s, "
  153. "nonlinear_mode=%s",
  154. src_layout->to_string().c_str(), filter_layout->to_string().c_str(),
  155. bias_layout->to_string().c_str(), z_layout->to_string().c_str(),
  156. dst_layout->to_string().c_str(), fm.padding[0], fm.padding[1], fm.stride[0],
  157. fm.stride[1], fm.dilation[0], fm.dilation[1], !fm.should_flip,
  158. src_layout->dtype.name(), dst_layout->dtype.name(),
  159. nonlinear_mode_str.c_str());
  160. }
  161. param::Convolution ConvBiasForwardImpl::AlgoBase::get_param_convolution(
  162. const SizeArgs& args) const {
  163. param::Convolution::Mode mode;
  164. param::Convolution::Sparse sparse = args.filter_meta.group > 1
  165. ? param::Convolution::Sparse::GROUP
  166. : param::Convolution::Sparse::DENSE;
  167. if (args.filter_meta.should_flip) {
  168. mode = param::Convolution::Mode::CONVOLUTION;
  169. } else {
  170. mode = param::Convolution::Mode::CROSS_CORRELATION;
  171. }
  172. return param::Convolution{
  173. mode,
  174. args.filter_meta.padding[0],
  175. args.filter_meta.padding[1],
  176. args.filter_meta.stride[0],
  177. args.filter_meta.stride[1],
  178. args.filter_meta.dilation[1],
  179. args.filter_meta.dilation[0],
  180. sparse,
  181. args.filter_meta.format,
  182. args.opr->param().compute_mode};
  183. }
  184. void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() {
  185. for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) {
  186. cudnn_conv_bias_activations.push_back(algo.first);
  187. cudnn_convs.push_back(algo.first);
  188. }
  189. }
  190. #if CUDA_VERSION >= 10000
  191. void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() {
  192. int8_chwn4_imma.push_back(
  193. {AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize::IMMA16x16x16});
  194. int8_chwn4_imma.push_back(
  195. {AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize::IMMA32x8x16});
  196. int8_chwn4_imma.push_back(
  197. {AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize::IMMA8x32x16});
  198. int8_nchw4_imma.push_back(
  199. {AlgoInt8NCHW4IMMAImplicitGemm::MMATileSize::IMMA16x16x16});
  200. int8_nchw4_imma.push_back(
  201. {AlgoInt8NCHW4IMMAImplicitGemm::MMATileSize::IMMA32x8x16});
  202. int8_nchw4_imma.push_back(
  203. {AlgoInt8NCHW4IMMAImplicitGemm::MMATileSize::IMMA8x32x16});
  204. int8_chwn4_imma_reorder_filter.push_back(
  205. {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize::IMMA16x16x16});
  206. int8_chwn4_imma_reorder_filter.push_back(
  207. {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize::IMMA32x8x16});
  208. int8_chwn4_imma_reorder_filter.push_back(
  209. {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize::IMMA8x32x16});
  210. int8_chwn4_imma_unroll_width.push_back(
  211. {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize::IMMA16x16x16});
  212. int8_chwn4_imma_unroll_width.push_back(
  213. {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize::IMMA32x8x16});
  214. int8_chwn4_imma_unroll_width.push_back(
  215. {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize::IMMA8x32x16});
  216. #if CUDA_VERSION >= 10020
  217. {
  218. using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam;
  219. int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 8, 8, 16, 2});
  220. int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 8, 8, 16, 2});
  221. int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 8, 8, 16, 2});
  222. int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 8, 8, 16, 2});
  223. int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 8, 8, 16, 2});
  224. int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 8, 8, 16, 1});
  225. int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1});
  226. int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 8, 8, 16, 1});
  227. int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 8, 8, 16, 1});
  228. }
  229. {
  230. using AlgoParam = AlgoInt8NHWCIMMAImplicitGemm::AlgoParam;
  231. int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 16});
  232. int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 8});
  233. int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 4});
  234. int8_nhwc_imma.emplace_back(
  235. AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 16});
  236. int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 8});
  237. int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 4});
  238. }
  239. {
  240. using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam;
  241. int4_int4_nchw64_imma.emplace_back(
  242. AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2});
  243. int4_int4_nchw64_imma.emplace_back(
  244. AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2});
  245. int4_int4_nchw64_imma.emplace_back(
  246. AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2});
  247. int4_int4_nchw64_imma.emplace_back(
  248. AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1});
  249. }
  250. {
  251. using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam;
  252. uint4_int4_nchw64_imma.emplace_back(
  253. AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2});
  254. uint4_int4_nchw64_imma.emplace_back(
  255. AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2});
  256. uint4_int4_nchw64_imma.emplace_back(
  257. AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2});
  258. uint4_int4_nchw64_imma.emplace_back(
  259. AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1});
  260. }
  261. {
  262. using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam;
  263. int4_int4_nhwc_imma.emplace_back(
  264. AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 32});
  265. int4_int4_nhwc_imma.emplace_back(
  266. AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 16});
  267. int4_int4_nhwc_imma.emplace_back(
  268. AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 8});
  269. int4_int4_nhwc_imma.emplace_back(
  270. AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32});
  271. int4_int4_nhwc_imma.emplace_back(
  272. AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16});
  273. int4_int4_nhwc_imma.emplace_back(
  274. AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8});
  275. int4_int4_nhwc_imma.emplace_back(
  276. AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32});
  277. int4_int4_nhwc_imma.emplace_back(
  278. AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16});
  279. int4_int4_nhwc_imma.emplace_back(
  280. AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8});
  281. }
  282. {
  283. using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam;
  284. uint4_int4_nhwc_imma.emplace_back(
  285. AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 32});
  286. uint4_int4_nhwc_imma.emplace_back(
  287. AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 16});
  288. uint4_int4_nhwc_imma.emplace_back(
  289. AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 8});
  290. uint4_int4_nhwc_imma.emplace_back(
  291. AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32});
  292. uint4_int4_nhwc_imma.emplace_back(
  293. AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16});
  294. uint4_int4_nhwc_imma.emplace_back(
  295. AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8});
  296. uint4_int4_nhwc_imma.emplace_back(
  297. AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32});
  298. uint4_int4_nhwc_imma.emplace_back(
  299. AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16});
  300. uint4_int4_nhwc_imma.emplace_back(
  301. AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8});
  302. }
  303. #endif
  304. }
  305. #endif
  306. void ConvBiasForwardImpl::AlgoPack::fill_dwconv_algos() {
  307. using AlgoParam = AlgoCutlassConvolutionBase::AlgoParam;
  308. /// preferred algo
  309. f32_implicit_bmm.emplace_back(AlgoParam{64, 128, 8, 32, 64, 8, 1, 1, 1, 2});
  310. f32_implicit_bmm.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8, 1, 1, 1, 2});
  311. f32_implicit_bmm.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8, 1, 1, 1, 2});
  312. f32_implicit_bmm.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8, 1, 1, 1, 2});
  313. f32_implicit_bmm.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8, 1, 1, 1, 2});
  314. f32_implicit_bmm.emplace_back(AlgoParam{64, 64, 8, 32, 64, 8, 1, 1, 1, 2});
  315. f32_implicit_bmm.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8, 1, 1, 1, 2});
  316. f32_implicit_bmm.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8, 1, 1, 1, 2});
  317. f32_implicit_bmm.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8, 1, 1, 1, 2});
  318. for (auto&& algo : f32_implicit_bmm) {
  319. all_algos.push_back(&algo);
  320. }
  321. #if CUDA_VERSION >= 10010
  322. /// preferred algo
  323. f16_implicit_bmm.emplace_back(AlgoParam{64, 128, 32, 32, 32, 32, 8, 8, 4, 2});
  324. f16_implicit_bmm.emplace_back(AlgoParam{128, 128, 32, 32, 32, 32, 8, 8, 4, 2});
  325. f16_implicit_bmm.emplace_back(AlgoParam{128, 256, 32, 64, 64, 32, 8, 8, 4, 2});
  326. f16_implicit_bmm.emplace_back(AlgoParam{128, 64, 32, 32, 32, 32, 8, 8, 4, 2});
  327. f16_implicit_bmm.emplace_back(AlgoParam{64, 64, 32, 32, 32, 32, 8, 8, 4, 2});
  328. for (auto&& algo : f16_implicit_bmm) {
  329. all_algos.push_back(&algo);
  330. }
  331. #endif
  332. }
  333. void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() {
  334. using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam;
  335. int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2});
  336. int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1, 1, 4, 2});
  337. int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 1, 1, 4, 2});
  338. int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1, 1, 4, 2});
  339. int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 1, 4, 2});
  340. int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 1, 1, 4, 2});
  341. int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 1, 1, 4, 2});
  342. int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1, 1, 4, 1});
  343. int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2});
  344. }
  345. ConvBiasForwardImpl::AlgoBase* ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum(
  346. cudnnConvolutionFwdAlgo_t algo) {
  347. for (auto&& i : cudnn_convs) {
  348. if (i.cudnn_enum() == algo)
  349. return &i;
  350. }
  351. megdnn_throw(ssprintf(
  352. "can not find cudnn conv fwd algorithm %d", static_cast<int>(algo)));
  353. }
  354. ConvBiasForwardImpl::AlgoBase* ConvBiasForwardImpl::AlgoPack::
  355. cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo) {
  356. for (auto&& i : cudnn_conv_bias_activations) {
  357. if (i.cudnn_enum() == algo)
  358. return &i;
  359. }
  360. megdnn_throw(ssprintf(
  361. "can not find cudnn conv bias act algorithm %d", static_cast<int>(algo)));
  362. }
  363. // vim: syntax=cpp.doxygen