|
|
@@ -6,7 +6,8 @@ |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, |
|
|
|
* software distributed under the License is distributed on an |
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
|
|
* implied. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "src/x86/matrix_mul/algos.h" |
|
|
@@ -17,7 +18,6 @@ |
|
|
|
|
|
|
|
#include "src/x86/matrix_mul/f32/strategy.h" |
|
|
|
|
|
|
|
|
|
|
|
MIDOUT_DECL(megdnn_x86_matmul_kern) |
|
|
|
MIDOUT_DECL(megdnn_x86_matmul_kern_mk8_8x8) |
|
|
|
using namespace megdnn; |
|
|
@@ -45,17 +45,16 @@ void f32_blas_kern(const MatrixMulImpl::KernParam& kern_param) { |
|
|
|
|
|
|
|
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM |
|
|
|
void f32_blas_kern_only_packA(const MatrixMulImpl::KernParam& kern_param, |
|
|
|
const void* a_panel, const void* b_panel) { |
|
|
|
MEGDNN_MARK_USED_VAR(b_panel); |
|
|
|
const void* a_panel, const void* b_panel) { |
|
|
|
MEGDNN_MARK_USED_VAR(b_panel); |
|
|
|
auto m = kern_param.M, n = kern_param.N, k = kern_param.K; |
|
|
|
const auto Bptr = kern_param.B<dt_float32>(); |
|
|
|
auto Cptr = kern_param.C<dt_float32>(); |
|
|
|
auto Atrd = kern_param.LDA, Btrd = kern_param.LDB, Ctrd = kern_param.LDC; |
|
|
|
disable_denorm(); |
|
|
|
cblas_sgemm_compute(CblasRowMajor, CblasPacked, CblasNoTrans, m, n, k, |
|
|
|
static_cast<const float*>(a_panel), Atrd, |
|
|
|
Bptr, Btrd, 0.0f, Cptr, |
|
|
|
Ctrd); |
|
|
|
static_cast<const float*>(a_panel), Atrd, Bptr, Btrd, |
|
|
|
0.0f, Cptr, Ctrd); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
@@ -111,8 +110,9 @@ WorkspaceBundle MatrixMulImpl::AlgoF32MKLPackA::get_bundle( |
|
|
|
return {nullptr, {a_size, 0, 0}}; |
|
|
|
} |
|
|
|
|
|
|
|
void MatrixMulImpl::AlgoF32MKLPackA::pack_A(const KernParam& kern_param, void* out, |
|
|
|
size_t index, size_t stride) const { |
|
|
|
void MatrixMulImpl::AlgoF32MKLPackA::pack_A(const KernParam& kern_param, |
|
|
|
void* out, size_t index, |
|
|
|
size_t stride) const { |
|
|
|
MEGDNN_MARK_USED_VAR(stride); |
|
|
|
MEGDNN_MARK_USED_VAR(index); |
|
|
|
auto m = kern_param.M, n = kern_param.N, k = kern_param.K; |
|
|
@@ -164,7 +164,7 @@ size_t get_kern_workspace(MatrixMulImpl::KernSizeParam kern_size_param) { |
|
|
|
|
|
|
|
bool MatrixMulImpl::AlgoInt8x8x32Vnni::usable( |
|
|
|
const KernSizeParam& kern_size_param) const { |
|
|
|
return kern_size_param.A_type == kern_size_param.B_type && |
|
|
|
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && |
|
|
|
((kern_size_param.A_type.enumv() == DTypeEnum::Int8 && |
|
|
|
kern_size_param.C_type.enumv() == DTypeEnum::Int32) || |
|
|
|
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && |
|
|
@@ -389,9 +389,10 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace( |
|
|
|
m, n, k, trans_a, trans_b, strategy, cacheline) |
|
|
|
.get_workspace_size(); |
|
|
|
} |
|
|
|
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16, megdnn_x86_matmul_kern, |
|
|
|
8, x86::matmul::gemm_avx2_s8s8s32_2x4x16, |
|
|
|
dt_int8, dt_int32); |
|
|
|
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16, |
|
|
|
megdnn_x86_matmul_kern, 8, |
|
|
|
x86::matmul::gemm_avx2_s8s8s32_2x4x16, |
|
|
|
dt_int8, dt_int32); |
|
|
|
|
|
|
|
/*************************AlgoInt8x8x32SSEM4N8K2********************/ |
|
|
|
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_kern( |
|
|
@@ -426,9 +427,10 @@ size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace( |
|
|
|
m, n, k, trans_a, trans_b, strategy, cacheline) |
|
|
|
.get_workspace_size(); |
|
|
|
} |
|
|
|
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( |
|
|
|
AlgoInt8x8x32SSEM4N8K2, megdnn_x86_matmul_kern, 9, |
|
|
|
x86::matmul::gemm_sse_s8s8s32_4x8x2, dt_int8, dt_int32, dt_int16); |
|
|
|
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2, |
|
|
|
megdnn_x86_matmul_kern, 9, |
|
|
|
x86::matmul::gemm_sse_s8s8s32_4x8x2, |
|
|
|
dt_int8, dt_int32, dt_int16); |
|
|
|
|
|
|
|
/*************************AlgoF32MK8_8x8********************/ |
|
|
|
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK8_8x8::get_kern( |
|
|
|