Browse Source

feat(gi/rvv): remove winograd rvv do not use FIXLEN workaround

GitOrigin-RevId: fce5103088
HuaHua404-patch-4
Megvii Engine Team 2 years ago
parent
commit
25e89d68b0
15 changed files with 1175 additions and 908 deletions
  1. +0
    -1
      dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp
  2. +0
    -1
      dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp
  3. +0
    -1
      dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.cpp
  4. +48
    -36
      dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h
  5. +162
    -146
      dnn/src/fallback/conv_bias/gi/fp32/helper.h
  6. +36
    -37
      dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp
  7. +201
    -96
      dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp
  8. +133
    -72
      dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp
  9. +60
    -47
      dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp
  10. +161
    -63
      dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp
  11. +55
    -51
      dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp
  12. +125
    -74
      dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp
  13. +145
    -68
      dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp
  14. +0
    -215
      dnn/src/fallback/conv_bias/gi/utils.h
  15. +49
    -0
      dnn/src/fallback/general_intrinsic/gi_float.h

+ 0
- 1
dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp View File

@@ -1,7 +1,6 @@
#include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" #include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/elemwise_helper/elemwise_op.h" #include "src/fallback/elemwise_helper/elemwise_op.h"


#pragma GCC diagnostic ignored "-Wunused-parameter" #pragma GCC diagnostic ignored "-Wunused-parameter"


+ 0
- 1
dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp View File

@@ -1,7 +1,6 @@
#include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" #include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/elemwise_helper/elemwise_op.h" #include "src/fallback/elemwise_helper/elemwise_op.h"


#pragma GCC diagnostic ignored "-Wunused-parameter" #pragma GCC diagnostic ignored "-Wunused-parameter"


+ 0
- 1
dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.cpp View File

@@ -4,7 +4,6 @@
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" #include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h"
#include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" #include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/elemwise_helper/elemwise_op.h" #include "src/fallback/elemwise_helper/elemwise_op.h"


using namespace megdnn; using namespace megdnn;


+ 48
- 36
dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h View File

@@ -3,29 +3,44 @@
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h" #include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/conv_bias/gi/utils.h"


namespace megdnn { namespace megdnn {
namespace fallback { namespace fallback {


/*
* wd##0 = d##0;
* tmp0 = (d##0 + d##2) * -0.2222222f
* tmp1 = d##1 * -0.2222222
* wd##1 = tmp0 + tmp1
* wd##2 = tmp0 - tmp1
* tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f
* tmp1 = d##1 * 0.0222222f
* wd##3 = tmp0 + tmp1
* wd##4 = tmp0 - tmp1
* tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f
* tmp1 = d##1 * 0.3555556f
* wd##5 = tmp0 + tmp1
* wd##6 = tmp0 - tmp1
* wd##7 = d##2
*/
template <param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT> template <param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT>
struct FilterTransform6X3 { struct FilterTransform6X3 {
#define FILTER_TRANSFORM(d, wd) \
do { \
wd##0 = d##0; \
auto tmp0 = (d##0 + d##2) * -0.2222222f; \
auto tmp1 = d##1 * -0.2222222f; \
wd##1 = tmp0 + tmp1; \
wd##2 = tmp0 - tmp1; \
tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \
tmp1 = d##1 * 0.0222222f; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \
tmp1 = d##1 * 0.3555556f; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = d##2; \
#define FILTER_TRANSFORM(d, wd, ADDC, SUBC, MULC) \
do { \
wd##0 = d##0; \
auto tmp0 = MULC(ADDC(d##0, d##2), -0.2222222f); \
auto tmp1 = MULC(d##1, -0.2222222f); \
wd##1 = ADDC(tmp0, tmp1); \
wd##2 = SUBC(tmp0, tmp1); \
tmp0 = ADDC(MULC(d##0, 0.0111111f), MULC(d##2, 0.0444444f)); \
tmp1 = MULC(d##1, 0.0222222f); \
wd##3 = ADDC(tmp0, tmp1); \
wd##4 = SUBC(tmp0, tmp1); \
tmp0 = ADDC(MULC(d##0, 0.7111111f), MULC(d##2, 0.1777778f)); \
tmp1 = MULC(d##1, 0.3555556f); \
wd##5 = ADDC(tmp0, tmp1); \
wd##6 = SUBC(tmp0, tmp1); \
wd##7 = d##2; \
} while (0); } while (0);


static void transform( static void transform(
@@ -49,37 +64,35 @@ struct FilterTransform6X3 {
rep(ic, IC) { rep(ic, IC) {
const float* fptr = filter + (oc * IC + ic) * 3 * 3; const float* fptr = filter + (oc * IC + ic) * 3 * 3;


Vector<float, 4> g0 = Vector<float, 4>::load(fptr);
Vector<float, 4> g1 = Vector<float, 4>::load(fptr + 3);
GI_FLOAT32_t g0 = GiLoadFloat32(fptr);
GI_FLOAT32_t g1 = GiLoadFloat32(fptr + 3);


Vector<float, 4> g2 = Vector<float, 4>::load(fptr + 6 - 1);
GI_FLOAT32_t g2 = GiLoadFloat32(fptr + 6 - 1);
GI_FLOAT32_t zeros = GiZeroFloat32(); GI_FLOAT32_t zeros = GiZeroFloat32();
g2.value = GiFloat32Type2FixLenType(
GiExtqFloat32(GiFixLenType2GiFloat32Type(g2.value), zeros, 1));
g2 = GiExtqFloat32(g2, zeros, 1);


#define cb(i) Vector<float, 4> wd##i;
#define cb(i) GI_FLOAT32_t wd##i;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


#define cb(i) Vector<float, 8> wdt##i;
UNROLL_CALL_NOWRAPPER(3, cb);
#undef cb

#define cb(i) Vector<float, 8> ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

FILTER_TRANSFORM(g, wd);
FILTER_TRANSFORM(g, wd, ADDF, SUBF, MULSF);


size_t ocb = oc / 4; size_t ocb = oc / 4;
size_t oc4 = oc % 4; size_t oc4 = oc % 4;
size_t icb = ic / 4; size_t icb = ic / 4;
size_t ic4 = ic % 4; size_t ic4 = ic % 4;
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#define cb(i) GI_FLOAT32_V2_t wdt##i;
UNROLL_CALL_NOWRAPPER(3, cb);
#undef cb

#define cb(i) GI_FLOAT32_V2_t ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
TRANSPOSE_8x3(wd, wdt); TRANSPOSE_8x3(wd, wdt);
FILTER_TRANSFORM(wdt, ret);
FILTER_TRANSFORM(wdt, ret, ADDFV2, SUBFV2, MULSFV2);


#define cb(i) ret##i.save(transform_mid_buf + i * alpha);
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
rep(i, alpha) rep(j, alpha) { rep(i, alpha) rep(j, alpha) {
@@ -116,8 +129,7 @@ struct FilterTransform6X3 {
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \
mid_buf1 += 8; \ mid_buf1 += 8; \
} while (0); } while (0);
#define GET_VECTOR_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value))
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i))


float* mid_buf1 = transform_mid_buf; float* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);


+ 162
- 146
dnn/src/fallback/conv_bias/gi/fp32/helper.h View File

@@ -2,6 +2,15 @@
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/fallback/general_intrinsic/gi_float.h" #include "src/fallback/general_intrinsic/gi_float.h"


#define ADDF GiAddFloat32
#define ADDFV2 GiAddFloat32V2
#define SUBF GiSubtractFloat32
#define SUBFV2 GiSubtractFloat32V2
#define MULF GiMultiplyFloat32
#define MULFV2 GiMultiplyFloat32V2
#define MULSF GiMultiplyScalerFloat32
#define MULSFV2 GiMultiplyScalerFloat32V2

namespace megdnn { namespace megdnn {
namespace fallback { namespace fallback {
inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) {
@@ -47,159 +56,166 @@ inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) {
#define CONCAT(a, idx) a##idx #define CONCAT(a, idx) a##idx


#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
//! ret and a are type Vector<float, 8>
#define TRANSPOSE_8x8(a, ret) \
do { \
auto b0 = GiZipqFloat32(CONCAT(a, 0).value.val[0], CONCAT(a, 1).value.val[0]); \
auto b1 = GiZipqFloat32(CONCAT(a, 0).value.val[1], CONCAT(a, 1).value.val[1]); \
auto b2 = GiZipqFloat32(CONCAT(a, 2).value.val[0], CONCAT(a, 3).value.val[0]); \
auto b3 = GiZipqFloat32(CONCAT(a, 2).value.val[1], CONCAT(a, 3).value.val[1]); \
auto b4 = GiZipqFloat32(CONCAT(a, 4).value.val[0], CONCAT(a, 5).value.val[0]); \
auto b5 = GiZipqFloat32(CONCAT(a, 4).value.val[1], CONCAT(a, 5).value.val[1]); \
auto b6 = GiZipqFloat32(CONCAT(a, 6).value.val[0], CONCAT(a, 7).value.val[0]); \
auto b7 = GiZipqFloat32(CONCAT(a, 6).value.val[1], CONCAT(a, 7).value.val[1]); \
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b2.val[0]))); \
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b4.val[0]), \
GiReinterpretqFloat32ToS64(b6.val[0]))); \
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b2.val[0]))); \
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b4.val[0]), \
GiReinterpretqFloat32ToS64(b6.val[0]))); \
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b2.val[1]))); \
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b4.val[1]), \
GiReinterpretqFloat32ToS64(b6.val[1]))); \
CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b2.val[1]))); \
CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b4.val[1]), \
GiReinterpretqFloat32ToS64(b6.val[1]))); \
CONCAT(ret, 4).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b1.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 4).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b5.val[0]), \
GiReinterpretqFloat32ToS64(b7.val[0]))); \
CONCAT(ret, 5).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b1.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 5).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b5.val[0]), \
GiReinterpretqFloat32ToS64(b7.val[0]))); \
CONCAT(ret, 6).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b1.val[1]), \
GiReinterpretqFloat32ToS64(b3.val[1]))); \
CONCAT(ret, 6).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b5.val[1]), \
GiReinterpretqFloat32ToS64(b7.val[1]))); \
CONCAT(ret, 7).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b1.val[1]), \
GiReinterpretqFloat32ToS64(b3.val[1]))); \
CONCAT(ret, 7).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b5.val[1]), \
GiReinterpretqFloat32ToS64(b7.val[1]))); \
#define TRANSPOSE_8x8(a, ret) \
do { \
auto b0 = GiZipqFloat32(CONCAT(a, 0).val[0], CONCAT(a, 1).val[0]); \
auto b1 = GiZipqFloat32(CONCAT(a, 0).val[1], CONCAT(a, 1).val[1]); \
auto b2 = GiZipqFloat32(CONCAT(a, 2).val[0], CONCAT(a, 3).val[0]); \
auto b3 = GiZipqFloat32(CONCAT(a, 2).val[1], CONCAT(a, 3).val[1]); \
auto b4 = GiZipqFloat32(CONCAT(a, 4).val[0], CONCAT(a, 5).val[0]); \
auto b5 = GiZipqFloat32(CONCAT(a, 4).val[1], CONCAT(a, 5).val[1]); \
auto b6 = GiZipqFloat32(CONCAT(a, 6).val[0], CONCAT(a, 7).val[0]); \
auto b7 = GiZipqFloat32(CONCAT(a, 6).val[1], CONCAT(a, 7).val[1]); \
CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b2.val[0]))); \
CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b4.val[0]), \
GiReinterpretqFloat32ToS64(b6.val[0]))); \
CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b2.val[0]))); \
CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b4.val[0]), \
GiReinterpretqFloat32ToS64(b6.val[0]))); \
CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b2.val[1]))); \
CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b4.val[1]), \
GiReinterpretqFloat32ToS64(b6.val[1]))); \
CONCAT(ret, 3).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b2.val[1]))); \
CONCAT(ret, 3).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b4.val[1]), \
GiReinterpretqFloat32ToS64(b6.val[1]))); \
CONCAT(ret, 4).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b1.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 4).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b5.val[0]), \
GiReinterpretqFloat32ToS64(b7.val[0]))); \
CONCAT(ret, 5).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b1.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 5).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b5.val[0]), \
GiReinterpretqFloat32ToS64(b7.val[0]))); \
CONCAT(ret, 6).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b1.val[1]), \
GiReinterpretqFloat32ToS64(b3.val[1]))); \
CONCAT(ret, 6).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b5.val[1]), \
GiReinterpretqFloat32ToS64(b7.val[1]))); \
CONCAT(ret, 7).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b1.val[1]), \
GiReinterpretqFloat32ToS64(b3.val[1]))); \
CONCAT(ret, 7).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b5.val[1]), \
GiReinterpretqFloat32ToS64(b7.val[1]))); \
} while (0); } while (0);


#define TRANSPOSE_8x3(a, ret) \
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b1.val[0]))); \
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b2.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b1.val[0]))); \
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b2.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b1.val[1]))); \
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b2.val[1]), \
#define TRANSPOSE_8x3(a, ret) \
auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \
auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \
auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5)); \
auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7)); \
CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b1.val[0]))); \
CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b2.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b1.val[0]))); \
CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b2.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b1.val[1]))); \
CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b2.val[1]), \
GiReinterpretqFloat32ToS64(b3.val[1]))); GiReinterpretqFloat32ToS64(b3.val[1])));


#define TRANSPOSE_8x4(a, ret) \
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b1.val[0]))); \
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b2.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b1.val[0]))); \
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b2.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b1.val[1]))); \
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b2.val[1]), \
GiReinterpretqFloat32ToS64(b3.val[1]))); \
CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b1.val[1]))); \
CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b2.val[1]), \
#define TRANSPOSE_8x4(a, ret) \
auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \
auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \
auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5)); \
auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7)); \
CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b1.val[0]))); \
CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b2.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b1.val[0]))); \
CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b2.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b1.val[1]))); \
CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b2.val[1]), \
GiReinterpretqFloat32ToS64(b3.val[1]))); \
CONCAT(ret, 3).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b1.val[1]))); \
CONCAT(ret, 3).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b2.val[1]), \
GiReinterpretqFloat32ToS64(b3.val[1]))); GiReinterpretqFloat32ToS64(b3.val[1])));


#else #else
#define TRANSPOSE_8x4(a, ret) \
auto b0 = GiZipqFloat32( \
GiFixLenType2GiFloat32Type(CONCAT(a, 0).value), \
GiFixLenType2GiFloat32Type(CONCAT(a, 1).value)); \
auto b1 = GiZipqFloat32( \
GiFixLenType2GiFloat32Type(CONCAT(a, 2).value), \
GiFixLenType2GiFloat32Type(CONCAT(a, 3).value)); \
auto b2 = GiZipqFloat32( \
GiFixLenType2GiFloat32Type(CONCAT(a, 4).value), \
GiFixLenType2GiFloat32Type(CONCAT(a, 5).value)); \
auto b3 = GiZipqFloat32( \
GiFixLenType2GiFloat32Type(CONCAT(a, 6).value), \
GiFixLenType2GiFloat32Type(CONCAT(a, 7).value)); \
CONCAT(ret, 0).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 0)), \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \
CONCAT(ret, 1).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 0)), \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \
CONCAT(ret, 2).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 1)), \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \
CONCAT(ret, 3).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 1)), \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \
CONCAT(ret, 0).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 0)), \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \
CONCAT(ret, 1).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 0)), \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \
CONCAT(ret, 2).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 1)), \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 1)))); \
CONCAT(ret, 3).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 1)), \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 1))));
#define TRANSPOSE_8x4(a, ret) \
auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \
auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \
auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5)); \
auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7)); \
GiSetSubVectorFloat32V2( \
CONCAT(ret, 0), 0, \
GiCombineFloat32( \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 0)), \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \
GiSetSubVectorFloat32V2( \
CONCAT(ret, 1), 0, \
GiCombineFloat32( \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 0)), \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \
GiSetSubVectorFloat32V2( \
CONCAT(ret, 2), 0, \
GiCombineFloat32( \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 1)), \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \
GiSetSubVectorFloat32V2( \
CONCAT(ret, 3), 0, \
GiCombineFloat32( \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 1)), \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \
GiSetSubVectorFloat32V2( \
CONCAT(ret, 0), 1, \
GiCombineFloat32( \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 0)), \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \
GiSetSubVectorFloat32V2( \
CONCAT(ret, 1), 1, \
GiCombineFloat32( \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 0)), \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \
GiSetSubVectorFloat32V2( \
CONCAT(ret, 2), 1, \
GiCombineFloat32( \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 1)), \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 1)))); \
GiSetSubVectorFloat32V2( \
CONCAT(ret, 3), 1, \
GiCombineFloat32( \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 1)), \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 1))));


