|
@@ -8,7 +8,6 @@ |
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
from abc import ABCMeta, abstractmethod |
|
|
from abc import ABCMeta, abstractmethod |
|
|
from collections.abc import Iterable |
|
|
from collections.abc import Iterable |
|
|
from contextlib import contextmanager |
|
|
|
|
|
from typing import Dict |
|
|
from typing import Dict |
|
|
from typing import Iterable as Iter |
|
|
from typing import Iterable as Iter |
|
|
from typing import Union |
|
|
from typing import Union |
|
@@ -180,7 +179,7 @@ class Optimizer(metaclass=ABCMeta): |
|
|
param.grad = None |
|
|
param.grad = None |
|
|
pop_scope("clear_grad") |
|
|
pop_scope("clear_grad") |
|
|
|
|
|
|
|
|
def state_dict(self) -> Dict: |
|
|
|
|
|
|
|
|
def state_dict(self, keep_var=False) -> Dict: |
|
|
r""" |
|
|
r""" |
|
|
Export the optimizer state. |
|
|
Export the optimizer state. |
|
|
|
|
|
|
|
@@ -198,6 +197,9 @@ class Optimizer(metaclass=ABCMeta): |
|
|
cur_id += 1 |
|
|
cur_id += 1 |
|
|
|
|
|
|
|
|
for param, st in self._state.items(): |
|
|
for param, st in self._state.items(): |
|
|
|
|
|
if not keep_var: |
|
|
|
|
|
for k, v in st.items(): |
|
|
|
|
|
st[k] = v.numpy() |
|
|
state[param2id[param]] = st |
|
|
state[param2id[param]] = st |
|
|
|
|
|
|
|
|
for group in self.param_groups: |
|
|
for group in self.param_groups: |
|
@@ -218,7 +220,6 @@ class Optimizer(metaclass=ABCMeta): |
|
|
raise ValueError( |
|
|
raise ValueError( |
|
|
"loaded state dict has a different number of parameter groups" |
|
|
"loaded state dict has a different number of parameter groups" |
|
|
) |
|
|
) |
|
|
parameter_map = dict() # type: Dict |
|
|
|
|
|
for group_new, group_saved in zip(self.param_groups, state["param_groups"]): |
|
|
for group_new, group_saved in zip(self.param_groups, state["param_groups"]): |
|
|
if len(group_new["params"]) != len(group_saved["params"]): |
|
|
if len(group_new["params"]) != len(group_saved["params"]): |
|
|
raise ValueError( |
|
|
raise ValueError( |
|
@@ -232,8 +233,9 @@ class Optimizer(metaclass=ABCMeta): |
|
|
self._state[p] = state["state"][param_saved].copy() |
|
|
self._state[p] = state["state"][param_saved].copy() |
|
|
for k, v in self._state[p].items(): |
|
|
for k, v in self._state[p].items(): |
|
|
if isinstance(v, Tensor): |
|
|
if isinstance(v, Tensor): |
|
|
# TODO: maybe a more efficient way? |
|
|
|
|
|
self._state[p][k] = Tensor(v.numpy()) |
|
|
|
|
|
|
|
|
self._state[p][k] = v.detach() |
|
|
|
|
|
else: |
|
|
|
|
|
self._state[p][k] = Tensor(v) |
|
|
|
|
|
|
|
|
if set(group_new.keys()) != set(group_saved.keys()): |
|
|
if set(group_new.keys()) != set(group_saved.keys()): |
|
|
raise ValueError( |
|
|
raise ValueError( |
|
|