@@ -3,3 +3,4 @@ | |||
dnn/src/cuda/conv_bias/int8/kimpl/* binary | |||
dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary | |||
dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary | |||
dnn/src/cuda/sass/prebuilt/map_defs.cpp binary |
@@ -236,6 +236,7 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||
} | |||
#endif | |||
ConvBiasForwardImpl::AlgoBase* | |||
ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum( | |||
cudnnConvolutionFwdAlgo_t algo) { | |||
@@ -14,11 +14,11 @@ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
#include "src/cuda/conv_bias/conv_bias_int8.cuh" | |||
#include "src/cuda/conv_bias/helper.h" | |||
#include "src/cuda/conv_bias/opr_impl.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/conv_bias/conv_bias_int8.cuh" | |||
#include "src/cuda/convolution_helper/parameter.cuh" | |||
#include "src/cuda/handle.h" | |||
#include <cuda.h> | |||
#include <memory> | |||
@@ -521,6 +521,7 @@ private: | |||
std::string m_name; | |||
}; | |||
class ConvBiasForwardImpl::AlgoPack { | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator=(const AlgoPack&) = delete; | |||
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "../elemwise/opr_impl.h" | |||
@@ -94,5 +95,5 @@ private: | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -10,7 +10,7 @@ | |||
*/ | |||
#include "src/cuda/pooling/opr_impl.h" | |||
#include "./pooling2d_int8_cdiv4hwn4.cuh" | |||
#include "./pooling2d_int8.cuh" | |||
#include "src/cuda/utils.h" | |||
namespace megdnn { | |||
@@ -67,7 +67,24 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, | |||
kern_param.window_h = window_h, kern_param.window_w = window_w, | |||
kern_param.sh = sh, kern_param.sw = sw; | |||
auto&& stream = cuda_stream(handle()); | |||
return pooling2d::_do_pooling2d_int8_cdiv4hwn4( | |||
return pooling2d::do_pooling2d_int8_cdiv4hwn4( | |||
src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(), | |||
kern_param, stream, static_cast<uint32_t>(param().mode)); | |||
} else if (param().format == Format::NCHW4) { | |||
pooling2d::Param kern_param; | |||
size_t n = src.layout[0], hi = src.layout[2], wi = src.layout[3], | |||
c = src.layout[1], ho = dst.layout[2], wo = dst.layout[3]; | |||
c = c * 4; | |||
size_t ph = param().pad_h, pw = param().pad_w; | |||
size_t window_h = param().window_h, window_w = param().window_w; | |||
size_t sh = param().stride_h, sw = param().stride_w; | |||
kern_param.n = n, kern_param.c = c, kern_param.hi = hi, | |||
kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, | |||
kern_param.ph = ph, kern_param.pw = pw, | |||
kern_param.window_h = window_h, kern_param.window_w = window_w, | |||
kern_param.sh = sh, kern_param.sw = sw; | |||
auto&& stream = cuda_stream(handle()); | |||
return pooling2d::do_pooling2d_int8_ncdiv4hw4( | |||
src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(), | |||
kern_param, stream, static_cast<uint32_t>(param().mode)); | |||
} | |||
@@ -8,8 +8,9 @@ | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./pooling2d_int8_cdiv4hwn4.cuh" | |||
#include "./pooling2d_int8.cuh" | |||
#include "src/common/opr_param_defs_enumv.cuh" | |||
#include "src/cuda/query_blocksize.cuh" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
@@ -360,11 +361,65 @@ __global__ void pooling2d_device_template_int8_cdiv4hwn4( | |||
ldg_type res = pooler.get_ans(); | |||
*(reinterpret_cast<ldg_type*>(g_dst_ptr)) = res; | |||
} | |||
template <typename Pooler> | |||
__global__ void pooling2d_device_template_int8_ncdiv4hw4( | |||
const int8_t* __restrict__ src, int8_t* __restrict__ dst, Param param) { | |||
const int tid = blockIdx.x * blockDim.x + threadIdx.x; | |||
using ldg_type = typename Pooler::feed_type; | |||
static int constexpr pack_size = 4; | |||
static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t); | |||
MEGDNN_STATIC_ASSERT( | |||
ldg_width == 1, | |||
"pooling2d (NCHW4) kernel must use 32bit width ldg instruction"); | |||
const int wo_ldg = param.wo / ldg_width; | |||
const int c_packed = param.c / pack_size; | |||
const int batch = tid / (param.ho * wo_ldg * c_packed); | |||
const int chw = tid - batch * param.ho * wo_ldg * c_packed; | |||
const int oc_packed = chw / (param.ho * wo_ldg); | |||
const int hw = chw - oc_packed * param.ho * wo_ldg; | |||
const int oh = hw / wo_ldg; | |||
const int ow = (hw - wo_ldg * oh) * ldg_width; | |||
if (batch >= param.n || oc_packed >= c_packed || oh >= param.ho || | |||
ow >= param.wo) | |||
return; | |||
const int in_batch_stride = param.hi * param.wi * param.c; | |||
const int out_batch_stride = param.ho * param.wo * param.c; | |||
const int in_channel_stride = param.hi * param.wi * pack_size; | |||
const int out_channel_stride = param.ho * param.wo * pack_size; | |||
const int8_t* __restrict__ g_src_ptr = | |||
src + batch * in_batch_stride + oc_packed * in_channel_stride; | |||
int8_t* __restrict__ g_dst_ptr = dst + batch * out_batch_stride + | |||
oc_packed * out_channel_stride + | |||
(oh * param.wo + ow) * pack_size; | |||
Pooler pooler(param.window_h * param.window_w); | |||
pooler.init(); | |||
for (int fh = 0; fh < param.window_h; fh++) { | |||
uint32_t ih = oh * param.sh + fh - param.ph; | |||
for (int fw = 0; fw < param.window_w; fw++) { | |||
uint32_t iw = ow * param.sw + fw - param.pw; | |||
if (ih < param.hi && iw < param.wi) { | |||
const int8_t* __restrict__ cur_src_ptr = | |||
g_src_ptr + (ih * param.wi + iw) * pack_size; | |||
ldg_type sval = __ldg(reinterpret_cast<const ldg_type*>(cur_src_ptr)); | |||
pooler.feed(sval); | |||
} | |||
} | |||
} | |||
ldg_type res = pooler.get_ans(); | |||
*(reinterpret_cast<ldg_type*>(g_dst_ptr)) = res; | |||
} | |||
}; // namespace | |||
void megdnn::cuda::pooling2d::_do_pooling2d_int8_cdiv4hwn4( | |||
const int8_t* d_src, int8_t* d_dst, const Param& param, | |||
cudaStream_t stream, uint32_t mode) { | |||
void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, | |||
int8_t* d_dst, | |||
const Param& param, | |||
cudaStream_t stream, | |||
uint32_t mode) { | |||
using Mode = megdnn::param_enumv::Pooling::Mode; | |||
void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); | |||
uint32_t vthreads_x = 0, vthreads_y = param.c / 4; | |||
@@ -397,8 +452,7 @@ void megdnn::cuda::pooling2d::_do_pooling2d_int8_cdiv4hwn4( | |||
} | |||
#undef dispatch_pooling_mode | |||
constexpr uint32_t threads_x = 16; | |||
uint32_t nr_threads = | |||
_get_kern_block_size(reinterpret_cast<const void*>(kern)); | |||
uint32_t nr_threads = query_blocksize_for_kernel(kern); | |||
uint32_t nr_threads_x = std::min(threads_x, vthreads_x), | |||
nr_threads_y = std::min(nr_threads / nr_threads_x, vthreads_y); | |||
uint32_t nr_blocks_x = param.ho * param.wo, | |||
@@ -410,4 +464,34 @@ void megdnn::cuda::pooling2d::_do_pooling2d_int8_cdiv4hwn4( | |||
after_kernel_launch(); | |||
} | |||
void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, | |||
int8_t* d_dst, | |||
const Param& param, | |||
cudaStream_t stream, | |||
uint32_t mode) { | |||
using Mode = megdnn::param_enumv::Pooling::Mode; | |||
void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); | |||
uint32_t vthreads = param.n * param.c * param.ho * param.wo / 4; | |||
switch (mode) { | |||
case Mode::MAX: | |||
kern = pooling2d_device_template_int8_ncdiv4hw4< | |||
MaxPooler<int8_t, int32_t>>; | |||
break; | |||
case Mode::AVERAGE: | |||
kern = pooling2d_device_template_int8_ncdiv4hw4< | |||
MeanIncludeRoundedPooler<int8_t, int32_t, int32_t>>; | |||
break; | |||
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: | |||
kern = pooling2d_device_template_int8_ncdiv4hw4< | |||
MeanExcludeRoundedPooler<int8_t, int32_t, int32_t>>; | |||
break; | |||
default: | |||
megdnn_assert(false, "invalid pooling mode"); | |||
} | |||
uint32_t nr_threads = query_blocksize_for_kernel(kern); | |||
nr_threads = std::min(nr_threads, vthreads); | |||
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||
kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param); | |||
after_kernel_launch(); | |||
} | |||
// vim: syntax=cuda.doxygen |
@@ -1,12 +1,13 @@ | |||
/** | |||
* \file dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cuh | |||
* \file dnn/src/cuda/pooling/pooling2d_int8.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
@@ -20,15 +21,16 @@ struct Param { | |||
int n, c, hi, wi, ho, wo, ph, pw, window_h, window_w, sh, sw; | |||
}; | |||
uint32_t _get_kern_block_size(const void* kern); | |||
void do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, int8_t* d_dst, | |||
const Param& param, cudaStream_t stream, | |||
uint32_t mode); | |||
void _do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, int8_t* d_dst, | |||
const Param& param, cudaStream_t stream, | |||
uint32_t mode); | |||
void do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, int8_t* d_dst, | |||
const Param& param, cudaStream_t stream, | |||
uint32_t mode); | |||
} // namespace pooling2d | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cuda.doxygen |
@@ -1,27 +0,0 @@ | |||
/** | |||
* \file dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./pooling2d_int8_cdiv4hwn4.cuh" | |||
#include "src/cuda/query_blocksize.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace pooling2d { | |||
uint32_t _get_kern_block_size(const void* kern) { | |||
uint32_t ret = query_blocksize_for_kernel(kern); | |||
return ret; | |||
} | |||
} // namespace pooling2d | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -82,6 +82,11 @@ void cuda::__throw_cusolver_error__(cusolverStatus_t err, const char* msg) { | |||
megdnn_throw(s.c_str()); | |||
} | |||
void cuda::__throw_cuda_driver_error__(CUresult err, const char* msg) { | |||
auto s = ssprintf("cuda driver error %d occurred; expr: %s", int(err), msg); | |||
megdnn_throw(s.c_str()); | |||
} | |||
void cuda::report_error(const char *msg) { | |||
megdnn_throw(msg); | |||
MEGDNN_MARK_USED_VAR(msg); | |||
@@ -118,9 +123,31 @@ bool cuda::is_compute_capability_required(int major, int minor) { | |||
(device_prop.major == major && device_prop.minor >= minor); | |||
} | |||
bool cuda::is_compute_capability_equalto(int major, int minor) { | |||
auto&& device_prop = cuda::current_device_prop(); | |||
return device_prop.major == major && device_prop.minor == minor; | |||
} | |||
size_t cuda::max_batch_x_channel_size() { | |||
return current_device_prop().maxGridSize[2]; | |||
} | |||
const char* cuda::current_device_arch_name() { | |||
auto&& device_prop = current_device_prop(); | |||
int cap = 10 * device_prop.major + device_prop.minor; | |||
if (cap >= 50 && cap < 60) | |||
return "maxwell"; | |||
else if (cap >= 60 && cap < 70) | |||
return "pascal"; | |||
else if (cap >= 70 && cap < 75) | |||
return "volta"; | |||
else if (cap >= 75 && cap < 80) | |||
return "turing"; | |||
else if (cap >= 80) | |||
return "ampere"; | |||
megdnn_throw( | |||
ssprintf("unsupported cuda compute capability %d", cap).c_str()); | |||
} | |||
// vim: syntax=cpp.doxygen | |||
@@ -53,6 +53,14 @@ | |||
} \ | |||
} while (0) | |||
#define cucheck(_x) \ | |||
do { \ | |||
CUresult _err = (_x); \ | |||
if (_err != CUDA_SUCCESS) { \ | |||
::megdnn::cuda::__throw_cuda_driver_error__(_err, #_x); \ | |||
} \ | |||
} while (0) | |||
#define after_kernel_launch() \ | |||
do { \ | |||
cuda_check(cudaGetLastError()); \ | |||
@@ -84,6 +92,7 @@ MEGDNN_NORETURN void __throw_cublas_error__(cublasStatus_t err, | |||
const char* msg); | |||
MEGDNN_NORETURN void __throw_cusolver_error__(cusolverStatus_t err, | |||
const char* msg); | |||
MEGDNN_NORETURN void __throw_cuda_driver_error__(CUresult err, const char* msg); | |||
MEGDNN_NORETURN void report_error(const char* msg); | |||
template <typename T, size_t N> | |||
@@ -57,10 +57,15 @@ cudaDeviceProp current_device_prop(); | |||
//! check compute capability satisfied with given sm version | |||
bool is_compute_capability_required(int major, int minor); | |||
//! check compute capability equal to the given sm version | |||
bool is_compute_capability_equalto(int major, int minor); | |||
//! get the CUDNN_MAX_BATCH_X_CHANNEL_SIZE, it's just return the max size of the | |||
//! third demension | |||
size_t max_batch_x_channel_size(); | |||
const char* current_device_arch_name(); | |||
} // namespace cuda | |||
} // namespace megdnn | |||
@@ -493,6 +493,7 @@ std::vector<TestArg> get_int8_nchw44_args(size_t kernel_size, size_t pack_size, | |||
return args; | |||
} | |||
std::vector<TestArg> get_int8_nchw4_args_check_bounds(size_t kernel_size) { | |||
std::vector<TestArg> args; | |||
param::ConvBias cur_param; | |||
@@ -528,6 +529,7 @@ std::vector<TestArg> get_int8_nchw4_args_check_bounds(size_t kernel_size) { | |||
return args; | |||
} | |||
std::vector<TestArg> get_int8_nchw4_args_small_batch(size_t kernel_size) { | |||
std::vector<TestArg> args; | |||
param::ConvBias cur_param; | |||
@@ -728,7 +730,7 @@ std::vector<TestArg> get_int8_chwn4_tensorcore_args(size_t kernel_size) { | |||
void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, | |||
DType dst_dtype, Handle* handle, const char* algo, | |||
param::ConvBias::Format format, | |||
const std::vector<TestArg>& args) { | |||
const std::vector<TestArg>& args, bool fuse_z) { | |||
megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); | |||
Checker<ConvBiasForward> checker(handle); | |||
if (algo) { | |||
@@ -758,36 +760,72 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, | |||
bias_rng = std::make_unique<NormalRNG>(2.f); | |||
} | |||
using Param = param::ConvBias; | |||
using Format = Param::Format; | |||
auto get_z_shape = [&fuse_z, &format](TestArg arg) -> TensorShape { | |||
TensorShape z{}; | |||
if (fuse_z) { | |||
size_t hi, wi, sh, sw, ph, pw, fh, fw; | |||
z = arg.src; | |||
size_t spatial_idx = 2; | |||
if (format == Format::NCHW4) { | |||
hi = arg.src[2]; | |||
wi = arg.src[3]; | |||
fh = arg.filter[2]; | |||
fw = arg.filter[3]; | |||
z[1] = arg.filter[0] / 4; | |||
} else { | |||
megdnn_assert(format == Format::CHWN4); | |||
hi = arg.src[1]; | |||
wi = arg.src[2]; | |||
fh = arg.filter[1]; | |||
fw = arg.filter[2]; | |||
z[0] = arg.filter[3] / 4; | |||
spatial_idx = 1; | |||
} | |||
sh = arg.param.stride_h; | |||
sw = arg.param.stride_w; | |||
ph = arg.param.pad_h; | |||
pw = arg.param.pad_w; | |||
size_t ho = infer_conv_shape(hi, fh, sh, ph); | |||
size_t wo = infer_conv_shape(wi, fw, sw, pw); | |||
z[spatial_idx] = ho; | |||
z[spatial_idx + 1] = wo; | |||
} | |||
return z; | |||
}; | |||
megdnn_assert(rng != nullptr && bias_rng != nullptr); | |||
checker.set_rng(0, rng.get()) | |||
checker.set_rng(0, rng.get()) | |||
.set_rng(1, rng.get()) | |||
.set_rng(2, rng.get()) | |||
.set_rng(3, rng.get()); | |||
if (args.empty()) { | |||
std::vector<TestArg> default_args; | |||
using Param = param::ConvBias; | |||
using Format = Param::Format; | |||
if (format == Format::NCHW4) { | |||
default_args = get_int8_nchw4_args(3); | |||
} else if (format == Format::CHWN4) { | |||
default_args = get_int8_chwn4_args(3); | |||
} | |||
for (auto&& arg : default_args) { | |||
auto z = get_z_shape(arg); | |||
checker.set_dtype(0, src_dtype) | |||
.set_dtype(1, filter_dtype) | |||
.set_dtype(2, bias_dtype) | |||
.set_dtype(3, dst_dtype) | |||
.set_dtype(4, dst_dtype) | |||
.set_param(arg.param) | |||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||
.execs({arg.src, arg.filter, arg.bias, z, {}}); | |||
} | |||
} else { | |||
for (auto&& arg : args) { | |||
auto z = get_z_shape(arg); | |||
checker.set_dtype(0, src_dtype) | |||
.set_dtype(1, filter_dtype) | |||
.set_dtype(2, bias_dtype) | |||
.set_dtype(3, dst_dtype) | |||
.set_dtype(4, dst_dtype) | |||
.set_param(arg.param) | |||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||
.execs({arg.src, arg.filter, arg.bias, z, {}}); | |||
} | |||
} | |||
} | |||
@@ -66,7 +66,7 @@ void check_conv_bias( | |||
DType src_dtype, DType filter_dtype, DType bias_dtype, DType dst_dtype, | |||
Handle* handle, const char* algo = nullptr, | |||
param::ConvBias::Format format = param::ConvBias::Format::NCHW4, | |||
const std::vector<TestArg>& args = {}); | |||
const std::vector<TestArg>& args = {}, bool fuse_z = false); | |||
#if MEGDNN_WITH_BENCHMARK | |||
std::vector<conv_bias::TestArg> get_winograd_benchmark_args( | |||
@@ -18,10 +18,14 @@ | |||
#include "test/cuda/fixture.h" | |||
#include "test/cuda/utils.h" | |||
#define V1(x) #x | |||
#define V(x) V1(x) | |||
namespace megdnn { | |||
namespace test { | |||
#if MEGDNN_WITH_BENCHMARK | |||
namespace { | |||
#if MEGDNN_WITH_BENCHMARK | |||
struct BenchArgs { | |||
size_t n, ci, hi, wi, co, f, s; | |||
}; | |||
@@ -29,9 +33,16 @@ struct BenchArgs { | |||
std::vector<BenchArgs> get_resnet50_bench_args(size_t batch = 64) { | |||
std::vector<BenchArgs> args; | |||
args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1}); | |||
args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 1}); | |||
args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 2}); | |||
args.emplace_back(BenchArgs{batch, 4, 256, 256, 32, 7, 2}); | |||
args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 1, 1}); | |||
args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 1, 1}); | |||
args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 1}); | |||
args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 2}); | |||
args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 3, 2}); | |||
args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1}); | |||
args.emplace_back(BenchArgs{batch, 256, 56, 56, 512, 1, 2}); | |||
@@ -101,13 +112,12 @@ void benchmark_target_algo( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo)); | |||
} | |||
#define V1(x) #x | |||
#define V(x) V1(x) | |||
#define CUDNN_VERSION_STRING \ | |||
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | |||
benchmarker_cudnn.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_" | |||
"DEFAULT:CUDNN:ConvBiasActivation:CUDNN_CONVOLUTION_FWD_" | |||
"ALGO_IMPLICIT_PRECOMP_" | |||
"GEMM" CUDNN_VERSION_STRING)); | |||
benchmarker.set_dtype(0, src_dtype) | |||
@@ -141,6 +151,7 @@ void benchmark_target_algo( | |||
{}, | |||
{}}) / | |||
RUNS; | |||
param.nonlineMode = Param::NonlineMode::IDENTITY; | |||
benchmarker_cudnn.set_param(param); | |||
auto time_in_ms_cudnn = | |||
benchmarker_cudnn.execs( | |||
@@ -162,6 +173,47 @@ void benchmark_target_algo( | |||
(flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
time_in_ms_cudnn / time_in_ms); | |||
} | |||
printf("bench with z tensor\n"); | |||
for (auto&& arg : args) { | |||
Param param; | |||
param.pad_h = param.pad_w = arg.f / 2; | |||
param.stride_h = param.stride_w = arg.s; | |||
param.format = Format::NCHW4; | |||
size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
benchmarker.set_param(param); | |||
auto time_in_ms = | |||
benchmarker.execs({{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
{arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||
{1, arg.co / 4, 1, 1, 4}, | |||
{arg.n, arg.co / 4, ho, wo, 4}, | |||
{}}) / | |||
RUNS; | |||
param.format = Format::NCHW4; | |||
param.nonlineMode = Param::NonlineMode::IDENTITY; | |||
benchmarker_cudnn.set_param(param); | |||
auto time_in_ms_cudnn = | |||
benchmarker_cudnn.execs( | |||
{{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
{arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||
{1, arg.co / 4, 1, 1, 4}, | |||
{arg.n, arg.co / 4, ho, wo, 4}, | |||
{}}) / | |||
RUNS; | |||
float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * | |||
arg.f / (1e12); | |||
TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
filter{arg.co, arg.ci, arg.f, arg.f}; | |||
printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, " | |||
"time(cudnn)=%.2f %.2fTops, " | |||
"perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
src.to_string().c_str(), filter.to_string().c_str(), algo, | |||
time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
(flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
time_in_ms_cudnn / time_in_ms); | |||
} | |||
} else if (format == Format::CHWN4) { | |||
for (auto&& arg : args) { | |||
Param param; | |||
@@ -222,6 +274,7 @@ void benchmark_target_algo( | |||
RUNS; | |||
param.format = Format::NCHW4; | |||
benchmarker_cudnn.set_param(param); | |||
param.nonlineMode = Param::NonlineMode::IDENTITY; | |||
auto time_in_ms_cudnn = | |||
benchmarker_cudnn.execs( | |||
{{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
@@ -242,7 +295,6 @@ void benchmark_target_algo( | |||
(flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
time_in_ms_cudnn / time_in_ms); | |||
} | |||
} | |||
} | |||
@@ -265,15 +317,14 @@ void benchmark_target_algo_with_cudnn_tsc( | |||
benchmarker.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo)); | |||
} else { | |||
benchmarker.set_proxy(proxy); | |||
benchmarker.set_proxy(proxy); | |||
} | |||
benchmarker_cudnn.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_" | |||
"DEFAULT:CUDNN:ConvBiasActivation:CUDNN_CONVOLUTION_FWD_" | |||
"ALGO_IMPLICIT_PRECOMP_" | |||
"GEMM" CUDNN_VERSION_STRING)); | |||
#undef V1 | |||
#undef V | |||
#undef CUDNN_VERSION_STRING | |||
benchmarker.set_dtype(0, src_dtype) | |||
@@ -446,12 +497,10 @@ void benchmark_target_algo_with_cudnn_tsc( | |||
(flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
time_in_ms_cudnn / time_in_ms); | |||
} | |||
} | |||
} | |||
} // namespace | |||
#endif | |||
} // namespace | |||
TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_1x1) { | |||
require_compute_capability(6, 1); | |||
@@ -1116,6 +1165,7 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) { | |||
conv_bias::get_int8_chwn4_args_small_batch(1)); | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4) { | |||
require_compute_capability(6, 1); | |||
@@ -1182,9 +1232,13 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL) { | |||
dtype::QuantizedS8{1.0f}, "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM", | |||
param::ConvBias::Format::CHWN4); | |||
} | |||
#endif | |||
} // namespace test | |||
} // namespace megdnn | |||
#undef V1 | |||
#undef V | |||
// vim: syntax=cpp.doxygen |
@@ -290,6 +290,26 @@ TEST_F(CUDA, POOLING_FORWARD_CHWN4) { | |||
} | |||
} | |||
TEST_F(CUDA, POOLING_FORWARD_INT8_NCHW4) { | |||
require_compute_capability(6, 1); | |||
using Param = param::Pooling; | |||
Checker<Pooling> checker(handle_cuda()); | |||
Param param; | |||
auto i8_min = std::numeric_limits<int8_t>().min(); | |||
auto i8_max = std::numeric_limits<int8_t>().max(); | |||
UniformIntRNG int_rng{i8_min, i8_max}; | |||
checker.set_dtype(0, dtype::QuantizedS8(0.1f)); | |||
param.format = Param::Format::NCHW4; | |||
for (auto mode : {Param::Mode::MAX, Param::Mode::AVERAGE, | |||
Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING}) { | |||
param.mode = mode; | |||
checker.set_epsilon(1e-3).set_rng(0, &int_rng); | |||
checker.set_param(param).exec({{64, 8, 28, 28, 4}, {}}); | |||
checker.set_param(param).exec({{15, 8, 28, 28, 4}, {}}); | |||
checker.set_param(param).exec({{30, 8, 28, 28, 4}, {}}); | |||
} | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(CUDA, BENCHMARK_POOLING_CHWN4) { | |||
CUBenchmarker<Pooling> bencher(handle_cuda()); | |||
@@ -20,6 +20,14 @@ bool check_compute_capability(int major, int minor) { | |||
cuda_check(cudaGetDeviceProperties(&prop, dev)); | |||
return prop.major > major || (prop.major == major && prop.minor >= minor); | |||
} | |||
bool check_compute_capability_eq(int major, int minor) { | |||
int dev; | |||
cuda_check(cudaGetDevice(&dev)); | |||
cudaDeviceProp prop; | |||
cuda_check(cudaGetDeviceProperties(&prop, dev)); | |||
return (prop.major == major && prop.minor == minor); | |||
} | |||
} // namespace test | |||
} // namespace megdnn | |||
@@ -26,13 +26,28 @@ | |||
namespace megdnn { | |||
namespace test { | |||
bool check_compute_capability(int major, int minor); | |||
bool check_compute_capability_eq(int major, int minor); | |||
} // namespace test | |||
} // namespace megdnn | |||
#define require_compute_capability(x, y) \ | |||
do { \ | |||
if (!megdnn::test::check_compute_capability((x), (y))) \ | |||
return; \ | |||
#define require_compute_capability(x, y) \ | |||
do { \ | |||
if (!megdnn::test::check_compute_capability((x), (y))) { \ | |||
printf("skip testcase due to cuda compute capability not " \ | |||
"require.(expected:%d.%d)", \ | |||
(x), (y)); \ | |||
return; \ | |||
} \ | |||
} while (0) | |||
#define require_compute_capability_eq(x, y) \ | |||
do { \ | |||
if (!megdnn::test::check_compute_capability_eq((x), (y))) { \ | |||
printf("skip testcase due to cuda compute capability not " \ | |||
"equal to %d.%d", \ | |||
(x), (y)); \ | |||
return; \ | |||
} \ | |||
} while (0) | |||
// vim: syntax=cpp.doxygen |