GitOrigin-RevId: d195f44dec
tags/v1.0.0-rc1
@@ -22,26 +22,19 @@ using namespace aarch64; | |||||
/* ===================== stride-2 algo ===================== */ | /* ===================== stride-2 algo ===================== */ | ||||
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) | MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) | ||||
bool ConvBiasImpl::AlgoF16DirectStride2::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool ConvBiasImpl::AlgoF16DirectStride2::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { | MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { | ||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
bool aviliable = | |||||
param.filter_meta.format == param::Convolution::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float16 && | |||||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
return param.filter_meta.format == param::Convolution::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float16 && | |||||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||||
FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return false; | return false; | ||||
@@ -50,8 +43,9 @@ bool ConvBiasImpl::AlgoF16DirectStride2::usable( | |||||
size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( | size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 1) { | MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 1) { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto wbundle = arm_common::MultithreadDirectConvCommon< | auto wbundle = arm_common::MultithreadDirectConvCommon< | ||||
dt_float16, __fp16>::get_bundle_stride(param, m_large_group); | |||||
dt_float16, __fp16>::get_bundle_stride(param, large_group); | |||||
return wbundle.total_size_in_bytes(); | return wbundle.total_size_in_bytes(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -77,6 +71,7 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||||
size_t IC = param.filter_meta.icpg; | size_t IC = param.filter_meta.icpg; | ||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
size_t group = fm.group; | size_t group = fm.group; | ||||
bool large_group = group >= param.nr_threads; | |||||
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | ||||
size_t, size_t, size_t, size_t, size_t)>; | size_t, size_t, size_t, size_t, size_t)>; | ||||
Func conv = nullptr; | Func conv = nullptr; | ||||
@@ -91,11 +86,11 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||||
} | } | ||||
WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< | WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< | ||||
dt_float16, __fp16>::get_bundle_stride(param, m_large_group); | |||||
dt_float16, __fp16>::get_bundle_stride(param, large_group); | |||||
SmallVector<NCBKern> ret_kerns; | SmallVector<NCBKern> ret_kerns; | ||||
//! Dense conv and small group | //! Dense conv and small group | ||||
if (m_large_group) { | |||||
if (large_group) { | |||||
//! Channel wise conv and big groups | //! Channel wise conv and big groups | ||||
auto exec_one_group = [bundle, conv]( | auto exec_one_group = [bundle, conv]( | ||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
@@ -18,15 +18,9 @@ namespace aarch64 { | |||||
/* ===================== stride-2 algo ===================== */ | /* ===================== stride-2 algo ===================== */ | ||||
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoF16DirectStride2(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "ARMV8F16STRD2_LARGE_GROUP" | |||||
: "ARMV8F16STRD2_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "ARMV8F16STRD2"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -21,26 +21,19 @@ using namespace megdnn; | |||||
using namespace aarch64; | using namespace aarch64; | ||||
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) | MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) | ||||
bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { | MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { | ||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
bool aviliable = | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||||
FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return false; | return false; | ||||
@@ -49,8 +42,9 @@ bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||||
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 1) { | MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 1) { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto wbundle = arm_common::MultithreadDirectConvCommon< | auto wbundle = arm_common::MultithreadDirectConvCommon< | ||||
float, float>::get_bundle_stride(param, m_large_group); | |||||
float, float>::get_bundle_stride(param, large_group); | |||||
return wbundle.total_size_in_bytes(); | return wbundle.total_size_in_bytes(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -75,6 +69,7 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
size_t IC = param.filter_meta.icpg; | size_t IC = param.filter_meta.icpg; | ||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
size_t group = fm.group; | size_t group = fm.group; | ||||
bool large_group = group >= param.nr_threads; | |||||
using Func = std::function<void(const float*, const float*, float*, size_t, | using Func = std::function<void(const float*, const float*, float*, size_t, | ||||
size_t, size_t, size_t, size_t)>; | size_t, size_t, size_t, size_t)>; | ||||
Func conv = nullptr; | Func conv = nullptr; | ||||
@@ -89,11 +84,11 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
} | } | ||||
WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< | WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< | ||||
float, float>::get_bundle_stride(param, m_large_group); | |||||
float, float>::get_bundle_stride(param, large_group); | |||||
SmallVector<NCBKern> ret_kerns; | SmallVector<NCBKern> ret_kerns; | ||||
//! Dense conv and small group | //! Dense conv and small group | ||||
if (m_large_group) { | |||||
if (large_group) { | |||||
//! Channel wise conv and big groups | //! Channel wise conv and big groups | ||||
auto exec_one_group = [bundle, conv]( | auto exec_one_group = [bundle, conv]( | ||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
@@ -22,15 +22,9 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoF32DirectStride2(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "ARMV8F32STRD2_LARGE_GROUP" | |||||
: "ARMV8F32STRD2_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "ARMV8F32STRD2"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -25,13 +25,11 @@ using namespace megdnn; | |||||
using namespace aarch64; | using namespace aarch64; | ||||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | class ConvBiasImpl::AlgoPack : NonCopyableObj { | ||||
AlgoF32DirectStride2 f32_direct_stride2_large_group{true}; | |||||
AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; | |||||
AlgoF32DirectStride2 f32_direct_stride2; | |||||
AlgoS8MatrixMul s8_matrix_mul; | AlgoS8MatrixMul s8_matrix_mul; | ||||
AlgoQU8MatrixMul qu8_matrix_mul; | AlgoQU8MatrixMul qu8_matrix_mul; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
AlgoF16DirectStride2 f16_direct_stride2_large_group{true}; | |||||
AlgoF16DirectStride2 f16_direct_stride2_small_group{false}; | |||||
AlgoF16DirectStride2 f16_direct_stride2; | |||||
#endif | #endif | ||||
public: | public: | ||||
@@ -39,11 +37,9 @@ public: | |||||
matmul_algos.emplace_back(&qu8_matrix_mul); | matmul_algos.emplace_back(&qu8_matrix_mul); | ||||
matmul_algos.emplace_back(&s8_matrix_mul); | matmul_algos.emplace_back(&s8_matrix_mul); | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
direct_algos.emplace_back(&f16_direct_stride2_large_group); | |||||
direct_algos.emplace_back(&f16_direct_stride2_small_group); | |||||
direct_algos.emplace_back(&f16_direct_stride2); | |||||
#endif | #endif | ||||
direct_algos.emplace_back(&f32_direct_stride2_large_group); | |||||
direct_algos.emplace_back(&f32_direct_stride2_small_group); | |||||
direct_algos.emplace_back(&f32_direct_stride2); | |||||
} | } | ||||
SmallVector<AlgoBase*> direct_algos; | SmallVector<AlgoBase*> direct_algos; | ||||
SmallVector<AlgoBase*> matmul_algos; | SmallVector<AlgoBase*> matmul_algos; | ||||
@@ -192,9 +192,8 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP16WinogradF23_8x8, | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_kimpl) | MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_kimpl) | ||||
bool ConvBiasImpl::AlgoF16Direct::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool ConvBiasImpl::AlgoF16Direct::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 0) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 0) { | ||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
@@ -203,20 +202,14 @@ bool ConvBiasImpl::AlgoF16Direct::usable( | |||||
// ``param.osz[0]*param.osz[1] >= 8'' comes from the fact that the | // ``param.osz[0]*param.osz[1] >= 8'' comes from the fact that the | ||||
// kernel may have access to up to 8 fp16 after the end of the memory | // kernel may have access to up to 8 fp16 after the end of the memory | ||||
// chunk. | // chunk. | ||||
bool aviliable = fm.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float16 && | |||||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && | |||||
param.isz[0] * param.isz[1] >= 8 && | |||||
param.osz[0] * param.osz[1] >= 8 && FH <= 7 && | |||||
SH == 1 && SW == 1; | |||||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
return fm.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float16 && | |||||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && param.isz[0] * param.isz[1] >= 8 && | |||||
param.osz[0] * param.osz[1] >= 8 && FH <= 7 && SH == 1 && | |||||
SW == 1; | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return false; | return false; | ||||
@@ -225,9 +218,10 @@ bool ConvBiasImpl::AlgoF16Direct::usable( | |||||
size_t ConvBiasImpl::AlgoF16Direct::get_workspace( | size_t ConvBiasImpl::AlgoF16Direct::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto wbundle = | auto wbundle = | ||||
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle( | MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle( | ||||
param, m_large_group); | |||||
param, large_group); | |||||
return wbundle.total_size_in_bytes(); | return wbundle.total_size_in_bytes(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -241,13 +235,14 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls( | |||||
size_t IC = param.filter_meta.icpg; | size_t IC = param.filter_meta.icpg; | ||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
size_t group = fm.group; | size_t group = fm.group; | ||||
bool large_group = group >= param.nr_threads; | |||||
WorkspaceBundle bundle = | WorkspaceBundle bundle = | ||||
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle( | MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle( | ||||
param, m_large_group); | |||||
param, large_group); | |||||
SmallVector<NCBKern> ret_kerns; | SmallVector<NCBKern> ret_kerns; | ||||
//! When group >= nr_threads, treat it as large_group, each thread process | //! When group >= nr_threads, treat it as large_group, each thread process | ||||
//! one group for better performance | //! one group for better performance | ||||
if (m_large_group) { | |||||
if (large_group) { | |||||
//! Channel wise conv and big groups | //! Channel wise conv and big groups | ||||
auto exec_one_group = [bundle](const NCBKernParam& kern_param, | auto exec_one_group = [bundle](const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index) mutable { | const NCBKernIndex& ncb_index) mutable { | ||||
@@ -316,27 +311,18 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::dispatch_kerns( | |||||
/* ===================== stride-1 algo ===================== */ | /* ===================== stride-1 algo ===================== */ | ||||
bool ConvBiasImpl::AlgoF16DirectStride1::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool ConvBiasImpl::AlgoF16DirectStride1::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 0) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 0) { | ||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
bool aviliable = | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float16 && | |||||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float16 && | |||||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && | |||||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return false; | return false; | ||||
@@ -351,6 +337,7 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( | |||||
size_t IC = param.filter_meta.icpg; | size_t IC = param.filter_meta.icpg; | ||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
size_t group = fm.group; | size_t group = fm.group; | ||||
bool large_group = group >= param.nr_threads; | |||||
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | ||||
size_t, size_t, size_t, size_t, size_t)>; | size_t, size_t, size_t, size_t, size_t)>; | ||||
Func conv_kern_function = nullptr; | Func conv_kern_function = nullptr; | ||||
@@ -371,11 +358,11 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( | |||||
WorkspaceBundle bundle = | WorkspaceBundle bundle = | ||||
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle_stride( | MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle_stride( | ||||
param, m_large_group); | |||||
param, large_group); | |||||
SmallVector<NCBKern> ret_kerns; | SmallVector<NCBKern> ret_kerns; | ||||
//! When group >= nr_threads, treat it as large_group, each thread process | //! When group >= nr_threads, treat it as large_group, each thread process | ||||
//! one group for better performance | //! one group for better performance | ||||
if (m_large_group) { | |||||
if (large_group) { | |||||
//! Channel wise conv and big groups | //! Channel wise conv and big groups | ||||
auto exec_one_group = [bundle, conv_kern_function]( | auto exec_one_group = [bundle, conv_kern_function]( | ||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
@@ -423,8 +410,9 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( | |||||
size_t ConvBiasImpl::AlgoF16DirectStride1::get_workspace( | size_t ConvBiasImpl::AlgoF16DirectStride1::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 1) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 1) { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = MultithreadDirectConvCommon< | auto bundle = MultithreadDirectConvCommon< | ||||
dt_float16, __fp16>::get_bundle_stride(param, m_large_group); | |||||
dt_float16, __fp16>::get_bundle_stride(param, large_group); | |||||
return bundle.total_size_in_bytes(); | return bundle.total_size_in_bytes(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -79,15 +79,10 @@ public: | |||||
class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { | class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoF16Direct(bool is_large_group) : m_large_group{is_large_group} {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "F16DIRECT_LARGE_GROUP" | |||||
: "F16DIRECT_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "F16DIRECT"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -99,14 +94,10 @@ public: | |||||
class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoF16DirectStride1(bool is_large_group) : m_large_group{is_large_group} {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "F16STRD1_LARGE_GROUP" : "F16STRD1_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "F16STRD1"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
@@ -334,9 +334,8 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF63_4x4_NCHW44, | |||||
/* ===================== direct algo ===================== */ | /* ===================== direct algo ===================== */ | ||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl); | MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl); | ||||
bool ConvBiasImpl::AlgoF32Direct::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool ConvBiasImpl::AlgoF32Direct::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) { | ||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
@@ -345,20 +344,14 @@ bool ConvBiasImpl::AlgoF32Direct::usable( | |||||
// ``param.osz[0]*param.osz[1] >= 4'' comes from the fact that the | // ``param.osz[0]*param.osz[1] >= 4'' comes from the fact that the | ||||
// kernel may have access to up to 4 floats after the end of the memory | // kernel may have access to up to 4 floats after the end of the memory | ||||
// chunk. | // chunk. | ||||
bool aviliable = fm.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && | |||||
param.isz[0] * param.isz[1] >= 4 && | |||||
param.osz[0] * param.osz[1] >= 4 && FH <= 7 && | |||||
SH == 1 && SW == 1; | |||||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
return fm.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && param.isz[0] * param.isz[1] >= 4 && | |||||
param.osz[0] * param.osz[1] >= 4 && FH <= 7 && SH == 1 && | |||||
SW == 1; | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return false; | return false; | ||||
@@ -366,8 +359,9 @@ bool ConvBiasImpl::AlgoF32Direct::usable( | |||||
size_t ConvBiasImpl::AlgoF32Direct::get_workspace( | size_t ConvBiasImpl::AlgoF32Direct::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto wbundle = MultithreadDirectConvCommon<float, float>::get_bundle( | auto wbundle = MultithreadDirectConvCommon<float, float>::get_bundle( | ||||
param, m_large_group); | |||||
param, large_group); | |||||
return wbundle.total_size_in_bytes(); | return wbundle.total_size_in_bytes(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -380,13 +374,14 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls( | |||||
size_t IC = param.filter_meta.icpg; | size_t IC = param.filter_meta.icpg; | ||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
size_t group = fm.group; | size_t group = fm.group; | ||||
bool large_group = group >= param.nr_threads; | |||||
WorkspaceBundle bundle = | WorkspaceBundle bundle = | ||||
MultithreadDirectConvCommon<float, float>::get_bundle( | MultithreadDirectConvCommon<float, float>::get_bundle( | ||||
param, m_large_group); | |||||
param, large_group); | |||||
SmallVector<NCBKern> ret_kerns; | SmallVector<NCBKern> ret_kerns; | ||||
//! When group >= nr_threads, treat it as large_group, each thread process | //! When group >= nr_threads, treat it as large_group, each thread process | ||||
//! one group for better performance | //! one group for better performance | ||||
if (m_large_group) { | |||||
if (large_group) { | |||||
//! Channel wise conv and big groups | //! Channel wise conv and big groups | ||||
auto exec_one_group = [bundle](const NCBKernParam& kern_param, | auto exec_one_group = [bundle](const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index) mutable { | const NCBKernIndex& ncb_index) mutable { | ||||
@@ -452,27 +447,19 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::dispatch_kerns( | |||||
return {}; | return {}; | ||||
} | } | ||||
/* ===================== stride-1 algo ===================== */ | /* ===================== stride-1 algo ===================== */ | ||||
bool ConvBiasImpl::AlgoF32DirectStride1::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool ConvBiasImpl::AlgoF32DirectStride1::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | ||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
bool aviliable = | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && | |||||
FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return false; | return false; | ||||
@@ -481,9 +468,10 @@ bool ConvBiasImpl::AlgoF32DirectStride1::usable( | |||||
size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace( | size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = | auto bundle = | ||||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | MultithreadDirectConvCommon<float, float>::get_bundle_stride( | ||||
param, m_large_group); | |||||
param, large_group); | |||||
return bundle.total_size_in_bytes(); | return bundle.total_size_in_bytes(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -499,6 +487,7 @@ ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( | |||||
size_t IC = param.filter_meta.icpg; | size_t IC = param.filter_meta.icpg; | ||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
size_t group = fm.group; | size_t group = fm.group; | ||||
bool large_group = group >= param.nr_threads; | |||||
using Func = std::function<void(const float*, const float*, float*, size_t, | using Func = std::function<void(const float*, const float*, float*, size_t, | ||||
size_t, size_t, size_t, size_t)>; | size_t, size_t, size_t, size_t)>; | ||||
Func conv_kern_function = nullptr; | Func conv_kern_function = nullptr; | ||||
@@ -522,11 +511,11 @@ ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( | |||||
WorkspaceBundle bundle = | WorkspaceBundle bundle = | ||||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | MultithreadDirectConvCommon<float, float>::get_bundle_stride( | ||||
param, m_large_group); | |||||
param, large_group); | |||||
SmallVector<NCBKern> ret_kerns; | SmallVector<NCBKern> ret_kerns; | ||||
//! When group >= nr_threads, treat it as large_group, each thread process | //! When group >= nr_threads, treat it as large_group, each thread process | ||||
//! one group for better performance | //! one group for better performance | ||||
if (m_large_group) { | |||||
if (large_group) { | |||||
//! Channel wise conv and big groups | //! Channel wise conv and big groups | ||||
auto exec_one_group = [bundle, conv_kern_function]( | auto exec_one_group = [bundle, conv_kern_function]( | ||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
@@ -580,27 +569,19 @@ ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns( | |||||
/* ===================== stride-2 algo ===================== */ | /* ===================== stride-2 algo ===================== */ | ||||
bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) { | ||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
bool aviliable = | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||||
FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return false; | return false; | ||||
@@ -608,9 +589,10 @@ bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||||
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = | auto bundle = | ||||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | MultithreadDirectConvCommon<float, float>::get_bundle_stride( | ||||
param, m_large_group); | |||||
param, large_group); | |||||
return bundle.total_size_in_bytes(); | return bundle.total_size_in_bytes(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -625,6 +607,7 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
size_t IC = param.filter_meta.icpg; | size_t IC = param.filter_meta.icpg; | ||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
size_t group = fm.group; | size_t group = fm.group; | ||||
bool large_group = group >= param.nr_threads; | |||||
using Func = std::function<void(const float*, const float*, float*, size_t, | using Func = std::function<void(const float*, const float*, float*, size_t, | ||||
size_t, size_t, size_t, size_t)>; | size_t, size_t, size_t, size_t)>; | ||||
Func conv_kern_function = nullptr; | Func conv_kern_function = nullptr; | ||||
@@ -648,11 +631,11 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
WorkspaceBundle bundle = | WorkspaceBundle bundle = | ||||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | MultithreadDirectConvCommon<float, float>::get_bundle_stride( | ||||
param, m_large_group); | |||||
param, large_group); | |||||
SmallVector<NCBKern> ret_kerns; | SmallVector<NCBKern> ret_kerns; | ||||
//! When group >= nr_threads, treat it as large_group, each thread process | //! When group >= nr_threads, treat it as large_group, each thread process | ||||
//! one group for better performance | //! one group for better performance | ||||
if (m_large_group) { | |||||
if (large_group) { | |||||
//! Channel wise conv and big groups | //! Channel wise conv and big groups | ||||
auto exec_one_group = [bundle, conv_kern_function]( | auto exec_one_group = [bundle, conv_kern_function]( | ||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
@@ -128,15 +128,10 @@ public: | |||||
class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { | class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoF32Direct(bool is_large_group) : m_large_group{is_large_group} {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "F32DIRECT_LARGE_GROUP" | |||||
: "F32DIRECT_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "F32DIRECT"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -147,14 +142,10 @@ public: | |||||
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoF32DirectStride1(bool is_large_group) : m_large_group{is_large_group} {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "F32STRD1_LARGE_GROUP" : "F32STRD1_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "F32STRD1"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -165,14 +156,10 @@ public: | |||||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoF32DirectStride2(bool is_large_group) : m_large_group{is_large_group} {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "F32STRD2_LARGE_GROUP" : "F32STRD2_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "F32STRD2"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -27,17 +27,10 @@ using namespace arm_common; | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8) | MIDOUT_DECL(megdnn_arm_common_conv_bias_int8) | ||||
/* ===================== stride1 algo ===================== */ | /* ===================== stride1 algo ===================== */ | ||||
bool ConvBiasImpl::AlgoS8DirectStride1::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool avaible = direct_int8_stride1::can_conv_direct_stride1_int8(param); | |||||
auto fm = param.filter_meta; | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = fm.group >= param.nr_threads; | |||||
avaible &= (large_group == m_large_group); | |||||
} | |||||
return avaible; | |||||
bool ConvBiasImpl::AlgoS8DirectStride1::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
return direct_int8_stride1::can_conv_direct_stride1_int8(param); | |||||
} | } | ||||
bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred( | bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
@@ -53,8 +46,9 @@ bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred( | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoS8DirectStride1::get_workspace( | size_t ConvBiasImpl::AlgoS8DirectStride1::get_workspace( | ||||
const NCBKernSizeParam& param) const { | |||||
auto bundle = direct_int8_stride1::get_bundle(param, m_large_group); | |||||
const NCBKernSizeParam& param) const { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_int8_stride1::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | return bundle.total_size_in_bytes(); | ||||
} | } | ||||
@@ -62,7 +56,8 @@ SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns( | ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 0) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 0) { | ||||
return direct_int8_stride1::get_kimpls(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
return direct_int8_stride1::get_kimpls(param, large_group); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return {}; | return {}; | ||||
@@ -117,21 +112,15 @@ ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::dispatch_kerns( | |||||
} | } | ||||
/* ===================== stride2 algo ===================== */ | /* ===================== stride2 algo ===================== */ | ||||
bool ConvBiasImpl::AlgoS8DirectStride2::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool avaible = direct_int8_stride2::can_conv_direct_stride2_int8(param); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
avaible &= (large_group == m_large_group); | |||||
} | |||||
return avaible; | |||||
bool ConvBiasImpl::AlgoS8DirectStride2::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
return direct_int8_stride2::can_conv_direct_stride2_int8(param); | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoS8DirectStride2::get_workspace( | size_t ConvBiasImpl::AlgoS8DirectStride2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
auto bundle = direct_int8_stride2::get_bundle(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_int8_stride2::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | return bundle.total_size_in_bytes(); | ||||
} | } | ||||
@@ -139,7 +128,8 @@ SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( | ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 1) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 1) { | ||||
return direct_int8_stride2::get_kimpls(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
return direct_int8_stride2::get_kimpls(param, large_group); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return {}; | return {}; | ||||
@@ -147,24 +137,15 @@ ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
/* ===================== dot stride1 algo ======================== */ | /* ===================== dot stride1 algo ======================== */ | ||||
bool ConvBiasImpl::AlgoDotS8DirectStride1::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool avaible = | |||||
direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
avaible &= (large_group == m_large_group); | |||||
} | |||||
return avaible; | |||||
bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
return direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param); | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoDotS8DirectStride1::get_workspace( | size_t ConvBiasImpl::AlgoDotS8DirectStride1::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | return bundle.total_size_in_bytes(); | ||||
} | } | ||||
@@ -172,29 +153,23 @@ SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns( | ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 1) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 1) { | ||||
return direct_dotprod_int8_stride1::get_kimpls(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
return direct_dotprod_int8_stride1::get_kimpls(param, large_group); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return {}; | return {}; | ||||
} | } | ||||
/* ===================== dot stride2 algo ======================== */ | /* ===================== dot stride2 algo ======================== */ | ||||
bool ConvBiasImpl::AlgoDotS8DirectStride2::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool avaible = | |||||
direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
avaible &= (large_group == m_large_group); | |||||
} | |||||
return avaible; | |||||
bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
return direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param); | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoDotS8DirectStride2::get_workspace( | size_t ConvBiasImpl::AlgoDotS8DirectStride2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | |||||
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, m_large_group); | |||||
const NCBKernSizeParam& param) const { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | return bundle.total_size_in_bytes(); | ||||
} | } | ||||
@@ -202,7 +177,8 @@ SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns( | ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 2) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 2) { | ||||
return direct_dotprod_int8_stride2::get_kimpls(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
return direct_dotprod_int8_stride2::get_kimpls(param, large_group); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return {}; | return {}; | ||||
@@ -18,14 +18,10 @@ namespace megdnn { | |||||
namespace arm_common { | namespace arm_common { | ||||
class ConvBiasImpl::AlgoS8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectStride1 final : public AlgoBase { | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoS8DirectStride1(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "S8STRD1_LARGE_GROUP" : "S8STRD1_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "S8STRD1"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
@@ -36,14 +32,10 @@ public: | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoS8DirectStride2(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "S8STRD2_LARGE_GROUP" : "S8STRD2_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "S8STRD2"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -115,16 +107,10 @@ public: | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoDotS8DirectStride1(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "ARMDOTS8STRD1_LARGE_GROUP" | |||||
: "ARMDOTS8STRD1_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "ARMDOTS8STRD1"; } | |||||
bool usable(const NCBKernSizeParam&, | bool usable(const NCBKernSizeParam&, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -134,15 +120,10 @@ public: | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoDotS8DirectStride2(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "ARMDOTS8STRD2_LARGE_GROUP" | |||||
: "ARMDOTS8STRD2_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "ARMDOTS8STRD2"; } | |||||
bool usable(const NCBKernSizeParam&, | bool usable(const NCBKernSizeParam&, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -82,28 +82,20 @@ void get_rectified_size_str2(size_t IH, size_t IW, size_t OH, size_t OW, | |||||
} // namespace | } // namespace | ||||
/* ===================== direct algo ===================== */ | /* ===================== direct algo ===================== */ | ||||
bool ConvBiasImpl::AlgoI8x8x16Direct::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool ConvBiasImpl::AlgoI8x8x16Direct::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 0) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 0) { | ||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
bool aviliable = | |||||
param.bias_mode == BiasMode::NO_BIAS && | |||||
param.nonlineMode == NonlineMode::IDENTITY && | |||||
fm.format == param::ConvBias::Format::NCHW && !fm.should_flip && | |||||
param.src_type.enumv() == DTypeEnum::Int8 && | |||||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||||
param.dst_type.enumv() == DTypeEnum::Int16 && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && | |||||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
return param.bias_mode == BiasMode::NO_BIAS && | |||||
param.nonlineMode == NonlineMode::IDENTITY && | |||||
fm.format == param::ConvBias::Format::NCHW && !fm.should_flip && | |||||
param.src_type.enumv() == DTypeEnum::Int8 && | |||||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||||
param.dst_type.enumv() == DTypeEnum::Int16 && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && | |||||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return false; | return false; | ||||
@@ -117,11 +109,12 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Direct::get_bundle( | |||||
auto OH = param.osz[0], OW = param.osz[1]; | auto OH = param.osz[0], OW = param.osz[1]; | ||||
auto PH = fm.padding[0], PW = fm.padding[1]; | auto PH = fm.padding[0], PW = fm.padding[1]; | ||||
size_t OH2, OW2, IH2, IW2; | size_t OH2, OW2, IH2, IW2; | ||||
bool large_group = group >= param.nr_threads; | |||||
get_rectified_size_str1(IH, IW, OH, OW, PH, PW, IH2, IW2, OH2, OW2); | get_rectified_size_str1(IH, IW, OH, OW, PH, PW, IH2, IW2, OH2, OW2); | ||||
size_t part0 = 0u, part1 = 0u; | size_t part0 = 0u, part1 = 0u; | ||||
if (need_src_copy_str1(param)) { | if (need_src_copy_str1(param)) { | ||||
part0 = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads | |||||
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch; | |||||
part0 = large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads | |||||
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch; | |||||
} | } | ||||
if (need_dst_copy_str1(param)) { | if (need_dst_copy_str1(param)) { | ||||
part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16; | part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16; | ||||
@@ -255,9 +248,10 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Direct::get_kimpls( | |||||
size_t IC = param.filter_meta.icpg; | size_t IC = param.filter_meta.icpg; | ||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
size_t group = fm.group; | size_t group = fm.group; | ||||
bool large_group = group >= param.nr_threads; | |||||
WorkspaceBundle bundle = get_bundle(param); | WorkspaceBundle bundle = get_bundle(param); | ||||
SmallVector<NCBKern> ret_kerns; | SmallVector<NCBKern> ret_kerns; | ||||
if (m_large_group) { | |||||
if (large_group) { | |||||
auto exec_one_group = [bundle](const NCBKernParam& kern_param, | auto exec_one_group = [bundle](const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index) mutable { | const NCBKernIndex& ncb_index) mutable { | ||||
auto fm = kern_param.filter_meta; | auto fm = kern_param.filter_meta; | ||||
@@ -302,28 +296,20 @@ ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns( | |||||
} | } | ||||
/* ===================== stride-2 algo ===================== */ | /* ===================== stride-2 algo ===================== */ | ||||
bool ConvBiasImpl::AlgoI8x8x16Stride2::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool ConvBiasImpl::AlgoI8x8x16Stride2::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 0) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 0) { | ||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
bool aviliable = param.bias_mode == BiasMode::NO_BIAS && | |||||
param.nonlineMode == NonlineMode::IDENTITY && | |||||
fm.format == param::ConvBias::Format::NCHW && | |||||
!fm.should_flip && | |||||
param.src_type.enumv() == DTypeEnum::Int8 && | |||||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||||
param.dst_type.enumv() == DTypeEnum::Int16 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 2 && fm.stride[1] == 2 && | |||||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
return param.bias_mode == BiasMode::NO_BIAS && | |||||
param.nonlineMode == NonlineMode::IDENTITY && | |||||
fm.format == param::ConvBias::Format::NCHW && !fm.should_flip && | |||||
param.src_type.enumv() == DTypeEnum::Int8 && | |||||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||||
param.dst_type.enumv() == DTypeEnum::Int16 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return false; | return false; | ||||
@@ -340,9 +326,10 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Stride2::get_bundle( | |||||
size_t OH2, OW2, IH2, IW2; | size_t OH2, OW2, IH2, IW2; | ||||
get_rectified_size_str2(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); | get_rectified_size_str2(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); | ||||
size_t part0 = 0u, part1 = 0u; | size_t part0 = 0u, part1 = 0u; | ||||
bool large_group = group >= param.nr_threads; | |||||
if (need_src_copy_str2(param)) { | if (need_src_copy_str2(param)) { | ||||
part0 = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads | |||||
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch; | |||||
part0 = large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads | |||||
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch; | |||||
} | } | ||||
if (need_dst_copy_str2(param)) { | if (need_dst_copy_str2(param)) { | ||||
part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16; | part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16; | ||||
@@ -475,9 +462,10 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Stride2::get_kimpls( | |||||
size_t IC = param.filter_meta.icpg; | size_t IC = param.filter_meta.icpg; | ||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
size_t group = fm.group; | size_t group = fm.group; | ||||
bool large_group = group >= param.nr_threads; | |||||
WorkspaceBundle bundle = get_bundle(param); | WorkspaceBundle bundle = get_bundle(param); | ||||
SmallVector<NCBKern> ret_kerns; | SmallVector<NCBKern> ret_kerns; | ||||
if (m_large_group) { | |||||
if (large_group) { | |||||
auto exec_one_group = [bundle](const NCBKernParam& kern_param, | auto exec_one_group = [bundle](const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index) mutable { | const NCBKernIndex& ncb_index) mutable { | ||||
auto fm = kern_param.filter_meta; | auto fm = kern_param.filter_meta; | ||||
@@ -26,15 +26,10 @@ class ConvBiasImpl::AlgoI8x8x16Direct final : public AlgoBase { | |||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index, | const NCBKernIndex& ncb_index, | ||||
const CpuNDRange& workspace_ids); | const CpuNDRange& workspace_ids); | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoI8x8x16Direct(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "I8816DIRECT_LARGE_GROUP" | |||||
: "I8816DIRECT_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "I8816DIRECT"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
@@ -53,15 +48,9 @@ class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { | |||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index, | const NCBKernIndex& ncb_index, | ||||
const CpuNDRange& workspace_ids); | const CpuNDRange& workspace_ids); | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoI8x8x16Stride2(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "I8816STRD2_LARGE_GROUP" | |||||
: "I8816STRD2_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "I8816STRD2"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -40,28 +40,20 @@ uint8_t arm_common_algo_type_storage; | |||||
} // anonymous namespace | } // anonymous namespace | ||||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | class ConvBiasImpl::AlgoPack : NonCopyableObj { | ||||
AlgoQU8DirectStride2 qu8_direct_stride2_large_group{true}; | |||||
AlgoQU8DirectStride2 qu8_direct_stride2_small_group{false}; | |||||
AlgoQU8DirectStride1 qu8_direct_stride1_large_group{true}; | |||||
AlgoQU8DirectStride1 qu8_direct_stride1_small_group{false}; | |||||
AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; | |||||
AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; | |||||
AlgoQU8DirectStride2 qu8_direct_stride2; | |||||
AlgoQU8DirectStride1 qu8_direct_stride1; | |||||
AlgoS8DirectStride2 s8_direct_stride2; | |||||
AlgoS8DirectNCHW44 s8_direct_nchw44; | AlgoS8DirectNCHW44 s8_direct_nchw44; | ||||
AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; | AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; | ||||
AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; | |||||
AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; | |||||
AlgoS8DirectStride1 s8_direct_stride1; | |||||
AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; | AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; | ||||
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; | AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true}; | |||||
AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false}; | |||||
AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true}; | |||||
AlgoDotS8DirectStride2 ds8_direct_stride2_small_group{false}; | |||||
AlgoDotU8DirectStride1 du8_direct_stride1_large_group{true}; | |||||
AlgoDotU8DirectStride1 du8_direct_stride1_small_group{false}; | |||||
AlgoDotU8DirectStride2 du8_direct_stride2_large_group{true}; | |||||
AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false}; | |||||
AlgoDotS8DirectStride1 ds8_direct_stride1; | |||||
AlgoDotS8DirectStride2 ds8_direct_stride2; | |||||
AlgoDotU8DirectStride1 du8_direct_stride1; | |||||
AlgoDotU8DirectStride2 du8_direct_stride2; | |||||
AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; | AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; | ||||
AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; | AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; | ||||
@@ -71,23 +63,16 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44; | AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44; | ||||
AlgoF32DirectNCHW44 f32_direct_nchw44; | AlgoF32DirectNCHW44 f32_direct_nchw44; | ||||
AlgoF32Direct f32_direct_large_group{true}; | |||||
AlgoF32Direct f32_direct_small_group{false}; | |||||
AlgoF32DirectStride2 f32_direct_stride2_large_group{true}; | |||||
AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; | |||||
AlgoF32DirectStride1 f32_direct_stride1_large_group{true}; | |||||
AlgoF32DirectStride1 f32_direct_stride1_small_group{false}; | |||||
AlgoF32Direct f32_direct; | |||||
AlgoF32DirectStride2 f32_direct_stride2; | |||||
AlgoF32DirectStride1 f32_direct_stride1; | |||||
AlgoI8x8x16Direct i8x8x16_direct_large_group{true}; | |||||
AlgoI8x8x16Direct i8x8x16_direct_small_group{false}; | |||||
AlgoI8x8x16Stride2 i8x8x16_stride2_large_group{true}; | |||||
AlgoI8x8x16Stride2 i8x8x16_stride2_small_group{false}; | |||||
AlgoI8x8x16Direct i8x8x16_direct; | |||||
AlgoI8x8x16Stride2 i8x8x16_stride2; | |||||
AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2; | AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
AlgoF16Direct f16_direct_large_group{true}; | |||||
AlgoF16Direct f16_direct_small_group{false}; | |||||
AlgoF16DirectStride1 f16_direct_stride1_large_group{true}; | |||||
AlgoF16DirectStride1 f16_direct_stride1_small_group{false}; | |||||
AlgoF16Direct f16_direct; | |||||
AlgoF16DirectStride1 f16_direct_stride1; | |||||
#endif | #endif | ||||
SmallVector<std::unique_ptr<AlgoBase>> refhold; | SmallVector<std::unique_ptr<AlgoBase>> refhold; | ||||
@@ -95,54 +80,39 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
public: | public: | ||||
AlgoPack() { | AlgoPack() { | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
direct_algos.emplace_back(&ds8_direct_stride1_large_group); | |||||
direct_algos.emplace_back(&ds8_direct_stride1_small_group); | |||||
direct_algos.emplace_back(&ds8_direct_stride2_large_group); | |||||
direct_algos.emplace_back(&ds8_direct_stride2_small_group); | |||||
direct_algos.emplace_back(&du8_direct_stride1_large_group); | |||||
direct_algos.emplace_back(&du8_direct_stride1_small_group); | |||||
direct_algos.emplace_back(&du8_direct_stride2_large_group); | |||||
direct_algos.emplace_back(&du8_direct_stride2_small_group); | |||||
direct_algos.emplace_back(&ds8_direct_stride1); | |||||
direct_algos.emplace_back(&ds8_direct_stride2); | |||||
direct_algos.emplace_back(&du8_direct_stride1); | |||||
direct_algos.emplace_back(&du8_direct_stride2); | |||||
direct_algos.emplace_back(&ds8_direct_nchw44); | direct_algos.emplace_back(&ds8_direct_nchw44); | ||||
direct_algos.emplace_back(&ds8_direct_nchw_nchw44); | direct_algos.emplace_back(&ds8_direct_nchw_nchw44); | ||||
#endif | #endif | ||||
direct_algos.emplace_back(&qu8_direct_stride2_large_group); | |||||
direct_algos.emplace_back(&qu8_direct_stride2_small_group); | |||||
direct_algos.emplace_back(&qu8_direct_stride1_large_group); | |||||
direct_algos.emplace_back(&qu8_direct_stride1_small_group); | |||||
direct_algos.emplace_back(&s8_direct_stride2_large_group); | |||||
direct_algos.emplace_back(&s8_direct_stride2_small_group); | |||||
direct_algos.emplace_back(&qu8_direct_stride2); | |||||
direct_algos.emplace_back(&qu8_direct_stride1); | |||||
direct_algos.emplace_back(&s8_direct_stride2); | |||||
direct_algos.emplace_back(&s8_direct_nchw44); | direct_algos.emplace_back(&s8_direct_nchw44); | ||||
direct_algos.emplace_back(&s8_direct_nchw_nchw44); | direct_algos.emplace_back(&s8_direct_nchw_nchw44); | ||||
direct_algos.emplace_back(&s8_direct_stride1_large_group); | |||||
direct_algos.emplace_back(&s8_direct_stride1_small_group); | |||||
direct_algos.emplace_back(&s8_direct_stride1); | |||||
direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | ||||
direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
direct_algos.emplace_back(&f16_direct_stride1_large_group); | |||||
direct_algos.emplace_back(&f16_direct_stride1_small_group); | |||||
direct_algos.emplace_back(&f16_direct_large_group); | |||||
direct_algos.emplace_back(&f16_direct_small_group); | |||||
direct_algos.emplace_back(&f16_direct_stride1); | |||||
direct_algos.emplace_back(&f16_direct); | |||||
#endif | #endif | ||||
direct_algos.emplace_back(&i8x8x16_direct_large_group); | |||||
direct_algos.emplace_back(&i8x8x16_direct_small_group); | |||||
direct_algos.emplace_back(&i8x8x16_direct); | |||||
direct_algos.emplace_back(&i8x8x16_stride2_filter2); | direct_algos.emplace_back(&i8x8x16_stride2_filter2); | ||||
direct_algos.emplace_back(&i8x8x16_stride2_large_group); | |||||
direct_algos.emplace_back(&i8x8x16_stride2_small_group); | |||||
direct_algos.emplace_back(&i8x8x16_stride2); | |||||
direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | ||||
direct_algos.emplace_back(&f32_chanel_wise_nchw44); | direct_algos.emplace_back(&f32_chanel_wise_nchw44); | ||||
direct_algos.emplace_back(&f32_direct_nchw44); | direct_algos.emplace_back(&f32_direct_nchw44); | ||||
direct_algos.emplace_back(&f32_direct_stride1_large_group); | |||||
direct_algos.emplace_back(&f32_direct_stride1_small_group); | |||||
direct_algos.emplace_back(&f32_direct_stride2_large_group); | |||||
direct_algos.emplace_back(&f32_direct_stride2_small_group); | |||||
direct_algos.emplace_back(&f32_direct_large_group); | |||||
direct_algos.emplace_back(&f32_direct_small_group); | |||||
direct_algos.emplace_back(&f32_direct_stride1); | |||||
direct_algos.emplace_back(&f32_direct_stride2); | |||||
direct_algos.emplace_back(&f32_direct); | |||||
static CpuOprDelegationStorage<2> storage; | static CpuOprDelegationStorage<2> storage; | ||||
auto matmul_opr = storage.get<MatrixMul, 0>(); | auto matmul_opr = storage.get<MatrixMul, 0>(); | ||||
@@ -25,21 +25,15 @@ using namespace megdnn; | |||||
using namespace arm_common; | using namespace arm_common; | ||||
/* ===================== stride1 algo ===================== */ | /* ===================== stride1 algo ===================== */ | ||||
bool ConvBiasImpl::AlgoQU8DirectStride1::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool avaible = direct_quint8_stride1::can_conv_direct_stride1_quint8(param); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
avaible &= (large_group == m_large_group); | |||||
} | |||||
return avaible; | |||||
bool ConvBiasImpl::AlgoQU8DirectStride1::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
return direct_quint8_stride1::can_conv_direct_stride1_quint8(param); | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace( | size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
auto bundle = direct_quint8_stride1::get_bundle(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_quint8_stride1::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | return bundle.total_size_in_bytes(); | ||||
} | } | ||||
@@ -47,7 +41,8 @@ SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( | ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 0) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 0) { | ||||
return direct_quint8_stride1::get_kimpls(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
return direct_quint8_stride1::get_kimpls(param, large_group); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return {}; | return {}; | ||||
@@ -55,20 +50,15 @@ ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( | |||||
/* ===================== stride2 algo ===================== */ | /* ===================== stride2 algo ===================== */ | ||||
bool ConvBiasImpl::AlgoQU8DirectStride2::usable( | bool ConvBiasImpl::AlgoQU8DirectStride2::usable( | ||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool avaible = direct_quint8_stride2::can_conv_direct_stride2_quint8(param); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
avaible &= (large_group == m_large_group); | |||||
} | |||||
return avaible; | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
return direct_quint8_stride2::can_conv_direct_stride2_quint8(param); | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace( | size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
auto bundle = direct_quint8_stride2::get_bundle(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_quint8_stride2::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | return bundle.total_size_in_bytes(); | ||||
} | } | ||||
@@ -76,31 +66,23 @@ SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns( | ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 1) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 1) { | ||||
return direct_quint8_stride2::get_kimpls(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
return direct_quint8_stride2::get_kimpls(param, large_group); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return {}; | return {}; | ||||
} | } | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
/* ===================== stride1 algo ===================== */ | /* ===================== stride1 algo ===================== */ | ||||
bool ConvBiasImpl::AlgoDotU8DirectStride1::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool avaible = | |||||
direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8( | |||||
param); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
avaible &= (large_group == m_large_group); | |||||
} | |||||
return avaible; | |||||
bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
return direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8(param); | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace( | size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
auto bundle = | |||||
direct_dotprod_quint8_stride1::get_bundle(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | return bundle.total_size_in_bytes(); | ||||
} | } | ||||
@@ -108,31 +90,23 @@ SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns( | ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 0) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 0) { | ||||
return direct_dotprod_quint8_stride1::get_kimpls(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
return direct_dotprod_quint8_stride1::get_kimpls(param, large_group); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return {}; | return {}; | ||||
} | } | ||||
/* ===================== stride2 algo ===================== */ | /* ===================== stride2 algo ===================== */ | ||||
bool ConvBiasImpl::AlgoDotU8DirectStride2::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool avaible = | |||||
direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8( | |||||
param); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
avaible &= (large_group == m_large_group); | |||||
} | |||||
return avaible; | |||||
bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
return direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8(param); | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace( | size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
auto bundle = | |||||
direct_dotprod_quint8_stride2::get_bundle(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | return bundle.total_size_in_bytes(); | ||||
} | } | ||||
@@ -140,7 +114,8 @@ SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns( | ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 1) { | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 1) { | ||||
return direct_dotprod_quint8_stride2::get_kimpls(param, m_large_group); | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
return direct_dotprod_quint8_stride2::get_kimpls(param, large_group); | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return {}; | return {}; | ||||
@@ -18,14 +18,10 @@ namespace megdnn { | |||||
namespace arm_common { | namespace arm_common { | ||||
class ConvBiasImpl::AlgoQU8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoQU8DirectStride1 final : public AlgoBase { | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoQU8DirectStride1(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "QU8STRD1_LARGE_GROUP" : "QU8STRD1_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "QU8STRD1"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -36,14 +32,10 @@ public: | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoQU8DirectStride2(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "QU8STRD2_LARGE_GROUP" : "QU8STRD2_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "QU8STRD2"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -53,15 +45,10 @@ public: | |||||
}; | }; | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoDotU8DirectStride1(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "ARMDOTU8STRD1_LARGE_GROUP" | |||||
: "ARMDOTU8STRD1_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "ARMDOTU8STRD1"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -72,15 +59,10 @@ public: | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoDotU8DirectStride2(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return m_large_group ? "ARMDOTU8STRD2_LARGE_GROUP" | |||||
: "ARMDOTU8STRD2_SMALL_GROUP"; | |||||
} | |||||
const char* name() const override { return "ARMDOTU8STRD2"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -65,9 +65,10 @@ void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, | |||||
size_t IC = param.filter_meta.icpg; \ | size_t IC = param.filter_meta.icpg; \ | ||||
size_t OC = param.filter_meta.ocpg; \ | size_t OC = param.filter_meta.ocpg; \ | ||||
size_t group = fm.group; \ | size_t group = fm.group; \ | ||||
bool large_group = group >= param.nr_threads; \ | |||||
WorkspaceBundle bundle = get_bundle(param); \ | WorkspaceBundle bundle = get_bundle(param); \ | ||||
SmallVector<NCBKern> ret_kerns; \ | SmallVector<NCBKern> ret_kerns; \ | ||||
if (m_large_group) { \ | |||||
if (large_group) { \ | |||||
auto exec_one_group = [bundle]( \ | auto exec_one_group = [bundle]( \ | ||||
const NCBKernParam& kern_param, \ | const NCBKernParam& kern_param, \ | ||||
const NCBKernIndex& ncb_index) mutable { \ | const NCBKernIndex& ncb_index) mutable { \ | ||||
@@ -104,22 +105,15 @@ void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, | |||||
/* ===================== direct algo ===================== */ | /* ===================== direct algo ===================== */ | ||||
bool ConvBiasImpl::AlgoDirect::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool ConvBiasImpl::AlgoDirect::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
bool aviliable = fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.spatial[0] <= 7 && fm.stride[0] == 1 && | |||||
fm.stride[1] == 1; | |||||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
return fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.spatial[0] <= 7 && | |||||
fm.stride[0] == 1 && fm.stride[1] == 1; | |||||
} | } | ||||
WorkspaceBundle ConvBiasImpl::AlgoDirect::get_bundle( | WorkspaceBundle ConvBiasImpl::AlgoDirect::get_bundle( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
@@ -133,9 +127,10 @@ WorkspaceBundle ConvBiasImpl::AlgoDirect::get_bundle( | |||||
get_rectified_img_size(IH, IW, FH, FW, OH, OW, fm.padding[0], fm.padding[1], | get_rectified_img_size(IH, IW, FH, FW, OH, OW, fm.padding[0], fm.padding[1], | ||||
IH2, IW2, OH2, OW2); | IH2, IW2, OH2, OW2); | ||||
size_t part0 = 0u, part1 = 0u; | size_t part0 = 0u, part1 = 0u; | ||||
bool large_group = group >= param.nr_threads; | |||||
if (IH != IH2 || IW != IW2) { | if (IH != IH2 || IW != IW2) { | ||||
part0 = m_large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads | |||||
: IC * IH2 * IW2 * sizeof(float) * group * batch; | |||||
part0 = large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads | |||||
: IC * IH2 * IW2 * sizeof(float) * group * batch; | |||||
} | } | ||||
if (OH != OH2 || OW != OW2) { | if (OH != OH2 || OW != OW2) { | ||||
part1 = OH2 * OW2 * sizeof(float) * nr_threads; | part1 = OH2 * OW2 * sizeof(float) * nr_threads; | ||||
@@ -319,24 +314,17 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDirect::get_kimpls( | |||||
GET_KERN; | GET_KERN; | ||||
} | } | ||||
/* ===================== direct-stride2 algo ===================== */ | /* ===================== direct-stride2 algo ===================== */ | ||||
bool ConvBiasImpl::AlgoDirectStride2::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
bool ConvBiasImpl::AlgoDirectStride2::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
bool aviliable = | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||||
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
} | } | ||||
WorkspaceBundle ConvBiasImpl::AlgoDirectStride2::get_bundle( | WorkspaceBundle ConvBiasImpl::AlgoDirectStride2::get_bundle( | ||||
@@ -352,10 +340,10 @@ WorkspaceBundle ConvBiasImpl::AlgoDirectStride2::get_bundle( | |||||
size_t src_size = 0, dst_size = 0; | size_t src_size = 0, dst_size = 0; | ||||
size_t IH2, IW2, OH2, OW2; | size_t IH2, IW2, OH2, OW2; | ||||
get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); | get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); | ||||
bool large_group = group >= param.nr_threads; \ | |||||
if (need_src_copy(param)) { | if (need_src_copy(param)) { | ||||
src_size = m_large_group | |||||
? IC * IH2 * IW2 * sizeof(float) * nr_threads | |||||
: IC * IH2 * IW2 * sizeof(float) * group * batch; | |||||
src_size = large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads | |||||
: IC * IH2 * IW2 * sizeof(float) * group * batch; | |||||
} | } | ||||
if (need_dst_copy(param)) { | if (need_dst_copy(param)) { | ||||
// we only need one dst plane | // we only need one dst plane | ||||
@@ -29,14 +29,10 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase { | |||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index, | const NCBKernIndex& ncb_index, | ||||
const CpuNDRange& workspace_ids); | const CpuNDRange& workspace_ids); | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoDirect(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | const char* name() const override { | ||||
return m_large_group ? "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP" | |||||
: "X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP"; | |||||
return "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"; | |||||
} | } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -65,14 +61,10 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase { | |||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index, | const NCBKernIndex& ncb_index, | ||||
const CpuNDRange& workspace_ids); | const CpuNDRange& workspace_ids); | ||||
bool m_large_group; | |||||
public: | public: | ||||
AlgoDirectStride2(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | const char* name() const override { | ||||
return m_large_group ? "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP" | |||||
: "X86_CONV_BIAS_DIRECT_STRIDE2_SMALL_GROUP"; | |||||
return "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP"; | |||||
} | } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -76,10 +76,8 @@ void* ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::type() const { | |||||
} | } | ||||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | class ConvBiasImpl::AlgoPack : NonCopyableObj { | ||||
AlgoDirect stride1_direct_large_group{true}; | |||||
AlgoDirect stride1_direct_small_group{false}; | |||||
AlgoDirectStride2 stride2_direct_large_group{true}; | |||||
AlgoDirectStride2 stride2_direct_small_group{false}; | |||||
AlgoDirect stride1_direct; | |||||
AlgoDirectStride2 stride2_direct; | |||||
AlgoDirectAvx2Stride1Int8 avx2_stride1_direct_int8; | AlgoDirectAvx2Stride1Int8 avx2_stride1_direct_int8; | ||||
AlgoAVX2DirectConvStride2 avx2_stride2_direct; | AlgoAVX2DirectConvStride2 avx2_stride2_direct; | ||||
AlgoChanWiseAvx2Stride1Qint8 avx2_stride1_chanwsie_qint8; | AlgoChanWiseAvx2Stride1Qint8 avx2_stride1_chanwsie_qint8; | ||||
@@ -103,10 +101,8 @@ public: | |||||
all_algos.emplace_back(&mkldnn_matmul_qint8); | all_algos.emplace_back(&mkldnn_matmul_qint8); | ||||
all_algos.emplace_back(&mkldnn_qint8); | all_algos.emplace_back(&mkldnn_qint8); | ||||
#endif | #endif | ||||
all_algos.emplace_back(&stride1_direct_large_group); | |||||
all_algos.emplace_back(&stride1_direct_small_group); | |||||
all_algos.emplace_back(&stride2_direct_large_group); | |||||
all_algos.emplace_back(&stride2_direct_small_group); | |||||
all_algos.emplace_back(&stride1_direct); | |||||
all_algos.emplace_back(&stride2_direct); | |||||
all_algos.emplace_back(&avx2_stride1_direct_int8); | all_algos.emplace_back(&avx2_stride1_direct_int8); | ||||
all_algos.emplace_back(&avx2_stride2_direct); | all_algos.emplace_back(&avx2_stride2_direct); | ||||
all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); | all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); | ||||
@@ -81,15 +81,10 @@ void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle, | |||||
{arg.src, arg.filter, arg.bias, {}, {}}); | {arg.src, arg.filter, arg.bias, {}, {}}); | ||||
} | } | ||||
} | } | ||||
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) { | |||||
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) { | |||||
check_conv_bias( | check_conv_bias( | ||||
conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | ||||
handle(), "ARMV8F32STRD2_LARGE_GROUP"); | |||||
} | |||||
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) { | |||||
check_conv_bias( | |||||
conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | |||||
handle(), "ARMV8F32STRD2_SMALL_GROUP"); | |||||
handle(), "ARMV8F32STRD2"); | |||||
} | } | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
@@ -114,17 +109,11 @@ void checker_conv_bias_fp16(std::vector<conv_bias::TestArg> args, | |||||
} | } | ||||
} | } | ||||
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2_LARGE_GROUP) { | |||||
NormalRNG rng(1); | |||||
checker_conv_bias_f16( | |||||
conv_bias::get_conv_bias_args({2, 3, 5}, 2, false, false, false), | |||||
handle(), rng, "ARMV8F16STRD2_LARGE_GROUP", 0.04); | |||||
} | |||||
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2_SMALL_GROUP) { | |||||
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2) { | |||||
NormalRNG rng(1); | NormalRNG rng(1); | ||||
checker_conv_bias_f16( | checker_conv_bias_f16( | ||||
conv_bias::get_conv_bias_args({2, 3, 5}, 2, false, false, false), | conv_bias::get_conv_bias_args({2, 3, 5}, 2, false, false, false), | ||||
handle(), rng, "ARMV8F16STRD2_SMALL_GROUP", 0.04); | |||||
handle(), rng, "ARMV8F16STRD2", 0.04); | |||||
} | } | ||||
#endif | #endif | ||||
@@ -1310,8 +1310,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE1_NCHW44) { | |||||
benchmark0.set_param(param); | benchmark0.set_param(param); | ||||
benchmark0.set_times(RUN); | benchmark0.set_times(RUN); | ||||
benchmark0.set_before_exec_callback( | benchmark0.set_before_exec_callback( | ||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||||
"F32STRD1_LARGE_GROUP")); | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD1")); | |||||
auto opr = handle()->create_operator<ConvBias>(); | auto opr = handle()->create_operator<ConvBias>(); | ||||
opr->param() = param; | opr->param() = param; | ||||
@@ -1385,8 +1384,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE2_NCHW44) { | |||||
benchmark0.set_param(param); | benchmark0.set_param(param); | ||||
benchmark0.set_times(RUN); | benchmark0.set_times(RUN); | ||||
benchmark0.set_before_exec_callback( | benchmark0.set_before_exec_callback( | ||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||||
"F32STRD2_LARGE_GROUP")); | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD2")); | |||||
auto opr = handle()->create_operator<ConvBias>(); | auto opr = handle()->create_operator<ConvBias>(); | ||||
opr->param() = param; | opr->param() = param; | ||||
@@ -1464,8 +1462,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { | |||||
benchmark0.set_param(param); | benchmark0.set_param(param); | ||||
benchmark0.set_times(RUN); | benchmark0.set_times(RUN); | ||||
benchmark0.set_before_exec_callback( | benchmark0.set_before_exec_callback( | ||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||||
"S8STRD1_LARGE_GROUP")); | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("S8STRD1")); | |||||
auto opr = handle()->create_operator<ConvBias>(); | auto opr = handle()->create_operator<ConvBias>(); | ||||
opr->param() = param; | opr->param() = param; | ||||
@@ -356,15 +356,10 @@ void checker_conv_bias_int8x8x32_multi(std::vector<conv_bias::TestArg> args, | |||||
} | } | ||||
/**********************************F32 direct************************/ | /**********************************F32 direct************************/ | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32) { | |||||
check_conv_bias( | check_conv_bias( | ||||
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | ||||
handle(), "F32DIRECT_LARGE_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) { | |||||
check_conv_bias( | |||||
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | |||||
handle(), "F32DIRECT_SMALL_GROUP"); | |||||
handle(), "F32DIRECT"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) { | ||||
@@ -391,21 +386,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { | |||||
handle(), "F32_CONV_NCHW44_DIRECT"); | handle(), "F32_CONV_NCHW44_DIRECT"); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_LARGE_GROUP) { | |||||
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | |||||
handle(), "F32STRD1_LARGE_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_SMALL_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1) { | |||||
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | ||||
handle(), "F32STRD1_SMALL_GROUP"); | |||||
handle(), "F32STRD1"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) { | |||||
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | ||||
handle(), "F32STRD2_LARGE_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) { | |||||
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | |||||
handle(), "F32STRD2_SMALL_GROUP"); | |||||
handle(), "F32STRD2"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) { | ||||
check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, | check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, | ||||
@@ -437,72 +424,41 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) { | |||||
/**********************************F16 direct************************/ | /**********************************F16 direct************************/ | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) { | |||||
NormalRNG rng(1); | NormalRNG rng(1); | ||||
checker_conv_bias_f16( | checker_conv_bias_f16( | ||||
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | ||||
handle(), rng, "F16DIRECT_LARGE_GROUP", 0.03); | |||||
handle(), rng, "F16DIRECT", 0.03); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_SMALL_GROUP) { | |||||
NormalRNG rng(1); | |||||
checker_conv_bias_f16( | |||||
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | |||||
handle(), rng, "F16DIRECT_SMALL_GROUP", 0.03); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) { | |||||
NormalRNG rng(1); | NormalRNG rng(1); | ||||
checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false), | checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false), | ||||
handle(), rng, "F16STRD1_LARGE_GROUP", 0.03); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_SMALL_GROUP) { | |||||
NormalRNG rng(1); | |||||
checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false), | |||||
handle(), rng, "F16STRD1_SMALL_GROUP", 0.03); | |||||
handle(), rng, "F16STRD1", 0.03); | |||||
} | } | ||||
#endif | #endif | ||||
/**********************************algo 8816 direct************************/ | /**********************************algo 8816 direct************************/ | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT) { | |||||
checker_conv_bias_int8x8x16( | checker_conv_bias_int8x8x16( | ||||
get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(), | get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(), | ||||
"I8816DIRECT_LARGE_GROUP"); | |||||
"I8816DIRECT"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_SMALL_GROUP) { | |||||
checker_conv_bias_int8x8x16( | |||||
get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(), | |||||
"I8816DIRECT_SMALL_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2) { | |||||
checker_conv_bias_int8x8x16( | checker_conv_bias_int8x8x16( | ||||
get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(), | get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(), | ||||
"I8816STRD2_LARGE_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_SMALL_GROUP) { | |||||
checker_conv_bias_int8x8x16( | |||||
get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(), | |||||
"I8816STRD2_SMALL_GROUP"); | |||||
"I8816STRD2"); | |||||
} | } | ||||
/**********************************algo 8-8-32 direct************************/ | /**********************************algo 8-8-32 direct************************/ | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1) { | |||||
checker_conv_bias_int8x8x32_multi( | checker_conv_bias_int8x8x32_multi( | ||||
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | ||||
"S8STRD1_LARGE_GROUP"); | |||||
"S8STRD1"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_SMALL_GROUP) { | |||||
checker_conv_bias_int8x8x32_multi( | |||||
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | |||||
"S8STRD1_SMALL_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2) { | |||||
checker_conv_bias_int8x8x32_multi( | checker_conv_bias_int8x8x32_multi( | ||||
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | ||||
"S8STRD2_LARGE_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_SMALL_GROUP) { | |||||
checker_conv_bias_int8x8x32_multi( | |||||
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | |||||
"S8STRD2_SMALL_GROUP"); | |||||
"S8STRD2"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, | TEST_F(ARM_COMMON_MULTI_THREADS, | ||||
@@ -520,25 +476,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
} | } | ||||
/********************************qint8 direct******************************/ | /********************************qint8 direct******************************/ | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_LARGE_GROUP) { | |||||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||||
{2, 3, 5, 7}, 1, false, false, false), | |||||
handle(), "S8STRD1_LARGE_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_SMALL_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1) { | |||||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | ||||
{2, 3, 5, 7}, 1, false, false, false), | {2, 3, 5, 7}, 1, false, false, false), | ||||
handle(), "S8STRD1_SMALL_GROUP"); | |||||
handle(), "S8STRD1"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2) { | |||||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | ||||
{2, 3, 5, 7}, 2, false, false, false), | {2, 3, 5, 7}, 2, false, false, false), | ||||
handle(), "S8STRD2_LARGE_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) { | |||||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||||
{2, 3, 5, 7}, 2, false, false, false), | |||||
handle(), "S8STRD2_SMALL_GROUP"); | |||||
handle(), "S8STRD2"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { | ||||
checker_conv_bias_qint8x8x8( | checker_conv_bias_qint8x8x8( | ||||
@@ -586,25 +532,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) { | |||||
} | } | ||||
/*****************************quint8 direct****************************/ | /*****************************quint8 direct****************************/ | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) { | |||||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | ||||
{2, 3, 5, 7}, 1, false, false, false), | {2, 3, 5, 7}, 1, false, false, false), | ||||
handle(), "QU8STRD1_LARGE_GROUP"); | |||||
handle(), "QU8STRD1"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_SMALL_GROUP) { | |||||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | |||||
{2, 3, 5, 7}, 1, false, false, false), | |||||
handle(), "QU8STRD1_SMALL_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) { | |||||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | ||||
{2, 3, 5, 7}, 2, false, false, false), | {2, 3, 5, 7}, 2, false, false, false), | ||||
handle(), "QU8STRD2_LARGE_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) { | |||||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | |||||
{2, 3, 5, 7}, 2, false, false, false), | |||||
handle(), "QU8STRD2_SMALL_GROUP"); | |||||
handle(), "QU8STRD2"); | |||||
} | } | ||||
/****************************dot qint8 direct*************************/ | /****************************dot qint8 direct*************************/ | ||||
@@ -624,100 +560,53 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | |||||
} | } | ||||
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); | checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | |||||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | ||||
{2, 3, 5, 7}, 1, false, false, false), | {2, 3, 5, 7}, 1, false, false, false), | ||||
handle(), "ARMDOTS8STRD1_LARGE_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) { | |||||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||||
{2, 3, 5, 7}, 1, false, false, false), | |||||
handle(), "ARMDOTS8STRD1_SMALL_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) { | |||||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||||
{2, 3, 5, 7}, 2, false, false, false), | |||||
handle(), "ARMDOTS8STRD2_LARGE_GROUP"); | |||||
handle(), "ARMDOTS8STRD1"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_WITHDOTPROD) { | |||||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | ||||
{2, 3, 5, 7}, 2, false, false, false), | {2, 3, 5, 7}, 2, false, false, false), | ||||
handle(), "ARMDOTS8STRD2_SMALL_GROUP"); | |||||
handle(), "ARMDOTS8STRD2"); | |||||
} | } | ||||
/****************************dot 8-8-32 direct*************************/ | /****************************dot 8-8-32 direct*************************/ | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT) { | |||||
checker_conv_bias_qint8x8x32( | checker_conv_bias_qint8x8x32( | ||||
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | ||||
"ARMDOTS8STRD1_LARGE_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_SMALL_GROUP) { | |||||
checker_conv_bias_qint8x8x32( | |||||
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | |||||
"ARMDOTS8STRD1_SMALL_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_LARGE_GROUP) { | |||||
checker_conv_bias_qint8x8x32( | |||||
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | |||||
"ARMDOTS8STRD2_LARGE_GROUP"); | |||||
"ARMDOTS8STRD1"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_SMALL_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT) { | |||||
checker_conv_bias_qint8x8x32( | checker_conv_bias_qint8x8x32( | ||||
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | ||||
"ARMDOTS8STRD2_SMALL_GROUP"); | |||||
"ARMDOTS8STRD2"); | |||||
} | } | ||||
/******************************dot quint8*****************************/ | /******************************dot quint8*****************************/ | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD) { | |||||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | ||||
{2, 3, 5, 7}, 1, false, false, false), | {2, 3, 5, 7}, 1, false, false, false), | ||||
handle(), "ARMDOTU8STRD1_LARGE_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) { | |||||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | |||||
{2, 3, 5, 7}, 1, false, false, false), | |||||
handle(), "ARMDOTU8STRD1_SMALL_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) { | |||||
checker_conv_bias_quint8x8x8( | |||||
get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false), | |||||
handle(), "ARMDOTU8STRD2_LARGE_GROUP"); | |||||
handle(), "ARMDOTU8STRD1"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) { | |||||
//! TODO: this test without test kernel size=3, add it will case buss error now | |||||
//! in armv7 | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) { | |||||
checker_conv_bias_quint8x8x8( | checker_conv_bias_quint8x8x8( | ||||
get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false), | get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false), | ||||
handle(), "ARMDOTU8STRD2_SMALL_GROUP"); | |||||
handle(), "ARMDOTU8STRD2"); | |||||
} | } | ||||
/******************************dot quint8x8x32***********************/ | /******************************dot quint8x8x32***********************/ | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_LARGE_GROUP) { | |||||
checker_conv_bias_quint8x8x32( | |||||
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | |||||
"ARMDOTU8STRD1_LARGE_GROUP"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_SMALL_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1) { | |||||
checker_conv_bias_quint8x8x32( | checker_conv_bias_quint8x8x32( | ||||
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), | ||||
"ARMDOTU8STRD1_SMALL_GROUP"); | |||||
"ARMDOTU8STRD1"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_LARGE_GROUP) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2) { | |||||
checker_conv_bias_quint8x8x32( | checker_conv_bias_quint8x8x32( | ||||
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | ||||
"ARMDOTU8STRD2_LARGE_GROUP"); | |||||
"ARMDOTU8STRD2"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) { | |||||
checker_conv_bias_quint8x8x32( | |||||
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | |||||
"ARMDOTU8STRD2_SMALL_GROUP"); | |||||
} | |||||
/******************************dot int8x8x8 nchw44 ***********************/ | /******************************dot int8x8x8 nchw44 ***********************/ | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
@@ -125,7 +125,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { | |||||
bench_case(1, 32, 32, 80, 80, 3, 4); | bench_case(1, 32, 32, 80, 80, 3, 4); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32); | bench_case(1, 32, 32, 80, 80, 3, 32); | ||||
std::string algo_name = "F32DIRECT_LARGE_GROUP"; | |||||
std::string algo_name = "F32DIRECT"; | |||||
printf("Benchmark F32DIRECT_LARGE_GROUP algo\n"); | printf("Benchmark F32DIRECT_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(), | std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(), | ||||
dtype::Float32(), dtype::Float32()}; | dtype::Float32(), dtype::Float32()}; | ||||
@@ -137,7 +137,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "F32DIRECT_SMALL_GROUP"; | |||||
algo_name = "F32DIRECT"; | |||||
printf("Benchmark F32DIRECT_SMALL_GROUP algo\n"); | printf("Benchmark F32DIRECT_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1); | bench_case(1, 32, 32, 200, 200, 3, 1); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1); | bench_case(1, 32, 32, 128, 128, 3, 1); | ||||
@@ -186,7 +186,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR1) { | |||||
bench_case(1, 32, 32, 80, 80, 3, 4); | bench_case(1, 32, 32, 80, 80, 3, 4); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32); | bench_case(1, 32, 32, 80, 80, 3, 32); | ||||
std::string algo_name = "F32STRD1_LARGE_GROUP"; | |||||
std::string algo_name = "F32STRD1"; | |||||
printf("Benchmark F32STRD1_LARGE_GROUP algo\n"); | printf("Benchmark F32STRD1_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(), | std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(), | ||||
dtype::Float32(), dtype::Float32()}; | dtype::Float32(), dtype::Float32()}; | ||||
@@ -198,7 +198,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR1) { | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "F32STRD1_SMALL_GROUP"; | |||||
algo_name = "F32STRD1"; | |||||
printf("Benchmark F32STRD1_SMALL_GROUP algo\n"); | printf("Benchmark F32STRD1_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1); | bench_case(1, 32, 32, 200, 200, 3, 1); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1); | bench_case(1, 32, 32, 128, 128, 3, 1); | ||||
@@ -249,7 +249,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR2) { | |||||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | ||||
std::string algo_name = "F32STRD2_LARGE_GROUP"; | |||||
std::string algo_name = "F32STRD2"; | |||||
printf("Benchmark F32STRD2_LARGE_GROUP algo\n"); | printf("Benchmark F32STRD2_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(), | std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(), | ||||
dtype::Float32(), dtype::Float32()}; | dtype::Float32(), dtype::Float32()}; | ||||
@@ -261,7 +261,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR2) { | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "F32STRD2_SMALL_GROUP"; | |||||
algo_name = "F32STRD2"; | |||||
printf("Benchmark F32STRD2_SMALL_GROUP algo\n"); | printf("Benchmark F32STRD2_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | ||||
@@ -313,7 +313,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) { | |||||
bench_case(1, 32, 32, 80, 80, 3, 4); | bench_case(1, 32, 32, 80, 80, 3, 4); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32); | bench_case(1, 32, 32, 80, 80, 3, 32); | ||||
std::string algo_name = "F16DIRECT_LARGE_GROUP"; | |||||
std::string algo_name = "F16DIRECT"; | |||||
printf("Benchmark F16DIRECT_LARGE_GROUP algo\n"); | printf("Benchmark F16DIRECT_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = {dtype::Float16(), dtype::Float16(), | std::vector<DType> data_type = {dtype::Float16(), dtype::Float16(), | ||||
dtype::Float16(), dtype::Float16()}; | dtype::Float16(), dtype::Float16()}; | ||||
@@ -325,7 +325,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) { | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "F16DIRECT_SMALL_GROUP"; | |||||
algo_name = "F16DIRECT"; | |||||
printf("Benchmark F16DIRECT_SMALL_GROUP algo\n"); | printf("Benchmark F16DIRECT_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1); | bench_case(1, 32, 32, 200, 200, 3, 1); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1); | bench_case(1, 32, 32, 128, 128, 3, 1); | ||||
@@ -375,7 +375,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) { | |||||
bench_case(1, 32, 32, 80, 80, 3, 4); | bench_case(1, 32, 32, 80, 80, 3, 4); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32); | bench_case(1, 32, 32, 80, 80, 3, 32); | ||||
std::string algo_name = "F16STRD1_LARGE_GROUP"; | |||||
std::string algo_name = "F16STRD1"; | |||||
printf("Benchmark F16STRD1_LARGE_GROUP algo\n"); | printf("Benchmark F16STRD1_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = {dtype::Float16(), dtype::Float16(), | std::vector<DType> data_type = {dtype::Float16(), dtype::Float16(), | ||||
dtype::Float16(), dtype::Float16()}; | dtype::Float16(), dtype::Float16()}; | ||||
@@ -387,7 +387,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) { | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "F16STRD1_SMALL_GROUP"; | |||||
algo_name = "F16STRD1"; | |||||
printf("Benchmark F16STRD1_SMALL_GROUP algo\n"); | printf("Benchmark F16STRD1_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1); | bench_case(1, 32, 32, 200, 200, 3, 1); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1); | bench_case(1, 32, 32, 128, 128, 3, 1); | ||||
@@ -439,7 +439,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
bench_case(1, 32, 32, 80, 80, 3, 4); | bench_case(1, 32, 32, 80, 80, 3, 4); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32); | bench_case(1, 32, 32, 80, 80, 3, 32); | ||||
std::string algo_name = "I8816DIRECT_LARGE_GROUP"; | |||||
std::string algo_name = "I8816DIRECT"; | |||||
printf("Benchmark I8816DIRECT_LARGE_GROUP algo\n"); | printf("Benchmark I8816DIRECT_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = {dtype::Int8(), dtype::Int8(), | std::vector<DType> data_type = {dtype::Int8(), dtype::Int8(), | ||||
dtype::Int16(), dtype::Int16()}; | dtype::Int16(), dtype::Int16()}; | ||||
@@ -451,7 +451,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "I8816DIRECT_SMALL_GROUP"; | |||||
algo_name = "I8816DIRECT"; | |||||
printf("Benchmark I8816DIRECT_SMALL_GROUP algo\n"); | printf("Benchmark I8816DIRECT_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1); | bench_case(1, 32, 32, 200, 200, 3, 1); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1); | bench_case(1, 32, 32, 128, 128, 3, 1); | ||||
@@ -503,7 +503,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | ||||
std::string algo_name = "I8816STRD2_LARGE_GROUP"; | |||||
std::string algo_name = "I8816STRD2"; | |||||
printf("Benchmark I8816STRD2_LARGE_GROUP algo\n"); | printf("Benchmark I8816STRD2_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = {dtype::Int8(), dtype::Int8(), | std::vector<DType> data_type = {dtype::Int8(), dtype::Int8(), | ||||
dtype::Int16(), dtype::Int16()}; | dtype::Int16(), dtype::Int16()}; | ||||
@@ -515,7 +515,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "I8816STRD2_SMALL_GROUP"; | |||||
algo_name = "I8816STRD2"; | |||||
printf("Benchmark I8816STRD2_SMALL_GROUP algo\n"); | printf("Benchmark I8816STRD2_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | ||||
@@ -567,7 +567,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); | bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); | bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); | ||||
std::string algo_name = "S8STRD1_LARGE_GROUP"; | |||||
std::string algo_name = "S8STRD1"; | |||||
printf("Benchmark S8STRD1_LARGE_GROUP algo\n"); | printf("Benchmark S8STRD1_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = { | std::vector<DType> data_type = { | ||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | ||||
@@ -580,7 +580,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "S8STRD1_SMALL_GROUP"; | |||||
algo_name = "S8STRD1"; | |||||
printf("Benchmark S8STRD1_SMALL_GROUP algo\n"); | printf("Benchmark S8STRD1_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); | bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); | bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); | ||||
@@ -866,7 +866,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | ||||
std::string algo_name = "S8STRD2_LARGE_GROUP"; | |||||
std::string algo_name = "S8STRD2"; | |||||
printf("Benchmark S8STRD2_LARGE_GROUP algo\n"); | printf("Benchmark S8STRD2_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = { | std::vector<DType> data_type = { | ||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | ||||
@@ -879,7 +879,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "S8STRD2_SMALL_GROUP"; | |||||
algo_name = "S8STRD2"; | |||||
printf("Benchmark S8STRD2_SMALL_GROUP algo\n"); | printf("Benchmark S8STRD2_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | ||||
@@ -932,7 +932,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); | bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); | bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); | ||||
std::string algo_name = "ARMDOTS8STRD1_LARGE_GROUP"; | |||||
std::string algo_name = "ARMDOTS8STRD1"; | |||||
printf("Benchmark ARMDOTS8STRD1_LARGE_GROUP algo\n"); | printf("Benchmark ARMDOTS8STRD1_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = { | std::vector<DType> data_type = { | ||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | ||||
@@ -945,7 +945,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "ARMDOTS8STRD1_SMALL_GROUP"; | |||||
algo_name = "ARMDOTS8STRD1"; | |||||
printf("Benchmark ARMDOTS8STRD1_SMALL_GROUP algo\n"); | printf("Benchmark ARMDOTS8STRD1_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); | bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); | bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); | ||||
@@ -997,7 +997,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | ||||
std::string algo_name = "ARMDOTS8STRD2_LARGE_GROUP"; | |||||
std::string algo_name = "ARMDOTS8STRD2"; | |||||
printf("Benchmark ARMDOTS8STRD2_LARGE_GROUP algo\n"); | printf("Benchmark ARMDOTS8STRD2_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = { | std::vector<DType> data_type = { | ||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | ||||
@@ -1010,7 +1010,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "ARMDOTS8STRD2_SMALL_GROUP"; | |||||
algo_name = "ARMDOTS8STRD2"; | |||||
printf("Benchmark ARMDOTS8STRD2_SMALL_GROUP algo\n"); | printf("Benchmark ARMDOTS8STRD2_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | ||||
@@ -1064,7 +1064,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); | bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); | bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); | ||||
std::string algo_name = "QU8STRD1_LARGE_GROUP"; | |||||
std::string algo_name = "QU8STRD1"; | |||||
printf("Benchmark QU8STRD1_LARGE_GROUP algo\n"); | printf("Benchmark QU8STRD1_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100), | std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100), | ||||
dtype::Quantized8Asymm(0.2f, 120), | dtype::Quantized8Asymm(0.2f, 120), | ||||
@@ -1078,7 +1078,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "QU8STRD1_SMALL_GROUP"; | |||||
algo_name = "QU8STRD1"; | |||||
printf("Benchmark QU8STRD1_SMALL_GROUP algo\n"); | printf("Benchmark QU8STRD1_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); | bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); | bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); | ||||
@@ -1130,7 +1130,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | ||||
std::string algo_name = "QU8STRD2_LARGE_GROUP"; | |||||
std::string algo_name = "QU8STRD2"; | |||||
printf("Benchmark QU8STRD2_LARGE_GROUP algo\n"); | printf("Benchmark QU8STRD2_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100), | std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100), | ||||
dtype::Quantized8Asymm(0.2f, 120), | dtype::Quantized8Asymm(0.2f, 120), | ||||
@@ -1144,7 +1144,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "QU8STRD2_SMALL_GROUP"; | |||||
algo_name = "QU8STRD2"; | |||||
printf("Benchmark QU8STRD2_SMALL_GROUP algo\n"); | printf("Benchmark QU8STRD2_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | ||||
@@ -1198,7 +1198,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); | bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); | ||||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); | bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); | ||||
std::string algo_name = "ARMDOTU8STRD1_LARGE_GROUP"; | |||||
std::string algo_name = "ARMDOTU8STRD1"; | |||||
printf("Benchmark ARMDOTU8STRD1_LARGE_GROUP algo\n"); | printf("Benchmark ARMDOTU8STRD1_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100), | std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100), | ||||
dtype::Quantized8Asymm(0.2f, 120), | dtype::Quantized8Asymm(0.2f, 120), | ||||
@@ -1212,7 +1212,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "ARMDOTU8STRD1_SMALL_GROUP"; | |||||
algo_name = "ARMDOTU8STRD1"; | |||||
printf("Benchmark ARMDOTS8STRD1_SMALL_GROUP algo\n"); | printf("Benchmark ARMDOTS8STRD1_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); | bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); | bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); | ||||
@@ -1265,7 +1265,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
bench_case(1, 32, 32, 80, 80, 5, 4, 1, 2); | bench_case(1, 32, 32, 80, 80, 5, 4, 1, 2); | ||||
bench_case(1, 32, 32, 80, 80, 5, 32, 1, 2); | bench_case(1, 32, 32, 80, 80, 5, 32, 1, 2); | ||||
std::string algo_name = "ARMDOTU8STRD2_LARGE_GROUP"; | |||||
std::string algo_name = "ARMDOTU8STRD2"; | |||||
printf("Benchmark ARMDOTU8STRD2_LARGE_GROUP algo\n"); | printf("Benchmark ARMDOTU8STRD2_LARGE_GROUP algo\n"); | ||||
std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100), | std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100), | ||||
dtype::Quantized8Asymm(0.2f, 120), | dtype::Quantized8Asymm(0.2f, 120), | ||||
@@ -1279,7 +1279,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "ARMDOTU8STRD2_SMALL_GROUP"; | |||||
algo_name = "ARMDOTU8STRD2"; | |||||
printf("Benchmark ARMDOTU8STRD2_SMALL_GROUP algo\n"); | printf("Benchmark ARMDOTU8STRD2_SMALL_GROUP algo\n"); | ||||
bench_case(1, 32, 32, 200, 200, 5, 1, 1, 2); | bench_case(1, 32, 32, 200, 200, 5, 1, 1, 2); | ||||
bench_case(1, 32, 32, 128, 128, 5, 1, 1, 2); | bench_case(1, 32, 32, 128, 128, 5, 1, 1, 2); | ||||
@@ -176,7 +176,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE1_I8x8x32_WITHDOTPROD) { | |||||
constexpr size_t RUN = 50; | constexpr size_t RUN = 50; | ||||
Benchmarker<Convolution> benchmark(handle()); | Benchmarker<Convolution> benchmark(handle()); | ||||
benchmark.set_before_exec_callback( | benchmark.set_before_exec_callback( | ||||
AlgoChecker<Convolution>("CONVOLUTION_DEFAULT_ARMDOTS8STRD1_SMALL_GROUP")); | |||||
AlgoChecker<Convolution>("CONVOLUTION_DEFAULT_ARMDOTS8STRD1")); | |||||
benchmark.set_dtype(0, dtype::Int8()) | benchmark.set_dtype(0, dtype::Int8()) | ||||
.set_dtype(1, dtype::Int8()) | .set_dtype(1, dtype::Int8()) | ||||
.set_dtype(2, dtype::Int32()); | .set_dtype(2, dtype::Int32()); | ||||
@@ -243,7 +243,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE2_I8x8x32_WITHDOTPROD) { | |||||
constexpr size_t RUN = 10; | constexpr size_t RUN = 10; | ||||
Benchmarker<Convolution> benchmark(handle()); | Benchmarker<Convolution> benchmark(handle()); | ||||
benchmark.set_before_exec_callback( | benchmark.set_before_exec_callback( | ||||
AlgoChecker<Convolution>("CONVOLUTION_DEFAULT_ARMDOTS8STRD2_SMALL_GROUP")); | |||||
AlgoChecker<Convolution>("CONVOLUTION_DEFAULT_ARMDOTS8STRD2")); | |||||
benchmark.set_dtype(0, dtype::Int8()) | benchmark.set_dtype(0, dtype::Int8()) | ||||
.set_dtype(1, dtype::Int8()) | .set_dtype(1, dtype::Int8()) | ||||
.set_dtype(2, dtype::Int32()); | .set_dtype(2, dtype::Int32()); | ||||
@@ -317,7 +317,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE1_QUINT8_WITHDOTPROD) { | |||||
benchmark.set_display(false); | benchmark.set_display(false); | ||||
benchmark.set_times(RUN); | benchmark.set_times(RUN); | ||||
benchmark.set_before_exec_callback(AlgoChecker<ConvolutionForward>( | benchmark.set_before_exec_callback(AlgoChecker<ConvolutionForward>( | ||||
"CONVOLUTION_DEFAULT_ARMDOTU8STRD1_SMALL_GROUP")); | |||||
"CONVOLUTION_DEFAULT_ARMDOTU8STRD1")); | |||||
Benchmarker<Convolution> benchmark_float(handle()); | Benchmarker<Convolution> benchmark_float(handle()); | ||||
benchmark_float.set_display(false); | benchmark_float.set_display(false); | ||||
@@ -387,7 +387,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE2_QUINT8_WITHDOTPROD) { | |||||
benchmark.set_display(false); | benchmark.set_display(false); | ||||
benchmark.set_times(RUN); | benchmark.set_times(RUN); | ||||
benchmark.set_before_exec_callback(AlgoChecker<ConvolutionForward>( | benchmark.set_before_exec_callback(AlgoChecker<ConvolutionForward>( | ||||
"CONVOLUTION_DEFAULT_ARMDOTU8STRD2_SMALL_GROUP")); | |||||
"CONVOLUTION_DEFAULT_ARMDOTU8STRD2")); | |||||
Benchmarker<Convolution> benchmark_float(handle()); | Benchmarker<Convolution> benchmark_float(handle()); | ||||
benchmark_float.set_display(false); | benchmark_float.set_display(false); | ||||
@@ -583,7 +583,7 @@ TEST_F(X86_MULTI_THREADS, AVX2_CONV_BIAS_DIRECT_STRIDE2_S8S8S8) { | |||||
} | } | ||||
} | } | ||||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP) { | |||||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_DENSE) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
@@ -633,19 +633,19 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP) { | |||||
.set_rng(2, &rng); | .set_rng(2, &rng); | ||||
checker.set_before_exec_callback( | checker.set_before_exec_callback( | ||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | ||||
"X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP")); | |||||
"X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP")); | |||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
checker.set_param(arg.param).exec( | checker.set_param(arg.param).exec( | ||||
{arg.src, arg.filter, arg.bias, {}, {}}); | {arg.src, arg.filter, arg.bias, {}, {}}); | ||||
} | } | ||||
} | } | ||||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP) { | |||||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_GROUP) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, | |||||
size_t p, NonlineMode nonline_mode) { | |||||
auto run = [&](size_t group, size_t channel, size_t w, size_t h, | |||||
size_t kernel, size_t p, NonlineMode nonline_mode) { | |||||
if (w + 2 * p < kernel || h + 2 * p < kernel) | if (w + 2 * p < kernel || h + 2 * p < kernel) | ||||
return; | return; | ||||
param::ConvBias param; | param::ConvBias param; | ||||
@@ -654,30 +654,37 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP) { | |||||
param.pad_h = p; | param.pad_h = p; | ||||
param.pad_w = p; | param.pad_w = p; | ||||
param.nonlineMode = nonline_mode; | param.nonlineMode = nonline_mode; | ||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
//! no bias | //! no bias | ||||
args.emplace_back(param, TensorShape{1, ic, h, w}, | |||||
TensorShape{oc, ic, kernel, kernel}, TensorShape{}); | |||||
args.emplace_back( | |||||
param, TensorShape{1, channel, h, w}, | |||||
TensorShape{group, channel / group, channel / group, kernel, kernel}, | |||||
TensorShape{}); | |||||
//! bias channel | //! bias channel | ||||
args.emplace_back(param, TensorShape{2, ic, h, w}, | |||||
TensorShape{oc, ic, kernel, kernel}, | |||||
TensorShape{1, oc, 1, 1}); | |||||
args.emplace_back(param, TensorShape{2, channel, h, w}, | |||||
TensorShape{group, channel / group, channel / group, | |||||
kernel, kernel}, | |||||
TensorShape{1, channel, 1, 1}); | |||||
//! bias | //! bias | ||||
args.emplace_back(param, TensorShape{2, ic, h, w}, | |||||
TensorShape{oc, ic, kernel, kernel}, | |||||
TensorShape{2, oc, (h + param.pad_h * 2 - kernel) + 1, | |||||
(w + param.pad_w * 2 - kernel) + 1}); | |||||
args.emplace_back( | |||||
param, TensorShape{2, channel, h, w}, | |||||
TensorShape{group, channel / group, channel / group, kernel, | |||||
kernel}, | |||||
TensorShape{2, channel, (h + param.pad_h * 2 - kernel) + 1, | |||||
(w + param.pad_w * 2 - kernel) + 1}); | |||||
}; | }; | ||||
for (size_t kernel : {1, 2, 3, 4, 5, 6, 7}) | for (size_t kernel : {1, 2, 3, 4, 5, 6, 7}) | ||||
for (size_t ic : {1, 4, 8, 16}) | |||||
for (size_t oc : {1, 4, 8}) | |||||
for (size_t channel : {4, 8, 16}) | |||||
for (size_t group : {1, 2, 4}) | |||||
for (size_t p : {0, 2}) | for (size_t p : {0, 2}) | ||||
for (size_t size : {20, 21, 24}) | for (size_t size : {20, 21, 24}) | ||||
for (NonlineMode nonline_mode : | for (NonlineMode nonline_mode : | ||||
{NonlineMode::RELU, NonlineMode::SIGMOID, | {NonlineMode::RELU, NonlineMode::SIGMOID, | ||||
NonlineMode::H_SWISH, NonlineMode::IDENTITY}) { | NonlineMode::H_SWISH, NonlineMode::IDENTITY}) { | ||||
run(oc, ic, size, size, kernel, p, nonline_mode); | |||||
run(group, channel, size, size, kernel, p, | |||||
nonline_mode); | |||||
} | } | ||||
Checker<ConvBias> checker(handle()); | Checker<ConvBias> checker(handle()); | ||||
@@ -697,7 +704,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP) { | |||||
} | } | ||||
} | } | ||||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) { | |||||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2_DENSE) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
@@ -738,11 +745,68 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) { | |||||
.set_rng(2, &rng); | .set_rng(2, &rng); | ||||
checker.set_before_exec_callback( | checker.set_before_exec_callback( | ||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | ||||
"X86_CONV_BIAS_DIRECT_STRIDE2_SMALL_GROUP")); | |||||
"X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP")); | |||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
checker.set_param(arg.param).exec( | checker.set_param(arg.param).exec( | ||||
{arg.src, arg.filter, arg.bias, {}, {}}); | {arg.src, arg.filter, arg.bias, {}, {}}); | ||||
} | } | ||||
} | |||||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2_GROUP) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args; | |||||
auto run = [&](size_t group, size_t channel, size_t w, size_t h, | |||||
size_t kernel, size_t p, NonlineMode nonline_mode) { | |||||
if (w + 2 * p < kernel || h + 2 * p < kernel) | |||||
return; | |||||
param::ConvBias param; | |||||
param.stride_h = 2; | |||||
param.stride_w = 2; | |||||
param.pad_h = p; | |||||
param.pad_w = p; | |||||
param.nonlineMode = nonline_mode; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
//! no bias | |||||
args.emplace_back( | |||||
param, TensorShape{1, channel, h, w}, | |||||
TensorShape{group, channel / group, channel / group, kernel, kernel}, | |||||
TensorShape{}); | |||||
//! bias channel | |||||
args.emplace_back(param, TensorShape{2, channel, h, w}, | |||||
TensorShape{group, channel / group, channel / group, | |||||
kernel, kernel}, | |||||
TensorShape{1, channel, 1, 1}); | |||||
//! bias | |||||
args.emplace_back( | |||||
param, TensorShape{2, channel, h, w}, | |||||
TensorShape{group, channel / group, channel / group, kernel, | |||||
kernel}, | |||||
TensorShape{2, channel, (h + param.pad_h * 2 - kernel) / 2 + 1, | |||||
(w + param.pad_w * 2 - kernel) / 2 + 1}); | |||||
}; | |||||
for (size_t kernel : {2, 3, 5, 7}) | |||||
for (size_t channel : {4, 8, 16}) | |||||
for (size_t group : {1, 2, 4}) | |||||
for (size_t p : {0, 2}) | |||||
for (size_t size : {20, 21, 24}) | |||||
for (NonlineMode nonline_mode : | |||||
{NonlineMode::RELU, NonlineMode::SIGMOID, | |||||
NonlineMode::H_SWISH, NonlineMode::IDENTITY}) { | |||||
run(group, channel, size, size, kernel, p, | |||||
nonline_mode); | |||||
} | |||||
Checker<ConvBias> checker(handle()); | |||||
UniformIntRNG rng{-50, 50}; | |||||
checker.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.set_rng(0, &rng) | |||||
.set_rng(1, &rng) | |||||
.set_rng(2, &rng); | |||||
checker.set_before_exec_callback( | checker.set_before_exec_callback( | ||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | ||||
"X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP")); | "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP")); | ||||
@@ -2502,7 +2566,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { | |||||
bench_case(1, 32, 32, 80, 80, 3, 32); | bench_case(1, 32, 32, 80, 80, 3, 32); | ||||
std::string algo_name = "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"; | std::string algo_name = "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"; | ||||
printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP algo\n"); | |||||
printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_GROUP algo\n"); | |||||
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | ||||
{4, {4, 5, 6, 7}}, {1, {4}}, data_type); | {4, {4, 5, 6, 7}}, {1, {4}}, data_type); | ||||
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | ||||
@@ -2511,8 +2575,8 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { | |||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
algo_name = "X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP"; | |||||
printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP algo\n"); | |||||
algo_name = "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"; | |||||
printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_DENSE algo\n"); | |||||
bench_case(1, 32, 32, 200, 200, 3, 1); | bench_case(1, 32, 32, 200, 200, 3, 1); | ||||
bench_case(1, 32, 32, 128, 128, 3, 1); | bench_case(1, 32, 32, 128, 128, 3, 1); | ||||
bench_case(1, 32, 32, 100, 100, 3, 1); | bench_case(1, 32, 32, 100, 100, 3, 1); | ||||
@@ -125,7 +125,7 @@ TEST_F(X86, DEFAULT_CONV_DIRECT_STRIDE1) { | |||||
Checker<ConvolutionForward> checker(handle()); | Checker<ConvolutionForward> checker(handle()); | ||||
checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>( | checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>( | ||||
"CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP")); | |||||
"CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP")); | |||||
checker.set_epsilon(1); | checker.set_epsilon(1); | ||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
checker.set_dtype(0, dtype::Float32()) | checker.set_dtype(0, dtype::Float32()) | ||||
@@ -167,7 +167,7 @@ TEST_F(X86, DEFAULT_CONV_DIRECT_STRIDE2) { | |||||
Checker<ConvolutionForward> checker(handle()); | Checker<ConvolutionForward> checker(handle()); | ||||
checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>( | checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>( | ||||
"CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE2_SMALL_GROUP")); | |||||
"CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP")); | |||||
checker.set_epsilon(1); | checker.set_epsilon(1); | ||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
checker.set_dtype(0, dtype::Float32()) | checker.set_dtype(0, dtype::Float32()) | ||||