GitOrigin-RevId: 814b8a83f8
release-1.11.1
@@ -38,7 +38,7 @@ void RegionRestrictedConvolutionForward::deduce_dtype( | |||||
"only float type is supported for region_restricted_conv forward"); | "only float type is supported for region_restricted_conv forward"); | ||||
megdnn_assert( | megdnn_assert( | ||||
rin == rout && (rin == dtype::Int32() || rin == dtype::Uint8()), | rin == rout && (rin == dtype::Int32() || rin == dtype::Uint8()), | ||||
"the dtype of rin/rout should be Int32, got %s.", rin.name()); | |||||
"the dtype of rin/rout should be Int32 or Uint8, got %s.", rin.name()); | |||||
} | } | ||||
void RegionRestrictedConvolutionForward::deduce_layout( | void RegionRestrictedConvolutionForward::deduce_layout( | ||||
@@ -91,12 +91,12 @@ RegionRestrictedConvolutionBackwardData::check_exec( | |||||
auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd); | auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd); | ||||
#define err_msg(lhs, rhs) \ | #define err_msg(lhs, rhs) \ | ||||
megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs); | megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs); | ||||
err_msg(rin.shape[0], grad_fwd.shape[0]); | |||||
err_msg(rin.shape[1], grad_fwd.shape[2]); | |||||
err_msg(rin.shape[2], grad_fwd.shape[3]); | |||||
err_msg(rout.shape[0], diff_fwd.shape[0]); | |||||
err_msg(rout.shape[1], diff_fwd.shape[2]); | |||||
err_msg(rout.shape[2], diff_fwd.shape[3]); | |||||
err_msg(rin.shape[0], grad_fwd.shape[0]); // batch | |||||
err_msg(rin.shape[1], grad_fwd.shape[2]); // ih | |||||
err_msg(rin.shape[2], grad_fwd.shape[3]); // iw | |||||
err_msg(rout.shape[0], diff_fwd.shape[0]); // batch | |||||
err_msg(rout.shape[1], diff_fwd.shape[2]); // oh | |||||
err_msg(rout.shape[2], diff_fwd.shape[3]); // ow | |||||
#undef err_msg | #undef err_msg | ||||
auto required_workspace_in_bytes = | auto required_workspace_in_bytes = | ||||
get_workspace_in_bytes(filter, diff, rin, rout, grad); | get_workspace_in_bytes(filter, diff, rin, rout, grad); | ||||
@@ -106,45 +106,22 @@ RegionRestrictedConvolutionBackwardData::check_exec( | |||||
void RegionRestrictedConvolutionBackwardData::deduce_dtype( | void RegionRestrictedConvolutionBackwardData::deduce_dtype( | ||||
DType filter, DType diff, DType rin, DType rout, DType& grad) { | DType filter, DType diff, DType rin, DType rout, DType& grad) { | ||||
SmallVector<DType> supported_dst_dtype; | |||||
if (filter.category() == diff.category() && | |||||
filter.category() == DTypeCategory::FLOAT) { | |||||
supported_dst_dtype.push_back(filter); | |||||
} else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) { | |||||
supported_dst_dtype.push_back(dtype::Int32()); | |||||
} else if ( | |||||
(filter.enumv() == DTypeEnum::QuantizedS8 && | |||||
diff.enumv() == DTypeEnum::QuantizedS8) || | |||||
(filter.enumv() == DTypeEnum::Quantized8Asymm && | |||||
diff.enumv() == DTypeEnum::Quantized8Asymm)) { | |||||
supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(filter, diff))); | |||||
if (grad.valid() && grad.enumv() == diff.enumv()) { | |||||
supported_dst_dtype.push_back(grad); | |||||
} | |||||
} else { | |||||
megdnn_throw(ssprintf( | |||||
"unsupported input / diff DType: %s x %s", filter.name(), diff.name())); | |||||
} | |||||
if (!grad.valid()) { | |||||
grad = supported_dst_dtype.at(0); | |||||
} else { | |||||
megdnn_assert( | |||||
vec_contains(supported_dst_dtype, grad), | |||||
"unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(), | |||||
grad.name()); | |||||
} | |||||
megdnn_assert( | |||||
param().compute_mode != Param::ComputeMode::FLOAT32 | |||||
// FIXME: infering dtype of grad via naive impl only support fp32 | |||||
// (lack of quantized dtype infering or others) may not suitable in the furture | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
|| filter.enumv() == DTypeEnum::Float16 || | |||||
filter.enumv() == DTypeEnum::BFloat16 | |||||
if (diff.enumv() == DTypeEnum::Float32 || diff.enumv() == DTypeEnum::Float16) { | |||||
grad = diff; | |||||
} | |||||
#endif | #endif | ||||
, | |||||
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 " | |||||
"input / output."); | |||||
megdnn_assert(grad.valid(), "dtype of grad requires deducing of assigned"); | |||||
megdnn_assert( | megdnn_assert( | ||||
rin == rout && rin == dtype::Int32(), | |||||
"the dtype of rin/rout should be Int32, got %s.", rin.name()); | |||||
diff.category() == DTypeCategory::FLOAT && | |||||
filter.category() == DTypeCategory::FLOAT && | |||||
grad.category() == DTypeCategory::FLOAT, | |||||
"only float type is supported for region_restricted_conv backward data"); | |||||
megdnn_assert( | |||||
rin == rout && (rin == dtype::Int32() || rin == dtype::Uint8()), | |||||
"the dtype of rin/rout should be Int32 or Uint8, got %s.", rin.name()); | |||||
} | } | ||||
void RegionRestrictedConvolutionBackwardData::deduce_layout( | void RegionRestrictedConvolutionBackwardData::deduce_layout( | ||||
@@ -1,7 +1,7 @@ | |||||
#include "./kern.cuh" | |||||
#include "cuda.h" | #include "cuda.h" | ||||
#include "cuda_fp16.h" | #include "cuda_fp16.h" | ||||
#include "src/cuda/fp16_help.cuh" | #include "src/cuda/fp16_help.cuh" | ||||
#include "src/cuda/region_restricted_convolution/chanwise/kern.cuh" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -15,7 +15,7 @@ namespace cuda { | |||||
namespace region_restricted_convolution { | namespace region_restricted_convolution { | ||||
namespace chanwise { | namespace chanwise { | ||||
// =====================================fwd===================================== | |||||
// =====================================bwd===================================== | |||||
template <> | template <> | ||||
void run_bwd_depthwise_large_filter( | void run_bwd_depthwise_large_filter( | ||||
@@ -498,16 +498,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( | |||||
SrcGlobal2ShareVisitor gl2sh_src = { | SrcGlobal2ShareVisitor gl2sh_src = { | ||||
smem_src, | smem_src, | ||||
static_cast<int>(param.src_w), | static_cast<int>(param.src_w), | ||||
static_cast<int>( | |||||
is_fwd ? src_start_h | |||||
: src_start_h - | |||||
(param.out_h / 2 + param.flt_h / 2 - param.pad_h - | |||||
param.src_h * param.stride_h / 2)), | |||||
static_cast<int>( | |||||
is_fwd ? src_start_w | |||||
: src_start_w - | |||||
(param.out_w / 2 + param.flt_w / 2 - param.pad_w - | |||||
param.src_w * param.stride_w / 2)), | |||||
static_cast<int>(src_start_h), | |||||
static_cast<int>(src_start_w), | |||||
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), | static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), | ||||
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), | static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), | ||||
is_fwd ? 1 : static_cast<int>(param.stride_h), | is_fwd ? 1 : static_cast<int>(param.stride_h), | ||||
@@ -516,16 +508,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( | |||||
RinGlobal2ShareVisitor gl2sh_rin = { | RinGlobal2ShareVisitor gl2sh_rin = { | ||||
smem_rin, | smem_rin, | ||||
static_cast<int>(param.src_w), | static_cast<int>(param.src_w), | ||||
static_cast<int>( | |||||
is_fwd ? src_start_h | |||||
: src_start_h - | |||||
(param.out_h / 2 + param.flt_h / 2 - param.pad_h - | |||||
param.src_h * param.stride_h / 2)), | |||||
static_cast<int>( | |||||
is_fwd ? src_start_w | |||||
: src_start_w - | |||||
(param.out_w / 2 + param.flt_w / 2 - param.pad_w - | |||||
param.src_w * param.stride_w / 2)), | |||||
static_cast<int>(src_start_h), | |||||
static_cast<int>(src_start_w), | |||||
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), | static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), | ||||
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), | static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), | ||||
is_fwd ? 1 : static_cast<int>(param.stride_h), | is_fwd ? 1 : static_cast<int>(param.stride_h), | ||||
@@ -790,14 +774,15 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( | |||||
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h; | out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h; | ||||
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w; | T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w; | ||||
static_assert((FilterTileConfig::unroll_w & 3) == 0); | |||||
static_assert( | |||||
(FilterTileConfig::unroll_w & 3) == 0, "filter tile unroll_w & 3 != 0"); | |||||
int* smem_rin_ptr = smem_rin + (off_ow * FilterTileConfig::unroll_w >> 2); | int* smem_rin_ptr = smem_rin + (off_ow * FilterTileConfig::unroll_w >> 2); | ||||
T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w; | T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w; | ||||
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w; | T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w; | ||||
const uint8_t* rout_base_ptr = rout + batch * param.out_h * param.out_w; | const uint8_t* rout_base_ptr = rout + batch * param.out_h * param.out_w; | ||||
static_assert((OutTileConfig::unroll_w & 3) == 0); | |||||
static_assert((OutTileConfig::block_w & 3) == 0); | |||||
static_assert((OutTileConfig::unroll_w & 3) == 0, "output tile unroll_w & 3 != 0"); | |||||
static_assert((OutTileConfig::block_w & 3) == 0, "output block_w & 3 != 0"); | |||||
int reg_rout[OutTileConfig::unroll_size] = {0}; | int reg_rout[OutTileConfig::unroll_size] = {0}; | ||||
#pragma unroll | #pragma unroll | ||||
for (int i = 0; i < OutTileConfig::unroll_h; ++i) { | for (int i = 0; i < OutTileConfig::unroll_h; ++i) { | ||||
@@ -821,16 +806,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( | |||||
SrcGlobal2ShareVisitor gl2sh_src = { | SrcGlobal2ShareVisitor gl2sh_src = { | ||||
smem_src, | smem_src, | ||||
static_cast<int>(param.src_w), | static_cast<int>(param.src_w), | ||||
static_cast<int>( | |||||
is_fwd ? src_start_h | |||||
: src_start_h - | |||||
(param.out_h / 2 + param.flt_h / 2 - param.pad_h - | |||||
param.src_h * param.stride_h / 2)), | |||||
static_cast<int>( | |||||
is_fwd ? src_start_w | |||||
: src_start_w - | |||||
(param.out_w / 2 + param.flt_w / 2 - param.pad_w - | |||||
param.src_w * param.stride_w / 2)), | |||||
static_cast<int>(src_start_h), | |||||
static_cast<int>(src_start_w), | |||||
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), | static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), | ||||
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), | static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), | ||||
is_fwd ? 1 : static_cast<int>(param.stride_h), | is_fwd ? 1 : static_cast<int>(param.stride_h), | ||||
@@ -839,16 +816,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( | |||||
RinGlobal2ShareVisitor gl2sh_rin = { | RinGlobal2ShareVisitor gl2sh_rin = { | ||||
smem_rin, | smem_rin, | ||||
static_cast<int>(param.src_w), | static_cast<int>(param.src_w), | ||||
static_cast<int>( | |||||
is_fwd ? src_start_h | |||||
: src_start_h - | |||||
(param.out_h / 2 + param.flt_h / 2 - param.pad_h - | |||||
param.src_h * param.stride_h / 2)), | |||||
static_cast<int>( | |||||
is_fwd ? src_start_w | |||||
: src_start_w - | |||||
(param.out_w / 2 + param.flt_w / 2 - param.pad_w - | |||||
param.src_w * param.stride_w / 2)), | |||||
static_cast<int>(src_start_h), | |||||
static_cast<int>(src_start_w), | |||||
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), | static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), | ||||
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), | static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), | ||||
is_fwd ? 1 : static_cast<int>(param.stride_h), | is_fwd ? 1 : static_cast<int>(param.stride_h), | ||||
@@ -1134,14 +1103,20 @@ void LaunchDepthwiseConv2dGPU( | |||||
RinTileCount::smem_size * sizeof(int); | RinTileCount::smem_size * sizeof(int); | ||||
void (*kernel)(const Param, const T*, const T*, const RT*, const RT*, T*); | void (*kernel)(const Param, const T*, const T*, const RT*, const RT*, T*); | ||||
const bool is_fwd = (kDirection == DIRECTION_FORWARD); | |||||
if (param.is_compute_deafult) { | if (param.is_compute_deafult) { | ||||
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection, stride>; | kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection, stride>; | ||||
} else { | } else { | ||||
megdnn_assert_internal(0); | megdnn_assert_internal(0); | ||||
} | } | ||||
kernel<<<grid, block, shared_storage, stream>>>( | |||||
param, input, filter, rin, rout, output); | |||||
if (is_fwd) { | |||||
kernel<<<grid, block, shared_storage, stream>>>( | |||||
param, input, filter, rin, rout, output); | |||||
} else { | |||||
kernel<<<grid, block, shared_storage, stream>>>( | |||||
param, input, filter, rout, rin, output); | |||||
} | |||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
@@ -55,25 +55,65 @@ size_t RegionRestrictedConvolutionBackwardDataImpl::get_workspace_in_bytes( | |||||
void RegionRestrictedConvolutionBackwardDataImpl::exec( | void RegionRestrictedConvolutionBackwardDataImpl::exec( | ||||
_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | ||||
_megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | _megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | ||||
megdnn_throw(ssprintf( | |||||
"unsupported RegionRestrictedConvolutionBackwardData(%s, %s, %s, %s) -> %s", | |||||
filter.layout.dtype.name(), diff.layout.dtype.name(), | |||||
rin.layout.dtype.name(), rout.layout.dtype.name(), | |||||
grad.layout.dtype.name())); | |||||
auto fm = check_exec( | |||||
filter.layout, diff.layout, rin.layout, rout.layout, grad.layout, | |||||
workspace.size); | |||||
// XXX: a naive impl to set deconv padding to param, needs optimization in future. | |||||
[&]() -> void { | |||||
size_t stride = fm.stride[0]; | |||||
size_t src_size = grad.layout.shape[2]; | |||||
size_t fwd_pad = fm.padding[0]; | |||||
size_t filter_size = fm.spatial[0]; | |||||
size_t deconv_pad = (stride * src_size - stride + stride * filter_size - | |||||
src_size - 2 * fwd_pad + filter_size - 1) / | |||||
(2 * stride); | |||||
fm.padding[0] = fm.padding[1] = deconv_pad; | |||||
return; | |||||
}(); | |||||
auto kparam = chanwise::Param::load( | |||||
diff.layout, grad.layout, fm, | |||||
param().compute_mode == Param::ComputeMode::DEFAULT); | |||||
megdnn_assert( | |||||
fm.group > 1 && diff.layout.dtype.category() == DTypeCategory::FLOAT && | |||||
param().compute_mode == Param::ComputeMode::DEFAULT && | |||||
fm.spatial_ndim == 2 && fm.icpg == 1 && fm.ocpg == 1 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && !fm.should_flip && | |||||
param().stride_h == 1 && param().stride_w == 1); | |||||
// NOTE: uint8 dtype region mask requires the spatial size of src&dst is 4*N | |||||
if (rin.layout.dtype == dtype::Uint8()) { | |||||
megdnn_assert( | |||||
(grad.layout.shape[3] & 3) == 0 && (diff.layout.shape[3] & 3) == 0); | |||||
} | |||||
auto stream = cuda_stream(handle()); | |||||
if (filter.layout.dtype == dtype::Float32() && rin.layout.dtype == dtype::Int32() && | |||||
rout.layout.dtype == dtype::Int32()) { | |||||
chanwise::run_bwd_depthwise_large_filter( | |||||
grad.ptr<dt_float32>(), diff.ptr<dt_float32>(), | |||||
filter.ptr<dt_float32>(), rin.ptr<dt_int32>(), rout.ptr<dt_int32>(), | |||||
kparam, stream); | |||||
} else if ( | |||||
filter.layout.dtype == dtype::Float32() && | |||||
rin.layout.dtype == dtype::Uint8() && rout.layout.dtype == dtype::Uint8()) { | |||||
chanwise::run_bwd_depthwise_large_filter( | |||||
grad.ptr<dt_float32>(), diff.ptr<dt_float32>(), | |||||
filter.ptr<dt_float32>(), rin.ptr<dt_uint8>(), rout.ptr<dt_uint8>(), | |||||
kparam, stream); | |||||
} else { | |||||
megdnn_throw("undefined or unimplemented region restricted conv mode"); | |||||
} | |||||
} | } | ||||
size_t RegionRestrictedConvolutionBackwardFilterImpl::get_workspace_in_bytes( | size_t RegionRestrictedConvolutionBackwardFilterImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& diff, const TensorLayout&, | const TensorLayout& src, const TensorLayout& diff, const TensorLayout&, | ||||
const TensorLayout&, const TensorLayout& grad) { | const TensorLayout&, const TensorLayout& grad) { | ||||
size_t workspace_size = 0; | |||||
return workspace_size; | |||||
return 0; | |||||
} | } | ||||
/* ============== RegionRestrictedConvolutionBackwardFilterImpl ============== */ | /* ============== RegionRestrictedConvolutionBackwardFilterImpl ============== */ | ||||
void RegionRestrictedConvolutionBackwardFilterImpl::exec( | void RegionRestrictedConvolutionBackwardFilterImpl::exec( | ||||
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | ||||
_megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | _megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | ||||
megdnn_assert_internal(0); | |||||
megdnn_throw("Region Restricted Conv BackwardFilter unimplemented"); | |||||
} | } | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -117,7 +117,7 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_FP32) { | |||||
.set_dtype(1, dtype::Float32()) | .set_dtype(1, dtype::Float32()) | ||||
.set_dtype(2, dtype::Int32()) | .set_dtype(2, dtype::Int32()) | ||||
.set_dtype(3, dtype::Int32()); | .set_dtype(3, dtype::Int32()); | ||||
rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng).set_rng(0, &r_rng); | |||||
rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng); | |||||
rr_bencher.set_times(nr_times); | rr_bencher.set_times(nr_times); | ||||
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); | size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); | ||||
@@ -169,6 +169,202 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_FP32) { | |||||
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); | run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); | ||||
} | } | ||||
TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_BACKWARD_LARGE_FILTER_FP32) { | |||||
require_compute_capability(7, 5); | |||||
Benchmarker<ConvolutionBackwardData> bencher(handle_cuda()); | |||||
bencher.set_display(false); | |||||
bencher.set_before_exec_callback( | |||||
AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER")); | |||||
Benchmarker<RegionRestrictedConvolutionBackwardData> rr_bencher(handle_cuda()); | |||||
rr_bencher.set_display(false); | |||||
ConvolutionBackwardData::Param param; | |||||
param.format = ConvolutionBackwardData::Param::Format::NCHW; | |||||
param.sparse = ConvolutionBackwardData::Param::Sparse::GROUP; | |||||
RegionRestrictedConvolutionBackwardData::Param rr_param; | |||||
rr_param.format = RegionRestrictedConvolutionBackwardData::Param::Format::NCHW; | |||||
rr_param.sparse = RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; | |||||
UniformIntRNG r_rng{1, 3}; | |||||
auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh, | |||||
size_t fw, size_t sh, size_t sw, size_t nr_times) { | |||||
param.pad_h = fh / 2; | |||||
param.pad_w = fw / 2; | |||||
param.stride_h = sh; | |||||
param.stride_w = sw; | |||||
rr_param.pad_h = fh / 2; | |||||
rr_param.pad_w = fw / 2; | |||||
rr_param.stride_h = sh; | |||||
rr_param.stride_w = sw; | |||||
bencher.set_param(param) | |||||
.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.set_dtype(4, dtype::Float32()); | |||||
bencher.set_times(nr_times); | |||||
rr_bencher.set_param(rr_param) | |||||
.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Int32()) | |||||
.set_dtype(3, dtype::Int32()); | |||||
rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng); | |||||
rr_bencher.set_times(nr_times); | |||||
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); | |||||
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w); | |||||
TensorShape inp{batch, g, hi, wi} /*src*/, kern{g, 1, 1, fh, fw} /*filter*/, | |||||
rin{batch, hi, wi}, rout{batch, ho, wo}, | |||||
out{batch, g, ho, wo} /*output*/; | |||||
float bandwith = static_cast<float>( | |||||
inp.total_nr_elems() + kern.total_nr_elems() + | |||||
out.total_nr_elems()) / | |||||
(1024 * 1024 * 1024) * 1e3; | |||||
float rr_bandwith = static_cast<float>( | |||||
inp.total_nr_elems() + kern.total_nr_elems() + | |||||
rin.total_nr_elems() + rout.total_nr_elems() + | |||||
out.total_nr_elems()) / | |||||
(1024 * 1024 * 1024) * 1e3; | |||||
auto time_in_ms = bencher.execs({kern, out, inp}) / nr_times; | |||||
auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12; | |||||
auto rr_time_in_ms = rr_bencher.execs({kern, out, rin, rout, inp}) / nr_times; | |||||
auto rr_ops = | |||||
2.0 * batch * g * ho * wo * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12; | |||||
printf("[DGRAD]RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: " | |||||
"grad=%s, " | |||||
"kern=%s, diff=%s\n" | |||||
"time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n" | |||||
"bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n", | |||||
inp.to_string().c_str(), kern.to_string().c_str(), | |||||
out.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops, | |||||
bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms, | |||||
time_in_ms / rr_time_in_ms); | |||||
}; | |||||
run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); | |||||
} | |||||
TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_BACKWARD_LARGE_FILTER_FP32_UINT8) { | |||||
require_compute_capability(7, 5); | |||||
Benchmarker<ConvolutionBackwardData> bencher(handle_cuda()); | |||||
bencher.set_display(false); | |||||
bencher.set_before_exec_callback( | |||||
AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER")); | |||||
Benchmarker<RegionRestrictedConvolutionBackwardData> rr_bencher(handle_cuda()); | |||||
rr_bencher.set_display(false); | |||||
ConvolutionBackwardData::Param param; | |||||
param.format = ConvolutionBackwardData::Param::Format::NCHW; | |||||
param.sparse = ConvolutionBackwardData::Param::Sparse::GROUP; | |||||
RegionRestrictedConvolutionBackwardData::Param rr_param; | |||||
rr_param.format = RegionRestrictedConvolutionBackwardData::Param::Format::NCHW; | |||||
rr_param.sparse = RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; | |||||
UniformIntRNG r_rng{1, 3}; | |||||
auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh, | |||||
size_t fw, size_t sh, size_t sw, size_t nr_times) { | |||||
param.pad_h = fh / 2; | |||||
param.pad_w = fw / 2; | |||||
param.stride_h = sh; | |||||
param.stride_w = sw; | |||||
rr_param.pad_h = fh / 2; | |||||
rr_param.pad_w = fw / 2; | |||||
rr_param.stride_h = sh; | |||||
rr_param.stride_w = sw; | |||||
bencher.set_param(param) | |||||
.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.set_dtype(4, dtype::Float32()); | |||||
bencher.set_times(nr_times); | |||||
rr_bencher.set_param(rr_param) | |||||
.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Uint8()) | |||||
.set_dtype(3, dtype::Uint8()); | |||||
rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng); | |||||
rr_bencher.set_times(nr_times); | |||||
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); | |||||
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w); | |||||
TensorShape inp{batch, g, hi, wi} /*src*/, kern{g, 1, 1, fh, fw} /*filter*/, | |||||
rin{batch, hi, wi}, rout{batch, ho, wo}, | |||||
out{batch, g, ho, wo} /*output*/; | |||||
float bandwith = static_cast<float>( | |||||
inp.total_nr_elems() + kern.total_nr_elems() + | |||||
out.total_nr_elems()) / | |||||
(1024 * 1024 * 1024) * 1e3; | |||||
float rr_bandwith = static_cast<float>( | |||||
inp.total_nr_elems() + kern.total_nr_elems() + | |||||
rin.total_nr_elems() + rout.total_nr_elems() + | |||||
out.total_nr_elems()) / | |||||
(1024 * 1024 * 1024) * 1e3; | |||||
auto time_in_ms = bencher.execs({kern, out, inp}) / nr_times; | |||||
auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12; | |||||
auto rr_time_in_ms = rr_bencher.execs({kern, out, rin, rout, inp}) / nr_times; | |||||
auto rr_ops = | |||||
2.0 * batch * g * ho * wo * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12; | |||||
printf("[DGRAD]RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: " | |||||
"grad=%s, " | |||||
"kern=%s, diff=%s\n" | |||||
"time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n" | |||||
"bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n", | |||||
inp.to_string().c_str(), kern.to_string().c_str(), | |||||
out.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops, | |||||
bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms, | |||||
time_in_ms / rr_time_in_ms); | |||||
}; | |||||
run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10); | |||||
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); | |||||
} | |||||
TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) { | TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) { | ||||
require_compute_capability(7, 5); | require_compute_capability(7, 5); | ||||
Benchmarker<ConvBiasForward> bencher(handle_cuda()); | Benchmarker<ConvBiasForward> bencher(handle_cuda()); | ||||
@@ -271,6 +467,124 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) { | |||||
#endif | #endif | ||||
TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_DATA_FP32) { | |||||
Checker<RegionRestrictedConvolutionBackwardData> checker(handle_cuda()); | |||||
for (auto dt : std::vector<DType>{dtype::Int32(), dtype::Uint8()}) { | |||||
auto run = [&checker, &dt]( | |||||
size_t n, size_t g, size_t ih, size_t fh, size_t padding, | |||||
size_t stride) { | |||||
RegionRestrictedConvolutionBackwardData::Param cur_param; | |||||
cur_param.mode = RegionRestrictedConvolutionBackwardData::Param::Mode:: | |||||
CROSS_CORRELATION; | |||||
cur_param.compute_mode = RegionRestrictedConvolutionBackwardData::Param:: | |||||
ComputeMode::DEFAULT; | |||||
cur_param.sparse = | |||||
RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; | |||||
checker.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dt) | |||||
.set_dtype(3, dt); | |||||
float scale = 64.f / sqrt(fh * fh); | |||||
UniformFloatRNG rng(scale, 2 * scale); | |||||
UniformIntRNG r_rng{1, 2}; | |||||
checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &r_rng).set_rng( | |||||
3, &r_rng); | |||||
cur_param.pad_h = cur_param.pad_w = padding; | |||||
cur_param.stride_h = cur_param.stride_w = stride; | |||||
size_t oh = (ih + 2 * padding - fh + 1) / stride; | |||||
checker.set_param(cur_param).execs({ | |||||
{g, 1, 1, fh, fh}, // filter | |||||
{n, g * 1, oh, oh}, // diff | |||||
{n, ih, ih}, // rin | |||||
{n, oh, oh}, // rout | |||||
{n, g * 1, ih, ih} // grad | |||||
}); | |||||
}; | |||||
if (dt == dtype::Int32()) { | |||||
run(4, 8, 32, 5, 5 / 2, 1); | |||||
run(1, 2, 2, 2, 0, 1); | |||||
run(1, 2, 3, 3, 0, 1); | |||||
run(1, 2, 4, 4, 0, 1); | |||||
run(1, 2, 5, 5, 0, 1); | |||||
run(1, 2, 6, 6, 0, 1); | |||||
run(1, 2, 7, 7, 0, 1); | |||||
} | |||||
run(4, 8, 32, 7, 7 / 2, 1); | |||||
run(4, 8, 32, 9, 9 / 2, 1); | |||||
run(4, 8, 32, 11, 11 / 2, 1); | |||||
run(4, 8, 32, 13, 13 / 2, 1); | |||||
run(4, 8, 32, 15, 15 / 2, 1); | |||||
run(4, 8, 32, 17, 17 / 2, 1); | |||||
run(4, 8, 32, 19, 19 / 2, 1); | |||||
run(4, 8, 32, 21, 21 / 2, 1); | |||||
run(4, 8, 32, 23, 23 / 2, 1); | |||||
run(4, 8, 32, 25, 25 / 2, 1); | |||||
run(4, 8, 32, 27, 27 / 2, 1); | |||||
run(4, 8, 32, 29, 29 / 2, 1); | |||||
run(4, 8, 32, 31, 31 / 2, 1); | |||||
} | |||||
} | |||||
TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_DATA_FP32_RIN_EQ_ROUT) { | |||||
Checker<RegionRestrictedConvolutionBackwardData> checker(handle_cuda()); | |||||
for (auto dt : std::vector<DType>{dtype::Int32()}) { | |||||
auto run = [&checker, &dt]( | |||||
size_t n, size_t g, size_t ih, size_t fh, size_t padding, | |||||
size_t stride) { | |||||
RegionRestrictedConvolutionBackwardData::Param cur_param; | |||||
cur_param.mode = RegionRestrictedConvolutionBackwardData::Param::Mode:: | |||||
CROSS_CORRELATION; | |||||
cur_param.compute_mode = RegionRestrictedConvolutionBackwardData::Param:: | |||||
ComputeMode::DEFAULT; | |||||
cur_param.sparse = | |||||
RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; | |||||
checker.set_dtype(2, dt).set_dtype(3, dt); | |||||
float scale = 64.f / sqrt(fh * fh); | |||||
UniformFloatRNG rng(scale, 2 * scale); | |||||
// value 0 mask may cause unexpected behaviour. | |||||
UniformIntRNG r_rng{1, 1}; | |||||
checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &r_rng).set_rng( | |||||
3, &r_rng); | |||||
cur_param.pad_h = cur_param.pad_w = padding; | |||||
cur_param.stride_h = cur_param.stride_w = stride; | |||||
size_t oh = (ih + 2 * padding - fh + 1) / stride; | |||||
checker.set_param(cur_param).execs( | |||||
{/*filter*/ {g, 1, 1, fh, fh}, | |||||
/*diff*/ {n, g * 1, oh, oh}, | |||||
/*rin*/ {n, ih, ih}, | |||||
/*rout*/ {n, oh, oh}, | |||||
/*grad*/ {n, g * 1, ih, ih}}); | |||||
}; | |||||
if (dt == dtype::Int32()) { | |||||
// NOTE: UINT8 assert the spatial size of src&dst is 4*N | |||||
run(4, 8, 32, 5, 5 / 2, 1); | |||||
run(1, 2, 2, 2, 0, 1); | |||||
run(1, 2, 3, 3, 0, 1); | |||||
run(1, 2, 4, 4, 0, 1); | |||||
run(1, 2, 5, 5, 0, 1); | |||||
run(1, 2, 6, 6, 0, 1); | |||||
run(1, 2, 7, 7, 0, 1); | |||||
} | |||||
run(4, 8, 32, 7, 7 / 2, 1); | |||||
run(4, 8, 32, 9, 9 / 2, 1); | |||||
run(4, 8, 32, 11, 11 / 2, 1); | |||||
run(4, 8, 32, 13, 13 / 2, 1); | |||||
run(4, 8, 32, 15, 15 / 2, 1); | |||||
run(4, 8, 32, 17, 17 / 2, 1); | |||||
run(4, 8, 32, 19, 19 / 2, 1); | |||||
run(4, 8, 32, 21, 21 / 2, 1); | |||||
run(4, 8, 32, 23, 23 / 2, 1); | |||||
run(4, 8, 32, 25, 25 / 2, 1); | |||||
run(4, 8, 32, 27, 27 / 2, 1); | |||||
run(4, 8, 32, 29, 29 / 2, 1); | |||||
run(4, 8, 32, 31, 31 / 2, 1); | |||||
} | |||||
} | |||||
} // namespace test | } // namespace test | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -131,4 +131,110 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) { | |||||
{}}); | {}}); | ||||
} | } | ||||
TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD_DENSE_BRUTE) { | |||||
Checker<RegionRestrictedConvolutionForward> checker(handle()); | |||||
RegionRestrictedConvolutionForward::Param param; | |||||
checker.set_param(param).exect( | |||||
Testcase{ | |||||
TensorValue( // src | |||||
{1, 1, 4, 4}, dtype::Float32(), | |||||
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), | |||||
TensorValue( // filter | |||||
{1, 1, 2, 2}, dtype::Float32(), {1, 1, 1, 1}), | |||||
TensorValue( // rin | |||||
{1, 4, 4}, dtype::Int32(), | |||||
{1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1}), | |||||
TensorValue( // rout | |||||
{1, 3, 3}, dtype::Int32(), {0, 1, 1, 1, 0, 0, 1, 0, 1}), | |||||
{}, // output | |||||
}, | |||||
Testcase{ | |||||
{}, | |||||
{}, | |||||
{}, | |||||
{}, | |||||
TensorValue( | |||||
{1, 1, 3, 3}, dtype::Float32(), | |||||
{4, 14, 18, 5, 9, 0, 13, 9, 50})}); | |||||
} | |||||
TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_BWD_DATA_DENSE_BRUTE) { | |||||
Checker<RegionRestrictedConvolutionBackwardData> checker(handle()); | |||||
RegionRestrictedConvolutionBackwardData::Param param; | |||||
checker.set_param(param).exect( | |||||
Testcase{ | |||||
// filter | |||||
TensorValue( | |||||
{1, 1, 2, 2}, // shape | |||||
dtype::Float32(), // dtype | |||||
{1.f, 1.f, 1.f, 1.f}), | |||||
// diff | |||||
TensorValue( | |||||
{1, 1, 3, 3}, dtype::Float32(), | |||||
{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}), | |||||
// rin | |||||
TensorValue( | |||||
{1, 4, 4}, dtype::Int32(), | |||||
{1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1}), | |||||
// rout | |||||
TensorValue({1, 3, 3}, dtype::Int32(), {0, 1, 1, 1, 0, 0, 1, 0, 1}), | |||||
// grad | |||||
{}}, | |||||
Testcase{// filter | |||||
{}, | |||||
// diff | |||||
{}, | |||||
// rin | |||||
{}, | |||||
// rout | |||||
{}, | |||||
// grad | |||||
TensorValue( | |||||
{1, 1, 4, 4}, dtype::Float32(), | |||||
{0., 2., 5., 3., 1., 6., 5., 3., 0., 13., 9., 9., 0., 7., | |||||
9., 9.})}); | |||||
} | |||||
TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_BWD_DATA_GROUP_BRUTE) { | |||||
Checker<RegionRestrictedConvolutionBackwardData> checker(handle()); | |||||
// params | |||||
RegionRestrictedConvolutionBackwardData::Param param; | |||||
param.sparse = RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; | |||||
param.mode = RegionRestrictedConvolutionBackwardData::Mode::CROSS_CORRELATION; | |||||
param.compute_mode = | |||||
RegionRestrictedConvolutionBackwardData::Param::ComputeMode::DEFAULT; | |||||
param.pad_h = param.pad_w = | |||||
0; // forward param, naive backward data doesn't matter with deconv padding | |||||
param.stride_h = param.stride_w = 1; | |||||
// checker setting | |||||
checker.set_param(param).exect( | |||||
Testcase{// filter | |||||
TensorValue( | |||||
{2, 1, 1, 2, 2}, // shape | |||||
dtype::Float32(), // dtype | |||||
{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}), | |||||
// diff | |||||
TensorValue({1, 2, 1, 1}, dtype::Float32(), {1, 2}), | |||||
// rin | |||||
TensorValue({1, 2, 2}, dtype::Int32(), {1, 1, 1, 1}), | |||||
// rout | |||||
TensorValue({1, 1, 1}, dtype::Int32(), {1}), | |||||
// grad | |||||
{}}, | |||||
Testcase{// filter | |||||
{}, | |||||
// diff | |||||
{}, | |||||
// rin | |||||
{}, | |||||
// rout | |||||
{}, | |||||
// grad | |||||
TensorValue( | |||||
{1, 2, 2, 2}, dtype::Float32(), | |||||
{1, 2, 3, 4, 10, 12, 14, 16})}); | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |