diff --git a/dnn/src/common/argmxx_helper.h b/dnn/src/common/argmxx_helper.h index 87c69be3..a59e26fc 100644 --- a/dnn/src/common/argmxx_helper.h +++ b/dnn/src/common/argmxx_helper.h @@ -56,7 +56,7 @@ struct ArgmxxOp { ArgmxxOp(stype_ *src, dt_int32 *dst, uint32_t A, uint32_t B, uint32_t C): src(src), dst(dst), A(A), B(B), C(C), INIT(wtype(is_max ? DTypeTrait::min() : - DTypeTrait::max(), -1)) + DTypeTrait::max(), 0)) { } MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) diff --git a/dnn/src/naive/argmxx/opr_impl.cpp b/dnn/src/naive/argmxx/opr_impl.cpp index 5754303a..8cca17ae 100644 --- a/dnn/src/naive/argmxx/opr_impl.cpp +++ b/dnn/src/naive/argmxx/opr_impl.cpp @@ -45,7 +45,7 @@ void exec_forward(_megdnn_tensor_in src, reduce::get_ABC(src.layout, A, B, C, param.axis); for (size_t a = 0; a < A; ++a) for (size_t c = 0; c < C; ++c) { float best_val = traits::init; - size_t best_arg = -1; + size_t best_arg = 0; for (size_t b = 0; b < B; ++b) { float curr_val = float(src.ptr()[(a*B+b)*C+c]); if (traits::better_than(curr_val, best_val)) { diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 76ca492d..acdf1a51 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -527,3 +527,20 @@ def test_nms_is_same(): assert op3 != op4 + + +def test_argmxx_on_inf(): + def run_argmax(): + x = F.zeros((100, 100)) + x[:] = -float("inf") + idxs = F.argmax(x, axis=0) + return idxs + + def run_argmin(): + x = F.zeros((100, 100)) + x[:] = float("inf") + idxs = F.argmin(x, axis=0) + return idxs + + assert all(run_argmax() >= 0) + assert all(run_argmin() >= 0)