Browse Source

fix(mge): fix some minor problems

GitOrigin-RevId: 43abda2ab9
release-1.7
Megvii Engine Team 3 years ago
parent
commit
f2f3356572
6 changed files with 22 additions and 36 deletions
  1. +2
    -0
      imperative/python/megengine/logger.py
  2. +1
    -1
      imperative/python/megengine/module/conv_bn.py
  3. +4
    -8
      imperative/python/megengine/module/module.py
  4. +3
    -26
      imperative/python/megengine/quantization/utils.py
  5. +0
    -1
      imperative/python/megengine/traced_module/traced_module.py
  6. +12
    -0
      imperative/python/megengine/traced_module/utils.py

+ 2
- 0
imperative/python/megengine/logger.py View File

@@ -158,10 +158,12 @@ def set_log_level(level, update_existing=True):
update_existing: whether to update existing loggers
"""
global _default_level # pylint: disable=global-statement
origin_level = _default_level
_default_level = level
if update_existing:
for i in _all_loggers:
i.setLevel(level)
return origin_level


_logger = get_logger(__name__)


+ 1
- 1
imperative/python/megengine/module/conv_bn.py View File

@@ -32,7 +32,7 @@ class _ConvBnActivation2d(Module):
track_running_stats=True,
**kwargs
):
super().__init__()
super().__init__(**kwargs)
self.conv = Conv2d(
in_channels,
out_channels,


+ 4
- 8
imperative/python/megengine/module/module.py View File

@@ -49,17 +49,13 @@ def _access_structure(obj, key, callback=None):
parent = None
for k in key_list:
parent = cur
if isinstance(cur, (Tensor, Module)):
cur = getattr(cur, k)
elif isinstance(cur, (list, tuple)):
if isinstance(cur, (list, tuple)):
k = int(k)
cur = cur[k]
elif isinstance(cur, dict):
cur = cur[k]
else:
raise ValueError(
"Unsupport value type {} to access attribute".format(type(cur))
)
cur = getattr(cur, k)
return callback(parent, k, cur)


@@ -650,8 +646,8 @@ class Module(metaclass=ABCMeta):
v._name = k
elif v._name != k:
logger.warning(
"try setting the submodule `{}` to a new attribute `{}`, its name `{}` will remain unchanged".format(
v._name, k, v._name
"try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format(
type(v), type(self), k, v._name
)
)
super().__setattr__(name, value)


+ 3
- 26
imperative/python/megengine/quantization/utils.py View File

@@ -111,10 +111,8 @@ class QParams:
return "QParams({})".format(content)


class LSQParams:
r"""To standardize LSQ's qparams format. If custom
qparams is needed, inherit this class and add custom ``__slots__``.
"""
class LSQParams(QParams):
r"""LSQ qparams with extra grad_scale slot."""

__slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale"

@@ -126,30 +124,9 @@ class LSQParams:
zero_point: Tensor,
grad_scale: Tensor,
):
self.mode = mode
self.dtype_meta = dtype_meta
self.scale = scale
self.zero_point = zero_point
super().__init__(mode, dtype_meta, scale, zero_point)
self.grad_scale = grad_scale

def update(self, lsqparams: "LSQParams"):
for key in self.__slots__:
setattr(self, key, getattr(lsqparams, key))

def __eq__(self, other):
if len(self.__slots__) != len(other.__slots__):
return False
for key in self.__slots__:
if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
return False
return True

def __repr__(self):
content = ", ".join(
["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
)
return "LSQParams({})".format(content)


class QParamsModuleMixin(abc.ABC):
def get_quantized_dtype(self):


+ 0
- 1
imperative/python/megengine/traced_module/traced_module.py View File

@@ -642,7 +642,6 @@ class InternalGraph:
Returns:
A :class:`~.TracedModule.NodeFilterType`.
"""
assert issubclass(module_cls, Module)
return self.nodes(recursive).type(module_cls)

def get_node_by_id(self, node_id: List[int] = None, recursive=True):


+ 12
- 0
imperative/python/megengine/traced_module/utils.py View File

@@ -96,6 +96,12 @@ class _ModuleList(Module, MutableSequence):
raise IndexError("list index out of range")
return rst if len(rst) > 1 else rst[0]

def __setattr__(self, key, value):
# clear mod name to avoid warning in Module's setattr
if isinstance(value, Module):
value._name = None
super().__setattr__(key, value)

def __setitem__(self, idx: int, mod: Module):
if not isinstance(mod, Module):
raise ValueError("invalid sub-module")
@@ -159,6 +165,12 @@ class _ModuleDict(Module, MutableMapping):
def __getitem__(self, key):
return getattr(self, key)

def __setattr__(self, key, value):
# clear mod name to avoid warning in Module's setattr
if isinstance(value, Module):
value._name = None
super().__setattr__(key, value)

def __setitem__(self, key, value):
if not isinstance(value, Module):
raise ValueError("invalid sub-module")


Loading…
Cancel
Save