GitOrigin-RevId: fce5103088
HuaHua404-patch-4
@@ -1,7 +1,6 @@ | |||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||
@@ -1,7 +1,6 @@ | |||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||
@@ -4,7 +4,6 @@ | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||
using namespace megdnn; | |||
@@ -3,29 +3,44 @@ | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
namespace megdnn { | |||
namespace fallback { | |||
/* | |||
* wd##0 = d##0; | |||
* tmp0 = (d##0 + d##2) * -0.2222222f | |||
* tmp1 = d##1 * -0.2222222 | |||
* wd##1 = tmp0 + tmp1 | |||
* wd##2 = tmp0 - tmp1 | |||
* tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f | |||
* tmp1 = d##1 * 0.0222222f | |||
* wd##3 = tmp0 + tmp1 | |||
* wd##4 = tmp0 - tmp1 | |||
* tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f | |||
* tmp1 = d##1 * 0.3555556f | |||
* wd##5 = tmp0 + tmp1 | |||
* wd##6 = tmp0 - tmp1 | |||
* wd##7 = d##2 | |||
*/ | |||
template <param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT> | |||
struct FilterTransform6X3 { | |||
#define FILTER_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
auto tmp0 = (d##0 + d##2) * -0.2222222f; \ | |||
auto tmp1 = d##1 * -0.2222222f; \ | |||
wd##1 = tmp0 + tmp1; \ | |||
wd##2 = tmp0 - tmp1; \ | |||
tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \ | |||
tmp1 = d##1 * 0.0222222f; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \ | |||
tmp1 = d##1 * 0.3555556f; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = d##2; \ | |||
#define FILTER_TRANSFORM(d, wd, ADDC, SUBC, MULC) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
auto tmp0 = MULC(ADDC(d##0, d##2), -0.2222222f); \ | |||
auto tmp1 = MULC(d##1, -0.2222222f); \ | |||
wd##1 = ADDC(tmp0, tmp1); \ | |||
wd##2 = SUBC(tmp0, tmp1); \ | |||
tmp0 = ADDC(MULC(d##0, 0.0111111f), MULC(d##2, 0.0444444f)); \ | |||
tmp1 = MULC(d##1, 0.0222222f); \ | |||
wd##3 = ADDC(tmp0, tmp1); \ | |||
wd##4 = SUBC(tmp0, tmp1); \ | |||
tmp0 = ADDC(MULC(d##0, 0.7111111f), MULC(d##2, 0.1777778f)); \ | |||
tmp1 = MULC(d##1, 0.3555556f); \ | |||
wd##5 = ADDC(tmp0, tmp1); \ | |||
wd##6 = SUBC(tmp0, tmp1); \ | |||
wd##7 = d##2; \ | |||
} while (0); | |||
static void transform( | |||
@@ -49,37 +64,35 @@ struct FilterTransform6X3 { | |||
rep(ic, IC) { | |||
const float* fptr = filter + (oc * IC + ic) * 3 * 3; | |||
Vector<float, 4> g0 = Vector<float, 4>::load(fptr); | |||
Vector<float, 4> g1 = Vector<float, 4>::load(fptr + 3); | |||
GI_FLOAT32_t g0 = GiLoadFloat32(fptr); | |||
GI_FLOAT32_t g1 = GiLoadFloat32(fptr + 3); | |||
Vector<float, 4> g2 = Vector<float, 4>::load(fptr + 6 - 1); | |||
GI_FLOAT32_t g2 = GiLoadFloat32(fptr + 6 - 1); | |||
GI_FLOAT32_t zeros = GiZeroFloat32(); | |||
g2.value = GiFloat32Type2FixLenType( | |||
GiExtqFloat32(GiFixLenType2GiFloat32Type(g2.value), zeros, 1)); | |||
g2 = GiExtqFloat32(g2, zeros, 1); | |||
#define cb(i) Vector<float, 4> wd##i; | |||
#define cb(i) GI_FLOAT32_t wd##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> wdt##i; | |||
UNROLL_CALL_NOWRAPPER(3, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
FILTER_TRANSFORM(g, wd); | |||
FILTER_TRANSFORM(g, wd, ADDF, SUBF, MULSF); | |||
size_t ocb = oc / 4; | |||
size_t oc4 = oc % 4; | |||
size_t icb = ic / 4; | |||
size_t ic4 = ic % 4; | |||
#if MEGDNN_AARCH64 | |||
#define cb(i) GI_FLOAT32_V2_t wdt##i; | |||
UNROLL_CALL_NOWRAPPER(3, cb); | |||
#undef cb | |||
#define cb(i) GI_FLOAT32_V2_t ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
TRANSPOSE_8x3(wd, wdt); | |||
FILTER_TRANSFORM(wdt, ret); | |||
FILTER_TRANSFORM(wdt, ret, ADDFV2, SUBFV2, MULSFV2); | |||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
@@ -116,8 +129,7 @@ struct FilterTransform6X3 { | |||
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ | |||
mid_buf1 += 8; \ | |||
} while (0); | |||
#define GET_VECTOR_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value)) | |||
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i)) | |||
float* mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
@@ -2,6 +2,15 @@ | |||
#include "src/common/unroll_macro.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
#define ADDF GiAddFloat32 | |||
#define ADDFV2 GiAddFloat32V2 | |||
#define SUBF GiSubtractFloat32 | |||
#define SUBFV2 GiSubtractFloat32V2 | |||
#define MULF GiMultiplyFloat32 | |||
#define MULFV2 GiMultiplyFloat32V2 | |||
#define MULSF GiMultiplyScalerFloat32 | |||
#define MULSFV2 GiMultiplyScalerFloat32V2 | |||
namespace megdnn { | |||
namespace fallback { | |||
inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { | |||
@@ -47,159 +56,166 @@ inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { | |||
#define CONCAT(a, idx) a##idx | |||
#if MEGDNN_AARCH64 | |||
//! ret and a are type Vector<float, 8> | |||
#define TRANSPOSE_8x8(a, ret) \ | |||
do { \ | |||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value.val[0], CONCAT(a, 1).value.val[0]); \ | |||
auto b1 = GiZipqFloat32(CONCAT(a, 0).value.val[1], CONCAT(a, 1).value.val[1]); \ | |||
auto b2 = GiZipqFloat32(CONCAT(a, 2).value.val[0], CONCAT(a, 3).value.val[0]); \ | |||
auto b3 = GiZipqFloat32(CONCAT(a, 2).value.val[1], CONCAT(a, 3).value.val[1]); \ | |||
auto b4 = GiZipqFloat32(CONCAT(a, 4).value.val[0], CONCAT(a, 5).value.val[0]); \ | |||
auto b5 = GiZipqFloat32(CONCAT(a, 4).value.val[1], CONCAT(a, 5).value.val[1]); \ | |||
auto b6 = GiZipqFloat32(CONCAT(a, 6).value.val[0], CONCAT(a, 7).value.val[0]); \ | |||
auto b7 = GiZipqFloat32(CONCAT(a, 6).value.val[1], CONCAT(a, 7).value.val[1]); \ | |||
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b4.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b6.val[0]))); \ | |||
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b4.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b6.val[0]))); \ | |||
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b4.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b6.val[1]))); \ | |||
CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||
CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b4.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b6.val[1]))); \ | |||
CONCAT(ret, 4).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 4).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b5.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b7.val[0]))); \ | |||
CONCAT(ret, 5).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 5).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b5.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b7.val[0]))); \ | |||
CONCAT(ret, 6).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||
CONCAT(ret, 6).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b5.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b7.val[1]))); \ | |||
CONCAT(ret, 7).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||
CONCAT(ret, 7).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b5.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b7.val[1]))); \ | |||
#define TRANSPOSE_8x8(a, ret) \ | |||
do { \ | |||
auto b0 = GiZipqFloat32(CONCAT(a, 0).val[0], CONCAT(a, 1).val[0]); \ | |||
auto b1 = GiZipqFloat32(CONCAT(a, 0).val[1], CONCAT(a, 1).val[1]); \ | |||
auto b2 = GiZipqFloat32(CONCAT(a, 2).val[0], CONCAT(a, 3).val[0]); \ | |||
auto b3 = GiZipqFloat32(CONCAT(a, 2).val[1], CONCAT(a, 3).val[1]); \ | |||
auto b4 = GiZipqFloat32(CONCAT(a, 4).val[0], CONCAT(a, 5).val[0]); \ | |||
auto b5 = GiZipqFloat32(CONCAT(a, 4).val[1], CONCAT(a, 5).val[1]); \ | |||
auto b6 = GiZipqFloat32(CONCAT(a, 6).val[0], CONCAT(a, 7).val[0]); \ | |||
auto b7 = GiZipqFloat32(CONCAT(a, 6).val[1], CONCAT(a, 7).val[1]); \ | |||
CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||
CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b4.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b6.val[0]))); \ | |||
CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||
CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b4.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b6.val[0]))); \ | |||
CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||
CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b4.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b6.val[1]))); \ | |||
CONCAT(ret, 3).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||
CONCAT(ret, 3).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b4.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b6.val[1]))); \ | |||
CONCAT(ret, 4).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 4).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b5.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b7.val[0]))); \ | |||
CONCAT(ret, 5).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 5).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b5.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b7.val[0]))); \ | |||
CONCAT(ret, 6).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||
CONCAT(ret, 6).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b5.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b7.val[1]))); \ | |||
CONCAT(ret, 7).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||
CONCAT(ret, 7).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b5.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b7.val[1]))); \ | |||
} while (0); | |||
#define TRANSPOSE_8x3(a, ret) \ | |||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||
#define TRANSPOSE_8x3(a, ret) \ | |||
auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \ | |||
auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \ | |||
auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5)); \ | |||
auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7)); \ | |||
CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||
CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||
CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||
CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[1]))); | |||
#define TRANSPOSE_8x4(a, ret) \ | |||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||
CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||
CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||
#define TRANSPOSE_8x4(a, ret) \ | |||
auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \ | |||
auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \ | |||
auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5)); \ | |||
auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7)); \ | |||
CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||
CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||
CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||
CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||
CONCAT(ret, 3).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||
CONCAT(ret, 3).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[1]))); | |||
#else | |||
#define TRANSPOSE_8x4(a, ret) \ | |||
auto b0 = GiZipqFloat32( \ | |||
GiFixLenType2GiFloat32Type(CONCAT(a, 0).value), \ | |||
GiFixLenType2GiFloat32Type(CONCAT(a, 1).value)); \ | |||
auto b1 = GiZipqFloat32( \ | |||
GiFixLenType2GiFloat32Type(CONCAT(a, 2).value), \ | |||
GiFixLenType2GiFloat32Type(CONCAT(a, 3).value)); \ | |||
auto b2 = GiZipqFloat32( \ | |||
GiFixLenType2GiFloat32Type(CONCAT(a, 4).value), \ | |||
GiFixLenType2GiFloat32Type(CONCAT(a, 5).value)); \ | |||
auto b3 = GiZipqFloat32( \ | |||
GiFixLenType2GiFloat32Type(CONCAT(a, 6).value), \ | |||
GiFixLenType2GiFloat32Type(CONCAT(a, 7).value)); \ | |||
CONCAT(ret, 0).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 0)), \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \ | |||
CONCAT(ret, 1).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 0)), \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \ | |||
CONCAT(ret, 2).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 1)), \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \ | |||
CONCAT(ret, 3).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 1)), \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \ | |||
CONCAT(ret, 0).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 0)), \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \ | |||
CONCAT(ret, 1).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 0)), \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \ | |||
CONCAT(ret, 2).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 1)), \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 1)))); \ | |||
CONCAT(ret, 3).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 1)), \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 1)))); | |||
#define TRANSPOSE_8x4(a, ret) \ | |||
auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \ | |||
auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \ | |||
auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5)); \ | |||
auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7)); \ | |||
GiSetSubVectorFloat32V2( \ | |||
CONCAT(ret, 0), 0, \ | |||
GiCombineFloat32( \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 0)), \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \ | |||
GiSetSubVectorFloat32V2( \ | |||
CONCAT(ret, 1), 0, \ | |||
GiCombineFloat32( \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 0)), \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \ | |||
GiSetSubVectorFloat32V2( \ | |||
CONCAT(ret, 2), 0, \ | |||
GiCombineFloat32( \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 1)), \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \ | |||
GiSetSubVectorFloat32V2( \ | |||
CONCAT(ret, 3), 0, \ | |||
GiCombineFloat32( \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 1)), \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \ | |||
GiSetSubVectorFloat32V2( \ | |||
CONCAT(ret, 0), 1, \ | |||
GiCombineFloat32( \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 0)), \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \ | |||
GiSetSubVectorFloat32V2( \ | |||
CONCAT(ret, 1), 1, \ | |||
GiCombineFloat32( \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 0)), \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \ | |||
GiSetSubVectorFloat32V2( \ | |||
CONCAT(ret, 2), 1, \ | |||
GiCombineFloat32( \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 1)), \ | |||
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 1)))); \ | |||
GiSetSubVectorFloat32V2( \ | |||
CONCAT(ret, 3), 1, \ | |||
GiCombineFloat32( \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 1)), \ | |||
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 1)))); | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -1,7 +1,6 @@ | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
@@ -70,8 +69,7 @@ struct InputTransform2X3 { | |||
size_t nr_units_in_tile, size_t ic, size_t IC) { | |||
constexpr size_t alpha = 2 + 3 - 1; | |||
// BT * d * B | |||
#define cb(m, n) \ | |||
Vector<float, 4> d##m##n = Vector<float, 4>::load(patchT + m * 4 * 4 + n * 4); | |||
#define cb(m, n) GI_FLOAT32_t d##m##n = GiLoadFloat32(patchT + m * 4 * 4 + n * 4); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
#undef cb | |||
@@ -80,20 +78,20 @@ struct InputTransform2X3 { | |||
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | |||
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | |||
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 | |||
#define cb(m) \ | |||
auto t0##m = d0##m - d2##m; \ | |||
auto t1##m = d1##m + d2##m; \ | |||
auto t2##m = d2##m - d1##m; \ | |||
auto t3##m = d3##m - d1##m; | |||
#define cb(m) \ | |||
auto t0##m = SUBF(d0##m, d2##m); \ | |||
auto t1##m = ADDF(d1##m, d2##m); \ | |||
auto t2##m = SUBF(d2##m, d1##m); \ | |||
auto t3##m = SUBF(d3##m, d1##m); | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
d##m##0 = t##m##0 - t##m##2; \ | |||
d##m##1 = t##m##1 + t##m##2; \ | |||
d##m##2 = t##m##2 - t##m##1; \ | |||
d##m##3 = t##m##3 - t##m##1; | |||
#define cb(m) \ | |||
d##m##0 = SUBF(t##m##0, t##m##2); \ | |||
d##m##1 = ADDF(t##m##1, t##m##2); \ | |||
d##m##2 = SUBF(t##m##2, t##m##1); \ | |||
d##m##3 = SUBF(t##m##3, t##m##1); | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
@@ -101,9 +99,10 @@ struct InputTransform2X3 { | |||
size_t ICB = IC / 4; | |||
size_t icb = ic / 4; | |||
#define cb(m, n) \ | |||
d##m##n.save( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ | |||
icb * nr_units_in_tile * 4 + unit_idx * 4); | |||
icb * nr_units_in_tile * 4 + unit_idx * 4, \ | |||
d##m##n); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) | |||
#undef cb | |||
} | |||
@@ -125,7 +124,7 @@ struct OutputTransform2X3 { | |||
size_t ocb = oc_index / 4; | |||
#define cb(m, n) \ | |||
auto v##m##n = Vector<float, 4>::load( \ | |||
auto v##m##n = GiLoadFloat32( \ | |||
output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ | |||
ocb * nr_units_in_tile * 4 + unit_idx * 4); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
@@ -134,37 +133,37 @@ struct OutputTransform2X3 { | |||
//! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||
//! v20 v21 v22 v23 1 -1 | |||
//! v30 v31 v32 v33 0 1 | |||
#define cb(m) \ | |||
auto t0##m = v0##m + v1##m + v2##m; \ | |||
auto t1##m = v1##m - v2##m + v3##m; | |||
#define cb(m) \ | |||
auto t0##m = ADDF(ADDF(v0##m, v1##m), v2##m); \ | |||
auto t1##m = ADDF(SUBF(v1##m, v2##m), v3##m); | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
v00 = t00 + t01 + t02; | |||
v10 = t10 + t11 + t12; | |||
v01 = t01 - t02 + t03; | |||
v11 = t11 - t12 + t13; | |||
v00 = ADDF(ADDF(t00, t01), t02); | |||
v10 = ADDF(ADDF(t10, t11), t12); | |||
v01 = ADDF(SUBF(t01, t02), t03); | |||
v11 = ADDF(SUBF(t11, t12), t13); | |||
Vector<float, 4> vbias; | |||
GI_FLOAT32_t vbias; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
vbias = Vector<float, 4>::load(bias + oc); | |||
vbias = GiLoadFloat32(bias + oc); | |||
v00 += vbias; | |||
v10 += vbias; | |||
v01 += vbias; | |||
v11 += vbias; | |||
v00 = ADDF(v00, vbias); | |||
v10 = ADDF(v10, vbias); | |||
v01 = ADDF(v01, vbias); | |||
v11 = ADDF(v11, vbias); | |||
} | |||
if (bmode != BiasMode::BIAS) { | |||
v00 = op(GiFixLenType2GiFloat32Type(v00.value)); | |||
v01 = op(GiFixLenType2GiFloat32Type(v01.value)); | |||
v10 = op(GiFixLenType2GiFloat32Type(v10.value)); | |||
v11 = op(GiFixLenType2GiFloat32Type(v11.value)); | |||
v00 = op(v00); | |||
v01 = op(v01); | |||
v10 = op(v10); | |||
v11 = op(v11); | |||
} | |||
v00.save(transform_mid_buf + (0 * 2 + 0) * 4); | |||
v10.save(transform_mid_buf + (1 * 2 + 0) * 4); | |||
v01.save(transform_mid_buf + (0 * 2 + 1) * 4); | |||
v11.save(transform_mid_buf + (1 * 2 + 1) * 4); | |||
GiStoreFloat32(transform_mid_buf + (0 * 2 + 0) * 4, v00); | |||
GiStoreFloat32(transform_mid_buf + (1 * 2 + 0) * 4, v10); | |||
GiStoreFloat32(transform_mid_buf + (0 * 2 + 1) * 4, v01); | |||
GiStoreFloat32(transform_mid_buf + (1 * 2 + 1) * 4, v11); | |||
for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { | |||
for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { | |||
@@ -1,7 +1,6 @@ | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
@@ -15,61 +14,124 @@ using namespace megdnn; | |||
using namespace fallback; | |||
namespace { | |||
/* | |||
* wd##0 = d##0; | |||
* wd##r0 = d##r0; | |||
* wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; | |||
* wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; | |||
* wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; | |||
* wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; | |||
* auto tmpd0 = d##0 * 0.7111111; | |||
* auto tmpd1 = d##1 * 0.3555556; | |||
* auto tmpd2 = d##2 * 0.1777778; | |||
* auto tmpd3 = d##3 * 0.0888889; | |||
* auto tmpd4 = d##4 * 0.0444444; | |||
* auto tmpdr0 = d##r0 * 0.7111111; | |||
* auto tmpdr1 = d##r1 * 0.3555556; | |||
* auto tmpdr2 = d##r2 * 0.1777778; | |||
* auto tmpdr3 = d##r3 * 0.0888889; | |||
* auto tmpdr4 = d##r4 * 0.0444444; | |||
* wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; | |||
* wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; | |||
* wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; | |||
* wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; | |||
* tmpd0 = d##0 * 0.0111111; | |||
* tmpd1 = d##1 * 0.0222222; | |||
* tmpd2 = d##2 * 0.0444444; | |||
* tmpd3 = d##3 * 0.0888889; | |||
* tmpd4 = d##4 * 0.1777778; | |||
* tmpdr0 = d##r0 * 0.0111111; | |||
* tmpdr1 = d##r1 * 0.0222222; | |||
* tmpdr2 = d##r2 * 0.0444444; | |||
* tmpdr3 = d##r3 * 0.0888889; | |||
* tmpdr4 = d##r4 * 0.1777778; | |||
* wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; | |||
* wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; | |||
* wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; | |||
* wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; | |||
* wd##7 = d##4; | |||
* wd##r7 = d##r4; | |||
* | |||
* */ | |||
struct FilterTransform4X5 { | |||
#define FILTER_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
wd##r0 = d##r0; \ | |||
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \ | |||
wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; \ | |||
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \ | |||
wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; \ | |||
auto tmpd0 = d##0 * 0.7111111; \ | |||
auto tmpd1 = d##1 * 0.3555556; \ | |||
auto tmpd2 = d##2 * 0.1777778; \ | |||
auto tmpd3 = d##3 * 0.0888889; \ | |||
auto tmpd4 = d##4 * 0.0444444; \ | |||
auto tmpdr0 = d##r0 * 0.7111111; \ | |||
auto tmpdr1 = d##r1 * 0.3555556; \ | |||
auto tmpdr2 = d##r2 * 0.1777778; \ | |||
auto tmpdr3 = d##r3 * 0.0888889; \ | |||
auto tmpdr4 = d##r4 * 0.0444444; \ | |||
wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ | |||
wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||
wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ | |||
wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||
tmpd0 = d##0 * 0.0111111; \ | |||
tmpd1 = d##1 * 0.0222222; \ | |||
tmpd2 = d##2 * 0.0444444; \ | |||
tmpd3 = d##3 * 0.0888889; \ | |||
tmpd4 = d##4 * 0.1777778; \ | |||
tmpdr0 = d##r0 * 0.0111111; \ | |||
tmpdr1 = d##r1 * 0.0222222; \ | |||
tmpdr2 = d##r2 * 0.0444444; \ | |||
tmpdr3 = d##r3 * 0.0888889; \ | |||
tmpdr4 = d##r4 * 0.1777778; \ | |||
wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ | |||
wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||
wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ | |||
wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||
wd##7 = d##4; \ | |||
wd##r7 = d##r4; \ | |||
#define FILTER_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
wd##r0 = d##r0; \ | |||
wd##1 = MULSF( \ | |||
ADDF(ADDF(ADDF(ADDF(d##0, d##1), d##2), d##3), d##4), -0.2222222); \ | |||
wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; \ | |||
wd##2 = MULSF( \ | |||
ADDF(SUBF(ADDF(SUBF(d##0, d##1), d##2), d##3), d##4), -0.2222222); \ | |||
wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; \ | |||
auto tmpd0 = MULSF(d##0, 0.7111111); \ | |||
auto tmpd1 = MULSF(d##1, 0.3555556); \ | |||
auto tmpd2 = MULSF(d##2, 0.1777778); \ | |||
auto tmpd3 = MULSF(d##3, 0.0888889); \ | |||
auto tmpd4 = MULSF(d##4, 0.0444444); \ | |||
auto tmpdr0 = d##r0 * 0.7111111; \ | |||
auto tmpdr1 = d##r1 * 0.3555556; \ | |||
auto tmpdr2 = d##r2 * 0.1777778; \ | |||
auto tmpdr3 = d##r3 * 0.0888889; \ | |||
auto tmpdr4 = d##r4 * 0.0444444; \ | |||
wd##3 = ADDF(ADDF(ADDF(ADDF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \ | |||
wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||
wd##4 = ADDF(SUBF(ADDF(SUBF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \ | |||
wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||
tmpd0 = MULSF(d##0, 0.0111111); \ | |||
tmpd1 = MULSF(d##1, 0.0222222); \ | |||
tmpd2 = MULSF(d##2, 0.0444444); \ | |||
tmpd3 = MULSF(d##3, 0.0888889); \ | |||
tmpd4 = MULSF(d##4, 0.1777778); \ | |||
tmpdr0 = d##r0 * 0.0111111; \ | |||
tmpdr1 = d##r1 * 0.0222222; \ | |||
tmpdr2 = d##r2 * 0.0444444; \ | |||
tmpdr3 = d##r3 * 0.0888889; \ | |||
tmpdr4 = d##r4 * 0.1777778; \ | |||
wd##5 = ADDF(ADDF(ADDF(ADDF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \ | |||
wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||
wd##6 = ADDF(SUBF(ADDF(SUBF(tmpd0, tmpd1), tmpd2), tmpd3), tmpd4); \ | |||
wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||
wd##7 = d##4; \ | |||
wd##r7 = d##r4; \ | |||
} while (0); | |||
#define FILTER_TRANSFORM_FINAL(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \ | |||
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \ | |||
auto tmp0 = d##0 * 0.7111111 + d##2 * 0.1777778 + d##4 * 0.0444444; \ | |||
auto tmp1 = d##1 * 0.3555556 + d##3 * 0.0888889; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d##0 * 0.0111111 + d##2 * 0.0444444 + d##4 * 0.1777778; \ | |||
tmp1 = d##1 * 0.0222222 + d##3 * 0.0888889; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = d##4; \ | |||
/* | |||
*wd##0 = d##0; | |||
*wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; | |||
*wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; | |||
*auto tmp0 = d##0 * 0.7111111 + d##2 * 0.1777778 + d##4 * 0.0444444; | |||
*auto tmp1 = d##1 * 0.3555556 + d##3 * 0.0888889; | |||
*wd##3 = tmp0 + tmp1; | |||
*wd##4 = tmp0 - tmp1; | |||
*tmp0 = d##0 * 0.0111111 + d##2 * 0.0444444 + d##4 * 0.1777778; | |||
*tmp1 = d##1 * 0.0222222 + d##3 * 0.0888889; | |||
*wd##5 = tmp0 + tmp1; | |||
*wd##6 = tmp0 - tmp1; | |||
*wd##7 = d##4; | |||
*/ | |||
#define FILTER_TRANSFORM_FINAL(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
wd##1 = \ | |||
MULSFV2(ADDFV2(ADDFV2(ADDFV2(ADDFV2(d##0, d##1), d##2), d##3), d##4), \ | |||
-0.2222222); \ | |||
wd##2 = \ | |||
MULSFV2(ADDFV2(SUBFV2(ADDFV2(SUBFV2(d##0, d##1), d##2), d##3), d##4), \ | |||
-0.2222222); \ | |||
auto tmp0 = \ | |||
ADDFV2(ADDFV2(MULSFV2(d##0, 0.7111111), MULSFV2(d##2, 0.1777778)), \ | |||
MULSFV2(d##4, 0.0444444)); \ | |||
auto tmp1 = ADDFV2(MULSFV2(d##1, 0.3555556), MULSFV2(d##3, 0.0888889)); \ | |||
wd##3 = ADDFV2(tmp0, tmp1); \ | |||
wd##4 = SUBFV2(tmp0, tmp1); \ | |||
tmp0 = \ | |||
ADDFV2(ADDFV2(MULSFV2(d##0, 0.0111111), MULSFV2(d##2, 0.0444444)), \ | |||
MULSFV2(d##4, 0.1777778)); \ | |||
tmp1 = ADDFV2(MULSFV2(d##1, 0.0222222), MULSFV2(d##3, 0.0888889)); \ | |||
wd##5 = ADDFV2(tmp0, tmp1); \ | |||
wd##6 = SUBFV2(tmp0, tmp1); \ | |||
wd##7 = d##4; \ | |||
} while (0); | |||
static void transform( | |||
const float* filter, float* filter_transform_buf, float* transform_mid_buf, | |||
@@ -89,7 +151,7 @@ struct FilterTransform4X5 { | |||
rep(ic, IC) { | |||
const float* fptr = filter + (oc * IC + ic) * 5 * 5; | |||
#define cb(i) Vector<float, 4> g##i = Vector<float, 4>::load(fptr + 5 * i); | |||
#define cb(i) GI_FLOAT32_t g##i = GiLoadFloat32(fptr + 5 * i); | |||
UNROLL_CALL_NOWRAPPER(5, cb); | |||
#undef cb | |||
@@ -97,7 +159,7 @@ struct FilterTransform4X5 { | |||
UNROLL_CALL_NOWRAPPER(5, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 4> Gg##i; | |||
#define cb(i) GI_FLOAT32_t Gg##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
@@ -105,11 +167,11 @@ struct FilterTransform4X5 { | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> Ggt##i; | |||
#define cb(i) GI_FLOAT32_V2_t Ggt##i; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> result##i; | |||
#define cb(i) GI_FLOAT32_V2_t result##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
@@ -128,11 +190,11 @@ struct FilterTransform4X5 { | |||
GI_FLOAT32_t vgr1 = GiLoadFloat32(tmp); | |||
GiSetSubVectorFloat32V2(vgr, 0, vgr0); //{Ggr0, Ggr1, Ggr2, Ggr3}; | |||
GiSetSubVectorFloat32V2(vgr, 1, vgr1); //{Ggr4, Ggr5, Ggr6, Ggr7}; | |||
Vector<float, 8> Ggt4(vgr); | |||
GI_FLOAT32_V2_t Ggt4(vgr); | |||
TRANSPOSE_8x4(Gg, Ggt); | |||
FILTER_TRANSFORM_FINAL(Ggt, result); | |||
#define cb(i) result##i.save(transform_mid_buf + i * alpha); | |||
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, result##i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
@@ -145,31 +207,49 @@ struct FilterTransform4X5 { | |||
#undef FILTER_TRANSFORM | |||
#undef FILTER_TRANSFORM_FINAL | |||
/* | |||
* wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; | |||
* auto tmp0 = d##2 - d##4 * 4.25f + d##6; | |||
* auto tmp1 = d##1 - d##3 * 4.25f + d##5; | |||
* wd##1 = tmp0 + tmp1; | |||
* wd##2 = tmp0 - tmp1; | |||
* tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; | |||
* tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; | |||
* wd##3 = tmp0 + tmp1; | |||
* wd##4 = tmp0 - tmp1; | |||
* tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; | |||
* tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; | |||
* wd##5 = tmp0 + tmp1; | |||
* wd##6 = tmp0 - tmp1; | |||
* wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; | |||
*/ | |||
struct InputTransform4X5 { | |||
#define INPUT_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||
auto tmp0 = d##2 - d##4 * 4.25f + d##6; \ | |||
auto tmp1 = d##1 - d##3 * 4.25f + d##5; \ | |||
wd##1 = tmp0 + tmp1; \ | |||
wd##2 = tmp0 - tmp1; \ | |||
tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ | |||
tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ | |||
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||
#define INPUT_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = ADDFV2(SUBFV2(d##0, d##6), MULSFV2(SUBFV2(d##4, d##2), 5.25f)); \ | |||
auto tmp0 = ADDFV2(SUBFV2(d##2, MULSFV2(d##4, 4.25f)), d##6); \ | |||
auto tmp1 = ADDFV2(SUBFV2(d##1, MULSFV2(d##3, 4.25f)), d##5); \ | |||
wd##1 = ADDFV2(tmp0, tmp1); \ | |||
wd##2 = SUBFV2(tmp0, tmp1); \ | |||
tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 4.0f), MULSFV2(d##4, 5.0f)), d##6); \ | |||
tmp1 = \ | |||
ADDFV2(SUBFV2(MULSFV2(d##1, 2.0f), MULSFV2(d##3, 2.5f)), \ | |||
MULSFV2(d##5, 0.5f)); \ | |||
wd##3 = ADDFV2(tmp0, tmp1); \ | |||
wd##4 = SUBFV2(tmp0, tmp1); \ | |||
tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 0.25f), MULSFV2(d##4, 1.25f)), d##6); \ | |||
tmp1 = \ | |||
ADDFV2(SUBFV2(MULSFV2(d##1, 0.5f), MULSFV2(d##3, 2.5f)), \ | |||
MULSFV2(d##5, 2.0f)); \ | |||
wd##5 = ADDFV2(tmp0, tmp1); \ | |||
wd##6 = SUBFV2(tmp0, tmp1); \ | |||
wd##7 = ADDFV2(SUBFV2(d##7, d##1), MULSFV2(SUBFV2(d##3, d##5), 5.25f)); \ | |||
} while (0) | |||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2( \ | |||
GiFixLenType2GiFloat32V2Type(CONCAT(s, i).value), 1)) | |||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2( \ | |||
GiFixLenType2GiFloat32V2Type(CONCAT(s, i).value), 0)) | |||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 1)) | |||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 0)) | |||
template <bool inner> | |||
static void transform( | |||
@@ -191,13 +271,13 @@ struct InputTransform4X5 { | |||
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | |||
} | |||
#define cb(i) Vector<float, 8> d##i; | |||
#define cb(i) GI_FLOAT32_V2_t d##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
if (inner) { | |||
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | |||
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i); | |||
#define cb(i) d##i = GiLoadFloat32V2(input_ptr + IW * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} else { | |||
@@ -212,21 +292,25 @@ struct InputTransform4X5 { | |||
input[ic * IH * IW + ih * IW + iw]; | |||
} | |||
} | |||
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||
#define cb(i) d##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} | |||
#define cb(i) Vector<float, 8> wd##i, ret##i; | |||
#define cb(i) GI_FLOAT32_V2_t wd##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
INPUT_TRANSFORM(d, wd); | |||
#if MEGDNN_AARCH64 | |||
#define cb(i) GI_FLOAT32_V2_t ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
TRANSPOSE_8x8(wd, d); | |||
INPUT_TRANSFORM(d, ret); | |||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
@@ -283,12 +367,32 @@ struct InputTransform4X5 { | |||
}; | |||
#undef INPUT_TRANSFORM | |||
#define OUTPUT_TRANSFORM(m, s) \ | |||
do { \ | |||
s0 = m0 + m1 + m2 + m3 + m4 + m5 + m6; \ | |||
s1 = m1 - m2 + m3 * 0.5 - m4 * 0.5 + m5 * 2.0 - m6 * 2.0; \ | |||
s2 = m1 + m2 + m3 * 0.25 + m4 * 0.25 + m5 * 4.0 + m6 * 4.0; \ | |||
s3 = m1 - m2 + m3 * 0.125 - m4 * 0.125 + m5 * 8.0 - m6 * 8.0 + m7; \ | |||
/* | |||
* s0 = m0 + m1 + m2 + m3 + m4 + m5 + m6; | |||
* s1 = m1 - m2 + m3 * 0.5 - m4 * 0.5 + m5 * 2.0 - m6 * 2.0; | |||
* s2 = m1 + m2 + m3 * 0.25 + m4 * 0.25 + m5 * 4.0 + m6 * 4.0; | |||
* s3 = m1 - m2 + m3 * 0.125 - m4 * 0.125 + m5 * 8.0 - m6 * 8.0 + m7; | |||
*/ | |||
#define OUTPUT_TRANSFORM(m, s) \ | |||
do { \ | |||
s0 = ADDFV2( \ | |||
ADDFV2(ADDFV2(ADDFV2(ADDFV2(ADDFV2(m0, m1), m2), m3), m4), m5), m6); \ | |||
s1 = \ | |||
SUBFV2(ADDFV2(SUBFV2(ADDFV2(SUBFV2(m1, m2), MULSFV2(m3, 0.5)), \ | |||
MULSFV2(m4, 0.5)), \ | |||
MULSFV2(m5, 2.0)), \ | |||
MULSFV2(m6, 2.0)); \ | |||
s2 = \ | |||
ADDFV2(ADDFV2(ADDFV2(ADDFV2(ADDFV2(m1, m2), MULSFV2(m3, 0.25)), \ | |||
MULSFV2(m4, 0.25)), \ | |||
MULSFV2(m5, 4.0)), \ | |||
MULSFV2(m6, 4.0)); \ | |||
s3 = ADDFV2( \ | |||
SUBFV2(ADDFV2(SUBFV2(ADDFV2(SUBFV2(m1, m2), MULSFV2(m3, 0.125)), \ | |||
MULSFV2(m4, 0.125)), \ | |||
MULSFV2(m5, 8.0)), \ | |||
MULSFV2(m6, 8.0)), \ | |||
m7); \ | |||
} while (0) | |||
template <BiasMode bmode, typename Op> | |||
struct OutputTransform4X5 { | |||
@@ -316,14 +420,15 @@ struct OutputTransform4X5 { | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
#undef cb | |||
#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||
#define cb(i) auto m##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> s##i; | |||
#define cb(i) GI_FLOAT32_V2_t s##i; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
OUTPUT_TRANSFORM(m, s); | |||
#define cb(i) \ | |||
do { \ | |||
auto add12 = GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||
@@ -355,10 +460,10 @@ struct OutputTransform4X5 { | |||
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
item0 = GiAddFloat32(item0, bias0); | |||
item0 = ADDF(item0, bias0); | |||
} else if (bmode == BiasMode::BIAS) { | |||
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); | |||
item0 = GiAddFloat32(item0, bias0); | |||
item0 = ADDF(item0, bias0); | |||
} | |||
item0 = op(item0); | |||
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||
@@ -1,7 +1,6 @@ | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
@@ -15,23 +14,39 @@ using namespace megdnn; | |||
using namespace fallback; | |||
namespace { | |||
/* | |||
* wd##0 = d##0; | |||
* auto tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; | |||
* auto tmp1 = d##1 * 0.3555556f + d##3 * 0.0888889f; | |||
* wd##1 = tmp0 + tmp1; | |||
* wd##2 = tmp0 - tmp1; | |||
* tmp0 = (d##0 + d##2) * -0.2222222f; | |||
* tmp1 = (d##1 + d##3) * -0.2222222f; | |||
* wd##3 = tmp0 + tmp1; | |||
* wd##4 = tmp0 - tmp1; | |||
* tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; | |||
* tmp1 = d##1 * 0.0222222f + d##3 * 0.0888889f; | |||
* wd##5 = tmp0 + tmp1; | |||
* wd##6 = tmp0 - tmp1; | |||
* wd##7 = d##3; | |||
*/ | |||
struct FilterTransform5X4 { | |||
#define FILTER_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
auto tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \ | |||
auto tmp1 = d##1 * 0.3555556f + d##3 * 0.0888889f; \ | |||
wd##1 = tmp0 + tmp1; \ | |||
wd##2 = tmp0 - tmp1; \ | |||
tmp0 = (d##0 + d##2) * -0.2222222f; \ | |||
tmp1 = (d##1 + d##3) * -0.2222222f; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \ | |||
tmp1 = d##1 * 0.0222222f + d##3 * 0.0888889f; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = d##3; \ | |||
#define FILTER_TRANSFORM(d, wd, ADDC, SUBC, MULC) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
auto tmp0 = ADDC(MULC(d##0, 0.7111111f), MULC(d##2, 0.1777778f)); \ | |||
auto tmp1 = ADDC(MULC(d##1, 0.3555556f), MULC(d##3, 0.0888889f)); \ | |||
wd##1 = ADDC(tmp0, tmp1); \ | |||
wd##2 = SUBC(tmp0, tmp1); \ | |||
tmp0 = MULC(ADDC(d##0, d##2), -0.2222222f); \ | |||
tmp1 = MULC(ADDC(d##1, d##3), -0.2222222f); \ | |||
wd##3 = ADDC(tmp0, tmp1); \ | |||
wd##4 = SUBC(tmp0, tmp1); \ | |||
tmp0 = ADDC(MULC(d##0, 0.0111111f), MULC(d##2, 0.0444444f)); \ | |||
tmp1 = ADDC(MULC(d##1, 0.0222222f), MULC(d##3, 0.0888889f)); \ | |||
wd##5 = ADDC(tmp0, tmp1); \ | |||
wd##6 = SUBC(tmp0, tmp1); \ | |||
wd##7 = d##3; \ | |||
} while (0) | |||
static void transform( | |||
@@ -53,28 +68,27 @@ struct FilterTransform5X4 { | |||
rep(ic, IC) { | |||
const float* fptr = filter + (oc * IC + ic) * 4 * 4; | |||
#define cb(i) Vector<float, 4> g##i = Vector<float, 4>::load(fptr + 4 * i); | |||
#define cb(i) GI_FLOAT32_t g##i = GiLoadFloat32(fptr + 4 * i); | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 4> wd##i; | |||
#define cb(i) GI_FLOAT32_t wd##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> wdt##i; | |||
FILTER_TRANSFORM(g, wd, ADDF, SUBF, MULSF); | |||
#if MEGDNN_AARCH64 | |||
#define cb(i) GI_FLOAT32_V2_t wdt##i; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> ret##i; | |||
#define cb(i) GI_FLOAT32_V2_t ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
FILTER_TRANSFORM(g, wd); | |||
#if MEGDNN_AARCH64 | |||
TRANSPOSE_8x4(wd, wdt); | |||
FILTER_TRANSFORM(wdt, ret); | |||
FILTER_TRANSFORM(wdt, ret, ADDFV2, SUBFV2, MULSFV2); | |||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
@@ -104,8 +118,7 @@ struct FilterTransform5X4 { | |||
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ | |||
mid_buf1 += 8; \ | |||
} while (0); | |||
#define GET_VECTOR_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value)) | |||
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i)) | |||
float* mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
@@ -123,29 +136,49 @@ struct FilterTransform5X4 { | |||
#undef FILTER_TRANSFORM | |||
#undef GET_VECTOR_ELEM | |||
/* | |||
* wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; | |||
* auto tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; | |||
* auto tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; | |||
* wd##1 = tmp0 + tmp1; | |||
* wd##2 = tmp0 - tmp1; | |||
* tmp0 = d##2 - d##4 * 4.25f + d##6; | |||
* tmp1 = d##1 - d##3 * 4.25f + d##5; | |||
* wd##3 = tmp0 + tmp1; | |||
* wd##4 = tmp0 - tmp1; | |||
* tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; | |||
* tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; | |||
* wd##5 = tmp0 + tmp1; | |||
* wd##6 = tmp0 - tmp1; | |||
* wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; | |||
*/ | |||
struct InputTransform5X4 { | |||
#define INPUT_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||
auto tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ | |||
auto tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ | |||
wd##1 = tmp0 + tmp1; \ | |||
wd##2 = tmp0 - tmp1; \ | |||
tmp0 = d##2 - d##4 * 4.25f + d##6; \ | |||
tmp1 = d##1 - d##3 * 4.25f + d##5; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ | |||
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||
#define INPUT_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = ADDFV2(SUBFV2(d##0, d##6), MULSFV2(SUBFV2(d##4, d##2), 5.25f)); \ | |||
auto tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 4.0f), MULSFV2(d##4, 5.0f)), d##6); \ | |||
auto tmp1 = \ | |||
ADDFV2(SUBFV2(MULSFV2(d##1, 2.0f), MULSFV2(d##3, 2.5f)), \ | |||
MULSFV2(d##5, 0.5f)); \ | |||
wd##1 = ADDFV2(tmp0, tmp1); \ | |||
wd##2 = SUBFV2(tmp0, tmp1); \ | |||
tmp0 = ADDFV2(SUBFV2(d##2, MULSFV2(d##4, 4.25f)), d##6); \ | |||
tmp1 = ADDFV2(SUBFV2(d##1, MULSFV2(d##3, 4.25f)), d##5); \ | |||
wd##3 = ADDFV2(tmp0, tmp1); \ | |||
wd##4 = SUBFV2(tmp0, tmp1); \ | |||
tmp0 = ADDFV2(SUBFV2(MULSFV2(d##2, 0.25f), MULSFV2(d##4, 1.25f)), d##6); \ | |||
tmp1 = \ | |||
ADDFV2(SUBFV2(MULSFV2(d##1, 0.5f), MULSFV2(d##3, 2.5f)), \ | |||
MULSFV2(d##5, 2.0f)); \ | |||
wd##5 = ADDFV2(tmp0, tmp1); \ | |||
wd##6 = SUBFV2(tmp0, tmp1); \ | |||
wd##7 = ADDFV2(SUBFV2(d##7, d##1), MULSFV2(SUBFV2(d##3, d##5), 5.25f)); \ | |||
} while (0) | |||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[1])) | |||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 1)) | |||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[0])) | |||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 0)) | |||
template <bool inner> | |||
static void transform( | |||
@@ -168,13 +201,13 @@ struct InputTransform5X4 { | |||
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | |||
} | |||
#define cb(i) Vector<float, 8> d##i; | |||
#define cb(i) GI_FLOAT32_V2_t d##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
if (inner) { | |||
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | |||
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i); | |||
#define cb(i) d##i = GiLoadFloat32V2(input_ptr + IW * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} else { | |||
@@ -189,21 +222,25 @@ struct InputTransform5X4 { | |||
input[ic * IH * IW + ih * IW + iw]; | |||
} | |||
} | |||
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||
#define cb(i) d##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} | |||
#define cb(i) Vector<float, 8> wd##i, ret##i; | |||
#define cb(i) GI_FLOAT32_V2_t wd##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
INPUT_TRANSFORM(d, wd); | |||
#if MEGDNN_AARCH64 | |||
#define cb(i) GI_FLOAT32_V2_t ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
TRANSPOSE_8x8(wd, d); | |||
INPUT_TRANSFORM(d, ret); | |||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
@@ -260,24 +297,48 @@ struct InputTransform5X4 { | |||
}; | |||
#undef INPUT_TRANSFORM | |||
#define OUTPUT_TRANSFORM(m, s) \ | |||
do { \ | |||
auto m1addm2 = m##1 + m##2; \ | |||
auto m1subm2 = m##1 - m##2; \ | |||
auto m3addm4 = m##3 + m##4; \ | |||
auto m3subm4 = m##3 - m##4; \ | |||
auto m5addm6 = (m##5 + m##6); \ | |||
auto m5subm6 = (m##5 - m##6); \ | |||
s##0 = m##0; \ | |||
CONCAT(s, 0).add(m1addm2).add(m3addm4).add(m5addm6); \ | |||
CONCAT(s, 1) = m3subm4; \ | |||
CONCAT(s, 1).mla(m1subm2, 0.5f).mla(m5subm6, 2.0f); \ | |||
CONCAT(s, 2) = m3addm4; \ | |||
CONCAT(s, 2).mla(m1addm2, 0.25f).mla(m5addm6, 4.0f); \ | |||
CONCAT(s, 3) = m3subm4; \ | |||
CONCAT(s, 3).mla(m1subm2, 0.125f).mla(m5subm6, 8.0f); \ | |||
CONCAT(s, 4) = m##7; \ | |||
CONCAT(s, 4).mla(m1addm2, 0.0625f).add(m3addm4).mla(m5addm6, 16.0f); \ | |||
/* | |||
* auto m1addm2 = m##1 + m##2; | |||
* auto m1subm2 = m##1 - m##2; | |||
* auto m3addm4 = m##3 + m##4; | |||
* auto m3subm4 = m##3 - m##4; | |||
* auto m5addm6 = (m##5 + m##6); | |||
* auto m5subm6 = (m##5 - m##6); | |||
* s##0 = m##0; | |||
* CONCAT(s, 0).add(m1addm2).add(m3addm4).add(m5addm6); | |||
* CONCAT(s, 1) = m3subm4; | |||
* CONCAT(s, 1).mla(m1subm2, 0.5f).mla(m5subm6, 2.0f); | |||
* CONCAT(s, 2) = m3addm4; | |||
* CONCAT(s, 2).mla(m1addm2, 0.25f).mla(m5addm6, 4.0f); | |||
* CONCAT(s, 3) = m3subm4; | |||
* CONCAT(s, 3).mla(m1subm2, 0.125f).mla(m5subm6, 8.0f); | |||
* CONCAT(s, 4) = m##7; | |||
* CONCAT(s, 4).mla(m1addm2, 0.0625f).add(m3addm4).mla(m5addm6, 16.0f); | |||
*/ | |||
#define OUTPUT_TRANSFORM(m, s) \ | |||
do { \ | |||
auto m1addm2 = ADDFV2(m##1, m##2); \ | |||
auto m1subm2 = SUBFV2(m##1, m##2); \ | |||
auto m3addm4 = ADDFV2(m##3, m##4); \ | |||
auto m3subm4 = SUBFV2(m##3, m##4); \ | |||
auto m5addm6 = ADDFV2(m##5, m##6); \ | |||
auto m5subm6 = SUBFV2(m##5, m##6); \ | |||
s##0 = m##0; \ | |||
CONCAT(s, 0) = \ | |||
ADDFV2(ADDFV2(ADDFV2(CONCAT(s, 0), m1addm2), m3addm4), m5addm6); \ | |||
CONCAT(s, 1) = m3subm4; \ | |||
CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m1subm2, 0.5f); \ | |||
CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m5subm6, 2.0f); \ | |||
CONCAT(s, 2) = m3addm4; \ | |||
CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m1addm2, 0.25f); \ | |||
CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m5addm6, 4.0f); \ | |||
CONCAT(s, 3) = m3subm4; \ | |||
CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m1subm2, 0.125f); \ | |||
CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m5subm6, 8.0f); \ | |||
CONCAT(s, 4) = m##7; \ | |||
CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m1addm2, 0.0625f); \ | |||
CONCAT(s, 4) = ADDFV2(CONCAT(s, 4), m3addm4); \ | |||
CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m5addm6, 16.0f); \ | |||
} while (0) | |||
#if defined(__GNUC__) && !defined(__llvm__) && !defined(_MSC_VER) | |||
@@ -314,11 +375,11 @@ struct OutputTransform5X4 { | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
#undef cb | |||
#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||
#define cb(i) auto m##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> s##i, ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#define cb(i) GI_FLOAT32_V2_t s##i; | |||
UNROLL_CALL_NOWRAPPER(5, cb); | |||
#undef cb | |||
OUTPUT_TRANSFORM(m, s); | |||
@@ -3,7 +3,6 @@ | |||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/fallback/elemwise_helper/op_unary.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
@@ -27,28 +26,31 @@ namespace { | |||
* wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) | |||
* wd7 = (d7 - d1) + 5.25 * (d3 - d5) | |||
*/ | |||
#define INPUT_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||
auto tmp0 = d##6 + d##2 - d##4 * 4.25f; \ | |||
auto tmp1 = d##1 + d##5 - d##3 * 4.25f; \ | |||
wd##1 = tmp0 + tmp1; \ | |||
wd##2 = tmp0 - tmp1; \ | |||
tmp0 = d##6 + d##2 * 0.25f - d##4 * 1.25f; \ | |||
tmp1 = (d##5 + d##1 * 0.25f - d##3 * 1.25f) * 2.0f; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d6 - d4 * 5.0f + d2 * 4.0f; \ | |||
tmp1 = (d1 + d5 * 0.25f - d3 * 1.25f) * 2.0f; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||
#define INPUT_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = ADDFV2(SUBFV2(d##0, d##6), MULSFV2(SUBFV2(d##4, d##2), 5.25f)); \ | |||
auto tmp0 = SUBFV2(ADDFV2(d##6, d##2), MULSFV2(d##4, 4.25f)); \ | |||
auto tmp1 = SUBFV2(ADDFV2(d##1, d##5), MULSFV2(d##3, 4.25f)); \ | |||
wd##1 = ADDFV2(tmp0, tmp1); \ | |||
wd##2 = SUBFV2(tmp0, tmp1); \ | |||
tmp0 = SUBFV2(ADDFV2(d##6, MULSFV2(d##2, 0.25f)), MULSFV2(d##4, 1.25f)); \ | |||
tmp1 = MULSFV2( \ | |||
(SUBFV2(ADDFV2(d##5, MULSFV2(d##1, 0.25f)), MULSFV2(d##3, 1.25f))), \ | |||
2.0f); \ | |||
wd##3 = ADDFV2(tmp0, tmp1); \ | |||
wd##4 = SUBFV2(tmp0, tmp1); \ | |||
tmp0 = ADDFV2(SUBFV2(d6, MULSFV2(d4, 5.0f)), MULSFV2(d2, 4.0f)); \ | |||
tmp1 = MULSFV2( \ | |||
(SUBFV2(ADDFV2(d1, MULSFV2(d5, 0.25f)), MULSFV2(d3, 1.25f))), 2.0f); \ | |||
wd##5 = ADDFV2(tmp0, tmp1); \ | |||
wd##6 = SUBFV2(tmp0, tmp1); \ | |||
wd##7 = ADDFV2(SUBFV2(d##7, d##1), MULSFV2(SUBFV2(d##3, d##5), 5.25f)); \ | |||
} while (0); | |||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[1])) | |||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 1)) | |||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[0])) | |||
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2(CONCAT(s, i), 0)) | |||
struct InputTransform6X3 { | |||
template <bool inner> | |||
static void transform( | |||
@@ -60,13 +62,13 @@ struct InputTransform6X3 { | |||
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | |||
} | |||
#define cb(i) Vector<float, 8> d##i; | |||
#define cb(i) GI_FLOAT32_V2_t d##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
if (inner) { | |||
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | |||
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i); | |||
#define cb(i) d##i = GiLoadFloat32V2(input_ptr + IW * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} else { | |||
@@ -81,22 +83,25 @@ struct InputTransform6X3 { | |||
input[ic * IH * IW + ih * IW + iw]; | |||
} | |||
} | |||
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||
#define cb(i) d##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} | |||
#define cb(i) Vector<float, 8> wd##i, ret##i; | |||
#define cb(i) GI_FLOAT32_V2_t wd##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
INPUT_TRANSFORM(d, wd); | |||
#if MEGDNN_AARCH64 | |||
#define cb(i) GI_FLOAT32_V2_t ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
TRANSPOSE_8x8(wd, d); | |||
INPUT_TRANSFORM(d, ret); | |||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
@@ -174,26 +179,34 @@ struct InputTransform6X3 { | |||
* s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6) / 32 | |||
* s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) / 32 + m7 | |||
*/ | |||
#define OUTPUT_TRANSFORM(m, s) \ | |||
do { \ | |||
auto m1addm2 = m##1 + m##2; \ | |||
auto m1subm2 = m##1 - m##2; \ | |||
auto m3addm4 = m##3 + m##4; \ | |||
auto m3subm4 = m##3 - m##4; \ | |||
auto m5addm6 = (m##5 + m##6) * 0.03125f; \ | |||
auto m5subm6 = (m##5 - m##6) * 0.03125f; \ | |||
s##0 = m##0; \ | |||
CONCAT(s, 0).mla(m5addm6, 32.f).add(m3addm4).add(m1addm2); \ | |||
CONCAT(s, 1) = m1subm2; \ | |||
CONCAT(s, 1).mla(m3subm4, 2.f).mla(m5subm6, 16.f); \ | |||
CONCAT(s, 2) = m1addm2; \ | |||
CONCAT(s, 2).mla(m3addm4, 4.f).mla(m5addm6, 8.f); \ | |||
CONCAT(s, 3) = m1subm2; \ | |||
CONCAT(s, 3).mla(m3subm4, 8.f).mla(m5subm6, 4.f); \ | |||
CONCAT(s, 4) = m1addm2; \ | |||
CONCAT(s, 4).mla(m3addm4, 16.f).mla(m5addm6, 2.f); \ | |||
CONCAT(s, 5) = m1subm2; \ | |||
CONCAT(s, 5).mla(m3subm4, 32.f).add(m5subm6).add(m##7); \ | |||
#define OUTPUT_TRANSFORM(m, s) \ | |||
do { \ | |||
auto m1addm2 = ADDFV2(m##1, m##2); \ | |||
auto m1subm2 = SUBFV2(m##1, m##2); \ | |||
auto m3addm4 = ADDFV2(m##3, m##4); \ | |||
auto m3subm4 = SUBFV2(m##3, m##4); \ | |||
auto m5addm6 = MULSFV2((ADDFV2(m##5, m##6)), 0.03125f); \ | |||
auto m5subm6 = MULSFV2((SUBFV2(m##5, m##6)), 0.03125f); \ | |||
s##0 = m##0; \ | |||
CONCAT(s, 0) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 0), m5addm6, 32.f); \ | |||
CONCAT(s, 0) = ADDFV2(CONCAT(s, 0), m3addm4); \ | |||
CONCAT(s, 0) = ADDFV2(CONCAT(s, 0), m1addm2); \ | |||
CONCAT(s, 1) = m1subm2; \ | |||
CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m3subm4, 2.f); \ | |||
CONCAT(s, 1) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 1), m5subm6, 16.f); \ | |||
CONCAT(s, 2) = m1addm2; \ | |||
CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m3addm4, 4.f); \ | |||
CONCAT(s, 2) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 2), m5addm6, 8.f); \ | |||
CONCAT(s, 3) = m1subm2; \ | |||
CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m3subm4, 8.f); \ | |||
CONCAT(s, 3) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 3), m5subm6, 4.f); \ | |||
CONCAT(s, 4) = m1addm2; \ | |||
CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m3addm4, 16.f); \ | |||
CONCAT(s, 4) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 4), m5addm6, 2.f); \ | |||
CONCAT(s, 5) = m1subm2; \ | |||
CONCAT(s, 5) = GiMultiplyAddScalarFloat32V2(CONCAT(s, 5), m3subm4, 32.f); \ | |||
CONCAT(s, 5) = ADDFV2(CONCAT(s, 5), m5subm6); \ | |||
CONCAT(s, 5) = ADDFV2(CONCAT(s, 5), m##7); \ | |||
} while (0); | |||
#if defined(__GNUC__) && !defined(__llvm__) && !defined(_MSC_VER) | |||
@@ -224,11 +237,11 @@ struct OutputTransform6X3 { | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
#undef cb | |||
#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||
#define cb(i) auto m##i = GiLoadFloat32V2(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> s##i, ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#define cb(i) GI_FLOAT32_V2_t s##i; | |||
UNROLL_CALL_NOWRAPPER(6, cb); | |||
#undef cb | |||
OUTPUT_TRANSFORM(m, s); | |||
@@ -4,7 +4,6 @@ | |||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/fallback/elemwise_helper/op_unary.h" | |||
@@ -76,8 +75,7 @@ struct InputTransform6X3 { | |||
size_t nr_units_in_tile, size_t ic, size_t IC) { | |||
constexpr size_t alpha = 6 + 3 - 1; | |||
// BT * d * B | |||
#define cb(m, n) \ | |||
Vector<float, 4> d##m##n = Vector<float, 4>::load(patchT + m * 8 * 4 + n * 4); | |||
#define cb(m, n) GI_FLOAT32_t d##m##n = GiLoadFloat32(patchT + m * 8 * 4 + n * 4); | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
#undef cb | |||
@@ -91,36 +89,103 @@ struct InputTransform6X3 { | |||
//! 0 1 -1 2 -2 0.5 -0.5 -5.25 | |||
//! -1 1 1 1 1 1 1 0 | |||
//! 0 0 0 0 0 0 0 1 | |||
/* | |||
* auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; | |||
* auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; | |||
* auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; | |||
* auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + | |||
* d5##m * 2.f + d6##m; | |||
* auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - d4##m * 1.25f - | |||
* d5##m * 2.f + d6##m; | |||
* auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + | |||
* d5##m * 0.5f + d6##m; | |||
* auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - | |||
* d5##m * 0.5f + d6##m; | |||
* auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; | |||
*/ | |||
#define cb(m) \ | |||
auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \ | |||
auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \ | |||
auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \ | |||
auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \ | |||
d5##m * 2.f + d6##m; \ | |||
auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - d4##m * 1.25f - \ | |||
d5##m * 2.f + d6##m; \ | |||
auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \ | |||
d5##m * 0.5f + d6##m; \ | |||
auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \ | |||
d5##m * 0.5f + d6##m; \ | |||
auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; | |||
auto t0##m = SUBF(ADDF(d0##m, MULSF(SUBF(d4##m, d2##m), 5.25f)), d6##m); \ | |||
auto t1##m = \ | |||
SUBF(ADDF(ADDF(ADDF(d1##m, d2##m), d5##m), d6##m), \ | |||
MULSF(ADDF(d3##m, d4##m), 4.25f)); \ | |||
auto t2##m = \ | |||
ADDF(SUBF(ADDF(d2##m, d6##m), ADDF(d1##m, d5##m)), \ | |||
MULSF(SUBF(d3##m, d4##m), 4.25f)); \ | |||
auto t3##m = \ | |||
ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(d1##m, 0.5f), MULSF(d2##m, 0.25f)), \ | |||
MULSF(d3##m, 2.5f)), \ | |||
MULSF(d4##m, 1.25f)), \ | |||
MULSF(d5##m, 2.f)), \ | |||
d6##m); \ | |||
auto t4##m = \ | |||
ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(d1##m, (-0.5f)), MULSF(d2##m, 0.25f)), \ | |||
MULSF(d3##m, 2.5f)), \ | |||
MULSF(d4##m, 1.25f)), \ | |||
MULSF(d5##m, 2.f)), \ | |||
d6##m); \ | |||
auto t5##m = \ | |||
ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(d1##m, 2.f), MULSF(d2##m, 4.f)), \ | |||
MULSF(d3##m, 2.5f)), \ | |||
MULSF(d4##m, 5.f)), \ | |||
MULSF(d5##m, 0.5f)), \ | |||
d6##m); \ | |||
auto t6##m = \ | |||
ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(d1##m, (-2.f)), MULSF(d2##m, 4.f)), \ | |||
MULSF(d3##m, 2.5f)), \ | |||
MULSF(d4##m, 5.f)), \ | |||
MULSF(d5##m, 0.5f)), \ | |||
d6##m); \ | |||
auto t7##m = ADDF(SUBF(d7##m, d1##m), MULSF(SUBF(d3##m, d5##m), 5.25f)); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \ | |||
d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - (t##m##3 + t##m##4) * 4.25f; \ | |||
d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + (t##m##3 - t##m##4) * 4.25f; \ | |||
d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - t##m##4 * 1.25f + \ | |||
t##m##5 * 2.f + t##m##6; \ | |||
d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - t##m##4 * 1.25f - \ | |||
t##m##5 * 2.f + t##m##6; \ | |||
d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \ | |||
t##m##5 * 0.5f + t##m##6; \ | |||
d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - t##m##4 * 5.f - \ | |||
t##m##5 * 0.5f + t##m##6; \ | |||
d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; | |||
/* | |||
* d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; | |||
* d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - (t##m##3 + t##m##4) | |||
* * 4.25f; d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + (t##m##3 - | |||
* t##m##4) * 4.25f; d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f | |||
* - t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; d##m##4 = t##m##1 * (-0.5f) + | |||
* t##m##2 * 0.25f + t##m##3 * 2.5f - t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; | |||
* d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + | |||
* t##m##5 * 0.5f + t##m##6; | |||
* d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - t##m##4 * 5.f - | |||
* t##m##5 * 0.5f + t##m##6; | |||
* d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; | |||
*/ | |||
#define cb(m) \ | |||
d##m##0 = SUBF(ADDF(t##m##0, MULSF(SUBF(t##m##4, t##m##2), 5.25f)), t##m##6); \ | |||
d##m##1 = \ | |||
SUBF(ADDF(ADDF(ADDF(t##m##1, t##m##2), t##m##5), t##m##6), \ | |||
MULSF(ADDF(t##m##3, t##m##4), 4.25f)); \ | |||
d##m##2 = \ | |||
ADDF(SUBF(ADDF(t##m##2, t##m##6), ADDF(t##m##1, t##m##5)), \ | |||
MULSF(SUBF(t##m##3, t##m##4), 4.25f)); \ | |||
d##m##3 = \ | |||
ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(t##m##1, 0.5f), MULSF(t##m##2, 0.25f)), \ | |||
MULSF(t##m##3, 2.5f)), \ | |||
MULSF(t##m##4, 1.25f)), \ | |||
MULSF(t##m##5, 2.f)), \ | |||
t##m##6); \ | |||
d##m##4 = \ | |||
ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(t##m##1, (-0.5f)), MULSF(t##m##2, 0.25f)), \ | |||
MULSF(t##m##3, 2.5f)), \ | |||
MULSF(t##m##4, 1.25f)), \ | |||
MULSF(t##m##5, 2.f)), \ | |||
t##m##6); \ | |||
d##m##5 = \ | |||
ADDF(ADDF(SUBF(SUBF(ADDF(MULSF(t##m##1, 2.f), MULSF(t##m##2, 4.f)), \ | |||
MULSF(t##m##3, 2.5f)), \ | |||
MULSF(t##m##4, 5.f)), \ | |||
MULSF(t##m##5, 0.5f)), \ | |||
t##m##6); \ | |||
d##m##6 = \ | |||
ADDF(SUBF(SUBF(ADDF(ADDF(MULSF(t##m##1, (-2.f)), MULSF(t##m##2, 4.f)), \ | |||
MULSF(t##m##3, 2.5f)), \ | |||
MULSF(t##m##4, 5.f)), \ | |||
MULSF(t##m##5, 0.5f)), \ | |||
t##m##6); \ | |||
d##m##7 = ADDF(SUBF(t##m##7, t##m##1), MULSF(SUBF(t##m##3, t##m##5), 5.25f)); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
@@ -128,9 +193,10 @@ struct InputTransform6X3 { | |||
size_t ICB = IC / 4; | |||
size_t icb = ic / 4; | |||
#define cb(m, n) \ | |||
d##m##n.save( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ | |||
icb * nr_units_in_tile * 4 + unit_idx * 4); | |||
icb * nr_units_in_tile * 4 + unit_idx * 4, \ | |||
d##m##n); | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb) | |||
#undef cb | |||
} | |||
@@ -152,7 +218,7 @@ struct OutputTransform6X3 { | |||
size_t ocb = oc_index / 4; | |||
#define cb(m, n) \ | |||
auto v##m##n = Vector<float, 4>::load( \ | |||
auto v##m##n = GiLoadFloat32( \ | |||
output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ | |||
ocb * nr_units_in_tile * 4 + unit_idx * 4); | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
@@ -171,56 +237,88 @@ struct OutputTransform6X3 { | |||
* 0 0.0 0 0 0 1 | |||
*/ | |||
Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||
#define cb(m) \ | |||
v1addv2 = v1##m + v2##m; \ | |||
v1subv2 = v1##m - v2##m; \ | |||
v3addv4 = v3##m + v4##m; \ | |||
v3subv4 = v3##m - v4##m; \ | |||
v5addv6 = v5##m + v6##m; \ | |||
v5subv6 = v5##m - v6##m; \ | |||
auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \ | |||
auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ | |||
auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ | |||
auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ | |||
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | |||
/* | |||
* v1addv2 = v1##m + v2##m; | |||
* v1subv2 = v1##m - v2##m; | |||
* v3addv4 = v3##m + v4##m; | |||
* v3subv4 = v3##m - v4##m; | |||
* v5addv6 = v5##m + v6##m; | |||
* v5subv6 = v5##m - v6##m; | |||
* auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; | |||
* auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; | |||
* auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; | |||
* auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; | |||
* auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; | |||
* auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | |||
*/ | |||
GI_FLOAT32_t v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||
#define cb(m) \ | |||
v1addv2 = ADDF(v1##m, v2##m); \ | |||
v1subv2 = SUBF(v1##m, v2##m); \ | |||
v3addv4 = ADDF(v3##m, v4##m); \ | |||
v3subv4 = SUBF(v3##m, v4##m); \ | |||
v5addv6 = ADDF(v5##m, v6##m); \ | |||
v5subv6 = SUBF(v5##m, v6##m); \ | |||
auto t0##m = ADDF(ADDF(ADDF(v0##m, v1addv2), v3addv4), v5addv6); \ | |||
auto t1##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \ | |||
auto t2##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \ | |||
auto t3##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \ | |||
auto t4##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \ | |||
auto t5##m = \ | |||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ | |||
v7##m); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
v1addv2 = t##m##1 + t##m##2; \ | |||
v1subv2 = t##m##1 - t##m##2; \ | |||
v3addv4 = t##m##3 + t##m##4; \ | |||
v3subv4 = t##m##3 - t##m##4; \ | |||
v5addv6 = t##m##5 + t##m##6; \ | |||
v5subv6 = t##m##5 - t##m##6; \ | |||
v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \ | |||
v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ | |||
v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ | |||
v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ | |||
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | |||
/* | |||
* v1addv2 = t##m##1 + t##m##2; | |||
* v1subv2 = t##m##1 - t##m##2; | |||
* v3addv4 = t##m##3 + t##m##4; | |||
* v3subv4 = t##m##3 - t##m##4; | |||
* v5addv6 = t##m##5 + t##m##6; | |||
* v5subv6 = t##m##5 - t##m##6; | |||
* v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; | |||
* v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; | |||
* v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; | |||
* v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; | |||
* v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; | |||
* v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | |||
*/ | |||
#define cb(m) \ | |||
v1addv2 = ADDF(t##m##1, t##m##2); \ | |||
v1subv2 = SUBF(t##m##1, t##m##2); \ | |||
v3addv4 = ADDF(t##m##3, t##m##4); \ | |||
v3subv4 = SUBF(t##m##3, t##m##4); \ | |||
v5addv6 = ADDF(t##m##5, t##m##6); \ | |||
v5subv6 = SUBF(t##m##5, t##m##6); \ | |||
v##m##0 = ADDF(ADDF(ADDF(t##m##0, v1addv2), v3addv4), v5addv6); \ | |||
v##m##1 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \ | |||
v##m##2 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \ | |||
v##m##3 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \ | |||
v##m##4 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \ | |||
v##m##5 = \ | |||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ | |||
t##m##7); | |||
UNROLL_CALL_NOWRAPPER(6, cb); | |||
#undef cb | |||
Vector<float, 4> vbias; | |||
GI_FLOAT32_t vbias; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
vbias = Vector<float, 4>::load(bias + oc); | |||
vbias = GiLoadFloat32(bias + oc); | |||
#define cb(m, n) v##m##n += vbias; | |||
#define cb(m, n) v##m##n = GiAddFloat32(v##m##n, vbias); | |||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||
#undef cb | |||
} | |||
if (bmode != BiasMode::BIAS) { | |||
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value)); | |||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n)); | |||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||
#undef cb | |||
} | |||
#define cb(m, n) CONCAT(v##m, n).save(transform_mid_buf + (m * 6 + n) * 4); | |||
#define cb(m, n) GiStoreFloat32(transform_mid_buf + (m * 6 + n) * 4, CONCAT(v##m, n)); | |||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||
#undef cb | |||
@@ -1,7 +1,6 @@ | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
@@ -29,7 +28,7 @@ struct InputTransformF23_NCHW44 { | |||
size_t iw4_start = iw_start * pack_size; | |||
size_t ICB = IC / pack_size; | |||
#define cb(m, n) Vector<float, 4> d##m##n; | |||
#define cb(m, n) GI_FLOAT32_t d##m##n; | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
#undef cb | |||
@@ -40,7 +39,7 @@ struct InputTransformF23_NCHW44 { | |||
MEGDNN_MARK_USED_VAR(patchT); | |||
const float* input_ptr = | |||
input + icb * IH * IW4 + ih_start * IW4 + iw4_start; | |||
#define cb(n, m) d##m##n = Vector<float, 4>::load(input_ptr + pack_size * n); | |||
#define cb(n, m) d##m##n = GiLoadFloat32(input_ptr + pack_size * n); | |||
UNROLL_CALL_RAW(4, cb, 0); | |||
input_ptr += IW4; | |||
@@ -66,7 +65,7 @@ struct InputTransformF23_NCHW44 { | |||
} | |||
} | |||
#define cb(m, n) \ | |||
d##m##n = Vector<float, 4>::load(patchT + m * alpha * pack_size + n * pack_size); | |||
d##m##n = GiLoadFloat32(patchT + m * alpha * pack_size + n * pack_size); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
#undef cb | |||
} | |||
@@ -74,29 +73,30 @@ struct InputTransformF23_NCHW44 { | |||
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | |||
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | |||
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 | |||
#define cb(m) \ | |||
auto t0##m = d0##m - d2##m; \ | |||
auto t1##m = d1##m + d2##m; \ | |||
auto t2##m = d2##m - d1##m; \ | |||
auto t3##m = d3##m - d1##m; | |||
#define cb(m) \ | |||
auto t0##m = SUBF(d0##m, d2##m); \ | |||
auto t1##m = ADDF(d1##m, d2##m); \ | |||
auto t2##m = SUBF(d2##m, d1##m); \ | |||
auto t3##m = SUBF(d3##m, d1##m); | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
d##m##0 = t##m##0 - t##m##2; \ | |||
d##m##1 = t##m##1 + t##m##2; \ | |||
d##m##2 = t##m##2 - t##m##1; \ | |||
d##m##3 = t##m##3 - t##m##1; | |||
#define cb(m) \ | |||
d##m##0 = SUBF(t##m##0, t##m##2); \ | |||
d##m##1 = ADDF(t##m##1, t##m##2); \ | |||
d##m##2 = SUBF(t##m##2, t##m##1); \ | |||
d##m##3 = SUBF(t##m##3, t##m##1); | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(m, n) \ | |||
d##m##n.save( \ | |||
input_transform_buf + \ | |||
(m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size); | |||
#define cb(m, n) \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d##m##n); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) | |||
#undef cb | |||
} | |||
@@ -118,7 +118,7 @@ struct OutputTransformF23_NCHW44 { | |||
size_t ocb = oc_index / pack_size; | |||
#define cb(m, n) \ | |||
auto v##m##n = Vector<float, 4>::load( \ | |||
auto v##m##n = GiLoadFloat32( \ | |||
output_transform_buf + \ | |||
(m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ | |||
ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); | |||
@@ -130,30 +130,30 @@ struct OutputTransformF23_NCHW44 { | |||
//! v20 v21 v22 v23 1 -1 | |||
//! v30 v31 v32 v33 0 1 | |||
#define cb(m) \ | |||
auto t0##m = v0##m + v1##m + v2##m; \ | |||
auto t1##m = v1##m - v2##m + v3##m; | |||
#define cb(m) \ | |||
auto t0##m = ADDF(ADDF(v0##m, v1##m), v2##m); \ | |||
auto t1##m = ADDF(SUBF(v1##m, v2##m), v3##m); | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
v##m##0 = t##m##0 + t##m##1 + t##m##2; \ | |||
v##m##1 = t##m##1 - t##m##2 + t##m##3; | |||
#define cb(m) \ | |||
v##m##0 = ADDF(ADDF(t##m##0, t##m##1), t##m##2); \ | |||
v##m##1 = ADDF(SUBF(t##m##1, t##m##2), t##m##3); | |||
UNROLL_CALL_NOWRAPPER(2, cb); | |||
#undef cb | |||
Vector<float, 4> vbias; | |||
GI_FLOAT32_t vbias; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
vbias = Vector<float, 4>::load(bias + oc); | |||
vbias = GiLoadFloat32(bias + oc); | |||
#define cb(m, n) v##m##n += vbias; | |||
#define cb(m, n) v##m##n = ADDF(v##m##n, vbias); | |||
UNROLL_CALL_RAW_D2(2, 2, cb); | |||
#undef cb | |||
} | |||
if (bmode != BiasMode::BIAS) { | |||
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value)); | |||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n)); | |||
UNROLL_CALL_RAW_D2(2, 2, cb); | |||
#undef cb | |||
} | |||
@@ -163,12 +163,15 @@ struct OutputTransformF23_NCHW44 { | |||
size_t ow = ow_start + owo; \ | |||
if (oh < OH && ow < OW) { \ | |||
if (bmode == BiasMode::BIAS) { \ | |||
v##oho##owo += Vector<float, 4>::load( \ | |||
bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ | |||
v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \ | |||
v##oho##owo = ADDF( \ | |||
v##oho##owo, GiLoadFloat32( \ | |||
bias + oc * OH * OW + \ | |||
oh * OW * pack_size + ow * pack_size)); \ | |||
v##oho##owo = op(v##oho##owo); \ | |||
} \ | |||
v##oho##owo.save( \ | |||
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ | |||
GiStoreFloat32( \ | |||
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \ | |||
v##oho##owo); \ | |||
} \ | |||
} while (0); | |||
UNROLL_CALL_RAW_D2(2, 2, out_save); | |||
@@ -211,29 +214,30 @@ void winograd_F23_mk4_f_nchw44::filter( | |||
pack_size * pack_size + | |||
ic_inner * pack_size; | |||
#define cb(m, n) \ | |||
Vector<float, 4> g##m##n = Vector<float, 4>::load( \ | |||
fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); | |||
#define cb(m, n) \ | |||
GI_FLOAT32_t g##m##n = \ | |||
GiLoadFloat32(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); | |||
UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) | |||
#undef cb | |||
#define FILTER_TRANSFORM(n, wd, g) \ | |||
auto wd##n##0 = g##0##n; \ | |||
tmp0 = (g##0##n + g##2##n) * 0.5; \ | |||
tmp1 = g##1##n * 0.5; \ | |||
auto wd##n##1 = tmp0 + tmp1; \ | |||
auto wd##n##2 = tmp0 - tmp1; \ | |||
#define FILTER_TRANSFORM(n, wd, g) \ | |||
auto wd##n##0 = g##0##n; \ | |||
tmp0 = MULSF(ADDF(g##0##n, g##2##n), 0.5); \ | |||
tmp1 = MULSF(g##1##n, 0.5); \ | |||
auto wd##n##1 = ADDF(tmp0, tmp1); \ | |||
auto wd##n##2 = SUBF(tmp0, tmp1); \ | |||
auto wd##n##3 = g##2##n; | |||
Vector<float, 4> tmp0, tmp1; | |||
GI_FLOAT32_t tmp0, tmp1; | |||
UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); | |||
UNROLL_CALL_RAW(4, FILTER_TRANSFORM, ret, wd); | |||
#undef FILTER_TRANSFORM | |||
#define cb_save(m, n) \ | |||
ret##m##n.save( \ | |||
filter_transform_buf + \ | |||
(m * ALPHA + n) * OCB * ICB * pack_size * pack_size + \ | |||
ocb * ICB * pack_size * pack_size + icb * pack_size * pack_size + \ | |||
ic_inner * pack_size); | |||
#define cb_save(m, n) \ | |||
GiStoreFloat32( \ | |||
filter_transform_buf + \ | |||
(m * ALPHA + n) * OCB * ICB * pack_size * pack_size + \ | |||
ocb * ICB * pack_size * pack_size + icb * pack_size * pack_size + \ | |||
ic_inner * pack_size, \ | |||
ret##m##n); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb_save) | |||
#undef cb_save | |||
} | |||
@@ -4,7 +4,6 @@ | |||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/fallback/elemwise_helper/op_unary.h" | |||
@@ -114,17 +113,17 @@ struct InputTransformF63_NCHW44 { | |||
auto t##i##4 = d6; \ | |||
auto t##i##5 = d6; \ | |||
auto t##i##6 = d6; \ | |||
t##i##0 = GiSubtractFloat32(t##i##0, d6); \ | |||
t##i##1 = GiAddFloat32(t##i##1, d1); \ | |||
t##i##2 = GiSubtractFloat32(t##i##2, d1); \ | |||
t##i##0 = SUBF(t##i##0, d6); \ | |||
t##i##1 = ADDF(t##i##1, d1); \ | |||
t##i##2 = SUBF(t##i##2, d1); \ | |||
t##i##3 = MADD(t##i##3, d1, v0, 2); \ | |||
t##i##4 = MSUB(t##i##4, d1, v0, 2); \ | |||
t##i##5 = MADD(t##i##5, d1, v1, 2); \ | |||
t##i##6 = MSUB(t##i##6, d1, v1, 2); \ | |||
t##i##7 = GiSubtractFloat32(t##i##7, d1); \ | |||
t##i##7 = SUBF(t##i##7, d1); \ | |||
t##i##0 = MSUB(t##i##0, d2, v0, 0); \ | |||
t##i##1 = GiAddFloat32(t##i##1, d2); \ | |||
t##i##2 = GiAddFloat32(t##i##2, d2); \ | |||
t##i##1 = ADDF(t##i##1, d2); \ | |||
t##i##2 = ADDF(t##i##2, d2); \ | |||
t##i##3 = MADD(t##i##3, d2, v0, 3); \ | |||
t##i##4 = MADD(t##i##4, d2, v0, 3); \ | |||
t##i##5 = MADD(t##i##5, d2, v1, 3); \ | |||
@@ -143,8 +142,8 @@ struct InputTransformF63_NCHW44 { | |||
t##i##4 = MSUB(t##i##4, d4, v1, 1); \ | |||
t##i##5 = MSUB(t##i##5, d4, v2, 0); \ | |||
t##i##6 = MSUB(t##i##6, d4, v2, 0); \ | |||
t##i##1 = GiAddFloat32(t##i##1, d5); \ | |||
t##i##2 = GiSubtractFloat32(t##i##2, d5); \ | |||
t##i##1 = ADDF(t##i##1, d5); \ | |||
t##i##2 = SUBF(t##i##2, d5); \ | |||
t##i##3 = MADD(t##i##3, d5, v1, 2); \ | |||
t##i##4 = MSUB(t##i##4, d5, v1, 2); \ | |||
t##i##5 = MADD(t##i##5, d5, v0, 2); \ | |||
@@ -162,17 +161,17 @@ struct InputTransformF63_NCHW44 { | |||
d5 = t6##i; \ | |||
d6 = t6##i; \ | |||
d7 = t7##i; \ | |||
d0 = GiSubtractFloat32(d0, t6##i); \ | |||
d1 = GiAddFloat32(d1, t1##i); \ | |||
d2 = GiSubtractFloat32(d2, t1##i); \ | |||
d0 = SUBF(d0, t6##i); \ | |||
d1 = ADDF(d1, t1##i); \ | |||
d2 = SUBF(d2, t1##i); \ | |||
d3 = MADD(d3, t1##i, v0, 2); \ | |||
d4 = MSUB(d4, t1##i, v0, 2); \ | |||
d5 = MADD(d5, t1##i, v1, 2); \ | |||
d6 = MSUB(d6, t1##i, v1, 2); \ | |||
d7 = GiSubtractFloat32(d7, t1##i); \ | |||
d7 = SUBF(d7, t1##i); \ | |||
d0 = MSUB(d0, t2##i, v0, 0); \ | |||
d1 = GiAddFloat32(d1, t2##i); \ | |||
d2 = GiAddFloat32(d2, t2##i); \ | |||
d1 = ADDF(d1, t2##i); \ | |||
d2 = ADDF(d2, t2##i); \ | |||
d3 = MADD(d3, t2##i, v0, 3); \ | |||
d4 = MADD(d4, t2##i, v0, 3); \ | |||
d5 = MADD(d5, t2##i, v1, 3); \ | |||
@@ -191,8 +190,8 @@ struct InputTransformF63_NCHW44 { | |||
d4 = MSUB(d4, t4##i, v1, 1); \ | |||
d5 = MSUB(d5, t4##i, v2, 0); \ | |||
d6 = MSUB(d6, t4##i, v2, 0); \ | |||
d1 = GiAddFloat32(d1, t5##i); \ | |||
d2 = GiSubtractFloat32(d2, t5##i); \ | |||
d1 = ADDF(d1, t5##i); \ | |||
d2 = SUBF(d2, t5##i); \ | |||
d3 = MADD(d3, t5##i, v1, 2); \ | |||
d4 = MSUB(d4, t5##i, v1, 2); \ | |||
d5 = MADD(d5, t5##i, v0, 2); \ | |||
@@ -261,7 +260,7 @@ struct OutputTransformF63_NCHW44 { | |||
size_t ocb = oc_index / pack_size; | |||
#define cb(m, n) \ | |||
auto v##m##n = Vector<float, 4>::load( \ | |||
auto v##m##n = GiLoadFloat32( \ | |||
output_transform_buf + \ | |||
(m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ | |||
ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); | |||
@@ -281,51 +280,83 @@ struct OutputTransformF63_NCHW44 { | |||
* 0 0 0 0 0 1 | |||
*/ | |||
Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||
#define cb(m) \ | |||
v1addv2 = v1##m + v2##m; \ | |||
v1subv2 = v1##m - v2##m; \ | |||
v3addv4 = v3##m + v4##m; \ | |||
v3subv4 = v3##m - v4##m; \ | |||
v5addv6 = v5##m + v6##m; \ | |||
v5subv6 = v5##m - v6##m; \ | |||
auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \ | |||
auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ | |||
auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ | |||
auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ | |||
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | |||
/* | |||
* v1addv2 = v1##m + v2##m; | |||
* v1subv2 = v1##m - v2##m; | |||
* v3addv4 = v3##m + v4##m; | |||
* v3subv4 = v3##m - v4##m; | |||
* v5addv6 = v5##m + v6##m; | |||
* v5subv6 = v5##m - v6##m; | |||
* auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; | |||
* auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; | |||
* auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; | |||
* auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; | |||
* auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; | |||
* auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | |||
*/ | |||
GI_FLOAT32_t v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||
#define cb(m) \ | |||
v1addv2 = ADDF(v1##m, v2##m); \ | |||
v1subv2 = SUBF(v1##m, v2##m); \ | |||
v3addv4 = ADDF(v3##m, v4##m); \ | |||
v3subv4 = SUBF(v3##m, v4##m); \ | |||
v5addv6 = ADDF(v5##m, v6##m); \ | |||
v5subv6 = SUBF(v5##m, v6##m); \ | |||
auto t0##m = ADDF(ADDF(ADDF(v0##m, v1addv2), v3addv4), v5addv6); \ | |||
auto t1##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \ | |||
auto t2##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \ | |||
auto t3##m = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \ | |||
auto t4##m = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \ | |||
auto t5##m = \ | |||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ | |||
v7##m); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
v1addv2 = t##m##1 + t##m##2; \ | |||
v1subv2 = t##m##1 - t##m##2; \ | |||
v3addv4 = t##m##3 + t##m##4; \ | |||
v3subv4 = t##m##3 - t##m##4; \ | |||
v5addv6 = t##m##5 + t##m##6; \ | |||
v5subv6 = t##m##5 - t##m##6; \ | |||
v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \ | |||
v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ | |||
v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ | |||
v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ | |||
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | |||
/* | |||
* v1addv2 = t##m##1 + t##m##2; | |||
* v1subv2 = t##m##1 - t##m##2; | |||
* v3addv4 = t##m##3 + t##m##4; | |||
* v3subv4 = t##m##3 - t##m##4; | |||
* v5addv6 = t##m##5 + t##m##6; | |||
* v5subv6 = t##m##5 - t##m##6; | |||
* v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; | |||
* v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; | |||
* v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; | |||
* v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; | |||
* v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; | |||
* v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | |||
*/ | |||
#define cb(m) \ | |||
v1addv2 = ADDF(t##m##1, t##m##2); \ | |||
v1subv2 = SUBF(t##m##1, t##m##2); \ | |||
v3addv4 = ADDF(t##m##3, t##m##4); \ | |||
v3subv4 = SUBF(t##m##3, t##m##4); \ | |||
v5addv6 = ADDF(t##m##5, t##m##6); \ | |||
v5subv6 = SUBF(t##m##5, t##m##6); \ | |||
v##m##0 = ADDF(ADDF(ADDF(t##m##0, v1addv2), v3addv4), v5addv6); \ | |||
v##m##1 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)); \ | |||
v##m##2 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)); \ | |||
v##m##3 = ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)); \ | |||
v##m##4 = ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)); \ | |||
v##m##5 = \ | |||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ | |||
t##m##7); | |||
UNROLL_CALL_NOWRAPPER(6, cb); | |||
#undef cb | |||
Vector<float, 4> vbias; | |||
GI_FLOAT32_t vbias; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
vbias = Vector<float, 4>::load(bias + oc); | |||
vbias = GiLoadFloat32(bias + oc); | |||
#define cb(m, n) v##m##n += vbias; | |||
#define cb(m, n) v##m##n = ADDF(v##m##n, vbias); | |||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||
#undef cb | |||
} | |||
if (bmode != BiasMode::BIAS) { | |||
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value)); | |||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n)); | |||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||
#undef cb | |||
} | |||
@@ -335,12 +366,15 @@ struct OutputTransformF63_NCHW44 { | |||
size_t ow = ow_start + owo; \ | |||
if (oh < OH && ow < OW) { \ | |||
if (bmode == BiasMode::BIAS) { \ | |||
v##oho##owo += Vector<float, 4>::load( \ | |||
bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ | |||
v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \ | |||
v##oho##owo = ADDF( \ | |||
v##oho##owo, GiLoadFloat32( \ | |||
bias + oc * OH * OW + \ | |||
oh * OW * pack_size + ow * pack_size)); \ | |||
v##oho##owo = op(v##oho##owo); \ | |||
} \ | |||
v##oho##owo.save( \ | |||
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ | |||
GiStoreFloat32( \ | |||
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \ | |||
v##oho##owo); \ | |||
} \ | |||
} while (0); | |||
UNROLL_CALL_RAW_D2(6, 6, out_save); | |||
@@ -387,35 +421,52 @@ void winograd_F63_mk4_f_nchw44::filter( | |||
pack_size * pack_size + | |||
ic_inner * pack_size; | |||
#define cb(m, n) \ | |||
Vector<float, 4> g##m##n = Vector<float, 4>::load( \ | |||
fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); | |||
#define cb(m, n) \ | |||
GI_FLOAT32_t g##m##n = \ | |||
GiLoadFloat32(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); | |||
UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) | |||
#undef cb | |||
#define FILTER_TRANSFORM(n, wd, g) \ | |||
auto wd##n##0 = g##0##n; \ | |||
tmp0 = (g##0##n + g##2##n) * -0.2222222f; \ | |||
tmp1 = g##1##n * -0.2222222f; \ | |||
auto wd##n##1 = tmp0 + tmp1; \ | |||
auto wd##n##2 = tmp0 - tmp1; \ | |||
tmp0 = g##0##n * 0.0111111f + g##2##n * 0.0444444f; \ | |||
tmp1 = g##1##n * 0.0222222f; \ | |||
auto wd##n##3 = tmp0 + tmp1; \ | |||
auto wd##n##4 = tmp0 - tmp1; \ | |||
tmp0 = g##0##n * 0.7111111f + g##2##n * 0.1777778f; \ | |||
tmp1 = g##1##n * 0.3555556f; \ | |||
auto wd##n##5 = tmp0 + tmp1; \ | |||
auto wd##n##6 = tmp0 - tmp1; \ | |||
/* | |||
* auto wd##n##0 = g##0##n; | |||
* tmp0 = (g##0##n + g##2##n) * -0.2222222f; | |||
* tmp1 = g##1##n * -0.2222222f; | |||
* auto wd##n##1 = tmp0 + tmp1; | |||
* auto wd##n##2 = tmp0 - tmp1; | |||
* tmp0 = g##0##n * 0.0111111f + g##2##n * 0.0444444f; | |||
* tmp1 = g##1##n * 0.0222222f; | |||
* auto wd##n##3 = tmp0 + tmp1; | |||
* auto wd##n##4 = tmp0 - tmp1; | |||
* tmp0 = g##0##n * 0.7111111f + g##2##n * 0.1777778f; | |||
* tmp1 = g##1##n * 0.3555556f; | |||
* auto wd##n##5 = tmp0 + tmp1; | |||
* auto wd##n##6 = tmp0 - tmp1; | |||
* auto wd##n##7 = g##2##n; | |||
*/ | |||
#define FILTER_TRANSFORM(n, wd, g) \ | |||
auto wd##n##0 = g##0##n; \ | |||
tmp0 = MULSF(ADDF(g##0##n, g##2##n), -0.2222222f); \ | |||
tmp1 = MULSF(g##1##n, -0.2222222f); \ | |||
auto wd##n##1 = ADDF(tmp0, tmp1); \ | |||
auto wd##n##2 = SUBF(tmp0, tmp1); \ | |||
tmp0 = ADDF(MULSF(g##0##n, 0.0111111f), MULSF(g##2##n, 0.0444444f)); \ | |||
tmp1 = MULSF(g##1##n, 0.0222222f); \ | |||
auto wd##n##3 = ADDF(tmp0, tmp1); \ | |||
auto wd##n##4 = SUBF(tmp0, tmp1); \ | |||
tmp0 = ADDF(MULSF(g##0##n, 0.7111111f), MULSF(g##2##n, 0.1777778f)); \ | |||
tmp1 = MULSF(g##1##n, 0.3555556f); \ | |||
auto wd##n##5 = ADDF(tmp0, tmp1); \ | |||
auto wd##n##6 = SUBF(tmp0, tmp1); \ | |||
auto wd##n##7 = g##2##n; | |||
Vector<float, 4> tmp0, tmp1; | |||
GI_FLOAT32_t tmp0, tmp1; | |||
UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); | |||
UNROLL_CALL_RAW(8, FILTER_TRANSFORM, ret, wd); | |||
#undef FILTER_TRANSFORM | |||
#define cb_save(m, n) \ | |||
ret##m##n.save( \ | |||
GiStoreFloat32( \ | |||
filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \ | |||
icb * pack_size * pack_size + ic_inner * pack_size); | |||
icb * pack_size * pack_size + ic_inner * pack_size, \ | |||
ret##m##n); | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb_save) | |||
#undef cb_save | |||
} | |||
@@ -4,7 +4,6 @@ | |||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/fallback/elemwise_helper/op_unary.h" | |||
@@ -137,14 +136,14 @@ struct InputTransformF73_NCHW44 { | |||
auto t##i##6 = d7; \ | |||
auto t##i##7 = d7; \ | |||
t##i##8 = MSUB(t##i##8, d7, v0, 0); \ | |||
t##i##0 = GiSubtractFloat32(t##i##0, d1); \ | |||
t##i##0 = SUBF(t##i##0, d1); \ | |||
t##i##1 = MSUB(t##i##1, d1, v0, 0); \ | |||
t##i##2 = MADD(t##i##2, d1, v0, 0); \ | |||
t##i##3 = MSUB(t##i##3, d1, v0, 1); \ | |||
t##i##4 = MADD(t##i##4, d1, v0, 1); \ | |||
t##i##5 = MSUB(t##i##5, d1, v0, 2); \ | |||
t##i##6 = MADD(t##i##6, d1, v0, 2); \ | |||
t##i##7 = GiSubtractFloat32(t##i##7, d1); \ | |||
t##i##7 = SUBF(t##i##7, d1); \ | |||
t##i##8 = MADD(t##i##8, d1, v0, 0); \ | |||
t##i##0 = MSUB(t##i##0, d2, v0, 3); \ | |||
t##i##1 = MSUB(t##i##1, d2, v1, 0); \ | |||
@@ -153,7 +152,7 @@ struct InputTransformF73_NCHW44 { | |||
t##i##4 = MSUB(t##i##4, d2, v1, 3); \ | |||
t##i##5 = MSUB(t##i##5, d2, v2, 0); \ | |||
t##i##6 = MSUB(t##i##6, d2, v2, 1); \ | |||
t##i##8 = GiSubtractFloat32(t##i##8, d2); \ | |||
t##i##8 = SUBF(t##i##8, d2); \ | |||
t##i##0 = MADD(t##i##0, d3, v2, 2); \ | |||
t##i##1 = MADD(t##i##1, d3, v2, 3); \ | |||
t##i##2 = MSUB(t##i##2, d3, v3, 0); \ | |||
@@ -185,7 +184,7 @@ struct InputTransformF73_NCHW44 { | |||
t##i##2 = MSUB(t##i##2, d6, v1, 1); \ | |||
t##i##3 = MADD(t##i##3, d6, v1, 0); \ | |||
t##i##4 = MSUB(t##i##4, d6, v3, 1); \ | |||
t##i##5 = GiSubtractFloat32(t##i##5, d6); \ | |||
t##i##5 = SUBF(t##i##5, d6); \ | |||
t##i##6 = MSUB(t##i##6, d6, v6, 2); \ | |||
t##i##8 = MSUB(t##i##8, d6, v2, 2); \ | |||
t##i##0 = MADD(t##i##0, d0, v0, 0); | |||
@@ -204,14 +203,14 @@ struct InputTransformF73_NCHW44 { | |||
d6 = t7##i; \ | |||
d7 = t7##i; \ | |||
d8 = MSUB(d8, t7##i, v0, 0); \ | |||
d0 = GiSubtractFloat32(d0, t1##i); \ | |||
d0 = SUBF(d0, t1##i); \ | |||
d1 = MSUB(d1, t1##i, v0, 0); \ | |||
d2 = MADD(d2, t1##i, v0, 0); \ | |||
d3 = MSUB(d3, t1##i, v0, 1); \ | |||
d4 = MADD(d4, t1##i, v0, 1); \ | |||
d5 = MSUB(d5, t1##i, v0, 2); \ | |||
d6 = MADD(d6, t1##i, v0, 2); \ | |||
d7 = GiSubtractFloat32(d7, t1##i); \ | |||
d7 = SUBF(d7, t1##i); \ | |||
d8 = MADD(d8, t1##i, v0, 0); \ | |||
d0 = MSUB(d0, t2##i, v0, 3); \ | |||
d1 = MSUB(d1, t2##i, v1, 0); \ | |||
@@ -220,7 +219,7 @@ struct InputTransformF73_NCHW44 { | |||
d4 = MSUB(d4, t2##i, v1, 3); \ | |||
d5 = MSUB(d5, t2##i, v2, 0); \ | |||
d6 = MSUB(d6, t2##i, v2, 1); \ | |||
d8 = GiSubtractFloat32(d8, t2##i); \ | |||
d8 = SUBF(d8, t2##i); \ | |||
d0 = MADD(d0, t3##i, v2, 2); \ | |||
d1 = MADD(d1, t3##i, v2, 3); \ | |||
d2 = MSUB(d2, t3##i, v3, 0); \ | |||
@@ -252,7 +251,7 @@ struct InputTransformF73_NCHW44 { | |||
d2 = MSUB(d2, t6##i, v1, 1); \ | |||
d3 = MADD(d3, t6##i, v1, 0); \ | |||
d4 = MSUB(d4, t6##i, v3, 1); \ | |||
d5 = GiSubtractFloat32(d5, t6##i); \ | |||
d5 = SUBF(d5, t6##i); \ | |||
d6 = MSUB(d6, t6##i, v6, 2); \ | |||
d8 = MSUB(d8, t6##i, v2, 2); \ | |||
d0 = MADD(d0, t0##i, v0, 0); \ | |||
@@ -325,7 +324,7 @@ struct OutputTransformF73_NCHW44 { | |||
size_t ocb = oc_index / pack_size; | |||
#define cb(m, n) \ | |||
auto v##m##n = Vector<float, 4>::load( \ | |||
auto v##m##n = GiLoadFloat32( \ | |||
output_transform_buf + \ | |||
(m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ | |||
ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); | |||
@@ -346,56 +345,112 @@ struct OutputTransformF73_NCHW44 { | |||
* 1 1.5 2.25 3.375 5.0625 7.59375 11.390625 | |||
* 0 0 0 0 0 0 1 | |||
*/ | |||
/* | |||
* v1addv2 = v1##m + v2##m; | |||
* v1subv2 = v1##m - v2##m; | |||
* v3addv4 = v3##m + v4##m; | |||
* v3subv4 = v3##m - v4##m; | |||
* v5addv6 = v5##m + v6##m; | |||
* v5subv6 = v5##m - v6##m; | |||
* auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6 + v7##m; | |||
* auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + v7##m * 1.5f; | |||
* auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + v7##m * 2.25f; | |||
* auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + v7##m * 3.375f; | |||
* auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + v7##m * 5.0625f; | |||
* auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m | |||
* * 7.59375f; auto t6##m = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + | |||
* v7##m * 11.390625f + v8##m; | |||
*/ | |||
Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||
GI_FLOAT32_t v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||
#define cb(m) \ | |||
v1addv2 = v1##m + v2##m; \ | |||
v1subv2 = v1##m - v2##m; \ | |||
v3addv4 = v3##m + v4##m; \ | |||
v3subv4 = v3##m - v4##m; \ | |||
v5addv6 = v5##m + v6##m; \ | |||
v5subv6 = v5##m - v6##m; \ | |||
auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6 + v7##m; \ | |||
auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + v7##m * 1.5f; \ | |||
auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + v7##m * 2.25f; \ | |||
auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + v7##m * 3.375f; \ | |||
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + v7##m * 5.0625f; \ | |||
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m * 7.59375f; \ | |||
auto t6##m = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + v7##m * 11.390625f + \ | |||
v8##m; | |||
v1addv2 = ADDF(v1##m, v2##m); \ | |||
v1subv2 = SUBF(v1##m, v2##m); \ | |||
v3addv4 = ADDF(v3##m, v4##m); \ | |||
v3subv4 = SUBF(v3##m, v4##m); \ | |||
v5addv6 = ADDF(v5##m, v6##m); \ | |||
v5subv6 = SUBF(v5##m, v6##m); \ | |||
auto t0##m = ADDF(ADDF(ADDF(ADDF(v0##m, v1addv2), v3addv4), v5addv6), v7##m); \ | |||
auto t1##m = \ | |||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)), \ | |||
MULSF(v7##m, 1.5f)); \ | |||
auto t2##m = \ | |||
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)), \ | |||
MULSF(v7##m, 2.25f)); \ | |||
auto t3##m = \ | |||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)), \ | |||
MULSF(v7##m, 3.375f)); \ | |||
auto t4##m = \ | |||
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)), \ | |||
MULSF(v7##m, 5.0625f)); \ | |||
auto t5##m = \ | |||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ | |||
MULSF(v7##m, 7.59375f)); \ | |||
auto t6##m = ADDF( \ | |||
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 64.f)), MULSF(v5addv6, 0.015625f)), \ | |||
MULSF(v7##m, 11.390625f)), \ | |||
v8##m); | |||
UNROLL_CALL_NOWRAPPER(9, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
v1addv2 = t##m##1 + t##m##2; \ | |||
v1subv2 = t##m##1 - t##m##2; \ | |||
v3addv4 = t##m##3 + t##m##4; \ | |||
v3subv4 = t##m##3 - t##m##4; \ | |||
v5addv6 = t##m##5 + t##m##6; \ | |||
v5subv6 = t##m##5 - t##m##6; \ | |||
v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6 + t##m##7; \ | |||
v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + t##m##7 * 1.5f; \ | |||
v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + t##m##7 * 2.25f; \ | |||
v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + t##m##7 * 3.375; \ | |||
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + t##m##7 * 5.0625f; \ | |||
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7 * 7.59375f; \ | |||
v##m##6 = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + t##m##7 * 11.390625f + \ | |||
t##m##8; | |||
/* | |||
* v1addv2 = t##m##1 + t##m##2; | |||
* v1subv2 = t##m##1 - t##m##2; | |||
* v3addv4 = t##m##3 + t##m##4; | |||
* v3subv4 = t##m##3 - t##m##4; | |||
* v5addv6 = t##m##5 + t##m##6; | |||
* v5subv6 = t##m##5 - t##m##6; | |||
* v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6 + t##m##7; | |||
* v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + t##m##7 * 1.5f; | |||
* v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + t##m##7 * 2.25f; | |||
* v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + t##m##7 * 3.375; | |||
* v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + t##m##7 * 5.0625f; | |||
* v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7 * 7.59375f; | |||
* v##m##6 = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + t##m##7 | |||
* * 11.390625f + t##m##8; | |||
*/ | |||
#define cb(m) \ | |||
v1addv2 = ADDF(t##m##1, t##m##2); \ | |||
v1subv2 = SUBF(t##m##1, t##m##2); \ | |||
v3addv4 = ADDF(t##m##3, t##m##4); \ | |||
v3subv4 = SUBF(t##m##3, t##m##4); \ | |||
v5addv6 = ADDF(t##m##5, t##m##6); \ | |||
v5subv6 = SUBF(t##m##5, t##m##6); \ | |||
v##m##0 = ADDF(ADDF(ADDF(ADDF(t##m##0, v1addv2), v3addv4), v5addv6), t##m##7); \ | |||
v##m##1 = \ | |||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 2.f)), MULSF(v5subv6, 0.5f)), \ | |||
MULSF(t##m##7, 1.5f)); \ | |||
v##m##2 = \ | |||
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 4.f)), MULSF(v5addv6, 0.25f)), \ | |||
MULSF(t##m##7, 2.25f)); \ | |||
v##m##3 = \ | |||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 8.f)), MULSF(v5subv6, 0.125f)), \ | |||
MULSF(t##m##7, 3.375)); \ | |||
v##m##4 = \ | |||
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 16.f)), MULSF(v5addv6, 0.0625f)), \ | |||
MULSF(t##m##7, 5.0625f)); \ | |||
v##m##5 = \ | |||
ADDF(ADDF(ADDF(v1subv2, MULSF(v3subv4, 32.f)), MULSF(v5subv6, 0.03125f)), \ | |||
MULSF(t##m##7, 7.59375f)); \ | |||
v##m##6 = ADDF( \ | |||
ADDF(ADDF(ADDF(v1addv2, MULSF(v3addv4, 64.f)), MULSF(v5addv6, 0.015625f)), \ | |||
MULSF(t##m##7, 11.390625f)), \ | |||
t##m##8); | |||
UNROLL_CALL_NOWRAPPER(7, cb); | |||
#undef cb | |||
Vector<float, 4> vbias; | |||
GI_FLOAT32_t vbias; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
vbias = Vector<float, 4>::load(bias + oc); | |||
vbias = GiLoadFloat32(bias + oc); | |||
#define cb(m, n) v##m##n += vbias; | |||
#define cb(m, n) v##m##n = ADDF(v##m##n, vbias); | |||
UNROLL_CALL_RAW_D2(7, 7, cb); | |||
#undef cb | |||
} | |||
if (bmode != BiasMode::BIAS) { | |||
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value)); | |||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n)); | |||
UNROLL_CALL_RAW_D2(7, 7, cb); | |||
#undef cb | |||
} | |||
@@ -405,12 +460,15 @@ struct OutputTransformF73_NCHW44 { | |||
size_t ow = ow_start + owo; \ | |||
if (oh < OH && ow < OW) { \ | |||
if (bmode == BiasMode::BIAS) { \ | |||
v##oho##owo += Vector<float, 4>::load( \ | |||
bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ | |||
v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \ | |||
v##oho##owo = ADDF( \ | |||
v##oho##owo, GiLoadFloat32( \ | |||
bias + oc * OH * OW + \ | |||
oh * OW * pack_size + ow * pack_size)); \ | |||
v##oho##owo = op(v##oho##owo); \ | |||
} \ | |||
v##oho##owo.save( \ | |||
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ | |||
GiStoreFloat32( \ | |||
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \ | |||
v##oho##owo); \ | |||
} \ | |||
} while (0); | |||
UNROLL_CALL_RAW_D2(7, 7, out_save); | |||
@@ -458,34 +516,53 @@ void winograd_F73_mk4_f_nchw44::filter( | |||
pack_size * pack_size + | |||
ic_inner * pack_size; | |||
#define cb(m, n) \ | |||
Vector<float, 4> g##m##n = Vector<float, 4>::load( \ | |||
fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); | |||
#define cb(m, n) \ | |||
GI_FLOAT32_t g##m##n = \ | |||
GiLoadFloat32(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); | |||
UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) | |||
#undef cb | |||
#define FILTER_TRANSFORM(n, wd, g) \ | |||
auto wd##n##0 = g##0##n * 0.6666667f; \ | |||
auto wd##n##1 = (g##0##n + g##1##n + g##2##n) * 0.4444444f; \ | |||
auto wd##n##2 = (g##0##n - g##1##n + g##2##n) * 0.0888889f; \ | |||
auto wd##n##3 = \ | |||
g##0##n * 0.0222222f + g##1##n * 0.0444444f + g##2##n * 0.0888889f; \ | |||
auto wd##n##4 = \ | |||
g##0##n * -0.0031746f + g##1##n * 0.0063492f + g##2##n * -0.0126984f; \ | |||
auto wd##n##5 = \ | |||
g##0##n * -0.7111111f + g##1##n * -0.3555556f + g##2##n * -0.1777778f; \ | |||
auto wd##n##6 = \ | |||
g##0##n * -0.3555556f + g##1##n * 0.1777778f + g##2##n * -0.0888889f; \ | |||
auto wd##n##7 = \ | |||
g##0##n * -0.1523810f + g##1##n * -0.2285714f + g##2##n * -0.3428572f; \ | |||
/* | |||
* auto wd##n##0 = g##0##n * 0.6666667f; | |||
* auto wd##n##1 = (g##0##n + g##1##n + g##2##n) * 0.4444444f; | |||
* auto wd##n##2 = (g##0##n - g##1##n + g##2##n) * 0.0888889f; | |||
* auto wd##n##3 = | |||
* g##0##n * 0.0222222f + g##1##n * 0.0444444f + g##2##n * | |||
* 0.0888889f; auto wd##n##4 = g##0##n * -0.0031746f + g##1##n * | |||
* 0.0063492f + g##2##n * -0.0126984f; auto wd##n##5 = g##0##n * | |||
* -0.7111111f + g##1##n * -0.3555556f + g##2##n * -0.1777778f; auto | |||
* wd##n##6 = g##0##n * -0.3555556f + g##1##n * 0.1777778f + g##2##n * | |||
* -0.0888889f; auto wd##n##7 = g##0##n * -0.1523810f + g##1##n * | |||
* -0.2285714f + g##2##n * -0.3428572f; | |||
*/ | |||
#define FILTER_TRANSFORM(n, wd, g) \ | |||
auto wd##n##0 = MULSF(g##0##n, 0.6666667f); \ | |||
auto wd##n##1 = MULSF(ADDF(ADDF(g##0##n, g##1##n), g##2##n), 0.4444444f); \ | |||
auto wd##n##2 = MULSF(ADDF(SUBF(g##0##n, g##1##n), g##2##n), 0.0888889f); \ | |||
auto wd##n##3 = \ | |||
ADDF(ADDF(MULSF(g##0##n, 0.0222222f), MULSF(g##1##n, 0.0444444f)), \ | |||
MULSF(g##2##n, 0.0888889f)); \ | |||
auto wd##n##4 = \ | |||
ADDF(ADDF(MULSF(g##0##n, -0.0031746f), MULSF(g##1##n, 0.0063492f)), \ | |||
MULSF(g##2##n, -0.0126984f)); \ | |||
auto wd##n##5 = \ | |||
ADDF(ADDF(MULSF(g##0##n, -0.7111111f), MULSF(g##1##n, -0.3555556f)), \ | |||
MULSF(g##2##n, -0.1777778f)); \ | |||
auto wd##n##6 = \ | |||
ADDF(ADDF(MULSF(g##0##n, -0.3555556f), MULSF(g##1##n, 0.1777778f)), \ | |||
MULSF(g##2##n, -0.0888889f)); \ | |||
auto wd##n##7 = \ | |||
ADDF(ADDF(MULSF(g##0##n, -0.1523810f), MULSF(g##1##n, -0.2285714f)), \ | |||
MULSF(g##2##n, -0.3428572f)); \ | |||
auto wd##n##8 = g##2##n; | |||
UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); | |||
UNROLL_CALL_RAW(9, FILTER_TRANSFORM, ret, wd); | |||
#undef FILTER_TRANSFORM | |||
#define cb_save(m, n) \ | |||
ret##m##n.save( \ | |||
GiStoreFloat32( \ | |||
filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \ | |||
icb * pack_size * pack_size + ic_inner * pack_size); | |||
icb * pack_size * pack_size + ic_inner * pack_size, \ | |||
ret##m##n); | |||
UNROLL_CALL_NOWRAPPER_D2(9, 9, cb_save) | |||
#undef cb_save | |||
} | |||
@@ -1,215 +0,0 @@ | |||
#pragma once | |||
#include <cstring> | |||
#include "src/common/utils.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
namespace megdnn { | |||
namespace fallback { | |||
template <typename ctype, size_t len> | |||
struct Vector; | |||
template <> | |||
struct Vector<float, 4> { | |||
GI_FLOAT32_FIXLEN_t value; | |||
Vector() {} | |||
Vector(const float v) { value = GiFloat32Type2FixLenType(GiBroadcastFloat32(v)); } | |||
Vector(const Vector& lr) { value = lr.value; } | |||
Vector(const Vector&& lr) { value = std::move(lr.value); } | |||
Vector(const GI_FLOAT32_t& v) { value = GiFloat32Type2FixLenType(v); } | |||
static Vector load(const float* addr) { | |||
Vector v; | |||
v.value = GiFloat32Type2FixLenType(GiLoadFloat32(addr)); | |||
return v; | |||
} | |||
static void save(float* addr, const Vector& v) { | |||
GiStoreFloat32(addr, GiFixLenType2GiFloat32Type(v.value)); | |||
} | |||
void save(float* addr) { save(addr, *this); } | |||
Vector operator+(const Vector& lr) { | |||
Vector dst; | |||
dst.value = GiFloat32Type2FixLenType(GiAddFloat32( | |||
GiFixLenType2GiFloat32Type(value), | |||
GiFixLenType2GiFloat32Type(lr.value))); | |||
return dst; | |||
} | |||
Vector& operator+=(const Vector& lr) { | |||
value = GiFloat32Type2FixLenType(GiAddFloat32( | |||
GiFixLenType2GiFloat32Type(value), | |||
GiFixLenType2GiFloat32Type(lr.value))); | |||
return *this; | |||
} | |||
Vector operator-(const Vector& lr) { | |||
Vector dst; | |||
dst.value = GiFloat32Type2FixLenType(GiSubtractFloat32( | |||
GiFixLenType2GiFloat32Type(value), | |||
GiFixLenType2GiFloat32Type(lr.value))); | |||
return dst; | |||
} | |||
Vector& operator-=(const Vector& lr) { | |||
value = GiFloat32Type2FixLenType(GiSubtractFloat32( | |||
GiFixLenType2GiFloat32Type(value), | |||
GiFixLenType2GiFloat32Type(lr.value))); | |||
return *this; | |||
} | |||
Vector operator*(float lr) { | |||
Vector dst; | |||
dst.value = GiFloat32Type2FixLenType( | |||
GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value), lr)); | |||
return dst; | |||
} | |||
Vector operator*(const Vector& lr) { | |||
Vector dst; | |||
dst.value = GiFloat32Type2FixLenType(GiMultiplyFloat32( | |||
GiFixLenType2GiFloat32Type(value), | |||
GiFixLenType2GiFloat32Type(lr.value))); | |||
return dst; | |||
} | |||
Vector& operator*=(const Vector& lr) { | |||
value = GiFloat32Type2FixLenType(GiMultiplyFloat32( | |||
GiFixLenType2GiFloat32Type(value), | |||
GiFixLenType2GiFloat32Type(lr.value))); | |||
return *this; | |||
} | |||
Vector& operator=(const Vector& lr) { | |||
value = lr.value; | |||
return *this; | |||
} | |||
Vector& operator=(const Vector&& lr) { | |||
value = std::move(lr.value); | |||
return *this; | |||
} | |||
Vector operator-() { | |||
Vector dst; | |||
dst.value = -value; | |||
return dst; | |||
} | |||
}; | |||
template <> | |||
struct Vector<float, 8> { | |||
GI_FLOAT32_FIXLEN_V2_t value; | |||
Vector() {} | |||
Vector(const float v) { | |||
value.val[0] = GiFloat32Type2FixLenType(GiBroadcastFloat32(v)); | |||
value.val[1] = GiFloat32Type2FixLenType(GiBroadcastFloat32(v)); | |||
} | |||
Vector(const Vector& lr) { value = lr.value; } | |||
Vector(const Vector&& lr) { value = std::move(lr.value); } | |||
Vector(const GI_FLOAT32_V2_t& v) { value = GiFloat32Type2FixLenV2Type(v); } | |||
static Vector load(const float* addr) { | |||
Vector v; | |||
v.value = GiFloat32Type2FixLenV2Type(GiLoadFloat32V2(addr)); | |||
return v; | |||
} | |||
static void save(float* addr, const Vector& v) { | |||
GiStoreFloat32V2(addr, GiFixLenType2GiFloat32V2Type(v.value)); | |||
} | |||
void save(float* addr) { save(addr, *this); } | |||
Vector operator+(const Vector& lr) { | |||
Vector dst; | |||
dst.value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[0]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||
dst.value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[1]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||
return dst; | |||
} | |||
Vector& operator+=(const Vector& lr) { | |||
value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[0]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||
value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[1]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||
return *this; | |||
} | |||
Vector& add(const Vector& lr) { | |||
value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[0]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||
value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[1]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||
return *this; | |||
} | |||
Vector operator-(const Vector& lr) { | |||
Vector dst; | |||
dst.value.val[0] = GiFloat32Type2FixLenType(GiSubtractFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[0]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||
dst.value.val[1] = GiFloat32Type2FixLenType(GiSubtractFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[1]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||
return dst; | |||
} | |||
Vector& operator-=(const Vector& lr) { | |||
value.val[0] = GiFloat32Type2FixLenType(GiSubtractFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[0]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||
value.val[1] = GiFloat32Type2FixLenType(GiSubtractFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[1]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||
return *this; | |||
} | |||
Vector operator*(float lr) { | |||
Vector dst; | |||
dst.value.val[0] = GiFloat32Type2FixLenType( | |||
GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value.val[0]), lr)); | |||
dst.value.val[1] = GiFloat32Type2FixLenType( | |||
GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value.val[1]), lr)); | |||
return dst; | |||
} | |||
//! val + lr * n | |||
Vector& mla(const Vector& lr, float n) { | |||
value.val[0] = GiFloat32Type2FixLenType(GiMultiplyAddScalarFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[0]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[0]), n)); | |||
value.val[1] = GiFloat32Type2FixLenType(GiMultiplyAddScalarFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[1]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[1]), n)); | |||
return *this; | |||
} | |||
Vector operator*(const Vector& lr) { | |||
Vector dst; | |||
dst.value.val[0] = GiFloat32Type2FixLenType(GiMultiplyFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[0]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||
dst.value.val[1] = GiFloat32Type2FixLenType(GiMultiplyFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[1]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||
return dst; | |||
} | |||
Vector& operator*=(const Vector& lr) { | |||
value.val[0] = GiFloat32Type2FixLenType(GiMultiplyFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[0]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[0]))); | |||
value.val[1] = GiFloat32Type2FixLenType(GiMultiplyFloat32( | |||
GiFixLenType2GiFloat32Type(value.val[1]), | |||
GiFixLenType2GiFloat32Type(lr.value.val[1]))); | |||
return *this; | |||
} | |||
Vector& operator=(const Vector& lr) { | |||
value = lr.value; | |||
return *this; | |||
} | |||
Vector& operator=(const Vector&& lr) { | |||
value = std::move(lr.value); | |||
return *this; | |||
} | |||
Vector operator-() { | |||
Vector dst; | |||
dst.value.val[0] = -value.val[0]; | |||
dst.value.val[1] = -value.val[1]; | |||
return dst; | |||
} | |||
}; | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -857,6 +857,18 @@ GI_FLOAT32_t GiMultiplyScalerFloat32(GI_FLOAT32_t Vector1, float Scaler) { | |||
} | |||
GI_FORCEINLINE | |||
GI_FLOAT32_V2_t GiMultiplyScalerFloat32V2(GI_FLOAT32_V2_t Vector1, float Scaler) { | |||
GI_FLOAT32_V2_t ret; | |||
GiSetSubVectorFloat32V2( | |||
ret, 0, | |||
GiMultiplyScalerFloat32(GiGetSubVectorFloat32V2(Vector1, 0), Scaler)); | |||
GiSetSubVectorFloat32V2( | |||
ret, 1, | |||
GiMultiplyScalerFloat32(GiGetSubVectorFloat32V2(Vector1, 1), Scaler)); | |||
return ret; | |||
} | |||
GI_FORCEINLINE | |||
GI_FLOAT32_t GiMultiplyAddFloat32( | |||
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | |||
#if defined(GI_NEON_INTRINSICS) | |||
@@ -888,6 +900,23 @@ GI_FLOAT32_t GiMultiplyAddScalarFloat32( | |||
} | |||
GI_FORCEINLINE | |||
GI_FLOAT32_V2_t GiMultiplyAddScalarFloat32V2( | |||
GI_FLOAT32_V2_t VectorSum, GI_FLOAT32_V2_t Vector, float Scalar) { | |||
GI_FLOAT32_V2_t ret; | |||
GiSetSubVectorFloat32V2( | |||
ret, 0, | |||
GiMultiplyAddScalarFloat32( | |||
GiGetSubVectorFloat32V2(VectorSum, 0), | |||
GiGetSubVectorFloat32V2(Vector, 0), Scalar)); | |||
GiSetSubVectorFloat32V2( | |||
ret, 1, | |||
GiMultiplyAddScalarFloat32( | |||
GiGetSubVectorFloat32V2(VectorSum, 1), | |||
GiGetSubVectorFloat32V2(Vector, 1), Scalar)); | |||
return ret; | |||
} | |||
GI_FORCEINLINE | |||
GI_FLOAT32_t GiMultiplySubScalarFloat32( | |||
GI_FLOAT32_t VectorSub, GI_FLOAT32_t Vector, float Scalar) { | |||
#if defined(GI_NEON_INTRINSICS) | |||
@@ -951,6 +980,26 @@ GI_FLOAT32_t GiDivideFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | |||
#endif | |||
} | |||
#define OPV2(op) \ | |||
GI_FORCEINLINE \ | |||
GI_FLOAT32_V2_t op##V2(GI_FLOAT32_V2_t Vector1, GI_FLOAT32_V2_t Vector2) { \ | |||
GI_FLOAT32_V2_t ret; \ | |||
GiSetSubVectorFloat32V2( \ | |||
ret, 0, \ | |||
op(GiGetSubVectorFloat32V2(Vector1, 0), \ | |||
GiGetSubVectorFloat32V2(Vector2, 0))); \ | |||
GiSetSubVectorFloat32V2( \ | |||
ret, 1, \ | |||
op(GiGetSubVectorFloat32V2(Vector1, 1), \ | |||
GiGetSubVectorFloat32V2(Vector2, 1))); \ | |||
return ret; \ | |||
} | |||
OPV2(GiAddFloat32); | |||
OPV2(GiSubtractFloat32); | |||
OPV2(GiMultiplyFloat32); | |||
OPV2(GiDivideFloat32); | |||
#undef OPV2 | |||
GI_FORCEINLINE | |||
GI_FLOAT32_t GiRecpeSFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | |||
#if defined(GI_NEON64_INTRINSICS) | |||