Browse Source

feat(mge/optimizer): save state's numpy value by default in `state_dict`

GitOrigin-RevId: ec7e4d56f5
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
4917534b65
1 changed files with 7 additions and 5 deletions
  1. +7
    -5
      imperative/python/megengine/optimizer/optimizer.py

+ 7
- 5
imperative/python/megengine/optimizer/optimizer.py View File

@@ -8,7 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import contextmanager
from typing import Dict from typing import Dict
from typing import Iterable as Iter from typing import Iterable as Iter
from typing import Union from typing import Union
@@ -180,7 +179,7 @@ class Optimizer(metaclass=ABCMeta):
param.grad = None param.grad = None
pop_scope("clear_grad") pop_scope("clear_grad")


def state_dict(self) -> Dict:
def state_dict(self, keep_var=False) -> Dict:
r""" r"""
Export the optimizer state. Export the optimizer state.


@@ -198,6 +197,9 @@ class Optimizer(metaclass=ABCMeta):
cur_id += 1 cur_id += 1


for param, st in self._state.items(): for param, st in self._state.items():
if not keep_var:
for k, v in st.items():
st[k] = v.numpy()
state[param2id[param]] = st state[param2id[param]] = st


for group in self.param_groups: for group in self.param_groups:
@@ -218,7 +220,6 @@ class Optimizer(metaclass=ABCMeta):
raise ValueError( raise ValueError(
"loaded state dict has a different number of parameter groups" "loaded state dict has a different number of parameter groups"
) )
parameter_map = dict() # type: Dict
for group_new, group_saved in zip(self.param_groups, state["param_groups"]): for group_new, group_saved in zip(self.param_groups, state["param_groups"]):
if len(group_new["params"]) != len(group_saved["params"]): if len(group_new["params"]) != len(group_saved["params"]):
raise ValueError( raise ValueError(
@@ -232,8 +233,9 @@ class Optimizer(metaclass=ABCMeta):
self._state[p] = state["state"][param_saved].copy() self._state[p] = state["state"][param_saved].copy()
for k, v in self._state[p].items(): for k, v in self._state[p].items():
if isinstance(v, Tensor): if isinstance(v, Tensor):
# TODO: maybe a more efficient way?
self._state[p][k] = Tensor(v.numpy())
self._state[p][k] = v.detach()
else:
self._state[p][k] = Tensor(v)


if set(group_new.keys()) != set(group_saved.keys()): if set(group_new.keys()) != set(group_saved.keys()):
raise ValueError( raise ValueError(


Loading…
Cancel
Save