# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import os import sys import numpy as np import megengine as mge import megengine.functional as F from megengine import jit, tensor from megengine.functional.debug_param import set_conv_execution_strategy from megengine.module import BatchNorm2d, Conv2d, Linear, MaxPool2d, Module from megengine.optimizer import SGD from megengine.test import assertTensorClose class MnistNet(Module): def __init__(self, has_bn=False): super().__init__() self.conv0 = Conv2d(1, 20, kernel_size=5, bias=True) self.pool0 = MaxPool2d(2) self.conv1 = Conv2d(20, 20, kernel_size=5, bias=True) self.pool1 = MaxPool2d(2) self.fc0 = Linear(20 * 4 * 4, 500, bias=True) self.fc1 = Linear(500, 10, bias=True) self.bn0 = None self.bn1 = None if has_bn: self.bn0 = BatchNorm2d(20) self.bn1 = BatchNorm2d(20) def forward(self, x): x = self.conv0(x) if self.bn0: x = self.bn0(x) x = F.relu(x) x = self.pool0(x) x = self.conv1(x) if self.bn1: x = self.bn1(x) x = F.relu(x) x = self.pool1(x) x = F.flatten(x, 1) x = self.fc0(x) x = F.relu(x) x = self.fc1(x) return x def train(data, label, net, opt): pred = net(data) loss = F.cross_entropy_with_softmax(pred, label) opt.backward(loss) return loss def update_model(model_path): """ Update the dumped model with test cases for new reference values """ net = MnistNet(has_bn=True) checkpoint = mge.load(model_path) net.load_state_dict(checkpoint["net_init"]) lr = checkpoint["sgd_lr"] opt = SGD(net.parameters(), lr=lr) data = tensor(dtype=np.float32) label = tensor(dtype=np.int32) data.set_value(checkpoint["data"]) label.set_value(checkpoint["label"]) opt.zero_grad() loss = train(data, label, net=net, opt=opt) opt.step() checkpoint.update({"net_updated": net.state_dict(), "loss": loss.numpy()}) mge.save(checkpoint, model_path) def run_test(model_path, use_jit, use_symbolic): """ Load the model with test cases and run the training for one iter. The loss and updated weights are compared with reference value to verify the correctness. The model with pre-trained weights is trained for one iter and the net state dict is dumped. The test cases is appended to the model file. The reference result is obtained by running the train for one iter. Dump a new file with updated result by calling update_model if you think the test fails due to numerical rounding errors instead of bugs. Please think twice before you do so. """ net = MnistNet(has_bn=True) checkpoint = mge.load(model_path) net.load_state_dict(checkpoint["net_init"]) lr = checkpoint["sgd_lr"] opt = SGD(net.parameters(), lr=lr) data = tensor(dtype=np.float32) label = tensor(dtype=np.int32) data.set_value(checkpoint["data"]) label.set_value(checkpoint["label"]) max_err = 0.0 train_func = train if use_jit: train_func = jit.trace(train_func, symbolic=use_symbolic) opt.zero_grad() loss = train_func(data, label, net=net, opt=opt) opt.step() assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err) for param, param_ref in zip( net.state_dict().items(), checkpoint["net_updated"].items() ): assert param[0] == param_ref[0] assertTensorClose(param[1], param_ref[1], max_err=max_err) def test_correctness(): if mge.is_cuda_available(): model_name = "mnist_model_with_test.mge" else: model_name = "mnist_model_with_test_cpu.mge" model_path = os.path.join(os.path.dirname(__file__), model_name) set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") run_test(model_path, False, False) run_test(model_path, True, False) run_test(model_path, True, True)