Browse Source

fix(mge/functional): fix expand_dims for scalar

GitOrigin-RevId: 253ea608f7
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
5198b7834c
2 changed files with 21 additions and 1 deletions
  1. +8
    -1
      imperative/python/megengine/functional/tensor.py
  2. +13
    -0
      imperative/python/test/unit/functional/test_tensor.py

+ 8
- 1
imperative/python/megengine/functional/tensor.py View File

@@ -851,7 +851,14 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
axis = get_axes() axis = get_axes()
ndim = inp.ndim + len(axis) ndim = inp.ndim + len(axis)
axis = sorted(i + ndim if i < 0 else i for i in axis) axis = sorted(i + ndim if i < 0 else i for i in axis)

assert axis, "axis could not be empty"
if inp._isscalar():
assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0])
if len(axis) == 1:
inp = copy(inp, device=None)
inp._unsetscalar()
return inp
axis = axis[1:]
op = builtin.AddAxis(axis=axis) op = builtin.AddAxis(axis=axis)
(result,) = apply(op, inp) (result,) = apply(op, inp)
return result return result


+ 13
- 0
imperative/python/test/unit/functional/test_tensor.py View File

@@ -253,6 +253,19 @@ def test_expand_dims(is_varnode):
np.testing.assert_equal(y, yy.numpy()) np.testing.assert_equal(y, yy.numpy())




def test_expand_dims_for_scalar():
x = np.array(1, dtype="float32")
xx = make_tensor(x, None)
for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]:
y = np.expand_dims(x, axis)
yy = F.expand_dims(xx, axis)
np.testing.assert_equal(y, yy.numpy())

for axis in [1, -2, (1, 2), (-2, -3)]:
np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
np.testing.assert_raises(AssertionError, F.expand_dims, xx, axis)


@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_elemwise_dtype_promotion(is_varnode): def test_elemwise_dtype_promotion(is_varnode):
if is_varnode: if is_varnode:


Loading…
Cancel
Save