You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import builtins
  9. import collections
  10. import copy
  11. import inspect
  12. import re
  13. import weakref
  14. from importlib import import_module
  15. from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
  16. from ..core._imperative_rt import OpDef
  17. from ..core._imperative_rt.core2 import Tensor as RawTensor
  18. from ..core._imperative_rt.core2 import (
  19. apply,
  20. is_tracing_module,
  21. set_module_trace_hook,
  22. set_module_tracing,
  23. unset_module_tracing,
  24. )
  25. from ..core.ops.builtin import FakeQuant
  26. from ..core.ops.special import Const
  27. from ..module import Module
  28. from ..tensor import Parameter, Tensor
  29. from ..version import __version__
  30. from .module_tracer import active_module_tracer, module_tracer
  31. from .node import ModuleNode, Node, NodeMixin, TensorNode
  32. from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
  33. from .serialization import _ModuleState
  34. from .tm_config import _exclude_from_trace, _get_expr_checker
  35. from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args
  36. def rstrip(s: str, __chars: str):
  37. __chars = re.escape(__chars)
  38. s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
  39. return s
  40. def get_suffix_name(prefix: str, name: str):
  41. if prefix == name:
  42. return ""
  43. matchd = re.compile("^%s\.(.*)" % prefix).match(name)
  44. if matchd is None:
  45. return None
  46. return matchd.group(1)
  47. def is_call_module(expr, module_cls: Module = None):
  48. return (
  49. isinstance(expr, CallMethod)
  50. and isinstance(expr.inputs[0], ModuleNode)
  51. and expr.method == "__call__"
  52. ) and (module_cls is None or isinstance(expr.inputs[0].owner, module_cls))
  53. def is_call_tensor_method(expr, method: Iterable[str] = None):
  54. if method and isinstance(method, str):
  55. method = (method,)
  56. return (
  57. isinstance(expr, CallMethod)
  58. and not is_call_module(expr)
  59. and (method is None or any(expr.method == f for f in method))
  60. )
  61. def is_call_function(expr, func: Iterable[Callable] = None):
  62. if func and not isinstance(func, Iterable):
  63. func = (func,)
  64. return isinstance(expr, CallFunction) and (
  65. func is None or any(expr.func == f for f in func)
  66. )
  67. def is_constant(expr):
  68. return isinstance(expr, Constant)
  69. def is_getattr(expr):
  70. return isinstance(expr, GetAttr)
  71. def is_apply_def(expr, opdef=None):
  72. return isinstance(expr, Apply) and (opdef is None or isinstance(expr.opdef, opdef))
  73. def is_input(expr):
  74. return isinstance(expr, Input)
  75. class Expr:
  76. r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
  77. ``GetAttr``, ``Input``, ``Constant``) on ``Node``.
  78. """
  79. inputs = None # type: List[Node]
  80. r"""The input Nodes of this Expr."""
  81. outputs = None # type: List[Node]
  82. r"""The output Nodes of this Expr."""
  83. const_val = None # type: List[Any]
  84. r"""The non-tensor object in the input of the operation."""
  85. arg_def = None # type: TreeDef
  86. r"""The :class:`TreeDef` used to reconstruct the input of the operation."""
  87. out_def = None # type: TreeDef
  88. r"""The :class:`TreeDef` used to reconstruct the output of the operation."""
  89. _top_graph = None # type: weakref.ReferenceType
  90. __total_id = 0
  91. def __init__(self) -> None:
  92. self._id = Expr.__total_id
  93. Expr.__total_id += 1
  94. self._disable_remove = False
  95. def enable_remove(self):
  96. self._disable_remove = False
  97. def disable_remove(self):
  98. self._disable_remove = True
  99. def add_inputs(self, vals):
  100. if not isinstance(vals, collections.abc.Sequence):
  101. vals = (vals,)
  102. for val in vals:
  103. node = NodeMixin.get(val, None)
  104. if isinstance(node, (TensorNode, ModuleNode)):
  105. self.inputs.append(node)
  106. node.users.append(self)
  107. else:
  108. assert node is None
  109. assert not isinstance(val, (Module, RawTensor))
  110. assert _is_leaf(val) and _is_const_leaf(val)
  111. idx = len(self.inputs) + len(self.const_val)
  112. self.const_val.append((idx, val))
  113. def add_outputs(self, outputs):
  114. assert active_module_tracer() is not None
  115. self.outputs = []
  116. if outputs is None:
  117. return
  118. current_graph = active_module_tracer().current_scope()
  119. if not isinstance(outputs, collections.Sequence):
  120. outputs = (outputs,)
  121. for i in outputs:
  122. assert isinstance(i, RawTensor), "The output must be a Tensor"
  123. node = NodeMixin.get_wrapped_type(i)(expr=self, name="", qualname="",)
  124. NodeMixin.wrap_safe(i, node)
  125. self.outputs.append(node)
  126. current_graph._namespace.auto_naming_for_outputs(self)
  127. def unflatten_args(self, inputs):
  128. assert self.arg_def is not None, "{} expr doesn't have args/kwargs".format(
  129. type(self).__name__
  130. )
  131. inputs = list(inputs)
  132. for idx, val in self.const_val:
  133. inputs.insert(idx, val)
  134. args, kwargs = self.arg_def.unflatten(inputs)
  135. return args, kwargs
  136. def replace_inputs(self, repl_dict: Dict[Node, Node]):
  137. r"""Replace the input Nodes of this Expr.
  138. Args:
  139. repl_dict: the map {old_Node: new_Node} that specifies how to replace the input Nodes.
  140. """
  141. while repl_dict:
  142. node, repl_node = repl_dict.popitem()
  143. assert type(node) == type(repl_node)
  144. assert node in self.inputs, "({}) is not in the ({})".format(node, self)
  145. assert (
  146. repl_node.top_graph == node.top_graph
  147. ), "({}) and ({}) are not in the same graph".format(node, repl_node)
  148. graph = self.top_graph
  149. repl_expr_idx = graph._exprs.index(repl_node.expr)
  150. self_idx = graph._exprs.index(self)
  151. assert (
  152. repl_expr_idx < self_idx
  153. ), "({}) must be generated before ({})".format(repl_node, self)
  154. idx = self.inputs.index(node)
  155. self.inputs[idx] = repl_node
  156. node.users.remove(self)
  157. repl_node.users.append(self)
  158. @property
  159. def _support_set_args_kwargs(self):
  160. return False
  161. def set_args_kwargs(self, *args, **kwargs):
  162. r""" Set args and kwargs for Expr.
  163. """
  164. assert (
  165. self._support_set_args_kwargs
  166. ), "Doesn't support set args/kwargs for {} expr".format(type(self).__name__)
  167. args, kwargs = _convert_kwargs_to_args(self._get_func(), args, kwargs)
  168. inputs, arg_def = tree_flatten((args, kwargs))
  169. orig_inputs = self.inputs
  170. self.inputs = []
  171. self.const_val = []
  172. for val in inputs:
  173. if isinstance(val, (TensorNode, ModuleNode)):
  174. self.inputs.append(val)
  175. else:
  176. assert _is_leaf(val) and _is_const_leaf(val)
  177. idx = len(self.inputs) + len(self.const_val)
  178. self.const_val.append((idx, val))
  179. for n in orig_inputs:
  180. if n not in self.inputs:
  181. n.users.remove(self)
  182. for n in self.inputs:
  183. if n not in orig_inputs:
  184. n.users.append(self)
  185. self.arg_def = arg_def
  186. @property
  187. def kwargs(self):
  188. r"""Get the keyword arguments of the operation corresponding to this Expr."""
  189. _, kwargs = self.unflatten_args(self.inputs)
  190. return kwargs
  191. @property
  192. def args(self):
  193. r"""Get the positional arguments of the operation corresponding to this Expr."""
  194. args, _ = self.unflatten_args(self.inputs)
  195. return args
  196. def _get_func(self):
  197. # get called function when the expr is interpreted
  198. raise NotImplementedError
  199. @property
  200. def named_args(self):
  201. func = self._get_func()
  202. return inspect.getcallargs(func, *self.args, **self.kwargs)
  203. def set_arg(self, name, val):
  204. func = self._get_func()
  205. if name in self.kwargs:
  206. new_kwargs = self.kwargs
  207. new_kwargs[name] = val
  208. self.set_args_kwargs(*self.args, **new_kwargs)
  209. else:
  210. arg_spec = inspect.getfullargspec(func)
  211. if name in arg_spec.args:
  212. ind = arg_spec.args.index(name)
  213. new_args = list(self.args)
  214. new_args[ind] = val
  215. self.set_args_kwargs(*new_args)
  216. elif name == arg_spec.varargs:
  217. assert arg_spec.varargs is not None
  218. assert len(self.args) >= len(arg_spec.args)
  219. val = (val,) if not isinstance(val, Sequence) else val
  220. self.set_args_kwargs(*self.args[0 : len(arg_spec.args)], *val)
  221. else:
  222. assert (
  223. arg_spec.varkw is not None
  224. ), "func {} does't have argument named {}".format(func, name)
  225. new_kwargs = self.kwargs
  226. new_kwargs[name] = val
  227. self.set_args_kwargs(*self.args, **new_kwargs)
  228. @property
  229. def return_val(self):
  230. return self.out_def.unflatten(self.outputs)
  231. @return_val.setter
  232. def return_val(self, new_outputs):
  233. outputs, out_def = tree_flatten(
  234. new_outputs, is_leaf=lambda x: isinstance(x, Node)
  235. )
  236. assert all(
  237. isinstance(o, Node) for o in outputs
  238. ), "Return values of expr must be ModuleNode or TensorNode or Container with them"
  239. assert all(
  240. o.expr in (None, self) for o in outputs
  241. ), "Some nodes are produced by other expr, can not be output of expr {}".format(
  242. self
  243. )
  244. self.outputs = outputs
  245. self.out_def = out_def
  246. @property
  247. def top_graph(self):
  248. r"""Get the parent graph of this Expr."""
  249. if self._top_graph:
  250. return self._top_graph()
  251. return None
  252. @classmethod
  253. def _get_next_id(cls):
  254. return cls.__total_id
  255. @classmethod
  256. def _set_next_id(cls, id: int = 0):
  257. assert isinstance(id, int)
  258. cls.__total_id = id
  259. def __copy__(self):
  260. cls = self.__class__
  261. result = cls.__new__(cls)
  262. result.__dict__.update(self.__dict__)
  263. return result
  264. def __deepcopy__(self, memo):
  265. cls = self.__class__
  266. result = cls.__new__(cls)
  267. state = {}
  268. memo[id(self)] = result
  269. for k, v in self.__dict__.items():
  270. if not isinstance(v, weakref.ReferenceType):
  271. state[k] = copy.deepcopy(v, memo)
  272. result.__dict__.update(state)
  273. return result
  274. # expr: None (i.e. fake expression which is used to mark input)
  275. class Input(Expr):
  276. r"""A fake Expr which is used to mark the input of graph."""
  277. name = None
  278. def __init__(self, type: List[Node], name: str = "args", qualname: str = ""):
  279. super().__init__()
  280. assert type in [ModuleNode, TensorNode]
  281. assert name and qualname
  282. self.inputs = []
  283. node_cls = type if type else Node
  284. self.outputs = [
  285. node_cls(self, name=name, qualname=qualname),
  286. ]
  287. self.name = name
  288. @classmethod
  289. def make(cls, *args, **kwargs):
  290. assert active_module_tracer() is not None
  291. current_graph = active_module_tracer().current_scope()
  292. expr = cls(*args, **kwargs)
  293. out_node = expr.outputs[0]
  294. current_graph._namespace.auto_naming_for_outputs(expr)
  295. current_graph._add_input(out_node)
  296. return expr.outputs[0]
  297. def __repr__(self):
  298. return "%{}:\t{} = Input()".format(self._id, self.outputs[0])
  299. def __getstate__(self):
  300. state = {
  301. "_id": self._id,
  302. "_disable_remove": self._disable_remove,
  303. "inputs": self.inputs,
  304. "outputs": self.outputs,
  305. "name": self.name,
  306. }
  307. _check_obj_attr(state)
  308. return state
  309. # expr: outputs = getattr(inputs[0], self.name)
  310. class GetAttr(Expr):
  311. r"""``Getattr`` represents the fetch of an attribute from the ``Module`` hierarchy."""
  312. name = None
  313. r"""name: the qualified name of the attribute to be retrieved."""
  314. def __init__(
  315. self, module: ModuleNode, type: Union[Node], attr_name: str, name: str = "",
  316. ):
  317. super().__init__()
  318. assert isinstance(module, ModuleNode)
  319. assert type in [TensorNode, ModuleNode]
  320. self.inputs = [
  321. module,
  322. ]
  323. module.users.append(self)
  324. self.name = attr_name
  325. self.outputs = [
  326. type(self, name=name, qualname="{}.{}".format(module.qualname, attr_name)),
  327. ]
  328. @classmethod
  329. def make(cls, *args, **kwargs):
  330. assert active_module_tracer() is not None
  331. current_graph = active_module_tracer().current_scope()
  332. expr = cls(*args, **kwargs)
  333. current_graph._namespace.auto_naming_for_outputs(expr)
  334. current_graph._insert(expr)
  335. return expr.outputs[0]
  336. def interpret(self, *inputs):
  337. mod = inputs[0]
  338. module_path, _, name = self.name.rpartition(".")
  339. if module_path == "":
  340. return (getattr(mod, name),)
  341. module_names = module_path.split(".")
  342. for item in module_names:
  343. mod = getattr(mod, item)
  344. if not isinstance(mod, Module):
  345. raise AttributeError("`{}` is not an Module".format(item))
  346. return (getattr(mod, name),)
  347. def __repr__(self):
  348. out_type = "Tensor"
  349. if isinstance(self.outputs[0], ModuleNode):
  350. m_type = self.outputs[0].module_type
  351. out_type = m_type.__name__ if isinstance(m_type, type) else m_type[1]
  352. return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
  353. self._id, self.outputs[0], self.inputs[0], self.name, out_type
  354. )
  355. def __getstate__(self):
  356. state = {
  357. "_id": self._id,
  358. "_disable_remove": self._disable_remove,
  359. "inputs": self.inputs,
  360. "outputs": self.outputs,
  361. "name": self.name,
  362. }
  363. _check_obj_attr(state)
  364. return state
  365. # expr: outputs = inputs[0].__call__(*inputs[1:])
  366. class CallMethod(Expr):
  367. r"""``CallMethod`` represents a call to the ``__call__`` method of ``Module`` or a method of ``Tensor``.
  368. Args:
  369. node: the Node to be called.
  370. method: the method name.
  371. Default: "__call__"
  372. """
  373. def __init__(self, node, method="__call__"):
  374. super().__init__()
  375. if isinstance(node, type):
  376. assert issubclass(node, Tensor)
  377. cls = Parameter if issubclass(node, Parameter) else Tensor
  378. self.inputs = []
  379. self.const_val = [(0, cls)]
  380. else:
  381. assert isinstance(node, (TensorNode, ModuleNode))
  382. node.users.append(self)
  383. self.inputs = [
  384. node,
  385. ]
  386. self.const_val = []
  387. self.arg_def = tree_flatten(((node,), {}))[1]
  388. self.method = method
  389. @classmethod
  390. def make(cls, *args, **kwargs):
  391. assert active_module_tracer() is not None
  392. expr = cls(*args, **kwargs)
  393. active_module_tracer().current_scope()._insert(expr)
  394. return expr
  395. @property
  396. def graph(self):
  397. if isinstance(self.inputs[0], ModuleNode):
  398. m_node = self.inputs[0]
  399. if (
  400. hasattr(m_node.owner, "argdef_graph_map")
  401. and m_node.owner.argdef_graph_map
  402. ):
  403. assert self.arg_def in m_node.owner.argdef_graph_map
  404. return m_node.owner.argdef_graph_map[self.arg_def]
  405. return None
  406. def interpret(self, *inputs):
  407. args, kwargs = self.unflatten_args(inputs)
  408. obj = args[0]
  409. meth = getattr(obj, self.method)
  410. if inspect.ismethod(meth):
  411. args = args[1:]
  412. outputs = getattr(obj, self.method)(*args, **kwargs)
  413. if self.method == "__setitem__":
  414. outputs = obj
  415. if outputs is None:
  416. return outputs
  417. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  418. return outputs
  419. def _get_func(self):
  420. if isinstance(self.args[0], type):
  421. obj_type = self.args[0]
  422. elif isinstance(self.args[0], ModuleNode):
  423. obj_type = self.args[0].module_type
  424. else:
  425. assert isinstance(self.args[0], TensorNode)
  426. obj_type = Tensor
  427. meth = getattr(
  428. obj_type, "forward" if issubclass(obj_type, Module) else self.method
  429. )
  430. return meth
  431. @property
  432. def _support_set_args_kwargs(self):
  433. # only expr call tensor method or builtin module support modify args/kwargs
  434. return (
  435. isinstance(self.args[0], (TensorNode, type))
  436. or self.args[0].module_type is not Module
  437. )
  438. def __repr__(self):
  439. args = ", ".join(str(i) for i in self.args[1:])
  440. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  441. outputs = self.outputs
  442. if self.out_def:
  443. outputs = self.out_def.unflatten(outputs)
  444. method = ".%s" % self.method
  445. if method == ".__call__":
  446. method = ""
  447. return "%{}:\t{}{}{}({})".format(
  448. self._id,
  449. str(outputs) + " = " if outputs else "",
  450. self.args[0],
  451. method,
  452. ", ".join([args, kwargs]),
  453. )
  454. def __getstate__(self):
  455. state = {
  456. "_id": self._id,
  457. "_disable_remove": self._disable_remove,
  458. "inputs": self.inputs,
  459. "const_val": self.const_val,
  460. "method": self.method,
  461. "arg_def": self.arg_def,
  462. "out_def": self.out_def,
  463. "outputs": self.outputs,
  464. "version": __version__,
  465. }
  466. _check_obj_attr(state)
  467. return state
  468. # expr: outputs = apply(self.opdef, *inputs)
  469. class Apply(Expr):
  470. r"""``Apply`` represents a call to :func:`apply`.
  471. Args:
  472. opdef: the applied :class:`OpDef`.
  473. """
  474. opdef = None
  475. def __init__(self, opdef):
  476. super().__init__()
  477. assert isinstance(opdef, OpDef)
  478. self.opdef = opdef
  479. self.inputs = []
  480. @classmethod
  481. def make(cls, *args, **kwargs):
  482. assert active_module_tracer() is not None
  483. expr = cls(*args, **kwargs)
  484. active_module_tracer().current_scope()._insert(expr)
  485. return expr
  486. def interpret(self, *inputs):
  487. return apply(self.opdef, *inputs)
  488. def __repr__(self):
  489. return "%{}:\t{} = {}({})".format(
  490. self._id,
  491. ", ".join(str(i) for i in self.outputs),
  492. self.opdef,
  493. ", ".join(str(i) for i in self.inputs),
  494. )
  495. def __getstate__(self):
  496. opdef_state = self.opdef.__getstate__()
  497. opdef_state["opdef_type"] = type(self.opdef)
  498. state = {
  499. "_id": self._id,
  500. "_disable_remove": self._disable_remove,
  501. "opdef_state": opdef_state,
  502. "inputs": self.inputs,
  503. "outputs": self.outputs,
  504. "version": __version__,
  505. }
  506. _check_obj_attr(state)
  507. return state
  508. def __setstate__(self, state):
  509. # compat with mge 1.6
  510. if "opdef" in state and "opdef_state" not in state:
  511. opdef_state = state.pop("opdef")
  512. opdef_state["opdef_type"] = opdef_state.pop("type")
  513. state["opdef_state"] = opdef_state
  514. self.__dict__.update(state)
  515. assert isinstance(state["opdef_state"], dict)
  516. opdef_state = state["opdef_state"].copy()
  517. opdef_type = opdef_state.pop("opdef_type")
  518. opdef_obj = opdef_type()
  519. opdef_obj.__setstate__(opdef_state)
  520. setattr(self, "opdef", opdef_obj)
  521. @classmethod
  522. def apply_module_trace_hook(cls, opdef, *inputs):
  523. for i in inputs:
  524. node = NodeMixin.get(i, None)
  525. assert node is not None
  526. if isinstance(opdef, FakeQuant):
  527. inp_nodes = [NodeMixin.get(inputs[0])]
  528. for i in inputs[1:]:
  529. node = Constant.make(i)
  530. if _get_expr_checker():
  531. active_module_tracer().checker.record_node2value(node, Tensor(i))
  532. inp_nodes.append(node)
  533. apply_node = cls.make(opdef)
  534. for n in inp_nodes:
  535. n.users.append(apply_node)
  536. apply_node.inputs = inp_nodes
  537. else:
  538. apply_node = cls.make(opdef)
  539. apply_node.add_inputs(inputs)
  540. assert not apply_node.const_val
  541. unset_module_tracing()
  542. outputs = apply(opdef, *inputs)
  543. outputs = list(map(Tensor, outputs))
  544. set_module_tracing()
  545. apply_node.add_outputs(outputs)
  546. for n, v in zip(apply_node.outputs, outputs):
  547. NodeMixin.wrap_safe(v, n)
  548. if _get_expr_checker():
  549. with _exclude_from_trace():
  550. active_module_tracer().checker.check_apply(apply_node, outputs, opdef)
  551. return list(outputs)
  552. class CallFunction(Expr):
  553. r"""``CallFunction`` represents a call to a built-in function.
  554. Args:
  555. func: a built-in function.
  556. """
  557. def __init__(self, func):
  558. super().__init__()
  559. assert isinstance(func, Callable)
  560. self.func = func
  561. self.const_val = []
  562. self.inputs = []
  563. @classmethod
  564. def make(cls, *args, **kwargs):
  565. assert active_module_tracer() is not None
  566. expr = cls(*args, **kwargs)
  567. active_module_tracer().current_scope()._insert(expr)
  568. return expr
  569. def interpret(self, *inputs):
  570. args, kwargs = self.unflatten_args(inputs)
  571. func = (
  572. self.func
  573. if not is_tracing_module()
  574. else active_module_tracer().patcher.wrap_fn(self.func)
  575. )
  576. outputs = func(*args, **kwargs)
  577. if outputs is None:
  578. return outputs
  579. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  580. return outputs
  581. def _get_func(self):
  582. return self.func
  583. @property
  584. def _support_set_args_kwargs(self):
  585. return True
  586. def __repr__(self):
  587. args = ", ".join(str(i) for i in self.args)
  588. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  589. outputs = self.outputs
  590. if self.out_def:
  591. outputs = self.out_def.unflatten(outputs)
  592. return "%{}:\t{}{}({})".format(
  593. self._id,
  594. str(outputs) + " = " if outputs else "",
  595. self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
  596. ", ".join([args, kwargs]),
  597. )
  598. def __getstate__(self):
  599. state = {
  600. "_id": self._id,
  601. "_disable_remove": self._disable_remove,
  602. "func": (self.func.__module__, self.func.__qualname__),
  603. "const_val": self.const_val,
  604. "inputs": self.inputs,
  605. "arg_def": self.arg_def,
  606. "out_def": self.out_def,
  607. "outputs": self.outputs,
  608. "version": __version__,
  609. }
  610. _check_obj_attr(state)
  611. return state
  612. def __setstate__(self, state):
  613. self.__dict__.update(state)
  614. try:
  615. if isinstance(self.func, tuple):
  616. mname, fname = self.func
  617. f = import_module(mname)
  618. for i in fname.split("."):
  619. f = getattr(f, i)
  620. self.func = f
  621. except Exception:
  622. pass
  623. # expr outputs = self.value
  624. class Constant(Expr):
  625. r"""``Constant`` represents a ``Tensor`` or "Module" which is not the attribute of a Module.
  626. Args:
  627. c: a const Tensor or Module.
  628. name: the name of output Node.
  629. """
  630. value = None
  631. r"""The const Tensor or Module"""
  632. # TODO: constant cache to reduce the size of dumped model
  633. _constant_cache = {}
  634. def __init__(self, c, name: str = "", qualname: str = ""):
  635. super().__init__()
  636. assert isinstance(c, (RawTensor, Module))
  637. if isinstance(c, Module):
  638. assert module_tracer.is_builtin(c) or c.is_qat
  639. if isinstance(c, RawTensor):
  640. if is_tracing_module():
  641. unset_module_tracing()
  642. c = Tensor(c)
  643. set_module_tracing()
  644. else:
  645. c = Tensor(c)
  646. self.value = c
  647. self.name = name
  648. self.inputs = []
  649. node_cls = NodeMixin.get_wrapped_type(c)
  650. self.outputs = [
  651. node_cls(self, name=name, qualname=qualname),
  652. ]
  653. @classmethod
  654. def make(cls, *args, **kwargs):
  655. assert active_module_tracer() is not None
  656. expr = cls(*args, **kwargs)
  657. current_graph = active_module_tracer().current_scope()
  658. current_graph._namespace.auto_naming_for_outputs(expr)
  659. current_graph._insert(expr)
  660. active_module_tracer().current_constant_cache().append(expr.value)
  661. return expr.outputs[0]
  662. def interpret(self, *inputs):
  663. if isinstance(self.value, RawTensor):
  664. return Const(self.value.numpy())()
  665. return (self.value,)
  666. def __repr__(self):
  667. name = self.name
  668. if name is None:
  669. name = type(self.value)
  670. node_type = "Module"
  671. if isinstance(self.outputs[0], TensorNode):
  672. node_type = "Tensor"
  673. return "%{}:\t{} = Constant({}) -> ({})".format(
  674. self._id, self.outputs[0], name, node_type
  675. )
  676. def __getstate__(self):
  677. state = {
  678. "_id": self._id,
  679. "_disable_remove": self._disable_remove,
  680. "value": self.value,
  681. "name": self.name,
  682. "inputs": self.inputs,
  683. "outputs": self.outputs,
  684. }
  685. _check_obj_attr(state)
  686. if isinstance(self.value, RawTensor):
  687. state["value"] = Tensor(self.value)
  688. if isinstance(self.value, Module) and module_tracer.is_builtin(self.value):
  689. _check_builtin_module_attr(self.value)
  690. state["value"] = _ModuleState.get_module_state(self.value)
  691. return state
  692. def __setstate__(self, state):
  693. for k, v in state.items():
  694. if isinstance(v, _ModuleState):
  695. state[k] = v.to_module()
  696. self.__dict__.update(state)
  697. def _module_trace_capture(value):
  698. node = Constant.make(value)
  699. NodeMixin.wrap_safe(value, node)
  700. return node
  701. set_module_trace_hook(Apply.apply_module_trace_hook)