You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cv.h 12 kB


  1. /**
  2. * \file dnn/include/megdnn/oprs/cv.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/internal/opr_header_prologue.h"
  14. namespace megdnn {
  15. /**
  16. * \brief This file contains CV operators, The layout is NHWC
  17. */
  18. class FlipBase : public OperatorBase {
  19. DEF_OPR_IMPL_CTOR(FlipBase, OperatorBase);
  20. DEF_OPR_PARAM(Flip);
  21. protected:
  22. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  23. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  24. };
  25. class FlipForward : public FlipBase {
  26. DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1);
  27. public:
  28. virtual void exec(
  29. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  30. _megdnn_workspace workspace) = 0;
  31. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  32. virtual size_t get_workspace_in_bytes(
  33. const TensorLayout& src, const TensorLayout& dst) = 0;
  34. protected:
  35. void check_exec(
  36. const TensorLayout& src, const TensorLayout& dst,
  37. size_t workspace_in_bytes);
  38. };
  39. using Flip = FlipForward;
  40. class RotateBase : public OperatorBase {
  41. DEF_OPR_IMPL_CTOR(RotateBase, OperatorBase);
  42. DEF_OPR_PARAM(Rotate);
  43. protected:
  44. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  45. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  46. };
  47. class RotateForward : public RotateBase {
  48. DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1);
  49. public:
  50. virtual void exec(
  51. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  52. _megdnn_workspace workspace) = 0;
  53. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  54. virtual size_t get_workspace_in_bytes(
  55. const TensorLayout& src, const TensorLayout& dst) = 0;
  56. protected:
  57. void check_exec(
  58. const TensorLayout& src, const TensorLayout& dst,
  59. size_t workspace_in_bytes);
  60. };
  61. using Rotate = RotateForward;
  62. class ROICopyBase : public OperatorBase {
  63. DEF_OPR_IMPL_CTOR(ROICopyBase, OperatorBase);
  64. DEF_OPR_PARAM(ROICopy);
  65. protected:
  66. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  67. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  68. };
  69. class ROICopyForward : public ROICopyBase {
  70. DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1);
  71. public:
  72. virtual void exec(
  73. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  74. _megdnn_workspace workspace) = 0;
  75. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  76. virtual size_t get_workspace_in_bytes(
  77. const TensorLayout& src, const TensorLayout& dst) = 0;
  78. protected:
  79. void check_exec(
  80. const TensorLayout& src, const TensorLayout& dst,
  81. size_t workspace_in_bytes);
  82. };
  83. using ROICopy = ROICopyForward;
  84. class CvtColorBase : public OperatorBase {
  85. DEF_OPR_IMPL_CTOR(CvtColorBase, OperatorBase);
  86. DEF_OPR_PARAM(CvtColor);
  87. protected:
  88. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  89. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  90. };
  91. class CvtColorForward : public CvtColorBase {
  92. DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1);
  93. public:
  94. virtual void exec(
  95. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  96. _megdnn_workspace workspace) = 0;
  97. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  98. virtual size_t get_workspace_in_bytes(
  99. const TensorLayout& src, const TensorLayout& dst) = 0;
  100. protected:
  101. void check_exec(
  102. const TensorLayout& src, const TensorLayout& dst,
  103. size_t workspace_in_bytes);
  104. };
  105. using CvtColor = CvtColorForward;
  106. /**
  107. * \brief Applices an affine transformation
  108. */
  109. class WarpAffineBase : public OperatorBase {
  110. DEF_OPR_IMPL_CTOR(WarpAffineBase, OperatorBase);
  111. DEF_OPR_PARAM(WarpAffine);
  112. public:
  113. using InterpolationMode = Param::InterpolationMode;
  114. using BorderMode = Param::BorderMode;
  115. protected:
  116. void check_layout_fwd(
  117. const TensorLayout& src, const TensorLayout& trans,
  118. const TensorLayout& dst);
  119. std::string param_msg() const;
  120. int get_real_coord(int p, int len);
  121. };
  122. class WarpAffineForward : public WarpAffineBase {
  123. DEF_OPR_IMPL(WarpAffineForward, WarpAffineBase, 2, 1);
  124. public:
  125. /**
  126. * \param[in] src input tensor
  127. * \param[in] trans transform matrix tensor
  128. * \param[in] dst output tensor
  129. *
  130. * \warning src, trans, border_value, dst should be contiguous
  131. * The size of trans is N * 2 * 3
  132. */
  133. virtual void exec(
  134. _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_out dst,
  135. _megdnn_workspace workspace) = 0;
  136. virtual size_t get_workspace_in_bytes(
  137. const TensorLayout& src, const TensorLayout& trans,
  138. const TensorLayout& dst) = 0;
  139. protected:
  140. void check_exec(
  141. const TensorLayout& src, const TensorLayout& trans, const TensorLayout& dst,
  142. size_t workspace_in_bytes);
  143. };
  144. using WarpAffine = WarpAffineForward;
  145. class GaussianBlurBase : public OperatorBase {
  146. DEF_OPR_IMPL_CTOR(GaussianBlurBase, OperatorBase);
  147. DEF_OPR_PARAM(GaussianBlur);
  148. protected:
  149. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  150. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  151. };
  152. class GaussianBlurForward : public GaussianBlurBase {
  153. DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1);
  154. public:
  155. virtual void exec(
  156. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  157. _megdnn_workspace workspace) = 0;
  158. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  159. virtual size_t get_workspace_in_bytes(
  160. const TensorLayout& src, const TensorLayout& dst) = 0;
  161. protected:
  162. void check_exec(
  163. const TensorLayout& src, const TensorLayout& dst,
  164. size_t workspace_in_bytes);
  165. };
  166. using GaussianBlur = GaussianBlurForward;
  167. /**
  168. * \brief Resize opr.
  169. */
  170. class ResizeBase : public OperatorBase {
  171. DEF_OPR_PARAM(Resize);
  172. DEF_OPR_IMPL(ResizeBase, OperatorBase, 1, 1);
  173. public:
  174. using InterpolationMode = Param::InterpolationMode;
  175. protected:
  176. //! get origin coord
  177. std::pair<float, int> get_cubic_coord(float scale, int idx);
  178. std::tuple<float, int, float, int> get_nearest_linear_coord(
  179. InterpolationMode imode, float scale, int size, int idx);
  180. //! get nearest index in src
  181. int get_nearest_src(float scale, int size, int idx);
  182. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  183. };
  184. class ResizeForward : public ResizeBase {
  185. DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1);
  186. public:
  187. virtual void exec(
  188. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  189. _megdnn_workspace workspace) = 0;
  190. virtual size_t get_workspace_in_bytes(
  191. const TensorLayout& src, const TensorLayout& dst) = 0;
  192. protected:
  193. void check_exec(
  194. const TensorLayout& src, const TensorLayout& dst,
  195. size_t workspace_in_bytes);
  196. };
  197. using Resize = ResizeForward;
  198. class ResizeBackward : public ResizeBase {
  199. DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1);
  200. public:
  201. virtual void exec(
  202. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  203. _megdnn_workspace workspace) = 0;
  204. virtual size_t get_workspace_in_bytes(
  205. const TensorLayout& diff, const TensorLayout& mat) = 0;
  206. protected:
  207. void check_exec(
  208. const TensorLayout& diff, const TensorLayout& mat,
  209. size_t workspace_in_bytes);
  210. };
  211. /**
  212. * \brief Remap opr.
  213. */
  214. class RemapBase : public OperatorBase {
  215. DEF_OPR_PARAM(Remap);
  216. DEF_OPR_IMPL(RemapBase, OperatorBase, 2, 1);
  217. public:
  218. using InterpolationMode = Param::InterpolationMode;
  219. using BorderMode = Param::BorderMode;
  220. protected:
  221. void check_layout_fwd(
  222. const TensorLayout& src, const TensorLayout& map_xy,
  223. const TensorLayout& dst);
  224. void deduce_layout_fwd(
  225. const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);
  226. };
  227. class RemapForward : public RemapBase {
  228. DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1);
  229. public:
  230. virtual void exec(
  231. _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_out dst,
  232. _megdnn_workspace workspace) = 0;
  233. void deduce_layout(
  234. const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);
  235. virtual size_t get_workspace_in_bytes(
  236. const TensorLayout& src, const TensorLayout& map_xy,
  237. const TensorLayout& dst) = 0;
  238. protected:
  239. void check_exec(
  240. const TensorLayout& src, const TensorLayout& map_xy,
  241. const TensorLayout& dst, size_t workspace_in_bytes);
  242. };
  243. using Remap = RemapForward;
  244. class RemapBackwardData : public RemapBase {
  245. DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1);
  246. public:
  247. virtual void exec(
  248. _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  249. _megdnn_workspace workspace) = 0;
  250. virtual size_t get_workspace_in_bytes(
  251. const TensorLayout& map_xy, const TensorLayout& diff,
  252. const TensorLayout& grad) = 0;
  253. protected:
  254. void check_exec(
  255. const TensorLayout& map_xy, const TensorLayout& diff,
  256. const TensorLayout& grad, size_t workspace_in_bytes);
  257. };
  258. class RemapBackwardMat : public RemapBase {
  259. DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1);
  260. public:
  261. virtual void exec(
  262. _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
  263. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  264. virtual size_t get_workspace_in_bytes(
  265. const TensorLayout& src, const TensorLayout& map_xy,
  266. const TensorLayout& diff, const TensorLayout& grad) = 0;
  267. protected:
  268. void check_exec(
  269. const TensorLayout& src, const TensorLayout& map_xy,
  270. const TensorLayout& diff, const TensorLayout& grad,
  271. size_t workspace_in_bytes);
  272. };
  273. class SeparableFilterBase : public OperatorBase {
  274. DEF_OPR_IMPL_CTOR(SeparableFilterBase, OperatorBase);
  275. DEF_OPR_PARAM(SeparableFilter);
  276. protected:
  277. void deduce_layout_fwd(
  278. const TensorLayout& src, const TensorLayout& filter_x,
  279. const TensorLayout& filter_y, TensorLayout& dst);
  280. void check_layout_fwd(
  281. const TensorLayout& src, const TensorLayout& filter_x,
  282. const TensorLayout& filter_y, const TensorLayout& dst);
  283. };
  284. class SeparableFilterForward : public SeparableFilterBase {
  285. DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1);
  286. public:
  287. virtual void exec(
  288. _megdnn_tensor_in src, _megdnn_tensor_in filter_x,
  289. _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
  290. _megdnn_workspace workspace) = 0;
  291. void deduce_layout(
  292. const TensorLayout& src, const TensorLayout& filter_x,
  293. const TensorLayout& filter_y, TensorLayout& dst);
  294. virtual size_t get_workspace_in_bytes(
  295. const TensorLayout& src, const TensorLayout& filter_x,
  296. const TensorLayout& filter_y, const TensorLayout& dst) = 0;
  297. protected:
  298. void check_exec(
  299. const TensorLayout& src, const TensorLayout& filter_x,
  300. const TensorLayout& filter_y, const TensorLayout& dst,
  301. size_t workspace_in_bytes);
  302. };
  303. using SeparableFilter = SeparableFilterForward;
  304. } // namespace megdnn
  305. #include "megdnn/internal/opr_header_epilogue.h"
  306. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台