You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

linalg.h 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. #pragma once
  2. #include "megdnn/internal/opr_header_prologue.h"
  3. namespace megdnn {
  4. class BatchedMatrixMulForward
  5. : public OperatorBase,
  6. public detail::MultiAlgoOpr<BatchedMatrixMulForward, 3> {
  7. DEF_OPR_PARAM(MatrixMul);
  8. DEF_OPR_IMPL(BatchedMatrixMulForward, OperatorBase, 2, 1);
  9. public:
  10. /**
  11. * \brief C = op(A) * op(B)
  12. * \param A (B, m, k) if transposeA is false, (B, k, m) otherwise
  13. * \param B (B, k, n) if transposeB is false, (B, n, k) otherwise
  14. * \param C (B, m, n)
  15. *
  16. * A, B, C must be 3-dimensional and C must be contiguous. A and B must
  17. * have stride[2] == 1, and stride[1] >= shape[2],
  18. * and stride[0] >= shape[1] * stride[1]
  19. *
  20. * op(A) = A if transposeA is false, otherwise op(A) = A^t.
  21. * op(B) = B if transposeB is false, otherwise op(B) = B^t.
  22. */
  23. virtual void exec(
  24. _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
  25. _megdnn_workspace workspace) = 0;
  26. MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C);
  27. void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
  28. virtual size_t get_workspace_in_bytes(
  29. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
  30. static Algorithm::OprType get_opr_type() {
  31. return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD;
  32. }
  33. protected:
  34. void check_exec(
  35. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  36. size_t workspace_in_bytes);
  37. };
  38. using BatchedMatrixMul = BatchedMatrixMulForward;
  39. class MatrixMulForward : public OperatorBase,
  40. public detail::MultiAlgoOpr<MatrixMulForward, 3> {
  41. DEF_OPR_PARAM(MatrixMul);
  42. DEF_OPR_IMPL(MatrixMulForward, OperatorBase, 2, 1);
  43. public:
  44. /**
  45. * \brief C = op(A) * op(B)
  46. * \param A (m, k) if transposeA is false, (k, m) otherwise
  47. * \param B (k, n) if transposeB is false, (n, k) otherwise
  48. * \param C (m, n)
  49. *
  50. * A, B, C must be 2-dimensional and C must be contiguous. A and B must
  51. * have stride[1] == 1, and stride[0] >= shape[1]
  52. *
  53. * op(A) = A if transposeA is false, otherwise op(A) = A^t.
  54. * op(B) = B if transposeB is false, otherwise op(B) = B^t.
  55. */
  56. virtual void exec(
  57. _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
  58. _megdnn_workspace workspace) = 0;
  59. MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C);
  60. void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
  61. virtual size_t get_workspace_in_bytes(
  62. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
  63. static size_t pack_size(const Param::Format format);
  64. static Algorithm::OprType get_opr_type() {
  65. return Algorithm::OprType::MATRIX_MUL_FORWARD;
  66. }
  67. protected:
  68. void check_exec(
  69. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  70. size_t workspace_in_bytes);
  71. };
  72. using MatrixMul = MatrixMulForward;
  73. /*!
  74. * \brief compute the inverse of a batch of matrices
  75. *
  76. * Input and output tensors have the same shape [..., n, n] where the last two
  77. * dimensions represent the matrices.
  78. *
  79. * Currently only float32 is supported.
  80. */
  81. class MatrixInverse : public OperatorBase {
  82. DEF_OPR_IMPL(MatrixInverse, OperatorBase, 1, 1);
  83. DEF_OPR_PARAM(Empty);
  84. public:
  85. virtual void exec(
  86. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  87. _megdnn_workspace workspace) = 0;
  88. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  89. size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst);
  90. protected:
  91. /*!
  92. * \brief get canonized params; throw exception on error.
  93. *
  94. * Note that \p batch and \p n can be null
  95. */
  96. static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n);
  97. /*!
  98. * \brief canonize and validate input params for exec() impls
  99. *
  100. * Since get_workspace_in_bytes() would be called, \p batch and \p n can not
  101. * be null
  102. */
  103. void check_exec(
  104. const TensorLayout& src, const TensorLayout& dst,
  105. _megdnn_workspace workspace, size_t* batch, size_t* n);
  106. virtual size_t get_workspace_in_bytes(
  107. size_t batch, size_t n, size_t dtype_size) = 0;
  108. };
  109. //! inter-product of two vectors
  110. class DotForward : public OperatorBase {
  111. DEF_OPR_PARAM(Empty);
  112. DEF_OPR_IMPL(DotForward, OperatorBase, 2, 1);
  113. public:
  114. /**
  115. * \param[in] A
  116. * \param[in] B
  117. * \param[out] C
  118. *
  119. * Calculating the dot product of A and B and store it in C.
  120. * A, B, C must be contiguous. A and B must have the same 1-dimensional
  121. * shape and non-negative strides. C must be scalar.
  122. */
  123. virtual void exec(
  124. _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
  125. _megdnn_workspace workspace) = 0;
  126. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  127. const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
  128. virtual size_t get_workspace_in_bytes(
  129. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
  130. protected:
  131. void check_exec(
  132. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  133. size_t workspace_in_bytes);
  134. };
  135. using Dot = DotForward;
  136. /*!
  137. * \brief Compute the singular value decomposition of a batch of matrices
  138. *
  139. * Input tensors have the shape [..., m, n], where the last two
  140. * dimensions represent the matrices. For the output tensor u, s, vt,
  141. * the following equation holds: u * diag(s) * vt == src.
  142. *
  143. * Currently only float32 is supported.
  144. */
  145. class SVDForward : public OperatorBase {
  146. DEF_OPR_IMPL(SVDForward, OperatorBase, 1, 3);
  147. DEF_OPR_PARAM(SVD);
  148. public:
  149. /**
  150. * \brief u, s, vt = SVD(src) and u * diag(s) * vt == src
  151. * \param src (..., m, n) The input tensor, let p = min(m, n)
  152. * \param u (..., m, p) if full_matrices is false,
  153. (..., m, m) if full_matrices is true,
  154. empty tensor if compute_uv is false.
  155. The left singular vector.
  156. * \param s (..., p) The singular values.
  157. * \param vt (..., p, n) if full_matrices is false,
  158. (..., n, n) if full_matrices is true,
  159. empty tensor if compute_uv is false.
  160. The right singular vector.
  161. *
  162. * src must be contiguous. The computation might be significantly faster
  163. * if compute_uv is false (default to true).
  164. *
  165. */
  166. virtual void exec(
  167. _megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s,
  168. _megdnn_tensor_out vt, _megdnn_workspace workspace) = 0;
  169. void deduce_layout(
  170. const TensorLayout& src, TensorLayout& u, TensorLayout& s,
  171. TensorLayout& vt);
  172. size_t get_workspace_in_bytes(
  173. const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
  174. const TensorLayout& vt);
  175. protected:
  176. static void canonize_params(
  177. const TensorLayout& layout, size_t* batch, size_t* m, size_t* n);
  178. virtual size_t get_workspace_in_bytes(
  179. size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0;
  180. void check_exec(
  181. const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
  182. const TensorLayout& vt, size_t workspace_in_bytes);
  183. };
  184. using SVD = SVDForward;
  185. } // namespace megdnn
  186. #include "megdnn/internal/opr_header_epilogue.h"
  187. // vim: syntax=cpp.doxygen