|
@@ -119,8 +119,16 @@ def test_replace_opr(): |
|
|
out1 = graph.add_dep_oprs(out1) |
|
|
out1 = graph.add_dep_oprs(out1) |
|
|
orig_opr = graph.opr_filter.has_input(vara).as_unique() |
|
|
orig_opr = graph.opr_filter.has_input(vara).as_unique() |
|
|
|
|
|
|
|
|
repl_dict = {orig_opr: out1[0].owner} |
|
|
|
|
|
|
|
|
new_opr = out1[0].owner |
|
|
|
|
|
repl_dict = {orig_opr: new_opr} |
|
|
graph.replace_oprs(repl_dict) |
|
|
graph.replace_oprs(repl_dict) |
|
|
|
|
|
|
|
|
|
|
|
var_out = orig_opr.outputs |
|
|
|
|
|
|
|
|
|
|
|
for idx, node in enumerate(var_out): |
|
|
|
|
|
assert node.owner is new_opr |
|
|
|
|
|
assert node.owner.outputs[idx] is node |
|
|
|
|
|
|
|
|
modified_model1 = io.BytesIO() |
|
|
modified_model1 = io.BytesIO() |
|
|
graph.dump(modified_model1) |
|
|
graph.dump(modified_model1) |
|
|
modified_model1.seek(0) |
|
|
modified_model1.seek(0) |
|
|