diff --git a/imperative/python/megengine/utils/network.py b/imperative/python/megengine/utils/network.py index 3409331e..16c3b59b 100644 --- a/imperative/python/megengine/utils/network.py +++ b/imperative/python/megengine/utils/network.py @@ -371,7 +371,9 @@ class Network: if repl_var is var: continue for opnode in var.users: - assert var in opnode.inputs + # use method 'is' instead of 'in' to avoid + # compare VarNode use elemwise equal + assert any([var is _ for _ in opnode.inputs]) opnode.inputs = [repl_var if var is i else i for i in opnode.inputs] if opnode not in repl_var.users: repl_var.users.append(opnode) diff --git a/imperative/python/test/unit/utils/test_network.py b/imperative/python/test/unit/utils/test_network.py index ebf62fa0..f223d69d 100644 --- a/imperative/python/test/unit/utils/test_network.py +++ b/imperative/python/test/unit/utils/test_network.py @@ -511,3 +511,50 @@ def test_set_symbolic_shape(): set_symbolic_shape(False) assert var_a.shape == var_a.partial_shape set_symbolic_shape(saved_symbolic_shape) + + +def test_replace_var_in_different_network(): + + a = Tensor([1, 2]) + b = Tensor([3, 4]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(a, b): + return (a + b) * 2 + + @trace(symbolic=True, capture_as_const=True) + def fwd1(c, d): + return c + d + + fwd(a, b) + orig_model = io.BytesIO() + fwd.dump( + orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False + ) + orig_model.seek(0) + + fwd1(a, b) + orig_model1 = io.BytesIO() + fwd1.dump( + orig_model1, + arg_names=["c", "d"], + output_names="o", + optimize_for_inference=False, + ) + orig_model1.seek(0) + + graph = Net.load(orig_model) + graph1 = Net.load(orig_model1) + vara = graph.var_filter.name("a").as_unique() + varb = graph.var_filter.name("b").as_unique() + varo = graph1.var_filter.name("o").as_unique() + + graph.replace_vars({vara: varo, varb: varo}) + + modified_model = io.BytesIO() + graph.dump(modified_model) + modified_model.seek(0) + load_graph = GraphInference(modified_model) + + out = load_graph.run(a, b) + np.testing.assert_equal(out["o"], [16, 24])