You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo.cpp 50 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977
  1. /**
  2. * \file dnn/src/arm_common/pooling/algo.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/arm_common/pooling/algo.h"
  13. #include "megdnn/opr_param_defs.h"
  14. #include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h"
  15. #include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h"
  16. #include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h"
  17. #include "src/arm_common/pooling/do_pooling_2x2_nchw44.h"
  18. #include "src/arm_common/pooling/do_pooling_3x3_nchw44.h"
  19. #include "src/arm_common/pooling/do_pooling_4x4_nchw44.h"
  20. #include "src/arm_common/pooling/do_pooling_5x5_nchw44.h"
  21. #include "midout.h"
  22. MIDOUT_DECL(megdnn_arm_common_pooling)
  23. namespace megdnn {
  24. namespace arm_common {
  25. WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) {
  26. megdnn_assert((param.src_type.category() == DTypeCategory::FLOAT ||
  27. param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
  28. param.src_type.enumv() == DTypeEnum::Quantized8Asymm ||
  29. param.src_type == dtype::Int8{}) &&
  30. param.format == param::Pooling::Format::NCHW &&
  31. (param.mode == param::Pooling::Mode::MAX ||
  32. (param.mode == param::Pooling::Mode::AVERAGE &&
  33. param.filter[0] == 3)) &&
  34. param.filter[0] == param.filter[1] &&
  35. (param.filter[0] == 3 || param.filter[1] == 5) &&
  36. param.stride[0] == 2 && param.stride[1] == 2 &&
  37. param.isz[0] >= 2 && param.isz[1] >= 2);
  38. //! max pooling nxn stride 2
  39. auto IW = param.isz[1];
  40. auto OW = param.osz[1];
  41. // In order to process odd size filter,
  42. // Firstly, Store a row of the input separately by odd and even numbers
  43. // Then process them, get a row of the outputs
  44. // We need to store n rows of results
  45. SmallVector<size_t> needed_mem;
  46. for (size_t i = 0; i < param.filter[0]; ++i)
  47. needed_mem.push_back(OW * param.src_type.size());
  48. needed_mem.push_back((IW + 1) / 2 * param.src_type.size());
  49. needed_mem.push_back((IW + 1) / 2 * param.src_type.size());
  50. WorkspaceBundle ws(nullptr, needed_mem, 16);
  51. return ws;
  52. }
  53. WorkspaceBundle get_bundle_nchw44(
  54. const PoolingImpl::PoolingKernSizeParam& param) {
  55. megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
  56. param.src_type.enumv() == DTypeEnum::Int8) &&
  57. (param.format == param::Pooling::Format::NCHW44));
  58. auto IH = param.isz[0];
  59. auto IW = param.isz[1];
  60. auto PH = param.padding[0];
  61. auto PW = param.padding[1];
  62. size_t padding_size = 0;
  63. if ((PH != 0) || (PW != 0)) {
  64. padding_size = (IW + 2 * PW) * (IH + 2 * PH) * 4 * sizeof(int8_t);
  65. }
  66. return WorkspaceBundle(nullptr, {padding_size});
  67. }
  68. const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW,
  69. size_t& IH2, size_t& IW2, size_t PH, size_t PW,
  70. const WorkspaceBundle& ws, bool is_max_mode) {
  71. int8_t* sptr_base = nullptr;
  72. int8_t padding_value = is_max_mode ? INT8_MIN : 0;
  73. bool need_pad = ((PH != 0) || (PW != 0)) ? true : false;
  74. if (need_pad) {
  75. IH2 = IH + 2 * PH;
  76. IW2 = IW + 2 * PW;
  77. sptr_base = static_cast<int8_t*>(ws.get(0));
  78. memset(sptr_base, padding_value, sizeof(int8_t) * IH2 * IW2 * 4);
  79. rep(ih, IH) {
  80. std::memcpy(sptr_base + (ih + PH) * IW2 * 4 + PW * 4,
  81. src + ih * IW * 4, sizeof(int8_t) * IW * 4);
  82. }
  83. } else {
  84. IH2 = IH;
  85. IW2 = IW;
  86. }
  87. return need_pad ? sptr_base : src;
  88. }
  89. bool PoolingImpl::AlgoFilterxModexStride1::usable(
  90. const PoolingKernSizeParam& param) const {
  91. auto SH = param.stride[0];
  92. auto SW = param.stride[1];
  93. auto FH = param.filter[0];
  94. auto FW = param.filter[1];
  95. bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
  96. param.src_type.category() == DTypeCategory::QUANTIZED) &&
  97. param.format == Param::Format::NCHW && SH == 1 && SW == 1 &&
  98. FH == FW && (FH == 2 || FH == 3);
  99. return avaible;
  100. }
  101. void PoolingImpl::AlgoFilterxModexStride1::exec(
  102. const PoolingKernParam& param) const {
  103. auto IH = param.isz[0], IW = param.isz[1];
  104. auto OH = param.osz[0], OW = param.osz[1];
  105. auto N = param.n, C = param.ic;
  106. auto PH = param.padding[0];
  107. auto PW = param.padding[1];
  108. auto FH = param.filter[0];
  109. void* src_ptr = param.src_ptr;
  110. void* dst_ptr = param.dst_ptr;
  111. #define DISPATCH_FUNC(Pooler, NeonPooler, window, midout_type_id) \
  112. MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(0), \
  113. midout_iv(midout_type_id), Pooler::MIDOUT_CASE_NUM, \
  114. NeonPooler::MIDOUT_CASE_NUM, window) { \
  115. auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
  116. src_dtype = param.src_type](size_t index, size_t) { \
  117. size_t n = index / C; \
  118. size_t c = index % C; \
  119. do_pooling_compact< \
  120. Pooler MEGDNN_COMMA NeonPooler MEGDNN_COMMA window>( \
  121. static_cast<const typename Pooler::ctype*>(src_ptr) + \
  122. n * C * IH * IW + c * IH * IW, \
  123. static_cast<typename Pooler::ctype*>(dst_ptr) + \
  124. n * C * OH * OW + c * OH * OW, \
  125. src_dtype, IH, IW, OH, OW, PH, PW); \
  126. }; \
  127. MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
  128. static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
  129. run); \
  130. } \
  131. MIDOUT_END()
  132. #define DISPATCH_WINDOW(Pooler, NeonPooler, dtype, ctype, comp_type, \
  133. midout_type_id) \
  134. switch (FH) { \
  135. case 2: { \
  136. using _Pooler = Pooler<4, dtype, ctype, comp_type>; \
  137. using _NeonPooler = NeonPooler<4, dtype, ctype, comp_type>; \
  138. DISPATCH_FUNC(_Pooler, _NeonPooler, 2, midout_type_id); \
  139. break; \
  140. } \
  141. case 3: { \
  142. using _Pooler = Pooler<9, dtype, ctype, comp_type>; \
  143. using _NeonPooler = NeonPooler<9, dtype, ctype, comp_type>; \
  144. DISPATCH_FUNC(_Pooler, _NeonPooler, 3, midout_type_id); \
  145. break; \
  146. } \
  147. default: \
  148. megdnn_assert(0, "unsupport pooling filter size"); \
  149. break; \
  150. }
  151. #define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id) \
  152. switch (param.mode) { \
  153. case Mode::MAX: \
  154. DISPATCH_WINDOW(MaxPooler, NeonMaxPooler, dtype, ctype, comp_type, \
  155. midout_type_id); \
  156. break; \
  157. case Mode::AVERAGE: \
  158. DISPATCH_WINDOW(MeanInPooler, NeonMeanPooler, dtype, ctype, \
  159. comp_type, midout_type_id); \
  160. break; \
  161. default: \
  162. megdnn_assert(0, "unsupport pooling mode"); \
  163. break; \
  164. }
  165. if (param.src_type == dtype::Float32{}) {
  166. DISPATCH_MODE(dt_float32, float, float, 0);
  167. } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
  168. DISPATCH_MODE(dt_qint8, int8_t, float, 1);
  169. } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
  170. DISPATCH_MODE(dt_quint8, uint8_t, float, 2);
  171. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  172. } else if (param.src_type == dtype::Float16{}) {
  173. DISPATCH_MODE(dt_float16, __fp16, __fp16, 3);
  174. #endif
  175. }
  176. #undef DISPATCH_FUNC
  177. #undef DISPATCH_WINDOW
  178. #undef DISPATCH_MODE
  179. }
  180. bool PoolingImpl::AlgoFilter2ModexStride2::usable(
  181. const PoolingKernSizeParam& param) const {
  182. auto SH = param.stride[0];
  183. auto SW = param.stride[1];
  184. auto FH = param.filter[0];
  185. auto FW = param.filter[1];
  186. bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
  187. param.src_type.category() == DTypeCategory::QUANTIZED) &&
  188. param.format == Param::Format::NCHW && FH == FW &&
  189. SH == SW && FH == 2 && SH == 2;
  190. return avaible;
  191. }
  192. void PoolingImpl::AlgoFilter2ModexStride2::exec(
  193. const PoolingKernParam& param) const {
  194. auto IH = param.isz[0], IW = param.isz[1];
  195. auto OH = param.osz[0], OW = param.osz[1];
  196. auto N = param.n, C = param.ic;
  197. auto PH = param.padding[0];
  198. auto PW = param.padding[1];
  199. void* src_ptr = param.src_ptr;
  200. void* dst_ptr = param.dst_ptr;
  201. #define DISPATCH_FUNC(Pooler, mode, midout_type_id) \
  202. MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(1), \
  203. midout_iv(midout_type_id), Pooler::MIDOUT_CASE_NUM) { \
  204. auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
  205. src_dtype = param.src_type](size_t index, size_t) { \
  206. size_t n = index / C; \
  207. size_t c = index % C; \
  208. do_pooling_2x2<Pooler MEGDNN_COMMA mode>( \
  209. static_cast<const typename Pooler::ctype*>(src_ptr) + \
  210. n * C * IH * IW + c * IH * IW, \
  211. static_cast<typename Pooler::ctype*>(dst_ptr) + \
  212. n * C * OH * OW + c * OH * OW, \
  213. src_dtype, IH, IW, OH, OW, PH, PW); \
  214. }; \
  215. MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
  216. static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
  217. run); \
  218. } \
  219. MIDOUT_END()
  220. #define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id) \
  221. switch (param.mode) { \
  222. case Mode::MAX: { \
  223. using _Pooler = MaxPooler<4, dtype, ctype, comp_type>; \
  224. DISPATCH_FUNC(_Pooler, Mode::MAX, midout_type_id); \
  225. break; \
  226. } \
  227. case Mode::AVERAGE: { \
  228. using _Pooler = MeanInPooler<4, dtype, ctype, comp_type>; \
  229. DISPATCH_FUNC(_Pooler, Mode::AVERAGE, midout_type_id); \
  230. break; \
  231. } \
  232. default: \
  233. megdnn_assert(0, "unsupport pooling mode"); \
  234. break; \
  235. }
  236. if (param.src_type == dtype::Float32{}) {
  237. DISPATCH_MODE(dt_float32, float, float, 0);
  238. } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
  239. DISPATCH_MODE(dt_qint8, int8_t, float, 1);
  240. } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
  241. DISPATCH_MODE(dt_quint8, uint8_t, float, 2);
  242. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  243. } else if (param.src_type == dtype::Float16{}) {
  244. DISPATCH_MODE(dt_float16, __fp16, __fp16, 3);
  245. #endif
  246. }
  247. #undef DISPATCH_FUNC
  248. #undef DISPATCH_PAD
  249. #undef DISPATCH_MODE
  250. }
  251. bool PoolingImpl::AlgoFilter3MaxStride2::usable(
  252. const PoolingKernSizeParam& param) const {
  253. bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
  254. param.src_type.category() == DTypeCategory::QUANTIZED) &&
  255. param.format == Param::Format::NCHW &&
  256. param.mode == Mode::MAX && param.filter[0] == 3 &&
  257. param.filter[1] == 3 && param.stride[0] == 2 &&
  258. param.stride[1] == 2 && param.isz[0] >= 2 &&
  259. param.isz[1] >= 2;
  260. return avaible;
  261. }
  262. void PoolingImpl::AlgoFilter3MaxStride2::exec(
  263. const PoolingKernParam& param) const {
  264. auto IH = param.isz[0], IW = param.isz[1];
  265. auto OH = param.osz[0], OW = param.osz[1];
  266. auto N = param.n, C = param.ic;
  267. auto PH = param.padding[0];
  268. auto PW = param.padding[1];
  269. void* src_ptr = param.src_ptr;
  270. void* dst_ptr = param.dst_ptr;
  271. #define DISPATCH_FUNC(type, func, midout_type_id) \
  272. MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
  273. midout_iv(midout_type_id)) { \
  274. WorkspaceBundle wbundle = get_bundle(param); \
  275. auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
  276. wbundle = wbundle, \
  277. workspace_ptr = param.workspace<dt_byte>()]( \
  278. size_t index, size_t thread_id) { \
  279. auto ws = wbundle; \
  280. ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
  281. size_t n = index / C; \
  282. size_t c = index % C; \
  283. do_max_pooling_3x3_s2x2_##func##_NEON( \
  284. static_cast<const type*>(src_ptr) + n * C * IH * IW + \
  285. c * IH * IW, \
  286. static_cast<type*>(dst_ptr) + n * C * OH * OW + \
  287. c * OH * OW, \
  288. IH, IW, OH, OW, PH, PW, ws); \
  289. }; \
  290. MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
  291. static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
  292. run); \
  293. } \
  294. MIDOUT_END();
  295. if (param.src_type == dtype::Float32{}) {
  296. DISPATCH_FUNC(float, float, 0);
  297. } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
  298. DISPATCH_FUNC(int8_t, int8, 1);
  299. } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
  300. DISPATCH_FUNC(uint8_t, uint8, 2);
  301. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  302. } else if (param.src_type == dtype::Float16{}) {
  303. DISPATCH_FUNC(__fp16, float16, 3);
  304. #endif
  305. }
  306. #undef DISPATCH_FUNC
  307. }
  308. bool PoolingImpl::AlgoFilter3AverageStride2::usable(
  309. const PoolingKernSizeParam& param) const {
  310. bool avaible = (param.src_type.category() == DTypeCategory::FLOAT) &&
  311. param.format == Param::Format::NCHW &&
  312. param.mode == Mode::AVERAGE && param.filter[0] == 3 &&
  313. param.filter[1] == 3 && param.stride[0] == 2 &&
  314. param.stride[1] == 2 && param.isz[0] >= 2 &&
  315. param.isz[1] >= 2;
  316. return avaible;
  317. }
  318. void PoolingImpl::AlgoFilter3AverageStride2::exec(
  319. const PoolingKernParam& param) const {
  320. auto IH = param.isz[0], IW = param.isz[1];
  321. auto OH = param.osz[0], OW = param.osz[1];
  322. auto N = param.n, C = param.ic;
  323. auto PH = param.padding[0];
  324. auto PW = param.padding[1];
  325. void* src_ptr = param.src_ptr;
  326. void* dst_ptr = param.dst_ptr;
  327. #define DISPATCH_FUNC(type, MEGDNN_SIMD_WIDTH, midout_type_id) \
  328. MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(3), \
  329. midout_iv(midout_type_id)) { \
  330. WorkspaceBundle wbundle = get_bundle(param); \
  331. auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
  332. wbundle = wbundle, \
  333. workspace_ptr = param.workspace<dt_byte>()]( \
  334. size_t index, size_t thread_id) { \
  335. auto ws = wbundle; \
  336. ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
  337. size_t n = index / C; \
  338. size_t c = index % C; \
  339. do_average_pooling_3x3_s2x2_NEON( \
  340. static_cast<const type*>(src_ptr) + n * C * IH * IW + \
  341. c * IH * IW, \
  342. static_cast<type*>(dst_ptr) + n * C * OH * OW + \
  343. c * OH * OW, \
  344. IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \
  345. }; \
  346. MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
  347. static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
  348. run); \
  349. } \
  350. MIDOUT_END();
  351. if (param.src_type == dtype::Float32{}) {
  352. DISPATCH_FUNC(dt_float32, 4, 0);
  353. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  354. } else if (param.src_type == dtype::Float16{}) {
  355. DISPATCH_FUNC(__fp16, 8, 1);
  356. #endif
  357. }
  358. #undef DISPATCH_FUNC
  359. }
  360. bool PoolingImpl::AlgoFilter4MaxStride2::usable(
  361. const PoolingKernSizeParam& param) const {
  362. auto SH = param.stride[0];
  363. auto SW = param.stride[1];
  364. auto FH = param.filter[0];
  365. auto FW = param.filter[1];
  366. auto OH = param.osz[0], OW = param.osz[1];
  367. bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
  368. param.src_type.category() == DTypeCategory::QUANTIZED) &&
  369. param.format == Param::Format::NCHW &&
  370. param.mode == Mode::MAX && FH == 4 && FW == 4 && SH == 2 &&
  371. SW == 2 && OH >= 2 && OW >= 2;
  372. return avaible;
  373. }
  374. void PoolingImpl::AlgoFilter4MaxStride2::exec(
  375. const PoolingKernParam& param) const {
  376. auto IH = param.isz[0], IW = param.isz[1];
  377. auto OH = param.osz[0], OW = param.osz[1];
  378. auto N = param.n, C = param.ic;
  379. auto PH = param.padding[0];
  380. auto PW = param.padding[1];
  381. void* src_ptr = param.src_ptr;
  382. void* dst_ptr = param.dst_ptr;
  383. #define DISPATCH_FUNC(type, func, midout_type_id) \
  384. MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(4), \
  385. midout_iv(midout_type_id)) { \
  386. auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
  387. src_dtype = param.src_type](size_t index, size_t) { \
  388. size_t n = index / C; \
  389. size_t c = index % C; \
  390. do_max_pooling_w4x4_s2x2_##func##_NEON( \
  391. static_cast<const type*>(src_ptr) + n * C * IH * IW + \
  392. c * IH * IW, \
  393. static_cast<type*>(dst_ptr) + n * C * OH * OW + \
  394. c * OH * OW, \
  395. src_dtype, IH, IW, OH, OW, PH, PW); \
  396. }; \
  397. MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
  398. static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
  399. run); \
  400. } \
  401. MIDOUT_END();
  402. if (param.src_type == dtype::Float32{}) {
  403. DISPATCH_FUNC(float, float, 0);
  404. } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
  405. DISPATCH_FUNC(int8_t, int8, 1);
  406. } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
  407. DISPATCH_FUNC(uint8_t, uint8, 2);
  408. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  409. } else if (param.src_type == dtype::Float16{}) {
  410. DISPATCH_FUNC(__fp16, float16, 3);
  411. #endif
  412. }
  413. #undef DISPATCH_FUNC
  414. }
  415. bool PoolingImpl::AlgoFilter5MaxStride2::usable(
  416. const PoolingKernSizeParam& param) const {
  417. auto SH = param.stride[0];
  418. auto SW = param.stride[1];
  419. auto FH = param.filter[0];
  420. auto FW = param.filter[1];
  421. auto OH = param.osz[0], OW = param.osz[1];
  422. bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
  423. param.src_type.category() == DTypeCategory::QUANTIZED) &&
  424. param.format == Param::Format::NCHW &&
  425. param.mode == Mode::MAX && FH == 5 && FW == 5 && SH == 2 &&
  426. SW == 2 && OH >= 2 && OW >= 2;
  427. return avaible;
  428. }
  429. void PoolingImpl::AlgoFilter5MaxStride2::exec(
  430. const PoolingKernParam& param) const {
  431. auto IH = param.isz[0], IW = param.isz[1];
  432. auto OH = param.osz[0], OW = param.osz[1];
  433. auto N = param.n, C = param.ic;
  434. auto PH = param.padding[0];
  435. auto PW = param.padding[1];
  436. void* src_ptr = param.src_ptr;
  437. void* dst_ptr = param.dst_ptr;
  438. #define DISPATCH_FUNC(dtype, type, midout_type_id, MEGDNN_SIMD_WIDTH) \
  439. MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(5), \
  440. midout_iv(midout_type_id)) { \
  441. WorkspaceBundle wbundle = get_bundle(param); \
  442. auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
  443. wbundle = wbundle, \
  444. workspace_ptr = param.workspace<dt_byte>()]( \
  445. size_t index, size_t thread_id) { \
  446. auto ws = wbundle; \
  447. ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
  448. size_t n = index / C; \
  449. size_t c = index % C; \
  450. do_max_pooling_w5x5_s2x2_NEON<dtype>( \
  451. static_cast<const type*>(src_ptr) + n * C * IH * IW + \
  452. c * IH * IW, \
  453. static_cast<type*>(dst_ptr) + n * C * OH * OW + \
  454. c * OH * OW, \
  455. IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \
  456. }; \
  457. MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
  458. static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
  459. run); \
  460. } \
  461. MIDOUT_END();
  462. if (param.src_type == dtype::Float32{}) {
  463. DISPATCH_FUNC(dt_float32, float, 0, 4);
  464. } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
  465. DISPATCH_FUNC(dt_int8, int8_t, 1, 16);
  466. } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
  467. DISPATCH_FUNC(dt_uint8, uint8_t, 2, 16);
  468. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  469. } else if (param.src_type == dtype::Float16{}) {
  470. DISPATCH_FUNC(dt_float16, __fp16, 3, 8);
  471. #endif
  472. }
  473. #undef DISPATCH_FUNC
  474. }
  475. bool PoolingImpl::AlgoInt8Filter2MaxStride2::usable(
  476. const PoolingKernSizeParam& param) const {
  477. auto SH = param.stride[0];
  478. auto SW = param.stride[1];
  479. auto FH = param.filter[0];
  480. auto FW = param.filter[1];
  481. auto PH = param.padding[0];
  482. auto PW = param.padding[1];
  483. bool avaible = param.src_type == dtype::Int8() &&
  484. param.format == Param::Format::NCHW &&
  485. param.mode == Mode::MAX && SH == 2 && SW == 2 && PH == 0 &&
  486. PW == 0 && FH == 2 && FW == 2;
  487. return avaible;
  488. }
  489. void PoolingImpl::AlgoInt8Filter2MaxStride2::exec(
  490. const PoolingKernParam& param) const {
  491. auto IH = param.isz[0], IW = param.isz[1];
  492. auto OH = param.osz[0], OW = param.osz[1];
  493. auto N = param.n, C = param.ic;
  494. auto src_ptr = param.src<dt_int8>();
  495. auto dst_ptr = param.dst<dt_int8>();
  496. MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(6)) {
  497. auto run = [C, IH, IW, OH, OW, src_ptr, dst_ptr](size_t index, size_t) {
  498. size_t n = index / C;
  499. size_t c = index % C;
  500. pooling_max_w2x2_s2x2(src_ptr + n * C * IH * IW + c * IH * IW,
  501. dst_ptr + n * C * OH * OW + c * OH * OW, 1, 1,
  502. IH, IW, OH, OW);
  503. };
  504. MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(
  505. static_cast<::megdnn::naive::HandleImpl*>(param.handle), N * C,
  506. run);
  507. }
  508. MIDOUT_END();
  509. }
  510. bool PoolingImpl::AlgoInt8Filter3MaxStride2::usable(
  511. const PoolingKernSizeParam& param) const {
  512. auto SH = param.stride[0];
  513. auto SW = param.stride[1];
  514. auto FH = param.filter[0];
  515. auto FW = param.filter[1];
  516. auto IH = param.isz[0];
  517. auto IW = param.isz[1];
  518. bool avaible = param.src_type == dtype::Int8() &&
  519. param.format == Param::Format::NCHW &&
  520. param.mode == Mode::MAX && FH == 3 && FW == 3 && SH == 2 &&
  521. SW == 2 && IH >= 2 && IW >= 2;
  522. return avaible;
  523. }
  524. void PoolingImpl::AlgoInt8Filter3MaxStride2::exec(
  525. const PoolingKernParam& param) const {
  526. auto IH = param.isz[0], IW = param.isz[1];
  527. auto OH = param.osz[0], OW = param.osz[1];
  528. auto N = param.n, C = param.ic;
  529. auto PH = param.padding[0];
  530. auto PW = param.padding[1];
  531. auto src_ptr = param.src<dt_int8>();
  532. auto dst_ptr = param.dst<dt_int8>();
  533. MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(7)) {
  534. WorkspaceBundle wbundle = get_bundle(param);
  535. auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr,
  536. wbundle = wbundle,
  537. workspace_ptr = param.workspace<dt_byte>()](
  538. size_t index, size_t thread_id) {
  539. auto ws = wbundle;
  540. ws.set(workspace_ptr + thread_id * ws.total_size_in_bytes());
  541. size_t n = index / C;
  542. size_t c = index % C;
  543. do_max_pooling_3x3_s2x2_int8_NEON(
  544. src_ptr + n * C * IH * IW + c * IH * IW,
  545. dst_ptr + n * C * OH * OW + c * OH * OW, IH, IW, OH, OW, PH,
  546. PW, ws);
  547. };
  548. MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(
  549. static_cast<::megdnn::naive::HandleImpl*>(param.handle), N * C,
  550. run);
  551. }
  552. MIDOUT_END();
  553. }
  554. bool PoolingImpl::AlgoFilter3ModexStridexNCHW44::usable(
  555. const PoolingKernSizeParam& param) const {
  556. auto SH = param.stride[0];
  557. auto SW = param.stride[1];
  558. auto FH = param.filter[0];
  559. auto FW = param.filter[1];
  560. bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
  561. param.src_type.enumv() == DTypeEnum::Int8) &&
  562. param.format == Param::Format::NCHW44 &&
  563. (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
  564. FH == 3 && FW == 3 && SW == SH && (SH == 1 || SW == 2);
  565. //! Int8 not support average, because its round mode is different form
  566. //! qint8
  567. avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 &&
  568. param.mode == Mode::AVERAGE);
  569. return avaible;
  570. }
  571. void PoolingImpl::AlgoFilter3ModexStridexNCHW44::exec(
  572. const PoolingKernParam& param) const {
  573. auto IH = param.isz[0], IW = param.isz[1];
  574. auto OH = param.osz[0], OW = param.osz[1];
  575. auto N = param.n, C = param.ic;
  576. auto PH = param.padding[0];
  577. auto PW = param.padding[1];
  578. auto SW = param.stride[0];
  579. void* src_ptr = param.src_ptr;
  580. void* dst_ptr = param.dst_ptr;
  581. #define DISPATCH_FUNC(type, func, i, mode) \
  582. MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(8), \
  583. midout_iv(#type #i##_hash)) { \
  584. WorkspaceBundle wbundle = get_bundle_nchw44(param); \
  585. auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
  586. wbundle = wbundle, \
  587. workspace_ptr = param.workspace<dt_byte>()]( \
  588. size_t index, size_t thread_id) { \
  589. auto ws = wbundle; \
  590. ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
  591. size_t n = index / C; \
  592. size_t c = index % C; \
  593. do_##mode##_pooling_3x3_stride##i##_##func##_nchw44_NEON( \
  594. static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \
  595. c * IH * IW * 4, \
  596. static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
  597. c * OH * OW * 4, \
  598. IH, IW, OH, OW, PH, PW, ws); \
  599. }; \
  600. MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
  601. static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
  602. run); \
  603. } \
  604. MIDOUT_END();
  605. #define DISPATCH_MODE(type, func, stride) \
  606. switch (param.mode) { \
  607. case Mode::MAX: { \
  608. DISPATCH_FUNC(type, func, stride, max); \
  609. break; \
  610. } \
  611. case Mode::AVERAGE: { \
  612. DISPATCH_FUNC(type, func, stride, avg); \
  613. break; \
  614. } \
  615. default: \
  616. megdnn_throw(ssprintf("Unsupport pooling mode %d", \
  617. static_cast<int>(param.mode)) \
  618. .c_str()); \
  619. }
  620. #define DISPATCH_STRIDE(type, func) \
  621. switch (SW) { \
  622. case 1: { \
  623. DISPATCH_MODE(type, func, 1); \
  624. break; \
  625. } \
  626. case 2: { \
  627. DISPATCH_MODE(type, func, 2); \
  628. break; \
  629. } \
  630. default: \
  631. megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \
  632. }
  633. DISPATCH_STRIDE(int8_t, int8);
  634. #undef DISPATCH_STRIDE
  635. #undef DISPATCH_MODE
  636. #undef DISPATCH_FUNC
  637. }
  638. bool PoolingImpl::AlgoFilter2ModexStridexNCHW44::usable(
  639. const PoolingKernSizeParam& param) const {
  640. auto SH = param.stride[0];
  641. auto SW = param.stride[1];
  642. auto FH = param.filter[0];
  643. auto FW = param.filter[1];
  644. bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
  645. param.src_type.enumv() == DTypeEnum::Int8) &&
  646. param.format == Param::Format::NCHW44 &&
  647. (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
  648. FH == 2 && FW == 2 && SH == SW && (SW == 1 || SW == 2);
  649. //! Int8 not support average, because its round mode is different form
  650. //! qint8
  651. avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 &&
  652. param.mode == Mode::AVERAGE);
  653. return avaible;
  654. }
  655. void PoolingImpl::AlgoFilter2ModexStridexNCHW44::exec(
  656. const PoolingKernParam& param) const {
  657. auto IH = param.isz[0], IW = param.isz[1];
  658. auto OH = param.osz[0], OW = param.osz[1];
  659. auto N = param.n, C = param.ic;
  660. auto PH = param.padding[0];
  661. auto PW = param.padding[1];
  662. auto SW = param.stride[0];
  663. void* src_ptr = param.src_ptr;
  664. void* dst_ptr = param.dst_ptr;
  665. #define DISPATCH_FUNC(type, func, i, mode) \
  666. MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(9), \
  667. midout_iv(#func #i##_hash)) { \
  668. WorkspaceBundle wbundle = get_bundle_nchw44(param); \
  669. auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
  670. wbundle = wbundle, \
  671. workspace_ptr = param.workspace<dt_byte>()]( \
  672. size_t index, size_t thread_id) { \
  673. auto ws = wbundle; \
  674. ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
  675. size_t n = index / C; \
  676. size_t c = index % C; \
  677. do_##mode##_pooling_2x2_stride##i##_##func##_nchw44_NEON( \
  678. static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \
  679. c * IH * IW * 4, \
  680. static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
  681. c * OH * OW * 4, \
  682. IH, IW, OH, OW, PH, PW, ws); \
  683. }; \
  684. MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
  685. static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
  686. run); \
  687. } \
  688. MIDOUT_END();
  689. #define DISPATCH_MODE(type, func, stride) \
  690. switch (param.mode) { \
  691. case Mode::MAX: { \
  692. DISPATCH_FUNC(type, func, stride, max); \
  693. break; \
  694. } \
  695. case Mode::AVERAGE: { \
  696. DISPATCH_FUNC(type, func, stride, avg); \
  697. break; \
  698. } \
  699. default: \
  700. megdnn_throw(ssprintf("Unsupport pooling mode %d", \
  701. static_cast<int>(param.mode)) \
  702. .c_str()); \
  703. }
  704. #define DISPATCH_STRIDE(type, func) \
  705. switch (SW) { \
  706. case 1: { \
  707. DISPATCH_MODE(type, func, 1); \
  708. break; \
  709. } \
  710. case 2: { \
  711. DISPATCH_MODE(type, func, 2); \
  712. break; \
  713. } \
  714. default: \
  715. megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \
  716. }
  717. DISPATCH_STRIDE(int8_t, int8);
  718. #undef DISPATCH_STRIDE
  719. #undef DISPATCH_MODE
  720. #undef DISPATCH_FUNC
  721. }
  722. bool PoolingImpl::AlgoFilter4ModexStridexNCHW44::usable(
  723. const PoolingKernSizeParam& param) const {
  724. auto SH = param.stride[0];
  725. auto SW = param.stride[1];
  726. auto FH = param.filter[0];
  727. auto FW = param.filter[1];
  728. bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
  729. param.src_type.enumv() == DTypeEnum::Int8) &&
  730. param.format == Param::Format::NCHW44 &&
  731. (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
  732. FH == 4 && FW == 4 && SH == SW && (SW == 1 || SW == 2);
  733. //! Int8 not support average, because its round mode is different form
  734. //! qint8
  735. avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 &&
  736. param.mode == Mode::AVERAGE);
  737. return avaible;
  738. }
  739. void PoolingImpl::AlgoFilter4ModexStridexNCHW44::exec(
  740. const PoolingKernParam& param) const {
  741. auto IH = param.isz[0], IW = param.isz[1];
  742. auto OH = param.osz[0], OW = param.osz[1];
  743. auto N = param.n, C = param.ic;
  744. auto PH = param.padding[0];
  745. auto PW = param.padding[1];
  746. auto SW = param.stride[0];
  747. void* src_ptr = param.src_ptr;
  748. void* dst_ptr = param.dst_ptr;
  749. #define DISPATCH_FUNC(type, func, i, mode) \
  750. MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(10), \
  751. midout_iv(#func #i##_hash)) { \
  752. WorkspaceBundle wbundle = get_bundle_nchw44(param); \
  753. auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
  754. wbundle = wbundle, \
  755. workspace_ptr = param.workspace<dt_byte>()]( \
  756. size_t index, size_t thread_id) { \
  757. auto ws = wbundle; \
  758. ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
  759. size_t n = index / C; \
  760. size_t c = index % C; \
  761. do_##mode##_pooling_4x4_stride##i##_##func##_nchw44_NEON( \
  762. static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \
  763. c * IH * IW * 4, \
  764. static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
  765. c * OH * OW * 4, \
  766. IH, IW, OH, OW, PH, PW, ws); \
  767. }; \
  768. MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
  769. static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
  770. run); \
  771. } \
  772. MIDOUT_END();
  773. #define DISPATCH_MODE(type, func, stride) \
  774. switch (param.mode) { \
  775. case Mode::MAX: { \
  776. DISPATCH_FUNC(type, func, stride, max); \
  777. break; \
  778. } \
  779. case Mode::AVERAGE: { \
  780. DISPATCH_FUNC(type, func, stride, avg); \
  781. break; \
  782. } \
  783. default: \
  784. megdnn_throw(ssprintf("Unsupport pooling mode %d", \
  785. static_cast<int>(param.mode)) \
  786. .c_str()); \
  787. }
  788. #define DISPATCH_STRIDE(type, func) \
  789. switch (SW) { \
  790. case 1: { \
  791. DISPATCH_MODE(type, func, 1); \
  792. break; \
  793. } \
  794. case 2: { \
  795. DISPATCH_MODE(type, func, 2); \
  796. break; \
  797. } \
  798. default: \
  799. megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \
  800. }
  801. DISPATCH_STRIDE(int8_t, int8);
  802. #undef DISPATCH_STRIDE
  803. #undef DISPATCH_MODE
  804. #undef DISPATCH_FUNC
  805. }
  806. bool PoolingImpl::AlgoFilter5ModexStridexNCHW44::usable(
  807. const PoolingKernSizeParam& param) const {
  808. auto SH = param.stride[0];
  809. auto SW = param.stride[1];
  810. auto FH = param.filter[0];
  811. auto FW = param.filter[1];
  812. bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
  813. param.src_type.enumv() == DTypeEnum::Int8) &&
  814. param.format == Param::Format::NCHW44 &&
  815. (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
  816. FH == 5 && FW == 5 && SH == SW && (SW == 1 || SW == 2);
  817. //! Int8 not support average, because its round mode is different form
  818. //! qint8
  819. avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 &&
  820. param.mode == Mode::AVERAGE);
  821. return avaible;
  822. }
  823. void PoolingImpl::AlgoFilter5ModexStridexNCHW44::exec(
  824. const PoolingKernParam& param) const {
  825. auto IH = param.isz[0], IW = param.isz[1];
  826. auto OH = param.osz[0], OW = param.osz[1];
  827. auto N = param.n, C = param.ic;
  828. auto PH = param.padding[0];
  829. auto PW = param.padding[1];
  830. auto SW = param.stride[0];
  831. void* src_ptr = param.src_ptr;
  832. void* dst_ptr = param.dst_ptr;
  833. #define DISPATCH_FUNC(type, func, i, mode) \
  834. MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(11), \
  835. midout_iv(#func #i##_hash)) { \
  836. WorkspaceBundle wbundle = get_bundle_nchw44(param); \
  837. auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
  838. wbundle = wbundle, \
  839. workspace_ptr = param.workspace<dt_byte>()]( \
  840. size_t index, size_t thread_id) { \
  841. auto ws = wbundle; \
  842. ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
  843. size_t n = index / C; \
  844. size_t c = index % C; \
  845. do_##mode##_pooling_5x5_stride##i##_##func##_nchw44_NEON( \
  846. static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \
  847. c * IH * IW * 4, \
  848. static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
  849. c * OH * OW * 4, \
  850. IH, IW, OH, OW, PH, PW, ws); \
  851. }; \
  852. MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
  853. static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
  854. run); \
  855. } \
  856. MIDOUT_END();
  857. #define DISPATCH_MODE(type, func, stride) \
  858. switch (param.mode) { \
  859. case Mode::MAX: { \
  860. DISPATCH_FUNC(type, func, stride, max); \
  861. break; \
  862. } \
  863. case Mode::AVERAGE: { \
  864. DISPATCH_FUNC(type, func, stride, avg); \
  865. break; \
  866. } \
  867. default: \
  868. megdnn_throw(ssprintf("Unsupport pooling mode %d", \
  869. static_cast<int>(param.mode)) \
  870. .c_str()); \
  871. }
  872. #define DISPATCH_STRIDE(type, func) \
  873. switch (SW) { \
  874. case 1: { \
  875. DISPATCH_MODE(type, func, 1); \
  876. break; \
  877. } \
  878. case 2: { \
  879. DISPATCH_MODE(type, func, 2); \
  880. break; \
  881. } \
  882. default: \
  883. megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \
  884. }
  885. DISPATCH_STRIDE(int8_t, int8);
  886. #undef DISPATCH_STRIDE
  887. #undef DISPATCH_MODE
  888. #undef DISPATCH_FUNC
  889. }
  890. } // namespace arm_common
  891. } // namespace megdnn
  892. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台