You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imgproc.h 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. #pragma once
  2. #include "megdnn/internal/opr_header_prologue.h"
  3. namespace megdnn {
  4. class WarpPerspectiveBase : public OperatorBase {
  5. DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase);
  6. DEF_OPR_PARAM(WarpPerspective);
  7. public:
  8. using InterpolationMode = Param::InterpolationMode;
  9. using BorderMode = Param::BorderMode;
  10. protected:
  11. void check_layout_fwd(
  12. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
  13. check_layout_fwd(src, mat, {}, dst);
  14. }
  15. void check_layout_fwd(
  16. const TensorLayout& src, const TensorLayout& mat,
  17. const TensorLayout& mat_idx, const TensorLayout& dst);
  18. std::string param_msg() const;
  19. int get_real_coord(int p, int len);
  20. };
  21. class WarpPerspectiveForward : public WarpPerspectiveBase {
  22. DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1);
  23. public:
  24. /**
  25. * \param[in] src (n, channel, in_height, in_width)
  26. * \param[in] mat (n, 3, 3)
  27. * \param[out] dst (n, channel, out_height, out_width)
  28. *
  29. * \see
  30. * http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine
  31. *
  32. * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2]
  33. * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator,
  34. * (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator)
  35. *
  36. * src and dst can have different shapes, as long as their n and c agree.
  37. * src, mat and dst should be contiguous.
  38. */
  39. void exec(
  40. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst,
  41. _megdnn_workspace workspace) {
  42. exec(src, mat, {}, dst, workspace);
  43. }
  44. /**
  45. * \p src should have batch size m, and \p mat and \p mat_idx should
  46. * both have batch size n. Each item in \p mat_idx must be in the range
  47. * of [0, m-1].
  48. *
  49. * \param mat_idx the indices of input image that each matrix in \p mat
  50. * should act on. It can also be empty and in such case \p mat
  51. * should have the same batch size as \p src.
  52. */
  53. virtual void exec(
  54. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
  55. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  56. size_t get_workspace_in_bytes(
  57. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
  58. return get_workspace_in_bytes(src, mat, {}, dst);
  59. }
  60. virtual size_t get_workspace_in_bytes(
  61. const TensorLayout& src, const TensorLayout& mat,
  62. const TensorLayout& mat_idx, const TensorLayout& dst) = 0;
  63. protected:
  64. void check_exec(
  65. const TensorLayout& src, const TensorLayout& mat,
  66. const TensorLayout& mat_idx, const TensorLayout& dst,
  67. size_t workspace_in_bytes);
  68. void check_exec_allow_nhwc_mat_idx(
  69. const TensorLayout& src, const TensorLayout& mat,
  70. const TensorLayout& mat_idx, const TensorLayout& dst,
  71. size_t workspace_in_bytes);
  72. };
  73. using WarpPerspective = WarpPerspectiveForward;
  74. class WarpPerspectiveBackwardData : public WarpPerspectiveBase {
  75. DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1);
  76. public:
  77. /**
  78. * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
  79. * \param[in] diff the backpropagated gradient wrt. dst
  80. * \param[out] grad the backpropagated gradient wrt. src
  81. * \param[out] workspace temporary workspace to perform backward
  82. */
  83. void exec(
  84. _megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  85. _megdnn_workspace workspace) {
  86. exec(mat, {}, diff, grad, workspace);
  87. }
  88. virtual void exec(
  89. _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff,
  90. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  91. size_t get_workspace_in_bytes(
  92. const TensorLayout& mat, const TensorLayout& diff,
  93. const TensorLayout& grad) {
  94. return get_workspace_in_bytes(mat, {}, diff, grad);
  95. }
  96. virtual size_t get_workspace_in_bytes(
  97. const TensorLayout& mat, const TensorLayout& mat_idx,
  98. const TensorLayout& diff, const TensorLayout& grad) = 0;
  99. protected:
  100. void check_exec(
  101. const TensorLayout& mat, const TensorLayout& mat_idx,
  102. const TensorLayout& diff, const TensorLayout& grad,
  103. size_t workspace_in_bytes);
  104. };
  105. class WarpPerspectiveBackwardMat : public WarpPerspectiveBase {
  106. DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1);
  107. public:
  108. /**
  109. * \param[in] src the `src' parameter in WarpPerspectiveForward::exec
  110. * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
  111. * \param[in] diff the backpropagated gradient wrt. dst
  112. * \param[out] grad the backpropagated gradient wrt. mat
  113. * \param[out] workspace temporary workspace to perform backward
  114. */
  115. void exec(
  116. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in diff,
  117. _megdnn_tensor_out grad, _megdnn_workspace workspace) {
  118. exec(src, mat, {}, diff, grad, workspace);
  119. }
  120. virtual void exec(
  121. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
  122. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  123. _megdnn_workspace workspace) = 0;
  124. size_t get_workspace_in_bytes(
  125. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& diff,
  126. const TensorLayout& grad) {
  127. return get_workspace_in_bytes(src, mat, {}, diff, grad);
  128. }
  129. virtual size_t get_workspace_in_bytes(
  130. const TensorLayout& src, const TensorLayout& mat,
  131. const TensorLayout& mat_idx, const TensorLayout& diff,
  132. const TensorLayout& grad) = 0;
  133. protected:
  134. void check_exec(
  135. const TensorLayout& src, const TensorLayout& mat,
  136. const TensorLayout& mat_idx, const TensorLayout& diff,
  137. const TensorLayout& grad, size_t workspace_in_bytes);
  138. };
  139. class DctChannelSelectForward : public OperatorBase {
  140. DEF_OPR_PARAM(DctChannelSelect);
  141. DEF_OPR_IMPL(DctChannelSelectForward, OperatorBase, 3, 1);
  142. public:
  143. /**
  144. * \param[in] DctChannelSelectForward input, must be uint8 nchw tensor
  145. * \param[in] mask_offset input, must be int32 nchw tensor
  146. * \param[in] mask_val input, must be int32 nchw tensor
  147. * \param[dst] DctChannelSelectForward output, default fp32 nchw tensor
  148. * \param[out] workspace temporary workspace to perform forward
  149. */
  150. virtual void exec(
  151. _megdnn_tensor_in src, _megdnn_tensor_in mask_offset,
  152. _megdnn_tensor_in mask_val, _megdnn_tensor_out dst,
  153. _megdnn_workspace workspace) = 0;
  154. void deduce_layout(
  155. const TensorLayout& src, const TensorLayout& mask_offset,
  156. const TensorLayout& mask_val, TensorLayout& dst);
  157. virtual size_t get_workspace_in_bytes(
  158. const TensorLayout& src, const TensorLayout& mask_offset,
  159. const TensorLayout& mask_val, const TensorLayout& dst) = 0;
  160. protected:
  161. void check_layout_fwd(
  162. const TensorLayout& src, const TensorLayout& mask_offset,
  163. const TensorLayout& mask_val, const TensorLayout& dst);
  164. void deduce_layout_fwd(
  165. const TensorLayout& src, const TensorLayout& mask_offset,
  166. const TensorLayout& mask_val, TensorLayout& dst);
  167. std::string param_msg() const;
  168. };
  169. } // namespace megdnn
  170. #include "megdnn/internal/opr_header_epilogue.h"
  171. // vim: syntax=cpp.doxygen