Browse Source

feat(cuda): add convbias ptx algo testcase

GitOrigin-RevId: 9ad6d4561f
HuaHua404-patch-4
Megvii Engine Team 2 years ago
parent
commit
1a2ed8c47b
2 changed files with 37 additions and 0 deletions
  1. +36
    -0
      dnn/test/common/conv_bias.cpp
  2. +1
    -0
      dnn/test/common/conv_bias.h

+ 36
- 0
dnn/test/common/conv_bias.cpp View File

@@ -524,6 +524,42 @@ std::vector<TestArg> get_int8_nchw4_args_check_bounds(size_t kernel_size) {
return args;
}

std::vector<TestArg> get_int4_nchw64_args_ptx(size_t kernel_size, bool is_uint4) {
std::vector<TestArg> args;
param::ConvBias cur_param;

using NLMode = param::ConvBias::NonlineMode;

// clang-format off
for (auto nlmode : {NLMode::RELU, NLMode::IDENTITY}) {//{NLMode::H_SWISH} are not currently supported
for (auto mode : {param::ConvBias::Mode::CROSS_CORRELATION}) {
for (size_t b : {3, 7}) {
for (size_t ic : {64, 128}) {
for (size_t oc : {64, 320}) {
for (size_t h : {13}) {
for (size_t w : {28}) {
for (int p : {0, static_cast<int>(kernel_size / 2)}) {
for (size_t s : {1, 2}) {
if (is_uint4 && nlmode == NLMode::H_SWISH) continue;
size_t f = kernel_size;
cur_param.mode = mode;
cur_param.nonlineMode = nlmode;
cur_param.format = param::ConvBias::Format::NCHW64;
cur_param.sparse = param::ConvBias::Sparse::DENSE;
cur_param.pad_h = cur_param.pad_w = p;
cur_param.stride_h = cur_param.stride_w = s;

//! bias channel
args.emplace_back(cur_param, TensorShape{b, ic / 64, h, w, 64},
TensorShape{oc, ic / 64, f, f, 64},
TensorShape{1, oc / 64, 1, 1, 64});

} } } } } } } } }
// clang-format on

return args;
}

std::vector<TestArg> get_int8_nchw4_args_small_batch(size_t kernel_size) {
std::vector<TestArg> args;
param::ConvBias cur_param;


+ 1
- 0
dnn/test/common/conv_bias.h View File

@@ -30,6 +30,7 @@ std::vector<TestArg> get_quantized_winograd_mk_packed_args(
std::vector<TestArg> get_quantized_args_with_nlmode(
param::ConvBias::NonlineMode nlmode);
std::vector<TestArg> get_quantized_args();
std::vector<TestArg> get_int4_nchw64_args_ptx(size_t kernel_size, bool is_uint4);
std::vector<TestArg> get_int8_nchw4_args(size_t kernel_size);
std::vector<TestArg> get_int8_nchw4_args_check_bounds(size_t kernel_size);
std::vector<TestArg> get_int8_nchw4_small_channel_args(size_t kernel_size);


Loading…
Cancel
Save