Browse Source

fix(mge/optimizer): fix optimizer's state_dict bug

GitOrigin-RevId: 67fb112fb8
release-1.3
Megvii Engine Team 4 years ago
parent
commit
cad8568c34
2 changed files with 8 additions and 2 deletions
  1. +4
    -2
      imperative/python/megengine/optimizer/optimizer.py
  2. +4
    -0
      imperative/python/test/integration/test_optimizer.py

+ 4
- 2
imperative/python/megengine/optimizer/optimizer.py View File

@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import Dict from typing import Dict
@@ -197,10 +198,11 @@ class Optimizer(metaclass=ABCMeta):
cur_id += 1 cur_id += 1


for param, st in self._state.items(): for param, st in self._state.items():
_st = copy.copy(st)
if not keep_var: if not keep_var:
for k, v in st.items(): for k, v in st.items():
st[k] = v.numpy()
state[param2id[param]] = st
_st[k] = v.numpy()
state[param2id[param]] = _st


for group in self.param_groups: for group in self.param_groups:
param_group = {k: v for k, v in group.items() if k != "params"} param_group = {k: v for k, v in group.items() if k != "params"}


+ 4
- 0
imperative/python/test/integration/test_optimizer.py View File

@@ -104,6 +104,10 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
) )
step += 1 step += 1
check_func(ori_params, net.parameters(), step) check_func(ori_params, net.parameters(), step)
try_state_dict = {
"net": net.state_dict(),
"opt": opt.state_dict(),
}




def test_sgd(): def test_sgd():


Loading…
Cancel
Save