GitOrigin-RevId: c710530d93
release-1.1
@@ -16,6 +16,25 @@ from ..ops.special import Const | |||
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
from .dtype import is_equal, is_quantize | |||
_enable_convert_inputs = True | |||
def get_convert_inputs(): | |||
""" get the curerent state of `_enable_convert_inputs` """ | |||
return _enable_convert_inputs | |||
def set_convert_inputs(flag): | |||
""" This function is a temporary workaround for reducing the overhead of operator | |||
invocations. The function `convert_inputs` is disabled if the global state | |||
`_enable_convert_inputs` is set to `False`, otherwise enabled. This function is for | |||
internal use only, and should be removed when the tensor-like system is refactored. | |||
""" | |||
global _enable_convert_inputs | |||
backup = _enable_convert_inputs | |||
_enable_convert_inputs = flag | |||
return backup | |||
def dtype_promotion(inputs): | |||
""" | |||
@@ -129,6 +148,9 @@ def convert_single_value(v, inputs, *, dtype=None, device=None): | |||
def convert_inputs(*args: TensorBase): | |||
if not _enable_convert_inputs: | |||
return args | |||
dtype = dtype_promotion(args) | |||
device = get_device(args) | |||
@@ -10,8 +10,8 @@ from typing import Iterable, Union | |||
import numpy as np | |||
from ..functional import sqrt | |||
from ..tensor import Parameter | |||
from ..core.tensor.tensor import Tensor | |||
from ..tensor import Parameter, tensor | |||
from .optimizer import Optimizer | |||
@@ -62,6 +62,16 @@ class Adadelta(Optimizer): | |||
rho = param_group["rho"] | |||
eps = param_group["eps"] | |||
# since `conver_inputs` is disabled for param updates, | |||
# scalar should be explicitly tansforred to tensor | |||
_lr = tensor([lr]) | |||
_weight_decay = tensor([weight_decay]) | |||
_rho = tensor([rho]) | |||
_eps = tensor([eps]) | |||
c05 = tensor([0.5]) | |||
c1 = tensor([1.0]) | |||
c2 = tensor([2.0]) | |||
for param in param_group["params"]: | |||
if param.grad is None: | |||
@@ -69,17 +79,17 @@ class Adadelta(Optimizer): | |||
states = self._state[param] | |||
step = states["step"] | |||
step += 1.0 | |||
step += c1 | |||
grad = param.grad | |||
if weight_decay != 0.0: | |||
grad += param * weight_decay | |||
grad += param * _weight_decay | |||
square_avg = states["square_avg"] | |||
acc_delta = states["acc_delta"] | |||
square_avg = rho * square_avg + (1 - rho) * grad ** 2 | |||
std = sqrt(square_avg + eps) | |||
delta = sqrt(acc_delta + eps) / std * grad | |||
param -= lr * delta | |||
acc_delta = rho * acc_delta + (1 - rho) * delta ** 2 | |||
square_avg = _rho * square_avg + (c1 - _rho) * grad ** c2 | |||
std = (square_avg + _eps) ** c05 | |||
delta = (acc_delta + _eps) ** c05 / std * grad | |||
param -= _lr * delta | |||
acc_delta = _rho * acc_delta + (c1 - _rho) * delta ** c2 | |||
states["square_avg"]._reset(square_avg) | |||
states["acc_delta"]._reset(acc_delta) |
@@ -10,8 +10,8 @@ from typing import Iterable, Union | |||
import numpy as np | |||
from ..functional import sqrt | |||
from ..tensor import Parameter | |||
from ..core.tensor.tensor import Tensor | |||
from ..tensor import Parameter, tensor | |||
from .optimizer import Optimizer | |||
@@ -61,6 +61,16 @@ class Adagrad(Optimizer): | |||
weight_decay = param_group["weight_decay"] | |||
eps = param_group["eps"] | |||
# since `conver_inputs` is disabled for param updates, | |||
# scalar should be explicitly tansforred to tensor | |||
_lr = tensor([lr]) | |||
_lr_decay = tensor([lr_decay]) | |||
_weight_decay = tensor([weight_decay]) | |||
_eps = tensor([eps]) | |||
c05 = tensor([0.5]) | |||
c1 = tensor([1.0]) | |||
c2 = tensor([2.0]) | |||
for param in param_group["params"]: | |||
if param.grad is None: | |||
@@ -68,14 +78,14 @@ class Adagrad(Optimizer): | |||
states = self._state[param] | |||
step = states["step"] | |||
step += 1.0 | |||
step += c1 | |||
grad = param.grad | |||
if weight_decay != 0.0: | |||
grad += param * weight_decay | |||
grad += param * _weight_decay | |||
square_avg = states["square_avg"] | |||
square_avg += grad ** 2 | |||
delta = grad / sqrt(square_avg + eps) | |||
clr = lr / (1 + (step - 1) * lr_decay) | |||
square_avg += grad ** c2 | |||
delta = grad / (square_avg + _eps) ** c05 | |||
clr = _lr / (c1 + (step - c1) * _lr_decay) | |||
param -= clr * delta |
@@ -8,7 +8,8 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable, Tuple, Union | |||
from ..tensor import Parameter | |||
from ..core.tensor.tensor import Tensor | |||
from ..tensor import Parameter, tensor | |||
from .optimizer import Optimizer | |||
@@ -58,6 +59,15 @@ class Adam(Optimizer): | |||
eps = param_group["eps"] | |||
beta0, beta1 = param_group["betas"] | |||
# since `conver_inputs` is disabled for param updates, | |||
# scalar should be explicitly tansforred to tensor | |||
_lr = tensor([lr]) | |||
_weight_decay = tensor([weight_decay]) | |||
_eps = tensor([eps]) | |||
_beta0, _beta1 = tensor([beta0]), tensor([beta1]) | |||
c1 = tensor([1.0]) | |||
c05 = tensor([0.5]) | |||
for param in param_group["params"]: | |||
if param.grad is None: | |||
@@ -65,20 +75,20 @@ class Adam(Optimizer): | |||
grad = param.grad | |||
if weight_decay != 0.0: | |||
grad += param * weight_decay | |||
grad += param * _weight_decay | |||
states = self._state[param] | |||
step = states["step"] | |||
step += 1.0 | |||
step += c1 | |||
exp_avg = states["exp_avg"] | |||
exp_avg_sq = states["exp_avg_sq"] | |||
exp_avg = beta0 * exp_avg + grad * (1 - beta0) | |||
exp_avg_sq = beta1 * exp_avg_sq + (1 - beta1) * (grad * grad) | |||
exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0) | |||
exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad) | |||
delta = (exp_avg / (1 - beta0 ** step)) / ( | |||
(exp_avg_sq / (1 - beta1 ** step)) ** 0.5 + eps | |||
delta = (exp_avg / (c1 - _beta0 ** step)) / ( | |||
(exp_avg_sq / (c1 - _beta1 ** step)) ** c05 + _eps | |||
) | |||
param -= lr * delta | |||
param -= _lr * delta | |||
# not inplace change, need to update underlying tensor handler in state | |||
states["exp_avg"]._reset(exp_avg) | |||
@@ -15,6 +15,7 @@ from typing import Union | |||
import numpy as np | |||
from ..core.tensor.utils import set_convert_inputs | |||
from ..tensor import Parameter, Tensor | |||
from ..utils.deprecation import deprecated | |||
@@ -143,6 +144,9 @@ class Optimizer(metaclass=ABCMeta): | |||
Performs a single optimization step. | |||
""" | |||
# set the globle state `_enable_convert_inputs` to `False` to disable | |||
# the `convert_inputs` for param updates | |||
backup = set_convert_inputs(False) | |||
for group in self.param_groups: | |||
if isinstance(group["params"], set): | |||
raise TypeError( | |||
@@ -151,6 +155,8 @@ class Optimizer(metaclass=ABCMeta): | |||
"Please use a list instead." | |||
) | |||
self._updates(group) | |||
# restore the globle state `_enable_convert_inputs` | |||
set_convert_inputs(backup) | |||
return self | |||
@deprecated(version="1.0", reason="use clear_grad instead") | |||
@@ -8,7 +8,8 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable, Union | |||
from ..tensor import Parameter | |||
from ..core.tensor.tensor import Tensor | |||
from ..tensor import Parameter, tensor | |||
from .optimizer import Optimizer | |||
@@ -52,18 +53,24 @@ class SGD(Optimizer): | |||
weight_decay = param_group["weight_decay"] | |||
momentum = param_group["momentum"] | |||
# since `conver_inputs` is disabled for param updates, | |||
# scalar should be explicitly tansforred to tensor | |||
_lr = tensor([lr]) | |||
_weight_decay = tensor([weight_decay]) | |||
_momentum = tensor([momentum]) | |||
for param in param_group["params"]: | |||
if param.grad is None: | |||
continue | |||
grad = param.grad | |||
if weight_decay != 0.0: | |||
grad += param * weight_decay | |||
grad += param * _weight_decay | |||
if momentum: | |||
v = self._state[param]["momentum_buffer"] | |||
v = momentum * v + grad | |||
param -= lr * v | |||
v = _momentum * v + grad | |||
param -= _lr * v | |||
self._state[param]["momentum_buffer"]._reset(v) | |||
else: | |||
param -= lr * grad | |||
param -= _lr * grad |