From 5198b7834cf4e042c9d3bb4ddcef613089ab0a1e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 2 Jul 2021 15:28:45 +0800 Subject: [PATCH] fix(mge/functional): fix expand_dims for scalar GitOrigin-RevId: 253ea608f7e45a86a90e53cf6159964b2ab54678 --- imperative/python/megengine/functional/tensor.py | 9 ++++++++- imperative/python/test/unit/functional/test_tensor.py | 13 +++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 4cd1daba..e7b7434d 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -851,7 +851,14 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: axis = get_axes() ndim = inp.ndim + len(axis) axis = sorted(i + ndim if i < 0 else i for i in axis) - + assert axis, "axis could not be empty" + if inp._isscalar(): + assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0]) + if len(axis) == 1: + inp = copy(inp, device=None) + inp._unsetscalar() + return inp + axis = axis[1:] op = builtin.AddAxis(axis=axis) (result,) = apply(op, inp) return result diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 144a5e28..c87a5e74 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -253,6 +253,19 @@ def test_expand_dims(is_varnode): np.testing.assert_equal(y, yy.numpy()) +def test_expand_dims_for_scalar(): + x = np.array(1, dtype="float32") + xx = make_tensor(x, None) + for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]: + y = np.expand_dims(x, axis) + yy = F.expand_dims(xx, axis) + np.testing.assert_equal(y, yy.numpy()) + + for axis in [1, -2, (1, 2), (-2, -3)]: + np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis) + np.testing.assert_raises(AssertionError, F.expand_dims, xx, axis) + + @pytest.mark.parametrize("is_varnode", [True, False]) def test_elemwise_dtype_promotion(is_varnode): if is_varnode: