GitOrigin-RevId: fce5103088
HuaHua404-patch-4
@@ -1,7 +1,6 @@ | |||||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | #include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | #include "src/fallback/elemwise_helper/elemwise_op.h" | ||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | #pragma GCC diagnostic ignored "-Wunused-parameter" | ||||
@@ -1,7 +1,6 @@ | |||||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | #include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | #include "src/fallback/elemwise_helper/elemwise_op.h" | ||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | #pragma GCC diagnostic ignored "-Wunused-parameter" | ||||
@@ -4,7 +4,6 @@ | |||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | #include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | #include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | #include "src/fallback/elemwise_helper/elemwise_op.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -3,29 +3,44 @@ | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | #include "src/fallback/conv_bias/gi/fp32/helper.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace fallback { | namespace fallback { | ||||
/* | |||||
* wd##0 = d##0; | |||||
* tmp0 = (d##0 + d##2) * -0.2222222f | |||||
* tmp1 = d##1 * -0.2222222 | |||||
* wd##1 = tmp0 + tmp1 | |||||
* wd##2 = tmp0 - tmp1 | |||||
* tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f | |||||
* tmp1 = d##1 * 0.0222222f | |||||
* wd##3 = tmp0 + tmp1 | |||||
* wd##4 = tmp0 - tmp1 | |||||
* tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f | |||||
* tmp1 = d##1 * 0.3555556f | |||||
* wd##5 = tmp0 + tmp1 | |||||
* wd##6 = tmp0 - tmp1 | |||||
* wd##7 = d##2 | |||||
*/ | |||||
template <param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT> | template <param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT> | ||||
struct FilterTransform6X3 { | struct FilterTransform6X3 { | ||||
#define FILTER_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
auto tmp0 = (d##0 + d##2) * -0.2222222f; \ | |||||
auto tmp1 = d##1 * -0.2222222f; \ | |||||
wd##1 = tmp0 + tmp1; \ | |||||
wd##2 = tmp0 - tmp1; \ | |||||
tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \ | |||||
tmp1 = d##1 * 0.0222222f; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \ | |||||
tmp1 = d##1 * 0.3555556f; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = d##2; \ | |||||
#define FILTER_TRANSFORM(d, wd, ADDC, SUBC, MULC) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
auto tmp0 = MULC(ADDC(d##0, d##2), -0.2222222f); \ | |||||
auto tmp1 = MULC(d##1, -0.2222222f); \ | |||||
wd##1 = ADDC(tmp0, tmp1); \ | |||||
wd##2 = SUBC(tmp0, tmp1); \ | |||||
tmp0 = ADDC(MULC(d##0, 0.0111111f), MULC(d##2, 0.0444444f)); \ | |||||
tmp1 = MULC(d##1, 0.0222222f); \ | |||||
wd##3 = ADDC(tmp0, tmp1); \ | |||||
wd##4 = SUBC(tmp0, tmp1); \ | |||||
tmp0 = ADDC(MULC(d##0, 0.7111111f), MULC(d##2, 0.1777778f)); \ | |||||
tmp1 = MULC(d##1, 0.3555556f); \ | |||||
wd##5 = ADDC(tmp0, tmp1); \ | |||||
wd##6 = SUBC(tmp0, tmp1); \ | |||||
wd##7 = d##2; \ | |||||
} while (0); | } while (0); | ||||
static void transform( | static void transform( | ||||
@@ -49,37 +64,35 @@ struct FilterTransform6X3 { | |||||
rep(ic, IC) { | rep(ic, IC) { | ||||
const float* fptr = filter + (oc * IC + ic) * 3 * 3; | const float* fptr = filter + (oc * IC + ic) * 3 * 3; | ||||
Vector<float, 4> g0 = Vector<float, 4>::load(fptr); | |||||
Vector<float, 4> g1 = Vector<float, 4>::load(fptr + 3); | |||||
GI_FLOAT32_t g0 = GiLoadFloat32(fptr); | |||||
GI_FLOAT32_t g1 = GiLoadFloat32(fptr + 3); | |||||
Vector<float, 4> g2 = Vector<float, 4>::load(fptr + 6 - 1); | |||||
GI_FLOAT32_t g2 = GiLoadFloat32(fptr + 6 - 1); | |||||
GI_FLOAT32_t zeros = GiZeroFloat32(); | GI_FLOAT32_t zeros = GiZeroFloat32(); | ||||
g2.value = GiFloat32Type2FixLenType( | |||||
GiExtqFloat32(GiFixLenType2GiFloat32Type(g2.value), zeros, 1)); | |||||
g2 = GiExtqFloat32(g2, zeros, 1); | |||||
#define cb(i) Vector<float, 4> wd##i; | |||||
#define cb(i) GI_FLOAT32_t wd##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) Vector<float, 8> wdt##i; | |||||
UNROLL_CALL_NOWRAPPER(3, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<float, 8> ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
FILTER_TRANSFORM(g, wd); | |||||
FILTER_TRANSFORM(g, wd, ADDF, SUBF, MULSF); | |||||
size_t ocb = oc / 4; | size_t ocb = oc / 4; | ||||
size_t oc4 = oc % 4; | size_t oc4 = oc % 4; | ||||
size_t icb = ic / 4; | size_t icb = ic / 4; | ||||
size_t ic4 = ic % 4; | size_t ic4 = ic % 4; | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
#define cb(i) GI_FLOAT32_V2_t wdt##i; | |||||
UNROLL_CALL_NOWRAPPER(3, cb); | |||||
#undef cb | |||||
#define cb(i) GI_FLOAT32_V2_t ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
TRANSPOSE_8x3(wd, wdt); | TRANSPOSE_8x3(wd, wdt); | ||||
FILTER_TRANSFORM(wdt, ret); | |||||
FILTER_TRANSFORM(wdt, ret, ADDFV2, SUBFV2, MULSFV2); | |||||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||||
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
rep(i, alpha) rep(j, alpha) { | rep(i, alpha) rep(j, alpha) { | ||||
@@ -116,8 +129,7 @@ struct FilterTransform6X3 { | |||||
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ | mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ | ||||
mid_buf1 += 8; \ | mid_buf1 += 8; \ | ||||
} while (0); | } while (0); | ||||
#define GET_VECTOR_ELEM(s, i, idx) \ | |||||
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value)) | |||||
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i)) | |||||
float* mid_buf1 = transform_mid_buf; | float* mid_buf1 = transform_mid_buf; | ||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
@@ -2,6 +2,15 @@ | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/fallback/general_intrinsic/gi_float.h" | #include "src/fallback/general_intrinsic/gi_float.h" | ||||
#define ADDF GiAddFloat32 | |||||
#define ADDFV2 GiAddFloat32V2 | |||||
#define SUBF GiSubtractFloat32 | |||||
#define SUBFV2 GiSubtractFloat32V2 | |||||
#define MULF GiMultiplyFloat32 | |||||
#define MULFV2 GiMultiplyFloat32V2 | |||||
#define MULSF GiMultiplyScalerFloat32 | |||||
#define MULSFV2 GiMultiplyScalerFloat32V2 | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace fallback { | namespace fallback { | ||||
inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { | inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { | ||||
@@ -47,159 +56,166 @@ inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { | |||||
#define CONCAT(a, idx) a##idx | #define CONCAT(a, idx) a##idx | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
//! ret and a are type Vector<float, 8> | |||||
#define TRANSPOSE_8x8(a, ret) \ | |||||
do { \ | |||||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value.val[0], CONCAT(a, 1).value.val[0]); \ | |||||
auto b1 = GiZipqFloat32(CONCAT(a, 0).value.val[1], CONCAT(a, 1).value.val[1]); \ | |||||
auto b2 = GiZipqFloat32(CONCAT(a, 2).value.val[0], CONCAT(a, 3).value.val[0]); \ | |||||
auto b3 = GiZipqFloat32(CONCAT(a, 2).value.val[1], CONCAT(a, 3).value.val[1]); \ | |||||
auto b4 = GiZipqFloat32(CONCAT(a, 4).value.val[0], CONCAT(a, 5).value.val[0]); \ | |||||
auto b5 = GiZipqFloat32(CONCAT(a, 4).value.val[1], CONCAT(a, 5).value.val[1]); \ | |||||
auto b6 = GiZipqFloat32(CONCAT(a, 6).value.val[0], CONCAT(a, 7).value.val[0]); \ | |||||
auto b7 = GiZipqFloat32(CONCAT(a, 6).value.val[1], CONCAT(a, 7).value.val[1]); \ | |||||
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||||
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b6.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b6.val[0]))); \ | |||||
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||||
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b6.val[1]))); \ | |||||
CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||||
CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b6.val[1]))); \ | |||||
CONCAT(ret, 4).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 4).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b5.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b7.val[0]))); \ | |||||
CONCAT(ret, 5).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 5).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b5.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b7.val[0]))); \ | |||||
CONCAT(ret, 6).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||||
CONCAT(ret, 6).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b5.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b7.val[1]))); \ | |||||
CONCAT(ret, 7).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||||
CONCAT(ret, 7).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b5.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b7.val[1]))); \ | |||||
#define TRANSPOSE_8x8(a, ret) \ | |||||
do { \ | |||||
auto b0 = GiZipqFloat32(CONCAT(a, 0).val[0], CONCAT(a, 1).val[0]); \ | |||||
auto b1 = GiZipqFloat32(CONCAT(a, 0).val[1], CONCAT(a, 1).val[1]); \ | |||||
auto b2 = GiZipqFloat32(CONCAT(a, 2).val[0], CONCAT(a, 3).val[0]); \ | |||||
auto b3 = GiZipqFloat32(CONCAT(a, 2).val[1], CONCAT(a, 3).val[1]); \ | |||||
auto b4 = GiZipqFloat32(CONCAT(a, 4).val[0], CONCAT(a, 5).val[0]); \ | |||||
auto b5 = GiZipqFloat32(CONCAT(a, 4).val[1], CONCAT(a, 5).val[1]); \ | |||||
auto b6 = GiZipqFloat32(CONCAT(a, 6).val[0], CONCAT(a, 7).val[0]); \ | |||||
auto b7 = GiZipqFloat32(CONCAT(a, 6).val[1], CONCAT(a, 7).val[1]); \ | |||||
CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||||
CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b6.val[0]))); \ | |||||
CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||||
CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b6.val[0]))); \ | |||||
CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||||
CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b6.val[1]))); \ | |||||
CONCAT(ret, 3).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||||
CONCAT(ret, 3).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b6.val[1]))); \ | |||||
CONCAT(ret, 4).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 4).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b5.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b7.val[0]))); \ | |||||
CONCAT(ret, 5).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 5).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b5.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b7.val[0]))); \ | |||||
CONCAT(ret, 6).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||||
CONCAT(ret, 6).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b5.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b7.val[1]))); \ | |||||
CONCAT(ret, 7).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||||
CONCAT(ret, 7).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b5.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b7.val[1]))); \ | |||||
} while (0); | } while (0); | ||||
#define TRANSPOSE_8x3(a, ret) \ | |||||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||||
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||||
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||||
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||||
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||||
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||||
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||||
#define TRANSPOSE_8x3(a, ret) \ | |||||
auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \ | |||||
auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \ | |||||
auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5)); \ | |||||
auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7)); \ | |||||
CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||||
CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||||
CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||||
CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[1]))); | GiReinterpretqFloat32ToS64(b3.val[1]))); | ||||
#define TRANSPOSE_8x4(a, ret) \ | |||||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||||
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||||
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||||
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||||
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||||
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||||
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||||
CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||||
CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||||
#define TRANSPOSE_8x4(a, ret) \ | |||||
auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \ | |||||
auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \ | |||||
auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5)); \ | |||||
auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7)); \ | |||||
CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||||
CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||||
CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||||
CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||||
CONCAT(ret, 3).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||||
CONCAT(ret, 3).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[1]))); | GiReinterpretqFloat32ToS64(b3.val[1]))); | ||||
#else | #else | ||||
#define TRANSPOSE_8x4(a, ret) \ | |||||
auto b0 = GiZipqFloat32( \ | |||||
GiFixLenType2GiFloat32Type(CONCAT(a, 0).value), \ | |||||
GiFixLenType2GiFloat32Type(CONCAT(a, 1).value)); \ | |||||
auto b1 = GiZipqFloat32( \ | |||||
GiFixLenType2GiFloat32Type(CONCAT(a, 2).value), \ | |||||
GiFixLenType2GiFloat32Type(CONCAT(a, 3).value)); \ | |||||
auto b2 = GiZipqFloat32( \ | |||||
GiFixLenType2GiFloat32Type(CONCAT(a, 4).value), \ | |||||
GiFixLenType2GiFloat32Type(CONCAT(a, 5).value)); \ | |||||
auto b3 = GiZipqFloat32( \ | |||||
GiFixLenType2GiFloat32Type(CONCAT(a, 6).value), \ | |||||
GiFixLenType2GiFloat32Type(CONCAT(a, 7).value)); \ | |||||
CONCAT(ret, 0).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 0)), \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \ | |||||
CONCAT(ret, 1).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 0)), \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \ | |||||
CONCAT(ret, 2).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 1)), \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \ | |||||
CONCAT(ret, 3).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 1)), \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \ | |||||
CONCAT(ret, 0).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 0)), \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \ | |||||
CONCAT(ret, 1).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 0)), \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \ | |||||
CONCAT(ret, 2).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 1)), \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 1)))); \ | |||||
CONCAT(ret, 3).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 1)), \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 1)))); | |||||
#define TRANSPOSE_8x4(a, ret) \ | |||||
auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \ | |||||
auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \ | |||||
auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5)); \ | |||||
auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7)); \ | |||||
GiSetSubVectorFloat32V2( \ | |||||
CONCAT(ret, 0), 0, \ | |||||
GiCombineFloat32( \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 0)), \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \ | |||||
GiSetSubVectorFloat32V2( \ | |||||
CONCAT(ret, 1), 0, \ | |||||
GiCombineFloat32( \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 0)), \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \ | |||||
GiSetSubVectorFloat32V2( \ | |||||
CONCAT(ret, 2), 0, \ | |||||
GiCombineFloat32( \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 1)), \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \ | |||||
GiSetSubVectorFloat32V2( \ | |||||
CONCAT(ret, 3), 0, \ | |||||
GiCombineFloat32( \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 1)), \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \ | |||||
GiSetSubVectorFloat32V2( \ | |||||
CONCAT(ret, 0), 1, \ | |||||
GiCombineFloat32( \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 0)), \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \ | |||||
GiSetSubVectorFloat32V2( \ | |||||
CONCAT(ret, 1), 1, \ | |||||
GiCombineFloat32( \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 0)), \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \ | |||||
GiSetSubVectorFloat32V2( \ | |||||
CONCAT(ret, 2), 1, \ | |||||
GiCombineFloat32( \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 1)), \ | |||||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 1)))); \ | |||||
GiSetSubVectorFloat32V2( \ | |||||
CONCAT(ret, 3), 1, \ | |||||
GiCombineFloat32( \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 1)), \ | |||||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 1)))); | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,7 +1,6 @@ | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | #include "src/fallback/conv_bias/gi/fp32/strategy.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | #include "src/fallback/conv_bias/gi/fp32/helper.h" | ||||
@@ -70,8 +69,7 @@ struct InputTransform2X3 { | |||||
size_t nr_units_in_tile, size_t ic, size_t IC) { | size_t nr_units_in_tile, size_t ic, size_t IC) { | ||||
constexpr size_t alpha = 2 + 3 - 1; | constexpr size_t alpha = 2 + 3 - 1; | ||||
// BT * d * B | // BT * d * B | ||||
#define cb(m, n) \ | |||||
Vector<float, 4> d##m##n = Vector<float, 4>::load(patchT + m * 4 * 4 + n * 4); | |||||
#define cb(m, n) GI_FLOAT32_t d##m##n = GiLoadFloat32(patchT + m * 4 * 4 + n * 4); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | ||||
#undef cb | #undef cb | ||||
@@ -80,20 +78,20 @@ struct InputTransform2X3 { | |||||
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | ||||
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | ||||
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 | //! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 | ||||
#define cb(m) \ | |||||
auto t0##m = d0##m - d2##m; \ | |||||
auto t1##m = d1##m + d2##m; \ | |||||
auto t2##m = d2##m - d1##m; \ | |||||
auto t3##m = d3##m - d1##m; | |||||
#define cb(m) \ | |||||
auto t0##m = SUBF(d0##m, d2##m); \ | |||||
auto t1##m = ADDF(d1##m, d2##m); \ | |||||
auto t2##m = SUBF(d2##m, d1##m); \ | |||||
auto t3##m = SUBF(d3##m, d1##m); | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | UNROLL_CALL_NOWRAPPER(4, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(m) \ | |||||
d##m##0 = t##m##0 - t##m##2; \ | |||||
d##m##1 = t##m##1 + t##m##2; \ | |||||
d##m##2 = t##m##2 - t##m##1; \ | |||||
d##m##3 = t##m##3 - t##m##1; | |||||
#define cb(m) \ | |||||
d##m##0 = SUBF(t##m##0, t##m##2); \ | |||||
d##m##1 = ADDF(t##m##1, t##m##2); \ | |||||
d##m##2 = SUBF(t##m##2, t##m##1); \ | |||||
d##m##3 = SUBF(t##m##3, t##m##1); | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | UNROLL_CALL_NOWRAPPER(4, cb); | ||||
#undef cb | #undef cb | ||||
@@ -101,9 +99,10 @@ struct InputTransform2X3 { | |||||
size_t ICB = IC / 4; | size_t ICB = IC / 4; | ||||
size_t icb = ic / 4; | size_t icb = ic / 4; | ||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
d##m##n.save( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ | input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ | ||||
icb * nr_units_in_tile * 4 + unit_idx * 4); | |||||
icb * nr_units_in_tile * 4 + unit_idx * 4, \ | |||||
d##m##n); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) | UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) | ||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -125,7 +124,7 @@ struct OutputTransform2X3 { | |||||
size_t ocb = oc_index / 4; | size_t ocb = oc_index / 4; | ||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
auto v##m##n = Vector<float, 4>::load( \ | |||||
auto v##m##n = GiLoadFloat32( \ | |||||
output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ | output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ | ||||
ocb * nr_units_in_tile * 4 + unit_idx * 4); | ocb * nr_units_in_tile * 4 + unit_idx * 4); | ||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | ||||
@@ -134,37 +133,37 @@ struct OutputTransform2X3 { | |||||
//! 0 1 -1 1 v10 v11 v12 v13 1 1 | //! 0 1 -1 1 v10 v11 v12 v13 1 1 | ||||
//! v20 v21 v22 v23 1 -1 | //! v20 v21 v22 v23 1 -1 | ||||
//! v30 v31 v32 v33 0 1 | //! v30 v31 v32 v33 0 1 | ||||
#define cb(m) \ | |||||
auto t0##m = v0##m + v1##m + v2##m; \ | |||||
auto t1##m = v1##m - v2##m + v3##m; | |||||
#define cb(m) \ | |||||
auto t0##m = ADDF(ADDF(v0##m, v1##m), v2##m); \ | |||||
auto t1##m = ADDF(SUBF(v1##m, v2##m), v3##m); | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | UNROLL_CALL_NOWRAPPER(4, cb); | ||||
#undef cb | #undef cb | ||||
v00 = t00 + t01 + t02; | |||||
v10 = t10 + t11 + t12; | |||||
v01 = t01 - t02 + t03; | |||||
v11 = t11 - t12 + t13; | |||||
v00 = ADDF(ADDF(t00, t01), t02); | |||||
v10 = ADDF(ADDF(t10, t11), t12); | |||||
v01 = ADDF(SUBF(t01, t02), t03); | |||||
v11 = ADDF(SUBF(t11, t12), t13); | |||||
Vector<float, 4> vbias; | |||||
GI_FLOAT32_t vbias; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
vbias = Vector<float, 4>::load(bias + oc); | |||||
vbias = GiLoadFloat32(bias + oc); | |||||
v00 += vbias; | |||||
v10 += vbias; | |||||
v01 += vbias; | |||||
v11 += vbias; | |||||
v00 = ADDF(v00, vbias); | |||||
v10 = ADDF(v10, vbias); | |||||
v01 = ADDF(v01, vbias); | |||||
v11 = ADDF(v11, vbias); | |||||
} | } | ||||
if (bmode != BiasMode::BIAS) { | if (bmode != BiasMode::BIAS) { | ||||
v00 = op(GiFixLenType2GiFloat32Type(v00.value)); | |||||
v01 = op(GiFixLenType2GiFloat32Type(v01.value)); | |||||
v10 = op(GiFixLenType2GiFloat32Type(v10.value)); | |||||
v11 = op(GiFixLenType2GiFloat32Type(v11.value)); | |||||
v00 = op(v00); | |||||
v01 = op(v01); | |||||
v10 = op(v10); | |||||
v11 = op(v11); | |||||
} | } | ||||
v00.save(transform_mid_buf + (0 * 2 + 0) * 4); | |||||
v10.save(transform_mid_buf + (1 * 2 + 0) * 4); | |||||
v01.save(transform_mid_buf + (0 * 2 + 1) * 4); | |||||
v11.save(transform_mid_buf + (1 * 2 + 1) * 4); | |||||
GiStoreFloat32(transform_mid_buf + (0 * 2 + 0) * 4, v00); | |||||
GiStoreFloat32(transform_mid_buf + (1 * 2 + 0) * 4, v10); | |||||
GiStoreFloat32(transform_mid_buf + (0 * 2 + 1) * 4, v01); | |||||
GiStoreFloat32(transform_mid_buf + (1 * 2 + 1) * 4, v11); | |||||
for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { | for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { | ||||
for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { | for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { | ||||
@@ -1,7 +1,6 @@ | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | #include "src/fallback/conv_bias/gi/fp32/strategy.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | #include "src/fallback/conv_bias/gi/fp32/helper.h" | ||||
@@ -15,61 +14,124 @@ using namespace megdnn; | |||||
using namespace fallback; | using namespace fallback; | ||||
namespace { | namespace { | ||||
/* | |||||
* wd##0 = d##0; | |||||
* wd##r0 = d##r0; | |||||
* wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; | |||||
* wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; | |||||
* wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; | |||||
* wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; | |||||
* auto tmpd0 = d##0 * 0.7111111; | |||||
* auto tmpd1 = d##1 * 0.3555556; | |||||
* auto tmpd2 = d##2 * 0.1777778; | |||||
* auto tmpd3 = d##3 * 0.0888889; | |||||
* auto tmpd4 = d##4 * 0.0444444; | |||||
* auto tmpdr0 = d##r0 * 0.7111111; | |||||
* auto tmpdr1 = d##r1 * 0.3555556; | |||||
* auto tmpdr2 = d##r2 * 0.1777778; | |||||
* auto tmpdr3 = d##r3 * 0.0888889; | |||||
* auto tmpdr4 = d##r4 * 0.0444444; | |||||
* wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; | |||||
* wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; | |||||
* wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; | |||||
* wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; | |||||
* tmpd0 = d##0 * 0.0111111; | |||||
* tmpd1 = d##1 * 0.0222222; | |||||
* tmpd2 = d##2 * 0.0444444; | |||||
* tmpd3 = d##3 * 0.0888889; | |||||
* tmpd4 = d##4 * 0.1777778; | |||||
* tmpdr0 = d##r0 * 0.0111111; | |||||
* tmpdr1 = d##r1 * 0.0222222; | |||||
* tmpdr2 = d##r2 * 0.0444444; | |||||
* tmpdr3 = d##r3 * 0.0888889; | |||||
* tmpdr4 = d##r4 * 0.1777778; | |||||
* wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; | |||||
* wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; | |||||
* wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; | |||||
* wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; | |||||
* wd##7 = d##4; | |||||
* wd##r7 = d##r4; | |||||
* | |||||
* */ | |||||
struct FilterTransform4X5 { | struct FilterTransform4X5 { | ||||
#define FILTER_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
wd##r0 = d##r0; \ | |||||
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \ | |||||
wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; \ | |||||
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \ | |||||
wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; \ | |||||
auto tmpd0 = d##0 * 0.7111111; \ | |||||
auto tmpd1 = d##1 * 0.3555556; \ | |||||
auto tmpd2 = d##2 * 0.1777778; \ | |||||
auto tmpd3 = d##3 * 0.0888889; \ | |||||
auto tmpd4 = d##4 * 0.0444444; \ | |||||
auto tmpdr0 = d##r0 * 0.7111111; \ | |||||
auto tmpdr1 = d##r1 * 0.3555556; \ | |||||
auto tmpdr2 = d##r2 * 0.1777778; \ | |||||
auto tmpdr3 = d##r3 * 0.0888889; \ | |||||
auto tmpdr4 = d##r4 * 0.0444444; \ | |||||
wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ | |||||
wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||||
wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ | |||||
wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||||
tmpd0 = d##0 * 0.0111111; \ | |||||
tmpd1 = d##1 * 0.0222222; \ | |||||
tmpd2 = d##2 * 0.0444444; \ | |||||
tmpd3 = d##3 * 0.0888889; \ | |||||
tmpd4 = d##4 * 0.1777778; \ | |||||
tmpdr0 = d##r0 * 0.0111111; \ | |||||
tmpdr1 = d##r1 * 0.0222222; \ | |||||
tmpdr2 = d##r2 * 0.0444444; \ | |||||
tmpdr3 = d##r3 * 0.0888889; \ | |||||
tmpdr4 = d##r4 * 0.1777778; \ | |||||
wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ | |||||
wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||||
wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ | |||||
wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||||
wd##7 = d##4; \ | |||||
wd##r7 = d##r4; \ | |||||
#define FILTER_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
wd##r0 = d##r0; \ | |||||
wd##1 = MULSF( \ | |||||
ADDF(ADDF(ADDF(ADDF(d##0, d##1), d##2), d##3), d##4), -0.2222222); \ | |||||
wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; \ | |||||
wd##2 = MULSF( \ | |||||
ADDF(SUBF(ADDF(SUBF(d##0, d##1), d##2), d##3), d##4), -0.2222222); \ | |||||
wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; \ | |||||
auto tmpd0 = MULSF(d##0, 0.7111111); \ | |||||
auto tmpd1 = MULSF(d##1, 0.3555556); \ | |||||
auto tmpd2 = MULSF(d##2, 0.1777778); \ | |||||
auto tmpd3 = MULSF(d##3, 0.0888889); \ | |||||
auto tmpd4 = MULSF(d##4, 0.0444444); \ | |||||
auto tmpdr0 = d##r0 * 0.7111111; \ | |||||
auto tmpdr1 = d##r1 * 0.3555556; \ | |||||
auto tmpdr2 = d##r2 * 0.1777778; \ | |||||
auto tmpdr3 = d##r3 * 0.0888889; \ | |||||
auto tmpdr4 = d##r4 * 0.0444444; \ | |||||
wd##3 = ADDF(ADDF(ADDF(ADDF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \ | |||||
wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||||
wd##4 = ADDF(SUBF(ADDF(SUBF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \ | |||||
wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||||
tmpd0 = MULSF(d##0, 0.0111111); \ | |||||
tmpd1 = MULSF(d##1, 0.0222222); \ | |||||
tmpd2 = MULSF(d##2, 0.0444444); \ | |||||
tmpd3 = MULSF(d##3, 0.0888889); \ | |||||
tmpd4 = MULSF(d##4, 0.1777778); \ | |||||
tmpdr0 = d##r0 * 0.0111111; \ | |||||
tmpdr1 = d##r1 * 0.0222222; \ | |||||
tmpdr2 = d##r2 * 0.0444444; \ | |||||
tmpdr3 = d##r3 * 0.0888889; \ | |||||
tmpdr4 = d##r4 * 0.1777778; \ | |||||
wd##5 = ADDF(ADDF(ADDF(ADDF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \ | |||||
wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||||
wd##6 = ADDF(SUBF(ADDF(SUBF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \ | |||||
wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||||
wd##7 = d##4; \ | |||||
wd##r7 = d##r4; \ | |||||
} while (0); | } while (0); | ||||
#define FILTER_TRANSFORM_FINAL(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \ | |||||
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \ | |||||
auto tmp0 = d##0 * 0.7111111 + d##2 * 0.1777778 + d##4 * 0.0444444; \ | |||||
auto tmp1 = d##1 * 0.3555556 + d##3 * 0.0888889; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d##0 * 0.0111111 + d##2 * 0.0444444 + d##4 * 0.1777778; \ | |||||
tmp1 = d##1 * 0.0222222 + d##3 * 0.0888889; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = d##4; \ | |||||
/* | |||||
*wd##0 = d##0; | |||||
*wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; | |||||
*wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; | |||||
*auto tmp0 = d##0 * 0.7111111 + d##2 * 0.1777778 + d##4 * 0.0444444; | |||||
*auto tmp1 = d##1 * 0.3555556 + d##3 * 0.0888889; | |||||
*wd##3 = tmp0 + tmp1; | |||||
*wd##4 = tmp0 - tmp1; | |||||
*tmp0 = d##0 * 0.0111111 + d##2 * 0.0444444 + d##4 * 0.1777778; | |||||
*tmp1 = d##1 * 0.0222222 + d##3 * 0.0888889; | |||||
*wd##5 = tmp0 + tmp1; | |||||
*wd##6 = tmp0 - tmp1; | |||||
*wd##7 = d##4; | |||||
*/ | |||||
#define FILTER_TRANSFORM_FINAL(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
wd##1 = \ | |||||
MULSFV2(ADDFV2(ADDFV2(ADDFV2(ADDFV2(d##0, d##1), d##2), d##3), d##4), \ | |||||
-0.2222222); \ | |||||
wd##2 = \ | |||||
MULSFV2(ADDFV2(SUBFV2(ADDFV2(SUBFV2(d##0, d##1), d##2), d##3), d##4), \ | |||||
-0.2222222); \ | |||||
auto tmp0 = \ | |||||
ADDFV2(ADDFV2(MULSFV2(d##0, 0.7111111), MULSFV2(d##2, 0.1777778)), \ | |||||
MULSFV2(d##4, 0.0444444)); \ | |||||
auto tmp1 = ADDFV2(MULSFV2(d##1, 0.3555556), MULSFV2(d##3, 0.0888889)); \ | |||||
wd##3 = ADDFV2(tmp0, tmp1); \ | |||||
wd##4 = SUBFV2(tmp0, tmp1); \ | |||||
tmp0 = \ | |||||
ADDFV2(ADDFV2(MULSFV2(d##0, 0.0111111), MULSFV2(d##2, 0.0444444)), \ | |||||
MULSFV2(d##4, 0.1777778)); \ | |||||
tmp1 = ADDFV2(MULSFV2(d##1, 0.0222222), MULSFV2(d##3, 0.0888889)); \ | |||||
wd##5 = ADDFV2(tmp0, tmp1); \ | |||||
wd##6 = SUBFV2(tmp0, tmp1); \ | |||||
wd##7 = d##4; \ | |||||
} while (0); | } while (0); | ||||
static void transform( | static void transform( | ||||
const float* filter, float* filter_transform_buf, float* transform_mid_buf, | const float* filter, float* filter_transform_buf, float* transform_mid_buf, | ||||
@@ -89,7 +151,7 @@ struct FilterTransform4X5 { | |||||
rep(ic, IC) { | rep(ic, IC) { | ||||
const float* fptr = filter + (oc * IC + ic) * 5 * 5; | const float* fptr = filter + (oc * IC + ic) * 5 * 5; | ||||
#define cb(i) Vector<float, 4> g##i = Vector<float, 4>::load(fptr + 5 * i); | |||||
#define cb(i) GI_FLOAT32_t g##i = GiLoadFloat32(fptr + 5 * i); | |||||
UNROLL_CALL_NOWRAPPER(5, cb); | UNROLL_CALL_NOWRAPPER(5, cb); | ||||
#undef cb | #undef cb | ||||
@@ -97,7 +159,7 @@ struct FilterTransform4X5 { | |||||
UNROLL_CALL_NOWRAPPER(5, cb); | UNROLL_CALL_NOWRAPPER(5, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) Vector<float, 4> Gg##i; | |||||
#define cb(i) GI_FLOAT32_t Gg##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
@@ -105,11 +167,11 @@ struct FilterTransform4X5 { | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) Vector<float, 8> Ggt##i; | |||||
#define cb(i) GI_FLOAT32_V2_t Ggt##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | UNROLL_CALL_NOWRAPPER(4, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) Vector<float, 8> result##i; | |||||
#define cb(i) GI_FLOAT32_V2_t result##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
@@ -128,11 +190,11 @@ struct FilterTransform4X5 { | |||||
GI_FLOAT32_t vgr1 = GiLoadFloat32(tmp); | GI_FLOAT32_t vgr1 = GiLoadFloat32(tmp); | ||||
GiSetSubVectorFloat32V2(vgr, 0, vgr0); //{Ggr0, Ggr1, Ggr2, Ggr3}; | GiSetSubVectorFloat32V2(vgr, 0, vgr0); //{Ggr0, Ggr1, Ggr2, Ggr3}; | ||||
GiSetSubVectorFloat32V2(vgr, 1, vgr1); //{Ggr4, Ggr5, Ggr6, Ggr7}; | GiSetSubVectorFloat32V2(vgr, 1, vgr1); //{Ggr4, Ggr5, Ggr6, Ggr7}; | ||||
Vector<float, 8> Ggt4(vgr); | |||||
GI_FLOAT32_V2_t Ggt4(vgr); | |||||
TRANSPOSE_8x4(Gg, Ggt); | TRANSPOSE_8x4(Gg, Ggt); | ||||
FILTER_TRANSFORM_FINAL(Ggt, result); | FILTER_TRANSFORM_FINAL(Ggt, result); | ||||
#define cb(i) result##i.save(transform_mid_buf + i * alpha); | |||||
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, result##i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
rep(i, alpha) rep(j, alpha) { | rep(i, alpha) rep(j, alpha) { | ||||
@@ -145,31 +207,49 @@ struct FilterTransform4X5 { | |||||
#undef FILTER_TRANSFORM | #undef FILTER_TRANSFORM | ||||
#undef FILTER_TRANSFORM_FINAL | #undef FILTER_TRANSFORM_FINAL | ||||
/* | |||||
* wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; | |||||
* auto tmp0 = d##2 - d##4 * 4.25f + d##6; | |||||
* auto tmp1 = d##1 - d##3 * 4.25f + d##5; | |||||
* wd##1 = tmp0 + tmp1; | |||||
* wd##2 = tmp0 - tmp1; | |||||
* tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; | |||||
* tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; | |||||
* wd##3 = tmp0 + tmp1; | |||||
* wd##4 = tmp0 - tmp1; | |||||
* tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; | |||||
* tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; | |||||
* wd##5 = tmp0 + tmp1; | |||||
* wd##6 = tmp0 - tmp1; | |||||
* wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; | |||||
*/ | |||||
struct InputTransform4X5 { | struct InputTransform4X5 { | ||||
#define INPUT_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||||
auto tmp0 = d##2 - d##4 * 4.25f + d##6; \ | |||||
auto tmp1 = d##1 - d##3 * 4.25f + d##5; \ | |||||
wd##1 = tmp0 + tmp1; \ | |||||
wd##2 = tmp0 - tmp1; \ | |||||
tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ | |||||
tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ | |||||
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||||
#define INPUT_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = ADDFV2(SUBFV2(d##0, d##6), MULSFV2(SUBFV2(d##4, d##2), 5.25f)); \ | |||||
auto tmp0 = ADDFV2(SUBFV2(d##2, MULSFV2(d##4, 4.25f)), d##6); \ | |||||
auto tmp1 = ADDFV2(SUBFV2(d##1, MULSFV2(d##3, 4.25f)), d##5); \ | |||||
wd##1 = ADDFV2(tmp0, tmp1); \ | |||||
wd##2 = SUBFV2(tmp0, tmp1); \ | |||||
tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 4.0f), MULSFV2(d##4, 5.0f)), d##6); \ | |||||
tmp1 = \ | |||||
ADDFV2(SUBFV2(MULSFV2(d##1, 2.0f), MULSFV2(d##3, 2.5f)), \ | |||||
MULSFV2(d##5, 0.5f)); \ | |||||
wd##3 = ADDFV2(tmp0, tmp1); \ | |||||
wd##4 = SUBFV2(tmp0, tmp1); \ | |||||
tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 0.25f), MULSFV2(d##4, 1.25f)), d##6); \ | |||||
tmp1 = \ | |||||
ADDFV2(SUBFV2(MULSFV2(d##1, 0.5f), MULSFV2(d##3, 2.5f)), \ | |||||
MULSFV2(d##5, 2.0f)); \ | |||||
wd##5 = ADDFV2(tmp0, tmp1); \ | |||||
wd##6 = SUBFV2(tmp0, tmp1); \ | |||||
wd##7 = ADDFV2(SUBFV2(d##7, d##1), MULSFV2(SUBFV2(d##3, d##5), 5.25f)); \ | |||||
} while (0) | } while (0) | ||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2( \ | |||||
GiFixLenType2GiFloat32V2Type(CONCAT(s, i).value), 1)) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2( \ | |||||
GiFixLenType2GiFloat32V2Type(CONCAT(s, i).value), 0)) | |||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 1)) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 0)) | |||||
template <bool inner> | template <bool inner> | ||||
static void transform( | static void transform( | ||||
@@ -191,13 +271,13 @@ struct InputTransform4X5 { | |||||
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | ||||
} | } | ||||
#define cb(i) Vector<float, 8> d##i; | |||||
#define cb(i) GI_FLOAT32_V2_t d##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
if (inner) { | if (inner) { | ||||
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | ||||
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i); | |||||
#define cb(i) d##i = GiLoadFloat32V2(input_ptr + IW * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
} else { | } else { | ||||
@@ -212,21 +292,25 @@ struct InputTransform4X5 { | |||||
input[ic * IH * IW + ih * IW + iw]; | input[ic * IH * IW + ih * IW + iw]; | ||||
} | } | ||||
} | } | ||||
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||||
#define cb(i) d##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
#define cb(i) Vector<float, 8> wd##i, ret##i; | |||||
#define cb(i) GI_FLOAT32_V2_t wd##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
INPUT_TRANSFORM(d, wd); | INPUT_TRANSFORM(d, wd); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
#define cb(i) GI_FLOAT32_V2_t ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
TRANSPOSE_8x8(wd, d); | TRANSPOSE_8x8(wd, d); | ||||
INPUT_TRANSFORM(d, ret); | INPUT_TRANSFORM(d, ret); | ||||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||||
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
rep(i, alpha) rep(j, alpha) { | rep(i, alpha) rep(j, alpha) { | ||||
@@ -283,12 +367,32 @@ struct InputTransform4X5 { | |||||
}; | }; | ||||
#undef INPUT_TRANSFORM | #undef INPUT_TRANSFORM | ||||
#define OUTPUT_TRANSFORM(m, s) \ | |||||
do { \ | |||||
s0 = m0 + m1 + m2 + m3 + m4 + m5 + m6; \ | |||||
s1 = m1 - m2 + m3 * 0.5 - m4 * 0.5 + m5 * 2.0 - m6 * 2.0; \ | |||||
s2 = m1 + m2 + m3 * 0.25 + m4 * 0.25 + m5 * 4.0 + m6 * 4.0; \ | |||||
s3 = m1 - m2 + m3 * 0.125 - m4 * 0.125 + m5 * 8.0 - m6 * 8.0 + m7; \ | |||||
/* | |||||
* s0 = m0 + m1 + m2 + m3 + m4 + m5 + m6; | |||||
* s1 = m1 - m2 + m3 * 0.5 - m4 * 0.5 + m5 * 2.0 - m6 * 2.0; | |||||
* s2 = m1 + m2 + m3 * 0.25 + m4 * 0.25 + m5 * 4.0 + m6 * 4.0; | |||||
* s3 = m1 - m2 + m3 * 0.125 - m4 * 0.125 + m5 * 8.0 - m6 * 8.0 + m7; | |||||
*/ | |||||
#define OUTPUT_TRANSFORM(m, s) \ | |||||
do { \ | |||||
s0 = ADDFV2( \ | |||||
ADDFV2(ADDFV2(ADDFV2(ADDFV2(ADDFV2(m0, m1), m2), m3), m4), m5), m6); \ | |||||
s1 = \ | |||||
SUBFV2(ADDFV2(SUBFV2(ADDFV2(SUBFV2(m1, m2), MULSFV2(m3, 0.5)), \ | |||||
MULSFV2(m4, 0.5)), \ | |||||
MULSFV2(m5, 2.0)), \ | |||||
MULSFV2(m6, 2.0)); \ | |||||
s2 = \ | |||||
ADDFV2(ADDFV2(ADDFV2(ADDFV2(ADDFV2(m1, m2), MULSFV2(m3, 0.25)), \ | |||||
MULSFV2(m4, 0.25)), \ | |||||
MULSFV2(m5, 4.0)), \ | |||||
MULSFV2(m6, 4.0)); \ | |||||
s3 = ADDFV2( \ | |||||
SUBFV2(ADDFV2(SUBFV2(ADDFV2(SUBFV2(m1, m2), MULSFV2(m3, 0.125)), \ | |||||
MULSFV2(m4, 0.125)), \ | |||||
MULSFV2(m5, 8.0)), \ | |||||
MULSFV2(m6, 8.0)), \ | |||||
m7); \ | |||||
} while (0) | } while (0) | ||||
template <BiasMode bmode, typename Op> | template <BiasMode bmode, typename Op> | ||||
struct OutputTransform4X5 { | struct OutputTransform4X5 { | ||||
@@ -316,14 +420,15 @@ struct OutputTransform4X5 { | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||||
#define cb(i) auto m##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) Vector<float, 8> s##i; | |||||
#define cb(i) GI_FLOAT32_V2_t s##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | UNROLL_CALL_NOWRAPPER(4, cb); | ||||
#undef cb | #undef cb | ||||
OUTPUT_TRANSFORM(m, s); | OUTPUT_TRANSFORM(m, s); | ||||
#define cb(i) \ | #define cb(i) \ | ||||
do { \ | do { \ | ||||
auto add12 = GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ | auto add12 = GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ | ||||
@@ -355,10 +460,10 @@ struct OutputTransform4X5 { | |||||
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); | GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); | ||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
item0 = GiAddFloat32(item0, bias0); | |||||
item0 = ADDF(item0, bias0); | |||||
} else if (bmode == BiasMode::BIAS) { | } else if (bmode == BiasMode::BIAS) { | ||||
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); | bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); | ||||
item0 = GiAddFloat32(item0, bias0); | |||||
item0 = ADDF(item0, bias0); | |||||
} | } | ||||
item0 = op(item0); | item0 = op(item0); | ||||
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); | GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); | ||||
@@ -1,7 +1,6 @@ | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | #include "src/fallback/conv_bias/gi/fp32/strategy.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | #include "src/fallback/conv_bias/gi/fp32/helper.h" | ||||
@@ -15,23 +14,39 @@ using namespace megdnn; | |||||
using namespace fallback; | using namespace fallback; | ||||
namespace { | namespace { | ||||
/* | |||||
* wd##0 = d##0; | |||||
* auto tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; | |||||
* auto tmp1 = d##1 * 0.3555556f + d##3 * 0.0888889f; | |||||
* wd##1 = tmp0 + tmp1; | |||||
* wd##2 = tmp0 - tmp1; | |||||
* tmp0 = (d##0 + d##2) * -0.2222222f; | |||||
* tmp1 = (d##1 + d##3) * -0.2222222f; | |||||
* wd##3 = tmp0 + tmp1; | |||||
* wd##4 = tmp0 - tmp1; | |||||
* tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; | |||||
* tmp1 = d##1 * 0.0222222f + d##3 * 0.0888889f; | |||||
* wd##5 = tmp0 + tmp1; | |||||
* wd##6 = tmp0 - tmp1; | |||||
* wd##7 = d##3; | |||||
*/ | |||||
struct FilterTransform5X4 { | struct FilterTransform5X4 { | ||||
#define FILTER_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
auto tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \ | |||||
auto tmp1 = d##1 * 0.3555556f + d##3 * 0.0888889f; \ | |||||
wd##1 = tmp0 + tmp1; \ | |||||
wd##2 = tmp0 - tmp1; \ | |||||
tmp0 = (d##0 + d##2) * -0.2222222f; \ | |||||
tmp1 = (d##1 + d##3) * -0.2222222f; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \ | |||||
tmp1 = d##1 * 0.0222222f + d##3 * 0.0888889f; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = d##3; \ | |||||
#define FILTER_TRANSFORM(d, wd, ADDC, SUBC, MULC) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
auto tmp0 = ADDC(MULC(d##0, 0.7111111f), MULC(d##2, 0.1777778f)); \ | |||||
auto tmp1 = ADDC(MULC(d##1, 0.3555556f), MULC(d##3, 0.0888889f)); \ | |||||
wd##1 = ADDC(tmp0, tmp1); \ | |||||
wd##2 = SUBC(tmp0, tmp1); \ | |||||
tmp0 = MULC(ADDC(d##0, d##2), -0.2222222f); \ | |||||
tmp1 = MULC(ADDC(d##1, d##3), -0.2222222f); \ | |||||
wd##3 = ADDC(tmp0, tmp1); \ | |||||
wd##4 = SUBC(tmp0, tmp1); \ | |||||
tmp0 = ADDC(MULC(d##0, 0.0111111f), MULC(d##2, 0.0444444f)); \ | |||||
tmp1 = ADDC(MULC(d##1, 0.0222222f), MULC(d##3, 0.0888889f)); \ | |||||
wd##5 = ADDC(tmp0, tmp1); \ | |||||
wd##6 = SUBC(tmp0, tmp1); \ | |||||
wd##7 = d##3; \ | |||||
} while (0) | } while (0) | ||||
static void transform( | static void transform( | ||||
@@ -53,28 +68,27 @@ struct FilterTransform5X4 { | |||||
rep(ic, IC) { | rep(ic, IC) { | ||||
const float* fptr = filter + (oc * IC + ic) * 4 * 4; | const float* fptr = filter + (oc * IC + ic) * 4 * 4; | ||||
#define cb(i) Vector<float, 4> g##i = Vector<float, 4>::load(fptr + 4 * i); | |||||
#define cb(i) GI_FLOAT32_t g##i = GiLoadFloat32(fptr + 4 * i); | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | UNROLL_CALL_NOWRAPPER(4, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) Vector<float, 4> wd##i; | |||||
#define cb(i) GI_FLOAT32_t wd##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) Vector<float, 8> wdt##i; | |||||
FILTER_TRANSFORM(g, wd, ADDF, SUBF, MULSF); | |||||
#if MEGDNN_AARCH64 | |||||
#define cb(i) GI_FLOAT32_V2_t wdt##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | UNROLL_CALL_NOWRAPPER(4, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) Vector<float, 8> ret##i; | |||||
#define cb(i) GI_FLOAT32_V2_t ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
FILTER_TRANSFORM(g, wd); | |||||
#if MEGDNN_AARCH64 | |||||
TRANSPOSE_8x4(wd, wdt); | TRANSPOSE_8x4(wd, wdt); | ||||
FILTER_TRANSFORM(wdt, ret); | |||||
FILTER_TRANSFORM(wdt, ret, ADDFV2, SUBFV2, MULSFV2); | |||||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||||
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
rep(i, alpha) rep(j, alpha) { | rep(i, alpha) rep(j, alpha) { | ||||
@@ -104,8 +118,7 @@ struct FilterTransform5X4 { | |||||
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ | mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ | ||||
mid_buf1 += 8; \ | mid_buf1 += 8; \ | ||||
} while (0); | } while (0); | ||||
#define GET_VECTOR_ELEM(s, i, idx) \ | |||||
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value)) | |||||
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i)) | |||||
float* mid_buf1 = transform_mid_buf; | float* mid_buf1 = transform_mid_buf; | ||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
@@ -123,29 +136,49 @@ struct FilterTransform5X4 { | |||||
#undef FILTER_TRANSFORM | #undef FILTER_TRANSFORM | ||||
#undef GET_VECTOR_ELEM | #undef GET_VECTOR_ELEM | ||||
/* | |||||
* wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; | |||||
* auto tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; | |||||
* auto tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; | |||||
* wd##1 = tmp0 + tmp1; | |||||
* wd##2 = tmp0 - tmp1; | |||||
* tmp0 = d##2 - d##4 * 4.25f + d##6; | |||||
* tmp1 = d##1 - d##3 * 4.25f + d##5; | |||||
* wd##3 = tmp0 + tmp1; | |||||
* wd##4 = tmp0 - tmp1; | |||||
* tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; | |||||
* tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; | |||||
* wd##5 = tmp0 + tmp1; | |||||
* wd##6 = tmp0 - tmp1; | |||||
* wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; | |||||
*/ | |||||
struct InputTransform5X4 { | struct InputTransform5X4 { | ||||
#define INPUT_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||||
auto tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ | |||||
auto tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ | |||||
wd##1 = tmp0 + tmp1; \ | |||||
wd##2 = tmp0 - tmp1; \ | |||||
tmp0 = d##2 - d##4 * 4.25f + d##6; \ | |||||
tmp1 = d##1 - d##3 * 4.25f + d##5; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ | |||||
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||||
#define INPUT_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = ADDFV2(SUBFV2(d##0, d##6), MULSFV2(SUBFV2(d##4, d##2), 5.25f)); \ | |||||
auto tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 4.0f), MULSFV2(d##4, 5.0f)), d##6); \ | |||||
auto tmp1 = \ | |||||
ADDFV2(SUBFV2(MULSFV2(d##1, 2.0f), MULSFV2(d##3, 2.5f)), \ | |||||
MULSFV2(d##5, 0.5f)); \ | |||||
wd##1 = ADDFV2(tmp0, tmp1); \ | |||||
wd##2 = SUBFV2(tmp0, tmp1); \ | |||||
tmp0 = ADDFV2(SUBFV2(d##2, MULSFV2(d##4, 4.25f)), d##6); \ | |||||
tmp1 = ADDFV2(SUBFV2(d##1, MULSFV2(d##3, 4.25f)), d##5); \ | |||||
wd##3 = ADDFV2(tmp0, tmp1); \ | |||||
wd##4 = SUBFV2(tmp0, tmp1); \ | |||||
tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 0.25f), MULSFV2(d##4, 1.25f)), d##6); \ | |||||
tmp1 = \ | |||||
ADDFV2(SUBFV2(MULSFV2(d##1, 0.5f), MULSFV2(d##3, 2.5f)), \ | |||||
MULSFV2(d##5, 2.0f)); \ | |||||
wd##5 = ADDFV2(tmp0, tmp1); \ | |||||
wd##6 = SUBFV2(tmp0, tmp1); \ | |||||
wd##7 = ADDFV2(SUBFV2(d##7, d##1), MULSFV2(SUBFV2(d##3, d##5), 5.25f)); \ | |||||
} while (0) | } while (0) | ||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | #define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | ||||
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[1])) | |||||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 1)) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | #define GET_VECTOR_LOW_ELEM(s, i, idx) \ | ||||
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[0])) | |||||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 0)) | |||||
template <bool inner> | template <bool inner> | ||||
static void transform( | static void transform( | ||||
@@ -168,13 +201,13 @@ struct InputTransform5X4 { | |||||
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | ||||
} | } | ||||
#define cb(i) Vector<float, 8> d##i; | |||||
#define cb(i) GI_FLOAT32_V2_t d##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
if (inner) { | if (inner) { | ||||
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | ||||
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i); | |||||
#define cb(i) d##i = GiLoadFloat32V2(input_ptr + IW * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
} else { | } else { | ||||
@@ -189,21 +222,25 @@ struct InputTransform5X4 { | |||||
input[ic * IH * IW + ih * IW + iw]; | input[ic * IH * IW + ih * IW + iw]; | ||||
} | } | ||||
} | } | ||||
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||||
#define cb(i) d##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
#define cb(i) Vector<float, 8> wd##i, ret##i; | |||||
#define cb(i) GI_FLOAT32_V2_t wd##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
INPUT_TRANSFORM(d, wd); | INPUT_TRANSFORM(d, wd); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
#define cb(i) GI_FLOAT32_V2_t ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
TRANSPOSE_8x8(wd, d); | TRANSPOSE_8x8(wd, d); | ||||
INPUT_TRANSFORM(d, ret); | INPUT_TRANSFORM(d, ret); | ||||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||||
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
rep(i, alpha) rep(j, alpha) { | rep(i, alpha) rep(j, alpha) { | ||||
@@ -260,24 +297,48 @@ struct InputTransform5X4 { | |||||
}; | }; | ||||
#undef INPUT_TRANSFORM | #undef INPUT_TRANSFORM | ||||
#define OUTPUT_TRANSFORM(m, s) \ | |||||
do { \ | |||||
auto m1addm2 = m##1 + m##2; \ | |||||
auto m1subm2 = m##1 - m##2; \ | |||||
auto m3addm4 = m##3 + m##4; \ | |||||
auto m3subm4 = m##3 - m##4; \ | |||||
auto m5addm6 = (m##5 + m##6); \ | |||||
auto m5subm6 = (m##5 - m##6); \ | |||||
s##0 = m##0; \ | |||||
CONCAT(s, 0).add(m1addm2).add(m3addm4).add(m5addm6); \ | |||||
CONCAT(s, 1) = m3subm4; \ | |||||
CONCAT(s, 1).mla(m1subm2, 0.5f).mla(m5subm6, 2.0f); \ | |||||
CONCAT(s, 2) = m3addm4; \ | |||||
CONCAT(s, 2).mla(m1addm2, 0.25f).mla(m5addm6, 4.0f); \ | |||||
CONCAT(s, 3) = m3subm4; \ | |||||
CONCAT(s, 3).mla(m1subm2, 0.125f).mla(m5subm6, 8.0f); \ | |||||
CONCAT(s, 4) = m##7; \ | |||||
CONCAT(s, 4).mla(m1addm2, 0.0625f).add(m3addm4).mla(m5addm6, 16.0f); \ | |||||
/* | |||||
* auto m1addm2 = m##1 + m##2; | |||||
* auto m1subm2 = m##1 - m##2; | |||||
* auto m3addm4 = m##3 + m##4; | |||||
* auto m3subm4 = m##3 - m##4; | |||||
* auto m5addm6 = (m##5 + m##6); | |||||
* auto m5subm6 = (m##5 - m##6); | |||||
* s##0 = m##0; | |||||
* CONCAT(s, 0).add(m1addm2).add(m3addm4).add(m5addm6); | |||||
* CONCAT(s, 1) = m3subm4; | |||||
* CONCAT(s, 1).mla(m1subm2, 0.5f).mla(m5subm6, 2.0f); | |||||
* CONCAT(s, 2) = m3addm4; | |||||
* CONCAT(s, 2).mla(m1addm2, 0.25f).mla(m5addm6, 4.0f); | |||||
* CONCAT(s, 3) = m3subm4; | |||||
* CONCAT(s, 3).mla(m1subm2, 0.125f).mla(m5subm6, 8.0f); | |||||
* CONCAT(s, 4) = m##7; | |||||
* CONCAT(s, 4).mla(m1addm2, 0.0625f).add(m3addm4).mla(m5addm6, 16.0f); | |||||
*/ | |||||
#define OUTPUT_TRANSFORM(m, s) \ | |||||
do { \ | |||||
auto m1addm2 = ADDFV2(m##1, m##2); \ | |||||
auto m1subm2 = SUBFV2(m##1, m##2); \ | |||||
auto m3addm4 = ADDFV2(m##3, m##4); \ | |||||
auto m3subm4 = SUBFV2(m##3, m##4); \ | |||||
auto m5addm6 = ADDFV2(m##5, m##6); \ | |||||
auto m5subm6 = SUBFV2(m##5, m##6); \ | |||||
s##0 = m##0; \ | |||||
CONCAT(s, 0) = \ | |||||
ADDFV2(ADDFV2(ADDFV2(CONCAT(s, 0), m1addm2), m3addm4), m5addm6); \ | |||||
CONCAT(s, 1) = m3subm4; \ | |||||
CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m1subm2, 0.5f); \ | |||||
CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m5subm6, 2.0f); \ | |||||
CONCAT(s, 2) = m3addm4; \ | |||||
CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m1addm2, 0.25f); \ | |||||
CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m5addm6, 4.0f); \ | |||||
CONCAT(s, 3) = m3subm4; \ | |||||
CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m1subm2, 0.125f); \ | |||||
CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m5subm6, 8.0f); \ | |||||
CONCAT(s, 4) = m##7; \ | |||||
CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m1addm2, 0.0625f); \ | |||||
CONCAT(s, 4) = ADDFV2(CONCAT(s, 4), m3addm4); \ | |||||
CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m5addm6, 16.0f); \ | |||||
} while (0) | } while (0) | ||||
#if defined(__GNUC__) && !defined(__llvm__) && !defined(_MSC_VER) | #if defined(__GNUC__) && !defined(__llvm__) && !defined(_MSC_VER) | ||||
@@ -314,11 +375,11 @@ struct OutputTransform5X4 { | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||||
#define cb(i) auto m##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) Vector<float, 8> s##i, ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#define cb(i) GI_FLOAT32_V2_t s##i; | |||||
UNROLL_CALL_NOWRAPPER(5, cb); | |||||
#undef cb | #undef cb | ||||
OUTPUT_TRANSFORM(m, s); | OUTPUT_TRANSFORM(m, s); | ||||
@@ -3,7 +3,6 @@ | |||||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | #include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | #include "src/fallback/conv_bias/gi/fp32/helper.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | #include "src/fallback/conv_bias/gi/fp32/strategy.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/fallback/elemwise_helper/op_unary.h" | #include "src/fallback/elemwise_helper/op_unary.h" | ||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | #include "src/naive/matrix_mul/matrix_mul_helper.h" | ||||
@@ -27,28 +26,31 @@ namespace { | |||||
* wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) | * wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) | ||||
* wd7 = (d7 - d1) + 5.25 * (d3 - d5) | * wd7 = (d7 - d1) + 5.25 * (d3 - d5) | ||||
*/ | */ | ||||
#define INPUT_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||||
auto tmp0 = d##6 + d##2 - d##4 * 4.25f; \ | |||||
auto tmp1 = d##1 + d##5 - d##3 * 4.25f; \ | |||||
wd##1 = tmp0 + tmp1; \ | |||||
wd##2 = tmp0 - tmp1; \ | |||||
tmp0 = d##6 + d##2 * 0.25f - d##4 * 1.25f; \ | |||||
tmp1 = (d##5 + d##1 * 0.25f - d##3 * 1.25f) * 2.0f; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d6 - d4 * 5.0f + d2 * 4.0f; \ | |||||
tmp1 = (d1 + d5 * 0.25f - d3 * 1.25f) * 2.0f; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||||
#define INPUT_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = ADDFV2(SUBFV2(d##0, d##6), MULSFV2(SUBFV2(d##4, d##2), 5.25f)); \ | |||||
auto tmp0 = SUBFV2(ADDFV2(d##6, d##2), MULSFV2(d##4, 4.25f)); \ | |||||
auto tmp1 = SUBFV2(ADDFV2(d##1, d##5), MULSFV2(d##3, 4.25f)); \ | |||||
wd##1 = ADDFV2(tmp0, tmp1); \ | |||||
wd##2 = SUBFV2(tmp0, tmp1); \ | |||||
tmp0 = SUBFV2(ADDFV2(d##6, MULSFV2(d##2, 0.25f)), MULSFV2(d##4, 1.25f)); \ | |||||
tmp1 = MULSFV2( \ | |||||
(SUBFV2(ADDFV2(d##5, MULSFV2(d##1, 0.25f)), MULSFV2(d##3, 1.25f))), \ | |||||
2.0f); \ | |||||
wd##3 = ADDFV2(tmp0, tmp1); \ | |||||
wd##4 = SUBFV2(tmp0, tmp1); \ | |||||
tmp0 = ADDFV2(SUBFV2(d6, MULSFV2(d4, 5.0f)), MULSFV2(d2, 4.0f)); \ | |||||
tmp1 = MULSFV2( \ | |||||
(SUBFV2(ADDFV2(d1, MULSFV2(d5, 0.25f)), MULSFV2(d3, 1.25f))), 2.0f); \ | |||||
wd##5 = ADDFV2(tmp0, tmp1); \ | |||||
wd##6 = SUBFV2(tmp0, tmp1); \ | |||||
wd##7 = ADDFV2(SUBFV2(d##7, d##1), MULSFV2(SUBFV2(d##3, d##5), 5.25f)); \ | |||||
} while (0); | } while (0); | ||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | #define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | ||||
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[1])) | |||||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 1)) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | #define GET_VECTOR_LOW_ELEM(s, i, idx) \ | ||||
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[0])) | |||||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 0)) | |||||
struct InputTransform6X3 { | struct InputTransform6X3 { | ||||
template <bool inner> | template <bool inner> | ||||
static void transform( | static void transform( | ||||
@@ -60,13 +62,13 @@ struct InputTransform6X3 { | |||||
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | ||||
} | } | ||||
#define cb(i) Vector<float, 8> d##i; | |||||
#define cb(i) GI_FLOAT32_V2_t d##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
if (inner) { | if (inner) { | ||||
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | ||||
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i); | |||||
#define cb(i) d##i = GiLoadFloat32V2(input_ptr + IW * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
} else { | } else { | ||||
@@ -81,22 +83,25 @@ struct InputTransform6X3 { | |||||
input[ic * IH * IW + ih * IW + iw]; | input[ic * IH * IW + ih * IW + iw]; | ||||
} | } | ||||
} | } | ||||
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||||
#define cb(i) d##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
#define cb(i) Vector<float, 8> wd##i, ret##i; | |||||
#define cb(i) GI_FLOAT32_V2_t wd##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
INPUT_TRANSFORM(d, wd); | INPUT_TRANSFORM(d, wd); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
#define cb(i) GI_FLOAT32_V2_t ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
TRANSPOSE_8x8(wd, d); | TRANSPOSE_8x8(wd, d); | ||||
INPUT_TRANSFORM(d, ret); | INPUT_TRANSFORM(d, ret); | ||||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||||
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
@@ -174,26 +179,34 @@ struct InputTransform6X3 { | |||||
* s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6) / 32 | * s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6) / 32 | ||||
* s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) / 32 + m7 | * s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) / 32 + m7 | ||||
*/ | */ | ||||
#define OUTPUT_TRANSFORM(m, s) \ | |||||
do { \ | |||||
auto m1addm2 = m##1 + m##2; \ | |||||
auto m1subm2 = m##1 - m##2; \ | |||||
auto m3addm4 = m##3 + m##4; \ | |||||
auto m3subm4 = m##3 - m##4; \ | |||||
auto m5addm6 = (m##5 + m##6) * 0.03125f; \ | |||||
auto m5subm6 = (m##5 - m##6) * 0.03125f; \ | |||||
s##0 = m##0; \ | |||||
CONCAT(s, 0).mla(m5addm6, 32.f).add(m3addm4).add(m1addm2); \ | |||||
CONCAT(s, 1) = m1subm2; \ | |||||
CONCAT(s, 1).mla(m3subm4, 2.f).mla(m5subm6, 16.f); \ | |||||
CONCAT(s, 2) = m1addm2; \ | |||||
CONCAT(s, 2).mla(m3addm4, 4.f).mla(m5addm6, 8.f); \ | |||||
CONCAT(s, 3) = m1subm2; \ | |||||
CONCAT(s, 3).mla(m3subm4, 8.f).mla(m5subm6, 4.f); \ | |||||
CONCAT(s, 4) = m1addm2; \ | |||||
CONCAT(s, 4).mla(m3addm4, 16.f).mla(m5addm6, 2.f); \ | |||||
CONCAT(s, 5) = m1subm2; \ | |||||
CONCAT(s, 5).mla(m3subm4, 32.f).add(m5subm6).add(m##7); \ | |||||
#define OUTPUT_TRANSFORM(m, s) \ | |||||
do { \ | |||||
auto m1addm2 = ADDFV2(m##1, m##2); \ | |||||
auto m1subm2 = SUBFV2(m##1, m##2); \ | |||||
auto m3addm4 = ADDFV2(m##3, m##4); \ | |||||
auto m3subm4 = SUBFV2(m##3, m##4); \ | |||||
auto m5addm6 = MULSFV2((ADDFV2(m##5, m##6)), 0.03125f); \ | |||||
auto m5subm6 = MULSFV2((SUBFV2(m##5, m##6)), 0.03125f); \ | |||||
s##0 = m##0; \ | |||||
CONCAT(s, 0) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 0), m5addm6, 32.f); \ | |||||
CONCAT(s, 0) = ADDFV2(CONCAT(s, 0), m3addm4); \ | |||||
CONCAT(s, 0) = ADDFV2(CONCAT(s, 0), m1addm2); \ | |||||
CONCAT(s, 1) = m1subm2; \ | |||||
CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m3subm4, 2.f); \ | |||||
CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m5subm6, 16.f); \ | |||||
CONCAT(s, 2) = m1addm2; \ | |||||
CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m3addm4, 4.f); \ | |||||
CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m5addm6, 8.f); \ | |||||
CONCAT(s, 3) = m1subm2; \ | |||||
CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m3subm4, 8.f); \ | |||||
CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m5subm6, 4.f); \ | |||||
CONCAT(s, 4) = m1addm2; \ | |||||
CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m3addm4, 16.f); \ | |||||
CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m5addm6, 2.f); \ | |||||
CONCAT(s, 5) = m1subm2; \ | |||||
CONCAT(s, 5) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 5), m3subm4, 32.f); \ | |||||
CONCAT(s, 5) = ADDFV2(CONCAT(s, 5), m5subm6); \ | |||||
CONCAT(s, 5) = ADDFV2(CONCAT(s, 5), m##7); \ | |||||
} while (0); | } while (0); | ||||
#if defined(__GNUC__) && !defined(__llvm__) && !defined(_MSC_VER) | #if defined(__GNUC__) && !defined(__llvm__) && !defined(_MSC_VER) | ||||
@@ -224,11 +237,11 @@ struct OutputTransform6X3 { | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||||
#define cb(i) auto m##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) Vector<float, 8> s##i, ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#define cb(i) GI_FLOAT32_V2_t s##i; | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
#undef cb | #undef cb | ||||
OUTPUT_TRANSFORM(m, s); | OUTPUT_TRANSFORM(m, s); | ||||
@@ -4,7 +4,6 @@ | |||||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | #include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | #include "src/fallback/conv_bias/gi/fp32/helper.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | #include "src/fallback/conv_bias/gi/fp32/strategy.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/fallback/elemwise_helper/op_unary.h" | #include "src/fallback/elemwise_helper/op_unary.h" | ||||
@@ -76,8 +75,7 @@ struct InputTransform6X3 { | |||||
size_t nr_units_in_tile, size_t ic, size_t IC) { | size_t nr_units_in_tile, size_t ic, size_t IC) { | ||||
constexpr size_t alpha = 6 + 3 - 1; | constexpr size_t alpha = 6 + 3 - 1; | ||||
// BT * d * B | // BT * d * B | ||||
#define cb(m, n) \ | |||||
Vector<float, 4> d##m##n = Vector<float, 4>::load(patchT + m * 8 * 4 + n * 4); | |||||
#define cb(m, n) GI_FLOAT32_t d##m##n = GiLoadFloat32(patchT + m * 8 * 4 + n * 4); | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | ||||
#undef cb | #undef cb | ||||
@@ -91,36 +89,103 @@ struct InputTransform6X3 { | |||||
//! 0 1 -1 2 -2 0.5 -0.5 -5.25 | //! 0 1 -1 2 -2 0.5 -0.5 -5.25 | ||||
//! -1 1 1 1 1 1 1 0 | //! -1 1 1 1 1 1 1 0 | ||||
//! 0 0 0 0 0 0 0 1 | //! 0 0 0 0 0 0 0 1 | ||||
/* | |||||
* auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; | |||||
* auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; | |||||
* auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; | |||||
* auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + | |||||
* d5##m * 2.f + d6##m; | |||||
* auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - d4##m * 1.25f - | |||||
* d5##m * 2.f + d6##m; | |||||
* auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + | |||||
* d5##m * 0.5f + d6##m; | |||||
* auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - | |||||
* d5##m * 0.5f + d6##m; | |||||
* auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; | |||||
*/ | |||||
#define cb(m) \ | #define cb(m) \ | ||||
auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \ | |||||
auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \ | |||||
auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \ | |||||
auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \ | |||||
d5##m * 2.f + d6##m; \ | |||||
auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - d4##m * 1.25f - \ | |||||
d5##m * 2.f + d6##m; \ | |||||
auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \ | |||||
d5##m * 0.5f + d6##m; \ | |||||
auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \ | |||||
d5##m * 0.5f + d6##m; \ | |||||
auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; | |||||
auto t0##m = SUBF(ADDF(d0##m, MULSF(SUBF(d4##m, d2##m), 5.25f)), d6##m); \ | |||||
auto t1##m = \ | |||||
SUBF(ADDF(ADDF(ADDF(d1##m, d2##m), d5##m), d6##m), \ | |||||
MULSF(ADDF(d3##m, d4##m), 4.25f)); \ | |||||
auto t2##m = \ | |||||
ADDF(SUBF(ADDF(d2##m, d6##m), ADDF(d1##m, d5##m)), \ | |||||
MULSF(SUBF(d3##m, d4##m), 4.25f)); \ | |||||
auto t3##m = \ | |||||
ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(d1##m, 0.5f), MULSF(d2##m, 0.25f)), \ | |||||
MULSF(d3##m, 2.5f)), \ | |||||
MULSF(d4##m, 1.25f)), \ | |||||
MULSF(d5##m, 2.f)), \ | |||||
d6##m); \ | |||||
auto t4##m = \ | |||||
ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(d1##m, (-0.5f)), MULSF(d2##m, 0.25f)), \ | |||||
MULSF(d3##m, 2.5f)), \ | |||||
MULSF(d4##m, 1.25f)), \ | |||||
MULSF(d5##m, 2.f)), \ | |||||
d6##m); \ | |||||
auto t5##m = \ | |||||
ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(d1##m, 2.f), MULSF(d2##m, 4.f)), \ | |||||
MULSF(d3##m, 2.5f)), \ | |||||
MULSF(d4##m, 5.f)), \ | |||||
MULSF(d5##m, 0.5f)), \ | |||||
d6##m); \ | |||||
auto t6##m = \ | |||||
ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(d1##m, (-2.f)), MULSF(d2##m, 4.f)), \ | |||||
MULSF(d3##m, 2.5f)), \ | |||||
MULSF(d4##m, 5.f)), \ | |||||
MULSF(d5##m, 0.5f)), \ | |||||
d6##m); \ | |||||
auto t7##m = ADDF(SUBF(d7##m, d1##m), MULSF(SUBF(d3##m, d5##m), 5.25f)); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(m) \ | |||||
d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \ | |||||
d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - (t##m##3 + t##m##4) * 4.25f; \ | |||||
d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + (t##m##3 - t##m##4) * 4.25f; \ | |||||
d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - t##m##4 * 1.25f + \ | |||||
t##m##5 * 2.f + t##m##6; \ | |||||
d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - t##m##4 * 1.25f - \ | |||||
t##m##5 * 2.f + t##m##6; \ | |||||
d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \ | |||||
t##m##5 * 0.5f + t##m##6; \ | |||||
d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - t##m##4 * 5.f - \ | |||||
t##m##5 * 0.5f + t##m##6; \ | |||||
d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; | |||||
/* | |||||
* d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; | |||||
* d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - (t##m##3 + t##m##4) | |||||
* * 4.25f; d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + (t##m##3 - | |||||
* t##m##4) * 4.25f; d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f | |||||
* - t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; d##m##4 = t##m##1 * (-0.5f) + | |||||
* t##m##2 * 0.25f + t##m##3 * 2.5f - t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; | |||||
* d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + | |||||
* t##m##5 * 0.5f + t##m##6; | |||||
* d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - t##m##4 * 5.f - | |||||
* t##m##5 * 0.5f + t##m##6; | |||||
* d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; | |||||
*/ | |||||
#define cb(m) \ | |||||
d##m##0 = SUBF(ADDF(t##m##0, MULSF(SUBF(t##m##4, t##m##2), 5.25f)), t##m##6); \ | |||||
d##m##1 = \ | |||||
SUBF(ADDF(ADDF(ADDF(t##m##1, t##m##2), t##m##5), t##m##6), \ | |||||
MULSF(ADDF(t##m##3, t##m##4), 4.25f)); \ | |||||
d##m##2 = \ | |||||
ADDF(SUBF(ADDF(t##m##2, t##m##6), ADDF(t##m##1, t##m##5)), \ | |||||
MULSF(SUBF(t##m##3, t##m##4), 4.25f)); \ | |||||
d##m##3 = \ | |||||
ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(t##m##1, 0.5f), MULSF(t##m##2, 0.25f)), \ | |||||
MULSF(t##m##3, 2.5f)), \ | |||||
MULSF(t##m##4, 1.25f)), \ | |||||
MULSF(t##m##5, 2.f)), \ | |||||
t##m##6); \ | |||||
d##m##4 = \ | |||||
ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(t##m##1, (-0.5f)), MULSF(t##m##2, 0.25f)), \ | |||||
MULSF(t##m##3, 2.5f)), \ | |||||
MULSF(t##m##4, 1.25f)), \ | |||||
MULSF(t##m##5, 2.f)), \ | |||||
t##m##6); \ | |||||
d##m##5 = \ | |||||
ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(t##m##1, 2.f), MULSF(t##m##2, 4.f)), \ | |||||
MULSF(t##m##3, 2.5f)), \ | |||||
MULSF(t##m##4, 5.f)), \ | |||||
MULSF(t##m##5, 0.5f)), \ | |||||
t##m##6); \ | |||||
d##m##6 = \ | |||||
ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(t##m##1, (-2.f)), MULSF(t##m##2, 4.f)), \ | |||||
MULSF(t##m##3, 2.5f)), \ | |||||
MULSF(t##m##4, 5.f)), \ | |||||
MULSF(t##m##5, 0.5f)), \ | |||||
t##m##6); \ | |||||
d##m##7 = ADDF(SUBF(t##m##7, t##m##1), MULSF(SUBF(t##m##3, t##m##5), 5.25f)); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
@@ -128,9 +193,10 @@ struct InputTransform6X3 { | |||||
size_t ICB = IC / 4; | size_t ICB = IC / 4; | ||||
size_t icb = ic / 4; | size_t icb = ic / 4; | ||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
d##m##n.save( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ | input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ | ||||
icb * nr_units_in_tile * 4 + unit_idx * 4); | |||||
icb * nr_units_in_tile * 4 + unit_idx * 4, \ | |||||
d##m##n); | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb) | UNROLL_CALL_NOWRAPPER_D2(8, 8, cb) | ||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -152,7 +218,7 @@ struct OutputTransform6X3 { | |||||
size_t ocb = oc_index / 4; | size_t ocb = oc_index / 4; | ||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
auto v##m##n = Vector<float, 4>::load( \ | |||||
auto v##m##n = GiLoadFloat32( \ | |||||
output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ | output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ | ||||
ocb * nr_units_in_tile * 4 + unit_idx * 4); | ocb * nr_units_in_tile * 4 + unit_idx * 4); | ||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | ||||
@@ -171,56 +237,88 @@ struct OutputTransform6X3 { | |||||
* 0 0.0 0 0 0 1 | * 0 0.0 0 0 0 1 | ||||
*/ | */ | ||||
Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||||
#define cb(m) \ | |||||
v1addv2 = v1##m + v2##m; \ | |||||
v1subv2 = v1##m - v2##m; \ | |||||
v3addv4 = v3##m + v4##m; \ | |||||
v3subv4 = v3##m - v4##m; \ | |||||
v5addv6 = v5##m + v6##m; \ | |||||
v5subv6 = v5##m - v6##m; \ | |||||
auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \ | |||||
auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ | |||||
auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ | |||||
auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ | |||||
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||||
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | |||||
/* | |||||
* v1addv2 = v1##m + v2##m; | |||||
* v1subv2 = v1##m - v2##m; | |||||
* v3addv4 = v3##m + v4##m; | |||||
* v3subv4 = v3##m - v4##m; | |||||
* v5addv6 = v5##m + v6##m; | |||||
* v5subv6 = v5##m - v6##m; | |||||
* auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; | |||||
* auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; | |||||
* auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; | |||||
* auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; | |||||
* auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; | |||||
* auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | |||||
*/ | |||||
GI_FLOAT32_t v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||||
#define cb(m) \ | |||||
v1addv2 = ADDF(v1##m, v2##m); \ | |||||
v1subv2 = SUBF(v1##m, v2##m); \ | |||||
v3addv4 = ADDF(v3##m, v4##m); \ | |||||
v3subv4 = SUBF(v3##m, v4##m); \ | |||||
v5addv6 = ADDF(v5##m, v6##m); \ | |||||
v5subv6 = SUBF(v5##m, v6##m); \ | |||||
auto t0##m = ADDF(ADDF(ADDF(v0##m, v1addv2), v3addv4), v5addv6); \ | |||||
auto t1##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \ | |||||
auto t2##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \ | |||||
auto t3##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \ | |||||
auto t4##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \ | |||||
auto t5##m = \ | |||||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ | |||||
v7##m); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(m) \ | |||||
v1addv2 = t##m##1 + t##m##2; \ | |||||
v1subv2 = t##m##1 - t##m##2; \ | |||||
v3addv4 = t##m##3 + t##m##4; \ | |||||
v3subv4 = t##m##3 - t##m##4; \ | |||||
v5addv6 = t##m##5 + t##m##6; \ | |||||
v5subv6 = t##m##5 - t##m##6; \ | |||||
v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \ | |||||
v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ | |||||
v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ | |||||
v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ | |||||
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||||
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | |||||
/* | |||||
* v1addv2 = t##m##1 + t##m##2; | |||||
* v1subv2 = t##m##1 - t##m##2; | |||||
* v3addv4 = t##m##3 + t##m##4; | |||||
* v3subv4 = t##m##3 - t##m##4; | |||||
* v5addv6 = t##m##5 + t##m##6; | |||||
* v5subv6 = t##m##5 - t##m##6; | |||||
* v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; | |||||
* v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; | |||||
* v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; | |||||
* v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; | |||||
* v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; | |||||
* v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | |||||
*/ | |||||
#define cb(m) \ | |||||
v1addv2 = ADDF(t##m##1, t##m##2); \ | |||||
v1subv2 = SUBF(t##m##1, t##m##2); \ | |||||
v3addv4 = ADDF(t##m##3, t##m##4); \ | |||||
v3subv4 = SUBF(t##m##3, t##m##4); \ | |||||
v5addv6 = ADDF(t##m##5, t##m##6); \ | |||||
v5subv6 = SUBF(t##m##5, t##m##6); \ | |||||
v##m##0 = ADDF(ADDF(ADDF(t##m##0, v1addv2), v3addv4), v5addv6); \ | |||||
v##m##1 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \ | |||||
v##m##2 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \ | |||||
v##m##3 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \ | |||||
v##m##4 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \ | |||||
v##m##5 = \ | |||||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ | |||||
t##m##7); | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | UNROLL_CALL_NOWRAPPER(6, cb); | ||||
#undef cb | #undef cb | ||||
Vector<float, 4> vbias; | |||||
GI_FLOAT32_t vbias; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
vbias = Vector<float, 4>::load(bias + oc); | |||||
vbias = GiLoadFloat32(bias + oc); | |||||
#define cb(m, n) v##m##n += vbias; | |||||
#define cb(m, n) v##m##n = GiAddFloat32(v##m##n, vbias); | |||||
UNROLL_CALL_RAW_D2(6, 6, cb); | UNROLL_CALL_RAW_D2(6, 6, cb); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
if (bmode != BiasMode::BIAS) { | if (bmode != BiasMode::BIAS) { | ||||
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value)); | |||||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n)); | |||||
UNROLL_CALL_RAW_D2(6, 6, cb); | UNROLL_CALL_RAW_D2(6, 6, cb); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
#define cb(m, n) CONCAT(v##m, n).save(transform_mid_buf + (m * 6 + n) * 4); | |||||
#define cb(m, n) GiStoreFloat32(transform_mid_buf + (m * 6 + n) * 4, CONCAT(v##m, n)); | |||||
UNROLL_CALL_RAW_D2(6, 6, cb); | UNROLL_CALL_RAW_D2(6, 6, cb); | ||||
#undef cb | #undef cb | ||||
@@ -1,7 +1,6 @@ | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | #include "src/fallback/conv_bias/gi/fp32/strategy.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | #include "src/fallback/conv_bias/gi/fp32/helper.h" | ||||
@@ -29,7 +28,7 @@ struct InputTransformF23_NCHW44 { | |||||
size_t iw4_start = iw_start * pack_size; | size_t iw4_start = iw_start * pack_size; | ||||
size_t ICB = IC / pack_size; | size_t ICB = IC / pack_size; | ||||
#define cb(m, n) Vector<float, 4> d##m##n; | |||||
#define cb(m, n) GI_FLOAT32_t d##m##n; | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | ||||
#undef cb | #undef cb | ||||
@@ -40,7 +39,7 @@ struct InputTransformF23_NCHW44 { | |||||
MEGDNN_MARK_USED_VAR(patchT); | MEGDNN_MARK_USED_VAR(patchT); | ||||
const float* input_ptr = | const float* input_ptr = | ||||
input + icb * IH * IW4 + ih_start * IW4 + iw4_start; | input + icb * IH * IW4 + ih_start * IW4 + iw4_start; | ||||
#define cb(n, m) d##m##n = Vector<float, 4>::load(input_ptr + pack_size * n); | |||||
#define cb(n, m) d##m##n = GiLoadFloat32(input_ptr + pack_size * n); | |||||
UNROLL_CALL_RAW(4, cb, 0); | UNROLL_CALL_RAW(4, cb, 0); | ||||
input_ptr += IW4; | input_ptr += IW4; | ||||
@@ -66,7 +65,7 @@ struct InputTransformF23_NCHW44 { | |||||
} | } | ||||
} | } | ||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
d##m##n = Vector<float, 4>::load(patchT + m * alpha * pack_size + n * pack_size); | |||||
d##m##n = GiLoadFloat32(patchT + m * alpha * pack_size + n * pack_size); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -74,29 +73,30 @@ struct InputTransformF23_NCHW44 { | |||||
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | ||||
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | ||||
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 | //! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 | ||||
#define cb(m) \ | |||||
auto t0##m = d0##m - d2##m; \ | |||||
auto t1##m = d1##m + d2##m; \ | |||||
auto t2##m = d2##m - d1##m; \ | |||||
auto t3##m = d3##m - d1##m; | |||||
#define cb(m) \ | |||||
auto t0##m = SUBF(d0##m, d2##m); \ | |||||
auto t1##m = ADDF(d1##m, d2##m); \ | |||||
auto t2##m = SUBF(d2##m, d1##m); \ | |||||
auto t3##m = SUBF(d3##m, d1##m); | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | UNROLL_CALL_NOWRAPPER(4, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(m) \ | |||||
d##m##0 = t##m##0 - t##m##2; \ | |||||
d##m##1 = t##m##1 + t##m##2; \ | |||||
d##m##2 = t##m##2 - t##m##1; \ | |||||
d##m##3 = t##m##3 - t##m##1; | |||||
#define cb(m) \ | |||||
d##m##0 = SUBF(t##m##0, t##m##2); \ | |||||
d##m##1 = ADDF(t##m##1, t##m##2); \ | |||||
d##m##2 = SUBF(t##m##2, t##m##1); \ | |||||
d##m##3 = SUBF(t##m##3, t##m##1); | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | UNROLL_CALL_NOWRAPPER(4, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(m, n) \ | |||||
d##m##n.save( \ | |||||
input_transform_buf + \ | |||||
(m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ | |||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size); | |||||
#define cb(m, n) \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | |||||
(m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ | |||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||||
d##m##n); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) | UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) | ||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -118,7 +118,7 @@ struct OutputTransformF23_NCHW44 { | |||||
size_t ocb = oc_index / pack_size; | size_t ocb = oc_index / pack_size; | ||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
auto v##m##n = Vector<float, 4>::load( \ | |||||
auto v##m##n = GiLoadFloat32( \ | |||||
output_transform_buf + \ | output_transform_buf + \ | ||||
(m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ | (m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ | ||||
ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); | ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); | ||||
@@ -130,30 +130,30 @@ struct OutputTransformF23_NCHW44 { | |||||
//! v20 v21 v22 v23 1 -1 | //! v20 v21 v22 v23 1 -1 | ||||
//! v30 v31 v32 v33 0 1 | //! v30 v31 v32 v33 0 1 | ||||
#define cb(m) \ | |||||
auto t0##m = v0##m + v1##m + v2##m; \ | |||||
auto t1##m = v1##m - v2##m + v3##m; | |||||
#define cb(m) \ | |||||
auto t0##m = ADDF(ADDF(v0##m, v1##m), v2##m); \ | |||||
auto t1##m = ADDF(SUBF(v1##m, v2##m), v3##m); | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | UNROLL_CALL_NOWRAPPER(4, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(m) \ | |||||
v##m##0 = t##m##0 + t##m##1 + t##m##2; \ | |||||
v##m##1 = t##m##1 - t##m##2 + t##m##3; | |||||
#define cb(m) \ | |||||
v##m##0 = ADDF(ADDF(t##m##0, t##m##1), t##m##2); \ | |||||
v##m##1 = ADDF(SUBF(t##m##1, t##m##2), t##m##3); | |||||
UNROLL_CALL_NOWRAPPER(2, cb); | UNROLL_CALL_NOWRAPPER(2, cb); | ||||
#undef cb | #undef cb | ||||
Vector<float, 4> vbias; | |||||
GI_FLOAT32_t vbias; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
vbias = Vector<float, 4>::load(bias + oc); | |||||
vbias = GiLoadFloat32(bias + oc); | |||||
#define cb(m, n) v##m##n += vbias; | |||||
#define cb(m, n) v##m##n = ADDF(v##m##n, vbias); | |||||
UNROLL_CALL_RAW_D2(2, 2, cb); | UNROLL_CALL_RAW_D2(2, 2, cb); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
if (bmode != BiasMode::BIAS) { | if (bmode != BiasMode::BIAS) { | ||||
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value)); | |||||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n)); | |||||
UNROLL_CALL_RAW_D2(2, 2, cb); | UNROLL_CALL_RAW_D2(2, 2, cb); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -163,12 +163,15 @@ struct OutputTransformF23_NCHW44 { | |||||
size_t ow = ow_start + owo; \ | size_t ow = ow_start + owo; \ | ||||
if (oh < OH && ow < OW) { \ | if (oh < OH && ow < OW) { \ | ||||
if (bmode == BiasMode::BIAS) { \ | if (bmode == BiasMode::BIAS) { \ | ||||
v##oho##owo += Vector<float, 4>::load( \ | |||||
bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ | |||||
v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \ | |||||
v##oho##owo = ADDF( \ | |||||
v##oho##owo, GiLoadFloat32( \ | |||||
bias + oc * OH * OW + \ | |||||
oh * OW * pack_size + ow * pack_size)); \ | |||||
v##oho##owo = op(v##oho##owo); \ | |||||
} \ | } \ | ||||
v##oho##owo.save( \ | |||||
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ | |||||
GiStoreFloat32( \ | |||||
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \ | |||||
v##oho##owo); \ | |||||
} \ | } \ | ||||
} while (0); | } while (0); | ||||
UNROLL_CALL_RAW_D2(2, 2, out_save); | UNROLL_CALL_RAW_D2(2, 2, out_save); | ||||
@@ -211,29 +214,30 @@ void winograd_F23_mk4_f_nchw44::filter( | |||||
pack_size * pack_size + | pack_size * pack_size + | ||||
ic_inner * pack_size; | ic_inner * pack_size; | ||||
#define cb(m, n) \ | |||||
Vector<float, 4> g##m##n = Vector<float, 4>::load( \ | |||||
fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); | |||||
#define cb(m, n) \ | |||||
GI_FLOAT32_t g##m##n = \ | |||||
GiLoadFloat32(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); | |||||
UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) | UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) | ||||
#undef cb | #undef cb | ||||
#define FILTER_TRANSFORM(n, wd, g) \ | |||||
auto wd##n##0 = g##0##n; \ | |||||
tmp0 = (g##0##n + g##2##n) * 0.5; \ | |||||
tmp1 = g##1##n * 0.5; \ | |||||
auto wd##n##1 = tmp0 + tmp1; \ | |||||
auto wd##n##2 = tmp0 - tmp1; \ | |||||
#define FILTER_TRANSFORM(n, wd, g) \ | |||||
auto wd##n##0 = g##0##n; \ | |||||
tmp0 = MULSF(ADDF(g##0##n, g##2##n), 0.5); \ | |||||
tmp1 = MULSF(g##1##n, 0.5); \ | |||||
auto wd##n##1 = ADDF(tmp0, tmp1); \ | |||||
auto wd##n##2 = SUBF(tmp0, tmp1); \ | |||||
auto wd##n##3 = g##2##n; | auto wd##n##3 = g##2##n; | ||||
Vector<float, 4> tmp0, tmp1; | |||||
GI_FLOAT32_t tmp0, tmp1; | |||||
UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); | UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); | ||||
UNROLL_CALL_RAW(4, FILTER_TRANSFORM, ret, wd); | UNROLL_CALL_RAW(4, FILTER_TRANSFORM, ret, wd); | ||||
#undef FILTER_TRANSFORM | #undef FILTER_TRANSFORM | ||||
#define cb_save(m, n) \ | |||||
ret##m##n.save( \ | |||||
filter_transform_buf + \ | |||||
(m * ALPHA + n) * OCB * ICB * pack_size * pack_size + \ | |||||
ocb * ICB * pack_size * pack_size + icb * pack_size * pack_size + \ | |||||
ic_inner * pack_size); | |||||
#define cb_save(m, n) \ | |||||
GiStoreFloat32( \ | |||||
filter_transform_buf + \ | |||||
(m * ALPHA + n) * OCB * ICB * pack_size * pack_size + \ | |||||
ocb * ICB * pack_size * pack_size + icb * pack_size * pack_size + \ | |||||
ic_inner * pack_size, \ | |||||
ret##m##n); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb_save) | UNROLL_CALL_NOWRAPPER_D2(4, 4, cb_save) | ||||
#undef cb_save | #undef cb_save | ||||
} | } | ||||
@@ -4,7 +4,6 @@ | |||||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | #include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | #include "src/fallback/conv_bias/gi/fp32/helper.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | #include "src/fallback/conv_bias/gi/fp32/strategy.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/fallback/elemwise_helper/op_unary.h" | #include "src/fallback/elemwise_helper/op_unary.h" | ||||
@@ -114,17 +113,17 @@ struct InputTransformF63_NCHW44 { | |||||
auto t##i##4 = d6; \ | auto t##i##4 = d6; \ | ||||
auto t##i##5 = d6; \ | auto t##i##5 = d6; \ | ||||
auto t##i##6 = d6; \ | auto t##i##6 = d6; \ | ||||
t##i##0 = GiSubtractFloat32(t##i##0, d6); \ | |||||
t##i##1 = GiAddFloat32(t##i##1, d1); \ | |||||
t##i##2 = GiSubtractFloat32(t##i##2, d1); \ | |||||
t##i##0 = SUBF(t##i##0, d6); \ | |||||
t##i##1 = ADDF(t##i##1, d1); \ | |||||
t##i##2 = SUBF(t##i##2, d1); \ | |||||
t##i##3 = MADD(t##i##3, d1, v0, 2); \ | t##i##3 = MADD(t##i##3, d1, v0, 2); \ | ||||
t##i##4 = MSUB(t##i##4, d1, v0, 2); \ | t##i##4 = MSUB(t##i##4, d1, v0, 2); \ | ||||
t##i##5 = MADD(t##i##5, d1, v1, 2); \ | t##i##5 = MADD(t##i##5, d1, v1, 2); \ | ||||
t##i##6 = MSUB(t##i##6, d1, v1, 2); \ | t##i##6 = MSUB(t##i##6, d1, v1, 2); \ | ||||
t##i##7 = GiSubtractFloat32(t##i##7, d1); \ | |||||
t##i##7 = SUBF(t##i##7, d1); \ | |||||
t##i##0 = MSUB(t##i##0, d2, v0, 0); \ | t##i##0 = MSUB(t##i##0, d2, v0, 0); \ | ||||
t##i##1 = GiAddFloat32(t##i##1, d2); \ | |||||
t##i##2 = GiAddFloat32(t##i##2, d2); \ | |||||
t##i##1 = ADDF(t##i##1, d2); \ | |||||
t##i##2 = ADDF(t##i##2, d2); \ | |||||
t##i##3 = MADD(t##i##3, d2, v0, 3); \ | t##i##3 = MADD(t##i##3, d2, v0, 3); \ | ||||
t##i##4 = MADD(t##i##4, d2, v0, 3); \ | t##i##4 = MADD(t##i##4, d2, v0, 3); \ | ||||
t##i##5 = MADD(t##i##5, d2, v1, 3); \ | t##i##5 = MADD(t##i##5, d2, v1, 3); \ | ||||
@@ -143,8 +142,8 @@ struct InputTransformF63_NCHW44 { | |||||
t##i##4 = MSUB(t##i##4, d4, v1, 1); \ | t##i##4 = MSUB(t##i##4, d4, v1, 1); \ | ||||
t##i##5 = MSUB(t##i##5, d4, v2, 0); \ | t##i##5 = MSUB(t##i##5, d4, v2, 0); \ | ||||
t##i##6 = MSUB(t##i##6, d4, v2, 0); \ | t##i##6 = MSUB(t##i##6, d4, v2, 0); \ | ||||
t##i##1 = GiAddFloat32(t##i##1, d5); \ | |||||
t##i##2 = GiSubtractFloat32(t##i##2, d5); \ | |||||
t##i##1 = ADDF(t##i##1, d5); \ | |||||
t##i##2 = SUBF(t##i##2, d5); \ | |||||
t##i##3 = MADD(t##i##3, d5, v1, 2); \ | t##i##3 = MADD(t##i##3, d5, v1, 2); \ | ||||
t##i##4 = MSUB(t##i##4, d5, v1, 2); \ | t##i##4 = MSUB(t##i##4, d5, v1, 2); \ | ||||
t##i##5 = MADD(t##i##5, d5, v0, 2); \ | t##i##5 = MADD(t##i##5, d5, v0, 2); \ | ||||
@@ -162,17 +161,17 @@ struct InputTransformF63_NCHW44 { | |||||
d5 = t6##i; \ | d5 = t6##i; \ | ||||
d6 = t6##i; \ | d6 = t6##i; \ | ||||
d7 = t7##i; \ | d7 = t7##i; \ | ||||
d0 = GiSubtractFloat32(d0, t6##i); \ | |||||
d1 = GiAddFloat32(d1, t1##i); \ | |||||
d2 = GiSubtractFloat32(d2, t1##i); \ | |||||
d0 = SUBF(d0, t6##i); \ | |||||
d1 = ADDF(d1, t1##i); \ | |||||
d2 = SUBF(d2, t1##i); \ | |||||
d3 = MADD(d3, t1##i, v0, 2); \ | d3 = MADD(d3, t1##i, v0, 2); \ | ||||
d4 = MSUB(d4, t1##i, v0, 2); \ | d4 = MSUB(d4, t1##i, v0, 2); \ | ||||
d5 = MADD(d5, t1##i, v1, 2); \ | d5 = MADD(d5, t1##i, v1, 2); \ | ||||
d6 = MSUB(d6, t1##i, v1, 2); \ | d6 = MSUB(d6, t1##i, v1, 2); \ | ||||
d7 = GiSubtractFloat32(d7, t1##i); \ | |||||
d7 = SUBF(d7, t1##i); \ | |||||
d0 = MSUB(d0, t2##i, v0, 0); \ | d0 = MSUB(d0, t2##i, v0, 0); \ | ||||
d1 = GiAddFloat32(d1, t2##i); \ | |||||
d2 = GiAddFloat32(d2, t2##i); \ | |||||
d1 = ADDF(d1, t2##i); \ | |||||
d2 = ADDF(d2, t2##i); \ | |||||
d3 = MADD(d3, t2##i, v0, 3); \ | d3 = MADD(d3, t2##i, v0, 3); \ | ||||
d4 = MADD(d4, t2##i, v0, 3); \ | d4 = MADD(d4, t2##i, v0, 3); \ | ||||
d5 = MADD(d5, t2##i, v1, 3); \ | d5 = MADD(d5, t2##i, v1, 3); \ | ||||
@@ -191,8 +190,8 @@ struct InputTransformF63_NCHW44 { | |||||
d4 = MSUB(d4, t4##i, v1, 1); \ | d4 = MSUB(d4, t4##i, v1, 1); \ | ||||
d5 = MSUB(d5, t4##i, v2, 0); \ | d5 = MSUB(d5, t4##i, v2, 0); \ | ||||
d6 = MSUB(d6, t4##i, v2, 0); \ | d6 = MSUB(d6, t4##i, v2, 0); \ | ||||
d1 = GiAddFloat32(d1, t5##i); \ | |||||
d2 = GiSubtractFloat32(d2, t5##i); \ | |||||
d1 = ADDF(d1, t5##i); \ | |||||
d2 = SUBF(d2, t5##i); \ | |||||
d3 = MADD(d3, t5##i, v1, 2); \ | d3 = MADD(d3, t5##i, v1, 2); \ | ||||
d4 = MSUB(d4, t5##i, v1, 2); \ | d4 = MSUB(d4, t5##i, v1, 2); \ | ||||
d5 = MADD(d5, t5##i, v0, 2); \ | d5 = MADD(d5, t5##i, v0, 2); \ | ||||
@@ -261,7 +260,7 @@ struct OutputTransformF63_NCHW44 { | |||||
size_t ocb = oc_index / pack_size; | size_t ocb = oc_index / pack_size; | ||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
auto v##m##n = Vector<float, 4>::load( \ | |||||
auto v##m##n = GiLoadFloat32( \ | |||||
output_transform_buf + \ | output_transform_buf + \ | ||||
(m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ | (m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ | ||||
ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); | ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); | ||||
@@ -281,51 +280,83 @@ struct OutputTransformF63_NCHW44 { | |||||
* 0 0 0 0 0 1 | * 0 0 0 0 0 1 | ||||
*/ | */ | ||||
Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||||
#define cb(m) \ | |||||
v1addv2 = v1##m + v2##m; \ | |||||
v1subv2 = v1##m - v2##m; \ | |||||
v3addv4 = v3##m + v4##m; \ | |||||
v3subv4 = v3##m - v4##m; \ | |||||
v5addv6 = v5##m + v6##m; \ | |||||
v5subv6 = v5##m - v6##m; \ | |||||
auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \ | |||||
auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ | |||||
auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ | |||||
auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ | |||||
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||||
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | |||||
/* | |||||
* v1addv2 = v1##m + v2##m; | |||||
* v1subv2 = v1##m - v2##m; | |||||
* v3addv4 = v3##m + v4##m; | |||||
* v3subv4 = v3##m - v4##m; | |||||
* v5addv6 = v5##m + v6##m; | |||||
* v5subv6 = v5##m - v6##m; | |||||
* auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; | |||||
* auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; | |||||
* auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; | |||||
* auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; | |||||
* auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; | |||||
* auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | |||||
*/ | |||||
GI_FLOAT32_t v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||||
#define cb(m) \ | |||||
v1addv2 = ADDF(v1##m, v2##m); \ | |||||
v1subv2 = SUBF(v1##m, v2##m); \ | |||||
v3addv4 = ADDF(v3##m, v4##m); \ | |||||
v3subv4 = SUBF(v3##m, v4##m); \ | |||||
v5addv6 = ADDF(v5##m, v6##m); \ | |||||
v5subv6 = SUBF(v5##m, v6##m); \ | |||||
auto t0##m = ADDF(ADDF(ADDF(v0##m, v1addv2), v3addv4), v5addv6); \ | |||||
auto t1##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \ | |||||
auto t2##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \ | |||||
auto t3##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \ | |||||
auto t4##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \ | |||||
auto t5##m = \ | |||||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ | |||||
v7##m); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(m) \ | |||||
v1addv2 = t##m##1 + t##m##2; \ | |||||
v1subv2 = t##m##1 - t##m##2; \ | |||||
v3addv4 = t##m##3 + t##m##4; \ | |||||
v3subv4 = t##m##3 - t##m##4; \ | |||||
v5addv6 = t##m##5 + t##m##6; \ | |||||
v5subv6 = t##m##5 - t##m##6; \ | |||||
v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \ | |||||
v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ | |||||
v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ | |||||
v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ | |||||
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||||
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | |||||
/* | |||||
* v1addv2 = t##m##1 + t##m##2; | |||||
* v1subv2 = t##m##1 - t##m##2; | |||||
* v3addv4 = t##m##3 + t##m##4; | |||||
* v3subv4 = t##m##3 - t##m##4; | |||||
* v5addv6 = t##m##5 + t##m##6; | |||||
* v5subv6 = t##m##5 - t##m##6; | |||||
* v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; | |||||
* v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; | |||||
* v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; | |||||
* v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; | |||||
* v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; | |||||
* v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | |||||
*/ | |||||
#define cb(m) \ | |||||
v1addv2 = ADDF(t##m##1, t##m##2); \ | |||||
v1subv2 = SUBF(t##m##1, t##m##2); \ | |||||
v3addv4 = ADDF(t##m##3, t##m##4); \ | |||||
v3subv4 = SUBF(t##m##3, t##m##4); \ | |||||
v5addv6 = ADDF(t##m##5, t##m##6); \ | |||||
v5subv6 = SUBF(t##m##5, t##m##6); \ | |||||
v##m##0 = ADDF(ADDF(ADDF(t##m##0, v1addv2), v3addv4), v5addv6); \ | |||||
v##m##1 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \ | |||||
v##m##2 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \ | |||||
v##m##3 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \ | |||||
v##m##4 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \ | |||||
v##m##5 = \ | |||||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ | |||||
t##m##7); | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | UNROLL_CALL_NOWRAPPER(6, cb); | ||||
#undef cb | #undef cb | ||||
Vector<float, 4> vbias; | |||||
GI_FLOAT32_t vbias; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
vbias = Vector<float, 4>::load(bias + oc); | |||||
vbias = GiLoadFloat32(bias + oc); | |||||
#define cb(m, n) v##m##n += vbias; | |||||
#define cb(m, n) v##m##n = ADDF(v##m##n, vbias); | |||||
UNROLL_CALL_RAW_D2(6, 6, cb); | UNROLL_CALL_RAW_D2(6, 6, cb); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
if (bmode != BiasMode::BIAS) { | if (bmode != BiasMode::BIAS) { | ||||
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value)); | |||||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n)); | |||||
UNROLL_CALL_RAW_D2(6, 6, cb); | UNROLL_CALL_RAW_D2(6, 6, cb); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -335,12 +366,15 @@ struct OutputTransformF63_NCHW44 { | |||||
size_t ow = ow_start + owo; \ | size_t ow = ow_start + owo; \ | ||||
if (oh < OH && ow < OW) { \ | if (oh < OH && ow < OW) { \ | ||||
if (bmode == BiasMode::BIAS) { \ | if (bmode == BiasMode::BIAS) { \ | ||||
v##oho##owo += Vector<float, 4>::load( \ | |||||
bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ | |||||
v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \ | |||||
v##oho##owo = ADDF( \ | |||||
v##oho##owo, GiLoadFloat32( \ | |||||
bias + oc * OH * OW + \ | |||||
oh * OW * pack_size + ow * pack_size)); \ | |||||
v##oho##owo = op(v##oho##owo); \ | |||||
} \ | } \ | ||||
v##oho##owo.save( \ | |||||
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ | |||||
GiStoreFloat32( \ | |||||
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \ | |||||
v##oho##owo); \ | |||||
} \ | } \ | ||||
} while (0); | } while (0); | ||||
UNROLL_CALL_RAW_D2(6, 6, out_save); | UNROLL_CALL_RAW_D2(6, 6, out_save); | ||||
@@ -387,35 +421,52 @@ void winograd_F63_mk4_f_nchw44::filter( | |||||
pack_size * pack_size + | pack_size * pack_size + | ||||
ic_inner * pack_size; | ic_inner * pack_size; | ||||
#define cb(m, n) \ | |||||
Vector<float, 4> g##m##n = Vector<float, 4>::load( \ | |||||
fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); | |||||
#define cb(m, n) \ | |||||
GI_FLOAT32_t g##m##n = \ | |||||
GiLoadFloat32(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); | |||||
UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) | UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) | ||||
#undef cb | #undef cb | ||||
#define FILTER_TRANSFORM(n, wd, g) \ | |||||
auto wd##n##0 = g##0##n; \ | |||||
tmp0 = (g##0##n + g##2##n) * -0.2222222f; \ | |||||
tmp1 = g##1##n * -0.2222222f; \ | |||||
auto wd##n##1 = tmp0 + tmp1; \ | |||||
auto wd##n##2 = tmp0 - tmp1; \ | |||||
tmp0 = g##0##n * 0.0111111f + g##2##n * 0.0444444f; \ | |||||
tmp1 = g##1##n * 0.0222222f; \ | |||||
auto wd##n##3 = tmp0 + tmp1; \ | |||||
auto wd##n##4 = tmp0 - tmp1; \ | |||||
tmp0 = g##0##n * 0.7111111f + g##2##n * 0.1777778f; \ | |||||
tmp1 = g##1##n * 0.3555556f; \ | |||||
auto wd##n##5 = tmp0 + tmp1; \ | |||||
auto wd##n##6 = tmp0 - tmp1; \ | |||||
/* | |||||
* auto wd##n##0 = g##0##n; | |||||
* tmp0 = (g##0##n + g##2##n) * -0.2222222f; | |||||
* tmp1 = g##1##n * -0.2222222f; | |||||
* auto wd##n##1 = tmp0 + tmp1; | |||||
* auto wd##n##2 = tmp0 - tmp1; | |||||
* tmp0 = g##0##n * 0.0111111f + g##2##n * 0.0444444f; | |||||
* tmp1 = g##1##n * 0.0222222f; | |||||
* auto wd##n##3 = tmp0 + tmp1; | |||||
* auto wd##n##4 = tmp0 - tmp1; | |||||
* tmp0 = g##0##n * 0.7111111f + g##2##n * 0.1777778f; | |||||
* tmp1 = g##1##n * 0.3555556f; | |||||
* auto wd##n##5 = tmp0 + tmp1; | |||||
* auto wd##n##6 = tmp0 - tmp1; | |||||
* auto wd##n##7 = g##2##n; | |||||
*/ | |||||
#define FILTER_TRANSFORM(n, wd, g) \ | |||||
auto wd##n##0 = g##0##n; \ | |||||
tmp0 = MULSF(ADDF(g##0##n, g##2##n), -0.2222222f); \ | |||||
tmp1 = MULSF(g##1##n, -0.2222222f); \ | |||||
auto wd##n##1 = ADDF(tmp0, tmp1); \ | |||||
auto wd##n##2 = SUBF(tmp0, tmp1); \ | |||||
tmp0 = ADDF(MULSF(g##0##n, 0.0111111f), MULSF(g##2##n, 0.0444444f)); \ | |||||
tmp1 = MULSF(g##1##n, 0.0222222f); \ | |||||
auto wd##n##3 = ADDF(tmp0, tmp1); \ | |||||
auto wd##n##4 = SUBF(tmp0, tmp1); \ | |||||
tmp0 = ADDF(MULSF(g##0##n, 0.7111111f), MULSF(g##2##n, 0.1777778f)); \ | |||||
tmp1 = MULSF(g##1##n, 0.3555556f); \ | |||||
auto wd##n##5 = ADDF(tmp0, tmp1); \ | |||||
auto wd##n##6 = SUBF(tmp0, tmp1); \ | |||||
auto wd##n##7 = g##2##n; | auto wd##n##7 = g##2##n; | ||||
Vector<float, 4> tmp0, tmp1; | |||||
GI_FLOAT32_t tmp0, tmp1; | |||||
UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); | UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); | ||||
UNROLL_CALL_RAW(8, FILTER_TRANSFORM, ret, wd); | UNROLL_CALL_RAW(8, FILTER_TRANSFORM, ret, wd); | ||||
#undef FILTER_TRANSFORM | #undef FILTER_TRANSFORM | ||||
#define cb_save(m, n) \ | #define cb_save(m, n) \ | ||||
ret##m##n.save( \ | |||||
GiStoreFloat32( \ | |||||
filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \ | filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \ | ||||
icb * pack_size * pack_size + ic_inner * pack_size); | |||||
icb * pack_size * pack_size + ic_inner * pack_size, \ | |||||
ret##m##n); | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb_save) | UNROLL_CALL_NOWRAPPER_D2(8, 8, cb_save) | ||||
#undef cb_save | #undef cb_save | ||||
} | } | ||||
@@ -4,7 +4,6 @@ | |||||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | #include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | #include "src/fallback/conv_bias/gi/fp32/helper.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | #include "src/fallback/conv_bias/gi/fp32/strategy.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/fallback/elemwise_helper/op_unary.h" | #include "src/fallback/elemwise_helper/op_unary.h" | ||||
@@ -137,14 +136,14 @@ struct InputTransformF73_NCHW44 { | |||||
auto t##i##6 = d7; \ | auto t##i##6 = d7; \ | ||||
auto t##i##7 = d7; \ | auto t##i##7 = d7; \ | ||||
t##i##8 = MSUB(t##i##8, d7, v0, 0); \ | t##i##8 = MSUB(t##i##8, d7, v0, 0); \ | ||||
t##i##0 = GiSubtractFloat32(t##i##0, d1); \ | |||||
t##i##0 = SUBF(t##i##0, d1); \ | |||||
t##i##1 = MSUB(t##i##1, d1, v0, 0); \ | t##i##1 = MSUB(t##i##1, d1, v0, 0); \ | ||||
t##i##2 = MADD(t##i##2, d1, v0, 0); \ | t##i##2 = MADD(t##i##2, d1, v0, 0); \ | ||||
t##i##3 = MSUB(t##i##3, d1, v0, 1); \ | t##i##3 = MSUB(t##i##3, d1, v0, 1); \ | ||||
t##i##4 = MADD(t##i##4, d1, v0, 1); \ | t##i##4 = MADD(t##i##4, d1, v0, 1); \ | ||||
t##i##5 = MSUB(t##i##5, d1, v0, 2); \ | t##i##5 = MSUB(t##i##5, d1, v0, 2); \ | ||||
t##i##6 = MADD(t##i##6, d1, v0, 2); \ | t##i##6 = MADD(t##i##6, d1, v0, 2); \ | ||||
t##i##7 = GiSubtractFloat32(t##i##7, d1); \ | |||||
t##i##7 = SUBF(t##i##7, d1); \ | |||||
t##i##8 = MADD(t##i##8, d1, v0, 0); \ | t##i##8 = MADD(t##i##8, d1, v0, 0); \ | ||||
t##i##0 = MSUB(t##i##0, d2, v0, 3); \ | t##i##0 = MSUB(t##i##0, d2, v0, 3); \ | ||||
t##i##1 = MSUB(t##i##1, d2, v1, 0); \ | t##i##1 = MSUB(t##i##1, d2, v1, 0); \ | ||||
@@ -153,7 +152,7 @@ struct InputTransformF73_NCHW44 { | |||||
t##i##4 = MSUB(t##i##4, d2, v1, 3); \ | t##i##4 = MSUB(t##i##4, d2, v1, 3); \ | ||||
t##i##5 = MSUB(t##i##5, d2, v2, 0); \ | t##i##5 = MSUB(t##i##5, d2, v2, 0); \ | ||||
t##i##6 = MSUB(t##i##6, d2, v2, 1); \ | t##i##6 = MSUB(t##i##6, d2, v2, 1); \ | ||||
t##i##8 = GiSubtractFloat32(t##i##8, d2); \ | |||||
t##i##8 = SUBF(t##i##8, d2); \ | |||||
t##i##0 = MADD(t##i##0, d3, v2, 2); \ | t##i##0 = MADD(t##i##0, d3, v2, 2); \ | ||||
t##i##1 = MADD(t##i##1, d3, v2, 3); \ | t##i##1 = MADD(t##i##1, d3, v2, 3); \ | ||||
t##i##2 = MSUB(t##i##2, d3, v3, 0); \ | t##i##2 = MSUB(t##i##2, d3, v3, 0); \ | ||||
@@ -185,7 +184,7 @@ struct InputTransformF73_NCHW44 { | |||||
t##i##2 = MSUB(t##i##2, d6, v1, 1); \ | t##i##2 = MSUB(t##i##2, d6, v1, 1); \ | ||||
t##i##3 = MADD(t##i##3, d6, v1, 0); \ | t##i##3 = MADD(t##i##3, d6, v1, 0); \ | ||||
t##i##4 = MSUB(t##i##4, d6, v3, 1); \ | t##i##4 = MSUB(t##i##4, d6, v3, 1); \ | ||||
t##i##5 = GiSubtractFloat32(t##i##5, d6); \ | |||||
t##i##5 = SUBF(t##i##5, d6); \ | |||||
t##i##6 = MSUB(t##i##6, d6, v6, 2); \ | t##i##6 = MSUB(t##i##6, d6, v6, 2); \ | ||||
t##i##8 = MSUB(t##i##8, d6, v2, 2); \ | t##i##8 = MSUB(t##i##8, d6, v2, 2); \ | ||||
t##i##0 = MADD(t##i##0, d0, v0, 0); | t##i##0 = MADD(t##i##0, d0, v0, 0); | ||||
@@ -204,14 +203,14 @@ struct InputTransformF73_NCHW44 { | |||||
d6 = t7##i; \ | d6 = t7##i; \ | ||||
d7 = t7##i; \ | d7 = t7##i; \ | ||||
d8 = MSUB(d8, t7##i, v0, 0); \ | d8 = MSUB(d8, t7##i, v0, 0); \ | ||||
d0 = GiSubtractFloat32(d0, t1##i); \ | |||||
d0 = SUBF(d0, t1##i); \ | |||||
d1 = MSUB(d1, t1##i, v0, 0); \ | d1 = MSUB(d1, t1##i, v0, 0); \ | ||||
d2 = MADD(d2, t1##i, v0, 0); \ | d2 = MADD(d2, t1##i, v0, 0); \ | ||||
d3 = MSUB(d3, t1##i, v0, 1); \ | d3 = MSUB(d3, t1##i, v0, 1); \ | ||||
d4 = MADD(d4, t1##i, v0, 1); \ | d4 = MADD(d4, t1##i, v0, 1); \ | ||||
d5 = MSUB(d5, t1##i, v0, 2); \ | d5 = MSUB(d5, t1##i, v0, 2); \ | ||||
d6 = MADD(d6, t1##i, v0, 2); \ | d6 = MADD(d6, t1##i, v0, 2); \ | ||||
d7 = GiSubtractFloat32(d7, t1##i); \ | |||||
d7 = SUBF(d7, t1##i); \ | |||||
d8 = MADD(d8, t1##i, v0, 0); \ | d8 = MADD(d8, t1##i, v0, 0); \ | ||||
d0 = MSUB(d0, t2##i, v0, 3); \ | d0 = MSUB(d0, t2##i, v0, 3); \ | ||||
d1 = MSUB(d1, t2##i, v1, 0); \ | d1 = MSUB(d1, t2##i, v1, 0); \ | ||||
@@ -220,7 +219,7 @@ struct InputTransformF73_NCHW44 { | |||||
d4 = MSUB(d4, t2##i, v1, 3); \ | d4 = MSUB(d4, t2##i, v1, 3); \ | ||||
d5 = MSUB(d5, t2##i, v2, 0); \ | d5 = MSUB(d5, t2##i, v2, 0); \ | ||||
d6 = MSUB(d6, t2##i, v2, 1); \ | d6 = MSUB(d6, t2##i, v2, 1); \ | ||||
d8 = GiSubtractFloat32(d8, t2##i); \ | |||||
d8 = SUBF(d8, t2##i); \ | |||||
d0 = MADD(d0, t3##i, v2, 2); \ | d0 = MADD(d0, t3##i, v2, 2); \ | ||||
d1 = MADD(d1, t3##i, v2, 3); \ | d1 = MADD(d1, t3##i, v2, 3); \ | ||||
d2 = MSUB(d2, t3##i, v3, 0); \ | d2 = MSUB(d2, t3##i, v3, 0); \ | ||||
@@ -252,7 +251,7 @@ struct InputTransformF73_NCHW44 { | |||||
d2 = MSUB(d2, t6##i, v1, 1); \ | d2 = MSUB(d2, t6##i, v1, 1); \ | ||||
d3 = MADD(d3, t6##i, v1, 0); \ | d3 = MADD(d3, t6##i, v1, 0); \ | ||||
d4 = MSUB(d4, t6##i, v3, 1); \ | d4 = MSUB(d4, t6##i, v3, 1); \ | ||||
d5 = GiSubtractFloat32(d5, t6##i); \ | |||||
d5 = SUBF(d5, t6##i); \ | |||||
d6 = MSUB(d6, t6##i, v6, 2); \ | d6 = MSUB(d6, t6##i, v6, 2); \ | ||||
d8 = MSUB(d8, t6##i, v2, 2); \ | d8 = MSUB(d8, t6##i, v2, 2); \ | ||||
d0 = MADD(d0, t0##i, v0, 0); \ | d0 = MADD(d0, t0##i, v0, 0); \ | ||||
@@ -325,7 +324,7 @@ struct OutputTransformF73_NCHW44 { | |||||
size_t ocb = oc_index / pack_size; | size_t ocb = oc_index / pack_size; | ||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
auto v##m##n = Vector<float, 4>::load( \ | |||||
auto v##m##n = GiLoadFloat32( \ | |||||
output_transform_buf + \ | output_transform_buf + \ | ||||
(m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ | (m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ | ||||
ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); | ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); | ||||
@@ -346,56 +345,112 @@ struct OutputTransformF73_NCHW44 { | |||||
* 1 1.5 2.25 3.375 5.0625 7.59375 11.390625 | * 1 1.5 2.25 3.375 5.0625 7.59375 11.390625 | ||||
* 0 0 0 0 0 0 1 | * 0 0 0 0 0 0 1 | ||||
*/ | */ | ||||
/* | |||||
* v1addv2 = v1##m + v2##m; | |||||
* v1subv2 = v1##m - v2##m; | |||||
* v3addv4 = v3##m + v4##m; | |||||
* v3subv4 = v3##m - v4##m; | |||||
* v5addv6 = v5##m + v6##m; | |||||
* v5subv6 = v5##m - v6##m; | |||||
* auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6 + v7##m; | |||||
* auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + v7##m * 1.5f; | |||||
* auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + v7##m * 2.25f; | |||||
* auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + v7##m * 3.375f; | |||||
* auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + v7##m * 5.0625f; | |||||
* auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m | |||||
* * 7.59375f; auto t6##m = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + | |||||
* v7##m * 11.390625f + v8##m; | |||||
*/ | |||||
Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||||
GI_FLOAT32_t v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||||
#define cb(m) \ | #define cb(m) \ | ||||
v1addv2 = v1##m + v2##m; \ | |||||
v1subv2 = v1##m - v2##m; \ | |||||
v3addv4 = v3##m + v4##m; \ | |||||
v3subv4 = v3##m - v4##m; \ | |||||
v5addv6 = v5##m + v6##m; \ | |||||
v5subv6 = v5##m - v6##m; \ | |||||
auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6 + v7##m; \ | |||||
auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + v7##m * 1.5f; \ | |||||
auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + v7##m * 2.25f; \ | |||||
auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + v7##m * 3.375f; \ | |||||
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + v7##m * 5.0625f; \ | |||||
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m * 7.59375f; \ | |||||
auto t6##m = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + v7##m * 11.390625f + \ | |||||
v8##m; | |||||
v1addv2 = ADDF(v1##m, v2##m); \ | |||||
v1subv2 = SUBF(v1##m, v2##m); \ | |||||
v3addv4 = ADDF(v3##m, v4##m); \ | |||||
v3subv4 = SUBF(v3##m, v4##m); \ | |||||
v5addv6 = ADDF(v5##m, v6##m); \ | |||||
v5subv6 = SUBF(v5##m, v6##m); \ | |||||
auto t0##m = ADDF(ADDF(ADDF(ADDF(v0##m, v1addv2), v3addv4), v5addv6), v7##m); \ | |||||
auto t1##m = \ | |||||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)), \ | |||||
MULSF(v7##m, 1.5f)); \ | |||||
auto t2##m = \ | |||||
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)), \ | |||||
MULSF(v7##m, 2.25f)); \ | |||||
auto t3##m = \ | |||||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)), \ | |||||
MULSF(v7##m, 3.375f)); \ | |||||
auto t4##m = \ | |||||
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)), \ | |||||
MULSF(v7##m, 5.0625f)); \ | |||||
auto t5##m = \ | |||||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ | |||||
MULSF(v7##m, 7.59375f)); \ | |||||
auto t6##m = ADDF( \ | |||||
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 64.f)), MULSF(v5addv6, 0.015625f)), \ | |||||
MULSF(v7##m, 11.390625f)), \ | |||||
v8##m); | |||||
UNROLL_CALL_NOWRAPPER(9, cb); | UNROLL_CALL_NOWRAPPER(9, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(m) \ | |||||
v1addv2 = t##m##1 + t##m##2; \ | |||||
v1subv2 = t##m##1 - t##m##2; \ | |||||
v3addv4 = t##m##3 + t##m##4; \ | |||||
v3subv4 = t##m##3 - t##m##4; \ | |||||
v5addv6 = t##m##5 + t##m##6; \ | |||||
v5subv6 = t##m##5 - t##m##6; \ | |||||
v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6 + t##m##7; \ | |||||
v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + t##m##7 * 1.5f; \ | |||||
v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + t##m##7 * 2.25f; \ | |||||
v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + t##m##7 * 3.375; \ | |||||
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + t##m##7 * 5.0625f; \ | |||||
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7 * 7.59375f; \ | |||||
v##m##6 = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + t##m##7 * 11.390625f + \ | |||||
t##m##8; | |||||
/* | |||||
* v1addv2 = t##m##1 + t##m##2; | |||||
* v1subv2 = t##m##1 - t##m##2; | |||||
* v3addv4 = t##m##3 + t##m##4; | |||||
* v3subv4 = t##m##3 - t##m##4; | |||||
* v5addv6 = t##m##5 + t##m##6; | |||||
* v5subv6 = t##m##5 - t##m##6; | |||||
* v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6 + t##m##7; | |||||
* v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + t##m##7 * 1.5f; | |||||
* v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + t##m##7 * 2.25f; | |||||
* v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + t##m##7 * 3.375; | |||||
* v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + t##m##7 * 5.0625f; | |||||
* v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7 * 7.59375f; | |||||
* v##m##6 = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + t##m##7 | |||||
* * 11.390625f + t##m##8; | |||||
*/ | |||||
#define cb(m) \ | |||||
v1addv2 = ADDF(t##m##1, t##m##2); \ | |||||
v1subv2 = SUBF(t##m##1, t##m##2); \ | |||||
v3addv4 = ADDF(t##m##3, t##m##4); \ | |||||
v3subv4 = SUBF(t##m##3, t##m##4); \ | |||||
v5addv6 = ADDF(t##m##5, t##m##6); \ | |||||
v5subv6 = SUBF(t##m##5, t##m##6); \ | |||||
v##m##0 = ADDF(ADDF(ADDF(ADDF(t##m##0, v1addv2), v3addv4), v5addv6), t##m##7); \ | |||||
v##m##1 = \ | |||||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)), \ | |||||
MULSF(t##m##7, 1.5f)); \ | |||||
v##m##2 = \ | |||||
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)), \ | |||||
MULSF(t##m##7, 2.25f)); \ | |||||
v##m##3 = \ | |||||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)), \ | |||||
MULSF(t##m##7, 3.375)); \ | |||||
v##m##4 = \ | |||||
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)), \ | |||||
MULSF(t##m##7, 5.0625f)); \ | |||||
v##m##5 = \ | |||||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ | |||||
MULSF(t##m##7, 7.59375f)); \ | |||||
v##m##6 = ADDF( \ | |||||
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 64.f)), MULSF(v5addv6, 0.015625f)), \ | |||||
MULSF(t##m##7, 11.390625f)), \ | |||||
t##m##8); | |||||
UNROLL_CALL_NOWRAPPER(7, cb); | UNROLL_CALL_NOWRAPPER(7, cb); | ||||
#undef cb | #undef cb | ||||
Vector<float, 4> vbias; | |||||
GI_FLOAT32_t vbias; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
vbias = Vector<float, 4>::load(bias + oc); | |||||
vbias = GiLoadFloat32(bias + oc); | |||||
#define cb(m, n) v##m##n += vbias; | |||||
#define cb(m, n) v##m##n = ADDF(v##m##n, vbias); | |||||
UNROLL_CALL_RAW_D2(7, 7, cb); | UNROLL_CALL_RAW_D2(7, 7, cb); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
if (bmode != BiasMode::BIAS) { | if (bmode != BiasMode::BIAS) { | ||||
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value)); | |||||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n)); | |||||
UNROLL_CALL_RAW_D2(7, 7, cb); | UNROLL_CALL_RAW_D2(7, 7, cb); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -405,12 +460,15 @@ struct OutputTransformF73_NCHW44 { | |||||
size_t ow = ow_start + owo; \ | size_t ow = ow_start + owo; \ | ||||
if (oh < OH && ow < OW) { \ | if (oh < OH && ow < OW) { \ | ||||
if (bmode == BiasMode::BIAS) { \ | if (bmode == BiasMode::BIAS) { \ | ||||
v##oho##owo += Vector<float, 4>::load( \ | |||||
bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ | |||||
v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \ | |||||
v##oho##owo = ADDF( \ | |||||
v##oho##owo, GiLoadFloat32( \ | |||||
bias + oc * OH * OW + \ | |||||
oh * OW * pack_size + ow * pack_size)); \ | |||||
v##oho##owo = op(v##oho##owo); \ | |||||
} \ | } \ | ||||
v##oho##owo.save( \ | |||||
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ | |||||
GiStoreFloat32( \ | |||||
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \ | |||||
v##oho##owo); \ | |||||
} \ | } \ | ||||
} while (0); | } while (0); | ||||
UNROLL_CALL_RAW_D2(7, 7, out_save); | UNROLL_CALL_RAW_D2(7, 7, out_save); | ||||
@@ -458,34 +516,53 @@ void winograd_F73_mk4_f_nchw44::filter( | |||||
pack_size * pack_size + | pack_size * pack_size + | ||||
ic_inner * pack_size; | ic_inner * pack_size; | ||||
#define cb(m, n) \ | |||||
Vector<float, 4> g##m##n = Vector<float, 4>::load( \ | |||||
fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); | |||||
#define cb(m, n) \ | |||||
GI_FLOAT32_t g##m##n = \ | |||||
GiLoadFloat32(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); | |||||
UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) | UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) | ||||
#undef cb | #undef cb | ||||
#define FILTER_TRANSFORM(n, wd, g) \ | |||||
auto wd##n##0 = g##0##n * 0.6666667f; \ | |||||
auto wd##n##1 = (g##0##n + g##1##n + g##2##n) * 0.4444444f; \ | |||||
auto wd##n##2 = (g##0##n - g##1##n + g##2##n) * 0.0888889f; \ | |||||
auto wd##n##3 = \ | |||||
g##0##n * 0.0222222f + g##1##n * 0.0444444f + g##2##n * 0.0888889f; \ | |||||
auto wd##n##4 = \ | |||||
g##0##n * -0.0031746f + g##1##n * 0.0063492f + g##2##n * -0.0126984f; \ | |||||
auto wd##n##5 = \ | |||||
g##0##n * -0.7111111f + g##1##n * -0.3555556f + g##2##n * -0.1777778f; \ | |||||
auto wd##n##6 = \ | |||||
g##0##n * -0.3555556f + g##1##n * 0.1777778f + g##2##n * -0.0888889f; \ | |||||
auto wd##n##7 = \ | |||||
g##0##n * -0.1523810f + g##1##n * -0.2285714f + g##2##n * -0.3428572f; \ | |||||
/* | |||||
* auto wd##n##0 = g##0##n * 0.6666667f; | |||||
* auto wd##n##1 = (g##0##n + g##1##n + g##2##n) * 0.4444444f; | |||||
* auto wd##n##2 = (g##0##n - g##1##n + g##2##n) * 0.0888889f; | |||||
* auto wd##n##3 = | |||||
* g##0##n * 0.0222222f + g##1##n * 0.0444444f + g##2##n * | |||||
* 0.0888889f; auto wd##n##4 = g##0##n * -0.0031746f + g##1##n * | |||||
* 0.0063492f + g##2##n * -0.0126984f; auto wd##n##5 = g##0##n * | |||||
* -0.7111111f + g##1##n * -0.3555556f + g##2##n * -0.1777778f; auto | |||||
* wd##n##6 = g##0##n * -0.3555556f + g##1##n * 0.1777778f + g##2##n * | |||||
* -0.0888889f; auto wd##n##7 = g##0##n * -0.1523810f + g##1##n * | |||||
* -0.2285714f + g##2##n * -0.3428572f; | |||||
*/ | |||||
#define FILTER_TRANSFORM(n, wd, g) \ | |||||
auto wd##n##0 = MULSF(g##0##n, 0.6666667f); \ | |||||
auto wd##n##1 = MULSF(ADDF(ADDF(g##0##n, g##1##n), g##2##n), 0.4444444f); \ | |||||
auto wd##n##2 = MULSF(ADDF(SUBF(g##0##n, g##1##n), g##2##n), 0.0888889f); \ | |||||
auto wd##n##3 = \ | |||||
ADDF(ADDF(MULSF(g##0##n, 0.0222222f), MULSF(g##1##n, 0.0444444f)), \ | |||||
MULSF(g##2##n, 0.0888889f)); \ | |||||
auto wd##n##4 = \ | |||||
ADDF(ADDF(MULSF(g##0##n, -0.0031746f), MULSF(g##1##n, 0.0063492f)), \ | |||||
MULSF(g##2##n, -0.0126984f)); \ | |||||
auto wd##n##5 = \ | |||||
ADDF(ADDF(MULSF(g##0##n, -0.7111111f), MULSF(g##1##n, -0.3555556f)), \ | |||||
MULSF(g##2##n, -0.1777778f)); \ | |||||
auto wd##n##6 = \ | |||||
ADDF(ADDF(MULSF(g##0##n, -0.3555556f), MULSF(g##1##n, 0.1777778f)), \ | |||||
MULSF(g##2##n, -0.0888889f)); \ | |||||
auto wd##n##7 = \ | |||||
ADDF(ADDF(MULSF(g##0##n, -0.1523810f), MULSF(g##1##n, -0.2285714f)), \ | |||||
MULSF(g##2##n, -0.3428572f)); \ | |||||
auto wd##n##8 = g##2##n; | auto wd##n##8 = g##2##n; | ||||
UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); | UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); | ||||
UNROLL_CALL_RAW(9, FILTER_TRANSFORM, ret, wd); | UNROLL_CALL_RAW(9, FILTER_TRANSFORM, ret, wd); | ||||
#undef FILTER_TRANSFORM | #undef FILTER_TRANSFORM | ||||
#define cb_save(m, n) \ | #define cb_save(m, n) \ | ||||
ret##m##n.save( \ | |||||
GiStoreFloat32( \ | |||||
filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \ | filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \ | ||||
icb * pack_size * pack_size + ic_inner * pack_size); | |||||
icb * pack_size * pack_size + ic_inner * pack_size, \ | |||||
ret##m##n); | |||||
UNROLL_CALL_NOWRAPPER_D2(9, 9, cb_save) | UNROLL_CALL_NOWRAPPER_D2(9, 9, cb_save) | ||||
#undef cb_save | #undef cb_save | ||||
} | } | ||||
@@ -1,215 +0,0 @@ | |||||
#pragma once | |||||
#include <cstring> | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | |||||
namespace megdnn { | |||||
namespace fallback { | |||||
template <typename ctype, size_t len> | |||||
struct Vector; | |||||
template <> | |||||
struct Vector<float, 4> { | |||||
GI_FLOAT32_FIXLEN_t value; | |||||
Vector() {} | |||||
Vector(const float v) { value = GiFloat32Type2FixLenType(GiBroadcastFloat32(v)); } | |||||
Vector(const Vector& lr) { value = lr.value; } | |||||
Vector(const Vector&& lr) { value = std::move(lr.value); } | |||||
Vector(const GI_FLOAT32_t& v) { value = GiFloat32Type2FixLenType(v); } | |||||
static Vector load(const float* addr) { | |||||
Vector v; | |||||
v.value = GiFloat32Type2FixLenType(GiLoadFloat32(addr)); | |||||
return v; | |||||
} | |||||
static void save(float* addr, const Vector& v) { | |||||
GiStoreFloat32(addr, GiFixLenType2GiFloat32Type(v.value)); | |||||
} | |||||
void save(float* addr) { save(addr, *this); } | |||||
Vector operator+(const Vector& lr) { | |||||
Vector dst; | |||||
dst.value = GiFloat32Type2FixLenType(GiAddFloat32( | |||||
GiFixLenType2GiFloat32Type(value), | |||||
GiFixLenType2GiFloat32Type(lr.value))); | |||||
return dst; | |||||
} | |||||
Vector& operator+=(const Vector& lr) { | |||||
value = GiFloat32Type2FixLenType(GiAddFloat32( | |||||
GiFixLenType2GiFloat32Type(value), | |||||
GiFixLenType2GiFloat32Type(lr.value))); | |||||
return *this; | |||||
} | |||||
Vector operator-(const Vector& lr) { | |||||
Vector dst; | |||||
dst.value = GiFloat32Type2FixLenType(GiSubtractFloat32( | |||||
GiFixLenType2GiFloat32Type(value), | |||||
GiFixLenType2GiFloat32Type(lr.value))); | |||||
return dst; | |||||
} | |||||
Vector& operator-=(const Vector& lr) { | |||||
value = GiFloat32Type2FixLenType(GiSubtractFloat32( | |||||
GiFixLenType2GiFloat32Type(value), | |||||
GiFixLenType2GiFloat32Type(lr.value))); | |||||
return *this; | |||||
} | |||||
Vector operator*(float lr) { | |||||
Vector dst; | |||||
dst.value = GiFloat32Type2FixLenType( | |||||
GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value), lr)); | |||||
return dst; | |||||
} | |||||
Vector operator*(const Vector& lr) { | |||||
Vector dst; | |||||
dst.value = GiFloat32Type2FixLenType(GiMultiplyFloat32( | |||||
GiFixLenType2GiFloat32Type(value), | |||||
GiFixLenType2GiFloat32Type(lr.value))); | |||||
return dst; | |||||
} | |||||
Vector& operator*=(const Vector& lr) { | |||||
value = GiFloat32Type2FixLenType(GiMultiplyFloat32( | |||||
GiFixLenType2GiFloat32Type(value), | |||||
GiFixLenType2GiFloat32Type(lr.value))); | |||||
return *this; | |||||
} | |||||
Vector& operator=(const Vector& lr) { | |||||
value = lr.value; | |||||
return *this; | |||||
} | |||||
Vector& operator=(const Vector&& lr) { | |||||
value = std::move(lr.value); | |||||
return *this; | |||||
} | |||||
Vector operator-() { | |||||
Vector dst; | |||||
dst.value = -value; | |||||
return dst; | |||||
} | |||||
}; | |||||
template <> | |||||
struct Vector<float, 8> { | |||||
GI_FLOAT32_FIXLEN_V2_t value; | |||||
Vector() {} | |||||
Vector(const float v) { | |||||
value.val[0] = GiFloat32Type2FixLenType(GiBroadcastFloat32(v)); | |||||
value.val[1] = GiFloat32Type2FixLenType(GiBroadcastFloat32(v)); | |||||
} | |||||
Vector(const Vector& lr) { value = lr.value; } | |||||
Vector(const Vector&& lr) { value = std::move(lr.value); } | |||||
Vector(const GI_FLOAT32_V2_t& v) { value = GiFloat32Type2FixLenV2Type(v); } | |||||
static Vector load(const float* addr) { | |||||
Vector v; | |||||
v.value = GiFloat32Type2FixLenV2Type(GiLoadFloat32V2(addr)); | |||||
return v; | |||||
} | |||||
static void save(float* addr, const Vector& v) { | |||||
GiStoreFloat32V2(addr, GiFixLenType2GiFloat32V2Type(v.value)); | |||||
} | |||||
void save(float* addr) { save(addr, *this); } | |||||
Vector operator+(const Vector& lr) { | |||||
Vector dst; | |||||
dst.value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[0]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||||
dst.value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[1]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||||
return dst; | |||||
} | |||||
Vector& operator+=(const Vector& lr) { | |||||
value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[0]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||||
value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[1]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||||
return *this; | |||||
} | |||||
Vector& add(const Vector& lr) { | |||||
value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[0]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||||
value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[1]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||||
return *this; | |||||
} | |||||
Vector operator-(const Vector& lr) { | |||||
Vector dst; | |||||
dst.value.val[0] = GiFloat32Type2FixLenType(GiSubtractFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[0]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||||
dst.value.val[1] = GiFloat32Type2FixLenType(GiSubtractFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[1]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||||
return dst; | |||||
} | |||||
Vector& operator-=(const Vector& lr) { | |||||
value.val[0] = GiFloat32Type2FixLenType(GiSubtractFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[0]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||||
value.val[1] = GiFloat32Type2FixLenType(GiSubtractFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[1]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||||
return *this; | |||||
} | |||||
Vector operator*(float lr) { | |||||
Vector dst; | |||||
dst.value.val[0] = GiFloat32Type2FixLenType( | |||||
GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value.val[0]), lr)); | |||||
dst.value.val[1] = GiFloat32Type2FixLenType( | |||||
GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value.val[1]), lr)); | |||||
return dst; | |||||
} | |||||
//! val + lr * n | |||||
Vector& mla(const Vector& lr, float n) { | |||||
value.val[0] = GiFloat32Type2FixLenType(GiMultiplyAddScalarFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[0]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[0]), n)); | |||||
value.val[1] = GiFloat32Type2FixLenType(GiMultiplyAddScalarFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[1]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[1]), n)); | |||||
return *this; | |||||
} | |||||
Vector operator*(const Vector& lr) { | |||||
Vector dst; | |||||
dst.value.val[0] = GiFloat32Type2FixLenType(GiMultiplyFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[0]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||||
dst.value.val[1] = GiFloat32Type2FixLenType(GiMultiplyFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[1]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||||
return dst; | |||||
} | |||||
Vector& operator*=(const Vector& lr) { | |||||
value.val[0] = GiFloat32Type2FixLenType(GiMultiplyFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[0]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||||
value.val[1] = GiFloat32Type2FixLenType(GiMultiplyFloat32( | |||||
GiFixLenType2GiFloat32Type(value.val[1]), | |||||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||||
return *this; | |||||
} | |||||
Vector& operator=(const Vector& lr) { | |||||
value = lr.value; | |||||
return *this; | |||||
} | |||||
Vector& operator=(const Vector&& lr) { | |||||
value = std::move(lr.value); | |||||
return *this; | |||||
} | |||||
Vector operator-() { | |||||
Vector dst; | |||||
dst.value.val[0] = -value.val[0]; | |||||
dst.value.val[1] = -value.val[1]; | |||||
return dst; | |||||
} | |||||
}; | |||||
} // namespace fallback | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -857,6 +857,18 @@ GI_FLOAT32_t GiMultiplyScalerFloat32(GI_FLOAT32_t Vector1, float Scaler) { | |||||
} | } | ||||
GI_FORCEINLINE | GI_FORCEINLINE | ||||
GI_FLOAT32_V2_t GiMultiplyScalerFloat32V2(GI_FLOAT32_V2_t Vector1, float Scaler) { | |||||
GI_FLOAT32_V2_t ret; | |||||
GiSetSubVectorFloat32V2( | |||||
ret, 0, | |||||
GiMultiplyScalerFloat32(GiGetSubVectorFloat32V2(Vector1, 0), Scaler)); | |||||
GiSetSubVectorFloat32V2( | |||||
ret, 1, | |||||
GiMultiplyScalerFloat32(GiGetSubVectorFloat32V2(Vector1, 1), Scaler)); | |||||
return ret; | |||||
} | |||||
GI_FORCEINLINE | |||||
GI_FLOAT32_t GiMultiplyAddFloat32( | GI_FLOAT32_t GiMultiplyAddFloat32( | ||||
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | ||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
@@ -888,6 +900,23 @@ GI_FLOAT32_t GiMultiplyAddScalarFloat32( | |||||
} | } | ||||
GI_FORCEINLINE | GI_FORCEINLINE | ||||
GI_FLOAT32_V2_t GiMultiplyAddScalarFloat32V2( | |||||
GI_FLOAT32_V2_t VectorSum, GI_FLOAT32_V2_t Vector, float Scalar) { | |||||
GI_FLOAT32_V2_t ret; | |||||
GiSetSubVectorFloat32V2( | |||||
ret, 0, | |||||
GiMultiplyAddScalarFloat32( | |||||
GiGetSubVectorFloat32V2(VectorSum, 0), | |||||
GiGetSubVectorFloat32V2(Vector, 0), Scalar)); | |||||
GiSetSubVectorFloat32V2( | |||||
ret, 1, | |||||
GiMultiplyAddScalarFloat32( | |||||
GiGetSubVectorFloat32V2(VectorSum, 1), | |||||
GiGetSubVectorFloat32V2(Vector, 1), Scalar)); | |||||
return ret; | |||||
} | |||||
GI_FORCEINLINE | |||||
GI_FLOAT32_t GiMultiplySubScalarFloat32( | GI_FLOAT32_t GiMultiplySubScalarFloat32( | ||||
GI_FLOAT32_t VectorSub, GI_FLOAT32_t Vector, float Scalar) { | GI_FLOAT32_t VectorSub, GI_FLOAT32_t Vector, float Scalar) { | ||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
@@ -951,6 +980,26 @@ GI_FLOAT32_t GiDivideFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | |||||
#endif | #endif | ||||
} | } | ||||
#define OPV2(op) \ | |||||
GI_FORCEINLINE \ | |||||
GI_FLOAT32_V2_t op##V2(GI_FLOAT32_V2_t Vector1, GI_FLOAT32_V2_t Vector2) { \ | |||||
GI_FLOAT32_V2_t ret; \ | |||||
GiSetSubVectorFloat32V2( \ | |||||
ret, 0, \ | |||||
op(GiGetSubVectorFloat32V2(Vector1, 0), \ | |||||
GiGetSubVectorFloat32V2(Vector2, 0))); \ | |||||
GiSetSubVectorFloat32V2( \ | |||||
ret, 1, \ | |||||
op(GiGetSubVectorFloat32V2(Vector1, 1), \ | |||||
GiGetSubVectorFloat32V2(Vector2, 1))); \ | |||||
return ret; \ | |||||
} | |||||
OPV2(GiAddFloat32); | |||||
OPV2(GiSubtractFloat32); | |||||
OPV2(GiMultiplyFloat32); | |||||
OPV2(GiDivideFloat32); | |||||
#undef OPV2 | |||||
GI_FORCEINLINE | GI_FORCEINLINE | ||||
GI_FLOAT32_t GiRecpeSFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | GI_FLOAT32_t GiRecpeSFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | ||||
#if defined(GI_NEON64_INTRINSICS) | #if defined(GI_NEON64_INTRINSICS) | ||||