Browse Source

feat(dnn/cuda): opt dp4a conv for small channel base on cutlass

GitOrigin-RevId: 2a74c35f27
release-1.1
Megvii Engine Team 4 years ago
parent
commit
739f927c4c
11 changed files with 268 additions and 26 deletions
  1. +11
    -10
      dnn/src/cuda/conv_bias/algo.cpp
  2. +4
    -3
      dnn/src/cuda/conv_bias/algo.h
  3. +13
    -12
      dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu
  4. +35
    -0
      dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_hswish.cu
  5. +35
    -0
      dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_id.cu
  6. +35
    -0
      dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_relu.cu
  7. +35
    -0
      dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_hswish.cu
  8. +35
    -0
      dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_id.cu
  9. +35
    -0
      dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_relu.cu
  10. +29
    -0
      dnn/test/cuda/conv_bias_int8.cpp
  11. +1
    -1
      third_party/cutlass

+ 11
- 10
dnn/src/cuda/conv_bias/algo.cpp View File

@@ -260,16 +260,17 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() {

void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() {
using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam;
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32});
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32});
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32});
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32});
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 64, 32, 64, 32, 32});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32});
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 32, 32, 32, 32, 32});
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8});
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 64, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 32, 32, 32, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1});
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2});
}




+ 4
- 3
dnn/src/cuda/conv_bias/algo.h View File

@@ -407,15 +407,16 @@ public:
int warp_m;
int warp_n;
int warp_k;
int stage;
std::string to_string() {
/// default algorithm
if (threadblock_m == 128 && threadblock_n == 128 &&
threadblock_k == 32 && warp_m == 32 && warp_n == 64 &&
warp_k == 32) {
warp_k == 32 && stage == 2) {
return "";
}
return ssprintf("_%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n,
threadblock_k, warp_m, warp_n, warp_k);
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n,
threadblock_k, warp_m, warp_n, warp_k, stage);
}
};
AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param)


+ 13
- 12
dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu View File

@@ -172,7 +172,7 @@ void megdnn::cuda::cutlass_wrapper::
const GemmCoord& warp_shape, cudaStream_t stream) {
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_, aligned_) \
warp_k_, stage_, aligned_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
@@ -194,7 +194,7 @@ void megdnn::cuda::cutlass_wrapper::
cutlass::convolution::threadblock:: \
ConvolutionNCxHWxThreadblockSwizzle< \
cutlass::convolution::ConvType::kConvolution>, \
2, 4, aligned_, NeedLoadFromConstMem>; \
stage_, 4, aligned_, NeedLoadFromConstMem>; \
typename Convolution::ConvolutionParameter conv_param{ \
param.n, param.ci, param.co, param.hi, param.wi, \
param.fh, param.fw, param.ho, param.wo, param.sh, \
@@ -204,16 +204,17 @@ void megdnn::cuda::cutlass_wrapper::
epilogue, stream); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 4); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \


+ 35
- 0
dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_hswish.cu View File

@@ -0,0 +1,35 @@
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl"

