You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imgproc.h 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. /**
  2. * \file dnn/include/megdnn/oprs/imgproc.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. namespace megdnn {
  14. class WarpPerspectiveBase: public OperatorBase {
  15. DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase);
  16. DEF_OPR_PARAM(WarpPerspective);
  17. public:
  18. using InterpolationMode = Param::InterpolationMode;
  19. using BorderMode = Param::BorderMode;
  20. protected:
  21. void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat,
  22. const TensorLayout &dst) {
  23. check_layout_fwd(src, mat, {}, dst);
  24. }
  25. void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat,
  26. const TensorLayout &mat_idx, const TensorLayout &dst);
  27. std::string param_msg() const;
  28. int get_real_coord(int p, int len);
  29. };
  30. class WarpPerspectiveForward: public WarpPerspectiveBase {
  31. DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1);
  32. public:
  33. /**
  34. * \param[in] src (n, channel, in_height, in_width)
  35. * \param[in] mat (n, 3, 3)
  36. * \param[out] dst (n, channel, out_height, out_width)
  37. *
  38. * \see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine
  39. *
  40. * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2]
  41. * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator,
  42. * (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator)
  43. *
  44. * src and dst can have different shapes, as long as their n and c agree.
  45. * src, mat and dst should be contiguous.
  46. */
  47. void exec(_megdnn_tensor_in src,
  48. _megdnn_tensor_in mat,
  49. _megdnn_tensor_out dst,
  50. _megdnn_workspace workspace) {
  51. exec(src, mat, {}, dst, workspace);
  52. }
  53. /**
  54. * \p src should have batch size m, and \p mat and \p mat_idx should
  55. * both have batch size n. Each item in \p mat_idx must be in the range
  56. * of [0, m-1].
  57. *
  58. * \param mat_idx the indices of input image that each matrix in \p mat
  59. * should act on. It can also be empty and in such case \p mat
  60. * should have the same batch size as \p src.
  61. */
  62. virtual void exec(_megdnn_tensor_in src,
  63. _megdnn_tensor_in mat,
  64. _megdnn_tensor_in mat_idx,
  65. _megdnn_tensor_out dst,
  66. _megdnn_workspace workspace) = 0;
  67. size_t get_workspace_in_bytes(const TensorLayout &src,
  68. const TensorLayout &mat,
  69. const TensorLayout &dst) {
  70. return get_workspace_in_bytes(src, mat, {}, dst);
  71. }
  72. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  73. const TensorLayout &mat,
  74. const TensorLayout &mat_idx,
  75. const TensorLayout &dst) = 0;
  76. protected:
  77. void check_exec(const TensorLayout &src,
  78. const TensorLayout &mat,
  79. const TensorLayout &mat_idx,
  80. const TensorLayout &dst,
  81. size_t workspace_in_bytes);
  82. void check_exec_allow_nhwc_mat_idx(const TensorLayout &src,
  83. const TensorLayout &mat,
  84. const TensorLayout &mat_idx,
  85. const TensorLayout &dst,
  86. size_t workspace_in_bytes);
  87. };
  88. using WarpPerspective = WarpPerspectiveForward;
  89. class WarpPerspectiveBackwardData: public WarpPerspectiveBase {
  90. DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1);
  91. public:
  92. /**
  93. * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
  94. * \param[in] diff the backpropagated gradient wrt. dst
  95. * \param[out] grad the backpropagated gradient wrt. src
  96. * \param[out] workspace temporary workspace to perform backward
  97. */
  98. virtual void exec(_megdnn_tensor_in mat,
  99. _megdnn_tensor_in diff,
  100. _megdnn_tensor_out grad,
  101. _megdnn_workspace workspace) = 0;
  102. virtual size_t get_workspace_in_bytes(const TensorLayout &mat,
  103. const TensorLayout &diff,
  104. const TensorLayout &grad) = 0;
  105. protected:
  106. void check_exec(const TensorLayout &mat,
  107. const TensorLayout &diff,
  108. const TensorLayout &grad,
  109. size_t workspace_in_bytes);
  110. };
  111. class WarpPerspectiveBackwardMat: public WarpPerspectiveBase {
  112. DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1);
  113. public:
  114. /**
  115. * \param[in] src the `src' parameter in WarpPerspectiveForward::exec
  116. * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
  117. * \param[in] diff the backpropagated gradient wrt. dst
  118. * \param[out] grad the backpropagated gradient wrt. mat
  119. * \param[out] workspace temporary workspace to perform backward
  120. */
  121. virtual void exec(_megdnn_tensor_in src,
  122. _megdnn_tensor_in mat,
  123. _megdnn_tensor_in diff,
  124. _megdnn_tensor_out grad,
  125. _megdnn_workspace workspace) = 0;
  126. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  127. const TensorLayout &mat,
  128. const TensorLayout &diff,
  129. const TensorLayout &grad) = 0;
  130. protected:
  131. void check_exec(const TensorLayout &src,
  132. const TensorLayout &mat,
  133. const TensorLayout &diff,
  134. const TensorLayout &grad,
  135. size_t workspace_in_bytes);
  136. };
  137. } // namespace megdnn
  138. #include "megdnn/internal/opr_header_epilogue.h"
  139. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)