You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imgproc.h 10 kB


  1. #pragma once
  2. #include "megdnn/internal/opr_header_prologue.h"
  3. namespace megdnn {
  4. class WarpPerspectiveBase : public OperatorBase {
  5. DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase);
  6. DEF_OPR_PARAM(WarpPerspective);
  7. public:
  8. using InterpolationMode = Param::InterpolationMode;
  9. using BorderMode = Param::BorderMode;
  10. protected:
  11. void check_layout_fwd(
  12. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
  13. check_layout_fwd(src, mat, {}, dst);
  14. }
  15. void check_layout_fwd(
  16. const TensorLayoutArray& srcs, const TensorLayout& mat,
  17. const TensorLayout& dst) {
  18. check_layout_fwd(srcs, mat, {}, dst);
  19. }
  20. void check_layout_fwd(
  21. const TensorLayout& src, const TensorLayout& mat,
  22. const TensorLayout& mat_idx, const TensorLayout& dst);
  23. void check_layout_fwd(
  24. const TensorLayoutArray& srcs, const TensorLayout& mat,
  25. const TensorLayout& mat_idx, const TensorLayout& dst);
  26. std::string param_msg() const;
  27. int get_real_coord(int p, int len);
  28. };
  29. class WarpPerspectiveForward : public WarpPerspectiveBase {
  30. DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1);
  31. public:
  32. /**
  33. * \param[in] src (n, channel, in_height, in_width)
  34. * \param[in] mat (n, 3, 3)
  35. * \param[out] dst (n, channel, out_height, out_width)
  36. *
  37. * \see
  38. * http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine
  39. *
  40. * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2]
  41. * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator,
  42. * (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator)
  43. *
  44. * src and dst can have different shapes, as long as their n and c agree.
  45. * src, mat and dst should be contiguous.
  46. */
  47. void exec(
  48. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst,
  49. _megdnn_workspace workspace) {
  50. exec(src, mat, {}, dst, workspace);
  51. }
  52. /**
  53. * \param[in] srcs consists of n TensorNDs, each TensorND has shape (1, channel,
  54. * in_height, in_width) \param[in] mat (n, 3, 3) \param[out] dst (n, channel,
  55. * out_height, out_width)
  56. *
  57. * \note
  58. * srcs and dst can have different shapes, as long as their c agree and the size of
  59. * srcs is equal to n. every element of srcs, mat and dst should be contiguous.
  60. *
  61. * equivalent to:
  62. * TensorND src{nullptr, TensorLayout({n, channel, in_height, in_width},
  63. * srcs[0].layout.dtype)}; auto concat = handle()->create_operator<Concat>();
  64. * concat->exec(srcs, src);
  65. * auto warp = handle()->create_operator<WarpPerspectiveForward>();
  66. * warp->exec(src, mat, dst, workspace);
  67. */
  68. void exec(
  69. _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
  70. _megdnn_tensor_out dst, _megdnn_workspace workspace) {
  71. exec(srcs, mat, {}, dst, workspace);
  72. }
  73. /**
  74. * \p src should have batch size m, and \p mat and \p mat_idx should
  75. * both have batch size n. Each item in \p mat_idx must be in the range
  76. * of [0, m-1].
  77. *
  78. * \param mat_idx the indices of input image that each matrix in \p mat
  79. * should act on. It can also be empty and in such case \p mat
  80. * should have the same batch size as \p src.
  81. */
  82. virtual void exec(
  83. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
  84. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  85. /**
  86. * \p srcs should have m elements, and \p mat and \p mat_idx should
  87. * both have batch size n. Each item in \p mat_idx must be in the range
  88. * of [0, m-1].
  89. *
  90. * \param mat_idx the indices of input image that each matrix in \p mat
  91. * should act on. It can also be empty and in such case \p mat batch size
  92. * should be the same as the number of elements in \p srcs .
  93. */
  94. virtual void exec(
  95. _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
  96. _megdnn_tensor_in mat_idx, _megdnn_tensor_out dst,
  97. _megdnn_workspace workspace) {
  98. static_cast<void>(srcs);
  99. static_cast<void>(mat);
  100. static_cast<void>(mat_idx);
  101. static_cast<void>(dst);
  102. static_cast<void>(workspace);
  103. }
  104. size_t get_workspace_in_bytes(
  105. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
  106. return get_workspace_in_bytes(src, mat, {}, dst);
  107. }
  108. size_t get_workspace_in_bytes(
  109. const TensorLayoutArray& srcs, const TensorLayout& mat,
  110. const TensorLayout& dst) {
  111. return get_workspace_in_bytes(srcs, mat, {}, dst);
  112. }
  113. virtual size_t get_workspace_in_bytes(
  114. const TensorLayout& src, const TensorLayout& mat,
  115. const TensorLayout& mat_idx, const TensorLayout& dst) = 0;
  116. virtual size_t get_workspace_in_bytes(
  117. const TensorLayoutArray& srcs, const TensorLayout& mat,
  118. const TensorLayout& mat_idx, const TensorLayout& dst) {
  119. static_cast<void>(srcs);
  120. static_cast<void>(mat);
  121. static_cast<void>(mat_idx);
  122. static_cast<void>(dst);
  123. return 0;
  124. }
  125. protected:
  126. void check_exec(
  127. const TensorLayout& src, const TensorLayout& mat,
  128. const TensorLayout& mat_idx, const TensorLayout& dst,
  129. size_t workspace_in_bytes);
  130. void check_exec_allow_nhwc_mat_idx(
  131. const TensorLayout& src, const TensorLayout& mat,
  132. const TensorLayout& mat_idx, const TensorLayout& dst,
  133. size_t workspace_in_bytes);
  134. void check_exec_allow_nhwc_mat_idx(
  135. const TensorLayoutArray& srcs, const TensorLayout& mat,
  136. const TensorLayout& mat_idx, const TensorLayout& dst,
  137. size_t workspace_in_bytes);
  138. };
  139. using WarpPerspective = WarpPerspectiveForward;
  140. class WarpPerspectiveBackwardData : public WarpPerspectiveBase {
  141. DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1);
  142. public:
  143. /**
  144. * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
  145. * \param[in] diff the backpropagated gradient wrt. dst
  146. * \param[out] grad the backpropagated gradient wrt. src
  147. * \param[out] workspace temporary workspace to perform backward
  148. */
  149. void exec(
  150. _megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  151. _megdnn_workspace workspace) {
  152. exec(mat, {}, diff, grad, workspace);
  153. }
  154. virtual void exec(
  155. _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff,
  156. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  157. size_t get_workspace_in_bytes(
  158. const TensorLayout& mat, const TensorLayout& diff,
  159. const TensorLayout& grad) {
  160. return get_workspace_in_bytes(mat, {}, diff, grad);
  161. }
  162. virtual size_t get_workspace_in_bytes(
  163. const TensorLayout& mat, const TensorLayout& mat_idx,
  164. const TensorLayout& diff, const TensorLayout& grad) = 0;
  165. protected:
  166. void check_exec(
  167. const TensorLayout& mat, const TensorLayout& mat_idx,
  168. const TensorLayout& diff, const TensorLayout& grad,
  169. size_t workspace_in_bytes);
  170. };
  171. class WarpPerspectiveBackwardMat : public WarpPerspectiveBase {
  172. DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1);
  173. public:
  174. /**
  175. * \param[in] src the `src' parameter in WarpPerspectiveForward::exec
  176. * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
  177. * \param[in] diff the backpropagated gradient wrt. dst
  178. * \param[out] grad the backpropagated gradient wrt. mat
  179. * \param[out] workspace temporary workspace to perform backward
  180. */
  181. void exec(
  182. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in diff,
  183. _megdnn_tensor_out grad, _megdnn_workspace workspace) {
  184. exec(src, mat, {}, diff, grad, workspace);
  185. }
  186. virtual void exec(
  187. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
  188. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  189. _megdnn_workspace workspace) = 0;
  190. size_t get_workspace_in_bytes(
  191. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& diff,
  192. const TensorLayout& grad) {
  193. return get_workspace_in_bytes(src, mat, {}, diff, grad);
  194. }
  195. virtual size_t get_workspace_in_bytes(
  196. const TensorLayout& src, const TensorLayout& mat,
  197. const TensorLayout& mat_idx, const TensorLayout& diff,
  198. const TensorLayout& grad) = 0;
  199. protected:
  200. void check_exec(
  201. const TensorLayout& src, const TensorLayout& mat,
  202. const TensorLayout& mat_idx, const TensorLayout& diff,
  203. const TensorLayout& grad, size_t workspace_in_bytes);
  204. };
  205. class DctChannelSelectForward : public OperatorBase {
  206. DEF_OPR_PARAM(DctChannelSelect);
  207. DEF_OPR_IMPL(DctChannelSelectForward, OperatorBase, 3, 1);
  208. public:
  209. /**
  210. * \param[in] DctChannelSelectForward input, must be uint8 nchw tensor
  211. * \param[in] mask_offset input, must be int32 nchw tensor
  212. * \param[in] mask_val input, must be int32 nchw tensor
  213. * \param[dst] DctChannelSelectForward output, default fp32 nchw tensor
  214. * \param[out] workspace temporary workspace to perform forward
  215. */
  216. virtual void exec(
  217. _megdnn_tensor_in src, _megdnn_tensor_in mask_offset,
  218. _megdnn_tensor_in mask_val, _megdnn_tensor_out dst,
  219. _megdnn_workspace workspace) = 0;
  220. void deduce_layout(
  221. const TensorLayout& src, const TensorLayout& mask_offset,
  222. const TensorLayout& mask_val, TensorLayout& dst);
  223. virtual size_t get_workspace_in_bytes(
  224. const TensorLayout& src, const TensorLayout& mask_offset,
  225. const TensorLayout& mask_val, const TensorLayout& dst) = 0;
  226. protected:
  227. void check_layout_fwd(
  228. const TensorLayout& src, const TensorLayout& mask_offset,
  229. const TensorLayout& mask_val, const TensorLayout& dst);
  230. void deduce_layout_fwd(
  231. const TensorLayout& src, const TensorLayout& mask_offset,
  232. const TensorLayout& mask_val, TensorLayout& dst);
  233. std::string param_msg() const;
  234. };
  235. } // namespace megdnn
  236. #include "megdnn/internal/opr_header_epilogue.h"
  237. // vim: syntax=cpp.doxygen