|
|
@@ -42,6 +42,7 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
): |
|
|
|
self._state = dict() |
|
|
|
self._defaults = defaults |
|
|
|
self._disable_type_convert = False |
|
|
|
|
|
|
|
if isinstance(params, (Parameter, dict)): |
|
|
|
params = [params] |
|
|
@@ -149,7 +150,8 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
# set the globle state `_enable_convert_inputs` to `False` to disable |
|
|
|
# the `convert_inputs` for param updates |
|
|
|
set_option("record_computing_path", 0) |
|
|
|
backup = set_convert_inputs(False) |
|
|
|
if self._disable_type_convert: |
|
|
|
backup = set_convert_inputs(False) |
|
|
|
for group in self.param_groups: |
|
|
|
if isinstance(group["params"], set): |
|
|
|
raise TypeError( |
|
|
@@ -160,8 +162,9 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
push_scope("step") |
|
|
|
self._updates(group) |
|
|
|
pop_scope("step") |
|
|
|
# restore the globle state `_enable_convert_inputs` |
|
|
|
set_convert_inputs(backup) |
|
|
|
if self._disable_type_convert: |
|
|
|
# restore the globle state `_enable_convert_inputs` |
|
|
|
set_convert_inputs(backup) |
|
|
|
set_option("record_computing_path", 1) |
|
|
|
return self |
|
|
|
|
|
|
|