GitOrigin-RevId: 258c03ee34
tags/v1.0.0-rc1
@@ -235,15 +235,36 @@ class Tensor: | |||
return self.__val.dtype | |||
return self._symvar.dtype | |||
def set_dtype(self, dtype: str = None): | |||
@dtype.setter | |||
def dtype(self, dtype: str = None): | |||
r"""Set the data type of the tensor. | |||
""" | |||
if self.__val is not None: | |||
self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) | |||
elif self.__sym_override is not None: | |||
self.__sym_override = self.__sym_override.astype(dtype) | |||
elif self.__sym is not None: | |||
self.__sym = self.__sym.astype(dtype) | |||
@property | |||
def name(self): | |||
r"""Get the tensor name, does not support Parameter and Buffer. | |||
""" | |||
return self._symvar.name | |||
@name.setter | |||
def name(self, name: str = None): | |||
r"""Set the tensor name, does not support Parameter and Buffer. | |||
""" | |||
if self.__val is not None: | |||
raise ValueError("name setting is not available for Parameter or Buffer.") | |||
if self.__sym_override is not None: | |||
self.__sym_override = self.__sym_override.rename(name) | |||
if self.__sym is not None: | |||
assert not self.__val | |||
self.__sym = self.__sym.rename(name) | |||
@property | |||
def _comp_node(self): | |||
if self.__val is not None: | |||
return self.__val.comp_node | |||
@@ -436,6 +436,7 @@ class trace: | |||
arg_names=None, | |||
append=False, | |||
optimize_for_inference=False, | |||
output_names=None, | |||
**kwargs | |||
): | |||
""" | |||
@@ -446,6 +447,8 @@ class trace: | |||
:param append: whether output is appended to ``fpath``. | |||
:param optimize_for_inference: whether to enable optimize_for_inference | |||
pass before dump. | |||
:param output_names: names of the output tensors in the traced function, | |||
will use the default name if does not specify. | |||
:param enable_io16xc32: whether to use float16 for I/O between oprs and use | |||
float32 as internal computation precision. Note the output var would be | |||
@@ -488,6 +491,17 @@ class trace: | |||
len(self._args), len(arg_names) | |||
) | |||
) | |||
if isinstance(output_names, str): | |||
output_names = [output_names] | |||
if output_names is None: | |||
output_names = [var.name for var in self._sym_outputs] | |||
elif len(output_names) != len(self._sym_outputs): | |||
raise ValueError( | |||
"len(output_names) should be {}, got {}".format( | |||
len(self._sym_outputs), len(output_names) | |||
) | |||
) | |||
optimize_for_inference_args_map = { | |||
"enable_io16xc32": "f16_io_f32_comp", | |||
"enable_ioc16": "f16_io_comp", | |||
@@ -541,6 +555,8 @@ class trace: | |||
sym_outputs = mgb.optimize_for_inference( | |||
sym_outputs, **optimize_for_inference_kwargs | |||
) | |||
for var, name in zip(sym_outputs, output_names): | |||
var.rename(name) | |||
mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append) | |||
def get_profile(self): | |||
@@ -464,7 +464,7 @@ class Module(metaclass=ABCMeta): | |||
# For quantized dtype, the initialized dtype | |||
# scale/zero_points maybe invalid, use pretrained dtype instead. | |||
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||
var.set_dtype(to_be_load.dtype) | |||
var.dtype = to_be_load.dtype | |||
var.set_value(to_be_load) | |||
loaded.append(k) | |||
@@ -46,29 +46,46 @@ def test_tensor_set_dtype(): | |||
) | |||
t = mge.Parameter(np.ones((3, 4), dtype="float32")) | |||
t.set_dtype(mgb.dtype.qint8(0.1)) | |||
t.dtype = mgb.dtype.qint8(0.1) | |||
check_dtype_value(t, 0.1, 10) | |||
t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||
t.dtype = mgb.dtype.qint8(0.3) | |||
check_dtype_value(t, 0.3, 3) | |||
t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||
t.set_dtype(mgb.dtype.qint8(0.1)) | |||
t.dtype = mgb.dtype.qint8(0.1) | |||
check_dtype_value(t, 0.1, 10) | |||
t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||
t.dtype = mgb.dtype.qint8(0.3) | |||
check_dtype_value(t, 0.3, 3) | |||
t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||
s = t + 1 | |||
s.set_dtype(mgb.dtype.qint8(0.2)) | |||
s.dtype = mgb.dtype.qint8(0.2) | |||
check_dtype_value(s, 0.2, 10) | |||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||
t.dtype = mgb.dtype.qint8(0.3) | |||
s = t + 1 | |||
s.set_dtype(mgb.dtype.qint8(0.1)) | |||
s.dtype = mgb.dtype.qint8(0.1) | |||
check_dtype_value(s, 0.1, 18) | |||
s.set_dtype("float32") | |||
s.dtype = "float32" | |||
check_dtype_value(s, 0, 1.8) | |||
def test_tensor_name(): | |||
p = mge.Parameter(np.ones((3, 4), dtype="float32")) | |||
assert "shared" in p.name | |||
with pytest.raises(ValueError): | |||
p.name = "Parameter0" | |||
b = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||
assert "shared" in b.name | |||
with pytest.raises(ValueError): | |||
b.name = "Buffer0" | |||
s = b + 1 | |||
assert "ADD" in s.name | |||
s.name = "WeightAdd1" | |||
assert s.name == "WeightAdd1" |