GitOrigin-RevId: 0e688ebd59
tags/v1.3.0
@@ -27,7 +27,6 @@ __all__ = [ | |||||
"replace_vars", | "replace_vars", | ||||
"replace_oprs", | "replace_oprs", | ||||
"set_priority_to_id", | "set_priority_to_id", | ||||
"load_and_inference", | |||||
"GraphInference", | "GraphInference", | ||||
] | ] | ||||
@@ -274,21 +273,6 @@ def replace_oprs( | |||||
return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) | return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) | ||||
def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.ndarray]: | |||||
""" | |||||
Loads a serialized computing graph and run inference with input data. | |||||
:param file: path or handle of the input file. | |||||
:param inp_data_list: list of input data. | |||||
:return: list of inference results. | |||||
""" | |||||
graph = GraphInference(file) | |||||
result = graph.run(*inp_data_list) | |||||
out_data_list = list(result.values()) | |||||
return out_data_list | |||||
class GraphInference: | class GraphInference: | ||||
""" | """ | ||||
Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph. | Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph. | ||||
@@ -201,5 +201,6 @@ def test_quantize_batchmatmul_activation(): | |||||
file = io.BytesIO() | file = io.BytesIO() | ||||
f.dump(file, enable_nchw4=True) | f.dump(file, enable_nchw4=True) | ||||
file.seek(0) | file.seek(0) | ||||
dumped_outputs = cgtools.load_and_inference(file, [inputs])[0] | |||||
infer_cg = cgtools.GraphInference(file)[0] | |||||
dumped_outputs = list(infer_cg.run(inputs.numpy()).values())[0] | |||||
np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6) | np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6) |
@@ -141,7 +141,8 @@ def test_dump(): | |||||
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) | np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) | ||||
np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"]) | np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"]) | ||||
file.seek(0) | file.seek(0) | ||||
result = cgtools.load_and_inference(file, [a, b]) | |||||
infer_cg = cgtools.GraphInference(file) | |||||
result = list((infer_cg.run(a, b)).values())[0] | |||||
np.testing.assert_equal(result[0], y) | np.testing.assert_equal(result[0], y) | ||||
@@ -161,7 +162,8 @@ def test_capture_dump(): | |||||
file = io.BytesIO() | file = io.BytesIO() | ||||
f.dump(file) | f.dump(file) | ||||
file.seek(0) | file.seek(0) | ||||
result = cgtools.load_and_inference(file, [x]) | |||||
infer_cg = cgtools.GraphInference(file) | |||||
result = list((infer_cg.run(x)).values())[0] | |||||
np.testing.assert_equal(result[0], y) | np.testing.assert_equal(result[0], y) | ||||