@@ -352,6 +352,61 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern( | |||||
return kern_mk8_8x8; | return kern_mk8_8x8; | ||||
} | } | ||||
/* ==================== F16_MK8_16x12x1 algo ====================*/ | |||||
bool MatrixMulImpl::AlgoF16MK8_16x12x1::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
kern_size_param.C_type == kern_size_param.A_type && | |||||
kern_size_param.B_type == kern_size_param.A_type && | |||||
kern_size_param.A_type == dtype::Float16() && | |||||
kern_size_param.format == param::MatrixMul::Format::MK8 && | |||||
!kern_size_param.trA && !kern_size_param.trB; | |||||
} | |||||
size_t MatrixMulImpl::AlgoF16MK8_16x12x1::get_workspace( | |||||
const KernSizeParam& kern_size_param) const { | |||||
MIDOUT_BEGIN( | |||||
megdnn_aarch64_matmul_kern, | |||||
midout_iv("AlgoF16MK8_16x12x1::get_workspace"_hash)) { | |||||
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; | |||||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||||
C_type = kern_size_param.C_type; | |||||
aarch64::matmul::hgemm_mk8_16x12 strategy(M, N, K, A_type, B_type, C_type); | |||||
return megdnn::matmul::GemmInterleaved<aarch64::matmul::hgemm_mk8_16x12>( | |||||
M, N, K, trA, trB, strategy) | |||||
.get_workspace_size(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_16x12x1::get_kern( | |||||
const KernSizeParam&) const { | |||||
auto kern_mk8_16x12x1 = [](const MatrixMulImpl::KernParam& kern_param) { | |||||
MIDOUT_BEGIN( | |||||
megdnn_aarch64_matmul_kern, | |||||
midout_iv("AlgoF16MK8_16x12x1::get_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto trA = kern_param.trA, trB = kern_param.trB; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||||
C_type = kern_param.C_type; | |||||
const auto Aptr = kern_param.A<dt_float16>(), | |||||
Bptr = kern_param.B<dt_float16>(); | |||||
auto Cptr = kern_param.C<dt_float16>(); | |||||
aarch64::matmul::hgemm_mk8_16x12 strategy(M, N, K, A_type, B_type, C_type); | |||||
megdnn::matmul::GemmInterleaved<aarch64::matmul::hgemm_mk8_16x12>( | |||||
M, N, K, trA, trB, strategy) | |||||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); | |||||
} | |||||
MIDOUT_END(); | |||||
}; | |||||
return kern_mk8_16x12x1; | |||||
} | |||||
#endif | #endif | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
@@ -86,6 +86,17 @@ public: | |||||
MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_8X8) | MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_8X8) | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoF16MK8_16x12x1 final : public AlgoBase { | |||||
public: | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_F16_MK8_16X12X1"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(16, 12, 1, 2, AlgoDataType::FLOAT16, MK8); | |||||
MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_16X12X1); | |||||
}; | |||||
#endif | #endif | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
@@ -1143,6 +1143,29 @@ static inline void interleave_2x4_4_s(const T*& inptr0, const T*& inptr1, T* out | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
static inline void interleave_2x8_2_h(const T*& inptr0, const T*& inptr1, T* outptr) { | |||||
static_assert(sizeof(T) == 2, "interleave_2x8_2_s only support size == 2"); | |||||
asm volatile( | |||||
"ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%[inptr0]], #64\n" | |||||
"ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%[inptr0]], #64\n" | |||||
"ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [%[inptr1]], #64\n" | |||||
"ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [%[inptr1]], #64\n" | |||||
"stp q0, q8, [%[outptr]]\n" | |||||
"stp q1, q9, [%[outptr], #32]\n" | |||||
"stp q2, q10, [%[outptr], #64]\n" | |||||
"stp q3, q11, [%[outptr], #96]\n" | |||||
"stp q4, q12, [%[outptr], #128]\n" | |||||
"stp q5, q13, [%[outptr], #160]\n" | |||||
"stp q6, q14, [%[outptr], #192]\n" | |||||
"stp q7, q15, [%[outptr], #224]\n" | |||||
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "memory"); | |||||
} | |||||
template <typename T> | |||||
static inline void interleave_1x4_4_s(const T*& inptr0, T* outptr) { | static inline void interleave_1x4_4_s(const T*& inptr0, T* outptr) { | ||||
static_assert(sizeof(T) == 4, "interleave_1x4_4_s only support size == 4"); | static_assert(sizeof(T) == 4, "interleave_1x4_4_s only support size == 4"); | ||||
asm volatile( | asm volatile( | ||||
@@ -1155,6 +1178,20 @@ static inline void interleave_1x4_4_s(const T*& inptr0, T* outptr) { | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
static inline void interleave_1x8_2_h(const T*& inptr0, T* outptr) { | |||||
static_assert(sizeof(T) == 2, "interleave_1x8_2_s only support size == 2"); | |||||
asm volatile( | |||||
"ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%[inptr0]], #64\n" | |||||
"ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%[inptr0]], #64\n" | |||||
"st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%[outptr]], #64\n" | |||||
"st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%[outptr]], #64\n" | |||||
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); | |||||
} | |||||
template <typename T> | |||||
static inline void interleave_4x8_2_b( | static inline void interleave_4x8_2_b( | ||||
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, | const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, | ||||
T*& outptr) { | T*& outptr) { | ||||
@@ -1549,6 +1586,49 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
static inline void transpose_1x12_2_h(const T*& inptr, T* outptr) { | |||||
static_assert(sizeof(T) == 2, "transpose_1x12_2_s only support sizeof(T) == 2"); | |||||
asm volatile( | |||||
"ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr]], #64\n" | |||||
"ld4 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[inptr]], #64\n" | |||||
"ld4 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[inptr]], #64\n" | |||||
"uzp1 v12.8h, v0.8h, v4.8h\n" | |||||
"st1 {v12.8h}, [%[outptr]], #16\n" // line[0][0-7] | |||||
"uzp1 v14.8h, v8.8h, v9.8h\n" | |||||
"st1 {v14.d}[0], [%[outptr]], #8\n" // line[0][8-11] | |||||
"uzp2 v13.8h, v0.8h, v4.8h\n" | |||||
"st1 {v13.8h}, [%[outptr]], #16\n" // line[1][0-7] | |||||
"uzp2 v15.8h, v8.8h, v9.8h\n" | |||||
"st1 {v15.d}[0], [%[outptr]], #8\n" // line[1][8-11] | |||||
"uzp1 v16.8h, v1.8h, v5.8h\n" | |||||
"st1 {v16.8h}, [%[outptr]], #16\n" // line[2][0-7] | |||||
"st1 {v14.d}[1], [%[outptr]], #8\n" // line[2][8-11] | |||||
"uzp2 v17.8h, v1.8h, v5.8h\n" | |||||
"st1 {v17.8h}, [%[outptr]], #16\n" // line[3][0-7] | |||||
"st1 {v15.d}[1], [%[outptr]], #8\n" // line[3][8-11] | |||||
"uzp1 v18.8h, v2.8h, v6.8h\n" | |||||
"st1 {v18.8h}, [%[outptr]], #16\n" // line[4][0-7] | |||||
"uzp1 v19.8h, v10.8h, v11.8h\n" | |||||
"st1 {v19.d}[0], [%[outptr]], #8\n" // line[4][8-11] | |||||
"uzp2 v20.8h, v2.8h, v6.8h\n" | |||||
"st1 {v20.8h}, [%[outptr]], #16\n" // line[5][0-7] | |||||
"uzp2 v21.8h, v10.8h, v11.8h\n" | |||||
"st1 {v21.d}[0], [%[outptr]], #8\n" // line[5][8-11] | |||||
"uzp1 v22.8h, v3.8h, v7.8h\n" | |||||
"st1 {v22.8h}, [%[outptr]], #16\n" // line[6][0-7] | |||||
"st1 {v19.d}[1], [%[outptr]], #8\n" // line[6][8-11] | |||||
"uzp2 v23.8h, v3.8h, v7.8h\n" | |||||
"st1 {v23.8h}, [%[outptr]], #16\n" // line[7][0-7] | |||||
"st1 {v21.d}[1], [%[outptr]], #8\n" // line[7][8-11] | |||||
: [inptr] "+r"(inptr), [outptr] "+r"(outptr) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "memory"); | |||||
} | |||||
template <typename T> | |||||
static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { | static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { | ||||
static_assert(sizeof(T) == 4, "transpose_1x4_4_s only support sizeof(T) == 4"); | static_assert(sizeof(T) == 4, "transpose_1x4_4_s only support sizeof(T) == 4"); | ||||
@@ -0,0 +1,160 @@ | |||||
#pragma once | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/aarch64/matrix_mul/fp16/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
struct matmul_mk8_16x12 { | |||||
template <size_t M_BLOCK, size_t N_BLOCK> | |||||
static void kern( | |||||
const dt_float16* packedA, const dt_float16* packedB, int K, | |||||
dt_float16* out, int LDC, bool is_first_k); | |||||
static void hgemm_16x12_pack_A( | |||||
dt_float16* outptr, const dt_float16* inptr, int ldin, int y0, int ymax, | |||||
int k0, int kmax) { | |||||
megdnn_assert(y0 % 8 == 0 && ymax % 8 == 0, "M must be time of 8"); | |||||
megdnn_assert(k0 % 8 == 0 && kmax % 8 == 0, "K must be time of 8"); | |||||
constexpr int PACK_SIZE_128 = 16 * 8; | |||||
constexpr int PACK_SIZE_64 = 8 * 8; | |||||
constexpr int PACK_C_SIZE = 8; | |||||
int y = y0; | |||||
for (; y + 15 < ymax; y += 16) { | |||||
const dt_float16* inptr0 = inptr + y / PACK_C_SIZE * ldin + k0; | |||||
const dt_float16* inptr1 = inptr0 + ldin; | |||||
prefetch_4x(inptr0); | |||||
prefetch_4x(inptr1); | |||||
for (int k = k0; k < kmax; k += 8) { | |||||
interleave_2x8_2_h(inptr0, inptr1, outptr); | |||||
outptr += PACK_SIZE_128; | |||||
} | |||||
} | |||||
for (; y < ymax; y += 8) { | |||||
const dt_float16* inptr0 = inptr + y / PACK_C_SIZE * ldin + k0; | |||||
prefetch_4x(inptr0); | |||||
for (int k = k0; k < kmax; k += 8) { | |||||
interleave_1x8_2_h(inptr0, outptr); | |||||
outptr += PACK_SIZE_64; | |||||
} | |||||
} | |||||
} | |||||
static void hgemm_16x12_pack_B( | |||||
dt_float16* out, const dt_float16* in, int ldin, int x0, int xmax, int k0, | |||||
int kmax) { | |||||
megdnn_assert(k0 % 8 == 0 && kmax % 8 == 0, "K must be time of 8"); | |||||
dt_float16 tmpbuff[96] = {static_cast<dt_float16>(0.0)}; | |||||
constexpr int PACK_C_SIZE = 8; | |||||
int ksize = kmax - k0; | |||||
int ksize12 = ksize * 12; | |||||
dt_float16* outptr_base = out; | |||||
for (int k = k0; k < kmax; k += 8) { | |||||
const dt_float16* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE; | |||||
prefetch_3x(inptr); | |||||
int x = x0; | |||||
auto outptr = outptr_base; | |||||
for (; x + 12 <= xmax; x += 12) { | |||||
auto outptr_interleave = outptr; | |||||
transpose_1x12_2_h(inptr, outptr_interleave); | |||||
outptr += ksize12; | |||||
} | |||||
if (x < xmax) { | |||||
std::memcpy( | |||||
tmpbuff, inptr, sizeof(dt_float16) * (xmax - x) * PACK_C_SIZE); | |||||
auto outptr_interleave = outptr; | |||||
inptr = tmpbuff; | |||||
transpose_1x12_2_h(inptr, outptr_interleave); | |||||
} | |||||
outptr_base += 12 * 8; | |||||
} | |||||
} | |||||
}; | |||||
#define M_BLOCK 1 | |||||
#define N_BLOCK 1 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 2 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 3 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 4 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 5 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 6 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 7 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 8 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 9 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 10 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 11 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 12 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#undef M_BLOCK | |||||
#define M_BLOCK 2 | |||||
#define N_BLOCK 1 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 2 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 3 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 4 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 5 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 6 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 7 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 8 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 9 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 10 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 11 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#define N_BLOCK 12 | |||||
#include "mk8_16x12_kern.inc" | |||||
#undef N_BLOCK | |||||
#undef M_BLOCK | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
#endif |
@@ -0,0 +1,321 @@ | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
#ifndef _STR | |||||
#define _STR(X) #X | |||||
#endif | |||||
#ifndef STR | |||||
#define STR(X) _STR(X) | |||||
#endif | |||||
template <> | |||||
void matmul_mk8_16x12::kern<M_BLOCK, N_BLOCK>( | |||||
const dt_float16* packedA, const dt_float16* packedB, int K, | |||||
dt_float16* out, int LDC, bool is_first_k) { | |||||
#define IF_M_GT(M, INSTRUC) ".if " STR(M_BLOCK) " > " #M "\n" INSTRUC ".endif\n" | |||||
#define IF_N_GT(N, INSTRUC) ".if " STR(N_BLOCK) " > " #N "\n" INSTRUC ".endif\n" | |||||
// clang-format off | |||||
#define IF_MN_GT(M, N, INSTRUC) \ | |||||
".if " STR(M_BLOCK) " > " #M "\n" \ | |||||
".if " STR(N_BLOCK) " > " #N "\n" \ | |||||
INSTRUC \ | |||||
".endif\n" \ | |||||
".endif\n" | |||||
const dt_float16* a_ptr = packedA; | |||||
const dt_float16* b_ptr = packedB; | |||||
dt_float16* outptr0 = out; | |||||
dt_float16* outptr1 = out + LDC; | |||||
int oddK = (K & 1); | |||||
K = ((K + 1) / 2) - 1; | |||||
asm volatile( | |||||
"cmp %w[is_first_k], #1\n" | |||||
"beq 1f\n" | |||||
IF_M_GT(0, "mov x1, %[outptr0]\n") | |||||
IF_M_GT(1, "mov x2, %[outptr1]\n") | |||||
IF_MN_GT(0, 0, "ld1 {v8.8h}, [x1], #16\n") | |||||
IF_MN_GT(0, 1, "ld1 {v9.8h}, [x1], #16\n") | |||||
IF_MN_GT(0, 2, "ld1 {v10.8h}, [x1], #16\n") | |||||
IF_MN_GT(0, 3, "ld1 {v11.8h}, [x1], #16\n") | |||||
IF_MN_GT(0, 4, "ld1 {v12.8h}, [x1], #16\n") | |||||
IF_MN_GT(0, 5, "ld1 {v13.8h}, [x1], #16\n") | |||||
IF_MN_GT(0, 6, "ld1 {v14.8h}, [x1], #16\n") | |||||
IF_MN_GT(0, 7, "ld1 {v15.8h}, [x1], #16\n") | |||||
IF_MN_GT(0, 8, "ld1 {v16.8h}, [x1], #16\n") | |||||
IF_MN_GT(0, 9, "ld1 {v17.8h}, [x1], #16\n") | |||||
IF_MN_GT(0, 10, "ld1 {v18.8h}, [x1], #16\n") | |||||
IF_MN_GT(0, 11, "ld1 {v19.8h}, [x1], #16\n") | |||||
IF_MN_GT(1, 0, "ld1 {v20.8h}, [x2], #16\n") | |||||
IF_MN_GT(1, 1, "ld1 {v21.8h}, [x2], #16\n") | |||||
IF_MN_GT(1, 2, "ld1 {v22.8h}, [x2], #16\n") | |||||
IF_MN_GT(1, 3, "ld1 {v23.8h}, [x2], #16\n") | |||||
IF_MN_GT(1, 4, "ld1 {v24.8h}, [x2], #16\n") | |||||
IF_MN_GT(1, 5, "ld1 {v25.8h}, [x2], #16\n") | |||||
IF_MN_GT(1, 6, "ld1 {v26.8h}, [x2], #16\n") | |||||
IF_MN_GT(1, 7, "ld1 {v27.8h}, [x2], #16\n") | |||||
IF_MN_GT(1, 8, "ld1 {v28.8h}, [x2], #16\n") | |||||
IF_MN_GT(1, 9, "ld1 {v29.8h}, [x2], #16\n") | |||||
IF_MN_GT(1, 10, "ld1 {v30.8h}, [x2], #16\n") | |||||
IF_MN_GT(1, 11, "ld1 {v31.8h}, [x2], #16\n") | |||||
IF_M_GT(0, "ld1 {v0.8h}, [%[a_ptr]], #16\n") | |||||
IF_N_GT(0, "ld1 {v2.8h}, [%[b_ptr]], #16\n") | |||||
"b 2f\n" | |||||
"1:\n" | |||||
IF_MN_GT(0, 0, "eor v8.16b, v8.16b, v8.16b\n") | |||||
IF_MN_GT(0, 1, "eor v9.16b, v9.16b, v9.16b\n") | |||||
IF_MN_GT(0, 2, "eor v10.16b, v10.16b, v10.16b\n") | |||||
"prfm pstl1keep, [%[outptr0]]\n" | |||||
IF_MN_GT(0, 3, "eor v11.16b, v11.16b, v11.16b\n") | |||||
IF_MN_GT(0, 4, "eor v12.16b, v12.16b, v12.16b\n") | |||||
IF_MN_GT(0, 5, "eor v13.16b, v13.16b, v13.16b\n") | |||||
"prfm pstl1keep, [%[outptr1]]\n" | |||||
IF_MN_GT(0, 6, "eor v14.16b, v14.16b, v14.16b\n") | |||||
IF_MN_GT(0, 7, "eor v15.16b, v15.16b, v15.16b\n") | |||||
IF_MN_GT(0, 8, "eor v16.16b, v16.16b, v16.16b\n") | |||||
IF_N_GT(0, "ld1 {v2.8h}, [%[b_ptr]], #16\n") | |||||
IF_MN_GT(0, 9, "eor v17.16b, v17.16b, v17.16b\n") | |||||
IF_MN_GT(0, 10, "eor v18.16b, v18.16b, v18.16b\n") | |||||
IF_MN_GT(0, 11, "eor v19.16b, v19.16b, v19.16b\n") | |||||
IF_MN_GT(1, 0, "eor v20.16b, v20.16b, v20.16b\n") | |||||
IF_MN_GT(1, 1, "eor v21.16b, v21.16b, v21.16b\n") | |||||
IF_M_GT(0, "ld1 {v0.8h}, [%[a_ptr]], #16\n") | |||||
IF_MN_GT(1, 2, "eor v22.16b, v22.16b, v22.16b\n") | |||||
IF_MN_GT(1, 3, "eor v23.16b, v23.16b, v23.16b\n") | |||||
IF_MN_GT(1, 4, "eor v24.16b, v24.16b, v24.16b\n") | |||||
IF_MN_GT(1, 5, "eor v25.16b, v25.16b, v25.16b\n") | |||||
IF_MN_GT(1, 6, "eor v26.16b, v26.16b, v26.16b\n") | |||||
IF_MN_GT(1, 7, "eor v27.16b, v27.16b, v27.16b\n") | |||||
IF_MN_GT(1, 8, "eor v28.16b, v28.16b, v28.16b\n") | |||||
IF_MN_GT(1, 9, "eor v29.16b, v29.16b, v29.16b\n") | |||||
IF_MN_GT(1, 10, "eor v30.16b, v30.16b, v30.16b\n") | |||||
IF_MN_GT(1, 11, "eor v31.16b, v31.16b, v31.16b\n") | |||||
"2:\n" | |||||
"cmp %w[K], #0\n" | |||||
"beq 4f\n" | |||||
"3:\n" | |||||
"ld1 {v3.8h}, [%[b_ptr]], #16\n" | |||||
IF_MN_GT(0, 0, "fmla v8.8h, v0.8h, v2.h[0]\n") | |||||
IF_MN_GT(0, 1, "fmla v9.8h, v0.8h, v2.h[1]\n") | |||||
IF_MN_GT(0, 2, "fmla v10.8h, v0.8h, v2.h[2]\n") | |||||
IF_MN_GT(0, 3, "fmla v11.8h, v0.8h, v2.h[3]\n") | |||||
IF_M_GT(1, "ld1 {v1.8h}, [%[a_ptr]], #16\n") | |||||
IF_MN_GT(0, 4, "fmla v12.8h, v0.8h, v2.h[4]\n") | |||||
IF_MN_GT(0, 5, "fmla v13.8h, v0.8h, v2.h[5]\n") | |||||
IF_MN_GT(0, 6, "fmla v14.8h, v0.8h, v2.h[6]\n") | |||||
IF_MN_GT(0, 7, "fmla v15.8h, v0.8h, v2.h[7]\n") | |||||
IF_M_GT(0, "ld1 {v5.8h}, [%[a_ptr]], #16\n") | |||||
IF_MN_GT(0, 8, "fmla v16.8h, v0.8h, v3.h[0]\n") | |||||
IF_MN_GT(0, 9, "fmla v17.8h, v0.8h, v3.h[1]\n") | |||||
IF_MN_GT(0, 10, "fmla v18.8h, v0.8h, v3.h[2]\n") | |||||
IF_MN_GT(0, 11, "fmla v19.8h, v0.8h, v3.h[3]\n") | |||||
IF_MN_GT(1, 0, "fmla v20.8h, v1.8h, v2.h[0]\n") | |||||
IF_MN_GT(1, 1, "fmla v21.8h, v1.8h, v2.h[1]\n") | |||||
IF_MN_GT(1, 2, "fmla v22.8h, v1.8h, v2.h[2]\n") | |||||
IF_MN_GT(1, 3, "fmla v23.8h, v1.8h, v2.h[3]\n") | |||||
IF_MN_GT(1, 4, "fmla v24.8h, v1.8h, v2.h[4]\n") | |||||
IF_MN_GT(1, 5, "fmla v25.8h, v1.8h, v2.h[5]\n") | |||||
"ld1 {v4.8h}, [%[b_ptr]], #16\n" | |||||
IF_MN_GT(1, 6, "fmla v26.8h, v1.8h, v2.h[6]\n") | |||||
IF_MN_GT(1, 7, "fmla v27.8h, v1.8h, v2.h[7]\n") | |||||
IF_MN_GT(1, 8, "fmla v28.8h, v1.8h, v3.h[0]\n") | |||||
IF_MN_GT(1, 9, "fmla v29.8h, v1.8h, v3.h[1]\n") | |||||
IF_MN_GT(1, 10, "fmla v30.8h, v1.8h, v3.h[2]\n") | |||||
IF_MN_GT(1, 11, "fmla v31.8h, v1.8h, v3.h[3]\n") | |||||
IF_M_GT(1, "ld1 {v6.8h}, [%[a_ptr]], #16\n") | |||||
IF_MN_GT(0, 0, "fmla v8.8h, v5.8h, v3.h[4]\n") | |||||
IF_MN_GT(0, 1, "fmla v9.8h, v5.8h, v3.h[5]\n") | |||||
IF_M_GT(0, "ld1 {v0.8h}, [%[a_ptr]], #16\n") | |||||
IF_MN_GT(0, 2, "fmla v10.8h, v5.8h, v3.h[6]\n") | |||||
IF_MN_GT(0, 3, "fmla v11.8h, v5.8h, v3.h[7]\n") | |||||
IF_MN_GT(0, 4, "fmla v12.8h, v5.8h, v4.h[0]\n") | |||||
IF_MN_GT(0, 5, "fmla v13.8h, v5.8h, v4.h[1]\n") | |||||
"ld1 {v2.8h}, [%[b_ptr]], #16\n" | |||||
IF_MN_GT(0, 6, "fmla v14.8h, v5.8h, v4.h[2]\n") | |||||
IF_MN_GT(0, 7, "fmla v15.8h, v5.8h, v4.h[3]\n") | |||||
IF_MN_GT(0, 8, "fmla v16.8h, v5.8h, v4.h[4]\n") | |||||
IF_MN_GT(0, 9, "fmla v17.8h, v5.8h, v4.h[5]\n") | |||||
IF_MN_GT(0, 10, "fmla v18.8h, v5.8h, v4.h[6]\n") | |||||
IF_MN_GT(0, 11, "fmla v19.8h, v5.8h, v4.h[7]\n") | |||||
IF_MN_GT(1, 0, "fmla v20.8h, v6.8h, v3.h[4]\n") | |||||
IF_MN_GT(1, 1, "fmla v21.8h, v6.8h, v3.h[5]\n") | |||||
IF_MN_GT(1, 2, "fmla v22.8h, v6.8h, v3.h[6]\n") | |||||
IF_MN_GT(1, 3, "fmla v23.8h, v6.8h, v3.h[7]\n") | |||||
IF_MN_GT(1, 4, "fmla v24.8h, v6.8h, v4.h[0]\n") | |||||
IF_MN_GT(1, 5, "fmla v25.8h, v6.8h, v4.h[1]\n") | |||||
"subs %w[K], %w[K], #1\n" | |||||
IF_MN_GT(1, 6, "fmla v26.8h, v6.8h, v4.h[2]\n") | |||||
IF_MN_GT(1, 7, "fmla v27.8h, v6.8h, v4.h[3]\n") | |||||
IF_MN_GT(1, 8, "fmla v28.8h, v6.8h, v4.h[4]\n") | |||||
IF_MN_GT(1, 9, "fmla v29.8h, v6.8h, v4.h[5]\n") | |||||
IF_MN_GT(1, 10, "fmla v30.8h, v6.8h, v4.h[6]\n") | |||||
IF_MN_GT(1, 11, "fmla v31.8h, v6.8h, v4.h[7]\n") | |||||
"bne 3b\n" | |||||
"4:\n" | |||||
"cmp %w[oddK], #1\n" | |||||
"beq 5f\n" | |||||
// even tail | |||||
"ld1 {v3.8h}, [%[b_ptr]], #16\n" | |||||
IF_MN_GT(0, 0, "fmla v8.8h, v0.8h, v2.h[0]\n") | |||||
IF_MN_GT(0, 1, "fmla v9.8h, v0.8h, v2.h[1]\n") | |||||
IF_MN_GT(0, 2, "fmla v10.8h, v0.8h, v2.h[2]\n") | |||||
IF_MN_GT(0, 3, "fmla v11.8h, v0.8h, v2.h[3]\n") | |||||
IF_M_GT(1, "ld1 {v1.8h}, [%[a_ptr]], #16\n") | |||||
IF_MN_GT(0, 4, "fmla v12.8h, v0.8h, v2.h[4]\n") | |||||
IF_MN_GT(0, 5, "fmla v13.8h, v0.8h, v2.h[5]\n") | |||||
IF_MN_GT(0, 6, "fmla v14.8h, v0.8h, v2.h[6]\n") | |||||
IF_MN_GT(0, 7, "fmla v15.8h, v0.8h, v2.h[7]\n") | |||||
IF_M_GT(0, "ld1 {v5.8h}, [%[a_ptr]], #16\n") | |||||
IF_MN_GT(0, 8, "fmla v16.8h, v0.8h, v3.h[0]\n") | |||||
IF_MN_GT(0, 9, "fmla v17.8h, v0.8h, v3.h[1]\n") | |||||
IF_MN_GT(0, 10, "fmla v18.8h, v0.8h, v3.h[2]\n") | |||||
IF_MN_GT(0, 11, "fmla v19.8h, v0.8h, v3.h[3]\n") | |||||
IF_MN_GT(1, 0, "fmla v20.8h, v1.8h, v2.h[0]\n") | |||||
IF_MN_GT(1, 1, "fmla v21.8h, v1.8h, v2.h[1]\n") | |||||
IF_MN_GT(1, 2, "fmla v22.8h, v1.8h, v2.h[2]\n") | |||||
IF_MN_GT(1, 3, "fmla v23.8h, v1.8h, v2.h[3]\n") | |||||
IF_MN_GT(1, 4, "fmla v24.8h, v1.8h, v2.h[4]\n") | |||||
IF_MN_GT(1, 5, "fmla v25.8h, v1.8h, v2.h[5]\n") | |||||
"ld1 {v4.8h}, [%[b_ptr]], #16\n" | |||||
IF_MN_GT(1, 6, "fmla v26.8h, v1.8h, v2.h[6]\n") | |||||
IF_MN_GT(1, 7, "fmla v27.8h, v1.8h, v2.h[7]\n") | |||||
IF_MN_GT(1, 8, "fmla v28.8h, v1.8h, v3.h[0]\n") | |||||
IF_MN_GT(1, 9, "fmla v29.8h, v1.8h, v3.h[1]\n") | |||||
IF_MN_GT(1, 10, "fmla v30.8h, v1.8h, v3.h[2]\n") | |||||
IF_MN_GT(1, 11, "fmla v31.8h, v1.8h, v3.h[3]\n") | |||||
IF_M_GT(1, "ld1 {v6.8h}, [%[a_ptr]], #16\n") | |||||
IF_MN_GT(0, 0, "fmla v8.8h, v5.8h, v3.h[4]\n") | |||||
IF_MN_GT(0, 1, "fmla v9.8h, v5.8h, v3.h[5]\n") | |||||
IF_MN_GT(0, 2, "fmla v10.8h, v5.8h, v3.h[6]\n") | |||||
IF_MN_GT(0, 3, "fmla v11.8h, v5.8h, v3.h[7]\n") | |||||
IF_MN_GT(0, 4, "fmla v12.8h, v5.8h, v4.h[0]\n") | |||||
IF_MN_GT(0, 5, "fmla v13.8h, v5.8h, v4.h[1]\n") | |||||
IF_MN_GT(0, 6, "fmla v14.8h, v5.8h, v4.h[2]\n") | |||||
IF_MN_GT(0, 7, "fmla v15.8h, v5.8h, v4.h[3]\n") | |||||
IF_MN_GT(0, 0, "st1 {v8.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(0, 8, "fmla v16.8h, v5.8h, v4.h[4]\n") | |||||
IF_MN_GT(0, 1, "st1 {v9.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(0, 9, "fmla v17.8h, v5.8h, v4.h[5]\n") | |||||
IF_MN_GT(0, 2, "st1 {v10.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(0, 10, "fmla v18.8h, v5.8h, v4.h[6]\n") | |||||
IF_MN_GT(0, 3, "st1 {v11.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(0, 11, "fmla v19.8h, v5.8h, v4.h[7]\n") | |||||
IF_MN_GT(0, 4, "st1 {v12.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 0, "fmla v20.8h, v6.8h, v3.h[4]\n") | |||||
IF_MN_GT(0, 5, "st1 {v13.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 1, "fmla v21.8h, v6.8h, v3.h[5]\n") | |||||
IF_MN_GT(0, 6, "st1 {v14.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 2, "fmla v22.8h, v6.8h, v3.h[6]\n") | |||||
IF_MN_GT(0, 7, "st1 {v15.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 3, "fmla v23.8h, v6.8h, v3.h[7]\n") | |||||
IF_MN_GT(0, 8, "st1 {v16.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 4, "fmla v24.8h, v6.8h, v4.h[0]\n") | |||||
IF_MN_GT(0, 9, "st1 {v17.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 5, "fmla v25.8h, v6.8h, v4.h[1]\n") | |||||
IF_MN_GT(0, 10, "st1 {v18.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 6, "fmla v26.8h, v6.8h, v4.h[2]\n") | |||||
IF_MN_GT(0, 11, "st1 {v19.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 7, "fmla v27.8h, v6.8h, v4.h[3]\n") | |||||
IF_MN_GT(1, 0, "st1 {v20.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 8, "fmla v28.8h, v6.8h, v4.h[4]\n") | |||||
IF_MN_GT(1, 1, "st1 {v21.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 9, "fmla v29.8h, v6.8h, v4.h[5]\n") | |||||
IF_MN_GT(1, 2, "st1 {v22.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 10, "fmla v30.8h, v6.8h, v4.h[6]\n") | |||||
IF_MN_GT(1, 3, "st1 {v23.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 11, "fmla v31.8h, v6.8h, v4.h[7]\n") | |||||
IF_MN_GT(1, 4, "st1 {v24.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 5, "st1 {v25.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 6, "st1 {v26.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 7, "st1 {v27.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 8, "st1 {v28.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 9, "st1 {v29.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 10, "st1 {v30.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 11, "st1 {v31.8h}, [%[outptr1]], #16\n") | |||||
"b 6f\n" | |||||
"5:\n" | |||||
// odd tail | |||||
"ld1 {v3.4h}, [%[b_ptr]], #8\n" | |||||
IF_MN_GT(0, 0, "fmla v8.8h, v0.8h, v2.h[0]\n") | |||||
IF_MN_GT(0, 1, "fmla v9.8h, v0.8h, v2.h[1]\n") | |||||
IF_MN_GT(0, 2, "fmla v10.8h, v0.8h, v2.h[2]\n") | |||||
IF_MN_GT(0, 3, "fmla v11.8h, v0.8h, v2.h[3]\n") | |||||
IF_MN_GT(0, 4, "fmla v12.8h, v0.8h, v2.h[4]\n") | |||||
IF_MN_GT(0, 5, "fmla v13.8h, v0.8h, v2.h[5]\n") | |||||
IF_M_GT(1, "ld1 {v1.8h}, [%[a_ptr]], #16\n") | |||||
IF_MN_GT(0, 6, "fmla v14.8h, v0.8h, v2.h[6]\n") | |||||
IF_MN_GT(0, 7, "fmla v15.8h, v0.8h, v2.h[7]\n") | |||||
IF_MN_GT(0, 0, "st1 {v8.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(0, 8, "fmla v16.8h, v0.8h, v3.h[0]\n") | |||||
IF_MN_GT(0, 1, "st1 {v9.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(0, 9, "fmla v17.8h, v0.8h, v3.h[1]\n") | |||||
IF_MN_GT(0, 2, "st1 {v10.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(0, 10, "fmla v18.8h, v0.8h, v3.h[2]\n") | |||||
IF_MN_GT(0, 3, "st1 {v11.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(0, 11, "fmla v19.8h, v0.8h, v3.h[3]\n") | |||||
IF_MN_GT(0, 4, "st1 {v12.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 0, "fmla v20.8h, v1.8h, v2.h[0]\n") | |||||
IF_MN_GT(0, 5, "st1 {v13.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 1, "fmla v21.8h, v1.8h, v2.h[1]\n") | |||||
IF_MN_GT(0, 6, "st1 {v14.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 2, "fmla v22.8h, v1.8h, v2.h[2]\n") | |||||
IF_MN_GT(0, 7, "st1 {v15.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 3, "fmla v23.8h, v1.8h, v2.h[3]\n") | |||||
IF_MN_GT(0, 8, "st1 {v16.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 4, "fmla v24.8h, v1.8h, v2.h[4]\n") | |||||
IF_MN_GT(0, 9, "st1 {v17.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 5, "fmla v25.8h, v1.8h, v2.h[5]\n") | |||||
IF_MN_GT(0, 10, "st1 {v18.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 6, "fmla v26.8h, v1.8h, v2.h[6]\n") | |||||
IF_MN_GT(0, 11, "st1 {v19.8h}, [%[outptr0]], #16\n") | |||||
IF_MN_GT(1, 7, "fmla v27.8h, v1.8h, v2.h[7]\n") | |||||
IF_MN_GT(1, 0, "st1 {v20.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 8, "fmla v28.8h, v1.8h, v3.h[0]\n") | |||||
IF_MN_GT(1, 1, "st1 {v21.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 9, "fmla v29.8h, v1.8h, v3.h[1]\n") | |||||
IF_MN_GT(1, 2, "st1 {v22.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 10, "fmla v30.8h, v1.8h, v3.h[2]\n") | |||||
IF_MN_GT(1, 3, "st1 {v23.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 11, "fmla v31.8h, v1.8h, v3.h[3]\n") | |||||
IF_MN_GT(1, 4, "st1 {v24.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 5, "st1 {v25.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 6, "st1 {v26.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 7, "st1 {v27.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 8, "st1 {v28.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 9, "st1 {v29.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 10, "st1 {v30.8h}, [%[outptr1]], #16\n") | |||||
IF_MN_GT(1, 11, "st1 {v31.8h}, [%[outptr1]], #16\n") | |||||
"6:\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[is_first_k] "+r"(is_first_k), [oddK] "+r"(oddK), | |||||
[outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v8", "v9", "v10", "v11", "v12", "v13", | |||||
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||||
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x1", "x2", | |||||
"cc", "memory"); | |||||
#undef IF_MN_GT | |||||
#undef IF_N_GT | |||||
#undef IF_M_GT | |||||
} | |||||
#endif |
@@ -9,6 +9,9 @@ namespace matmul { | |||||
MEGDNN_REG_GEMM_STRATEGY( | MEGDNN_REG_GEMM_STRATEGY( | ||||
dt_float16, dt_float16, dt_float16, 8, 24, 1, false, true, hgemm_8x24); | dt_float16, dt_float16, dt_float16, 8, 24, 1, false, true, hgemm_8x24); | ||||
MEGDNN_REG_GEMM_STRATEGY( | |||||
dt_float16, dt_float16, dt_float16, 16, 12, 1, false, false, hgemm_mk8_16x12); | |||||
MEGDNN_REG_GEMM_STRATEGY_NOPACK( | MEGDNN_REG_GEMM_STRATEGY_NOPACK( | ||||
dt_float16, dt_float16, dt_float16, 8, 8, 1, false, true, gemm_nopack_f16_8x8); | dt_float16, dt_float16, dt_float16, 8, 8, 1, false, true, gemm_nopack_f16_8x8); | ||||
@@ -0,0 +1,107 @@ | |||||
#include "src/aarch64/matrix_mul/fp16/kernel_mk8_16x12.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
typedef void (*kern_func)( | |||||
const dt_float16*, const dt_float16*, int, dt_float16*, int, bool); | |||||
static kern_func kern_func_table[2][12] = { | |||||
{matmul_mk8_16x12::kern<1, 1>, matmul_mk8_16x12::kern<1, 2>, | |||||
matmul_mk8_16x12::kern<1, 3>, matmul_mk8_16x12::kern<1, 4>, | |||||
matmul_mk8_16x12::kern<1, 5>, matmul_mk8_16x12::kern<1, 6>, | |||||
matmul_mk8_16x12::kern<1, 7>, matmul_mk8_16x12::kern<1, 8>, | |||||
matmul_mk8_16x12::kern<1, 9>, matmul_mk8_16x12::kern<1, 10>, | |||||
matmul_mk8_16x12::kern<1, 11>, matmul_mk8_16x12::kern<1, 12>}, | |||||
{matmul_mk8_16x12::kern<2, 1>, matmul_mk8_16x12::kern<2, 2>, | |||||
matmul_mk8_16x12::kern<2, 3>, matmul_mk8_16x12::kern<2, 4>, | |||||
matmul_mk8_16x12::kern<2, 5>, matmul_mk8_16x12::kern<2, 6>, | |||||
matmul_mk8_16x12::kern<2, 7>, matmul_mk8_16x12::kern<2, 8>, | |||||
matmul_mk8_16x12::kern<2, 9>, matmul_mk8_16x12::kern<2, 10>, | |||||
matmul_mk8_16x12::kern<2, 11>, matmul_mk8_16x12::kern<2, 12>}}; | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(hgemm_mk8_16x12); | |||||
void hgemm_mk8_16x12::pack_A( | |||||
dt_float16* out, const dt_float16* in, int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose_A) const { | |||||
megdnn_assert(!transpose_A, "mk8 float16 matmul not support transpose A"); | |||||
matmul_mk8_16x12::hgemm_16x12_pack_A(out, in, ldin, y0, ymax, k0, kmax); | |||||
} | |||||
void hgemm_mk8_16x12::pack_B( | |||||
dt_float16* out, const dt_float16* in, int ldin, int x0, int xmax, int k0, | |||||
int kmax, bool transpose_B) const { | |||||
megdnn_assert(!transpose_B, "mk8 float16 matmul not support transpose B"); | |||||
matmul_mk8_16x12::hgemm_16x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | |||||
// Overview of register layout: | |||||
// | |||||
// A 12x2 cell of Rhs is stored in 16bit in q2-q4. | |||||
// A 2x16 cell of Lhs is stored in 16bit in q0-q1 and q5-q6 | |||||
// A 12x16 block of accumulators is stored in 16bit in q8--q31. | |||||
// | |||||
// +----+----+ | |||||
// | v0 | v1 | | |||||
// Rhs +----+----+ | |||||
// | v5 | v6 | | |||||
// +----+----+ | |||||
// | |||||
// | | | | |||||
// | |||||
// Lhs | | | | |||||
// | |||||
// +---------------+---------------+ - - - - +----+----+ | |||||
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v8 | v20| | |||||
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v9 | v21| | |||||
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v10| v22| | |||||
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v11| v23| | |||||
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v12| v24| | |||||
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v13| v25| | |||||
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v14| v26| | |||||
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v15| v27| | |||||
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v16| v28| | |||||
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v17| v29| | |||||
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v18| v30| | |||||
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v19| v31| | |||||
// +---------------+---------------+ - - - - +----+----+ | |||||
// | |||||
// Accumulator | |||||
void hgemm_mk8_16x12::kern( | |||||
const dt_float16* packedA, const dt_float16* packedB, size_t M, size_t N, | |||||
size_t K, dt_float16* C, size_t LDC, bool is_first_k, const dt_float16*, | |||||
dt_float16*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float16); | |||||
const size_t K16 = K * 16; | |||||
const size_t K8 = K * 8; | |||||
const size_t K12 = K * 12; | |||||
constexpr size_t PACK_C_SIZE = 8; | |||||
constexpr size_t A_BLOCK = 16; | |||||
constexpr size_t B_BLOCK = 12; | |||||
size_t m = 0; | |||||
for (; m < M; m += A_BLOCK) { | |||||
dt_float16* outptr = C + (m / PACK_C_SIZE * LDC); | |||||
const size_t m_func_idx = std::min<size_t>(M - m, A_BLOCK) / 8 - 1; | |||||
size_t n = 0; | |||||
const dt_float16* cur_packedB = packedB; | |||||
for (; n < N; n += B_BLOCK) { | |||||
const size_t n_func_idx = std::min<size_t>(N - n, B_BLOCK) - 1; | |||||
kern_func_table[m_func_idx][n_func_idx]( | |||||
packedA, cur_packedB, K, outptr, LDC, is_first_k); | |||||
cur_packedB += K12; | |||||
outptr += B_BLOCK * PACK_C_SIZE; | |||||
} | |||||
packedA += (m_func_idx ? K16 : K8); | |||||
} | |||||
} | |||||
#endif |
@@ -15,6 +15,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
AlgoF16K8x24x1 f16_k8x24x1; | AlgoF16K8x24x1 f16_k8x24x1; | ||||
AlgoF16MK8_8x8 f16_mk8_8x8; | AlgoF16MK8_8x8 f16_mk8_8x8; | ||||
AlgoF16MK8_16x12x1 f16_mk8_16x12x1; | |||||
#endif | #endif | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; | AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; | ||||
@@ -52,6 +53,7 @@ public: | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
m_all_algos.emplace_back(&f16_k8x24x1); | m_all_algos.emplace_back(&f16_k8x24x1); | ||||
m_all_algos.emplace_back(&f16_mk8_8x8); | m_all_algos.emplace_back(&f16_mk8_8x8); | ||||
m_all_algos.emplace_back(&f16_mk8_16x12x1); | |||||
#endif | #endif | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
m_all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | m_all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | ||||
@@ -25,8 +25,9 @@ private: | |||||
class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4 | class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4 | ||||
class AlgoF32Gemv; // Aarch64 F32 Gemv | class AlgoF32Gemv; // Aarch64 F32 Gemv | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
class AlgoF16K8x24x1; // Aarch64 F16 Kernel 8x24x1 | |||||
class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8 | |||||
class AlgoF16K8x24x1; // Aarch64 F16 Kernel 8x24x1 | |||||
class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8 | |||||
class AlgoF16MK8_16x12x1; // Aarch64 F16 Format MK8 block 16x12x1 | |||||
#endif | #endif | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
@@ -136,6 +136,7 @@ public: | |||||
AARCH64_F32_GEMV, | AARCH64_F32_GEMV, | ||||
AARCH64_F16_K8X24X1, | AARCH64_F16_K8X24X1, | ||||
AARCH64_F16_MK8_8X8, | AARCH64_F16_MK8_8X8, | ||||
AARCH64_F16_MK8_16X12X1, | |||||
AARCH64_INT8X8X32_K8X12X4_DOTPROD, | AARCH64_INT8X8X32_K8X12X4_DOTPROD, | ||||
AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD, | AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD, | ||||
AARCH64_INT8X8X32_MK4_4X4X16, | AARCH64_INT8X8X32_MK4_4X4X16, | ||||
@@ -74,6 +74,12 @@ TEST_F(AARCH64, MATRIX_MUL_F16_MK8) { | |||||
dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), | dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), | ||||
"AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 1); | "AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 1); | ||||
} | } | ||||
TEST_F(AARCH64, MATRIX_MUL_F16_MK8_16x12x1) { | |||||
matrix_mul::check_matrix_mul( | |||||
dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), | |||||
"AARCH64_F16_MK8_16X12X1", param::MatrixMul::Format::MK8, 1); | |||||
} | |||||
#endif | #endif | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
@@ -790,6 +796,14 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_F16_MK8) { | |||||
"AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, dtype::Float16{}, | "AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, dtype::Float16{}, | ||||
dtype::Float16{}, dtype::Float16{}, "AARCH64_F16_K8X24X1"); | dtype::Float16{}, dtype::Float16{}, "AARCH64_F16_K8X24X1"); | ||||
} | } | ||||
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_F16_MK8_16x12) { | |||||
auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8); | |||||
matrix_mul::benchmark_with_contrast( | |||||
handle(), args, dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, | |||||
"AARCH64_F16_MK8_16X12X1", param::MatrixMul::Format::MK8, dtype::Float16{}, | |||||
dtype::Float16{}, dtype::Float16{}, "AARCH64_F16_K8X24X1"); | |||||
} | |||||
#endif | #endif | ||||
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16x16x32) { | TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16x16x32) { | ||||