using LayoutSrc = cutlass::layout::TensorNCxHWx<4>;
using LayoutFilter = cutlass::layout::TensorCxRSKx<4>;
using ThreadBlockShape = cutlass::gemm::GemmShape<16, 128, 16>;
using WarpShape = cutlass::gemm::GemmShape<16, 128, 16>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp<
int8_t, 4, int32_t, int32_t, float>;
using Convolution = cutlass::convolution::device::Convolution<
int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t,
LayoutSrc, int32_t, LayoutSrc, int32_t,
cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassSimt, cutlass::arch::Sm61,
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp,
cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle<
cutlass::convolution::ConvType::kConvolution>,
1, 4, 8, true>;
template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>(
const int8_t* d_src,
const int8_t* d_filter,
const int32_t* d_bias,
const int8_t* d_z,
int8_t* d_dst,
int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif

+ 35
- 0
dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_id.cu View File

@@ -0,0 +1,35 @@
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl"

using LayoutSrc = cutlass::layout::TensorNCxHWx<4>;
using LayoutFilter = cutlass::layout::TensorCxRSKx<4>;
using ThreadBlockShape = cutlass::gemm::GemmShape<16, 128, 16>;
using WarpShape = cutlass::gemm::GemmShape<16, 128, 16>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
int8_t, 4, int32_t, int32_t, float>;
using Convolution = cutlass::convolution::device::Convolution<
int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t,
LayoutSrc, int32_t, LayoutSrc, int32_t,
cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassSimt, cutlass::arch::Sm61,
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp,
cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle<
cutlass::convolution::ConvType::kConvolution>,
1, 4, 8, true>;
template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>(
const int8_t* d_src,
const int8_t* d_filter,
const int32_t* d_bias,
const int8_t* d_z,
int8_t* d_dst,
int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif

+ 35
- 0
dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_relu.cu View File

@@ -0,0 +1,35 @@
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl"

using LayoutSrc = cutlass::layout::TensorNCxHWx<4>;
using LayoutFilter = cutlass::layout::TensorCxRSKx<4>;
using ThreadBlockShape = cutlass::gemm::GemmShape<16, 128, 16>;
using WarpShape = cutlass::gemm::GemmShape<16, 128, 16>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp<
int8_t, 4, int32_t, int32_t, float>;
using Convolution = cutlass::convolution::device::Convolution<
int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t,
LayoutSrc, int32_t, LayoutSrc, int32_t,
cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassSimt, cutlass::arch::Sm61,
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp,
cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle<
cutlass::convolution::ConvType::kConvolution>,
1, 4, 8, true>;
template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>(
const int8_t* d_src,
const int8_t* d_filter,
const int32_t* d_bias,
const int8_t* d_z,
int8_t* d_dst,
int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif

+ 35
- 0
dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_hswish.cu View File

@@ -0,0 +1,35 @@
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl"

using LayoutSrc = cutlass::layout::TensorNCxHWx<4>;
using LayoutFilter = cutlass::layout::TensorCxRSKx<4>;
using ThreadBlockShape = cutlass::gemm::GemmShape<16, 128, 16>;
using WarpShape = cutlass::gemm::GemmShape<16, 128, 16>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp<
int8_t, 4, int32_t, int32_t, float>;
using Convolution = cutlass::convolution::device::Convolution<
int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t,
LayoutSrc, int32_t, LayoutSrc, int32_t,
cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassSimt, cutlass::arch::Sm61,
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp,
cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle<
cutlass::convolution::ConvType::kConvolution>,
1, 4, 8, false>;
template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>(
const int8_t* d_src,
const int8_t* d_filter,
const int32_t* d_bias,
const int8_t* d_z,
int8_t* d_dst,
int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif

+ 35
- 0
dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_id.cu View File

@@ -0,0 +1,35 @@
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl"

using LayoutSrc = cutlass::layout::TensorNCxHWx<4>;
using LayoutFilter = cutlass::layout::TensorCxRSKx<4>;
using ThreadBlockShape = cutlass::gemm::GemmShape<16, 128, 16>;
using WarpShape = cutlass::gemm::GemmShape<16, 128, 16>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
int8_t, 4, int32_t, int32_t, float>;
using Convolution = cutlass::convolution::device::Convolution<
int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t,
LayoutSrc, int32_t, LayoutSrc, int32_t,
cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassSimt, cutlass::arch::Sm61,
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp,
cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle<
cutlass::convolution::ConvType::kConvolution>,
1, 4, 8, false>;
template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>(
const int8_t* d_src,
const int8_t* d_filter,
const int32_t* d_bias,
const int8_t* d_z,
int8_t* d_dst,
int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif

+ 35
- 0
dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_relu.cu View File

@@ -0,0 +1,35 @@
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl"

using LayoutSrc = cutlass::layout::TensorNCxHWx<4>;
using LayoutFilter = cutlass::layout::TensorCxRSKx<4>;
using ThreadBlockShape = cutlass::gemm::GemmShape<16, 128, 16>;
using WarpShape = cutlass::gemm::GemmShape<16, 128, 16>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp<
int8_t, 4, int32_t, int32_t, float>;
using Convolution = cutlass::convolution::device::Convolution<
int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t,
LayoutSrc, int32_t, LayoutSrc, int32_t,
cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassSimt, cutlass::arch::Sm61,
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp,
cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle<
cutlass::convolution::ConvType::kConvolution>,
1, 4, 8, false>;
template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>(
const int8_t* d_src,
const int8_t* d_filter,
const int32_t* d_bias,
const int8_t* d_z,
int8_t* d_dst,
int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif

+ 29
- 0
dnn/test/cuda/conv_bias_int8.cpp View File

@@ -97,6 +97,13 @@ std::vector<BenchArgs> get_detection_bench_args(size_t batch = 16) {
return args;
}

std::vector<BenchArgs> get_det_first_bench_args(size_t batch = 16) {
std::vector<BenchArgs> args;
args.emplace_back(BenchArgs{batch, 4, 736, 1280, 16, 3, 2});
args.emplace_back(BenchArgs{batch, 16, 384, 640, 16, 3, 1});
return args;
}

void benchmark_target_algo(
Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype,
DType filter_dtype, DType bias_dtype, DType dst_dtype,
@@ -1236,6 +1243,28 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_CONV_BIAS_INT8_NCHW4) {
dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.0f},
"INT8_NCHW4_DOTPROD_IMPLICIT_GEMM", param::ConvBias::Format::NCHW4);
}

TEST_F(CUDA, BENCHMARK_SASS_CONV_BIAS_INT8_NCHW4_DET_FIRST) {
require_compute_capability(6, 1);
std::string algo = ConvBias::algo_name<ConvBias::DirectParam>(
"SASS_INT8_NCHW4_DOTPROD_IMPLICIT_GEMM_128X32_64",
ConvBias::DirectParam{});
benchmark_target_algo(handle_cuda(), get_det_first_bench_args(16),
dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
dtype::QuantizedS32{1.2f * 1.3f},
dtype::QuantizedS8{1.0f}, algo.c_str(),
param::ConvBias::Format::NCHW4);
}

TEST_F(CUDA, BENCHMARK_CUTLASS_CONV_BIAS_INT8_NCHW4_DET_FIRST) {
require_compute_capability(6, 1);
benchmark_target_algo(
handle_cuda(), get_det_first_bench_args(16),
dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.0f},
"INT8_NCHW4_DOTPROD_IMPLICIT_GEMM_16", param::ConvBias::Format::NCHW4);
}

#endif
} // namespace test
} // namespace megdnn


+ 1
- 1
third_party/cutlass

@@ -1 +1 @@
Subproject commit 5a7f4bfa0e57f92140c8236322a86730132e0847
Subproject commit 41426ea4074dcfc448b1c9979ea7617407590c04

Loading…
Cancel
Save