From 2e530779152d68cab61c690974d21a0e5d634727 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 11 Nov 2020 11:52:32 +0800 Subject: [PATCH] fix(mge/trace): use xpux device when dump GitOrigin-RevId: f37285f70e9d21ca0c3951ebe917351e94e1ec3f --- imperative/python/megengine/jit/tracing.py | 11 ++++++++--- imperative/python/test/unit/test_tracing.py | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index a2d141b4..927b954c 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -20,6 +20,7 @@ import numpy as np from ..core._imperative_rt import GraphProfiler from ..core._imperative_rt.ops import OprAttr from ..core._trace_option import set_symbolic_shape +from ..core._wrap import device as as_device from ..core.ops.special import Const from ..core.tensor import megbrain_graph as G from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply @@ -588,6 +589,8 @@ class trace: len(self._output_bindings) ) ) + if arg_names is None: + arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))] if arg_names and not isinstance(arg_names, collections.abc.Sequence): arg_names = (arg_names,) if arg_names and len(arg_names) != len(self._arg_bindings): @@ -598,6 +601,8 @@ class trace: ) output_names = output_names or self._output_names + dumped_device = as_device("xpux") + h2v = {} graph = G.Graph() # only graph_opt_level takes effect in dump @@ -607,14 +612,14 @@ class trace: info = self._tinfo[h] h2v[h] = graph.make_h2d( dtype=info.dtype, - device=info.device, + device=dumped_device, shape=info.shape, name=arg_names[i] if arg_names else None, ) for k, h in self._kwarg_bindings.items(): info = self._tinfo[h] h2v[h] = graph.make_h2d( - dtype=info.dtype, device=info.device, shape=info.shape, name=k + dtype=info.dtype, device=dumped_device, shape=info.shape, name=k ) for op, ihandles, ohandles in self._seq: @@ -625,7 +630,7 @@ class trace: assert info.external assert info.bound_data h2v[h] = graph.make_const( - info.bound_data.numpy(), dtype=info.dtype, device=info.device + info.bound_data.numpy(), dtype=info.dtype, device=dumped_device ) ivars.append(h2v[h]) ovars = apply(op, *ivars) diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 1c4699d3..32ec58f5 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -100,8 +100,8 @@ def test_dump(): file = io.BytesIO() dump_info = f.dump(file) assert dump_info.nr_opr == 3 - np.testing.assert_equal(dump_info.inputs, ["h2d[0]", "h2d[2]"]) - np.testing.assert_equal(dump_info.outputs, ["ADD(h2d[0],h2d[2])[4]"]) + np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) + np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"]) file.seek(0) result = cgtools.load_and_inference(file, [a, b]) np.testing.assert_equal(result[0], y)