|
|
@@ -26,18 +26,18 @@ from megengine.jit import exclude_from_trace, trace |
|
|
|
|
|
|
|
|
|
|
|
def load_and_inference(file, inp_data): |
|
|
|
cg, _, out_list = mgb_graph.load_graph(file) |
|
|
|
cg, _, out_list = G.load_graph(file) |
|
|
|
inputs = cgtools.get_dep_vars(out_list, "Host2DeviceCopy") |
|
|
|
replace_dict = {} |
|
|
|
inp_node_list = [] |
|
|
|
for i in inputs: |
|
|
|
inp_node = mgb_graph.InputNode( |
|
|
|
inp_node = G.InputNode( |
|
|
|
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph |
|
|
|
) |
|
|
|
replace_dict[i] = inp_node.outputs[0] |
|
|
|
inp_node_list.append(inp_node) |
|
|
|
new_out = cgtools.replace_vars(out_list, replace_dict) |
|
|
|
out_node_list = [mgb_graph.OutputNode(i) for i in new_out] |
|
|
|
out_node_list = [G.OutputNode(i) for i in new_out] |
|
|
|
new_out_list = [i.outputs[0] for i in out_node_list] |
|
|
|
new_cg = new_out_list[0].graph |
|
|
|
func = new_cg.compile(new_out_list) |
|
|
@@ -150,6 +150,7 @@ def test_capture_dump(): |
|
|
|
np.testing.assert_equal(result[0], y) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip(reason="get MultipleDeviceTensorHolder instead of SharedDeviceTensor") |
|
|
|
def test_dump_volatile(): |
|
|
|
p = as_raw_tensor([2]) |
|
|
|
|
|
|
@@ -168,7 +169,7 @@ def test_dump_volatile(): |
|
|
|
file = io.BytesIO() |
|
|
|
f.dump(file) |
|
|
|
file.seek(0) |
|
|
|
cg, _, outputs = mgb_graph.load_graph(file) |
|
|
|
cg, _, outputs = G.load_graph(file) |
|
|
|
(out,) = outputs |
|
|
|
assert ( |
|
|
|
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) |
|
|
|