Browse Source

refactor(mge/tensor): combine Dict and TensorDict

GitOrigin-RevId: 6b6c03c04b
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
b8d8886e35
1 changed files with 31 additions and 39 deletions
  1. +31
    -39
      python_module/megengine/core/tensor.py

+ 31
- 39
python_module/megengine/core/tensor.py View File

@@ -425,7 +425,7 @@ class Tensor:
def __getitem__(self, idx): def __getitem__(self, idx):
return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx)) return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx))


def set_subtensor(self, val: "Tensor"):
def set_subtensor(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Return a object which supports using ``__getitem__`` to set subtensor. Return a object which supports using ``__getitem__`` to set subtensor.


@@ -433,7 +433,7 @@ class Tensor:
""" """
return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val) return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val)


def incr_subtensor(self, val: "Tensor"):
def incr_subtensor(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Return a object which supports using ``__getitem__`` to increase subtensor. Return a object which supports using ``__getitem__`` to increase subtensor.


@@ -442,7 +442,7 @@ class Tensor:
return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val) return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val)


@property @property
def ai(self):
def ai(self) -> _MGBIndexWrapper:
r""" r"""
Return a object which supports complex index method to get subtensor. Return a object which supports complex index method to get subtensor.


@@ -465,20 +465,20 @@ class Tensor:
""" """
return _MGBIndexWrapper(self, mgb.opr.advanced_indexing) return _MGBIndexWrapper(self, mgb.opr.advanced_indexing)


def set_ai(self, val: "Tensor"):
def set_ai(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing. Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing.
""" """
return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val) return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)


def incr_ai(self, val: "Tensor"):
def incr_ai(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing. Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing.
""" """
return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val) return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)


@property @property
def mi(self):
def mi(self) -> _MGBIndexWrapper:
r""" r"""
Return a object which supports getting subtensor by Return a object which supports getting subtensor by
the coordinates which is Cartesian product of given index. the coordinates which is Cartesian product of given index.
@@ -502,20 +502,20 @@ class Tensor:
""" """
return _MGBIndexWrapper(self, mgb.opr.mesh_indexing) return _MGBIndexWrapper(self, mgb.opr.mesh_indexing)


def set_mi(self, val: "Tensor"):
def set_mi(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing. Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing.
""" """
return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val) return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)


def incr_mi(self, val: "Tensor"):
def incr_mi(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing. Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing.
""" """
return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val) return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val)


@property @property
def batched_mi(self):
def batched_mi(self) -> _MGBIndexWrapper:
r""" r"""
Return a object which supports getting subtensor by Return a object which supports getting subtensor by
batched mesh indexing. batched mesh indexing.
@@ -555,13 +555,13 @@ class Tensor:
""" """
return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing) return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing)


def batched_set_mi(self, val: "Tensor"):
def batched_set_mi(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
""" """
return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val) return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val)


def batched_incr_mi(self, val: "Tensor"):
def batched_incr_mi(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
""" """
@@ -680,18 +680,31 @@ def tensor(
return Tensor(shared_nd, requires_grad=requires_grad) return Tensor(shared_nd, requires_grad=requires_grad)




class Dict(collections.MutableMapping):
def __init__(self, *args, key=None, **kwargs):
class TensorDict(collections.MutableMapping):
r"""
A helper class to maintain dict with Tensor key.
"""

def __init__(self, *args, **kwargs):
self.data = {} self.data = {}
if key:
self.keyfn = key
for i in args: for i in args:
self.update(i) self.update(i)
self.update(**kwargs) self.update(**kwargs)


@staticmethod
def keyfn(key): # pylint: disable=method-hidden
return key
class keyfn:
def __new__(cls, x: Tensor):
if not isinstance(x, Tensor):
return x
return super().__new__(cls)

def __init__(self, x: Tensor):
self._data = x # do not save id directly to make pickle work

def __hash__(self):
return id(self._data)

def __eq__(self, other):
return isinstance(other, type(self)) and id(self._data) == id(other._data)


def __getitem__(self, key): def __getitem__(self, key):
_, v = self.data[self.keyfn(key)] _, v = self.data[self.keyfn(key)]
@@ -709,24 +722,3 @@ class Dict(collections.MutableMapping):


def __len__(self): def __len__(self):
return len(self.data) return len(self.data)


class TensorDict(Dict): # pylint: disable=too-many-ancestors
class keyfn:
def __new__(cls, x: Tensor):
if not isinstance(x, Tensor):
return x
return super().__new__(cls)

def __init__(self, x: Tensor):
self._data = x # do not save id directly to make pickle work

def __hash__(self):
return id(self._data)

def __eq__(self, other):
# pylint: disable=undefined-variable
return isinstance(other, __class__) and id(self._data) == id(other._data)

def __init__(self, *args):
super().__init__(*args)

Loading…
Cancel
Save