diff --git a/third_party/fwkacllib/inc/ops/nn_training_ops.h b/third_party/fwkacllib/inc/ops/nn_training_ops.h index b8f4003e..922869c3 100644 --- a/third_party/fwkacllib/inc/ops/nn_training_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_training_ops.h @@ -299,6 +299,18 @@ REG_OP(ApplyMomentumCCE) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyMomentumCCE) +REG_OP(ApplyMomentumD) + .INPUT(var, TensorType::NumberType()) + .INPUT(accum, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(momentum, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(accum, TensorType::NumberType()) + .ATTR(use_nesterov, Bool, false) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyMomentumD) + /** *@brief Updates "var" according to the AddSign update.\n * t-1 mean previous period.