You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

postprocess_helper.h 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. /**
  2. * \file dnn/src/x86/conv_bias/postprocess_helper.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/opr_param_defs.h"
  13. #include "src/fallback/conv_bias/common.h"
  14. #include "src/x86/elemwise_op.h"
  15. #include "src/x86/utils.h"
  16. #include "src/fallback/conv_bias/opr_impl.h"
  17. namespace megdnn {
  18. namespace x86 {
  19. #define BIAS_CASE(mode) \
  20. case megdnn::param::ConvBias::NonlineMode::mode: \
  21. elem_mode = megdnn::param::Elemwise::Mode::FUSE_ADD_##mode; \
  22. break;
  23. #define NOBIAS_CASE(mode) \
  24. case megdnn::param::ConvBias::NonlineMode::mode: \
  25. elem_mode = megdnn::param::Elemwise::Mode::mode; \
  26. break;
  27. #define IDENTITY_CASE(mode) \
  28. case megdnn::param::ConvBias::NonlineMode::mode: \
  29. break;
  30. #define DEFAULT_CASE \
  31. default: \
  32. megdnn_throw("unsupported nolinemode"); \
  33. break;
  34. #define CALL_UNARY(_op, _simd_type) \
  35. thin_function<void(const ctype*, ctype*, DType, DType, size_t)> run = \
  36. OpCallerUnary<_op<_simd_type, ctype, ctype>, _simd_type>::run; \
  37. run(static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \
  38. bias_type, dst_type, N* OC* OH* OW);
  39. #define CALL_BINARY_BROADCAST(_op, _simd_type) \
  40. thin_function<void(const ctype*, const ctype*, ctype*, DType, DType, \
  41. DType, size_t, size_t, size_t)> \
  42. run = OpCallerBinary<_op<_simd_type, ctype, ctype>, _simd_type, \
  43. megdnn::x86::BcastType::VEC_BCAST101>::run; \
  44. run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \
  45. reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \
  46. OC, OH* OW);
  47. #define CALL_BINARY(_op, _simd_type) \
  48. thin_function<void(const ctype*, const ctype*, ctype*, DType, DType, \
  49. DType, size_t)> \
  50. run = OpCallerBinary<_op<_simd_type, ctype, ctype>, _simd_type, \
  51. megdnn::x86::BcastType::VEC_VEC>::run; \
  52. run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \
  53. reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
  54. N* OC* OH* OW);
  55. #define cb_unary(_simd_type) \
  56. if (elem_mode == megdnn::param::Elemwise::Mode::RELU) { \
  57. CALL_UNARY(ReluOp, _simd_type); \
  58. } else if (elem_mode == megdnn::param::Elemwise::Mode::SIGMOID) { \
  59. CALL_UNARY(SigmoidOp, _simd_type); \
  60. } else if (elem_mode == megdnn::param::Elemwise::Mode::H_SWISH) { \
  61. CALL_UNARY(HSwishOp, _simd_type); \
  62. }
  63. #define FOR_NONLINEAR_NOBIAS() \
  64. if (is_supported(SIMDType::AVX2)) { \
  65. cb_unary(SIMDType::AVX2) \
  66. } else if (is_supported(SIMDType::SSE4_2)) { \
  67. cb_unary(SIMDType::SSE4_2) \
  68. } else { \
  69. cb_unary(SIMDType::NONE) \
  70. }
  71. #define cb_binary(_caller, _simd_type) \
  72. if (elem_mode == megdnn::param::Elemwise::Mode::ADD) { \
  73. _caller(AddOp, _simd_type); \
  74. } else if (elem_mode == megdnn::param::Elemwise::Mode::FUSE_ADD_SIGMOID) { \
  75. _caller(FuseAddSigmoidOp, _simd_type); \
  76. } else if (elem_mode == megdnn::param::Elemwise::Mode::FUSE_ADD_RELU) { \
  77. _caller(FuseAddReluOp, _simd_type); \
  78. } else if (elem_mode == megdnn::param::Elemwise::Mode::FUSE_ADD_H_SWISH) { \
  79. _caller(FuseAddHSwishOp, _simd_type); \
  80. }
  81. #define FOR_NONLINEAR(CALLER) \
  82. if (is_supported(SIMDType::AVX2)) { \
  83. cb_binary(CALLER, SIMDType::AVX2) \
  84. } else if (is_supported(SIMDType::SSE4_2)) { \
  85. cb_binary(CALLER, SIMDType::SSE4_2) \
  86. } else { \
  87. cb_binary(CALLER, SIMDType::NONE) \
  88. }
  89. #define FOR_BIAS(bias_mode) \
  90. switch (bias_mode) { \
  91. case BiasMode::NO_BIAS: \
  92. FOR_NONLINEAR_NOBIAS(); \
  93. break; \
  94. case BiasMode::BROADCAST_CHANNEL_BIAS: \
  95. FOR_NONLINEAR(CALL_BINARY_BROADCAST); \
  96. break; \
  97. case BiasMode::BIAS: \
  98. FOR_NONLINEAR(CALL_BINARY); \
  99. break; \
  100. default: \
  101. break; \
  102. }
  103. template <typename ctype, typename dtype = ctype,
  104. megdnn::PostprocessMode postprocess_mode =
  105. megdnn::PostprocessMode::FLOAT>
  106. struct PostProcess {
  107. static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr,
  108. megdnn::ConvBiasForward::BiasMode bias_mode,
  109. megdnn::param::ConvBias::NonlineMode nonlineMode,
  110. DType bias_type, DType dst_type, size_t N, size_t OC,
  111. size_t OH, size_t OW, size_t pack_oc_size = 1) {
  112. MEGDNN_MARK_USED_VAR(pack_oc_size);
  113. megdnn::param::Elemwise::Mode elem_mode =
  114. megdnn::param::Elemwise::Mode::ADD;
  115. if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
  116. switch (nonlineMode) {
  117. BIAS_CASE(RELU);
  118. BIAS_CASE(SIGMOID);
  119. BIAS_CASE(H_SWISH);
  120. IDENTITY_CASE(IDENTITY);
  121. DEFAULT_CASE;
  122. }
  123. } else {
  124. switch (nonlineMode) {
  125. NOBIAS_CASE(RELU);
  126. NOBIAS_CASE(SIGMOID);
  127. NOBIAS_CASE(H_SWISH);
  128. IDENTITY_CASE(IDENTITY);
  129. DEFAULT_CASE;
  130. }
  131. }
  132. FOR_BIAS(bias_mode);
  133. }
  134. };
  135. template <typename ctype, typename dtype>
  136. struct PostProcess<ctype, dtype, megdnn::PostprocessMode::FLOAT> {
  137. static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr,
  138. megdnn::ConvBiasForward::BiasMode bias_mode,
  139. megdnn::param::ConvBias::NonlineMode nonlineMode,
  140. DType bias_type, DType dst_type, size_t N, size_t OC,
  141. size_t OH, size_t OW, size_t pack_oc_size=1) {
  142. MEGDNN_MARK_USED_VAR(pack_oc_size);
  143. megdnn::param::Elemwise::Mode elem_mode =
  144. megdnn::param::Elemwise::Mode::ADD;
  145. if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
  146. switch (nonlineMode) {
  147. BIAS_CASE(RELU);
  148. BIAS_CASE(SIGMOID);
  149. BIAS_CASE(H_SWISH);
  150. IDENTITY_CASE(IDENTITY);
  151. DEFAULT_CASE;
  152. }
  153. } else {
  154. switch (nonlineMode) {
  155. NOBIAS_CASE(RELU);
  156. NOBIAS_CASE(SIGMOID);
  157. NOBIAS_CASE(H_SWISH);
  158. IDENTITY_CASE(IDENTITY);
  159. DEFAULT_CASE;
  160. }
  161. }
  162. FOR_BIAS(bias_mode);
  163. }
  164. };
  165. template <typename ctype, typename dtype>
  166. struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
  167. static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr,
  168. megdnn::ConvBiasForward::BiasMode bias_mode,
  169. megdnn::param::ConvBias::NonlineMode nonlineMode,
  170. DType bias_type, DType dst_type, size_t N, size_t OC,
  171. size_t OH, size_t OW,size_t pack_oc_size = 1) {
  172. MEGDNN_MARK_USED_VAR(pack_oc_size);
  173. MEGDNN_MARK_USED_VAR(conv_dst_ptr);
  174. MEGDNN_MARK_USED_VAR(bias_ptr);
  175. MEGDNN_MARK_USED_VAR(dst_ptr);
  176. MEGDNN_MARK_USED_VAR(bias_mode);
  177. MEGDNN_MARK_USED_VAR(nonlineMode);
  178. MEGDNN_MARK_USED_VAR(bias_type);
  179. MEGDNN_MARK_USED_VAR(dst_type);
  180. MEGDNN_MARK_USED_VAR(N);
  181. MEGDNN_MARK_USED_VAR(OC);
  182. MEGDNN_MARK_USED_VAR(OH);
  183. MEGDNN_MARK_USED_VAR(OW);
  184. }
  185. };
  186. #undef FOR_NONLINEAR_NOBIAS
  187. #undef FOR_NONLINEAR
  188. #undef FOR_BIAS
  189. #undef cb_binary
  190. #undef cb_unary
  191. #undef CALL_UNARY
  192. #undef CALL_BINARY_BROADCAST
  193. #define CALL_UNARY(_op, _simd_type) \
  194. thin_function<void(const ctype*, dtype*, DType, DType, size_t)> run = \
  195. OpCallerUnary<_op<_simd_type, ctype, dtype>, _simd_type>::run; \
  196. run(static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<dtype*>(dst_ptr), \
  197. bias_type, dst_type, N* OC* OH* OW);
  198. #define CALL_BINARY_BROADCAST(_op, _simd_type) \
  199. thin_function<void(const ctype*, const ctype*, dtype*, DType, DType, \
  200. DType, size_t, size_t, size_t)> \
  201. run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \
  202. megdnn::x86::BcastType::VEC_BCAST101>::run; \
  203. run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \
  204. reinterpret_cast<dtype*>(dst_ptr), bias_type, bias_type, dst_type, N, \
  205. OC, OH* OW);
  206. #define cb_unary(_simd_type) \
  207. if (elem_mode == megdnn::param::Elemwise::Mode::RELU) { \
  208. CALL_UNARY(ReluOp, _simd_type); \
  209. } else if (elem_mode == megdnn::param::Elemwise::Mode::H_SWISH) { \
  210. CALL_UNARY(HSwishOp, _simd_type); \
  211. } else { \
  212. if (nonlineMode == megdnn::param::ConvBias::NonlineMode::IDENTITY) { \
  213. CALL_UNARY(TypeCvtOp, _simd_type); \
  214. } else { \
  215. megdnn_throw("not supported nonlinemode\n"); \
  216. } \
  217. }
  218. #define FOR_NONLINEAR_NOBIAS() \
  219. if (is_supported(SIMDType::AVX2)) { \
  220. if (elem_mode == megdnn::param::Elemwise::Mode::RELU) { \
  221. CALL_UNARY(ReluOp, SIMDType::AVX2); \
  222. } else if (elem_mode == megdnn::param::Elemwise::Mode::H_SWISH) { \
  223. CALL_UNARY(HSwishOp, SIMDType::NONE); \
  224. } else { \
  225. if (nonlineMode == \
  226. megdnn::param::ConvBias::NonlineMode::IDENTITY) { \
  227. CALL_UNARY(TypeCvtOp, SIMDType::NONE); \
  228. } else { \
  229. megdnn_throw("not supported nonlinemode\n"); \
  230. } \
  231. } \
  232. } else if (is_supported(SIMDType::SSE4_2)) { \
  233. cb_unary(SIMDType::SSE4_2) \
  234. } else { \
  235. cb_unary(SIMDType::NONE) \
  236. }
  237. #define cb_binary(_caller, _simd_type) \
  238. if (elem_mode == megdnn::param::Elemwise::Mode::ADD) { \
  239. _caller(AddOp, _simd_type); \
  240. } else if (elem_mode == megdnn::param::Elemwise::Mode::FUSE_ADD_RELU) { \
  241. _caller(FuseAddReluOp, _simd_type); \
  242. } else if (elem_mode == megdnn::param::Elemwise::Mode::FUSE_ADD_H_SWISH) { \
  243. _caller(FuseAddHSwishOp, _simd_type); \
  244. }
  245. #define FOR_NONLINEAR(CALLER) \
  246. if (is_supported(SIMDType::AVX2)) { \
  247. cb_binary(CALLER, SIMDType::AVX2) \
  248. } else if (!is_supported(SIMDType::SSE4_2)) { \
  249. cb_binary(CALLER, SIMDType::SSE4_2) \
  250. } else { \
  251. cb_binary(CALLER, SIMDType::NONE) \
  252. }
  253. #define FOR_BIAS(bias_mode) \
  254. switch (bias_mode) { \
  255. case BiasMode::NO_BIAS: \
  256. FOR_NONLINEAR_NOBIAS(); \
  257. break; \
  258. case BiasMode::BROADCAST_CHANNEL_BIAS: \
  259. FOR_NONLINEAR(CALL_BINARY_BROADCAST); \
  260. break; \
  261. default: \
  262. break; \
  263. }
  264. template <typename ctype, typename dtype>
  265. struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> {
  266. static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr,
  267. megdnn::ConvBiasForward::BiasMode bias_mode,
  268. megdnn::param::ConvBiasV0::NonlineMode nonlineMode,
  269. DType bias_type, DType dst_type, size_t N, size_t OC,
  270. size_t OH, size_t OW, size_t pack_oc_size = 1) {
  271. MEGDNN_MARK_USED_VAR(pack_oc_size);
  272. megdnn::param::Elemwise::Mode elem_mode =
  273. megdnn::param::Elemwise::Mode::ADD;
  274. if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
  275. switch (nonlineMode) {
  276. BIAS_CASE(RELU);
  277. BIAS_CASE(H_SWISH);
  278. IDENTITY_CASE(IDENTITY);
  279. DEFAULT_CASE;
  280. }
  281. } else {
  282. switch (nonlineMode) {
  283. NOBIAS_CASE(RELU);
  284. NOBIAS_CASE(H_SWISH);
  285. IDENTITY_CASE(IDENTITY);
  286. DEFAULT_CASE;
  287. }
  288. }
  289. FOR_BIAS(bias_mode);
  290. #undef FOR_NONLINEAR_NOBIAS
  291. #undef FOR_NONLINEAR
  292. #undef FOR_BIAS
  293. }
  294. };
  295. #undef cb_unary
  296. #undef cb_binary
  297. #undef BIAS_CASE
  298. #undef NOBIAS_CASE
  299. #undef DEFAULT_CASE
  300. #undef CALL_UNARY
  301. #undef CALL_BINARY
  302. #undef CALL_BINARY_BROADCAST
  303. #define DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, _bias_id, _simd_type, \
  304. _src_type, _dst_type, _bmode, \
  305. _nonline_mode, ...) \
  306. switch (_nonline_mode) { \
  307. case param::ConvBias::NonlineMode::IDENTITY: \
  308. MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \
  309. cb(_bmode, \
  310. NoneOp<_simd_type MEGDNN_COMMA _src_type MEGDNN_COMMA \
  311. _dst_type>, \
  312. __VA_ARGS__); \
  313. } \
  314. MIDOUT_END(); \
  315. break; \
  316. case param::ConvBias::NonlineMode::RELU: { \
  317. MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \
  318. cb(_bmode, \
  319. ReluOp<_simd_type MEGDNN_COMMA _src_type MEGDNN_COMMA \
  320. _dst_type>, \
  321. __VA_ARGS__); \
  322. } \
  323. MIDOUT_END(); \
  324. break; \
  325. } \
  326. case param::ConvBias::NonlineMode::SIGMOID: { \
  327. MIDOUT_BEGIN(_midout_tag, _bias_id, 2) { \
  328. cb(_bmode, \
  329. SigmoidOp<_simd_type MEGDNN_COMMA _src_type MEGDNN_COMMA \
  330. _dst_type>, \
  331. __VA_ARGS__); \
  332. } \
  333. MIDOUT_END(); \
  334. break; \
  335. } \
  336. case param::ConvBias::NonlineMode::H_SWISH: { \
  337. MIDOUT_BEGIN(_midout_tag, _bias_id, 3) { \
  338. cb(_bmode, \
  339. HSwishOp<_simd_type MEGDNN_COMMA _src_type MEGDNN_COMMA \
  340. _dst_type>, \
  341. __VA_ARGS__); \
  342. } \
  343. MIDOUT_END(); \
  344. break; \
  345. } \
  346. default: \
  347. megdnn_assert(0); \
  348. break; \
  349. }
  350. #define DISPATCH_CONV_WINOGRAD_BIAS(_midout_tag, cb, _simd_type, _src_type, \
  351. _dst_type, _bmode, _nonline_mode, ...) \
  352. switch (_bmode) { \
  353. case BiasMode::BIAS: { \
  354. DISPATCH_CONV_WINOGRAD_NONLINE( \
  355. _midout_tag, cb, 0, _simd_type, _src_type, _dst_type, \
  356. BiasMode::BIAS, _nonline_mode, __VA_ARGS__) \
  357. break; \
  358. } \
  359. case BiasMode::NO_BIAS: { \
  360. DISPATCH_CONV_WINOGRAD_NONLINE( \
  361. _midout_tag, cb, 1, _simd_type, _src_type, _dst_type, \
  362. BiasMode::NO_BIAS, _nonline_mode, __VA_ARGS__) \
  363. break; \
  364. } \
  365. case BiasMode::BROADCAST_CHANNEL_BIAS: { \
  366. DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, 2, _simd_type, \
  367. _src_type, _dst_type, \
  368. BiasMode::BROADCAST_CHANNEL_BIAS, \
  369. _nonline_mode, __VA_ARGS__) \
  370. break; \
  371. } \
  372. default: \
  373. megdnn_assert(0); \
  374. break; \
  375. }
  376. #define DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \
  377. _midout_tag, cb, _bias_id, _simd_type, _src_type, _dst_type, _bmode, \
  378. _nonline_mode, ...) \
  379. switch (_nonline_mode) { \
  380. case param::ConvBias::NonlineMode::IDENTITY: { \
  381. MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \
  382. cb(_bmode, \
  383. TypeCvtOp<_simd_type MEGDNN_COMMA _src_type MEGDNN_COMMA \
  384. _dst_type>, \
  385. __VA_ARGS__); \
  386. } \
  387. MIDOUT_END(); \
  388. break; \
  389. } \
  390. case param::ConvBias::NonlineMode::RELU: { \
  391. MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \
  392. cb(_bmode, \
  393. ReluOp<_simd_type MEGDNN_COMMA _src_type MEGDNN_COMMA \
  394. _dst_type>, \
  395. __VA_ARGS__); \
  396. } \
  397. MIDOUT_END(); \
  398. break; \
  399. } \
  400. default: \
  401. megdnn_assert(0); \
  402. break; \
  403. }
  404. #define DISPATCH_CONV_WINOGRAD_BIAS_QUANTIZED(_midout_tag, cb, _simd_type, \
  405. _src_type, _dst_type, _bmode, \
  406. _nonline_mode, ...) \
  407. switch (_bmode) { \
  408. case BiasMode::BIAS: { \
  409. DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \
  410. _midout_tag, cb, 0, _simd_type, _src_type, _dst_type, \
  411. BiasMode::BIAS, _nonline_mode, __VA_ARGS__) \
  412. break; \
  413. } \
  414. case BiasMode::NO_BIAS: { \
  415. DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \
  416. _midout_tag, cb, 1, _simd_type, _src_type, _dst_type, \
  417. BiasMode::NO_BIAS, _nonline_mode, __VA_ARGS__) \
  418. break; \
  419. } \
  420. case BiasMode::BROADCAST_CHANNEL_BIAS: { \
  421. DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \
  422. _midout_tag, cb, 2, _simd_type, _src_type, _dst_type, \
  423. BiasMode::BROADCAST_CHANNEL_BIAS, _nonline_mode, \
  424. __VA_ARGS__) \
  425. break; \
  426. } \
  427. default: \
  428. megdnn_assert(0); \
  429. break; \
  430. }
  431. } // namespace x86
  432. } // namespace megdnn

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台