|
@@ -1,4 +1,3 @@ |
|
|
# -*- coding: utf-8 -*- |
|
|
|
|
|
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") |
|
|
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") |
|
|
# |
|
|
# |
|
|
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. |
|
|
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. |
|
@@ -10,7 +9,7 @@ import collections |
|
|
import functools |
|
|
import functools |
|
|
import itertools |
|
|
import itertools |
|
|
import weakref |
|
|
import weakref |
|
|
from typing import Union |
|
|
|
|
|
|
|
|
from typing import Callable, Tuple, Union |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import numpy as np |
|
|
|
|
|
|
|
@@ -68,24 +67,38 @@ def _wrap_symbolvar_binary_op(f): |
|
|
return wrapped |
|
|
return wrapped |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wrap_slice(inp): |
|
|
|
|
|
|
|
|
def _wrap_slice(inp: slice): |
|
|
|
|
|
r""" |
|
|
|
|
|
A wrapper to handle Tensor values in ``inp`` slice. |
|
|
|
|
|
""" |
|
|
start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start |
|
|
start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start |
|
|
stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop |
|
|
stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop |
|
|
step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step |
|
|
step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step |
|
|
return slice(start, stop, step) |
|
|
return slice(start, stop, step) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wrap_idx(idx): |
|
|
|
|
|
|
|
|
def _wrap_idx(idx: Tuple[Union[int, "Tensor"]]): |
|
|
|
|
|
r""" |
|
|
|
|
|
A wrapper to handle Tensor values in ``idx``. |
|
|
|
|
|
""" |
|
|
if not isinstance(idx, tuple): |
|
|
if not isinstance(idx, tuple): |
|
|
idx = (idx,) |
|
|
idx = (idx,) |
|
|
|
|
|
|
|
|
idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx) |
|
|
idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx) |
|
|
idx = tuple(wrap_slice(i) if isinstance(i, slice) else i for i in idx) |
|
|
|
|
|
|
|
|
idx = tuple(_wrap_slice(i) if isinstance(i, slice) else i for i in idx) |
|
|
return idx |
|
|
return idx |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MGBIndexWrapper: |
|
|
|
|
|
def __init__(self, dest, mgb_index, val=None): |
|
|
|
|
|
|
|
|
class _MGBIndexWrapper: |
|
|
|
|
|
r""" |
|
|
|
|
|
A wrapper class to handle ``__getitem__`` for index containing Tensor values. |
|
|
|
|
|
|
|
|
|
|
|
:param dest: a destination Tensor to do indexing on. |
|
|
|
|
|
:param mgb_index: an ``_internal`` helper function indicating how to index. |
|
|
|
|
|
:param val: a optional Tensor parameter used for ``mgb_index``. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dest: "Tensor", mgb_index: Callable, val=None): |
|
|
self.dest = dest |
|
|
self.dest = dest |
|
|
self.val = val |
|
|
self.val = val |
|
|
self.mgb_index = mgb_index |
|
|
self.mgb_index = mgb_index |
|
@@ -93,16 +106,22 @@ class MGBIndexWrapper: |
|
|
def __getitem__(self, idx): |
|
|
def __getitem__(self, idx): |
|
|
if self.val is None: |
|
|
if self.val is None: |
|
|
return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)( |
|
|
return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)( |
|
|
wrap_idx(idx) |
|
|
|
|
|
|
|
|
_wrap_idx(idx) |
|
|
) |
|
|
) |
|
|
else: |
|
|
else: |
|
|
return wrap_io_tensor( |
|
|
return wrap_io_tensor( |
|
|
self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__ |
|
|
self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__ |
|
|
)(wrap_idx(idx)) |
|
|
|
|
|
|
|
|
)(_wrap_idx(idx)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _Guard: |
|
|
|
|
|
r""" |
|
|
|
|
|
A wrapper class with custom ``__del__`` method calling ``deleter``. |
|
|
|
|
|
|
|
|
|
|
|
:param deleter: a function to be called in ``__del__``. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
class Guard: |
|
|
|
|
|
def __init__(self, deleter): |
|
|
|
|
|
|
|
|
def __init__(self, deleter: Callable): |
|
|
self.deleter = deleter |
|
|
self.deleter = deleter |
|
|
|
|
|
|
|
|
def __del__(self): |
|
|
def __del__(self): |
|
@@ -161,6 +180,7 @@ class Tensor: |
|
|
return self.__sym.inferred_value |
|
|
return self.__sym.inferred_value |
|
|
|
|
|
|
|
|
def item(self): |
|
|
def item(self): |
|
|
|
|
|
r"""If tensor only has only one value, return it.""" |
|
|
return self.numpy().item() |
|
|
return self.numpy().item() |
|
|
|
|
|
|
|
|
def _attach(self, comp_graph, *, volatile=True): |
|
|
def _attach(self, comp_graph, *, volatile=True): |
|
@@ -204,7 +224,7 @@ class Tensor: |
|
|
if self is not None: |
|
|
if self is not None: |
|
|
self.__sym_override = None |
|
|
self.__sym_override = None |
|
|
|
|
|
|
|
|
deleters.add(Guard(restore)) |
|
|
|
|
|
|
|
|
deleters.add(_Guard(restore)) |
|
|
self.__sym_override = symvar |
|
|
self.__sym_override = symvar |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
@@ -403,43 +423,149 @@ class Tensor: |
|
|
|
|
|
|
|
|
# mgb indexing family |
|
|
# mgb indexing family |
|
|
def __getitem__(self, idx): |
|
|
def __getitem__(self, idx): |
|
|
return wrap_io_tensor(self._symvar.__getitem__)(wrap_idx(idx)) |
|
|
|
|
|
|
|
|
return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx)) |
|
|
|
|
|
|
|
|
|
|
|
def set_subtensor(self, val: "Tensor"): |
|
|
|
|
|
r""" |
|
|
|
|
|
Return a object which supports using ``__getitem__`` to set subtensor. |
|
|
|
|
|
|
|
|
def set_subtensor(self, val): |
|
|
|
|
|
return MGBIndexWrapper(self, mgb.opr.set_subtensor, val) |
|
|
|
|
|
|
|
|
``c = a.set_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] = b``. |
|
|
|
|
|
""" |
|
|
|
|
|
return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val) |
|
|
|
|
|
|
|
|
def incr_subtensor(self, val): |
|
|
|
|
|
return MGBIndexWrapper(self, mgb.opr.incr_subtensor, val) |
|
|
|
|
|
|
|
|
def incr_subtensor(self, val: "Tensor"): |
|
|
|
|
|
r""" |
|
|
|
|
|
Return a object which supports using ``__getitem__`` to increase subtensor. |
|
|
|
|
|
|
|
|
|
|
|
``c = a.incr_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] += b``. |
|
|
|
|
|
""" |
|
|
|
|
|
return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val) |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def ai(self): |
|
|
def ai(self): |
|
|
return MGBIndexWrapper(self, mgb.opr.advanced_indexing) |
|
|
|
|
|
|
|
|
r""" |
|
|
|
|
|
Return a object which supports complex index method to get subtensor. |
|
|
|
|
|
|
|
|
def set_ai(self, val): |
|
|
|
|
|
return MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val) |
|
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
|
|
|
def incr_ai(self, val): |
|
|
|
|
|
return MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val) |
|
|
|
|
|
|
|
|
.. testcode:: |
|
|
|
|
|
|
|
|
|
|
|
from megengine import tensor |
|
|
|
|
|
a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4))) |
|
|
|
|
|
print(a.ai[:, [2, 3]]) |
|
|
|
|
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
|
|
|
|
|
|
.. testoutput:: |
|
|
|
|
|
|
|
|
|
|
|
Tensor([[ 2. 3.] |
|
|
|
|
|
[ 6. 7.] |
|
|
|
|
|
[10. 11.] |
|
|
|
|
|
[14. 15.]]) |
|
|
|
|
|
""" |
|
|
|
|
|
return _MGBIndexWrapper(self, mgb.opr.advanced_indexing) |
|
|
|
|
|
|
|
|
|
|
|
def set_ai(self, val: "Tensor"): |
|
|
|
|
|
r""" |
|
|
|
|
|
Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing. |
|
|
|
|
|
""" |
|
|
|
|
|
return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val) |
|
|
|
|
|
|
|
|
|
|
|
def incr_ai(self, val: "Tensor"): |
|
|
|
|
|
r""" |
|
|
|
|
|
Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing. |
|
|
|
|
|
""" |
|
|
|
|
|
return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val) |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def mi(self): |
|
|
def mi(self): |
|
|
return MGBIndexWrapper(self, mgb.opr.mesh_indexing) |
|
|
|
|
|
|
|
|
r""" |
|
|
|
|
|
Return a object which supports getting subtensor by |
|
|
|
|
|
the coordinates which is Cartesian product of given index. |
|
|
|
|
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
|
|
|
def set_mi(self, val): |
|
|
|
|
|
return MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val) |
|
|
|
|
|
|
|
|
.. testcode:: |
|
|
|
|
|
|
|
|
|
|
|
from megengine import tensor |
|
|
|
|
|
a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4))) |
|
|
|
|
|
print(a.mi[[1, 2], [2, 3]]) |
|
|
|
|
|
# is equal to elements on [1, 2] * [2, 3] = [[(1,2), (1, 3)], [(2, 2), (2, 3)]] |
|
|
|
|
|
# a[1,2] = 6, a[1,3] = 7, a[2,2] = 10, a[2,3] = 11 |
|
|
|
|
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
|
|
|
|
|
|
.. testoutput:: |
|
|
|
|
|
|
|
|
|
|
|
Tensor([[ 6. 7.] |
|
|
|
|
|
[10. 11.]]) |
|
|
|
|
|
""" |
|
|
|
|
|
return _MGBIndexWrapper(self, mgb.opr.mesh_indexing) |
|
|
|
|
|
|
|
|
|
|
|
def set_mi(self, val: "Tensor"): |
|
|
|
|
|
r""" |
|
|
|
|
|
Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing. |
|
|
|
|
|
""" |
|
|
|
|
|
return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val) |
|
|
|
|
|
|
|
|
def incr_mi(self, val): |
|
|
|
|
|
return MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val) |
|
|
|
|
|
|
|
|
def incr_mi(self, val: "Tensor"): |
|
|
|
|
|
r""" |
|
|
|
|
|
Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing. |
|
|
|
|
|
""" |
|
|
|
|
|
return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val) |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def batched_mi(self): |
|
|
def batched_mi(self): |
|
|
return MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing) |
|
|
|
|
|
|
|
|
r""" |
|
|
|
|
|
Return a object which supports getting subtensor by |
|
|
|
|
|
batched mesh indexing. |
|
|
|
|
|
|
|
|
def batched_set_mi(self, val): |
|
|
|
|
|
return MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val) |
|
|
|
|
|
|
|
|
For Tensor ``a`` and index ``idx``, each value of the ``idx`` need to be a 2-dim matrix or slice. |
|
|
|
|
|
Cartesian product ``... * idx[k-1][i] * idx[k][i] * idx[k+1][i] * ...`` will be a subtensor from ``a[i]``. |
|
|
|
|
|
Each matrix ``idx[k]`` should have the size of ``batched_dim`` rows as ``idx[0]`` indicated. |
|
|
|
|
|
And for slice value, it will apply same slice for each ``batched_dim``. For more details see the example below. |
|
|
|
|
|
|
|
|
def batched_incr_mi(self, val): |
|
|
|
|
|
return MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val) |
|
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
|
|
|
|
|
|
.. testcode:: |
|
|
|
|
|
|
|
|
|
|
|
from megengine import tensor |
|
|
|
|
|
a = tensor(np.arange(144, dtype=np.float32).reshape((3, 3, 4, 4))) |
|
|
|
|
|
|
|
|
|
|
|
print(a.batched_mi[:2, [[0],[1]],[[0,1],[2,3]],[[0],[1]]]) |
|
|
|
|
|
# is equal to elements from a[0] with ``[0] * [0,1] * [0] = [[[(0,0,0)], [(0,1,0)]]]``(shape is [1,2,1]) |
|
|
|
|
|
# and from a[1] with ``[1] * [2,3] * [1] = [[[(1,2,1)], [(1,3,1)]]]``(shape is also [1,2,1]) |
|
|
|
|
|
# a[0,0,0,0] = 0, a[0,0,1,0] = 4, a[1,1,2,1] = 73, a[1,1,3,1] = 77 |
|
|
|
|
|
|
|
|
|
|
|
print(a.batched_mi[:2, [[0],[1]], :2, :1]) |
|
|
|
|
|
# is equal to ``a.batched_mi[:2, [[0],[1]], [[0,1],[0,1]],[[0],[0]]]`` |
|
|
|
|
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
|
|
|
|
|
|
.. testoutput:: |
|
|
|
|
|
|
|
|
|
|
|
Tensor([[[[ 0.] |
|
|
|
|
|
[ 4.]]] |
|
|
|
|
|
[[[73.] |
|
|
|
|
|
[77.]]]]) |
|
|
|
|
|
Tensor([[[[ 0.] |
|
|
|
|
|
[ 4.]]] |
|
|
|
|
|
[[[64.] |
|
|
|
|
|
[68.]]]]) |
|
|
|
|
|
""" |
|
|
|
|
|
return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing) |
|
|
|
|
|
|
|
|
|
|
|
def batched_set_mi(self, val: "Tensor"): |
|
|
|
|
|
r""" |
|
|
|
|
|
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. |
|
|
|
|
|
""" |
|
|
|
|
|
return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val) |
|
|
|
|
|
|
|
|
|
|
|
def batched_incr_mi(self, val: "Tensor"): |
|
|
|
|
|
r""" |
|
|
|
|
|
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. |
|
|
|
|
|
""" |
|
|
|
|
|
return _MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val) |
|
|
|
|
|
|
|
|
def __array__(self, dtype=None): |
|
|
def __array__(self, dtype=None): |
|
|
if dtype is None: |
|
|
if dtype is None: |
|
|