diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index b18f66a7..41d90de8 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -756,24 +756,32 @@ class InternalGraph: return not end_nodes_set return False + ref_count = lambda n: len(n.users) + (1 if n in self._outputs else 0) + for n, v in zip(self._inputs, inputs): - node2value[n] = v + if ref_count(n) > 0: + node2value[n] = [v, ref_count(n)] if n in self._watch_point: self._rst[n].append(v) if n in self._end_point and get_all_endnode_val(n, v): return list(endnode2value[i] for i in self._end_point) for expr in self._exprs: - values = expr.interpret(*list(node2value[i] for i in expr.inputs)) + values = expr.interpret(*list(node2value[i][0] for i in expr.inputs)) + for n in expr.inputs: + node2value[n][1] -= 1 + if node2value[n][1] == 0: + node2value.pop(n) if values is not None: for n, v in zip(expr.outputs, values): - node2value[n] = v + if ref_count(n) > 0: + node2value[n] = [v, ref_count(n)] if n in self._watch_point: self._rst[n] = v if self._end_point and get_all_endnode_val(n, v): return list(endnode2value[i] for i in self._end_point) - return list(node2value[i] for i in self._outputs) + return list(node2value[i][0] for i in self._outputs) def eval(self, *inputs): assert len(inputs) == len(self._inputs) - 1 @@ -1575,6 +1583,7 @@ class TracedModule(Module): for index, inp in enumerate(expr.inputs): if inp is call_out: expr.inputs[index] = repl_dict[out] + repl_dict[out].users.append(expr) continue repl_dict[out] = call.outputs[ind]