Browse Source

fix(dnn/argmxx): fix argmxx on inf

GitOrigin-RevId: 740f67b73a
release-1.1
Megvii Engine Team 4 years ago
parent
commit
215f88f373
3 changed files with 19 additions and 2 deletions
  1. +1
    -1
      dnn/src/common/argmxx_helper.h
  2. +1
    -1
      dnn/src/naive/argmxx/opr_impl.cpp
  3. +17
    -0
      imperative/python/test/unit/functional/test_functional.py

+ 1
- 1
dnn/src/common/argmxx_helper.h View File

@@ -56,7 +56,7 @@ struct ArgmxxOp {
ArgmxxOp(stype_ *src, dt_int32 *dst, uint32_t A, uint32_t B, uint32_t C):
src(src), dst(dst), A(A), B(B), C(C),
INIT(wtype(is_max ? DTypeTrait<stype_>::min() :
DTypeTrait<stype_>::max(), -1))
DTypeTrait<stype_>::max(), 0))
{
}
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx)


+ 1
- 1
dnn/src/naive/argmxx/opr_impl.cpp View File

@@ -45,7 +45,7 @@ void exec_forward(_megdnn_tensor_in src,
reduce::get_ABC(src.layout, A, B, C, param.axis);
for (size_t a = 0; a < A; ++a) for (size_t c = 0; c < C; ++c) {
float best_val = traits<is_max>::init;
size_t best_arg = -1;
size_t best_arg = 0;
for (size_t b = 0; b < B; ++b) {
float curr_val = float(src.ptr<T>()[(a*B+b)*C+c]);
if (traits<is_max>::better_than(curr_val, best_val)) {


+ 17
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -527,3 +527,20 @@ def test_nms_is_same():
assert op3 != op4




def test_argmxx_on_inf():
def run_argmax():
x = F.zeros((100, 100))
x[:] = -float("inf")
idxs = F.argmax(x, axis=0)
return idxs

def run_argmin():
x = F.zeros((100, 100))
x[:] = float("inf")
idxs = F.argmin(x, axis=0)
return idxs

assert all(run_argmax() >= 0)
assert all(run_argmin() >= 0)

Loading…
Cancel
Save