You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import builtins
  10. import collections
  11. import copy
  12. import inspect
  13. from typing import Callable, Dict, List
  14. from ...core._imperative_rt import OpDef
  15. from ...core._imperative_rt.core2 import Tensor as RawTensor
  16. from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
  17. from ...core.ops.special import Const
  18. from ...module import Module
  19. from ...tensor import Parameter, Tensor
  20. from .module_tracer import active_module_tracer, module_tracer
  21. from .node import ModuleNode, Node, NodeMixin, TensorNode
  22. from .pytree import TreeDef, tree_flatten
  23. class Expr:
  24. """
  25. ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
  26. """
  27. __total_id = 0
  28. inputs = None # type: List[Node]
  29. outputs = None # type: List[Node]
  30. const_val = None # type: List[Any]
  31. arg_def = None # type: TreeDef
  32. out_def = None # type: TreeDef
  33. _top_graph = None # type: weakref.ReferenceType
  34. def __init__(self) -> None:
  35. self._id = Expr.__total_id
  36. Expr.__total_id += 1
  37. self._disable_remove = False
  38. def enable_remove(self):
  39. self._disable_remove = False
  40. def disable_remove(self):
  41. self._disable_remove = True
  42. def add_inputs(self, vals):
  43. if not isinstance(vals, collections.abc.Sequence):
  44. vals = (vals,)
  45. for val in vals:
  46. node = NodeMixin.get(val, None)
  47. if isinstance(node, (TensorNode, ModuleNode)):
  48. self.inputs.append(node)
  49. node.users.append(self)
  50. else:
  51. assert node is None
  52. idx = len(self.inputs) + len(self.const_val)
  53. self.const_val.append((idx, val))
  54. def add_outputs(self, outputs):
  55. self.outputs = []
  56. if outputs is not None:
  57. if not isinstance(outputs, collections.Sequence):
  58. outputs = (outputs,)
  59. for i in outputs:
  60. assert isinstance(i, RawTensor)
  61. self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
  62. for i, node in zip(outputs, self.outputs,):
  63. NodeMixin.wrap_safe(i, node)
  64. def unflatten_args(self, inputs):
  65. if self.arg_def is not None:
  66. inputs = list(inputs)
  67. for idx, val in self.const_val:
  68. inputs.insert(idx, val)
  69. args, kwargs = self.arg_def.unflatten(inputs)
  70. return args, kwargs
  71. else:
  72. return inputs, {}
  73. def _replace_nodes(self, repl_dict: Dict[Node, Node], nodes: List[Node]):
  74. while repl_dict:
  75. node, repl_node = repl_dict.popitem()
  76. assert type(node) == type(repl_node)
  77. assert node in nodes
  78. index = nodes.index(node)
  79. nodes[index] = repl_node
  80. repl_node.users.append(self)
  81. node.users.pop(self)
  82. def replace_inputs(self, repl_dict: Dict[Node, Node]):
  83. self._replace_nodes(repl_dict, self.inputs)
  84. def replace_outputs(self, repl_dict: Dict[Node, Node]):
  85. self._replace_nodes(repl_dict, self.outputs)
  86. @property
  87. def kwargs(self):
  88. _, kwargs = self.unflatten_args(self.inputs)
  89. return kwargs
  90. @property
  91. def args(self):
  92. args, _ = self.unflatten_args(self.inputs)
  93. return args
  94. @property
  95. def top_graph(self):
  96. if self._top_graph:
  97. return self._top_graph()
  98. return None
  99. # expr: None (i.e. fake expression which is used to mark input)
  100. class Input(Expr):
  101. name = None
  102. def __init__(self, name=None, type=None):
  103. super().__init__()
  104. self.inputs = []
  105. node_cls = type if type else Node
  106. self.outputs = [
  107. node_cls(self, name=name),
  108. ]
  109. self.name = name
  110. @classmethod
  111. def make(cls, *args, **kwargs):
  112. expr = cls(*args, **kwargs)
  113. active_module_tracer().current_scope().add_input(expr.outputs[0])
  114. return expr.outputs[0]
  115. def __repr__(self):
  116. return "%{}: {} = Input({})".format(self._id, self.outputs[0], self.name)
  117. # expr: outputs = getattr(inputs[0], self.name)
  118. class GetAttr(Expr):
  119. name = None
  120. def __init__(self, module, name, type=None):
  121. super().__init__()
  122. assert isinstance(module, ModuleNode)
  123. self.inputs = [
  124. module,
  125. ]
  126. module.users.append(self)
  127. self.name = name
  128. node_cls = type if type else Node
  129. self.outputs = [
  130. node_cls(self),
  131. ]
  132. @classmethod
  133. def make(cls, *args, **kwargs):
  134. expr = cls(*args, **kwargs)
  135. active_module_tracer().current_scope().insert(expr)
  136. expr.outputs[0]._name = expr.name
  137. return expr.outputs[0]
  138. def interpret(self, *inputs):
  139. return (getattr(inputs[0], self.name),)
  140. def __repr__(self):
  141. return '%{}: {} = GetAttr({}, "{}")'.format(
  142. self._id, self.outputs[0], self.inputs[0], self.name
  143. )
  144. # expr: outputs = inputs[0].__call__(*inputs[1:])
  145. class CallMethod(Expr):
  146. def __init__(self, node, method="__call__"):
  147. super().__init__()
  148. if isinstance(node, type):
  149. assert issubclass(node, Tensor)
  150. cls = Parameter if issubclass(node, Parameter) else Tensor
  151. self.inputs = []
  152. self.const_val = [(0, cls)]
  153. else:
  154. assert isinstance(node, (TensorNode, ModuleNode))
  155. node.users.append(self)
  156. self.inputs = [
  157. node,
  158. ]
  159. self.const_val = []
  160. self.method = method
  161. @classmethod
  162. def make(cls, *args, **kwargs):
  163. expr = cls(*args, **kwargs)
  164. active_module_tracer().current_scope().insert(expr)
  165. return expr
  166. @property
  167. def graph(self):
  168. if isinstance(self.inputs[0], ModuleNode):
  169. m_node = self.inputs[0]
  170. if (
  171. hasattr(m_node.owner, "argdef_graph_map")
  172. and m_node.owner.argdef_graph_map
  173. ):
  174. assert self.arg_def in m_node.owner.argdef_graph_map
  175. return m_node.owner.argdef_graph_map[self.arg_def]
  176. return None
  177. def interpret(self, *inputs):
  178. args, kwargs = self.unflatten_args(inputs)
  179. obj = args[0]
  180. meth = getattr(obj, self.method)
  181. if inspect.ismethod(meth):
  182. args = args[1:]
  183. outputs = getattr(obj, self.method)(*args, **kwargs)
  184. if self.method == "__setitem__":
  185. outputs = obj
  186. if outputs is None:
  187. return outputs
  188. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  189. return outputs
  190. def __repr__(self):
  191. args = ", ".join(str(i) for i in self.args[1:])
  192. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  193. outputs = self.outputs
  194. if self.out_def:
  195. outputs = self.out_def.unflatten(outputs)
  196. return "%{}: {}{}.{}({})".format(
  197. self._id,
  198. str(outputs) + " = " if outputs else "",
  199. self.args[0],
  200. self.method,
  201. ", ".join([args, kwargs]),
  202. )
  203. # expr: outputs = apply(self.opdef, *inputs)
  204. class Apply(Expr):
  205. opdef = None
  206. def __init__(self, opdef):
  207. super().__init__()
  208. assert isinstance(opdef, OpDef)
  209. self.opdef = opdef
  210. self.inputs = []
  211. @classmethod
  212. def make(cls, *args, **kwargs):
  213. expr = cls(*args, **kwargs)
  214. active_module_tracer().current_scope().insert(expr)
  215. return expr
  216. def interpret(self, *inputs):
  217. return apply(self.opdef, *inputs)
  218. def __repr__(self):
  219. return "%{}: {} = {}({})".format(
  220. self._id,
  221. ", ".join(str(i) for i in self.outputs),
  222. self.opdef,
  223. ", ".join(str(i) for i in self.inputs),
  224. )
  225. @classmethod
  226. def apply_module_trace_hook(cls, opdef, *inputs):
  227. for i in inputs:
  228. node = NodeMixin.get(i, None)
  229. if node is None: # capture as constant
  230. NodeMixin.wrap_safe(i, Constant.make(i))
  231. apply_node = cls.make(opdef)
  232. apply_node.add_inputs(inputs)
  233. assert not apply_node.const_val
  234. unset_module_tracing()
  235. outputs = apply(opdef, *inputs)
  236. set_module_tracing()
  237. apply_node.add_outputs(outputs)
  238. for n, v in zip(apply_node.outputs, outputs):
  239. NodeMixin.wrap_safe(v, n)
  240. return list(outputs)
  241. class CallFunction(Expr):
  242. def __init__(self, func):
  243. super().__init__()
  244. assert isinstance(func, Callable)
  245. self.func = func
  246. self.const_val = []
  247. self.inputs = []
  248. @classmethod
  249. def make(cls, *args, **kwargs):
  250. expr = cls(*args, **kwargs)
  251. active_module_tracer().current_scope().insert(expr)
  252. return expr
  253. def interpret(self, *inputs):
  254. args, kwargs = self.unflatten_args(inputs)
  255. outputs = self.func(*args, **kwargs)
  256. if outputs is None:
  257. return outputs
  258. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  259. return outputs
  260. def __repr__(self):
  261. args = ", ".join(str(i) for i in self.args)
  262. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  263. outputs = self.outputs
  264. if self.out_def:
  265. outputs = self.out_def.unflatten(outputs)
  266. return "%{}: {}{}({})".format(
  267. self._id,
  268. str(outputs) + " = " if outputs else "",
  269. self.func.__module__ + "." + self.func.__name__,
  270. ", ".join([args, kwargs]),
  271. )
  272. # expr outputs = self.value
  273. class Constant(Expr):
  274. value = None
  275. # TODO: constant cache to reduce the size of dumped model
  276. _constant_cache = {}
  277. def __init__(self, c):
  278. super().__init__()
  279. assert isinstance(c, (RawTensor, Module))
  280. if isinstance(c, Module):
  281. assert module_tracer.is_builtin(c)
  282. self.value = c
  283. self.inputs = []
  284. node_cls = NodeMixin.get_wrapped_type(c)
  285. self.outputs = [
  286. node_cls(self),
  287. ]
  288. @classmethod
  289. def make(cls, *args, **kwargs):
  290. expr = cls(*args, **kwargs)
  291. active_module_tracer().current_scope().insert(expr)
  292. return expr.outputs[0]
  293. def interpret(self, *inputs):
  294. if isinstance(self.value, RawTensor):
  295. return Const(self.value.numpy())()
  296. return (self.value,)
  297. def __repr__(self):
  298. return "%{}: {} = Constant({})".format(
  299. self._id, self.outputs[0], type(self.value)
  300. )
  301. def __getstate__(self):
  302. state = self.__dict__.copy()
  303. if isinstance(self.value, RawTensor):
  304. state["value"] = Tensor(self.value)
  305. return state

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台