GitOrigin-RevId: 258c03ee34
tags/v1.0.0-rc1
@@ -235,15 +235,36 @@ class Tensor: | |||||
return self.__val.dtype | return self.__val.dtype | ||||
return self._symvar.dtype | return self._symvar.dtype | ||||
def set_dtype(self, dtype: str = None): | |||||
@dtype.setter | |||||
def dtype(self, dtype: str = None): | |||||
r"""Set the data type of the tensor. | r"""Set the data type of the tensor. | ||||
""" | """ | ||||
if self.__val is not None: | if self.__val is not None: | ||||
self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) | self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) | ||||
elif self.__sym_override is not None: | |||||
self.__sym_override = self.__sym_override.astype(dtype) | |||||
elif self.__sym is not None: | elif self.__sym is not None: | ||||
self.__sym = self.__sym.astype(dtype) | self.__sym = self.__sym.astype(dtype) | ||||
@property | @property | ||||
def name(self): | |||||
r"""Get the tensor name, does not support Parameter and Buffer. | |||||
""" | |||||
return self._symvar.name | |||||
@name.setter | |||||
def name(self, name: str = None): | |||||
r"""Set the tensor name, does not support Parameter and Buffer. | |||||
""" | |||||
if self.__val is not None: | |||||
raise ValueError("name setting is not available for Parameter or Buffer.") | |||||
if self.__sym_override is not None: | |||||
self.__sym_override = self.__sym_override.rename(name) | |||||
if self.__sym is not None: | |||||
assert not self.__val | |||||
self.__sym = self.__sym.rename(name) | |||||
@property | |||||
def _comp_node(self): | def _comp_node(self): | ||||
if self.__val is not None: | if self.__val is not None: | ||||
return self.__val.comp_node | return self.__val.comp_node | ||||
@@ -436,6 +436,7 @@ class trace: | |||||
arg_names=None, | arg_names=None, | ||||
append=False, | append=False, | ||||
optimize_for_inference=False, | optimize_for_inference=False, | ||||
output_names=None, | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
""" | """ | ||||
@@ -446,6 +447,8 @@ class trace: | |||||
:param append: whether output is appended to ``fpath``. | :param append: whether output is appended to ``fpath``. | ||||
:param optimize_for_inference: whether to enable optimize_for_inference | :param optimize_for_inference: whether to enable optimize_for_inference | ||||
pass before dump. | pass before dump. | ||||
:param output_names: names of the output tensors in the traced function, | |||||
will use the default name if does not specify. | |||||
:param enable_io16xc32: whether to use float16 for I/O between oprs and use | :param enable_io16xc32: whether to use float16 for I/O between oprs and use | ||||
float32 as internal computation precision. Note the output var would be | float32 as internal computation precision. Note the output var would be | ||||
@@ -488,6 +491,17 @@ class trace: | |||||
len(self._args), len(arg_names) | len(self._args), len(arg_names) | ||||
) | ) | ||||
) | ) | ||||
if isinstance(output_names, str): | |||||
output_names = [output_names] | |||||
if output_names is None: | |||||
output_names = [var.name for var in self._sym_outputs] | |||||
elif len(output_names) != len(self._sym_outputs): | |||||
raise ValueError( | |||||
"len(output_names) should be {}, got {}".format( | |||||
len(self._sym_outputs), len(output_names) | |||||
) | |||||
) | |||||
optimize_for_inference_args_map = { | optimize_for_inference_args_map = { | ||||
"enable_io16xc32": "f16_io_f32_comp", | "enable_io16xc32": "f16_io_f32_comp", | ||||
"enable_ioc16": "f16_io_comp", | "enable_ioc16": "f16_io_comp", | ||||
@@ -541,6 +555,8 @@ class trace: | |||||
sym_outputs = mgb.optimize_for_inference( | sym_outputs = mgb.optimize_for_inference( | ||||
sym_outputs, **optimize_for_inference_kwargs | sym_outputs, **optimize_for_inference_kwargs | ||||
) | ) | ||||
for var, name in zip(sym_outputs, output_names): | |||||
var.rename(name) | |||||
mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append) | mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append) | ||||
def get_profile(self): | def get_profile(self): | ||||
@@ -464,7 +464,7 @@ class Module(metaclass=ABCMeta): | |||||
# For quantized dtype, the initialized dtype | # For quantized dtype, the initialized dtype | ||||
# scale/zero_points maybe invalid, use pretrained dtype instead. | # scale/zero_points maybe invalid, use pretrained dtype instead. | ||||
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | ||||
var.set_dtype(to_be_load.dtype) | |||||
var.dtype = to_be_load.dtype | |||||
var.set_value(to_be_load) | var.set_value(to_be_load) | ||||
loaded.append(k) | loaded.append(k) | ||||
@@ -46,29 +46,46 @@ def test_tensor_set_dtype(): | |||||
) | ) | ||||
t = mge.Parameter(np.ones((3, 4), dtype="float32")) | t = mge.Parameter(np.ones((3, 4), dtype="float32")) | ||||
t.set_dtype(mgb.dtype.qint8(0.1)) | |||||
t.dtype = mgb.dtype.qint8(0.1) | |||||
check_dtype_value(t, 0.1, 10) | check_dtype_value(t, 0.1, 10) | ||||
t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | ||||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||||
t.dtype = mgb.dtype.qint8(0.3) | |||||
check_dtype_value(t, 0.3, 3) | check_dtype_value(t, 0.3, 3) | ||||
t = mge.Buffer(np.ones((3, 4), dtype="float32")) | t = mge.Buffer(np.ones((3, 4), dtype="float32")) | ||||
t.set_dtype(mgb.dtype.qint8(0.1)) | |||||
t.dtype = mgb.dtype.qint8(0.1) | |||||
check_dtype_value(t, 0.1, 10) | check_dtype_value(t, 0.1, 10) | ||||
t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | ||||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||||
t.dtype = mgb.dtype.qint8(0.3) | |||||
check_dtype_value(t, 0.3, 3) | check_dtype_value(t, 0.3, 3) | ||||
t = mge.Buffer(np.ones((3, 4), dtype="float32")) | t = mge.Buffer(np.ones((3, 4), dtype="float32")) | ||||
s = t + 1 | s = t + 1 | ||||
s.set_dtype(mgb.dtype.qint8(0.2)) | |||||
s.dtype = mgb.dtype.qint8(0.2) | |||||
check_dtype_value(s, 0.2, 10) | check_dtype_value(s, 0.2, 10) | ||||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||||
t.dtype = mgb.dtype.qint8(0.3) | |||||
s = t + 1 | s = t + 1 | ||||
s.set_dtype(mgb.dtype.qint8(0.1)) | |||||
s.dtype = mgb.dtype.qint8(0.1) | |||||
check_dtype_value(s, 0.1, 18) | check_dtype_value(s, 0.1, 18) | ||||
s.set_dtype("float32") | |||||
s.dtype = "float32" | |||||
check_dtype_value(s, 0, 1.8) | check_dtype_value(s, 0, 1.8) | ||||
def test_tensor_name(): | |||||
p = mge.Parameter(np.ones((3, 4), dtype="float32")) | |||||
assert "shared" in p.name | |||||
with pytest.raises(ValueError): | |||||
p.name = "Parameter0" | |||||
b = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||||
assert "shared" in b.name | |||||
with pytest.raises(ValueError): | |||||
b.name = "Buffer0" | |||||
s = b + 1 | |||||
assert "ADD" in s.name | |||||
s.name = "WeightAdd1" | |||||
assert s.name == "WeightAdd1" |