|
|
@@ -350,6 +350,7 @@ class trace: |
|
|
|
lazy_eval_graph.options.graph_opt_level = self._graph_opt_level |
|
|
|
else: |
|
|
|
lazy_eval_graph.options.graph_opt_level = 2 |
|
|
|
lazy_eval_graph.set_priority_to_id([*lazy_eval_links, *readers]) |
|
|
|
lazy_eval_graph.compile(*lazy_eval_links, *readers) |
|
|
|
lazy_eval_graph() |
|
|
|
for r, x in zip(readers, lazy_eval_tensors): |
|
|
@@ -484,7 +485,8 @@ class trace: |
|
|
|
# graph.options.graph_opt_level = 0 |
|
|
|
need_reset_nodes = self._need_reset_nodes = [] |
|
|
|
# links enforce ordering of I/O nodes |
|
|
|
links = () |
|
|
|
in_out_links = () |
|
|
|
io_links = () |
|
|
|
readers = [] |
|
|
|
|
|
|
|
if self._capture_as_const: |
|
|
@@ -499,7 +501,7 @@ class trace: |
|
|
|
) |
|
|
|
need_reset_nodes.append(opnode) |
|
|
|
info.varnode = opnode.outputs[0] |
|
|
|
links += opnode.outputs[1:] |
|
|
|
in_out_links += opnode.outputs[1:] |
|
|
|
|
|
|
|
for op, ihandles, ohandles in self._seq: |
|
|
|
if isinstance(op, Const): |
|
|
@@ -536,7 +538,7 @@ class trace: |
|
|
|
) |
|
|
|
else: |
|
|
|
opnode = info.data_setter = G.InputNode( |
|
|
|
*links, |
|
|
|
*in_out_links, |
|
|
|
device=info.device, |
|
|
|
dtype=info.dtype, |
|
|
|
shape=info.shape or (1,), |
|
|
@@ -544,45 +546,48 @@ class trace: |
|
|
|
use_static_shape=_input_node_use_static_shape(), |
|
|
|
) |
|
|
|
need_reset_nodes.append(opnode) |
|
|
|
info.varnode, *links = opnode.outputs |
|
|
|
if require_links and i == 0 and len(links) > 0: |
|
|
|
info.varnode = apply(VirtualDep(), info.varnode, *links)[0] |
|
|
|
links = (info.varnode,) |
|
|
|
info.varnode, *in_out_links = opnode.outputs |
|
|
|
if require_links and i == 0 and len(io_links) > 0: |
|
|
|
info.varnode = apply( |
|
|
|
VirtualDep(str(io_links[0].device)), info.varnode, *io_links |
|
|
|
)[0] |
|
|
|
io_links = (info.varnode,) |
|
|
|
|
|
|
|
ivars.append(info.varnode) |
|
|
|
ovars = apply(op, *ivars) |
|
|
|
if require_links and len(ovars) > 0: |
|
|
|
links = (ovars[0],) |
|
|
|
io_links = (ovars[0],) |
|
|
|
assert len(ovars) == len(ohandles) |
|
|
|
for h, v in zip(ohandles, ovars): |
|
|
|
info = self._tinfo[h] |
|
|
|
info.varnode = v |
|
|
|
|
|
|
|
def add_reader(opnode): |
|
|
|
nonlocal links |
|
|
|
nonlocal in_out_links |
|
|
|
need_reset_nodes.append(opnode) |
|
|
|
readers.append(opnode.outputs[0]) |
|
|
|
links = opnode.outputs |
|
|
|
in_out_links = opnode.outputs |
|
|
|
|
|
|
|
if info.data_read: |
|
|
|
# Shape can be obtained from data so doesn't need its own |
|
|
|
# output node. On the other hand, value is read separately |
|
|
|
# to leverage eager h2d copy |
|
|
|
info.shape_read = False |
|
|
|
opnode = info.data_reader = G.OutputNode(v, *links) |
|
|
|
opnode = info.data_reader = G.OutputNode(v, *in_out_links) |
|
|
|
add_reader(opnode) |
|
|
|
if info.value_read: |
|
|
|
opnode = info.value_reader = G.ValueOutputNode(v, *links) |
|
|
|
opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links) |
|
|
|
add_reader(opnode) |
|
|
|
if info.shape_read: |
|
|
|
opnode = info.shape_reader = G.AttrOutputNode(v, *links) |
|
|
|
opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) |
|
|
|
add_reader(opnode) |
|
|
|
# FIXME |
|
|
|
if self._graph_opt_level is not None: |
|
|
|
graph.options.graph_opt_level = self._graph_opt_level |
|
|
|
else: |
|
|
|
graph.options.graph_opt_level = 2 |
|
|
|
graph.compile(*readers, *links) |
|
|
|
graph.set_priority_to_id([*readers, *in_out_links, *io_links]) |
|
|
|
graph.compile(*readers, *in_out_links, *io_links) |
|
|
|
|
|
|
|
def _reset_exec_env(self): |
|
|
|
for opnode in self._need_reset_nodes: |
|
|
@@ -1107,7 +1112,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): |
|
|
|
|
|
|
|
if require_links and active_trace._lazy_eval_links: |
|
|
|
assert len(ivars) > 0, "op should has at least one input" |
|
|
|
ivars[0] = apply(VirtualDep(), ivars[0], *active_trace._lazy_eval_links)[0] |
|
|
|
ivars[0] = apply( |
|
|
|
VirtualDep(str(active_trace._lazy_eval_links[0].device)), |
|
|
|
ivars[0], |
|
|
|
*active_trace._lazy_eval_links, |
|
|
|
)[0] |
|
|
|
active_trace._lazy_eval_links = (ivars[0],) |
|
|
|
|
|
|
|
ovars = apply(op, *ivars) |
|
|
|