Browse Source

feat(imperative/jit): catch input tensors name when tracing

GitOrigin-RevId: 9c69254866
tags/v1.3.1
Megvii Engine Team 4 years ago
parent
commit
e474994f46
2 changed files with 23 additions and 4 deletions
  1. +4
    -2
      imperative/python/megengine/jit/tracing.py
  2. +19
    -2
      imperative/python/test/unit/utils/test_dump_naming.py

+ 4
- 2
imperative/python/megengine/jit/tracing.py View File

@@ -772,7 +772,8 @@ class trace:
len(self._output_bindings)
)
)
if arg_names is None:
without_arg_names = arg_names is None
if without_arg_names:
arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
if arg_names and not isinstance(arg_names, collections.abc.Sequence):
arg_names = (arg_names,)
@@ -802,7 +803,7 @@ class trace:
dtype=info.dtype,
device=dumped_device(info),
shape=info.shape or (1,),
name=arg_names[i] if arg_names else None,
name=info.name if without_arg_names and info.name else arg_names[i],
)
for k, h in self._kwarg_bindings.items():
info = self._tinfo[h]
@@ -889,6 +890,7 @@ class trace:
return
h, info = self._new_handle()
info.external = False
info.name = x.c_name
info.device = x.device
info.dtype = x.dtype
info.shape = x.numpy().shape


+ 19
- 2
imperative/python/test/unit/utils/test_dump_naming.py View File

@@ -203,14 +203,31 @@ def test_with_same_operators(symbolic):
assert ops[-2].name == "simple.RELU[0]"


def test_not_keep_opr_name():
@pytest.mark.parametrize("symbolic", [False, True])
def test_not_keep_opr_name(symbolic):
def f(x):
return 2 * x

op = _dump_and_load(f, True, False)[-1]
op = _dump_and_load(f, symbolic, False)[-1]
assert op.name == "MUL(x,const<2>[2])[4]"


@pytest.mark.parametrize("tensor_name, var_name", [("data", "data"), (None, "arg_0")])
def test_catch_input_name(tensor_name, var_name):
def f(x):
return 2 * x

func = trace(f, symbolic=True, capture_as_const=True)
x = Tensor(np.ones(shape=(2, 3)), name=tensor_name)
func(x).numpy()
file = io.BytesIO()
func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2)
file.seek(0)
*_, outputs = G.load_graph(file)
op = cgtools.get_oprs_seq(outputs)[-1]
assert op.inputs[0].name == var_name


@pytest.mark.parametrize("symbolic", [False, True])
def test_quantized_module_auto_naming(symbolic):
class Simple(M.Module):


Loading…
Cancel
Save