Browse Source

fix(mge): fix scalar transpose

GitOrigin-RevId: c2b9e025c7
release-1.2
Megvii Engine Team 4 years ago
parent
commit
5a01de7851
2 changed files with 12 additions and 0 deletions
  1. +7
    -0
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  2. +5
    -0
      imperative/python/test/unit/test_zero_dim_tensor.py

+ 7
- 0
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -392,6 +392,13 @@ class ArrayMethodMixin(abc.ABC):
return _broadcast(self, _expand_args(args))

def transpose(self, *args):
if self.ndim == 0:
assert (
len(args) == 0
), "transpose for scalar does not accept additional args"
ret = self.to(self.device)
setscalar(ret)
return ret
if not args:
args = range(self.ndim)[::-1]
return _transpose(self, _expand_args(args))


+ 5
- 0
imperative/python/test/unit/test_zero_dim_tensor.py View File

@@ -50,3 +50,8 @@ def test_elemementwise():
def test_astype():
a = Tensor(1.0)
assert a.astype("int32").ndim == 0


def test_tranpose():
a = Tensor(1.0)
assert a.transpose().ndim == 0

Loading…
Cancel
Save