You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

linalg.h 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #pragma once
  2. #include "megdnn/internal/opr_header_prologue.h"
  3. namespace megdnn {
  4. class BatchedMatrixMulForward
  5. : public OperatorBase,
  6. public detail::MultiAlgoOpr<BatchedMatrixMulForward, 3> {
  7. DEF_OPR_PARAM(MatrixMul);
  8. DEF_OPR_IMPL(BatchedMatrixMulForward, OperatorBase, 2, 1);
  9. public:
  10. /**
  11. * \brief C = op(A) * op(B)
  12. * \param A (B, m, k) if transposeA is false, (B, k, m) otherwise
  13. * \param B (B, k, n) if transposeB is false, (B, n, k) otherwise
  14. * \param C (B, m, n)
  15. *
  16. * A, B, C must be 3-dimensional and C must be contiguous. A and B must
  17. * have stride[2] == 1, and stride[1] >= shape[2],
  18. * and stride[0] >= shape[1] * stride[1]
  19. *
  20. * op(A) = A if transposeA is false, otherwise op(A) = A^t.
  21. * op(B) = B if transposeB is false, otherwise op(B) = B^t.
  22. */
  23. virtual void exec(
  24. _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
  25. _megdnn_workspace workspace) = 0;
  26. MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C);
  27. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  28. const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
  29. virtual size_t get_workspace_in_bytes(
  30. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
  31. static Algorithm::OprType get_opr_type() {
  32. return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD;
  33. }
  34. protected:
  35. void check_exec(
  36. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  37. size_t workspace_in_bytes);
  38. };
  39. using BatchedMatrixMul = BatchedMatrixMulForward;
  40. class MatrixMulForward : public OperatorBase,
  41. public detail::MultiAlgoOpr<MatrixMulForward, 3> {
  42. DEF_OPR_PARAM(MatrixMul);
  43. DEF_OPR_IMPL(MatrixMulForward, OperatorBase, 2, 1);
  44. public:
  45. /**
  46. * \brief C = op(A) * op(B)
  47. * \param A (m, k) if transposeA is false, (k, m) otherwise
  48. * \param B (k, n) if transposeB is false, (n, k) otherwise
  49. * \param C (m, n)
  50. *
  51. * A, B, C must be 2-dimensional and C must be contiguous. A and B must
  52. * have stride[1] == 1, and stride[0] >= shape[1]
  53. *
  54. * op(A) = A if transposeA is false, otherwise op(A) = A^t.
  55. * op(B) = B if transposeB is false, otherwise op(B) = B^t.
  56. */
  57. virtual void exec(
  58. _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
  59. _megdnn_workspace workspace) = 0;
  60. MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C);
  61. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  62. const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
  63. virtual size_t get_workspace_in_bytes(
  64. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
  65. static size_t pack_size(const Param::Format format);
  66. static Algorithm::OprType get_opr_type() {
  67. return Algorithm::OprType::MATRIX_MUL_FORWARD;
  68. }
  69. protected:
  70. void check_exec(
  71. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  72. size_t workspace_in_bytes);
  73. };
  74. using MatrixMul = MatrixMulForward;
  75. /*!
  76. * \brief compute the inverse of a batch of matrices
  77. *
  78. * Input and output tensors have the same shape [..., n, n] where the last two
  79. * dimensions represent the matrices.
  80. *
  81. * Currently only float32 is supported.
  82. */
  83. class MatrixInverse : public OperatorBase {
  84. DEF_OPR_IMPL(MatrixInverse, OperatorBase, 1, 1);
  85. DEF_OPR_PARAM(Empty);
  86. public:
  87. virtual void exec(
  88. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  89. _megdnn_workspace workspace) = 0;
  90. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  91. size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst);
  92. protected:
  93. /*!
  94. * \brief get canonized params; throw exception on error.
  95. *
  96. * Note that \p batch and \p n can be null
  97. */
  98. static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n);
  99. /*!
  100. * \brief canonize and validate input params for exec() impls
  101. *
  102. * Since get_workspace_in_bytes() would be called, \p batch and \p n can not
  103. * be null
  104. */
  105. void check_exec(
  106. const TensorLayout& src, const TensorLayout& dst,
  107. _megdnn_workspace workspace, size_t* batch, size_t* n);
  108. virtual size_t get_workspace_in_bytes(
  109. size_t batch, size_t n, size_t dtype_size) = 0;
  110. };
  111. //! inter-product of two vectors
  112. class DotForward : public OperatorBase {
  113. DEF_OPR_PARAM(Empty);
  114. DEF_OPR_IMPL(DotForward, OperatorBase, 2, 1);
  115. public:
  116. /**
  117. * \param[in] A
  118. * \param[in] B
  119. * \param[out] C
  120. *
  121. * Calculating the dot product of A and B and store it in C.
  122. * A, B, C must be contiguous. A and B must have the same 1-dimensional
  123. * shape and non-negative strides. C must be scalar.
  124. */
  125. virtual void exec(
  126. _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
  127. _megdnn_workspace workspace) = 0;
  128. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  129. const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
  130. virtual size_t get_workspace_in_bytes(
  131. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
  132. protected:
  133. void check_exec(
  134. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  135. size_t workspace_in_bytes);
  136. };
  137. using Dot = DotForward;
  138. /*!
  139. * \brief Compute the singular value decomposition of a batch of matrices
  140. *
  141. * Input tensors have the shape [..., m, n], where the last two
  142. * dimensions represent the matrices. For the output tensor u, s, vt,
  143. * the following equation holds: u * diag(s) * vt == src.
  144. *
  145. * Currently only float32 is supported.
  146. */
  147. class SVDForward : public OperatorBase {
  148. DEF_OPR_IMPL(SVDForward, OperatorBase, 1, 3);
  149. DEF_OPR_PARAM(SVD);
  150. public:
  151. /**
  152. * \brief u, s, vt = SVD(src) and u * diag(s) * vt == src
  153. * \param src (..., m, n) The input tensor, let p = min(m, n)
  154. * \param u (..., m, p) if full_matrices is false,
  155. (..., m, m) if full_matrices is true,
  156. empty tensor if compute_uv is false.
  157. The left singular vector.
  158. * \param s (..., p) The singular values.
  159. * \param vt (..., p, n) if full_matrices is false,
  160. (..., n, n) if full_matrices is true,
  161. empty tensor if compute_uv is false.
  162. The right singular vector.
  163. *
  164. * src must be contiguous. The computation might be significantly faster
  165. * if compute_uv is false (default to true).
  166. *
  167. */
  168. virtual void exec(
  169. _megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s,
  170. _megdnn_tensor_out vt, _megdnn_workspace workspace) = 0;
  171. void deduce_layout(
  172. const TensorLayout& src, TensorLayout& u, TensorLayout& s,
  173. TensorLayout& vt);
  174. size_t get_workspace_in_bytes(
  175. const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
  176. const TensorLayout& vt);
  177. protected:
  178. static void canonize_params(
  179. const TensorLayout& layout, size_t* batch, size_t* m, size_t* n);
  180. virtual size_t get_workspace_in_bytes(
  181. size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0;
  182. void check_exec(
  183. const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
  184. const TensorLayout& vt, size_t workspace_in_bytes);
  185. };
  186. using SVD = SVDForward;
  187. } // namespace megdnn
  188. #include "megdnn/internal/opr_header_epilogue.h"
  189. // vim: syntax=cpp.doxygen