You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

local_def.inl 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. /**
  2. * \file dnn/src/common/local/local_def.inl
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. // simd_macro/*_helper.h should be included before including this file.
  12. //
  13. // The following functions would be defined in this file:
  14. //
  15. // void local_xcorr_MEGDNN_SIMD_NAME(const LocalKParam &kparam);
  16. // void local_conv_MEGDNN_SIMD_NAME(const LocalKParam &kparam);
  17. //
  18. #include "src/common/local/local_decl.inl"
  19. #include "src/common/macro_helper.h"
  20. #include "src/common/utils.h"
  21. namespace {
  22. using namespace megdnn;
  23. template <int N, int OC>
  24. void local_xcorr_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET;
  25. template <int N, int OC>
  26. void local_xcorr_tpl(const LocalKParam& kparam) {
  27. const float* src = static_cast<const float*>(kparam.src.get_ptr());
  28. const float* filter = static_cast<const float*>(kparam.filter.get_ptr());
  29. float* dst = static_cast<float*>(kparam.dst.get_ptr());
  30. float* workspace = static_cast<float*>(kparam.workspace);
  31. const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh,
  32. OW = kparam.ow, FH = kparam.fh, FW = kparam.fw;
  33. const uint32_t PH = kparam.ph, PW = kparam.pw, SH = kparam.sh, SW = kparam.sw;
  34. const ptrdiff_t INP_BS = kparam.inp_bs, OUT_BS = kparam.out_bs;
  35. float* dst2 = workspace;
  36. const int width = MEGDNN_SIMD_WIDTH;
  37. // dst2 is (H, W, N, C)
  38. memset(dst2, 0, sizeof(float) * OH * OW * N * OC);
  39. float* dst2_hwnc = dst2;
  40. rep(oh, OH) rep(ow, OW) {
  41. const float* src_bak = src;
  42. rep(ic, IC) {
  43. rep(fh, FH) for (int fw = 0; fw < FW; ++fw, filter += OC) {
  44. int ih = -PH + oh * SH + fh;
  45. int iw = -PW + ow * SW + fw;
  46. if (ih < 0 || ih >= IH || iw < 0 || iw >= IW)
  47. continue;
  48. float* dst2_bak = dst2;
  49. rep(n, N) {
  50. float s = src[n * INP_BS + ih * IW + iw];
  51. const float* filter_bak = filter;
  52. MEGDNN_SIMD_TYPE vs = MEGDNN_SIMD_SET1(s);
  53. int oc = 0;
  54. for (; oc + 4 * width <= OC; oc += 4 * width, filter += 4 * width) {
  55. MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
  56. MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
  57. MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2 * width);
  58. MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3 * width);
  59. MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
  60. MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
  61. MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2 * width);
  62. MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3 * width);
  63. vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
  64. vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
  65. vd2 = MEGDNN_SIMD_FMADD(vf2, vs, vd2);
  66. vd3 = MEGDNN_SIMD_FMADD(vf3, vs, vd3);
  67. MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
  68. MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
  69. MEGDNN_SIMD_STOREU(dst2 + oc + 2 * width, vd2);
  70. MEGDNN_SIMD_STOREU(dst2 + oc + 3 * width, vd3);
  71. }
  72. if (oc + 2 * width <= OC) {
  73. MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
  74. MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
  75. MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
  76. MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
  77. vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
  78. vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
  79. MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
  80. MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
  81. oc += 2 * width;
  82. filter += 2 * width;
  83. }
  84. if (oc + 1 * width <= OC) {
  85. MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
  86. MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
  87. vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
  88. MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
  89. oc += 1 * width;
  90. filter += 1 * width;
  91. }
  92. for (; oc < OC; ++oc, ++filter) {
  93. dst2[oc] += s * (*filter);
  94. }
  95. filter = filter_bak;
  96. dst2 += OC;
  97. }
  98. dst2 = dst2_bak;
  99. }
  100. src += IH * IW;
  101. }
  102. src = src_bak;
  103. dst2 += N * OC;
  104. }
  105. transpose_knc2nsck(dst2_hwnc, dst, OH * OW, N, OC, OUT_BS);
  106. }
  107. void local_xcorr_generic(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET;
  108. void local_xcorr_generic(const LocalKParam& kparam) {
  109. UNPACK_LOCAL_FLOAT_NONCONTIG_BATCH_KERN_PARAM(kparam, float);
  110. float* dst2 = workspace;
  111. const int width = MEGDNN_SIMD_WIDTH;
  112. // dst2 is (H, W, N, C)
  113. memset(dst2, 0, sizeof(float) * OH * OW * N * OC);
  114. float* dst2_hwnc = dst2;
  115. rep(oh, OH) rep(ow, OW) {
  116. const float* src_bak = src;
  117. rep(ic, IC) {
  118. rep(fh, FH) for (int fw = 0; fw < FW; ++fw, filter += OC) {
  119. int ih = -PH + oh * SH + fh;
  120. int iw = -PW + ow * SW + fw;
  121. if (ih < 0 || ih >= IH || iw < 0 || iw >= IW)
  122. continue;
  123. float* dst2_bak = dst2;
  124. rep(n, N) {
  125. float s = src[n * INP_BS + ih * IW + iw];
  126. const float* filter_bak = filter;
  127. MEGDNN_SIMD_TYPE vs = MEGDNN_SIMD_SET1(s);
  128. int oc = 0;
  129. for (; oc + 4 * width <= OC; oc += 4 * width, filter += 4 * width) {
  130. MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
  131. MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
  132. MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2 * width);
  133. MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3 * width);
  134. MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
  135. MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
  136. MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2 * width);
  137. MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3 * width);
  138. vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
  139. vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
  140. vd2 = MEGDNN_SIMD_FMADD(vf2, vs, vd2);
  141. vd3 = MEGDNN_SIMD_FMADD(vf3, vs, vd3);
  142. MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
  143. MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
  144. MEGDNN_SIMD_STOREU(dst2 + oc + 2 * width, vd2);
  145. MEGDNN_SIMD_STOREU(dst2 + oc + 3 * width, vd3);
  146. }
  147. if (oc + 2 * width <= OC) {
  148. MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
  149. MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
  150. MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
  151. MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
  152. vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
  153. vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
  154. MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
  155. MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
  156. oc += 2 * width;
  157. filter += 2 * width;
  158. }
  159. if (oc + 1 * width <= OC) {
  160. MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
  161. MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
  162. vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
  163. MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
  164. oc += 1 * width;
  165. filter += 1 * width;
  166. }
  167. for (; oc < OC; ++oc, ++filter) {
  168. dst2[oc] += s * (*filter);
  169. }
  170. filter = filter_bak;
  171. dst2 += OC;
  172. }
  173. dst2 = dst2_bak;
  174. }
  175. src += IH * IW;
  176. }
  177. src = src_bak;
  178. dst2 += N * OC;
  179. }
  180. transpose_knc2nsck(dst2_hwnc, dst, OH * OW, N, OC, OUT_BS);
  181. }
  182. template <int N, int OC>
  183. void local_conv_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET;
  184. template <int N, int OC>
  185. void local_conv_tpl(const LocalKParam& kparam) {
  186. const float* src = static_cast<const float*>(kparam.src.get_ptr());
  187. const float* filter = static_cast<const float*>(kparam.filter.get_ptr());
  188. float* dst = static_cast<float*>(kparam.dst.get_ptr());
  189. float* workspace = static_cast<float*>(kparam.workspace);
  190. const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh,
  191. OW = kparam.ow, FH = kparam.fh, FW = kparam.fw;
  192. const uint32_t PH = kparam.ph, PW = kparam.pw, SH = kparam.sh, SW = kparam.sw;
  193. const ptrdiff_t INP_BS = kparam.inp_bs, OUT_BS = kparam.out_bs;
  194. float* dst2 = workspace;
  195. const int width = MEGDNN_SIMD_WIDTH;
  196. // dst2 is (H, W, N, C)
  197. memset(dst2, 0, sizeof(float) * OH * OW * N * OC);
  198. float* dst2_hwnc = dst2;
  199. rep(oh, OH) rep(ow, OW) {
  200. const float* src_bak = src;
  201. rep(ic, IC) {
  202. rep(fh, FH) for (int fw = 0; fw < FW; ++fw, filter += OC) {
  203. int ih = -PH + oh * SH + (FH - fh - 1);
  204. int iw = -PW + ow * SW + (FW - fw - 1);
  205. if (ih < 0 || ih >= IH || iw < 0 || iw >= IW)
  206. continue;
  207. float* dst2_bak = dst2;
  208. rep(n, N) {
  209. float s = src[n * INP_BS + ih * IW + iw];
  210. const float* filter_bak = filter;
  211. MEGDNN_SIMD_TYPE vs = MEGDNN_SIMD_SET1(s);
  212. int oc = 0;
  213. for (; oc + 4 * width <= OC; oc += 4 * width, filter += 4 * width) {
  214. MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
  215. MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
  216. MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2 * width);
  217. MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3 * width);
  218. MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
  219. MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
  220. MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2 * width);
  221. MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3 * width);
  222. vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
  223. vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
  224. vd2 = MEGDNN_SIMD_FMADD(vf2, vs, vd2);
  225. vd3 = MEGDNN_SIMD_FMADD(vf3, vs, vd3);
  226. MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
  227. MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
  228. MEGDNN_SIMD_STOREU(dst2 + oc + 2 * width, vd2);
  229. MEGDNN_SIMD_STOREU(dst2 + oc + 3 * width, vd3);
  230. }
  231. if (oc + 2 * width <= OC) {
  232. MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
  233. MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
  234. MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
  235. MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
  236. vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
  237. vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
  238. MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
  239. MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
  240. oc += 2 * width;
  241. filter += 2 * width;
  242. }
  243. if (oc + 1 * width <= OC) {
  244. MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
  245. MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
  246. vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
  247. MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
  248. oc += 1 * width;
  249. filter += 1 * width;
  250. }
  251. for (; oc < OC; ++oc, ++filter) {
  252. dst2[oc] += s * (*filter);
  253. }
  254. filter = filter_bak;
  255. dst2 += OC;
  256. }
  257. dst2 = dst2_bak;
  258. }
  259. src += IH * IW;
  260. }
  261. src = src_bak;
  262. dst2 += N * OC;
  263. }
  264. transpose_knc2nsck(dst2_hwnc, dst, OH * OW, N, OC, OUT_BS);
  265. }
  266. void local_conv_generic(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET;
  267. void local_conv_generic(const LocalKParam& kparam) {
  268. UNPACK_LOCAL_FLOAT_NONCONTIG_BATCH_KERN_PARAM(kparam, float);
  269. float* dst2 = workspace;
  270. const int width = MEGDNN_SIMD_WIDTH;
  271. // dst2 is (H, W, N, C)
  272. memset(dst2, 0, sizeof(float) * OH * OW * N * OC);
  273. float* dst2_hwnc = dst2;
  274. rep(oh, OH) rep(ow, OW) {
  275. const float* src_bak = src;
  276. rep(ic, IC) {
  277. rep(fh, FH) for (int fw = 0; fw < FW; ++fw, filter += OC) {
  278. int ih = -PH + oh * SH + (FH - fh - 1);
  279. int iw = -PW + ow * SW + (FW - fw - 1);
  280. if (ih < 0 || ih >= IH || iw < 0 || iw >= IW)
  281. continue;
  282. float* dst2_bak = dst2;
  283. rep(n, N) {
  284. float s = src[n * INP_BS + ih * IW + iw];
  285. const float* filter_bak = filter;
  286. MEGDNN_SIMD_TYPE vs = MEGDNN_SIMD_SET1(s);
  287. int oc = 0;
  288. for (; oc + 4 * width <= OC; oc += 4 * width, filter += 4 * width) {
  289. MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
  290. MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
  291. MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2 * width);
  292. MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3 * width);
  293. MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
  294. MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
  295. MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2 * width);
  296. MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3 * width);
  297. vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
  298. vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
  299. vd2 = MEGDNN_SIMD_FMADD(vf2, vs, vd2);
  300. vd3 = MEGDNN_SIMD_FMADD(vf3, vs, vd3);
  301. MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
  302. MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
  303. MEGDNN_SIMD_STOREU(dst2 + oc + 2 * width, vd2);
  304. MEGDNN_SIMD_STOREU(dst2 + oc + 3 * width, vd3);
  305. }
  306. if (oc + 2 * width <= OC) {
  307. MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
  308. MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
  309. MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
  310. MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
  311. vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
  312. vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
  313. MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
  314. MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
  315. oc += 2 * width;
  316. filter += 2 * width;
  317. }
  318. if (oc + 1 * width <= OC) {
  319. MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
  320. MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
  321. vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
  322. MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
  323. oc += 1 * width;
  324. filter += 1 * width;
  325. }
  326. for (; oc < OC; ++oc, ++filter) {
  327. dst2[oc] += s * (*filter);
  328. }
  329. filter = filter_bak;
  330. dst2 += OC;
  331. }
  332. dst2 = dst2_bak;
  333. }
  334. src += IH * IW;
  335. }
  336. src = src_bak;
  337. dst2 += N * OC;
  338. }
  339. transpose_knc2nsck(dst2_hwnc, dst, OH * OW, N, OC, OUT_BS);
  340. }
  341. } // anonymous namespace
  342. namespace megdnn {
  343. #define FUNC_NAME CONCAT_STR(local_xcorr_, MEGDNN_SIMD_NAME)
  344. void FUNC_NAME(const LocalKParam& kparam) {
  345. auto N = kparam.n, OC = kparam.oc;
  346. #define DISPATCH_WITH_N_OC(N, OC) \
  347. do { \
  348. local_xcorr_tpl<N, OC>(kparam); \
  349. return; \
  350. } while (0)
  351. #define DISPATCH_WITH_N(N) \
  352. switch (OC) { \
  353. case 16: \
  354. DISPATCH_WITH_N_OC(N, 16); \
  355. break; \
  356. case 32: \
  357. DISPATCH_WITH_N_OC(N, 32); \
  358. break; \
  359. case 48: \
  360. DISPATCH_WITH_N_OC(N, 48); \
  361. break; \
  362. case 64: \
  363. DISPATCH_WITH_N_OC(N, 64); \
  364. break; \
  365. }
  366. #define DISPATCH() \
  367. switch (N) { \
  368. case 1: \
  369. DISPATCH_WITH_N(1); \
  370. break; \
  371. case 2: \
  372. DISPATCH_WITH_N(2); \
  373. break; \
  374. }
  375. DISPATCH();
  376. #undef DISPATCH
  377. #undef DISPATCH_WITH_N
  378. #undef DISPATCH_WITH_N_OC
  379. local_xcorr_generic(kparam);
  380. }
  381. #undef FUNC_NAME
  382. #define FUNC_NAME CONCAT_STR(local_conv_, MEGDNN_SIMD_NAME)
  383. void FUNC_NAME(const LocalKParam& kparam) {
  384. auto N = kparam.n, OC = kparam.oc;
  385. #define DISPATCH_WITH_N_OC(N, OC) \
  386. do { \
  387. local_conv_tpl<N, OC>(kparam); \
  388. return; \
  389. } while (0)
  390. #define DISPATCH_WITH_N(N) \
  391. switch (OC) { \
  392. case 16: \
  393. DISPATCH_WITH_N_OC(N, 16); \
  394. break; \
  395. case 32: \
  396. DISPATCH_WITH_N_OC(N, 32); \
  397. break; \
  398. case 48: \
  399. DISPATCH_WITH_N_OC(N, 48); \
  400. break; \
  401. case 64: \
  402. DISPATCH_WITH_N_OC(N, 64); \
  403. break; \
  404. }
  405. #define DISPATCH() \
  406. switch (N) { \
  407. case 1: \
  408. DISPATCH_WITH_N(1); \
  409. break; \
  410. case 2: \
  411. DISPATCH_WITH_N(2); \
  412. break; \
  413. }
  414. DISPATCH();
  415. #undef DISPATCH
  416. #undef DISPATCH_WITH_N
  417. #undef DISPATCH_WITH_N_OC
  418. local_conv_generic(kparam);
  419. }
  420. #undef FUNC_NAME
  421. } // namespace megdnn
  422. #include "src/common/macro_helper_epilogue.h"

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台