Browse Source

feat(mge/imperative): add more trace test

GitOrigin-RevId: b02e420a8a
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
afddefb677
2 changed files with 40 additions and 4 deletions
  1. +38
    -2
      imperative/python/test/integration/test_correctness.py
  2. +2
    -2
      imperative/src/test/opr_utility.cpp

+ 38
- 2
imperative/python/test/integration/test_correctness.py View File

@@ -16,6 +16,8 @@ import pytest

import megengine as mge
import megengine.functional as F
from megengine import jit
from megengine.core._trace_option import set_tensor_shape
from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module
from megengine.optimizer import SGD
@@ -129,7 +131,7 @@ def update_model(model_path):
mge.save(checkpoint, model_path)


def run_test(
def run_train(
model_path, use_jit, use_symbolic, sublinear_memory_config=None, max_err=None,
):

@@ -175,6 +177,37 @@ def run_test(
assertTensorClose(param[1], param_ref[1], max_err=max_err)


def run_eval(
model_path, use_symbolic, sublinear_memory_config=None, max_err=None,
):

"""
Load the model with test cases and run the training for one iter.
The loss and updated weights are compared with reference value to verify the correctness.

Dump a new file with updated result by calling update_model
if you think the test fails due to numerical rounding errors instead of bugs.
Please think twice before you do so.

"""
net = MnistNet(has_bn=True)
checkpoint = mge.load(model_path)
net.load_state_dict(checkpoint["net_init"])

data = Tensor(checkpoint["data"], dtype=np.float32)

def eval_fun(data, *, net=None):
pred = net(data)
return pred

refer_value = eval_fun(data, net=net)
eval_fun = jit.trace(eval_fun, symbolic=use_symbolic)

for _ in range(3):
new_value = eval_fun(data, net=net)
assertTensorClose(new_value.numpy(), refer_value.numpy(), max_err=max_err)


def test_correctness():
if mge.is_cuda_available():
model_name = "mnist_model_with_test.mge"
@@ -183,7 +216,7 @@ def test_correctness():
model_path = os.path.join(os.path.dirname(__file__), model_name)
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE")

run_test(model_path, False, False, max_err=1e-5)
run_train(model_path, False, False, max_err=1e-5)
# run_test(model_path, True, False)
# run_test(model_path, True, True)

@@ -192,3 +225,6 @@ def test_correctness():
# run_test(
# model_path, True, True, sublinear_memory_config=config, max_err=1e-5,
# )

run_eval(model_path, False, max_err=1e-7)
# run_eval(model_path, True, max_err=1e-7) # XXX: fix me

+ 2
- 2
imperative/src/test/opr_utility.cpp View File

@@ -25,7 +25,7 @@ TEST(TestOprUtility, InputCallback) {
dv.copy_from(*hv).sync();
auto graph = ComputingGraph::make();
auto callback = [dv]() {return dv;};
auto outputs = opr::InputCallback::make(*graph, callback, dv.comp_node(), dv.dtype());
auto outputs = opr::InputCallback::make(*graph, callback, dv.comp_node(), dv.dtype(), {2, 3});

HostTensorND hout;
ComputingGraph::OutputSpec outspec{make_callback_copy(outputs[0], hout)};
@@ -99,7 +99,7 @@ TEST(TestOprUtility, CallbackChain) {
dev_x.storage({});
return ret;
};
auto out = opr::InputCallback::make(*graph, callback, cn, dev_x.dtype());
auto out = opr::InputCallback::make(*graph, callback, cn, dev_x.dtype(), {2, 3});
x = out[0];
dummy = out[1];
}


Loading…
Cancel
Save