You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import copy
  9. from collections.abc import MutableMapping, MutableSequence
  10. from typing import Dict, Iterable, List, Optional, Sequence
  11. from ..module import Module
  12. def replace_container_with_module_container(container):
  13. has_module = False
  14. module_container = None
  15. if isinstance(container, Dict):
  16. m_dic = copy.copy(container)
  17. for key, value in container.items():
  18. if isinstance(value, Module):
  19. has_module = True
  20. elif isinstance(value, (List, Dict)):
  21. (
  22. _has_module,
  23. _module_container,
  24. ) = replace_container_with_module_container(value)
  25. m_dic[key] = _module_container
  26. if _has_module:
  27. has_module = True
  28. if not all(isinstance(v, Module) for v in m_dic.values()):
  29. return has_module, None
  30. else:
  31. return has_module, _ModuleDict(m_dic)
  32. elif isinstance(container, List):
  33. m_list = copy.copy(container)
  34. for ind, value in enumerate(container):
  35. if isinstance(value, Module):
  36. has_module = True
  37. elif isinstance(value, (List, Dict)):
  38. (
  39. _has_module,
  40. _module_container,
  41. ) = replace_container_with_module_container(value)
  42. m_list[ind] = _module_container
  43. if _has_module:
  44. has_module = True
  45. if not all(isinstance(v, Module) for v in m_list):
  46. return has_module, None
  47. else:
  48. return has_module, _ModuleList(m_list)
  49. return has_module, module_container
  50. class _ModuleList(Module, MutableSequence):
  51. r"""A List-like container.
  52. Using a ``ModuleList``, one can visit, add, delete and modify submodules
  53. just like an ordinary python list.
  54. """
  55. def __init__(self, modules: Optional[Iterable[Module]] = None):
  56. super().__init__()
  57. self._size = 0
  58. if modules is None:
  59. return
  60. for mod in modules:
  61. self.append(mod)
  62. @classmethod
  63. def _ikey(cls, idx):
  64. return "{}".format(idx)
  65. def _check_idx(self, idx):
  66. L = len(self)
  67. if idx < 0:
  68. idx = L + idx
  69. if idx < 0 or idx >= L:
  70. raise IndexError("list index out of range")
  71. return idx
  72. def __getitem__(self, idx: int):
  73. if isinstance(idx, slice):
  74. idx = range(self._size)[idx]
  75. if not isinstance(idx, Sequence):
  76. idx = [
  77. idx,
  78. ]
  79. rst = []
  80. for i in idx:
  81. i = self._check_idx(i)
  82. key = self._ikey(i)
  83. try:
  84. rst.append(getattr(self, key))
  85. except AttributeError:
  86. raise IndexError("list index out of range")
  87. return rst if len(rst) > 1 else rst[0]
  88. def __setattr__(self, key, value):
  89. # clear mod name to avoid warning in Module's setattr
  90. if isinstance(value, Module):
  91. value._name = None
  92. super().__setattr__(key, value)
  93. def __setitem__(self, idx: int, mod: Module):
  94. if not isinstance(mod, Module):
  95. raise ValueError("invalid sub-module")
  96. idx = self._check_idx(idx)
  97. setattr(self, self._ikey(idx), mod)
  98. def __delitem__(self, idx):
  99. idx = self._check_idx(idx)
  100. L = len(self)
  101. for orig_idx in range(idx + 1, L):
  102. new_idx = orig_idx - 1
  103. self[new_idx] = self[orig_idx]
  104. delattr(self, self._ikey(L - 1))
  105. self._size -= 1
  106. def __len__(self):
  107. return self._size
  108. def insert(self, idx, mod: Module):
  109. assert isinstance(mod, Module)
  110. L = len(self)
  111. if idx < 0:
  112. idx = L - idx
  113. # clip idx to (0, L)
  114. if idx > L:
  115. idx = L
  116. elif idx < 0:
  117. idx = 0
  118. for new_idx in range(L, idx, -1):
  119. orig_idx = new_idx - 1
  120. key = self._ikey(new_idx)
  121. setattr(self, key, self[orig_idx])
  122. key = self._ikey(idx)
  123. setattr(self, key, mod)
  124. self._size += 1
  125. def forward(self):
  126. raise RuntimeError("ModuleList is not callable")
  127. class _ModuleDict(Module, MutableMapping):
  128. r"""A Dict-like container.
  129. Using a ``ModuleDict``, one can visit, add, delete and modify submodules
  130. just like an ordinary python dict.
  131. """
  132. def __init__(self, modules: Optional[Dict[str, Module]] = None):
  133. super().__init__()
  134. self._module_keys = []
  135. if modules is not None:
  136. self.update(modules)
  137. def __delitem__(self, key):
  138. delattr(self, key)
  139. assert key in self._module_keys
  140. self._module_keys.remove(key)
  141. def __getitem__(self, key):
  142. return getattr(self, key)
  143. def __setattr__(self, key, value):
  144. # clear mod name to avoid warning in Module's setattr
  145. if isinstance(value, Module):
  146. value._name = None
  147. super().__setattr__(key, value)
  148. def __setitem__(self, key, value):
  149. if not isinstance(value, Module):
  150. raise ValueError("invalid sub-module")
  151. setattr(self, key, value)
  152. if key not in self._module_keys:
  153. self._module_keys.append(key)
  154. def __iter__(self):
  155. return iter(self.keys())
  156. def __len__(self):
  157. return len(self._module_keys)
  158. def items(self):
  159. return [(key, getattr(self, key)) for key in self._module_keys]
  160. def values(self):
  161. return [getattr(self, key) for key in self._module_keys]
  162. def keys(self):
  163. return self._module_keys
  164. def forward(self):
  165. raise RuntimeError("ModuleList is not callable")

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台