GitOrigin-RevId: 0f64b9f70f
tags/v1.0.0-rc1
@@ -214,8 +214,7 @@ bool MatrixMulImpl::AlgoF32MK4_4x16::usable( | |||||
kern_size_param.B_type == dtype::Float32() && | kern_size_param.B_type == dtype::Float32() && | ||||
kern_size_param.A_type == dtype::Float32() && | kern_size_param.A_type == dtype::Float32() && | ||||
kern_size_param.format == param::MatrixMul::Format::MK4 && | kern_size_param.format == param::MatrixMul::Format::MK4 && | ||||
!kern_size_param.trA && !kern_size_param.trB && | |||||
kern_size_param.N % 4 == 0; | |||||
!kern_size_param.trA && !kern_size_param.trB; | |||||
} | } | ||||
size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace( | size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace( | ||||
@@ -330,8 +329,7 @@ bool MatrixMulImpl::AlgoF16MK8_8x8::usable( | |||||
kern_size_param.B_type == kern_size_param.A_type && | kern_size_param.B_type == kern_size_param.A_type && | ||||
kern_size_param.A_type == dtype::Float16() && | kern_size_param.A_type == dtype::Float16() && | ||||
kern_size_param.format == param::MatrixMul::Format::MK8 && | kern_size_param.format == param::MatrixMul::Format::MK8 && | ||||
!kern_size_param.trA && !kern_size_param.trB && | |||||
kern_size_param.N % 4 == 0; | |||||
!kern_size_param.trA && !kern_size_param.trB; | |||||
} | } | ||||
size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace( | size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace( | ||||
@@ -918,8 +916,7 @@ bool MatrixMulImpl::AlgoInt16x16x32MK8_8x8::usable( | |||||
kern_size_param.B_type == dtype::Int16() && | kern_size_param.B_type == dtype::Int16() && | ||||
kern_size_param.A_type == dtype::Int16() && | kern_size_param.A_type == dtype::Int16() && | ||||
kern_size_param.format == param::MatrixMul::Format::MK8 && | kern_size_param.format == param::MatrixMul::Format::MK8 && | ||||
!kern_size_param.trA && !kern_size_param.trB && | |||||
kern_size_param.N % 4 == 0; | |||||
!kern_size_param.trA && !kern_size_param.trB; | |||||
} | } | ||||
size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace( | size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace( | ||||
@@ -21,6 +21,76 @@ using namespace aarch64::matmul; | |||||
namespace { | namespace { | ||||
void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
dt_float16* output) { | |||||
LDB *= sizeof(dt_float16); | |||||
asm volatile( | |||||
".arch armv8.2-a+fp16\n" | |||||
"subs %w[K], %w[K], #8\n" | |||||
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[a_ptr]], 64\n" | |||||
"ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[a_ptr]], 64\n" | |||||
"eor v24.16b, v24.16b, v24.16b\n" | |||||
"eor v25.16b, v25.16b, v25.16b\n" | |||||
"eor v26.16b, v26.16b, v26.16b\n" | |||||
"eor v27.16b, v27.16b, v27.16b\n" | |||||
"eor v28.16b, v28.16b, v28.16b\n" | |||||
"eor v29.16b, v29.16b, v29.16b\n" | |||||
"eor v30.16b, v30.16b, v30.16b\n" | |||||
"eor v31.16b, v31.16b, v31.16b\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"fmla v24.8h, v16.8h, v0.h[0]\n" | |||||
"fmla v25.8h, v17.8h, v0.h[1]\n" | |||||
"fmla v26.8h, v18.8h, v0.h[2]\n" | |||||
"fmla v27.8h, v19.8h, v0.h[3]\n" | |||||
"beq 2f\n" | |||||
"1:\n" | |||||
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[a_ptr]], 64\n" | |||||
"fmla v28.8h, v20.8h, v0.h[4]\n" | |||||
"fmla v29.8h, v21.8h, v0.h[5]\n" | |||||
"fmla v30.8h, v22.8h, v0.h[6]\n" | |||||
"fmla v31.8h, v23.8h, v0.h[7]\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[a_ptr]], 64\n" | |||||
"fmla v24.8h, v16.8h, v0.h[0]\n" | |||||
"fmla v25.8h, v17.8h, v0.h[1]\n" | |||||
"fmla v26.8h, v18.8h, v0.h[2]\n" | |||||
"fmla v27.8h, v19.8h, v0.h[3]\n" | |||||
"subs %w[K], %w[K], #8\n" | |||||
"bne 1b\n" | |||||
"2:\n" | |||||
"fmla v28.8h, v20.8h, v0.h[4]\n" | |||||
"fmla v29.8h, v21.8h, v0.h[5]\n" | |||||
"fmla v30.8h, v22.8h, v0.h[6]\n" | |||||
"fmla v31.8h, v23.8h, v0.h[7]\n" | |||||
"fadd v24.8h, v24.8h, v25.8h\n" | |||||
"fadd v26.8h, v26.8h, v27.8h\n" | |||||
"fadd v28.8h, v28.8h, v29.8h\n" | |||||
"fadd v30.8h, v30.8h, v31.8h\n" | |||||
"fadd v24.8h, v24.8h, v26.8h\n" | |||||
"fadd v28.8h, v28.8h, v30.8h\n" | |||||
"fadd v24.8h, v24.8h, v28.8h\n" | |||||
"st1 {v24.4s}, [%[output]], 16\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[output] "+r"(output), [LDB] "+r"(LDB) | |||||
: | |||||
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||||
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||||
"memory"); | |||||
} | |||||
// Overview of register layout: | // Overview of register layout: | ||||
// | // | ||||
// A 8x1 cell of Rhs is stored in 16bit in v0-v3 | // A 8x1 cell of Rhs is stored in 16bit in v0-v3 | ||||
@@ -416,7 +486,7 @@ void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, | |||||
constexpr static size_t NB = 8; | constexpr static size_t NB = 8; | ||||
constexpr static size_t CALCBLK = 4; | constexpr static size_t CALCBLK = 4; | ||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | ||||
for (size_t m = 0; m < M; m += MB) { | for (size_t m = 0; m < M; m += MB) { | ||||
@@ -428,8 +498,17 @@ void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, | |||||
cur_B += KB * NB; | cur_B += KB * NB; | ||||
output += MB * NB; | output += MB * NB; | ||||
} | } | ||||
if (n < N) { | |||||
if (N - n >= 4) { | |||||
kern_8x4(A, cur_B, LDB, K, output); | kern_8x4(A, cur_B, LDB, K, output); | ||||
cur_B += KB * CALCBLK; | |||||
output += MB * CALCBLK; | |||||
n += 4; | |||||
} | |||||
while (n < N) { | |||||
kern_8x1(A, cur_B, LDB, K, output); | |||||
cur_B += KB; | |||||
output += MB; | |||||
n++; | |||||
} | } | ||||
A += LDA; | A += LDA; | ||||
} | } | ||||
@@ -20,6 +20,54 @@ using namespace aarch64::matmul; | |||||
namespace { | namespace { | ||||
void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
float* output) { | |||||
LDB *= sizeof(float); | |||||
asm volatile( | |||||
"subs %w[K], %w[K], #4\n" | |||||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n" | |||||
"eor v16.16b, v16.16b, v16.16b\n" | |||||
"eor v17.16b, v17.16b, v17.16b\n" | |||||
"eor v18.16b, v18.16b, v18.16b\n" | |||||
"eor v19.16b, v19.16b, v19.16b\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"prfm pstl1keep, [%[b_ptr]]\n" | |||||
"fmla v16.4s, v4.4s, v0.s[0]\n" | |||||
"fmla v17.4s, v5.4s, v0.s[1]\n" | |||||
"beq 2f\n" | |||||
"1:\n" | |||||
"ld1 {v4.4s, v5.4s}, [%[a_ptr]], 32\n" | |||||
"fmla v18.4s, v6.4s, v0.s[2]\n" | |||||
"fmla v19.4s, v7.4s, v0.s[3]\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"prfm pstl1keep, [%[b_ptr]]\n" | |||||
"ld1 {v6.4s, v7.4s}, [%[a_ptr]], 32\n" | |||||
"fmla v16.4s, v4.4s, v0.s[0]\n" | |||||
"fmla v17.4s, v5.4s, v0.s[1]\n" | |||||
"subs %w[K], %w[K], #4\n" | |||||
"bne 1b\n" | |||||
"2:\n" | |||||
"fmla v18.4s, v6.4s, v0.s[2]\n" | |||||
"fmla v19.4s, v7.4s, v0.s[3]\n" | |||||
"fadd v16.4s, v16.4s, v18.4s\n" | |||||
"fadd v17.4s, v17.4s, v19.4s\n" | |||||
"fadd v16.4s, v16.4s, v17.4s\n" | |||||
"st1 {v16.4s}, [%[output]], 16\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[output] "+r"(output), [LDB] "+r"(LDB) | |||||
: | |||||
: "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", | |||||
"memory"); | |||||
} | |||||
// Overview of register layout: | // Overview of register layout: | ||||
// | // | ||||
// A 4x4 block of A is stored in register v4-v7 | // A 4x4 block of A is stored in register v4-v7 | ||||
@@ -117,7 +165,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||||
"v18", "v19", "cc", "memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -535,7 +584,7 @@ void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, | |||||
constexpr static size_t NB = 16; | constexpr static size_t NB = 16; | ||||
constexpr static size_t CALCBLK = 4; | constexpr static size_t CALCBLK = 4; | ||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||||
//! (m/4, k/4, 4, 4) * (k/4, n, 4) = (m/4, n, 4) | //! (m/4, k/4, 4, 4) * (k/4, n, 4) = (m/4, n, 4) | ||||
for (size_t m = 0; m < M; m += MB) { | for (size_t m = 0; m < M; m += MB) { | ||||
@@ -547,21 +596,23 @@ void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, | |||||
cur_B += KB * NB; | cur_B += KB * NB; | ||||
output += MB * NB; | output += MB * NB; | ||||
} | } | ||||
switch (N - n) { | |||||
case 4: | |||||
kern_4x4(A, cur_B, LDB, K, output); | |||||
break; | |||||
case 8: | |||||
kern_4x8(A, cur_B, LDB, K, output); | |||||
break; | |||||
case 12: | |||||
kern_4x8(A, cur_B, LDB, K, output); | |||||
cur_B += KB * CALCBLK * 2; | |||||
output += MB * CALCBLK * 2; | |||||
kern_4x4(A, cur_B, LDB, K, output); | |||||
break; | |||||
default: | |||||
break; | |||||
if (N - n >= 8) { | |||||
kern_4x8(A, cur_B, LDB, K, output); | |||||
cur_B += KB * CALCBLK * 2; | |||||
output += MB * CALCBLK * 2; | |||||
n += 8; | |||||
} | |||||
if (N - n >= 4) { | |||||
kern_4x4(A, cur_B, LDB, K, output); | |||||
cur_B += KB * CALCBLK; | |||||
output += MB * CALCBLK; | |||||
n += 4; | |||||
} | |||||
while (n < N) { | |||||
kern_4x1(A, cur_B, LDB, K, output); | |||||
cur_B += KB; | |||||
output += MB; | |||||
n++; | |||||
} | } | ||||
A += LDA; | A += LDA; | ||||
} | } | ||||
@@ -20,6 +20,82 @@ using namespace aarch64::matmul; | |||||
namespace { | namespace { | ||||
void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
dt_int32* output) { | |||||
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | |||||
//! here. | |||||
LDB *= sizeof(dt_int16); | |||||
asm volatile( | |||||
"subs %w[K], %w[K], #8\n" | |||||
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[a_ptr]], 64\n" | |||||
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[a_ptr]], 64\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"smull v16.4s, v24.4h, v0.h[0]\n" | |||||
"smull2 v17.4s, v24.8h, v0.h[0]\n" | |||||
"smull v18.4s, v25.4h, v0.h[1]\n" | |||||
"smull2 v19.4s, v25.8h, v0.h[1]\n" | |||||
"smull v20.4s, v26.4h, v0.h[2]\n" | |||||
"smull2 v21.4s, v26.8h, v0.h[2]\n" | |||||
"smull v22.4s, v27.4h, v0.h[3]\n" | |||||
"smull2 v23.4s, v27.8h, v0.h[3]\n" | |||||
"beq 2f\n" | |||||
"1:\n" | |||||
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[a_ptr]], 64\n" | |||||
"smlal v16.4s, v28.4h, v0.h[4]\n" | |||||
"smlal2 v17.4s, v28.8h, v0.h[4]\n" | |||||
"smlal v18.4s, v29.4h, v0.h[5]\n" | |||||
"smlal2 v19.4s, v29.8h, v0.h[5]\n" | |||||
"smlal v20.4s, v30.4h, v0.h[6]\n" | |||||
"smlal2 v21.4s, v30.8h, v0.h[6]\n" | |||||
"smlal v22.4s, v31.4h, v0.h[7]\n" | |||||
"smlal2 v23.4s, v31.8h, v0.h[7]\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[a_ptr]], 64\n" | |||||
"smlal v16.4s, v24.4h, v0.h[0]\n" | |||||
"smlal2 v17.4s, v24.8h, v0.h[0]\n" | |||||
"smlal v18.4s, v25.4h, v0.h[1]\n" | |||||
"smlal2 v19.4s, v25.8h, v0.h[1]\n" | |||||
"smlal v20.4s, v26.4h, v0.h[2]\n" | |||||
"smlal2 v21.4s, v26.8h, v0.h[2]\n" | |||||
"smlal v22.4s, v27.4h, v0.h[3]\n" | |||||
"smlal2 v23.4s, v27.8h, v0.h[3]\n" | |||||
"subs %w[K], %w[K], #8\n" | |||||
"bne 1b\n" | |||||
"2:\n" | |||||
"smlal v16.4s, v28.4h, v0.h[4]\n" | |||||
"smlal2 v17.4s, v28.8h, v0.h[4]\n" | |||||
"smlal v18.4s, v29.4h, v0.h[5]\n" | |||||
"smlal2 v19.4s, v29.8h, v0.h[5]\n" | |||||
"smlal v20.4s, v30.4h, v0.h[6]\n" | |||||
"smlal2 v21.4s, v30.8h, v0.h[6]\n" | |||||
"smlal v22.4s, v31.4h, v0.h[7]\n" | |||||
"smlal2 v23.4s, v31.8h, v0.h[7]\n" | |||||
"add v16.4s, v16.4s, v18.4s\n" | |||||
"add v20.4s, v20.4s, v22.4s\n" | |||||
"add v17.4s, v17.4s, v19.4s\n" | |||||
"add v21.4s, v21.4s, v23.4s\n" | |||||
"add v16.4s, v16.4s, v20.4s\n" | |||||
"add v17.4s, v17.4s, v21.4s\n" | |||||
"st1 {v16.4s, v17.4s}, [%[output]], 32\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[output] "+r"(output), [LDB] "+r"(LDB) | |||||
: | |||||
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "cc", "memory"); | |||||
} | |||||
// Overview of register layout: | // Overview of register layout: | ||||
// | // | ||||
// A 8x1 cell of Lhs is stored in 16bit in v24-v27 | // A 8x1 cell of Lhs is stored in 16bit in v24-v27 | ||||
@@ -636,7 +712,7 @@ void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||||
constexpr static size_t NB = 8; | constexpr static size_t NB = 8; | ||||
constexpr static size_t CALCBLK = 4; | constexpr static size_t CALCBLK = 4; | ||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | ||||
for (size_t m = 0; m < M; m += MB) { | for (size_t m = 0; m < M; m += MB) { | ||||
@@ -648,8 +724,17 @@ void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||||
cur_B += KB * NB; | cur_B += KB * NB; | ||||
output += MB * NB; | output += MB * NB; | ||||
} | } | ||||
if (n < N) { | |||||
if (N - n >= 4) { | |||||
kern_8x4(A, cur_B, LDB, K, output); | kern_8x4(A, cur_B, LDB, K, output); | ||||
cur_B += KB * CALCBLK; | |||||
output += MB * CALCBLK; | |||||
n += 4; | |||||
} | |||||
while (n < N) { | |||||
kern_8x1(A, cur_B, LDB, K, output); | |||||
cur_B += KB; | |||||
output += MB; | |||||
n++; | |||||
} | } | ||||
A += LDA; | A += LDA; | ||||
} | } | ||||
@@ -390,7 +390,7 @@ void winograd_2x3_8x8_f16::output(const dt_float16* output_transform_buf, | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | ||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | ||||
DISPATCH_CONV_WINOGRAD_BIAS( | DISPATCH_CONV_WINOGRAD_BIAS( | ||||
megdnn_arm_common_winograd_fp16_F23_8x8, cb, __fp16, __fp16, | |||||
megdnn_arm_common_winograd_f16_F23_8x8, cb, __fp16, __fp16, | |||||
bmode, nonline_mode, output_transform_buf, bias, output, | bmode, nonline_mode, output_transform_buf, bias, output, | ||||
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, | transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, | ||||
oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | ||||
@@ -875,8 +875,7 @@ bool MatrixMulImpl::AlgoF32MK4_4x8::usable( | |||||
kern_size_param.B_type == kern_size_param.A_type && | kern_size_param.B_type == kern_size_param.A_type && | ||||
kern_size_param.C_type == kern_size_param.A_type && | kern_size_param.C_type == kern_size_param.A_type && | ||||
kern_size_param.A_type == dtype::Float32() && | kern_size_param.A_type == dtype::Float32() && | ||||
kern_size_param.N % 4 == 0 && !kern_size_param.trA && | |||||
!kern_size_param.trB; | |||||
!kern_size_param.trA && !kern_size_param.trB; | |||||
} | } | ||||
size_t MatrixMulImpl::AlgoF32MK4_4x8::get_workspace( | size_t MatrixMulImpl::AlgoF32MK4_4x8::get_workspace( | ||||
@@ -911,8 +910,7 @@ bool MatrixMulImpl::AlgoInt16x16x32MK8_4x8::usable( | |||||
kern_size_param.A_type == dtype::Int16() && | kern_size_param.A_type == dtype::Int16() && | ||||
kern_size_param.B_type == dtype::Int16() && | kern_size_param.B_type == dtype::Int16() && | ||||
kern_size_param.C_type == dtype::Int32() && | kern_size_param.C_type == dtype::Int32() && | ||||
kern_size_param.N % 4 == 0 && !kern_size_param.trA && | |||||
!kern_size_param.trB; | |||||
!kern_size_param.trA && !kern_size_param.trB; | |||||
} | } | ||||
size_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_workspace( | size_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_workspace( | ||||
@@ -969,8 +967,7 @@ bool MatrixMulImpl::AlgoF16MK8_4x8::usable( | |||||
kern_size_param.B_type == kern_size_param.A_type && | kern_size_param.B_type == kern_size_param.A_type && | ||||
kern_size_param.A_type == dtype::Float16() && | kern_size_param.A_type == dtype::Float16() && | ||||
kern_size_param.format == param::MatrixMul::Format::MK8 && | kern_size_param.format == param::MatrixMul::Format::MK8 && | ||||
!kern_size_param.trA && !kern_size_param.trB && | |||||
kern_size_param.N % 4 == 0; | |||||
!kern_size_param.trA && !kern_size_param.trB; | |||||
} | } | ||||
size_t MatrixMulImpl::AlgoF16MK8_4x8::get_workspace( | size_t MatrixMulImpl::AlgoF16MK8_4x8::get_workspace( | ||||
@@ -21,6 +21,66 @@ using namespace armv7::matmul; | |||||
namespace { | namespace { | ||||
void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
dt_float16* output) { | |||||
LDB = (LDB - 4) * sizeof(dt_float16); | |||||
asm volatile( | |||||
"subs %[K], #8\n" | |||||
"vld1.32 {d0}, [%[b_ptr]]!\n" | |||||
"vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" | |||||
"vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" | |||||
"vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" | |||||
"vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" | |||||
"vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" | |||||
"vmul.f16 q12, q4, d0[0]\n" | |||||
"vmul.f16 q13, q5, d0[1]\n" | |||||
"vmul.f16 q14, q6, d0[2]\n" | |||||
"vmul.f16 q15, q7, d0[3]\n" | |||||
"beq 2f\n" | |||||
"1:\n" | |||||
"vmla.f16 q12, q8, d1[0]\n" | |||||
"vld1.32 {d0}, [%[b_ptr]]!\n" | |||||
"vmla.f16 q13, q9, d1[1]\n" | |||||
"vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" | |||||
"vmla.f16 q14, q10, d1[2]\n" | |||||
"vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" | |||||
"vmla.f16 q15, q11, d1[3]\n" | |||||
"vmla.f16 q12, q4, d0[0]\n" | |||||
"vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" | |||||
"vmla.f16 q13, q5, d0[1]\n" | |||||
"vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" | |||||
"vmla.f16 q14, q6, d0[2]\n" | |||||
"vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" | |||||
"vmla.f16 q15, q7, d0[3]\n" | |||||
"subs %[K], #8\n" | |||||
"bne 1b\n" | |||||
"2:\n" | |||||
"vmla.f16 q12, q8, d1[0]\n" | |||||
"vmla.f16 q13, q9, d1[1]\n" | |||||
"vmla.f16 q14, q10, d1[2]\n" | |||||
"vmla.f16 q15, q11, d1[3]\n" | |||||
"vadd.f16 q12, q12, q14\n" | |||||
"vadd.f16 q13, q13, q15\n" | |||||
"vadd.f16 q12, q12, q13\n" | |||||
"vst1.32 {d24, d25}, [%[output]]!\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[output] "+r"(output), [LDB] "+r"(LDB) | |||||
: | |||||
: "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", | |||||
"d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", | |||||
"d25", "d26", "d27", "d28", "d29", "d30", "d31", "cc", "memory"); | |||||
} | |||||
// Overview of register layout: | // Overview of register layout: | ||||
// | // | ||||
// A 8x1 cell of Rhs is stored in 16bit in v4-v11 | // A 8x1 cell of Rhs is stored in 16bit in v4-v11 | ||||
@@ -45,7 +105,7 @@ namespace { | |||||
// | v3[0-7]| |v15[0-7]| | // | v3[0-7]| |v15[0-7]| | ||||
// +--------+ +--------+--------+ | // +--------+ +--------+--------+ | ||||
// Accumulator | // Accumulator | ||||
void kern_4x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
dt_float16* output) { | dt_float16* output) { | ||||
//! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 | //! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 | ||||
//! here. | //! here. | ||||
@@ -179,19 +239,25 @@ void gemm_nopack_f16_4x8::kern(const dt_float16* A, size_t LDA, | |||||
constexpr static size_t MB = 8; | constexpr static size_t MB = 8; | ||||
constexpr static size_t KB = 8; | constexpr static size_t KB = 8; | ||||
constexpr static size_t NB = 4; | constexpr static size_t NB = 4; | ||||
constexpr static size_t CALCBLK = 4; | |||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | ||||
for (size_t m = 0; m < M; m += MB) { | for (size_t m = 0; m < M; m += MB) { | ||||
dt_float16* output = C + (m / MB) * LDC; | dt_float16* output = C + (m / MB) * LDC; | ||||
const dt_float16* cur_B = B; | const dt_float16* cur_B = B; | ||||
for (size_t n = 0; n < N; n += NB) { | |||||
kern_4x8(A, cur_B, LDB, K, output); | |||||
size_t n = 0; | |||||
for (; n + NB - 1 < N; n += NB) { | |||||
kern_8x4(A, cur_B, LDB, K, output); | |||||
cur_B += KB * NB; | cur_B += KB * NB; | ||||
output += MB * NB; | output += MB * NB; | ||||
} | } | ||||
while (n < N) { | |||||
kern_8x1(A, cur_B, LDB, K, output); | |||||
cur_B += KB; | |||||
output += MB; | |||||
n++; | |||||
} | |||||
A += LDA; | A += LDA; | ||||
} | } | ||||
} | } | ||||
@@ -20,6 +20,58 @@ using namespace armv7::matmul; | |||||
namespace { | namespace { | ||||
void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) { | |||||
LDB = (LDB - 4) * sizeof(float); | |||||
asm volatile( | |||||
"subs %[K], %[K], #4\n" | |||||
"vld1.32 {d8-d11}, [%[A]]!\n" | |||||
"vld1.32 {d12-d15}, [%[A]]!\n" | |||||
"veor q8, q8 \n" | |||||
"veor q9, q9 \n" | |||||
"veor q10, q10 \n" | |||||
"veor q11, q11 \n" | |||||
"vld1.32 {d0-d1}, [%[B]]!\n" | |||||
"vmla.f32 q8, q4, d0[0]\n" | |||||
"vmla.f32 q9, q5, d0[1]\n" | |||||
"beq 2f\n" | |||||
"1:\n" | |||||
"vld1.32 {d8-d11}, [%[A]]!\n" | |||||
"vmla.f32 q10, q6, d1[0]\n" | |||||
"vmla.f32 q11, q7, d1[1]\n" | |||||
"add %[B], %[B], %[LDB]\n" | |||||
"vld1.32 {d0-d1}, [%[B]]!\n" | |||||
"vld1.32 {d12-d15}, [%[A]]!\n" | |||||
"vmla.f32 q8, q4, d0[0]\n" | |||||
"vmla.f32 q9, q5, d0[1]\n" | |||||
"subs %[K], %[K], #4\n" | |||||
"bne 1b\n" | |||||
"2:\n" | |||||
"vmla.f32 q10, q6, d1[0]\n" | |||||
"vmla.f32 q11, q7, d1[1]\n" | |||||
"vadd.f32 q8, q8, q10\n" | |||||
"vadd.f32 q9, q9, q11\n" | |||||
"vadd.f32 q8, q8, q9\n" | |||||
"vst1.32 {d16, d17}, [%[C]]!\n" | |||||
: [ A ] "+r"(A), [ B ] "+r"(B), [ K ] "+r"(K), [ C ] "+r"(C) | |||||
: [ LDB ] "r"(LDB) | |||||
: "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", | |||||
"d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "cc", | |||||
"memory"); | |||||
} | |||||
// Overview of register layout: | // Overview of register layout: | ||||
// | // | ||||
// A 8x4 cell of Rhs is stored in 32bit in q0-q3, load 4 register each time | // A 8x4 cell of Rhs is stored in 32bit in q0-q3, load 4 register each time | ||||
@@ -268,9 +320,9 @@ void sgemm_nopack_4x8::kern(const float* A, size_t LDA, const float* B, | |||||
constexpr size_t MB = 4; | constexpr size_t MB = 4; | ||||
constexpr size_t KB = 4; | constexpr size_t KB = 4; | ||||
constexpr size_t NB = 8; | constexpr size_t NB = 8; | ||||
constexpr size_t CALCBLK = 4; | |||||
constexpr size_t NB_HALF = 4; | |||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | ||||
for (size_t m = 0; m < M; m += MB) { | for (size_t m = 0; m < M; m += MB) { | ||||
@@ -282,8 +334,17 @@ void sgemm_nopack_4x8::kern(const float* A, size_t LDA, const float* B, | |||||
cur_B += KB * NB; | cur_B += KB * NB; | ||||
output += MB * NB; | output += MB * NB; | ||||
} | } | ||||
if (n < N) { | |||||
if (N - n >= 4) { | |||||
kern_4x4(A, cur_B, LDB, K, output); | kern_4x4(A, cur_B, LDB, K, output); | ||||
cur_B += KB * NB_HALF; | |||||
output += MB * NB_HALF; | |||||
n += 4; | |||||
} | |||||
while (n < N) { | |||||
kern_4x1(A, cur_B, LDB, K, output); | |||||
cur_B += KB; | |||||
output += MB; | |||||
n++; | |||||
} | } | ||||
A += LDA; | A += LDA; | ||||
} | } | ||||
@@ -20,6 +20,91 @@ using namespace armv7::matmul; | |||||
namespace { | namespace { | ||||
void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
dt_int32* output) { | |||||
//! As each load 16 number from B, but the pos add 16 * 2, so we minus 16 | |||||
//! here. | |||||
LDB = (LDB - 4) * sizeof(dt_int16); | |||||
asm volatile( | |||||
"subs %[K], #8\n" | |||||
"vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" | |||||
"vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" | |||||
"vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" | |||||
"vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" | |||||
"vld1.32 {d0}, [%[b_ptr]]!\n" | |||||
"vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" | |||||
"vmull.s16 q12, d8, d0[0]\n" | |||||
"vmull.s16 q13, d9, d0[0]\n" | |||||
"vmull.s16 q14, d10, d0[1]\n" | |||||
"vmull.s16 q15, d11, d0[1]\n" | |||||
"vmlal.s16 q12, d12, d0[2]\n" | |||||
"vmlal.s16 q13, d13, d0[2]\n" | |||||
"vmlal.s16 q14, d14, d0[3]\n" | |||||
"vmlal.s16 q15, d15, d0[3]\n" | |||||
"beq 2f\n" | |||||
"1:\n" | |||||
"vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" | |||||
"vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" | |||||
"vld1.32 {d0}, [%[b_ptr]]!\n" | |||||
"vmlal.s16 q12, d16, d1[0]\n" | |||||
"vmlal.s16 q13, d17, d1[0]\n" | |||||
"vmlal.s16 q14, d18, d1[1]\n" | |||||
"vmlal.s16 q15, d19, d1[1]\n" | |||||
"vmlal.s16 q12, d20, d1[2]\n" | |||||
"vmlal.s16 q13, d21, d1[2]\n" | |||||
"vmlal.s16 q14, d22, d1[3]\n" | |||||
"vmlal.s16 q15, d23, d1[3]\n" | |||||
"vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" | |||||
"vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" | |||||
"vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" | |||||
"vmlal.s16 q12, d8, d0[0]\n" | |||||
"vmlal.s16 q13, d9, d0[0]\n" | |||||
"vmlal.s16 q14, d10, d0[1]\n" | |||||
"vmlal.s16 q15, d11, d0[1]\n" | |||||
"vmlal.s16 q12, d12, d0[2]\n" | |||||
"vmlal.s16 q13, d13, d0[2]\n" | |||||
"vmlal.s16 q14, d14, d0[3]\n" | |||||
"vmlal.s16 q15, d15, d0[3]\n" | |||||
"subs %[K], %[K], #8\n" | |||||
"bne 1b\n" | |||||
"2:\n" | |||||
"vmlal.s16 q12, d16, d1[0]\n" | |||||
"vmlal.s16 q13, d17, d1[0]\n" | |||||
"vmlal.s16 q14, d18, d1[1]\n" | |||||
"vmlal.s16 q15, d19, d1[1]\n" | |||||
"vmlal.s16 q12, d20, d1[2]\n" | |||||
"vmlal.s16 q13, d21, d1[2]\n" | |||||
"vmlal.s16 q14, d22, d1[3]\n" | |||||
"vmlal.s16 q15, d23, d1[3]\n" | |||||
"vadd.s32 q12, q12, q14\n" | |||||
"vadd.s32 q13, q13, q15\n" | |||||
"vst1.32 {d24, d25, d26, d27}, [%[output]]!\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[output] "+r"(output), [LDB] "+r"(LDB) | |||||
: | |||||
: "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", | |||||
"d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", | |||||
"d25", "d26", "d27", "d28", "d29", "d30", "d31", "cc", "memory"); | |||||
} | |||||
// Overview of register layout: | // Overview of register layout: | ||||
// | // | ||||
// A 4x8 cell of Rhs is stored in 16bit in q0-q3 | // A 4x8 cell of Rhs is stored in 16bit in q0-q3 | ||||
@@ -40,7 +125,7 @@ namespace { | |||||
// | q3[0-7]| |q14[0-3]|v15[0-3]| | // | q3[0-7]| |q14[0-3]|v15[0-3]| | ||||
// +--------+ +--------+--------+ | // +--------+ +--------+--------+ | ||||
// Accumulator | // Accumulator | ||||
void kern_4x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
dt_int32* output) { | dt_int32* output) { | ||||
//! As each load 16 number from B, but the pos add 16 * 2, so we minus 16 | //! As each load 16 number from B, but the pos add 16 * 2, so we minus 16 | ||||
//! here. | //! here. | ||||
@@ -247,19 +332,25 @@ void gemm_nopack_s16_4x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||||
constexpr static size_t MB = 8; | constexpr static size_t MB = 8; | ||||
constexpr static size_t KB = 8; | constexpr static size_t KB = 8; | ||||
constexpr static size_t NB = 4; | constexpr static size_t NB = 4; | ||||
constexpr static size_t CALCBLK = 4; | |||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | ||||
for (size_t m = 0; m < M; m += MB) { | for (size_t m = 0; m < M; m += MB) { | ||||
dt_int32* output = C + (m / MB) * LDC; | dt_int32* output = C + (m / MB) * LDC; | ||||
const dt_int16* cur_B = B; | const dt_int16* cur_B = B; | ||||
for (size_t n = 0; n < N; n += NB) { | |||||
kern_4x8(A, cur_B, LDB, K, output); | |||||
size_t n = 0; | |||||
for (; n + NB - 1 < N; n += NB) { | |||||
kern_8x4(A, cur_B, LDB, K, output); | |||||
cur_B += KB * NB; | cur_B += KB * NB; | ||||
output += MB * NB; | output += MB * NB; | ||||
} | } | ||||
while (n < N) { | |||||
kern_8x1(A, cur_B, LDB, K, output); | |||||
cur_B += KB; | |||||
output += MB; | |||||
n++; | |||||
} | |||||
A += LDA; | A += LDA; | ||||
} | } | ||||
} | } | ||||
@@ -427,9 +427,6 @@ public: | |||||
"The winograd remain oc is not times of OC_BLOCK_SIZE"); | "The winograd remain oc is not times of OC_BLOCK_SIZE"); | ||||
if (format == param::MatrixMul::Format::MK4 || | if (format == param::MatrixMul::Format::MK4 || | ||||
format == param::MatrixMul::Format::MK8) { | format == param::MatrixMul::Format::MK8) { | ||||
#if !MEGDNN_X86 | |||||
nr_tiles_in_unit = round_up<size_t>(nr_tiles_in_unit, 4); | |||||
#endif | |||||
megdnn_assert(nr_tiles_in_unit <= unit_tile_size, | megdnn_assert(nr_tiles_in_unit <= unit_tile_size, | ||||
"nr_tiles_in_unit: %zu TILE_SIZE:%zu", | "nr_tiles_in_unit: %zu TILE_SIZE:%zu", | ||||
nr_tiles_in_unit, unit_tile_size); | nr_tiles_in_unit, unit_tile_size); | ||||
@@ -38,10 +38,9 @@ TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4) { | |||||
} | } | ||||
TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) { | TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) { | ||||
//! nbase should be 4 in order to test the last rest 4 in N dim | |||||
matrix_mul::check_matrix_mul( | matrix_mul::check_matrix_mul( | ||||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | ||||
"AARCH64_F32_MK4_4x16", param::MatrixMul::Format::MK4, 4); | |||||
"AARCH64_F32_MK4_4x16", param::MatrixMul::Format::MK4, 1); | |||||
} | } | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
@@ -52,10 +51,9 @@ TEST_F(AARCH64, MATRIX_MUL_F16_K8X24X1) { | |||||
} | } | ||||
TEST_F(AARCH64, MATRIX_MUL_F16_MK8) { | TEST_F(AARCH64, MATRIX_MUL_F16_MK8) { | ||||
//! nbase should be 4 in order to test the last rest 4 in N dim | |||||
matrix_mul::check_matrix_mul( | matrix_mul::check_matrix_mul( | ||||
dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), | dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), | ||||
"AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 4); | |||||
"AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 1); | |||||
} | } | ||||
#endif | #endif | ||||
@@ -116,10 +114,9 @@ TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_K12X8X1) { | |||||
} | } | ||||
TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_MK8) { | TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_MK8) { | ||||
//! nbase should be 4 in order to test the last rest 4 in N dim | |||||
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | ||||
handle(), "AARCH64_INT16X16X32_MK8_8X8", | handle(), "AARCH64_INT16X16X32_MK8_8X8", | ||||
param::MatrixMul::Format::MK8, 4); | |||||
param::MatrixMul::Format::MK8, 1); | |||||
} | } | ||||
//! FIXME: need to add tests of GEMV and QUINT8 | //! FIXME: need to add tests of GEMV and QUINT8 | ||||
@@ -26,7 +26,7 @@ TEST_F(ARMV7, MATRIX_MUL) { | |||||
TEST_F(ARMV7, MATRIX_MUL_MK4) { | TEST_F(ARMV7, MATRIX_MUL_MK4) { | ||||
matrix_mul::check_matrix_mul( | matrix_mul::check_matrix_mul( | ||||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | ||||
"ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 4); | |||||
"ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); | |||||
} | } | ||||
TEST_F(ARMV7, MATRIX_MUL_PACK_MK4) { | TEST_F(ARMV7, MATRIX_MUL_PACK_MK4) { | ||||
@@ -66,7 +66,7 @@ TEST_F(ARMV7, MATRIX_MUL_INT16x16x32) { | |||||
TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) { | TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) { | ||||
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | ||||
handle(), "ARMV7_INT16X16X32_MK8_4X8", | handle(), "ARMV7_INT16X16X32_MK8_4X8", | ||||
param::MatrixMul::Format::MK8, 4); | |||||
param::MatrixMul::Format::MK8, 1); | |||||
} | } | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
@@ -78,7 +78,7 @@ TEST_F(ARMV7, MATRIX_MUL_FP16) { | |||||
TEST_F(ARMV7, MATRIX_MUL_F16_MK8) { | TEST_F(ARMV7, MATRIX_MUL_F16_MK8) { | ||||
matrix_mul::check_matrix_mul( | matrix_mul::check_matrix_mul( | ||||
dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), | dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), | ||||
"AARCH32_F16_MK8_4X8", param::MatrixMul::Format::MK8, 4); | |||||
"AARCH32_F16_MK8_4X8", param::MatrixMul::Format::MK8, 1); | |||||
} | } | ||||
#endif | #endif | ||||