You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import builtins
  10. import collections
  11. import copy
  12. import inspect
  13. import re
  14. from typing import Callable, Dict, List
  15. from ...core._imperative_rt import OpDef
  16. from ...core._imperative_rt.core2 import Tensor as RawTensor
  17. from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
  18. from ...core.ops.special import Const
  19. from ...module import Module
  20. from ...tensor import Parameter, Tensor
  21. from .module_tracer import active_module_tracer, module_tracer
  22. from .node import ModuleNode, Node, NodeMixin, TensorNode
  23. from .pytree import ArgsIndex, TreeDef, tree_flatten
  24. def rstrip(s: str, __chars: str):
  25. __chars = re.escape(__chars)
  26. s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
  27. return s
  28. def lstrip(s: str, __chars: str):
  29. __chars = re.escape(__chars)
  30. s = re.sub(r"^(?:%s)+(?P<right>.*)$" % __chars, "\g<right>", s)
  31. return s
  32. def strip(s: str, __chars: str):
  33. s = lstrip(rstrip(s, __chars), __chars)
  34. return s
  35. class Expr:
  36. """
  37. ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
  38. """
  39. __total_id = 0
  40. inputs = None # type: List[Node]
  41. outputs = None # type: List[Node]
  42. const_val = None # type: List[Any]
  43. arg_def = None # type: TreeDef
  44. out_def = None # type: TreeDef
  45. _top_graph = None # type: weakref.ReferenceType
  46. def __init__(self) -> None:
  47. self._id = Expr.__total_id
  48. Expr.__total_id += 1
  49. self._disable_remove = False
  50. def enable_remove(self):
  51. self._disable_remove = False
  52. def disable_remove(self):
  53. self._disable_remove = True
  54. def add_inputs(self, vals):
  55. if not isinstance(vals, collections.abc.Sequence):
  56. vals = (vals,)
  57. for val in vals:
  58. node = NodeMixin.get(val, None)
  59. if isinstance(node, (TensorNode, ModuleNode)):
  60. self.inputs.append(node)
  61. node.users.append(self)
  62. else:
  63. assert node is None
  64. idx = len(self.inputs) + len(self.const_val)
  65. self.const_val.append((idx, val))
  66. def add_outputs(self, outputs):
  67. self.outputs = []
  68. if outputs is not None:
  69. if not isinstance(outputs, collections.Sequence):
  70. outputs = (outputs,)
  71. name = None
  72. if isinstance(self, CallMethod):
  73. name = self.inputs[0]._name
  74. assert name is not None
  75. name = rstrip(name, "_out")
  76. if self.method == "__call__":
  77. name += "_out"
  78. else:
  79. strip_method = strip(self.method, "_")
  80. name = "%s_out" % strip_method
  81. elif isinstance(self, CallFunction):
  82. name = self.func.__name__ + "_out"
  83. elif isinstance(self, Apply):
  84. name = str(self.opdef).lower() + "_out"
  85. for i in outputs:
  86. assert isinstance(i, RawTensor)
  87. o_name = (
  88. active_module_tracer().current_scope()._create_unique_name(name)
  89. )
  90. self.outputs.append(
  91. NodeMixin.get_wrapped_type(i)(expr=self, name=o_name)
  92. )
  93. for i, node in zip(outputs, self.outputs,):
  94. NodeMixin.wrap_safe(i, node)
  95. def unflatten_args(self, inputs):
  96. if self.arg_def is not None:
  97. inputs = list(inputs)
  98. for idx, val in self.const_val:
  99. inputs.insert(idx, val)
  100. args, kwargs = self.arg_def.unflatten(inputs)
  101. return args, kwargs
  102. else:
  103. return inputs, {}
  104. def _replace_nodes(self, repl_dict: Dict[Node, Node], nodes: List[Node]):
  105. while repl_dict:
  106. node, repl_node = repl_dict.popitem()
  107. assert type(node) == type(repl_node)
  108. assert node in nodes
  109. index = nodes.index(node)
  110. nodes[index] = repl_node
  111. repl_node.users.append(self)
  112. node.users.pop(self)
  113. def replace_inputs(self, repl_dict: Dict[Node, Node]):
  114. self._replace_nodes(repl_dict, self.inputs)
  115. def replace_outputs(self, repl_dict: Dict[Node, Node]):
  116. self._replace_nodes(repl_dict, self.outputs)
  117. @property
  118. def kwargs(self):
  119. _, kwargs = self.unflatten_args(self.inputs)
  120. return kwargs
  121. @property
  122. def args(self):
  123. args, _ = self.unflatten_args(self.inputs)
  124. return args
  125. @property
  126. def top_graph(self):
  127. if self._top_graph:
  128. return self._top_graph()
  129. return None
  130. # expr: None (i.e. fake expression which is used to mark input)
  131. class Input(Expr):
  132. name = None
  133. def __init__(self, name=None, type=None):
  134. super().__init__()
  135. self.inputs = []
  136. node_cls = type if type else Node
  137. self.outputs = [
  138. node_cls(self, name=name),
  139. ]
  140. self.name = name
  141. @classmethod
  142. def make(cls, *args, **kwargs):
  143. expr = cls(*args, **kwargs)
  144. oup_node = expr.outputs[0]
  145. name = (
  146. active_module_tracer().current_scope()._create_unique_name(oup_node._name)
  147. )
  148. oup_node._name = name
  149. active_module_tracer().current_scope().add_input(oup_node)
  150. return expr.outputs[0]
  151. def __repr__(self):
  152. return "%{}:\t{} = Input({})".format(self._id, self.outputs[0], self.name)
  153. # expr: outputs = getattr(inputs[0], self.name)
  154. class GetAttr(Expr):
  155. name = None
  156. def __init__(self, module, name, type=None):
  157. super().__init__()
  158. assert isinstance(module, ModuleNode)
  159. self.inputs = [
  160. module,
  161. ]
  162. module.users.append(self)
  163. self.name = name
  164. node_cls = type if type else Node
  165. self.outputs = [
  166. node_cls(self, name=name),
  167. ]
  168. @classmethod
  169. def make(cls, *args, **kwargs):
  170. expr = cls(*args, **kwargs)
  171. module = expr.inputs[0]
  172. oup_name = expr.name
  173. while module._name != "self":
  174. oup_name = module._name + "_" + oup_name
  175. module = module.expr.inputs[0]
  176. oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name)
  177. expr.outputs[0]._name = oup_name
  178. active_module_tracer().current_scope().insert(expr)
  179. return expr.outputs[0]
  180. def interpret(self, *inputs):
  181. return (getattr(inputs[0], self.name),)
  182. def __repr__(self):
  183. out_type = "Tensor"
  184. if isinstance(self.outputs[0], ModuleNode):
  185. out_type = self.outputs[0].module_type.__name__
  186. return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
  187. self._id, self.outputs[0], self.inputs[0], self.name, out_type
  188. )
  189. # expr: outputs = inputs[0].__call__(*inputs[1:])
  190. class CallMethod(Expr):
  191. def __init__(self, node, method="__call__"):
  192. super().__init__()
  193. if isinstance(node, type):
  194. assert issubclass(node, Tensor)
  195. cls = Parameter if issubclass(node, Parameter) else Tensor
  196. self.inputs = []
  197. self.const_val = [(0, cls)]
  198. else:
  199. assert isinstance(node, (TensorNode, ModuleNode))
  200. node.users.append(self)
  201. self.inputs = [
  202. node,
  203. ]
  204. self.const_val = []
  205. self.method = method
  206. @classmethod
  207. def make(cls, *args, **kwargs):
  208. expr = cls(*args, **kwargs)
  209. active_module_tracer().current_scope().insert(expr)
  210. return expr
  211. @property
  212. def graph(self):
  213. if isinstance(self.inputs[0], ModuleNode):
  214. m_node = self.inputs[0]
  215. if (
  216. hasattr(m_node.owner, "argdef_graph_map")
  217. and m_node.owner.argdef_graph_map
  218. ):
  219. assert self.arg_def in m_node.owner.argdef_graph_map
  220. return m_node.owner.argdef_graph_map[self.arg_def]
  221. return None
  222. def interpret(self, *inputs):
  223. args, kwargs = self.unflatten_args(inputs)
  224. obj = args[0]
  225. meth = getattr(obj, self.method)
  226. if inspect.ismethod(meth):
  227. args = args[1:]
  228. outputs = getattr(obj, self.method)(*args, **kwargs)
  229. if self.method == "__setitem__":
  230. outputs = obj
  231. if outputs is None:
  232. return outputs
  233. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  234. return outputs
  235. def __repr__(self):
  236. args = ", ".join(str(i) for i in self.args[1:])
  237. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  238. outputs = self.outputs
  239. if self.out_def:
  240. outputs = self.out_def.unflatten(outputs)
  241. method = ".%s" % self.method
  242. if method == ".__call__":
  243. method = ""
  244. return "%{}:\t{}{}{}({})".format(
  245. self._id,
  246. str(outputs) + " = " if outputs else "",
  247. self.args[0],
  248. method,
  249. ", ".join([args, kwargs]),
  250. )
  251. # expr: outputs = apply(self.opdef, *inputs)
  252. class Apply(Expr):
  253. opdef = None
  254. def __init__(self, opdef):
  255. super().__init__()
  256. assert isinstance(opdef, OpDef)
  257. self.opdef = opdef
  258. self.inputs = []
  259. @classmethod
  260. def make(cls, *args, **kwargs):
  261. expr = cls(*args, **kwargs)
  262. active_module_tracer().current_scope().insert(expr)
  263. return expr
  264. def interpret(self, *inputs):
  265. return apply(self.opdef, *inputs)
  266. def __repr__(self):
  267. return "%{}:\t{} = {}({})".format(
  268. self._id,
  269. ", ".join(str(i) for i in self.outputs),
  270. self.opdef,
  271. ", ".join(str(i) for i in self.inputs),
  272. )
  273. @classmethod
  274. def apply_module_trace_hook(cls, opdef, *inputs):
  275. for i in inputs:
  276. node = NodeMixin.get(i, None)
  277. if node is None: # capture as constant
  278. NodeMixin.wrap_safe(i, Constant.make(i))
  279. apply_node = cls.make(opdef)
  280. apply_node.add_inputs(inputs)
  281. assert not apply_node.const_val
  282. unset_module_tracing()
  283. outputs = apply(opdef, *inputs)
  284. set_module_tracing()
  285. apply_node.add_outputs(outputs)
  286. for n, v in zip(apply_node.outputs, outputs):
  287. NodeMixin.wrap_safe(v, n)
  288. return list(outputs)
  289. class CallFunction(Expr):
  290. def __init__(self, func):
  291. super().__init__()
  292. assert isinstance(func, Callable)
  293. self.func = func
  294. self.const_val = []
  295. self.inputs = []
  296. @classmethod
  297. def make(cls, *args, **kwargs):
  298. expr = cls(*args, **kwargs)
  299. active_module_tracer().current_scope().insert(expr)
  300. return expr
  301. def interpret(self, *inputs):
  302. args, kwargs = self.unflatten_args(inputs)
  303. outputs = self.func(*args, **kwargs)
  304. if outputs is None:
  305. return outputs
  306. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  307. return outputs
  308. def __repr__(self):
  309. args = ", ".join(str(i) for i in self.args)
  310. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  311. outputs = self.outputs
  312. if self.out_def:
  313. outputs = self.out_def.unflatten(outputs)
  314. return "%{}:\t{}{}({})".format(
  315. self._id,
  316. str(outputs) + " = " if outputs else "",
  317. self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
  318. ", ".join([args, kwargs]),
  319. )
  320. # expr outputs = self.value
  321. class Constant(Expr):
  322. value = None
  323. # TODO: constant cache to reduce the size of dumped model
  324. _constant_cache = {}
  325. def __init__(self, c, name=None):
  326. super().__init__()
  327. assert isinstance(c, (RawTensor, Module))
  328. if isinstance(c, Module):
  329. assert module_tracer.is_builtin(c)
  330. self.value = c
  331. self.name = name
  332. self.inputs = []
  333. node_cls = NodeMixin.get_wrapped_type(c)
  334. self.outputs = [
  335. node_cls(self, name=name),
  336. ]
  337. @classmethod
  338. def make(cls, *args, **kwargs):
  339. expr = cls(*args, **kwargs)
  340. name = "const_module" if isinstance(expr.value, Module) else "const_tensor"
  341. name = active_module_tracer().current_scope()._create_unique_name(name)
  342. expr.outputs[0]._name = name
  343. active_module_tracer().current_scope().insert(expr)
  344. return expr.outputs[0]
  345. def interpret(self, *inputs):
  346. if isinstance(self.value, RawTensor):
  347. return Const(self.value.numpy())()
  348. return (self.value,)
  349. def __repr__(self):
  350. name = self.name
  351. if name is None:
  352. name = type(self.value)
  353. node_type = "Module"
  354. if isinstance(self.outputs[0], TensorNode):
  355. node_type = "Tensor"
  356. return "%{}:\t{} = Constant({}) -> ({})".format(
  357. self._id, self.outputs[0], name, node_type
  358. )
  359. def __getstate__(self):
  360. state = self.__dict__.copy()
  361. if isinstance(self.value, RawTensor):
  362. state["value"] = Tensor(self.value)
  363. return state

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台