You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import copy
  11. import functools
  12. from typing import List, Type
  13. from ... import module as M
  14. from ...core._imperative_rt.core2 import (
  15. is_tracing_module,
  16. set_module_tracing,
  17. unset_module_tracing,
  18. )
  19. from ...core.tensor.array_method import ArrayMethodMixin
  20. from ...module import Module
  21. from ...tensor import Tensor
  22. from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
  23. from .module_tracer import (
  24. Patcher,
  25. active_module_tracer,
  26. module_tracer,
  27. set_active_module_tracer,
  28. )
  29. from .node import ModuleNode, Node, NodeMixin, TensorNode
  30. class InternalGraph:
  31. """
  32. ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
  33. Attributes:
  34. _exprs: List of Exprs in order of execution
  35. _inputs: Input Nodes of InternalGraph
  36. _outputs: Output Nodes of InternalGraph
  37. """
  38. _exprs = None # type: List[Expr]
  39. _inputs = None # type: List[Node]
  40. _outputs = None # type: List[Node]
  41. def __init__(self):
  42. self._exprs = []
  43. self._inputs = []
  44. self._outputs = []
  45. def insert(self, expr):
  46. self._exprs.append(expr)
  47. def add_input(self, i):
  48. self._inputs.append(i)
  49. def add_output(self, o):
  50. self._outputs.append(o)
  51. def interpret(self, *inputs):
  52. # TODO: support kwargs ?
  53. # TODO: skip expressions which are independent and have no side effect
  54. node2value = {}
  55. for n, v in zip(self._inputs, inputs):
  56. node2value[n] = v
  57. for expr in self._exprs:
  58. values = expr.interpret(
  59. *list(Expr.get_arg_value(i, node2value) for i in expr.inputs)
  60. )
  61. for n, v in zip(expr.outputs, values):
  62. node2value[n] = v
  63. return list(node2value[i] for i in self._outputs)
  64. def __repr__(self):
  65. return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
  66. ", ".join(str(i) for i in self._inputs),
  67. "\n\t".join(str(i) for i in self._exprs),
  68. ", ".join(str(i) for i in self._outputs),
  69. )
  70. def _wrapped_function(orig_func):
  71. @functools.wraps(orig_func)
  72. def wrapped_fn(*inputs, **kwargs):
  73. if is_tracing_module():
  74. unset_module_tracing()
  75. const_kwargs = {}
  76. arg_names = orig_func.__code__.co_varnames
  77. if orig_func.__qualname__.split(".").__len__() > 1:
  78. # FIXME: a robust way to distinguish method and function. <XP>
  79. self = inputs[0]
  80. call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__)
  81. else:
  82. call_node = CallFunction.make(orig_func)
  83. def add_input(inp, varname=None):
  84. node = Expr.get_args_node(inp)
  85. if node is not None:
  86. call_node.add_input(node, varname)
  87. else:
  88. const_kwargs[varname] = inp
  89. for ind, inp in enumerate(inputs):
  90. add_input(inp, arg_names[ind])
  91. for k, v in kwargs.items():
  92. add_input(v, k)
  93. call_node.kwargs = const_kwargs
  94. outputs = orig_func(*inputs, **kwargs)
  95. call_node.add_outputs(outputs)
  96. set_module_tracing()
  97. return outputs
  98. return orig_func(*inputs, **kwargs)
  99. return wrapped_fn
  100. class TracedModuleBuilder(NodeMixin):
  101. _mod = None # type: Module
  102. _body = None # type: InternalGraph
  103. _is_builtin = None # type: bool
  104. __builder_attributes__ = [
  105. "_mod",
  106. "_body",
  107. "_NodeMixin__node",
  108. "_is_builtin",
  109. "_is_traced",
  110. "build",
  111. ]
  112. def __init__(self, mod):
  113. super(TracedModuleBuilder, self).__init__()
  114. self._mod = mod
  115. self._body = InternalGraph()
  116. self._is_traced = False
  117. self._is_builtin = module_tracer.is_builtin(mod)
  118. def build(self):
  119. if self._is_builtin:
  120. node = NodeMixin.get(self)
  121. node.module_type = type(self._mod)
  122. return self._mod
  123. else:
  124. node = NodeMixin.get(self)
  125. node.graph = self._body
  126. node.attr_type_map = {}
  127. traced_module = TracedModule(node)
  128. for k, v in self.__dict__.items():
  129. if k not in TracedModuleBuilder.__builder_attributes__:
  130. if isinstance(v, TracedModuleBuilder):
  131. v = v.build()
  132. setattr(traced_module, k, v)
  133. traced_module.m_node.attr_type_map[k] = type(v)
  134. return traced_module
  135. def __call__(self, *inputs, **kwargs):
  136. assert isinstance(self._mod, Module)
  137. # prepare args and kwargs for inner graph
  138. def mark_constant(x):
  139. node = NodeMixin.get(x, None)
  140. if node is None: # capture as constant
  141. NodeMixin.wrap(x, lambda: Constant.make(x))
  142. for i in inputs:
  143. mark_constant(i)
  144. for k, v in kwargs.items():
  145. mark_constant(v)
  146. callnode = CallMethod.make(NodeMixin.get(self))
  147. def add_input(x):
  148. callnode.add_input(NodeMixin.get(x))
  149. for i in inputs:
  150. add_input(i)
  151. for k, v in kwargs.items():
  152. add_input(v)
  153. if self._is_builtin or self._is_traced:
  154. unset_module_tracing()
  155. outputs = self._mod(*inputs, **kwargs)
  156. set_module_tracing()
  157. if self._is_builtin:
  158. self._body = None
  159. else:
  160. active_module_tracer().push_scope(self._body)
  161. # rebind self to new input node
  162. orig_self = NodeMixin.get(self)
  163. NodeMixin.wrap_safe(
  164. self, Input.make("self", NodeMixin.get_wrapped_type(self))
  165. )
  166. # prepare args and kwargs for inner graph
  167. def wrap(x):
  168. # wrapped = copy.copy(x) # FIXME
  169. wrapped = x # FIXME: <XP>
  170. NodeMixin.wrap(
  171. wrapped,
  172. lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
  173. )
  174. return wrapped
  175. args = []
  176. for i in inputs:
  177. args.append(wrap(i))
  178. for k, v in kwargs.items():
  179. kwargs[k] = wrap(v)
  180. active_module_tracer().patcher.auto_patch(
  181. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  182. )
  183. outputs = type(self._mod).forward(self, *args, **kwargs)
  184. for i in (
  185. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  186. ):
  187. active_module_tracer().current_scope().add_output(NodeMixin.get(i))
  188. NodeMixin.wrap_safe(self, orig_self)
  189. self._is_traced = True
  190. active_module_tracer().pop_scope()
  191. # rebind output to outer graph
  192. callnode.add_outputs(outputs)
  193. return outputs
  194. def __getattr__(self, name):
  195. if name not in self._mod.__dict__:
  196. attr = getattr(type(self._mod), name).__get__(self, type(self))
  197. else:
  198. attr = getattr(self._mod, name)
  199. if isinstance(attr, Module):
  200. attr = TracedModuleBuilder(attr)
  201. setattr(self, name, attr)
  202. NodeMixin.wrap(
  203. attr,
  204. lambda: GetAttr.make(
  205. NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr)
  206. ),
  207. )
  208. return attr
  209. def __getattribute__(self, name):
  210. if name in TracedModuleBuilder.__builder_attributes__:
  211. return super().__getattribute__(name)
  212. else:
  213. wrapped = super().__getattribute__(name)
  214. if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None):
  215. assert not self._is_builtin
  216. NodeMixin.wrap(
  217. wrapped,
  218. lambda: GetAttr.make(
  219. NodeMixin.get(self),
  220. name,
  221. type=NodeMixin.get_wrapped_type(wrapped),
  222. ),
  223. )
  224. return wrapped
  225. class TracedModule(Module):
  226. """
  227. `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
  228. """
  229. m_node = None # type: ModuleNode
  230. def __init__(self, node):
  231. super(TracedModule, self).__init__()
  232. self.m_node = node
  233. def forward(self, *inputs):
  234. rst = self.m_node.graph.interpret(self, *inputs)
  235. if len(rst) == 1:
  236. rst = rst[0]
  237. return rst
  238. @property
  239. def all_exprs(self):
  240. """
  241. Visit all ``Expr``s in the graph recursively.
  242. :return: List[Expr]
  243. """
  244. in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self]
  245. def _flatten_submodule(module, call=None):
  246. if not isinstance(module, TracedModule):
  247. call.inputs[0] = module
  248. return (call,)
  249. exprs = []
  250. graph = module.m_node.graph
  251. for expr in graph._exprs:
  252. # replace inputs for submodule's expr
  253. for idx, inp in enumerate(expr.inputs):
  254. if call and inp in graph._inputs:
  255. expr.inputs[idx] = call.inputs[idx]
  256. # replace outputs for submodule's expr
  257. for idx, outp in enumerate(expr.outputs):
  258. if call and outp in graph._outputs:
  259. expr.outputs[idx] = call.outputs[idx]
  260. if isinstance(expr, GetAttr):
  261. # replace GetAttr with Constant
  262. if isinstance(expr.outputs[0], TensorNode):
  263. const = Constant(getattr(module, expr.name))
  264. const.outputs = expr.outputs
  265. exprs.append(const)
  266. elif isinstance(expr, CallMethod):
  267. obj_node = expr.inputs[0]
  268. if isinstance(obj_node, ModuleNode):
  269. (obj,) = expr.inputs[0].expr.interpret(module)
  270. exprs.extend(_flatten_submodule(obj, expr))
  271. else:
  272. exprs.append(expr)
  273. else:
  274. exprs.append(expr)
  275. return exprs
  276. return in_nodes + _flatten_submodule(self)
  277. def __getstate__(self):
  278. d = self.__dict__
  279. for k in Module.__dict__:
  280. d.pop(k, None)
  281. return d
  282. def cpp_apply_module_trace(opdef, *args):
  283. return Apply.apply_module_trace_hook(opdef, *args)
  284. def register_as_builtin(mod_cls: Type[Module]) -> None:
  285. """
  286. Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
  287. param mod_cls: the Module class which will be threated as builtin module in tracing
  288. """
  289. module_tracer.register_as_builtin(mod_cls)
  290. def _register_all_builtin_module():
  291. from inspect import getmembers, isclass
  292. for sub_mod in [M, M.qat, M.quantized]:
  293. for m in getmembers(sub_mod):
  294. if (
  295. isclass(m[1])
  296. and issubclass(m[1], M.Module)
  297. and m[1] is not M.Sequential
  298. ):
  299. module_tracer.register_as_builtin(m[1])
  300. def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule:
  301. """
  302. Traces module ``mod`` and returns corresponding TracedModule.
  303. param mod: the module will be converted to TracedModule
  304. param input: the positional arguments passed to forward method of ``mod``
  305. param kwargs: the keyword arguments passed to forward method of ``mod``
  306. """
  307. assert active_module_tracer() is None
  308. try:
  309. set_module_tracing()
  310. set_active_module_tracer(module_tracer(_wrapped_function))
  311. with active_module_tracer().patcher:
  312. global_scope = InternalGraph()
  313. active_module_tracer().push_scope(global_scope)
  314. builder = TracedModuleBuilder(mod)
  315. NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
  316. for _, i in enumerate(inputs):
  317. NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_)))
  318. for k, v in kwargs.items():
  319. NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k)))
  320. builder(*inputs, **kwargs)
  321. active_module_tracer().pop_scope()
  322. return builder.build()
  323. finally:
  324. set_active_module_tracer(None)
  325. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台