|
|
@@ -20,6 +20,7 @@ import numpy as np |
|
|
|
from ..core._imperative_rt import GraphProfiler |
|
|
|
from ..core._imperative_rt.ops import OprAttr |
|
|
|
from ..core._trace_option import set_symbolic_shape |
|
|
|
from ..core._wrap import device as as_device |
|
|
|
from ..core.ops.special import Const |
|
|
|
from ..core.tensor import megbrain_graph as G |
|
|
|
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply |
|
|
@@ -588,6 +589,8 @@ class trace: |
|
|
|
len(self._output_bindings) |
|
|
|
) |
|
|
|
) |
|
|
|
if arg_names is None: |
|
|
|
arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))] |
|
|
|
if arg_names and not isinstance(arg_names, collections.abc.Sequence): |
|
|
|
arg_names = (arg_names,) |
|
|
|
if arg_names and len(arg_names) != len(self._arg_bindings): |
|
|
@@ -598,6 +601,8 @@ class trace: |
|
|
|
) |
|
|
|
output_names = output_names or self._output_names |
|
|
|
|
|
|
|
dumped_device = as_device("xpux") |
|
|
|
|
|
|
|
h2v = {} |
|
|
|
graph = G.Graph() |
|
|
|
# only graph_opt_level takes effect in dump |
|
|
@@ -607,14 +612,14 @@ class trace: |
|
|
|
info = self._tinfo[h] |
|
|
|
h2v[h] = graph.make_h2d( |
|
|
|
dtype=info.dtype, |
|
|
|
device=info.device, |
|
|
|
device=dumped_device, |
|
|
|
shape=info.shape, |
|
|
|
name=arg_names[i] if arg_names else None, |
|
|
|
) |
|
|
|
for k, h in self._kwarg_bindings.items(): |
|
|
|
info = self._tinfo[h] |
|
|
|
h2v[h] = graph.make_h2d( |
|
|
|
dtype=info.dtype, device=info.device, shape=info.shape, name=k |
|
|
|
dtype=info.dtype, device=dumped_device, shape=info.shape, name=k |
|
|
|
) |
|
|
|
|
|
|
|
for op, ihandles, ohandles in self._seq: |
|
|
@@ -625,7 +630,7 @@ class trace: |
|
|
|
assert info.external |
|
|
|
assert info.bound_data |
|
|
|
h2v[h] = graph.make_const( |
|
|
|
info.bound_data.numpy(), dtype=info.dtype, device=info.device |
|
|
|
info.bound_data.numpy(), dtype=info.dtype, device=dumped_device |
|
|
|
) |
|
|
|
ivars.append(h2v[h]) |
|
|
|
ovars = apply(op, *ivars) |
|
|
|