#endif #endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 36
- 37
dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp View File

@@ -1,7 +1,6 @@
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/conv_bias/winograd/winograd.h"


#include "src/fallback/conv_bias/gi/fp32/helper.h" #include "src/fallback/conv_bias/gi/fp32/helper.h"
@@ -70,8 +69,7 @@ struct InputTransform2X3 {
size_t nr_units_in_tile, size_t ic, size_t IC) { size_t nr_units_in_tile, size_t ic, size_t IC) {
constexpr size_t alpha = 2 + 3 - 1; constexpr size_t alpha = 2 + 3 - 1;
// BT * d * B // BT * d * B
#define cb(m, n) \
Vector<float, 4> d##m##n = Vector<float, 4>::load(patchT + m * 4 * 4 + n * 4);
#define cb(m, n) GI_FLOAT32_t d##m##n = GiLoadFloat32(patchT + m * 4 * 4 + n * 4);


UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb #undef cb
@@ -80,20 +78,20 @@ struct InputTransform2X3 {
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 //! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1
#define cb(m) \
auto t0##m = d0##m - d2##m; \
auto t1##m = d1##m + d2##m; \
auto t2##m = d2##m - d1##m; \
auto t3##m = d3##m - d1##m;
#define cb(m) \
auto t0##m = SUBF(d0##m, d2##m); \
auto t1##m = ADDF(d1##m, d2##m); \
auto t2##m = SUBF(d2##m, d1##m); \
auto t3##m = SUBF(d3##m, d1##m);


UNROLL_CALL_NOWRAPPER(4, cb); UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb #undef cb


#define cb(m) \
d##m##0 = t##m##0 - t##m##2; \
d##m##1 = t##m##1 + t##m##2; \
d##m##2 = t##m##2 - t##m##1; \
d##m##3 = t##m##3 - t##m##1;
#define cb(m) \
d##m##0 = SUBF(t##m##0, t##m##2); \
d##m##1 = ADDF(t##m##1, t##m##2); \
d##m##2 = SUBF(t##m##2, t##m##1); \
d##m##3 = SUBF(t##m##3, t##m##1);


UNROLL_CALL_NOWRAPPER(4, cb); UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb #undef cb
@@ -101,9 +99,10 @@ struct InputTransform2X3 {
size_t ICB = IC / 4; size_t ICB = IC / 4;
size_t icb = ic / 4; size_t icb = ic / 4;
#define cb(m, n) \ #define cb(m, n) \
d##m##n.save( \
GiStoreFloat32( \
input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \
icb * nr_units_in_tile * 4 + unit_idx * 4);
icb * nr_units_in_tile * 4 + unit_idx * 4, \
d##m##n);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) UNROLL_CALL_NOWRAPPER_D2(4, 4, cb)
#undef cb #undef cb
} }
@@ -125,7 +124,7 @@ struct OutputTransform2X3 {
size_t ocb = oc_index / 4; size_t ocb = oc_index / 4;


#define cb(m, n) \ #define cb(m, n) \
auto v##m##n = Vector<float, 4>::load( \
auto v##m##n = GiLoadFloat32( \
output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \
ocb * nr_units_in_tile * 4 + unit_idx * 4); ocb * nr_units_in_tile * 4 + unit_idx * 4);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
@@ -134,37 +133,37 @@ struct OutputTransform2X3 {
//! 0 1 -1 1 v10 v11 v12 v13 1 1 //! 0 1 -1 1 v10 v11 v12 v13 1 1
//! v20 v21 v22 v23 1 -1 //! v20 v21 v22 v23 1 -1
//! v30 v31 v32 v33 0 1 //! v30 v31 v32 v33 0 1
#define cb(m) \
auto t0##m = v0##m + v1##m + v2##m; \
auto t1##m = v1##m - v2##m + v3##m;
#define cb(m) \
auto t0##m = ADDF(ADDF(v0##m, v1##m), v2##m); \
auto t1##m = ADDF(SUBF(v1##m, v2##m), v3##m);


UNROLL_CALL_NOWRAPPER(4, cb); UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb #undef cb
v00 = t00 + t01 + t02;
v10 = t10 + t11 + t12;
v01 = t01 - t02 + t03;
v11 = t11 - t12 + t13;
v00 = ADDF(ADDF(t00, t01), t02);
v10 = ADDF(ADDF(t10, t11), t12);
v01 = ADDF(SUBF(t01, t02), t03);
v11 = ADDF(SUBF(t11, t12), t13);


Vector<float, 4> vbias;
GI_FLOAT32_t vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 4>::load(bias + oc);
vbias = GiLoadFloat32(bias + oc);


v00 += vbias;
v10 += vbias;
v01 += vbias;
v11 += vbias;
v00 = ADDF(v00, vbias);
v10 = ADDF(v10, vbias);
v01 = ADDF(v01, vbias);
v11 = ADDF(v11, vbias);
} }
if (bmode != BiasMode::BIAS) { if (bmode != BiasMode::BIAS) {
v00 = op(GiFixLenType2GiFloat32Type(v00.value));
v01 = op(GiFixLenType2GiFloat32Type(v01.value));
v10 = op(GiFixLenType2GiFloat32Type(v10.value));
v11 = op(GiFixLenType2GiFloat32Type(v11.value));
v00 = op(v00);
v01 = op(v01);
v10 = op(v10);
v11 = op(v11);
} }


v00.save(transform_mid_buf + (0 * 2 + 0) * 4);
v10.save(transform_mid_buf + (1 * 2 + 0) * 4);
v01.save(transform_mid_buf + (0 * 2 + 1) * 4);
v11.save(transform_mid_buf + (1 * 2 + 1) * 4);
GiStoreFloat32(transform_mid_buf + (0 * 2 + 0) * 4, v00);
GiStoreFloat32(transform_mid_buf + (1 * 2 + 0) * 4, v10);
GiStoreFloat32(transform_mid_buf + (0 * 2 + 1) * 4, v01);
GiStoreFloat32(transform_mid_buf + (1 * 2 + 1) * 4, v11);


