diff --git a/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp index 03bb0cd4..ae2f8aa8 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp @@ -1,7 +1,6 @@ #include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" -#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/elemwise_helper/elemwise_op.h" #pragma GCC diagnostic ignored "-Wunused-parameter" diff --git a/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp index 54e1df8e..b2f30207 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp @@ -1,7 +1,6 @@ #include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" -#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/elemwise_helper/elemwise_op.h" #pragma GCC diagnostic ignored "-Wunused-parameter" diff --git a/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.cpp b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.cpp index b64b2557..72d0863a 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.cpp @@ -4,7 +4,6 @@ #include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" #include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" -#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/elemwise_helper/elemwise_op.h" using namespace megdnn; diff --git a/dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h b/dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h index de1d15ce..74871894 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h @@ -3,29 +3,44 @@ #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/gi/fp32/helper.h" -#include "src/fallback/conv_bias/gi/utils.h" namespace megdnn { namespace fallback { +/* + * wd##0 = d##0; + * tmp0 = (d##0 + d##2) * -0.2222222f + * tmp1 = d##1 * -0.2222222 + * wd##1 = tmp0 + tmp1 + * wd##2 = tmp0 - tmp1 + * tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f + * tmp1 = d##1 * 0.0222222f + * wd##3 = tmp0 + tmp1 + * wd##4 = tmp0 - tmp1 + * tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f + * tmp1 = d##1 * 0.3555556f + * wd##5 = tmp0 + tmp1 + * wd##6 = tmp0 - tmp1 + * wd##7 = d##2 + */ template struct FilterTransform6X3 { -#define FILTER_TRANSFORM(d, wd) \ - do { \ - wd##0 = d##0; \ - auto tmp0 = (d##0 + d##2) * -0.2222222f; \ - auto tmp1 = d##1 * -0.2222222f; \ - wd##1 = tmp0 + tmp1; \ - wd##2 = tmp0 - tmp1; \ - tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \ - tmp1 = d##1 * 0.0222222f; \ - wd##3 = tmp0 + tmp1; \ - wd##4 = tmp0 - tmp1; \ - tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \ - tmp1 = d##1 * 0.3555556f; \ - wd##5 = tmp0 + tmp1; \ - wd##6 = tmp0 - tmp1; \ - wd##7 = d##2; \ +#define FILTER_TRANSFORM(d, wd, ADDC, SUBC, MULC) \ + do { \ + wd##0 = d##0; \ + auto tmp0 = MULC(ADDC(d##0, d##2), -0.2222222f); \ + auto tmp1 = MULC(d##1, -0.2222222f); \ + wd##1 = ADDC(tmp0, tmp1); \ + wd##2 = SUBC(tmp0, tmp1); \ + tmp0 = ADDC(MULC(d##0, 0.0111111f), MULC(d##2, 0.0444444f)); \ + tmp1 = MULC(d##1, 0.0222222f); \ + wd##3 = ADDC(tmp0, tmp1); \ + wd##4 = SUBC(tmp0, tmp1); \ + tmp0 = ADDC(MULC(d##0, 0.7111111f), MULC(d##2, 0.1777778f)); \ + tmp1 = MULC(d##1, 0.3555556f); \ + wd##5 = ADDC(tmp0, tmp1); \ + wd##6 = SUBC(tmp0, tmp1); \ + wd##7 = d##2; \ } while (0); static void transform( @@ -49,37 +64,35 @@ struct FilterTransform6X3 { rep(ic, IC) { const float* fptr = filter + (oc * IC + ic) * 3 * 3; - Vector g0 = Vector::load(fptr); - Vector g1 = Vector::load(fptr + 3); + GI_FLOAT32_t g0 = GiLoadFloat32(fptr); + GI_FLOAT32_t g1 = GiLoadFloat32(fptr + 3); - Vector g2 = Vector::load(fptr + 6 - 1); + GI_FLOAT32_t g2 = GiLoadFloat32(fptr + 6 - 1); GI_FLOAT32_t zeros = GiZeroFloat32(); - g2.value = GiFloat32Type2FixLenType( - GiExtqFloat32(GiFixLenType2GiFloat32Type(g2.value), zeros, 1)); + g2 = GiExtqFloat32(g2, zeros, 1); -#define cb(i) Vector wd##i; +#define cb(i) GI_FLOAT32_t wd##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb -#define cb(i) Vector wdt##i; - UNROLL_CALL_NOWRAPPER(3, cb); -#undef cb - -#define cb(i) Vector ret##i; - UNROLL_CALL_NOWRAPPER(8, cb); -#undef cb - - FILTER_TRANSFORM(g, wd); + FILTER_TRANSFORM(g, wd, ADDF, SUBF, MULSF); size_t ocb = oc / 4; size_t oc4 = oc % 4; size_t icb = ic / 4; size_t ic4 = ic % 4; #if MEGDNN_AARCH64 +#define cb(i) GI_FLOAT32_V2_t wdt##i; + UNROLL_CALL_NOWRAPPER(3, cb); +#undef cb + +#define cb(i) GI_FLOAT32_V2_t ret##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb TRANSPOSE_8x3(wd, wdt); - FILTER_TRANSFORM(wdt, ret); + FILTER_TRANSFORM(wdt, ret, ADDFV2, SUBFV2, MULSFV2); -#define cb(i) ret##i.save(transform_mid_buf + i * alpha); +#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb rep(i, alpha) rep(j, alpha) { @@ -116,8 +129,7 @@ struct FilterTransform6X3 { mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ mid_buf1 += 8; \ } while (0); -#define GET_VECTOR_ELEM(s, i, idx) \ - GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value)) +#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i)) float* mid_buf1 = transform_mid_buf; UNROLL_CALL_NOWRAPPER(8, cb); diff --git a/dnn/src/fallback/conv_bias/gi/fp32/helper.h b/dnn/src/fallback/conv_bias/gi/fp32/helper.h index 2b38f8ef..00ecce48 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/helper.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/helper.h @@ -2,6 +2,15 @@ #include "src/common/unroll_macro.h" #include "src/fallback/general_intrinsic/gi_float.h" +#define ADDF GiAddFloat32 +#define ADDFV2 GiAddFloat32V2 +#define SUBF GiSubtractFloat32 +#define SUBFV2 GiSubtractFloat32V2 +#define MULF GiMultiplyFloat32 +#define MULFV2 GiMultiplyFloat32V2 +#define MULSF GiMultiplyScalerFloat32 +#define MULSFV2 GiMultiplyScalerFloat32V2 + namespace megdnn { namespace fallback { inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { @@ -47,159 +56,166 @@ inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { #define CONCAT(a, idx) a##idx #if MEGDNN_AARCH64 -//! ret and a are type Vector -#define TRANSPOSE_8x8(a, ret) \ - do { \ - auto b0 = GiZipqFloat32(CONCAT(a, 0).value.val[0], CONCAT(a, 1).value.val[0]); \ - auto b1 = GiZipqFloat32(CONCAT(a, 0).value.val[1], CONCAT(a, 1).value.val[1]); \ - auto b2 = GiZipqFloat32(CONCAT(a, 2).value.val[0], CONCAT(a, 3).value.val[0]); \ - auto b3 = GiZipqFloat32(CONCAT(a, 2).value.val[1], CONCAT(a, 3).value.val[1]); \ - auto b4 = GiZipqFloat32(CONCAT(a, 4).value.val[0], CONCAT(a, 5).value.val[0]); \ - auto b5 = GiZipqFloat32(CONCAT(a, 4).value.val[1], CONCAT(a, 5).value.val[1]); \ - auto b6 = GiZipqFloat32(CONCAT(a, 6).value.val[0], CONCAT(a, 7).value.val[0]); \ - auto b7 = GiZipqFloat32(CONCAT(a, 6).value.val[1], CONCAT(a, 7).value.val[1]); \ - CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b0.val[0]), \ - GiReinterpretqFloat32ToS64(b2.val[0]))); \ - CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b4.val[0]), \ - GiReinterpretqFloat32ToS64(b6.val[0]))); \ - CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b0.val[0]), \ - GiReinterpretqFloat32ToS64(b2.val[0]))); \ - CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b4.val[0]), \ - GiReinterpretqFloat32ToS64(b6.val[0]))); \ - CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b0.val[1]), \ - GiReinterpretqFloat32ToS64(b2.val[1]))); \ - CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b4.val[1]), \ - GiReinterpretqFloat32ToS64(b6.val[1]))); \ - CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b0.val[1]), \ - GiReinterpretqFloat32ToS64(b2.val[1]))); \ - CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b4.val[1]), \ - GiReinterpretqFloat32ToS64(b6.val[1]))); \ - CONCAT(ret, 4).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b1.val[0]), \ - GiReinterpretqFloat32ToS64(b3.val[0]))); \ - CONCAT(ret, 4).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b5.val[0]), \ - GiReinterpretqFloat32ToS64(b7.val[0]))); \ - CONCAT(ret, 5).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b1.val[0]), \ - GiReinterpretqFloat32ToS64(b3.val[0]))); \ - CONCAT(ret, 5).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b5.val[0]), \ - GiReinterpretqFloat32ToS64(b7.val[0]))); \ - CONCAT(ret, 6).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b1.val[1]), \ - GiReinterpretqFloat32ToS64(b3.val[1]))); \ - CONCAT(ret, 6).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b5.val[1]), \ - GiReinterpretqFloat32ToS64(b7.val[1]))); \ - CONCAT(ret, 7).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b1.val[1]), \ - GiReinterpretqFloat32ToS64(b3.val[1]))); \ - CONCAT(ret, 7).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b5.val[1]), \ - GiReinterpretqFloat32ToS64(b7.val[1]))); \ +#define TRANSPOSE_8x8(a, ret) \ + do { \ + auto b0 = GiZipqFloat32(CONCAT(a, 0).val[0], CONCAT(a, 1).val[0]); \ + auto b1 = GiZipqFloat32(CONCAT(a, 0).val[1], CONCAT(a, 1).val[1]); \ + auto b2 = GiZipqFloat32(CONCAT(a, 2).val[0], CONCAT(a, 3).val[0]); \ + auto b3 = GiZipqFloat32(CONCAT(a, 2).val[1], CONCAT(a, 3).val[1]); \ + auto b4 = GiZipqFloat32(CONCAT(a, 4).val[0], CONCAT(a, 5).val[0]); \ + auto b5 = GiZipqFloat32(CONCAT(a, 4).val[1], CONCAT(a, 5).val[1]); \ + auto b6 = GiZipqFloat32(CONCAT(a, 6).val[0], CONCAT(a, 7).val[0]); \ + auto b7 = GiZipqFloat32(CONCAT(a, 6).val[1], CONCAT(a, 7).val[1]); \ + CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b2.val[0]))); \ + CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b4.val[0]), \ + GiReinterpretqFloat32ToS64(b6.val[0]))); \ + CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b2.val[0]))); \ + CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b4.val[0]), \ + GiReinterpretqFloat32ToS64(b6.val[0]))); \ + CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[1]), \ + GiReinterpretqFloat32ToS64(b2.val[1]))); \ + CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b4.val[1]), \ + GiReinterpretqFloat32ToS64(b6.val[1]))); \ + CONCAT(ret, 3).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b0.val[1]), \ + GiReinterpretqFloat32ToS64(b2.val[1]))); \ + CONCAT(ret, 3).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b4.val[1]), \ + GiReinterpretqFloat32ToS64(b6.val[1]))); \ + CONCAT(ret, 4).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b1.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 4).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b5.val[0]), \ + GiReinterpretqFloat32ToS64(b7.val[0]))); \ + CONCAT(ret, 5).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b1.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 5).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b5.val[0]), \ + GiReinterpretqFloat32ToS64(b7.val[0]))); \ + CONCAT(ret, 6).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b1.val[1]), \ + GiReinterpretqFloat32ToS64(b3.val[1]))); \ + CONCAT(ret, 6).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b5.val[1]), \ + GiReinterpretqFloat32ToS64(b7.val[1]))); \ + CONCAT(ret, 7).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b1.val[1]), \ + GiReinterpretqFloat32ToS64(b3.val[1]))); \ + CONCAT(ret, 7).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b5.val[1]), \ + GiReinterpretqFloat32ToS64(b7.val[1]))); \ } while (0); -#define TRANSPOSE_8x3(a, ret) \ - auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ - auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ - auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ - auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ - CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b0.val[0]), \ - GiReinterpretqFloat32ToS64(b1.val[0]))); \ - CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b2.val[0]), \ - GiReinterpretqFloat32ToS64(b3.val[0]))); \ - CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b0.val[0]), \ - GiReinterpretqFloat32ToS64(b1.val[0]))); \ - CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b2.val[0]), \ - GiReinterpretqFloat32ToS64(b3.val[0]))); \ - CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b0.val[1]), \ - GiReinterpretqFloat32ToS64(b1.val[1]))); \ - CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b2.val[1]), \ +#define TRANSPOSE_8x3(a, ret) \ + auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \ + auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \ + auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5)); \ + auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7)); \ + CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b1.val[0]))); \ + CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b2.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b1.val[0]))); \ + CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b2.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[1]), \ + GiReinterpretqFloat32ToS64(b1.val[1]))); \ + CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b2.val[1]), \ GiReinterpretqFloat32ToS64(b3.val[1]))); -#define TRANSPOSE_8x4(a, ret) \ - auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ - auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ - auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ - auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ - CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b0.val[0]), \ - GiReinterpretqFloat32ToS64(b1.val[0]))); \ - CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b2.val[0]), \ - GiReinterpretqFloat32ToS64(b3.val[0]))); \ - CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b0.val[0]), \ - GiReinterpretqFloat32ToS64(b1.val[0]))); \ - CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b2.val[0]), \ - GiReinterpretqFloat32ToS64(b3.val[0]))); \ - CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b0.val[1]), \ - GiReinterpretqFloat32ToS64(b1.val[1]))); \ - CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ - GiReinterpretqFloat32ToS64(b2.val[1]), \ - GiReinterpretqFloat32ToS64(b3.val[1]))); \ - CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b0.val[1]), \ - GiReinterpretqFloat32ToS64(b1.val[1]))); \ - CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ - GiReinterpretqFloat32ToS64(b2.val[1]), \ +#define TRANSPOSE_8x4(a, ret) \ + auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \ + auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \ + auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5)); \ + auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7)); \ + CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b1.val[0]))); \ + CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b2.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b1.val[0]))); \ + CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b2.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[1]), \ + GiReinterpretqFloat32ToS64(b1.val[1]))); \ + CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b2.val[1]), \ + GiReinterpretqFloat32ToS64(b3.val[1]))); \ + CONCAT(ret, 3).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b0.val[1]), \ + GiReinterpretqFloat32ToS64(b1.val[1]))); \ + CONCAT(ret, 3).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b2.val[1]), \ GiReinterpretqFloat32ToS64(b3.val[1]))); #else -#define TRANSPOSE_8x4(a, ret) \ - auto b0 = GiZipqFloat32( \ - GiFixLenType2GiFloat32Type(CONCAT(a, 0).value), \ - GiFixLenType2GiFloat32Type(CONCAT(a, 1).value)); \ - auto b1 = GiZipqFloat32( \ - GiFixLenType2GiFloat32Type(CONCAT(a, 2).value), \ - GiFixLenType2GiFloat32Type(CONCAT(a, 3).value)); \ - auto b2 = GiZipqFloat32( \ - GiFixLenType2GiFloat32Type(CONCAT(a, 4).value), \ - GiFixLenType2GiFloat32Type(CONCAT(a, 5).value)); \ - auto b3 = GiZipqFloat32( \ - GiFixLenType2GiFloat32Type(CONCAT(a, 6).value), \ - GiFixLenType2GiFloat32Type(CONCAT(a, 7).value)); \ - CONCAT(ret, 0).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \ - GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 0)), \ - GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \ - CONCAT(ret, 1).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \ - GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 0)), \ - GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \ - CONCAT(ret, 2).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \ - GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 1)), \ - GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \ - CONCAT(ret, 3).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \ - GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 1)), \ - GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \ - CONCAT(ret, 0).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \ - GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 0)), \ - GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \ - CONCAT(ret, 1).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \ - GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 0)), \ - GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \ - CONCAT(ret, 2).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \ - GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 1)), \ - GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 1)))); \ - CONCAT(ret, 3).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \ - GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 1)), \ - GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 1)))); +#define TRANSPOSE_8x4(a, ret) \ + auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \ + auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \ + auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5)); \ + auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7)); \ + GiSetSubVectorFloat32V2( \ + CONCAT(ret, 0), 0, \ + GiCombineFloat32( \ + GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 0)), \ + GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \ + GiSetSubVectorFloat32V2( \ + CONCAT(ret, 1), 0, \ + GiCombineFloat32( \ + GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 0)), \ + GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \ + GiSetSubVectorFloat32V2( \ + CONCAT(ret, 2), 0, \ + GiCombineFloat32( \ + GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 1)), \ + GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \ + GiSetSubVectorFloat32V2( \ + CONCAT(ret, 3), 0, \ + GiCombineFloat32( \ + GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 1)), \ + GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \ + GiSetSubVectorFloat32V2( \ + CONCAT(ret, 0), 1, \ + GiCombineFloat32( \ + GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 0)), \ + GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \ + GiSetSubVectorFloat32V2( \ + CONCAT(ret, 1), 1, \ + GiCombineFloat32( \ + GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 0)), \ + GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \ + GiSetSubVectorFloat32V2( \ + CONCAT(ret, 2), 1, \ + GiCombineFloat32( \ + GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 1)), \ + GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 1)))); \ + GiSetSubVectorFloat32V2( \ + CONCAT(ret, 3), 1, \ + GiCombineFloat32( \ + GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 1)), \ + GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 1)))); #endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp index 5951f6f4..ee1be8ab 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp @@ -1,7 +1,6 @@ #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h" -#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/conv_bias/gi/fp32/helper.h" @@ -70,8 +69,7 @@ struct InputTransform2X3 { size_t nr_units_in_tile, size_t ic, size_t IC) { constexpr size_t alpha = 2 + 3 - 1; // BT * d * B -#define cb(m, n) \ - Vector d##m##n = Vector::load(patchT + m * 4 * 4 + n * 4); +#define cb(m, n) GI_FLOAT32_t d##m##n = GiLoadFloat32(patchT + m * 4 * 4 + n * 4); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb @@ -80,20 +78,20 @@ struct InputTransform2X3 { //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 //! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 -#define cb(m) \ - auto t0##m = d0##m - d2##m; \ - auto t1##m = d1##m + d2##m; \ - auto t2##m = d2##m - d1##m; \ - auto t3##m = d3##m - d1##m; +#define cb(m) \ + auto t0##m = SUBF(d0##m, d2##m); \ + auto t1##m = ADDF(d1##m, d2##m); \ + auto t2##m = SUBF(d2##m, d1##m); \ + auto t3##m = SUBF(d3##m, d1##m); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(m) \ - d##m##0 = t##m##0 - t##m##2; \ - d##m##1 = t##m##1 + t##m##2; \ - d##m##2 = t##m##2 - t##m##1; \ - d##m##3 = t##m##3 - t##m##1; +#define cb(m) \ + d##m##0 = SUBF(t##m##0, t##m##2); \ + d##m##1 = ADDF(t##m##1, t##m##2); \ + d##m##2 = SUBF(t##m##2, t##m##1); \ + d##m##3 = SUBF(t##m##3, t##m##1); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb @@ -101,9 +99,10 @@ struct InputTransform2X3 { size_t ICB = IC / 4; size_t icb = ic / 4; #define cb(m, n) \ - d##m##n.save( \ + GiStoreFloat32( \ input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ - icb * nr_units_in_tile * 4 + unit_idx * 4); + icb * nr_units_in_tile * 4 + unit_idx * 4, \ + d##m##n); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) #undef cb } @@ -125,7 +124,7 @@ struct OutputTransform2X3 { size_t ocb = oc_index / 4; #define cb(m, n) \ - auto v##m##n = Vector::load( \ + auto v##m##n = GiLoadFloat32( \ output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ ocb * nr_units_in_tile * 4 + unit_idx * 4); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); @@ -134,37 +133,37 @@ struct OutputTransform2X3 { //! 0 1 -1 1 v10 v11 v12 v13 1 1 //! v20 v21 v22 v23 1 -1 //! v30 v31 v32 v33 0 1 -#define cb(m) \ - auto t0##m = v0##m + v1##m + v2##m; \ - auto t1##m = v1##m - v2##m + v3##m; +#define cb(m) \ + auto t0##m = ADDF(ADDF(v0##m, v1##m), v2##m); \ + auto t1##m = ADDF(SUBF(v1##m, v2##m), v3##m); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb - v00 = t00 + t01 + t02; - v10 = t10 + t11 + t12; - v01 = t01 - t02 + t03; - v11 = t11 - t12 + t13; + v00 = ADDF(ADDF(t00, t01), t02); + v10 = ADDF(ADDF(t10, t11), t12); + v01 = ADDF(SUBF(t01, t02), t03); + v11 = ADDF(SUBF(t11, t12), t13); - Vector vbias; + GI_FLOAT32_t vbias; if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - vbias = Vector::load(bias + oc); + vbias = GiLoadFloat32(bias + oc); - v00 += vbias; - v10 += vbias; - v01 += vbias; - v11 += vbias; + v00 = ADDF(v00, vbias); + v10 = ADDF(v10, vbias); + v01 = ADDF(v01, vbias); + v11 = ADDF(v11, vbias); } if (bmode != BiasMode::BIAS) { - v00 = op(GiFixLenType2GiFloat32Type(v00.value)); - v01 = op(GiFixLenType2GiFloat32Type(v01.value)); - v10 = op(GiFixLenType2GiFloat32Type(v10.value)); - v11 = op(GiFixLenType2GiFloat32Type(v11.value)); + v00 = op(v00); + v01 = op(v01); + v10 = op(v10); + v11 = op(v11); } - v00.save(transform_mid_buf + (0 * 2 + 0) * 4); - v10.save(transform_mid_buf + (1 * 2 + 0) * 4); - v01.save(transform_mid_buf + (0 * 2 + 1) * 4); - v11.save(transform_mid_buf + (1 * 2 + 1) * 4); + GiStoreFloat32(transform_mid_buf + (0 * 2 + 0) * 4, v00); + GiStoreFloat32(transform_mid_buf + (1 * 2 + 0) * 4, v10); + GiStoreFloat32(transform_mid_buf + (0 * 2 + 1) * 4, v01); + GiStoreFloat32(transform_mid_buf + (1 * 2 + 1) * 4, v11); for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp index 4eabc4c4..a7af3421 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp @@ -1,7 +1,6 @@ #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h" -#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/conv_bias/gi/fp32/helper.h" @@ -15,61 +14,124 @@ using namespace megdnn; using namespace fallback; namespace { +/* + * wd##0 = d##0; + * wd##r0 = d##r0; + * wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; + * wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; + * wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; + * wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; + * auto tmpd0 = d##0 * 0.7111111; + * auto tmpd1 = d##1 * 0.3555556; + * auto tmpd2 = d##2 * 0.1777778; + * auto tmpd3 = d##3 * 0.0888889; + * auto tmpd4 = d##4 * 0.0444444; + * auto tmpdr0 = d##r0 * 0.7111111; + * auto tmpdr1 = d##r1 * 0.3555556; + * auto tmpdr2 = d##r2 * 0.1777778; + * auto tmpdr3 = d##r3 * 0.0888889; + * auto tmpdr4 = d##r4 * 0.0444444; + * wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; + * wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; + * wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; + * wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; + * tmpd0 = d##0 * 0.0111111; + * tmpd1 = d##1 * 0.0222222; + * tmpd2 = d##2 * 0.0444444; + * tmpd3 = d##3 * 0.0888889; + * tmpd4 = d##4 * 0.1777778; + * tmpdr0 = d##r0 * 0.0111111; + * tmpdr1 = d##r1 * 0.0222222; + * tmpdr2 = d##r2 * 0.0444444; + * tmpdr3 = d##r3 * 0.0888889; + * tmpdr4 = d##r4 * 0.1777778; + * wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; + * wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; + * wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; + * wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; + * wd##7 = d##4; + * wd##r7 = d##r4; + * + * */ struct FilterTransform4X5 { -#define FILTER_TRANSFORM(d, wd) \ - do { \ - wd##0 = d##0; \ - wd##r0 = d##r0; \ - wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \ - wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; \ - wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \ - wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; \ - auto tmpd0 = d##0 * 0.7111111; \ - auto tmpd1 = d##1 * 0.3555556; \ - auto tmpd2 = d##2 * 0.1777778; \ - auto tmpd3 = d##3 * 0.0888889; \ - auto tmpd4 = d##4 * 0.0444444; \ - auto tmpdr0 = d##r0 * 0.7111111; \ - auto tmpdr1 = d##r1 * 0.3555556; \ - auto tmpdr2 = d##r2 * 0.1777778; \ - auto tmpdr3 = d##r3 * 0.0888889; \ - auto tmpdr4 = d##r4 * 0.0444444; \ - wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ - wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ - wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ - wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ - tmpd0 = d##0 * 0.0111111; \ - tmpd1 = d##1 * 0.0222222; \ - tmpd2 = d##2 * 0.0444444; \ - tmpd3 = d##3 * 0.0888889; \ - tmpd4 = d##4 * 0.1777778; \ - tmpdr0 = d##r0 * 0.0111111; \ - tmpdr1 = d##r1 * 0.0222222; \ - tmpdr2 = d##r2 * 0.0444444; \ - tmpdr3 = d##r3 * 0.0888889; \ - tmpdr4 = d##r4 * 0.1777778; \ - wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ - wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ - wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ - wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ - wd##7 = d##4; \ - wd##r7 = d##r4; \ +#define FILTER_TRANSFORM(d, wd) \ + do { \ + wd##0 = d##0; \ + wd##r0 = d##r0; \ + wd##1 = MULSF( \ + ADDF(ADDF(ADDF(ADDF(d##0, d##1), d##2), d##3), d##4), -0.2222222); \ + wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; \ + wd##2 = MULSF( \ + ADDF(SUBF(ADDF(SUBF(d##0, d##1), d##2), d##3), d##4), -0.2222222); \ + wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; \ + auto tmpd0 = MULSF(d##0, 0.7111111); \ + auto tmpd1 = MULSF(d##1, 0.3555556); \ + auto tmpd2 = MULSF(d##2, 0.1777778); \ + auto tmpd3 = MULSF(d##3, 0.0888889); \ + auto tmpd4 = MULSF(d##4, 0.0444444); \ + auto tmpdr0 = d##r0 * 0.7111111; \ + auto tmpdr1 = d##r1 * 0.3555556; \ + auto tmpdr2 = d##r2 * 0.1777778; \ + auto tmpdr3 = d##r3 * 0.0888889; \ + auto tmpdr4 = d##r4 * 0.0444444; \ + wd##3 = ADDF(ADDF(ADDF(ADDF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \ + wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ + wd##4 = ADDF(SUBF(ADDF(SUBF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \ + wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ + tmpd0 = MULSF(d##0, 0.0111111); \ + tmpd1 = MULSF(d##1, 0.0222222); \ + tmpd2 = MULSF(d##2, 0.0444444); \ + tmpd3 = MULSF(d##3, 0.0888889); \ + tmpd4 = MULSF(d##4, 0.1777778); \ + tmpdr0 = d##r0 * 0.0111111; \ + tmpdr1 = d##r1 * 0.0222222; \ + tmpdr2 = d##r2 * 0.0444444; \ + tmpdr3 = d##r3 * 0.0888889; \ + tmpdr4 = d##r4 * 0.1777778; \ + wd##5 = ADDF(ADDF(ADDF(ADDF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \ + wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ + wd##6 = ADDF(SUBF(ADDF(SUBF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \ + wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ + wd##7 = d##4; \ + wd##r7 = d##r4; \ } while (0); -#define FILTER_TRANSFORM_FINAL(d, wd) \ - do { \ - wd##0 = d##0; \ - wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \ - wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \ - auto tmp0 = d##0 * 0.7111111 + d##2 * 0.1777778 + d##4 * 0.0444444; \ - auto tmp1 = d##1 * 0.3555556 + d##3 * 0.0888889; \ - wd##3 = tmp0 + tmp1; \ - wd##4 = tmp0 - tmp1; \ - tmp0 = d##0 * 0.0111111 + d##2 * 0.0444444 + d##4 * 0.1777778; \ - tmp1 = d##1 * 0.0222222 + d##3 * 0.0888889; \ - wd##5 = tmp0 + tmp1; \ - wd##6 = tmp0 - tmp1; \ - wd##7 = d##4; \ + /* + *wd##0 = d##0; + *wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; + *wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; + *auto tmp0 = d##0 * 0.7111111 + d##2 * 0.1777778 + d##4 * 0.0444444; + *auto tmp1 = d##1 * 0.3555556 + d##3 * 0.0888889; + *wd##3 = tmp0 + tmp1; + *wd##4 = tmp0 - tmp1; + *tmp0 = d##0 * 0.0111111 + d##2 * 0.0444444 + d##4 * 0.1777778; + *tmp1 = d##1 * 0.0222222 + d##3 * 0.0888889; + *wd##5 = tmp0 + tmp1; + *wd##6 = tmp0 - tmp1; + *wd##7 = d##4; + */ +#define FILTER_TRANSFORM_FINAL(d, wd) \ + do { \ + wd##0 = d##0; \ + wd##1 = \ + MULSFV2(ADDFV2(ADDFV2(ADDFV2(ADDFV2(d##0, d##1), d##2), d##3), d##4), \ + -0.2222222); \ + wd##2 = \ + MULSFV2(ADDFV2(SUBFV2(ADDFV2(SUBFV2(d##0, d##1), d##2), d##3), d##4), \ + -0.2222222); \ + auto tmp0 = \ + ADDFV2(ADDFV2(MULSFV2(d##0, 0.7111111), MULSFV2(d##2, 0.1777778)), \ + MULSFV2(d##4, 0.0444444)); \ + auto tmp1 = ADDFV2(MULSFV2(d##1, 0.3555556), MULSFV2(d##3, 0.0888889)); \ + wd##3 = ADDFV2(tmp0, tmp1); \ + wd##4 = SUBFV2(tmp0, tmp1); \ + tmp0 = \ + ADDFV2(ADDFV2(MULSFV2(d##0, 0.0111111), MULSFV2(d##2, 0.0444444)), \ + MULSFV2(d##4, 0.1777778)); \ + tmp1 = ADDFV2(MULSFV2(d##1, 0.0222222), MULSFV2(d##3, 0.0888889)); \ + wd##5 = ADDFV2(tmp0, tmp1); \ + wd##6 = SUBFV2(tmp0, tmp1); \ + wd##7 = d##4; \ } while (0); static void transform( const float* filter, float* filter_transform_buf, float* transform_mid_buf, @@ -89,7 +151,7 @@ struct FilterTransform4X5 { rep(ic, IC) { const float* fptr = filter + (oc * IC + ic) * 5 * 5; -#define cb(i) Vector g##i = Vector::load(fptr + 5 * i); +#define cb(i) GI_FLOAT32_t g##i = GiLoadFloat32(fptr + 5 * i); UNROLL_CALL_NOWRAPPER(5, cb); #undef cb @@ -97,7 +159,7 @@ struct FilterTransform4X5 { UNROLL_CALL_NOWRAPPER(5, cb); #undef cb -#define cb(i) Vector Gg##i; +#define cb(i) GI_FLOAT32_t Gg##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb @@ -105,11 +167,11 @@ struct FilterTransform4X5 { UNROLL_CALL_NOWRAPPER(8, cb); #undef cb -#define cb(i) Vector Ggt##i; +#define cb(i) GI_FLOAT32_V2_t Ggt##i; UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(i) Vector result##i; +#define cb(i) GI_FLOAT32_V2_t result##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb @@ -128,11 +190,11 @@ struct FilterTransform4X5 { GI_FLOAT32_t vgr1 = GiLoadFloat32(tmp); GiSetSubVectorFloat32V2(vgr, 0, vgr0); //{Ggr0, Ggr1, Ggr2, Ggr3}; GiSetSubVectorFloat32V2(vgr, 1, vgr1); //{Ggr4, Ggr5, Ggr6, Ggr7}; - Vector Ggt4(vgr); + GI_FLOAT32_V2_t Ggt4(vgr); TRANSPOSE_8x4(Gg, Ggt); FILTER_TRANSFORM_FINAL(Ggt, result); -#define cb(i) result##i.save(transform_mid_buf + i * alpha); +#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, result##i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb rep(i, alpha) rep(j, alpha) { @@ -145,31 +207,49 @@ struct FilterTransform4X5 { #undef FILTER_TRANSFORM #undef FILTER_TRANSFORM_FINAL +/* + * wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; + * auto tmp0 = d##2 - d##4 * 4.25f + d##6; + * auto tmp1 = d##1 - d##3 * 4.25f + d##5; + * wd##1 = tmp0 + tmp1; + * wd##2 = tmp0 - tmp1; + * tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; + * tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; + * wd##3 = tmp0 + tmp1; + * wd##4 = tmp0 - tmp1; + * tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; + * tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; + * wd##5 = tmp0 + tmp1; + * wd##6 = tmp0 - tmp1; + * wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; + */ struct InputTransform4X5 { -#define INPUT_TRANSFORM(d, wd) \ - do { \ - wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ - auto tmp0 = d##2 - d##4 * 4.25f + d##6; \ - auto tmp1 = d##1 - d##3 * 4.25f + d##5; \ - wd##1 = tmp0 + tmp1; \ - wd##2 = tmp0 - tmp1; \ - tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ - tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ - wd##3 = tmp0 + tmp1; \ - wd##4 = tmp0 - tmp1; \ - tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ - tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ - wd##5 = tmp0 + tmp1; \ - wd##6 = tmp0 - tmp1; \ - wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ +#define INPUT_TRANSFORM(d, wd) \ + do { \ + wd##0 = ADDFV2(SUBFV2(d##0, d##6), MULSFV2(SUBFV2(d##4, d##2), 5.25f)); \ + auto tmp0 = ADDFV2(SUBFV2(d##2, MULSFV2(d##4, 4.25f)), d##6); \ + auto tmp1 = ADDFV2(SUBFV2(d##1, MULSFV2(d##3, 4.25f)), d##5); \ + wd##1 = ADDFV2(tmp0, tmp1); \ + wd##2 = SUBFV2(tmp0, tmp1); \ + tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 4.0f), MULSFV2(d##4, 5.0f)), d##6); \ + tmp1 = \ + ADDFV2(SUBFV2(MULSFV2(d##1, 2.0f), MULSFV2(d##3, 2.5f)), \ + MULSFV2(d##5, 0.5f)); \ + wd##3 = ADDFV2(tmp0, tmp1); \ + wd##4 = SUBFV2(tmp0, tmp1); \ + tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 0.25f), MULSFV2(d##4, 1.25f)), d##6); \ + tmp1 = \ + ADDFV2(SUBFV2(MULSFV2(d##1, 0.5f), MULSFV2(d##3, 2.5f)), \ + MULSFV2(d##5, 2.0f)); \ + wd##5 = ADDFV2(tmp0, tmp1); \ + wd##6 = SUBFV2(tmp0, tmp1); \ + wd##7 = ADDFV2(SUBFV2(d##7, d##1), MULSFV2(SUBFV2(d##3, d##5), 5.25f)); \ } while (0) -#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ - GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2( \ - GiFixLenType2GiFloat32V2Type(CONCAT(s, i).value), 1)) -#define GET_VECTOR_LOW_ELEM(s, i, idx) \ - GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2( \ - GiFixLenType2GiFloat32V2Type(CONCAT(s, i).value), 0)) +#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ + GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 1)) +#define GET_VECTOR_LOW_ELEM(s, i, idx) \ + GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 0)) template static void transform( @@ -191,13 +271,13 @@ struct InputTransform4X5 { memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); } -#define cb(i) Vector d##i; +#define cb(i) GI_FLOAT32_V2_t d##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb if (inner) { const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; -#define cb(i) d##i = Vector::load(input_ptr + IW * i); +#define cb(i) d##i = GiLoadFloat32V2(input_ptr + IW * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb } else { @@ -212,21 +292,25 @@ struct InputTransform4X5 { input[ic * IH * IW + ih * IW + iw]; } } -#define cb(i) d##i = Vector::load(transform_mid_buf + alpha * i); +#define cb(i) d##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb } -#define cb(i) Vector wd##i, ret##i; +#define cb(i) GI_FLOAT32_V2_t wd##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb INPUT_TRANSFORM(d, wd); #if MEGDNN_AARCH64 +#define cb(i) GI_FLOAT32_V2_t ret##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + TRANSPOSE_8x8(wd, d); INPUT_TRANSFORM(d, ret); -#define cb(i) ret##i.save(transform_mid_buf + i * alpha); +#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb rep(i, alpha) rep(j, alpha) { @@ -283,12 +367,32 @@ struct InputTransform4X5 { }; #undef INPUT_TRANSFORM -#define OUTPUT_TRANSFORM(m, s) \ - do { \ - s0 = m0 + m1 + m2 + m3 + m4 + m5 + m6; \ - s1 = m1 - m2 + m3 * 0.5 - m4 * 0.5 + m5 * 2.0 - m6 * 2.0; \ - s2 = m1 + m2 + m3 * 0.25 + m4 * 0.25 + m5 * 4.0 + m6 * 4.0; \ - s3 = m1 - m2 + m3 * 0.125 - m4 * 0.125 + m5 * 8.0 - m6 * 8.0 + m7; \ +/* + * s0 = m0 + m1 + m2 + m3 + m4 + m5 + m6; + * s1 = m1 - m2 + m3 * 0.5 - m4 * 0.5 + m5 * 2.0 - m6 * 2.0; + * s2 = m1 + m2 + m3 * 0.25 + m4 * 0.25 + m5 * 4.0 + m6 * 4.0; + * s3 = m1 - m2 + m3 * 0.125 - m4 * 0.125 + m5 * 8.0 - m6 * 8.0 + m7; + */ +#define OUTPUT_TRANSFORM(m, s) \ + do { \ + s0 = ADDFV2( \ + ADDFV2(ADDFV2(ADDFV2(ADDFV2(ADDFV2(m0, m1), m2), m3), m4), m5), m6); \ + s1 = \ + SUBFV2(ADDFV2(SUBFV2(ADDFV2(SUBFV2(m1, m2), MULSFV2(m3, 0.5)), \ + MULSFV2(m4, 0.5)), \ + MULSFV2(m5, 2.0)), \ + MULSFV2(m6, 2.0)); \ + s2 = \ + ADDFV2(ADDFV2(ADDFV2(ADDFV2(ADDFV2(m1, m2), MULSFV2(m3, 0.25)), \ + MULSFV2(m4, 0.25)), \ + MULSFV2(m5, 4.0)), \ + MULSFV2(m6, 4.0)); \ + s3 = ADDFV2( \ + SUBFV2(ADDFV2(SUBFV2(ADDFV2(SUBFV2(m1, m2), MULSFV2(m3, 0.125)), \ + MULSFV2(m4, 0.125)), \ + MULSFV2(m5, 8.0)), \ + MULSFV2(m6, 8.0)), \ + m7); \ } while (0) template struct OutputTransform4X5 { @@ -316,14 +420,15 @@ struct OutputTransform4X5 { UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); #undef cb -#define cb(i) auto m##i = Vector::load(transform_mid_buf + alpha * i); +#define cb(i) auto m##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb -#define cb(i) Vector s##i; +#define cb(i) GI_FLOAT32_V2_t s##i; UNROLL_CALL_NOWRAPPER(4, cb); #undef cb OUTPUT_TRANSFORM(m, s); + #define cb(i) \ do { \ auto add12 = GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ @@ -355,10 +460,10 @@ struct OutputTransform4X5 { GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - item0 = GiAddFloat32(item0, bias0); + item0 = ADDF(item0, bias0); } else if (bmode == BiasMode::BIAS) { bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); - item0 = GiAddFloat32(item0, bias0); + item0 = ADDF(item0, bias0); } item0 = op(item0); GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp index 1b3611fa..5c3aef6e 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp @@ -1,7 +1,6 @@ #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h" -#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/conv_bias/gi/fp32/helper.h" @@ -15,23 +14,39 @@ using namespace megdnn; using namespace fallback; namespace { +/* + * wd##0 = d##0; + * auto tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; + * auto tmp1 = d##1 * 0.3555556f + d##3 * 0.0888889f; + * wd##1 = tmp0 + tmp1; + * wd##2 = tmp0 - tmp1; + * tmp0 = (d##0 + d##2) * -0.2222222f; + * tmp1 = (d##1 + d##3) * -0.2222222f; + * wd##3 = tmp0 + tmp1; + * wd##4 = tmp0 - tmp1; + * tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; + * tmp1 = d##1 * 0.0222222f + d##3 * 0.0888889f; + * wd##5 = tmp0 + tmp1; + * wd##6 = tmp0 - tmp1; + * wd##7 = d##3; + */ struct FilterTransform5X4 { -#define FILTER_TRANSFORM(d, wd) \ - do { \ - wd##0 = d##0; \ - auto tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \ - auto tmp1 = d##1 * 0.3555556f + d##3 * 0.0888889f; \ - wd##1 = tmp0 + tmp1; \ - wd##2 = tmp0 - tmp1; \ - tmp0 = (d##0 + d##2) * -0.2222222f; \ - tmp1 = (d##1 + d##3) * -0.2222222f; \ - wd##3 = tmp0 + tmp1; \ - wd##4 = tmp0 - tmp1; \ - tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \ - tmp1 = d##1 * 0.0222222f + d##3 * 0.0888889f; \ - wd##5 = tmp0 + tmp1; \ - wd##6 = tmp0 - tmp1; \ - wd##7 = d##3; \ +#define FILTER_TRANSFORM(d, wd, ADDC, SUBC, MULC) \ + do { \ + wd##0 = d##0; \ + auto tmp0 = ADDC(MULC(d##0, 0.7111111f), MULC(d##2, 0.1777778f)); \ + auto tmp1 = ADDC(MULC(d##1, 0.3555556f), MULC(d##3, 0.0888889f)); \ + wd##1 = ADDC(tmp0, tmp1); \ + wd##2 = SUBC(tmp0, tmp1); \ + tmp0 = MULC(ADDC(d##0, d##2), -0.2222222f); \ + tmp1 = MULC(ADDC(d##1, d##3), -0.2222222f); \ + wd##3 = ADDC(tmp0, tmp1); \ + wd##4 = SUBC(tmp0, tmp1); \ + tmp0 = ADDC(MULC(d##0, 0.0111111f), MULC(d##2, 0.0444444f)); \ + tmp1 = ADDC(MULC(d##1, 0.0222222f), MULC(d##3, 0.0888889f)); \ + wd##5 = ADDC(tmp0, tmp1); \ + wd##6 = SUBC(tmp0, tmp1); \ + wd##7 = d##3; \ } while (0) static void transform( @@ -53,28 +68,27 @@ struct FilterTransform5X4 { rep(ic, IC) { const float* fptr = filter + (oc * IC + ic) * 4 * 4; -#define cb(i) Vector g##i = Vector::load(fptr + 4 * i); +#define cb(i) GI_FLOAT32_t g##i = GiLoadFloat32(fptr + 4 * i); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(i) Vector wd##i; +#define cb(i) GI_FLOAT32_t wd##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb -#define cb(i) Vector wdt##i; + FILTER_TRANSFORM(g, wd, ADDF, SUBF, MULSF); +#if MEGDNN_AARCH64 +#define cb(i) GI_FLOAT32_V2_t wdt##i; UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(i) Vector ret##i; +#define cb(i) GI_FLOAT32_V2_t ret##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb - - FILTER_TRANSFORM(g, wd); -#if MEGDNN_AARCH64 TRANSPOSE_8x4(wd, wdt); - FILTER_TRANSFORM(wdt, ret); + FILTER_TRANSFORM(wdt, ret, ADDFV2, SUBFV2, MULSFV2); -#define cb(i) ret##i.save(transform_mid_buf + i * alpha); +#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb rep(i, alpha) rep(j, alpha) { @@ -104,8 +118,7 @@ struct FilterTransform5X4 { mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ mid_buf1 += 8; \ } while (0); -#define GET_VECTOR_ELEM(s, i, idx) \ - GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value)) +#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i)) float* mid_buf1 = transform_mid_buf; UNROLL_CALL_NOWRAPPER(8, cb); @@ -123,29 +136,49 @@ struct FilterTransform5X4 { #undef FILTER_TRANSFORM #undef GET_VECTOR_ELEM +/* + * wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; + * auto tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; + * auto tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; + * wd##1 = tmp0 + tmp1; + * wd##2 = tmp0 - tmp1; + * tmp0 = d##2 - d##4 * 4.25f + d##6; + * tmp1 = d##1 - d##3 * 4.25f + d##5; + * wd##3 = tmp0 + tmp1; + * wd##4 = tmp0 - tmp1; + * tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; + * tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; + * wd##5 = tmp0 + tmp1; + * wd##6 = tmp0 - tmp1; + * wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; + */ struct InputTransform5X4 { -#define INPUT_TRANSFORM(d, wd) \ - do { \ - wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ - auto tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ - auto tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ - wd##1 = tmp0 + tmp1; \ - wd##2 = tmp0 - tmp1; \ - tmp0 = d##2 - d##4 * 4.25f + d##6; \ - tmp1 = d##1 - d##3 * 4.25f + d##5; \ - wd##3 = tmp0 + tmp1; \ - wd##4 = tmp0 - tmp1; \ - tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ - tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ - wd##5 = tmp0 + tmp1; \ - wd##6 = tmp0 - tmp1; \ - wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ +#define INPUT_TRANSFORM(d, wd) \ + do { \ + wd##0 = ADDFV2(SUBFV2(d##0, d##6), MULSFV2(SUBFV2(d##4, d##2), 5.25f)); \ + auto tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 4.0f), MULSFV2(d##4, 5.0f)), d##6); \ + auto tmp1 = \ + ADDFV2(SUBFV2(MULSFV2(d##1, 2.0f), MULSFV2(d##3, 2.5f)), \ + MULSFV2(d##5, 0.5f)); \ + wd##1 = ADDFV2(tmp0, tmp1); \ + wd##2 = SUBFV2(tmp0, tmp1); \ + tmp0 = ADDFV2(SUBFV2(d##2, MULSFV2(d##4, 4.25f)), d##6); \ + tmp1 = ADDFV2(SUBFV2(d##1, MULSFV2(d##3, 4.25f)), d##5); \ + wd##3 = ADDFV2(tmp0, tmp1); \ + wd##4 = SUBFV2(tmp0, tmp1); \ + tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 0.25f), MULSFV2(d##4, 1.25f)), d##6); \ + tmp1 = \ + ADDFV2(SUBFV2(MULSFV2(d##1, 0.5f), MULSFV2(d##3, 2.5f)), \ + MULSFV2(d##5, 2.0f)); \ + wd##5 = ADDFV2(tmp0, tmp1); \ + wd##6 = SUBFV2(tmp0, tmp1); \ + wd##7 = ADDFV2(SUBFV2(d##7, d##1), MULSFV2(SUBFV2(d##3, d##5), 5.25f)); \ } while (0) #define GET_VECTOR_HIGH_ELEM(s, i, idx) \ - GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[1])) + GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 1)) #define GET_VECTOR_LOW_ELEM(s, i, idx) \ - GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[0])) + GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 0)) template static void transform( @@ -168,13 +201,13 @@ struct InputTransform5X4 { memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); } -#define cb(i) Vector d##i; +#define cb(i) GI_FLOAT32_V2_t d##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb if (inner) { const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; -#define cb(i) d##i = Vector::load(input_ptr + IW * i); +#define cb(i) d##i = GiLoadFloat32V2(input_ptr + IW * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb } else { @@ -189,21 +222,25 @@ struct InputTransform5X4 { input[ic * IH * IW + ih * IW + iw]; } } -#define cb(i) d##i = Vector::load(transform_mid_buf + alpha * i); +#define cb(i) d##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb } -#define cb(i) Vector wd##i, ret##i; +#define cb(i) GI_FLOAT32_V2_t wd##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb INPUT_TRANSFORM(d, wd); #if MEGDNN_AARCH64 +#define cb(i) GI_FLOAT32_V2_t ret##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + TRANSPOSE_8x8(wd, d); INPUT_TRANSFORM(d, ret); -#define cb(i) ret##i.save(transform_mid_buf + i * alpha); +#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb rep(i, alpha) rep(j, alpha) { @@ -260,24 +297,48 @@ struct InputTransform5X4 { }; #undef INPUT_TRANSFORM -#define OUTPUT_TRANSFORM(m, s) \ - do { \ - auto m1addm2 = m##1 + m##2; \ - auto m1subm2 = m##1 - m##2; \ - auto m3addm4 = m##3 + m##4; \ - auto m3subm4 = m##3 - m##4; \ - auto m5addm6 = (m##5 + m##6); \ - auto m5subm6 = (m##5 - m##6); \ - s##0 = m##0; \ - CONCAT(s, 0).add(m1addm2).add(m3addm4).add(m5addm6); \ - CONCAT(s, 1) = m3subm4; \ - CONCAT(s, 1).mla(m1subm2, 0.5f).mla(m5subm6, 2.0f); \ - CONCAT(s, 2) = m3addm4; \ - CONCAT(s, 2).mla(m1addm2, 0.25f).mla(m5addm6, 4.0f); \ - CONCAT(s, 3) = m3subm4; \ - CONCAT(s, 3).mla(m1subm2, 0.125f).mla(m5subm6, 8.0f); \ - CONCAT(s, 4) = m##7; \ - CONCAT(s, 4).mla(m1addm2, 0.0625f).add(m3addm4).mla(m5addm6, 16.0f); \ +/* + * auto m1addm2 = m##1 + m##2; + * auto m1subm2 = m##1 - m##2; + * auto m3addm4 = m##3 + m##4; + * auto m3subm4 = m##3 - m##4; + * auto m5addm6 = (m##5 + m##6); + * auto m5subm6 = (m##5 - m##6); + * s##0 = m##0; + * CONCAT(s, 0).add(m1addm2).add(m3addm4).add(m5addm6); + * CONCAT(s, 1) = m3subm4; + * CONCAT(s, 1).mla(m1subm2, 0.5f).mla(m5subm6, 2.0f); + * CONCAT(s, 2) = m3addm4; + * CONCAT(s, 2).mla(m1addm2, 0.25f).mla(m5addm6, 4.0f); + * CONCAT(s, 3) = m3subm4; + * CONCAT(s, 3).mla(m1subm2, 0.125f).mla(m5subm6, 8.0f); + * CONCAT(s, 4) = m##7; + * CONCAT(s, 4).mla(m1addm2, 0.0625f).add(m3addm4).mla(m5addm6, 16.0f); + */ +#define OUTPUT_TRANSFORM(m, s) \ + do { \ + auto m1addm2 = ADDFV2(m##1, m##2); \ + auto m1subm2 = SUBFV2(m##1, m##2); \ + auto m3addm4 = ADDFV2(m##3, m##4); \ + auto m3subm4 = SUBFV2(m##3, m##4); \ + auto m5addm6 = ADDFV2(m##5, m##6); \ + auto m5subm6 = SUBFV2(m##5, m##6); \ + s##0 = m##0; \ + CONCAT(s, 0) = \ + ADDFV2(ADDFV2(ADDFV2(CONCAT(s, 0), m1addm2), m3addm4), m5addm6); \ + CONCAT(s, 1) = m3subm4; \ + CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m1subm2, 0.5f); \ + CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m5subm6, 2.0f); \ + CONCAT(s, 2) = m3addm4; \ + CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m1addm2, 0.25f); \ + CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m5addm6, 4.0f); \ + CONCAT(s, 3) = m3subm4; \ + CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m1subm2, 0.125f); \ + CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m5subm6, 8.0f); \ + CONCAT(s, 4) = m##7; \ + CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m1addm2, 0.0625f); \ + CONCAT(s, 4) = ADDFV2(CONCAT(s, 4), m3addm4); \ + CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m5addm6, 16.0f); \ } while (0) #if defined(__GNUC__) && !defined(__llvm__) && !defined(_MSC_VER) @@ -314,11 +375,11 @@ struct OutputTransform5X4 { UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); #undef cb -#define cb(i) auto m##i = Vector::load(transform_mid_buf + alpha * i); +#define cb(i) auto m##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb -#define cb(i) Vector s##i, ret##i; - UNROLL_CALL_NOWRAPPER(8, cb); +#define cb(i) GI_FLOAT32_V2_t s##i; + UNROLL_CALL_NOWRAPPER(5, cb); #undef cb OUTPUT_TRANSFORM(m, s); diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp index 16af210d..4adf2af0 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp @@ -3,7 +3,6 @@ #include "src/fallback/conv_bias/gi/fp32/filter_transform.h" #include "src/fallback/conv_bias/gi/fp32/helper.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h" -#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/elemwise_helper/op_unary.h" #include "src/naive/matrix_mul/matrix_mul_helper.h" @@ -27,28 +26,31 @@ namespace { * wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) * wd7 = (d7 - d1) + 5.25 * (d3 - d5) */ -#define INPUT_TRANSFORM(d, wd) \ - do { \ - wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ - auto tmp0 = d##6 + d##2 - d##4 * 4.25f; \ - auto tmp1 = d##1 + d##5 - d##3 * 4.25f; \ - wd##1 = tmp0 + tmp1; \ - wd##2 = tmp0 - tmp1; \ - tmp0 = d##6 + d##2 * 0.25f - d##4 * 1.25f; \ - tmp1 = (d##5 + d##1 * 0.25f - d##3 * 1.25f) * 2.0f; \ - wd##3 = tmp0 + tmp1; \ - wd##4 = tmp0 - tmp1; \ - tmp0 = d6 - d4 * 5.0f + d2 * 4.0f; \ - tmp1 = (d1 + d5 * 0.25f - d3 * 1.25f) * 2.0f; \ - wd##5 = tmp0 + tmp1; \ - wd##6 = tmp0 - tmp1; \ - wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ +#define INPUT_TRANSFORM(d, wd) \ + do { \ + wd##0 = ADDFV2(SUBFV2(d##0, d##6), MULSFV2(SUBFV2(d##4, d##2), 5.25f)); \ + auto tmp0 = SUBFV2(ADDFV2(d##6, d##2), MULSFV2(d##4, 4.25f)); \ + auto tmp1 = SUBFV2(ADDFV2(d##1, d##5), MULSFV2(d##3, 4.25f)); \ + wd##1 = ADDFV2(tmp0, tmp1); \ + wd##2 = SUBFV2(tmp0, tmp1); \ + tmp0 = SUBFV2(ADDFV2(d##6, MULSFV2(d##2, 0.25f)), MULSFV2(d##4, 1.25f)); \ + tmp1 = MULSFV2( \ + (SUBFV2(ADDFV2(d##5, MULSFV2(d##1, 0.25f)), MULSFV2(d##3, 1.25f))), \ + 2.0f); \ + wd##3 = ADDFV2(tmp0, tmp1); \ + wd##4 = SUBFV2(tmp0, tmp1); \ + tmp0 = ADDFV2(SUBFV2(d6, MULSFV2(d4, 5.0f)), MULSFV2(d2, 4.0f)); \ + tmp1 = MULSFV2( \ + (SUBFV2(ADDFV2(d1, MULSFV2(d5, 0.25f)), MULSFV2(d3, 1.25f))), 2.0f); \ + wd##5 = ADDFV2(tmp0, tmp1); \ + wd##6 = SUBFV2(tmp0, tmp1); \ + wd##7 = ADDFV2(SUBFV2(d##7, d##1), MULSFV2(SUBFV2(d##3, d##5), 5.25f)); \ } while (0); #define GET_VECTOR_HIGH_ELEM(s, i, idx) \ - GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[1])) + GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 1)) #define GET_VECTOR_LOW_ELEM(s, i, idx) \ - GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[0])) + GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 0)) struct InputTransform6X3 { template static void transform( @@ -60,13 +62,13 @@ struct InputTransform6X3 { memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); } -#define cb(i) Vector d##i; +#define cb(i) GI_FLOAT32_V2_t d##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb if (inner) { const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; -#define cb(i) d##i = Vector::load(input_ptr + IW * i); +#define cb(i) d##i = GiLoadFloat32V2(input_ptr + IW * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb } else { @@ -81,22 +83,25 @@ struct InputTransform6X3 { input[ic * IH * IW + ih * IW + iw]; } } -#define cb(i) d##i = Vector::load(transform_mid_buf + alpha * i); +#define cb(i) d##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb } -#define cb(i) Vector wd##i, ret##i; +#define cb(i) GI_FLOAT32_V2_t wd##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb INPUT_TRANSFORM(d, wd); #if MEGDNN_AARCH64 +#define cb(i) GI_FLOAT32_V2_t ret##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb TRANSPOSE_8x8(wd, d); INPUT_TRANSFORM(d, ret); -#define cb(i) ret##i.save(transform_mid_buf + i * alpha); +#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb @@ -174,26 +179,34 @@ struct InputTransform6X3 { * s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6) / 32 * s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) / 32 + m7 */ -#define OUTPUT_TRANSFORM(m, s) \ - do { \ - auto m1addm2 = m##1 + m##2; \ - auto m1subm2 = m##1 - m##2; \ - auto m3addm4 = m##3 + m##4; \ - auto m3subm4 = m##3 - m##4; \ - auto m5addm6 = (m##5 + m##6) * 0.03125f; \ - auto m5subm6 = (m##5 - m##6) * 0.03125f; \ - s##0 = m##0; \ - CONCAT(s, 0).mla(m5addm6, 32.f).add(m3addm4).add(m1addm2); \ - CONCAT(s, 1) = m1subm2; \ - CONCAT(s, 1).mla(m3subm4, 2.f).mla(m5subm6, 16.f); \ - CONCAT(s, 2) = m1addm2; \ - CONCAT(s, 2).mla(m3addm4, 4.f).mla(m5addm6, 8.f); \ - CONCAT(s, 3) = m1subm2; \ - CONCAT(s, 3).mla(m3subm4, 8.f).mla(m5subm6, 4.f); \ - CONCAT(s, 4) = m1addm2; \ - CONCAT(s, 4).mla(m3addm4, 16.f).mla(m5addm6, 2.f); \ - CONCAT(s, 5) = m1subm2; \ - CONCAT(s, 5).mla(m3subm4, 32.f).add(m5subm6).add(m##7); \ +#define OUTPUT_TRANSFORM(m, s) \ + do { \ + auto m1addm2 = ADDFV2(m##1, m##2); \ + auto m1subm2 = SUBFV2(m##1, m##2); \ + auto m3addm4 = ADDFV2(m##3, m##4); \ + auto m3subm4 = SUBFV2(m##3, m##4); \ + auto m5addm6 = MULSFV2((ADDFV2(m##5, m##6)), 0.03125f); \ + auto m5subm6 = MULSFV2((SUBFV2(m##5, m##6)), 0.03125f); \ + s##0 = m##0; \ + CONCAT(s, 0) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 0), m5addm6, 32.f); \ + CONCAT(s, 0) = ADDFV2(CONCAT(s, 0), m3addm4); \ + CONCAT(s, 0) = ADDFV2(CONCAT(s, 0), m1addm2); \ + CONCAT(s, 1) = m1subm2; \ + CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m3subm4, 2.f); \ + CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m5subm6, 16.f); \ + CONCAT(s, 2) = m1addm2; \ + CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m3addm4, 4.f); \ + CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m5addm6, 8.f); \ + CONCAT(s, 3) = m1subm2; \ + CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m3subm4, 8.f); \ + CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m5subm6, 4.f); \ + CONCAT(s, 4) = m1addm2; \ + CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m3addm4, 16.f); \ + CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m5addm6, 2.f); \ + CONCAT(s, 5) = m1subm2; \ + CONCAT(s, 5) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 5), m3subm4, 32.f); \ + CONCAT(s, 5) = ADDFV2(CONCAT(s, 5), m5subm6); \ + CONCAT(s, 5) = ADDFV2(CONCAT(s, 5), m##7); \ } while (0); #if defined(__GNUC__) && !defined(__llvm__) && !defined(_MSC_VER) @@ -224,11 +237,11 @@ struct OutputTransform6X3 { UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); #undef cb -#define cb(i) auto m##i = Vector::load(transform_mid_buf + alpha * i); +#define cb(i) auto m##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb -#define cb(i) Vector s##i, ret##i; - UNROLL_CALL_NOWRAPPER(8, cb); +#define cb(i) GI_FLOAT32_V2_t s##i; + UNROLL_CALL_NOWRAPPER(6, cb); #undef cb OUTPUT_TRANSFORM(m, s); diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp index 46a6c312..d61e6a8e 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp @@ -4,7 +4,6 @@ #include "src/fallback/conv_bias/gi/fp32/filter_transform.h" #include "src/fallback/conv_bias/gi/fp32/helper.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h" -#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/elemwise_helper/op_unary.h" @@ -76,8 +75,7 @@ struct InputTransform6X3 { size_t nr_units_in_tile, size_t ic, size_t IC) { constexpr size_t alpha = 6 + 3 - 1; // BT * d * B -#define cb(m, n) \ - Vector d##m##n = Vector::load(patchT + m * 8 * 4 + n * 4); +#define cb(m, n) GI_FLOAT32_t d##m##n = GiLoadFloat32(patchT + m * 8 * 4 + n * 4); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); #undef cb @@ -91,36 +89,103 @@ struct InputTransform6X3 { //! 0 1 -1 2 -2 0.5 -0.5 -5.25 //! -1 1 1 1 1 1 1 0 //! 0 0 0 0 0 0 0 1 +/* + * auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; + * auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; + * auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; + * auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + + * d5##m * 2.f + d6##m; + * auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - d4##m * 1.25f - + * d5##m * 2.f + d6##m; + * auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + + * d5##m * 0.5f + d6##m; + * auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - + * d5##m * 0.5f + d6##m; + * auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; + */ #define cb(m) \ - auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \ - auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \ - auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \ - auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \ - d5##m * 2.f + d6##m; \ - auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - d4##m * 1.25f - \ - d5##m * 2.f + d6##m; \ - auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \ - d5##m * 0.5f + d6##m; \ - auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \ - d5##m * 0.5f + d6##m; \ - auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; + auto t0##m = SUBF(ADDF(d0##m, MULSF(SUBF(d4##m, d2##m), 5.25f)), d6##m); \ + auto t1##m = \ + SUBF(ADDF(ADDF(ADDF(d1##m, d2##m), d5##m), d6##m), \ + MULSF(ADDF(d3##m, d4##m), 4.25f)); \ + auto t2##m = \ + ADDF(SUBF(ADDF(d2##m, d6##m), ADDF(d1##m, d5##m)), \ + MULSF(SUBF(d3##m, d4##m), 4.25f)); \ + auto t3##m = \ + ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(d1##m, 0.5f), MULSF(d2##m, 0.25f)), \ + MULSF(d3##m, 2.5f)), \ + MULSF(d4##m, 1.25f)), \ + MULSF(d5##m, 2.f)), \ + d6##m); \ + auto t4##m = \ + ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(d1##m, (-0.5f)), MULSF(d2##m, 0.25f)), \ + MULSF(d3##m, 2.5f)), \ + MULSF(d4##m, 1.25f)), \ + MULSF(d5##m, 2.f)), \ + d6##m); \ + auto t5##m = \ + ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(d1##m, 2.f), MULSF(d2##m, 4.f)), \ + MULSF(d3##m, 2.5f)), \ + MULSF(d4##m, 5.f)), \ + MULSF(d5##m, 0.5f)), \ + d6##m); \ + auto t6##m = \ + ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(d1##m, (-2.f)), MULSF(d2##m, 4.f)), \ + MULSF(d3##m, 2.5f)), \ + MULSF(d4##m, 5.f)), \ + MULSF(d5##m, 0.5f)), \ + d6##m); \ + auto t7##m = ADDF(SUBF(d7##m, d1##m), MULSF(SUBF(d3##m, d5##m), 5.25f)); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb -#define cb(m) \ - d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \ - d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - (t##m##3 + t##m##4) * 4.25f; \ - d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + (t##m##3 - t##m##4) * 4.25f; \ - d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - t##m##4 * 1.25f + \ - t##m##5 * 2.f + t##m##6; \ - d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - t##m##4 * 1.25f - \ - t##m##5 * 2.f + t##m##6; \ - d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \ - t##m##5 * 0.5f + t##m##6; \ - d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - t##m##4 * 5.f - \ - t##m##5 * 0.5f + t##m##6; \ - d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; + /* + * d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; + * d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - (t##m##3 + t##m##4) + * * 4.25f; d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + (t##m##3 - + * t##m##4) * 4.25f; d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f + * - t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; d##m##4 = t##m##1 * (-0.5f) + + * t##m##2 * 0.25f + t##m##3 * 2.5f - t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; + * d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + + * t##m##5 * 0.5f + t##m##6; + * d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - t##m##4 * 5.f - + * t##m##5 * 0.5f + t##m##6; + * d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; + */ +#define cb(m) \ + d##m##0 = SUBF(ADDF(t##m##0, MULSF(SUBF(t##m##4, t##m##2), 5.25f)), t##m##6); \ + d##m##1 = \ + SUBF(ADDF(ADDF(ADDF(t##m##1, t##m##2), t##m##5), t##m##6), \ + MULSF(ADDF(t##m##3, t##m##4), 4.25f)); \ + d##m##2 = \ + ADDF(SUBF(ADDF(t##m##2, t##m##6), ADDF(t##m##1, t##m##5)), \ + MULSF(SUBF(t##m##3, t##m##4), 4.25f)); \ + d##m##3 = \ + ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(t##m##1, 0.5f), MULSF(t##m##2, 0.25f)), \ + MULSF(t##m##3, 2.5f)), \ + MULSF(t##m##4, 1.25f)), \ + MULSF(t##m##5, 2.f)), \ + t##m##6); \ + d##m##4 = \ + ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(t##m##1, (-0.5f)), MULSF(t##m##2, 0.25f)), \ + MULSF(t##m##3, 2.5f)), \ + MULSF(t##m##4, 1.25f)), \ + MULSF(t##m##5, 2.f)), \ + t##m##6); \ + d##m##5 = \ + ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(t##m##1, 2.f), MULSF(t##m##2, 4.f)), \ + MULSF(t##m##3, 2.5f)), \ + MULSF(t##m##4, 5.f)), \ + MULSF(t##m##5, 0.5f)), \ + t##m##6); \ + d##m##6 = \ + ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(t##m##1, (-2.f)), MULSF(t##m##2, 4.f)), \ + MULSF(t##m##3, 2.5f)), \ + MULSF(t##m##4, 5.f)), \ + MULSF(t##m##5, 0.5f)), \ + t##m##6); \ + d##m##7 = ADDF(SUBF(t##m##7, t##m##1), MULSF(SUBF(t##m##3, t##m##5), 5.25f)); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb @@ -128,9 +193,10 @@ struct InputTransform6X3 { size_t ICB = IC / 4; size_t icb = ic / 4; #define cb(m, n) \ - d##m##n.save( \ + GiStoreFloat32( \ input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ - icb * nr_units_in_tile * 4 + unit_idx * 4); + icb * nr_units_in_tile * 4 + unit_idx * 4, \ + d##m##n); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb) #undef cb } @@ -152,7 +218,7 @@ struct OutputTransform6X3 { size_t ocb = oc_index / 4; #define cb(m, n) \ - auto v##m##n = Vector::load( \ + auto v##m##n = GiLoadFloat32( \ output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ ocb * nr_units_in_tile * 4 + unit_idx * 4); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); @@ -171,56 +237,88 @@ struct OutputTransform6X3 { * 0 0.0 0 0 0 1 */ - Vector v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; -#define cb(m) \ - v1addv2 = v1##m + v2##m; \ - v1subv2 = v1##m - v2##m; \ - v3addv4 = v3##m + v4##m; \ - v3subv4 = v3##m - v4##m; \ - v5addv6 = v5##m + v6##m; \ - v5subv6 = v5##m - v6##m; \ - auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \ - auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ - auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ - auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ - auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ - auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; + /* + * v1addv2 = v1##m + v2##m; + * v1subv2 = v1##m - v2##m; + * v3addv4 = v3##m + v4##m; + * v3subv4 = v3##m - v4##m; + * v5addv6 = v5##m + v6##m; + * v5subv6 = v5##m - v6##m; + * auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; + * auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; + * auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; + * auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; + * auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; + * auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; + */ + GI_FLOAT32_t v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; +#define cb(m) \ + v1addv2 = ADDF(v1##m, v2##m); \ + v1subv2 = SUBF(v1##m, v2##m); \ + v3addv4 = ADDF(v3##m, v4##m); \ + v3subv4 = SUBF(v3##m, v4##m); \ + v5addv6 = ADDF(v5##m, v6##m); \ + v5subv6 = SUBF(v5##m, v6##m); \ + auto t0##m = ADDF(ADDF(ADDF(v0##m, v1addv2), v3addv4), v5addv6); \ + auto t1##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \ + auto t2##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \ + auto t3##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \ + auto t4##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \ + auto t5##m = \ + ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ + v7##m); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb -#define cb(m) \ - v1addv2 = t##m##1 + t##m##2; \ - v1subv2 = t##m##1 - t##m##2; \ - v3addv4 = t##m##3 + t##m##4; \ - v3subv4 = t##m##3 - t##m##4; \ - v5addv6 = t##m##5 + t##m##6; \ - v5subv6 = t##m##5 - t##m##6; \ - v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \ - v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ - v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ - v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ - v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ - v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; +/* + * v1addv2 = t##m##1 + t##m##2; + * v1subv2 = t##m##1 - t##m##2; + * v3addv4 = t##m##3 + t##m##4; + * v3subv4 = t##m##3 - t##m##4; + * v5addv6 = t##m##5 + t##m##6; + * v5subv6 = t##m##5 - t##m##6; + * v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; + * v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; + * v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; + * v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; + * v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; + * v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; + */ +#define cb(m) \ + v1addv2 = ADDF(t##m##1, t##m##2); \ + v1subv2 = SUBF(t##m##1, t##m##2); \ + v3addv4 = ADDF(t##m##3, t##m##4); \ + v3subv4 = SUBF(t##m##3, t##m##4); \ + v5addv6 = ADDF(t##m##5, t##m##6); \ + v5subv6 = SUBF(t##m##5, t##m##6); \ + v##m##0 = ADDF(ADDF(ADDF(t##m##0, v1addv2), v3addv4), v5addv6); \ + v##m##1 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \ + v##m##2 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \ + v##m##3 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \ + v##m##4 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \ + v##m##5 = \ + ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ + t##m##7); UNROLL_CALL_NOWRAPPER(6, cb); #undef cb - Vector vbias; + GI_FLOAT32_t vbias; if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - vbias = Vector::load(bias + oc); + vbias = GiLoadFloat32(bias + oc); -#define cb(m, n) v##m##n += vbias; +#define cb(m, n) v##m##n = GiAddFloat32(v##m##n, vbias); UNROLL_CALL_RAW_D2(6, 6, cb); #undef cb } if (bmode != BiasMode::BIAS) { -#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value)); +#define cb(m, n) v##m##n = op(CONCAT(v##m, n)); UNROLL_CALL_RAW_D2(6, 6, cb); #undef cb } -#define cb(m, n) CONCAT(v##m, n).save(transform_mid_buf + (m * 6 + n) * 4); +#define cb(m, n) GiStoreFloat32(transform_mid_buf + (m * 6 + n) * 4, CONCAT(v##m, n)); UNROLL_CALL_RAW_D2(6, 6, cb); #undef cb diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp index c0413335..800ee78b 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp @@ -1,7 +1,6 @@ #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h" -#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/conv_bias/gi/fp32/helper.h" @@ -29,7 +28,7 @@ struct InputTransformF23_NCHW44 { size_t iw4_start = iw_start * pack_size; size_t ICB = IC / pack_size; -#define cb(m, n) Vector d##m##n; +#define cb(m, n) GI_FLOAT32_t d##m##n; UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb @@ -40,7 +39,7 @@ struct InputTransformF23_NCHW44 { MEGDNN_MARK_USED_VAR(patchT); const float* input_ptr = input + icb * IH * IW4 + ih_start * IW4 + iw4_start; -#define cb(n, m) d##m##n = Vector::load(input_ptr + pack_size * n); +#define cb(n, m) d##m##n = GiLoadFloat32(input_ptr + pack_size * n); UNROLL_CALL_RAW(4, cb, 0); input_ptr += IW4; @@ -66,7 +65,7 @@ struct InputTransformF23_NCHW44 { } } #define cb(m, n) \ - d##m##n = Vector::load(patchT + m * alpha * pack_size + n * pack_size); + d##m##n = GiLoadFloat32(patchT + m * alpha * pack_size + n * pack_size); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb } @@ -74,29 +73,30 @@ struct InputTransformF23_NCHW44 { //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 //! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 -#define cb(m) \ - auto t0##m = d0##m - d2##m; \ - auto t1##m = d1##m + d2##m; \ - auto t2##m = d2##m - d1##m; \ - auto t3##m = d3##m - d1##m; +#define cb(m) \ + auto t0##m = SUBF(d0##m, d2##m); \ + auto t1##m = ADDF(d1##m, d2##m); \ + auto t2##m = SUBF(d2##m, d1##m); \ + auto t3##m = SUBF(d3##m, d1##m); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(m) \ - d##m##0 = t##m##0 - t##m##2; \ - d##m##1 = t##m##1 + t##m##2; \ - d##m##2 = t##m##2 - t##m##1; \ - d##m##3 = t##m##3 - t##m##1; +#define cb(m) \ + d##m##0 = SUBF(t##m##0, t##m##2); \ + d##m##1 = ADDF(t##m##1, t##m##2); \ + d##m##2 = SUBF(t##m##2, t##m##1); \ + d##m##3 = SUBF(t##m##3, t##m##1); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(m, n) \ - d##m##n.save( \ - input_transform_buf + \ - (m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + unit_idx * pack_size); +#define cb(m, n) \ + GiStoreFloat32( \ + input_transform_buf + \ + (m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d##m##n); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) #undef cb } @@ -118,7 +118,7 @@ struct OutputTransformF23_NCHW44 { size_t ocb = oc_index / pack_size; #define cb(m, n) \ - auto v##m##n = Vector::load( \ + auto v##m##n = GiLoadFloat32( \ output_transform_buf + \ (m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); @@ -130,30 +130,30 @@ struct OutputTransformF23_NCHW44 { //! v20 v21 v22 v23 1 -1 //! v30 v31 v32 v33 0 1 -#define cb(m) \ - auto t0##m = v0##m + v1##m + v2##m; \ - auto t1##m = v1##m - v2##m + v3##m; +#define cb(m) \ + auto t0##m = ADDF(ADDF(v0##m, v1##m), v2##m); \ + auto t1##m = ADDF(SUBF(v1##m, v2##m), v3##m); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(m) \ - v##m##0 = t##m##0 + t##m##1 + t##m##2; \ - v##m##1 = t##m##1 - t##m##2 + t##m##3; +#define cb(m) \ + v##m##0 = ADDF(ADDF(t##m##0, t##m##1), t##m##2); \ + v##m##1 = ADDF(SUBF(t##m##1, t##m##2), t##m##3); UNROLL_CALL_NOWRAPPER(2, cb); #undef cb - Vector vbias; + GI_FLOAT32_t vbias; if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - vbias = Vector::load(bias + oc); + vbias = GiLoadFloat32(bias + oc); -#define cb(m, n) v##m##n += vbias; +#define cb(m, n) v##m##n = ADDF(v##m##n, vbias); UNROLL_CALL_RAW_D2(2, 2, cb); #undef cb } if (bmode != BiasMode::BIAS) { -#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value)); +#define cb(m, n) v##m##n = op(CONCAT(v##m, n)); UNROLL_CALL_RAW_D2(2, 2, cb); #undef cb } @@ -163,12 +163,15 @@ struct OutputTransformF23_NCHW44 { size_t ow = ow_start + owo; \ if (oh < OH && ow < OW) { \ if (bmode == BiasMode::BIAS) { \ - v##oho##owo += Vector::load( \ - bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ - v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \ + v##oho##owo = ADDF( \ + v##oho##owo, GiLoadFloat32( \ + bias + oc * OH * OW + \ + oh * OW * pack_size + ow * pack_size)); \ + v##oho##owo = op(v##oho##owo); \ } \ - v##oho##owo.save( \ - output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ + GiStoreFloat32( \ + output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \ + v##oho##owo); \ } \ } while (0); UNROLL_CALL_RAW_D2(2, 2, out_save); @@ -211,29 +214,30 @@ void winograd_F23_mk4_f_nchw44::filter( pack_size * pack_size + ic_inner * pack_size; -#define cb(m, n) \ - Vector g##m##n = Vector::load( \ - fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); +#define cb(m, n) \ + GI_FLOAT32_t g##m##n = \ + GiLoadFloat32(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) #undef cb -#define FILTER_TRANSFORM(n, wd, g) \ - auto wd##n##0 = g##0##n; \ - tmp0 = (g##0##n + g##2##n) * 0.5; \ - tmp1 = g##1##n * 0.5; \ - auto wd##n##1 = tmp0 + tmp1; \ - auto wd##n##2 = tmp0 - tmp1; \ +#define FILTER_TRANSFORM(n, wd, g) \ + auto wd##n##0 = g##0##n; \ + tmp0 = MULSF(ADDF(g##0##n, g##2##n), 0.5); \ + tmp1 = MULSF(g##1##n, 0.5); \ + auto wd##n##1 = ADDF(tmp0, tmp1); \ + auto wd##n##2 = SUBF(tmp0, tmp1); \ auto wd##n##3 = g##2##n; - Vector tmp0, tmp1; + GI_FLOAT32_t tmp0, tmp1; UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); UNROLL_CALL_RAW(4, FILTER_TRANSFORM, ret, wd); #undef FILTER_TRANSFORM -#define cb_save(m, n) \ - ret##m##n.save( \ - filter_transform_buf + \ - (m * ALPHA + n) * OCB * ICB * pack_size * pack_size + \ - ocb * ICB * pack_size * pack_size + icb * pack_size * pack_size + \ - ic_inner * pack_size); +#define cb_save(m, n) \ + GiStoreFloat32( \ + filter_transform_buf + \ + (m * ALPHA + n) * OCB * ICB * pack_size * pack_size + \ + ocb * ICB * pack_size * pack_size + icb * pack_size * pack_size + \ + ic_inner * pack_size, \ + ret##m##n); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb_save) #undef cb_save } diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp index d507c712..1811ec19 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp @@ -4,7 +4,6 @@ #include "src/fallback/conv_bias/gi/fp32/filter_transform.h" #include "src/fallback/conv_bias/gi/fp32/helper.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h" -#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/elemwise_helper/op_unary.h" @@ -114,17 +113,17 @@ struct InputTransformF63_NCHW44 { auto t##i##4 = d6; \ auto t##i##5 = d6; \ auto t##i##6 = d6; \ - t##i##0 = GiSubtractFloat32(t##i##0, d6); \ - t##i##1 = GiAddFloat32(t##i##1, d1); \ - t##i##2 = GiSubtractFloat32(t##i##2, d1); \ + t##i##0 = SUBF(t##i##0, d6); \ + t##i##1 = ADDF(t##i##1, d1); \ + t##i##2 = SUBF(t##i##2, d1); \ t##i##3 = MADD(t##i##3, d1, v0, 2); \ t##i##4 = MSUB(t##i##4, d1, v0, 2); \ t##i##5 = MADD(t##i##5, d1, v1, 2); \ t##i##6 = MSUB(t##i##6, d1, v1, 2); \ - t##i##7 = GiSubtractFloat32(t##i##7, d1); \ + t##i##7 = SUBF(t##i##7, d1); \ t##i##0 = MSUB(t##i##0, d2, v0, 0); \ - t##i##1 = GiAddFloat32(t##i##1, d2); \ - t##i##2 = GiAddFloat32(t##i##2, d2); \ + t##i##1 = ADDF(t##i##1, d2); \ + t##i##2 = ADDF(t##i##2, d2); \ t##i##3 = MADD(t##i##3, d2, v0, 3); \ t##i##4 = MADD(t##i##4, d2, v0, 3); \ t##i##5 = MADD(t##i##5, d2, v1, 3); \ @@ -143,8 +142,8 @@ struct InputTransformF63_NCHW44 { t##i##4 = MSUB(t##i##4, d4, v1, 1); \ t##i##5 = MSUB(t##i##5, d4, v2, 0); \ t##i##6 = MSUB(t##i##6, d4, v2, 0); \ - t##i##1 = GiAddFloat32(t##i##1, d5); \ - t##i##2 = GiSubtractFloat32(t##i##2, d5); \ + t##i##1 = ADDF(t##i##1, d5); \ + t##i##2 = SUBF(t##i##2, d5); \ t##i##3 = MADD(t##i##3, d5, v1, 2); \ t##i##4 = MSUB(t##i##4, d5, v1, 2); \ t##i##5 = MADD(t##i##5, d5, v0, 2); \ @@ -162,17 +161,17 @@ struct InputTransformF63_NCHW44 { d5 = t6##i; \ d6 = t6##i; \ d7 = t7##i; \ - d0 = GiSubtractFloat32(d0, t6##i); \ - d1 = GiAddFloat32(d1, t1##i); \ - d2 = GiSubtractFloat32(d2, t1##i); \ + d0 = SUBF(d0, t6##i); \ + d1 = ADDF(d1, t1##i); \ + d2 = SUBF(d2, t1##i); \ d3 = MADD(d3, t1##i, v0, 2); \ d4 = MSUB(d4, t1##i, v0, 2); \ d5 = MADD(d5, t1##i, v1, 2); \ d6 = MSUB(d6, t1##i, v1, 2); \ - d7 = GiSubtractFloat32(d7, t1##i); \ + d7 = SUBF(d7, t1##i); \ d0 = MSUB(d0, t2##i, v0, 0); \ - d1 = GiAddFloat32(d1, t2##i); \ - d2 = GiAddFloat32(d2, t2##i); \ + d1 = ADDF(d1, t2##i); \ + d2 = ADDF(d2, t2##i); \ d3 = MADD(d3, t2##i, v0, 3); \ d4 = MADD(d4, t2##i, v0, 3); \ d5 = MADD(d5, t2##i, v1, 3); \ @@ -191,8 +190,8 @@ struct InputTransformF63_NCHW44 { d4 = MSUB(d4, t4##i, v1, 1); \ d5 = MSUB(d5, t4##i, v2, 0); \ d6 = MSUB(d6, t4##i, v2, 0); \ - d1 = GiAddFloat32(d1, t5##i); \ - d2 = GiSubtractFloat32(d2, t5##i); \ + d1 = ADDF(d1, t5##i); \ + d2 = SUBF(d2, t5##i); \ d3 = MADD(d3, t5##i, v1, 2); \ d4 = MSUB(d4, t5##i, v1, 2); \ d5 = MADD(d5, t5##i, v0, 2); \ @@ -261,7 +260,7 @@ struct OutputTransformF63_NCHW44 { size_t ocb = oc_index / pack_size; #define cb(m, n) \ - auto v##m##n = Vector::load( \ + auto v##m##n = GiLoadFloat32( \ output_transform_buf + \ (m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); @@ -281,51 +280,83 @@ struct OutputTransformF63_NCHW44 { * 0 0 0 0 0 1 */ - Vector v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; -#define cb(m) \ - v1addv2 = v1##m + v2##m; \ - v1subv2 = v1##m - v2##m; \ - v3addv4 = v3##m + v4##m; \ - v3subv4 = v3##m - v4##m; \ - v5addv6 = v5##m + v6##m; \ - v5subv6 = v5##m - v6##m; \ - auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \ - auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ - auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ - auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ - auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ - auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; + /* + * v1addv2 = v1##m + v2##m; + * v1subv2 = v1##m - v2##m; + * v3addv4 = v3##m + v4##m; + * v3subv4 = v3##m - v4##m; + * v5addv6 = v5##m + v6##m; + * v5subv6 = v5##m - v6##m; + * auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; + * auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; + * auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; + * auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; + * auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; + * auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; + */ + GI_FLOAT32_t v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; +#define cb(m) \ + v1addv2 = ADDF(v1##m, v2##m); \ + v1subv2 = SUBF(v1##m, v2##m); \ + v3addv4 = ADDF(v3##m, v4##m); \ + v3subv4 = SUBF(v3##m, v4##m); \ + v5addv6 = ADDF(v5##m, v6##m); \ + v5subv6 = SUBF(v5##m, v6##m); \ + auto t0##m = ADDF(ADDF(ADDF(v0##m, v1addv2), v3addv4), v5addv6); \ + auto t1##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \ + auto t2##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \ + auto t3##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \ + auto t4##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \ + auto t5##m = \ + ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ + v7##m); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb -#define cb(m) \ - v1addv2 = t##m##1 + t##m##2; \ - v1subv2 = t##m##1 - t##m##2; \ - v3addv4 = t##m##3 + t##m##4; \ - v3subv4 = t##m##3 - t##m##4; \ - v5addv6 = t##m##5 + t##m##6; \ - v5subv6 = t##m##5 - t##m##6; \ - v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \ - v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ - v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ - v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ - v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ - v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; + /* + * v1addv2 = t##m##1 + t##m##2; + * v1subv2 = t##m##1 - t##m##2; + * v3addv4 = t##m##3 + t##m##4; + * v3subv4 = t##m##3 - t##m##4; + * v5addv6 = t##m##5 + t##m##6; + * v5subv6 = t##m##5 - t##m##6; + * v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; + * v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; + * v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; + * v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; + * v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; + * v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; + */ +#define cb(m) \ + v1addv2 = ADDF(t##m##1, t##m##2); \ + v1subv2 = SUBF(t##m##1, t##m##2); \ + v3addv4 = ADDF(t##m##3, t##m##4); \ + v3subv4 = SUBF(t##m##3, t##m##4); \ + v5addv6 = ADDF(t##m##5, t##m##6); \ + v5subv6 = SUBF(t##m##5, t##m##6); \ + v##m##0 = ADDF(ADDF(ADDF(t##m##0, v1addv2), v3addv4), v5addv6); \ + v##m##1 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \ + v##m##2 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \ + v##m##3 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \ + v##m##4 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \ + v##m##5 = \ + ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ + t##m##7); UNROLL_CALL_NOWRAPPER(6, cb); #undef cb - Vector vbias; + GI_FLOAT32_t vbias; if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - vbias = Vector::load(bias + oc); + vbias = GiLoadFloat32(bias + oc); -#define cb(m, n) v##m##n += vbias; +#define cb(m, n) v##m##n = ADDF(v##m##n, vbias); UNROLL_CALL_RAW_D2(6, 6, cb); #undef cb } if (bmode != BiasMode::BIAS) { -#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value)); +#define cb(m, n) v##m##n = op(CONCAT(v##m, n)); UNROLL_CALL_RAW_D2(6, 6, cb); #undef cb } @@ -335,12 +366,15 @@ struct OutputTransformF63_NCHW44 { size_t ow = ow_start + owo; \ if (oh < OH && ow < OW) { \ if (bmode == BiasMode::BIAS) { \ - v##oho##owo += Vector::load( \ - bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ - v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \ + v##oho##owo = ADDF( \ + v##oho##owo, GiLoadFloat32( \ + bias + oc * OH * OW + \ + oh * OW * pack_size + ow * pack_size)); \ + v##oho##owo = op(v##oho##owo); \ } \ - v##oho##owo.save( \ - output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ + GiStoreFloat32( \ + output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \ + v##oho##owo); \ } \ } while (0); UNROLL_CALL_RAW_D2(6, 6, out_save); @@ -387,35 +421,52 @@ void winograd_F63_mk4_f_nchw44::filter( pack_size * pack_size + ic_inner * pack_size; -#define cb(m, n) \ - Vector g##m##n = Vector::load( \ - fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); +#define cb(m, n) \ + GI_FLOAT32_t g##m##n = \ + GiLoadFloat32(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) #undef cb -#define FILTER_TRANSFORM(n, wd, g) \ - auto wd##n##0 = g##0##n; \ - tmp0 = (g##0##n + g##2##n) * -0.2222222f; \ - tmp1 = g##1##n * -0.2222222f; \ - auto wd##n##1 = tmp0 + tmp1; \ - auto wd##n##2 = tmp0 - tmp1; \ - tmp0 = g##0##n * 0.0111111f + g##2##n * 0.0444444f; \ - tmp1 = g##1##n * 0.0222222f; \ - auto wd##n##3 = tmp0 + tmp1; \ - auto wd##n##4 = tmp0 - tmp1; \ - tmp0 = g##0##n * 0.7111111f + g##2##n * 0.1777778f; \ - tmp1 = g##1##n * 0.3555556f; \ - auto wd##n##5 = tmp0 + tmp1; \ - auto wd##n##6 = tmp0 - tmp1; \ + /* + * auto wd##n##0 = g##0##n; + * tmp0 = (g##0##n + g##2##n) * -0.2222222f; + * tmp1 = g##1##n * -0.2222222f; + * auto wd##n##1 = tmp0 + tmp1; + * auto wd##n##2 = tmp0 - tmp1; + * tmp0 = g##0##n * 0.0111111f + g##2##n * 0.0444444f; + * tmp1 = g##1##n * 0.0222222f; + * auto wd##n##3 = tmp0 + tmp1; + * auto wd##n##4 = tmp0 - tmp1; + * tmp0 = g##0##n * 0.7111111f + g##2##n * 0.1777778f; + * tmp1 = g##1##n * 0.3555556f; + * auto wd##n##5 = tmp0 + tmp1; + * auto wd##n##6 = tmp0 - tmp1; + * auto wd##n##7 = g##2##n; + */ +#define FILTER_TRANSFORM(n, wd, g) \ + auto wd##n##0 = g##0##n; \ + tmp0 = MULSF(ADDF(g##0##n, g##2##n), -0.2222222f); \ + tmp1 = MULSF(g##1##n, -0.2222222f); \ + auto wd##n##1 = ADDF(tmp0, tmp1); \ + auto wd##n##2 = SUBF(tmp0, tmp1); \ + tmp0 = ADDF(MULSF(g##0##n, 0.0111111f), MULSF(g##2##n, 0.0444444f)); \ + tmp1 = MULSF(g##1##n, 0.0222222f); \ + auto wd##n##3 = ADDF(tmp0, tmp1); \ + auto wd##n##4 = SUBF(tmp0, tmp1); \ + tmp0 = ADDF(MULSF(g##0##n, 0.7111111f), MULSF(g##2##n, 0.1777778f)); \ + tmp1 = MULSF(g##1##n, 0.3555556f); \ + auto wd##n##5 = ADDF(tmp0, tmp1); \ + auto wd##n##6 = SUBF(tmp0, tmp1); \ auto wd##n##7 = g##2##n; - Vector tmp0, tmp1; + GI_FLOAT32_t tmp0, tmp1; UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); UNROLL_CALL_RAW(8, FILTER_TRANSFORM, ret, wd); #undef FILTER_TRANSFORM #define cb_save(m, n) \ - ret##m##n.save( \ + GiStoreFloat32( \ filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \ - icb * pack_size * pack_size + ic_inner * pack_size); + icb * pack_size * pack_size + ic_inner * pack_size, \ + ret##m##n); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb_save) #undef cb_save } diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp index 087a738c..419df03e 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp @@ -4,7 +4,6 @@ #include "src/fallback/conv_bias/gi/fp32/filter_transform.h" #include "src/fallback/conv_bias/gi/fp32/helper.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h" -#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/elemwise_helper/op_unary.h" @@ -137,14 +136,14 @@ struct InputTransformF73_NCHW44 { auto t##i##6 = d7; \ auto t##i##7 = d7; \ t##i##8 = MSUB(t##i##8, d7, v0, 0); \ - t##i##0 = GiSubtractFloat32(t##i##0, d1); \ + t##i##0 = SUBF(t##i##0, d1); \ t##i##1 = MSUB(t##i##1, d1, v0, 0); \ t##i##2 = MADD(t##i##2, d1, v0, 0); \ t##i##3 = MSUB(t##i##3, d1, v0, 1); \ t##i##4 = MADD(t##i##4, d1, v0, 1); \ t##i##5 = MSUB(t##i##5, d1, v0, 2); \ t##i##6 = MADD(t##i##6, d1, v0, 2); \ - t##i##7 = GiSubtractFloat32(t##i##7, d1); \ + t##i##7 = SUBF(t##i##7, d1); \ t##i##8 = MADD(t##i##8, d1, v0, 0); \ t##i##0 = MSUB(t##i##0, d2, v0, 3); \ t##i##1 = MSUB(t##i##1, d2, v1, 0); \ @@ -153,7 +152,7 @@ struct InputTransformF73_NCHW44 { t##i##4 = MSUB(t##i##4, d2, v1, 3); \ t##i##5 = MSUB(t##i##5, d2, v2, 0); \ t##i##6 = MSUB(t##i##6, d2, v2, 1); \ - t##i##8 = GiSubtractFloat32(t##i##8, d2); \ + t##i##8 = SUBF(t##i##8, d2); \ t##i##0 = MADD(t##i##0, d3, v2, 2); \ t##i##1 = MADD(t##i##1, d3, v2, 3); \ t##i##2 = MSUB(t##i##2, d3, v3, 0); \ @@ -185,7 +184,7 @@ struct InputTransformF73_NCHW44 { t##i##2 = MSUB(t##i##2, d6, v1, 1); \ t##i##3 = MADD(t##i##3, d6, v1, 0); \ t##i##4 = MSUB(t##i##4, d6, v3, 1); \ - t##i##5 = GiSubtractFloat32(t##i##5, d6); \ + t##i##5 = SUBF(t##i##5, d6); \ t##i##6 = MSUB(t##i##6, d6, v6, 2); \ t##i##8 = MSUB(t##i##8, d6, v2, 2); \ t##i##0 = MADD(t##i##0, d0, v0, 0); @@ -204,14 +203,14 @@ struct InputTransformF73_NCHW44 { d6 = t7##i; \ d7 = t7##i; \ d8 = MSUB(d8, t7##i, v0, 0); \ - d0 = GiSubtractFloat32(d0, t1##i); \ + d0 = SUBF(d0, t1##i); \ d1 = MSUB(d1, t1##i, v0, 0); \ d2 = MADD(d2, t1##i, v0, 0); \ d3 = MSUB(d3, t1##i, v0, 1); \ d4 = MADD(d4, t1##i, v0, 1); \ d5 = MSUB(d5, t1##i, v0, 2); \ d6 = MADD(d6, t1##i, v0, 2); \ - d7 = GiSubtractFloat32(d7, t1##i); \ + d7 = SUBF(d7, t1##i); \ d8 = MADD(d8, t1##i, v0, 0); \ d0 = MSUB(d0, t2##i, v0, 3); \ d1 = MSUB(d1, t2##i, v1, 0); \ @@ -220,7 +219,7 @@ struct InputTransformF73_NCHW44 { d4 = MSUB(d4, t2##i, v1, 3); \ d5 = MSUB(d5, t2##i, v2, 0); \ d6 = MSUB(d6, t2##i, v2, 1); \ - d8 = GiSubtractFloat32(d8, t2##i); \ + d8 = SUBF(d8, t2##i); \ d0 = MADD(d0, t3##i, v2, 2); \ d1 = MADD(d1, t3##i, v2, 3); \ d2 = MSUB(d2, t3##i, v3, 0); \ @@ -252,7 +251,7 @@ struct InputTransformF73_NCHW44 { d2 = MSUB(d2, t6##i, v1, 1); \ d3 = MADD(d3, t6##i, v1, 0); \ d4 = MSUB(d4, t6##i, v3, 1); \ - d5 = GiSubtractFloat32(d5, t6##i); \ + d5 = SUBF(d5, t6##i); \ d6 = MSUB(d6, t6##i, v6, 2); \ d8 = MSUB(d8, t6##i, v2, 2); \ d0 = MADD(d0, t0##i, v0, 0); \ @@ -325,7 +324,7 @@ struct OutputTransformF73_NCHW44 { size_t ocb = oc_index / pack_size; #define cb(m, n) \ - auto v##m##n = Vector::load( \ + auto v##m##n = GiLoadFloat32( \ output_transform_buf + \ (m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); @@ -346,56 +345,112 @@ struct OutputTransformF73_NCHW44 { * 1 1.5 2.25 3.375 5.0625 7.59375 11.390625 * 0 0 0 0 0 0 1 */ + /* + * v1addv2 = v1##m + v2##m; + * v1subv2 = v1##m - v2##m; + * v3addv4 = v3##m + v4##m; + * v3subv4 = v3##m - v4##m; + * v5addv6 = v5##m + v6##m; + * v5subv6 = v5##m - v6##m; + * auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6 + v7##m; + * auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + v7##m * 1.5f; + * auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + v7##m * 2.25f; + * auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + v7##m * 3.375f; + * auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + v7##m * 5.0625f; + * auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m + * * 7.59375f; auto t6##m = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + + * v7##m * 11.390625f + v8##m; + */ - Vector v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; + GI_FLOAT32_t v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; #define cb(m) \ - v1addv2 = v1##m + v2##m; \ - v1subv2 = v1##m - v2##m; \ - v3addv4 = v3##m + v4##m; \ - v3subv4 = v3##m - v4##m; \ - v5addv6 = v5##m + v6##m; \ - v5subv6 = v5##m - v6##m; \ - auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6 + v7##m; \ - auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + v7##m * 1.5f; \ - auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + v7##m * 2.25f; \ - auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + v7##m * 3.375f; \ - auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + v7##m * 5.0625f; \ - auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m * 7.59375f; \ - auto t6##m = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + v7##m * 11.390625f + \ - v8##m; + v1addv2 = ADDF(v1##m, v2##m); \ + v1subv2 = SUBF(v1##m, v2##m); \ + v3addv4 = ADDF(v3##m, v4##m); \ + v3subv4 = SUBF(v3##m, v4##m); \ + v5addv6 = ADDF(v5##m, v6##m); \ + v5subv6 = SUBF(v5##m, v6##m); \ + auto t0##m = ADDF(ADDF(ADDF(ADDF(v0##m, v1addv2), v3addv4), v5addv6), v7##m); \ + auto t1##m = \ + ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)), \ + MULSF(v7##m, 1.5f)); \ + auto t2##m = \ + ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)), \ + MULSF(v7##m, 2.25f)); \ + auto t3##m = \ + ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)), \ + MULSF(v7##m, 3.375f)); \ + auto t4##m = \ + ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)), \ + MULSF(v7##m, 5.0625f)); \ + auto t5##m = \ + ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ + MULSF(v7##m, 7.59375f)); \ + auto t6##m = ADDF( \ + ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 64.f)), MULSF(v5addv6, 0.015625f)), \ + MULSF(v7##m, 11.390625f)), \ + v8##m); UNROLL_CALL_NOWRAPPER(9, cb); #undef cb -#define cb(m) \ - v1addv2 = t##m##1 + t##m##2; \ - v1subv2 = t##m##1 - t##m##2; \ - v3addv4 = t##m##3 + t##m##4; \ - v3subv4 = t##m##3 - t##m##4; \ - v5addv6 = t##m##5 + t##m##6; \ - v5subv6 = t##m##5 - t##m##6; \ - v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6 + t##m##7; \ - v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + t##m##7 * 1.5f; \ - v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + t##m##7 * 2.25f; \ - v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + t##m##7 * 3.375; \ - v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + t##m##7 * 5.0625f; \ - v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7 * 7.59375f; \ - v##m##6 = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + t##m##7 * 11.390625f + \ - t##m##8; + /* + * v1addv2 = t##m##1 + t##m##2; + * v1subv2 = t##m##1 - t##m##2; + * v3addv4 = t##m##3 + t##m##4; + * v3subv4 = t##m##3 - t##m##4; + * v5addv6 = t##m##5 + t##m##6; + * v5subv6 = t##m##5 - t##m##6; + * v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6 + t##m##7; + * v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + t##m##7 * 1.5f; + * v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + t##m##7 * 2.25f; + * v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + t##m##7 * 3.375; + * v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + t##m##7 * 5.0625f; + * v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7 * 7.59375f; + * v##m##6 = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + t##m##7 + * * 11.390625f + t##m##8; + */ +#define cb(m) \ + v1addv2 = ADDF(t##m##1, t##m##2); \ + v1subv2 = SUBF(t##m##1, t##m##2); \ + v3addv4 = ADDF(t##m##3, t##m##4); \ + v3subv4 = SUBF(t##m##3, t##m##4); \ + v5addv6 = ADDF(t##m##5, t##m##6); \ + v5subv6 = SUBF(t##m##5, t##m##6); \ + v##m##0 = ADDF(ADDF(ADDF(ADDF(t##m##0, v1addv2), v3addv4), v5addv6), t##m##7); \ + v##m##1 = \ + ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)), \ + MULSF(t##m##7, 1.5f)); \ + v##m##2 = \ + ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)), \ + MULSF(t##m##7, 2.25f)); \ + v##m##3 = \ + ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)), \ + MULSF(t##m##7, 3.375)); \ + v##m##4 = \ + ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)), \ + MULSF(t##m##7, 5.0625f)); \ + v##m##5 = \ + ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ + MULSF(t##m##7, 7.59375f)); \ + v##m##6 = ADDF( \ + ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 64.f)), MULSF(v5addv6, 0.015625f)), \ + MULSF(t##m##7, 11.390625f)), \ + t##m##8); UNROLL_CALL_NOWRAPPER(7, cb); #undef cb - Vector vbias; + GI_FLOAT32_t vbias; if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - vbias = Vector::load(bias + oc); + vbias = GiLoadFloat32(bias + oc); -#define cb(m, n) v##m##n += vbias; +#define cb(m, n) v##m##n = ADDF(v##m##n, vbias); UNROLL_CALL_RAW_D2(7, 7, cb); #undef cb } if (bmode != BiasMode::BIAS) { -#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value)); +#define cb(m, n) v##m##n = op(CONCAT(v##m, n)); UNROLL_CALL_RAW_D2(7, 7, cb); #undef cb } @@ -405,12 +460,15 @@ struct OutputTransformF73_NCHW44 { size_t ow = ow_start + owo; \ if (oh < OH && ow < OW) { \ if (bmode == BiasMode::BIAS) { \ - v##oho##owo += Vector::load( \ - bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ - v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \ + v##oho##owo = ADDF( \ + v##oho##owo, GiLoadFloat32( \ + bias + oc * OH * OW + \ + oh * OW * pack_size + ow * pack_size)); \ + v##oho##owo = op(v##oho##owo); \ } \ - v##oho##owo.save( \ - output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ + GiStoreFloat32( \ + output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \ + v##oho##owo); \ } \ } while (0); UNROLL_CALL_RAW_D2(7, 7, out_save); @@ -458,34 +516,53 @@ void winograd_F73_mk4_f_nchw44::filter( pack_size * pack_size + ic_inner * pack_size; -#define cb(m, n) \ - Vector g##m##n = Vector::load( \ - fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); +#define cb(m, n) \ + GI_FLOAT32_t g##m##n = \ + GiLoadFloat32(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) #undef cb -#define FILTER_TRANSFORM(n, wd, g) \ - auto wd##n##0 = g##0##n * 0.6666667f; \ - auto wd##n##1 = (g##0##n + g##1##n + g##2##n) * 0.4444444f; \ - auto wd##n##2 = (g##0##n - g##1##n + g##2##n) * 0.0888889f; \ - auto wd##n##3 = \ - g##0##n * 0.0222222f + g##1##n * 0.0444444f + g##2##n * 0.0888889f; \ - auto wd##n##4 = \ - g##0##n * -0.0031746f + g##1##n * 0.0063492f + g##2##n * -0.0126984f; \ - auto wd##n##5 = \ - g##0##n * -0.7111111f + g##1##n * -0.3555556f + g##2##n * -0.1777778f; \ - auto wd##n##6 = \ - g##0##n * -0.3555556f + g##1##n * 0.1777778f + g##2##n * -0.0888889f; \ - auto wd##n##7 = \ - g##0##n * -0.1523810f + g##1##n * -0.2285714f + g##2##n * -0.3428572f; \ +/* + * auto wd##n##0 = g##0##n * 0.6666667f; + * auto wd##n##1 = (g##0##n + g##1##n + g##2##n) * 0.4444444f; + * auto wd##n##2 = (g##0##n - g##1##n + g##2##n) * 0.0888889f; + * auto wd##n##3 = + * g##0##n * 0.0222222f + g##1##n * 0.0444444f + g##2##n * + * 0.0888889f; auto wd##n##4 = g##0##n * -0.0031746f + g##1##n * + * 0.0063492f + g##2##n * -0.0126984f; auto wd##n##5 = g##0##n * + * -0.7111111f + g##1##n * -0.3555556f + g##2##n * -0.1777778f; auto + * wd##n##6 = g##0##n * -0.3555556f + g##1##n * 0.1777778f + g##2##n * + * -0.0888889f; auto wd##n##7 = g##0##n * -0.1523810f + g##1##n * + * -0.2285714f + g##2##n * -0.3428572f; + */ +#define FILTER_TRANSFORM(n, wd, g) \ + auto wd##n##0 = MULSF(g##0##n, 0.6666667f); \ + auto wd##n##1 = MULSF(ADDF(ADDF(g##0##n, g##1##n), g##2##n), 0.4444444f); \ + auto wd##n##2 = MULSF(ADDF(SUBF(g##0##n, g##1##n), g##2##n), 0.0888889f); \ + auto wd##n##3 = \ + ADDF(ADDF(MULSF(g##0##n, 0.0222222f), MULSF(g##1##n, 0.0444444f)), \ + MULSF(g##2##n, 0.0888889f)); \ + auto wd##n##4 = \ + ADDF(ADDF(MULSF(g##0##n, -0.0031746f), MULSF(g##1##n, 0.0063492f)), \ + MULSF(g##2##n, -0.0126984f)); \ + auto wd##n##5 = \ + ADDF(ADDF(MULSF(g##0##n, -0.7111111f), MULSF(g##1##n, -0.3555556f)), \ + MULSF(g##2##n, -0.1777778f)); \ + auto wd##n##6 = \ + ADDF(ADDF(MULSF(g##0##n, -0.3555556f), MULSF(g##1##n, 0.1777778f)), \ + MULSF(g##2##n, -0.0888889f)); \ + auto wd##n##7 = \ + ADDF(ADDF(MULSF(g##0##n, -0.1523810f), MULSF(g##1##n, -0.2285714f)), \ + MULSF(g##2##n, -0.3428572f)); \ auto wd##n##8 = g##2##n; UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); UNROLL_CALL_RAW(9, FILTER_TRANSFORM, ret, wd); #undef FILTER_TRANSFORM #define cb_save(m, n) \ - ret##m##n.save( \ + GiStoreFloat32( \ filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \ - icb * pack_size * pack_size + ic_inner * pack_size); + icb * pack_size * pack_size + ic_inner * pack_size, \ + ret##m##n); UNROLL_CALL_NOWRAPPER_D2(9, 9, cb_save) #undef cb_save } diff --git a/dnn/src/fallback/conv_bias/gi/utils.h b/dnn/src/fallback/conv_bias/gi/utils.h deleted file mode 100644 index 5e8ec5fe..00000000 --- a/dnn/src/fallback/conv_bias/gi/utils.h +++ /dev/null @@ -1,215 +0,0 @@ -#pragma once - -#include -#include "src/common/utils.h" -#include "src/fallback/general_intrinsic/gi_float.h" - -namespace megdnn { -namespace fallback { - -template -struct Vector; - -template <> -struct Vector { - GI_FLOAT32_FIXLEN_t value; - Vector() {} - Vector(const float v) { value = GiFloat32Type2FixLenType(GiBroadcastFloat32(v)); } - Vector(const Vector& lr) { value = lr.value; } - Vector(const Vector&& lr) { value = std::move(lr.value); } - Vector(const GI_FLOAT32_t& v) { value = GiFloat32Type2FixLenType(v); } - static Vector load(const float* addr) { - Vector v; - v.value = GiFloat32Type2FixLenType(GiLoadFloat32(addr)); - return v; - } - static void save(float* addr, const Vector& v) { - GiStoreFloat32(addr, GiFixLenType2GiFloat32Type(v.value)); - } - void save(float* addr) { save(addr, *this); } - Vector operator+(const Vector& lr) { - Vector dst; - dst.value = GiFloat32Type2FixLenType(GiAddFloat32( - GiFixLenType2GiFloat32Type(value), - GiFixLenType2GiFloat32Type(lr.value))); - return dst; - } - Vector& operator+=(const Vector& lr) { - value = GiFloat32Type2FixLenType(GiAddFloat32( - GiFixLenType2GiFloat32Type(value), - GiFixLenType2GiFloat32Type(lr.value))); - return *this; - } - Vector operator-(const Vector& lr) { - Vector dst; - dst.value = GiFloat32Type2FixLenType(GiSubtractFloat32( - GiFixLenType2GiFloat32Type(value), - GiFixLenType2GiFloat32Type(lr.value))); - return dst; - } - Vector& operator-=(const Vector& lr) { - value = GiFloat32Type2FixLenType(GiSubtractFloat32( - GiFixLenType2GiFloat32Type(value), - GiFixLenType2GiFloat32Type(lr.value))); - return *this; - } - Vector operator*(float lr) { - Vector dst; - dst.value = GiFloat32Type2FixLenType( - GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value), lr)); - return dst; - } - Vector operator*(const Vector& lr) { - Vector dst; - dst.value = GiFloat32Type2FixLenType(GiMultiplyFloat32( - GiFixLenType2GiFloat32Type(value), - GiFixLenType2GiFloat32Type(lr.value))); - return dst; - } - Vector& operator*=(const Vector& lr) { - value = GiFloat32Type2FixLenType(GiMultiplyFloat32( - GiFixLenType2GiFloat32Type(value), - GiFixLenType2GiFloat32Type(lr.value))); - return *this; - } - Vector& operator=(const Vector& lr) { - value = lr.value; - return *this; - } - Vector& operator=(const Vector&& lr) { - value = std::move(lr.value); - return *this; - } - Vector operator-() { - Vector dst; - dst.value = -value; - return dst; - } -}; - -template <> -struct Vector { - GI_FLOAT32_FIXLEN_V2_t value; - Vector() {} - Vector(const float v) { - value.val[0] = GiFloat32Type2FixLenType(GiBroadcastFloat32(v)); - value.val[1] = GiFloat32Type2FixLenType(GiBroadcastFloat32(v)); - } - Vector(const Vector& lr) { value = lr.value; } - Vector(const Vector&& lr) { value = std::move(lr.value); } - Vector(const GI_FLOAT32_V2_t& v) { value = GiFloat32Type2FixLenV2Type(v); } - static Vector load(const float* addr) { - Vector v; - v.value = GiFloat32Type2FixLenV2Type(GiLoadFloat32V2(addr)); - return v; - } - static void save(float* addr, const Vector& v) { - GiStoreFloat32V2(addr, GiFixLenType2GiFloat32V2Type(v.value)); - } - - void save(float* addr) { save(addr, *this); } - Vector operator+(const Vector& lr) { - Vector dst; - dst.value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32( - GiFixLenType2GiFloat32Type(value.val[0]), - GiFixLenType2GiFloat32Type(lr.value.val[0]))); - dst.value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32( - GiFixLenType2GiFloat32Type(value.val[1]), - GiFixLenType2GiFloat32Type(lr.value.val[1]))); - return dst; - } - Vector& operator+=(const Vector& lr) { - value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32( - GiFixLenType2GiFloat32Type(value.val[0]), - GiFixLenType2GiFloat32Type(lr.value.val[0]))); - value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32( - GiFixLenType2GiFloat32Type(value.val[1]), - GiFixLenType2GiFloat32Type(lr.value.val[1]))); - return *this; - } - Vector& add(const Vector& lr) { - value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32( - GiFixLenType2GiFloat32Type(value.val[0]), - GiFixLenType2GiFloat32Type(lr.value.val[0]))); - value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32( - GiFixLenType2GiFloat32Type(value.val[1]), - GiFixLenType2GiFloat32Type(lr.value.val[1]))); - return *this; - } - Vector operator-(const Vector& lr) { - Vector dst; - dst.value.val[0] = GiFloat32Type2FixLenType(GiSubtractFloat32( - GiFixLenType2GiFloat32Type(value.val[0]), - GiFixLenType2GiFloat32Type(lr.value.val[0]))); - dst.value.val[1] = GiFloat32Type2FixLenType(GiSubtractFloat32( - GiFixLenType2GiFloat32Type(value.val[1]), - GiFixLenType2GiFloat32Type(lr.value.val[1]))); - return dst; - } - Vector& operator-=(const Vector& lr) { - value.val[0] = GiFloat32Type2FixLenType(GiSubtractFloat32( - GiFixLenType2GiFloat32Type(value.val[0]), - GiFixLenType2GiFloat32Type(lr.value.val[0]))); - value.val[1] = GiFloat32Type2FixLenType(GiSubtractFloat32( - GiFixLenType2GiFloat32Type(value.val[1]), - GiFixLenType2GiFloat32Type(lr.value.val[1]))); - return *this; - } - Vector operator*(float lr) { - Vector dst; - dst.value.val[0] = GiFloat32Type2FixLenType( - GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value.val[0]), lr)); - dst.value.val[1] = GiFloat32Type2FixLenType( - GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value.val[1]), lr)); - return dst; - } - //! val + lr * n - Vector& mla(const Vector& lr, float n) { - value.val[0] = GiFloat32Type2FixLenType(GiMultiplyAddScalarFloat32( - GiFixLenType2GiFloat32Type(value.val[0]), - GiFixLenType2GiFloat32Type(lr.value.val[0]), n)); - value.val[1] = GiFloat32Type2FixLenType(GiMultiplyAddScalarFloat32( - GiFixLenType2GiFloat32Type(value.val[1]), - GiFixLenType2GiFloat32Type(lr.value.val[1]), n)); - return *this; - } - - Vector operator*(const Vector& lr) { - Vector dst; - dst.value.val[0] = GiFloat32Type2FixLenType(GiMultiplyFloat32( - GiFixLenType2GiFloat32Type(value.val[0]), - GiFixLenType2GiFloat32Type(lr.value.val[0]))); - dst.value.val[1] = GiFloat32Type2FixLenType(GiMultiplyFloat32( - GiFixLenType2GiFloat32Type(value.val[1]), - GiFixLenType2GiFloat32Type(lr.value.val[1]))); - return dst; - } - Vector& operator*=(const Vector& lr) { - value.val[0] = GiFloat32Type2FixLenType(GiMultiplyFloat32( - GiFixLenType2GiFloat32Type(value.val[0]), - GiFixLenType2GiFloat32Type(lr.value.val[0]))); - value.val[1] = GiFloat32Type2FixLenType(GiMultiplyFloat32( - GiFixLenType2GiFloat32Type(value.val[1]), - GiFixLenType2GiFloat32Type(lr.value.val[1]))); - return *this; - } - Vector& operator=(const Vector& lr) { - value = lr.value; - return *this; - } - Vector& operator=(const Vector&& lr) { - value = std::move(lr.value); - return *this; - } - Vector operator-() { - Vector dst; - dst.value.val[0] = -value.val[0]; - dst.value.val[1] = -value.val[1]; - return dst; - } -}; - -} // namespace fallback -} // namespace megdnn - -// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/general_intrinsic/gi_float.h b/dnn/src/fallback/general_intrinsic/gi_float.h index 8f9edde3..236e669d 100644 --- a/dnn/src/fallback/general_intrinsic/gi_float.h +++ b/dnn/src/fallback/general_intrinsic/gi_float.h @@ -857,6 +857,18 @@ GI_FLOAT32_t GiMultiplyScalerFloat32(GI_FLOAT32_t Vector1, float Scaler) { } GI_FORCEINLINE +GI_FLOAT32_V2_t GiMultiplyScalerFloat32V2(GI_FLOAT32_V2_t Vector1, float Scaler) { + GI_FLOAT32_V2_t ret; + GiSetSubVectorFloat32V2( + ret, 0, + GiMultiplyScalerFloat32(GiGetSubVectorFloat32V2(Vector1, 0), Scaler)); + GiSetSubVectorFloat32V2( + ret, 1, + GiMultiplyScalerFloat32(GiGetSubVectorFloat32V2(Vector1, 1), Scaler)); + return ret; +} + +GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddFloat32( GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { #if defined(GI_NEON_INTRINSICS) @@ -888,6 +900,23 @@ GI_FLOAT32_t GiMultiplyAddScalarFloat32( } GI_FORCEINLINE +GI_FLOAT32_V2_t GiMultiplyAddScalarFloat32V2( + GI_FLOAT32_V2_t VectorSum, GI_FLOAT32_V2_t Vector, float Scalar) { + GI_FLOAT32_V2_t ret; + GiSetSubVectorFloat32V2( + ret, 0, + GiMultiplyAddScalarFloat32( + GiGetSubVectorFloat32V2(VectorSum, 0), + GiGetSubVectorFloat32V2(Vector, 0), Scalar)); + GiSetSubVectorFloat32V2( + ret, 1, + GiMultiplyAddScalarFloat32( + GiGetSubVectorFloat32V2(VectorSum, 1), + GiGetSubVectorFloat32V2(Vector, 1), Scalar)); + return ret; +} + +GI_FORCEINLINE GI_FLOAT32_t GiMultiplySubScalarFloat32( GI_FLOAT32_t VectorSub, GI_FLOAT32_t Vector, float Scalar) { #if defined(GI_NEON_INTRINSICS) @@ -951,6 +980,26 @@ GI_FLOAT32_t GiDivideFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { #endif } +#define OPV2(op) \ + GI_FORCEINLINE \ + GI_FLOAT32_V2_t op##V2(GI_FLOAT32_V2_t Vector1, GI_FLOAT32_V2_t Vector2) { \ + GI_FLOAT32_V2_t ret; \ + GiSetSubVectorFloat32V2( \ + ret, 0, \ + op(GiGetSubVectorFloat32V2(Vector1, 0), \ + GiGetSubVectorFloat32V2(Vector2, 0))); \ + GiSetSubVectorFloat32V2( \ + ret, 1, \ + op(GiGetSubVectorFloat32V2(Vector1, 1), \ + GiGetSubVectorFloat32V2(Vector2, 1))); \ + return ret; \ + } +OPV2(GiAddFloat32); +OPV2(GiSubtractFloat32); +OPV2(GiMultiplyFloat32); +OPV2(GiDivideFloat32); +#undef OPV2 + GI_FORCEINLINE GI_FLOAT32_t GiRecpeSFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { #if defined(GI_NEON64_INTRINSICS)