You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import abc
  10. import weakref
  11. from typing import Any, Dict, List, Tuple, Type
  12. import numpy
  13. from ...core._imperative_rt.core2 import Tensor as RawTensor
  14. from ...module import Module
  15. from ...tensor import Tensor
  16. from .pytree import TreeDef
  17. class Node:
  18. """
  19. ``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method. They are inputs/outputs of Expr(the operations on variables).
  20. param expr: the Expr which produces the node
  21. param name: the name of the node
  22. """
  23. expr = None
  24. __total_id = 0
  25. _id = None
  26. _name = None
  27. _top_graph = None # type: weakref.ReferenceType
  28. def __init__(self, expr: "Expr", name: str = None):
  29. self.expr = expr
  30. self.users = [] # List[Expr]
  31. self._id = Node.__total_id
  32. Node.__total_id += 1
  33. self._name = name
  34. def __setstate__(self, d):
  35. self.__dict__ = d
  36. Node.__total_id = max(Node.__total_id, self._id) + 1
  37. def __repr__(self):
  38. if self._name is None:
  39. return "%{}".format(self._id)
  40. else:
  41. return "%{}".format(self._name)
  42. @property
  43. def top_graph(self):
  44. if self._top_graph:
  45. return self._top_graph()
  46. return None
  47. class ModuleNode(Node):
  48. """
  49. ``ModuleNode`` represents the Module objects.
  50. Attributes:
  51. module_type: type of the Module correspending to the ModuleNode
  52. graph: the InternalGraph which will be interpreted when call Module's forward method
  53. attr_type_map: record the type of Module's attributes
  54. """
  55. module_type = Module # type: Type[Module]
  56. _owner = None # type: weakref.ReferenceType
  57. def __init__(self, expr: "Expr", name: str = None):
  58. super().__init__(expr, name)
  59. self.actual_mnode = []
  60. def __repr__(self):
  61. if self._name is None:
  62. return "%{}_({})".format(self._id, self.module_type.__name__)
  63. else:
  64. return "%{}_{}({})".format(self._id, self._name, self.module_type.__name__)
  65. def __getstate__(self):
  66. return {
  67. "expr": self.expr,
  68. "users": self.users,
  69. "_id": self._id,
  70. "_name": self._name,
  71. "module_type": self.module_type,
  72. }
  73. @property
  74. def owner(self):
  75. if self._owner:
  76. return self._owner()
  77. return None
  78. class TensorNode(Node):
  79. """
  80. ``TensorNode`` represents the Tensor objects.
  81. """
  82. shape = None # type: Tuple[int]
  83. dtype = None # type: numpy.dtype
  84. def __repr__(self):
  85. if self._name is None:
  86. return "%{}_(Tensor)".format(self._id)
  87. else:
  88. return "%{}_{}(Tensor)".format(self._id, self._name)
  89. class NodeMixin(abc.ABC):
  90. __node = None
  91. @abc.abstractmethod
  92. def _record_wrapped_nodes(self, node):
  93. # record the nodes which had been bound to this NodeMixin
  94. pass
  95. @classmethod
  96. def wrap(cls, value, node):
  97. if isinstance(value, (NodeMixin, RawTensor)):
  98. if isinstance(node, Node):
  99. if isinstance(value, RawTensor):
  100. node.dtype = value.dtype
  101. node.shape = (
  102. value._tuple_shape if isinstance(value, Tensor) else value.shape
  103. )
  104. if isinstance(value, NodeMixin):
  105. value._record_wrapped_nodes(node)
  106. setattr(value, "_NodeMixin__node", node)
  107. else:
  108. assert callable(node)
  109. n = node()
  110. assert isinstance(n, Node)
  111. if isinstance(value, RawTensor):
  112. n.dtype = value.dtype
  113. n.shape = (
  114. value._tuple_shape if isinstance(value, Tensor) else value.shape
  115. )
  116. if isinstance(value, NodeMixin):
  117. value._record_wrapped_nodes(n)
  118. setattr(value, "_NodeMixin__node", n)
  119. @classmethod
  120. def wrap_safe(cls, value, node):
  121. assert isinstance(value, (NodeMixin, RawTensor))
  122. if isinstance(value, RawTensor):
  123. node.dtype = value.dtype
  124. node.shape = (
  125. value._tuple_shape if isinstance(value, Tensor) else value.shape
  126. )
  127. setattr(value, "_NodeMixin__node", node)
  128. if isinstance(value, NodeMixin):
  129. value._record_wrapped_nodes(node)
  130. @classmethod
  131. def get(cls, value, *default):
  132. return getattr(value, "_NodeMixin__node", *default)
  133. @classmethod
  134. def get_wrapped_type(cls, value):
  135. if isinstance(value, RawTensor):
  136. return TensorNode
  137. if isinstance(value, (Module, NodeMixin)):
  138. return ModuleNode
  139. return Node

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台