GitOrigin-RevId: 9c69254866
tags/v1.3.1
@@ -772,7 +772,8 @@ class trace: | |||||
len(self._output_bindings) | len(self._output_bindings) | ||||
) | ) | ||||
) | ) | ||||
if arg_names is None: | |||||
without_arg_names = arg_names is None | |||||
if without_arg_names: | |||||
arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))] | arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))] | ||||
if arg_names and not isinstance(arg_names, collections.abc.Sequence): | if arg_names and not isinstance(arg_names, collections.abc.Sequence): | ||||
arg_names = (arg_names,) | arg_names = (arg_names,) | ||||
@@ -802,7 +803,7 @@ class trace: | |||||
dtype=info.dtype, | dtype=info.dtype, | ||||
device=dumped_device(info), | device=dumped_device(info), | ||||
shape=info.shape or (1,), | shape=info.shape or (1,), | ||||
name=arg_names[i] if arg_names else None, | |||||
name=info.name if without_arg_names and info.name else arg_names[i], | |||||
) | ) | ||||
for k, h in self._kwarg_bindings.items(): | for k, h in self._kwarg_bindings.items(): | ||||
info = self._tinfo[h] | info = self._tinfo[h] | ||||
@@ -889,6 +890,7 @@ class trace: | |||||
return | return | ||||
h, info = self._new_handle() | h, info = self._new_handle() | ||||
info.external = False | info.external = False | ||||
info.name = x.c_name | |||||
info.device = x.device | info.device = x.device | ||||
info.dtype = x.dtype | info.dtype = x.dtype | ||||
info.shape = x.numpy().shape | info.shape = x.numpy().shape | ||||
@@ -203,14 +203,31 @@ def test_with_same_operators(symbolic): | |||||
assert ops[-2].name == "simple.RELU[0]" | assert ops[-2].name == "simple.RELU[0]" | ||||
def test_not_keep_opr_name(): | |||||
@pytest.mark.parametrize("symbolic", [False, True]) | |||||
def test_not_keep_opr_name(symbolic): | |||||
def f(x): | def f(x): | ||||
return 2 * x | return 2 * x | ||||
op = _dump_and_load(f, True, False)[-1] | |||||
op = _dump_and_load(f, symbolic, False)[-1] | |||||
assert op.name == "MUL(x,const<2>[2])[4]" | assert op.name == "MUL(x,const<2>[2])[4]" | ||||
@pytest.mark.parametrize("tensor_name, var_name", [("data", "data"), (None, "arg_0")]) | |||||
def test_catch_input_name(tensor_name, var_name): | |||||
def f(x): | |||||
return 2 * x | |||||
func = trace(f, symbolic=True, capture_as_const=True) | |||||
x = Tensor(np.ones(shape=(2, 3)), name=tensor_name) | |||||
func(x).numpy() | |||||
file = io.BytesIO() | |||||
func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2) | |||||
file.seek(0) | |||||
*_, outputs = G.load_graph(file) | |||||
op = cgtools.get_oprs_seq(outputs)[-1] | |||||
assert op.inputs[0].name == var_name | |||||
@pytest.mark.parametrize("symbolic", [False, True]) | @pytest.mark.parametrize("symbolic", [False, True]) | ||||
def test_quantized_module_auto_naming(symbolic): | def test_quantized_module_auto_naming(symbolic): | ||||
class Simple(M.Module): | class Simple(M.Module): | ||||