|
|
@@ -0,0 +1,950 @@ |
|
|
|
#include "src/fallback/matrix_mul/generic_strategy.h" |
|
|
|
#include "src/fallback/matrix_mul/gi/fp32/common.h" |
|
|
|
|
|
|
|
using namespace megdnn; |
|
|
|
using namespace matmul::fallback; |
|
|
|
|
|
|
|
namespace { |
|
|
|
|
|
|
|
#pragma GCC diagnostic push |
|
|
|
#pragma GCC diagnostic ignored "-Wuninitialized" |
|
|
|
|
|
|
|
#ifdef __GNUC__ |
|
|
|
#ifndef __has_warning |
|
|
|
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" |
|
|
|
#else |
|
|
|
#if __has_warning("-Wmaybe-uninitialized") |
|
|
|
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
void kern_4x12( |
|
|
|
const float* packA, const float* packB, int K, float* output, int LDC, |
|
|
|
bool is_first_k, int m_remain) { |
|
|
|
const float* a_ptr = packA; |
|
|
|
const float* b_ptr = packB; |
|
|
|
int oddk = (K & 1); |
|
|
|
K = ((K + 1) / 2) - 1; |
|
|
|
|
|
|
|
float* r0 = output; |
|
|
|
float* r1 = r0 + LDC; |
|
|
|
float* r2 = r1 + LDC; |
|
|
|
float* r3 = r2 + LDC; |
|
|
|
|
|
|
|
GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19, |
|
|
|
d20d21, d22d23, d24d25, d26d27, d28d29, d30d31; |
|
|
|
|
|
|
|
if (is_first_k) { |
|
|
|
d8d9 = GiBroadcastFloat32(0.0f); |
|
|
|
d10d11 = GiBroadcastFloat32(0.0f); |
|
|
|
d12d13 = GiBroadcastFloat32(0.0f); |
|
|
|
d14d15 = GiBroadcastFloat32(0.0f); |
|
|
|
d16d17 = GiBroadcastFloat32(0.0f); |
|
|
|
d18d19 = GiBroadcastFloat32(0.0f); |
|
|
|
d20d21 = GiBroadcastFloat32(0.0f); |
|
|
|
d22d23 = GiBroadcastFloat32(0.0f); |
|
|
|
d24d25 = GiBroadcastFloat32(0.0f); |
|
|
|
d26d27 = GiBroadcastFloat32(0.0f); |
|
|
|
d28d29 = GiBroadcastFloat32(0.0f); |
|
|
|
d30d31 = GiBroadcastFloat32(0.0f); |
|
|
|
} else { |
|
|
|
if (m_remain == 4) { |
|
|
|
d8d9 = GiLoadFloat32(r0); |
|
|
|
d10d11 = GiLoadFloat32(r0 + 4); |
|
|
|
d12d13 = GiLoadFloat32(r0 + 8); |
|
|
|
|
|
|
|
d14d15 = GiLoadFloat32(r1); |
|
|
|
d16d17 = GiLoadFloat32(r1 + 4); |
|
|
|
d18d19 = GiLoadFloat32(r1 + 8); |
|
|
|
|
|
|
|
d20d21 = GiLoadFloat32(r2); |
|
|
|
d22d23 = GiLoadFloat32(r2 + 4); |
|
|
|
d24d25 = GiLoadFloat32(r2 + 8); |
|
|
|
|
|
|
|
d26d27 = GiLoadFloat32(r3); |
|
|
|
d28d29 = GiLoadFloat32(r3 + 4); |
|
|
|
d30d31 = GiLoadFloat32(r3 + 8); |
|
|
|
} else if (m_remain == 3) { |
|
|
|
d8d9 = GiLoadFloat32(r0); |
|
|
|
d10d11 = GiLoadFloat32(r0 + 4); |
|
|
|
d12d13 = GiLoadFloat32(r0 + 8); |
|
|
|
|
|
|
|
d14d15 = GiLoadFloat32(r1); |
|
|
|
d16d17 = GiLoadFloat32(r1 + 4); |
|
|
|
d18d19 = GiLoadFloat32(r1 + 8); |
|
|
|
|
|
|
|
d20d21 = GiLoadFloat32(r2); |
|
|
|
d22d23 = GiLoadFloat32(r2 + 4); |
|
|
|
d24d25 = GiLoadFloat32(r2 + 8); |
|
|
|
} else if (m_remain == 2) { |
|
|
|
d8d9 = GiLoadFloat32(r0); |
|
|
|
d10d11 = GiLoadFloat32(r0 + 4); |
|
|
|
d12d13 = GiLoadFloat32(r0 + 8); |
|
|
|
|
|
|
|
d14d15 = GiLoadFloat32(r1); |
|
|
|
d16d17 = GiLoadFloat32(r1 + 4); |
|
|
|
d18d19 = GiLoadFloat32(r1 + 8); |
|
|
|
} else if (m_remain == 1) { |
|
|
|
d8d9 = GiLoadFloat32(r0); |
|
|
|
d10d11 = GiLoadFloat32(r0 + 4); |
|
|
|
d12d13 = GiLoadFloat32(r0 + 8); |
|
|
|
} |
|
|
|
} |
|
|
|
d2d3 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
d4d5 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
d6d7 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
for (; K > 0; K--) { |
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); |
|
|
|
|
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
d2d3 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
d4d5 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
d6d7 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); |
|
|
|
|
|
|
|
d2d3 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
d4d5 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
d6d7 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
} |
|
|
|
|
|
|
|
if (1 == oddk) { |
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); |
|
|
|
|
|
|
|
} else { |
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); |
|
|
|
|
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
d2d3 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
d4d5 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
d6d7 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); |
|
|
|
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); |
|
|
|
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); |
|
|
|
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); |
|
|
|
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); |
|
|
|
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); |
|
|
|
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); |
|
|
|
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); |
|
|
|
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); |
|
|
|
} |
|
|
|
|
|
|
|
if (m_remain == 4) { |
|
|
|
GiStoreFloat32(r0, d8d9); |
|
|
|
GiStoreFloat32(r0 + 4, d10d11); |
|
|
|
GiStoreFloat32(r0 + 8, d12d13); |
|
|
|
|
|
|
|
GiStoreFloat32(r1, d14d15); |
|
|
|
GiStoreFloat32(r1 + 4, d16d17); |
|
|
|
GiStoreFloat32(r1 + 8, d18d19); |
|
|
|
|
|
|
|
GiStoreFloat32(r2, d20d21); |
|
|
|
GiStoreFloat32(r2 + 4, d22d23); |
|
|
|
GiStoreFloat32(r2 + 8, d24d25); |
|
|
|
|
|
|
|
GiStoreFloat32(r3, d26d27); |
|
|
|
GiStoreFloat32(r3 + 4, d28d29); |
|
|
|
GiStoreFloat32(r3 + 8, d30d31); |
|
|
|
} else if (m_remain == 3) { |
|
|
|
GiStoreFloat32(r0, d8d9); |
|
|
|
GiStoreFloat32(r0 + 4, d10d11); |
|
|
|
GiStoreFloat32(r0 + 8, d12d13); |
|
|
|
|
|
|
|
GiStoreFloat32(r1, d14d15); |
|
|
|
GiStoreFloat32(r1 + 4, d16d17); |
|
|
|
GiStoreFloat32(r1 + 8, d18d19); |
|
|
|
|
|
|
|
GiStoreFloat32(r2, d20d21); |
|
|
|
GiStoreFloat32(r2 + 4, d22d23); |
|
|
|
GiStoreFloat32(r2 + 8, d24d25); |
|
|
|
} else if (m_remain == 2) { |
|
|
|
GiStoreFloat32(r0, d8d9); |
|
|
|
GiStoreFloat32(r0 + 4, d10d11); |
|
|
|
GiStoreFloat32(r0 + 8, d12d13); |
|
|
|
|
|
|
|
GiStoreFloat32(r1, d14d15); |
|
|
|
GiStoreFloat32(r1 + 4, d16d17); |
|
|
|
GiStoreFloat32(r1 + 8, d18d19); |
|
|
|
} else if (m_remain == 1) { |
|
|
|
GiStoreFloat32(r0, d8d9); |
|
|
|
GiStoreFloat32(r0 + 4, d10d11); |
|
|
|
GiStoreFloat32(r0 + 8, d12d13); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void kern_4x4( |
|
|
|
const float* packA, const float* packB, int K, float* output, int LDC, |
|
|
|
bool is_first_k, int m_remain, int n_remain) { |
|
|
|
const float* a_ptr = packA; |
|
|
|
const float* b_ptr = packB; |
|
|
|
int oddk = (K & 1); |
|
|
|
K = ((K + 1) / 2) - 1; |
|
|
|
|
|
|
|
float* r0 = output; |
|
|
|
float* r1 = r0 + LDC; |
|
|
|
float* r2 = r1 + LDC; |
|
|
|
float* r3 = r2 + LDC; |
|
|
|
size_t d_size = sizeof(float); |
|
|
|
|
|
|
|
GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15; |
|
|
|
float tmp[4]; |
|
|
|
if (is_first_k) { |
|
|
|
d8d9 = GiBroadcastFloat32(0.0f); |
|
|
|
d10d11 = GiBroadcastFloat32(0.0f); |
|
|
|
d12d13 = GiBroadcastFloat32(0.0f); |
|
|
|
d14d15 = GiBroadcastFloat32(0.0f); |
|
|
|
} else { |
|
|
|
if (m_remain == 4) { |
|
|
|
if (n_remain == 4) { |
|
|
|
d8d9 = GiLoadFloat32(r0); |
|
|
|
d10d11 = GiLoadFloat32(r1); |
|
|
|
d12d13 = GiLoadFloat32(r2); |
|
|
|
d14d15 = GiLoadFloat32(r3); |
|
|
|
} else if (n_remain == 3) { |
|
|
|
memcpy(tmp, r0, d_size * 3); |
|
|
|
r0 += 3; |
|
|
|
d8d9 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
memcpy(tmp, r1, d_size * 3); |
|
|
|
r1 += 3; |
|
|
|
d10d11 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
memcpy(tmp, r2, d_size * 3); |
|
|
|
r2 += 3; |
|
|
|
d12d13 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
memcpy(tmp, r3, d_size * 3); |
|
|
|
r3 += 3; |
|
|
|
d14d15 = GiLoadFloat32(tmp); |
|
|
|
} else if (n_remain == 2) { |
|
|
|
memcpy(tmp, r0, d_size * 2); |
|
|
|
r0 += 2; |
|
|
|
d8d9 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
memcpy(tmp, r1, d_size * 2); |
|
|
|
r1 += 2; |
|
|
|
d10d11 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
memcpy(tmp, r2, d_size * 2); |
|
|
|
r2 += 2; |
|
|
|
d12d13 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
memcpy(tmp, r3, d_size * 2); |
|
|
|
r3 += 2; |
|
|
|
d14d15 = GiLoadFloat32(tmp); |
|
|
|
} else if (n_remain == 1) { |
|
|
|
tmp[0] = *r0; |
|
|
|
r0++; |
|
|
|
d8d9 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
tmp[0] = *r1; |
|
|
|
r1++; |
|
|
|
d10d11 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
tmp[0] = *r2; |
|
|
|
r2++; |
|
|
|
d12d13 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
tmp[0] = *r3; |
|
|
|
r3++; |
|
|
|
d14d15 = GiLoadFloat32(tmp); |
|
|
|
} |
|
|
|
} else if (m_remain == 3) { |
|
|
|
if (n_remain == 4) { |
|
|
|
d8d9 = GiLoadFloat32(r0); |
|
|
|
d10d11 = GiLoadFloat32(r1); |
|
|
|
d12d13 = GiLoadFloat32(r2); |
|
|
|
} else if (n_remain == 3) { |
|
|
|
memcpy(tmp, r0, d_size * 3); |
|
|
|
r0 += 3; |
|
|
|
d8d9 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
memcpy(tmp, r1, d_size * 3); |
|
|
|
r1 += 3; |
|
|
|
d10d11 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
memcpy(tmp, r2, d_size * 3); |
|
|
|
r2 += 3; |
|
|
|
d12d13 = GiLoadFloat32(tmp); |
|
|
|
} else if (n_remain == 2) { |
|
|
|
memcpy(tmp, r0, d_size * 2); |
|
|
|
r0 += 2; |
|
|
|
d8d9 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
memcpy(tmp, r1, d_size * 2); |
|
|
|
r1 += 2; |
|
|
|
d10d11 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
memcpy(tmp, r2, d_size * 2); |
|
|
|
r2 += 2; |
|
|
|
d12d13 = GiLoadFloat32(tmp); |
|
|
|
} else if (n_remain == 1) { |
|
|
|
tmp[0] = *r0; |
|
|
|
r0++; |
|
|
|
d8d9 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
tmp[0] = *r1; |
|
|
|
r1++; |
|
|
|
d10d11 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
tmp[0] = *r2; |
|
|
|
r2++; |
|
|
|
d12d13 = GiLoadFloat32(tmp); |
|
|
|
} |
|
|
|
} else if (m_remain == 2) { |
|
|
|
if (n_remain == 4) { |
|
|
|
d8d9 = GiLoadFloat32(r0); |
|
|
|
d10d11 = GiLoadFloat32(r1); |
|
|
|
} else if (n_remain == 3) { |
|
|
|
memcpy(tmp, r0, d_size * 3); |
|
|
|
r0 += 3; |
|
|
|
d8d9 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
memcpy(tmp, r1, d_size * 3); |
|
|
|
r1 += 3; |
|
|
|
d10d11 = GiLoadFloat32(tmp); |
|
|
|
} else if (n_remain == 2) { |
|
|
|
memcpy(tmp, r0, d_size * 2); |
|
|
|
r0 += 2; |
|
|
|
d8d9 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
memcpy(tmp, r1, d_size * 2); |
|
|
|
r1 += 2; |
|
|
|
d10d11 = GiLoadFloat32(tmp); |
|
|
|
} else if (n_remain == 1) { |
|
|
|
tmp[0] = *r0; |
|
|
|
r0++; |
|
|
|
d8d9 = GiLoadFloat32(tmp); |
|
|
|
|
|
|
|
tmp[0] = *r1; |
|
|
|
r1++; |
|
|
|
d10d11 = GiLoadFloat32(tmp); |
|
|
|
} |
|
|
|
} else if (m_remain == 1) { |
|
|
|
if (n_remain == 4) { |
|
|
|
d8d9 = GiLoadFloat32(r0); |
|
|
|
} else if (n_remain == 3) { |
|
|
|
memcpy(tmp, r0, d_size * 3); |
|
|
|
r0 += 3; |
|
|
|
d8d9 = GiLoadFloat32(tmp); |
|
|
|
} else if (n_remain == 2) { |
|
|
|
memcpy(tmp, r0, d_size * 2); |
|
|
|
r0 += 2; |
|
|
|
d8d9 = GiLoadFloat32(tmp); |
|
|
|
} else if (n_remain == 1) { |
|
|
|
tmp[0] = *r0; |
|
|
|
r0++; |
|
|
|
d8d9 = GiLoadFloat32(tmp); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
d4d5 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
for (; K > 0; K--) { |
|
|
|
d2d3 = GiLoadFloat32(a_ptr); |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
d6d7 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); |
|
|
|
|
|
|
|
d0d1 = GiLoadFloat32(a_ptr); |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
d4d5 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d6d7, d2d3, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d6d7, d2d3, 1); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d2d3, 2); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d6d7, d2d3, 3); |
|
|
|
} |
|
|
|
|
|
|
|
if (1 == oddk) { |
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); |
|
|
|
|
|
|
|
} else { |
|
|
|
d2d3 = GiLoadFloat32(a_ptr); |
|
|
|
a_ptr = a_ptr + 4; |
|
|
|
d6d7 = GiLoadFloat32(b_ptr); |
|
|
|
b_ptr = b_ptr + 4; |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); |
|
|
|
|
|
|
|
d8d9 = GiSimdFmaLane(d8d9, d6d7, d2d3, 0); |
|
|
|
d10d11 = GiSimdFmaLane(d10d11, d6d7, d2d3, 1); |
|
|
|
d12d13 = GiSimdFmaLane(d12d13, d6d7, d2d3, 2); |
|
|
|
d14d15 = GiSimdFmaLane(d14d15, d6d7, d2d3, 3); |
|
|
|
} |
|
|
|
|
|
|
|
if (m_remain == 4) { |
|
|
|
if (n_remain == 4) { |
|
|
|
GiStoreFloat32(r0, d8d9); |
|
|
|
r0 = r0 + 4; |
|
|
|
GiStoreFloat32(r1, d10d11); |
|
|
|
r1 = r1 + 4; |
|
|
|
GiStoreFloat32(r2, d12d13); |
|
|
|
r2 = r2 + 4; |
|
|
|
GiStoreFloat32(r3, d14d15); |
|
|
|
r3 = r3 + 4; |
|
|
|
} else if (n_remain == 3) { |
|
|
|
GiStoreFloat32(tmp, d8d9); |
|
|
|
memcpy(r0, tmp, d_size * 3); |
|
|
|
r0 += 3; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d10d11); |
|
|
|
memcpy(r1, tmp, d_size * 3); |
|
|
|
r1 += 3; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d12d13); |
|
|
|
memcpy(r2, tmp, d_size * 3); |
|
|
|
r2 += 3; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d14d15); |
|
|
|
memcpy(r3, tmp, d_size * 3); |
|
|
|
r3 += 3; |
|
|
|
} else if (n_remain == 2) { |
|
|
|
GiStoreFloat32(tmp, d8d9); |
|
|
|
memcpy(r0, tmp, d_size * 2); |
|
|
|
r0 += 2; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d10d11); |
|
|
|
memcpy(r1, tmp, d_size * 2); |
|
|
|
r1 += 2; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d12d13); |
|
|
|
memcpy(r2, tmp, d_size * 2); |
|
|
|
r2 += 2; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d14d15); |
|
|
|
memcpy(r3, tmp, d_size * 2); |
|
|
|
r3 += 2; |
|
|
|
} else if (n_remain == 1) { |
|
|
|
GiStoreFloat32(tmp, d8d9); |
|
|
|
*r0 = tmp[0]; |
|
|
|
r0++; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d10d11); |
|
|
|
*r1 = tmp[0]; |
|
|
|
r1++; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d12d13); |
|
|
|
*r2 = tmp[0]; |
|
|
|
r2++; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d14d15); |
|
|
|
*r3 = tmp[0]; |
|
|
|
r3++; |
|
|
|
} |
|
|
|
} else if (m_remain == 3) { |
|
|
|
if (n_remain == 4) { |
|
|
|
GiStoreFloat32(r0, d8d9); |
|
|
|
r0 = r0 + 4; |
|
|
|
GiStoreFloat32(r1, d10d11); |
|
|
|
r1 = r1 + 4; |
|
|
|
GiStoreFloat32(r2, d12d13); |
|
|
|
r2 = r2 + 4; |
|
|
|
} else if (n_remain == 3) { |
|
|
|
GiStoreFloat32(tmp, d8d9); |
|
|
|
memcpy(r0, tmp, d_size * 3); |
|
|
|
r0 += 3; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d10d11); |
|
|
|
memcpy(r1, tmp, d_size * 3); |
|
|
|
r1 += 3; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d12d13); |
|
|
|
memcpy(r2, tmp, d_size * 3); |
|
|
|
r2 += 3; |
|
|
|
} else if (n_remain == 2) { |
|
|
|
GiStoreFloat32(tmp, d8d9); |
|
|
|
memcpy(r0, tmp, d_size * 2); |
|
|
|
r0 += 2; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d10d11); |
|
|
|
memcpy(r1, tmp, d_size * 2); |
|
|
|
r1 += 2; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d12d13); |
|
|
|
memcpy(r2, tmp, d_size * 2); |
|
|
|
r2 += 2; |
|
|
|
} else if (n_remain == 1) { |
|
|
|
GiStoreFloat32(tmp, d8d9); |
|
|
|
*r0 = tmp[0]; |
|
|
|
r0++; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d10d11); |
|
|
|
*r1 = tmp[0]; |
|
|
|
r1++; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d12d13); |
|
|
|
*r2 = tmp[0]; |
|
|
|
r2++; |
|
|
|
} |
|
|
|
} else if (m_remain == 2) { |
|
|
|
if (n_remain == 4) { |
|
|
|
GiStoreFloat32(r0, d8d9); |
|
|
|
r0 = r0 + 4; |
|
|
|
GiStoreFloat32(r1, d10d11); |
|
|
|
r1 = r1 + 4; |
|
|
|
} else if (n_remain == 3) { |
|
|
|
GiStoreFloat32(tmp, d8d9); |
|
|
|
memcpy(r0, tmp, d_size * 3); |
|
|
|
r0 += 3; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d10d11); |
|
|
|
memcpy(r1, tmp, d_size * 3); |
|
|
|
r1 += 3; |
|
|
|
} else if (n_remain == 2) { |
|
|
|
GiStoreFloat32(tmp, d8d9); |
|
|
|
memcpy(r0, tmp, d_size * 2); |
|
|
|
r0 += 2; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d10d11); |
|
|
|
memcpy(r1, tmp, d_size * 2); |
|
|
|
r1 += 2; |
|
|
|
} else if (n_remain == 1) { |
|
|
|
GiStoreFloat32(tmp, d8d9); |
|
|
|
*r0 = tmp[0]; |
|
|
|
r0++; |
|
|
|
|
|
|
|
GiStoreFloat32(tmp, d10d11); |
|
|
|
*r1 = tmp[0]; |
|
|
|
r1++; |
|
|
|
} |
|
|
|
} else if (m_remain == 1) { |
|
|
|
if (n_remain == 4) { |
|
|
|
GiStoreFloat32(r0, d8d9); |
|
|
|
r0 = r0 + 4; |
|
|
|
} else if (n_remain == 3) { |
|
|
|
GiStoreFloat32(tmp, d8d9); |
|
|
|
memcpy(r0, tmp, d_size * 3); |
|
|
|
r0 += 3; |
|
|
|
} else if (n_remain == 2) { |
|
|
|
GiStoreFloat32(tmp, d8d9); |
|
|
|
memcpy(r0, tmp, d_size * 2); |
|
|
|
r0 += 2; |
|
|
|
} else if (n_remain == 1) { |
|
|
|
GiStoreFloat32(tmp, d8d9); |
|
|
|
*r0 = tmp[0]; |
|
|
|
r0++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
#pragma GCC diagnostic pop |
|
|
|
|
|
|
|
void gi_sgemm_4x12_pack_A_n( |
|
|
|
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, |
|
|
|
int kmax) { |
|
|
|
float zerobuff[4]; |
|
|
|
std::memset(zerobuff, 0, sizeof(float) * 4); |
|
|
|
|
|
|
|
int y = y0; |
|
|
|
for (; y + 3 < ymax; y += 4) { |
|
|
|
const float* inptr0 = inptr + y * ldin + k0; |
|
|
|
const float* inptr1 = inptr0 + ldin; |
|
|
|
const float* inptr2 = inptr1 + ldin; |
|
|
|
const float* inptr3 = inptr2 + ldin; |
|
|
|
|
|
|
|
int K = (kmax - k0); |
|
|
|
for (; K > 3; K -= 4) { |
|
|
|
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); |
|
|
|
} |
|
|
|
|
|
|
|
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); |
|
|
|
} |
|
|
|
|
|
|
|
for (; y < ymax; y += 4) { |
|
|
|
const float* inptr0 = inptr + y * ldin + k0; |
|
|
|
const float* inptr1 = inptr0 + ldin; |
|
|
|
const float* inptr2 = inptr1 + ldin; |
|
|
|
const float* inptr3 = inptr2 + ldin; |
|
|
|
|
|
|
|
int K = (kmax - k0); |
|
|
|
for (; K > 3; K -= 4) { |
|
|
|
if ((y + 3) >= ymax) { |
|
|
|
switch ((y + 3) - ymax) { |
|
|
|
/* Everything falls through in here */ |
|
|
|
case 2: |
|
|
|
inptr1 = zerobuff; |
|
|
|
MEGDNN_FALLTHRU |
|
|
|
case 1: |
|
|
|
inptr2 = zerobuff; |
|
|
|
MEGDNN_FALLTHRU |
|
|
|
case 0: |
|
|
|
inptr3 = zerobuff; |
|
|
|
break; |
|
|
|
default: |
|
|
|
megdnn_assert(0); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); |
|
|
|
} |
|
|
|
|
|
|
|
if (K > 0) { |
|
|
|
if ((y + 3) >= ymax) { |
|
|
|
switch ((y + 3) - ymax) { |
|
|
|
/* Everything falls through in here */ |
|
|
|
case 2: |
|
|
|
inptr1 = zerobuff; |
|
|
|
MEGDNN_FALLTHRU |
|
|
|
case 1: |
|
|
|
inptr2 = zerobuff; |
|
|
|
MEGDNN_FALLTHRU |
|
|
|
case 0: |
|
|
|
inptr3 = zerobuff; |
|
|
|
break; |
|
|
|
default: |
|
|
|
megdnn_assert(0); |
|
|
|
} |
|
|
|
} |
|
|
|
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void gi_sgemm_4x12_pack_A_t( |
|
|
|
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { |
|
|
|
int ksize = kmax - k0; |
|
|
|
int ksize4 = (ksize << 2); |
|
|
|
float* outptr_base = out; |
|
|
|
|
|
|
|
int k = k0; |
|
|
|
for (; k + 3 < kmax; k += 4) { |
|
|
|
const float* inptr = in + k * ldin + x0; |
|
|
|
const float* inptr1 = inptr + ldin; |
|
|
|
const float* inptr2 = inptr1 + ldin; |
|
|
|
const float* inptr3 = inptr2 + ldin; |
|
|
|
|
|
|
|
int x = x0; |
|
|
|
auto outptr = outptr_base; |
|
|
|
for (; x + 4 <= xmax; x += 4) { |
|
|
|
auto outptr_interleave = outptr; |
|
|
|
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); |
|
|
|
outptr += ksize4; |
|
|
|
} |
|
|
|
|
|
|
|
if (x < xmax) { |
|
|
|
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); |
|
|
|
} |
|
|
|
|
|
|
|
outptr_base += 4 * 4; |
|
|
|
} |
|
|
|
|
|
|
|
for (; k < kmax; k++) { |
|
|
|
const float* inptr = in + k * ldin + x0; |
|
|
|
int x = x0; |
|
|
|
auto outptr = outptr_base; |
|
|
|
for (; x + 4 <= xmax; x += 4) { |
|
|
|
auto outptr_interleave = outptr; |
|
|
|
interleave_1x4_1_s(inptr, outptr_interleave); |
|
|
|
outptr += ksize4; |
|
|
|
} |
|
|
|
|
|
|
|
if (x < xmax) { |
|
|
|
interleave_1(inptr, outptr, 4, xmax - x); |
|
|
|
} |
|
|
|
|
|
|
|
outptr_base += 4; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void gi_sgemm_4x12_pack_B_n( |
|
|
|
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { |
|
|
|
int ksize = kmax - k0; |
|
|
|
int ksize12 = ksize * 12; |
|
|
|
int ksize4 = (ksize << 2); |
|
|
|
float* outptr_base = out; |
|
|
|
float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; |
|
|
|
|
|
|
|
int k = k0; |
|
|
|
for (; k + 3 < kmax; k += 4) { |
|
|
|
const float* inptr = in + k * ldin + x0; |
|
|
|
const float* inptr1 = inptr + ldin; |
|
|
|
const float* inptr2 = inptr1 + ldin; |
|
|
|
const float* inptr3 = inptr2 + ldin; |
|
|
|
|
|
|
|
int x = x0; |
|
|
|
auto outptr = outptr_base; |
|
|
|
for (; x + 12 <= xmax; x += 12) { |
|
|
|
auto outptr_interleave = outptr; |
|
|
|
interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); |
|
|
|
outptr += ksize12; |
|
|
|
} |
|
|
|
outptr = outptr_base4; |
|
|
|
for (; x + 4 <= xmax; x += 4) { |
|
|
|
auto outptr_interleave = outptr; |
|
|
|
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); |
|
|
|
outptr += ksize4; |
|
|
|
} |
|
|
|
|
|
|
|
if (x < xmax) { |
|
|
|
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); |
|
|
|
} |
|
|
|
|
|
|
|
outptr_base += 12 * 4; |
|
|
|
outptr_base4 += 4 * 4; |
|
|
|
} |
|
|
|
|
|
|
|
for (; k < kmax; k++) { |
|
|
|
const float* inptr = in + k * ldin + x0; |
|
|
|
int x = x0; |
|
|
|
auto outptr = outptr_base; |
|
|
|
for (; x + 12 <= xmax; x += 12) { |
|
|
|
auto outptr_interleave = outptr; |
|
|
|
interleave_1x12_1_s(inptr, outptr_interleave); |
|
|
|
outptr += ksize12; |
|
|
|
} |
|
|
|
outptr = outptr_base4; |
|
|
|
for (; x + 4 <= xmax; x += 4) { |
|
|
|
auto outptr_interleave = outptr; |
|
|
|
interleave_1x4_1_s(inptr, outptr_interleave); |
|
|
|
outptr += ksize4; |
|
|
|
} |
|
|
|
|
|
|
|
if (x < xmax) { |
|
|
|
interleave_1(inptr, outptr, 4, xmax - x); |
|
|
|
} |
|
|
|
|
|
|
|
outptr_base += 12; |
|
|
|
outptr_base4 += 4; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void gi_sgemm_4x12_pack_B_t( |
|
|
|
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) { |
|
|
|
float* outptr = out; |
|
|
|
const float* inptr = in; |
|
|
|
float zerobuff[4]; |
|
|
|
std::memset(zerobuff, 0, sizeof(float) * 4); |
|
|
|
int K12 = 12 * (kmax - k0); |
|
|
|
|
|
|
|
int y = y0; |
|
|
|
|
|
|
|
for (; y + 12 <= ymax; y += 12) { |
|
|
|
int yi = y; |
|
|
|
for (; yi < y + 12; yi += 4) { |
|
|
|
const float* inptr0 = inptr + yi * ldin + k0; |
|
|
|
const float* inptr1 = inptr0 + ldin; |
|
|
|
const float* inptr2 = inptr1 + ldin; |
|
|
|
const float* inptr3 = inptr2 + ldin; |
|
|
|
float* outptr_inner = outptr + yi - y; |
|
|
|
|
|
|
|
int x = (kmax - k0); |
|
|
|
for (; x > 3; x -= 4) { |
|
|
|
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, 48); |
|
|
|
} |
|
|
|
for (; x > 0; x--) { |
|
|
|
*outptr_inner++ = *inptr0++; |
|
|
|
*outptr_inner++ = *inptr1++; |
|
|
|
*outptr_inner++ = *inptr2++; |
|
|
|
*outptr_inner++ = *inptr3++; |
|
|
|
outptr_inner += 8; |
|
|
|
} |
|
|
|
} |
|
|
|
outptr += K12; |
|
|
|
} |
|
|
|
|
|
|
|
for (; y < ymax; y += 4) { |
|
|
|
const float* inptr0 = inptr + y * ldin + k0; |
|
|
|
const float* inptr1 = inptr0 + ldin; |
|
|
|
const float* inptr2 = inptr1 + ldin; |
|
|
|
const float* inptr3 = inptr2 + ldin; |
|
|
|
|
|
|
|
/* Cope with ragged cases by copying from a buffer of zeroes instead |
|
|
|
*/ |
|
|
|
int x = (kmax - k0); |
|
|
|
for (; x > 3; x -= 4) { |
|
|
|
if ((y + 3) >= ymax) { |
|
|
|
switch ((y + 3) - ymax) { |
|
|
|
/* Everything falls through in here */ |
|
|
|
case 2: |
|
|
|
inptr1 = zerobuff; |
|
|
|
MEGDNN_FALLTHRU |
|
|
|
case 1: |
|
|
|
inptr2 = zerobuff; |
|
|
|
MEGDNN_FALLTHRU |
|
|
|
case 0: |
|
|
|
inptr3 = zerobuff; |
|
|
|
break; |
|
|
|
default: |
|
|
|
megdnn_assert(0); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); |
|
|
|
} |
|
|
|
|
|
|
|
if (x > 0) { |
|
|
|
if ((y + 3) >= ymax) { |
|
|
|
switch ((y + 3) - ymax) { |
|
|
|
/* Everything falls through in here */ |
|
|
|
case 2: |
|
|
|
inptr1 = zerobuff; |
|
|
|
MEGDNN_FALLTHRU |
|
|
|
case 1: |
|
|
|
inptr2 = zerobuff; |
|
|
|
MEGDNN_FALLTHRU |
|
|
|
case 0: |
|
|
|
inptr3 = zerobuff; |
|
|
|
break; |
|
|
|
default: |
|
|
|
megdnn_assert(0); |
|
|
|
} |
|
|
|
} |
|
|
|
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, x); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace |
|
|
|
|
|
|
|
MEGDNN_REG_GEMM_STRATEGY_IMPL(gi_sgemm_4x12); |
|
|
|
|
|
|
|
void gi_sgemm_4x12::pack_A( |
|
|
|
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, |
|
|
|
bool transpose_A) const { |
|
|
|
if (transpose_A) { |
|
|
|
gi_sgemm_4x12_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); |
|
|
|
} else { |
|
|
|
gi_sgemm_4x12_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void gi_sgemm_4x12::pack_B( |
|
|
|
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, |
|
|
|
bool transpose_B) const { |
|
|
|
if (transpose_B) { |
|
|
|
gi_sgemm_4x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); |
|
|
|
} else { |
|
|
|
gi_sgemm_4x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void gi_sgemm_4x12::kern( |
|
|
|
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, |
|
|
|
size_t LDC, bool is_first_k, const float*, float*) const { |
|
|
|
megdnn_assert( |
|
|
|
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && |
|
|
|
A_dtype.enumv() == DTypeEnum::Float32); |
|
|
|
MEGDNN_MARK_USED_VAR(A_dtype); |
|
|
|
MEGDNN_MARK_USED_VAR(B_dtype); |
|
|
|
MEGDNN_MARK_USED_VAR(C_dtype); |
|
|
|
|
|
|
|
constexpr size_t A_INTERLEAVE = 4; |
|
|
|
constexpr size_t B_INTERLEAVE = 12; |
|
|
|
const int K12 = K * 12; |
|
|
|
const int K4 = K * 4; |
|
|
|
|
|
|
|
size_t m = 0; |
|
|
|
for (; m < M; m += A_INTERLEAVE) { |
|
|
|
float* output = C + (m * LDC); |
|
|
|
|
|
|
|
size_t n = 0; |
|
|
|
const float* cur_packB = packB; |
|
|
|
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { |
|
|
|
kern_4x12( |
|
|
|
packA, cur_packB, K, output, LDC, is_first_k, |
|
|
|
std::min<size_t>(M - m, 4)); |
|
|
|
output += B_INTERLEAVE; |
|
|
|
cur_packB += K12; |
|
|
|
} |
|
|
|
|
|
|
|
for (; n < N; n += 4) { |
|
|
|
kern_4x4( |
|
|
|
packA, cur_packB, K, output, LDC, is_first_k, |
|
|
|
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); |
|
|
|
output += 4; |
|
|
|
cur_packB += K4; |
|
|
|
} |
|
|
|
|
|
|
|
packA += K4; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen |