Browse Source

fix(mge/module): set no_cache=true when loading state dict

GitOrigin-RevId: 83281a3d47
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
d6db4fea93
3 changed files with 40 additions and 1 deletions
  1. +5
    -1
      imperative/python/megengine/module/module.py
  2. +32
    -0
      imperative/python/test/integration/test_save_load.py
  3. +3
    -0
      imperative/src/impl/ops/elemwise.cpp

+ 5
- 1
imperative/python/megengine/module/module.py View File

@@ -600,7 +600,11 @@ class Module(metaclass=ABCMeta):
k, var_shape, to_be_load_shape
)
)
var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device))
var._reset(
type(var)(
to_be_load, dtype=to_be_load.dtype, device=var.device, no_cache=True
)
)
loaded.append(k)

return set(loaded), set(skipped)


+ 32
- 0
imperative/python/test/integration/test_save_load.py View File

@@ -11,6 +11,7 @@ import numpy as np

import megengine as mge
import megengine.autodiff as ad
import megengine.module as M
import megengine.optimizer as optimizer
from megengine import Parameter, tensor
from megengine.module import Module
@@ -26,6 +27,37 @@ class Simple(Module):
return x


class Net(Module):
def __init__(self):
super().__init__()
self.fc = M.Linear(1, 1)

def forward(self, images):
x = self.fc(images)
loss = x.mean() * 10000
return loss


def test_load_state_dict_no_cache(monkeypatch):
with monkeypatch.context() as mk:
mk.setenv("MEGENGINE_INPLACE_UPDATE", "1")
net = Net()

optim = optimizer.SGD(net.parameters(), lr=0.1)
gm = ad.GradManager().attach(net.parameters())
state = {
"fc.weight": np.array([[0]], dtype=np.float32),
"fc.bias": np.array([0.0], dtype=np.float32),
}
net.load_state_dict(state)
images = mge.tensor([[0]], dtype=np.float32)
with gm:
loss = net(images)
gm.backward(loss)
optim.step()
optim.clear_grad()


def test_save_load():
net = Simple()



+ 3
- 0
imperative/src/impl/ops/elemwise.cpp View File

@@ -224,6 +224,9 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node(
SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs){
mgb_assert(inputs[0]->blob().unique() && inputs[0]->blob()->storage().unique(),
"This inplace modification may change the elements of other tensors. "
"Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs correctly.");
auto dest = inputs[0], delta = inputs[1],
alpha = inputs[2], beta = inputs[3];
auto tensor_to_scalar = [](const TensorPtr& tensor) -> float {


Loading…
Cancel
Save