You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Union
  10. import numpy as np
  11. from .core._imperative_rt import CompNode
  12. from .core._imperative_rt.core2 import Tensor as _Tensor
  13. from .core._imperative_rt.core2 import apply
  14. from .core._trace_option import use_symbolic_shape
  15. from .core._wrap import device as as_device
  16. from .core.ops.builtin import Copy, GetVarShape
  17. from .core.tensor.array_method import ArrayMethodMixin
  18. from .device import _valid_device, get_default_device
  19. from .logger import get_logger
  20. from .utils.deprecation import deprecated
  21. from .utils.naming import auto_naming
  22. class Tensor(_Tensor, ArrayMethodMixin):
  23. r"""
  24. A tensor object represents a multidimensional, homogeneous array of fixed-size items.
  25. """
  26. grad = None
  27. dmap_callback = None
  28. _q_dict = None
  29. def __new__(
  30. cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=None
  31. ):
  32. if device is None:
  33. cn = get_default_device()
  34. elif isinstance(device, str):
  35. if cls.dmap_callback is not None:
  36. cn = CompNode(cls.dmap_callback(device))
  37. else:
  38. cn = CompNode(device)
  39. else:
  40. if isinstance(device, CompNode):
  41. cn = device
  42. else:
  43. cn = device._cn
  44. if isinstance(data, _Tensor):
  45. if dtype is not None:
  46. get_logger().warning(
  47. "dtype does not work when creating a new Tensor with another Tensor"
  48. )
  49. obj = _Tensor.__new__(cls, data)
  50. else:
  51. if isinstance(data, np.ndarray):
  52. if 0 in data.strides:
  53. data = data.squeeze().reshape(data.shape)
  54. obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name)
  55. return obj
  56. @property
  57. def shape(self) -> Union[tuple, "Tensor"]:
  58. r"""
  59. Returns a :class:`tuple` or a :class:`~.Tensor` represents tensor dimensions.
  60. .. note::
  61. The shape of a tensor was usually represented by a :class:`tuple`.
  62. But if a tensor was treated as symbolic placeholder with tracing,
  63. it's shape could also be a :class:`~.Tensor`. See :class:`~.trace` for more details.
  64. The shape property is usually used to get the current shape of a tensor,
  65. but may also be used to reshape the tensor in-place by assigning a tuple of tensor dimensions to it.
  66. As with :func:`~.reshape`, one of the new shape dimensions can be -1,
  67. in which case its value is inferred from the size of the tensor and the remaining dimensions.
  68. """
  69. shape = super().shape
  70. if shape == () or not use_symbolic_shape():
  71. return shape
  72. return apply(GetVarShape(), self)[0]
  73. @property
  74. def _tuple_shape(self):
  75. return super().shape
  76. @property
  77. def device(self) -> CompNode:
  78. r"""
  79. Returns a string represents the device a :class:`~.Tensor` storaged on.
  80. """
  81. return super().device
  82. @property
  83. def dtype(self) -> np.dtype:
  84. r"""
  85. Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`.
  86. """
  87. return super().dtype
  88. @property
  89. def q_dict(self):
  90. if self._q_dict is None:
  91. self._q_dict = {"mode": None, "scale": None, "zero_point": None}
  92. return self._q_dict
  93. def numpy(self) -> np.ndarray:
  94. r"""
  95. Returns self :class:`~.Tensor` as a :class:`numpy.ndarray`.
  96. """
  97. return super().numpy()
  98. def detach(self):
  99. r"""
  100. Returns a new :class:`~.Tensor`, detached from the current graph.
  101. """
  102. return super().detach()
  103. def _reset(self, other):
  104. super()._reset(other)
  105. def __repr__(self):
  106. piece = "Tensor("
  107. with np.printoptions(precision=4, suppress=True):
  108. piece += "{}".format(str(self.numpy()))
  109. if self.dtype != np.float32:
  110. piece += ", dtype={}".format(np.dtype(self.dtype).name)
  111. piece += ", device={}".format(self.device) + ")"
  112. return piece
  113. @property
  114. def name(self):
  115. return self.c_name
  116. @name.setter
  117. def name(self, name):
  118. self.c_name = name
  119. auto_naming.record_var_name(self._mixin_handle, name)
  120. @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
  121. def set_value(self, value):
  122. if not isinstance(value, _Tensor):
  123. value = Tensor(value, dtype=self.dtype, device=self.device)
  124. self._reset(value)
  125. @deprecated(version="1.0", reason="use *= 0 instead")
  126. def reset_zero(self):
  127. self *= 0
  128. def to(self, device):
  129. r"""
  130. Copy self :class:`~.Tensor` to specified device. See :func:`~.copy`
  131. """
  132. if isinstance(device, str) and not _valid_device(device):
  133. raise ValueError(
  134. "invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format(
  135. device
  136. )
  137. )
  138. cn = as_device(device).to_c()
  139. return apply(Copy(comp_node=cn), self)[0]
  140. @property
  141. def requires_grad(self):
  142. raise AttributeError("requires_grad is reserved for future use")
  143. @requires_grad.setter
  144. def requires_grad(self, value):
  145. raise AttributeError("requires_grad is reserved for future use")
  146. @requires_grad.deleter
  147. def requires_grad(self):
  148. raise AttributeError("requires_grad is reserved for future use")
  149. def __hash__(self):
  150. return id(self)
  151. def __getnewargs__(self):
  152. r""" __getnewargs__ will be called for pickle serialization or deep copy
  153. """
  154. return (self.numpy(), self.dtype, self.device.logical_name)
  155. def __getstate__(self):
  156. r""" __getstate__ will be called for pickle serialization or deep copy
  157. """
  158. state = {
  159. "qdict": self.q_dict,
  160. }
  161. return state
  162. def __setstate__(self, state):
  163. self._q_dict = state.pop("qdict")
  164. tensor = Tensor
  165. class Parameter(Tensor):
  166. r"""
  167. A kind of Tensor that is to be considered a module parameter.
  168. """

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台