|
|
@@ -19,6 +19,17 @@ using namespace matmul::fallback; |
|
|
|
|
|
|
|
namespace { |
|
|
|
|
|
|
|
#undef PREFER_VF |
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) |
|
|
|
#define PREFER_VF |
|
|
|
#endif |
|
|
|
|
|
|
|
#if defined(PREFER_VF) |
|
|
|
#define MLA(a, b, c, d) GiMultiplyAddScalarFloat32(a, b, *(c + d)) |
|
|
|
#else |
|
|
|
#define MLA(a, b, c, d) GiSimdFmaLane(a, b, c, d) |
|
|
|
#endif |
|
|
|
|
|
|
|
void kern_4x12( |
|
|
|
const float* packA, const float* packB, int K, float* output, int LDC, |
|
|
|
bool is_first_k, int m_remain) { |
|
|
@@ -32,8 +43,13 @@ void kern_4x12( |
|
|
|
float* r2 = r1 + LDC; |
|
|
|
float* r3 = r2 + LDC; |
|
|
|
|
|
|
|
GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19, |
|
|
|
d20d21, d22d23, d24d25, d26d27, d28d29, d30d31; |
|
|
|
#if defined(PREFER_VF) |
|
|
|
const float* d0d1; |
|
|
|
#else |
|
|
|
GI_FLOAT32_t d0d1; |
|
|
|
#endif |
|
|
|
GI_FLOAT32_t d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19, d20d21, |
|
|
|
d22d23, d24d25, d26d27, d28d29, d30d31; |
|
|
|
|
|
|
|
if (is_first_k) { |
|
|
|
d8d9 = GiBroadcastFloat32(0.0f); |
|
|
@@ -99,23 +115,31 @@ void kern_4x12( |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
for (; K > 0; K--) { |
|
|
|
#if defined(PREFER_VF) |
|
|
|
d0d1 = a_ptr; |
|
|
|
#else |
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
#endif |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); |
|
|
|
|
|
|
|
d8d9 = MLA(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = MLA(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = MLA(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = MLA(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = MLA(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = MLA(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = MLA(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = MLA(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = MLA(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = MLA(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = MLA(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = MLA(d30d31, d6d7, d0d1, 3); |
|
|
|
|
|
|
|
#if defined(PREFER_VF) |
|
|
|
d0d1 = a_ptr; |
|
|
|
#else |
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
#endif |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
d2d3 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
@@ -124,18 +148,18 @@ void kern_4x12( |
|
|
|
d6d7 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); |
|
|
|
d8d9 = MLA(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = MLA(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = MLA(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = MLA(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = MLA(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = MLA(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = MLA(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = MLA(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = MLA(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = MLA(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = MLA(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = MLA(d30d31, d6d7, d0d1, 3); |
|
|
|
|
|
|
|
d2d3 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
@@ -146,40 +170,52 @@ void kern_4x12( |
|
|
|
} |
|
|
|
|
|
|
|
if (1 == oddk) { |
|
|
|
#if defined(PREFER_VF) |
|
|
|
d0d1 = a_ptr; |
|
|
|
#else |
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
#endif |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); |
|
|
|
d8d9 = MLA(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = MLA(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = MLA(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = MLA(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = MLA(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = MLA(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = MLA(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = MLA(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = MLA(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = MLA(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = MLA(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = MLA(d30d31, d6d7, d0d1, 3); |
|
|
|
|
|
|
|
} else { |
|
|
|
#if defined(PREFER_VF) |
|
|
|
d0d1 = a_ptr; |
|
|
|
#else |
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
#endif |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); |
|
|
|
|
|
|
|
d8d9 = MLA(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = MLA(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = MLA(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = MLA(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = MLA(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = MLA(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = MLA(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = MLA(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = MLA(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = MLA(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = MLA(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = MLA(d30d31, d6d7, d0d1, 3); |
|
|
|
|
|
|
|
#if defined(PREFER_VF) |
|
|
|
d0d1 = a_ptr; |
|
|
|
#else |
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
#endif |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
d2d3 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
@@ -188,18 +224,18 @@ void kern_4x12( |
|
|
|
d6d7 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); |
|
|
|
d8d9 = MLA(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = MLA(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = MLA(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = MLA(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = MLA(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = MLA(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = MLA(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = MLA(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = MLA(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = MLA(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = MLA(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = MLA(d30d31, d6d7, d0d1, 3); |
|
|
|
} |
|
|
|
|
|
|
|
if (m_remain == 4) { |
|
|
@@ -259,7 +295,13 @@ void kern_4x4( |
|
|
|
float* r3 = r2 + LDC; |
|
|
|
size_t d_size = sizeof(float); |
|
|
|
|
|
|
|
GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15; |
|
|
|
#if defined(PREFER_VF) |
|
|
|
const float* d0d1; |
|
|
|
const float* d2d3; |
|
|
|
#else |
|
|
|
GI_FLOAT32_t d0d1, d2d3; |
|
|
|
#endif |
|
|
|
GI_FLOAT32_t d4d5, d6d7, d8d9, d10d11, d12d13, d14d15; |
|
|
|
float tmp[4]; |
|
|
|
if (is_first_k) { |
|
|
|
d8d9 = GiBroadcastFloat32(0.0f); |
|
|
@@ -412,54 +454,70 @@ void kern_4x4( |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#if defined(PREFER_VF) |
|
|
|
d0d1 = a_ptr; |
|
|
|
#else |
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
#endif |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
d4d5 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
for (; K > 0; K--) { |
|
|
|
#if defined(PREFER_VF) |
|
|
|
d2d3 = a_ptr; |
|
|
|
#else |
|
|
|
d2d3 = GiLoadFloat32(a_ptr); |
|
|
|
#endif |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
d6d7 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); |
|
|
|
d8d9 = MLA(d8d9, d4d5, d0d1, 0); |
|
|
|
d10d11 = MLA(d10d11, d4d5, d0d1, 1); |
|
|
|
d12d13 = MLA(d12d13, d4d5, d0d1, 2); |
|
|
|
d14d15 = MLA(d14d15, d4d5, d0d1, 3); |
|
|
|
|
|
|
|
#if defined(PREFER_VF) |
|
|
|
d0d1 = a_ptr; |
|
|
|
#else |
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
#endif |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
d4d5 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d6d7, d2d3, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d6d7, d2d3, 1); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d2d3, 2); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d6d7, d2d3, 3); |
|
|
|
d8d9 = MLA(d8d9, d6d7, d2d3, 0); |
|
|
|
d10d11 = MLA(d10d11, d6d7, d2d3, 1); |
|
|
|
d12d13 = MLA(d12d13, d6d7, d2d3, 2); |
|
|
|
d14d15 = MLA(d14d15, d6d7, d2d3, 3); |
|
|
|
} |
|
|
|
|
|
|
|
if (1 == oddk) { |
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); |
|
|
|
d8d9 = MLA(d8d9, d4d5, d0d1, 0); |
|
|
|
d10d11 = MLA(d10d11, d4d5, d0d1, 1); |
|
|
|
d12d13 = MLA(d12d13, d4d5, d0d1, 2); |
|
|
|
d14d15 = MLA(d14d15, d4d5, d0d1, 3); |
|
|
|
|
|
|
|
} else { |
|
|
|
#if defined(PREFER_VF) |
|
|
|
d2d3 = a_ptr; |
|
|
|
#else |
|
|
|
d2d3 = GiLoadFloat32(a_ptr); |
|
|
|
#endif |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
d6d7 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); |
|
|
|
d8d9 = MLA(d8d9, d4d5, d0d1, 0); |
|
|
|
d10d11 = MLA(d10d11, d4d5, d0d1, 1); |
|
|
|
d12d13 = MLA(d12d13, d4d5, d0d1, 2); |
|
|
|
d14d15 = MLA(d14d15, d4d5, d0d1, 3); |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d6d7, d2d3, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d6d7, d2d3, 1); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d2d3, 2); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d6d7, d2d3, 3); |
|
|
|
d8d9 = MLA(d8d9, d6d7, d2d3, 0); |
|
|
|
d10d11 = MLA(d10d11, d6d7, d2d3, 1); |
|
|
|
d12d13 = MLA(d12d13, d6d7, d2d3, 2); |
|
|
|
d14d15 = MLA(d14d15, d6d7, d2d3, 3); |
|
|
|
} |
|
|
|
|
|
|
|
if (m_remain == 4) { |
|
|
@@ -882,6 +940,7 @@ void gi_sgemm_4x12_pack_B_t( |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#undef MLA |
|
|
|
} // namespace |
|
|
|
|
|
|
|
MEGDNN_REG_GEMM_STRATEGY_IMPL(gi_sgemm_4x12); |
|
|
|