GitOrigin-RevId: 87a7c9c575
tags/v0.5.0
@@ -35,7 +35,8 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
). | ). | ||||
add_enum(Doc('Format', 'convolution data/filter/output format; see ' | add_enum(Doc('Format', 'convolution data/filter/output format; see ' | ||||
':class:`RelayoutFormat` for more details'), | ':class:`RelayoutFormat` for more details'), | ||||
'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', 'NCHW44', | |||||
'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', | |||||
'NCHW44','NCHW44_DOT', | |||||
Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), | Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), | ||||
Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), | Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), | ||||
Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'), | Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'), | ||||
@@ -104,6 +104,7 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||||
bias.to_string().c_str(), dst.to_string().c_str()); | bias.to_string().c_str(), dst.to_string().c_str()); | ||||
} else if (param().format == param::ConvBias::Format::NCHW4 || | } else if (param().format == param::ConvBias::Format::NCHW4 || | ||||
param().format == param::ConvBias::Format::NCHW44 || | param().format == param::ConvBias::Format::NCHW44 || | ||||
param().format == param::ConvBias::Format::NCHW44_DOT || | |||||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | ||||
megdnn_assert(bias.shape[0] == 1); | megdnn_assert(bias.shape[0] == 1); | ||||
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | ||||
@@ -280,6 +280,13 @@ void make_canonized_filter_meta_nchwxx( | |||||
/** | /** | ||||
* input: N IC/pack_size, H, W, pack_size | * input: N IC/pack_size, H, W, pack_size | ||||
* | * | ||||
** NCHW44-DOT mode | |||||
* filter: | |||||
* {OC/pack_size, IC/pack_size, FH, FW, pack_size(OC), pack_size(IC)} | |||||
* [dense] | |||||
* {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \ | |||||
* FH, FW, pack_size(OC), pack_size(IC)} [group] | |||||
* | |||||
* NCHW88 and NCHW44 mode | * NCHW88 and NCHW44 mode | ||||
* filter: | * filter: | ||||
* {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)} | * {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)} | ||||
@@ -300,6 +307,7 @@ void make_canonized_filter_meta_nchwxx( | |||||
megdnn_assert(param.format == Param::Format::NCHW88 || | megdnn_assert(param.format == Param::Format::NCHW88 || | ||||
param.format == Param::Format::NCHW44 || | param.format == Param::Format::NCHW44 || | ||||
param.format == Param::Format::NCHW44_WINOGRAD || | param.format == Param::Format::NCHW44_WINOGRAD || | ||||
param.format == Param::Format::NCHW44_DOT || | |||||
param.format == Param::Format::NCHW88_WINOGRAD); | param.format == Param::Format::NCHW88_WINOGRAD); | ||||
size_t img_ndim = 2; | size_t img_ndim = 2; | ||||
size_t flt_start = 0; | size_t flt_start = 0; | ||||
@@ -554,6 +562,7 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta( | |||||
make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, | make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, | ||||
param(), ret); | param(), ret); | ||||
} else if (param().format == Param::Format::NCHW44 || | } else if (param().format == Param::Format::NCHW44 || | ||||
param().format == Param::Format::NCHW44_DOT || | |||||
param().format == Param::Format::NCHW44_WINOGRAD) { | param().format == Param::Format::NCHW44_WINOGRAD) { | ||||
make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, | make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, | ||||
param(), ret); | param(), ret); | ||||
@@ -660,6 +669,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
megdnn_assert(param().format == Param::Format::NHWCD4 || | megdnn_assert(param().format == Param::Format::NHWCD4 || | ||||
param().format == Param::Format::NCHW4 || | param().format == Param::Format::NCHW4 || | ||||
param().format == Param::Format::NCHW44 || | param().format == Param::Format::NCHW44 || | ||||
param().format == Param::Format::NCHW44_DOT || | |||||
param().format == Param::Format::NCHW8 || | param().format == Param::Format::NCHW8 || | ||||
param().format == Param::Format::NCHW32 || | param().format == Param::Format::NCHW32 || | ||||
param().format == Param::Format::NCHW88 || | param().format == Param::Format::NCHW88 || | ||||
@@ -668,6 +678,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
param().format == Param::Format::CHWN4); | param().format == Param::Format::CHWN4); | ||||
img_dim = src.ndim - 3; | img_dim = src.ndim - 3; | ||||
if ((param().format == Param::Format::NCHW88 || | if ((param().format == Param::Format::NCHW88 || | ||||
param().format == Param::Format::NCHW44_DOT || | |||||
param().format == Param::Format::NCHW44) && | param().format == Param::Format::NCHW44) && | ||||
filter.ndim == 5) { | filter.ndim == 5) { | ||||
img_dim = src.ndim - 2; | img_dim = src.ndim - 2; | ||||
@@ -675,6 +686,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
megdnn_assert(filter.ndim == img_dim + 3 || | megdnn_assert(filter.ndim == img_dim + 3 || | ||||
(filter.ndim == img_dim + 2 && | (filter.ndim == img_dim + 2 && | ||||
(param().format == Param::Format::NCHW88 || | (param().format == Param::Format::NCHW88 || | ||||
param().format == Param::Format::NCHW44_DOT || | |||||
param().format == Param::Format::NCHW44)) || | param().format == Param::Format::NCHW44)) || | ||||
filter.ndim == img_dim + 4 || | filter.ndim == img_dim + 4 || | ||||
filter.ndim == img_dim + 5, | filter.ndim == img_dim + 5, | ||||
@@ -727,6 +739,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
src.to_string().c_str(), filter.to_string().c_str()); | src.to_string().c_str(), filter.to_string().c_str()); | ||||
} | } | ||||
if (param().format == Param::Format::NCHW44 || | if (param().format == Param::Format::NCHW44 || | ||||
param().format == Param::Format::NCHW44_DOT || | |||||
param().format == Param::Format::NCHW44_WINOGRAD) { | param().format == Param::Format::NCHW44_WINOGRAD) { | ||||
megdnn_assert((src.ndim == 4 && filter.ndim == 5 && | megdnn_assert((src.ndim == 4 && filter.ndim == 5 && | ||||
filter[filter.ndim - 1] == 4) || | filter[filter.ndim - 1] == 4) || | ||||
@@ -859,8 +872,9 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
} | } | ||||
} else if (param().format == Param::Format::NCHW44 || | } else if (param().format == Param::Format::NCHW44 || | ||||
param().format == Param::Format::NCHW44_DOT || | |||||
param().format == Param::Format::NCHW44_WINOGRAD) { | param().format == Param::Format::NCHW44_WINOGRAD) { | ||||
megdnn_assert(src.ndim == 5 || (src.ndim == 4 && src[1] <= 8), | |||||
megdnn_assert(src.ndim == 5 || (src.ndim == 4 && src[1] <= 4), | |||||
"invalid src ndim for NCHW44, expected=5 or 4, got=%zu", | "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", | ||||
src.ndim); | src.ndim); | ||||
dst.ndim = 5; | dst.ndim = 5; | ||||
@@ -29,6 +29,7 @@ using namespace fallback; | |||||
size_t megdnn::fallback::get_format_pack_size(param::ConvBias::Format format) { | size_t megdnn::fallback::get_format_pack_size(param::ConvBias::Format format) { | ||||
switch (format) { | switch (format) { | ||||
case param::ConvBias::Format::NCHW44: | case param::ConvBias::Format::NCHW44: | ||||
case param::ConvBias::Format::NCHW44_DOT: | |||||
case param::ConvBias::Format::NCHW4: | case param::ConvBias::Format::NCHW4: | ||||
return 4_z; | return 4_z; | ||||
case param::ConvBias::Format::NCHW88: | case param::ConvBias::Format::NCHW88: | ||||
@@ -188,6 +189,7 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( | |||||
param().format == Param::Format::NCHW8 || | param().format == Param::Format::NCHW8 || | ||||
param().format == Param::Format::NCHW4 || | param().format == Param::Format::NCHW4 || | ||||
param().format == Param::Format::NCHW44 || | param().format == Param::Format::NCHW44 || | ||||
param().format == Param::Format::NCHW44_DOT || | |||||
param().format == Param::Format::NCHW || | param().format == Param::Format::NCHW || | ||||
param().format == Param::Format::NCHW_WINOGRAD || | param().format == Param::Format::NCHW_WINOGRAD || | ||||
param().format == Param::Format::NCHW88_WINOGRAD || | param().format == Param::Format::NCHW88_WINOGRAD || | ||||
@@ -405,6 +407,7 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id, | |||||
break; | break; | ||||
} | } | ||||
case Param::Format::NCHW44_DOT: | |||||
case Param::Format::NCHW44: { | case Param::Format::NCHW44: { | ||||
size_t group = filter_meta.group; | size_t group = filter_meta.group; | ||||
size_t icpg = filter_meta.icpg; | size_t icpg = filter_meta.icpg; | ||||
@@ -147,6 +147,7 @@ ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param( | |||||
if (param().format == Param::Format::NCHW88 || | if (param().format == Param::Format::NCHW88 || | ||||
param().format == Param::Format::NCHW8 || | param().format == Param::Format::NCHW8 || | ||||
param().format == Param::Format::NCHW4 || | param().format == Param::Format::NCHW4 || | ||||
param().format == Param::Format::NCHW44_DOT || | |||||
param().format == Param::Format::NCHW44) { | param().format == Param::Format::NCHW44) { | ||||
spatial_pos = 2; | spatial_pos = 2; | ||||
} else if (param().format == Param::Format::NCHW || | } else if (param().format == Param::Format::NCHW || | ||||
@@ -147,6 +147,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
if (filter_meta.format == Format::NCHW || | if (filter_meta.format == Format::NCHW || | ||||
filter_meta.format == Format::NCHW88 || | filter_meta.format == Format::NCHW88 || | ||||
filter_meta.format == Format::NCHW44 || | filter_meta.format == Format::NCHW44 || | ||||
filter_meta.format == Format::NCHW44_DOT || | |||||
filter_meta.format == Format::NCHW4 || | filter_meta.format == Format::NCHW4 || | ||||
filter_meta.format == Format::NCHW8 || | filter_meta.format == Format::NCHW8 || | ||||
filter_meta.format == Format::NCHW32) { | filter_meta.format == Format::NCHW32) { | ||||
@@ -174,6 +175,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
if (filter_meta.format == Format::NCHW4 || | if (filter_meta.format == Format::NCHW4 || | ||||
filter_meta.format == Format::CHWN4 || | filter_meta.format == Format::CHWN4 || | ||||
filter_meta.format == Format::NCHW44_DOT || | |||||
filter_meta.format == Format::NCHW44) { | filter_meta.format == Format::NCHW44) { | ||||
OC *= 4; | OC *= 4; | ||||
} else if (filter_meta.format == Format::NCHW8 || | } else if (filter_meta.format == Format::NCHW8 || | ||||
@@ -219,7 +221,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
FS_G = FS_OC * filter_meta.ocpg / 8; | FS_G = FS_OC * filter_meta.ocpg / 8; | ||||
} | } | ||||
} | } | ||||
} else if (filter_meta.format == Format::NCHW44) { | |||||
} else if (filter_meta.format == Format::NCHW44 || | |||||
filter_meta.format == Format::NCHW44_DOT) { | |||||
if (filter_meta.group > 1 && filter_meta.icpg == 1 && | if (filter_meta.group > 1 && filter_meta.icpg == 1 && | ||||
src.layout.ndim == 5 && filter_meta.ocpg == 1) { | src.layout.ndim == 5 && filter_meta.ocpg == 1) { | ||||
FS_SPATIAL = 4; | FS_SPATIAL = 4; | ||||
@@ -282,7 +285,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
h * layout.stride[2] + w * layout.stride[3] + | h * layout.stride[2] + w * layout.stride[3] + | ||||
(c & 0b111) * layout.stride[4]; | (c & 0b111) * layout.stride[4]; | ||||
} | } | ||||
} else if (filter_meta.format == Format::NCHW44) { | |||||
} else if (filter_meta.format == Format::NCHW44 || | |||||
filter_meta.format == Format::NCHW44_DOT) { | |||||
if (filter_meta.format == Format::NCHW44 && !is_output && | if (filter_meta.format == Format::NCHW44 && !is_output && | ||||
src.layout.ndim == 4) { | src.layout.ndim == 4) { | ||||
return n * layout.stride[0] + c * layout.stride[1] + | return n * layout.stride[0] + c * layout.stride[1] + | ||||
@@ -327,30 +331,41 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | ||||
(ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL + | (ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL + | ||||
((ic - ic0) % 4); | ((ic - ic0) % 4); | ||||
} else if (filter_meta.format == Format::NCHW88) { | |||||
} else if (filter_meta.format == Format::NCHW88 || | |||||
filter_meta.format == Format::NCHW44) { | |||||
size_t pack_c_size = 4_z; | |||||
if(filter_meta.format == Format::NCHW88){ | |||||
pack_c_size = 8_z; | |||||
} | |||||
if (src.layout.ndim == 4) { | if (src.layout.ndim == 4) { | ||||
// ic < 8, input is nchw | // ic < 8, input is nchw | ||||
return gc_out.cur_grp * FS_G + gc_out.cur_off / 8 * FS_OC + | |||||
return gc_out.cur_grp * FS_G + | |||||
gc_out.cur_off / pack_c_size * FS_OC + | |||||
(fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC + | (fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC + | ||||
gc_out.cur_off % 8; | |||||
gc_out.cur_off % pack_c_size; | |||||
} else if (filter_meta.group > 1 && filter_meta.icpg == 1 && | } else if (filter_meta.group > 1 && filter_meta.icpg == 1 && | ||||
filter_meta.ocpg == 1 && src.layout.ndim == 5) { | filter_meta.ocpg == 1 && src.layout.ndim == 5) { | ||||
// dw case | // dw case | ||||
return gc_out.cur_grp / 8 * FS_G + gc_out.cur_off * FS_OC + | |||||
(ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL + | |||||
gc_out.cur_grp % 8; | |||||
return gc_out.cur_grp / pack_c_size * FS_G + | |||||
gc_out.cur_off * FS_OC + (ic - ic0) * FS_IC + | |||||
(fh * FW + fw) * FS_SPATIAL + | |||||
gc_out.cur_grp % pack_c_size; | |||||
} else if (src.layout.ndim == 5) { | } else if (src.layout.ndim == 5) { | ||||
// normal case | // normal case | ||||
return gc_out.cur_grp * FS_G + gc_out.cur_off / 8 * FS_OC + | |||||
(ic - ic0) / 8 * FS_IC + (fh * FW + fw) * FS_SPATIAL + | |||||
((ic - ic0) & 0b111) * 8 + gc_out.cur_off % 8; | |||||
return gc_out.cur_grp * FS_G + | |||||
gc_out.cur_off / pack_c_size * FS_OC + | |||||
(ic - ic0) / pack_c_size * FS_IC + | |||||
(fh * FW + fw) * FS_SPATIAL + | |||||
((ic - ic0) % pack_c_size) * pack_c_size + | |||||
gc_out.cur_off % pack_c_size; | |||||
} else { | } else { | ||||
megdnn_assert( | |||||
0, "nchw88 naive not support this input and output\n"); | |||||
megdnn_throw( | |||||
"nchw88/nchw44 naive not support this input and " | |||||
"output\n"); | |||||
} | } | ||||
} else if (filter_meta.format == Format::NCHW44) { | |||||
} else if (filter_meta.format == Format::NCHW44_DOT) { | |||||
if (src.layout.ndim == 4) { | if (src.layout.ndim == 4) { | ||||
// ic < 8, input is nchw | |||||
// ic < 4, input is nchw | |||||
return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC + | return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC + | ||||
(fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC + | (fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC + | ||||
gc_out.cur_off % 4; | gc_out.cur_off % 4; | ||||
@@ -364,10 +379,10 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
// normal case | // normal case | ||||
return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC + | return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC + | ||||
(ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL + | (ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL + | ||||
((ic - ic0) % 4) * 4 + gc_out.cur_off % 4; | |||||
(gc_out.cur_off % 4) * 4 + ((ic - ic0) % 4); | |||||
} else { | } else { | ||||
megdnn_assert( | |||||
0, "nchw44 naive not support this input and output\n"); | |||||
megdnn_throw( | |||||
"nchw44_dot naive not support this input and output\n"); | |||||
} | } | ||||
} else { | } else { | ||||
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | ||||
@@ -559,6 +574,7 @@ void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst, | |||||
filter_meta.format == param::Convolution::Format::NHWC || | filter_meta.format == param::Convolution::Format::NHWC || | ||||
filter_meta.format == param::Convolution::Format::NCHW88 || | filter_meta.format == param::Convolution::Format::NCHW88 || | ||||
filter_meta.format == param::Convolution::Format::NCHW44 || | filter_meta.format == param::Convolution::Format::NCHW44 || | ||||
filter_meta.format == param::Convolution::Format::NCHW44_DOT || | |||||
filter_meta.format == param::Convolution::Format::NCHW4); | filter_meta.format == param::Convolution::Format::NCHW4); | ||||
compute2d<stype, ftype, dtype, comp_type, StrategyFwd>( | compute2d<stype, ftype, dtype, comp_type, StrategyFwd>( | ||||
src, const_cast<ftype*>(fptr), dst, filter_meta); | src, const_cast<ftype*>(fptr), dst, filter_meta); | ||||
@@ -613,6 +629,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
case param::Convolution::Format::NCHW: | case param::Convolution::Format::NCHW: | ||||
case param::Convolution::Format::NCHW88: | case param::Convolution::Format::NCHW88: | ||||
case param::Convolution::Format::NCHW44: | case param::Convolution::Format::NCHW44: | ||||
case param::Convolution::Format::NCHW44_DOT: | |||||
case param::Convolution::Format::NHWC: | case param::Convolution::Format::NHWC: | ||||
case param::Convolution::Format::NCHW4: | case param::Convolution::Format::NCHW4: | ||||
case param::Convolution::Format::NCHW8: | case param::Convolution::Format::NCHW8: | ||||
@@ -690,6 +707,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
case Format::NCHW44: | case Format::NCHW44: | ||||
case Format::NCHW44_DOT: | |||||
case Format::NCHW4: { | case Format::NCHW4: { | ||||
BIAS_ADD_NCHWx(4); | BIAS_ADD_NCHWx(4); | ||||
break; | break; | ||||
@@ -350,9 +350,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) { | |||||
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | ||||
handle(), "F32DIRECT_SMALL_GROUP"); | handle(), "F32DIRECT_SMALL_GROUP"); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1) { | |||||
check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, | |||||
false, false, true, true), | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_1) { | |||||
check_conv_bias(get_nchw44_conv_bias_args({2, 7}, 1, false, false, false, | |||||
false, true, true), | |||||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_2) { | |||||
check_conv_bias(get_nchw44_conv_bias_args({3, 5}, 1, false, false, false, | |||||
false, true, true), | |||||
handle(), "F32_CONV_NCHW44_DIRECT"); | handle(), "F32_CONV_NCHW44_DIRECT"); | ||||
} | } | ||||
@@ -516,4 +516,177 @@ TEST_F(NAIVE, CONV_BIAS_NCHW44) { | |||||
224, 268, 311, 218, 288, 311, 346, 277})}); | 224, 268, 311, 218, 288, 311, 346, 277})}); | ||||
} | } | ||||
} | } | ||||
TEST_F(NAIVE, CONV_BIAS_NCHW44_DOT) { | |||||
Checker<ConvBias> checker(handle(), /* check_dispatch */ false); | |||||
ConvBias::Param param; | |||||
param.format = ConvBias::Param::Format::NCHW44_DOT; | |||||
size_t n = 1; | |||||
size_t ic = 4; | |||||
size_t oc = 8; | |||||
size_t h = 2; | |||||
size_t w = 2; | |||||
size_t filter_size = 3; | |||||
size_t pad = 1; | |||||
auto src_tensor_shape = TensorShape{n, ic / 4, h, w, 4}; | |||||
auto weight_tensor_shape = | |||||
TensorShape{oc / 4, ic / 4, filter_size, filter_size, 4, 4}; | |||||
auto bias_tensor_shape = TensorShape{1, oc / 4, 1, 1, 4}; | |||||
param.pad_h = pad; | |||||
param.pad_w = pad; | |||||
UniformIntRNG rng{-127, 127}; | |||||
checker.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.set_dtype(4, dtype::Float32()) | |||||
.set_rng(0, &rng) | |||||
.set_rng(1, &rng) | |||||
.set_rng(2, &rng) | |||||
.set_epsilon(1e-3) | |||||
.set_param(param) | |||||
.execs({src_tensor_shape, | |||||
weight_tensor_shape, | |||||
bias_tensor_shape, | |||||
{}, | |||||
{}}); | |||||
checker.set_dtype(0, dtype::QuantizedS8(2.f)) | |||||
.set_dtype(1, dtype::QuantizedS8(3.f)) | |||||
.set_dtype(2, dtype::QuantizedS32(6.f)) | |||||
.set_dtype(4, dtype::QuantizedS32(6.f)) | |||||
.set_rng(0, &rng) | |||||
.set_rng(1, &rng) | |||||
.set_rng(2, &rng) | |||||
.set_epsilon(1e-3) | |||||
.set_param(param) | |||||
.execs({src_tensor_shape, | |||||
weight_tensor_shape, | |||||
bias_tensor_shape, | |||||
{}, | |||||
{}}); | |||||
{ | |||||
// test normal conv | |||||
ConvBias::Param param; | |||||
param.format = ConvBias::Param::Format::NCHW44_DOT; | |||||
param.sparse = ConvBias::Param::Sparse::DENSE; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
checker.set_param(param).exect( | |||||
Testcase{TensorValue({1, 1, 2, 2, 4}, dtype::Float32(), | |||||
{7, 2, 2, 1, 7, 5, 6, 3, 1, 2, 8, 3, 7, 7, | |||||
6, 4}), | |||||
TensorValue( | |||||
{1, 1, 3, 3, 4, 4}, dtype::Float32(), | |||||
{3, 0, 3, 1, 5, 1, 5, 7, 5, 4, 0, 0, 2, 8, 7, | |||||
7, 6, 5, 7, 3, 4, 2, 6, 2, 7, 2, 6, 2, 7, 4, | |||||
3, 8, 5, 0, 0, 7, 0, 5, 4, 7, 4, 1, 8, 2, 4, | |||||
0, 4, 0, 4, 6, 0, 1, 8, 2, 6, 4, 7, 3, 4, 3, | |||||
3, 0, 4, 8, 8, 2, 3, 7, 8, 5, 2, 0, 7, 5, 8, | |||||
2, 2, 1, 1, 7, 1, 0, 2, 4, 6, 6, 4, 2, 1, 3, | |||||
1, 7, 5, 0, 1, 5, 7, 5, 3, 0, 8, 7, 2, 1, 4, | |||||
0, 8, 4, 5, 3, 6, 6, 6, 2, 1, 5, 6, 4, 7, 2, | |||||
0, 4, 8, 8, 1, 1, 2, 3, 8, 6, 3, 1, 3, 3, 7, | |||||
1, 5, 4, 2, 1, 0, 3, 8, 4}), | |||||
TensorValue({1, 1, 1, 1, 4}, dtype::Float32(), | |||||
{7, 2, 8, 1}), | |||||
TensorValue({1, 1, 2, 2, 4}, dtype::Float32(), | |||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0}), | |||||
{}}, | |||||
Testcase{ | |||||
{}, | |||||
{}, | |||||
{}, | |||||
{}, | |||||
TensorValue({1, 1, 2, 2, 4}, dtype::Float32(), | |||||
{264, 338, 309, 195, 276, 332, 390, 199, | |||||
224, 268, 311, 218, 288, 311, 346, 277})}); | |||||
} | |||||
{ | |||||
// test dw conv | |||||
ConvBias::Param param; | |||||
param.format = ConvBias::Param::Format::NCHW44_DOT; | |||||
param.sparse = ConvBias::Param::Sparse::GROUP; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
checker.set_param(param).exect( | |||||
Testcase{TensorValue({1, 1, 2, 2, 4}, dtype::Float32(), | |||||
{5, 8, 3, 2, 4, 6, 1, 5, 0, 8, 2, 6, 8, 6, | |||||
5, 7}), | |||||
TensorValue({1, 1, 1, 3, 3, 4}, dtype::Float32(), | |||||
{3, 0, 3, 1, 6, 5, 7, 3, 5, 0, 0, 7, | |||||
4, 6, 0, 1, 8, 2, 3, 7, 1, 0, 2, 4, | |||||
7, 5, 3, 0, 6, 2, 1, 5, 8, 6, 3, 1}), | |||||
TensorValue({1, 1, 1, 1, 4}, dtype::Float32(), | |||||
{4, 3, 5, 6}), | |||||
TensorValue({1, 1, 2, 2, 4}, dtype::Float32(), | |||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0}), | |||||
{}}, | |||||
Testcase{{}, | |||||
{}, | |||||
{}, | |||||
{}, | |||||
TensorValue({1, 1, 2, 2, 4}, dtype::Float32(), | |||||
{112, 71, 33, 77, 104, 115, 19, 78, 62, 59, | |||||
42, 117, 107, 93, 36, 78})}); | |||||
} | |||||
{ | |||||
// test group conv | |||||
ConvBias::Param param; | |||||
param.format = ConvBias::Param::Format::NCHW44_DOT; | |||||
param.sparse = ConvBias::Param::Sparse::GROUP; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
checker.set_param(param).exect( | |||||
Testcase{TensorValue({1, 2, 2, 2, 4}, dtype::Float32(), | |||||
{6, 3, 2, 7, 7, 6, 4, 5, 8, 6, 3, | |||||
1, 1, 2, 8, 3, 1, 0, 6, 1, 3, 3, | |||||
6, 0, 0, 5, 6, 7, 2, 2, 4, 4}), | |||||
TensorValue( | |||||
{2, 1, 1, 3, 3, 4, 4}, dtype::Float32(), | |||||
{3, 0, 3, 1, 5, 1, 5, 7, 5, 4, 0, 0, 2, 8, 7, | |||||
7, 6, 5, 7, 3, 4, 2, 6, 2, 7, 2, 6, 2, 7, 4, | |||||
3, 8, 5, 0, 0, 7, 0, 5, 4, 7, 4, 1, 8, 2, 4, | |||||
0, 4, 0, 4, 6, 0, 1, 8, 2, 6, 4, 7, 3, 4, 3, | |||||
3, 0, 4, 8, 8, 2, 3, 7, 8, 5, 2, 0, 7, 5, 8, | |||||
2, 2, 1, 1, 7, 1, 0, 2, 4, 6, 6, 4, 2, 1, 3, | |||||
1, 7, 5, 0, 1, 5, 7, 5, 3, 0, 8, 7, 2, 1, 4, | |||||
0, 8, 4, 5, 3, 6, 6, 6, 2, 1, 5, 6, 4, 7, 2, | |||||
0, 4, 8, 8, 1, 1, 2, 3, 8, 6, 3, 1, 3, 3, 7, | |||||
1, 5, 4, 2, 1, 0, 3, 8, 4, 7, 6, 8, 3, 4, 8, | |||||
1, 0, 5, 7, 3, 0, 0, 4, 5, 3, 7, 8, 1, 3, 7, | |||||
1, 1, 0, 7, 2, 2, 0, 3, 0, 1, 1, 1, 6, 4, 0, | |||||
3, 3, 1, 2, 0, 0, 4, 1, 5, 5, 7, 6, 7, 1, 3, | |||||
5, 8, 6, 2, 1, 0, 7, 7, 1, 2, 6, 6, 1, 2, 3, | |||||
1, 2, 4, 8, 3, 2, 6, 0, 7, 4, 3, 7, 3, 3, 5, | |||||
3, 0, 3, 5, 1, 4, 5, 6, 2, 0, 5, 3, 3, 3, 5, | |||||
2, 4, 7, 1, 3, 5, 2, 8, 1, 8, 1, 2, 5, 1, 0, | |||||
6, 7, 7, 8, 7, 8, 8, 1, 8, 4, 4, 1, 4, 4, 5, | |||||
0, 2, 2, 2, 0, 1, 8, 4, 4, 7, 6, 8, 0, 1, 5, | |||||
4, 2, 6}), | |||||
TensorValue({1, 2, 1, 1, 4}, dtype::Float32(), | |||||
{1, 8, 5, 6, 2, 8, 7, 7}), | |||||
TensorValue({1, 2, 2, 2, 4}, dtype::Float32(), | |||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0}), | |||||
{}}, | |||||
Testcase{ | |||||
{}, | |||||
{}, | |||||
{}, | |||||
{}, | |||||
TensorValue({1, 2, 2, 2, 4}, dtype::Float32(), | |||||
{260, 342, 244, 241, 293, 385, 362, 257, | |||||
278, 301, 303, 226, 273, 306, 318, 307, | |||||
180, 244, 169, 156, 210, 244, 206, 167, | |||||
126, 165, 156, 207, 191, 141, 209, 172})}); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |