You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_common.h 60 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447
  1. #pragma once
  2. #include "megdnn/dtype.h"
  3. #include "src/fallback/elemwise_helper/kimpl/pow.h"
  4. namespace megdnn {
  5. namespace elemwise {
  6. /*!
  7. * \brief broadcast type
  8. * BCAST_x[0]x[1]...: x[i] == !stride[i]
  9. */
  10. enum BcastType {
  11. VEC,
  12. VEC_VEC,
  13. VEC_BCAST101,
  14. VEC_BCASTX0X,
  15. VEC_BCAST111C,
  16. VEC_BCAST101xX,
  17. VEC_SCALAR,
  18. SCALAR_VEC,
  19. BCAST101_VEC,
  20. BCASTX0X_VEC,
  21. BCAST111C_VEC,
  22. BCAST101xX_VEC,
  23. VEC_VEC_VEC,
  24. VEC_VEC_SCALAR,
  25. BCAST101_VEC_BCAST101,
  26. BCAST111C_VEC_BCAST111C,
  27. BCAST101xX_VEC_BCAST101xX,
  28. VEC_BCAST101_VEC,
  29. VEC_BCAST111C_VEC,
  30. VEC_BCAST101xX_VEC,
  31. VEC_SCALAR_VEC,
  32. VEC_SCALAR_SCALAR,
  33. UNKNOWN_BCAST_TYPE
  34. };
  35. ///////////////////////////////// ParamElemVistor v2///////////////////////////
  36. template <typename ctype>
  37. struct ParamElemVisitorV2;
  38. //! visitor single elemwise, and dup to vector
  39. template <typename ctype>
  40. struct ParamElemVisitorDupV2;
  41. template <typename ctype>
  42. struct ParamElemVisitorBcast101x4V2;
  43. #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, _simd_type_v2) \
  44. template <> \
  45. struct ParamElemVisitorV2<_ctype> { \
  46. _simd_type_v2 operator()(const _ctype* src, const _ctype* src_1) const { \
  47. _simd_type_v2 ret; \
  48. GiSetSubVector##_fun_suffix##V2(ret, 0, GiLoad##_fun_suffix(src)); \
  49. GiSetSubVector##_fun_suffix##V2(ret, 1, GiLoad##_fun_suffix(src_1)); \
  50. return ret; \
  51. } \
  52. }; \
  53. template <> \
  54. struct ParamElemVisitorDupV2<_ctype> { \
  55. _simd_type_v2 operator()(const _ctype* src) const { \
  56. _simd_type_v2 ret; \
  57. _simd_type tmp = GiBroadcast##_fun_suffix( \
  58. *reinterpret_cast<const _inner_ctype*>(src)); \
  59. GiSetSubVector##_fun_suffix##V2(ret, 0, tmp); \
  60. GiSetSubVector##_fun_suffix##V2(ret, 1, tmp); \
  61. return ret; \
  62. } \
  63. }
  64. cb(dt_qint32, int32_t, GI_INT32_t, Int32, GI_INT32_V2_t);
  65. cb(dt_qint8, int8_t, GI_INT8_t, Int8, GI_INT8_V2_t);
  66. cb(dt_float32, float, GI_FLOAT32_t, Float32, GI_FLOAT32_V2_t);
  67. cb(dt_int32, int32_t, GI_INT32_t, Int32, GI_INT32_V2_t);
  68. cb(dt_int8, int8_t, GI_INT8_t, Int8, GI_INT8_V2_t);
  69. #undef cb
  70. template <typename ctype>
  71. struct ParamElemVisitorBcast101x4V2;
  72. #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix, _simd_type_v2) \
  73. template <> \
  74. struct ParamElemVisitorBcast101x4V2<_ctype> { \
  75. _simd_type_v2 operator()(const _ctype* src) const { \
  76. _simd_type_v2 ret; \
  77. _simd_type tmp = \
  78. GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \
  79. *reinterpret_cast<const _inner_ctype*>(src))); \
  80. GiSetSubVector##_fun_suffix##V2(ret, 0, tmp); \
  81. GiSetSubVector##_fun_suffix##V2(ret, 1, tmp); \
  82. return ret; \
  83. } \
  84. }
  85. cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32, GI_INT8_V2_t);
  86. cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32, GI_INT8_V2_t);
  87. #undef cb
  88. #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, _simd_type_v2) \
  89. template <> \
  90. struct ParamElemVisitorBcast101x4V2<_ctype> { \
  91. _simd_type_v2 operator()(const _ctype* src) const { \
  92. _simd_type_v2 ret; \
  93. _simd_type tmp = GiLoad##_fun_suffix(src); \
  94. GiSetSubVector##_fun_suffix##V2(ret, 0, tmp); \
  95. GiSetSubVector##_fun_suffix##V2(ret, 1, tmp); \
  96. return ret; \
  97. } \
  98. }
  99. cb(dt_qint32, int32_t, GI_INT32_t, Int32, GI_INT32_V2_t);
  100. cb(dt_float32, float, GI_FLOAT32_t, Float32, GI_FLOAT32_V2_t);
  101. cb(dt_int32, int32_t, GI_INT32_t, Int32, GI_INT32_V2_t);
  102. #undef cb
  103. ///////////////////////////////// OpCaller /////////////////////////////
  104. template <typename Op, BcastType bcast_type>
  105. struct OpCallerUnary;
  106. template <typename Op>
  107. struct OpCallerUnary<Op, VEC> {
  108. static void run(
  109. const typename Op::src_ctype* src, typename Op::dst_ctype* dst,
  110. DType src_dtype, DType dst_dtype, size_t nr_elems) {
  111. Op op(src_dtype, dst_dtype);
  112. ParamElemVisitorV2<typename Op::src_ctype> vis;
  113. size_t i = 0;
  114. for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) {
  115. op(vis(src, src + Op::SIMD_WIDTH), dst);
  116. src += Op::SIMD_WIDTH * 2;
  117. dst += Op::SIMD_WIDTH * 2;
  118. }
  119. #if MEGDNN_FIX_AARCH32_BUG
  120. // FIXME: as llvm may cause cannot select error if enable vectorize
  121. #pragma clang loop vectorize(disable)
  122. #endif
  123. for (; i < nr_elems; i++) {
  124. op(*src, dst);
  125. src++;
  126. dst++;
  127. }
  128. }
  129. };
  130. template <typename Op, BcastType bcast_type, typename enbale = void>
  131. struct OpCallerBinary;
  132. ///////////////////////// Pow ////////////////////////////////
  133. template <typename ctype>
  134. struct OpCallerBinary<fallback::PowOp<ctype, ctype>, VEC_VEC> {
  135. using Op = fallback::PowOp<ctype, ctype>;
  136. static void run(
  137. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  138. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  139. DType dst_dtype, size_t nr_elems) {
  140. Op op(src0_dtype, src1_dtype, dst_dtype);
  141. size_t i = 0;
  142. #if MEGDNN_FIX_AARCH32_BUG
  143. // FIXME: as llvm may cause cannot select error if enable vectorize
  144. #pragma clang loop vectorize(disable)
  145. #endif
  146. for (; i < nr_elems; i++) {
  147. op(*src0, *src1, dst);
  148. src0++;
  149. src1++;
  150. dst++;
  151. }
  152. }
  153. };
  154. template <typename ctype>
  155. struct OpCallerBinary<fallback::PowOp<ctype, ctype>, VEC_SCALAR> {
  156. using Op = fallback::PowOp<ctype, ctype>;
  157. static void run(
  158. const typename Op::src_ctype* src0, const typename Op::src_ctype src1,
  159. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  160. DType dst_dtype, size_t nr_elems) {
  161. Op op(src0_dtype, src1_dtype, dst_dtype);
  162. size_t i = 0;
  163. #if MEGDNN_FIX_AARCH32_BUG
  164. // FIXME: as llvm may cause cannot select error if enable vectorize
  165. #pragma clang loop vectorize(disable)
  166. #endif
  167. for (; i < nr_elems; i++) {
  168. op(*src0, src1, dst);
  169. src0++;
  170. dst++;
  171. }
  172. }
  173. };
  174. template <typename ctype>
  175. struct OpCallerBinary<fallback::PowOp<ctype, ctype>, VEC_BCAST101> {
  176. using Op = fallback::PowOp<ctype, ctype>;
  177. static void run(
  178. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  179. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  180. DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
  181. Op op(src0_dtype, src1_dtype, dst_dtype);
  182. for (size_t b = 0; b < batch; b++) {
  183. const typename Op::src_ctype* src1_ptr = src1;
  184. for (size_t c = 0; c < channel; c++) {
  185. size_t i = 0;
  186. #if MEGDNN_FIX_AARCH32_BUG
  187. // FIXME: as llvm may cause cannot select error if enable vectorize
  188. #pragma clang loop vectorize(disable)
  189. #endif
  190. for (; i < channel_stride; i++) {
  191. op(*src0, *src1_ptr, dst);
  192. src0++;
  193. dst++;
  194. }
  195. src1_ptr++;
  196. }
  197. }
  198. }
  199. };
  200. template <typename ctype>
  201. struct OpCallerBinary<fallback::PowOp<ctype, ctype>, VEC_BCASTX0X> {
  202. using Op = fallback::PowOp<ctype, ctype>;
  203. static void run(
  204. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  205. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  206. DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
  207. Op op(src0_dtype, src1_dtype, dst_dtype);
  208. for (size_t b = 0; b < batch; b++) {
  209. const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride;
  210. for (size_t c = 0; c < channel; c++) {
  211. size_t i = 0;
  212. auto src1_ptr = src1_ptr_base;
  213. #if MEGDNN_FIX_AARCH32_BUG
  214. // FIXME: as llvm may cause cannot select error if enable vectorize
  215. #pragma clang loop vectorize(disable)
  216. #endif
  217. for (; i < channel_stride; i++) {
  218. op(*src0, *src1_ptr, dst);
  219. src0++;
  220. src1_ptr++;
  221. dst++;
  222. }
  223. }
  224. }
  225. }
  226. };
  227. template <typename ctype>
  228. struct OpCallerBinary<fallback::PowOp<ctype, ctype>, VEC_BCAST111C> {
  229. using Op = fallback::PowOp<ctype, ctype>;
  230. static void run(
  231. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  232. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  233. DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
  234. Op op(src0_dtype, src1_dtype, dst_dtype);
  235. for (size_t b = 0; b < batch; b++) {
  236. for (size_t c = 0; c < channel; c++) {
  237. size_t i = 0;
  238. const typename Op::src_ctype* src1_ptr = src1;
  239. #if MEGDNN_FIX_AARCH32_BUG
  240. // FIXME: as llvm may cause cannot select error if enable vectorize
  241. #pragma clang loop vectorize(disable)
  242. #endif
  243. for (; i < channel_stride; i++) {
  244. op(*src0, *src1_ptr, dst);
  245. src0++;
  246. src1_ptr++;
  247. dst++;
  248. }
  249. }
  250. }
  251. }
  252. };
  253. template <typename ctype>
  254. struct OpCallerBinary<fallback::PowOp<ctype, ctype>, BCAST111C_VEC> {
  255. using Op = fallback::PowOp<ctype, ctype>;
  256. static void run(
  257. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  258. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  259. DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
  260. Op op(src0_dtype, src1_dtype, dst_dtype);
  261. for (size_t b = 0; b < batch; b++) {
  262. for (size_t c = 0; c < channel; c++) {
  263. size_t i = 0;
  264. const typename Op::src_ctype* src0_ptr = src0;
  265. #if MEGDNN_FIX_AARCH32_BUG
  266. // FIXME: as llvm may cause cannot select error if enable vectorize
  267. #pragma clang loop vectorize(disable)
  268. #endif
  269. for (; i < channel_stride; i++) {
  270. op(*src0_ptr, *src1, dst);
  271. src0_ptr++;
  272. src1++;
  273. dst++;
  274. }
  275. }
  276. }
  277. }
  278. };
  279. template <typename ctype>
  280. struct OpCallerBinary<fallback::PowOp<ctype, ctype>, SCALAR_VEC> {
  281. using Op = fallback::PowOp<ctype, ctype>;
  282. static void run(
  283. const typename Op::src_ctype src0, const typename Op::src_ctype* src1,
  284. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  285. DType dst_dtype, size_t nr_elems) {
  286. Op op(src0_dtype, src1_dtype, dst_dtype);
  287. size_t i = 0;
  288. #if MEGDNN_FIX_AARCH32_BUG
  289. // FIXME: as llvm may cause cannot select error if enable vectorize
  290. #pragma clang loop vectorize(disable)
  291. #endif
  292. for (; i < nr_elems; i++) {
  293. op(src0, *src1, dst);
  294. src1++;
  295. dst++;
  296. }
  297. }
  298. };
  299. template <typename ctype>
  300. struct OpCallerBinary<fallback::PowOp<ctype, ctype>, BCAST101_VEC> {
  301. using Op = fallback::PowOp<ctype, ctype>;
  302. static void run(
  303. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  304. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  305. DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
  306. Op op(src0_dtype, src1_dtype, dst_dtype);
  307. for (size_t b = 0; b < batch; b++) {
  308. auto src0_ptr = src0;
  309. for (size_t c = 0; c < channel; c++) {
  310. size_t i = 0;
  311. #if MEGDNN_FIX_AARCH32_BUG
  312. // FIXME: as llvm may cause cannot select error if enable vectorize
  313. #pragma clang loop vectorize(disable)
  314. #endif
  315. for (; i < channel_stride; i++) {
  316. op(*src0_ptr, *src1, dst);
  317. src1++;
  318. dst++;
  319. }
  320. src0_ptr++;
  321. }
  322. }
  323. }
  324. };
  325. template <typename ctype>
  326. struct OpCallerBinary<fallback::PowOp<ctype, ctype>, BCASTX0X_VEC> {
  327. using Op = fallback::PowOp<ctype, ctype>;
  328. static void run(
  329. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  330. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  331. DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
  332. Op op(src0_dtype, src1_dtype, dst_dtype);
  333. for (size_t b = 0; b < batch; b++) {
  334. auto src0_ptr_base = src0 + b * channel_stride;
  335. for (size_t c = 0; c < channel; c++) {
  336. size_t i = 0;
  337. auto src0_ptr = src0_ptr_base;
  338. #if MEGDNN_FIX_AARCH32_BUG
  339. // FIXME: as llvm may cause cannot select error if enable vectorize
  340. #pragma clang loop vectorize(disable)
  341. #endif
  342. for (; i < channel_stride; i++) {
  343. op(*src0_ptr, *src1, dst);
  344. src0_ptr++;
  345. src1++;
  346. dst++;
  347. }
  348. }
  349. }
  350. }
  351. };
  352. template <typename Op>
  353. struct OpCallerBinary<Op, VEC_VEC> {
  354. static void run(
  355. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  356. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  357. DType dst_dtype, size_t nr_elems) {
  358. Op op(src0_dtype, src1_dtype, dst_dtype);
  359. ParamElemVisitorV2<typename Op::src_ctype> vis0;
  360. ParamElemVisitorV2<typename Op::src_ctype> vis1;
  361. size_t i = 0;
  362. for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) {
  363. op(vis0(src0, src0 + Op::SIMD_WIDTH), vis1(src1, src1 + Op::SIMD_WIDTH),
  364. dst);
  365. src0 += Op::SIMD_WIDTH * 2;
  366. src1 += Op::SIMD_WIDTH * 2;
  367. dst += Op::SIMD_WIDTH * 2;
  368. }
  369. #if MEGDNN_FIX_AARCH32_BUG
  370. // FIXME: as llvm may cause cannot select error if enable vectorize
  371. #pragma clang loop vectorize(disable)
  372. #endif
  373. for (; i < nr_elems; i++) {
  374. op(*src0, *src1, dst);
  375. src0++;
  376. src1++;
  377. dst++;
  378. }
  379. }
  380. };
  381. template <typename Op>
  382. struct OpCallerBinary<Op, VEC_BCAST101> {
  383. static void run(
  384. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  385. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  386. DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
  387. Op op(src0_dtype, src1_dtype, dst_dtype);
  388. ParamElemVisitorV2<typename Op::src_ctype> vis0;
  389. ParamElemVisitorDupV2<typename Op::src_ctype> vis1;
  390. for (size_t b = 0; b < batch; b++) {
  391. const typename Op::src_ctype* src1_ptr = src1;
  392. for (size_t c = 0; c < channel; c++) {
  393. size_t i = 0;
  394. auto src1_simd_v2 = vis1(src1_ptr);
  395. for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
  396. i += Op::SIMD_WIDTH * 2) {
  397. op(vis0(src0, src0 + Op::SIMD_WIDTH), src1_simd_v2, dst);
  398. src0 += Op::SIMD_WIDTH * 2;
  399. dst += Op::SIMD_WIDTH * 2;
  400. }
  401. #if MEGDNN_FIX_AARCH32_BUG
  402. // FIXME: as llvm may cause cannot select error if enable vectorize
  403. #pragma clang loop vectorize(disable)
  404. #endif
  405. for (; i < channel_stride; i++) {
  406. op(*src0, *src1_ptr, dst);
  407. src0++;
  408. dst++;
  409. }
  410. src1_ptr++;
  411. }
  412. }
  413. }
  414. };
  415. template <typename Op>
  416. struct OpCallerBinary<Op, VEC_BCASTX0X> {
  417. static void run(
  418. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  419. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  420. DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
  421. Op op(src0_dtype, src1_dtype, dst_dtype);
  422. ParamElemVisitorV2<typename Op::src_ctype> vis;
  423. for (size_t b = 0; b < batch; b++) {
  424. const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride;
  425. for (size_t c = 0; c < channel; c++) {
  426. size_t i = 0;
  427. auto src1_ptr = src1_ptr_base;
  428. for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
  429. i += Op::SIMD_WIDTH * 2) {
  430. auto src0_simd01 = vis(src0, src0 + Op::SIMD_WIDTH);
  431. auto src1_simd01 = vis(src1_ptr, src1_ptr + Op::SIMD_WIDTH);
  432. op(src0_simd01, src1_simd01, dst);
  433. src0 += Op::SIMD_WIDTH * 2;
  434. src1_ptr += Op::SIMD_WIDTH * 2;
  435. dst += Op::SIMD_WIDTH * 2;
  436. }
  437. #if MEGDNN_FIX_AARCH32_BUG
  438. // FIXME: as llvm may cause cannot select error if enable vectorize
  439. #pragma clang loop vectorize(disable)
  440. #endif
  441. for (; i < channel_stride; i++) {
  442. op(*src0, *src1_ptr, dst);
  443. src0++;
  444. src1_ptr++;
  445. dst++;
  446. }
  447. }
  448. }
  449. }
  450. };
  451. template <typename Op>
  452. struct OpCallerBinary<Op, VEC_BCAST111C> {
  453. static void run(
  454. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  455. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  456. DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
  457. Op op(src0_dtype, src1_dtype, dst_dtype);
  458. ParamElemVisitorV2<typename Op::src_ctype> vis;
  459. for (size_t b = 0; b < batch; b++) {
  460. for (size_t c = 0; c < channel; c++) {
  461. size_t rest = channel_stride;
  462. const typename Op::src_ctype* src1_ptr = src1;
  463. while (rest >= Op::SIMD_WIDTH * 2) {
  464. auto src0_simd01 = vis(src0, src0 + Op::SIMD_WIDTH);
  465. auto src1_simd01 = vis(src1_ptr, src1_ptr + Op::SIMD_WIDTH);
  466. src0 += Op::SIMD_WIDTH * 2;
  467. src1_ptr += Op::SIMD_WIDTH * 2;
  468. op(src0_simd01, src1_simd01, dst);
  469. dst += Op::SIMD_WIDTH * 2;
  470. rest -= Op::SIMD_WIDTH * 2;
  471. }
  472. #if MEGDNN_FIX_AARCH32_BUG
  473. // FIXME: as llvm may cause cannot select error if enable vectorize
  474. #pragma clang loop vectorize(disable)
  475. #endif
  476. while (rest > 0) {
  477. op(*src0, *src1_ptr, dst);
  478. dst++;
  479. src0++;
  480. src1_ptr++;
  481. rest--;
  482. }
  483. }
  484. }
  485. }
  486. };
  487. template <typename Op>
  488. struct OpCallerBinary<Op, BCAST111C_VEC> {
  489. static void run(
  490. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  491. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  492. DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
  493. Op op(src0_dtype, src1_dtype, dst_dtype);
  494. ParamElemVisitorV2<typename Op::src_ctype> vis;
  495. for (size_t b = 0; b < batch; b++) {
  496. for (size_t c = 0; c < channel; c++) {
  497. size_t rest = channel_stride;
  498. const typename Op::src_ctype* src0_ptr = src0;
  499. while (rest >= Op::SIMD_WIDTH * 2) {
  500. auto src0_simd01 = vis(src0_ptr, src0_ptr + Op::SIMD_WIDTH);
  501. auto src1_simd01 = vis(src1, src1 + Op::SIMD_WIDTH);
  502. src0_ptr += Op::SIMD_WIDTH * 2;
  503. src1 += Op::SIMD_WIDTH * 2;
  504. op(src0_simd01, src1_simd01, dst);
  505. dst += Op::SIMD_WIDTH * 2;
  506. rest -= Op::SIMD_WIDTH * 2;
  507. }
  508. #if MEGDNN_FIX_AARCH32_BUG
  509. // FIXME: as llvm may cause cannot select error if enable vectorize
  510. #pragma clang loop vectorize(disable)
  511. #endif
  512. while (rest > 0) {
  513. op(*src0_ptr, *src1, dst);
  514. dst++;
  515. src0_ptr++;
  516. src1++;
  517. rest--;
  518. }
  519. }
  520. }
  521. }
  522. };
  523. template <typename ctype>
  524. struct OpCallerBinary<fallback::PowOp<ctype, ctype>, BCAST101xX_VEC> {
  525. using Op = fallback::PowOp<ctype, ctype>;
  526. static void run(
  527. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  528. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  529. DType dst_dtype, size_t batch, size_t nr_channel_blocks,
  530. size_t channel_stride, size_t channel_block_dim) {
  531. Op op(src0_dtype, src1_dtype, dst_dtype);
  532. for (size_t b = 0; b < batch; b++) {
  533. auto src0_ptr = src0;
  534. for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
  535. auto src0_block_ptr = src0_ptr + cb * channel_block_dim;
  536. for (size_t i = 0; i < channel_stride; i++) {
  537. for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) {
  538. op(*(src0_block_ptr + c_iter), *src1, dst);
  539. src1++;
  540. dst++;
  541. }
  542. }
  543. }
  544. }
  545. }
  546. };
  547. template <typename src_ctype, size_t channel_block_dim>
  548. struct OpCallerBinaryBcast101xXVec {
  549. template <typename Op>
  550. static void run(
  551. const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst,
  552. const Op& op, size_t batch, size_t nr_channel_blocks,
  553. size_t channel_stride) {
  554. for (size_t b = 0; b < batch; b++) {
  555. auto src0_ptr = src0;
  556. for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
  557. auto src0_block_ptr = src0_ptr + cb * channel_block_dim;
  558. for (size_t img_index = 0; img_index < channel_stride; img_index++) {
  559. for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) {
  560. op(*(src0_block_ptr + c_iter), *src1, dst);
  561. src1++;
  562. dst++;
  563. }
  564. }
  565. }
  566. }
  567. }
  568. };
  569. template <typename src_ctype, size_t channel_block_dim>
  570. struct OpCallerBinaryBcast101xDVec {
  571. template <typename Op, typename Vis0, typename Vis1>
  572. static void run(
  573. const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst,
  574. const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch,
  575. size_t nr_channel_blocks, size_t channel_stride) {
  576. for (size_t b = 0; b < batch; b++) {
  577. auto src0_ptr = src0;
  578. for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
  579. auto src0_block_ptr = src0_ptr + cb * channel_block_dim;
  580. auto channel_block_vec_v2 = vis0(src0_block_ptr);
  581. size_t img_index = 0;
  582. auto src1_offset = Op::SIMD_WIDTH / channel_block_dim;
  583. for (; img_index + 2 * src1_offset <= channel_stride;
  584. img_index += 2 * src1_offset) {
  585. op(channel_block_vec_v2, vis1(src1, src1 + Op::SIMD_WIDTH), dst);
  586. src1 += Op::SIMD_WIDTH * 2;
  587. dst += Op::SIMD_WIDTH * 2;
  588. }
  589. // TODO:all elemwise_multi_type op imp one simd mode
  590. for (; img_index < channel_stride; img_index++) {
  591. for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) {
  592. op(*(src0_block_ptr + c_iter), *src1, dst);
  593. src1++;
  594. dst++;
  595. }
  596. }
  597. }
  598. }
  599. }
  600. };
  601. template <typename src_ctype>
  602. struct OpCallerBinaryBcast101xXVec<src_ctype, 4> {
  603. template <typename Op>
  604. static void run(
  605. const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst,
  606. const Op& op, size_t batch, size_t nr_channel_blocks,
  607. size_t channel_stride) {
  608. ParamElemVisitorBcast101x4V2<typename Op::src_ctype> vis0;
  609. ParamElemVisitorV2<typename Op::src_ctype> vis1;
  610. OpCallerBinaryBcast101xDVec<src_ctype, 4>::run(
  611. src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks,
  612. channel_stride);
  613. }
  614. };
  615. template <typename Op>
  616. struct OpCallerBinary<Op, BCAST101xX_VEC> {
  617. static void run(
  618. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  619. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  620. DType dst_dtype, size_t batch, size_t nr_channel_blocks,
  621. size_t channel_stride, size_t channel_block_dim) {
  622. megdnn_assert(
  623. channel_block_dim == 4 || channel_block_dim == 8,
  624. "only imp for nchw44/nchw88");
  625. Op op(src0_dtype, src1_dtype, dst_dtype);
  626. if (channel_block_dim == 4) {
  627. OpCallerBinaryBcast101xXVec<typename Op::src_ctype, 4>::run(
  628. src0, src1, dst, op, batch, nr_channel_blocks, channel_stride);
  629. } else {
  630. OpCallerBinaryBcast101xXVec<typename Op::src_ctype, 8>::run(
  631. src0, src1, dst, op, batch, nr_channel_blocks, channel_stride);
  632. }
  633. }
  634. };
  635. template <typename ctype>
  636. struct OpCallerBinary<fallback::PowOp<ctype, ctype>, VEC_BCAST101xX> {
  637. using Op = fallback::PowOp<ctype, ctype>;
  638. static void run(
  639. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  640. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  641. DType dst_dtype, size_t batch, size_t nr_channel_blocks,
  642. size_t channel_stride, size_t channel_block_dim) {
  643. Op op(src0_dtype, src1_dtype, dst_dtype);
  644. for (size_t b = 0; b < batch; b++) {
  645. auto src1_ptr = src1;
  646. for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
  647. auto src1_block_ptr = src1_ptr + cb * channel_block_dim;
  648. for (size_t i = 0; i < channel_stride; i++) {
  649. for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) {
  650. op(*(src0), *(src1_block_ptr + c_iter), dst);
  651. src0++;
  652. dst++;
  653. }
  654. }
  655. }
  656. }
  657. }
  658. };
  659. template <typename src_ctype, size_t channel_block_dim>
  660. struct OpCallerBinaryVecBcast101xX {
  661. template <typename Op>
  662. static void run(
  663. const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst,
  664. const Op& op, size_t batch, size_t nr_channel_blocks,
  665. size_t channel_stride) {
  666. for (size_t b = 0; b < batch; b++) {
  667. auto src1_ptr = src1;
  668. for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
  669. auto src1_block_ptr = src1_ptr + cb * channel_block_dim;
  670. for (size_t img_index = 0; img_index < channel_stride; img_index++) {
  671. for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) {
  672. op(*src0, *(src1_block_ptr + c_iter), dst);
  673. src0++;
  674. dst++;
  675. }
  676. }
  677. }
  678. }
  679. }
  680. };
  681. template <typename src_ctype, size_t channel_block_dim>
  682. struct OpCallerBinaryVecBcast101xD {
  683. template <typename Op, typename Vis0, typename Vis1>
  684. static void run(
  685. const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst,
  686. const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch,
  687. size_t nr_channel_blocks, size_t channel_stride) {
  688. for (size_t b = 0; b < batch; b++) {
  689. auto src1_ptr = src1;
  690. for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
  691. auto src1_block_ptr = src1_ptr + cb * channel_block_dim;
  692. auto channel_block_vec_v2 = vis1(src1_block_ptr);
  693. size_t img_index = 0;
  694. auto src0_offset = Op::SIMD_WIDTH / channel_block_dim;
  695. for (; img_index + 2 * src0_offset <= channel_stride;
  696. img_index += 2 * src0_offset) {
  697. op(vis0(src0, src0 + Op::SIMD_WIDTH), channel_block_vec_v2, dst);
  698. src0 += Op::SIMD_WIDTH * 2;
  699. dst += Op::SIMD_WIDTH * 2;
  700. }
  701. // TODO:all elemwise_multi_type op imp one simd mode
  702. for (; img_index < channel_stride; img_index++) {
  703. for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) {
  704. op(*src0, *(src1_block_ptr + c_iter), dst);
  705. src0++;
  706. dst++;
  707. }
  708. }
  709. }
  710. }
  711. }
  712. };
  713. template <typename src_ctype>
  714. struct OpCallerBinaryVecBcast101xX<src_ctype, 4> {
  715. template <typename Op>
  716. static void run(
  717. const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst,
  718. const Op& op, size_t batch, size_t nr_channel_blocks,
  719. size_t channel_stride) {
  720. ParamElemVisitorV2<src_ctype> vis0;
  721. ParamElemVisitorBcast101x4V2<src_ctype> vis1;
  722. OpCallerBinaryVecBcast101xD<src_ctype, 4>::run(
  723. src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks,
  724. channel_stride);
  725. }
  726. };
  727. template <typename Op>
  728. struct OpCallerBinary<Op, VEC_BCAST101xX> {
  729. static void run(
  730. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  731. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  732. DType dst_dtype, size_t batch, size_t nr_channel_blocks,
  733. size_t channel_stride, size_t channel_block_dim) {
  734. megdnn_assert(
  735. channel_block_dim == 4 || channel_block_dim == 8,
  736. "only imp for nchw44/nchw88");
  737. Op op(src0_dtype, src1_dtype, dst_dtype);
  738. if (channel_block_dim == 4) {
  739. OpCallerBinaryVecBcast101xX<typename Op::src_ctype, 4>::run(
  740. src0, src1, dst, op, batch, nr_channel_blocks, channel_stride);
  741. } else {
  742. OpCallerBinaryVecBcast101xX<typename Op::src_ctype, 8>::run(
  743. src0, src1, dst, op, batch, nr_channel_blocks, channel_stride);
  744. }
  745. }
  746. };
  747. template <typename Op>
  748. struct OpCallerBinary<Op, VEC_SCALAR> {
  749. static void run(
  750. const typename Op::src_ctype* src0, const typename Op::src_ctype src1,
  751. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  752. DType dst_dtype, size_t nr_elems) {
  753. Op op(src0_dtype, src1_dtype, dst_dtype);
  754. ParamElemVisitorV2<typename Op::src_ctype> vis0;
  755. ParamElemVisitorDupV2<typename Op::src_ctype> vis1;
  756. auto vis1_simd_v2 = vis1(&src1);
  757. size_t i = 0;
  758. for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) {
  759. op(vis0(src0, src0 + Op::SIMD_WIDTH), vis1_simd_v2, dst);
  760. src0 += Op::SIMD_WIDTH * 2;
  761. dst += Op::SIMD_WIDTH * 2;
  762. }
  763. #if MEGDNN_FIX_AARCH32_BUG
  764. // FIXME: as llvm may cause cannot select error if enable vectorize
  765. #pragma clang loop vectorize(disable)
  766. #endif
  767. for (; i < nr_elems; i++) {
  768. op(*src0, src1, dst);
  769. src0++;
  770. dst++;
  771. }
  772. }
  773. };
  774. //! this only for nonswap op, like SUB and DIV
  775. template <typename Op>
  776. struct OpCallerBinary<Op, SCALAR_VEC> {
  777. static void run(
  778. const typename Op::src_ctype src0, const typename Op::src_ctype* src1,
  779. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  780. DType dst_dtype, size_t nr_elems) {
  781. Op op(src0_dtype, src1_dtype, dst_dtype);
  782. ParamElemVisitorDupV2<typename Op::src_ctype> vis0;
  783. ParamElemVisitorV2<typename Op::src_ctype> vis1;
  784. auto vis0_simd_v2 = vis0(&src0);
  785. size_t i = 0;
  786. for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) {
  787. op(vis0_simd_v2, vis1(src1, src1 + Op::SIMD_WIDTH), dst);
  788. src1 += Op::SIMD_WIDTH * 2;
  789. dst += Op::SIMD_WIDTH * 2;
  790. }
  791. #if MEGDNN_FIX_AARCH32_BUG
  792. // FIXME: as llvm may cause cannot select error if enable vectorize
  793. #pragma clang loop vectorize(disable)
  794. #endif
  795. for (; i < nr_elems; i++) {
  796. op(src0, *src1, dst);
  797. src1++;
  798. dst++;
  799. }
  800. }
  801. };
  802. template <typename Op>
  803. struct OpCallerBinary<Op, BCAST101_VEC> {
  804. static void run(
  805. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  806. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  807. DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
  808. Op op(src0_dtype, src1_dtype, dst_dtype);
  809. ParamElemVisitorDupV2<typename Op::src_ctype> vis0;
  810. ParamElemVisitorV2<typename Op::src_ctype> vis1;
  811. for (size_t b = 0; b < batch; b++) {
  812. auto src0_ptr = src0;
  813. for (size_t c = 0; c < channel; c++) {
  814. auto vis0_simd_v2 = vis0(src0_ptr);
  815. size_t i = 0;
  816. for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
  817. i += Op::SIMD_WIDTH * 2) {
  818. op(vis0_simd_v2, vis1(src1, src1 + Op::SIMD_WIDTH), dst);
  819. src1 += Op::SIMD_WIDTH * 2;
  820. dst += Op::SIMD_WIDTH * 2;
  821. }
  822. #if MEGDNN_FIX_AARCH32_BUG
  823. // FIXME: as llvm may cause cannot select error if enable vectorize
  824. #pragma clang loop vectorize(disable)
  825. #endif
  826. for (; i < channel_stride; i++) {
  827. op(*src0_ptr, *src1, dst);
  828. src1++;
  829. dst++;
  830. }
  831. src0_ptr++;
  832. }
  833. }
  834. }
  835. };
  836. template <typename Op>
  837. struct OpCallerBinary<Op, BCASTX0X_VEC> {
  838. static void run(
  839. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  840. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  841. DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
  842. Op op(src0_dtype, src1_dtype, dst_dtype);
  843. ParamElemVisitorV2<typename Op::src_ctype> vis;
  844. for (size_t b = 0; b < batch; b++) {
  845. auto src0_ptr_base = src0 + b * channel_stride;
  846. for (size_t c = 0; c < channel; c++) {
  847. auto src0_ptr = src0_ptr_base;
  848. size_t i = 0;
  849. for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
  850. i += Op::SIMD_WIDTH * 2) {
  851. auto src0_simd01 = vis(src0_ptr, src0_ptr + Op::SIMD_WIDTH);
  852. auto src1_simd01 = vis(src1, src1 + Op::SIMD_WIDTH);
  853. op(src0_simd01, src1_simd01, dst);
  854. src0_ptr += Op::SIMD_WIDTH * 2;
  855. src1 += Op::SIMD_WIDTH * 2;
  856. dst += Op::SIMD_WIDTH * 2;
  857. }
  858. #if MEGDNN_FIX_AARCH32_BUG
  859. // FIXME: as llvm may cause cannot select error if enable vectorize
  860. #pragma clang loop vectorize(disable)
  861. #endif
  862. for (; i < channel_stride; i++) {
  863. op(*src0_ptr, *src1, dst);
  864. src0_ptr++;
  865. src1++;
  866. dst++;
  867. }
  868. }
  869. }
  870. }
  871. };
  872. template <typename Op, BcastType bcast_type>
  873. struct OpCallerTernary;
  874. template <typename Op>
  875. struct OpCallerTernary<Op, VEC_VEC_VEC> {
  876. static void run(
  877. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  878. const typename Op::src_ctype* src2, typename Op::dst_ctype* dst,
  879. DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype,
  880. size_t nr_elems) {
  881. Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
  882. ParamElemVisitorV2<typename Op::src_ctype> vis0;
  883. ParamElemVisitorV2<typename Op::src_ctype> vis1;
  884. ParamElemVisitorV2<typename Op::src_ctype> vis2;
  885. size_t i = 0;
  886. for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) {
  887. op(vis0(src0, src0 + Op::SIMD_WIDTH), vis1(src1, src1 + Op::SIMD_WIDTH),
  888. vis2(src2, src2 + Op::SIMD_WIDTH), dst);
  889. src0 += Op::SIMD_WIDTH * 2;
  890. src1 += Op::SIMD_WIDTH * 2;
  891. src2 += Op::SIMD_WIDTH * 2;
  892. dst += Op::SIMD_WIDTH * 2;
  893. }
  894. #if MEGDNN_FIX_AARCH32_BUG
  895. // FIXME: as llvm may cause cannot select error if enable vectorize
  896. #pragma clang loop vectorize(disable)
  897. #endif
  898. for (; i < nr_elems; i++) {
  899. op(*src0, *src1, *src2, dst);
  900. src0++;
  901. src1++;
  902. src2++;
  903. dst++;
  904. }
  905. }
  906. };
  907. //! src0: vector, src1: vector, src2: scalar
  908. template <typename Op>
  909. struct OpCallerTernary<Op, VEC_VEC_SCALAR> {
  910. static void run(
  911. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  912. const typename Op::src_ctype src2, typename Op::dst_ctype* dst,
  913. DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype,
  914. size_t nr_elems) {
  915. Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
  916. ParamElemVisitorV2<typename Op::src_ctype> vis0;
  917. ParamElemVisitorV2<typename Op::src_ctype> vis1;
  918. ParamElemVisitorDupV2<typename Op::src_ctype> vis2;
  919. auto vis2_simd_v2 = vis2(&src2);
  920. size_t i = 0;
  921. for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) {
  922. op(vis0(src0, src0 + Op::SIMD_WIDTH), vis1(src1, src1 + Op::SIMD_WIDTH),
  923. vis2_simd_v2, dst);
  924. src0 += Op::SIMD_WIDTH * 2;
  925. src1 += Op::SIMD_WIDTH * 2;
  926. dst += Op::SIMD_WIDTH * 2;
  927. }
  928. #if MEGDNN_FIX_AARCH32_BUG
  929. // FIXME: as llvm may cause cannot select error if enable vectorize
  930. #pragma clang loop vectorize(disable)
  931. #endif
  932. for (; i < nr_elems; i++) {
  933. op(*src0, *src1, src2, dst);
  934. src0++;
  935. src1++;
  936. dst++;
  937. }
  938. }
  939. };
  940. //! src0: 1C11, src1: vector, src2: 1C11
  941. template <typename Op>
  942. struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> {
  943. static void run(
  944. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  945. const typename Op::src_ctype* src2, typename Op::dst_ctype* dst,
  946. DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype,
  947. size_t batch_size, size_t channel_size, size_t channel_stride,
  948. size_t batch_offset) {
  949. Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
  950. ParamElemVisitorV2<typename Op::src_ctype> vis1;
  951. ParamElemVisitorDupV2<typename Op::src_ctype> vis0;
  952. ParamElemVisitorDupV2<typename Op::src_ctype> vis2;
  953. for (size_t batch = 0; batch < batch_size; batch++) {
  954. auto src0_ptr = src0;
  955. auto src2_ptr = src2;
  956. auto b_offset = batch_offset;
  957. for (size_t channel = 0; channel < channel_size; channel++) {
  958. size_t i = 0;
  959. auto src0_simd_v2 = vis0(src0_ptr);
  960. auto src2_simd_v2 = vis2(src2_ptr);
  961. for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
  962. i += Op::SIMD_WIDTH * 2) {
  963. op(src0_simd_v2, vis1(src1, src1 + Op::SIMD_WIDTH), src2_simd_v2,
  964. dst);
  965. src1 += Op::SIMD_WIDTH * 2;
  966. dst += Op::SIMD_WIDTH * 2;
  967. b_offset -= Op::SIMD_WIDTH * 2;
  968. }
  969. #if MEGDNN_FIX_AARCH32_BUG
  970. // FIXME: as llvm may cause cannot select error if enable vectorize
  971. #pragma clang loop vectorize(disable)
  972. #endif
  973. for (; i < channel_stride; i++) {
  974. op(*src0_ptr, *src1, *src2_ptr, dst);
  975. src1++;
  976. dst++;
  977. b_offset--;
  978. }
  979. src0_ptr++;
  980. src2_ptr++;
  981. }
  982. src1 += b_offset;
  983. dst += b_offset;
  984. }
  985. }
  986. };
  987. //! src0: 111C, src1: vector, src2: 111C, src1 may not be contig
  988. template <typename Op>
  989. struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> {
  990. static void run(
  991. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  992. size_t src1_offset, const typename Op::src_ctype* src2,
  993. typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
  994. DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size,
  995. size_t channel_stride, size_t batch_offset) {
  996. Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
  997. ParamElemVisitorV2<typename Op::src_ctype> vis;
  998. for (size_t batch = 0; batch < batch_size; batch++) {
  999. auto b_offset = batch_offset;
  1000. for (size_t channel = 0; channel < channel_size; channel++) {
  1001. auto src0_ptr = src0;
  1002. auto src2_ptr = src2;
  1003. size_t i = 0;
  1004. for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
  1005. i += Op::SIMD_WIDTH * 2) {
  1006. auto src0_simd01 = vis(src0_ptr, src0_ptr + Op::SIMD_WIDTH);
  1007. auto src1_simd01 = vis(src1, src1 + Op::SIMD_WIDTH);
  1008. auto src2_simd01 = vis(src2_ptr, src2_ptr + Op::SIMD_WIDTH);
  1009. op(src0_simd01, src1_simd01, src2_simd01, dst);
  1010. src0_ptr += Op::SIMD_WIDTH * 2;
  1011. src1 += Op::SIMD_WIDTH * 2;
  1012. src2_ptr += Op::SIMD_WIDTH * 2;
  1013. dst += Op::SIMD_WIDTH * 2;
  1014. b_offset -= Op::SIMD_WIDTH * 2;
  1015. }
  1016. #if MEGDNN_FIX_AARCH32_BUG
  1017. // FIXME: as llvm may cause cannot select error if enable vectorize
  1018. #pragma clang loop vectorize(disable)
  1019. #endif
  1020. for (; i < channel_stride; i++) {
  1021. op(*src0_ptr, *src1, *src2_ptr, dst);
  1022. src0_ptr++;
  1023. src1++;
  1024. src2_ptr++;
  1025. dst++;
  1026. b_offset--;
  1027. }
  1028. src1 += src1_offset;
  1029. }
  1030. src1 += b_offset;
  1031. dst += b_offset;
  1032. }
  1033. }
  1034. };
  1035. template <typename src_ctype, size_t channel_block_dim>
  1036. struct OpCallerTernaryBcast101xXVecBcast101xX {
  1037. template <typename Op>
  1038. static void run(
  1039. const src_ctype* src0, const src_ctype* src1, const src_ctype* src2,
  1040. typename Op::dst_ctype* dst, const Op& op, size_t batch,
  1041. size_t nr_channel_blocks, size_t channel_stride) {
  1042. for (size_t b = 0; b < batch; b++) {
  1043. auto src0_ptr = src0;
  1044. auto src2_ptr = src2;
  1045. for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
  1046. auto src0_block_ptr = src0_ptr + cb * channel_block_dim;
  1047. auto src2_block_ptr = src2_ptr + cb * channel_block_dim;
  1048. for (size_t img_index = 0; img_index < channel_stride; img_index++) {
  1049. for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) {
  1050. op(*(src0_block_ptr + c_iter), *src1,
  1051. *(src2_block_ptr + c_iter), dst);
  1052. src1++;
  1053. dst++;
  1054. }
  1055. }
  1056. }
  1057. }
  1058. }
  1059. };
  1060. template <typename src_ctype, size_t channel_block_dim>
  1061. struct OpCallerTernaryBcast101xDVecBcast101xD {
  1062. template <typename Op, typename Vis0, typename Vis1, typename Vis2>
  1063. static void run(
  1064. const src_ctype* src0, const src_ctype* src1, const src_ctype* src2,
  1065. typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0,
  1066. const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks,
  1067. size_t channel_stride) {
  1068. for (size_t b = 0; b < batch; b++) {
  1069. auto src0_ptr = src0;
  1070. auto src2_ptr = src2;
  1071. for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
  1072. auto src0_block_ptr = src0_ptr + cb * channel_block_dim;
  1073. auto src2_block_ptr = src2_ptr + cb * channel_block_dim;
  1074. auto channel_block_vec0_v2 = vis0(src0_block_ptr);
  1075. auto channel_block_vec2_v2 = vis2(src2_block_ptr);
  1076. size_t img_index = 0;
  1077. auto src1_offset = Op::SIMD_WIDTH / channel_block_dim;
  1078. for (; img_index + 2 * src1_offset <= channel_stride;
  1079. img_index += 2 * src1_offset) {
  1080. op(channel_block_vec0_v2, vis1(src1, src1 + Op::SIMD_WIDTH),
  1081. channel_block_vec2_v2, dst);
  1082. src1 += Op::SIMD_WIDTH * 2;
  1083. dst += Op::SIMD_WIDTH * 2;
  1084. }
  1085. // TODO:all elemwise_multi_type op imp one simd mode
  1086. for (; img_index < channel_stride; img_index++) {
  1087. for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) {
  1088. op(*(src0_block_ptr + c_iter), *src1,
  1089. *(src2_block_ptr + c_iter), dst);
  1090. src1++;
  1091. dst++;
  1092. }
  1093. }
  1094. }
  1095. }
  1096. }
  1097. };
  1098. //! src0: CHW44, src1: vector, src2: CHW44
  1099. template <typename src_ctype>
  1100. struct OpCallerTernaryBcast101xXVecBcast101xX<src_ctype, 4> {
  1101. template <typename Op>
  1102. static void run(
  1103. const src_ctype* src0, const src_ctype* src1, const src_ctype* src2,
  1104. typename Op::dst_ctype* dst, const Op& op, size_t batch,
  1105. size_t nr_channel_blocks, size_t channel_stride) {
  1106. ParamElemVisitorBcast101x4V2<src_ctype> vis0;
  1107. ParamElemVisitorV2<src_ctype> vis1;
  1108. ParamElemVisitorBcast101x4V2<src_ctype> vis2;
  1109. OpCallerTernaryBcast101xDVecBcast101xD<src_ctype, 4>::run(
  1110. src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks,
  1111. channel_stride);
  1112. }
  1113. };
  1114. template <typename Op>
  1115. struct OpCallerTernary<Op, BCAST101xX_VEC_BCAST101xX> {
  1116. static void run(
  1117. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  1118. const typename Op::src_ctype* src2, typename Op::dst_ctype* dst,
  1119. DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype,
  1120. size_t batch, size_t nr_channel_blocks, size_t channel_stride,
  1121. size_t channel_block_dim) {
  1122. megdnn_assert(
  1123. channel_block_dim == 4 || channel_block_dim == 8,
  1124. "only imp for nchw44/nchw88");
  1125. Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
  1126. if (channel_block_dim == 4) {
  1127. OpCallerTernaryBcast101xXVecBcast101xX<typename Op::src_ctype, 4>::run(
  1128. src0, src1, src2, dst, op, batch, nr_channel_blocks,
  1129. channel_stride);
  1130. } else {
  1131. OpCallerTernaryBcast101xXVecBcast101xX<typename Op::src_ctype, 8>::run(
  1132. src0, src1, src2, dst, op, batch, nr_channel_blocks,
  1133. channel_stride);
  1134. }
  1135. }
  1136. };
  1137. //! src1: 1C11, src0 and src2 are contig
  1138. template <typename Op>
  1139. struct OpCallerTernary<Op, VEC_BCAST101_VEC> {
  1140. static void run(
  1141. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  1142. const typename Op::src_ctype* src2, typename Op::dst_ctype* dst,
  1143. DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype,
  1144. size_t batch_size, size_t channel_size, size_t channel_stride) {
  1145. Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
  1146. ParamElemVisitorV2<typename Op::src_ctype> vis0;
  1147. ParamElemVisitorDupV2<typename Op::src_ctype> vis1;
  1148. ParamElemVisitorV2<typename Op::src_ctype> vis2;
  1149. for (size_t batch = 0; batch < batch_size; batch++) {
  1150. auto src1_ptr = src1;
  1151. for (size_t channel = 0; channel < channel_size; channel++) {
  1152. size_t i = 0;
  1153. auto src1_simd_v2 = vis1(src1_ptr);
  1154. for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
  1155. i += Op::SIMD_WIDTH * 2) {
  1156. op(vis0(src0, src0 + Op::SIMD_WIDTH), src1_simd_v2,
  1157. vis2(src2, src2 + Op::SIMD_WIDTH), dst);
  1158. src0 += Op::SIMD_WIDTH * 2;
  1159. src2 += Op::SIMD_WIDTH * 2;
  1160. dst += Op::SIMD_WIDTH * 2;
  1161. }
  1162. #if MEGDNN_FIX_AARCH32_BUG
  1163. // FIXME: as llvm may cause cannot select error if enable vectorize
  1164. #pragma clang loop vectorize(disable)
  1165. #endif
  1166. for (; i < channel_stride; i++) {
  1167. op(*src0, *src1_ptr, *src2, dst);
  1168. src0++;
  1169. src2++;
  1170. dst++;
  1171. }
  1172. src1_ptr++;
  1173. }
  1174. }
  1175. }
  1176. };
  1177. //! src1: 111C, src0 and src2 may not be contig
  1178. template <typename Op>
  1179. struct OpCallerTernary<Op, VEC_BCAST111C_VEC> {
  1180. static void run(
  1181. const typename Op::src_ctype* src0, size_t src0_offset,
  1182. const typename Op::src_ctype* src1, const typename Op::src_ctype* src2,
  1183. size_t src2_offset, typename Op::dst_ctype* dst, DType src0_dtype,
  1184. DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size,
  1185. size_t channel_size, size_t channel_stride) {
  1186. Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
  1187. ParamElemVisitorV2<typename Op::src_ctype> vis0;
  1188. ParamElemVisitorV2<typename Op::src_ctype> vis1;
  1189. ParamElemVisitorV2<typename Op::src_ctype> vis2;
  1190. for (size_t batch = 0; batch < batch_size; batch++) {
  1191. for (size_t channel = 0; channel < channel_size; channel++) {
  1192. auto src1_ptr = src1;
  1193. size_t i = 0;
  1194. for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
  1195. i += Op::SIMD_WIDTH * 2) {
  1196. op(vis0(src0, src0 + Op::SIMD_WIDTH),
  1197. vis1(src1_ptr, src1_ptr + Op::SIMD_WIDTH),
  1198. vis2(src2, src2 + Op::SIMD_WIDTH), dst);
  1199. src0 += Op::SIMD_WIDTH * 2;
  1200. src1_ptr += Op::SIMD_WIDTH * 2;
  1201. src2 += Op::SIMD_WIDTH * 2;
  1202. dst += Op::SIMD_WIDTH * 2;
  1203. }
  1204. #if MEGDNN_FIX_AARCH32_BUG
  1205. // FIXME: as llvm may cause cannot select error if enable vectorize
  1206. #pragma clang loop vectorize(disable)
  1207. #endif
  1208. for (; i < channel_stride; i++) {
  1209. op(*src0, *src1_ptr, *src2, dst);
  1210. src0++;
  1211. src1_ptr++;
  1212. src2++;
  1213. dst++;
  1214. }
  1215. src0 += src0_offset;
  1216. src2 += src2_offset;
  1217. }
  1218. }
  1219. }
  1220. };
  1221. template <typename src_ctype, size_t channel_block_dim>
  1222. struct OpCallerTernaryVecBcast101xXVec {
  1223. template <typename Op>
  1224. static void run(
  1225. const src_ctype* src0, const src_ctype* src1, const src_ctype* src2,
  1226. typename Op::dst_ctype* dst, const Op& op, size_t batch,
  1227. size_t nr_channel_blocks, size_t channel_stride) {
  1228. for (size_t b = 0; b < batch; b++) {
  1229. auto src1_ptr = src1;
  1230. for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
  1231. auto src1_block_ptr = src1_ptr + cb * channel_block_dim;
  1232. for (size_t img_index = 0; img_index < channel_stride; img_index++) {
  1233. for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) {
  1234. op(*src0, *(src1_block_ptr + c_iter), *src2, dst);
  1235. src0++;
  1236. src2++;
  1237. dst++;
  1238. }
  1239. }
  1240. }
  1241. }
  1242. }
  1243. };
  1244. //! src1: CHW44, src0 and src2 are contig
  1245. template <typename src_ctype, size_t channel_block_dim>
  1246. struct OpCallerTernaryVecBcast101xDVec {
  1247. template <typename Op, typename Vis0, typename Vis1, typename Vis2>
  1248. static void run(
  1249. const src_ctype* src0, const src_ctype* src1, const src_ctype* src2,
  1250. typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0,
  1251. const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks,
  1252. size_t channel_stride) {
  1253. for (size_t b = 0; b < batch; b++) {
  1254. auto src1_ptr = src1;
  1255. for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
  1256. auto src1_block_ptr = src1_ptr + cb * channel_block_dim;
  1257. auto channel_block_vec_v2 = vis1(src1_block_ptr);
  1258. size_t img_index = 0;
  1259. auto offset = Op::SIMD_WIDTH / channel_block_dim;
  1260. for (; img_index + 2 * offset <= channel_stride;
  1261. img_index += 2 * offset) {
  1262. op(vis0(src0, src0 + Op::SIMD_WIDTH), channel_block_vec_v2,
  1263. vis2(src2, src2 + Op::SIMD_WIDTH), dst);
  1264. src0 += Op::SIMD_WIDTH * 2;
  1265. src2 += Op::SIMD_WIDTH * 2;
  1266. dst += Op::SIMD_WIDTH * 2;
  1267. }
  1268. // TODO:all elemwise_multi_type op imp one simd mode
  1269. for (; img_index < channel_stride; img_index++) {
  1270. for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) {
  1271. op(*src0, *(src1_block_ptr + c_iter), *src2, dst);
  1272. src0++;
  1273. src2++;
  1274. dst++;
  1275. }
  1276. }
  1277. }
  1278. }
  1279. }
  1280. };
  1281. template <typename src_ctype>
  1282. struct OpCallerTernaryVecBcast101xXVec<src_ctype, 4> {
  1283. template <typename Op>
  1284. static void run(
  1285. const src_ctype* src0, const src_ctype* src1, const src_ctype* src2,
  1286. typename Op::dst_ctype* dst, const Op& op, size_t batch,
  1287. size_t nr_channel_blocks, size_t channel_stride) {
  1288. ParamElemVisitorV2<src_ctype> vis0;
  1289. ParamElemVisitorBcast101x4V2<src_ctype> vis1;
  1290. ParamElemVisitorV2<src_ctype> vis2;
  1291. OpCallerTernaryVecBcast101xDVec<src_ctype, 4>::run(
  1292. src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks,
  1293. channel_stride);
  1294. }
  1295. };
  1296. template <typename Op>
  1297. struct OpCallerTernary<Op, VEC_BCAST101xX_VEC> {
  1298. static void run(
  1299. const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
  1300. const typename Op::src_ctype* src2, typename Op::dst_ctype* dst,
  1301. DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype,
  1302. size_t batch, size_t nr_channel_blocks, size_t channel_stride,
  1303. size_t channel_block_dim) {
  1304. megdnn_assert(
  1305. channel_block_dim == 4 || channel_block_dim == 8,
  1306. "only imp for nchw44/nchw88");
  1307. Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
  1308. if (channel_block_dim == 4) {
  1309. OpCallerTernaryVecBcast101xXVec<typename Op::src_ctype, 4>::run(
  1310. src0, src1, src2, dst, op, batch, nr_channel_blocks,
  1311. channel_stride);
  1312. } else {
  1313. OpCallerTernaryVecBcast101xXVec<typename Op::src_ctype, 8>::run(
  1314. src0, src1, src2, dst, op, batch, nr_channel_blocks,
  1315. channel_stride);
  1316. }
  1317. }
  1318. };
  1319. //! src1: scalar, src0 and src2 has the same shape
  1320. template <typename Op>
  1321. struct OpCallerTernary<Op, VEC_SCALAR_VEC> {
  1322. static void run(
  1323. const typename Op::src_ctype* src0, const typename Op::src_ctype src1,
  1324. const typename Op::src_ctype* src2, typename Op::dst_ctype* dst,
  1325. DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype,
  1326. size_t nr_elems) {
  1327. Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
  1328. ParamElemVisitorV2<typename Op::src_ctype> vis0;
  1329. ParamElemVisitorDupV2<typename Op::src_ctype> vis1;
  1330. ParamElemVisitorV2<typename Op::src_ctype> vis2;
  1331. auto vis1_simd_v2 = vis1(&src1);
  1332. size_t i = 0;
  1333. for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) {
  1334. op(vis0(src0, src0 + Op::SIMD_WIDTH), vis1_simd_v2,
  1335. vis2(src2, src2 + Op::SIMD_WIDTH), dst);
  1336. src0 += Op::SIMD_WIDTH * 2;
  1337. src2 += Op::SIMD_WIDTH * 2;
  1338. dst += Op::SIMD_WIDTH * 2;
  1339. }
  1340. #if MEGDNN_FIX_AARCH32_BUG
  1341. // FIXME: as llvm may cause cannot select error if enable vectorize
  1342. #pragma clang loop vectorize(disable)
  1343. #endif
  1344. for (; i < nr_elems; i++) {
  1345. op(*src0, src1, *src2, dst);
  1346. src0++;
  1347. src2++;
  1348. dst++;
  1349. }
  1350. }
  1351. };
  1352. //! src1, src2: scalar, src0 is vector
  1353. template <typename Op>
  1354. struct OpCallerTernary<Op, VEC_SCALAR_SCALAR> {
  1355. static void run(
  1356. const typename Op::src_ctype* src0, const typename Op::src_ctype src1,
  1357. const typename Op::src_ctype src2, typename Op::dst_ctype* dst,
  1358. DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype,
  1359. size_t nr_elems) {
  1360. Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
  1361. ParamElemVisitorV2<typename Op::src_ctype> vis0;
  1362. ParamElemVisitorDupV2<typename Op::src_ctype> vis1;
  1363. ParamElemVisitorDupV2<typename Op::src_ctype> vis2;
  1364. auto vis1_simd_v2 = vis1(&src1);
  1365. auto vis2_simd_v2 = vis2(&src2);
  1366. size_t i = 0;
  1367. for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) {
  1368. op(vis0(src0, src0 + Op::SIMD_WIDTH), vis1_simd_v2, vis2_simd_v2, dst);
  1369. src0 += Op::SIMD_WIDTH * 2;
  1370. dst += Op::SIMD_WIDTH * 2;
  1371. }
  1372. #if MEGDNN_FIX_AARCH32_BUG
  1373. // FIXME: as llvm may cause cannot select error if enable vectorize
  1374. #pragma clang loop vectorize(disable)
  1375. #endif
  1376. for (; i < nr_elems; i++) {
  1377. op(*src0, src1, src2, dst);
  1378. src0++;
  1379. dst++;
  1380. }
  1381. }
  1382. };
  1383. } // namespace elemwise
  1384. } // namespace megdnn
  1385. // vim: syntax=cpp.doxygen