diff --git a/dnn/src/x86/conv_bias/opr_impl.cpp b/dnn/src/x86/conv_bias/opr_impl.cpp index a38d9ad9..4c7d10cd 100644 --- a/dnn/src/x86/conv_bias/opr_impl.cpp +++ b/dnn/src/x86/conv_bias/opr_impl.cpp @@ -95,6 +95,14 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { public: AlgoPack() { + //! FIXME: preference to use mkldnn algo on VNNI devices + //! But now mkldnn algo preference issue with NCHW->NHWC->NCHW +#if MEGDNN_X86_WITH_MKL_DNN + //! Create the mkldnn algo + all_algos.emplace_back(&mkldnn_conv_fp32); + all_algos.emplace_back(&mkldnn_matmul_qint8); + all_algos.emplace_back(&mkldnn_qint8); +#endif all_algos.emplace_back(&stride1_direct_large_group); all_algos.emplace_back(&stride1_direct_small_group); all_algos.emplace_back(&stride2_direct_large_group); @@ -105,14 +113,6 @@ public: all_algos.emplace_back(&avx2_stride2_chanwsie_qint8); all_algos.emplace_back(&matmul); - //! preference to use mkldnn algo on VNNI devices -#if MEGDNN_X86_WITH_MKL_DNN - //! Create the mkldnn algo - all_algos.emplace_back(&mkldnn_conv_fp32); - all_algos.emplace_back(&mkldnn_matmul_qint8); - all_algos.emplace_back(&mkldnn_qint8); -#endif - static CpuOprDelegationStorage<> storage; auto matmul_opr = storage.get(); auto&& matmul_algos = @@ -172,15 +172,18 @@ bool ConvBiasImpl::is_matmul_quantized_prefer( chanwise_avx2_stride2_qint8_usable_preferred(param) || direct_avx2_stride1_int8_usable_preferred(param) || direct_avx2_stride2_int8_usable_preferred(param); - } #if MEGDNN_X86_WITH_MKL_DNN - conv_direct_chanwise_mkldnn_usable = - conv_direct_chanwise_mkldnn_usable || - mkldnn_qint8_usable_preferred(param) || - mkldnn_matmul_qint8_usable_preferred(param); + conv_direct_chanwise_mkldnn_usable = + conv_direct_chanwise_mkldnn_usable || + mkldnn_qint8_usable_preferred(param) || + mkldnn_matmul_qint8_usable_preferred(param); #endif + } - return !conv_direct_chanwise_mkldnn_usable; + return !conv_direct_chanwise_mkldnn_usable || + (is_supported(SIMDType::VNNI) && + !chanwise_avx2_stride1_qint8_usable_preferred(param) && + !chanwise_avx2_stride2_qint8_usable_preferred(param)); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/matrix_mul/algos.cpp b/dnn/src/x86/matrix_mul/algos.cpp index 29815679..cf7e3669 100644 --- a/dnn/src/x86/matrix_mul/algos.cpp +++ b/dnn/src/x86/matrix_mul/algos.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/x86/matrix_mul/algos.h" @@ -17,7 +18,6 @@ #include "src/x86/matrix_mul/f32/strategy.h" - MIDOUT_DECL(megdnn_x86_matmul_kern) MIDOUT_DECL(megdnn_x86_matmul_kern_mk8_8x8) using namespace megdnn; @@ -45,17 +45,16 @@ void f32_blas_kern(const MatrixMulImpl::KernParam& kern_param) { #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM void f32_blas_kern_only_packA(const MatrixMulImpl::KernParam& kern_param, - const void* a_panel, const void* b_panel) { - MEGDNN_MARK_USED_VAR(b_panel); + const void* a_panel, const void* b_panel) { + MEGDNN_MARK_USED_VAR(b_panel); auto m = kern_param.M, n = kern_param.N, k = kern_param.K; const auto Bptr = kern_param.B(); auto Cptr = kern_param.C(); auto Atrd = kern_param.LDA, Btrd = kern_param.LDB, Ctrd = kern_param.LDC; disable_denorm(); cblas_sgemm_compute(CblasRowMajor, CblasPacked, CblasNoTrans, m, n, k, - static_cast(a_panel), Atrd, - Bptr, Btrd, 0.0f, Cptr, - Ctrd); + static_cast(a_panel), Atrd, Bptr, Btrd, + 0.0f, Cptr, Ctrd); } #endif @@ -111,8 +110,9 @@ WorkspaceBundle MatrixMulImpl::AlgoF32MKLPackA::get_bundle( return {nullptr, {a_size, 0, 0}}; } -void MatrixMulImpl::AlgoF32MKLPackA::pack_A(const KernParam& kern_param, void* out, - size_t index, size_t stride) const { +void MatrixMulImpl::AlgoF32MKLPackA::pack_A(const KernParam& kern_param, + void* out, size_t index, + size_t stride) const { MEGDNN_MARK_USED_VAR(stride); MEGDNN_MARK_USED_VAR(index); auto m = kern_param.M, n = kern_param.N, k = kern_param.K; @@ -164,7 +164,7 @@ size_t get_kern_workspace(MatrixMulImpl::KernSizeParam kern_size_param) { bool MatrixMulImpl::AlgoInt8x8x32Vnni::usable( const KernSizeParam& kern_size_param) const { - return kern_size_param.A_type == kern_size_param.B_type && + return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 && kern_size_param.C_type.enumv() == DTypeEnum::Int32) || (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && @@ -389,9 +389,10 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace( m, n, k, trans_a, trans_b, strategy, cacheline) .get_workspace_size(); } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16, megdnn_x86_matmul_kern, - 8, x86::matmul::gemm_avx2_s8s8s32_2x4x16, - dt_int8, dt_int32); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16, + megdnn_x86_matmul_kern, 8, + x86::matmul::gemm_avx2_s8s8s32_2x4x16, + dt_int8, dt_int32); /*************************AlgoInt8x8x32SSEM4N8K2********************/ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_kern( @@ -426,9 +427,10 @@ size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace( m, n, k, trans_a, trans_b, strategy, cacheline) .get_workspace_size(); } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( - AlgoInt8x8x32SSEM4N8K2, megdnn_x86_matmul_kern, 9, - x86::matmul::gemm_sse_s8s8s32_4x8x2, dt_int8, dt_int32, dt_int16); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2, + megdnn_x86_matmul_kern, 9, + x86::matmul::gemm_sse_s8s8s32_4x8x2, + dt_int8, dt_int32, dt_int16); /*************************AlgoF32MK8_8x8********************/ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK8_8x8::get_kern( diff --git a/dnn/src/x86/matrix_mul/opr_impl.cpp b/dnn/src/x86/matrix_mul/opr_impl.cpp index 7763ba5b..032356b6 100644 --- a/dnn/src/x86/matrix_mul/opr_impl.cpp +++ b/dnn/src/x86/matrix_mul/opr_impl.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/x86/matrix_mul/opr_impl.h" @@ -41,9 +42,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { public: AlgoPack() { if (is_supported(SIMDType::VNNI)) { -#if MEGDNN_X86_WITH_MKL_DNN - all_algos.emplace_back(&algoint8x8x32mkldnn); -#endif #if MEGDNN_X86_WITH_VNNI all_algos.emplace_back(&algoint8x8x32vnni); #endif