|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import copy
- from collections.abc import MutableMapping, MutableSequence
- from typing import Dict, Iterable, List, Optional, Sequence
-
- from ..module import Module
-
-
- def replace_container_with_module_container(container):
- has_module = False
- module_container = None
- if isinstance(container, Dict):
- m_dic = copy.copy(container)
- for key, value in container.items():
- if isinstance(value, Module):
- has_module = True
- elif isinstance(value, (List, Dict)):
- (
- _has_module,
- _module_container,
- ) = replace_container_with_module_container(value)
- m_dic[key] = _module_container
- if _has_module:
- has_module = True
- if not all(isinstance(v, Module) for v in m_dic.values()):
- return has_module, None
- else:
- return has_module, _ModuleDict(m_dic)
- elif isinstance(container, List):
- m_list = copy.copy(container)
- for ind, value in enumerate(container):
- if isinstance(value, Module):
- has_module = True
- elif isinstance(value, (List, Dict)):
- (
- _has_module,
- _module_container,
- ) = replace_container_with_module_container(value)
- m_list[ind] = _module_container
- if _has_module:
- has_module = True
- if not all(isinstance(v, Module) for v in m_list):
- return has_module, None
- else:
- return has_module, _ModuleList(m_list)
- return has_module, module_container
-
-
- class _ModuleList(Module, MutableSequence):
- r"""A List-like container.
-
- Using a ``ModuleList``, one can visit, add, delete and modify submodules
- just like an ordinary python list.
- """
-
- def __init__(self, modules: Optional[Iterable[Module]] = None):
- super().__init__()
- self._size = 0
- if modules is None:
- return
- for mod in modules:
- self.append(mod)
-
- @classmethod
- def _ikey(cls, idx):
- return "{}".format(idx)
-
- def _check_idx(self, idx):
- L = len(self)
- if idx < 0:
- idx = L + idx
- if idx < 0 or idx >= L:
- raise IndexError("list index out of range")
- return idx
-
- def __getitem__(self, idx: int):
- if isinstance(idx, slice):
- idx = range(self._size)[idx]
- if not isinstance(idx, Sequence):
- idx = [
- idx,
- ]
- rst = []
- for i in idx:
- i = self._check_idx(i)
- key = self._ikey(i)
- try:
- rst.append(getattr(self, key))
- except AttributeError:
- raise IndexError("list index out of range")
- return rst if len(rst) > 1 else rst[0]
-
- def __setattr__(self, key, value):
- # clear mod name to avoid warning in Module's setattr
- if isinstance(value, Module):
- value._name = None
- super().__setattr__(key, value)
-
- def __setitem__(self, idx: int, mod: Module):
- if not isinstance(mod, Module):
- raise ValueError("invalid sub-module")
- idx = self._check_idx(idx)
- setattr(self, self._ikey(idx), mod)
-
- def __delitem__(self, idx):
- idx = self._check_idx(idx)
- L = len(self)
- for orig_idx in range(idx + 1, L):
- new_idx = orig_idx - 1
- self[new_idx] = self[orig_idx]
- delattr(self, self._ikey(L - 1))
- self._size -= 1
-
- def __len__(self):
- return self._size
-
- def insert(self, idx, mod: Module):
- assert isinstance(mod, Module)
- L = len(self)
- if idx < 0:
- idx = L - idx
- # clip idx to (0, L)
- if idx > L:
- idx = L
- elif idx < 0:
- idx = 0
-
- for new_idx in range(L, idx, -1):
- orig_idx = new_idx - 1
- key = self._ikey(new_idx)
- setattr(self, key, self[orig_idx])
-
- key = self._ikey(idx)
- setattr(self, key, mod)
- self._size += 1
-
- def forward(self):
- raise RuntimeError("ModuleList is not callable")
-
-
- class _ModuleDict(Module, MutableMapping):
- r"""A Dict-like container.
-
- Using a ``ModuleDict``, one can visit, add, delete and modify submodules
- just like an ordinary python dict.
- """
-
- def __init__(self, modules: Optional[Dict[str, Module]] = None):
- super().__init__()
- self._module_keys = []
- if modules is not None:
- self.update(modules)
-
- def __delitem__(self, key):
- delattr(self, key)
- assert key in self._module_keys
- self._module_keys.remove(key)
-
- def __getitem__(self, key):
- return getattr(self, key)
-
- def __setattr__(self, key, value):
- # clear mod name to avoid warning in Module's setattr
- if isinstance(value, Module):
- value._name = None
- super().__setattr__(key, value)
-
- def __setitem__(self, key, value):
- if not isinstance(value, Module):
- raise ValueError("invalid sub-module")
- setattr(self, key, value)
- if key not in self._module_keys:
- self._module_keys.append(key)
-
- def __iter__(self):
- return iter(self.keys())
-
- def __len__(self):
- return len(self._module_keys)
-
- def items(self):
- return [(key, getattr(self, key)) for key in self._module_keys]
-
- def values(self):
- return [getattr(self, key) for key in self._module_keys]
-
- def keys(self):
- return self._module_keys
-
- def forward(self):
- raise RuntimeError("ModuleList is not callable")
|