You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 78 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import builtins
  9. import collections
  10. import copy
  11. import fnmatch
  12. import functools
  13. import inspect
  14. import keyword
  15. import re
  16. import weakref
  17. from inspect import getcallargs, getmembers, isclass, ismethod
  18. from itertools import chain
  19. from types import FunctionType
  20. from typing import (
  21. Any,
  22. Callable,
  23. Dict,
  24. Iterable,
  25. List,
  26. Optional,
  27. Sequence,
  28. Tuple,
  29. Type,
  30. Union,
  31. )
  32. from .. import functional as F
  33. from .. import get_logger
  34. from .. import module as M
  35. from ..core._imperative_rt.core2 import Tensor as RawTensor
  36. from ..core._imperative_rt.core2 import (
  37. is_tracing_module,
  38. set_module_tracing,
  39. unset_module_tracing,
  40. )
  41. from ..core._trace_option import set_symbolic_shape
  42. from ..module import Module
  43. from ..module.qat import QATModule
  44. from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize
  45. from ..quantization.observer import (
  46. ExponentialMovingAverageObserver,
  47. HistogramObserver,
  48. MinMaxObserver,
  49. Observer,
  50. PassiveObserver,
  51. SyncExponentialMovingAverageObserver,
  52. SyncMinMaxObserver,
  53. )
  54. from ..tensor import Tensor
  55. from .expr import (
  56. Apply,
  57. CallFunction,
  58. CallMethod,
  59. Constant,
  60. Expr,
  61. GetAttr,
  62. Input,
  63. get_suffix_name,
  64. is_apply_def,
  65. is_call_function,
  66. is_call_module,
  67. is_call_tensor_method,
  68. is_constant,
  69. is_getattr,
  70. )
  71. from .fake_quant import FakeQuantize as TM_FakeQuant
  72. from .module_tracer import (
  73. PatchedFn,
  74. Patcher,
  75. active_module_tracer,
  76. get_tensor_wrapable_method,
  77. module_tracer,
  78. set_active_module_tracer,
  79. )
  80. from .node import ModuleNode, Node, NodeMixin, TensorNode
  81. from .pytree import ArgsIndex, tree_flatten
  82. from .utils import replace_container_with_module_container
  83. logger = get_logger(__name__)
  84. def _is_builtin_name(name: str) -> bool:
  85. return (
  86. name in builtins.__dict__
  87. or name in keyword.kwlist
  88. or name in {"inf", "nan", "NoneType"}
  89. )
  90. def _is_leaf(node):
  91. assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
  92. type(node)
  93. )
  94. return isinstance(node, RawTensor)
  95. _enable_node_to_tensor = False
  96. def _convert_node_flag():
  97. return _enable_node_to_tensor
  98. def _set_convert_node_flag(flag: bool = False):
  99. global _enable_node_to_tensor
  100. pre_flag = _enable_node_to_tensor
  101. _enable_node_to_tensor = flag
  102. return pre_flag
  103. def _node_to_tensor(*args, **kwargs):
  104. tensors = []
  105. nodes, tree_def = tree_flatten((args, kwargs))
  106. for n in nodes:
  107. if isinstance(n, TensorNode):
  108. if n.top_graph is not None:
  109. active_module_tracer().current_scope()._add_input(n)
  110. value = n.value
  111. if value is None:
  112. flag = _set_convert_node_flag(False)
  113. unset_module_tracing()
  114. value = F.zeros(shape=n._shape, dtype=n._dtype)
  115. set_module_tracing()
  116. _set_convert_node_flag(flag)
  117. orig_n = NodeMixin.get(value, None)
  118. if orig_n is None or "setitem" not in orig_n._name:
  119. NodeMixin.wrap_safe(value, n)
  120. tensors.append(value)
  121. else:
  122. tensors.append(n)
  123. tensors = tree_def.unflatten(tensors)
  124. return tensors
  125. def _tensor_to_node(tensors):
  126. if tensors is None:
  127. return None
  128. nodes = []
  129. tensors, out_def = tree_flatten(tensors)
  130. for t in tensors:
  131. if isinstance(t, Tensor):
  132. n = NodeMixin.get(t, None)
  133. if isinstance(n, TensorNode):
  134. n.value = t
  135. nodes.append(n)
  136. else:
  137. nodes.append(t)
  138. else:
  139. nodes.append(t)
  140. nodes = out_def.unflatten(nodes)
  141. return nodes
  142. def _wrap_method_to_tensor_node():
  143. def _any_method(name):
  144. def _any(*args, **kwargs):
  145. args, kwargs = _node_to_tensor(*args, **kwargs)
  146. attr = getattr(args[0], name)
  147. outs = attr
  148. if callable(attr):
  149. outs = attr(*(args[1:]), **kwargs)
  150. if name == "__setitem__":
  151. _node_to_tensor(outs)
  152. return None
  153. outs = _tensor_to_node(outs)
  154. return outs
  155. return _any
  156. tensor_method_patch = []
  157. for method in get_tensor_wrapable_method():
  158. patch = PatchedFn(TensorNode, method)
  159. if type(getattr(Tensor, method)) == property:
  160. patch.set_func(property(_any_method(method)))
  161. else:
  162. patch.set_func(_any_method(method))
  163. tensor_method_patch.append(patch)
  164. return tensor_method_patch
  165. def _convert_node_and_tensor(orig_func):
  166. @functools.wraps(orig_func)
  167. def _convert(*args, **kwargs):
  168. if _convert_node_flag() and is_tracing_module():
  169. args, kwargs = _node_to_tensor(*args, **kwargs)
  170. rst = orig_func(*args, **kwargs, method_func=_convert)
  171. rst = _tensor_to_node(rst)
  172. return rst
  173. else:
  174. rst = orig_func(*args, **kwargs)
  175. return rst
  176. return _convert
  177. def _wrap_mnode_getattr(orig_getattr):
  178. @functools.wraps(orig_getattr)
  179. def wraped_fn(self, name):
  180. obj = self.owner
  181. current_graph = active_module_tracer().current_scope()
  182. if self.top_graph is not None:
  183. current_graph._add_input(self)
  184. attr = getattr(obj, name)
  185. node = attr
  186. if not isinstance(attr, TracedModuleBuilder):
  187. if isinstance(attr, Module):
  188. attr = TracedModuleBuilder(attr)
  189. setattr(obj, name, attr)
  190. if isinstance(attr, (NodeMixin, RawTensor)):
  191. NodeMixin.wrap(
  192. attr,
  193. lambda: GetAttr.make(
  194. self,
  195. type=NodeMixin.get_wrapped_type(attr),
  196. attr_name=name,
  197. name="",
  198. ),
  199. )
  200. if isinstance(attr, (NodeMixin, RawTensor)):
  201. node = NodeMixin.get(attr)
  202. if isinstance(node, ModuleNode):
  203. node._owner = weakref.ref(attr)
  204. return node
  205. return wraped_fn
  206. def _wrap_mnode_call(orig_call):
  207. @functools.wraps(orig_call)
  208. def wraped_fn(self, *args, **kwargs):
  209. obj = self.owner
  210. if self.top_graph is not None:
  211. active_module_tracer().current_scope()._add_input(self)
  212. rst = obj(*args, **kwargs)
  213. return rst
  214. return wraped_fn
  215. class _InsertExprs:
  216. def __init__(self, graph, expr: Optional[Expr] = None):
  217. self.graph = graph
  218. while graph.top_graph is not None:
  219. graph = graph.top_graph
  220. assert graph.inputs[0].owner._is_top
  221. self.root_graph = graph
  222. self.global_scope = InternalGraph(self.graph._name, self.graph._qualname)
  223. self.global_scope._namespace.merge(self.graph._namespace)
  224. self.expr = expr
  225. self._tensor_method_patch = None
  226. def __enter__(self):
  227. self.use_sym_shape = set_symbolic_shape(True)
  228. node_id, expr_id = self.root_graph._total_ids
  229. Node._set_next_id(node_id)
  230. Expr._set_next_id(expr_id)
  231. set_module_tracing()
  232. _set_convert_node_flag(True)
  233. assert active_module_tracer() is None
  234. set_active_module_tracer(
  235. module_tracer(lambda x: _convert_node_and_tensor(_wrapped_function(x)))
  236. )
  237. active_module_tracer().patcher.__enter__()
  238. for cls, name, func in [
  239. [ModuleNode, "__getattr__", _wrap_mnode_getattr],
  240. [ModuleNode, "__call__", _wrap_mnode_call],
  241. [TracedModuleBuilder, "__call__", _convert_node_and_tensor],
  242. ]:
  243. active_module_tracer().patcher.patch_function(cls, name, func)
  244. self._tensor_method_patch = _wrap_method_to_tensor_node()
  245. active_module_tracer().push_scope(self.global_scope)
  246. def __exit__(self, ty, va, tr):
  247. if va is not None:
  248. return False
  249. set_symbolic_shape(self.use_sym_shape)
  250. active_module_tracer().patcher.__exit__(ty, va, tr)
  251. _set_convert_node_flag(False)
  252. set_active_module_tracer(None)
  253. unset_module_tracing()
  254. while self._tensor_method_patch:
  255. pf = self._tensor_method_patch.pop()
  256. pf.set_func(pf.origin_fn)
  257. module = self.graph.inputs[0].owner
  258. for mod, parent in module.modules(with_parent=True):
  259. name = mod._name
  260. if isinstance(mod, TracedModuleBuilder):
  261. mod = mod.build()
  262. if hasattr(mod, "argdef_graph_map"):
  263. for g in mod.argdef_graph_map.values():
  264. for n in g.nodes(False):
  265. if isinstance(n, TensorNode):
  266. n.value = None
  267. setattr(parent, name, mod)
  268. for node in self.global_scope.nodes(False):
  269. node.value = None
  270. extra_inp_nodes = set(self.global_scope.inputs)
  271. max_inp_expr_idx = -1
  272. for node in extra_inp_nodes:
  273. assert (
  274. node.top_graph == self.graph
  275. ), "The input node ({}) is not in the graph ({})".format(node, self.graph)
  276. if isinstance(node, TensorNode) and node.expr in self.graph._exprs:
  277. max_inp_expr_idx = max(
  278. max_inp_expr_idx, self.graph._exprs.index(node.expr)
  279. )
  280. max_inp_expr_idx += 1
  281. insert_index = -1
  282. if self.expr is not None:
  283. insert_index = self.graph._exprs.index(self.expr)
  284. insert_index += 1
  285. if insert_index < max_inp_expr_idx:
  286. insert_index = max_inp_expr_idx
  287. for expr in self.global_scope._exprs:
  288. self.graph._exprs.insert(insert_index, expr)
  289. insert_index += 1
  290. self.graph._namespace.merge(self.global_scope._namespace)
  291. self.root_graph._total_ids = (Node._get_next_id(), Expr._get_next_id())
  292. self.root_graph.inputs[0].owner._update_ref()
  293. return True
  294. class NameSpace:
  295. def __init__(self, name, qualname):
  296. self.name = name
  297. self.qualname = qualname
  298. self._used_names = {}
  299. def create_unique_name(self, name: str) -> str:
  300. assert isinstance(name, str), "The name must be a string"
  301. name = re.sub("[^0-9a-zA-Z_]+", "_", name)
  302. if name[0].isdigit():
  303. name = "_{}".format(name)
  304. while name in self._used_names or _is_builtin_name(name):
  305. match = re.match(r"(.*)_(\d+)$", name)
  306. if match is None:
  307. name = name + "_1"
  308. else:
  309. base, num = match.group(1, 2)
  310. name = "{}_{}".format(base, int(num) + 1)
  311. self._used_names.setdefault(name)
  312. return name
  313. def auto_naming_for_outputs(self, expr: Expr):
  314. _add_suffix = lambda x: x + "_out"
  315. if is_call_module(expr):
  316. call_node = expr.inputs[0]
  317. qualname = "%s.[out]" % (call_node.qualname)
  318. name = call_node.name
  319. elif is_call_tensor_method(expr):
  320. name = expr.method.strip("_")
  321. qualname = "{}.[{}]".format(
  322. self.qualname, self.create_unique_name("method_%s" % (name)),
  323. )
  324. elif is_call_function(expr):
  325. name = expr.func.__name__
  326. qualname = "{}.[{}]".format(
  327. self.qualname, self.create_unique_name("func_%s" % name),
  328. )
  329. elif is_apply_def(expr):
  330. name = str(expr.opdef).lower()
  331. qualname = "{}.[{}]".format(
  332. self.qualname, self.create_unique_name("def_%s" % name),
  333. )
  334. elif is_getattr(expr):
  335. qualname = "{}.{}".format(expr.inputs[0].qualname, expr.name)
  336. name = get_suffix_name(self.qualname, qualname)
  337. _add_suffix = lambda x: x
  338. elif is_constant(expr):
  339. name = (
  340. expr.name if expr.name else "const_" + type(expr.value).__name__.lower()
  341. )
  342. qualname = "{}.[{}]".format(self.qualname, name)
  343. _add_suffix = lambda x: x
  344. for node in expr.outputs:
  345. if node._name == "" or node._name in self.used_names:
  346. assert _add_suffix(name) == name or isinstance(node, TensorNode)
  347. node._name = self.create_unique_name(_add_suffix(name))
  348. if node._qualname == "":
  349. node._qualname = qualname
  350. assert get_suffix_name(self.qualname, qualname)
  351. def merge(self, other: "NameSpace"):
  352. self._used_names.update(other.used_names)
  353. @property
  354. def used_names(self):
  355. return self._used_names
  356. class InternalGraph:
  357. r"""``InternalGraph`` is the main data structure used in the TracedModule.
  358. It is used to represent the execution procedure of Module's forward method.
  359. For example, the following code
  360. .. code-block::
  361. import megengine.random as rand
  362. import megengine.functional as F
  363. import megengine.module as M
  364. import megengine.traced_module as tm
  365. class MyModule(M.Module):
  366. def __init__(self):
  367. super().__init__()
  368. self.param = rand.normal(size=(3, 4))
  369. self.linear = M.Linear(4, 5)
  370. def forward(self, x):
  371. return F.relu(self.linear(x + self.param))
  372. net = MyModule()
  373. inp = F.zeros(shape = (3, 4))
  374. traced_module = tm.trace_module(net, inp)
  375. Will produce the following ``InternalGraph``::
  376. print(traced_module.graph)
  377. .. code-block:: text
  378. MyModule.Graph (self, x) {
  379. %2: linear = getattr(self, "linear") -> (Linear)
  380. %3: param = getattr(self, "param") -> (Tensor)
  381. %4: add_out = x.__add__(param, )
  382. %5: linear_out = linear(add_out, )
  383. %6: relu_out = nn.relu(linear_out, )
  384. return relu_out
  385. }
  386. """
  387. _exprs = None # type: List[Expr]
  388. _inputs = None # type: List[Node]
  389. _outputs = None # type: List[Node]
  390. _top_graph = None # type: InternalGraph
  391. _total_ids = None # type: List[int]
  392. def __init__(self, name: str, qualname: str):
  393. self._exprs = []
  394. self._inputs = []
  395. self._outputs = []
  396. self._watch_point = []
  397. self._end_point = []
  398. self._namespace = NameSpace(name, qualname)
  399. self._rst = collections.defaultdict(list)
  400. self._name = name
  401. self._qualname = qualname
  402. def _insert(self, expr):
  403. self._exprs.append(expr)
  404. @property
  405. def name(self) -> str:
  406. r"""Get the name of this graph."""
  407. return self._name
  408. @name.setter
  409. def name(self, new_name: str):
  410. r"""Set a new name to this graph."""
  411. mod = self.inputs[0].owner
  412. graph = self.top_graph
  413. assert graph is not None or mod._is_top, "The parent graph cannot be None."
  414. if graph is not None:
  415. assert new_name not in self._namespace.used_names, (
  416. "The name(%s) is already in use. Please try a different one again."
  417. % (new_name)
  418. )
  419. new_name = self._namespace.create_unique_name(new_name)
  420. self._name = new_name
  421. @property
  422. def qualname(self) -> str:
  423. r"""Get the `qualname` of this graph. The `qualname` can be used to get the
  424. submodule from the traced Module or Module.
  425. Example:
  426. .. code-block::
  427. import megengine.module as M
  428. import megengine.traced_module as tm
  429. import megengine as mge
  430. class block(M.Module):
  431. def __init__(self):
  432. super().__init__()
  433. self.relu = M.ReLU()
  434. def forward(self, x):
  435. return self.relu(x)
  436. class module(M.Module):
  437. def __init__(self):
  438. super().__init__()
  439. self.block = block()
  440. def forward(self, x):
  441. x = self.block(x)
  442. return x
  443. net = module()
  444. traced_net = tm.trace_module(net, mge.Tensor([0.]))
  445. qualname = traced_net.block.graph.qualname # qualname = "module.block"
  446. qualname = qualname.split(".", 1)[-1] # qualname = "block"
  447. assert qualname in list(map(lambda x: x[0], net.named_modules()))
  448. assert qualname in list(map(lambda x: x[0], traced_net.named_modules()))
  449. """
  450. return self._qualname
  451. @property
  452. def inputs(self) -> List[Node]:
  453. r"""Get the list of input Nodes of this graph.
  454. Returns:
  455. A list of ``Node``.
  456. """
  457. return self._inputs
  458. @property
  459. def outputs(self) -> List[Node]:
  460. r"""Get the list of output Nodes of this graph.
  461. Returns:
  462. A list of ``Node``.
  463. """
  464. return self._outputs
  465. @property
  466. def top_graph(self):
  467. r"""Get the parent graph of this graph.
  468. Returns:
  469. An ``InternalGraph``.
  470. """
  471. if self._top_graph:
  472. return self._top_graph()
  473. return None
  474. def exprs(self, recursive=True):
  475. r"""Get the Exprs that constitute this graph.
  476. Args:
  477. recursive: whether to get the Exprs in the subgraph.
  478. Default: True
  479. Returns:
  480. A ``ExprFilter`` containing all Exprs of this graph.
  481. """
  482. return ExprFilter(_expr_iter(self, recursive))
  483. def nodes(self, recursive=True):
  484. r"""Get the Nodes that constitute this graph.
  485. Args:
  486. recursive: whether to get the Nodes in the subgraph.
  487. Default: True
  488. Returns:
  489. A ``NodeFilter`` containing all Nodes of this graph.
  490. """
  491. return NodeFilter(_node_iter(self, recursive))
  492. def get_function_by_type(self, func: Callable = None, recursive=True):
  493. r"""Filter Exprs by the type of ``CallFunction``.
  494. Args:
  495. func: a built-in function, such as ``F.relu``.
  496. recursive: whether to get the Exprs in the subgraph.
  497. Default: True
  498. Returns:
  499. A :class:`~.TracedModule.ExprFilterCallFunction`.
  500. """
  501. return self.exprs(recursive).call_function(func)
  502. def get_method_by_type(self, method: str = None, recursive=True):
  503. r"""Filter Exprs by the type of ``CallMethod``.
  504. Args:
  505. method: a method string, such as "__add__".
  506. recursive: whether to get the Exprs in the subgraph.
  507. Default: True
  508. Returns:
  509. A :class:`~.TracedModule.ExprFilterCallMethod`.
  510. """
  511. return self.exprs(recursive).call_method(method)
  512. def get_expr_by_id(self, expr_id: List[int] = None, recursive=True):
  513. r"""Filter Exprs by their ``id``.
  514. Args:
  515. expr_id: a list of :class:`int`.
  516. recursive: whether to get the Exprs in the subgraph.
  517. Default: True
  518. Returns:
  519. A :class:`~.TracedModule.ExprFilterExprId`.
  520. """
  521. return self.exprs(recursive).expr_id(expr_id)
  522. def get_module_by_type(self, module_cls: Module, recursive=True):
  523. r"""Filter Nodes by the ``module_type`` of ``ModuleNode``.
  524. Args:
  525. module_cls: a subclass of :class:`~.Module`.
  526. recursive: whether to get the Nodes in the subgraph.
  527. Default: True
  528. Returns:
  529. A :class:`~.TracedModule.NodeFilterType`.
  530. """
  531. assert issubclass(module_cls, Module)
  532. return self.nodes(recursive).type(module_cls)
  533. def get_node_by_id(self, node_id: List[int] = None, recursive=True):
  534. r"""Filter Nodes by their ``id``.
  535. The ``id`` of the ``Node`` can be obtained by the following code
  536. .. code-block::
  537. # node : Node
  538. print("{:i}".format(node))
  539. print(node.__format__("i"))
  540. # graph : InternalGraph
  541. print("{:i}".format(graph))
  542. print(graph.__format__("i"))
  543. Args:
  544. node_id: a list of :class:`int`.
  545. recursive: whether to get the Nodes in the subgraph.
  546. Default: True
  547. Returns:
  548. A :class:`~.TracedModule.NodeFilterNodeId`.
  549. """
  550. return self.nodes(recursive).node_id(node_id)
  551. def get_node_by_name(
  552. self, name: str = None, ignorecase: bool = True, recursive=True
  553. ):
  554. r"""Filter Nodes by their full name.
  555. The full name of the ``Node`` can be obtained by the following code
  556. .. code-block::
  557. # node : Node
  558. print("{:p}".format(node))
  559. print(node.__format__("p"))
  560. # graph : InternalGraph
  561. print("{:p}".format(graph))
  562. print(graph.__format__("p"))
  563. Args:
  564. name: a string in glob syntax that can contain ``?`` and
  565. ``*`` to match a single or arbitrary characters.
  566. ignorecase: whether to ignroe case.
  567. Default: True
  568. recursive: whether to get the Nodes in the subgraph.
  569. Default: True
  570. Returns:
  571. A :class:`~.TracedModule.NodeFilterName`.
  572. """
  573. return self.nodes(recursive).name(name, ignorecase)
  574. def _add_input(self, i):
  575. self._inputs.append(i)
  576. def _add_output(self, o):
  577. self._outputs.append(o)
  578. def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
  579. r"""Get the dependent Exprs of the ``nodes``.
  580. Args:
  581. nodes: a list of :class:`Node`.
  582. Returns:
  583. A list of dependent :class:`Expr`.
  584. """
  585. if not isinstance(nodes, Sequence):
  586. nodes = (nodes,)
  587. ret = list()
  588. queue = list(nodes)
  589. visited_queue = list()
  590. while queue:
  591. node = queue.pop()
  592. visited_queue.append(node)
  593. expr = node.expr
  594. if expr not in ret:
  595. ret.append(expr)
  596. for i in expr.inputs:
  597. if i not in queue and i not in visited_queue:
  598. queue.append(i)
  599. return ret
  600. def reset_inputs(self, *args, **kwargs):
  601. forma_mnode = self.inputs[0]
  602. moudle = forma_mnode.owner
  603. assert moudle._is_top, "reset_inputs only supports top graph"
  604. inputs, tree_def = tree_flatten(((moudle, *args), kwargs))
  605. def create_node(val: Tensor):
  606. name = self._namespace.create_unique_name("args")
  607. node = Input(
  608. type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name)
  609. ).outputs[0]
  610. node.shape = val.shape
  611. node.dtype = val.dtype
  612. return node
  613. formal_node_inputs = [
  614. forma_mnode,
  615. ]
  616. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  617. for v in inputs[1:]:
  618. assert isinstance(v, RawTensor)
  619. formal_node_inputs.append(create_node(v))
  620. self._inputs[:] = formal_node_inputs
  621. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  622. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  623. return formal_node_inputs[1:]
  624. def add_input_node(
  625. self, shape: Tuple[int], dtype: str = "float32", name: str = "args"
  626. ):
  627. r"""Add an input node to the graph.
  628. The new Node will be the last of the positional arguments.
  629. Args:
  630. shape: the shape of the new input Node.
  631. dtype: the dtype of the new input Node.
  632. Default: float32
  633. name: the name of the new input Node. When the name is used in the graph,
  634. a suffix will be added to it.
  635. """
  636. forma_mnode = self.inputs[0]
  637. moudle = forma_mnode.owner
  638. assert moudle._is_top, "add_input_node only supports top graph"
  639. def create_node(name=None):
  640. node = Input(
  641. type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name)
  642. ).outputs[0]
  643. node.shape = shape
  644. node.dtype = dtype
  645. return node
  646. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  647. args, kwargs = org_argdef.unflatten(self._inputs)
  648. formal_inp_node = create_node(self._namespace.create_unique_name(name))
  649. inputs, tree_def = tree_flatten(
  650. ((*args, formal_inp_node), kwargs),
  651. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  652. )
  653. self._inputs[:] = inputs[:]
  654. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  655. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  656. return formal_inp_node
  657. def reset_outputs(self, outputs):
  658. r"""Reset the output Nodes of the graph.
  659. .. note::
  660. This method only supports resetting the output of graphs
  661. that do not have a parent graph.
  662. Args:
  663. outputs: an object which inner element is Node. Support tuple, list
  664. dict, etc.
  665. For example, the following code
  666. .. code-block::
  667. import megengine.functional as F
  668. import megengine.module as M
  669. import megengine.traced_module as tm
  670. class MyModule(M.Module):
  671. def forward(self, x):
  672. x = x + 1
  673. return x
  674. net = MyModule()
  675. inp = F.zeros(shape = (1, ))
  676. traced_module = tm.trace_module(net, inp)
  677. graph = traced_module.graph
  678. inp_node = graph.inputs[1]
  679. out_node = graph.outputs[0]
  680. graph.reset_outputs((out_node, {"input": inp_node}))
  681. out = traced_module(inp)
  682. Will produce the following ``InternalGraph`` and ``out``::
  683. print(graph)
  684. print(out)
  685. .. code-block:: text
  686. MyModule.Graph (self, x) {
  687. %2: add_out = x.__add__(1, )
  688. return add_out, x
  689. }
  690. (Tensor([1.], device=xpux:0), {'input': Tensor([0.], device=xpux:0)})
  691. """
  692. outputs, out_def = tree_flatten(
  693. outputs, is_leaf=lambda x: isinstance(x, TensorNode),
  694. )
  695. forma_mnode = self.inputs[0]
  696. moudle = forma_mnode.owner
  697. assert moudle._is_top, "reset_outputs only supports top graph"
  698. tree_def = list(moudle.argdef_graph_map.keys())[0]
  699. self._outputs[:] = outputs
  700. moudle.argdef_outdef_map[tree_def] = out_def
  701. def add_output_node(self, node: TensorNode):
  702. r"""Add an output node to the Graph.
  703. The Graph output will become a ``tuple`` after calling ``add_output_node``.
  704. The first element of the ``tuple`` is the original output, and the second
  705. is the ``node``.
  706. For example, the following code
  707. .. code-block::
  708. import megengine.functional as F
  709. import megengine.module as M
  710. import megengine.traced_module as tm
  711. class MyModule(M.Module):
  712. def forward(self, x):
  713. x = x + 1
  714. return x
  715. net = MyModule()
  716. inp = F.zeros(shape = (1, ))
  717. traced_module = tm.trace_module(net, inp)
  718. graph = traced_module.graph
  719. inp_node = graph.inputs[1]
  720. out_node = graph.outputs[0]
  721. graph.add_output_node(inp_node)
  722. graph.add_output_node(out_node)
  723. out = traced_module(inp)
  724. Will produce the following ``InternalGraph`` and ``out``::
  725. print(graph)
  726. print(out)
  727. .. code-block:: text
  728. MyModule.Graph (self, x) {
  729. %2: add_out = x.__add__(1, )
  730. return add_out, x, add_out
  731. }
  732. ((Tensor([1.], device=xpux:0), Tensor([0.], device=xpux:0)), Tensor([1.], device=xpux:0))
  733. """
  734. forma_mnode = self.inputs[0]
  735. moudle = forma_mnode.owner
  736. assert moudle._is_top, "add_output_node only supports top graph"
  737. tree_def = list(moudle.argdef_graph_map.keys())[0]
  738. org_out_def = moudle.argdef_outdef_map[tree_def]
  739. org_outs = org_out_def.unflatten(self._outputs)
  740. outputs, out_def = tree_flatten(
  741. (org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode),
  742. )
  743. self._outputs[:] = outputs
  744. moudle.argdef_outdef_map[tree_def] = out_def
  745. def insert_exprs(self, expr: Optional[Expr] = None):
  746. r"""Initialize the trace mode and insertion position.
  747. When used within a 'with' statement, this will temporary set the trace mode and
  748. then restore normal mode when the with statement exits::
  749. with graph.insert_exprs(e): # set the trace mode
  750. ... # trace function or module
  751. ... # inert exprs into graph and resotre normal mode
  752. Args:
  753. expr: the ``expr`` after which to insert. If None, the insertion position will be
  754. automatically set based on the input node.
  755. Returns:
  756. A resource manager that will initialize trace mode on ``__enter__`` and
  757. restore normal mode on ``__exit__``.
  758. """
  759. if expr is not None:
  760. assert expr.top_graph == self, "Expr to insert after is not in graph."
  761. return _InsertExprs(self, expr)
  762. def replace_node(self, repl_dict: Dict[Node, Node]):
  763. r"""Replace the Nodes in the graph.
  764. Args:
  765. repl_dict: the map {old_Node: new_Node} that specifies how to replace the Nodes.
  766. """
  767. while repl_dict:
  768. node, repl_node = repl_dict.popitem()
  769. assert type(node) == type(
  770. repl_node
  771. ), "The type of {}({}) and {}({}) are not the same".format(
  772. node, type(node).__name__, repl_node, type(repl_node).__name__
  773. )
  774. # check graph inputs and outputs
  775. for i, n in enumerate(self.outputs):
  776. if n is node:
  777. self.outputs[i] = repl_node
  778. # update users of node and repl_node
  779. # update inputs of expr in node.users
  780. graph = repl_node.top_graph
  781. assert graph is not None
  782. assert graph is self
  783. index = -1
  784. if not isinstance(repl_node.expr, Input):
  785. index = graph._exprs.index(repl_node.expr)
  786. dep_exprs = self.get_dep_exprs(repl_node)
  787. i = 0
  788. while i < len(node.users):
  789. n = node.users[i]
  790. if n in graph._exprs and index >= graph._exprs.index(n):
  791. i += 1
  792. continue
  793. if n in dep_exprs:
  794. logger.info("Find a loop: ignore this replacement once")
  795. logger.info("node: %s" % node.__repr__())
  796. logger.info("expr: %s" % n.__repr__())
  797. i += 1
  798. continue
  799. repl_node.users.append(n)
  800. node.users.pop(i)
  801. idx = n.inputs.index(node)
  802. n.inputs[idx] = repl_node
  803. def _merge_getattr_expr(self):
  804. getattr_nodes_map = dict()
  805. for expr in self._exprs:
  806. if not isinstance(expr, GetAttr):
  807. continue
  808. attr_name = get_suffix_name(self.qualname, expr.outputs[0].qualname)
  809. assert attr_name, '"{}" is not a prefix of "{}"'.format(
  810. self.qualname, expr.outputs[0].qualname
  811. )
  812. if attr_name in getattr_nodes_map:
  813. base_node = getattr_nodes_map[attr_name]
  814. repl_node = expr.outputs[0]
  815. for expr in repl_node.users:
  816. base_node.users.append(expr)
  817. idx = expr.inputs.index(repl_node)
  818. expr.inputs[idx] = base_node
  819. repl_node.users = []
  820. else:
  821. if attr_name != expr.name:
  822. expr.name = attr_name
  823. expr.inputs[0].users.remove(expr)
  824. self.inputs[0].users.append(expr)
  825. expr.inputs[0] = self.inputs[0]
  826. getattr_nodes_map[attr_name] = expr.outputs[0]
  827. def compile(self):
  828. r"""Delete unused expr."""
  829. self._merge_getattr_expr()
  830. dep_exprs = self.get_dep_exprs(self.outputs)
  831. i = 0
  832. while i < len(self._exprs):
  833. expr = self._exprs[i]
  834. if expr in dep_exprs or expr._disable_remove:
  835. i += 1
  836. continue
  837. for n in expr.inputs:
  838. n.users.remove(expr)
  839. self._exprs.remove(expr)
  840. def _reset_ids(self):
  841. for total_expr_id, expr in enumerate(self.exprs()):
  842. expr._id = total_expr_id
  843. for total_node_id, node in enumerate(self.nodes()):
  844. node._id = total_node_id
  845. self._total_ids = (total_node_id + 1, total_expr_id + 1)
  846. def interpret(self, *inputs):
  847. node2value = {}
  848. end_nodes_set = set(self._end_point)
  849. endnode2value = {}
  850. def get_all_endnode_val(n, v):
  851. if n in end_nodes_set:
  852. endnode2value[n] = v
  853. end_nodes_set.remove(n)
  854. return not end_nodes_set
  855. return False
  856. ref_count = lambda n: len(n.users) + (1 if n in self._outputs else 0)
  857. for n, v in zip(self._inputs, inputs):
  858. if ref_count(n) > 0:
  859. node2value[n] = [v, ref_count(n)]
  860. if n in self._watch_point:
  861. self._rst[n].append(v)
  862. if n in self._end_point and get_all_endnode_val(n, v):
  863. return list(endnode2value[i] for i in self._end_point)
  864. for expr in self._exprs:
  865. values = expr.interpret(*list(node2value[i][0] for i in expr.inputs))
  866. for n in expr.inputs:
  867. node2value[n][1] -= 1
  868. if node2value[n][1] == 0:
  869. node2value.pop(n)
  870. if values is not None:
  871. for n, v in zip(expr.outputs, values):
  872. if ref_count(n) > 0:
  873. node2value[n] = [v, ref_count(n)]
  874. if n in self._watch_point:
  875. self._rst[n] = v
  876. if self._end_point and get_all_endnode_val(n, v):
  877. return list(endnode2value[i] for i in self._end_point)
  878. return list(node2value[i][0] for i in self._outputs)
  879. def eval(self, *inputs: Tuple[Tensor]):
  880. r"""Call this method to execute the graph.
  881. Args:
  882. inputs: the tensors corresponding to the ``graph.inputs[1:]``.
  883. """
  884. assert len(inputs) == len(self._inputs) - 1
  885. inp = [self._inputs[0].owner] + list(inputs)
  886. return self.interpret(*inp)
  887. def __repr__(self):
  888. return self.__format__()
  889. def __format__(self, format_spec: str = "") -> str:
  890. saved_format_spec = Node._set_format_spec(format_spec)
  891. name = ""
  892. if self._name:
  893. name = "%s.Graph" % self._name
  894. res = "{} ({}) {{\n\t{}\n\treturn {}\n}}".format(
  895. name,
  896. ", ".join(str(i) for i in self._inputs),
  897. "\n\t".join("{}".format(str(i)) for i in self._exprs),
  898. ", ".join(str(i) for i in self._outputs),
  899. )
  900. Node._set_format_spec(saved_format_spec)
  901. return res
  902. def __getstate__(self):
  903. state = self.__dict__.copy()
  904. if "_top_graph" in state:
  905. state.pop("_top_graph")
  906. return state
  907. def __setstate__(self, state):
  908. old_version = False
  909. if "_module_name" in state:
  910. old_version = True
  911. state["_qualname"] = state.pop("_module_name")
  912. prefix_name = state.pop("_prefix_name")
  913. if prefix_name:
  914. state["_name"] = "{}_{}".format(prefix_name, state["_name"])
  915. self.__dict__.update(state)
  916. if old_version:
  917. for n in self.nodes(False):
  918. qualname = self._qualname
  919. if isinstance(n.expr, CallMethod) and isinstance(
  920. n.expr.inputs[0], ModuleNode
  921. ):
  922. n._qualname = n.expr.inputs[0]._qualname + ".[out]"
  923. continue
  924. if n._qualname:
  925. qualname = "{}.{}".format(qualname, n._qualname)
  926. n._qualname = qualname
  927. def _get_meth_name(obj, func):
  928. tp = obj if isinstance(obj, type) else type(obj)
  929. for cls in tp.mro():
  930. for k, v in cls.__dict__.items():
  931. if v == func:
  932. return k
  933. return None
  934. def _wrapped_function(orig_func):
  935. @functools.wraps(orig_func)
  936. def wrapped_fn(*args, **kwargs):
  937. method_func = wrapped_fn
  938. if "method_func" in kwargs:
  939. method_func = kwargs.pop("method_func")
  940. if is_tracing_module():
  941. unset_module_tracing()
  942. inputs, tree_def = tree_flatten((args, kwargs))
  943. for i in inputs:
  944. if not NodeMixin.get(i, None):
  945. if isinstance(i, (RawTensor, NodeMixin)):
  946. NodeMixin.wrap_safe(i, Constant.make(i))
  947. meth_name, arg_type = None, None
  948. if args:
  949. meth_name = _get_meth_name(args[0], method_func)
  950. arg_type = args[0] if isinstance(args[0], type) else type(args[0])
  951. if meth_name and arg_type and issubclass(arg_type, RawTensor):
  952. self = inputs[0]
  953. if meth_name == "__new__":
  954. if all([not isinstance(i, RawTensor) for i in inputs]):
  955. # only trace Tensor.__new__() when there are tensors in args
  956. set_module_tracing()
  957. return orig_func(*args, **kwargs)
  958. if isinstance(args[1], RawTensor):
  959. node = NodeMixin.get(inputs[1])
  960. inputs[1] = copy.copy(inputs[1])
  961. # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor,
  962. # which will cause they have same _NodeMixin__node in tracing.
  963. NodeMixin.wrap_safe(inputs[1], node)
  964. args, kwargs = tree_def.unflatten(inputs)
  965. call_node = CallMethod.make(self, meth_name)
  966. else:
  967. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  968. call_node.add_inputs(inputs[1:])
  969. else:
  970. call_node = CallFunction.make(orig_func)
  971. call_node.add_inputs(inputs)
  972. call_node.arg_def = tree_def
  973. rst = orig_func(*args, **kwargs)
  974. if meth_name == "__setitem__":
  975. rst = self
  976. if rst is not None:
  977. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  978. call_node.out_def = out_def
  979. else:
  980. outputs = None
  981. call_node.add_outputs(outputs)
  982. set_module_tracing()
  983. return rst
  984. return orig_func(*args, **kwargs)
  985. return wrapped_fn
  986. class TracedModuleBuilder(NodeMixin):
  987. _mod = None # type: Module
  988. _body = None # type: InternalGraph
  989. _is_builtin = None # type: bool
  990. _argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
  991. _argdef_outdef_map = None # type: Dict[Treedef, Treedef]
  992. nodes = None
  993. __builder_attributes__ = [
  994. "_mod",
  995. "_body",
  996. "_NodeMixin__node",
  997. "_is_builtin",
  998. "build",
  999. "_record_wrapped_nodes",
  1000. "_argdef_graph_map",
  1001. "_argdef_outdef_map",
  1002. "nodes",
  1003. "__class__",
  1004. "__dict__",
  1005. ]
  1006. def __init__(self, mod, is_top_module=False):
  1007. super(TracedModuleBuilder, self).__init__()
  1008. assert isinstance(mod, Module)
  1009. self._mod = mod
  1010. self._body = None
  1011. self._is_top = is_top_module
  1012. self._is_builtin = (
  1013. True
  1014. if isinstance(mod, (Observer, _FakeQuantize))
  1015. else module_tracer.is_builtin(mod)
  1016. )
  1017. if isinstance(self._mod, QATModule):
  1018. unset_module_tracing()
  1019. self._check_qat_module(self._mod)
  1020. set_module_tracing()
  1021. self._argdef_graph_map = {}
  1022. self._argdef_outdef_map = {}
  1023. self.nodes = set()
  1024. # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
  1025. # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
  1026. self.__class__ = type(
  1027. "TracedModuleBuilder",
  1028. (TracedModuleBuilder, mod.__class__),
  1029. dict(TracedModuleBuilder.__dict__),
  1030. )
  1031. def _check_qat_module(self, qat_module):
  1032. def isbuiltin(m):
  1033. return m is None or module_tracer.is_builtin(m)
  1034. if qat_module.with_act:
  1035. act_observer = qat_module.act_observer
  1036. act_fakequant = qat_module.act_fake_quant
  1037. if not isbuiltin(act_observer) or not isbuiltin(act_fakequant):
  1038. qparams = (
  1039. act_observer.get_qparams()
  1040. if hasattr(act_observer, "get_qparams")
  1041. else act_fakequant.get_qparams()
  1042. )
  1043. dtype = (
  1044. act_observer.dtype
  1045. if hasattr(act_observer, "dtype")
  1046. else act_fakequant.dtype
  1047. )
  1048. qat_module.act_observer = None
  1049. qat_module.act_fake_quant = TM_FakeQuant(dtype)
  1050. qat_module.act_fake_quant.set_qparams(qparams)
  1051. if qat_module.with_weight:
  1052. weight_observer = qat_module.weight_observer
  1053. weight_fakequant = qat_module.weight_fake_quant
  1054. if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant):
  1055. qparams = (
  1056. weight_observer.get_qparams()
  1057. if hasattr(weight_observer, "get_qparams")
  1058. else weight_fakequant.get_qparams()
  1059. )
  1060. dtype = (
  1061. weight_observer.dtype
  1062. if hasattr(weight_observer, "dtype")
  1063. else weight_fakequant.dtype
  1064. )
  1065. qat_module.weight_observer = None
  1066. qat_module.weight_fake_quant = TM_FakeQuant(dtype)
  1067. qat_module.weight_fake_quant.set_qparams(qparams)
  1068. def build(self):
  1069. if self._is_builtin or isinstance(self._mod, TracedModule):
  1070. if module_tracer.is_builtin(self._mod) or isinstance(
  1071. self._mod, TracedModule
  1072. ):
  1073. mod_type = type(self._mod)
  1074. else:
  1075. assert isinstance(self._mod, (Observer, _FakeQuantize))
  1076. mod_type = (
  1077. Observer if isinstance(self._mod, Observer) else _FakeQuantize
  1078. )
  1079. for node in self.nodes:
  1080. node.module_type = mod_type
  1081. return self._mod
  1082. else:
  1083. is_qat = isinstance(self._mod, QATModule)
  1084. traced_module = TracedModule(
  1085. self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat
  1086. )
  1087. for _, g in self._argdef_graph_map.items():
  1088. g.compile()
  1089. if self._is_top:
  1090. g._total_ids = (Node._get_next_id(), Expr._get_next_id())
  1091. for k, v in self.__dict__.items():
  1092. if k not in TracedModuleBuilder.__builder_attributes__:
  1093. if isinstance(v, TracedModuleBuilder):
  1094. v = v.build()
  1095. setattr(traced_module, k, v)
  1096. elif isinstance(v, RawTensor):
  1097. setattr(traced_module, k, v)
  1098. if isinstance(self._mod, QATModule):
  1099. unset_module_tracing()
  1100. traced_module.with_act = self._mod.with_act
  1101. traced_module.with_weight = self._mod.with_weight
  1102. if not hasattr(traced_module, "act_fake_quant"):
  1103. traced_module.act_fakequant = None
  1104. if not hasattr(traced_module, "act_observer"):
  1105. traced_module.act_observer = None
  1106. if not hasattr(traced_module, "weight_fake_quant"):
  1107. traced_module.weight_fakequant = None
  1108. if not hasattr(traced_module, "weight_observer"):
  1109. traced_module.weight_observer = None
  1110. set_module_tracing()
  1111. return traced_module
  1112. def _record_wrapped_nodes(self, node):
  1113. self.nodes.add(node)
  1114. def __call__(self, *args, **kwargs):
  1115. assert isinstance(self._mod, Module)
  1116. # prepare args and kwargs for inner graph
  1117. if "method_func" in kwargs:
  1118. kwargs.pop("method_func")
  1119. def mark_constant(x):
  1120. node = NodeMixin.get(x, None)
  1121. if node is None: # capture as constant
  1122. NodeMixin.wrap(x, lambda: Constant.make(x))
  1123. inputs, tree_def = tree_flatten(((self, *args), kwargs))
  1124. for i in inputs:
  1125. mark_constant(i)
  1126. callnode = CallMethod.make(NodeMixin.get(self))
  1127. callnode.add_inputs(inputs[1:])
  1128. callnode.arg_def = tree_def
  1129. if (
  1130. self._is_builtin
  1131. or tree_def in self._argdef_graph_map
  1132. or isinstance(self._mod, TracedModule)
  1133. ):
  1134. unset_module_tracing()
  1135. rst = self._mod(*args, **kwargs)
  1136. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  1137. set_module_tracing()
  1138. if self._is_builtin:
  1139. self._body = None
  1140. elif tree_def in self._argdef_graph_map:
  1141. self._body = self._argdef_graph_map[tree_def]
  1142. else:
  1143. self._mod._is_top = False
  1144. self._body = self._mod.argdef_graph_map[tree_def]
  1145. module_qualname = NodeMixin.get(self).qualname
  1146. if module_qualname != self._body.qualname:
  1147. src_name, dst_name = self._body.qualname, module_qualname
  1148. def replace_qualname(g):
  1149. attr_name = get_suffix_name(src_name, g.qualname)
  1150. if attr_name is not None:
  1151. g._qualname = (
  1152. ("%s.%s" % (dst_name, attr_name))
  1153. if attr_name
  1154. else dst_name
  1155. )
  1156. assert get_suffix_name(dst_name, g.qualname) is not None
  1157. for mod in self._mod.modules():
  1158. if not hasattr(mod, "argdef_graph_map"):
  1159. continue
  1160. for g in mod.argdef_graph_map.values():
  1161. replace_qualname(g)
  1162. for n in g.nodes(False):
  1163. replace_qualname(n)
  1164. else:
  1165. self_node = None
  1166. orig_self = NodeMixin.get(self)
  1167. parent_graph = active_module_tracer().current_scope()
  1168. module_qualname = orig_self._qualname
  1169. self._body = InternalGraph(
  1170. name=parent_graph._namespace.create_unique_name(module_qualname),
  1171. qualname=module_qualname,
  1172. )
  1173. active_module_tracer().push_scope(self._body)
  1174. # rebind self to new input node
  1175. if self_node:
  1176. NodeMixin.wrap_safe(self, self_node)
  1177. active_module_tracer().current_scope()._add_input(self_node)
  1178. else:
  1179. NodeMixin.wrap_safe(
  1180. self,
  1181. self_node
  1182. if self_node
  1183. else Input.make(
  1184. name="self",
  1185. qualname=module_qualname,
  1186. type=NodeMixin.get_wrapped_type(self),
  1187. ),
  1188. )
  1189. origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
  1190. # prepare args and kwargs for inner graph
  1191. index_args, index_kwargs = tree_def.unflatten(
  1192. [
  1193. ArgsIndex(0),
  1194. *list(ArgsIndex(i + 1) for i in range(len(origin_inp_node))),
  1195. ]
  1196. )
  1197. key2idx = getcallargs(type(self._mod).forward, *index_args, **index_kwargs)
  1198. idx2key = {}
  1199. for k, v in key2idx.items():
  1200. if isinstance(v, ArgsIndex):
  1201. idx2key[v.index] = k
  1202. else:
  1203. flatten_argidx, _ = tree_flatten(v)
  1204. for _i, v in enumerate(flatten_argidx):
  1205. if isinstance(v, ArgsIndex):
  1206. idx2key[v.index] = k + "_%d" % _i
  1207. def wrap(x, name):
  1208. if isinstance(x, (RawTensor, NodeMixin)):
  1209. NodeMixin.wrap(
  1210. x,
  1211. lambda: Input.make(
  1212. type=NodeMixin.get_wrapped_type(x),
  1213. name=name,
  1214. qualname="%s.[%s]" % (module_qualname, name),
  1215. ),
  1216. )
  1217. return x
  1218. args = [self]
  1219. for i, v in enumerate(inputs[1:]):
  1220. args.append(wrap(v, idx2key[i + 1]))
  1221. args, kwargs = tree_def.unflatten(args)
  1222. active_module_tracer().patcher.auto_patch(
  1223. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  1224. )
  1225. rst = type(self._mod).forward(*args, **kwargs)
  1226. if _convert_node_flag():
  1227. rst = _node_to_tensor(rst)[0][0]
  1228. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  1229. for i in (
  1230. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  1231. ):
  1232. mark_constant(i)
  1233. active_module_tracer().current_scope()._add_output(NodeMixin.get(i))
  1234. NodeMixin.wrap_safe(self, orig_self)
  1235. for arg, node in zip(inputs[1:], origin_inp_node):
  1236. if node:
  1237. NodeMixin.wrap_safe(arg, node)
  1238. active_module_tracer().pop_scope()
  1239. # rebind output to outer graph
  1240. callnode.out_def = out_def
  1241. callnode.add_outputs(outputs)
  1242. self._argdef_graph_map[callnode.arg_def] = self._body
  1243. self._argdef_outdef_map[callnode.arg_def] = out_def
  1244. return rst
  1245. def __setattr__(self, name, value):
  1246. object.__setattr__(self, name, value)
  1247. def __repr__(self):
  1248. return repr(self._mod)
  1249. def __getattr__(self, name):
  1250. if name not in self._mod.__dict__:
  1251. attr = getattr(type(self._mod), name).__get__(self, type(self))
  1252. else:
  1253. attr = getattr(self._mod, name)
  1254. if (
  1255. isinstance(attr, FunctionType)
  1256. and id(attr) in active_module_tracer().patcher.patched_fn_ids
  1257. ):
  1258. return active_module_tracer().patcher.wrap_fn(attr)
  1259. if isinstance(attr, (List, Dict)):
  1260. unset_module_tracing()
  1261. has_module, m_container = replace_container_with_module_container(attr)
  1262. if m_container:
  1263. attr = m_container
  1264. if has_module and not m_container:
  1265. raise ValueError(
  1266. "Can not trace the module that uses the same container to store"
  1267. " Module and Non-Module objects."
  1268. )
  1269. set_module_tracing()
  1270. if isinstance(attr, Module):
  1271. attr = TracedModuleBuilder(attr)
  1272. if isinstance(attr, (Module, RawTensor)):
  1273. setattr(self, name, attr)
  1274. NodeMixin.wrap(
  1275. attr,
  1276. lambda: GetAttr.make(
  1277. NodeMixin.get(self),
  1278. type=NodeMixin.get_wrapped_type(attr),
  1279. attr_name=name,
  1280. name="",
  1281. ),
  1282. )
  1283. return attr
  1284. def __getattribute__(self, name):
  1285. if name in TracedModuleBuilder.__builder_attributes__:
  1286. return object.__getattribute__(self, name)
  1287. else:
  1288. wrapped = object.__getattribute__(self, name)
  1289. class_members = dict(inspect.getmembers(self.__class__))
  1290. if name in self._mod.__dict__:
  1291. mod_attr = getattr(self._mod, name)
  1292. if name in class_members:
  1293. if (
  1294. not isinstance(wrapped, TracedModuleBuilder)
  1295. and wrapped is not mod_attr
  1296. ):
  1297. wrapped = self.__getattr__(name)
  1298. if isinstance(wrapped, TracedModuleBuilder):
  1299. if not isinstance(mod_attr, (List, Dict)):
  1300. assert mod_attr is wrapped._mod
  1301. else:
  1302. assert mod_attr is wrapped
  1303. if isinstance(wrapped, (NodeMixin, RawTensor)):
  1304. NodeMixin.wrap(
  1305. wrapped,
  1306. lambda: GetAttr.make(
  1307. NodeMixin.get(self),
  1308. type=NodeMixin.get_wrapped_type(wrapped),
  1309. attr_name=name,
  1310. name="",
  1311. ),
  1312. )
  1313. return wrapped
  1314. class _expr_iter:
  1315. def __init__(self, graph: InternalGraph, recursive: bool = True):
  1316. self.graph = graph
  1317. self.recursive = recursive
  1318. def __iter__(self):
  1319. for inp_node in self.graph.inputs:
  1320. yield inp_node.expr
  1321. for expr in self.graph._exprs:
  1322. if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
  1323. yield expr
  1324. if self.recursive and expr.graph is not None:
  1325. yield from expr.graph.exprs(self.recursive)
  1326. else:
  1327. yield expr
  1328. class _node_iter:
  1329. def __init__(self, graph: InternalGraph, recursive: bool = True) -> None:
  1330. nodes = []
  1331. node_ids = set()
  1332. for expr in graph.exprs(recursive):
  1333. for n in expr.inputs + expr.outputs:
  1334. if id(n) in node_ids:
  1335. continue
  1336. nodes.append(n)
  1337. node_ids.add(id(n))
  1338. self.nodes = list(sorted(nodes, key=lambda x: x._id))
  1339. def __iter__(self):
  1340. for node in self.nodes:
  1341. yield node
  1342. class BaseFilter:
  1343. r"""``BaseFilter`` exposes some methods for converting ``_node_iter/_expr_iter`` to ``list``, ``dict``, etc."""
  1344. def __init__(self, iter: Iterable):
  1345. self._iter = iter
  1346. def __iter__(self):
  1347. return iter(self._iter)
  1348. def as_list(self):
  1349. r"""Consume this iterator and return its content as a list.
  1350. Returns:
  1351. A list of ``Node`` or ``Expr``.
  1352. """
  1353. return list(self)
  1354. def as_dict(self):
  1355. r"""Construct an ordered dict to map from ``id`` to objects in this iterator.
  1356. Returns:
  1357. An :class:`OrderedDict`.
  1358. """
  1359. return collections.OrderedDict((i._id, i) for i in self)
  1360. def as_unique(self):
  1361. """Assert that this iterator yields only one ``Node`` or ``Expr`` and return it.
  1362. Rerurns:
  1363. A ``Node`` or ``Expr``.
  1364. """
  1365. rst = self.as_list()
  1366. assert len(rst) == 1, "{} elements found".format(len(rst))
  1367. (elem,) = self
  1368. return elem
  1369. def as_count(self):
  1370. r"""Consume this iterator and get the number of elements."""
  1371. return sum(1 for _ in self)
  1372. class ExprFilter(BaseFilter):
  1373. """Filter on Expr iterator.
  1374. This class is an iterator of :class:`.Expr` objects and multiple
  1375. filtering conditions and mappers can be chained.
  1376. """
  1377. def call_function(self, func):
  1378. r"""Filter by specific ``CallFunction.func``.
  1379. See :meth:`~.InternalGraph.get_function_by_type` for details.
  1380. """
  1381. return ExprFilterCallFunction(self, func)
  1382. def call_method(self, method):
  1383. r"""Filter by specific ``CallMethod.method``.
  1384. See :meth:`~.InternalGraph.get_function_by_type` for details.
  1385. """
  1386. return ExprFilterCallMethod(self, method)
  1387. def expr_id(self, expr_id: List[int]):
  1388. r"""Filter Exprs by their ``id``.
  1389. See :meth:`~.InternalGraph.get_function_by_type` for details.
  1390. """
  1391. return ExprFilterExprId(self, expr_id)
  1392. class NodeFilter(BaseFilter):
  1393. """Filter on Node iterator.
  1394. This class is an iterator of :class:`.Node` objects and multiple
  1395. filtering conditions and mappers can be chained.
  1396. """
  1397. def type(self, owner_type):
  1398. r"""Filter by specific Module type.
  1399. See :meth:`~.InternalGraph.get_module_by_type` for details.
  1400. """
  1401. return NodeFilterType(self, owner_type)
  1402. def node_id(self, node_id: List[int]):
  1403. r"""Filter Nodes by their ``id``.
  1404. See :meth:`~.InternalGraph.get_node_by_id` for details.
  1405. """
  1406. return NodeFilterNodeId(self, node_id)
  1407. def name(self, name: str, ignorecase: bool = True):
  1408. r"""Filter Nodes by their full name.
  1409. See :meth:`~.InternalGraph.get_node_by_name` for details.
  1410. """
  1411. return NodeFilterName(self, name, ignorecase)
  1412. class NodeFilterType(NodeFilter):
  1413. """See :meth:`~.InternalGraph.get_module_by_type`"""
  1414. def __init__(self, expr_iter, owner_type):
  1415. super().__init__(expr_iter)
  1416. self.owner_type = owner_type
  1417. def __iter__(self):
  1418. for node in self._iter:
  1419. if not isinstance(node, ModuleNode):
  1420. continue
  1421. if not hasattr(node, "owner"):
  1422. continue
  1423. if isinstance(node.owner, self.owner_type):
  1424. yield node
  1425. class NodeFilterNodeId(NodeFilter):
  1426. """See :meth:`~.InternalGraph.get_node_by_id`"""
  1427. def __init__(self, expr_iter, node_id: List[int]):
  1428. super().__init__(expr_iter)
  1429. if not isinstance(node_id, Sequence):
  1430. node_id = [node_id]
  1431. self.node_id = node_id
  1432. def __iter__(self):
  1433. for node in self._iter:
  1434. if node._id in self.node_id:
  1435. yield node
  1436. class NodeFilterName(NodeFilter):
  1437. """See :meth:`~.InternalGraph.get_node_by_name`"""
  1438. _re = None
  1439. def __init__(self, node_iter, pattern, ignorecase):
  1440. super().__init__(node_iter)
  1441. self.pattern = pattern
  1442. self._re = self.make_re(pattern, ignorecase)
  1443. @classmethod
  1444. def make_re(cls, pattern, ignorecase=True):
  1445. assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern)
  1446. assert isinstance(ignorecase, bool)
  1447. flags = 0
  1448. if ignorecase:
  1449. flags |= re.IGNORECASE
  1450. return re.compile(fnmatch.translate(pattern), flags=flags)
  1451. def __iter__(self):
  1452. for i in self._iter:
  1453. graph = i.top_graph
  1454. name = "{}_{}".format(graph._name, i._name)
  1455. if self.pattern == name or self._re.match(name):
  1456. yield i
  1457. class ExprFilterCallFunction(ExprFilter):
  1458. """See :meth:`~.InternalGraph.get_function_by_type`"""
  1459. def __init__(self, expr_iter, func: Callable = None):
  1460. super().__init__(expr_iter)
  1461. self.func = func
  1462. def __iter__(self):
  1463. for expr in self._iter:
  1464. if not isinstance(expr, CallFunction):
  1465. continue
  1466. if self.func is None or expr.func == self.func:
  1467. yield expr
  1468. class ExprFilterCallMethod(ExprFilter):
  1469. """See :meth:`~.InternalGraph.get_method_by_type`"""
  1470. def __init__(self, expr_iter, method: str = None):
  1471. super().__init__(expr_iter)
  1472. self.method = method
  1473. def __iter__(self):
  1474. for expr in self._iter:
  1475. if not isinstance(expr, CallMethod):
  1476. continue
  1477. if self.method is None or expr.method == self.method:
  1478. yield expr
  1479. class ExprFilterExprId(ExprFilter):
  1480. """See :meth:`~.InternalGraph.get_expr_by_id`"""
  1481. def __init__(self, expr_iter, expr_id: List[int]):
  1482. super().__init__(expr_iter)
  1483. if not isinstance(expr_id, Sequence):
  1484. expr_id = [expr_id]
  1485. self.expr_id = expr_id
  1486. def __iter__(self):
  1487. for expr in self._iter:
  1488. if expr._id in self.expr_id:
  1489. yield expr
  1490. class TracedModule(Module):
  1491. r"""``TracedModule`` is the Module created by tracing normal module.
  1492. It owns an argdef to graph(InternalGraph) map. The forward method of ``TracedModule``
  1493. will get a graph from ``argdef_graph_map`` according to the argdef of input ``args/kwargs``
  1494. and interpret it.
  1495. .. note::
  1496. ``TracedModule`` can only be created by :func:`~.trace_module`. See :func:`~.trace_module`
  1497. for more details.
  1498. """
  1499. # m_node = None # type: ModuleNode
  1500. argdef_graph_map = None
  1501. argdef_outdef_map = None
  1502. def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False):
  1503. super(TracedModule, self).__init__()
  1504. self.argdef_graph_map = argdef_graph_map
  1505. self.argdef_outdef_map = argdef_outdef_map
  1506. self._is_top = is_top
  1507. self.watch_points = []
  1508. self.watch_node_value = {}
  1509. self.end_points = []
  1510. self.is_qat = is_qat
  1511. def forward(self, *args, **kwargs):
  1512. inputs, treedef = tree_flatten(((self, *args), kwargs))
  1513. assert treedef in self.argdef_graph_map
  1514. inputs = filter(
  1515. lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
  1516. ) # allow TracedModuleBuilder for retrace.
  1517. outputs = self.argdef_graph_map[treedef].interpret(*inputs)
  1518. if self.watch_points:
  1519. self.watch_node_value = {}
  1520. for n in self.watch_points:
  1521. self.watch_node_value[n] = n.top_graph._rst.pop(n)
  1522. if self.end_points:
  1523. return outputs
  1524. out_def = self.argdef_outdef_map[treedef]
  1525. outputs = out_def.unflatten(outputs)
  1526. return outputs
  1527. def set_watch_points(self, nodes):
  1528. r"""Initialize the :attr:`~.TracedModule.watch_points`.
  1529. You can call this function to get the ``Tensor/Module`` corresponding to a ``Node`` at runtime.
  1530. Args:
  1531. nodes: a list of ``Node``.
  1532. For example, the following code
  1533. .. code-block::
  1534. import megengine.module as M
  1535. import megengine as mge
  1536. import megengine.traced_module as tm
  1537. class MyModule(M.Module):
  1538. def forward(self, x):
  1539. x = x + 1 + 2
  1540. return x
  1541. net = MyModule()
  1542. inp = mge.Tensor([0])
  1543. traced_module = tm.trace_module(net, inp)
  1544. add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
  1545. traced_module.set_watch_points(add_1_node)
  1546. out = traced_module(inp)
  1547. Will get the following ``watch_node_value``::
  1548. print(traced_module.watch_node_value)
  1549. .. code-block:: text
  1550. {add_out: Tensor([1.], device=xpux:0)}
  1551. """
  1552. if not isinstance(nodes, Sequence):
  1553. nodes = [nodes]
  1554. self.watch_points = nodes
  1555. if nodes:
  1556. nodes[0].top_graph._watch_point = []
  1557. for n in nodes:
  1558. n.top_graph._watch_point.append(n)
  1559. def clear_watch_points(self):
  1560. r"""Clear the :attr:`~.TracedModule.watch_points` and :attr:`~.TracedModule.watch_node_value`.
  1561. """
  1562. for n in self.watch_points:
  1563. n.top_graph._watch_point = []
  1564. self.watch_points = []
  1565. self.watch_node_value = {}
  1566. def set_end_points(self, nodes: Sequence[Node]):
  1567. r"""Initialize the :attr:`~.TracedModule.end_points`.
  1568. When all the ``nodes`` are generated, the Module will stop execution and return directly.
  1569. Args:
  1570. nodes: a list of ``Node``.
  1571. For example, the following code
  1572. .. code-block::
  1573. import megengine.module as M
  1574. import megengine as mge
  1575. import megengine.traced_module as tm
  1576. class MyModule(M.Module):
  1577. def forward(self, x):
  1578. x = x + 1 + 2
  1579. return x
  1580. net = MyModule()
  1581. inp = mge.Tensor([0])
  1582. traced_module = tm.trace_module(net, inp)
  1583. add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
  1584. traced_module.set_end_points(add_1_node)
  1585. out = traced_module(inp)
  1586. Will get the following ``out``::
  1587. print(out)
  1588. .. code-block:: text
  1589. [Tensor([1.], device=xpux:0)]
  1590. """
  1591. if not isinstance(nodes, Sequence):
  1592. nodes = [nodes]
  1593. self.end_points = nodes
  1594. graphs = list(self.argdef_graph_map.values())
  1595. for n in nodes:
  1596. assert n.top_graph in graphs
  1597. n.top_graph._end_point.append(n)
  1598. def clear_end_points(self):
  1599. r"""Clear the :attr:`~.TracedModule.end_points`.
  1600. """
  1601. for n in self.end_points:
  1602. n.top_graph._end_point = []
  1603. self.end_points = []
  1604. @property
  1605. def graph(self) -> InternalGraph:
  1606. """Return the ``InternalGraph`` of this ``TracedModule``.
  1607. """
  1608. if self._is_top:
  1609. self._update_ref()
  1610. assert len(self.argdef_graph_map) == 1
  1611. return list(self.argdef_graph_map.values())[0]
  1612. def _update_ref(self, actual_node_map: Union[Dict] = None, top_graph=None):
  1613. for inp_def, graph in self.argdef_graph_map.items():
  1614. if top_graph is not None:
  1615. graph._top_graph = weakref.ref(top_graph)
  1616. for n in graph._inputs + graph.outputs:
  1617. n._top_graph = weakref.ref(graph)
  1618. graph._inputs[0]._owner = weakref.ref(self)
  1619. for i, n in enumerate(graph._inputs):
  1620. n.actual_node = []
  1621. if actual_node_map is not None and inp_def in actual_node_map.keys():
  1622. n.actual_node = list(list(zip(*(actual_node_map[inp_def])))[i])
  1623. node2obj = {}
  1624. next_actual_node_map = collections.defaultdict(
  1625. lambda: collections.defaultdict(list)
  1626. )
  1627. node2obj[graph._inputs[0]] = self
  1628. for expr in graph._exprs:
  1629. for n in expr.inputs + expr.outputs:
  1630. n._top_graph = weakref.ref(graph)
  1631. expr._top_graph = weakref.ref(graph)
  1632. if isinstance(expr, GetAttr) and isinstance(
  1633. expr.outputs[0], ModuleNode
  1634. ):
  1635. obj = expr.interpret(node2obj[expr.inputs[0]])[0]
  1636. expr.outputs[0]._owner = weakref.ref(obj)
  1637. node2obj[expr.outputs[0]] = obj
  1638. if isinstance(expr, Constant) and isinstance(
  1639. expr.outputs[0], ModuleNode
  1640. ):
  1641. obj = expr.value
  1642. expr.outputs[0]._owner = weakref.ref(obj)
  1643. node2obj[expr.outputs[0]] = obj
  1644. if (
  1645. isinstance(expr, CallMethod)
  1646. and expr.method == "__call__"
  1647. and isinstance(expr.inputs[0], ModuleNode)
  1648. ):
  1649. obj = node2obj[expr.inputs[0]]
  1650. if expr.arg_def is not None:
  1651. next_actual_node_map[obj][expr.arg_def].append(expr.inputs)
  1652. for obj in node2obj.values():
  1653. if obj is self:
  1654. continue
  1655. mnode_map = None
  1656. if obj in next_actual_node_map.keys():
  1657. mnode_map = next_actual_node_map[obj]
  1658. if isinstance(obj, TracedModule):
  1659. obj._update_ref(mnode_map, graph)
  1660. def flatten(self):
  1661. r"""Get a new TracedModule, which eliminates ``GetAttr`` and has no hierarchy.
  1662. Retruns:
  1663. A new :class:`TracedModule`.
  1664. """
  1665. new_module = copy.deepcopy(self)
  1666. def _replace_inputs_and_outputs(expr: Expr, repl_dict: Dict[Node, Node]):
  1667. inputs, outputs = expr.inputs, expr.outputs
  1668. for i, node in enumerate(inputs):
  1669. if node in repl_dict:
  1670. inputs[i] = repl_dict[node]
  1671. for i, node in enumerate(outputs):
  1672. if node in repl_dict:
  1673. outputs[i] = repl_dict[node]
  1674. outputs[i].expr = expr
  1675. def _flatten_subgraph(
  1676. parent_graph: InternalGraph,
  1677. graph: InternalGraph,
  1678. call: CallMethod,
  1679. module: Module,
  1680. ):
  1681. repl_dict, node2obj, rename_blacklist = {}, {}, []
  1682. if call is not None:
  1683. graph = copy.deepcopy(graph)
  1684. node2obj[call.inputs[0]] = module
  1685. repl_dict = dict(zip(graph._inputs, call.inputs))
  1686. for ind, out in enumerate(graph.outputs):
  1687. if isinstance(out.expr, Input):
  1688. assert out in repl_dict
  1689. call_out = call.outputs[ind]
  1690. for expr in call.outputs[ind].users:
  1691. for index, inp in enumerate(expr.inputs):
  1692. if inp is call_out:
  1693. expr.inputs[index] = repl_dict[out]
  1694. repl_dict[out].users.append(expr)
  1695. if parent_graph is not None:
  1696. for index, parent_out in enumerate(parent_graph._outputs):
  1697. if parent_out is call_out:
  1698. parent_graph._outputs[index] = repl_dict[out]
  1699. continue
  1700. repl_dict[out] = call.outputs[ind]
  1701. if isinstance(out, TensorNode):
  1702. call.outputs[ind]._qualname = out._qualname
  1703. for node, repl_node in repl_dict.items():
  1704. assert node in graph._inputs or node in graph._outputs
  1705. for i in node.users:
  1706. if i not in repl_node.users:
  1707. repl_node.users.append(i)
  1708. rename_blacklist = list(chain(call.inputs, call.outputs))
  1709. node2obj[graph._inputs[0]] = module
  1710. prefix_name = call.inputs[0]._name if call else ""
  1711. exprs = []
  1712. for expr in graph._exprs:
  1713. if call is not None:
  1714. _replace_inputs_and_outputs(expr, repl_dict)
  1715. if isinstance(expr, GetAttr):
  1716. mnode = expr.inputs[0]
  1717. node2obj[expr.outputs[0]] = expr.interpret(node2obj[mnode])[0]
  1718. if isinstance(expr, CallMethod):
  1719. obj_node = expr.inputs[0]
  1720. if isinstance(obj_node, ModuleNode) and isinstance(
  1721. obj_node.expr, GetAttr
  1722. ):
  1723. obj = node2obj[obj_node]
  1724. expr_graph = (
  1725. obj.argdef_graph_map[expr.arg_def]
  1726. if hasattr(obj, "argdef_graph_map")
  1727. else None
  1728. )
  1729. if expr_graph is not None:
  1730. exprs.extend(
  1731. _flatten_subgraph(graph, expr_graph, expr, obj)
  1732. )
  1733. continue
  1734. if parent_graph is not None:
  1735. for node in expr.outputs:
  1736. if node in rename_blacklist:
  1737. continue
  1738. name = "{}_{}".format(prefix_name, node._name)
  1739. node._name = parent_graph._namespace.create_unique_name(name)
  1740. exprs.append(expr)
  1741. if call is not None:
  1742. for i in call.inputs:
  1743. i.users.remove(call)
  1744. return exprs
  1745. new_module.graph._exprs = _flatten_subgraph(
  1746. None, new_module.graph, None, new_module
  1747. )
  1748. new_module.graph.compile()
  1749. new_module.graph._reset_ids()
  1750. return new_module
  1751. def __getstate__(self):
  1752. d = self.__dict__
  1753. for k in Module.__dict__:
  1754. d.pop(k, None)
  1755. return d
  1756. def cpp_apply_module_trace(opdef, *args):
  1757. return Apply.apply_module_trace_hook(opdef, *args)
  1758. def register_as_builtin(mod_cls: Type[Module]) -> None:
  1759. r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module.
  1760. Args:
  1761. mod_cls: the module class which will be treated as builtin module in tracing.
  1762. """
  1763. module_tracer.register_as_builtin(mod_cls)
  1764. def wrap(func: Callable):
  1765. r"""Call this function to register ``func`` as a builtin function.
  1766. This function can be called at module-level scope to register ``func`` as a builtin function.
  1767. A builtin function will be converted to a :class:`CallFunction` Expr in tracing::
  1768. def my_func(x, y):
  1769. return x + y
  1770. import megengine.traced_module as tm
  1771. tm.wrap(my_func)
  1772. This function can also equivalently be used as a decorator::
  1773. @tm.wrap
  1774. def my_func(x, y):
  1775. return x + y
  1776. Args:
  1777. func: the function of the global function to insert into the graph when it's called.
  1778. """
  1779. assert callable(func), "func must be a callable"
  1780. assert hasattr(func, "__code__")
  1781. fn_name = func.__code__.co_name
  1782. currentframe = inspect.currentframe()
  1783. assert currentframe is not None
  1784. f = currentframe.f_back
  1785. assert f is not None
  1786. assert (
  1787. f.f_code.co_name == "<module>"
  1788. ), "wrap must be called at the top level of a module"
  1789. Patcher._builtin_functions.append((f.f_globals, fn_name))
  1790. return func
  1791. def _register_all_builtin_module():
  1792. for sub_mod in [M, M.qat, M.quantized]:
  1793. for m in getmembers(sub_mod):
  1794. if (
  1795. isclass(m[1])
  1796. and issubclass(m[1], M.Module)
  1797. and m[1] is not M.Sequential
  1798. ):
  1799. module_tracer.register_as_builtin(m[1])
  1800. module_tracer.register_as_builtin(Observer)
  1801. module_tracer.register_as_builtin(MinMaxObserver)
  1802. module_tracer.register_as_builtin(SyncMinMaxObserver)
  1803. module_tracer.register_as_builtin(ExponentialMovingAverageObserver)
  1804. module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver)
  1805. module_tracer.register_as_builtin(HistogramObserver)
  1806. module_tracer.register_as_builtin(PassiveObserver)
  1807. module_tracer.register_as_builtin(LSQ)
  1808. module_tracer.register_as_builtin(TQT)
  1809. module_tracer.register_as_builtin(FakeQuantize)
  1810. module_tracer.register_as_builtin(TM_FakeQuant)
  1811. def trace_module(
  1812. mod: Module, *args: Tuple[Any], **kwargs: Dict[str, Any]
  1813. ) -> TracedModule:
  1814. r"""Traces module ``mod`` and returns corresponding :class:`TracedModule`.
  1815. Args:
  1816. mod: the module will be converted to :class:`TracedModule`.
  1817. args: the positional arguments passed to forward method of ``mod``.
  1818. kwargs: the keyword arguments passed to forward method of ``mod``.
  1819. """
  1820. assert active_module_tracer() is None
  1821. assert isinstance(mod, Module)
  1822. try:
  1823. net_name = mod._name if mod._name else mod.__class__.__name__
  1824. use_sym_shape = set_symbolic_shape(True)
  1825. set_module_tracing()
  1826. set_active_module_tracer(module_tracer(_wrapped_function))
  1827. for cls in [Expr, Node]:
  1828. cls._set_next_id(0)
  1829. with active_module_tracer().patcher:
  1830. global_scope = InternalGraph(name="top", qualname=net_name)
  1831. active_module_tracer().push_scope(global_scope)
  1832. builder = TracedModuleBuilder(mod, True)
  1833. NodeMixin.wrap_safe(
  1834. builder, Input.make(name="top", type=ModuleNode, qualname=net_name)
  1835. )
  1836. inputs, _ = tree_flatten((args, kwargs))
  1837. for _, i in enumerate(inputs):
  1838. # assert isinstance(i, Tensor), "not support "
  1839. if isinstance(i, RawTensor):
  1840. NodeMixin.wrap_safe(
  1841. i,
  1842. Input.make(
  1843. name="arg_{}".format(_),
  1844. type=NodeMixin.get_wrapped_type(i),
  1845. qualname="{}.[{}]".format(net_name, "arg_{}".format(_)),
  1846. ),
  1847. )
  1848. builder(*args, **kwargs)
  1849. active_module_tracer().pop_scope()
  1850. traced_mod = builder.build()
  1851. traced_mod.graph._reset_ids()
  1852. return traced_mod
  1853. finally:
  1854. set_symbolic_shape(use_sym_shape)
  1855. set_active_module_tracer(None)
  1856. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台