You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

strategy_default.cpp 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. /**
  2. * \file dnn/src/fallback/conv_bias/im2col/strategy_default.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/fallback/conv_bias/im2col/strategy_base.h"
  12. #include "src/fallback/convolution/img2col_helper.h"
  13. namespace megdnn {
  14. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  15. typename op_ctype, typename op_dtype,
  16. megdnn::PostprocessMode postprocess_mode>
  17. void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
  18. postprocess_mode, PackMode::DEFAULT>::
  19. packA_kern(WorkspaceBundle bundle,
  20. const fallback::ConvBiasImpl::NCBKernParam& param,
  21. fallback::MatrixMulImpl::KernSizeParam matmulparam,
  22. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  23. const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
  24. size_t) {
  25. bundle.set(param.workspace_ptr);
  26. fallback::MatrixMulImpl::KernParam matmul_param;
  27. size_t group_id = ncb_index.ndrange_id[0];
  28. static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
  29. matmulparam;
  30. size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0);
  31. size_t packed_per_oc_block_size =
  32. round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) *
  33. matmul_algo->get_inner_block_size().m *
  34. matmul_algo->get_packA_type_size();
  35. size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size;
  36. int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) +
  37. group_id * packA_group_size + a_panel_offset;
  38. matmul_param.A_ptr =
  39. const_cast<src_ctype*>(param.filter<src_ctype>(group_id));
  40. matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1],
  41. matmul_algo->get_inner_block_size().m);
  42. }
  43. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  44. typename op_ctype, typename op_dtype,
  45. megdnn::PostprocessMode postprocess_mode>
  46. void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
  47. postprocess_mode, PackMode::DEFAULT>::
  48. exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
  49. const StrategyParam& sparam,
  50. const fallback::ConvBiasImpl::NCBKernParam& param,
  51. fallback::MatrixMulImpl::KernParam matmul_param,
  52. fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
  53. size_t sh = param.filter_meta.stride[0];
  54. size_t sw = param.filter_meta.stride[1];
  55. size_t oc = param.filter_meta.ocpg;
  56. size_t oh = param.osz[0];
  57. size_t ow = param.osz[1];
  58. size_t ic = param.filter_meta.icpg;
  59. size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
  60. size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
  61. size_t fh = param.filter_meta.spatial[0];
  62. size_t fw = param.filter_meta.spatial[1];
  63. size_t is_xcorr = !param.filter_meta.should_flip;
  64. size_t input_offset =
  65. ih * iw * ic *
  66. (sparam.group_id + param.filter_meta.group * sparam.batch_id) *
  67. sizeof(src_ctype);
  68. src_ctype* src2 = reinterpret_cast<src_ctype*>(
  69. reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) +
  70. input_offset);
  71. bool is_phpwzero = param.filter_meta.padding[0] == 0 &&
  72. param.filter_meta.padding[1] == 0;
  73. if (is_phpwzero) {
  74. src2 = const_cast<src_ctype*>(
  75. param.src<src_ctype>(sparam.batch_id, sparam.group_id));
  76. }
  77. src_ctype* im2col_dst = static_cast<src_ctype*>(
  78. bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
  79. if (sh == 1 && sw == 1) {
  80. if (is_xcorr) {
  81. img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
  82. sparam.ohw_cur_index, sparam.output_block_size);
  83. } else {
  84. img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
  85. sparam.ohw_cur_index, sparam.output_block_size);
  86. }
  87. } else {
  88. if (is_xcorr) {
  89. img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
  90. fw, sh, sw, sparam.ohw_cur_index,
  91. sparam.output_block_size);
  92. } else {
  93. img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
  94. fw, sh, sw, sparam.ohw_cur_index,
  95. sparam.output_block_size);
  96. }
  97. }
  98. matmul_param.M = sparam.output_block_oc_size;
  99. matmul_param.N = sparam.output_block_size;
  100. matmul_param.LDB = sparam.output_block_size;
  101. matmul_param.LDC = sparam.output_block_size;
  102. matmul_param.B_ptr = im2col_dst;
  103. src_ctype* b_panel =
  104. reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>(
  105. bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)));
  106. matmul_algo->pack_B(matmul_param, b_panel, 0, matmul_param.N);
  107. }
  108. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  109. typename op_ctype, typename op_dtype,
  110. megdnn::PostprocessMode postprocess_mode>
  111. void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
  112. postprocess_mode, PackMode::DEFAULT>::
  113. get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
  114. const WorkspaceBundle& bundle_thread,
  115. const StrategyParam& sparam) {
  116. if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) {
  117. return static_cast<void*>(
  118. bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
  119. } else {
  120. bias_ctype* dst =
  121. param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) +
  122. sparam.oc_cur_index * sparam.ohw;
  123. return static_cast<void*>(dst);
  124. }
  125. }
  126. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  127. typename op_ctype, typename op_dtype,
  128. megdnn::PostprocessMode postprocess_mode>
  129. void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
  130. postprocess_mode, PackMode::DEFAULT>::
  131. exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param,
  132. const StrategyParam& sparam, WorkspaceBundle bundle,
  133. WorkspaceBundle bundle_thread,
  134. fallback::MatrixMulImpl::KernParam matmul_param,
  135. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  136. const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
  137. size_t packA_per_oc_block_size =
  138. round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) *
  139. sparam.oc_tile_size * matmul_algo->get_packA_type_size();
  140. size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0);
  141. size_t a_panel_offset = ncb_index.ndrange_id[1] * packA_group_size +
  142. ncb_index.ndrange_id[3] * packA_per_oc_block_size;
  143. void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam);
  144. src_ctype* a_panel = reinterpret_cast<src_ctype*>(
  145. reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PACKA_INDEX)) +
  146. a_panel_offset);
  147. src_ctype* b_panel =
  148. reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>(
  149. bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)));
  150. size_t pack_oc_size = sparam.pack_oc_size;
  151. matmul_param.M = sparam.output_block_oc_size;
  152. matmul_param.N = sparam.output_block_size;
  153. matmul_param.LDB = pack_oc_size * sparam.output_block_size;
  154. matmul_param.LDC = pack_oc_size * sparam.output_block_size;
  155. matmul_param.C_ptr = matmul_dst;
  156. auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param);
  157. matmul_kern_naked(matmul_param, a_panel, b_panel);
  158. }
  159. #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
  160. _op_dtype, _postprocess_mode) \
  161. template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
  162. _op_dtype, _postprocess_mode, PackMode::DEFAULT>;
  163. INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
  164. megdnn::PostprocessMode::FLOAT)
  165. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  166. INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
  167. megdnn::PostprocessMode::FLOAT)
  168. #else
  169. #if !MEGDNN_DISABLE_FLOAT16
  170. INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
  171. megdnn::PostprocessMode::NO_PROCESS)
  172. #endif
  173. #endif
  174. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  175. //! x86 do not have uint8 matmul so only armv7 armv8 support uint8
  176. INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
  177. megdnn::PostprocessMode::QUANTIZED)
  178. INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
  179. megdnn::PostprocessMode::NO_PROCESS)
  180. #endif
  181. INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
  182. megdnn::PostprocessMode::QUANTIZED)
  183. INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
  184. megdnn::PostprocessMode::NO_PROCESS)
  185. INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16,
  186. megdnn::PostprocessMode::NO_PROCESS)
  187. INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32,
  188. megdnn::PostprocessMode::NO_PROCESS)
  189. #undef INSTANTIAL_CLASS
  190. } // namespace megdnn
  191. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台