Browse Source

feat(mge/jit): add support output symbol var name settings for dump

GitOrigin-RevId: 258c03ee34
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
380cb6e47f
4 changed files with 64 additions and 10 deletions
  1. +22
    -1
      python_module/megengine/core/tensor.py
  2. +16
    -0
      python_module/megengine/jit/__init__.py
  3. +1
    -1
      python_module/megengine/module/module.py
  4. +25
    -8
      python_module/test/unit/core/test_tensor.py

+ 22
- 1
python_module/megengine/core/tensor.py View File

@@ -235,15 +235,36 @@ class Tensor:
return self.__val.dtype
return self._symvar.dtype

def set_dtype(self, dtype: str = None):
@dtype.setter
def dtype(self, dtype: str = None):
r"""Set the data type of the tensor.
"""
if self.__val is not None:
self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy())
elif self.__sym_override is not None:
self.__sym_override = self.__sym_override.astype(dtype)
elif self.__sym is not None:
self.__sym = self.__sym.astype(dtype)

@property
def name(self):
r"""Get the tensor name, does not support Parameter and Buffer.
"""
return self._symvar.name

@name.setter
def name(self, name: str = None):
r"""Set the tensor name, does not support Parameter and Buffer.
"""
if self.__val is not None:
raise ValueError("name setting is not available for Parameter or Buffer.")
if self.__sym_override is not None:
self.__sym_override = self.__sym_override.rename(name)
if self.__sym is not None:
assert not self.__val
self.__sym = self.__sym.rename(name)

@property
def _comp_node(self):
if self.__val is not None:
return self.__val.comp_node


+ 16
- 0
python_module/megengine/jit/__init__.py View File

@@ -436,6 +436,7 @@ class trace:
arg_names=None,
append=False,
optimize_for_inference=False,
output_names=None,
**kwargs
):
"""
@@ -446,6 +447,8 @@ class trace:
:param append: whether output is appended to ``fpath``.
:param optimize_for_inference: whether to enable optimize_for_inference
pass before dump.
:param output_names: names of the output tensors in the traced function,
will use the default name if does not specify.

:param enable_io16xc32: whether to use float16 for I/O between oprs and use
float32 as internal computation precision. Note the output var would be
@@ -488,6 +491,17 @@ class trace:
len(self._args), len(arg_names)
)
)
if isinstance(output_names, str):
output_names = [output_names]
if output_names is None:
output_names = [var.name for var in self._sym_outputs]
elif len(output_names) != len(self._sym_outputs):
raise ValueError(
"len(output_names) should be {}, got {}".format(
len(self._sym_outputs), len(output_names)
)
)

optimize_for_inference_args_map = {
"enable_io16xc32": "f16_io_f32_comp",
"enable_ioc16": "f16_io_comp",
@@ -541,6 +555,8 @@ class trace:
sym_outputs = mgb.optimize_for_inference(
sym_outputs, **optimize_for_inference_kwargs
)
for var, name in zip(sym_outputs, output_names):
var.rename(name)
mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append)

def get_profile(self):


+ 1
- 1
python_module/megengine/module/module.py View File

@@ -464,7 +464,7 @@ class Module(metaclass=ABCMeta):
# For quantized dtype, the initialized dtype
# scale/zero_points maybe invalid, use pretrained dtype instead.
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype):
var.set_dtype(to_be_load.dtype)
var.dtype = to_be_load.dtype
var.set_value(to_be_load)
loaded.append(k)



+ 25
- 8
python_module/test/unit/core/test_tensor.py View File

@@ -46,29 +46,46 @@ def test_tensor_set_dtype():
)

t = mge.Parameter(np.ones((3, 4), dtype="float32"))
t.set_dtype(mgb.dtype.qint8(0.1))
t.dtype = mgb.dtype.qint8(0.1)
check_dtype_value(t, 0.1, 10)

t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1)))
t.set_dtype(mgb.dtype.qint8(0.3))
t.dtype = mgb.dtype.qint8(0.3)
check_dtype_value(t, 0.3, 3)

t = mge.Buffer(np.ones((3, 4), dtype="float32"))
t.set_dtype(mgb.dtype.qint8(0.1))
t.dtype = mgb.dtype.qint8(0.1)
check_dtype_value(t, 0.1, 10)

t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1)))
t.set_dtype(mgb.dtype.qint8(0.3))
t.dtype = mgb.dtype.qint8(0.3)
check_dtype_value(t, 0.3, 3)

t = mge.Buffer(np.ones((3, 4), dtype="float32"))
s = t + 1
s.set_dtype(mgb.dtype.qint8(0.2))
s.dtype = mgb.dtype.qint8(0.2)
check_dtype_value(s, 0.2, 10)

t.set_dtype(mgb.dtype.qint8(0.3))
t.dtype = mgb.dtype.qint8(0.3)
s = t + 1
s.set_dtype(mgb.dtype.qint8(0.1))
s.dtype = mgb.dtype.qint8(0.1)
check_dtype_value(s, 0.1, 18)
s.set_dtype("float32")
s.dtype = "float32"
check_dtype_value(s, 0, 1.8)


def test_tensor_name():
p = mge.Parameter(np.ones((3, 4), dtype="float32"))
assert "shared" in p.name
with pytest.raises(ValueError):
p.name = "Parameter0"

b = mge.Buffer(np.ones((3, 4), dtype="float32"))
assert "shared" in b.name
with pytest.raises(ValueError):
b.name = "Buffer0"

s = b + 1
assert "ADD" in s.name
s.name = "WeightAdd1"
assert s.name == "WeightAdd1"

Loading…
Cancel
Save