@@ -315,9 +315,9 @@ class GraphInference: | |||||
inputs = get_dep_vars(output_nodes, "Host2DeviceCopy") | inputs = get_dep_vars(output_nodes, "Host2DeviceCopy") | ||||
self._inp_dict = OrderedDict() | self._inp_dict = OrderedDict() | ||||
replace_dict = {} | replace_dict = {} | ||||
for i in inputs: | |||||
for idx, i in enumerate(inputs): | |||||
inp_node = G.InputNode( | inp_node = G.InputNode( | ||||
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph | |||||
device="xpux", dtype=inputs[idx].dtype, graph=inputs[0].graph | |||||
) | ) | ||||
self._inp_dict[i.name] = inp_node | self._inp_dict[i.name] = inp_node | ||||
replace_dict[i] = inp_node.outputs[0] | replace_dict[i] = inp_node.outputs[0] | ||||
@@ -1,13 +1,22 @@ | |||||
import io | |||||
import numpy as np | import numpy as np | ||||
import megengine.utils.comp_graph_tools as cgtools | |||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.jit import trace | |||||
def _default_compare_fn(x, y): | def _default_compare_fn(x, y): | ||||
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | |||||
if isinstance(x, np.ndarray): | |||||
np.testing.assert_allclose(x, y, rtol=1e-6) | |||||
else: | |||||
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | |||||
def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs): | |||||
def opr_test( | |||||
cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs | |||||
): | |||||
""" | """ | ||||
:param cases: the list which have dict element, the list length should be 2 for dynamic shape test. | :param cases: the list which have dict element, the list length should be 2 for dynamic shape test. | ||||
and the dict should have input, | and the dict should have input, | ||||
@@ -35,6 +44,8 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs) | |||||
if not isinstance(results, (tuple, list)): | if not isinstance(results, (tuple, list)): | ||||
results = (results,) | results = (results,) | ||||
for r, e in zip(results, expected): | for r, e in zip(results, expected): | ||||
if not isinstance(r, tensor): | |||||
r = tensor(r) | |||||
compare_fn(r, e) | compare_fn(r, e) | ||||
def get_param(cases, idx): | def get_param(cases, idx): | ||||
@@ -63,5 +74,36 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs) | |||||
inp, outp = get_param(cases, 0) | inp, outp = get_param(cases, 0) | ||||
inp_tensor = [tensor(inpi) for inpi in inp] | inp_tensor = [tensor(inpi) for inpi in inp] | ||||
if test_trace: | |||||
copied_inp = inp_tensor.copy() | |||||
for symbolic in [False, True]: | |||||
traced_func = trace(symbolic=symbolic)(func) | |||||
for _ in range(3): | |||||
traced_results = traced_func(*copied_inp, **kwargs) | |||||
check_results(traced_results, outp) | |||||
dumped_func = trace(symbolic=True, capture_as_const=True)(func) | |||||
dumped_results = dumped_func(*copied_inp, **kwargs) | |||||
check_results(dumped_results, outp) | |||||
file = io.BytesIO() | |||||
dump_info = dumped_func.dump(file) | |||||
file.seek(0) | |||||
# arg_name has pattern arg_xxx, xxx is int value | |||||
def take_number(arg_name): | |||||
return int(arg_name.split("_")[-1]) | |||||
input_names = dump_info[4] | |||||
inps_np = [i.numpy() for i in copied_inp] | |||||
input_names.sort(key=take_number) | |||||
inp_dict = dict(zip(input_names, inps_np)) | |||||
infer_cg = cgtools.GraphInference(file) | |||||
# assume #outputs == 1 | |||||
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0] | |||||
check_results(loaded_results, outp) | |||||
results = func(*inp_tensor, **kwargs) | results = func(*inp_tensor, **kwargs) | ||||
check_results(results, outp) | check_results(results, outp) |
@@ -36,7 +36,7 @@ def test_where(): | |||||
{"input": [maskv0, xv0, yv0]}, | {"input": [maskv0, xv0, yv0]}, | ||||
{"input": [maskv1, xv1, yv1]}, | {"input": [maskv1, xv1, yv1]}, | ||||
] | ] | ||||
opr_test(cases, F.where, ref_fn=np.where) | |||||
opr_test(cases, F.where, ref_fn=np.where, test_trace=False) | |||||
maskv2 = np.array([1, 1, 1], dtype=np.bool_) | maskv2 = np.array([1, 1, 1], dtype=np.bool_) | ||||
xv2 = np.array([1, 3, 2], dtype=np.float32) | xv2 = np.array([1, 3, 2], dtype=np.float32) | ||||
@@ -50,7 +50,7 @@ def test_where(): | |||||
{"input": [maskv2, xv2, yv2]}, | {"input": [maskv2, xv2, yv2]}, | ||||
{"input": [maskv3, xv3, yv3]}, | {"input": [maskv3, xv3, yv3]}, | ||||
] | ] | ||||
opr_test(cases, F.where, ref_fn=np.where) | |||||
opr_test(cases, F.where, ref_fn=np.where, test_trace=False) | |||||
def test_dropout(): | def test_dropout(): | ||||
@@ -115,14 +115,17 @@ def test_matmul(): | |||||
{"input": [data4, data5]}, | {"input": [data4, data5]}, | ||||
] | ] | ||||
for _ in range(0, batch_size): | for _ in range(0, batch_size): | ||||
# FIXME: remove test_trace=False in the future | |||||
opr_test( | opr_test( | ||||
cases, F.matmul, ref_fn=np.matmul, | |||||
cases, F.matmul, test_trace=False, ref_fn=np.matmul, | |||||
) | ) | ||||
# FIXME: remove test_trace=False in the future | |||||
opr_test( | opr_test( | ||||
[{"input": [data1, data4]}], | [{"input": [data1, data4]}], | ||||
F.matmul, | F.matmul, | ||||
ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)), | ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)), | ||||
test_trace=False, | |||||
transpose_b=True, | transpose_b=True, | ||||
) | ) | ||||
@@ -162,20 +162,24 @@ def test_linspace(): | |||||
{"input": [1, 9, 9]}, | {"input": [1, 9, 9]}, | ||||
{"input": [3, 10, 8]}, | {"input": [3, 10, 8]}, | ||||
] | ] | ||||
# FIXME: remove test_trace=False in the future | |||||
opr_test( | opr_test( | ||||
cases, | cases, | ||||
F.linspace, | F.linspace, | ||||
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ||||
test_trace=False, | |||||
) | ) | ||||
cases = [ | cases = [ | ||||
{"input": [9, 1, 9]}, | {"input": [9, 1, 9]}, | ||||
{"input": [10, 3, 8]}, | {"input": [10, 3, 8]}, | ||||
] | ] | ||||
# FIXME: remove test_trace=False in the future | |||||
opr_test( | opr_test( | ||||
cases, | cases, | ||||
F.linspace, | F.linspace, | ||||
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ||||
test_trace=False, | |||||
) | ) | ||||
@@ -184,30 +188,36 @@ def test_arange(): | |||||
{"input": [1, 9, 1]}, | {"input": [1, 9, 1]}, | ||||
{"input": [2, 10, 2]}, | {"input": [2, 10, 2]}, | ||||
] | ] | ||||
# FIXME: remove test_trace=False in the future | |||||
opr_test( | opr_test( | ||||
cases, | cases, | ||||
F.arange, | F.arange, | ||||
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ||||
test_trace=False, | |||||
) | ) | ||||
cases = [ | cases = [ | ||||
{"input": [9, 1, -1]}, | {"input": [9, 1, -1]}, | ||||
{"input": [10, 2, -2]}, | {"input": [10, 2, -2]}, | ||||
] | ] | ||||
# FIXME: remove test_trace=False in the future | |||||
opr_test( | opr_test( | ||||
cases, | cases, | ||||
F.arange, | F.arange, | ||||
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ||||
test_trace=False, | |||||
) | ) | ||||
cases = [ | cases = [ | ||||
{"input": [9.3, 1.2, -0.5]}, | {"input": [9.3, 1.2, -0.5]}, | ||||
{"input": [10.3, 2.1, -1.7]}, | {"input": [10.3, 2.1, -1.7]}, | ||||
] | ] | ||||
# FIXME: remove test_trace=False in the future | |||||
opr_test( | opr_test( | ||||
cases, | cases, | ||||
F.arange, | F.arange, | ||||
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ||||
test_trace=False, | |||||
) | ) | ||||
@@ -279,7 +289,8 @@ def test_broadcast(): | |||||
{"input": [data1, output1_shape], "output": output1_shape}, | {"input": [data1, output1_shape], "output": output1_shape}, | ||||
{"input": [data2, output2_shape], "output": output2_shape}, | {"input": [data2, output2_shape], "output": output2_shape}, | ||||
] | ] | ||||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||||
# FIXME: remove test_trace=False in the future | |||||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, test_trace=False) | |||||
x = F.ones((2, 1, 3)) | x = F.ones((2, 1, 3)) | ||||
with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||