|
|
@@ -141,7 +141,8 @@ def test_dump(): |
|
|
|
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) |
|
|
|
np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"]) |
|
|
|
file.seek(0) |
|
|
|
result = cgtools.load_and_inference(file, [a, b]) |
|
|
|
infer_cg = cgtools.GraphInference(file) |
|
|
|
result = list((infer_cg.run(a, b)).values())[0] |
|
|
|
np.testing.assert_equal(result[0], y) |
|
|
|
|
|
|
|
|
|
|
@@ -161,7 +162,8 @@ def test_capture_dump(): |
|
|
|
file = io.BytesIO() |
|
|
|
f.dump(file) |
|
|
|
file.seek(0) |
|
|
|
result = cgtools.load_and_inference(file, [x]) |
|
|
|
infer_cg = cgtools.GraphInference(file) |
|
|
|
result = list((infer_cg.run(x)).values())[0] |
|
|
|
np.testing.assert_equal(result[0], y) |
|
|
|
|
|
|
|
|
|
|
|