@@ -158,10 +158,12 @@ def set_log_level(level, update_existing=True): | |||||
update_existing: whether to update existing loggers | update_existing: whether to update existing loggers | ||||
""" | """ | ||||
global _default_level # pylint: disable=global-statement | global _default_level # pylint: disable=global-statement | ||||
origin_level = _default_level | |||||
_default_level = level | _default_level = level | ||||
if update_existing: | if update_existing: | ||||
for i in _all_loggers: | for i in _all_loggers: | ||||
i.setLevel(level) | i.setLevel(level) | ||||
return origin_level | |||||
_logger = get_logger(__name__) | _logger = get_logger(__name__) | ||||
@@ -32,7 +32,7 @@ class _ConvBnActivation2d(Module): | |||||
track_running_stats=True, | track_running_stats=True, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
super().__init__() | |||||
super().__init__(**kwargs) | |||||
self.conv = Conv2d( | self.conv = Conv2d( | ||||
in_channels, | in_channels, | ||||
out_channels, | out_channels, | ||||
@@ -49,17 +49,13 @@ def _access_structure(obj, key, callback=None): | |||||
parent = None | parent = None | ||||
for k in key_list: | for k in key_list: | ||||
parent = cur | parent = cur | ||||
if isinstance(cur, (Tensor, Module)): | |||||
cur = getattr(cur, k) | |||||
elif isinstance(cur, (list, tuple)): | |||||
if isinstance(cur, (list, tuple)): | |||||
k = int(k) | k = int(k) | ||||
cur = cur[k] | cur = cur[k] | ||||
elif isinstance(cur, dict): | elif isinstance(cur, dict): | ||||
cur = cur[k] | cur = cur[k] | ||||
else: | else: | ||||
raise ValueError( | |||||
"Unsupport value type {} to access attribute".format(type(cur)) | |||||
) | |||||
cur = getattr(cur, k) | |||||
return callback(parent, k, cur) | return callback(parent, k, cur) | ||||
@@ -650,8 +646,8 @@ class Module(metaclass=ABCMeta): | |||||
v._name = k | v._name = k | ||||
elif v._name != k: | elif v._name != k: | ||||
logger.warning( | logger.warning( | ||||
"try setting the submodule `{}` to a new attribute `{}`, its name `{}` will remain unchanged".format( | |||||
v._name, k, v._name | |||||
"try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format( | |||||
type(v), type(self), k, v._name | |||||
) | ) | ||||
) | ) | ||||
super().__setattr__(name, value) | super().__setattr__(name, value) | ||||
@@ -111,10 +111,8 @@ class QParams: | |||||
return "QParams({})".format(content) | return "QParams({})".format(content) | ||||
class LSQParams: | |||||
r"""To standardize LSQ's qparams format. If custom | |||||
qparams is needed, inherit this class and add custom ``__slots__``. | |||||
""" | |||||
class LSQParams(QParams): | |||||
r"""LSQ qparams with extra grad_scale slot.""" | |||||
__slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale" | __slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale" | ||||
@@ -126,30 +124,9 @@ class LSQParams: | |||||
zero_point: Tensor, | zero_point: Tensor, | ||||
grad_scale: Tensor, | grad_scale: Tensor, | ||||
): | ): | ||||
self.mode = mode | |||||
self.dtype_meta = dtype_meta | |||||
self.scale = scale | |||||
self.zero_point = zero_point | |||||
super().__init__(mode, dtype_meta, scale, zero_point) | |||||
self.grad_scale = grad_scale | self.grad_scale = grad_scale | ||||
def update(self, lsqparams: "LSQParams"): | |||||
for key in self.__slots__: | |||||
setattr(self, key, getattr(lsqparams, key)) | |||||
def __eq__(self, other): | |||||
if len(self.__slots__) != len(other.__slots__): | |||||
return False | |||||
for key in self.__slots__: | |||||
if not hasattr(other, key) or getattr(self, key) != getattr(other, key): | |||||
return False | |||||
return True | |||||
def __repr__(self): | |||||
content = ", ".join( | |||||
["{}={}".format(key, getattr(self, key)) for key in self.__slots__] | |||||
) | |||||
return "LSQParams({})".format(content) | |||||
class QParamsModuleMixin(abc.ABC): | class QParamsModuleMixin(abc.ABC): | ||||
def get_quantized_dtype(self): | def get_quantized_dtype(self): | ||||
@@ -642,7 +642,6 @@ class InternalGraph: | |||||
Returns: | Returns: | ||||
A :class:`~.TracedModule.NodeFilterType`. | A :class:`~.TracedModule.NodeFilterType`. | ||||
""" | """ | ||||
assert issubclass(module_cls, Module) | |||||
return self.nodes(recursive).type(module_cls) | return self.nodes(recursive).type(module_cls) | ||||
def get_node_by_id(self, node_id: List[int] = None, recursive=True): | def get_node_by_id(self, node_id: List[int] = None, recursive=True): | ||||
@@ -96,6 +96,12 @@ class _ModuleList(Module, MutableSequence): | |||||
raise IndexError("list index out of range") | raise IndexError("list index out of range") | ||||
return rst if len(rst) > 1 else rst[0] | return rst if len(rst) > 1 else rst[0] | ||||
def __setattr__(self, key, value): | |||||
# clear mod name to avoid warning in Module's setattr | |||||
if isinstance(value, Module): | |||||
value._name = None | |||||
super().__setattr__(key, value) | |||||
def __setitem__(self, idx: int, mod: Module): | def __setitem__(self, idx: int, mod: Module): | ||||
if not isinstance(mod, Module): | if not isinstance(mod, Module): | ||||
raise ValueError("invalid sub-module") | raise ValueError("invalid sub-module") | ||||
@@ -159,6 +165,12 @@ class _ModuleDict(Module, MutableMapping): | |||||
def __getitem__(self, key): | def __getitem__(self, key): | ||||
return getattr(self, key) | return getattr(self, key) | ||||
def __setattr__(self, key, value): | |||||
# clear mod name to avoid warning in Module's setattr | |||||
if isinstance(value, Module): | |||||
value._name = None | |||||
super().__setattr__(key, value) | |||||
def __setitem__(self, key, value): | def __setitem__(self, key, value): | ||||
if not isinstance(value, Module): | if not isinstance(value, Module): | ||||
raise ValueError("invalid sub-module") | raise ValueError("invalid sub-module") | ||||