Browse Source

fix(mge): fix bug of tensor.T

GitOrigin-RevId: 9fe9347b00
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
81b6a73382
2 changed files with 7 additions and 1 deletions
  1. +1
    -1
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  2. +6
    -0
      imperative/python/test/unit/core/test_tensor_wrapper.py

+ 1
- 1
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -371,7 +371,7 @@ class ArrayMethodMixin(abc.ABC):

def transpose(self, *args):
if not args:
args = reversed(range(self.ndim))
args = range(self.ndim)[::-1]
return _transpose(self, _expand_args(args))

def flatten(self):


+ 6
- 0
imperative/python/test/unit/core/test_tensor_wrapper.py View File

@@ -60,3 +60,9 @@ def test_computing_with_numpy_array():
np.testing.assert_equal(np.add(xx, y).numpy(), np.add(x, y))
np.testing.assert_equal(np.equal(xx, y).numpy(), np.equal(x, y))
np.testing.assert_equal(np.equal(xx, xx).numpy(), np.equal(x, x))


def test_transpose():
x = np.random.rand(2, 5).astype("float32")
xx = TensorWrapper(x)
np.testing.assert_almost_equal(xx.T.numpy(), x.T)

Loading…
Cancel
Save