You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 64 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import builtins
  9. import collections
  10. import copy
  11. import ctypes
  12. import fnmatch
  13. import functools
  14. import inspect
  15. import keyword
  16. import re
  17. import weakref
  18. from inspect import getcallargs, getmembers, isclass, ismethod
  19. from itertools import chain
  20. from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
  21. from megengine import tensor
  22. from .. import functional as F
  23. from .. import get_logger
  24. from .. import module as M
  25. from ..core._imperative_rt.core2 import Tensor as RawTensor
  26. from ..core._imperative_rt.core2 import (
  27. is_tracing_module,
  28. set_module_tracing,
  29. unset_module_tracing,
  30. )
  31. from ..core._trace_option import set_symbolic_shape
  32. from ..core.tensor.array_method import ArrayMethodMixin
  33. from ..module import Module
  34. from ..module.qat import QATModule
  35. from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize
  36. from ..quantization.observer import (
  37. ExponentialMovingAverageObserver,
  38. HistogramObserver,
  39. MinMaxObserver,
  40. Observer,
  41. PassiveObserver,
  42. SyncExponentialMovingAverageObserver,
  43. SyncMinMaxObserver,
  44. )
  45. from ..tensor import Tensor
  46. from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
  47. from .fake_quant import FakeQuantize as TM_FakeQuant
  48. from .module_tracer import (
  49. PatchedFn,
  50. Patcher,
  51. active_module_tracer,
  52. get_tensor_wrapable_method,
  53. module_tracer,
  54. set_active_module_tracer,
  55. )
  56. from .node import ModuleNode, Node, NodeMixin, TensorNode
  57. from .pytree import ArgsIndex, tree_flatten
  58. from .utils import replace_container_with_module_container
  59. logger = get_logger(__name__)
  60. def _is_builtin_name(name: str) -> bool:
  61. return (
  62. name in builtins.__dict__
  63. or name in keyword.kwlist
  64. or name in {"inf", "nan", "NoneType"}
  65. )
  66. def _is_leaf(node):
  67. assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
  68. type(node)
  69. )
  70. return isinstance(node, RawTensor)
  71. _enable_node_to_tensor = False
  72. def _convert_node_flag():
  73. return _enable_node_to_tensor
  74. def _set_convert_node_flag(flag: bool = False):
  75. global _enable_node_to_tensor
  76. pre_flag = _enable_node_to_tensor
  77. _enable_node_to_tensor = flag
  78. return pre_flag
  79. def _node_to_tensor(*args, **kwargs):
  80. tensors = []
  81. nodes, tree_def = tree_flatten((args, kwargs))
  82. for n in nodes:
  83. if isinstance(n, TensorNode):
  84. if n.top_graph is not None:
  85. active_module_tracer().current_scope()._add_input(n)
  86. value = n.value
  87. if value is None:
  88. flag = _set_convert_node_flag(False)
  89. unset_module_tracing()
  90. value = F.zeros(shape=n._shape, dtype=n._dtype)
  91. set_module_tracing()
  92. _set_convert_node_flag(flag)
  93. orig_n = NodeMixin.get(value, None)
  94. if orig_n is None or "setitem" not in orig_n._name:
  95. NodeMixin.wrap_safe(value, n)
  96. tensors.append(value)
  97. else:
  98. tensors.append(n)
  99. tensors = tree_def.unflatten(tensors)
  100. return tensors
  101. def _tensor_to_node(tensors):
  102. if tensors is None:
  103. return None
  104. nodes = []
  105. tensors, out_def = tree_flatten(tensors)
  106. for t in tensors:
  107. if isinstance(t, Tensor):
  108. n = NodeMixin.get(t, None)
  109. if isinstance(n, TensorNode):
  110. n.value = t
  111. nodes.append(n)
  112. else:
  113. nodes.append(t)
  114. else:
  115. nodes.append(t)
  116. nodes = out_def.unflatten(nodes)
  117. return nodes
  118. def _wrap_method_to_tensor_node():
  119. def _any_method(name):
  120. def _any(*args, **kwargs):
  121. args, kwargs = _node_to_tensor(*args, **kwargs)
  122. attr = getattr(args[0], name)
  123. outs = attr
  124. if callable(attr):
  125. outs = attr(*(args[1:]), **kwargs)
  126. if name == "__setitem__":
  127. _node_to_tensor(outs)
  128. return None
  129. outs = _tensor_to_node(outs)
  130. return outs
  131. return _any
  132. tensor_method_patch = []
  133. for method in get_tensor_wrapable_method():
  134. patch = PatchedFn(TensorNode, method)
  135. if type(getattr(Tensor, method)) == property:
  136. patch.set_func(property(_any_method(method)))
  137. else:
  138. patch.set_func(_any_method(method))
  139. tensor_method_patch.append(patch)
  140. return tensor_method_patch
  141. def _convert_node_and_tensor(orig_func):
  142. @functools.wraps(orig_func)
  143. def _convert(*args, **kwargs):
  144. if _convert_node_flag() and is_tracing_module():
  145. args, kwargs = _node_to_tensor(*args, **kwargs)
  146. rst = orig_func(*args, **kwargs, method_func=_convert)
  147. rst = _tensor_to_node(rst)
  148. return rst
  149. else:
  150. rst = orig_func(*args, **kwargs)
  151. return rst
  152. return _convert
  153. def _wrap_mnode_getattr(orig_getattr):
  154. @functools.wraps(orig_getattr)
  155. def wraped_fn(self, name):
  156. obj = self.owner
  157. if self.top_graph is not None:
  158. active_module_tracer().current_scope()._add_input(self)
  159. attr = getattr(obj, name)
  160. node = attr
  161. full_name = None
  162. if id(attr) in active_module_tracer().id2name:
  163. full_name = active_module_tracer().id2name[id(attr)]
  164. if not isinstance(attr, TracedModuleBuilder):
  165. if isinstance(attr, Module):
  166. attr = TracedModuleBuilder(attr)
  167. setattr(obj, name, attr)
  168. active_module_tracer().id2name[id(attr)] = full_name
  169. if isinstance(attr, (NodeMixin, RawTensor)):
  170. if full_name:
  171. scope_name = active_module_tracer().current_scope()._module_name
  172. if scope_name:
  173. full_name = full_name[len(scope_name) + 1 :]
  174. else:
  175. full_name = name
  176. else:
  177. full_name = name
  178. NodeMixin.wrap(
  179. attr,
  180. lambda: GetAttr.make(
  181. self,
  182. name,
  183. type=NodeMixin.get_wrapped_type(attr),
  184. orig_name=full_name,
  185. ),
  186. )
  187. if isinstance(attr, (NodeMixin, RawTensor)):
  188. node = NodeMixin.get(attr)
  189. if isinstance(node, ModuleNode):
  190. node._owner = weakref.ref(attr)
  191. return node
  192. return wraped_fn
  193. def _wrap_mnode_call(orig_call):
  194. @functools.wraps(orig_call)
  195. def wraped_fn(self, *args, **kwargs):
  196. obj = self.owner
  197. if self.top_graph is not None:
  198. active_module_tracer().current_scope()._add_input(self)
  199. rst = obj(*args, **kwargs)
  200. return rst
  201. return wraped_fn
  202. def _init_id2name(mod: Module, prefix: str = ""):
  203. id2name = {
  204. id(m): "%s.%s" % (prefix, key)
  205. for key, m in chain(
  206. mod.named_modules(), mod.named_parameters(), mod.named_buffers()
  207. )
  208. }
  209. return id2name
  210. class _InsertExprs:
  211. def __init__(self, graph, expr: Optional[Expr] = None):
  212. self.graph = graph
  213. self.global_scope = InternalGraph(
  214. graph._name, graph._prefix_name, graph._module_name
  215. )
  216. self.global_scope._used_names.update(graph._used_names)
  217. self.expr = expr
  218. self._tensor_method_patch = None
  219. def __enter__(self):
  220. self.use_sym_shape = set_symbolic_shape(True)
  221. set_module_tracing()
  222. _set_convert_node_flag(True)
  223. assert active_module_tracer() is None
  224. module = self.graph.inputs[0].owner
  225. _wrap_func = lambda x: _convert_node_and_tensor(_wrapped_function(x))
  226. set_active_module_tracer(
  227. module_tracer(_wrap_func, _init_id2name(module, self.graph._module_name))
  228. )
  229. active_module_tracer().patcher.__enter__()
  230. for cls, name, func in [
  231. [ModuleNode, "__getattr__", _wrap_mnode_getattr],
  232. [ModuleNode, "__call__", _wrap_mnode_call],
  233. [TracedModuleBuilder, "__call__", _convert_node_and_tensor],
  234. ]:
  235. active_module_tracer().patcher.patch_function(cls, name, func)
  236. self._tensor_method_patch = _wrap_method_to_tensor_node()
  237. active_module_tracer().push_scope(self.global_scope)
  238. def __exit__(self, ty, va, tr):
  239. if va is not None:
  240. return False
  241. set_symbolic_shape(self.use_sym_shape)
  242. unset_module_tracing()
  243. active_module_tracer().patcher.__exit__(ty, va, tr)
  244. _set_convert_node_flag(False)
  245. while self._tensor_method_patch:
  246. pf = self._tensor_method_patch.pop()
  247. pf.set_func(pf.origin_fn)
  248. module = self.graph.inputs[0].owner
  249. for mod, parent in module.modules(with_parent=True):
  250. name = mod._name
  251. if isinstance(mod, TracedModuleBuilder):
  252. mod = mod.build()
  253. if hasattr(mod, "graph"):
  254. for node in mod.graph.nodes():
  255. node.value = None
  256. setattr(parent, name, mod)
  257. set_active_module_tracer(None)
  258. for node in self.global_scope.nodes():
  259. node.value = None
  260. extra_inp_nodes = set(self.global_scope.inputs)
  261. max_inp_expr_idx = -1
  262. for node in extra_inp_nodes:
  263. assert (
  264. node.top_graph == self.graph
  265. ), "The input node ({}) is not in the graph ({})".format(node, self.graph)
  266. if isinstance(node, TensorNode) and node.expr in self.graph._exprs:
  267. max_inp_expr_idx = max(
  268. max_inp_expr_idx, self.graph._exprs.index(node.expr)
  269. )
  270. max_inp_expr_idx += 1
  271. insert_index = -1
  272. if self.expr is not None:
  273. insert_index = self.graph._exprs.index(self.expr)
  274. insert_index += 1
  275. if insert_index < max_inp_expr_idx:
  276. insert_index = max_inp_expr_idx
  277. anchor_index = insert_index - 1
  278. if anchor_index >= 0:
  279. logger.info(
  280. "The new expr will be inserted after ( {} )".format(
  281. self.graph._exprs[anchor_index]
  282. )
  283. )
  284. for expr in self.global_scope._exprs:
  285. self.graph._exprs.insert(insert_index, expr)
  286. insert_index += 1
  287. self.graph._used_names.update(self.global_scope._used_names)
  288. graph = self.graph
  289. while graph.top_graph is not None:
  290. graph = graph.top_graph
  291. graph.inputs[0].owner._update_ref()
  292. return True
  293. class InternalGraph:
  294. r"""``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
  295. Attributes:
  296. _exprs: List of Exprs in order of execution
  297. _inputs: Input Nodes of InternalGraph
  298. _outputs: Output Nodes of InternalGraph
  299. """
  300. _exprs = None # type: List[Expr]
  301. _inputs = None # type: List[Node]
  302. _outputs = None # type: List[Node]
  303. _top_graph = None
  304. def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""):
  305. self._exprs = []
  306. self._inputs = []
  307. self._outputs = []
  308. self._watch_point = []
  309. self._end_point = []
  310. self._used_names = {}
  311. self._rst = collections.defaultdict(list)
  312. self._name = name
  313. self._prefix_name = prefix_name
  314. self._module_name = module_name
  315. def _insert(self, expr):
  316. self._exprs.append(expr)
  317. def _create_unique_name(self, name: str) -> str:
  318. assert isinstance(name, str), "The name must be a str"
  319. name = re.sub("[^0-9a-zA-Z_]+", "_", name)
  320. if name[0].isdigit():
  321. name = "_{}".format(name)
  322. while name in self._used_names or _is_builtin_name(name):
  323. match = re.match(r"(.*)_(\d+)$", name)
  324. if match is None:
  325. name = name + "_1"
  326. else:
  327. base, num = match.group(1, 2)
  328. name = "{}_{}".format(base, int(num) + 1)
  329. self._used_names.setdefault(name)
  330. return name
  331. @property
  332. def inputs(self):
  333. return self._inputs
  334. @property
  335. def outputs(self):
  336. return self._outputs
  337. @property
  338. def top_graph(self):
  339. if self._top_graph:
  340. return self._top_graph()
  341. return None
  342. def exprs(self, recursive=True):
  343. return ExprFilter(_expr_iter(self, recursive))
  344. def nodes(self, recursive=True):
  345. return NodeFilter(_node_iter(self, recursive))
  346. def get_function_by_type(self, func: Callable = None, recursive=True):
  347. return self.exprs(recursive).call_function(func)
  348. def get_method_by_type(self, method: str = None, recursive=True):
  349. return self.exprs(recursive).call_method(method)
  350. def get_expr_by_id(self, expr_id: List[int] = None, recursive=True):
  351. return self.exprs(recursive).expr_id(expr_id)
  352. def get_module_by_type(self, module_cls: Module, recursive=True):
  353. assert issubclass(module_cls, Module)
  354. return self.nodes(recursive).type(module_cls, ModuleNode)
  355. def get_node_by_id(self, node_id: List[int] = None, recursive=True):
  356. return self.nodes(recursive).node_id(node_id)
  357. def get_node_by_name(
  358. self, name: str = None, ignorecase: bool = True, recursive=True
  359. ):
  360. return self.nodes(recursive).name(name, ignorecase)
  361. def _add_input(self, i):
  362. self._inputs.append(i)
  363. def _add_output(self, o):
  364. self._outputs.append(o)
  365. def _replace_inputs_outputs(self, repl_dict, prefix_name="", module_name=""):
  366. for node, repl_node in repl_dict.items():
  367. assert node in self._inputs or node in self._outputs
  368. for i in node.users:
  369. if i not in repl_node.users:
  370. repl_node.users.append(i)
  371. for idx, i in enumerate(self._inputs):
  372. if i in repl_dict:
  373. self._inputs[idx] = repl_dict[i]
  374. for idx, o in enumerate(self._outputs):
  375. if o in repl_dict:
  376. repl_dict[o]._orig_name = "{}{}".format(module_name, o._orig_name)
  377. self._outputs[idx] = repl_dict[o]
  378. for expr in self._exprs:
  379. for idx, i in enumerate(expr.inputs):
  380. assert isinstance(
  381. i._name, str
  382. ), "The node ({}) name must be a str".format(i)
  383. if i in repl_dict:
  384. expr.inputs[idx] = repl_dict[i]
  385. elif isinstance(i, TensorNode) and prefix_name not in i._name:
  386. if i.top_graph != active_module_tracer().current_scope():
  387. i._name = (
  388. active_module_tracer()
  389. .current_scope()
  390. ._create_unique_name(prefix_name + i._name.lstrip("_"))
  391. )
  392. i._orig_name = "{}{}".format(module_name, i._orig_name)
  393. for idx, o in enumerate(expr.outputs):
  394. assert isinstance(
  395. o._name, str
  396. ), "The node ({}) name must be a str".format(i)
  397. if o in repl_dict:
  398. expr.outputs[idx] = repl_dict[o]
  399. expr.outputs[idx].expr = expr
  400. elif isinstance(o, TensorNode) and prefix_name not in i._name:
  401. if o.top_graph != active_module_tracer().current_scope():
  402. o._name = (
  403. active_module_tracer()
  404. .current_scope()
  405. ._create_unique_name(prefix_name + o._name.lstrip("_"))
  406. )
  407. o._orig_name = "{}{}".format(module_name, o._orig_name)
  408. def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
  409. if not isinstance(nodes, Sequence):
  410. nodes = (nodes,)
  411. ret = list()
  412. queue = list(nodes)
  413. visited_queue = list()
  414. while queue:
  415. node = queue.pop()
  416. visited_queue.append(node)
  417. expr = node.expr
  418. if expr not in ret:
  419. ret.append(expr)
  420. for i in expr.inputs:
  421. if i not in queue and i not in visited_queue:
  422. queue.append(i)
  423. return ret
  424. def reset_inputs(self, *args, **kwargs):
  425. forma_mnode = self.inputs[0]
  426. actual_mnodes = forma_mnode.actual_node
  427. call_nodes = []
  428. for n in actual_mnodes:
  429. for c_expr in n.users:
  430. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  431. call_nodes.append((c_expr, n))
  432. moudle = forma_mnode.owner
  433. assert moudle._is_top, "reset_inputs only support the top-level graph"
  434. inputs, tree_def = tree_flatten(((moudle, *args), kwargs))
  435. def create_node(val: Tensor):
  436. node = Input(type=TensorNode).outputs[0]
  437. node.shape = val.shape
  438. node.dtype = val.dtype
  439. return node
  440. formal_node_inputs = [
  441. forma_mnode,
  442. ]
  443. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  444. if call_nodes:
  445. org_argdef = call_nodes[0][0].arg_def
  446. for v in inputs[1:]:
  447. assert isinstance(v, RawTensor)
  448. formal_node_inputs.append(create_node(v))
  449. actual_nodes = []
  450. for e, n in call_nodes:
  451. e.arg_def = tree_def
  452. actual_node_inputs = [
  453. n,
  454. ]
  455. for v in inputs[1:]:
  456. actual_node_inputs.append(create_node(v))
  457. for org_n in e.inputs:
  458. org_n.users.pop(e)
  459. e.inputs[:] = actual_node_inputs
  460. e.const_val = []
  461. actual_nodes.append(actual_node_inputs[1:])
  462. self._inputs[:] = formal_node_inputs
  463. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  464. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  465. # return formal_node_inputs[1:], actual_nodes
  466. return formal_node_inputs[1:]
  467. def add_input_node(self, shape, dtype="float32", name="args"):
  468. forma_mnode = self.inputs[0]
  469. actual_mnodes = forma_mnode.actual_node
  470. moudle = forma_mnode.owner
  471. assert moudle._is_top, "add_input_node only support the top-level graph"
  472. call_nodes = []
  473. for n in actual_mnodes:
  474. for c_expr in n.users:
  475. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  476. call_nodes.append(c_expr)
  477. def create_node(name=None, is_input: bool = True):
  478. if is_input:
  479. node = Input(type=TensorNode, name=name).outputs[0]
  480. else:
  481. node = TensorNode(expr=None, name=None)
  482. node.shape = shape
  483. node.dtype = dtype
  484. return node
  485. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  486. if call_nodes:
  487. org_argdef = call_nodes[0].arg_def
  488. args, kwargs = org_argdef.unflatten(self._inputs)
  489. formal_inp_node = create_node(self._create_unique_name(name), True)
  490. inputs, tree_def = tree_flatten(
  491. ((*args, formal_inp_node), kwargs),
  492. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  493. )
  494. self._inputs[:] = inputs[:]
  495. actual_inp_nodes = []
  496. for e in call_nodes:
  497. args, kwargs = e.unflatten_args(e.inputs)
  498. args = args + (create_node(False),)
  499. inputs, tree_def = tree_flatten(
  500. (args, kwargs),
  501. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  502. )
  503. e.inputs[:] = inputs[:]
  504. e.arg_def = tree_def
  505. actual_inp_nodes.append(args[-1])
  506. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  507. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  508. # return formal_inp_node, actual_inp_nodes
  509. return formal_inp_node
  510. def reset_outputs(self, outputs):
  511. outputs, out_def = tree_flatten(
  512. outputs, is_leaf=lambda x: isinstance(x, TensorNode),
  513. )
  514. forma_mnode = self.inputs[0]
  515. moudle = forma_mnode.owner
  516. assert moudle._is_top, "reset_outputs only support the top-level graph"
  517. actual_mnodes = forma_mnode.actual_node
  518. call_nodes = []
  519. for n in actual_mnodes:
  520. for c_expr in n.users:
  521. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  522. call_nodes.append((c_expr))
  523. def create_node(val: TensorNode, expr: Expr):
  524. node = TensorNode(expr)
  525. node.shape = val.shape
  526. node.dtype = val.dtype
  527. return node
  528. tree_def = list(moudle.argdef_graph_map.keys())[0]
  529. if call_nodes:
  530. tree_def = call_nodes[0].arg_def
  531. actual_nodes = []
  532. for e in call_nodes:
  533. actual_node_outputs = []
  534. for v in outputs:
  535. actual_node_outputs.append(create_node(v, e))
  536. e.outputs[:] = actual_node_outputs
  537. e.out_def = out_def
  538. actual_nodes.append(actual_node_outputs)
  539. self._outputs[:] = outputs
  540. moudle.argdef_outdef_map[tree_def] = out_def
  541. return actual_nodes
  542. def add_output_node(self, node: TensorNode):
  543. forma_mnode = self.inputs[0]
  544. moudle = forma_mnode.owner
  545. assert moudle._is_top, "add_output_node only support the top-level graph"
  546. actual_mnodes = forma_mnode.actual_node
  547. call_nodes = []
  548. for n in actual_mnodes:
  549. for c_expr in n.users:
  550. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  551. call_nodes.append((c_expr))
  552. def create_node(val: TensorNode, expr: Expr):
  553. node = TensorNode(expr)
  554. node.shape = val.shape
  555. node.dtype = val.dtype
  556. return node
  557. tree_def = list(moudle.argdef_graph_map.keys())[0]
  558. if call_nodes:
  559. tree_def = call_nodes[0].arg_def
  560. org_out_def = moudle.argdef_outdef_map[tree_def]
  561. org_outs = org_out_def.unflatten(self._outputs)
  562. outputs, out_def = tree_flatten(
  563. (org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode),
  564. )
  565. self._outputs[:] = outputs
  566. actual_out_nodes = []
  567. for e in call_nodes:
  568. actual_node = create_node(node, e)
  569. org_outs = org_out_def.unflatten(e.outputs)
  570. outputs, out_def = tree_flatten(
  571. (org_outs, actual_node), is_leaf=lambda x: isinstance(x, TensorNode),
  572. )
  573. e.outputs[:] = outputs
  574. e.out_def = out_def
  575. actual_out_nodes.append(actual_node)
  576. moudle.argdef_outdef_map[tree_def] = out_def
  577. return actual_out_nodes
  578. def insert_exprs(self, expr: Optional[Expr] = None):
  579. if expr is not None:
  580. assert expr.top_graph == self, "Expr to insert after is not in graph."
  581. return _InsertExprs(self, expr)
  582. def replace_node(self, repl_dict: Dict[Node, Node]):
  583. while repl_dict:
  584. node, repl_node = repl_dict.popitem()
  585. # check graph inputs and outputs
  586. # assert node not in self.inputs, "Cannot replace inputs"
  587. for i, n in enumerate(self.outputs):
  588. if n is node:
  589. self.outputs[i] = repl_node
  590. # update users of node and repl_node
  591. # update inputs of expr in node.users
  592. graph = repl_node.top_graph
  593. assert graph is not None
  594. index = graph._exprs.index(repl_node.expr)
  595. dep_exprs = self.get_dep_exprs(repl_node)
  596. i = 0
  597. while i < len(node.users):
  598. n = node.users[i]
  599. if n in graph._exprs and index >= graph._exprs.index(n):
  600. i += 1
  601. continue
  602. if n in dep_exprs:
  603. logger.info("Find a loop: ignore this replacement once")
  604. logger.info("node: %s" % node.__repr__())
  605. logger.info("expr: %s" % n.__repr__())
  606. i += 1
  607. continue
  608. repl_node.users.append(n)
  609. node.users.pop(i)
  610. idx = n.inputs.index(node)
  611. n.inputs[idx] = repl_node
  612. def compile(self):
  613. """Delete unused expr."""
  614. dep_exprs = self.get_dep_exprs(self.outputs)
  615. i = 0
  616. while i < len(self._exprs):
  617. expr = self._exprs[i]
  618. if expr in dep_exprs or expr._disable_remove:
  619. i += 1
  620. continue
  621. for n in expr.inputs:
  622. n.users.remove(expr)
  623. self._exprs.remove(expr)
  624. def interpret(self, *inputs):
  625. node2value = {}
  626. end_nodes_set = set(self._end_point)
  627. endnode2value = {}
  628. def get_all_endnode_val(n, v):
  629. if n in end_nodes_set:
  630. endnode2value[n] = v
  631. end_nodes_set.remove(n)
  632. return not end_nodes_set
  633. return False
  634. ref_count = lambda n: len(n.users) + (1 if n in self._outputs else 0)
  635. for n, v in zip(self._inputs, inputs):
  636. if ref_count(n) > 0:
  637. node2value[n] = [v, ref_count(n)]
  638. if n in self._watch_point:
  639. self._rst[n].append(v)
  640. if n in self._end_point and get_all_endnode_val(n, v):
  641. return list(endnode2value[i] for i in self._end_point)
  642. for expr in self._exprs:
  643. values = expr.interpret(*list(node2value[i][0] for i in expr.inputs))
  644. for n in expr.inputs:
  645. node2value[n][1] -= 1
  646. if node2value[n][1] == 0:
  647. node2value.pop(n)
  648. if values is not None:
  649. for n, v in zip(expr.outputs, values):
  650. if ref_count(n) > 0:
  651. node2value[n] = [v, ref_count(n)]
  652. if n in self._watch_point:
  653. self._rst[n] = v
  654. if self._end_point and get_all_endnode_val(n, v):
  655. return list(endnode2value[i] for i in self._end_point)
  656. return list(node2value[i][0] for i in self._outputs)
  657. def eval(self, *inputs):
  658. assert len(inputs) == len(self._inputs) - 1
  659. inp = [self._inputs[0].owner] + list(inputs)
  660. return self.interpret(*inp)
  661. def __repr__(self):
  662. return self.__format__()
  663. def __format__(self, format_spec: str = "") -> str:
  664. saved_format_spec = Node.set_format_spec(format_spec)
  665. name = ""
  666. if self._name:
  667. name = "%s.Graph" % self._name
  668. res = "{} ({}) {{\n\t{}\n\treturn {}\n}}".format(
  669. name,
  670. ", ".join(str(i) for i in self._inputs),
  671. "\n\t".join("{}".format(str(i)) for i in self._exprs),
  672. ", ".join(str(i) for i in self._outputs),
  673. )
  674. Node.set_format_spec(saved_format_spec)
  675. return res
  676. def __getstate__(self):
  677. state = self.__dict__.copy()
  678. if "_top_graph" in state:
  679. state.pop("_top_graph")
  680. return state
  681. def _get_meth_name(obj, func):
  682. tp = obj if isinstance(obj, type) else type(obj)
  683. for cls in tp.mro():
  684. for k, v in cls.__dict__.items():
  685. if v == func:
  686. return k
  687. return None
  688. def _wrapped_function(orig_func):
  689. @functools.wraps(orig_func)
  690. def wrapped_fn(*args, **kwargs):
  691. method_func = wrapped_fn
  692. if "method_func" in kwargs:
  693. method_func = kwargs.pop("method_func")
  694. if is_tracing_module():
  695. unset_module_tracing()
  696. inputs, tree_def = tree_flatten((args, kwargs))
  697. for i in inputs:
  698. if not NodeMixin.get(i, None):
  699. if isinstance(i, (RawTensor, NodeMixin)):
  700. NodeMixin.wrap_safe(i, Constant.make(i))
  701. meth_name, arg_type = None, None
  702. if args:
  703. meth_name = _get_meth_name(args[0], method_func)
  704. arg_type = args[0] if isinstance(args[0], type) else type(args[0])
  705. if meth_name and arg_type and issubclass(arg_type, RawTensor):
  706. self = inputs[0]
  707. if meth_name == "__new__":
  708. if all([not isinstance(i, RawTensor) for i in inputs]):
  709. # only trace Tensor.__new__() when there are tensors in args
  710. set_module_tracing()
  711. return orig_func(*args, **kwargs)
  712. if isinstance(args[1], RawTensor):
  713. node = NodeMixin.get(inputs[1])
  714. inputs[1] = copy.copy(inputs[1])
  715. # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing.
  716. NodeMixin.wrap_safe(inputs[1], node)
  717. args, kwargs = tree_def.unflatten(inputs)
  718. call_node = CallMethod.make(self, meth_name)
  719. else:
  720. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  721. call_node.add_inputs(inputs[1:])
  722. else:
  723. call_node = CallFunction.make(orig_func)
  724. call_node.add_inputs(inputs)
  725. call_node.arg_def = tree_def
  726. rst = orig_func(*args, **kwargs)
  727. if meth_name == "__setitem__":
  728. rst = self
  729. if rst is not None:
  730. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  731. call_node.out_def = out_def
  732. else:
  733. outputs = None
  734. call_node.add_outputs(outputs)
  735. set_module_tracing()
  736. return rst
  737. return orig_func(*args, **kwargs)
  738. return wrapped_fn
  739. class TracedModuleBuilder(NodeMixin):
  740. _mod = None # type: Module
  741. _body = None # type: InternalGraph
  742. _is_builtin = None # type: bool
  743. _argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
  744. _argdef_outdef_map = None # type: Dict[Treedef, Treedef]
  745. nodes = None
  746. __builder_attributes__ = [
  747. "_mod",
  748. "_body",
  749. "_NodeMixin__node",
  750. "_is_builtin",
  751. "build",
  752. "_record_wrapped_nodes",
  753. "_argdef_graph_map",
  754. "_argdef_outdef_map",
  755. "nodes",
  756. "__class__",
  757. "__dict__",
  758. ]
  759. def __init__(self, mod, is_top_module=False):
  760. super(TracedModuleBuilder, self).__init__()
  761. assert isinstance(mod, Module)
  762. self._mod = mod
  763. self._body = None
  764. self._is_top = is_top_module
  765. self._is_builtin = (
  766. True
  767. if isinstance(mod, (Observer, _FakeQuantize))
  768. else module_tracer.is_builtin(mod)
  769. )
  770. if isinstance(self._mod, QATModule):
  771. unset_module_tracing()
  772. self._check_qat_module(self._mod)
  773. set_module_tracing()
  774. self._argdef_graph_map = {}
  775. self._argdef_outdef_map = {}
  776. self.nodes = set()
  777. # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
  778. # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
  779. self.__class__ = type(
  780. "TracedModuleBuilder",
  781. (TracedModuleBuilder, mod.__class__),
  782. dict(TracedModuleBuilder.__dict__),
  783. )
  784. def _check_qat_module(self, qat_module):
  785. def isbuiltin(m):
  786. return m is None or module_tracer.is_builtin(m)
  787. if qat_module.with_act:
  788. act_observer = qat_module.act_observer
  789. act_fakequant = qat_module.act_fake_quant
  790. if not isbuiltin(act_observer) or not isbuiltin(act_fakequant):
  791. qparams = (
  792. act_observer.get_qparams()
  793. if hasattr(act_observer, "get_qparams")
  794. else act_fakequant.get_qparams()
  795. )
  796. dtype = (
  797. act_observer.dtype
  798. if hasattr(act_observer, "dtype")
  799. else act_fakequant.dtype
  800. )
  801. qat_module.act_observer = None
  802. qat_module.act_fake_quant = TM_FakeQuant(dtype)
  803. qat_module.act_fake_quant.set_qparams(qparams)
  804. if qat_module.with_weight:
  805. weight_observer = qat_module.weight_observer
  806. weight_fakequant = qat_module.weight_fake_quant
  807. if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant):
  808. qparams = (
  809. weight_observer.get_qparams()
  810. if hasattr(weight_observer, "get_qparams")
  811. else weight_fakequant.get_qparams()
  812. )
  813. dtype = (
  814. weight_observer.dtype
  815. if hasattr(weight_observer, "dtype")
  816. else weight_fakequant.dtype
  817. )
  818. qat_module.weight_observer = None
  819. qat_module.weight_fake_quant = TM_FakeQuant(dtype)
  820. qat_module.weight_fake_quant.set_qparams(qparams)
  821. def build(self):
  822. if self._is_builtin or isinstance(self._mod, TracedModule):
  823. if module_tracer.is_builtin(self._mod) or isinstance(
  824. self._mod, TracedModule
  825. ):
  826. mod_type = type(self._mod)
  827. else:
  828. assert isinstance(self._mod, (Observer, _FakeQuantize))
  829. mod_type = (
  830. Observer if isinstance(self._mod, Observer) else _FakeQuantize
  831. )
  832. for node in self.nodes:
  833. node.module_type = mod_type
  834. return self._mod
  835. else:
  836. is_qat = isinstance(self._mod, QATModule)
  837. traced_module = TracedModule(
  838. self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat
  839. )
  840. for _, g in self._argdef_graph_map.items():
  841. g.compile()
  842. for k, v in self.__dict__.items():
  843. if k not in TracedModuleBuilder.__builder_attributes__:
  844. if isinstance(v, TracedModuleBuilder):
  845. v = v.build()
  846. setattr(traced_module, k, v)
  847. elif isinstance(v, RawTensor):
  848. setattr(traced_module, k, v)
  849. if isinstance(self._mod, QATModule):
  850. unset_module_tracing()
  851. traced_module.with_act = self._mod.with_act
  852. traced_module.with_weight = self._mod.with_weight
  853. if not hasattr(traced_module, "act_fake_quant"):
  854. traced_module.act_fakequant = None
  855. if not hasattr(traced_module, "act_observer"):
  856. traced_module.act_observer = None
  857. if not hasattr(traced_module, "weight_fake_quant"):
  858. traced_module.weight_fakequant = None
  859. if not hasattr(traced_module, "weight_observer"):
  860. traced_module.weight_observer = None
  861. set_module_tracing()
  862. return traced_module
  863. def _record_wrapped_nodes(self, node):
  864. self.nodes.add(node)
  865. def __call__(self, *args, **kwargs):
  866. assert isinstance(self._mod, Module)
  867. # prepare args and kwargs for inner graph
  868. if "method_func" in kwargs:
  869. kwargs.pop("method_func")
  870. def mark_constant(x):
  871. node = NodeMixin.get(x, None)
  872. if node is None: # capture as constant
  873. NodeMixin.wrap(x, lambda: Constant.make(x))
  874. inputs, tree_def = tree_flatten(((self, *args), kwargs))
  875. for i in inputs:
  876. mark_constant(i)
  877. callnode = CallMethod.make(NodeMixin.get(self))
  878. callnode.add_inputs(inputs[1:])
  879. callnode.arg_def = tree_def
  880. if (
  881. self._is_builtin
  882. or tree_def in self._argdef_graph_map
  883. or isinstance(self._mod, TracedModule)
  884. ):
  885. unset_module_tracing()
  886. rst = self._mod(*args, **kwargs)
  887. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  888. set_module_tracing()
  889. if self._is_builtin:
  890. self._body = None
  891. elif tree_def in self._argdef_graph_map:
  892. self._body = self._argdef_graph_map[tree_def]
  893. else:
  894. self._mod._is_top = False
  895. self._body = self._mod.graph
  896. else:
  897. self_node = None
  898. orig_self = NodeMixin.get(self)
  899. top_graph = active_module_tracer().current_scope()
  900. graph_prefix_name = top_graph._name
  901. if top_graph._prefix_name:
  902. graph_prefix_name = "{}_{}".format(
  903. top_graph._prefix_name, graph_prefix_name.lstrip("_")
  904. )
  905. module_name = orig_self._orig_name
  906. if top_graph._module_name:
  907. module_name = "{}.{}".format(top_graph._module_name, module_name)
  908. self._body = InternalGraph(
  909. orig_self._name, prefix_name=graph_prefix_name, module_name=module_name
  910. )
  911. active_module_tracer().push_scope(self._body)
  912. # rebind self to new input node
  913. if self_node:
  914. NodeMixin.wrap_safe(self, self_node)
  915. active_module_tracer().current_scope()._add_input(self_node)
  916. else:
  917. NodeMixin.wrap_safe(
  918. self,
  919. self_node
  920. if self_node
  921. else Input.make("self", NodeMixin.get_wrapped_type(self), ""),
  922. )
  923. origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
  924. # prepare args and kwargs for inner graph
  925. index_args, index_kwargs = tree_def.unflatten(
  926. [
  927. ArgsIndex(0),
  928. *list(ArgsIndex(i + 1) for i in range(len(origin_inp_node))),
  929. ]
  930. )
  931. key2idx = getcallargs(type(self._mod).forward, *index_args, **index_kwargs)
  932. idx2key = {}
  933. for k, v in key2idx.items():
  934. if isinstance(v, ArgsIndex):
  935. idx2key[v.index] = k
  936. else:
  937. flatten_argidx, _ = tree_flatten(v)
  938. for _i, v in enumerate(flatten_argidx):
  939. if isinstance(v, ArgsIndex):
  940. idx2key[v.index] = k + "_%d" % _i
  941. def wrap(x, name):
  942. if isinstance(x, (RawTensor, NodeMixin)):
  943. NodeMixin.wrap(
  944. x,
  945. lambda: Input.make(
  946. type=NodeMixin.get_wrapped_type(x), name=name
  947. ),
  948. )
  949. return x
  950. args = [self]
  951. for i, v in enumerate(inputs[1:]):
  952. args.append(wrap(v, idx2key[i + 1]))
  953. args, kwargs = tree_def.unflatten(args)
  954. active_module_tracer().patcher.auto_patch(
  955. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  956. )
  957. rst = type(self._mod).forward(*args, **kwargs)
  958. if _convert_node_flag():
  959. rst = _node_to_tensor(rst)[0][0]
  960. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  961. for i in (
  962. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  963. ):
  964. active_module_tracer().current_scope()._add_output(NodeMixin.get(i))
  965. NodeMixin.wrap_safe(self, orig_self)
  966. for arg, node in zip(inputs[1:], origin_inp_node):
  967. if node:
  968. NodeMixin.wrap_safe(arg, node)
  969. active_module_tracer().pop_scope()
  970. # rebind output to outer graph
  971. callnode.out_def = out_def
  972. callnode.add_outputs(outputs)
  973. self._argdef_graph_map[callnode.arg_def] = self._body
  974. self._argdef_outdef_map[callnode.arg_def] = out_def
  975. return rst
  976. def __setattr__(self, name, value):
  977. object.__setattr__(self, name, value)
  978. def __repr__(self):
  979. return repr(self._mod)
  980. def __getattr__(self, name):
  981. if name not in self._mod.__dict__:
  982. attr = getattr(type(self._mod), name).__get__(self, type(self))
  983. else:
  984. attr = getattr(self._mod, name)
  985. full_name = None
  986. if id(attr) in active_module_tracer().id2name:
  987. full_name = active_module_tracer().id2name[id(attr)]
  988. if isinstance(attr, (List, Dict)):
  989. unset_module_tracing()
  990. has_module, m_container = replace_container_with_module_container(attr)
  991. if m_container:
  992. attr = m_container
  993. if has_module and not m_container:
  994. raise ValueError(
  995. "Can not trace the module that uses the same container to store Module and Non-Module objects "
  996. )
  997. set_module_tracing()
  998. if isinstance(attr, Module):
  999. attr = TracedModuleBuilder(attr)
  1000. if isinstance(attr, (Module, RawTensor)):
  1001. setattr(self, name, attr)
  1002. active_module_tracer().id2name[id(attr)] = full_name
  1003. if full_name:
  1004. scope_name = active_module_tracer().current_scope()._module_name
  1005. if scope_name:
  1006. full_name = full_name[len(scope_name) + 1 :]
  1007. else:
  1008. full_name = name
  1009. else:
  1010. full_name = name
  1011. NodeMixin.wrap(
  1012. attr,
  1013. lambda: GetAttr.make(
  1014. NodeMixin.get(self),
  1015. name,
  1016. type=NodeMixin.get_wrapped_type(attr),
  1017. orig_name=full_name,
  1018. ),
  1019. )
  1020. return attr
  1021. def __getattribute__(self, name):
  1022. if name in TracedModuleBuilder.__builder_attributes__:
  1023. return object.__getattribute__(self, name)
  1024. else:
  1025. wrapped = object.__getattribute__(self, name)
  1026. class_members = dict(inspect.getmembers(self.__class__))
  1027. if name in self._mod.__dict__:
  1028. mod_attr = getattr(self._mod, name)
  1029. if name in class_members:
  1030. if (
  1031. not isinstance(wrapped, TracedModuleBuilder)
  1032. and wrapped is not mod_attr
  1033. ):
  1034. wrapped = self.__getattr__(name)
  1035. if isinstance(wrapped, TracedModuleBuilder):
  1036. if not isinstance(mod_attr, (List, Dict)):
  1037. assert mod_attr is wrapped._mod
  1038. else:
  1039. assert mod_attr is wrapped
  1040. full_name = None
  1041. if id(mod_attr) in active_module_tracer().id2name:
  1042. full_name = active_module_tracer().id2name[id(mod_attr)]
  1043. scope_name = active_module_tracer().current_scope()._module_name
  1044. if full_name and scope_name:
  1045. full_name = full_name[len(scope_name) + 1 :]
  1046. else:
  1047. full_name = name
  1048. else:
  1049. full_name = name
  1050. # assert not self._is_builtin
  1051. if isinstance(wrapped, (NodeMixin, RawTensor)):
  1052. NodeMixin.wrap(
  1053. wrapped,
  1054. lambda: GetAttr.make(
  1055. NodeMixin.get(self),
  1056. name,
  1057. type=NodeMixin.get_wrapped_type(wrapped),
  1058. orig_name=full_name,
  1059. ),
  1060. )
  1061. return wrapped
  1062. class _expr_iter:
  1063. def __init__(self, graph: InternalGraph, recursive: bool = True):
  1064. self.graph = graph
  1065. self.recursive = recursive
  1066. def __iter__(self):
  1067. for expr in self.graph._exprs:
  1068. if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
  1069. yield expr
  1070. if self.recursive and expr.graph is not None:
  1071. yield from expr.graph.exprs(self.recursive)
  1072. else:
  1073. yield expr
  1074. class _node_iter:
  1075. def __init__(self, graph: InternalGraph, recursive: bool = True) -> None:
  1076. nodes = []
  1077. node_ids = set()
  1078. for expr in graph.exprs(recursive):
  1079. for n in expr.inputs + expr.outputs:
  1080. if n._id in node_ids:
  1081. continue
  1082. nodes.append(n)
  1083. node_ids.add(n._id)
  1084. self.nodes = list(sorted(nodes, key=lambda x: x._id))
  1085. def __iter__(self):
  1086. for node in self.nodes:
  1087. yield node
  1088. class BaseFilter:
  1089. def __init__(self, expr_iter: Iterable):
  1090. self._iter = expr_iter
  1091. def __iter__(self):
  1092. return iter(self._iter)
  1093. def as_list(self):
  1094. return list(self)
  1095. def as_dict(self):
  1096. return collections.OrderedDict((i._id, i) for i in self)
  1097. def as_unique(self):
  1098. rst = self.as_list()
  1099. assert len(rst) == 1, "{} elements found".format(len(rst))
  1100. (expr,) = self
  1101. return expr
  1102. def as_count(self):
  1103. return sum(1 for _ in self)
  1104. class ExprFilter(BaseFilter):
  1105. def call_function(self, func):
  1106. return ExprFilterCallFunction(self, func)
  1107. def call_method(self, method):
  1108. return ExprFilterCallMethod(self, method)
  1109. def expr_id(self, expr_id: List[int]):
  1110. return ExprFilterExprId(self, expr_id)
  1111. class NodeFilter(BaseFilter):
  1112. def type(self, owner_type, node_type):
  1113. return NodeFilterType(self, owner_type, node_type)
  1114. def node_id(self, node_id: List[int]):
  1115. return NodeFilterNodeId(self, node_id)
  1116. def name(self, name: str, ignorecase: bool = True):
  1117. return NodeFilterName(self, name, ignorecase)
  1118. class NodeFilterType(NodeFilter):
  1119. def __init__(self, expr_iter, owner_type, node_type):
  1120. super().__init__(expr_iter)
  1121. self.owner_type = owner_type
  1122. self.node_type = node_type
  1123. def __iter__(self):
  1124. for node in self._iter:
  1125. if not isinstance(node, self.node_type):
  1126. continue
  1127. if not hasattr(node, "owner"):
  1128. continue
  1129. if isinstance(node.owner, self.owner_type):
  1130. yield node
  1131. class NodeFilterNodeId(NodeFilter):
  1132. def __init__(self, expr_iter, node_id: List[int]):
  1133. super().__init__(expr_iter)
  1134. if not isinstance(node_id, Sequence):
  1135. node_id = [node_id]
  1136. self.node_id = node_id
  1137. def __iter__(self):
  1138. for node in self._iter:
  1139. if node._id in self.node_id:
  1140. yield node
  1141. class NodeFilterName(NodeFilter):
  1142. _re = None
  1143. def __init__(self, node_iter, pattern, ignorecase):
  1144. super().__init__(node_iter)
  1145. self.pattern = pattern
  1146. self._re = self.make_re(pattern, ignorecase)
  1147. @classmethod
  1148. def make_re(cls, pattern, ignorecase=True):
  1149. assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern)
  1150. assert isinstance(ignorecase, bool)
  1151. flags = 0
  1152. if ignorecase:
  1153. flags |= re.IGNORECASE
  1154. return re.compile(fnmatch.translate(pattern), flags=flags)
  1155. def __iter__(self):
  1156. for i in self._iter:
  1157. graph = i.top_graph
  1158. name = "{}_{}".format(graph._name, i._name.lstrip("_"))
  1159. if graph._prefix_name:
  1160. name = "{}_{}".format(graph._prefix_name, name.lstrip("_"))
  1161. if self.pattern == name or self._re.match(name):
  1162. yield i
  1163. class ExprFilterCallFunction(ExprFilter):
  1164. def __init__(self, expr_iter, func: Callable = None):
  1165. super().__init__(expr_iter)
  1166. self.func = func
  1167. def __iter__(self):
  1168. for expr in self._iter:
  1169. if not isinstance(expr, CallFunction):
  1170. continue
  1171. if self.func is None or expr.func == self.func:
  1172. yield expr
  1173. class ExprFilterCallMethod(ExprFilter):
  1174. def __init__(self, expr_iter, method: str = None):
  1175. super().__init__(expr_iter)
  1176. self.method = method
  1177. def __iter__(self):
  1178. for expr in self._iter:
  1179. if not isinstance(expr, CallMethod):
  1180. continue
  1181. if self.method is None or expr.method == self.method:
  1182. yield expr
  1183. class ExprFilterExprId(ExprFilter):
  1184. def __init__(self, expr_iter, expr_id: List[int]):
  1185. super().__init__(expr_iter)
  1186. if not isinstance(expr_id, Sequence):
  1187. expr_id = [expr_id]
  1188. self.expr_id = expr_id
  1189. def __iter__(self):
  1190. for expr in self._iter:
  1191. if expr._id in self.expr_id:
  1192. yield expr
  1193. class TracedModule(Module):
  1194. r"""`TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it."""
  1195. # m_node = None # type: ModuleNode
  1196. argdef_graph_map = None
  1197. argdef_outdef_map = None
  1198. def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False):
  1199. super(TracedModule, self).__init__()
  1200. self.argdef_graph_map = argdef_graph_map
  1201. self.argdef_outdef_map = argdef_outdef_map
  1202. self._is_top = is_top
  1203. self.watch_points = []
  1204. self.watch_node_value = {}
  1205. self.end_points = []
  1206. self.is_qat = is_qat
  1207. def forward(self, *args, **kwargs):
  1208. inputs, treedef = tree_flatten(((self, *args), kwargs))
  1209. assert treedef in self.argdef_graph_map
  1210. inputs = filter(
  1211. lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
  1212. ) # allow TracedModuleBuilder for retrace.
  1213. outputs = self.argdef_graph_map[treedef].interpret(*inputs)
  1214. if self.watch_points:
  1215. self.watch_node_value = {}
  1216. for n in self.watch_points:
  1217. self.watch_node_value[n] = n.top_graph._rst.pop(n)
  1218. if self.end_points:
  1219. return outputs
  1220. out_def = self.argdef_outdef_map[treedef]
  1221. outputs = out_def.unflatten(outputs)
  1222. return outputs
  1223. def set_watch_points(self, nodes):
  1224. if not isinstance(nodes, Sequence):
  1225. nodes = [nodes]
  1226. self.watch_points = nodes
  1227. for n in nodes:
  1228. n.top_graph._watch_point.append(n)
  1229. def clear_watch_points(self):
  1230. for n in self.watch_points:
  1231. n.top_graph._watch_point = []
  1232. self.watch_points = []
  1233. self.watch_node_value = {}
  1234. def set_end_points(self, nodes):
  1235. if not isinstance(nodes, Sequence):
  1236. nodes = [nodes]
  1237. self.end_points = nodes
  1238. graphs = list(self.argdef_graph_map.values())
  1239. for n in nodes:
  1240. assert n.top_graph in graphs
  1241. n.top_graph._end_point.append(n)
  1242. def clear_end_points(self):
  1243. for n in self.end_points:
  1244. n.top_graph._end_point = []
  1245. self.end_points = []
  1246. @property
  1247. def graph(self) -> InternalGraph:
  1248. if self._is_top:
  1249. self._update_ref()
  1250. assert len(self.argdef_graph_map) == 1
  1251. return list(self.argdef_graph_map.values())[0]
  1252. def _update_ref(self, actual_node_map: Union[Dict] = None, top_graph=None):
  1253. for inp_def, graph in self.argdef_graph_map.items():
  1254. if top_graph is not None:
  1255. graph._top_graph = weakref.ref(top_graph)
  1256. for n in graph._inputs + graph.outputs:
  1257. n._top_graph = weakref.ref(graph)
  1258. graph._inputs[0]._owner = weakref.ref(self)
  1259. for i, n in enumerate(graph._inputs):
  1260. n.actual_node = []
  1261. if actual_node_map is not None and inp_def in actual_node_map.keys():
  1262. n.actual_node = list(list(zip(*(actual_node_map[inp_def])))[i])
  1263. node2obj = {}
  1264. next_actual_node_map = collections.defaultdict(
  1265. lambda: collections.defaultdict(list)
  1266. )
  1267. node2obj[graph._inputs[0]] = self
  1268. for expr in graph._exprs:
  1269. for n in expr.inputs + expr.outputs:
  1270. n._top_graph = weakref.ref(graph)
  1271. expr._top_graph = weakref.ref(graph)
  1272. if isinstance(expr, GetAttr) and isinstance(
  1273. expr.outputs[0], ModuleNode
  1274. ):
  1275. obj = getattr(node2obj[expr.inputs[0]], expr.name)
  1276. expr.outputs[0]._owner = weakref.ref(obj)
  1277. node2obj[expr.outputs[0]] = obj
  1278. if isinstance(expr, Constant) and isinstance(
  1279. expr.outputs[0], ModuleNode
  1280. ):
  1281. obj = expr.value
  1282. expr.outputs[0]._owner = weakref.ref(obj)
  1283. node2obj[expr.outputs[0]] = obj
  1284. if (
  1285. isinstance(expr, CallMethod)
  1286. and expr.method == "__call__"
  1287. and isinstance(expr.inputs[0], ModuleNode)
  1288. ):
  1289. obj = node2obj[expr.inputs[0]]
  1290. if expr.arg_def is not None:
  1291. next_actual_node_map[obj][expr.arg_def].append(expr.inputs)
  1292. for obj in node2obj.values():
  1293. if obj is self:
  1294. continue
  1295. mnode_map = None
  1296. if obj in next_actual_node_map.keys():
  1297. mnode_map = next_actual_node_map[obj]
  1298. if isinstance(obj, TracedModule):
  1299. obj._update_ref(mnode_map, graph)
  1300. def flatten(self):
  1301. r"""Get a new module, which eliminates ``GetAttr`` and has no hierarchy.
  1302. :return: :class:`TracedModule`
  1303. """
  1304. new_module = copy.deepcopy(self)
  1305. assert active_module_tracer() is None
  1306. id2name = _init_id2name(new_module, "self")
  1307. set_active_module_tracer(module_tracer(lambda x: x, {}))
  1308. active_module_tracer().push_scope(new_module.graph)
  1309. def _flatten_subgraph(
  1310. graph: InternalGraph,
  1311. module: Module,
  1312. call=None,
  1313. prefix_name="",
  1314. module_name="",
  1315. ):
  1316. if isinstance(prefix_name, str) and prefix_name and prefix_name[-1] != "_":
  1317. prefix_name += "_"
  1318. if isinstance(module_name, str) and module_name:
  1319. module_name += "."
  1320. if graph is None or module.is_qat:
  1321. assert not isinstance(module, TracedModule) or module.is_qat
  1322. const = Constant(module, id2name[id(module)])
  1323. m_node = call.inputs[0]
  1324. if m_node.top_graph != active_module_tracer().current_scope():
  1325. m_node._name = (
  1326. active_module_tracer()
  1327. .current_scope()
  1328. ._create_unique_name(prefix_name)
  1329. )
  1330. m_node._orig_name = id2name[id(module)][5:]
  1331. const.outputs[0] = m_node
  1332. const.outputs[0].expr = const
  1333. return [const, call]
  1334. if call is not None:
  1335. graph = copy.deepcopy(graph)
  1336. exprs = []
  1337. node2obj = {}
  1338. node2obj[graph._inputs[0]] = module
  1339. if call:
  1340. node2obj[call.inputs[0]] = module
  1341. # replace inputs for submodule's exprx
  1342. if call:
  1343. repl_dict = dict(zip(graph._inputs, call.inputs))
  1344. for ind, out in enumerate(graph.outputs):
  1345. if isinstance(out.expr, Input):
  1346. assert out in repl_dict
  1347. call_out = call.outputs[ind]
  1348. for expr in call.outputs[ind].users:
  1349. for index, inp in enumerate(expr.inputs):
  1350. if inp is call_out:
  1351. expr.inputs[index] = repl_dict[out]
  1352. repl_dict[out].users.append(expr)
  1353. continue
  1354. repl_dict[out] = call.outputs[ind]
  1355. graph._replace_inputs_outputs(repl_dict, prefix_name, module_name)
  1356. for expr in graph._exprs:
  1357. if isinstance(expr, GetAttr):
  1358. # replace GetAttr with Constant
  1359. if isinstance(expr.outputs[0], TensorNode):
  1360. const = Constant(getattr(node2obj[expr.inputs[0]], expr.name))
  1361. const.outputs = expr.outputs
  1362. const.outputs[0].expr = const
  1363. exprs.append(const)
  1364. elif isinstance(expr.outputs[0], ModuleNode):
  1365. node2obj[expr.outputs[0]] = getattr(
  1366. node2obj[expr.inputs[0]], expr.name
  1367. )
  1368. elif isinstance(expr, CallMethod):
  1369. obj_node = expr.inputs[0]
  1370. if isinstance(obj_node, ModuleNode):
  1371. pre_expr = expr.inputs[0].expr
  1372. if isinstance(pre_expr, GetAttr):
  1373. (obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]])
  1374. expr_graph = (
  1375. obj.argdef_graph_map[expr.arg_def]
  1376. if hasattr(obj, "argdef_graph_map")
  1377. else None
  1378. )
  1379. exprs.extend(
  1380. _flatten_subgraph(
  1381. expr_graph,
  1382. obj,
  1383. expr,
  1384. prefix_name + obj_node._name.lstrip("_"),
  1385. module_name + obj_node._orig_name,
  1386. )
  1387. )
  1388. else:
  1389. # module has been replaced.
  1390. assert isinstance(pre_expr, Constant)
  1391. exprs.append(expr)
  1392. else:
  1393. exprs.append(expr)
  1394. else:
  1395. exprs.append(expr)
  1396. if call is not None:
  1397. for i in call.inputs:
  1398. i.users.remove(call)
  1399. return exprs
  1400. new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
  1401. new_module.graph.compile()
  1402. set_active_module_tracer(None)
  1403. for _id, expr in enumerate(new_module.graph._exprs):
  1404. expr._id = _id
  1405. total_node_id = 0
  1406. for i in new_module.graph._inputs:
  1407. i._id = total_node_id
  1408. total_node_id += 1
  1409. for expr in new_module.graph._exprs:
  1410. for o in expr.outputs:
  1411. o._id = total_node_id
  1412. total_node_id += 1
  1413. return new_module
  1414. def __getstate__(self):
  1415. d = self.__dict__
  1416. for k in Module.__dict__:
  1417. d.pop(k, None)
  1418. return d
  1419. def cpp_apply_module_trace(opdef, *args):
  1420. return Apply.apply_module_trace_hook(opdef, *args)
  1421. def register_as_builtin(mod_cls: Type[Module]) -> None:
  1422. r"""Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
  1423. Args:
  1424. mod_cls: the Module class which will be threated as builtin module in tracing
  1425. """
  1426. module_tracer.register_as_builtin(mod_cls)
  1427. def wrap(func: Callable):
  1428. r"""Call this function to register func as a builtin function."""
  1429. assert callable(func), "func must be a callable"
  1430. assert hasattr(func, "__code__")
  1431. fn_name = func.__code__.co_name
  1432. currentframe = inspect.currentframe()
  1433. assert currentframe is not None
  1434. f = currentframe.f_back
  1435. assert f is not None
  1436. assert (
  1437. f.f_code.co_name == "<module>"
  1438. ), "wrap must be called at the top level of a module"
  1439. Patcher._builtin_functions.append((f.f_globals, fn_name))
  1440. return func
  1441. def _register_all_builtin_module():
  1442. for sub_mod in [M, M.qat, M.quantized]:
  1443. for m in getmembers(sub_mod):
  1444. if (
  1445. isclass(m[1])
  1446. and issubclass(m[1], M.Module)
  1447. and m[1] is not M.Sequential
  1448. ):
  1449. module_tracer.register_as_builtin(m[1])
  1450. module_tracer.register_as_builtin(Observer)
  1451. module_tracer.register_as_builtin(MinMaxObserver)
  1452. module_tracer.register_as_builtin(SyncMinMaxObserver)
  1453. module_tracer.register_as_builtin(ExponentialMovingAverageObserver)
  1454. module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver)
  1455. module_tracer.register_as_builtin(HistogramObserver)
  1456. module_tracer.register_as_builtin(PassiveObserver)
  1457. module_tracer.register_as_builtin(LSQ)
  1458. module_tracer.register_as_builtin(TQT)
  1459. module_tracer.register_as_builtin(FakeQuantize)
  1460. module_tracer.register_as_builtin(TM_FakeQuant)
  1461. def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
  1462. r"""Traces module ``mod`` and returns corresponding TracedModule.
  1463. Args:
  1464. mod: the module will be converted to TracedModule
  1465. input: the positional arguments passed to forward method of ``mod``
  1466. kwargs: the keyword arguments passed to forward method of ``mod``
  1467. """
  1468. assert active_module_tracer() is None
  1469. assert isinstance(mod, Module)
  1470. try:
  1471. use_sym_shape = set_symbolic_shape(True)
  1472. set_module_tracing()
  1473. set_active_module_tracer(
  1474. module_tracer(_wrapped_function, _init_id2name(mod, "self"))
  1475. )
  1476. with active_module_tracer().patcher:
  1477. global_scope = InternalGraph(name="")
  1478. active_module_tracer().push_scope(global_scope)
  1479. builder = TracedModuleBuilder(mod, True)
  1480. name = mod._name if mod._name else mod.__class__.__name__
  1481. NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode, orig_name="self"))
  1482. inputs, _ = tree_flatten((args, kwargs))
  1483. for _, i in enumerate(inputs):
  1484. # assert isinstance(i, Tensor), "not support "
  1485. if isinstance(i, RawTensor):
  1486. NodeMixin.wrap_safe(
  1487. i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
  1488. )
  1489. builder(*args, **kwargs)
  1490. active_module_tracer().pop_scope()
  1491. return builder.build()
  1492. finally:
  1493. set_symbolic_shape(use_sym_shape)
  1494. set_active_module_tracer(None)
  1495. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台