Browse Source

fix(imperative): fix inplace operation of optim

GitOrigin-RevId: 2aaa71eb66
HuaHua404-patch-4
Megvii Engine Team 3 years ago
parent
commit
f386381080
2 changed files with 43 additions and 1 deletions
  1. +1
    -1
      imperative/python/megengine/optimizer/optimizer.py
  2. +42
    -0
      imperative/python/test/unit/amp/test_convert_format.py

+ 1
- 1
imperative/python/megengine/optimizer/optimizer.py View File

@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is " "optimizer can only optimize Parameters, but one of the params is "
+ str(type(param)) + str(type(param))
) )
param[...] = Tensor(param.numpy(), no_cache=True)
param[...] = Tensor(param, no_cache=True)


for name, default in self._defaults.items(): for name, default in self._defaults.items():
if default is required and name not in param_group: if default is required and name not in param_group:


+ 42
- 0
imperative/python/test/unit/amp/test_convert_format.py View File

@@ -1,8 +1,11 @@
import numpy as np import numpy as np
import pytest import pytest


import megengine as mge
import megengine.autodiff as autodiff
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.optimizer as optim
from megengine import Parameter, Tensor, amp from megengine import Parameter, Tensor, amp
from megengine.core._config import set_auto_format_convert from megengine.core._config import set_auto_format_convert
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
@@ -57,3 +60,42 @@ def test_convert_module(is_inplace):
) )
else: else:
assert param.shape == expected_shape[name], name assert param.shape == expected_shape[name], name


class Module(M.Module):
def __init__(self):
super().__init__()
self.conv = M.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn = M.BatchNorm2d(16)

def forward(self, x):
out = F.relu(self.bn(self.conv(x)))
return out


def test_format_remained():
m = Module()

m = amp.convert_module_format(m)

gm = autodiff.GradManager().attach(m.parameters())
opt = optim.SGD(m.parameters(), lr=0.01)
scaler = amp.GradScaler()

image = mge.tensor(np.random.normal(size=(1, 3, 224, 224)), dtype="float32")
label = mge.tensor(np.ones((1, 224, 224)), dtype="int32")

image = amp.convert_tensor_format(image)

@amp.autocast(enabled=True)
def train_step(image):
with gm:
logits = m(image)
loss = F.nn.cross_entropy(logits, label)
scaler.backward(gm, loss)
opt.step().clear_grad()
return logits

for _ in range(5):
res = train_step(image)
assert res.format == "nhwc"

Loading…
Cancel
Save