# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections import functools import itertools import weakref from typing import Callable, Tuple, Union import numpy as np import megengine._internal as mgb from .graph import _use_default_if_none, get_default_graph def wrap_io_tensor(func): r"""A wrapper to make ``func`` compatible with functions in ``_internal.opr``. """ @functools.wraps(func) def wrapper(*args, **kwargs): comp_graph = None for i in itertools.chain(args, kwargs.values()): if isinstance(i, Tensor) and i._comp_graph: comp_graph = i._comp_graph break else: comp_graph = get_default_graph() new_args = ( arg._attach(comp_graph) if isinstance(arg, Tensor) else arg for arg in args ) new_kwargs = { k: v._attach(comp_graph) if isinstance(v, Tensor) else v for k, v in kwargs.items() } ret = func(*new_args, **new_kwargs) if isinstance(ret, mgb.SymbolVar): ret = Tensor(ret) elif isinstance(ret, list): ret = [Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret] elif isinstance(ret, tuple): ret = tuple(Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret) return ret return wrapper def _wrap_symbolvar_binary_op(f): @functools.wraps(f) def wrapped(self, other): comp_graph = ( isinstance(other, Tensor) and other._comp_graph or self._comp_graph or get_default_graph() ) if isinstance(other, Tensor): other = other._attach(comp_graph) return Tensor(f(self._attach(comp_graph), other)) return wrapped def _wrap_slice(inp: slice): r""" A wrapper to handle Tensor values in ``inp`` slice. """ start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step return slice(start, stop, step) def _wrap_idx(idx: Tuple[Union[int, "Tensor"]]): r""" A wrapper to handle Tensor values in ``idx``. """ if not isinstance(idx, tuple): idx = (idx,) idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx) idx = tuple(_wrap_slice(i) if isinstance(i, slice) else i for i in idx) return idx class _MGBIndexWrapper: r""" A wrapper class to handle ``__getitem__`` for index containing Tensor values. :param dest: a destination Tensor to do indexing on. :param mgb_index: an ``_internal`` helper function indicating how to index. :param val: a optional Tensor parameter used for ``mgb_index``. """ def __init__(self, dest: "Tensor", mgb_index: Callable, val=None): self.dest = dest self.val = val self.mgb_index = mgb_index def __getitem__(self, idx): if self.val is None: return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)( _wrap_idx(idx) ) else: return wrap_io_tensor( self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__ )(_wrap_idx(idx)) class _Guard: r""" A wrapper class with custom ``__del__`` method calling ``deleter``. :param deleter: a function to be called in ``__del__``. """ def __init__(self, deleter: Callable): self.deleter = deleter def __del__(self): self.deleter() class Tensor: r"""The main data container in MegEngine. Use :func:`~.tensor` to create a Tensor with existed data. """ requires_grad = False grad = None def __init__(self, val=None, *, requires_grad=None): self._reset(val, requires_grad=requires_grad) def _reset(self, val=None, *, requires_grad=None): self.__sym_override = None if val is None: self.__val = None self.__sym = None elif isinstance(val, mgb.SharedND): self.__val = val self.__sym = None elif isinstance(val, mgb.SymbolVar): self.__val = None self.__sym = val else: raise TypeError("must be initialized with SymbolVar or SharedND") self.requires_grad = requires_grad def _as_tensor(self, obj): r"""Convert the data into a ``Tensor``. If the data is already a Tensor with the same dtype and device, no copy will be performed. Otherwise a new Tensor will be returned with computational graph retained. """ if isinstance(obj, Tensor): return obj if isinstance(obj, mgb.SymbolVar): return Tensor(obj) if isinstance(obj, mgb.SharedScalar): return Tensor(obj._as_sym_var(self._comp_graph, self._comp_node)) return tensor(data=obj, device=self.device) def numpy(self): r"""Return the tensor value in numpy.ndarray format. """ if self.__val is not None: assert self.__sym is None return self.__val.get_value() if self.__sym is None: raise ValueError("uninitialized") if self.__sym.eager_val is not None: return self.__sym.eager_val.get_value() return self.__sym.inferred_value def item(self): r"""If tensor only has only one value, return it.""" return self.numpy().item() def _attach(self, comp_graph, *, volatile=True): sym = self.__sym_override or self.__sym if sym: if sym.owner_graph != comp_graph: raise RuntimeError("internal error") return sym if self.__val: return self.__val.symvar(comp_graph, volatile=volatile) else: raise ValueError("uninitialized") @property def _symvar(self): if self.__sym_override: return self.__sym_override if self.__sym: assert not self.__val return self.__sym if not self.__val: raise ValueError("uninitialized") return self._attach(get_default_graph()) def __mgb_symvar__(self, comp_graph=None, **_): if self.__sym_override: return self.__sym_override if self.__val and comp_graph: return self._attach(comp_graph) return self._symvar # read by mgb.opr def _override_symvar_during_trace(self, trace, symvar): assert self.__val and not self.__sym assert trace is type(trace)._active_instance deleters = trace._user_cache.setdefault(Tensor, set()) self_ref = weakref.ref(self) def restore(): self = self_ref() if self is not None: self.__sym_override = None deleters.add(_Guard(restore)) self.__sym_override = symvar @property def dtype(self): r"""Return the data type of the tensor. """ if self.__val is not None: return self.__val.dtype return self._symvar.dtype def set_dtype(self, dtype: str = None): r"""Set the data type of the tensor. """ if self.__val is not None: self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) elif self.__sym is not None: self.__sym = self.__sym.astype(dtype) @property def _comp_node(self): if self.__val is not None: return self.__val.comp_node return self._symvar.comp_node device = _comp_node @property def _comp_graph(self): if self.__sym is not None: return self.__sym.owner_graph return None @property def shape(self): r"""Return an int tuple that is the shape/layout of the tensor. Could be invalid in static graph mode. """ from ..jit import trace if trace._active_instance: # pylint: disable=protected-access # NOTE: this is an hack shape = mgb.opr.get_var_shape(self._symvar) return tuple(Tensor(shape[i]) for i in range(self.ndim)) return self._symvar.imm_shape def set_value(self, value, *, sync=True, inplace=False, share=False): r"""Set value to the tensor. """ if not self.__val: raise ValueError("not detached") if isinstance(value, Tensor): value = value.__val or value.__sym.eager_val self.__val.set_value(value, sync=sync, inplace=inplace, share=share) def fill(self, value): r"""Fills the tensor with the specified value. """ self.set_value(np.full(self.shape, value, dtype=self.dtype)) def reset_zero(self): r"""Reset the tensor and fills with zeros. """ if not self.__val: raise ValueError("not detached") self.__val.reset_zero() def to(self, device): r"""Performs Tensor device conversion, returns Tensor with the specified device. """ return wrap_io_tensor(mgb.opr.copy)(self, comp_node=device) # https://docs.python.org/3/reference/datamodel.html#object.__hash__ # > If a class does not define an __eq__() method it should not define a # > __hash__() operation either __hash__ = None # type: ignore[assignment] def __eq__(self, rhs): rhs = self._as_tensor(rhs) return Tensor(self._symvar._binary_opr("EQ", rhs._symvar)) def __ne__(self, rhs): return 1 - self.__eq__(rhs) def __len__(self): if self._symvar.eager_val is not None: return self._symvar.eager_val.shape[0] raise TypeError( "__len__ and __iter__ is not available for tensors on non eager graph." ) __add__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__add__) __radd__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__radd__) __sub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__sub__) __rsub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rsub__) __mul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mul__) __rmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmul__) __matmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__matmul__) __rmatmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmatmul__) __lshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lshift__) __rshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rshift__) __truediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__truediv__) __rtruediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rtruediv__) __floordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__floordiv__) __rfloordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rfloordiv__) __mod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mod__) __rmod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmod__) __pow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__pow__) __rpow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rpow__) __lt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lt__) __gt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__gt__) __le__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__le__) __ge__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__ge__) __neg__ = wrap_io_tensor(mgb.SymbolVar.__neg__) sum = wrap_io_tensor(mgb.SymbolVar.sum) """ Sum up the given tensors. """ max = wrap_io_tensor(mgb.SymbolVar.max) """ Return the maximum value of given tensor. """ min = wrap_io_tensor(mgb.SymbolVar.min) """ Return the minimum value of given tensor. """ prod = wrap_io_tensor(mgb.SymbolVar.prod) """ Return the product value of the given tensor. """ mean = wrap_io_tensor(mgb.SymbolVar.mean) """ Return the mean value of the given tensor. """ dimshuffle = wrap_io_tensor(mgb.SymbolVar.dimshuffle) """ See more details in :func:`~.functional.tensor.dimshuffle`. """ astype = wrap_io_tensor(mgb.SymbolVar.astype) """ Cast the tensor to a specified type. """ def reshape(self, *target_shape): r"""Return a tensor which has given target shape Examples: .. testcode:: import numpy as np from megengine import tensor inp = tensor(np.arange(1, 17, dtype=np.int32).reshape(4,4)) out = tensor(np.arange(100, 116, dtype=np.int32).reshape(1,16)) out = out.reshape(inp.shape) print(out.numpy()) .. testoutput:: [[100 101 102 103] [104 105 106 107] [108 109 110 111] [112 113 114 115]] """ if isinstance(target_shape[0], tuple): if len(target_shape) > 1: raise ValueError("Only single tuple is accepted in reshape") target_shape = target_shape[0] target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape) return Tensor(mgb.SymbolVar.reshape(self._symvar, *target_shape)) def broadcast(self, *target_shape): r"""Return a tesnor broadcasted by current tensor to given target shape Examples: .. testcode:: import numpy as np from megengine import tensor data = tensor(np.arange(100, 104, dtype=np.int32).reshape(1,4)) data = data.broadcast((4,4)) print(data.numpy()) .. testoutput:: [[100 101 102 103] [100 101 102 103] [100 101 102 103] [100 101 102 103]] """ if isinstance(target_shape[0], tuple): if len(target_shape) > 1: raise ValueError("Only single tuple is accepted in broadcast") target_shape = target_shape[0] target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape) return Tensor(mgb.SymbolVar.broadcast(self._symvar, *target_shape)) # Prefer operators on Tensor instead of convert to numpy __array_priority__ = 1000 # mgb indexing family def __getitem__(self, idx): return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx)) def set_subtensor(self, val: "Tensor") -> _MGBIndexWrapper: r""" Return a object which supports using ``__getitem__`` to set subtensor. ``c = a.set_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] = b``. """ return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val) def incr_subtensor(self, val: "Tensor") -> _MGBIndexWrapper: r""" Return a object which supports using ``__getitem__`` to increase subtensor. ``c = a.incr_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] += b``. """ return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val) @property def ai(self) -> _MGBIndexWrapper: r""" Return a object which supports complex index method to get subtensor. Examples: .. testcode:: from megengine import tensor a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4))) print(a.ai[:, [2, 3]]) Outputs: .. testoutput:: Tensor([[ 2. 3.] [ 6. 7.] [10. 11.] [14. 15.]]) """ return _MGBIndexWrapper(self, mgb.opr.advanced_indexing) def set_ai(self, val: "Tensor") -> _MGBIndexWrapper: r""" Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing. """ return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val) def incr_ai(self, val: "Tensor") -> _MGBIndexWrapper: r""" Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing. """ return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val) @property def mi(self) -> _MGBIndexWrapper: r""" Return a object which supports getting subtensor by the coordinates which is Cartesian product of given index. Examples: .. testcode:: from megengine import tensor a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4))) print(a.mi[[1, 2], [2, 3]]) # is equal to elements on [1, 2] * [2, 3] = [[(1,2), (1, 3)], [(2, 2), (2, 3)]] # a[1,2] = 6, a[1,3] = 7, a[2,2] = 10, a[2,3] = 11 Outputs: .. testoutput:: Tensor([[ 6. 7.] [10. 11.]]) """ return _MGBIndexWrapper(self, mgb.opr.mesh_indexing) def set_mi(self, val: "Tensor") -> _MGBIndexWrapper: r""" Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing. """ return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val) def incr_mi(self, val: "Tensor") -> _MGBIndexWrapper: r""" Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing. """ return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val) @property def batched_mi(self) -> _MGBIndexWrapper: r""" Return a object which supports getting subtensor by batched mesh indexing. For Tensor ``a`` and index ``idx``, each value of the ``idx`` need to be a 2-dim matrix or slice. Cartesian product ``... * idx[k-1][i] * idx[k][i] * idx[k+1][i] * ...`` will be a subtensor from ``a[i]``. Each matrix ``idx[k]`` should have the size of ``batched_dim`` rows as ``idx[0]`` indicated. And for slice value, it will apply same slice for each ``batched_dim``. For more details see the example below. Examples: .. testcode:: from megengine import tensor a = tensor(np.arange(144, dtype=np.float32).reshape((3, 3, 4, 4))) print(a.batched_mi[:2, [[0],[1]],[[0,1],[2,3]],[[0],[1]]]) # is equal to elements from a[0] with ``[0] * [0,1] * [0] = [[[(0,0,0)], [(0,1,0)]]]``(shape is [1,2,1]) # and from a[1] with ``[1] * [2,3] * [1] = [[[(1,2,1)], [(1,3,1)]]]``(shape is also [1,2,1]) # a[0,0,0,0] = 0, a[0,0,1,0] = 4, a[1,1,2,1] = 73, a[1,1,3,1] = 77 print(a.batched_mi[:2, [[0],[1]], :2, :1]) # is equal to ``a.batched_mi[:2, [[0],[1]], [[0,1],[0,1]],[[0],[0]]]`` Outputs: .. testoutput:: Tensor([[[[ 0.] [ 4.]]] [[[73.] [77.]]]]) Tensor([[[[ 0.] [ 4.]]] [[[64.] [68.]]]]) """ return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing) def batched_set_mi(self, val: "Tensor") -> _MGBIndexWrapper: r""" Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. """ return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val) def batched_incr_mi(self, val: "Tensor") -> _MGBIndexWrapper: r""" Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. """ return _MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val) def __array__(self, dtype=None): if dtype is None: return self.numpy() else: return self.numpy().astype(dtype, copy=False) def __int__(self): return int(self.item()) def __index__(self): return int(self.item()) def __round__(self, ndigits=0): if ndigits != 0: raise ValueError("ndigits must be 0 for Tensor.round") return Tensor(mgb.opr.elemwise([self._symvar], mode="ROUND")) round = __round__ def sqrt(self): r"""Return a tensor that each element is the square root of its original value. """ return Tensor(mgb.opr.sqrt(self._symvar)) def shapeof(self, axis=None): r"""Return a Tensor that represent the shape of the tensor. """ return Tensor(mgb.opr.get_var_shape(self._symvar, axis=axis)) @property def ndim(self): r"""Return the number of dimensions of the tensor. """ return len(self._symvar.imm_shape) def __repr__(self): piece = "Tensor(" with np.printoptions(precision=4, suppress=True): piece += "{}".format(str(self.numpy())) if self.dtype != np.float32: piece += ", dtype={}".format(np.dtype(self.dtype).name) if self._comp_node.locator_logical != ("XPU", -1, 0): piece += ", device={}".format(self.device) piece += ")" return piece def __bool__(self): raise RuntimeError( "Tensor object should not be converted to bool or used in a if statement. Use .numpy(), int() or float() if you want to use its value in if statement, be aware that this may lead to incorrect result in non-eager mode." ) def __getstate__(self): r""" __getstate__ will be called for pickle serialization or deep copy """ assert (self.__val is not None) and ( self.__sym is None ), "Only SharedND initialized Tensor can be serialized or deep copied" metadata = {"requires_grad": self.requires_grad} state = { "data": self.numpy(), "device": self.device, "dtype": self.dtype, "metadata": metadata, } return state def __setstate__(self, state): data = state.pop("data") device = state.pop("device") dtype = state.pop("dtype") metadata = state.pop("metadata", {}) requires_grad = metadata.pop("requires_grad", None) snd = mgb.make_shared(device, value=data, dtype=dtype) self._reset(snd, requires_grad=requires_grad) def tensor( data: Union[list, np.ndarray] = None, *, dtype: str = None, device: mgb.CompNode = None, requires_grad: bool = None ): r"""A helper function to create a :class:`~.Tensor` using existing data. :param data: an existing data array, must be Python list, NumPy array or None. :param dtype: target Tensor data type, one of ``("uint8", "int8", "int16", "int32", "float32", "float16")``. :param device: target device for Tensor storing. :param requires_grad: whether its gradiant will be calculated during :meth:`~.Optimizer.backward` """ supported_dtypes = ("uint8", "int8", "int16", "int32", "float32", "float16") if isinstance(data, Tensor): raise NotImplementedError if dtype is not None and np.dtype(dtype).name not in supported_dtypes: raise TypeError("unsupported dtype {}".format(dtype)) if data is not None: if not isinstance(data, np.ndarray): data = np.array(data, dtype=dtype) # In order to accept tensor([1]), # Automaticlly convert to 32-bit number instead of numpy's default 64-bit when input data is not nparray. dtype = mgb.to_mgb_supported_dtype(data.dtype) if dtype is None: if data.dtype.name not in supported_dtypes: raise TypeError("unsupported dtype {}".format(data.dtype)) device, _ = _use_default_if_none(device, None) shared_nd = mgb.make_shared(device, value=data, dtype=dtype) return Tensor(shared_nd, requires_grad=requires_grad) class TensorDict(collections.MutableMapping): r""" A helper class to maintain dict with Tensor key. """ def __init__(self, *args, **kwargs): self.data = {} for i in args: self.update(i) self.update(**kwargs) class keyfn: def __new__(cls, x: Tensor): if not isinstance(x, Tensor): return x return super().__new__(cls) def __init__(self, x: Tensor): self._data = x # do not save id directly to make pickle work def __hash__(self): return id(self._data) def __eq__(self, other): return isinstance(other, type(self)) and id(self._data) == id(other._data) def __getitem__(self, key): _, v = self.data[self.keyfn(key)] return v def __setitem__(self, key, value): self.data[self.keyfn(key)] = key, value def __delitem__(self, key): del self.data[self.keyfn(key)] def __iter__(self): for _, (k, _) in self.data.items(): yield k def __len__(self): return len(self.data)