Browse Source

fix(mge/functional): support negative axis in math.py

GitOrigin-RevId: 75143a7309
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
9b41321926
3 changed files with 22 additions and 24 deletions
  1. +0
    -2
      python_module/src/python/opr_template.py
  2. +0
    -1
      python_module/src/swig/symbol_var_SymbolVar.py
  3. +22
    -21
      python_module/test/unit/functional/test_math.py

+ 0
- 2
python_module/src/python/opr_template.py View File

@@ -106,8 +106,6 @@ def reduce_(src, mode, axis=None, keepdims=False, *,
inputs.append(1) inputs.append(1)
assert not keepdims, 'can not set axis=None and keepdims=True' assert not keepdims, 'can not set axis=None and keepdims=True'
else: else:
assert isinstance(axis, int) and axis >= 0, (
'bad axis: {!r}'.format(axis))
remove_axis = not keepdims remove_axis = not keepdims
kwargs['axis'] = axis kwargs['axis'] = axis




+ 0
- 1
python_module/src/swig/symbol_var_SymbolVar.py View File

@@ -196,7 +196,6 @@ def shape(self):
return get_var_shape(self) return get_var_shape(self)


def axis_shape(self, axis): def axis_shape(self, axis):
assert axis >= 0
from .opr import get_var_shape from .opr import get_var_shape
return get_var_shape(self, axis=axis) return get_var_shape(self, axis=axis)




+ 22
- 21
python_module/test/unit/functional/test_math.py View File

@@ -20,30 +20,31 @@ def common_test_reduce(opr, ref_opr):
cases = [{"input": data1}, {"input": data2}] cases = [{"input": data1}, {"input": data2}]


if opr not in (F.argmin, F.argmax): if opr not in (F.argmin, F.argmax):
# test default axis
opr_test(cases, opr, ref_fn=ref_opr) opr_test(cases, opr, ref_fn=ref_opr)

axis = 2
opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis)

axis = 2
keepdims = True
opr_test(
cases,
opr,
ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=keepdims),
axis=axis,
keepdims=keepdims,
)
# test all axises in range of input shape
for axis in range(-3, 3):
# test keepdims False
opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis)
# test keepdims True
opr_test(
cases,
opr,
ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=True),
axis=axis,
keepdims=True,
)
else: else:
# test defaut axis
opr_test(cases, opr, ref_fn=lambda x: ref_opr(x).astype(np.int32)) opr_test(cases, opr, ref_fn=lambda x: ref_opr(x).astype(np.int32))
axis = 2
opr_test(
cases,
opr,
ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32),
axis=axis,
)
# test all axises in range of input shape
for axis in range(0, 3):
opr_test(
cases,
opr,
ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32),
axis=axis,
)




def test_sum(): def test_sum():


Loading…
Cancel
Save