Browse Source

fix(imperative): add __array__ and __array_wrap__ for tensorwrapper

GitOrigin-RevId: 87d4ab6c8e
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
67859f04e1
2 changed files with 22 additions and 1 deletions
  1. +12
    -1
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  2. +10
    -0
      imperative/python/test/unit/test_tensor_wrapper.py

+ 12
- 1
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -211,7 +211,18 @@ def _expand_args(args):

class ArrayMethodMixin(abc.ABC):

__array_priority__ = 233333
# enable tensor to be converted to numpy array
__array_priority__ = 1001

def __array__(self, dtype=None):
if dtype == None:
return self.numpy()
return self.numpy().astype(dtype)

def __array_wrap__(self, array):
return TensorWrapper(
as_raw_tensor(array, dtype=array.dtype, device=self.device)
)

@abc.abstractmethod
def _reset(self, other):


+ 10
- 0
imperative/python/test/unit/test_tensor_wrapper.py View File

@@ -50,3 +50,13 @@ def test_set_subtensor():
np.testing.assert_almost_equal(x.numpy(), [3, 1, 2], decimal=6)
x[1:3] = [4, 5]
np.testing.assert_almost_equal(x.numpy(), [3, 4, 5], decimal=6)


def test_computing_with_numpy_array():
x = np.array([1, 2, 3], dtype=np.int32)
xx = TensorWrapper(x, device="cpu0")
y = np.array([1, 0, 3], dtype=np.int32)
assert np.add(xx, y).device == xx.device
np.testing.assert_equal(np.add(xx, y).numpy(), np.add(x, y))
np.testing.assert_equal(np.equal(xx, y).numpy(), np.equal(x, y))
np.testing.assert_equal(np.equal(xx, xx).numpy(), np.equal(x, x))

Loading…
Cancel
Save