@@ -14,7 +14,6 @@ import sys | |||||
import time | import time | ||||
import numpy as np | import numpy as np | ||||
from resnet50 import Resnet50 | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.distributed as dist | import megengine.distributed as dist | ||||
@@ -70,6 +69,9 @@ def run_perf( | |||||
eager=False, | eager=False, | ||||
): | ): | ||||
# pylint: disable = import-outside-toplevel | |||||
from resnet50 import Resnet50 | |||||
if conv_fastrun: | if conv_fastrun: | ||||
set_conv_execution_strategy("PROFILE") | set_conv_execution_strategy("PROFILE") | ||||
@@ -1,141 +0,0 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import os | |||||
import subprocess | |||||
import sys | |||||
import numpy as np | |||||
def fwd_test(backend): | |||||
model_path = "../examples/cifar10/resnet_example/checkpoint/pretrained_model_82.mge" | |||||
# Change the reference number if the change is from numerical rounding-off | |||||
# FIXME! Need to use different number depending on CPU/GPU | |||||
loss_ref = np.array([7.315978]).astype(np.float32) | |||||
if backend == "megengine-dynamic": | |||||
os.environ["MGE_DISABLE_TRACE"] = "true" | |||||
import megengine | |||||
from megengine.functional.debug_param import set_conv_execution_strategy | |||||
from megengine.test import assertTensorClose | |||||
from megengine.core import Graph | |||||
sys.path.append( | |||||
os.path.join(os.path.dirname(__file__), "..", "..", "..", "examples") | |||||
) | |||||
from cifar10.resnet_example.main import Example as resnet18_config | |||||
from cifar10.resnet_example.main import eval_one_iter_mge | |||||
mge_root = os.path.dirname(megengine.__file__) | |||||
model_path = os.path.join(mge_root, model_path) | |||||
run_case = resnet18_config(backend=backend, mode="eval") | |||||
run_case.init_net() | |||||
run_case.load_model(model_path) | |||||
np.random.seed(0) | |||||
inputs = np.random.rand(run_case.train_batch_size, 3, 32, 32) | |||||
targets = np.random.randint(10, size=(run_case.train_batch_size,)) | |||||
max_err = 0.0 | |||||
run_case.net_context["net"].eval() | |||||
loss, _ = eval_one_iter_mge(inputs, targets, config=run_case) | |||||
try: | |||||
loss = loss.numpy() | |||||
assertTensorClose(loss, loss_ref, max_err=max_err) | |||||
except: | |||||
print("calculated loss:", loss) | |||||
print("expect:", loss_ref) | |||||
sys.exit(1) | |||||
def train_test(backend): | |||||
model_path = "../examples/cifar10/resnet_example/checkpoint/pretrained_model_82.mge" | |||||
# Change the reference number if the change is from numerical rounding-off | |||||
# FIXME! Need to use different number depending on CPU/GPU | |||||
if backend == "megengine-dynamic": | |||||
os.environ["MGE_DISABLE_TRACE"] = "true" | |||||
loss_ref = np.array([3.4709125, 12.46342]).astype(np.float32) | |||||
else: | |||||
loss_ref = np.array([3.4709125, 12.463419]).astype(np.float32) | |||||
import megengine | |||||
from megengine.functional.debug_param import set_conv_execution_strategy | |||||
from megengine.test import assertTensorClose | |||||
from megengine.core import Graph | |||||
sys.path.append( | |||||
os.path.join(os.path.dirname(__file__), "..", "..", "..", "examples") | |||||
) | |||||
from cifar10.resnet_example.main import Example as resnet18_config | |||||
from cifar10.resnet_example.main import train_one_iter_mge | |||||
mge_root = os.path.dirname(megengine.__file__) | |||||
model_path = os.path.join(mge_root, model_path) | |||||
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") | |||||
run_case = resnet18_config(backend=backend, mode="train") | |||||
run_case.init_net() | |||||
run_case.load_model(model_path) | |||||
max_err = 0.0 | |||||
loss = [] | |||||
np.random.seed(0) | |||||
inputs = np.random.rand(run_case.train_batch_size, 3, 32, 32) | |||||
targets = np.random.randint(10, size=(run_case.train_batch_size,)) | |||||
run_case.set_optimizer(0.0) | |||||
opt = run_case.net_context["optimizer"] | |||||
for lr in (1.0, 1.0): | |||||
run_case.set_optimizer(lr) | |||||
opt.zero_grad() | |||||
loss_batch, _ = train_one_iter_mge(inputs, targets, config=run_case) | |||||
opt.step() | |||||
loss.append(loss_batch.numpy()[0]) | |||||
try: | |||||
assertTensorClose(np.array(loss).astype(np.float32), loss_ref, max_err=1e-5) | |||||
except: | |||||
print("calculated loss:", loss) | |||||
print("expect:", loss_ref) | |||||
sys.exit(1) | |||||
def run_func(func): | |||||
cmd_start = ["python3", "-c"] | |||||
cmd_head = "from verify_correctness import fwd_test, train_test\n" | |||||
cmd = cmd_start + [cmd_head + func] | |||||
ret = subprocess.run( | |||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True | |||||
) | |||||
if ret.returncode != 0: | |||||
print("Failed!!!") | |||||
print(ret.stdout) | |||||
print(ret.stderr) | |||||
raise | |||||
print("Success") | |||||
if __name__ == "__main__": | |||||
print("Running fwd static ...") | |||||
run_func('fwd_test(backend="megengine-static")') | |||||
print("Running fwd dynamic ...") | |||||
run_func('fwd_test(backend="megengine-dynamic")') | |||||
print("Running train static ...") | |||||
run_func('train_test(backend="megengine-static")') | |||||
print("Running train dynamic ...") | |||||
run_func('train_test(backend="megengine-dynamic")') |
@@ -0,0 +1,142 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import os | |||||
import sys | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.functional as F | |||||
from megengine import jit, tensor | |||||
from megengine.functional.debug_param import set_conv_execution_strategy | |||||
from megengine.module import BatchNorm2d, Conv2d, Linear, MaxPool2d, Module | |||||
from megengine.optimizer import SGD | |||||
from megengine.test import assertTensorClose | |||||
class MnistNet(Module): | |||||
def __init__(self, has_bn=False): | |||||
super().__init__() | |||||
self.conv0 = Conv2d(1, 20, kernel_size=5, bias=True) | |||||
self.pool0 = MaxPool2d(2) | |||||
self.conv1 = Conv2d(20, 20, kernel_size=5, bias=True) | |||||
self.pool1 = MaxPool2d(2) | |||||
self.fc0 = Linear(20 * 4 * 4, 500, bias=True) | |||||
self.fc1 = Linear(500, 10, bias=True) | |||||
self.bn0 = None | |||||
self.bn1 = None | |||||
if has_bn: | |||||
self.bn0 = BatchNorm2d(20) | |||||
self.bn1 = BatchNorm2d(20) | |||||
def forward(self, x): | |||||
x = self.conv0(x) | |||||
if self.bn0: | |||||
x = self.bn0(x) | |||||
x = F.relu(x) | |||||
x = self.pool0(x) | |||||
x = self.conv1(x) | |||||
if self.bn1: | |||||
x = self.bn1(x) | |||||
x = F.relu(x) | |||||
x = self.pool1(x) | |||||
x = F.flatten(x, 1) | |||||
x = self.fc0(x) | |||||
x = F.relu(x) | |||||
x = self.fc1(x) | |||||
return x | |||||
def train(data, label, net, opt): | |||||
pred = net(data) | |||||
loss = F.cross_entropy_with_softmax(pred, label) | |||||
opt.backward(loss) | |||||
return loss | |||||
def update_model(model_path): | |||||
""" | |||||
Update the dumped model with test cases for new reference values | |||||
""" | |||||
net = MnistNet(has_bn=True) | |||||
checkpoint = mge.load(model_path) | |||||
net.load_state_dict(checkpoint["net_init"]) | |||||
lr = checkpoint["sgd_lr"] | |||||
opt = SGD(net.parameters(), lr=lr) | |||||
data = tensor(dtype=np.float32) | |||||
label = tensor(dtype=np.int32) | |||||
data.set_value(checkpoint["data"]) | |||||
label.set_value(checkpoint["label"]) | |||||
opt.zero_grad() | |||||
loss = train(data, label, net=net, opt=opt) | |||||
opt.step() | |||||
checkpoint.update({"net_updated": net.state_dict(), "loss": loss.numpy()}) | |||||
mge.save(checkpoint, model_path) | |||||
def run_test(model_path, use_jit, use_symbolic): | |||||
""" | |||||
Load the model with test cases and run the training for one iter. | |||||
The loss and updated weights are compared with reference value to verify the correctness. | |||||
The model with pre-trained weights is trained for one iter and the net state dict is dumped. | |||||
The test cases is appended to the model file. The reference result is obtained | |||||
by running the train for one iter. | |||||
Dump a new file with updated result by calling update_model | |||||
if you think the test fails due to numerical rounding errors instead of bugs. | |||||
Please think twice before you do so. | |||||
""" | |||||
net = MnistNet(has_bn=True) | |||||
checkpoint = mge.load(model_path) | |||||
net.load_state_dict(checkpoint["net_init"]) | |||||
lr = checkpoint["sgd_lr"] | |||||
opt = SGD(net.parameters(), lr=lr) | |||||
data = tensor(dtype=np.float32) | |||||
label = tensor(dtype=np.int32) | |||||
data.set_value(checkpoint["data"]) | |||||
label.set_value(checkpoint["label"]) | |||||
max_err = 0.0 | |||||
train_func = train | |||||
if use_jit: | |||||
train_func = jit.trace(train_func, symbolic=use_symbolic) | |||||
opt.zero_grad() | |||||
loss = train_func(data, label, net=net, opt=opt) | |||||
opt.step() | |||||
assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err) | |||||
for param, param_ref in zip( | |||||
net.state_dict().items(), checkpoint["net_updated"].items() | |||||
): | |||||
assert param[0] == param_ref[0] | |||||
assertTensorClose(param[1], param_ref[1], max_err=max_err) | |||||
def test_correctness(): | |||||
if mge.is_cuda_available(): | |||||
model_name = "mnist_model_with_test.mge" | |||||
else: | |||||
model_name = "mnist_model_with_test_cpu.mge" | |||||
model_path = os.path.join(os.path.dirname(__file__), model_name) | |||||
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") | |||||
run_test(model_path, False, False) | |||||
run_test(model_path, True, False) | |||||
run_test(model_path, True, True) |
@@ -94,9 +94,7 @@ def test_pytorch_mixed(): | |||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.torch_module = PyTorchModule(self.SubModule()) | self.torch_module = PyTorchModule(self.SubModule()) | ||||
a = list(self.SubModule().named_parameters(recurse=True)) | |||||
a = list(self.SubModule().parameters()) | |||||
self.multiplier = Parameter(np.array(init_param[1]), dtype=np.float32) | |||||
self.multiplier = Parameter(init_param[1], dtype=np.float32) | |||||
def forward(self, inp): | def forward(self, inp): | ||||
return self.torch_module(inp) * self.multiplier | return self.torch_module(inp) * self.multiplier | ||||