From 9b4132192693264917c89dc7241f4a7b8026887e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 22 Apr 2020 16:50:17 +0800 Subject: [PATCH] fix(mge/functional): support negative axis in math.py GitOrigin-RevId: 75143a73096eb46a1ef821b95a4228fd8f06ab58 --- python_module/src/python/opr_template.py | 2 -- python_module/src/swig/symbol_var_SymbolVar.py | 1 - python_module/test/unit/functional/test_math.py | 43 +++++++++++++------------ 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/python_module/src/python/opr_template.py b/python_module/src/python/opr_template.py index c83d4172..29ea0a9d 100644 --- a/python_module/src/python/opr_template.py +++ b/python_module/src/python/opr_template.py @@ -106,8 +106,6 @@ def reduce_(src, mode, axis=None, keepdims=False, *, inputs.append(1) assert not keepdims, 'can not set axis=None and keepdims=True' else: - assert isinstance(axis, int) and axis >= 0, ( - 'bad axis: {!r}'.format(axis)) remove_axis = not keepdims kwargs['axis'] = axis diff --git a/python_module/src/swig/symbol_var_SymbolVar.py b/python_module/src/swig/symbol_var_SymbolVar.py index f2c10437..96fb3383 100644 --- a/python_module/src/swig/symbol_var_SymbolVar.py +++ b/python_module/src/swig/symbol_var_SymbolVar.py @@ -196,7 +196,6 @@ def shape(self): return get_var_shape(self) def axis_shape(self, axis): - assert axis >= 0 from .opr import get_var_shape return get_var_shape(self, axis=axis) diff --git a/python_module/test/unit/functional/test_math.py b/python_module/test/unit/functional/test_math.py index 96e49408..b5cb4a49 100644 --- a/python_module/test/unit/functional/test_math.py +++ b/python_module/test/unit/functional/test_math.py @@ -20,30 +20,31 @@ def common_test_reduce(opr, ref_opr): cases = [{"input": data1}, {"input": data2}] if opr not in (F.argmin, F.argmax): + # test default axis opr_test(cases, opr, ref_fn=ref_opr) - - axis = 2 - opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis) - - axis = 2 - keepdims = True - opr_test( - cases, - opr, - ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=keepdims), - axis=axis, - keepdims=keepdims, - ) + # test all axises in range of input shape + for axis in range(-3, 3): + # test keepdims False + opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis) + # test keepdims True + opr_test( + cases, + opr, + ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=True), + axis=axis, + keepdims=True, + ) else: + # test defaut axis opr_test(cases, opr, ref_fn=lambda x: ref_opr(x).astype(np.int32)) - - axis = 2 - opr_test( - cases, - opr, - ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), - axis=axis, - ) + # test all axises in range of input shape + for axis in range(0, 3): + opr_test( + cases, + opr, + ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), + axis=axis, + ) def test_sum():