|
|
@@ -16,6 +16,8 @@ import pytest |
|
|
|
|
|
|
|
import megengine as mge |
|
|
|
import megengine.functional as F |
|
|
|
from megengine import jit |
|
|
|
from megengine.core._trace_option import set_tensor_shape |
|
|
|
from megengine.functional.debug_param import set_conv_execution_strategy |
|
|
|
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module |
|
|
|
from megengine.optimizer import SGD |
|
|
@@ -129,7 +131,7 @@ def update_model(model_path): |
|
|
|
mge.save(checkpoint, model_path) |
|
|
|
|
|
|
|
|
|
|
|
def run_test( |
|
|
|
def run_train( |
|
|
|
model_path, use_jit, use_symbolic, sublinear_memory_config=None, max_err=None, |
|
|
|
): |
|
|
|
|
|
|
@@ -175,6 +177,37 @@ def run_test( |
|
|
|
assertTensorClose(param[1], param_ref[1], max_err=max_err) |
|
|
|
|
|
|
|
|
|
|
|
def run_eval( |
|
|
|
model_path, use_symbolic, sublinear_memory_config=None, max_err=None, |
|
|
|
): |
|
|
|
|
|
|
|
""" |
|
|
|
Load the model with test cases and run the training for one iter. |
|
|
|
The loss and updated weights are compared with reference value to verify the correctness. |
|
|
|
|
|
|
|
Dump a new file with updated result by calling update_model |
|
|
|
if you think the test fails due to numerical rounding errors instead of bugs. |
|
|
|
Please think twice before you do so. |
|
|
|
|
|
|
|
""" |
|
|
|
net = MnistNet(has_bn=True) |
|
|
|
checkpoint = mge.load(model_path) |
|
|
|
net.load_state_dict(checkpoint["net_init"]) |
|
|
|
|
|
|
|
data = Tensor(checkpoint["data"], dtype=np.float32) |
|
|
|
|
|
|
|
def eval_fun(data, *, net=None): |
|
|
|
pred = net(data) |
|
|
|
return pred |
|
|
|
|
|
|
|
refer_value = eval_fun(data, net=net) |
|
|
|
eval_fun = jit.trace(eval_fun, symbolic=use_symbolic) |
|
|
|
|
|
|
|
for _ in range(3): |
|
|
|
new_value = eval_fun(data, net=net) |
|
|
|
assertTensorClose(new_value.numpy(), refer_value.numpy(), max_err=max_err) |
|
|
|
|
|
|
|
|
|
|
|
def test_correctness(): |
|
|
|
if mge.is_cuda_available(): |
|
|
|
model_name = "mnist_model_with_test.mge" |
|
|
@@ -183,7 +216,7 @@ def test_correctness(): |
|
|
|
model_path = os.path.join(os.path.dirname(__file__), model_name) |
|
|
|
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") |
|
|
|
|
|
|
|
run_test(model_path, False, False, max_err=1e-5) |
|
|
|
run_train(model_path, False, False, max_err=1e-5) |
|
|
|
# run_test(model_path, True, False) |
|
|
|
# run_test(model_path, True, True) |
|
|
|
|
|
|
@@ -192,3 +225,6 @@ def test_correctness(): |
|
|
|
# run_test( |
|
|
|
# model_path, True, True, sublinear_memory_config=config, max_err=1e-5, |
|
|
|
# ) |
|
|
|
|
|
|
|
run_eval(model_path, False, max_err=1e-7) |
|
|
|
# run_eval(model_path, True, max_err=1e-7) # XXX: fix me |