|
|
@@ -299,6 +299,18 @@ REG_OP(ApplyMomentumCCE) |
|
|
|
.ATTR(use_locking, Bool, false) |
|
|
|
.OP_END_FACTORY_REG(ApplyMomentumCCE) |
|
|
|
|
|
|
|
REG_OP(ApplyMomentumD) |
|
|
|
.INPUT(var, TensorType::NumberType()) |
|
|
|
.INPUT(accum, TensorType::NumberType()) |
|
|
|
.INPUT(lr, TensorType::NumberType()) |
|
|
|
.INPUT(grad, TensorType::NumberType()) |
|
|
|
.INPUT(momentum, TensorType::NumberType()) |
|
|
|
.OUTPUT(var, TensorType::NumberType()) |
|
|
|
.OUTPUT(accum, TensorType::NumberType()) |
|
|
|
.ATTR(use_nesterov, Bool, false) |
|
|
|
.ATTR(use_locking, Bool, false) |
|
|
|
.OP_END_FACTORY_REG(ApplyMomentumD) |
|
|
|
|
|
|
|
/** |
|
|
|
*@brief Updates "var" according to the AddSign update.\n |
|
|
|
* t-1 mean previous period. |
|
|
|