Browse Source

fix(mge/optimizer): only disable convert inputs in build-in optimizers

GitOrigin-RevId: 1a48fe318d
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
3103180456
6 changed files with 11 additions and 3 deletions
  1. +1
    -0
      imperative/python/megengine/optimizer/adadelta.py
  2. +1
    -0
      imperative/python/megengine/optimizer/adagrad.py
  3. +1
    -0
      imperative/python/megengine/optimizer/adam.py
  4. +1
    -0
      imperative/python/megengine/optimizer/adamw.py
  5. +6
    -3
      imperative/python/megengine/optimizer/optimizer.py
  6. +1
    -0
      imperative/python/megengine/optimizer/sgd.py

+ 1
- 0
imperative/python/megengine/optimizer/adadelta.py View File

@@ -48,6 +48,7 @@ class Adadelta(Optimizer):

defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
self._disable_type_convert = True

def _create_state(self, param_group):
for param in param_group["params"]:


+ 1
- 0
imperative/python/megengine/optimizer/adagrad.py View File

@@ -48,6 +48,7 @@ class Adagrad(Optimizer):

defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
self._disable_type_convert = True

def _create_state(self, param_group):
for param in param_group["params"]:


+ 1
- 0
imperative/python/megengine/optimizer/adam.py View File

@@ -47,6 +47,7 @@ class Adam(Optimizer):

defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps)
super().__init__(params, defaults)
self._disable_type_convert = True

def _create_state(self, param_group):
for param in param_group["params"]:


+ 1
- 0
imperative/python/megengine/optimizer/adamw.py View File

@@ -47,6 +47,7 @@ class AdamW(Optimizer):

defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps)
super().__init__(params, defaults)
self._disable_type_convert = True

def _create_state(self, param_group):
for param in param_group["params"]:


+ 6
- 3
imperative/python/megengine/optimizer/optimizer.py View File

@@ -42,6 +42,7 @@ class Optimizer(metaclass=ABCMeta):
):
self._state = dict()
self._defaults = defaults
self._disable_type_convert = False

if isinstance(params, (Parameter, dict)):
params = [params]
@@ -149,7 +150,8 @@ class Optimizer(metaclass=ABCMeta):
# set the globle state `_enable_convert_inputs` to `False` to disable
# the `convert_inputs` for param updates
set_option("record_computing_path", 0)
backup = set_convert_inputs(False)
if self._disable_type_convert:
backup = set_convert_inputs(False)
for group in self.param_groups:
if isinstance(group["params"], set):
raise TypeError(
@@ -160,8 +162,9 @@ class Optimizer(metaclass=ABCMeta):
push_scope("step")
self._updates(group)
pop_scope("step")
# restore the globle state `_enable_convert_inputs`
set_convert_inputs(backup)
if self._disable_type_convert:
# restore the globle state `_enable_convert_inputs`
set_convert_inputs(backup)
set_option("record_computing_path", 1)
return self



+ 1
- 0
imperative/python/megengine/optimizer/sgd.py View File

@@ -43,6 +43,7 @@ class SGD(Optimizer):

defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
super().__init__(params, defaults)
self._disable_type_convert = True

def _create_state(self, param_group):
if param_group["momentum"] != 0.0:


Loading…
Cancel
Save