You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import builtins
  10. import collections
  11. import copy
  12. import inspect
  13. import re
  14. from typing import Callable, Dict, List
  15. from ...core._imperative_rt import OpDef
  16. from ...core._imperative_rt.core2 import Tensor as RawTensor
  17. from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
  18. from ...core.ops.builtin import FakeQuant
  19. from ...core.ops.special import Const
  20. from ...module import Module
  21. from ...tensor import Parameter, Tensor
  22. from .module_tracer import active_module_tracer, module_tracer
  23. from .node import ModuleNode, Node, NodeMixin, TensorNode
  24. from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
  25. from .serialization import get_opdef_state, load_opdef_from_state
  26. def rstrip(s: str, __chars: str):
  27. __chars = re.escape(__chars)
  28. s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
  29. return s
  30. class Expr:
  31. """
  32. ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
  33. """
  34. __total_id = 0
  35. inputs = None # type: List[Node]
  36. outputs = None # type: List[Node]
  37. const_val = None # type: List[Any]
  38. arg_def = None # type: TreeDef
  39. out_def = None # type: TreeDef
  40. _top_graph = None # type: weakref.ReferenceType
  41. def __init__(self) -> None:
  42. self._id = Expr.__total_id
  43. Expr.__total_id += 1
  44. self._disable_remove = False
  45. def enable_remove(self):
  46. self._disable_remove = False
  47. def disable_remove(self):
  48. self._disable_remove = True
  49. def add_inputs(self, vals):
  50. if not isinstance(vals, collections.abc.Sequence):
  51. vals = (vals,)
  52. for val in vals:
  53. node = NodeMixin.get(val, None)
  54. if isinstance(node, (TensorNode, ModuleNode)):
  55. self.inputs.append(node)
  56. node.users.append(self)
  57. else:
  58. assert node is None
  59. assert _is_leaf(val) and _is_const_leaf(val)
  60. idx = len(self.inputs) + len(self.const_val)
  61. self.const_val.append((idx, val))
  62. def add_outputs(self, outputs):
  63. self.outputs = []
  64. if outputs is not None:
  65. if not isinstance(outputs, collections.Sequence):
  66. outputs = (outputs,)
  67. name = None
  68. orig_name = None
  69. if isinstance(self, CallMethod):
  70. name = self.inputs[0]._name
  71. orig_name = self.inputs[0]._orig_name
  72. assert isinstance(name, str), "The name of ({}) must be a str".format(
  73. self.inputs[0]
  74. )
  75. assert isinstance(
  76. orig_name, str
  77. ), "The orig_name of ({}) must be a str".format(self.inputs[0])
  78. name = rstrip(name, "_out")
  79. if self.method == "__call__":
  80. name += "_out"
  81. orig_name += "_out"
  82. else:
  83. strip_method = self.method.strip("_")
  84. name = "%s_out" % strip_method
  85. orig_name = name
  86. elif isinstance(self, CallFunction):
  87. name = self.func.__name__ + "_out"
  88. elif isinstance(self, Apply):
  89. name = str(self.opdef).lower() + "_out"
  90. for i in outputs:
  91. assert isinstance(i, RawTensor), "The output must be a Tensor"
  92. o_name = (
  93. active_module_tracer().current_scope()._create_unique_name(name)
  94. )
  95. self.outputs.append(
  96. NodeMixin.get_wrapped_type(i)(
  97. expr=self,
  98. name=o_name,
  99. orig_name=orig_name if orig_name else o_name,
  100. )
  101. )
  102. for i, node in zip(outputs, self.outputs,):
  103. NodeMixin.wrap_safe(i, node)
  104. def unflatten_args(self, inputs):
  105. if self.arg_def is not None:
  106. inputs = list(inputs)
  107. for idx, val in self.const_val:
  108. inputs.insert(idx, val)
  109. args, kwargs = self.arg_def.unflatten(inputs)
  110. return args, kwargs
  111. else:
  112. return inputs, {}
  113. def replace_inputs(self, repl_dict: Dict[Node, Node]):
  114. while repl_dict:
  115. node, repl_node = repl_dict.popitem()
  116. assert type(node) == type(repl_node)
  117. assert node in self.inputs, "({}) is not in the ({})".format(node, self)
  118. assert (
  119. repl_node.top_graph == node.top_graph
  120. ), "({}) and ({}) are not in the same graph".format(node, repl_node)
  121. graph = self.top_graph
  122. repl_expr_idx = graph._exprs.index(repl_node.expr)
  123. self_idx = graph._exprs.index(self)
  124. assert (
  125. repl_expr_idx < self_idx
  126. ), "({}) must be generated before ({})".format(repl_node, self)
  127. idx = self.inputs.index(node)
  128. self.inputs[idx] = repl_node
  129. user_idx = node.users.index(self)
  130. assert user_idx >= 0
  131. node.users.pop(user_idx)
  132. repl_node.users.append(self)
  133. @property
  134. def kwargs(self):
  135. _, kwargs = self.unflatten_args(self.inputs)
  136. return kwargs
  137. @property
  138. def args(self):
  139. args, _ = self.unflatten_args(self.inputs)
  140. return args
  141. @property
  142. def top_graph(self):
  143. if self._top_graph:
  144. return self._top_graph()
  145. return None
  146. def __getstate__(self):
  147. state = self.__dict__.copy()
  148. if "_top_graph" in state:
  149. state.pop("_top_graph")
  150. return state
  151. # expr: None (i.e. fake expression which is used to mark input)
  152. class Input(Expr):
  153. name = None
  154. def __init__(self, name=None, type=None, orig_name=None):
  155. super().__init__()
  156. self.inputs = []
  157. node_cls = type if type else Node
  158. if orig_name is None:
  159. orig_name = name
  160. self.outputs = [
  161. node_cls(self, name=name, orig_name=orig_name),
  162. ]
  163. self.name = name
  164. @classmethod
  165. def make(cls, *args, **kwargs):
  166. expr = cls(*args, **kwargs)
  167. oup_node = expr.outputs[0]
  168. name = (
  169. active_module_tracer().current_scope()._create_unique_name(oup_node._name)
  170. )
  171. oup_node._name = name
  172. active_module_tracer().current_scope()._add_input(oup_node)
  173. return expr.outputs[0]
  174. def __repr__(self):
  175. return "%{}:\t{} = Input({})".format(self._id, self.outputs[0], self.name)
  176. # expr: outputs = getattr(inputs[0], self.name)
  177. class GetAttr(Expr):
  178. name = None
  179. def __init__(self, module, name, type=None, orig_name=None):
  180. super().__init__()
  181. assert isinstance(module, ModuleNode)
  182. self.inputs = [
  183. module,
  184. ]
  185. module.users.append(self)
  186. self.name = name
  187. node_cls = type if type else Node
  188. self.outputs = [
  189. node_cls(self, name=name, orig_name=orig_name),
  190. ]
  191. @classmethod
  192. def make(cls, *args, **kwargs):
  193. expr = cls(*args, **kwargs)
  194. module = expr.inputs[0]
  195. oup_name = expr.name
  196. while module._name != "self":
  197. oup_name = module._name + "_" + oup_name
  198. module = module.expr.inputs[0]
  199. oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name)
  200. expr.outputs[0]._name = oup_name
  201. active_module_tracer().current_scope()._insert(expr)
  202. return expr.outputs[0]
  203. def interpret(self, *inputs):
  204. return (getattr(inputs[0], self.name),)
  205. def __repr__(self):
  206. out_type = "Tensor"
  207. if isinstance(self.outputs[0], ModuleNode):
  208. out_type = self.outputs[0].module_type.__name__
  209. return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
  210. self._id, self.outputs[0], self.inputs[0], self.name, out_type
  211. )
  212. # expr: outputs = inputs[0].__call__(*inputs[1:])
  213. class CallMethod(Expr):
  214. def __init__(self, node, method="__call__"):
  215. super().__init__()
  216. if isinstance(node, type):
  217. assert issubclass(node, Tensor)
  218. cls = Parameter if issubclass(node, Parameter) else Tensor
  219. self.inputs = []
  220. self.const_val = [(0, cls)]
  221. else:
  222. assert isinstance(node, (TensorNode, ModuleNode))
  223. node.users.append(self)
  224. self.inputs = [
  225. node,
  226. ]
  227. self.const_val = []
  228. self.method = method
  229. @classmethod
  230. def make(cls, *args, **kwargs):
  231. expr = cls(*args, **kwargs)
  232. active_module_tracer().current_scope()._insert(expr)
  233. return expr
  234. @property
  235. def graph(self):
  236. if isinstance(self.inputs[0], ModuleNode):
  237. m_node = self.inputs[0]
  238. if (
  239. hasattr(m_node.owner, "argdef_graph_map")
  240. and m_node.owner.argdef_graph_map
  241. ):
  242. assert self.arg_def in m_node.owner.argdef_graph_map
  243. return m_node.owner.argdef_graph_map[self.arg_def]
  244. return None
  245. def interpret(self, *inputs):
  246. args, kwargs = self.unflatten_args(inputs)
  247. obj = args[0]
  248. meth = getattr(obj, self.method)
  249. if inspect.ismethod(meth):
  250. args = args[1:]
  251. outputs = getattr(obj, self.method)(*args, **kwargs)
  252. if self.method == "__setitem__":
  253. outputs = obj
  254. if outputs is None:
  255. return outputs
  256. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  257. return outputs
  258. def __repr__(self):
  259. args = ", ".join(str(i) for i in self.args[1:])
  260. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  261. outputs = self.outputs
  262. if self.out_def:
  263. outputs = self.out_def.unflatten(outputs)
  264. method = ".%s" % self.method
  265. if method == ".__call__":
  266. method = ""
  267. return "%{}:\t{}{}{}({})".format(
  268. self._id,
  269. str(outputs) + " = " if outputs else "",
  270. self.args[0],
  271. method,
  272. ", ".join([args, kwargs]),
  273. )
  274. # expr: outputs = apply(self.opdef, *inputs)
  275. class Apply(Expr):
  276. opdef = None
  277. def __init__(self, opdef):
  278. super().__init__()
  279. assert isinstance(opdef, OpDef)
  280. self.opdef = opdef
  281. self.inputs = []
  282. @classmethod
  283. def make(cls, *args, **kwargs):
  284. expr = cls(*args, **kwargs)
  285. active_module_tracer().current_scope()._insert(expr)
  286. return expr
  287. def interpret(self, *inputs):
  288. return apply(self.opdef, *inputs)
  289. def __repr__(self):
  290. return "%{}:\t{} = {}({})".format(
  291. self._id,
  292. ", ".join(str(i) for i in self.outputs),
  293. self.opdef,
  294. ", ".join(str(i) for i in self.inputs),
  295. )
  296. def __getstate__(self):
  297. state = super().__getstate__()
  298. state["opdef"] = get_opdef_state(state["opdef"])
  299. return state
  300. def __setstate__(self, state):
  301. state["opdef"] = load_opdef_from_state(state["opdef"])
  302. for k, v in state.items():
  303. setattr(self, k, v)
  304. @classmethod
  305. def apply_module_trace_hook(cls, opdef, *inputs):
  306. for i in inputs:
  307. node = NodeMixin.get(i, None)
  308. if node is None: # capture as constant
  309. NodeMixin.wrap_safe(i, Constant.make(i))
  310. if isinstance(opdef, FakeQuant):
  311. inp_nodes = [NodeMixin.get(inputs[0])]
  312. for i in inputs[1:]:
  313. node = Constant.make(i)
  314. inp_nodes.append(node)
  315. apply_node = cls.make(opdef)
  316. for n in inp_nodes:
  317. n.users.append(apply_node)
  318. apply_node.inputs = inp_nodes
  319. else:
  320. apply_node = cls.make(opdef)
  321. apply_node.add_inputs(inputs)
  322. assert not apply_node.const_val
  323. unset_module_tracing()
  324. outputs = apply(opdef, *inputs)
  325. set_module_tracing()
  326. apply_node.add_outputs(outputs)
  327. for n, v in zip(apply_node.outputs, outputs):
  328. NodeMixin.wrap_safe(v, n)
  329. return list(outputs)
  330. class CallFunction(Expr):
  331. def __init__(self, func):
  332. super().__init__()
  333. assert isinstance(func, Callable)
  334. self.func = func
  335. self.const_val = []
  336. self.inputs = []
  337. @classmethod
  338. def make(cls, *args, **kwargs):
  339. expr = cls(*args, **kwargs)
  340. active_module_tracer().current_scope()._insert(expr)
  341. return expr
  342. def interpret(self, *inputs):
  343. args, kwargs = self.unflatten_args(inputs)
  344. outputs = self.func(*args, **kwargs)
  345. if outputs is None:
  346. return outputs
  347. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  348. return outputs
  349. def __repr__(self):
  350. args = ", ".join(str(i) for i in self.args)
  351. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  352. outputs = self.outputs
  353. if self.out_def:
  354. outputs = self.out_def.unflatten(outputs)
  355. return "%{}:\t{}{}({})".format(
  356. self._id,
  357. str(outputs) + " = " if outputs else "",
  358. self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
  359. ", ".join([args, kwargs]),
  360. )
  361. # expr outputs = self.value
  362. class Constant(Expr):
  363. value = None
  364. # TODO: constant cache to reduce the size of dumped model
  365. _constant_cache = {}
  366. def __init__(self, c, name=None):
  367. super().__init__()
  368. assert isinstance(c, (RawTensor, Module))
  369. if isinstance(c, Module):
  370. assert module_tracer.is_builtin(c) or c.is_qat
  371. self.value = c
  372. self.name = name
  373. self.inputs = []
  374. node_cls = NodeMixin.get_wrapped_type(c)
  375. self.outputs = [
  376. node_cls(self, name=name, orig_name=name),
  377. ]
  378. self.outputs[0]._name = name if name else "const_" + str(self._id)
  379. @classmethod
  380. def make(cls, *args, **kwargs):
  381. expr = cls(*args, **kwargs)
  382. name = "const_module" if isinstance(expr.value, Module) else "const_tensor"
  383. full_name = name
  384. if (
  385. isinstance(expr.value, RawTensor)
  386. and id(expr.value) in active_module_tracer().id2name
  387. ):
  388. full_name = active_module_tracer().id2name[id(expr.value)]
  389. scope_name = active_module_tracer().current_scope()._module_name
  390. if full_name and scope_name:
  391. full_name = ("self." + full_name)[len(scope_name) + 1 :]
  392. else:
  393. full_name = name
  394. else:
  395. full_name = name
  396. name = active_module_tracer().current_scope()._create_unique_name(full_name)
  397. expr.outputs[0]._name = name
  398. expr.outputs[0]._orig_name = full_name
  399. active_module_tracer().current_scope()._insert(expr)
  400. return expr.outputs[0]
  401. def interpret(self, *inputs):
  402. if isinstance(self.value, RawTensor):
  403. return Const(self.value.numpy())()
  404. return (self.value,)
  405. def __repr__(self):
  406. name = self.name
  407. if name is None:
  408. name = type(self.value)
  409. node_type = "Module"
  410. if isinstance(self.outputs[0], TensorNode):
  411. node_type = "Tensor"
  412. return "%{}:\t{} = Constant({}) -> ({})".format(
  413. self._id, self.outputs[0], name, node_type
  414. )
  415. def __getstate__(self):
  416. state = self.__dict__.copy()
  417. if "_top_graph" in state:
  418. state.pop("_top_graph")
  419. if isinstance(self.value, RawTensor):
  420. state["value"] = Tensor(self.value)
  421. return state

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台