You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Union
  10. import numpy as np
  11. from .core._imperative_rt import CompNode
  12. from .core._imperative_rt.core2 import Tensor as _Tensor
  13. from .core._imperative_rt.core2 import apply
  14. from .core._trace_option import use_symbolic_shape
  15. from .core._wrap import as_device
  16. from .core.ops.builtin import Copy, GetVarShape
  17. from .core.tensor.array_method import ArrayMethodMixin
  18. from .device import _valid_device, get_default_device
  19. from .logger import get_logger
  20. from .utils.deprecation import deprecated
  21. from .utils.naming import AutoNaming
  22. logger = get_logger(__name__)
  23. class Tensor(_Tensor, ArrayMethodMixin):
  24. r"""
  25. A tensor object represents a multidimensional, homogeneous array of fixed-size items.
  26. :param data: The value of returned Tensor.
  27. :type data: Tensor, :class:`~.numpy.ndarray`, :class:`list` or python number.
  28. :param dtype: The dtype of returned Tensor. Uses data's dtype if not specified.
  29. :param device: The desired device of returned Tensor. Uses :func:`get_default_device` if not specified.
  30. :param is_const: Whether make it a ``ImutableTensor`` in tracing mode.
  31. :param no_cache: Whether cache it for memory sharing.
  32. :param name: Used to improve convenience in graph operation on dumped model.
  33. """
  34. grad = None
  35. dmap_callback = None
  36. _qparams = None
  37. def __new__(
  38. cls,
  39. data: Union["Tensor", np.ndarray, list, int, float] = None,
  40. dtype: np.dtype = None,
  41. device: str = None,
  42. is_const: bool = False,
  43. no_cache: bool = False,
  44. name: str = None,
  45. ):
  46. if data is None:
  47. data = []
  48. if device is None:
  49. cn = get_default_device()
  50. elif isinstance(device, str):
  51. if cls.dmap_callback is not None:
  52. cn = CompNode(cls.dmap_callback(device))
  53. else:
  54. cn = CompNode(device)
  55. else:
  56. if isinstance(device, CompNode):
  57. cn = device
  58. else:
  59. cn = device._cn
  60. if isinstance(data, _Tensor):
  61. obj = _Tensor.__new__(cls, data)
  62. else:
  63. if isinstance(data, np.ndarray):
  64. if 0 in data.strides:
  65. data = data.squeeze().reshape(data.shape)
  66. obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name)
  67. return obj
  68. def __init__(
  69. self,
  70. data: Union["Tensor", np.ndarray, list, int, float],
  71. dtype: np.dtype = None,
  72. device: str = None,
  73. is_const: bool = False,
  74. no_cache: bool = False,
  75. name: str = None,
  76. ):
  77. pass
  78. @property
  79. def shape(self) -> Union[tuple, "Tensor"]:
  80. r"""
  81. Returns a :class:`tuple` or a :class:`~.Tensor` represents tensor dimensions.
  82. .. note::
  83. The shape of a tensor was usually represented by a :class:`tuple`.
  84. But if a tensor was treated as symbolic placeholder with tracing,
  85. it's shape could also be a :class:`~.Tensor`. See :class:`~.trace` for more details.
  86. The shape property is usually used to get the current shape of a tensor,
  87. but may also be used to reshape the tensor in-place by assigning a tuple of tensor dimensions to it.
  88. As with :func:`~.reshape`, one of the new shape dimensions can be -1,
  89. in which case its value is inferred from the size of the tensor and the remaining dimensions.
  90. """
  91. shape = super().shape
  92. if shape == () or not use_symbolic_shape():
  93. return shape
  94. return apply(GetVarShape(), self)[0]
  95. @property
  96. def _tuple_shape(self):
  97. return super().shape
  98. @property
  99. def device(self) -> CompNode:
  100. r"""
  101. Returns a string represents the device a :class:`~.Tensor` storaged on.
  102. """
  103. return super().device
  104. @property
  105. def dtype(self) -> np.dtype:
  106. r"""
  107. Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`.
  108. """
  109. return super().dtype
  110. @property
  111. def qparams(self):
  112. r"""
  113. Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`.
  114. """
  115. from .quantization.utils import create_qparams # pylint: disable=all
  116. if self._qparams is None:
  117. self._qparams = create_qparams()
  118. return self._qparams
  119. def numpy(self) -> np.ndarray:
  120. r"""
  121. Returns self :class:`~.Tensor` as a :class:`numpy.ndarray`.
  122. """
  123. return super().numpy()
  124. def detach(self):
  125. r"""
  126. Returns a new :class:`~.Tensor`, detached from the current graph.
  127. """
  128. return super().detach()
  129. def _reset(self, other):
  130. if not isinstance(other, _Tensor):
  131. other = Tensor(other, dtype=self.dtype, device=self.device)
  132. super()._reset(other)
  133. def __repr__(self):
  134. piece = "{}(".format(self.__class__.__name__)
  135. with np.printoptions(precision=4, suppress=True):
  136. piece += "{}".format(str(self.numpy()))
  137. if self.dtype != np.float32:
  138. piece += ", dtype={}".format(np.dtype(self.dtype).name)
  139. piece += ", device={}".format(self.device) + ")"
  140. return piece
  141. @property
  142. def name(self):
  143. return self.c_name
  144. @name.setter
  145. def name(self, name):
  146. self.c_name = name
  147. AutoNaming.record_var_name(self._mixin_handle, name)
  148. @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
  149. def set_value(self, value):
  150. self._reset(value)
  151. @deprecated(version="1.0", reason="use ``*= 0`` instead")
  152. def reset_zero(self):
  153. self *= 0
  154. def to(self, device):
  155. r"""
  156. Copy self :class:`~.Tensor` to specified device. See :func:`~.copy`
  157. """
  158. if isinstance(device, str) and not _valid_device(device):
  159. raise ValueError(
  160. "invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format(
  161. device
  162. )
  163. )
  164. cn = as_device(device).to_c()
  165. return apply(Copy(comp_node=cn), self)[0]
  166. @property
  167. def requires_grad(self):
  168. raise AttributeError("requires_grad is reserved for future use")
  169. @requires_grad.setter
  170. def requires_grad(self, value):
  171. raise AttributeError("requires_grad is reserved for future use")
  172. @requires_grad.deleter
  173. def requires_grad(self):
  174. raise AttributeError("requires_grad is reserved for future use")
  175. def __hash__(self):
  176. return id(self)
  177. def __getnewargs__(self):
  178. r""" __getnewargs__ will be called for pickle serialization or deep copy
  179. """
  180. return (self.numpy(), self.dtype, self.device.logical_name)
  181. def __getstate__(self):
  182. r""" __getstate__ will be called for pickle serialization or deep copy
  183. """
  184. state = {}
  185. if self._qparams is not None:
  186. state["qparams"] = self._qparams
  187. return state
  188. def __setstate__(self, state):
  189. # for compatibility with old version not using fastcore
  190. if "data" in state:
  191. data = state.pop("data")
  192. device = state.pop("device")
  193. dtype = state.pop("dtype")
  194. self._reset(Tensor(data, dtype=dtype, device=device))
  195. # quantize related state for deepcopy
  196. if "qdict" in state:
  197. qparams = state.pop("qdict")
  198. logger.warning(
  199. "Tensor's 'qdict' state is depreciated. Use 'qparams' instead"
  200. )
  201. elif "qparams" in state:
  202. qparams = state.pop("qparams")
  203. else:
  204. qparams = None
  205. self._qparams = qparams
  206. tensor = Tensor
  207. class Parameter(Tensor):
  208. r"""
  209. A kind of Tensor that is to be considered a module parameter.
  210. .. note::
  211. Operations happened on Parameter usually return a Tensor instead of Parameter.
  212. For example, with a Parameter ``x``, ``x.reshape/to/sum/...`` will result into a Tensor.
  213. Any operations between Parameter and Tensor will have Tensor as outputs.
  214. """

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台