You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import builtins
  10. import collections
  11. from typing import Callable, List
  12. from ...core._imperative_rt import OpDef
  13. from ...core._imperative_rt.core2 import Tensor as RawTensor
  14. from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
  15. from ...core.ops.special import Const
  16. from ...module import Module
  17. from ...tensor import Tensor
  18. from .module_tracer import active_module_tracer
  19. from .node import ModuleNode, Node, NodeMixin, TensorNode
  20. from .pytree import TreeDef
  21. class Expr:
  22. """
  23. ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
  24. """
  25. inputs = None # type: List[Node]
  26. outputs = None # type: List[Node]
  27. const_val = None # type: List[Any]
  28. arg_def = None # type: TreeDef
  29. def add_inputs(self, vals):
  30. if not isinstance(vals, collections.abc.Sequence):
  31. vals = (vals,)
  32. for val in vals:
  33. node = NodeMixin.get(val, None)
  34. if isinstance(node, (TensorNode, ModuleNode)):
  35. if node not in self.inputs:
  36. self.inputs.append(node)
  37. else:
  38. assert node is None
  39. assert type(val) in builtins.__dict__.values()
  40. idx = len(self.inputs) + len(self.const_val)
  41. self.const_val.append((idx, val))
  42. def add_outputs(self, outputs):
  43. self.outputs = []
  44. if not isinstance(outputs, collections.Sequence):
  45. outputs = (outputs,)
  46. for i in outputs:
  47. assert isinstance(i, RawTensor)
  48. self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
  49. for i, node in zip(outputs, self.outputs,):
  50. NodeMixin.wrap_safe(i, node)
  51. def unflatten_args(self, inputs):
  52. if self.arg_def is not None:
  53. inputs = list(inputs)
  54. for idx, val in self.const_val:
  55. inputs.insert(idx, val)
  56. args, kwargs = self.arg_def.unflatten(inputs)
  57. return args, kwargs
  58. else:
  59. return inputs, {}
  60. @property
  61. def kwargs(self):
  62. _, kwargs = self.unflatten_args(self.inputs)
  63. return kwargs
  64. @property
  65. def args(self):
  66. args, _ = self.unflatten_args(self.inputs)
  67. return args
  68. # expr: None (i.e. fake expression which is used to mark input)
  69. class Input(Expr):
  70. name = None
  71. def __init__(self, name=None, type=None):
  72. self.inputs = []
  73. node_cls = type if type else Node
  74. self.outputs = [
  75. node_cls(self, name=name),
  76. ]
  77. self.name = name
  78. @classmethod
  79. def make(cls, *args, **kwargs):
  80. expr = cls(*args, **kwargs)
  81. active_module_tracer().current_scope().add_input(expr.outputs[0])
  82. return expr.outputs[0]
  83. def __repr__(self):
  84. return "{} = Input({})".format(self.outputs[0], self.name)
  85. # expr: outputs = getattr(inputs[0], self.name)
  86. class GetAttr(Expr):
  87. name = None
  88. def __init__(self, module, name, type=None):
  89. assert isinstance(module, ModuleNode)
  90. self.inputs = [
  91. module,
  92. ]
  93. self.name = name
  94. node_cls = type if type else Node
  95. self.outputs = [
  96. node_cls(self),
  97. ]
  98. @classmethod
  99. def make(cls, *args, **kwargs):
  100. expr = cls(*args, **kwargs)
  101. active_module_tracer().current_scope().insert(expr)
  102. expr.outputs[0]._name = expr.name
  103. return expr.outputs[0]
  104. def interpret(self, *inputs):
  105. return (getattr(inputs[0], self.name),)
  106. def __repr__(self):
  107. return '{} = GetAttr({}, "{}")'.format(
  108. self.outputs[0], self.inputs[0], self.name
  109. )
  110. # expr: outputs = inputs[0].__call__(*inputs[1:])
  111. class CallMethod(Expr):
  112. def __init__(self, module, method="__call__"):
  113. assert isinstance(module, (TensorNode, ModuleNode))
  114. self.inputs = [
  115. module,
  116. ]
  117. self.const_val = []
  118. self.method = method
  119. @classmethod
  120. def make(cls, *args, **kwargs):
  121. expr = cls(*args, **kwargs)
  122. active_module_tracer().current_scope().insert(expr)
  123. return expr
  124. def interpret(self, *inputs):
  125. args, kwargs = self.unflatten_args(inputs)
  126. obj = args[0]
  127. args = args[1:]
  128. outputs = getattr(obj, self.method)(*args, **kwargs)
  129. if isinstance(outputs, RawTensor):
  130. outputs = (outputs,)
  131. return outputs
  132. def __repr__(self):
  133. args = ", ".join(str(i) for i in self.args[1:])
  134. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  135. return "{} = {}.{}({})".format(
  136. ", ".join(str(i) for i in self.outputs),
  137. self.inputs[0],
  138. self.method,
  139. ", ".join([args, kwargs]),
  140. )
  141. # expr: outputs = apply(self.opdef, *inputs)
  142. class Apply(Expr):
  143. opdef = None
  144. def __init__(self, opdef):
  145. assert isinstance(opdef, OpDef)
  146. self.opdef = opdef
  147. self.inputs = []
  148. @classmethod
  149. def make(cls, *args, **kwargs):
  150. expr = cls(*args, **kwargs)
  151. active_module_tracer().current_scope().insert(expr)
  152. return expr
  153. def interpret(self, *inputs):
  154. return apply(self.opdef, *inputs)
  155. def __repr__(self):
  156. return "{} = {}({})".format(
  157. ", ".join(str(i) for i in self.outputs),
  158. self.opdef,
  159. ", ".join(str(i) for i in self.inputs),
  160. )
  161. @classmethod
  162. def apply_module_trace_hook(cls, opdef, *inputs):
  163. for i in inputs:
  164. node = NodeMixin.get(i, None)
  165. if node is None: # capture as constant
  166. NodeMixin.wrap_safe(i, Constant.make(i))
  167. apply_node = cls.make(opdef)
  168. for i in inputs:
  169. apply_node.add_input(NodeMixin.get(i))
  170. unset_module_tracing()
  171. outputs = apply(opdef, *inputs)
  172. set_module_tracing()
  173. apply_node.add_outputs(outputs)
  174. for n, v in zip(apply_node.outputs, outputs):
  175. NodeMixin.wrap_safe(v, n)
  176. return list(outputs)
  177. class CallFunction(Expr):
  178. def __init__(self, func):
  179. assert isinstance(func, Callable)
  180. self.func = func
  181. self.const_val = []
  182. self.inputs = []
  183. @classmethod
  184. def make(cls, *args, **kwargs):
  185. expr = cls(*args, **kwargs)
  186. active_module_tracer().current_scope().insert(expr)
  187. return expr
  188. def interpret(self, *inputs):
  189. args, kwargs = self.unflatten_args(inputs)
  190. outputs = self.func(*args, **kwargs)
  191. outputs = (
  192. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  193. )
  194. return outputs
  195. def __repr__(self):
  196. args = ", ".join(str(i) for i in self.args)
  197. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  198. return "{} = {}({})".format(
  199. ", ".join(str(i) for i in self.outputs),
  200. self.func.__module__ + "." + self.func.__name__,
  201. ", ".join([args, kwargs]),
  202. )
  203. # expr outputs = self.value
  204. class Constant(Expr):
  205. value = None
  206. # TODO: constant cache to reduce the size of dumped model
  207. _constant_cache = {}
  208. def __init__(self, c):
  209. # TODO: type check, since not all types should be captured as constant
  210. self.value = c
  211. self.inputs = []
  212. node_cls = NodeMixin.get_wrapped_type(c)
  213. self.outputs = [
  214. node_cls(self),
  215. ]
  216. @classmethod
  217. def make(cls, *args, **kwargs):
  218. expr = cls(*args, **kwargs)
  219. active_module_tracer().current_scope().insert(expr)
  220. return expr.outputs[0]
  221. def interpret(self, *inputs):
  222. if isinstance(self.value, RawTensor):
  223. return Const(self.value.numpy())()
  224. return (self.value,)
  225. def __repr__(self):
  226. return "{} = Constant({})".format(self.outputs[0], self.value)
  227. def __getstate__(self):
  228. state = self.__dict__.copy()
  229. if isinstance(self.value, RawTensor):
  230. state["value"] = Tensor(self.value)
  231. return state

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台