GitOrigin-RevId: 329bac640a
tags/v1.3.1
@@ -40,7 +40,7 @@ from ..core.ops.builtin import BackwardGraph, OpDef | |||
from ..core.ops.special import Const | |||
from ..core.tensor import megbrain_graph as G | |||
from ..core.tensor.utils import setscalar | |||
from ..utils.naming import auto_naming | |||
from ..utils.naming import AutoNaming | |||
from .sublinear_memory_config import SublinearMemoryConfig | |||
@@ -297,9 +297,7 @@ class trace: | |||
h = getattr(x, "_mixin_handle", -1) | |||
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): | |||
h, info = self._new_handle() | |||
name = ( | |||
auto_naming.get_scope() + "." + (x.c_name if x.c_name else x._name) | |||
) | |||
name = AutoNaming.gen_name(x) | |||
info.name = name | |||
info.external = True | |||
info.device = x.device | |||
@@ -845,17 +843,17 @@ class trace: | |||
ivars.append(h2v[h]) | |||
ovars = G.apply_normal_varnode(op, *ivars) | |||
auto_naming.record_opnode(ovars[0].op) | |||
AutoNaming.record_opnode(ovars[0].op) | |||
assert len(ovars) == len(ohandles) | |||
h2v.update(zip(ohandles, ovars)) | |||
for i in ohandles: | |||
name = auto_naming.get_var_name(i) | |||
name = AutoNaming.get_var_name(i) | |||
if name is not None: | |||
h2v[i].name = name | |||
auto_naming.remove_duplicate_names() | |||
AutoNaming.remove_duplicate_names() | |||
dest_vars = [] | |||
for i, h in enumerate(self._output_bindings): | |||
@@ -1173,7 +1171,7 @@ def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name): | |||
def apply_with_tracing(op: OpDef, *args: RawTensor): | |||
if hasattr(op, "scope"): | |||
op.scope = auto_naming.get_scope() | |||
op.scope = AutoNaming.get_scope() | |||
if active_trace._symbolic: | |||
outputs = apply_symbolic_mode(op, *args) | |||
else: | |||
@@ -16,7 +16,7 @@ from ..logger import get_logger | |||
from ..tensor import Parameter, Tensor | |||
from ..utils.deprecation import deprecated | |||
from ..utils.hook import HookHandler | |||
from ..utils.naming import auto_naming | |||
from ..utils.naming import AutoNaming | |||
logger = get_logger(__name__) | |||
@@ -111,7 +111,7 @@ class Module(metaclass=ABCMeta): | |||
self._forward_hooks = OrderedDict() | |||
# used for profiler and automatic naming | |||
self._name = "{anonymous}" | |||
self._name = None | |||
@abstractmethod | |||
def forward(self, inputs): | |||
@@ -137,7 +137,7 @@ class Module(metaclass=ABCMeta): | |||
return HookHandler(self._forward_hooks, hook) | |||
def __call__(self, *inputs, **kwargs): | |||
auto_naming.push_scope(self.name if self.name is not None else self._name) | |||
AutoNaming.push_scope(self.name if self.name is not None else self._name) | |||
for hook in self._forward_pre_hooks.values(): | |||
modified_inputs = hook(self, inputs) | |||
if modified_inputs is not None: | |||
@@ -151,7 +151,7 @@ class Module(metaclass=ABCMeta): | |||
modified_outputs = hook(self, inputs, outputs) | |||
if modified_outputs is not None: | |||
outputs = modified_outputs | |||
auto_naming.pop_scope() | |||
AutoNaming.pop_scope() | |||
return outputs | |||
def _flatten( | |||
@@ -20,7 +20,7 @@ from .core.tensor.array_method import ArrayMethodMixin | |||
from .device import _valid_device, get_default_device | |||
from .logger import get_logger | |||
from .utils.deprecation import deprecated | |||
from .utils.naming import auto_naming | |||
from .utils.naming import AutoNaming | |||
logger = get_logger(__name__) | |||
@@ -168,7 +168,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
@name.setter | |||
def name(self, name): | |||
self.c_name = name | |||
auto_naming.record_var_name(self._mixin_handle, name) | |||
AutoNaming.record_var_name(self._mixin_handle, name) | |||
@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") | |||
def set_value(self, value): | |||
@@ -15,40 +15,57 @@ class AutoNaming: | |||
renamed by the user. | |||
""" | |||
def __init__(self): | |||
self.scopes = [] | |||
self.c_ops = [] | |||
self.name2ops = {} | |||
self.handle2names = {} | |||
scopes = [] | |||
c_ops = [] | |||
name2ops = {} | |||
handle2names = {} | |||
__cls_attributes__ = {"scopes", "c_ops", "name2ops", "handle2names"} | |||
def clear(self): | |||
for var in vars(self).values(): | |||
var.clear() | |||
@classmethod | |||
def clear(cls): | |||
for attr in cls.__cls_attributes__: | |||
getattr(cls, attr).clear() | |||
def push_scope(self, scope): | |||
push_scope(scope) | |||
self.scopes.append(scope) | |||
@classmethod | |||
def push_scope(cls, scope): | |||
if scope is not None: | |||
push_scope(scope) | |||
cls.scopes.append(scope) | |||
def pop_scope(self): | |||
scope = self.scopes.pop() | |||
pop_scope(scope) | |||
@classmethod | |||
def pop_scope(cls): | |||
scope = cls.scopes.pop() | |||
if scope is not None: | |||
pop_scope(scope) | |||
def get_scope(self): | |||
return ".".join(self.scopes) | |||
@classmethod | |||
def get_scope(cls): | |||
return ".".join(s for s in cls.scopes if s is not None) | |||
def record_var_name(self, handle, name): | |||
self.handle2names[handle] = name | |||
@classmethod | |||
def gen_name(cls, x) -> str: | |||
scope = cls.get_scope() | |||
name = x.c_name if x.c_name else x._name | |||
return scope + "." + name if len(scope) else name | |||
def get_var_name(self, handle): | |||
return self.handle2names.pop(handle, None) | |||
@classmethod | |||
def record_var_name(cls, handle, name): | |||
cls.handle2names[handle] = name | |||
def record_opnode(self, op): | |||
ops = self.name2ops.get(op.name, []) | |||
ops.append(op) | |||
self.name2ops[op.name] = ops | |||
@classmethod | |||
def get_var_name(cls, handle): | |||
return cls.handle2names.pop(handle, None) | |||
def remove_duplicate_names(self): | |||
for key, ops in self.name2ops.items(): | |||
@classmethod | |||
def record_opnode(cls, op): | |||
ops = cls.name2ops.get(op.name, []) | |||
if op not in ops: | |||
ops.append(op) | |||
cls.name2ops[op.name] = ops | |||
@classmethod | |||
def remove_duplicate_names(cls): | |||
for key, ops in cls.name2ops.items(): | |||
if len(ops) == 1: | |||
continue | |||
for i, op in enumerate(ops): | |||
@@ -57,7 +74,4 @@ class AutoNaming: | |||
continue | |||
for var in op.outputs: | |||
var.name = var.name.replace(key, op.name) | |||
self.name2ops.clear() | |||
auto_naming = AutoNaming() | |||
cls.name2ops.clear() |
@@ -28,7 +28,7 @@ from megengine.functional import exp, log | |||
from megengine.jit import exclude_from_trace, trace | |||
from megengine.module import Module | |||
from megengine.random import normal, uniform | |||
from megengine.utils.naming import auto_naming | |||
from megengine.utils.naming import AutoNaming | |||
@pytest.mark.parametrize("trace_mode", [False, True]) | |||
@@ -141,7 +141,7 @@ def test_dump(): | |||
return a + b | |||
# prevent from remaining scope from exception test | |||
auto_naming.clear() | |||
AutoNaming.clear() | |||
a = tensor([2]) | |||
b = tensor([4]) | |||
y = f(a, b).numpy() | |||
@@ -18,11 +18,11 @@ from megengine import Parameter, Tensor | |||
from megengine.core.tensor import megbrain_graph as G | |||
from megengine.jit.tracing import trace | |||
from megengine.quantization.quantize import quantize, quantize_qat | |||
from megengine.utils.naming import auto_naming | |||
from megengine.utils.naming import AutoNaming | |||
def _dump_and_load(func, symbolic, keep_opr_name=True): | |||
auto_naming.clear() | |||
AutoNaming.clear() | |||
func = trace(func, symbolic=symbolic, capture_as_const=True) | |||
x = Tensor(np.ones(shape=(2, 3))) | |||
func(x).numpy() | |||
@@ -104,6 +104,18 @@ def test_without_module(symbolic): | |||
@pytest.mark.parametrize("symbolic", [False, True]) | |||
def test_ignore_top_module(symbolic): | |||
class Simple(M.Module): | |||
def forward(self, x): | |||
return x + x | |||
m = Simple() | |||
op = _dump_and_load(m, symbolic)[-1] | |||
assert op.name == "ADD" | |||
assert op.outputs[0].name == "ADD" | |||
@pytest.mark.parametrize("symbolic", [False, True]) | |||
def test_with_submodule(symbolic): | |||
class Simple(M.Module): | |||
def __init__(self, name): | |||
@@ -196,7 +208,7 @@ def test_not_keep_opr_name(): | |||
return 2 * x | |||
op = _dump_and_load(f, True, False)[-1] | |||
assert op.name == "MUL(x,2[2])[4]" | |||
assert op.name == "MUL(x,const<2>[2])[4]" | |||
@pytest.mark.parametrize("symbolic", [False, True]) | |||
@@ -419,7 +419,7 @@ void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND &val) { | |||
if (one_elem(val.shape())) { | |||
float v; | |||
static_cast_dtype(&v, val.dtype(), val.raw_ptr()); | |||
m_summary = ssprintf("%.3g", v); | |||
m_summary = ssprintf("const<%.3g>", v); | |||
if (val.shape().ndim != 1) { | |||
m_summary += val.shape().to_string(); | |||
} | |||