|
|
@@ -511,3 +511,50 @@ def test_set_symbolic_shape(): |
|
|
|
set_symbolic_shape(False) |
|
|
|
assert var_a.shape == var_a.partial_shape |
|
|
|
set_symbolic_shape(saved_symbolic_shape) |
|
|
|
|
|
|
|
|
|
|
|
def test_replace_var_in_different_network(): |
|
|
|
|
|
|
|
a = Tensor([1, 2]) |
|
|
|
b = Tensor([3, 4]) |
|
|
|
|
|
|
|
@trace(symbolic=True, capture_as_const=True) |
|
|
|
def fwd(a, b): |
|
|
|
return (a + b) * 2 |
|
|
|
|
|
|
|
@trace(symbolic=True, capture_as_const=True) |
|
|
|
def fwd1(c, d): |
|
|
|
return c + d |
|
|
|
|
|
|
|
fwd(a, b) |
|
|
|
orig_model = io.BytesIO() |
|
|
|
fwd.dump( |
|
|
|
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False |
|
|
|
) |
|
|
|
orig_model.seek(0) |
|
|
|
|
|
|
|
fwd1(a, b) |
|
|
|
orig_model1 = io.BytesIO() |
|
|
|
fwd1.dump( |
|
|
|
orig_model1, |
|
|
|
arg_names=["c", "d"], |
|
|
|
output_names="o", |
|
|
|
optimize_for_inference=False, |
|
|
|
) |
|
|
|
orig_model1.seek(0) |
|
|
|
|
|
|
|
graph = Net.load(orig_model) |
|
|
|
graph1 = Net.load(orig_model1) |
|
|
|
vara = graph.var_filter.name("a").as_unique() |
|
|
|
varb = graph.var_filter.name("b").as_unique() |
|
|
|
varo = graph1.var_filter.name("o").as_unique() |
|
|
|
|
|
|
|
graph.replace_vars({vara: varo, varb: varo}) |
|
|
|
|
|
|
|
modified_model = io.BytesIO() |
|
|
|
graph.dump(modified_model) |
|
|
|
modified_model.seek(0) |
|
|
|
load_graph = GraphInference(modified_model) |
|
|
|
|
|
|
|
out = load_graph.run(a, b) |
|
|
|
np.testing.assert_equal(out["o"], [16, 24]) |