You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

winograd_helper.cpp 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. /**
  2. * \file dnn/src/common/winograd/winograd_helper.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/common/winograd/winograd_helper.h"
  12. #include "src/common/winograd/winograd_generator.h"
  13. #include "src/naive/matrix_mul/matrix_mul_helper.h"
  14. using namespace megdnn;
  15. namespace {
  16. template <typename ctype, typename otype, typename enable = void>
  17. struct Getter {
  18. Getter(DType){};
  19. otype operator()(ctype item) { return item; }
  20. };
  21. template <typename ctype, typename otype>
  22. struct Getter<ctype, otype,
  23. typename std::enable_if_t<std::is_same<ctype, uint8_t>::value>> {
  24. otype zp;
  25. Getter(DType dtype) {
  26. zp = dtype.param<dtype::Quantized8Asymm>().zero_point;
  27. }
  28. otype operator()(ctype item) { return static_cast<otype>(item) - zp; }
  29. };
  30. template <typename ctype, typename otype, typename enable = void>
  31. struct OutputGetter {
  32. OutputGetter(DType){};
  33. otype operator()(float item) { return static_cast<otype>(item); }
  34. };
  35. template <typename ctype, typename otype>
  36. struct OutputGetter<
  37. ctype, otype,
  38. typename std::enable_if_t<std::is_same<otype, int8_t>::value>> {
  39. DType dtype;
  40. OutputGetter(DType dtype) : dtype{dtype} {}
  41. otype operator()(float item) {
  42. return dtype.param<dtype::QuantizedS8>().quantize(item).as_int8();
  43. }
  44. };
  45. template <typename ctype, typename otype>
  46. struct OutputGetter<
  47. ctype, otype,
  48. typename std::enable_if_t<std::is_same<otype, uint8_t>::value>> {
  49. DType dtype;
  50. OutputGetter(DType dtype) : dtype{dtype} {}
  51. otype operator()(float item) {
  52. return dtype.param<dtype::Quantized8Asymm>().quantize(item).as_uint8();
  53. }
  54. };
  55. } // namespace
  56. namespace megdnn {
  57. namespace winograd {
  58. constexpr size_t layout_pack_size(param::ConvBias::Format layout) {
  59. switch (layout) {
  60. case param::ConvBias::Format::NHWCD4:
  61. return 4;
  62. case param::ConvBias::Format::NCHW4:
  63. case param::ConvBias::Format::NCHW44:
  64. return 4;
  65. case param::ConvBias::Format::NCHW32:
  66. return 32;
  67. case param::ConvBias::Format::NCHW88:
  68. case param::ConvBias::Format::NCHW8:
  69. return 8;
  70. default:
  71. return 1;
  72. }
  73. }
  74. template <param::ConvBias::Format layout, param::MatrixMul::Format format>
  75. struct FilterVisitor {
  76. size_t IC, OC;
  77. FilterVisitor(size_t OC, size_t IC) : IC(IC), OC(OC) {}
  78. size_t get(size_t r, size_t oc, size_t ic, size_t h, size_t w) {
  79. constexpr size_t input_pack_size = layout_pack_size(layout);
  80. size_t ocb_layout = oc / input_pack_size;
  81. size_t oc_layout = oc % input_pack_size;
  82. size_t icb_layout = ic / input_pack_size;
  83. size_t ic_layout = ic % input_pack_size;
  84. return (ocb_layout * (IC / input_pack_size) + icb_layout) * r * r *
  85. input_pack_size * input_pack_size +
  86. ic_layout * input_pack_size + oc_layout +
  87. (h * r + w) * input_pack_size * input_pack_size;
  88. }
  89. size_t put(size_t alpha, size_t oc, size_t ic, size_t h, size_t w) {
  90. if (format == param::MatrixMul::Format::DEFAULT) {
  91. return (h * alpha + w) * OC * IC + ic * OC + oc;
  92. }
  93. size_t matmul_pack_size = MatrixMulForward::pack_size(format);
  94. size_t ocb = oc / matmul_pack_size;
  95. size_t oc_pack = oc % matmul_pack_size;
  96. size_t icb = ic / matmul_pack_size;
  97. size_t ic_pack = ic % matmul_pack_size;
  98. size_t OCB = OC / matmul_pack_size;
  99. size_t ICB = IC / matmul_pack_size;
  100. return (h * alpha + w) * OCB * ICB * matmul_pack_size *
  101. matmul_pack_size +
  102. ocb * ICB * matmul_pack_size * matmul_pack_size +
  103. icb * matmul_pack_size * matmul_pack_size +
  104. ic_pack * matmul_pack_size + oc_pack;
  105. }
  106. };
  107. template <param::ConvBias::Format layout, param::MatrixMul::Format format>
  108. struct InputVisitor {
  109. size_t IC;
  110. InputVisitor(size_t IC) : IC(IC) {}
  111. size_t get(size_t /*alpha*/, size_t ic, size_t IH, size_t IW, size_t ih,
  112. size_t iw) {
  113. constexpr size_t input_pack_size = layout_pack_size(layout);
  114. size_t icb_layout = ic / input_pack_size;
  115. size_t ic_layout = ic % input_pack_size;
  116. return (icb_layout * IH * IW + ih * IW + iw) * input_pack_size +
  117. ic_layout;
  118. }
  119. size_t put(size_t alpha, size_t ic, size_t nr_units_in_tile,
  120. size_t unit_idx, size_t h, size_t w) {
  121. if (format == param::MatrixMul::Format::DEFAULT) {
  122. return (h * alpha + w) * nr_units_in_tile * IC + unit_idx * IC + ic;
  123. }
  124. size_t matmul_pack_size = MatrixMulForward::pack_size(format);
  125. size_t icb = ic / matmul_pack_size;
  126. size_t ic_pack = ic % matmul_pack_size;
  127. size_t ICB = IC / matmul_pack_size;
  128. return (h * alpha + w) * ICB * nr_units_in_tile * matmul_pack_size +
  129. icb * nr_units_in_tile * matmul_pack_size +
  130. unit_idx * matmul_pack_size + ic_pack;
  131. }
  132. };
  133. template <param::ConvBias::Format layout, param::MatrixMul::Format format>
  134. struct OutputVisitor {
  135. size_t OC;
  136. OutputVisitor(size_t OC) : OC(OC) {}
  137. size_t get(size_t alpha, size_t oc_index, size_t oc,
  138. size_t nr_units_in_tile, size_t unit_idx, size_t h, size_t w) {
  139. if (format == param::MatrixMul::Format::DEFAULT) {
  140. return (h * alpha + w) * nr_units_in_tile * OC + unit_idx * OC +
  141. oc_index;
  142. }
  143. size_t matmul_pack_size = MatrixMulForward::pack_size(format);
  144. size_t ocb = oc_index / matmul_pack_size;
  145. size_t oc_pack = oc % matmul_pack_size;
  146. size_t OCB = OC / matmul_pack_size;
  147. return (h * alpha + w) * OCB * nr_units_in_tile * matmul_pack_size +
  148. ocb * nr_units_in_tile * matmul_pack_size +
  149. unit_idx * matmul_pack_size + oc_pack;
  150. }
  151. size_t put(size_t oc, size_t OH, size_t OW, size_t oh, size_t ow) {
  152. constexpr size_t input_pack_size = layout_pack_size(layout);
  153. size_t oc_layout = oc % input_pack_size;
  154. return (oc / input_pack_size * OH * OW + oh * OW + ow) *
  155. input_pack_size +
  156. oc_layout;
  157. }
  158. };
  159. template <typename ctype, typename dst_type, typename input_filter_compute_type,
  160. typename output_compute_type, param::ConvBias::Format layout,
  161. param::MatrixMul::Format format>
  162. void StrategyHelper<
  163. ctype, dst_type, input_filter_compute_type, output_compute_type, layout,
  164. format>::filter(const ctype* filter,
  165. input_filter_compute_type* filter_transform_buf,
  166. input_filter_compute_type* transform_mid_buf, size_t OC,
  167. size_t IC, size_t oc_start, size_t oc_end, size_t m,
  168. size_t r, const std::vector<float>& interp_points,
  169. DType dtype, float rescale) {
  170. size_t alpha = m + r - 1;
  171. WinogradCoeff<input_filter_compute_type> winograd_coeff(m, r,
  172. interp_points);
  173. input_filter_compute_type* mid_buf1 = transform_mid_buf;
  174. input_filter_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha;
  175. Getter<ctype, input_filter_compute_type> getter(dtype);
  176. FilterVisitor<layout, format> filter_visitor(OC, IC);
  177. for (size_t oc = oc_start; oc < oc_end; oc++) {
  178. rep(ic, IC) {
  179. rep(i, r) rep(j, r) {
  180. mid_buf1[i * r + j] =
  181. getter(filter[filter_visitor.get(r, oc, ic, i, j)]);
  182. }
  183. /* tmp = Matmul(G, src) */
  184. megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type,
  185. input_filter_compute_type, false,
  186. false>(
  187. winograd_coeff.G(rescale).data(), mid_buf1, mid_buf2, alpha,
  188. r, r, r, r, r, dtype, dtype);
  189. /* dst = Matmul(tmp, G^T) */
  190. megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type,
  191. input_filter_compute_type, false,
  192. true>(
  193. mid_buf2, winograd_coeff.G(rescale).data(), mid_buf1, alpha,
  194. alpha, r, r, r, alpha, dtype, dtype);
  195. rep(i, alpha) rep(j, alpha) {
  196. filter_transform_buf[filter_visitor.put(alpha, oc, ic, i, j)] =
  197. mid_buf1[i * alpha + j];
  198. }
  199. }
  200. }
  201. }
  202. template <typename ctype, typename dst_type, typename input_filter_compute_type,
  203. typename output_compute_type, param::ConvBias::Format layout,
  204. param::MatrixMul::Format format>
  205. void StrategyHelper<
  206. ctype, dst_type, input_filter_compute_type, output_compute_type, layout,
  207. format>::input(const ctype* input,
  208. input_filter_compute_type* input_transform_buf,
  209. input_filter_compute_type* transform_mid_buf,
  210. int ih_start, int iw_start, size_t IH, size_t IW,
  211. size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile,
  212. size_t m, size_t r,
  213. const std::vector<float>& interp_points, DType dtype,
  214. float rescale) {
  215. size_t alpha = m + r - 1;
  216. WinogradCoeff<input_filter_compute_type> winograd_coeff(m, r,
  217. interp_points);
  218. input_filter_compute_type* mid_buf1 = transform_mid_buf;
  219. input_filter_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha;
  220. Getter<ctype, input_filter_compute_type> getter(dtype);
  221. InputVisitor<layout, format> intput_visitor(IC);
  222. memset(mid_buf1, 0, alpha * alpha * sizeof(input_filter_compute_type));
  223. rep(i, alpha) rep(j, alpha) {
  224. int ih = ih_start + i;
  225. int iw = iw_start + j;
  226. if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) {
  227. mid_buf1[i * alpha + j] = getter(
  228. input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]);
  229. }
  230. }
  231. megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type,
  232. input_filter_compute_type, true,
  233. false>(
  234. winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha,
  235. alpha, alpha, alpha, alpha, alpha, dtype, dtype);
  236. megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type,
  237. input_filter_compute_type, false,
  238. false>(
  239. mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha,
  240. alpha, alpha, alpha, alpha, alpha, dtype, dtype);
  241. rep(i, alpha) rep(j, alpha) {
  242. input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile,
  243. unit_idx, i, j)] =
  244. mid_buf1[i * alpha + j];
  245. }
  246. }
  247. template <typename ctype, typename dst_type, typename input_filter_compute_type,
  248. typename output_compute_type, param::ConvBias::Format layout,
  249. param::MatrixMul::Format format>
  250. void StrategyHelper<
  251. ctype, dst_type, input_filter_compute_type, output_compute_type, layout,
  252. format>::output(const output_compute_type* output_transform_buf,
  253. const output_compute_type* bias, dst_type* output,
  254. output_compute_type* transform_mid_buf, BiasMode bmode,
  255. NonlineMode nonline_mode, size_t oh_start,
  256. size_t ow_start, size_t OH, size_t OW, size_t OC, size_t oc_start,
  257. size_t oc_index, size_t unit_idx, size_t nr_units_in_tile,
  258. size_t m, size_t r,
  259. const std::vector<float>& interp_points, DType dtype,
  260. float input_filter_scale, float input_filter_rescale,
  261. float rescale) {
  262. size_t alpha = m + r - 1;
  263. winograd::WinogradCoeff<output_compute_type> winograd_coeff(m, r,
  264. interp_points);
  265. output_compute_type* mid_buf1 = transform_mid_buf;
  266. output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha;
  267. OutputGetter<output_compute_type, dst_type> getter(dtype);
  268. OutputVisitor<layout, format> output_visitor(OC);
  269. size_t oc = oc_start + oc_index;
  270. /* gather */
  271. rep(i, alpha) rep(j, alpha) {
  272. mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get(
  273. alpha, oc_index, oc, nr_units_in_tile, unit_idx, i,
  274. j)];
  275. }
  276. /* A[alpha*m] M[alpha*alpha] */
  277. megdnn::naive::run_matrix_mul_tpl<output_compute_type,
  278. output_compute_type, true, false>(
  279. winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha,
  280. alpha, m, alpha, alpha, dtype, dtype);
  281. megdnn::naive::run_matrix_mul_tpl<output_compute_type,
  282. output_compute_type, false, false>(
  283. mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m,
  284. alpha, alpha, m, m, dtype, dtype);
  285. rep(i, m) rep(j, m) {
  286. auto oh = oh_start + i;
  287. auto ow = ow_start + j;
  288. if (oh < OH && ow < OW) {
  289. float val = mid_buf1[i * m + j];
  290. if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
  291. val += bias[oc] * input_filter_rescale *
  292. input_filter_rescale;
  293. } else if (bmode == BiasMode::BIAS) {
  294. val += bias[output_visitor.put(oc, OH, OW, oh, ow)] *
  295. input_filter_rescale * input_filter_rescale;
  296. }
  297. val = val * input_filter_scale /
  298. (input_filter_rescale * input_filter_rescale * rescale *
  299. rescale);
  300. if (nonline_mode == NonlineMode::RELU) {
  301. val = val > 0 ? val : 0;
  302. } else if (nonline_mode == NonlineMode::SIGMOID) {
  303. val = 1.f / (expf(-val) + 1.f);
  304. } else if (nonline_mode == NonlineMode::H_SWISH) {
  305. val = val * std::min(std::max(val + 3, 0.f), 6.f) / 6.f;
  306. } else {
  307. megdnn_assert(nonline_mode == NonlineMode::IDENTITY);
  308. }
  309. output[output_visitor.put(oc, OH, OW, oh, ow)] = getter(val);
  310. }
  311. }
  312. };
  313. #define INST(_ctype, _dst_type, _input_filter_compute_type, \
  314. _output_compute_type) \
  315. template class StrategyHelper<_ctype, _dst_type, \
  316. _input_filter_compute_type, \
  317. _output_compute_type>;
  318. INST(float, float, float, float)
  319. MEGDNN_INC_FLOAT16(INST(dt_float16, dt_float16, dt_float16, dt_float16))
  320. INST(int8_t, int8_t, int16_t, int)
  321. INST(uint8_t, uint8_t, int16_t, int)
  322. #undef INST
  323. #define INST(_ctype, _dst_type, _input_filter_compute_type, \
  324. _output_compute_type, layout) \
  325. template class StrategyHelper< \
  326. _ctype, _dst_type, _input_filter_compute_type, \
  327. _output_compute_type, layout, param::MatrixMul::Format::MK4>;
  328. INST(float, float, float, float, param::ConvBias::Format::NCHW)
  329. INST(float, float, float, float, param::ConvBias::Format::NCHW44)
  330. #undef INST
  331. #define INST(_ctype, _dst_type, _input_filter_compute_type, \
  332. _output_compute_type, layout) \
  333. template class StrategyHelper< \
  334. _ctype, _dst_type, _input_filter_compute_type, \
  335. _output_compute_type, layout, param::MatrixMul::Format::MK8>;
  336. INST(int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW)
  337. INST(float, float, float, float, param::ConvBias::Format::NCHW88)
  338. MEGDNN_INC_FLOAT16(INST(dt_float16, dt_float16, dt_float16, dt_float16,
  339. param::ConvBias::Format::NCHW))
  340. #undef INST
  341. } // namespace winograd
  342. } // namespace megdnn
  343. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台