You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math_ops.h 20 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_OP_MATH_OPS_H_
  17. #define GE_OP_MATH_OPS_H_
  18. #include "graph/operator_reg.h"
  19. #include "graph/operator.h"
  20. namespace ge {
  21. /**
  22. *@brief Computes the output as (shift + scale * x) ^ power.
  23. *@par Inputs:
  24. * x: A Tensor of type float16 or float32.
  25. *@par Attributes:
  26. *@li power: Optional. Defaults to 1.0.
  27. *@li scale: Optional. Defaults to 1.0.
  28. *@li shift: Optional. Defaults to 0.0.
  29. *@par Outputs:
  30. * y: A Tensor. Has the same type and shape as "x".
  31. *@par Third-party framework compatibility
  32. * Compatible with the Caffe operator Power.
  33. */
  34. REG_OP(Power)
  35. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT}))
  36. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT}))
  37. .ATTR(power, Float, 1.0)
  38. .ATTR(scale, Float, 1.0)
  39. .ATTR(shift, Float, 0.0)
  40. .OP_END_FACTORY_REG(Power);
  41. /**
  42. *@brief Compute the lower regularized incomplete Gamma function P(a, x).
  43. *@par Inputs:
  44. *The input a and x must have the same type. Inputs include: \n
  45. *@li a:A Tensor. Must be one of the following types: float, double.
  46. *@li x:A Tensor. Must have the same type as a.
  47. *@par Outputs:
  48. *z:A Tensor. Has the same type as a.
  49. *@par Third-party framework compatibility.
  50. *Compatible with tensorflow Igamma operator.
  51. */
  52. REG_OP(Igamma)
  53. .INPUT(a, TensorType({DT_FLOAT, DT_DOUBLE}))
  54. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE}))
  55. .OUTPUT(z, TensorType({DT_FLOAT, DT_DOUBLE}))
  56. .OP_END_FACTORY_REG(Igamma)
  57. /**
  58. *@brief Compute the upper regularized incomplete Gamma function Q(a, x).
  59. *@par Inputs:
  60. *The input a and x must have the same type. Inputs include: \n
  61. *@li a:A Tensor. Must be one of the following types: float, float64.
  62. *@li x:A Tensor. Must have the same type as a.
  63. *@par Outputs:
  64. *z:A Tensor. Has the same type as a.
  65. *@par Third-party framework compatibility.
  66. *Compatible with tensorflow Igammac operator.
  67. */
  68. REG_OP(Igammac)
  69. .INPUT(a, TensorType({DT_FLOAT, DT_DOUBLE}))
  70. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE}))
  71. .OUTPUT(z, TensorType({DT_FLOAT, DT_DOUBLE}))
  72. .OP_END_FACTORY_REG(Igammac)
  73. /**
  74. *@brief Compare values of input to threshold and pack resulting bits into \n
  75. a uint8.
  76. *@par Inputs:
  77. *The input size must be a non-negative int32 scalar Tensor. Inputs include: \n
  78. *@li input:Values to compare against threshold and bitpack.
  79. *@li threshold:Threshold to compare against.
  80. *@par Outputs:
  81. *y:The bitpacked comparisons.
  82. *@attention Constraints: \n
  83. *Currently, the innermost dimension of the tensor must be divisible by 8. \n
  84. *@par Third-party framework compatibility
  85. *Compatible with tensorflow CompareAndBitpack operator
  86. */
  87. REG_OP(CompareAndBitpack)
  88. .INPUT(x, TensorType({ DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_INT8, \
  89. DT_INT16, DT_INT32, DT_INT64, DT_BOOL }))
  90. .INPUT(threshold, TensorType({ DT_FLOAT, DT_FLOAT16, DT_DOUBLE, \
  91. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_BOOL }))
  92. .OUTPUT(y, TensorType(DT_UINT8))
  93. .OP_END_FACTORY_REG(CompareAndBitpack)
  94. /**
  95. *@brief Counts the number of occurrences of each value in an integer array. \n
  96. Outputs a vector with length size and the same dtype as weights. If weights \n
  97. are empty, then index i stores the number of times the value i is counted in \n
  98. arr. If weights are non-empty, then index i stores the sum of the value in \n
  99. weights at each index.
  100. *@par Inputs:
  101. *The input size must be a non-negative int32 scalar Tensor. Inputs include: \n
  102. *@li array:int32 Tensor.
  103. *@li size:non-negative int32 scalar Tensor.
  104. *@li weights: is an int32, int64, float32, or double Tensor with the same \n
  105. shape as arr, or a length-0 Tensor, in which case it acts as all weights \n
  106. equal to 1.
  107. *@par Outputs:
  108. *bins:1D Tensor with length equal to size. The counts or summed weights for \n
  109. each value in the range [0, size).
  110. *@par Third-party framework compatibility
  111. *Compatible with tensorflow Bincount operator
  112. */
  113. REG_OP(Bincount)
  114. .INPUT(array, TensorType(DT_INT32))
  115. .INPUT(size, TensorType(DT_INT32))
  116. .INPUT(weights, TensorType({ DT_FLOAT, DT_INT32, DT_INT64, DT_DOUBLE }))
  117. .OUTPUT(bins, TensorType({ DT_FLOAT, DT_INT32, DT_INT64, DT_DOUBLE }))
  118. .OP_END_FACTORY_REG(Bincount)
  119. /**
  120. *@brief Compute the regularized incomplete beta integral.
  121. *@par Inputs:
  122. *The input b and x must have the same types as a. Inputs include: \n
  123. *@li a:A Tensor. Must be one of the following types: float32, double.
  124. *@li b:A Tensor. Must have the same type as a.
  125. *@li x:A Tensor. Must have the same type as a.
  126. *@par Outputs:
  127. *z:A Tensor. Has the same type as a.
  128. *@par Third-party framework compatibility.
  129. *Compatible with tensorflow Betainc operator.
  130. */
  131. REG_OP(Betainc)
  132. .INPUT(a, TensorType({DT_DOUBLE, DT_FLOAT}))
  133. .INPUT(b, TensorType({DT_DOUBLE, DT_FLOAT}))
  134. .INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT}))
  135. .OUTPUT(z, TensorType({DT_DOUBLE, DT_FLOAT}))
  136. .OP_END_FACTORY_REG(Betainc)
  137. /**
  138. *@brief Compute the Hurwitz zeta function
  139. *@par Inputs:
  140. *The input q must be the same type as x. Inputs include: \n
  141. *@li x:A Tensor. Must be one of the following types: float32, double.
  142. *@li q:A Tensor. Must have the same type as x.
  143. *@par Outputs:
  144. *z:A Tensor. Has the same type as x.
  145. *@attention Constraints: \n
  146. *The implementation for Zeta on Ascend uses ai cpu, with bad performance. \n
  147. *@par Third-party framework compatibility.
  148. *Compatible with tensorflow Zeta operator.
  149. */
  150. REG_OP(Zeta)
  151. .INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT}))
  152. .INPUT(q, TensorType({DT_DOUBLE, DT_FLOAT}))
  153. .OUTPUT(z, TensorType({DT_DOUBLE, DT_FLOAT}))
  154. .OP_END_FACTORY_REG(Zeta)
  155. /**
  156. *@brief Bucketizes 'input' based on 'boundaries'. For example, if the inputs \n
  157. are boundaries = [0, 10, 100] input = [[-5, 10000] [150, 10] [5, 100]] then \n
  158. the output will be output = [[0, 3] [3, 2] [1, 3]]
  159. *@par Inputs:
  160. *The dtype of input x must be int or float. Inputs include: \n
  161. *x:Any shape of Tensor contains with int or float type.
  162. *@par Attributes:
  163. *boundaries:A sorted list of floats gives the boundary of the buckets.
  164. *@par Outputs:
  165. *y:Same shape with 'input', each value of input replaced with bucket index.
  166. *@par Third-party framework compatibility.
  167. *Compatible with tensorflow Bucketize operator.
  168. */
  169. REG_OP(Bucketize)
  170. .INPUT(x, TensorType({DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT}))
  171. .OUTPUT(y, TensorType({DT_INT32}))
  172. .REQUIRED_ATTR(boundaries, ListFloat)
  173. .OP_END_FACTORY_REG(Bucketize)
  174. /**
  175. *@brief Computes the sum along sparse segments of a tensor.
  176. *@par Inputs:
  177. *The input indices and segment_ids must have same rank. Inputs include: \n
  178. *@li x:A Tensor. Must be one of the following types: float, double, int32, \n
  179. uint8, int16, int8, int64, uint16, uint32, uint64.
  180. *@li indices: A Tensor. Must be one of the following types: int32, int64. \n
  181. A 1-D tensor. Has same rank as segment_ids.
  182. *@li segment_ids: A Tensor of type int32. A 1-D tensor. Values should be \n
  183. sorted and can be repeated.
  184. *@par Outputs:
  185. *y:A Tensor. Has the same type as x.
  186. *@par Third-party framework compatibility
  187. *Compatible with tensorflow SparseSegmentSum operator
  188. */
  189. REG_OP(SparseSegmentSum)
  190. .INPUT(x, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16,
  191. DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
  192. .INPUT(indices, TensorType({DT_INT32}))
  193. .INPUT(segment_ids, TensorType({DT_INT32}))
  194. .OUTPUT(y, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16,
  195. DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
  196. .OP_END_FACTORY_REG(SparseSegmentSum)
  197. /**
  198. *@brief Computes the mean along sparse segments of a tensor.
  199. *@par Inputs:
  200. *The input indices and segment_ids must have same rank. Inputs include: \n
  201. *@li x: A Tensor. Must be one of the following types: float, double.
  202. *@li indices: A Tensor. Must be one of the following types: int32, int64. \n
  203. A 1-D tensor. Has same rank as segment_ids.
  204. *@li segment_ids: A Tensor of type int32. A 1-D tensor. Values should be \n
  205. sorted and can be repeated.
  206. *@par Outputs:
  207. *y:A Tensor. Has the same type as x.
  208. *@par Third-party framework compatibility
  209. *Compatible with tensorflow SparseSegmentMean operator
  210. */
  211. REG_OP(SparseSegmentMean)
  212. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE}))
  213. .INPUT(indices, TensorType({DT_INT32}))
  214. .INPUT(segment_ids, TensorType({DT_INT32}))
  215. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE}))
  216. .OP_END_FACTORY_REG(SparseSegmentMean)
  217. /**
  218. *@brief Computes gradients for SparseSegmentMean.
  219. *@par Inputs:
  220. *The input grad must have be type float or double. Inputs include: \n
  221. *@li grad: A Tensor. Must be one of the following types: float, double. \n
  222. gradient propagated to the SparseSegmentMean op.
  223. *@li indices: A Tensor. Must be one of the following types: int32, int64. \n
  224. indices passed to the corresponding SparseSegmentMean op.
  225. *@li segment_ids: A Tensor of type int32. segment_ids passed to the \n
  226. corresponding SparseSegmentMean op.
  227. *@li output_dim0: A Tensor of type int32. dimension 0 of "x" passed to \n
  228. SparseSegmentMean op.
  229. *@par Outputs:
  230. *y:A Tensor. Has the same type as grad.
  231. *@par Third-party framework compatibility
  232. *Compatible with tensorflow SparseSegmentMeanGrad operator
  233. */
  234. REG_OP(SparseSegmentMeanGrad)
  235. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE}))
  236. .INPUT(indices, TensorType({DT_INT32}))
  237. .INPUT(segment_ids, TensorType({DT_INT32}))
  238. .INPUT(output_dim0, TensorType({DT_INT32}))
  239. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE}))
  240. .OP_END_FACTORY_REG(SparseSegmentMeanGrad)
  241. /**
  242. *@brief Computes the gradient of igamma(a, x) wrt a
  243. *@par Inputs:
  244. *The input a and x must have the same type. Inputs include: \n
  245. *@li a:A Tensor. Must be one of the following types: float32, double.
  246. *@li x:A Tensor. Must have the same type as a.
  247. *@par Outputs:
  248. *y:A Tensor. Has the same type as a.
  249. *@par Third-party framework compatibility
  250. *Compatible with tensorflow IgammaGradA operator
  251. */
  252. REG_OP(IgammaGradA)
  253. .INPUT(a, TensorType({DT_FLOAT, DT_DOUBLE}))
  254. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE}))
  255. .OUTPUT(z, TensorType({DT_FLOAT, DT_DOUBLE}))
  256. .OP_END_FACTORY_REG(IgammaGradA)
  257. /**
  258. *@brief Initialize data process channel.
  259. *@par Attributes:
  260. *channel_name: A string. Default "".
  261. *@par Third-party framework compatibility
  262. *Compatible with tensorflow InitData operator
  263. */
  264. REG_OP(InitData)
  265. .ATTR(channel_name, String, "")
  266. .OP_END_FACTORY_REG(InitData)
  267. /**
  268. *@brief Get the next batch of data in data processing.
  269. *@par Attributes:
  270. *@li output_types: A nested structure of DType objects corresponding to each \n
  271. component of an element of this dataset.
  272. *@li output_shapes: A nested structure of TensorShape objects corresponding \n
  273. to each component of an element of this dataset.
  274. *@li channel_name: A string. Default "".
  275. *@par Outputs:
  276. *y:A nested structure of Tensor objects.
  277. *@par Third-party framework compatibility
  278. *Compatible with tensorflow GetNext operator
  279. */
  280. REG_OP(GetNext)
  281. .DYNAMIC_OUTPUT(y, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64,
  282. DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL}))
  283. .ATTR(output_types, ListInt, {})
  284. .ATTR(output_shapes, ListListInt, {})
  285. .ATTR(output_num, Int, 1)
  286. .ATTR(channel_name, String, "")
  287. .OP_END_FACTORY_REG(GetNext)
  288. /**
  289. *@brief End of sequence.
  290. *@par Inputs:
  291. *x: A Tensor of type uint8.
  292. *@par Outputs:
  293. *y: A Tensor. Has the same type as "x".
  294. */
  295. REG_OP(EndOfSequence)
  296. .INPUT(x, TensorType({DT_UINT8}))
  297. .OUTPUT(y, TensorType({DT_UINT8}))
  298. .OP_END_FACTORY_REG(EndOfSequence)
  299. /**
  300. *@brief: Computes the Gauss error function of `x` element-wise.
  301. *@par Inputs:
  302. *x: A Tensor of type float16, float32 or double. the format can be
  303. * [NCHW,NC1HWC0,NHWC,ND]
  304. *@par Outputs:
  305. *y: A Tensor. Has the same type and format as "x".
  306. *@par Third-party framework compatibility
  307. * Compatible with the TensorFlow operator Erf.
  308. */
  309. REG_OP(Erf)
  310. .INPUT(x, TensorType::FloatingDataType())
  311. .OUTPUT(y, TensorType::FloatingDataType())
  312. .OP_END_FACTORY_REG(Erf)
  313. /**
  314. *@brief: Computes the Gauss complementary error function of "x" element-wise.
  315. *@par Inputs:
  316. *x: A Tensor of type float16 ,float32, double.
  317. *@par Outputs:
  318. *y: A Tensor. Has the same type as "x".
  319. *@par Third-party framework compatibility
  320. * Compatible with the TensorFlow operator Erfc.
  321. */
  322. REG_OP(Erfc)
  323. .INPUT(x, TensorType::FloatingDataType())
  324. .OUTPUT(y, TensorType::FloatingDataType())
  325. .OP_END_FACTORY_REG(Erfc)
  326. /**
  327. *@brief This operation returns a rank 1 histogram counting the number of entries in `values`
  328. * that fell into every bin.The bins are equal width and determined by the arguments
  329. * 'value_range' and 'nbins'.
  330. *@par Inputs:
  331. *Three inputs, including: \n
  332. *@li x: A Tensor of type float32, float16, int32, int64.
  333. *@li range: A Tensor of type float32,float16,int32, int64.
  334. *@li nbins: A Tensor of type int32.
  335. *@par Attributes:
  336. * dtype: An optional attribute. Defaults to "int32".
  337. *@par Outputs:
  338. *y: A Tensor. A Tensor of type int32 or int64.
  339. *@par Third-party framework compatibility
  340. * Compatible with TensorFlow operator HistogramFixedWidth.
  341. */
  342. REG_OP(HistogramFixedWidth)
  343. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64}))
  344. .INPUT(range, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64}))
  345. .INPUT(nbins, TensorType({DT_INT32}))
  346. .OUTPUT(y, TensorType({DT_INT32}))
  347. .ATTR(dtype, String, "int32")
  348. .OP_END_FACTORY_REG(HistogramFixedWidth)
  349. /**
  350. *@brief This operation returns a rank 1 histogram counting the number of entries in `values`
  351. * that fell into every bin.The bins are equal width and determined by the arguments
  352. * 'value_range' and 'nbins'.
  353. *@par Inputs:
  354. *Two inputs, including: \n
  355. *@li x: A Tensor of type float32,float16,int32, int64.
  356. *@li range: A Tensor of type float32,float16,int32, int64.
  357. *@par Attributes:
  358. *@li dtype: An optional attribute. Defaults to "int32".
  359. *@li nbins: A required attribute,the type is int32.
  360. *@par Outputs:
  361. *y: A Tensor. A Tensor of type int32.
  362. *@par Third-party framework compatibility
  363. * Compatible with TensorFlow operator HistogramFixedWidth.
  364. */
  365. REG_OP(HistogramFixedWidthD)
  366. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64}))
  367. .INPUT(range, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64}))
  368. .OUTPUT(y, TensorType({DT_INT32}))
  369. .REQUIRED_ATTR(nbins, Int)
  370. .ATTR(dtype, String, "int32")
  371. .OP_END_FACTORY_REG(HistogramFixedWidthD)
  372. /**
  373. *@brief Returns the next representable value of x1 in the direction of x2, element-wise.
  374. *@par Inputs:
  375. *The input X1 and x2 must have the same type. Inputs include: \n
  376. *@li x1:A Tensor. Must be one of the following types: float32, double.
  377. *@li x2:A Tensor. Must have the same type as x1.
  378. *@par Outputs:
  379. *output:A Tensor. Has the same type as x1.
  380. *@par Third-party framework compatibility
  381. *Compatible with tensorflow NextAfter operator
  382. */
  383. REG_OP(NextAfter)
  384. .INPUT(x1, TensorType({DT_FLOAT, DT_DOUBLE}))
  385. .INPUT(x2, TensorType({DT_FLOAT, DT_DOUBLE}))
  386. .OUTPUT(output, TensorType({DT_FLOAT, DT_DOUBLE}))
  387. .OP_END_FACTORY_REG(NextAfter)
  388. /**
  389. * *@brief Compute element-wise finiteness, return a boolean tensor.
  390. *
  391. * *@par Inputs:
  392. * *x:A Tensor.
  393. *
  394. * *@par Outputs:
  395. * *y:A Tensor. Has the same shape as x.
  396. *
  397. * *@par Third-party framework compatibility.
  398. * *Compatible with tensorflow IsFinite operator.
  399. * */
  400. REG_OP(IsFinite)
  401. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  402. .OUTPUT(y, TensorType({DT_BOOL}))
  403. .OP_END_FACTORY_REG(IsFinite)
  404. /**
  405. * *@brief Computes the complex absolute value of a tensor.
  406. *
  407. * *@par Inputs:
  408. * *x:A Tensor.
  409. *
  410. * *@par Outputs:
  411. * *y:A tensor of type `float` or `double` that is the absolute value of each element in `x`.
  412. *
  413. * *@par Third-party framework compatibility.
  414. * *Compatible with tensorflow ComplexAbs operator.
  415. * */
  416. REG_OP(ComplexAbs)
  417. .INPUT(x, TensorType({DT_COMPLEX64, DT_COMPLEX128}))
  418. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE}))
  419. .ATTR(Tout, Type, DT_FLOAT)
  420. .OP_END_FACTORY_REG(ComplexAbs)
  421. /**
  422. * *@brief Returns which elements of x are NaN.
  423. *
  424. * *@par Inputs:
  425. * *x:A Tensor.
  426. *
  427. * *@par Outputs:
  428. * *y:A Tensor. Has the same shape as x.
  429. *
  430. * *@par Third-party framework compatibility.
  431. * *Compatible with tensorflow IsNan operator.
  432. * */
  433. REG_OP(IsNan)
  434. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  435. .OUTPUT(y, TensorType({DT_BOOL}))
  436. .OP_END_FACTORY_REG(IsNan)
  437. /**
  438. * *@brief Returns the real part of a complex number.
  439. *
  440. * *@par Inputs:
  441. * *input:A Tensor.
  442. *
  443. * *@par Outputs:
  444. * *output:A Tensor. Has the same shape as input.
  445. *
  446. * *@par Third-party framework compatibility.
  447. * *Compatible with tensorflow Real operator.
  448. * */
  449. REG_OP(Real)
  450. .INPUT(input, TensorType({DT_COMPLEX64, DT_COMPLEX128}))
  451. .OUTPUT(output, TensorType({DT_FLOAT, DT_DOUBLE}))
  452. .ATTR(Tout, Type, DT_FLOAT)
  453. .OP_END_FACTORY_REG(Real)
  454. /**
  455. * *@brief Returns the complex conjugate of a complex number.
  456. *
  457. * *@par Inputs:
  458. * *input:A Tensor.
  459. *
  460. * *@par Outputs:
  461. * *output:A Tensor. Has the same shape as input.
  462. *
  463. * *@par Third-party framework compatibility.
  464. * *Compatible with tensorflow output operator.
  465. * */
  466. REG_OP(Conj)
  467. .INPUT(input, TensorType({DT_COMPLEX64, DT_COMPLEX128}))
  468. .OUTPUT(output, TensorType({DT_COMPLEX64, DT_COMPLEX128}))
  469. .OP_END_FACTORY_REG(Conj)
  470. /**
  471. * *@brief The negative log likelihood loss.
  472. *
  473. * *@par Inputs:
  474. * *The input x and weight must have the same type. Inputs include: \n
  475. * *@li x:A Tensor. Must be the type: float32.
  476. * *@li target:A Tensor. Must be the type: int32.
  477. * *@li weight:A Tensor. Must be the type: float32.
  478. *
  479. * *@par Attributes:
  480. * *@li reduction: An optional attribute. Defaults to "mean".
  481. *
  482. * *@par Outputs:
  483. * *Two outputs, including:
  484. * *@li y: A Tensor. Must be the following type: float32.
  485. * *@li total_weight: A Tensor. Must be the type: float32.
  486. *
  487. * *@par Third-party framework compatibility
  488. * *Compatible with pytorch NLLLoss operator
  489. * */
  490. REG_OP(NLLLoss)
  491. .INPUT(x, TensorType({DT_FLOAT}))
  492. .INPUT(target, TensorType({DT_INT32}))
  493. .INPUT(weight, TensorType({DT_FLOAT}))
  494. .OUTPUT(y, TensorType({DT_FLOAT}))
  495. .OUTPUT(total_weight, TensorType({DT_FLOAT}))
  496. .ATTR(reduction, String, "mean")
  497. .OP_END_FACTORY_REG(NLLLoss)
  498. /**
  499. * *@brief The negative log likelihood loss grad.
  500. * *@par Inputs:
  501. * *Inputs include:
  502. * *@li x:A Tensor. Must be the type: float32.
  503. * *@li y_grad:A Tensor. Must be the type: float32.
  504. * *@li target:A Tensor. Must be the type: int32.
  505. * *@li weight:A Tensor. Must be the type: float32.
  506. * *@li total_weight:A Tensor. Must be the type: float32.
  507. *
  508. * *@par Attributes:
  509. * *@li reduction: An optional attribute. Defaults to "mean".
  510. *
  511. * *@par Outputs:
  512. * *One outputs, including:
  513. * *@li x_grad: A Tensor. Must be the following type: float32.
  514. *
  515. * *@par Third-party framework compatibility
  516. * *Compatible with pytorch NLLLossGrad operator
  517. * */
  518. REG_OP(NLLLossGrad)
  519. .INPUT(x, TensorType({DT_FLOAT}))
  520. .INPUT(y_grad, TensorType({DT_FLOAT}))
  521. .INPUT(target, TensorType({DT_INT32}))
  522. .INPUT(weight, TensorType({DT_FLOAT}))
  523. .INPUT(total_weight, TensorType({DT_FLOAT}))
  524. .OUTPUT(x_grad, TensorType({DT_FLOAT}))
  525. .ATTR(reduction, String, "mean")
  526. .OP_END_FACTORY_REG(NLLLossGrad)
  527. } // namespace ge
  528. #endif // GE_OP_MATH_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示