You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cv.h 12 kB


  1. /**
  2. * \file dnn/include/megdnn/oprs/cv.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/internal/opr_header_prologue.h"
  14. namespace megdnn {
  15. /**
  16. * \brief This file contains CV operators, The layout is NHWC
  17. */
  18. class FlipBase : public OperatorBase {
  19. DEF_OPR_IMPL_CTOR(FlipBase, OperatorBase);
  20. DEF_OPR_PARAM(Flip);
  21. protected:
  22. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  23. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  24. };
  25. class FlipForward : public FlipBase {
  26. DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1);
  27. public:
  28. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  29. _megdnn_workspace workspace) = 0;
  30. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  31. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  32. const TensorLayout& dst) = 0;
  33. protected:
  34. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  35. size_t workspace_in_bytes);
  36. };
  37. using Flip = FlipForward;
  38. class RotateBase : public OperatorBase {
  39. DEF_OPR_IMPL_CTOR(RotateBase, OperatorBase);
  40. DEF_OPR_PARAM(Rotate);
  41. protected:
  42. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  43. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  44. };
  45. class RotateForward : public RotateBase {
  46. DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1);
  47. public:
  48. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  49. _megdnn_workspace workspace) = 0;
  50. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  51. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  52. const TensorLayout& dst) = 0;
  53. protected:
  54. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  55. size_t workspace_in_bytes);
  56. };
  57. using Rotate = RotateForward;
  58. class ROICopyBase : public OperatorBase {
  59. DEF_OPR_IMPL_CTOR(ROICopyBase, OperatorBase);
  60. DEF_OPR_PARAM(ROICopy);
  61. protected:
  62. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  63. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  64. };
  65. class ROICopyForward : public ROICopyBase {
  66. DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1);
  67. public:
  68. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  69. _megdnn_workspace workspace) = 0;
  70. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  71. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  72. const TensorLayout& dst) = 0;
  73. protected:
  74. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  75. size_t workspace_in_bytes);
  76. };
  77. using ROICopy = ROICopyForward;
  78. class CvtColorBase : public OperatorBase {
  79. DEF_OPR_IMPL_CTOR(CvtColorBase, OperatorBase);
  80. DEF_OPR_PARAM(CvtColor);
  81. protected:
  82. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  83. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  84. };
  85. class CvtColorForward : public CvtColorBase {
  86. DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1);
  87. public:
  88. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  89. _megdnn_workspace workspace) = 0;
  90. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  91. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  92. const TensorLayout& dst) = 0;
  93. protected:
  94. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  95. size_t workspace_in_bytes);
  96. };
  97. using CvtColor = CvtColorForward;
  98. /**
  99. * \brief Applices an affine transformation
  100. */
  101. class WarpAffineBase : public OperatorBase {
  102. DEF_OPR_IMPL_CTOR(WarpAffineBase, OperatorBase);
  103. DEF_OPR_PARAM(WarpAffine);
  104. public:
  105. using InterpolationMode = Param::InterpolationMode;
  106. using BorderMode = Param::BorderMode;
  107. protected:
  108. void check_layout_fwd(const TensorLayout& src, const TensorLayout& trans,
  109. const TensorLayout& dst);
  110. std::string param_msg() const;
  111. int get_real_coord(int p, int len);
  112. };
  113. class WarpAffineForward : public WarpAffineBase {
  114. DEF_OPR_IMPL(WarpAffineForward, WarpAffineBase, 2, 1);
  115. public:
  116. /**
  117. * \param[in] src input tensor
  118. * \param[in] trans transform matrix tensor
  119. * \param[in] dst output tensor
  120. *
  121. * \warning src, trans, border_value, dst should be contiguous
  122. * The size of trans is N * 2 * 3
  123. */
  124. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in trans,
  125. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  126. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  127. const TensorLayout& trans,
  128. const TensorLayout& dst) = 0;
  129. protected:
  130. void check_exec(const TensorLayout& src, const TensorLayout& trans,
  131. const TensorLayout& dst, size_t workspace_in_bytes);
  132. };
  133. using WarpAffine = WarpAffineForward;
  134. class GaussianBlurBase : public OperatorBase {
  135. DEF_OPR_IMPL_CTOR(GaussianBlurBase, OperatorBase);
  136. DEF_OPR_PARAM(GaussianBlur);
  137. protected:
  138. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  139. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  140. };
  141. class GaussianBlurForward : public GaussianBlurBase {
  142. DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1);
  143. public:
  144. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  145. _megdnn_workspace workspace) = 0;
  146. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  147. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  148. const TensorLayout& dst) = 0;
  149. protected:
  150. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  151. size_t workspace_in_bytes);
  152. };
  153. using GaussianBlur = GaussianBlurForward;
  154. /**
  155. * \brief Resize opr.
  156. */
  157. class ResizeBase : public OperatorBase {
  158. DEF_OPR_PARAM(Resize);
  159. DEF_OPR_IMPL(ResizeBase, OperatorBase, 1, 1);
  160. public:
  161. using InterpolationMode = Param::InterpolationMode;
  162. protected:
  163. //! get origin coord
  164. std::pair<float, int> get_origin_coord(float scale, int size, int idx);
  165. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  166. };
  167. class ResizeForward : public ResizeBase {
  168. DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1);
  169. public:
  170. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  171. _megdnn_workspace workspace) = 0;
  172. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  173. const TensorLayout& dst) = 0;
  174. protected:
  175. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  176. size_t workspace_in_bytes);
  177. };
  178. using Resize = ResizeForward;
  179. class ResizeBackward : public ResizeBase {
  180. DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1);
  181. public:
  182. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
  183. _megdnn_workspace workspace) = 0;
  184. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  185. const TensorLayout& mat) = 0;
  186. protected:
  187. void check_exec(const TensorLayout& diff, const TensorLayout& mat,
  188. size_t workspace_in_bytes);
  189. };
  190. /**
  191. * \brief Remap opr.
  192. */
  193. class RemapBase : public OperatorBase {
  194. DEF_OPR_PARAM(Remap);
  195. DEF_OPR_IMPL(RemapBase, OperatorBase, 2, 1);
  196. public:
  197. using InterpolationMode = Param::InterpolationMode;
  198. using BorderMode = Param::BorderMode;
  199. protected:
  200. void check_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy,
  201. const TensorLayout& dst);
  202. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy,
  203. TensorLayout& dst);
  204. };
  205. class RemapForward : public RemapBase {
  206. DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1);
  207. public:
  208. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
  209. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  210. void deduce_layout(const TensorLayout& src, const TensorLayout& map_xy,
  211. TensorLayout& dst);
  212. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  213. const TensorLayout& map_xy,
  214. const TensorLayout& dst) = 0;
  215. protected:
  216. void check_exec(const TensorLayout& src, const TensorLayout& map_xy,
  217. const TensorLayout& dst, size_t workspace_in_bytes);
  218. };
  219. using Remap = RemapForward;
  220. class RemapBackwardData : public RemapBase {
  221. DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1);
  222. public:
  223. virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
  224. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  225. virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy,
  226. const TensorLayout& diff,
  227. const TensorLayout& grad) = 0;
  228. protected:
  229. void check_exec(const TensorLayout& map_xy, const TensorLayout& diff,
  230. const TensorLayout& grad, size_t workspace_in_bytes);
  231. };
  232. class RemapBackwardMat : public RemapBase {
  233. DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1);
  234. public:
  235. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
  236. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  237. _megdnn_workspace workspace) = 0;
  238. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  239. const TensorLayout& map_xy,
  240. const TensorLayout& diff,
  241. const TensorLayout& grad) = 0;
  242. protected:
  243. void check_exec(const TensorLayout& src, const TensorLayout& map_xy,
  244. const TensorLayout& diff, const TensorLayout& grad,
  245. size_t workspace_in_bytes);
  246. };
  247. class SeparableFilterBase : public OperatorBase {
  248. DEF_OPR_IMPL_CTOR(SeparableFilterBase, OperatorBase);
  249. DEF_OPR_PARAM(SeparableFilter);
  250. protected:
  251. void deduce_layout_fwd(const TensorLayout& src,
  252. const TensorLayout& filter_x,
  253. const TensorLayout& filter_y, TensorLayout& dst);
  254. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x,
  255. const TensorLayout& filter_y,
  256. const TensorLayout& dst);
  257. };
  258. class SeparableFilterForward : public SeparableFilterBase {
  259. DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1);
  260. public:
  261. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x,
  262. _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
  263. _megdnn_workspace workspace) = 0;
  264. void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x,
  265. const TensorLayout& filter_y, TensorLayout& dst);
  266. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  267. const TensorLayout& filter_x,
  268. const TensorLayout& filter_y,
  269. const TensorLayout& dst) = 0;
  270. protected:
  271. void check_exec(const TensorLayout& src, const TensorLayout& filter_x,
  272. const TensorLayout& filter_y, const TensorLayout& dst,
  273. size_t workspace_in_bytes);
  274. };
  275. using SeparableFilter = SeparableFilterForward;
  276. } // namespace megdnn
  277. #include "megdnn/internal/opr_header_epilogue.h"
  278. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台