From 69fe5ab3b36d5a64d67bbf43444dcbe2f3f76ed1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 12 May 2020 11:41:14 +0800 Subject: [PATCH] feat(dnn/cuda): add conv2d-sass-kernel GitOrigin-RevId: f284d5a4cec44378bdc263cc65eed3986993e1a3 --- .gitattributes | 1 + dnn/src/cuda/conv_bias/algo.cpp | 1 + dnn/src/cuda/conv_bias/algo.h | 5 +- dnn/src/cuda/conv_bias/opr_impl.h | 5 +- dnn/src/cuda/pooling/opr_impl.cpp | 21 ++++- ...oling2d_int8_cdiv4hwn4.cu => pooling2d_int8.cu} | 96 ++++++++++++++++++++-- ...ing2d_int8_cdiv4hwn4.cuh => pooling2d_int8.cuh} | 16 ++-- dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cpp | 27 ------ dnn/src/cuda/utils.cpp | 27 ++++++ dnn/src/cuda/utils.cuh | 9 ++ dnn/src/cuda/utils.h | 5 ++ dnn/test/common/conv_bias.cpp | 50 +++++++++-- dnn/test/common/conv_bias.h | 2 +- dnn/test/cuda/conv_bias_int8.cpp | 78 +++++++++++++++--- dnn/test/cuda/pooling.cpp | 20 +++++ dnn/test/cuda/utils.cpp | 8 ++ dnn/test/cuda/utils.h | 23 +++++- 17 files changed, 325 insertions(+), 69 deletions(-) rename dnn/src/cuda/pooling/{pooling2d_int8_cdiv4hwn4.cu => pooling2d_int8.cu} (77%) rename dnn/src/cuda/pooling/{pooling2d_int8_cdiv4hwn4.cuh => pooling2d_int8.cuh} (57%) delete mode 100644 dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cpp diff --git a/.gitattributes b/.gitattributes index da7d1ec4..6e3614c6 100644 --- a/.gitattributes +++ b/.gitattributes @@ -3,3 +3,4 @@ dnn/src/cuda/conv_bias/int8/kimpl/* binary dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary +dnn/src/cuda/sass/prebuilt/map_defs.cpp binary diff --git a/dnn/src/cuda/conv_bias/algo.cpp b/dnn/src/cuda/conv_bias/algo.cpp index 235f3cf9..2a2e7320 100644 --- a/dnn/src/cuda/conv_bias/algo.cpp +++ b/dnn/src/cuda/conv_bias/algo.cpp @@ -236,6 +236,7 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { } #endif + ConvBiasForwardImpl::AlgoBase* ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum( cudnnConvolutionFwdAlgo_t algo) { diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index bc45781b..4a87bb76 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -14,11 +14,11 @@ #include "megdnn/oprs.h" #include "src/common/utils.h" +#include "src/cuda/conv_bias/conv_bias_int8.cuh" #include "src/cuda/conv_bias/helper.h" #include "src/cuda/conv_bias/opr_impl.h" -#include "src/cuda/handle.h" -#include "src/cuda/conv_bias/conv_bias_int8.cuh" #include "src/cuda/convolution_helper/parameter.cuh" +#include "src/cuda/handle.h" #include #include @@ -521,6 +521,7 @@ private: std::string m_name; }; + class ConvBiasForwardImpl::AlgoPack { AlgoPack(const AlgoPack&) = delete; AlgoPack& operator=(const AlgoPack&) = delete; diff --git a/dnn/src/cuda/conv_bias/opr_impl.h b/dnn/src/cuda/conv_bias/opr_impl.h index 67489134..222c381d 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.h +++ b/dnn/src/cuda/conv_bias/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "../elemwise/opr_impl.h" @@ -94,5 +95,5 @@ private: } // namespace cuda } // namespace megdnn - + // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/pooling/opr_impl.cpp b/dnn/src/cuda/pooling/opr_impl.cpp index 34ced0a8..85fc7b7d 100644 --- a/dnn/src/cuda/pooling/opr_impl.cpp +++ b/dnn/src/cuda/pooling/opr_impl.cpp @@ -10,7 +10,7 @@ */ #include "src/cuda/pooling/opr_impl.h" -#include "./pooling2d_int8_cdiv4hwn4.cuh" +#include "./pooling2d_int8.cuh" #include "src/cuda/utils.h" namespace megdnn { @@ -67,7 +67,24 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, kern_param.window_h = window_h, kern_param.window_w = window_w, kern_param.sh = sh, kern_param.sw = sw; auto&& stream = cuda_stream(handle()); - return pooling2d::_do_pooling2d_int8_cdiv4hwn4( + return pooling2d::do_pooling2d_int8_cdiv4hwn4( + src.compatible_ptr(), dst.compatible_ptr(), + kern_param, stream, static_cast(param().mode)); + } else if (param().format == Format::NCHW4) { + pooling2d::Param kern_param; + size_t n = src.layout[0], hi = src.layout[2], wi = src.layout[3], + c = src.layout[1], ho = dst.layout[2], wo = dst.layout[3]; + c = c * 4; + size_t ph = param().pad_h, pw = param().pad_w; + size_t window_h = param().window_h, window_w = param().window_w; + size_t sh = param().stride_h, sw = param().stride_w; + kern_param.n = n, kern_param.c = c, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, + kern_param.ph = ph, kern_param.pw = pw, + kern_param.window_h = window_h, kern_param.window_w = window_w, + kern_param.sh = sh, kern_param.sw = sw; + auto&& stream = cuda_stream(handle()); + return pooling2d::do_pooling2d_int8_ncdiv4hw4( src.compatible_ptr(), dst.compatible_ptr(), kern_param, stream, static_cast(param().mode)); } diff --git a/dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cu b/dnn/src/cuda/pooling/pooling2d_int8.cu similarity index 77% rename from dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cu rename to dnn/src/cuda/pooling/pooling2d_int8.cu index 179a7884..ef14353f 100644 --- a/dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cu +++ b/dnn/src/cuda/pooling/pooling2d_int8.cu @@ -8,8 +8,9 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./pooling2d_int8_cdiv4hwn4.cuh" +#include "./pooling2d_int8.cuh" #include "src/common/opr_param_defs_enumv.cuh" +#include "src/cuda/query_blocksize.cuh" using namespace megdnn; using namespace cuda; @@ -360,11 +361,65 @@ __global__ void pooling2d_device_template_int8_cdiv4hwn4( ldg_type res = pooler.get_ans(); *(reinterpret_cast(g_dst_ptr)) = res; } + +template +__global__ void pooling2d_device_template_int8_ncdiv4hw4( + const int8_t* __restrict__ src, int8_t* __restrict__ dst, Param param) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + + using ldg_type = typename Pooler::feed_type; + static int constexpr pack_size = 4; + static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t); + MEGDNN_STATIC_ASSERT( + ldg_width == 1, + "pooling2d (NCHW4) kernel must use 32bit width ldg instruction"); + const int wo_ldg = param.wo / ldg_width; + const int c_packed = param.c / pack_size; + const int batch = tid / (param.ho * wo_ldg * c_packed); + const int chw = tid - batch * param.ho * wo_ldg * c_packed; + const int oc_packed = chw / (param.ho * wo_ldg); + const int hw = chw - oc_packed * param.ho * wo_ldg; + const int oh = hw / wo_ldg; + const int ow = (hw - wo_ldg * oh) * ldg_width; + if (batch >= param.n || oc_packed >= c_packed || oh >= param.ho || + ow >= param.wo) + return; + + const int in_batch_stride = param.hi * param.wi * param.c; + const int out_batch_stride = param.ho * param.wo * param.c; + const int in_channel_stride = param.hi * param.wi * pack_size; + const int out_channel_stride = param.ho * param.wo * pack_size; + const int8_t* __restrict__ g_src_ptr = + src + batch * in_batch_stride + oc_packed * in_channel_stride; + int8_t* __restrict__ g_dst_ptr = dst + batch * out_batch_stride + + oc_packed * out_channel_stride + + (oh * param.wo + ow) * pack_size; + + Pooler pooler(param.window_h * param.window_w); + pooler.init(); + for (int fh = 0; fh < param.window_h; fh++) { + uint32_t ih = oh * param.sh + fh - param.ph; + for (int fw = 0; fw < param.window_w; fw++) { + uint32_t iw = ow * param.sw + fw - param.pw; + if (ih < param.hi && iw < param.wi) { + const int8_t* __restrict__ cur_src_ptr = + g_src_ptr + (ih * param.wi + iw) * pack_size; + ldg_type sval = __ldg(reinterpret_cast(cur_src_ptr)); + pooler.feed(sval); + } + } + } + ldg_type res = pooler.get_ans(); + *(reinterpret_cast(g_dst_ptr)) = res; +} + }; // namespace -void megdnn::cuda::pooling2d::_do_pooling2d_int8_cdiv4hwn4( - const int8_t* d_src, int8_t* d_dst, const Param& param, - cudaStream_t stream, uint32_t mode) { +void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, + int8_t* d_dst, + const Param& param, + cudaStream_t stream, + uint32_t mode) { using Mode = megdnn::param_enumv::Pooling::Mode; void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); uint32_t vthreads_x = 0, vthreads_y = param.c / 4; @@ -397,8 +452,7 @@ void megdnn::cuda::pooling2d::_do_pooling2d_int8_cdiv4hwn4( } #undef dispatch_pooling_mode constexpr uint32_t threads_x = 16; - uint32_t nr_threads = - _get_kern_block_size(reinterpret_cast(kern)); + uint32_t nr_threads = query_blocksize_for_kernel(kern); uint32_t nr_threads_x = std::min(threads_x, vthreads_x), nr_threads_y = std::min(nr_threads / nr_threads_x, vthreads_y); uint32_t nr_blocks_x = param.ho * param.wo, @@ -410,4 +464,34 @@ void megdnn::cuda::pooling2d::_do_pooling2d_int8_cdiv4hwn4( after_kernel_launch(); } +void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, + int8_t* d_dst, + const Param& param, + cudaStream_t stream, + uint32_t mode) { + using Mode = megdnn::param_enumv::Pooling::Mode; + void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); + uint32_t vthreads = param.n * param.c * param.ho * param.wo / 4; + switch (mode) { + case Mode::MAX: + kern = pooling2d_device_template_int8_ncdiv4hw4< + MaxPooler>; + break; + case Mode::AVERAGE: + kern = pooling2d_device_template_int8_ncdiv4hw4< + MeanIncludeRoundedPooler>; + break; + case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: + kern = pooling2d_device_template_int8_ncdiv4hw4< + MeanExcludeRoundedPooler>; + break; + default: + megdnn_assert(false, "invalid pooling mode"); + } + uint32_t nr_threads = query_blocksize_for_kernel(kern); + nr_threads = std::min(nr_threads, vthreads); + uint32_t nr_blocks = DIVUP(vthreads, nr_threads); + kern<<>>(d_src, d_dst, param); + after_kernel_launch(); +} // vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cuh b/dnn/src/cuda/pooling/pooling2d_int8.cuh similarity index 57% rename from dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cuh rename to dnn/src/cuda/pooling/pooling2d_int8.cuh index 6e709eed..dd6e352e 100644 --- a/dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cuh +++ b/dnn/src/cuda/pooling/pooling2d_int8.cuh @@ -1,12 +1,13 @@ /** - * \file dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cuh + * \file dnn/src/cuda/pooling/pooling2d_int8.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -20,15 +21,16 @@ struct Param { int n, c, hi, wi, ho, wo, ph, pw, window_h, window_w, sh, sw; }; -uint32_t _get_kern_block_size(const void* kern); +void do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, int8_t* d_dst, + const Param& param, cudaStream_t stream, + uint32_t mode); -void _do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, int8_t* d_dst, - const Param& param, cudaStream_t stream, - uint32_t mode); +void do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, int8_t* d_dst, + const Param& param, cudaStream_t stream, + uint32_t mode); } // namespace pooling2d } // namespace cuda } // namespace megdnn - // vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cpp b/dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cpp deleted file mode 100644 index 46766cc9..00000000 --- a/dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/** - * \file dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ -#include "./pooling2d_int8_cdiv4hwn4.cuh" -#include "src/cuda/query_blocksize.cuh" - -namespace megdnn { -namespace cuda { -namespace pooling2d { - -uint32_t _get_kern_block_size(const void* kern) { - uint32_t ret = query_blocksize_for_kernel(kern); - return ret; -} - -} // namespace pooling2d -} // namespace cuda -} // namespace megdnn - -// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/utils.cpp b/dnn/src/cuda/utils.cpp index 016f5498..97a17223 100644 --- a/dnn/src/cuda/utils.cpp +++ b/dnn/src/cuda/utils.cpp @@ -82,6 +82,11 @@ void cuda::__throw_cusolver_error__(cusolverStatus_t err, const char* msg) { megdnn_throw(s.c_str()); } +void cuda::__throw_cuda_driver_error__(CUresult err, const char* msg) { + auto s = ssprintf("cuda driver error %d occurred; expr: %s", int(err), msg); + megdnn_throw(s.c_str()); +} + void cuda::report_error(const char *msg) { megdnn_throw(msg); MEGDNN_MARK_USED_VAR(msg); @@ -118,9 +123,31 @@ bool cuda::is_compute_capability_required(int major, int minor) { (device_prop.major == major && device_prop.minor >= minor); } +bool cuda::is_compute_capability_equalto(int major, int minor) { + auto&& device_prop = cuda::current_device_prop(); + return device_prop.major == major && device_prop.minor == minor; +} + size_t cuda::max_batch_x_channel_size() { return current_device_prop().maxGridSize[2]; } +const char* cuda::current_device_arch_name() { + auto&& device_prop = current_device_prop(); + int cap = 10 * device_prop.major + device_prop.minor; + if (cap >= 50 && cap < 60) + return "maxwell"; + else if (cap >= 60 && cap < 70) + return "pascal"; + else if (cap >= 70 && cap < 75) + return "volta"; + else if (cap >= 75 && cap < 80) + return "turing"; + else if (cap >= 80) + return "ampere"; + megdnn_throw( + ssprintf("unsupported cuda compute capability %d", cap).c_str()); +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/utils.cuh b/dnn/src/cuda/utils.cuh index cccf6495..89354b6b 100644 --- a/dnn/src/cuda/utils.cuh +++ b/dnn/src/cuda/utils.cuh @@ -53,6 +53,14 @@ } \ } while (0) +#define cucheck(_x) \ + do { \ + CUresult _err = (_x); \ + if (_err != CUDA_SUCCESS) { \ + ::megdnn::cuda::__throw_cuda_driver_error__(_err, #_x); \ + } \ + } while (0) + #define after_kernel_launch() \ do { \ cuda_check(cudaGetLastError()); \ @@ -84,6 +92,7 @@ MEGDNN_NORETURN void __throw_cublas_error__(cublasStatus_t err, const char* msg); MEGDNN_NORETURN void __throw_cusolver_error__(cusolverStatus_t err, const char* msg); +MEGDNN_NORETURN void __throw_cuda_driver_error__(CUresult err, const char* msg); MEGDNN_NORETURN void report_error(const char* msg); template diff --git a/dnn/src/cuda/utils.h b/dnn/src/cuda/utils.h index 9bbe7ff6..542eb6bc 100644 --- a/dnn/src/cuda/utils.h +++ b/dnn/src/cuda/utils.h @@ -57,10 +57,15 @@ cudaDeviceProp current_device_prop(); //! check compute capability satisfied with given sm version bool is_compute_capability_required(int major, int minor); +//! check compute capability equal to the given sm version +bool is_compute_capability_equalto(int major, int minor); + //! get the CUDNN_MAX_BATCH_X_CHANNEL_SIZE, it's just return the max size of the //! third demension size_t max_batch_x_channel_size(); +const char* current_device_arch_name(); + } // namespace cuda } // namespace megdnn diff --git a/dnn/test/common/conv_bias.cpp b/dnn/test/common/conv_bias.cpp index 8bd3f5b0..1a8d6a12 100644 --- a/dnn/test/common/conv_bias.cpp +++ b/dnn/test/common/conv_bias.cpp @@ -493,6 +493,7 @@ std::vector get_int8_nchw44_args(size_t kernel_size, size_t pack_size, return args; } + std::vector get_int8_nchw4_args_check_bounds(size_t kernel_size) { std::vector args; param::ConvBias cur_param; @@ -528,6 +529,7 @@ std::vector get_int8_nchw4_args_check_bounds(size_t kernel_size) { return args; } + std::vector get_int8_nchw4_args_small_batch(size_t kernel_size) { std::vector args; param::ConvBias cur_param; @@ -728,7 +730,7 @@ std::vector get_int8_chwn4_tensorcore_args(size_t kernel_size) { void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, DType dst_dtype, Handle* handle, const char* algo, param::ConvBias::Format format, - const std::vector& args) { + const std::vector& args, bool fuse_z) { megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); Checker checker(handle); if (algo) { @@ -758,36 +760,72 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, bias_rng = std::make_unique(2.f); } + using Param = param::ConvBias; + using Format = Param::Format; + auto get_z_shape = [&fuse_z, &format](TestArg arg) -> TensorShape { + TensorShape z{}; + if (fuse_z) { + size_t hi, wi, sh, sw, ph, pw, fh, fw; + z = arg.src; + size_t spatial_idx = 2; + if (format == Format::NCHW4) { + hi = arg.src[2]; + wi = arg.src[3]; + fh = arg.filter[2]; + fw = arg.filter[3]; + z[1] = arg.filter[0] / 4; + } else { + megdnn_assert(format == Format::CHWN4); + hi = arg.src[1]; + wi = arg.src[2]; + fh = arg.filter[1]; + fw = arg.filter[2]; + z[0] = arg.filter[3] / 4; + spatial_idx = 1; + } + sh = arg.param.stride_h; + sw = arg.param.stride_w; + ph = arg.param.pad_h; + pw = arg.param.pad_w; + size_t ho = infer_conv_shape(hi, fh, sh, ph); + size_t wo = infer_conv_shape(wi, fw, sw, pw); + z[spatial_idx] = ho; + z[spatial_idx + 1] = wo; + } + return z; + }; megdnn_assert(rng != nullptr && bias_rng != nullptr); - checker.set_rng(0, rng.get()) + checker.set_rng(0, rng.get()) .set_rng(1, rng.get()) .set_rng(2, rng.get()) .set_rng(3, rng.get()); if (args.empty()) { std::vector default_args; - using Param = param::ConvBias; - using Format = Param::Format; if (format == Format::NCHW4) { default_args = get_int8_nchw4_args(3); } else if (format == Format::CHWN4) { default_args = get_int8_chwn4_args(3); } for (auto&& arg : default_args) { + auto z = get_z_shape(arg); checker.set_dtype(0, src_dtype) .set_dtype(1, filter_dtype) .set_dtype(2, bias_dtype) + .set_dtype(3, dst_dtype) .set_dtype(4, dst_dtype) .set_param(arg.param) - .execs({arg.src, arg.filter, arg.bias, {}, {}}); + .execs({arg.src, arg.filter, arg.bias, z, {}}); } } else { for (auto&& arg : args) { + auto z = get_z_shape(arg); checker.set_dtype(0, src_dtype) .set_dtype(1, filter_dtype) .set_dtype(2, bias_dtype) + .set_dtype(3, dst_dtype) .set_dtype(4, dst_dtype) .set_param(arg.param) - .execs({arg.src, arg.filter, arg.bias, {}, {}}); + .execs({arg.src, arg.filter, arg.bias, z, {}}); } } } diff --git a/dnn/test/common/conv_bias.h b/dnn/test/common/conv_bias.h index e52815b5..d928fda3 100644 --- a/dnn/test/common/conv_bias.h +++ b/dnn/test/common/conv_bias.h @@ -66,7 +66,7 @@ void check_conv_bias( DType src_dtype, DType filter_dtype, DType bias_dtype, DType dst_dtype, Handle* handle, const char* algo = nullptr, param::ConvBias::Format format = param::ConvBias::Format::NCHW4, - const std::vector& args = {}); + const std::vector& args = {}, bool fuse_z = false); #if MEGDNN_WITH_BENCHMARK std::vector get_winograd_benchmark_args( diff --git a/dnn/test/cuda/conv_bias_int8.cpp b/dnn/test/cuda/conv_bias_int8.cpp index f24b120c..c81fd7e2 100644 --- a/dnn/test/cuda/conv_bias_int8.cpp +++ b/dnn/test/cuda/conv_bias_int8.cpp @@ -18,10 +18,14 @@ #include "test/cuda/fixture.h" #include "test/cuda/utils.h" +#define V1(x) #x +#define V(x) V1(x) + namespace megdnn { namespace test { -#if MEGDNN_WITH_BENCHMARK namespace { + +#if MEGDNN_WITH_BENCHMARK struct BenchArgs { size_t n, ci, hi, wi, co, f, s; }; @@ -29,9 +33,16 @@ struct BenchArgs { std::vector get_resnet50_bench_args(size_t batch = 64) { std::vector args; args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1}); + + args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 1}); + args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 2}); + args.emplace_back(BenchArgs{batch, 4, 256, 256, 32, 7, 2}); + args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 1, 1}); args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 1, 1}); args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 1}); + args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 2}); + args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 3, 2}); args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1}); args.emplace_back(BenchArgs{batch, 256, 56, 56, 512, 1, 2}); @@ -101,13 +112,12 @@ void benchmark_target_algo( conv_bias::ConvBiasAlgoChecker(algo)); } -#define V1(x) #x -#define V(x) V1(x) #define CUDNN_VERSION_STRING \ "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) benchmarker_cudnn.set_before_exec_callback( conv_bias::ConvBiasAlgoChecker( - "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_" + "DEFAULT:CUDNN:ConvBiasActivation:CUDNN_CONVOLUTION_FWD_" + "ALGO_IMPLICIT_PRECOMP_" "GEMM" CUDNN_VERSION_STRING)); benchmarker.set_dtype(0, src_dtype) @@ -141,6 +151,7 @@ void benchmark_target_algo( {}, {}}) / RUNS; + param.nonlineMode = Param::NonlineMode::IDENTITY; benchmarker_cudnn.set_param(param); auto time_in_ms_cudnn = benchmarker_cudnn.execs( @@ -162,6 +173,47 @@ void benchmark_target_algo( (flo / (time_in_ms_cudnn * 1e-3)), algo, time_in_ms_cudnn / time_in_ms); } + printf("bench with z tensor\n"); + for (auto&& arg : args) { + Param param; + param.pad_h = param.pad_w = arg.f / 2; + param.stride_h = param.stride_w = arg.s; + param.format = Format::NCHW4; + + size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); + size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); + + benchmarker.set_param(param); + auto time_in_ms = + benchmarker.execs({{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, + {arg.co, arg.ci / 4, arg.f, arg.f, 4}, + {1, arg.co / 4, 1, 1, 4}, + {arg.n, arg.co / 4, ho, wo, 4}, + {}}) / + RUNS; + param.format = Format::NCHW4; + param.nonlineMode = Param::NonlineMode::IDENTITY; + benchmarker_cudnn.set_param(param); + auto time_in_ms_cudnn = + benchmarker_cudnn.execs( + {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, + {arg.co, arg.ci / 4, arg.f, arg.f, 4}, + {1, arg.co / 4, 1, 1, 4}, + {arg.n, arg.co / 4, ho, wo, 4}, + {}}) / + RUNS; + float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * + arg.f / (1e12); + TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, + filter{arg.co, arg.ci, arg.f, arg.f}; + printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, " + "time(cudnn)=%.2f %.2fTops, " + "perf(algo=%s)/perf(cudnn)=%.2f\n", + src.to_string().c_str(), filter.to_string().c_str(), algo, + time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, + (flo / (time_in_ms_cudnn * 1e-3)), algo, + time_in_ms_cudnn / time_in_ms); + } } else if (format == Format::CHWN4) { for (auto&& arg : args) { Param param; @@ -222,6 +274,7 @@ void benchmark_target_algo( RUNS; param.format = Format::NCHW4; benchmarker_cudnn.set_param(param); + param.nonlineMode = Param::NonlineMode::IDENTITY; auto time_in_ms_cudnn = benchmarker_cudnn.execs( {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, @@ -242,7 +295,6 @@ void benchmark_target_algo( (flo / (time_in_ms_cudnn * 1e-3)), algo, time_in_ms_cudnn / time_in_ms); } - } } @@ -265,15 +317,14 @@ void benchmark_target_algo_with_cudnn_tsc( benchmarker.set_before_exec_callback( conv_bias::ConvBiasAlgoChecker(algo)); } else { - benchmarker.set_proxy(proxy); + benchmarker.set_proxy(proxy); } benchmarker_cudnn.set_before_exec_callback( conv_bias::ConvBiasAlgoChecker( - "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_" + "DEFAULT:CUDNN:ConvBiasActivation:CUDNN_CONVOLUTION_FWD_" + "ALGO_IMPLICIT_PRECOMP_" "GEMM" CUDNN_VERSION_STRING)); -#undef V1 -#undef V #undef CUDNN_VERSION_STRING benchmarker.set_dtype(0, src_dtype) @@ -446,12 +497,10 @@ void benchmark_target_algo_with_cudnn_tsc( (flo / (time_in_ms_cudnn * 1e-3)), algo, time_in_ms_cudnn / time_in_ms); } - } } - -} // namespace #endif +} // namespace TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_1x1) { require_compute_capability(6, 1); @@ -1116,6 +1165,7 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) { conv_bias::get_int8_chwn4_args_small_batch(1)); } + #if MEGDNN_WITH_BENCHMARK TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4) { require_compute_capability(6, 1); @@ -1182,9 +1232,13 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL) { dtype::QuantizedS8{1.0f}, "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM", param::ConvBias::Format::CHWN4); } + #endif } // namespace test } // namespace megdnn +#undef V1 +#undef V + // vim: syntax=cpp.doxygen diff --git a/dnn/test/cuda/pooling.cpp b/dnn/test/cuda/pooling.cpp index 69a7ccc5..db601726 100644 --- a/dnn/test/cuda/pooling.cpp +++ b/dnn/test/cuda/pooling.cpp @@ -290,6 +290,26 @@ TEST_F(CUDA, POOLING_FORWARD_CHWN4) { } } +TEST_F(CUDA, POOLING_FORWARD_INT8_NCHW4) { + require_compute_capability(6, 1); + using Param = param::Pooling; + Checker checker(handle_cuda()); + Param param; + auto i8_min = std::numeric_limits().min(); + auto i8_max = std::numeric_limits().max(); + UniformIntRNG int_rng{i8_min, i8_max}; + checker.set_dtype(0, dtype::QuantizedS8(0.1f)); + param.format = Param::Format::NCHW4; + for (auto mode : {Param::Mode::MAX, Param::Mode::AVERAGE, + Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING}) { + param.mode = mode; + checker.set_epsilon(1e-3).set_rng(0, &int_rng); + checker.set_param(param).exec({{64, 8, 28, 28, 4}, {}}); + checker.set_param(param).exec({{15, 8, 28, 28, 4}, {}}); + checker.set_param(param).exec({{30, 8, 28, 28, 4}, {}}); + } +} + #if MEGDNN_WITH_BENCHMARK TEST_F(CUDA, BENCHMARK_POOLING_CHWN4) { CUBenchmarker bencher(handle_cuda()); diff --git a/dnn/test/cuda/utils.cpp b/dnn/test/cuda/utils.cpp index 9b8bac58..175895f4 100644 --- a/dnn/test/cuda/utils.cpp +++ b/dnn/test/cuda/utils.cpp @@ -20,6 +20,14 @@ bool check_compute_capability(int major, int minor) { cuda_check(cudaGetDeviceProperties(&prop, dev)); return prop.major > major || (prop.major == major && prop.minor >= minor); } + +bool check_compute_capability_eq(int major, int minor) { + int dev; + cuda_check(cudaGetDevice(&dev)); + cudaDeviceProp prop; + cuda_check(cudaGetDeviceProperties(&prop, dev)); + return (prop.major == major && prop.minor == minor); +} } // namespace test } // namespace megdnn diff --git a/dnn/test/cuda/utils.h b/dnn/test/cuda/utils.h index 6dabd494..299d4ec5 100644 --- a/dnn/test/cuda/utils.h +++ b/dnn/test/cuda/utils.h @@ -26,13 +26,28 @@ namespace megdnn { namespace test { bool check_compute_capability(int major, int minor); +bool check_compute_capability_eq(int major, int minor); } // namespace test } // namespace megdnn -#define require_compute_capability(x, y) \ - do { \ - if (!megdnn::test::check_compute_capability((x), (y))) \ - return; \ +#define require_compute_capability(x, y) \ + do { \ + if (!megdnn::test::check_compute_capability((x), (y))) { \ + printf("skip testcase due to cuda compute capability not " \ + "require.(expected:%d.%d)", \ + (x), (y)); \ + return; \ + } \ + } while (0) + +#define require_compute_capability_eq(x, y) \ + do { \ + if (!megdnn::test::check_compute_capability_eq((x), (y))) { \ + printf("skip testcase due to cuda compute capability not " \ + "equal to %d.%d", \ + (x), (y)); \ + return; \ + } \ } while (0) // vim: syntax=cpp.doxygen