You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import builtins
  9. import collections
  10. import copy
  11. import inspect
  12. import re
  13. import weakref
  14. from importlib import import_module
  15. from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
  16. from ..core._imperative_rt import OpDef
  17. from ..core._imperative_rt.core2 import Tensor as RawTensor
  18. from ..core._imperative_rt.core2 import (
  19. apply,
  20. is_tracing_module,
  21. set_module_tracing,
  22. unset_module_tracing,
  23. )
  24. from ..core.ops.builtin import FakeQuant
  25. from ..core.ops.special import Const
  26. from ..module import Module
  27. from ..tensor import Parameter, Tensor
  28. from ..version import __version__
  29. from .module_tracer import active_module_tracer, module_tracer
  30. from .node import ModuleNode, Node, NodeMixin, TensorNode
  31. from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
  32. from .serialization import _ModuleState
  33. from .tm_config import _exclude_from_trace, _get_expr_checker
  34. from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args
  35. def rstrip(s: str, __chars: str):
  36. __chars = re.escape(__chars)
  37. s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
  38. return s
  39. def get_suffix_name(prefix: str, name: str):
  40. if prefix == name:
  41. return ""
  42. matchd = re.compile("^%s\.(.*)" % prefix).match(name)
  43. if matchd is None:
  44. return None
  45. return matchd.group(1)
  46. def is_call_module(expr, module_cls: Module = None):
  47. return (
  48. isinstance(expr, CallMethod)
  49. and isinstance(expr.inputs[0], ModuleNode)
  50. and expr.method == "__call__"
  51. ) and (module_cls is None or isinstance(expr.inputs[0].owner, module_cls))
  52. def is_call_tensor_method(expr, method: Iterable[str] = None):
  53. if method and isinstance(method, str):
  54. method = (method,)
  55. return (
  56. isinstance(expr, CallMethod)
  57. and not is_call_module(expr)
  58. and (method is None or any(expr.method == f for f in method))
  59. )
  60. def is_call_function(expr, func: Iterable[Callable] = None):
  61. if func and not isinstance(func, Iterable):
  62. func = (func,)
  63. return isinstance(expr, CallFunction) and (
  64. func is None or any(expr.func == f for f in func)
  65. )
  66. def is_constant(expr):
  67. return isinstance(expr, Constant)
  68. def is_getattr(expr):
  69. return isinstance(expr, GetAttr)
  70. def is_apply_def(expr, opdef=None):
  71. return isinstance(expr, Apply) and (opdef is None or isinstance(expr.opdef, opdef))
  72. def is_input(expr):
  73. return isinstance(expr, Input)
  74. class Expr:
  75. r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
  76. ``GetAttr``, ``Input``, ``Constant``) on ``Node``.
  77. """
  78. inputs = None # type: List[Node]
  79. r"""The input Nodes of this Expr."""
  80. outputs = None # type: List[Node]
  81. r"""The output Nodes of this Expr."""
  82. const_val = None # type: List[Any]
  83. r"""The non-tensor object in the input of the operation."""
  84. arg_def = None # type: TreeDef
  85. r"""The :class:`TreeDef` used to reconstruct the input of the operation."""
  86. out_def = None # type: TreeDef
  87. r"""The :class:`TreeDef` used to reconstruct the output of the operation."""
  88. _top_graph = None # type: weakref.ReferenceType
  89. __total_id = 0
  90. def __init__(self) -> None:
  91. self._id = Expr.__total_id
  92. Expr.__total_id += 1
  93. self._disable_remove = False
  94. def enable_remove(self):
  95. self._disable_remove = False
  96. def disable_remove(self):
  97. self._disable_remove = True
  98. def add_inputs(self, vals):
  99. if not isinstance(vals, collections.abc.Sequence):
  100. vals = (vals,)
  101. for val in vals:
  102. node = NodeMixin.get(val, None)
  103. if isinstance(node, (TensorNode, ModuleNode)):
  104. self.inputs.append(node)
  105. node.users.append(self)
  106. else:
  107. assert node is None
  108. assert not isinstance(val, (Module, RawTensor))
  109. assert _is_leaf(val) and _is_const_leaf(val)
  110. idx = len(self.inputs) + len(self.const_val)
  111. self.const_val.append((idx, val))
  112. def add_outputs(self, outputs):
  113. assert active_module_tracer() is not None
  114. self.outputs = []
  115. if outputs is None:
  116. return
  117. current_graph = active_module_tracer().current_scope()
  118. if not isinstance(outputs, collections.Sequence):
  119. outputs = (outputs,)
  120. for i in outputs:
  121. assert isinstance(i, RawTensor), "The output must be a Tensor"
  122. node = NodeMixin.get_wrapped_type(i)(expr=self, name="", qualname="",)
  123. NodeMixin.wrap_safe(i, node)
  124. self.outputs.append(node)
  125. current_graph._namespace.auto_naming_for_outputs(self)
  126. def unflatten_args(self, inputs):
  127. assert self.arg_def is not None, "{} expr doesn't have args/kwargs".format(
  128. type(self).__name__
  129. )
  130. inputs = list(inputs)
  131. for idx, val in self.const_val:
  132. inputs.insert(idx, val)
  133. args, kwargs = self.arg_def.unflatten(inputs)
  134. return args, kwargs
  135. def replace_inputs(self, repl_dict: Dict[Node, Node]):
  136. r"""Replace the input Nodes of this Expr.
  137. Args:
  138. repl_dict: the map {old_Node: new_Node} that specifies how to replace the input Nodes.
  139. """
  140. while repl_dict:
  141. node, repl_node = repl_dict.popitem()
  142. assert type(node) == type(repl_node)
  143. assert node in self.inputs, "({}) is not in the ({})".format(node, self)
  144. assert (
  145. repl_node.top_graph == node.top_graph
  146. ), "({}) and ({}) are not in the same graph".format(node, repl_node)
  147. graph = self.top_graph
  148. repl_expr_idx = graph._exprs.index(repl_node.expr)
  149. self_idx = graph._exprs.index(self)
  150. assert (
  151. repl_expr_idx < self_idx
  152. ), "({}) must be generated before ({})".format(repl_node, self)
  153. idx = self.inputs.index(node)
  154. self.inputs[idx] = repl_node
  155. node.users.remove(self)
  156. repl_node.users.append(self)
  157. @property
  158. def _support_set_args_kwargs(self):
  159. return False
  160. def set_args_kwargs(self, *args, **kwargs):
  161. r""" Set args and kwargs for Expr.
  162. """
  163. assert (
  164. self._support_set_args_kwargs
  165. ), "Doesn't support set args/kwargs for {} expr".format(type(self).__name__)
  166. args, kwargs = _convert_kwargs_to_args(self._get_func(), args, kwargs)
  167. inputs, arg_def = tree_flatten((args, kwargs))
  168. orig_inputs = self.inputs
  169. self.inputs = []
  170. self.const_val = []
  171. for val in inputs:
  172. if isinstance(val, (TensorNode, ModuleNode)):
  173. self.inputs.append(val)
  174. else:
  175. assert _is_leaf(val) and _is_const_leaf(val)
  176. idx = len(self.inputs) + len(self.const_val)
  177. self.const_val.append((idx, val))
  178. for n in orig_inputs:
  179. if n not in self.inputs:
  180. n.users.remove(self)
  181. for n in self.inputs:
  182. if n not in orig_inputs:
  183. n.users.append(self)
  184. self.arg_def = arg_def
  185. @property
  186. def kwargs(self):
  187. r"""Get the keyword arguments of the operation corresponding to this Expr."""
  188. _, kwargs = self.unflatten_args(self.inputs)
  189. return kwargs
  190. @property
  191. def args(self):
  192. r"""Get the positional arguments of the operation corresponding to this Expr."""
  193. args, _ = self.unflatten_args(self.inputs)
  194. return args
  195. def _get_func(self):
  196. # get called function when the expr is interpreted
  197. raise NotImplementedError
  198. @property
  199. def named_args(self):
  200. func = self._get_func()
  201. return inspect.getcallargs(func, *self.args, **self.kwargs)
  202. def set_arg(self, name, val):
  203. func = self._get_func()
  204. if name in self.kwargs:
  205. new_kwargs = self.kwargs
  206. new_kwargs[name] = val
  207. self.set_args_kwargs(*self.args, **new_kwargs)
  208. else:
  209. arg_spec = inspect.getfullargspec(func)
  210. if name in arg_spec.args:
  211. ind = arg_spec.args.index(name)
  212. new_args = list(self.args)
  213. new_args[ind] = val
  214. self.set_args_kwargs(*new_args)
  215. elif name == arg_spec.varargs:
  216. assert arg_spec.varargs is not None
  217. assert len(self.args) >= len(arg_spec.args)
  218. val = (val,) if not isinstance(val, Sequence) else val
  219. self.set_args_kwargs(*self.args[0 : len(arg_spec.args)], *val)
  220. else:
  221. assert (
  222. arg_spec.varkw is not None
  223. ), "func {} does't have argument named {}".format(func, name)
  224. new_kwargs = self.kwargs
  225. new_kwargs[name] = val
  226. self.set_args_kwargs(*self.args, **new_kwargs)
  227. @property
  228. def return_val(self):
  229. return self.out_def.unflatten(self.outputs)
  230. @return_val.setter
  231. def return_val(self, new_outputs):
  232. outputs, out_def = tree_flatten(
  233. new_outputs, is_leaf=lambda x: isinstance(x, Node)
  234. )
  235. assert all(
  236. isinstance(o, Node) for o in outputs
  237. ), "Return values of expr must be ModuleNode or TensorNode or Container with them"
  238. assert all(
  239. o.expr in (None, self) for o in outputs
  240. ), "Some nodes are produced by other expr, can not be output of expr {}".format(
  241. self
  242. )
  243. self.outputs = outputs
  244. self.out_def = out_def
  245. @property
  246. def top_graph(self):
  247. r"""Get the parent graph of this Expr."""
  248. if self._top_graph:
  249. return self._top_graph()
  250. return None
  251. @classmethod
  252. def _get_next_id(cls):
  253. return cls.__total_id
  254. @classmethod
  255. def _set_next_id(cls, id: int = 0):
  256. assert isinstance(id, int)
  257. cls.__total_id = id
  258. def __copy__(self):
  259. cls = self.__class__
  260. result = cls.__new__(cls)
  261. result.__dict__.update(self.__dict__)
  262. return result
  263. def __deepcopy__(self, memo):
  264. cls = self.__class__
  265. result = cls.__new__(cls)
  266. state = {}
  267. memo[id(self)] = result
  268. for k, v in self.__dict__.items():
  269. if not isinstance(v, weakref.ReferenceType):
  270. state[k] = copy.deepcopy(v, memo)
  271. result.__dict__.update(state)
  272. return result
  273. # expr: None (i.e. fake expression which is used to mark input)
  274. class Input(Expr):
  275. r"""A fake Expr which is used to mark the input of graph."""
  276. name = None
  277. def __init__(self, type: List[Node], name: str = "args", qualname: str = ""):
  278. super().__init__()
  279. assert type in [ModuleNode, TensorNode]
  280. assert name and qualname
  281. self.inputs = []
  282. node_cls = type if type else Node
  283. self.outputs = [
  284. node_cls(self, name=name, qualname=qualname),
  285. ]
  286. self.name = name
  287. @classmethod
  288. def make(cls, *args, **kwargs):
  289. assert active_module_tracer() is not None
  290. current_graph = active_module_tracer().current_scope()
  291. expr = cls(*args, **kwargs)
  292. out_node = expr.outputs[0]
  293. current_graph._namespace.auto_naming_for_outputs(expr)
  294. current_graph._add_input(out_node)
  295. return expr.outputs[0]
  296. def __repr__(self):
  297. return "%{}:\t{} = Input()".format(self._id, self.outputs[0])
  298. def __getstate__(self):
  299. state = {
  300. "_id": self._id,
  301. "_disable_remove": self._disable_remove,
  302. "inputs": self.inputs,
  303. "outputs": self.outputs,
  304. "name": self.name,
  305. }
  306. _check_obj_attr(state)
  307. return state
  308. # expr: outputs = getattr(inputs[0], self.name)
  309. class GetAttr(Expr):
  310. r"""``Getattr`` represents the fetch of an attribute from the ``Module`` hierarchy."""
  311. name = None
  312. r"""name: the qualified name of the attribute to be retrieved."""
  313. def __init__(
  314. self, module: ModuleNode, type: Union[Node], attr_name: str, name: str = "",
  315. ):
  316. super().__init__()
  317. assert isinstance(module, ModuleNode)
  318. assert type in [TensorNode, ModuleNode]
  319. self.inputs = [
  320. module,
  321. ]
  322. module.users.append(self)
  323. self.name = attr_name
  324. self.outputs = [
  325. type(self, name=name, qualname="{}.{}".format(module.qualname, attr_name)),
  326. ]
  327. @classmethod
  328. def make(cls, *args, **kwargs):
  329. assert active_module_tracer() is not None
  330. current_graph = active_module_tracer().current_scope()
  331. expr = cls(*args, **kwargs)
  332. current_graph._namespace.auto_naming_for_outputs(expr)
  333. current_graph._insert(expr)
  334. return expr.outputs[0]
  335. def interpret(self, *inputs):
  336. mod = inputs[0]
  337. module_path, _, name = self.name.rpartition(".")
  338. if module_path == "":
  339. return (getattr(mod, name),)
  340. module_names = module_path.split(".")
  341. for item in module_names:
  342. mod = getattr(mod, item)
  343. if not isinstance(mod, Module):
  344. raise AttributeError("`{}` is not an Module".format(item))
  345. return (getattr(mod, name),)
  346. def __repr__(self):
  347. out_type = "Tensor"
  348. if isinstance(self.outputs[0], ModuleNode):
  349. m_type = self.outputs[0].module_type
  350. out_type = m_type.__name__ if isinstance(m_type, type) else m_type[1]
  351. return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
  352. self._id, self.outputs[0], self.inputs[0], self.name, out_type
  353. )
  354. def __getstate__(self):
  355. state = {
  356. "_id": self._id,
  357. "_disable_remove": self._disable_remove,
  358. "inputs": self.inputs,
  359. "outputs": self.outputs,
  360. "name": self.name,
  361. }
  362. _check_obj_attr(state)
  363. return state
  364. # expr: outputs = inputs[0].__call__(*inputs[1:])
  365. class CallMethod(Expr):
  366. r"""``CallMethod`` represents a call to the ``__call__`` method of ``Module`` or a method of ``Tensor``.
  367. Args:
  368. node: the Node to be called.
  369. method: the method name.
  370. Default: "__call__"
  371. """
  372. def __init__(self, node, method="__call__"):
  373. super().__init__()
  374. if isinstance(node, type):
  375. assert issubclass(node, Tensor)
  376. cls = Parameter if issubclass(node, Parameter) else Tensor
  377. self.inputs = []
  378. self.const_val = [(0, cls)]
  379. else:
  380. assert isinstance(node, (TensorNode, ModuleNode))
  381. node.users.append(self)
  382. self.inputs = [
  383. node,
  384. ]
  385. self.const_val = []
  386. self.arg_def = tree_flatten(((node,), {}))[1]
  387. self.method = method
  388. @classmethod
  389. def make(cls, *args, **kwargs):
  390. assert active_module_tracer() is not None
  391. expr = cls(*args, **kwargs)
  392. active_module_tracer().current_scope()._insert(expr)
  393. return expr
  394. @property
  395. def graph(self):
  396. if isinstance(self.inputs[0], ModuleNode):
  397. m_node = self.inputs[0]
  398. if (
  399. hasattr(m_node.owner, "argdef_graph_map")
  400. and m_node.owner.argdef_graph_map
  401. ):
  402. assert self.arg_def in m_node.owner.argdef_graph_map
  403. return m_node.owner.argdef_graph_map[self.arg_def]
  404. return None
  405. def interpret(self, *inputs):
  406. args, kwargs = self.unflatten_args(inputs)
  407. obj = args[0]
  408. meth = getattr(obj, self.method)
  409. if inspect.ismethod(meth):
  410. args = args[1:]
  411. outputs = getattr(obj, self.method)(*args, **kwargs)
  412. if self.method == "__setitem__":
  413. outputs = obj
  414. if outputs is None:
  415. return outputs
  416. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  417. return outputs
  418. def _get_func(self):
  419. if isinstance(self.args[0], type):
  420. obj_type = self.args[0]
  421. elif isinstance(self.args[0], ModuleNode):
  422. obj_type = self.args[0].module_type
  423. else:
  424. assert isinstance(self.args[0], TensorNode)
  425. obj_type = Tensor
  426. meth = getattr(
  427. obj_type, "forward" if issubclass(obj_type, Module) else self.method
  428. )
  429. return meth
  430. @property
  431. def _support_set_args_kwargs(self):
  432. # only expr call tensor method or builtin module support modify args/kwargs
  433. return (
  434. isinstance(self.args[0], (TensorNode, type))
  435. or self.args[0].module_type is not Module
  436. )
  437. def __repr__(self):
  438. args = ", ".join(str(i) for i in self.args[1:])
  439. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  440. outputs = self.outputs
  441. if self.out_def:
  442. outputs = self.out_def.unflatten(outputs)
  443. method = ".%s" % self.method
  444. if method == ".__call__":
  445. method = ""
  446. return "%{}:\t{}{}{}({})".format(
  447. self._id,
  448. str(outputs) + " = " if outputs else "",
  449. self.args[0],
  450. method,
  451. ", ".join([args, kwargs]),
  452. )
  453. def __getstate__(self):
  454. state = {
  455. "_id": self._id,
  456. "_disable_remove": self._disable_remove,
  457. "inputs": self.inputs,
  458. "const_val": self.const_val,
  459. "method": self.method,
  460. "arg_def": self.arg_def,
  461. "out_def": self.out_def,
  462. "outputs": self.outputs,
  463. "version": __version__,
  464. }
  465. _check_obj_attr(state)
  466. return state
  467. # expr: outputs = apply(self.opdef, *inputs)
  468. class Apply(Expr):
  469. r"""``Apply`` represents a call to :func:`apply`.
  470. Args:
  471. opdef: the applied :class:`OpDef`.
  472. """
  473. opdef = None
  474. def __init__(self, opdef):
  475. super().__init__()
  476. assert isinstance(opdef, OpDef)
  477. self.opdef = opdef
  478. self.inputs = []
  479. @classmethod
  480. def make(cls, *args, **kwargs):
  481. assert active_module_tracer() is not None
  482. expr = cls(*args, **kwargs)
  483. active_module_tracer().current_scope()._insert(expr)
  484. return expr
  485. def interpret(self, *inputs):
  486. return apply(self.opdef, *inputs)
  487. def __repr__(self):
  488. return "%{}:\t{} = {}({})".format(
  489. self._id,
  490. ", ".join(str(i) for i in self.outputs),
  491. self.opdef,
  492. ", ".join(str(i) for i in self.inputs),
  493. )
  494. def __getstate__(self):
  495. opdef_state = self.opdef.__getstate__()
  496. opdef_state["opdef_type"] = type(self.opdef)
  497. state = {
  498. "_id": self._id,
  499. "_disable_remove": self._disable_remove,
  500. "opdef_state": opdef_state,
  501. "inputs": self.inputs,
  502. "outputs": self.outputs,
  503. "version": __version__,
  504. }
  505. _check_obj_attr(state)
  506. return state
  507. def __setstate__(self, state):
  508. # compat with mge 1.6
  509. if "opdef" in state and "opdef_state" not in state:
  510. opdef_state = state.pop("opdef")
  511. opdef_state["opdef_type"] = opdef_state.pop("type")
  512. state["opdef_state"] = opdef_state
  513. self.__dict__.update(state)
  514. assert isinstance(state["opdef_state"], dict)
  515. opdef_state = state["opdef_state"].copy()
  516. opdef_type = opdef_state.pop("opdef_type")
  517. opdef_obj = opdef_type()
  518. opdef_obj.__setstate__(opdef_state)
  519. setattr(self, "opdef", opdef_obj)
  520. @classmethod
  521. def apply_module_trace_hook(cls, opdef, *inputs):
  522. for i in inputs:
  523. node = NodeMixin.get(i, None)
  524. if node is None: # capture as constant
  525. NodeMixin.wrap_safe(i, Constant.make(i))
  526. if isinstance(opdef, FakeQuant):
  527. inp_nodes = [NodeMixin.get(inputs[0])]
  528. for i in inputs[1:]:
  529. node = Constant.make(i)
  530. if _get_expr_checker():
  531. active_module_tracer().checker.record_node2value(node, Tensor(i))
  532. inp_nodes.append(node)
  533. apply_node = cls.make(opdef)
  534. for n in inp_nodes:
  535. n.users.append(apply_node)
  536. apply_node.inputs = inp_nodes
  537. else:
  538. apply_node = cls.make(opdef)
  539. apply_node.add_inputs(inputs)
  540. assert not apply_node.const_val
  541. unset_module_tracing()
  542. outputs = apply(opdef, *inputs)
  543. outputs = list(map(Tensor, outputs))
  544. set_module_tracing()
  545. apply_node.add_outputs(outputs)
  546. for n, v in zip(apply_node.outputs, outputs):
  547. NodeMixin.wrap_safe(v, n)
  548. if _get_expr_checker():
  549. with _exclude_from_trace():
  550. active_module_tracer().checker.check_apply(apply_node, outputs, opdef)
  551. return list(outputs)
  552. class CallFunction(Expr):
  553. r"""``CallFunction`` represents a call to a built-in function.
  554. Args:
  555. func: a built-in function.
  556. """
  557. def __init__(self, func):
  558. super().__init__()
  559. assert isinstance(func, Callable)
  560. self.func = func
  561. self.const_val = []
  562. self.inputs = []
  563. @classmethod
  564. def make(cls, *args, **kwargs):
  565. assert active_module_tracer() is not None
  566. expr = cls(*args, **kwargs)
  567. active_module_tracer().current_scope()._insert(expr)
  568. return expr
  569. def interpret(self, *inputs):
  570. args, kwargs = self.unflatten_args(inputs)
  571. func = (
  572. self.func
  573. if not is_tracing_module()
  574. else active_module_tracer().patcher.wrap_fn(self.func)
  575. )
  576. outputs = func(*args, **kwargs)
  577. if outputs is None:
  578. return outputs
  579. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  580. return outputs
  581. def _get_func(self):
  582. return self.func
  583. @property
  584. def _support_set_args_kwargs(self):
  585. return True
  586. def __repr__(self):
  587. args = ", ".join(str(i) for i in self.args)
  588. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  589. outputs = self.outputs
  590. if self.out_def:
  591. outputs = self.out_def.unflatten(outputs)
  592. return "%{}:\t{}{}({})".format(
  593. self._id,
  594. str(outputs) + " = " if outputs else "",
  595. self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
  596. ", ".join([args, kwargs]),
  597. )
  598. def __getstate__(self):
  599. state = {
  600. "_id": self._id,
  601. "_disable_remove": self._disable_remove,
  602. "func": (self.func.__module__, self.func.__qualname__),
  603. "const_val": self.const_val,
  604. "inputs": self.inputs,
  605. "arg_def": self.arg_def,
  606. "out_def": self.out_def,
  607. "outputs": self.outputs,
  608. "version": __version__,
  609. }
  610. _check_obj_attr(state)
  611. return state
  612. def __setstate__(self, state):
  613. self.__dict__.update(state)
  614. try:
  615. if isinstance(self.func, tuple):
  616. mname, fname = self.func
  617. f = import_module(mname)
  618. for i in fname.split("."):
  619. f = getattr(f, i)
  620. self.func = f
  621. except Exception:
  622. pass
  623. # expr outputs = self.value
  624. class Constant(Expr):
  625. r"""``Constant`` represents a ``Tensor`` or "Module" which is not the attribute of a Module.
  626. Args:
  627. c: a const Tensor or Module.
  628. name: the name of output Node.
  629. """
  630. value = None
  631. r"""The const Tensor or Module"""
  632. # TODO: constant cache to reduce the size of dumped model
  633. _constant_cache = {}
  634. def __init__(self, c, name: str = "", qualname: str = ""):
  635. super().__init__()
  636. assert isinstance(c, (RawTensor, Module))
  637. if isinstance(c, Module):
  638. assert module_tracer.is_builtin(c) or c.is_qat
  639. if isinstance(c, RawTensor):
  640. if is_tracing_module():
  641. unset_module_tracing()
  642. c = Tensor(c)
  643. set_module_tracing()
  644. else:
  645. c = Tensor(c)
  646. self.value = c
  647. self.name = name
  648. self.inputs = []
  649. node_cls = NodeMixin.get_wrapped_type(c)
  650. self.outputs = [
  651. node_cls(self, name=name, qualname=qualname),
  652. ]
  653. @classmethod
  654. def make(cls, *args, **kwargs):
  655. assert active_module_tracer() is not None
  656. expr = cls(*args, **kwargs)
  657. current_graph = active_module_tracer().current_scope()
  658. current_graph._namespace.auto_naming_for_outputs(expr)
  659. current_graph._insert(expr)
  660. return expr.outputs[0]
  661. def interpret(self, *inputs):
  662. if isinstance(self.value, RawTensor):
  663. return Const(self.value.numpy())()
  664. return (self.value,)
  665. def __repr__(self):
  666. name = self.name
  667. if name is None:
  668. name = type(self.value)
  669. node_type = "Module"
  670. if isinstance(self.outputs[0], TensorNode):
  671. node_type = "Tensor"
  672. return "%{}:\t{} = Constant({}) -> ({})".format(
  673. self._id, self.outputs[0], name, node_type
  674. )
  675. def __getstate__(self):
  676. state = {
  677. "_id": self._id,
  678. "_disable_remove": self._disable_remove,
  679. "value": self.value,
  680. "name": self.name,
  681. "inputs": self.inputs,
  682. "outputs": self.outputs,
  683. }
  684. _check_obj_attr(state)
  685. if isinstance(self.value, RawTensor):
  686. state["value"] = Tensor(self.value)
  687. if isinstance(self.value, Module) and module_tracer.is_builtin(self.value):
  688. _check_builtin_module_attr(self.value)
  689. state["value"] = _ModuleState.get_module_state(self.value)
  690. return state
  691. def __setstate__(self, state):
  692. for k, v in state.items():
  693. if isinstance(v, _ModuleState):
  694. state[k] = v.to_module()
  695. self.__dict__.update(state)