GitOrigin-RevId: d195f44dec
tags/v1.0.0-rc1
@@ -22,26 +22,19 @@ using namespace aarch64; | |||
/* ===================== stride-2 algo ===================== */ | |||
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) | |||
bool ConvBiasImpl::AlgoF16DirectStride2::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool ConvBiasImpl::AlgoF16DirectStride2::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
bool aviliable = | |||
param.filter_meta.format == param::Convolution::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float16 && | |||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
return param.filter_meta.format == param::Convolution::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float16 && | |||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||
FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
@@ -50,8 +43,9 @@ bool ConvBiasImpl::AlgoF16DirectStride2::usable( | |||
size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 1) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto wbundle = arm_common::MultithreadDirectConvCommon< | |||
dt_float16, __fp16>::get_bundle_stride(param, m_large_group); | |||
dt_float16, __fp16>::get_bundle_stride(param, large_group); | |||
return wbundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
@@ -77,6 +71,7 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
bool large_group = group >= param.nr_threads; | |||
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | |||
size_t, size_t, size_t, size_t, size_t)>; | |||
Func conv = nullptr; | |||
@@ -91,11 +86,11 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||
} | |||
WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< | |||
dt_float16, __fp16>::get_bundle_stride(param, m_large_group); | |||
dt_float16, __fp16>::get_bundle_stride(param, large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! Dense conv and small group | |||
if (m_large_group) { | |||
if (large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [bundle, conv]( | |||
const NCBKernParam& kern_param, | |||
@@ -18,15 +18,9 @@ namespace aarch64 { | |||
/* ===================== stride-2 algo ===================== */ | |||
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF16DirectStride2(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "ARMV8F16STRD2_LARGE_GROUP" | |||
: "ARMV8F16STRD2_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "ARMV8F16STRD2"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -21,26 +21,19 @@ using namespace megdnn; | |||
using namespace aarch64; | |||
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) | |||
bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
bool aviliable = | |||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||
FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
@@ -49,8 +42,9 @@ bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 1) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto wbundle = arm_common::MultithreadDirectConvCommon< | |||
float, float>::get_bundle_stride(param, m_large_group); | |||
float, float>::get_bundle_stride(param, large_group); | |||
return wbundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
@@ -75,6 +69,7 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
bool large_group = group >= param.nr_threads; | |||
using Func = std::function<void(const float*, const float*, float*, size_t, | |||
size_t, size_t, size_t, size_t)>; | |||
Func conv = nullptr; | |||
@@ -89,11 +84,11 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
} | |||
WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< | |||
float, float>::get_bundle_stride(param, m_large_group); | |||
float, float>::get_bundle_stride(param, large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! Dense conv and small group | |||
if (m_large_group) { | |||
if (large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [bundle, conv]( | |||
const NCBKernParam& kern_param, | |||
@@ -22,15 +22,9 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF32DirectStride2(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "ARMV8F32STRD2_LARGE_GROUP" | |||
: "ARMV8F32STRD2_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "ARMV8F32STRD2"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -25,13 +25,11 @@ using namespace megdnn; | |||
using namespace aarch64; | |||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoF32DirectStride2 f32_direct_stride2_large_group{true}; | |||
AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; | |||
AlgoF32DirectStride2 f32_direct_stride2; | |||
AlgoS8MatrixMul s8_matrix_mul; | |||
AlgoQU8MatrixMul qu8_matrix_mul; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
AlgoF16DirectStride2 f16_direct_stride2_large_group{true}; | |||
AlgoF16DirectStride2 f16_direct_stride2_small_group{false}; | |||
AlgoF16DirectStride2 f16_direct_stride2; | |||
#endif | |||
public: | |||
@@ -39,11 +37,9 @@ public: | |||
matmul_algos.emplace_back(&qu8_matrix_mul); | |||
matmul_algos.emplace_back(&s8_matrix_mul); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
direct_algos.emplace_back(&f16_direct_stride2_large_group); | |||
direct_algos.emplace_back(&f16_direct_stride2_small_group); | |||
direct_algos.emplace_back(&f16_direct_stride2); | |||
#endif | |||
direct_algos.emplace_back(&f32_direct_stride2_large_group); | |||
direct_algos.emplace_back(&f32_direct_stride2_small_group); | |||
direct_algos.emplace_back(&f32_direct_stride2); | |||
} | |||
SmallVector<AlgoBase*> direct_algos; | |||
SmallVector<AlgoBase*> matmul_algos; | |||
@@ -192,9 +192,8 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP16WinogradF23_8x8, | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_kimpl) | |||
bool ConvBiasImpl::AlgoF16Direct::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool ConvBiasImpl::AlgoF16Direct::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
@@ -203,20 +202,14 @@ bool ConvBiasImpl::AlgoF16Direct::usable( | |||
// ``param.osz[0]*param.osz[1] >= 8'' comes from the fact that the | |||
// kernel may have access to up to 8 fp16 after the end of the memory | |||
// chunk. | |||
bool aviliable = fm.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float16 && | |||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && | |||
param.isz[0] * param.isz[1] >= 8 && | |||
param.osz[0] * param.osz[1] >= 8 && FH <= 7 && | |||
SH == 1 && SW == 1; | |||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
return fm.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float16 && | |||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && param.isz[0] * param.isz[1] >= 8 && | |||
param.osz[0] * param.osz[1] >= 8 && FH <= 7 && SH == 1 && | |||
SW == 1; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
@@ -225,9 +218,10 @@ bool ConvBiasImpl::AlgoF16Direct::usable( | |||
size_t ConvBiasImpl::AlgoF16Direct::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto wbundle = | |||
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle( | |||
param, m_large_group); | |||
param, large_group); | |||
return wbundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
@@ -241,13 +235,14 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls( | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
bool large_group = group >= param.nr_threads; | |||
WorkspaceBundle bundle = | |||
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle( | |||
param, m_large_group); | |||
param, large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! When group >= nr_threads, treat it as large_group, each thread process | |||
//! one group for better performance | |||
if (m_large_group) { | |||
if (large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
@@ -316,27 +311,18 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::dispatch_kerns( | |||
/* ===================== stride-1 algo ===================== */ | |||
bool ConvBiasImpl::AlgoF16DirectStride1::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool ConvBiasImpl::AlgoF16DirectStride1::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
bool aviliable = | |||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float16 && | |||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float16 && | |||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && | |||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
@@ -351,6 +337,7 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
bool large_group = group >= param.nr_threads; | |||
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | |||
size_t, size_t, size_t, size_t, size_t)>; | |||
Func conv_kern_function = nullptr; | |||
@@ -371,11 +358,11 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( | |||
WorkspaceBundle bundle = | |||
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle_stride( | |||
param, m_large_group); | |||
param, large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! When group >= nr_threads, treat it as large_group, each thread process | |||
//! one group for better performance | |||
if (m_large_group) { | |||
if (large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [bundle, conv_kern_function]( | |||
const NCBKernParam& kern_param, | |||
@@ -423,8 +410,9 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( | |||
size_t ConvBiasImpl::AlgoF16DirectStride1::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 1) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = MultithreadDirectConvCommon< | |||
dt_float16, __fp16>::get_bundle_stride(param, m_large_group); | |||
dt_float16, __fp16>::get_bundle_stride(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
@@ -79,15 +79,10 @@ public: | |||
class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF16Direct(bool is_large_group) : m_large_group{is_large_group} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "F16DIRECT_LARGE_GROUP" | |||
: "F16DIRECT_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "F16DIRECT"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -99,14 +94,10 @@ public: | |||
class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF16DirectStride1(bool is_large_group) : m_large_group{is_large_group} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "F16STRD1_LARGE_GROUP" : "F16STRD1_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "F16STRD1"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
@@ -334,9 +334,8 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF63_4x4_NCHW44, | |||
/* ===================== direct algo ===================== */ | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl); | |||
bool ConvBiasImpl::AlgoF32Direct::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool ConvBiasImpl::AlgoF32Direct::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
@@ -345,20 +344,14 @@ bool ConvBiasImpl::AlgoF32Direct::usable( | |||
// ``param.osz[0]*param.osz[1] >= 4'' comes from the fact that the | |||
// kernel may have access to up to 4 floats after the end of the memory | |||
// chunk. | |||
bool aviliable = fm.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && | |||
param.isz[0] * param.isz[1] >= 4 && | |||
param.osz[0] * param.osz[1] >= 4 && FH <= 7 && | |||
SH == 1 && SW == 1; | |||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
return fm.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && param.isz[0] * param.isz[1] >= 4 && | |||
param.osz[0] * param.osz[1] >= 4 && FH <= 7 && SH == 1 && | |||
SW == 1; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
@@ -366,8 +359,9 @@ bool ConvBiasImpl::AlgoF32Direct::usable( | |||
size_t ConvBiasImpl::AlgoF32Direct::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto wbundle = MultithreadDirectConvCommon<float, float>::get_bundle( | |||
param, m_large_group); | |||
param, large_group); | |||
return wbundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
@@ -380,13 +374,14 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls( | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
bool large_group = group >= param.nr_threads; | |||
WorkspaceBundle bundle = | |||
MultithreadDirectConvCommon<float, float>::get_bundle( | |||
param, m_large_group); | |||
param, large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! When group >= nr_threads, treat it as large_group, each thread process | |||
//! one group for better performance | |||
if (m_large_group) { | |||
if (large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
@@ -452,27 +447,19 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::dispatch_kerns( | |||
return {}; | |||
} | |||
/* ===================== stride-1 algo ===================== */ | |||
bool ConvBiasImpl::AlgoF32DirectStride1::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool ConvBiasImpl::AlgoF32DirectStride1::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
bool aviliable = | |||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && | |||
FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
@@ -481,9 +468,10 @@ bool ConvBiasImpl::AlgoF32DirectStride1::usable( | |||
size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = | |||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||
param, m_large_group); | |||
param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
@@ -499,6 +487,7 @@ ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
bool large_group = group >= param.nr_threads; | |||
using Func = std::function<void(const float*, const float*, float*, size_t, | |||
size_t, size_t, size_t, size_t)>; | |||
Func conv_kern_function = nullptr; | |||
@@ -522,11 +511,11 @@ ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( | |||
WorkspaceBundle bundle = | |||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||
param, m_large_group); | |||
param, large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! When group >= nr_threads, treat it as large_group, each thread process | |||
//! one group for better performance | |||
if (m_large_group) { | |||
if (large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [bundle, conv_kern_function]( | |||
const NCBKernParam& kern_param, | |||
@@ -580,27 +569,19 @@ ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns( | |||
/* ===================== stride-2 algo ===================== */ | |||
bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
bool aviliable = | |||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||
FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
@@ -608,9 +589,10 @@ bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = | |||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||
param, m_large_group); | |||
param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
@@ -625,6 +607,7 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
bool large_group = group >= param.nr_threads; | |||
using Func = std::function<void(const float*, const float*, float*, size_t, | |||
size_t, size_t, size_t, size_t)>; | |||
Func conv_kern_function = nullptr; | |||
@@ -648,11 +631,11 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
WorkspaceBundle bundle = | |||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||
param, m_large_group); | |||
param, large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! When group >= nr_threads, treat it as large_group, each thread process | |||
//! one group for better performance | |||
if (m_large_group) { | |||
if (large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [bundle, conv_kern_function]( | |||
const NCBKernParam& kern_param, | |||
@@ -128,15 +128,10 @@ public: | |||
class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF32Direct(bool is_large_group) : m_large_group{is_large_group} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "F32DIRECT_LARGE_GROUP" | |||
: "F32DIRECT_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "F32DIRECT"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -147,14 +142,10 @@ public: | |||
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF32DirectStride1(bool is_large_group) : m_large_group{is_large_group} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "F32STRD1_LARGE_GROUP" : "F32STRD1_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "F32STRD1"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -165,14 +156,10 @@ public: | |||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF32DirectStride2(bool is_large_group) : m_large_group{is_large_group} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "F32STRD2_LARGE_GROUP" : "F32STRD2_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "F32STRD2"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -27,17 +27,10 @@ using namespace arm_common; | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8) | |||
/* ===================== stride1 algo ===================== */ | |||
bool ConvBiasImpl::AlgoS8DirectStride1::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool avaible = direct_int8_stride1::can_conv_direct_stride1_int8(param); | |||
auto fm = param.filter_meta; | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = fm.group >= param.nr_threads; | |||
avaible &= (large_group == m_large_group); | |||
} | |||
return avaible; | |||
bool ConvBiasImpl::AlgoS8DirectStride1::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
return direct_int8_stride1::can_conv_direct_stride1_int8(param); | |||
} | |||
bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred( | |||
const NCBKernSizeParam& param) const { | |||
@@ -53,8 +46,9 @@ bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred( | |||
} | |||
size_t ConvBiasImpl::AlgoS8DirectStride1::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
auto bundle = direct_int8_stride1::get_bundle(param, m_large_group); | |||
const NCBKernSizeParam& param) const { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_int8_stride1::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
@@ -62,7 +56,8 @@ SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 0) { | |||
return direct_int8_stride1::get_kimpls(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_int8_stride1::get_kimpls(param, large_group); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
@@ -117,21 +112,15 @@ ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::dispatch_kerns( | |||
} | |||
/* ===================== stride2 algo ===================== */ | |||
bool ConvBiasImpl::AlgoS8DirectStride2::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool avaible = direct_int8_stride2::can_conv_direct_stride2_int8(param); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
avaible &= (large_group == m_large_group); | |||
} | |||
return avaible; | |||
bool ConvBiasImpl::AlgoS8DirectStride2::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
return direct_int8_stride2::can_conv_direct_stride2_int8(param); | |||
} | |||
size_t ConvBiasImpl::AlgoS8DirectStride2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
auto bundle = direct_int8_stride2::get_bundle(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_int8_stride2::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
@@ -139,7 +128,8 @@ SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 1) { | |||
return direct_int8_stride2::get_kimpls(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_int8_stride2::get_kimpls(param, large_group); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
@@ -147,24 +137,15 @@ ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( | |||
#if __ARM_FEATURE_DOTPROD | |||
/* ===================== dot stride1 algo ======================== */ | |||
bool ConvBiasImpl::AlgoDotS8DirectStride1::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool avaible = | |||
direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
avaible &= (large_group == m_large_group); | |||
} | |||
return avaible; | |||
bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
return direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param); | |||
} | |||
size_t ConvBiasImpl::AlgoDotS8DirectStride1::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
@@ -172,29 +153,23 @@ SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 1) { | |||
return direct_dotprod_int8_stride1::get_kimpls(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_dotprod_int8_stride1::get_kimpls(param, large_group); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ===================== dot stride2 algo ======================== */ | |||
bool ConvBiasImpl::AlgoDotS8DirectStride2::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool avaible = | |||
direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
avaible &= (large_group == m_large_group); | |||
} | |||
return avaible; | |||
bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
return direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param); | |||
} | |||
size_t ConvBiasImpl::AlgoDotS8DirectStride2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, m_large_group); | |||
const NCBKernSizeParam& param) const { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
@@ -202,7 +177,8 @@ SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 2) { | |||
return direct_dotprod_int8_stride2::get_kimpls(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_dotprod_int8_stride2::get_kimpls(param, large_group); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
@@ -18,14 +18,10 @@ namespace megdnn { | |||
namespace arm_common { | |||
class ConvBiasImpl::AlgoS8DirectStride1 final : public AlgoBase { | |||
bool m_large_group; | |||
public: | |||
AlgoS8DirectStride1(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "S8STRD1_LARGE_GROUP" : "S8STRD1_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "S8STRD1"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
@@ -36,14 +32,10 @@ public: | |||
}; | |||
class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | |||
bool m_large_group; | |||
public: | |||
AlgoS8DirectStride2(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "S8STRD2_LARGE_GROUP" : "S8STRD2_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "S8STRD2"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -115,16 +107,10 @@ public: | |||
}; | |||
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | |||
bool m_large_group; | |||
public: | |||
AlgoDotS8DirectStride1(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "ARMDOTS8STRD1_LARGE_GROUP" | |||
: "ARMDOTS8STRD1_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "ARMDOTS8STRD1"; } | |||
bool usable(const NCBKernSizeParam&, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -134,15 +120,10 @@ public: | |||
}; | |||
class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { | |||
bool m_large_group; | |||
public: | |||
AlgoDotS8DirectStride2(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "ARMDOTS8STRD2_LARGE_GROUP" | |||
: "ARMDOTS8STRD2_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "ARMDOTS8STRD2"; } | |||
bool usable(const NCBKernSizeParam&, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -82,28 +82,20 @@ void get_rectified_size_str2(size_t IH, size_t IW, size_t OH, size_t OW, | |||
} // namespace | |||
/* ===================== direct algo ===================== */ | |||
bool ConvBiasImpl::AlgoI8x8x16Direct::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool ConvBiasImpl::AlgoI8x8x16Direct::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
bool aviliable = | |||
param.bias_mode == BiasMode::NO_BIAS && | |||
param.nonlineMode == NonlineMode::IDENTITY && | |||
fm.format == param::ConvBias::Format::NCHW && !fm.should_flip && | |||
param.src_type.enumv() == DTypeEnum::Int8 && | |||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||
param.dst_type.enumv() == DTypeEnum::Int16 && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && | |||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
return param.bias_mode == BiasMode::NO_BIAS && | |||
param.nonlineMode == NonlineMode::IDENTITY && | |||
fm.format == param::ConvBias::Format::NCHW && !fm.should_flip && | |||
param.src_type.enumv() == DTypeEnum::Int8 && | |||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||
param.dst_type.enumv() == DTypeEnum::Int16 && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && | |||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
@@ -117,11 +109,12 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Direct::get_bundle( | |||
auto OH = param.osz[0], OW = param.osz[1]; | |||
auto PH = fm.padding[0], PW = fm.padding[1]; | |||
size_t OH2, OW2, IH2, IW2; | |||
bool large_group = group >= param.nr_threads; | |||
get_rectified_size_str1(IH, IW, OH, OW, PH, PW, IH2, IW2, OH2, OW2); | |||
size_t part0 = 0u, part1 = 0u; | |||
if (need_src_copy_str1(param)) { | |||
part0 = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads | |||
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch; | |||
part0 = large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads | |||
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch; | |||
} | |||
if (need_dst_copy_str1(param)) { | |||
part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16; | |||
@@ -255,9 +248,10 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Direct::get_kimpls( | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
bool large_group = group >= param.nr_threads; | |||
WorkspaceBundle bundle = get_bundle(param); | |||
SmallVector<NCBKern> ret_kerns; | |||
if (m_large_group) { | |||
if (large_group) { | |||
auto exec_one_group = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
auto fm = kern_param.filter_meta; | |||
@@ -302,28 +296,20 @@ ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns( | |||
} | |||
/* ===================== stride-2 algo ===================== */ | |||
bool ConvBiasImpl::AlgoI8x8x16Stride2::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool ConvBiasImpl::AlgoI8x8x16Stride2::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
bool aviliable = param.bias_mode == BiasMode::NO_BIAS && | |||
param.nonlineMode == NonlineMode::IDENTITY && | |||
fm.format == param::ConvBias::Format::NCHW && | |||
!fm.should_flip && | |||
param.src_type.enumv() == DTypeEnum::Int8 && | |||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||
param.dst_type.enumv() == DTypeEnum::Int16 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 2 && fm.stride[1] == 2 && | |||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
return param.bias_mode == BiasMode::NO_BIAS && | |||
param.nonlineMode == NonlineMode::IDENTITY && | |||
fm.format == param::ConvBias::Format::NCHW && !fm.should_flip && | |||
param.src_type.enumv() == DTypeEnum::Int8 && | |||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||
param.dst_type.enumv() == DTypeEnum::Int16 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5); | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
@@ -340,9 +326,10 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Stride2::get_bundle( | |||
size_t OH2, OW2, IH2, IW2; | |||
get_rectified_size_str2(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); | |||
size_t part0 = 0u, part1 = 0u; | |||
bool large_group = group >= param.nr_threads; | |||
if (need_src_copy_str2(param)) { | |||
part0 = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads | |||
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch; | |||
part0 = large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads | |||
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch; | |||
} | |||
if (need_dst_copy_str2(param)) { | |||
part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16; | |||
@@ -475,9 +462,10 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Stride2::get_kimpls( | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
bool large_group = group >= param.nr_threads; | |||
WorkspaceBundle bundle = get_bundle(param); | |||
SmallVector<NCBKern> ret_kerns; | |||
if (m_large_group) { | |||
if (large_group) { | |||
auto exec_one_group = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
auto fm = kern_param.filter_meta; | |||
@@ -26,15 +26,10 @@ class ConvBiasImpl::AlgoI8x8x16Direct final : public AlgoBase { | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids); | |||
bool m_large_group; | |||
public: | |||
AlgoI8x8x16Direct(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "I8816DIRECT_LARGE_GROUP" | |||
: "I8816DIRECT_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "I8816DIRECT"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
@@ -53,15 +48,9 @@ class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids); | |||
bool m_large_group; | |||
public: | |||
AlgoI8x8x16Stride2(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "I8816STRD2_LARGE_GROUP" | |||
: "I8816STRD2_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "I8816STRD2"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -40,28 +40,20 @@ uint8_t arm_common_algo_type_storage; | |||
} // anonymous namespace | |||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoQU8DirectStride2 qu8_direct_stride2_large_group{true}; | |||
AlgoQU8DirectStride2 qu8_direct_stride2_small_group{false}; | |||
AlgoQU8DirectStride1 qu8_direct_stride1_large_group{true}; | |||
AlgoQU8DirectStride1 qu8_direct_stride1_small_group{false}; | |||
AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; | |||
AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; | |||
AlgoQU8DirectStride2 qu8_direct_stride2; | |||
AlgoQU8DirectStride1 qu8_direct_stride1; | |||
AlgoS8DirectStride2 s8_direct_stride2; | |||
AlgoS8DirectNCHW44 s8_direct_nchw44; | |||
AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; | |||
AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; | |||
AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; | |||
AlgoS8DirectStride1 s8_direct_stride1; | |||
AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; | |||
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; | |||
#if __ARM_FEATURE_DOTPROD | |||
AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true}; | |||
AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false}; | |||
AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true}; | |||
AlgoDotS8DirectStride2 ds8_direct_stride2_small_group{false}; | |||
AlgoDotU8DirectStride1 du8_direct_stride1_large_group{true}; | |||
AlgoDotU8DirectStride1 du8_direct_stride1_small_group{false}; | |||
AlgoDotU8DirectStride2 du8_direct_stride2_large_group{true}; | |||
AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false}; | |||
AlgoDotS8DirectStride1 ds8_direct_stride1; | |||
AlgoDotS8DirectStride2 ds8_direct_stride2; | |||
AlgoDotU8DirectStride1 du8_direct_stride1; | |||
AlgoDotU8DirectStride2 du8_direct_stride2; | |||
AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; | |||
AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; | |||
@@ -71,23 +63,16 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44; | |||
AlgoF32DirectNCHW44 f32_direct_nchw44; | |||
AlgoF32Direct f32_direct_large_group{true}; | |||
AlgoF32Direct f32_direct_small_group{false}; | |||
AlgoF32DirectStride2 f32_direct_stride2_large_group{true}; | |||
AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; | |||
AlgoF32DirectStride1 f32_direct_stride1_large_group{true}; | |||
AlgoF32DirectStride1 f32_direct_stride1_small_group{false}; | |||
AlgoF32Direct f32_direct; | |||
AlgoF32DirectStride2 f32_direct_stride2; | |||
AlgoF32DirectStride1 f32_direct_stride1; | |||
AlgoI8x8x16Direct i8x8x16_direct_large_group{true}; | |||
AlgoI8x8x16Direct i8x8x16_direct_small_group{false}; | |||
AlgoI8x8x16Stride2 i8x8x16_stride2_large_group{true}; | |||
AlgoI8x8x16Stride2 i8x8x16_stride2_small_group{false}; | |||
AlgoI8x8x16Direct i8x8x16_direct; | |||
AlgoI8x8x16Stride2 i8x8x16_stride2; | |||
AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
AlgoF16Direct f16_direct_large_group{true}; | |||
AlgoF16Direct f16_direct_small_group{false}; | |||
AlgoF16DirectStride1 f16_direct_stride1_large_group{true}; | |||
AlgoF16DirectStride1 f16_direct_stride1_small_group{false}; | |||
AlgoF16Direct f16_direct; | |||
AlgoF16DirectStride1 f16_direct_stride1; | |||
#endif | |||
SmallVector<std::unique_ptr<AlgoBase>> refhold; | |||
@@ -95,54 +80,39 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
public: | |||
AlgoPack() { | |||
#if __ARM_FEATURE_DOTPROD | |||
direct_algos.emplace_back(&ds8_direct_stride1_large_group); | |||
direct_algos.emplace_back(&ds8_direct_stride1_small_group); | |||
direct_algos.emplace_back(&ds8_direct_stride2_large_group); | |||
direct_algos.emplace_back(&ds8_direct_stride2_small_group); | |||
direct_algos.emplace_back(&du8_direct_stride1_large_group); | |||
direct_algos.emplace_back(&du8_direct_stride1_small_group); | |||
direct_algos.emplace_back(&du8_direct_stride2_large_group); | |||
direct_algos.emplace_back(&du8_direct_stride2_small_group); | |||
direct_algos.emplace_back(&ds8_direct_stride1); | |||
direct_algos.emplace_back(&ds8_direct_stride2); | |||
direct_algos.emplace_back(&du8_direct_stride1); | |||
direct_algos.emplace_back(&du8_direct_stride2); | |||
direct_algos.emplace_back(&ds8_direct_nchw44); | |||
direct_algos.emplace_back(&ds8_direct_nchw_nchw44); | |||
#endif | |||
direct_algos.emplace_back(&qu8_direct_stride2_large_group); | |||
direct_algos.emplace_back(&qu8_direct_stride2_small_group); | |||
direct_algos.emplace_back(&qu8_direct_stride1_large_group); | |||
direct_algos.emplace_back(&qu8_direct_stride1_small_group); | |||
direct_algos.emplace_back(&s8_direct_stride2_large_group); | |||
direct_algos.emplace_back(&s8_direct_stride2_small_group); | |||
direct_algos.emplace_back(&qu8_direct_stride2); | |||
direct_algos.emplace_back(&qu8_direct_stride1); | |||
direct_algos.emplace_back(&s8_direct_stride2); | |||
direct_algos.emplace_back(&s8_direct_nchw44); | |||
direct_algos.emplace_back(&s8_direct_nchw_nchw44); | |||
direct_algos.emplace_back(&s8_direct_stride1_large_group); | |||
direct_algos.emplace_back(&s8_direct_stride1_small_group); | |||
direct_algos.emplace_back(&s8_direct_stride1); | |||
direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | |||
direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
direct_algos.emplace_back(&f16_direct_stride1_large_group); | |||
direct_algos.emplace_back(&f16_direct_stride1_small_group); | |||
direct_algos.emplace_back(&f16_direct_large_group); | |||
direct_algos.emplace_back(&f16_direct_small_group); | |||
direct_algos.emplace_back(&f16_direct_stride1); | |||
direct_algos.emplace_back(&f16_direct); | |||
#endif | |||
direct_algos.emplace_back(&i8x8x16_direct_large_group); | |||
direct_algos.emplace_back(&i8x8x16_direct_small_group); | |||
direct_algos.emplace_back(&i8x8x16_direct); | |||
direct_algos.emplace_back(&i8x8x16_stride2_filter2); | |||
direct_algos.emplace_back(&i8x8x16_stride2_large_group); | |||
direct_algos.emplace_back(&i8x8x16_stride2_small_group); | |||
direct_algos.emplace_back(&i8x8x16_stride2); | |||
direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | |||
direct_algos.emplace_back(&f32_chanel_wise_nchw44); | |||
direct_algos.emplace_back(&f32_direct_nchw44); | |||
direct_algos.emplace_back(&f32_direct_stride1_large_group); | |||
direct_algos.emplace_back(&f32_direct_stride1_small_group); | |||
direct_algos.emplace_back(&f32_direct_stride2_large_group); | |||
direct_algos.emplace_back(&f32_direct_stride2_small_group); | |||
direct_algos.emplace_back(&f32_direct_large_group); | |||
direct_algos.emplace_back(&f32_direct_small_group); | |||
direct_algos.emplace_back(&f32_direct_stride1); | |||
direct_algos.emplace_back(&f32_direct_stride2); | |||
direct_algos.emplace_back(&f32_direct); | |||
static CpuOprDelegationStorage<2> storage; | |||
auto matmul_opr = storage.get<MatrixMul, 0>(); | |||
@@ -25,21 +25,15 @@ using namespace megdnn; | |||
using namespace arm_common; | |||
/* ===================== stride1 algo ===================== */ | |||
bool ConvBiasImpl::AlgoQU8DirectStride1::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool avaible = direct_quint8_stride1::can_conv_direct_stride1_quint8(param); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
avaible &= (large_group == m_large_group); | |||
} | |||
return avaible; | |||
bool ConvBiasImpl::AlgoQU8DirectStride1::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
return direct_quint8_stride1::can_conv_direct_stride1_quint8(param); | |||
} | |||
size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
auto bundle = direct_quint8_stride1::get_bundle(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_quint8_stride1::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
@@ -47,7 +41,8 @@ SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 0) { | |||
return direct_quint8_stride1::get_kimpls(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_quint8_stride1::get_kimpls(param, large_group); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
@@ -55,20 +50,15 @@ ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( | |||
/* ===================== stride2 algo ===================== */ | |||
bool ConvBiasImpl::AlgoQU8DirectStride2::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool avaible = direct_quint8_stride2::can_conv_direct_stride2_quint8(param); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
avaible &= (large_group == m_large_group); | |||
} | |||
return avaible; | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
return direct_quint8_stride2::can_conv_direct_stride2_quint8(param); | |||
} | |||
size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
auto bundle = direct_quint8_stride2::get_bundle(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_quint8_stride2::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
@@ -76,31 +66,23 @@ SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 1) { | |||
return direct_quint8_stride2::get_kimpls(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_quint8_stride2::get_kimpls(param, large_group); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
#if __ARM_FEATURE_DOTPROD | |||
/* ===================== stride1 algo ===================== */ | |||
bool ConvBiasImpl::AlgoDotU8DirectStride1::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool avaible = | |||
direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8( | |||
param); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
avaible &= (large_group == m_large_group); | |||
} | |||
return avaible; | |||
bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
return direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8(param); | |||
} | |||
size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
auto bundle = | |||
direct_dotprod_quint8_stride1::get_bundle(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
@@ -108,31 +90,23 @@ SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 0) { | |||
return direct_dotprod_quint8_stride1::get_kimpls(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_dotprod_quint8_stride1::get_kimpls(param, large_group); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ===================== stride2 algo ===================== */ | |||
bool ConvBiasImpl::AlgoDotU8DirectStride2::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool avaible = | |||
direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8( | |||
param); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
avaible &= (large_group == m_large_group); | |||
} | |||
return avaible; | |||
bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
return direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8(param); | |||
} | |||
size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
auto bundle = | |||
direct_dotprod_quint8_stride2::get_bundle(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
@@ -140,7 +114,8 @@ SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 1) { | |||
return direct_dotprod_quint8_stride2::get_kimpls(param, m_large_group); | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_dotprod_quint8_stride2::get_kimpls(param, large_group); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
@@ -18,14 +18,10 @@ namespace megdnn { | |||
namespace arm_common { | |||
class ConvBiasImpl::AlgoQU8DirectStride1 final : public AlgoBase { | |||
bool m_large_group; | |||
public: | |||
AlgoQU8DirectStride1(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "QU8STRD1_LARGE_GROUP" : "QU8STRD1_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "QU8STRD1"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -36,14 +32,10 @@ public: | |||
}; | |||
class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { | |||
bool m_large_group; | |||
public: | |||
AlgoQU8DirectStride2(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "QU8STRD2_LARGE_GROUP" : "QU8STRD2_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "QU8STRD2"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -53,15 +45,10 @@ public: | |||
}; | |||
#if __ARM_FEATURE_DOTPROD | |||
class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | |||
bool m_large_group; | |||
public: | |||
AlgoDotU8DirectStride1(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "ARMDOTU8STRD1_LARGE_GROUP" | |||
: "ARMDOTU8STRD1_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "ARMDOTU8STRD1"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -72,15 +59,10 @@ public: | |||
}; | |||
class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { | |||
bool m_large_group; | |||
public: | |||
AlgoDotU8DirectStride2(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "ARMDOTU8STRD2_LARGE_GROUP" | |||
: "ARMDOTU8STRD2_SMALL_GROUP"; | |||
} | |||
const char* name() const override { return "ARMDOTU8STRD2"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -65,9 +65,10 @@ void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, | |||
size_t IC = param.filter_meta.icpg; \ | |||
size_t OC = param.filter_meta.ocpg; \ | |||
size_t group = fm.group; \ | |||
bool large_group = group >= param.nr_threads; \ | |||
WorkspaceBundle bundle = get_bundle(param); \ | |||
SmallVector<NCBKern> ret_kerns; \ | |||
if (m_large_group) { \ | |||
if (large_group) { \ | |||
auto exec_one_group = [bundle]( \ | |||
const NCBKernParam& kern_param, \ | |||
const NCBKernIndex& ncb_index) mutable { \ | |||
@@ -104,22 +105,15 @@ void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, | |||
/* ===================== direct algo ===================== */ | |||
bool ConvBiasImpl::AlgoDirect::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool ConvBiasImpl::AlgoDirect::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
auto&& fm = param.filter_meta; | |||
bool aviliable = fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.spatial[0] <= 7 && fm.stride[0] == 1 && | |||
fm.stride[1] == 1; | |||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
return fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.spatial[0] <= 7 && | |||
fm.stride[0] == 1 && fm.stride[1] == 1; | |||
} | |||
WorkspaceBundle ConvBiasImpl::AlgoDirect::get_bundle( | |||
const NCBKernSizeParam& param) const { | |||
@@ -133,9 +127,10 @@ WorkspaceBundle ConvBiasImpl::AlgoDirect::get_bundle( | |||
get_rectified_img_size(IH, IW, FH, FW, OH, OW, fm.padding[0], fm.padding[1], | |||
IH2, IW2, OH2, OW2); | |||
size_t part0 = 0u, part1 = 0u; | |||
bool large_group = group >= param.nr_threads; | |||
if (IH != IH2 || IW != IW2) { | |||
part0 = m_large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads | |||
: IC * IH2 * IW2 * sizeof(float) * group * batch; | |||
part0 = large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads | |||
: IC * IH2 * IW2 * sizeof(float) * group * batch; | |||
} | |||
if (OH != OH2 || OW != OW2) { | |||
part1 = OH2 * OW2 * sizeof(float) * nr_threads; | |||
@@ -319,24 +314,17 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDirect::get_kimpls( | |||
GET_KERN; | |||
} | |||
/* ===================== direct-stride2 algo ===================== */ | |||
bool ConvBiasImpl::AlgoDirectStride2::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
bool ConvBiasImpl::AlgoDirectStride2::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
bool aviliable = | |||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
} | |||
WorkspaceBundle ConvBiasImpl::AlgoDirectStride2::get_bundle( | |||
@@ -352,10 +340,10 @@ WorkspaceBundle ConvBiasImpl::AlgoDirectStride2::get_bundle( | |||
size_t src_size = 0, dst_size = 0; | |||
size_t IH2, IW2, OH2, OW2; | |||
get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); | |||
bool large_group = group >= param.nr_threads; \ | |||
if (need_src_copy(param)) { | |||
src_size = m_large_group | |||
? IC * IH2 * IW2 * sizeof(float) * nr_threads | |||
: IC * IH2 * IW2 * sizeof(float) * group * batch; | |||
src_size = large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads | |||
: IC * IH2 * IW2 * sizeof(float) * group * batch; | |||
} | |||
if (need_dst_copy(param)) { | |||
// we only need one dst plane | |||
@@ -29,14 +29,10 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase { | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids); | |||
bool m_large_group; | |||
public: | |||
AlgoDirect(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP" | |||
: "X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP"; | |||
return "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"; | |||
} | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -65,14 +61,10 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase { | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids); | |||
bool m_large_group; | |||
public: | |||
AlgoDirectStride2(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP" | |||
: "X86_CONV_BIAS_DIRECT_STRIDE2_SMALL_GROUP"; | |||
return "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP"; | |||
} | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
@@ -76,10 +76,8 @@ void* ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::type() const { | |||
} | |||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoDirect stride1_direct_large_group{true}; | |||
AlgoDirect stride1_direct_small_group{false}; | |||
AlgoDirectStride2 stride2_direct_large_group{true}; | |||
AlgoDirectStride2 stride2_direct_small_group{false}; | |||
AlgoDirect stride1_direct; | |||
AlgoDirectStride2 stride2_direct; | |||
AlgoDirectAvx2Stride1Int8 avx2_stride1_direct_int8; | |||
AlgoAVX2DirectConvStride2 avx2_stride2_direct; | |||
AlgoChanWiseAvx2Stride1Qint8 avx2_stride1_chanwsie_qint8; | |||
@@ -103,10 +101,8 @@ public: | |||
all_algos.emplace_back(&mkldnn_matmul_qint8); | |||
all_algos.emplace_back(&mkldnn_qint8); | |||
#endif | |||
all_algos.emplace_back(&stride1_direct_large_group); | |||
all_algos.emplace_back(&stride1_direct_small_group); | |||
all_algos.emplace_back(&stride2_direct_large_group); | |||
all_algos.emplace_back(&stride2_direct_small_group); | |||
all_algos.emplace_back(&stride1_direct); | |||
all_algos.emplace_back(&stride2_direct); | |||
all_algos.emplace_back(&avx2_stride1_direct_int8); | |||
all_algos.emplace_back(&avx2_stride2_direct); | |||
all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); | |||
@@ -81,15 +81,10 @@ void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle, | |||
{arg.src, arg.filter, arg.bias, {}, {}}); | |||
} | |||
} | |||
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) { | |||
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) { | |||
check_conv_bias( | |||
conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | |||
handle(), "ARMV8F32STRD2_LARGE_GROUP"); | |||
} | |||
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) { | |||
check_conv_bias( | |||
conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | |||
handle(), "ARMV8F32STRD2_SMALL_GROUP"); | |||
handle(), "ARMV8F32STRD2"); | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
@@ -114,17 +109,11 @@ void checker_conv_bias_fp16(std::vector<conv_bias::TestArg> args, | |||
} | |||
} | |||
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2_LARGE_GROUP) { | |||
NormalRNG rng(1); | |||
checker_conv_bias_f16( | |||
conv_bias::get_conv_bias_args({2, 3, 5}, 2, false, false, false), | |||
handle(), rng, "ARMV8F16STRD2_LARGE_GROUP", 0.04); | |||
} | |||
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2_SMALL_GROUP) { | |||
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2) { | |||
NormalRNG rng(1); | |||
checker_conv_bias_f16( | |||
conv_bias::get_conv_bias_args({2, 3, 5}, 2, false, false, false), | |||
handle(), rng, "ARMV8F16STRD2_SMALL_GROUP", 0.04); | |||
handle(), rng, "ARMV8F16STRD2", 0.04); | |||
} | |||
#endif | |||
@@ -1310,8 +1310,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE1_NCHW44) { | |||
benchmark0.set_param(param); | |||
benchmark0.set_times(RUN); | |||
benchmark0.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
"F32STRD1_LARGE_GROUP")); | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD1")); | |||
auto opr = handle()->create_operator<ConvBias>(); | |||
opr->param() = param; | |||
@@ -1385,8 +1384,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE2_NCHW44) { | |||
benchmark0.set_param(param); | |||
benchmark0.set_times(RUN); | |||
benchmark0.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
"F32STRD2_LARGE_GROUP")); | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD2")); | |||
auto opr = handle()->create_operator<ConvBias>(); | |||
opr->param() = param; | |||
@@ -1464,8 +1462,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { | |||
benchmark0.set_param(param); | |||
benchmark0.set_times(RUN); | |||
benchmark0.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
"S8STRD1_LARGE_GROUP")); | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("S8STRD1")); | |||
auto opr = handle()->create_operator<ConvBias>(); | |||
opr->param() = param; | |||
@@ -356,15 +356,10 @@ void checker_conv_bias_int8x8x32_multi(std::vector<conv_bias::TestArg> args, | |||
} | |||
/**********************************F32 direct************************/ | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32) { | |||
check_conv_bias( | |||
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | |||
handle(), "F32DIRECT_LARGE_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) { | |||
check_conv_bias( | |||
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | |||
handle(), "F32DIRECT_SMALL_GROUP"); | |||
handle(), "F32DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) { | |||
@@ -391,21 +386,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { | |||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_LARGE_GROUP) { | |||
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | |||
handle(), "F32STRD1_LARGE_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_SMALL_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1) { | |||
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | |||
handle(), "F32STRD1_SMALL_GROUP"); | |||
handle(), "F32STRD1"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) { | |||
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | |||
handle(), "F32STRD2_LARGE_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) { | |||
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | |||
handle(), "F32STRD2_SMALL_GROUP"); | |||
handle(), "F32STRD2"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) { | |||
check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, | |||
@@ -437,72 +424,41 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) { | |||
/**********************************F16 direct************************/ | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) { | |||
NormalRNG rng(1); | |||
checker_conv_bias_f16( | |||
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | |||
handle(), rng, "F16DIRECT_LARGE_GROUP", 0.03); | |||
handle(), rng, "F16DIRECT", 0.03); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_SMALL_GROUP) { | |||
NormalRNG rng(1); | |||
checker_conv_bias_f16( | |||
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | |||
handle(), rng, "F16DIRECT_SMALL_GROUP", 0.03); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) { | |||
NormalRNG rng(1); | |||
checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false), | |||
handle(), rng, "F16STRD1_LARGE_GROUP", 0.03); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_SMALL_GROUP) { | |||
NormalRNG rng(1); | |||
checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false), | |||
handle(), rng, "F16STRD1_SMALL_GROUP", 0.03); | |||
handle(), rng, "F16STRD1", 0.03); | |||
} | |||
#endif | |||
/**********************************algo 8816 direct************************/ | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT) { | |||
checker_conv_bias_int8x8x16( | |||
get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(), | |||
"I8816DIRECT_LARGE_GROUP"); | |||
"I8816DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_SMALL_GROUP) { | |||
checker_conv_bias_int8x8x16( | |||
get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(), | |||
"I8816DIRECT_SMALL_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2) { | |||
checker_conv_bias_int8x8x16( | |||
get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(), | |||
"I8816STRD2_LARGE_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_SMALL_GROUP) { | |||
checker_conv_bias_int8x8x16( | |||
get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(), | |||
"I8816STRD2_SMALL_GROUP"); | |||
"I8816STRD2"); | |||
} | |||
/**********************************algo 8-8-32 direct************************/ | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1) { | |||
checker_conv_bias_int8x8x32_multi( | |||
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | |||
"S8STRD1_LARGE_GROUP"); | |||
"S8STRD1"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_SMALL_GROUP) { | |||
checker_conv_bias_int8x8x32_multi( | |||
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | |||
"S8STRD1_SMALL_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2) { | |||
checker_conv_bias_int8x8x32_multi( | |||
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | |||
"S8STRD2_LARGE_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_SMALL_GROUP) { | |||
checker_conv_bias_int8x8x32_multi( | |||
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | |||
"S8STRD2_SMALL_GROUP"); | |||
"S8STRD2"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
@@ -520,25 +476,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, | |||
} | |||
/********************************qint8 direct******************************/ | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_LARGE_GROUP) { | |||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 1, false, false, false), | |||
handle(), "S8STRD1_LARGE_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_SMALL_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1) { | |||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 1, false, false, false), | |||
handle(), "S8STRD1_SMALL_GROUP"); | |||
handle(), "S8STRD1"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2) { | |||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 2, false, false, false), | |||
handle(), "S8STRD2_LARGE_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) { | |||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 2, false, false, false), | |||
handle(), "S8STRD2_SMALL_GROUP"); | |||
handle(), "S8STRD2"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { | |||
checker_conv_bias_qint8x8x8( | |||
@@ -586,25 +532,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) { | |||
} | |||
/*****************************quint8 direct****************************/ | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) { | |||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 1, false, false, false), | |||
handle(), "QU8STRD1_LARGE_GROUP"); | |||
handle(), "QU8STRD1"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_SMALL_GROUP) { | |||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 1, false, false, false), | |||
handle(), "QU8STRD1_SMALL_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) { | |||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 2, false, false, false), | |||
handle(), "QU8STRD2_LARGE_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) { | |||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 2, false, false, false), | |||
handle(), "QU8STRD2_SMALL_GROUP"); | |||
handle(), "QU8STRD2"); | |||
} | |||
/****************************dot qint8 direct*************************/ | |||
@@ -624,100 +560,53 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | |||
} | |||
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | |||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 1, false, false, false), | |||
handle(), "ARMDOTS8STRD1_LARGE_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) { | |||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 1, false, false, false), | |||
handle(), "ARMDOTS8STRD1_SMALL_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) { | |||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 2, false, false, false), | |||
handle(), "ARMDOTS8STRD2_LARGE_GROUP"); | |||
handle(), "ARMDOTS8STRD1"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_WITHDOTPROD) { | |||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 2, false, false, false), | |||
handle(), "ARMDOTS8STRD2_SMALL_GROUP"); | |||
handle(), "ARMDOTS8STRD2"); | |||
} | |||
/****************************dot 8-8-32 direct*************************/ | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT) { | |||
checker_conv_bias_qint8x8x32( | |||
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | |||
"ARMDOTS8STRD1_LARGE_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_SMALL_GROUP) { | |||
checker_conv_bias_qint8x8x32( | |||
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | |||
"ARMDOTS8STRD1_SMALL_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_LARGE_GROUP) { | |||
checker_conv_bias_qint8x8x32( | |||
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | |||
"ARMDOTS8STRD2_LARGE_GROUP"); | |||
"ARMDOTS8STRD1"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_SMALL_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT) { | |||
checker_conv_bias_qint8x8x32( | |||
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | |||
"ARMDOTS8STRD2_SMALL_GROUP"); | |||
"ARMDOTS8STRD2"); | |||
} | |||
/******************************dot quint8*****************************/ | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD) { | |||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 1, false, false, false), | |||
handle(), "ARMDOTU8STRD1_LARGE_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) { | |||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 1, false, false, false), | |||
handle(), "ARMDOTU8STRD1_SMALL_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) { | |||
checker_conv_bias_quint8x8x8( | |||
get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false), | |||
handle(), "ARMDOTU8STRD2_LARGE_GROUP"); | |||
handle(), "ARMDOTU8STRD1"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) { | |||
//! TODO: this test without test kernel size=3, add it will case buss error now | |||
//! in armv7 | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) { | |||
checker_conv_bias_quint8x8x8( | |||
get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false), | |||
handle(), "ARMDOTU8STRD2_SMALL_GROUP"); | |||
handle(), "ARMDOTU8STRD2"); | |||
} | |||
/******************************dot quint8x8x32***********************/ | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_LARGE_GROUP) { | |||
checker_conv_bias_quint8x8x32( | |||
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | |||
"ARMDOTU8STRD1_LARGE_GROUP"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_SMALL_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1) { | |||
checker_conv_bias_quint8x8x32( | |||
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | |||
"ARMDOTU8STRD1_SMALL_GROUP"); | |||
"ARMDOTU8STRD1"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_LARGE_GROUP) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2) { | |||
checker_conv_bias_quint8x8x32( | |||
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | |||
"ARMDOTU8STRD2_LARGE_GROUP"); | |||
"ARMDOTU8STRD2"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) { | |||
checker_conv_bias_quint8x8x32( | |||
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | |||
"ARMDOTU8STRD2_SMALL_GROUP"); | |||
} | |||
/******************************dot int8x8x8 nchw44 ***********************/ | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) { | |||
using namespace conv_bias; | |||
@@ -125,7 +125,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { | |||
bench_case(1, 32, 32, 80, 80, 3, 4); | |||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||
std::string algo_name = "F32DIRECT_LARGE_GROUP"; | |||
std::string algo_name = "F32DIRECT"; | |||
printf("Benchmark F32DIRECT_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(), | |||
dtype::Float32(), dtype::Float32()}; | |||
@@ -137,7 +137,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "F32DIRECT_SMALL_GROUP"; | |||
algo_name = "F32DIRECT"; | |||
printf("Benchmark F32DIRECT_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||
@@ -186,7 +186,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR1) { | |||
bench_case(1, 32, 32, 80, 80, 3, 4); | |||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||
std::string algo_name = "F32STRD1_LARGE_GROUP"; | |||
std::string algo_name = "F32STRD1"; | |||
printf("Benchmark F32STRD1_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(), | |||
dtype::Float32(), dtype::Float32()}; | |||
@@ -198,7 +198,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR1) { | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "F32STRD1_SMALL_GROUP"; | |||
algo_name = "F32STRD1"; | |||
printf("Benchmark F32STRD1_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||
@@ -249,7 +249,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR2) { | |||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | |||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | |||
std::string algo_name = "F32STRD2_LARGE_GROUP"; | |||
std::string algo_name = "F32STRD2"; | |||
printf("Benchmark F32STRD2_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(), | |||
dtype::Float32(), dtype::Float32()}; | |||
@@ -261,7 +261,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR2) { | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "F32STRD2_SMALL_GROUP"; | |||
algo_name = "F32STRD2"; | |||
printf("Benchmark F32STRD2_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | |||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | |||
@@ -313,7 +313,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) { | |||
bench_case(1, 32, 32, 80, 80, 3, 4); | |||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||
std::string algo_name = "F16DIRECT_LARGE_GROUP"; | |||
std::string algo_name = "F16DIRECT"; | |||
printf("Benchmark F16DIRECT_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = {dtype::Float16(), dtype::Float16(), | |||
dtype::Float16(), dtype::Float16()}; | |||
@@ -325,7 +325,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) { | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "F16DIRECT_SMALL_GROUP"; | |||
algo_name = "F16DIRECT"; | |||
printf("Benchmark F16DIRECT_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||
@@ -375,7 +375,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) { | |||
bench_case(1, 32, 32, 80, 80, 3, 4); | |||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||
std::string algo_name = "F16STRD1_LARGE_GROUP"; | |||
std::string algo_name = "F16STRD1"; | |||
printf("Benchmark F16STRD1_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = {dtype::Float16(), dtype::Float16(), | |||
dtype::Float16(), dtype::Float16()}; | |||
@@ -387,7 +387,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) { | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "F16STRD1_SMALL_GROUP"; | |||
algo_name = "F16STRD1"; | |||
printf("Benchmark F16STRD1_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||
@@ -439,7 +439,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
bench_case(1, 32, 32, 80, 80, 3, 4); | |||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||
std::string algo_name = "I8816DIRECT_LARGE_GROUP"; | |||
std::string algo_name = "I8816DIRECT"; | |||
printf("Benchmark I8816DIRECT_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = {dtype::Int8(), dtype::Int8(), | |||
dtype::Int16(), dtype::Int16()}; | |||
@@ -451,7 +451,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "I8816DIRECT_SMALL_GROUP"; | |||
algo_name = "I8816DIRECT"; | |||
printf("Benchmark I8816DIRECT_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||
@@ -503,7 +503,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | |||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | |||
std::string algo_name = "I8816STRD2_LARGE_GROUP"; | |||
std::string algo_name = "I8816STRD2"; | |||
printf("Benchmark I8816STRD2_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = {dtype::Int8(), dtype::Int8(), | |||
dtype::Int16(), dtype::Int16()}; | |||
@@ -515,7 +515,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "I8816STRD2_SMALL_GROUP"; | |||
algo_name = "I8816STRD2"; | |||
printf("Benchmark I8816STRD2_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | |||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | |||
@@ -567,7 +567,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); | |||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); | |||
std::string algo_name = "S8STRD1_LARGE_GROUP"; | |||
std::string algo_name = "S8STRD1"; | |||
printf("Benchmark S8STRD1_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = { | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||
@@ -580,7 +580,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "S8STRD1_SMALL_GROUP"; | |||
algo_name = "S8STRD1"; | |||
printf("Benchmark S8STRD1_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); | |||
@@ -866,7 +866,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | |||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | |||
std::string algo_name = "S8STRD2_LARGE_GROUP"; | |||
std::string algo_name = "S8STRD2"; | |||
printf("Benchmark S8STRD2_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = { | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||
@@ -879,7 +879,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "S8STRD2_SMALL_GROUP"; | |||
algo_name = "S8STRD2"; | |||
printf("Benchmark S8STRD2_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | |||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | |||
@@ -932,7 +932,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); | |||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); | |||
std::string algo_name = "ARMDOTS8STRD1_LARGE_GROUP"; | |||
std::string algo_name = "ARMDOTS8STRD1"; | |||
printf("Benchmark ARMDOTS8STRD1_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = { | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||
@@ -945,7 +945,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "ARMDOTS8STRD1_SMALL_GROUP"; | |||
algo_name = "ARMDOTS8STRD1"; | |||
printf("Benchmark ARMDOTS8STRD1_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); | |||
@@ -997,7 +997,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | |||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | |||
std::string algo_name = "ARMDOTS8STRD2_LARGE_GROUP"; | |||
std::string algo_name = "ARMDOTS8STRD2"; | |||
printf("Benchmark ARMDOTS8STRD2_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = { | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||
@@ -1010,7 +1010,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "ARMDOTS8STRD2_SMALL_GROUP"; | |||
algo_name = "ARMDOTS8STRD2"; | |||
printf("Benchmark ARMDOTS8STRD2_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | |||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | |||
@@ -1064,7 +1064,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); | |||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); | |||
std::string algo_name = "QU8STRD1_LARGE_GROUP"; | |||
std::string algo_name = "QU8STRD1"; | |||
printf("Benchmark QU8STRD1_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100), | |||
dtype::Quantized8Asymm(0.2f, 120), | |||
@@ -1078,7 +1078,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "QU8STRD1_SMALL_GROUP"; | |||
algo_name = "QU8STRD1"; | |||
printf("Benchmark QU8STRD1_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); | |||
@@ -1130,7 +1130,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | |||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | |||
std::string algo_name = "QU8STRD2_LARGE_GROUP"; | |||
std::string algo_name = "QU8STRD2"; | |||
printf("Benchmark QU8STRD2_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100), | |||
dtype::Quantized8Asymm(0.2f, 120), | |||
@@ -1144,7 +1144,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "QU8STRD2_SMALL_GROUP"; | |||
algo_name = "QU8STRD2"; | |||
printf("Benchmark QU8STRD2_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | |||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | |||
@@ -1198,7 +1198,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); | |||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); | |||
std::string algo_name = "ARMDOTU8STRD1_LARGE_GROUP"; | |||
std::string algo_name = "ARMDOTU8STRD1"; | |||
printf("Benchmark ARMDOTU8STRD1_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100), | |||
dtype::Quantized8Asymm(0.2f, 120), | |||
@@ -1212,7 +1212,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "ARMDOTU8STRD1_SMALL_GROUP"; | |||
algo_name = "ARMDOTU8STRD1"; | |||
printf("Benchmark ARMDOTS8STRD1_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); | |||
@@ -1265,7 +1265,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
bench_case(1, 32, 32, 80, 80, 5, 4, 1, 2); | |||
bench_case(1, 32, 32, 80, 80, 5, 32, 1, 2); | |||
std::string algo_name = "ARMDOTU8STRD2_LARGE_GROUP"; | |||
std::string algo_name = "ARMDOTU8STRD2"; | |||
printf("Benchmark ARMDOTU8STRD2_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100), | |||
dtype::Quantized8Asymm(0.2f, 120), | |||
@@ -1279,7 +1279,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "ARMDOTU8STRD2_SMALL_GROUP"; | |||
algo_name = "ARMDOTU8STRD2"; | |||
printf("Benchmark ARMDOTU8STRD2_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 5, 1, 1, 2); | |||
bench_case(1, 32, 32, 128, 128, 5, 1, 1, 2); | |||
@@ -176,7 +176,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE1_I8x8x32_WITHDOTPROD) { | |||
constexpr size_t RUN = 50; | |||
Benchmarker<Convolution> benchmark(handle()); | |||
benchmark.set_before_exec_callback( | |||
AlgoChecker<Convolution>("CONVOLUTION_DEFAULT_ARMDOTS8STRD1_SMALL_GROUP")); | |||
AlgoChecker<Convolution>("CONVOLUTION_DEFAULT_ARMDOTS8STRD1")); | |||
benchmark.set_dtype(0, dtype::Int8()) | |||
.set_dtype(1, dtype::Int8()) | |||
.set_dtype(2, dtype::Int32()); | |||
@@ -243,7 +243,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE2_I8x8x32_WITHDOTPROD) { | |||
constexpr size_t RUN = 10; | |||
Benchmarker<Convolution> benchmark(handle()); | |||
benchmark.set_before_exec_callback( | |||
AlgoChecker<Convolution>("CONVOLUTION_DEFAULT_ARMDOTS8STRD2_SMALL_GROUP")); | |||
AlgoChecker<Convolution>("CONVOLUTION_DEFAULT_ARMDOTS8STRD2")); | |||
benchmark.set_dtype(0, dtype::Int8()) | |||
.set_dtype(1, dtype::Int8()) | |||
.set_dtype(2, dtype::Int32()); | |||
@@ -317,7 +317,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE1_QUINT8_WITHDOTPROD) { | |||
benchmark.set_display(false); | |||
benchmark.set_times(RUN); | |||
benchmark.set_before_exec_callback(AlgoChecker<ConvolutionForward>( | |||
"CONVOLUTION_DEFAULT_ARMDOTU8STRD1_SMALL_GROUP")); | |||
"CONVOLUTION_DEFAULT_ARMDOTU8STRD1")); | |||
Benchmarker<Convolution> benchmark_float(handle()); | |||
benchmark_float.set_display(false); | |||
@@ -387,7 +387,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE2_QUINT8_WITHDOTPROD) { | |||
benchmark.set_display(false); | |||
benchmark.set_times(RUN); | |||
benchmark.set_before_exec_callback(AlgoChecker<ConvolutionForward>( | |||
"CONVOLUTION_DEFAULT_ARMDOTU8STRD2_SMALL_GROUP")); | |||
"CONVOLUTION_DEFAULT_ARMDOTU8STRD2")); | |||
Benchmarker<Convolution> benchmark_float(handle()); | |||
benchmark_float.set_display(false); | |||
@@ -583,7 +583,7 @@ TEST_F(X86_MULTI_THREADS, AVX2_CONV_BIAS_DIRECT_STRIDE2_S8S8S8) { | |||
} | |||
} | |||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP) { | |||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_DENSE) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args; | |||
@@ -633,19 +633,19 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP) { | |||
.set_rng(2, &rng); | |||
checker.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
"X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP")); | |||
"X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP")); | |||
for (auto&& arg : args) { | |||
checker.set_param(arg.param).exec( | |||
{arg.src, arg.filter, arg.bias, {}, {}}); | |||
} | |||
} | |||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP) { | |||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_GROUP) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args; | |||
auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, | |||
size_t p, NonlineMode nonline_mode) { | |||
auto run = [&](size_t group, size_t channel, size_t w, size_t h, | |||
size_t kernel, size_t p, NonlineMode nonline_mode) { | |||
if (w + 2 * p < kernel || h + 2 * p < kernel) | |||
return; | |||
param::ConvBias param; | |||
@@ -654,30 +654,37 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP) { | |||
param.pad_h = p; | |||
param.pad_w = p; | |||
param.nonlineMode = nonline_mode; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
//! no bias | |||
args.emplace_back(param, TensorShape{1, ic, h, w}, | |||
TensorShape{oc, ic, kernel, kernel}, TensorShape{}); | |||
args.emplace_back( | |||
param, TensorShape{1, channel, h, w}, | |||
TensorShape{group, channel / group, channel / group, kernel, kernel}, | |||
TensorShape{}); | |||
//! bias channel | |||
args.emplace_back(param, TensorShape{2, ic, h, w}, | |||
TensorShape{oc, ic, kernel, kernel}, | |||
TensorShape{1, oc, 1, 1}); | |||
args.emplace_back(param, TensorShape{2, channel, h, w}, | |||
TensorShape{group, channel / group, channel / group, | |||
kernel, kernel}, | |||
TensorShape{1, channel, 1, 1}); | |||
//! bias | |||
args.emplace_back(param, TensorShape{2, ic, h, w}, | |||
TensorShape{oc, ic, kernel, kernel}, | |||
TensorShape{2, oc, (h + param.pad_h * 2 - kernel) + 1, | |||
(w + param.pad_w * 2 - kernel) + 1}); | |||
args.emplace_back( | |||
param, TensorShape{2, channel, h, w}, | |||
TensorShape{group, channel / group, channel / group, kernel, | |||
kernel}, | |||
TensorShape{2, channel, (h + param.pad_h * 2 - kernel) + 1, | |||
(w + param.pad_w * 2 - kernel) + 1}); | |||
}; | |||
for (size_t kernel : {1, 2, 3, 4, 5, 6, 7}) | |||
for (size_t ic : {1, 4, 8, 16}) | |||
for (size_t oc : {1, 4, 8}) | |||
for (size_t channel : {4, 8, 16}) | |||
for (size_t group : {1, 2, 4}) | |||
for (size_t p : {0, 2}) | |||
for (size_t size : {20, 21, 24}) | |||
for (NonlineMode nonline_mode : | |||
{NonlineMode::RELU, NonlineMode::SIGMOID, | |||
NonlineMode::H_SWISH, NonlineMode::IDENTITY}) { | |||
run(oc, ic, size, size, kernel, p, nonline_mode); | |||
run(group, channel, size, size, kernel, p, | |||
nonline_mode); | |||
} | |||
Checker<ConvBias> checker(handle()); | |||
@@ -697,7 +704,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP) { | |||
} | |||
} | |||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) { | |||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2_DENSE) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args; | |||
@@ -738,11 +745,68 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) { | |||
.set_rng(2, &rng); | |||
checker.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
"X86_CONV_BIAS_DIRECT_STRIDE2_SMALL_GROUP")); | |||
"X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP")); | |||
for (auto&& arg : args) { | |||
checker.set_param(arg.param).exec( | |||
{arg.src, arg.filter, arg.bias, {}, {}}); | |||
} | |||
} | |||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2_GROUP) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args; | |||
auto run = [&](size_t group, size_t channel, size_t w, size_t h, | |||
size_t kernel, size_t p, NonlineMode nonline_mode) { | |||
if (w + 2 * p < kernel || h + 2 * p < kernel) | |||
return; | |||
param::ConvBias param; | |||
param.stride_h = 2; | |||
param.stride_w = 2; | |||
param.pad_h = p; | |||
param.pad_w = p; | |||
param.nonlineMode = nonline_mode; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
//! no bias | |||
args.emplace_back( | |||
param, TensorShape{1, channel, h, w}, | |||
TensorShape{group, channel / group, channel / group, kernel, kernel}, | |||
TensorShape{}); | |||
//! bias channel | |||
args.emplace_back(param, TensorShape{2, channel, h, w}, | |||
TensorShape{group, channel / group, channel / group, | |||
kernel, kernel}, | |||
TensorShape{1, channel, 1, 1}); | |||
//! bias | |||
args.emplace_back( | |||
param, TensorShape{2, channel, h, w}, | |||
TensorShape{group, channel / group, channel / group, kernel, | |||
kernel}, | |||
TensorShape{2, channel, (h + param.pad_h * 2 - kernel) / 2 + 1, | |||
(w + param.pad_w * 2 - kernel) / 2 + 1}); | |||
}; | |||
for (size_t kernel : {2, 3, 5, 7}) | |||
for (size_t channel : {4, 8, 16}) | |||
for (size_t group : {1, 2, 4}) | |||
for (size_t p : {0, 2}) | |||
for (size_t size : {20, 21, 24}) | |||
for (NonlineMode nonline_mode : | |||
{NonlineMode::RELU, NonlineMode::SIGMOID, | |||
NonlineMode::H_SWISH, NonlineMode::IDENTITY}) { | |||
run(group, channel, size, size, kernel, p, | |||
nonline_mode); | |||
} | |||
Checker<ConvBias> checker(handle()); | |||
UniformIntRNG rng{-50, 50}; | |||
checker.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
.set_dtype(2, dtype::Float32()) | |||
.set_rng(0, &rng) | |||
.set_rng(1, &rng) | |||
.set_rng(2, &rng); | |||
checker.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
"X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP")); | |||
@@ -2502,7 +2566,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { | |||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||
std::string algo_name = "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"; | |||
printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP algo\n"); | |||
printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_GROUP algo\n"); | |||
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | |||
{4, {4, 5, 6, 7}}, {1, {4}}, data_type); | |||
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | |||
@@ -2511,8 +2575,8 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { | |||
{1, {4}}, data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP"; | |||
printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP algo\n"); | |||
algo_name = "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"; | |||
printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_DENSE algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||
bench_case(1, 32, 32, 100, 100, 3, 1); | |||
@@ -125,7 +125,7 @@ TEST_F(X86, DEFAULT_CONV_DIRECT_STRIDE1) { | |||
Checker<ConvolutionForward> checker(handle()); | |||
checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>( | |||
"CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP")); | |||
"CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP")); | |||
checker.set_epsilon(1); | |||
UniformIntRNG rng{-50, 50}; | |||
checker.set_dtype(0, dtype::Float32()) | |||
@@ -167,7 +167,7 @@ TEST_F(X86, DEFAULT_CONV_DIRECT_STRIDE2) { | |||
Checker<ConvolutionForward> checker(handle()); | |||
checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>( | |||
"CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE2_SMALL_GROUP")); | |||
"CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP")); | |||
checker.set_epsilon(1); | |||
UniformIntRNG rng{-50, 50}; | |||
checker.set_dtype(0, dtype::Float32()) | |||