@@ -315,9 +315,9 @@ class GraphInference: | |||
inputs = get_dep_vars(output_nodes, "Host2DeviceCopy") | |||
self._inp_dict = OrderedDict() | |||
replace_dict = {} | |||
for i in inputs: | |||
for idx, i in enumerate(inputs): | |||
inp_node = G.InputNode( | |||
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph | |||
device="xpux", dtype=inputs[idx].dtype, graph=inputs[0].graph | |||
) | |||
self._inp_dict[i.name] = inp_node | |||
replace_dict[i] = inp_node.outputs[0] | |||
@@ -1,13 +1,22 @@ | |||
import io | |||
import numpy as np | |||
import megengine.utils.comp_graph_tools as cgtools | |||
from megengine import tensor | |||
from megengine.jit import trace | |||
def _default_compare_fn(x, y): | |||
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | |||
if isinstance(x, np.ndarray): | |||
np.testing.assert_allclose(x, y, rtol=1e-6) | |||
else: | |||
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | |||
def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs): | |||
def opr_test( | |||
cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs | |||
): | |||
""" | |||
:param cases: the list which have dict element, the list length should be 2 for dynamic shape test. | |||
and the dict should have input, | |||
@@ -35,6 +44,8 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs) | |||
if not isinstance(results, (tuple, list)): | |||
results = (results,) | |||
for r, e in zip(results, expected): | |||
if not isinstance(r, tensor): | |||
r = tensor(r) | |||
compare_fn(r, e) | |||
def get_param(cases, idx): | |||
@@ -63,5 +74,36 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs) | |||
inp, outp = get_param(cases, 0) | |||
inp_tensor = [tensor(inpi) for inpi in inp] | |||
if test_trace: | |||
copied_inp = inp_tensor.copy() | |||
for symbolic in [False, True]: | |||
traced_func = trace(symbolic=symbolic)(func) | |||
for _ in range(3): | |||
traced_results = traced_func(*copied_inp, **kwargs) | |||
check_results(traced_results, outp) | |||
dumped_func = trace(symbolic=True, capture_as_const=True)(func) | |||
dumped_results = dumped_func(*copied_inp, **kwargs) | |||
check_results(dumped_results, outp) | |||
file = io.BytesIO() | |||
dump_info = dumped_func.dump(file) | |||
file.seek(0) | |||
# arg_name has pattern arg_xxx, xxx is int value | |||
def take_number(arg_name): | |||
return int(arg_name.split("_")[-1]) | |||
input_names = dump_info[4] | |||
inps_np = [i.numpy() for i in copied_inp] | |||
input_names.sort(key=take_number) | |||
inp_dict = dict(zip(input_names, inps_np)) | |||
infer_cg = cgtools.GraphInference(file) | |||
# assume #outputs == 1 | |||
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0] | |||
check_results(loaded_results, outp) | |||
results = func(*inp_tensor, **kwargs) | |||
check_results(results, outp) |
@@ -36,7 +36,7 @@ def test_where(): | |||
{"input": [maskv0, xv0, yv0]}, | |||
{"input": [maskv1, xv1, yv1]}, | |||
] | |||
opr_test(cases, F.where, ref_fn=np.where) | |||
opr_test(cases, F.where, ref_fn=np.where, test_trace=False) | |||
maskv2 = np.array([1, 1, 1], dtype=np.bool_) | |||
xv2 = np.array([1, 3, 2], dtype=np.float32) | |||
@@ -50,7 +50,7 @@ def test_where(): | |||
{"input": [maskv2, xv2, yv2]}, | |||
{"input": [maskv3, xv3, yv3]}, | |||
] | |||
opr_test(cases, F.where, ref_fn=np.where) | |||
opr_test(cases, F.where, ref_fn=np.where, test_trace=False) | |||
def test_dropout(): | |||
@@ -115,14 +115,17 @@ def test_matmul(): | |||
{"input": [data4, data5]}, | |||
] | |||
for _ in range(0, batch_size): | |||
# FIXME: remove test_trace=False in the future | |||
opr_test( | |||
cases, F.matmul, ref_fn=np.matmul, | |||
cases, F.matmul, test_trace=False, ref_fn=np.matmul, | |||
) | |||
# FIXME: remove test_trace=False in the future | |||
opr_test( | |||
[{"input": [data1, data4]}], | |||
F.matmul, | |||
ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)), | |||
test_trace=False, | |||
transpose_b=True, | |||
) | |||
@@ -162,20 +162,24 @@ def test_linspace(): | |||
{"input": [1, 9, 9]}, | |||
{"input": [3, 10, 8]}, | |||
] | |||
# FIXME: remove test_trace=False in the future | |||
opr_test( | |||
cases, | |||
F.linspace, | |||
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | |||
test_trace=False, | |||
) | |||
cases = [ | |||
{"input": [9, 1, 9]}, | |||
{"input": [10, 3, 8]}, | |||
] | |||
# FIXME: remove test_trace=False in the future | |||
opr_test( | |||
cases, | |||
F.linspace, | |||
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | |||
test_trace=False, | |||
) | |||
@@ -184,30 +188,36 @@ def test_arange(): | |||
{"input": [1, 9, 1]}, | |||
{"input": [2, 10, 2]}, | |||
] | |||
# FIXME: remove test_trace=False in the future | |||
opr_test( | |||
cases, | |||
F.arange, | |||
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
test_trace=False, | |||
) | |||
cases = [ | |||
{"input": [9, 1, -1]}, | |||
{"input": [10, 2, -2]}, | |||
] | |||
# FIXME: remove test_trace=False in the future | |||
opr_test( | |||
cases, | |||
F.arange, | |||
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
test_trace=False, | |||
) | |||
cases = [ | |||
{"input": [9.3, 1.2, -0.5]}, | |||
{"input": [10.3, 2.1, -1.7]}, | |||
] | |||
# FIXME: remove test_trace=False in the future | |||
opr_test( | |||
cases, | |||
F.arange, | |||
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
test_trace=False, | |||
) | |||
@@ -279,7 +289,8 @@ def test_broadcast(): | |||
{"input": [data1, output1_shape], "output": output1_shape}, | |||
{"input": [data2, output2_shape], "output": output2_shape}, | |||
] | |||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||
# FIXME: remove test_trace=False in the future | |||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, test_trace=False) | |||
x = F.ones((2, 1, 3)) | |||
with pytest.raises(RuntimeError): | |||