Browse Source

feat(imperative/utils): optimize the naming rules

GitOrigin-RevId: 329bac640a
tags/v1.3.1
Megvii Engine Team 4 years ago
parent
commit
aed681d325
7 changed files with 74 additions and 50 deletions
  1. +6
    -8
      imperative/python/megengine/jit/tracing.py
  2. +4
    -4
      imperative/python/megengine/module/module.py
  3. +2
    -2
      imperative/python/megengine/tensor.py
  4. +44
    -30
      imperative/python/megengine/utils/naming.py
  5. +2
    -2
      imperative/python/test/unit/jit/test_tracing.py
  6. +15
    -3
      imperative/python/test/unit/utils/test_dump_naming.py
  7. +1
    -1
      src/opr/impl/io.cpp

+ 6
- 8
imperative/python/megengine/jit/tracing.py View File

@@ -40,7 +40,7 @@ from ..core.ops.builtin import BackwardGraph, OpDef
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar from ..core.tensor.utils import setscalar
from ..utils.naming import auto_naming
from ..utils.naming import AutoNaming
from .sublinear_memory_config import SublinearMemoryConfig from .sublinear_memory_config import SublinearMemoryConfig




@@ -297,9 +297,7 @@ class trace:
h = getattr(x, "_mixin_handle", -1) h = getattr(x, "_mixin_handle", -1)
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
h, info = self._new_handle() h, info = self._new_handle()
name = (
auto_naming.get_scope() + "." + (x.c_name if x.c_name else x._name)
)
name = AutoNaming.gen_name(x)
info.name = name info.name = name
info.external = True info.external = True
info.device = x.device info.device = x.device
@@ -845,17 +843,17 @@ class trace:
ivars.append(h2v[h]) ivars.append(h2v[h])
ovars = G.apply_normal_varnode(op, *ivars) ovars = G.apply_normal_varnode(op, *ivars)


auto_naming.record_opnode(ovars[0].op)
AutoNaming.record_opnode(ovars[0].op)


assert len(ovars) == len(ohandles) assert len(ovars) == len(ohandles)
h2v.update(zip(ohandles, ovars)) h2v.update(zip(ohandles, ovars))


for i in ohandles: for i in ohandles:
name = auto_naming.get_var_name(i)
name = AutoNaming.get_var_name(i)
if name is not None: if name is not None:
h2v[i].name = name h2v[i].name = name


auto_naming.remove_duplicate_names()
AutoNaming.remove_duplicate_names()


dest_vars = [] dest_vars = []
for i, h in enumerate(self._output_bindings): for i, h in enumerate(self._output_bindings):
@@ -1173,7 +1171,7 @@ def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):


def apply_with_tracing(op: OpDef, *args: RawTensor): def apply_with_tracing(op: OpDef, *args: RawTensor):
if hasattr(op, "scope"): if hasattr(op, "scope"):
op.scope = auto_naming.get_scope()
op.scope = AutoNaming.get_scope()
if active_trace._symbolic: if active_trace._symbolic:
outputs = apply_symbolic_mode(op, *args) outputs = apply_symbolic_mode(op, *args)
else: else:


+ 4
- 4
imperative/python/megengine/module/module.py View File

@@ -16,7 +16,7 @@ from ..logger import get_logger
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
from ..utils.deprecation import deprecated from ..utils.deprecation import deprecated
from ..utils.hook import HookHandler from ..utils.hook import HookHandler
from ..utils.naming import auto_naming
from ..utils.naming import AutoNaming


logger = get_logger(__name__) logger = get_logger(__name__)


@@ -111,7 +111,7 @@ class Module(metaclass=ABCMeta):
self._forward_hooks = OrderedDict() self._forward_hooks = OrderedDict()


# used for profiler and automatic naming # used for profiler and automatic naming
self._name = "{anonymous}"
self._name = None


@abstractmethod @abstractmethod
def forward(self, inputs): def forward(self, inputs):
@@ -137,7 +137,7 @@ class Module(metaclass=ABCMeta):
return HookHandler(self._forward_hooks, hook) return HookHandler(self._forward_hooks, hook)


