|
|
@@ -28,6 +28,13 @@ logger = get_logger(__name__) |
|
|
|
class Tensor(_Tensor, ArrayMethodMixin): |
|
|
|
r""" |
|
|
|
A tensor object represents a multidimensional, homogeneous array of fixed-size items. |
|
|
|
|
|
|
|
:param data: The value of returned Tensor. |
|
|
|
:param dtype: The dtype of returned Tensor. Uses data's dtype if not specified. |
|
|
|
:param device: The desired device of returned Tensor. Uses :func:`get_default_device` if not specified. |
|
|
|
:param is_const: Whether make it a ``ImutableTensor`` in tracing mode. |
|
|
|
:param no_cache: Whether cache it for memory sharing. |
|
|
|
:param name: Used to improve convenience in graph operation on dumped model. |
|
|
|
""" |
|
|
|
|
|
|
|
grad = None |
|
|
@@ -35,8 +42,16 @@ class Tensor(_Tensor, ArrayMethodMixin): |
|
|
|
_qparams = None |
|
|
|
|
|
|
|
def __new__( |
|
|
|
cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=None |
|
|
|
cls, |
|
|
|
data: Union["Tensor", np.ndarray, list, "scalar"] = None, |
|
|
|
dtype: np.dtype = None, |
|
|
|
device: str = None, |
|
|
|
is_const: bool = False, |
|
|
|
no_cache: bool = False, |
|
|
|
name: str = None, |
|
|
|
): |
|
|
|
if data is None: |
|
|
|
data = [] |
|
|
|
if device is None: |
|
|
|
cn = get_default_device() |
|
|
|
elif isinstance(device, str): |
|
|
@@ -59,13 +74,24 @@ class Tensor(_Tensor, ArrayMethodMixin): |
|
|
|
obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name) |
|
|
|
return obj |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
data: Union["Tensor", np.ndarray, list, "scalar"], |
|
|
|
dtype: np.dtype = None, |
|
|
|
device: str = None, |
|
|
|
is_const: bool = False, |
|
|
|
no_cache: bool = False, |
|
|
|
name: str = None, |
|
|
|
): |
|
|
|
pass |
|
|
|
|
|
|
|
@property |
|
|
|
def shape(self) -> Union[tuple, "Tensor"]: |
|
|
|
r""" |
|
|
|
Returns a :class:`tuple` or a :class:`~.Tensor` represents tensor dimensions. |
|
|
|
|
|
|
|
.. note:: |
|
|
|
|
|
|
|
|
|
|
|
The shape of a tensor was usually represented by a :class:`tuple`. |
|
|
|
But if a tensor was treated as symbolic placeholder with tracing, |
|
|
|
it's shape could also be a :class:`~.Tensor`. See :class:`~.trace` for more details. |
|
|
@@ -100,6 +126,9 @@ class Tensor(_Tensor, ArrayMethodMixin): |
|
|
|
|
|
|
|
@property |
|
|
|
def qparams(self): |
|
|
|
r""" |
|
|
|
Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`. |
|
|
|
""" |
|
|
|
from .quantization.utils import create_qparams # pylint: disable=all |
|
|
|
|
|
|
|
if self._qparams is None: |
|
|
@@ -185,18 +214,20 @@ class Tensor(_Tensor, ArrayMethodMixin): |
|
|
|
def __getstate__(self): |
|
|
|
r""" __getstate__ will be called for pickle serialization or deep copy |
|
|
|
""" |
|
|
|
state = { |
|
|
|
"numpy": self.numpy(), |
|
|
|
"dtype": self.dtype, |
|
|
|
"device": self.device.logical_name, |
|
|
|
} |
|
|
|
state = {} |
|
|
|
if self._qparams is not None: |
|
|
|
state["qparams"] = self._qparams |
|
|
|
return state |
|
|
|
|
|
|
|
def __setstate__(self, state): |
|
|
|
from .quantization.utils import create_qparams # pylint: disable=all |
|
|
|
|
|
|
|
# for compatibility with old version not using fastcore |
|
|
|
if "data" in state: |
|
|
|
data = state.pop("data") |
|
|
|
device = state.pop("device") |
|
|
|
dtype = state.pop("dtype") |
|
|
|
self._reset(Tensor(data, dtype=dtype, device=device)) |
|
|
|
|
|
|
|
# quantize related state for deepcopy |
|
|
|
if "qdict" in state: |
|
|
|
qparams = state.pop("qdict") |
|
|
|
logger.warning( |
|
|
@@ -206,7 +237,6 @@ class Tensor(_Tensor, ArrayMethodMixin): |
|
|
|
qparams = state.pop("qparams") |
|
|
|
else: |
|
|
|
qparams = None |
|
|
|
self._reset(Tensor(state.pop("numpy"), state.pop("dtype"), state.pop("device"))) |
|
|
|
self._qparams = qparams |
|
|
|
|
|
|
|
|
|
|
|