You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import copy
  11. import functools
  12. from inspect import getmembers, isclass, ismethod
  13. from typing import Dict, List, Type
  14. from ... import module as M
  15. from ...core._imperative_rt.core2 import Tensor as RawTensor
  16. from ...core._imperative_rt.core2 import (
  17. is_tracing_module,
  18. set_module_tracing,
  19. unset_module_tracing,
  20. )
  21. from ...core.tensor.array_method import ArrayMethodMixin
  22. from ...module import Module
  23. from ...tensor import Tensor
  24. from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
  25. from .module_tracer import (
  26. Patcher,
  27. active_module_tracer,
  28. module_tracer,
  29. set_active_module_tracer,
  30. )
  31. from .node import ModuleNode, Node, NodeMixin, TensorNode
  32. from .pytree import tree_flatten
  33. def _leaf_type(node):
  34. if isinstance(node, RawTensor):
  35. return (Tensor, TensorNode)
  36. elif isinstance(node, (NodeMixin, Module)):
  37. return (Module, ModuleNode, NodeMixin)
  38. else:
  39. return type(node)
  40. class InternalGraph:
  41. """
  42. ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
  43. Attributes:
  44. _exprs: List of Exprs in order of execution
  45. _inputs: Input Nodes of InternalGraph
  46. _outputs: Output Nodes of InternalGraph
  47. """
  48. _exprs = None # type: List[Expr]
  49. _inputs = None # type: List[Node]
  50. _outputs = None # type: List[Node]
  51. def __init__(self):
  52. self._exprs = []
  53. self._inputs = []
  54. self._outputs = []
  55. def insert(self, expr):
  56. self._exprs.append(expr)
  57. @property
  58. def inputs(self):
  59. return self._inputs
  60. @property
  61. def outputs(self):
  62. return self._outputs
  63. def add_input(self, i):
  64. self._inputs.append(i)
  65. def add_output(self, o):
  66. self._outputs.append(o)
  67. def interpret(self, *inputs):
  68. # TODO: support kwargs ?
  69. # TODO: skip expressions which are independent and have no side effect
  70. node2value = {}
  71. for n, v in zip(self._inputs, inputs):
  72. node2value[n] = v
  73. for expr in self._exprs:
  74. values = expr.interpret(*list(node2value[i] for i in expr.inputs))
  75. for n, v in zip(expr.outputs, values):
  76. node2value[n] = v
  77. return list(node2value[i] for i in self._outputs)
  78. def __repr__(self):
  79. return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
  80. ", ".join(str(i) for i in self._inputs),
  81. "\n\t".join(str(i) for i in self._exprs),
  82. ", ".join(str(i) for i in self._outputs),
  83. )
  84. def _get_meth_name(obj, func):
  85. for cls in type(obj).mro():
  86. for k, v in cls.__dict__.items():
  87. if v == func:
  88. return k
  89. return None
  90. def _wrapped_function(orig_func):
  91. @functools.wraps(orig_func)
  92. def wrapped_fn(*args, **kwargs):
  93. if is_tracing_module():
  94. unset_module_tracing()
  95. inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type)
  96. for i in inputs:
  97. if not NodeMixin.get(i, None):
  98. if isinstance(i, (RawTensor, NodeMixin)):
  99. NodeMixin.wrap_safe(i, Constant.make(i))
  100. meth_name = _get_meth_name(args[0], wrapped_fn)
  101. if meth_name:
  102. self = inputs[0]
  103. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  104. else:
  105. call_node = CallFunction.make(orig_func)
  106. call_node.add_inputs(inputs)
  107. call_node.arg_def = tree_def
  108. outputs = orig_func(*args, **kwargs)
  109. call_node.add_outputs(outputs)
  110. set_module_tracing()
  111. return outputs
  112. return orig_func(*args, **kwargs)
  113. return wrapped_fn
  114. class TracedModuleBuilder(NodeMixin):
  115. _mod = None # type: Module
  116. _body = None # type: InternalGraph
  117. _is_builtin = None # type: bool
  118. _arg_def = None # type: TreeDef
  119. __builder_attributes__ = [
  120. "_mod",
  121. "_body",
  122. "_NodeMixin__node",
  123. "_is_builtin",
  124. "_is_traced",
  125. "_arg_def" "build",
  126. ]
  127. def __init__(self, mod):
  128. super(TracedModuleBuilder, self).__init__()
  129. self._mod = mod
  130. self._body = InternalGraph()
  131. self._is_traced = False
  132. self._is_builtin = module_tracer.is_builtin(mod)
  133. def build(self):
  134. if self._is_builtin:
  135. node = NodeMixin.get(self)
  136. node.module_type = type(self._mod)
  137. return self._mod
  138. else:
  139. node = NodeMixin.get(self)
  140. node.graph = self._body
  141. node.attr_type_map = {}
  142. node.arg_def = self._arg_def
  143. traced_module = TracedModule(node)
  144. for k, v in self.__dict__.items():
  145. if k not in TracedModuleBuilder.__builder_attributes__:
  146. if isinstance(v, TracedModuleBuilder):
  147. v = v.build()
  148. setattr(traced_module, k, v)
  149. traced_module.m_node.attr_type_map[k] = type(v)
  150. return traced_module
  151. def __call__(self, *args, **kwargs):
  152. assert isinstance(self._mod, Module)
  153. for arg in args:
  154. assert isinstance(arg, RawTensor)
  155. for k, v in kwargs.items():
  156. assert isinstance(v, RawTensor)
  157. # prepare args and kwargs for inner graph
  158. def mark_constant(x):
  159. node = NodeMixin.get(x, None)
  160. if node is None: # capture as constant
  161. NodeMixin.wrap(x, lambda: Constant.make(x))
  162. inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
  163. if self._arg_def is None:
  164. self._arg_def = tree_def
  165. assert self._arg_def == tree_def
  166. for i in inputs:
  167. mark_constant(i)
  168. callnode = CallMethod.make(NodeMixin.get(self))
  169. callnode.add_inputs(inputs)
  170. callnode.arg_def = tree_def
  171. if self._is_builtin or self._is_traced:
  172. unset_module_tracing()
  173. outputs = self._mod(*args, **kwargs)
  174. set_module_tracing()
  175. if self._is_builtin:
  176. self._body = None
  177. else:
  178. active_module_tracer().push_scope(self._body)
  179. # rebind self to new input node
  180. orig_self = NodeMixin.get(self)
  181. NodeMixin.wrap_safe(
  182. self, Input.make("self", NodeMixin.get_wrapped_type(self))
  183. )
  184. # prepare args and kwargs for inner graph
  185. def wrap(x):
  186. wrapped = copy.copy(x) # FIXME
  187. NodeMixin.wrap(
  188. wrapped,
  189. lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
  190. )
  191. return wrapped
  192. args = [self]
  193. for i in inputs[1:]:
  194. args.append(wrap(i))
  195. args, kwargs = tree_def.unflatten(args)
  196. active_module_tracer().patcher.auto_patch(
  197. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  198. )
  199. outputs = type(self._mod).forward(*args, **kwargs)
  200. for i in (
  201. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  202. ):
  203. active_module_tracer().current_scope().add_output(NodeMixin.get(i))
  204. NodeMixin.wrap_safe(self, orig_self)
  205. self._is_traced = True
  206. active_module_tracer().pop_scope()
  207. # rebind output to outer graph
  208. callnode.add_outputs(outputs)
  209. return outputs
  210. def __getattr__(self, name):
  211. if name not in self._mod.__dict__:
  212. attr = getattr(type(self._mod), name).__get__(self, type(self))
  213. else:
  214. attr = getattr(self._mod, name)
  215. if isinstance(attr, Module):
  216. attr = TracedModuleBuilder(attr)
  217. setattr(self, name, attr)
  218. NodeMixin.wrap(
  219. attr,
  220. lambda: GetAttr.make(
  221. NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr)
  222. ),
  223. )
  224. return attr
  225. def __getattribute__(self, name):
  226. if name in TracedModuleBuilder.__builder_attributes__:
  227. return super().__getattribute__(name)
  228. else:
  229. wrapped = super().__getattribute__(name)
  230. if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None):
  231. assert not self._is_builtin
  232. NodeMixin.wrap(
  233. wrapped,
  234. lambda: GetAttr.make(
  235. NodeMixin.get(self),
  236. name,
  237. type=NodeMixin.get_wrapped_type(wrapped),
  238. ),
  239. )
  240. return wrapped
  241. class _expr_list:
  242. def __init__(self, module: "TracedModule"):
  243. self.module = module
  244. def __iter__(self):
  245. graph = self.module.m_node.graph
  246. for expr in graph._exprs:
  247. if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
  248. yield expr
  249. assert isinstance(expr.inputs[0].expr, GetAttr)
  250. (obj,) = expr.inputs[0].expr.interpret(self.module)
  251. if isinstance(obj, TracedModule):
  252. yield from obj.exprs
  253. yield expr
  254. class TracedModule(Module):
  255. """
  256. `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
  257. """
  258. m_node = None # type: ModuleNode
  259. def __init__(self, node):
  260. super(TracedModule, self).__init__()
  261. self.m_node = node
  262. def forward(self, *args, **kwargs):
  263. inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
  264. assert treedef == self.m_node.arg_def
  265. rst = self.m_node.graph.interpret(*inputs)
  266. if len(rst) == 1:
  267. rst = rst[0]
  268. return rst
  269. @property
  270. def exprs(self):
  271. """
  272. Get all ``Expr`` s recursively.
  273. :return: Iterator[Expr]
  274. """
  275. return _expr_list(self)
  276. def flatten(self):
  277. """
  278. Get a new module, which eliminates ``GetAttr`` and has no hierarchy.
  279. :return: :class:`TracedModule`
  280. """
  281. new_module = copy.deepcopy(self)
  282. def _flatten_submodule(module, call=None):
  283. if not isinstance(module, TracedModule):
  284. call.inputs[0] = module
  285. return (call,)
  286. exprs = []
  287. graph = module.m_node.graph
  288. for expr in graph._exprs:
  289. # replace inputs for submodule's expr
  290. for idx, inp in enumerate(expr.inputs):
  291. if call and inp in graph._inputs:
  292. expr.inputs[idx] = call.inputs[idx]
  293. # replace outputs for submodule's expr
  294. for idx, outp in enumerate(expr.outputs):
  295. if call and outp in graph._outputs:
  296. expr.outputs[idx] = call.outputs[idx]
  297. if isinstance(expr, GetAttr):
  298. # replace GetAttr with Constant
  299. if isinstance(expr.outputs[0], TensorNode):
  300. const = Constant(getattr(module, expr.name))
  301. const.outputs = expr.outputs
  302. exprs.append(const)
  303. elif isinstance(expr, CallMethod):
  304. obj_node = expr.inputs[0]
  305. if isinstance(obj_node, ModuleNode):
  306. assert isinstance(expr.inputs[0].expr, GetAttr)
  307. (obj,) = expr.inputs[0].expr.interpret(module)
  308. exprs.extend(_flatten_submodule(obj, expr))
  309. else:
  310. exprs.append(expr)
  311. else:
  312. exprs.append(expr)
  313. return exprs
  314. new_module.m_node.graph._exprs = _flatten_submodule(new_module)
  315. return new_module
  316. def __getstate__(self):
  317. d = self.__dict__
  318. for k in Module.__dict__:
  319. d.pop(k, None)
  320. return d
  321. def cpp_apply_module_trace(opdef, *args):
  322. return Apply.apply_module_trace_hook(opdef, *args)
  323. def register_as_builtin(mod_cls: Type[Module]) -> None:
  324. """
  325. Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
  326. param mod_cls: the Module class which will be threated as builtin module in tracing
  327. """
  328. module_tracer.register_as_builtin(mod_cls)
  329. def _register_all_builtin_module():
  330. for sub_mod in [M, M.qat, M.quantized]:
  331. for m in getmembers(sub_mod):
  332. if (
  333. isclass(m[1])
  334. and issubclass(m[1], M.Module)
  335. and m[1] is not M.Sequential
  336. ):
  337. module_tracer.register_as_builtin(m[1])
  338. def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
  339. """
  340. Traces module ``mod`` and returns corresponding TracedModule.
  341. param mod: the module will be converted to TracedModule
  342. param input: the positional arguments passed to forward method of ``mod``
  343. param kwargs: the keyword arguments passed to forward method of ``mod``
  344. """
  345. assert active_module_tracer() is None
  346. try:
  347. set_module_tracing()
  348. set_active_module_tracer(module_tracer(_wrapped_function))
  349. with active_module_tracer().patcher:
  350. global_scope = InternalGraph()
  351. active_module_tracer().push_scope(global_scope)
  352. builder = TracedModuleBuilder(mod)
  353. NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
  354. inputs, _ = tree_flatten((args, kwargs))
  355. for _, i in enumerate(inputs):
  356. NodeMixin.wrap_safe(
  357. i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
  358. )
  359. builder(*args, **kwargs)
  360. active_module_tracer().pop_scope()
  361. return builder.build()
  362. finally:
  363. set_active_module_tracer(None)
  364. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台