You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn_training_ops.h 34 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_OP_TRAINING_OPS_H
  17. #define GE_OP_TRAINING_OPS_H
  18. #include "../graph/operator_reg.h"
  19. namespace ge {
  20. /**
  21. *@brief Updates "var" according to the AdaMax algorithm.\n
  22. * t-1 mean previous period.
  23. * m_t <- beta1 * m{t-1} + (1 - beta1) * grad\n
  24. * v_t <- max(beta2 * v{t-1}, abs(grad))\n
  25. * var <- var - lr / (1 - beta1^t) * m_t / (v_t + epsilon)
  26. *
  27. *@attention Constraints:\n
  28. * the input tensors must have the same shape.
  29. *
  30. *@par Inputs:
  31. *@li var: A mutable tensor. Must be one of the following types: TensorType::NumberType().
  32. * Should be from a Variable().
  33. *@li m: A mutable tensor. Has the same type as "var".
  34. * Should be from a Variable().
  35. *@li v: A mutable tensor. Has the same type as "var".
  36. * Should be from a Variable().
  37. *@li beta1_power: A scalar. Has the same type as "var".
  38. *@li lr: learning_rate. A scalar. Has the same type as "var".
  39. *@li beta1: A scalar. Has the same type as "var".
  40. *@li beta2: A scalar. Has the same type as "var".
  41. *@li epsilon: A scalar. Has the same type as "var".
  42. *@li grad: A tensor for the gradient. Has the same type as "var".
  43. *
  44. *@par Attributes:\n
  45. * use_locking: An optional bool. Defaults to "False".
  46. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  47. * by a lock; otherwise the behavior is undefined, but may exhibit less
  48. * contention.
  49. *
  50. *@par Outputs:
  51. * var: A mutable tensor. Has the same type as input "var".
  52. *
  53. */
  54. REG_OP(ApplyAdaMax)
  55. .INPUT(var, TensorType::NumberType())
  56. .INPUT(m, TensorType::NumberType())
  57. .INPUT(v, TensorType::NumberType())
  58. .INPUT(beta1_power, TensorType::NumberType())
  59. .INPUT(lr, TensorType::NumberType())
  60. .INPUT(beta1, TensorType::NumberType())
  61. .INPUT(beta2, TensorType::NumberType())
  62. .INPUT(epsilon, TensorType::NumberType())
  63. .INPUT(grad, TensorType::NumberType())
  64. .OUTPUT(var, TensorType::NumberType())
  65. .ATTR(use_locking, Bool, false)
  66. .OP_END_FACTORY_REG(ApplyAdaMax)
  67. /**
  68. *@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you
  69. * want to use Nesterov momentum.\n
  70. * computing process: \n
  71. * accum = accum * momentum + grad\n
  72. * var -= lr * accum
  73. *
  74. *@attention Constraints:\n
  75. * the input tensors must have the same shape.
  76. *
  77. *@par Inputs:
  78. *@li var: A mutable tensor. Should be from a Variable().
  79. *@li accum: A mutable tensor. Has the same type as "var".
  80. * Should be from a Variable().
  81. *@li lr: A scalar. Has the same type as "var".
  82. *@li grad: A tensor for the gradient. Has the same type as "var".
  83. *
  84. *@par Attributes:
  85. *@li use_nesterov: An optional bool. Defaults to "False".
  86. * If "True", the tensor passed to compute grad will be
  87. * var - lr * momentum * accum, so in the end, the var you get is actually
  88. * var - lr * momentum * accum.
  89. *
  90. *@li use_locking: An optional bool. Defaults to "False".\n
  91. * If "True", updating of the "var", "ms", and "mom" tensors is protected by a lock;
  92. * otherwise the behavior is undefined, but may exhibit less contention.
  93. *
  94. *@par Outputs:
  95. * var: A mutable tensor. Has the same type as input "var".
  96. *
  97. */
  98. REG_OP(ApplyMomentum)
  99. .INPUT(var, TensorType::NumberType())
  100. .INPUT(accum, TensorType::NumberType())
  101. .INPUT(lr, TensorType::NumberType())
  102. .INPUT(grad, TensorType::NumberType())
  103. .INPUT(momentum, TensorType::NumberType())
  104. .OUTPUT(var, TensorType::NumberType())
  105. .ATTR(use_nesterov, Bool, false)
  106. .ATTR(use_locking, Bool, false)
  107. .OP_END_FACTORY_REG(ApplyMomentum)
  108. REG_OP(ApplyMomentumCCE)
  109. .INPUT(var, TensorType::NumberType())
  110. .INPUT(accum, TensorType::NumberType())
  111. .INPUT(lr, TensorType::NumberType())
  112. .INPUT(grad, TensorType::NumberType())
  113. .INPUT(momentum, TensorType::NumberType())
  114. .OUTPUT(var, TensorType::NumberType())
  115. .ATTR(use_nesterov, Bool, false)
  116. .ATTR(use_locking, Bool, false)
  117. .OP_END_FACTORY_REG(ApplyMomentumCCE)
  118. /**
  119. *@brief Updates "var" according to the AddSign update.\n
  120. * t-1 mean previous period.
  121. * m_t <- beta1 * m_{t-1} + (1 - beta1) * grad\n
  122. * update <- exp(logbase * sign_decay * sign(grad) * sign(m_t)) * grad\n
  123. * var <- var - lr * update
  124. *
  125. *@attention Constraints:\n
  126. * the input tensors must have the same shape.
  127. *
  128. *@par Inputs:
  129. *@li var: A mutable tensor. Should be from a Variable().
  130. *@li m: A mutable tensor. Has the same type as "var".
  131. * Should be from a Variable().
  132. *@li lr: A scalar. Has the same type as "var".
  133. *@li logbase: A scalar. Has the same type as "var".
  134. *@li sign_decay: A scalar. Has the same type as "var".
  135. *@li beta: A scalar. Has the same type as "var".
  136. *@li grad: A tensor for the gradient. Has the same type as "var".
  137. *
  138. *@par Attributes:
  139. * use_locking: An optional bool. Defaults to "False".
  140. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  141. * by a lock; otherwise the behavior is undefined, but may exhibit less
  142. * contention.
  143. *
  144. *@par Outputs:
  145. * var: A mutable tensor. Has the same type as input "var".
  146. *
  147. */
  148. REG_OP(ApplyPowerSign)
  149. .INPUT(var, TensorType::NumberType())
  150. .INPUT(m, TensorType::NumberType())
  151. .INPUT(lr, TensorType::NumberType())
  152. .INPUT(logbase, TensorType::NumberType())
  153. .INPUT(sign_decay, TensorType::NumberType())
  154. .INPUT(beta, TensorType::NumberType())
  155. .INPUT(grad, TensorType::NumberType())
  156. .OUTPUT(var, TensorType::NumberType())
  157. .ATTR(use_locking, Bool, false)
  158. .OP_END_FACTORY_REG(ApplyPowerSign)
  159. /**
  160. *@brief Updates "var" as FOBOS algorithm with fixed learning rate.\n
  161. * prox_v = var - alpha * delta\n
  162. * var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0}
  163. *
  164. *@attention Constraints:\n
  165. * the input tensors must have the same shape.
  166. *
  167. *@par Inputs:
  168. *@li var: A mutable tensor. Should be from a Variable().
  169. *@li alpha: A scalar. Has the same type as "var".
  170. *@li l1: A scalar. Has the same type as "var".
  171. *@li l2: A scalar. Has the same type as "var".
  172. *@li delta: A tensor. Has the same type as "var". The change.
  173. *
  174. *@par Attributes:
  175. * use_locking: An optional bool. Defaults to "False".
  176. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  177. * by a lock; otherwise the behavior is undefined, but may exhibit less
  178. * contention.
  179. *
  180. *@par Outputs:
  181. * var: A mutable tensor. Has the same type as input "var".
  182. *
  183. */
  184. REG_OP(ApplyProximalGradientDescent)
  185. .INPUT(var, TensorType::NumberType())
  186. .INPUT(alpha, TensorType::NumberType())
  187. .INPUT(l1, TensorType::NumberType())
  188. .INPUT(l2, TensorType::NumberType())
  189. .INPUT(delta, TensorType::NumberType())
  190. .OUTPUT(var, TensorType::NumberType())
  191. .ATTR(use_locking, Bool, false)
  192. .OP_END_FACTORY_REG(ApplyProximalGradientDescent)
  193. /**
  194. *@brief Updates "var" according to the AddSign update.
  195. *@par Inputs:
  196. *Seven inputs, including:
  197. * @li var: A mutable Tensor of type TensorType::NumberType().
  198. * Should be a Variable Tensor.
  199. * @li m: A mutable Tensor of the same type as "var".
  200. * Should be a Variable Tensor.
  201. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  202. * @li alpha: A Tensor of the same type as "var". Must be a scalar.
  203. * @li sign_decay: A Tensor of the same type as "var". Must be a scalar.
  204. * @li beta: A Tensor of the same type as "var". Must be a scalar.
  205. * @li grad: A Tensor of the same type as "var", for the gradient.
  206. *@par Attributes:
  207. *use_locking: An optional bool. Defaults to "False".
  208. * If "True", updating of the "var" and "m" tensors will be
  209. * protected by a lock; otherwise the behavior is undefined,
  210. * but may exhibit less contention.
  211. *@par Outputs:
  212. *var: A mutable Tensor. Has the same type as "var".
  213. */
  214. REG_OP(ApplyAddSign)
  215. .INPUT(var, TensorType::NumberType())
  216. .INPUT(m, TensorType::NumberType())
  217. .INPUT(lr, TensorType::NumberType())
  218. .INPUT(alpha, TensorType::NumberType())
  219. .INPUT(sign_decay, TensorType::NumberType())
  220. .INPUT(beta, TensorType::NumberType())
  221. .INPUT(grad, TensorType::NumberType())
  222. .OUTPUT(var, TensorType::NumberType())
  223. .ATTR(use_locking, Bool, false)
  224. .OP_END_FACTORY_REG(ApplyAddSign)
  225. /**
  226. *@brief Updates "var" according to the centered RMSProp algorithm.\n
  227. * The centered RMSProp algorithm uses an estimate of the centered second moment
  228. * (i.e., the variance) for normalization, as opposed to regular RMSProp, which
  229. * uses the (uncentered) second moment. This often helps with training, but is
  230. * slightly more expensive in terms of computation and memory.
  231. *
  232. * t-1 mean previous period.
  233. * mg <- rho * mg{t-1} + (1-rho) * grad\n
  234. * ms <- rho * ms{t-1} + (1-rho) * grad * grad\n
  235. * mom <- momentum * mom{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)\n
  236. * var <- var - mom\n
  237. *
  238. *@attention Constraints:\n
  239. *@li in dense implementation of this algorithm, mg, ms, and mom will
  240. * update even if the grad is zero, but in this sparse implementation, mg, ms,
  241. * and mom will not update in iterations during which the grad is zero.
  242. *@li the input tensors must have the same shape.
  243. *
  244. *@par Inputs:
  245. *@li var: A mutable tensor. Should be from a Variable().
  246. *@li mg: A mutable tensor. Has the same type as "var".
  247. * Should be from a Variable().
  248. *@li ms: A mutable tensor. Has the same type as "var".
  249. * Should be from a Variable().
  250. *@li mom: A mutable tensor. Has the same type as "var".
  251. * Should be from a Variable().
  252. *@li lr: A scalar. Has the same type as "var".
  253. *@li rho: A scalar. Has the same type as "var".
  254. *@li momentum: A tensor. Has the same type as "var".
  255. *@li epsilon: A scalar. Has the same type as "var".
  256. *@li grad: A tensor for the gradient. Has the same type as "var".
  257. *
  258. *@par Attributes:
  259. * use_locking: An optional bool. Defaults to "False".
  260. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  261. * by a lock; otherwise the behavior is undefined, but may exhibit less
  262. * contention.
  263. *
  264. *@par Outputs:
  265. * var: A mutable tensor. Has the same type as input "var".
  266. *
  267. */
  268. REG_OP(ApplyCenteredRMSProp)
  269. .INPUT(var, TensorType::NumberType())
  270. .INPUT(mg, TensorType::NumberType())
  271. .INPUT(ms, TensorType::NumberType())
  272. .INPUT(mom, TensorType::NumberType())
  273. .INPUT(lr, TensorType::NumberType())
  274. .INPUT(rho, TensorType::NumberType())
  275. .INPUT(momentum, TensorType::NumberType())
  276. .INPUT(epsilon, TensorType::NumberType())
  277. .INPUT(grad, TensorType::NumberType())
  278. .OUTPUT(var, TensorType::NumberType())
  279. .ATTR(use_locking, Bool, false)
  280. .OP_END_FACTORY_REG(ApplyCenteredRMSProp)
  281. /**
  282. *@brief Updates "var" by subtracting 'alpha' * 'delta' from it.\n
  283. * var -= delta * alpha
  284. *
  285. *@attention Constraints:\n
  286. * the input tensors must have the same shape.
  287. *
  288. *@par Inputs:
  289. *@li var: A mutable tensor. Should be from a Variable().
  290. *@li alpha: A scalar. Has the same type as "var".
  291. *@li delta: A tensor for the change. Has the same type as "var".
  292. *
  293. *@par Attributes:
  294. * use_locking: An optional bool. Defaults to "False".
  295. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  296. * by a lock; otherwise the behavior is undefined, but may exhibit less
  297. * contention.
  298. *
  299. *@par Outputs:
  300. * var: A mutable tensor. Has the same type as input "var".
  301. *
  302. */
  303. REG_OP(ApplyGradientDescent)
  304. .INPUT(var, TensorType::NumberType())
  305. .INPUT(alpha, TensorType::NumberType())
  306. .INPUT(delta, TensorType::NumberType())
  307. .OUTPUT(var, TensorType::NumberType())
  308. .ATTR(use_locking, Bool, false)
  309. .OP_END_FACTORY_REG(ApplyGradientDescent)
  310. /**
  311. *@brief Updates "var" according to the adagrad scheme.\n
  312. * accum += grad * grad\n
  313. * var -= lr * grad * (1 / sqrt(accum))
  314. *
  315. *@attention Constraints:\n
  316. * the input tensors must have the same shape.
  317. *
  318. *@par Inputs:
  319. *@li var: A mutable tensor. Should be from a Variable().
  320. *@li accum: A mutable tensor. Has the same type as "var".
  321. * Should be from a Variable().
  322. *@li lr: A scalar. Has the same type as "var".
  323. *@li grad: A tensor for the gradient. Has the same type as "var".
  324. *
  325. *@par Attributes:
  326. * use_locking: An optional bool. Defaults to "False".
  327. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  328. * by a lock; otherwise the behavior is undefined, but may exhibit less
  329. * contention.
  330. *
  331. *@par Outputs:
  332. * var: A mutable tensor. Has the same type as input "var".
  333. *
  334. */
  335. REG_OP(ApplyAdagrad)
  336. .INPUT(var, TensorType::NumberType())
  337. .INPUT(accum, TensorType::NumberType())
  338. .INPUT(lr, TensorType::NumberType())
  339. .INPUT(grad, TensorType::NumberType())
  340. .OUTPUT(var, TensorType::NumberType())
  341. .ATTR(update_slots, Bool, true)
  342. .ATTR(use_locking, Bool, false)
  343. .OP_END_FACTORY_REG(ApplyAdagrad)
  344. /**
  345. *@brief Updates "var" according to the proximal adagrad scheme.
  346. *@par Inputs:
  347. *Eight inputs, including:
  348. * @li var: A mutable Tensor. Must be one of the following types:
  349. * TensorType::NumberType(). Should be a Variable Tensor.
  350. * @li gradient_accumulator: A mutable Tensor. Must have the same
  351. * type as "var". Should be a Variable Tensor.
  352. * @li gradient_squared_accumulator: A mutable Tensor of the same type as "var".
  353. * Should be a Variable Tensor.
  354. * @li grad: A Tensor of the same type as "var", for the gradient.
  355. * @li lr: A Tensor of the same type as "var".
  356. * Scaling factor. Must be a scalar.
  357. * @li l1: A Tensor of the same type as "var".
  358. * L1 regulariation. Must be a scalar.
  359. * @li l2: A Tensor of the same type as "var".
  360. * L2 regulariation. Must be a scalar.
  361. * @li global_step: A Tensor of type int32 or int64.
  362. * Training step number. Must be a scalar.
  363. *@par Attributes:
  364. *use_locking: An optional bool. Defaults to "False".
  365. * If "True", updating of the var and accum tensors will be
  366. * protected by a lock; otherwise the behavior is undefined,
  367. * but may exhibit less contention.
  368. *@par Outputs:
  369. *var: A mutable Tensor. Has the same type as "var".
  370. */
  371. REG_OP(ApplyAdagradDA)
  372. .INPUT(var, TensorType::NumberType())
  373. .INPUT(gradient_accumulator, TensorType::NumberType())
  374. .INPUT(gradient_squared_accumulator, TensorType::NumberType())
  375. .INPUT(grad, TensorType::NumberType())
  376. .INPUT(lr, TensorType::NumberType())
  377. .INPUT(l1, TensorType::NumberType())
  378. .INPUT(l2, TensorType::NumberType())
  379. .INPUT(global_step, TensorType({DT_INT32, DT_INT64}))
  380. .OUTPUT(var, TensorType::NumberType())
  381. .ATTR(use_locking, Bool, false)
  382. .OP_END_FACTORY_REG(ApplyAdagradDA)
  383. /**
  384. *@brief Returns the dimension index in the destination data format given the one in
  385. * the source data format.
  386. *
  387. *@par Inputs:
  388. * x: A tensor of type int32 or int64.
  389. * A Tensor with each element as a dimension index in source data format.
  390. * Must be in the range [-4, 4).
  391. *
  392. *@par Attributes:
  393. *@li src_format: An optional string. Defaults to NHWC.
  394. * source data format.
  395. *@li dst_format: An optional string. Defaults to NCHW.
  396. * destination data format.
  397. *
  398. *@par Outputs:
  399. * y: A tensor. Has the same type as "x".
  400. *
  401. */
  402. REG_OP(DataFormatDimMap)
  403. .INPUT(x, TensorType::IndexNumberType())
  404. .ATTR(src_format, String, "NHWC")
  405. .ATTR(dst_format, String, "NCHW")
  406. .OUTPUT(y, TensorType::IndexNumberType())
  407. .OP_END_FACTORY_REG(DataFormatDimMap)
  408. /**
  409. * @brief Implements stochastic gradient descent (optionally with momentum).\n
  410. * Nesterov momentum is based on the formula from
  411. * On the importance of initialization and momentum in deep learning.\n
  412. * @par Inputs:
  413. * @li parameters: A mutable tensor of type float16 or float32.\n
  414. * Specifies the iterable of parameters to optimize or dicts defining parameter
  415. * groups.
  416. * @li gradient: A tensor of type float16 or float32.\n
  417. * Specifies the gradient of training step.
  418. * @li learning_rate: A tensor of type float16 or float32.\n
  419. * Specifies the learing_rate of training step.
  420. * @li accum: A tensor of type float16 or float32.
  421. * Specifies the velocity of training step.
  422. * @li momentum: A tensor of type float16 or float32.
  423. * Specifies the momentum factor.
  424. * @li stat: A tensor of type float16 or float32.
  425. * Specifies the status representing the first step or not.
  426. * @par Attributes:
  427. * @li dampening: An optional float, specifying the dampening for momentum.
  428. * Defaults to "0.0".
  429. * @li weight_decay: An optional float, specifying the L2 penalty. Defaults to
  430. * "0.0".
  431. * @li nesterov: An optional bool, specifying whether to enable Nesterov
  432. * momentum. Defaults to "False".
  433. * @par Outputs:
  434. * parameters: A mutable tensor same as input "parameters".
  435. * @see ApplyMomentum()
  436. */
  437. REG_OP(SGD)
  438. .INPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  439. .INPUT(gradient, TensorType(DT_FLOAT, DT_FLOAT16))
  440. .INPUT(learning_rate, TensorType(DT_FLOAT, DT_FLOAT16))
  441. .INPUT(accum, TensorType(DT_FLOAT, DT_FLOAT16))
  442. .INPUT(momentum, TensorType(DT_FLOAT, DT_FLOAT16))
  443. .INPUT(stat, TensorType(DT_FLOAT, DT_FLOAT16))
  444. .OUTPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  445. .ATTR(dampening, Float, 0.0)
  446. .ATTR(weight_decay, Float, 0.0)
  447. .ATTR(nesterov, Bool, false)
  448. .OP_END_FACTORY_REG(SGD)
  449. /**
  450. * @brief Updates "var" according to the RMSProp algorithm.\n
  451. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  452. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  453. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  454. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  455. * var <- var - mom\n
  456. *
  457. * @attention Constraints:
  458. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will\n
  459. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom"\n
  460. * will not update in iterations during which "grad" is 0.\n
  461. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  462. *
  463. * @par Inputs:
  464. * @li var: A mutable tensor. Must be one of the data types defined in\n
  465. * TensorType::NumberType(). Should be from a Variable().
  466. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  467. * Variable().
  468. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  469. * Variable().
  470. * @li lr: A scalar. Must have the same type as "var".
  471. * @li rho: A scalar. Must have the same type as "var".
  472. * @li momentum: A scalar. Must have the same type as "var".
  473. * @li epsilon: A scalar. Must have the same type as "var".
  474. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  475. *
  476. * @par Attributes:
  477. * use_locking: An optional "bool". Defaults to "False". If "True", updating of\n
  478. * the "var", "ms", and "mom" tensors will be protected by a lock; otherwise the\n
  479. * behavior is undefined, but may exhibit less contention.
  480. *
  481. * @par Outputs:
  482. * var: A mutable tensor. Has the same type as input "var".
  483. */
  484. REG_OP(ApplyRMSProp)
  485. .INPUT(var, TensorType::NumberType())
  486. .INPUT(ms, TensorType::NumberType())
  487. .INPUT(mom, TensorType::NumberType())
  488. .INPUT(lr, TensorType::NumberType())
  489. .INPUT(rho, TensorType::NumberType())
  490. .INPUT(momentum, TensorType::NumberType())
  491. .INPUT(epsilon, TensorType::NumberType())
  492. .INPUT(grad, TensorType::NumberType())
  493. .OUTPUT(var, TensorType::NumberType())
  494. .ATTR(use_locking, Bool, false)
  495. .OP_END_FACTORY_REG(ApplyRMSProp)
  496. /**
  497. * @brief Updates "var" according to the RMSProp algorithm, a const input will be
  498. * considered as an attribute.\n
  499. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  500. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  501. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  502. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  503. * var <- var - mom
  504. *
  505. * @attention Constraints:
  506. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will\n
  507. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom"\n
  508. * will not update in iterations during which "grad" is 0.
  509. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  510. *
  511. * @par Inputs:
  512. * @li var: A mutable tensor. Must be one of the data types defined in\n
  513. * TensorType::NumberType(). Should be from a Variable().
  514. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  515. * Variable().
  516. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  517. * Variable().
  518. * @li lr: A scalar. Must have the same type as "var".
  519. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  520. *
  521. * @par Attributes:
  522. * @li use_locking: An optional "bool". Defaults to "False". If "True", updating\n
  523. * of the "var", "ms", and "mom" tensors will be protected by a lock; otherwise
  524. * the behavior is undefined, but may exhibit less contention.
  525. * @li rho: A scalar. Must have the same type as "var".
  526. * @li momentum: A scalar. Must have the same type as "var".
  527. * @li epsilon: A scalar. Must have the same type as "var".
  528. *
  529. * @par Outputs:
  530. * var: A mutable tensor. Must have the same type as input "var".
  531. */
  532. REG_OP(ApplyRMSPropD)
  533. .INPUT(var, TensorType::NumberType())
  534. .INPUT(ms, TensorType::NumberType())
  535. .INPUT(mom, TensorType::NumberType())
  536. .INPUT(lr, TensorType::NumberType())
  537. .INPUT(grad, TensorType::NumberType())
  538. .OUTPUT(var, TensorType::NumberType())
  539. .REQUIRED_ATTR(rho, Float)
  540. .REQUIRED_ATTR(momentum, Float)
  541. .REQUIRED_ATTR(epsilon, Float)
  542. .ATTR(use_locking, Bool, false)
  543. .OP_END_FACTORY_REG(ApplyRMSPropD)
  544. /**
  545. *@brief Update "var" and "accum" according to FOBOS with Adagrad learning rate.
  546. *@par Inputs:
  547. *Six inputs, including:
  548. * @li var: A mutable Tensor of type TensorType::NumberType().
  549. * Should be from a Variable().
  550. * @li accum: A mutable Tensor of the same type as "var". Should be from a Variable().
  551. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  552. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  553. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  554. * @li grad: A Tensor of the same type as "var", for the gradient.
  555. *@par Attributes:
  556. *use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention.
  557. *@par Outputs:
  558. *var: A mutable Tensor. Has the same type as "var".
  559. */
  560. REG_OP(ApplyProximalAdagrad)
  561. .INPUT(var, TensorType::NumberType())
  562. .INPUT(accum, TensorType::NumberType())
  563. .INPUT(lr, TensorType::NumberType())
  564. .INPUT(l1, TensorType::NumberType())
  565. .INPUT(l2, TensorType::NumberType())
  566. .INPUT(grad, TensorType::NumberType())
  567. .OUTPUT(var, TensorType::NumberType())
  568. .ATTR(use_locking, Bool, false)
  569. .OP_END_FACTORY_REG(ApplyProximalAdagrad)
  570. /**
  571. *@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.\ n
  572. * Compared with op ApplyProximalAdagrad, an additional index tensor is input,
  573. * Only the indices into the first dimensions of "var" and "accum" are updated.
  574. *@par Inputs:
  575. * Seven inputs, including:\n
  576. * @li var: A mutable Tensor.\n
  577. * TensorType::NumberType(). Should be a Variable Tensor.
  578. * @li accum: A mutable Tensor of the same type as "var".\n
  579. * Should be a Variable Tensor.
  580. * @li lr: A Tensor of the same type as "var".\n
  581. * Scaling factor. Must be a scalar.
  582. * @li l1: A Tensor of the same type as "var".\n
  583. * L1 regulariation. Must be a scalar.
  584. * @li l2: A Tensor of the same type as "var".\n
  585. * L2 regulariation. Must be a scalar.
  586. * @li grad: A Tensor. Has the same type as "var". \n
  587. * The gradient.
  588. * @li indices: A vector of indices into the first dimension of "var" and "accum".\n
  589. * TensorType::IndexNumberType().
  590. *@par Attributes:
  591. *use_locking: An optional bool. Defaults to "False".\n
  592. * If "True", updating of the var and accum tensors will be protected by a lock; \n
  593. * If "False", the behavior is undefined, but may exhibit less contention.
  594. *@par Outputs:
  595. *var: A mutable Tensor. Has the same type as "var".
  596. */
  597. REG_OP(SparseApplyProximalAdagrad)
  598. .INPUT(var, TensorType::NumberType())
  599. .INPUT(accum, TensorType::NumberType())
  600. .INPUT(lr, TensorType::NumberType())
  601. .INPUT(l1, TensorType::NumberType())
  602. .INPUT(l2, TensorType::NumberType())
  603. .INPUT(grad, TensorType::NumberType())
  604. .INPUT(indices, TensorType::IndexNumberType())
  605. .OUTPUT(var, TensorType::NumberType())
  606. .ATTR(use_locking, Bool, false)
  607. .OP_END_FACTORY_REG(SparseApplyProximalAdagrad)
  608. /**
  609. *@brief Updates "var" according to the Ftrl-proximal scheme.
  610. *@par Inputs:
  611. *Eight inputs, including:
  612. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  613. * Should be a Variable Tensor.
  614. * @li accum: A mutable Tensor of the same type as "var".
  615. * Should be a Variable Tensor.
  616. * @li linear: A mutable Tensor of the same type as "var".
  617. * Should be a Variable Tensor.
  618. * @li grad: A Tensor of the same type as "var", for the gradient.
  619. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  620. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  621. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  622. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  623. *@par Attributes:
  624. *use_locking: An optional bool. Defaults to "False".
  625. * If "True", updating of the "var" and "accum" tensors will be
  626. * protected by a lock; otherwise the behavior is undefined,
  627. * but may exhibit less contention.
  628. *@par Outputs:
  629. *var: A mutable Tensor. Has the same type as "var".
  630. */
  631. REG_OP(ApplyFtrl)
  632. .INPUT(var, TensorType::NumberType())
  633. .INPUT(accum, TensorType::NumberType())
  634. .INPUT(linear, TensorType::NumberType())
  635. .INPUT(grad, TensorType::NumberType())
  636. .INPUT(lr, TensorType::NumberType())
  637. .INPUT(l1, TensorType::NumberType())
  638. .INPUT(l2, TensorType::NumberType())
  639. .INPUT(lr_power, TensorType::NumberType())
  640. .OUTPUT(var, TensorType::NumberType())
  641. .ATTR(use_locking, Bool, false)
  642. .OP_END_FACTORY_REG(ApplyFtrl)
  643. /**
  644. *@brief Update "var" according to the Ftrl-proximal scheme.
  645. *@par Inputs:
  646. *Nine inputs, including:
  647. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  648. * Should be a Variable Tensor.
  649. * @li accum: A mutable Tensor of the same type as "var".
  650. * Should be a Variable Tensor.
  651. * @li linear: A mutable Tensor of the same type as "var".
  652. * Should be a Variable Tensor.
  653. * @li grad: A Tensor of the same type as "var", for the gradient.
  654. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  655. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  656. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  657. * @li l2_shrinkage: A Tensor of the same type as "var".
  658. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  659. *@par Attributes:
  660. *use_locking: An optional bool. Defaults to "False".
  661. * If "True", updating of the "var" and "accum" tensors will be
  662. * protected by a lock; otherwise the behavior is undefined,
  663. * but may exhibit less contention.
  664. *@par Outputs:
  665. *var: A mutable Tensor. Has the same type as "var".
  666. */
  667. REG_OP(ApplyFtrlV2)
  668. .INPUT(var, TensorType::NumberType())
  669. .INPUT(accum, TensorType::NumberType())
  670. .INPUT(linear, TensorType::NumberType())
  671. .INPUT(grad, TensorType::NumberType())
  672. .INPUT(lr, TensorType::NumberType())
  673. .INPUT(l1, TensorType::NumberType())
  674. .INPUT(l2, TensorType::NumberType())
  675. .INPUT(l2_shrinkage, TensorType::NumberType())
  676. .INPUT(lr_power, TensorType::NumberType())
  677. .OUTPUT(var, TensorType::NumberType())
  678. .ATTR(use_locking, Bool, false)
  679. .OP_END_FACTORY_REG(ApplyFtrlV2)
  680. /**
  681. *@brief Updates "var" according to the Adam algorithm.\n
  682. * lr_t <- text{learning\_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)\n
  683. * m_t <- beta_1 * m_{t-1} + (1 - beta_1) * g\n
  684. * v_t <- max(beta2 * v{t-1}, abs(g))\n
  685. * variable <- variable - lr_t * m_t / (sqrt{v_t} + epsilon)
  686. *
  687. *@attention Constraints:\n
  688. * *The input tensors must have the same shape.*
  689. *
  690. *@par Inputs:
  691. *@li var: A mutable Tensor of the type TensorType::NumberType().
  692. * Should be from a Variable().
  693. *@li m: A mutable Tensor of the same type as "var".
  694. * Should be from a Variable().
  695. *@li v: A mutable Tensor of the same type as "var".
  696. * Should be from a Variable().
  697. *@li beta1_power: A scalar of the same type as "var".
  698. *@li beta2_power: A scalar of the same type as "var".
  699. *@li lr: learning_rate. A scalar of the same type as "var".
  700. *@li beta1: A scalar of the same type as "var".
  701. *@li beta2: A scalar of the same type as "var".
  702. *@li epsilon: A scalar of the same type as "var".
  703. *@li grad: A Tensor of the same type as "var", for the gradient.
  704. *
  705. *@par Attributes:\n
  706. *@li use_locking: An optional bool. Defaults to "False".
  707. * If "True", updating of the "var", m", and "v" tensors will be protected
  708. * by a lock; otherwise the behavior is undefined, but may exhibit less
  709. * contention.
  710. *@li use_nesterov: An optional bool. Defaults to "False".
  711. If "True", uses the nesterov update.
  712. *
  713. *@par Outputs:
  714. * var: A mutable Tensor. Has the same type as intput "var".
  715. */
  716. REG_OP(ApplyAdam)
  717. .INPUT(var, TensorType::NumberType())
  718. .INPUT(m, TensorType::NumberType())
  719. .INPUT(v, TensorType::NumberType())
  720. .INPUT(beta1_power, TensorType::NumberType())
  721. .INPUT(beta2_power, TensorType::NumberType())
  722. .INPUT(lr, TensorType::NumberType())
  723. .INPUT(beta1, TensorType::NumberType())
  724. .INPUT(beta2, TensorType::NumberType())
  725. .INPUT(epsilon, TensorType::NumberType())
  726. .INPUT(grad, TensorType::NumberType())
  727. .OUTPUT(var, TensorType::NumberType())
  728. .ATTR(use_locking, Bool, false)
  729. .ATTR(use_nesterov, Bool, false)
  730. .OP_END_FACTORY_REG(ApplyAdam)
  731. /**
  732. *@brief Updates "var" according to the proximal adadelta scheme.
  733. *@par Inputs:
  734. *Seven inputs, including:
  735. * @li var: A mutable Tensor of type TensorType::NumberType().
  736. * Should be a Variable Tensor.
  737. * @li accum: A mutable Tensor of the same type as "var".
  738. * Should be a Variable Tensor.
  739. * @li accum_update: A mutable Tensor of the same type as "var".
  740. * Should be a Variable Tensor.
  741. * @li lr: A scalar of the same type as "var", for the scaling factor.
  742. * @li rho: A scalar of the same type as "var", for the decay factor.
  743. * @li epsilon: A scalar of the same type as "var", for the constant factor.
  744. * @li grad: A Tensor of the same type as "var", for the gradient.
  745. *@par Attributes:
  746. *use_locking: An optional bool. Defaults to "False".
  747. * If "True", updating of the "var", "accum" and "accum_update" tensors will be
  748. * protected by a lock; otherwise the behavior is undefined,
  749. * but may exhibit less contention.
  750. *@par Outputs:
  751. *var: A mutable Tensor. Has the same type as "var".
  752. */
  753. REG_OP(ApplyAdadelta)
  754. .INPUT(var, TensorType::NumberType())
  755. .INPUT(accum, TensorType::NumberType())
  756. .INPUT(accum_update, TensorType::NumberType())
  757. .INPUT(lr, TensorType::NumberType())
  758. .INPUT(rho, TensorType::NumberType())
  759. .INPUT(epsilon, TensorType::NumberType())
  760. .INPUT(grad, TensorType::NumberType())
  761. .OUTPUT(var, TensorType::NumberType())
  762. .ATTR(use_locking, Bool, false)
  763. .OP_END_FACTORY_REG(ApplyAdadelta)
  764. REG_OP(FusedMulApplyMomentum)
  765. .INPUT(var, TensorType::NumberType())
  766. .INPUT(accum, TensorType::NumberType())
  767. .INPUT(lr, TensorType::NumberType())
  768. .INPUT(x1, TensorType::NumberType())
  769. .INPUT(momentum, TensorType::NumberType())
  770. .INPUT(x2, TensorType::NumberType())
  771. .OUTPUT(var, TensorType::NumberType())
  772. .ATTR(use_nesterov, Bool, false)
  773. .ATTR(use_locking, Bool, false)
  774. .OP_END_FACTORY_REG(FusedMulApplyMomentum)
  775. REG_OP(FusedMulApplyMomentumExtern)
  776. .INPUT(var, TensorType::NumberType())
  777. .INPUT(accum, TensorType::NumberType())
  778. .INPUT(lr, TensorType::NumberType())
  779. .INPUT(x1, TensorType::NumberType())
  780. .INPUT(momentum, TensorType::NumberType())
  781. .INPUT(x2, TensorType::NumberType())
  782. .INPUT(var_copy, TensorType::NumberType())
  783. .OUTPUT(var, TensorType::NumberType())
  784. .OUTPUT(var_copy, TensorType::NumberType())
  785. .ATTR(use_nesterov, Bool, false)
  786. .ATTR(use_locking, Bool, false)
  787. .OP_END_FACTORY_REG(FusedMulApplyMomentumExtern)
  788. REG_OP(LarsV2)
  789. .INPUT(w, TensorType(DT_FLOAT))
  790. .INPUT(g, TensorType(DT_FLOAT))
  791. .INPUT(weight_decay, TensorType(DT_FLOAT))
  792. .INPUT(learning_rate, TensorType(DT_FLOAT))
  793. .OUTPUT(g_new, TensorType(DT_FLOAT))
  794. .ATTR(hyperpara, Float, 0.001)
  795. .ATTR(epsilon, Float, 0.00001)
  796. .ATTR(use_clip, Bool, false)
  797. .OP_END_FACTORY_REG(LarsV2)
  798. REG_OP(LarsV2Update)
  799. .INPUT(w, TensorType(DT_FLOAT))
  800. .INPUT(g, TensorType(DT_FLOAT))
  801. .INPUT(w_square_sum, TensorType(DT_FLOAT))
  802. .INPUT(g_square_sum, TensorType(DT_FLOAT))
  803. .INPUT(weight_decay, TensorType(DT_FLOAT))
  804. .INPUT(learning_rate, TensorType(DT_FLOAT))
  805. .OUTPUT(g_new, TensorType(DT_FLOAT))
  806. .ATTR(hyperpara, Float, 0.001)
  807. .ATTR(epsilon, Float, 0.00001)
  808. .ATTR(use_clip, Bool, false)
  809. .OP_END_FACTORY_REG(LarsV2Update)
  810. } // namespace ge
  811. #endif // GE_OP_TRAINING_OPS_H

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示