Browse Source

feat(module): add tensors and named_tensors

GitOrigin-RevId: cb56d65d38
release-1.6
Megvii Engine Team 3 years ago
parent
commit
1019865071
1 changed files with 27 additions and 0 deletions
  1. +27
    -0
      imperative/python/megengine/module/module.py

+ 27
- 0
imperative/python/megengine/module/module.py View File

@@ -301,6 +301,33 @@ class Module(metaclass=ABCMeta):
**kwargs, **kwargs,
) )


def tensors(self, recursive: bool = True, **kwargs) -> Iterable[Parameter]:
r"""
Returns an iterable for the :class:`~.Tensor` of the module.

:param recursive: If ``True``, returns all :class:`~.Tensor` within this
module, else only returns :class:`~.Tensor` that are direct attributes
of this module.
"""

yield from self._flatten(with_key=False, recursive=recursive, **kwargs)

def named_tensors(
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
) -> Iterable[Tuple[str, Tensor]]:
"""
Returns an iterable for key tensor pairs of the module, where
``key`` is the dotted path from this module to the tensor.

:param prefix: prefix prepended to the keys.
:param recursive: if ``True``, returns all tensors within this
module, else only returns tensors that are direct attributes
of this module.
"""
yield from self._flatten(
with_key=True, prefix=prefix, recursive=recursive, **kwargs,
)

def children(self, **kwargs) -> "Iterable[Module]": def children(self, **kwargs) -> "Iterable[Module]":
r"""Returns an iterable for all the submodules that are direct attributes of this r"""Returns an iterable for all the submodules that are direct attributes of this
module. module.


Loading…
Cancel
Save