You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

forward.cpp 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. /**
  2. * \file dnn/src/cuda/warp_perspective/forward.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/cuda/warp_perspective/opr_impl.h"
  13. #include "src/cuda/warp_perspective/warp_perspective_cv.cuh"
  14. #include "src/cuda/utils.h"
  15. #include "src/cuda/warp_perspective/common.h"
  16. #include "src/cuda/warp_perspective/helper.h"
  17. #include "src/common/cv/common.h"
  18. #include "src/common/warp_common.h"
  19. namespace megdnn {
  20. namespace cuda {
  21. namespace {
  22. inline void deduce_reformat_layout(std::unique_ptr<RelayoutFormat>& relayout,
  23. const TensorLayout& src_layout,
  24. TensorLayout& dst_layout,
  25. RelayoutFormat::Param::Mode mode,
  26. const int oc = 0, const int group = 1) {
  27. if (src_layout.ndim > 0) {
  28. RelayoutFormat::Param trans_param;
  29. trans_param.mode = mode;
  30. trans_param.oc = oc;
  31. trans_param.group = group;
  32. relayout->param() = trans_param;
  33. relayout->deduce_layout(src_layout, dst_layout);
  34. } else {
  35. dst_layout = src_layout;
  36. }
  37. }
  38. void get_inner_layout(const TensorLayout& src, const TensorLayout& dst,
  39. TensorLayout& inner_src, TensorLayout& inner_dst,
  40. Handle* handle,
  41. WarpPerspectiveForwardImpl::Param::Format format) {
  42. if (src.dtype.enumv() == DTypeEnum::QuantizedS4 &&
  43. dst.dtype.enumv() == DTypeEnum::QuantizedS4 &&
  44. format == param::WarpPerspective::Format::NCHW) {
  45. auto relayout_opr = handle->create_operator<RelayoutFormat>();
  46. deduce_reformat_layout(relayout_opr, src, inner_src,
  47. RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1);
  48. deduce_reformat_layout(relayout_opr, dst, inner_dst,
  49. RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1);
  50. } else {
  51. megdnn_assert(0, "not support");
  52. }
  53. }
  54. } // namespace
  55. namespace warp_perspective {
  56. void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in mat,
  57. _megdnn_tensor_in dst, float border_val,
  58. BorderMode bmode, InterpolationMode imode,
  59. _megdnn_workspace workspace,
  60. cudaStream_t stream) {
  61. megdnn_assert(src.layout[3] == 1 || src.layout[3] == 3,
  62. "unsupported src channel");
  63. megdnn_assert(src.layout.dtype != dtype::Float32() ||
  64. src.layout.dtype != dtype::Uint8(),
  65. "unsupported src dtype");
  66. if (imode == InterpolationMode::INTER_AREA) {
  67. imode = InterpolationMode::INTER_LINEAR;
  68. }
  69. using namespace megcv;
  70. const float* trans_ptr = mat.ptr<dt_float32>();
  71. double* workspace_ptr = workspace.ptr<double>();
  72. for (size_t i = 0; i < src.layout.shape[0]; ++i) {
  73. if (dst.layout.dtype == dtype::Float32()) {
  74. Mat<float> src_mat = TensorND2Mat<float>(src, i);
  75. Mat<float> dst_mat = TensorND2Mat<float>(dst, i);
  76. if (src_mat.channels() == 1) {
  77. warp_perspective_cv_proxy<float, 1>(
  78. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
  79. src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
  80. src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr,
  81. border_val, workspace_ptr, stream);
  82. } else {
  83. warp_perspective_cv_proxy<float, 3>(
  84. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
  85. src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
  86. src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr,
  87. border_val, workspace_ptr, stream);
  88. }
  89. } else if (dst.layout.dtype == dtype::Uint8()) {
  90. Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i);
  91. Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i);
  92. if (src_mat.channels() == 1) {
  93. warp_perspective_cv_proxy<uchar, 1>(
  94. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
  95. src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
  96. src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr,
  97. static_cast<uchar>(border_val), workspace_ptr, stream);
  98. } else {
  99. warp_perspective_cv_proxy<uchar, 3>(
  100. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
  101. src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
  102. src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr,
  103. static_cast<uchar>(border_val), workspace_ptr, stream);
  104. }
  105. } else {
  106. megdnn_throw("Unsupported datatype of WarpPerspective optr.");
  107. }
  108. trans_ptr += 3 * 3;
  109. workspace_ptr += 3 * 3;
  110. }
  111. }
  112. } // namespace warp_perspective
  113. WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle(
  114. void* ptr, const TensorLayout& src, const TensorLayout& mat,
  115. const TensorLayout& mat_idx, const TensorLayout& dst) const {
  116. MEGDNN_MARK_USED_VAR(mat_idx);
  117. SmallVector<size_t> sizes;
  118. TensorLayout fsrc = src;
  119. TensorLayout fmat = mat;
  120. TensorLayout fdst = dst;
  121. if (src.dtype.enumv() == DTypeEnum::QuantizedS4 &&
  122. param().format == param::WarpPerspective::Format::NCHW) {
  123. get_inner_layout(src, dst, fsrc, fdst, handle(), param().format);
  124. sizes.push_back(fsrc.span().dist_byte());
  125. sizes.push_back(fdst.span().dist_byte());
  126. } else {
  127. auto get_workspace = [&sizes](TensorLayout& layout) {
  128. if (layout.dtype == dtype::BFloat16()) {
  129. layout.dtype = dtype::Float32();
  130. sizes.push_back(layout.span().dist_byte());
  131. }
  132. };
  133. get_workspace(fsrc);
  134. get_workspace(fmat);
  135. get_workspace(fdst);
  136. }
  137. if (param().format == param::WarpPerspective::Format::NHWC) {
  138. //! use double for the workspace dtype as float may cause
  139. //! accuracy problems
  140. sizes.push_back(mat.total_nr_elems() * sizeof(double));
  141. }
  142. return {ptr, std::move(sizes)};
  143. }
  144. void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
  145. _megdnn_tensor_in smat,
  146. _megdnn_tensor_in smat_idx,
  147. _megdnn_tensor_out sdst,
  148. _megdnn_workspace sworkspace) {
  149. check_exec_allow_nhwc_mat_idx(ssrc.layout, smat.layout, smat_idx.layout,
  150. sdst.layout, sworkspace.size);
  151. TensorND src = ssrc;
  152. TensorND mat = smat;
  153. TensorND mat_idx = smat_idx;
  154. TensorND dst = sdst;
  155. Param::Format inner_format = param().format;
  156. auto bundle =
  157. get_workspace_bundle(sworkspace.raw_ptr, ssrc.layout, smat.layout,
  158. smat_idx.layout, sdst.layout);
  159. auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
  160. concrete_handle(this->handle()), &bundle);
  161. if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  162. ctypecvt.src_to_comp_type(ssrc, src)
  163. .src_to_comp_type(smat, mat)
  164. .src_to_comp_type(sdst, dst);
  165. } else if (ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 &&
  166. param().format == Param::Format::NCHW) {
  167. auto handle_ptr = handle();
  168. get_inner_layout(ssrc.layout, sdst.layout, src.layout, dst.layout,
  169. handle_ptr, param().format);
  170. src.raw_ptr = bundle.get(0);
  171. dst.raw_ptr = bundle.get(1);
  172. auto relayout_opr = handle_ptr->create_operator<RelayoutFormat>();
  173. RelayoutFormat::Param trans_param;
  174. trans_param.mode = RelayoutFormat::Param::Mode::NCHW_NCHW64;
  175. relayout_opr->param() = trans_param;
  176. relayout_opr->exec(ssrc, src, {});
  177. inner_format = Param::Format::NCHW64;
  178. }
  179. {
  180. auto stream = cuda_stream(this->handle());
  181. bool is_nhwc = inner_format == param::WarpPerspective::Format::NHWC;
  182. if (is_nhwc && param().imode != Param::InterpolationMode::LINEAR) {
  183. // use opencv impl only for nhwc and non-linear interp
  184. megdnn_assert(!mat_idx.raw_ptr,
  185. "mat_idx is not supported in NHWC case with "
  186. "non-linear interpolation");
  187. warp_perspective::warp_perspective_cv_exec(
  188. src, mat, dst, param().border_val,
  189. warp_perspective::get_bmode(param().bmode),
  190. warp_perspective::get_imode(param().imode),
  191. ctypecvt.workspace(), stream);
  192. } else {
  193. megdnn_assert(warp::is_dnn_available(src.layout, mat.layout,
  194. dst.layout, param().imode,
  195. inner_format));
  196. size_t C, IH, IW, OH, OW;
  197. if (is_nhwc) {
  198. C = src.layout.shape[3];
  199. IH = src.layout.shape[1];
  200. IW = src.layout.shape[2];
  201. OH = dst.layout.shape[1];
  202. OW = dst.layout.shape[2];
  203. } else if (inner_format == Param::Format::NCHW4) {
  204. C = src.layout.shape[1] * 4;
  205. IH = src.layout.shape[2];
  206. IW = src.layout.shape[3];
  207. OH = dst.layout.shape[2];
  208. OW = dst.layout.shape[3];
  209. } else if (inner_format == Param::Format::NHWC_NCHW) {
  210. C = src.layout.shape[3];
  211. IH = src.layout.shape[1];
  212. IW = src.layout.shape[2];
  213. OH = dst.layout.shape[2];
  214. OW = dst.layout.shape[3];
  215. } else if (inner_format == Param::Format::NHWC_NCHW4_IC_SMALL) {
  216. C = src.layout.shape[3];
  217. IH = src.layout.shape[1];
  218. IW = src.layout.shape[2];
  219. OH = dst.layout.shape[2];
  220. OW = dst.layout.shape[3];
  221. megdnn_assert(
  222. (C == 1) || (C == 3),
  223. "NHWC_NCHW4_IC_SMALL only support C == 1 or C == 3");
  224. } else if (inner_format == Param::Format::NCHW_NCHW4_IC_SMALL) {
  225. C = src.layout.shape[1];
  226. IH = src.layout.shape[2];
  227. IW = src.layout.shape[3];
  228. OH = dst.layout.shape[2];
  229. OW = dst.layout.shape[3];
  230. megdnn_assert(
  231. (C == 1) || (C == 3),
  232. "NCHW_NCHW4_IC_SMALL only support C == 1 or C == 3");
  233. } else if (inner_format == Param::Format::NCHW64) {
  234. C = src.layout.shape[1] * 64;
  235. IH = src.layout.shape[2];
  236. IW = src.layout.shape[3];
  237. OH = dst.layout.shape[2];
  238. OW = dst.layout.shape[3];
  239. } else {
  240. megdnn_assert(
  241. inner_format == param::WarpPerspective::Format::NCHW,
  242. "invalid warp_perspective format");
  243. C = src.layout.shape[1];
  244. IH = src.layout.shape[2];
  245. IW = src.layout.shape[3];
  246. OH = dst.layout.shape[2];
  247. OW = dst.layout.shape[3];
  248. }
  249. megdnn_assert(param().imode == Param::InterpolationMode::LINEAR,
  250. "unsupported interpolation mode for NCHW format");
  251. auto bval = param().border_val;
  252. auto bmode = warp_perspective::get_bmode(param().bmode);
  253. if (src.layout.dtype == dst.layout.dtype) {
  254. if (src.layout.dtype == dtype::Float32{}) {
  255. warp_perspective::forward_proxy(
  256. is_nhwc, src.ptr<dt_float32>(),
  257. mat.ptr<dt_float32>(),
  258. mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
  259. dst.ptr<dt_float32>(), src.layout[0], mat.layout[0],
  260. C, IH, IW, OH, OW, bval, bmode,
  261. async_error_info(handle()), m_error_tracker,
  262. stream);
  263. } else if (DNN_FLOAT16_SELECT(
  264. src.layout.dtype == dtype::Float16(),
  265. false)) {
  266. #ifndef MEGDNN_DISABLE_FLOAT16
  267. warp_perspective::forward_proxy(
  268. is_nhwc, src.ptr<dt_float16>(),
  269. mat.ptr<dt_float32>(),
  270. mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
  271. dst.ptr<dt_float16>(), src.layout[0], mat.layout[0],
  272. C, IH, IW, OH, OW, static_cast<dt_float16>(bval),
  273. bmode, async_error_info(handle()), m_error_tracker,
  274. stream);
  275. #endif
  276. } else if (src.layout.dtype == dtype::Uint8()) {
  277. warp_perspective::forward_proxy<dt_uint8>(
  278. is_nhwc, src.ptr<dt_uint8>(), mat.ptr<dt_float32>(),
  279. mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
  280. dst.ptr<dt_uint8>(), src.layout[0], mat.layout[0],
  281. C, IH, IW, OH, OW, bval, bmode,
  282. async_error_info(handle()), m_error_tracker,
  283. stream);
  284. } else if (src.layout.dtype == dtype::Int8()) {
  285. megdnn_assert(!is_nhwc,
  286. "WarpPerspective on CUDA does not support "
  287. "NHWC + Int8");
  288. warp_perspective::forward_proxy<dt_int8>(
  289. false, src.ptr<dt_int8>(), mat.ptr<dt_float32>(),
  290. mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
  291. dst.ptr<dt_int8>(), src.layout[0], mat.layout[0], C,
  292. IH, IW, OH, OW,
  293. bval /* implicit float -> int8 conversion,
  294. should be safe */
  295. ,
  296. bmode, async_error_info(handle()), m_error_tracker,
  297. stream);
  298. } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
  299. megdnn_assert(param().format == Param::Format::NCHW4,
  300. "WarpPerspective on CUDA supports NCHW4 + "
  301. "QuantizedS8 only");
  302. warp_perspective::forward_proxy_nchw4<dt_int8>(
  303. src.compatible_ptr<dt_int8>(),
  304. mat.ptr<dt_float32>(),
  305. mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
  306. dst.compatible_ptr<dt_int8>(), src.layout[0],
  307. mat.layout[0], C, IH, IW, OH, OW, bval, bmode,
  308. async_error_info(handle()), m_error_tracker,
  309. stream);
  310. } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
  311. megdnn_assert(
  312. param().format == Param::Format::NCHW64 ||
  313. param().format == Param::Format::NCHW,
  314. "WarpPerspective on CUDA supports NCHW64 or NCHW+ "
  315. "QuantizedS4 only");
  316. bval = roundf(bval);
  317. bval = fmin(fmax(-8.f, bval), 7.f);
  318. warp_perspective::forward_proxy_nchw64<dt_qint4>(
  319. src.compatible_ptr<dt_qint4>(),
  320. mat.ptr<dt_float32>(),
  321. mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
  322. dst.compatible_ptr<dt_qint4>(), src.layout[0],
  323. mat.layout[0], C, IH, IW, OH, OW,
  324. static_cast<dt_qint4>(bval), bmode,
  325. async_error_info(handle()), m_error_tracker,
  326. stream);
  327. if (param().format == Param::Format::NCHW) {
  328. auto relayout_opr =
  329. handle()->create_operator<RelayoutFormat>();
  330. RelayoutFormat::Param trans_param;
  331. trans_param.mode =
  332. RelayoutFormat::Param::Mode::NCHW64_NCHW;
  333. relayout_opr->param() = trans_param;
  334. relayout_opr->exec(dst, sdst, {});
  335. }
  336. }
  337. } else if ((src.layout.dtype.enumv() ==
  338. DTypeEnum::Quantized8Asymm ||
  339. src.layout.dtype.enumv() == DTypeEnum::Uint8)) {
  340. uint8_t zero_point = 0;
  341. float scale = 1.f;
  342. if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
  343. zero_point =
  344. src.layout.dtype.param<dtype::Quantized8Asymm>()
  345. .zero_point;
  346. scale = src.layout.dtype.param<dtype::Quantized8Asymm>()
  347. .scale;
  348. } else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 &&
  349. dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
  350. zero_point = 128;
  351. scale = 1.f;
  352. }
  353. DTypeParamImpl<dt_quint8> src_dtype_param(scale, zero_point);
  354. if ((dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8 &&
  355. dst.layout.dtype.param<dtype::QuantizedS8>().scale ==
  356. scale) &&
  357. ((param().format == Param::Format::NCHW_NCHW4_IC_SMALL) ||
  358. (param().format == Param::Format::NHWC_NCHW4_IC_SMALL))) {
  359. bool is_nhwc_ic_small =
  360. (param().format ==
  361. Param::Format::NHWC_NCHW4_IC_SMALL);
  362. warp_perspective::
  363. forward_proxy_quint8_dimshuffle_typecvt_nchw4<
  364. dt_quint8, dt_uint8, dt_int8>(
  365. is_nhwc_ic_small,
  366. src.compatible_ptr<dt_uint8>(),
  367. mat.ptr<dt_float32>(),
  368. mat_idx.raw_ptr ? mat_idx.ptr<int>()
  369. : nullptr,
  370. dst.compatible_ptr<dt_int8>(),
  371. src.layout[0], mat.layout[0], C, IH, IW, OH,
  372. OW, bval, src_dtype_param, bmode,
  373. async_error_info(handle()), m_error_tracker,
  374. stream);
  375. } else {
  376. megdnn_assert(
  377. ((dst.layout.dtype.enumv() == DTypeEnum::Float32) &&
  378. ((param().format == Param::Format::NCHW) ||
  379. (param().format == Param::Format::NHWC_NCHW))),
  380. "invalid format for Quantized8Asymm input");
  381. bool is_nhwc = (param().format == Param::Format::NHWC_NCHW);
  382. warp_perspective::
  383. forward_proxy_quint8_dimshuffle_typecvt_nchw<
  384. dt_quint8, dt_uint8, dt_float32>(
  385. is_nhwc, src.compatible_ptr<dt_uint8>(),
  386. mat.ptr<dt_float32>(),
  387. mat_idx.raw_ptr ? mat_idx.ptr<int>()
  388. : nullptr,
  389. dst.compatible_ptr<dt_float32>(),
  390. src.layout[0], mat.layout[0], C, IH, IW, OH,
  391. OW, bval, src_dtype_param, bmode,
  392. async_error_info(handle()), m_error_tracker,
  393. stream);
  394. }
  395. } else {
  396. megdnn_throw(ssprintf("unsupported dtype: %s",
  397. src.layout.dtype.name()));
  398. }
  399. }
  400. }
  401. if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  402. ctypecvt.comp_to_dst_type(dst, sdst);
  403. }
  404. }
  405. } // namespace cuda
  406. } // namespace megdnn
  407. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台