You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

linalg.h 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. /**
  2. * \file dnn/include/megdnn/oprs/linalg.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. namespace megdnn {
  14. class BatchedMatrixMulForward
  15. : public OperatorBase,
  16. public detail::MultiAlgoOpr<BatchedMatrixMulForward, 3> {
  17. DEF_OPR_PARAM(MatrixMul);
  18. DEF_OPR_IMPL(BatchedMatrixMulForward, OperatorBase, 2, 1);
  19. public:
  20. /**
  21. * \brief C = op(A) * op(B)
  22. * \param A (B, m, k) if transposeA is false, (B, k, m) otherwise
  23. * \param B (B, k, n) if transposeB is false, (B, n, k) otherwise
  24. * \param C (B, m, n)
  25. *
  26. * A, B, C must be 3-dimensional and C must be contiguous. A and B must
  27. * have stride[2] == 1, and stride[1] >= shape[2],
  28. * and stride[0] >= shape[1] * stride[1]
  29. *
  30. * op(A) = A if transposeA is false, otherwise op(A) = A^t.
  31. * op(B) = B if transposeB is false, otherwise op(B) = B^t.
  32. */
  33. virtual void exec(
  34. _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
  35. _megdnn_workspace workspace) = 0;
  36. void deduce_dtype(DType A, DType B, DType& C);
  37. void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
  38. virtual size_t get_workspace_in_bytes(
  39. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
  40. static Algorithm::OprType get_opr_type() {
  41. return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD;
  42. }
  43. protected:
  44. void check_exec(
  45. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  46. size_t workspace_in_bytes);
  47. };
  48. using BatchedMatrixMul = BatchedMatrixMulForward;
  49. class MatrixMulForward : public OperatorBase,
  50. public detail::MultiAlgoOpr<MatrixMulForward, 3> {
  51. DEF_OPR_PARAM(MatrixMul);
  52. DEF_OPR_IMPL(MatrixMulForward, OperatorBase, 2, 1);
  53. public:
  54. /**
  55. * \brief C = op(A) * op(B)
  56. * \param A (m, k) if transposeA is false, (k, m) otherwise
  57. * \param B (k, n) if transposeB is false, (n, k) otherwise
  58. * \param C (m, n)
  59. *
  60. * A, B, C must be 2-dimensional and C must be contiguous. A and B must
  61. * have stride[1] == 1, and stride[0] >= shape[1]
  62. *
  63. * op(A) = A if transposeA is false, otherwise op(A) = A^t.
  64. * op(B) = B if transposeB is false, otherwise op(B) = B^t.
  65. */
  66. virtual void exec(
  67. _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
  68. _megdnn_workspace workspace) = 0;
  69. void deduce_dtype(DType A, DType B, DType& C);
  70. void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
  71. virtual size_t get_workspace_in_bytes(
  72. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
  73. static size_t pack_size(const Param::Format format);
  74. static Algorithm::OprType get_opr_type() {
  75. return Algorithm::OprType::MATRIX_MUL_FORWARD;
  76. }
  77. protected:
  78. void check_exec(
  79. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  80. size_t workspace_in_bytes);
  81. };
  82. using MatrixMul = MatrixMulForward;
  83. /*!
  84. * \brief compute the inverse of a batch of matrices
  85. *
  86. * Input and output tensors have the same shape [..., n, n] where the last two
  87. * dimensions represent the matrices.
  88. *
  89. * Currently only float32 is supported.
  90. */
  91. class MatrixInverse : public OperatorBase {
  92. DEF_OPR_IMPL(MatrixInverse, OperatorBase, 1, 1);
  93. DEF_OPR_PARAM(Empty);
  94. public:
  95. virtual void exec(
  96. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  97. _megdnn_workspace workspace) = 0;
  98. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  99. size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst);
  100. protected:
  101. /*!
  102. * \brief get canonized params; throw exception on error.
  103. *
  104. * Note that \p batch and \p n can be null
  105. */
  106. static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n);
  107. /*!
  108. * \brief canonize and validate input params for exec() impls
  109. *
  110. * Since get_workspace_in_bytes() would be called, \p batch and \p n can not
  111. * be null
  112. */
  113. void check_exec(
  114. const TensorLayout& src, const TensorLayout& dst,
  115. _megdnn_workspace workspace, size_t* batch, size_t* n);
  116. virtual size_t get_workspace_in_bytes(
  117. size_t batch, size_t n, size_t dtype_size) = 0;
  118. };
  119. //! inter-product of two vectors
  120. class DotForward : public OperatorBase {
  121. DEF_OPR_PARAM(Empty);
  122. DEF_OPR_IMPL(DotForward, OperatorBase, 2, 1);
  123. public:
  124. /**
  125. * \param[in] A
  126. * \param[in] B
  127. * \param[out] C
  128. *
  129. * Calculating the dot product of A and B and store it in C.
  130. * A, B, C must be contiguous. A and B must have the same 1-dimensional
  131. * shape and non-negative strides. C must be scalar.
  132. */
  133. virtual void exec(
  134. _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
  135. _megdnn_workspace workspace) = 0;
  136. void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
  137. virtual size_t get_workspace_in_bytes(
  138. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
  139. protected:
  140. void check_exec(
  141. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  142. size_t workspace_in_bytes);
  143. };
  144. using Dot = DotForward;
  145. /*!
  146. * \brief Compute the singular value decomposition of a batch of matrices
  147. *
  148. * Input tensors have the shape [..., m, n], where the last two
  149. * dimensions represent the matrices. For the output tensor u, s, vt,
  150. * the following equation holds: u * diag(s) * vt == src.
  151. *
  152. * Currently only float32 is supported.
  153. */
  154. class SVDForward : public OperatorBase {
  155. DEF_OPR_IMPL(SVDForward, OperatorBase, 1, 3);
  156. DEF_OPR_PARAM(SVD);
  157. public:
  158. /**
  159. * \brief u, s, vt = SVD(src) and u * diag(s) * vt == src
  160. * \param src (..., m, n) The input tensor, let p = min(m, n)
  161. * \param u (..., m, p) if full_matrices is false,
  162. (..., m, m) if full_matrices is true,
  163. empty tensor if compute_uv is false.
  164. The left singular vector.
  165. * \param s (..., p) The singular values.
  166. * \param vt (..., p, n) if full_matrices is false,
  167. (..., n, n) if full_matrices is true,
  168. empty tensor if compute_uv is false.
  169. The right singular vector.
  170. *
  171. * src must be contiguous. The computation might be significantly faster
  172. * if compute_uv is false (default to true).
  173. *
  174. */
  175. virtual void exec(
  176. _megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s,
  177. _megdnn_tensor_out vt, _megdnn_workspace workspace) = 0;
  178. void deduce_layout(
  179. const TensorLayout& src, TensorLayout& u, TensorLayout& s,
  180. TensorLayout& vt);
  181. size_t get_workspace_in_bytes(
  182. const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
  183. const TensorLayout& vt);
  184. protected:
  185. static void canonize_params(
  186. const TensorLayout& layout, size_t* batch, size_t* m, size_t* n);
  187. virtual size_t get_workspace_in_bytes(
  188. size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0;
  189. void check_exec(
  190. const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
  191. const TensorLayout& vt, size_t workspace_in_bytes);
  192. };
  193. using SVD = SVDForward;
  194. } // namespace megdnn
  195. #include "megdnn/internal/opr_header_epilogue.h"
  196. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台