diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 5d8cdf84..9f1cb5b5 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -392,6 +392,13 @@ class ArrayMethodMixin(abc.ABC): return _broadcast(self, _expand_args(args)) def transpose(self, *args): + if self.ndim == 0: + assert ( + len(args) == 0 + ), "transpose for scalar does not accept additional args" + ret = self.to(self.device) + setscalar(ret) + return ret if not args: args = range(self.ndim)[::-1] return _transpose(self, _expand_args(args)) diff --git a/imperative/python/test/unit/test_zero_dim_tensor.py b/imperative/python/test/unit/test_zero_dim_tensor.py index 8130ca82..fd1a1031 100644 --- a/imperative/python/test/unit/test_zero_dim_tensor.py +++ b/imperative/python/test/unit/test_zero_dim_tensor.py @@ -50,3 +50,8 @@ def test_elemementwise(): def test_astype(): a = Tensor(1.0) assert a.astype("int32").ndim == 0 + + +def test_tranpose(): + a = Tensor(1.0) + assert a.transpose().ndim == 0