|
@@ -147,13 +147,14 @@ class _ModuleDict(Module, MutableMapping): |
|
|
|
|
|
|
|
|
def __init__(self, modules: Optional[Dict[str, Module]] = None): |
|
|
def __init__(self, modules: Optional[Dict[str, Module]] = None): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self._size = 0 |
|
|
|
|
|
|
|
|
self._module_keys = [] |
|
|
if modules is not None: |
|
|
if modules is not None: |
|
|
self.update(modules) |
|
|
self.update(modules) |
|
|
|
|
|
|
|
|
def __delitem__(self, key): |
|
|
def __delitem__(self, key): |
|
|
delattr(self, key) |
|
|
delattr(self, key) |
|
|
self._size -= 1 |
|
|
|
|
|
|
|
|
assert key in self._module_keys |
|
|
|
|
|
self._module_keys.remove(key) |
|
|
|
|
|
|
|
|
def __getitem__(self, key): |
|
|
def __getitem__(self, key): |
|
|
return getattr(self, key) |
|
|
return getattr(self, key) |
|
@@ -162,22 +163,23 @@ class _ModuleDict(Module, MutableMapping): |
|
|
if not isinstance(value, Module): |
|
|
if not isinstance(value, Module): |
|
|
raise ValueError("invalid sub-module") |
|
|
raise ValueError("invalid sub-module") |
|
|
setattr(self, key, value) |
|
|
setattr(self, key, value) |
|
|
self._size += 1 |
|
|
|
|
|
|
|
|
if key not in self._module_keys: |
|
|
|
|
|
self._module_keys.append(key) |
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
def __iter__(self): |
|
|
return iter(self.keys()) |
|
|
return iter(self.keys()) |
|
|
|
|
|
|
|
|
def __len__(self): |
|
|
def __len__(self): |
|
|
return self._size |
|
|
|
|
|
|
|
|
return len(self._module_keys) |
|
|
|
|
|
|
|
|
def items(self): |
|
|
def items(self): |
|
|
return dict(self.named_children()).items() |
|
|
|
|
|
|
|
|
return [(key, getattr(self, key)) for key in self._module_keys] |
|
|
|
|
|
|
|
|
def values(self): |
|
|
def values(self): |
|
|
return dict(self.named_children()).values() |
|
|
|
|
|
|
|
|
return [getattr(self, key) for key in self._module_keys] |
|
|
|
|
|
|
|
|
def keys(self): |
|
|
def keys(self): |
|
|
return dict(self.named_children()).keys() |
|
|
|
|
|
|
|
|
return self._module_keys |
|
|
|
|
|
|
|
|
def forward(self): |
|
|
def forward(self): |
|
|
raise RuntimeError("ModuleList is not callable") |
|
|
raise RuntimeError("ModuleList is not callable") |