You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # -*- coding: utf-8 -*-
  2. from typing import Union
  3. import numpy as np
  4. from .core._imperative_rt import CompNode
  5. from .core._imperative_rt.core2 import FormatType
  6. from .core._imperative_rt.core2 import Tensor as _Tensor
  7. from .core._imperative_rt.core2 import apply, set_py_tensor_type
  8. from .core._trace_option import use_symbolic_shape
  9. from .core._wrap import as_device
  10. from .core.ops.builtin import Borrow, Copy, GetVarShape
  11. from .core.tensor.array_method import ArrayMethodMixin
  12. from .device import _valid_device, get_default_device
  13. from .logger import get_logger
  14. from .utils.deprecation import deprecated
  15. logger = get_logger(__name__)
  16. class Tensor(_Tensor, ArrayMethodMixin):
  17. r"""A tensor object represents a multidimensional, homogeneous array of fixed-size items.
  18. Tensor is the primary MegEngine data structure.
  19. Data type(dtype) describes the format of each element, such as ``float32``, ``int8`` and so on,
  20. see :ref:`tensor-dtype` for more details.
  21. It is similar to :class:`numpy.ndarray` but not the same in the design.
  22. For example, GPU devices can be used to store Tensors and execute calculations in MegEngine.
  23. The concept of `view <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.view.html>`_
  24. does not exist in MegEngine so indexing and other behaviors might be different with NumPy.
  25. All manipulations and operations on/between Tensors could be found in the :mod:`~.megengine.functional` module.
  26. Keep in mind that they are **not in-place**, a new Tensor will always be returned and
  27. the original data will remain constant.
  28. For more information, refer to the :ref:`tensor-guide` topic.
  29. Args:
  30. data(Tensor, :class:`~.numpy.ndarray`, :class:`list` or Python number):
  31. The data used for construcing Tensor.
  32. Tensor could be constructed from a Python :class:`list` / :class:`tuple` or sequence;
  33. a NumPy :class:`~.numpy.ndarray` data structure; MegEngine builtin methods and so on.
  34. Refer to :ref:`tensor-creation` for more details.
  35. dtype(:attr:`~.Tensor.dtype`): The data type of returned Tensor. Infer from ``data`` if not specified.
  36. device(:attr:`~.Tensor.device`): The desired device of returned Tensor. Uses :func:`get_default_device` if not specified.
  37. is_const: Whether make it a ``ImutableTensor`` in tracing mode, refer to :class:`.jit.trace`.
  38. no_cache: Whether cache it for memory sharing.
  39. name: Used to improve convenience in graph operation on dumped model.
  40. format: Used to indicate which memory format Tensor uses. It will not affect actual memory order or stride,
  41. but may affect some operators related to indexing and dimension. Only support "default", "nchw" and "nhwc".
  42. .. note::
  43. There are some methods like :meth:`~.Tensor.reshape` / :meth:`~.Tensor.flatten` /
  44. :meth:`~.Tensor.transpose` / :meth:`~.Tensor.min` / :meth:`~.Tensor.max` /
  45. :meth:`~.Tensor.mean` / :meth:`~.Tensor.sum` / :meth:`~.Tensor.prod` implemented
  46. in ``Tensor`` class for convenience and historical reasons.
  47. But other methods implemented in the :mod:`~.megengine.functional` module will not be added here anymore,
  48. it is hard for maintaining and too many candidates will affect code completion experience.
  49. """
  50. grad = None
  51. dmap_callback = None
  52. _qparams = None
  53. _custom_name = ""
  54. _name = None
  55. _short_name = None
  56. _prefix = None
  57. def __init__(
  58. self,
  59. data: Union["Tensor", np.ndarray, list, int, float],
  60. dtype: np.dtype = None,
  61. device: str = None,
  62. is_const: bool = False,
  63. no_cache: bool = False,
  64. name: str = None,
  65. format: str = "default",
  66. ):
  67. if name is None:
  68. name = ""
  69. else:
  70. self._set_name(name)
  71. self._custom_name = name
  72. self._name = name
  73. self._short_name = name
  74. self._prefix = None
  75. @property
  76. def shape(self) -> Union[tuple, "Tensor"]:
  77. r"""Returns a :class:`tuple` or a :class:`~.Tensor` represents tensor dimensions.
  78. Note:
  79. The shape of a tensor was usually represented by a :class:`tuple`.
  80. But if a tensor was treated as symbolic placeholder with tracing,
  81. it's shape could also be a :class:`~.Tensor`. See :class:`~.trace` for more details.
  82. The shape property is usually used to get the current shape of a tensor,
  83. but may also be used to reshape the tensor in-place by assigning a tuple of tensor dimensions to it.
  84. As with :func:`~.reshape`, one of the new shape dimensions can be -1,
  85. in which case its value is inferred from the size of the tensor and the remaining dimensions.
  86. """
  87. shape = super().shape
  88. if shape == () or not use_symbolic_shape():
  89. return shape
  90. return apply(GetVarShape(), self)[0]
  91. @property
  92. def _tuple_shape(self):
  93. return super().shape
  94. @property
  95. def device(self) -> CompNode:
  96. r"""Returns a string represents the device a :class:`~.Tensor` storaged on."""
  97. return super().device
  98. @property
  99. def dtype(self) -> np.dtype:
  100. r"""Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`."""
  101. return super().dtype
  102. @property
  103. def format(self) -> str:
  104. return super().format()
  105. @format.setter
  106. def format(self, format):
  107. super()._set_format(format)
  108. @property
  109. def qparams(self):
  110. r"""Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`."""
  111. from .quantization.utils import create_qparams # pylint: disable=all
  112. if self._qparams is None:
  113. self._qparams = create_qparams()
  114. return self._qparams
  115. def numpy(self) -> np.ndarray:
  116. r"""Returns self :class:`~.Tensor` as a :class:`numpy.ndarray`."""
  117. return super().numpy()
  118. def detach(self):
  119. r"""Returns a new :class:`~.Tensor`, detached from the current graph."""
  120. return super().detach()
  121. def _reset(self, other):
  122. if not isinstance(other, _Tensor):
  123. other = Tensor(other, dtype=self.dtype, device=self.device)
  124. super()._reset(other)
  125. def __repr__(self):
  126. piece = "{}(".format(self.__class__.__name__)
  127. with np.printoptions(precision=4, suppress=True):
  128. piece += "{}".format(str(self.numpy()))
  129. if self.dtype != np.float32:
  130. piece += ", dtype={}".format(np.dtype(self.dtype).name)
  131. piece += ", device={}".format(self.device) + ")"
  132. return piece
  133. @property
  134. def name(self):
  135. return self._custom_name
  136. @name.setter
  137. def name(self, name):
  138. self._custom_name = name
  139. if name == None:
  140. name = ""
  141. self._name = self._prefix + "." + name if self._prefix else name
  142. self._set_name(self._name)
  143. @deprecated(
  144. version="1.0", reason="please use ``tensor_name[...] = value``",
  145. )
  146. def set_value(self, value):
  147. self._reset(value)
  148. @deprecated(version="1.0", reason="use ``*= 0`` instead")
  149. def reset_zero(self):
  150. self *= 0
  151. def to(self, device, *, _borrow=False):
  152. r"""Copy self :class:`~.Tensor` to specified device. See :func:`~.copy`"""
  153. if isinstance(device, str) and not _valid_device(device):
  154. raise ValueError(
  155. "invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format(
  156. device
  157. )
  158. )
  159. cn = as_device(device).to_c()
  160. op = Borrow(comp_node=cn) if _borrow else Copy(comp_node=cn)
  161. return apply(op, self)[0]
  162. @property
  163. def requires_grad(self):
  164. raise AttributeError("requires_grad is reserved for future use")
  165. @requires_grad.setter
  166. def requires_grad(self, value):
  167. raise AttributeError("requires_grad is reserved for future use")
  168. @requires_grad.deleter
  169. def requires_grad(self):
  170. raise AttributeError("requires_grad is reserved for future use")
  171. def __hash__(self):
  172. return id(self)
  173. def __getnewargs__(self):
  174. r"""__getnewargs__ will be called for pickle serialization or deep copy"""
  175. return (self.numpy(), self.dtype, self.device.logical_name)
  176. def __getstate__(self):
  177. r"""__getstate__ will be called for pickle serialization or deep copy"""
  178. state = {}
  179. if self._qparams is not None:
  180. state["qparams"] = self._qparams
  181. return state
  182. def __setstate__(self, state):
  183. # for compatibility with old version not using fastcore
  184. if "data" in state:
  185. data = state.pop("data")
  186. device = state.pop("device")
  187. dtype = state.pop("dtype")
  188. self._reset(Tensor(data, dtype=dtype, device=device))
  189. # quantize related state for deepcopy
  190. if "qdict" in state:
  191. qparams = state.pop("qdict")
  192. logger.warning(
  193. "Tensor's 'qdict' state is depreciated. Use 'qparams' instead"
  194. )
  195. elif "qparams" in state:
  196. qparams = state.pop("qparams")
  197. else:
  198. qparams = None
  199. self._qparams = qparams
  200. set_py_tensor_type(Tensor)
  201. tensor = Tensor
  202. class Parameter(Tensor):
  203. r"""A kind of Tensor that is to be considered a module parameter.
  204. Note:
  205. Operations happened on Parameter usually return a Tensor instead of Parameter.
  206. For example, with a Parameter ``x``, ``x.reshape/to/sum/...`` will result into a Tensor.
  207. Any operations between Parameter and Tensor will have Tensor as outputs.
  208. """