You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import builtins
  10. import collections
  11. from typing import Callable, List
  12. from ...core._imperative_rt import OpDef
  13. from ...core._imperative_rt.core2 import Tensor as RawTensor
  14. from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
  15. from ...core.ops.special import Const
  16. from ...module import Module
  17. from ...tensor import Tensor
  18. from .module_tracer import active_module_tracer, module_tracer
  19. from .node import ModuleNode, Node, NodeMixin, TensorNode
  20. from .pytree import TreeDef
  21. class Expr:
  22. """
  23. ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
  24. """
  25. inputs = None # type: List[Node]
  26. outputs = None # type: List[Node]
  27. const_val = None # type: List[Any]
  28. arg_def = None # type: TreeDef
  29. def add_inputs(self, vals):
  30. if not isinstance(vals, collections.abc.Sequence):
  31. vals = (vals,)
  32. for val in vals:
  33. node = NodeMixin.get(val, None)
  34. if isinstance(node, (TensorNode, ModuleNode)):
  35. if node not in self.inputs:
  36. self.inputs.append(node)
  37. else:
  38. assert node is None
  39. assert type(val) in builtins.__dict__.values()
  40. idx = len(self.inputs) + len(self.const_val)
  41. self.const_val.append((idx, val))
  42. def add_outputs(self, outputs):
  43. self.outputs = []
  44. if not isinstance(outputs, collections.Sequence):
  45. outputs = (outputs,)
  46. for i in outputs:
  47. assert isinstance(i, RawTensor)
  48. self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
  49. for i, node in zip(outputs, self.outputs,):
  50. NodeMixin.wrap_safe(i, node)
  51. def unflatten_args(self, inputs):
  52. if self.arg_def is not None:
  53. inputs = list(inputs)
  54. for idx, val in self.const_val:
  55. inputs.insert(idx, val)
  56. args, kwargs = self.arg_def.unflatten(inputs)
  57. return args, kwargs
  58. else:
  59. return inputs, {}
  60. @property
  61. def kwargs(self):
  62. _, kwargs = self.unflatten_args(self.inputs)
  63. return kwargs
  64. @property
  65. def args(self):
  66. args, _ = self.unflatten_args(self.inputs)
  67. return args
  68. # expr: None (i.e. fake expression which is used to mark input)
  69. class Input(Expr):
  70. name = None
  71. def __init__(self, name=None, type=None):
  72. self.inputs = []
  73. node_cls = type if type else Node
  74. self.outputs = [
  75. node_cls(self, name=name),
  76. ]
  77. self.name = name
  78. @classmethod
  79. def make(cls, *args, **kwargs):
  80. expr = cls(*args, **kwargs)
  81. active_module_tracer().current_scope().add_input(expr.outputs[0])
  82. return expr.outputs[0]
  83. def __repr__(self):
  84. return "{} = Input({})".format(self.outputs[0], self.name)
  85. # expr: outputs = getattr(inputs[0], self.name)
  86. class GetAttr(Expr):
  87. name = None
  88. def __init__(self, module, name, type=None):
  89. assert isinstance(module, ModuleNode)
  90. self.inputs = [
  91. module,
  92. ]
  93. self.name = name
  94. node_cls = type if type else Node
  95. self.outputs = [
  96. node_cls(self),
  97. ]
  98. @classmethod
  99. def make(cls, *args, **kwargs):
  100. expr = cls(*args, **kwargs)
  101. active_module_tracer().current_scope().insert(expr)
  102. expr.outputs[0]._name = expr.name
  103. return expr.outputs[0]
  104. def interpret(self, *inputs):
  105. return (getattr(inputs[0], self.name),)
  106. def __repr__(self):
  107. return '{} = GetAttr({}, "{}")'.format(
  108. self.outputs[0], self.inputs[0], self.name
  109. )
  110. # expr: outputs = inputs[0].__call__(*inputs[1:])
  111. class CallMethod(Expr):
  112. def __init__(self, module, method="__call__"):
  113. assert isinstance(module, (TensorNode, ModuleNode))
  114. self.inputs = [
  115. module,
  116. ]
  117. self.const_val = []
  118. self.method = method
  119. @classmethod
  120. def make(cls, *args, **kwargs):
  121. expr = cls(*args, **kwargs)
  122. active_module_tracer().current_scope().insert(expr)
  123. return expr
  124. @property
  125. def graph(self):
  126. if isinstance(self.inputs[0], ModuleNode):
  127. m_node = self.inputs[0]
  128. if m_node.argdef_graph_map:
  129. assert self.arg_def in m_node.argdef_graph_map
  130. return m_node.argdef_graph_map[self.arg_def]
  131. return None
  132. def interpret(self, *inputs):
  133. args, kwargs = self.unflatten_args(inputs)
  134. obj = args[0]
  135. args = args[1:]
  136. outputs = getattr(obj, self.method)(*args, **kwargs)
  137. if isinstance(outputs, RawTensor):
  138. outputs = (outputs,)
  139. return outputs
  140. def __repr__(self):
  141. args = ", ".join(str(i) for i in self.args[1:])
  142. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  143. return "{} = {}.{}({})".format(
  144. ", ".join(str(i) for i in self.outputs),
  145. self.inputs[0],
  146. self.method,
  147. ", ".join([args, kwargs]),
  148. )
  149. # expr: outputs = apply(self.opdef, *inputs)
  150. class Apply(Expr):
  151. opdef = None
  152. def __init__(self, opdef):
  153. assert isinstance(opdef, OpDef)
  154. self.opdef = opdef
  155. self.inputs = []
  156. @classmethod
  157. def make(cls, *args, **kwargs):
  158. expr = cls(*args, **kwargs)
  159. active_module_tracer().current_scope().insert(expr)
  160. return expr
  161. def interpret(self, *inputs):
  162. return apply(self.opdef, *inputs)
  163. def __repr__(self):
  164. return "{} = {}({})".format(
  165. ", ".join(str(i) for i in self.outputs),
  166. self.opdef,
  167. ", ".join(str(i) for i in self.inputs),
  168. )
  169. @classmethod
  170. def apply_module_trace_hook(cls, opdef, *inputs):
  171. for i in inputs:
  172. node = NodeMixin.get(i, None)
  173. if node is None: # capture as constant
  174. NodeMixin.wrap_safe(i, Constant.make(i))
  175. apply_node = cls.make(opdef)
  176. for i in inputs:
  177. assert isinstance(i, RawTensor)
  178. apply_node.inputs.append(NodeMixin.get(i))
  179. unset_module_tracing()
  180. outputs = apply(opdef, *inputs)
  181. set_module_tracing()
  182. apply_node.add_outputs(outputs)
  183. for n, v in zip(apply_node.outputs, outputs):
  184. NodeMixin.wrap_safe(v, n)
  185. return list(outputs)
  186. class CallFunction(Expr):
  187. def __init__(self, func):
  188. assert isinstance(func, Callable)
  189. self.func = func
  190. self.const_val = []
  191. self.inputs = []
  192. @classmethod
  193. def make(cls, *args, **kwargs):
  194. expr = cls(*args, **kwargs)
  195. active_module_tracer().current_scope().insert(expr)
  196. return expr
  197. def interpret(self, *inputs):
  198. args, kwargs = self.unflatten_args(inputs)
  199. outputs = self.func(*args, **kwargs)
  200. outputs = (
  201. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  202. )
  203. return outputs
  204. def __repr__(self):
  205. args = ", ".join(str(i) for i in self.args)
  206. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  207. return "{} = {}({})".format(
  208. ", ".join(str(i) for i in self.outputs),
  209. self.func.__module__ + "." + self.func.__name__,
  210. ", ".join([args, kwargs]),
  211. )
  212. # expr outputs = self.value
  213. class Constant(Expr):
  214. value = None
  215. # TODO: constant cache to reduce the size of dumped model
  216. _constant_cache = {}
  217. def __init__(self, c):
  218. assert isinstance(c, (RawTensor, Module))
  219. if isinstance(c, Module):
  220. assert module_tracer.is_builtin(c)
  221. self.value = c
  222. self.inputs = []
  223. node_cls = NodeMixin.get_wrapped_type(c)
  224. self.outputs = [
  225. node_cls(self),
  226. ]
  227. @classmethod
  228. def make(cls, *args, **kwargs):
  229. expr = cls(*args, **kwargs)
  230. active_module_tracer().current_scope().insert(expr)
  231. return expr.outputs[0]
  232. def interpret(self, *inputs):
  233. if isinstance(self.value, RawTensor):
  234. return Const(self.value.numpy())()
  235. return (self.value,)
  236. def __repr__(self):
  237. return "{} = Constant({})".format(self.outputs[0], self.value)
  238. def __getstate__(self):
  239. state = self.__dict__.copy()
  240. if isinstance(self.value, RawTensor):
  241. state["value"] = Tensor(self.value)
  242. return state

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台