You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import builtins
  9. import collections
  10. import copy
  11. import inspect
  12. import re
  13. from typing import Callable, Dict, List
  14. from ..core._imperative_rt import OpDef
  15. from ..core._imperative_rt.core2 import Tensor as RawTensor
  16. from ..core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
  17. from ..core.ops.builtin import FakeQuant
  18. from ..core.ops.special import Const
  19. from ..module import Module
  20. from ..tensor import Parameter, Tensor
  21. from .module_tracer import active_module_tracer, module_tracer
  22. from .node import ModuleNode, Node, NodeMixin, TensorNode
  23. from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
  24. from .serialization import get_opdef_state, load_opdef_from_state
  25. def rstrip(s: str, __chars: str):
  26. __chars = re.escape(__chars)
  27. s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
  28. return s
  29. class Expr:
  30. r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
  31. ``GetAttr``, ``Input``, ``Constant``) on ``Node``.
  32. """
  33. inputs = None # type: List[Node]
  34. r"""The input Nodes of this Expr."""
  35. outputs = None # type: List[Node]
  36. r"""The output Nodes of this Expr."""
  37. const_val = None # type: List[Any]
  38. r"""The non-tensor object in the input of the operation."""
  39. arg_def = None # type: TreeDef
  40. r"""The :class:`TreeDef` used to reconstruct the input of the operation."""
  41. out_def = None # type: TreeDef
  42. r"""The :class:`TreeDef` used to reconstruct the output of the operation."""
  43. _top_graph = None # type: weakref.ReferenceType
  44. __total_id = 0
  45. def __init__(self) -> None:
  46. self._id = Expr.__total_id
  47. Expr.__total_id += 1
  48. self._disable_remove = False
  49. def enable_remove(self):
  50. self._disable_remove = False
  51. def disable_remove(self):
  52. self._disable_remove = True
  53. def add_inputs(self, vals):
  54. if not isinstance(vals, collections.abc.Sequence):
  55. vals = (vals,)
  56. for val in vals:
  57. node = NodeMixin.get(val, None)
  58. if isinstance(node, (TensorNode, ModuleNode)):
  59. self.inputs.append(node)
  60. node.users.append(self)
  61. else:
  62. assert node is None
  63. assert _is_leaf(val) and _is_const_leaf(val)
  64. idx = len(self.inputs) + len(self.const_val)
  65. self.const_val.append((idx, val))
  66. def add_outputs(self, outputs):
  67. self.outputs = []
  68. if outputs is not None:
  69. if not isinstance(outputs, collections.Sequence):
  70. outputs = (outputs,)
  71. name = None
  72. orig_name = None
  73. if isinstance(self, CallMethod):
  74. name = self.inputs[0]._name
  75. orig_name = self.inputs[0]._orig_name
  76. assert isinstance(name, str), "The name of ({}) must be a str".format(
  77. self.inputs[0]
  78. )
  79. assert isinstance(
  80. orig_name, str
  81. ), "The orig_name of ({}) must be a str".format(self.inputs[0])
  82. name = rstrip(name, "_out")
  83. if self.method == "__call__":
  84. name += "_out"
  85. orig_name += "_out"
  86. else:
  87. strip_method = self.method.strip("_")
  88. name = "%s_out" % strip_method
  89. orig_name = name
  90. elif isinstance(self, CallFunction):
  91. name = self.func.__name__ + "_out"
  92. elif isinstance(self, Apply):
  93. name = str(self.opdef).lower() + "_out"
  94. for i in outputs:
  95. assert isinstance(i, RawTensor), "The output must be a Tensor"
  96. o_name = (
  97. active_module_tracer().current_scope()._create_unique_name(name)
  98. )
  99. self.outputs.append(
  100. NodeMixin.get_wrapped_type(i)(
  101. expr=self,
  102. name=o_name,
  103. orig_name=orig_name if orig_name else o_name,
  104. )
  105. )
  106. for i, node in zip(outputs, self.outputs,):
  107. NodeMixin.wrap_safe(i, node)
  108. def unflatten_args(self, inputs):
  109. if self.arg_def is not None:
  110. inputs = list(inputs)
  111. for idx, val in self.const_val:
  112. inputs.insert(idx, val)
  113. args, kwargs = self.arg_def.unflatten(inputs)
  114. return args, kwargs
  115. else:
  116. return inputs, {}
  117. def replace_inputs(self, repl_dict: Dict[Node, Node]):
  118. r"""Replace the input Nodes of this Expr.
  119. Args:
  120. repl_dict: the map {old_Node: new_Node} that specifies how to replace the input Nodes.
  121. """
  122. while repl_dict:
  123. node, repl_node = repl_dict.popitem()
  124. assert type(node) == type(repl_node)
  125. assert node in self.inputs, "({}) is not in the ({})".format(node, self)
  126. assert (
  127. repl_node.top_graph == node.top_graph
  128. ), "({}) and ({}) are not in the same graph".format(node, repl_node)
  129. graph = self.top_graph
  130. repl_expr_idx = graph._exprs.index(repl_node.expr)
  131. self_idx = graph._exprs.index(self)
  132. assert (
  133. repl_expr_idx < self_idx
  134. ), "({}) must be generated before ({})".format(repl_node, self)
  135. idx = self.inputs.index(node)
  136. self.inputs[idx] = repl_node
  137. user_idx = node.users.index(self)
  138. assert user_idx >= 0
  139. node.users.pop(user_idx)
  140. repl_node.users.append(self)
  141. @property
  142. def kwargs(self):
  143. r"""Get the the keyword arguments of the operation corresponding to this Expr."""
  144. _, kwargs = self.unflatten_args(self.inputs)
  145. return kwargs
  146. @property
  147. def args(self):
  148. r"""Get the the positional arguments of the operation corresponding to this Expr."""
  149. args, _ = self.unflatten_args(self.inputs)
  150. return args
  151. @property
  152. def top_graph(self):
  153. r"""Get the parent graph of this Expr."""
  154. if self._top_graph:
  155. return self._top_graph()
  156. return None
  157. def __getstate__(self):
  158. state = self.__dict__.copy()
  159. if "_top_graph" in state:
  160. state.pop("_top_graph")
  161. return state
  162. @classmethod
  163. def _get_next_id(cls):
  164. return cls.__total_id
  165. @classmethod
  166. def _set_next_id(cls, id: int = 0):
  167. assert isinstance(id, int)
  168. cls.__total_id = id
  169. # expr: None (i.e. fake expression which is used to mark input)
  170. class Input(Expr):
  171. r"""A fake Expr which is used to mark the input of graph."""
  172. name = None
  173. def __init__(self, name=None, type=None, orig_name=None):
  174. super().__init__()
  175. self.inputs = []
  176. node_cls = type if type else Node
  177. if orig_name is None:
  178. orig_name = name
  179. self.outputs = [
  180. node_cls(self, name=name, orig_name=orig_name),
  181. ]
  182. self.name = name
  183. @classmethod
  184. def make(cls, *args, **kwargs):
  185. expr = cls(*args, **kwargs)
  186. oup_node = expr.outputs[0]
  187. name = (
  188. active_module_tracer().current_scope()._create_unique_name(oup_node._name)
  189. )
  190. oup_node._name = name
  191. active_module_tracer().current_scope()._add_input(oup_node)
  192. return expr.outputs[0]
  193. def __repr__(self):
  194. return "%{}:\t{} = Input()".format(self._id, self.outputs[0])
  195. # expr: outputs = getattr(inputs[0], self.name)
  196. class GetAttr(Expr):
  197. r"""``Getattr`` represents the fetch of an attribute from the ``Module`` hierarchy."""
  198. name = None
  199. r"""name: the qualified name of the attribute to be retrieved."""
  200. def __init__(self, module, name, type=None, orig_name=None):
  201. super().__init__()
  202. assert isinstance(module, ModuleNode)
  203. self.inputs = [
  204. module,
  205. ]
  206. module.users.append(self)
  207. self.name = name
  208. node_cls = type if type else Node
  209. self.outputs = [
  210. node_cls(self, name=name, orig_name=orig_name),
  211. ]
  212. @classmethod
  213. def make(cls, *args, **kwargs):
  214. expr = cls(*args, **kwargs)
  215. module = expr.inputs[0]
  216. oup_name = expr.name
  217. while module._name != "self":
  218. oup_name = module._name + "_" + oup_name
  219. module = module.expr.inputs[0]
  220. oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name)
  221. expr.outputs[0]._name = oup_name
  222. active_module_tracer().current_scope()._insert(expr)
  223. return expr.outputs[0]
  224. def interpret(self, *inputs):
  225. return (getattr(inputs[0], self.name),)
  226. def __repr__(self):
  227. out_type = "Tensor"
  228. if isinstance(self.outputs[0], ModuleNode):
  229. out_type = self.outputs[0].module_type.__name__
  230. return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
  231. self._id, self.outputs[0], self.inputs[0], self.name, out_type
  232. )
  233. # expr: outputs = inputs[0].__call__(*inputs[1:])
  234. class CallMethod(Expr):
  235. r"""``CallMethod`` represents a call to the ``__call__`` method of ``Module`` or a method of ``Tensor``.
  236. Args:
  237. node: the Node to be called.
  238. method: the method name.
  239. Default: "__call__"
  240. """
  241. def __init__(self, node, method="__call__"):
  242. super().__init__()
  243. if isinstance(node, type):
  244. assert issubclass(node, Tensor)
  245. cls = Parameter if issubclass(node, Parameter) else Tensor
  246. self.inputs = []
  247. self.const_val = [(0, cls)]
  248. else:
  249. assert isinstance(node, (TensorNode, ModuleNode))
  250. node.users.append(self)
  251. self.inputs = [
  252. node,
  253. ]
  254. self.const_val = []
  255. self.method = method
  256. @classmethod
  257. def make(cls, *args, **kwargs):
  258. expr = cls(*args, **kwargs)
  259. active_module_tracer().current_scope()._insert(expr)
  260. return expr
  261. @property
  262. def graph(self):
  263. if isinstance(self.inputs[0], ModuleNode):
  264. m_node = self.inputs[0]
  265. if (
  266. hasattr(m_node.owner, "argdef_graph_map")
  267. and m_node.owner.argdef_graph_map
  268. ):
  269. assert self.arg_def in m_node.owner.argdef_graph_map
  270. return m_node.owner.argdef_graph_map[self.arg_def]
  271. return None
  272. def interpret(self, *inputs):
  273. args, kwargs = self.unflatten_args(inputs)
  274. obj = args[0]
  275. meth = getattr(obj, self.method)
  276. if inspect.ismethod(meth):
  277. args = args[1:]
  278. outputs = getattr(obj, self.method)(*args, **kwargs)
  279. if self.method == "__setitem__":
  280. outputs = obj
  281. if outputs is None:
  282. return outputs
  283. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  284. return outputs
  285. def __repr__(self):
  286. args = ", ".join(str(i) for i in self.args[1:])
  287. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  288. outputs = self.outputs
  289. if self.out_def:
  290. outputs = self.out_def.unflatten(outputs)
  291. method = ".%s" % self.method
  292. if method == ".__call__":
  293. method = ""
  294. return "%{}:\t{}{}{}({})".format(
  295. self._id,
  296. str(outputs) + " = " if outputs else "",
  297. self.args[0],
  298. method,
  299. ", ".join([args, kwargs]),
  300. )
  301. # expr: outputs = apply(self.opdef, *inputs)
  302. class Apply(Expr):
  303. r"""``Apply`` represents a call to :func:`apply`.
  304. Args:
  305. opdef: the applied :class:`OpDef`.
  306. """
  307. opdef = None
  308. def __init__(self, opdef):
  309. super().__init__()
  310. assert isinstance(opdef, OpDef)
  311. self.opdef = opdef
  312. self.inputs = []
  313. @classmethod
  314. def make(cls, *args, **kwargs):
  315. expr = cls(*args, **kwargs)
  316. active_module_tracer().current_scope()._insert(expr)
  317. return expr
  318. def interpret(self, *inputs):
  319. return apply(self.opdef, *inputs)
  320. def __repr__(self):
  321. return "%{}:\t{} = {}({})".format(
  322. self._id,
  323. ", ".join(str(i) for i in self.outputs),
  324. self.opdef,
  325. ", ".join(str(i) for i in self.inputs),
  326. )
  327. def __getstate__(self):
  328. state = super().__getstate__()
  329. state["opdef"] = get_opdef_state(state["opdef"])
  330. return state
  331. def __setstate__(self, state):
  332. state["opdef"] = load_opdef_from_state(state["opdef"])
  333. for k, v in state.items():
  334. setattr(self, k, v)
  335. @classmethod
  336. def apply_module_trace_hook(cls, opdef, *inputs):
  337. for i in inputs:
  338. node = NodeMixin.get(i, None)
  339. if node is None: # capture as constant
  340. NodeMixin.wrap_safe(i, Constant.make(i))
  341. if isinstance(opdef, FakeQuant):
  342. inp_nodes = [NodeMixin.get(inputs[0])]
  343. for i in inputs[1:]:
  344. node = Constant.make(i)
  345. inp_nodes.append(node)
  346. apply_node = cls.make(opdef)
  347. for n in inp_nodes:
  348. n.users.append(apply_node)
  349. apply_node.inputs = inp_nodes
  350. else:
  351. apply_node = cls.make(opdef)
  352. apply_node.add_inputs(inputs)
  353. assert not apply_node.const_val
  354. unset_module_tracing()
  355. outputs = apply(opdef, *inputs)
  356. set_module_tracing()
  357. apply_node.add_outputs(outputs)
  358. for n, v in zip(apply_node.outputs, outputs):
  359. NodeMixin.wrap_safe(v, n)
  360. return list(outputs)
  361. class CallFunction(Expr):
  362. r"""``CallFunction`` represents a call to a built-in function.
  363. Args:
  364. func: a built-in function.
  365. """
  366. def __init__(self, func):
  367. super().__init__()
  368. assert isinstance(func, Callable)
  369. self.func = func
  370. self.const_val = []
  371. self.inputs = []
  372. @classmethod
  373. def make(cls, *args, **kwargs):
  374. expr = cls(*args, **kwargs)
  375. active_module_tracer().current_scope()._insert(expr)
  376. return expr
  377. def interpret(self, *inputs):
  378. args, kwargs = self.unflatten_args(inputs)
  379. outputs = self.func(*args, **kwargs)
  380. if outputs is None:
  381. return outputs
  382. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  383. return outputs
  384. def __repr__(self):
  385. args = ", ".join(str(i) for i in self.args)
  386. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  387. outputs = self.outputs
  388. if self.out_def:
  389. outputs = self.out_def.unflatten(outputs)
  390. return "%{}:\t{}{}({})".format(
  391. self._id,
  392. str(outputs) + " = " if outputs else "",
  393. self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
  394. ", ".join([args, kwargs]),
  395. )
  396. # expr outputs = self.value
  397. class Constant(Expr):
  398. r"""``Constant`` represents a ``Tensor`` or "Module" which is not the attribute of a Module.
  399. Args:
  400. c: a const Tensor or Module.
  401. name: the name of output Node.
  402. """
  403. value = None
  404. r"""The const Tensor or Module"""
  405. # TODO: constant cache to reduce the size of dumped model
  406. _constant_cache = {}
  407. def __init__(self, c, name=None):
  408. super().__init__()
  409. assert isinstance(c, (RawTensor, Module))
  410. if isinstance(c, Module):
  411. assert module_tracer.is_builtin(c) or c.is_qat
  412. self.value = c
  413. self.name = name
  414. self.inputs = []
  415. node_cls = NodeMixin.get_wrapped_type(c)
  416. self.outputs = [
  417. node_cls(self, name=name, orig_name=name),
  418. ]
  419. self.outputs[0]._name = name if name else "const_" + str(self._id)
  420. @classmethod
  421. def make(cls, *args, **kwargs):
  422. expr = cls(*args, **kwargs)
  423. name = "const_module" if isinstance(expr.value, Module) else "const_tensor"
  424. full_name = name
  425. if (
  426. isinstance(expr.value, RawTensor)
  427. and id(expr.value) in active_module_tracer().id2name
  428. ):
  429. full_name = active_module_tracer().id2name[id(expr.value)]
  430. scope_name = active_module_tracer().current_scope()._module_name
  431. if full_name and scope_name:
  432. full_name = ("self." + full_name)[len(scope_name) + 1 :]
  433. else:
  434. full_name = name
  435. else:
  436. full_name = name
  437. name = active_module_tracer().current_scope()._create_unique_name(full_name)
  438. expr.outputs[0]._name = name
  439. expr.outputs[0]._orig_name = full_name
  440. active_module_tracer().current_scope()._insert(expr)
  441. return expr.outputs[0]
  442. def interpret(self, *inputs):
  443. if isinstance(self.value, RawTensor):
  444. return Const(self.value.numpy())()
  445. return (self.value,)
  446. def __repr__(self):
  447. name = self.name
  448. if name is None:
  449. name = type(self.value)
  450. node_type = "Module"
  451. if isinstance(self.outputs[0], TensorNode):
  452. node_type = "Tensor"
  453. return "%{}:\t{} = Constant({}) -> ({})".format(
  454. self._id, self.outputs[0], name, node_type
  455. )
  456. def __getstate__(self):
  457. state = self.__dict__.copy()
  458. if "_top_graph" in state:
  459. state.pop("_top_graph")
  460. if isinstance(self.value, RawTensor):
  461. state["value"] = Tensor(self.value)
  462. return state

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台