Browse Source

fix(dnn): fix the modulo of int

GitOrigin-RevId: 6f7280246b
master
Megvii Engine Team 2 years ago
parent
commit
0ebd4400d5
4 changed files with 23 additions and 3 deletions
  1. +1
    -1
      dnn/src/common/elemwise/kern_defs.cuh
  2. +1
    -1
      dnn/test/common/elemwise.cpp
  3. +20
    -0
      dnn/test/naive/elemwise_multi_type.cpp
  4. +1
    -1
      src/opr/test/basic_arith/elemwise.cpp

+ 1
- 1
dnn/src/common/elemwise/kern_defs.cuh View File

@@ -229,7 +229,7 @@ DEF_KERN(dt_bool, EQ, x == y);
DEF_KERN_INT(FLOOR_DIV, dispatch_floordiv_int(x, y));
DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y));

DEF_KERN_INT(MOD, x % y);
DEF_KERN_INT(MOD, ((y + x % y) % y)); // consistent with python modulo
DEF_KERN_FLOAT(MOD, fmodf(x, y));

DEF_KERN_INT(SHL, x << y);


+ 1
- 1
dnn/test/common/elemwise.cpp View File

@@ -878,8 +878,8 @@ DEF_TEST(all_modes) {
} while (0)

if (trait.allow_int) {
run(dtype::Int8{});
run(dtype::Int32{});
run(dtype::Int8{});
}
if (trait.allow_float) {
DNN_FLOAT16_SELECT(


+ 20
- 0
dnn/test/naive/elemwise_multi_type.cpp View File

@@ -280,4 +280,24 @@ TEST_F(NAIVE, ELEMWISE_QUANTIZED_MODE_TERNARY) {
}
}

TEST_F(NAIVE, ELELMWISE_INT_MODULO) {
Checker<Elemwise> checker(handle(), /* check_dispatch */ false);
Elemwise::Param param;
param.mode = Elemwise::Param::Mode::MOD;

checker.set_param(param).exect(
Testcase{
TensorValue(
{10}, dtype::Int32(),
{10, 24, -6, -20, 10, -90, 45, 3, -1, 0}),
TensorValue(
{10}, dtype::Int32(), {3, 7, 5, -3, -6, 11, 7, -1, 8, -1}),
{}},
Testcase{
{},
{},
TensorValue(
{10}, dtype::Int32(), {1, 3, 4, -2, -2, 9, 3, 0, 7, 0})});
}

// vim: syntax=cpp.doxygen

+ 1
- 1
src/opr/test/basic_arith/elemwise.cpp View File

@@ -25,7 +25,7 @@ float do_mod(float a, float b) {
}

int do_mod(int a, int b) {
return a % b;
return (a % b + b) % b;
}

float do_floor_div(float a, float b) {


Loading…
Cancel
Save