You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

strategy_6x3_8x8.cpp 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. /**
  2. * \file dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/common/unroll_macro.h"
  13. #include "src/common/utils.h"
  14. #include "src/common/winograd/winograd_helper.h"
  15. #include "src/fallback/conv_bias/winograd/winograd.h"
  16. #include "src/x86/conv_bias/f32/strategy.h"
  17. #include "src/x86/elemwise_helper/op_unary.h"
  18. #include "src/x86/simd_helper.h"
  19. #include <x86intrin.h>
  20. #ifdef WIN32CMAKE
  21. #include <avx2intrin.h>
  22. #include <avxintrin.h>
  23. #include <fmaintrin.h>
  24. #include <smmintrin.h>
  25. #endif
  26. #include "midout.h"
  27. MIDOUT_DECL(megdnn_x86_winograd_nchw88_fp32_F63_8x8)
  28. using namespace megdnn;
  29. using namespace x86;
  30. namespace {
  31. constexpr size_t alpha = 6 + 3 - 1;
  32. struct InputTransform6X3_NCHW88 {
  33. template <bool inner>
  34. MEGDNN_ATTRIBUTE_TARGET("avx2")
  35. static void prepare(const float* input, float* patch, float* patchT,
  36. int ih_start, int iw_start, size_t IH, size_t IW,
  37. size_t ic, size_t IC) {
  38. MEGDNN_MARK_USED_VAR(patch);
  39. size_t IW8 = IW * 8; //! For nchw88 mode
  40. size_t iw8_start = iw_start * 8; //! For nchw88 mode
  41. size_t icb = ic / 8;
  42. if (!(inner && ic + 8 < IC)) {
  43. memset(patchT, 0, sizeof(float) * 8 * alpha * alpha);
  44. }
  45. if (inner) {
  46. //! Copy to continue memory patchT,
  47. //! TODO:can be optimized
  48. const float* input_ptr =
  49. input + icb * IH * IW8 + ih_start * IW8 + iw8_start;
  50. for (size_t ih = 0; ih < alpha; ih++) {
  51. #define cb(i) auto v##i = _mm256_loadu_ps(input_ptr + 8 * i);
  52. UNROLL_CALL_NOWRAPPER(8, cb);
  53. #undef cb
  54. #define cb(i) _mm256_storeu_ps(patchT + ih * 8 * alpha + i * 8, v##i);
  55. UNROLL_CALL_NOWRAPPER(8, cb);
  56. #undef cb
  57. input_ptr += IW8;
  58. }
  59. } else {
  60. int ih0_act = std::max<int>(ih_start, 0),
  61. ih1_act = std::min<int>(ih_start + alpha, IH),
  62. iw0_act = std::max<int>(iw_start, 0),
  63. iw1_act = std::min<int>(iw_start + alpha, IW);
  64. const float* input_ptr = input + icb * IH * IW8;
  65. // partial copy
  66. for (int ih = ih0_act; ih < ih1_act; ++ih) {
  67. for (int iw = iw0_act; iw < iw1_act; ++iw) {
  68. size_t iho = ih - ih_start, iwo = iw - iw_start;
  69. auto src = _mm256_loadu_ps(input_ptr + ih * IW8 + iw * 8);
  70. _mm256_storeu_ps(patchT + iho * 8 * alpha + iwo * 8, src);
  71. }
  72. }
  73. }
  74. }
  75. MEGDNN_ATTRIBUTE_TARGET("avx2")
  76. static void transform(const float* patchT, float* input_transform_buf,
  77. size_t unit_idx, size_t nr_units_in_tile, size_t ic,
  78. size_t IC) {
  79. // BT * d * B
  80. #define cb(m, n) \
  81. Vector<float, 8> d##m##n = \
  82. Vector<float, 8>::load(patchT + m * 8 * 8 + n * 8);
  83. UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
  84. #undef cb
  85. //! B
  86. //! 1 0 0 0 0 0 0 0
  87. //! 0 1 -1 0.5 -0.5 2 -2 -1
  88. //! -5.25 1 1 0.25 0.25 4 4 0
  89. //! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25
  90. //! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0
  91. //! 0 1 -1 2 -2 0.5 -0.5 -5.25
  92. //! -1 1 1 1 1 1 1 0
  93. //! 0 0 0 0 0 0 0 1
  94. #define cb(m) \
  95. auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \
  96. auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \
  97. auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \
  98. auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \
  99. d5##m * 2.f + d6##m; \
  100. auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - \
  101. d4##m * 1.25f - d5##m * 2.f + d6##m; \
  102. auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \
  103. d5##m * 0.5f + d6##m; \
  104. auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \
  105. d5##m * 0.5f + d6##m; \
  106. auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f;
  107. UNROLL_CALL_NOWRAPPER(8, cb);
  108. #undef cb
  109. #define cb(m) \
  110. d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \
  111. d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - \
  112. (t##m##3 + t##m##4) * 4.25f; \
  113. d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + \
  114. (t##m##3 - t##m##4) * 4.25f; \
  115. d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - \
  116. t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; \
  117. d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - \
  118. t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; \
  119. d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \
  120. t##m##5 * 0.5f + t##m##6; \
  121. d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - \
  122. t##m##4 * 5.f - t##m##5 * 0.5f + t##m##6; \
  123. d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f;
  124. UNROLL_CALL_NOWRAPPER(8, cb);
  125. #undef cb
  126. size_t ICB = IC / 8;
  127. size_t icb = ic / 8;
  128. #define cb(m, n) \
  129. d##m##n.save(input_transform_buf + \
  130. (m * alpha + n) * ICB * nr_units_in_tile * 8 + \
  131. icb * nr_units_in_tile * 8 + unit_idx * 8);
  132. UNROLL_CALL_NOWRAPPER_D2(8, 8, cb)
  133. #undef cb
  134. }
  135. };
  136. struct FilterTransform6X3_MCHW88 {
  137. MEGDNN_ATTRIBUTE_TARGET("avx2")
  138. static void transform(const float* filter, float* filter_transform_buf,
  139. float* transform_mid_buf, size_t OC, size_t IC,
  140. size_t oc_start, size_t oc_end) {
  141. // Gg * GT
  142. // G
  143. // 1.0000000 0.0000000 0.0000000
  144. // -0.2222222 -0.2222222 -0.2222222
  145. // -0.2222222 0.2222222 -0.2222222
  146. // 0.0111111 0.0222222 0.0444444
  147. // 0.0111111 -0.0222222 0.0444444
  148. // 0.7111111 0.3555556 0.1777778
  149. // 0.7111111 -0.3555556 0.1777778
  150. // 0.0000000 0.0000000 1.0000000
  151. MEGDNN_MARK_USED_VAR(transform_mid_buf);
  152. megdnn_assert(
  153. (oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 &&
  154. oc_end % 8 == 0 && IC % 8 == 0 && OC % 8 == 0,
  155. "Winograd filter transform input param is not times of 8!");
  156. size_t OCB = OC / 8;
  157. size_t ICB = IC / 8;
  158. for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) {
  159. for (size_t icb = 0; icb < ICB; icb++) {
  160. for (size_t ic_inner = 0; ic_inner < 8; ic_inner++) {
  161. const float* fptr = filter +
  162. (ocb * ICB + icb) * 3 * 3 * 8 * 8 +
  163. ic_inner * 8;
  164. #define cb(m, n) \
  165. Vector<float, 8> g##m##n = \
  166. Vector<float, 8>::load(fptr + (m * 3 + n) * 8 * 8);
  167. UNROLL_CALL_NOWRAPPER_D2(3, 3, cb)
  168. #undef cb
  169. #define FILTER_TRANSFORM(n, wd, g) \
  170. auto wd##n##0 = g##0##n; \
  171. tmp0 = (g##0##n + g##2##n) * -0.2222222f; \
  172. tmp1 = g##1##n * -0.2222222f; \
  173. auto wd##n##1 = tmp0 + tmp1; \
  174. auto wd##n##2 = tmp0 - tmp1; \
  175. tmp0 = g##0##n * 0.0111111f + g##2##n * 0.0444444f; \
  176. tmp1 = g##1##n * 0.0222222f; \
  177. auto wd##n##3 = tmp0 + tmp1; \
  178. auto wd##n##4 = tmp0 - tmp1; \
  179. tmp0 = g##0##n * 0.7111111f + g##2##n * 0.1777778f; \
  180. tmp1 = g##1##n * 0.3555556f; \
  181. auto wd##n##5 = tmp0 + tmp1; \
  182. auto wd##n##6 = tmp0 - tmp1; \
  183. auto wd##n##7 = g##2##n;
  184. Vector<float, 8> tmp0, tmp1;
  185. UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g);
  186. UNROLL_CALL_RAW(8, FILTER_TRANSFORM, ret, wd);
  187. #undef FILTER_TRANSFORM
  188. #define cb_save(m, n) \
  189. ret##m##n.save(filter_transform_buf + \
  190. (m * alpha + n) * OCB * ICB * 8 * 8 + ocb * ICB * 8 * 8 + \
  191. icb * 8 * 8 + ic_inner * 8);
  192. UNROLL_CALL_NOWRAPPER_D2(8, 8, cb_save)
  193. #undef cb_save
  194. }
  195. }
  196. }
  197. }
  198. };
  199. #define CONCAT(a, idx) a##idx
  200. template <BiasMode bmode, typename Op>
  201. struct OutputTransform6X3_NCHW88 {
  202. MEGDNN_ATTRIBUTE_TARGET("avx2")
  203. static void transform(const float* output_transform_buf, const float* bias,
  204. float* output, float* transform_mid_buf,
  205. size_t oh_start, size_t ow_start, size_t OH,
  206. size_t OW, size_t oc_start, size_t oc_end,
  207. size_t oc_index, size_t unit_idx,
  208. size_t nr_units_in_tile, const DType& src_dtype,
  209. const DType& dst_dtype) {
  210. MEGDNN_MARK_USED_VAR(transform_mid_buf);
  211. Op op(src_dtype, dst_dtype);
  212. //! AT * m * A
  213. size_t OCB = (oc_end - oc_start) / 8;
  214. size_t oc = oc_start + oc_index;
  215. size_t ocb = oc_index / 8;
  216. #define cb(m, n) \
  217. auto v##m##n = Vector<float, 8>::load( \
  218. output_transform_buf + \
  219. (m * alpha + n) * OCB * nr_units_in_tile * 8 + \
  220. ocb * nr_units_in_tile * 8 + unit_idx * 8);
  221. UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
  222. #undef cb
  223. /**
  224. * A
  225. *
  226. * 1 0 0 0 0 0
  227. * 1 1 1 1 1 1
  228. * 1 -1 1 -1 1 -1
  229. * 1 2 4 8 16 32
  230. * 1 -2 4 -8 16 -32
  231. * 1 0.5 0.25 0.125 0.0625 0.03125
  232. * 1 -0.5 0.25 -0.125 0.0625 -0.03125
  233. * 0 0.0 0 0 0 1
  234. */
  235. Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
  236. #define cb(m) \
  237. v1addv2 = v1##m + v2##m; \
  238. v1subv2 = v1##m - v2##m; \
  239. v3addv4 = v3##m + v4##m; \
  240. v3subv4 = v3##m - v4##m; \
  241. v5addv6 = v5##m + v6##m; \
  242. v5subv6 = v5##m - v6##m; \
  243. auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \
  244. auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \
  245. auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \
  246. auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \
  247. auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
  248. auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m;
  249. UNROLL_CALL_NOWRAPPER(8, cb);
  250. #undef cb
  251. #define cb(m) \
  252. v1addv2 = t##m##1 + t##m##2; \
  253. v1subv2 = t##m##1 - t##m##2; \
  254. v3addv4 = t##m##3 + t##m##4; \
  255. v3subv4 = t##m##3 - t##m##4; \
  256. v5addv6 = t##m##5 + t##m##6; \
  257. v5subv6 = t##m##5 - t##m##6; \
  258. v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \
  259. v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \
  260. v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \
  261. v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \
  262. v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
  263. v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7;
  264. UNROLL_CALL_NOWRAPPER(6, cb);
  265. #undef cb
  266. Vector<float, 8> vbias;
  267. if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
  268. vbias = Vector<float, 8>::load(bias + oc);
  269. #define cb(m, n) v##m##n += vbias;
  270. UNROLL_CALL_RAW_D2(6, 6, cb);
  271. #undef cb
  272. }
  273. if (bmode != BiasMode::BIAS) {
  274. #define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
  275. UNROLL_CALL_RAW_D2(6, 6, cb);
  276. #undef cb
  277. }
  278. #define out_save(oho, owo) \
  279. do { \
  280. size_t oh = oh_start + oho; \
  281. size_t ow = ow_start + owo; \
  282. if (oh < OH && ow < OW) { \
  283. if (bmode == BiasMode::BIAS) { \
  284. v##oho##owo += Vector<float, 8>::load( \
  285. bias + oc / 8 * OH * OW * 8 + oh * OW * 8 + ow * 8); \
  286. v##oho##owo = op(v##oho##owo.value); \
  287. } \
  288. v##oho##owo.save(output + oc / 8 * OH * OW * 8 + oh * OW * 8 + \
  289. ow * 8); \
  290. } \
  291. } while (0);
  292. UNROLL_CALL_RAW_D2(6, 6, out_save);
  293. }
  294. };
  295. #undef CONCAT
  296. } // namespace
  297. namespace megdnn {
  298. namespace x86 {
  299. namespace winograd {
  300. MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_nchw88_6x3_8x8_f)
  301. void winograd_nchw88_6x3_8x8_f::filter(const float* filter,
  302. float* filter_transform_buf,
  303. float* transform_mid_buf, size_t OC,
  304. size_t IC, size_t oc_start,
  305. size_t oc_end) {
  306. FilterTransform6X3_MCHW88::transform(filter, filter_transform_buf,
  307. transform_mid_buf, OC, IC, oc_start,
  308. oc_end);
  309. }
  310. void winograd_nchw88_6x3_8x8_f::input(const float* input,
  311. float* input_transform_buf,
  312. float* transform_mid_buf, size_t IH,
  313. size_t IW, size_t IC, size_t PH,
  314. size_t PW, size_t unit_start_idx,
  315. size_t nr_units_in_tile) {
  316. megdnn_assert(IC % 8 == 0);
  317. // OW = IW + 2 * PW - KERNEL_SIZE + 1
  318. auto units_w =
  319. div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
  320. float* patch = transform_mid_buf;
  321. float* patchT = transform_mid_buf + 8 * alpha * alpha;
  322. for (size_t ic = 0; ic < IC; ic += 8) {
  323. rep(unit_idx, nr_units_in_tile) {
  324. size_t index = unit_start_idx + unit_idx;
  325. size_t nh = index / units_w;
  326. size_t nw = index % units_w;
  327. int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
  328. int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
  329. if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) &&
  330. iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) {
  331. InputTransform6X3_NCHW88::prepare<true>(input, patch, patchT,
  332. ih_start, iw_start, IH,
  333. IW, ic, IC);
  334. InputTransform6X3_NCHW88::transform(patchT, input_transform_buf,
  335. unit_idx, nr_units_in_tile,
  336. ic, IC);
  337. } else {
  338. InputTransform6X3_NCHW88::prepare<false>(input, patch, patchT,
  339. ih_start, iw_start, IH,
  340. IW, ic, IC);
  341. InputTransform6X3_NCHW88::transform(patchT, input_transform_buf,
  342. unit_idx, nr_units_in_tile,
  343. ic, IC);
  344. }
  345. }
  346. }
  347. }
  348. void winograd_nchw88_6x3_8x8_f::output(const float* output_transform_buf,
  349. const float* bias, float* output,
  350. float* transform_mid_buf, BiasMode bmode,
  351. NonlineMode nonline_mode, size_t OH,
  352. size_t OW, size_t oc_start,
  353. size_t oc_end, size_t unit_start_idx,
  354. size_t nr_units_in_tile) {
  355. #define cb(_bmode, _nonline_op, ...) \
  356. OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
  357. __VA_ARGS__);
  358. auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
  359. size_t OC = oc_end - oc_start;
  360. megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0,
  361. "Winograd output transform input param is not times of 8!");
  362. for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
  363. size_t oc_index = oc - oc_start;
  364. rep(unit_idx, nr_units_in_tile) {
  365. size_t index = unit_start_idx + unit_idx;
  366. auto nh = index / units_w;
  367. auto nw = index % units_w;
  368. size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
  369. size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
  370. DISPATCH_CONV_WINOGRAD_BIAS(
  371. megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2,
  372. float, float, bmode, nonline_mode, output_transform_buf,
  373. bias, output, transform_mid_buf, oh_start, ow_start, OH, OW,
  374. oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile,
  375. src_dtype, dst_dtype);
  376. }
  377. }
  378. #undef cb
  379. }
  380. } // namespace winograd
  381. } // namespace x86
  382. } // namespace megdnn
  383. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)