diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index d6985274..790c9629 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -165,7 +165,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: return list(map(int, axis)) axis = get_axes() - axis = sorted(i + inp.ndim if i < 0 else i for i in axis) + axis = utils._normalize_axis(inp.ndim, axis) axis = [a - i for i, a in enumerate(axis)] op = builtin.RemoveAxis(axis=axis) @@ -190,8 +190,7 @@ def _reduce(mode): op = builtin.Reduce(mode=mode, axis=0) (result,) = apply(op, data) elif isinstance(axis, collections.abc.Iterable): - axis = list(axis) - axis.sort(reverse=True) + axis = utils._normalize_axis(self.ndim, axis, reverse=True) for ai in axis: op = builtin.Reduce(mode=mode, axis=ai) (data,) = apply(op, data) @@ -199,6 +198,7 @@ def _reduce(mode): data = _remove_axis(data, ai) result = data else: + # builtin.Reduce already accept negtive axis op = builtin.Reduce(mode=mode, axis=axis) (result,) = apply(op, data) diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 98379a01..b1f834ce 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -178,3 +178,28 @@ def make_shape_tuple(shape): s = [] _expand_int(s, shape) return tuple(s) + + +def _normalize_axis( + ndim: int, axis: Union[int, Iterable], reverse=False +) -> Union[int, list]: + def convert(x): + x_org = x + if x < 0: + x = ndim + x + assert ( + x >= 0 and x < ndim + ), "axis {} is out of bounds for tensor of dimension {}".format(x_org, ndim) + return x + + if isinstance(axis, int): + return convert(axis) + elif isinstance(axis, Iterable): + axis_org = axis + axis = list(sorted(map(convert, axis), reverse=reverse)) + for i in range(len(axis) - 1): + assert axis[i] != axis[i + 1], "axis {} contains duplicated indices".format( + axis_org + ) + return axis + raise diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index cf6ecebf..f2187b3f 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -466,9 +466,13 @@ def argmin( 0 """ + if axis is None: + assert not keepdims, "can not set axis=None and keepdims=True" + inp = inp.flatten() + axis = 0 + + axis = utils._normalize_axis(inp.ndim, axis, reverse=True) if isinstance(axis, collections.abc.Iterable): - axis = list(axis) - axis.sort(reverse=True) for ai in axis: op = builtin.Argmin(axis=ai) @@ -479,11 +483,6 @@ def argmin( return inp - if axis is None: - assert not keepdims, "can not set axis=None and keepdims=True" - inp = inp.flatten() - axis = 0 - op = builtin.Argmin(axis=axis) (result,) = apply(op, inp) if not keepdims: @@ -525,9 +524,13 @@ def argmax( 5 """ + if axis is None: + assert not keepdims, "can not set axis=None and keepdims=True" + inp = inp.flatten() + axis = 0 + axis = utils._normalize_axis(inp.ndim, axis, reverse=True) + if isinstance(axis, collections.abc.Iterable): - axis = list(axis) - axis.sort(reverse=True) for ai in axis: op = builtin.Argmax(axis=ai) @@ -538,11 +541,6 @@ def argmax( return inp - if axis is None: - assert not keepdims, "can not set axis=None and keepdims=True" - inp = inp.flatten() - axis = 0 - op = builtin.Argmax(axis=axis) (result,) = apply(op, inp) if not keepdims: diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 90d801de..c2a178a8 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -811,3 +811,19 @@ def test_assert_not_equal(): y = F.zeros(shape, dtype=np.float32) + 1.1 with pytest.raises(RuntimeError): z = F.utils._assert_equal(x, y) + + +def test_neg_axis(): + x = tensor(np.random.normal(0, 1, (32, 5))) + + y = F.argmax(x, axis=-1) + yy = F.argmax(x, axis=1) + np.testing.assert_equal(y.numpy(), yy.numpy()) + + y = F.argmax(x, axis=(-1, -2)) + yy = F.argmax(x, axis=(0, 1)) + np.testing.assert_equal(y.numpy(), yy.numpy()) + + y = F.argmin(x, axis=(-1, -2)) + yy = F.argmin(x, axis=(0, 1)) + np.testing.assert_equal(y.numpy(), yy.numpy()) diff --git a/imperative/python/test/unit/functional/test_math.py b/imperative/python/test/unit/functional/test_math.py index acff5858..661e40d1 100644 --- a/imperative/python/test/unit/functional/test_math.py +++ b/imperative/python/test/unit/functional/test_math.py @@ -9,6 +9,7 @@ from functools import partial import numpy as np +import pytest from utils import opr_test import megengine.functional as F @@ -48,6 +49,14 @@ def common_test_reduce(opr, ref_opr): ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), axis=axis, ) + # test negative axis + axis = axis - len(data1_shape) + opr_test( + cases, + opr, + ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), + axis=axis, + ) def test_sum(): @@ -137,3 +146,14 @@ def test_normalize(): cases[0]["input"][0, 0, 0, :] = 0 cases[1]["input"][0, 0, 0, :] = 0 opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3)) + + +def test_sum_neg_axis(): + shape = (2, 3) + data = np.random.random(shape).astype(np.float32) + for axis in (-1, -2, (-2, 1), (-1, 0)): + get = F.sum(tensor(data), axis=axis) + ref = np.sum(data, axis=axis) + np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6) + with pytest.raises(AssertionError): + F.sum(tensor(data), axis=(-1, 1))