You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 80 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import builtins
  9. import collections
  10. import copy
  11. import ctypes
  12. import fnmatch
  13. import functools
  14. import inspect
  15. import keyword
  16. import re
  17. import weakref
  18. from inspect import getcallargs, getmembers, isclass, ismethod
  19. from itertools import chain
  20. from types import FunctionType
  21. from typing import (
  22. Any,
  23. Callable,
  24. Dict,
  25. Iterable,
  26. List,
  27. Optional,
  28. Sequence,
  29. Tuple,
  30. Type,
  31. Union,
  32. )
  33. from megengine import tensor
  34. from .. import functional as F
  35. from .. import get_logger
  36. from .. import module as M
  37. from ..core._imperative_rt.core2 import Tensor as RawTensor
  38. from ..core._imperative_rt.core2 import (
  39. is_tracing_module,
  40. set_module_tracing,
  41. unset_module_tracing,
  42. )
  43. from ..core._trace_option import set_symbolic_shape
  44. from ..core.tensor.array_method import ArrayMethodMixin
  45. from ..module import Module
  46. from ..module.qat import QATModule
  47. from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize
  48. from ..quantization.observer import (
  49. ExponentialMovingAverageObserver,
  50. HistogramObserver,
  51. MinMaxObserver,
  52. Observer,
  53. PassiveObserver,
  54. SyncExponentialMovingAverageObserver,
  55. SyncMinMaxObserver,
  56. )
  57. from ..tensor import Tensor
  58. from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
  59. from .fake_quant import FakeQuantize as TM_FakeQuant
  60. from .module_tracer import (
  61. PatchedFn,
  62. Patcher,
  63. active_module_tracer,
  64. get_tensor_wrapable_method,
  65. module_tracer,
  66. set_active_module_tracer,
  67. )
  68. from .node import ModuleNode, Node, NodeMixin, TensorNode
  69. from .pytree import ArgsIndex, tree_flatten
  70. from .utils import replace_container_with_module_container
  71. logger = get_logger(__name__)
  72. def _is_builtin_name(name: str) -> bool:
  73. return (
  74. name in builtins.__dict__
  75. or name in keyword.kwlist
  76. or name in {"inf", "nan", "NoneType"}
  77. )
  78. def _is_leaf(node):
  79. assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
  80. type(node)
  81. )
  82. return isinstance(node, RawTensor)
  83. _enable_node_to_tensor = False
  84. def _convert_node_flag():
  85. return _enable_node_to_tensor
  86. def _set_convert_node_flag(flag: bool = False):
  87. global _enable_node_to_tensor
  88. pre_flag = _enable_node_to_tensor
  89. _enable_node_to_tensor = flag
  90. return pre_flag
  91. def _node_to_tensor(*args, **kwargs):
  92. tensors = []
  93. nodes, tree_def = tree_flatten((args, kwargs))
  94. for n in nodes:
  95. if isinstance(n, TensorNode):
  96. if n.top_graph is not None:
  97. active_module_tracer().current_scope()._add_input(n)
  98. value = n.value
  99. if value is None:
  100. flag = _set_convert_node_flag(False)
  101. unset_module_tracing()
  102. value = F.zeros(shape=n._shape, dtype=n._dtype)
  103. set_module_tracing()
  104. _set_convert_node_flag(flag)
  105. orig_n = NodeMixin.get(value, None)
  106. if orig_n is None or "setitem" not in orig_n._name:
  107. NodeMixin.wrap_safe(value, n)
  108. tensors.append(value)
  109. else:
  110. tensors.append(n)
  111. tensors = tree_def.unflatten(tensors)
  112. return tensors
  113. def _tensor_to_node(tensors):
  114. if tensors is None:
  115. return None
  116. nodes = []
  117. tensors, out_def = tree_flatten(tensors)
  118. for t in tensors:
  119. if isinstance(t, Tensor):
  120. n = NodeMixin.get(t, None)
  121. if isinstance(n, TensorNode):
  122. n.value = t
  123. nodes.append(n)
  124. else:
  125. nodes.append(t)
  126. else:
  127. nodes.append(t)
  128. nodes = out_def.unflatten(nodes)
  129. return nodes
  130. def _wrap_method_to_tensor_node():
  131. def _any_method(name):
  132. def _any(*args, **kwargs):
  133. args, kwargs = _node_to_tensor(*args, **kwargs)
  134. attr = getattr(args[0], name)
  135. outs = attr
  136. if callable(attr):
  137. outs = attr(*(args[1:]), **kwargs)
  138. if name == "__setitem__":
  139. _node_to_tensor(outs)
  140. return None
  141. outs = _tensor_to_node(outs)
  142. return outs
  143. return _any
  144. tensor_method_patch = []
  145. for method in get_tensor_wrapable_method():
  146. patch = PatchedFn(TensorNode, method)
  147. if type(getattr(Tensor, method)) == property:
  148. patch.set_func(property(_any_method(method)))
  149. else:
  150. patch.set_func(_any_method(method))
  151. tensor_method_patch.append(patch)
  152. return tensor_method_patch
  153. def _convert_node_and_tensor(orig_func):
  154. @functools.wraps(orig_func)
  155. def _convert(*args, **kwargs):
  156. if _convert_node_flag() and is_tracing_module():
  157. args, kwargs = _node_to_tensor(*args, **kwargs)
  158. rst = orig_func(*args, **kwargs, method_func=_convert)
  159. rst = _tensor_to_node(rst)
  160. return rst
  161. else:
  162. rst = orig_func(*args, **kwargs)
  163. return rst
  164. return _convert
  165. def _wrap_mnode_getattr(orig_getattr):
  166. @functools.wraps(orig_getattr)
  167. def wraped_fn(self, name):
  168. obj = self.owner
  169. if self.top_graph is not None:
  170. active_module_tracer().current_scope()._add_input(self)
  171. attr = getattr(obj, name)
  172. node = attr
  173. full_name = None
  174. if id(attr) in active_module_tracer().id2name:
  175. full_name = active_module_tracer().id2name[id(attr)]
  176. if not isinstance(attr, TracedModuleBuilder):
  177. if isinstance(attr, Module):
  178. attr = TracedModuleBuilder(attr)
  179. setattr(obj, name, attr)
  180. active_module_tracer().id2name[id(attr)] = full_name
  181. if isinstance(attr, (NodeMixin, RawTensor)):
  182. if full_name:
  183. scope_name = active_module_tracer().current_scope()._module_name
  184. if scope_name:
  185. full_name = full_name[len(scope_name) + 1 :]
  186. else:
  187. full_name = name
  188. else:
  189. full_name = name
  190. NodeMixin.wrap(
  191. attr,
  192. lambda: GetAttr.make(
  193. self,
  194. name,
  195. type=NodeMixin.get_wrapped_type(attr),
  196. orig_name=full_name,
  197. ),
  198. )
  199. if isinstance(attr, (NodeMixin, RawTensor)):
  200. node = NodeMixin.get(attr)
  201. if isinstance(node, ModuleNode):
  202. node._owner = weakref.ref(attr)
  203. return node
  204. return wraped_fn
  205. def _wrap_mnode_call(orig_call):
  206. @functools.wraps(orig_call)
  207. def wraped_fn(self, *args, **kwargs):
  208. obj = self.owner
  209. if self.top_graph is not None:
  210. active_module_tracer().current_scope()._add_input(self)
  211. rst = obj(*args, **kwargs)
  212. return rst
  213. return wraped_fn
  214. def _init_id2name(mod: Module, prefix: str = ""):
  215. id2name = {
  216. id(m): "%s.%s" % (prefix, key)
  217. for key, m in chain(
  218. mod.named_modules(), mod.named_parameters(), mod.named_buffers()
  219. )
  220. }
  221. return id2name
  222. class _InsertExprs:
  223. def __init__(self, graph, expr: Optional[Expr] = None):
  224. self.graph = graph
  225. while graph.top_graph is not None:
  226. graph = graph.top_graph
  227. assert graph.inputs[0].owner._is_top
  228. self.root_graph = graph
  229. self.global_scope = InternalGraph(
  230. graph._name, graph._prefix_name, graph._module_name
  231. )
  232. self.global_scope._used_names.update(graph._used_names)
  233. self.expr = expr
  234. self._tensor_method_patch = None
  235. def __enter__(self):
  236. self.use_sym_shape = set_symbolic_shape(True)
  237. node_id, expr_id = self.root_graph._total_ids
  238. Node._set_next_id(node_id)
  239. Expr._set_next_id(expr_id)
  240. set_module_tracing()
  241. _set_convert_node_flag(True)
  242. assert active_module_tracer() is None
  243. module = self.graph.inputs[0].owner
  244. _wrap_func = lambda x: _convert_node_and_tensor(_wrapped_function(x))
  245. set_active_module_tracer(
  246. module_tracer(_wrap_func, _init_id2name(module, self.graph._module_name))
  247. )
  248. active_module_tracer().patcher.__enter__()
  249. for cls, name, func in [
  250. [ModuleNode, "__getattr__", _wrap_mnode_getattr],
  251. [ModuleNode, "__call__", _wrap_mnode_call],
  252. [TracedModuleBuilder, "__call__", _convert_node_and_tensor],
  253. ]:
  254. active_module_tracer().patcher.patch_function(cls, name, func)
  255. self._tensor_method_patch = _wrap_method_to_tensor_node()
  256. active_module_tracer().push_scope(self.global_scope)
  257. def __exit__(self, ty, va, tr):
  258. if va is not None:
  259. return False
  260. set_symbolic_shape(self.use_sym_shape)
  261. unset_module_tracing()
  262. active_module_tracer().patcher.__exit__(ty, va, tr)
  263. _set_convert_node_flag(False)
  264. while self._tensor_method_patch:
  265. pf = self._tensor_method_patch.pop()
  266. pf.set_func(pf.origin_fn)
  267. module = self.graph.inputs[0].owner
  268. for mod, parent in module.modules(with_parent=True):
  269. name = mod._name
  270. if isinstance(mod, TracedModuleBuilder):
  271. mod = mod.build()
  272. if hasattr(mod, "graph"):
  273. for node in mod.graph.nodes():
  274. node.value = None
  275. setattr(parent, name, mod)
  276. set_active_module_tracer(None)
  277. for node in self.global_scope.nodes():
  278. node.value = None
  279. extra_inp_nodes = set(self.global_scope.inputs)
  280. max_inp_expr_idx = -1
  281. for node in extra_inp_nodes:
  282. assert (
  283. node.top_graph == self.graph
  284. ), "The input node ({}) is not in the graph ({})".format(node, self.graph)
  285. if isinstance(node, TensorNode) and node.expr in self.graph._exprs:
  286. max_inp_expr_idx = max(
  287. max_inp_expr_idx, self.graph._exprs.index(node.expr)
  288. )
  289. max_inp_expr_idx += 1
  290. insert_index = -1
  291. if self.expr is not None:
  292. insert_index = self.graph._exprs.index(self.expr)
  293. insert_index += 1
  294. if insert_index < max_inp_expr_idx:
  295. insert_index = max_inp_expr_idx
  296. anchor_index = insert_index - 1
  297. if anchor_index >= 0:
  298. logger.info(
  299. "The new expr will be inserted after ( {} )".format(
  300. self.graph._exprs[anchor_index]
  301. )
  302. )
  303. for expr in self.global_scope._exprs:
  304. self.graph._exprs.insert(insert_index, expr)
  305. insert_index += 1
  306. self.graph._used_names.update(self.global_scope._used_names)
  307. self.root_graph._total_ids = (Node._get_next_id(), Expr._get_next_id())
  308. self.root_graph.inputs[0].owner._update_ref()
  309. return True
  310. class InternalGraph:
  311. r"""``InternalGraph`` is the main data structure used in the TracedModule.
  312. It is used to represent the execution procedure of Module's forward method.
  313. For example, the following code
  314. .. code-block::
  315. import megengine.random as rand
  316. import megengine.functional as F
  317. import megengine.module as M
  318. import megengine.traced_module as tm
  319. class MyModule(M.Module):
  320. def __init__(self):
  321. super().__init__()
  322. self.param = rand.normal(size=(3, 4))
  323. self.linear = M.Linear(4, 5)
  324. def forward(self, x):
  325. return F.relu(self.linear(x + self.param))
  326. net = MyModule()
  327. inp = F.zeros(shape = (3, 4))
  328. traced_module = tm.trace_module(net, inp)
  329. Will produce the following ``InternalGraph``::
  330. print(traced_module.graph)
  331. .. code-block:: text
  332. MyModule.Graph (self, x) {
  333. %2: linear = getattr(self, "linear") -> (Linear)
  334. %3: param = getattr(self, "param") -> (Tensor)
  335. %4: add_out = x.__add__(param, )
  336. %5: linear_out = linear(add_out, )
  337. %6: relu_out = nn.relu(linear_out, )
  338. return relu_out
  339. }
  340. """
  341. _exprs = None # type: List[Expr]
  342. _inputs = None # type: List[Node]
  343. _outputs = None # type: List[Node]
  344. _top_graph = None # type: InternalGraph
  345. _total_ids = None # type: List[int]
  346. def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""):
  347. self._exprs = []
  348. self._inputs = []
  349. self._outputs = []
  350. self._watch_point = []
  351. self._end_point = []
  352. self._used_names = {}
  353. self._rst = collections.defaultdict(list)
  354. self._name = name
  355. self._prefix_name = prefix_name
  356. self._module_name = module_name
  357. def _insert(self, expr):
  358. self._exprs.append(expr)
  359. def _create_unique_name(self, name: str) -> str:
  360. assert isinstance(name, str), "The name must be a str"
  361. name = re.sub("[^0-9a-zA-Z_]+", "_", name)
  362. if name[0].isdigit():
  363. name = "_{}".format(name)
  364. while name in self._used_names or _is_builtin_name(name):
  365. match = re.match(r"(.*)_(\d+)$", name)
  366. if match is None:
  367. name = name + "_1"
  368. else:
  369. base, num = match.group(1, 2)
  370. name = "{}_{}".format(base, int(num) + 1)
  371. self._used_names.setdefault(name)
  372. return name
  373. @property
  374. def inputs(self) -> List[Node]:
  375. r"""Get the list of input Nodes of this graph.
  376. Returns:
  377. A list of ``Node``.
  378. """
  379. return self._inputs
  380. @property
  381. def outputs(self) -> List[Node]:
  382. r"""Get the list of output Nodes of this graph.
  383. Returns:
  384. A list of Node.
  385. """
  386. return self._outputs
  387. @property
  388. def top_graph(self):
  389. r"""Get the parent graph of this graph.
  390. Returns:
  391. An ``InternalGraph``.
  392. """
  393. if self._top_graph:
  394. return self._top_graph()
  395. return None
  396. def exprs(self, recursive=True):
  397. r"""Get the Exprs that constitute this graph.
  398. Args:
  399. recursive: whether to get the Exprs in the subgraph.
  400. Default: True
  401. Returns:
  402. A ``ExprFilter`` containing all Exprs of this graph.
  403. """
  404. return ExprFilter(_expr_iter(self, recursive))
  405. def nodes(self, recursive=True):
  406. r"""Get the Nodes that constitute this graph.
  407. Args:
  408. recursive: whether to get the Nodes in the subgraph.
  409. Default: True
  410. Returns:
  411. A ``NodeFilter`` containing all Nodes of this graph.
  412. """
  413. return NodeFilter(_node_iter(self, recursive))
  414. def get_function_by_type(self, func: Callable = None, recursive=True):
  415. r"""Filter Exprs by the type of ``CallFunction``.
  416. Args:
  417. func: a built-in function, such as ``F.relu``.
  418. recursive: whether to get the Exprs in the subgraph.
  419. Default: True
  420. Returns:
  421. A :class:`~.TracedModule.ExprFilterCallFunction`.
  422. """
  423. return self.exprs(recursive).call_function(func)
  424. def get_method_by_type(self, method: str = None, recursive=True):
  425. r"""Filter Exprs by the type of ``CallMethod``.
  426. Args:
  427. method: a method string, such as "__add__".
  428. recursive: whether to get the Exprs in the subgraph.
  429. Default: True
  430. Returns:
  431. A :class:`~.TracedModule.ExprFilterCallMethod`.
  432. """
  433. return self.exprs(recursive).call_method(method)
  434. def get_expr_by_id(self, expr_id: List[int] = None, recursive=True):
  435. r"""Filter Exprs by their ``id``.
  436. Args:
  437. expr_id: a list of :class:`int`.
  438. recursive: whether to get the Exprs in the subgraph.
  439. Default: True
  440. Returns:
  441. A :class:`~.TracedModule.ExprFilterExprId`.
  442. """
  443. return self.exprs(recursive).expr_id(expr_id)
  444. def get_module_by_type(self, module_cls: Module, recursive=True):
  445. r"""Filter Nodes by the ``module_type`` of ``ModuleNode``.
  446. Args:
  447. module_cls: a subclass of :class:`~.Module`.
  448. recursive: whether to get the Nodes in the subgraph.
  449. Default: True
  450. Returns:
  451. A :class:`~.TracedModule.NodeFilterType`.
  452. """
  453. assert issubclass(module_cls, Module)
  454. return self.nodes(recursive).type(module_cls)
  455. def get_node_by_id(self, node_id: List[int] = None, recursive=True):
  456. r"""Filter Nodes by their ``id``.
  457. The ``id`` of the ``Node`` can be obtained by the following code
  458. .. code-block::
  459. # node : Node
  460. print("{:i}".format(node))
  461. print(node.__format__("i"))
  462. # graph : InternalGraph
  463. print("{:i}".format(graph))
  464. print(graph.__format__("i"))
  465. Args:
  466. node_id: a list of :class:`int`.
  467. recursive: whether to get the Nodes in the subgraph.
  468. Default: True
  469. Returns:
  470. A :class:`~.TracedModule.NodeFilterNodeId`.
  471. """
  472. return self.nodes(recursive).node_id(node_id)
  473. def get_node_by_name(
  474. self, name: str = None, ignorecase: bool = True, recursive=True
  475. ):
  476. r"""Filter Nodes by their full name.
  477. The full name of the ``Node`` can be obtained by the following code
  478. .. code-block::
  479. # node : Node
  480. print("{:p}".format(node))
  481. print(node.__format__("p"))
  482. # graph : InternalGraph
  483. print("{:p}".format(graph))
  484. print(graph.__format__("p"))
  485. Args:
  486. name: a string in glob syntax that can contain ``?`` and
  487. ``*`` to match a single or arbitrary characters.
  488. ignorecase: whether to ignroe case.
  489. Default: True
  490. recursive: whether to get the Nodes in the subgraph.
  491. Default: True
  492. Returns:
  493. A :class:`~.TracedModule.NodeFilterName`.
  494. """
  495. return self.nodes(recursive).name(name, ignorecase)
  496. def _add_input(self, i):
  497. self._inputs.append(i)
  498. def _add_output(self, o):
  499. self._outputs.append(o)
  500. def _replace_inputs_outputs(self, repl_dict, prefix_name="", module_name=""):
  501. for node, repl_node in repl_dict.items():
  502. assert node in self._inputs or node in self._outputs
  503. for i in node.users:
  504. if i not in repl_node.users:
  505. repl_node.users.append(i)
  506. for idx, i in enumerate(self._inputs):
  507. if i in repl_dict:
  508. self._inputs[idx] = repl_dict[i]
  509. for idx, o in enumerate(self._outputs):
  510. if o in repl_dict:
  511. repl_dict[o]._orig_name = "{}{}".format(module_name, o._orig_name)
  512. self._outputs[idx] = repl_dict[o]
  513. for expr in self._exprs:
  514. for idx, i in enumerate(expr.inputs):
  515. assert isinstance(
  516. i._name, str
  517. ), "The node ({}) name must be a str".format(i)
  518. if i in repl_dict:
  519. expr.inputs[idx] = repl_dict[i]
  520. elif isinstance(i, TensorNode) and prefix_name not in i._name:
  521. if i.top_graph != active_module_tracer().current_scope():
  522. i._name = (
  523. active_module_tracer()
  524. .current_scope()
  525. ._create_unique_name(prefix_name + i._name.lstrip("_"))
  526. )
  527. i._orig_name = "{}{}".format(module_name, i._orig_name)
  528. for idx, o in enumerate(expr.outputs):
  529. assert isinstance(
  530. o._name, str
  531. ), "The node ({}) name must be a str".format(i)
  532. if o in repl_dict:
  533. expr.outputs[idx] = repl_dict[o]
  534. expr.outputs[idx].expr = expr
  535. elif isinstance(o, TensorNode) and prefix_name not in i._name:
  536. if o.top_graph != active_module_tracer().current_scope():
  537. o._name = (
  538. active_module_tracer()
  539. .current_scope()
  540. ._create_unique_name(prefix_name + o._name.lstrip("_"))
  541. )
  542. o._orig_name = "{}{}".format(module_name, o._orig_name)
  543. def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
  544. r"""Get the dependent Exprs of the ``nodes``.
  545. Args:
  546. nodes: a list of :class:`Node`.
  547. Returns:
  548. A list of dependent :class:`Expr`.
  549. """
  550. if not isinstance(nodes, Sequence):
  551. nodes = (nodes,)
  552. ret = list()
  553. queue = list(nodes)
  554. visited_queue = list()
  555. while queue:
  556. node = queue.pop()
  557. visited_queue.append(node)
  558. expr = node.expr
  559. if expr not in ret:
  560. ret.append(expr)
  561. for i in expr.inputs:
  562. if i not in queue and i not in visited_queue:
  563. queue.append(i)
  564. return ret
  565. def reset_inputs(self, *args, **kwargs):
  566. forma_mnode = self.inputs[0]
  567. actual_mnodes = forma_mnode.actual_node
  568. call_nodes = []
  569. for n in actual_mnodes:
  570. for c_expr in n.users:
  571. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  572. call_nodes.append((c_expr, n))
  573. moudle = forma_mnode.owner
  574. assert moudle._is_top, "reset_inputs only support the top-level graph"
  575. inputs, tree_def = tree_flatten(((moudle, *args), kwargs))
  576. def create_node(val: Tensor):
  577. node = Input(type=TensorNode).outputs[0]
  578. node.shape = val.shape
  579. node.dtype = val.dtype
  580. return node
  581. formal_node_inputs = [
  582. forma_mnode,
  583. ]
  584. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  585. if call_nodes:
  586. org_argdef = call_nodes[0][0].arg_def
  587. for v in inputs[1:]:
  588. assert isinstance(v, RawTensor)
  589. formal_node_inputs.append(create_node(v))
  590. actual_nodes = []
  591. for e, n in call_nodes:
  592. e.arg_def = tree_def
  593. actual_node_inputs = [
  594. n,
  595. ]
  596. for v in inputs[1:]:
  597. actual_node_inputs.append(create_node(v))
  598. for org_n in e.inputs:
  599. org_n.users.pop(e)
  600. e.inputs[:] = actual_node_inputs
  601. e.const_val = []
  602. actual_nodes.append(actual_node_inputs[1:])
  603. self._inputs[:] = formal_node_inputs
  604. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  605. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  606. return formal_node_inputs[1:]
  607. def add_input_node(
  608. self, shape: Tuple[int], dtype: str = "float32", name: str = "args"
  609. ):
  610. r"""Add an input node to the graph.
  611. The new Node will be the last of the positional arguments.
  612. Args:
  613. shape: the shape of the new input Node.
  614. dtype: the dtype of the new input Node.
  615. Default: float32
  616. name: the name of the new input Node. When the name is used in the graph,
  617. a suffix will be added to it.
  618. """
  619. forma_mnode = self.inputs[0]
  620. actual_mnodes = forma_mnode.actual_node
  621. moudle = forma_mnode.owner
  622. assert moudle._is_top, "add_input_node only support the top-level graph"
  623. call_nodes = []
  624. for n in actual_mnodes:
  625. for c_expr in n.users:
  626. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  627. call_nodes.append(c_expr)
  628. def create_node(name=None, is_input: bool = True):
  629. if is_input:
  630. node = Input(type=TensorNode, name=name).outputs[0]
  631. else:
  632. node = TensorNode(expr=None, name=None)
  633. node.shape = shape
  634. node.dtype = dtype
  635. return node
  636. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  637. if call_nodes:
  638. org_argdef = call_nodes[0].arg_def
  639. args, kwargs = org_argdef.unflatten(self._inputs)
  640. formal_inp_node = create_node(self._create_unique_name(name), True)
  641. inputs, tree_def = tree_flatten(
  642. ((*args, formal_inp_node), kwargs),
  643. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  644. )
  645. self._inputs[:] = inputs[:]
  646. actual_inp_nodes = []
  647. for e in call_nodes:
  648. args, kwargs = e.unflatten_args(e.inputs)
  649. args = args + (create_node(False),)
  650. inputs, tree_def = tree_flatten(
  651. (args, kwargs),
  652. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  653. )
  654. e.inputs[:] = inputs[:]
  655. e.arg_def = tree_def
  656. actual_inp_nodes.append(args[-1])
  657. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  658. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  659. return formal_inp_node
  660. def reset_outputs(self, outputs):
  661. r"""Reset the output Nodes of the graph.
  662. .. note::
  663. This method only supports resetting the output of graphs
  664. that do not have a parent graph.
  665. Args:
  666. outputs: an object which inner element is Node. Support tuple, list
  667. dict, etc.
  668. For example, the following code
  669. .. code-block::
  670. import megengine.functional as F
  671. import megengine.module as M
  672. import megengine.traced_module as tm
  673. class MyModule(M.Module):
  674. def forward(self, x):
  675. x = x + 1
  676. return x
  677. net = MyModule()
  678. inp = F.zeros(shape = (1, ))
  679. traced_module = tm.trace_module(net, inp)
  680. graph = traced_module.graph
  681. inp_node = graph.inputs[1]
  682. out_node = graph.outputs[0]
  683. graph.reset_outputs((out_node, {"input": inp_node}))
  684. out = traced_module(inp)
  685. Will produce the following ``InternalGraph`` and ``out``::
  686. print(graph)
  687. print(out)
  688. .. code-block:: text
  689. MyModule.Graph (self, x) {
  690. %2: add_out = x.__add__(1, )
  691. return add_out, x
  692. }
  693. (Tensor([1.], device=xpux:0), {'input': Tensor([0.], device=xpux:0)})
  694. """
  695. outputs, out_def = tree_flatten(
  696. outputs, is_leaf=lambda x: isinstance(x, TensorNode),
  697. )
  698. forma_mnode = self.inputs[0]
  699. moudle = forma_mnode.owner
  700. assert moudle._is_top, "reset_outputs only support the top graph"
  701. actual_mnodes = forma_mnode.actual_node
  702. call_nodes = []
  703. for n in actual_mnodes:
  704. for c_expr in n.users:
  705. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  706. call_nodes.append((c_expr))
  707. def create_node(val: TensorNode, expr: Expr):
  708. node = TensorNode(expr)
  709. node.shape = val.shape
  710. node.dtype = val.dtype
  711. return node
  712. tree_def = list(moudle.argdef_graph_map.keys())[0]
  713. if call_nodes:
  714. tree_def = call_nodes[0].arg_def
  715. actual_nodes = []
  716. for e in call_nodes:
  717. actual_node_outputs = []
  718. for v in outputs:
  719. actual_node_outputs.append(create_node(v, e))
  720. e.outputs[:] = actual_node_outputs
  721. e.out_def = out_def
  722. actual_nodes.append(actual_node_outputs)
  723. self._outputs[:] = outputs
  724. moudle.argdef_outdef_map[tree_def] = out_def
  725. return actual_nodes
  726. def add_output_node(self, node: TensorNode):
  727. r"""Add an output node to the Graph.
  728. The Graph output will become a ``tuple`` after calling ``add_output_node``.
  729. The first element of the ``tuple`` is the original output, and the second
  730. is the ``node``.
  731. For example, the following code
  732. .. code-block::
  733. import megengine.functional as F
  734. import megengine.module as M
  735. import megengine.traced_module as tm
  736. class MyModule(M.Module):
  737. def forward(self, x):
  738. x = x + 1
  739. return x
  740. net = MyModule()
  741. inp = F.zeros(shape = (1, ))
  742. traced_module = tm.trace_module(net, inp)
  743. graph = traced_module.graph
  744. inp_node = graph.inputs[1]
  745. out_node = graph.outputs[0]
  746. graph.add_output_node(inp_node)
  747. graph.add_output_node(out_node)
  748. out = traced_module(inp)
  749. Will produce the following ``InternalGraph`` and ``out``::
  750. print(graph)
  751. print(out)
  752. .. code-block:: text
  753. MyModule.Graph (self, x) {
  754. %2: add_out = x.__add__(1, )
  755. return add_out, x, add_out
  756. }
  757. ((Tensor([1.], device=xpux:0), Tensor([0.], device=xpux:0)), Tensor([1.], device=xpux:0))
  758. """
  759. forma_mnode = self.inputs[0]
  760. moudle = forma_mnode.owner
  761. assert moudle._is_top, "add_output_node only support the top graph"
  762. actual_mnodes = forma_mnode.actual_node
  763. call_nodes = []
  764. for n in actual_mnodes:
  765. for c_expr in n.users:
  766. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  767. call_nodes.append((c_expr))
  768. def create_node(val: TensorNode, expr: Expr):
  769. node = TensorNode(expr)
  770. node.shape = val.shape
  771. node.dtype = val.dtype
  772. return node
  773. tree_def = list(moudle.argdef_graph_map.keys())[0]
  774. if call_nodes:
  775. tree_def = call_nodes[0].arg_def
  776. org_out_def = moudle.argdef_outdef_map[tree_def]
  777. org_outs = org_out_def.unflatten(self._outputs)
  778. outputs, out_def = tree_flatten(
  779. (org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode),
  780. )
  781. self._outputs[:] = outputs
  782. actual_out_nodes = []
  783. for e in call_nodes:
  784. actual_node = create_node(node, e)
  785. org_outs = org_out_def.unflatten(e.outputs)
  786. outputs, out_def = tree_flatten(
  787. (org_outs, actual_node), is_leaf=lambda x: isinstance(x, TensorNode),
  788. )
  789. e.outputs[:] = outputs
  790. e.out_def = out_def
  791. actual_out_nodes.append(actual_node)
  792. moudle.argdef_outdef_map[tree_def] = out_def
  793. return actual_out_nodes
  794. def insert_exprs(self, expr: Optional[Expr] = None):
  795. r"""Initialize the trace mode and insertion position.
  796. When used within a 'with' statement, this will temporary set the trace mode and
  797. then restore normal mode when the with statement exits::
  798. with graph.insert_exprs(e): # set the trace mode
  799. ... # trace function or module
  800. ... # inert exprs into graph and resotre normal mode
  801. Args:
  802. expr: the ``expr`` after which to insert. If None, the insertion position will be
  803. automatically set based on the input node.
  804. Returns:
  805. A resource manager that will initialize trace mode on ``__enter__`` and
  806. restore normal mode on ``__exit__``.
  807. """
  808. if expr is not None:
  809. assert expr.top_graph == self, "Expr to insert after is not in graph."
  810. return _InsertExprs(self, expr)
  811. def replace_node(self, repl_dict: Dict[Node, Node]):
  812. r"""Replace the Nodes in the graph.
  813. Args:
  814. repl_dict: the map {old_Node: new_Node} that specifies how to replace the Nodes.
  815. """
  816. while repl_dict:
  817. node, repl_node = repl_dict.popitem()
  818. assert type(node) == type(
  819. repl_node
  820. ), "The type of {}({}) and {}({}) are not the same".format(
  821. node, type(node).__name__, repl_node, type(repl_node).__name__
  822. )
  823. # check graph inputs and outputs
  824. for i, n in enumerate(self.outputs):
  825. if n is node:
  826. self.outputs[i] = repl_node
  827. # update users of node and repl_node
  828. # update inputs of expr in node.users
  829. graph = repl_node.top_graph
  830. assert graph is not None
  831. assert graph is self
  832. index = -1
  833. if not isinstance(repl_node.expr, Input):
  834. index = graph._exprs.index(repl_node.expr)
  835. dep_exprs = self.get_dep_exprs(repl_node)
  836. i = 0
  837. while i < len(node.users):
  838. n = node.users[i]
  839. if n in graph._exprs and index >= graph._exprs.index(n):
  840. i += 1
  841. continue
  842. if n in dep_exprs:
  843. logger.info("Find a loop: ignore this replacement once")
  844. logger.info("node: %s" % node.__repr__())
  845. logger.info("expr: %s" % n.__repr__())
  846. i += 1
  847. continue
  848. repl_node.users.append(n)
  849. node.users.pop(i)
  850. idx = n.inputs.index(node)
  851. n.inputs[idx] = repl_node
  852. def compile(self):
  853. r"""Delete unused expr."""
  854. dep_exprs = self.get_dep_exprs(self.outputs)
  855. i = 0
  856. while i < len(self._exprs):
  857. expr = self._exprs[i]
  858. if expr in dep_exprs or expr._disable_remove:
  859. i += 1
  860. continue
  861. for n in expr.inputs:
  862. n.users.remove(expr)
  863. self._exprs.remove(expr)
  864. def _reset_ids(self):
  865. for total_expr_id, expr in enumerate(self.exprs()):
  866. expr._id = total_expr_id
  867. for total_node_id, node in enumerate(self.nodes()):
  868. node._id = total_node_id
  869. self._total_ids = (total_node_id + 1, total_expr_id + 1)
  870. def interpret(self, *inputs):
  871. node2value = {}
  872. end_nodes_set = set(self._end_point)
  873. endnode2value = {}
  874. def get_all_endnode_val(n, v):
  875. if n in end_nodes_set:
  876. endnode2value[n] = v
  877. end_nodes_set.remove(n)
  878. return not end_nodes_set
  879. return False
  880. ref_count = lambda n: len(n.users) + (1 if n in self._outputs else 0)
  881. for n, v in zip(self._inputs, inputs):
  882. if ref_count(n) > 0:
  883. node2value[n] = [v, ref_count(n)]
  884. if n in self._watch_point:
  885. self._rst[n].append(v)
  886. if n in self._end_point and get_all_endnode_val(n, v):
  887. return list(endnode2value[i] for i in self._end_point)
  888. for expr in self._exprs:
  889. values = expr.interpret(*list(node2value[i][0] for i in expr.inputs))
  890. for n in expr.inputs:
  891. node2value[n][1] -= 1
  892. if node2value[n][1] == 0:
  893. node2value.pop(n)
  894. if values is not None:
  895. for n, v in zip(expr.outputs, values):
  896. if ref_count(n) > 0:
  897. node2value[n] = [v, ref_count(n)]
  898. if n in self._watch_point:
  899. self._rst[n] = v
  900. if self._end_point and get_all_endnode_val(n, v):
  901. return list(endnode2value[i] for i in self._end_point)
  902. return list(node2value[i][0] for i in self._outputs)
  903. def eval(self, *inputs: Tuple[Tensor]):
  904. r"""Call this method to execute the graph.
  905. Args:
  906. inputs: the tensors corresponding to the ``graph.inputs[1:]``.
  907. """
  908. assert len(inputs) == len(self._inputs) - 1
  909. inp = [self._inputs[0].owner] + list(inputs)
  910. return self.interpret(*inp)
  911. def __repr__(self):
  912. return self.__format__()
  913. def __format__(self, format_spec: str = "") -> str:
  914. saved_format_spec = Node._set_format_spec(format_spec)
  915. name = ""
  916. if self._name:
  917. name = "%s.Graph" % self._name
  918. res = "{} ({}) {{\n\t{}\n\treturn {}\n}}".format(
  919. name,
  920. ", ".join(str(i) for i in self._inputs),
  921. "\n\t".join("{}".format(str(i)) for i in self._exprs),
  922. ", ".join(str(i) for i in self._outputs),
  923. )
  924. Node._set_format_spec(saved_format_spec)
  925. return res
  926. def __getstate__(self):
  927. state = self.__dict__.copy()
  928. if "_top_graph" in state:
  929. state.pop("_top_graph")
  930. return state
  931. def _get_meth_name(obj, func):
  932. tp = obj if isinstance(obj, type) else type(obj)
  933. for cls in tp.mro():
  934. for k, v in cls.__dict__.items():
  935. if v == func:
  936. return k
  937. return None
  938. def _wrapped_function(orig_func):
  939. @functools.wraps(orig_func)
  940. def wrapped_fn(*args, **kwargs):
  941. method_func = wrapped_fn
  942. if "method_func" in kwargs:
  943. method_func = kwargs.pop("method_func")
  944. if is_tracing_module():
  945. unset_module_tracing()
  946. inputs, tree_def = tree_flatten((args, kwargs))
  947. for i in inputs:
  948. if not NodeMixin.get(i, None):
  949. if isinstance(i, (RawTensor, NodeMixin)):
  950. NodeMixin.wrap_safe(i, Constant.make(i))
  951. meth_name, arg_type = None, None
  952. if args:
  953. meth_name = _get_meth_name(args[0], method_func)
  954. arg_type = args[0] if isinstance(args[0], type) else type(args[0])
  955. if meth_name and arg_type and issubclass(arg_type, RawTensor):
  956. self = inputs[0]
  957. if meth_name == "__new__":
  958. if all([not isinstance(i, RawTensor) for i in inputs]):
  959. # only trace Tensor.__new__() when there are tensors in args
  960. set_module_tracing()
  961. return orig_func(*args, **kwargs)
  962. if isinstance(args[1], RawTensor):
  963. node = NodeMixin.get(inputs[1])
  964. inputs[1] = copy.copy(inputs[1])
  965. # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing.
  966. NodeMixin.wrap_safe(inputs[1], node)
  967. args, kwargs = tree_def.unflatten(inputs)
  968. call_node = CallMethod.make(self, meth_name)
  969. else:
  970. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  971. call_node.add_inputs(inputs[1:])
  972. else:
  973. call_node = CallFunction.make(orig_func)
  974. call_node.add_inputs(inputs)
  975. call_node.arg_def = tree_def
  976. rst = orig_func(*args, **kwargs)
  977. if meth_name == "__setitem__":
  978. rst = self
  979. if rst is not None:
  980. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  981. call_node.out_def = out_def
  982. else:
  983. outputs = None
  984. call_node.add_outputs(outputs)
  985. set_module_tracing()
  986. return rst
  987. return orig_func(*args, **kwargs)
  988. return wrapped_fn
  989. class TracedModuleBuilder(NodeMixin):
  990. _mod = None # type: Module
  991. _body = None # type: InternalGraph
  992. _is_builtin = None # type: bool
  993. _argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
  994. _argdef_outdef_map = None # type: Dict[Treedef, Treedef]
  995. nodes = None
  996. __builder_attributes__ = [
  997. "_mod",
  998. "_body",
  999. "_NodeMixin__node",
  1000. "_is_builtin",
  1001. "build",
  1002. "_record_wrapped_nodes",
  1003. "_argdef_graph_map",
  1004. "_argdef_outdef_map",
  1005. "nodes",
  1006. "__class__",
  1007. "__dict__",
  1008. ]
  1009. def __init__(self, mod, is_top_module=False):
  1010. super(TracedModuleBuilder, self).__init__()
  1011. assert isinstance(mod, Module)
  1012. self._mod = mod
  1013. self._body = None
  1014. self._is_top = is_top_module
  1015. self._is_builtin = (
  1016. True
  1017. if isinstance(mod, (Observer, _FakeQuantize))
  1018. else module_tracer.is_builtin(mod)
  1019. )
  1020. if isinstance(self._mod, QATModule):
  1021. unset_module_tracing()
  1022. self._check_qat_module(self._mod)
  1023. set_module_tracing()
  1024. self._argdef_graph_map = {}
  1025. self._argdef_outdef_map = {}
  1026. self.nodes = set()
  1027. # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
  1028. # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
  1029. self.__class__ = type(
  1030. "TracedModuleBuilder",
  1031. (TracedModuleBuilder, mod.__class__),
  1032. dict(TracedModuleBuilder.__dict__),
  1033. )
  1034. def _check_qat_module(self, qat_module):
  1035. def isbuiltin(m):
  1036. return m is None or module_tracer.is_builtin(m)
  1037. if qat_module.with_act:
  1038. act_observer = qat_module.act_observer
  1039. act_fakequant = qat_module.act_fake_quant
  1040. if not isbuiltin(act_observer) or not isbuiltin(act_fakequant):
  1041. qparams = (
  1042. act_observer.get_qparams()
  1043. if hasattr(act_observer, "get_qparams")
  1044. else act_fakequant.get_qparams()
  1045. )
  1046. dtype = (
  1047. act_observer.dtype
  1048. if hasattr(act_observer, "dtype")
  1049. else act_fakequant.dtype
  1050. )
  1051. qat_module.act_observer = None
  1052. qat_module.act_fake_quant = TM_FakeQuant(dtype)
  1053. qat_module.act_fake_quant.set_qparams(qparams)
  1054. if qat_module.with_weight:
  1055. weight_observer = qat_module.weight_observer
  1056. weight_fakequant = qat_module.weight_fake_quant
  1057. if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant):
  1058. qparams = (
  1059. weight_observer.get_qparams()
  1060. if hasattr(weight_observer, "get_qparams")
  1061. else weight_fakequant.get_qparams()
  1062. )
  1063. dtype = (
  1064. weight_observer.dtype
  1065. if hasattr(weight_observer, "dtype")
  1066. else weight_fakequant.dtype
  1067. )
  1068. qat_module.weight_observer = None
  1069. qat_module.weight_fake_quant = TM_FakeQuant(dtype)
  1070. qat_module.weight_fake_quant.set_qparams(qparams)
  1071. def build(self):
  1072. if self._is_builtin or isinstance(self._mod, TracedModule):
  1073. if module_tracer.is_builtin(self._mod) or isinstance(
  1074. self._mod, TracedModule
  1075. ):
  1076. mod_type = type(self._mod)
  1077. else:
  1078. assert isinstance(self._mod, (Observer, _FakeQuantize))
  1079. mod_type = (
  1080. Observer if isinstance(self._mod, Observer) else _FakeQuantize
  1081. )
  1082. for node in self.nodes:
  1083. node.module_type = mod_type
  1084. return self._mod
  1085. else:
  1086. is_qat = isinstance(self._mod, QATModule)
  1087. traced_module = TracedModule(
  1088. self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat
  1089. )
  1090. for _, g in self._argdef_graph_map.items():
  1091. g.compile()
  1092. if self._is_top:
  1093. g._total_ids = (Node._get_next_id(), Expr._get_next_id())
  1094. for k, v in self.__dict__.items():
  1095. if k not in TracedModuleBuilder.__builder_attributes__:
  1096. if isinstance(v, TracedModuleBuilder):
  1097. v = v.build()
  1098. setattr(traced_module, k, v)
  1099. elif isinstance(v, RawTensor):
  1100. setattr(traced_module, k, v)
  1101. if isinstance(self._mod, QATModule):
  1102. unset_module_tracing()
  1103. traced_module.with_act = self._mod.with_act
  1104. traced_module.with_weight = self._mod.with_weight
  1105. if not hasattr(traced_module, "act_fake_quant"):
  1106. traced_module.act_fakequant = None
  1107. if not hasattr(traced_module, "act_observer"):
  1108. traced_module.act_observer = None
  1109. if not hasattr(traced_module, "weight_fake_quant"):
  1110. traced_module.weight_fakequant = None
  1111. if not hasattr(traced_module, "weight_observer"):
  1112. traced_module.weight_observer = None
  1113. set_module_tracing()
  1114. return traced_module
  1115. def _record_wrapped_nodes(self, node):
  1116. self.nodes.add(node)
  1117. def __call__(self, *args, **kwargs):
  1118. assert isinstance(self._mod, Module)
  1119. # prepare args and kwargs for inner graph
  1120. if "method_func" in kwargs:
  1121. kwargs.pop("method_func")
  1122. def mark_constant(x):
  1123. node = NodeMixin.get(x, None)
  1124. if node is None: # capture as constant
  1125. NodeMixin.wrap(x, lambda: Constant.make(x))
  1126. inputs, tree_def = tree_flatten(((self, *args), kwargs))
  1127. for i in inputs:
  1128. mark_constant(i)
  1129. callnode = CallMethod.make(NodeMixin.get(self))
  1130. callnode.add_inputs(inputs[1:])
  1131. callnode.arg_def = tree_def
  1132. if (
  1133. self._is_builtin
  1134. or tree_def in self._argdef_graph_map
  1135. or isinstance(self._mod, TracedModule)
  1136. ):
  1137. unset_module_tracing()
  1138. rst = self._mod(*args, **kwargs)
  1139. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  1140. set_module_tracing()
  1141. if self._is_builtin:
  1142. self._body = None
  1143. elif tree_def in self._argdef_graph_map:
  1144. self._body = self._argdef_graph_map[tree_def]
  1145. else:
  1146. self._mod._is_top = False
  1147. self._body = self._mod.graph
  1148. else:
  1149. self_node = None
  1150. orig_self = NodeMixin.get(self)
  1151. top_graph = active_module_tracer().current_scope()
  1152. graph_prefix_name = top_graph._name
  1153. if top_graph._prefix_name:
  1154. graph_prefix_name = "{}_{}".format(
  1155. top_graph._prefix_name, graph_prefix_name.lstrip("_")
  1156. )
  1157. module_name = orig_self._orig_name
  1158. if top_graph._module_name:
  1159. module_name = "{}.{}".format(top_graph._module_name, module_name)
  1160. self._body = InternalGraph(
  1161. orig_self._name, prefix_name=graph_prefix_name, module_name=module_name
  1162. )
  1163. active_module_tracer().push_scope(self._body)
  1164. # rebind self to new input node
  1165. if self_node:
  1166. NodeMixin.wrap_safe(self, self_node)
  1167. active_module_tracer().current_scope()._add_input(self_node)
  1168. else:
  1169. NodeMixin.wrap_safe(
  1170. self,
  1171. self_node
  1172. if self_node
  1173. else Input.make("self", NodeMixin.get_wrapped_type(self), ""),
  1174. )
  1175. origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
  1176. # prepare args and kwargs for inner graph
  1177. index_args, index_kwargs = tree_def.unflatten(
  1178. [
  1179. ArgsIndex(0),
  1180. *list(ArgsIndex(i + 1) for i in range(len(origin_inp_node))),
  1181. ]
  1182. )
  1183. key2idx = getcallargs(type(self._mod).forward, *index_args, **index_kwargs)
  1184. idx2key = {}
  1185. for k, v in key2idx.items():
  1186. if isinstance(v, ArgsIndex):
  1187. idx2key[v.index] = k
  1188. else:
  1189. flatten_argidx, _ = tree_flatten(v)
  1190. for _i, v in enumerate(flatten_argidx):
  1191. if isinstance(v, ArgsIndex):
  1192. idx2key[v.index] = k + "_%d" % _i
  1193. def wrap(x, name):
  1194. if isinstance(x, (RawTensor, NodeMixin)):
  1195. NodeMixin.wrap(
  1196. x,
  1197. lambda: Input.make(
  1198. type=NodeMixin.get_wrapped_type(x), name=name
  1199. ),
  1200. )
  1201. return x
  1202. args = [self]
  1203. for i, v in enumerate(inputs[1:]):
  1204. args.append(wrap(v, idx2key[i + 1]))
  1205. args, kwargs = tree_def.unflatten(args)
  1206. active_module_tracer().patcher.auto_patch(
  1207. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  1208. )
  1209. rst = type(self._mod).forward(*args, **kwargs)
  1210. if _convert_node_flag():
  1211. rst = _node_to_tensor(rst)[0][0]
  1212. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  1213. for i in (
  1214. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  1215. ):
  1216. active_module_tracer().current_scope()._add_output(NodeMixin.get(i))
  1217. NodeMixin.wrap_safe(self, orig_self)
  1218. for arg, node in zip(inputs[1:], origin_inp_node):
  1219. if node:
  1220. NodeMixin.wrap_safe(arg, node)
  1221. active_module_tracer().pop_scope()
  1222. # rebind output to outer graph
  1223. callnode.out_def = out_def
  1224. callnode.add_outputs(outputs)
  1225. self._argdef_graph_map[callnode.arg_def] = self._body
  1226. self._argdef_outdef_map[callnode.arg_def] = out_def
  1227. return rst
  1228. def __setattr__(self, name, value):
  1229. object.__setattr__(self, name, value)
  1230. def __repr__(self):
  1231. return repr(self._mod)
  1232. def __getattr__(self, name):
  1233. if name not in self._mod.__dict__:
  1234. attr = getattr(type(self._mod), name).__get__(self, type(self))
  1235. else:
  1236. attr = getattr(self._mod, name)
  1237. full_name = None
  1238. if (
  1239. isinstance(attr, FunctionType)
  1240. and id(attr) in active_module_tracer().patcher.patched_fn_ids
  1241. ):
  1242. return active_module_tracer().patcher.wrap_fn(attr)
  1243. if id(attr) in active_module_tracer().id2name:
  1244. full_name = active_module_tracer().id2name[id(attr)]
  1245. if isinstance(attr, (List, Dict)):
  1246. unset_module_tracing()
  1247. has_module, m_container = replace_container_with_module_container(attr)
  1248. if m_container:
  1249. attr = m_container
  1250. if has_module and not m_container:
  1251. raise ValueError(
  1252. "Can not trace the module that uses the same container to store Module and Non-Module objects "
  1253. )
  1254. set_module_tracing()
  1255. if isinstance(attr, Module):
  1256. attr = TracedModuleBuilder(attr)
  1257. if isinstance(attr, (Module, RawTensor)):
  1258. setattr(self, name, attr)
  1259. active_module_tracer().id2name[id(attr)] = full_name
  1260. if full_name:
  1261. scope_name = active_module_tracer().current_scope()._module_name
  1262. if scope_name:
  1263. full_name = full_name[len(scope_name) + 1 :]
  1264. else:
  1265. full_name = name
  1266. else:
  1267. full_name = name
  1268. NodeMixin.wrap(
  1269. attr,
  1270. lambda: GetAttr.make(
  1271. NodeMixin.get(self),
  1272. name,
  1273. type=NodeMixin.get_wrapped_type(attr),
  1274. orig_name=full_name,
  1275. ),
  1276. )
  1277. return attr
  1278. def __getattribute__(self, name):
  1279. if name in TracedModuleBuilder.__builder_attributes__:
  1280. return object.__getattribute__(self, name)
  1281. else:
  1282. wrapped = object.__getattribute__(self, name)
  1283. class_members = dict(inspect.getmembers(self.__class__))
  1284. if name in self._mod.__dict__:
  1285. mod_attr = getattr(self._mod, name)
  1286. if name in class_members:
  1287. if (
  1288. not isinstance(wrapped, TracedModuleBuilder)
  1289. and wrapped is not mod_attr
  1290. ):
  1291. wrapped = self.__getattr__(name)
  1292. if isinstance(wrapped, TracedModuleBuilder):
  1293. if not isinstance(mod_attr, (List, Dict)):
  1294. assert mod_attr is wrapped._mod
  1295. else:
  1296. assert mod_attr is wrapped
  1297. full_name = None
  1298. if id(mod_attr) in active_module_tracer().id2name:
  1299. full_name = active_module_tracer().id2name[id(mod_attr)]
  1300. scope_name = active_module_tracer().current_scope()._module_name
  1301. if full_name and scope_name:
  1302. full_name = full_name[len(scope_name) + 1 :]
  1303. else:
  1304. full_name = name
  1305. else:
  1306. full_name = name
  1307. # assert not self._is_builtin
  1308. if isinstance(wrapped, (NodeMixin, RawTensor)):
  1309. NodeMixin.wrap(
  1310. wrapped,
  1311. lambda: GetAttr.make(
  1312. NodeMixin.get(self),
  1313. name,
  1314. type=NodeMixin.get_wrapped_type(wrapped),
  1315. orig_name=full_name,
  1316. ),
  1317. )
  1318. return wrapped
  1319. class _expr_iter:
  1320. def __init__(self, graph: InternalGraph, recursive: bool = True):
  1321. self.graph = graph
  1322. self.recursive = recursive
  1323. def __iter__(self):
  1324. for inp_node in self.graph.inputs:
  1325. yield inp_node.expr
  1326. for expr in self.graph._exprs:
  1327. if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
  1328. yield expr
  1329. if self.recursive and expr.graph is not None:
  1330. yield from expr.graph.exprs(self.recursive)
  1331. else:
  1332. yield expr
  1333. class _node_iter:
  1334. def __init__(self, graph: InternalGraph, recursive: bool = True) -> None:
  1335. nodes = []
  1336. node_ids = set()
  1337. for expr in graph.exprs(recursive):
  1338. for n in expr.inputs + expr.outputs:
  1339. if id(n) in node_ids:
  1340. continue
  1341. nodes.append(n)
  1342. node_ids.add(id(n))
  1343. self.nodes = list(sorted(nodes, key=lambda x: x._id))
  1344. def __iter__(self):
  1345. for node in self.nodes:
  1346. yield node
  1347. class BaseFilter:
  1348. r"""``BaseFilter`` exposes some methods for converting ``_node_iter/_expr_iter`` to ``list``, ``dict``, etc."""
  1349. def __init__(self, iter: Iterable):
  1350. self._iter = iter
  1351. def __iter__(self):
  1352. return iter(self._iter)
  1353. def as_list(self):
  1354. r"""Consume this iterator and return its content as a list.
  1355. Returns:
  1356. A list of ``Node`` or ``Expr``.
  1357. """
  1358. return list(self)
  1359. def as_dict(self):
  1360. r"""Construct an ordered dict to map from ``id`` to objects in this iterator.
  1361. Returns:
  1362. An :class:`OrderedDict`.
  1363. """
  1364. return collections.OrderedDict((i._id, i) for i in self)
  1365. def as_unique(self):
  1366. """Assert that this iterator yields only one ``Node`` or ``Expr`` and return it.
  1367. Rerurns:
  1368. A ``Node`` or ``Expr``.
  1369. """
  1370. rst = self.as_list()
  1371. assert len(rst) == 1, "{} elements found".format(len(rst))
  1372. (elem,) = self
  1373. return elem
  1374. def as_count(self):
  1375. r"""Consume this iterator and get the number of elements."""
  1376. return sum(1 for _ in self)
  1377. class ExprFilter(BaseFilter):
  1378. """Filter on Expr iterator.
  1379. This class is an iterator of :class:`.Expr` objects and multiple
  1380. filtering conditions and mappers can be chained.
  1381. """
  1382. def call_function(self, func):
  1383. r"""Filter by specific ``CallFunction.func``.
  1384. See :meth:`~.InternalGraph.get_function_by_type` for details.
  1385. """
  1386. return ExprFilterCallFunction(self, func)
  1387. def call_method(self, method):
  1388. r"""Filter by specific ``CallMethod.method``.
  1389. See :meth:`~.InternalGraph.get_function_by_type` for details.
  1390. """
  1391. return ExprFilterCallMethod(self, method)
  1392. def expr_id(self, expr_id: List[int]):
  1393. r"""Filter Exprs by their ``id``.
  1394. See :meth:`~.InternalGraph.get_function_by_type` for details.
  1395. """
  1396. return ExprFilterExprId(self, expr_id)
  1397. class NodeFilter(BaseFilter):
  1398. """Filter on Node iterator.
  1399. This class is an iterator of :class:`.Node` objects and multiple
  1400. filtering conditions and mappers can be chained.
  1401. """
  1402. def type(self, owner_type):
  1403. r"""Filter by specific Module type.
  1404. See :meth:`~.InternalGraph.get_module_by_type` for details.
  1405. """
  1406. return NodeFilterType(self, owner_type)
  1407. def node_id(self, node_id: List[int]):
  1408. r"""Filter Nodes by their ``id``.
  1409. See :meth:`~.InternalGraph.get_node_by_id` for details.
  1410. """
  1411. return NodeFilterNodeId(self, node_id)
  1412. def name(self, name: str, ignorecase: bool = True):
  1413. r"""Filter Nodes by their full name.
  1414. See :meth:`~.InternalGraph.get_node_by_name` for details.
  1415. """
  1416. return NodeFilterName(self, name, ignorecase)
  1417. class NodeFilterType(NodeFilter):
  1418. """See :meth:`~.InternalGraph.get_module_by_type`"""
  1419. def __init__(self, expr_iter, owner_type):
  1420. super().__init__(expr_iter)
  1421. self.owner_type = owner_type
  1422. def __iter__(self):
  1423. for node in self._iter:
  1424. if not isinstance(node, ModuleNode):
  1425. continue
  1426. if not hasattr(node, "owner"):
  1427. continue
  1428. if isinstance(node.owner, self.owner_type):
  1429. yield node
  1430. class NodeFilterNodeId(NodeFilter):
  1431. """See :meth:`~.InternalGraph.get_node_by_id`"""
  1432. def __init__(self, expr_iter, node_id: List[int]):
  1433. super().__init__(expr_iter)
  1434. if not isinstance(node_id, Sequence):
  1435. node_id = [node_id]
  1436. self.node_id = node_id
  1437. def __iter__(self):
  1438. for node in self._iter:
  1439. if node._id in self.node_id:
  1440. yield node
  1441. class NodeFilterName(NodeFilter):
  1442. """See :meth:`~.InternalGraph.get_node_by_name`"""
  1443. _re = None
  1444. def __init__(self, node_iter, pattern, ignorecase):
  1445. super().__init__(node_iter)
  1446. self.pattern = pattern
  1447. self._re = self.make_re(pattern, ignorecase)
  1448. @classmethod
  1449. def make_re(cls, pattern, ignorecase=True):
  1450. assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern)
  1451. assert isinstance(ignorecase, bool)
  1452. flags = 0
  1453. if ignorecase:
  1454. flags |= re.IGNORECASE
  1455. return re.compile(fnmatch.translate(pattern), flags=flags)
  1456. def __iter__(self):
  1457. for i in self._iter:
  1458. graph = i.top_graph
  1459. name = "{}_{}".format(graph._name, i._name.lstrip("_"))
  1460. if graph._prefix_name:
  1461. name = "{}_{}".format(graph._prefix_name, name.lstrip("_"))
  1462. if self.pattern == name or self._re.match(name):
  1463. yield i
  1464. class ExprFilterCallFunction(ExprFilter):
  1465. """See :meth:`~.InternalGraph.get_function_by_type`"""
  1466. def __init__(self, expr_iter, func: Callable = None):
  1467. super().__init__(expr_iter)
  1468. self.func = func
  1469. def __iter__(self):
  1470. for expr in self._iter:
  1471. if not isinstance(expr, CallFunction):
  1472. continue
  1473. if self.func is None or expr.func == self.func:
  1474. yield expr
  1475. class ExprFilterCallMethod(ExprFilter):
  1476. """See :meth:`~.InternalGraph.get_method_by_type`"""
  1477. def __init__(self, expr_iter, method: str = None):
  1478. super().__init__(expr_iter)
  1479. self.method = method
  1480. def __iter__(self):
  1481. for expr in self._iter:
  1482. if not isinstance(expr, CallMethod):
  1483. continue
  1484. if self.method is None or expr.method == self.method:
  1485. yield expr
  1486. class ExprFilterExprId(ExprFilter):
  1487. """See :meth:`~.InternalGraph.get_expr_by_id`"""
  1488. def __init__(self, expr_iter, expr_id: List[int]):
  1489. super().__init__(expr_iter)
  1490. if not isinstance(expr_id, Sequence):
  1491. expr_id = [expr_id]
  1492. self.expr_id = expr_id
  1493. def __iter__(self):
  1494. for expr in self._iter:
  1495. if expr._id in self.expr_id:
  1496. yield expr
  1497. class TracedModule(Module):
  1498. r"""``TracedModule`` is the Module created by tracing normal module.
  1499. It owns an argdef to graph(InternalGraph) map. The forward method of ``TracedModule``
  1500. will get a graph from ``argdef_graph_map`` according to the argdef of input ``args/kwargs``
  1501. and interpret it.
  1502. .. note::
  1503. ``TracedModule`` can only be created by :func:`~.trace_module`. See :func:`~.trace_module`
  1504. for more details.
  1505. """
  1506. # m_node = None # type: ModuleNode
  1507. argdef_graph_map = None
  1508. argdef_outdef_map = None
  1509. def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False):
  1510. super(TracedModule, self).__init__()
  1511. self.argdef_graph_map = argdef_graph_map
  1512. self.argdef_outdef_map = argdef_outdef_map
  1513. self._is_top = is_top
  1514. self.watch_points = []
  1515. self.watch_node_value = {}
  1516. self.end_points = []
  1517. self.is_qat = is_qat
  1518. def forward(self, *args, **kwargs):
  1519. inputs, treedef = tree_flatten(((self, *args), kwargs))
  1520. assert treedef in self.argdef_graph_map
  1521. inputs = filter(
  1522. lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
  1523. ) # allow TracedModuleBuilder for retrace.
  1524. outputs = self.argdef_graph_map[treedef].interpret(*inputs)
  1525. if self.watch_points:
  1526. self.watch_node_value = {}
  1527. for n in self.watch_points:
  1528. self.watch_node_value[n] = n.top_graph._rst.pop(n)
  1529. if self.end_points:
  1530. return outputs
  1531. out_def = self.argdef_outdef_map[treedef]
  1532. outputs = out_def.unflatten(outputs)
  1533. return outputs
  1534. def set_watch_points(self, nodes):
  1535. r"""Initialize the :attr:`~.TracedModule.watch_points`.
  1536. You can call this function to get the ``Tensor/Module`` corresponding to a ``Node`` at runtime.
  1537. Args:
  1538. nodes: a list of ``Node``.
  1539. For example, the following code
  1540. .. code-block::
  1541. import megengine.module as M
  1542. import megengine as mge
  1543. import megengine.traced_module as tm
  1544. class MyModule(M.Module):
  1545. def forward(self, x):
  1546. x = x + 1 + 2
  1547. return x
  1548. net = MyModule()
  1549. inp = mge.Tensor([0])
  1550. traced_module = tm.trace_module(net, inp)
  1551. add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
  1552. traced_module.set_watch_points(add_1_node)
  1553. out = traced_module(inp)
  1554. Will get the following ``watch_node_value``::
  1555. print(traced_module.watch_node_value)
  1556. .. code-block:: text
  1557. {add_out: Tensor([1.], device=xpux:0)}
  1558. """
  1559. if not isinstance(nodes, Sequence):
  1560. nodes = [nodes]
  1561. self.watch_points = nodes
  1562. if nodes:
  1563. nodes[0].top_graph._watch_point = []
  1564. for n in nodes:
  1565. n.top_graph._watch_point.append(n)
  1566. def clear_watch_points(self):
  1567. r"""Clear the :attr:`~.TracedModule.watch_points` and :attr:`~.TracedModule.watch_node_value`.
  1568. """
  1569. for n in self.watch_points:
  1570. n.top_graph._watch_point = []
  1571. self.watch_points = []
  1572. self.watch_node_value = {}
  1573. def set_end_points(self, nodes: Sequence[Node]):
  1574. r"""Initialize the :attr:`~.TracedModule.end_points`.
  1575. When all the ``nodes`` are generated, the Module will stop execution and return directly.
  1576. Args:
  1577. nodes: a list of ``Node``.
  1578. For example, the following code
  1579. .. code-block::
  1580. import megengine.module as M
  1581. import megengine as mge
  1582. import megengine.traced_module as tm
  1583. class MyModule(M.Module):
  1584. def forward(self, x):
  1585. x = x + 1 + 2
  1586. return x
  1587. net = MyModule()
  1588. inp = mge.Tensor([0])
  1589. traced_module = tm.trace_module(net, inp)
  1590. add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
  1591. traced_module.set_end_points(add_1_node)
  1592. out = traced_module(inp)
  1593. Will get the following ``out``::
  1594. print(out)
  1595. .. code-block:: text
  1596. [Tensor([1.], device=xpux:0)]
  1597. """
  1598. if not isinstance(nodes, Sequence):
  1599. nodes = [nodes]
  1600. self.end_points = nodes
  1601. graphs = list(self.argdef_graph_map.values())
  1602. for n in nodes:
  1603. assert n.top_graph in graphs
  1604. n.top_graph._end_point.append(n)
  1605. def clear_end_points(self):
  1606. r"""Clear the :attr:`~.TracedModule.end_points`.
  1607. """
  1608. for n in self.end_points:
  1609. n.top_graph._end_point = []
  1610. self.end_points = []
  1611. @property
  1612. def graph(self) -> InternalGraph:
  1613. """Return the ``InternalGraph`` of this ``TracedModule``
  1614. """
  1615. if self._is_top:
  1616. self._update_ref()
  1617. assert len(self.argdef_graph_map) == 1
  1618. return list(self.argdef_graph_map.values())[0]
  1619. def _update_ref(self, actual_node_map: Union[Dict] = None, top_graph=None):
  1620. for inp_def, graph in self.argdef_graph_map.items():
  1621. if top_graph is not None:
  1622. graph._top_graph = weakref.ref(top_graph)
  1623. for n in graph._inputs + graph.outputs:
  1624. n._top_graph = weakref.ref(graph)
  1625. graph._inputs[0]._owner = weakref.ref(self)
  1626. for i, n in enumerate(graph._inputs):
  1627. n.actual_node = []
  1628. if actual_node_map is not None and inp_def in actual_node_map.keys():
  1629. n.actual_node = list(list(zip(*(actual_node_map[inp_def])))[i])
  1630. node2obj = {}
  1631. next_actual_node_map = collections.defaultdict(
  1632. lambda: collections.defaultdict(list)
  1633. )
  1634. node2obj[graph._inputs[0]] = self
  1635. for expr in graph._exprs:
  1636. for n in expr.inputs + expr.outputs:
  1637. n._top_graph = weakref.ref(graph)
  1638. expr._top_graph = weakref.ref(graph)
  1639. if isinstance(expr, GetAttr) and isinstance(
  1640. expr.outputs[0], ModuleNode
  1641. ):
  1642. obj = getattr(node2obj[expr.inputs[0]], expr.name)
  1643. expr.outputs[0]._owner = weakref.ref(obj)
  1644. node2obj[expr.outputs[0]] = obj
  1645. if isinstance(expr, Constant) and isinstance(
  1646. expr.outputs[0], ModuleNode
  1647. ):
  1648. obj = expr.value
  1649. expr.outputs[0]._owner = weakref.ref(obj)
  1650. node2obj[expr.outputs[0]] = obj
  1651. if (
  1652. isinstance(expr, CallMethod)
  1653. and expr.method == "__call__"
  1654. and isinstance(expr.inputs[0], ModuleNode)
  1655. ):
  1656. obj = node2obj[expr.inputs[0]]
  1657. if expr.arg_def is not None:
  1658. next_actual_node_map[obj][expr.arg_def].append(expr.inputs)
  1659. for obj in node2obj.values():
  1660. if obj is self:
  1661. continue
  1662. mnode_map = None
  1663. if obj in next_actual_node_map.keys():
  1664. mnode_map = next_actual_node_map[obj]
  1665. if isinstance(obj, TracedModule):
  1666. obj._update_ref(mnode_map, graph)
  1667. def flatten(self):
  1668. r"""Get a new TracedModule, which eliminates ``GetAttr`` and has no hierarchy.
  1669. Retruns:
  1670. A new :class:`TracedModule`.
  1671. """
  1672. new_module = copy.deepcopy(self)
  1673. assert active_module_tracer() is None
  1674. id2name = _init_id2name(new_module, "self")
  1675. set_active_module_tracer(module_tracer(lambda x: x, {}))
  1676. active_module_tracer().push_scope(new_module.graph)
  1677. def _flatten_subgraph(
  1678. parent_graph: InternalGraph,
  1679. graph: InternalGraph,
  1680. module: Module,
  1681. call=None,
  1682. prefix_name="",
  1683. module_name="",
  1684. ):
  1685. if isinstance(prefix_name, str) and prefix_name and prefix_name[-1] != "_":
  1686. prefix_name += "_"
  1687. if isinstance(module_name, str) and module_name:
  1688. module_name += "."
  1689. if graph is None or module.is_qat:
  1690. assert not isinstance(module, TracedModule) or module.is_qat
  1691. const = Constant(module, id2name[id(module)])
  1692. m_node = call.inputs[0]
  1693. if m_node.top_graph != active_module_tracer().current_scope():
  1694. m_node._name = (
  1695. active_module_tracer()
  1696. .current_scope()
  1697. ._create_unique_name(prefix_name)
  1698. )
  1699. m_node._orig_name = id2name[id(module)][5:]
  1700. const.outputs[0] = m_node
  1701. const.outputs[0].expr = const
  1702. return [const, call]
  1703. if call is not None:
  1704. graph = copy.deepcopy(graph)
  1705. exprs = []
  1706. node2obj = {}
  1707. node2obj[graph._inputs[0]] = module
  1708. if call:
  1709. node2obj[call.inputs[0]] = module
  1710. # replace inputs for submodule's exprx
  1711. if call:
  1712. repl_dict = dict(zip(graph._inputs, call.inputs))
  1713. for ind, out in enumerate(graph.outputs):
  1714. if isinstance(out.expr, Input):
  1715. assert out in repl_dict
  1716. call_out = call.outputs[ind]
  1717. for expr in call.outputs[ind].users:
  1718. for index, inp in enumerate(expr.inputs):
  1719. if inp is call_out:
  1720. expr.inputs[index] = repl_dict[out]
  1721. repl_dict[out].users.append(expr)
  1722. if parent_graph is not None:
  1723. for index, parent_out in enumerate(parent_graph._outputs):
  1724. if parent_out is call_out:
  1725. parent_graph._outputs[index] = repl_dict[out]
  1726. continue
  1727. repl_dict[out] = call.outputs[ind]
  1728. graph._replace_inputs_outputs(repl_dict, prefix_name, module_name)
  1729. for expr in graph._exprs:
  1730. if isinstance(expr, GetAttr):
  1731. # replace GetAttr with Constant
  1732. if isinstance(expr.outputs[0], TensorNode):
  1733. const = Constant(getattr(node2obj[expr.inputs[0]], expr.name))
  1734. const.outputs = expr.outputs
  1735. const.outputs[0].expr = const
  1736. exprs.append(const)
  1737. elif isinstance(expr.outputs[0], ModuleNode):
  1738. node2obj[expr.outputs[0]] = getattr(
  1739. node2obj[expr.inputs[0]], expr.name
  1740. )
  1741. elif isinstance(expr, CallMethod):
  1742. obj_node = expr.inputs[0]
  1743. if isinstance(obj_node, ModuleNode):
  1744. pre_expr = expr.inputs[0].expr
  1745. if isinstance(pre_expr, GetAttr):
  1746. (obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]])
  1747. expr_graph = (
  1748. obj.argdef_graph_map[expr.arg_def]
  1749. if hasattr(obj, "argdef_graph_map")
  1750. else None
  1751. )
  1752. exprs.extend(
  1753. _flatten_subgraph(
  1754. graph,
  1755. expr_graph,
  1756. obj,
  1757. expr,
  1758. prefix_name + obj_node._name.lstrip("_"),
  1759. module_name + obj_node._orig_name,
  1760. )
  1761. )
  1762. else:
  1763. # module has been replaced.
  1764. assert isinstance(pre_expr, Constant)
  1765. exprs.append(expr)
  1766. else:
  1767. exprs.append(expr)
  1768. else:
  1769. exprs.append(expr)
  1770. if call is not None:
  1771. for i in call.inputs:
  1772. i.users.remove(call)
  1773. return exprs
  1774. new_module.graph._exprs = _flatten_subgraph(None, new_module.graph, new_module)
  1775. new_module.graph.compile()
  1776. set_active_module_tracer(None)
  1777. new_module.graph._reset_ids()
  1778. return new_module
  1779. def __getstate__(self):
  1780. d = self.__dict__
  1781. for k in Module.__dict__:
  1782. d.pop(k, None)
  1783. return d
  1784. def cpp_apply_module_trace(opdef, *args):
  1785. return Apply.apply_module_trace_hook(opdef, *args)
  1786. def register_as_builtin(mod_cls: Type[Module]) -> None:
  1787. r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module.
  1788. Args:
  1789. mod_cls: the module class which will be treated as builtin module in tracing.
  1790. """
  1791. module_tracer.register_as_builtin(mod_cls)
  1792. def wrap(func: Callable):
  1793. r"""Call this function to register ``func`` as a builtin function.
  1794. This function can be called at module-level scope to register ``func`` as a builtin function.
  1795. A builtin function will be converted to a :class:`CallFunction` Expr in tracing::
  1796. def my_func(x, y):
  1797. return x + y
  1798. import megengine.traced_module as tm
  1799. tm.wrap(my_func)
  1800. This function can also equivalently be used as a decorator::
  1801. @tm.wrap
  1802. def my_func(x, y):
  1803. return x + y
  1804. Args:
  1805. func: the function of the global function to insert into the graph when it's called.
  1806. """
  1807. assert callable(func), "func must be a callable"
  1808. assert hasattr(func, "__code__")
  1809. fn_name = func.__code__.co_name
  1810. currentframe = inspect.currentframe()
  1811. assert currentframe is not None
  1812. f = currentframe.f_back
  1813. assert f is not None
  1814. assert (
  1815. f.f_code.co_name == "<module>"
  1816. ), "wrap must be called at the top level of a module"
  1817. Patcher._builtin_functions.append((f.f_globals, fn_name))
  1818. return func
  1819. def _register_all_builtin_module():
  1820. for sub_mod in [M, M.qat, M.quantized]:
  1821. for m in getmembers(sub_mod):
  1822. if (
  1823. isclass(m[1])
  1824. and issubclass(m[1], M.Module)
  1825. and m[1] is not M.Sequential
  1826. ):
  1827. module_tracer.register_as_builtin(m[1])
  1828. module_tracer.register_as_builtin(Observer)
  1829. module_tracer.register_as_builtin(MinMaxObserver)
  1830. module_tracer.register_as_builtin(SyncMinMaxObserver)
  1831. module_tracer.register_as_builtin(ExponentialMovingAverageObserver)
  1832. module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver)
  1833. module_tracer.register_as_builtin(HistogramObserver)
  1834. module_tracer.register_as_builtin(PassiveObserver)
  1835. module_tracer.register_as_builtin(LSQ)
  1836. module_tracer.register_as_builtin(TQT)
  1837. module_tracer.register_as_builtin(FakeQuantize)
  1838. module_tracer.register_as_builtin(TM_FakeQuant)
  1839. def trace_module(
  1840. mod: Module, *args: Tuple[Any], **kwargs: Dict[str, Any]
  1841. ) -> TracedModule:
  1842. r"""Traces module ``mod`` and returns corresponding :class:`TracedModule`.
  1843. Args:
  1844. mod: the module will be converted to :class:`TracedModule`.
  1845. args: the positional arguments passed to forward method of ``mod``.
  1846. kwargs: the keyword arguments passed to forward method of ``mod``.
  1847. """
  1848. assert active_module_tracer() is None
  1849. assert isinstance(mod, Module)
  1850. try:
  1851. use_sym_shape = set_symbolic_shape(True)
  1852. set_module_tracing()
  1853. set_active_module_tracer(
  1854. module_tracer(_wrapped_function, _init_id2name(mod, "self"))
  1855. )
  1856. for cls in [Expr, Node]:
  1857. cls._set_next_id(0)
  1858. with active_module_tracer().patcher:
  1859. global_scope = InternalGraph(name="")
  1860. active_module_tracer().push_scope(global_scope)
  1861. builder = TracedModuleBuilder(mod, True)
  1862. name = mod._name if mod._name else mod.__class__.__name__
  1863. NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode, orig_name="self"))
  1864. inputs, _ = tree_flatten((args, kwargs))
  1865. for _, i in enumerate(inputs):
  1866. # assert isinstance(i, Tensor), "not support "
  1867. if isinstance(i, RawTensor):
  1868. NodeMixin.wrap_safe(
  1869. i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
  1870. )
  1871. builder(*args, **kwargs)
  1872. active_module_tracer().pop_scope()
  1873. traced_mod = builder.build()
  1874. traced_mod.graph._reset_ids()
  1875. return traced_mod
  1876. finally:
  1877. set_symbolic_shape(use_sym_shape)
  1878. set_active_module_tracer(None)
  1879. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台