You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import copy
  11. import functools
  12. from inspect import getmembers, isclass, ismethod
  13. from typing import Dict, List, Type
  14. from ... import module as M
  15. from ...core._imperative_rt.core2 import Tensor as RawTensor
  16. from ...core._imperative_rt.core2 import (
  17. is_tracing_module,
  18. set_module_tracing,
  19. unset_module_tracing,
  20. )
  21. from ...core.tensor.array_method import ArrayMethodMixin
  22. from ...module import Module
  23. from ...tensor import Tensor
  24. from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
  25. from .module_tracer import (
  26. Patcher,
  27. active_module_tracer,
  28. module_tracer,
  29. set_active_module_tracer,
  30. )
  31. from .node import ModuleNode, Node, NodeMixin, TensorNode
  32. from .pytree import tree_flatten
  33. def _leaf_type(node):
  34. if isinstance(node, RawTensor):
  35. return (Tensor, TensorNode)
  36. elif isinstance(node, (NodeMixin, Module)):
  37. return (Module, ModuleNode, NodeMixin)
  38. else:
  39. return type(node)
  40. def _is_const_leaf(node):
  41. if isinstance(node, (RawTensor, NodeMixin, Module)):
  42. return False
  43. return True
  44. class InternalGraph:
  45. """
  46. ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
  47. Attributes:
  48. _exprs: List of Exprs in order of execution
  49. _inputs: Input Nodes of InternalGraph
  50. _outputs: Output Nodes of InternalGraph
  51. """
  52. _exprs = None # type: List[Expr]
  53. _inputs = None # type: List[Node]
  54. _outputs = None # type: List[Node]
  55. def __init__(self):
  56. self._exprs = []
  57. self._inputs = []
  58. self._outputs = []
  59. def insert(self, expr):
  60. self._exprs.append(expr)
  61. @property
  62. def inputs(self):
  63. return self._inputs
  64. @property
  65. def outputs(self):
  66. return self._outputs
  67. @property
  68. def exprs(self):
  69. return _expr_list(self)
  70. def add_input(self, i):
  71. self._inputs.append(i)
  72. def add_output(self, o):
  73. self._outputs.append(o)
  74. def interpret(self, *inputs):
  75. # TODO: support kwargs ?
  76. # TODO: skip expressions which are independent and have no side effect
  77. node2value = {}
  78. for n, v in zip(self._inputs, inputs):
  79. node2value[n] = v
  80. for expr in self._exprs:
  81. values = expr.interpret(*list(node2value[i] for i in expr.inputs))
  82. for n, v in zip(expr.outputs, values):
  83. node2value[n] = v
  84. return list(node2value[i] for i in self._outputs)
  85. def __repr__(self):
  86. return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
  87. ", ".join(str(i) for i in self._inputs),
  88. "\n\t".join(str(i) for i in self._exprs),
  89. ", ".join(str(i) for i in self._outputs),
  90. )
  91. def _get_meth_name(obj, func):
  92. for cls in type(obj).mro():
  93. for k, v in cls.__dict__.items():
  94. if v == func:
  95. return k
  96. return None
  97. def _wrapped_function(orig_func):
  98. @functools.wraps(orig_func)
  99. def wrapped_fn(*args, **kwargs):
  100. if is_tracing_module():
  101. unset_module_tracing()
  102. inputs, tree_def = tree_flatten(
  103. (args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
  104. )
  105. for i in inputs:
  106. if not NodeMixin.get(i, None):
  107. if isinstance(i, (RawTensor, NodeMixin)):
  108. NodeMixin.wrap_safe(i, Constant.make(i))
  109. meth_name = _get_meth_name(args[0], wrapped_fn)
  110. if meth_name:
  111. self = inputs[0]
  112. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  113. else:
  114. call_node = CallFunction.make(orig_func)
  115. call_node.add_inputs(inputs)
  116. call_node.arg_def = tree_def
  117. outputs = orig_func(*args, **kwargs)
  118. call_node.add_outputs(outputs)
  119. set_module_tracing()
  120. return outputs
  121. return orig_func(*args, **kwargs)
  122. return wrapped_fn
  123. class TracedModuleBuilder(NodeMixin):
  124. _mod = None # type: Module
  125. _body = None # type: InternalGraph
  126. _is_builtin = None # type: bool
  127. __builder_attributes__ = [
  128. "_mod",
  129. "_body",
  130. "_NodeMixin__node",
  131. "_is_builtin",
  132. "build",
  133. ]
  134. def __init__(self, mod, is_top_module=False):
  135. super(TracedModuleBuilder, self).__init__()
  136. self._mod = mod
  137. self._body = None
  138. self._is_builtin = module_tracer.is_builtin(mod)
  139. def build(self):
  140. if self._is_builtin:
  141. node = NodeMixin.get(self)
  142. node.module_type = type(self._mod)
  143. return self._mod
  144. else:
  145. node = NodeMixin.get(self)
  146. traced_module = TracedModule(node)
  147. for k, v in self.__dict__.items():
  148. if k not in TracedModuleBuilder.__builder_attributes__:
  149. if isinstance(v, TracedModuleBuilder):
  150. v = v.build()
  151. setattr(traced_module, k, v)
  152. traced_module.m_node.attr_type_map[k] = type(v)
  153. return traced_module
  154. def __call__(self, *args, **kwargs):
  155. assert isinstance(self._mod, Module)
  156. # prepare args and kwargs for inner graph
  157. def mark_constant(x):
  158. node = NodeMixin.get(x, None)
  159. if node is None: # capture as constant
  160. NodeMixin.wrap(x, lambda: Constant.make(x))
  161. inputs, tree_def = tree_flatten(
  162. ((self, *args), kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
  163. )
  164. for i in inputs:
  165. mark_constant(i)
  166. callnode = CallMethod.make(NodeMixin.get(self))
  167. callnode.add_inputs(inputs)
  168. callnode.arg_def = tree_def
  169. if self._is_builtin:
  170. unset_module_tracing()
  171. outputs = self._mod(*args, **kwargs)
  172. set_module_tracing()
  173. if self._is_builtin:
  174. self._body = None
  175. else:
  176. self._body = InternalGraph()
  177. active_module_tracer().push_scope(self._body)
  178. # rebind self to new input node
  179. orig_self = NodeMixin.get(self)
  180. NodeMixin.wrap_safe(
  181. self, Input.make("self", NodeMixin.get_wrapped_type(self))
  182. )
  183. # prepare args and kwargs for inner graph
  184. def wrap(x):
  185. wrapped = copy.copy(x) # FIXME
  186. NodeMixin.wrap(
  187. wrapped,
  188. lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
  189. )
  190. return wrapped
  191. args = [self]
  192. for i in inputs[1:]:
  193. args.append(wrap(i))
  194. args, kwargs = tree_def.unflatten(args)
  195. active_module_tracer().patcher.auto_patch(
  196. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  197. )
  198. outputs = type(self._mod).forward(*args, **kwargs)
  199. for i in (
  200. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  201. ):
  202. active_module_tracer().current_scope().add_output(NodeMixin.get(i))
  203. NodeMixin.wrap_safe(self, orig_self)
  204. active_module_tracer().pop_scope()
  205. # rebind output to outer graph
  206. callnode.add_outputs(outputs)
  207. self_node = NodeMixin.get(self)
  208. self_node.argdef_graph_map[callnode.arg_def] = self._body
  209. return outputs
  210. def __getattr__(self, name):
  211. if name not in self._mod.__dict__:
  212. attr = getattr(type(self._mod), name).__get__(self, type(self))
  213. else:
  214. attr = getattr(self._mod, name)
  215. if isinstance(attr, Module):
  216. attr = TracedModuleBuilder(attr)
  217. setattr(self, name, attr)
  218. NodeMixin.wrap(
  219. attr,
  220. lambda: GetAttr.make(
  221. NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr)
  222. ),
  223. )
  224. return attr
  225. def __getattribute__(self, name):
  226. if name in TracedModuleBuilder.__builder_attributes__:
  227. return super().__getattribute__(name)
  228. else:
  229. wrapped = super().__getattribute__(name)
  230. if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None):
  231. assert not self._is_builtin
  232. NodeMixin.wrap(
  233. wrapped,
  234. lambda: GetAttr.make(
  235. NodeMixin.get(self),
  236. name,
  237. type=NodeMixin.get_wrapped_type(wrapped),
  238. ),
  239. )
  240. return wrapped
  241. class _expr_list:
  242. def __init__(self, graph: InternalGraph):
  243. self.graph = graph
  244. def __iter__(self):
  245. for expr in self.graph._exprs:
  246. if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
  247. yield expr
  248. if expr.graph is not None:
  249. yield from expr.graph.exprs
  250. else:
  251. yield expr
  252. class TracedModule(Module):
  253. """
  254. `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
  255. interpreted by CallMethod Expr.
  256. """
  257. m_node = None # type: ModuleNode
  258. def __init__(self, node):
  259. super(TracedModule, self).__init__()
  260. self.m_node = node
  261. def forward(self, *args, **kwargs):
  262. inputs, treedef = tree_flatten(
  263. ((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
  264. )
  265. assert treedef in self.m_node.argdef_graph_map
  266. inputs = [i for i in inputs if isinstance(i, (Module, RawTensor))]
  267. outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs)
  268. if len(outputs) == 1:
  269. return outputs[0]
  270. return outputs
  271. @property
  272. def graph(self):
  273. assert len(self.m_node.argdef_graph_map) == 1
  274. return list(self.m_node.argdef_graph_map.values())[0]
  275. @property
  276. def exprs(self):
  277. return self.graph.exprs
  278. def flatten(self):
  279. """
  280. Get a new module, which eliminates ``GetAttr`` and has no hierarchy.
  281. :return: :class:`TracedModule`
  282. """
  283. new_module = copy.deepcopy(self)
  284. def _flatten_subgraph(graph, module, call=None):
  285. if graph is None:
  286. assert not isinstance(module, TracedModule)
  287. const = Constant(module)
  288. modulenode = const.outputs[0]
  289. modulenode.module_type = type(module)
  290. call.inputs[0] = modulenode
  291. return [const, call]
  292. exprs = []
  293. for expr in graph._exprs:
  294. # replace inputs for submodule's expr
  295. for idx, inp in enumerate(expr.inputs):
  296. if call and inp in graph._inputs:
  297. inp_idx = graph._inputs.index(inp)
  298. expr.inputs[idx] = call.inputs[inp_idx]
  299. # replace outputs for submodule's expr
  300. for idx, outp in enumerate(expr.outputs):
  301. if call and outp in graph._outputs:
  302. oup_idx = graph._outputs.index(outp)
  303. expr.outputs[idx] = call.outputs[oup_idx]
  304. if isinstance(expr, GetAttr):
  305. # replace GetAttr with Constant
  306. if isinstance(expr.outputs[0], TensorNode):
  307. const = Constant(getattr(module, expr.name))
  308. const.outputs = expr.outputs
  309. exprs.append(const)
  310. elif isinstance(expr, CallMethod):
  311. obj_node = expr.inputs[0]
  312. if isinstance(obj_node, ModuleNode):
  313. assert isinstance(expr.inputs[0].expr, GetAttr)
  314. (obj,) = expr.inputs[0].expr.interpret(module)
  315. exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
  316. else:
  317. exprs.append(expr)
  318. else:
  319. exprs.append(expr)
  320. return exprs
  321. new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
  322. return new_module
  323. def __getstate__(self):
  324. d = self.__dict__
  325. for k in Module.__dict__:
  326. d.pop(k, None)
  327. return d
  328. def cpp_apply_module_trace(opdef, *args):
  329. return Apply.apply_module_trace_hook(opdef, *args)
  330. def register_as_builtin(mod_cls: Type[Module]) -> None:
  331. """
  332. Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
  333. param mod_cls: the Module class which will be threated as builtin module in tracing
  334. """
  335. module_tracer.register_as_builtin(mod_cls)
  336. def _register_all_builtin_module():
  337. for sub_mod in [M, M.qat, M.quantized]:
  338. for m in getmembers(sub_mod):
  339. if (
  340. isclass(m[1])
  341. and issubclass(m[1], M.Module)
  342. and m[1] is not M.Sequential
  343. ):
  344. module_tracer.register_as_builtin(m[1])
  345. def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
  346. """
  347. Traces module ``mod`` and returns corresponding TracedModule.
  348. param mod: the module will be converted to TracedModule
  349. param input: the positional arguments passed to forward method of ``mod``
  350. param kwargs: the keyword arguments passed to forward method of ``mod``
  351. """
  352. assert active_module_tracer() is None
  353. try:
  354. set_module_tracing()
  355. set_active_module_tracer(module_tracer(_wrapped_function))
  356. with active_module_tracer().patcher:
  357. global_scope = InternalGraph()
  358. active_module_tracer().push_scope(global_scope)
  359. builder = TracedModuleBuilder(mod, True)
  360. NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
  361. inputs, _ = tree_flatten((args, kwargs))
  362. for _, i in enumerate(inputs):
  363. NodeMixin.wrap_safe(
  364. i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
  365. )
  366. builder(*args, **kwargs)
  367. active_module_tracer().pop_scope()
  368. return builder.build()
  369. finally:
  370. set_active_module_tracer(None)
  371. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台