You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # -*- coding: utf-8 -*-
  2. from typing import Union
  3. import numpy as np
  4. from .core._imperative_rt import CompNode
  5. from .core._imperative_rt.core2 import Tensor as _Tensor
  6. from .core._imperative_rt.core2 import apply, set_py_tensor_type
  7. from .core._trace_option import use_symbolic_shape
  8. from .core._wrap import as_device
  9. from .core.ops.builtin import Borrow, Copy, GetVarShape
  10. from .core.tensor.array_method import ArrayMethodMixin
  11. from .device import _valid_device, get_default_device
  12. from .logger import get_logger
  13. from .utils.deprecation import deprecated
  14. logger = get_logger(__name__)
  15. class Tensor(_Tensor, ArrayMethodMixin):
  16. r"""A tensor object represents a multidimensional, homogeneous array of fixed-size items.
  17. Tensor is the primary MegEngine data structure.
  18. Data type(dtype) describes the format of each element, such as ``float32``, ``int8`` and so on,
  19. see :ref:`tensor-dtype` for more details.
  20. It is similar to :class:`numpy.ndarray` but not the same in the design.
  21. For example, GPU devices can be used to store Tensors and execute calculations in MegEngine.
  22. The concept of `view <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.view.html>`_
  23. does not exist in MegEngine so indexing and other behaviors might be different with NumPy.
  24. All manipulations and operations on/between Tensors could be found in the :mod:`~.megengine.functional` module.
  25. Keep in mind that they are **not in-place**, a new Tensor will always be returned and
  26. the original data will remain constant.
  27. For more information, refer to the :ref:`tensor-guide` topic.
  28. Args:
  29. data(Tensor, :class:`~.numpy.ndarray`, :class:`list` or Python number):
  30. The data used for construcing Tensor.
  31. Tensor could be constructed from a Python :class:`list` / :class:`tuple` or sequence;
  32. a NumPy :class:`~.numpy.ndarray` data structure; MegEngine builtin methods and so on.
  33. Refer to :ref:`tensor-creation` for more details.
  34. dtype(:attr:`~.Tensor.dtype`): The data type of returned Tensor. Infer from ``data`` if not specified.
  35. device(:attr:`~.Tensor.device`): The desired device of returned Tensor. Uses :func:`get_default_device` if not specified.
  36. is_const: Whether make it a ``ImutableTensor`` in tracing mode, refer to :class:`.jit.trace`.
  37. no_cache: Whether cache it for memory sharing.
  38. name: Used to improve convenience in graph operation on dumped model.
  39. .. note::
  40. There are some methods like :meth:`~.Tensor.reshape` / :meth:`~.Tensor.flatten` /
  41. :meth:`~.Tensor.transpose` / :meth:`~.Tensor.min` / :meth:`~.Tensor.max` /
  42. :meth:`~.Tensor.mean` / :meth:`~.Tensor.sum` / :meth:`~.Tensor.prod` implemented
  43. in ``Tensor`` class for convenience and historical reasons.
  44. But other methods implemented in the :mod:`~.megengine.functional` module will not be added here anymore,
  45. it is hard for maintaining and too many candidates will affect code completion experience.
  46. """
  47. grad = None
  48. dmap_callback = None
  49. _qparams = None
  50. _custom_name = ""
  51. _name = None
  52. _short_name = None
  53. _prefix = None
  54. def __init__(
  55. self,
  56. data: Union["Tensor", np.ndarray, list, int, float],
  57. dtype: np.dtype = None,
  58. device: str = None,
  59. is_const: bool = False,
  60. no_cache: bool = False,
  61. name: str = None,
  62. ):
  63. if name is None:
  64. name = ""
  65. else:
  66. self._set_name(name)
  67. self._custom_name = name
  68. self._name = name
  69. self._short_name = name
  70. self._prefix = None
  71. @property
  72. def shape(self) -> Union[tuple, "Tensor"]:
  73. r"""Returns a :class:`tuple` or a :class:`~.Tensor` represents tensor dimensions.
  74. Note:
  75. The shape of a tensor was usually represented by a :class:`tuple`.
  76. But if a tensor was treated as symbolic placeholder with tracing,
  77. it's shape could also be a :class:`~.Tensor`. See :class:`~.trace` for more details.
  78. The shape property is usually used to get the current shape of a tensor,
  79. but may also be used to reshape the tensor in-place by assigning a tuple of tensor dimensions to it.
  80. As with :func:`~.reshape`, one of the new shape dimensions can be -1,
  81. in which case its value is inferred from the size of the tensor and the remaining dimensions.
  82. """
  83. shape = super().shape
  84. if shape == () or not use_symbolic_shape():
  85. return shape
  86. return apply(GetVarShape(), self)[0]
  87. @property
  88. def _tuple_shape(self):
  89. return super().shape
  90. @property
  91. def device(self) -> CompNode:
  92. r"""Returns a string represents the device a :class:`~.Tensor` storaged on."""
  93. return super().device
  94. @property
  95. def dtype(self) -> np.dtype:
  96. r"""Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`."""
  97. return super().dtype
  98. @property
  99. def qparams(self):
  100. r"""Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`."""
  101. from .quantization.utils import create_qparams # pylint: disable=all
  102. if self._qparams is None:
  103. self._qparams = create_qparams()
  104. return self._qparams
  105. def numpy(self) -> np.ndarray:
  106. r"""Returns self :class:`~.Tensor` as a :class:`numpy.ndarray`."""
  107. return super().numpy()
  108. def detach(self):
  109. r"""Returns a new :class:`~.Tensor`, detached from the current graph."""
  110. return super().detach()
  111. def _reset(self, other):
  112. if not isinstance(other, _Tensor):
  113. other = Tensor(other, dtype=self.dtype, device=self.device)
  114. super()._reset(other)
  115. def __repr__(self):
  116. piece = "{}(".format(self.__class__.__name__)
  117. with np.printoptions(precision=4, suppress=True):
  118. piece += "{}".format(str(self.numpy()))
  119. if self.dtype != np.float32:
  120. piece += ", dtype={}".format(np.dtype(self.dtype).name)
  121. piece += ", device={}".format(self.device) + ")"
  122. return piece
  123. @property
  124. def name(self):
  125. return self._custom_name
  126. @name.setter
  127. def name(self, name):
  128. self._custom_name = name
  129. self._name = self._prefix + "." + name if self._prefix else name
  130. self._set_name(self._name)
  131. @deprecated(
  132. version="1.0", reason="please use ``tensor_name[...] = value``",
  133. )
  134. def set_value(self, value):
  135. self._reset(value)
  136. @deprecated(version="1.0", reason="use ``*= 0`` instead")
  137. def reset_zero(self):
  138. self *= 0
  139. def to(self, device, *, _borrow=False):
  140. r"""Copy self :class:`~.Tensor` to specified device. See :func:`~.copy`"""
  141. if isinstance(device, str) and not _valid_device(device):
  142. raise ValueError(
  143. "invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format(
  144. device
  145. )
  146. )
  147. cn = as_device(device).to_c()
  148. op = Borrow(comp_node=cn) if _borrow else Copy(comp_node=cn)
  149. return apply(op, self)[0]
  150. @property
  151. def requires_grad(self):
  152. raise AttributeError("requires_grad is reserved for future use")
  153. @requires_grad.setter
  154. def requires_grad(self, value):
  155. raise AttributeError("requires_grad is reserved for future use")
  156. @requires_grad.deleter
  157. def requires_grad(self):
  158. raise AttributeError("requires_grad is reserved for future use")
  159. def __hash__(self):
  160. return id(self)
  161. def __getnewargs__(self):
  162. r"""__getnewargs__ will be called for pickle serialization or deep copy"""
  163. return (self.numpy(), self.dtype, self.device.logical_name)
  164. def __getstate__(self):
  165. r"""__getstate__ will be called for pickle serialization or deep copy"""
  166. state = {}
  167. if self._qparams is not None:
  168. state["qparams"] = self._qparams
  169. return state
  170. def __setstate__(self, state):
  171. # for compatibility with old version not using fastcore
  172. if "data" in state:
  173. data = state.pop("data")
  174. device = state.pop("device")
  175. dtype = state.pop("dtype")
  176. self._reset(Tensor(data, dtype=dtype, device=device))
  177. # quantize related state for deepcopy
  178. if "qdict" in state:
  179. qparams = state.pop("qdict")
  180. logger.warning(
  181. "Tensor's 'qdict' state is depreciated. Use 'qparams' instead"
  182. )
  183. elif "qparams" in state:
  184. qparams = state.pop("qparams")
  185. else:
  186. qparams = None
  187. self._qparams = qparams
  188. set_py_tensor_type(Tensor)
  189. tensor = Tensor
  190. class Parameter(Tensor):
  191. r"""A kind of Tensor that is to be considered a module parameter.
  192. Note:
  193. Operations happened on Parameter usually return a Tensor instead of Parameter.
  194. For example, with a Parameter ``x``, ``x.reshape/to/sum/...`` will result into a Tensor.
  195. Any operations between Parameter and Tensor will have Tensor as outputs.
  196. """