add AlgoF32GiMK4Pack4x12 matrix_mul algo
GitOrigin-RevId: 47cfe1d733
release-1.10
@@ -197,10 +197,17 @@ bool ConvBiasImpl::AlgoConv1x1::usable( | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
#else //! x86 only support nchw mode | |||||
if (format != param::ConvBias::Format::NCHW) { | |||||
#else //! x86 and RISC-V do not support NCHW44_DOT | |||||
if (format != param::ConvBias::Format::NCHW && | |||||
format != param::ConvBias::Format::NCHW44) { | |||||
return false; | return false; | ||||
} | } | ||||
//! hybird mode is not support | |||||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||||
if (param.filter_meta.icpg < 4_z || param.filter_meta.ocpg == 1) { | |||||
return false; | |||||
} | |||||
} | |||||
#endif | #endif | ||||
//! param | //! param | ||||
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) { | if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) { | ||||
@@ -345,9 +345,21 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
} | } | ||||
} | } | ||||
#else | #else | ||||
if (format != param::ConvBias::Format::NCHW) { | |||||
if (format != param::ConvBias::Format::NCHW && | |||||
format != param::ConvBias::Format::NCHW44) { | |||||
return false; | return false; | ||||
} | } | ||||
if (format == param::ConvBias::Format::NCHW44) { | |||||
//! current NCHW44 im2col only support DEFAULT mode matmul | |||||
if (matmul_desc.packmode != Pack_Mode::DEFAULT) { | |||||
return false; | |||||
//! nchw44 hybird mode and channel wise is not support | |||||
} else if ( | |||||
param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || | |||||
param.filter_meta.ocpg == 1) { | |||||
return false; | |||||
} | |||||
} | |||||
#endif | #endif | ||||
if (param.src_type.enumv() != param.filter_type.enumv() || | if (param.src_type.enumv() != param.filter_type.enumv() || | ||||
(param.src_type.enumv() != DTypeEnum::Int8 && | (param.src_type.enumv() != DTypeEnum::Int8 && | ||||
@@ -216,10 +216,9 @@ public: | |||||
cb1(NCHW, DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT, | cb1(NCHW, DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT, | ||||
"DefaultStrategyType::FLOAT"_hash); | "DefaultStrategyType::FLOAT"_hash); | ||||
} else if (format == param::ConvBias::Format::NCHW44) { | } else if (format == param::ConvBias::Format::NCHW44) { | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
auto matmul_block = matmul_algo->get_inner_block_size(); | auto matmul_block = matmul_algo->get_inner_block_size(); | ||||
//! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 | |||||
//! im2col+pack fuse | |||||
//! Optimize NCHW44 3x3s2 on aarch64 8X12X4 and fallback/armv7 | |||||
//! 4x12x4 im2col+pack fuse | |||||
if ((matmul_block.m == 8 || matmul_block.m == 4) && | if ((matmul_block.m == 8 || matmul_block.m == 4) && | ||||
matmul_block.n == 12 && matmul_block.k == 1 && | matmul_block.n == 12 && matmul_block.k == 1 && | ||||
param.filter_meta.spatial[0] == 3 && | param.filter_meta.spatial[0] == 3 && | ||||
@@ -236,7 +235,6 @@ public: | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return {}; | return {}; | ||||
} | } | ||||
#endif | |||||
cb1(NCHW44, DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT, | cb1(NCHW44, DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT, | ||||
"DefaultStrategyTypeNCHW44::FLOAT"_hash); | "DefaultStrategyTypeNCHW44::FLOAT"_hash); | ||||
@@ -530,7 +530,6 @@ public: | |||||
}; | }; | ||||
#endif | #endif | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
template < | template < | ||||
typename op_ctype, typename op_dtype, megdnn::PostprocessMode postprocess_mode> | typename op_ctype, typename op_dtype, megdnn::PostprocessMode postprocess_mode> | ||||
class StrategyFuseXx12x1Nchw44K3x3S2 | class StrategyFuseXx12x1Nchw44K3x3S2 | ||||
@@ -553,7 +552,6 @@ public: | |||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | ||||
}; | }; | ||||
#endif | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,7 +1,6 @@ | |||||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | #include "src/fallback/conv_bias/im2col/strategy_base.h" | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
#include <arm_neon.h> | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -11,32 +10,32 @@ namespace { | |||||
int out_index = 0; \ | int out_index = 0; \ | ||||
outptr = output_base; \ | outptr = output_base; \ | ||||
for (; out_index + 11 < block_size; out_index += 12) { \ | for (; out_index + 11 < block_size; out_index += 12) { \ | ||||
float32x4x4_t v0 = vld4q_f32(tmp_output); \ | |||||
float32x4x4_t v1 = vld4q_f32(tmp_output + 16); \ | |||||
float32x4x4_t v2 = vld4q_f32(tmp_output + 32); \ | |||||
vst1q_f32(outptr, v0.val[0]); \ | |||||
vst1q_f32(outptr + 4, v1.val[0]); \ | |||||
vst1q_f32(outptr + 8, v2.val[0]); \ | |||||
vst1q_f32(outptr + 12, v0.val[1]); \ | |||||
vst1q_f32(outptr + 16, v1.val[1]); \ | |||||
vst1q_f32(outptr + 20, v2.val[1]); \ | |||||
vst1q_f32(outptr + 24, v0.val[2]); \ | |||||
vst1q_f32(outptr + 28, v1.val[2]); \ | |||||
vst1q_f32(outptr + 32, v2.val[2]); \ | |||||
vst1q_f32(outptr + 36, v0.val[3]); \ | |||||
vst1q_f32(outptr + 40, v1.val[3]); \ | |||||
vst1q_f32(outptr + 44, v2.val[3]); \ | |||||
GI_FLOAT32_V4_t v0 = GiLoadUzipFloat32V4(tmp_output); \ | |||||
GI_FLOAT32_V4_t v1 = GiLoadUzipFloat32V4(tmp_output + 16); \ | |||||
GI_FLOAT32_V4_t v2 = GiLoadUzipFloat32V4(tmp_output + 32); \ | |||||
GiStoreFloat32(outptr, GiGetSubVectorFloat32V4(v0, 0)); \ | |||||
GiStoreFloat32(outptr + 4, GiGetSubVectorFloat32V4(v1, 0)); \ | |||||
GiStoreFloat32(outptr + 8, GiGetSubVectorFloat32V4(v2, 0)); \ | |||||
GiStoreFloat32(outptr + 12, GiGetSubVectorFloat32V4(v0, 1)); \ | |||||
GiStoreFloat32(outptr + 16, GiGetSubVectorFloat32V4(v1, 1)); \ | |||||
GiStoreFloat32(outptr + 20, GiGetSubVectorFloat32V4(v2, 1)); \ | |||||
GiStoreFloat32(outptr + 24, GiGetSubVectorFloat32V4(v0, 2)); \ | |||||
GiStoreFloat32(outptr + 28, GiGetSubVectorFloat32V4(v1, 2)); \ | |||||
GiStoreFloat32(outptr + 32, GiGetSubVectorFloat32V4(v2, 2)); \ | |||||
GiStoreFloat32(outptr + 36, GiGetSubVectorFloat32V4(v0, 3)); \ | |||||
GiStoreFloat32(outptr + 40, GiGetSubVectorFloat32V4(v1, 3)); \ | |||||
GiStoreFloat32(outptr + 44, GiGetSubVectorFloat32V4(v2, 3)); \ | |||||
outptr += ksize12; \ | outptr += ksize12; \ | ||||
tmp_output += 48; \ | tmp_output += 48; \ | ||||
} \ | } \ | ||||
\ | \ | ||||
outptr = output_base4; \ | outptr = output_base4; \ | ||||
for (; out_index + 3 < block_size; out_index += 4) { \ | for (; out_index + 3 < block_size; out_index += 4) { \ | ||||
float32x4x4_t v0 = vld4q_f32(tmp_output); \ | |||||
vst1q_f32(outptr, v0.val[0]); \ | |||||
vst1q_f32(outptr + 4, v0.val[1]); \ | |||||
vst1q_f32(outptr + 8, v0.val[2]); \ | |||||
vst1q_f32(outptr + 12, v0.val[3]); \ | |||||
GI_FLOAT32_V4_t v0 = GiLoadUzipFloat32V4(tmp_output); \ | |||||
GiStoreFloat32(outptr, GiGetSubVectorFloat32V4(v0, 0)); \ | |||||
GiStoreFloat32(outptr + 4, GiGetSubVectorFloat32V4(v0, 1)); \ | |||||
GiStoreFloat32(outptr + 8, GiGetSubVectorFloat32V4(v0, 2)); \ | |||||
GiStoreFloat32(outptr + 12, GiGetSubVectorFloat32V4(v0, 3)); \ | |||||
outptr += ksize4; \ | outptr += ksize4; \ | ||||
tmp_output += 16; \ | tmp_output += 16; \ | ||||
} \ | } \ | ||||
@@ -45,23 +44,23 @@ namespace { | |||||
float zerobuffer[16] = {0}; \ | float zerobuffer[16] = {0}; \ | ||||
size_t out_remain = std::min(block_size - out_index, 4); \ | size_t out_remain = std::min(block_size - out_index, 4); \ | ||||
std::memcpy(zerobuffer, tmp_output, out_remain * sizeof(float) * 4); \ | std::memcpy(zerobuffer, tmp_output, out_remain * sizeof(float) * 4); \ | ||||
float32x4x4_t v0 = vld4q_f32(zerobuffer); \ | |||||
vst1q_f32(outptr, v0.val[0]); \ | |||||
vst1q_f32(outptr + 4, v0.val[1]); \ | |||||
vst1q_f32(outptr + 8, v0.val[2]); \ | |||||
vst1q_f32(outptr + 12, v0.val[3]); \ | |||||
GI_FLOAT32_V4_t v0 = GiLoadUzipFloat32V4(zerobuffer); \ | |||||
GiStoreFloat32(outptr, GiGetSubVectorFloat32V4(v0, 0)); \ | |||||
GiStoreFloat32(outptr + 4, GiGetSubVectorFloat32V4(v0, 1)); \ | |||||
GiStoreFloat32(outptr + 8, GiGetSubVectorFloat32V4(v0, 2)); \ | |||||
GiStoreFloat32(outptr + 12, GiGetSubVectorFloat32V4(v0, 3)); \ | |||||
} \ | } \ | ||||
output_base += 48; \ | output_base += 48; \ | ||||
output_base4 += 16; | output_base4 += 16; | ||||
#define LOAD_AND_STOR_IM2COL_DST() \ | |||||
float32x4_t v1 = vld1q_f32(&src[index + 4]); \ | |||||
float32x4_t v2 = vld1q_f32(&src[index + 8]); \ | |||||
vst1q_f32(&output0[i], v0); \ | |||||
vst1q_f32(&output1[i], v1); \ | |||||
vst1q_f32(&output2[i], v2); \ | |||||
i += 4; \ | |||||
index += 8; \ | |||||
#define LOAD_AND_STOR_IM2COL_DST() \ | |||||
GI_FLOAT32_t v1 = GiLoadFloat32(&src[index + 4]); \ | |||||
GI_FLOAT32_t v2 = GiLoadFloat32(&src[index + 8]); \ | |||||
GiStoreFloat32(&output0[i], v0); \ | |||||
GiStoreFloat32(&output1[i], v1); \ | |||||
GiStoreFloat32(&output2[i], v2); \ | |||||
i += 4; \ | |||||
index += 8; \ | |||||
v0 = v2; | v0 = v2; | ||||
void fuse_packb( | void fuse_packb( | ||||
@@ -94,12 +93,12 @@ void fuse_packb( | |||||
size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW + | size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW + | ||||
cur_remain_w * SW); | cur_remain_w * SW); | ||||
for (int w = cur_remain_w; w < end_remain_w; w++) { | for (int w = cur_remain_w; w < end_remain_w; w++) { | ||||
vst1q_f32(&output02[i], vld1q_f32(&src[index])); | |||||
vst1q_f32(&output1[i], vld1q_f32(&src[index + 4])); | |||||
GiStoreFloat32(&output02[i], GiLoadFloat32(&src[index])); | |||||
GiStoreFloat32(&output1[i], GiLoadFloat32(&src[index + 4])); | |||||
i += 4; | i += 4; | ||||
index += 8; | index += 8; | ||||
} | } | ||||
vst1q_f32(&output02[i], vld1q_f32(&src[index])); | |||||
GiStoreFloat32(&output02[i], GiLoadFloat32(&src[index])); | |||||
float* output[3]; | float* output[3]; | ||||
output[0] = output02; | output[0] = output02; | ||||
output[1] = output1; | output[1] = output1; | ||||
@@ -120,19 +119,19 @@ void fuse_packb( | |||||
size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW + | size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW + | ||||
(cur_remain_w * SW)); | (cur_remain_w * SW)); | ||||
float32x4_t v0 = vld1q_f32(&src[index]); | |||||
GI_FLOAT32_t v0 = GiLoadFloat32(&src[index]); | |||||
for (int w = cur_remain_w; w < OW; w++) { | for (int w = cur_remain_w; w < OW; w++) { | ||||
LOAD_AND_STOR_IM2COL_DST(); | LOAD_AND_STOR_IM2COL_DST(); | ||||
} | } | ||||
for (int h = start_h + 1; h < end_h; h++) { | for (int h = start_h + 1; h < end_h; h++) { | ||||
size_t index = 4 * (ic * IH * IW + (h * SH + fh) * IW); | size_t index = 4 * (ic * IH * IW + (h * SH + fh) * IW); | ||||
v0 = vld1q_f32(&src[index]); | |||||
v0 = GiLoadFloat32(&src[index]); | |||||
rep(ow, OW) { LOAD_AND_STOR_IM2COL_DST(); } | rep(ow, OW) { LOAD_AND_STOR_IM2COL_DST(); } | ||||
} | } | ||||
index = 4 * (ic * IH * IW + (end_h * SH + fh) * IW); | index = 4 * (ic * IH * IW + (end_h * SH + fh) * IW); | ||||
v0 = vld1q_f32(&src[index]); | |||||
v0 = GiLoadFloat32(&src[index]); | |||||
for (int w = 0; w < end_remain_w; w++) { | for (int w = 0; w < end_remain_w; w++) { | ||||
LOAD_AND_STOR_IM2COL_DST(); | LOAD_AND_STOR_IM2COL_DST(); | ||||
} | } | ||||
@@ -190,6 +189,4 @@ template class StrategyFuseXx12x1Nchw44K3x3S2< | |||||
float, float, megdnn::PostprocessMode::FLOAT>; | float, float, megdnn::PostprocessMode::FLOAT>; | ||||
} // namespace megdnn | } // namespace megdnn | ||||
#endif | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -57,7 +57,7 @@ void kern_naive(const MatrixMulImpl::KernParam& kern_param) { | |||||
size_t pack_size = get_pack_size(); | size_t pack_size = get_pack_size(); | ||||
megdnn_assert( | megdnn_assert( | ||||
(M % pack_size == 0 && K % pack_size == 0), | (M % pack_size == 0 && K % pack_size == 0), | ||||
"M and N must time of pack_size M: %zu N: %zu pack_size: %zu", M, N, | |||||
"M and K must time of pack_size M: %zu K: %zu pack_size: %zu", M, N, | |||||
pack_size); | pack_size); | ||||
#define DISPATCH(TA, TB) \ | #define DISPATCH(TA, TB) \ | ||||
@@ -263,12 +263,15 @@ void gi_f32_mk4_4x8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
} // anonymous namespace | } // anonymous namespace | ||||
bool MatrixMulImpl::AlgoF32GiMK4_4x8::usable( | bool MatrixMulImpl::AlgoF32GiMK4_4x8::usable( | ||||
const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
constexpr size_t MB = 4; | |||||
constexpr size_t KB = 4; | |||||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | ||||
kern_size_param.format == param::MatrixMul::Format::MK4 && | kern_size_param.format == param::MatrixMul::Format::MK4 && | ||||
kern_size_param.B_type == kern_size_param.A_type && | kern_size_param.B_type == kern_size_param.A_type && | ||||
kern_size_param.C_type == kern_size_param.A_type && | kern_size_param.C_type == kern_size_param.A_type && | ||||
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | ||||
!kern_size_param.trB; | |||||
!kern_size_param.trB && kern_size_param.M % MB == 0 && | |||||
kern_size_param.K % KB == 0; | |||||
} | } | ||||
size_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_workspace( | size_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_workspace( | ||||
@@ -295,6 +298,71 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern( | |||||
return gi_f32_mk4_4x8_kern; | return gi_f32_mk4_4x8_kern; | ||||
} | } | ||||
/* ===================== F32 algo gi mk4 pack K4x12 ===================== */ | |||||
namespace { | |||||
void f32_gi_mk4_pack_4x12_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
MIDOUT_BEGIN( | |||||
megdnn_fb_gi_matmul_kern, midout_iv("f32_gi_mk4_pack_4x12_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto trA = kern_param.trA, trB = kern_param.trB; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||||
C_type = kern_param.C_type; | |||||
const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>(); | |||||
auto Cptr = kern_param.C<float>(); | |||||
matmul::fallback::gi_sgemm_mk4_pack_4x12 strategy( | |||||
M, N, K, A_type, B_type, C_type); | |||||
megdnn::matmul::GemmInterleaved<matmul::fallback::gi_sgemm_mk4_pack_4x12>( | |||||
M, N, K, trA, trB, strategy) | |||||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoF32GiMK4Pack4x12::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||||
kern_size_param.B_type == kern_size_param.A_type && | |||||
kern_size_param.C_type == kern_size_param.A_type && | |||||
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||||
!kern_size_param.trB && kern_size_param.M % 4 == 0 && | |||||
kern_size_param.K % 4 == 0 && !kern_size_param.trA && !kern_size_param.trB; | |||||
} | |||||
size_t MatrixMulImpl::AlgoF32GiMK4Pack4x12::get_workspace( | |||||
const KernSizeParam& kern_size_param) const { | |||||
MIDOUT_BEGIN( | |||||
megdnn_fb_gi_matmul_kern, | |||||
midout_iv("AlgoF32GiMK4Pack4x12::get_workspace"_hash)) { | |||||
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; | |||||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||||
C_type = kern_size_param.C_type; | |||||
matmul::fallback::gi_sgemm_mk4_pack_4x12 strategy( | |||||
M, N, K, A_type, B_type, C_type); | |||||
return megdnn::matmul::GemmInterleaved< | |||||
matmul::fallback::gi_sgemm_mk4_pack_4x12>( | |||||
M, N, K, trA, trB, strategy) | |||||
.get_workspace_size(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4Pack4x12::get_kern( | |||||
const KernSizeParam&) const { | |||||
return f32_gi_mk4_pack_4x12_kern; | |||||
} | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( | |||||
AlgoF32GiMK4Pack4x12, megdnn_fb_gi_matmul_kern, "AlgoF32GiMK4Pack4x12"_hash, | |||||
matmul::fallback::gi_sgemm_mk4_pack_4x12, float, float, AlgoDataType::FLOAT32, | |||||
MK4); | |||||
/* ===================== F32 algo ===================== */ | /* ===================== F32 algo ===================== */ | ||||
namespace { | namespace { | ||||
void f32_kern(const MatrixMulImpl::KernParam& kern_param) { | void f32_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
@@ -97,6 +97,19 @@ public: | |||||
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8) | MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8) | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoF32GiMK4Pack4x12 final : public AlgoBase { | |||||
public: | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | |||||
const char* name() const override { return "FB_GI_F32_MK4_PACK_4x12"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_PACK_4x12) | |||||
}; | |||||
class MatrixMulImpl::AlgoF32Gi4x12 final : public AlgoBase { | class MatrixMulImpl::AlgoF32Gi4x12 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | ||||
@@ -9,6 +9,8 @@ MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12) | |||||
MEGDNN_REG_GEMM_STRATEGY_NOPACK( | MEGDNN_REG_GEMM_STRATEGY_NOPACK( | ||||
float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8); | float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8); | ||||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, gi_sgemm_4x12); | MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, gi_sgemm_4x12); | ||||
MEGDNN_REG_GEMM_STRATEGY( | |||||
float, float, float, 4, 12, 1, false, false, gi_sgemm_mk4_pack_4x12); | |||||
} // namespace fallback | } // namespace fallback | ||||
} // namespace matmul | } // namespace matmul | ||||
@@ -214,6 +214,113 @@ static GI_FORCEINLINE void transpose_4x4_1_s( | |||||
outptr += stride; | outptr += stride; | ||||
} | } | ||||
template <typename T> | |||||
static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { | |||||
static_assert(sizeof(T) == 4, "transpose_1x12_4_s only support sizeof(T) == 4"); | |||||
GI_FLOAT32_t tmp_a, tmp_b; | |||||
#define LOAD() \ | |||||
tmp_a = GiLoadFloat32(inptr0); \ | |||||
inptr0 += 4; \ | |||||
tmp_b = GiLoadFloat32(inptr0); \ | |||||
inptr0 += 4; | |||||
LOAD(); | |||||
GI_FLOAT32_V2_t d0d1d2d3 = GiZipqFloat32(tmp_a, tmp_b); | |||||
LOAD(); | |||||
GI_FLOAT32_V2_t d4d5d6d7 = GiZipqFloat32(tmp_a, tmp_b); | |||||
LOAD(); | |||||
GI_FLOAT32_V2_t d8d9d10d11 = GiZipqFloat32(tmp_a, tmp_b); | |||||
LOAD(); | |||||
GI_FLOAT32_V2_t d12d13d14d15 = GiZipqFloat32(tmp_a, tmp_b); | |||||
LOAD(); | |||||
GI_FLOAT32_V2_t d16d17d18d19 = GiZipqFloat32(tmp_a, tmp_b); | |||||
LOAD(); | |||||
GI_FLOAT32_V2_t d20d21d22d23 = GiZipqFloat32(tmp_a, tmp_b); | |||||
#undef LOAD | |||||
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0))); | |||||
GiSt1Float32(outptr + 1 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0))); | |||||
GiSt1Float32( | |||||
outptr + 2 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 0))); | |||||
GiSt1Float32( | |||||
outptr + 3 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 0))); | |||||
GiSt1Float32( | |||||
outptr + 4 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 0))); | |||||
GiSt1Float32( | |||||
outptr + 5 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 0))); | |||||
GiSt1Float32( | |||||
outptr + 6 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0))); | |||||
GiSt1Float32( | |||||
outptr + 7 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0))); | |||||
GiSt1Float32( | |||||
outptr + 8 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 0))); | |||||
GiSt1Float32( | |||||
outptr + 9 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 0))); | |||||
GiSt1Float32( | |||||
outptr + 10 * 2, | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 0))); | |||||
GiSt1Float32( | |||||
outptr + 11 * 2, | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 0))); | |||||
GiSt1Float32( | |||||
outptr + 12 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1))); | |||||
GiSt1Float32( | |||||
outptr + 13 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1))); | |||||
GiSt1Float32( | |||||
outptr + 14 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 1))); | |||||
GiSt1Float32( | |||||
outptr + 15 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 1))); | |||||
GiSt1Float32( | |||||
outptr + 16 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 1))); | |||||
GiSt1Float32( | |||||
outptr + 17 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 1))); | |||||
GiSt1Float32( | |||||
outptr + 18 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1))); | |||||
GiSt1Float32( | |||||
outptr + 19 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1))); | |||||
GiSt1Float32( | |||||
outptr + 20 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 1))); | |||||
GiSt1Float32( | |||||
outptr + 21 * 2, | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 1))); | |||||
GiSt1Float32( | |||||
outptr + 22 * 2, | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 1))); | |||||
GiSt1Float32( | |||||
outptr + 23 * 2, | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 1))); | |||||
outptr += 23 * 2; | |||||
} | |||||
template <typename T> | |||||
static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { | |||||
static_assert(sizeof(T) == 4, "transpose_1x4_4_s only support sizeof(T) == 4"); | |||||
GI_FLOAT32_t tmp_a, tmp_b; | |||||
#define LOAD() \ | |||||
tmp_a = GiLoadFloat32(inptr0); \ | |||||
inptr0 += 4; \ | |||||
tmp_b = GiLoadFloat32(inptr0); \ | |||||
inptr0 += 4; | |||||
LOAD(); | |||||
GI_FLOAT32_V2_t d0d1d2d3 = GiZipqFloat32(tmp_a, tmp_b); | |||||
LOAD(); | |||||
GI_FLOAT32_V2_t d4d5d6d7 = GiZipqFloat32(tmp_a, tmp_b); | |||||
#undef LOAD | |||||
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0))); | |||||
GiSt1Float32(outptr + 1 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0))); | |||||
GiSt1Float32( | |||||
outptr + 2 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0))); | |||||
GiSt1Float32( | |||||
outptr + 3 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0))); | |||||
GiSt1Float32(outptr + 4 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1))); | |||||
GiSt1Float32(outptr + 5 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1))); | |||||
GiSt1Float32( | |||||
outptr + 6 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1))); | |||||
GiSt1Float32( | |||||
outptr + 7 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1))); | |||||
outptr += 7 * 2; | |||||
} | |||||
} // namespace fallback | } // namespace fallback | ||||
} // namespace matmul | } // namespace matmul | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -0,0 +1,458 @@ | |||||
//! risc-v gcc will error report uninitialized var at if/else case when use RVV type | |||||
#pragma GCC diagnostic push | |||||
#pragma GCC diagnostic ignored "-Wuninitialized" | |||||
#ifdef __GNUC__ | |||||
#ifndef __has_warning | |||||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||||
#else | |||||
#if __has_warning("-Wmaybe-uninitialized") | |||||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||||
#endif | |||||
#endif | |||||
#endif | |||||
#include "src/fallback/matrix_mul/generic_strategy.h" | |||||
#include "src/fallback/matrix_mul/gi/fp32/common.h" | |||||
using namespace megdnn; | |||||
using namespace matmul::fallback; | |||||
namespace { | |||||
void kern_4x12( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k) { | |||||
MEGDNN_MARK_USED_VAR(LDC); | |||||
const float* a_ptr = packA; | |||||
const float* b_ptr = packB; | |||||
float* output0 = output; | |||||
int oddk = (K & 1); | |||||
K = ((K + 1) / 2) - 1; | |||||
float* r1 = output; | |||||
GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19, | |||||
d20d21, d22d23, d24d25, d26d27, d28d29, d30d31; | |||||
if (is_first_k) { | |||||
d8d9 = GiBroadcastFloat32(0.0f); | |||||
d10d11 = GiBroadcastFloat32(0.0f); | |||||
d12d13 = GiBroadcastFloat32(0.0f); | |||||
d14d15 = GiBroadcastFloat32(0.0f); | |||||
d0d1 = GiLoadFloat32(a_ptr); | |||||
a_ptr = a_ptr + 4; | |||||
d16d17 = GiBroadcastFloat32(0.0f); | |||||
d18d19 = GiBroadcastFloat32(0.0f); | |||||
d20d21 = GiBroadcastFloat32(0.0f); | |||||
d22d23 = GiBroadcastFloat32(0.0f); | |||||
d4d5 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d6d7 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d24d25 = GiBroadcastFloat32(0.0f); | |||||
d26d27 = GiBroadcastFloat32(0.0f); | |||||
d28d29 = GiBroadcastFloat32(0.0f); | |||||
d30d31 = GiBroadcastFloat32(0.0f); | |||||
} else { | |||||
d8d9 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d10d11 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d12d13 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d14d15 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d16d17 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d18d19 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d20d21 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d22d23 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d24d25 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d26d27 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d28d29 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d30d31 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d0d1 = GiLoadFloat32(a_ptr); | |||||
a_ptr = a_ptr + 4; | |||||
d4d5 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
} | |||||
for (; K > 0; K--) { | |||||
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); | |||||
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); | |||||
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); | |||||
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); | |||||
d4d5 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0); | |||||
d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1); | |||||
d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2); | |||||
d2d3 = GiLoadFloat32(a_ptr); | |||||
a_ptr = a_ptr + 4; | |||||
d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3); | |||||
d6d7 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0); | |||||
d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1); | |||||
d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2); | |||||
d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3); | |||||
d4d5 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); | |||||
d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); | |||||
d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); | |||||
d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); | |||||
d6d7 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d16d17 = GiSimdFmaLane(d16d17, d2d3, d4d5, 0); | |||||
d18d19 = GiSimdFmaLane(d18d19, d2d3, d4d5, 1); | |||||
d0d1 = GiLoadFloat32(a_ptr); | |||||
a_ptr = a_ptr + 4; | |||||
d20d21 = GiSimdFmaLane(d20d21, d2d3, d4d5, 2); | |||||
d22d23 = GiSimdFmaLane(d22d23, d2d3, d4d5, 3); | |||||
d4d5 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d24d25 = GiSimdFmaLane(d24d25, d2d3, d6d7, 0); | |||||
d26d27 = GiSimdFmaLane(d26d27, d2d3, d6d7, 1); | |||||
d28d29 = GiSimdFmaLane(d28d29, d2d3, d6d7, 2); | |||||
d30d31 = GiSimdFmaLane(d30d31, d2d3, d6d7, 3); | |||||
d6d7 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
} | |||||
if (1 == oddk) { | |||||
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); | |||||
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); | |||||
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); | |||||
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); | |||||
d4d5 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0); | |||||
GiStoreFloat32(output0, d8d9); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d10d11); | |||||
output0 = output0 + 4; | |||||
d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1); | |||||
d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2); | |||||
GiStoreFloat32(output0, d12d13); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d14d15); | |||||
output0 = output0 + 4; | |||||
d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3); | |||||
d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0); | |||||
GiStoreFloat32(output0, d16d17); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d18d19); | |||||
output0 = output0 + 4; | |||||
d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1); | |||||
GiStoreFloat32(output0, d20d21); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d22d23); | |||||
output0 = output0 + 4; | |||||
d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2); | |||||
GiStoreFloat32(output0, d24d25); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d26d27); | |||||
output0 = output0 + 4; | |||||
d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3); | |||||
GiStoreFloat32(output0, d28d29); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d30d31); | |||||
output0 = output0 + 4; | |||||
} else { | |||||
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); | |||||
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); | |||||
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); | |||||
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); | |||||
d4d5 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0); | |||||
d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1); | |||||
d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2); | |||||
d2d3 = GiLoadFloat32(a_ptr); | |||||
a_ptr = a_ptr + 4; | |||||
d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3); | |||||
d6d7 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0); | |||||
d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1); | |||||
d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2); | |||||
d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3); | |||||
d4d5 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); | |||||
d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); | |||||
d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); | |||||
d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); | |||||
d6d7 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d16d17 = GiSimdFmaLane(d16d17, d2d3, d4d5, 0); | |||||
d18d19 = GiSimdFmaLane(d18d19, d2d3, d4d5, 1); | |||||
GiStoreFloat32(output0, d8d9); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d10d11); | |||||
output0 = output0 + 4; | |||||
d20d21 = GiSimdFmaLane(d20d21, d2d3, d4d5, 2); | |||||
d22d23 = GiSimdFmaLane(d22d23, d2d3, d4d5, 3); | |||||
GiStoreFloat32(output0, d12d13); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d14d15); | |||||
output0 = output0 + 4; | |||||
d24d25 = GiSimdFmaLane(d24d25, d2d3, d6d7, 0); | |||||
d26d27 = GiSimdFmaLane(d26d27, d2d3, d6d7, 1); | |||||
GiStoreFloat32(output0, d16d17); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d18d19); | |||||
output0 = output0 + 4; | |||||
d28d29 = GiSimdFmaLane(d28d29, d2d3, d6d7, 2); | |||||
d30d31 = GiSimdFmaLane(d30d31, d2d3, d6d7, 3); | |||||
GiStoreFloat32(output0, d20d21); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d22d23); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d24d25); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d26d27); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d28d29); | |||||
output0 = output0 + 4; | |||||
GiStoreFloat32(output0, d30d31); | |||||
output0 = output0 + 4; | |||||
} | |||||
} | |||||
void kern_4x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int n_remain) { | |||||
MEGDNN_MARK_USED_VAR(LDC); | |||||
const float* a_ptr = packA; | |||||
const float* b_ptr = packB; | |||||
int oddk = (K & 1); | |||||
K = ((K + 1) / 2) - 1; | |||||
float* r1 = output; | |||||
GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15; | |||||
if (is_first_k) { | |||||
d8d9 = GiBroadcastFloat32(0.0f); | |||||
d10d11 = GiBroadcastFloat32(0.0f); | |||||
d0d1 = GiLoadFloat32(a_ptr); | |||||
a_ptr = a_ptr + 4; | |||||
d12d13 = GiBroadcastFloat32(0.0f); | |||||
d4d5 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d14d15 = GiBroadcastFloat32(0.0f); | |||||
} else { | |||||
if (n_remain == 4) { | |||||
d8d9 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d10d11 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d12d13 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d14d15 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
} else if (n_remain == 3) { | |||||
d8d9 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d10d11 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d12d13 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
} else if (n_remain == 2) { | |||||
d8d9 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
d10d11 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
} else if (n_remain == 1) { | |||||
d8d9 = GiLoadFloat32(r1); | |||||
r1 = r1 + 4; | |||||
} | |||||
} | |||||
for (; K > 0; K--) { | |||||
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); | |||||
d2d3 = GiLoadFloat32(a_ptr); | |||||
a_ptr = a_ptr + 4; | |||||
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); | |||||
d6d7 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); | |||||
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); | |||||
d4d5 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); | |||||
d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); | |||||
d0d1 = GiLoadFloat32(a_ptr); | |||||
a_ptr = a_ptr + 4; | |||||
d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); | |||||
d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); | |||||
} | |||||
if (1 == oddk) { | |||||
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); | |||||
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); | |||||
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); | |||||
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); | |||||
} else { | |||||
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); | |||||
d2d3 = GiLoadFloat32(a_ptr); | |||||
a_ptr = a_ptr + 4; | |||||
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); | |||||
d6d7 = GiLoadFloat32(b_ptr); | |||||
b_ptr = b_ptr + 4; | |||||
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); | |||||
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); | |||||
d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); | |||||
d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); | |||||
d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); | |||||
d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); | |||||
} | |||||
if (n_remain == 4) { | |||||
GiStoreFloat32(output, d8d9); | |||||
output = output + 4; | |||||
GiStoreFloat32(output, d10d11); | |||||
output = output + 4; | |||||
GiStoreFloat32(output, d12d13); | |||||
output = output + 4; | |||||
GiStoreFloat32(output, d14d15); | |||||
output = output + 4; | |||||
} else if (n_remain == 3) { | |||||
GiStoreFloat32(output, d8d9); | |||||
output = output + 4; | |||||
GiStoreFloat32(output, d10d11); | |||||
output = output + 4; | |||||
GiStoreFloat32(output, d12d13); | |||||
output = output + 4; | |||||
} else if (n_remain == 2) { | |||||
GiStoreFloat32(output, d8d9); | |||||
output = output + 4; | |||||
GiStoreFloat32(output, d10d11); | |||||
output = output + 4; | |||||
} else if (n_remain == 1) { | |||||
GiStoreFloat32(output, d8d9); | |||||
output = output + 4; | |||||
} | |||||
} | |||||
} // namespace | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gi_sgemm_mk4_pack_4x12); | |||||
//! Now no matmul mode of only packB support in conv1x1 and im2col, so just copy | |||||
//! the weight | |||||
void gi_sgemm_mk4_pack_4x12::pack_A( | |||||
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
bool) const { | |||||
megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); | |||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||||
constexpr int PACK_C_SIZE = 4; | |||||
size_t cp_length = (kmax - k0) * PACK_C_SIZE; | |||||
for (int m = y0; m < ymax; m += 4) { | |||||
const float* src = in + (m / PACK_C_SIZE) * ldin + k0 * PACK_C_SIZE; | |||||
memcpy(out, src, cp_length * sizeof(float)); | |||||
out += cp_length; | |||||
} | |||||
} | |||||
void gi_sgemm_mk4_pack_4x12::pack_B( | |||||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose_B) const { | |||||
megdnn_assert(!transpose_B); | |||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||||
float tmpbuff[16] = {0.0f}; | |||||
constexpr int PACK_C_SIZE = 4; | |||||
int ksize = kmax - k0; | |||||
int ksize12 = ksize * 12; | |||||
int ksize4 = (ksize << 2); | |||||
float* outptr_base = out; | |||||
float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; | |||||
int k = k0; | |||||
for (; k + 3 < kmax; k += 4) { | |||||
const float* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE; | |||||
int x = x0; | |||||
auto outptr = outptr_base; | |||||
for (; x + 12 <= xmax; x += 12) { | |||||
auto outptr_interleave = outptr; | |||||
transpose_1x12_4_s(inptr, outptr_interleave); | |||||
outptr += ksize12; | |||||
} | |||||
outptr = outptr_base4; | |||||
for (; x + 4 <= xmax; x += 4) { | |||||
auto outptr_interleave = outptr; | |||||
transpose_1x4_4_s(inptr, outptr_interleave); | |||||
outptr += ksize4; | |||||
} | |||||
if (x < xmax) { | |||||
memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE); | |||||
auto outptr_interleave = outptr; | |||||
const float* tmp_ptr = &tmpbuff[0]; | |||||
transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave); | |||||
outptr += ksize4; | |||||
} | |||||
outptr_base += 12 * PACK_C_SIZE; | |||||
outptr_base4 += 4 * PACK_C_SIZE; | |||||
} | |||||
} | |||||
void gi_sgemm_mk4_pack_4x12::kern( | |||||
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||||
size_t LDC, bool is_first_k, const float*, float*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
constexpr int PACK_C_SIZE = 4; | |||||
constexpr size_t A_INTERLEAVE = 4; | |||||
constexpr size_t B_INTERLEAVE = 12; | |||||
const int K12 = K * 12; | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m < M; m += A_INTERLEAVE) { | |||||
float* output = C + (m / 4 * LDC); | |||||
size_t n = 0; | |||||
const float* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
kern_4x12(packA, cur_packB, K, output, LDC, is_first_k); | |||||
output += PACK_C_SIZE * B_INTERLEAVE; | |||||
cur_packB += K12; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += PACK_C_SIZE * 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -28,6 +28,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
AlgoNaive naive; | AlgoNaive naive; | ||||
AlgoF32GiGemvMK4 f32_gemv_mk4; | AlgoF32GiGemvMK4 f32_gemv_mk4; | ||||
AlgoF32GiMK4_4x8 f32_mk4_4x8; | AlgoF32GiMK4_4x8 f32_mk4_4x8; | ||||
AlgoF32GiMK4Pack4x12 f32_mk4_gi_pack_4x12; | |||||
AlgoF32Gi4x12 f32_4x8; | AlgoF32Gi4x12 f32_4x8; | ||||
SmallVector<AlgoBase*> m_all_algos; | SmallVector<AlgoBase*> m_all_algos; | ||||
AlgoBase::Mapper m_all_algos_map; | AlgoBase::Mapper m_all_algos_map; | ||||
@@ -36,6 +37,7 @@ public: | |||||
AlgoPack() { | AlgoPack() { | ||||
m_all_algos.emplace_back(&f32_gemv_mk4); | m_all_algos.emplace_back(&f32_gemv_mk4); | ||||
m_all_algos.emplace_back(&f32_mk4_4x8); | m_all_algos.emplace_back(&f32_mk4_4x8); | ||||
m_all_algos.emplace_back(&f32_mk4_gi_pack_4x12); | |||||
m_all_algos.emplace_back(&f32_4x8); | m_all_algos.emplace_back(&f32_4x8); | ||||
m_all_algos.emplace_back(&gemv); | m_all_algos.emplace_back(&gemv); | ||||
m_all_algos.emplace_back(&f32_k8x12x1); | m_all_algos.emplace_back(&f32_k8x12x1); | ||||
@@ -103,6 +103,7 @@ public: | |||||
FB_NAIVE, | FB_NAIVE, | ||||
FB_GI_F32_GEMV_MK4, | FB_GI_F32_GEMV_MK4, | ||||
FB_GI_F32_MK4_4x8, | FB_GI_F32_MK4_4x8, | ||||
FB_GI_F32_MK4_PACK_4x12, | |||||
FB_GI_F32_4x12, | FB_GI_F32_4x12, | ||||
#if MEGDNN_X86 | #if MEGDNN_X86 | ||||
@@ -230,10 +231,11 @@ public: | |||||
}; | }; | ||||
private: | private: | ||||
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | |||||
class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44 | |||||
class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44 | |||||
class AlgoF32Gi4x12; // fallback F32 gi Gemm | |||||
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | |||||
class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44 | |||||
class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44 | |||||
class AlgoF32GiMK4Pack4x12; // fallback F32 gi Gemm pack NCHW44 | |||||
class AlgoF32Gi4x12; // fallback F32 gi Gemm | |||||
class AlgoGemv; | class AlgoGemv; | ||||
class AlgoNaive; | class AlgoNaive; | ||||
class AlgoPack; | class AlgoPack; | ||||
@@ -364,7 +364,9 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | |||||
DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, | DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, | ||||
size_t pack_oc_size = 1) { | size_t pack_oc_size = 1) { | ||||
MEGDNN_MARK_USED_VAR(pack_oc_size); | MEGDNN_MARK_USED_VAR(pack_oc_size); | ||||
megdnn_assert(pack_oc_size == 1, "PostProcess only support nchw in x86"); | |||||
megdnn_assert( | |||||
pack_oc_size == 1 || pack_oc_size == 4, | |||||
"PostProcess only support nchw/44 in x86"); | |||||
megdnn_assert( | megdnn_assert( | ||||
nonlineMode == megdnn::param::ConvBiasV0::NonlineMode::IDENTITY, | nonlineMode == megdnn::param::ConvBiasV0::NonlineMode::IDENTITY, | ||||
"Add bias PostProcess only support IDENTITY"); | "Add bias PostProcess only support IDENTITY"); | ||||
@@ -59,6 +59,11 @@ cb(dt_float32, float, "avx2", float, __m256, mm256, ps, ps, SIMDType::AVX2); | |||||
template <typename ctype, SIMDType simd_type = SIMDType::AVX2> | template <typename ctype, SIMDType simd_type = SIMDType::AVX2> | ||||
struct ParamElemVisitorHalfBoardCast; | struct ParamElemVisitorHalfBoardCast; | ||||
//! some compiler do not define _mm256_set_m128 | |||||
#define _mm256_set_m128ff(xmm1, xmm2) \ | |||||
_mm256_permute2f128_ps( \ | |||||
_mm256_castps128_ps256(xmm1), _mm256_castps128_ps256(xmm2), 2) | |||||
#define cb( \ | #define cb( \ | ||||
_ctype, _simd_ptr_type, load_half_fuc, half_type, _simd_type, board_cast_func) \ | _ctype, _simd_ptr_type, load_half_fuc, half_type, _simd_type, board_cast_func) \ | ||||
template <> \ | template <> \ | ||||
@@ -78,9 +83,10 @@ cb(dt_int32, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); | |||||
cb(dt_int16, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); | cb(dt_int16, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); | ||||
cb(dt_int8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); | cb(dt_int8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); | ||||
cb(dt_uint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); | cb(dt_uint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); | ||||
cb(dt_float32, float, _mm_load_ps, __m128, __m256, _mm256_set_m128); | |||||
cb(dt_float32, float, _mm_load_ps, __m128, __m256, _mm256_set_m128ff); | |||||
#undef cb | #undef cb | ||||
#undef _mm256_set_m128ff | |||||
/*! | /*! | ||||
* \brief broadcast type | * \brief broadcast type | ||||
* BCAST_x[0]x[1]...: x[i] == !stride[i] | * BCAST_x[0]x[1]...: x[i] == !stride[i] | ||||
@@ -239,6 +239,74 @@ void checker_conv_bias( | |||||
} | } | ||||
} | } | ||||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_1X1_S1_MK4_PACK_F32) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({1}, FULL_NLMODE, ALL_BIASMODE, 1, true); | |||||
check_conv_bias(args, handle(), "CONV1x1:FB_GI_F32_MK4_PACK_4x12:24"); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S1_MK4_PACK_F32_PREPROCESS) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({2, 4, 7}, FULL_NLMODE, BR_AND_NO_BIASMODE, 1); | |||||
#define cb(name) \ | |||||
check_conv_bias_preprocess( \ | |||||
args, handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \ | |||||
dtype::Float32(), dtype::Float32(), name); | |||||
cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); | |||||
#undef cb | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S2_MK4_PACK_F32_FUSE_PREPROCESS) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({3}, FULL_NLMODE, BR_AND_BIAS_BIASMODE, 2); | |||||
#define cb(name) \ | |||||
check_conv_bias_preprocess( \ | |||||
args, handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \ | |||||
dtype::Float32(), dtype::Float32(), name); | |||||
cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); | |||||
#undef cb | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_1X1_S1_MK4_PACK_F32_PREPROCESS) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({1}, FULL_NLMODE, ALL_BIASMODE, 1, true); | |||||
#define cb(name) \ | |||||
check_conv_bias_preprocess( \ | |||||
args, handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \ | |||||
dtype::Float32(), dtype::Float32(), name); | |||||
cb("CONV1x1:FB_GI_F32_MK4_PACK_4x12:24"); | |||||
#undef cb | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S1_MK4_PACK_F32) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({2, 4, 7}, FULL_NLMODE, BR_AND_BIAS_BIASMODE, 1); | |||||
check_conv_bias(args, handle(), "IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S2_MK4_PACK_F32) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({3, 5, 6}, FULL_NLMODE, BR_AND_BIAS_BIASMODE, 2); | |||||
#define cb(name) check_conv_bias(args, handle(), name); | |||||
cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); | |||||
#undef cb | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S2_MK4_PACK_F32_FUSE) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({3}, FULL_NLMODE, ALL_BIASMODE, 2); | |||||
#define cb(name) check_conv_bias(args, handle(), name); | |||||
cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); | |||||
#undef cb | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_IM2COL_8X8X16) { | TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_IM2COL_8X8X16) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
param::ConvBias cur_param; | param::ConvBias cur_param; | ||||
@@ -42,12 +42,18 @@ TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) { | |||||
"FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); | "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); | ||||
} | } | ||||
TEST_F(FALLBACK, MATRIX_MULF_GI_F32_4x12) { | |||||
TEST_F(FALLBACK, MATRIX_MUL_GI_F32_4x12) { | |||||
matrix_mul::check_matrix_mul( | matrix_mul::check_matrix_mul( | ||||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | ||||
"FB_GI_F32_4x12"); | "FB_GI_F32_4x12"); | ||||
} | } | ||||
TEST_F(FALLBACK, MATRIX_MUL_GI_PACK_MK4) { | |||||
matrix_mul::check_matrix_mul( | |||||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||||
"FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4, 1); | |||||
} | |||||
TEST_F(FALLBACK, MATRIX_MUL_RECORD) { | TEST_F(FALLBACK, MATRIX_MUL_RECORD) { | ||||
TaskRecordChecker<MatrixMul> checker(1); | TaskRecordChecker<MatrixMul> checker(1); | ||||
using Param = MatrixMul::Param; | using Param = MatrixMul::Param; | ||||
@@ -163,6 +169,13 @@ TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_FB_GI_F32_4x12) { | |||||
"FB_GI_F32_4x12", param::MatrixMul::Format::DEFAULT); | "FB_GI_F32_4x12", param::MatrixMul::Format::DEFAULT); | ||||
} | } | ||||
TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_GI_PACK_MK4) { | |||||
auto args = matrix_mul::get_benchmark_matmul_args(); | |||||
matrix_mul::benchmark_single_algo( | |||||
handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, | |||||
"FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4); | |||||
} | |||||
#endif | #endif | ||||
} // namespace test | } // namespace test | ||||
} // namespace megdnn | } // namespace megdnn | ||||