Browse Source

fix(mge/trace): use xpux device when dump

GitOrigin-RevId: f37285f70e
release-1.1
Megvii Engine Team 4 years ago
parent
commit
2e53077915
2 changed files with 10 additions and 5 deletions
  1. +8
    -3
      imperative/python/megengine/jit/tracing.py
  2. +2
    -2
      imperative/python/test/unit/test_tracing.py

+ 8
- 3
imperative/python/megengine/jit/tracing.py View File

@@ -20,6 +20,7 @@ import numpy as np
from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.ops import OprAttr
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
@@ -588,6 +589,8 @@ class trace:
len(self._output_bindings)
)
)
if arg_names is None:
arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
if arg_names and not isinstance(arg_names, collections.abc.Sequence):
arg_names = (arg_names,)
if arg_names and len(arg_names) != len(self._arg_bindings):
@@ -598,6 +601,8 @@ class trace:
)
output_names = output_names or self._output_names

dumped_device = as_device("xpux")

h2v = {}
graph = G.Graph()
# only graph_opt_level takes effect in dump
@@ -607,14 +612,14 @@ class trace:
info = self._tinfo[h]
h2v[h] = graph.make_h2d(
dtype=info.dtype,
device=info.device,
device=dumped_device,
shape=info.shape,
name=arg_names[i] if arg_names else None,
)
for k, h in self._kwarg_bindings.items():
info = self._tinfo[h]
h2v[h] = graph.make_h2d(
dtype=info.dtype, device=info.device, shape=info.shape, name=k
dtype=info.dtype, device=dumped_device, shape=info.shape, name=k
)

for op, ihandles, ohandles in self._seq:
@@ -625,7 +630,7 @@ class trace:
assert info.external
assert info.bound_data
h2v[h] = graph.make_const(
info.bound_data.numpy(), dtype=info.dtype, device=info.device
info.bound_data.numpy(), dtype=info.dtype, device=dumped_device
)
ivars.append(h2v[h])
ovars = apply(op, *ivars)


+ 2
- 2
imperative/python/test/unit/test_tracing.py View File

@@ -100,8 +100,8 @@ def test_dump():
file = io.BytesIO()
dump_info = f.dump(file)
assert dump_info.nr_opr == 3
np.testing.assert_equal(dump_info.inputs, ["h2d[0]", "h2d[2]"])
np.testing.assert_equal(dump_info.outputs, ["ADD(h2d[0],h2d[2])[4]"])
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"])
np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"])
file.seek(0)
result = cgtools.load_and_inference(file, [a, b])
np.testing.assert_equal(result[0], y)


Loading…
Cancel
Save