|
|
@@ -6,6 +6,7 @@ |
|
|
|
# Unless required by applicable law or agreed to in writing, |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import copy |
|
|
|
from abc import ABCMeta, abstractmethod |
|
|
|
from collections.abc import Iterable |
|
|
|
from typing import Dict |
|
|
@@ -197,10 +198,11 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
cur_id += 1 |
|
|
|
|
|
|
|
for param, st in self._state.items(): |
|
|
|
_st = copy.copy(st) |
|
|
|
if not keep_var: |
|
|
|
for k, v in st.items(): |
|
|
|
st[k] = v.numpy() |
|
|
|
state[param2id[param]] = st |
|
|
|
_st[k] = v.numpy() |
|
|
|
state[param2id[param]] = _st |
|
|
|
|
|
|
|
for group in self.param_groups: |
|
|
|
param_group = {k: v for k, v in group.items() if k != "params"} |
|
|
|