You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import builtins
  9. import collections
  10. import copy
  11. import inspect
  12. import re
  13. import weakref
  14. from importlib import import_module
  15. from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
  16. from ..core._imperative_rt import OpDef
  17. from ..core._imperative_rt.core2 import Tensor as RawTensor
  18. from ..core._imperative_rt.core2 import (
  19. apply,
  20. is_tracing_module,
  21. set_module_tracing,
  22. unset_module_tracing,
  23. )
  24. from ..core.ops.builtin import FakeQuant
  25. from ..core.ops.special import Const
  26. from ..module import Module
  27. from ..tensor import Parameter, Tensor
  28. from ..version import __version__
  29. from .module_tracer import active_module_tracer, module_tracer
  30. from .node import ModuleNode, Node, NodeMixin, TensorNode
  31. from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
  32. from .serialization import _ModuleState
  33. from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args
  34. def rstrip(s: str, __chars: str):
  35. __chars = re.escape(__chars)
  36. s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
  37. return s
  38. def get_suffix_name(prefix: str, name: str):
  39. if prefix == name:
  40. return ""
  41. matchd = re.compile("^%s\.(.*)" % prefix).match(name)
  42. if matchd is None:
  43. return None
  44. return matchd.group(1)
  45. def is_call_module(expr, module_cls: Module = None):
  46. return (
  47. isinstance(expr, CallMethod)
  48. and isinstance(expr.inputs[0], ModuleNode)
  49. and expr.method == "__call__"
  50. ) and (module_cls is None or isinstance(expr.inputs[0].owner, module_cls))
  51. def is_call_tensor_method(expr, method: Iterable[str] = None):
  52. if method and isinstance(method, str):
  53. method = (method,)
  54. return (
  55. isinstance(expr, CallMethod)
  56. and not is_call_module(expr)
  57. and (method is None or any(expr.method == f for f in method))
  58. )
  59. def is_call_function(expr, func: Iterable[Callable] = None):
  60. if func and not isinstance(func, Iterable):
  61. func = (func,)
  62. return isinstance(expr, CallFunction) and (
  63. func is None or any(expr.func == f for f in func)
  64. )
  65. def is_constant(expr):
  66. return isinstance(expr, Constant)
  67. def is_getattr(expr):
  68. return isinstance(expr, GetAttr)
  69. def is_apply_def(expr, opdef=None):
  70. return isinstance(expr, Apply) and (opdef is None or isinstance(expr.opdef, opdef))
  71. def is_input(expr):
  72. return isinstance(expr, Input)
  73. class Expr:
  74. r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
  75. ``GetAttr``, ``Input``, ``Constant``) on ``Node``.
  76. """
  77. inputs = None # type: List[Node]
  78. r"""The input Nodes of this Expr."""
  79. outputs = None # type: List[Node]
  80. r"""The output Nodes of this Expr."""
  81. const_val = None # type: List[Any]
  82. r"""The non-tensor object in the input of the operation."""
  83. arg_def = None # type: TreeDef
  84. r"""The :class:`TreeDef` used to reconstruct the input of the operation."""
  85. out_def = None # type: TreeDef
  86. r"""The :class:`TreeDef` used to reconstruct the output of the operation."""
  87. _top_graph = None # type: weakref.ReferenceType
  88. __total_id = 0
  89. def __init__(self) -> None:
  90. self._id = Expr.__total_id
  91. Expr.__total_id += 1
  92. self._disable_remove = False
  93. def enable_remove(self):
  94. self._disable_remove = False
  95. def disable_remove(self):
  96. self._disable_remove = True
  97. def add_inputs(self, vals):
  98. if not isinstance(vals, collections.abc.Sequence):
  99. vals = (vals,)
  100. for val in vals:
  101. node = NodeMixin.get(val, None)
  102. if isinstance(node, (TensorNode, ModuleNode)):
  103. self.inputs.append(node)
  104. node.users.append(self)
  105. else:
  106. assert node is None
  107. assert not isinstance(val, (Module, RawTensor))
  108. assert _is_leaf(val) and _is_const_leaf(val)
  109. idx = len(self.inputs) + len(self.const_val)
  110. self.const_val.append((idx, val))
  111. def add_outputs(self, outputs):
  112. assert active_module_tracer() is not None
  113. self.outputs = []
  114. if outputs is None:
  115. return
  116. current_graph = active_module_tracer().current_scope()
  117. if not isinstance(outputs, collections.Sequence):
  118. outputs = (outputs,)
  119. for i in outputs:
  120. assert isinstance(i, RawTensor), "The output must be a Tensor"
  121. node = NodeMixin.get_wrapped_type(i)(expr=self, name="", qualname="",)
  122. NodeMixin.wrap_safe(i, node)
  123. self.outputs.append(node)
  124. current_graph._namespace.auto_naming_for_outputs(self)
  125. def unflatten_args(self, inputs):
  126. assert self.arg_def is not None, "{} expr doesn't have args/kwargs".format(
  127. type(self).__name__
  128. )
  129. inputs = list(inputs)
  130. for idx, val in self.const_val:
  131. inputs.insert(idx, val)
  132. args, kwargs = self.arg_def.unflatten(inputs)
  133. return args, kwargs
  134. def replace_inputs(self, repl_dict: Dict[Node, Node]):
  135. r"""Replace the input Nodes of this Expr.
  136. Args:
  137. repl_dict: the map {old_Node: new_Node} that specifies how to replace the input Nodes.
  138. """
  139. while repl_dict:
  140. node, repl_node = repl_dict.popitem()
  141. assert type(node) == type(repl_node)
  142. assert node in self.inputs, "({}) is not in the ({})".format(node, self)
  143. assert (
  144. repl_node.top_graph == node.top_graph
  145. ), "({}) and ({}) are not in the same graph".format(node, repl_node)
  146. graph = self.top_graph
  147. repl_expr_idx = graph._exprs.index(repl_node.expr)
  148. self_idx = graph._exprs.index(self)
  149. assert (
  150. repl_expr_idx < self_idx
  151. ), "({}) must be generated before ({})".format(repl_node, self)
  152. idx = self.inputs.index(node)
  153. self.inputs[idx] = repl_node
  154. node.users.remove(self)
  155. repl_node.users.append(self)
  156. @property
  157. def _support_set_args_kwargs(self):
  158. return False
  159. def set_args_kwargs(self, *args, **kwargs):
  160. r""" Set args and kwargs for Expr.
  161. """
  162. assert (
  163. self._support_set_args_kwargs
  164. ), "Doesn't support set args/kwargs for {} expr".format(type(self).__name__)
  165. args, kwargs = _convert_kwargs_to_args(self._get_func(), args, kwargs)
  166. inputs, arg_def = tree_flatten((args, kwargs))
  167. orig_inputs = self.inputs
  168. self.inputs = []
  169. self.const_val = []
  170. for val in inputs:
  171. if isinstance(val, (TensorNode, ModuleNode)):
  172. self.inputs.append(val)
  173. else:
  174. assert _is_leaf(val) and _is_const_leaf(val)
  175. idx = len(self.inputs) + len(self.const_val)
  176. self.const_val.append((idx, val))
  177. for n in orig_inputs:
  178. if n not in self.inputs:
  179. n.users.remove(self)
  180. for n in self.inputs:
  181. if n not in orig_inputs:
  182. n.users.append(self)
  183. self.arg_def = arg_def
  184. @property
  185. def kwargs(self):
  186. r"""Get the keyword arguments of the operation corresponding to this Expr."""
  187. _, kwargs = self.unflatten_args(self.inputs)
  188. return kwargs
  189. @property
  190. def args(self):
  191. r"""Get the positional arguments of the operation corresponding to this Expr."""
  192. args, _ = self.unflatten_args(self.inputs)
  193. return args
  194. def _get_func(self):
  195. # get called function when the expr is interpreted
  196. raise NotImplementedError
  197. @property
  198. def named_args(self):
  199. func = self._get_func()
  200. return inspect.getcallargs(func, *self.args, **self.kwargs)
  201. def set_arg(self, name, val):
  202. func = self._get_func()
  203. if name in self.kwargs:
  204. new_kwargs = self.kwargs
  205. new_kwargs[name] = val
  206. self.set_args_kwargs(*self.args, **new_kwargs)
  207. else:
  208. arg_spec = inspect.getfullargspec(func)
  209. if name in arg_spec.args:
  210. ind = arg_spec.args.index(name)
  211. new_args = list(self.args)
  212. new_args[ind] = val
  213. self.set_args_kwargs(*new_args)
  214. elif name == arg_spec.varargs:
  215. assert arg_spec.varargs is not None
  216. assert len(self.args) >= len(arg_spec.args)
  217. val = (val,) if not isinstance(val, Sequence) else val
  218. self.set_args_kwargs(*self.args[0 : len(arg_spec.args)], *val)
  219. else:
  220. assert (
  221. arg_spec.varkw is not None
  222. ), "func {} does't have argument named {}".format(func, name)
  223. new_kwargs = self.kwargs
  224. new_kwargs[name] = val
  225. self.set_args_kwargs(*self.args, **new_kwargs)
  226. @property
  227. def return_val(self):
  228. return self.out_def.unflatten(self.outputs)
  229. @return_val.setter
  230. def return_val(self, new_outputs):
  231. outputs, out_def = tree_flatten(
  232. new_outputs, is_leaf=lambda x: isinstance(x, Node)
  233. )
  234. assert all(
  235. isinstance(o, Node) for o in outputs
  236. ), "Return values of expr must be ModuleNode or TensorNode or Container with them"
  237. assert all(
  238. o.expr in (None, self) for o in outputs
  239. ), "Some nodes are produced by other expr, can not be output of expr {}".format(
  240. self
  241. )
  242. self.outputs = outputs
  243. self.out_def = out_def
  244. @property
  245. def top_graph(self):
  246. r"""Get the parent graph of this Expr."""
  247. if self._top_graph:
  248. return self._top_graph()
  249. return None
  250. @classmethod
  251. def _get_next_id(cls):
  252. return cls.__total_id
  253. @classmethod
  254. def _set_next_id(cls, id: int = 0):
  255. assert isinstance(id, int)
  256. cls.__total_id = id
  257. def __copy__(self):
  258. cls = self.__class__
  259. result = cls.__new__(cls)
  260. result.__dict__.update(self.__dict__)
  261. return result
  262. def __deepcopy__(self, memo):
  263. cls = self.__class__
  264. result = cls.__new__(cls)
  265. state = {}
  266. memo[id(self)] = result
  267. for k, v in self.__dict__.items():
  268. if not isinstance(v, weakref.ReferenceType):
  269. state[k] = copy.deepcopy(v, memo)
  270. result.__dict__.update(state)
  271. return result
  272. # expr: None (i.e. fake expression which is used to mark input)
  273. class Input(Expr):
  274. r"""A fake Expr which is used to mark the input of graph."""
  275. name = None
  276. def __init__(self, type: List[Node], name: str = "args", qualname: str = ""):
  277. super().__init__()
  278. assert type in [ModuleNode, TensorNode]
  279. assert name and qualname
  280. self.inputs = []
  281. node_cls = type if type else Node
  282. self.outputs = [
  283. node_cls(self, name=name, qualname=qualname),
  284. ]
  285. self.name = name
  286. @classmethod
  287. def make(cls, *args, **kwargs):
  288. assert active_module_tracer() is not None
  289. current_graph = active_module_tracer().current_scope()
  290. expr = cls(*args, **kwargs)
  291. out_node = expr.outputs[0]
  292. current_graph._namespace.auto_naming_for_outputs(expr)
  293. current_graph._add_input(out_node)
  294. return expr.outputs[0]
  295. def __repr__(self):
  296. return "%{}:\t{} = Input()".format(self._id, self.outputs[0])
  297. def __getstate__(self):
  298. state = {
  299. "_id": self._id,
  300. "_disable_remove": self._disable_remove,
  301. "inputs": self.inputs,
  302. "outputs": self.outputs,
  303. "name": self.name,
  304. }
  305. _check_obj_attr(state)
  306. return state
  307. # expr: outputs = getattr(inputs[0], self.name)
  308. class GetAttr(Expr):
  309. r"""``Getattr`` represents the fetch of an attribute from the ``Module`` hierarchy."""
  310. name = None
  311. r"""name: the qualified name of the attribute to be retrieved."""
  312. def __init__(
  313. self, module: ModuleNode, type: Union[Node], attr_name: str, name: str = "",
  314. ):
  315. super().__init__()
  316. assert isinstance(module, ModuleNode)
  317. assert type in [TensorNode, ModuleNode]
  318. self.inputs = [
  319. module,
  320. ]
  321. module.users.append(self)
  322. self.name = attr_name
  323. self.outputs = [
  324. type(self, name=name, qualname="{}.{}".format(module.qualname, attr_name)),
  325. ]
  326. @classmethod
  327. def make(cls, *args, **kwargs):
  328. assert active_module_tracer() is not None
  329. current_graph = active_module_tracer().current_scope()
  330. expr = cls(*args, **kwargs)
  331. current_graph._namespace.auto_naming_for_outputs(expr)
  332. current_graph._insert(expr)
  333. return expr.outputs[0]
  334. def interpret(self, *inputs):
  335. mod = inputs[0]
  336. module_path, _, name = self.name.rpartition(".")
  337. if module_path == "":
  338. return (getattr(mod, name),)
  339. module_names = module_path.split(".")
  340. for item in module_names:
  341. mod = getattr(mod, item)
  342. if not isinstance(mod, Module):
  343. raise AttributeError("`{}` is not an Module".format(item))
  344. return (getattr(mod, name),)
  345. def __repr__(self):
  346. out_type = "Tensor"
  347. if isinstance(self.outputs[0], ModuleNode):
  348. m_type = self.outputs[0].module_type
  349. out_type = m_type.__name__ if isinstance(m_type, type) else m_type[1]
  350. return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
  351. self._id, self.outputs[0], self.inputs[0], self.name, out_type
  352. )
  353. def __getstate__(self):
  354. state = {
  355. "_id": self._id,
  356. "_disable_remove": self._disable_remove,
  357. "inputs": self.inputs,
  358. "outputs": self.outputs,
  359. "name": self.name,
  360. }
  361. _check_obj_attr(state)
  362. return state
  363. # expr: outputs = inputs[0].__call__(*inputs[1:])
  364. class CallMethod(Expr):
  365. r"""``CallMethod`` represents a call to the ``__call__`` method of ``Module`` or a method of ``Tensor``.
  366. Args:
  367. node: the Node to be called.
  368. method: the method name.
  369. Default: "__call__"
  370. """
  371. def __init__(self, node, method="__call__"):
  372. super().__init__()
  373. if isinstance(node, type):
  374. assert issubclass(node, Tensor)
  375. cls = Parameter if issubclass(node, Parameter) else Tensor
  376. self.inputs = []
  377. self.const_val = [(0, cls)]
  378. else:
  379. assert isinstance(node, (TensorNode, ModuleNode))
  380. node.users.append(self)
  381. self.inputs = [
  382. node,
  383. ]
  384. self.const_val = []
  385. self.arg_def = tree_flatten(((node,), {}))[1]
  386. self.method = method
  387. @classmethod
  388. def make(cls, *args, **kwargs):
  389. assert active_module_tracer() is not None
  390. expr = cls(*args, **kwargs)
  391. active_module_tracer().current_scope()._insert(expr)
  392. return expr
  393. @property
  394. def graph(self):
  395. if isinstance(self.inputs[0], ModuleNode):
  396. m_node = self.inputs[0]
  397. if (
  398. hasattr(m_node.owner, "argdef_graph_map")
  399. and m_node.owner.argdef_graph_map
  400. ):
  401. assert self.arg_def in m_node.owner.argdef_graph_map
  402. return m_node.owner.argdef_graph_map[self.arg_def]
  403. return None
  404. def interpret(self, *inputs):
  405. args, kwargs = self.unflatten_args(inputs)
  406. obj = args[0]
  407. meth = getattr(obj, self.method)
  408. if inspect.ismethod(meth):
  409. args = args[1:]
  410. outputs = getattr(obj, self.method)(*args, **kwargs)
  411. if self.method == "__setitem__":
  412. outputs = obj
  413. if outputs is None:
  414. return outputs
  415. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  416. return outputs
  417. def _get_func(self):
  418. if isinstance(self.args[0], type):
  419. obj_type = self.args[0]
  420. elif isinstance(self.args[0], ModuleNode):
  421. obj_type = self.args[0].module_type
  422. else:
  423. assert isinstance(self.args[0], TensorNode)
  424. obj_type = Tensor
  425. meth = getattr(
  426. obj_type, "forward" if issubclass(obj_type, Module) else self.method
  427. )
  428. return meth
  429. @property
  430. def _support_set_args_kwargs(self):
  431. # only expr call tensor method or builtin module support modify args/kwargs
  432. return (
  433. isinstance(self.args[0], (TensorNode, type))
  434. or self.args[0].module_type is not Module
  435. )
  436. def __repr__(self):
  437. args = ", ".join(str(i) for i in self.args[1:])
  438. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  439. outputs = self.outputs
  440. if self.out_def:
  441. outputs = self.out_def.unflatten(outputs)
  442. method = ".%s" % self.method
  443. if method == ".__call__":
  444. method = ""
  445. return "%{}:\t{}{}{}({})".format(
  446. self._id,
  447. str(outputs) + " = " if outputs else "",
  448. self.args[0],
  449. method,
  450. ", ".join([args, kwargs]),
  451. )
  452. def __getstate__(self):
  453. state = {
  454. "_id": self._id,
  455. "_disable_remove": self._disable_remove,
  456. "inputs": self.inputs,
  457. "const_val": self.const_val,
  458. "method": self.method,
  459. "arg_def": self.arg_def,
  460. "out_def": self.out_def,
  461. "outputs": self.outputs,
  462. "version": __version__,
  463. }
  464. _check_obj_attr(state)
  465. return state
  466. # expr: outputs = apply(self.opdef, *inputs)
  467. class Apply(Expr):
  468. r"""``Apply`` represents a call to :func:`apply`.
  469. Args:
  470. opdef: the applied :class:`OpDef`.
  471. """
  472. opdef = None
  473. def __init__(self, opdef):
  474. super().__init__()
  475. assert isinstance(opdef, OpDef)
  476. self.opdef = opdef
  477. self.inputs = []
  478. @classmethod
  479. def make(cls, *args, **kwargs):
  480. assert active_module_tracer() is not None
  481. expr = cls(*args, **kwargs)
  482. active_module_tracer().current_scope()._insert(expr)
  483. return expr
  484. def interpret(self, *inputs):
  485. return apply(self.opdef, *inputs)
  486. def __repr__(self):
  487. return "%{}:\t{} = {}({})".format(
  488. self._id,
  489. ", ".join(str(i) for i in self.outputs),
  490. self.opdef,
  491. ", ".join(str(i) for i in self.inputs),
  492. )
  493. def __getstate__(self):
  494. opdef_state = self.opdef.__getstate__()
  495. opdef_state["opdef_type"] = type(self.opdef)
  496. state = {
  497. "_id": self._id,
  498. "_disable_remove": self._disable_remove,
  499. "opdef_state": opdef_state,
  500. "inputs": self.inputs,
  501. "outputs": self.outputs,
  502. "version": __version__,
  503. }
  504. _check_obj_attr(state)
  505. return state
  506. def __setstate__(self, state):
  507. # compat with mge 1.6
  508. if "opdef" in state and "opdef_state" not in state:
  509. opdef_state = state.pop("opdef")
  510. opdef_state["opdef_type"] = opdef_state.pop("type")
  511. state["opdef_state"] = opdef_state
  512. self.__dict__.update(state)
  513. assert isinstance(state["opdef_state"], dict)
  514. opdef_state = state["opdef_state"].copy()
  515. opdef_type = opdef_state.pop("opdef_type")
  516. opdef_obj = opdef_type()
  517. opdef_obj.__setstate__(opdef_state)
  518. setattr(self, "opdef", opdef_obj)
  519. @classmethod
  520. def apply_module_trace_hook(cls, opdef, *inputs):
  521. for i in inputs:
  522. node = NodeMixin.get(i, None)
  523. if node is None: # capture as constant
  524. NodeMixin.wrap_safe(i, Constant.make(i))
  525. if isinstance(opdef, FakeQuant):
  526. inp_nodes = [NodeMixin.get(inputs[0])]
  527. for i in inputs[1:]:
  528. node = Constant.make(i)
  529. inp_nodes.append(node)
  530. apply_node = cls.make(opdef)
  531. for n in inp_nodes:
  532. n.users.append(apply_node)
  533. apply_node.inputs = inp_nodes
  534. else:
  535. apply_node = cls.make(opdef)
  536. apply_node.add_inputs(inputs)
  537. assert not apply_node.const_val
  538. unset_module_tracing()
  539. outputs = apply(opdef, *inputs)
  540. set_module_tracing()
  541. apply_node.add_outputs(outputs)
  542. for n, v in zip(apply_node.outputs, outputs):
  543. NodeMixin.wrap_safe(v, n)
  544. return list(outputs)
  545. class CallFunction(Expr):
  546. r"""``CallFunction`` represents a call to a built-in function.
  547. Args:
  548. func: a built-in function.
  549. """
  550. def __init__(self, func):
  551. super().__init__()
  552. assert isinstance(func, Callable)
  553. self.func = func
  554. self.const_val = []
  555. self.inputs = []
  556. @classmethod
  557. def make(cls, *args, **kwargs):
  558. assert active_module_tracer() is not None
  559. expr = cls(*args, **kwargs)
  560. active_module_tracer().current_scope()._insert(expr)
  561. return expr
  562. def interpret(self, *inputs):
  563. args, kwargs = self.unflatten_args(inputs)
  564. func = (
  565. self.func
  566. if not is_tracing_module()
  567. else active_module_tracer().patcher.wrap_fn(self.func)
  568. )
  569. outputs = func(*args, **kwargs)
  570. if outputs is None:
  571. return outputs
  572. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  573. return outputs
  574. def _get_func(self):
  575. return self.func
  576. @property
  577. def _support_set_args_kwargs(self):
  578. return True
  579. def __repr__(self):
  580. args = ", ".join(str(i) for i in self.args)
  581. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  582. outputs = self.outputs
  583. if self.out_def:
  584. outputs = self.out_def.unflatten(outputs)
  585. return "%{}:\t{}{}({})".format(
  586. self._id,
  587. str(outputs) + " = " if outputs else "",
  588. self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
  589. ", ".join([args, kwargs]),
  590. )
  591. def __getstate__(self):
  592. state = {
  593. "_id": self._id,
  594. "_disable_remove": self._disable_remove,
  595. "func": (self.func.__module__, self.func.__qualname__),
  596. "const_val": self.const_val,
  597. "inputs": self.inputs,
  598. "arg_def": self.arg_def,
  599. "out_def": self.out_def,
  600. "outputs": self.outputs,
  601. "version": __version__,
  602. }
  603. _check_obj_attr(state)
  604. return state
  605. def __setstate__(self, state):
  606. self.__dict__.update(state)
  607. try:
  608. if isinstance(self.func, tuple):
  609. mname, fname = self.func
  610. f = import_module(mname)
  611. for i in fname.split("."):
  612. f = getattr(f, i)
  613. self.func = f
  614. except Exception:
  615. pass
  616. # expr outputs = self.value
  617. class Constant(Expr):
  618. r"""``Constant`` represents a ``Tensor`` or "Module" which is not the attribute of a Module.
  619. Args:
  620. c: a const Tensor or Module.
  621. name: the name of output Node.
  622. """
  623. value = None
  624. r"""The const Tensor or Module"""
  625. # TODO: constant cache to reduce the size of dumped model
  626. _constant_cache = {}
  627. def __init__(self, c, name: str = "", qualname: str = ""):
  628. super().__init__()
  629. assert isinstance(c, (RawTensor, Module))
  630. if isinstance(c, Module):
  631. assert module_tracer.is_builtin(c) or c.is_qat
  632. if isinstance(c, RawTensor):
  633. if is_tracing_module():
  634. unset_module_tracing()
  635. c = Tensor(c)
  636. set_module_tracing()
  637. else:
  638. c = Tensor(c)
  639. self.value = c
  640. self.name = name
  641. self.inputs = []
  642. node_cls = NodeMixin.get_wrapped_type(c)
  643. self.outputs = [
  644. node_cls(self, name=name, qualname=qualname),
  645. ]
  646. @classmethod
  647. def make(cls, *args, **kwargs):
  648. assert active_module_tracer() is not None
  649. expr = cls(*args, **kwargs)
  650. current_graph = active_module_tracer().current_scope()
  651. current_graph._namespace.auto_naming_for_outputs(expr)
  652. current_graph._insert(expr)
  653. return expr.outputs[0]
  654. def interpret(self, *inputs):
  655. if isinstance(self.value, RawTensor):
  656. return Const(self.value.numpy())()
  657. return (self.value,)
  658. def __repr__(self):
  659. name = self.name
  660. if name is None:
  661. name = type(self.value)
  662. node_type = "Module"
  663. if isinstance(self.outputs[0], TensorNode):
  664. node_type = "Tensor"
  665. return "%{}:\t{} = Constant({}) -> ({})".format(
  666. self._id, self.outputs[0], name, node_type
  667. )
  668. def __getstate__(self):
  669. state = {
  670. "_id": self._id,
  671. "_disable_remove": self._disable_remove,
  672. "value": self.value,
  673. "name": self.name,
  674. "inputs": self.inputs,
  675. "outputs": self.outputs,
  676. }
  677. _check_obj_attr(state)
  678. if isinstance(self.value, RawTensor):
  679. state["value"] = Tensor(self.value)
  680. if isinstance(self.value, Module) and module_tracer.is_builtin(self.value):
  681. _check_builtin_module_attr(self.value)
  682. state["value"] = _ModuleState.get_module_state(self.value)
  683. return state
  684. def __setstate__(self, state):
  685. for k, v in state.items():
  686. if isinstance(v, _ModuleState):
  687. state[k] = v.to_module()
  688. self.__dict__.update(state)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台