Browse Source

docs(mge/tensor): add advanced index related docs

GitOrigin-RevId: 31735ddac4
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
7751a0676e
1 changed files with 158 additions and 32 deletions
  1. +158
    -32
      python_module/megengine/core/tensor.py

+ 158
- 32
python_module/megengine/core/tensor.py View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
@@ -10,7 +9,7 @@ import collections
import functools import functools
import itertools import itertools
import weakref import weakref
from typing import Union
from typing import Callable, Tuple, Union


import numpy as np import numpy as np


@@ -68,24 +67,38 @@ def _wrap_symbolvar_binary_op(f):
return wrapped return wrapped




def wrap_slice(inp):
def _wrap_slice(inp: slice):
r"""
A wrapper to handle Tensor values in ``inp`` slice.
"""
start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start
stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop
step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step
return slice(start, stop, step) return slice(start, stop, step)




def wrap_idx(idx):
def _wrap_idx(idx: Tuple[Union[int, "Tensor"]]):
r"""
A wrapper to handle Tensor values in ``idx``.
"""
if not isinstance(idx, tuple): if not isinstance(idx, tuple):
idx = (idx,) idx = (idx,)


idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx) idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx)
idx = tuple(wrap_slice(i) if isinstance(i, slice) else i for i in idx)
idx = tuple(_wrap_slice(i) if isinstance(i, slice) else i for i in idx)
return idx return idx




class MGBIndexWrapper:
def __init__(self, dest, mgb_index, val=None):
class _MGBIndexWrapper:
r"""
A wrapper class to handle ``__getitem__`` for index containing Tensor values.

:param dest: a destination Tensor to do indexing on.
:param mgb_index: an ``_internal`` helper function indicating how to index.
:param val: a optional Tensor parameter used for ``mgb_index``.
"""

def __init__(self, dest: "Tensor", mgb_index: Callable, val=None):
self.dest = dest self.dest = dest
self.val = val self.val = val
self.mgb_index = mgb_index self.mgb_index = mgb_index
@@ -93,16 +106,22 @@ class MGBIndexWrapper:
def __getitem__(self, idx): def __getitem__(self, idx):
if self.val is None: if self.val is None:
return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)( return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)(
wrap_idx(idx)
_wrap_idx(idx)
) )
else: else:
return wrap_io_tensor( return wrap_io_tensor(
self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__ self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__
)(wrap_idx(idx))
)(_wrap_idx(idx))


class _Guard:
r"""
A wrapper class with custom ``__del__`` method calling ``deleter``.


:param deleter: a function to be called in ``__del__``.
"""


class Guard:
def __init__(self, deleter):
def __init__(self, deleter: Callable):
self.deleter = deleter self.deleter = deleter


def __del__(self): def __del__(self):
@@ -161,6 +180,7 @@ class Tensor:
return self.__sym.inferred_value return self.__sym.inferred_value


def item(self): def item(self):
r"""If tensor only has only one value, return it."""
return self.numpy().item() return self.numpy().item()


def _attach(self, comp_graph, *, volatile=True): def _attach(self, comp_graph, *, volatile=True):
@@ -204,7 +224,7 @@ class Tensor:
if self is not None: if self is not None:
self.__sym_override = None self.__sym_override = None


deleters.add(Guard(restore))
deleters.add(_Guard(restore))
self.__sym_override = symvar self.__sym_override = symvar


@property @property
@@ -403,43 +423,149 @@ class Tensor:


# mgb indexing family # mgb indexing family
def __getitem__(self, idx): def __getitem__(self, idx):
return wrap_io_tensor(self._symvar.__getitem__)(wrap_idx(idx))
return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx))

def set_subtensor(self, val: "Tensor"):
r"""
Return a object which supports using ``__getitem__`` to set subtensor.


def set_subtensor(self, val):
return MGBIndexWrapper(self, mgb.opr.set_subtensor, val)
``c = a.set_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] = b``.
"""
return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val)


