You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 90 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import builtins
  9. import collections
  10. import copy
  11. import fnmatch
  12. import functools
  13. import inspect
  14. import keyword
  15. import re
  16. import weakref
  17. from importlib import import_module
  18. from inspect import getcallargs, getmembers, isclass, ismethod
  19. from itertools import chain
  20. from types import FunctionType
  21. from typing import (
  22. Any,
  23. Callable,
  24. Dict,
  25. Iterable,
  26. List,
  27. Optional,
  28. Sequence,
  29. Tuple,
  30. Type,
  31. Union,
  32. )
  33. from .. import functional as F
  34. from .. import get_logger
  35. from .. import module as M
  36. from ..core._imperative_rt.core2 import Tensor as RawTensor
  37. from ..core._imperative_rt.core2 import (
  38. apply,
  39. is_tracing_module,
  40. set_module_tracing,
  41. unset_module_tracing,
  42. )
  43. from ..core._trace_option import set_symbolic_shape
  44. from ..core.ops.builtin import Copy
  45. from ..core.tensor.utils import isscalar, setscalar
  46. from ..module import Module
  47. from ..module import external as MExternal
  48. from ..module.qat import QATModule
  49. from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize
  50. from ..quantization.observer import (
  51. ExponentialMovingAverageObserver,
  52. HistogramObserver,
  53. MinMaxObserver,
  54. Observer,
  55. PassiveObserver,
  56. SyncExponentialMovingAverageObserver,
  57. SyncMinMaxObserver,
  58. )
  59. from ..tensor import Tensor
  60. from ..utils.max_recursion_limit import max_recursion_limit
  61. from ..version import __version__
  62. from .expr import (
  63. Apply,
  64. CallFunction,
  65. CallMethod,
  66. Constant,
  67. Expr,
  68. GetAttr,
  69. Input,
  70. get_suffix_name,
  71. is_apply_def,
  72. is_call_function,
  73. is_call_module,
  74. is_call_tensor_method,
  75. is_constant,
  76. is_getattr,
  77. is_input,
  78. )
  79. from .fake_quant import FakeQuantize as TM_FakeQuant
  80. from .module_tracer import (
  81. PatchedFn,
  82. Patcher,
  83. active_module_tracer,
  84. get_tensor_wrapable_method,
  85. module_tracer,
  86. set_active_module_tracer,
  87. )
  88. from .node import ModuleNode, Node, NodeMixin, TensorNode
  89. from .pytree import (
  90. USER_REGISTERED_CONTAINER_TYPE,
  91. USER_REGISTERED_LEAF_TYPE,
  92. ArgsIndex,
  93. TreeDef,
  94. _register_supported_type,
  95. tree_flatten,
  96. )
  97. from .serialization import (
  98. _ModuleState,
  99. load_apply_expr,
  100. load_call_module_expr,
  101. load_call_tensor_method_expr,
  102. load_functional,
  103. )
  104. from .tm_config import (
  105. _exclude_from_trace,
  106. _get_default_checker,
  107. _get_expr_checker,
  108. _graph_surgery_mode,
  109. _set_graph_surgery_mode,
  110. )
  111. from .utils import (
  112. _check_builtin_module_attr,
  113. _check_obj_attr,
  114. _convert_kwargs_to_args,
  115. replace_container_with_module_container,
  116. )
  117. logger = get_logger(__name__)
  118. def _is_builtin_name(name: str) -> bool:
  119. return (
  120. name in builtins.__dict__
  121. or name in keyword.kwlist
  122. or name in {"inf", "nan", "NoneType"}
  123. )
  124. def _is_leaf(node):
  125. assert isinstance(
  126. node, RawTensor
  127. ), 'doesn\'t support {} in return values, MUST use Tensor or use "register_supported_type" method to register self-defined type'.format(
  128. type(node)
  129. )
  130. return isinstance(node, RawTensor)
  131. def _node_to_tensor(*args, **kwargs):
  132. tensors = []
  133. nodes, tree_def = tree_flatten((args, kwargs))
  134. for n in nodes:
  135. if isinstance(n, TensorNode):
  136. if n.top_graph is not None:
  137. active_module_tracer().current_scope()._add_input(n)
  138. value = n.value
  139. if value is None:
  140. flag = _set_graph_surgery_mode(False)
  141. unset_module_tracing()
  142. value = F.zeros(shape=n._shape, dtype=n._dtype)
  143. set_module_tracing()
  144. _set_graph_surgery_mode(flag)
  145. orig_n = NodeMixin.get(value, None)
  146. if orig_n is None or "setitem" not in orig_n._name:
  147. NodeMixin.wrap_safe(value, n)
  148. tensors.append(value)
  149. else:
  150. tensors.append(n)
  151. tensors = tree_def.unflatten(tensors)
  152. return tensors
  153. def _tensor_to_node(tensors):
  154. if tensors is None:
  155. return None
  156. nodes = []
  157. tensors, out_def = tree_flatten(tensors)
  158. for t in tensors:
  159. if isinstance(t, Tensor):
  160. n = NodeMixin.get(t, None)
  161. if isinstance(n, TensorNode):
  162. n.value = t
  163. nodes.append(n)
  164. else:
  165. nodes.append(t)
  166. else:
  167. nodes.append(t)
  168. nodes = out_def.unflatten(nodes)
  169. return nodes
  170. def _name_setter(node: Node, new_name: str):
  171. surgery_mode = _set_graph_surgery_mode(False)
  172. graph = active_module_tracer().current_scope()
  173. if node.top_graph is not None:
  174. top_graph = active_module_tracer().top_scope()
  175. if node is top_graph._namespace.used_names.get(node._name, None):
  176. graph = top_graph
  177. else:
  178. graph = node.top_graph
  179. assert (
  180. graph._namespace.used_names.get(new_name, None) is None
  181. ), "The name(%s) is already in use. Please try a different one again." % (new_name)
  182. graph._namespace.unassociate_name_with_obj(node)
  183. node._name = graph._namespace.create_unique_name(new_name, node)
  184. _set_graph_surgery_mode(surgery_mode)
  185. def _wrap_method_to_tensor_node():
  186. def _any_method(name, func):
  187. def _any(*args, **kwargs):
  188. if is_tracing_module() and _graph_surgery_mode():
  189. args, kwargs = _node_to_tensor(*args, **kwargs)
  190. attr = getattr(args[0], name)
  191. outs = attr
  192. if callable(attr):
  193. outs = attr(*(args[1:]), **kwargs)
  194. if name == "__setitem__":
  195. _node_to_tensor(outs)
  196. return None
  197. outs = _tensor_to_node(outs)
  198. return outs
  199. else:
  200. outs = func
  201. if callable(func):
  202. outs = func(*args, **kwargs)
  203. if isinstance(func, property):
  204. outs = func.__get__(*args, **kwargs)
  205. return outs
  206. return _any
  207. tensor_method_patch = []
  208. for method in get_tensor_wrapable_method():
  209. patch = PatchedFn(TensorNode, method)
  210. if type(getattr(Tensor, method)) == property:
  211. # Only support property.getter
  212. patch.set_func(property(_any_method(method, patch.origin_fn)))
  213. else:
  214. patch.set_func(_any_method(method, patch.origin_fn))
  215. tensor_method_patch.append(patch)
  216. patch = PatchedFn(Node, "name")
  217. patch.set_func(property(patch.origin_fn.fget, _name_setter))
  218. tensor_method_patch.append(patch)
  219. return tensor_method_patch
  220. def _convert_node_and_tensor(orig_func):
  221. @functools.wraps(orig_func)
  222. def _convert(*args, **kwargs):
  223. if is_tracing_module() and _graph_surgery_mode():
  224. args, kwargs = _node_to_tensor(*args, **kwargs)
  225. rst = orig_func(*args, **kwargs, method_func=_convert)
  226. rst = _tensor_to_node(rst)
  227. return rst
  228. else:
  229. rst = orig_func(*args, **kwargs)
  230. return rst
  231. return _convert
  232. def _wrap_mnode_getattr(orig_getattr):
  233. @functools.wraps(orig_getattr)
  234. def wraped_fn(self, name):
  235. if is_tracing_module() and _graph_surgery_mode():
  236. obj = self.owner
  237. current_graph = active_module_tracer().current_scope()
  238. if self.top_graph is not None:
  239. current_graph._add_input(self)
  240. attr = getattr(obj, name)
  241. node = attr
  242. if not isinstance(attr, TracedModuleBuilder):
  243. if isinstance(attr, Module):
  244. attr = TracedModuleBuilder(attr)
  245. setattr(obj, name, attr)
  246. if isinstance(attr, (NodeMixin, RawTensor)):
  247. NodeMixin.wrap(
  248. attr,
  249. lambda: GetAttr.make(
  250. self,
  251. type=NodeMixin.get_wrapped_type(attr),
  252. attr_name=name,
  253. name="",
  254. ),
  255. )
  256. if isinstance(attr, (NodeMixin, RawTensor)):
  257. node = NodeMixin.get(attr)
  258. if isinstance(node, ModuleNode) and isinstance(attr, (NodeMixin, Module)):
  259. node._owner = weakref.ref(attr)
  260. return node
  261. else:
  262. node = object.__getattribute__(self, name)
  263. return node
  264. return wraped_fn
  265. def _wrap_mnode_call(orig_call):
  266. @functools.wraps(orig_call)
  267. def wraped_fn(self, *args, **kwargs):
  268. if is_tracing_module() and _graph_surgery_mode():
  269. obj = self.owner
  270. if self.top_graph is not None:
  271. active_module_tracer().current_scope()._add_input(self)
  272. rst = obj(*args, **kwargs)
  273. else:
  274. raise TypeError("'ModuleNode' object is not callable")
  275. return rst
  276. return wraped_fn
  277. class _InsertExprs:
  278. def __init__(self, graph, expr: Optional[Expr] = None):
  279. self.graph = graph
  280. while graph.top_graph is not None:
  281. graph = graph.top_graph
  282. assert graph.inputs[0].owner._is_top
  283. self.root_graph = graph
  284. self.global_scope = InternalGraph(self.graph._name, self.graph._qualname)
  285. self.global_scope._namespace.merge(self.graph._namespace)
  286. self.expr = expr
  287. self._tensor_method_patch = None
  288. def __enter__(self):
  289. self.use_sym_shape = set_symbolic_shape(True)
  290. node_id, expr_id = self.root_graph._total_ids
  291. Node._set_next_id(node_id)
  292. Expr._set_next_id(expr_id)
  293. set_module_tracing()
  294. _set_graph_surgery_mode(True)
  295. assert active_module_tracer() is None
  296. set_active_module_tracer(
  297. module_tracer(lambda x: _convert_node_and_tensor(_wrapped_function(x)))
  298. )
  299. active_module_tracer().patcher.__enter__()
  300. for cls, name, func in [
  301. [ModuleNode, "__getattr__", _wrap_mnode_getattr],
  302. [ModuleNode, "__call__", _wrap_mnode_call],
  303. [TracedModuleBuilder, "__call__", _convert_node_and_tensor],
  304. ]:
  305. active_module_tracer().patcher.patch_function(cls, name, func)
  306. self._tensor_method_patch = _wrap_method_to_tensor_node()
  307. active_module_tracer().push_scope(self.global_scope)
  308. def __exit__(self, ty, va, tr):
  309. if va is not None:
  310. return False
  311. active_module_tracer().patcher.__exit__(ty, va, tr)
  312. while self._tensor_method_patch:
  313. pf = self._tensor_method_patch.pop()
  314. pf.set_func(pf.origin_fn)
  315. # delete ModuleNode.__call__ to avoid entering the
  316. # ModuleNode.__init__ method when call a ModuleNode object.
  317. delattr(ModuleNode, "__call__")
  318. module = self.graph.inputs[0].owner
  319. def build_traced_module(
  320. module: TracedModuleBuilder, target_module: TracedModule
  321. ):
  322. for k, v in module.__dict__.items():
  323. if isinstance(v, TracedModuleBuilder):
  324. traced_v = v.build()
  325. build_traced_module(v, traced_v)
  326. setattr(target_module, k, traced_v)
  327. build_traced_module(module, module)
  328. set_symbolic_shape(self.use_sym_shape)
  329. _set_graph_surgery_mode(False)
  330. set_active_module_tracer(None)
  331. unset_module_tracing()
  332. extra_inp_nodes = set(self.global_scope.inputs)
  333. max_inp_expr_idx = -1
  334. for node in extra_inp_nodes:
  335. assert (
  336. node.top_graph == self.graph
  337. ), "The input node ({}) is not in the graph ({})".format(node, self.graph)
  338. if node.expr in self.graph._exprs:
  339. max_inp_expr_idx = max(
  340. max_inp_expr_idx, self.graph._exprs.index(node.expr)
  341. )
  342. max_inp_expr_idx += 1
  343. insert_index = -1
  344. if self.expr in self.graph._exprs:
  345. insert_index = self.graph._exprs.index(self.expr)
  346. insert_index += 1
  347. if insert_index < max_inp_expr_idx:
  348. insert_index = max_inp_expr_idx
  349. for expr in self.global_scope._exprs:
  350. self.graph._exprs.insert(insert_index, expr)
  351. insert_index += 1
  352. self.graph._namespace.merge(self.global_scope._namespace)
  353. self.root_graph._total_ids = (Node._get_next_id(), Expr._get_next_id())
  354. self.root_graph.inputs[0].owner._update_ref()
  355. for node in self.root_graph.nodes():
  356. if isinstance(node, TensorNode):
  357. node.value = None
  358. return True
  359. class NameSpace:
  360. def __init__(self, name, qualname):
  361. self.name = name
  362. self.qualname = qualname
  363. self._used_names = {}
  364. def create_unique_name(self, name: str, node: Any = None) -> str:
  365. assert isinstance(name, str), "The name must be a string"
  366. if name in self._used_names and (self._used_names[name] is node):
  367. return name
  368. name = re.sub("[^0-9a-zA-Z_]+", "_", name)
  369. if name[0].isdigit():
  370. name = "_{}".format(name)
  371. while (
  372. name in self._used_names and self._used_names[name] is not None
  373. ) or _is_builtin_name(name):
  374. match = re.match(r"(.*)_(\d+)$", name)
  375. if match is None:
  376. name = name + "_1"
  377. else:
  378. base, num = match.group(1, 2)
  379. name = "{}_{}".format(base, int(num) + 1)
  380. self._used_names.setdefault(name)
  381. if node is not None:
  382. self.associate_name_with_obj(name, node)
  383. return name
  384. def auto_naming_for_outputs(self, expr: Expr):
  385. _add_suffix = lambda x: x + "_out"
  386. if is_call_module(expr):
  387. call_node = expr.inputs[0]
  388. qualname = "%s.[out]" % (call_node.qualname)
  389. name = call_node.name
  390. elif is_call_tensor_method(expr):
  391. name = expr.method.strip("_")
  392. qualname = "{}.[{}]".format(
  393. self.qualname, self.create_unique_name("method_%s" % (name)),
  394. )
  395. elif is_call_function(expr):
  396. name = expr.func.__name__
  397. qualname = "{}.[{}]".format(
  398. self.qualname, self.create_unique_name("func_%s" % name),
  399. )
  400. elif is_apply_def(expr):
  401. name = str(expr.opdef).lower()
  402. qualname = "{}.[{}]".format(
  403. self.qualname, self.create_unique_name("def_%s" % name),
  404. )
  405. elif is_getattr(expr):
  406. qualname = "{}.{}".format(expr.inputs[0].qualname, expr.name)
  407. name = get_suffix_name(self.qualname, qualname)
  408. _add_suffix = lambda x: x
  409. elif is_constant(expr) or is_input(expr):
  410. name = (
  411. expr.name if expr.name else "const_" + type(expr.value).__name__.lower()
  412. )
  413. qualname = "{}.[{}]".format(self.qualname, name)
  414. _add_suffix = lambda x: x
  415. for node in expr.outputs:
  416. cur_name = node._name if node._name else _add_suffix(name)
  417. node._name = self.create_unique_name(cur_name, node)
  418. if node._qualname == "":
  419. node._qualname = qualname
  420. assert get_suffix_name(self.qualname, qualname) is not None
  421. def merge(self, other: "NameSpace"):
  422. self._used_names.update(other.used_names)
  423. def associate_name_with_obj(self, name: str, node: Node):
  424. assert name in self.used_names
  425. assert self.used_names[name] is None, "The name(%s) is already in use" % (name)
  426. self._used_names[name] = node
  427. def unassociate_name_with_obj(self, node: Node):
  428. assert node.name in self.used_names
  429. # assert self.used_names[node.name] is node
  430. self._used_names[node.name] = None
  431. @property
  432. def used_names(self):
  433. return self._used_names
  434. class InternalGraph:
  435. r"""``InternalGraph`` is the main data structure used in the TracedModule.
  436. It is used to represent the execution procedure of Module's forward method.
  437. For example, the following code
  438. .. code-block::
  439. import megengine.random as rand
  440. import megengine.functional as F
  441. import megengine.module as M
  442. import megengine.traced_module as tm
  443. class MyModule(M.Module):
  444. def __init__(self):
  445. super().__init__()
  446. self.param = rand.normal(size=(3, 4))
  447. self.linear = M.Linear(4, 5)
  448. def forward(self, x):
  449. return F.relu(self.linear(x + self.param))
  450. net = MyModule()
  451. inp = F.zeros(shape = (3, 4))
  452. traced_module = tm.trace_module(net, inp)
  453. Will produce the following ``InternalGraph``::
  454. print(traced_module.graph)
  455. .. code-block:: text
  456. MyModule.Graph (self, x) {
  457. %2: linear = getattr(self, "linear") -> (Linear)
  458. %3: param = getattr(self, "param") -> (Tensor)
  459. %4: add_out = x.__add__(param, )
  460. %5: linear_out = linear(add_out, )
  461. %6: relu_out = nn.relu(linear_out, )
  462. return relu_out
  463. }
  464. """
  465. _exprs = None # type: List[Expr]
  466. _inputs = None # type: List[Node]
  467. _outputs = None # type: List[Node]
  468. _top_graph = None # type: InternalGraph
  469. _total_ids = None # type: List[int]
  470. def __init__(self, name: str, qualname: str):
  471. self._exprs = []
  472. self._inputs = []
  473. self._outputs = []
  474. self._watch_point = []
  475. self._end_point = []
  476. self._namespace = NameSpace(name, qualname)
  477. self._rst = collections.defaultdict(list)
  478. self._name = name
  479. self._qualname = qualname
  480. def _insert(self, expr):
  481. self._exprs.append(expr)
  482. @property
  483. def name(self) -> str:
  484. r"""Get the name of this graph."""
  485. return self._name
  486. @name.setter
  487. def name(self, new_name: str):
  488. r"""Set a new name to this graph."""
  489. mod = self.inputs[0].owner
  490. graph = self.top_graph
  491. assert graph is not None or mod._is_top, "The parent graph cannot be None."
  492. if graph is not None:
  493. assert graph._namespace.used_names.get(new_name, None) is None, (
  494. "The name(%s) is already in use. Please try a different one again."
  495. % (new_name)
  496. )
  497. new_name = graph._namespace.create_unique_name(new_name, self)
  498. self._name = new_name
  499. @property
  500. def qualname(self) -> str:
  501. r"""Get the `qualname` of this graph. The `qualname` can be used to get the
  502. submodule from the traced Module or Module.
  503. Example:
  504. .. code-block::
  505. import megengine.module as M
  506. import megengine.traced_module as tm
  507. import megengine as mge
  508. class block(M.Module):
  509. def __init__(self):
  510. super().__init__()
  511. self.relu = M.ReLU()
  512. def forward(self, x):
  513. return self.relu(x)
  514. class module(M.Module):
  515. def __init__(self):
  516. super().__init__()
  517. self.block = block()
  518. def forward(self, x):
  519. x = self.block(x)
  520. return x
  521. net = module()
  522. traced_net = tm.trace_module(net, mge.Tensor([0.]))
  523. qualname = traced_net.block.graph.qualname # qualname = "module.block"
  524. qualname = qualname.split(".", 1)[-1] # qualname = "block"
  525. assert qualname in list(map(lambda x: x[0], net.named_modules()))
  526. assert qualname in list(map(lambda x: x[0], traced_net.named_modules()))
  527. """
  528. return self._qualname
  529. @property
  530. def inputs(self) -> List[Node]:
  531. r"""Get the list of input Nodes of this graph.
  532. Returns:
  533. A list of ``Node``.
  534. """
  535. return self._inputs
  536. @property
  537. def outputs(self) -> List[Node]:
  538. r"""Get the list of output Nodes of this graph.
  539. Returns:
  540. A list of ``Node``.
  541. """
  542. return self._outputs
  543. @property
  544. def top_graph(self):
  545. r"""Get the parent graph of this graph.
  546. Returns:
  547. An ``InternalGraph``.
  548. """
  549. if self._top_graph:
  550. return self._top_graph()
  551. return None
  552. def exprs(self, recursive=True):
  553. r"""Get the Exprs that constitute this graph.
  554. Args:
  555. recursive: whether to get the Exprs in the subgraph.
  556. Default: True
  557. Returns:
  558. A ``ExprFilter`` containing all Exprs of this graph.
  559. """
  560. return ExprFilter(_expr_iter(self, recursive))
  561. def nodes(self, recursive=True):
  562. r"""Get the Nodes that constitute this graph.
  563. Args:
  564. recursive: whether to get the Nodes in the subgraph.
  565. Default: True
  566. Returns:
  567. A ``NodeFilter`` containing all Nodes of this graph.
  568. """
  569. return NodeFilter(_node_iter(self, recursive))
  570. def get_function_by_type(self, func: Callable = None, recursive=True):
  571. r"""Filter Exprs by the type of ``CallFunction``.
  572. Args:
  573. func: a built-in function, such as ``F.relu``.
  574. recursive: whether to get the Exprs in the subgraph.
  575. Default: True
  576. Returns:
  577. A :class:`~.TracedModule.ExprFilterCallFunction`.
  578. """
  579. return self.exprs(recursive).call_function(func)
  580. def get_method_by_type(self, method: str = None, recursive=True):
  581. r"""Filter Exprs by the type of ``CallMethod``.
  582. Args:
  583. method: a method string, such as "__add__".
  584. recursive: whether to get the Exprs in the subgraph.
  585. Default: True
  586. Returns:
  587. A :class:`~.TracedModule.ExprFilterCallMethod`.
  588. """
  589. return self.exprs(recursive).call_method(method)
  590. def get_expr_by_id(self, expr_id: List[int] = None, recursive=True):
  591. r"""Filter Exprs by their ``id``.
  592. Args:
  593. expr_id: a list of :class:`int`.
  594. recursive: whether to get the Exprs in the subgraph.
  595. Default: True
  596. Returns:
  597. A :class:`~.TracedModule.ExprFilterExprId`.
  598. """
  599. return self.exprs(recursive).expr_id(expr_id)
  600. def get_module_by_type(self, module_cls: Module, recursive=True):
  601. r"""Filter Nodes by the ``module_type`` of ``ModuleNode``.
  602. Args:
  603. module_cls: a subclass of :class:`~.Module`.
  604. recursive: whether to get the Nodes in the subgraph.
  605. Default: True
  606. Returns:
  607. A :class:`~.TracedModule.NodeFilterType`.
  608. """
  609. return self.nodes(recursive).type(module_cls)
  610. def get_node_by_id(self, node_id: List[int] = None, recursive=True):
  611. r"""Filter Nodes by their ``id``.
  612. The ``id`` of the ``Node`` can be obtained by the following code
  613. .. code-block::
  614. # node : Node
  615. print("{:i}".format(node))
  616. print(node.__format__("i"))
  617. # graph : InternalGraph
  618. print("{:i}".format(graph))
  619. print(graph.__format__("i"))
  620. Args:
  621. node_id: a list of :class:`int`.
  622. recursive: whether to get the Nodes in the subgraph.
  623. Default: True
  624. Returns:
  625. A :class:`~.TracedModule.NodeFilterNodeId`.
  626. """
  627. return self.nodes(recursive).node_id(node_id)
  628. def get_node_by_name(
  629. self, name: str = None, ignorecase: bool = True, recursive=True
  630. ):
  631. r"""Filter Nodes by their full name.
  632. The full name of the ``Node`` can be obtained by the following code
  633. .. code-block::
  634. # node : Node
  635. print("{:p}".format(node))
  636. print(node.__format__("p"))
  637. # graph : InternalGraph
  638. print("{:p}".format(graph))
  639. print(graph.__format__("p"))
  640. Args:
  641. name: a string in glob syntax that can contain ``?`` and
  642. ``*`` to match a single or arbitrary characters.
  643. ignorecase: whether to ignroe case.
  644. Default: True
  645. recursive: whether to get the Nodes in the subgraph.
  646. Default: True
  647. Returns:
  648. A :class:`~.TracedModule.NodeFilterName`.
  649. """
  650. return self.nodes(recursive).name(name, ignorecase)
  651. def _add_input(self, i):
  652. self._inputs.append(i)
  653. def _add_output(self, o):
  654. self._outputs.append(o)
  655. def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
  656. r"""Get the dependent Exprs of the ``nodes``.
  657. Args:
  658. nodes: a list of :class:`Node`.
  659. Returns:
  660. A list of dependent :class:`Expr`.
  661. """
  662. if not isinstance(nodes, Sequence):
  663. nodes = (nodes,)
  664. ret = list()
  665. queue = list(nodes)
  666. visited_queue = list()
  667. while queue:
  668. node = queue.pop()
  669. visited_queue.append(node)
  670. expr = node.expr
  671. if expr not in ret:
  672. ret.append(expr)
  673. for i in expr.inputs:
  674. if i not in queue and i not in visited_queue:
  675. queue.append(i)
  676. return ret
  677. def reset_inputs(self, *args, **kwargs):
  678. forma_mnode = self.inputs[0]
  679. moudle = forma_mnode.owner
  680. assert moudle._is_top, "reset_inputs only supports top graph"
  681. inputs, tree_def = tree_flatten(((moudle, *args), kwargs))
  682. def create_node(val: Tensor):
  683. name = self._namespace.create_unique_name("args")
  684. node = Input(
  685. type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name)
  686. ).outputs[0]
  687. self._namespace.associate_name_with_obj(node.name, node)
  688. node.shape = val.shape
  689. node.dtype = val.dtype
  690. return node
  691. formal_node_inputs = [
  692. forma_mnode,
  693. ]
  694. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  695. for v in inputs[1:]:
  696. assert isinstance(v, RawTensor)
  697. formal_node_inputs.append(create_node(v))
  698. self._inputs[:] = formal_node_inputs
  699. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  700. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  701. return formal_node_inputs[1:]
  702. def add_input_node(
  703. self, shape: Tuple[int], dtype: str = "float32", name: str = "args"
  704. ):
  705. r"""Add an input node to the graph.
  706. The new Node will be the last of the positional arguments.
  707. Args:
  708. shape: the shape of the new input Node.
  709. dtype: the dtype of the new input Node.
  710. Default: float32
  711. name: the name of the new input Node. When the name is used in the graph,
  712. a suffix will be added to it.
  713. """
  714. forma_mnode = self.inputs[0]
  715. moudle = forma_mnode.owner
  716. assert moudle._is_top, "add_input_node only supports top graph"
  717. def create_node(name=None):
  718. name = self._namespace.create_unique_name(name)
  719. node = Input(
  720. type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name)
  721. ).outputs[0]
  722. self._namespace.associate_name_with_obj(node.name, node)
  723. node.shape = shape
  724. node.dtype = dtype
  725. return node
  726. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  727. args, kwargs = org_argdef.unflatten(self._inputs)
  728. formal_inp_node = create_node(name)
  729. inputs, tree_def = tree_flatten(
  730. ((*args, formal_inp_node), kwargs),
  731. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  732. )
  733. self._inputs[:] = inputs[:]
  734. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  735. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  736. return formal_inp_node
  737. def reset_outputs(self, outputs):
  738. r"""Reset the output Nodes of the graph.
  739. .. note::
  740. This method only supports resetting the output of graphs
  741. that do not have a parent graph.
  742. Args:
  743. outputs: an object which inner element is Node. Support tuple, list
  744. dict, etc.
  745. For example, the following code
  746. .. code-block::
  747. import megengine.functional as F
  748. import megengine.module as M
  749. import megengine.traced_module as tm
  750. class MyModule(M.Module):
  751. def forward(self, x):
  752. x = x + 1
  753. return x
  754. net = MyModule()
  755. inp = F.zeros(shape = (1, ))
  756. traced_module = tm.trace_module(net, inp)
  757. graph = traced_module.graph
  758. inp_node = graph.inputs[1]
  759. out_node = graph.outputs[0]
  760. graph.reset_outputs((out_node, {"input": inp_node}))
  761. out = traced_module(inp)
  762. Will produce the following ``InternalGraph`` and ``out``::
  763. print(graph)
  764. print(out)
  765. .. code-block:: text
  766. MyModule.Graph (self, x) {
  767. %2: add_out = x.__add__(1, )
  768. return add_out, x
  769. }
  770. (Tensor([1.], device=xpux:0), {'input': Tensor([0.], device=xpux:0)})
  771. """
  772. outputs, out_def = tree_flatten(
  773. outputs, is_leaf=lambda x: isinstance(x, TensorNode),
  774. )
  775. forma_mnode = self.inputs[0]
  776. moudle = forma_mnode.owner
  777. assert moudle._is_top, "reset_outputs only supports top graph"
  778. tree_def = list(moudle.argdef_graph_map.keys())[0]
  779. self._outputs[:] = outputs
  780. moudle.argdef_outdef_map[tree_def] = out_def
  781. def add_output_node(self, node: TensorNode):
  782. r"""Add an output node to the Graph.
  783. The Graph output will become a ``tuple`` after calling ``add_output_node``.
  784. The first element of the ``tuple`` is the original output, and the second
  785. is the ``node``.
  786. For example, the following code
  787. .. code-block::
  788. import megengine.functional as F
  789. import megengine.module as M
  790. import megengine.traced_module as tm
  791. class MyModule(M.Module):
  792. def forward(self, x):
  793. x = x + 1
  794. return x
  795. net = MyModule()
  796. inp = F.zeros(shape = (1, ))
  797. traced_module = tm.trace_module(net, inp)
  798. graph = traced_module.graph
  799. inp_node = graph.inputs[1]
  800. out_node = graph.outputs[0]
  801. graph.add_output_node(inp_node)
  802. graph.add_output_node(out_node)
  803. out = traced_module(inp)
  804. Will produce the following ``InternalGraph`` and ``out``::
  805. print(graph)
  806. print(out)
  807. .. code-block:: text
  808. MyModule.Graph (self, x) {
  809. %2: add_out = x.__add__(1, )
  810. return add_out, x, add_out
  811. }
  812. ((Tensor([1.], device=xpux:0), Tensor([0.], device=xpux:0)), Tensor([1.], device=xpux:0))
  813. """
  814. forma_mnode = self.inputs[0]
  815. moudle = forma_mnode.owner
  816. assert moudle._is_top, "add_output_node only supports top graph"
  817. tree_def = list(moudle.argdef_graph_map.keys())[0]
  818. org_out_def = moudle.argdef_outdef_map[tree_def]
  819. org_outs = org_out_def.unflatten(self._outputs)
  820. outputs, out_def = tree_flatten(
  821. (org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode),
  822. )
  823. self._outputs[:] = outputs
  824. moudle.argdef_outdef_map[tree_def] = out_def
  825. def insert_exprs(self, expr: Optional[Expr] = None):
  826. r"""Initialize the trace mode and insertion position.
  827. When used within a 'with' statement, this will temporary set the trace mode and
  828. then restore normal mode when the with statement exits::
  829. with graph.insert_exprs(e): # set the trace mode
  830. ... # trace function or module
  831. ... # inert exprs into graph and resotre normal mode
  832. Args:
  833. expr: the ``expr`` after which to insert. If None, the insertion position will be
  834. automatically set based on the input node.
  835. Returns:
  836. A resource manager that will initialize trace mode on ``__enter__`` and
  837. restore normal mode on ``__exit__``.
  838. """
  839. if expr is not None:
  840. assert expr.top_graph == self, "Expr to insert after is not in graph."
  841. return _InsertExprs(self, expr)
  842. def replace_node(self, repl_dict: Dict[Node, Node]):
  843. r"""Replace the Nodes in the graph.
  844. Args:
  845. repl_dict: the map {old_Node: new_Node} that specifies how to replace the Nodes.
  846. """
  847. while repl_dict:
  848. node, repl_node = repl_dict.popitem()
  849. assert type(node) == type(
  850. repl_node
  851. ), "The type of {}({}) and {}({}) are not the same".format(
  852. node, type(node).__name__, repl_node, type(repl_node).__name__
  853. )
  854. # check graph inputs and outputs
  855. for i, n in enumerate(self.outputs):
  856. if n is node:
  857. self.outputs[i] = repl_node
  858. # update users of node and repl_node
  859. # update inputs of expr in node.users
  860. graph = repl_node.top_graph
  861. assert graph is not None
  862. assert graph is self
  863. index = -1
  864. if not isinstance(repl_node.expr, Input):
  865. index = graph._exprs.index(repl_node.expr)
  866. dep_exprs = self.get_dep_exprs(repl_node)
  867. i = 0
  868. while i < len(node.users):
  869. n = node.users[i]
  870. if n in graph._exprs and index >= graph._exprs.index(n):
  871. i += 1
  872. continue
  873. if n in dep_exprs:
  874. logger.info("Find a loop: ignore this replacement once")
  875. logger.info("node: %s" % node.__repr__())
  876. logger.info("expr: %s" % n.__repr__())
  877. i += 1
  878. continue
  879. repl_node.users.append(n)
  880. node.users.pop(i)
  881. idx = n.inputs.index(node)
  882. n.inputs[idx] = repl_node
  883. def _merge_getattr_expr(self):
  884. getattr_nodes_map = dict() # Dcit[(Node, str), Node]
  885. node_to_attrname = dict() # Dict[Node, (Node, Str)]
  886. for expr in filter(lambda x: isinstance(x, GetAttr), self._exprs):
  887. base_node, attr_name = expr.inputs[0], expr.name
  888. if expr.inputs[0] in node_to_attrname:
  889. base_node, base_name = node_to_attrname[expr.inputs[0]]
  890. attr_name = "{}.{}".format(base_name, expr.name)
  891. if get_suffix_name(self.qualname, expr.outputs[0].qualname) != attr_name:
  892. expected_qualname = base_node.qualname + "." + attr_name
  893. logger.warning(
  894. "{}.qualname expects {}, got {} actually. You can re-trace this "
  895. "TracedModel to make the name correct.".format(
  896. expr.outputs[0], expected_qualname, expr.outputs[0].qualname
  897. )
  898. )
  899. expr.outputs[0]._qualname = expected_qualname
  900. key = (base_node, attr_name)
  901. node_to_attrname[expr.outputs[0]] = key
  902. if key in getattr_nodes_map:
  903. existed_node = getattr_nodes_map[key]
  904. repl_node = expr.outputs[0]
  905. for expr in repl_node.users:
  906. existed_node.users.append(expr)
  907. idx = expr.inputs.index(repl_node)
  908. expr.inputs[idx] = existed_node
  909. repl_node.users = []
  910. else:
  911. if attr_name != expr.name:
  912. expr.name = attr_name
  913. expr.inputs[0].users.remove(expr)
  914. self.inputs[0].users.append(expr)
  915. expr.inputs[0] = self.inputs[0]
  916. getattr_nodes_map[key] = expr.outputs[0]
  917. def compile(self):
  918. r"""Delete unused expr."""
  919. self._merge_getattr_expr()
  920. dep_exprs = self.get_dep_exprs(self.outputs)
  921. i = 0
  922. while i < len(self._exprs):
  923. expr = self._exprs[i]
  924. if expr in dep_exprs or expr._disable_remove:
  925. i += 1
  926. continue
  927. for n in expr.inputs:
  928. n.users.remove(expr)
  929. self._exprs.remove(expr)
  930. for n in expr.outputs:
  931. self._namespace.unassociate_name_with_obj(n)
  932. def _reset_ids(self):
  933. for total_expr_id, expr in enumerate(self.exprs()):
  934. expr._id = total_expr_id
  935. for total_node_id, node in enumerate(self.nodes()):
  936. node._id = total_node_id
  937. self._total_ids = (total_node_id + 1, total_expr_id + 1)
  938. def _re_associate_name(self):
  939. self._namespace.used_names.clear()
  940. for node in self.nodes(False):
  941. node._name = self._namespace.create_unique_name(node.name, node)
  942. def interpret(self, *inputs):
  943. node2value = {}
  944. end_nodes_set = set(self._end_point)
  945. endnode2value = {}
  946. def get_all_endnode_val(n, v):
  947. if n in end_nodes_set:
  948. endnode2value[n] = v
  949. end_nodes_set.remove(n)
  950. return not end_nodes_set
  951. return False
  952. ref_count = lambda n: len(n.users) + (1 if n in self._outputs else 0)
  953. for n, v in zip(self._inputs, inputs):
  954. if ref_count(n) > 0:
  955. node2value[n] = [v, ref_count(n)]
  956. if n in self._watch_point:
  957. self._rst[n].append(v)
  958. if n in self._end_point and get_all_endnode_val(n, v):
  959. return list(endnode2value[i] for i in self._end_point)
  960. for expr in self._exprs:
  961. values = expr.interpret(*list(node2value[i][0] for i in expr.inputs))
  962. for n in expr.inputs:
  963. node2value[n][1] -= 1
  964. if node2value[n][1] == 0:
  965. node2value.pop(n)
  966. if values is not None:
  967. assert len(values) == len(expr.outputs)
  968. for n, v in zip(expr.outputs, values):
  969. if ref_count(n) > 0:
  970. node2value[n] = [v, ref_count(n)]
  971. if n in self._watch_point:
  972. self._rst[n] = v
  973. if self._end_point and get_all_endnode_val(n, v):
  974. return list(endnode2value[i] for i in self._end_point)
  975. return list(node2value[i][0] for i in self._outputs)
  976. def eval(self, *inputs: Tuple[Tensor]):
  977. r"""Call this method to execute the graph.
  978. Args:
  979. inputs: the tensors corresponding to the ``graph.inputs[1:]``.
  980. """
  981. assert len(inputs) == len(self._inputs) - 1
  982. inp = [self._inputs[0].owner] + list(inputs)
  983. return self.interpret(*inp)
  984. def __repr__(self):
  985. return self.__format__()
  986. def __format__(self, format_spec: str = "") -> str:
  987. saved_format_spec = Node._set_format_spec(format_spec)
  988. name = ""
  989. if self._name:
  990. name = "%s.Graph" % self._name
  991. res = "{} ({}) {{\n\t{}\n\treturn {}\n}}".format(
  992. name,
  993. ", ".join(str(i) for i in self._inputs),
  994. "\n\t".join("{}".format(str(i)) for i in self._exprs),
  995. ", ".join(str(i) for i in self._outputs),
  996. )
  997. Node._set_format_spec(saved_format_spec)
  998. return res
  999. def __getstate__(self):
  1000. state = {
  1001. "_exprs": self._exprs,
  1002. "_inputs": self._inputs,
  1003. "_outputs": self._outputs,
  1004. "_watch_point": [],
  1005. "_end_point": [],
  1006. "_namespace": self._namespace,
  1007. "_rst": collections.defaultdict(list),
  1008. "_name": self._name,
  1009. "_qualname": self._qualname,
  1010. }
  1011. if self._total_ids:
  1012. state["_total_ids"] = self._total_ids
  1013. _check_obj_attr(state)
  1014. return state
  1015. def __setstate__(self, state):
  1016. old_version = False
  1017. if "_module_name" in state:
  1018. old_version = True
  1019. state["_qualname"] = state.pop("_module_name")
  1020. prefix_name = state.pop("_prefix_name")
  1021. if prefix_name:
  1022. state["_name"] = "{}_{}".format(prefix_name, state["_name"])
  1023. self.__dict__.update(state)
  1024. if old_version:
  1025. self.inputs[0]._qualname = self._qualname
  1026. for e in self.exprs(False):
  1027. if isinstance(e, GetAttr):
  1028. e.outputs[0]._qualname = "{}.{}".format(
  1029. e.inputs[0]._qualname, e.name
  1030. )
  1031. for n in self.nodes(False):
  1032. if isinstance(n.expr, CallMethod) and isinstance(
  1033. n.expr.inputs[0], ModuleNode
  1034. ):
  1035. n._qualname = n.expr.inputs[0]._qualname + ".[out]"
  1036. continue
  1037. if (
  1038. not isinstance(n.expr, GetAttr)
  1039. and isinstance(n, TensorNode)
  1040. and n._qualname
  1041. ):
  1042. n._qualname = "{}.{}".format(self._qualname, n._qualname)
  1043. self._namespace = NameSpace(self._name, self._qualname)
  1044. self._re_associate_name()
  1045. def __copy__(self):
  1046. cls = self.__class__
  1047. result = cls.__new__(cls)
  1048. result.__dict__.update(self.__dict__)
  1049. return result
  1050. def __deepcopy__(self, memo):
  1051. with max_recursion_limit():
  1052. if id(self) in memo:
  1053. return memo[id(self)]
  1054. cls = self.__class__
  1055. result = cls.__new__(cls)
  1056. state = {}
  1057. memo[id(self)] = result
  1058. for k, v in self.__dict__.items():
  1059. if not isinstance(v, weakref.ReferenceType):
  1060. state[k] = copy.deepcopy(v, memo)
  1061. result.__dict__.update(state)
  1062. return result
  1063. def _get_meth_name(obj, func):
  1064. tp = obj if isinstance(obj, type) else type(obj)
  1065. for cls in tp.mro():
  1066. for k, v in cls.__dict__.items():
  1067. if v == func:
  1068. return k
  1069. return None
  1070. def _wrapped_function(orig_func):
  1071. @functools.wraps(orig_func)
  1072. def wrapped_fn(*args, **kwargs):
  1073. method_func = kwargs.pop("method_func", wrapped_fn)
  1074. if is_tracing_module():
  1075. unset_module_tracing()
  1076. inputs, tree_def = tree_flatten((args, kwargs))
  1077. for i in inputs:
  1078. if not NodeMixin.get(i, None):
  1079. if isinstance(i, (RawTensor, NodeMixin)):
  1080. NodeMixin.wrap_safe(i, Constant.make(i))
  1081. args, kwargs = _convert_kwargs_to_args(orig_func, args, kwargs)
  1082. meth_name = _get_meth_name(args[0], method_func)
  1083. arg_type = args[0] if isinstance(args[0], type) else type(args[0])
  1084. if meth_name and arg_type and issubclass(arg_type, RawTensor):
  1085. inputs, tree_def = tree_flatten((args, kwargs))
  1086. self = inputs[0]
  1087. if meth_name == "__new__":
  1088. if all([not isinstance(i, RawTensor) for i in inputs]):
  1089. # only trace Tensor.__new__() when there are tensors in args
  1090. set_module_tracing()
  1091. return orig_func(*args, **kwargs)
  1092. if isinstance(args[1], RawTensor):
  1093. node = NodeMixin.get(inputs[1])
  1094. is_scalar = isscalar(inputs[1])
  1095. inputs[1] = apply(
  1096. Copy(comp_node=inputs[1].device), Tensor(inputs[1])
  1097. )[0]
  1098. if is_scalar:
  1099. setscalar(inputs[1])
  1100. # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor,
  1101. # which will cause they have same _NodeMixin__node in tracing.
  1102. NodeMixin.wrap_safe(inputs[1], node)
  1103. args, kwargs = tree_def.unflatten(inputs)
  1104. call_node = CallMethod.make(self, meth_name)
  1105. else:
  1106. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  1107. call_node.add_inputs(inputs[1:])
  1108. else:
  1109. inputs, tree_def = tree_flatten((args, kwargs))
  1110. call_node = CallFunction.make(orig_func)
  1111. call_node.add_inputs(inputs)
  1112. call_node.arg_def = tree_def
  1113. rst = orig_func(*args, **kwargs)
  1114. if meth_name == "__setitem__":
  1115. rst = self
  1116. if rst is not None:
  1117. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  1118. call_node.out_def = out_def
  1119. else:
  1120. outputs = None
  1121. call_node.add_outputs(outputs)
  1122. if _get_expr_checker():
  1123. with _exclude_from_trace():
  1124. active_module_tracer().checker.check_expr_interpret(
  1125. call_node, outputs
  1126. )
  1127. set_module_tracing()
  1128. return rst
  1129. return orig_func(*args, **kwargs)
  1130. return wrapped_fn
  1131. class TracedModuleBuilder(NodeMixin):
  1132. _mod = None # type: Module
  1133. _body = None # type: InternalGraph
  1134. _is_builtin = None # type: bool
  1135. _argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
  1136. _argdef_outdef_map = None # type: Dict[Treedef, Treedef]
  1137. nodes = None
  1138. __builder_attributes__ = [
  1139. "_mod",
  1140. "_body",
  1141. "_NodeMixin__node",
  1142. "_is_builtin",
  1143. "build",
  1144. "_record_wrapped_nodes",
  1145. "_argdef_graph_map",
  1146. "_argdef_outdef_map",
  1147. "_check_qat_module",
  1148. "nodes",
  1149. "__class__",
  1150. "__dict__",
  1151. "_is_top",
  1152. ]
  1153. def __init__(self, mod, is_top_module=False):
  1154. super(TracedModuleBuilder, self).__init__()
  1155. assert isinstance(mod, Module)
  1156. self._mod = mod
  1157. self._body = None
  1158. self._is_top = is_top_module
  1159. self._is_builtin = (
  1160. True
  1161. if isinstance(mod, (Observer, _FakeQuantize))
  1162. else module_tracer.is_builtin(mod)
  1163. )
  1164. if isinstance(self._mod, QATModule):
  1165. unset_module_tracing()
  1166. self._check_qat_module(self._mod)
  1167. set_module_tracing()
  1168. self._argdef_graph_map = {}
  1169. self._argdef_outdef_map = {}
  1170. self.nodes = set()
  1171. # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
  1172. # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
  1173. self.__class__ = type(
  1174. "TracedModuleBuilder",
  1175. (TracedModuleBuilder, mod.__class__),
  1176. dict(TracedModuleBuilder.__dict__),
  1177. )
  1178. def _check_qat_module(self, qat_module):
  1179. def isbuiltin(m):
  1180. return m is None or module_tracer.is_builtin(m)
  1181. if qat_module.with_act:
  1182. act_observer = qat_module.act_observer
  1183. act_fakequant = qat_module.act_fake_quant
  1184. if not isbuiltin(act_observer) or not isbuiltin(act_fakequant):
  1185. qparams = (
  1186. act_observer.get_qparams()
  1187. if hasattr(act_observer, "get_qparams")
  1188. else act_fakequant.get_qparams()
  1189. )
  1190. dtype = (
  1191. act_observer.dtype
  1192. if hasattr(act_observer, "dtype")
  1193. else act_fakequant.dtype
  1194. )
  1195. qat_module.act_observer = None
  1196. qat_module.act_fake_quant = TM_FakeQuant(dtype)
  1197. qat_module.act_fake_quant.set_qparams(qparams)
  1198. if qat_module.with_weight:
  1199. weight_observer = qat_module.weight_observer
  1200. weight_fakequant = qat_module.weight_fake_quant
  1201. if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant):
  1202. qparams = (
  1203. weight_observer.get_qparams()
  1204. if hasattr(weight_observer, "get_qparams")
  1205. else weight_fakequant.get_qparams()
  1206. )
  1207. dtype = (
  1208. weight_observer.dtype
  1209. if hasattr(weight_observer, "dtype")
  1210. else weight_fakequant.dtype
  1211. )
  1212. qat_module.weight_observer = None
  1213. qat_module.weight_fake_quant = TM_FakeQuant(dtype)
  1214. qat_module.weight_fake_quant.set_qparams(qparams)
  1215. def build(self):
  1216. if self._is_builtin:
  1217. assert module_tracer.is_builtin(self._mod)
  1218. mod_type = type(self._mod)
  1219. for node in self.nodes:
  1220. node.module_type = mod_type
  1221. return self._mod
  1222. elif isinstance(self._mod, TracedModule) and _graph_surgery_mode():
  1223. return self._mod
  1224. else:
  1225. is_qat = isinstance(self._mod, QATModule) or (
  1226. isinstance(self._mod, TracedModule) and self._mod.is_qat
  1227. )
  1228. traced_module = TracedModule(
  1229. self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat
  1230. )
  1231. for _, g in self._argdef_graph_map.items():
  1232. g.compile()
  1233. if self._is_top:
  1234. g._total_ids = (Node._get_next_id(), Expr._get_next_id())
  1235. for k, v in self.__dict__.items():
  1236. if k not in TracedModuleBuilder.__builder_attributes__:
  1237. if isinstance(v, TracedModuleBuilder):
  1238. v = v.build()
  1239. setattr(traced_module, k, v)
  1240. elif isinstance(v, RawTensor):
  1241. setattr(traced_module, k, v)
  1242. if isinstance(self._mod, QATModule):
  1243. unset_module_tracing()
  1244. traced_module.with_act = self._mod.with_act
  1245. traced_module.with_weight = self._mod.with_weight
  1246. if not hasattr(traced_module, "act_fake_quant"):
  1247. traced_module.act_fake_quant = None
  1248. if not hasattr(traced_module, "act_observer"):
  1249. traced_module.act_observer = None
  1250. if not hasattr(traced_module, "weight_fake_quant"):
  1251. traced_module.weight_fake_quant = None
  1252. if not hasattr(traced_module, "weight_observer"):
  1253. traced_module.weight_observer = None
  1254. set_module_tracing()
  1255. if self._is_top:
  1256. traced_module._update_ref()
  1257. return traced_module
  1258. def _record_wrapped_nodes(self, node):
  1259. self.nodes.add(node)
  1260. def __call__(self, *args, **kwargs):
  1261. assert isinstance(self._mod, Module)
  1262. is_graph_surgery_mode = _graph_surgery_mode()
  1263. if isinstance(self._mod, TracedModule) and is_graph_surgery_mode:
  1264. _set_graph_surgery_mode(False)
  1265. # prepare args and kwargs for inner graph
  1266. if "method_func" in kwargs:
  1267. kwargs.pop("method_func")
  1268. args, kwargs = _convert_kwargs_to_args(self._mod.forward, args, kwargs, True)
  1269. def mark_constant(x):
  1270. node = NodeMixin.get(x, None)
  1271. if node is None: # capture as constant
  1272. NodeMixin.wrap(x, lambda: Constant.make(x))
  1273. inputs, tree_def = tree_flatten(((self, *args), kwargs))
  1274. for i in inputs:
  1275. mark_constant(i)
  1276. callnode = CallMethod.make(NodeMixin.get(self))
  1277. callnode.add_inputs(inputs[1:])
  1278. callnode.arg_def = tree_def
  1279. if self._is_builtin or tree_def in self._argdef_graph_map:
  1280. unset_module_tracing()
  1281. rst = self._mod(*args, **kwargs)
  1282. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  1283. if _get_expr_checker():
  1284. with _exclude_from_trace():
  1285. tmp = self.build()
  1286. active_module_tracer().checker.check_builtin_module(
  1287. tmp, callnode, outputs
  1288. )
  1289. set_module_tracing()
  1290. if self._is_builtin:
  1291. self._body = None
  1292. elif tree_def in self._argdef_graph_map:
  1293. self._body = self._argdef_graph_map[tree_def]
  1294. else:
  1295. orig_self = NodeMixin.get(self)
  1296. parent_graph = active_module_tracer().current_scope()
  1297. module_qualname = orig_self._qualname
  1298. self._body = InternalGraph(
  1299. name=parent_graph._namespace.create_unique_name(module_qualname),
  1300. qualname=module_qualname,
  1301. )
  1302. parent_graph._namespace.associate_name_with_obj(self._body.name, self._body)
  1303. active_module_tracer().push_scope(self._body)
  1304. # rebind self to new input node
  1305. NodeMixin.wrap_safe(
  1306. self,
  1307. Input.make(
  1308. name="self",
  1309. qualname=module_qualname,
  1310. type=NodeMixin.get_wrapped_type(self),
  1311. ),
  1312. )
  1313. origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
  1314. # prepare args and kwargs for inner graph
  1315. index_args, index_kwargs = tree_def.unflatten(
  1316. [
  1317. ArgsIndex(0),
  1318. *list(ArgsIndex(i + 1) for i in range(len(origin_inp_node))),
  1319. ]
  1320. )
  1321. key2idx = getcallargs(type(self._mod).forward, *index_args, **index_kwargs)
  1322. idx2key = {}
  1323. for k, v in key2idx.items():
  1324. if isinstance(v, ArgsIndex):
  1325. idx2key[v.index] = k
  1326. else:
  1327. flatten_argidx, _ = tree_flatten(v)
  1328. for _i, v in enumerate(flatten_argidx):
  1329. if isinstance(v, ArgsIndex):
  1330. idx2key[v.index] = k + "_%d" % _i
  1331. def wrap(x, name):
  1332. if isinstance(x, (RawTensor, NodeMixin)):
  1333. NodeMixin.wrap(
  1334. x,
  1335. lambda: Input.make(
  1336. type=NodeMixin.get_wrapped_type(x),
  1337. name=name,
  1338. qualname="%s.[%s]" % (module_qualname, name),
  1339. ),
  1340. )
  1341. return x
  1342. args = [self]
  1343. orig_traced_inputs = (
  1344. None
  1345. if not isinstance(self._mod, TracedModule)
  1346. else self._mod.argdef_graph_map[tree_def].inputs
  1347. )
  1348. ind = 1
  1349. for v in inputs[1:]:
  1350. if isinstance(v, (RawTensor, NodeMixin)):
  1351. args_name = (
  1352. orig_traced_inputs[ind]._name
  1353. if orig_traced_inputs
  1354. else idx2key[ind]
  1355. )
  1356. ind += 1
  1357. args.append(wrap(v, args_name))
  1358. else:
  1359. args.append(v)
  1360. args, kwargs = tree_def.unflatten(args)
  1361. active_module_tracer().patcher.auto_patch(
  1362. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  1363. )
  1364. rst = type(self._mod).forward(*args, **kwargs)
  1365. if _graph_surgery_mode():
  1366. rst = _node_to_tensor(rst)[0][0]
  1367. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  1368. for i in (
  1369. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  1370. ):
  1371. mark_constant(i)
  1372. active_module_tracer().current_scope()._add_output(NodeMixin.get(i))
  1373. NodeMixin.wrap_safe(self, orig_self)
  1374. for arg, node in zip(inputs[1:], origin_inp_node):
  1375. if node:
  1376. NodeMixin.wrap_safe(arg, node)
  1377. active_module_tracer().pop_scope()
  1378. # rebind output to outer graph
  1379. callnode.out_def = out_def
  1380. callnode.add_outputs(outputs)
  1381. self._argdef_graph_map[callnode.arg_def] = self._body
  1382. self._argdef_outdef_map[callnode.arg_def] = out_def
  1383. _set_graph_surgery_mode(is_graph_surgery_mode)
  1384. return rst
  1385. def __setattr__(self, name, value):
  1386. object.__setattr__(self, name, value)
  1387. def __repr__(self):
  1388. return repr(self._mod)
  1389. def __getattr__(self, name):
  1390. if name not in self._mod.__dict__:
  1391. attr = getattr(type(self._mod), name).__get__(self, type(self))
  1392. else:
  1393. attr = getattr(self._mod, name)
  1394. if (
  1395. isinstance(attr, FunctionType)
  1396. and id(attr) in active_module_tracer().patcher.patched_fn_ids
  1397. ):
  1398. return active_module_tracer().patcher.wrap_fn(attr)
  1399. if isinstance(attr, (List, Dict)):
  1400. flag = _set_graph_surgery_mode(False)
  1401. unset_module_tracing()
  1402. has_module, m_container = replace_container_with_module_container(attr)
  1403. if m_container:
  1404. attr = m_container
  1405. if has_module and not m_container:
  1406. raise ValueError(
  1407. "Can not trace the module that uses the same container to store"
  1408. " Module and Non-Module objects."
  1409. )
  1410. set_module_tracing()
  1411. _set_graph_surgery_mode(flag)
  1412. if isinstance(attr, Module):
  1413. attr = TracedModuleBuilder(attr)
  1414. if isinstance(attr, (Module, RawTensor)):
  1415. setattr(self, name, attr)
  1416. NodeMixin.wrap(
  1417. attr,
  1418. lambda: GetAttr.make(
  1419. NodeMixin.get(self),
  1420. type=NodeMixin.get_wrapped_type(attr),
  1421. attr_name=name,
  1422. name="",
  1423. ),
  1424. )
  1425. return attr
  1426. def __getattribute__(self, name):
  1427. if name in TracedModuleBuilder.__builder_attributes__:
  1428. return object.__getattribute__(self, name)
  1429. else:
  1430. wrapped = object.__getattribute__(self, name)
  1431. class_members = dict(inspect.getmembers(self.__class__))
  1432. if name in self._mod.__dict__:
  1433. mod_attr = getattr(self._mod, name)
  1434. if name in class_members:
  1435. if (
  1436. not isinstance(wrapped, TracedModuleBuilder)
  1437. and wrapped is not mod_attr
  1438. ):
  1439. wrapped = self.__getattr__(name)
  1440. if isinstance(wrapped, TracedModuleBuilder):
  1441. if not isinstance(mod_attr, (List, Dict, QATModule)):
  1442. assert mod_attr is wrapped._mod
  1443. else:
  1444. assert (
  1445. mod_attr is wrapped
  1446. ), "TracedModule do not support modify attributes, please check your code."
  1447. if isinstance(wrapped, (NodeMixin, RawTensor)):
  1448. NodeMixin.wrap(
  1449. wrapped,
  1450. lambda: GetAttr.make(
  1451. NodeMixin.get(self),
  1452. type=NodeMixin.get_wrapped_type(wrapped),
  1453. attr_name=name,
  1454. name="",
  1455. ),
  1456. )
  1457. return wrapped
  1458. class _expr_iter:
  1459. def __init__(self, graph: InternalGraph, recursive: bool = True):
  1460. self.graph = graph
  1461. self.recursive = recursive
  1462. self._visited_graph = set()
  1463. def __iter__(self):
  1464. yield from self._gen_expr(self.graph)
  1465. def _gen_expr(self, graph: InternalGraph):
  1466. visit_inp = set()
  1467. for inp_node in graph.inputs:
  1468. if inp_node not in visit_inp:
  1469. yield inp_node.expr
  1470. visit_inp.add(inp_node)
  1471. for expr in graph._exprs:
  1472. yield expr
  1473. if (
  1474. self.recursive
  1475. and hasattr(expr, "graph")
  1476. and expr.graph is not None
  1477. and id(expr.graph) not in self._visited_graph
  1478. ):
  1479. self._visited_graph.add(id(expr.graph))
  1480. yield from self._gen_expr(expr.graph)
  1481. class _node_iter:
  1482. def __init__(self, graph: InternalGraph, recursive: bool = True) -> None:
  1483. nodes = []
  1484. node_ids = set()
  1485. for expr in graph.exprs(recursive):
  1486. for n in expr.outputs:
  1487. assert id(n) not in node_ids
  1488. nodes.append(n)
  1489. node_ids.add(id(n))
  1490. self.nodes = nodes
  1491. def __iter__(self):
  1492. for node in self.nodes:
  1493. yield node
  1494. class BaseFilter:
  1495. r"""``BaseFilter`` exposes some methods for converting ``_node_iter/_expr_iter`` to ``list``, ``dict``, etc."""
  1496. def __init__(self, iter: Iterable):
  1497. self._iter = iter
  1498. def __iter__(self):
  1499. return iter(self._iter)
  1500. def as_list(self):
  1501. r"""Consume this iterator and return its content as a list.
  1502. Returns:
  1503. A list of ``Node`` or ``Expr``.
  1504. """
  1505. return list(self)
  1506. def as_dict(self):
  1507. r"""Construct an ordered dict to map from ``id`` to objects in this iterator.
  1508. Returns:
  1509. An :class:`OrderedDict`.
  1510. """
  1511. return collections.OrderedDict((i._id, i) for i in self)
  1512. def as_unique(self):
  1513. """Assert that this iterator yields only one ``Node`` or ``Expr`` and return it.
  1514. Rerurns:
  1515. A ``Node`` or ``Expr``.
  1516. """
  1517. rst = self.as_list()
  1518. assert len(rst) == 1, "{} elements found".format(len(rst))
  1519. (elem,) = self
  1520. return elem
  1521. def as_count(self):
  1522. r"""Consume this iterator and get the number of elements."""
  1523. return sum(1 for _ in self)
  1524. class ExprFilter(BaseFilter):
  1525. """Filter on Expr iterator.
  1526. This class is an iterator of :class:`.Expr` objects and multiple
  1527. filtering conditions and mappers can be chained.
  1528. """
  1529. def call_function(self, func):
  1530. r"""Filter by specific ``CallFunction.func``.
  1531. See :meth:`~.InternalGraph.get_function_by_type` for details.
  1532. """
  1533. return ExprFilterCallFunction(self, func)
  1534. def call_method(self, method):
  1535. r"""Filter by specific ``CallMethod.method``.
  1536. See :meth:`~.InternalGraph.get_function_by_type` for details.
  1537. """
  1538. return ExprFilterCallMethod(self, method)
  1539. def expr_id(self, expr_id: List[int]):
  1540. r"""Filter Exprs by their ``id``.
  1541. See :meth:`~.InternalGraph.get_function_by_type` for details.
  1542. """
  1543. return ExprFilterExprId(self, expr_id)
  1544. class NodeFilter(BaseFilter):
  1545. """Filter on Node iterator.
  1546. This class is an iterator of :class:`~.traced_module.Node` objects and multiple
  1547. filtering conditions and mappers can be chained.
  1548. """
  1549. def type(self, owner_type):
  1550. r"""Filter by specific Module type.
  1551. See :meth:`~.InternalGraph.get_module_by_type` for details.
  1552. """
  1553. return NodeFilterType(self, owner_type)
  1554. def node_id(self, node_id: List[int]):
  1555. r"""Filter Nodes by their ``id``.
  1556. See :meth:`~.InternalGraph.get_node_by_id` for details.
  1557. """
  1558. return NodeFilterNodeId(self, node_id)
  1559. def name(self, name: str, ignorecase: bool = True):
  1560. r"""Filter Nodes by their full name.
  1561. See :meth:`~.InternalGraph.get_node_by_name` for details.
  1562. """
  1563. return NodeFilterName(self, name, ignorecase)
  1564. class NodeFilterType(NodeFilter):
  1565. """See :meth:`~.InternalGraph.get_module_by_type`"""
  1566. def __init__(self, expr_iter, owner_type):
  1567. super().__init__(expr_iter)
  1568. self.owner_type = owner_type
  1569. def __iter__(self):
  1570. for node in self._iter:
  1571. if not isinstance(node, ModuleNode):
  1572. continue
  1573. if not hasattr(node, "owner"):
  1574. continue
  1575. if isinstance(node.owner, self.owner_type):
  1576. yield node
  1577. class NodeFilterNodeId(NodeFilter):
  1578. """See :meth:`~.InternalGraph.get_node_by_id`"""
  1579. def __init__(self, expr_iter, node_id: List[int]):
  1580. super().__init__(expr_iter)
  1581. if not isinstance(node_id, Sequence):
  1582. node_id = [node_id]
  1583. self.node_id = node_id
  1584. def __iter__(self):
  1585. for node in self._iter:
  1586. if node._id in self.node_id:
  1587. yield node
  1588. class NodeFilterName(NodeFilter):
  1589. """See :meth:`~.InternalGraph.get_node_by_name`"""
  1590. _re = None
  1591. def __init__(self, node_iter, pattern, ignorecase):
  1592. super().__init__(node_iter)
  1593. self.pattern = pattern
  1594. self._re = self.make_re(pattern, ignorecase)
  1595. @classmethod
  1596. def make_re(cls, pattern, ignorecase=True):
  1597. assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern)
  1598. assert isinstance(ignorecase, bool)
  1599. flags = 0
  1600. if ignorecase:
  1601. flags |= re.IGNORECASE
  1602. return re.compile(fnmatch.translate(pattern), flags=flags)
  1603. def __iter__(self):
  1604. for i in self._iter:
  1605. graph = i.top_graph
  1606. name = "{}_{}".format(graph._name, i._name)
  1607. if self.pattern == name or self._re.match(name):
  1608. yield i
  1609. class ExprFilterCallFunction(ExprFilter):
  1610. """See :meth:`~.InternalGraph.get_function_by_type`"""
  1611. def __init__(self, expr_iter, func: Callable = None):
  1612. super().__init__(expr_iter)
  1613. self.func = func
  1614. def __iter__(self):
  1615. for expr in self._iter:
  1616. if not isinstance(expr, CallFunction):
  1617. continue
  1618. if self.func is None or expr.func == self.func:
  1619. yield expr
  1620. class ExprFilterCallMethod(ExprFilter):
  1621. """See :meth:`~.InternalGraph.get_method_by_type`"""
  1622. def __init__(self, expr_iter, method: str = None):
  1623. super().__init__(expr_iter)
  1624. self.method = method
  1625. def __iter__(self):
  1626. for expr in self._iter:
  1627. if not isinstance(expr, CallMethod):
  1628. continue
  1629. if self.method is None or expr.method == self.method:
  1630. yield expr
  1631. class ExprFilterExprId(ExprFilter):
  1632. """See :meth:`~.InternalGraph.get_expr_by_id`"""
  1633. def __init__(self, expr_iter, expr_id: List[int]):
  1634. super().__init__(expr_iter)
  1635. if not isinstance(expr_id, Sequence):
  1636. expr_id = [expr_id]
  1637. self.expr_id = expr_id
  1638. def __iter__(self):
  1639. for expr in self._iter:
  1640. if expr._id in self.expr_id:
  1641. yield expr
  1642. class TracedModule(Module):
  1643. r"""``TracedModule`` is the Module created by tracing normal module.
  1644. It owns an argdef to graph(InternalGraph) map. The forward method of ``TracedModule``
  1645. will get a graph from ``argdef_graph_map`` according to the argdef of input ``args/kwargs``
  1646. and interpret it.
  1647. .. note::
  1648. ``TracedModule`` can only be created by :func:`~.trace_module`. See :func:`~.trace_module`
  1649. for more details.
  1650. """
  1651. # m_node = None # type: ModuleNode
  1652. argdef_graph_map = None
  1653. argdef_outdef_map = None
  1654. def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False):
  1655. super(TracedModule, self).__init__()
  1656. self.argdef_graph_map = argdef_graph_map
  1657. self.argdef_outdef_map = argdef_outdef_map
  1658. self._is_top = is_top
  1659. self.watch_points = []
  1660. self.watch_node_value = {}
  1661. self.end_points = []
  1662. self.is_qat = is_qat
  1663. self.argspec = None
  1664. def forward(self, *args, **kwargs):
  1665. if hasattr(self, "argspec") and self.argspec is not None:
  1666. args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True)
  1667. inputs, treedef = tree_flatten(((self, *args), kwargs))
  1668. assert (
  1669. treedef in self.argdef_graph_map
  1670. ), "support input args kwargs format: \n{}, but get: \n{}".format(
  1671. "\n ".join(
  1672. "forward({})".format(i._args_kwargs_repr())
  1673. for i in self.argdef_graph_map.keys()
  1674. ),
  1675. treedef._args_kwargs_repr(),
  1676. )
  1677. inputs = filter(
  1678. lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
  1679. ) # allow TracedModuleBuilder for retrace.
  1680. outputs = self.argdef_graph_map[treedef].interpret(*inputs)
  1681. if self.watch_points:
  1682. self.watch_node_value = {}
  1683. for n in self.watch_points:
  1684. self.watch_node_value[n] = n.top_graph._rst.pop(n)
  1685. if self.end_points:
  1686. return outputs
  1687. out_def = self.argdef_outdef_map[treedef]
  1688. outputs = out_def.unflatten(outputs)
  1689. return outputs
  1690. def set_watch_points(self, nodes):
  1691. r"""Initialize the :attr:`~.TracedModule.watch_points`.
  1692. You can call this function to get the ``Tensor/Module`` corresponding to a ``Node`` at runtime.
  1693. Args:
  1694. nodes: a list of ``Node``.
  1695. For example, the following code
  1696. .. code-block::
  1697. import megengine.module as M
  1698. import megengine as mge
  1699. import megengine.traced_module as tm
  1700. class MyModule(M.Module):
  1701. def forward(self, x):
  1702. x = x + 1 + 2
  1703. return x
  1704. net = MyModule()
  1705. inp = mge.Tensor([0])
  1706. traced_module = tm.trace_module(net, inp)
  1707. add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
  1708. traced_module.set_watch_points(add_1_node)
  1709. out = traced_module(inp)
  1710. Will get the following ``watch_node_value``::
  1711. print(traced_module.watch_node_value)
  1712. .. code-block:: text
  1713. {add_out: Tensor([1.], device=xpux:0)}
  1714. """
  1715. if not isinstance(nodes, Sequence):
  1716. nodes = [nodes]
  1717. self.watch_points = nodes
  1718. if nodes:
  1719. nodes[0].top_graph._watch_point = []
  1720. for n in nodes:
  1721. n.top_graph._watch_point.append(n)
  1722. def clear_watch_points(self):
  1723. r"""Clear the :attr:`~.TracedModule.watch_points` and :attr:`~.TracedModule.watch_node_value`.
  1724. """
  1725. for n in self.watch_points:
  1726. n.top_graph._watch_point = []
  1727. self.watch_points = []
  1728. self.watch_node_value = {}
  1729. def set_end_points(self, nodes: Sequence[Node]):
  1730. r"""Initialize the :attr:`~.TracedModule.end_points`.
  1731. When all the ``nodes`` are generated, the Module will stop execution and return directly.
  1732. Args:
  1733. nodes: a list of ``Node``.
  1734. For example, the following code
  1735. .. code-block::
  1736. import megengine.module as M
  1737. import megengine as mge
  1738. import megengine.traced_module as tm
  1739. class MyModule(M.Module):
  1740. def forward(self, x):
  1741. x = x + 1 + 2
  1742. return x
  1743. net = MyModule()
  1744. inp = mge.Tensor([0])
  1745. traced_module = tm.trace_module(net, inp)
  1746. add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
  1747. traced_module.set_end_points(add_1_node)
  1748. out = traced_module(inp)
  1749. Will get the following ``out``::
  1750. print(out)
  1751. .. code-block:: text
  1752. [Tensor([1.], device=xpux:0)]
  1753. """
  1754. if not isinstance(nodes, Sequence):
  1755. nodes = [nodes]
  1756. self.end_points = nodes
  1757. graphs = list(self.argdef_graph_map.values())
  1758. for n in nodes:
  1759. assert n.top_graph in graphs
  1760. n.top_graph._end_point.append(n)
  1761. def clear_end_points(self):
  1762. r"""Clear the :attr:`~.TracedModule.end_points`.
  1763. """
  1764. for n in self.end_points:
  1765. n.top_graph._end_point = []
  1766. self.end_points = []
  1767. @property
  1768. def graph(self) -> InternalGraph:
  1769. """Return the ``InternalGraph`` of this ``TracedModule``.
  1770. """
  1771. assert len(self.argdef_graph_map) == 1
  1772. return list(self.argdef_graph_map.values())[0]
  1773. def _update_ref(self, actual_node_map: Union[Dict] = None, top_graph=None):
  1774. for inp_def, graph in self.argdef_graph_map.items():
  1775. if top_graph is not None:
  1776. graph._top_graph = weakref.ref(top_graph)
  1777. for n in graph._inputs + graph._outputs:
  1778. n.expr._top_graph = weakref.ref(graph)
  1779. n._top_graph = weakref.ref(graph)
  1780. graph._inputs[0]._owner = weakref.ref(self)
  1781. for i, n in enumerate(graph._inputs):
  1782. n.actual_node = []
  1783. if actual_node_map is not None and inp_def in actual_node_map.keys():
  1784. n.actual_node = list(list(zip(*(actual_node_map[inp_def])))[i])
  1785. node2obj = {}
  1786. next_actual_node_map = collections.defaultdict(
  1787. lambda: collections.defaultdict(list)
  1788. )
  1789. node2obj[graph._inputs[0]] = self
  1790. for expr in graph._exprs:
  1791. for n in expr.inputs + expr.outputs:
  1792. n._top_graph = weakref.ref(graph)
  1793. expr._top_graph = weakref.ref(graph)
  1794. if isinstance(expr, GetAttr) and isinstance(
  1795. expr.outputs[0], ModuleNode
  1796. ):
  1797. obj = expr.interpret(node2obj[expr.inputs[0]])[0]
  1798. expr.outputs[0]._owner = weakref.ref(obj)
  1799. node2obj[expr.outputs[0]] = obj
  1800. if isinstance(expr, Constant) and isinstance(
  1801. expr.outputs[0], ModuleNode
  1802. ):
  1803. obj = expr.value
  1804. expr.outputs[0]._owner = weakref.ref(obj)
  1805. node2obj[expr.outputs[0]] = obj
  1806. if (
  1807. isinstance(expr, CallMethod)
  1808. and expr.method == "__call__"
  1809. and isinstance(expr.inputs[0], ModuleNode)
  1810. ):
  1811. obj = node2obj[expr.inputs[0]]
  1812. if expr.arg_def is not None:
  1813. next_actual_node_map[obj][expr.arg_def].append(expr.inputs)
  1814. for obj in node2obj.values():
  1815. if obj is self:
  1816. continue
  1817. mnode_map = None
  1818. if obj in next_actual_node_map.keys():
  1819. mnode_map = next_actual_node_map[obj]
  1820. if isinstance(obj, TracedModule):
  1821. obj._update_ref(mnode_map, graph)
  1822. def flatten(self):
  1823. r"""Get a new TracedModule, which eliminates ``GetAttr`` and has no hierarchy.
  1824. Retruns:
  1825. A new :class:`TracedModule`.
  1826. """
  1827. new_module = copy.deepcopy(self)
  1828. def _replace_inputs_and_outputs(expr: Expr, repl_dict: Dict[Node, Node]):
  1829. inputs, outputs = expr.inputs, expr.outputs
  1830. for i, node in enumerate(inputs):
  1831. if node in repl_dict:
  1832. inputs[i] = repl_dict[node]
  1833. for i, node in enumerate(outputs):
  1834. if node in repl_dict:
  1835. outputs[i] = repl_dict[node]
  1836. outputs[i].expr = expr
  1837. def _flatten_subgraph(
  1838. parent_graph: InternalGraph,
  1839. graph: InternalGraph,
  1840. call: CallMethod,
  1841. module: Module,
  1842. ):
  1843. repl_dict, node2obj, rename_blacklist = {}, {}, []
  1844. if call is not None:
  1845. graph = copy.deepcopy(graph)
  1846. node2obj[call.inputs[0]] = module
  1847. repl_dict = dict(zip(graph._inputs, call.inputs))
  1848. for ind, out in enumerate(graph.outputs):
  1849. if isinstance(out.expr, Input):
  1850. assert out in repl_dict
  1851. call_out = call.outputs[ind]
  1852. for expr in call.outputs[ind].users:
  1853. for index, inp in enumerate(expr.inputs):
  1854. if inp is call_out:
  1855. expr.inputs[index] = repl_dict[out]
  1856. repl_dict[out].users.append(expr)
  1857. if parent_graph is not None:
  1858. for index, parent_out in enumerate(parent_graph._outputs):
  1859. if parent_out is call_out:
  1860. parent_graph._outputs[index] = repl_dict[out]
  1861. continue
  1862. repl_dict[out] = call.outputs[ind]
  1863. if isinstance(out, TensorNode):
  1864. call.outputs[ind]._qualname = out._qualname
  1865. for node, repl_node in repl_dict.items():
  1866. assert node in graph._inputs or node in graph._outputs
  1867. repl_node.users.extend(node.users)
  1868. rename_blacklist = list(chain(call.inputs, call.outputs))
  1869. node2obj[graph._inputs[0]] = module
  1870. prefix_name = call.inputs[0]._name if call else ""
  1871. flattened_exprs = []
  1872. for expr in graph._exprs:
  1873. exprs = [expr]
  1874. if call is not None:
  1875. _replace_inputs_and_outputs(expr, repl_dict)
  1876. if isinstance(expr, GetAttr):
  1877. mnode = expr.inputs[0]
  1878. node2obj[expr.outputs[0]] = expr.interpret(node2obj[mnode])[0]
  1879. if isinstance(expr, CallMethod):
  1880. obj_node = expr.inputs[0]
  1881. if isinstance(obj_node, ModuleNode) and isinstance(
  1882. obj_node.expr, GetAttr
  1883. ):
  1884. obj = node2obj[obj_node]
  1885. expr_graph = (
  1886. obj.argdef_graph_map[expr.arg_def]
  1887. if hasattr(obj, "argdef_graph_map")
  1888. else None
  1889. )
  1890. if expr_graph is not None and not obj.is_qat:
  1891. exprs = _flatten_subgraph(graph, expr_graph, expr, obj)
  1892. if parent_graph is not None:
  1893. for node in expr.outputs:
  1894. name = node._name
  1895. if node not in rename_blacklist:
  1896. name = "{}_{}".format(prefix_name, name)
  1897. node._name = parent_graph._namespace.create_unique_name(
  1898. name, node
  1899. )
  1900. flattened_exprs.extend(exprs)
  1901. if call is not None:
  1902. for i in call.inputs:
  1903. i.users.remove(call)
  1904. return flattened_exprs
  1905. new_module.graph._exprs = _flatten_subgraph(
  1906. None, new_module.graph, None, new_module
  1907. )
  1908. new_module.graph._re_associate_name()
  1909. new_module.graph.compile()
  1910. new_module._update_ref()
  1911. new_module.graph._reset_ids()
  1912. return new_module
  1913. def __getstate__(self):
  1914. d = self.__dict__.copy()
  1915. for k in Module.__dict__:
  1916. d.pop(k, None)
  1917. _check_obj_attr(d)
  1918. for k in d:
  1919. if module_tracer.is_builtin(d[k]):
  1920. assert _check_builtin_module_attr(
  1921. d[k]
  1922. ), "Module {} can not be serialized. ".format(type(d[k]))
  1923. d[k] = _ModuleState.get_module_state(d[k])
  1924. dump_info = {
  1925. "version": __version__,
  1926. "register_type": USER_REGISTERED_LEAF_TYPE,
  1927. "register_container_type": USER_REGISTERED_CONTAINER_TYPE,
  1928. "register_mdule": USER_REGISTERED_MODULE,
  1929. "register_function": USER_REGISTERED_FUNCTION,
  1930. }
  1931. d["dump_info"] = dump_info
  1932. return d
  1933. def __setstate__(self, state):
  1934. for k, v in state.items():
  1935. if isinstance(v, _ModuleState):
  1936. state[k] = v.to_module()
  1937. self.__dict__.update(state)
  1938. self._update_ref()
  1939. for _, graph in self.argdef_graph_map.items():
  1940. for expr in graph._exprs:
  1941. if isinstance(expr, CallFunction):
  1942. load_functional(expr)
  1943. if isinstance(expr, CallMethod):
  1944. if expr.method == "__call__":
  1945. load_call_module_expr(expr)
  1946. else:
  1947. load_call_tensor_method_expr(expr)
  1948. if isinstance(expr, Apply):
  1949. load_apply_expr(expr)
  1950. for _, graph in self.argdef_graph_map.items():
  1951. ind = 0
  1952. while ind < len(graph._exprs):
  1953. cur_expr = graph._exprs[ind]
  1954. has_new_expr = False
  1955. for i in cur_expr.inputs:
  1956. if i.expr not in graph._exprs and not isinstance(i.expr, Input):
  1957. graph._exprs.insert(ind, i.expr)
  1958. has_new_expr = True
  1959. if not has_new_expr:
  1960. ind += 1
  1961. for expr in graph._exprs:
  1962. for i in expr.inputs:
  1963. if expr.inputs.count(i) != i.users.count(expr):
  1964. add_or_del_count = expr.inputs.count(i) - i.users.count(expr)
  1965. if add_or_del_count > 0:
  1966. i.users.extend([expr] * add_or_del_count)
  1967. else:
  1968. [i.users.remove(expr) for i in range(-add_or_del_count)]
  1969. for o in expr.outputs:
  1970. if o.expr is not expr:
  1971. assert o not in o.expr.outputs
  1972. o.expr = expr
  1973. for node in graph.nodes(False):
  1974. # remove users of node which doesn't use node as input
  1975. node.users = [e for e in node.users if node in e.inputs]
  1976. for expr in graph._exprs:
  1977. graph._namespace.auto_naming_for_outputs(expr)
  1978. self._update_ref()
  1979. for _, graph in self.argdef_graph_map.items():
  1980. graph._reset_ids()
  1981. def __copy__(self):
  1982. cls = self.__class__
  1983. result = cls.__new__(cls)
  1984. result.__dict__.update(self.__dict__)
  1985. return result
  1986. def __deepcopy__(self, memo):
  1987. with max_recursion_limit():
  1988. cls = self.__class__
  1989. result = cls.__new__(cls)
  1990. state = {}
  1991. memo[id(self)] = result
  1992. for k, v in self.__dict__.items():
  1993. if not isinstance(v, weakref.ReferenceType):
  1994. state[k] = copy.deepcopy(v, memo)
  1995. result.__dict__.update(state)
  1996. result._update_ref()
  1997. return result
  1998. def cpp_apply_module_trace(opdef, *args):
  1999. return Apply.apply_module_trace_hook(opdef, *args)
  2000. USER_REGISTERED_MODULE = []
  2001. USER_REGISTERED_FUNCTION = []
  2002. def register_as_builtin(mod_cls: Type[Module]) -> None:
  2003. r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module.
  2004. Args:
  2005. mod_cls: the module class which will be treated as builtin module in tracing.
  2006. """
  2007. USER_REGISTERED_MODULE.append((mod_cls.__module__, mod_cls.__qualname__))
  2008. module_tracer.register_as_builtin(mod_cls)
  2009. def wrap(func: Callable):
  2010. r"""Call this function to register ``func`` as a builtin function.
  2011. This function can be called at module-level scope to register ``func`` as a builtin function.
  2012. A builtin function will be converted to a :class:`CallFunction` Expr in tracing::
  2013. def my_func(x, y):
  2014. return x + y
  2015. import megengine.traced_module as tm
  2016. tm.wrap(my_func)
  2017. This function can also equivalently be used as a decorator::
  2018. @tm.wrap
  2019. def my_func(x, y):
  2020. return x + y
  2021. Args:
  2022. func: the function of the global function to insert into the graph when it's called.
  2023. """
  2024. USER_REGISTERED_FUNCTION.append((func.__module__, func.__qualname__))
  2025. assert callable(func), "func must be a callable"
  2026. assert hasattr(func, "__code__")
  2027. fn_name = func.__code__.co_name
  2028. currentframe = inspect.currentframe()
  2029. assert currentframe is not None
  2030. f = currentframe.f_back
  2031. assert f is not None
  2032. assert (
  2033. f.f_code.co_name == "<module>"
  2034. ), "wrap must be called at the top level of a module"
  2035. Patcher._builtin_functions.append((f.f_globals, fn_name))
  2036. return func
  2037. def _register_all_builtin_module():
  2038. for sub_mod in [M, M.qat, M.quantized, MExternal]:
  2039. for m in getmembers(sub_mod):
  2040. if (
  2041. isclass(m[1])
  2042. and issubclass(m[1], M.Module)
  2043. and m[1] is not M.Sequential
  2044. ):
  2045. module_tracer.register_as_builtin(m[1])
  2046. module_tracer.register_as_builtin(Observer)
  2047. module_tracer.register_as_builtin(MinMaxObserver)
  2048. module_tracer.register_as_builtin(SyncMinMaxObserver)
  2049. module_tracer.register_as_builtin(ExponentialMovingAverageObserver)
  2050. module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver)
  2051. module_tracer.register_as_builtin(HistogramObserver)
  2052. module_tracer.register_as_builtin(PassiveObserver)
  2053. module_tracer.register_as_builtin(LSQ)
  2054. module_tracer.register_as_builtin(TQT)
  2055. module_tracer.register_as_builtin(FakeQuantize)
  2056. module_tracer.register_as_builtin(TM_FakeQuant)
  2057. def trace_module(
  2058. mod: Module, *args: Tuple[Any], **kwargs: Dict[str, Any]
  2059. ) -> TracedModule:
  2060. r"""Traces module ``mod`` and returns corresponding :class:`TracedModule`.
  2061. Args:
  2062. mod: the module will be converted to :class:`TracedModule`.
  2063. args: the positional arguments passed to forward method of ``mod``.
  2064. kwargs: the keyword arguments passed to forward method of ``mod``.
  2065. """
  2066. assert active_module_tracer() is None
  2067. assert isinstance(mod, Module)
  2068. try:
  2069. net_name = mod._name if mod._name else mod.__class__.__name__
  2070. use_sym_shape = set_symbolic_shape(True)
  2071. set_module_tracing()
  2072. set_active_module_tracer(module_tracer(_wrapped_function))
  2073. for cls in [Expr, Node]:
  2074. cls._set_next_id(0)
  2075. with active_module_tracer().patcher:
  2076. global_scope = InternalGraph(name="top", qualname=net_name)
  2077. active_module_tracer().push_scope(global_scope)
  2078. builder = TracedModuleBuilder(mod, True)
  2079. NodeMixin.wrap_safe(
  2080. builder, Input.make(name="top", type=ModuleNode, qualname=net_name)
  2081. )
  2082. forward_argspec = (
  2083. mod.argspec
  2084. if hasattr(mod, "argspec")
  2085. else inspect.getfullargspec(mod.forward)
  2086. )
  2087. args, kwargs = _convert_kwargs_to_args(forward_argspec, args, kwargs, True)
  2088. inputs, _ = tree_flatten((args, kwargs))
  2089. for _, i in enumerate(inputs):
  2090. # assert isinstance(i, Tensor), "not support "
  2091. if isinstance(i, RawTensor):
  2092. NodeMixin.wrap_safe(
  2093. i,
  2094. Input.make(
  2095. name="arg_{}".format(_),
  2096. type=NodeMixin.get_wrapped_type(i),
  2097. qualname="{}.[{}]".format(net_name, "arg_{}".format(_)),
  2098. ),
  2099. )
  2100. rst = builder(*copy.deepcopy(args), **copy.deepcopy(kwargs))
  2101. active_module_tracer().pop_scope()
  2102. traced_mod = builder.build()
  2103. traced_mod.argspec = forward_argspec
  2104. traced_mod.graph._reset_ids()
  2105. has_expr_not_check = False
  2106. if _get_expr_checker():
  2107. has_expr_not_check = (
  2108. active_module_tracer().checker.check_node_not_in_scope()
  2109. )
  2110. if _get_default_checker() or has_expr_not_check:
  2111. with _exclude_from_trace():
  2112. tm_res = traced_mod(*args, **kwargs)
  2113. tm_res, _ = tree_flatten(tm_res, is_leaf=_is_leaf)
  2114. rst, _ = tree_flatten(rst, is_leaf=_is_leaf)
  2115. active_module_tracer().checker.check_net_outputs(tm_res, rst)
  2116. return traced_mod
  2117. finally:
  2118. set_symbolic_shape(use_sym_shape)
  2119. set_active_module_tracer(None)
  2120. unset_module_tracing()
  2121. for t in mod.tensors(recursive=True):
  2122. NodeMixin.clear_node(t)
  2123. for t in inputs:
  2124. NodeMixin.clear_node(t)