You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

forward.cpp 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. /**
  2. * \file dnn/src/cuda/warp_perspective/forward.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/cuda/warp_perspective/opr_impl.h"
  13. #include "src/cuda/warp_perspective/warp_perspective_cv.cuh"
  14. #include "src/cuda/utils.h"
  15. #include "src/cuda/warp_perspective/common.h"
  16. #include "src/cuda/warp_perspective/helper.h"
  17. #include "src/common/cv/common.h"
  18. #include "src/common/warp_common.h"
  19. namespace megdnn {
  20. namespace cuda {
  21. namespace {
  22. inline void deduce_reformat_layout(std::unique_ptr<RelayoutFormat>& relayout,
  23. const TensorLayout& src_layout,
  24. TensorLayout& dst_layout,
  25. RelayoutFormat::Param::Mode mode,
  26. const int oc = 0, const int group = 1) {
  27. if (src_layout.ndim > 0) {
  28. RelayoutFormat::Param trans_param;
  29. trans_param.mode = mode;
  30. trans_param.oc = oc;
  31. trans_param.group = group;
  32. relayout->param() = trans_param;
  33. relayout->deduce_layout(src_layout, dst_layout);
  34. } else {
  35. dst_layout = src_layout;
  36. }
  37. }
  38. void get_inner_layout(const TensorLayout& src, const TensorLayout& dst,
  39. TensorLayout& inner_src, TensorLayout& inner_dst,
  40. Handle* handle,
  41. WarpPerspectiveForwardImpl::Param::Format format) {
  42. if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  43. src.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
  44. dst.dtype.enumv() == src.dtype.enumv() &&
  45. format == param::WarpPerspective::Format::NCHW) {
  46. auto relayout_opr = handle->create_operator<RelayoutFormat>();
  47. deduce_reformat_layout(relayout_opr, src, inner_src,
  48. RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1);
  49. deduce_reformat_layout(relayout_opr, dst, inner_dst,
  50. RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1);
  51. } else {
  52. megdnn_assert(0, "not support");
  53. }
  54. }
  55. } // namespace
  56. namespace warp_perspective {
  57. void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in mat,
  58. _megdnn_tensor_in dst, float border_val,
  59. BorderMode bmode, InterpolationMode imode,
  60. _megdnn_workspace workspace,
  61. cudaStream_t stream) {
  62. megdnn_assert(src.layout[3] == 1 || src.layout[3] == 3,
  63. "unsupported src channel");
  64. megdnn_assert(src.layout.dtype != dtype::Float32() ||
  65. src.layout.dtype != dtype::Uint8(),
  66. "unsupported src dtype");
  67. if (imode == InterpolationMode::INTER_AREA) {
  68. imode = InterpolationMode::INTER_LINEAR;
  69. }
  70. using namespace megcv;
  71. const float* trans_ptr = mat.ptr<dt_float32>();
  72. double* workspace_ptr = workspace.ptr<double>();
  73. for (size_t i = 0; i < src.layout.shape[0]; ++i) {
  74. if (dst.layout.dtype == dtype::Float32()) {
  75. Mat<float> src_mat = TensorND2Mat<float>(src, i);
  76. Mat<float> dst_mat = TensorND2Mat<float>(dst, i);
  77. if (src_mat.channels() == 1) {
  78. warp_perspective_cv_proxy<float, 1>(
  79. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
  80. src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
  81. src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr,
  82. border_val, workspace_ptr, stream);
  83. } else {
  84. warp_perspective_cv_proxy<float, 3>(
  85. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
  86. src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
  87. src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr,
  88. border_val, workspace_ptr, stream);
  89. }
  90. } else if (dst.layout.dtype == dtype::Uint8()) {
  91. Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i);
  92. Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i);
  93. if (src_mat.channels() == 1) {
  94. warp_perspective_cv_proxy<uchar, 1>(
  95. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
  96. src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
  97. src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr,
  98. static_cast<uchar>(border_val), workspace_ptr, stream);
  99. } else {
  100. warp_perspective_cv_proxy<uchar, 3>(
  101. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
  102. src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
  103. src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr,
  104. static_cast<uchar>(border_val), workspace_ptr, stream);
  105. }
  106. } else {
  107. megdnn_throw("Unsupported datatype of WarpPerspective optr.");
  108. }
  109. trans_ptr += 3 * 3;
  110. workspace_ptr += 3 * 3;
  111. }
  112. }
  113. } // namespace warp_perspective
  114. WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle(
  115. void* ptr, const TensorLayout& src, const TensorLayout& mat,
  116. const TensorLayout& mat_idx, const TensorLayout& dst) const {
  117. MEGDNN_MARK_USED_VAR(mat_idx);
  118. SmallVector<size_t> sizes;
  119. TensorLayout fsrc = src;
  120. TensorLayout fmat = mat;
  121. TensorLayout fdst = dst;
  122. if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  123. src.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
  124. param().format == param::WarpPerspective::Format::NCHW) {
  125. get_inner_layout(src, dst, fsrc, fdst, handle(), param().format);
  126. sizes.push_back(fsrc.span().dist_byte());
  127. sizes.push_back(fdst.span().dist_byte());
  128. } else {
  129. auto get_workspace = [&sizes](TensorLayout& layout) {
  130. if (layout.dtype == dtype::BFloat16()) {
  131. layout.dtype = dtype::Float32();
  132. sizes.push_back(layout.span().dist_byte());
  133. }
  134. };
  135. get_workspace(fsrc);
  136. get_workspace(fmat);
  137. get_workspace(fdst);
  138. }
  139. if (param().format == param::WarpPerspective::Format::NHWC) {
  140. //! use double for the workspace dtype as float may cause
  141. //! accuracy problems
  142. sizes.push_back(mat.total_nr_elems() * sizeof(double));
  143. }
  144. return {ptr, std::move(sizes)};
  145. }
  146. void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
  147. _megdnn_tensor_in smat,
  148. _megdnn_tensor_in smat_idx,
  149. _megdnn_tensor_out sdst,
  150. _megdnn_workspace sworkspace) {
  151. check_exec_allow_nhwc_mat_idx(ssrc.layout, smat.layout, smat_idx.layout,
  152. sdst.layout, sworkspace.size);
  153. TensorND src = ssrc;
  154. TensorND mat = smat;
  155. TensorND mat_idx = smat_idx;
  156. TensorND dst = sdst;
  157. Param::Format inner_format = param().format;
  158. auto bundle =
  159. get_workspace_bundle(sworkspace.raw_ptr, ssrc.layout, smat.layout,
  160. smat_idx.layout, sdst.layout);
  161. auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
  162. concrete_handle(this->handle()), &bundle);
  163. if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  164. ctypecvt.src_to_comp_type(ssrc, src)
  165. .src_to_comp_type(smat, mat)
  166. .src_to_comp_type(sdst, dst);
  167. } else if ((ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  168. ssrc.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
  169. param().format == Param::Format::NCHW) {
  170. auto handle_ptr = handle();
  171. get_inner_layout(ssrc.layout, sdst.layout, src.layout, dst.layout,
  172. handle_ptr, param().format);
  173. src.raw_ptr = bundle.get(0);
  174. dst.raw_ptr = bundle.get(1);
  175. auto relayout_opr = handle_ptr->create_operator<RelayoutFormat>();
  176. RelayoutFormat::Param trans_param;
  177. trans_param.mode = RelayoutFormat::Param::Mode::NCHW_NCHW64;
  178. relayout_opr->param() = trans_param;
  179. relayout_opr->exec(ssrc, src, {});
  180. inner_format = Param::Format::NCHW64;
  181. }
  182. {
  183. auto stream = cuda_stream(this->handle());
  184. bool is_nhwc = inner_format == param::WarpPerspective::Format::NHWC;
  185. if (is_nhwc && param().imode != Param::InterpolationMode::LINEAR) {
  186. // use opencv impl only for nhwc and non-linear interp
  187. megdnn_assert(!mat_idx.raw_ptr,
  188. "mat_idx is not supported in NHWC case with "
  189. "non-linear interpolation");
  190. warp_perspective::warp_perspective_cv_exec(
  191. src, mat, dst, param().border_val,
  192. warp_perspective::get_bmode(param().bmode),
  193. warp_perspective::get_imode(param().imode),
  194. ctypecvt.workspace(), stream);
  195. } else {
  196. megdnn_assert(warp::is_dnn_available(src.layout, mat.layout,
  197. dst.layout, param().imode,
  198. inner_format));
  199. size_t C, IH, IW, OH, OW;
  200. if (is_nhwc) {
  201. C = src.layout.shape[3];
  202. IH = src.layout.shape[1];
  203. IW = src.layout.shape[2];
  204. OH = dst.layout.shape[1];
  205. OW = dst.layout.shape[2];
  206. } else if (inner_format == Param::Format::NCHW4) {
  207. C = src.layout.shape[1] * 4;
  208. IH = src.layout.shape[2];
  209. IW = src.layout.shape[3];
  210. OH = dst.layout.shape[2];
  211. OW = dst.layout.shape[3];
  212. } else if (inner_format == Param::Format::NHWC_NCHW) {
  213. C = src.layout.shape[3];
  214. IH = src.layout.shape[1];
  215. IW = src.layout.shape[2];
  216. OH = dst.layout.shape[2];
  217. OW = dst.layout.shape[3];
  218. } else if (inner_format == Param::Format::NHWC_NCHW4_IC_SMALL) {
  219. C = src.layout.shape[3];
  220. IH = src.layout.shape[1];
  221. IW = src.layout.shape[2];
  222. OH = dst.layout.shape[2];
  223. OW = dst.layout.shape[3];
  224. megdnn_assert(
  225. (C == 1) || (C == 3),
  226. "NHWC_NCHW4_IC_SMALL only support C == 1 or C == 3");
  227. } else if (inner_format == Param::Format::NCHW_NCHW4_IC_SMALL) {
  228. C = src.layout.shape[1];
  229. IH = src.layout.shape[2];
  230. IW = src.layout.shape[3];
  231. OH = dst.layout.shape[2];
  232. OW = dst.layout.shape[3];
  233. megdnn_assert(
  234. (C == 1) || (C == 3),
  235. "NCHW_NCHW4_IC_SMALL only support C == 1 or C == 3");
  236. } else if (inner_format == Param::Format::NCHW64) {
  237. C = src.layout.shape[1] * 64;
  238. IH = src.layout.shape[2];
  239. IW = src.layout.shape[3];
  240. OH = dst.layout.shape[2];
  241. OW = dst.layout.shape[3];
  242. } else {
  243. megdnn_assert(
  244. inner_format == param::WarpPerspective::Format::NCHW,
  245. "invalid warp_perspective format");
  246. C = src.layout.shape[1];
  247. IH = src.layout.shape[2];
  248. IW = src.layout.shape[3];
  249. OH = dst.layout.shape[2];
  250. OW = dst.layout.shape[3];
  251. }
  252. megdnn_assert(param().imode == Param::InterpolationMode::LINEAR,
  253. "unsupported interpolation mode for NCHW format");
  254. auto bval = param().border_val;
  255. auto bmode = warp_perspective::get_bmode(param().bmode);
  256. if (src.layout.dtype == dst.layout.dtype) {
  257. if (src.layout.dtype == dtype::Float32{}) {
  258. warp_perspective::forward_proxy(
  259. is_nhwc, src.ptr<dt_float32>(),
  260. mat.ptr<dt_float32>(),
  261. mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
  262. dst.ptr<dt_float32>(), src.layout[0], mat.layout[0],
  263. C, IH, IW, OH, OW, bval, bmode,
  264. async_error_info(handle()), m_error_tracker,
  265. stream);
  266. } else if (DNN_FLOAT16_SELECT(
  267. src.layout.dtype == dtype::Float16(),
  268. false)) {
  269. #ifndef MEGDNN_DISABLE_FLOAT16
  270. warp_perspective::forward_proxy(
  271. is_nhwc, src.ptr<dt_float16>(),
  272. mat.ptr<dt_float32>(),
  273. mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
  274. dst.ptr<dt_float16>(), src.layout[0], mat.layout[0],
  275. C, IH, IW, OH, OW, static_cast<dt_float16>(bval),
  276. bmode, async_error_info(handle()), m_error_tracker,
  277. stream);
  278. #endif
  279. } else if (src.layout.dtype == dtype::Uint8()) {
  280. warp_perspective::forward_proxy<dt_uint8>(
  281. is_nhwc, src.ptr<dt_uint8>(), mat.ptr<dt_float32>(),
  282. mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
  283. dst.ptr<dt_uint8>(), src.layout[0], mat.layout[0],
  284. C, IH, IW, OH, OW, bval, bmode,
  285. async_error_info(handle()), m_error_tracker,
  286. stream);
  287. } else if (src.layout.dtype == dtype::Int8()) {
  288. megdnn_assert(!is_nhwc,
  289. "WarpPerspective on CUDA does not support "
  290. "NHWC + Int8");
  291. warp_perspective::forward_proxy<dt_int8>(
  292. false, src.ptr<dt_int8>(), mat.ptr<dt_float32>(),
  293. mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
  294. dst.ptr<dt_int8>(), src.layout[0], mat.layout[0], C,
  295. IH, IW, OH, OW,
  296. bval /* implicit float -> int8 conversion,
  297. should be safe */
  298. ,
  299. bmode, async_error_info(handle()), m_error_tracker,
  300. stream);
  301. } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
  302. megdnn_assert(param().format == Param::Format::NCHW4,
  303. "WarpPerspective on CUDA supports NCHW4 + "
  304. "QuantizedS8 only");
  305. warp_perspective::forward_proxy_nchw4<dt_int8>(
  306. src.compatible_ptr<dt_int8>(),
  307. mat.ptr<dt_float32>(),
  308. mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
  309. dst.compatible_ptr<dt_int8>(), src.layout[0],
  310. mat.layout[0], C, IH, IW, OH, OW, bval, bmode,
  311. async_error_info(handle()), m_error_tracker,
  312. stream);
  313. } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
  314. megdnn_assert(
  315. param().format == Param::Format::NCHW64 ||
  316. param().format == Param::Format::NCHW,
  317. "WarpPerspective on CUDA supports NCHW64 or NCHW+ "
  318. "QuantizedS4");
  319. bval = roundf(bval);
  320. bval = fmin(fmax(-8.f, bval), 7.f);
  321. warp_perspective::forward_proxy_nchw64<dt_qint4>(
  322. src.compatible_ptr<dt_qint4>(),
  323. mat.ptr<dt_float32>(),
  324. mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
  325. dst.compatible_ptr<dt_qint4>(), src.layout[0],
  326. mat.layout[0], C, IH, IW, OH, OW,
  327. static_cast<dt_qint4>(bval), bmode,
  328. async_error_info(handle()), m_error_tracker,
  329. stream);
  330. if (param().format == Param::Format::NCHW) {
  331. auto relayout_opr =
  332. handle()->create_operator<RelayoutFormat>();
  333. RelayoutFormat::Param trans_param;
  334. trans_param.mode =
  335. RelayoutFormat::Param::Mode::NCHW64_NCHW;
  336. trans_param.oc = sdst.layout[1];
  337. relayout_opr->param() = trans_param;
  338. relayout_opr->exec(dst, sdst, {});
  339. }
  340. } else if (src.layout.dtype.enumv() ==
  341. DTypeEnum::Quantized4Asymm) {
  342. megdnn_assert(
  343. param().format == Param::Format::NCHW64 ||
  344. param().format == Param::Format::NCHW,
  345. "WarpPerspective on CUDA supports NCHW64 or NCHW+ "
  346. "Quantized4Asymm");
  347. bval = roundf(bval);
  348. bval = fmin(fmax(0, bval), 15);
  349. warp_perspective::forward_proxy_nchw64<dt_quint4>(
  350. src.compatible_ptr<dt_quint4>(),
  351. mat.ptr<dt_float32>(),
  352. mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
  353. dst.compatible_ptr<dt_quint4>(), src.layout[0],
  354. mat.layout[0], C, IH, IW, OH, OW,
  355. static_cast<dt_quint4>(bval), bmode,
  356. async_error_info(handle()), m_error_tracker,
  357. stream);
  358. if (param().format == Param::Format::NCHW) {
  359. auto relayout_opr =
  360. handle()->create_operator<RelayoutFormat>();
  361. RelayoutFormat::Param trans_param;
  362. trans_param.mode =
  363. RelayoutFormat::Param::Mode::NCHW64_NCHW;
  364. trans_param.oc = sdst.layout[1];
  365. relayout_opr->param() = trans_param;
  366. relayout_opr->exec(dst, sdst, {});
  367. }
  368. }
  369. } else if ((src.layout.dtype.enumv() ==
  370. DTypeEnum::Quantized8Asymm ||
  371. src.layout.dtype.enumv() == DTypeEnum::Uint8)) {
  372. uint8_t zero_point = 0;
  373. float scale = 1.f;
  374. if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
  375. zero_point =
  376. src.layout.dtype.param<dtype::Quantized8Asymm>()
  377. .zero_point;
  378. scale = src.layout.dtype.param<dtype::Quantized8Asymm>()
  379. .scale;
  380. } else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 &&
  381. dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
  382. zero_point = 128;
  383. scale = 1.f;
  384. }
  385. DTypeParamImpl<dt_quint8> src_dtype_param(scale, zero_point);
  386. if ((dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8 &&
  387. dst.layout.dtype.param<dtype::QuantizedS8>().scale ==
  388. scale) &&
  389. ((param().format == Param::Format::NCHW_NCHW4_IC_SMALL) ||
  390. (param().format == Param::Format::NHWC_NCHW4_IC_SMALL))) {
  391. bool is_nhwc_ic_small =
  392. (param().format ==
  393. Param::Format::NHWC_NCHW4_IC_SMALL);
  394. warp_perspective::
  395. forward_proxy_quint8_dimshuffle_typecvt_nchw4<
  396. dt_quint8, dt_uint8, dt_int8>(
  397. is_nhwc_ic_small,
  398. src.compatible_ptr<dt_uint8>(),
  399. mat.ptr<dt_float32>(),
  400. mat_idx.raw_ptr ? mat_idx.ptr<int>()
  401. : nullptr,
  402. dst.compatible_ptr<dt_int8>(),
  403. src.layout[0], mat.layout[0], C, IH, IW, OH,
  404. OW, bval, src_dtype_param, bmode,
  405. async_error_info(handle()), m_error_tracker,
  406. stream);
  407. } else {
  408. megdnn_assert(
  409. ((dst.layout.dtype.enumv() == DTypeEnum::Float32) &&
  410. ((param().format == Param::Format::NCHW) ||
  411. (param().format == Param::Format::NHWC_NCHW))),
  412. "invalid format for Quantized8Asymm input");
  413. bool is_nhwc = (param().format == Param::Format::NHWC_NCHW);
  414. warp_perspective::
  415. forward_proxy_quint8_dimshuffle_typecvt_nchw<
  416. dt_quint8, dt_uint8, dt_float32>(
  417. is_nhwc, src.compatible_ptr<dt_uint8>(),
  418. mat.ptr<dt_float32>(),
  419. mat_idx.raw_ptr ? mat_idx.ptr<int>()
  420. : nullptr,
  421. dst.compatible_ptr<dt_float32>(),
  422. src.layout[0], mat.layout[0], C, IH, IW, OH,
  423. OW, bval, src_dtype_param, bmode,
  424. async_error_info(handle()), m_error_tracker,
  425. stream);
  426. }
  427. } else {
  428. megdnn_throw(ssprintf("unsupported dtype: %s",
  429. src.layout.dtype.name()));
  430. }
  431. }
  432. }
  433. if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  434. ctypecvt.comp_to_dst_type(dst, sdst);
  435. }
  436. }
  437. } // namespace cuda
  438. } // namespace megdnn
  439. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台