You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804
  1. import builtins
  2. import collections
  3. import copy
  4. import inspect
  5. import re
  6. import weakref
  7. from importlib import import_module
  8. from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
  9. from ..core._imperative_rt import OpDef
  10. from ..core._imperative_rt.core2 import Const
  11. from ..core._imperative_rt.core2 import Tensor as RawTensor
  12. from ..core._imperative_rt.core2 import (
  13. apply,
  14. is_tracing_module,
  15. set_module_trace_hook,
  16. set_module_tracing,
  17. unset_module_tracing,
  18. )
  19. from ..core.ops.builtin import FakeQuant
  20. from ..module import Module
  21. from ..tensor import Parameter, Tensor
  22. from ..version import __version__
  23. from .module_tracer import active_module_tracer, module_tracer
  24. from .node import ModuleNode, Node, NodeMixin, TensorNode
  25. from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
  26. from .serialization import _ModuleState
  27. from .tm_config import _exclude_from_trace, _get_expr_checker
  28. from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args
  29. def rstrip(s: str, __chars: str):
  30. __chars = re.escape(__chars)
  31. s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
  32. return s
  33. def get_suffix_name(prefix: str, name: str):
  34. if prefix == name:
  35. return ""
  36. matchd = re.compile("^%s\.(.*)" % prefix).match(name)
  37. if matchd is None:
  38. return None
  39. return matchd.group(1)
  40. def is_call_module(expr, module_cls: Module = None):
  41. return (
  42. isinstance(expr, CallMethod)
  43. and isinstance(expr.inputs[0], ModuleNode)
  44. and expr.method == "__call__"
  45. ) and (module_cls is None or isinstance(expr.inputs[0].owner, module_cls))
  46. def is_call_tensor_method(expr, method: Iterable[str] = None):
  47. if method and isinstance(method, str):
  48. method = (method,)
  49. return (
  50. isinstance(expr, CallMethod)
  51. and not is_call_module(expr)
  52. and (method is None or any(expr.method == f for f in method))
  53. )
  54. def is_call_function(expr, func: Iterable[Callable] = None):
  55. if func and not isinstance(func, Iterable):
  56. func = (func,)
  57. return isinstance(expr, CallFunction) and (
  58. func is None or any(expr.func == f for f in func)
  59. )
  60. def is_constant(expr):
  61. return isinstance(expr, Constant)
  62. def is_getattr(expr):
  63. return isinstance(expr, GetAttr)
  64. def is_apply_def(expr, opdef=None):
  65. return isinstance(expr, Apply) and (opdef is None or isinstance(expr.opdef, opdef))
  66. def is_input(expr):
  67. return isinstance(expr, Input)
  68. class Expr:
  69. r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
  70. ``GetAttr``, ``Input``, ``Constant``) on ``Node``.
  71. """
  72. inputs = None # type: List[Node]
  73. r"""The input Nodes of this Expr."""
  74. outputs = None # type: List[Node]
  75. r"""The output Nodes of this Expr."""
  76. const_val = None # type: List[Any]
  77. r"""The non-tensor object in the input of the operation."""
  78. arg_def = None # type: TreeDef
  79. r"""The :class:`TreeDef` used to reconstruct the input of the operation."""
  80. out_def = None # type: TreeDef
  81. r"""The :class:`TreeDef` used to reconstruct the output of the operation."""
  82. _top_graph = None # type: weakref.ReferenceType
  83. __total_id = 0
  84. def __init__(self) -> None:
  85. self._id = Expr.__total_id
  86. Expr.__total_id += 1
  87. self._disable_remove = False
  88. def enable_remove(self):
  89. self._disable_remove = False
  90. def disable_remove(self):
  91. self._disable_remove = True
  92. def add_inputs(self, vals):
  93. if not isinstance(vals, collections.abc.Sequence):
  94. vals = (vals,)
  95. for val in vals:
  96. node = NodeMixin.get(val, None)
  97. if isinstance(node, (TensorNode, ModuleNode)):
  98. self.inputs.append(node)
  99. node.users.append(self)
  100. else:
  101. assert node is None
  102. assert not isinstance(val, (Module, RawTensor))
  103. assert _is_leaf(val) and _is_const_leaf(val)
  104. idx = len(self.inputs) + len(self.const_val)
  105. self.const_val.append((idx, val))
  106. def add_outputs(self, outputs):
  107. assert active_module_tracer() is not None
  108. self.outputs = []
  109. if outputs is None:
  110. return
  111. current_graph = active_module_tracer().current_scope()
  112. if not isinstance(outputs, collections.abc.Sequence):
  113. outputs = (outputs,)
  114. for i in outputs:
  115. assert isinstance(i, RawTensor), "The output must be a Tensor"
  116. node = NodeMixin.get_wrapped_type(i)(expr=self, name="", qualname="",)
  117. NodeMixin.wrap_safe(i, node)
  118. self.outputs.append(node)
  119. current_graph._namespace.auto_naming_for_outputs(self)
  120. def unflatten_args(self, inputs):
  121. assert self.arg_def is not None, "{} expr doesn't have args/kwargs".format(
  122. type(self).__name__
  123. )
  124. inputs = list(inputs)
  125. for idx, val in self.const_val:
  126. inputs.insert(idx, val)
  127. args, kwargs = self.arg_def.unflatten(inputs)
  128. return args, kwargs
  129. def replace_inputs(self, repl_dict: Dict[Node, Node]):
  130. r"""Replace the input Nodes of this Expr.
  131. Args:
  132. repl_dict: the map {old_Node: new_Node} that specifies how to replace the input Nodes.
  133. """
  134. while repl_dict:
  135. node, repl_node = repl_dict.popitem()
  136. assert type(node) == type(repl_node)
  137. assert node in self.inputs, "({}) is not in the ({})".format(node, self)
  138. assert (
  139. repl_node.top_graph == node.top_graph
  140. ), "({}) and ({}) are not in the same graph".format(node, repl_node)
  141. graph = self.top_graph
  142. repl_expr_idx = graph._exprs.index(repl_node.expr)
  143. self_idx = graph._exprs.index(self)
  144. assert (
  145. repl_expr_idx < self_idx
  146. ), "({}) must be generated before ({})".format(repl_node, self)
  147. idx = self.inputs.index(node)
  148. self.inputs[idx] = repl_node
  149. node.users.remove(self)
  150. repl_node.users.append(self)
  151. @property
  152. def _support_set_args_kwargs(self):
  153. return False
  154. def set_args_kwargs(self, *args, **kwargs):
  155. r""" Set args and kwargs for Expr.
  156. """
  157. assert (
  158. self._support_set_args_kwargs
  159. ), "Doesn't support set args/kwargs for {} expr".format(type(self).__name__)
  160. args, kwargs = _convert_kwargs_to_args(self._get_func(), args, kwargs)
  161. inputs, arg_def = tree_flatten((args, kwargs))
  162. orig_inputs = self.inputs
  163. self.inputs = []
  164. self.const_val = []
  165. for val in inputs:
  166. if isinstance(val, (TensorNode, ModuleNode)):
  167. self.inputs.append(val)
  168. else:
  169. assert _is_leaf(val) and _is_const_leaf(val)
  170. idx = len(self.inputs) + len(self.const_val)
  171. self.const_val.append((idx, val))
  172. for n in orig_inputs:
  173. if n not in self.inputs:
  174. n.users.remove(self)
  175. for n in self.inputs:
  176. if n not in orig_inputs:
  177. n.users.append(self)
  178. self.arg_def = arg_def
  179. @property
  180. def kwargs(self):
  181. r"""Get the keyword arguments of the operation corresponding to this Expr."""
  182. _, kwargs = self.unflatten_args(self.inputs)
  183. return kwargs
  184. @property
  185. def args(self):
  186. r"""Get the positional arguments of the operation corresponding to this Expr."""
  187. args, _ = self.unflatten_args(self.inputs)
  188. return args
  189. def _get_func(self):
  190. # get called function when the expr is interpreted
  191. raise NotImplementedError
  192. @property
  193. def named_args(self):
  194. func = self._get_func()
  195. return inspect.getcallargs(func, *self.args, **self.kwargs)
  196. def set_arg(self, name, val):
  197. func = self._get_func()
  198. if name in self.kwargs:
  199. new_kwargs = self.kwargs
  200. new_kwargs[name] = val
  201. self.set_args_kwargs(*self.args, **new_kwargs)
  202. else:
  203. arg_spec = inspect.getfullargspec(func)
  204. if name in arg_spec.args:
  205. ind = arg_spec.args.index(name)
  206. new_args = list(self.args)
  207. new_args[ind] = val
  208. self.set_args_kwargs(*new_args)
  209. elif name == arg_spec.varargs:
  210. assert arg_spec.varargs is not None
  211. assert len(self.args) >= len(arg_spec.args)
  212. val = (val,) if not isinstance(val, Sequence) else val
  213. self.set_args_kwargs(*self.args[0 : len(arg_spec.args)], *val)
  214. else:
  215. assert (
  216. arg_spec.varkw is not None
  217. ), "func {} does't have argument named {}".format(func, name)
  218. new_kwargs = self.kwargs
  219. new_kwargs[name] = val
  220. self.set_args_kwargs(*self.args, **new_kwargs)
  221. @property
  222. def return_val(self):
  223. return self.out_def.unflatten(self.outputs)
  224. @return_val.setter
  225. def return_val(self, new_outputs):
  226. outputs, out_def = tree_flatten(
  227. new_outputs, is_leaf=lambda x: isinstance(x, Node)
  228. )
  229. assert all(
  230. isinstance(o, Node) for o in outputs
  231. ), "Return values of expr must be ModuleNode or TensorNode or Container with them"
  232. assert all(
  233. o.expr in (None, self) for o in outputs
  234. ), "Some nodes are produced by other expr, can not be output of expr {}".format(
  235. self
  236. )
  237. self.outputs = outputs
  238. self.out_def = out_def
  239. @property
  240. def top_graph(self):
  241. r"""Get the parent graph of this Expr."""
  242. if self._top_graph:
  243. return self._top_graph()
  244. return None
  245. @classmethod
  246. def _get_next_id(cls):
  247. return cls.__total_id
  248. @classmethod
  249. def _set_next_id(cls, id: int = 0):
  250. assert isinstance(id, int)
  251. cls.__total_id = id
  252. def __copy__(self):
  253. cls = self.__class__
  254. result = cls.__new__(cls)
  255. result.__dict__.update(self.__dict__)
  256. return result
  257. def __deepcopy__(self, memo):
  258. cls = self.__class__
  259. result = cls.__new__(cls)
  260. state = {}
  261. memo[id(self)] = result
  262. for k, v in self.__dict__.items():
  263. if not isinstance(v, weakref.ReferenceType):
  264. state[k] = copy.deepcopy(v, memo)
  265. result.__dict__.update(state)
  266. return result
  267. # expr: None (i.e. fake expression which is used to mark input)
  268. class Input(Expr):
  269. r"""A fake Expr which is used to mark the input of graph."""
  270. name = None
  271. def __init__(self, type: List[Node], name: str = "args", qualname: str = ""):
  272. super().__init__()
  273. assert type in [ModuleNode, TensorNode]
  274. assert name and qualname
  275. self.inputs = []
  276. node_cls = type if type else Node
  277. self.outputs = [
  278. node_cls(self, name=name, qualname=qualname),
  279. ]
  280. self.name = name
  281. @classmethod
  282. def make(cls, *args, **kwargs):
  283. assert active_module_tracer() is not None
  284. current_graph = active_module_tracer().current_scope()
  285. expr = cls(*args, **kwargs)
  286. out_node = expr.outputs[0]
  287. current_graph._namespace.auto_naming_for_outputs(expr)
  288. current_graph._add_input(out_node)
  289. return expr.outputs[0]
  290. def __repr__(self):
  291. return "%{}:\t{} = Input()".format(self._id, self.outputs[0])
  292. def __getstate__(self):
  293. state = {
  294. "_id": self._id,
  295. "_disable_remove": self._disable_remove,
  296. "inputs": self.inputs,
  297. "outputs": self.outputs,
  298. "name": self.name,
  299. }
  300. _check_obj_attr(state)
  301. return state
  302. # expr: outputs = getattr(inputs[0], self.name)
  303. class GetAttr(Expr):
  304. r"""``Getattr`` represents the fetch of an attribute from the ``Module`` hierarchy."""
  305. name = None
  306. r"""name: the qualified name of the attribute to be retrieved."""
  307. def __init__(
  308. self, module: ModuleNode, type: Union[Node], attr_name: str, name: str = "",
  309. ):
  310. super().__init__()
  311. assert isinstance(module, ModuleNode)
  312. assert type in [TensorNode, ModuleNode]
  313. self.inputs = [
  314. module,
  315. ]
  316. module.users.append(self)
  317. self.name = attr_name
  318. self.outputs = [
  319. type(self, name=name, qualname="{}.{}".format(module.qualname, attr_name)),
  320. ]
  321. @classmethod
  322. def make(cls, *args, **kwargs):
  323. assert active_module_tracer() is not None
  324. current_graph = active_module_tracer().current_scope()
  325. expr = cls(*args, **kwargs)
  326. current_graph._namespace.auto_naming_for_outputs(expr)
  327. current_graph._insert(expr)
  328. return expr.outputs[0]
  329. def interpret(self, *inputs):
  330. mod = inputs[0]
  331. module_path, _, name = self.name.rpartition(".")
  332. if module_path == "":
  333. return (getattr(mod, name),)
  334. module_names = module_path.split(".")
  335. for item in module_names:
  336. mod = getattr(mod, item)
  337. if not isinstance(mod, Module):
  338. raise AttributeError("`{}` is not an Module".format(item))
  339. return (getattr(mod, name),)
  340. def __repr__(self):
  341. out_type = "Tensor"
  342. if isinstance(self.outputs[0], ModuleNode):
  343. m_type = self.outputs[0].module_type
  344. out_type = m_type.__name__ if isinstance(m_type, type) else m_type[1]
  345. return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
  346. self._id, self.outputs[0], self.inputs[0], self.name, out_type
  347. )
  348. def __getstate__(self):
  349. state = {
  350. "_id": self._id,
  351. "_disable_remove": self._disable_remove,
  352. "inputs": self.inputs,
  353. "outputs": self.outputs,
  354. "name": self.name,
  355. }
  356. _check_obj_attr(state)
  357. return state
  358. # expr: outputs = inputs[0].__call__(*inputs[1:])
  359. class CallMethod(Expr):
  360. r"""``CallMethod`` represents a call to the ``__call__`` method of ``Module`` or a method of ``Tensor``.
  361. Args:
  362. node: the Node to be called.
  363. method: the method name.
  364. Default: "__call__"
  365. """
  366. def __init__(self, node, method="__call__"):
  367. super().__init__()
  368. if isinstance(node, type):
  369. assert issubclass(node, Tensor)
  370. cls = Parameter if issubclass(node, Parameter) else Tensor
  371. self.inputs = []
  372. self.const_val = [(0, cls)]
  373. else:
  374. assert isinstance(node, (TensorNode, ModuleNode))
  375. node.users.append(self)
  376. self.inputs = [
  377. node,
  378. ]
  379. self.const_val = []
  380. self.arg_def = tree_flatten(((node,), {}))[1]
  381. self.method = method
  382. @classmethod
  383. def make(cls, *args, **kwargs):
  384. assert active_module_tracer() is not None
  385. expr = cls(*args, **kwargs)
  386. active_module_tracer().current_scope()._insert(expr)
  387. return expr
  388. @property
  389. def graph(self):
  390. if isinstance(self.inputs[0], ModuleNode):
  391. m_node = self.inputs[0]
  392. if (
  393. hasattr(m_node.owner, "argdef_graph_map")
  394. and m_node.owner.argdef_graph_map
  395. ):
  396. assert self.arg_def in m_node.owner.argdef_graph_map
  397. return m_node.owner.argdef_graph_map[self.arg_def]
  398. return None
  399. def interpret(self, *inputs):
  400. args, kwargs = self.unflatten_args(inputs)
  401. obj = args[0]
  402. meth = getattr(obj, self.method)
  403. if inspect.ismethod(meth):
  404. args = args[1:]
  405. outputs = getattr(obj, self.method)(*args, **kwargs)
  406. if self.method == "__setitem__":
  407. outputs = obj
  408. if outputs is None:
  409. return outputs
  410. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  411. return outputs
  412. def _get_func(self):
  413. if isinstance(self.args[0], type):
  414. obj_type = self.args[0]
  415. elif isinstance(self.args[0], ModuleNode):
  416. obj_type = self.args[0].module_type
  417. else:
  418. assert isinstance(self.args[0], TensorNode)
  419. obj_type = Tensor
  420. meth = getattr(
  421. obj_type, "forward" if issubclass(obj_type, Module) else self.method
  422. )
  423. return meth
  424. @property
  425. def _support_set_args_kwargs(self):
  426. # only expr call tensor method or builtin module support modify args/kwargs
  427. return (
  428. isinstance(self.args[0], (TensorNode, type))
  429. or self.args[0].module_type is not Module
  430. )
  431. def __repr__(self):
  432. args = ", ".join(str(i) for i in self.args[1:])
  433. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  434. outputs = self.outputs
  435. if self.out_def:
  436. outputs = self.out_def.unflatten(outputs)
  437. method = ".%s" % self.method
  438. if method == ".__call__":
  439. method = ""
  440. return "%{}:\t{}{}{}({})".format(
  441. self._id,
  442. str(outputs) + " = " if outputs else "",
  443. self.args[0],
  444. method,
  445. ", ".join([args, kwargs]),
  446. )
  447. def __getstate__(self):
  448. state = {
  449. "_id": self._id,
  450. "_disable_remove": self._disable_remove,
  451. "inputs": self.inputs,
  452. "const_val": self.const_val,
  453. "method": self.method,
  454. "arg_def": self.arg_def,
  455. "out_def": self.out_def,
  456. "outputs": self.outputs,
  457. "version": __version__,
  458. }
  459. _check_obj_attr(state)
  460. return state
  461. # expr: outputs = apply(self.opdef, *inputs)
  462. class Apply(Expr):
  463. r"""``Apply`` represents a call to :func:`apply`.
  464. Args:
  465. opdef: the applied :class:`OpDef`.
  466. """
  467. opdef = None
  468. def __init__(self, opdef):
  469. super().__init__()
  470. assert isinstance(opdef, OpDef)
  471. self.opdef = opdef
  472. self.inputs = []
  473. @classmethod
  474. def make(cls, *args, **kwargs):
  475. assert active_module_tracer() is not None
  476. expr = cls(*args, **kwargs)
  477. active_module_tracer().current_scope()._insert(expr)
  478. return expr
  479. def interpret(self, *inputs):
  480. return apply(self.opdef, *inputs)
  481. def __repr__(self):
  482. return "%{}:\t{} = {}({})".format(
  483. self._id,
  484. ", ".join(str(i) for i in self.outputs),
  485. self.opdef,
  486. ", ".join(str(i) for i in self.inputs),
  487. )
  488. def __getstate__(self):
  489. opdef_state = self.opdef.__getstate__()
  490. opdef_state["opdef_type"] = type(self.opdef)
  491. state = {
  492. "_id": self._id,
  493. "_disable_remove": self._disable_remove,
  494. "opdef_state": opdef_state,
  495. "inputs": self.inputs,
  496. "outputs": self.outputs,
  497. "version": __version__,
  498. }
  499. _check_obj_attr(state)
  500. return state
  501. def __setstate__(self, state):
  502. # compat with mge 1.6
  503. if "opdef" in state and "opdef_state" not in state:
  504. opdef_state = state.pop("opdef")
  505. opdef_state["opdef_type"] = opdef_state.pop("type")
  506. state["opdef_state"] = opdef_state
  507. self.__dict__.update(state)
  508. assert isinstance(state["opdef_state"], dict)
  509. opdef_state = state["opdef_state"].copy()
  510. opdef_type = opdef_state.pop("opdef_type")
  511. opdef_obj = opdef_type()
  512. opdef_obj.__setstate__(opdef_state)
  513. setattr(self, "opdef", opdef_obj)
  514. @classmethod
  515. def apply_module_trace_hook(cls, opdef, *inputs):
  516. for i in inputs:
  517. node = NodeMixin.get(i, None)
  518. if node is None: # capture as constant
  519. NodeMixin.wrap_safe(i, Constant.make(i))
  520. if isinstance(opdef, FakeQuant):
  521. inp_nodes = [NodeMixin.get(inputs[0])]
  522. for i in inputs[1:]:
  523. node = Constant.make(i)
  524. if _get_expr_checker():
  525. active_module_tracer().checker.record_node2value(node, Tensor(i))
  526. inp_nodes.append(node)
  527. apply_node = cls.make(opdef)
  528. for n in inp_nodes:
  529. n.users.append(apply_node)
  530. apply_node.inputs = inp_nodes
  531. else:
  532. apply_node = cls.make(opdef)
  533. apply_node.add_inputs(inputs)
  534. assert not apply_node.const_val
  535. unset_module_tracing()
  536. outputs = apply(opdef, *inputs)
  537. set_module_tracing()
  538. apply_node.add_outputs(outputs)
  539. for n, v in zip(apply_node.outputs, outputs):
  540. NodeMixin.wrap_safe(v, n)
  541. if _get_expr_checker():
  542. with _exclude_from_trace():
  543. active_module_tracer().checker.check_apply(apply_node, outputs, opdef)
  544. return list(outputs)
  545. class CallFunction(Expr):
  546. r"""``CallFunction`` represents a call to a built-in function.
  547. Args:
  548. func: a built-in function.
  549. """
  550. def __init__(self, func):
  551. super().__init__()
  552. assert isinstance(func, Callable)
  553. self.func = func
  554. self.const_val = []
  555. self.inputs = []
  556. @classmethod
  557. def make(cls, *args, **kwargs):
  558. assert active_module_tracer() is not None
  559. expr = cls(*args, **kwargs)
  560. active_module_tracer().current_scope()._insert(expr)
  561. return expr
  562. def interpret(self, *inputs):
  563. args, kwargs = self.unflatten_args(inputs)
  564. func = (
  565. self.func
  566. if not is_tracing_module()
  567. else active_module_tracer().patcher.wrap_fn(self.func)
  568. )
  569. outputs = func(*args, **kwargs)
  570. if outputs is None:
  571. return outputs
  572. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  573. return outputs
  574. def _get_func(self):
  575. return self.func
  576. @property
  577. def _support_set_args_kwargs(self):
  578. return True
  579. def __repr__(self):
  580. args = ", ".join(str(i) for i in self.args)
  581. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  582. outputs = self.outputs
  583. if self.out_def:
  584. outputs = self.out_def.unflatten(outputs)
  585. return "%{}:\t{}{}({})".format(
  586. self._id,
  587. str(outputs) + " = " if outputs else "",
  588. self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
  589. ", ".join([args, kwargs]),
  590. )
  591. def __getstate__(self):
  592. state = {
  593. "_id": self._id,
  594. "_disable_remove": self._disable_remove,
  595. "func": (self.func.__module__, self.func.__qualname__),
  596. "const_val": self.const_val,
  597. "inputs": self.inputs,
  598. "arg_def": self.arg_def,
  599. "out_def": self.out_def,
  600. "outputs": self.outputs,
  601. "version": __version__,
  602. }
  603. _check_obj_attr(state)
  604. return state
  605. def __setstate__(self, state):
  606. self.__dict__.update(state)
  607. try:
  608. if isinstance(self.func, tuple):
  609. mname, fname = self.func
  610. f = import_module(mname)
  611. for i in fname.split("."):
  612. f = getattr(f, i)
  613. self.func = f
  614. except Exception:
  615. pass
  616. # expr outputs = self.value
  617. class Constant(Expr):
  618. r"""``Constant`` represents a ``Tensor`` or "Module" which is not the attribute of a Module.
  619. Args:
  620. c: a const Tensor or Module.
  621. name: the name of output Node.
  622. """
  623. value = None
  624. r"""The const Tensor or Module"""
  625. # TODO: constant cache to reduce the size of dumped model
  626. _constant_cache = {}
  627. def __init__(self, c, name: str = "", qualname: str = ""):
  628. super().__init__()
  629. assert isinstance(c, (RawTensor, Module))
  630. if isinstance(c, Module):
  631. assert module_tracer.is_builtin(c) or c.is_qat
  632. if type(c) is RawTensor:
  633. with _exclude_from_trace():
  634. c = Tensor(c)
  635. self.value = c
  636. self.name = name
  637. self.inputs = []
  638. node_cls = NodeMixin.get_wrapped_type(c)
  639. self.outputs = [
  640. node_cls(self, name=name, qualname=qualname),
  641. ]
  642. @classmethod
  643. def make(cls, *args, **kwargs):
  644. assert active_module_tracer() is not None
  645. expr = cls(*args, **kwargs)
  646. current_graph = active_module_tracer().current_scope()
  647. current_graph._namespace.auto_naming_for_outputs(expr)
  648. current_graph._insert(expr)
  649. active_module_tracer().current_constant_cache().append(expr.value)
  650. return expr.outputs[0]
  651. def interpret(self, *inputs):
  652. if isinstance(self.value, RawTensor):
  653. return (Const(self.value.numpy(), None, None),)
  654. return (self.value,)
  655. def __repr__(self):
  656. name = self.name
  657. if name is None:
  658. name = type(self.value)
  659. node_type = "Module"
  660. if isinstance(self.outputs[0], TensorNode):
  661. node_type = "Tensor"
  662. return "%{}:\t{} = Constant({}) -> ({})".format(
  663. self._id, self.outputs[0], name, node_type
  664. )
  665. def __getstate__(self):
  666. state = {
  667. "_id": self._id,
  668. "_disable_remove": self._disable_remove,
  669. "value": self.value,
  670. "name": self.name,
  671. "inputs": self.inputs,
  672. "outputs": self.outputs,
  673. }
  674. _check_obj_attr(state)
  675. if isinstance(self.value, RawTensor):
  676. state["value"] = Tensor(self.value)
  677. if isinstance(self.value, Module) and module_tracer.is_builtin(self.value):
  678. _check_builtin_module_attr(self.value)
  679. state["value"] = _ModuleState.get_module_state(self.value)
  680. return state
  681. def __setstate__(self, state):
  682. for k, v in state.items():
  683. if isinstance(v, _ModuleState):
  684. state[k] = v.to_module()
  685. self.__dict__.update(state)
  686. def _module_trace_capture(value):
  687. node = Constant.make(value)
  688. NodeMixin.wrap_safe(value, node)
  689. return node
  690. set_module_trace_hook(Apply.apply_module_trace_hook)