Browse Source

feat(traced_module): delete value of node when it will not be used by any expr

GitOrigin-RevId: 3fb7350d01
release-1.6
Megvii Engine Team 3 years ago
parent
commit
526c82c858
1 changed files with 13 additions and 4 deletions
  1. +13
    -4
      imperative/python/megengine/traced_module/traced_module.py

+ 13
- 4
imperative/python/megengine/traced_module/traced_module.py View File

@@ -756,24 +756,32 @@ class InternalGraph:
return not end_nodes_set
return False

ref_count = lambda n: len(n.users) + (1 if n in self._outputs else 0)

for n, v in zip(self._inputs, inputs):
node2value[n] = v
if ref_count(n) > 0:
node2value[n] = [v, ref_count(n)]
if n in self._watch_point:
self._rst[n].append(v)
if n in self._end_point and get_all_endnode_val(n, v):
return list(endnode2value[i] for i in self._end_point)

for expr in self._exprs:
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
values = expr.interpret(*list(node2value[i][0] for i in expr.inputs))
for n in expr.inputs:
node2value[n][1] -= 1
if node2value[n][1] == 0:
node2value.pop(n)
if values is not None:
for n, v in zip(expr.outputs, values):
node2value[n] = v
if ref_count(n) > 0:
node2value[n] = [v, ref_count(n)]
if n in self._watch_point:
self._rst[n] = v
if self._end_point and get_all_endnode_val(n, v):
return list(endnode2value[i] for i in self._end_point)

return list(node2value[i] for i in self._outputs)
return list(node2value[i][0] for i in self._outputs)

def eval(self, *inputs):
assert len(inputs) == len(self._inputs) - 1
@@ -1575,6 +1583,7 @@ class TracedModule(Module):
for index, inp in enumerate(expr.inputs):
if inp is call_out:
expr.inputs[index] = repl_dict[out]
repl_dict[out].users.append(expr)

continue
repl_dict[out] = call.outputs[ind]


Loading…
Cancel
Save