You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stateless_random_ops.h 19 kB

5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file stateless_random_ops.h
  18. * \brief
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_STATELESS_RANDOM_OPS_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_STATELESS_RANDOM_OPS_H_
  22. #include "graph/operator.h"
  23. #include "graph/operator_reg.h"
  24. namespace ge {
  25. /**
  26. *@brief Draws samples from a multinomial distribution . \n
  27. *@par Inputs:
  28. include:
  29. *@li logits:2-D Tensor with shape [batch_size, num_classes]. Each slice [i, :]
  30. *represents the unnormalized log probabilities for all classes.
  31. *@li num_samples:0-D. Number of independent samples to draw for each row slice.
  32. *@li seed:The seed to generate random . \n
  33. *@par Attributes:
  34. *output_dtype:Output data type . \n
  35. *@par Outputs:
  36. *y:Output random number . \n
  37. *@see StatelessMultinomial()
  38. *@par Third-party framework compatibility
  39. *compatible with StatelessMultinomial op of tensorflow
  40. */
  41. REG_OP(StatelessMultinomial)
  42. .INPUT(logits, TensorType({DT_FLOAT16,DT_FLOAT,DT_DOUBLE}))
  43. .INPUT(num_samples, TensorType({DT_INT32}))
  44. .INPUT(seed, TensorType({DT_INT32, DT_INT64}))
  45. .OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
  46. .ATTR(output_dtype, Type, DT_INT64)
  47. .OP_END_FACTORY_REG(StatelessMultinomial)
  48. /**
  49. *@brief Outputs deterministic pseudorandom random integers from a uniform distribution . \n
  50. *@par Inputs:
  51. *@li shape: The shape of the output tensor.
  52. *@li seed: 2 seeds (shape [2]).
  53. *@li minval: Minimum value (inclusive, scalar).
  54. *@li maxval: Maximum value (exclusive, scalar) . \n
  55. *@par Outputs:
  56. *y: Returns Random values with specified shape . \n
  57. *@par Third-party framework compatibility
  58. * Compatible with TensorFlow StatelessRandomUniformInt operator.
  59. */
  60. REG_OP(StatelessRandomUniformInt)
  61. .INPUT(shape, TensorType({DT_INT32, DT_INT64}))
  62. .INPUT(seed, TensorType({DT_INT32, DT_INT64}))
  63. .INPUT(minval, TensorType({DT_INT32, DT_INT64}))
  64. .INPUT(maxval, TensorType({DT_INT32, DT_INT64}))
  65. .OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
  66. .OP_END_FACTORY_REG(StatelessRandomUniformInt)
  67. /**
  68. * @brief Outputs random values from a normal distribution. \n
  69. * @par Inputs:
  70. * Inputs include:
  71. * @li shape: A Tensor. Must be one of the following types: int32, int64.
  72. The shape of the output tensor. Batches are indexed by the 0th dimension.
  73. * @li seed: 2 seeds (shape [2]).
  74. * @li means: A Tensor. Must be one of the following types: half, bfloat16, float32, float64.
  75. * @li stdevs: A Tensor. Must have the same type as means.
  76. * @li min: A Tensor. Must have the same type as means. The minimum cutoff. May be -infinity.
  77. * @li max: A Tensor. Must have the same type as means. \n
  78. * @par Outputs:
  79. * y: A Tensor. Has the same type as means. \n
  80. * @attention Constraints:
  81. * The implementation for StatelessParameterizedTruncatedNormal on Ascend uses AICPU, with bad performance. \n
  82. * @par Third-party framework compatibility
  83. * @li compatible with tensorflow StatelessParameterizedTruncatedNormal operator.
  84. */
  85. REG_OP(StatelessParameterizedTruncatedNormal)
  86. .INPUT(shape, TensorType({DT_INT32, DT_INT64}))
  87. .INPUT(seed, TensorType({DT_INT32, DT_INT64}))
  88. .INPUT(means, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  89. .INPUT(stdevs, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  90. .INPUT(min, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  91. .INPUT(max, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  92. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  93. .OP_END_FACTORY_REG(StatelessParameterizedTruncatedNormal)
  94. /**
  95. * @brief Generate a single randomly distorted bounding box for an image . \n
  96. * @par Inputs:
  97. * Input images must be a 4-D tensor. Inputs include:
  98. * @li image_size: 1-D, containing [height, width, channels].
  99. * @li bounding_boxes: 3-D with shape [batch, N, 4] describing the N bounding
  100. boxes associated with the image.
  101. * @li min_object_covered: The cropped area of the image must contain at least
  102. this fraction of any bounding box supplied. The value of this parameter should
  103. be non-negative. In the case of 0, the cropped area does not need to overlap
  104. any of the bounding boxes supplied .
  105. * @li seed: A shape [2] Tensor, the seed to the random number generator. \n
  106. * @par Attributes:
  107. * @li aspect_ratio_range: The cropped area of the image must have an aspect
  108. ratio = width / height within this range.
  109. * @li area_range: An optional list of `floats`. Defaults to `[0.05, 1]`. The
  110. cropped area of the image must contain a fraction of the supplied image
  111. within this range.
  112. * @li max_attempts: Number of attempts at generating a cropped region of the
  113. image of the specified constraints. After max_attempts failures, return the
  114. entire image.
  115. * @li use_image_if_no_bounding_boxes: Controls behavior if no bounding boxes
  116. supplied. If true, assume an implicit bounding box covering the whole input.
  117. If false, raise an error . \n
  118. * @par Outputs:
  119. * @li begin: 1-D, containing [offset_height, offset_width, 0].
  120. * @li size: 1-D, containing [target_height, target_width, -1].
  121. * @li bboxes: 3-D with shape [1, 1, 4] containing the distorted bounding box . \n
  122. * @attention Constraints:
  123. * Input images can be of different types but output images are always float . \n
  124. * @par Third-party framework compatibility
  125. * Compatible with tensorflow StatelessSampleDistortedBoundingBox operator.
  126. */
  127. REG_OP(StatelessSampleDistortedBoundingBox)
  128. .INPUT(image_size, TensorType({ DT_UINT8, DT_INT8, DT_INT16, \
  129. DT_INT32, DT_INT64 }))
  130. .INPUT(bounding_boxes, TensorType({ DT_FLOAT }))
  131. .INPUT(min_object_covered, TensorType({ DT_FLOAT }))
  132. .INPUT(seed, TensorType({ DT_INT32, DT_INT64 }))
  133. .OUTPUT(begin, TensorType({ DT_UINT8, DT_INT8, DT_INT16, \
  134. DT_INT32, DT_INT64 }))
  135. .OUTPUT(size, TensorType({ DT_UINT8, DT_INT8, DT_INT16, \
  136. DT_INT32, DT_INT64 }))
  137. .OUTPUT(bboxes, TensorType({ DT_FLOAT }))
  138. .ATTR(aspect_ratio_range, ListFloat, { 0.75f, 1.33f })
  139. .ATTR(area_range, ListFloat, { 0.05f, 1.0f })
  140. .ATTR(max_attempts, Int, 100)
  141. .ATTR(use_image_if_no_bounding_boxes, Bool, false)
  142. .OP_END_FACTORY_REG(StatelessSampleDistortedBoundingBox)
  143. /**
  144. * @brief Outputs random values from a truncated normal distribution. \n
  145. * @par Inputs:
  146. * Inputs include:
  147. * @li shape: A Tensor. Must be one of the following types: int32, int64. \n
  148. * @li key: Key of RNG algorithm. Shape[1]. \n
  149. * @li counter: Counter of RNG algorithm. Shape[2] for philox, shape[1] for threefry. \n
  150. * @li alg: RNG algorithm. 1:philox 2:threefry. \n
  151. * @par Attributes:
  152. * @li dtype: dtype: A optional attr, specifying the output data type. Defaults to "DT_FLOAT". \n
  153. * @par Outputs:
  154. * y: A Tensor of types: float16, float32, double. A tensor of the specified shape
  155. filled with random truncated normal values. \n
  156. * @attention Constraints:
  157. * The implementation for StatelessTruncatedNormalV2 on Ascend uses AICPU, with bad performance.
  158. * @par Third-party framework compatibility
  159. * @li compatible with tensorflow StatelessTruncatedNormalV2 operator.
  160. */
  161. REG_OP(StatelessTruncatedNormalV2)
  162. .INPUT(shape, TensorType({ DT_INT32, DT_INT64 }))
  163. .INPUT(key, TensorType({ DT_UINT64 }))
  164. .INPUT(counter, TensorType({ DT_UINT64 }))
  165. .INPUT(alg, TensorType({ DT_INT32 }))
  166. .OUTPUT(y, TensorType({ DT_FLOAT16, DT_FLOAT, DT_DOUBLE }))
  167. .ATTR(dtype, Type, DT_FLOAT)
  168. .OP_END_FACTORY_REG(StatelessTruncatedNormalV2)
  169. /**
  170. * @brief Outputs deterministic pseudorandom random numbers from a gamma distribution. \n
  171. * @par Inputs:
  172. * @li shape: The shape of the output tensor.
  173. * @li seed: 2 seeds (shape [2]).
  174. * @li alpha: The concentration of the gamma distribution. Shape must match the rightmost dimensions of shape. \n
  175. * @par Outputs:
  176. * y: A Tensor. Has the same type as alpha. \n
  177. * @par Third-party framework compatibility
  178. * Compatible with TensorFlow StatelessRandomGammaV2 operator.
  179. */
  180. REG_OP(StatelessRandomGammaV2)
  181. .INPUT(shape, TensorType({DT_INT32, DT_INT64}))
  182. .INPUT(seed, TensorType({DT_INT32, DT_INT64}))
  183. .INPUT(alpha, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE}))
  184. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE}))
  185. .OP_END_FACTORY_REG(StatelessRandomGammaV2)
  186. /**
  187. * @brief Outputs deterministic pseudorandom random integers from a uniform distribution . \n
  188. * @par Inputs:
  189. * @li shape: The shape of the output tensor.
  190. * @li seed: 2 seeds (shape [2]). \n
  191. * @par Attributes:
  192. * dtype:Output data type . \n
  193. * @par Outputs:
  194. * y: Returns Random values with specified shape . \n
  195. * @par Third-party framework compatibility
  196. * Compatible with TensorFlow StatelessRandomUniformFullInt operator.
  197. */
  198. REG_OP(StatelessRandomUniformFullInt)
  199. .INPUT(shape, TensorType({DT_INT32, DT_INT64}))
  200. .INPUT(seed, TensorType({DT_INT32, DT_INT64}))
  201. .OUTPUT(y, TensorType({DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}))
  202. .ATTR(dtype, Type, DT_INT32)
  203. .OP_END_FACTORY_REG(StatelessRandomUniformFullInt)
  204. /**
  205. * @brief Outputs deterministic pseudorandom random integers from a uniform distribution . \n
  206. * @par Inputs:
  207. * @li shape: The shape of the output tensor.
  208. * @li key: Key for the counter-based RNG algorithm.
  209. * @li counter: Initial counter for the counter-based RNG algorithm.
  210. * @li alg: 0-D. The RNG algorithm. \n
  211. * @par Attributes:
  212. * dtype:Output data type . \n
  213. * @par Outputs:
  214. * y: Returns Random values with specified shape . \n
  215. * @par Third-party framework compatibility
  216. * Compatible with TensorFlow StatelessRandomUniformFullIntV2 operator.
  217. */
  218. REG_OP(StatelessRandomUniformFullIntV2)
  219. .INPUT(shape, TensorType({DT_INT32, DT_INT64}))
  220. .INPUT(key, TensorType({DT_UINT64}))
  221. .INPUT(counter, TensorType({DT_UINT64}))
  222. .INPUT(alg, TensorType({DT_INT32}))
  223. .OUTPUT(y, TensorType({DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}))
  224. .ATTR(dtype, Type, DT_INT32)
  225. .OP_END_FACTORY_REG(StatelessRandomUniformFullIntV2)
  226. /**
  227. * @brief Outputs deterministic pseudorandom random integers from a uniform distribution . \n
  228. * @par Inputs:
  229. * @li shape: The shape of the output tensor.
  230. * @li key: Key for the counter-based RNG algorithm.
  231. * @li counter: Initial counter for the counter-based RNG algorithm.
  232. * @li alg: 0-D. The RNG algorithm.
  233. * @li minval: Minimum value (inclusive, scalar).
  234. * @li maxval: Maximum value (exclusive, scalar) . \n
  235. * @par Outputs:
  236. * y: Returns Random values with specified shape . \n
  237. * @par Third-party framework compatibility
  238. * Compatible with TensorFlow StatelessRandomUniformIntV2 operator.
  239. */
  240. REG_OP(StatelessRandomUniformIntV2)
  241. .INPUT(shape, TensorType({DT_INT32, DT_INT64}))
  242. .INPUT(key, TensorType({DT_UINT64}))
  243. .INPUT(counter, TensorType({DT_UINT64}))
  244. .INPUT(alg, TensorType({DT_INT32}))
  245. .INPUT(minval, TensorType({DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}))
  246. .INPUT(maxval, TensorType({DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}))
  247. .OUTPUT(y, TensorType({DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}))
  248. .OP_END_FACTORY_REG(StatelessRandomUniformIntV2)
  249. /**
  250. * @brief Outputs deterministic pseudorandom random integers from a binomial distribution. \n
  251. * @par Inputs:
  252. * @li shape: The shape of the output tensor.
  253. * @li seed: 2 seeds (shape [2]).
  254. * @li counts: The counts of the binomial distribution. Must be broadcastable with probs,
  255. * and broadcastable with the rightmost dimensions of shape.
  256. * @li probs: The probability of success for the binomial distribution.
  257. * Must be broadcastable with counts and broadcastable with the rightmost dimensions of shape. \n
  258. * @par Attributes:
  259. * @li dtype: A optional int32, specifying the output data type. Defaults to "DT_INT32". \n
  260. * @par Outputs:
  261. * @li y: Returns Random values with specified shape. \n
  262. * @par Third-party framework compatibility
  263. * Compatible with TensorFlow StatelessRandomBinomial operator.
  264. */
  265. REG_OP(StatelessRandomBinomial)
  266. .INPUT(shape, TensorType({DT_INT32, DT_INT64}))
  267. .INPUT(seed, TensorType({DT_INT32, DT_INT64}))
  268. .INPUT(counts, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64}))
  269. .INPUT(probs, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64}))
  270. .OUTPUT(y, TensorType({DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  271. .ATTR(dtype, Type, DT_INT32)
  272. .OP_END_FACTORY_REG(StatelessRandomBinomial)
  273. /**
  274. * @brief Outputs deterministic pseudorandom random integers from a poisson distribution . \n
  275. * @par Inputs:
  276. * @li shape: The shape of the output tensor.
  277. * @li seed: 2 seeds (shape [2]).
  278. * @li lam: mean value value of poisson distribution . \n
  279. * @par Attributes:
  280. * dtype:Output data type . \n
  281. * @par Outputs:
  282. * y: Returns Random values with specified shape . \n
  283. * @par Third-party framework compatibility
  284. * Compatible with TensorFlow StatelessRandomUniformInt operator.
  285. */
  286. REG_OP(StatelessRandomPoisson)
  287. .INPUT(shape, TensorType({DT_INT32, DT_INT64}))
  288. .INPUT(seed, TensorType({DT_INT32, DT_INT64}))
  289. .INPUT(lam, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_INT32, DT_INT64}))
  290. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_INT32, DT_INT64}))
  291. .REQUIRED_ATTR(dtype, Type)
  292. .OP_END_FACTORY_REG(StatelessRandomPoisson)
  293. /**
  294. * @brief Get the counter of the RNG algorithm. \n
  295. * @par Outputs:
  296. * @li alg: The RNG algorithm. \n
  297. * @par Third-party framework compatibility
  298. * Compatible with TensorFlow StatelessRandomGetAlg operator.
  299. */
  300. REG_OP(StatelessRandomGetAlg)
  301. .OUTPUT(alg, TensorType({DT_INT32}))
  302. .OP_END_FACTORY_REG(StatelessRandomGetAlg)
  303. /**
  304. * @brief This op picks the best counter-based RNG algorithm based on device, and
  305. * scrambles a shape-[2] seed into a key and a counter, both needed by the
  306. * counter-based algorithm. \n
  307. * @par Inputs:
  308. * @li seed: 2 seeds (shape [2]). \n
  309. * @par Outputs:
  310. * @li key: Key for the counter-based RNG algorithm.
  311. * @li counter: Initial counter for the counter-based RNG algorithm. \n
  312. * @par Third-party framework compatibility
  313. * Compatible with TensorFlow StatelessRandomGetKeyCounter operator.
  314. */
  315. REG_OP(StatelessRandomGetKeyCounter)
  316. .INPUT(seed, TensorType({DT_INT32, DT_INT64}))
  317. .OUTPUT(key, TensorType({DT_UINT64}))
  318. .OUTPUT(counter, TensorType({DT_UINT64}))
  319. .OP_END_FACTORY_REG(StatelessRandomGetKeyCounter)
  320. /**
  321. * @brief This op picks the best counter-based RNG algorithm based on device, and
  322. * scrambles a shape-[2] seed into a key and a counter, both needed by the
  323. * counter-based algorithm. \n
  324. * @par Inputs:
  325. * @li seed: 2 seeds (shape [2]). \n
  326. * @par Outputs:
  327. * @li key: Key for the counter-based RNG algorithm.
  328. * @li counter: Initial counter for the counter-based RNG algorithm.
  329. * @li alg: The RNG algorithm. \n
  330. * @par Third-party framework compatibility
  331. * Compatible with TensorFlow StatelessRandomGetKeyCounterAlg operator.
  332. */
  333. REG_OP(StatelessRandomGetKeyCounterAlg)
  334. .INPUT(seed, TensorType({DT_INT32, DT_INT64}))
  335. .OUTPUT(key, TensorType({DT_UINT64}))
  336. .OUTPUT(counter, TensorType({DT_UINT64}))
  337. .OUTPUT(alg, TensorType({DT_INT32}))
  338. .OP_END_FACTORY_REG(StatelessRandomGetKeyCounterAlg)
  339. /**
  340. * @brief Outputs deterministic pseudorandom values from a normal distribution. \n
  341. * @par Inputs:
  342. * @li shape: The shape of the output tensor.
  343. * @li key: Key for the counter-based RNG algorithm.
  344. * @li counter: Initial counter for the counter-based RNG algorithm.
  345. * @li alg: The RNG algorithm. \n
  346. * @par Attributes:
  347. * @li dtype: Output data type . \n
  348. * @par Outputs:
  349. * @li y: Returns Random values with specified shape . \n
  350. * @par Third-party framework compatibility
  351. * Compatible with TensorFlow StatelessRandomNormalV2 operator.
  352. */
  353. REG_OP(StatelessRandomNormalV2)
  354. .INPUT(shape, TensorType({DT_INT32, DT_INT64}))
  355. .INPUT(key, TensorType({DT_UINT64}))
  356. .INPUT(counter, TensorType({DT_UINT64}))
  357. .INPUT(alg, TensorType({DT_INT32}))
  358. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  359. .ATTR(dtype, Type, DT_FLOAT)
  360. .OP_END_FACTORY_REG(StatelessRandomNormalV2)
  361. /**
  362. * @brief Outputs deterministic pseudorandom random integers from a uniform distribution . \n
  363. * @par Inputs:
  364. * @li shape: The shape of the output tensor.
  365. * @li key: Key for the counter-based RNG algorithm.
  366. * @li counter: Initial counter for the counter-based RNG algorithm.
  367. * @li alg: 0-D. The RNG algorithm. \n
  368. * @par Attributes:
  369. * dtype:Output data type . \n
  370. * @par Outputs:
  371. * y: Returns Random values with specified shape . \n
  372. * @par Third-party framework compatibility
  373. * Compatible with TensorFlow StatelessRandomUniformV2 operator.
  374. */
  375. REG_OP(StatelessRandomUniformV2)
  376. .INPUT(shape, TensorType({DT_INT32, DT_INT64}))
  377. .INPUT(key, TensorType({DT_UINT64}))
  378. .INPUT(counter, TensorType({DT_UINT64}))
  379. .INPUT(alg, TensorType({DT_INT32}))
  380. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE}))
  381. .ATTR(dtype, Type, DT_FLOAT)
  382. .OP_END_FACTORY_REG(StatelessRandomUniformV2)
  383. /**
  384. * @brief Create a random number seed generator . \n
  385. * @par Inputs:
  386. * include:
  387. * @li seed:1-D Tensor,the seed to generate random.
  388. * Must be one of the types:int32 or int64.
  389. * @li seed2:1-D Tensor,the seed to generate random.
  390. * Must be one of the types:int32 or int64.
  391. * @li reshuffle:1-D Tensor.Seed selection, True:random seed, False:fixed seed.
  392. * Must be one of the types:bool. \n
  393. * @par Outputs:
  394. * handle:Handle to the random number generator.
  395. * deleter:Handle to the remover.
  396. * Used when deleting the random number seed generator \n
  397. * @see AnonymousSeedGenerator()
  398. * @par Third-party framework compatibility
  399. * compatible with AnonymousSeedGenerator op of tensorflow
  400. */
  401. REG_OP(AnonymousSeedGenerator)
  402. .INPUT(seed, TensorType({DT_INT32,DT_INT64}))
  403. .INPUT(seed2, TensorType({DT_INT32,DT_INT64}))
  404. .INPUT(reshuffle, TensorType({DT_BOOL}))
  405. .OUTPUT(handle, TensorType({DT_RESOURSE}))
  406. .OUTPUT(deleter, TensorType({DT_VARIANT}))
  407. .OP_END_FACTORY_REG(AnonymousSeedGenerator)
  408. /**
  409. * @brief DeleteSeedGenerator . \n
  410. * @par Inputs:
  411. * @li handle: A Tensor of type resource.
  412. * @li deleter: A Tensor of type variant.
  413. * @par Third-party framework compatibility
  414. * Compatible with TensorFlow DeleteSeedGenerator operator.
  415. */
  416. REG_OP(DeleteSeedGenerator)
  417. .INPUT(handle, TensorType({DT_RESOURCE}))
  418. .INPUT(deleter, TensorType({DT_VARIANT}))
  419. .OP_END_FACTORY_REG(DeleteSeedGenerator)
  420. /**
  421. * @brief Create a placeholder handle to rewrite and pass
  422. * to use during the graph compilation phase. \n
  423. * @par Outputs:
  424. * handle:Output random number . \n
  425. */
  426. REG_OP(DummySeedGenerator)
  427. .OUTPUT(handle, TensorType({ DT_RESOURCE }))
  428. .OP_END_FACTORY_REG(DummySeedGenerator)
  429. } // namespace ge
  430. #endif // OPS_BUILT_IN_OP_PROTO_INC_STATELESS_RANDOM_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示