You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

avx2_chanwise_stride1.cpp 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. /**
  2. * \file src/x86/conv_bias/int8/avx2_chanwsie_stride1.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/x86/conv_bias/int8/avx2_chanwise_stride1.h"
  13. #include "src/x86/conv_bias/int8/avx2_chanwise_kern.h"
  14. #include "src/x86/elemwise_op.h"
  15. namespace megdnn {
  16. namespace x86 {
  17. namespace avx2_chanwise_stride1 {
  18. template <size_t filter, BiasMode bias_mode, bool is_quantized, typename Op>
  19. void conv_kimpl(const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
  20. const NCBKernIndex& ncb_index) {
  21. size_t OH = kern_param.osz[0];
  22. size_t OW = kern_param.osz[1];
  23. size_t IH2, IW2, OH2, OW2;
  24. get_rectified_size(kern_param, IH2, IW2, OH2, OW2);
  25. bool need_src_copy_var = need_src_copy(kern_param);
  26. bool need_dst_copy_var = need_dst_copy(kern_param);
  27. bool need_post_process =
  28. kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8;
  29. Op op = Op(1.0f, 4.0f);
  30. if (need_post_process) {
  31. float scale_bias =
  32. kern_param.bias_type.param<dtype::QuantizedS32>().scale;
  33. float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
  34. op = Op(scale_bias, scale_dst);
  35. }
  36. size_t padding_group_size = IH2 * IW2;
  37. size_t workspace_group_id = ncb_index.thread_id;
  38. size_t group_id = ncb_index.ndrange_id[0],
  39. batch_id = ncb_index.ndrange_id[1];
  40. const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id);
  41. const int8_t* fptr = kern_param.filter<dt_int8>(group_id);
  42. void* dst = kern_param.dst<void>(batch_id, group_id);
  43. const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id);
  44. if (need_src_copy_var) {
  45. sptr = static_cast<int8_t*>(bundle.get(0)) +
  46. workspace_group_id * padding_group_size;
  47. }
  48. void* dptr = nullptr;
  49. int32_t* tptr = nullptr;
  50. if (need_dst_copy_var) {
  51. dptr = reinterpret_cast<void*>(
  52. reinterpret_cast<ptrdiff_t>(bundle.get(1)) +
  53. ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size());
  54. } else {
  55. dptr = dst;
  56. }
  57. #define KERN_NEED_POST_PROCESS(filter) \
  58. avx2_chanwise_direct_stride1_##filter##x##filter##_int8<bias_mode, true, \
  59. Op>( \
  60. sptr, fptr, bptr, tptr, static_cast<int8_t*>(dptr), IH2, IW2, OH2, \
  61. OW2, op)
  62. #define KERN_NO_POST_PROCESS(filter) \
  63. avx2_chanwise_direct_stride1_##filter##x##filter##_int8<bias_mode, false, \
  64. Op>( \
  65. sptr, fptr, bptr, static_cast<int32_t*>(dptr), nullptr, IH2, IW2, \
  66. OH2, OW2, op)
  67. if (need_post_process) {
  68. tptr = static_cast<int32_t*>(bundle.get(2)) +
  69. ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size();
  70. DISPATCH_FILTER(filter, KERN_NEED_POST_PROCESS)
  71. } else {
  72. DISPATCH_FILTER(filter, KERN_NO_POST_PROCESS)
  73. }
  74. #undef KERN_NEED_POST_PROCESS
  75. #undef KERN_NO_POST_PROCESS
  76. if (need_dst_copy_var) {
  77. rep(oh, OH) {
  78. std::memcpy(reinterpret_cast<void*>(
  79. reinterpret_cast<ptrdiff_t>(dst) +
  80. oh * OW * kern_param.dst_type.size()),
  81. reinterpret_cast<void*>(
  82. reinterpret_cast<ptrdiff_t>(dptr) +
  83. oh * OW2 * kern_param.dst_type.size()),
  84. kern_param.dst_type.size() * OW);
  85. }
  86. }
  87. };
  88. SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& kern_param,
  89. const WorkspaceBundle& bundle) {
  90. MEGDNN_MARK_USED_VAR(kern_param);
  91. auto fm = kern_param.filter_meta;
  92. size_t group = fm.group;
  93. size_t n = kern_param.n;
  94. SmallVector<NCBKern> ncb_kerns;
  95. conv_fun do_conv_fun = nullptr;
  96. #define DO_CONV_KERN_FUN(filter, bias_mode, is_quantized, op) \
  97. do_conv_fun = conv_kimpl<filter, bias_mode, is_quantized, op>;
  98. #define GET_OP_PARAM(i, bias_mode, is_quantized) \
  99. switch (kern_param.nonlineMode) { \
  100. case param::ConvBias::NonlineMode::IDENTITY: \
  101. DO_CONV_KERN_FUN(i, bias_mode, is_quantized, \
  102. TypeCvtOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 \
  103. MEGDNN_COMMA dt_qint8>) \
  104. break; \
  105. case param::ConvBias::NonlineMode::RELU: \
  106. DO_CONV_KERN_FUN(i, bias_mode, is_quantized, \
  107. ReluOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 \
  108. MEGDNN_COMMA dt_qint8>) \
  109. break; \
  110. case param::ConvBias::NonlineMode::H_SWISH: \
  111. DO_CONV_KERN_FUN(i, bias_mode, is_quantized, \
  112. HSwishOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 \
  113. MEGDNN_COMMA dt_qint8>) \
  114. break; \
  115. default: \
  116. megdnn_assert(0); \
  117. break; \
  118. }
  119. #define GET_BIAS_MODE_PARAM(i, is_quantized) \
  120. switch (kern_param.bias_mode) { \
  121. case BiasMode::NO_BIAS: \
  122. GET_OP_PARAM(i, BiasMode::NO_BIAS, is_quantized) \
  123. break; \
  124. case BiasMode::BROADCAST_CHANNEL_BIAS: \
  125. GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS, is_quantized) \
  126. break; \
  127. default: \
  128. megdnn_assert(0); \
  129. break; \
  130. }
  131. #define GET_QUANTIZED(i) \
  132. switch (kern_param.dst_type.enumv()) { \
  133. case DTypeEnum::QuantizedS8: \
  134. GET_BIAS_MODE_PARAM(i, true) \
  135. break; \
  136. case DTypeEnum::QuantizedS32: \
  137. GET_BIAS_MODE_PARAM(i, false) \
  138. break; \
  139. case DTypeEnum::Int32: \
  140. GET_BIAS_MODE_PARAM(i, false) \
  141. break; \
  142. default: \
  143. megdnn_assert(0); \
  144. break; \
  145. }
  146. #define DISPATCH_CONV_KERN() \
  147. switch (kern_param.filter_meta.spatial[0]) { \
  148. case 2: \
  149. GET_QUANTIZED(2) \
  150. break; \
  151. case 3: \
  152. GET_QUANTIZED(3) \
  153. break; \
  154. case 5: \
  155. GET_QUANTIZED(5) \
  156. break; \
  157. case 7: \
  158. GET_QUANTIZED(7) \
  159. break; \
  160. default: \
  161. megdnn_assert(0); \
  162. break; \
  163. }
  164. DISPATCH_CONV_KERN();
  165. auto exec_one_group = [bundle = bundle, do_conv_fun](
  166. const NCBKernParam& kern_param,
  167. const NCBKernIndex& ncb_index) mutable {
  168. bundle.set(kern_param.workspace_ptr);
  169. copy_padding_kern(bundle, kern_param, ncb_index);
  170. do_conv_fun(bundle, kern_param, ncb_index);
  171. };
  172. ncb_kerns.push_back({exec_one_group, {group, n, 1_z}});
  173. return ncb_kerns;
  174. }
  175. } // namespace avx2_chanwise_stride1
  176. } // namespace x86
  177. } // namespace megdnn
  178. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台