|
|
@@ -18,17 +18,25 @@ logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def _expand_structure(key, obj): |
|
|
|
if isinstance(obj, (list, tuple, dict)): |
|
|
|
if isinstance(obj, (Tensor, Module)): |
|
|
|
return [(key, obj)] |
|
|
|
elif isinstance(obj, (list, tuple, dict)): |
|
|
|
ret = [] |
|
|
|
if isinstance(obj, dict): |
|
|
|
targets = ((k, obj[k]) for k in sorted(obj)) |
|
|
|
else: |
|
|
|
targets = ((str(k), v) for k, v in enumerate(obj)) |
|
|
|
for k, o in targets: |
|
|
|
ret.extend(_expand_structure(key + "." + k, o)) |
|
|
|
sub_ret = _expand_structure(k, o) |
|
|
|
if sub_ret and not isinstance(k, str): |
|
|
|
raise AssertionError( |
|
|
|
"keys for Tensor and Module must be str, error key: {}".format(k) |
|
|
|
) |
|
|
|
for kt, vt in sub_ret: |
|
|
|
ret.extend([(key + "." + kt, vt)]) |
|
|
|
return ret |
|
|
|
else: |
|
|
|
return [(key, obj)] |
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
def _is_parameter(obj): |
|
|
@@ -72,11 +80,11 @@ class Module(metaclass=ABCMeta): |
|
|
|
predicate: Callable[[Any], bool] = lambda _: True, |
|
|
|
seen: Optional[Set[int]] = None |
|
|
|
) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]: |
|
|
|
"""Scans the module object and returns an iterable for the attributes that |
|
|
|
agree with the ``predicate``. For multiple calls of this function with same |
|
|
|
arguments, the order of objects within the returned iterable is guaranteed to be |
|
|
|
identical, as long as all the involved module objects' ``__dict__`` does not |
|
|
|
change thoughout those calls. |
|
|
|
"""Scans the module object and returns an iterable for the :class:`~.Tensor` |
|
|
|
and :class:`~.Module` attributes that agree with the ``predicate``. For multiple |
|
|
|
calls of this function with same arguments, the order of objects within the |
|
|
|
returned iterable is guaranteed to be identical, as long as all the involved |
|
|
|
module objects' ``__dict__`` does not change thoughout those calls. |
|
|
|
|
|
|
|
:param recursive: Whether to recursively scan all the submodules. |
|
|
|
:param with_key: Whether to yield keys along with yielded objects. |
|
|
|