for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) {
for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) {


+ 201
- 96
dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp View File

@@ -1,7 +1,6 @@
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/conv_bias/winograd/winograd.h"


#include "src/fallback/conv_bias/gi/fp32/helper.h" #include "src/fallback/conv_bias/gi/fp32/helper.h"
@@ -15,61 +14,124 @@ using namespace megdnn;
using namespace fallback; using namespace fallback;
namespace { namespace {


/*
* wd##0 = d##0;
* wd##r0 = d##r0;
* wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222;
* wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222;
* wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222;
* wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222;
* auto tmpd0 = d##0 * 0.7111111;
* auto tmpd1 = d##1 * 0.3555556;
* auto tmpd2 = d##2 * 0.1777778;
* auto tmpd3 = d##3 * 0.0888889;
* auto tmpd4 = d##4 * 0.0444444;
* auto tmpdr0 = d##r0 * 0.7111111;
* auto tmpdr1 = d##r1 * 0.3555556;
* auto tmpdr2 = d##r2 * 0.1777778;
* auto tmpdr3 = d##r3 * 0.0888889;
* auto tmpdr4 = d##r4 * 0.0444444;
* wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4;
* wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4;
* wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4;
* wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4;
* tmpd0 = d##0 * 0.0111111;
* tmpd1 = d##1 * 0.0222222;
* tmpd2 = d##2 * 0.0444444;
* tmpd3 = d##3 * 0.0888889;
* tmpd4 = d##4 * 0.1777778;
* tmpdr0 = d##r0 * 0.0111111;
* tmpdr1 = d##r1 * 0.0222222;
* tmpdr2 = d##r2 * 0.0444444;
* tmpdr3 = d##r3 * 0.0888889;
* tmpdr4 = d##r4 * 0.1777778;
* wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4;
* wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4;
* wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4;
* wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4;
* wd##7 = d##4;
* wd##r7 = d##r4;
*
* */
struct FilterTransform4X5 { struct FilterTransform4X5 {
#define FILTER_TRANSFORM(d, wd) \
do { \
wd##0 = d##0; \
wd##r0 = d##r0; \
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \
wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; \
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \
wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; \
auto tmpd0 = d##0 * 0.7111111; \
auto tmpd1 = d##1 * 0.3555556; \
auto tmpd2 = d##2 * 0.1777778; \
auto tmpd3 = d##3 * 0.0888889; \
auto tmpd4 = d##4 * 0.0444444; \
auto tmpdr0 = d##r0 * 0.7111111; \
auto tmpdr1 = d##r1 * 0.3555556; \
auto tmpdr2 = d##r2 * 0.1777778; \
auto tmpdr3 = d##r3 * 0.0888889; \
auto tmpdr4 = d##r4 * 0.0444444; \
wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \
wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \
wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \
wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \
tmpd0 = d##0 * 0.0111111; \
tmpd1 = d##1 * 0.0222222; \
tmpd2 = d##2 * 0.0444444; \
tmpd3 = d##3 * 0.0888889; \
tmpd4 = d##4 * 0.1777778; \
tmpdr0 = d##r0 * 0.0111111; \
tmpdr1 = d##r1 * 0.0222222; \
tmpdr2 = d##r2 * 0.0444444; \
tmpdr3 = d##r3 * 0.0888889; \
tmpdr4 = d##r4 * 0.1777778; \
wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \
wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \
wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \
wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \
wd##7 = d##4; \
wd##r7 = d##r4; \
#define FILTER_TRANSFORM(d, wd) \
do { \
wd##0 = d##0; \
wd##r0 = d##r0; \
wd##1 = MULSF( \
ADDF(ADDF(ADDF(ADDF(d##0, d##1), d##2), d##3), d##4), -0.2222222); \
wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; \
wd##2 = MULSF( \
ADDF(SUBF(ADDF(SUBF(d##0, d##1), d##2), d##3), d##4), -0.2222222); \
wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; \
auto tmpd0 = MULSF(d##0, 0.7111111); \
auto tmpd1 = MULSF(d##1, 0.3555556); \
auto tmpd2 = MULSF(d##2, 0.1777778); \
auto tmpd3 = MULSF(d##3, 0.0888889); \
auto tmpd4 = MULSF(d##4, 0.0444444); \
auto tmpdr0 = d##r0 * 0.7111111; \
auto tmpdr1 = d##r1 * 0.3555556; \
auto tmpdr2 = d##r2 * 0.1777778; \
auto tmpdr3 = d##r3 * 0.0888889; \
auto tmpdr4 = d##r4 * 0.0444444; \
wd##3 = ADDF(ADDF(ADDF(ADDF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \
wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \
wd##4 = ADDF(SUBF(ADDF(SUBF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \
wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \
tmpd0 = MULSF(d##0, 0.0111111); \
tmpd1 = MULSF(d##1, 0.0222222); \
tmpd2 = MULSF(d##2, 0.0444444); \
tmpd3 = MULSF(d##3, 0.0888889); \
tmpd4 = MULSF(d##4, 0.1777778); \
tmpdr0 = d##r0 * 0.0111111; \
tmpdr1 = d##r1 * 0.0222222; \
tmpdr2 = d##r2 * 0.0444444; \
tmpdr3 = d##r3 * 0.0888889; \
tmpdr4 = d##r4 * 0.1777778; \
wd##5 = ADDF(ADDF(ADDF(ADDF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \
wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \
wd##6 = ADDF(SUBF(ADDF(SUBF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \
wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \
wd##7 = d##4; \
wd##r7 = d##r4; \
} while (0); } while (0);


#define FILTER_TRANSFORM_FINAL(d, wd) \
do { \
wd##0 = d##0; \
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \
auto tmp0 = d##0 * 0.7111111 + d##2 * 0.1777778 + d##4 * 0.0444444; \
auto tmp1 = d##1 * 0.3555556 + d##3 * 0.0888889; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d##0 * 0.0111111 + d##2 * 0.0444444 + d##4 * 0.1777778; \
tmp1 = d##1 * 0.0222222 + d##3 * 0.0888889; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = d##4; \
/*
*wd##0 = d##0;
*wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222;
*wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222;
*auto tmp0 = d##0 * 0.7111111 + d##2 * 0.1777778 + d##4 * 0.0444444;
*auto tmp1 = d##1 * 0.3555556 + d##3 * 0.0888889;
*wd##3 = tmp0 + tmp1;
*wd##4 = tmp0 - tmp1;
*tmp0 = d##0 * 0.0111111 + d##2 * 0.0444444 + d##4 * 0.1777778;
*tmp1 = d##1 * 0.0222222 + d##3 * 0.0888889;
*wd##5 = tmp0 + tmp1;
*wd##6 = tmp0 - tmp1;
*wd##7 = d##4;
*/
#define FILTER_TRANSFORM_FINAL(d, wd) \
do { \
wd##0 = d##0; \
wd##1 = \
MULSFV2(ADDFV2(ADDFV2(ADDFV2(ADDFV2(d##0, d##1), d##2), d##3), d##4), \
-0.2222222); \
wd##2 = \
MULSFV2(ADDFV2(SUBFV2(ADDFV2(SUBFV2(d##0, d##1), d##2), d##3), d##4), \
-0.2222222); \
auto tmp0 = \
ADDFV2(ADDFV2(MULSFV2(d##0, 0.7111111), MULSFV2(d##2, 0.1777778)), \
MULSFV2(d##4, 0.0444444)); \
auto tmp1 = ADDFV2(MULSFV2(d##1, 0.3555556), MULSFV2(d##3, 0.0888889)); \
wd##3 = ADDFV2(tmp0, tmp1); \
wd##4 = SUBFV2(tmp0, tmp1); \
tmp0 = \
ADDFV2(ADDFV2(MULSFV2(d##0, 0.0111111), MULSFV2(d##2, 0.0444444)), \
MULSFV2(d##4, 0.1777778)); \
tmp1 = ADDFV2(MULSFV2(d##1, 0.0222222), MULSFV2(d##3, 0.0888889)); \
wd##5 = ADDFV2(tmp0, tmp1); \
wd##6 = SUBFV2(tmp0, tmp1); \
wd##7 = d##4; \
} while (0); } while (0);
static void transform( static void transform(
const float* filter, float* filter_transform_buf, float* transform_mid_buf, const float* filter, float* filter_transform_buf, float* transform_mid_buf,
@@ -89,7 +151,7 @@ struct FilterTransform4X5 {
rep(ic, IC) { rep(ic, IC) {
const float* fptr = filter + (oc * IC + ic) * 5 * 5; const float* fptr = filter + (oc * IC + ic) * 5 * 5;


#define cb(i) Vector<float, 4> g##i = Vector<float, 4>::load(fptr + 5 * i);
#define cb(i) GI_FLOAT32_t g##i = GiLoadFloat32(fptr + 5 * i);
UNROLL_CALL_NOWRAPPER(5, cb); UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb #undef cb


@@ -97,7 +159,7 @@ struct FilterTransform4X5 {
UNROLL_CALL_NOWRAPPER(5, cb); UNROLL_CALL_NOWRAPPER(5, cb);


#undef cb #undef cb
#define cb(i) Vector<float, 4> Gg##i;
#define cb(i) GI_FLOAT32_t Gg##i;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


@@ -105,11 +167,11 @@ struct FilterTransform4X5 {
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


#define cb(i) Vector<float, 8> Ggt##i;
#define cb(i) GI_FLOAT32_V2_t Ggt##i;
UNROLL_CALL_NOWRAPPER(4, cb); UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb #undef cb


#define cb(i) Vector<float, 8> result##i;
#define cb(i) GI_FLOAT32_V2_t result##i;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


@@ -128,11 +190,11 @@ struct FilterTransform4X5 {
GI_FLOAT32_t vgr1 = GiLoadFloat32(tmp); GI_FLOAT32_t vgr1 = GiLoadFloat32(tmp);
GiSetSubVectorFloat32V2(vgr, 0, vgr0); //{Ggr0, Ggr1, Ggr2, Ggr3}; GiSetSubVectorFloat32V2(vgr, 0, vgr0); //{Ggr0, Ggr1, Ggr2, Ggr3};
GiSetSubVectorFloat32V2(vgr, 1, vgr1); //{Ggr4, Ggr5, Ggr6, Ggr7}; GiSetSubVectorFloat32V2(vgr, 1, vgr1); //{Ggr4, Ggr5, Ggr6, Ggr7};
Vector<float, 8> Ggt4(vgr);
GI_FLOAT32_V2_t Ggt4(vgr);
TRANSPOSE_8x4(Gg, Ggt); TRANSPOSE_8x4(Gg, Ggt);
FILTER_TRANSFORM_FINAL(Ggt, result); FILTER_TRANSFORM_FINAL(Ggt, result);


#define cb(i) result##i.save(transform_mid_buf + i * alpha);
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, result##i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
rep(i, alpha) rep(j, alpha) { rep(i, alpha) rep(j, alpha) {
@@ -145,31 +207,49 @@ struct FilterTransform4X5 {
#undef FILTER_TRANSFORM #undef FILTER_TRANSFORM
#undef FILTER_TRANSFORM_FINAL #undef FILTER_TRANSFORM_FINAL


/*
* wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f;
* auto tmp0 = d##2 - d##4 * 4.25f + d##6;
* auto tmp1 = d##1 - d##3 * 4.25f + d##5;
* wd##1 = tmp0 + tmp1;
* wd##2 = tmp0 - tmp1;
* tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6;
* tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f;
* wd##3 = tmp0 + tmp1;
* wd##4 = tmp0 - tmp1;
* tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6;
* tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f;
* wd##5 = tmp0 + tmp1;
* wd##6 = tmp0 - tmp1;
* wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f;
*/
struct InputTransform4X5 { struct InputTransform4X5 {
#define INPUT_TRANSFORM(d, wd) \
do { \
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \
auto tmp0 = d##2 - d##4 * 4.25f + d##6; \
auto tmp1 = d##1 - d##3 * 4.25f + d##5; \
wd##1 = tmp0 + tmp1; \
wd##2 = tmp0 - tmp1; \
tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \
tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \
#define INPUT_TRANSFORM(d, wd) \
do { \
wd##0 = ADDFV2(SUBFV2(d##0, d##6), MULSFV2(SUBFV2(d##4, d##2), 5.25f)); \
auto tmp0 = ADDFV2(SUBFV2(d##2, MULSFV2(d##4, 4.25f)), d##6); \
auto tmp1 = ADDFV2(SUBFV2(d##1, MULSFV2(d##3, 4.25f)), d##5); \
wd##1 = ADDFV2(tmp0, tmp1); \
wd##2 = SUBFV2(tmp0, tmp1); \
tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 4.0f), MULSFV2(d##4, 5.0f)), d##6); \
tmp1 = \
ADDFV2(SUBFV2(MULSFV2(d##1, 2.0f), MULSFV2(d##3, 2.5f)), \
MULSFV2(d##5, 0.5f)); \
wd##3 = ADDFV2(tmp0, tmp1); \
wd##4 = SUBFV2(tmp0, tmp1); \
tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 0.25f), MULSFV2(d##4, 1.25f)), d##6); \
tmp1 = \
ADDFV2(SUBFV2(MULSFV2(d##1, 0.5f), MULSFV2(d##3, 2.5f)), \
MULSFV2(d##5, 2.0f)); \
wd##5 = ADDFV2(tmp0, tmp1); \
wd##6 = SUBFV2(tmp0, tmp1); \
wd##7 = ADDFV2(SUBFV2(d##7, d##1), MULSFV2(SUBFV2(d##3, d##5), 5.25f)); \
} while (0) } while (0)


#define GET_VECTOR_HIGH_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2( \
GiFixLenType2GiFloat32V2Type(CONCAT(s, i).value), 1))
#define GET_VECTOR_LOW_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2( \
GiFixLenType2GiFloat32V2Type(CONCAT(s, i).value), 0))
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 1))
#define GET_VECTOR_LOW_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 0))


template <bool inner> template <bool inner>
static void transform( static void transform(
@@ -191,13 +271,13 @@ struct InputTransform4X5 {
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha);
} }


#define cb(i) Vector<float, 8> d##i;
#define cb(i) GI_FLOAT32_V2_t d##i;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


if (inner) { if (inner) {
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start;
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i);
#define cb(i) d##i = GiLoadFloat32V2(input_ptr + IW * i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
} else { } else {
@@ -212,21 +292,25 @@ struct InputTransform4X5 {
input[ic * IH * IW + ih * IW + iw]; input[ic * IH * IW + ih * IW + iw];
} }
} }
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i);
#define cb(i) d##i = GiLoadFloat32V2(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
} }


#define cb(i) Vector<float, 8> wd##i, ret##i;
#define cb(i) GI_FLOAT32_V2_t wd##i;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


INPUT_TRANSFORM(d, wd); INPUT_TRANSFORM(d, wd);
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#define cb(i) GI_FLOAT32_V2_t ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

TRANSPOSE_8x8(wd, d); TRANSPOSE_8x8(wd, d);
INPUT_TRANSFORM(d, ret); INPUT_TRANSFORM(d, ret);


#define cb(i) ret##i.save(transform_mid_buf + i * alpha);
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
rep(i, alpha) rep(j, alpha) { rep(i, alpha) rep(j, alpha) {
@@ -283,12 +367,32 @@ struct InputTransform4X5 {
}; };
#undef INPUT_TRANSFORM #undef INPUT_TRANSFORM


#define OUTPUT_TRANSFORM(m, s) \
do { \
s0 = m0 + m1 + m2 + m3 + m4 + m5 + m6; \
s1 = m1 - m2 + m3 * 0.5 - m4 * 0.5 + m5 * 2.0 - m6 * 2.0; \
s2 = m1 + m2 + m3 * 0.25 + m4 * 0.25 + m5 * 4.0 + m6 * 4.0; \
s3 = m1 - m2 + m3 * 0.125 - m4 * 0.125 + m5 * 8.0 - m6 * 8.0 + m7; \
/*
* s0 = m0 + m1 + m2 + m3 + m4 + m5 + m6;
* s1 = m1 - m2 + m3 * 0.5 - m4 * 0.5 + m5 * 2.0 - m6 * 2.0;
* s2 = m1 + m2 + m3 * 0.25 + m4 * 0.25 + m5 * 4.0 + m6 * 4.0;
* s3 = m1 - m2 + m3 * 0.125 - m4 * 0.125 + m5 * 8.0 - m6 * 8.0 + m7;
*/
#define OUTPUT_TRANSFORM(m, s) \
do { \
s0 = ADDFV2( \
ADDFV2(ADDFV2(ADDFV2(ADDFV2(ADDFV2(m0, m1), m2), m3), m4), m5), m6); \
s1 = \
SUBFV2(ADDFV2(SUBFV2(ADDFV2(SUBFV2(m1, m2), MULSFV2(m3, 0.5)), \
MULSFV2(m4, 0.5)), \
MULSFV2(m5, 2.0)), \
MULSFV2(m6, 2.0)); \
s2 = \
ADDFV2(ADDFV2(ADDFV2(ADDFV2(ADDFV2(m1, m2), MULSFV2(m3, 0.25)), \
MULSFV2(m4, 0.25)), \
MULSFV2(m5, 4.0)), \
MULSFV2(m6, 4.0)); \
s3 = ADDFV2( \
SUBFV2(ADDFV2(SUBFV2(ADDFV2(SUBFV2(m1, m2), MULSFV2(m3, 0.125)), \
MULSFV2(m4, 0.125)), \
MULSFV2(m5, 8.0)), \
MULSFV2(m6, 8.0)), \
m7); \
} while (0) } while (0)
template <BiasMode bmode, typename Op> template <BiasMode bmode, typename Op>
struct OutputTransform4X5 { struct OutputTransform4X5 {
@@ -316,14 +420,15 @@ struct OutputTransform4X5 {
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb #undef cb


#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i);
#define cb(i) auto m##i = GiLoadFloat32V2(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
#define cb(i) Vector<float, 8> s##i;
#define cb(i) GI_FLOAT32_V2_t s##i;
UNROLL_CALL_NOWRAPPER(4, cb); UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb #undef cb


OUTPUT_TRANSFORM(m, s); OUTPUT_TRANSFORM(m, s);

#define cb(i) \ #define cb(i) \
do { \ do { \
auto add12 = GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ auto add12 = GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \
@@ -355,10 +460,10 @@ struct OutputTransform4X5 {
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1);


if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
item0 = GiAddFloat32(item0, bias0);
item0 = ADDF(item0, bias0);
} else if (bmode == BiasMode::BIAS) { } else if (bmode == BiasMode::BIAS) {
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start);
item0 = GiAddFloat32(item0, bias0);
item0 = ADDF(item0, bias0);
} }
item0 = op(item0); item0 = op(item0);
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0);


+ 133
- 72
dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp View File

@@ -1,7 +1,6 @@
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/conv_bias/winograd/winograd.h"


#include "src/fallback/conv_bias/gi/fp32/helper.h" #include "src/fallback/conv_bias/gi/fp32/helper.h"
@@ -15,23 +14,39 @@ using namespace megdnn;
using namespace fallback; using namespace fallback;
namespace { namespace {


/*
* wd##0 = d##0;
* auto tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f;
* auto tmp1 = d##1 * 0.3555556f + d##3 * 0.0888889f;
* wd##1 = tmp0 + tmp1;
* wd##2 = tmp0 - tmp1;
* tmp0 = (d##0 + d##2) * -0.2222222f;
* tmp1 = (d##1 + d##3) * -0.2222222f;
* wd##3 = tmp0 + tmp1;
* wd##4 = tmp0 - tmp1;
* tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f;
* tmp1 = d##1 * 0.0222222f + d##3 * 0.0888889f;
* wd##5 = tmp0 + tmp1;
* wd##6 = tmp0 - tmp1;
* wd##7 = d##3;
*/
struct FilterTransform5X4 { struct FilterTransform5X4 {
#define FILTER_TRANSFORM(d, wd) \
do { \
wd##0 = d##0; \
auto tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \
auto tmp1 = d##1 * 0.3555556f + d##3 * 0.0888889f; \
wd##1 = tmp0 + tmp1; \
wd##2 = tmp0 - tmp1; \
tmp0 = (d##0 + d##2) * -0.2222222f; \
tmp1 = (d##1 + d##3) * -0.2222222f; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \
tmp1 = d##1 * 0.0222222f + d##3 * 0.0888889f; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = d##3; \
#define FILTER_TRANSFORM(d, wd, ADDC, SUBC, MULC) \
do { \
wd##0 = d##0; \
auto tmp0 = ADDC(MULC(d##0, 0.7111111f), MULC(d##2, 0.1777778f)); \
auto tmp1 = ADDC(MULC(d##1, 0.3555556f), MULC(d##3, 0.0888889f)); \
wd##1 = ADDC(tmp0, tmp1); \
wd##2 = SUBC(tmp0, tmp1); \
tmp0 = MULC(ADDC(d##0, d##2), -0.2222222f); \
tmp1 = MULC(ADDC(d##1, d##3), -0.2222222f); \
wd##3 = ADDC(tmp0, tmp1); \
wd##4 = SUBC(tmp0, tmp1); \
tmp0 = ADDC(MULC(d##0, 0.0111111f), MULC(d##2, 0.0444444f)); \
tmp1 = ADDC(MULC(d##1, 0.0222222f), MULC(d##3, 0.0888889f)); \
wd##5 = ADDC(tmp0, tmp1); \
wd##6 = SUBC(tmp0, tmp1); \
wd##7 = d##3; \
} while (0) } while (0)


static void transform( static void transform(
@@ -53,28 +68,27 @@ struct FilterTransform5X4 {
rep(ic, IC) { rep(ic, IC) {
const float* fptr = filter + (oc * IC + ic) * 4 * 4; const float* fptr = filter + (oc * IC + ic) * 4 * 4;


#define cb(i) Vector<float, 4> g##i = Vector<float, 4>::load(fptr + 4 * i);
#define cb(i) GI_FLOAT32_t g##i = GiLoadFloat32(fptr + 4 * i);
UNROLL_CALL_NOWRAPPER(4, cb); UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb #undef cb


#define cb(i) Vector<float, 4> wd##i;
#define cb(i) GI_FLOAT32_t wd##i;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


#define cb(i) Vector<float, 8> wdt##i;
FILTER_TRANSFORM(g, wd, ADDF, SUBF, MULSF);
#if MEGDNN_AARCH64
#define cb(i) GI_FLOAT32_V2_t wdt##i;
UNROLL_CALL_NOWRAPPER(4, cb); UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb #undef cb


#define cb(i) Vector<float, 8> ret##i;
#define cb(i) GI_FLOAT32_V2_t ret##i;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb

FILTER_TRANSFORM(g, wd);
#if MEGDNN_AARCH64
TRANSPOSE_8x4(wd, wdt); TRANSPOSE_8x4(wd, wdt);
FILTER_TRANSFORM(wdt, ret);
FILTER_TRANSFORM(wdt, ret, ADDFV2, SUBFV2, MULSFV2);


#define cb(i) ret##i.save(transform_mid_buf + i * alpha);
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
rep(i, alpha) rep(j, alpha) { rep(i, alpha) rep(j, alpha) {
@@ -104,8 +118,7 @@ struct FilterTransform5X4 {
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \
mid_buf1 += 8; \ mid_buf1 += 8; \
} while (0); } while (0);
#define GET_VECTOR_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value))
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i))


float* mid_buf1 = transform_mid_buf; float* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
@@ -123,29 +136,49 @@ struct FilterTransform5X4 {
#undef FILTER_TRANSFORM #undef FILTER_TRANSFORM
#undef GET_VECTOR_ELEM #undef GET_VECTOR_ELEM


/*
* wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f;
* auto tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6;
* auto tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f;
* wd##1 = tmp0 + tmp1;
* wd##2 = tmp0 - tmp1;
* tmp0 = d##2 - d##4 * 4.25f + d##6;
* tmp1 = d##1 - d##3 * 4.25f + d##5;
* wd##3 = tmp0 + tmp1;
* wd##4 = tmp0 - tmp1;
* tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6;
* tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f;
* wd##5 = tmp0 + tmp1;
* wd##6 = tmp0 - tmp1;
* wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f;
*/
struct InputTransform5X4 { struct InputTransform5X4 {
#define INPUT_TRANSFORM(d, wd) \
do { \
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \
auto tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \
auto tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \
wd##1 = tmp0 + tmp1; \
wd##2 = tmp0 - tmp1; \
tmp0 = d##2 - d##4 * 4.25f + d##6; \
tmp1 = d##1 - d##3 * 4.25f + d##5; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \
#define INPUT_TRANSFORM(d, wd) \
do { \
wd##0 = ADDFV2(SUBFV2(d##0, d##6), MULSFV2(SUBFV2(d##4, d##2), 5.25f)); \
auto tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 4.0f), MULSFV2(d##4, 5.0f)), d##6); \
auto tmp1 = \
ADDFV2(SUBFV2(MULSFV2(d##1, 2.0f), MULSFV2(d##3, 2.5f)), \
MULSFV2(d##5, 0.5f)); \
wd##1 = ADDFV2(tmp0, tmp1); \
wd##2 = SUBFV2(tmp0, tmp1); \
tmp0 = ADDFV2(SUBFV2(d##2, MULSFV2(d##4, 4.25f)), d##6); \
tmp1 = ADDFV2(SUBFV2(d##1, MULSFV2(d##3, 4.25f)), d##5); \
wd##3 = ADDFV2(tmp0, tmp1); \
wd##4 = SUBFV2(tmp0, tmp1); \
tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 0.25f), MULSFV2(d##4, 1.25f)), d##6); \
tmp1 = \
ADDFV2(SUBFV2(MULSFV2(d##1, 0.5f), MULSFV2(d##3, 2.5f)), \
MULSFV2(d##5, 2.0f)); \
wd##5 = ADDFV2(tmp0, tmp1); \
wd##6 = SUBFV2(tmp0, tmp1); \
wd##7 = ADDFV2(SUBFV2(d##7, d##1), MULSFV2(SUBFV2(d##3, d##5), 5.25f)); \
} while (0) } while (0)


#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ #define GET_VECTOR_HIGH_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[1]))
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 1))
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ #define GET_VECTOR_LOW_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[0]))
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 0))


template <bool inner> template <bool inner>
static void transform( static void transform(
@@ -168,13 +201,13 @@ struct InputTransform5X4 {
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha);
} }


#define cb(i) Vector<float, 8> d##i;
#define cb(i) GI_FLOAT32_V2_t d##i;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


if (inner) { if (inner) {
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start;
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i);
#define cb(i) d##i = GiLoadFloat32V2(input_ptr + IW * i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
} else { } else {
@@ -189,21 +222,25 @@ struct InputTransform5X4 {
input[ic * IH * IW + ih * IW + iw]; input[ic * IH * IW + ih * IW + iw];
} }
} }
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i);
#define cb(i) d##i = GiLoadFloat32V2(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
} }


#define cb(i) Vector<float, 8> wd##i, ret##i;
#define cb(i) GI_FLOAT32_V2_t wd##i;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


INPUT_TRANSFORM(d, wd); INPUT_TRANSFORM(d, wd);
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#define cb(i) GI_FLOAT32_V2_t ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

TRANSPOSE_8x8(wd, d); TRANSPOSE_8x8(wd, d);
INPUT_TRANSFORM(d, ret); INPUT_TRANSFORM(d, ret);


#define cb(i) ret##i.save(transform_mid_buf + i * alpha);
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
rep(i, alpha) rep(j, alpha) { rep(i, alpha) rep(j, alpha) {
@@ -260,24 +297,48 @@ struct InputTransform5X4 {
}; };
#undef INPUT_TRANSFORM #undef INPUT_TRANSFORM


#define OUTPUT_TRANSFORM(m, s) \
do { \
auto m1addm2 = m##1 + m##2; \
auto m1subm2 = m##1 - m##2; \
auto m3addm4 = m##3 + m##4; \
auto m3subm4 = m##3 - m##4; \
auto m5addm6 = (m##5 + m##6); \
auto m5subm6 = (m##5 - m##6); \
s##0 = m##0; \
CONCAT(s, 0).add(m1addm2).add(m3addm4).add(m5addm6); \
CONCAT(s, 1) = m3subm4; \
CONCAT(s, 1).mla(m1subm2, 0.5f).mla(m5subm6, 2.0f); \
CONCAT(s, 2) = m3addm4; \
CONCAT(s, 2).mla(m1addm2, 0.25f).mla(m5addm6, 4.0f); \
CONCAT(s, 3) = m3subm4; \
CONCAT(s, 3).mla(m1subm2, 0.125f).mla(m5subm6, 8.0f); \
CONCAT(s, 4) = m##7; \
CONCAT(s, 4).mla(m1addm2, 0.0625f).add(m3addm4).mla(m5addm6, 16.0f); \
/*
* auto m1addm2 = m##1 + m##2;
* auto m1subm2 = m##1 - m##2;
* auto m3addm4 = m##3 + m##4;
* auto m3subm4 = m##3 - m##4;
* auto m5addm6 = (m##5 + m##6);
* auto m5subm6 = (m##5 - m##6);
* s##0 = m##0;
* CONCAT(s, 0).add(m1addm2).add(m3addm4).add(m5addm6);
* CONCAT(s, 1) = m3subm4;
* CONCAT(s, 1).mla(m1subm2, 0.5f).mla(m5subm6, 2.0f);
* CONCAT(s, 2) = m3addm4;
* CONCAT(s, 2).mla(m1addm2, 0.25f).mla(m5addm6, 4.0f);
* CONCAT(s, 3) = m3subm4;
* CONCAT(s, 3).mla(m1subm2, 0.125f).mla(m5subm6, 8.0f);
* CONCAT(s, 4) = m##7;
* CONCAT(s, 4).mla(m1addm2, 0.0625f).add(m3addm4).mla(m5addm6, 16.0f);
*/
#define OUTPUT_TRANSFORM(m, s) \
do { \
auto m1addm2 = ADDFV2(m##1, m##2); \
auto m1subm2 = SUBFV2(m##1, m##2); \
auto m3addm4 = ADDFV2(m##3, m##4); \
auto m3subm4 = SUBFV2(m##3, m##4); \
auto m5addm6 = ADDFV2(m##5, m##6); \
auto m5subm6 = SUBFV2(m##5, m##6); \
s##0 = m##0; \
CONCAT(s, 0) = \
ADDFV2(ADDFV2(ADDFV2(CONCAT(s, 0), m1addm2), m3addm4), m5addm6); \
CONCAT(s, 1) = m3subm4; \
CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m1subm2, 0.5f); \
CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m5subm6, 2.0f); \
CONCAT(s, 2) = m3addm4; \
CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m1addm2, 0.25f); \
CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m5addm6, 4.0f); \
CONCAT(s, 3) = m3subm4; \
CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m1subm2, 0.125f); \
CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m5subm6, 8.0f); \
CONCAT(s, 4) = m##7; \
CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m1addm2, 0.0625f); \
CONCAT(s, 4) = ADDFV2(CONCAT(s, 4), m3addm4); \
CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m5addm6, 16.0f); \
} while (0) } while (0)


#if defined(__GNUC__) && !defined(__llvm__) && !defined(_MSC_VER) #if defined(__GNUC__) && !defined(__llvm__) && !defined(_MSC_VER)
@@ -314,11 +375,11 @@ struct OutputTransform5X4 {
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb #undef cb


#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i);
#define cb(i) auto m##i = GiLoadFloat32V2(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
#define cb(i) Vector<float, 8> s##i, ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#define cb(i) GI_FLOAT32_V2_t s##i;
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb #undef cb


OUTPUT_TRANSFORM(m, s); OUTPUT_TRANSFORM(m, s);


+ 60
- 47
dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp View File

@@ -3,7 +3,6 @@
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" #include "src/fallback/conv_bias/gi/fp32/filter_transform.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h" #include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/fallback/elemwise_helper/op_unary.h" #include "src/fallback/elemwise_helper/op_unary.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h" #include "src/naive/matrix_mul/matrix_mul_helper.h"
@@ -27,28 +26,31 @@ namespace {
* wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) * wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3)
* wd7 = (d7 - d1) + 5.25 * (d3 - d5) * wd7 = (d7 - d1) + 5.25 * (d3 - d5)
*/ */
#define INPUT_TRANSFORM(d, wd) \
do { \
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \
auto tmp0 = d##6 + d##2 - d##4 * 4.25f; \
auto tmp1 = d##1 + d##5 - d##3 * 4.25f; \
wd##1 = tmp0 + tmp1; \
wd##2 = tmp0 - tmp1; \
tmp0 = d##6 + d##2 * 0.25f - d##4 * 1.25f; \
tmp1 = (d##5 + d##1 * 0.25f - d##3 * 1.25f) * 2.0f; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d6 - d4 * 5.0f + d2 * 4.0f; \
tmp1 = (d1 + d5 * 0.25f - d3 * 1.25f) * 2.0f; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \
#define INPUT_TRANSFORM(d, wd) \
do { \
wd##0 = ADDFV2(SUBFV2(d##0, d##6), MULSFV2(SUBFV2(d##4, d##2), 5.25f)); \
auto tmp0 = SUBFV2(ADDFV2(d##6, d##2), MULSFV2(d##4, 4.25f)); \
auto tmp1 = SUBFV2(ADDFV2(d##1, d##5), MULSFV2(d##3, 4.25f)); \
wd##1 = ADDFV2(tmp0, tmp1); \
wd##2 = SUBFV2(tmp0, tmp1); \
tmp0 = SUBFV2(ADDFV2(d##6, MULSFV2(d##2, 0.25f)), MULSFV2(d##4, 1.25f)); \
tmp1 = MULSFV2( \
(SUBFV2(ADDFV2(d##5, MULSFV2(d##1, 0.25f)), MULSFV2(d##3, 1.25f))), \
2.0f); \
wd##3 = ADDFV2(tmp0, tmp1); \
wd##4 = SUBFV2(tmp0, tmp1); \
tmp0 = ADDFV2(SUBFV2(d6, MULSFV2(d4, 5.0f)), MULSFV2(d2, 4.0f)); \
tmp1 = MULSFV2( \
(SUBFV2(ADDFV2(d1, MULSFV2(d5, 0.25f)), MULSFV2(d3, 1.25f))), 2.0f); \
wd##5 = ADDFV2(tmp0, tmp1); \
wd##6 = SUBFV2(tmp0, tmp1); \
wd##7 = ADDFV2(SUBFV2(d##7, d##1), MULSFV2(SUBFV2(d##3, d##5), 5.25f)); \
} while (0); } while (0);


#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ #define GET_VECTOR_HIGH_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[1]))
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 1))
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ #define GET_VECTOR_LOW_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[0]))
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 0))
struct InputTransform6X3 { struct InputTransform6X3 {
template <bool inner> template <bool inner>
static void transform( static void transform(
@@ -60,13 +62,13 @@ struct InputTransform6X3 {
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha);
} }


#define cb(i) Vector<float, 8> d##i;
#define cb(i) GI_FLOAT32_V2_t d##i;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


if (inner) { if (inner) {
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start;
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i);
#define cb(i) d##i = GiLoadFloat32V2(input_ptr + IW * i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
} else { } else {
@@ -81,22 +83,25 @@ struct InputTransform6X3 {
input[ic * IH * IW + ih * IW + iw]; input[ic * IH * IW + ih * IW + iw];
} }
} }
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i);
#define cb(i) d##i = GiLoadFloat32V2(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
} }


#define cb(i) Vector<float, 8> wd##i, ret##i;
#define cb(i) GI_FLOAT32_V2_t wd##i;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


INPUT_TRANSFORM(d, wd); INPUT_TRANSFORM(d, wd);


#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#define cb(i) GI_FLOAT32_V2_t ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
TRANSPOSE_8x8(wd, d); TRANSPOSE_8x8(wd, d);
INPUT_TRANSFORM(d, ret); INPUT_TRANSFORM(d, ret);


#define cb(i) ret##i.save(transform_mid_buf + i * alpha);
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


@@ -174,26 +179,34 @@ struct InputTransform6X3 {
* s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6) / 32 * s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6) / 32
* s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) / 32 + m7 * s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) / 32 + m7
*/ */
#define OUTPUT_TRANSFORM(m, s) \
do { \
auto m1addm2 = m##1 + m##2; \
auto m1subm2 = m##1 - m##2; \
auto m3addm4 = m##3 + m##4; \
auto m3subm4 = m##3 - m##4; \
auto m5addm6 = (m##5 + m##6) * 0.03125f; \
auto m5subm6 = (m##5 - m##6) * 0.03125f; \
s##0 = m##0; \
CONCAT(s, 0).mla(m5addm6, 32.f).add(m3addm4).add(m1addm2); \
CONCAT(s, 1) = m1subm2; \
CONCAT(s, 1).mla(m3subm4, 2.f).mla(m5subm6, 16.f); \
CONCAT(s, 2) = m1addm2; \
CONCAT(s, 2).mla(m3addm4, 4.f).mla(m5addm6, 8.f); \
CONCAT(s, 3) = m1subm2; \
CONCAT(s, 3).mla(m3subm4, 8.f).mla(m5subm6, 4.f); \
CONCAT(s, 4) = m1addm2; \
CONCAT(s, 4).mla(m3addm4, 16.f).mla(m5addm6, 2.f); \
CONCAT(s, 5) = m1subm2; \
CONCAT(s, 5).mla(m3subm4, 32.f).add(m5subm6).add(m##7); \
#define OUTPUT_TRANSFORM(m, s) \
do { \
auto m1addm2 = ADDFV2(m##1, m##2); \
auto m1subm2 = SUBFV2(m##1, m##2); \
auto m3addm4 = ADDFV2(m##3, m##4); \
auto m3subm4 = SUBFV2(m##3, m##4); \
auto m5addm6 = MULSFV2((ADDFV2(m##5, m##6)), 0.03125f); \
auto m5subm6 = MULSFV2((SUBFV2(m##5, m##6)), 0.03125f); \
s##0 = m##0; \
CONCAT(s, 0) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 0), m5addm6, 32.f); \
CONCAT(s, 0) = ADDFV2(CONCAT(s, 0), m3addm4); \
CONCAT(s, 0) = ADDFV2(CONCAT(s, 0), m1addm2); \
CONCAT(s, 1) = m1subm2; \
CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m3subm4, 2.f); \
CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m5subm6, 16.f); \
CONCAT(s, 2) = m1addm2; \
CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m3addm4, 4.f); \
CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m5addm6, 8.f); \
CONCAT(s, 3) = m1subm2; \
CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m3subm4, 8.f); \
CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m5subm6, 4.f); \
CONCAT(s, 4) = m1addm2; \
CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m3addm4, 16.f); \
CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m5addm6, 2.f); \
CONCAT(s, 5) = m1subm2; \
CONCAT(s, 5) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 5), m3subm4, 32.f); \
CONCAT(s, 5) = ADDFV2(CONCAT(s, 5), m5subm6); \
CONCAT(s, 5) = ADDFV2(CONCAT(s, 5), m##7); \
} while (0); } while (0);


#if defined(__GNUC__) && !defined(__llvm__) && !defined(_MSC_VER) #if defined(__GNUC__) && !defined(__llvm__) && !defined(_MSC_VER)
@@ -224,11 +237,11 @@ struct OutputTransform6X3 {
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb #undef cb


#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i);
#define cb(i) auto m##i = GiLoadFloat32V2(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
#define cb(i) Vector<float, 8> s##i, ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#define cb(i) GI_FLOAT32_V2_t s##i;
UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb #undef cb


OUTPUT_TRANSFORM(m, s); OUTPUT_TRANSFORM(m, s);


+ 161
- 63
dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp View File

@@ -4,7 +4,6 @@
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" #include "src/fallback/conv_bias/gi/fp32/filter_transform.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h" #include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/fallback/elemwise_helper/op_unary.h" #include "src/fallback/elemwise_helper/op_unary.h"


@@ -76,8 +75,7 @@ struct InputTransform6X3 {
size_t nr_units_in_tile, size_t ic, size_t IC) { size_t nr_units_in_tile, size_t ic, size_t IC) {
constexpr size_t alpha = 6 + 3 - 1; constexpr size_t alpha = 6 + 3 - 1;
// BT * d * B // BT * d * B
#define cb(m, n) \
Vector<float, 4> d##m##n = Vector<float, 4>::load(patchT + m * 8 * 4 + n * 4);
#define cb(m, n) GI_FLOAT32_t d##m##n = GiLoadFloat32(patchT + m * 8 * 4 + n * 4);


UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb #undef cb
@@ -91,36 +89,103 @@ struct InputTransform6X3 {
//! 0 1 -1 2 -2 0.5 -0.5 -5.25 //! 0 1 -1 2 -2 0.5 -0.5 -5.25
//! -1 1 1 1 1 1 1 0 //! -1 1 1 1 1 1 1 0
//! 0 0 0 0 0 0 0 1 //! 0 0 0 0 0 0 0 1
/*
* auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m;
* auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f;
* auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f;
* auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f +
* d5##m * 2.f + d6##m;
* auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - d4##m * 1.25f -
* d5##m * 2.f + d6##m;
* auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f +
* d5##m * 0.5f + d6##m;
* auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f -
* d5##m * 0.5f + d6##m;
* auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f;
*/
#define cb(m) \ #define cb(m) \
auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \
auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \
auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \
auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \
d5##m * 2.f + d6##m; \
auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - d4##m * 1.25f - \
d5##m * 2.f + d6##m; \
auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \
d5##m * 0.5f + d6##m; \
auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \
d5##m * 0.5f + d6##m; \
auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f;
auto t0##m = SUBF(ADDF(d0##m, MULSF(SUBF(d4##m, d2##m), 5.25f)), d6##m); \
auto t1##m = \
SUBF(ADDF(ADDF(ADDF(d1##m, d2##m), d5##m), d6##m), \
MULSF(ADDF(d3##m, d4##m), 4.25f)); \
auto t2##m = \
ADDF(SUBF(ADDF(d2##m, d6##m), ADDF(d1##m, d5##m)), \
MULSF(SUBF(d3##m, d4##m), 4.25f)); \
auto t3##m = \
ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(d1##m, 0.5f), MULSF(d2##m, 0.25f)), \
MULSF(d3##m, 2.5f)), \
MULSF(d4##m, 1.25f)), \
MULSF(d5##m, 2.f)), \
d6##m); \
auto t4##m = \
ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(d1##m, (-0.5f)), MULSF(d2##m, 0.25f)), \
MULSF(d3##m, 2.5f)), \
MULSF(d4##m, 1.25f)), \
MULSF(d5##m, 2.f)), \
d6##m); \
auto t5##m = \
ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(d1##m, 2.f), MULSF(d2##m, 4.f)), \
MULSF(d3##m, 2.5f)), \
MULSF(d4##m, 5.f)), \
MULSF(d5##m, 0.5f)), \
d6##m); \
auto t6##m = \
ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(d1##m, (-2.f)), MULSF(d2##m, 4.f)), \
MULSF(d3##m, 2.5f)), \
MULSF(d4##m, 5.f)), \
MULSF(d5##m, 0.5f)), \
d6##m); \
auto t7##m = ADDF(SUBF(d7##m, d1##m), MULSF(SUBF(d3##m, d5##m), 5.25f));


UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


#define cb(m) \
d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \
d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - (t##m##3 + t##m##4) * 4.25f; \
d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + (t##m##3 - t##m##4) * 4.25f; \
d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - t##m##4 * 1.25f + \
t##m##5 * 2.f + t##m##6; \
d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - t##m##4 * 1.25f - \
t##m##5 * 2.f + t##m##6; \
d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \
t##m##5 * 0.5f + t##m##6; \
d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - t##m##4 * 5.f - \
t##m##5 * 0.5f + t##m##6; \
d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f;
/*
* d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6;
* d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - (t##m##3 + t##m##4)
* * 4.25f; d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + (t##m##3 -
* t##m##4) * 4.25f; d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f
* - t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; d##m##4 = t##m##1 * (-0.5f) +
* t##m##2 * 0.25f + t##m##3 * 2.5f - t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6;
* d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f +
* t##m##5 * 0.5f + t##m##6;
* d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - t##m##4 * 5.f -
* t##m##5 * 0.5f + t##m##6;
* d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f;
*/
#define cb(m) \
d##m##0 = SUBF(ADDF(t##m##0, MULSF(SUBF(t##m##4, t##m##2), 5.25f)), t##m##6); \
d##m##1 = \
SUBF(ADDF(ADDF(ADDF(t##m##1, t##m##2), t##m##5), t##m##6), \
MULSF(ADDF(t##m##3, t##m##4), 4.25f)); \
d##m##2 = \
ADDF(SUBF(ADDF(t##m##2, t##m##6), ADDF(t##m##1, t##m##5)), \
MULSF(SUBF(t##m##3, t##m##4), 4.25f)); \
d##m##3 = \
ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(t##m##1, 0.5f), MULSF(t##m##2, 0.25f)), \
MULSF(t##m##3, 2.5f)), \
MULSF(t##m##4, 1.25f)), \
MULSF(t##m##5, 2.f)), \
t##m##6); \
d##m##4 = \
ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(t##m##1, (-0.5f)), MULSF(t##m##2, 0.25f)), \
MULSF(t##m##3, 2.5f)), \
MULSF(t##m##4, 1.25f)), \
MULSF(t##m##5, 2.f)), \
t##m##6); \
d##m##5 = \
ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(t##m##1, 2.f), MULSF(t##m##2, 4.f)), \
MULSF(t##m##3, 2.5f)), \
MULSF(t##m##4, 5.f)), \
MULSF(t##m##5, 0.5f)), \
t##m##6); \
d##m##6 = \
ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(t##m##1, (-2.f)), MULSF(t##m##2, 4.f)), \
MULSF(t##m##3, 2.5f)), \
MULSF(t##m##4, 5.f)), \
MULSF(t##m##5, 0.5f)), \
t##m##6); \
d##m##7 = ADDF(SUBF(t##m##7, t##m##1), MULSF(SUBF(t##m##3, t##m##5), 5.25f));


UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
@@ -128,9 +193,10 @@ struct InputTransform6X3 {
size_t ICB = IC / 4; size_t ICB = IC / 4;
size_t icb = ic / 4; size_t icb = ic / 4;
#define cb(m, n) \ #define cb(m, n) \
d##m##n.save( \
GiStoreFloat32( \
input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \
icb * nr_units_in_tile * 4 + unit_idx * 4);
icb * nr_units_in_tile * 4 + unit_idx * 4, \
d##m##n);
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb) UNROLL_CALL_NOWRAPPER_D2(8, 8, cb)
#undef cb #undef cb
} }
@@ -152,7 +218,7 @@ struct OutputTransform6X3 {
size_t ocb = oc_index / 4; size_t ocb = oc_index / 4;


#define cb(m, n) \ #define cb(m, n) \
auto v##m##n = Vector<float, 4>::load( \
auto v##m##n = GiLoadFloat32( \
output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \
ocb * nr_units_in_tile * 4 + unit_idx * 4); ocb * nr_units_in_tile * 4 + unit_idx * 4);
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
@@ -171,56 +237,88 @@ struct OutputTransform6X3 {
* 0 0.0 0 0 0 1 * 0 0.0 0 0 0 1
*/ */


Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
#define cb(m) \
v1addv2 = v1##m + v2##m; \
v1subv2 = v1##m - v2##m; \
v3addv4 = v3##m + v4##m; \
v3subv4 = v3##m - v4##m; \
v5addv6 = v5##m + v6##m; \
v5subv6 = v5##m - v6##m; \
auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \
auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \
auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \
auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m;
/*
* v1addv2 = v1##m + v2##m;
* v1subv2 = v1##m - v2##m;
* v3addv4 = v3##m + v4##m;
* v3subv4 = v3##m - v4##m;
* v5addv6 = v5##m + v6##m;
* v5subv6 = v5##m - v6##m;
* auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6;
* auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f;
* auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f;
* auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f;
* auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f;
* auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m;
*/
GI_FLOAT32_t v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
#define cb(m) \
v1addv2 = ADDF(v1##m, v2##m); \
v1subv2 = SUBF(v1##m, v2##m); \
v3addv4 = ADDF(v3##m, v4##m); \
v3subv4 = SUBF(v3##m, v4##m); \
v5addv6 = ADDF(v5##m, v6##m); \
v5subv6 = SUBF(v5##m, v6##m); \
auto t0##m = ADDF(ADDF(ADDF(v0##m, v1addv2), v3addv4), v5addv6); \
auto t1##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \
auto t2##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \
auto t3##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \
auto t4##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \
auto t5##m = \
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \
v7##m);


UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


#define cb(m) \
v1addv2 = t##m##1 + t##m##2; \
v1subv2 = t##m##1 - t##m##2; \
v3addv4 = t##m##3 + t##m##4; \
v3subv4 = t##m##3 - t##m##4; \
v5addv6 = t##m##5 + t##m##6; \
v5subv6 = t##m##5 - t##m##6; \
v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \
v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \
v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \
v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7;
/*
* v1addv2 = t##m##1 + t##m##2;
* v1subv2 = t##m##1 - t##m##2;
* v3addv4 = t##m##3 + t##m##4;
* v3subv4 = t##m##3 - t##m##4;
* v5addv6 = t##m##5 + t##m##6;
* v5subv6 = t##m##5 - t##m##6;
* v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6;
* v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f;
* v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f;
* v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f;
* v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f;
* v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7;
*/
#define cb(m) \
v1addv2 = ADDF(t##m##1, t##m##2); \
v1subv2 = SUBF(t##m##1, t##m##2); \
v3addv4 = ADDF(t##m##3, t##m##4); \
v3subv4 = SUBF(t##m##3, t##m##4); \
v5addv6 = ADDF(t##m##5, t##m##6); \
v5subv6 = SUBF(t##m##5, t##m##6); \
v##m##0 = ADDF(ADDF(ADDF(t##m##0, v1addv2), v3addv4), v5addv6); \
v##m##1 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \
v##m##2 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \
v##m##3 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \
v##m##4 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \
v##m##5 = \
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \
t##m##7);


UNROLL_CALL_NOWRAPPER(6, cb); UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb #undef cb


Vector<float, 4> vbias;
GI_FLOAT32_t vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 4>::load(bias + oc);
vbias = GiLoadFloat32(bias + oc);


#define cb(m, n) v##m##n += vbias;
#define cb(m, n) v##m##n = GiAddFloat32(v##m##n, vbias);
UNROLL_CALL_RAW_D2(6, 6, cb); UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb #undef cb
} }
if (bmode != BiasMode::BIAS) { if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value));
#define cb(m, n) v##m##n = op(CONCAT(v##m, n));
UNROLL_CALL_RAW_D2(6, 6, cb); UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb #undef cb
} }


#define cb(m, n) CONCAT(v##m, n).save(transform_mid_buf + (m * 6 + n) * 4);
#define cb(m, n) GiStoreFloat32(transform_mid_buf + (m * 6 + n) * 4, CONCAT(v##m, n));
UNROLL_CALL_RAW_D2(6, 6, cb); UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb #undef cb




+ 55
- 51
dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp View File

@@ -1,7 +1,6 @@
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/conv_bias/winograd/winograd.h"


#include "src/fallback/conv_bias/gi/fp32/helper.h" #include "src/fallback/conv_bias/gi/fp32/helper.h"
@@ -29,7 +28,7 @@ struct InputTransformF23_NCHW44 {
size_t iw4_start = iw_start * pack_size; size_t iw4_start = iw_start * pack_size;
size_t ICB = IC / pack_size; size_t ICB = IC / pack_size;


#define cb(m, n) Vector<float, 4> d##m##n;
#define cb(m, n) GI_FLOAT32_t d##m##n;
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb #undef cb


@@ -40,7 +39,7 @@ struct InputTransformF23_NCHW44 {
MEGDNN_MARK_USED_VAR(patchT); MEGDNN_MARK_USED_VAR(patchT);
const float* input_ptr = const float* input_ptr =
input + icb * IH * IW4 + ih_start * IW4 + iw4_start; input + icb * IH * IW4 + ih_start * IW4 + iw4_start;
#define cb(n, m) d##m##n = Vector<float, 4>::load(input_ptr + pack_size * n);
#define cb(n, m) d##m##n = GiLoadFloat32(input_ptr + pack_size * n);


UNROLL_CALL_RAW(4, cb, 0); UNROLL_CALL_RAW(4, cb, 0);
input_ptr += IW4; input_ptr += IW4;
@@ -66,7 +65,7 @@ struct InputTransformF23_NCHW44 {
} }
} }
#define cb(m, n) \ #define cb(m, n) \
d##m##n = Vector<float, 4>::load(patchT + m * alpha * pack_size + n * pack_size);
d##m##n = GiLoadFloat32(patchT + m * alpha * pack_size + n * pack_size);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb #undef cb
} }
@@ -74,29 +73,30 @@ struct InputTransformF23_NCHW44 {
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 //! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1
#define cb(m) \
auto t0##m = d0##m - d2##m; \
auto t1##m = d1##m + d2##m; \
auto t2##m = d2##m - d1##m; \
auto t3##m = d3##m - d1##m;
#define cb(m) \
auto t0##m = SUBF(d0##m, d2##m); \
auto t1##m = ADDF(d1##m, d2##m); \
auto t2##m = SUBF(d2##m, d1##m); \
auto t3##m = SUBF(d3##m, d1##m);


UNROLL_CALL_NOWRAPPER(4, cb); UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb #undef cb


#define cb(m) \
d##m##0 = t##m##0 - t##m##2; \
d##m##1 = t##m##1 + t##m##2; \
d##m##2 = t##m##2 - t##m##1; \
d##m##3 = t##m##3 - t##m##1;
#define cb(m) \
d##m##0 = SUBF(t##m##0, t##m##2); \
d##m##1 = ADDF(t##m##1, t##m##2); \
d##m##2 = SUBF(t##m##2, t##m##1); \
d##m##3 = SUBF(t##m##3, t##m##1);


UNROLL_CALL_NOWRAPPER(4, cb); UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb #undef cb


#define cb(m, n) \
d##m##n.save( \
input_transform_buf + \
(m * alpha + n) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size);
#define cb(m, n) \
GiStoreFloat32( \
input_transform_buf + \
(m * alpha + n) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d##m##n);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) UNROLL_CALL_NOWRAPPER_D2(4, 4, cb)
#undef cb #undef cb
} }
@@ -118,7 +118,7 @@ struct OutputTransformF23_NCHW44 {
size_t ocb = oc_index / pack_size; size_t ocb = oc_index / pack_size;


#define cb(m, n) \ #define cb(m, n) \
auto v##m##n = Vector<float, 4>::load( \
auto v##m##n = GiLoadFloat32( \
output_transform_buf + \ output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ (m * alpha + n) * OCB * nr_units_in_tile * pack_size + \
ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); ocb * nr_units_in_tile * pack_size + unit_idx * pack_size);
@@ -130,30 +130,30 @@ struct OutputTransformF23_NCHW44 {
//! v20 v21 v22 v23 1 -1 //! v20 v21 v22 v23 1 -1
//! v30 v31 v32 v33 0 1 //! v30 v31 v32 v33 0 1


#define cb(m) \
auto t0##m = v0##m + v1##m + v2##m; \
auto t1##m = v1##m - v2##m + v3##m;
#define cb(m) \
auto t0##m = ADDF(ADDF(v0##m, v1##m), v2##m); \
auto t1##m = ADDF(SUBF(v1##m, v2##m), v3##m);


UNROLL_CALL_NOWRAPPER(4, cb); UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb #undef cb


#define cb(m) \
v##m##0 = t##m##0 + t##m##1 + t##m##2; \
v##m##1 = t##m##1 - t##m##2 + t##m##3;
#define cb(m) \
v##m##0 = ADDF(ADDF(t##m##0, t##m##1), t##m##2); \
v##m##1 = ADDF(SUBF(t##m##1, t##m##2), t##m##3);


UNROLL_CALL_NOWRAPPER(2, cb); UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb #undef cb


Vector<float, 4> vbias;
GI_FLOAT32_t vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 4>::load(bias + oc);
vbias = GiLoadFloat32(bias + oc);


#define cb(m, n) v##m##n += vbias;
#define cb(m, n) v##m##n = ADDF(v##m##n, vbias);
UNROLL_CALL_RAW_D2(2, 2, cb); UNROLL_CALL_RAW_D2(2, 2, cb);
#undef cb #undef cb
} }
if (bmode != BiasMode::BIAS) { if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value));
#define cb(m, n) v##m##n = op(CONCAT(v##m, n));
UNROLL_CALL_RAW_D2(2, 2, cb); UNROLL_CALL_RAW_D2(2, 2, cb);
#undef cb #undef cb
} }
@@ -163,12 +163,15 @@ struct OutputTransformF23_NCHW44 {
size_t ow = ow_start + owo; \ size_t ow = ow_start + owo; \
if (oh < OH && ow < OW) { \ if (oh < OH && ow < OW) { \
if (bmode == BiasMode::BIAS) { \ if (bmode == BiasMode::BIAS) { \
v##oho##owo += Vector<float, 4>::load( \
bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \
v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \
v##oho##owo = ADDF( \
v##oho##owo, GiLoadFloat32( \
bias + oc * OH * OW + \
oh * OW * pack_size + ow * pack_size)); \
v##oho##owo = op(v##oho##owo); \
} \ } \
v##oho##owo.save( \
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \
GiStoreFloat32( \
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \
v##oho##owo); \
} \ } \
} while (0); } while (0);
UNROLL_CALL_RAW_D2(2, 2, out_save); UNROLL_CALL_RAW_D2(2, 2, out_save);
@@ -211,29 +214,30 @@ void winograd_F23_mk4_f_nchw44::filter(
pack_size * pack_size + pack_size * pack_size +
ic_inner * pack_size; ic_inner * pack_size;


#define cb(m, n) \
Vector<float, 4> g##m##n = Vector<float, 4>::load( \
fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size);
#define cb(m, n) \
GI_FLOAT32_t g##m##n = \
GiLoadFloat32(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size);
UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) UNROLL_CALL_NOWRAPPER_D2(3, 3, cb)
#undef cb #undef cb


#define FILTER_TRANSFORM(n, wd, g) \
auto wd##n##0 = g##0##n; \
tmp0 = (g##0##n + g##2##n) * 0.5; \
tmp1 = g##1##n * 0.5; \
auto wd##n##1 = tmp0 + tmp1; \
auto wd##n##2 = tmp0 - tmp1; \
#define FILTER_TRANSFORM(n, wd, g) \
auto wd##n##0 = g##0##n; \
tmp0 = MULSF(ADDF(g##0##n, g##2##n), 0.5); \
tmp1 = MULSF(g##1##n, 0.5); \
auto wd##n##1 = ADDF(tmp0, tmp1); \
auto wd##n##2 = SUBF(tmp0, tmp1); \
auto wd##n##3 = g##2##n; auto wd##n##3 = g##2##n;
Vector<float, 4> tmp0, tmp1;
GI_FLOAT32_t tmp0, tmp1;
UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g);
UNROLL_CALL_RAW(4, FILTER_TRANSFORM, ret, wd); UNROLL_CALL_RAW(4, FILTER_TRANSFORM, ret, wd);
#undef FILTER_TRANSFORM #undef FILTER_TRANSFORM
#define cb_save(m, n) \
ret##m##n.save( \
filter_transform_buf + \
(m * ALPHA + n) * OCB * ICB * pack_size * pack_size + \
ocb * ICB * pack_size * pack_size + icb * pack_size * pack_size + \
ic_inner * pack_size);
#define cb_save(m, n) \
GiStoreFloat32( \
filter_transform_buf + \
(m * ALPHA + n) * OCB * ICB * pack_size * pack_size + \
ocb * ICB * pack_size * pack_size + icb * pack_size * pack_size + \
ic_inner * pack_size, \
ret##m##n);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb_save) UNROLL_CALL_NOWRAPPER_D2(4, 4, cb_save)
#undef cb_save #undef cb_save
} }


+ 125
- 74
dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp View File

@@ -4,7 +4,6 @@
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" #include "src/fallback/conv_bias/gi/fp32/filter_transform.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h" #include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/fallback/elemwise_helper/op_unary.h" #include "src/fallback/elemwise_helper/op_unary.h"


@@ -114,17 +113,17 @@ struct InputTransformF63_NCHW44 {
auto t##i##4 = d6; \ auto t##i##4 = d6; \
auto t##i##5 = d6; \ auto t##i##5 = d6; \
auto t##i##6 = d6; \ auto t##i##6 = d6; \
t##i##0 = GiSubtractFloat32(t##i##0, d6); \
t##i##1 = GiAddFloat32(t##i##1, d1); \
t##i##2 = GiSubtractFloat32(t##i##2, d1); \
t##i##0 = SUBF(t##i##0, d6); \
t##i##1 = ADDF(t##i##1, d1); \
t##i##2 = SUBF(t##i##2, d1); \
t##i##3 = MADD(t##i##3, d1, v0, 2); \ t##i##3 = MADD(t##i##3, d1, v0, 2); \
t##i##4 = MSUB(t##i##4, d1, v0, 2); \ t##i##4 = MSUB(t##i##4, d1, v0, 2); \
t##i##5 = MADD(t##i##5, d1, v1, 2); \ t##i##5 = MADD(t##i##5, d1, v1, 2); \
t##i##6 = MSUB(t##i##6, d1, v1, 2); \ t##i##6 = MSUB(t##i##6, d1, v1, 2); \
t##i##7 = GiSubtractFloat32(t##i##7, d1); \
t##i##7 = SUBF(t##i##7, d1); \
t##i##0 = MSUB(t##i##0, d2, v0, 0); \ t##i##0 = MSUB(t##i##0, d2, v0, 0); \
t##i##1 = GiAddFloat32(t##i##1, d2); \
t##i##2 = GiAddFloat32(t##i##2, d2); \
t##i##1 = ADDF(t##i##1, d2); \
t##i##2 = ADDF(t##i##2, d2); \
t##i##3 = MADD(t##i##3, d2, v0, 3); \ t##i##3 = MADD(t##i##3, d2, v0, 3); \
t##i##4 = MADD(t##i##4, d2, v0, 3); \ t##i##4 = MADD(t##i##4, d2, v0, 3); \
t##i##5 = MADD(t##i##5, d2, v1, 3); \ t##i##5 = MADD(t##i##5, d2, v1, 3); \
@@ -143,8 +142,8 @@ struct InputTransformF63_NCHW44 {
t##i##4 = MSUB(t##i##4, d4, v1, 1); \ t##i##4 = MSUB(t##i##4, d4, v1, 1); \
t##i##5 = MSUB(t##i##5, d4, v2, 0); \ t##i##5 = MSUB(t##i##5, d4, v2, 0); \
t##i##6 = MSUB(t##i##6, d4, v2, 0); \ t##i##6 = MSUB(t##i##6, d4, v2, 0); \
t##i##1 = GiAddFloat32(t##i##1, d5); \
t##i##2 = GiSubtractFloat32(t##i##2, d5); \
t##i##1 = ADDF(t##i##1, d5); \
t##i##2 = SUBF(t##i##2, d5); \
t##i##3 = MADD(t##i##3, d5, v1, 2); \ t##i##3 = MADD(t##i##3, d5, v1, 2); \
t##i##4 = MSUB(t##i##4, d5, v1, 2); \ t##i##4 = MSUB(t##i##4, d5, v1, 2); \
t##i##5 = MADD(t##i##5, d5, v0, 2); \ t##i##5 = MADD(t##i##5, d5, v0, 2); \
@@ -162,17 +161,17 @@ struct InputTransformF63_NCHW44 {
d5 = t6##i; \ d5 = t6##i; \
d6 = t6##i; \ d6 = t6##i; \
d7 = t7##i; \ d7 = t7##i; \
d0 = GiSubtractFloat32(d0, t6##i); \
d1 = GiAddFloat32(d1, t1##i); \
d2 = GiSubtractFloat32(d2, t1##i); \
d0 = SUBF(d0, t6##i); \
d1 = ADDF(d1, t1##i); \
d2 = SUBF(d2, t1##i); \
d3 = MADD(d3, t1##i, v0, 2); \ d3 = MADD(d3, t1##i, v0, 2); \
d4 = MSUB(d4, t1##i, v0, 2); \ d4 = MSUB(d4, t1##i, v0, 2); \
d5 = MADD(d5, t1##i, v1, 2); \ d5 = MADD(d5, t1##i, v1, 2); \
d6 = MSUB(d6, t1##i, v1, 2); \ d6 = MSUB(d6, t1##i, v1, 2); \
d7 = GiSubtractFloat32(d7, t1##i); \
d7 = SUBF(d7, t1##i); \
d0 = MSUB(d0, t2##i, v0, 0); \ d0 = MSUB(d0, t2##i, v0, 0); \
d1 = GiAddFloat32(d1, t2##i); \
d2 = GiAddFloat32(d2, t2##i); \
d1 = ADDF(d1, t2##i); \
d2 = ADDF(d2, t2##i); \
d3 = MADD(d3, t2##i, v0, 3); \ d3 = MADD(d3, t2##i, v0, 3); \
d4 = MADD(d4, t2##i, v0, 3); \ d4 = MADD(d4, t2##i, v0, 3); \
d5 = MADD(d5, t2##i, v1, 3); \ d5 = MADD(d5, t2##i, v1, 3); \
@@ -191,8 +190,8 @@ struct InputTransformF63_NCHW44 {
d4 = MSUB(d4, t4##i, v1, 1); \ d4 = MSUB(d4, t4##i, v1, 1); \
d5 = MSUB(d5, t4##i, v2, 0); \ d5 = MSUB(d5, t4##i, v2, 0); \
d6 = MSUB(d6, t4##i, v2, 0); \ d6 = MSUB(d6, t4##i, v2, 0); \
d1 = GiAddFloat32(d1, t5##i); \
d2 = GiSubtractFloat32(d2, t5##i); \
d1 = ADDF(d1, t5##i); \
d2 = SUBF(d2, t5##i); \
d3 = MADD(d3, t5##i, v1, 2); \ d3 = MADD(d3, t5##i, v1, 2); \
d4 = MSUB(d4, t5##i, v1, 2); \ d4 = MSUB(d4, t5##i, v1, 2); \
d5 = MADD(d5, t5##i, v0, 2); \ d5 = MADD(d5, t5##i, v0, 2); \
@@ -261,7 +260,7 @@ struct OutputTransformF63_NCHW44 {
size_t ocb = oc_index / pack_size; size_t ocb = oc_index / pack_size;


#define cb(m, n) \ #define cb(m, n) \
auto v##m##n = Vector<float, 4>::load( \
auto v##m##n = GiLoadFloat32( \
output_transform_buf + \ output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ (m * alpha + n) * OCB * nr_units_in_tile * pack_size + \
ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); ocb * nr_units_in_tile * pack_size + unit_idx * pack_size);
@@ -281,51 +280,83 @@ struct OutputTransformF63_NCHW44 {
* 0 0 0 0 0 1 * 0 0 0 0 0 1
*/ */


Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
#define cb(m) \
v1addv2 = v1##m + v2##m; \
v1subv2 = v1##m - v2##m; \
v3addv4 = v3##m + v4##m; \
v3subv4 = v3##m - v4##m; \
v5addv6 = v5##m + v6##m; \
v5subv6 = v5##m - v6##m; \
auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \
auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \
auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \
auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m;
/*
* v1addv2 = v1##m + v2##m;
* v1subv2 = v1##m - v2##m;
* v3addv4 = v3##m + v4##m;
* v3subv4 = v3##m - v4##m;
* v5addv6 = v5##m + v6##m;
* v5subv6 = v5##m - v6##m;
* auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6;
* auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f;
* auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f;
* auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f;
* auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f;
* auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m;
*/
GI_FLOAT32_t v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
#define cb(m) \
v1addv2 = ADDF(v1##m, v2##m); \
v1subv2 = SUBF(v1##m, v2##m); \
v3addv4 = ADDF(v3##m, v4##m); \
v3subv4 = SUBF(v3##m, v4##m); \
v5addv6 = ADDF(v5##m, v6##m); \
v5subv6 = SUBF(v5##m, v6##m); \
auto t0##m = ADDF(ADDF(ADDF(v0##m, v1addv2), v3addv4), v5addv6); \
auto t1##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \
auto t2##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \
auto t3##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \
auto t4##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \
auto t5##m = \
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \
v7##m);


UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb


#define cb(m) \
v1addv2 = t##m##1 + t##m##2; \
v1subv2 = t##m##1 - t##m##2; \
v3addv4 = t##m##3 + t##m##4; \
v3subv4 = t##m##3 - t##m##4; \
v5addv6 = t##m##5 + t##m##6; \
v5subv6 = t##m##5 - t##m##6; \
v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \
v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \
v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \
v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7;
/*
* v1addv2 = t##m##1 + t##m##2;
* v1subv2 = t##m##1 - t##m##2;
* v3addv4 = t##m##3 + t##m##4;
* v3subv4 = t##m##3 - t##m##4;
* v5addv6 = t##m##5 + t##m##6;
* v5subv6 = t##m##5 - t##m##6;
* v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6;
* v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f;
* v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f;
* v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f;
* v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f;
* v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7;
*/
#define cb(m) \
v1addv2 = ADDF(t##m##1, t##m##2); \
v1subv2 = SUBF(t##m##1, t##m##2); \
v3addv4 = ADDF(t##m##3, t##m##4); \
v3subv4 = SUBF(t##m##3, t##m##4); \
v5addv6 = ADDF(t##m##5, t##m##6); \
v5subv6 = SUBF(t##m##5, t##m##6); \
v##m##0 = ADDF(ADDF(ADDF(t##m##0, v1addv2), v3addv4), v5addv6); \
v##m##1 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \
v##m##2 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \
v##m##3 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \
v##m##4 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \
v##m##5 = \
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \
t##m##7);


UNROLL_CALL_NOWRAPPER(6, cb); UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb #undef cb


Vector<float, 4> vbias;
GI_FLOAT32_t vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 4>::load(bias + oc);
vbias = GiLoadFloat32(bias + oc);


#define cb(m, n) v##m##n += vbias;
#define cb(m, n) v##m##n = ADDF(v##m##n, vbias);
UNROLL_CALL_RAW_D2(6, 6, cb); UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb #undef cb
} }
if (bmode != BiasMode::BIAS) { if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value));
#define cb(m, n) v##m##n = op(CONCAT(v##m, n));
UNROLL_CALL_RAW_D2(6, 6, cb); UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb #undef cb
} }
@@ -335,12 +366,15 @@ struct OutputTransformF63_NCHW44 {
size_t ow = ow_start + owo; \ size_t ow = ow_start + owo; \
if (oh < OH && ow < OW) { \ if (oh < OH && ow < OW) { \
if (bmode == BiasMode::BIAS) { \ if (bmode == BiasMode::BIAS) { \
v##oho##owo += Vector<float, 4>::load( \
bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \
v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \
v##oho##owo = ADDF( \
v##oho##owo, GiLoadFloat32( \
bias + oc * OH * OW + \
oh * OW * pack_size + ow * pack_size)); \
v##oho##owo = op(v##oho##owo); \
} \ } \
v##oho##owo.save( \
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \
GiStoreFloat32( \
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \
v##oho##owo); \
} \ } \
} while (0); } while (0);
UNROLL_CALL_RAW_D2(6, 6, out_save); UNROLL_CALL_RAW_D2(6, 6, out_save);
@@ -387,35 +421,52 @@ void winograd_F63_mk4_f_nchw44::filter(
pack_size * pack_size + pack_size * pack_size +
ic_inner * pack_size; ic_inner * pack_size;


#define cb(m, n) \
Vector<float, 4> g##m##n = Vector<float, 4>::load( \
fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size);
#define cb(m, n) \
GI_FLOAT32_t g##m##n = \
GiLoadFloat32(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size);
UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) UNROLL_CALL_NOWRAPPER_D2(3, 3, cb)
#undef cb #undef cb


#define FILTER_TRANSFORM(n, wd, g) \
auto wd##n##0 = g##0##n; \
tmp0 = (g##0##n + g##2##n) * -0.2222222f; \
tmp1 = g##1##n * -0.2222222f; \
auto wd##n##1 = tmp0 + tmp1; \
auto wd##n##2 = tmp0 - tmp1; \
tmp0 = g##0##n * 0.0111111f + g##2##n * 0.0444444f; \
tmp1 = g##1##n * 0.0222222f; \
auto wd##n##3 = tmp0 + tmp1; \
auto wd##n##4 = tmp0 - tmp1; \
tmp0 = g##0##n * 0.7111111f + g##2##n * 0.1777778f; \
tmp1 = g##1##n * 0.3555556f; \
auto wd##n##5 = tmp0 + tmp1; \
auto wd##n##6 = tmp0 - tmp1; \
/*
* auto wd##n##0 = g##0##n;
* tmp0 = (g##0##n + g##2##n) * -0.2222222f;
* tmp1 = g##1##n * -0.2222222f;
* auto wd##n##1 = tmp0 + tmp1;
* auto wd##n##2 = tmp0 - tmp1;
* tmp0 = g##0##n * 0.0111111f + g##2##n * 0.0444444f;
* tmp1 = g##1##n * 0.0222222f;
* auto wd##n##3 = tmp0 + tmp1;
* auto wd##n##4 = tmp0 - tmp1;
* tmp0 = g##0##n * 0.7111111f + g##2##n * 0.1777778f;
* tmp1 = g##1##n * 0.3555556f;
* auto wd##n##5 = tmp0 + tmp1;
* auto wd##n##6 = tmp0 - tmp1;
* auto wd##n##7 = g##2##n;
*/
#define FILTER_TRANSFORM(n, wd, g) \
auto wd##n##0 = g##0##n; \
tmp0 = MULSF(ADDF(g##0##n, g##2##n), -0.2222222f); \
tmp1 = MULSF(g##1##n, -0.2222222f); \
auto wd##n##1 = ADDF(tmp0, tmp1); \
auto wd##n##2 = SUBF(tmp0, tmp1); \
tmp0 = ADDF(MULSF(g##0##n, 0.0111111f), MULSF(g##2##n, 0.0444444f)); \
tmp1 = MULSF(g##1##n, 0.0222222f); \
auto wd##n##3 = ADDF(tmp0, tmp1); \
auto wd##n##4 = SUBF(tmp0, tmp1); \
tmp0 = ADDF(MULSF(g##0##n, 0.7111111f), MULSF(g##2##n, 0.1777778f)); \
tmp1 = MULSF(g##1##n, 0.3555556f); \
auto wd##n##5 = ADDF(tmp0, tmp1); \
auto wd##n##6 = SUBF(tmp0, tmp1); \
auto wd##n##7 = g##2##n; auto wd##n##7 = g##2##n;
Vector<float, 4> tmp0, tmp1;
GI_FLOAT32_t tmp0, tmp1;
UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g);
UNROLL_CALL_RAW(8, FILTER_TRANSFORM, ret, wd); UNROLL_CALL_RAW(8, FILTER_TRANSFORM, ret, wd);
#undef FILTER_TRANSFORM #undef FILTER_TRANSFORM
#define cb_save(m, n) \ #define cb_save(m, n) \
ret##m##n.save( \
GiStoreFloat32( \
filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \ filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \
icb * pack_size * pack_size + ic_inner * pack_size);
icb * pack_size * pack_size + ic_inner * pack_size, \
ret##m##n);
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb_save) UNROLL_CALL_NOWRAPPER_D2(8, 8, cb_save)
#undef cb_save #undef cb_save
} }


+ 145
- 68
dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp View File

@@ -4,7 +4,6 @@
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" #include "src/fallback/conv_bias/gi/fp32/filter_transform.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h" #include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h" #include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h" #include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/fallback/elemwise_helper/op_unary.h" #include "src/fallback/elemwise_helper/op_unary.h"


@@ -137,14 +136,14 @@ struct InputTransformF73_NCHW44 {
auto t##i##6 = d7; \ auto t##i##6 = d7; \
auto t##i##7 = d7; \ auto t##i##7 = d7; \
t##i##8 = MSUB(t##i##8, d7, v0, 0); \ t##i##8 = MSUB(t##i##8, d7, v0, 0); \
t##i##0 = GiSubtractFloat32(t##i##0, d1); \
t##i##0 = SUBF(t##i##0, d1); \
t##i##1 = MSUB(t##i##1, d1, v0, 0); \ t##i##1 = MSUB(t##i##1, d1, v0, 0); \
t##i##2 = MADD(t##i##2, d1, v0, 0); \ t##i##2 = MADD(t##i##2, d1, v0, 0); \
t##i##3 = MSUB(t##i##3, d1, v0, 1); \ t##i##3 = MSUB(t##i##3, d1, v0, 1); \
t##i##4 = MADD(t##i##4, d1, v0, 1); \ t##i##4 = MADD(t##i##4, d1, v0, 1); \
t##i##5 = MSUB(t##i##5, d1, v0, 2); \ t##i##5 = MSUB(t##i##5, d1, v0, 2); \
t##i##6 = MADD(t##i##6, d1, v0, 2); \ t##i##6 = MADD(t##i##6, d1, v0, 2); \
t##i##7 = GiSubtractFloat32(t##i##7, d1); \
t##i##7 = SUBF(t##i##7, d1); \
t##i##8 = MADD(t##i##8, d1, v0, 0); \ t##i##8 = MADD(t##i##8, d1, v0, 0); \
t##i##0 = MSUB(t##i##0, d2, v0, 3); \ t##i##0 = MSUB(t##i##0, d2, v0, 3); \
t##i##1 = MSUB(t##i##1, d2, v1, 0); \ t##i##1 = MSUB(t##i##1, d2, v1, 0); \
@@ -153,7 +152,7 @@ struct InputTransformF73_NCHW44 {
t##i##4 = MSUB(t##i##4, d2, v1, 3); \ t##i##4 = MSUB(t##i##4, d2, v1, 3); \
t##i##5 = MSUB(t##i##5, d2, v2, 0); \ t##i##5 = MSUB(t##i##5, d2, v2, 0); \
t##i##6 = MSUB(t##i##6, d2, v2, 1); \ t##i##6 = MSUB(t##i##6, d2, v2, 1); \
t##i##8 = GiSubtractFloat32(t##i##8, d2); \
t##i##8 = SUBF(t##i##8, d2); \
t##i##0 = MADD(t##i##0, d3, v2, 2); \ t##i##0 = MADD(t##i##0, d3, v2, 2); \
t##i##1 = MADD(t##i##1, d3, v2, 3); \ t##i##1 = MADD(t##i##1, d3, v2, 3); \
t##i##2 = MSUB(t##i##2, d3, v3, 0); \ t##i##2 = MSUB(t##i##2, d3, v3, 0); \
@@ -185,7 +184,7 @@ struct InputTransformF73_NCHW44 {
t##i##2 = MSUB(t##i##2, d6, v1, 1); \ t##i##2 = MSUB(t##i##2, d6, v1, 1); \
t##i##3 = MADD(t##i##3, d6, v1, 0); \ t##i##3 = MADD(t##i##3, d6, v1, 0); \
t##i##4 = MSUB(t##i##4, d6, v3, 1); \ t##i##4 = MSUB(t##i##4, d6, v3, 1); \
t##i##5 = GiSubtractFloat32(t##i##5, d6); \
t##i##5 = SUBF(t##i##5, d6); \
t##i##6 = MSUB(t##i##6, d6, v6, 2); \ t##i##6 = MSUB(t##i##6, d6, v6, 2); \
t##i##8 = MSUB(t##i##8, d6, v2, 2); \ t##i##8 = MSUB(t##i##8, d6, v2, 2); \
t##i##0 = MADD(t##i##0, d0, v0, 0); t##i##0 = MADD(t##i##0, d0, v0, 0);
@@ -204,14 +203,14 @@ struct InputTransformF73_NCHW44 {
d6 = t7##i; \ d6 = t7##i; \
d7 = t7##i; \ d7 = t7##i; \
d8 = MSUB(d8, t7##i, v0, 0); \ d8 = MSUB(d8, t7##i, v0, 0); \
d0 = GiSubtractFloat32(d0, t1##i); \
d0 = SUBF(d0, t1##i); \
d1 = MSUB(d1, t1##i, v0, 0); \ d1 = MSUB(d1, t1##i, v0, 0); \
d2 = MADD(d2, t1##i, v0, 0); \ d2 = MADD(d2, t1##i, v0, 0); \
d3 = MSUB(d3, t1##i, v0, 1); \ d3 = MSUB(d3, t1##i, v0, 1); \
d4 = MADD(d4, t1##i, v0, 1); \ d4 = MADD(d4, t1##i, v0, 1); \
d5 = MSUB(d5, t1##i, v0, 2); \ d5 = MSUB(d5, t1##i, v0, 2); \
d6 = MADD(d6, t1##i, v0, 2); \ d6 = MADD(d6, t1##i, v0, 2); \
d7 = GiSubtractFloat32(d7, t1##i); \
d7 = SUBF(d7, t1##i); \
d8 = MADD(d8, t1##i, v0, 0); \ d8 = MADD(d8, t1##i, v0, 0); \
d0 = MSUB(d0, t2##i, v0, 3); \ d0 = MSUB(d0, t2##i, v0, 3); \
d1 = MSUB(d1, t2##i, v1, 0); \ d1 = MSUB(d1, t2##i, v1, 0); \
@@ -220,7 +219,7 @@ struct InputTransformF73_NCHW44 {
d4 = MSUB(d4, t2##i, v1, 3); \ d4 = MSUB(d4, t2##i, v1, 3); \
d5 = MSUB(d5, t2##i, v2, 0); \ d5 = MSUB(d5, t2##i, v2, 0); \
d6 = MSUB(d6, t2##i, v2, 1); \ d6 = MSUB(d6, t2##i, v2, 1); \
d8 = GiSubtractFloat32(d8, t2##i); \
d8 = SUBF(d8, t2##i); \
d0 = MADD(d0, t3##i, v2, 2); \ d0 = MADD(d0, t3##i, v2, 2); \
d1 = MADD(d1, t3##i, v2, 3); \ d1 = MADD(d1, t3##i, v2, 3); \
d2 = MSUB(d2, t3##i, v3, 0); \ d2 = MSUB(d2, t3##i, v3, 0); \
@@ -252,7 +251,7 @@ struct InputTransformF73_NCHW44 {
d2 = MSUB(d2, t6##i, v1, 1); \ d2 = MSUB(d2, t6##i, v1, 1); \
d3 = MADD(d3, t6##i, v1, 0); \ d3 = MADD(d3, t6##i, v1, 0); \
d4 = MSUB(d4, t6##i, v3, 1); \ d4 = MSUB(d4, t6##i, v3, 1); \
d5 = GiSubtractFloat32(d5, t6##i); \
d5 = SUBF(d5, t6##i); \
d6 = MSUB(d6, t6##i, v6, 2); \ d6 = MSUB(d6, t6##i, v6, 2); \
d8 = MSUB(d8, t6##i, v2, 2); \ d8 = MSUB(d8, t6##i, v2, 2); \
d0 = MADD(d0, t0##i, v0, 0); \ d0 = MADD(d0, t0##i, v0, 0); \
@@ -325,7 +324,7 @@ struct OutputTransformF73_NCHW44 {
size_t ocb = oc_index / pack_size; size_t ocb = oc_index / pack_size;


#define cb(m, n) \ #define cb(m, n) \
auto v##m##n = Vector<float, 4>::load( \
auto v##m##n = GiLoadFloat32( \
output_transform_buf + \ output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ (m * alpha + n) * OCB * nr_units_in_tile * pack_size + \
ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); ocb * nr_units_in_tile * pack_size + unit_idx * pack_size);
@@ -346,56 +345,112 @@ struct OutputTransformF73_NCHW44 {
* 1 1.5 2.25 3.375 5.0625 7.59375 11.390625 * 1 1.5 2.25 3.375 5.0625 7.59375 11.390625
* 0 0 0 0 0 0 1 * 0 0 0 0 0 0 1
*/ */
/*
* v1addv2 = v1##m + v2##m;
* v1subv2 = v1##m - v2##m;
* v3addv4 = v3##m + v4##m;
* v3subv4 = v3##m - v4##m;
* v5addv6 = v5##m + v6##m;
* v5subv6 = v5##m - v6##m;
* auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6 + v7##m;
* auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + v7##m * 1.5f;
* auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + v7##m * 2.25f;
* auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + v7##m * 3.375f;
* auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + v7##m * 5.0625f;
* auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m
* * 7.59375f; auto t6##m = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f +
* v7##m * 11.390625f + v8##m;
*/


Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
GI_FLOAT32_t v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
#define cb(m) \ #define cb(m) \
v1addv2 = v1##m + v2##m; \
v1subv2 = v1##m - v2##m; \
v3addv4 = v3##m + v4##m; \
v3subv4 = v3##m - v4##m; \
v5addv6 = v5##m + v6##m; \
v5subv6 = v5##m - v6##m; \
auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6 + v7##m; \
auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + v7##m * 1.5f; \
auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + v7##m * 2.25f; \
auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + v7##m * 3.375f; \
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + v7##m * 5.0625f; \
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m * 7.59375f; \
auto t6##m = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + v7##m * 11.390625f + \
v8##m;
v1addv2 = ADDF(v1##m, v2##m); \
v1subv2 = SUBF(v1##m, v2##m); \
v3addv4 = ADDF(v3##m, v4##m); \
v3subv4 = SUBF(v3##m, v4##m); \
v5addv6 = ADDF(v5##m, v6##m); \
v5subv6 = SUBF(v5##m, v6##m); \
auto t0##m = ADDF(ADDF(ADDF(ADDF(v0##m, v1addv2), v3addv4), v5addv6), v7##m); \
auto t1##m = \
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)), \
MULSF(v7##m, 1.5f)); \
auto t2##m = \
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)), \
MULSF(v7##m, 2.25f)); \
auto t3##m = \
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)), \
MULSF(v7##m, 3.375f)); \
auto t4##m = \
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)), \
MULSF(v7##m, 5.0625f)); \
auto t5##m = \
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \
MULSF(v7##m, 7.59375f)); \
auto t6##m = ADDF( \
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 64.f)), MULSF(v5addv6, 0.015625f)), \
MULSF(v7##m, 11.390625f)), \
v8##m);


UNROLL_CALL_NOWRAPPER(9, cb); UNROLL_CALL_NOWRAPPER(9, cb);
#undef cb #undef cb


#define cb(m) \
v1addv2 = t##m##1 + t##m##2; \
v1subv2 = t##m##1 - t##m##2; \
v3addv4 = t##m##3 + t##m##4; \
v3subv4 = t##m##3 - t##m##4; \
v5addv6 = t##m##5 + t##m##6; \
v5subv6 = t##m##5 - t##m##6; \
v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6 + t##m##7; \
v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + t##m##7 * 1.5f; \
v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + t##m##7 * 2.25f; \
v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + t##m##7 * 3.375; \
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + t##m##7 * 5.0625f; \
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7 * 7.59375f; \
v##m##6 = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + t##m##7 * 11.390625f + \
t##m##8;
/*
* v1addv2 = t##m##1 + t##m##2;
* v1subv2 = t##m##1 - t##m##2;
* v3addv4 = t##m##3 + t##m##4;
* v3subv4 = t##m##3 - t##m##4;
* v5addv6 = t##m##5 + t##m##6;
* v5subv6 = t##m##5 - t##m##6;
* v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6 + t##m##7;
* v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + t##m##7 * 1.5f;
* v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + t##m##7 * 2.25f;
* v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + t##m##7 * 3.375;
* v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + t##m##7 * 5.0625f;
* v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7 * 7.59375f;
* v##m##6 = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + t##m##7
* * 11.390625f + t##m##8;
*/
#define cb(m) \
v1addv2 = ADDF(t##m##1, t##m##2); \
v1subv2 = SUBF(t##m##1, t##m##2); \
v3addv4 = ADDF(t##m##3, t##m##4); \
v3subv4 = SUBF(t##m##3, t##m##4); \
v5addv6 = ADDF(t##m##5, t##m##6); \
v5subv6 = SUBF(t##m##5, t##m##6); \
v##m##0 = ADDF(ADDF(ADDF(ADDF(t##m##0, v1addv2), v3addv4), v5addv6), t##m##7); \
v##m##1 = \
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)), \
MULSF(t##m##7, 1.5f)); \
v##m##2 = \
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)), \
MULSF(t##m##7, 2.25f)); \
v##m##3 = \
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)), \
MULSF(t##m##7, 3.375)); \
v##m##4 = \
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)), \
MULSF(t##m##7, 5.0625f)); \
v##m##5 = \
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \
MULSF(t##m##7, 7.59375f)); \
v##m##6 = ADDF( \
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 64.f)), MULSF(v5addv6, 0.015625f)), \
MULSF(t##m##7, 11.390625f)), \
t##m##8);


UNROLL_CALL_NOWRAPPER(7, cb); UNROLL_CALL_NOWRAPPER(7, cb);
#undef cb #undef cb


Vector<float, 4> vbias;
GI_FLOAT32_t vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 4>::load(bias + oc);
vbias = GiLoadFloat32(bias + oc);


#define cb(m, n) v##m##n += vbias;
#define cb(m, n) v##m##n = ADDF(v##m##n, vbias);
UNROLL_CALL_RAW_D2(7, 7, cb); UNROLL_CALL_RAW_D2(7, 7, cb);
#undef cb #undef cb
} }
if (bmode != BiasMode::BIAS) { if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value));
#define cb(m, n) v##m##n = op(CONCAT(v##m, n));
UNROLL_CALL_RAW_D2(7, 7, cb); UNROLL_CALL_RAW_D2(7, 7, cb);
#undef cb #undef cb
} }
@@ -405,12 +460,15 @@ struct OutputTransformF73_NCHW44 {
size_t ow = ow_start + owo; \ size_t ow = ow_start + owo; \
if (oh < OH && ow < OW) { \ if (oh < OH && ow < OW) { \
if (bmode == BiasMode::BIAS) { \ if (bmode == BiasMode::BIAS) { \
v##oho##owo += Vector<float, 4>::load( \
bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \
v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \
v##oho##owo = ADDF( \
v##oho##owo, GiLoadFloat32( \
bias + oc * OH * OW + \
oh * OW * pack_size + ow * pack_size)); \
v##oho##owo = op(v##oho##owo); \
} \ } \
v##oho##owo.save( \
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \
GiStoreFloat32( \
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \
v##oho##owo); \
} \ } \
} while (0); } while (0);
UNROLL_CALL_RAW_D2(7, 7, out_save); UNROLL_CALL_RAW_D2(7, 7, out_save);
@@ -458,34 +516,53 @@ void winograd_F73_mk4_f_nchw44::filter(
pack_size * pack_size + pack_size * pack_size +
ic_inner * pack_size; ic_inner * pack_size;


#define cb(m, n) \
Vector<float, 4> g##m##n = Vector<float, 4>::load( \
fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size);
#define cb(m, n) \
GI_FLOAT32_t g##m##n = \
GiLoadFloat32(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size);
UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) UNROLL_CALL_NOWRAPPER_D2(3, 3, cb)
#undef cb #undef cb


#define FILTER_TRANSFORM(n, wd, g) \
auto wd##n##0 = g##0##n * 0.6666667f; \
auto wd##n##1 = (g##0##n + g##1##n + g##2##n) * 0.4444444f; \
auto wd##n##2 = (g##0##n - g##1##n + g##2##n) * 0.0888889f; \
auto wd##n##3 = \
g##0##n * 0.0222222f + g##1##n * 0.0444444f + g##2##n * 0.0888889f; \
auto wd##n##4 = \
g##0##n * -0.0031746f + g##1##n * 0.0063492f + g##2##n * -0.0126984f; \
auto wd##n##5 = \
g##0##n * -0.7111111f + g##1##n * -0.3555556f + g##2##n * -0.1777778f; \
auto wd##n##6 = \
g##0##n * -0.3555556f + g##1##n * 0.1777778f + g##2##n * -0.0888889f; \
auto wd##n##7 = \
g##0##n * -0.1523810f + g##1##n * -0.2285714f + g##2##n * -0.3428572f; \
/*
* auto wd##n##0 = g##0##n * 0.6666667f;
* auto wd##n##1 = (g##0##n + g##1##n + g##2##n) * 0.4444444f;
* auto wd##n##2 = (g##0##n - g##1##n + g##2##n) * 0.0888889f;
* auto wd##n##3 =
* g##0##n * 0.0222222f + g##1##n * 0.0444444f + g##2##n *
* 0.0888889f; auto wd##n##4 = g##0##n * -0.0031746f + g##1##n *
* 0.0063492f + g##2##n * -0.0126984f; auto wd##n##5 = g##0##n *
* -0.7111111f + g##1##n * -0.3555556f + g##2##n * -0.1777778f; auto
* wd##n##6 = g##0##n * -0.3555556f + g##1##n * 0.1777778f + g##2##n *
* -0.0888889f; auto wd##n##7 = g##0##n * -0.1523810f + g##1##n *
* -0.2285714f + g##2##n * -0.3428572f;
*/
#define FILTER_TRANSFORM(n, wd, g) \
auto wd##n##0 = MULSF(g##0##n, 0.6666667f); \
auto wd##n##1 = MULSF(ADDF(ADDF(g##0##n, g##1##n), g##2##n), 0.4444444f); \
auto wd##n##2 = MULSF(ADDF(SUBF(g##0##n, g##1##n), g##2##n), 0.0888889f); \
auto wd##n##3 = \
ADDF(ADDF(MULSF(g##0##n, 0.0222222f), MULSF(g##1##n, 0.0444444f)), \
MULSF(g##2##n, 0.0888889f)); \
auto wd##n##4 = \
ADDF(ADDF(MULSF(g##0##n, -0.0031746f), MULSF(g##1##n, 0.0063492f)), \
MULSF(g##2##n, -0.0126984f)); \
auto wd##n##5 = \
ADDF(ADDF(MULSF(g##0##n, -0.7111111f), MULSF(g##1##n, -0.3555556f)), \
MULSF(g##2##n, -0.1777778f)); \
auto wd##n##6 = \
ADDF(ADDF(MULSF(g##0##n, -0.3555556f), MULSF(g##1##n, 0.1777778f)), \
MULSF(g##2##n, -0.0888889f)); \
auto wd##n##7 = \
ADDF(ADDF(MULSF(g##0##n, -0.1523810f), MULSF(g##1##n, -0.2285714f)), \
MULSF(g##2##n, -0.3428572f)); \
auto wd##n##8 = g##2##n; auto wd##n##8 = g##2##n;
UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g);
UNROLL_CALL_RAW(9, FILTER_TRANSFORM, ret, wd); UNROLL_CALL_RAW(9, FILTER_TRANSFORM, ret, wd);
#undef FILTER_TRANSFORM #undef FILTER_TRANSFORM
#define cb_save(m, n) \ #define cb_save(m, n) \
ret##m##n.save( \
GiStoreFloat32( \
filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \ filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \
icb * pack_size * pack_size + ic_inner * pack_size);
icb * pack_size * pack_size + ic_inner * pack_size, \
ret##m##n);
UNROLL_CALL_NOWRAPPER_D2(9, 9, cb_save) UNROLL_CALL_NOWRAPPER_D2(9, 9, cb_save)
#undef cb_save #undef cb_save
} }


+ 0
- 215
dnn/src/fallback/conv_bias/gi/utils.h View File

@@ -1,215 +0,0 @@
#pragma once

#include <cstring>
#include "src/common/utils.h"
#include "src/fallback/general_intrinsic/gi_float.h"

namespace megdnn {
namespace fallback {

template <typename ctype, size_t len>
struct Vector;

template <>
struct Vector<float, 4> {
GI_FLOAT32_FIXLEN_t value;
Vector() {}
Vector(const float v) { value = GiFloat32Type2FixLenType(GiBroadcastFloat32(v)); }
Vector(const Vector& lr) { value = lr.value; }
Vector(const Vector&& lr) { value = std::move(lr.value); }
Vector(const GI_FLOAT32_t& v) { value = GiFloat32Type2FixLenType(v); }
static Vector load(const float* addr) {
Vector v;
v.value = GiFloat32Type2FixLenType(GiLoadFloat32(addr));
return v;
}
static void save(float* addr, const Vector& v) {
GiStoreFloat32(addr, GiFixLenType2GiFloat32Type(v.value));
}
void save(float* addr) { save(addr, *this); }
Vector operator+(const Vector& lr) {
Vector dst;
dst.value = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value),
GiFixLenType2GiFloat32Type(lr.value)));
return dst;
}
Vector& operator+=(const Vector& lr) {
value = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value),
GiFixLenType2GiFloat32Type(lr.value)));
return *this;
}
Vector operator-(const Vector& lr) {
Vector dst;
dst.value = GiFloat32Type2FixLenType(GiSubtractFloat32(
GiFixLenType2GiFloat32Type(value),
GiFixLenType2GiFloat32Type(lr.value)));
return dst;
}
Vector& operator-=(const Vector& lr) {
value = GiFloat32Type2FixLenType(GiSubtractFloat32(
GiFixLenType2GiFloat32Type(value),
GiFixLenType2GiFloat32Type(lr.value)));
return *this;
}
Vector operator*(float lr) {
Vector dst;
dst.value = GiFloat32Type2FixLenType(
GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value), lr));
return dst;
}
Vector operator*(const Vector& lr) {
Vector dst;
dst.value = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(value),
GiFixLenType2GiFloat32Type(lr.value)));
return dst;
}
Vector& operator*=(const Vector& lr) {
value = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(value),
GiFixLenType2GiFloat32Type(lr.value)));
return *this;
}
Vector& operator=(const Vector& lr) {
value = lr.value;
return *this;
}
Vector& operator=(const Vector&& lr) {
value = std::move(lr.value);
return *this;
}
Vector operator-() {
Vector dst;
dst.value = -value;
return dst;
}
};

template <>
struct Vector<float, 8> {
GI_FLOAT32_FIXLEN_V2_t value;
Vector() {}
Vector(const float v) {
value.val[0] = GiFloat32Type2FixLenType(GiBroadcastFloat32(v));
value.val[1] = GiFloat32Type2FixLenType(GiBroadcastFloat32(v));
}
Vector(const Vector& lr) { value = lr.value; }
Vector(const Vector&& lr) { value = std::move(lr.value); }
Vector(const GI_FLOAT32_V2_t& v) { value = GiFloat32Type2FixLenV2Type(v); }
static Vector load(const float* addr) {
Vector v;
v.value = GiFloat32Type2FixLenV2Type(GiLoadFloat32V2(addr));
return v;
}
static void save(float* addr, const Vector& v) {
GiStoreFloat32V2(addr, GiFixLenType2GiFloat32V2Type(v.value));
}

void save(float* addr) { save(addr, *this); }
Vector operator+(const Vector& lr) {
Vector dst;
dst.value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
dst.value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return dst;
}
Vector& operator+=(const Vector& lr) {
value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return *this;
}
Vector& add(const Vector& lr) {
value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return *this;
}
Vector operator-(const Vector& lr) {
Vector dst;
dst.value.val[0] = GiFloat32Type2FixLenType(GiSubtractFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
dst.value.val[1] = GiFloat32Type2FixLenType(GiSubtractFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return dst;
}
Vector& operator-=(const Vector& lr) {
value.val[0] = GiFloat32Type2FixLenType(GiSubtractFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
value.val[1] = GiFloat32Type2FixLenType(GiSubtractFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return *this;
}
Vector operator*(float lr) {
Vector dst;
dst.value.val[0] = GiFloat32Type2FixLenType(
GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value.val[0]), lr));
dst.value.val[1] = GiFloat32Type2FixLenType(
GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value.val[1]), lr));
return dst;
}
//! val + lr * n
Vector& mla(const Vector& lr, float n) {
value.val[0] = GiFloat32Type2FixLenType(GiMultiplyAddScalarFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0]), n));
value.val[1] = GiFloat32Type2FixLenType(GiMultiplyAddScalarFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1]), n));
return *this;
}

Vector operator*(const Vector& lr) {
Vector dst;
dst.value.val[0] = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
dst.value.val[1] = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return dst;
}
Vector& operator*=(const Vector& lr) {
value.val[0] = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
value.val[1] = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return *this;
}
Vector& operator=(const Vector& lr) {
value = lr.value;
return *this;
}
Vector& operator=(const Vector&& lr) {
value = std::move(lr.value);
return *this;
}
Vector operator-() {
Vector dst;
dst.value.val[0] = -value.val[0];
dst.value.val[1] = -value.val[1];
return dst;
}
};

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 49
- 0
dnn/src/fallback/general_intrinsic/gi_float.h View File

@@ -857,6 +857,18 @@ GI_FLOAT32_t GiMultiplyScalerFloat32(GI_FLOAT32_t Vector1, float Scaler) {
} }


GI_FORCEINLINE GI_FORCEINLINE
GI_FLOAT32_V2_t GiMultiplyScalerFloat32V2(GI_FLOAT32_V2_t Vector1, float Scaler) {
GI_FLOAT32_V2_t ret;
GiSetSubVectorFloat32V2(
ret, 0,
GiMultiplyScalerFloat32(GiGetSubVectorFloat32V2(Vector1, 0), Scaler));
GiSetSubVectorFloat32V2(
ret, 1,
GiMultiplyScalerFloat32(GiGetSubVectorFloat32V2(Vector1, 1), Scaler));
return ret;
}

GI_FORCEINLINE
GI_FLOAT32_t GiMultiplyAddFloat32( GI_FLOAT32_t GiMultiplyAddFloat32(
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
@@ -888,6 +900,23 @@ GI_FLOAT32_t GiMultiplyAddScalarFloat32(
} }


GI_FORCEINLINE GI_FORCEINLINE
GI_FLOAT32_V2_t GiMultiplyAddScalarFloat32V2(
GI_FLOAT32_V2_t VectorSum, GI_FLOAT32_V2_t Vector, float Scalar) {
GI_FLOAT32_V2_t ret;
GiSetSubVectorFloat32V2(
ret, 0,
GiMultiplyAddScalarFloat32(
GiGetSubVectorFloat32V2(VectorSum, 0),
GiGetSubVectorFloat32V2(Vector, 0), Scalar));
GiSetSubVectorFloat32V2(
ret, 1,
GiMultiplyAddScalarFloat32(
GiGetSubVectorFloat32V2(VectorSum, 1),
GiGetSubVectorFloat32V2(Vector, 1), Scalar));
return ret;
}

GI_FORCEINLINE
GI_FLOAT32_t GiMultiplySubScalarFloat32( GI_FLOAT32_t GiMultiplySubScalarFloat32(
GI_FLOAT32_t VectorSub, GI_FLOAT32_t Vector, float Scalar) { GI_FLOAT32_t VectorSub, GI_FLOAT32_t Vector, float Scalar) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
@@ -951,6 +980,26 @@ GI_FLOAT32_t GiDivideFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) {
#endif #endif
} }


#define OPV2(op) \
GI_FORCEINLINE \
GI_FLOAT32_V2_t op##V2(GI_FLOAT32_V2_t Vector1, GI_FLOAT32_V2_t Vector2) { \
GI_FLOAT32_V2_t ret; \
GiSetSubVectorFloat32V2( \
ret, 0, \
op(GiGetSubVectorFloat32V2(Vector1, 0), \
GiGetSubVectorFloat32V2(Vector2, 0))); \
GiSetSubVectorFloat32V2( \
ret, 1, \
op(GiGetSubVectorFloat32V2(Vector1, 1), \
GiGetSubVectorFloat32V2(Vector2, 1))); \
return ret; \
}
OPV2(GiAddFloat32);
OPV2(GiSubtractFloat32);
OPV2(GiMultiplyFloat32);
OPV2(GiDivideFloat32);
#undef OPV2

GI_FORCEINLINE GI_FORCEINLINE
GI_FLOAT32_t GiRecpeSFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { GI_FLOAT32_t GiRecpeSFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) {
#if defined(GI_NEON64_INTRINSICS) #if defined(GI_NEON64_INTRINSICS)


Loading…
Cancel
Save