def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
auto_naming.push_scope(self.name if self.name is not None else self._name)
AutoNaming.push_scope(self.name if self.name is not None else self._name)
for hook in self._forward_pre_hooks.values(): for hook in self._forward_pre_hooks.values():
modified_inputs = hook(self, inputs) modified_inputs = hook(self, inputs)
if modified_inputs is not None: if modified_inputs is not None:
@@ -151,7 +151,7 @@ class Module(metaclass=ABCMeta):
modified_outputs = hook(self, inputs, outputs) modified_outputs = hook(self, inputs, outputs)
if modified_outputs is not None: if modified_outputs is not None:
outputs = modified_outputs outputs = modified_outputs
auto_naming.pop_scope()
AutoNaming.pop_scope()
return outputs return outputs


def _flatten( def _flatten(


+ 2
- 2
imperative/python/megengine/tensor.py View File

@@ -20,7 +20,7 @@ from .core.tensor.array_method import ArrayMethodMixin
from .device import _valid_device, get_default_device from .device import _valid_device, get_default_device
from .logger import get_logger from .logger import get_logger
from .utils.deprecation import deprecated from .utils.deprecation import deprecated
from .utils.naming import auto_naming
from .utils.naming import AutoNaming


logger = get_logger(__name__) logger = get_logger(__name__)


@@ -168,7 +168,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
@name.setter @name.setter
def name(self, name): def name(self, name):
self.c_name = name self.c_name = name
auto_naming.record_var_name(self._mixin_handle, name)
AutoNaming.record_var_name(self._mixin_handle, name)


@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
def set_value(self, value): def set_value(self, value):


+ 44
- 30
imperative/python/megengine/utils/naming.py View File

@@ -15,40 +15,57 @@ class AutoNaming:
renamed by the user. renamed by the user.
""" """


def __init__(self):
self.scopes = []
self.c_ops = []
self.name2ops = {}
self.handle2names = {}
scopes = []
c_ops = []
name2ops = {}
handle2names = {}
__cls_attributes__ = {"scopes", "c_ops", "name2ops", "handle2names"}


def clear(self):
for var in vars(self).values():
var.clear()
@classmethod
def clear(cls):
for attr in cls.__cls_attributes__:
getattr(cls, attr).clear()


def push_scope(self, scope):
push_scope(scope)
self.scopes.append(scope)
@classmethod
def push_scope(cls, scope):
if scope is not None:
push_scope(scope)
cls.scopes.append(scope)


def pop_scope(self):
scope = self.scopes.pop()
pop_scope(scope)
@classmethod
def pop_scope(cls):
scope = cls.scopes.pop()
if scope is not None:
pop_scope(scope)


def get_scope(self):
return ".".join(self.scopes)
@classmethod
def get_scope(cls):
return ".".join(s for s in cls.scopes if s is not None)


def record_var_name(self, handle, name):
self.handle2names[handle] = name
@classmethod
def gen_name(cls, x) -> str:
scope = cls.get_scope()
name = x.c_name if x.c_name else x._name
return scope + "." + name if len(scope) else name


def get_var_name(self, handle):
return self.handle2names.pop(handle, None)
@classmethod
def record_var_name(cls, handle, name):
cls.handle2names[handle] = name


def record_opnode(self, op):
ops = self.name2ops.get(op.name, [])
ops.append(op)
self.name2ops[op.name] = ops
@classmethod
def get_var_name(cls, handle):
return cls.handle2names.pop(handle, None)


def remove_duplicate_names(self):
for key, ops in self.name2ops.items():
@classmethod
def record_opnode(cls, op):
ops = cls.name2ops.get(op.name, [])
if op not in ops:
ops.append(op)
cls.name2ops[op.name] = ops

@classmethod
def remove_duplicate_names(cls):
for key, ops in cls.name2ops.items():
if len(ops) == 1: if len(ops) == 1:
continue continue
for i, op in enumerate(ops): for i, op in enumerate(ops):
@@ -57,7 +74,4 @@ class AutoNaming:
continue continue
for var in op.outputs: for var in op.outputs:
var.name = var.name.replace(key, op.name) var.name = var.name.replace(key, op.name)
self.name2ops.clear()


auto_naming = AutoNaming()
cls.name2ops.clear()

+ 2
- 2
imperative/python/test/unit/jit/test_tracing.py View File

@@ -28,7 +28,7 @@ from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace from megengine.jit import exclude_from_trace, trace
from megengine.module import Module from megengine.module import Module
from megengine.random import normal, uniform from megengine.random import normal, uniform
from megengine.utils.naming import auto_naming
from megengine.utils.naming import AutoNaming




@pytest.mark.parametrize("trace_mode", [False, True]) @pytest.mark.parametrize("trace_mode", [False, True])
@@ -141,7 +141,7 @@ def test_dump():
return a + b return a + b


# prevent from remaining scope from exception test # prevent from remaining scope from exception test
auto_naming.clear()
AutoNaming.clear()
a = tensor([2]) a = tensor([2])
b = tensor([4]) b = tensor([4])
y = f(a, b).numpy() y = f(a, b).numpy()


+ 15
- 3
imperative/python/test/unit/utils/test_dump_naming.py View File

@@ -18,11 +18,11 @@ from megengine import Parameter, Tensor
from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor import megbrain_graph as G
from megengine.jit.tracing import trace from megengine.jit.tracing import trace
from megengine.quantization.quantize import quantize, quantize_qat from megengine.quantization.quantize import quantize, quantize_qat
from megengine.utils.naming import auto_naming
from megengine.utils.naming import AutoNaming




def _dump_and_load(func, symbolic, keep_opr_name=True): def _dump_and_load(func, symbolic, keep_opr_name=True):
auto_naming.clear()
AutoNaming.clear()
func = trace(func, symbolic=symbolic, capture_as_const=True) func = trace(func, symbolic=symbolic, capture_as_const=True)
x = Tensor(np.ones(shape=(2, 3))) x = Tensor(np.ones(shape=(2, 3)))
func(x).numpy() func(x).numpy()
@@ -104,6 +104,18 @@ def test_without_module(symbolic):




@pytest.mark.parametrize("symbolic", [False, True]) @pytest.mark.parametrize("symbolic", [False, True])
def test_ignore_top_module(symbolic):
class Simple(M.Module):
def forward(self, x):
return x + x

m = Simple()
op = _dump_and_load(m, symbolic)[-1]
assert op.name == "ADD"
assert op.outputs[0].name == "ADD"


@pytest.mark.parametrize("symbolic", [False, True])
def test_with_submodule(symbolic): def test_with_submodule(symbolic):
class Simple(M.Module): class Simple(M.Module):
def __init__(self, name): def __init__(self, name):
@@ -196,7 +208,7 @@ def test_not_keep_opr_name():
return 2 * x return 2 * x


op = _dump_and_load(f, True, False)[-1] op = _dump_and_load(f, True, False)[-1]
assert op.name == "MUL(x,2[2])[4]"
assert op.name == "MUL(x,const<2>[2])[4]"




@pytest.mark.parametrize("symbolic", [False, True]) @pytest.mark.parametrize("symbolic", [False, True])


+ 1
- 1
src/opr/impl/io.cpp View File

@@ -419,7 +419,7 @@ void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND &val) {
if (one_elem(val.shape())) { if (one_elem(val.shape())) {
float v; float v;
static_cast_dtype(&v, val.dtype(), val.raw_ptr()); static_cast_dtype(&v, val.dtype(), val.raw_ptr());
m_summary = ssprintf("%.3g", v);
m_summary = ssprintf("const<%.3g>", v);
if (val.shape().ndim != 1) { if (val.shape().ndim != 1) {
m_summary += val.shape().to_string(); m_summary += val.shape().to_string();
} }


Loading…
Cancel
Save