add AlgoF32GiMK4Pack4x12 matrix_mul algo
GitOrigin-RevId: 47cfe1d733
release-1.10
@@ -197,10 +197,17 @@ bool ConvBiasImpl::AlgoConv1x1::usable( | |||
return false; | |||
} | |||
} | |||
#else //! x86 only support nchw mode | |||
if (format != param::ConvBias::Format::NCHW) { | |||
#else //! x86 and RISC-V do not support NCHW44_DOT | |||
if (format != param::ConvBias::Format::NCHW && | |||
format != param::ConvBias::Format::NCHW44) { | |||
return false; | |||
} | |||
//! hybird mode is not support | |||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||
if (param.filter_meta.icpg < 4_z || param.filter_meta.ocpg == 1) { | |||
return false; | |||
} | |||
} | |||
#endif | |||
//! param | |||
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) { | |||
@@ -345,9 +345,21 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
} | |||
} | |||
#else | |||
if (format != param::ConvBias::Format::NCHW) { | |||
if (format != param::ConvBias::Format::NCHW && | |||
format != param::ConvBias::Format::NCHW44) { | |||
return false; | |||
} | |||
if (format == param::ConvBias::Format::NCHW44) { | |||
//! current NCHW44 im2col only support DEFAULT mode matmul | |||
if (matmul_desc.packmode != Pack_Mode::DEFAULT) { | |||
return false; | |||
//! nchw44 hybird mode and channel wise is not support | |||
} else if ( | |||
param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || | |||
param.filter_meta.ocpg == 1) { | |||
return false; | |||
} | |||
} | |||
#endif | |||
if (param.src_type.enumv() != param.filter_type.enumv() || | |||
(param.src_type.enumv() != DTypeEnum::Int8 && | |||
@@ -216,10 +216,9 @@ public: | |||
cb1(NCHW, DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT, | |||
"DefaultStrategyType::FLOAT"_hash); | |||
} else if (format == param::ConvBias::Format::NCHW44) { | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
auto matmul_block = matmul_algo->get_inner_block_size(); | |||
//! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 | |||
//! im2col+pack fuse | |||
//! Optimize NCHW44 3x3s2 on aarch64 8X12X4 and fallback/armv7 | |||
//! 4x12x4 im2col+pack fuse | |||
if ((matmul_block.m == 8 || matmul_block.m == 4) && | |||
matmul_block.n == 12 && matmul_block.k == 1 && | |||
param.filter_meta.spatial[0] == 3 && | |||
@@ -236,7 +235,6 @@ public: | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
#endif | |||
cb1(NCHW44, DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT, | |||
"DefaultStrategyTypeNCHW44::FLOAT"_hash); | |||
@@ -530,7 +530,6 @@ public: | |||
}; | |||
#endif | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
template < | |||
typename op_ctype, typename op_dtype, megdnn::PostprocessMode postprocess_mode> | |||
class StrategyFuseXx12x1Nchw44K3x3S2 | |||
@@ -553,7 +552,6 @@ public: | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
}; | |||
#endif | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,7 +1,6 @@ | |||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
#include <arm_neon.h> | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
using namespace megdnn; | |||
@@ -11,32 +10,32 @@ namespace { | |||
int out_index = 0; \ | |||
outptr = output_base; \ | |||
for (; out_index + 11 < block_size; out_index += 12) { \ | |||
float32x4x4_t v0 = vld4q_f32(tmp_output); \ | |||
float32x4x4_t v1 = vld4q_f32(tmp_output + 16); \ | |||
float32x4x4_t v2 = vld4q_f32(tmp_output + 32); \ | |||
vst1q_f32(outptr, v0.val[0]); \ | |||
vst1q_f32(outptr + 4, v1.val[0]); \ | |||
vst1q_f32(outptr + 8, v2.val[0]); \ | |||
vst1q_f32(outptr + 12, v0.val[1]); \ | |||
vst1q_f32(outptr + 16, v1.val[1]); \ | |||
vst1q_f32(outptr + 20, v2.val[1]); \ | |||
vst1q_f32(outptr + 24, v0.val[2]); \ | |||
vst1q_f32(outptr + 28, v1.val[2]); \ | |||
vst1q_f32(outptr + 32, v2.val[2]); \ | |||
vst1q_f32(outptr + 36, v0.val[3]); \ | |||
vst1q_f32(outptr + 40, v1.val[3]); \ | |||
vst1q_f32(outptr + 44, v2.val[3]); \ | |||
GI_FLOAT32_V4_t v0 = GiLoadUzipFloat32V4(tmp_output); \ | |||
GI_FLOAT32_V4_t v1 = GiLoadUzipFloat32V4(tmp_output + 16); \ | |||
GI_FLOAT32_V4_t v2 = GiLoadUzipFloat32V4(tmp_output + 32); \ | |||
GiStoreFloat32(outptr, GiGetSubVectorFloat32V4(v0, 0)); \ | |||
GiStoreFloat32(outptr + 4, GiGetSubVectorFloat32V4(v1, 0)); \ | |||
GiStoreFloat32(outptr + 8, GiGetSubVectorFloat32V4(v2, 0)); \ | |||
GiStoreFloat32(outptr + 12, GiGetSubVectorFloat32V4(v0, 1)); \ | |||
GiStoreFloat32(outptr + 16, GiGetSubVectorFloat32V4(v1, 1)); \ | |||
GiStoreFloat32(outptr + 20, GiGetSubVectorFloat32V4(v2, 1)); \ | |||
GiStoreFloat32(outptr + 24, GiGetSubVectorFloat32V4(v0, 2)); \ | |||
GiStoreFloat32(outptr + 28, GiGetSubVectorFloat32V4(v1, 2)); \ | |||
GiStoreFloat32(outptr + 32, GiGetSubVectorFloat32V4(v2, 2)); \ | |||
GiStoreFloat32(outptr + 36, GiGetSubVectorFloat32V4(v0, 3)); \ | |||
GiStoreFloat32(outptr + 40, GiGetSubVectorFloat32V4(v1, 3)); \ | |||
GiStoreFloat32(outptr + 44, GiGetSubVectorFloat32V4(v2, 3)); \ | |||
outptr += ksize12; \ | |||
tmp_output += 48; \ | |||
} \ | |||
\ | |||
outptr = output_base4; \ | |||
for (; out_index + 3 < block_size; out_index += 4) { \ | |||
float32x4x4_t v0 = vld4q_f32(tmp_output); \ | |||
vst1q_f32(outptr, v0.val[0]); \ | |||
vst1q_f32(outptr + 4, v0.val[1]); \ | |||
vst1q_f32(outptr + 8, v0.val[2]); \ | |||
vst1q_f32(outptr + 12, v0.val[3]); \ | |||
GI_FLOAT32_V4_t v0 = GiLoadUzipFloat32V4(tmp_output); \ | |||
GiStoreFloat32(outptr, GiGetSubVectorFloat32V4(v0, 0)); \ | |||
GiStoreFloat32(outptr + 4, GiGetSubVectorFloat32V4(v0, 1)); \ | |||
GiStoreFloat32(outptr + 8, GiGetSubVectorFloat32V4(v0, 2)); \ | |||
GiStoreFloat32(outptr + 12, GiGetSubVectorFloat32V4(v0, 3)); \ | |||
outptr += ksize4; \ | |||
tmp_output += 16; \ | |||
} \ | |||
@@ -45,23 +44,23 @@ namespace { | |||
float zerobuffer[16] = {0}; \ | |||
size_t out_remain = std::min(block_size - out_index, 4); \ | |||
std::memcpy(zerobuffer, tmp_output, out_remain * sizeof(float) * 4); \ | |||
float32x4x4_t v0 = vld4q_f32(zerobuffer); \ | |||
vst1q_f32(outptr, v0.val[0]); \ | |||
vst1q_f32(outptr + 4, v0.val[1]); \ | |||
vst1q_f32(outptr + 8, v0.val[2]); \ | |||
vst1q_f32(outptr + 12, v0.val[3]); \ | |||
GI_FLOAT32_V4_t v0 = GiLoadUzipFloat32V4(zerobuffer); \ | |||
GiStoreFloat32(outptr, GiGetSubVectorFloat32V4(v0, 0)); \ | |||
GiStoreFloat32(outptr + 4, GiGetSubVectorFloat32V4(v0, 1)); \ | |||
GiStoreFloat32(outptr + 8, GiGetSubVectorFloat32V4(v0, 2)); \ | |||
GiStoreFloat32(outptr + 12, GiGetSubVectorFloat32V4(v0, 3)); \ | |||
} \ | |||
output_base += 48; \ | |||
output_base4 += 16; | |||
#define LOAD_AND_STOR_IM2COL_DST() \ | |||
float32x4_t v1 = vld1q_f32(&src[index + 4]); \ | |||
float32x4_t v2 = vld1q_f32(&src[index + 8]); \ | |||
vst1q_f32(&output0[i], v0); \ | |||
vst1q_f32(&output1[i], v1); \ | |||
vst1q_f32(&output2[i], v2); \ | |||
i += 4; \ | |||
index += 8; \ | |||
#define LOAD_AND_STOR_IM2COL_DST() \ | |||
GI_FLOAT32_t v1 = GiLoadFloat32(&src[index + 4]); \ | |||
GI_FLOAT32_t v2 = GiLoadFloat32(&src[index + 8]); \ | |||
GiStoreFloat32(&output0[i], v0); \ | |||
GiStoreFloat32(&output1[i], v1); \ | |||
GiStoreFloat32(&output2[i], v2); \ | |||
i += 4; \ | |||
index += 8; \ | |||
v0 = v2; | |||
void fuse_packb( | |||
@@ -94,12 +93,12 @@ void fuse_packb( | |||
size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW + | |||
cur_remain_w * SW); | |||
for (int w = cur_remain_w; w < end_remain_w; w++) { | |||
vst1q_f32(&output02[i], vld1q_f32(&src[index])); | |||
vst1q_f32(&output1[i], vld1q_f32(&src[index + 4])); | |||
GiStoreFloat32(&output02[i], GiLoadFloat32(&src[index])); | |||
GiStoreFloat32(&output1[i], GiLoadFloat32(&src[index + 4])); | |||
i += 4; | |||
index += 8; | |||
} | |||
vst1q_f32(&output02[i], vld1q_f32(&src[index])); | |||
GiStoreFloat32(&output02[i], GiLoadFloat32(&src[index])); | |||
float* output[3]; | |||
output[0] = output02; | |||
output[1] = output1; | |||
@@ -120,19 +119,19 @@ void fuse_packb( | |||
size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW + | |||
(cur_remain_w * SW)); | |||
float32x4_t v0 = vld1q_f32(&src[index]); | |||
GI_FLOAT32_t v0 = GiLoadFloat32(&src[index]); | |||
for (int w = cur_remain_w; w < OW; w++) { | |||
LOAD_AND_STOR_IM2COL_DST(); | |||
} | |||
for (int h = start_h + 1; h < end_h; h++) { | |||
size_t index = 4 * (ic * IH * IW + (h * SH + fh) * IW); | |||
v0 = vld1q_f32(&src[index]); | |||
v0 = GiLoadFloat32(&src[index]); | |||
rep(ow, OW) { LOAD_AND_STOR_IM2COL_DST(); } | |||
} | |||
index = 4 * (ic * IH * IW + (end_h * SH + fh) * IW); | |||
v0 = vld1q_f32(&src[index]); | |||
v0 = GiLoadFloat32(&src[index]); | |||
for (int w = 0; w < end_remain_w; w++) { | |||
LOAD_AND_STOR_IM2COL_DST(); | |||
} | |||
@@ -190,6 +189,4 @@ template class StrategyFuseXx12x1Nchw44K3x3S2< | |||
float, float, megdnn::PostprocessMode::FLOAT>; | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -57,7 +57,7 @@ void kern_naive(const MatrixMulImpl::KernParam& kern_param) { | |||
size_t pack_size = get_pack_size(); | |||
megdnn_assert( | |||
(M % pack_size == 0 && K % pack_size == 0), | |||
"M and N must time of pack_size M: %zu N: %zu pack_size: %zu", M, N, | |||
"M and K must time of pack_size M: %zu K: %zu pack_size: %zu", M, N, | |||
pack_size); | |||
#define DISPATCH(TA, TB) \ | |||
@@ -263,12 +263,15 @@ void gi_f32_mk4_4x8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoF32GiMK4_4x8::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
constexpr size_t MB = 4; | |||
constexpr size_t KB = 4; | |||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
kern_size_param.B_type == kern_size_param.A_type && | |||
kern_size_param.C_type == kern_size_param.A_type && | |||
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||
!kern_size_param.trB; | |||
!kern_size_param.trB && kern_size_param.M % MB == 0 && | |||
kern_size_param.K % KB == 0; | |||
} | |||
size_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_workspace( | |||
@@ -295,6 +298,71 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern( | |||
return gi_f32_mk4_4x8_kern; | |||
} | |||
/* ===================== F32 algo gi mk4 pack K4x12 ===================== */ | |||
namespace { | |||
void f32_gi_mk4_pack_4x12_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
MIDOUT_BEGIN( | |||
megdnn_fb_gi_matmul_kern, midout_iv("f32_gi_mk4_pack_4x12_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto trA = kern_param.trA, trB = kern_param.trB; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||
C_type = kern_param.C_type; | |||
const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>(); | |||
auto Cptr = kern_param.C<float>(); | |||
matmul::fallback::gi_sgemm_mk4_pack_4x12 strategy( | |||
M, N, K, A_type, B_type, C_type); | |||
megdnn::matmul::GemmInterleaved<matmul::fallback::gi_sgemm_mk4_pack_4x12>( | |||
M, N, K, trA, trB, strategy) | |||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoF32GiMK4Pack4x12::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
kern_size_param.B_type == kern_size_param.A_type && | |||
kern_size_param.C_type == kern_size_param.A_type && | |||
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||
!kern_size_param.trB && kern_size_param.M % 4 == 0 && | |||
kern_size_param.K % 4 == 0 && !kern_size_param.trA && !kern_size_param.trB; | |||
} | |||
size_t MatrixMulImpl::AlgoF32GiMK4Pack4x12::get_workspace( | |||
const KernSizeParam& kern_size_param) const { | |||
MIDOUT_BEGIN( | |||
megdnn_fb_gi_matmul_kern, | |||
midout_iv("AlgoF32GiMK4Pack4x12::get_workspace"_hash)) { | |||
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; | |||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||
C_type = kern_size_param.C_type; | |||
matmul::fallback::gi_sgemm_mk4_pack_4x12 strategy( | |||
M, N, K, A_type, B_type, C_type); | |||
return megdnn::matmul::GemmInterleaved< | |||
matmul::fallback::gi_sgemm_mk4_pack_4x12>( | |||
M, N, K, trA, trB, strategy) | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4Pack4x12::get_kern( | |||
const KernSizeParam&) const { | |||
return f32_gi_mk4_pack_4x12_kern; | |||
} | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( | |||
AlgoF32GiMK4Pack4x12, megdnn_fb_gi_matmul_kern, "AlgoF32GiMK4Pack4x12"_hash, | |||
matmul::fallback::gi_sgemm_mk4_pack_4x12, float, float, AlgoDataType::FLOAT32, | |||
MK4); | |||
/* ===================== F32 algo ===================== */ | |||
namespace { | |||
void f32_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
@@ -97,6 +97,19 @@ public: | |||
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8) | |||
}; | |||
class MatrixMulImpl::AlgoF32GiMK4Pack4x12 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "FB_GI_F32_MK4_PACK_4x12"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_PACK_4x12) | |||
}; | |||
class MatrixMulImpl::AlgoF32Gi4x12 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
@@ -9,6 +9,8 @@ MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12) | |||
MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||
float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8); | |||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, gi_sgemm_4x12); | |||
MEGDNN_REG_GEMM_STRATEGY( | |||
float, float, float, 4, 12, 1, false, false, gi_sgemm_mk4_pack_4x12); | |||
} // namespace fallback | |||
} // namespace matmul | |||
@@ -214,6 +214,113 @@ static GI_FORCEINLINE void transpose_4x4_1_s( | |||
outptr += stride; | |||
} | |||
template <typename T> | |||
static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { | |||
static_assert(sizeof(T) == 4, "transpose_1x12_4_s only support sizeof(T) == 4"); | |||
GI_FLOAT32_t tmp_a, tmp_b; | |||
#define LOAD() \ | |||
tmp_a = GiLoadFloat32(inptr0); \ | |||
inptr0 += 4; \ | |||
tmp_b = GiLoadFloat32(inptr0); \ | |||
inptr0 += 4; | |||
LOAD(); | |||
GI_FLOAT32_V2_t d0d1d2d3 = GiZipqFloat32(tmp_a, tmp_b); | |||
LOAD(); | |||
GI_FLOAT32_V2_t d4d5d6d7 = GiZipqFloat32(tmp_a, tmp_b); | |||
LOAD(); | |||
GI_FLOAT32_V2_t d8d9d10d11 = GiZipqFloat32(tmp_a, tmp_b); | |||
LOAD(); | |||
GI_FLOAT32_V2_t d12d13d14d15 = GiZipqFloat32(tmp_a, tmp_b); | |||
LOAD(); | |||
GI_FLOAT32_V2_t d16d17d18d19 = GiZipqFloat32(tmp_a, tmp_b); | |||
LOAD(); | |||
GI_FLOAT32_V2_t d20d21d22d23 = GiZipqFloat32(tmp_a, tmp_b); | |||
#undef LOAD | |||
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0))); | |||
GiSt1Float32(outptr + 1 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0))); | |||
GiSt1Float32( | |||
outptr + 2 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 0))); | |||
GiSt1Float32( | |||
outptr + 3 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 0))); | |||
GiSt1Float32( | |||
outptr + 4 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 0))); | |||
GiSt1Float32( | |||
outptr + 5 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 0))); | |||
GiSt1Float32( | |||
outptr + 6 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0))); | |||
GiSt1Float32( | |||
outptr + 7 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0))); | |||
GiSt1Float32( | |||
outptr + 8 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 0))); | |||
GiSt1Float32( | |||
outptr + 9 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 0))); | |||
GiSt1Float32( | |||
outptr + 10 * 2, | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 0))); | |||
GiSt1Float32( | |||
outptr + 11 * 2, | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 0))); | |||
GiSt1Float32( | |||
outptr + 12 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1))); | |||
GiSt1Float32( | |||
outptr + 13 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1))); | |||
GiSt1Float32( | |||
outptr + 14 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 1))); | |||
GiSt1Float32( | |||
outptr + 15 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 1))); | |||
GiSt1Float32( | |||
outptr + 16 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 1))); | |||
GiSt1Float32( | |||
outptr + 17 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 1))); | |||
GiSt1Float32( | |||
outptr + 18 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1))); | |||
GiSt1Float32( | |||
outptr + 19 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1))); | |||
GiSt1Float32( | |||
outptr + 20 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 1))); | |||
GiSt1Float32( | |||
outptr + 21 * 2, | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 1))); | |||
GiSt1Float32( | |||
outptr + 22 * 2, | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 1))); | |||
GiSt1Float32( | |||
outptr + 23 * 2, | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 1))); | |||
outptr += 23 * 2; | |||
} | |||
template <typename T> | |||
static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { | |||
static_assert(sizeof(T) == 4, "transpose_1x4_4_s only support sizeof(T) == 4"); | |||
GI_FLOAT32_t tmp_a, tmp_b; | |||
#define LOAD() \ | |||
tmp_a = GiLoadFloat32(inptr0); \ | |||
inptr0 += 4; \ | |||
tmp_b = GiLoadFloat32(inptr0); \ | |||
inptr0 += 4; | |||
LOAD(); | |||
GI_FLOAT32_V2_t d0d1d2d3 = GiZipqFloat32(tmp_a, tmp_b); | |||
LOAD(); | |||
GI_FLOAT32_V2_t d4d5d6d7 = GiZipqFloat32(tmp_a, tmp_b); | |||
#undef LOAD | |||
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0))); | |||
GiSt1Float32(outptr + 1 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0))); | |||
GiSt1Float32( | |||
outptr + 2 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0))); | |||
GiSt1Float32( | |||
outptr + 3 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0))); | |||
GiSt1Float32(outptr + 4 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1))); | |||
GiSt1Float32(outptr + 5 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1))); | |||
GiSt1Float32( | |||
outptr + 6 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1))); | |||
GiSt1Float32( | |||
outptr + 7 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1))); | |||
outptr += 7 * 2; | |||
} | |||
} // namespace fallback | |||
} // namespace matmul | |||
} // namespace megdnn | |||
@@ -0,0 +1,458 @@ | |||
//! risc-v gcc will error report uninitialized var at if/else case when use RVV type | |||
#pragma GCC diagnostic push | |||
#pragma GCC diagnostic ignored "-Wuninitialized" | |||
#ifdef __GNUC__ | |||
#ifndef __has_warning | |||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||
#else | |||
#if __has_warning("-Wmaybe-uninitialized") | |||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||
#endif | |||
#endif | |||
#endif | |||
#include "src/fallback/matrix_mul/generic_strategy.h" | |||
#include "src/fallback/matrix_mul/gi/fp32/common.h" | |||
using namespace megdnn; | |||
using namespace matmul::fallback; | |||
namespace { | |||
void kern_4x12( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k) { | |||
MEGDNN_MARK_USED_VAR(LDC); | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
float* output0 = output; | |||
int oddk = (K & 1); | |||
K = ((K + 1) / 2) - 1; | |||
float* r1 = output; | |||
GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19, | |||
d20d21, d22d23, d24d25, d26d27, d28d29, d30d31; | |||
if (is_first_k) { | |||
d8d9 = GiBroadcastFloat32(0.0f); | |||
d10d11 = GiBroadcastFloat32(0.0f); | |||
d12d13 = GiBroadcastFloat32(0.0f); | |||
d14d15 = GiBroadcastFloat32(0.0f); | |||
d0d1 = GiLoadFloat32(a_ptr); | |||
a_ptr = a_ptr + 4; | |||
d16d17 = GiBroadcastFloat32(0.0f); | |||
d18d19 = GiBroadcastFloat32(0.0f); | |||
d20d21 = GiBroadcastFloat32(0.0f); | |||
d22d23 = GiBroadcastFloat32(0.0f); | |||
d4d5 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d6d7 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d24d25 = GiBroadcastFloat32(0.0f); | |||
d26d27 = GiBroadcastFloat32(0.0f); | |||
d28d29 = GiBroadcastFloat32(0.0f); | |||
d30d31 = GiBroadcastFloat32(0.0f); | |||
} else { | |||
d8d9 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d10d11 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d12d13 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d14d15 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d16d17 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d18d19 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d20d21 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d22d23 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d24d25 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d26d27 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d28d29 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d30d31 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d0d1 = GiLoadFloat32(a_ptr); | |||
a_ptr = a_ptr + 4; | |||
d4d5 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
} | |||
for (; K > 0; K--) { | |||
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); | |||
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); | |||
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); | |||
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); | |||
d4d5 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0); | |||
d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1); | |||
d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2); | |||
d2d3 = GiLoadFloat32(a_ptr); | |||
a_ptr = a_ptr + 4; | |||
d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3); | |||
d6d7 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0); | |||
d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1); | |||
d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2); | |||
d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3); | |||
d4d5 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); | |||
d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); | |||
d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); | |||
d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); | |||
d6d7 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d16d17 = GiSimdFmaLane(d16d17, d2d3, d4d5, 0); | |||
d18d19 = GiSimdFmaLane(d18d19, d2d3, d4d5, 1); | |||
d0d1 = GiLoadFloat32(a_ptr); | |||
a_ptr = a_ptr + 4; | |||
d20d21 = GiSimdFmaLane(d20d21, d2d3, d4d5, 2); | |||
d22d23 = GiSimdFmaLane(d22d23, d2d3, d4d5, 3); | |||
d4d5 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d24d25 = GiSimdFmaLane(d24d25, d2d3, d6d7, 0); | |||
d26d27 = GiSimdFmaLane(d26d27, d2d3, d6d7, 1); | |||
d28d29 = GiSimdFmaLane(d28d29, d2d3, d6d7, 2); | |||
d30d31 = GiSimdFmaLane(d30d31, d2d3, d6d7, 3); | |||
d6d7 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
} | |||
if (1 == oddk) { | |||
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); | |||
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); | |||
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); | |||
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); | |||
d4d5 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0); | |||
GiStoreFloat32(output0, d8d9); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d10d11); | |||
output0 = output0 + 4; | |||
d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1); | |||
d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2); | |||
GiStoreFloat32(output0, d12d13); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d14d15); | |||
output0 = output0 + 4; | |||
d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3); | |||
d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0); | |||
GiStoreFloat32(output0, d16d17); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d18d19); | |||
output0 = output0 + 4; | |||
d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1); | |||
GiStoreFloat32(output0, d20d21); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d22d23); | |||
output0 = output0 + 4; | |||
d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2); | |||
GiStoreFloat32(output0, d24d25); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d26d27); | |||
output0 = output0 + 4; | |||
d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3); | |||
GiStoreFloat32(output0, d28d29); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d30d31); | |||
output0 = output0 + 4; | |||
} else { | |||
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); | |||
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); | |||
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); | |||
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); | |||
d4d5 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0); | |||
d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1); | |||
d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2); | |||
d2d3 = GiLoadFloat32(a_ptr); | |||
a_ptr = a_ptr + 4; | |||
d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3); | |||
d6d7 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0); | |||
d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1); | |||
d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2); | |||
d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3); | |||
d4d5 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); | |||
d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); | |||
d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); | |||
d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); | |||
d6d7 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d16d17 = GiSimdFmaLane(d16d17, d2d3, d4d5, 0); | |||
d18d19 = GiSimdFmaLane(d18d19, d2d3, d4d5, 1); | |||
GiStoreFloat32(output0, d8d9); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d10d11); | |||
output0 = output0 + 4; | |||
d20d21 = GiSimdFmaLane(d20d21, d2d3, d4d5, 2); | |||
d22d23 = GiSimdFmaLane(d22d23, d2d3, d4d5, 3); | |||
GiStoreFloat32(output0, d12d13); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d14d15); | |||
output0 = output0 + 4; | |||
d24d25 = GiSimdFmaLane(d24d25, d2d3, d6d7, 0); | |||
d26d27 = GiSimdFmaLane(d26d27, d2d3, d6d7, 1); | |||
GiStoreFloat32(output0, d16d17); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d18d19); | |||
output0 = output0 + 4; | |||
d28d29 = GiSimdFmaLane(d28d29, d2d3, d6d7, 2); | |||
d30d31 = GiSimdFmaLane(d30d31, d2d3, d6d7, 3); | |||
GiStoreFloat32(output0, d20d21); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d22d23); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d24d25); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d26d27); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d28d29); | |||
output0 = output0 + 4; | |||
GiStoreFloat32(output0, d30d31); | |||
output0 = output0 + 4; | |||
} | |||
} | |||
void kern_4x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int n_remain) { | |||
MEGDNN_MARK_USED_VAR(LDC); | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
K = ((K + 1) / 2) - 1; | |||
float* r1 = output; | |||
GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15; | |||
if (is_first_k) { | |||
d8d9 = GiBroadcastFloat32(0.0f); | |||
d10d11 = GiBroadcastFloat32(0.0f); | |||
d0d1 = GiLoadFloat32(a_ptr); | |||
a_ptr = a_ptr + 4; | |||
d12d13 = GiBroadcastFloat32(0.0f); | |||
d4d5 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d14d15 = GiBroadcastFloat32(0.0f); | |||
} else { | |||
if (n_remain == 4) { | |||
d8d9 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d10d11 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d12d13 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d14d15 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
} else if (n_remain == 3) { | |||
d8d9 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d10d11 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d12d13 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
} else if (n_remain == 2) { | |||
d8d9 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
d10d11 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
} else if (n_remain == 1) { | |||
d8d9 = GiLoadFloat32(r1); | |||
r1 = r1 + 4; | |||
} | |||
} | |||
for (; K > 0; K--) { | |||
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); | |||
d2d3 = GiLoadFloat32(a_ptr); | |||
a_ptr = a_ptr + 4; | |||
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); | |||
d6d7 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); | |||
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); | |||
d4d5 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); | |||
d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); | |||
d0d1 = GiLoadFloat32(a_ptr); | |||
a_ptr = a_ptr + 4; | |||
d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); | |||
d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); | |||
} | |||
if (1 == oddk) { | |||
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); | |||
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); | |||
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); | |||
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); | |||
} else { | |||
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); | |||
d2d3 = GiLoadFloat32(a_ptr); | |||
a_ptr = a_ptr + 4; | |||
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); | |||
d6d7 = GiLoadFloat32(b_ptr); | |||
b_ptr = b_ptr + 4; | |||
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); | |||
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); | |||
d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); | |||
d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); | |||
d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); | |||
d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); | |||
} | |||
if (n_remain == 4) { | |||
GiStoreFloat32(output, d8d9); | |||
output = output + 4; | |||
GiStoreFloat32(output, d10d11); | |||
output = output + 4; | |||
GiStoreFloat32(output, d12d13); | |||
output = output + 4; | |||
GiStoreFloat32(output, d14d15); | |||
output = output + 4; | |||
} else if (n_remain == 3) { | |||
GiStoreFloat32(output, d8d9); | |||
output = output + 4; | |||
GiStoreFloat32(output, d10d11); | |||
output = output + 4; | |||
GiStoreFloat32(output, d12d13); | |||
output = output + 4; | |||
} else if (n_remain == 2) { | |||
GiStoreFloat32(output, d8d9); | |||
output = output + 4; | |||
GiStoreFloat32(output, d10d11); | |||
output = output + 4; | |||
} else if (n_remain == 1) { | |||
GiStoreFloat32(output, d8d9); | |||
output = output + 4; | |||
} | |||
} | |||
} // namespace | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gi_sgemm_mk4_pack_4x12); | |||
//! Now no matmul mode of only packB support in conv1x1 and im2col, so just copy | |||
//! the weight | |||
void gi_sgemm_mk4_pack_4x12::pack_A( | |||
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
bool) const { | |||
megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); | |||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
constexpr int PACK_C_SIZE = 4; | |||
size_t cp_length = (kmax - k0) * PACK_C_SIZE; | |||
for (int m = y0; m < ymax; m += 4) { | |||
const float* src = in + (m / PACK_C_SIZE) * ldin + k0 * PACK_C_SIZE; | |||
memcpy(out, src, cp_length * sizeof(float)); | |||
out += cp_length; | |||
} | |||
} | |||
void gi_sgemm_mk4_pack_4x12::pack_B( | |||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose_B) const { | |||
megdnn_assert(!transpose_B); | |||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
float tmpbuff[16] = {0.0f}; | |||
constexpr int PACK_C_SIZE = 4; | |||
int ksize = kmax - k0; | |||
int ksize12 = ksize * 12; | |||
int ksize4 = (ksize << 2); | |||
float* outptr_base = out; | |||
float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; | |||
int k = k0; | |||
for (; k + 3 < kmax; k += 4) { | |||
const float* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE; | |||
int x = x0; | |||
auto outptr = outptr_base; | |||
for (; x + 12 <= xmax; x += 12) { | |||
auto outptr_interleave = outptr; | |||
transpose_1x12_4_s(inptr, outptr_interleave); | |||
outptr += ksize12; | |||
} | |||
outptr = outptr_base4; | |||
for (; x + 4 <= xmax; x += 4) { | |||
auto outptr_interleave = outptr; | |||
transpose_1x4_4_s(inptr, outptr_interleave); | |||
outptr += ksize4; | |||
} | |||
if (x < xmax) { | |||
memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE); | |||
auto outptr_interleave = outptr; | |||
const float* tmp_ptr = &tmpbuff[0]; | |||
transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave); | |||
outptr += ksize4; | |||
} | |||
outptr_base += 12 * PACK_C_SIZE; | |||
outptr_base4 += 4 * PACK_C_SIZE; | |||
} | |||
} | |||
void gi_sgemm_mk4_pack_4x12::kern( | |||
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||
size_t LDC, bool is_first_k, const float*, float*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||
A_dtype.enumv() == DTypeEnum::Float32); | |||
constexpr int PACK_C_SIZE = 4; | |||
constexpr size_t A_INTERLEAVE = 4; | |||
constexpr size_t B_INTERLEAVE = 12; | |||
const int K12 = K * 12; | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m < M; m += A_INTERLEAVE) { | |||
float* output = C + (m / 4 * LDC); | |||
size_t n = 0; | |||
const float* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
kern_4x12(packA, cur_packB, K, output, LDC, is_first_k); | |||
output += PACK_C_SIZE * B_INTERLEAVE; | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += PACK_C_SIZE * 4; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -28,6 +28,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoNaive naive; | |||
AlgoF32GiGemvMK4 f32_gemv_mk4; | |||
AlgoF32GiMK4_4x8 f32_mk4_4x8; | |||
AlgoF32GiMK4Pack4x12 f32_mk4_gi_pack_4x12; | |||
AlgoF32Gi4x12 f32_4x8; | |||
SmallVector<AlgoBase*> m_all_algos; | |||
AlgoBase::Mapper m_all_algos_map; | |||
@@ -36,6 +37,7 @@ public: | |||
AlgoPack() { | |||
m_all_algos.emplace_back(&f32_gemv_mk4); | |||
m_all_algos.emplace_back(&f32_mk4_4x8); | |||
m_all_algos.emplace_back(&f32_mk4_gi_pack_4x12); | |||
m_all_algos.emplace_back(&f32_4x8); | |||
m_all_algos.emplace_back(&gemv); | |||
m_all_algos.emplace_back(&f32_k8x12x1); | |||
@@ -103,6 +103,7 @@ public: | |||
FB_NAIVE, | |||
FB_GI_F32_GEMV_MK4, | |||
FB_GI_F32_MK4_4x8, | |||
FB_GI_F32_MK4_PACK_4x12, | |||
FB_GI_F32_4x12, | |||
#if MEGDNN_X86 | |||
@@ -230,10 +231,11 @@ public: | |||
}; | |||
private: | |||
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | |||
class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44 | |||
class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44 | |||
class AlgoF32Gi4x12; // fallback F32 gi Gemm | |||
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | |||
class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44 | |||
class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44 | |||
class AlgoF32GiMK4Pack4x12; // fallback F32 gi Gemm pack NCHW44 | |||
class AlgoF32Gi4x12; // fallback F32 gi Gemm | |||
class AlgoGemv; | |||
class AlgoNaive; | |||
class AlgoPack; | |||
@@ -364,7 +364,9 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | |||
DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, | |||
size_t pack_oc_size = 1) { | |||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
megdnn_assert(pack_oc_size == 1, "PostProcess only support nchw in x86"); | |||
megdnn_assert( | |||
pack_oc_size == 1 || pack_oc_size == 4, | |||
"PostProcess only support nchw/44 in x86"); | |||
megdnn_assert( | |||
nonlineMode == megdnn::param::ConvBiasV0::NonlineMode::IDENTITY, | |||
"Add bias PostProcess only support IDENTITY"); | |||
@@ -59,6 +59,11 @@ cb(dt_float32, float, "avx2", float, __m256, mm256, ps, ps, SIMDType::AVX2); | |||
template <typename ctype, SIMDType simd_type = SIMDType::AVX2> | |||
struct ParamElemVisitorHalfBoardCast; | |||
//! some compiler do not define _mm256_set_m128 | |||
#define _mm256_set_m128ff(xmm1, xmm2) \ | |||
_mm256_permute2f128_ps( \ | |||
_mm256_castps128_ps256(xmm1), _mm256_castps128_ps256(xmm2), 2) | |||
#define cb( \ | |||
_ctype, _simd_ptr_type, load_half_fuc, half_type, _simd_type, board_cast_func) \ | |||
template <> \ | |||
@@ -78,9 +83,10 @@ cb(dt_int32, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); | |||
cb(dt_int16, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); | |||
cb(dt_int8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); | |||
cb(dt_uint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); | |||
cb(dt_float32, float, _mm_load_ps, __m128, __m256, _mm256_set_m128); | |||
cb(dt_float32, float, _mm_load_ps, __m128, __m256, _mm256_set_m128ff); | |||
#undef cb | |||
#undef _mm256_set_m128ff | |||
/*! | |||
* \brief broadcast type | |||
* BCAST_x[0]x[1]...: x[i] == !stride[i] | |||
@@ -239,6 +239,74 @@ void checker_conv_bias( | |||
} | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_1X1_S1_MK4_PACK_F32) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({1}, FULL_NLMODE, ALL_BIASMODE, 1, true); | |||
check_conv_bias(args, handle(), "CONV1x1:FB_GI_F32_MK4_PACK_4x12:24"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S1_MK4_PACK_F32_PREPROCESS) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({2, 4, 7}, FULL_NLMODE, BR_AND_NO_BIASMODE, 1); | |||
#define cb(name) \ | |||
check_conv_bias_preprocess( \ | |||
args, handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \ | |||
dtype::Float32(), dtype::Float32(), name); | |||
cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); | |||
#undef cb | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S2_MK4_PACK_F32_FUSE_PREPROCESS) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({3}, FULL_NLMODE, BR_AND_BIAS_BIASMODE, 2); | |||
#define cb(name) \ | |||
check_conv_bias_preprocess( \ | |||
args, handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \ | |||
dtype::Float32(), dtype::Float32(), name); | |||
cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); | |||
#undef cb | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_1X1_S1_MK4_PACK_F32_PREPROCESS) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({1}, FULL_NLMODE, ALL_BIASMODE, 1, true); | |||
#define cb(name) \ | |||
check_conv_bias_preprocess( \ | |||
args, handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \ | |||
dtype::Float32(), dtype::Float32(), name); | |||
cb("CONV1x1:FB_GI_F32_MK4_PACK_4x12:24"); | |||
#undef cb | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S1_MK4_PACK_F32) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({2, 4, 7}, FULL_NLMODE, BR_AND_BIAS_BIASMODE, 1); | |||
check_conv_bias(args, handle(), "IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S2_MK4_PACK_F32) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({3, 5, 6}, FULL_NLMODE, BR_AND_BIAS_BIASMODE, 2); | |||
#define cb(name) check_conv_bias(args, handle(), name); | |||
cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); | |||
#undef cb | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S2_MK4_PACK_F32_FUSE) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({3}, FULL_NLMODE, ALL_BIASMODE, 2); | |||
#define cb(name) check_conv_bias(args, handle(), name); | |||
cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); | |||
#undef cb | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_IM2COL_8X8X16) { | |||
using namespace conv_bias; | |||
param::ConvBias cur_param; | |||
@@ -42,12 +42,18 @@ TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) { | |||
"FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); | |||
} | |||
TEST_F(FALLBACK, MATRIX_MULF_GI_F32_4x12) { | |||
TEST_F(FALLBACK, MATRIX_MUL_GI_F32_4x12) { | |||
matrix_mul::check_matrix_mul( | |||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||
"FB_GI_F32_4x12"); | |||
} | |||
TEST_F(FALLBACK, MATRIX_MUL_GI_PACK_MK4) { | |||
matrix_mul::check_matrix_mul( | |||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||
"FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4, 1); | |||
} | |||
TEST_F(FALLBACK, MATRIX_MUL_RECORD) { | |||
TaskRecordChecker<MatrixMul> checker(1); | |||
using Param = MatrixMul::Param; | |||
@@ -163,6 +169,13 @@ TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_FB_GI_F32_4x12) { | |||
"FB_GI_F32_4x12", param::MatrixMul::Format::DEFAULT); | |||
} | |||
TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_GI_PACK_MK4) { | |||
auto args = matrix_mul::get_benchmark_matmul_args(); | |||
matrix_mul::benchmark_single_algo( | |||
handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, | |||
"FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4); | |||
} | |||
#endif | |||
} // namespace test | |||
} // namespace megdnn | |||