diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index 1196594f..cea3e49d 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -6,6 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import copy from abc import ABCMeta, abstractmethod from collections.abc import Iterable from typing import Dict @@ -197,10 +198,11 @@ class Optimizer(metaclass=ABCMeta): cur_id += 1 for param, st in self._state.items(): + _st = copy.copy(st) if not keep_var: for k, v in st.items(): - st[k] = v.numpy() - state[param2id[param]] = st + _st[k] = v.numpy() + state[param2id[param]] = _st for group in self.param_groups: param_group = {k: v for k, v in group.items() if k != "params"} diff --git a/imperative/python/test/integration/test_optimizer.py b/imperative/python/test/integration/test_optimizer.py index 6210233e..fd51d567 100644 --- a/imperative/python/test/integration/test_optimizer.py +++ b/imperative/python/test/integration/test_optimizer.py @@ -104,6 +104,10 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ) step += 1 check_func(ori_params, net.parameters(), step) + try_state_dict = { + "net": net.state_dict(), + "opt": opt.state_dict(), + } def test_sgd():