You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

linalg.h 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. /**
  2. * \file dnn/include/megdnn/oprs/linalg.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. namespace megdnn {
  14. class BatchedMatrixMulForward
  15. : public OperatorBase,
  16. public detail::MultiAlgoOpr<BatchedMatrixMulForward, 3> {
  17. DEF_OPR_PARAM(MatrixMul);
  18. DEF_OPR_IMPL(BatchedMatrixMulForward, OperatorBase, 2, 1);
  19. public:
  20. /**
  21. * \brief C = op(A) * op(B)
  22. * \param A (B, m, k) if transposeA is false, (B, k, m) otherwise
  23. * \param B (B, k, n) if transposeB is false, (B, n, k) otherwise
  24. * \param C (B, m, n)
  25. *
  26. * A, B, C must be 3-dimensional and C must be contiguous. A and B must
  27. * have stride[2] == 1, and stride[1] >= shape[2],
  28. * and stride[0] >= shape[1] * stride[1]
  29. *
  30. * op(A) = A if transposeA is false, otherwise op(A) = A^t.
  31. * op(B) = B if transposeB is false, otherwise op(B) = B^t.
  32. */
  33. virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
  34. _megdnn_tensor_out C, _megdnn_workspace workspace) = 0;
  35. void deduce_dtype(DType A, DType B, DType &C);
  36. void deduce_layout(const TensorLayout& A, const TensorLayout& B,
  37. TensorLayout& C);
  38. virtual size_t get_workspace_in_bytes(const TensorLayout& A,
  39. const TensorLayout& B,
  40. const TensorLayout& C) = 0;
  41. protected:
  42. void check_exec(const TensorLayout& A, const TensorLayout& B,
  43. const TensorLayout& C, size_t workspace_in_bytes);
  44. };
  45. using BatchedMatrixMul = BatchedMatrixMulForward;
  46. class MatrixMulForward : public OperatorBase,
  47. public detail::MultiAlgoOpr<MatrixMulForward, 3> {
  48. DEF_OPR_PARAM(MatrixMul);
  49. DEF_OPR_IMPL(MatrixMulForward, OperatorBase, 2, 1);
  50. public:
  51. /**
  52. * \brief C = op(A) * op(B)
  53. * \param A (m, k) if transposeA is false, (k, m) otherwise
  54. * \param B (k, n) if transposeB is false, (n, k) otherwise
  55. * \param C (m, n)
  56. *
  57. * A, B, C must be 2-dimensional and C must be contiguous. A and B must
  58. * have stride[1] == 1, and stride[0] >= shape[1]
  59. *
  60. * op(A) = A if transposeA is false, otherwise op(A) = A^t.
  61. * op(B) = B if transposeB is false, otherwise op(B) = B^t.
  62. */
  63. virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
  64. _megdnn_tensor_out C, _megdnn_workspace workspace) = 0;
  65. void deduce_dtype(DType A, DType B, DType& C);
  66. void deduce_layout(const TensorLayout& A, const TensorLayout& B,
  67. TensorLayout& C);
  68. virtual size_t get_workspace_in_bytes(const TensorLayout& A,
  69. const TensorLayout& B,
  70. const TensorLayout& C) = 0;
  71. static size_t pack_size (const Param::Format format);
  72. protected:
  73. void check_exec(const TensorLayout& A, const TensorLayout& B,
  74. const TensorLayout& C, size_t workspace_in_bytes);
  75. };
  76. using MatrixMul = MatrixMulForward;
  77. /*!
  78. * \brief compute the inverse of a batch of matrices
  79. *
  80. * Input and output tensors have the same shape [..., n, n] where the last two
  81. * dimensions represent the matrices.
  82. *
  83. * Currently only float32 is supported.
  84. */
  85. class MatrixInverse : public OperatorBase {
  86. DEF_OPR_IMPL(MatrixInverse, OperatorBase, 1, 1);
  87. DEF_OPR_PARAM(Empty);
  88. public:
  89. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  90. _megdnn_workspace workspace) = 0;
  91. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  92. size_t get_workspace_in_bytes(const TensorLayout& src,
  93. const TensorLayout& dst);
  94. protected:
  95. /*!
  96. * \brief get canonized params; throw exception on error.
  97. *
  98. * Note that \p batch and \p n can be null
  99. */
  100. static void canonize_params(const TensorLayout& layout, size_t* batch,
  101. size_t* n);
  102. /*!
  103. * \brief canonize and validate input params for exec() impls
  104. *
  105. * Since get_workspace_in_bytes() would be called, \p batch and \p n can not
  106. * be null
  107. */
  108. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  109. _megdnn_workspace workspace, size_t* batch, size_t* n);
  110. virtual size_t get_workspace_in_bytes(size_t batch, size_t n,
  111. size_t dtype_size) = 0;
  112. };
  113. //! inter-product of two vectors
  114. class DotForward : public OperatorBase {
  115. DEF_OPR_PARAM(Empty);
  116. DEF_OPR_IMPL(DotForward, OperatorBase, 2, 1);
  117. public:
  118. /**
  119. * \param[in] A
  120. * \param[in] B
  121. * \param[out] C
  122. *
  123. * Calculating the dot product of A and B and store it in C.
  124. * A, B, C must be contiguous. A and B must have the same 1-dimensional
  125. * shape and non-negative strides. C must be scalar.
  126. */
  127. virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
  128. _megdnn_tensor_out C, _megdnn_workspace workspace) = 0;
  129. void deduce_layout(const TensorLayout& A, const TensorLayout& B,
  130. TensorLayout& C);
  131. virtual size_t get_workspace_in_bytes(const TensorLayout& A,
  132. const TensorLayout& B,
  133. const TensorLayout& C) = 0;
  134. protected:
  135. void check_exec(const TensorLayout& A, const TensorLayout& B,
  136. const TensorLayout& C, size_t workspace_in_bytes);
  137. };
  138. using Dot = DotForward;
  139. /*!
  140. * \brief Compute the singular value decomposition of a batch of matrices
  141. *
  142. * Input tensors have the shape [..., m, n], where the last two
  143. * dimensions represent the matrices. For the output tensor u, s, vt,
  144. * the following equation holds: u * diag(s) * vt == src.
  145. *
  146. * Currently only float32 is supported.
  147. */
  148. class SVDForward : public OperatorBase {
  149. DEF_OPR_IMPL(SVDForward, OperatorBase, 1, 3);
  150. DEF_OPR_PARAM(SVD);
  151. public:
  152. /**
  153. * \brief u, s, vt = SVD(src) and u * diag(s) * vt == src
  154. * \param src (..., m, n) The input tensor, let p = min(m, n)
  155. * \param u (..., m, p) if full_matrices is false,
  156. (..., m, m) if full_matrices is true,
  157. empty tensor if compute_uv is false.
  158. The left singular vector.
  159. * \param s (..., p) The singular values.
  160. * \param vt (..., p, n) if full_matrices is false,
  161. (..., n, n) if full_matrices is true,
  162. empty tensor if compute_uv is false.
  163. The right singular vector.
  164. *
  165. * src must be contiguous. The computation might be significantly faster
  166. * if compute_uv is false (default to true).
  167. *
  168. */
  169. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out u,
  170. _megdnn_tensor_out s, _megdnn_tensor_out vt,
  171. _megdnn_workspace workspace) = 0;
  172. void deduce_layout(const TensorLayout& src, TensorLayout& u,
  173. TensorLayout& s, TensorLayout& vt);
  174. size_t get_workspace_in_bytes(const TensorLayout& src,
  175. const TensorLayout& u, const TensorLayout& s,
  176. const TensorLayout& vt);
  177. protected:
  178. static void canonize_params(const TensorLayout& layout, size_t* batch,
  179. size_t* m, size_t* n);
  180. virtual size_t get_workspace_in_bytes(size_t block_cnt, size_t m, size_t n,
  181. size_t dtype_size) = 0;
  182. void check_exec(const TensorLayout& src, const TensorLayout& u,
  183. const TensorLayout& s, const TensorLayout& vt,
  184. size_t workspace_in_bytes);
  185. };
  186. using SVD = SVDForward;
  187. } // namespace megdnn
  188. #include "megdnn/internal/opr_header_epilogue.h"
  189. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台