GitOrigin-RevId: 3e5d0612f0
tags/v1.0.0-rc1
@@ -12,10 +12,10 @@ import numpy as np | |||||
from ..functional import sqrt | from ..functional import sqrt | ||||
from ..tensor_nn import Buffer, Parameter | from ..tensor_nn import Buffer, Parameter | ||||
from .distributed_optimizer import DistributedOptimizer | |||||
from .optimizer import Optimizer | |||||
class Adadelta(DistributedOptimizer): | |||||
class Adadelta(Optimizer): | |||||
r"""Implements Adadelta algorithm. | r"""Implements Adadelta algorithm. | ||||
It has been proposed in `"ADADELTA: An Adaptive Learning Rate Method" <https://arxiv.org/abs/1212.5701>`_. | It has been proposed in `"ADADELTA: An Adaptive Learning Rate Method" <https://arxiv.org/abs/1212.5701>`_. | ||||
@@ -38,7 +38,6 @@ class Adadelta(DistributedOptimizer): | |||||
rho: float = 0.9, | rho: float = 0.9, | ||||
eps: float = 1e-6, | eps: float = 1e-6, | ||||
weight_decay: float = 0.0, | weight_decay: float = 0.0, | ||||
**kwargs | |||||
): | ): | ||||
assert lr >= 0.0, "Invalid learning rate: {}".format(lr) | assert lr >= 0.0, "Invalid learning rate: {}".format(lr) | ||||
assert rho >= 0.0 and rho <= 1.0, "Invalid rho value: {}".format(rho) | assert rho >= 0.0 and rho <= 1.0, "Invalid rho value: {}".format(rho) | ||||
@@ -48,7 +47,7 @@ class Adadelta(DistributedOptimizer): | |||||
) | ) | ||||
defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay) | defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay) | ||||
super().__init__(params, defaults, **kwargs) | |||||
super().__init__(params, defaults) | |||||
def _create_state(self, param_group): | def _create_state(self, param_group): | ||||
for param in param_group["params"]: | for param in param_group["params"]: | ||||
@@ -12,10 +12,10 @@ import numpy as np | |||||
from ..functional import sqrt | from ..functional import sqrt | ||||
from ..tensor_nn import Buffer, Parameter | from ..tensor_nn import Buffer, Parameter | ||||
from .distributed_optimizer import DistributedOptimizer | |||||
from .optimizer import Optimizer | |||||
class Adagrad(DistributedOptimizer): | |||||
class Adagrad(Optimizer): | |||||
r"""Implements Adagrad algorithm. | r"""Implements Adagrad algorithm. | ||||
It has been proposed in `"Adaptive Subgradient Methods for Online Learning | It has been proposed in `"Adaptive Subgradient Methods for Online Learning | ||||
@@ -38,7 +38,6 @@ class Adagrad(DistributedOptimizer): | |||||
lr_decay: float = 0.0, | lr_decay: float = 0.0, | ||||
eps: float = 1e-10, | eps: float = 1e-10, | ||||
weight_decay: float = 0.0, | weight_decay: float = 0.0, | ||||
**kwargs | |||||
): | ): | ||||
assert lr >= 0.0, "Invalid learning rate: {}".format(lr) | assert lr >= 0.0, "Invalid learning rate: {}".format(lr) | ||||
assert lr_decay >= 0, "Invalid learning rate decay: {}".format(lr_decay) | assert lr_decay >= 0, "Invalid learning rate decay: {}".format(lr_decay) | ||||
@@ -48,7 +47,7 @@ class Adagrad(DistributedOptimizer): | |||||
) | ) | ||||
defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay) | defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay) | ||||
super().__init__(params, defaults, **kwargs) | |||||
super().__init__(params, defaults) | |||||
def _create_state(self, param_group): | def _create_state(self, param_group): | ||||
for param in param_group["params"]: | for param in param_group["params"]: | ||||
@@ -9,10 +9,10 @@ | |||||
from typing import Iterable, Tuple, Union | from typing import Iterable, Tuple, Union | ||||
from ..tensor_nn import Buffer, Parameter | from ..tensor_nn import Buffer, Parameter | ||||
from .distributed_optimizer import DistributedOptimizer | |||||
from .optimizer import Optimizer | |||||
class Adam(DistributedOptimizer): | |||||
class Adam(Optimizer): | |||||
r"""Implements Adam algorithm proposed in `"Adam: A Method for Stochastic Optimization" <https://arxiv.org/abs/1412.6980>`_. | r"""Implements Adam algorithm proposed in `"Adam: A Method for Stochastic Optimization" <https://arxiv.org/abs/1412.6980>`_. | ||||
:param params: iterable of parameters to optimize or dicts defining | :param params: iterable of parameters to optimize or dicts defining | ||||
@@ -32,7 +32,6 @@ class Adam(DistributedOptimizer): | |||||
betas: Tuple[float, float] = (0.9, 0.999), | betas: Tuple[float, float] = (0.9, 0.999), | ||||
eps: float = 1e-8, | eps: float = 1e-8, | ||||
weight_decay: float = 0.0, | weight_decay: float = 0.0, | ||||
**kwargs | |||||
): | ): | ||||
if lr < 0.0: | if lr < 0.0: | ||||
raise ValueError("Invalid learning rate: {}".format(lr)) | raise ValueError("Invalid learning rate: {}".format(lr)) | ||||
@@ -44,7 +43,7 @@ class Adam(DistributedOptimizer): | |||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | ||||
defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) | defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) | ||||
super().__init__(params, defaults, **kwargs) | |||||
super().__init__(params, defaults) | |||||
def _create_state(self, param_group): | def _create_state(self, param_group): | ||||
for param in param_group["params"]: | for param in param_group["params"]: | ||||
@@ -8,7 +8,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from abc import ABCMeta | from abc import ABCMeta | ||||
from .distributed_optimizer import DistributedOptimizer | |||||
from .optimizer import Optimizer | |||||
class LRScheduler(metaclass=ABCMeta): | class LRScheduler(metaclass=ABCMeta): | ||||
@@ -19,9 +19,9 @@ class LRScheduler(metaclass=ABCMeta): | |||||
""" | """ | ||||
def __init__( # pylint: disable=too-many-branches | def __init__( # pylint: disable=too-many-branches | ||||
self, optimizer: DistributedOptimizer, current_epoch: int = -1 | |||||
self, optimizer: Optimizer, current_epoch: int = -1 | |||||
): | ): | ||||
if not isinstance(optimizer, DistributedOptimizer): | |||||
if not isinstance(optimizer, Optimizer): | |||||
raise TypeError( | raise TypeError( | ||||
"optimizer argument given to the lr_scheduler should be Optimizer" | "optimizer argument given to the lr_scheduler should be Optimizer" | ||||
) | ) | ||||
@@ -9,7 +9,7 @@ | |||||
from bisect import bisect_right | from bisect import bisect_right | ||||
from typing import Iterable as Iter | from typing import Iterable as Iter | ||||
from .distributed_optimizer import DistributedOptimizer | |||||
from .optimizer import Optimizer | |||||
from .lr_scheduler import LRScheduler | from .lr_scheduler import LRScheduler | ||||
@@ -25,7 +25,7 @@ class MultiStepLR(LRScheduler): | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
optimizer: DistributedOptimizer, | |||||
optimizer: Optimizer, | |||||
milestones: Iter[int], | milestones: Iter[int], | ||||
gamma: float = 0.1, | gamma: float = 0.1, | ||||
current_epoch: int = -1, | current_epoch: int = -1, | ||||
@@ -9,10 +9,10 @@ | |||||
from typing import Iterable, Union | from typing import Iterable, Union | ||||
from ..tensor_nn import Buffer, Parameter | from ..tensor_nn import Buffer, Parameter | ||||
from .distributed_optimizer import DistributedOptimizer | |||||
from .optimizer import Optimizer | |||||
class SGD(DistributedOptimizer): | |||||
class SGD(Optimizer): | |||||
r"""Implements stochastic gradient descent. | r"""Implements stochastic gradient descent. | ||||
Nesterov momentum is based on the formula from | Nesterov momentum is based on the formula from | ||||
@@ -31,7 +31,6 @@ class SGD(DistributedOptimizer): | |||||
lr: float, | lr: float, | ||||
momentum: float = 0.0, | momentum: float = 0.0, | ||||
weight_decay: float = 0.0, | weight_decay: float = 0.0, | ||||
**kwargs | |||||
): | ): | ||||
assert lr >= 0.0, "Invalid learning rate: {}".format(lr) | assert lr >= 0.0, "Invalid learning rate: {}".format(lr) | ||||
assert momentum >= 0.0, "Invalid momentum value: {}".format(momentum) | assert momentum >= 0.0, "Invalid momentum value: {}".format(momentum) | ||||
@@ -40,7 +39,7 @@ class SGD(DistributedOptimizer): | |||||
) | ) | ||||
defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) | defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) | ||||
super().__init__(params, defaults, **kwargs) | |||||
super().__init__(params, defaults) | |||||
def _create_state(self, param_group): | def _create_state(self, param_group): | ||||
if param_group["momentum"] != 0.0: | if param_group["momentum"] != 0.0: | ||||