GitOrigin-RevId: a1bd1102a6
release-1.4
@@ -165,7 +165,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
return list(map(int, axis)) | |||
axis = get_axes() | |||
axis = sorted(i + inp.ndim if i < 0 else i for i in axis) | |||
axis = utils._normalize_axis(inp.ndim, axis) | |||
axis = [a - i for i, a in enumerate(axis)] | |||
op = builtin.RemoveAxis(axis=axis) | |||
@@ -190,8 +190,7 @@ def _reduce(mode): | |||
op = builtin.Reduce(mode=mode, axis=0) | |||
(result,) = apply(op, data) | |||
elif isinstance(axis, collections.abc.Iterable): | |||
axis = list(axis) | |||
axis.sort(reverse=True) | |||
axis = utils._normalize_axis(self.ndim, axis, reverse=True) | |||
for ai in axis: | |||
op = builtin.Reduce(mode=mode, axis=ai) | |||
(data,) = apply(op, data) | |||
@@ -199,6 +198,7 @@ def _reduce(mode): | |||
data = _remove_axis(data, ai) | |||
result = data | |||
else: | |||
# builtin.Reduce already accept negtive axis | |||
op = builtin.Reduce(mode=mode, axis=axis) | |||
(result,) = apply(op, data) | |||
@@ -178,3 +178,28 @@ def make_shape_tuple(shape): | |||
s = [] | |||
_expand_int(s, shape) | |||
return tuple(s) | |||
def _normalize_axis( | |||
ndim: int, axis: Union[int, Iterable], reverse=False | |||
) -> Union[int, list]: | |||
def convert(x): | |||
x_org = x | |||
if x < 0: | |||
x = ndim + x | |||
assert ( | |||
x >= 0 and x < ndim | |||
), "axis {} is out of bounds for tensor of dimension {}".format(x_org, ndim) | |||
return x | |||
if isinstance(axis, int): | |||
return convert(axis) | |||
elif isinstance(axis, Iterable): | |||
axis_org = axis | |||
axis = list(sorted(map(convert, axis), reverse=reverse)) | |||
for i in range(len(axis) - 1): | |||
assert axis[i] != axis[i + 1], "axis {} contains duplicated indices".format( | |||
axis_org | |||
) | |||
return axis | |||
raise |
@@ -466,9 +466,13 @@ def argmin( | |||
0 | |||
""" | |||
if axis is None: | |||
assert not keepdims, "can not set axis=None and keepdims=True" | |||
inp = inp.flatten() | |||
axis = 0 | |||
axis = utils._normalize_axis(inp.ndim, axis, reverse=True) | |||
if isinstance(axis, collections.abc.Iterable): | |||
axis = list(axis) | |||
axis.sort(reverse=True) | |||
for ai in axis: | |||
op = builtin.Argmin(axis=ai) | |||
@@ -479,11 +483,6 @@ def argmin( | |||
return inp | |||
if axis is None: | |||
assert not keepdims, "can not set axis=None and keepdims=True" | |||
inp = inp.flatten() | |||
axis = 0 | |||
op = builtin.Argmin(axis=axis) | |||
(result,) = apply(op, inp) | |||
if not keepdims: | |||
@@ -525,9 +524,13 @@ def argmax( | |||
5 | |||
""" | |||
if axis is None: | |||
assert not keepdims, "can not set axis=None and keepdims=True" | |||
inp = inp.flatten() | |||
axis = 0 | |||
axis = utils._normalize_axis(inp.ndim, axis, reverse=True) | |||
if isinstance(axis, collections.abc.Iterable): | |||
axis = list(axis) | |||
axis.sort(reverse=True) | |||
for ai in axis: | |||
op = builtin.Argmax(axis=ai) | |||
@@ -538,11 +541,6 @@ def argmax( | |||
return inp | |||
if axis is None: | |||
assert not keepdims, "can not set axis=None and keepdims=True" | |||
inp = inp.flatten() | |||
axis = 0 | |||
op = builtin.Argmax(axis=axis) | |||
(result,) = apply(op, inp) | |||
if not keepdims: | |||
@@ -811,3 +811,19 @@ def test_assert_not_equal(): | |||
y = F.zeros(shape, dtype=np.float32) + 1.1 | |||
with pytest.raises(RuntimeError): | |||
z = F.utils._assert_equal(x, y) | |||
def test_neg_axis(): | |||
x = tensor(np.random.normal(0, 1, (32, 5))) | |||
y = F.argmax(x, axis=-1) | |||
yy = F.argmax(x, axis=1) | |||
np.testing.assert_equal(y.numpy(), yy.numpy()) | |||
y = F.argmax(x, axis=(-1, -2)) | |||
yy = F.argmax(x, axis=(0, 1)) | |||
np.testing.assert_equal(y.numpy(), yy.numpy()) | |||
y = F.argmin(x, axis=(-1, -2)) | |||
yy = F.argmin(x, axis=(0, 1)) | |||
np.testing.assert_equal(y.numpy(), yy.numpy()) |
@@ -9,6 +9,7 @@ | |||
from functools import partial | |||
import numpy as np | |||
import pytest | |||
from utils import opr_test | |||
import megengine.functional as F | |||
@@ -48,6 +49,14 @@ def common_test_reduce(opr, ref_opr): | |||
ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), | |||
axis=axis, | |||
) | |||
# test negative axis | |||
axis = axis - len(data1_shape) | |||
opr_test( | |||
cases, | |||
opr, | |||
ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), | |||
axis=axis, | |||
) | |||
def test_sum(): | |||
@@ -137,3 +146,14 @@ def test_normalize(): | |||
cases[0]["input"][0, 0, 0, :] = 0 | |||
cases[1]["input"][0, 0, 0, :] = 0 | |||
opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3)) | |||
def test_sum_neg_axis(): | |||
shape = (2, 3) | |||
data = np.random.random(shape).astype(np.float32) | |||
for axis in (-1, -2, (-2, 1), (-1, 0)): | |||
get = F.sum(tensor(data), axis=axis) | |||
ref = np.sum(data, axis=axis) | |||
np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6) | |||
with pytest.raises(AssertionError): | |||
F.sum(tensor(data), axis=(-1, 1)) |