You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 43 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import copy
  11. import functools
  12. import inspect
  13. import weakref
  14. from inspect import getmembers, isclass, ismethod
  15. from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
  16. from ... import functional as F
  17. from ... import get_logger
  18. from ... import module as M
  19. from ...core._imperative_rt.core2 import Tensor as RawTensor
  20. from ...core._imperative_rt.core2 import (
  21. is_tracing_module,
  22. set_module_tracing,
  23. unset_module_tracing,
  24. )
  25. from ...core._trace_option import set_symbolic_shape
  26. from ...core.tensor.array_method import ArrayMethodMixin
  27. from ...module import Module
  28. from ...quantization.fake_quant import LSQ, TQT, FakeQuantize
  29. from ...quantization.observer import (
  30. ExponentialMovingAverageObserver,
  31. MinMaxObserver,
  32. SyncMinMaxObserver,
  33. )
  34. from ...tensor import Tensor
  35. from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
  36. from .module_tracer import (
  37. Patcher,
  38. active_module_tracer,
  39. module_tracer,
  40. set_active_module_tracer,
  41. )
  42. from .node import ModuleNode, Node, NodeMixin, TensorNode
  43. from .pytree import tree_flatten
  44. logger = get_logger(__name__)
  45. def _is_leaf(node):
  46. assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
  47. type(node)
  48. )
  49. return isinstance(node, RawTensor)
  50. def wrap_tensors(tensors: Tensor, nodes: TensorNode):
  51. inp_tensors = copy.deepcopy(tensors)
  52. inp_tensors, inp_def_v = tree_flatten(inp_tensors)
  53. inp_nodes, inp_def_n = tree_flatten(nodes)
  54. for v, n in zip(inp_tensors, inp_nodes):
  55. if isinstance(n, TensorNode) and isinstance(v, Tensor):
  56. NodeMixin.wrap_safe(v, n)
  57. return inp_def_v.unflatten(inp_tensors)
  58. class _InsertExprs:
  59. def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True):
  60. self.graph = graph
  61. self.global_scope = InternalGraph()
  62. self.expr = expr
  63. self.after = after
  64. def __enter__(self):
  65. self.use_sym_shape = set_symbolic_shape(True)
  66. set_module_tracing()
  67. assert active_module_tracer() is None
  68. set_active_module_tracer(module_tracer(_wrapped_function))
  69. active_module_tracer().patcher.__enter__()
  70. active_module_tracer().push_scope(self.global_scope)
  71. def __exit__(self, ty, va, tr):
  72. set_symbolic_shape(self.use_sym_shape)
  73. unset_module_tracing()
  74. active_module_tracer().patcher.__exit__(ty, va, tr)
  75. set_active_module_tracer(None)
  76. index = len(self.graph._exprs) if self.after else 0
  77. if self.expr is not None:
  78. index = self.graph._exprs.index(self.expr)
  79. if self.after:
  80. index += 1
  81. for expr in self.global_scope._exprs:
  82. self.graph._exprs.insert(index, expr)
  83. index += 1
  84. class InternalGraph:
  85. """
  86. ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
  87. Attributes:
  88. _exprs: List of Exprs in order of execution
  89. _inputs: Input Nodes of InternalGraph
  90. _outputs: Output Nodes of InternalGraph
  91. """
  92. _exprs = None # type: List[Expr]
  93. _inputs = None # type: List[Node]
  94. _outputs = None # type: List[Node]
  95. def __init__(self):
  96. self._exprs = []
  97. self._inputs = []
  98. self._outputs = []
  99. self._watch_point = []
  100. self._end_point = []
  101. self._rst = collections.defaultdict(list)
  102. def insert(self, expr):
  103. self._exprs.append(expr)
  104. @property
  105. def inputs(self):
  106. return self._inputs
  107. @property
  108. def outputs(self):
  109. return self._outputs
  110. @property
  111. def expr_filter(self):
  112. return ExprFilter(_expr_iter(self))
  113. @property
  114. def node_filter(self):
  115. return NodeFilter(_node_iter(self))
  116. def get_function_by_type(self, func: Callable = None):
  117. return self.expr_filter.call_function(func)
  118. def get_method_by_type(self, method: str = None):
  119. return self.expr_filter.call_method(method)
  120. def get_expr_by_id(self, expr_id: List[int] = None):
  121. return self.expr_filter.expr_id(expr_id)
  122. def get_module_by_type(self, module_cls: Module):
  123. assert issubclass(module_cls, Module)
  124. return self.node_filter.type(module_cls, ModuleNode)
  125. def get_node_by_id(self, node_id: List[int] = None):
  126. return self.node_filter.node_id(node_id)
  127. def add_input(self, i):
  128. self._inputs.append(i)
  129. def add_output(self, o):
  130. self._outputs.append(o)
  131. def _replace_inputs_outputs(self, repl_dict):
  132. for node, repl_node in repl_dict.items():
  133. assert node in self._inputs or node in self._outputs
  134. for i in node.users:
  135. if i not in repl_node.users:
  136. repl_node.users.append(i)
  137. for idx, i in enumerate(self._inputs):
  138. if i in repl_dict:
  139. self._inputs[idx] = repl_dict[i]
  140. for idx, o in enumerate(self._outputs):
  141. if o in repl_dict:
  142. self._outputs[idx] = repl_dict[o]
  143. for expr in self._exprs:
  144. for idx, i in enumerate(expr.inputs):
  145. if i in repl_dict:
  146. expr.inputs[idx] = repl_dict[i]
  147. for idx, o in enumerate(expr.outputs):
  148. if o in repl_dict:
  149. expr.outputs[idx] = repl_dict[o]
  150. expr.outputs[idx].expr = expr
  151. def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
  152. if not isinstance(nodes, Sequence):
  153. nodes = (nodes,)
  154. ret = list()
  155. queue = list(nodes)
  156. visited_queue = list()
  157. while queue:
  158. node = queue.pop()
  159. visited_queue.append(node)
  160. expr = node.expr
  161. if expr not in ret:
  162. ret.append(expr)
  163. for i in expr.inputs:
  164. if i not in queue and i not in visited_queue:
  165. queue.append(i)
  166. return ret
  167. def reset_inputs(self, *args, **kwargs):
  168. forma_mnode = self.inputs[0]
  169. actual_mnodes = forma_mnode.actual_mnode
  170. call_nodes = []
  171. for n in actual_mnodes:
  172. for c_expr in n.users:
  173. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  174. call_nodes.append((c_expr, n))
  175. moudle = forma_mnode.owner
  176. assert moudle._is_top, "reset_inputs only support the top-level graph"
  177. inputs, tree_def = tree_flatten(((moudle, *args), kwargs))
  178. def create_node(val: Tensor):
  179. node = Input(type=TensorNode).outputs[0]
  180. node.shape = val.shape
  181. node.dtype = val.dtype
  182. return node
  183. formal_node_inputs = [
  184. forma_mnode,
  185. ]
  186. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  187. if call_nodes:
  188. org_argdef = call_nodes[0][0].arg_def
  189. for v in inputs[1:]:
  190. assert isinstance(v, RawTensor)
  191. formal_node_inputs.append(create_node(v))
  192. actual_nodes = []
  193. for e, n in call_nodes:
  194. e.arg_def = tree_def
  195. actual_node_inputs = [
  196. n,
  197. ]
  198. for v in inputs[1:]:
  199. actual_node_inputs.append(create_node(v))
  200. for org_n in e.inputs:
  201. org_n.users.pop(e)
  202. e.inputs[:] = actual_node_inputs
  203. e.const_val = []
  204. actual_nodes.append(actual_node_inputs[1:])
  205. self._inputs[:] = formal_node_inputs
  206. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  207. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  208. # return formal_node_inputs[1:], actual_nodes
  209. return formal_node_inputs[1:]
  210. def add_input_node(self, shape, dtype="float32"):
  211. forma_mnode = self.inputs[0]
  212. actual_mnodes = forma_mnode.actual_mnode
  213. moudle = forma_mnode.owner
  214. assert moudle._is_top, "add_input_node only support the top-level graph"
  215. call_nodes = []
  216. for n in actual_mnodes:
  217. for c_expr in n.users:
  218. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  219. call_nodes.append(c_expr)
  220. def create_node(is_input: bool = True):
  221. if is_input:
  222. node = Input(type=TensorNode).outputs[0]
  223. else:
  224. node = TensorNode(expr=None)
  225. node.shape = shape
  226. node.dtype = dtype
  227. return node
  228. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  229. if call_nodes:
  230. org_argdef = call_nodes[0].arg_def
  231. args, kwargs = org_argdef.unflatten(self._inputs)
  232. formal_inp_node = create_node(True)
  233. inputs, tree_def = tree_flatten(
  234. ((*args, formal_inp_node), kwargs),
  235. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  236. )
  237. self._inputs[:] = inputs[:]
  238. actual_inp_nodes = []
  239. for e in call_nodes:
  240. args, kwargs = e.unflatten_args(e.inputs)
  241. args = args + (create_node(False),)
  242. inputs, tree_def = tree_flatten(
  243. (args, kwargs),
  244. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  245. )
  246. e.inputs[:] = inputs[:]
  247. e.arg_def = tree_def
  248. actual_inp_nodes.append(args[-1])
  249. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  250. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  251. # return formal_inp_node, actual_inp_nodes
  252. return formal_inp_node
  253. def reset_outputs(self, outputs):
  254. outputs, out_def = tree_flatten(
  255. outputs, is_leaf=lambda x: isinstance(x, TensorNode),
  256. )
  257. forma_mnode = self.inputs[0]
  258. moudle = forma_mnode.owner
  259. assert moudle._is_top, "reset_outputs only support the top-level graph"
  260. actual_mnodes = forma_mnode.actual_mnode
  261. call_nodes = []
  262. for n in actual_mnodes:
  263. for c_expr in n.users:
  264. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  265. call_nodes.append((c_expr))
  266. def create_node(val: TensorNode, expr: Expr):
  267. node = TensorNode(expr)
  268. node.shape = val.shape
  269. node.dtype = val.dtype
  270. return node
  271. tree_def = list(moudle.argdef_graph_map.keys())[0]
  272. if call_nodes:
  273. tree_def = call_nodes[0].arg_def
  274. actual_nodes = []
  275. for e in call_nodes:
  276. actual_node_outputs = []
  277. for v in outputs:
  278. actual_node_outputs.append(create_node(v, e))
  279. e.outputs[:] = actual_node_outputs
  280. e.out_def = out_def
  281. actual_nodes.append(actual_node_outputs)
  282. self._outputs[:] = outputs
  283. moudle.argdef_outdef_map[tree_def] = out_def
  284. return actual_nodes
  285. def add_output_node(self, node: TensorNode):
  286. forma_mnode = self.inputs[0]
  287. moudle = forma_mnode.owner
  288. assert moudle._is_top, "add_output_node only support the top-level graph"
  289. actual_mnodes = forma_mnode.actual_mnode
  290. call_nodes = []
  291. for n in actual_mnodes:
  292. for c_expr in n.users:
  293. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  294. call_nodes.append((c_expr))
  295. def create_node(val: TensorNode, expr: Expr):
  296. node = TensorNode(expr)
  297. node.shape = val.shape
  298. node.dtype = val.dtype
  299. return node
  300. tree_def = list(moudle.argdef_graph_map.keys())[0]
  301. if call_nodes:
  302. tree_def = call_nodes[0].arg_def
  303. org_out_def = moudle.argdef_outdef_map[tree_def]
  304. org_outs = org_out_def.unflatten(self._outputs)
  305. outputs, out_def = tree_flatten(
  306. (org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode),
  307. )
  308. self._outputs[:] = outputs
  309. actual_out_nodes = []
  310. for e in call_nodes:
  311. actual_node = create_node(node, e)
  312. org_outs = org_out_def.unflatten(e.outputs)
  313. outputs, out_def = tree_flatten(
  314. (org_outs, actual_node), is_leaf=lambda x: isinstance(x, TensorNode),
  315. )
  316. e.outputs[:] = outputs
  317. e.out_def = out_def
  318. actual_out_nodes.append(actual_node)
  319. moudle.argdef_outdef_map[tree_def] = out_def
  320. return actual_out_nodes
  321. def insert_function(self, func: Callable, *args, **kwargs):
  322. assert isinstance(func, Callable)
  323. inp_nodes, inp_def = tree_flatten((args, kwargs))
  324. insert_idx = -1
  325. for i in inp_nodes:
  326. if isinstance(i, TensorNode) and i.expr in self._exprs:
  327. insert_idx = max(insert_idx, self._exprs.index(i.expr))
  328. fake_inp_val = list(
  329. F.zeros(shape=i.shape, dtype=i.dtype) if isinstance(i, TensorNode) else i
  330. for i in inp_nodes
  331. )
  332. for v, n in zip(fake_inp_val, inp_nodes):
  333. if isinstance(n, TensorNode):
  334. NodeMixin.wrap_safe(v, n)
  335. fake_args, fake_kwargs = inp_def.unflatten(fake_inp_val)
  336. insert_point = self.insert_exprs_before()
  337. if insert_idx != -1:
  338. insert_point = self.insert_exprs_after(self._exprs[insert_idx])
  339. with insert_point:
  340. rst = func(*fake_args, **fake_kwargs)
  341. if rst is None:
  342. return None
  343. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  344. node_outputs = []
  345. for out in outputs:
  346. assert isinstance(out, RawTensor)
  347. node_outputs.append(NodeMixin.get(out, None))
  348. node_outputs = out_def.unflatten(node_outputs)
  349. return node_outputs
  350. def insert_exprs_after(self, expr: Optional[Expr] = None):
  351. if expr is not None:
  352. assert expr.top_graph == self, "Expr to insert after is not in graph."
  353. return _InsertExprs(self, expr, after=True)
  354. def insert_exprs_before(self, expr: Optional[Expr] = None):
  355. if expr is not None:
  356. assert expr.top_graph == self, "Expr to insert before is not in graph."
  357. return _InsertExprs(self, expr, after=False)
  358. def replace_node(self, repl_dict: Dict[Node, Node]):
  359. while repl_dict:
  360. node, repl_node = repl_dict.popitem()
  361. # check graph inputs and outputs
  362. assert node not in self.inputs, "Cannot replace inputs"
  363. for i, n in enumerate(self.outputs):
  364. if n is node:
  365. self.outputs[i] = repl_node
  366. # update users of node and repl_node
  367. # update inputs of expr in node.users
  368. dep_exprs = self.get_dep_exprs(repl_node)
  369. i = 0
  370. while i < len(node.users):
  371. n = node.users[i]
  372. if n in dep_exprs:
  373. logger.info("Find a loop: ignore this replacement once")
  374. logger.info("node: %s" % node.__repr__())
  375. logger.info("repl_node: %s" % repl_node.__repr__())
  376. i += 1
  377. continue
  378. repl_node.users.append(n)
  379. node.users.pop(i)
  380. idx = n.inputs.index(node)
  381. n.inputs[idx] = repl_node
  382. def compile(self):
  383. """
  384. Delete unused expr.
  385. """
  386. dep_exprs = self.get_dep_exprs(self.outputs)
  387. i = 0
  388. while i < len(self._exprs):
  389. expr = self._exprs[i]
  390. if expr in dep_exprs or expr._disable_remove:
  391. i += 1
  392. continue
  393. for n in expr.inputs:
  394. n.users.remove(expr)
  395. self._exprs.remove(expr)
  396. def interpret(self, *inputs):
  397. node2value = {}
  398. end_nodes_set = set(self._end_point)
  399. endnode2value = {}
  400. def get_all_endnode_val(n, v):
  401. if n in end_nodes_set:
  402. endnode2value[n] = v
  403. end_nodes_set.remove(n)
  404. return not end_nodes_set
  405. return False
  406. for n, v in zip(self._inputs, inputs):
  407. node2value[n] = v
  408. if n in self._watch_point:
  409. self._rst[n].append(v)
  410. if n in self._end_point and get_all_endnode_val(n, v):
  411. return list(endnode2value[i] for i in self._end_point)
  412. for expr in self._exprs:
  413. values = expr.interpret(*list(node2value[i] for i in expr.inputs))
  414. if values is not None:
  415. for n, v in zip(expr.outputs, values):
  416. node2value[n] = v
  417. if n in self._watch_point:
  418. self._rst[n] = v
  419. if self._end_point and get_all_endnode_val(n, v):
  420. return list(endnode2value[i] for i in self._end_point)
  421. return list(node2value[i] for i in self._outputs)
  422. def eval(self, *inputs):
  423. assert len(inputs) == len(self._inputs) - 1
  424. inp = [self._inputs[0].owner] + list(inputs)
  425. return self.interpret(*inp)
  426. def __repr__(self):
  427. return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
  428. ", ".join(str(i) for i in self._inputs),
  429. "\n\t".join("{}".format(str(i)) for i in self._exprs),
  430. ", ".join(str(i) for i in self._outputs),
  431. )
  432. def _get_meth_name(obj, func):
  433. tp = obj if isinstance(obj, type) else type(obj)
  434. for cls in tp.mro():
  435. for k, v in cls.__dict__.items():
  436. if v == func:
  437. return k
  438. return None
  439. def _wrapped_function(orig_func):
  440. @functools.wraps(orig_func)
  441. def wrapped_fn(*args, **kwargs):
  442. if is_tracing_module():
  443. unset_module_tracing()
  444. inputs, tree_def = tree_flatten((args, kwargs))
  445. for i in inputs:
  446. if not NodeMixin.get(i, None):
  447. if isinstance(i, (RawTensor, NodeMixin)):
  448. NodeMixin.wrap_safe(i, Constant.make(i))
  449. meth_name = _get_meth_name(args[0], wrapped_fn) if args else None
  450. if meth_name:
  451. self = inputs[0]
  452. if meth_name == "__new__":
  453. if all([not isinstance(i, RawTensor) for i in inputs]):
  454. # only trace Tensor.__new__() when there are tensors in args
  455. set_module_tracing()
  456. return orig_func(*args, **kwargs)
  457. if isinstance(args[1], RawTensor):
  458. node = NodeMixin.get(inputs[1])
  459. inputs[1] = copy.copy(inputs[1])
  460. # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing.
  461. NodeMixin.wrap_safe(inputs[1], node)
  462. args, kwargs = tree_def.unflatten(inputs)
  463. call_node = CallMethod.make(self, meth_name)
  464. else:
  465. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  466. call_node.add_inputs(inputs[1:])
  467. else:
  468. call_node = CallFunction.make(orig_func)
  469. call_node.add_inputs(inputs)
  470. call_node.arg_def = tree_def
  471. rst = orig_func(*args, **kwargs)
  472. if meth_name == "__setitem__":
  473. rst = self
  474. if rst is not None:
  475. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  476. call_node.out_def = out_def
  477. else:
  478. outputs = None
  479. call_node.add_outputs(outputs)
  480. set_module_tracing()
  481. return rst
  482. return orig_func(*args, **kwargs)
  483. return wrapped_fn
  484. class TracedModuleBuilder(NodeMixin):
  485. _mod = None # type: Module
  486. _body = None # type: InternalGraph
  487. _is_builtin = None # type: bool
  488. _argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
  489. _argdef_outdef_map = None # type: Dict[Treedef, Treedef]
  490. nodes = None
  491. __builder_attributes__ = [
  492. "_mod",
  493. "_body",
  494. "_NodeMixin__node",
  495. "_is_builtin",
  496. "build",
  497. "_record_wrapped_nodes",
  498. "_argdef_graph_map",
  499. "_argdef_outdef_map",
  500. "nodes",
  501. "__class__",
  502. "__dict__",
  503. ]
  504. def __init__(self, mod, is_top_module=False):
  505. super(TracedModuleBuilder, self).__init__()
  506. assert isinstance(mod, Module)
  507. self._mod = mod
  508. self._body = None
  509. self._is_top = is_top_module
  510. self._is_builtin = module_tracer.is_builtin(mod)
  511. self._argdef_graph_map = {}
  512. self._argdef_outdef_map = {}
  513. self.nodes = set()
  514. # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
  515. # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
  516. self.__class__ = type(
  517. "TracedModuleBuilder",
  518. (TracedModuleBuilder, mod.__class__),
  519. dict(TracedModuleBuilder.__dict__),
  520. )
  521. def build(self):
  522. if self._is_builtin:
  523. for node in self.nodes:
  524. node.module_type = type(self._mod)
  525. # node._owner = weakref.ref(self._mod)
  526. return self._mod
  527. else:
  528. traced_module = TracedModule(
  529. self._is_top, self._argdef_graph_map, self._argdef_outdef_map
  530. )
  531. for _, g in self._argdef_graph_map.items():
  532. g.compile()
  533. for k, v in self.__dict__.items():
  534. if k not in TracedModuleBuilder.__builder_attributes__:
  535. if isinstance(v, TracedModuleBuilder):
  536. v = v.build()
  537. setattr(traced_module, k, v)
  538. return traced_module
  539. def _record_wrapped_nodes(self, node):
  540. self.nodes.add(node)
  541. def __call__(self, *args, **kwargs):
  542. assert isinstance(self._mod, Module)
  543. # prepare args and kwargs for inner graph
  544. def mark_constant(x):
  545. node = NodeMixin.get(x, None)
  546. if node is None: # capture as constant
  547. NodeMixin.wrap(x, lambda: Constant.make(x))
  548. inputs, tree_def = tree_flatten(((self, *args), kwargs))
  549. for i in inputs:
  550. mark_constant(i)
  551. callnode = CallMethod.make(NodeMixin.get(self))
  552. callnode.add_inputs(inputs[1:])
  553. callnode.arg_def = tree_def
  554. if self._is_builtin:
  555. unset_module_tracing()
  556. rst = self._mod(*args, **kwargs)
  557. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  558. set_module_tracing()
  559. if self._is_builtin:
  560. self._body = None
  561. else:
  562. self_node = None
  563. if tree_def in self._argdef_graph_map:
  564. self_node = self._argdef_graph_map[tree_def].inputs[0]
  565. self._body = InternalGraph()
  566. active_module_tracer().push_scope(self._body)
  567. # rebind self to new input node
  568. orig_self = NodeMixin.get(self)
  569. if self_node:
  570. NodeMixin.wrap_safe(self, self_node)
  571. active_module_tracer().current_scope().add_input(self_node)
  572. else:
  573. NodeMixin.wrap_safe(
  574. self,
  575. self_node
  576. if self_node
  577. else Input.make("self", NodeMixin.get_wrapped_type(self)),
  578. )
  579. origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
  580. # prepare args and kwargs for inner graph
  581. def wrap(x):
  582. if isinstance(x, (RawTensor, NodeMixin)):
  583. NodeMixin.wrap(
  584. x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
  585. )
  586. return x
  587. args = [self]
  588. for i in inputs[1:]:
  589. args.append(wrap(i))
  590. args, kwargs = tree_def.unflatten(args)
  591. active_module_tracer().patcher.auto_patch(
  592. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  593. )
  594. rst = type(self._mod).forward(*args, **kwargs)
  595. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  596. for i in (
  597. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  598. ):
  599. active_module_tracer().current_scope().add_output(NodeMixin.get(i))
  600. NodeMixin.get(self, None).actual_mnode.append(orig_self)
  601. NodeMixin.wrap_safe(self, orig_self)
  602. for arg, node in zip(inputs[1:], origin_inp_node):
  603. if node:
  604. NodeMixin.wrap_safe(arg, node)
  605. active_module_tracer().pop_scope()
  606. # rebind output to outer graph
  607. callnode.out_def = out_def
  608. callnode.add_outputs(outputs)
  609. self._argdef_graph_map[callnode.arg_def] = self._body
  610. self._argdef_outdef_map[callnode.arg_def] = out_def
  611. return rst
  612. def __setattr__(self, name, value):
  613. object.__setattr__(self, name, value)
  614. def __repr__(self):
  615. return repr(self._mod)
  616. def __getattr__(self, name):
  617. if name not in self._mod.__dict__:
  618. attr = getattr(type(self._mod), name).__get__(self, type(self))
  619. else:
  620. attr = getattr(self._mod, name)
  621. if isinstance(attr, Module):
  622. attr = TracedModuleBuilder(attr)
  623. setattr(self, name, attr)
  624. NodeMixin.wrap(
  625. attr,
  626. lambda: GetAttr.make(
  627. NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr)
  628. ),
  629. )
  630. return attr
  631. def __getattribute__(self, name):
  632. if name in TracedModuleBuilder.__builder_attributes__:
  633. return object.__getattribute__(self, name)
  634. else:
  635. wrapped = object.__getattribute__(self, name)
  636. if name in self._mod.__dict__:
  637. mod_attr = getattr(self._mod, name)
  638. if not isinstance(mod_attr, Module) and wrapped is not mod_attr:
  639. wrapped = mod_attr
  640. setattr(self, name, wrapped)
  641. if isinstance(mod_attr, Module):
  642. assert mod_attr is wrapped._mod
  643. else:
  644. assert mod_attr is wrapped
  645. # assert not self._is_builtin
  646. if isinstance(wrapped, (NodeMixin, RawTensor)):
  647. NodeMixin.wrap(
  648. wrapped,
  649. lambda: GetAttr.make(
  650. NodeMixin.get(self),
  651. name,
  652. type=NodeMixin.get_wrapped_type(wrapped),
  653. ),
  654. )
  655. return wrapped
  656. class _expr_iter:
  657. def __init__(self, graph: InternalGraph):
  658. self.graph = graph
  659. def __iter__(self):
  660. for expr in self.graph._exprs:
  661. if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
  662. yield expr
  663. if expr.graph is not None:
  664. yield from expr.graph.expr_filter
  665. else:
  666. yield expr
  667. class _node_iter:
  668. def __init__(self, graph: InternalGraph) -> None:
  669. nodes = []
  670. node_ids = set()
  671. for expr in graph.expr_filter:
  672. for n in expr.inputs + expr.outputs:
  673. if n._id in node_ids:
  674. continue
  675. nodes.append(n)
  676. node_ids.add(n._id)
  677. self.nodes = list(sorted(nodes, key=lambda x: x._id))
  678. def __iter__(self):
  679. for node in self.nodes:
  680. yield node
  681. class BaseFilter:
  682. def __init__(self, expr_iter: Iterable):
  683. self._iter = expr_iter
  684. def __iter__(self):
  685. return iter(self._iter)
  686. def as_list(self):
  687. return list(self)
  688. def as_dict(self):
  689. return collections.OrderedDict((i._id, i) for i in self)
  690. def as_unique(self):
  691. rst = self.as_list()
  692. assert len(rst) == 1, "{} elements found".format(len(rst))
  693. (expr,) = self
  694. return expr
  695. def as_count(self):
  696. return sum(1 for _ in self)
  697. class ExprFilter(BaseFilter):
  698. def call_function(self, func):
  699. return ExprFilterCallFunction(self, func)
  700. def call_method(self, method):
  701. return ExprFilterCallMethod(self, method)
  702. def expr_id(self, expr_id: List[int]):
  703. return ExprFilterExprId(self, expr_id)
  704. class NodeFilter(BaseFilter):
  705. def type(self, owner_type, node_type):
  706. return NodeFilterType(self, owner_type, node_type)
  707. def node_id(self, node_id: List[int]):
  708. return NodeFilterNodeId(self, node_id)
  709. class NodeFilterType(NodeFilter):
  710. def __init__(self, expr_iter, owner_type, node_type):
  711. super().__init__(expr_iter)
  712. self.owner_type = owner_type
  713. self.node_type = node_type
  714. def __iter__(self):
  715. for node in self._iter:
  716. if not isinstance(node, self.node_type):
  717. continue
  718. if not hasattr(node, "owner"):
  719. continue
  720. if isinstance(node.owner, self.owner_type):
  721. yield node
  722. class NodeFilterNodeId(NodeFilter):
  723. def __init__(self, expr_iter, node_id: List[int]):
  724. super().__init__(expr_iter)
  725. if not isinstance(node_id, Sequence):
  726. node_id = [node_id]
  727. self.node_id = node_id
  728. def __iter__(self):
  729. for node in self._iter:
  730. if node._id in self.node_id:
  731. yield node
  732. class ExprFilterCallFunction(ExprFilter):
  733. def __init__(self, expr_iter, func: Callable = None):
  734. super().__init__(expr_iter)
  735. self.func = func
  736. def __iter__(self):
  737. for expr in self._iter:
  738. if not isinstance(expr, CallFunction):
  739. continue
  740. if self.func is None or expr.func == self.func:
  741. yield expr
  742. class ExprFilterCallMethod(ExprFilter):
  743. def __init__(self, expr_iter, method: str = None):
  744. super().__init__(expr_iter)
  745. self.method = method
  746. def __iter__(self):
  747. for expr in self._iter:
  748. if not isinstance(expr, CallMethod):
  749. continue
  750. if self.method is None or expr.method == self.method:
  751. yield expr
  752. class ExprFilterExprId(ExprFilter):
  753. def __init__(self, expr_iter, expr_id: List[int]):
  754. super().__init__(expr_iter)
  755. if not isinstance(expr_id, Sequence):
  756. expr_id = [expr_id]
  757. self.expr_id = expr_id
  758. def __iter__(self):
  759. for expr in self._iter:
  760. if expr._id in self.expr_id:
  761. yield expr
  762. class TracedModule(Module):
  763. """
  764. `TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it.
  765. """
  766. # m_node = None # type: ModuleNode
  767. argdef_graph_map = None
  768. argdef_outdef_map = None
  769. def __init__(self, is_top, argdef_graph_map, argdef_outdef_map):
  770. super(TracedModule, self).__init__()
  771. self.argdef_graph_map = argdef_graph_map
  772. self.argdef_outdef_map = argdef_outdef_map
  773. self._is_top = is_top
  774. self.watch_points = []
  775. self.watch_node_value = {}
  776. self.end_points = []
  777. def forward(self, *args, **kwargs):
  778. inputs, treedef = tree_flatten(((self, *args), kwargs))
  779. assert treedef in self.argdef_graph_map
  780. inputs = filter(
  781. lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
  782. ) # allow TracedModuleBuilder for retrace.
  783. outputs = self.argdef_graph_map[treedef].interpret(*inputs)
  784. if self.watch_points:
  785. self.watch_node_value = {}
  786. for n in self.watch_points:
  787. self.watch_node_value[n] = n.top_graph._rst.pop(n)
  788. if self.end_points:
  789. return outputs
  790. out_def = self.argdef_outdef_map[treedef]
  791. outputs = out_def.unflatten(outputs)
  792. return outputs
  793. def set_watch_points(self, nodes):
  794. if not isinstance(nodes, Sequence):
  795. nodes = [nodes]
  796. self.watch_points = nodes
  797. for n in nodes:
  798. n.top_graph._watch_point.append(n)
  799. def clear_watch_points(self):
  800. for n in self.watch_points:
  801. n.top_graph._watch_point = []
  802. self.watch_points = []
  803. self.watch_node_value = {}
  804. def set_end_points(self, nodes):
  805. if not isinstance(nodes, Sequence):
  806. nodes = [nodes]
  807. self.end_points = nodes
  808. graphs = list(self.argdef_graph_map.values())
  809. for n in nodes:
  810. assert n.top_graph in graphs
  811. n.top_graph._end_point.append(n)
  812. def clear_end_points(self):
  813. for n in self.end_points:
  814. n.top_graph._end_point = []
  815. self.end_points = []
  816. @property
  817. def graph(self) -> InternalGraph:
  818. if self._is_top:
  819. self._update_ref()
  820. assert len(self.argdef_graph_map) == 1
  821. return list(self.argdef_graph_map.values())[0]
  822. def _update_ref(self, actual_node_map: Union[Dict] = None):
  823. for inp_def, graph in self.argdef_graph_map.items():
  824. for n in graph._inputs + graph.outputs:
  825. n._top_graph = weakref.ref(graph)
  826. graph._inputs[0]._owner = weakref.ref(self)
  827. graph._inputs[0].actual_mnode = []
  828. if actual_node_map is not None and inp_def in actual_node_map.keys():
  829. graph._inputs[0].actual_mnode = actual_node_map[inp_def]
  830. node2obj = {}
  831. next_actual_node_map = collections.defaultdict(
  832. lambda: collections.defaultdict(list)
  833. )
  834. node2obj[graph._inputs[0]] = self
  835. for expr in graph._exprs:
  836. for n in expr.inputs + expr.outputs:
  837. n._top_graph = weakref.ref(graph)
  838. expr._top_graph = weakref.ref(graph)
  839. if isinstance(expr, GetAttr) and isinstance(
  840. expr.outputs[0], ModuleNode
  841. ):
  842. obj = getattr(node2obj[expr.inputs[0]], expr.name)
  843. expr.outputs[0]._owner = weakref.ref(obj)
  844. node2obj[expr.outputs[0]] = obj
  845. if isinstance(expr, Constant) and isinstance(
  846. expr.outputs[0], ModuleNode
  847. ):
  848. obj = expr.value
  849. expr.outputs[0]._owner = weakref.ref(obj)
  850. node2obj[expr.outputs[0]] = obj
  851. if (
  852. isinstance(expr, CallMethod)
  853. and expr.method == "__call__"
  854. and isinstance(expr.inputs[0], ModuleNode)
  855. ):
  856. obj = node2obj[expr.inputs[0]]
  857. if expr.arg_def is not None:
  858. next_actual_node_map[obj][expr.arg_def].append(expr.inputs[0])
  859. for obj in node2obj.values():
  860. if obj is self:
  861. continue
  862. mnode_map = None
  863. if obj in next_actual_node_map.keys():
  864. mnode_map = next_actual_node_map[obj]
  865. if isinstance(obj, TracedModule):
  866. obj._update_ref(mnode_map)
  867. def flatten(self):
  868. """
  869. Get a new module, which eliminates ``GetAttr`` and has no hierarchy.
  870. :return: :class:`TracedModule`
  871. """
  872. new_module = copy.deepcopy(self)
  873. def _flatten_subgraph(graph, module, call=None):
  874. if graph is None:
  875. assert not isinstance(module, TracedModule)
  876. const = Constant(module)
  877. const.outputs[0] = call.inputs[0]
  878. const.outputs[0].expr = const
  879. return [const, call]
  880. if call is not None:
  881. graph = copy.deepcopy(graph)
  882. exprs = []
  883. node2obj = {}
  884. node2obj[graph._inputs[0]] = module
  885. if call:
  886. node2obj[call.inputs[0]] = module
  887. # replace inputs for submodule's exprx
  888. if call:
  889. repl_dict = dict(zip(graph._inputs, call.inputs))
  890. for ind, out in enumerate(graph.outputs):
  891. if isinstance(out.expr, Input):
  892. assert out in repl_dict
  893. call_out = call.outputs[ind]
  894. for expr in call.outputs[ind].users:
  895. for index, inp in enumerate(expr.inputs):
  896. if inp is call_out:
  897. expr.inputs[index] = repl_dict[out]
  898. continue
  899. repl_dict[out] = call.outputs[ind]
  900. graph._replace_inputs_outputs(repl_dict)
  901. for expr in graph._exprs:
  902. if isinstance(expr, GetAttr):
  903. # replace GetAttr with Constant
  904. if isinstance(expr.outputs[0], TensorNode):
  905. const = Constant(getattr(node2obj[expr.inputs[0]], expr.name))
  906. const.outputs = expr.outputs
  907. const.outputs[0].expr = const
  908. exprs.append(const)
  909. elif isinstance(expr.outputs[0], ModuleNode):
  910. node2obj[expr.outputs[0]] = getattr(
  911. node2obj[expr.inputs[0]], expr.name
  912. )
  913. elif isinstance(expr, CallMethod):
  914. obj_node = expr.inputs[0]
  915. if isinstance(obj_node, ModuleNode):
  916. pre_expr = expr.inputs[0].expr
  917. if isinstance(pre_expr, GetAttr):
  918. (obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]])
  919. expr_graph = (
  920. obj.argdef_graph_map[expr.arg_def]
  921. if hasattr(obj, "argdef_graph_map")
  922. else None
  923. )
  924. exprs.extend(_flatten_subgraph(expr_graph, obj, expr))
  925. else:
  926. # module has been replaced.
  927. assert isinstance(pre_expr, Constant)
  928. exprs.append(expr)
  929. else:
  930. exprs.append(expr)
  931. else:
  932. exprs.append(expr)
  933. if call is not None:
  934. for i in call.inputs:
  935. i.users.remove(call)
  936. return exprs
  937. new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
  938. return new_module
  939. def __getstate__(self):
  940. d = self.__dict__
  941. for k in Module.__dict__:
  942. d.pop(k, None)
  943. return d
  944. def cpp_apply_module_trace(opdef, *args):
  945. return Apply.apply_module_trace_hook(opdef, *args)
  946. def register_as_builtin(mod_cls: Type[Module]) -> None:
  947. """
  948. Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
  949. param mod_cls: the Module class which will be threated as builtin module in tracing
  950. """
  951. module_tracer.register_as_builtin(mod_cls)
  952. def wrap(func: Union[Callable]):
  953. assert callable(func)
  954. if hasattr(func, "__code__"):
  955. assert not isinstance(func, str)
  956. fn_name = func.__code__.co_name
  957. currentframe = inspect.currentframe()
  958. assert currentframe is not None
  959. f = currentframe.f_back
  960. assert f is not None
  961. if f.f_code.co_name != "<module>":
  962. raise NotImplementedError("wrap must be called at the top level of a module")
  963. Patcher._builtin_functions.append((f.f_globals, fn_name))
  964. return func
  965. def _register_all_builtin_module():
  966. for sub_mod in [M, M.qat, M.quantized]:
  967. for m in getmembers(sub_mod):
  968. if (
  969. isclass(m[1])
  970. and issubclass(m[1], M.Module)
  971. and m[1] is not M.Sequential
  972. ):
  973. module_tracer.register_as_builtin(m[1])
  974. def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
  975. """
  976. Traces module ``mod`` and returns corresponding TracedModule.
  977. param mod: the module will be converted to TracedModule
  978. param input: the positional arguments passed to forward method of ``mod``
  979. param kwargs: the keyword arguments passed to forward method of ``mod``
  980. """
  981. assert active_module_tracer() is None
  982. assert isinstance(mod, Module)
  983. try:
  984. use_sym_shape = set_symbolic_shape(True)
  985. set_module_tracing()
  986. set_active_module_tracer(module_tracer(_wrapped_function))
  987. with active_module_tracer().patcher:
  988. global_scope = InternalGraph()
  989. active_module_tracer().push_scope(global_scope)
  990. builder = TracedModuleBuilder(mod, True)
  991. NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
  992. inputs, _ = tree_flatten((args, kwargs))
  993. for _, i in enumerate(inputs):
  994. # assert isinstance(i, Tensor), "not support "
  995. if isinstance(i, RawTensor):
  996. NodeMixin.wrap_safe(
  997. i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
  998. )
  999. builder(*args, **kwargs)
  1000. active_module_tracer().pop_scope()
  1001. return builder.build()
  1002. finally:
  1003. set_symbolic_shape(use_sym_shape)
  1004. set_active_module_tracer(None)
  1005. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台