|
@@ -68,6 +68,7 @@ class Module(metaclass=ABCMeta): |
|
|
*, |
|
|
*, |
|
|
recursive: bool = True, |
|
|
recursive: bool = True, |
|
|
with_key: bool = False, |
|
|
with_key: bool = False, |
|
|
|
|
|
with_parent: bool = False, |
|
|
prefix: Optional[str] = None, |
|
|
prefix: Optional[str] = None, |
|
|
predicate: Callable[[Any], bool] = lambda _: True, |
|
|
predicate: Callable[[Any], bool] = lambda _: True, |
|
|
seen: Optional[Set[int]] = None |
|
|
seen: Optional[Set[int]] = None |
|
@@ -80,6 +81,7 @@ class Module(metaclass=ABCMeta): |
|
|
|
|
|
|
|
|
:param recursive: Whether to recursively scan all the submodules. |
|
|
:param recursive: Whether to recursively scan all the submodules. |
|
|
:param with_key: Whether to yield keys along with yielded objects. |
|
|
:param with_key: Whether to yield keys along with yielded objects. |
|
|
|
|
|
:param with_parent: Whether to yield ``self`` along with yielded objects. |
|
|
:param prefix: The prefix appended to the yielded keys. |
|
|
:param prefix: The prefix appended to the yielded keys. |
|
|
:param predicate: The predicate function applied to scanned objects. |
|
|
:param predicate: The predicate function applied to scanned objects. |
|
|
:param seen: A dict that records whether a module has been traversed yet. |
|
|
:param seen: A dict that records whether a module has been traversed yet. |
|
@@ -88,7 +90,7 @@ class Module(metaclass=ABCMeta): |
|
|
seen = set([id(self)]) |
|
|
seen = set([id(self)]) |
|
|
|
|
|
|
|
|
module_dict = vars(self) |
|
|
module_dict = vars(self) |
|
|
_prefix = "" if not prefix else prefix + "." |
|
|
|
|
|
|
|
|
_prefix = "" if prefix is None else prefix + "." |
|
|
|
|
|
|
|
|
for key in sorted(module_dict): |
|
|
for key in sorted(module_dict): |
|
|
for expanded_key, leaf in _expand_structure(key, module_dict[key]): |
|
|
for expanded_key, leaf in _expand_structure(key, module_dict[key]): |
|
@@ -98,8 +100,12 @@ class Module(metaclass=ABCMeta): |
|
|
seen.add(leaf_id) |
|
|
seen.add(leaf_id) |
|
|
|
|
|
|
|
|
if predicate(leaf): |
|
|
if predicate(leaf): |
|
|
if with_key: |
|
|
|
|
|
|
|
|
if with_key and with_parent: |
|
|
|
|
|
yield _prefix + expanded_key, leaf, self |
|
|
|
|
|
elif with_key: |
|
|
yield _prefix + expanded_key, leaf |
|
|
yield _prefix + expanded_key, leaf |
|
|
|
|
|
elif with_parent: |
|
|
|
|
|
yield leaf, self |
|
|
else: |
|
|
else: |
|
|
yield leaf |
|
|
yield leaf |
|
|
|
|
|
|
|
@@ -107,22 +113,22 @@ class Module(metaclass=ABCMeta): |
|
|
yield from leaf._flatten( |
|
|
yield from leaf._flatten( |
|
|
recursive=recursive, |
|
|
recursive=recursive, |
|
|
with_key=with_key, |
|
|
with_key=with_key, |
|
|
prefix=None if prefix is None else _prefix + expanded_key, |
|
|
|
|
|
|
|
|
with_parent=with_parent, |
|
|
|
|
|
prefix=_prefix + expanded_key if with_key else None, |
|
|
predicate=predicate, |
|
|
predicate=predicate, |
|
|
seen=seen, |
|
|
seen=seen, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def parameters( |
|
|
def parameters( |
|
|
self, requires_grad: Optional[bool] = None, recursive: bool = True |
|
|
|
|
|
|
|
|
self, requires_grad: Optional[bool] = None, recursive: bool = True, **kwargs |
|
|
) -> Iterable[Parameter]: |
|
|
) -> Iterable[Parameter]: |
|
|
r"""Returns an iterable for the :class:`~.Parameter` of the module. |
|
|
r"""Returns an iterable for the :class:`~.Parameter` of the module. |
|
|
|
|
|
|
|
|
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` |
|
|
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` |
|
|
attribute of returned :class:`.Parameter`. ``None`` for |
|
|
|
|
|
no limitation. |
|
|
|
|
|
|
|
|
attribute of returned :class:`.Parameter`. ``None`` for no limitation. |
|
|
:param recursive: If ``True``, returns all :class:`~.Parameter` within this |
|
|
:param recursive: If ``True``, returns all :class:`~.Parameter` within this |
|
|
module, else only returns :class:`~.Parameter` that are direct |
|
|
|
|
|
attributes of this module. |
|
|
|
|
|
|
|
|
module, else only returns :class:`~.Parameter` that are direct attributes |
|
|
|
|
|
of this module. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def predicate(obj) -> bool: |
|
|
def predicate(obj) -> bool: |
|
@@ -130,24 +136,26 @@ class Module(metaclass=ABCMeta): |
|
|
requires_grad is None or obj.requires_grad == requires_grad |
|
|
requires_grad is None or obj.requires_grad == requires_grad |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
yield from self._flatten(predicate=predicate, recursive=recursive) |
|
|
|
|
|
|
|
|
yield from self._flatten( |
|
|
|
|
|
with_key=False, predicate=predicate, recursive=recursive, **kwargs |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
def named_parameters( |
|
|
def named_parameters( |
|
|
self, |
|
|
self, |
|
|
requires_grad: Optional[bool] = None, |
|
|
requires_grad: Optional[bool] = None, |
|
|
prefix: str = "", |
|
|
|
|
|
|
|
|
prefix: Optional[str] = None, |
|
|
recursive: bool = True, |
|
|
recursive: bool = True, |
|
|
|
|
|
**kwargs |
|
|
) -> Iterable[Tuple[str, Parameter]]: |
|
|
) -> Iterable[Tuple[str, Parameter]]: |
|
|
"""Returns an iterable for key :class:`~.Parameter` pairs of the module, where |
|
|
"""Returns an iterable for key :class:`~.Parameter` pairs of the module, where |
|
|
``key`` is the dotted path from this module to the :class:`~.Parameter` . |
|
|
``key`` is the dotted path from this module to the :class:`~.Parameter` . |
|
|
|
|
|
|
|
|
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` |
|
|
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` |
|
|
attribute of returned :class:`~.Parameter` . ``None`` for |
|
|
|
|
|
no limitation. |
|
|
|
|
|
|
|
|
attribute of returned :class:`~.Parameter` . ``None`` for no limitation. |
|
|
:param prefix: The prefix prepended to the keys. |
|
|
:param prefix: The prefix prepended to the keys. |
|
|
:param recursive: If ``True``, returns all :class:`~.Parameter` within this |
|
|
:param recursive: If ``True``, returns all :class:`~.Parameter` within this |
|
|
module, else only returns :class:`~.Parameter` that are direct |
|
|
|
|
|
attributes of this module. |
|
|
|
|
|
|
|
|
module, else only returns :class:`~.Parameter` that are direct attributes |
|
|
|
|
|
of this module. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def predicate(obj) -> bool: |
|
|
def predicate(obj) -> bool: |
|
@@ -156,17 +164,23 @@ class Module(metaclass=ABCMeta): |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
yield from self._flatten( |
|
|
yield from self._flatten( |
|
|
with_key=True, prefix=prefix, predicate=predicate, recursive=recursive |
|
|
|
|
|
|
|
|
with_key=True, |
|
|
|
|
|
prefix=prefix, |
|
|
|
|
|
predicate=predicate, |
|
|
|
|
|
recursive=recursive, |
|
|
|
|
|
**kwargs, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def buffers(self, recursive: bool = True) -> Iterable[Buffer]: |
|
|
|
|
|
|
|
|
def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Buffer]: |
|
|
"""Returns an iterable for the :class:`~.Buffer` of the module. |
|
|
"""Returns an iterable for the :class:`~.Buffer` of the module. |
|
|
|
|
|
|
|
|
:param recursive: If ``True``, returns all :class:`~.Buffer` within this |
|
|
:param recursive: If ``True``, returns all :class:`~.Buffer` within this |
|
|
module, else only returns :class:`~.Buffer` that are direct |
|
|
|
|
|
attributes of this module. |
|
|
|
|
|
|
|
|
module, else only returns :class:`~.Buffer` that are direct attributes |
|
|
|
|
|
of this module. |
|
|
""" |
|
|
""" |
|
|
yield from self._flatten(predicate=_is_buffer, recursive=recursive) |
|
|
|
|
|
|
|
|
yield from self._flatten( |
|
|
|
|
|
with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
def replace_param( |
|
|
def replace_param( |
|
|
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None |
|
|
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None |
|
@@ -192,48 +206,66 @@ class Module(metaclass=ABCMeta): |
|
|
return offset |
|
|
return offset |
|
|
|
|
|
|
|
|
def named_buffers( |
|
|
def named_buffers( |
|
|
self, prefix: str = "", recursive: bool = True |
|
|
|
|
|
|
|
|
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs |
|
|
) -> Iterable[Tuple[str, Buffer]]: |
|
|
) -> Iterable[Tuple[str, Buffer]]: |
|
|
"""Returns an iterable for key :class:`~.Buffer` pairs of the module, where |
|
|
"""Returns an iterable for key :class:`~.Buffer` pairs of the module, where |
|
|
``key`` is the dotted path from this module to the :class:`~.Buffer` . |
|
|
``key`` is the dotted path from this module to the :class:`~.Buffer` . |
|
|
|
|
|
|
|
|
:param prefix: The prefix prepended to the keys. |
|
|
:param prefix: The prefix prepended to the keys. |
|
|
:param recursive: If ``True``, returns all :class:`~.Buffer` within this |
|
|
:param recursive: If ``True``, returns all :class:`~.Buffer` within this |
|
|
module, else only returns :class:`~.Buffer` that are direct |
|
|
|
|
|
attributes of this module. |
|
|
|
|
|
|
|
|
module, else only returns :class:`~.Buffer` that are direct attributes |
|
|
|
|
|
of this module. |
|
|
""" |
|
|
""" |
|
|
yield from self._flatten( |
|
|
yield from self._flatten( |
|
|
with_key=True, prefix=prefix, predicate=_is_buffer, recursive=recursive |
|
|
|
|
|
|
|
|
with_key=True, |
|
|
|
|
|
prefix=prefix, |
|
|
|
|
|
predicate=_is_buffer, |
|
|
|
|
|
recursive=recursive, |
|
|
|
|
|
**kwargs, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def children(self) -> "Iterable[Module]": |
|
|
|
|
|
|
|
|
def children(self, **kwargs) -> "Iterable[Module]": |
|
|
"""Returns an iterable for all the submodules that are direct attributes of this |
|
|
"""Returns an iterable for all the submodules that are direct attributes of this |
|
|
module. |
|
|
module. |
|
|
""" |
|
|
""" |
|
|
yield from self._flatten(predicate=_is_module, recursive=False) |
|
|
|
|
|
|
|
|
yield from self._flatten( |
|
|
|
|
|
with_key=False, predicate=_is_module, recursive=False, **kwargs |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
def named_children(self) -> "Iterable[Tuple[str, Module]]": |
|
|
|
|
|
|
|
|
def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]": |
|
|
"""Returns an iterable of key-submodule pairs for all the submodules that are |
|
|
"""Returns an iterable of key-submodule pairs for all the submodules that are |
|
|
direct attributes of this module, where 'key' is the attribute name of |
|
|
direct attributes of this module, where 'key' is the attribute name of |
|
|
submodules. |
|
|
submodules. |
|
|
""" |
|
|
""" |
|
|
yield from self._flatten(with_key=True, predicate=_is_module, recursive=False) |
|
|
|
|
|
|
|
|
yield from self._flatten( |
|
|
|
|
|
with_key=True, predicate=_is_module, recursive=False, **kwargs |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
def modules(self) -> "Iterable[Module]": |
|
|
|
|
|
|
|
|
def modules(self, **kwargs) -> "Iterable[Module]": |
|
|
"""Returns an iterable for all the modules within this module, including itself. |
|
|
"""Returns an iterable for all the modules within this module, including itself. |
|
|
""" |
|
|
""" |
|
|
yield self |
|
|
|
|
|
yield from self._flatten(predicate=_is_module) |
|
|
|
|
|
|
|
|
if "with_parent" in kwargs and kwargs["with_parent"]: |
|
|
|
|
|
yield self, None |
|
|
|
|
|
else: |
|
|
|
|
|
yield self |
|
|
|
|
|
yield from self._flatten(with_key=False, predicate=_is_module, **kwargs) |
|
|
|
|
|
|
|
|
def named_modules(self, prefix: str = "") -> "Iterable[Tuple[str, Module]]": |
|
|
|
|
|
|
|
|
def named_modules( |
|
|
|
|
|
self, prefix: Optional[str] = None, **kwargs |
|
|
|
|
|
) -> "Iterable[Tuple[str, Module]]": |
|
|
"""Returns an iterable of key-module pairs for all the modules within this |
|
|
"""Returns an iterable of key-module pairs for all the modules within this |
|
|
module, including itself, where 'key' is the dotted path from this module to the |
|
|
module, including itself, where 'key' is the dotted path from this module to the |
|
|
submodules. |
|
|
submodules. |
|
|
|
|
|
|
|
|
:param prefix: The prefix prepended to the path. |
|
|
:param prefix: The prefix prepended to the path. |
|
|
""" |
|
|
""" |
|
|
yield prefix, self |
|
|
|
|
|
yield from self._flatten(with_key=True, prefix=prefix, predicate=_is_module) |
|
|
|
|
|
|
|
|
if "with_parent" in kwargs and kwargs["with_parent"]: |
|
|
|
|
|
yield ("" if prefix is None else prefix), self, None |
|
|
|
|
|
else: |
|
|
|
|
|
yield ("" if prefix is None else prefix), self |
|
|
|
|
|
yield from self._flatten( |
|
|
|
|
|
with_key=True, prefix=prefix, predicate=_is_module, **kwargs |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
def apply(self, fn: "Callable[[Module], Any]") -> None: |
|
|
def apply(self, fn: "Callable[[Module], Any]") -> None: |
|
|
"""Apply function ``fn`` to all the modules within this module, including |
|
|
"""Apply function ``fn`` to all the modules within this module, including |
|
|