Browse Source

refactor(mge/cgtools): remove load_and_inference and use GraphInference

GitOrigin-RevId: 0e688ebd59
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
91a3580f75
3 changed files with 6 additions and 19 deletions
  1. +0
    -16
      imperative/python/megengine/utils/comp_graph_tools.py
  2. +2
    -1
      imperative/python/test/unit/module/test_qat.py
  3. +4
    -2
      imperative/python/test/unit/test_tracing.py

+ 0
- 16
imperative/python/megengine/utils/comp_graph_tools.py View File

@@ -27,7 +27,6 @@ __all__ = [
"replace_vars", "replace_vars",
"replace_oprs", "replace_oprs",
"set_priority_to_id", "set_priority_to_id",
"load_and_inference",
"GraphInference", "GraphInference",
] ]


@@ -274,21 +273,6 @@ def replace_oprs(
return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)




def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.ndarray]:
"""
Loads a serialized computing graph and run inference with input data.

:param file: path or handle of the input file.
:param inp_data_list: list of input data.
:return: list of inference results.

"""
graph = GraphInference(file)
result = graph.run(*inp_data_list)
out_data_list = list(result.values())
return out_data_list


class GraphInference: class GraphInference:
""" """
Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph. Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph.


+ 2
- 1
imperative/python/test/unit/module/test_qat.py View File

@@ -201,5 +201,6 @@ def test_quantize_batchmatmul_activation():
file = io.BytesIO() file = io.BytesIO()
f.dump(file, enable_nchw4=True) f.dump(file, enable_nchw4=True)
file.seek(0) file.seek(0)
dumped_outputs = cgtools.load_and_inference(file, [inputs])[0]
infer_cg = cgtools.GraphInference(file)[0]
dumped_outputs = list(infer_cg.run(inputs.numpy()).values())[0]
np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6) np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)

+ 4
- 2
imperative/python/test/unit/test_tracing.py View File

@@ -141,7 +141,8 @@ def test_dump():
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"])
np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"]) np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"])
file.seek(0) file.seek(0)
result = cgtools.load_and_inference(file, [a, b])
infer_cg = cgtools.GraphInference(file)
result = list((infer_cg.run(a, b)).values())[0]
np.testing.assert_equal(result[0], y) np.testing.assert_equal(result[0], y)




@@ -161,7 +162,8 @@ def test_capture_dump():
file = io.BytesIO() file = io.BytesIO()
f.dump(file) f.dump(file)
file.seek(0) file.seek(0)
result = cgtools.load_and_inference(file, [x])
infer_cg = cgtools.GraphInference(file)
result = list((infer_cg.run(x)).values())[0]
np.testing.assert_equal(result[0], y) np.testing.assert_equal(result[0], y)






Loading…
Cancel
Save