Browse Source

fix(utils/network): fix replace oprs

GitOrigin-RevId: eba27e3dfb
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
c54a731b9b
2 changed files with 10 additions and 1 deletions
  1. +1
    -0
      imperative/python/megengine/utils/network.py
  2. +9
    -1
      imperative/python/test/unit/utils/test_network.py

+ 1
- 0
imperative/python/megengine/utils/network.py View File

@@ -399,6 +399,7 @@ class Network:
var.owner = repl_dict[opr]
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__)
var.var = repl_dict[opr].outputs[ind].var
repl_dict[opr].outputs = opr.outputs
self._compile()

def get_opr_by_type(self, oprcls, unique=True):


+ 9
- 1
imperative/python/test/unit/utils/test_network.py View File

@@ -119,8 +119,16 @@ def test_replace_opr():
out1 = graph.add_dep_oprs(out1)
orig_opr = graph.opr_filter.has_input(vara).as_unique()

repl_dict = {orig_opr: out1[0].owner}
new_opr = out1[0].owner
repl_dict = {orig_opr: new_opr}
graph.replace_oprs(repl_dict)

var_out = orig_opr.outputs

for idx, node in enumerate(var_out):
assert node.owner is new_opr
assert node.owner.outputs[idx] is node

modified_model1 = io.BytesIO()
graph.dump(modified_model1)
modified_model1.seek(0)


Loading…
Cancel
Save