You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

postprocess_helper.h 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. /**
  2. * \file dnn/src/arm_common/conv_bias/postprocess_helper.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "src/arm_common/elemwise_helper/kimpl/op_base.h"
  14. #include "src/arm_common/elemwise_op.h"
  15. #include "src/fallback/conv_bias/opr_impl.h"
  16. #include "midout.h"
  17. MIDOUT_DECL(arm_common_conv_bias_postprocess_helper)
  18. namespace {
  19. #define CONCAT_OP(_name) megdnn::arm_common::_name
  20. #define CONCAT_NL(_name) megdnn::NonlineMode::_name
  21. #define CB(_caller, _op, _mode, midout_tag) \
  22. case _mode: \
  23. MIDOUT_BEGIN(arm_common_conv_bias_postprocess_helper, 1, midout_tag) { \
  24. _caller(_op); \
  25. } \
  26. MIDOUT_END(); \
  27. break;
  28. #define DEFAULT \
  29. default: \
  30. megdnn_throw("unsupported nolinemode"); \
  31. break;
  32. #define HANDLE_IDENTITY() \
  33. case megdnn::NonlineMode::IDENTITY: \
  34. break;
  35. #define FOR_NONLINEAR_UNARY(_op) \
  36. megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::arm_common::VEC>:: \
  37. run(static_cast<ctype*>(conv_dst_ptr), \
  38. reinterpret_cast<ctype*>(dst_ptr), bias_type, dst_type, \
  39. N* OC* OH* OW* pack_oc_size);
  40. #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
  41. megdnn::arm_common:: \
  42. OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101>::run( \
  43. static_cast<ctype*>(conv_dst_ptr), \
  44. reinterpret_cast<const ctype*>(bias_ptr), \
  45. reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
  46. dst_type, N, OC, OH* OW);
  47. #define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \
  48. megdnn::arm_common::OpCallerBinary<_op<ctype>, \
  49. megdnn::arm_common::VEC_BCAST101x4>:: \
  50. run(static_cast<ctype*>(conv_dst_ptr), \
  51. reinterpret_cast<const ctype*>(bias_ptr), \
  52. reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
  53. dst_type, N, OC, OH* OW, pack_oc_size);
  54. #define FOR_NONLINEAR_BINARY(_op) \
  55. megdnn::arm_common:: \
  56. OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \
  57. static_cast<ctype*>(conv_dst_ptr), \
  58. reinterpret_cast<const ctype*>(bias_ptr), \
  59. reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
  60. dst_type, N* OC* OH* OW* pack_oc_size);
  61. #define FOR_BIAS(_mode) \
  62. switch (_mode) { \
  63. case megdnn::BiasMode::NO_BIAS: \
  64. MIDOUT_BEGIN(arm_common_conv_bias_postprocess_helper, 0, 0) { \
  65. FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY); \
  66. } \
  67. MIDOUT_END(); \
  68. break; \
  69. case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \
  70. MIDOUT_BEGIN(arm_common_conv_bias_postprocess_helper, 0, 1) { \
  71. if (pack_oc_size == 1) { \
  72. FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
  73. } else { \
  74. megdnn_assert(pack_oc_size == 4, \
  75. "Only support nchw44 in ARM"); \
  76. FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \
  77. } \
  78. } \
  79. MIDOUT_END(); \
  80. break; \
  81. case megdnn::BiasMode::BIAS: \
  82. MIDOUT_BEGIN(arm_common_conv_bias_postprocess_helper, 0, 2) { \
  83. FOR_NONLINEAR(FOR_NONLINEAR_BINARY); \
  84. } \
  85. MIDOUT_END(); \
  86. break; \
  87. default: \
  88. megdnn_throw("no quantized unsupported biasmode"); \
  89. break; \
  90. }
  91. #define FOR_NONLINEAR(_caller) \
  92. switch (nonlineMode) { \
  93. CB(_caller, CONCAT_OP(AddOp), CONCAT_NL(IDENTITY), 3) \
  94. CB(_caller, CONCAT_OP(FuseAddReluOp), CONCAT_NL(RELU), 4) \
  95. CB(_caller, CONCAT_OP(FuseAddSigmoidOp), CONCAT_NL(SIGMOID), 5) \
  96. CB(_caller, CONCAT_OP(FuseAddHSwishOp), CONCAT_NL(H_SWISH), 6) \
  97. DEFAULT \
  98. }
  99. #define FOR_NONLINEAR_NOBIAS(_caller) \
  100. switch (nonlineMode) { \
  101. HANDLE_IDENTITY() \
  102. CB(_caller, CONCAT_OP(ReluOp), CONCAT_NL(RELU), 7); \
  103. CB(_caller, CONCAT_OP(SigmoidOp), CONCAT_NL(SIGMOID), 8); \
  104. CB(_caller, CONCAT_OP(HSwishOp), CONCAT_NL(H_SWISH), 9); \
  105. DEFAULT \
  106. }
  107. template <typename ctype, typename dtype = ctype,
  108. megdnn::PostprocessMode postprocess_mode =
  109. megdnn::PostprocessMode::FLOAT>
  110. struct PostProcess {
  111. static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr,
  112. megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode,
  113. megdnn::DType bias_type, megdnn::DType dst_type, size_t N,
  114. size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) {
  115. FOR_BIAS(bias_mode)
  116. }
  117. };
  118. template <typename ctype, typename dtype>
  119. struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
  120. static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr,
  121. megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode,
  122. megdnn::DType bias_type, megdnn::DType dst_type, size_t N,
  123. size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) {
  124. MEGDNN_MARK_USED_VAR(conv_dst_ptr);
  125. MEGDNN_MARK_USED_VAR(bias_ptr);
  126. MEGDNN_MARK_USED_VAR(dst_ptr);
  127. MEGDNN_MARK_USED_VAR(bias_mode);
  128. MEGDNN_MARK_USED_VAR(nonlineMode);
  129. MEGDNN_MARK_USED_VAR(bias_type);
  130. MEGDNN_MARK_USED_VAR(dst_type);
  131. MEGDNN_MARK_USED_VAR(N);
  132. MEGDNN_MARK_USED_VAR(OC);
  133. MEGDNN_MARK_USED_VAR(OH);
  134. MEGDNN_MARK_USED_VAR(OW);
  135. MEGDNN_MARK_USED_VAR(pack_oc_size);
  136. megdnn_assert(bias_mode == megdnn::BiasMode::NO_BIAS &&
  137. nonlineMode == megdnn::NonlineMode::IDENTITY);
  138. }
  139. };
  140. #undef FOR_NONLINEAR_UNARY
  141. #undef FOR_NONLINEAR_BINARY_BROADCAST
  142. #undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44
  143. #undef FOR_NONLINEAR_BINARY
  144. #undef FOR_NONLINEAR_NOBIAS
  145. #undef FOR_NONLINEAR
  146. #undef FOR_BIAS
  147. #undef HANDLE_IDENTITY
  148. #define FOR_NONLINEAR_UNARY(_op) \
  149. megdnn::arm_common::OpCallerUnary< \
  150. _op<opctype, opdtype>, \
  151. megdnn::arm_common::VEC>::run(static_cast<opctype*>(conv_dst_ptr), \
  152. reinterpret_cast<opdtype*>(dst_ptr), \
  153. bias_type, dst_type, \
  154. N* OC* OH* OW* pack_oc_size);
  155. #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
  156. megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \
  157. megdnn::arm_common::VEC_BCAST101>:: \
  158. run(static_cast<opctype*>(conv_dst_ptr), \
  159. reinterpret_cast<const opctype*>(bias_ptr), \
  160. reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
  161. dst_type, N, OC, OH* OW);
  162. #define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \
  163. megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \
  164. megdnn::arm_common::VEC_BCAST101x4>:: \
  165. run(static_cast<opctype*>(conv_dst_ptr), \
  166. reinterpret_cast<const opctype*>(bias_ptr), \
  167. reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
  168. dst_type, N, OC, OH* OW, pack_oc_size);
  169. #define HANDLE_IDENTITY(_caller, _op) \
  170. case megdnn::NonlineMode::IDENTITY: \
  171. _caller(_op) break;
  172. #define FOR_NONLINEAR(_caller) \
  173. switch (nonlineMode) { \
  174. HANDLE_IDENTITY(_caller, CONCAT_OP(AddOp)) \
  175. CB(_caller, CONCAT_OP(FuseAddReluOp), CONCAT_NL(RELU), 10) \
  176. CB(_caller, CONCAT_OP(FuseAddHSwishOp), CONCAT_NL(H_SWISH), 11) \
  177. DEFAULT \
  178. }
  179. #define FOR_NONLINEAR_NOBIAS(_caller) \
  180. switch (nonlineMode) { \
  181. HANDLE_IDENTITY(_caller, CONCAT_OP(TypeCvtOp)) \
  182. CB(_caller, CONCAT_OP(ReluOp), CONCAT_NL(RELU), 12) \
  183. CB(_caller, CONCAT_OP(HSwishOp), CONCAT_NL(H_SWISH), 13) \
  184. DEFAULT \
  185. }
  186. #define FOR_BIAS(_bias_mode, OH, OW) \
  187. switch (_bias_mode) { \
  188. case megdnn::BiasMode::NO_BIAS: \
  189. FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY); \
  190. break; \
  191. case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \
  192. if (pack_oc_size == 1) { \
  193. FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
  194. } else { \
  195. megdnn_assert(pack_oc_size == 4, \
  196. "Only support nchw44 in ARM"); \
  197. FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \
  198. } \
  199. break; \
  200. default: \
  201. if (OH * OW == 1) { \
  202. FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
  203. break; \
  204. } \
  205. megdnn_throw("quantized unsupported biasmode"); \
  206. break; \
  207. }
  208. template <typename opctype, typename opdtype>
  209. struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
  210. static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr,
  211. megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode,
  212. megdnn::DType bias_type, megdnn::DType dst_type, size_t N,
  213. size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) {
  214. //! when OH * OW = 1, the bias_mode will be BiasMode::BIAS. It is wrong,
  215. //! we deal this case at default branch.
  216. FOR_BIAS(bias_mode, OH, OW);
  217. }
  218. };
  219. #undef FOR_NONLINEAR_UNARY
  220. #undef FOR_NONLINEAR_BINARY_BROADCAST
  221. #undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44
  222. #undef FOR_NONLINEAR_BINARY
  223. #undef FOR_NONLINEAR_NOBIAS
  224. #undef FOR_NONLINEAR
  225. #undef FOR_BIAS
  226. #undef CB
  227. #undef CONCAT_OP
  228. #undef CONCAT_NL
  229. #undef DEFAULT
  230. #undef HANDLE_IDENTITY
  231. #define DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, _bias_id, _src_type, \
  232. _dst_type, _bmode, _nonline_mode, ...) \
  233. switch (_nonline_mode) { \
  234. case param::ConvBias::NonlineMode::IDENTITY: { \
  235. MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \
  236. cb(_bmode, NoneOp<_src_type MEGDNN_COMMA _dst_type>, \
  237. __VA_ARGS__); \
  238. } \
  239. MIDOUT_END(); \
  240. break; \
  241. } \
  242. case param::ConvBias::NonlineMode::RELU: { \
  243. MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \
  244. cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, \
  245. __VA_ARGS__); \
  246. } \
  247. MIDOUT_END(); \
  248. break; \
  249. } \
  250. case param::ConvBias::NonlineMode::SIGMOID: { \
  251. MIDOUT_BEGIN(_midout_tag, _bias_id, 2) { \
  252. cb(_bmode, SigmoidOp<_src_type MEGDNN_COMMA _dst_type>, \
  253. __VA_ARGS__); \
  254. } \
  255. MIDOUT_END(); \
  256. break; \
  257. } \
  258. case param::ConvBias::NonlineMode::H_SWISH: { \
  259. MIDOUT_BEGIN(_midout_tag, _bias_id, 3) { \
  260. cb(_bmode, HSwishOp<_src_type MEGDNN_COMMA _dst_type>, \
  261. __VA_ARGS__); \
  262. } \
  263. MIDOUT_END(); \
  264. break; \
  265. } \
  266. default: \
  267. megdnn_assert(0); \
  268. break; \
  269. }
  270. #define DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED(_midout_tag, cb, _bias_id, \
  271. _src_type, _dst_type, _bmode, \
  272. _nonline_mode, ...) \
  273. switch (_nonline_mode) { \
  274. case param::ConvBias::NonlineMode::IDENTITY: { \
  275. MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \
  276. cb(_bmode, TypeCvtOp<_src_type MEGDNN_COMMA _dst_type>, \
  277. __VA_ARGS__); \
  278. } \
  279. MIDOUT_END(); \
  280. break; \
  281. } \
  282. case param::ConvBias::NonlineMode::RELU: { \
  283. MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \
  284. cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, \
  285. __VA_ARGS__); \
  286. } \
  287. MIDOUT_END(); \
  288. break; \
  289. } \
  290. default: \
  291. megdnn_assert(0); \
  292. break; \
  293. }
  294. #define DISPATCH_CONV_WINOGRAD_BIAS(_midout_tag, cb, _src_type, _dst_type, \
  295. _bmode, _nonline_mode, ...) \
  296. switch (_bmode) { \
  297. case BiasMode::BIAS: { \
  298. DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, 0, _src_type, \
  299. _dst_type, BiasMode::BIAS, \
  300. _nonline_mode, __VA_ARGS__) \
  301. break; \
  302. } \
  303. case BiasMode::NO_BIAS: { \
  304. DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, 1, _src_type, \
  305. _dst_type, BiasMode::NO_BIAS, \
  306. _nonline_mode, __VA_ARGS__) \
  307. break; \
  308. } \
  309. case BiasMode::BROADCAST_CHANNEL_BIAS: { \
  310. DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, 2, _src_type, \
  311. _dst_type, \
  312. BiasMode::BROADCAST_CHANNEL_BIAS, \
  313. _nonline_mode, __VA_ARGS__) \
  314. break; \
  315. } \
  316. default: \
  317. megdnn_assert(0); \
  318. break; \
  319. }
  320. #define DISPATCH_CONV_WINOGRAD_BIAS_QUANTIZED( \
  321. _midout_tag, cb, _src_type, _dst_type, _bmode, _nonline_mode, ...) \
  322. switch (_bmode) { \
  323. case BiasMode::BIAS: { \
  324. DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \
  325. _midout_tag, cb, 0, _src_type, _dst_type, BiasMode::BIAS, \
  326. _nonline_mode, __VA_ARGS__) \
  327. break; \
  328. } \
  329. case BiasMode::NO_BIAS: { \
  330. DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \
  331. _midout_tag, cb, 1, _src_type, _dst_type, \
  332. BiasMode::NO_BIAS, _nonline_mode, __VA_ARGS__) \
  333. break; \
  334. } \
  335. case BiasMode::BROADCAST_CHANNEL_BIAS: { \
  336. DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \
  337. _midout_tag, cb, 2, _src_type, _dst_type, \
  338. BiasMode::BROADCAST_CHANNEL_BIAS, _nonline_mode, \
  339. __VA_ARGS__) \
  340. break; \
  341. } \
  342. default: \
  343. megdnn_assert(0); \
  344. break; \
  345. }
  346. } // namespace

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台