From 5a01de78519e851cf4442b55a0220427f1395afc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 24 Nov 2020 17:55:46 +0800 Subject: [PATCH] fix(mge): fix scalar transpose GitOrigin-RevId: c2b9e025c7f305975509e195f6c8e2c2cf7f81b7 --- imperative/python/megengine/core/tensor/tensor_wrapper.py | 7 +++++++ imperative/python/test/unit/test_zero_dim_tensor.py | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 5d8cdf84..9f1cb5b5 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -392,6 +392,13 @@ class ArrayMethodMixin(abc.ABC): return _broadcast(self, _expand_args(args)) def transpose(self, *args): + if self.ndim == 0: + assert ( + len(args) == 0 + ), "transpose for scalar does not accept additional args" + ret = self.to(self.device) + setscalar(ret) + return ret if not args: args = range(self.ndim)[::-1] return _transpose(self, _expand_args(args)) diff --git a/imperative/python/test/unit/test_zero_dim_tensor.py b/imperative/python/test/unit/test_zero_dim_tensor.py index 8130ca82..fd1a1031 100644 --- a/imperative/python/test/unit/test_zero_dim_tensor.py +++ b/imperative/python/test/unit/test_zero_dim_tensor.py @@ -50,3 +50,8 @@ def test_elemementwise(): def test_astype(): a = Tensor(1.0) assert a.astype("int32").ndim == 0 + + +def test_tranpose(): + a = Tensor(1.0) + assert a.transpose().ndim == 0