def incr_subtensor(self, val):
return MGBIndexWrapper(self, mgb.opr.incr_subtensor, val)
def incr_subtensor(self, val: "Tensor"):
r"""
Return a object which supports using ``__getitem__`` to increase subtensor.

``c = a.incr_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] += b``.
"""
return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val)


@property @property
def ai(self): def ai(self):
return MGBIndexWrapper(self, mgb.opr.advanced_indexing)
r"""
Return a object which supports complex index method to get subtensor.


def set_ai(self, val):
return MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)
Examples:


def incr_ai(self, val):
return MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)
.. testcode::

from megengine import tensor
a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
print(a.ai[:, [2, 3]])

Outputs:

.. testoutput::

Tensor([[ 2. 3.]
[ 6. 7.]
[10. 11.]
[14. 15.]])
"""
return _MGBIndexWrapper(self, mgb.opr.advanced_indexing)

def set_ai(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)

def incr_ai(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)


@property @property
def mi(self): def mi(self):
return MGBIndexWrapper(self, mgb.opr.mesh_indexing)
r"""
Return a object which supports getting subtensor by
the coordinates which is Cartesian product of given index.

Examples:


def set_mi(self, val):
return MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)
.. testcode::

from megengine import tensor
a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
print(a.mi[[1, 2], [2, 3]])
# is equal to elements on [1, 2] * [2, 3] = [[(1,2), (1, 3)], [(2, 2), (2, 3)]]
# a[1,2] = 6, a[1,3] = 7, a[2,2] = 10, a[2,3] = 11

Outputs:

.. testoutput::

Tensor([[ 6. 7.]
[10. 11.]])
"""
return _MGBIndexWrapper(self, mgb.opr.mesh_indexing)

def set_mi(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)


def incr_mi(self, val):
return MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val)
def incr_mi(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val)


@property @property
def batched_mi(self): def batched_mi(self):
return MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing)
r"""
Return a object which supports getting subtensor by
batched mesh indexing.


def batched_set_mi(self, val):
return MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val)
For Tensor ``a`` and index ``idx``, each value of the ``idx`` need to be a 2-dim matrix or slice.
Cartesian product ``... * idx[k-1][i] * idx[k][i] * idx[k+1][i] * ...`` will be a subtensor from ``a[i]``.
Each matrix ``idx[k]`` should have the size of ``batched_dim`` rows as ``idx[0]`` indicated.
And for slice value, it will apply same slice for each ``batched_dim``. For more details see the example below.


def batched_incr_mi(self, val):
return MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val)
Examples:

.. testcode::

from megengine import tensor
a = tensor(np.arange(144, dtype=np.float32).reshape((3, 3, 4, 4)))

print(a.batched_mi[:2, [[0],[1]],[[0,1],[2,3]],[[0],[1]]])
# is equal to elements from a[0] with ``[0] * [0,1] * [0] = [[[(0,0,0)], [(0,1,0)]]]``(shape is [1,2,1])
# and from a[1] with ``[1] * [2,3] * [1] = [[[(1,2,1)], [(1,3,1)]]]``(shape is also [1,2,1])
# a[0,0,0,0] = 0, a[0,0,1,0] = 4, a[1,1,2,1] = 73, a[1,1,3,1] = 77

print(a.batched_mi[:2, [[0],[1]], :2, :1])
# is equal to ``a.batched_mi[:2, [[0],[1]], [[0,1],[0,1]],[[0],[0]]]``

Outputs:

.. testoutput::

Tensor([[[[ 0.]
[ 4.]]]
[[[73.]
[77.]]]])
Tensor([[[[ 0.]
[ 4.]]]
[[[64.]
[68.]]]])
"""
return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing)

def batched_set_mi(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val)

def batched_incr_mi(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val)


def __array__(self, dtype=None): def __array__(self, dtype=None):
if dtype is None: if dtype is None:


Loading…
Cancel
Save