You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

strategy_default.cpp 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. /**
  2. * \file dnn/src/fallback/conv_bias/im2col/strategy_default.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/fallback/conv_bias/im2col/strategy_base.h"
  12. #include "src/fallback/convolution/img2col_helper.h"
  13. #if MEGDNN_X86
  14. #include "src/x86/conv_bias/postprocess_helper.h"
  15. #endif
  16. using namespace megdnn;
  17. #if MEGDNN_X86
  18. using namespace x86;
  19. #endif
  20. namespace megdnn {
  21. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  22. typename op_ctype, typename op_dtype,
  23. megdnn::PostprocessMode postprocess_mode>
  24. void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
  25. postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
  26. copy_padding_kern(WorkspaceBundle bundle,
  27. const fallback::ConvBiasImpl::NCBKernParam& param,
  28. const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
  29. size_t pack_oc_size) {
  30. UNPACK_CONV_F32_NCB_KERN_SIZES(param);
  31. MEGDNN_MARK_USED_VAR(N);
  32. MEGDNN_MARK_USED_VAR(OC);
  33. MEGDNN_MARK_USED_VAR(OH);
  34. MEGDNN_MARK_USED_VAR(OW);
  35. MEGDNN_MARK_USED_VAR(FH);
  36. MEGDNN_MARK_USED_VAR(FW);
  37. MEGDNN_MARK_USED_VAR(SH);
  38. MEGDNN_MARK_USED_VAR(SW);
  39. size_t IW2 = IW + 2 * PW;
  40. size_t IH2 = IH + 2 * PH;
  41. size_t batch_id = ncb_index.ndrange_id[0];
  42. size_t group_id = ncb_index.ndrange_id[1];
  43. size_t channel_id = ncb_index.ndrange_id[2];
  44. size_t PH_SIZE = PH * IW2 * pack_oc_size;
  45. PW = PW * pack_oc_size;
  46. IW = IW * pack_oc_size;
  47. size_t padding_group_size = IH2 * IW2 * IC;
  48. size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id;
  49. size_t workspace_group_offset = group_id * padding_group_size;
  50. size_t workspace_batch_offset =
  51. param.filter_meta.group * batch_id * padding_group_size;
  52. bundle.set(param.workspace_ptr);
  53. src_ctype src_zp = static_cast<src_ctype>(0);
  54. if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
  55. src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
  56. }
  57. src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>(
  58. batch_id, group_id, channel_id, 1, pack_oc_size));
  59. src_ctype* src2;
  60. src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) +
  61. workspace_group_offset + workspace_batch_offset +
  62. workspace_channel_offset;
  63. src_ctype* src2_ptr = src2;
  64. const src_ctype* src_ptr = src;
  65. if (PH != 0) {
  66. std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
  67. src2_ptr += PH_SIZE;
  68. }
  69. rep(ih, IH) {
  70. if (PW != 0)
  71. rep(pw, PW) * (src2_ptr++) = src_zp;
  72. std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW);
  73. src2_ptr += IW;
  74. src_ptr += IW;
  75. if (PW != 0)
  76. rep(pw, PW) * (src2_ptr++) = src_zp;
  77. }
  78. if (PH != 0) {
  79. std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
  80. src2_ptr += PH_SIZE;
  81. }
  82. }
  83. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  84. typename op_ctype, typename op_dtype,
  85. megdnn::PostprocessMode postprocess_mode>
  86. void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
  87. postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
  88. packA_kern(WorkspaceBundle bundle,
  89. const fallback::ConvBiasImpl::NCBKernParam& param,
  90. fallback::MatrixMulImpl::KernSizeParam matmulparam,
  91. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  92. const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
  93. size_t pack_oc_size) {
  94. bundle.set(param.workspace_ptr);
  95. fallback::MatrixMulImpl::KernParam matmul_param;
  96. size_t group_id = ncb_index.ndrange_id[0];
  97. static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
  98. matmulparam;
  99. size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0);
  100. size_t packed_per_oc_block_size =
  101. round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) *
  102. matmul_algo->get_inner_block_size().m *
  103. matmul_algo->get_packA_type_size();
  104. size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size;
  105. int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) +
  106. group_id * packA_group_size +
  107. (pack_oc_size == 4 ? 0 : a_panel_offset);
  108. matmul_param.A_ptr =
  109. const_cast<src_ctype*>(param.filter<src_ctype>(group_id));
  110. matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1],
  111. matmul_algo->get_inner_block_size().m * pack_oc_size);
  112. }
  113. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  114. typename op_ctype, typename op_dtype,
  115. megdnn::PostprocessMode postprocess_mode>
  116. void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
  117. postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
  118. exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
  119. const StrategyParam& sparam,
  120. const fallback::ConvBiasImpl::NCBKernParam& param,
  121. fallback::MatrixMulImpl::KernParam matmul_param,
  122. fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
  123. size_t sh = param.filter_meta.stride[0];
  124. size_t sw = param.filter_meta.stride[1];
  125. size_t oc = param.filter_meta.ocpg;
  126. size_t oh = param.osz[0];
  127. size_t ow = param.osz[1];
  128. size_t ic = param.filter_meta.icpg;
  129. size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
  130. size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
  131. size_t fh = param.filter_meta.spatial[0];
  132. size_t fw = param.filter_meta.spatial[1];
  133. size_t is_xcorr = !param.filter_meta.should_flip;
  134. size_t input_offset =
  135. ih * iw * ic *
  136. (sparam.group_id + param.filter_meta.group * sparam.batch_id) *
  137. sizeof(src_ctype);
  138. src_ctype* src2 = reinterpret_cast<src_ctype*>(
  139. reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) +
  140. input_offset);
  141. bool is_phpwzero = param.filter_meta.padding[0] == 0 &&
  142. param.filter_meta.padding[1] == 0;
  143. if (is_phpwzero) {
  144. src2 = const_cast<src_ctype*>(
  145. param.src<src_ctype>(sparam.batch_id, sparam.group_id));
  146. }
  147. src_ctype* im2col_dst = static_cast<src_ctype*>(
  148. bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
  149. if (sh == 1 && sw == 1) {
  150. if (is_xcorr) {
  151. img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
  152. sparam.ohw_cur_index, sparam.output_block_size);
  153. } else {
  154. img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
  155. sparam.ohw_cur_index, sparam.output_block_size);
  156. }
  157. } else {
  158. if (is_xcorr) {
  159. img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
  160. fw, sh, sw, sparam.ohw_cur_index,
  161. sparam.output_block_size);
  162. } else {
  163. img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
  164. fw, sh, sw, sparam.ohw_cur_index,
  165. sparam.output_block_size);
  166. }
  167. }
  168. matmul_param.M = sparam.output_block_oc_size;
  169. matmul_param.N = sparam.output_block_size;
  170. matmul_param.LDB = sparam.output_block_size;
  171. matmul_param.LDC = sparam.output_block_size;
  172. matmul_param.B_ptr = im2col_dst;
  173. src_ctype* b_panel =
  174. reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>(
  175. bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)));
  176. matmul_algo->pack_B(matmul_param, b_panel, 0, matmul_param.N);
  177. }
  178. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  179. typename op_ctype, typename op_dtype,
  180. megdnn::PostprocessMode postprocess_mode>
  181. void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
  182. postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
  183. get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
  184. const WorkspaceBundle& bundle_thread,
  185. const StrategyParam& sparam) {
  186. if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) {
  187. return static_cast<void*>(
  188. bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
  189. } else {
  190. bias_ctype* dst =
  191. param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) +
  192. sparam.oc_cur_index * sparam.ohw;
  193. return static_cast<void*>(dst);
  194. }
  195. }
  196. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  197. typename op_ctype, typename op_dtype,
  198. megdnn::PostprocessMode postprocess_mode>
  199. void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
  200. postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
  201. exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param,
  202. const StrategyParam& sparam, WorkspaceBundle bundle,
  203. WorkspaceBundle bundle_thread,
  204. fallback::MatrixMulImpl::KernParam matmul_param,
  205. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  206. const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
  207. size_t packA_per_oc_block_size =
  208. round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) *
  209. sparam.oc_tile_size * matmul_algo->get_packA_type_size();
  210. size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0);
  211. size_t a_panel_offset = ncb_index.ndrange_id[1] * packA_group_size +
  212. ncb_index.ndrange_id[3] * packA_per_oc_block_size;
  213. void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam);
  214. src_ctype* a_panel = reinterpret_cast<src_ctype*>(
  215. reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PACKA_INDEX)) +
  216. a_panel_offset);
  217. src_ctype* b_panel =
  218. reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>(
  219. bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)));
  220. size_t pack_oc_size = sparam.pack_oc_size;
  221. matmul_param.M = sparam.output_block_oc_size;
  222. matmul_param.N = sparam.output_block_size;
  223. matmul_param.LDB = pack_oc_size * sparam.output_block_size;
  224. matmul_param.LDC = pack_oc_size * sparam.output_block_size;
  225. matmul_param.C_ptr = matmul_dst;
  226. auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param);
  227. matmul_kern_naked(matmul_param, a_panel, b_panel);
  228. }
  229. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  230. typename op_ctype, typename op_dtype,
  231. megdnn::PostprocessMode postprocess_mode>
  232. void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
  233. postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
  234. exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
  235. const StrategyParam& sparam,
  236. WorkspaceBundle bundle_thread) {
  237. copy_bias(param, bundle_thread, sparam);
  238. void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam);
  239. const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
  240. param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
  241. void* bias_temp_ptr = get_bias_temp_ptr(param, bundle_thread);
  242. void* bias_preprocess_ptr = const_cast<void*>(
  243. param.bias_mode == megdnn::BiasMode::BIAS
  244. ? bias_temp_ptr
  245. : static_cast<void*>(const_cast<bias_ctype*>(
  246. bias_ptr + sparam.oc_cur_index)));
  247. PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
  248. matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode,
  249. param.nonlineMode, param.bias_type, param.dst_type, 1_z,
  250. sparam.output_block_oc_size, 1_z, sparam.output_block_size,
  251. sparam.pack_oc_size);
  252. copy_dst(param, matmul_dst, sparam);
  253. }
  254. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  255. typename op_ctype, typename op_dtype,
  256. megdnn::PostprocessMode postprocess_mode>
  257. void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
  258. postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
  259. copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
  260. const void* matmul_dst, const StrategyParam& sparam) {
  261. if (!sparam.skip_copy_dst) {
  262. size_t pack_oc_size = sparam.pack_oc_size;
  263. dst_ctype* dst_tmp_ptr =
  264. reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst));
  265. dst_ctype* dst =
  266. param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) +
  267. sparam.oc_cur_index * sparam.ohw +
  268. sparam.ohw_cur_index * pack_oc_size;
  269. size_t oc_loop = sparam.output_block_oc_size / pack_oc_size;
  270. for (size_t oc = 0; oc < oc_loop; oc++) {
  271. std::memcpy(dst, dst_tmp_ptr,
  272. sizeof(dst_ctype) * sparam.output_block_size *
  273. pack_oc_size);
  274. dst_tmp_ptr += sparam.output_block_size * pack_oc_size;
  275. dst += sparam.ohw * pack_oc_size;
  276. }
  277. }
  278. }
  279. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  280. typename op_ctype, typename op_dtype,
  281. megdnn::PostprocessMode postprocess_mode>
  282. void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
  283. postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
  284. get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
  285. const WorkspaceBundle& bundle_thread) {
  286. bias_ctype* bias_tmp_ptr =
  287. param.bias_mode == megdnn::BiasMode::BIAS
  288. ? static_cast<bias_ctype*>(
  289. bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX))
  290. : nullptr;
  291. return bias_tmp_ptr;
  292. }
  293. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  294. typename op_ctype, typename op_dtype,
  295. megdnn::PostprocessMode postprocess_mode>
  296. void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
  297. postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
  298. copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
  299. WorkspaceBundle bundle_thread, const StrategyParam& sparam) {
  300. const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
  301. param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
  302. bias_ctype* bias_temp_ptr =
  303. static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread));
  304. if (param.bias_mode == megdnn::BiasMode::BIAS) {
  305. bias_ctype* copy_dst = bias_temp_ptr;
  306. const bias_ctype* copy_src = bias_ptr +
  307. sparam.oc_cur_index * sparam.ohw +
  308. sparam.ohw_cur_index;
  309. for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) {
  310. std::memcpy(copy_dst, copy_src,
  311. sizeof(bias_ctype) * sparam.output_block_size);
  312. copy_dst += sparam.output_block_size;
  313. copy_src += sparam.ohw;
  314. }
  315. }
  316. }
  317. #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
  318. _op_dtype, _postprocess_mode) \
  319. template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
  320. _op_dtype, _postprocess_mode, PackMode::DEFAULT, \
  321. FormatMode::NCHW>;
  322. INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
  323. megdnn::PostprocessMode::FLOAT)
  324. #if !MEGDNN_DISABLE_FLOAT16
  325. INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
  326. megdnn::PostprocessMode::NO_PROCESS)
  327. #endif
  328. INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
  329. megdnn::PostprocessMode::QUANTIZED)
  330. INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
  331. megdnn::PostprocessMode::NO_PROCESS)
  332. INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16,
  333. megdnn::PostprocessMode::NO_PROCESS)
  334. INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32,
  335. megdnn::PostprocessMode::NO_PROCESS)
  336. #undef INSTANTIAL_CLASS
  337. } // namespace megdnn

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台