You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import copy
  11. import functools
  12. from inspect import getmembers, isclass, ismethod
  13. from typing import List, Type
  14. from ... import module as M
  15. from ...core._imperative_rt.core2 import Tensor as RawTensor
  16. from ...core._imperative_rt.core2 import (
  17. is_tracing_module,
  18. set_module_tracing,
  19. unset_module_tracing,
  20. )
  21. from ...core.tensor.array_method import ArrayMethodMixin
  22. from ...module import Module
  23. from ...tensor import Tensor
  24. from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
  25. from .module_tracer import (
  26. Patcher,
  27. active_module_tracer,
  28. module_tracer,
  29. set_active_module_tracer,
  30. )
  31. from .node import ModuleNode, Node, NodeMixin, TensorNode
  32. from .pytree import tree_flatten
  33. def _leaf_type(node):
  34. if isinstance(node, RawTensor):
  35. return (Tensor, TensorNode)
  36. elif isinstance(node, (NodeMixin, Module)):
  37. return (Module, ModuleNode, NodeMixin)
  38. else:
  39. return type(node)
  40. class InternalGraph:
  41. """
  42. ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
  43. Attributes:
  44. _exprs: List of Exprs in order of execution
  45. _inputs: Input Nodes of InternalGraph
  46. _outputs: Output Nodes of InternalGraph
  47. """
  48. _exprs = None # type: List[Expr]
  49. _inputs = None # type: List[Node]
  50. _outputs = None # type: List[Node]
  51. def __init__(self):
  52. self._exprs = []
  53. self._inputs = []
  54. self._outputs = []
  55. def insert(self, expr):
  56. self._exprs.append(expr)
  57. def add_input(self, i):
  58. self._inputs.append(i)
  59. def add_output(self, o):
  60. self._outputs.append(o)
  61. def interpret(self, *inputs):
  62. # TODO: support kwargs ?
  63. # TODO: skip expressions which are independent and have no side effect
  64. node2value = {}
  65. for n, v in zip(self._inputs, inputs):
  66. node2value[n] = v
  67. for expr in self._exprs:
  68. values = expr.interpret(*list(node2value[i] for i in expr.inputs))
  69. for n, v in zip(expr.outputs, values):
  70. node2value[n] = v
  71. return list(node2value[i] for i in self._outputs)
  72. def __repr__(self):
  73. return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
  74. ", ".join(str(i) for i in self._inputs),
  75. "\n\t".join(str(i) for i in self._exprs),
  76. ", ".join(str(i) for i in self._outputs),
  77. )
  78. def _get_meth_name(obj, func):
  79. for cls in type(obj).mro():
  80. for k, v in cls.__dict__.items():
  81. if v == func:
  82. return k
  83. return None
  84. def _wrapped_function(orig_func):
  85. @functools.wraps(orig_func)
  86. def wrapped_fn(*args, **kwargs):
  87. if is_tracing_module():
  88. unset_module_tracing()
  89. inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type)
  90. for i in inputs:
  91. if not NodeMixin.get(i, None):
  92. if isinstance(i, (RawTensor, NodeMixin)):
  93. NodeMixin.wrap_safe(i, Constant.make(i))
  94. meth_name = _get_meth_name(args[0], wrapped_fn)
  95. if meth_name:
  96. self = inputs[0]
  97. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  98. else:
  99. call_node = CallFunction.make(orig_func)
  100. call_node.add_inputs(inputs)
  101. call_node.arg_def = tree_def
  102. outputs = orig_func(*args, **kwargs)
  103. call_node.add_outputs(outputs)
  104. set_module_tracing()
  105. return outputs
  106. return orig_func(*args, **kwargs)
  107. return wrapped_fn
  108. class TracedModuleBuilder(NodeMixin):
  109. _mod = None # type: Module
  110. _body = None # type: InternalGraph
  111. _is_builtin = None # type: bool
  112. _arg_def = None # type: TreeDef
  113. __builder_attributes__ = [
  114. "_mod",
  115. "_body",
  116. "_NodeMixin__node",
  117. "_is_builtin",
  118. "_is_traced",
  119. "_arg_def" "build",
  120. ]
  121. def __init__(self, mod):
  122. super(TracedModuleBuilder, self).__init__()
  123. self._mod = mod
  124. self._body = InternalGraph()
  125. self._is_traced = False
  126. self._is_builtin = module_tracer.is_builtin(mod)
  127. def build(self):
  128. if self._is_builtin:
  129. node = NodeMixin.get(self)
  130. node.module_type = type(self._mod)
  131. return self._mod
  132. else:
  133. node = NodeMixin.get(self)
  134. node.graph = self._body
  135. node.attr_type_map = {}
  136. node.arg_def = self._arg_def
  137. traced_module = TracedModule(node)
  138. for k, v in self.__dict__.items():
  139. if k not in TracedModuleBuilder.__builder_attributes__:
  140. if isinstance(v, TracedModuleBuilder):
  141. v = v.build()
  142. setattr(traced_module, k, v)
  143. traced_module.m_node.attr_type_map[k] = type(v)
  144. return traced_module
  145. def __call__(self, *args, **kwargs):
  146. assert isinstance(self._mod, Module)
  147. for arg in args:
  148. assert isinstance(arg, RawTensor)
  149. for k, v in kwargs.items():
  150. assert isinstance(v, RawTensor)
  151. # prepare args and kwargs for inner graph
  152. def mark_constant(x):
  153. node = NodeMixin.get(x, None)
  154. if node is None: # capture as constant
  155. NodeMixin.wrap(x, lambda: Constant.make(x))
  156. inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
  157. if self._arg_def is None:
  158. self._arg_def = tree_def
  159. assert self._arg_def == tree_def
  160. for i in inputs:
  161. mark_constant(i)
  162. callnode = CallMethod.make(NodeMixin.get(self))
  163. callnode.add_inputs(inputs)
  164. callnode.arg_def = tree_def
  165. if self._is_builtin or self._is_traced:
  166. unset_module_tracing()
  167. outputs = self._mod(*args, **kwargs)
  168. set_module_tracing()
  169. if self._is_builtin:
  170. self._body = None
  171. else:
  172. active_module_tracer().push_scope(self._body)
  173. # rebind self to new input node
  174. orig_self = NodeMixin.get(self)
  175. NodeMixin.wrap_safe(
  176. self, Input.make("self", NodeMixin.get_wrapped_type(self))
  177. )
  178. # prepare args and kwargs for inner graph
  179. def wrap(x):
  180. wrapped = copy.copy(x) # FIXME
  181. NodeMixin.wrap(
  182. wrapped,
  183. lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
  184. )
  185. return wrapped
  186. args = [self]
  187. for i in inputs[1:]:
  188. args.append(wrap(i))
  189. args, kwargs = tree_def.unflatten(args)
  190. active_module_tracer().patcher.auto_patch(
  191. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  192. )
  193. outputs = type(self._mod).forward(*args, **kwargs)
  194. for i in (
  195. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  196. ):
  197. active_module_tracer().current_scope().add_output(NodeMixin.get(i))
  198. NodeMixin.wrap_safe(self, orig_self)
  199. self._is_traced = True
  200. active_module_tracer().pop_scope()
  201. # rebind output to outer graph
  202. callnode.add_outputs(outputs)
  203. return outputs
  204. def __getattr__(self, name):
  205. if name not in self._mod.__dict__:
  206. attr = getattr(type(self._mod), name).__get__(self, type(self))
  207. else:
  208. attr = getattr(self._mod, name)
  209. if isinstance(attr, Module):
  210. attr = TracedModuleBuilder(attr)
  211. setattr(self, name, attr)
  212. NodeMixin.wrap(
  213. attr,
  214. lambda: GetAttr.make(
  215. NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr)
  216. ),
  217. )
  218. return attr
  219. def __getattribute__(self, name):
  220. if name in TracedModuleBuilder.__builder_attributes__:
  221. return super().__getattribute__(name)
  222. else:
  223. wrapped = super().__getattribute__(name)
  224. if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None):
  225. assert not self._is_builtin
  226. NodeMixin.wrap(
  227. wrapped,
  228. lambda: GetAttr.make(
  229. NodeMixin.get(self),
  230. name,
  231. type=NodeMixin.get_wrapped_type(wrapped),
  232. ),
  233. )
  234. return wrapped
  235. class TracedModule(Module):
  236. """
  237. `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
  238. """
  239. m_node = None # type: ModuleNode
  240. def __init__(self, node):
  241. super(TracedModule, self).__init__()
  242. self.m_node = node
  243. def forward(self, *args, **kwargs):
  244. inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
  245. assert treedef == self.m_node.arg_def
  246. rst = self.m_node.graph.interpret(*inputs)
  247. if len(rst) == 1:
  248. rst = rst[0]
  249. return rst
  250. @property
  251. def all_exprs(self):
  252. """
  253. Visit all ``Expr``s in the graph recursively.
  254. :return: List[Expr]
  255. """
  256. in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self]
  257. def _flatten_submodule(module, call=None):
  258. if not isinstance(module, TracedModule):
  259. call.inputs[0] = module
  260. return (call,)
  261. exprs = []
  262. graph = module.m_node.graph
  263. for expr in graph._exprs:
  264. # replace inputs for submodule's expr
  265. for idx, inp in enumerate(expr.inputs):
  266. if call and inp in graph._inputs:
  267. expr.inputs[idx] = call.inputs[idx]
  268. # replace outputs for submodule's expr
  269. for idx, outp in enumerate(expr.outputs):
  270. if call and outp in graph._outputs:
  271. expr.outputs[idx] = call.outputs[idx]
  272. if isinstance(expr, GetAttr):
  273. # replace GetAttr with Constant
  274. if isinstance(expr.outputs[0], TensorNode):
  275. const = Constant(getattr(module, expr.name))
  276. const.outputs = expr.outputs
  277. exprs.append(const)
  278. elif isinstance(expr, CallMethod):
  279. obj_node = expr.inputs[0]
  280. if isinstance(obj_node, ModuleNode):
  281. (obj,) = expr.inputs[0].expr.interpret(module)
  282. exprs.extend(_flatten_submodule(obj, expr))
  283. else:
  284. exprs.append(expr)
  285. else:
  286. exprs.append(expr)
  287. return exprs
  288. return in_nodes + _flatten_submodule(self)
  289. def __getstate__(self):
  290. d = self.__dict__
  291. for k in Module.__dict__:
  292. d.pop(k, None)
  293. return d
  294. def cpp_apply_module_trace(opdef, *args):
  295. return Apply.apply_module_trace_hook(opdef, *args)
  296. def register_as_builtin(mod_cls: Type[Module]) -> None:
  297. """
  298. Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
  299. param mod_cls: the Module class which will be threated as builtin module in tracing
  300. """
  301. module_tracer.register_as_builtin(mod_cls)
  302. def _register_all_builtin_module():
  303. for sub_mod in [M, M.qat, M.quantized]:
  304. for m in getmembers(sub_mod):
  305. if (
  306. isclass(m[1])
  307. and issubclass(m[1], M.Module)
  308. and m[1] is not M.Sequential
  309. ):
  310. module_tracer.register_as_builtin(m[1])
  311. def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
  312. """
  313. Traces module ``mod`` and returns corresponding TracedModule.
  314. param mod: the module will be converted to TracedModule
  315. param input: the positional arguments passed to forward method of ``mod``
  316. param kwargs: the keyword arguments passed to forward method of ``mod``
  317. """
  318. assert active_module_tracer() is None
  319. try:
  320. set_module_tracing()
  321. set_active_module_tracer(module_tracer(_wrapped_function))
  322. with active_module_tracer().patcher:
  323. global_scope = InternalGraph()
  324. active_module_tracer().push_scope(global_scope)
  325. builder = TracedModuleBuilder(mod)
  326. NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
  327. inputs, _ = tree_flatten((args, kwargs))
  328. for _, i in enumerate(inputs):
  329. NodeMixin.wrap_safe(
  330. i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
  331. )
  332. builder(*args, **kwargs)
  333. active_module_tracer().pop_scope()
  334. return builder.build()
  335. finally:
  336. set_active_module_tracer(None)
  337. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台