GitOrigin-RevId: 1a48fe318d
tags/v1.6.0-rc1
@@ -48,6 +48,7 @@ class Adadelta(Optimizer): | |||||
defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay) | defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay) | ||||
super().__init__(params, defaults) | super().__init__(params, defaults) | ||||
self._disable_type_convert = True | |||||
def _create_state(self, param_group): | def _create_state(self, param_group): | ||||
for param in param_group["params"]: | for param in param_group["params"]: | ||||
@@ -48,6 +48,7 @@ class Adagrad(Optimizer): | |||||
defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay) | defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay) | ||||
super().__init__(params, defaults) | super().__init__(params, defaults) | ||||
self._disable_type_convert = True | |||||
def _create_state(self, param_group): | def _create_state(self, param_group): | ||||
for param in param_group["params"]: | for param in param_group["params"]: | ||||
@@ -47,6 +47,7 @@ class Adam(Optimizer): | |||||
defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) | defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) | ||||
super().__init__(params, defaults) | super().__init__(params, defaults) | ||||
self._disable_type_convert = True | |||||
def _create_state(self, param_group): | def _create_state(self, param_group): | ||||
for param in param_group["params"]: | for param in param_group["params"]: | ||||
@@ -47,6 +47,7 @@ class AdamW(Optimizer): | |||||
defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) | defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) | ||||
super().__init__(params, defaults) | super().__init__(params, defaults) | ||||
self._disable_type_convert = True | |||||
def _create_state(self, param_group): | def _create_state(self, param_group): | ||||
for param in param_group["params"]: | for param in param_group["params"]: | ||||
@@ -42,6 +42,7 @@ class Optimizer(metaclass=ABCMeta): | |||||
): | ): | ||||
self._state = dict() | self._state = dict() | ||||
self._defaults = defaults | self._defaults = defaults | ||||
self._disable_type_convert = False | |||||
if isinstance(params, (Parameter, dict)): | if isinstance(params, (Parameter, dict)): | ||||
params = [params] | params = [params] | ||||
@@ -149,7 +150,8 @@ class Optimizer(metaclass=ABCMeta): | |||||
# set the globle state `_enable_convert_inputs` to `False` to disable | # set the globle state `_enable_convert_inputs` to `False` to disable | ||||
# the `convert_inputs` for param updates | # the `convert_inputs` for param updates | ||||
set_option("record_computing_path", 0) | set_option("record_computing_path", 0) | ||||
backup = set_convert_inputs(False) | |||||
if self._disable_type_convert: | |||||
backup = set_convert_inputs(False) | |||||
for group in self.param_groups: | for group in self.param_groups: | ||||
if isinstance(group["params"], set): | if isinstance(group["params"], set): | ||||
raise TypeError( | raise TypeError( | ||||
@@ -160,8 +162,9 @@ class Optimizer(metaclass=ABCMeta): | |||||
push_scope("step") | push_scope("step") | ||||
self._updates(group) | self._updates(group) | ||||
pop_scope("step") | pop_scope("step") | ||||
# restore the globle state `_enable_convert_inputs` | |||||
set_convert_inputs(backup) | |||||
if self._disable_type_convert: | |||||
# restore the globle state `_enable_convert_inputs` | |||||
set_convert_inputs(backup) | |||||
set_option("record_computing_path", 1) | set_option("record_computing_path", 1) | ||||
return self | return self | ||||
@@ -43,6 +43,7 @@ class SGD(Optimizer): | |||||
defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) | defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) | ||||
super().__init__(params, defaults) | super().__init__(params, defaults) | ||||
self._disable_type_convert = True | |||||
def _create_state(self, param_group): | def _create_state(self, param_group): | ||||
if param_group["momentum"] != 0.0: | if param_group["momentum"] != 0.0: | ||||