You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 66 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import builtins
  9. import collections
  10. import copy
  11. import ctypes
  12. import fnmatch
  13. import functools
  14. import inspect
  15. import keyword
  16. import re
  17. import weakref
  18. from inspect import getcallargs, getmembers, isclass, ismethod
  19. from itertools import chain
  20. from types import FunctionType
  21. from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
  22. from megengine import tensor
  23. from .. import functional as F
  24. from .. import get_logger
  25. from .. import module as M
  26. from ..core._imperative_rt.core2 import Tensor as RawTensor
  27. from ..core._imperative_rt.core2 import (
  28. is_tracing_module,
  29. set_module_tracing,
  30. unset_module_tracing,
  31. )
  32. from ..core._trace_option import set_symbolic_shape
  33. from ..core.tensor.array_method import ArrayMethodMixin
  34. from ..module import Module
  35. from ..module.qat import QATModule
  36. from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize
  37. from ..quantization.observer import (
  38. ExponentialMovingAverageObserver,
  39. HistogramObserver,
  40. MinMaxObserver,
  41. Observer,
  42. PassiveObserver,
  43. SyncExponentialMovingAverageObserver,
  44. SyncMinMaxObserver,
  45. )
  46. from ..tensor import Tensor
  47. from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
  48. from .fake_quant import FakeQuantize as TM_FakeQuant
  49. from .module_tracer import (
  50. PatchedFn,
  51. Patcher,
  52. active_module_tracer,
  53. get_tensor_wrapable_method,
  54. module_tracer,
  55. set_active_module_tracer,
  56. )
  57. from .node import ModuleNode, Node, NodeMixin, TensorNode
  58. from .pytree import ArgsIndex, tree_flatten
  59. from .utils import replace_container_with_module_container
  60. logger = get_logger(__name__)
  61. def _is_builtin_name(name: str) -> bool:
  62. return (
  63. name in builtins.__dict__
  64. or name in keyword.kwlist
  65. or name in {"inf", "nan", "NoneType"}
  66. )
  67. def _is_leaf(node):
  68. assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
  69. type(node)
  70. )
  71. return isinstance(node, RawTensor)
  72. _enable_node_to_tensor = False
  73. def _convert_node_flag():
  74. return _enable_node_to_tensor
  75. def _set_convert_node_flag(flag: bool = False):
  76. global _enable_node_to_tensor
  77. pre_flag = _enable_node_to_tensor
  78. _enable_node_to_tensor = flag
  79. return pre_flag
  80. def _node_to_tensor(*args, **kwargs):
  81. tensors = []
  82. nodes, tree_def = tree_flatten((args, kwargs))
  83. for n in nodes:
  84. if isinstance(n, TensorNode):
  85. if n.top_graph is not None:
  86. active_module_tracer().current_scope()._add_input(n)
  87. value = n.value
  88. if value is None:
  89. flag = _set_convert_node_flag(False)
  90. unset_module_tracing()
  91. value = F.zeros(shape=n._shape, dtype=n._dtype)
  92. set_module_tracing()
  93. _set_convert_node_flag(flag)
  94. orig_n = NodeMixin.get(value, None)
  95. if orig_n is None or "setitem" not in orig_n._name:
  96. NodeMixin.wrap_safe(value, n)
  97. tensors.append(value)
  98. else:
  99. tensors.append(n)
  100. tensors = tree_def.unflatten(tensors)
  101. return tensors
  102. def _tensor_to_node(tensors):
  103. if tensors is None:
  104. return None
  105. nodes = []
  106. tensors, out_def = tree_flatten(tensors)
  107. for t in tensors:
  108. if isinstance(t, Tensor):
  109. n = NodeMixin.get(t, None)
  110. if isinstance(n, TensorNode):
  111. n.value = t
  112. nodes.append(n)
  113. else:
  114. nodes.append(t)
  115. else:
  116. nodes.append(t)
  117. nodes = out_def.unflatten(nodes)
  118. return nodes
  119. def _wrap_method_to_tensor_node():
  120. def _any_method(name):
  121. def _any(*args, **kwargs):
  122. args, kwargs = _node_to_tensor(*args, **kwargs)
  123. attr = getattr(args[0], name)
  124. outs = attr
  125. if callable(attr):
  126. outs = attr(*(args[1:]), **kwargs)
  127. if name == "__setitem__":
  128. _node_to_tensor(outs)
  129. return None
  130. outs = _tensor_to_node(outs)
  131. return outs
  132. return _any
  133. tensor_method_patch = []
  134. for method in get_tensor_wrapable_method():
  135. patch = PatchedFn(TensorNode, method)
  136. if type(getattr(Tensor, method)) == property:
  137. patch.set_func(property(_any_method(method)))
  138. else:
  139. patch.set_func(_any_method(method))
  140. tensor_method_patch.append(patch)
  141. return tensor_method_patch
  142. def _convert_node_and_tensor(orig_func):
  143. @functools.wraps(orig_func)
  144. def _convert(*args, **kwargs):
  145. if _convert_node_flag() and is_tracing_module():
  146. args, kwargs = _node_to_tensor(*args, **kwargs)
  147. rst = orig_func(*args, **kwargs, method_func=_convert)
  148. rst = _tensor_to_node(rst)
  149. return rst
  150. else:
  151. rst = orig_func(*args, **kwargs)
  152. return rst
  153. return _convert
  154. def _wrap_mnode_getattr(orig_getattr):
  155. @functools.wraps(orig_getattr)
  156. def wraped_fn(self, name):
  157. obj = self.owner
  158. if self.top_graph is not None:
  159. active_module_tracer().current_scope()._add_input(self)
  160. attr = getattr(obj, name)
  161. node = attr
  162. full_name = None
  163. if id(attr) in active_module_tracer().id2name:
  164. full_name = active_module_tracer().id2name[id(attr)]
  165. if not isinstance(attr, TracedModuleBuilder):
  166. if isinstance(attr, Module):
  167. attr = TracedModuleBuilder(attr)
  168. setattr(obj, name, attr)
  169. active_module_tracer().id2name[id(attr)] = full_name
  170. if isinstance(attr, (NodeMixin, RawTensor)):
  171. if full_name:
  172. scope_name = active_module_tracer().current_scope()._module_name
  173. if scope_name:
  174. full_name = full_name[len(scope_name) + 1 :]
  175. else:
  176. full_name = name
  177. else:
  178. full_name = name
  179. NodeMixin.wrap(
  180. attr,
  181. lambda: GetAttr.make(
  182. self,
  183. name,
  184. type=NodeMixin.get_wrapped_type(attr),
  185. orig_name=full_name,
  186. ),
  187. )
  188. if isinstance(attr, (NodeMixin, RawTensor)):
  189. node = NodeMixin.get(attr)
  190. if isinstance(node, ModuleNode):
  191. node._owner = weakref.ref(attr)
  192. return node
  193. return wraped_fn
  194. def _wrap_mnode_call(orig_call):
  195. @functools.wraps(orig_call)
  196. def wraped_fn(self, *args, **kwargs):
  197. obj = self.owner
  198. if self.top_graph is not None:
  199. active_module_tracer().current_scope()._add_input(self)
  200. rst = obj(*args, **kwargs)
  201. return rst
  202. return wraped_fn
  203. def _init_id2name(mod: Module, prefix: str = ""):
  204. id2name = {
  205. id(m): "%s.%s" % (prefix, key)
  206. for key, m in chain(
  207. mod.named_modules(), mod.named_parameters(), mod.named_buffers()
  208. )
  209. }
  210. return id2name
  211. class _InsertExprs:
  212. def __init__(self, graph, expr: Optional[Expr] = None):
  213. self.graph = graph
  214. while graph.top_graph is not None:
  215. graph = graph.top_graph
  216. assert graph.inputs[0].owner._is_top
  217. self.root_graph = graph
  218. self.global_scope = InternalGraph(
  219. graph._name, graph._prefix_name, graph._module_name
  220. )
  221. self.global_scope._used_names.update(graph._used_names)
  222. self.expr = expr
  223. self._tensor_method_patch = None
  224. def __enter__(self):
  225. self.use_sym_shape = set_symbolic_shape(True)
  226. node_id, expr_id = self.root_graph._total_ids
  227. Node.set_total_id(node_id)
  228. Expr.set_total_id(expr_id)
  229. set_module_tracing()
  230. _set_convert_node_flag(True)
  231. assert active_module_tracer() is None
  232. module = self.graph.inputs[0].owner
  233. _wrap_func = lambda x: _convert_node_and_tensor(_wrapped_function(x))
  234. set_active_module_tracer(
  235. module_tracer(_wrap_func, _init_id2name(module, self.graph._module_name))
  236. )
  237. active_module_tracer().patcher.__enter__()
  238. for cls, name, func in [
  239. [ModuleNode, "__getattr__", _wrap_mnode_getattr],
  240. [ModuleNode, "__call__", _wrap_mnode_call],
  241. [TracedModuleBuilder, "__call__", _convert_node_and_tensor],
  242. ]:
  243. active_module_tracer().patcher.patch_function(cls, name, func)
  244. self._tensor_method_patch = _wrap_method_to_tensor_node()
  245. active_module_tracer().push_scope(self.global_scope)
  246. def __exit__(self, ty, va, tr):
  247. if va is not None:
  248. return False
  249. set_symbolic_shape(self.use_sym_shape)
  250. unset_module_tracing()
  251. active_module_tracer().patcher.__exit__(ty, va, tr)
  252. _set_convert_node_flag(False)
  253. while self._tensor_method_patch:
  254. pf = self._tensor_method_patch.pop()
  255. pf.set_func(pf.origin_fn)
  256. module = self.graph.inputs[0].owner
  257. for mod, parent in module.modules(with_parent=True):
  258. name = mod._name
  259. if isinstance(mod, TracedModuleBuilder):
  260. mod = mod.build()
  261. if hasattr(mod, "graph"):
  262. for node in mod.graph.nodes():
  263. node.value = None
  264. setattr(parent, name, mod)
  265. set_active_module_tracer(None)
  266. for node in self.global_scope.nodes():
  267. node.value = None
  268. extra_inp_nodes = set(self.global_scope.inputs)
  269. max_inp_expr_idx = -1
  270. for node in extra_inp_nodes:
  271. assert (
  272. node.top_graph == self.graph
  273. ), "The input node ({}) is not in the graph ({})".format(node, self.graph)
  274. if isinstance(node, TensorNode) and node.expr in self.graph._exprs:
  275. max_inp_expr_idx = max(
  276. max_inp_expr_idx, self.graph._exprs.index(node.expr)
  277. )
  278. max_inp_expr_idx += 1
  279. insert_index = -1
  280. if self.expr is not None:
  281. insert_index = self.graph._exprs.index(self.expr)
  282. insert_index += 1
  283. if insert_index < max_inp_expr_idx:
  284. insert_index = max_inp_expr_idx
  285. anchor_index = insert_index - 1
  286. if anchor_index >= 0:
  287. logger.info(
  288. "The new expr will be inserted after ( {} )".format(
  289. self.graph._exprs[anchor_index]
  290. )
  291. )
  292. for expr in self.global_scope._exprs:
  293. self.graph._exprs.insert(insert_index, expr)
  294. insert_index += 1
  295. self.graph._used_names.update(self.global_scope._used_names)
  296. self.root_graph._total_ids = (Node.get_total_id(), Expr.get_total_id())
  297. self.root_graph.inputs[0].owner._update_ref()
  298. return True
  299. class InternalGraph:
  300. r"""``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
  301. Attributes:
  302. _exprs: List of Exprs in order of execution
  303. _inputs: Input Nodes of InternalGraph
  304. _outputs: Output Nodes of InternalGraph
  305. """
  306. _exprs = None # type: List[Expr]
  307. _inputs = None # type: List[Node]
  308. _outputs = None # type: List[Node]
  309. _top_graph = None # type: InternalGraph
  310. _total_ids = None # type: List[int]
  311. def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""):
  312. self._exprs = []
  313. self._inputs = []
  314. self._outputs = []
  315. self._watch_point = []
  316. self._end_point = []
  317. self._used_names = {}
  318. self._rst = collections.defaultdict(list)
  319. self._name = name
  320. self._prefix_name = prefix_name
  321. self._module_name = module_name
  322. def _insert(self, expr):
  323. self._exprs.append(expr)
  324. def _create_unique_name(self, name: str) -> str:
  325. assert isinstance(name, str), "The name must be a str"
  326. name = re.sub("[^0-9a-zA-Z_]+", "_", name)
  327. if name[0].isdigit():
  328. name = "_{}".format(name)
  329. while name in self._used_names or _is_builtin_name(name):
  330. match = re.match(r"(.*)_(\d+)$", name)
  331. if match is None:
  332. name = name + "_1"
  333. else:
  334. base, num = match.group(1, 2)
  335. name = "{}_{}".format(base, int(num) + 1)
  336. self._used_names.setdefault(name)
  337. return name
  338. @property
  339. def inputs(self):
  340. return self._inputs
  341. @property
  342. def outputs(self):
  343. return self._outputs
  344. @property
  345. def top_graph(self):
  346. if self._top_graph:
  347. return self._top_graph()
  348. return None
  349. def exprs(self, recursive=True):
  350. return ExprFilter(_expr_iter(self, recursive))
  351. def nodes(self, recursive=True):
  352. return NodeFilter(_node_iter(self, recursive))
  353. def get_function_by_type(self, func: Callable = None, recursive=True):
  354. return self.exprs(recursive).call_function(func)
  355. def get_method_by_type(self, method: str = None, recursive=True):
  356. return self.exprs(recursive).call_method(method)
  357. def get_expr_by_id(self, expr_id: List[int] = None, recursive=True):
  358. return self.exprs(recursive).expr_id(expr_id)
  359. def get_module_by_type(self, module_cls: Module, recursive=True):
  360. assert issubclass(module_cls, Module)
  361. return self.nodes(recursive).type(module_cls, ModuleNode)
  362. def get_node_by_id(self, node_id: List[int] = None, recursive=True):
  363. return self.nodes(recursive).node_id(node_id)
  364. def get_node_by_name(
  365. self, name: str = None, ignorecase: bool = True, recursive=True
  366. ):
  367. return self.nodes(recursive).name(name, ignorecase)
  368. def _add_input(self, i):
  369. self._inputs.append(i)
  370. def _add_output(self, o):
  371. self._outputs.append(o)
  372. def _replace_inputs_outputs(self, repl_dict, prefix_name="", module_name=""):
  373. for node, repl_node in repl_dict.items():
  374. assert node in self._inputs or node in self._outputs
  375. for i in node.users:
  376. if i not in repl_node.users:
  377. repl_node.users.append(i)
  378. for idx, i in enumerate(self._inputs):
  379. if i in repl_dict:
  380. self._inputs[idx] = repl_dict[i]
  381. for idx, o in enumerate(self._outputs):
  382. if o in repl_dict:
  383. repl_dict[o]._orig_name = "{}{}".format(module_name, o._orig_name)
  384. self._outputs[idx] = repl_dict[o]
  385. for expr in self._exprs:
  386. for idx, i in enumerate(expr.inputs):
  387. assert isinstance(
  388. i._name, str
  389. ), "The node ({}) name must be a str".format(i)
  390. if i in repl_dict:
  391. expr.inputs[idx] = repl_dict[i]
  392. elif isinstance(i, TensorNode) and prefix_name not in i._name:
  393. if i.top_graph != active_module_tracer().current_scope():
  394. i._name = (
  395. active_module_tracer()
  396. .current_scope()
  397. ._create_unique_name(prefix_name + i._name.lstrip("_"))
  398. )
  399. i._orig_name = "{}{}".format(module_name, i._orig_name)
  400. for idx, o in enumerate(expr.outputs):
  401. assert isinstance(
  402. o._name, str
  403. ), "The node ({}) name must be a str".format(i)
  404. if o in repl_dict:
  405. expr.outputs[idx] = repl_dict[o]
  406. expr.outputs[idx].expr = expr
  407. elif isinstance(o, TensorNode) and prefix_name not in i._name:
  408. if o.top_graph != active_module_tracer().current_scope():
  409. o._name = (
  410. active_module_tracer()
  411. .current_scope()
  412. ._create_unique_name(prefix_name + o._name.lstrip("_"))
  413. )
  414. o._orig_name = "{}{}".format(module_name, o._orig_name)
  415. def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
  416. if not isinstance(nodes, Sequence):
  417. nodes = (nodes,)
  418. ret = list()
  419. queue = list(nodes)
  420. visited_queue = list()
  421. while queue:
  422. node = queue.pop()
  423. visited_queue.append(node)
  424. expr = node.expr
  425. if expr not in ret:
  426. ret.append(expr)
  427. for i in expr.inputs:
  428. if i not in queue and i not in visited_queue:
  429. queue.append(i)
  430. return ret
  431. def reset_inputs(self, *args, **kwargs):
  432. forma_mnode = self.inputs[0]
  433. actual_mnodes = forma_mnode.actual_node
  434. call_nodes = []
  435. for n in actual_mnodes:
  436. for c_expr in n.users:
  437. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  438. call_nodes.append((c_expr, n))
  439. moudle = forma_mnode.owner
  440. assert moudle._is_top, "reset_inputs only support the top-level graph"
  441. inputs, tree_def = tree_flatten(((moudle, *args), kwargs))
  442. def create_node(val: Tensor):
  443. node = Input(type=TensorNode).outputs[0]
  444. node.shape = val.shape
  445. node.dtype = val.dtype
  446. return node
  447. formal_node_inputs = [
  448. forma_mnode,
  449. ]
  450. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  451. if call_nodes:
  452. org_argdef = call_nodes[0][0].arg_def
  453. for v in inputs[1:]:
  454. assert isinstance(v, RawTensor)
  455. formal_node_inputs.append(create_node(v))
  456. actual_nodes = []
  457. for e, n in call_nodes:
  458. e.arg_def = tree_def
  459. actual_node_inputs = [
  460. n,
  461. ]
  462. for v in inputs[1:]:
  463. actual_node_inputs.append(create_node(v))
  464. for org_n in e.inputs:
  465. org_n.users.pop(e)
  466. e.inputs[:] = actual_node_inputs
  467. e.const_val = []
  468. actual_nodes.append(actual_node_inputs[1:])
  469. self._inputs[:] = formal_node_inputs
  470. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  471. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  472. # return formal_node_inputs[1:], actual_nodes
  473. return formal_node_inputs[1:]
  474. def add_input_node(self, shape, dtype="float32", name="args"):
  475. forma_mnode = self.inputs[0]
  476. actual_mnodes = forma_mnode.actual_node
  477. moudle = forma_mnode.owner
  478. assert moudle._is_top, "add_input_node only support the top-level graph"
  479. call_nodes = []
  480. for n in actual_mnodes:
  481. for c_expr in n.users:
  482. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  483. call_nodes.append(c_expr)
  484. def create_node(name=None, is_input: bool = True):
  485. if is_input:
  486. node = Input(type=TensorNode, name=name).outputs[0]
  487. else:
  488. node = TensorNode(expr=None, name=None)
  489. node.shape = shape
  490. node.dtype = dtype
  491. return node
  492. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  493. if call_nodes:
  494. org_argdef = call_nodes[0].arg_def
  495. args, kwargs = org_argdef.unflatten(self._inputs)
  496. formal_inp_node = create_node(self._create_unique_name(name), True)
  497. inputs, tree_def = tree_flatten(
  498. ((*args, formal_inp_node), kwargs),
  499. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  500. )
  501. self._inputs[:] = inputs[:]
  502. actual_inp_nodes = []
  503. for e in call_nodes:
  504. args, kwargs = e.unflatten_args(e.inputs)
  505. args = args + (create_node(False),)
  506. inputs, tree_def = tree_flatten(
  507. (args, kwargs),
  508. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  509. )
  510. e.inputs[:] = inputs[:]
  511. e.arg_def = tree_def
  512. actual_inp_nodes.append(args[-1])
  513. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  514. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  515. # return formal_inp_node, actual_inp_nodes
  516. return formal_inp_node
  517. def reset_outputs(self, outputs):
  518. outputs, out_def = tree_flatten(
  519. outputs, is_leaf=lambda x: isinstance(x, TensorNode),
  520. )
  521. forma_mnode = self.inputs[0]
  522. moudle = forma_mnode.owner
  523. assert moudle._is_top, "reset_outputs only support the top-level graph"
  524. actual_mnodes = forma_mnode.actual_node
  525. call_nodes = []
  526. for n in actual_mnodes:
  527. for c_expr in n.users:
  528. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  529. call_nodes.append((c_expr))
  530. def create_node(val: TensorNode, expr: Expr):
  531. node = TensorNode(expr)
  532. node.shape = val.shape
  533. node.dtype = val.dtype
  534. return node
  535. tree_def = list(moudle.argdef_graph_map.keys())[0]
  536. if call_nodes:
  537. tree_def = call_nodes[0].arg_def
  538. actual_nodes = []
  539. for e in call_nodes:
  540. actual_node_outputs = []
  541. for v in outputs:
  542. actual_node_outputs.append(create_node(v, e))
  543. e.outputs[:] = actual_node_outputs
  544. e.out_def = out_def
  545. actual_nodes.append(actual_node_outputs)
  546. self._outputs[:] = outputs
  547. moudle.argdef_outdef_map[tree_def] = out_def
  548. return actual_nodes
  549. def add_output_node(self, node: TensorNode):
  550. forma_mnode = self.inputs[0]
  551. moudle = forma_mnode.owner
  552. assert moudle._is_top, "add_output_node only support the top-level graph"
  553. actual_mnodes = forma_mnode.actual_node
  554. call_nodes = []
  555. for n in actual_mnodes:
  556. for c_expr in n.users:
  557. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  558. call_nodes.append((c_expr))
  559. def create_node(val: TensorNode, expr: Expr):
  560. node = TensorNode(expr)
  561. node.shape = val.shape
  562. node.dtype = val.dtype
  563. return node
  564. tree_def = list(moudle.argdef_graph_map.keys())[0]
  565. if call_nodes:
  566. tree_def = call_nodes[0].arg_def
  567. org_out_def = moudle.argdef_outdef_map[tree_def]
  568. org_outs = org_out_def.unflatten(self._outputs)
  569. outputs, out_def = tree_flatten(
  570. (org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode),
  571. )
  572. self._outputs[:] = outputs
  573. actual_out_nodes = []
  574. for e in call_nodes:
  575. actual_node = create_node(node, e)
  576. org_outs = org_out_def.unflatten(e.outputs)
  577. outputs, out_def = tree_flatten(
  578. (org_outs, actual_node), is_leaf=lambda x: isinstance(x, TensorNode),
  579. )
  580. e.outputs[:] = outputs
  581. e.out_def = out_def
  582. actual_out_nodes.append(actual_node)
  583. moudle.argdef_outdef_map[tree_def] = out_def
  584. return actual_out_nodes
  585. def insert_exprs(self, expr: Optional[Expr] = None):
  586. if expr is not None:
  587. assert expr.top_graph == self, "Expr to insert after is not in graph."
  588. return _InsertExprs(self, expr)
  589. def replace_node(self, repl_dict: Dict[Node, Node]):
  590. while repl_dict:
  591. node, repl_node = repl_dict.popitem()
  592. assert type(node) == type(
  593. repl_node
  594. ), "The type of {}({}) and {}({}) are not the same".format(
  595. node, type(node).__name__, repl_node, type(repl_node).__name__
  596. )
  597. # check graph inputs and outputs
  598. for i, n in enumerate(self.outputs):
  599. if n is node:
  600. self.outputs[i] = repl_node
  601. # update users of node and repl_node
  602. # update inputs of expr in node.users
  603. graph = repl_node.top_graph
  604. assert graph is not None
  605. assert graph is self
  606. index = -1
  607. if not isinstance(repl_node.expr, Input):
  608. index = graph._exprs.index(repl_node.expr)
  609. dep_exprs = self.get_dep_exprs(repl_node)
  610. i = 0
  611. while i < len(node.users):
  612. n = node.users[i]
  613. if n in graph._exprs and index >= graph._exprs.index(n):
  614. i += 1
  615. continue
  616. if n in dep_exprs:
  617. logger.info("Find a loop: ignore this replacement once")
  618. logger.info("node: %s" % node.__repr__())
  619. logger.info("expr: %s" % n.__repr__())
  620. i += 1
  621. continue
  622. repl_node.users.append(n)
  623. node.users.pop(i)
  624. idx = n.inputs.index(node)
  625. n.inputs[idx] = repl_node
  626. def compile(self):
  627. """Delete unused expr."""
  628. dep_exprs = self.get_dep_exprs(self.outputs)
  629. i = 0
  630. while i < len(self._exprs):
  631. expr = self._exprs[i]
  632. if expr in dep_exprs or expr._disable_remove:
  633. i += 1
  634. continue
  635. for n in expr.inputs:
  636. n.users.remove(expr)
  637. self._exprs.remove(expr)
  638. def _reset_ids(self):
  639. for total_expr_id, expr in enumerate(self.exprs()):
  640. expr._id = total_expr_id
  641. for total_node_id, node in enumerate(self.nodes()):
  642. node._id = total_node_id
  643. self._total_ids = (total_node_id + 1, total_expr_id + 1)
  644. def interpret(self, *inputs):
  645. node2value = {}
  646. end_nodes_set = set(self._end_point)
  647. endnode2value = {}
  648. def get_all_endnode_val(n, v):
  649. if n in end_nodes_set:
  650. endnode2value[n] = v
  651. end_nodes_set.remove(n)
  652. return not end_nodes_set
  653. return False
  654. ref_count = lambda n: len(n.users) + (1 if n in self._outputs else 0)
  655. for n, v in zip(self._inputs, inputs):
  656. if ref_count(n) > 0:
  657. node2value[n] = [v, ref_count(n)]
  658. if n in self._watch_point:
  659. self._rst[n].append(v)
  660. if n in self._end_point and get_all_endnode_val(n, v):
  661. return list(endnode2value[i] for i in self._end_point)
  662. for expr in self._exprs:
  663. values = expr.interpret(*list(node2value[i][0] for i in expr.inputs))
  664. for n in expr.inputs:
  665. node2value[n][1] -= 1
  666. if node2value[n][1] == 0:
  667. node2value.pop(n)
  668. if values is not None:
  669. for n, v in zip(expr.outputs, values):
  670. if ref_count(n) > 0:
  671. node2value[n] = [v, ref_count(n)]
  672. if n in self._watch_point:
  673. self._rst[n] = v
  674. if self._end_point and get_all_endnode_val(n, v):
  675. return list(endnode2value[i] for i in self._end_point)
  676. return list(node2value[i][0] for i in self._outputs)
  677. def eval(self, *inputs):
  678. assert len(inputs) == len(self._inputs) - 1
  679. inp = [self._inputs[0].owner] + list(inputs)
  680. return self.interpret(*inp)
  681. def __repr__(self):
  682. return self.__format__()
  683. def __format__(self, format_spec: str = "") -> str:
  684. saved_format_spec = Node.set_format_spec(format_spec)
  685. name = ""
  686. if self._name:
  687. name = "%s.Graph" % self._name
  688. res = "{} ({}) {{\n\t{}\n\treturn {}\n}}".format(
  689. name,
  690. ", ".join(str(i) for i in self._inputs),
  691. "\n\t".join("{}".format(str(i)) for i in self._exprs),
  692. ", ".join(str(i) for i in self._outputs),
  693. )
  694. Node.set_format_spec(saved_format_spec)
  695. return res
  696. def __getstate__(self):
  697. state = self.__dict__.copy()
  698. if "_top_graph" in state:
  699. state.pop("_top_graph")
  700. return state
  701. def _get_meth_name(obj, func):
  702. tp = obj if isinstance(obj, type) else type(obj)
  703. for cls in tp.mro():
  704. for k, v in cls.__dict__.items():
  705. if v == func:
  706. return k
  707. return None
  708. def _wrapped_function(orig_func):
  709. @functools.wraps(orig_func)
  710. def wrapped_fn(*args, **kwargs):
  711. method_func = wrapped_fn
  712. if "method_func" in kwargs:
  713. method_func = kwargs.pop("method_func")
  714. if is_tracing_module():
  715. unset_module_tracing()
  716. inputs, tree_def = tree_flatten((args, kwargs))
  717. for i in inputs:
  718. if not NodeMixin.get(i, None):
  719. if isinstance(i, (RawTensor, NodeMixin)):
  720. NodeMixin.wrap_safe(i, Constant.make(i))
  721. meth_name, arg_type = None, None
  722. if args:
  723. meth_name = _get_meth_name(args[0], method_func)
  724. arg_type = args[0] if isinstance(args[0], type) else type(args[0])
  725. if meth_name and arg_type and issubclass(arg_type, RawTensor):
  726. self = inputs[0]
  727. if meth_name == "__new__":
  728. if all([not isinstance(i, RawTensor) for i in inputs]):
  729. # only trace Tensor.__new__() when there are tensors in args
  730. set_module_tracing()
  731. return orig_func(*args, **kwargs)
  732. if isinstance(args[1], RawTensor):
  733. node = NodeMixin.get(inputs[1])
  734. inputs[1] = copy.copy(inputs[1])
  735. # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing.
  736. NodeMixin.wrap_safe(inputs[1], node)
  737. args, kwargs = tree_def.unflatten(inputs)
  738. call_node = CallMethod.make(self, meth_name)
  739. else:
  740. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  741. call_node.add_inputs(inputs[1:])
  742. else:
  743. call_node = CallFunction.make(orig_func)
  744. call_node.add_inputs(inputs)
  745. call_node.arg_def = tree_def
  746. rst = orig_func(*args, **kwargs)
  747. if meth_name == "__setitem__":
  748. rst = self
  749. if rst is not None:
  750. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  751. call_node.out_def = out_def
  752. else:
  753. outputs = None
  754. call_node.add_outputs(outputs)
  755. set_module_tracing()
  756. return rst
  757. return orig_func(*args, **kwargs)
  758. return wrapped_fn
  759. class TracedModuleBuilder(NodeMixin):
  760. _mod = None # type: Module
  761. _body = None # type: InternalGraph
  762. _is_builtin = None # type: bool
  763. _argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
  764. _argdef_outdef_map = None # type: Dict[Treedef, Treedef]
  765. nodes = None
  766. __builder_attributes__ = [
  767. "_mod",
  768. "_body",
  769. "_NodeMixin__node",
  770. "_is_builtin",
  771. "build",
  772. "_record_wrapped_nodes",
  773. "_argdef_graph_map",
  774. "_argdef_outdef_map",
  775. "nodes",
  776. "__class__",
  777. "__dict__",
  778. ]
  779. def __init__(self, mod, is_top_module=False):
  780. super(TracedModuleBuilder, self).__init__()
  781. assert isinstance(mod, Module)
  782. self._mod = mod
  783. self._body = None
  784. self._is_top = is_top_module
  785. self._is_builtin = (
  786. True
  787. if isinstance(mod, (Observer, _FakeQuantize))
  788. else module_tracer.is_builtin(mod)
  789. )
  790. if isinstance(self._mod, QATModule):
  791. unset_module_tracing()
  792. self._check_qat_module(self._mod)
  793. set_module_tracing()
  794. self._argdef_graph_map = {}
  795. self._argdef_outdef_map = {}
  796. self.nodes = set()
  797. # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
  798. # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
  799. self.__class__ = type(
  800. "TracedModuleBuilder",
  801. (TracedModuleBuilder, mod.__class__),
  802. dict(TracedModuleBuilder.__dict__),
  803. )
  804. def _check_qat_module(self, qat_module):
  805. def isbuiltin(m):
  806. return m is None or module_tracer.is_builtin(m)
  807. if qat_module.with_act:
  808. act_observer = qat_module.act_observer
  809. act_fakequant = qat_module.act_fake_quant
  810. if not isbuiltin(act_observer) or not isbuiltin(act_fakequant):
  811. qparams = (
  812. act_observer.get_qparams()
  813. if hasattr(act_observer, "get_qparams")
  814. else act_fakequant.get_qparams()
  815. )
  816. dtype = (
  817. act_observer.dtype
  818. if hasattr(act_observer, "dtype")
  819. else act_fakequant.dtype
  820. )
  821. qat_module.act_observer = None
  822. qat_module.act_fake_quant = TM_FakeQuant(dtype)
  823. qat_module.act_fake_quant.set_qparams(qparams)
  824. if qat_module.with_weight:
  825. weight_observer = qat_module.weight_observer
  826. weight_fakequant = qat_module.weight_fake_quant
  827. if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant):
  828. qparams = (
  829. weight_observer.get_qparams()
  830. if hasattr(weight_observer, "get_qparams")
  831. else weight_fakequant.get_qparams()
  832. )
  833. dtype = (
  834. weight_observer.dtype
  835. if hasattr(weight_observer, "dtype")
  836. else weight_fakequant.dtype
  837. )
  838. qat_module.weight_observer = None
  839. qat_module.weight_fake_quant = TM_FakeQuant(dtype)
  840. qat_module.weight_fake_quant.set_qparams(qparams)
  841. def build(self):
  842. if self._is_builtin or isinstance(self._mod, TracedModule):
  843. if module_tracer.is_builtin(self._mod) or isinstance(
  844. self._mod, TracedModule
  845. ):
  846. mod_type = type(self._mod)
  847. else:
  848. assert isinstance(self._mod, (Observer, _FakeQuantize))
  849. mod_type = (
  850. Observer if isinstance(self._mod, Observer) else _FakeQuantize
  851. )
  852. for node in self.nodes:
  853. node.module_type = mod_type
  854. return self._mod
  855. else:
  856. is_qat = isinstance(self._mod, QATModule)
  857. traced_module = TracedModule(
  858. self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat
  859. )
  860. for _, g in self._argdef_graph_map.items():
  861. g.compile()
  862. if self._is_top:
  863. g._total_ids = (Node.get_total_id(), Expr.get_total_id())
  864. for k, v in self.__dict__.items():
  865. if k not in TracedModuleBuilder.__builder_attributes__:
  866. if isinstance(v, TracedModuleBuilder):
  867. v = v.build()
  868. setattr(traced_module, k, v)
  869. elif isinstance(v, RawTensor):
  870. setattr(traced_module, k, v)
  871. if isinstance(self._mod, QATModule):
  872. unset_module_tracing()
  873. traced_module.with_act = self._mod.with_act
  874. traced_module.with_weight = self._mod.with_weight
  875. if not hasattr(traced_module, "act_fake_quant"):
  876. traced_module.act_fakequant = None
  877. if not hasattr(traced_module, "act_observer"):
  878. traced_module.act_observer = None
  879. if not hasattr(traced_module, "weight_fake_quant"):
  880. traced_module.weight_fakequant = None
  881. if not hasattr(traced_module, "weight_observer"):
  882. traced_module.weight_observer = None
  883. set_module_tracing()
  884. return traced_module
  885. def _record_wrapped_nodes(self, node):
  886. self.nodes.add(node)
  887. def __call__(self, *args, **kwargs):
  888. assert isinstance(self._mod, Module)
  889. # prepare args and kwargs for inner graph
  890. if "method_func" in kwargs:
  891. kwargs.pop("method_func")
  892. def mark_constant(x):
  893. node = NodeMixin.get(x, None)
  894. if node is None: # capture as constant
  895. NodeMixin.wrap(x, lambda: Constant.make(x))
  896. inputs, tree_def = tree_flatten(((self, *args), kwargs))
  897. for i in inputs:
  898. mark_constant(i)
  899. callnode = CallMethod.make(NodeMixin.get(self))
  900. callnode.add_inputs(inputs[1:])
  901. callnode.arg_def = tree_def
  902. if (
  903. self._is_builtin
  904. or tree_def in self._argdef_graph_map
  905. or isinstance(self._mod, TracedModule)
  906. ):
  907. unset_module_tracing()
  908. rst = self._mod(*args, **kwargs)
  909. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  910. set_module_tracing()
  911. if self._is_builtin:
  912. self._body = None
  913. elif tree_def in self._argdef_graph_map:
  914. self._body = self._argdef_graph_map[tree_def]
  915. else:
  916. self._mod._is_top = False
  917. self._body = self._mod.graph
  918. else:
  919. self_node = None
  920. orig_self = NodeMixin.get(self)
  921. top_graph = active_module_tracer().current_scope()
  922. graph_prefix_name = top_graph._name
  923. if top_graph._prefix_name:
  924. graph_prefix_name = "{}_{}".format(
  925. top_graph._prefix_name, graph_prefix_name.lstrip("_")
  926. )
  927. module_name = orig_self._orig_name
  928. if top_graph._module_name:
  929. module_name = "{}.{}".format(top_graph._module_name, module_name)
  930. self._body = InternalGraph(
  931. orig_self._name, prefix_name=graph_prefix_name, module_name=module_name
  932. )
  933. active_module_tracer().push_scope(self._body)
  934. # rebind self to new input node
  935. if self_node:
  936. NodeMixin.wrap_safe(self, self_node)
  937. active_module_tracer().current_scope()._add_input(self_node)
  938. else:
  939. NodeMixin.wrap_safe(
  940. self,
  941. self_node
  942. if self_node
  943. else Input.make("self", NodeMixin.get_wrapped_type(self), ""),
  944. )
  945. origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
  946. # prepare args and kwargs for inner graph
  947. index_args, index_kwargs = tree_def.unflatten(
  948. [
  949. ArgsIndex(0),
  950. *list(ArgsIndex(i + 1) for i in range(len(origin_inp_node))),
  951. ]
  952. )
  953. key2idx = getcallargs(type(self._mod).forward, *index_args, **index_kwargs)
  954. idx2key = {}
  955. for k, v in key2idx.items():
  956. if isinstance(v, ArgsIndex):
  957. idx2key[v.index] = k
  958. else:
  959. flatten_argidx, _ = tree_flatten(v)
  960. for _i, v in enumerate(flatten_argidx):
  961. if isinstance(v, ArgsIndex):
  962. idx2key[v.index] = k + "_%d" % _i
  963. def wrap(x, name):
  964. if isinstance(x, (RawTensor, NodeMixin)):
  965. NodeMixin.wrap(
  966. x,
  967. lambda: Input.make(
  968. type=NodeMixin.get_wrapped_type(x), name=name
  969. ),
  970. )
  971. return x
  972. args = [self]
  973. for i, v in enumerate(inputs[1:]):
  974. args.append(wrap(v, idx2key[i + 1]))
  975. args, kwargs = tree_def.unflatten(args)
  976. active_module_tracer().patcher.auto_patch(
  977. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  978. )
  979. rst = type(self._mod).forward(*args, **kwargs)
  980. if _convert_node_flag():
  981. rst = _node_to_tensor(rst)[0][0]
  982. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  983. for i in (
  984. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  985. ):
  986. active_module_tracer().current_scope()._add_output(NodeMixin.get(i))
  987. NodeMixin.wrap_safe(self, orig_self)
  988. for arg, node in zip(inputs[1:], origin_inp_node):
  989. if node:
  990. NodeMixin.wrap_safe(arg, node)
  991. active_module_tracer().pop_scope()
  992. # rebind output to outer graph
  993. callnode.out_def = out_def
  994. callnode.add_outputs(outputs)
  995. self._argdef_graph_map[callnode.arg_def] = self._body
  996. self._argdef_outdef_map[callnode.arg_def] = out_def
  997. return rst
  998. def __setattr__(self, name, value):
  999. object.__setattr__(self, name, value)
  1000. def __repr__(self):
  1001. return repr(self._mod)
  1002. def __getattr__(self, name):
  1003. if name not in self._mod.__dict__:
  1004. attr = getattr(type(self._mod), name).__get__(self, type(self))
  1005. else:
  1006. attr = getattr(self._mod, name)
  1007. full_name = None
  1008. if (
  1009. isinstance(attr, FunctionType)
  1010. and id(attr) in active_module_tracer().patcher.patched_fn_ids
  1011. ):
  1012. return active_module_tracer().patcher.wrap_fn(attr)
  1013. if id(attr) in active_module_tracer().id2name:
  1014. full_name = active_module_tracer().id2name[id(attr)]
  1015. if isinstance(attr, (List, Dict)):
  1016. unset_module_tracing()
  1017. has_module, m_container = replace_container_with_module_container(attr)
  1018. if m_container:
  1019. attr = m_container
  1020. if has_module and not m_container:
  1021. raise ValueError(
  1022. "Can not trace the module that uses the same container to store Module and Non-Module objects "
  1023. )
  1024. set_module_tracing()
  1025. if isinstance(attr, Module):
  1026. attr = TracedModuleBuilder(attr)
  1027. if isinstance(attr, (Module, RawTensor)):
  1028. setattr(self, name, attr)
  1029. active_module_tracer().id2name[id(attr)] = full_name
  1030. if full_name:
  1031. scope_name = active_module_tracer().current_scope()._module_name
  1032. if scope_name:
  1033. full_name = full_name[len(scope_name) + 1 :]
  1034. else:
  1035. full_name = name
  1036. else:
  1037. full_name = name
  1038. NodeMixin.wrap(
  1039. attr,
  1040. lambda: GetAttr.make(
  1041. NodeMixin.get(self),
  1042. name,
  1043. type=NodeMixin.get_wrapped_type(attr),
  1044. orig_name=full_name,
  1045. ),
  1046. )
  1047. return attr
  1048. def __getattribute__(self, name):
  1049. if name in TracedModuleBuilder.__builder_attributes__:
  1050. return object.__getattribute__(self, name)
  1051. else:
  1052. wrapped = object.__getattribute__(self, name)
  1053. class_members = dict(inspect.getmembers(self.__class__))
  1054. if name in self._mod.__dict__:
  1055. mod_attr = getattr(self._mod, name)
  1056. if name in class_members:
  1057. if (
  1058. not isinstance(wrapped, TracedModuleBuilder)
  1059. and wrapped is not mod_attr
  1060. ):
  1061. wrapped = self.__getattr__(name)
  1062. if isinstance(wrapped, TracedModuleBuilder):
  1063. if not isinstance(mod_attr, (List, Dict)):
  1064. assert mod_attr is wrapped._mod
  1065. else:
  1066. assert mod_attr is wrapped
  1067. full_name = None
  1068. if id(mod_attr) in active_module_tracer().id2name:
  1069. full_name = active_module_tracer().id2name[id(mod_attr)]
  1070. scope_name = active_module_tracer().current_scope()._module_name
  1071. if full_name and scope_name:
  1072. full_name = full_name[len(scope_name) + 1 :]
  1073. else:
  1074. full_name = name
  1075. else:
  1076. full_name = name
  1077. # assert not self._is_builtin
  1078. if isinstance(wrapped, (NodeMixin, RawTensor)):
  1079. NodeMixin.wrap(
  1080. wrapped,
  1081. lambda: GetAttr.make(
  1082. NodeMixin.get(self),
  1083. name,
  1084. type=NodeMixin.get_wrapped_type(wrapped),
  1085. orig_name=full_name,
  1086. ),
  1087. )
  1088. return wrapped
  1089. class _expr_iter:
  1090. def __init__(self, graph: InternalGraph, recursive: bool = True):
  1091. self.graph = graph
  1092. self.recursive = recursive
  1093. def __iter__(self):
  1094. for inp_node in self.graph.inputs:
  1095. yield inp_node.expr
  1096. for expr in self.graph._exprs:
  1097. if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
  1098. yield expr
  1099. if self.recursive and expr.graph is not None:
  1100. yield from expr.graph.exprs(self.recursive)
  1101. else:
  1102. yield expr
  1103. class _node_iter:
  1104. def __init__(self, graph: InternalGraph, recursive: bool = True) -> None:
  1105. nodes = []
  1106. node_ids = set()
  1107. for expr in graph.exprs(recursive):
  1108. for n in expr.inputs + expr.outputs:
  1109. if id(n) in node_ids:
  1110. continue
  1111. nodes.append(n)
  1112. node_ids.add(id(n))
  1113. self.nodes = list(sorted(nodes, key=lambda x: x._id))
  1114. def __iter__(self):
  1115. for node in self.nodes:
  1116. yield node
  1117. class BaseFilter:
  1118. def __init__(self, expr_iter: Iterable):
  1119. self._iter = expr_iter
  1120. def __iter__(self):
  1121. return iter(self._iter)
  1122. def as_list(self):
  1123. return list(self)
  1124. def as_dict(self):
  1125. return collections.OrderedDict((i._id, i) for i in self)
  1126. def as_unique(self):
  1127. rst = self.as_list()
  1128. assert len(rst) == 1, "{} elements found".format(len(rst))
  1129. (expr,) = self
  1130. return expr
  1131. def as_count(self):
  1132. return sum(1 for _ in self)
  1133. class ExprFilter(BaseFilter):
  1134. def call_function(self, func):
  1135. return ExprFilterCallFunction(self, func)
  1136. def call_method(self, method):
  1137. return ExprFilterCallMethod(self, method)
  1138. def expr_id(self, expr_id: List[int]):
  1139. return ExprFilterExprId(self, expr_id)
  1140. class NodeFilter(BaseFilter):
  1141. def type(self, owner_type, node_type):
  1142. return NodeFilterType(self, owner_type, node_type)
  1143. def node_id(self, node_id: List[int]):
  1144. return NodeFilterNodeId(self, node_id)
  1145. def name(self, name: str, ignorecase: bool = True):
  1146. return NodeFilterName(self, name, ignorecase)
  1147. class NodeFilterType(NodeFilter):
  1148. def __init__(self, expr_iter, owner_type, node_type):
  1149. super().__init__(expr_iter)
  1150. self.owner_type = owner_type
  1151. self.node_type = node_type
  1152. def __iter__(self):
  1153. for node in self._iter:
  1154. if not isinstance(node, self.node_type):
  1155. continue
  1156. if not hasattr(node, "owner"):
  1157. continue
  1158. if isinstance(node.owner, self.owner_type):
  1159. yield node
  1160. class NodeFilterNodeId(NodeFilter):
  1161. def __init__(self, expr_iter, node_id: List[int]):
  1162. super().__init__(expr_iter)
  1163. if not isinstance(node_id, Sequence):
  1164. node_id = [node_id]
  1165. self.node_id = node_id
  1166. def __iter__(self):
  1167. for node in self._iter:
  1168. if node._id in self.node_id:
  1169. yield node
  1170. class NodeFilterName(NodeFilter):
  1171. _re = None
  1172. def __init__(self, node_iter, pattern, ignorecase):
  1173. super().__init__(node_iter)
  1174. self.pattern = pattern
  1175. self._re = self.make_re(pattern, ignorecase)
  1176. @classmethod
  1177. def make_re(cls, pattern, ignorecase=True):
  1178. assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern)
  1179. assert isinstance(ignorecase, bool)
  1180. flags = 0
  1181. if ignorecase:
  1182. flags |= re.IGNORECASE
  1183. return re.compile(fnmatch.translate(pattern), flags=flags)
  1184. def __iter__(self):
  1185. for i in self._iter:
  1186. graph = i.top_graph
  1187. name = "{}_{}".format(graph._name, i._name.lstrip("_"))
  1188. if graph._prefix_name:
  1189. name = "{}_{}".format(graph._prefix_name, name.lstrip("_"))
  1190. if self.pattern == name or self._re.match(name):
  1191. yield i
  1192. class ExprFilterCallFunction(ExprFilter):
  1193. def __init__(self, expr_iter, func: Callable = None):
  1194. super().__init__(expr_iter)
  1195. self.func = func
  1196. def __iter__(self):
  1197. for expr in self._iter:
  1198. if not isinstance(expr, CallFunction):
  1199. continue
  1200. if self.func is None or expr.func == self.func:
  1201. yield expr
  1202. class ExprFilterCallMethod(ExprFilter):
  1203. def __init__(self, expr_iter, method: str = None):
  1204. super().__init__(expr_iter)
  1205. self.method = method
  1206. def __iter__(self):
  1207. for expr in self._iter:
  1208. if not isinstance(expr, CallMethod):
  1209. continue
  1210. if self.method is None or expr.method == self.method:
  1211. yield expr
  1212. class ExprFilterExprId(ExprFilter):
  1213. def __init__(self, expr_iter, expr_id: List[int]):
  1214. super().__init__(expr_iter)
  1215. if not isinstance(expr_id, Sequence):
  1216. expr_id = [expr_id]
  1217. self.expr_id = expr_id
  1218. def __iter__(self):
  1219. for expr in self._iter:
  1220. if expr._id in self.expr_id:
  1221. yield expr
  1222. class TracedModule(Module):
  1223. r"""`TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it."""
  1224. # m_node = None # type: ModuleNode
  1225. argdef_graph_map = None
  1226. argdef_outdef_map = None
  1227. def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False):
  1228. super(TracedModule, self).__init__()
  1229. self.argdef_graph_map = argdef_graph_map
  1230. self.argdef_outdef_map = argdef_outdef_map
  1231. self._is_top = is_top
  1232. self.watch_points = []
  1233. self.watch_node_value = {}
  1234. self.end_points = []
  1235. self.is_qat = is_qat
  1236. def forward(self, *args, **kwargs):
  1237. inputs, treedef = tree_flatten(((self, *args), kwargs))
  1238. assert treedef in self.argdef_graph_map
  1239. inputs = filter(
  1240. lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
  1241. ) # allow TracedModuleBuilder for retrace.
  1242. outputs = self.argdef_graph_map[treedef].interpret(*inputs)
  1243. if self.watch_points:
  1244. self.watch_node_value = {}
  1245. for n in self.watch_points:
  1246. self.watch_node_value[n] = n.top_graph._rst.pop(n)
  1247. if self.end_points:
  1248. return outputs
  1249. out_def = self.argdef_outdef_map[treedef]
  1250. outputs = out_def.unflatten(outputs)
  1251. return outputs
  1252. def set_watch_points(self, nodes):
  1253. if not isinstance(nodes, Sequence):
  1254. nodes = [nodes]
  1255. self.watch_points = nodes
  1256. for n in nodes:
  1257. n.top_graph._watch_point.append(n)
  1258. def clear_watch_points(self):
  1259. for n in self.watch_points:
  1260. n.top_graph._watch_point = []
  1261. self.watch_points = []
  1262. self.watch_node_value = {}
  1263. def set_end_points(self, nodes):
  1264. if not isinstance(nodes, Sequence):
  1265. nodes = [nodes]
  1266. self.end_points = nodes
  1267. graphs = list(self.argdef_graph_map.values())
  1268. for n in nodes:
  1269. assert n.top_graph in graphs
  1270. n.top_graph._end_point.append(n)
  1271. def clear_end_points(self):
  1272. for n in self.end_points:
  1273. n.top_graph._end_point = []
  1274. self.end_points = []
  1275. @property
  1276. def graph(self) -> InternalGraph:
  1277. if self._is_top:
  1278. self._update_ref()
  1279. assert len(self.argdef_graph_map) == 1
  1280. return list(self.argdef_graph_map.values())[0]
  1281. def _update_ref(self, actual_node_map: Union[Dict] = None, top_graph=None):
  1282. for inp_def, graph in self.argdef_graph_map.items():
  1283. if top_graph is not None:
  1284. graph._top_graph = weakref.ref(top_graph)
  1285. for n in graph._inputs + graph.outputs:
  1286. n._top_graph = weakref.ref(graph)
  1287. graph._inputs[0]._owner = weakref.ref(self)
  1288. for i, n in enumerate(graph._inputs):
  1289. n.actual_node = []
  1290. if actual_node_map is not None and inp_def in actual_node_map.keys():
  1291. n.actual_node = list(list(zip(*(actual_node_map[inp_def])))[i])
  1292. node2obj = {}
  1293. next_actual_node_map = collections.defaultdict(
  1294. lambda: collections.defaultdict(list)
  1295. )
  1296. node2obj[graph._inputs[0]] = self
  1297. for expr in graph._exprs:
  1298. for n in expr.inputs + expr.outputs:
  1299. n._top_graph = weakref.ref(graph)
  1300. expr._top_graph = weakref.ref(graph)
  1301. if isinstance(expr, GetAttr) and isinstance(
  1302. expr.outputs[0], ModuleNode
  1303. ):
  1304. obj = getattr(node2obj[expr.inputs[0]], expr.name)
  1305. expr.outputs[0]._owner = weakref.ref(obj)
  1306. node2obj[expr.outputs[0]] = obj
  1307. if isinstance(expr, Constant) and isinstance(
  1308. expr.outputs[0], ModuleNode
  1309. ):
  1310. obj = expr.value
  1311. expr.outputs[0]._owner = weakref.ref(obj)
  1312. node2obj[expr.outputs[0]] = obj
  1313. if (
  1314. isinstance(expr, CallMethod)
  1315. and expr.method == "__call__"
  1316. and isinstance(expr.inputs[0], ModuleNode)
  1317. ):
  1318. obj = node2obj[expr.inputs[0]]
  1319. if expr.arg_def is not None:
  1320. next_actual_node_map[obj][expr.arg_def].append(expr.inputs)
  1321. for obj in node2obj.values():
  1322. if obj is self:
  1323. continue
  1324. mnode_map = None
  1325. if obj in next_actual_node_map.keys():
  1326. mnode_map = next_actual_node_map[obj]
  1327. if isinstance(obj, TracedModule):
  1328. obj._update_ref(mnode_map, graph)
  1329. def flatten(self):
  1330. r"""Get a new module, which eliminates ``GetAttr`` and has no hierarchy.
  1331. :return: :class:`TracedModule`
  1332. """
  1333. new_module = copy.deepcopy(self)
  1334. assert active_module_tracer() is None
  1335. id2name = _init_id2name(new_module, "self")
  1336. set_active_module_tracer(module_tracer(lambda x: x, {}))
  1337. active_module_tracer().push_scope(new_module.graph)
  1338. def _flatten_subgraph(
  1339. parent_graph: InternalGraph,
  1340. graph: InternalGraph,
  1341. module: Module,
  1342. call=None,
  1343. prefix_name="",
  1344. module_name="",
  1345. ):
  1346. if isinstance(prefix_name, str) and prefix_name and prefix_name[-1] != "_":
  1347. prefix_name += "_"
  1348. if isinstance(module_name, str) and module_name:
  1349. module_name += "."
  1350. if graph is None or module.is_qat:
  1351. assert not isinstance(module, TracedModule) or module.is_qat
  1352. const = Constant(module, id2name[id(module)])
  1353. m_node = call.inputs[0]
  1354. if m_node.top_graph != active_module_tracer().current_scope():
  1355. m_node._name = (
  1356. active_module_tracer()
  1357. .current_scope()
  1358. ._create_unique_name(prefix_name)
  1359. )
  1360. m_node._orig_name = id2name[id(module)][5:]
  1361. const.outputs[0] = m_node
  1362. const.outputs[0].expr = const
  1363. return [const, call]
  1364. if call is not None:
  1365. graph = copy.deepcopy(graph)
  1366. exprs = []
  1367. node2obj = {}
  1368. node2obj[graph._inputs[0]] = module
  1369. if call:
  1370. node2obj[call.inputs[0]] = module
  1371. # replace inputs for submodule's exprx
  1372. if call:
  1373. repl_dict = dict(zip(graph._inputs, call.inputs))
  1374. for ind, out in enumerate(graph.outputs):
  1375. if isinstance(out.expr, Input):
  1376. assert out in repl_dict
  1377. call_out = call.outputs[ind]
  1378. for expr in call.outputs[ind].users:
  1379. for index, inp in enumerate(expr.inputs):
  1380. if inp is call_out:
  1381. expr.inputs[index] = repl_dict[out]
  1382. repl_dict[out].users.append(expr)
  1383. if parent_graph is not None:
  1384. for index, parent_out in enumerate(parent_graph._outputs):
  1385. if parent_out is call_out:
  1386. parent_graph._outputs[index] = repl_dict[out]
  1387. continue
  1388. repl_dict[out] = call.outputs[ind]
  1389. graph._replace_inputs_outputs(repl_dict, prefix_name, module_name)
  1390. for expr in graph._exprs:
  1391. if isinstance(expr, GetAttr):
  1392. # replace GetAttr with Constant
  1393. if isinstance(expr.outputs[0], TensorNode):
  1394. const = Constant(getattr(node2obj[expr.inputs[0]], expr.name))
  1395. const.outputs = expr.outputs
  1396. const.outputs[0].expr = const
  1397. exprs.append(const)
  1398. elif isinstance(expr.outputs[0], ModuleNode):
  1399. node2obj[expr.outputs[0]] = getattr(
  1400. node2obj[expr.inputs[0]], expr.name
  1401. )
  1402. elif isinstance(expr, CallMethod):
  1403. obj_node = expr.inputs[0]
  1404. if isinstance(obj_node, ModuleNode):
  1405. pre_expr = expr.inputs[0].expr
  1406. if isinstance(pre_expr, GetAttr):
  1407. (obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]])
  1408. expr_graph = (
  1409. obj.argdef_graph_map[expr.arg_def]
  1410. if hasattr(obj, "argdef_graph_map")
  1411. else None
  1412. )
  1413. exprs.extend(
  1414. _flatten_subgraph(
  1415. graph,
  1416. expr_graph,
  1417. obj,
  1418. expr,
  1419. prefix_name + obj_node._name.lstrip("_"),
  1420. module_name + obj_node._orig_name,
  1421. )
  1422. )
  1423. else:
  1424. # module has been replaced.
  1425. assert isinstance(pre_expr, Constant)
  1426. exprs.append(expr)
  1427. else:
  1428. exprs.append(expr)
  1429. else:
  1430. exprs.append(expr)
  1431. if call is not None:
  1432. for i in call.inputs:
  1433. i.users.remove(call)
  1434. return exprs
  1435. new_module.graph._exprs = _flatten_subgraph(None, new_module.graph, new_module)
  1436. new_module.graph.compile()
  1437. set_active_module_tracer(None)
  1438. new_module.graph._reset_ids()
  1439. return new_module
  1440. def __getstate__(self):
  1441. d = self.__dict__
  1442. for k in Module.__dict__:
  1443. d.pop(k, None)
  1444. return d
  1445. def cpp_apply_module_trace(opdef, *args):
  1446. return Apply.apply_module_trace_hook(opdef, *args)
  1447. def register_as_builtin(mod_cls: Type[Module]) -> None:
  1448. r"""Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
  1449. Args:
  1450. mod_cls: the Module class which will be threated as builtin module in tracing
  1451. """
  1452. module_tracer.register_as_builtin(mod_cls)
  1453. def wrap(func: Callable):
  1454. r"""Call this function to register func as a builtin function."""
  1455. assert callable(func), "func must be a callable"
  1456. assert hasattr(func, "__code__")
  1457. fn_name = func.__code__.co_name
  1458. currentframe = inspect.currentframe()
  1459. assert currentframe is not None
  1460. f = currentframe.f_back
  1461. assert f is not None
  1462. assert (
  1463. f.f_code.co_name == "<module>"
  1464. ), "wrap must be called at the top level of a module"
  1465. Patcher._builtin_functions.append((f.f_globals, fn_name))
  1466. return func
  1467. def _register_all_builtin_module():
  1468. for sub_mod in [M, M.qat, M.quantized]:
  1469. for m in getmembers(sub_mod):
  1470. if (
  1471. isclass(m[1])
  1472. and issubclass(m[1], M.Module)
  1473. and m[1] is not M.Sequential
  1474. ):
  1475. module_tracer.register_as_builtin(m[1])
  1476. module_tracer.register_as_builtin(Observer)
  1477. module_tracer.register_as_builtin(MinMaxObserver)
  1478. module_tracer.register_as_builtin(SyncMinMaxObserver)
  1479. module_tracer.register_as_builtin(ExponentialMovingAverageObserver)
  1480. module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver)
  1481. module_tracer.register_as_builtin(HistogramObserver)
  1482. module_tracer.register_as_builtin(PassiveObserver)
  1483. module_tracer.register_as_builtin(LSQ)
  1484. module_tracer.register_as_builtin(TQT)
  1485. module_tracer.register_as_builtin(FakeQuantize)
  1486. module_tracer.register_as_builtin(TM_FakeQuant)
  1487. def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
  1488. r"""Traces module ``mod`` and returns corresponding TracedModule.
  1489. Args:
  1490. mod: the module will be converted to TracedModule
  1491. input: the positional arguments passed to forward method of ``mod``
  1492. kwargs: the keyword arguments passed to forward method of ``mod``
  1493. """
  1494. assert active_module_tracer() is None
  1495. assert isinstance(mod, Module)
  1496. try:
  1497. use_sym_shape = set_symbolic_shape(True)
  1498. set_module_tracing()
  1499. set_active_module_tracer(
  1500. module_tracer(_wrapped_function, _init_id2name(mod, "self"))
  1501. )
  1502. for cls in [Expr, Node]:
  1503. cls.set_total_id(0)
  1504. with active_module_tracer().patcher:
  1505. global_scope = InternalGraph(name="")
  1506. active_module_tracer().push_scope(global_scope)
  1507. builder = TracedModuleBuilder(mod, True)
  1508. name = mod._name if mod._name else mod.__class__.__name__
  1509. NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode, orig_name="self"))
  1510. inputs, _ = tree_flatten((args, kwargs))
  1511. for _, i in enumerate(inputs):
  1512. # assert isinstance(i, Tensor), "not support "
  1513. if isinstance(i, RawTensor):
  1514. NodeMixin.wrap_safe(
  1515. i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
  1516. )
  1517. builder(*args, **kwargs)
  1518. active_module_tracer().pop_scope()
  1519. traced_mod = builder.build()
  1520. traced_mod.graph._reset_ids()
  1521. return traced_mod
  1522. finally:
  1523. set_symbolic_shape(use_sym_shape)
  1524. set_active_module_tracer(None)
  1525. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台