diff --git a/imperative/python/megengine/optimizer/adadelta.py b/imperative/python/megengine/optimizer/adadelta.py index 1c321d21..81565a1c 100644 --- a/imperative/python/megengine/optimizer/adadelta.py +++ b/imperative/python/megengine/optimizer/adadelta.py @@ -48,6 +48,7 @@ class Adadelta(Optimizer): defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/adagrad.py b/imperative/python/megengine/optimizer/adagrad.py index c983c791..fadbf48f 100644 --- a/imperative/python/megengine/optimizer/adagrad.py +++ b/imperative/python/megengine/optimizer/adagrad.py @@ -48,6 +48,7 @@ class Adagrad(Optimizer): defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/adam.py b/imperative/python/megengine/optimizer/adam.py index 40d5eec5..9e51c90a 100644 --- a/imperative/python/megengine/optimizer/adam.py +++ b/imperative/python/megengine/optimizer/adam.py @@ -47,6 +47,7 @@ class Adam(Optimizer): defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/adamw.py b/imperative/python/megengine/optimizer/adamw.py index aec655e0..cd3f2d91 100644 --- a/imperative/python/megengine/optimizer/adamw.py +++ b/imperative/python/megengine/optimizer/adamw.py @@ -47,6 +47,7 @@ class AdamW(Optimizer): defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index 8b2d4858..b6f60cd7 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -42,6 +42,7 @@ class Optimizer(metaclass=ABCMeta): ): self._state = dict() self._defaults = defaults + self._disable_type_convert = False if isinstance(params, (Parameter, dict)): params = [params] @@ -149,7 +150,8 @@ class Optimizer(metaclass=ABCMeta): # set the globle state `_enable_convert_inputs` to `False` to disable # the `convert_inputs` for param updates set_option("record_computing_path", 0) - backup = set_convert_inputs(False) + if self._disable_type_convert: + backup = set_convert_inputs(False) for group in self.param_groups: if isinstance(group["params"], set): raise TypeError( @@ -160,8 +162,9 @@ class Optimizer(metaclass=ABCMeta): push_scope("step") self._updates(group) pop_scope("step") - # restore the globle state `_enable_convert_inputs` - set_convert_inputs(backup) + if self._disable_type_convert: + # restore the globle state `_enable_convert_inputs` + set_convert_inputs(backup) set_option("record_computing_path", 1) return self diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index 5ed256d2..9c939eb3 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -43,6 +43,7 @@ class SGD(Optimizer): defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): if param_group["momentum"] != 0.0: