You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

linalg_ops.h 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file linalg_ops.h
  18. * \brief
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_LINALG_OPS_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_LINALG_OPS_H_
  22. #include "graph/operator_reg.h"
  23. #include "graph/operator.h"
  24. namespace ge {
  25. /**
  26. *@brief Computes the reverse mode backpropagated gradient of the Cholesky
  27. algorithm . \n
  28. *@par Inputs:
  29. *The input x has to be symmetric and positive definite. Inputs include:
  30. *@li x:A Tensor. Must be one of the following types: double, float32. Output
  31. of batch Cholesky algorithm x = cholesky(A). Shape is [..., M, M]. Algorithm
  32. depends only on lower triangular part of the innermost matrices of this tensor.
  33. *@li grad:A Tensor. Must have the same type as l. df/dx where f is some
  34. scalar function. Shape is [..., M, M]. Algorithm depends only on lower
  35. triangular part of the innermost matrices of this tensor . \n
  36. *@par Outputs:
  37. *y:A Tensor. Has the same type as x . \n
  38. *@attention Constraints:
  39. *The input x is a tensor of shape [..., M, M] whose inner-most 2 dimensions
  40. form square matrices.
  41. *@par Third-party framework compatibility
  42. *Compatible with tensorflow CholeskyGrad operator.
  43. */
  44. REG_OP(CholeskyGrad)
  45. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE}))
  46. .INPUT(grad, TensorType({DT_FLOAT, DT_DOUBLE}))
  47. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE}))
  48. .OP_END_FACTORY_REG(CholeskyGrad)
  49. /**
  50. *@brief Computes the Cholesky decomposition of one or more square matrices . \n
  51. *@par Inputs:
  52. *The input x has to be symmetric and positive definite.Inputs include:
  53. *x:A Tensor. Must be one of the following types: double, float32, float16,
  54. complex64, complex128. Shape is [..., M, M] . \n
  55. *@par Outputs:
  56. *y:A Tensor. Has the same type as x . \n
  57. *@attention Constraints:
  58. *The input x is a tensor of shape [..., M, M] whose inner-most 2 dimensions
  59. form square matrices.
  60. *@par Third-party framework compatibility
  61. *Compatible with tensorflow Cholesky operator.
  62. */
  63. REG_OP(Cholesky)
  64. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE, \
  65. DT_FLOAT16, DT_COMPLEX64, DT_COMPLEX128}))
  66. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, \
  67. DT_FLOAT16, DT_COMPLEX64, DT_COMPLEX128}))
  68. .OP_END_FACTORY_REG(Cholesky)
  69. /**
  70. *@brief Computes the outer product of two 1D vectors . \n
  71. *@par Inputs:
  72. *The input x1 and x2 has to be a 1D vector.Inputs include:
  73. *@li x1:A Tensor. Must be one of the following types: float16, float32.
  74. Shape is [N] . \n
  75. *@li x2:A Tensor. Must have the same type as x. Shape is [M] . \n
  76. *@par Outputs:
  77. *y:A Tensor. Has the same type as x . \n
  78. */
  79. REG_OP(Ger)
  80. .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT}))
  81. .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT}))
  82. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT}))
  83. .OP_END_FACTORY_REG(Ger)
  84. /**
  85. *@brief Computes the sign and the log of the absolute value of the determinant
  86. of one or more square matrices . \n
  87. *@par Inputs:
  88. *The input x is a tensor of shape [N, M, M] whose inner-most 2 dimensions
  89. form square matrices. Inputs include:
  90. *x:A Tensor. Must be one of the following types: double, float32,
  91. complex64, complex128. Shape is [..., M, M] . \n
  92. *@par Outputs:
  93. *@li y:A Tensor. Has the same type as x.
  94. *@li sign:A Tensor. Has the same type as x . \n
  95. *@attention Constraints:
  96. *The input x is a tensor of shape [N, M, M] whose inner-most 2 dimensions
  97. form square matrices. \n
  98. *@par Third-party framework compatibility
  99. *Compatible with tensorflow LogMatrixDeterminant operator.
  100. */
  101. REG_OP(LogMatrixDeterminant)
  102. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  103. .OUTPUT(sign, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  104. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  105. .OP_END_FACTORY_REG(LogMatrixDeterminant)
  106. /**
  107. *@brief Computes the determinant of one or more square matrices . \n
  108. *@par Inputs:
  109. *The input x is a tensor of shape [N, M, M] whose inner-most 2 dimensions
  110. form square matrices. Inputs include:
  111. *x:A Tensor. Must be one of the following types: double, float32, complex64,
  112. complex128. Shape is [..., M, M] . \n
  113. *@par Outputs:
  114. *y:A Tensor. Has the same type as x . \n
  115. *@attention Constraints:
  116. *The input x is a tensor of shape [..., M, M] whose inner-most 2 dimensions
  117. form square matrices.
  118. *@par Third-party framework compatibility
  119. *Compatible with tensorflow MatrixDeterminant operator.
  120. */
  121. REG_OP(MatrixDeterminant)
  122. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  123. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  124. .OP_END_FACTORY_REG(MatrixDeterminant)
  125. /**
  126. *@brief Computes the inverse of one or more square invertible matrices or
  127. their adjoints (conjugate transposes) . \n
  128. *@par Inputs:
  129. *The input x is a tensor of shape [..., M, M] whose inner-most 2 dimensions
  130. form square matrices. Inputs include:
  131. *x:A Tensor of input. Shape is [..., M, M] . \n
  132. *@par Attributes:
  133. *adjoint:An optional bool. Defaults to False.Boolean indicating whether to
  134. deal with matrix or its (block-wise) adjoint . \n
  135. *@par Outputs:
  136. *y:A Tensor. Has the same type as x . \n
  137. *@attention Constraints:
  138. *The input x is a tensor of shape [..., M, M] whose inner-most 2 dimensions
  139. form square matrices. \n
  140. *@par Third-party framework compatibility
  141. *Compatible with tensorflow MatrixInverse operator.
  142. */
  143. REG_OP(MatrixInverse)
  144. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  145. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  146. .ATTR(adjoint, Bool, false)
  147. .OP_END_FACTORY_REG(MatrixInverse)
  148. /**
  149. *@brief Solves systems of linear equations . \n
  150. *@par Inputs:
  151. *The input rhs must have the same type as matrix. Inputs include:
  152. *@li matrix:A Tensor of input. Shape is [..., M, M].
  153. *@li rhs:A Tensor. Must have the same type as matrix. Shape is [..., M, K] . \n
  154. *@par Attributes:
  155. *adjoint:An optional bool. Defaults to False.Boolean indicating whether to
  156. solve with matrix or its (block-wise) adjoint . \n
  157. *@par Outputs:
  158. *y:A Tensor. Has the same type as matrix . \n
  159. *@attention Constraints:
  160. *The input matrix is a tensor of shape [..., M, M] whose inner-most 2
  161. dimensions form square matrices. \n
  162. *@par Third-party framework compatibility
  163. *Compatible with tensorflow MatrixSolve operator.
  164. */
  165. REG_OP(MatrixSolve)
  166. .INPUT(matrix, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  167. .INPUT(rhs, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  168. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  169. .ATTR(adjoint, Bool, false)
  170. .OP_END_FACTORY_REG(MatrixSolve)
  171. /**
  172. *@brief Solves systems of linear equations . \n
  173. *@par Inputs:
  174. *The input rhs must have the same type as matrix. Inputs include:
  175. *@li matrix:A Tensor. Shape is [..., M, M].
  176. *@li rhs:A Tensor. Must have the same type as matrix. Shape is [..., M, K].
  177. *@li l2:0-D double Tensor. Ignored if fast=False . \n
  178. *@par Attributes:
  179. *fast:bool. Defaults to True . \n
  180. *@par Outputs:
  181. *y:Tensor of shape [..., N, K] whose inner-most 2 dimensions form M-by-K
  182. matrices that solve the equations matrix[..., :, :] * output[..., :, :] =
  183. rhs[..., :, :] in the least squares sense . \n
  184. *@attention Constraints:
  185. *The input matrix matrix is a tensor of shape [..., M, M] whose inner-most 2
  186. dimensions form square matrices. \n
  187. *@par Third-party framework compatibility
  188. *Compatible with tensorflow MatrixSolveLs operator.
  189. */
  190. REG_OP(MatrixSolveLs)
  191. .INPUT(matrix, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  192. .INPUT(rhs, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  193. .INPUT(l2, TensorType({DT_DOUBLE}))
  194. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE}))
  195. .ATTR(fast, Bool, true)
  196. .OP_END_FACTORY_REG(MatrixSolveLs)
  197. /**
  198. *@brief Solves systems of linear equations with upper or lower triangular
  199. matrices by backsubstitution . \n
  200. *@par Inputs:
  201. *The input rhs must have the same type as matrix. Inputs include:
  202. *@li matrix: A Tensor. Shape is [..., M, M].
  203. *@li rhs:A Tensor. Must have the same type as matrix. Shape is [..., M, K] . \n
  204. *@par Attributes:
  205. *@li lower: An optional bool. Defaults to True. Boolean indicating whether
  206. the innermost matrices in matrix are lower or upper triangular.
  207. *@li An optional bool. Defaults to False. Boolean indicating whether to solve
  208. with matrix or its (block-wise) adjoint . \n
  209. *@par Outputs:
  210. *y:A Tensor. Has the same type as matrix . \n
  211. *@attention Constraints:
  212. *The input matrix is a tensor of shape [..., M, M] whose inner-most 2
  213. dimensions form square matrices. \n
  214. *@par Third-party framework compatibility
  215. *Compatible with tensorflow MatrixTriangularSolve operator.
  216. */
  217. REG_OP(MatrixTriangularSolve)
  218. .INPUT(matrix, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  219. .INPUT(rhs, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  220. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  221. .ATTR(lower, Bool, true)
  222. .ATTR(adjoint, Bool, false)
  223. .OP_END_FACTORY_REG(MatrixTriangularSolve)
  224. /**
  225. *@brief Computes the QR decompositions of one or more matrices . \n
  226. *@par Inputs:
  227. *The input shape of x must be [..., M, N]. Inputs include:
  228. *x:A Tensor whose shape is [..., M, N]. \n
  229. *@par Attributes:
  230. *full_matrices: An optional bool. Defaults to False. If true, compute
  231. full-sized q and r. If false (the default), compute only the leading P
  232. columns of q . \n
  233. *@par Outputs:
  234. *@li q: A Tensor. Has the same type as x.
  235. *@li r: A Tensor. Has the same type as x . \n
  236. *@attention Constraints:
  237. *The input matrix x is a tensor of shape [..., M, N] whose inner-most 2
  238. dimensions form matrices of size [M, N]. \n
  239. *@par Third-party framework compatibility
  240. *Compatible with tensorflow Qr operator.
  241. */
  242. REG_OP(Qr)
  243. .INPUT(x, TensorType({ DT_FLOAT16, DT_FLOAT, DT_DOUBLE, \
  244. DT_COMPLEX64, DT_COMPLEX128 }))
  245. .OUTPUT(q, TensorType({ DT_FLOAT16, DT_FLOAT, DT_DOUBLE, \
  246. DT_COMPLEX64, DT_COMPLEX128 }))
  247. .OUTPUT(r, TensorType({ DT_FLOAT16, DT_FLOAT, DT_DOUBLE, \
  248. DT_COMPLEX64, DT_COMPLEX128 }))
  249. .ATTR(full_matrices, Bool, false)
  250. .OP_END_FACTORY_REG(Qr)
  251. /**
  252. *@brief Computes the eigen decomposition of a batch of self-adjoint matrices . \n
  253. *@par Inputs:
  254. *The input shape of x must be [..., N, N]. Inputs include:
  255. *x:Tensor of shape [..., N, N]. Only the lower triangular part of each inner
  256. inner matrix is referenced . \n
  257. *@par Attributes:
  258. *compute_v:bool. Defaults to True . \n
  259. *@par Outputs:
  260. *@li eigen_value:Eigenvalues. Shape is [..., N]. Sorted in non-decreasing order.
  261. *@li eigen_vector:Shape is [..., N, N]. The columns of the inner most matrices
  262. contain eigenvectors of the corresponding matrices in tensor
  263. *@attention Constraints:
  264. *The input x is a tensor of shape [..., N, N] whose inner-most 2 dimensions
  265. form square matrices. \n
  266. *@par Third-party framework compatibility
  267. *Compatible with tensorflow SelfAdjointEig operator.
  268. */
  269. REG_OP(SelfAdjointEig)
  270. .INPUT(x, TensorType({ DT_DOUBLE, DT_FLOAT, DT_COMPLEX64, DT_COMPLEX128 }))
  271. .OUTPUT(eigen_value, TensorType({ DT_DOUBLE, DT_FLOAT, DT_COMPLEX64, DT_COMPLEX128 }))
  272. .OUTPUT(eigen_vector, TensorType({ DT_DOUBLE, DT_FLOAT, DT_COMPLEX64, DT_COMPLEX128 }))
  273. .ATTR(compute_v, Bool, true)
  274. .OP_END_FACTORY_REG(SelfAdjointEig)
  275. /**
  276. *@par Restrictions:
  277. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  278. *@brief Computes the sign and the log of the absolute value of the determinant
  279. of one or more square matrices . \n
  280. *@par Inputs:
  281. *The input x is a tensor of shape [N, M, M] whose inner-most 2 dimensions
  282. form square matrices. Inputs include:
  283. *x:A Tensor. Must be one of the following types: double, float32, float16
  284. Shape is [..., M, M] . \n
  285. *@par Outputs:
  286. *@li y:A Tensor. Has the same type as x.
  287. *@li sign:A Tensor. Has the same type as x . \n
  288. *@attention Constraints:
  289. *The input x is a tensor of shape [N, M, M] whose inner-most 2 dimensions
  290. form square matrices. \n
  291. *@par Third-party framework compatibility
  292. *Compatible with tensorflow LogMatrixDeterminant operator.
  293. */
  294. REG_OP(Slogdet)
  295. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  296. .OUTPUT(sign, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  297. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  298. .OP_END_FACTORY_REG(Slogdet)
  299. /**
  300. *@brief Computes the singular value decompositions of one or more matrices . \n
  301. *@par Inputs:
  302. *The input shape of x must be [..., N, N]. Inputs include:
  303. *x:Tensor of shape [..., M, N]. Let P be the minimum of M and N . \n
  304. *@par Attributes:
  305. *@li compute_uv:If True then left and right singular vectors will be computed and
  306. returned in u and v, respectively. Otherwise, only the singular values will
  307. be computed, which can be significantly faster .
  308. *@li full_matrices:the param effect u,v. \n
  309. *@par Outputs:
  310. *@li sigma:Singular values. Shape is [..., P]. The values are sorted in
  311. reverse order of magnitude, so s[..., 0] is the largest value, s[..., 1]
  312. is the second largest, etc.
  313. *@li u:Left singular vectors. If full_matrices is False (default) then shape
  314. is [..., M, P]; if full_matrices is True then shape is [..., M, M]. Not
  315. returned if compute_uv is False.
  316. *@li v:Right singular vectors. If full_matrices is False (default) then shape
  317. is [..., N, P]. If full_matrices is True then shape is [..., N, N]. Not
  318. returned if compute_uv is False . \n
  319. *@attention Constraints:
  320. *The input x is a tensor of shape [..., N, N] whose inner-most 2 dimensions
  321. form square matrices. \n
  322. *@par Third-party framework compatibility
  323. *Compatible with tensorflow Svd operator
  324. */
  325. REG_OP(Svd)
  326. .INPUT(x, TensorType({ DT_DOUBLE, DT_FLOAT, DT_COMPLEX64, DT_COMPLEX128 }))
  327. .OUTPUT(sigma, TensorType({ DT_DOUBLE, DT_FLOAT, DT_COMPLEX64, DT_COMPLEX128 }))
  328. .OUTPUT(u, TensorType({ DT_DOUBLE, DT_FLOAT, DT_COMPLEX64, DT_COMPLEX128 }))
  329. .OUTPUT(v, TensorType({ DT_DOUBLE, DT_FLOAT, DT_COMPLEX64, DT_COMPLEX128 }))
  330. .ATTR(compute_uv, Bool, true)
  331. .ATTR(full_matrices, Bool, false)
  332. .OP_END_FACTORY_REG(Svd)
  333. /**
  334. *@brief Computes the LU decomposition of one or more square matrices . \n
  335. *@par Inputs:
  336. *input: A tensor of shape `[..., M, M]` whose inner-most 2 dimensions form
  337. matrices of size `[M, M]` . \n
  338. *@par Outputs:
  339. *@li lu: A tensor of shape `[..., M, M]` whose strictly lower triangular part
  340. denotes the lower triangular factor `L` with unit diagonal.
  341. *@li p: upper triangular part denotes the upper triangular factor `U`.Permutation
  342. of the rows encoded as a list of indices in `0..M-1`. Shape is `[..., M]` . \n
  343. *@par Attributes:
  344. *output_idx_type: An optional DType from: int32, int64.
  345. *@par Third-party framework compatibility
  346. * Compatible with TensorFlow Lu operator.
  347. */
  348. REG_OP(Lu)
  349. .INPUT(input, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  350. .OUTPUT(lu, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  351. .OUTPUT(p, TensorType({DT_INT32, DT_INT64}))
  352. .REQUIRED_ATTR(output_idx_type, Type)
  353. .OP_END_FACTORY_REG(Lu)
  354. /**
  355. *@brief Computes the matrix square root of one or more square matrices . \n
  356. *@par Inputs:
  357. *input: Shape is `[..., M, M]` . \n
  358. *@par Outputs:
  359. y: Shape is `[..., M, M]` . \n
  360. *@par Third-party framework compatibility
  361. * Compatible with TensorFlow MatrixSquareRoot operator.
  362. */
  363. REG_OP(MatrixSquareRoot)
  364. .INPUT(input, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  365. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  366. .OP_END_FACTORY_REG(MatrixSquareRoot)
  367. /**
  368. *@brief Solves tridiagonal systems of equations . \n
  369. *@par Inputs:
  370. *@li diagonals: Tensor of shape `[..., 3, M]` whose innermost 2 dimensions represent the tridiagonal matrices with three rows being the superdiagonal, diagonals, and subdiagonals, in order. The last element of the superdiagonal and the first element of the subdiagonal is ignored.
  371. *@li rhs: Tensor of shape `[..., M, K]`, representing K right-hand sides per each
  372. left-hand side . \n
  373. *@par Outputs:
  374. y: Tensor of shape `[..., M, K]` containing the solutions \n
  375. *@par Attributes:
  376. *partial_pivoting: Whether to perform partial pivoting. `True` by default.
  377. Partial pivoting makes the procedure more stable, but slower. Partial
  378. pivoting is unnecessary in some cases, including diagonally dominant and
  379. symmetric positive definite matrices
  380. *@par Third-party framework compatibility
  381. * Compatible with TensorFlow TridiagonalSolve operator.
  382. */
  383. REG_OP(TridiagonalSolve)
  384. .INPUT(diagonals, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  385. .INPUT(rhs, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  386. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}))
  387. .ATTR(partial_pivoting, Bool, true)
  388. .OP_END_FACTORY_REG(TridiagonalSolve)
  389. } // namespace ge
  390. #endif // OPS_BUILT_IN_OP_PROTO_INC_LINALG_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示