You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 9.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import builtins
  10. import collections
  11. import copy
  12. import inspect
  13. from typing import Callable, List
  14. from ...core._imperative_rt import OpDef
  15. from ...core._imperative_rt.core2 import Tensor as RawTensor
  16. from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
  17. from ...core.ops.special import Const
  18. from ...module import Module
  19. from ...tensor import Parameter, Tensor
  20. from .module_tracer import active_module_tracer, module_tracer
  21. from .node import ModuleNode, Node, NodeMixin, TensorNode
  22. from .pytree import TreeDef, tree_flatten
  23. class Expr:
  24. """
  25. ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
  26. """
  27. inputs = None # type: List[Node]
  28. outputs = None # type: List[Node]
  29. const_val = None # type: List[Any]
  30. arg_def = None # type: TreeDef
  31. def add_inputs(self, vals):
  32. if not isinstance(vals, collections.abc.Sequence):
  33. vals = (vals,)
  34. for val in vals:
  35. node = NodeMixin.get(val, None)
  36. if isinstance(node, (TensorNode, ModuleNode)):
  37. self.inputs.append(node)
  38. node.users.append(self)
  39. else:
  40. assert node is None
  41. idx = len(self.inputs) + len(self.const_val)
  42. self.const_val.append((idx, val))
  43. def add_outputs(self, outputs):
  44. self.outputs = []
  45. if outputs is not None:
  46. if not isinstance(outputs, collections.Sequence):
  47. outputs = (outputs,)
  48. for i in outputs:
  49. assert isinstance(i, RawTensor)
  50. self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
  51. for i, node in zip(outputs, self.outputs,):
  52. NodeMixin.wrap_safe(i, node)
  53. def unflatten_args(self, inputs):
  54. if self.arg_def is not None:
  55. inputs = list(inputs)
  56. for idx, val in self.const_val:
  57. inputs.insert(idx, val)
  58. args, kwargs = self.arg_def.unflatten(inputs)
  59. return args, kwargs
  60. else:
  61. return inputs, {}
  62. @property
  63. def kwargs(self):
  64. _, kwargs = self.unflatten_args(self.inputs)
  65. return kwargs
  66. @property
  67. def args(self):
  68. args, _ = self.unflatten_args(self.inputs)
  69. return args
  70. # expr: None (i.e. fake expression which is used to mark input)
  71. class Input(Expr):
  72. name = None
  73. def __init__(self, name=None, type=None):
  74. self.inputs = []
  75. node_cls = type if type else Node
  76. self.outputs = [
  77. node_cls(self, name=name),
  78. ]
  79. self.name = name
  80. @classmethod
  81. def make(cls, *args, **kwargs):
  82. expr = cls(*args, **kwargs)
  83. active_module_tracer().current_scope().add_input(expr.outputs[0])
  84. return expr.outputs[0]
  85. def __repr__(self):
  86. return "{} = Input({})".format(self.outputs[0], self.name)
  87. # expr: outputs = getattr(inputs[0], self.name)
  88. class GetAttr(Expr):
  89. name = None
  90. def __init__(self, module, name, type=None):
  91. assert isinstance(module, ModuleNode)
  92. self.inputs = [
  93. module,
  94. ]
  95. module.users.append(self)
  96. self.name = name
  97. node_cls = type if type else Node
  98. self.outputs = [
  99. node_cls(self),
  100. ]
  101. @classmethod
  102. def make(cls, *args, **kwargs):
  103. expr = cls(*args, **kwargs)
  104. active_module_tracer().current_scope().insert(expr)
  105. expr.outputs[0]._name = expr.name
  106. return expr.outputs[0]
  107. def interpret(self, *inputs):
  108. return (getattr(inputs[0], self.name),)
  109. def __repr__(self):
  110. return '{} = GetAttr({}, "{}")'.format(
  111. self.outputs[0], self.inputs[0], self.name
  112. )
  113. # expr: outputs = inputs[0].__call__(*inputs[1:])
  114. class CallMethod(Expr):
  115. def __init__(self, node, method="__call__"):
  116. if isinstance(node, type):
  117. assert issubclass(node, Tensor)
  118. cls = Parameter if issubclass(node, Parameter) else Tensor
  119. self.inputs = []
  120. self.const_val = [(0, cls)]
  121. else:
  122. assert isinstance(node, (TensorNode, ModuleNode))
  123. node.users.append(self)
  124. self.inputs = [
  125. node,
  126. ]
  127. self.const_val = []
  128. self.method = method
  129. @classmethod
  130. def make(cls, *args, **kwargs):
  131. expr = cls(*args, **kwargs)
  132. active_module_tracer().current_scope().insert(expr)
  133. return expr
  134. @property
  135. def graph(self):
  136. if isinstance(self.inputs[0], ModuleNode):
  137. m_node = self.inputs[0]
  138. if (
  139. hasattr(m_node.owner, "argdef_graph_map")
  140. and m_node.owner.argdef_graph_map
  141. ):
  142. assert self.arg_def in m_node.owner.argdef_graph_map
  143. return m_node.owner.argdef_graph_map[self.arg_def]
  144. return None
  145. def interpret(self, *inputs):
  146. args, kwargs = self.unflatten_args(inputs)
  147. obj = args[0]
  148. meth = getattr(obj, self.method)
  149. if inspect.ismethod(meth):
  150. args = args[1:]
  151. outputs = getattr(obj, self.method)(*args, **kwargs)
  152. if outputs is None:
  153. return outputs
  154. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  155. return outputs
  156. def __repr__(self):
  157. args = ", ".join(str(i) for i in self.args[1:])
  158. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  159. return "{} = {}.{}({})".format(
  160. ", ".join(str(i) for i in self.outputs),
  161. self.args[0],
  162. self.method,
  163. ", ".join([args, kwargs]),
  164. )
  165. # expr: outputs = apply(self.opdef, *inputs)
  166. class Apply(Expr):
  167. opdef = None
  168. def __init__(self, opdef):
  169. assert isinstance(opdef, OpDef)
  170. self.opdef = opdef
  171. self.inputs = []
  172. @classmethod
  173. def make(cls, *args, **kwargs):
  174. expr = cls(*args, **kwargs)
  175. active_module_tracer().current_scope().insert(expr)
  176. return expr
  177. def interpret(self, *inputs):
  178. return apply(self.opdef, *inputs)
  179. def __repr__(self):
  180. return "{} = {}({})".format(
  181. ", ".join(str(i) for i in self.outputs),
  182. self.opdef,
  183. ", ".join(str(i) for i in self.inputs),
  184. )
  185. @classmethod
  186. def apply_module_trace_hook(cls, opdef, *inputs):
  187. for i in inputs:
  188. node = NodeMixin.get(i, None)
  189. if node is None: # capture as constant
  190. NodeMixin.wrap_safe(i, Constant.make(i))
  191. apply_node = cls.make(opdef)
  192. apply_node.add_inputs(inputs)
  193. assert not apply_node.const_val
  194. unset_module_tracing()
  195. outputs = apply(opdef, *inputs)
  196. set_module_tracing()
  197. apply_node.add_outputs(outputs)
  198. for n, v in zip(apply_node.outputs, outputs):
  199. NodeMixin.wrap_safe(v, n)
  200. return list(outputs)
  201. class CallFunction(Expr):
  202. def __init__(self, func):
  203. assert isinstance(func, Callable)
  204. self.func = func
  205. self.const_val = []
  206. self.inputs = []
  207. @classmethod
  208. def make(cls, *args, **kwargs):
  209. expr = cls(*args, **kwargs)
  210. active_module_tracer().current_scope().insert(expr)
  211. return expr
  212. def interpret(self, *inputs):
  213. args, kwargs = self.unflatten_args(inputs)
  214. outputs = self.func(*args, **kwargs)
  215. outputs = (
  216. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  217. )
  218. return outputs
  219. def __repr__(self):
  220. args = ", ".join(str(i) for i in self.args)
  221. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  222. return "{} = {}({})".format(
  223. ", ".join(str(i) for i in self.outputs),
  224. self.func.__module__ + "." + self.func.__name__,
  225. ", ".join([args, kwargs]),
  226. )
  227. # expr outputs = self.value
  228. class Constant(Expr):
  229. value = None
  230. # TODO: constant cache to reduce the size of dumped model
  231. _constant_cache = {}
  232. def __init__(self, c):
  233. assert isinstance(c, (RawTensor, Module))
  234. if isinstance(c, Module):
  235. assert module_tracer.is_builtin(c)
  236. self.value = c
  237. self.inputs = []
  238. node_cls = NodeMixin.get_wrapped_type(c)
  239. self.outputs = [
  240. node_cls(self),
  241. ]
  242. @classmethod
  243. def make(cls, *args, **kwargs):
  244. expr = cls(*args, **kwargs)
  245. active_module_tracer().current_scope().insert(expr)
  246. return expr.outputs[0]
  247. def interpret(self, *inputs):
  248. if isinstance(self.value, RawTensor):
  249. return Const(self.value.numpy())()
  250. return (self.value,)
  251. def __repr__(self):
  252. return "{} = Constant({})".format(self.outputs[0], type(self.value))
  253. def __getstate__(self):
  254. state = self.__dict__.copy()
  255. if isinstance(self.value, RawTensor):
  256. state["value"] = Tensor(self.value)
  257. return state

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台