|
|
@@ -756,24 +756,32 @@ class InternalGraph: |
|
|
|
return not end_nodes_set |
|
|
|
return False |
|
|
|
|
|
|
|
ref_count = lambda n: len(n.users) + (1 if n in self._outputs else 0) |
|
|
|
|
|
|
|
for n, v in zip(self._inputs, inputs): |
|
|
|
node2value[n] = v |
|
|
|
if ref_count(n) > 0: |
|
|
|
node2value[n] = [v, ref_count(n)] |
|
|
|
if n in self._watch_point: |
|
|
|
self._rst[n].append(v) |
|
|
|
if n in self._end_point and get_all_endnode_val(n, v): |
|
|
|
return list(endnode2value[i] for i in self._end_point) |
|
|
|
|
|
|
|
for expr in self._exprs: |
|
|
|
values = expr.interpret(*list(node2value[i] for i in expr.inputs)) |
|
|
|
values = expr.interpret(*list(node2value[i][0] for i in expr.inputs)) |
|
|
|
for n in expr.inputs: |
|
|
|
node2value[n][1] -= 1 |
|
|
|
if node2value[n][1] == 0: |
|
|
|
node2value.pop(n) |
|
|
|
if values is not None: |
|
|
|
for n, v in zip(expr.outputs, values): |
|
|
|
node2value[n] = v |
|
|
|
if ref_count(n) > 0: |
|
|
|
node2value[n] = [v, ref_count(n)] |
|
|
|
if n in self._watch_point: |
|
|
|
self._rst[n] = v |
|
|
|
if self._end_point and get_all_endnode_val(n, v): |
|
|
|
return list(endnode2value[i] for i in self._end_point) |
|
|
|
|
|
|
|
return list(node2value[i] for i in self._outputs) |
|
|
|
return list(node2value[i][0] for i in self._outputs) |
|
|
|
|
|
|
|
def eval(self, *inputs): |
|
|
|
assert len(inputs) == len(self._inputs) - 1 |
|
|
@@ -1575,6 +1583,7 @@ class TracedModule(Module): |
|
|
|
for index, inp in enumerate(expr.inputs): |
|
|
|
if inp is call_out: |
|
|
|
expr.inputs[index] = repl_dict[out] |
|
|
|
repl_dict[out].users.append(expr) |
|
|
|
|
|
|
|
continue |
|
|
|
repl_dict[out] = call.outputs[ind] |
|
|
|