You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

warp_perspective.cpp 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. /**
  2. * \file dnn/src/common/warp_perspective.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs.h"
  13. #include "src/common/utils.h"
  14. namespace megdnn {
  15. void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src,
  16. const TensorLayout& mat,
  17. const TensorLayout& mat_idx,
  18. const TensorLayout& dst) {
  19. megdnn_assert_contiguous(mat);
  20. megdnn_assert_contiguous(src);
  21. megdnn_assert_contiguous(dst);
  22. auto errmsg = [&]() {
  23. return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(mat) + ", " +
  24. megdnn_layout_msg(mat_idx) + ", " + megdnn_layout_msg(dst) +
  25. ", " + param_msg();
  26. };
  27. MEGDNN_MARK_USED_VAR(errmsg);
  28. if (param().format == param::WarpPerspective::Format::NHWCD4 ||
  29. param().format == param::WarpPerspective::Format::NCHW4 ||
  30. param().format == param::WarpPerspective::Format::NCHW64) {
  31. megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str());
  32. megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str());
  33. } else if (param().format ==
  34. param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL ||
  35. param().format ==
  36. param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) {
  37. megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str());
  38. megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str());
  39. } else {
  40. megdnn_assert(param().format == param::WarpPerspective::Format::NHWC ||
  41. param().format == param::WarpPerspective::Format::NCHW ||
  42. param().format ==
  43. param::WarpPerspective::Format::NHWC_NCHW);
  44. megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str());
  45. megdnn_assert(dst.ndim == 4_z, "%s", errmsg().c_str());
  46. }
  47. megdnn_assert(mat.ndim == 3_z, "%s", errmsg().c_str());
  48. megdnn_assert(dst.shape[0] == mat.shape[0], "%s", errmsg().c_str());
  49. if (mat_idx.ndim) {
  50. megdnn_assert(mat_idx.dtype == dtype::Int32() && mat_idx.ndim == 1,
  51. "%s", errmsg().c_str());
  52. megdnn_assert(mat.shape[0] == mat_idx.shape[0], "%s", errmsg().c_str());
  53. megdnn_assert_contiguous(mat_idx);
  54. } else {
  55. megdnn_assert(src.shape[0] == dst.shape[0], "%s", errmsg().c_str());
  56. }
  57. megdnn_assert(mat.shape[1] == 3_z, "%s", errmsg().c_str());
  58. megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str());
  59. if (src.format == dst.format && dst.dtype == src.dtype) {
  60. if (param().format == param::WarpPerspective::Format::NCHW) {
  61. megdnn_assert(
  62. src.dtype.enumv() == DTypeEnum::Float32 ||
  63. DNN_FLOAT16_SELECT(
  64. (src.dtype.enumv() == DTypeEnum::Float16 ||
  65. src.dtype.enumv() == DTypeEnum::BFloat16),
  66. false) ||
  67. src.dtype.enumv() == DTypeEnum::Int8 ||
  68. src.dtype.enumv() == DTypeEnum::Uint8 ||
  69. (src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  70. src.dtype.enumv() == DTypeEnum::Quantized8Asymm) ||
  71. src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  72. src.dtype.enumv() == DTypeEnum::Quantized4Asymm,
  73. "WarpPerspective NCHW input dtype should be "
  74. "Float32/Int8/Uint8/QInt8/QUint8/QInt4/QUInt4" DNN_FLOAT16_SELECT(
  75. "/Float16/BFloat16", "") ".");
  76. megdnn_assert(
  77. (src.dtype.category() == DTypeCategory::FLOAT &&
  78. (src.dtype == mat.dtype ||
  79. mat.dtype.enumv() == DTypeEnum::Float32)) ||
  80. ((src.dtype.category() == DTypeCategory::INT ||
  81. src.dtype.category() ==
  82. DTypeCategory::QUANTIZED) &&
  83. mat.dtype.enumv() == DTypeEnum::Float32),
  84. "The input to WarpPerspective is in NCHW format, in this "
  85. "case, if the input dtype is floating point, the "
  86. "transformation matrix should have same dtype as the "
  87. "input, otherwise, it should be in Float32, %s given.",
  88. mat.dtype.name());
  89. megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());
  90. megdnn_assert(param().imode ==
  91. param::WarpPerspective::InterpolationMode::LINEAR);
  92. megdnn_assert(param().bmode !=
  93. param::WarpPerspective::BorderMode::TRANSPARENT);
  94. megdnn_assert(param().bmode !=
  95. param::WarpPerspective::BorderMode::ISOLATED);
  96. } else if (param().format == param::WarpPerspective::Format::NHWC) {
  97. megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str());
  98. } else if (param().format == param::WarpPerspective::Format::NCHW4) {
  99. megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8,
  100. "src expected QuantizedS8, but got %s",
  101. src.dtype.name());
  102. megdnn_assert(mat.dtype == dtype::Float32(),
  103. "matrix dtype expected float, got %s",
  104. mat.dtype.name());
  105. megdnn_assert(src.shape[4] == 4 && dst.shape[4] == 4);
  106. megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());
  107. megdnn_assert(param().imode ==
  108. param::WarpPerspective::InterpolationMode::LINEAR);
  109. megdnn_assert(param().bmode !=
  110. param::WarpPerspective::BorderMode::TRANSPARENT);
  111. megdnn_assert(param().bmode !=
  112. param::WarpPerspective::BorderMode::ISOLATED);
  113. } else if (param().format == param::WarpPerspective::Format::NCHW64) {
  114. megdnn_assert((src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  115. src.dtype.enumv() == DTypeEnum::Quantized4Asymm),
  116. "src expected QuantizedS4/Quantized4Asymm, but got %s",
  117. src.dtype.name());
  118. megdnn_assert(mat.dtype == dtype::Float32(),
  119. "matrix dtype expected float, got %s",
  120. mat.dtype.name());
  121. megdnn_assert(src.shape[4] == 64 && dst.shape[4] == 64);
  122. megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());
  123. megdnn_assert(param().imode ==
  124. param::WarpPerspective::InterpolationMode::LINEAR);
  125. megdnn_assert(param().bmode !=
  126. param::WarpPerspective::BorderMode::TRANSPARENT);
  127. megdnn_assert(param().bmode !=
  128. param::WarpPerspective::BorderMode::ISOLATED);
  129. } else {
  130. megdnn_assert(param().format ==
  131. param::WarpPerspective::Format::NHWCD4);
  132. megdnn_assert(
  133. src.dtype == dtype::Float32() ||
  134. DNN_FLOAT16_SELECT((src.dtype == dtype::Float16() ||
  135. src.dtype == dtype::BFloat16()),
  136. false) ||
  137. src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  138. src.dtype.enumv() == DTypeEnum::Quantized8Asymm,
  139. "WarpPerspective NHWCD4 input dtype should be "
  140. "Float32" DNN_FLOAT16_SELECT(
  141. "/Float16/BFloat16",
  142. "") ",QunatizedS8, Quantized8Asymm.");
  143. megdnn_assert(
  144. (src.dtype == mat.dtype || mat.dtype == dtype::Float32()),
  145. "The input to WarpPerspective is in NHWCD4 format, in this "
  146. "case, if the input dtype is floating point, the "
  147. "transformation matrix should have same dtype as the "
  148. "input, %s given.",
  149. mat.dtype.name());
  150. //! number of channels is same
  151. megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str());
  152. megdnn_assert(param().imode ==
  153. param::WarpPerspective::InterpolationMode::LINEAR);
  154. megdnn_assert(param().bmode !=
  155. param::WarpPerspective::BorderMode::TRANSPARENT);
  156. megdnn_assert(param().bmode !=
  157. param::WarpPerspective::BorderMode::ISOLATED);
  158. }
  159. } else if (param().format ==
  160. param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL ||
  161. param().format ==
  162. param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) {
  163. megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
  164. src.dtype.enumv() == DTypeEnum::Uint8),
  165. "src expected Quantized8Asymm or Uint8, but got %s",
  166. src.dtype.name());
  167. megdnn_assert(mat.dtype == dtype::Float32(),
  168. "matrix dtype expected float, got %s", mat.dtype.name());
  169. megdnn_assert(dst.shape[4] == 4);
  170. megdnn_assert(param().imode ==
  171. param::WarpPerspective::InterpolationMode::LINEAR);
  172. megdnn_assert(param().bmode !=
  173. param::WarpPerspective::BorderMode::TRANSPARENT);
  174. megdnn_assert(param().bmode !=
  175. param::WarpPerspective::BorderMode::ISOLATED);
  176. } else if (param().format == param::WarpPerspective::Format::NHWC_NCHW) {
  177. megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
  178. src.dtype.enumv() == DTypeEnum::Uint8),
  179. "src expected Quantized8Asymm or Uint8, but got %s",
  180. src.dtype.name());
  181. megdnn_assert(mat.dtype == dtype::Float32(),
  182. "matrix dtype expected float, got %s", mat.dtype.name());
  183. megdnn_assert(src.shape[3] == dst.shape[1], "%s", errmsg().c_str());
  184. megdnn_assert(param().imode ==
  185. param::WarpPerspective::InterpolationMode::LINEAR);
  186. megdnn_assert(param().bmode !=
  187. param::WarpPerspective::BorderMode::TRANSPARENT);
  188. megdnn_assert(param().bmode !=
  189. param::WarpPerspective::BorderMode::ISOLATED);
  190. } else {
  191. megdnn_assert(param().format == param::WarpPerspective::Format::NCHW);
  192. megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
  193. src.dtype.enumv() == DTypeEnum::Uint8) &&
  194. dst.dtype.enumv() == DTypeEnum::Float32);
  195. }
  196. }
  197. std::string WarpPerspectiveBase::param_msg() const {
  198. std::string res;
  199. res.append("imode=");
  200. switch (param().imode) {
  201. case InterpolationMode::NEAREST:
  202. res.append("NEAREST");
  203. break;
  204. case InterpolationMode::LINEAR:
  205. res.append("LINEAR");
  206. break;
  207. case InterpolationMode::AREA:
  208. res.append("AREA");
  209. break;
  210. case InterpolationMode::CUBIC:
  211. res.append("CUBIC");
  212. break;
  213. case InterpolationMode::LANCZOS4:
  214. res.append("LANCZOS4");
  215. break;
  216. }
  217. res.append(", bmode=");
  218. switch (param().bmode) {
  219. case BorderMode::WRAP:
  220. res.append("WRAP");
  221. break;
  222. case BorderMode::CONSTANT:
  223. res.append("CONSTANT");
  224. break;
  225. case BorderMode::REFLECT:
  226. res.append("REFLECT");
  227. break;
  228. case BorderMode::REFLECT_101:
  229. res.append("REFLECT_101");
  230. break;
  231. case BorderMode::REPLICATE:
  232. res.append("REPLICATE");
  233. break;
  234. case BorderMode::TRANSPARENT:
  235. res.append("TRANSPARENT");
  236. break;
  237. case BorderMode::ISOLATED:
  238. res.append("ISOLATED");
  239. break;
  240. }
  241. if (param().bmode == BorderMode::CONSTANT) {
  242. res.append(", " + std::to_string(param().border_val));
  243. }
  244. return res;
  245. }
  246. int WarpPerspectiveBase::get_real_coord(int p, int len) {
  247. auto bmode = param().bmode;
  248. if ((unsigned)p < (unsigned)len)
  249. ;
  250. else if (bmode == BorderMode::REPLICATE)
  251. p = p < 0 ? 0 : len - 1;
  252. else if (bmode == BorderMode::REFLECT || bmode == BorderMode::REFLECT_101) {
  253. int delta = (bmode == BorderMode::REFLECT_101);
  254. if (len == 1)
  255. return 0;
  256. do {
  257. if (p < 0)
  258. p = -p - 1 + delta;
  259. else
  260. p = len - 1 - (p - len) - delta;
  261. } while ((unsigned)p >= (unsigned)len);
  262. } else if (bmode == BorderMode::WRAP) {
  263. if (p < 0)
  264. p -= ((p - len + 1) / len) * len;
  265. /*
  266. if( p >= len )
  267. p %= len;
  268. */
  269. while (p >= len) {
  270. p -= len;
  271. }
  272. } else if (bmode == BorderMode::CONSTANT)
  273. p = -1;
  274. return p;
  275. }
  276. void WarpPerspectiveForward::check_exec(const TensorLayout& src,
  277. const TensorLayout& mat,
  278. const TensorLayout& mat_idx,
  279. const TensorLayout& dst,
  280. size_t workspace_in_bytes) {
  281. check_exec_allow_nhwc_mat_idx(src, mat, mat_idx, dst, workspace_in_bytes);
  282. }
  283. void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
  284. const TensorLayout& src, const TensorLayout& mat,
  285. const TensorLayout& mat_idx, const TensorLayout& dst,
  286. size_t workspace_in_bytes) {
  287. check_layout_fwd(src, mat, mat_idx, dst);
  288. auto required_workspace_in_bytes =
  289. get_workspace_in_bytes(src, mat, mat_idx, dst);
  290. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  291. if (param().format != Param::Format::NHWC &&
  292. param().format != Param::Format::NCHW &&
  293. param().format != Param::Format::NCHW4 &&
  294. param().format != Param::Format::NHWC_NCHW &&
  295. param().format != Param::Format::NHWC_NCHW4_IC_SMALL &&
  296. param().format != Param::Format::NCHW_NCHW4_IC_SMALL &&
  297. param().format != Param::Format::NCHW64) {
  298. megdnn_assert(!mat_idx.ndim,
  299. "mat_idx not supported for current format");
  300. }
  301. }
  302. void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat,
  303. const TensorLayout& mat_idx,
  304. const TensorLayout& diff,
  305. const TensorLayout& grad,
  306. size_t workspace_in_bytes) {
  307. check_layout_fwd(grad, mat, mat_idx, diff);
  308. megdnn_assert(grad.dtype == dtype::Float32() DNN_INC_FLOAT16(
  309. || grad.dtype == dtype::BFloat16()),
  310. "Backward WarpPerspective only supports Float32/BFloat16.");
  311. auto required_workspace_in_bytes =
  312. get_workspace_in_bytes(mat, mat_idx, diff, grad);
  313. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  314. }
  315. void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src,
  316. const TensorLayout& mat,
  317. const TensorLayout& mat_idx,
  318. const TensorLayout& diff,
  319. const TensorLayout& grad,
  320. size_t workspace_in_bytes) {
  321. check_layout_fwd(src, mat, mat_idx, diff);
  322. megdnn_assert_eq_layout(mat, grad);
  323. megdnn_assert(grad.dtype == dtype::Float32() DNN_INC_FLOAT16(
  324. || grad.dtype == dtype::BFloat16()),
  325. "Backward WarpPerspective only supports Float32/BFloat16.");
  326. auto required_workspace_in_bytes =
  327. get_workspace_in_bytes(src, mat, mat_idx, diff, grad);
  328. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  329. }
  330. } // namespace megdnn
  331. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台