You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module.py 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732
  1. from abc import ABCMeta, abstractmethod
  2. from collections import OrderedDict
  3. from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
  4. import numpy as np
  5. from ..core.tensor.utils import make_shape_tuple
  6. from ..logger import get_logger
  7. from ..tensor import Parameter, Tensor
  8. from ..utils.deprecation import deprecated
  9. from ..utils.hook import HookHandler
  10. from ..utils.naming import AutoNaming
  11. logger = get_logger(__name__)
  12. def _expand_structure(prefix, obj):
  13. if isinstance(obj, (Tensor, Module)):
  14. return [(prefix, obj)]
  15. elif isinstance(obj, (list, tuple, dict)):
  16. ret = []
  17. if isinstance(obj, dict):
  18. targets = ((k, obj[k]) for k in sorted(obj))
  19. else:
  20. targets = ((str(k), v) for k, v in enumerate(obj))
  21. for k, o in targets:
  22. sub_ret = _expand_structure(k, o)
  23. if sub_ret and not isinstance(k, str):
  24. raise AssertionError(
  25. "keys for Tensor and Module must be str, error key: {}".format(k)
  26. )
  27. for kt, vt in sub_ret:
  28. ret.extend([(prefix + "." + kt, vt)])
  29. return ret
  30. else:
  31. return []
  32. def _access_structure(obj, key, callback=None):
  33. key_list = key.split(".")
  34. cur = obj
  35. parent = None
  36. for k in key_list:
  37. parent = cur
  38. if isinstance(cur, (list, tuple)):
  39. k = int(k)
  40. cur = cur[k]
  41. elif isinstance(cur, dict):
  42. cur = cur[k]
  43. else:
  44. cur = getattr(cur, k)
  45. return callback(parent, k, cur)
  46. def _is_parameter(obj):
  47. return isinstance(obj, Parameter)
  48. def _is_tensor(obj):
  49. return isinstance(obj, Tensor)
  50. def _is_buffer(obj):
  51. return isinstance(obj, Tensor) and not isinstance(obj, Parameter)
  52. def _is_module(obj):
  53. return isinstance(obj, Module)
  54. def _get_XNorm_typeclass():
  55. from .batchnorm import _BatchNorm
  56. from .normalization import GroupNorm, InstanceNorm, LayerNorm
  57. XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm)
  58. return XNorm_types
  59. class Module(metaclass=ABCMeta):
  60. r"""Base Module class.
  61. Args:
  62. name: module's name, can be initialized by the ``kwargs`` parameter
  63. of child class.
  64. """
  65. def __init__(self, name=None):
  66. self._modules = []
  67. if name is not None:
  68. assert (
  69. isinstance(name, str) and name.strip()
  70. ), "Module's name must be a non-empty string"
  71. self.name = name
  72. # runtime attributes
  73. self.training = True
  74. self.quantize_disabled = False
  75. # hooks
  76. self._forward_pre_hooks = OrderedDict()
  77. self._forward_hooks = OrderedDict()
  78. # used for profiler and automatic naming
  79. self._name = None
  80. self._short_name = None
  81. @abstractmethod
  82. def forward(self, inputs):
  83. pass
  84. def register_forward_pre_hook(self, hook: Callable) -> HookHandler:
  85. """Registers a hook to handle forward inputs. `hook` should be a function.
  86. Args:
  87. hook: a function that receive `module` and `inputs`, then return
  88. a modified `inputs` or `None`.
  89. Returns:
  90. a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
  91. """
  92. return HookHandler(self._forward_pre_hooks, hook)
  93. def register_forward_hook(self, hook: Callable) -> HookHandler:
  94. """Registers a hook to handle forward results. `hook` should be a function that
  95. receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`.
  96. This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
  97. """
  98. return HookHandler(self._forward_hooks, hook)
  99. def __call__(self, *inputs, **kwargs):
  100. AutoNaming.push_scope(self.name if self.name is not None else self._short_name)
  101. for hook in self._forward_pre_hooks.values():
  102. modified_inputs = hook(self, inputs)
  103. if modified_inputs is not None:
  104. if not isinstance(modified_inputs, tuple):
  105. modified_inputs = (modified_inputs,)
  106. inputs = modified_inputs
  107. outputs = self.forward(*inputs, **kwargs)
  108. for hook in self._forward_hooks.values():
  109. modified_outputs = hook(self, inputs, outputs)
  110. if modified_outputs is not None:
  111. outputs = modified_outputs
  112. AutoNaming.pop_scope()
  113. return outputs
  114. def _flatten(
  115. self,
  116. *,
  117. recursive: bool = True,
  118. with_key: bool = False,
  119. with_parent: bool = False,
  120. prefix: Optional[str] = None,
  121. predicate: Callable[[Any], bool] = lambda _: True,
  122. seen: Optional[Set[int]] = None
  123. ) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]:
  124. """Scans the module object and returns an iterable for the :class:`~.Tensor`
  125. and :class:`~.Module` attributes that agree with the ``predicate``. For multiple
  126. calls of this function with same arguments, the order of objects within the
  127. returned iterable is guaranteed to be identical, as long as all the involved
  128. module objects' ``__dict__`` does not change thoughout those calls.
  129. Args:
  130. recursive: whether to recursively scan all the submodules.
  131. with_key: whether to yield keys along with yielded objects.
  132. with_parent: whether to yield ``self`` along with yielded objects.
  133. prefix: prefix appended to the yielded keys.
  134. predicate: the predication function applied to scanned objects.
  135. seen: a dict that records whether a module has been traversed yet.
  136. """
  137. if seen is None:
  138. seen = set([id(self)])
  139. module_dict = vars(self)
  140. _prefix = "" if prefix is None else prefix + "."
  141. for key in sorted(module_dict):
  142. for expanded_key, leaf in _expand_structure(key, module_dict[key]):
  143. leaf_id = id(leaf)
  144. if leaf_id in seen:
  145. continue
  146. seen.add(leaf_id)
  147. if predicate(leaf):
  148. if with_key and with_parent:
  149. yield _prefix + expanded_key, leaf, self
  150. elif with_key:
  151. yield _prefix + expanded_key, leaf
  152. elif with_parent:
  153. yield leaf, self
  154. else:
  155. yield leaf
  156. if recursive and isinstance(leaf, Module):
  157. yield from leaf._flatten(
  158. recursive=recursive,
  159. with_key=with_key,
  160. with_parent=with_parent,
  161. prefix=_prefix + expanded_key if with_key else None,
  162. predicate=predicate,
  163. seen=seen,
  164. )
  165. def parameters(self, recursive: bool = True, **kwargs) -> Iterable[Parameter]:
  166. r"""Returns an iterable for the :class:`~.Parameter` of the module.
  167. Args:
  168. recursive: If ``True``, returns all :class:`~.Parameter` within this
  169. module, else only returns :class:`~.Parameter` that are direct attributes
  170. of this module.
  171. """
  172. if "requires_grad" in kwargs:
  173. del kwargs["requires_grad"]
  174. logger.warning(
  175. "Tensor currently has no requires_grad attribute "
  176. "so requires_grad argument is ignored here"
  177. )
  178. def predicate(obj) -> bool:
  179. return _is_parameter(obj)
  180. yield from self._flatten(
  181. with_key=False, predicate=predicate, recursive=recursive, **kwargs
  182. )
  183. def named_parameters(
  184. self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
  185. ) -> Iterable[Tuple[str, Parameter]]:
  186. r"""Returns an iterable for key :class:`~.Parameter` pairs of the module, where
  187. ``key`` is the dotted path from this module to the :class:`~.Parameter`.
  188. Args:
  189. prefix: prefix prepended to the keys.
  190. recursive: if ``True``, returns all :class:`~.Parameter` within this
  191. module, else only returns :class:`~.Parameter` that are direct attributes
  192. of this module.
  193. """
  194. if "requires_grad" in kwargs:
  195. del kwargs["requires_grad"]
  196. logger.warning(
  197. "Tensor currently has no requires_grad attribute "
  198. "so requires_grad argument is ignored here"
  199. )
  200. def predicate(obj) -> bool:
  201. return _is_parameter(obj)
  202. yield from self._flatten(
  203. with_key=True,
  204. prefix=prefix,
  205. predicate=predicate,
  206. recursive=recursive,
  207. **kwargs,
  208. )
  209. def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Tensor]:
  210. r"""Returns an iterable for the buffers of the module.
  211. Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`.
  212. Args:
  213. recursive: if ``True``, returns all buffers within this
  214. module, else only returns buffers that are direct attributes
  215. """
  216. yield from self._flatten(
  217. with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs
  218. )
  219. def named_buffers(
  220. self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
  221. ) -> Iterable[Tuple[str, Tensor]]:
  222. r"""Returns an iterable for key buffer pairs of the module, where
  223. ``key`` is the dotted path from this module to the buffer.
  224. Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`.
  225. Args:
  226. prefix: prefix prepended to the keys.
  227. recursive: if ``True``, returns all buffers within this
  228. module, else only returns buffers that are direct attributes
  229. of this module.
  230. prefix: Optional[str]:
  231. """
  232. yield from self._flatten(
  233. with_key=True,
  234. prefix=prefix,
  235. predicate=_is_buffer,
  236. recursive=recursive,
  237. **kwargs,
  238. )
  239. def tensors(self, recursive: bool = True, **kwargs) -> Iterable[Parameter]:
  240. r"""
  241. Returns an iterable for the :class:`~.Tensor` of the module.
  242. :param recursive: If ``True``, returns all :class:`~.Tensor` within this
  243. module, else only returns :class:`~.Tensor` that are direct attributes
  244. of this module.
  245. """
  246. yield from self._flatten(
  247. with_key=False, predicate=_is_tensor, recursive=recursive, **kwargs
  248. )
  249. def named_tensors(
  250. self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
  251. ) -> Iterable[Tuple[str, Tensor]]:
  252. """
  253. Returns an iterable for key tensor pairs of the module, where
  254. ``key`` is the dotted path from this module to the tensor.
  255. :param prefix: prefix prepended to the keys.
  256. :param recursive: if ``True``, returns all tensors within this
  257. module, else only returns tensors that are direct attributes
  258. of this module.
  259. """
  260. yield from self._flatten(
  261. with_key=True,
  262. prefix=prefix,
  263. predicate=_is_tensor,
  264. recursive=recursive,
  265. **kwargs,
  266. )
  267. def children(self, **kwargs) -> "Iterable[Module]":
  268. r"""Returns an iterable for all the submodules that are direct attributes of this
  269. module.
  270. """
  271. yield from self._flatten(
  272. with_key=False, predicate=_is_module, recursive=False, **kwargs
  273. )
  274. def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]":
  275. r"""Returns an iterable of key-submodule pairs for all the submodules that are
  276. direct attributes of this module, where 'key' is the attribute name of
  277. submodules.
  278. """
  279. yield from self._flatten(
  280. with_key=True, predicate=_is_module, recursive=False, **kwargs
  281. )
  282. def modules(self, **kwargs) -> "Iterable[Module]":
  283. r"""Returns an iterable for all the modules within this module, including itself."""
  284. if "with_parent" in kwargs and kwargs["with_parent"]:
  285. yield self, None
  286. else:
  287. yield self
  288. yield from self._flatten(with_key=False, predicate=_is_module, **kwargs)
  289. def named_modules(
  290. self, prefix: Optional[str] = None, **kwargs
  291. ) -> "Iterable[Tuple[str, Module]]":
  292. r"""Returns an iterable of key-module pairs for all the modules within this
  293. module, including itself, where 'key' is the dotted path from this module to the
  294. submodules.
  295. Args:
  296. prefix: prefix prepended to the path.
  297. """
  298. if "with_parent" in kwargs and kwargs["with_parent"]:
  299. yield ("" if prefix is None else prefix), self, None
  300. else:
  301. yield ("" if prefix is None else prefix), self
  302. yield from self._flatten(
  303. with_key=True, prefix=prefix, predicate=_is_module, **kwargs
  304. )
  305. def apply(self, fn: "Callable[[Module], Any]") -> None:
  306. r"""Applies function ``fn`` to all the modules within this module, including
  307. itself.
  308. Args:
  309. fn: the function to be applied on modules.
  310. """
  311. for it in self.modules():
  312. fn(it)
  313. @deprecated(version="1.0")
  314. def zero_grad(self) -> None:
  315. r"""Sets all parameters' grads to zero"""
  316. for param in self.parameters():
  317. if param.grad is not None:
  318. param.grad.reset_zero()
  319. def train(self, mode: bool = True, recursive: bool = True) -> None:
  320. r"""Sets training mode of all the modules within this module (including itself) to
  321. ``mode``. This effectively sets the ``training`` attributes of those modules
  322. to ``mode``, but only has effect on certain modules (e.g.
  323. :class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`)
  324. Args:
  325. mode: the training mode to be set on modules.
  326. recursive: whether to recursively call submodules' ``train()``.
  327. """
  328. if not recursive:
  329. self.training = mode
  330. return
  331. def fn(module: Module) -> None:
  332. module.train(mode, recursive=False)
  333. self.apply(fn)
  334. def eval(self) -> None:
  335. r"""Sets training mode of all the modules within this module (including itself) to
  336. ``False``. See :meth:`~.Module.train` for details.
  337. """
  338. self.train(False)
  339. def disable_quantize(self, value=True):
  340. r"""Sets ``module``'s ``quantize_disabled`` attribute and return ``module``.
  341. Could be used as a decorator.
  342. """
  343. def fn(module: Module) -> None:
  344. module.quantize_disabled = value
  345. self.apply(fn)
  346. @deprecated(version="1.0")
  347. def replace_param(
  348. self, params: dict, start_pos: int, seen: Optional[Set[int]] = None
  349. ):
  350. r"""Replaces module's parameters with ``params``, used by :class:`~.ParamPack` to
  351. speedup multimachine training.
  352. """
  353. offset = 0
  354. if seen is None:
  355. seen = set([id(self)])
  356. module_dict = vars(self)
  357. for key in sorted(module_dict):
  358. hash_id = id(module_dict[key])
  359. if hash_id in seen:
  360. continue
  361. seen.add(hash_id)
  362. if isinstance(module_dict[key], Parameter):
  363. if start_pos + offset in params:
  364. assert make_shape_tuple(module_dict[key].shape) == make_shape_tuple(
  365. params[start_pos + offset].shape
  366. )
  367. module_dict[key] = params[start_pos + offset]
  368. offset += 1
  369. if isinstance(module_dict[key], Module):
  370. offset += module_dict[key].replace_param(
  371. params, start_pos + offset, seen
  372. )
  373. return offset
  374. def state_dict(self, rst=None, prefix="", keep_var=False):
  375. r"""Returns a dictionary containing whole states of the module."""
  376. _rst = self._state_dict(rst=rst, prefix=prefix, keep_var=keep_var)
  377. rst = OrderedDict()
  378. XNorm_typeclass = _get_XNorm_typeclass()
  379. for (module_type, k), v in _rst.items():
  380. # for performance reasons, parameters in XNorm (e.g., BatchNorm2d) are 4-dim tensors,
  381. # however they will be reshaped to 1-dim tensors before returned by `statr_dict()`
  382. if issubclass(module_type, XNorm_typeclass):
  383. v = v.reshape(-1)
  384. rst[k] = v
  385. return rst
  386. def _state_dict(self, rst=None, prefix="", keep_var=False):
  387. r"""Returns a dictionary containing whole states of the module."""
  388. def is_state(obj):
  389. return _is_parameter(obj) or _is_buffer(obj)
  390. module_type = self.__class__
  391. if rst is None:
  392. rst = OrderedDict()
  393. for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state):
  394. assert prefix + k not in rst, "duplicated state: {}".format(k)
  395. if keep_var:
  396. rst[(module_type, prefix + k)] = v
  397. else:
  398. rst[(module_type, prefix + k)] = v.numpy()
  399. for k, submodule in self._flatten(
  400. recursive=False,
  401. with_key=True,
  402. predicate=lambda obj: isinstance(obj, Module),
  403. ):
  404. submodule.state_dict(rst, prefix + k + ".", keep_var)
  405. return rst
  406. def load_state_dict(
  407. self,
  408. state_dict: Union[dict, Callable[[str, Tensor], Optional[np.ndarray]]],
  409. strict=True,
  410. ):
  411. r"""Loads a given dictionary created by :func:`state_dict` into this module.
  412. If ``strict`` is ``True``, the keys of :func:`state_dict` must exactly match the keys
  413. returned by :func:`state_dict`.
  414. Users can also pass a closure: ``Function[key: str, var: Tensor] -> Optional[np.ndarray]``
  415. as a `state_dict`, in order to handle complex situations. For example, load everything
  416. except for the final linear classifier:
  417. .. code-block::
  418. state_dict = {...} # Dict[str, np.ndarray]
  419. model.load_state_dict({
  420. k: None if k.startswith('fc') else v
  421. for k, v in state_dict.items()
  422. }, strict=False)
  423. Here returning ``None`` means skipping parameter ``k``.
  424. To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading:
  425. .. code-block::
  426. state_dict = {...}
  427. def reshape_accordingly(k, v):
  428. return state_dict[k].reshape(v.shape)
  429. model.load_state_dict(reshape_accordingly)
  430. We can also perform inplace re-initialization or pruning:
  431. .. code-block::
  432. def reinit_and_pruning(k, v):
  433. if 'bias' in k:
  434. M.init.zero_(v)
  435. if 'conv' in k:
  436. """
  437. unused = []
  438. if isinstance(state_dict, dict):
  439. unused = state_dict.keys()
  440. def closure(k, _): # var unused
  441. return state_dict[k] if k in state_dict else None
  442. elif callable(state_dict):
  443. closure = state_dict
  444. else:
  445. raise ValueError(
  446. "`state_dict` must load a dict or callable, got {}".format(
  447. type(state_dict)
  448. )
  449. )
  450. loaded, skipped = self._load_state_dict_with_closure(closure)
  451. unused = set(unused) - loaded
  452. if len(unused) != 0:
  453. if strict:
  454. raise KeyError(
  455. "Unused params violate `strict=True`, unused={}".format(unused)
  456. )
  457. else:
  458. logger.warning(
  459. "Unused params in `strict=False` mode, unused={}".format(unused)
  460. )
  461. if len(skipped) != 0:
  462. if strict:
  463. raise KeyError(
  464. "Missing params violate `strict=True`, missing={}".format(skipped)
  465. )
  466. else:
  467. logger.warning(
  468. "Missing params in `strict=False` mode, missing={}".format(skipped)
  469. )
  470. def _load_state_dict_with_closure(self, closure):
  471. r"""Advance state_dict load through callable ``closure`` whose signature is
  472. ``closure(key: str, var: Tensor) -> Union[np.ndarry, None]``
  473. """
  474. XNorm_typeclass = _get_XNorm_typeclass()
  475. assert callable(closure), "closure must be a function"
  476. loaded = []
  477. skipped = []
  478. local_state_dict = self._state_dict(keep_var=True)
  479. for (module_type, k), var in local_state_dict.items():
  480. to_be_load = closure(k, var)
  481. if to_be_load is None:
  482. skipped.append(k)
  483. continue
  484. assert isinstance(
  485. to_be_load, np.ndarray
  486. ), "closure should return a `np.ndarray`, now `{}` get {}".format(
  487. k, to_be_load
  488. )
  489. var_shape = make_shape_tuple(var.shape)
  490. to_be_load_shape = make_shape_tuple(to_be_load.shape)
  491. if var_shape != to_be_load_shape:
  492. # weight and bias in BatchNorm1d, BatchNorm2d and SyncBatchNorm are 1-dim tensors in v1.0, and
  493. # since v1.1 they are 4-dim tensors. The following special rule for these modules preserves the
  494. # backward compatibility.
  495. if issubclass(module_type, XNorm_typeclass):
  496. if np.prod(var_shape) == np.prod(to_be_load_shape):
  497. to_be_load = to_be_load.reshape(var_shape)
  498. else:
  499. raise ValueError(
  500. "param `{}` size mismatch, should be {}, get {}".format(
  501. k, np.prod(var_shape), np.prod(to_be_load_shape)
  502. )
  503. )
  504. else:
  505. raise ValueError(
  506. "param `{}` shape mismatch, should be {}, get {}".format(
  507. k, var_shape, to_be_load_shape
  508. )
  509. )
  510. var._reset(
  511. type(var)(
  512. to_be_load, dtype=to_be_load.dtype, device=var.device, no_cache=True
  513. )
  514. )
  515. loaded.append(k)
  516. return set(loaded), set(skipped)
  517. def __setattr__(self, name: str, value):
  518. is_module_like = _is_module(value) or isinstance(value, (list, tuple, dict))
  519. if name != "_modules":
  520. modules = self.__dict__.get("_modules")
  521. if modules is None and is_module_like:
  522. raise AttributeError(
  523. "cannot assign module before Module.__init__() call"
  524. )
  525. if is_module_like:
  526. if name not in modules:
  527. modules.append(name)
  528. else:
  529. if modules is not None and name in modules:
  530. modules.remove(name)
  531. def append_name(prefix, name):
  532. if prefix is None or prefix == "":
  533. return name
  534. return prefix + "." + name
  535. def set_name(parent, prefix, name, obj):
  536. if isinstance(obj, Tensor):
  537. assert obj.name is not None
  538. if obj.name != "":
  539. name = obj.name
  540. full_name = append_name(prefix, name)
  541. if obj._short_name and obj._short_name != name:
  542. logger.warning(
  543. "try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format(
  544. obj._short_name, type(parent), name, obj._short_name
  545. )
  546. )
  547. return
  548. if isinstance(obj, Tensor):
  549. obj._prefix = prefix
  550. obj._name = full_name
  551. obj._short_name = name
  552. obj._set_name(obj._name)
  553. return obj._name
  554. elif isinstance(obj, Module):
  555. obj._name = full_name
  556. obj._short_name = name
  557. for k, v in obj._flatten(recursive=False, with_key=True):
  558. set_name(obj, full_name, k, v)
  559. return obj._name
  560. else:
  561. assert False
  562. for k, v in _expand_structure(name, value):
  563. prefix = self._name if self._name else self.name
  564. set_name(self, prefix, k, v)
  565. super().__setattr__(name, value)
  566. def __setstate__(self, state):
  567. if "_short_name" not in state:
  568. state["_short_name"] = state["_name"]
  569. state["_name"] = None
  570. self.__dict__.update(state)
  571. def __delattr__(self, name: str):
  572. if name in self.__dict__ and _is_module(self.__dict__[name]):
  573. modules = self.__dict__.get("_modules")
  574. if name in modules:
  575. modules.remove(name)
  576. super().__delattr__(name)
  577. def _module_info_string(self) -> str:
  578. r"""Set the extra representation of the module."""
  579. return ""
  580. def __repr__(self):
  581. def add_indent(repr_str, num_spaces):
  582. s = repr_str.split("\n")
  583. # don't do anything for single-line stuff
  584. if len(s) == 1:
  585. return repr_str
  586. first = s.pop(0)
  587. s = [(num_spaces * " ") + line for line in s]
  588. s = "\n".join(s)
  589. s = first + "\n" + s
  590. return s
  591. extra_lines = []
  592. extra_repr = self._module_info_string()
  593. if extra_repr:
  594. extra_lines = extra_repr.split("\n")
  595. child_lines = []
  596. for name in self._modules:
  597. if _is_module(self.__dict__[name]):
  598. child_lines.append(
  599. "(" + name + "): " + add_indent(repr(self.__dict__[name]), 2)
  600. )
  601. else:
  602. for k, v in _expand_structure(name, self.__dict__[name]):
  603. if _is_module(v):
  604. child_lines.append("(" + k + "): " + add_indent(repr(v), 2))
  605. lines = extra_lines + child_lines
  606. main_str = self.__class__.__name__ + "("
  607. if lines:
  608. # simple one-liner info, which most builtin Modules will use
  609. if len(extra_lines) == 1 and not child_lines:
  610. main_str += extra_lines[0]
  611. else:
  612. main_str += "\n " + "\n ".join(lines) + "\n"
  613. main_str += ")"
  614. return main_str