You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

forward.cpp 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. #include "src/cuda/warp_perspective/opr_impl.h"
  2. #include "src/cuda/warp_perspective/warp_perspective_cv.cuh"
  3. #include "src/cuda/utils.h"
  4. #include "src/cuda/warp_perspective/common.h"
  5. #include "src/cuda/warp_perspective/helper.h"
  6. #include "src/common/cv/common.h"
  7. #include "src/common/warp_common.h"
  8. namespace megdnn {
  9. namespace cuda {
  10. namespace {
  11. inline void deduce_reformat_layout(
  12. std::unique_ptr<RelayoutFormat>& relayout, const TensorLayout& src_layout,
  13. TensorLayout& dst_layout, RelayoutFormat::Param::Mode mode, const int oc = 0,
  14. const int group = 1) {
  15. if (src_layout.ndim > 0) {
  16. RelayoutFormat::Param trans_param;
  17. trans_param.mode = mode;
  18. trans_param.oc = oc;
  19. trans_param.group = group;
  20. relayout->param() = trans_param;
  21. relayout->deduce_layout(src_layout, dst_layout);
  22. } else {
  23. dst_layout = src_layout;
  24. }
  25. }
  26. void get_inner_layout(
  27. const TensorLayout& src, const TensorLayout& dst, TensorLayout& inner_src,
  28. TensorLayout& inner_dst, Handle* handle,
  29. WarpPerspectiveForwardImpl::Param::Format format) {
  30. if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  31. src.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
  32. dst.dtype.enumv() == src.dtype.enumv() &&
  33. format == param::WarpPerspective::Format::NCHW) {
  34. auto relayout_opr = handle->create_operator<RelayoutFormat>();
  35. deduce_reformat_layout(
  36. relayout_opr, src, inner_src, RelayoutFormat::Param::Mode::NCHW_NCHW64,
  37. 0, 1);
  38. deduce_reformat_layout(
  39. relayout_opr, dst, inner_dst, RelayoutFormat::Param::Mode::NCHW_NCHW64,
  40. 0, 1);
  41. } else {
  42. megdnn_assert(0, "not support");
  43. }
  44. }
  45. } // namespace
  46. namespace warp_perspective {
  47. void warp_perspective_cv_exec(
  48. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in dst,
  49. float border_val, BorderMode bmode, InterpolationMode imode,
  50. _megdnn_workspace workspace, cudaStream_t stream) {
  51. megdnn_assert(src.layout[3] == 1 || src.layout[3] == 3, "unsupported src channel");
  52. megdnn_assert(
  53. src.layout.dtype != dtype::Float32() || src.layout.dtype != dtype::Uint8(),
  54. "unsupported src dtype");
  55. if (imode == InterpolationMode::INTER_AREA) {
  56. imode = InterpolationMode::INTER_LINEAR;
  57. }
  58. using namespace megcv;
  59. const float* trans_ptr = mat.ptr<dt_float32>();
  60. double* workspace_ptr = workspace.ptr<double>();
  61. for (size_t i = 0; i < src.layout.shape[0]; ++i) {
  62. if (dst.layout.dtype == dtype::Float32()) {
  63. Mat<float> src_mat = TensorND2Mat<float>(src, i);
  64. Mat<float> dst_mat = TensorND2Mat<float>(dst, i);
  65. if (src_mat.channels() == 1) {
  66. warp_perspective_cv_proxy<float, 1>(
  67. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(),
  68. dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(),
  69. bmode, imode, trans_ptr, border_val, workspace_ptr, stream);
  70. } else {
  71. warp_perspective_cv_proxy<float, 3>(
  72. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(),
  73. dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(),
  74. bmode, imode, trans_ptr, border_val, workspace_ptr, stream);
  75. }
  76. } else if (dst.layout.dtype == dtype::Uint8()) {
  77. Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i);
  78. Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i);
  79. if (src_mat.channels() == 1) {
  80. warp_perspective_cv_proxy<uchar, 1>(
  81. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(),
  82. dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(),
  83. bmode, imode, trans_ptr, static_cast<uchar>(border_val),
  84. workspace_ptr, stream);
  85. } else {
  86. warp_perspective_cv_proxy<uchar, 3>(
  87. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(),
  88. dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(),
  89. bmode, imode, trans_ptr, static_cast<uchar>(border_val),
  90. workspace_ptr, stream);
  91. }
  92. } else {
  93. megdnn_throw("Unsupported datatype of WarpPerspective optr.");
  94. }
  95. trans_ptr += 3 * 3;
  96. workspace_ptr += 3 * 3;
  97. }
  98. }
  99. } // namespace warp_perspective
  100. WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle(
  101. void* ptr, const TensorLayout& src, const TensorLayout& mat,
  102. const TensorLayout& mat_idx, const TensorLayout& dst) const {
  103. MEGDNN_MARK_USED_VAR(mat_idx);
  104. SmallVector<size_t> sizes;
  105. TensorLayout fsrc = src;
  106. TensorLayout fmat = mat;
  107. TensorLayout fdst = dst;
  108. if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  109. src.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
  110. param().format == param::WarpPerspective::Format::NCHW) {
  111. get_inner_layout(src, dst, fsrc, fdst, handle(), param().format);
  112. sizes.push_back(fsrc.span().dist_byte());
  113. sizes.push_back(fdst.span().dist_byte());
  114. } else {
  115. auto get_workspace = [&sizes](TensorLayout& layout) {
  116. if (layout.dtype == dtype::BFloat16()) {
  117. layout.dtype = dtype::Float32();
  118. sizes.push_back(layout.span().dist_byte());
  119. }
  120. };
  121. get_workspace(fsrc);
  122. get_workspace(fmat);
  123. get_workspace(fdst);
  124. }
  125. if (param().format == param::WarpPerspective::Format::NHWC) {
  126. //! use double for the workspace dtype as float may cause
  127. //! accuracy problems
  128. sizes.push_back(mat.total_nr_elems() * sizeof(double));
  129. }
  130. return {ptr, std::move(sizes)};
  131. }
  132. WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle(
  133. void* ptr, const TensorLayoutArray& srcs, const TensorLayout& mat,
  134. const TensorLayout& mat_idx, const TensorLayout& dst) const {
  135. MEGDNN_MARK_USED_VAR(mat_idx);
  136. SmallVector<size_t> sizes;
  137. TensorLayoutArray fsrcs = srcs;
  138. TensorLayout fmat = mat;
  139. TensorLayout fdst = dst;
  140. auto get_workspace = [&sizes](TensorLayout& layout) {
  141. if (layout.dtype == dtype::BFloat16()) {
  142. layout.dtype = dtype::Float32();
  143. sizes.push_back(layout.span().dist_byte());
  144. }
  145. };
  146. for (auto&& fsrc : fsrcs) {
  147. get_workspace(fsrc);
  148. }
  149. get_workspace(fmat);
  150. get_workspace(fdst);
  151. sizes.push_back(sizeof(dt_float32*) * srcs.size());
  152. if (param().format == param::WarpPerspective::Format::NHWC) {
  153. //! use double for the workspace dtype as float may cause
  154. //! accuracy problems
  155. sizes.push_back(mat.total_nr_elems() * sizeof(double));
  156. }
  157. return {ptr, std::move(sizes)};
  158. }
  159. void WarpPerspectiveForwardImpl::exec(
  160. _megdnn_tensor_in ssrc, _megdnn_tensor_in smat, _megdnn_tensor_in smat_idx,
  161. _megdnn_tensor_out sdst, _megdnn_workspace sworkspace) {
  162. check_exec_allow_nhwc_mat_idx(
  163. ssrc.layout, smat.layout, smat_idx.layout, sdst.layout, sworkspace.size);
  164. TensorND src = ssrc;
  165. TensorND mat = smat;
  166. TensorND mat_idx = smat_idx;
  167. TensorND dst = sdst;
  168. Param::Format inner_format = param().format;
  169. auto bundle = get_workspace_bundle(
  170. sworkspace.raw_ptr, ssrc.layout, smat.layout, smat_idx.layout, sdst.layout);
  171. auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
  172. concrete_handle(this->handle()), &bundle);
  173. if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  174. ctypecvt.src_to_comp_type(ssrc, src)
  175. .src_to_comp_type(smat, mat)
  176. .src_to_comp_type(sdst, dst);
  177. } else if (
  178. (ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  179. ssrc.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
  180. param().format == Param::Format::NCHW) {
  181. auto handle_ptr = handle();
  182. get_inner_layout(
  183. ssrc.layout, sdst.layout, src.layout, dst.layout, handle_ptr,
  184. param().format);
  185. src = TensorND{bundle.get(0), src.layout};
  186. dst = TensorND{bundle.get(1), dst.layout};
  187. auto relayout_opr = handle_ptr->create_operator<RelayoutFormat>();
  188. RelayoutFormat::Param trans_param;
  189. trans_param.mode = RelayoutFormat::Param::Mode::NCHW_NCHW64;
  190. relayout_opr->param() = trans_param;
  191. relayout_opr->exec(ssrc, src, {});
  192. inner_format = Param::Format::NCHW64;
  193. }
  194. {
  195. auto stream = cuda_stream(this->handle());
  196. bool is_nhwc = inner_format == param::WarpPerspective::Format::NHWC;
  197. if (is_nhwc && param().imode != Param::InterpolationMode::LINEAR) {
  198. // use opencv impl only for nhwc and non-linear interp
  199. megdnn_assert(
  200. !mat_idx.raw_ptr(),
  201. "mat_idx is not supported in NHWC case with "
  202. "non-linear interpolation");
  203. warp_perspective::warp_perspective_cv_exec(
  204. src, mat, dst, param().border_val,
  205. warp_perspective::get_bmode(param().bmode),
  206. warp_perspective::get_imode(param().imode), ctypecvt.workspace(),
  207. stream);
  208. } else {
  209. megdnn_assert(warp::is_dnn_available(
  210. src.layout, mat.layout, dst.layout, param().imode, inner_format));
  211. size_t C, IH, IW, OH, OW;
  212. if (is_nhwc) {
  213. C = src.layout.shape[3];
  214. IH = src.layout.shape[1];
  215. IW = src.layout.shape[2];
  216. OH = dst.layout.shape[1];
  217. OW = dst.layout.shape[2];
  218. } else if (inner_format == Param::Format::NCHW4) {
  219. C = src.layout.shape[1] * 4;
  220. IH = src.layout.shape[2];
  221. IW = src.layout.shape[3];
  222. OH = dst.layout.shape[2];
  223. OW = dst.layout.shape[3];
  224. } else if (inner_format == Param::Format::NHWC_NCHW) {
  225. C = src.layout.shape[3];
  226. IH = src.layout.shape[1];
  227. IW = src.layout.shape[2];
  228. OH = dst.layout.shape[2];
  229. OW = dst.layout.shape[3];
  230. } else if (inner_format == Param::Format::NHWC_NCHW4_IC_SMALL) {
  231. C = src.layout.shape[3];
  232. IH = src.layout.shape[1];
  233. IW = src.layout.shape[2];
  234. OH = dst.layout.shape[2];
  235. OW = dst.layout.shape[3];
  236. megdnn_assert(
  237. (C == 1) || (C == 3),
  238. "NHWC_NCHW4_IC_SMALL only support C == 1 or C == 3");
  239. } else if (inner_format == Param::Format::NCHW_NCHW4_IC_SMALL) {
  240. C = src.layout.shape[1];
  241. IH = src.layout.shape[2];
  242. IW = src.layout.shape[3];
  243. OH = dst.layout.shape[2];
  244. OW = dst.layout.shape[3];
  245. megdnn_assert(
  246. (C == 1) || (C == 3),
  247. "NCHW_NCHW4_IC_SMALL only support C == 1 or C == 3");
  248. } else if (inner_format == Param::Format::NCHW64) {
  249. C = src.layout.shape[1] * 64;
  250. IH = src.layout.shape[2];
  251. IW = src.layout.shape[3];
  252. OH = dst.layout.shape[2];
  253. OW = dst.layout.shape[3];
  254. } else {
  255. megdnn_assert(
  256. inner_format == param::WarpPerspective::Format::NCHW,
  257. "invalid warp_perspective format");
  258. C = src.layout.shape[1];
  259. IH = src.layout.shape[2];
  260. IW = src.layout.shape[3];
  261. OH = dst.layout.shape[2];
  262. OW = dst.layout.shape[3];
  263. }
  264. megdnn_assert(
  265. param().imode == Param::InterpolationMode::LINEAR,
  266. "unsupported interpolation mode for NCHW format");
  267. auto bval = param().border_val;
  268. auto bmode = warp_perspective::get_bmode(param().bmode);
  269. if (src.layout.dtype == dst.layout.dtype) {
  270. if (src.layout.dtype == dtype::Float32{}) {
  271. warp_perspective::forward_proxy(
  272. is_nhwc, src.ptr<dt_float32>(), mat.ptr<dt_float32>(),
  273. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  274. dst.ptr<dt_float32>(), src.layout[0], mat.layout[0], C, IH,
  275. IW, OH, OW, bval, bmode, async_error_info(handle()),
  276. m_error_tracker, stream);
  277. } else if (DNN_FLOAT16_SELECT(
  278. src.layout.dtype == dtype::Float16(), false)) {
  279. #ifndef MEGDNN_DISABLE_FLOAT16
  280. warp_perspective::forward_proxy(
  281. is_nhwc, src.ptr<dt_float16>(), mat.ptr<dt_float32>(),
  282. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  283. dst.ptr<dt_float16>(), src.layout[0], mat.layout[0], C, IH,
  284. IW, OH, OW, static_cast<dt_float16>(bval), bmode,
  285. async_error_info(handle()), m_error_tracker, stream);
  286. #endif
  287. } else if (src.layout.dtype == dtype::Uint8()) {
  288. warp_perspective::forward_proxy<dt_uint8>(
  289. is_nhwc, src.ptr<dt_uint8>(), mat.ptr<dt_float32>(),
  290. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  291. dst.ptr<dt_uint8>(), src.layout[0], mat.layout[0], C, IH,
  292. IW, OH, OW, bval, bmode, async_error_info(handle()),
  293. m_error_tracker, stream);
  294. } else if (src.layout.dtype == dtype::Int8()) {
  295. megdnn_assert(
  296. !is_nhwc,
  297. "WarpPerspective on CUDA does not support "
  298. "NHWC + Int8");
  299. warp_perspective::forward_proxy<dt_int8>(
  300. false, src.ptr<dt_int8>(), mat.ptr<dt_float32>(),
  301. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  302. dst.ptr<dt_int8>(), src.layout[0], mat.layout[0], C, IH, IW,
  303. OH, OW, bval /* implicit float -> int8 conversion,
  304. should be safe */
  305. ,
  306. bmode, async_error_info(handle()), m_error_tracker, stream);
  307. } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
  308. megdnn_assert(
  309. param().format == Param::Format::NCHW4,
  310. "WarpPerspective on CUDA supports NCHW4 + "
  311. "QuantizedS8 only");
  312. warp_perspective::forward_proxy_nchw4<dt_int8>(
  313. src.compatible_ptr<dt_int8>(), mat.ptr<dt_float32>(),
  314. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  315. dst.compatible_ptr<dt_int8>(), src.layout[0], mat.layout[0],
  316. C, IH, IW, OH, OW, bval, bmode, async_error_info(handle()),
  317. m_error_tracker, stream);
  318. } else if (
  319. (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) &&
  320. (param().format == Param::Format::NCHW64 ||
  321. param().format == Param::Format::NCHW)) {
  322. bval = roundf(bval);
  323. bval = fmin(fmax(-8.f, bval), 7.f);
  324. warp_perspective::forward_proxy_nchw64<dt_qint4>(
  325. src.compatible_ptr<dt_qint4>(), mat.ptr<dt_float32>(),
  326. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  327. dst.compatible_ptr<dt_qint4>(), src.layout[0],
  328. mat.layout[0], C, IH, IW, OH, OW,
  329. static_cast<dt_qint4>(bval), bmode,
  330. async_error_info(handle()), m_error_tracker, stream);
  331. if (param().format == Param::Format::NCHW) {
  332. auto relayout_opr = handle()->create_operator<RelayoutFormat>();
  333. RelayoutFormat::Param trans_param;
  334. trans_param.mode = RelayoutFormat::Param::Mode::NCHW64_NCHW;
  335. trans_param.oc = sdst.layout[1];
  336. relayout_opr->param() = trans_param;
  337. relayout_opr->exec(dst, sdst, {});
  338. }
  339. } else if (
  340. (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
  341. (param().format == Param::Format::NCHW64 ||
  342. param().format == Param::Format::NCHW)) {
  343. bval = roundf(bval);
  344. bval = fmin(fmax(0, bval), 15);
  345. warp_perspective::forward_proxy_nchw64<dt_quint4>(
  346. src.compatible_ptr<dt_quint4>(), mat.ptr<dt_float32>(),
  347. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  348. dst.compatible_ptr<dt_quint4>(), src.layout[0],
  349. mat.layout[0], C, IH, IW, OH, OW,
  350. static_cast<dt_quint4>(bval), bmode,
  351. async_error_info(handle()), m_error_tracker, stream);
  352. if (param().format == Param::Format::NCHW) {
  353. auto relayout_opr = handle()->create_operator<RelayoutFormat>();
  354. RelayoutFormat::Param trans_param;
  355. trans_param.mode = RelayoutFormat::Param::Mode::NCHW64_NCHW;
  356. trans_param.oc = sdst.layout[1];
  357. relayout_opr->param() = trans_param;
  358. relayout_opr->exec(dst, sdst, {});
  359. }
  360. } else if (
  361. (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  362. src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
  363. (param().format == Param::Format::NHWC)) {
  364. constexpr int pack_c = 8;
  365. megdnn_assert(C % pack_c == 0);
  366. bval = roundf(bval);
  367. if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
  368. bval = fmin(fmax(-8.f, bval), 7.f);
  369. if (C % 16 == 0) {
  370. warp_perspective::forward_proxy_nhwc_bit4<dt_qint4, 16>(
  371. src.ptr<dt_qint4>(), mat.ptr<dt_float32>(),
  372. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  373. dst.ptr<dt_qint4>(), src.layout[0], mat.layout[0],
  374. C, IH, IW, OH, OW, static_cast<dt_qint4>(bval),
  375. bmode, async_error_info(handle()), m_error_tracker,
  376. stream);
  377. } else {
  378. warp_perspective::forward_proxy_nhwc_bit4<dt_qint4, pack_c>(
  379. src.ptr<dt_qint4>(), mat.ptr<dt_float32>(),
  380. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  381. dst.ptr<dt_qint4>(), src.layout[0], mat.layout[0],
  382. C, IH, IW, OH, OW, static_cast<dt_qint4>(bval),
  383. bmode, async_error_info(handle()), m_error_tracker,
  384. stream);
  385. }
  386. } else {
  387. bval = fmin(fmax(0.f, bval), 15.f);
  388. if (C % 16 == 0) {
  389. warp_perspective::forward_proxy_nhwc_bit4<dt_quint4, 16>(
  390. src.ptr<dt_quint4>(), mat.ptr<dt_float32>(),
  391. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  392. dst.ptr<dt_quint4>(), src.layout[0], mat.layout[0],
  393. C, IH, IW, OH, OW, static_cast<dt_quint4>(bval),
  394. bmode, async_error_info(handle()), m_error_tracker,
  395. stream);
  396. } else {
  397. warp_perspective::forward_proxy_nhwc_bit4<
  398. dt_quint4, pack_c>(
  399. src.ptr<dt_quint4>(), mat.ptr<dt_float32>(),
  400. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  401. dst.ptr<dt_quint4>(), src.layout[0], mat.layout[0],
  402. C, IH, IW, OH, OW, static_cast<dt_quint4>(bval),
  403. bmode, async_error_info(handle()), m_error_tracker,
  404. stream);
  405. }
  406. }
  407. }
  408. } else if ((src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
  409. src.layout.dtype.enumv() == DTypeEnum::Uint8)) {
  410. uint8_t zero_point = 0;
  411. float scale = 1.f;
  412. if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
  413. zero_point =
  414. src.layout.dtype.param<dtype::Quantized8Asymm>().zero_point;
  415. scale = src.layout.dtype.param<dtype::Quantized8Asymm>().scale;
  416. } else if (
  417. src.layout.dtype.enumv() == DTypeEnum::Uint8 &&
  418. dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
  419. zero_point = 128;
  420. scale = 1.f;
  421. }
  422. DTypeParamImpl<dt_quint8> src_dtype_param(scale, zero_point);
  423. if ((dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8 &&
  424. dst.layout.dtype.param<dtype::QuantizedS8>().scale == scale) &&
  425. ((param().format == Param::Format::NCHW_NCHW4_IC_SMALL) ||
  426. (param().format == Param::Format::NHWC_NCHW4_IC_SMALL))) {
  427. bool is_nhwc_ic_small =
  428. (param().format == Param::Format::NHWC_NCHW4_IC_SMALL);
  429. warp_perspective::forward_proxy_quint8_dimshuffle_typecvt_nchw4<
  430. dt_quint8, dt_uint8, dt_int8>(
  431. is_nhwc_ic_small, src.compatible_ptr<dt_uint8>(),
  432. mat.ptr<dt_float32>(),
  433. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  434. dst.compatible_ptr<dt_int8>(), src.layout[0], mat.layout[0],
  435. C, IH, IW, OH, OW, bval, src_dtype_param, bmode,
  436. async_error_info(handle()), m_error_tracker, stream);
  437. } else {
  438. megdnn_assert(
  439. ((dst.layout.dtype.enumv() == DTypeEnum::Float32) &&
  440. ((param().format == Param::Format::NCHW) ||
  441. (param().format == Param::Format::NHWC_NCHW))),
  442. "invalid format for Quantized8Asymm input");
  443. bool is_nhwc = (param().format == Param::Format::NHWC_NCHW);
  444. warp_perspective::forward_proxy_quint8_dimshuffle_typecvt_nchw<
  445. dt_quint8, dt_uint8, dt_float32>(
  446. is_nhwc, src.compatible_ptr<dt_uint8>(),
  447. mat.ptr<dt_float32>(),
  448. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  449. dst.compatible_ptr<dt_float32>(), src.layout[0],
  450. mat.layout[0], C, IH, IW, OH, OW, bval, src_dtype_param,
  451. bmode, async_error_info(handle()), m_error_tracker, stream);
  452. }
  453. } else {
  454. megdnn_throw(
  455. ssprintf("unsupported dtype: %s", src.layout.dtype.name()));
  456. }
  457. }
  458. }
  459. if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  460. ctypecvt.comp_to_dst_type(dst, sdst);
  461. }
  462. }
  463. void WarpPerspectiveForwardImpl::exec(
  464. _megdnn_in const TensorNDArray& ssrcs, _megdnn_tensor_in smat,
  465. _megdnn_tensor_in smat_idx, _megdnn_tensor_out sdst,
  466. _megdnn_workspace sworkspace) {
  467. TensorLayoutArray ssrcs_layout;
  468. for (auto&& s : ssrcs) {
  469. ssrcs_layout.push_back(s.layout);
  470. }
  471. check_exec_allow_nhwc_mat_idx(
  472. ssrcs_layout, smat.layout, smat_idx.layout, sdst.layout, sworkspace.size);
  473. TensorNDArray srcs = ssrcs;
  474. TensorND mat = smat;
  475. TensorND mat_idx = smat_idx;
  476. TensorND dst = sdst;
  477. Param::Format inner_format = param().format;
  478. auto bundle = get_workspace_bundle(
  479. sworkspace.raw_ptr, ssrcs_layout, smat.layout, smat_idx.layout,
  480. sdst.layout);
  481. auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
  482. concrete_handle(this->handle()), &bundle);
  483. if (ssrcs.front().layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  484. for (size_t i = 0; i < ssrcs.size(); i++) {
  485. ctypecvt.src_to_comp_type(ssrcs[i], srcs[i]);
  486. }
  487. ctypecvt.src_to_comp_type(smat, mat).src_to_comp_type(sdst, dst);
  488. }
  489. {
  490. auto stream = cuda_stream(this->handle());
  491. bool is_nhwc = inner_format == param::WarpPerspective::Format::NHWC;
  492. TensorND src = srcs.front();
  493. megdnn_assert(warp::is_dnn_available(
  494. ssrcs_layout, mat.layout, dst.layout, param().imode, inner_format));
  495. size_t C, IH, IW, OH, OW;
  496. if (is_nhwc) {
  497. C = src.layout.shape[3];
  498. IH = src.layout.shape[1];
  499. IW = src.layout.shape[2];
  500. OH = dst.layout.shape[1];
  501. OW = dst.layout.shape[2];
  502. } else {
  503. megdnn_assert(
  504. inner_format == param::WarpPerspective::Format::NCHW,
  505. "invalid warp_perspective format");
  506. C = src.layout.shape[1];
  507. IH = src.layout.shape[2];
  508. IW = src.layout.shape[3];
  509. OH = dst.layout.shape[2];
  510. OW = dst.layout.shape[3];
  511. }
  512. megdnn_assert(
  513. param().imode == Param::InterpolationMode::LINEAR,
  514. "unsupported interpolation mode form NCHW format");
  515. auto bval = param().border_val;
  516. auto bmode = warp_perspective::get_bmode(param().bmode);
  517. if (src.layout.dtype == dst.layout.dtype) {
  518. if (src.layout.dtype == dtype::Float32{}) {
  519. SmallVector<size_t> workspace_sizes{sizeof(dt_float32*) * srcs.size()};
  520. WorkspaceBundle workspace_cpu(nullptr, workspace_sizes);
  521. auto total_workspace_size = workspace_cpu.total_size_in_bytes();
  522. void* workspace_cpu_raw = malloc(total_workspace_size);
  523. workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes);
  524. auto srcs_cpu = static_cast<const dt_float32**>(workspace_cpu.get(0));
  525. size_t i =
  526. is_nhwc ? bundle.nr_workspace() - 2 : bundle.nr_workspace() - 1;
  527. auto srcs_gpu = static_cast<const dt_float32**>(bundle.get(i));
  528. for (size_t i = 0; i < srcs.size(); ++i) {
  529. srcs_cpu[i] = srcs[i].ptr<dt_float32>();
  530. }
  531. cuda_check(cudaMemcpyAsync(
  532. bundle.get(i), workspace_cpu.get(0), workspace_cpu.get_size(0),
  533. cudaMemcpyHostToDevice, stream));
  534. cuda_check(cudaStreamAddCallback(
  535. stream, callback_free, static_cast<void*>(workspace_cpu_raw),
  536. 0));
  537. warp_perspective::forward_proxy_multi_src(
  538. is_nhwc, srcs_gpu, mat.ptr<dt_float32>(),
  539. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  540. dst.ptr<dt_float32>(), srcs.size(), mat.layout[0], C, IH, IW,
  541. OH, OW, bval, bmode, async_error_info(handle()),
  542. m_error_tracker, stream);
  543. } else if (DNN_FLOAT16_SELECT(
  544. src.layout.dtype == dtype::Float16(), false)) {
  545. #ifndef MEGDNN_DISABLE_FLOAT16
  546. SmallVector<size_t> workspace_sizes{sizeof(dt_float16*) * srcs.size()};
  547. WorkspaceBundle workspace_cpu(nullptr, workspace_sizes);
  548. auto total_workspace_size = workspace_cpu.total_size_in_bytes();
  549. void* workspace_cpu_raw = malloc(total_workspace_size);
  550. workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes);
  551. auto srcs_cpu = static_cast<const dt_float16**>(workspace_cpu.get(0));
  552. auto srcs_gpu = static_cast<const dt_float16**>(bundle.get(0));
  553. for (size_t i = 0; i < srcs.size(); ++i) {
  554. srcs_cpu[i] = srcs[i].ptr<dt_float16>();
  555. }
  556. cuda_check(cudaMemcpyAsync(
  557. bundle.get(0), workspace_cpu.get(0), workspace_cpu.get_size(0),
  558. cudaMemcpyHostToDevice, stream));
  559. cuda_check(cudaStreamAddCallback(
  560. stream, callback_free, static_cast<void*>(workspace_cpu_raw),
  561. 0));
  562. warp_perspective::forward_proxy_multi_src(
  563. is_nhwc, srcs_gpu, mat.ptr<dt_float32>(),
  564. mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
  565. dst.ptr<dt_float16>(), srcs.size(), mat.layout[0], C, IH, IW,
  566. OH, OW, static_cast<dt_float16>(bval), bmode,
  567. async_error_info(handle()), m_error_tracker, stream);
  568. #endif
  569. }
  570. } else {
  571. megdnn_throw(ssprintf("unsupported dtype: %s", src.layout.dtype.name()));
  572. }
  573. }
  574. if (ssrcs.front().layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  575. ctypecvt.comp_to_dst_type(dst, sdst);
  576. }
  577. }
  578. } // namespace cuda
  579. } // namespace megdnn
  580. // vim: syntax=cpp.doxygen