@@ -56,7 +56,7 @@ struct ArgmxxOp { | |||||
ArgmxxOp(stype_ *src, dt_int32 *dst, uint32_t A, uint32_t B, uint32_t C): | ArgmxxOp(stype_ *src, dt_int32 *dst, uint32_t A, uint32_t B, uint32_t C): | ||||
src(src), dst(dst), A(A), B(B), C(C), | src(src), dst(dst), A(A), B(B), C(C), | ||||
INIT(wtype(is_max ? DTypeTrait<stype_>::min() : | INIT(wtype(is_max ? DTypeTrait<stype_>::min() : | ||||
DTypeTrait<stype_>::max(), -1)) | |||||
DTypeTrait<stype_>::max(), 0)) | |||||
{ | { | ||||
} | } | ||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) | MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) | ||||
@@ -45,7 +45,7 @@ void exec_forward(_megdnn_tensor_in src, | |||||
reduce::get_ABC(src.layout, A, B, C, param.axis); | reduce::get_ABC(src.layout, A, B, C, param.axis); | ||||
for (size_t a = 0; a < A; ++a) for (size_t c = 0; c < C; ++c) { | for (size_t a = 0; a < A; ++a) for (size_t c = 0; c < C; ++c) { | ||||
float best_val = traits<is_max>::init; | float best_val = traits<is_max>::init; | ||||
size_t best_arg = -1; | |||||
size_t best_arg = 0; | |||||
for (size_t b = 0; b < B; ++b) { | for (size_t b = 0; b < B; ++b) { | ||||
float curr_val = float(src.ptr<T>()[(a*B+b)*C+c]); | float curr_val = float(src.ptr<T>()[(a*B+b)*C+c]); | ||||
if (traits<is_max>::better_than(curr_val, best_val)) { | if (traits<is_max>::better_than(curr_val, best_val)) { | ||||
@@ -527,3 +527,20 @@ def test_nms_is_same(): | |||||
assert op3 != op4 | assert op3 != op4 | ||||
def test_argmxx_on_inf(): | |||||
def run_argmax(): | |||||
x = F.zeros((100, 100)) | |||||
x[:] = -float("inf") | |||||
idxs = F.argmax(x, axis=0) | |||||
return idxs | |||||
def run_argmin(): | |||||
x = F.zeros((100, 100)) | |||||
x[:] = float("inf") | |||||
idxs = F.argmin(x, axis=0) | |||||
return idxs | |||||
assert all(run_argmax() >= 0) | |||||
assert all(run_argmin() >= 0) |