Browse Source

feat(module): add tensors and named_tensors

GitOrigin-RevId: cb56d65d38
release-1.6
Megvii Engine Team 3 years ago
parent
commit
1019865071
1 changed files with 27 additions and 0 deletions
  1. +27
    -0
      imperative/python/megengine/module/module.py

+ 27
- 0
imperative/python/megengine/module/module.py View File

@@ -301,6 +301,33 @@ class Module(metaclass=ABCMeta):
**kwargs,
)

def tensors(self, recursive: bool = True, **kwargs) -> Iterable[Parameter]:
r"""
Returns an iterable for the :class:`~.Tensor` of the module.

:param recursive: If ``True``, returns all :class:`~.Tensor` within this
module, else only returns :class:`~.Tensor` that are direct attributes
of this module.
"""

yield from self._flatten(with_key=False, recursive=recursive, **kwargs)

def named_tensors(
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
) -> Iterable[Tuple[str, Tensor]]:
"""
Returns an iterable for key tensor pairs of the module, where
``key`` is the dotted path from this module to the tensor.

:param prefix: prefix prepended to the keys.
:param recursive: if ``True``, returns all tensors within this
module, else only returns tensors that are direct attributes
of this module.
"""
yield from self._flatten(
with_key=True, prefix=prefix, recursive=recursive, **kwargs,
)

def children(self, **kwargs) -> "Iterable[Module]":
r"""Returns an iterable for all the submodules that are direct attributes of this
module.


Loading…
Cancel
Save