|
|
@@ -22,29 +22,6 @@ from megengine.functional import exp, log |
|
|
|
from megengine.jit import exclude_from_trace, trace |
|
|
|
|
|
|
|
|
|
|
|
def load_and_inference(file, inp_data): |
|
|
|
cg, _, out_list = G.load_graph(file) |
|
|
|
inputs = cgtools.get_dep_vars(out_list, "Host2DeviceCopy") |
|
|
|
replace_dict = {} |
|
|
|
inp_node_list = [] |
|
|
|
for i in inputs: |
|
|
|
inp_node = G.InputNode( |
|
|
|
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph |
|
|
|
) |
|
|
|
replace_dict[i] = inp_node.outputs[0] |
|
|
|
inp_node_list.append(inp_node) |
|
|
|
new_out = cgtools.replace_vars(out_list, replace_dict) |
|
|
|
out_node_list = [G.OutputNode(i) for i in new_out] |
|
|
|
new_out_list = [i.outputs[0] for i in out_node_list] |
|
|
|
new_cg = new_out_list[0].graph |
|
|
|
func = new_cg.compile(new_out_list) |
|
|
|
for node, value in zip(inp_node_list, inp_data): |
|
|
|
node.set_value(as_raw_tensor(value)._dev_tensor()) |
|
|
|
func.execute() |
|
|
|
out_data_list = [o.get_value().numpy() for o in out_node_list] |
|
|
|
return out_data_list |
|
|
|
|
|
|
|
|
|
|
|
def test_trace(): |
|
|
|
for symbolic in [False, True]: |
|
|
|
|
|
|
@@ -124,7 +101,7 @@ def test_dump(): |
|
|
|
np.testing.assert_equal(dump_info.inputs, ["h2d[0]", "h2d[2]"]) |
|
|
|
np.testing.assert_equal(dump_info.outputs, ["ADD(h2d[0],h2d[2])[4]"]) |
|
|
|
file.seek(0) |
|
|
|
result = load_and_inference(file, [a, b]) |
|
|
|
result = cgtools.load_and_inference(file, [a, b]) |
|
|
|
np.testing.assert_equal(result[0], y) |
|
|
|
|
|
|
|
|
|
|
@@ -146,7 +123,7 @@ def test_capture_dump(): |
|
|
|
file = io.BytesIO() |
|
|
|
f.dump(file) |
|
|
|
file.seek(0) |
|
|
|
result = load_and_inference(file, [x]) |
|
|
|
result = cgtools.load_and_inference(file, [x]) |
|
|
|
np.testing.assert_equal(result[0], y) |
|
|
|
|
|
|
|
|
|
|
|