You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import builtins
  9. import collections
  10. import copy
  11. import inspect
  12. import re
  13. import weakref
  14. from importlib import import_module
  15. from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
  16. from ..core._imperative_rt import OpDef
  17. from ..core._imperative_rt.core2 import Tensor as RawTensor
  18. from ..core._imperative_rt.core2 import (
  19. apply,
  20. is_tracing_module,
  21. set_module_tracing,
  22. unset_module_tracing,
  23. )
  24. from ..core.ops.builtin import FakeQuant
  25. from ..core.ops.special import Const
  26. from ..module import Module
  27. from ..tensor import Parameter, Tensor
  28. from ..version import __version__
  29. from .module_tracer import active_module_tracer, module_tracer
  30. from .node import ModuleNode, Node, NodeMixin, TensorNode
  31. from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
  32. from .serialization import _ModuleState
  33. from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args
  34. def rstrip(s: str, __chars: str):
  35. __chars = re.escape(__chars)
  36. s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
  37. return s
  38. def get_suffix_name(prefix: str, name: str):
  39. if prefix == name:
  40. return ""
  41. matchd = re.compile("^%s\.(.*)" % prefix).match(name)
  42. if matchd is None:
  43. return None
  44. return matchd.group(1)
  45. def is_call_module(expr):
  46. return (
  47. isinstance(expr, CallMethod)
  48. and isinstance(expr.inputs[0], ModuleNode)
  49. and expr.method == "__call__"
  50. )
  51. def is_call_tensor_method(expr):
  52. return isinstance(expr, CallMethod) and not is_call_module(expr)
  53. def is_call_function(expr):
  54. return isinstance(expr, CallFunction)
  55. def is_constant(expr):
  56. return isinstance(expr, Constant)
  57. def is_getattr(expr):
  58. return isinstance(expr, GetAttr)
  59. def is_apply_def(expr):
  60. return isinstance(expr, Apply)
  61. def is_input(expr):
  62. return isinstance(expr, Input)
  63. class Expr:
  64. r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
  65. ``GetAttr``, ``Input``, ``Constant``) on ``Node``.
  66. """
  67. inputs = None # type: List[Node]
  68. r"""The input Nodes of this Expr."""
  69. outputs = None # type: List[Node]
  70. r"""The output Nodes of this Expr."""
  71. const_val = None # type: List[Any]
  72. r"""The non-tensor object in the input of the operation."""
  73. arg_def = None # type: TreeDef
  74. r"""The :class:`TreeDef` used to reconstruct the input of the operation."""
  75. out_def = None # type: TreeDef
  76. r"""The :class:`TreeDef` used to reconstruct the output of the operation."""
  77. _top_graph = None # type: weakref.ReferenceType
  78. __total_id = 0
  79. def __init__(self) -> None:
  80. self._id = Expr.__total_id
  81. Expr.__total_id += 1
  82. self._disable_remove = False
  83. def enable_remove(self):
  84. self._disable_remove = False
  85. def disable_remove(self):
  86. self._disable_remove = True
  87. def add_inputs(self, vals):
  88. if not isinstance(vals, collections.abc.Sequence):
  89. vals = (vals,)
  90. for val in vals:
  91. node = NodeMixin.get(val, None)
  92. if isinstance(node, (TensorNode, ModuleNode)):
  93. self.inputs.append(node)
  94. node.users.append(self)
  95. else:
  96. assert node is None
  97. assert not isinstance(val, (Module, RawTensor))
  98. assert _is_leaf(val) and _is_const_leaf(val)
  99. idx = len(self.inputs) + len(self.const_val)
  100. self.const_val.append((idx, val))
  101. def add_outputs(self, outputs):
  102. assert active_module_tracer() is not None
  103. self.outputs = []
  104. if outputs is None:
  105. return
  106. current_graph = active_module_tracer().current_scope()
  107. if not isinstance(outputs, collections.Sequence):
  108. outputs = (outputs,)
  109. for i in outputs:
  110. assert isinstance(i, RawTensor), "The output must be a Tensor"
  111. node = NodeMixin.get_wrapped_type(i)(expr=self, name="", qualname="",)
  112. NodeMixin.wrap_safe(i, node)
  113. self.outputs.append(node)
  114. current_graph._namespace.auto_naming_for_outputs(self)
  115. def unflatten_args(self, inputs):
  116. assert self.arg_def is not None, "{} expr doesn't have args/kwargs".format(
  117. type(self).__name__
  118. )
  119. inputs = list(inputs)
  120. for idx, val in self.const_val:
  121. inputs.insert(idx, val)
  122. args, kwargs = self.arg_def.unflatten(inputs)
  123. return args, kwargs
  124. def replace_inputs(self, repl_dict: Dict[Node, Node]):
  125. r"""Replace the input Nodes of this Expr.
  126. Args:
  127. repl_dict: the map {old_Node: new_Node} that specifies how to replace the input Nodes.
  128. """
  129. while repl_dict:
  130. node, repl_node = repl_dict.popitem()
  131. assert type(node) == type(repl_node)
  132. assert node in self.inputs, "({}) is not in the ({})".format(node, self)
  133. assert (
  134. repl_node.top_graph == node.top_graph
  135. ), "({}) and ({}) are not in the same graph".format(node, repl_node)
  136. graph = self.top_graph
  137. repl_expr_idx = graph._exprs.index(repl_node.expr)
  138. self_idx = graph._exprs.index(self)
  139. assert (
  140. repl_expr_idx < self_idx
  141. ), "({}) must be generated before ({})".format(repl_node, self)
  142. idx = self.inputs.index(node)
  143. self.inputs[idx] = repl_node
  144. node.users.remove(self)
  145. repl_node.users.append(self)
  146. @property
  147. def _support_set_args_kwargs(self):
  148. return False
  149. def set_args_kwargs(self, *args, **kwargs):
  150. r""" Set args and kwargs for Expr.
  151. """
  152. assert (
  153. self._support_set_args_kwargs
  154. ), "Doesn't support set args/kwargs for {} expr".format(type(self).__name__)
  155. args, kwargs = _convert_kwargs_to_args(self._get_func(), args, kwargs)
  156. inputs, arg_def = tree_flatten((args, kwargs))
  157. orig_inputs = self.inputs
  158. self.inputs = []
  159. self.const_val = []
  160. for val in inputs:
  161. if isinstance(val, (TensorNode, ModuleNode)):
  162. self.inputs.append(val)
  163. else:
  164. assert _is_leaf(val) and _is_const_leaf(val)
  165. idx = len(self.inputs) + len(self.const_val)
  166. self.const_val.append((idx, val))
  167. for n in orig_inputs:
  168. if n not in self.inputs:
  169. n.users.remove(self)
  170. for n in self.inputs:
  171. if n not in orig_inputs:
  172. n.users.append(self)
  173. self.arg_def = arg_def
  174. @property
  175. def kwargs(self):
  176. r"""Get the keyword arguments of the operation corresponding to this Expr."""
  177. _, kwargs = self.unflatten_args(self.inputs)
  178. return kwargs
  179. @property
  180. def args(self):
  181. r"""Get the positional arguments of the operation corresponding to this Expr."""
  182. args, _ = self.unflatten_args(self.inputs)
  183. return args
  184. def _get_func(self):
  185. # get called function when the expr is interpreted
  186. raise NotImplementedError
  187. @property
  188. def named_args(self):
  189. func = self._get_func()
  190. return inspect.getcallargs(func, *self.args, **self.kwargs)
  191. def set_arg(self, name, val):
  192. func = self._get_func()
  193. if name in self.kwargs:
  194. new_kwargs = self.kwargs
  195. new_kwargs[name] = val
  196. self.set_args_kwargs(*self.args, **new_kwargs)
  197. else:
  198. arg_spec = inspect.getfullargspec(func)
  199. if name in arg_spec.args:
  200. ind = arg_spec.args.index(name)
  201. new_args = list(self.args)
  202. new_args[ind] = val
  203. self.set_args_kwargs(*new_args)
  204. elif name == arg_spec.varargs:
  205. assert arg_spec.varargs is not None
  206. assert len(self.args) >= len(arg_spec.args)
  207. val = (val,) if not isinstance(val, Sequence) else val
  208. self.set_args_kwargs(*self.args[0 : len(arg_spec.args)], *val)
  209. else:
  210. assert (
  211. arg_spec.varkw is not None
  212. ), "func {} does't have argument named {}".format(func, name)
  213. new_kwargs = self.kwargs
  214. new_kwargs[name] = val
  215. self.set_args_kwargs(*self.args, **new_kwargs)
  216. @property
  217. def return_val(self):
  218. return self.out_def.unflatten(self.outputs)
  219. @return_val.setter
  220. def return_val(self, new_outputs):
  221. outputs, out_def = tree_flatten(
  222. new_outputs, is_leaf=lambda x: isinstance(x, Node)
  223. )
  224. assert all(
  225. isinstance(o, Node) for o in outputs
  226. ), "Return values of expr must be ModuleNode or TensorNode or Container with them"
  227. assert all(
  228. o.expr in (None, self) for o in outputs
  229. ), "Some nodes are produced by other expr, can not be output of expr {}".format(
  230. self
  231. )
  232. self.outputs = outputs
  233. self.out_def = out_def
  234. @property
  235. def top_graph(self):
  236. r"""Get the parent graph of this Expr."""
  237. if self._top_graph:
  238. return self._top_graph()
  239. return None
  240. @classmethod
  241. def _get_next_id(cls):
  242. return cls.__total_id
  243. @classmethod
  244. def _set_next_id(cls, id: int = 0):
  245. assert isinstance(id, int)
  246. cls.__total_id = id
  247. def __copy__(self):
  248. cls = self.__class__
  249. result = cls.__new__(cls)
  250. result.__dict__.update(self.__dict__)
  251. return result
  252. def __deepcopy__(self, memo):
  253. cls = self.__class__
  254. result = cls.__new__(cls)
  255. state = {}
  256. memo[id(self)] = result
  257. for k, v in self.__dict__.items():
  258. if not isinstance(v, weakref.ReferenceType):
  259. state[k] = copy.deepcopy(v, memo)
  260. result.__dict__.update(state)
  261. return result
  262. # expr: None (i.e. fake expression which is used to mark input)
  263. class Input(Expr):
  264. r"""A fake Expr which is used to mark the input of graph."""
  265. name = None
  266. def __init__(self, type: List[Node], name: str = "args", qualname: str = ""):
  267. super().__init__()
  268. assert type in [ModuleNode, TensorNode]
  269. assert name and qualname
  270. self.inputs = []
  271. node_cls = type if type else Node
  272. self.outputs = [
  273. node_cls(self, name=name, qualname=qualname),
  274. ]
  275. self.name = name
  276. @classmethod
  277. def make(cls, *args, **kwargs):
  278. assert active_module_tracer() is not None
  279. current_graph = active_module_tracer().current_scope()
  280. expr = cls(*args, **kwargs)
  281. out_node = expr.outputs[0]
  282. current_graph._namespace.auto_naming_for_outputs(expr)
  283. current_graph._add_input(out_node)
  284. return expr.outputs[0]
  285. def __repr__(self):
  286. return "%{}:\t{} = Input()".format(self._id, self.outputs[0])
  287. def __getstate__(self):
  288. state = {
  289. "_id": self._id,
  290. "_disable_remove": self._disable_remove,
  291. "inputs": self.inputs,
  292. "outputs": self.outputs,
  293. "name": self.name,
  294. }
  295. _check_obj_attr(state)
  296. return state
  297. # expr: outputs = getattr(inputs[0], self.name)
  298. class GetAttr(Expr):
  299. r"""``Getattr`` represents the fetch of an attribute from the ``Module`` hierarchy."""
  300. name = None
  301. r"""name: the qualified name of the attribute to be retrieved."""
  302. def __init__(
  303. self, module: ModuleNode, type: Union[Node], attr_name: str, name: str = "",
  304. ):
  305. super().__init__()
  306. assert isinstance(module, ModuleNode)
  307. assert type in [TensorNode, ModuleNode]
  308. self.inputs = [
  309. module,
  310. ]
  311. module.users.append(self)
  312. self.name = attr_name
  313. self.outputs = [
  314. type(self, name=name, qualname="{}.{}".format(module.qualname, attr_name)),
  315. ]
  316. @classmethod
  317. def make(cls, *args, **kwargs):
  318. assert active_module_tracer() is not None
  319. current_graph = active_module_tracer().current_scope()
  320. expr = cls(*args, **kwargs)
  321. current_graph._namespace.auto_naming_for_outputs(expr)
  322. current_graph._insert(expr)
  323. return expr.outputs[0]
  324. def interpret(self, *inputs):
  325. mod = inputs[0]
  326. module_path, _, name = self.name.rpartition(".")
  327. if module_path == "":
  328. return (getattr(mod, name),)
  329. module_names = module_path.split(".")
  330. for item in module_names:
  331. mod = getattr(mod, item)
  332. if not isinstance(mod, Module):
  333. raise AttributeError("`{}` is not an Module".format(item))
  334. return (getattr(mod, name),)
  335. def __repr__(self):
  336. out_type = "Tensor"
  337. if isinstance(self.outputs[0], ModuleNode):
  338. m_type = self.outputs[0].module_type
  339. out_type = m_type.__name__ if isinstance(m_type, type) else m_type[1]
  340. return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
  341. self._id, self.outputs[0], self.inputs[0], self.name, out_type
  342. )
  343. def __getstate__(self):
  344. state = {
  345. "_id": self._id,
  346. "_disable_remove": self._disable_remove,
  347. "inputs": self.inputs,
  348. "outputs": self.outputs,
  349. "name": self.name,
  350. }
  351. _check_obj_attr(state)
  352. return state
  353. # expr: outputs = inputs[0].__call__(*inputs[1:])
  354. class CallMethod(Expr):
  355. r"""``CallMethod`` represents a call to the ``__call__`` method of ``Module`` or a method of ``Tensor``.
  356. Args:
  357. node: the Node to be called.
  358. method: the method name.
  359. Default: "__call__"
  360. """
  361. def __init__(self, node, method="__call__"):
  362. super().__init__()
  363. if isinstance(node, type):
  364. assert issubclass(node, Tensor)
  365. cls = Parameter if issubclass(node, Parameter) else Tensor
  366. self.inputs = []
  367. self.const_val = [(0, cls)]
  368. else:
  369. assert isinstance(node, (TensorNode, ModuleNode))
  370. node.users.append(self)
  371. self.inputs = [
  372. node,
  373. ]
  374. self.const_val = []
  375. self.arg_def = tree_flatten(((node,), {}))[1]
  376. self.method = method
  377. @classmethod
  378. def make(cls, *args, **kwargs):
  379. assert active_module_tracer() is not None
  380. expr = cls(*args, **kwargs)
  381. active_module_tracer().current_scope()._insert(expr)
  382. return expr
  383. @property
  384. def graph(self):
  385. if isinstance(self.inputs[0], ModuleNode):
  386. m_node = self.inputs[0]
  387. if (
  388. hasattr(m_node.owner, "argdef_graph_map")
  389. and m_node.owner.argdef_graph_map
  390. ):
  391. assert self.arg_def in m_node.owner.argdef_graph_map
  392. return m_node.owner.argdef_graph_map[self.arg_def]
  393. return None
  394. def interpret(self, *inputs):
  395. args, kwargs = self.unflatten_args(inputs)
  396. obj = args[0]
  397. meth = getattr(obj, self.method)
  398. if inspect.ismethod(meth):
  399. args = args[1:]
  400. outputs = getattr(obj, self.method)(*args, **kwargs)
  401. if self.method == "__setitem__":
  402. outputs = obj
  403. if outputs is None:
  404. return outputs
  405. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  406. return outputs
  407. def _get_func(self):
  408. if isinstance(self.args[0], type):
  409. obj_type = self.args[0]
  410. elif isinstance(self.args[0], ModuleNode):
  411. obj_type = self.args[0].module_type
  412. else:
  413. assert isinstance(self.args[0], TensorNode)
  414. obj_type = Tensor
  415. meth = getattr(
  416. obj_type, "forward" if issubclass(obj_type, Module) else self.method
  417. )
  418. return meth
  419. @property
  420. def _support_set_args_kwargs(self):
  421. # only expr call tensor method or builtin module support modify args/kwargs
  422. return (
  423. isinstance(self.args[0], (TensorNode, type))
  424. or self.args[0].module_type is not Module
  425. )
  426. def __repr__(self):
  427. args = ", ".join(str(i) for i in self.args[1:])
  428. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  429. outputs = self.outputs
  430. if self.out_def:
  431. outputs = self.out_def.unflatten(outputs)
  432. method = ".%s" % self.method
  433. if method == ".__call__":
  434. method = ""
  435. return "%{}:\t{}{}{}({})".format(
  436. self._id,
  437. str(outputs) + " = " if outputs else "",
  438. self.args[0],
  439. method,
  440. ", ".join([args, kwargs]),
  441. )
  442. def __getstate__(self):
  443. state = {
  444. "_id": self._id,
  445. "_disable_remove": self._disable_remove,
  446. "inputs": self.inputs,
  447. "const_val": self.const_val,
  448. "method": self.method,
  449. "arg_def": self.arg_def,
  450. "out_def": self.out_def,
  451. "outputs": self.outputs,
  452. "version": __version__,
  453. }
  454. _check_obj_attr(state)
  455. return state
  456. # expr: outputs = apply(self.opdef, *inputs)
  457. class Apply(Expr):
  458. r"""``Apply`` represents a call to :func:`apply`.
  459. Args:
  460. opdef: the applied :class:`OpDef`.
  461. """
  462. opdef = None
  463. def __init__(self, opdef):
  464. super().__init__()
  465. assert isinstance(opdef, OpDef)
  466. self.opdef = opdef
  467. self.inputs = []
  468. @classmethod
  469. def make(cls, *args, **kwargs):
  470. assert active_module_tracer() is not None
  471. expr = cls(*args, **kwargs)
  472. active_module_tracer().current_scope()._insert(expr)
  473. return expr
  474. def interpret(self, *inputs):
  475. return apply(self.opdef, *inputs)
  476. def __repr__(self):
  477. return "%{}:\t{} = {}({})".format(
  478. self._id,
  479. ", ".join(str(i) for i in self.outputs),
  480. self.opdef,
  481. ", ".join(str(i) for i in self.inputs),
  482. )
  483. def __getstate__(self):
  484. opdef_state = self.opdef.__getstate__()
  485. opdef_state["opdef_type"] = type(self.opdef)
  486. state = {
  487. "_id": self._id,
  488. "_disable_remove": self._disable_remove,
  489. "opdef_state": opdef_state,
  490. "inputs": self.inputs,
  491. "outputs": self.outputs,
  492. "version": __version__,
  493. }
  494. _check_obj_attr(state)
  495. return state
  496. def __setstate__(self, state):
  497. # compat with mge 1.6
  498. if "opdef" in state and "opdef_state" not in state:
  499. opdef_state = state.pop("opdef")
  500. opdef_state["opdef_type"] = opdef_state.pop("type")
  501. state["opdef_state"] = opdef_state
  502. self.__dict__.update(state)
  503. assert isinstance(state["opdef_state"], dict)
  504. opdef_state = state["opdef_state"].copy()
  505. opdef_type = opdef_state.pop("opdef_type")
  506. opdef_obj = opdef_type()
  507. opdef_obj.__setstate__(opdef_state)
  508. setattr(self, "opdef", opdef_obj)
  509. @classmethod
  510. def apply_module_trace_hook(cls, opdef, *inputs):
  511. for i in inputs:
  512. node = NodeMixin.get(i, None)
  513. if node is None: # capture as constant
  514. NodeMixin.wrap_safe(i, Constant.make(i))
  515. if isinstance(opdef, FakeQuant):
  516. inp_nodes = [NodeMixin.get(inputs[0])]
  517. for i in inputs[1:]:
  518. node = Constant.make(i)
  519. inp_nodes.append(node)
  520. apply_node = cls.make(opdef)
  521. for n in inp_nodes:
  522. n.users.append(apply_node)
  523. apply_node.inputs = inp_nodes
  524. else:
  525. apply_node = cls.make(opdef)
  526. apply_node.add_inputs(inputs)
  527. assert not apply_node.const_val
  528. unset_module_tracing()
  529. outputs = apply(opdef, *inputs)
  530. set_module_tracing()
  531. apply_node.add_outputs(outputs)
  532. for n, v in zip(apply_node.outputs, outputs):
  533. NodeMixin.wrap_safe(v, n)
  534. return list(outputs)
  535. class CallFunction(Expr):
  536. r"""``CallFunction`` represents a call to a built-in function.
  537. Args:
  538. func: a built-in function.
  539. """
  540. def __init__(self, func):
  541. super().__init__()
  542. assert isinstance(func, Callable)
  543. self.func = func
  544. self.const_val = []
  545. self.inputs = []
  546. @classmethod
  547. def make(cls, *args, **kwargs):
  548. assert active_module_tracer() is not None
  549. expr = cls(*args, **kwargs)
  550. active_module_tracer().current_scope()._insert(expr)
  551. return expr
  552. def interpret(self, *inputs):
  553. args, kwargs = self.unflatten_args(inputs)
  554. func = (
  555. self.func
  556. if not is_tracing_module()
  557. else active_module_tracer().patcher.wrap_fn(self.func)
  558. )
  559. outputs = func(*args, **kwargs)
  560. if outputs is None:
  561. return outputs
  562. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  563. return outputs
  564. def _get_func(self):
  565. return self.func
  566. @property
  567. def _support_set_args_kwargs(self):
  568. return True
  569. def __repr__(self):
  570. args = ", ".join(str(i) for i in self.args)
  571. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  572. outputs = self.outputs
  573. if self.out_def:
  574. outputs = self.out_def.unflatten(outputs)
  575. return "%{}:\t{}{}({})".format(
  576. self._id,
  577. str(outputs) + " = " if outputs else "",
  578. self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
  579. ", ".join([args, kwargs]),
  580. )
  581. def __getstate__(self):
  582. state = {
  583. "_id": self._id,
  584. "_disable_remove": self._disable_remove,
  585. "func": (self.func.__module__, self.func.__qualname__),
  586. "const_val": self.const_val,
  587. "inputs": self.inputs,
  588. "arg_def": self.arg_def,
  589. "out_def": self.out_def,
  590. "outputs": self.outputs,
  591. "version": __version__,
  592. }
  593. _check_obj_attr(state)
  594. return state
  595. def __setstate__(self, state):
  596. self.__dict__.update(state)
  597. try:
  598. if isinstance(self.func, tuple):
  599. mname, fname = self.func
  600. f = import_module(mname)
  601. for i in fname.split("."):
  602. f = getattr(f, i)
  603. self.func = f
  604. except Exception:
  605. pass
  606. # expr outputs = self.value
  607. class Constant(Expr):
  608. r"""``Constant`` represents a ``Tensor`` or "Module" which is not the attribute of a Module.
  609. Args:
  610. c: a const Tensor or Module.
  611. name: the name of output Node.
  612. """
  613. value = None
  614. r"""The const Tensor or Module"""
  615. # TODO: constant cache to reduce the size of dumped model
  616. _constant_cache = {}
  617. def __init__(self, c, name: str = "", qualname: str = ""):
  618. super().__init__()
  619. assert isinstance(c, (RawTensor, Module))
  620. if isinstance(c, Module):
  621. assert module_tracer.is_builtin(c) or c.is_qat
  622. if isinstance(c, RawTensor):
  623. if is_tracing_module():
  624. unset_module_tracing()
  625. c = Tensor(c)
  626. set_module_tracing()
  627. else:
  628. c = Tensor(c)
  629. self.value = c
  630. self.name = name
  631. self.inputs = []
  632. node_cls = NodeMixin.get_wrapped_type(c)
  633. self.outputs = [
  634. node_cls(self, name=name, qualname=qualname),
  635. ]
  636. @classmethod
  637. def make(cls, *args, **kwargs):
  638. assert active_module_tracer() is not None
  639. expr = cls(*args, **kwargs)
  640. current_graph = active_module_tracer().current_scope()
  641. current_graph._namespace.auto_naming_for_outputs(expr)
  642. current_graph._insert(expr)
  643. return expr.outputs[0]
  644. def interpret(self, *inputs):
  645. if isinstance(self.value, RawTensor):
  646. return Const(self.value.numpy())()
  647. return (self.value,)
  648. def __repr__(self):
  649. name = self.name
  650. if name is None:
  651. name = type(self.value)
  652. node_type = "Module"
  653. if isinstance(self.outputs[0], TensorNode):
  654. node_type = "Tensor"
  655. return "%{}:\t{} = Constant({}) -> ({})".format(
  656. self._id, self.outputs[0], name, node_type
  657. )
  658. def __getstate__(self):
  659. state = {
  660. "_id": self._id,
  661. "_disable_remove": self._disable_remove,
  662. "value": self.value,
  663. "name": self.name,
  664. "inputs": self.inputs,
  665. "outputs": self.outputs,
  666. }
  667. _check_obj_attr(state)
  668. if isinstance(self.value, RawTensor):
  669. state["value"] = Tensor(self.value)
  670. if isinstance(self.value, Module) and module_tracer.is_builtin(self.value):
  671. _check_builtin_module_attr(self.value)
  672. state["value"] = _ModuleState.get_module_state(self.value)
  673. return state
  674. def __setstate__(self, state):
  675. for k, v in state.items():
  676. if isinstance(v, _ModuleState):
  677. state[k] = v.to_module()
  678. self.__dict__.update(state)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台