You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. from typing import Callable, List
  11. from ...core._imperative_rt import OpDef
  12. from ...core._imperative_rt.core2 import Tensor as RawTensor
  13. from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
  14. from ...core.ops.special import Const
  15. from ...module import Module
  16. from ...tensor import Tensor
  17. from .module_tracer import active_module_tracer
  18. from .node import ModuleNode, Node, NodeMixin, TensorNode
  19. class Expr:
  20. """
  21. ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
  22. """
  23. inputs = None # type: List[Node]
  24. outputs = None # type: List[Node]
  25. def add_input(self, node):
  26. self.inputs.append(node)
  27. def add_outputs(self, outputs):
  28. self.outputs = []
  29. if not isinstance(outputs, collections.Sequence):
  30. outputs = (outputs,)
  31. for i in outputs:
  32. self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
  33. for i, node in zip(outputs, self.outputs,):
  34. NodeMixin.wrap_safe(i, node)
  35. @classmethod
  36. def get_args_node(cls, arg):
  37. """
  38. Create nodes by ``arg``, which may be a container.
  39. Return the same structure with arg.
  40. If ``arg`` was not Tensor or Module, it will be stored as const.
  41. :param arg: tensor, module or const.
  42. """
  43. if isinstance(arg, (RawTensor, Module)):
  44. if not NodeMixin.get(arg, None):
  45. NodeMixin.wrap_safe(arg, Constant.make(arg))
  46. return NodeMixin.get(arg)
  47. elif isinstance(arg, collections.abc.Sequence):
  48. seq_cls = type(arg)
  49. return seq_cls([Expr.get_args_node(a) for a in arg])
  50. else:
  51. # TODO: assert arg type
  52. return arg # as const
  53. @classmethod
  54. def get_arg_value(cls, inp_node, node2value):
  55. """
  56. Get values from node2value by inp_node, which may be a container.
  57. Return the same structure with inp_node.
  58. If ``inp_node`` was not in node2value, it is a const.
  59. :param inp_node: nodes.
  60. :param node2value: dict from node to tensor and module.
  61. """
  62. if inp_node in node2value:
  63. return node2value[inp_node]
  64. elif isinstance(inp_node, collections.abc.Sequence):
  65. seq_cls = type(inp_node)
  66. return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node])
  67. else:
  68. return inp_node
  69. # expr: None (i.e. fake expression which is used to mark input)
  70. class Input(Expr):
  71. name = None
  72. def __init__(self, name=None, type=None):
  73. self.inputs = []
  74. node_cls = type if type else Node
  75. self.outputs = [
  76. node_cls(self, name=name),
  77. ]
  78. self.name = name
  79. @classmethod
  80. def make(cls, *args, **kwargs):
  81. expr = cls(*args, **kwargs)
  82. active_module_tracer().current_scope().add_input(expr.outputs[0])
  83. return expr.outputs[0]
  84. def __repr__(self):
  85. return "{} = Input({})".format(self.outputs[0], self.name)
  86. # expr: outputs = getattr(inputs[0], self.name)
  87. class GetAttr(Expr):
  88. name = None
  89. def __init__(self, module, name, type=None):
  90. assert isinstance(module, ModuleNode)
  91. self.inputs = [
  92. module,
  93. ]
  94. self.name = name
  95. node_cls = type if type else Node
  96. self.outputs = [
  97. node_cls(self),
  98. ]
  99. @classmethod
  100. def make(cls, *args, **kwargs):
  101. expr = cls(*args, **kwargs)
  102. active_module_tracer().current_scope().insert(expr)
  103. expr.outputs[0]._name = expr.name
  104. return expr.outputs[0]
  105. def interpret(self, *inputs):
  106. return (getattr(inputs[0], self.name),)
  107. def __repr__(self):
  108. return '{} = GetAttr({}, "{}")'.format(
  109. self.outputs[0], self.inputs[0], self.name
  110. )
  111. # expr: outputs = inputs[0].__call__(*inputs[1:])
  112. class CallMethod(Expr):
  113. def __init__(self, module, method="__call__"):
  114. assert isinstance(module, (TensorNode, ModuleNode))
  115. self.inputs = [
  116. module,
  117. ]
  118. self.method = method
  119. self.arg_names = []
  120. self.kwargs = {} # const kwargs
  121. def add_input(self, node, arg_name=None):
  122. if arg_name == "self": # FIXME: <XP>
  123. return
  124. self.inputs.append(node)
  125. if arg_name is not None:
  126. self.arg_names.append(arg_name)
  127. @classmethod
  128. def make(cls, *args, **kwargs):
  129. expr = cls(*args, **kwargs)
  130. active_module_tracer().current_scope().insert(expr)
  131. return expr
  132. def interpret(self, *inputs):
  133. mod = inputs[0]
  134. args = inputs[1:]
  135. outputs = getattr(mod, self.method)(*args, **self.kwargs)
  136. if isinstance(outputs, RawTensor):
  137. outputs = (outputs,)
  138. return outputs
  139. def __repr__(self):
  140. return "{} = CallMethod({}, {})({})".format(
  141. ", ".join(str(i) for i in self.outputs),
  142. self.inputs[0],
  143. self.method,
  144. ", ".join(str(i) for i in self.inputs[1:]),
  145. )
  146. # expr: outputs = apply(self.opdef, *inputs)
  147. class Apply(Expr):
  148. opdef = None
  149. def __init__(self, opdef):
  150. assert isinstance(opdef, OpDef)
  151. self.opdef = opdef
  152. self.inputs = []
  153. @classmethod
  154. def make(cls, *args, **kwargs):
  155. expr = cls(*args, **kwargs)
  156. active_module_tracer().current_scope().insert(expr)
  157. return expr
  158. def interpret(self, *inputs):
  159. return apply(self.opdef, *inputs)
  160. def __repr__(self):
  161. return "{} = {}({})".format(
  162. ", ".join(str(i) for i in self.outputs),
  163. self.opdef,
  164. ", ".join(str(i) for i in self.inputs),
  165. )
  166. @classmethod
  167. def apply_module_trace_hook(cls, opdef, *inputs):
  168. for i in inputs:
  169. node = NodeMixin.get(i, None)
  170. if node is None: # capture as constant
  171. NodeMixin.wrap_safe(i, Constant.make(i))
  172. apply_node = cls.make(opdef)
  173. for i in inputs:
  174. apply_node.add_input(NodeMixin.get(i))
  175. unset_module_tracing()
  176. outputs = apply(opdef, *inputs)
  177. set_module_tracing()
  178. apply_node.add_outputs(outputs)
  179. for n, v in zip(apply_node.outputs, outputs):
  180. NodeMixin.wrap_safe(v, n)
  181. return list(outputs)
  182. class CallFunction(Expr):
  183. def __init__(self, func):
  184. assert isinstance(func, Callable)
  185. self.func = func
  186. self.inputs = []
  187. self.arg_names = []
  188. self.kwargs = {} # const kwargs
  189. def add_input(self, node, arg_name):
  190. self.inputs.append(node)
  191. self.arg_names.append(arg_name)
  192. @classmethod
  193. def make(cls, *args, **kwargs):
  194. expr = cls(*args, **kwargs)
  195. active_module_tracer().current_scope().insert(expr)
  196. return expr
  197. def interpret(self, *inputs):
  198. inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)])
  199. outputs = self.func(**inp_dict, **self.kwargs)
  200. outputs = (
  201. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  202. )
  203. return outputs
  204. def __repr__(self):
  205. return "{} = {}({})".format(
  206. ", ".join(str(i) for i in self.outputs),
  207. self.func.__module__ + "." + self.func.__name__,
  208. ", ".join(str(i) for i in self.inputs),
  209. )
  210. # expr outputs = self.value
  211. class Constant(Expr):
  212. value = None
  213. # TODO: constant cache to reduce the size of dumped model
  214. _constant_cache = {}
  215. def __init__(self, c):
  216. # TODO: type check, since not all types should be captured as constant
  217. self.value = c
  218. self.inputs = []
  219. node_cls = NodeMixin.get_wrapped_type(c)
  220. self.outputs = [
  221. node_cls(self),
  222. ]
  223. @classmethod
  224. def make(cls, *args, **kwargs):
  225. expr = cls(*args, **kwargs)
  226. active_module_tracer().current_scope().insert(expr)
  227. return expr.outputs[0]
  228. def interpret(self, *inputs):
  229. if isinstance(self.value, RawTensor):
  230. return Const(self.value.numpy())()
  231. return (self.value,)
  232. def __repr__(self):
  233. return "{} = Constant({})".format(self.outputs[0], self.value)
  234. def __getstate__(self):
  235. state = self.__dict__.copy()
  236. if isinstance(self.value, RawTensor):
  237. state["value"] = Tensor(self.value)
  238. return state

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台