|
@@ -67,6 +67,10 @@ def _is_parameter(obj): |
|
|
return isinstance(obj, Parameter) |
|
|
return isinstance(obj, Parameter) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_tensor(obj): |
|
|
|
|
|
return isinstance(obj, Tensor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_buffer(obj): |
|
|
def _is_buffer(obj): |
|
|
return isinstance(obj, Tensor) and not isinstance(obj, Parameter) |
|
|
return isinstance(obj, Tensor) and not isinstance(obj, Parameter) |
|
|
|
|
|
|
|
@@ -309,8 +313,9 @@ class Module(metaclass=ABCMeta): |
|
|
module, else only returns :class:`~.Tensor` that are direct attributes |
|
|
module, else only returns :class:`~.Tensor` that are direct attributes |
|
|
of this module. |
|
|
of this module. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
yield from self._flatten(with_key=False, recursive=recursive, **kwargs) |
|
|
|
|
|
|
|
|
yield from self._flatten( |
|
|
|
|
|
with_key=False, predicate=_is_tensor, recursive=recursive, **kwargs |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
def named_tensors( |
|
|
def named_tensors( |
|
|
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs |
|
|
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs |
|
@@ -325,7 +330,11 @@ class Module(metaclass=ABCMeta): |
|
|
of this module. |
|
|
of this module. |
|
|
""" |
|
|
""" |
|
|
yield from self._flatten( |
|
|
yield from self._flatten( |
|
|
with_key=True, prefix=prefix, recursive=recursive, **kwargs, |
|
|
|
|
|
|
|
|
with_key=True, |
|
|
|
|
|
prefix=prefix, |
|
|
|
|
|
predicate=_is_tensor, |
|
|
|
|
|
recursive=recursive, |
|
|
|
|
|
**kwargs, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def children(self, **kwargs) -> "Iterable[Module]": |
|
|
def children(self, **kwargs) -> "Iterable[Module]": |
|
|