|
@@ -425,7 +425,7 @@ class Tensor: |
|
|
def __getitem__(self, idx): |
|
|
def __getitem__(self, idx): |
|
|
return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx)) |
|
|
return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx)) |
|
|
|
|
|
|
|
|
def set_subtensor(self, val: "Tensor"): |
|
|
|
|
|
|
|
|
def set_subtensor(self, val: "Tensor") -> _MGBIndexWrapper: |
|
|
r""" |
|
|
r""" |
|
|
Return a object which supports using ``__getitem__`` to set subtensor. |
|
|
Return a object which supports using ``__getitem__`` to set subtensor. |
|
|
|
|
|
|
|
@@ -433,7 +433,7 @@ class Tensor: |
|
|
""" |
|
|
""" |
|
|
return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val) |
|
|
return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val) |
|
|
|
|
|
|
|
|
def incr_subtensor(self, val: "Tensor"): |
|
|
|
|
|
|
|
|
def incr_subtensor(self, val: "Tensor") -> _MGBIndexWrapper: |
|
|
r""" |
|
|
r""" |
|
|
Return a object which supports using ``__getitem__`` to increase subtensor. |
|
|
Return a object which supports using ``__getitem__`` to increase subtensor. |
|
|
|
|
|
|
|
@@ -442,7 +442,7 @@ class Tensor: |
|
|
return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val) |
|
|
return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val) |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def ai(self): |
|
|
|
|
|
|
|
|
def ai(self) -> _MGBIndexWrapper: |
|
|
r""" |
|
|
r""" |
|
|
Return a object which supports complex index method to get subtensor. |
|
|
Return a object which supports complex index method to get subtensor. |
|
|
|
|
|
|
|
@@ -465,20 +465,20 @@ class Tensor: |
|
|
""" |
|
|
""" |
|
|
return _MGBIndexWrapper(self, mgb.opr.advanced_indexing) |
|
|
return _MGBIndexWrapper(self, mgb.opr.advanced_indexing) |
|
|
|
|
|
|
|
|
def set_ai(self, val: "Tensor"): |
|
|
|
|
|
|
|
|
def set_ai(self, val: "Tensor") -> _MGBIndexWrapper: |
|
|
r""" |
|
|
r""" |
|
|
Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing. |
|
|
Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing. |
|
|
""" |
|
|
""" |
|
|
return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val) |
|
|
return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val) |
|
|
|
|
|
|
|
|
def incr_ai(self, val: "Tensor"): |
|
|
|
|
|
|
|
|
def incr_ai(self, val: "Tensor") -> _MGBIndexWrapper: |
|
|
r""" |
|
|
r""" |
|
|
Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing. |
|
|
Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing. |
|
|
""" |
|
|
""" |
|
|
return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val) |
|
|
return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val) |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def mi(self): |
|
|
|
|
|
|
|
|
def mi(self) -> _MGBIndexWrapper: |
|
|
r""" |
|
|
r""" |
|
|
Return a object which supports getting subtensor by |
|
|
Return a object which supports getting subtensor by |
|
|
the coordinates which is Cartesian product of given index. |
|
|
the coordinates which is Cartesian product of given index. |
|
@@ -502,20 +502,20 @@ class Tensor: |
|
|
""" |
|
|
""" |
|
|
return _MGBIndexWrapper(self, mgb.opr.mesh_indexing) |
|
|
return _MGBIndexWrapper(self, mgb.opr.mesh_indexing) |
|
|
|
|
|
|
|
|
def set_mi(self, val: "Tensor"): |
|
|
|
|
|
|
|
|
def set_mi(self, val: "Tensor") -> _MGBIndexWrapper: |
|
|
r""" |
|
|
r""" |
|
|
Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing. |
|
|
Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing. |
|
|
""" |
|
|
""" |
|
|
return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val) |
|
|
return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val) |
|
|
|
|
|
|
|
|
def incr_mi(self, val: "Tensor"): |
|
|
|
|
|
|
|
|
def incr_mi(self, val: "Tensor") -> _MGBIndexWrapper: |
|
|
r""" |
|
|
r""" |
|
|
Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing. |
|
|
Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing. |
|
|
""" |
|
|
""" |
|
|
return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val) |
|
|
return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val) |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def batched_mi(self): |
|
|
|
|
|
|
|
|
def batched_mi(self) -> _MGBIndexWrapper: |
|
|
r""" |
|
|
r""" |
|
|
Return a object which supports getting subtensor by |
|
|
Return a object which supports getting subtensor by |
|
|
batched mesh indexing. |
|
|
batched mesh indexing. |
|
@@ -555,13 +555,13 @@ class Tensor: |
|
|
""" |
|
|
""" |
|
|
return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing) |
|
|
return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing) |
|
|
|
|
|
|
|
|
def batched_set_mi(self, val: "Tensor"): |
|
|
|
|
|
|
|
|
def batched_set_mi(self, val: "Tensor") -> _MGBIndexWrapper: |
|
|
r""" |
|
|
r""" |
|
|
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. |
|
|
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. |
|
|
""" |
|
|
""" |
|
|
return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val) |
|
|
return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val) |
|
|
|
|
|
|
|
|
def batched_incr_mi(self, val: "Tensor"): |
|
|
|
|
|
|
|
|
def batched_incr_mi(self, val: "Tensor") -> _MGBIndexWrapper: |
|
|
r""" |
|
|
r""" |
|
|
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. |
|
|
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. |
|
|
""" |
|
|
""" |
|
@@ -680,18 +680,31 @@ def tensor( |
|
|
return Tensor(shared_nd, requires_grad=requires_grad) |
|
|
return Tensor(shared_nd, requires_grad=requires_grad) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Dict(collections.MutableMapping): |
|
|
|
|
|
def __init__(self, *args, key=None, **kwargs): |
|
|
|
|
|
|
|
|
class TensorDict(collections.MutableMapping): |
|
|
|
|
|
r""" |
|
|
|
|
|
A helper class to maintain dict with Tensor key. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
self.data = {} |
|
|
self.data = {} |
|
|
if key: |
|
|
|
|
|
self.keyfn = key |
|
|
|
|
|
for i in args: |
|
|
for i in args: |
|
|
self.update(i) |
|
|
self.update(i) |
|
|
self.update(**kwargs) |
|
|
self.update(**kwargs) |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def keyfn(key): # pylint: disable=method-hidden |
|
|
|
|
|
return key |
|
|
|
|
|
|
|
|
class keyfn: |
|
|
|
|
|
def __new__(cls, x: Tensor): |
|
|
|
|
|
if not isinstance(x, Tensor): |
|
|
|
|
|
return x |
|
|
|
|
|
return super().__new__(cls) |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, x: Tensor): |
|
|
|
|
|
self._data = x # do not save id directly to make pickle work |
|
|
|
|
|
|
|
|
|
|
|
def __hash__(self): |
|
|
|
|
|
return id(self._data) |
|
|
|
|
|
|
|
|
|
|
|
def __eq__(self, other): |
|
|
|
|
|
return isinstance(other, type(self)) and id(self._data) == id(other._data) |
|
|
|
|
|
|
|
|
def __getitem__(self, key): |
|
|
def __getitem__(self, key): |
|
|
_, v = self.data[self.keyfn(key)] |
|
|
_, v = self.data[self.keyfn(key)] |
|
@@ -709,24 +722,3 @@ class Dict(collections.MutableMapping): |
|
|
|
|
|
|
|
|
def __len__(self): |
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
return len(self.data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorDict(Dict): # pylint: disable=too-many-ancestors |
|
|
|
|
|
class keyfn: |
|
|
|
|
|
def __new__(cls, x: Tensor): |
|
|
|
|
|
if not isinstance(x, Tensor): |
|
|
|
|
|
return x |
|
|
|
|
|
return super().__new__(cls) |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, x: Tensor): |
|
|
|
|
|
self._data = x # do not save id directly to make pickle work |
|
|
|
|
|
|
|
|
|
|
|
def __hash__(self): |
|
|
|
|
|
return id(self._data) |
|
|
|
|
|
|
|
|
|
|
|
def __eq__(self, other): |
|
|
|
|
|
# pylint: disable=undefined-variable |
|
|
|
|
|
return isinstance(other, __class__) and id(self._data) == id(other._data) |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args): |
|
|
|
|
|
super().__init__(*args) |
|
|
|