You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. /**
  2. * \file dnn/src/naive/resize/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/naive/resize/opr_impl.h"
  12. #include "midout.h"
  13. #include "src/common/cv/enums.h"
  14. #include "src/common/resize.cuh"
  15. #include "src/common/rounding_converter.cuh"
  16. #include "src/common/utils.cuh"
  17. #include "src/naive/handle.h"
  18. #include "src/naive/resize/resize_cv.h"
  19. MIDOUT_DECL(megdnn_naive_resize_layout)
  20. MIDOUT_DECL(megdnn_naive_resize_nchw)
  21. using namespace megdnn;
  22. using namespace naive;
  23. using namespace resize;
  24. template <typename ctype>
  25. ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors(
  26. Format format, _megdnn_tensor_in src, _megdnn_tensor_out dst,
  27. _megdnn_workspace workspace) {
  28. KernParam<ctype> ret;
  29. ret.format = format;
  30. ret.n = src.layout.shape[0];
  31. if (format == Format::NCHW) {
  32. ret.c = src.layout.shape[1];
  33. ret.ih = src.layout.shape[2];
  34. ret.iw = src.layout.shape[3];
  35. ret.oh = dst.layout.shape[2];
  36. ret.ow = dst.layout.shape[3];
  37. ret.s_in = src.layout.stride[0];
  38. ret.s_ic = src.layout.stride[1];
  39. ret.s_ih = src.layout.stride[2];
  40. ret.s_iw = src.layout.stride[3];
  41. } else if (format == Format::NHWC) {
  42. ret.c = src.layout.shape[3];
  43. ret.ih = src.layout.shape[1];
  44. ret.iw = src.layout.shape[2];
  45. ret.oh = dst.layout.shape[1];
  46. ret.ow = dst.layout.shape[2];
  47. } else if (format == Format::NCHW4) {
  48. ret.c = src.layout.shape[1] * 4;
  49. ret.ih = src.layout.shape[2];
  50. ret.iw = src.layout.shape[3];
  51. ret.oh = dst.layout.shape[2];
  52. ret.ow = dst.layout.shape[3];
  53. } else {
  54. megdnn_assert(format == Format::NHWCD4);
  55. ret.c = src.layout.shape[2] * 4;
  56. ret.ih = src.layout.shape[1];
  57. ret.iw = src.layout.shape[3];
  58. ret.oh = dst.layout.shape[1];
  59. ret.ow = dst.layout.shape[3];
  60. }
  61. if (src.layout.dtype.enumv() == DTypeEnum::Float32 ||
  62. DNN_FLOAT16_SELECT(src.layout.dtype.enumv() == DTypeEnum::Float16,
  63. false) ||
  64. src.layout.dtype.enumv() == DTypeEnum::Int8 ||
  65. src.layout.dtype.enumv() == DTypeEnum::Uint8 ||
  66. src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  67. src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
  68. ret.sptr = src.compatible_ptr<ctype>();
  69. ret.dptr = dst.compatible_ptr<ctype>();
  70. } else {
  71. megdnn_assert(0, "current do not support dtype %s in resize",
  72. src.layout.dtype.name());
  73. }
  74. ret.workspace = workspace;
  75. return ret;
  76. }
  77. #define INST(_dtype) template struct ResizeImpl::KernParam<_dtype>;
  78. INST(dt_float32);
  79. #ifndef MEGDNN_DISABLE_FLOAT16
  80. INST(dt_float16);
  81. #endif
  82. INST(dt_int8);
  83. INST(dt_uint8);
  84. INST(dt_qint8);
  85. INST(dt_quint8);
  86. #undef INST
  87. template <typename ctype>
  88. void ResizeImpl::kern_nchw(const KernParam<ctype>& kern_param,
  89. InterpolationMode imode) {
  90. megdnn_assert(kern_param.format == Format::NCHW);
  91. UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param);
  92. float scale_h = static_cast<float>(OH) / IH;
  93. float scale_w = static_cast<float>(OW) / IW;
  94. rounding::RoundingConverter<ctype> output_converter;
  95. rep(n, N) {
  96. rep(oh, OH) rep(ow, OW) {
  97. switch (imode) {
  98. case InterpolationMode::NEAREST: {
  99. auto ih = get_nearest_src(scale_h, IH, oh);
  100. auto iw = get_nearest_src(scale_w, IW, ow);
  101. rep(c, static_cast<int>(C)) {
  102. dptr[c * OH * OW + oh * OW + ow] =
  103. sptr[c * S_IC + ih * S_IH + iw * S_IW];
  104. }
  105. break;
  106. }
  107. case InterpolationMode::INTER_LINEAR: {
  108. auto coord_h = get_origin_coord(scale_h, IH, oh);
  109. auto coord_w = get_origin_coord(scale_w, IW, ow);
  110. float alphah = coord_h.first;
  111. float alphaw = coord_w.first;
  112. int ih0 = coord_h.second;
  113. int ih1 = ih0 + 1;
  114. int iw0 = coord_w.second;
  115. int iw1 = iw0 + 1;
  116. rep(c, static_cast<int>(C)) {
  117. dptr[c * OH * OW + oh * OW + ow] = output_converter(
  118. sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] *
  119. (1.0f - alphaw) * (1.0f - alphah) +
  120. sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] *
  121. alphaw * (1.0f - alphah) +
  122. sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] *
  123. (1.0f - alphaw) * alphah +
  124. sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] *
  125. alphaw * alphah);
  126. }
  127. break;
  128. }
  129. case InterpolationMode::INTER_CUBIC: {
  130. auto coord_h = get_origin_coord(scale_h, IH, oh, true);
  131. auto coord_w = get_origin_coord(scale_w, IW, ow, true);
  132. float alphah = coord_h.first;
  133. float alphaw = coord_w.first;
  134. int ih0 = coord_h.second - 1;
  135. int iw0 = coord_w.second - 1;
  136. float h_coeff[4], w_coeff[4];
  137. interpolate_cubic(alphah, h_coeff);
  138. interpolate_cubic(alphaw, w_coeff);
  139. rep(c, static_cast<int>(C)) {
  140. constexpr int ksize = 4;
  141. float ret = 0;
  142. rep(kh, ksize) {
  143. int h = saturate<int, int>(ih0 + kh, 0, IH - 1);
  144. rep(kw, ksize) {
  145. int w = saturate<int, int>(iw0 + kw, 0, IW - 1);
  146. ret += sptr[c * S_IC + h * S_IH + w * S_IW] *
  147. h_coeff[kh] * w_coeff[kw];
  148. }
  149. }
  150. dptr[c * OH * OW + oh * OW + ow] =
  151. output_converter(ret);
  152. }
  153. break;
  154. }
  155. default:
  156. megdnn_throw("unsupported mode in ResizeBackwardImpl");
  157. break;
  158. }
  159. }
  160. sptr += S_IN;
  161. dptr += C * OH * OW;
  162. }
  163. }
  164. template <typename ctype>
  165. void ResizeImpl::kern_naive(const KernParam<ctype>& kern_param) {
  166. if (kern_param.format == Format::NHWC) {
  167. MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(0)) {
  168. kern_naive_nhwc(kern_param);
  169. }
  170. MIDOUT_END();
  171. return;
  172. } else if (kern_param.format == Format::NHWCD4) {
  173. MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(1)) {
  174. kern_naive_nhwcd4(kern_param);
  175. }
  176. MIDOUT_END();
  177. return;
  178. } else if (kern_param.format == Format::NCHW4) {
  179. MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(2)) {
  180. kern_naive_nchw4(kern_param);
  181. }
  182. MIDOUT_END();
  183. return;
  184. }
  185. }
  186. template <typename ctype>
  187. void ResizeImpl::kern_naive_nhwc(const KernParam<ctype>& kern_param) {
  188. UNPACK_RESIZE_FWD_KERN_PARAM(kern_param);
  189. rounding::RoundingConverter<ctype> output_converter;
  190. float scale_h = static_cast<float>(OH) / IH;
  191. float scale_w = static_cast<float>(OW) / IW;
  192. rep(n, N) {
  193. rep(oh, OH) rep(ow, OW) {
  194. auto coord_h = get_origin_coord(scale_h, IH, oh);
  195. auto coord_w = get_origin_coord(scale_w, IW, ow);
  196. float alphah = coord_h.first;
  197. float alphaw = coord_w.first;
  198. int ih0 = coord_h.second;
  199. int ih1 = ih0 + 1;
  200. int iw0 = coord_w.second;
  201. int iw1 = iw0 + 1;
  202. rep(c, C) {
  203. dptr[(oh * OW + ow) * C + c] = output_converter(
  204. sptr[(ih0 * IW + iw0) * C + c] * (1.0f - alphaw) *
  205. (1.0f - alphah) +
  206. sptr[(ih0 * IW + iw1) * C + c] * alphaw *
  207. (1.0f - alphah) +
  208. sptr[(ih1 * IW + iw0) * C + c] * (1.0f - alphaw) *
  209. alphah +
  210. sptr[(ih1 * IW + iw1) * C + c] * alphaw * alphah);
  211. }
  212. }
  213. sptr += C * IH * IW;
  214. dptr += C * OH * OW;
  215. }
  216. }
  217. template <typename ctype>
  218. void ResizeImpl::kern_naive_nhwcd4(const KernParam<ctype>& kern_param) {
  219. UNPACK_RESIZE_FWD_KERN_PARAM(kern_param);
  220. rounding::RoundingConverter<ctype> output_converter;
  221. float scale_h = static_cast<float>(OH) / IH;
  222. float scale_w = static_cast<float>(OW) / IW;
  223. auto get_tensor_addr = [&](size_t h, size_t w, size_t c, size_t W,
  224. size_t C) -> size_t {
  225. megdnn_assert((C & 0x3) == 0);
  226. size_t CBLK = (C >> 2);
  227. return (h * W * CBLK * 4 + (c >> 2) * W * 4 + w * 4 + (c & 0x3));
  228. };
  229. rep(n, N) {
  230. rep(oh, OH) rep(ow, OW) {
  231. auto coord_h = get_origin_coord(scale_h, IH, oh);
  232. auto coord_w = get_origin_coord(scale_w, IW, ow);
  233. float alphah = coord_h.first;
  234. float alphaw = coord_w.first;
  235. int ih0 = coord_h.second;
  236. int ih1 = ih0 + 1;
  237. int iw0 = coord_w.second;
  238. int iw1 = iw0 + 1;
  239. rep(c, C) {
  240. dptr[get_tensor_addr(oh, ow, c, OW, C)] = output_converter(
  241. sptr[get_tensor_addr(ih0, iw0, c, IW, C)] *
  242. (1.0f - alphaw) * (1.0f - alphah) +
  243. sptr[get_tensor_addr(ih0, iw1, c, IW, C)] * alphaw *
  244. (1.0f - alphah) +
  245. sptr[get_tensor_addr(ih1, iw0, c, IW, C)] *
  246. (1.0f - alphaw) * alphah +
  247. sptr[get_tensor_addr(ih1, iw1, c, IW, C)] * alphaw *
  248. alphah);
  249. }
  250. }
  251. sptr += IH * (C / 4) * IW * 4;
  252. dptr += OH * (C / 4) * OW * 4;
  253. }
  254. }
  255. template <typename ctype>
  256. void ResizeImpl::kern_naive_nchw4(const KernParam<ctype>& kern_param) {
  257. UNPACK_RESIZE_FWD_KERN_PARAM(kern_param);
  258. rounding::RoundingConverter<ctype> output_converter;
  259. float scale_h = static_cast<float>(OH) / IH;
  260. float scale_w = static_cast<float>(OW) / IW;
  261. auto get_tensor_addr = [&](size_t h, size_t w, size_t c, size_t H, size_t W,
  262. size_t C) -> size_t {
  263. megdnn_assert((C & 0x3) == 0);
  264. return (((c >> 2) * H * W + h * W + w) << 2) + (c & 0b11);
  265. };
  266. rep(n, N) {
  267. rep(oh, OH) rep(ow, OW) {
  268. auto coord_h = get_origin_coord(scale_h, IH, oh);
  269. auto coord_w = get_origin_coord(scale_w, IW, ow);
  270. float alphah = coord_h.first;
  271. float alphaw = coord_w.first;
  272. int ih0 = coord_h.second;
  273. int ih1 = ih0 + 1;
  274. int iw0 = coord_w.second;
  275. int iw1 = iw0 + 1;
  276. rep(c, C) {
  277. dptr[get_tensor_addr(oh, ow, c, OH, OW, C)] = output_converter(
  278. sptr[get_tensor_addr(ih0, iw0, c, IH, IW, C)] *
  279. (1.0f - alphaw) * (1.0f - alphah) +
  280. sptr[get_tensor_addr(ih0, iw1, c, IH, IW, C)] * alphaw *
  281. (1.0f - alphah) +
  282. sptr[get_tensor_addr(ih1, iw0, c, IH, IW, C)] *
  283. (1.0f - alphaw) * alphah +
  284. sptr[get_tensor_addr(ih1, iw1, c, IH, IW, C)] * alphaw *
  285. alphah);
  286. }
  287. }
  288. sptr += IH * IW * C;
  289. dptr += OH * OW * C;
  290. }
  291. }
  292. void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  293. _megdnn_workspace workspace) {
  294. check_exec(src.layout, dst.layout, workspace.size);
  295. if (param().format == param::Resize::Format::NCHW) {
  296. #define cb(dt, ct, _midout_iv) \
  297. case DTypeTrait<dt>::enumv: { \
  298. MIDOUT_BEGIN(megdnn_naive_resize_nchw, midout_iv(_midout_iv)) { \
  299. auto kparam = KernParam<ct>::from_tensors(param().format, src, \
  300. dst, workspace); \
  301. MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw(kparam, param().imode)); \
  302. } \
  303. MIDOUT_END(); \
  304. return; \
  305. }
  306. switch (src.layout.dtype.enumv()) {
  307. cb(dtype::Float32, float, 0);
  308. DNN_INC_FLOAT16(cb(dtype::Float16, dt_float16, 1));
  309. cb(dtype::Int8, int8_t, 2);
  310. cb(dtype::QuantizedS8, int8_t, 3);
  311. cb(dtype::Uint8, uint8_t, 4);
  312. cb(dtype::Quantized8Asymm, uint8_t, 5);
  313. default:
  314. megdnn_throw(ssprintf("Unsupported input DType in Resize "
  315. "NEAREST mode: %s",
  316. src.layout.dtype.name())
  317. .c_str());
  318. return;
  319. }
  320. #undef cb
  321. }
  322. if (((src.layout[3] != 1 && src.layout[3] != 3) ||
  323. !is_nhwc_contig_wc(src.layout)) ||
  324. (param().imode == param::Resize::InterpolationMode::LINEAR)) {
  325. #define cb(dt, ct, _midout_iv) \
  326. case DTypeTrait<dt>::enumv: { \
  327. MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(_midout_iv)) { \
  328. auto kparam = KernParam<ct>::from_tensors(param().format, src, \
  329. dst, workspace); \
  330. MEGDNN_DISPATCH_CPU_KERN_OPR(kern_naive(kparam)); \
  331. } \
  332. MIDOUT_END(); \
  333. return; \
  334. }
  335. switch (src.layout.dtype.enumv()) {
  336. cb(dtype::Float32, float, 0);
  337. DNN_INC_FLOAT16(cb(dtype::Float16, dt_float16, 1));
  338. cb(dtype::Int8, int8_t, 2);
  339. cb(dtype::QuantizedS8, int8_t, 3);
  340. cb(dtype::Uint8, uint8_t, 4);
  341. cb(dtype::Quantized8Asymm, uint8_t, 5);
  342. default:
  343. megdnn_throw(ssprintf("Unsupported input DType in Resize: %s",
  344. src.layout.dtype.name())
  345. .c_str());
  346. return;
  347. }
  348. #undef cb
  349. } else {
  350. megdnn_assert(param().format == param::Resize::Format::NHWC,
  351. "invalid resize format");
  352. MEGDNN_DISPATCH_CPU_KERN_OPR(resize_cv_exec(src, dst, param().imode));
  353. }
  354. }
  355. void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
  356. _megdnn_workspace workspace) {
  357. check_exec(diff.layout, grad.layout, workspace.size);
  358. megdnn_assert(param().format == param::Resize::Format::NCHW,
  359. "invalid resize format");
  360. const int N = grad.layout.shape[0], C = grad.layout.shape[1],
  361. IH = grad.layout.shape[2], IW = grad.layout.shape[3];
  362. const int OH = diff.layout.shape[2], OW = diff.layout.shape[3];
  363. const float* hptr_ = diff.ptr<dt_float32>();
  364. float* sptr_ = grad.ptr<dt_float32>();
  365. float scale_h = static_cast<float>(OH) / IH;
  366. float scale_w = static_cast<float>(OW) / IW;
  367. auto kern = [=]() {
  368. auto hptr = hptr_;
  369. auto sptr = sptr_;
  370. std::memset(sptr, 0, sizeof(float) * N * C * IH * IW);
  371. rep(n, N) {
  372. rep(oh, OH) rep(ow, OW) {
  373. switch (param().imode) {
  374. case InterpolationMode::INTER_LINEAR: {
  375. auto coord_h = get_origin_coord(scale_h, IH, oh);
  376. auto coord_w = get_origin_coord(scale_w, IW, ow);
  377. float alphah = coord_h.first;
  378. float alphaw = coord_w.first;
  379. int ih0 = coord_h.second;
  380. int ih1 = ih0 + 1;
  381. int iw0 = coord_w.second;
  382. int iw1 = iw0 + 1;
  383. rep(c, C) {
  384. float hidden = hptr[c * OH * OW + oh * OW + ow];
  385. sptr[c * IH * IW + ih0 * IW + iw0] +=
  386. (1.0f - alphaw) * (1.0f - alphah) * hidden;
  387. sptr[c * IH * IW + ih1 * IW + iw0] +=
  388. (1.0f - alphaw) * alphah * hidden;
  389. sptr[c * IH * IW + ih0 * IW + iw1] +=
  390. alphaw * (1.0f - alphah) * hidden;
  391. sptr[c * IH * IW + ih1 * IW + iw1] +=
  392. alphaw * alphah * hidden;
  393. }
  394. break;
  395. }
  396. case InterpolationMode::NEAREST: {
  397. auto ih = get_nearest_src(scale_h, IH, oh);
  398. auto iw = get_nearest_src(scale_w, IW, ow);
  399. rep(c, static_cast<int>(C)) {
  400. sptr[c * IH * IW + ih * IW + iw] +=
  401. hptr[c * OH * OW + oh * OW + ow];
  402. }
  403. break;
  404. }
  405. case InterpolationMode::INTER_CUBIC: {
  406. auto coord_h = get_origin_coord(scale_h, IH, oh, true);
  407. auto coord_w = get_origin_coord(scale_w, IW, ow, true);
  408. float alphah = coord_h.first;
  409. float alphaw = coord_w.first;
  410. int ih0 = coord_h.second - 1;
  411. int iw0 = coord_w.second - 1;
  412. float h_coeff[4], w_coeff[4];
  413. interpolate_cubic(alphah, h_coeff);
  414. interpolate_cubic(alphaw, w_coeff);
  415. rep(c, static_cast<int>(C)) {
  416. constexpr int ksize = 4;
  417. rep(kh, ksize) {
  418. int h = saturate<int, int>(ih0 + kh, 0, IH - 1);
  419. rep(kw, ksize) {
  420. int w = saturate<int, int>(iw0 + kw, 0, IW - 1);
  421. sptr[c * IH * IW + h * IW + w] +=
  422. hptr[c * OH * OW + oh * OW + ow] *
  423. h_coeff[kh] * w_coeff[kw];
  424. }
  425. }
  426. }
  427. break;
  428. }
  429. default: {
  430. megdnn_throw("unsupported mode in ResizeBackwardImpl");
  431. break;
  432. }
  433. }
  434. }
  435. sptr += C * IH * IW;
  436. hptr += C * OH * OW;
  437. }
  438. };
  439. MEGDNN_DISPATCH_CPU_KERN_OPR(kern());
  440. }
  441. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台