diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index fd2e6bea..722d5dc9 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -211,7 +211,18 @@ def _expand_args(args): class ArrayMethodMixin(abc.ABC): - __array_priority__ = 233333 + # enable tensor to be converted to numpy array + __array_priority__ = 1001 + + def __array__(self, dtype=None): + if dtype == None: + return self.numpy() + return self.numpy().astype(dtype) + + def __array_wrap__(self, array): + return TensorWrapper( + as_raw_tensor(array, dtype=array.dtype, device=self.device) + ) @abc.abstractmethod def _reset(self, other): diff --git a/imperative/python/test/unit/test_tensor_wrapper.py b/imperative/python/test/unit/test_tensor_wrapper.py index c2f8def6..26bc9c9c 100644 --- a/imperative/python/test/unit/test_tensor_wrapper.py +++ b/imperative/python/test/unit/test_tensor_wrapper.py @@ -50,3 +50,13 @@ def test_set_subtensor(): np.testing.assert_almost_equal(x.numpy(), [3, 1, 2], decimal=6) x[1:3] = [4, 5] np.testing.assert_almost_equal(x.numpy(), [3, 4, 5], decimal=6) + + +def test_computing_with_numpy_array(): + x = np.array([1, 2, 3], dtype=np.int32) + xx = TensorWrapper(x, device="cpu0") + y = np.array([1, 0, 3], dtype=np.int32) + assert np.add(xx, y).device == xx.device + np.testing.assert_equal(np.add(xx, y).numpy(), np.add(x, y)) + np.testing.assert_equal(np.equal(xx, y).numpy(), np.equal(x, y)) + np.testing.assert_equal(np.equal(xx, xx).numpy(), np.equal(x, x))