GitOrigin-RevId: 0f64b9f70f
tags/v1.0.0-rc1
@@ -214,8 +214,7 @@ bool MatrixMulImpl::AlgoF32MK4_4x16::usable( | |||
kern_size_param.B_type == dtype::Float32() && | |||
kern_size_param.A_type == dtype::Float32() && | |||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
!kern_size_param.trA && !kern_size_param.trB && | |||
kern_size_param.N % 4 == 0; | |||
!kern_size_param.trA && !kern_size_param.trB; | |||
} | |||
size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace( | |||
@@ -330,8 +329,7 @@ bool MatrixMulImpl::AlgoF16MK8_8x8::usable( | |||
kern_size_param.B_type == kern_size_param.A_type && | |||
kern_size_param.A_type == dtype::Float16() && | |||
kern_size_param.format == param::MatrixMul::Format::MK8 && | |||
!kern_size_param.trA && !kern_size_param.trB && | |||
kern_size_param.N % 4 == 0; | |||
!kern_size_param.trA && !kern_size_param.trB; | |||
} | |||
size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace( | |||
@@ -918,8 +916,7 @@ bool MatrixMulImpl::AlgoInt16x16x32MK8_8x8::usable( | |||
kern_size_param.B_type == dtype::Int16() && | |||
kern_size_param.A_type == dtype::Int16() && | |||
kern_size_param.format == param::MatrixMul::Format::MK8 && | |||
!kern_size_param.trA && !kern_size_param.trB && | |||
kern_size_param.N % 4 == 0; | |||
!kern_size_param.trA && !kern_size_param.trB; | |||
} | |||
size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace( | |||
@@ -21,6 +21,76 @@ using namespace aarch64::matmul; | |||
namespace { | |||
void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
dt_float16* output) { | |||
LDB *= sizeof(dt_float16); | |||
asm volatile( | |||
".arch armv8.2-a+fp16\n" | |||
"subs %w[K], %w[K], #8\n" | |||
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[a_ptr]], 64\n" | |||
"ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[a_ptr]], 64\n" | |||
"eor v24.16b, v24.16b, v24.16b\n" | |||
"eor v25.16b, v25.16b, v25.16b\n" | |||
"eor v26.16b, v26.16b, v26.16b\n" | |||
"eor v27.16b, v27.16b, v27.16b\n" | |||
"eor v28.16b, v28.16b, v28.16b\n" | |||
"eor v29.16b, v29.16b, v29.16b\n" | |||
"eor v30.16b, v30.16b, v30.16b\n" | |||
"eor v31.16b, v31.16b, v31.16b\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"fmla v24.8h, v16.8h, v0.h[0]\n" | |||
"fmla v25.8h, v17.8h, v0.h[1]\n" | |||
"fmla v26.8h, v18.8h, v0.h[2]\n" | |||
"fmla v27.8h, v19.8h, v0.h[3]\n" | |||
"beq 2f\n" | |||
"1:\n" | |||
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[a_ptr]], 64\n" | |||
"fmla v28.8h, v20.8h, v0.h[4]\n" | |||
"fmla v29.8h, v21.8h, v0.h[5]\n" | |||
"fmla v30.8h, v22.8h, v0.h[6]\n" | |||
"fmla v31.8h, v23.8h, v0.h[7]\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[a_ptr]], 64\n" | |||
"fmla v24.8h, v16.8h, v0.h[0]\n" | |||
"fmla v25.8h, v17.8h, v0.h[1]\n" | |||
"fmla v26.8h, v18.8h, v0.h[2]\n" | |||
"fmla v27.8h, v19.8h, v0.h[3]\n" | |||
"subs %w[K], %w[K], #8\n" | |||
"bne 1b\n" | |||
"2:\n" | |||
"fmla v28.8h, v20.8h, v0.h[4]\n" | |||
"fmla v29.8h, v21.8h, v0.h[5]\n" | |||
"fmla v30.8h, v22.8h, v0.h[6]\n" | |||
"fmla v31.8h, v23.8h, v0.h[7]\n" | |||
"fadd v24.8h, v24.8h, v25.8h\n" | |||
"fadd v26.8h, v26.8h, v27.8h\n" | |||
"fadd v28.8h, v28.8h, v29.8h\n" | |||
"fadd v30.8h, v30.8h, v31.8h\n" | |||
"fadd v24.8h, v24.8h, v26.8h\n" | |||
"fadd v28.8h, v28.8h, v30.8h\n" | |||
"fadd v24.8h, v24.8h, v28.8h\n" | |||
"st1 {v24.4s}, [%[output]], 16\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||
"memory"); | |||
} | |||
// Overview of register layout: | |||
// | |||
// A 8x1 cell of Rhs is stored in 16bit in v0-v3 | |||
@@ -416,7 +486,7 @@ void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, | |||
constexpr static size_t NB = 8; | |||
constexpr static size_t CALCBLK = 4; | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||
for (size_t m = 0; m < M; m += MB) { | |||
@@ -428,8 +498,17 @@ void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, | |||
cur_B += KB * NB; | |||
output += MB * NB; | |||
} | |||
if (n < N) { | |||
if (N - n >= 4) { | |||
kern_8x4(A, cur_B, LDB, K, output); | |||
cur_B += KB * CALCBLK; | |||
output += MB * CALCBLK; | |||
n += 4; | |||
} | |||
while (n < N) { | |||
kern_8x1(A, cur_B, LDB, K, output); | |||
cur_B += KB; | |||
output += MB; | |||
n++; | |||
} | |||
A += LDA; | |||
} | |||
@@ -20,6 +20,54 @@ using namespace aarch64::matmul; | |||
namespace { | |||
void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
float* output) { | |||
LDB *= sizeof(float); | |||
asm volatile( | |||
"subs %w[K], %w[K], #4\n" | |||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n" | |||
"eor v16.16b, v16.16b, v16.16b\n" | |||
"eor v17.16b, v17.16b, v17.16b\n" | |||
"eor v18.16b, v18.16b, v18.16b\n" | |||
"eor v19.16b, v19.16b, v19.16b\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"prfm pstl1keep, [%[b_ptr]]\n" | |||
"fmla v16.4s, v4.4s, v0.s[0]\n" | |||
"fmla v17.4s, v5.4s, v0.s[1]\n" | |||
"beq 2f\n" | |||
"1:\n" | |||
"ld1 {v4.4s, v5.4s}, [%[a_ptr]], 32\n" | |||
"fmla v18.4s, v6.4s, v0.s[2]\n" | |||
"fmla v19.4s, v7.4s, v0.s[3]\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"prfm pstl1keep, [%[b_ptr]]\n" | |||
"ld1 {v6.4s, v7.4s}, [%[a_ptr]], 32\n" | |||
"fmla v16.4s, v4.4s, v0.s[0]\n" | |||
"fmla v17.4s, v5.4s, v0.s[1]\n" | |||
"subs %w[K], %w[K], #4\n" | |||
"bne 1b\n" | |||
"2:\n" | |||
"fmla v18.4s, v6.4s, v0.s[2]\n" | |||
"fmla v19.4s, v7.4s, v0.s[3]\n" | |||
"fadd v16.4s, v16.4s, v18.4s\n" | |||
"fadd v17.4s, v17.4s, v19.4s\n" | |||
"fadd v16.4s, v16.4s, v17.4s\n" | |||
"st1 {v16.4s}, [%[output]], 16\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", | |||
"memory"); | |||
} | |||
// Overview of register layout: | |||
// | |||
// A 4x4 block of A is stored in register v4-v7 | |||
@@ -117,7 +165,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||
"v18", "v19", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -535,7 +584,7 @@ void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, | |||
constexpr static size_t NB = 16; | |||
constexpr static size_t CALCBLK = 4; | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
//! (m/4, k/4, 4, 4) * (k/4, n, 4) = (m/4, n, 4) | |||
for (size_t m = 0; m < M; m += MB) { | |||
@@ -547,21 +596,23 @@ void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, | |||
cur_B += KB * NB; | |||
output += MB * NB; | |||
} | |||
switch (N - n) { | |||
case 4: | |||
kern_4x4(A, cur_B, LDB, K, output); | |||
break; | |||
case 8: | |||
kern_4x8(A, cur_B, LDB, K, output); | |||
break; | |||
case 12: | |||
kern_4x8(A, cur_B, LDB, K, output); | |||
cur_B += KB * CALCBLK * 2; | |||
output += MB * CALCBLK * 2; | |||
kern_4x4(A, cur_B, LDB, K, output); | |||
break; | |||
default: | |||
break; | |||
if (N - n >= 8) { | |||
kern_4x8(A, cur_B, LDB, K, output); | |||
cur_B += KB * CALCBLK * 2; | |||
output += MB * CALCBLK * 2; | |||
n += 8; | |||
} | |||
if (N - n >= 4) { | |||
kern_4x4(A, cur_B, LDB, K, output); | |||
cur_B += KB * CALCBLK; | |||
output += MB * CALCBLK; | |||
n += 4; | |||
} | |||
while (n < N) { | |||
kern_4x1(A, cur_B, LDB, K, output); | |||
cur_B += KB; | |||
output += MB; | |||
n++; | |||
} | |||
A += LDA; | |||
} | |||
@@ -20,6 +20,82 @@ using namespace aarch64::matmul; | |||
namespace { | |||
void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
dt_int32* output) { | |||
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | |||
//! here. | |||
LDB *= sizeof(dt_int16); | |||
asm volatile( | |||
"subs %w[K], %w[K], #8\n" | |||
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[a_ptr]], 64\n" | |||
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[a_ptr]], 64\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"smull v16.4s, v24.4h, v0.h[0]\n" | |||
"smull2 v17.4s, v24.8h, v0.h[0]\n" | |||
"smull v18.4s, v25.4h, v0.h[1]\n" | |||
"smull2 v19.4s, v25.8h, v0.h[1]\n" | |||
"smull v20.4s, v26.4h, v0.h[2]\n" | |||
"smull2 v21.4s, v26.8h, v0.h[2]\n" | |||
"smull v22.4s, v27.4h, v0.h[3]\n" | |||
"smull2 v23.4s, v27.8h, v0.h[3]\n" | |||
"beq 2f\n" | |||
"1:\n" | |||
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[a_ptr]], 64\n" | |||
"smlal v16.4s, v28.4h, v0.h[4]\n" | |||
"smlal2 v17.4s, v28.8h, v0.h[4]\n" | |||
"smlal v18.4s, v29.4h, v0.h[5]\n" | |||
"smlal2 v19.4s, v29.8h, v0.h[5]\n" | |||
"smlal v20.4s, v30.4h, v0.h[6]\n" | |||
"smlal2 v21.4s, v30.8h, v0.h[6]\n" | |||
"smlal v22.4s, v31.4h, v0.h[7]\n" | |||
"smlal2 v23.4s, v31.8h, v0.h[7]\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[a_ptr]], 64\n" | |||
"smlal v16.4s, v24.4h, v0.h[0]\n" | |||
"smlal2 v17.4s, v24.8h, v0.h[0]\n" | |||
"smlal v18.4s, v25.4h, v0.h[1]\n" | |||
"smlal2 v19.4s, v25.8h, v0.h[1]\n" | |||
"smlal v20.4s, v26.4h, v0.h[2]\n" | |||
"smlal2 v21.4s, v26.8h, v0.h[2]\n" | |||
"smlal v22.4s, v27.4h, v0.h[3]\n" | |||
"smlal2 v23.4s, v27.8h, v0.h[3]\n" | |||
"subs %w[K], %w[K], #8\n" | |||
"bne 1b\n" | |||
"2:\n" | |||
"smlal v16.4s, v28.4h, v0.h[4]\n" | |||
"smlal2 v17.4s, v28.8h, v0.h[4]\n" | |||
"smlal v18.4s, v29.4h, v0.h[5]\n" | |||
"smlal2 v19.4s, v29.8h, v0.h[5]\n" | |||
"smlal v20.4s, v30.4h, v0.h[6]\n" | |||
"smlal2 v21.4s, v30.8h, v0.h[6]\n" | |||
"smlal v22.4s, v31.4h, v0.h[7]\n" | |||
"smlal2 v23.4s, v31.8h, v0.h[7]\n" | |||
"add v16.4s, v16.4s, v18.4s\n" | |||
"add v20.4s, v20.4s, v22.4s\n" | |||
"add v17.4s, v17.4s, v19.4s\n" | |||
"add v21.4s, v21.4s, v23.4s\n" | |||
"add v16.4s, v16.4s, v20.4s\n" | |||
"add v17.4s, v17.4s, v21.4s\n" | |||
"st1 {v16.4s, v17.4s}, [%[output]], 32\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
// | |||
// A 8x1 cell of Lhs is stored in 16bit in v24-v27 | |||
@@ -636,7 +712,7 @@ void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||
constexpr static size_t NB = 8; | |||
constexpr static size_t CALCBLK = 4; | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||
for (size_t m = 0; m < M; m += MB) { | |||
@@ -648,8 +724,17 @@ void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||
cur_B += KB * NB; | |||
output += MB * NB; | |||
} | |||
if (n < N) { | |||
if (N - n >= 4) { | |||
kern_8x4(A, cur_B, LDB, K, output); | |||
cur_B += KB * CALCBLK; | |||
output += MB * CALCBLK; | |||
n += 4; | |||
} | |||
while (n < N) { | |||
kern_8x1(A, cur_B, LDB, K, output); | |||
cur_B += KB; | |||
output += MB; | |||
n++; | |||
} | |||
A += LDA; | |||
} | |||
@@ -390,7 +390,7 @@ void winograd_2x3_8x8_f16::output(const dt_float16* output_transform_buf, | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp16_F23_8x8, cb, __fp16, __fp16, | |||
megdnn_arm_common_winograd_f16_F23_8x8, cb, __fp16, __fp16, | |||
bmode, nonline_mode, output_transform_buf, bias, output, | |||
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, | |||
oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | |||
@@ -875,8 +875,7 @@ bool MatrixMulImpl::AlgoF32MK4_4x8::usable( | |||
kern_size_param.B_type == kern_size_param.A_type && | |||
kern_size_param.C_type == kern_size_param.A_type && | |||
kern_size_param.A_type == dtype::Float32() && | |||
kern_size_param.N % 4 == 0 && !kern_size_param.trA && | |||
!kern_size_param.trB; | |||
!kern_size_param.trA && !kern_size_param.trB; | |||
} | |||
size_t MatrixMulImpl::AlgoF32MK4_4x8::get_workspace( | |||
@@ -911,8 +910,7 @@ bool MatrixMulImpl::AlgoInt16x16x32MK8_4x8::usable( | |||
kern_size_param.A_type == dtype::Int16() && | |||
kern_size_param.B_type == dtype::Int16() && | |||
kern_size_param.C_type == dtype::Int32() && | |||
kern_size_param.N % 4 == 0 && !kern_size_param.trA && | |||
!kern_size_param.trB; | |||
!kern_size_param.trA && !kern_size_param.trB; | |||
} | |||
size_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_workspace( | |||
@@ -969,8 +967,7 @@ bool MatrixMulImpl::AlgoF16MK8_4x8::usable( | |||
kern_size_param.B_type == kern_size_param.A_type && | |||
kern_size_param.A_type == dtype::Float16() && | |||
kern_size_param.format == param::MatrixMul::Format::MK8 && | |||
!kern_size_param.trA && !kern_size_param.trB && | |||
kern_size_param.N % 4 == 0; | |||
!kern_size_param.trA && !kern_size_param.trB; | |||
} | |||
size_t MatrixMulImpl::AlgoF16MK8_4x8::get_workspace( | |||
@@ -21,6 +21,66 @@ using namespace armv7::matmul; | |||
namespace { | |||
void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
dt_float16* output) { | |||
LDB = (LDB - 4) * sizeof(dt_float16); | |||
asm volatile( | |||
"subs %[K], #8\n" | |||
"vld1.32 {d0}, [%[b_ptr]]!\n" | |||
"vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" | |||
"vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" | |||
"vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" | |||
"vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" | |||
"vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" | |||
"vmul.f16 q12, q4, d0[0]\n" | |||
"vmul.f16 q13, q5, d0[1]\n" | |||
"vmul.f16 q14, q6, d0[2]\n" | |||
"vmul.f16 q15, q7, d0[3]\n" | |||
"beq 2f\n" | |||
"1:\n" | |||
"vmla.f16 q12, q8, d1[0]\n" | |||
"vld1.32 {d0}, [%[b_ptr]]!\n" | |||
"vmla.f16 q13, q9, d1[1]\n" | |||
"vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" | |||
"vmla.f16 q14, q10, d1[2]\n" | |||
"vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" | |||
"vmla.f16 q15, q11, d1[3]\n" | |||
"vmla.f16 q12, q4, d0[0]\n" | |||
"vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" | |||
"vmla.f16 q13, q5, d0[1]\n" | |||
"vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" | |||
"vmla.f16 q14, q6, d0[2]\n" | |||
"vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" | |||
"vmla.f16 q15, q7, d0[3]\n" | |||
"subs %[K], #8\n" | |||
"bne 1b\n" | |||
"2:\n" | |||
"vmla.f16 q12, q8, d1[0]\n" | |||
"vmla.f16 q13, q9, d1[1]\n" | |||
"vmla.f16 q14, q10, d1[2]\n" | |||
"vmla.f16 q15, q11, d1[3]\n" | |||
"vadd.f16 q12, q12, q14\n" | |||
"vadd.f16 q13, q13, q15\n" | |||
"vadd.f16 q12, q12, q13\n" | |||
"vst1.32 {d24, d25}, [%[output]]!\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", | |||
"d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", | |||
"d25", "d26", "d27", "d28", "d29", "d30", "d31", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
// | |||
// A 8x1 cell of Rhs is stored in 16bit in v4-v11 | |||
@@ -45,7 +105,7 @@ namespace { | |||
// | v3[0-7]| |v15[0-7]| | |||
// +--------+ +--------+--------+ | |||
// Accumulator | |||
void kern_4x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
dt_float16* output) { | |||
//! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 | |||
//! here. | |||
@@ -179,19 +239,25 @@ void gemm_nopack_f16_4x8::kern(const dt_float16* A, size_t LDA, | |||
constexpr static size_t MB = 8; | |||
constexpr static size_t KB = 8; | |||
constexpr static size_t NB = 4; | |||
constexpr static size_t CALCBLK = 4; | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||
for (size_t m = 0; m < M; m += MB) { | |||
dt_float16* output = C + (m / MB) * LDC; | |||
const dt_float16* cur_B = B; | |||
for (size_t n = 0; n < N; n += NB) { | |||
kern_4x8(A, cur_B, LDB, K, output); | |||
size_t n = 0; | |||
for (; n + NB - 1 < N; n += NB) { | |||
kern_8x4(A, cur_B, LDB, K, output); | |||
cur_B += KB * NB; | |||
output += MB * NB; | |||
} | |||
while (n < N) { | |||
kern_8x1(A, cur_B, LDB, K, output); | |||
cur_B += KB; | |||
output += MB; | |||
n++; | |||
} | |||
A += LDA; | |||
} | |||
} | |||
@@ -20,6 +20,58 @@ using namespace armv7::matmul; | |||
namespace { | |||
void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) { | |||
LDB = (LDB - 4) * sizeof(float); | |||
asm volatile( | |||
"subs %[K], %[K], #4\n" | |||
"vld1.32 {d8-d11}, [%[A]]!\n" | |||
"vld1.32 {d12-d15}, [%[A]]!\n" | |||
"veor q8, q8 \n" | |||
"veor q9, q9 \n" | |||
"veor q10, q10 \n" | |||
"veor q11, q11 \n" | |||
"vld1.32 {d0-d1}, [%[B]]!\n" | |||
"vmla.f32 q8, q4, d0[0]\n" | |||
"vmla.f32 q9, q5, d0[1]\n" | |||
"beq 2f\n" | |||
"1:\n" | |||
"vld1.32 {d8-d11}, [%[A]]!\n" | |||
"vmla.f32 q10, q6, d1[0]\n" | |||
"vmla.f32 q11, q7, d1[1]\n" | |||
"add %[B], %[B], %[LDB]\n" | |||
"vld1.32 {d0-d1}, [%[B]]!\n" | |||
"vld1.32 {d12-d15}, [%[A]]!\n" | |||
"vmla.f32 q8, q4, d0[0]\n" | |||
"vmla.f32 q9, q5, d0[1]\n" | |||
"subs %[K], %[K], #4\n" | |||
"bne 1b\n" | |||
"2:\n" | |||
"vmla.f32 q10, q6, d1[0]\n" | |||
"vmla.f32 q11, q7, d1[1]\n" | |||
"vadd.f32 q8, q8, q10\n" | |||
"vadd.f32 q9, q9, q11\n" | |||
"vadd.f32 q8, q8, q9\n" | |||
"vst1.32 {d16, d17}, [%[C]]!\n" | |||
: [ A ] "+r"(A), [ B ] "+r"(B), [ K ] "+r"(K), [ C ] "+r"(C) | |||
: [ LDB ] "r"(LDB) | |||
: "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", | |||
"d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "cc", | |||
"memory"); | |||
} | |||
// Overview of register layout: | |||
// | |||
// A 8x4 cell of Rhs is stored in 32bit in q0-q3, load 4 register each time | |||
@@ -268,9 +320,9 @@ void sgemm_nopack_4x8::kern(const float* A, size_t LDA, const float* B, | |||
constexpr size_t MB = 4; | |||
constexpr size_t KB = 4; | |||
constexpr size_t NB = 8; | |||
constexpr size_t CALCBLK = 4; | |||
constexpr size_t NB_HALF = 4; | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||
for (size_t m = 0; m < M; m += MB) { | |||
@@ -282,8 +334,17 @@ void sgemm_nopack_4x8::kern(const float* A, size_t LDA, const float* B, | |||
cur_B += KB * NB; | |||
output += MB * NB; | |||
} | |||
if (n < N) { | |||
if (N - n >= 4) { | |||
kern_4x4(A, cur_B, LDB, K, output); | |||
cur_B += KB * NB_HALF; | |||
output += MB * NB_HALF; | |||
n += 4; | |||
} | |||
while (n < N) { | |||
kern_4x1(A, cur_B, LDB, K, output); | |||
cur_B += KB; | |||
output += MB; | |||
n++; | |||
} | |||
A += LDA; | |||
} | |||
@@ -20,6 +20,91 @@ using namespace armv7::matmul; | |||
namespace { | |||
void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
dt_int32* output) { | |||
//! As each load 16 number from B, but the pos add 16 * 2, so we minus 16 | |||
//! here. | |||
LDB = (LDB - 4) * sizeof(dt_int16); | |||
asm volatile( | |||
"subs %[K], #8\n" | |||
"vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" | |||
"vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" | |||
"vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" | |||
"vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" | |||
"vld1.32 {d0}, [%[b_ptr]]!\n" | |||
"vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" | |||
"vmull.s16 q12, d8, d0[0]\n" | |||
"vmull.s16 q13, d9, d0[0]\n" | |||
"vmull.s16 q14, d10, d0[1]\n" | |||
"vmull.s16 q15, d11, d0[1]\n" | |||
"vmlal.s16 q12, d12, d0[2]\n" | |||
"vmlal.s16 q13, d13, d0[2]\n" | |||
"vmlal.s16 q14, d14, d0[3]\n" | |||
"vmlal.s16 q15, d15, d0[3]\n" | |||
"beq 2f\n" | |||
"1:\n" | |||
"vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" | |||
"vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" | |||
"vld1.32 {d0}, [%[b_ptr]]!\n" | |||
"vmlal.s16 q12, d16, d1[0]\n" | |||
"vmlal.s16 q13, d17, d1[0]\n" | |||
"vmlal.s16 q14, d18, d1[1]\n" | |||
"vmlal.s16 q15, d19, d1[1]\n" | |||
"vmlal.s16 q12, d20, d1[2]\n" | |||
"vmlal.s16 q13, d21, d1[2]\n" | |||
"vmlal.s16 q14, d22, d1[3]\n" | |||
"vmlal.s16 q15, d23, d1[3]\n" | |||
"vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" | |||
"vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" | |||
"vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" | |||
"vmlal.s16 q12, d8, d0[0]\n" | |||
"vmlal.s16 q13, d9, d0[0]\n" | |||
"vmlal.s16 q14, d10, d0[1]\n" | |||
"vmlal.s16 q15, d11, d0[1]\n" | |||
"vmlal.s16 q12, d12, d0[2]\n" | |||
"vmlal.s16 q13, d13, d0[2]\n" | |||
"vmlal.s16 q14, d14, d0[3]\n" | |||
"vmlal.s16 q15, d15, d0[3]\n" | |||
"subs %[K], %[K], #8\n" | |||
"bne 1b\n" | |||
"2:\n" | |||
"vmlal.s16 q12, d16, d1[0]\n" | |||
"vmlal.s16 q13, d17, d1[0]\n" | |||
"vmlal.s16 q14, d18, d1[1]\n" | |||
"vmlal.s16 q15, d19, d1[1]\n" | |||
"vmlal.s16 q12, d20, d1[2]\n" | |||
"vmlal.s16 q13, d21, d1[2]\n" | |||
"vmlal.s16 q14, d22, d1[3]\n" | |||
"vmlal.s16 q15, d23, d1[3]\n" | |||
"vadd.s32 q12, q12, q14\n" | |||
"vadd.s32 q13, q13, q15\n" | |||
"vst1.32 {d24, d25, d26, d27}, [%[output]]!\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", | |||
"d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", | |||
"d25", "d26", "d27", "d28", "d29", "d30", "d31", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
// | |||
// A 4x8 cell of Rhs is stored in 16bit in q0-q3 | |||
@@ -40,7 +125,7 @@ namespace { | |||
// | q3[0-7]| |q14[0-3]|v15[0-3]| | |||
// +--------+ +--------+--------+ | |||
// Accumulator | |||
void kern_4x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
dt_int32* output) { | |||
//! As each load 16 number from B, but the pos add 16 * 2, so we minus 16 | |||
//! here. | |||
@@ -247,19 +332,25 @@ void gemm_nopack_s16_4x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||
constexpr static size_t MB = 8; | |||
constexpr static size_t KB = 8; | |||
constexpr static size_t NB = 4; | |||
constexpr static size_t CALCBLK = 4; | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||
for (size_t m = 0; m < M; m += MB) { | |||
dt_int32* output = C + (m / MB) * LDC; | |||
const dt_int16* cur_B = B; | |||
for (size_t n = 0; n < N; n += NB) { | |||
kern_4x8(A, cur_B, LDB, K, output); | |||
size_t n = 0; | |||
for (; n + NB - 1 < N; n += NB) { | |||
kern_8x4(A, cur_B, LDB, K, output); | |||
cur_B += KB * NB; | |||
output += MB * NB; | |||
} | |||
while (n < N) { | |||
kern_8x1(A, cur_B, LDB, K, output); | |||
cur_B += KB; | |||
output += MB; | |||
n++; | |||
} | |||
A += LDA; | |||
} | |||
} | |||
@@ -427,9 +427,6 @@ public: | |||
"The winograd remain oc is not times of OC_BLOCK_SIZE"); | |||
if (format == param::MatrixMul::Format::MK4 || | |||
format == param::MatrixMul::Format::MK8) { | |||
#if !MEGDNN_X86 | |||
nr_tiles_in_unit = round_up<size_t>(nr_tiles_in_unit, 4); | |||
#endif | |||
megdnn_assert(nr_tiles_in_unit <= unit_tile_size, | |||
"nr_tiles_in_unit: %zu TILE_SIZE:%zu", | |||
nr_tiles_in_unit, unit_tile_size); | |||
@@ -38,10 +38,9 @@ TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4) { | |||
} | |||
TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) { | |||
//! nbase should be 4 in order to test the last rest 4 in N dim | |||
matrix_mul::check_matrix_mul( | |||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||
"AARCH64_F32_MK4_4x16", param::MatrixMul::Format::MK4, 4); | |||
"AARCH64_F32_MK4_4x16", param::MatrixMul::Format::MK4, 1); | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
@@ -52,10 +51,9 @@ TEST_F(AARCH64, MATRIX_MUL_F16_K8X24X1) { | |||
} | |||
TEST_F(AARCH64, MATRIX_MUL_F16_MK8) { | |||
//! nbase should be 4 in order to test the last rest 4 in N dim | |||
matrix_mul::check_matrix_mul( | |||
dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), | |||
"AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 4); | |||
"AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 1); | |||
} | |||
#endif | |||
@@ -116,10 +114,9 @@ TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_K12X8X1) { | |||
} | |||
TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_MK8) { | |||
//! nbase should be 4 in order to test the last rest 4 in N dim | |||
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | |||
handle(), "AARCH64_INT16X16X32_MK8_8X8", | |||
param::MatrixMul::Format::MK8, 4); | |||
param::MatrixMul::Format::MK8, 1); | |||
} | |||
//! FIXME: need to add tests of GEMV and QUINT8 | |||
@@ -26,7 +26,7 @@ TEST_F(ARMV7, MATRIX_MUL) { | |||
TEST_F(ARMV7, MATRIX_MUL_MK4) { | |||
matrix_mul::check_matrix_mul( | |||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||
"ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 4); | |||
"ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); | |||
} | |||
TEST_F(ARMV7, MATRIX_MUL_PACK_MK4) { | |||
@@ -66,7 +66,7 @@ TEST_F(ARMV7, MATRIX_MUL_INT16x16x32) { | |||
TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) { | |||
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | |||
handle(), "ARMV7_INT16X16X32_MK8_4X8", | |||
param::MatrixMul::Format::MK8, 4); | |||
param::MatrixMul::Format::MK8, 1); | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
@@ -78,7 +78,7 @@ TEST_F(ARMV7, MATRIX_MUL_FP16) { | |||
TEST_F(ARMV7, MATRIX_MUL_F16_MK8) { | |||
matrix_mul::check_matrix_mul( | |||
dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), | |||
"AARCH32_F16_MK8_4X8", param::MatrixMul::Format::MK8, 4); | |||
"AARCH32_F16_MK8_4X8", param::MatrixMul::Format::MK8, 1); | |||
} | |||
#endif | |||