Browse Source

fix(dnn/opencl): fix elemwise negative stride support

GitOrigin-RevId: 506d7e6104
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
0e8b81c20e
2 changed files with 128 additions and 0 deletions
  1. +125
    -0
      dnn/test/common/elemwise.cpp
  2. +3
    -0
      dnn/test/common/elemwise.h

+ 125
- 0
dnn/test/common/elemwise.cpp View File

@@ -924,6 +924,131 @@ DEF_TEST(all_modes) {
#undef run
}

#define UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int8()}, {}}); \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int16()}, {}}); \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int32()}, {}});

#define UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(_optr) \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Float32()}, {}});

#define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT \
UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(RELU); \
UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(ABS);

#define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ABS) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(LOG) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(COS) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIN) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FLOOR) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(CEIL) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIGMOID) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(EXP) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(RELU) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ROUND) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(TANH) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FAST_TANH)

DEF_TEST(unary_negative_stride) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT;

UniformFloatRNG rng(1e-2, 6e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT;
}

#undef UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
#undef UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
#undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
#undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT

#define BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int8()}, \
{{1, 4, 1}, dtype::Int8()}, \
{}}); \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int16()}, \
{{1, 4, 1}, dtype::Int16()}, \
{}}); \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int32()}, \
{{1, 4, 1}, dtype::Int32()}, \
{}});

#define BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(_optr) \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Float32()}, \
{{1, 4, 1}, dtype::Float32()}, \
{}});

#define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(ADD) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MUL) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MAX) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MIN) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(SUB)

#define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(POW) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(TRUE_DIV) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_SIGMOID) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_TANH) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_RELU) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_H_SWISH) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FAST_TANH_GRAD) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(H_SWISH_GRAD)

DEF_TEST(binary_negative_stride) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT;

UniformFloatRNG rng(1e-2, 2e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32;
}

#undef BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
#undef BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
#undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
#undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32

DEF_TEST(ternary_negative_stride) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_param(Mode::FUSE_MUL_ADD3);
checker.execl({{{1, 7}, {-7, -1}, dtype::Int8()},
{{1, 7}, {-3, -1}, dtype::Int8()},
{{1, 7}, {-7, -1}, dtype::Int8()},
{}});
checker.execl({{{1, 7}, {-7, -1}, dtype::Int16()},
{{1, 7}, {-3, -1}, dtype::Int16()},
{{1, 7}, {-7, -1}, dtype::Int16()},
{}});
checker.execl({{{1, 7}, {-7, -1}, dtype::Int32()},
{{1, 7}, {-3, -1}, dtype::Int32()},
{{1, 7}, {-7, -1}, dtype::Int32()},
{}});

UniformFloatRNG rng(1e-2, 2e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.execl({{{1, 7}, {-7, -1}, dtype::Float32()},
{{1, 7}, {-3, -1}, dtype::Float32()},
{{1, 7}, {-7, -1}, dtype::Float32()},
{}});
}

TEST(TEST_ELEMWISE, MODE_TRAIT) {
using M = Elemwise::Mode;
using T = Elemwise::ModeTrait;


+ 3
- 0
dnn/test/common/elemwise.h View File

@@ -40,6 +40,9 @@ namespace elemwise {
cb(unary3) \
cb(binary3) \
cb(all_modes) \
cb(unary_negative_stride) \
cb(binary_negative_stride) \
cb(ternary_negative_stride) \

#define FOREACH_ELEMWISE_CASE(cb) \
cb(FIRST_ELEMWISE_CASE) \


Loading…
Cancel
Save