diff --git a/imperative/python/megengine/optimizer/adadelta.py b/imperative/python/megengine/optimizer/adadelta.py index 9de92fa9..2eae5184 100644 --- a/imperative/python/megengine/optimizer/adadelta.py +++ b/imperative/python/megengine/optimizer/adadelta.py @@ -12,10 +12,10 @@ import numpy as np from ..functional import sqrt from ..tensor_nn import Buffer, Parameter -from .distributed_optimizer import DistributedOptimizer +from .optimizer import Optimizer -class Adadelta(DistributedOptimizer): +class Adadelta(Optimizer): r"""Implements Adadelta algorithm. It has been proposed in `"ADADELTA: An Adaptive Learning Rate Method" `_. @@ -38,7 +38,6 @@ class Adadelta(DistributedOptimizer): rho: float = 0.9, eps: float = 1e-6, weight_decay: float = 0.0, - **kwargs ): assert lr >= 0.0, "Invalid learning rate: {}".format(lr) assert rho >= 0.0 and rho <= 1.0, "Invalid rho value: {}".format(rho) @@ -48,7 +47,7 @@ class Adadelta(DistributedOptimizer): ) defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay) - super().__init__(params, defaults, **kwargs) + super().__init__(params, defaults) def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/adagrad.py b/imperative/python/megengine/optimizer/adagrad.py index 804c7abe..da1ad46b 100644 --- a/imperative/python/megengine/optimizer/adagrad.py +++ b/imperative/python/megengine/optimizer/adagrad.py @@ -12,10 +12,10 @@ import numpy as np from ..functional import sqrt from ..tensor_nn import Buffer, Parameter -from .distributed_optimizer import DistributedOptimizer +from .optimizer import Optimizer -class Adagrad(DistributedOptimizer): +class Adagrad(Optimizer): r"""Implements Adagrad algorithm. It has been proposed in `"Adaptive Subgradient Methods for Online Learning @@ -38,7 +38,6 @@ class Adagrad(DistributedOptimizer): lr_decay: float = 0.0, eps: float = 1e-10, weight_decay: float = 0.0, - **kwargs ): assert lr >= 0.0, "Invalid learning rate: {}".format(lr) assert lr_decay >= 0, "Invalid learning rate decay: {}".format(lr_decay) @@ -48,7 +47,7 @@ class Adagrad(DistributedOptimizer): ) defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay) - super().__init__(params, defaults, **kwargs) + super().__init__(params, defaults) def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/adam.py b/imperative/python/megengine/optimizer/adam.py index fac9f4cb..7901adb7 100644 --- a/imperative/python/megengine/optimizer/adam.py +++ b/imperative/python/megengine/optimizer/adam.py @@ -9,10 +9,10 @@ from typing import Iterable, Tuple, Union from ..tensor_nn import Buffer, Parameter -from .distributed_optimizer import DistributedOptimizer +from .optimizer import Optimizer -class Adam(DistributedOptimizer): +class Adam(Optimizer): r"""Implements Adam algorithm proposed in `"Adam: A Method for Stochastic Optimization" `_. :param params: iterable of parameters to optimize or dicts defining @@ -32,7 +32,6 @@ class Adam(DistributedOptimizer): betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, - **kwargs ): if lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -44,7 +43,7 @@ class Adam(DistributedOptimizer): raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) - super().__init__(params, defaults, **kwargs) + super().__init__(params, defaults) def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/lr_scheduler.py b/imperative/python/megengine/optimizer/lr_scheduler.py index 46d08d5d..d2b6c859 100644 --- a/imperative/python/megengine/optimizer/lr_scheduler.py +++ b/imperative/python/megengine/optimizer/lr_scheduler.py @@ -8,7 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from abc import ABCMeta -from .distributed_optimizer import DistributedOptimizer +from .optimizer import Optimizer class LRScheduler(metaclass=ABCMeta): @@ -19,9 +19,9 @@ class LRScheduler(metaclass=ABCMeta): """ def __init__( # pylint: disable=too-many-branches - self, optimizer: DistributedOptimizer, current_epoch: int = -1 + self, optimizer: Optimizer, current_epoch: int = -1 ): - if not isinstance(optimizer, DistributedOptimizer): + if not isinstance(optimizer, Optimizer): raise TypeError( "optimizer argument given to the lr_scheduler should be Optimizer" ) diff --git a/imperative/python/megengine/optimizer/multi_step_lr.py b/imperative/python/megengine/optimizer/multi_step_lr.py index 45cc74c3..fc3a43f4 100644 --- a/imperative/python/megengine/optimizer/multi_step_lr.py +++ b/imperative/python/megengine/optimizer/multi_step_lr.py @@ -9,7 +9,7 @@ from bisect import bisect_right from typing import Iterable as Iter -from .distributed_optimizer import DistributedOptimizer +from .optimizer import Optimizer from .lr_scheduler import LRScheduler @@ -25,7 +25,7 @@ class MultiStepLR(LRScheduler): def __init__( self, - optimizer: DistributedOptimizer, + optimizer: Optimizer, milestones: Iter[int], gamma: float = 0.1, current_epoch: int = -1, diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index 4dfb485b..88cffd07 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -9,10 +9,10 @@ from typing import Iterable, Union from ..tensor_nn import Buffer, Parameter -from .distributed_optimizer import DistributedOptimizer +from .optimizer import Optimizer -class SGD(DistributedOptimizer): +class SGD(Optimizer): r"""Implements stochastic gradient descent. Nesterov momentum is based on the formula from @@ -31,7 +31,6 @@ class SGD(DistributedOptimizer): lr: float, momentum: float = 0.0, weight_decay: float = 0.0, - **kwargs ): assert lr >= 0.0, "Invalid learning rate: {}".format(lr) assert momentum >= 0.0, "Invalid momentum value: {}".format(momentum) @@ -40,7 +39,7 @@ class SGD(DistributedOptimizer): ) defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) - super().__init__(params, defaults, **kwargs) + super().__init__(params, defaults) def _create_state(self, param_group): if param_group["momentum"] != 0.0: