You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. from typing import List
  11. from ...core._imperative_rt import OpDef
  12. from ...core._imperative_rt.core2 import Tensor as RawTensor
  13. from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
  14. from ...core.ops.special import Const
  15. from ...tensor import Tensor
  16. from .module_tracer import active_module_tracer
  17. from .node import ModuleNode, Node, NodeMixin, TensorNode
  18. class Expr:
  19. """
  20. ``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``.
  21. """
  22. inputs = None # type: List[Node]
  23. outputs = None # type: List[Node]
  24. # expr: None (i.e. fake expression which is used to mark input)
  25. class Input(Expr):
  26. name = None
  27. def __init__(self, name=None, type=None):
  28. self.inputs = []
  29. node_cls = type if type else Node
  30. self.outputs = [
  31. node_cls(self, name=name),
  32. ]
  33. self.name = name
  34. @classmethod
  35. def make(cls, *args, **kwargs):
  36. expr = cls(*args, **kwargs)
  37. active_module_tracer().current_scope().add_input(expr.outputs[0])
  38. return expr.outputs[0]
  39. def __repr__(self):
  40. return "{} = Input({})".format(self.outputs[0], self.name)
  41. # expr: outputs = getattr(inputs[0], self.name)
  42. class GetAttr(Expr):
  43. name = None
  44. def __init__(self, module, name, type=None):
  45. assert isinstance(module, ModuleNode)
  46. self.inputs = [
  47. module,
  48. ]
  49. self.name = name
  50. node_cls = type if type else Node
  51. self.outputs = [
  52. node_cls(self),
  53. ]
  54. @classmethod
  55. def make(cls, *args, **kwargs):
  56. expr = cls(*args, **kwargs)
  57. active_module_tracer().current_scope().insert(expr)
  58. expr.outputs[0]._name = expr.name
  59. return expr.outputs[0]
  60. def interpret(self, *inputs):
  61. return (getattr(inputs[0], self.name),)
  62. def __repr__(self):
  63. return '{} = GetAttr({}, "{}")'.format(
  64. self.outputs[0], self.inputs[0], self.name
  65. )
  66. # expr: outputs = inputs[0].__call__(*inputs[1:])
  67. class Call(Expr):
  68. def __init__(self, module):
  69. assert isinstance(module, ModuleNode)
  70. self.inputs = [
  71. module,
  72. ]
  73. def add_input(self, node):
  74. self.inputs.append(node)
  75. def add_outputs(self, references):
  76. self.outputs = []
  77. if not isinstance(references, collections.Sequence):
  78. references = (references,)
  79. for i in references:
  80. self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
  81. @classmethod
  82. def make(cls, *args, **kwargs):
  83. expr = cls(*args, **kwargs)
  84. active_module_tracer().current_scope().insert(expr)
  85. return expr
  86. def interpret(self, *inputs):
  87. mod = inputs[0]
  88. args = inputs[1:]
  89. outputs = mod(*args)
  90. if isinstance(outputs, RawTensor):
  91. outputs = (outputs,)
  92. return outputs
  93. def __repr__(self):
  94. return "{} = Call({})({})".format(
  95. ", ".join(str(i) for i in self.outputs),
  96. self.inputs[0],
  97. ", ".join(str(i) for i in self.inputs[1:]),
  98. )
  99. # expr: outputs = apply(self.opdef, *inputs)
  100. class Apply(Expr):
  101. opdef = None
  102. def __init__(self, opdef):
  103. assert isinstance(opdef, OpDef)
  104. self.opdef = opdef
  105. self.inputs = []
  106. def add_input(self, node):
  107. self.inputs.append(node)
  108. def add_outputs(self, references):
  109. self.outputs = []
  110. if not isinstance(references, collections.Sequence):
  111. references = (references,)
  112. for i in references:
  113. self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
  114. @classmethod
  115. def make(cls, *args, **kwargs):
  116. expr = cls(*args, **kwargs)
  117. active_module_tracer().current_scope().insert(expr)
  118. return expr
  119. def interpret(self, *inputs):
  120. return apply(self.opdef, *inputs)
  121. def __repr__(self):
  122. return "{} = {}({})".format(
  123. ", ".join(str(i) for i in self.outputs),
  124. self.opdef,
  125. ", ".join(str(i) for i in self.inputs),
  126. )
  127. @classmethod
  128. def apply_module_trace_hook(cls, opdef, *inputs):
  129. for i in inputs:
  130. node = NodeMixin.get(i, None)
  131. if node is None: # capture as constant
  132. NodeMixin.wrap_safe(i, Constant.make(i))
  133. apply_node = cls.make(opdef)
  134. for i in inputs:
  135. apply_node.add_input(NodeMixin.get(i))
  136. unset_module_tracing()
  137. outputs = apply(opdef, *inputs)
  138. set_module_tracing()
  139. apply_node.add_outputs(outputs)
  140. for n, v in zip(apply_node.outputs, outputs):
  141. NodeMixin.wrap_safe(v, n)
  142. return list(outputs)
  143. # expr outputs = self.value
  144. class Constant(Expr):
  145. value = None
  146. # TODO: constant cache to reduce the size of dumped model
  147. _constant_cache = {}
  148. def __init__(self, c):
  149. # TODO: type check, since not all types should be captured as constant
  150. self.value = c
  151. self.inputs = []
  152. node_cls = NodeMixin.get_wrapped_type(c)
  153. self.outputs = [
  154. node_cls(self),
  155. ]
  156. @classmethod
  157. def make(cls, *args, **kwargs):
  158. expr = cls(*args, **kwargs)
  159. active_module_tracer().current_scope().insert(expr)
  160. return expr.outputs[0]
  161. def interpret(self, *inputs):
  162. if isinstance(self.value, RawTensor):
  163. return Const(self.value.numpy())()
  164. return (self.value,)
  165. def __repr__(self):
  166. return "{} = Constant({})".format(self.outputs[0], self.value)
  167. def __getstate__(self):
  168. state = self.__dict__.copy()
  169. if isinstance(self.value, RawTensor):
  170. state["value"] = Tensor(self.value)
  171. return state

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台