You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imgproc.h 9.0 kB


  1. #pragma once
  2. #include "megdnn/internal/opr_header_prologue.h"
  3. namespace megdnn {
  4. class WarpPerspectiveBase : public OperatorBase {
  5. DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase);
  6. DEF_OPR_PARAM(WarpPerspective);
  7. public:
  8. using InterpolationMode = Param::InterpolationMode;
  9. using BorderMode = Param::BorderMode;
  10. protected:
  11. void check_layout_fwd(
  12. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
  13. check_layout_fwd(src, mat, {}, dst);
  14. }
  15. void check_layout_fwd(
  16. const TensorLayoutArray& srcs, const TensorLayout& mat,
  17. const TensorLayout& dst) {
  18. check_layout_fwd(srcs, mat, {}, dst);
  19. }
  20. void check_layout_fwd(
  21. const TensorLayout& src, const TensorLayout& mat,
  22. const TensorLayout& mat_idx, const TensorLayout& dst);
  23. void check_layout_fwd(
  24. const TensorLayoutArray& srcs, const TensorLayout& mat,
  25. const TensorLayout& mat_idx, const TensorLayout& dst);
  26. std::string param_msg() const;
  27. int get_real_coord(int p, int len);
  28. };
  29. class WarpPerspectiveForward : public WarpPerspectiveBase {
  30. DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1);
  31. public:
  32. /**
  33. * \param[in] src (n, channel, in_height, in_width)
  34. * \param[in] mat (n, 3, 3)
  35. * \param[out] dst (n, channel, out_height, out_width)
  36. *
  37. * \see
  38. * http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine
  39. *
  40. * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2]
  41. * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator,
  42. * (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator)
  43. *
  44. * src and dst can have different shapes, as long as their n and c agree.
  45. * src, mat and dst should be contiguous.
  46. */
  47. void exec(
  48. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst,
  49. _megdnn_workspace workspace) {
  50. exec(src, mat, {}, dst, workspace);
  51. }
  52. void exec(
  53. _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
  54. _megdnn_tensor_out dst, _megdnn_workspace workspace) {
  55. exec(srcs, mat, {}, dst, workspace);
  56. }
  57. /**
  58. * \p src should have batch size m, and \p mat and \p mat_idx should
  59. * both have batch size n. Each item in \p mat_idx must be in the range
  60. * of [0, m-1].
  61. *
  62. * \param mat_idx the indices of input image that each matrix in \p mat
  63. * should act on. It can also be empty and in such case \p mat
  64. * should have the same batch size as \p src.
  65. */
  66. virtual void exec(
  67. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
  68. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  69. virtual void exec(
  70. _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
  71. _megdnn_tensor_in mat_idx, _megdnn_tensor_out dst,
  72. _megdnn_workspace workspace) = 0;
  73. size_t get_workspace_in_bytes(
  74. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
  75. return get_workspace_in_bytes(src, mat, {}, dst);
  76. }
  77. size_t get_workspace_in_bytes(
  78. const TensorLayoutArray& srcs, const TensorLayout& mat,
  79. const TensorLayout& dst) {
  80. return get_workspace_in_bytes(srcs, mat, {}, dst);
  81. }
  82. virtual size_t get_workspace_in_bytes(
  83. const TensorLayout& src, const TensorLayout& mat,
  84. const TensorLayout& mat_idx, const TensorLayout& dst) = 0;
  85. virtual size_t get_workspace_in_bytes(
  86. const TensorLayoutArray& srcs, const TensorLayout& mat,
  87. const TensorLayout& mat_idx, const TensorLayout& dst) = 0;
  88. protected:
  89. void check_exec(
  90. const TensorLayout& src, const TensorLayout& mat,
  91. const TensorLayout& mat_idx, const TensorLayout& dst,
  92. size_t workspace_in_bytes);
  93. void check_exec_allow_nhwc_mat_idx(
  94. const TensorLayout& src, const TensorLayout& mat,
  95. const TensorLayout& mat_idx, const TensorLayout& dst,
  96. size_t workspace_in_bytes);
  97. void check_exec_allow_nhwc_mat_idx(
  98. const TensorLayoutArray& srcs, const TensorLayout& mat,
  99. const TensorLayout& mat_idx, const TensorLayout& dst,
  100. size_t workspace_in_bytes);
  101. };
  102. using WarpPerspective = WarpPerspectiveForward;
  103. class WarpPerspectiveBackwardData : public WarpPerspectiveBase {
  104. DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1);
  105. public:
  106. /**
  107. * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
  108. * \param[in] diff the backpropagated gradient wrt. dst
  109. * \param[out] grad the backpropagated gradient wrt. src
  110. * \param[out] workspace temporary workspace to perform backward
  111. */
  112. void exec(
  113. _megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  114. _megdnn_workspace workspace) {
  115. exec(mat, {}, diff, grad, workspace);
  116. }
  117. virtual void exec(
  118. _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff,
  119. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  120. size_t get_workspace_in_bytes(
  121. const TensorLayout& mat, const TensorLayout& diff,
  122. const TensorLayout& grad) {
  123. return get_workspace_in_bytes(mat, {}, diff, grad);
  124. }
  125. virtual size_t get_workspace_in_bytes(
  126. const TensorLayout& mat, const TensorLayout& mat_idx,
  127. const TensorLayout& diff, const TensorLayout& grad) = 0;
  128. protected:
  129. void check_exec(
  130. const TensorLayout& mat, const TensorLayout& mat_idx,
  131. const TensorLayout& diff, const TensorLayout& grad,
  132. size_t workspace_in_bytes);
  133. };
  134. class WarpPerspectiveBackwardMat : public WarpPerspectiveBase {
  135. DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1);
  136. public:
  137. /**
  138. * \param[in] src the `src' parameter in WarpPerspectiveForward::exec
  139. * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
  140. * \param[in] diff the backpropagated gradient wrt. dst
  141. * \param[out] grad the backpropagated gradient wrt. mat
  142. * \param[out] workspace temporary workspace to perform backward
  143. */
  144. void exec(
  145. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in diff,
  146. _megdnn_tensor_out grad, _megdnn_workspace workspace) {
  147. exec(src, mat, {}, diff, grad, workspace);
  148. }
  149. virtual void exec(
  150. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
  151. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  152. _megdnn_workspace workspace) = 0;
  153. size_t get_workspace_in_bytes(
  154. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& diff,
  155. const TensorLayout& grad) {
  156. return get_workspace_in_bytes(src, mat, {}, diff, grad);
  157. }
  158. virtual size_t get_workspace_in_bytes(
  159. const TensorLayout& src, const TensorLayout& mat,
  160. const TensorLayout& mat_idx, const TensorLayout& diff,
  161. const TensorLayout& grad) = 0;
  162. protected:
  163. void check_exec(
  164. const TensorLayout& src, const TensorLayout& mat,
  165. const TensorLayout& mat_idx, const TensorLayout& diff,
  166. const TensorLayout& grad, size_t workspace_in_bytes);
  167. };
  168. class DctChannelSelectForward : public OperatorBase {
  169. DEF_OPR_PARAM(DctChannelSelect);
  170. DEF_OPR_IMPL(DctChannelSelectForward, OperatorBase, 3, 1);
  171. public:
  172. /**
  173. * \param[in] DctChannelSelectForward input, must be uint8 nchw tensor
  174. * \param[in] mask_offset input, must be int32 nchw tensor
  175. * \param[in] mask_val input, must be int32 nchw tensor
  176. * \param[dst] DctChannelSelectForward output, default fp32 nchw tensor
  177. * \param[out] workspace temporary workspace to perform forward
  178. */
  179. virtual void exec(
  180. _megdnn_tensor_in src, _megdnn_tensor_in mask_offset,
  181. _megdnn_tensor_in mask_val, _megdnn_tensor_out dst,
  182. _megdnn_workspace workspace) = 0;
  183. void deduce_layout(
  184. const TensorLayout& src, const TensorLayout& mask_offset,
  185. const TensorLayout& mask_val, TensorLayout& dst);
  186. virtual size_t get_workspace_in_bytes(
  187. const TensorLayout& src, const TensorLayout& mask_offset,
  188. const TensorLayout& mask_val, const TensorLayout& dst) = 0;
  189. protected:
  190. void check_layout_fwd(
  191. const TensorLayout& src, const TensorLayout& mask_offset,
  192. const TensorLayout& mask_val, const TensorLayout& dst);
  193. void deduce_layout_fwd(
  194. const TensorLayout& src, const TensorLayout& mask_offset,
  195. const TensorLayout& mask_val, TensorLayout& dst);
  196. std::string param_msg() const;
  197. };
  198. } // namespace megdnn
  199. #include "megdnn/internal/opr_header_epilogue.h"
  200. // vim: syntax=cpp.doxygen