Browse Source

fix(mge/optimizer): fix optimizer's state_dict bug

GitOrigin-RevId: 67fb112fb8
release-1.3
Megvii Engine Team 4 years ago
parent
commit
cad8568c34
2 changed files with 8 additions and 2 deletions
  1. +4
    -2
      imperative/python/megengine/optimizer/optimizer.py
  2. +4
    -0
      imperative/python/test/integration/test_optimizer.py

+ 4
- 2
imperative/python/megengine/optimizer/optimizer.py View File

@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
from abc import ABCMeta, abstractmethod
from collections.abc import Iterable
from typing import Dict
@@ -197,10 +198,11 @@ class Optimizer(metaclass=ABCMeta):
cur_id += 1

for param, st in self._state.items():
_st = copy.copy(st)
if not keep_var:
for k, v in st.items():
st[k] = v.numpy()
state[param2id[param]] = st
_st[k] = v.numpy()
state[param2id[param]] = _st

for group in self.param_groups:
param_group = {k: v for k, v in group.items() if k != "params"}


+ 4
- 0
imperative/python/test/integration/test_optimizer.py View File

@@ -104,6 +104,10 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
)
step += 1
check_func(ori_params, net.parameters(), step)
try_state_dict = {
"net": net.state_dict(),
"opt": opt.state_dict(),
}


def test_sgd():


Loading…
Cancel
Save