diff --git a/python_module/test/integration/manual/resnet50_perf.py b/python_module/test/integration/manual/resnet50_perf.py index 0f193e7d..9f36e8f0 100644 --- a/python_module/test/integration/manual/resnet50_perf.py +++ b/python_module/test/integration/manual/resnet50_perf.py @@ -14,7 +14,6 @@ import sys import time import numpy as np -from resnet50 import Resnet50 import megengine as mge import megengine.distributed as dist @@ -70,6 +69,9 @@ def run_perf( eager=False, ): + # pylint: disable = import-outside-toplevel + from resnet50 import Resnet50 + if conv_fastrun: set_conv_execution_strategy("PROFILE") diff --git a/python_module/test/integration/manual/verify_correctness.py b/python_module/test/integration/manual/verify_correctness.py deleted file mode 100644 index 2efaa1fe..00000000 --- a/python_module/test/integration/manual/verify_correctness.py +++ /dev/null @@ -1,141 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import os -import subprocess -import sys - -import numpy as np - - -def fwd_test(backend): - - model_path = "../examples/cifar10/resnet_example/checkpoint/pretrained_model_82.mge" - - # Change the reference number if the change is from numerical rounding-off - # FIXME! Need to use different number depending on CPU/GPU - loss_ref = np.array([7.315978]).astype(np.float32) - - if backend == "megengine-dynamic": - os.environ["MGE_DISABLE_TRACE"] = "true" - - import megengine - from megengine.functional.debug_param import set_conv_execution_strategy - from megengine.test import assertTensorClose - from megengine.core import Graph - - sys.path.append( - os.path.join(os.path.dirname(__file__), "..", "..", "..", "examples") - ) - from cifar10.resnet_example.main import Example as resnet18_config - from cifar10.resnet_example.main import eval_one_iter_mge - - mge_root = os.path.dirname(megengine.__file__) - model_path = os.path.join(mge_root, model_path) - run_case = resnet18_config(backend=backend, mode="eval") - run_case.init_net() - run_case.load_model(model_path) - - np.random.seed(0) - inputs = np.random.rand(run_case.train_batch_size, 3, 32, 32) - targets = np.random.randint(10, size=(run_case.train_batch_size,)) - max_err = 0.0 - - run_case.net_context["net"].eval() - loss, _ = eval_one_iter_mge(inputs, targets, config=run_case) - try: - loss = loss.numpy() - assertTensorClose(loss, loss_ref, max_err=max_err) - except: - print("calculated loss:", loss) - print("expect:", loss_ref) - sys.exit(1) - - -def train_test(backend): - - model_path = "../examples/cifar10/resnet_example/checkpoint/pretrained_model_82.mge" - - # Change the reference number if the change is from numerical rounding-off - # FIXME! Need to use different number depending on CPU/GPU - if backend == "megengine-dynamic": - os.environ["MGE_DISABLE_TRACE"] = "true" - loss_ref = np.array([3.4709125, 12.46342]).astype(np.float32) - else: - loss_ref = np.array([3.4709125, 12.463419]).astype(np.float32) - - import megengine - from megengine.functional.debug_param import set_conv_execution_strategy - from megengine.test import assertTensorClose - from megengine.core import Graph - - sys.path.append( - os.path.join(os.path.dirname(__file__), "..", "..", "..", "examples") - ) - from cifar10.resnet_example.main import Example as resnet18_config - from cifar10.resnet_example.main import train_one_iter_mge - - mge_root = os.path.dirname(megengine.__file__) - model_path = os.path.join(mge_root, model_path) - set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") - run_case = resnet18_config(backend=backend, mode="train") - run_case.init_net() - run_case.load_model(model_path) - - max_err = 0.0 - - loss = [] - np.random.seed(0) - inputs = np.random.rand(run_case.train_batch_size, 3, 32, 32) - targets = np.random.randint(10, size=(run_case.train_batch_size,)) - - run_case.set_optimizer(0.0) - opt = run_case.net_context["optimizer"] - - for lr in (1.0, 1.0): - run_case.set_optimizer(lr) - opt.zero_grad() - loss_batch, _ = train_one_iter_mge(inputs, targets, config=run_case) - opt.step() - loss.append(loss_batch.numpy()[0]) - try: - assertTensorClose(np.array(loss).astype(np.float32), loss_ref, max_err=1e-5) - except: - print("calculated loss:", loss) - print("expect:", loss_ref) - sys.exit(1) - - -def run_func(func): - cmd_start = ["python3", "-c"] - cmd_head = "from verify_correctness import fwd_test, train_test\n" - cmd = cmd_start + [cmd_head + func] - ret = subprocess.run( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True - ) - if ret.returncode != 0: - print("Failed!!!") - print(ret.stdout) - print(ret.stderr) - raise - print("Success") - - -if __name__ == "__main__": - - print("Running fwd static ...") - run_func('fwd_test(backend="megengine-static")') - - print("Running fwd dynamic ...") - run_func('fwd_test(backend="megengine-dynamic")') - - print("Running train static ...") - run_func('train_test(backend="megengine-static")') - - print("Running train dynamic ...") - run_func('train_test(backend="megengine-dynamic")') diff --git a/python_module/test/integration/mnist_model_with_test.mge b/python_module/test/integration/mnist_model_with_test.mge new file mode 100644 index 00000000..9cf1e20a Binary files /dev/null and b/python_module/test/integration/mnist_model_with_test.mge differ diff --git a/python_module/test/integration/mnist_model_with_test_cpu.mge b/python_module/test/integration/mnist_model_with_test_cpu.mge new file mode 100644 index 00000000..5f87e566 Binary files /dev/null and b/python_module/test/integration/mnist_model_with_test_cpu.mge differ diff --git a/python_module/test/integration/test_correctness.py b/python_module/test/integration/test_correctness.py new file mode 100644 index 00000000..5ea6c5ed --- /dev/null +++ b/python_module/test/integration/test_correctness.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import os +import sys + +import numpy as np + +import megengine as mge +import megengine.functional as F +from megengine import jit, tensor +from megengine.functional.debug_param import set_conv_execution_strategy +from megengine.module import BatchNorm2d, Conv2d, Linear, MaxPool2d, Module +from megengine.optimizer import SGD +from megengine.test import assertTensorClose + + +class MnistNet(Module): + def __init__(self, has_bn=False): + super().__init__() + self.conv0 = Conv2d(1, 20, kernel_size=5, bias=True) + self.pool0 = MaxPool2d(2) + self.conv1 = Conv2d(20, 20, kernel_size=5, bias=True) + self.pool1 = MaxPool2d(2) + self.fc0 = Linear(20 * 4 * 4, 500, bias=True) + self.fc1 = Linear(500, 10, bias=True) + self.bn0 = None + self.bn1 = None + if has_bn: + self.bn0 = BatchNorm2d(20) + self.bn1 = BatchNorm2d(20) + + def forward(self, x): + x = self.conv0(x) + if self.bn0: + x = self.bn0(x) + x = F.relu(x) + x = self.pool0(x) + x = self.conv1(x) + if self.bn1: + x = self.bn1(x) + x = F.relu(x) + x = self.pool1(x) + x = F.flatten(x, 1) + x = self.fc0(x) + x = F.relu(x) + x = self.fc1(x) + return x + + +def train(data, label, net, opt): + + pred = net(data) + loss = F.cross_entropy_with_softmax(pred, label) + opt.backward(loss) + return loss + + +def update_model(model_path): + """ + Update the dumped model with test cases for new reference values + """ + net = MnistNet(has_bn=True) + checkpoint = mge.load(model_path) + net.load_state_dict(checkpoint["net_init"]) + lr = checkpoint["sgd_lr"] + opt = SGD(net.parameters(), lr=lr) + + data = tensor(dtype=np.float32) + label = tensor(dtype=np.int32) + data.set_value(checkpoint["data"]) + label.set_value(checkpoint["label"]) + + opt.zero_grad() + loss = train(data, label, net=net, opt=opt) + opt.step() + + checkpoint.update({"net_updated": net.state_dict(), "loss": loss.numpy()}) + mge.save(checkpoint, model_path) + + +def run_test(model_path, use_jit, use_symbolic): + + """ + Load the model with test cases and run the training for one iter. + The loss and updated weights are compared with reference value to verify the correctness. + The model with pre-trained weights is trained for one iter and the net state dict is dumped. + The test cases is appended to the model file. The reference result is obtained + by running the train for one iter. + + Dump a new file with updated result by calling update_model + if you think the test fails due to numerical rounding errors instead of bugs. + Please think twice before you do so. + + """ + net = MnistNet(has_bn=True) + checkpoint = mge.load(model_path) + net.load_state_dict(checkpoint["net_init"]) + lr = checkpoint["sgd_lr"] + opt = SGD(net.parameters(), lr=lr) + + data = tensor(dtype=np.float32) + label = tensor(dtype=np.int32) + data.set_value(checkpoint["data"]) + label.set_value(checkpoint["label"]) + + max_err = 0.0 + + train_func = train + if use_jit: + train_func = jit.trace(train_func, symbolic=use_symbolic) + + opt.zero_grad() + loss = train_func(data, label, net=net, opt=opt) + opt.step() + + assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err) + + for param, param_ref in zip( + net.state_dict().items(), checkpoint["net_updated"].items() + ): + assert param[0] == param_ref[0] + assertTensorClose(param[1], param_ref[1], max_err=max_err) + + +def test_correctness(): + + if mge.is_cuda_available(): + model_name = "mnist_model_with_test.mge" + else: + model_name = "mnist_model_with_test_cpu.mge" + model_path = os.path.join(os.path.dirname(__file__), model_name) + set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") + + run_test(model_path, False, False) + run_test(model_path, True, False) + run_test(model_path, True, True) diff --git a/python_module/test/unit/module/test_pytorch.py b/python_module/test/unit/module/test_pytorch.py index 105bb62f..d7b3ae9a 100644 --- a/python_module/test/unit/module/test_pytorch.py +++ b/python_module/test/unit/module/test_pytorch.py @@ -94,9 +94,7 @@ def test_pytorch_mixed(): def __init__(self): super().__init__() self.torch_module = PyTorchModule(self.SubModule()) - a = list(self.SubModule().named_parameters(recurse=True)) - a = list(self.SubModule().parameters()) - self.multiplier = Parameter(np.array(init_param[1]), dtype=np.float32) + self.multiplier = Parameter(init_param[1], dtype=np.float32) def forward(self, inp): return self.torch_module(inp) * self.multiplier