You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 53 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import builtins
  10. import collections
  11. import copy
  12. import fnmatch
  13. import functools
  14. import keyword
  15. import re
  16. import weakref
  17. from inspect import getcallargs, getmembers, isclass, ismethod
  18. from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
  19. from ... import functional as F
  20. from ... import get_logger
  21. from ... import module as M
  22. from ...core._imperative_rt.core2 import Tensor as RawTensor
  23. from ...core._imperative_rt.core2 import (
  24. is_tracing_module,
  25. set_module_tracing,
  26. unset_module_tracing,
  27. )
  28. from ...core._trace_option import set_symbolic_shape
  29. from ...core.tensor.array_method import ArrayMethodMixin
  30. from ...module import Module
  31. from ...module.qat import QATModule
  32. from ...quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize
  33. from ...quantization.observer import (
  34. ExponentialMovingAverageObserver,
  35. HistogramObserver,
  36. MinMaxObserver,
  37. Observer,
  38. PassiveObserver,
  39. SyncExponentialMovingAverageObserver,
  40. SyncMinMaxObserver,
  41. )
  42. from ...tensor import Tensor
  43. from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
  44. from .fake_quant import FakeQuantize as TM_FakeQuant
  45. from .module_tracer import (
  46. Patcher,
  47. active_module_tracer,
  48. module_tracer,
  49. set_active_module_tracer,
  50. )
  51. from .node import ModuleNode, Node, NodeMixin, TensorNode
  52. from .pytree import ArgsIndex, tree_flatten
  53. logger = get_logger(__name__)
  54. def _is_builtin_name(name: str) -> bool:
  55. return (
  56. name in builtins.__dict__
  57. or name in keyword.kwlist
  58. or name in {"inf", "nan", "NoneType"}
  59. )
  60. def _is_leaf(node):
  61. assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
  62. type(node)
  63. )
  64. return isinstance(node, RawTensor)
  65. def wrap_tensors(tensors: Tensor, nodes: TensorNode):
  66. inp_tensors = copy.deepcopy(tensors)
  67. inp_tensors, inp_def_v = tree_flatten(inp_tensors)
  68. inp_nodes, inp_def_n = tree_flatten(nodes)
  69. for v, n in zip(inp_tensors, inp_nodes):
  70. if isinstance(n, TensorNode) and isinstance(v, Tensor):
  71. NodeMixin.wrap_safe(v, n)
  72. return inp_def_v.unflatten(inp_tensors)
  73. class _InsertExprs:
  74. def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True):
  75. self.graph = graph
  76. self.global_scope = InternalGraph()
  77. self.global_scope._used_names.update(graph._used_names)
  78. self.expr = expr
  79. self.after = after
  80. def __enter__(self):
  81. self.use_sym_shape = set_symbolic_shape(True)
  82. set_module_tracing()
  83. assert active_module_tracer() is None
  84. set_active_module_tracer(module_tracer(_wrapped_function))
  85. active_module_tracer().patcher.__enter__()
  86. active_module_tracer().push_scope(self.global_scope)
  87. def __exit__(self, ty, va, tr):
  88. set_symbolic_shape(self.use_sym_shape)
  89. unset_module_tracing()
  90. active_module_tracer().patcher.__exit__(ty, va, tr)
  91. set_active_module_tracer(None)
  92. index = len(self.graph._exprs) if self.after else 0
  93. if self.expr is not None:
  94. index = self.graph._exprs.index(self.expr)
  95. if self.after:
  96. index += 1
  97. for expr in self.global_scope._exprs:
  98. self.graph._exprs.insert(index, expr)
  99. index += 1
  100. self.graph._used_names.update(self.global_scope._used_names)
  101. class InternalGraph:
  102. """
  103. ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
  104. Attributes:
  105. _exprs: List of Exprs in order of execution
  106. _inputs: Input Nodes of InternalGraph
  107. _outputs: Output Nodes of InternalGraph
  108. """
  109. _exprs = None # type: List[Expr]
  110. _inputs = None # type: List[Node]
  111. _outputs = None # type: List[Node]
  112. def __init__(self, name: str = None, prefix_name: str = ""):
  113. self._exprs = []
  114. self._inputs = []
  115. self._outputs = []
  116. self._watch_point = []
  117. self._end_point = []
  118. self._used_names = {}
  119. self._rst = collections.defaultdict(list)
  120. self._name = name
  121. self._prefix_name = prefix_name
  122. def insert(self, expr):
  123. self._exprs.append(expr)
  124. def _create_unique_name(self, name: str) -> str:
  125. assert isinstance(name, str)
  126. name = re.sub("[^0-9a-zA-Z_]+", "_", name)
  127. if name[0].isdigit():
  128. name = "_{}".format(name)
  129. while name in self._used_names or _is_builtin_name(name):
  130. match = re.match(r"(.*)_(\d+)$", name)
  131. if match is None:
  132. name = name + "_1"
  133. else:
  134. base, num = match.group(1, 2)
  135. name = "{}_{}".format(base, int(num) + 1)
  136. self._used_names.setdefault(name)
  137. return name
  138. @property
  139. def inputs(self):
  140. return self._inputs
  141. @property
  142. def outputs(self):
  143. return self._outputs
  144. @property
  145. def expr_filter(self):
  146. return ExprFilter(_expr_iter(self))
  147. @property
  148. def node_filter(self):
  149. return NodeFilter(_node_iter(self))
  150. def get_function_by_type(self, func: Callable = None):
  151. return self.expr_filter.call_function(func)
  152. def get_method_by_type(self, method: str = None):
  153. return self.expr_filter.call_method(method)
  154. def get_expr_by_id(self, expr_id: List[int] = None):
  155. return self.expr_filter.expr_id(expr_id)
  156. def get_module_by_type(self, module_cls: Module):
  157. assert issubclass(module_cls, Module)
  158. return self.node_filter.type(module_cls, ModuleNode)
  159. def get_node_by_id(self, node_id: List[int] = None):
  160. return self.node_filter.node_id(node_id)
  161. def get_node_by_name(self, name: str = None, ignorecase: bool = True):
  162. return self.node_filter.name(name, ignorecase)
  163. def add_input(self, i):
  164. self._inputs.append(i)
  165. def add_output(self, o):
  166. self._outputs.append(o)
  167. def _replace_inputs_outputs_and_add_prefixname(self, repl_dict, prefix_name=""):
  168. for node, repl_node in repl_dict.items():
  169. assert node in self._inputs or node in self._outputs
  170. for i in node.users:
  171. if i not in repl_node.users:
  172. repl_node.users.append(i)
  173. for idx, i in enumerate(self._inputs):
  174. if i in repl_dict:
  175. self._inputs[idx] = repl_dict[i]
  176. for idx, o in enumerate(self._outputs):
  177. if o in repl_dict:
  178. self._outputs[idx] = repl_dict[o]
  179. for expr in self._exprs:
  180. for idx, i in enumerate(expr.inputs):
  181. assert i._name is not None
  182. if i in repl_dict:
  183. expr.inputs[idx] = repl_dict[i]
  184. elif isinstance(i, TensorNode) and prefix_name not in i._name:
  185. if i.top_graph != active_module_tracer().current_scope():
  186. i._name = (
  187. active_module_tracer()
  188. .current_scope()
  189. ._create_unique_name(prefix_name + i._name.lstrip("_"))
  190. )
  191. for idx, o in enumerate(expr.outputs):
  192. assert o._name is not None
  193. if o in repl_dict:
  194. expr.outputs[idx] = repl_dict[o]
  195. expr.outputs[idx].expr = expr
  196. elif isinstance(o, TensorNode) and prefix_name not in i._name:
  197. if o.top_graph != active_module_tracer().current_scope():
  198. o._name = (
  199. active_module_tracer()
  200. .current_scope()
  201. ._create_unique_name(prefix_name + o._name.lstrip("_"))
  202. )
  203. def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
  204. if not isinstance(nodes, Sequence):
  205. nodes = (nodes,)
  206. ret = list()
  207. queue = list(nodes)
  208. visited_queue = list()
  209. while queue:
  210. node = queue.pop()
  211. visited_queue.append(node)
  212. expr = node.expr
  213. if expr not in ret:
  214. ret.append(expr)
  215. for i in expr.inputs:
  216. if i not in queue and i not in visited_queue:
  217. queue.append(i)
  218. return ret
  219. def reset_inputs(self, *args, **kwargs):
  220. forma_mnode = self.inputs[0]
  221. actual_mnodes = forma_mnode.actual_mnode
  222. call_nodes = []
  223. for n in actual_mnodes:
  224. for c_expr in n.users:
  225. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  226. call_nodes.append((c_expr, n))
  227. moudle = forma_mnode.owner
  228. assert moudle._is_top, "reset_inputs only support the top-level graph"
  229. inputs, tree_def = tree_flatten(((moudle, *args), kwargs))
  230. def create_node(val: Tensor):
  231. node = Input(type=TensorNode).outputs[0]
  232. node.shape = val.shape
  233. node.dtype = val.dtype
  234. return node
  235. formal_node_inputs = [
  236. forma_mnode,
  237. ]
  238. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  239. if call_nodes:
  240. org_argdef = call_nodes[0][0].arg_def
  241. for v in inputs[1:]:
  242. assert isinstance(v, RawTensor)
  243. formal_node_inputs.append(create_node(v))
  244. actual_nodes = []
  245. for e, n in call_nodes:
  246. e.arg_def = tree_def
  247. actual_node_inputs = [
  248. n,
  249. ]
  250. for v in inputs[1:]:
  251. actual_node_inputs.append(create_node(v))
  252. for org_n in e.inputs:
  253. org_n.users.pop(e)
  254. e.inputs[:] = actual_node_inputs
  255. e.const_val = []
  256. actual_nodes.append(actual_node_inputs[1:])
  257. self._inputs[:] = formal_node_inputs
  258. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  259. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  260. # return formal_node_inputs[1:], actual_nodes
  261. return formal_node_inputs[1:]
  262. def add_input_node(self, shape, dtype="float32", name="args"):
  263. forma_mnode = self.inputs[0]
  264. actual_mnodes = forma_mnode.actual_mnode
  265. moudle = forma_mnode.owner
  266. assert moudle._is_top, "add_input_node only support the top-level graph"
  267. call_nodes = []
  268. for n in actual_mnodes:
  269. for c_expr in n.users:
  270. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  271. call_nodes.append(c_expr)
  272. def create_node(name=None, is_input: bool = True):
  273. if is_input:
  274. node = Input(type=TensorNode, name=name).outputs[0]
  275. else:
  276. node = TensorNode(expr=None, name=None)
  277. node.shape = shape
  278. node.dtype = dtype
  279. return node
  280. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  281. if call_nodes:
  282. org_argdef = call_nodes[0].arg_def
  283. args, kwargs = org_argdef.unflatten(self._inputs)
  284. formal_inp_node = create_node(self._create_unique_name(name), True)
  285. inputs, tree_def = tree_flatten(
  286. ((*args, formal_inp_node), kwargs),
  287. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  288. )
  289. self._inputs[:] = inputs[:]
  290. actual_inp_nodes = []
  291. for e in call_nodes:
  292. args, kwargs = e.unflatten_args(e.inputs)
  293. args = args + (create_node(False),)
  294. inputs, tree_def = tree_flatten(
  295. (args, kwargs),
  296. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  297. )
  298. e.inputs[:] = inputs[:]
  299. e.arg_def = tree_def
  300. actual_inp_nodes.append(args[-1])
  301. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  302. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  303. # return formal_inp_node, actual_inp_nodes
  304. return formal_inp_node
  305. def reset_outputs(self, outputs):
  306. outputs, out_def = tree_flatten(
  307. outputs, is_leaf=lambda x: isinstance(x, TensorNode),
  308. )
  309. forma_mnode = self.inputs[0]
  310. moudle = forma_mnode.owner
  311. assert moudle._is_top, "reset_outputs only support the top-level graph"
  312. actual_mnodes = forma_mnode.actual_mnode
  313. call_nodes = []
  314. for n in actual_mnodes:
  315. for c_expr in n.users:
  316. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  317. call_nodes.append((c_expr))
  318. def create_node(val: TensorNode, expr: Expr):
  319. node = TensorNode(expr)
  320. node.shape = val.shape
  321. node.dtype = val.dtype
  322. return node
  323. tree_def = list(moudle.argdef_graph_map.keys())[0]
  324. if call_nodes:
  325. tree_def = call_nodes[0].arg_def
  326. actual_nodes = []
  327. for e in call_nodes:
  328. actual_node_outputs = []
  329. for v in outputs:
  330. actual_node_outputs.append(create_node(v, e))
  331. e.outputs[:] = actual_node_outputs
  332. e.out_def = out_def
  333. actual_nodes.append(actual_node_outputs)
  334. self._outputs[:] = outputs
  335. moudle.argdef_outdef_map[tree_def] = out_def
  336. return actual_nodes
  337. def add_output_node(self, node: TensorNode):
  338. forma_mnode = self.inputs[0]
  339. moudle = forma_mnode.owner
  340. assert moudle._is_top, "add_output_node only support the top-level graph"
  341. actual_mnodes = forma_mnode.actual_mnode
  342. call_nodes = []
  343. for n in actual_mnodes:
  344. for c_expr in n.users:
  345. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  346. call_nodes.append((c_expr))
  347. def create_node(val: TensorNode, expr: Expr):
  348. node = TensorNode(expr)
  349. node.shape = val.shape
  350. node.dtype = val.dtype
  351. return node
  352. tree_def = list(moudle.argdef_graph_map.keys())[0]
  353. if call_nodes:
  354. tree_def = call_nodes[0].arg_def
  355. org_out_def = moudle.argdef_outdef_map[tree_def]
  356. org_outs = org_out_def.unflatten(self._outputs)
  357. outputs, out_def = tree_flatten(
  358. (org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode),
  359. )
  360. self._outputs[:] = outputs
  361. actual_out_nodes = []
  362. for e in call_nodes:
  363. actual_node = create_node(node, e)
  364. org_outs = org_out_def.unflatten(e.outputs)
  365. outputs, out_def = tree_flatten(
  366. (org_outs, actual_node), is_leaf=lambda x: isinstance(x, TensorNode),
  367. )
  368. e.outputs[:] = outputs
  369. e.out_def = out_def
  370. actual_out_nodes.append(actual_node)
  371. moudle.argdef_outdef_map[tree_def] = out_def
  372. return actual_out_nodes
  373. def insert_function(self, func: Callable, *args, **kwargs):
  374. assert isinstance(func, Callable)
  375. inp_nodes, inp_def = tree_flatten((args, kwargs))
  376. insert_idx = -1
  377. for i in inp_nodes:
  378. if isinstance(i, TensorNode) and i.expr in self._exprs:
  379. insert_idx = max(insert_idx, self._exprs.index(i.expr))
  380. fake_inp_val = list(
  381. F.zeros(shape=i.shape, dtype=i.dtype) if isinstance(i, TensorNode) else i
  382. for i in inp_nodes
  383. )
  384. for v, n in zip(fake_inp_val, inp_nodes):
  385. if isinstance(n, TensorNode):
  386. NodeMixin.wrap_safe(v, n)
  387. fake_args, fake_kwargs = inp_def.unflatten(fake_inp_val)
  388. insert_point = self.insert_exprs_before()
  389. if insert_idx != -1:
  390. insert_point = self.insert_exprs_after(self._exprs[insert_idx])
  391. with insert_point:
  392. rst = func(*fake_args, **fake_kwargs)
  393. if rst is None:
  394. return None
  395. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  396. node_outputs = []
  397. for out in outputs:
  398. assert isinstance(out, RawTensor)
  399. node_outputs.append(NodeMixin.get(out, None))
  400. node_outputs = out_def.unflatten(node_outputs)
  401. return node_outputs
  402. def insert_exprs_after(self, expr: Optional[Expr] = None):
  403. if expr is not None:
  404. assert expr.top_graph == self, "Expr to insert after is not in graph."
  405. return _InsertExprs(self, expr, after=True)
  406. def insert_exprs_before(self, expr: Optional[Expr] = None):
  407. if expr is not None:
  408. assert expr.top_graph == self, "Expr to insert before is not in graph."
  409. return _InsertExprs(self, expr, after=False)
  410. def replace_node(self, repl_dict: Dict[Node, Node]):
  411. while repl_dict:
  412. node, repl_node = repl_dict.popitem()
  413. # check graph inputs and outputs
  414. assert node not in self.inputs, "Cannot replace inputs"
  415. for i, n in enumerate(self.outputs):
  416. if n is node:
  417. self.outputs[i] = repl_node
  418. # update users of node and repl_node
  419. # update inputs of expr in node.users
  420. dep_exprs = self.get_dep_exprs(repl_node)
  421. i = 0
  422. while i < len(node.users):
  423. n = node.users[i]
  424. if n in dep_exprs:
  425. logger.info("Find a loop: ignore this replacement once")
  426. logger.info("node: %s" % node.__repr__())
  427. logger.info("repl_node: %s" % repl_node.__repr__())
  428. i += 1
  429. continue
  430. repl_node.users.append(n)
  431. node.users.pop(i)
  432. idx = n.inputs.index(node)
  433. n.inputs[idx] = repl_node
  434. def compile(self):
  435. """
  436. Delete unused expr.
  437. """
  438. dep_exprs = self.get_dep_exprs(self.outputs)
  439. i = 0
  440. while i < len(self._exprs):
  441. expr = self._exprs[i]
  442. if expr in dep_exprs or expr._disable_remove:
  443. i += 1
  444. continue
  445. for n in expr.inputs:
  446. n.users.remove(expr)
  447. self._exprs.remove(expr)
  448. def interpret(self, *inputs):
  449. node2value = {}
  450. end_nodes_set = set(self._end_point)
  451. endnode2value = {}
  452. def get_all_endnode_val(n, v):
  453. if n in end_nodes_set:
  454. endnode2value[n] = v
  455. end_nodes_set.remove(n)
  456. return not end_nodes_set
  457. return False
  458. for n, v in zip(self._inputs, inputs):
  459. node2value[n] = v
  460. if n in self._watch_point:
  461. self._rst[n].append(v)
  462. if n in self._end_point and get_all_endnode_val(n, v):
  463. return list(endnode2value[i] for i in self._end_point)
  464. for expr in self._exprs:
  465. values = expr.interpret(*list(node2value[i] for i in expr.inputs))
  466. if values is not None:
  467. for n, v in zip(expr.outputs, values):
  468. node2value[n] = v
  469. if n in self._watch_point:
  470. self._rst[n] = v
  471. if self._end_point and get_all_endnode_val(n, v):
  472. return list(endnode2value[i] for i in self._end_point)
  473. return list(node2value[i] for i in self._outputs)
  474. def eval(self, *inputs):
  475. assert len(inputs) == len(self._inputs) - 1
  476. inp = [self._inputs[0].owner] + list(inputs)
  477. return self.interpret(*inp)
  478. def __repr__(self):
  479. return self.__format__()
  480. def __format__(self, format_spec: str = "") -> str:
  481. saved_format_spec = Node.set_format_spec(format_spec)
  482. name = ""
  483. if self._name:
  484. name = "%s.Graph" % self._name
  485. res = "{} ({}) {{\n\t{}\n\treturn {}\n}}".format(
  486. name,
  487. ", ".join(str(i) for i in self._inputs),
  488. "\n\t".join("{}".format(str(i)) for i in self._exprs),
  489. ", ".join(str(i) for i in self._outputs),
  490. )
  491. Node.set_format_spec(saved_format_spec)
  492. return res
  493. def _get_meth_name(obj, func):
  494. tp = obj if isinstance(obj, type) else type(obj)
  495. for cls in tp.mro():
  496. for k, v in cls.__dict__.items():
  497. if v == func:
  498. return k
  499. return None
  500. def _wrapped_function(orig_func):
  501. @functools.wraps(orig_func)
  502. def wrapped_fn(*args, **kwargs):
  503. if is_tracing_module():
  504. unset_module_tracing()
  505. inputs, tree_def = tree_flatten((args, kwargs))
  506. for i in inputs:
  507. if not NodeMixin.get(i, None):
  508. if isinstance(i, (RawTensor, NodeMixin)):
  509. NodeMixin.wrap_safe(i, Constant.make(i))
  510. meth_name = _get_meth_name(args[0], wrapped_fn) if args else None
  511. arg_type = args[0] if isinstance(args[0], type) else type(args[0])
  512. if meth_name and issubclass(arg_type, RawTensor):
  513. self = inputs[0]
  514. if meth_name == "__new__":
  515. if all([not isinstance(i, RawTensor) for i in inputs]):
  516. # only trace Tensor.__new__() when there are tensors in args
  517. set_module_tracing()
  518. return orig_func(*args, **kwargs)
  519. if isinstance(args[1], RawTensor):
  520. node = NodeMixin.get(inputs[1])
  521. inputs[1] = copy.copy(inputs[1])
  522. # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing.
  523. NodeMixin.wrap_safe(inputs[1], node)
  524. args, kwargs = tree_def.unflatten(inputs)
  525. call_node = CallMethod.make(self, meth_name)
  526. else:
  527. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  528. call_node.add_inputs(inputs[1:])
  529. else:
  530. call_node = CallFunction.make(orig_func)
  531. call_node.add_inputs(inputs)
  532. call_node.arg_def = tree_def
  533. rst = orig_func(*args, **kwargs)
  534. if meth_name == "__setitem__":
  535. rst = self
  536. if rst is not None:
  537. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  538. call_node.out_def = out_def
  539. else:
  540. outputs = None
  541. call_node.add_outputs(outputs)
  542. set_module_tracing()
  543. return rst
  544. return orig_func(*args, **kwargs)
  545. return wrapped_fn
  546. class TracedModuleBuilder(NodeMixin):
  547. _mod = None # type: Module
  548. _body = None # type: InternalGraph
  549. _is_builtin = None # type: bool
  550. _argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
  551. _argdef_outdef_map = None # type: Dict[Treedef, Treedef]
  552. nodes = None
  553. __builder_attributes__ = [
  554. "_mod",
  555. "_body",
  556. "_NodeMixin__node",
  557. "_is_builtin",
  558. "build",
  559. "_record_wrapped_nodes",
  560. "_argdef_graph_map",
  561. "_argdef_outdef_map",
  562. "nodes",
  563. "__class__",
  564. "__dict__",
  565. ]
  566. def __init__(self, mod, is_top_module=False):
  567. super(TracedModuleBuilder, self).__init__()
  568. assert isinstance(mod, Module)
  569. self._mod = mod
  570. self._body = None
  571. self._is_top = is_top_module
  572. self._is_builtin = (
  573. True
  574. if isinstance(mod, (Observer, _FakeQuantize))
  575. else module_tracer.is_builtin(mod)
  576. )
  577. if isinstance(self._mod, QATModule):
  578. unset_module_tracing()
  579. self._check_qat_module(self._mod)
  580. set_module_tracing()
  581. self._argdef_graph_map = {}
  582. self._argdef_outdef_map = {}
  583. self.nodes = set()
  584. # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
  585. # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
  586. self.__class__ = type(
  587. "TracedModuleBuilder",
  588. (TracedModuleBuilder, mod.__class__),
  589. dict(TracedModuleBuilder.__dict__),
  590. )
  591. def _check_qat_module(self, qat_module):
  592. def isbuiltin(m):
  593. return m is None or module_tracer.is_builtin(m)
  594. if qat_module.with_act:
  595. act_observer = qat_module.act_observer
  596. act_fakequant = qat_module.act_fake_quant
  597. if not isbuiltin(act_observer) or not isbuiltin(act_fakequant):
  598. qparams = (
  599. act_observer.get_qparams()
  600. if hasattr(act_observer, "get_qparams")
  601. else act_fakequant.get_qparams()
  602. )
  603. dtype = (
  604. act_observer.dtype
  605. if hasattr(act_observer, "dtype")
  606. else act_fakequant.dtype
  607. )
  608. qat_module.act_observer = None
  609. qat_module.act_fake_quant = TM_FakeQuant(dtype)
  610. qat_module.act_fake_quant.set_qparams(qparams)
  611. if qat_module.with_weight:
  612. weight_observer = qat_module.weight_observer
  613. weight_fakequant = qat_module.weight_fake_quant
  614. if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant):
  615. qparams = (
  616. weight_observer.get_qparams()
  617. if hasattr(weight_observer, "get_qparams")
  618. else weight_fakequant.get_qparams()
  619. )
  620. dtype = (
  621. weight_observer.dtype
  622. if hasattr(weight_observer, "dtype")
  623. else weight_fakequant.dtype
  624. )
  625. qat_module.weight_observer = None
  626. qat_module.weight_fake_quant = TM_FakeQuant(dtype)
  627. qat_module.weight_fake_quant.set_qparams(qparams)
  628. def build(self):
  629. if self._is_builtin or isinstance(self._mod, TracedModule):
  630. if module_tracer.is_builtin(self._mod) or isinstance(
  631. self._mod, TracedModule
  632. ):
  633. mod_type = type(self._mod)
  634. else:
  635. assert isinstance(self._mod, (Observer, _FakeQuantize))
  636. mod_type = (
  637. Observer if isinstance(self._mod, Observer) else _FakeQuantize
  638. )
  639. for node in self.nodes:
  640. node.module_type = mod_type
  641. return self._mod
  642. else:
  643. is_qat = isinstance(self._mod, QATModule)
  644. traced_module = TracedModule(
  645. self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat
  646. )
  647. for _, g in self._argdef_graph_map.items():
  648. g.compile()
  649. for k, v in self.__dict__.items():
  650. if k not in TracedModuleBuilder.__builder_attributes__:
  651. if isinstance(v, TracedModuleBuilder):
  652. v = v.build()
  653. setattr(traced_module, k, v)
  654. if isinstance(self._mod, QATModule):
  655. unset_module_tracing()
  656. traced_module.with_act = self._mod.with_act
  657. traced_module.with_weight = self._mod.with_weight
  658. if not hasattr(traced_module, "act_fake_quant"):
  659. traced_module.act_fakequant = None
  660. if not hasattr(traced_module, "act_observer"):
  661. traced_module.act_observer = None
  662. if not hasattr(traced_module, "weight_fake_quant"):
  663. traced_module.weight_fakequant = None
  664. if not hasattr(traced_module, "weight_observer"):
  665. traced_module.weight_observer = None
  666. set_module_tracing()
  667. return traced_module
  668. def _record_wrapped_nodes(self, node):
  669. self.nodes.add(node)
  670. def __call__(self, *args, **kwargs):
  671. assert isinstance(self._mod, Module)
  672. # prepare args and kwargs for inner graph
  673. def mark_constant(x):
  674. node = NodeMixin.get(x, None)
  675. if node is None: # capture as constant
  676. NodeMixin.wrap(x, lambda: Constant.make(x))
  677. inputs, tree_def = tree_flatten(((self, *args), kwargs))
  678. for i in inputs:
  679. mark_constant(i)
  680. callnode = CallMethod.make(NodeMixin.get(self))
  681. callnode.add_inputs(inputs[1:])
  682. callnode.arg_def = tree_def
  683. if (
  684. self._is_builtin
  685. or tree_def in self._argdef_graph_map
  686. or isinstance(self._mod, TracedModule)
  687. ):
  688. unset_module_tracing()
  689. rst = self._mod(*args, **kwargs)
  690. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  691. set_module_tracing()
  692. if self._is_builtin:
  693. self._body = None
  694. elif tree_def in self._argdef_graph_map:
  695. self._body = self._argdef_graph_map[tree_def]
  696. else:
  697. self._mod._is_top = False
  698. self._body = self._mod.graph
  699. name = NodeMixin.get(self)._name
  700. if name:
  701. self._body._name = name
  702. else:
  703. self_node = None
  704. orig_self = NodeMixin.get(self)
  705. top_graph = active_module_tracer().current_scope()
  706. graph_prefix_name = top_graph._name
  707. if top_graph._prefix_name:
  708. graph_prefix_name = "{}_{}".format(
  709. top_graph._prefix_name, graph_prefix_name.lstrip("_")
  710. )
  711. self._body = InternalGraph(orig_self._name, prefix_name=graph_prefix_name)
  712. active_module_tracer().push_scope(self._body)
  713. # rebind self to new input node
  714. if self_node:
  715. NodeMixin.wrap_safe(self, self_node)
  716. active_module_tracer().current_scope().add_input(self_node)
  717. else:
  718. NodeMixin.wrap_safe(
  719. self,
  720. self_node
  721. if self_node
  722. else Input.make("self", NodeMixin.get_wrapped_type(self)),
  723. )
  724. origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
  725. # prepare args and kwargs for inner graph
  726. index_args, index_kwargs = tree_def.unflatten(
  727. [
  728. ArgsIndex(0),
  729. *list(ArgsIndex(i + 1) for i in range(len(origin_inp_node))),
  730. ]
  731. )
  732. key2idx = getcallargs(type(self._mod).forward, *index_args, **index_kwargs)
  733. idx2key = {}
  734. for k, v in key2idx.items():
  735. if isinstance(v, ArgsIndex):
  736. idx2key[v.index] = k
  737. else:
  738. flatten_argidx, _ = tree_flatten(v)
  739. for _i, v in enumerate(flatten_argidx):
  740. if isinstance(v, ArgsIndex):
  741. idx2key[v.index] = k + "_%d" % _i
  742. def wrap(x, name):
  743. if isinstance(x, (RawTensor, NodeMixin)):
  744. NodeMixin.wrap(
  745. x,
  746. lambda: Input.make(
  747. type=NodeMixin.get_wrapped_type(x), name=name
  748. ),
  749. )
  750. return x
  751. args = [self]
  752. for i, v in enumerate(inputs[1:]):
  753. args.append(wrap(v, idx2key[i + 1]))
  754. args, kwargs = tree_def.unflatten(args)
  755. active_module_tracer().patcher.auto_patch(
  756. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  757. )
  758. rst = type(self._mod).forward(*args, **kwargs)
  759. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  760. for i in (
  761. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  762. ):
  763. active_module_tracer().current_scope().add_output(NodeMixin.get(i))
  764. NodeMixin.get(self, None).actual_mnode.append(orig_self)
  765. NodeMixin.wrap_safe(self, orig_self)
  766. for arg, node in zip(inputs[1:], origin_inp_node):
  767. if node:
  768. NodeMixin.wrap_safe(arg, node)
  769. active_module_tracer().pop_scope()
  770. # rebind output to outer graph
  771. callnode.out_def = out_def
  772. callnode.add_outputs(outputs)
  773. self._argdef_graph_map[callnode.arg_def] = self._body
  774. self._argdef_outdef_map[callnode.arg_def] = out_def
  775. return rst
  776. def __setattr__(self, name, value):
  777. object.__setattr__(self, name, value)
  778. def __repr__(self):
  779. return repr(self._mod)
  780. def __getattr__(self, name):
  781. if name not in self._mod.__dict__:
  782. attr = getattr(type(self._mod), name).__get__(self, type(self))
  783. else:
  784. attr = getattr(self._mod, name)
  785. if isinstance(attr, Module):
  786. attr = TracedModuleBuilder(attr)
  787. if isinstance(attr, (Module, RawTensor)):
  788. setattr(self, name, attr)
  789. NodeMixin.wrap(
  790. attr,
  791. lambda: GetAttr.make(
  792. NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr)
  793. ),
  794. )
  795. return attr
  796. def __getattribute__(self, name):
  797. if name in TracedModuleBuilder.__builder_attributes__:
  798. return object.__getattribute__(self, name)
  799. else:
  800. wrapped = object.__getattribute__(self, name)
  801. if name in self._mod.__dict__:
  802. mod_attr = getattr(self._mod, name)
  803. if not isinstance(mod_attr, Module) and wrapped is not mod_attr:
  804. wrapped = mod_attr
  805. setattr(self, name, wrapped)
  806. if isinstance(mod_attr, Module):
  807. assert mod_attr is wrapped._mod
  808. else:
  809. assert mod_attr is wrapped
  810. # assert not self._is_builtin
  811. if isinstance(wrapped, (NodeMixin, RawTensor)):
  812. NodeMixin.wrap(
  813. wrapped,
  814. lambda: GetAttr.make(
  815. NodeMixin.get(self),
  816. name,
  817. type=NodeMixin.get_wrapped_type(wrapped),
  818. ),
  819. )
  820. return wrapped
  821. class _expr_iter:
  822. def __init__(self, graph: InternalGraph):
  823. self.graph = graph
  824. def __iter__(self):
  825. for expr in self.graph._exprs:
  826. if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
  827. yield expr
  828. if expr.graph is not None:
  829. yield from expr.graph.expr_filter
  830. else:
  831. yield expr
  832. class _node_iter:
  833. def __init__(self, graph: InternalGraph) -> None:
  834. nodes = []
  835. node_ids = set()
  836. for expr in graph.expr_filter:
  837. for n in expr.inputs + expr.outputs:
  838. if n._id in node_ids:
  839. continue
  840. nodes.append(n)
  841. node_ids.add(n._id)
  842. self.nodes = list(sorted(nodes, key=lambda x: x._id))
  843. def __iter__(self):
  844. for node in self.nodes:
  845. yield node
  846. class BaseFilter:
  847. def __init__(self, expr_iter: Iterable):
  848. self._iter = expr_iter
  849. def __iter__(self):
  850. return iter(self._iter)
  851. def as_list(self):
  852. return list(self)
  853. def as_dict(self):
  854. return collections.OrderedDict((i._id, i) for i in self)
  855. def as_unique(self):
  856. rst = self.as_list()
  857. assert len(rst) == 1, "{} elements found".format(len(rst))
  858. (expr,) = self
  859. return expr
  860. def as_count(self):
  861. return sum(1 for _ in self)
  862. class ExprFilter(BaseFilter):
  863. def call_function(self, func):
  864. return ExprFilterCallFunction(self, func)
  865. def call_method(self, method):
  866. return ExprFilterCallMethod(self, method)
  867. def expr_id(self, expr_id: List[int]):
  868. return ExprFilterExprId(self, expr_id)
  869. class NodeFilter(BaseFilter):
  870. def type(self, owner_type, node_type):
  871. return NodeFilterType(self, owner_type, node_type)
  872. def node_id(self, node_id: List[int]):
  873. return NodeFilterNodeId(self, node_id)
  874. def name(self, name: str, ignorecase: bool = True):
  875. return NodeFilterName(self, name, ignorecase)
  876. class NodeFilterType(NodeFilter):
  877. def __init__(self, expr_iter, owner_type, node_type):
  878. super().__init__(expr_iter)
  879. self.owner_type = owner_type
  880. self.node_type = node_type
  881. def __iter__(self):
  882. for node in self._iter:
  883. if not isinstance(node, self.node_type):
  884. continue
  885. if not hasattr(node, "owner"):
  886. continue
  887. if isinstance(node.owner, self.owner_type):
  888. yield node
  889. class NodeFilterNodeId(NodeFilter):
  890. def __init__(self, expr_iter, node_id: List[int]):
  891. super().__init__(expr_iter)
  892. if not isinstance(node_id, Sequence):
  893. node_id = [node_id]
  894. self.node_id = node_id
  895. def __iter__(self):
  896. for node in self._iter:
  897. if node._id in self.node_id:
  898. yield node
  899. class NodeFilterName(NodeFilter):
  900. _re = None
  901. def __init__(self, node_iter, pattern, ignorecase):
  902. super().__init__(node_iter)
  903. self.pattern = pattern
  904. self._re = self.make_re(pattern, ignorecase)
  905. @classmethod
  906. def make_re(cls, pattern, ignorecase=True):
  907. assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern)
  908. assert isinstance(ignorecase, bool)
  909. flags = 0
  910. if ignorecase:
  911. flags |= re.IGNORECASE
  912. return re.compile(fnmatch.translate(pattern), flags=flags)
  913. def __iter__(self):
  914. for i in self._iter:
  915. graph = i.top_graph
  916. name = "{}_{}".format(graph._name, i._name.lstrip("_"))
  917. if graph._prefix_name:
  918. name = "{}_{}".format(graph._prefix_name, name.lstrip("_"))
  919. if self.pattern == name or self._re.match(name):
  920. yield i
  921. class ExprFilterCallFunction(ExprFilter):
  922. def __init__(self, expr_iter, func: Callable = None):
  923. super().__init__(expr_iter)
  924. self.func = func
  925. def __iter__(self):
  926. for expr in self._iter:
  927. if not isinstance(expr, CallFunction):
  928. continue
  929. if self.func is None or expr.func == self.func:
  930. yield expr
  931. class ExprFilterCallMethod(ExprFilter):
  932. def __init__(self, expr_iter, method: str = None):
  933. super().__init__(expr_iter)
  934. self.method = method
  935. def __iter__(self):
  936. for expr in self._iter:
  937. if not isinstance(expr, CallMethod):
  938. continue
  939. if self.method is None or expr.method == self.method:
  940. yield expr
  941. class ExprFilterExprId(ExprFilter):
  942. def __init__(self, expr_iter, expr_id: List[int]):
  943. super().__init__(expr_iter)
  944. if not isinstance(expr_id, Sequence):
  945. expr_id = [expr_id]
  946. self.expr_id = expr_id
  947. def __iter__(self):
  948. for expr in self._iter:
  949. if expr._id in self.expr_id:
  950. yield expr
  951. class TracedModule(Module):
  952. """
  953. `TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it.
  954. """
  955. # m_node = None # type: ModuleNode
  956. argdef_graph_map = None
  957. argdef_outdef_map = None
  958. def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False):
  959. super(TracedModule, self).__init__()
  960. self.argdef_graph_map = argdef_graph_map
  961. self.argdef_outdef_map = argdef_outdef_map
  962. self._is_top = is_top
  963. self.watch_points = []
  964. self.watch_node_value = {}
  965. self.end_points = []
  966. self.is_qat = is_qat
  967. def forward(self, *args, **kwargs):
  968. inputs, treedef = tree_flatten(((self, *args), kwargs))
  969. assert treedef in self.argdef_graph_map
  970. inputs = filter(
  971. lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
  972. ) # allow TracedModuleBuilder for retrace.
  973. outputs = self.argdef_graph_map[treedef].interpret(*inputs)
  974. if self.watch_points:
  975. self.watch_node_value = {}
  976. for n in self.watch_points:
  977. self.watch_node_value[n] = n.top_graph._rst.pop(n)
  978. if self.end_points:
  979. return outputs
  980. out_def = self.argdef_outdef_map[treedef]
  981. outputs = out_def.unflatten(outputs)
  982. return outputs
  983. def set_watch_points(self, nodes):
  984. if not isinstance(nodes, Sequence):
  985. nodes = [nodes]
  986. self.watch_points = nodes
  987. for n in nodes:
  988. n.top_graph._watch_point.append(n)
  989. def clear_watch_points(self):
  990. for n in self.watch_points:
  991. n.top_graph._watch_point = []
  992. self.watch_points = []
  993. self.watch_node_value = {}
  994. def set_end_points(self, nodes):
  995. if not isinstance(nodes, Sequence):
  996. nodes = [nodes]
  997. self.end_points = nodes
  998. graphs = list(self.argdef_graph_map.values())
  999. for n in nodes:
  1000. assert n.top_graph in graphs
  1001. n.top_graph._end_point.append(n)
  1002. def clear_end_points(self):
  1003. for n in self.end_points:
  1004. n.top_graph._end_point = []
  1005. self.end_points = []
  1006. @property
  1007. def graph(self) -> InternalGraph:
  1008. if self._is_top:
  1009. self._update_ref()
  1010. assert len(self.argdef_graph_map) == 1
  1011. return list(self.argdef_graph_map.values())[0]
  1012. def _update_ref(self, actual_node_map: Union[Dict] = None):
  1013. for inp_def, graph in self.argdef_graph_map.items():
  1014. for n in graph._inputs + graph.outputs:
  1015. n._top_graph = weakref.ref(graph)
  1016. graph._inputs[0]._owner = weakref.ref(self)
  1017. graph._inputs[0].actual_mnode = []
  1018. if actual_node_map is not None and inp_def in actual_node_map.keys():
  1019. graph._inputs[0].actual_mnode = actual_node_map[inp_def]
  1020. node2obj = {}
  1021. next_actual_node_map = collections.defaultdict(
  1022. lambda: collections.defaultdict(list)
  1023. )
  1024. node2obj[graph._inputs[0]] = self
  1025. for expr in graph._exprs:
  1026. for n in expr.inputs + expr.outputs:
  1027. n._top_graph = weakref.ref(graph)
  1028. expr._top_graph = weakref.ref(graph)
  1029. if isinstance(expr, GetAttr) and isinstance(
  1030. expr.outputs[0], ModuleNode
  1031. ):
  1032. obj = getattr(node2obj[expr.inputs[0]], expr.name)
  1033. expr.outputs[0]._owner = weakref.ref(obj)
  1034. node2obj[expr.outputs[0]] = obj
  1035. if isinstance(expr, Constant) and isinstance(
  1036. expr.outputs[0], ModuleNode
  1037. ):
  1038. obj = expr.value
  1039. expr.outputs[0]._owner = weakref.ref(obj)
  1040. node2obj[expr.outputs[0]] = obj
  1041. if (
  1042. isinstance(expr, CallMethod)
  1043. and expr.method == "__call__"
  1044. and isinstance(expr.inputs[0], ModuleNode)
  1045. ):
  1046. obj = node2obj[expr.inputs[0]]
  1047. if expr.arg_def is not None:
  1048. next_actual_node_map[obj][expr.arg_def].append(expr.inputs[0])
  1049. for obj in node2obj.values():
  1050. if obj is self:
  1051. continue
  1052. mnode_map = None
  1053. if obj in next_actual_node_map.keys():
  1054. mnode_map = next_actual_node_map[obj]
  1055. if isinstance(obj, TracedModule):
  1056. obj._update_ref(mnode_map)
  1057. def flatten(self):
  1058. """
  1059. Get a new module, which eliminates ``GetAttr`` and has no hierarchy.
  1060. :return: :class:`TracedModule`
  1061. """
  1062. new_module = copy.deepcopy(self)
  1063. module2name = {}
  1064. assert active_module_tracer() is None
  1065. set_active_module_tracer(module_tracer(lambda x: x))
  1066. active_module_tracer().push_scope(new_module.graph)
  1067. for n, m in new_module.named_modules():
  1068. module2name[id(m)] = n
  1069. def _flatten_subgraph(
  1070. graph: InternalGraph, module: Module, call=None, prefix_name=""
  1071. ):
  1072. if graph is not None and prefix_name and prefix_name[-1] != "_":
  1073. prefix_name += "_"
  1074. if graph is None or module.is_qat:
  1075. assert not isinstance(module, TracedModule) or module.is_qat
  1076. const = Constant(module, "self.%s" % module2name[id(module)])
  1077. m_node = call.inputs[0]
  1078. if m_node.top_graph != active_module_tracer().current_scope():
  1079. m_node._name = (
  1080. active_module_tracer()
  1081. .current_scope()
  1082. ._create_unique_name(prefix_name)
  1083. )
  1084. const.outputs[0] = m_node
  1085. const.outputs[0].expr = const
  1086. return [const, call]
  1087. if call is not None:
  1088. graph = copy.deepcopy(graph)
  1089. exprs = []
  1090. node2obj = {}
  1091. node2obj[graph._inputs[0]] = module
  1092. if call:
  1093. node2obj[call.inputs[0]] = module
  1094. # replace inputs for submodule's exprx
  1095. if call:
  1096. repl_dict = dict(zip(graph._inputs, call.inputs))
  1097. for ind, out in enumerate(graph.outputs):
  1098. if isinstance(out.expr, Input):
  1099. assert out in repl_dict
  1100. call_out = call.outputs[ind]
  1101. for expr in call.outputs[ind].users:
  1102. for index, inp in enumerate(expr.inputs):
  1103. if inp is call_out:
  1104. expr.inputs[index] = repl_dict[out]
  1105. continue
  1106. repl_dict[out] = call.outputs[ind]
  1107. graph._replace_inputs_outputs_and_add_prefixname(repl_dict, prefix_name)
  1108. for expr in graph._exprs:
  1109. if isinstance(expr, GetAttr):
  1110. # replace GetAttr with Constant
  1111. if isinstance(expr.outputs[0], TensorNode):
  1112. const = Constant(getattr(node2obj[expr.inputs[0]], expr.name))
  1113. const.outputs = expr.outputs
  1114. const.outputs[0].expr = const
  1115. exprs.append(const)
  1116. elif isinstance(expr.outputs[0], ModuleNode):
  1117. node2obj[expr.outputs[0]] = getattr(
  1118. node2obj[expr.inputs[0]], expr.name
  1119. )
  1120. elif isinstance(expr, CallMethod):
  1121. obj_node = expr.inputs[0]
  1122. if isinstance(obj_node, ModuleNode):
  1123. pre_expr = expr.inputs[0].expr
  1124. if isinstance(pre_expr, GetAttr):
  1125. (obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]])
  1126. expr_graph = (
  1127. obj.argdef_graph_map[expr.arg_def]
  1128. if hasattr(obj, "argdef_graph_map")
  1129. else None
  1130. )
  1131. exprs.extend(
  1132. _flatten_subgraph(
  1133. expr_graph,
  1134. obj,
  1135. expr,
  1136. prefix_name + obj_node._name.lstrip("_"),
  1137. )
  1138. )
  1139. else:
  1140. # module has been replaced.
  1141. assert isinstance(pre_expr, Constant)
  1142. exprs.append(expr)
  1143. else:
  1144. exprs.append(expr)
  1145. else:
  1146. exprs.append(expr)
  1147. if call is not None:
  1148. for i in call.inputs:
  1149. i.users.remove(call)
  1150. return exprs
  1151. new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
  1152. new_module.graph.compile()
  1153. set_active_module_tracer(None)
  1154. for _id, expr in enumerate(new_module.graph._exprs):
  1155. expr._id = _id
  1156. total_node_id = 0
  1157. for i in new_module.graph._inputs:
  1158. i._id = total_node_id
  1159. total_node_id += 1
  1160. for expr in new_module.graph._exprs:
  1161. for o in expr.outputs:
  1162. o._id = total_node_id
  1163. total_node_id += 1
  1164. return new_module
  1165. def __getstate__(self):
  1166. d = self.__dict__
  1167. for k in Module.__dict__:
  1168. d.pop(k, None)
  1169. return d
  1170. def cpp_apply_module_trace(opdef, *args):
  1171. return Apply.apply_module_trace_hook(opdef, *args)
  1172. def register_as_builtin(mod_cls: Type[Module]) -> None:
  1173. """
  1174. Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
  1175. param mod_cls: the Module class which will be threated as builtin module in tracing
  1176. """
  1177. module_tracer.register_as_builtin(mod_cls)
  1178. wrap = _wrapped_function
  1179. def _register_all_builtin_module():
  1180. for sub_mod in [M, M.qat, M.quantized]:
  1181. for m in getmembers(sub_mod):
  1182. if (
  1183. isclass(m[1])
  1184. and issubclass(m[1], M.Module)
  1185. and m[1] is not M.Sequential
  1186. and m[1] is not M.ModuleList
  1187. ):
  1188. module_tracer.register_as_builtin(m[1])
  1189. module_tracer.register_as_builtin(Observer)
  1190. module_tracer.register_as_builtin(MinMaxObserver)
  1191. module_tracer.register_as_builtin(SyncMinMaxObserver)
  1192. module_tracer.register_as_builtin(ExponentialMovingAverageObserver)
  1193. module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver)
  1194. module_tracer.register_as_builtin(HistogramObserver)
  1195. module_tracer.register_as_builtin(PassiveObserver)
  1196. module_tracer.register_as_builtin(LSQ)
  1197. module_tracer.register_as_builtin(TQT)
  1198. module_tracer.register_as_builtin(FakeQuantize)
  1199. module_tracer.register_as_builtin(TM_FakeQuant)
  1200. def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
  1201. """
  1202. Traces module ``mod`` and returns corresponding TracedModule.
  1203. param mod: the module will be converted to TracedModule
  1204. param input: the positional arguments passed to forward method of ``mod``
  1205. param kwargs: the keyword arguments passed to forward method of ``mod``
  1206. """
  1207. assert active_module_tracer() is None
  1208. assert isinstance(mod, Module)
  1209. try:
  1210. use_sym_shape = set_symbolic_shape(True)
  1211. set_module_tracing()
  1212. set_active_module_tracer(module_tracer(_wrapped_function))
  1213. with active_module_tracer().patcher:
  1214. global_scope = InternalGraph(name="")
  1215. active_module_tracer().push_scope(global_scope)
  1216. builder = TracedModuleBuilder(mod, True)
  1217. name = mod._name if mod._name else mod.__class__.__name__
  1218. NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode))
  1219. inputs, _ = tree_flatten((args, kwargs))
  1220. for _, i in enumerate(inputs):
  1221. # assert isinstance(i, Tensor), "not support "
  1222. if isinstance(i, RawTensor):
  1223. NodeMixin.wrap_safe(
  1224. i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
  1225. )
  1226. builder(*args, **kwargs)
  1227. active_module_tracer().pop_scope()
  1228. return builder.build()
  1229. finally:
  1230. set_symbolic_shape(use_sym_shape)
  1231. set_active_module_tracer(None)
  1232. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台