You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781
  1. /**
  2. * \file dnn/src/naive/pooling/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/naive/pooling/opr_impl.h"
  13. #include <cstring>
  14. #include "megdnn/heuristic_cache.h"
  15. #include "megdnn/dtype.h"
  16. #include "src/common/utils.h"
  17. #include "src/naive/handle.h"
  18. #include "src/naive/lowbit_utils.h"
  19. #include "midout.h"
  20. MIDOUT_DECL(megdnn_naive_pooling)
  21. namespace {
  22. using namespace megdnn;
  23. template <typename ctype_>
  24. struct MaxPooler {
  25. using ctype = ctype_;
  26. ctype answer;
  27. bool fed;
  28. MaxPooler(size_t, DType) : answer(DTypeTrait<ctype>::min()) {}
  29. void init() {
  30. answer = DTypeTrait<ctype>::min();
  31. fed = false;
  32. }
  33. void feed(ctype x) {
  34. answer = answer > x ? answer : x;
  35. fed = true;
  36. }
  37. ctype get_ans() {
  38. if (!fed) {
  39. megdnn_throw("The pooling window lies outside completely");
  40. }
  41. return answer;
  42. }
  43. };
  44. template <typename stype_, typename ctype_>
  45. struct MeanIncludePoolerBase {
  46. using stype = stype_;
  47. using ctype = ctype_;
  48. ctype sum;
  49. const ctype count;
  50. MeanIncludePoolerBase(size_t count, DType) : count(ctype(count)) {}
  51. void init() { sum = ctype(0); }
  52. void feed(stype x) { sum += x; }
  53. };
  54. template <typename T>
  55. struct MeanIncludePooler : public MeanIncludePoolerBase<T, T> {
  56. using MeanIncludePoolerBase<T, T>::MeanIncludePoolerBase;
  57. using ctype = typename MeanIncludePoolerBase<T, T>::ctype;
  58. ctype get_ans() { return this->sum / this->count; }
  59. };
  60. template <>
  61. struct MeanIncludePooler<int8_t>
  62. : public MeanIncludePoolerBase<int8_t, int32_t> {
  63. using MeanIncludePoolerBase::MeanIncludePoolerBase;
  64. ctype get_ans() {
  65. return std::min<int32_t>(
  66. std::max<int32_t>(std::numeric_limits<int8_t>::min(),
  67. sum / count),
  68. std::numeric_limits<int8_t>::max());
  69. }
  70. };
  71. template <>
  72. struct MeanIncludePooler<dt_quint8> {
  73. int32_t sum;
  74. size_t feed_count;
  75. const int32_t count;
  76. const int32_t zero_point;
  77. MeanIncludePooler(size_t count, DType dtype)
  78. : count(int32_t(count)),
  79. zero_point(dtype.param<dtype::Quantized8Asymm>().zero_point) {}
  80. void init() {
  81. sum = 0;
  82. feed_count = 0;
  83. }
  84. void feed(dt_quint8 x) {
  85. sum += x.as_uint8();
  86. ++feed_count;
  87. }
  88. dt_quint8 get_ans() {
  89. int32_t summie = sum + (count - feed_count) * zero_point;
  90. int32_t rounded = std::round(static_cast<float>(summie) / count);
  91. return dt_quint8(std::min<int32_t>(
  92. std::max<int32_t>(rounded, std::numeric_limits<uint8_t>::min()),
  93. std::numeric_limits<uint8_t>::max()));
  94. }
  95. };
  96. /*!
  97. * \brief Average pooling operation within a single window.
  98. * Works on integers. Rounds toward +INF.
  99. * \tparam T input data type
  100. * \tparam U convert input data type to U before accumulating
  101. * \tparam ICType data type for intermediate result
  102. */
  103. template <typename T, typename U = T, typename ICType = int32_t>
  104. struct MeanIncludeRoundedPooler {
  105. ICType sum;
  106. const int32_t count;
  107. MeanIncludeRoundedPooler(size_t count, DType) : count(ICType(count)) {}
  108. void init() { sum = 0; }
  109. void feed(T x) { sum += static_cast<ICType>(static_cast<U>(x)); }
  110. T get_ans() { return T(std::round(static_cast<float>(sum) / count)); }
  111. };
  112. template <>
  113. struct MeanIncludePooler<dt_qint32>
  114. : MeanIncludeRoundedPooler<dt_qint32, int32_t> {
  115. using MeanIncludeRoundedPooler::MeanIncludeRoundedPooler;
  116. };
  117. template <>
  118. struct MeanIncludePooler<dt_qint8>
  119. : MeanIncludeRoundedPooler<dt_qint8, int8_t> {
  120. using MeanIncludeRoundedPooler::MeanIncludeRoundedPooler;
  121. };
  122. struct NCHWIdxGetter {
  123. static size_t get_idx(size_t n, size_t c, size_t h, size_t w,
  124. size_t /* N */, size_t C, size_t H, size_t W) {
  125. return ((n * C + c) * H + h) * W + w;
  126. }
  127. };
  128. struct NHWCIdxGetter {
  129. static size_t get_idx(size_t n, size_t c, size_t h, size_t w,
  130. size_t /* N */, size_t C, size_t H, size_t W) {
  131. return ((n * H + h) * W + w) * C + c;
  132. }
  133. };
  134. struct NHWCD4IdxGetter {
  135. static size_t get_idx(size_t n, size_t c, size_t h, size_t w,
  136. size_t /* N */, size_t C, size_t H, size_t W) {
  137. return (((n * H + h) * (C >> 2) + (c >> 2)) * W + w) * 4 + (c & 0x3);
  138. }
  139. };
  140. struct NCHW4IdxGetter {
  141. static size_t get_idx(size_t n, size_t c, size_t h, size_t w, size_t,
  142. size_t C, size_t H, size_t W) {
  143. return (((n * (C >> 2) + (c >> 2)) * H + h) * W + w) * 4 + (c & 0b11);
  144. }
  145. };
  146. struct NCHW88IdxGetter {
  147. static size_t get_idx(size_t n, size_t c, size_t h, size_t w, size_t,
  148. size_t C, size_t H, size_t W) {
  149. size_t id =
  150. (((n * (C >> 3) + (c >> 3)) * H + h) * W + w) * 8 + (c & 0b111);
  151. return id;
  152. }
  153. };
  154. struct NCHW44IdxGetter {
  155. static size_t get_idx(size_t n, size_t c, size_t h, size_t w, size_t,
  156. size_t C, size_t H, size_t W) {
  157. size_t id = (((n * (C >> 2) + (c >> 2)) * H + h) * W + w) * 4 + (c % 4);
  158. return id;
  159. }
  160. };
  161. struct CHWN4IdxGetter {
  162. static size_t get_idx(size_t n, size_t c, size_t h, size_t w, size_t N,
  163. size_t, size_t H, size_t W) {
  164. return ((((c >> 2) * H + h) * W + w) * N + n) * 4 + (c & 0b11);
  165. }
  166. };
  167. struct NCHW32IdxGetter {
  168. static size_t get_idx(size_t n, size_t c, size_t h, size_t w, size_t,
  169. size_t C, size_t H, size_t W) {
  170. return (((n * (C >> 5) + (c >> 5)) * H + h) * W + w) * 32 + (c & 0x1f);
  171. }
  172. };
  173. struct NCHW64IdxGetter {
  174. static size_t get_idx(size_t n, size_t c, size_t h, size_t w, size_t,
  175. size_t C, size_t H, size_t W) {
  176. return (((n * (C >> 6) + (c >> 6)) * H + h) * W + w) * 64 + (c & 0x3f);
  177. }
  178. };
  179. /*!
  180. * Pooler for AVERAGE_COUNT_EXCLUDE_PADDING mode
  181. */
  182. template <typename ctype>
  183. struct MeanExcludePooler {
  184. ctype sum;
  185. size_t count;
  186. MeanExcludePooler(size_t, DType) {}
  187. void init() {
  188. sum = 0.0f;
  189. count = 0u;
  190. }
  191. void feed(ctype x) {
  192. sum += x;
  193. ++count;
  194. }
  195. ctype get_ans() {
  196. if (count == 0u) {
  197. megdnn_throw("The pooling window lies outside completely");
  198. }
  199. return sum / static_cast<ctype>(count);
  200. }
  201. };
  202. /*!
  203. * \brief Average pooling operation within a single window.
  204. * Works on integers. Rounds toward +INF.
  205. * \tparam T input data type
  206. * \tparam U convert input data type to U before accumulating
  207. * \tparam ICType data type for intermediate result
  208. */
  209. template <typename T, typename U, typename ICType = U>
  210. struct MeanExcludeRoundedPooler {
  211. ICType sum;
  212. size_t count;
  213. MeanExcludeRoundedPooler(size_t, DType) {}
  214. void init() {
  215. sum = 0;
  216. count = 0;
  217. }
  218. void feed(T x) {
  219. sum += U(x);
  220. ++count;
  221. }
  222. T get_ans() {
  223. if (count == 0u) {
  224. megdnn_throw("The pooling window lies outside completely");
  225. }
  226. return T(std::round(static_cast<float>(sum) / count));
  227. }
  228. };
  229. template <>
  230. struct MeanExcludePooler<dt_quint8>
  231. : MeanExcludeRoundedPooler<dt_quint8, uint8_t, uint32_t> {
  232. using MeanExcludeRoundedPooler::MeanExcludeRoundedPooler;
  233. };
  234. template <>
  235. struct MeanExcludePooler<dt_qint32>
  236. : MeanExcludeRoundedPooler<dt_qint32, int32_t> {
  237. using MeanExcludeRoundedPooler::MeanExcludeRoundedPooler;
  238. };
  239. template <>
  240. struct MeanExcludePooler<dt_qint8>
  241. : MeanExcludeRoundedPooler<dt_qint8, int8_t, int32_t> {
  242. using MeanExcludeRoundedPooler::MeanExcludeRoundedPooler;
  243. };
  244. template <typename Pooler, typename IdxGetter,
  245. typename ctype = typename Pooler::ctype>
  246. void pooling_forward_impl(const ctype* __restrict src, ctype* __restrict dst,
  247. DType src_dtype, size_t N, size_t C, size_t IH,
  248. size_t IW, size_t OH, size_t OW, size_t PH, size_t PW,
  249. size_t SH, size_t SW, size_t FH, size_t FW) {
  250. rep(n, N) rep(c, C) rep(oh, OH) rep(ow, OW) {
  251. Pooler pooler(FH * FW, src_dtype);
  252. pooler.init();
  253. rep(fh, FH) rep(fw, FW) {
  254. size_t ih = -PH + oh * SH + fh;
  255. size_t iw = -PW + ow * SW + fw;
  256. if (ih < IH && iw < IW) {
  257. size_t idx = IdxGetter::get_idx(n, c, ih, iw, N, C, IH, IW);
  258. pooler.feed(src[idx]);
  259. }
  260. }
  261. size_t idx = IdxGetter::get_idx(n, c, oh, ow, N, C, OH, OW);
  262. dst[idx] = pooler.get_ans();
  263. }
  264. }
  265. template <typename ctype, typename IdxGetter>
  266. void pooling_backward_avg_impl(const ctype* __restrict /* src */,
  267. const ctype* __restrict /* dst */,
  268. const ctype* __restrict diff,
  269. ctype* __restrict grad, size_t N, size_t C,
  270. size_t IH, size_t IW, size_t OH, size_t OW,
  271. size_t PH, size_t PW, size_t SH, size_t SW,
  272. size_t FH, size_t FW, bool is_include = true) {
  273. std::memset(grad, 0, sizeof(ctype) * (N * C * IH * IW));
  274. rep(n, N) rep(c, C) rep(oh, OH) rep(ow, OW) {
  275. size_t count = 0u;
  276. rep(fh, FH) rep(fw, FW) {
  277. size_t ih = -PH + oh * SH + fh;
  278. size_t iw = -PW + ow * SW + fw;
  279. if (ih < IH && iw < IW)
  280. ++count;
  281. }
  282. if (is_include)
  283. count = FH * FW;
  284. if (count == 0u) {
  285. megdnn_throw("The pooling window lies outside completely");
  286. }
  287. rep(fh, FH) rep(fw, FW) {
  288. size_t ih = -PH + oh * SH + fh;
  289. size_t iw = -PW + ow * SW + fw;
  290. if (ih < IH && iw < IW) {
  291. size_t gi = IdxGetter::get_idx(n, c, ih, iw, N, C, IH, IW);
  292. size_t di = IdxGetter::get_idx(n, c, oh, ow, N, C, OH, OW);
  293. auto& gval = grad[gi];
  294. auto dval = diff[di];
  295. gval += dval / ctype(count);
  296. }
  297. }
  298. }
  299. }
  300. template <typename ctype, typename IdxGetter>
  301. void pooling_backward_avg_expd_impl(const ctype* __restrict src,
  302. const ctype* __restrict dst,
  303. const ctype* __restrict diff,
  304. ctype* __restrict grad, size_t N, size_t C,
  305. size_t IH, size_t IW, size_t OH, size_t OW,
  306. size_t PH, size_t PW, size_t SH, size_t SW,
  307. size_t FH, size_t FW) {
  308. pooling_backward_avg_impl<ctype, IdxGetter>(src, dst, diff, grad, N, C, IH,
  309. IW, OH, OW, PH, PW, SH, SW, FH,
  310. FW, false);
  311. }
  312. template <typename ctype, typename IdxGetter>
  313. void pooling_backward_max_impl(const ctype* __restrict src,
  314. const ctype* __restrict dst,
  315. const ctype* __restrict diff,
  316. ctype* __restrict grad, size_t N, size_t C,
  317. size_t IH, size_t IW, size_t OH, size_t OW,
  318. size_t PH, size_t PW, size_t SH, size_t SW,
  319. size_t FH, size_t FW) {
  320. std::memset(grad, 0, sizeof(ctype) * (N * C * IH * IW));
  321. rep(n, N) rep(c, C) rep(oh, OH) rep(ow, OW) {
  322. size_t count = 0u;
  323. rep(fh, FH) rep(fw, FW) {
  324. size_t ih = -PH + oh * SH + fh;
  325. size_t iw = -PW + ow * SW + fw;
  326. if (ih < IH && iw < IW)
  327. ++count;
  328. }
  329. if (count == 0u) {
  330. megdnn_throw("The pooling window lies outside completely");
  331. }
  332. rep(fh, FH) rep(fw, FW) {
  333. size_t ih = -PH + oh * SH + fh;
  334. size_t iw = -PW + ow * SW + fw;
  335. if (ih < IH && iw < IW) {
  336. size_t si = IdxGetter::get_idx(n, c, ih, iw, N, C, IH, IW);
  337. size_t di = IdxGetter::get_idx(n, c, oh, ow, N, C, OH, OW);
  338. auto sval = src[si];
  339. auto& gval = grad[si];
  340. auto dst_val = dst[di];
  341. auto diff_val = diff[di];
  342. if (sval == dst_val)
  343. gval += diff_val;
  344. }
  345. }
  346. }
  347. }
  348. } // namespace
  349. namespace megdnn {
  350. namespace naive {
  351. WorkspaceBundle PoolingForwardImpl::get_workspace_bundle(
  352. void* ptr, const TensorLayout& src, const TensorLayout& dst) const {
  353. SmallVector<size_t> sizes;
  354. TensorLayout fsrc = src;
  355. TensorLayout fdst = dst;
  356. auto get_workspace = [&sizes](TensorLayout& layout) {
  357. if (layout.dtype.enumv() == DTypeEnum::Quantized4Asymm ||
  358. layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
  359. layout.dtype = dtype::Int8();
  360. layout.format = TensorLayout::Format(layout.dtype);
  361. sizes.push_back(layout.span().dist_byte());
  362. }
  363. };
  364. get_workspace(fsrc);
  365. get_workspace(fdst);
  366. return {ptr, std::move(sizes)};
  367. };
  368. size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
  369. const TensorLayout& dst) {
  370. TensorLayoutArray layouts{src, dst};
  371. HeuristicCache::Key key{this->handle(), this->get_opr_type(),
  372. layouts.data(), layouts.size(), &this->param(),
  373. sizeof(this->param())};
  374. auto rst = HeuristicCache::instance().get(key);
  375. if (rst.policy.algo.valid()) {
  376. return rst.workspace;
  377. }
  378. return get_workspace_bundle(nullptr, src, dst).total_size_in_bytes();
  379. }
  380. namespace {
  381. void post_process(const TensorND& dst, TensorND& comp_dst) {
  382. if (dst.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
  383. int8_to_int4(comp_dst, dst);
  384. } else if (dst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
  385. uint8_to_uint4(comp_dst, dst);
  386. }
  387. }
  388. } // namespace
  389. void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  390. _megdnn_workspace workspace) {
  391. check_exec(src.layout, dst.layout, workspace.size);
  392. TensorND comp_src = src;
  393. TensorND comp_dst = dst;
  394. auto wsb = get_workspace_bundle(workspace.raw_ptr, src.layout, dst.layout);
  395. if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
  396. float scale = src.layout.dtype.param<dtype::QuantizedS4>().scale;
  397. comp_src.layout.dtype = dtype::QuantizedS8(scale);
  398. comp_src.layout.format = TensorLayout::Format(comp_src.layout.dtype);
  399. comp_src.layout.init_contiguous_stride();
  400. comp_src.raw_ptr = wsb.get(0);
  401. comp_dst.layout.dtype = dtype::QuantizedS8(scale);
  402. comp_dst.layout.format = TensorLayout::Format(comp_dst.layout.dtype);
  403. comp_dst.layout.init_contiguous_stride();
  404. comp_dst.raw_ptr = wsb.get(1);
  405. int4_to_int8(src, comp_src);
  406. } else if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
  407. float scale = src.layout.dtype.param<dtype::Quantized4Asymm>().scale;
  408. uint8_t zero_point =
  409. src.layout.dtype.param<dtype::Quantized4Asymm>().zero_point;
  410. comp_src.layout.dtype = dtype::Quantized8Asymm(scale, zero_point);
  411. comp_src.layout.format = TensorLayout::Format(comp_src.layout.dtype);
  412. comp_src.layout.init_contiguous_stride();
  413. comp_src.raw_ptr = wsb.get(0);
  414. comp_dst.layout.dtype = dtype::Quantized8Asymm(scale, zero_point);
  415. comp_dst.layout.format = TensorLayout::Format(comp_dst.layout.dtype);
  416. comp_dst.layout.init_contiguous_stride();
  417. comp_dst.raw_ptr = wsb.get(1);
  418. uint4_to_uint8(src, comp_src);
  419. }
  420. size_t c_pos, spatial_pos, batch_pos = 0;
  421. if (param().format == Param::Format::NCHW ||
  422. param().format == Param::Format::NCHW4 ||
  423. param().format == Param::Format::NCHW88 ||
  424. param().format == Param::Format::NCHW44 ||
  425. param().format == Param::Format::NCHW32 ||
  426. param().format == Param::Format::NCHW64) {
  427. c_pos = 1;
  428. spatial_pos = 2;
  429. } else if (param().format == Param::Format::NHWC) {
  430. c_pos = 3;
  431. spatial_pos = 1;
  432. } else if (param().format == Param::Format::CHWN4) {
  433. c_pos = 0;
  434. spatial_pos = 1;
  435. batch_pos = 3;
  436. } else {
  437. megdnn_assert(param().format == Param::Format::NHWCD4);
  438. c_pos = 2;
  439. spatial_pos = 1;
  440. }
  441. size_t N = comp_src.layout.shape[batch_pos],
  442. C = comp_src.layout.shape[c_pos],
  443. IH = comp_src.layout.shape[spatial_pos + 0],
  444. IW = comp_src.layout.shape[spatial_pos + 1];
  445. size_t OH = comp_dst.layout.shape[spatial_pos + 0],
  446. OW = comp_dst.layout.shape[spatial_pos + 1];
  447. switch (param().format) {
  448. case Param::Format::NHWCD4:
  449. C *= 4;
  450. IW = comp_src.layout.shape[spatial_pos + 2];
  451. OW = comp_dst.layout.shape[spatial_pos + 2];
  452. break;
  453. case Param::Format::NCHW4:
  454. case Param::Format::NCHW44:
  455. case Param::Format::CHWN4:
  456. C *= 4;
  457. break;
  458. case Param::Format::NCHW88:
  459. C *= 8;
  460. break;
  461. case Param::Format::NCHW32:
  462. C *= 32;
  463. break;
  464. case Param::Format::NCHW64:
  465. C *= 64;
  466. break;
  467. default:;
  468. }
  469. size_t PH = param().pad_h, PW = param().pad_w;
  470. size_t FH = param().window_h, FW = param().window_w;
  471. size_t SH = param().stride_h, SW = param().stride_w;
  472. #define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \
  473. MIDOUT_BEGIN(megdnn_naive_pooling, midout_iv(#Pooler #IdxGetter##_hash)) { \
  474. MEGDNN_DISPATCH_CPU_KERN( \
  475. static_cast<naive::HandleImpl*>(handle()), \
  476. pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \
  477. sptr, dptr, comp_src.layout.dtype, N, C, IH, IW, OH, \
  478. OW, PH, PW, SH, SW, FH, FW)); \
  479. } \
  480. MIDOUT_END();
  481. #define DISPATCH_WITH_POOLER(Pooler) \
  482. switch (param().format) { \
  483. case Param::Format::NCHW: \
  484. DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHWIdxGetter); \
  485. break; \
  486. case Param::Format::NHWC: \
  487. DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NHWCIdxGetter); \
  488. break; \
  489. case Param::Format::NHWCD4: \
  490. DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NHWCD4IdxGetter); \
  491. break; \
  492. case Param::Format::NCHW4: \
  493. DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW4IdxGetter); \
  494. break; \
  495. case Param::Format::NCHW88: \
  496. DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW88IdxGetter); \
  497. break; \
  498. case Param::Format::NCHW44: \
  499. DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW44IdxGetter); \
  500. break; \
  501. case Param::Format::NCHW32: \
  502. DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW32IdxGetter); \
  503. break; \
  504. case Param::Format::NCHW64: \
  505. DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW64IdxGetter); \
  506. break; \
  507. case Param::Format::CHWN4: \
  508. DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, CHWN4IdxGetter); \
  509. break; \
  510. default: \
  511. megdnn_throw("invalid pooling format"); \
  512. }
  513. #define cb(DType) \
  514. if (comp_src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
  515. using ctype = typename DTypeTrait<DType>::ctype; \
  516. switch (param().mode) { \
  517. case Mode::MAX: { \
  518. auto sptr = comp_src.ptr<ctype>(); \
  519. auto dptr = comp_dst.ptr<ctype>(); \
  520. DISPATCH_WITH_POOLER(MaxPooler<ctype>); \
  521. break; \
  522. } \
  523. case Mode::AVERAGE: { \
  524. auto sptr = comp_src.ptr<ctype>(); \
  525. auto dptr = comp_dst.ptr<ctype>(); \
  526. DISPATCH_WITH_POOLER(MeanIncludePooler<ctype>); \
  527. break; \
  528. } \
  529. case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: { \
  530. auto sptr = comp_src.ptr<ctype>(); \
  531. auto dptr = comp_dst.ptr<ctype>(); \
  532. DISPATCH_WITH_POOLER(MeanExcludePooler<ctype>); \
  533. break; \
  534. } \
  535. default: \
  536. megdnn_assert(0, "not support mode"); \
  537. } \
  538. post_process(dst, comp_dst); \
  539. return; \
  540. }
  541. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  542. MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
  543. #undef cb
  544. #undef DISPATCH_WITH_POOLER_AND_IDX_GETTER
  545. #undef DISPATCH_WITH_POOLER
  546. megdnn_assert_internal(0);
  547. }
  548. PoolingForward::Algorithm* PoolingForwardImpl::get_algorithm_from_desc(
  549. const AlgorithmDesc& desc) {
  550. Algorithm* ret =
  551. static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo();
  552. megdnn_assert(desc == ret->info().desc);
  553. return ret;
  554. }
  555. std::vector<Algorithm*> PoolingForwardImpl::get_all_algorithms(
  556. const TensorLayout&, const TensorLayout&) {
  557. return {static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo()};
  558. }
  559. std::vector<Algorithm*> PoolingForwardImpl::get_all_algorithms_safe(
  560. const TensorLayout&, const TensorLayout&) {
  561. return {static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo()};
  562. }
  563. Algorithm* PoolingForwardImpl::get_algorithm_heuristic(
  564. const TensorLayout& /*src*/, const TensorLayout& /*dst*/,
  565. size_t /*workspace_limit_in_bytes*/, const AlgoAttribute& positive_attr,
  566. const AlgoAttribute& negative_attr) {
  567. auto algo = static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo();
  568. algo->check_attribute(positive_attr, negative_attr);
  569. return algo;
  570. }
  571. Algorithm* PoolingBackwardImpl::get_algorithm_from_desc(
  572. const AlgorithmDesc& desc) {
  573. Algorithm* ret =
  574. static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo();
  575. megdnn_assert(desc == ret->info().desc);
  576. return ret;
  577. }
  578. std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms(
  579. const TensorLayout& /*src*/, const TensorLayout& /*dst*/,
  580. const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) {
  581. return {static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo()};
  582. }
  583. std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms_safe(
  584. const TensorLayout& /*src*/, const TensorLayout& /*dst*/,
  585. const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) {
  586. return {static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo()};
  587. }
  588. Algorithm* PoolingBackwardImpl::get_algorithm_heuristic(
  589. const TensorLayout& /*src*/, const TensorLayout& /*dst*/,
  590. const TensorLayout& /*diff*/, const TensorLayout& /*grad*/,
  591. size_t /*workspace_limit_in_bytes*/, const AlgoAttribute& positive_attr,
  592. const AlgoAttribute& negative_attr) {
  593. auto algo = static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo();
  594. algo->check_attribute(positive_attr, negative_attr);
  595. return algo;
  596. }
  597. WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle(
  598. void* ptr, const TensorLayout& src, const TensorLayout& dst,
  599. const TensorLayout& diff, const TensorLayout& grad) const {
  600. SmallVector<size_t> sizes;
  601. TensorLayout fsrc = src;
  602. TensorLayout fdst = dst;
  603. TensorLayout fdiff = diff;
  604. TensorLayout fgrad = grad;
  605. auto get_workspace = [&sizes](TensorLayout& layout) {
  606. if (DNN_FLOAT16_SELECT(layout.dtype == dtype::BFloat16(), false)) {
  607. layout.dtype = dtype::Float32();
  608. sizes.push_back(layout.span().dist_byte());
  609. }
  610. };
  611. get_workspace(fsrc);
  612. get_workspace(fdst);
  613. get_workspace(fdiff);
  614. get_workspace(fgrad);
  615. return {ptr, std::move(sizes)};
  616. }
  617. size_t PoolingBackwardImpl::get_workspace_in_bytes(
  618. const TensorLayout& src, const TensorLayout& dst,
  619. const TensorLayout& diff, const TensorLayout& grad) {
  620. TensorLayoutArray layouts{src, dst, diff, grad};
  621. HeuristicCache::Key key{this->handle(), this->get_opr_type(),
  622. layouts.data(), layouts.size(), &this->param(),
  623. sizeof(this->param())};
  624. auto rst = HeuristicCache::instance().get(key);
  625. if (rst.policy.algo.valid()) {
  626. return rst.workspace;
  627. }
  628. return get_workspace_bundle(nullptr, src, dst, diff, grad)
  629. .total_size_in_bytes();
  630. }
  631. void PoolingBackwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_in sdst,
  632. _megdnn_tensor_in sdiff,
  633. _megdnn_tensor_out sgrad,
  634. _megdnn_workspace workspace) {
  635. check_exec(ssrc.layout, sdst.layout, sdiff.layout, sgrad.layout,
  636. workspace.size);
  637. TensorND src = ssrc;
  638. TensorND dst = sdst;
  639. TensorND diff = sdiff;
  640. TensorND grad = sgrad;
  641. #if !MEGDNN_DISABLE_FLOAT16
  642. auto wsb = get_workspace_bundle(workspace.raw_ptr, ssrc.layout, sdst.layout,
  643. sdiff.layout, sgrad.layout);
  644. auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
  645. static_cast<HandleImpl*>(handle()), &wsb);
  646. if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  647. ctypecvt.src_to_comp_type(ssrc, src)
  648. .src_to_comp_type(sdst, dst)
  649. .src_to_comp_type(sdiff, diff)
  650. .src_to_comp_type(sgrad, grad);
  651. }
  652. #endif
  653. size_t c_pos, spatial_pos;
  654. if (param().format == Param::Format::NCHW) {
  655. c_pos = 1;
  656. spatial_pos = 2;
  657. } else {
  658. megdnn_assert(param().format == Param::Format::NHWC);
  659. c_pos = 3;
  660. spatial_pos = 1;
  661. }
  662. size_t N = src.layout.shape[0], C = src.layout.shape[c_pos],
  663. IH = src.layout.shape[spatial_pos + 0],
  664. IW = src.layout.shape[spatial_pos + 1];
  665. size_t OH = dst.layout.shape[spatial_pos + 0],
  666. OW = dst.layout.shape[spatial_pos + 1];
  667. size_t PH = param().pad_h, PW = param().pad_w;
  668. size_t FH = param().window_h, FW = param().window_w;
  669. size_t SH = param().stride_h, SW = param().stride_w;
  670. #define DISPATCH_WITH_FUNC_AND_IDX_GETTER(Func, ctype, IdxGetter) \
  671. MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle()), \
  672. Func<ctype MEGDNN_COMMA IdxGetter>( \
  673. sptr, dptr, diffptr, gradptr, N, C, IH, \
  674. IW, OH, OW, PH, PW, SH, SW, FH, FW)); \
  675. #define DISPATCH_WITH_FUNC(Func, ctype) \
  676. switch (param().format) { \
  677. case Param::Format::NCHW: \
  678. DISPATCH_WITH_FUNC_AND_IDX_GETTER(Func, ctype, NCHWIdxGetter); \
  679. break; \
  680. case Param::Format::NHWC: \
  681. DISPATCH_WITH_FUNC_AND_IDX_GETTER(Func, ctype, NHWCIdxGetter); \
  682. break; \
  683. default: \
  684. megdnn_throw("invalid pooling format"); \
  685. }
  686. #define cb(DType) \
  687. if (src.layout.dtype == DType()) { \
  688. using ctype = typename DTypeTrait<DType>::ctype; \
  689. switch (param().mode) { \
  690. case Mode::AVERAGE: { \
  691. auto sptr = src.ptr<ctype>(), dptr = dst.ptr<ctype>(), \
  692. diffptr = diff.ptr<ctype>(), gradptr = grad.ptr<ctype>(); \
  693. DISPATCH_WITH_FUNC(pooling_backward_avg_impl, ctype); \
  694. break; \
  695. } \
  696. case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: { \
  697. auto sptr = src.ptr<ctype>(), dptr = dst.ptr<ctype>(), \
  698. diffptr = diff.ptr<ctype>(), gradptr = grad.ptr<ctype>(); \
  699. DISPATCH_WITH_FUNC(pooling_backward_avg_expd_impl, ctype); \
  700. break; \
  701. } \
  702. case Mode::MAX: { \
  703. auto sptr = src.ptr<ctype>(), dptr = dst.ptr<ctype>(), \
  704. diffptr = diff.ptr<ctype>(), gradptr = grad.ptr<ctype>(); \
  705. DISPATCH_WITH_FUNC(pooling_backward_max_impl, ctype); \
  706. break; \
  707. } \
  708. default: \
  709. megdnn_assert_internal(0); \
  710. } \
  711. }
  712. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  713. #undef cb
  714. #undef DISPATCH_WITH_FUNC_AND_IDX_GETTER
  715. #undef DISPATCH_WITH_FUNC
  716. #if !MEGDNN_DISABLE_FLOAT16
  717. if (sgrad.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  718. ctypecvt.comp_to_dst_type(grad, sgrad);
  719. }
  720. #endif
  721. }
  722. } // namespace naive
  723. } // namespace megdnn
  724. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台