|
@@ -20,30 +20,31 @@ def common_test_reduce(opr, ref_opr): |
|
|
cases = [{"input": data1}, {"input": data2}] |
|
|
cases = [{"input": data1}, {"input": data2}] |
|
|
|
|
|
|
|
|
if opr not in (F.argmin, F.argmax): |
|
|
if opr not in (F.argmin, F.argmax): |
|
|
|
|
|
# test default axis |
|
|
opr_test(cases, opr, ref_fn=ref_opr) |
|
|
opr_test(cases, opr, ref_fn=ref_opr) |
|
|
|
|
|
|
|
|
axis = 2 |
|
|
|
|
|
opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis) |
|
|
|
|
|
|
|
|
|
|
|
axis = 2 |
|
|
|
|
|
keepdims = True |
|
|
|
|
|
opr_test( |
|
|
|
|
|
cases, |
|
|
|
|
|
opr, |
|
|
|
|
|
ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=keepdims), |
|
|
|
|
|
axis=axis, |
|
|
|
|
|
keepdims=keepdims, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
# test all axises in range of input shape |
|
|
|
|
|
for axis in range(-3, 3): |
|
|
|
|
|
# test keepdims False |
|
|
|
|
|
opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis) |
|
|
|
|
|
# test keepdims True |
|
|
|
|
|
opr_test( |
|
|
|
|
|
cases, |
|
|
|
|
|
opr, |
|
|
|
|
|
ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=True), |
|
|
|
|
|
axis=axis, |
|
|
|
|
|
keepdims=True, |
|
|
|
|
|
) |
|
|
else: |
|
|
else: |
|
|
|
|
|
# test defaut axis |
|
|
opr_test(cases, opr, ref_fn=lambda x: ref_opr(x).astype(np.int32)) |
|
|
opr_test(cases, opr, ref_fn=lambda x: ref_opr(x).astype(np.int32)) |
|
|
|
|
|
|
|
|
axis = 2 |
|
|
|
|
|
opr_test( |
|
|
|
|
|
cases, |
|
|
|
|
|
opr, |
|
|
|
|
|
ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), |
|
|
|
|
|
axis=axis, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
# test all axises in range of input shape |
|
|
|
|
|
for axis in range(0, 3): |
|
|
|
|
|
opr_test( |
|
|
|
|
|
cases, |
|
|
|
|
|
opr, |
|
|
|
|
|
ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), |
|
|
|
|
|
axis=axis, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_sum(): |
|
|
def test_sum(): |
|
|