diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp index dc97b8c1..3173c287 100644 --- a/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp +++ b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp @@ -7,6 +7,9 @@ using namespace matmul::fallback; namespace { +//! x86 and rvv GiSimdFmaLane API is slowly, as an alternate, use +//! GiMultiplyAddScalarFloat32 +#define MLA GiMultiplyAddScalarFloat32 void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) { LDB = LDB - 4; K = K - 4; @@ -24,34 +27,32 @@ void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) { GI_FLOAT32_t d20d21 = GiBroadcastFloat32(0.0f); GI_FLOAT32_t d22d23 = GiBroadcastFloat32(0.0f); - GI_FLOAT32_t d0d1 = GiLoadFloat32(B); + d16d17 = MLA(d16d17, d8d9, *(B)); + d18d19 = MLA(d18d19, d10d11, *(B + 1)); B = B + 4; - d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); - d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1); - for (; K > 0; K -= 4) { d8d9 = GiLoadFloat32(A); A = A + 4; d10d11 = GiLoadFloat32(A); A = A + 4; - d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2); - d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3); + d20d21 = MLA(d20d21, d12d13, *(B + 2 - 4)); + d22d23 = MLA(d22d23, d14d15, *(B + 3 - 4)); B = B + LDB; - d0d1 = GiLoadFloat32(B); - B = B + 4; + d12d13 = GiLoadFloat32(A); A = A + 4; d14d15 = GiLoadFloat32(A); A = A + 4; - d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); - d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1); + d16d17 = MLA(d16d17, d8d9, *(B)); + d18d19 = MLA(d18d19, d10d11, *(B + 1)); + B = B + 4; } - d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2); - d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3); + d20d21 = MLA(d20d21, d12d13, *(B + 2 - 4)); + d22d23 = MLA(d22d23, d14d15, *(B + 3 - 4)); d16d17 = GiAddFloat32(d16d17, d20d21); d18d19 = GiAddFloat32(d18d19, d22d23); d16d17 = GiAddFloat32(d16d17, d18d19); @@ -73,25 +74,19 @@ void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) { GI_FLOAT32_t d14d15 = GiLoadFloat32(A); A = A + 4; - GI_FLOAT32_t d0d1 = GiLoadFloat32(B); - B = B + 4; - GI_FLOAT32_t d2d3 = GiLoadFloat32(B); - B = B + 4; - GI_FLOAT32_t d4d5 = GiLoadFloat32(B); - B = B + 4; - GI_FLOAT32_t d6d7 = GiLoadFloat32(B); - B = B + 4; - GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f); - GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); - GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); - GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); - GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); + GI_FLOAT32_t d16d17 = MLA(vfzero, d8d9, *(B)); + d16d17 = MLA(d16d17, d10d11, *(B + 1)); + + GI_FLOAT32_t d18d19 = MLA(vfzero, d8d9, *(B + 4)); + d18d19 = MLA(d18d19, d10d11, *(B + 5)); + + GI_FLOAT32_t d20d21 = MLA(vfzero, d8d9, *(B + 8)); + d20d21 = MLA(d20d21, d10d11, *(B + 9)); - d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); - d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); - d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); - d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); + GI_FLOAT32_t d22d23 = MLA(vfzero, d8d9, *(B + 12)); + d22d23 = MLA(d22d23, d10d11, *(B + 13)); + B = B + 16; for (; K > 0; K -= 4) { d8d9 = GiLoadFloat32(A); @@ -99,51 +94,50 @@ void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) { d10d11 = GiLoadFloat32(A); A = A + 4; - d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); - d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); - d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); - d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); + d16d17 = MLA(d16d17, d12d13, *(B + 2 - 16)); + d16d17 = MLA(d16d17, d14d15, *(B + 3 - 16)); + + d18d19 = MLA(d18d19, d12d13, *(B + 6 - 16)); + d18d19 = MLA(d18d19, d14d15, *(B + 7 - 16)); + + d20d21 = MLA(d20d21, d12d13, *(B + 10 - 16)); + d20d21 = MLA(d20d21, d14d15, *(B + 11 - 16)); + + d22d23 = MLA(d22d23, d12d13, *(B + 14 - 16)); + d22d23 = MLA(d22d23, d14d15, *(B + 15 - 16)); B = B + LDB; - d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); - d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); - d0d1 = GiLoadFloat32(B); - B = B + 4; - d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); - d2d3 = GiLoadFloat32(B); - B = B + 4; - d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); - d4d5 = GiLoadFloat32(B); - B = B + 4; + d16d17 = MLA(d16d17, d8d9, *(B)); + d16d17 = MLA(d16d17, d10d11, *(B + 1)); - d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); - d6d7 = GiLoadFloat32(B); - B = B + 4; - d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0); - d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0); - d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0); + d18d19 = MLA(d18d19, d8d9, *(B + 4)); + d18d19 = MLA(d18d19, d10d11, *(B + 5)); + + d20d21 = MLA(d20d21, d8d9, *(B + 8)); + d20d21 = MLA(d20d21, d10d11, *(B + 9)); + + d22d23 = MLA(d22d23, d8d9, *(B + 12)); + d22d23 = MLA(d22d23, d10d11, *(B + 13)); d12d13 = GiLoadFloat32(A); A = A + 4; d14d15 = GiLoadFloat32(A); A = A + 4; - - d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); - d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); - d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); - d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); + B = B + 16; } - d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); - d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); - d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); - d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); + d16d17 = MLA(d16d17, d12d13, *(B + 2 - 16)); + d16d17 = MLA(d16d17, d14d15, *(B + 3 - 16)); - d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); - d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); - d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); - d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); + d18d19 = MLA(d18d19, d12d13, *(B + 6 - 16)); + d18d19 = MLA(d18d19, d14d15, *(B + 7 - 16)); + + d20d21 = MLA(d20d21, d12d13, *(B + 10 - 16)); + d20d21 = MLA(d20d21, d14d15, *(B + 11 - 16)); + + d22d23 = MLA(d22d23, d12d13, *(B + 14 - 16)); + d22d23 = MLA(d22d23, d14d15, *(B + 15 - 16)); GiStoreFloat32(C, d16d17); C = C + 4; @@ -166,56 +160,55 @@ void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) { GI_FLOAT32_t d14d15 = GiLoadFloat32(A); A = A + 4; - GI_FLOAT32_t d0d1 = GiLoadFloat32(B); + GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f); + + GI_FLOAT32_t d16d17 = MLA(vfzero, d8d9, *(B)); + d16d17 = MLA(d16d17, d10d11, *(B + 1)); + d16d17 = MLA(d16d17, d12d13, *(B + 2)); + d16d17 = MLA(d16d17, d14d15, *(B + 3)); B = B + 4; - GI_FLOAT32_t d2d3 = GiLoadFloat32(B); + + GI_FLOAT32_t d18d19 = MLA(vfzero, d8d9, *(B)); + d18d19 = MLA(d18d19, d10d11, *(B + 1)); + d18d19 = MLA(d18d19, d12d13, *(B + 2)); + d18d19 = MLA(d18d19, d14d15, *(B + 3)); B = B + 4; - GI_FLOAT32_t d4d5 = GiLoadFloat32(B); + + GI_FLOAT32_t d20d21 = MLA(vfzero, d8d9, *(B)); + d20d21 = MLA(d20d21, d10d11, *(B + 1)); + d20d21 = MLA(d20d21, d12d13, *(B + 2)); + d20d21 = MLA(d20d21, d14d15, *(B + 3)); B = B + 4; - GI_FLOAT32_t d6d7 = GiLoadFloat32(B); + + GI_FLOAT32_t d22d23 = MLA(vfzero, d8d9, *(B)); + d22d23 = MLA(d22d23, d10d11, *(B + 1)); + d22d23 = MLA(d22d23, d12d13, *(B + 2)); + d22d23 = MLA(d22d23, d14d15, *(B + 3)); B = B + 4; - GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f); - GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); - d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); - GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); - d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); - d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); - d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); - d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); - d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); - d0d1 = GiLoadFloat32(B); + + GI_FLOAT32_t d24d25 = MLA(vfzero, d8d9, *(B)); + d24d25 = MLA(d24d25, d10d11, *(B + 1)); + d24d25 = MLA(d24d25, d12d13, *(B + 2)); + d24d25 = MLA(d24d25, d14d15, *(B + 3)); B = B + 4; - d2d3 = GiLoadFloat32(B); + + GI_FLOAT32_t d26d27 = MLA(vfzero, d8d9, *(B)); + d26d27 = MLA(d26d27, d10d11, *(B + 1)); + d26d27 = MLA(d26d27, d12d13, *(B + 2)); + d26d27 = MLA(d26d27, d14d15, *(B + 3)); B = B + 4; - GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); - d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); - GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); - d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); - d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); - d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); - d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); - d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); - - d4d5 = GiLoadFloat32(B); + + GI_FLOAT32_t d28d29 = MLA(vfzero, d8d9, *(B)); + d28d29 = MLA(d28d29, d10d11, *(B + 1)); + d28d29 = MLA(d28d29, d12d13, *(B + 2)); + d28d29 = MLA(d28d29, d14d15, *(B + 3)); B = B + 4; - d6d7 = GiLoadFloat32(B); + + GI_FLOAT32_t d30d31 = MLA(vfzero, d8d9, *(B)); + d30d31 = MLA(d30d31, d10d11, *(B + 1)); + d30d31 = MLA(d30d31, d12d13, *(B + 2)); + d30d31 = MLA(d30d31, d14d15, *(B + 3)); B = B + 4; - GI_FLOAT32_t d24d25 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); - d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1); - GI_FLOAT32_t d26d27 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); - d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2); - d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1); - d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3); - d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2); - d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3); - GI_FLOAT32_t d28d29 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); - d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1); - GI_FLOAT32_t d30d31 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); - d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2); - d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1); - d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3); - d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2); - d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3); B = B + LDB; K = K - 4; @@ -229,56 +222,53 @@ void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) { d14d15 = GiLoadFloat32(A); A = A + 4; - d0d1 = GiLoadFloat32(B); - B = B + 4; - d2d3 = GiLoadFloat32(B); + d16d17 = MLA(d16d17, d8d9, *(B)); + d16d17 = MLA(d16d17, d10d11, *(B + 1)); + d16d17 = MLA(d16d17, d12d13, *(B + 2)); + d16d17 = MLA(d16d17, d14d15, *(B + 3)); B = B + 4; - d4d5 = GiLoadFloat32(B); + + d18d19 = MLA(d18d19, d8d9, *(B)); + d18d19 = MLA(d18d19, d10d11, *(B + 1)); + d18d19 = MLA(d18d19, d12d13, *(B + 2)); + d18d19 = MLA(d18d19, d14d15, *(B + 3)); B = B + 4; - d6d7 = GiLoadFloat32(B); + + d20d21 = MLA(d20d21, d8d9, *(B)); + d20d21 = MLA(d20d21, d10d11, *(B + 1)); + d20d21 = MLA(d20d21, d12d13, *(B + 2)); + d20d21 = MLA(d20d21, d14d15, *(B + 3)); B = B + 4; - d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); - d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); - d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0); - d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); - d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); - d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); - d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); - d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); - d0d1 = GiLoadFloat32(B); + + d22d23 = MLA(d22d23, d8d9, *(B)); + d22d23 = MLA(d22d23, d10d11, *(B + 1)); + d22d23 = MLA(d22d23, d12d13, *(B + 2)); + d22d23 = MLA(d22d23, d14d15, *(B + 3)); B = B + 4; - d2d3 = GiLoadFloat32(B); + + d24d25 = MLA(d24d25, d8d9, *(B)); + d24d25 = MLA(d24d25, d10d11, *(B + 1)); + d24d25 = MLA(d24d25, d12d13, *(B + 2)); + d24d25 = MLA(d24d25, d14d15, *(B + 3)); B = B + 4; - d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0); - d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); - d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0); - d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); - d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); - d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); - d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); - d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); - - d4d5 = GiLoadFloat32(B); + + d26d27 = MLA(d26d27, d8d9, *(B)); + d26d27 = MLA(d26d27, d10d11, *(B + 1)); + d26d27 = MLA(d26d27, d12d13, *(B + 2)); + d26d27 = MLA(d26d27, d14d15, *(B + 3)); B = B + 4; - d6d7 = GiLoadFloat32(B); + + d28d29 = MLA(d28d29, d8d9, *(B)); + d28d29 = MLA(d28d29, d10d11, *(B + 1)); + d28d29 = MLA(d28d29, d12d13, *(B + 2)); + d28d29 = MLA(d28d29, d14d15, *(B + 3)); B = B + 4; - d24d25 = GiSimdFmaLane(d24d25, d8d9, d0d1, 0); - d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1); - d26d27 = GiSimdFmaLane(d26d27, d8d9, d2d3, 0); - d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2); - d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1); - d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3); - d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2); - d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3); - d28d29 = GiSimdFmaLane(d28d29, d8d9, d4d5, 0); - d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1); - d30d31 = GiSimdFmaLane(d30d31, d8d9, d6d7, 0); - d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2); - d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1); - d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3); - d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2); - d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3); - B = B + LDB; + + d30d31 = MLA(d30d31, d8d9, *(B)); + d30d31 = MLA(d30d31, d10d11, *(B + 1)); + d30d31 = MLA(d30d31, d12d13, *(B + 2)); + d30d31 = MLA(d30d31, d14d15, *(B + 3)); + B = B + 4 + LDB; } GiStoreFloat32(C, d16d17); C = C + 4; @@ -298,6 +288,7 @@ void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) { C = C + 4; } +#undef MLA } // namespace MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gi_sgemm_nopack_4x8); diff --git a/dnn/test/fallback/matrix_mul.cpp b/dnn/test/fallback/matrix_mul.cpp index 197804ea..3a8531db 100644 --- a/dnn/test/fallback/matrix_mul.cpp +++ b/dnn/test/fallback/matrix_mul.cpp @@ -176,6 +176,13 @@ TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_GI_PACK_MK4) { "FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4); } +TEST_F(FALLBACK, BENCHMARK_MATRIX_FB_GI_F32_MK4_4x8) { + auto args = matrix_mul::get_benchmark_matmul_args(); + matrix_mul::benchmark_single_algo( + handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, + "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4); +} + #endif } // namespace test } // namespace megdnn