You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imgproc.h 8.1 kB


  1. /**
  2. * \file dnn/include/megdnn/oprs/imgproc.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. namespace megdnn {
  14. class WarpPerspectiveBase : public OperatorBase {
  15. DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase);
  16. DEF_OPR_PARAM(WarpPerspective);
  17. public:
  18. using InterpolationMode = Param::InterpolationMode;
  19. using BorderMode = Param::BorderMode;
  20. protected:
  21. void check_layout_fwd(
  22. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
  23. check_layout_fwd(src, mat, {}, dst);
  24. }
  25. void check_layout_fwd(
  26. const TensorLayout& src, const TensorLayout& mat,
  27. const TensorLayout& mat_idx, const TensorLayout& dst);
  28. std::string param_msg() const;
  29. int get_real_coord(int p, int len);
  30. };
  31. class WarpPerspectiveForward : public WarpPerspectiveBase {
  32. DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1);
  33. public:
  34. /**
  35. * \param[in] src (n, channel, in_height, in_width)
  36. * \param[in] mat (n, 3, 3)
  37. * \param[out] dst (n, channel, out_height, out_width)
  38. *
  39. * \see
  40. * http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine
  41. *
  42. * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2]
  43. * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator,
  44. * (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator)
  45. *
  46. * src and dst can have different shapes, as long as their n and c agree.
  47. * src, mat and dst should be contiguous.
  48. */
  49. void exec(
  50. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst,
  51. _megdnn_workspace workspace) {
  52. exec(src, mat, {}, dst, workspace);
  53. }
  54. /**
  55. * \p src should have batch size m, and \p mat and \p mat_idx should
  56. * both have batch size n. Each item in \p mat_idx must be in the range
  57. * of [0, m-1].
  58. *
  59. * \param mat_idx the indices of input image that each matrix in \p mat
  60. * should act on. It can also be empty and in such case \p mat
  61. * should have the same batch size as \p src.
  62. */
  63. virtual void exec(
  64. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
  65. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  66. size_t get_workspace_in_bytes(
  67. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
  68. return get_workspace_in_bytes(src, mat, {}, dst);
  69. }
  70. virtual size_t get_workspace_in_bytes(
  71. const TensorLayout& src, const TensorLayout& mat,
  72. const TensorLayout& mat_idx, const TensorLayout& dst) = 0;
  73. protected:
  74. void check_exec(
  75. const TensorLayout& src, const TensorLayout& mat,
  76. const TensorLayout& mat_idx, const TensorLayout& dst,
  77. size_t workspace_in_bytes);
  78. void check_exec_allow_nhwc_mat_idx(
  79. const TensorLayout& src, const TensorLayout& mat,
  80. const TensorLayout& mat_idx, const TensorLayout& dst,
  81. size_t workspace_in_bytes);
  82. };
  83. using WarpPerspective = WarpPerspectiveForward;
  84. class WarpPerspectiveBackwardData : public WarpPerspectiveBase {
  85. DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1);
  86. public:
  87. /**
  88. * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
  89. * \param[in] diff the backpropagated gradient wrt. dst
  90. * \param[out] grad the backpropagated gradient wrt. src
  91. * \param[out] workspace temporary workspace to perform backward
  92. */
  93. void exec(
  94. _megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  95. _megdnn_workspace workspace) {
  96. exec(mat, {}, diff, grad, workspace);
  97. }
  98. virtual void exec(
  99. _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff,
  100. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  101. size_t get_workspace_in_bytes(
  102. const TensorLayout& mat, const TensorLayout& diff,
  103. const TensorLayout& grad) {
  104. return get_workspace_in_bytes(mat, {}, diff, grad);
  105. }
  106. virtual size_t get_workspace_in_bytes(
  107. const TensorLayout& mat, const TensorLayout& mat_idx,
  108. const TensorLayout& diff, const TensorLayout& grad) = 0;
  109. protected:
  110. void check_exec(
  111. const TensorLayout& mat, const TensorLayout& mat_idx,
  112. const TensorLayout& diff, const TensorLayout& grad,
  113. size_t workspace_in_bytes);
  114. };
  115. class WarpPerspectiveBackwardMat : public WarpPerspectiveBase {
  116. DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1);
  117. public:
  118. /**
  119. * \param[in] src the `src' parameter in WarpPerspectiveForward::exec
  120. * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
  121. * \param[in] diff the backpropagated gradient wrt. dst
  122. * \param[out] grad the backpropagated gradient wrt. mat
  123. * \param[out] workspace temporary workspace to perform backward
  124. */
  125. void exec(
  126. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in diff,
  127. _megdnn_tensor_out grad, _megdnn_workspace workspace) {
  128. exec(src, mat, {}, diff, grad, workspace);
  129. }
  130. virtual void exec(
  131. _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
  132. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  133. _megdnn_workspace workspace) = 0;
  134. size_t get_workspace_in_bytes(
  135. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& diff,
  136. const TensorLayout& grad) {
  137. return get_workspace_in_bytes(src, mat, {}, diff, grad);
  138. }
  139. virtual size_t get_workspace_in_bytes(
  140. const TensorLayout& src, const TensorLayout& mat,
  141. const TensorLayout& mat_idx, const TensorLayout& diff,
  142. const TensorLayout& grad) = 0;
  143. protected:
  144. void check_exec(
  145. const TensorLayout& src, const TensorLayout& mat,
  146. const TensorLayout& mat_idx, const TensorLayout& diff,
  147. const TensorLayout& grad, size_t workspace_in_bytes);
  148. };
  149. class DctChannelSelectForward : public OperatorBase {
  150. DEF_OPR_PARAM(DctChannelSelect);
  151. DEF_OPR_IMPL(DctChannelSelectForward, OperatorBase, 3, 1);
  152. public:
  153. /**
  154. * \param[in] DctChannelSelectForward input, must be uint8 nchw tensor
  155. * \param[in] mask_offset input, must be int32 nchw tensor
  156. * \param[in] mask_val input, must be int32 nchw tensor
  157. * \param[dst] DctChannelSelectForward output, default fp32 nchw tensor
  158. * \param[out] workspace temporary workspace to perform forward
  159. */
  160. virtual void exec(
  161. _megdnn_tensor_in src, _megdnn_tensor_in mask_offset,
  162. _megdnn_tensor_in mask_val, _megdnn_tensor_out dst,
  163. _megdnn_workspace workspace) = 0;
  164. void deduce_layout(
  165. const TensorLayout& src, const TensorLayout& mask_offset,
  166. const TensorLayout& mask_val, TensorLayout& dst);
  167. virtual size_t get_workspace_in_bytes(
  168. const TensorLayout& src, const TensorLayout& mask_offset,
  169. const TensorLayout& mask_val, const TensorLayout& dst) = 0;
  170. protected:
  171. void check_layout_fwd(
  172. const TensorLayout& src, const TensorLayout& mask_offset,
  173. const TensorLayout& mask_val, const TensorLayout& dst);
  174. void deduce_layout_fwd(
  175. const TensorLayout& src, const TensorLayout& mask_offset,
  176. const TensorLayout& mask_val, TensorLayout& dst);
  177. std::string param_msg() const;
  178. };
  179. } // namespace megdnn
  180. #include "megdnn/internal/opr_header_epilogue.h"
  181. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台