You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cv.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. #pragma once
  2. #include "megdnn/internal/opr_header_prologue.h"
  3. namespace megdnn {
  4. /**
  5. * \brief This file contains CV operators, The layout is NHWC
  6. */
  7. class FlipBase : public OperatorBase {
  8. DEF_OPR_IMPL_CTOR(FlipBase, OperatorBase);
  9. DEF_OPR_PARAM(Flip);
  10. protected:
  11. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  12. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  13. };
  14. class FlipForward : public FlipBase {
  15. DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1);
  16. public:
  17. virtual void exec(
  18. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  19. _megdnn_workspace workspace) = 0;
  20. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  21. virtual size_t get_workspace_in_bytes(
  22. const TensorLayout& src, const TensorLayout& dst) = 0;
  23. protected:
  24. void check_exec(
  25. const TensorLayout& src, const TensorLayout& dst,
  26. size_t workspace_in_bytes);
  27. };
  28. using Flip = FlipForward;
  29. class RotateBase : public OperatorBase {
  30. DEF_OPR_IMPL_CTOR(RotateBase, OperatorBase);
  31. DEF_OPR_PARAM(Rotate);
  32. protected:
  33. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  34. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  35. };
  36. class RotateForward : public RotateBase {
  37. DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1);
  38. public:
  39. virtual void exec(
  40. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  41. _megdnn_workspace workspace) = 0;
  42. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  43. virtual size_t get_workspace_in_bytes(
  44. const TensorLayout& src, const TensorLayout& dst) = 0;
  45. protected:
  46. void check_exec(
  47. const TensorLayout& src, const TensorLayout& dst,
  48. size_t workspace_in_bytes);
  49. };
  50. using Rotate = RotateForward;
  51. class ROICopyBase : public OperatorBase {
  52. DEF_OPR_IMPL_CTOR(ROICopyBase, OperatorBase);
  53. DEF_OPR_PARAM(ROICopy);
  54. protected:
  55. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  56. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  57. };
  58. class ROICopyForward : public ROICopyBase {
  59. DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1);
  60. public:
  61. virtual void exec(
  62. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  63. _megdnn_workspace workspace) = 0;
  64. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  65. virtual size_t get_workspace_in_bytes(
  66. const TensorLayout& src, const TensorLayout& dst) = 0;
  67. protected:
  68. void check_exec(
  69. const TensorLayout& src, const TensorLayout& dst,
  70. size_t workspace_in_bytes);
  71. };
  72. using ROICopy = ROICopyForward;
  73. class CvtColorBase : public OperatorBase {
  74. DEF_OPR_IMPL_CTOR(CvtColorBase, OperatorBase);
  75. DEF_OPR_PARAM(CvtColor);
  76. protected:
  77. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  78. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  79. };
  80. class CvtColorForward : public CvtColorBase {
  81. DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1);
  82. public:
  83. virtual void exec(
  84. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  85. _megdnn_workspace workspace) = 0;
  86. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  87. virtual size_t get_workspace_in_bytes(
  88. const TensorLayout& src, const TensorLayout& dst) = 0;
  89. protected:
  90. void check_exec(
  91. const TensorLayout& src, const TensorLayout& dst,
  92. size_t workspace_in_bytes);
  93. };
  94. using CvtColor = CvtColorForward;
  95. /**
  96. * \brief Applices an affine transformation
  97. */
  98. class WarpAffineBase : public OperatorBase {
  99. DEF_OPR_IMPL_CTOR(WarpAffineBase, OperatorBase);
  100. DEF_OPR_PARAM(WarpAffine);
  101. public:
  102. using InterpolationMode = Param::InterpolationMode;
  103. using BorderMode = Param::BorderMode;
  104. protected:
  105. void check_layout_fwd(
  106. const TensorLayout& src, const TensorLayout& trans,
  107. const TensorLayout& dst);
  108. std::string param_msg() const;
  109. int get_real_coord(int p, int len);
  110. };
  111. class WarpAffineForward : public WarpAffineBase {
  112. DEF_OPR_IMPL(WarpAffineForward, WarpAffineBase, 2, 1);
  113. public:
  114. /**
  115. * \param[in] src input tensor
  116. * \param[in] trans transform matrix tensor
  117. * \param[in] dst output tensor
  118. *
  119. * \warning src, trans, border_value, dst should be contiguous
  120. * The size of trans is N * 2 * 3
  121. */
  122. virtual void exec(
  123. _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_out dst,
  124. _megdnn_workspace workspace) = 0;
  125. virtual size_t get_workspace_in_bytes(
  126. const TensorLayout& src, const TensorLayout& trans,
  127. const TensorLayout& dst) = 0;
  128. protected:
  129. void check_exec(
  130. const TensorLayout& src, const TensorLayout& trans, const TensorLayout& dst,
  131. size_t workspace_in_bytes);
  132. };
  133. using WarpAffine = WarpAffineForward;
  134. class GaussianBlurBase : public OperatorBase {
  135. DEF_OPR_IMPL_CTOR(GaussianBlurBase, OperatorBase);
  136. DEF_OPR_PARAM(GaussianBlur);
  137. protected:
  138. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  139. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  140. };
  141. class GaussianBlurForward : public GaussianBlurBase {
  142. DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1);
  143. public:
  144. virtual void exec(
  145. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  146. _megdnn_workspace workspace) = 0;
  147. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  148. virtual size_t get_workspace_in_bytes(
  149. const TensorLayout& src, const TensorLayout& dst) = 0;
  150. protected:
  151. void check_exec(
  152. const TensorLayout& src, const TensorLayout& dst,
  153. size_t workspace_in_bytes);
  154. };
  155. using GaussianBlur = GaussianBlurForward;
  156. /**
  157. * \brief Resize opr.
  158. */
  159. class ResizeBase : public OperatorBase {
  160. DEF_OPR_PARAM(Resize);
  161. DEF_OPR_IMPL(ResizeBase, OperatorBase, 1, 1);
  162. public:
  163. using InterpolationMode = Param::InterpolationMode;
  164. protected:
  165. //! get origin coord
  166. std::pair<float, int> get_cubic_coord(float scale, int idx);
  167. std::tuple<float, int, float, int> get_nearest_linear_coord(
  168. InterpolationMode imode, float scale, int size, int idx);
  169. //! get nearest index in src
  170. int get_nearest_src(float scale, int size, int idx);
  171. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  172. };
  173. class ResizeForward : public ResizeBase {
  174. DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1);
  175. public:
  176. virtual void exec(
  177. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  178. _megdnn_workspace workspace) = 0;
  179. virtual size_t get_workspace_in_bytes(
  180. const TensorLayout& src, const TensorLayout& dst) = 0;
  181. protected:
  182. void check_exec(
  183. const TensorLayout& src, const TensorLayout& dst,
  184. size_t workspace_in_bytes);
  185. };
  186. using Resize = ResizeForward;
  187. class ResizeBackward : public ResizeBase {
  188. DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1);
  189. public:
  190. virtual void exec(
  191. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  192. _megdnn_workspace workspace) = 0;
  193. virtual size_t get_workspace_in_bytes(
  194. const TensorLayout& diff, const TensorLayout& mat) = 0;
  195. protected:
  196. void check_exec(
  197. const TensorLayout& diff, const TensorLayout& mat,
  198. size_t workspace_in_bytes);
  199. };
  200. /**
  201. * \brief Remap opr.
  202. */
  203. class RemapBase : public OperatorBase {
  204. DEF_OPR_PARAM(Remap);
  205. DEF_OPR_IMPL(RemapBase, OperatorBase, 2, 1);
  206. public:
  207. using InterpolationMode = Param::InterpolationMode;
  208. using BorderMode = Param::BorderMode;
  209. protected:
  210. void check_layout_fwd(
  211. const TensorLayout& src, const TensorLayout& map_xy,
  212. const TensorLayout& dst);
  213. void deduce_layout_fwd(
  214. const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);
  215. };
  216. class RemapForward : public RemapBase {
  217. DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1);
  218. public:
  219. virtual void exec(
  220. _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_out dst,
  221. _megdnn_workspace workspace) = 0;
  222. void deduce_layout(
  223. const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);
  224. virtual size_t get_workspace_in_bytes(
  225. const TensorLayout& src, const TensorLayout& map_xy,
  226. const TensorLayout& dst) = 0;
  227. protected:
  228. void check_exec(
  229. const TensorLayout& src, const TensorLayout& map_xy,
  230. const TensorLayout& dst, size_t workspace_in_bytes);
  231. };
  232. using Remap = RemapForward;
  233. class RemapBackwardData : public RemapBase {
  234. DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1);
  235. public:
  236. virtual void exec(
  237. _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  238. _megdnn_workspace workspace) = 0;
  239. virtual size_t get_workspace_in_bytes(
  240. const TensorLayout& map_xy, const TensorLayout& diff,
  241. const TensorLayout& grad) = 0;
  242. protected:
  243. void check_exec(
  244. const TensorLayout& map_xy, const TensorLayout& diff,
  245. const TensorLayout& grad, size_t workspace_in_bytes);
  246. };
  247. class RemapBackwardMat : public RemapBase {
  248. DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1);
  249. public:
  250. virtual void exec(
  251. _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
  252. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  253. virtual size_t get_workspace_in_bytes(
  254. const TensorLayout& src, const TensorLayout& map_xy,
  255. const TensorLayout& diff, const TensorLayout& grad) = 0;
  256. protected:
  257. void check_exec(
  258. const TensorLayout& src, const TensorLayout& map_xy,
  259. const TensorLayout& diff, const TensorLayout& grad,
  260. size_t workspace_in_bytes);
  261. };
  262. class SeparableFilterBase : public OperatorBase {
  263. DEF_OPR_IMPL_CTOR(SeparableFilterBase, OperatorBase);
  264. DEF_OPR_PARAM(SeparableFilter);
  265. protected:
  266. void deduce_layout_fwd(
  267. const TensorLayout& src, const TensorLayout& filter_x,
  268. const TensorLayout& filter_y, TensorLayout& dst);
  269. void check_layout_fwd(
  270. const TensorLayout& src, const TensorLayout& filter_x,
  271. const TensorLayout& filter_y, const TensorLayout& dst);
  272. };
  273. class SeparableFilterForward : public SeparableFilterBase {
  274. DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1);
  275. public:
  276. virtual void exec(
  277. _megdnn_tensor_in src, _megdnn_tensor_in filter_x,
  278. _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
  279. _megdnn_workspace workspace) = 0;
  280. void deduce_layout(
  281. const TensorLayout& src, const TensorLayout& filter_x,
  282. const TensorLayout& filter_y, TensorLayout& dst);
  283. virtual size_t get_workspace_in_bytes(
  284. const TensorLayout& src, const TensorLayout& filter_x,
  285. const TensorLayout& filter_y, const TensorLayout& dst) = 0;
  286. protected:
  287. void check_exec(
  288. const TensorLayout& src, const TensorLayout& filter_x,
  289. const TensorLayout& filter_y, const TensorLayout& dst,
  290. size_t workspace_in_bytes);
  291. };
  292. using SeparableFilter = SeparableFilterForward;
  293. } // namespace megdnn
  294. #include "megdnn/internal/opr_header_epilogue.h"
  295. // vim: syntax=cpp.doxygen