You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import copy
  11. from typing import List, Type
  12. from ... import module as M
  13. from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing
  14. from ...module import Module
  15. from ...tensor import Tensor
  16. from .expr import Apply, Call, Constant, Expr, GetAttr, Input
  17. from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer
  18. from .node import ModuleNode, Node, NodeMixin, TensorNode
  19. class InternalGraph:
  20. """
  21. ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
  22. Attributes:
  23. _exprs: List of Exprs in order of execution
  24. _inputs: Input Nodes of InternalGraph
  25. _outputs: Output Nodes of InternalGraph
  26. """
  27. _exprs = None # type: List[Expr]
  28. _inputs = None # type: List[Node]
  29. _outputs = None # type: List[Node]
  30. def __init__(self):
  31. self._exprs = []
  32. self._inputs = []
  33. self._outputs = []
  34. def insert(self, expr):
  35. self._exprs.append(expr)
  36. def add_input(self, i):
  37. self._inputs.append(i)
  38. def add_output(self, o):
  39. self._outputs.append(o)
  40. def interpret(self, *inputs):
  41. # TODO: support kwargs ?
  42. # TODO: skip expressions which are independent and have no side effect
  43. node2value = {}
  44. for n, v in zip(self._inputs, inputs):
  45. node2value[n] = v
  46. for expr in self._exprs:
  47. values = expr.interpret(*list(node2value[i] for i in expr.inputs))
  48. for n, v in zip(expr.outputs, values):
  49. node2value[n] = v
  50. return list(node2value[i] for i in self._outputs)
  51. def __repr__(self):
  52. return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
  53. ", ".join(str(i) for i in self._inputs),
  54. "\n\t".join(str(i) for i in self._exprs),
  55. ", ".join(str(i) for i in self._outputs),
  56. )
  57. class TracedModuleBuilder(NodeMixin):
  58. _mod = None # type: Module
  59. _body = None # type: InternalGraph
  60. _is_builtin = None # type: bool
  61. __builder_attributes__ = [
  62. "_mod",
  63. "_body",
  64. "_NodeMixin__node",
  65. "_is_builtin",
  66. "_is_traced",
  67. "build",
  68. ]
  69. def __init__(self, mod):
  70. super(TracedModuleBuilder, self).__init__()
  71. self._mod = mod
  72. self._body = InternalGraph()
  73. self._is_traced = False
  74. self._is_builtin = module_tracer.is_builtin(mod)
  75. def build(self):
  76. if self._is_builtin:
  77. node = NodeMixin.get(self)
  78. node.module_type = type(self._mod)
  79. return self._mod
  80. else:
  81. node = NodeMixin.get(self)
  82. node.graph = self._body
  83. node.attr_type_map = {}
  84. traced_module = TracedModule(node)
  85. for k, v in self.__dict__.items():
  86. if k not in TracedModuleBuilder.__builder_attributes__:
  87. if isinstance(v, TracedModuleBuilder):
  88. v = v.build()
  89. setattr(traced_module, k, v)
  90. traced_module.m_node.attr_type_map[k] = type(v)
  91. return traced_module
  92. def __call__(self, *inputs, **kwargs):
  93. assert isinstance(self._mod, Module)
  94. # prepare args and kwargs for inner graph
  95. def mark_constant(x):
  96. node = NodeMixin.get(x, None)
  97. if node is None: # capture as constant
  98. NodeMixin.wrap(x, lambda: Constant.make(x))
  99. for i in inputs:
  100. mark_constant(i)
  101. for k, v in kwargs.items():
  102. mark_constant(v)
  103. callnode = Call.make(NodeMixin.get(self))
  104. def add_input(x):
  105. callnode.add_input(NodeMixin.get(x))
  106. for i in inputs:
  107. add_input(i)
  108. for k, v in kwargs.items():
  109. add_input(v)
  110. if self._is_builtin or self._is_traced:
  111. unset_module_tracing()
  112. outputs = self._mod(*inputs, **kwargs)
  113. set_module_tracing()
  114. if self._is_builtin:
  115. self._body = None
  116. else:
  117. active_module_tracer().push_scope(self._body)
  118. # rebind self to new input node
  119. orig_self = NodeMixin.get(self)
  120. NodeMixin.wrap_safe(
  121. self, Input.make("self", NodeMixin.get_wrapped_type(self))
  122. )
  123. # prepare args and kwargs for inner graph
  124. def wrap(x):
  125. wrapped = copy.copy(x) # FIXME
  126. NodeMixin.wrap(
  127. wrapped,
  128. lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
  129. )
  130. return wrapped
  131. args = []
  132. for i in inputs:
  133. args.append(wrap(i))
  134. for k, v in kwargs.items():
  135. kwargs[k] = wrap(v)
  136. outputs = type(self._mod).forward(self, *args, **kwargs)
  137. for i in (
  138. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  139. ):
  140. active_module_tracer().current_scope().add_output(NodeMixin.get(i))
  141. NodeMixin.wrap_safe(self, orig_self)
  142. self._is_traced = True
  143. active_module_tracer().pop_scope()
  144. # rebind output to outer graph
  145. callnode.add_outputs(outputs)
  146. for i, node in zip(
  147. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,),
  148. callnode.outputs,
  149. ):
  150. NodeMixin.wrap_safe(i, node)
  151. return outputs
  152. def __getattr__(self, name):
  153. if name not in self._mod.__dict__:
  154. attr = getattr(type(self._mod), name).__get__(self, type(self))
  155. else:
  156. attr = getattr(self._mod, name)
  157. if isinstance(attr, Module):
  158. attr = TracedModuleBuilder(attr)
  159. setattr(self, name, attr)
  160. NodeMixin.wrap(
  161. attr,
  162. lambda: GetAttr.make(
  163. NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr)
  164. ),
  165. )
  166. return attr
  167. def __getattribute__(self, name):
  168. if name in TracedModuleBuilder.__builder_attributes__:
  169. return super().__getattribute__(name)
  170. else:
  171. wrapped = super().__getattribute__(name)
  172. if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None):
  173. assert not self._is_builtin
  174. NodeMixin.wrap(
  175. wrapped,
  176. lambda: GetAttr.make(
  177. NodeMixin.get(self),
  178. name,
  179. type=NodeMixin.get_wrapped_type(wrapped),
  180. ),
  181. )
  182. return wrapped
  183. class TracedModule(Module):
  184. """
  185. `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
  186. """
  187. m_node = None # type: ModuleNode
  188. def __init__(self, node):
  189. super(TracedModule, self).__init__()
  190. self.m_node = node
  191. def forward(self, *inputs):
  192. rst = self.m_node.graph.interpret(self, *inputs)
  193. if len(rst) == 1:
  194. rst = rst[0]
  195. return rst
  196. def __getstate__(self):
  197. d = self.__dict__
  198. for k in Module.__dict__:
  199. d.pop(k, None)
  200. return d
  201. def cpp_apply_module_trace(opdef, *args):
  202. return Apply.apply_module_trace_hook(opdef, *args)
  203. def register_as_builtin(mod_cls: Type[Module]) -> None:
  204. """
  205. Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
  206. param mod_cls: the Module class which will be threated as builtin module in tracing
  207. """
  208. module_tracer.register_as_builtin(mod_cls)
  209. def _register_all_builtin_module():
  210. from inspect import getmembers, isclass
  211. for sub_mod in [M, M.qat, M.quantized]:
  212. for m in getmembers(sub_mod):
  213. if (
  214. isclass(m[1])
  215. and issubclass(m[1], M.Module)
  216. and m[1] is not M.Sequential
  217. ):
  218. module_tracer.register_as_builtin(m[1])
  219. def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule:
  220. """
  221. Traces module ``mod`` and returns corresponding TracedModule.
  222. param mod: the module will be converted to TracedModule
  223. param input: the positional arguments passed to forward method of ``mod``
  224. param kwargs: the keyword arguments passed to forward method of ``mod``
  225. """
  226. assert active_module_tracer() is None
  227. try:
  228. set_module_tracing()
  229. set_active_module_tracer(module_tracer())
  230. global_scope = InternalGraph()
  231. active_module_tracer().push_scope(global_scope)
  232. builder = TracedModuleBuilder(mod)
  233. NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
  234. for _, i in enumerate(inputs):
  235. NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_)))
  236. for k, v in kwargs.items():
  237. NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k)))
  238. builder(*inputs, **kwargs)
  239. active_module_tracer().pop_scope()
  240. return builder.build()
  241. finally:
  242. set_active_module_tracer(None)
  243. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台