You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cv.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. /**
  2. * \file dnn/include/megdnn/oprs/cv.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/internal/opr_header_prologue.h"
  14. namespace megdnn {
  15. /**
  16. * \brief This file contains CV operators, The layout is NHWC
  17. */
  18. class FlipBase : public OperatorBase {
  19. DEF_OPR_IMPL_CTOR(FlipBase, OperatorBase);
  20. DEF_OPR_PARAM(Flip);
  21. protected:
  22. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  23. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  24. };
  25. class FlipForward : public FlipBase {
  26. DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1);
  27. public:
  28. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  29. _megdnn_workspace workspace) = 0;
  30. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  31. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  32. const TensorLayout& dst) = 0;
  33. protected:
  34. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  35. size_t workspace_in_bytes);
  36. };
  37. using Flip = FlipForward;
  38. class RotateBase : public OperatorBase {
  39. DEF_OPR_IMPL_CTOR(RotateBase, OperatorBase);
  40. DEF_OPR_PARAM(Rotate);
  41. protected:
  42. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  43. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  44. };
  45. class RotateForward : public RotateBase {
  46. DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1);
  47. public:
  48. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  49. _megdnn_workspace workspace) = 0;
  50. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  51. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  52. const TensorLayout& dst) = 0;
  53. protected:
  54. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  55. size_t workspace_in_bytes);
  56. };
  57. using Rotate = RotateForward;
  58. class ROICopyBase : public OperatorBase {
  59. DEF_OPR_IMPL_CTOR(ROICopyBase, OperatorBase);
  60. DEF_OPR_PARAM(ROICopy);
  61. protected:
  62. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  63. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  64. };
  65. class ROICopyForward : public ROICopyBase {
  66. DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1);
  67. public:
  68. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  69. _megdnn_workspace workspace) = 0;
  70. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  71. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  72. const TensorLayout& dst) = 0;
  73. protected:
  74. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  75. size_t workspace_in_bytes);
  76. };
  77. using ROICopy = ROICopyForward;
  78. class CvtColorBase : public OperatorBase {
  79. DEF_OPR_IMPL_CTOR(CvtColorBase, OperatorBase);
  80. DEF_OPR_PARAM(CvtColor);
  81. protected:
  82. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  83. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  84. };
  85. class CvtColorForward : public CvtColorBase {
  86. DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1);
  87. public:
  88. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  89. _megdnn_workspace workspace) = 0;
  90. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  91. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  92. const TensorLayout& dst) = 0;
  93. protected:
  94. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  95. size_t workspace_in_bytes);
  96. };
  97. using CvtColor = CvtColorForward;
  98. /**
  99. * \brief Applices an affine transformation
  100. */
  101. class WarpAffineBase : public OperatorBase {
  102. DEF_OPR_IMPL_CTOR(WarpAffineBase, OperatorBase);
  103. DEF_OPR_PARAM(WarpAffine);
  104. public:
  105. using InterpolationMode = Param::InterpolationMode;
  106. using BorderMode = Param::BorderMode;
  107. protected:
  108. void check_layout_fwd(const TensorLayout& src, const TensorLayout& trans,
  109. const TensorLayout& dst);
  110. std::string param_msg() const;
  111. int get_real_coord(int p, int len);
  112. };
  113. class WarpAffineForward : public WarpAffineBase {
  114. DEF_OPR_IMPL(WarpAffineForward, WarpAffineBase, 2, 1);
  115. public:
  116. /**
  117. * \param[in] src input tensor
  118. * \param[in] trans transform matrix tensor
  119. * \param[in] dst output tensor
  120. *
  121. * \warning src, trans, border_value, dst should be contiguous
  122. * The size of trans is N * 2 * 3
  123. */
  124. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in trans,
  125. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  126. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  127. const TensorLayout& trans,
  128. const TensorLayout& dst) = 0;
  129. protected:
  130. void check_exec(const TensorLayout& src, const TensorLayout& trans,
  131. const TensorLayout& dst, size_t workspace_in_bytes);
  132. };
  133. using WarpAffine = WarpAffineForward;
  134. class GaussianBlurBase : public OperatorBase {
  135. DEF_OPR_IMPL_CTOR(GaussianBlurBase, OperatorBase);
  136. DEF_OPR_PARAM(GaussianBlur);
  137. protected:
  138. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  139. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  140. };
  141. class GaussianBlurForward : public GaussianBlurBase {
  142. DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1);
  143. public:
  144. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  145. _megdnn_workspace workspace) = 0;
  146. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  147. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  148. const TensorLayout& dst) = 0;
  149. protected:
  150. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  151. size_t workspace_in_bytes);
  152. };
  153. using GaussianBlur = GaussianBlurForward;
  154. /**
  155. * \brief Resize opr.
  156. */
  157. class ResizeBase : public OperatorBase {
  158. DEF_OPR_PARAM(Resize);
  159. DEF_OPR_IMPL(ResizeBase, OperatorBase, 1, 1);
  160. public:
  161. using InterpolationMode = Param::InterpolationMode;
  162. protected:
  163. //! get origin coord
  164. std::pair<float, int> get_origin_coord(float scale, int size, int idx);
  165. //! get nearest index in src
  166. int get_nearest_src(float scale, int size, int idx);
  167. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  168. };
  169. class ResizeForward : public ResizeBase {
  170. DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1);
  171. public:
  172. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  173. _megdnn_workspace workspace) = 0;
  174. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  175. const TensorLayout& dst) = 0;
  176. protected:
  177. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  178. size_t workspace_in_bytes);
  179. };
  180. using Resize = ResizeForward;
  181. class ResizeBackward : public ResizeBase {
  182. DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1);
  183. public:
  184. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
  185. _megdnn_workspace workspace) = 0;
  186. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  187. const TensorLayout& mat) = 0;
  188. protected:
  189. void check_exec(const TensorLayout& diff, const TensorLayout& mat,
  190. size_t workspace_in_bytes);
  191. };
  192. /**
  193. * \brief Remap opr.
  194. */
  195. class RemapBase : public OperatorBase {
  196. DEF_OPR_PARAM(Remap);
  197. DEF_OPR_IMPL(RemapBase, OperatorBase, 2, 1);
  198. public:
  199. using InterpolationMode = Param::InterpolationMode;
  200. using BorderMode = Param::BorderMode;
  201. protected:
  202. void check_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy,
  203. const TensorLayout& dst);
  204. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy,
  205. TensorLayout& dst);
  206. };
  207. class RemapForward : public RemapBase {
  208. DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1);
  209. public:
  210. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
  211. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  212. void deduce_layout(const TensorLayout& src, const TensorLayout& map_xy,
  213. TensorLayout& dst);
  214. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  215. const TensorLayout& map_xy,
  216. const TensorLayout& dst) = 0;
  217. protected:
  218. void check_exec(const TensorLayout& src, const TensorLayout& map_xy,
  219. const TensorLayout& dst, size_t workspace_in_bytes);
  220. };
  221. using Remap = RemapForward;
  222. class RemapBackwardData : public RemapBase {
  223. DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1);
  224. public:
  225. virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
  226. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  227. virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy,
  228. const TensorLayout& diff,
  229. const TensorLayout& grad) = 0;
  230. protected:
  231. void check_exec(const TensorLayout& map_xy, const TensorLayout& diff,
  232. const TensorLayout& grad, size_t workspace_in_bytes);
  233. };
  234. class RemapBackwardMat : public RemapBase {
  235. DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1);
  236. public:
  237. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
  238. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  239. _megdnn_workspace workspace) = 0;
  240. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  241. const TensorLayout& map_xy,
  242. const TensorLayout& diff,
  243. const TensorLayout& grad) = 0;
  244. protected:
  245. void check_exec(const TensorLayout& src, const TensorLayout& map_xy,
  246. const TensorLayout& diff, const TensorLayout& grad,
  247. size_t workspace_in_bytes);
  248. };
  249. class SeparableFilterBase : public OperatorBase {
  250. DEF_OPR_IMPL_CTOR(SeparableFilterBase, OperatorBase);
  251. DEF_OPR_PARAM(SeparableFilter);
  252. protected:
  253. void deduce_layout_fwd(const TensorLayout& src,
  254. const TensorLayout& filter_x,
  255. const TensorLayout& filter_y, TensorLayout& dst);
  256. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x,
  257. const TensorLayout& filter_y,
  258. const TensorLayout& dst);
  259. };
  260. class SeparableFilterForward : public SeparableFilterBase {
  261. DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1);
  262. public:
  263. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x,
  264. _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
  265. _megdnn_workspace workspace) = 0;
  266. void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x,
  267. const TensorLayout& filter_y, TensorLayout& dst);
  268. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  269. const TensorLayout& filter_x,
  270. const TensorLayout& filter_y,
  271. const TensorLayout& dst) = 0;
  272. protected:
  273. void check_exec(const TensorLayout& src, const TensorLayout& filter_x,
  274. const TensorLayout& filter_y, const TensorLayout& dst,
  275. size_t workspace_in_bytes);
  276. };
  277. using SeparableFilter = SeparableFilterForward;
  278. } // namespace megdnn
  279. #include "megdnn/internal/opr_header_epilogue.h"
  280. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台