Browse Source

fix(dnn): correct behaviour of floor div for int tensor

GitOrigin-RevId: 1444f69cce
tags/v1.7.2.m1
Megvii Engine Team 3 years ago
parent
commit
d9a46ea47b
4 changed files with 35 additions and 3 deletions
  1. +10
    -1
      dnn/src/common/elemwise/kern_defs.cuh
  2. +11
    -1
      imperative/python/test/unit/functional/test_elemwise.py
  3. +13
    -0
      src/opr/test/basic_arith/elemwise.cpp
  4. +1
    -1
      src/opr/test/basic_arith/elemwise_binary_trait_def.inl

+ 10
- 1
dnn/src/common/elemwise/kern_defs.cuh View File

@@ -119,6 +119,15 @@ __device__ __host__ inline float dispatch_powf(float x, float y) {
return powf(x, y);
}

__device__ __host__ inline int dispatch_floordiv_int(int x, int y) {
if ((x ^ y) < 0) {
const auto quot = x / y;
const auto rem = x % y;
return rem ? quot - 1 : quot;
}
return x / y;
}

#include "src/common/elemwise/each_mode.inl"

template <megcorePlatform_t plat, uint32_t mode, typename dtype>
@@ -227,7 +236,7 @@ DEF_KERN(dt_bool, LT, x < y);
DEF_KERN(dt_bool, LEQ, x <= y);
DEF_KERN(dt_bool, EQ, x == y);

DEF_KERN_INT(FLOOR_DIV, x / y);
DEF_KERN_INT(FLOOR_DIV, dispatch_floordiv_int(x, y));
DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y));

DEF_KERN_INT(MOD, x % y);


+ 11
- 1
imperative/python/test/unit/functional/test_elemwise.py View File

@@ -59,7 +59,7 @@ def test_multiply():

def test_div():
np.testing.assert_allclose(
F.div(tensor([3, 4]), 2).numpy(),
F.div(tensor([3.0, 4.0]), 2).numpy(),
np.divide(np.array([3, 4], dtype=np.float32), 2),
)

@@ -67,6 +67,16 @@ def test_div():
(tensor([3, 4]) / 2).numpy(), np.divide(np.array([3, 4], dtype=np.float32), 2),
)

np.testing.assert_allclose(
F.floor_div(tensor([-5.0, -7.0]), 2).numpy(),
np.floor_divide(np.array([-5.0, -7.0], dtype=np.float32), 2),
)

np.testing.assert_allclose(
(tensor([-5, -7]) // 2).numpy(),
np.floor_divide(np.array([-5, -7], dtype=np.int32), 2),
)


def test_clamp():
"""Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and


+ 13
- 0
src/opr/test/basic_arith/elemwise.cpp View File

@@ -39,6 +39,19 @@ int do_mod(int a, int b) {
return a % b;
}

float do_floor_div(float a, float b) {
return std::floor(a / b);
}

int do_floor_div(int a, int b) {
if ((a ^ b) < 0) {
const auto quot = a / b;
const auto rem = a % b;
return rem ? quot - 1 : quot;
}
return a / b;
}

float do_erfinv(float x) {
return erfinvf(x);
}


+ 1
- 1
src/opr/test/basic_arith/elemwise_binary_trait_def.inl View File

@@ -41,7 +41,7 @@ DEF_TRAIT(LT, x < y)
#define _ALLOW_INT true
DEF_TRAIT(ABS_GRAD, x > 0 ? y : -y)
DEF_TRAIT(ADD, x + y)
DEF_TRAIT(FLOOR_DIV, floor(x / y))
DEF_TRAIT(FLOOR_DIV, do_floor_div(x, y))
DEF_TRAIT(MAX, std::max(x, y))
DEF_TRAIT(MIN, std::min(x, y))
DEF_TRAIT(MOD, do_mod(x, y))


Loading…
Cancel
Save