Browse Source

feat(gi): make matrix_mul apply gi class type

GitOrigin-RevId: 0c0029ee60
release-1.10
Megvii Engine Team 3 years ago
parent
commit
74fb63db29
3 changed files with 29 additions and 17 deletions
  1. +8
    -8
      dnn/src/fallback/matrix_mul/gi/fp32/common.h
  2. +19
    -9
      dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp
  3. +2
    -0
      dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp

+ 8
- 8
dnn/src/fallback/matrix_mul/gi/fp32/common.h View File

@@ -193,24 +193,24 @@ static GI_FORCEINLINE void transpose_4x4_1_s(
GI_FLOAT32_V2_t q0q1 = GiZipqFloat32(d0d1, d2d3); GI_FLOAT32_V2_t q0q1 = GiZipqFloat32(d0d1, d2d3);
GI_FLOAT32_V2_t q2q3 = GiZipqFloat32(d4d5, d6d7); GI_FLOAT32_V2_t q2q3 = GiZipqFloat32(d4d5, d6d7);


GiSt1Float32(outptr, GiGetLowFloat32(q0q1.val[0]));
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q0q1, 0)));
outptr += 2; outptr += 2;
GiSt1Float32(outptr, GiGetLowFloat32(q2q3.val[0]));
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q2q3, 0)));
outptr += stride; outptr += stride;


GiSt1Float32(outptr, GiGetHighFloat32(q0q1.val[0]));
GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q0q1, 0)));
outptr += 2; outptr += 2;
GiSt1Float32(outptr, GiGetHighFloat32(q2q3.val[0]));
GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q2q3, 0)));
outptr += stride; outptr += stride;


GiSt1Float32(outptr, GiGetLowFloat32(q0q1.val[1]));
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q0q1, 1)));
outptr += 2; outptr += 2;
GiSt1Float32(outptr, GiGetLowFloat32(q2q3.val[1]));
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q2q3, 1)));
outptr += stride; outptr += stride;


GiSt1Float32(outptr, GiGetHighFloat32(q0q1.val[1]));
GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q0q1, 1)));
outptr += 2; outptr += 2;
GiSt1Float32(outptr, GiGetHighFloat32(q2q3.val[1]));
GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q2q3, 1)));
outptr += stride; outptr += stride;
} }




+ 19
- 9
dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp View File

@@ -24,21 +24,26 @@ void sgemv_gi_naive_n_mk4(
while (m < M) { while (m < M) {
auto Aptr0 = Aptr; auto Aptr0 = Aptr;
auto Cptr0 = Cptr; auto Cptr0 = Cptr;
GI_FLOAT32_t c[4];
#define INIT(step) c[step] = GiBroadcastFloat32(0.0f);
GI_FLOAT32_V4_t c;
#define INIT(step) GiSetSubVectorFloat32V4(c, step, GiBroadcastFloat32(0.0f));
UNROLL_CALL_RAW(4, INIT) UNROLL_CALL_RAW(4, INIT)
#undef INIT #undef INIT
auto Bptr = B; auto Bptr = B;
size_t k = 0; size_t k = 0;
while (k < K) { while (k < K) {
GI_FLOAT32_t b = GiLoadFloat32(Bptr); GI_FLOAT32_t b = GiLoadFloat32(Bptr);
GI_FLOAT32_V2_t a[2];
#define LOAD_A(step) a[step] = GiLoadFloat32V2(Aptr0 + step * 8);
UNROLL_CALL_RAW(2, LOAD_A)
GI_FLOAT32_V4_t a;
#define LOAD_A(step) GiSetSubVectorFloat32V4(a, step, GiLoadFloat32(Aptr0 + step * 4));
UNROLL_CALL_RAW(4, LOAD_A)
#undef LOAD_A #undef LOAD_A


#define COMPT(step) \
c[step] = GiSimdFmaLane(c[step], a[step / 2].val[step % 2], b, step % 4);
#define COMPT(step) \
t = GiSimdFmaLane( \
GiGetSubVectorFloat32V4(c, step), GiGetSubVectorFloat32V4(a, step), b, \
step % 4); \
GiSetSubVectorFloat32V4(c, step, t);

GI_FLOAT32_t t;
UNROLL_CALL_RAW(4, COMPT) UNROLL_CALL_RAW(4, COMPT)
#undef COMPT #undef COMPT
Bptr += Bstride; Bptr += Bstride;
@@ -46,11 +51,16 @@ void sgemv_gi_naive_n_mk4(
k += PACK_SIZE; k += PACK_SIZE;
} }


#define ADD_C(step, stride) c[step] = GiAddFloat32(c[step], c[step + stride]);
#define ADD_C(step, stride) \
t = GiAddFloat32( \
GiGetSubVectorFloat32V4(c, step), \
GiGetSubVectorFloat32V4(c, step + stride)); \
GiSetSubVectorFloat32V4(c, step, t);
GI_FLOAT32_t t;
UNROLL_CALL_RAW(2, ADD_C, 2) UNROLL_CALL_RAW(2, ADD_C, 2)
UNROLL_CALL_RAW(1, ADD_C, 1) UNROLL_CALL_RAW(1, ADD_C, 1)
#undef ADD_C #undef ADD_C
GiStoreFloat32(Cptr0, c[0]);
GiStoreFloat32(Cptr0, GiGetSubVectorFloat32V4(c, 0));


Aptr += Astride; Aptr += Astride;
Cptr += Cstride; Cptr += Cstride;


+ 2
- 0
dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp View File

@@ -82,6 +82,7 @@ void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) {
GI_FLOAT32_t d6d7 = GiLoadFloat32(B); GI_FLOAT32_t d6d7 = GiLoadFloat32(B);
B = B + 4; B = B + 4;


GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0);
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0);
GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0);
@@ -173,6 +174,7 @@ void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) {
B = B + 4; B = B + 4;
GI_FLOAT32_t d6d7 = GiLoadFloat32(B); GI_FLOAT32_t d6d7 = GiLoadFloat32(B);
B = B + 4; B = B + 4;
GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0);
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0);


Loading…
Cancel
Save