diff --git a/dnn/src/fallback/general_intrinsic/gi_common.h b/dnn/src/fallback/general_intrinsic/gi_common.h index 4beed132..c8b06dca 100644 --- a/dnn/src/fallback/general_intrinsic/gi_common.h +++ b/dnn/src/fallback/general_intrinsic/gi_common.h @@ -38,8 +38,10 @@ #ifdef _WIN32 //! GI stand for general intrinsic +#define _GI_ALIGN_16 __declspec(align(16)) #define GI_DECLSPEC_ALIGN(variable, alignment) DECLSPEC_ALIGN(alignment) variable #else +#define _GI_ALIGN_16 __attribute__((aligned(16))) #define GI_DECLSPEC_ALIGN(variable, alignment) \ variable __attribute__((aligned(alignment))) #endif @@ -82,8 +84,50 @@ #endif #endif +#if defined(GI_TEST_NAIVE) +#undef GI_NEON_INTRINSICS +#undef GI_NEON64_INTRINSICS +#undef GI_NEON32_INTRINSICS +#undef GI_FMA_INTRINSICS +#undef GI_AVX2_INTRINSICS +#undef GI_AVX_INTRINSICS +#undef GI_SSE42_INTRINSICS +#undef GI_SSE2_INTRINSICS +#endif + +//! general intrinsic support dynamic length simd, if avx or avx2 the simd +//! length is 256 +#if defined(GI_AVX_INTRINSICS) || defined(GI_AVX2_INTRINSICS) || \ + defined(GI_FMA_INTRINSICS) +//! if neon and sse the simd lenght is 128 +#define GI_SIMD_LEN 256 +#define GI_SIMD_LEN_BYTE 32 +#elif defined(GI_NEON_INTRINSICS) || defined(GI_SSE2_INTRINSICS) || \ + defined(GI_SSE42_INTRINSICS) +#define GI_SIMD_LEN 128 +#define GI_SIMD_LEN_BYTE 16 +#else +//! if no simd hardware support, the simd is implemented by C, default set to +//! 128 +#define GI_SIMD_LEN 128 +#define GI_SIMD_LEN_BYTE 16 +#endif + +#define gi_trap() __builtin_trap() + +//! for ci test now +enum GiSimdType { + GI_UNKNOWN, + GI_NAIVE, + GI_AVX, + GI_SSE42, + GI_SSE2, + GI_NEON, +}; + #if defined(GI_AVX_INTRINSICS) || defined(GI_AVX2_INTRINSICS) || \ defined(GI_FMA_INTRINSICS) +#define __gi_simd_type GI_AVX typedef __m256 GI_FLOAT32_t; typedef __m256i GI_UINT8_t; typedef __m256i GI_INT8_t; @@ -91,46 +135,177 @@ typedef __m256i GI_INT16_t; typedef __m256i GI_INT32_t; typedef __m256i GI_UINT32_t; #elif defined(GI_NEON_INTRINSICS) +#define __gi_simd_type GI_NEON typedef float32x4_t GI_FLOAT32_t; typedef uint8x16_t GI_UINT8_t; typedef int8x16_t GI_INT8_t; typedef int16x8_t GI_INT16_t; typedef int32x4_t GI_INT32_t; typedef uint32x4_t GI_UINT32_t; +typedef float32x4x2_t GI_FLOAT32_V2_t; +typedef float32x4x4_t GI_FLOAT32_V4_t; +typedef int32x4x2_t GI_INT32_V2_t; +typedef int32x4x4_t GI_INT32_V4_t; +typedef int16x8x2_t GI_INT16_V2_t; +typedef int8x16x2_t GI_INT8_V2_t; +typedef int64x2_t GI_INT64_t; #elif defined(GI_SSE2_INTRINSICS) || defined(GI_SSE42_INTRINSICS) + +#if defined(GI_SSE42_INTRINSICS) +#define __gi_simd_type GI_SSE42 +#elif defined(GI_SSE2_INTRINSICS) +#define __gi_simd_type GI_SSE2 +#else +#define __gi_simd_type GI_UNKNOWN +#error "code issue happened!!" +#endif + typedef __m128 GI_FLOAT32_t; typedef __m128i GI_UINT8_t; typedef __m128i GI_INT8_t; typedef __m128i GI_INT16_t; typedef __m128i GI_INT32_t; typedef __m128i GI_UINT32_t; +typedef __m128i GI_INT64_t; +#define _INSERTPS_NDX(srcField, dstField) (((srcField) << 6) | ((dstField) << 4)) +#define _M64(out, inp) _mm_storel_epi64((__m128i*)&(out), inp) +#define _pM128i(a) _mm_loadl_epi64((__m128i*)&(a)) +#define _pM128(a) _mm_castsi128_ps(_pM128i(a)) +#define _M128i(a) _mm_castps_si128(a) +#define _M128(a) _mm_castsi128_ps(a) +#if defined(__x86_64__) +#define _M64f(out, inp) out.m64_i64[0] = _mm_cvtsi128_si64(_M128i(inp)); +#else +#define _M64f(out, inp) _mm_storel_epi64((__m128i*)&(out), _M128i(inp)) +#endif +#define _SSE_SWITCH16(NAME, a, b, LANE) \ + switch (LANE) { \ + case 0: \ + return NAME(a b, 0); \ + case 1: \ + return NAME(a b, 1); \ + case 2: \ + return NAME(a b, 2); \ + case 3: \ + return NAME(a b, 3); \ + case 4: \ + return NAME(a b, 4); \ + case 5: \ + return NAME(a b, 5); \ + case 6: \ + return NAME(a b, 6); \ + case 7: \ + return NAME(a b, 7); \ + case 8: \ + return NAME(a b, 8); \ + case 9: \ + return NAME(a b, 9); \ + case 10: \ + return NAME(a b, 10); \ + case 11: \ + return NAME(a b, 11); \ + case 12: \ + return NAME(a b, 12); \ + case 13: \ + return NAME(a b, 13); \ + case 14: \ + return NAME(a b, 14); \ + case 15: \ + return NAME(a b, 15); \ + default: \ + gi_trap(); \ + return NAME(a b, 0); \ + } +#if !defined(__SSE3__) +GI_FORCEINLINE __m128i _sse2_mm_alignr_epi8(__m128i b, __m128i a, int imm8) { + int imm2 = sizeof(__m128i) - imm8; + return _mm_or_si128(_mm_srli_si128(a, imm8), _mm_slli_si128(b, imm2)); +} +#endif + +#define _SSE_COMMA , +GI_FORCEINLINE __m128i _MM_ALIGNR_EPI8(__m128i a, __m128i b, int LANE) { +#if !defined(__SSE3__) + _SSE_SWITCH16(_sse2_mm_alignr_epi8, a, _SSE_COMMA b, LANE) +#else + _SSE_SWITCH16(_mm_alignr_epi8, a, _SSE_COMMA b, LANE) +#endif +} +typedef float float32_t; +typedef double float64_t; +typedef union __m64_128 { + uint64_t m64_u64[1]; + int64_t m64_i64[1]; + float64_t m64_d64[1]; + uint32_t m64_u32[2]; + int32_t m64_i32[2]; + float32_t m64_f32[2]; + int16_t m64_i16[4]; + uint16_t m64_u16[4]; + int8_t m64_i8[8]; + uint8_t m64_u8[8]; +} __m64_128; +typedef __m64_128 float32x2_t; + +#define return64(a) \ + _M64(res64, a); \ + return res64; +#define return64f(a) \ + _M64f(res64, a); \ + return res64; +#define _sse_vextq_s32(a, b, c) _MM_ALIGNR_EPI8(b, a, c * 4) +#define _sse_vget_lane_f32(vec, lane) vec.m64_f32[lane] #else +#define __gi_simd_type GI_NAIVE typedef float GI_FLOAT32_t __attribute__((vector_size(16))); typedef uint8_t GI_UINT8_t __attribute__((vector_size(16))); typedef int8_t GI_INT8_t __attribute__((vector_size(16))); typedef int16_t GI_INT16_t __attribute__((vector_size(16))); typedef int32_t GI_INT32_t __attribute__((vector_size(16))); typedef uint32_t GI_UINT32_t __attribute__((vector_size(16))); +typedef int64_t GI_INT64_t __attribute__((vector_size(16))); +#if !defined(__arm__) && !defined(__aarch64__) +typedef float float32x2_t __attribute__((vector_size(8))); #endif - -//! general intrinsic support dynamic length simd, if avx or avx2 the simd -//! length is 256 -#if defined(GI_AVX_INTRINSICS) || defined(GI_AVX2_INTRINSICS) || \ - defined(GI_FMA_INTRINSICS) -//! if neon and sse the simd lenght is 128 -#define GI_SIMD_LEN 256 -#define GI_SIMD_LEN_BYTE 32 -#elif defined(GI_NEON_INTRINSICS) || defined(GI_SSE2_INTRINSICS) || \ - defined(GI_SSE42_INTRINSICS) -#define GI_SIMD_LEN 128 -#define GI_SIMD_LEN_BYTE 16 -#else -//! if no simd hardware support, the simd is implemented by C, default set to -//! 128 -#define GI_SIMD_LEN 128 -#define GI_SIMD_LEN_BYTE 16 +typedef float float32_t; #endif +//! some GI api do not support full GiSimdType +//! for example: GiAbsInt32 do not imp SSE2 case +//! when *_t will define as _m128*(may be long long) +//! vector index do not have same logic as naive vector +typedef float GI_FLOAT32_NAIVE_t __attribute__((vector_size(16))); +typedef uint8_t GI_UINT8_NAIVE_t __attribute__((vector_size(16))); +typedef int8_t GI_INT8_NAIVE_t __attribute__((vector_size(16))); +typedef int16_t GI_INT16_NAIVE_t __attribute__((vector_size(16))); +typedef int32_t GI_INT32_NAIVE_t __attribute__((vector_size(16))); +typedef uint32_t GI_UINT32_NAIVE_t __attribute__((vector_size(16))); +typedef int64_t GI_INT64_NAIVE_t __attribute__((vector_size(16))); +typedef float float32x2_NAIVE_t __attribute__((vector_size(8))); +typedef struct { + GI_INT32_NAIVE_t val[2]; +} GI_INT32_V2_NAIVE_t; + +typedef struct { + GI_INT32_NAIVE_t val[4]; +} GI_INT32_V4_NAIVE_t; + +typedef struct { + GI_FLOAT32_NAIVE_t val[2]; +} GI_FLOAT32_V2_NAIVE_t; + +typedef struct { + GI_FLOAT32_NAIVE_t val[4]; +} GI_FLOAT32_V4_NAIVE_t; + +typedef struct { + GI_INT16_NAIVE_t val[2]; +} GI_INT16_V2_NAIVE_t; + +typedef struct { + GI_INT8_NAIVE_t val[2]; +} GI_INT8_V2_NAIVE_t; + #define Max(a, b) (a) > (b) ? (a) : (b) #define Min(a, b) (a) < (b) ? (a) : (b) @@ -146,6 +321,7 @@ typedef uint32_t GI_UINT32_t __attribute__((vector_size(16))); #endif #endif +#if !defined(GI_NEON_INTRINSICS) typedef struct { GI_INT32_t val[2]; } GI_INT32_V2_t; @@ -169,6 +345,7 @@ typedef struct { typedef struct { GI_INT8_t val[2]; } GI_INT8_V2_t; +#endif GI_FORCEINLINE GI_INT32_t GiAndInt32(GI_INT32_t Vector1, GI_INT32_t Vector2) { @@ -259,6 +436,34 @@ GI_INT8_t GiBroadcastInt8(int8_t Value) { #endif } +GI_FORCEINLINE +GiSimdType GiGetSimdType() { + //! override by special macro to insure ci have test naive and sse2 + //! now we do not imp GI_AVX to now and x64 ci device will test GI_SSE42 + //! now arm ci device will test GI_NEON + //! insure test GI_SSE2 by command: + //! --copt -march=core2 --copt -mno-sse4.2 + //! --copt -mno-sse3 --copt -DGI_TEST_SSE2 + //! insure test GI_NAIVE by command: + //! --copt -DGI_TEST_SSE2 + //! DNN code at least need sse2 at x86 + //! so we can not test GI_NAIVE by + //! --copt -march=core2 --copt -mno-sse4.2 + //! --copt -mno-sse3 --copt -mno-sse2 + //! --copt -DGI_TEST_NAIVE + //! about CMake, can override build flags to CMAKE_CXX_FLAGS/CMAKE_C_FLAGS by + //! EXTRA_CMAKE_ARGS when use scripts/cmake-build/*.sh +#if defined(GI_TEST_NAIVE) +#undef __gi_simd_type +#define __gi_simd_type GI_NAIVE +#elif defined(GI_TEST_SSE2) +#undef __gi_simd_type +#define __gi_simd_type GI_SSE2 +#endif + + return __gi_simd_type; +} + __attribute__((unused)) const GI_INT8_t vzero_int8 = GiBroadcastInt8(0); __attribute__((unused)) const GI_INT32_t vzero = GiBroadcastInt32(0); __attribute__((unused)) const GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f); diff --git a/dnn/src/fallback/general_intrinsic/gi_float.h b/dnn/src/fallback/general_intrinsic/gi_float.h index 8af5c231..e2910fa6 100644 --- a/dnn/src/fallback/general_intrinsic/gi_float.h +++ b/dnn/src/fallback/general_intrinsic/gi_float.h @@ -71,9 +71,13 @@ GI_INT32_t GiRoundAsInt32(GI_FLOAT32_t Vector) { return _mm_cvttps_epi32(_mm_add_ps(Vector, vinc0)); #else GI_INT32_t ret; + GI_INT32_NAIVE_t tmp_ret; + GI_FLOAT32_NAIVE_t s0; + memcpy(&s0, &Vector, sizeof(GI_FLOAT32_NAIVE_t)); for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { - ret[i] = (int32_t)round(Vector[i]); + tmp_ret[i] = (int32_t)round(s0[i]); } + memcpy(&ret, &tmp_ret, sizeof(GI_INT32_t)); return ret; #endif } @@ -139,7 +143,10 @@ GI_FLOAT32_t GiLoadFloat32(const float* Buffer) { #if defined(GI_NEON_INTRINSICS) return vld1q_f32(Buffer); #elif defined(GI_SSE2_INTRINSICS) - return _mm_loadu_ps(Buffer); + if ((((uintptr_t)(Buffer)) & 15) == 0) + return _mm_load_ps(Buffer); + else + return _mm_loadu_ps(Buffer); #else GI_FLOAT32_t ret; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { @@ -150,6 +157,356 @@ GI_FLOAT32_t GiLoadFloat32(const float* Buffer) { } GI_FORCEINLINE +GI_FLOAT32_t GiLoadFloat32LowHalf(const float* Buffer) { +#if defined(GI_NEON_INTRINSICS) + return vcombine_f32(vld1_f32(Buffer), vdup_n_f32(0.f)); +#elif defined(GI_SSE2_INTRINSICS) + typedef __m64_128 float32x2_t; + float32x2_t low, high; + low.m64_f32[0] = Buffer[0]; + low.m64_f32[1] = Buffer[1]; + high.m64_f32[0] = 0; + high.m64_f32[1] = 0; + __m128i res = _mm_unpacklo_epi64(_pM128i(low), _pM128i(high)); + return _M128(res); +#else + GI_FLOAT32_t ret; + memset(&ret, 0, sizeof(GI_FLOAT32_t)); + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float) / 2; i++) { + ret[i] = Buffer[i]; + } + return ret; +#endif +} + +GI_FORCEINLINE +GI_FLOAT32_t GiMlaqFloat32(GI_FLOAT32_t a, GI_FLOAT32_t b, GI_FLOAT32_t c) { +#if defined(GI_NEON_INTRINSICS) +#if defined(__ARM_FEATURE_FMA) + return vfmaq_f32(a, b, c); +#else + return vmlaq_f32(a, b, c); +#endif +#elif defined(GI_SSE2_INTRINSICS) + // fma is coming soon, but right now: + __m128 res; + res = _mm_mul_ps(c, b); + return _mm_add_ps(a, res); +#else + GI_FLOAT32_t ret; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + ret[i] = a[i] + (b[i] * c[i]); + } + return ret; +#endif +} + +GI_FORCEINLINE GI_FLOAT32_V2_t GiUzpqFloat32(GI_FLOAT32_t a, GI_FLOAT32_t b) { +#if defined(GI_NEON_INTRINSICS) + return vuzpq_f32(a, b); +#elif defined(GI_SSE2_INTRINSICS) + GI_FLOAT32_V2_t v32x4; + v32x4.val[0] = _mm_shuffle_ps(a, b, _MM_SHUFFLE(2, 0, 2, 0)); + v32x4.val[1] = _mm_shuffle_ps(a, b, _MM_SHUFFLE(3, 1, 3, 1)); + return v32x4; +#else + GI_FLOAT32_V2_t ret; + ret.val[0][0] = a[0]; + ret.val[0][1] = a[2]; + ret.val[0][2] = b[0]; + ret.val[0][3] = b[2]; + ret.val[1][0] = a[1]; + ret.val[1][1] = a[3]; + ret.val[1][2] = b[1]; + ret.val[1][3] = b[3]; + return ret; +#endif +} + +GI_FORCEINLINE float32x2_t GiDupFloat32(float a) { +#if defined(GI_NEON_INTRINSICS) + return vdup_n_f32(a); +#elif defined(GI_SSE2_INTRINSICS) + float32x2_t res; + res.m64_f32[0] = a; + res.m64_f32[1] = a; + return res; +#else + float32x2_t res; + res[0] = a; + res[1] = a; + return res; +#endif +} + +GI_FORCEINLINE float32x2_t GiLdFloat32(float const* ptr) { +#if defined(GI_NEON_INTRINSICS) + return vld1_f32(ptr); +#elif defined(GI_SSE2_INTRINSICS) + float32x2_t res; + res.m64_f32[0] = *(ptr); + res.m64_f32[1] = *(ptr + 1); + return res; +#else + float32x2_t res; + res[0] = *(ptr); + res[1] = *(ptr + 1); + return res; +#endif +} + +GI_FORCEINLINE float32x2_t GiAddDFloat32(float32x2_t a, float32x2_t b) { +#if defined(GI_NEON_INTRINSICS) + return vadd_f32(a, b); +#elif defined(GI_SSE2_INTRINSICS) + __m128 res; + __m64_128 res64; + res = _mm_add_ps(_pM128(a), _pM128(b)); // SSE, use only low 64 bits + _M64f(res64, res); + return res64; +#else + float32x2_t res; + res[0] = a[0] + b[0]; + res[1] = a[1] + b[1]; + return res; +#endif +} + +#if defined(GI_NEON_INTRINSICS) +#define GiGetLaneFloat32(v, lane) vget_lane_f32(v, lane) +#else +GI_FORCEINLINE float __gi_vget_lane_f32(float32x2_t v, const int lane) { +#if defined(GI_SSE2_INTRINSICS) + return _sse_vget_lane_f32(v, lane); +#else + return v[lane]; +#endif +} +#define GiGetLaneFloat32(v, lane) __gi_vget_lane_f32(v, lane) +#endif + +#if defined(GI_NEON_INTRINSICS) +#define GiSetLaneFloat32(value, vec, lane) vset_lane_f32(value, vec, lane) +#else +GI_FORCEINLINE float32x2_t +__gi_vset_lane_f32(float32_t value, float32x2_t vec, int lane) { +#if defined(GI_SSE2_INTRINSICS) + float32x2_t res; + res = vec; + res.m64_f32[lane] = value; + return res; +#else + float32x2_t res; + res = vec; + res[lane] = value; + return res; +#endif +} +#define GiSetLaneFloat32(value, vec, lane) __gi_vset_lane_f32(value, vec, lane) +#endif + +GI_FORCEINLINE void GiSt1Float32(float* ptr, float32x2_t val) { +#if defined(GI_NEON_INTRINSICS) + return vst1_f32(ptr, val); +#elif defined(GI_SSE2_INTRINSICS) + *(ptr) = val.m64_f32[0]; + *(ptr + 1) = val.m64_f32[1]; + return; +#else + *(ptr) = val[0]; + *(ptr + 1) = val[1]; + return; +#endif +} + +GI_FORCEINLINE GI_FLOAT32_V2_t GiLd2qFloat32(const float* Buffer) { +#if defined(GI_NEON_INTRINSICS) + return vld2q_f32(Buffer); +#elif defined(GI_SSE2_INTRINSICS) + GI_FLOAT32_V2_t v; + v.val[0] = GiLoadFloat32(Buffer); + v.val[1] = GiLoadFloat32((Buffer + 4)); + v = GiUzpqFloat32(v.val[0], v.val[1]); + return v; +#else + GI_FLOAT32_V2_t ret; + ret.val[0][0] = Buffer[0]; + ret.val[0][1] = Buffer[2]; + ret.val[0][2] = Buffer[4]; + ret.val[0][3] = Buffer[6]; + ret.val[1][0] = Buffer[1]; + ret.val[1][1] = Buffer[3]; + ret.val[1][2] = Buffer[5]; + ret.val[1][3] = Buffer[7]; + return ret; +#endif +} + +#if defined(GI_NEON_INTRINSICS) +#define GiExtqFloat32(a, b, n) vextq_f32(a, b, n) +#elif defined(GI_SSE2_INTRINSICS) +#define GiExtqFloat32(a, b, n) _M128(_sse_vextq_s32(_M128i(a), _M128i(b), n)); +#else +GI_FORCEINLINE GI_FLOAT32_t +__naive_gi_vextq_f32(GI_FLOAT32_t a, GI_FLOAT32_t b, const int n) { + GI_FLOAT32_t ret; + int t_count = GI_SIMD_LEN_BYTE / sizeof(float); + int a_count = t_count - n; + for (int i = 0; i < a_count; i++) { + ret[i] = a[i + n]; + } + for (int i = 0; i < n; i++) { + ret[i + a_count] = b[i]; + } + return ret; +} +#define GiExtqFloat32(a, b, n) __naive_gi_vextq_f32(a, b, n) +#endif + +GI_FORCEINLINE +GI_FLOAT32_t GiMultiplySubFloat32( + GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { +#if defined(GI_NEON_INTRINSICS) + return vmlsq_f32(VectorSum, Vector1, Vector2); +#elif defined(GI_SSE2_INTRINSICS) + return _mm_sub_ps(VectorSum, _mm_mul_ps(Vector1, Vector2)); +#else + GI_FLOAT32_t ret; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + ret[i] = VectorSum[i] - Vector1[i] * Vector2[i]; + } + + return ret; +#endif +} + +#if defined(GI_SSE2_INTRINSICS) +GI_FORCEINLINE GI_FLOAT32_t +_MM_INSERT_PS(GI_FLOAT32_t vec, GI_FLOAT32_t p, const int LANE) { + _GI_ALIGN_16 uint32_t mask[4] = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff}; + __m128 tmp, vec_masked, p_masked; + mask[LANE >> 4] = 0x0; + vec_masked = _mm_and_ps(*(__m128*)mask, vec); + p_masked = _mm_andnot_ps(*(__m128*)mask, p); + tmp = _mm_or_ps(vec_masked, p_masked); + return tmp; +} + +GI_FORCEINLINE float32x2_t sse_vget_high_f32(GI_FLOAT32_t a) { + __m128i res; + __m64_128 res64; + res = _mm_unpackhi_epi64(_M128i(a), _M128i(a)); + return64(res); +} + +GI_FORCEINLINE float32x2_t sse_vget_low_f32(GI_FLOAT32_t a) { + float32x2_t res64; + _M64f(res64, a); + return res64; +} + +GI_FORCEINLINE GI_FLOAT32_t +sse_vmlaq_lane_f32(GI_FLOAT32_t a, GI_FLOAT32_t b, float32x2_t v, int l) { + float32_t vlane; + GI_FLOAT32_t c; + vlane = _sse_vget_lane_f32(v, l); + c = _mm_set1_ps(vlane); + return GiMlaqFloat32(a, b, c); +} + +GI_FORCEINLINE int _MM_EXTRACT_PS(__m128 vec, const int LANE) { + _GI_ALIGN_16 int32_t tmp[4]; + _mm_store_si128((__m128i*)tmp, _M128i(vec)); + return tmp[LANE]; +} + +GI_FORCEINLINE float32_t sse_vgetq_lane_f32(GI_FLOAT32_t vec, int lane) { + float32_t floatVal; + char* const floatVal_c = (char*)&floatVal; + *((int32_t*)floatVal_c) = _MM_EXTRACT_PS(vec, lane); + return floatVal; +} + +GI_FORCEINLINE GI_FLOAT32_t +sse_vmlsq_lane_f32(GI_FLOAT32_t a, GI_FLOAT32_t b, float32x2_t v, int l) { + float32_t vlane; + GI_FLOAT32_t c; + vlane = (float)GiGetLaneFloat32(v, l); + c = GiBroadcastFloat32(vlane); + return GiMultiplySubFloat32(a, b, c); +} + +#endif + +#if defined(GI_NEON_INTRINSICS) +#define GiLd1qLaneFloat32(Buffer, src, n) vld1q_lane_f32(Buffer, src, n) +#else +GI_FORCEINLINE GI_FLOAT32_t +__gi_vld1q_lane_f32(const float* Buffer, GI_FLOAT32_t src, const int n) { +#if defined(GI_SSE2_INTRINSICS) + GI_FLOAT32_t p; + p = _mm_set1_ps(*(Buffer)); + return _MM_INSERT_PS(src, p, _INSERTPS_NDX(0, n)); +#else + GI_FLOAT32_t ret; + memcpy(&ret, &src, sizeof(GI_FLOAT32_t)); + ret[n] = *Buffer; + return ret; +#endif +} +#define GiLd1qLaneFloat32(Buffer, src, n) __gi_vld1q_lane_f32(Buffer, src, n) +#endif + +#if defined(GI_NEON_INTRINSICS) +#define GiSetqLaneFloat32(value, vec, lane) vsetq_lane_f32(value, vec, lane) +#else +GI_FORCEINLINE GI_FLOAT32_t +__gi_vsetq_lane_f32(float value, GI_FLOAT32_t vec, const int lane) { + float val = value; + return GiLd1qLaneFloat32(&val, vec, lane); +} +#define GiSetqLaneFloat32(value, vec, lane) __gi_vsetq_lane_f32(value, vec, lane) +#endif + +#if defined(GI_NEON_INTRINSICS) +#define GiMlaqLaneFloat32HighHalf(a, b, v, lane) \ + vmlaq_lane_f32(a, b, vget_high_f32(v), lane) +#elif defined(GI_SSE2_INTRINSICS) +#define GiMlaqLaneFloat32HighHalf(a, b, v, lane) \ + sse_vmlaq_lane_f32(a, b, sse_vget_high_f32(v), lane) +#else +GI_FORCEINLINE GI_FLOAT32_t __naive_gi_vmlaq_lane_f32_high_half( + GI_FLOAT32_t a, GI_FLOAT32_t b, GI_FLOAT32_t v, const int lane) { + GI_FLOAT32_t ret; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + ret[i] = a[i] + (b[i] * v[lane + 2]); + } + return ret; +} +#define GiMlaqLaneFloat32HighHalf(a, b, v, lane) \ + __naive_gi_vmlaq_lane_f32_high_half(a, b, v, lane) +#endif + +#if defined(GI_NEON_INTRINSICS) +#define GiVmlaqLaneFloat32LowHalf(a, b, v, lane) \ + vmlaq_lane_f32(a, b, vget_low_f32(v), lane) +#elif defined(GI_SSE2_INTRINSICS) +#define GiVmlaqLaneFloat32LowHalf(a, b, v, lane) \ + sse_vmlaq_lane_f32(a, b, sse_vget_low_f32(v), lane) +#else +GI_FORCEINLINE GI_FLOAT32_t __naive_gi_vmlaq_lane_f32_low_half( + GI_FLOAT32_t a, GI_FLOAT32_t b, GI_FLOAT32_t v, const int lane) { + GI_FLOAT32_t ret; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + ret[i] = a[i] + (b[i] * v[lane]); + } + return ret; +} +#define GiVmlaqLaneFloat32LowHalf(a, b, v, lane) \ + __naive_gi_vmlaq_lane_f32_low_half(a, b, v, lane) +#endif + +GI_FORCEINLINE void GiStoreFloat32(float* Buffer, GI_FLOAT32_t Vector) { #if defined(GI_NEON_INTRINSICS) vst1q_f32(Buffer, Vector); @@ -213,6 +570,29 @@ GIEXTRACTLANEFLOAT32(3) #undef GIEXTRACTLANEFLOAT32 GI_FORCEINLINE +GI_FLOAT32_V2_t GiZipqFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { +#if defined(GI_NEON_INTRINSICS) + return vzipq_f32(Vector1, Vector2); +#elif defined(GI_SSE2_INTRINSICS) + GI_FLOAT32_V2_t f32x4; + f32x4.val[0] = _mm_unpacklo_ps(Vector1, Vector2); + f32x4.val[1] = _mm_unpackhi_ps(Vector1, Vector2); + return f32x4; +#else + GI_FLOAT32_V2_t ret; + ret.val[0][0] = Vector1[0]; + ret.val[0][1] = Vector2[0]; + ret.val[0][2] = Vector1[1]; + ret.val[0][3] = Vector2[1]; + ret.val[1][0] = Vector1[2]; + ret.val[1][1] = Vector2[2]; + ret.val[1][2] = Vector1[3]; + ret.val[1][3] = Vector2[3]; + return ret; +#endif +} + +GI_FORCEINLINE GI_FLOAT32_t GiInterleaveLowFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { #if defined(GI_NEON64_INTRINSICS) return vzip1q_f32(Vector1, Vector2); @@ -243,8 +623,8 @@ GI_FLOAT32_t GiInterleaveHighFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) #else GI_FLOAT32_t ret; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(float); i++) { - ret[2 * i] = Vector1[GI_SIMD_LEN_BYTE / 2 + i]; - ret[2 * i + 1] = Vector2[GI_SIMD_LEN_BYTE / 2 + i]; + ret[2 * i] = Vector1[GI_SIMD_LEN_BYTE / 2 / sizeof(float) + i]; + ret[2 * i + 1] = Vector2[GI_SIMD_LEN_BYTE / 2 / sizeof(float) + i]; } return ret; #endif @@ -310,18 +690,6 @@ GI_FLOAT32_t GiMultiplyAddFloat32( } GI_FORCEINLINE -GI_FLOAT32_t GiMultiplySubFloat32( - GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { -#if defined(GI_NEON_INTRINSICS) - return vmlsq_f32(VectorSum, Vector1, Vector2); -#elif defined(GI_SSE2_INTRINSICS) - return _mm_sub_ps(VectorSum, _mm_mul_ps(Vector1, Vector2)); -#else - return VectorSum - Vector1 * Vector2; -#endif -} - -GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddScalarFloat32( GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector, float Scalar) { #if defined(GI_NEON_INTRINSICS) @@ -350,7 +718,7 @@ GIMULTIPLYADDLANFLOAT32(1) GIMULTIPLYADDLANFLOAT32(2) GIMULTIPLYADDLANFLOAT32(3) #undef GIMULTIPLYADDLANFLOAT32 -#elif defined(GI_SSE2_INTRINSICS) +#else #define GIMULTIPLYADDLANFLOAT32(i) \ GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ @@ -363,17 +731,6 @@ GIMULTIPLYADDLANFLOAT32(1) GIMULTIPLYADDLANFLOAT32(2) GIMULTIPLYADDLANFLOAT32(3) #undef GIMULTIPLYADDLANFLOAT32 -#else -#define GIMULTIPLYADDLANFLOAT32(i) \ - GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ - GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ - return VectorSum + Vector1 * Vector2[i]; \ - } -GIMULTIPLYADDLANFLOAT32(0) -GIMULTIPLYADDLANFLOAT32(1) -GIMULTIPLYADDLANFLOAT32(2) -GIMULTIPLYADDLANFLOAT32(3) -#undef GIMULTIPLYADDLANFLOAT32 #endif GI_FORCEINLINE @@ -411,6 +768,7 @@ GI_FLOAT32_t GiRecpeFloat32(GI_FLOAT32_t Vector) { GI_FLOAT32_t ones = _mm_set1_ps(1.0f); return _mm_div_ps(ones, Vector); #else + //! FIXME: neon or sse always have low accuracy than 1/x return 1 / Vector; #endif } @@ -516,7 +874,7 @@ GI_FORCEINLINE GI_FLOAT32_t GiBlendFloat32( GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2, GI_FLOAT32_t Selection) { return GiOrFloat32( - GiAndFloat32(Vector2, Selection), GiAndNotFloat32(Selection, Vector1)); + GiAndFloat32(Vector1, Selection), GiAndNotFloat32(Selection, Vector2)); } #define MIN_NAN(a, b) (isnan(a) || (a) < (b)) ? (a) : (b); @@ -569,7 +927,6 @@ GI_FLOAT32_t GiMaxNanFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { #else //! _mm_max_ps does not fellow the IEEE standard when input is NAN, so //! implement by C code -#define MAX_NAN(a, b) (isnan(a) || (a) > (b)) ? (a) : (b); GI_FLOAT32_t max; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { max[i] = MAX_NAN(Vector1[i], Vector2[i]); @@ -585,7 +942,6 @@ GI_FLOAT32_t GiMinNanFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { #else //! _mm_min_ps does not fellow the IEEE standard when input is NAN, so //! implement by C code -#define MIN_NAN(a, b) (isnan(a) || (a) < (b)) ? (a) : (b); GI_FLOAT32_t min; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { min[i] = MIN_NAN(Vector1[i], Vector2[i]); @@ -723,4 +1079,249 @@ GI_FLOAT32_t GiAbsFloat32(GI_FLOAT32_t Vector1) { #endif } -// vim: syntax=cpp.doxygen +#if defined(GI_SSE2_INTRINSICS) +typedef __m128i int8x16_t; +typedef __m64_128 int8x8_t; +GI_FORCEINLINE int8x16_t vcombine_s8(int8x8_t low, int8x8_t high) { + return _mm_unpacklo_epi64(_pM128i(low), _pM128i(high)); +} + +typedef __m64_128 int64x1_t; +GI_FORCEINLINE int64x1_t vget_low_s64(GI_INT64_t a) { + int64x1_t res64; + return64(a); +} +GI_FORCEINLINE int64x1_t vget_high_s64(GI_INT64_t a) { + int64x1_t res64; + __m128i res; + res = _mm_unpackhi_epi64(a, a); + return64(res); +} +#endif + +GI_FORCEINLINE GI_INT64_t GiZip1qS64(GI_INT64_t __p0, GI_INT64_t __p1) { +#if defined(GI_NEON_INTRINSICS) + return vzip1q_s64(__p0, __p1); +#elif defined(GI_SSE2_INTRINSICS) +#define vcombine_s64 vcombine_s8 + return vcombine_s64(vget_low_s64(__p0), vget_low_s64(__p1)); +#else + GI_INT64_t ret; + ret[0] = __p0[0]; + ret[1] = __p1[0]; + return ret; +#endif +} + +GI_FORCEINLINE GI_INT64_t GiZip2qS64(GI_INT64_t __p0, GI_INT64_t __p1) { +#if defined(GI_NEON_INTRINSICS) + return vzip2q_s64(__p0, __p1); +#elif defined(GI_SSE2_INTRINSICS) +#define vcombine_s64 vcombine_s8 + return vcombine_s64(vget_high_s64(__p0), vget_high_s64(__p1)); +#else + GI_INT64_t ret; + ret[0] = __p0[1]; + ret[1] = __p1[1]; + return ret; +#endif +} + +GI_FORCEINLINE GI_FLOAT32_t GiReinterpretqS64ToFloat32(GI_INT64_t a) { +#if defined(GI_NEON_INTRINSICS) + return vreinterpretq_f32_s64(a); +#elif defined(GI_SSE2_INTRINSICS) + return _M128(a); +#else + GI_FLOAT32_t ret; + memcpy(&ret, &a, sizeof(GI_FLOAT32_t)); + return ret; +#endif +} + +GI_FORCEINLINE GI_INT64_t GiReinterpretqFloat32ToS64(GI_FLOAT32_t a) { +#if defined(GI_NEON_INTRINSICS) + return vreinterpretq_s64_f32(a); +#elif defined(GI_SSE2_INTRINSICS) + return _M128i(a); +#else + GI_INT64_t ret; + memcpy(&ret, &a, sizeof(GI_INT64_t)); + return ret; +#endif +} + +#if defined(GI_NEON_INTRINSICS) +#define GiSimdFmaLane(a, b, c, d) vfmaq_laneq_f32(a, b, c, d) +#else +GI_FORCEINLINE GI_FLOAT32_t +___gi_vmlaq_lane_f32(GI_FLOAT32_t a, GI_FLOAT32_t b, float32x2_t v, int l) { + float vlane; + GI_FLOAT32_t c; + vlane = (float)GiGetLaneFloat32(v, l); + c = GiBroadcastFloat32(vlane); + return GiMlaqFloat32(a, b, c); +} +GI_FORCEINLINE float32x2_t ___gi_vget_low_f32(GI_FLOAT32_t a) { +#if defined(GI_SSE2_INTRINSICS) + float32x2_t res64; + _M64f(res64, a); + return res64; +#else + float32x2_t ret; + ret[0] = a[0]; + ret[1] = a[1]; + return ret; +#endif +} +GI_FORCEINLINE float32x2_t ___gi_vget_high_f32(GI_FLOAT32_t a) { +#if defined(GI_SSE2_INTRINSICS) + __m128i res; + __m64_128 res64; + res = _mm_unpackhi_epi64(_M128i(a), _M128i(a)); + return64(res); +#else + float32x2_t ret; + ret[0] = a[2]; + ret[1] = a[3]; + return ret; +#endif +} +GI_FORCEINLINE GI_FLOAT32_t +___gi_vfmaq_laneq_f32(GI_FLOAT32_t a, GI_FLOAT32_t b, GI_FLOAT32_t v, int l) { + if (l < 2) { + return ___gi_vmlaq_lane_f32(a, b, ___gi_vget_low_f32(v), l); + } else { + return ___gi_vmlaq_lane_f32(a, b, ___gi_vget_high_f32(v), l - 2); + } +} +#define GiSimdFmaLane(a, b, c, d) ___gi_vfmaq_laneq_f32(a, b, c, d) +#endif + +#if defined(GI_NEON_INTRINSICS) +#if MEGDNN_AARCH64 +#define GiMlaqLowLaneFloat32(__a, __b, __v, __lane) \ + vmlaq_laneq_f32(__a, __b, __v, __lane) + +#define GiMlaqHighLaneFloat32(__a, __b, __v, __lane) \ + vmlaq_laneq_f32(__a, __b, __v, __lane) + +#else +#define GiMlaqLowLaneFloat32(__a, __b, __v, __lane) \ + __extension__({ \ + float32x2_t c = vget_low_f32(__v); \ + GI_FLOAT32_t __ret = vmlaq_lane_f32(__a, __b, c, __lane); \ + __ret; \ + }) + +#define GiMlaqHighLaneFloat32(__a, __b, __v, __lane) \ + __extension__({ \ + float32x2_t c = vget_high_f32(__v); \ + GI_FLOAT32_t __ret = vmlaq_lane_f32(__a, __b, c, (__lane - 2)); \ + __ret; \ + }) + +#endif + +#elif defined(GI_SSE2_INTRINSICS) +#define GiMlaqLowLaneFloat32(__a, __b, __v, __lane) \ + __extension__({ \ + float32x2_t c = sse_vget_low_f32(__v); \ + GI_FLOAT32_t __ret = sse_vmlaq_lane_f32(__a, __b, c, __lane); \ + __ret; \ + }) + +#define GiMlaqHighLaneFloat32(__a, __b, __v, __lane) \ + __extension__({ \ + float32x2_t c = sse_vget_high_f32(__v); \ + GI_FLOAT32_t __ret = sse_vmlaq_lane_f32(__a, __b, c, (__lane - 2)); \ + __ret; \ + }) + +#else +//! naive +#define GiMlaqLowLaneFloat32(__a, __b, __v, __lane) \ + __extension__({ \ + GI_FLOAT32_t __ret; \ + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { \ + __ret[i] = __a[i] + (__b[i] * __v[__lane]); \ + } \ + __ret; \ + }) + +#define GiMlaqHighLaneFloat32(__a, __b, __v, __lane) \ + __extension__({ \ + GI_FLOAT32_t __ret; \ + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { \ + __ret[i] = __a[i] + (__b[i] * __v[__lane]); \ + } \ + __ret; \ + }) +#endif + +#if defined(GI_NEON_INTRINSICS) +#define GiFmsqLaneQFloat32(a, b, v, lane) vfmsq_laneq_f32(a, b, v, lane) +#elif defined(GI_SSE2_INTRINSICS) +#define SSE_VFMSQ_LANEQ_F32(lane) \ + GI_FORCEINLINE GI_FLOAT32_t sse_vfmsq_lane_##lane##_q_f32( \ + GI_FLOAT32_t a, GI_FLOAT32_t b, GI_FLOAT32_t v) { \ + return sse_vmlsq_lane_f32(a, b, sse_vget_low_f32(v), lane); \ + } +SSE_VFMSQ_LANEQ_F32(0) +SSE_VFMSQ_LANEQ_F32(1) +#undef SSE_VFMSQ_LANEQ_F32 +#define SSE_VFMSQ_LANEQ_F32(lane) \ + GI_FORCEINLINE GI_FLOAT32_t sse_vfmsq_lane_##lane##_q_f32( \ + GI_FLOAT32_t a, GI_FLOAT32_t b, GI_FLOAT32_t v) { \ + return sse_vmlsq_lane_f32(a, b, sse_vget_high_f32(v), lane - 2); \ + } +SSE_VFMSQ_LANEQ_F32(2) +SSE_VFMSQ_LANEQ_F32(3) +#undef SSE_VFMSQ_LANEQ_F32 +#define GiFmsqLaneQFloat32(a, b, v, lane) sse_vfmsq_lane_##lane##_q_f32(a, b, v) +#else +//! naive +GI_FORCEINLINE GI_FLOAT32_t __naive_GiFmsqLaneQFloat32( + GI_FLOAT32_t a, GI_FLOAT32_t b, GI_FLOAT32_t v, const int lane) { + GI_FLOAT32_t ret; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + ret[i] = a[i] - (b[i] * v[lane]); + } + + return ret; +} +#define GiFmsqLaneQFloat32(a, b, v, lane) __naive_GiFmsqLaneQFloat32(a, b, v, lane) +#endif + +GI_FORCEINLINE GI_FLOAT32_t GiCombineFloat32(float32x2_t a, float32x2_t b) { +#if defined(GI_NEON_INTRINSICS) + return vcombine_f32(a, b); +#elif defined(GI_SSE2_INTRINSICS) + __m128i res; + res = _mm_unpacklo_epi64(_pM128i(a), _pM128i(b)); + return _M128(res); +#else + GI_FLOAT32_t res; + res[0] = a[0]; + res[1] = a[1]; + res[2] = b[0]; + res[3] = b[1]; + return res; +#endif +} + +GI_FORCEINLINE float32x2_t GiGetLowFloat32(GI_FLOAT32_t a) { +#if defined(GI_NEON_INTRINSICS) + return vget_low_f32(a); +#else + return ___gi_vget_low_f32(a); +#endif +} + +GI_FORCEINLINE float32x2_t GiGetHighFloat32(GI_FLOAT32_t a) { +#if defined(GI_NEON_INTRINSICS) + return vget_high_f32(a); +#else + return ___gi_vget_high_f32(a); +#endif +} diff --git a/dnn/src/fallback/general_intrinsic/gi_int.h b/dnn/src/fallback/general_intrinsic/gi_int.h index 97181862..3808640d 100644 --- a/dnn/src/fallback/general_intrinsic/gi_int.h +++ b/dnn/src/fallback/general_intrinsic/gi_int.h @@ -214,8 +214,12 @@ GI_UINT32_t GiTestAndSetUint32(GI_UINT32_t Vector1, GI_UINT32_t Vector2) { #if defined(GI_NEON_INTRINSICS) return vtstq_u32(Vector1, Vector2); #elif defined(GI_SSE2_INTRINSICS) - GI_UINT32_t tmp = _mm_and_si128(Vector1, Vector2); - return _mm_cmpeq_epi32(tmp, _mm_setzero_si128()); + __m128i zero, one, res; + zero = _mm_setzero_si128(); + one = _mm_cmpeq_epi8(zero, zero); + res = _mm_and_si128(Vector1, Vector2); + res = _mm_cmpeq_epi32(res, zero); + return _mm_xor_si128(res, one); #else GI_UINT32_t ret; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { @@ -451,9 +455,15 @@ GI_INT32_t GiAbsInt32(GI_INT32_t Vector) { return _mm_abs_epi32(Vector); #else GI_INT32_t ret; + GI_INT32_NAIVE_t tmp_ret; + GI_INT32_NAIVE_t s0; + + memcpy(&s0, &Vector, sizeof(GI_INT32_t)); for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { - ret[i] = Vector[i] > 0 ? Vector[i] : -Vector[i]; + tmp_ret[i] = s0[i] > 0 ? s0[i] : -s0[i]; } + + memcpy(&ret, &tmp_ret, sizeof(GI_INT32_t)); return ret; #endif } @@ -466,9 +476,14 @@ GI_INT16_t GiAbsInt16(GI_INT16_t Vector) { return _mm_abs_epi16(Vector); #else GI_INT16_t ret; + GI_INT16_NAIVE_t tmp_ret; + GI_INT16_NAIVE_t s0; + + memcpy(&s0, &Vector, sizeof(GI_INT16_t)); for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int16_t); i++) { - ret[i] = Vector[i] > 0 ? Vector[i] : -Vector[i]; + tmp_ret[i] = s0[i] > 0 ? s0[i] : -s0[i]; } + memcpy(&ret, &tmp_ret, sizeof(GI_INT16_t)); return ret; #endif } @@ -481,9 +496,14 @@ GI_INT8_t GiAbsInt8(GI_INT8_t Vector) { return _mm_abs_epi8(Vector); #else GI_INT8_t ret; + GI_INT8_NAIVE_t tmp_ret; + GI_INT8_NAIVE_t s0; + + memcpy(&s0, &Vector, sizeof(GI_INT8_t)); for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { - ret[i] = Vector[i] > 0 ? Vector[i] : -Vector[i]; + tmp_ret[i] = s0[i] > 0 ? s0[i] : -s0[i]; } + memcpy(&ret, &tmp_ret, sizeof(GI_INT8_t)); return ret; #endif } @@ -497,7 +517,11 @@ GI_INT32_t GiMaximumInt32(GI_INT32_t Vector1, GI_INT32_t Vector2) { #elif defined(GI_SSE2_INTRINSICS) return GiBlendInt32(Vector2, Vector1, _mm_cmpgt_epi32(Vector1, Vector2)); #else - return GiBlendInt32(Vector2, Vector1, Vector1 > Vector2); + GI_INT32_t tmp; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { + tmp[i] = Vector1[i] > Vector2[i] ? 0xFFFFFFFF : 0; + } + return GiBlendInt32(Vector2, Vector1, tmp); #endif } @@ -510,7 +534,11 @@ GI_INT32_t GiMinimumInt32(GI_INT32_t Vector1, GI_INT32_t Vector2) { #elif defined(GI_SSE2_INTRINSICS) return GiBlendInt32(Vector2, Vector1, _mm_cmpgt_epi32(Vector2, Vector1)); #else - return GiBlendInt32(Vector2, Vector1, Vector2 > Vector1); + GI_INT32_t tmp; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { + tmp[i] = Vector2[i] > Vector1[i] ? 0xFFFFFFFF : 0; + } + return GiBlendInt32(Vector2, Vector1, tmp); #endif } @@ -528,7 +556,11 @@ GI_INT8_t GiMaximumInt8(GI_INT8_t Vector1, GI_INT8_t Vector2) { #elif defined(GI_SSE2_INTRINSICS) return GiBlendInt8(Vector2, Vector1, _mm_cmpgt_epi8(Vector1, Vector2)); #else - return GiBlendInt8(Vector2, Vector1, Vector1 > Vector2); + GI_INT8_t tmp; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { + tmp[i] = Vector1[i] > Vector2[i] ? 0xFF : 0; + } + return GiBlendInt8(Vector2, Vector1, tmp); #endif } @@ -541,7 +573,11 @@ GI_INT8_t GiMinimumInt8(GI_INT8_t Vector1, GI_INT8_t Vector2) { #elif defined(GI_SSE2_INTRINSICS) return GiBlendInt8(Vector2, Vector1, _mm_cmpgt_epi8(Vector2, Vector1)); #else - return GiBlendInt8(Vector2, Vector1, Vector2 > Vector1); + GI_INT8_t tmp; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { + tmp[i] = Vector2[i] > Vector1[i] ? 0xFF : 0; + } + return GiBlendInt8(Vector2, Vector1, tmp); #endif } @@ -813,14 +849,18 @@ GI_INT8_t GiCvtFromFloat32ToInt8(GI_FLOAT32_t src) { return vepi8; #else GI_INT8_t ret; + GI_INT8_NAIVE_t tmp_ret; + GI_FLOAT32_NAIVE_t s0; + memcpy(&s0, &src, sizeof(GI_INT32_t)); int length = GI_SIMD_LEN_BYTE / sizeof(float); for (int i = 0; i < length; i++) { - int8_t data = Saturate(round(src[i]), -128, 127); - ret[i] = data; - ret[length + i] = data; - ret[2 * length + i] = data; - ret[3 * length + i] = data; + int8_t data = Saturate(round(s0[i]), -128, 127); + tmp_ret[i] = data; + tmp_ret[length + i] = data; + tmp_ret[2 * length + i] = data; + tmp_ret[3 * length + i] = data; } + memcpy(&ret, &tmp_ret, sizeof(GI_INT8_t)); return ret; #endif } @@ -861,10 +901,16 @@ GI_INT8_t GiCvtFromFloat32V2ToInt8(GI_FLOAT32_V2_t vsrc) { return vepi8; #else GI_INT8_t ret; + GI_INT8_NAIVE_t tmp_ret; + GI_FLOAT32_V2_NAIVE_t s0; + memcpy(&s0, &vsrc, sizeof(GI_FLOAT32_V2_NAIVE_t)); int length = GI_SIMD_LEN_BYTE / sizeof(float); for (int i = 0; i < 2 * length; i++) { - ret[i] = Saturate(round(vsrc.val[i / length][i % length]), -128, 127); + int8_t data = Saturate(round(s0.val[i / length][i % length]), -128, 127); + tmp_ret[i] = data; + tmp_ret[i + length * 2] = data; } + memcpy(&ret, &tmp_ret, sizeof(GI_INT8_t)); return ret; #endif } @@ -875,8 +921,8 @@ GI_INT8_t GiCvtFromFloat32V4ToInt8(GI_FLOAT32_V4_t vsrc) { #if __ARM_ARCH >= 8 int32x4_t vres0 = vcvtaq_s32_f32(vsrc.val[0]); int32x4_t vres1 = vcvtaq_s32_f32(vsrc.val[1]); - int32x4_t vres2 = vcvtaq_s32_f32(vsrc.val[1]); - int32x4_t vres3 = vcvtaq_s32_f32(vsrc.val[1]); + int32x4_t vres2 = vcvtaq_s32_f32(vsrc.val[2]); + int32x4_t vres3 = vcvtaq_s32_f32(vsrc.val[3]); int8x8_t mid1 = vqmovn_s16(vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1))); int8x8_t mid2 = vqmovn_s16(vcombine_s16(vqmovn_s32(vres2), vqmovn_s32(vres3))); return vcombine_s8(mid1, mid2); @@ -910,7 +956,7 @@ GI_INT8_t GiCvtFromFloat32V4ToInt8(GI_FLOAT32_V4_t vsrc) { vres0 = _mm_round_ps(vres0, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); vres1 = _mm_round_ps(vres1, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); vres2 = _mm_round_ps(vres2, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); - vres3 = _mm_round_ps(vres1, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + vres3 = _mm_round_ps(vres3, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); vres0 = _mm_min_ps(_mm_max_ps(vres0, vfmin_int8), vfmax_int8); vres1 = _mm_min_ps(_mm_max_ps(vres1, vfmin_int8), vfmax_int8); @@ -927,10 +973,14 @@ GI_INT8_t GiCvtFromFloat32V4ToInt8(GI_FLOAT32_V4_t vsrc) { return vepi8; #else GI_INT8_t ret; + GI_INT8_NAIVE_t tmp_ret; + GI_FLOAT32_V4_NAIVE_t s0; + memcpy(&s0, &vsrc, sizeof(GI_FLOAT32_V4_NAIVE_t)); int length = GI_SIMD_LEN_BYTE / sizeof(float); for (int i = 0; i < 4 * length; i++) { - ret[i] = Saturate(round(vsrc.val[i / length][i % length]), -128, 127); + tmp_ret[i] = Saturate(round(s0.val[i / length][i % length]), -128, 127); } + memcpy(&ret, &tmp_ret, sizeof(GI_INT8_t)); return ret; #endif } diff --git a/dnn/test/fallback/gi.cpp b/dnn/test/fallback/gi.cpp new file mode 100644 index 00000000..e37a872e --- /dev/null +++ b/dnn/test/fallback/gi.cpp @@ -0,0 +1,3167 @@ +/** + * \file dnn/test/fallback/gi.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2022 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include "test/fallback/fixture.h" + +#include "src/fallback/general_intrinsic/gi_float.h" +#include "src/fallback/general_intrinsic/gi_int.h" + +namespace megdnn { +namespace test { + +#define SIMD_LEN GI_SIMD_LEN_BYTE / sizeof(float) +#define SIMD_LEN_16 GI_SIMD_LEN_BYTE / sizeof(int16_t) +#define SIMD_LEN_8 GI_SIMD_LEN_BYTE / sizeof(int8_t) +template +static void init( + T* dst, const std::vector& value, const size_t simd_len = SIMD_LEN) { + for (size_t i = 0; i < simd_len; i++) { + dst[i] = value[i]; + } +} + +template +static void assert_eq(T* a, const std::vector& b, const size_t simd_len = SIMD_LEN) { + for (size_t i = 0; i < simd_len; i++) { + ASSERT_EQ(a[i], b[i]); + } +} + +template +static void assert_eq_and_nan( + T* a, const std::vector& b, const size_t simd_len = SIMD_LEN) { + for (size_t i = 0; i < simd_len; i++) { + if (isnan(a[i]) && isnan(b[i])) { + continue; + } + ASSERT_EQ(a[i], b[i]); + } +} + +static void assert_lt( + float* a, const std::vector& b, const float eps, + const size_t simd_len = SIMD_LEN) { + for (size_t i = 0; i < simd_len; i++) { + ASSERT_LT(std::abs(a[i] - b[i]), eps); + } +} + +TEST_F(FALLBACK, GiGetSimdType) { + auto t = GiGetSimdType(); + auto should_type = GI_UNKNOWN; +#if defined(GI_AVX_INTRINSICS) || defined(GI_AVX2_INTRINSICS) || \ + defined(GI_FMA_INTRINSICS) + should_type = GI_AVX; +#elif defined(GI_NEON_INTRINSICS) + should_type = GI_NEON; +#elif defined(GI_SSE2_INTRINSICS) || defined(GI_SSE42_INTRINSICS) + +#if defined(GI_SSE42_INTRINSICS) + should_type = GI_SSE42; +#elif defined(GI_SSE2_INTRINSICS) + should_type = GI_SSE2; +#else + should_type = GI_UNKNOWN; +#error "code issue happened!!" +#endif + +#else + should_type = GI_NAIVE; +#endif + + printf("test GiGetSimdType: %d, should_type: %d\n", t, should_type); + + ASSERT_EQ(t, should_type); +} + +TEST_F(FALLBACK, GiAndInt32) { + GI_INT32_t src0, src1, ret; + std::vector s0{1, 2, 3, 4}; + s0.resize(SIMD_LEN); + std::vector s1{5, 6, 7, 8}; + s1.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + init((int32_t*)&src1, s1); + + ret = GiAndInt32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] & s1[i]); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiOrInt32) { + GI_INT32_t src0, src1, ret; + std::vector s0{1, 2, 3, 4}; + s0.resize(SIMD_LEN); + std::vector s1{5, 6, 7, 8}; + s1.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + init((int32_t*)&src1, s1); + + ret = GiOrInt32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] | s1[i]); + } + + assert_eq((int*)&ret, naive); +} + +TEST_F(FALLBACK, GiAndNotInt32) { + GI_INT32_t src0, src1, ret; + std::vector s0{1, 2, 3, 4}; + s0.resize(SIMD_LEN); + std::vector s1{5, 6, 7, 8}; + s1.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + init((int32_t*)&src1, s1); + + ret = GiAndNotInt32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(~s0[i] & s1[i]); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiXorInt32) { + GI_INT32_t src0, src1, ret; + std::vector s0{1, 2, 3, 4}; + s0.resize(SIMD_LEN); + std::vector s1{5, 6, 7, 8}; + s1.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + init((int32_t*)&src1, s1); + + ret = GiXorInt32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] ^ s1[i]); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiBroadcastFloat32) { + GI_FLOAT32_t ret; + float b = 2022.0420; + + ret = GiBroadcastFloat32(b); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(b); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiBroadcastInt32) { + GI_INT32_t ret; + int32_t b = 20220420; + + ret = GiBroadcastInt32(b); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(b); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiReinterpretAsInt32) { + GI_INT32_t ret; + GI_FLOAT32_t src0; + std::vector s0{1.0f, 2.2f, 3.4f, 4.5f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiReinterpretAsInt32(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + int32_t tmp; + memcpy(&tmp, &s0[i], sizeof(int32_t)); + naive.push_back(tmp); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiReinterpretAsUint32) { + GI_UINT32_t ret; + GI_FLOAT32_t src0; + std::vector s0{1.0f, 2.2f, 3.4f, 4.5f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiReinterpretAsUint32(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + uint32_t tmp; + memcpy(&tmp, &s0[i], sizeof(uint32_t)); + naive.push_back(tmp); + } + + assert_eq((uint32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiReintInt32ToFloat32) { + GI_FLOAT32_t ret; + GI_INT32_t src0; + std::vector s0{1, 2, 3, 4}; + s0.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + + ret = GiReintInt32ToFloat32(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + float tmp; + memcpy(&tmp, &s0[i], sizeof(float)); + naive.push_back(tmp); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiReintUint32ToFloat32) { + GI_FLOAT32_t ret; + GI_UINT32_t src0; + std::vector s0{1, 2, 3, 4}; + s0.resize(SIMD_LEN); + init((uint32_t*)&src0, s0); + + ret = GiReintUint32ToFloat32(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + float tmp; + memcpy(&tmp, &s0[i], sizeof(float)); + naive.push_back(tmp); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiRoundAsInt32) { + GI_FLOAT32_t src0; + GI_INT32_t ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiRoundAsInt32(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back((int32_t)round(s0[i])); + } + + assert_eq((int*)&ret, naive); +} + +TEST_F(FALLBACK, GiCastToInt32) { + GI_FLOAT32_t src0; + GI_INT32_t ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiCastToInt32(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back((int32_t)(s0[i])); + } + + assert_eq((int*)&ret, naive); +} + +TEST_F(FALLBACK, GiCastToFloat32) { + GI_INT32_t src0; + GI_FLOAT32_t ret; + std::vector s0{100, 200, 300, 400}; + s0.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + + ret = GiCastToFloat32(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back((float)s0[i]); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiLoadBroadcastFloat32) { + GI_FLOAT32_t ret; + float p = 2022.0420; + + ret = GiLoadBroadcastFloat32(&p); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(p); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiZeroFloat32) { + GI_FLOAT32_t ret; + memset(&ret, 'f', sizeof(GI_FLOAT32_t)); + float p = 0; + + ret = GiZeroFloat32(); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(p); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiLoadFloat32) { + GI_FLOAT32_t ret; + std::vector s0{2.3f, 4.7f, -1.4f, 1223.6f}; + s0.resize(SIMD_LEN); + + ret = GiLoadFloat32(s0.data()); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i]); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiLoadFloat32LowHalf) { + GI_FLOAT32_t ret; + std::vector s0{2.3f, 4.7f, -1.4f, 1223.6f}; + s0.resize(SIMD_LEN); + + ret = GiLoadFloat32LowHalf(s0.data()); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + if (i < SIMD_LEN / 2) { + naive.push_back(s0[i]); + } else { + naive.push_back(0); + } + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiMlaqFloat32) { + GI_FLOAT32_t src0, src1, src2, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + std::vector s2{1.2f, -3.1f, 9.0f, 11.2f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + s2.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + init((float*)&src2, s2); + + ret = GiMlaqFloat32(src0, src1, src2); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] + (s1[i] * s2[i])); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiUzpqFloat32) { + GI_FLOAT32_t src0, src1; + GI_FLOAT32_V2_t ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiUzpqFloat32(src0, src1); + + std::vector naive0; + std::vector naive1; + naive0.push_back(s0[0]); + naive0.push_back(s0[2]); + naive0.push_back(s1[0]); + naive0.push_back(s1[2]); + naive1.push_back(s0[1]); + naive1.push_back(s0[3]); + naive1.push_back(s1[1]); + naive1.push_back(s1[3]); + + assert_eq((float*)&ret, naive0); + assert_eq((float*)&ret + SIMD_LEN, naive1); +} + +TEST_F(FALLBACK, GiDupFloat32) { + float32x2_t ret; + float t = 3.1415; + + ret = GiDupFloat32(t); + + auto r = (float*)&ret; + ASSERT_EQ(*r, t); + ASSERT_EQ(*(r + 1), t); +} + +TEST_F(FALLBACK, GiLdFloat32) { + float32x2_t ret; + std::vector s0{1.1f, -3.1415f}; + + ret = GiLdFloat32(s0.data()); + + auto r = (float*)&ret; + ASSERT_EQ(*r, s0[0]); + ASSERT_EQ(*(r + 1), s0[1]); +} + +TEST_F(FALLBACK, GiAddDFloat32) { + float32x2_t src0, src1, ret; + std::vector s0{1.1f, -3.1415f}; + std::vector s1{2.3f, 3.14777f}; + memcpy(&src0, s0.data(), sizeof(float32x2_t)); + memcpy(&src1, s1.data(), sizeof(float32x2_t)); + + ret = GiAddDFloat32(src0, src1); + + auto r = (float*)&ret; + + auto naive0 = s0[0] + s1[0]; + auto naive1 = s0[1] + s1[1]; + ASSERT_EQ(*r, naive0); + ASSERT_EQ(*(r + 1), naive1); +} + +TEST_F(FALLBACK, GiGetLaneFloat32) { + float32x2_t src0; + std::vector s0{1.1f, -3.1415f}; + memcpy(&src0, s0.data(), sizeof(float32x2_t)); + + auto ret = GiGetLaneFloat32(src0, 0); + ASSERT_EQ(ret, s0[0]); + + ret = GiGetLaneFloat32(src0, 1); + ASSERT_EQ(ret, s0[1]); +} + +TEST_F(FALLBACK, GiSetLaneFloat32) { + float32x2_t src0, ret; + std::vector s0{2.1f, -3.1415f}; + memcpy(&src0, s0.data(), sizeof(float32x2_t)); + float p = 2022.0420; + + auto r = (float*)&ret; + ret = GiSetLaneFloat32(p, src0, 0); + ASSERT_EQ(*r, p); + ASSERT_EQ(*(r + 1), s0[1]); + + ret = GiSetLaneFloat32(p, src0, 1); + ASSERT_EQ(*r, s0[0]); + ASSERT_EQ(*(r + 1), p); +} + +TEST_F(FALLBACK, GiSt1Float32) { + float32x2_t src0; + std::vector s0{2.1f, -3.1415f}; + memcpy(&src0, s0.data(), sizeof(float32x2_t)); + + std::vector ret{0, 0}; + GiSt1Float32(ret.data(), src0); + ASSERT_EQ(ret[0], s0[0]); + ASSERT_EQ(ret[1], s0[1]); +} + +TEST_F(FALLBACK, GiLd2qFloat32) { + GI_FLOAT32_V2_t ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f, 2312.1f, 345.244f, 3.59f, -12.8f}; + + ret = GiLd2qFloat32(s0.data()); + + std::vector naive0; + std::vector naive1; + naive0.push_back(s0[0]); + naive0.push_back(s0[2]); + naive0.push_back(s0[4]); + naive0.push_back(s0[6]); + naive1.push_back(s0[1]); + naive1.push_back(s0[3]); + naive1.push_back(s0[5]); + naive1.push_back(s0[7]); + + assert_eq((float*)&ret, naive0); + assert_eq((float*)&ret + SIMD_LEN, naive1); +} + +TEST_F(FALLBACK, GiExtqFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{-9.1f, 34234.6f, 9.0f, 34.1f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + std::vector naive = {0, 0, 0, 0}; + + auto compare = [&](const size_t n) { + size_t t_count = SIMD_LEN; + size_t a_count = t_count - n; + for (size_t i = 0; i < a_count; i++) { + naive[i] = s0[i + n]; + } + for (size_t i = 0; i < n; i++) { + naive[i + a_count] = s1[i]; + } + assert_eq((float*)&ret, naive); + }; + +#define CB(n) \ + ret = GiExtqFloat32(src0, src1, n); \ + compare(n); + + CB(0) + CB(1) + CB(2) + CB(3) +#undef CB +} + +TEST_F(FALLBACK, GiMultiplySubFloat32) { + GI_FLOAT32_t src0, src1, src2, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{-9.1f, 34234.6f, 9.0f, 34.1f}; + std::vector s2{0.4f, 9.9f, 4.3f, 6.2f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + s2.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + init((float*)&src2, s2); + + ret = GiMultiplySubFloat32(src0, src1, src2); + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] - (s1[i] * s2[i])); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiLd1qLaneFloat32) { + GI_FLOAT32_t src0, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + std::vector naive = {0, 0, 0, 0}; + + float buffer = 3.14159; + + auto compare = [&](const size_t n) { + memcpy(naive.data(), s0.data(), sizeof(GI_FLOAT32_t)); + naive[n] = buffer; + assert_eq((float*)&ret, naive); + }; + +#define CB(n) \ + ret = GiLd1qLaneFloat32(&buffer, src0, n); \ + compare(n); + + CB(0) + CB(1) + CB(2) + CB(3) +#undef CB +} + +TEST_F(FALLBACK, GiSetqLaneFloat32) { + GI_FLOAT32_t src0, ret; + std::vector s0{2.1f, 6.2f, -9.5f, 2.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + std::vector naive = {0, 0, 0, 0}; + + float buffer = 6.14159; + + auto compare = [&](const size_t n) { + memcpy(naive.data(), s0.data(), sizeof(GI_FLOAT32_t)); + naive[n] = buffer; + assert_eq((float*)&ret, naive); + }; + +#define CB(n) \ + ret = GiSetqLaneFloat32(buffer, src0, n); \ + compare(n); + + CB(0) + CB(1) + CB(2) + CB(3) +#undef CB +} + +TEST_F(FALLBACK, GiMlaqLaneFloat32HighHalf) { + GI_FLOAT32_t src0, src1, src2, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{-9.1f, 34234.6f, 9.0f, 34.1f}; + std::vector s2{0.4f, 9.9f, 4.3f, 6.2f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + s2.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + init((float*)&src2, s2); + std::vector naive = {0, 0, 0, 0}; + + auto compare = [&](const size_t n) { + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + naive[i] = s0[i] + (s1[i] * s2[n + 2]); + } + assert_eq((float*)&ret, naive); + }; + +#define CB(n) \ + ret = GiMlaqLaneFloat32HighHalf(src0, src1, src2, n); \ + compare(n); + + CB(0) + CB(1) +#undef CB +} + +TEST_F(FALLBACK, GiVmlaqLaneFloat32LowHalf) { + GI_FLOAT32_t src0, src1, src2, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{-9.1f, 34234.6f, 9.0f, 34.1f}; + std::vector s2{0.4f, 9.9f, 4.3f, 6.2f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + s2.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + init((float*)&src2, s2); + std::vector naive = {0, 0, 0, 0}; + + auto compare = [&](const size_t n) { + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + naive[i] = s0[i] + (s1[i] * s2[n]); + } + assert_eq((float*)&ret, naive); + }; + +#define CB(n) \ + ret = GiVmlaqLaneFloat32LowHalf(src0, src1, src2, n); \ + compare(n); + + CB(0) + CB(1) +#undef CB +} + +TEST_F(FALLBACK, GiStoreFloat32) { + GI_FLOAT32_t src0; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + std::vector ret{0}; + ret.resize(SIMD_LEN); + + GiStoreFloat32(ret.data(), src0); + assert_eq(ret.data(), s0); +} + +TEST_F(FALLBACK, GiStoreLaneXXFloat32) { + GI_FLOAT32_t src0; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + float ret{0}; + +#define CB(n) \ + GiStoreLane##n##Float32(&ret, src0); \ + ASSERT_EQ(ret, s0[n]); + + CB(0) + CB(1) + CB(2) + CB(3) +#undef CB +} + +TEST_F(FALLBACK, GiExtractLaneXXFloat32) { + GI_FLOAT32_t src0; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + float ret{0}; + +#define CB(n) \ + ret = GiExtractLane##n##Float32(src0); \ + ASSERT_EQ(ret, s0[n]); + + CB(0) + CB(1) + CB(2) + CB(3) +#undef CB +} + +TEST_F(FALLBACK, GiZipqFloat32) { + GI_FLOAT32_t src0, src1; + GI_FLOAT32_V2_t ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiZipqFloat32(src0, src1); + + std::vector naive0; + std::vector naive1; + naive0.push_back(s0[0]); + naive0.push_back(s1[0]); + naive0.push_back(s0[1]); + naive0.push_back(s1[1]); + naive1.push_back(s0[2]); + naive1.push_back(s1[2]); + naive1.push_back(s0[3]); + naive1.push_back(s1[3]); + + assert_eq((float*)&ret, naive0); + assert_eq((float*)&ret + SIMD_LEN, naive1); +} + +TEST_F(FALLBACK, GiInterleaveLowFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiInterleaveLowFloat32(src0, src1); + + std::vector naive; + naive.resize(SIMD_LEN); + + for (size_t i = 0; i < SIMD_LEN / 2; i++) { + naive[2 * i] = s0[i]; + naive[2 * i + 1] = s1[i]; + } + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiInterleaveHighFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiInterleaveHighFloat32(src0, src1); + + std::vector naive; + naive.resize(SIMD_LEN); + + for (size_t i = 0; i < SIMD_LEN / 2; i++) { + naive[2 * i] = s0[i + SIMD_LEN / 2]; + naive[2 * i + 1] = s1[i + SIMD_LEN / 2]; + } + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiAddFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiAddFloat32(src0, src1); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] + s1[i]); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiSubtractFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiSubtractFloat32(src0, src1); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] - s1[i]); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiMultiplyFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiMultiplyFloat32(src0, src1); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] * s1[i]); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiMultiplyScalerFloat32) { + GI_FLOAT32_t src0, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + float scalar = 3.1415; + + ret = GiMultiplyScalerFloat32(src0, scalar); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] * scalar); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiMultiplyAddFloat32) { + GI_FLOAT32_t src0, src1, src2, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + std::vector s2{12.1f, 35.244f, 23.59f, -112.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + s2.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + init((float*)&src2, s2); + + ret = GiMultiplyAddFloat32(src0, src1, src2); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s1[i] * s2[i] + s0[i]); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiMultiplyAddScalarFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + float scalar = 3.1415; + + ret = GiMultiplyAddScalarFloat32(src0, src1, scalar); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s1[i] * scalar + s0[i]); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiMultiplyAddLanXXFloat32) { + GI_FLOAT32_t src0, src1, src2, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + std::vector s2{12.1f, 35.244f, 23.59f, -112.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + s2.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + init((float*)&src2, s2); + + std::vector naive = {0, 0, 0, 0}; + + auto compare = [&](const size_t n) { + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + naive[i] = s0[i] + (s1[i] * s2[n]); + } + assert_eq((float*)&ret, naive); + }; + +#define CB(n) \ + ret = GiMultiplyAddLan##n##Float32(src0, src1, src2); \ + compare(n); + + CB(0) + CB(1) + CB(2) + CB(3) +#undef CB +} + +TEST_F(FALLBACK, GiDivideFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiDivideFloat32(src0, src1); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] / s1[i]); + } + + assert_lt((float*)&ret, naive, 1e-3); +} + +TEST_F(FALLBACK, GiRecpeSFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiRecpeSFloat32(src0, src1); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(2.0f - s0[i] * s1[i]); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiRecpeFloat32) { + GI_FLOAT32_t src0, ret; + std::vector s0{100.1f, 2.2f, 3.5f, 4.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiRecpeFloat32(src0); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(1.0f / s0[i]); + } + + assert_lt((float*)&ret, naive, 1e-3); +} + +TEST_F(FALLBACK, GiNegFloat32) { + GI_FLOAT32_t src0, ret; + std::vector s0{-1.1f, 2.2f, 3.5f, 4.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiNegFloat32(src0); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(-s0[i]); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiGreaterThanFloat32) { + GI_FLOAT32_t src0, src1; + GI_UINT32_t ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 0.1f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiGreaterThanFloat32(src0, src1); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] > s1[i] ? 0xFFFFFFFF : 0); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiLessThanEqFloat32) { + GI_FLOAT32_t src0, src1; + GI_UINT32_t ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 0.1f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiLessThanEqFloat32(src0, src1); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] <= s1[i] ? 0xFFFFFFFF : 0); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiLessThanFloat32) { + GI_FLOAT32_t src0, src1; + GI_UINT32_t ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{1.1f, 0.1f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiLessThanFloat32(src0, src1); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] < s1[i] ? 0xFFFFFFFF : 0); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiAndFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiAndFloat32(src0, src1); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + int32_t tmp0, tmp1, tmp; + float tmp2; + memcpy(&tmp0, &s0[i], sizeof(int32_t)); + memcpy(&tmp1, &s1[i], sizeof(int32_t)); + tmp = tmp0 & tmp1; + memcpy(&tmp2, &tmp, sizeof(float)); + naive.push_back(tmp2); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiOrFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{2, 2, 3, 4}; + std::vector s1{6, 6, 7, 8}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiOrFloat32(src0, src1); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + int32_t tmp0, tmp1, tmp; + float tmp2; + memcpy(&tmp0, &s0[i], sizeof(int32_t)); + memcpy(&tmp1, &s1[i], sizeof(int32_t)); + tmp = tmp0 | tmp1; + memcpy(&tmp2, &tmp, sizeof(float)); + naive.push_back(tmp2); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiAndNotFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiAndNotFloat32(src0, src1); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + int32_t tmp0, tmp1, tmp; + float tmp2; + memcpy(&tmp0, &s0[i], sizeof(int32_t)); + memcpy(&tmp1, &s1[i], sizeof(int32_t)); + tmp = ~tmp0 & tmp1; + memcpy(&tmp2, &tmp, sizeof(float)); + naive.push_back(tmp2); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiXorFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiXorFloat32(src0, src1); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN; i++) { + int32_t tmp0, tmp1, tmp; + float tmp2; + memcpy(&tmp0, &s0[i], sizeof(int32_t)); + memcpy(&tmp1, &s1[i], sizeof(int32_t)); + tmp = tmp0 ^ tmp1; + memcpy(&tmp2, &tmp, sizeof(float)); + naive.push_back(tmp2); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiBSLFloat32) { + GI_FLOAT32_t src0, src1, ret, na; + GI_UINT32_t mask; + std::vector s0{1.1f, 2.2f, 4.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + std::vector> s2s = { + {1, 2, 3, 0}, {0u, 0u, 0u, 0u}, {~0u, 0u, 0u, 0u}, + {~0u, ~0u, 0u, 0u}, {~0u, ~0u, ~0u, 0u}, {~0u, ~0u, ~0u, ~0u}}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + for (auto& s2 : s2s) { + init((uint32_t*)&mask, s2); + + ret = GiBSLFloat32(mask, src0, src1); + na = GiBlendFloat32(src0, src1, GiReintUint32ToFloat32(mask)); + + std::vector naive; + naive.resize(SIMD_LEN); + memcpy(naive.data(), &na, sizeof(GI_FLOAT32_t)); + + assert_eq((float*)&ret, naive); + } +} + +TEST_F(FALLBACK, GiMaximumFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 4.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiMaximumFloat32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(Max(s0[i], s1[i])); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiMinimumFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 4.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiMinimumFloat32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(Min(s0[i], s1[i])); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiMaxNanFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{1.1f, 2.2f, 4.5f, NAN}; + std::vector s1{2312.1f, 345.244f, NAN, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiMaxNanFloat32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + auto t = MAX_NAN(s0[i], s1[i]); + naive.push_back(t); + } + + assert_eq_and_nan((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiMinNanFloat32) { + GI_FLOAT32_t src0, src1, ret; + std::vector s0{NAN, 2.2f, NAN, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + + ret = GiMinNanFloat32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + auto t = MIN_NAN(s0[i], s1[i]); + naive.push_back(t); + } + + assert_eq_and_nan((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiClampFloat32) { + GI_FLOAT32_t src0, src1, ret, na; + std::vector s0{1.1f, 2.2f, 4.5f, 4.9f}; + std::vector s1{1.1f, 2.2f, 4.5f, 4.9f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + float LowerRange = 3.1415; + float UpperRange = 4.876; + + auto naive_c = [](GI_FLOAT32_t Value, float LowerRange, + float UpperRange) -> GI_FLOAT32_t { + Value = GiMaximumFloat32(GiBroadcastFloat32(LowerRange), Value); + Value = GiMinimumFloat32(GiBroadcastFloat32(UpperRange), Value); + return Value; + }; + ret = GiClampFloat32(src0, LowerRange, UpperRange); + na = naive_c(src1, LowerRange, UpperRange); + + std::vector naive; + naive.resize(SIMD_LEN); + memcpy(naive.data(), &na, sizeof(GI_FLOAT32_t)); + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiReduceAddFloat32) { + GI_FLOAT32_t src0; + float ret{0}; + std::vector s0{1.1f, 2.2f, 4.5f, -4.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiReduceAddFloat32(src0); + + float naive{0}; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive += s0[i]; + } + + ASSERT_LT(std::abs(ret - naive), 1e-3); +} + +TEST_F(FALLBACK, GiReduceMultiplyFloat32) { + GI_FLOAT32_t src0; + float ret{0}; + std::vector s0{1.1f, 2.2f, 4.5f, -4.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiReduceMultiplyFloat32(src0); + + float naive{1}; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive *= s0[i]; + } + + ASSERT_LT(std::abs(ret - naive), 1e-3); +} + +TEST_F(FALLBACK, GiReduceMaxNanFloat32) { + GI_FLOAT32_t src0; + float ret{0}; + std::vector s0{1.1f, 2.2f, 4.9f, -4.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiReduceMaxNanFloat32(src0); + + float naive = s0[0]; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive = MAX_NAN(naive, s0[i]); + } + + ASSERT_EQ(ret, naive); + ret = 0; + s0 = {1.1f, 2.2f, 4.9f, NAN}; + init((float*)&src0, s0); + + ret = GiReduceMaxNanFloat32(src0); + ASSERT_TRUE(isnan(ret)); +} + +TEST_F(FALLBACK, GiReduceMinNanFloat32) { + GI_FLOAT32_t src0; + float ret{0}; + std::vector s0{1.1f, 2.2f, 4.5f, -4.9f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiReduceMinNanFloat32(src0); + + float naive = s0[0]; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive = MIN_NAN(naive, s0[i]); + } + + ASSERT_EQ(ret, naive); + ret = 0; + s0 = {-1.1f, 2.2f, 4.9f, NAN}; + init((float*)&src0, s0); + + ret = GiReduceMaxNanFloat32(src0); + ASSERT_TRUE(isnan(ret)); +} + +TEST_F(FALLBACK, GiAbsFloat32) { + GI_FLOAT32_t src0, ret; + std::vector s0{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiAbsFloat32(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] > 0 ? s0[i] : -s0[i]); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiZip1qS64) { + GI_INT64_t src0, src1, ret; + std::vector s0{234242423424245, 42342342422323}; + std::vector s1{23424245, -4234234242232}; + s0.resize(SIMD_LEN / 2); + s1.resize(SIMD_LEN / 2); + memcpy(&src0, s0.data(), sizeof(GI_INT64_t)); + memcpy(&src1, s1.data(), sizeof(GI_INT64_t)); + + ret = GiZip1qS64(src0, src1); + + std::vector naive; + naive.push_back(s0[0]); + naive.push_back(s1[0]); + auto p = (int64_t*)&ret; + ASSERT_EQ(naive[0], p[0]); + ASSERT_EQ(naive[1], p[1]); +} + +TEST_F(FALLBACK, GiZip2qS64) { + GI_INT64_t src0, src1, ret; + std::vector s0{234242423424245, 42342342422323}; + std::vector s1{23424245, -4234234242232}; + s0.resize(SIMD_LEN / 2); + s1.resize(SIMD_LEN / 2); + memcpy(&src0, s0.data(), sizeof(GI_INT64_t)); + memcpy(&src1, s1.data(), sizeof(GI_INT64_t)); + + ret = GiZip2qS64(src0, src1); + + std::vector naive; + naive.push_back(s0[1]); + naive.push_back(s1[1]); + auto p = (int64_t*)&ret; + ASSERT_EQ(naive[0], p[0]); + ASSERT_EQ(naive[1], p[1]); +} + +TEST_F(FALLBACK, GiReinterpretqS64ToFloat32) { + GI_INT64_t src0; + GI_FLOAT32_t ret; + std::vector s0{234242423424245, 42342342422323}; + s0.resize(SIMD_LEN / 2); + memcpy(&src0, s0.data(), sizeof(GI_INT64_t)); + + ret = GiReinterpretqS64ToFloat32(src0); + + std::vector naive; + naive.resize(SIMD_LEN); + memcpy(naive.data(), s0.data(), sizeof(GI_FLOAT32_t)); + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiReinterpretqFloat32ToS64) { + GI_FLOAT32_t src0; + GI_INT64_t ret; + std::vector s0{2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiReinterpretqFloat32ToS64(src0); + + std::vector naive; + naive.resize(SIMD_LEN); + memcpy(naive.data(), s0.data(), sizeof(GI_INT64_t)); + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiSimdFmaLane) { + GI_FLOAT32_t src0, src1, src2, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + std::vector s2{12.1f, 2.2f, 89.0f, -112.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + s2.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + init((float*)&src2, s2); + + std::vector naive = {0, 0, 0, 0}; + + auto compare = [&](const size_t n) { + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + naive[i] = s0[i] + (s1[i] * s2[n]); + } + assert_eq((float*)&ret, naive); + }; + +#define CB(n) \ + ret = GiSimdFmaLane(src0, src1, src2, n); \ + compare(n); + + CB(0) + CB(1) + CB(2) + CB(3) +#undef CB +} + +TEST_F(FALLBACK, GiMlaqLowLaneFloat32) { + GI_FLOAT32_t src0, src1, src2, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + std::vector s2{12.1f, 2.2f, 89.0f, -112.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + s2.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + init((float*)&src2, s2); + + std::vector naive = {0, 0, 0, 0}; + + auto compare = [&](const size_t n) { + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + naive[i] = s0[i] + (s1[i] * s2[n]); + } + assert_eq((float*)&ret, naive); + }; + +#define CB(n) \ + ret = GiMlaqLowLaneFloat32(src0, src1, src2, n); \ + compare(n); + + CB(0) + CB(1) +#undef CB +} + +TEST_F(FALLBACK, GiMlaqHighLaneFloat32) { + GI_FLOAT32_t src0, src1, src2, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + std::vector s2{12.1f, 2.2f, 89.0f, -112.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + s2.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + init((float*)&src2, s2); + + std::vector naive = {0, 0, 0, 0}; + + auto compare = [&](const size_t n) { + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + naive[i] = s0[i] + (s1[i] * s2[n]); + } + assert_eq((float*)&ret, naive); + }; + +#define CB(n) \ + ret = GiMlaqHighLaneFloat32(src0, src1, src2, n); \ + compare(n); + + CB(2) + CB(3) +#undef CB +} + +TEST_F(FALLBACK, GiFmsqLaneQFloat32) { + GI_FLOAT32_t src0, src1, src2, ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; + std::vector s1{2312.1f, 345.244f, 3.59f, -12.8f}; + std::vector s2{12.1f, 2.2f, 89.0f, -112.8f}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + s2.resize(SIMD_LEN); + init((float*)&src0, s0); + init((float*)&src1, s1); + init((float*)&src2, s2); + + std::vector naive = {0, 0, 0, 0}; + + auto compare = [&](const size_t n) { + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + naive[i] = s0[i] - (s1[i] * s2[n]); + } + assert_eq((float*)&ret, naive); + }; + +#define CB(n) \ + ret = GiFmsqLaneQFloat32(src0, src1, src2, n); \ + compare(n); + + CB(0) + CB(1) + CB(2) + CB(3) +#undef CB +} + +TEST_F(FALLBACK, GiBroadcastUint32) { + int32_t src0 = 20220422; + GI_UINT32_t ret; + + ret = GiBroadcastUint32(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(src0); + } + + assert_eq((uint32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiLoadInt32) { + std::vector s0{1, 2, -200, 999}; + GI_INT32_t ret; + + ret = GiLoadInt32(s0.data()); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i]); + } + + assert_eq((uint32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiLoadInt16) { + std::vector s0{1, 2, -200, 32767, -32768, 45, 3, 0}; + GI_INT16_t ret; + + ret = GiLoadInt16(s0.data()); + + auto p = (int16_t*)&ret; + for (size_t i = 0; i < SIMD_LEN_16; i++) { + ASSERT_EQ(p[i], s0[i]); + } +} + +TEST_F(FALLBACK, GiLoadInt8) { + std::vector s0{9, 2, -128, 127, 2, 45, 3, 0, + 11, 2, -128, 127, 2, 55, 3, -1}; + GI_INT8_t ret; + + ret = GiLoadInt8(s0.data()); + + auto p = (int8_t*)&ret; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + ASSERT_EQ(p[i], s0[i]); + } +} + +TEST_F(FALLBACK, GiStoreInt32) { + GI_INT32_t src0; + std::vector s0{1, 2, -200, 999}; + s0.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + + std::vector ret; + ret.resize(SIMD_LEN); + GiStoreInt32(ret.data(), src0); + + assert_eq(ret.data(), s0); +} + +TEST_F(FALLBACK, GiStoreLaneXXInt32) { + GI_INT32_t src0; + std::vector s0{1, 2, -200, 999}; + s0.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + + int32_t ret = 8888; + +#define CB(n) \ + GiStoreLane##n##Int32(&ret, src0); \ + ASSERT_EQ(s0[n], ret); + + CB(0) + CB(1) + CB(2) + CB(3) +} + +TEST_F(FALLBACK, GiReinterInt32ToInt8) { + GI_INT32_t src0; + GI_INT8_t ret, naive; + std::vector s0{65536, 2, -200, 999}; + s0.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + + ret = GiReinterInt32ToInt8(src0); + naive = (GI_INT8_t)src0; + + ASSERT_FALSE(memcmp(&ret, &naive, sizeof(GI_INT8_t))); +} + +TEST_F(FALLBACK, GiStoreInt16) { + GI_INT16_t src0; + std::vector s0{32767, 2, -200, -32768, 1, 2, 3, 4}; + s0.resize(SIMD_LEN_16); + init((int16_t*)&src0, s0, SIMD_LEN_16); + + std::vector ret; + ret.resize(SIMD_LEN_16); + GiStoreInt16(ret.data(), src0); + + assert_eq(ret.data(), s0, SIMD_LEN_16); +} + +TEST_F(FALLBACK, GiStoreInt8) { + GI_INT8_t src0; + std::vector s0{127, 2, 56, -128, 1, 2, 3, 4, 127, 2, 56, -128, 1, 2, 3, 4}; + s0.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + + std::vector ret; + ret.resize(SIMD_LEN_8); + GiStoreInt8(ret.data(), src0); + + assert_eq(ret.data(), s0, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiStoreLowInt8) { + GI_INT8_t src0; + std::vector s0{127, 2, 56, -128, 1, 2, 3, 4, 127, 2, 56, -128, 1, 2, 3, 4}; + s0.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + + std::vector ret; + ret.resize(SIMD_LEN_8 / 2); + GiStoreLowInt8(ret.data(), src0); + + assert_eq(ret.data(), s0, SIMD_LEN_8 / 2); +} + +TEST_F(FALLBACK, GiStoreHihgInt8) { + GI_INT8_t src0; + std::vector s0{127, 2, 56, -128, 1, 2, 3, 4, 127, 2, 56, -128, 1, 2, 3, 4}; + s0.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + + std::vector ret; + ret.resize(SIMD_LEN_8 / 2); + GiStoreHihgInt8(ret.data(), src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_8 / 2; i++) { + naive.push_back(s0[SIMD_LEN_8 / 2 + i]); + } + + assert_eq(ret.data(), naive, SIMD_LEN_8 / 2); +} + +TEST_F(FALLBACK, GiNegInt32) { + GI_INT32_t src0, ret; + std::vector s0{ + std::numeric_limits::max(), std::numeric_limits::min(), + -3, 4}; + s0.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + + ret = GiNegInt32(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(-s0[i]); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiNegInt8) { + GI_INT8_t src0, ret; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + s0.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + + ret = GiNegInt8(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back(-s0[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiTestAndSetUint32) { + GI_UINT32_t src0, src1, ret; + std::vector s0{ + 8, 2, std::numeric_limits::max(), + std::numeric_limits::min()}; + std::vector s1{ + 8, 4, std::numeric_limits::max(), + std::numeric_limits::max()}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((uint32_t*)&src0, s0); + init((uint32_t*)&src1, s1); + + ret = GiTestAndSetUint32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] & s1[i] ? 0xFFFFFFFF : 0); + } + + assert_eq((uint32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiAddInt32) { + GI_INT32_t src0, src1, ret; + std::vector s0{127, 2, std::numeric_limits::max(), 9999}; + std::vector s1{1, 2, std::numeric_limits::max(), -9}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + init((int32_t*)&src1, s1); + + ret = GiAddInt32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] + s1[i]); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiAddUint32) { + GI_UINT32_t src0, src1, ret; + std::vector s0{127, 2, std::numeric_limits::max(), 9999}; + std::vector s1{1, 2, std::numeric_limits::max(), 9}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((uint32_t*)&src0, s0); + init((uint32_t*)&src1, s1); + + ret = GiAddUint32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] + s1[i]); + } + + assert_eq((uint32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiAddInt16) { + GI_INT16_t src0, src1, ret; + std::vector s0{-127, 2, std::numeric_limits::max(), 9999, 1, 2, + 3, 4}; + std::vector s1{1, + 2, + std::numeric_limits::max(), + std::numeric_limits::min(), + -1, + 23, + -3, + -5}; + s0.resize(SIMD_LEN_16); + s1.resize(SIMD_LEN_16); + init((int16_t*)&src0, s0, SIMD_LEN_16); + init((int16_t*)&src1, s1, SIMD_LEN_16); + + ret = GiAddInt16(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_16; i++) { + naive.push_back(s0[i] + s1[i]); + } + + assert_eq((int16_t*)&ret, naive, SIMD_LEN_16); +} + +TEST_F(FALLBACK, GiAddInt8) { + GI_INT8_t src0, src1, ret; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + std::vector s1{ + 3, + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 4}; + s0.resize(SIMD_LEN_8); + s1.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + init((int8_t*)&src1, s1, SIMD_LEN_8); + + ret = GiAddInt8(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back(s0[i] + s1[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiSubtractInt32) { + GI_INT32_t src0, src1, ret; + std::vector s0{127, 2, std::numeric_limits::max(), 9999}; + std::vector s1{1, 2, std::numeric_limits::max(), -9}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + init((int32_t*)&src1, s1); + + ret = GiSubtractInt32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] - s1[i]); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiSubtractUint32) { + GI_UINT32_t src0, src1, ret; + std::vector s0{127, 2, std::numeric_limits::max(), 9999}; + std::vector s1{1, 2, std::numeric_limits::max(), 9}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((uint32_t*)&src0, s0); + init((uint32_t*)&src1, s1); + + ret = GiSubtractUint32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] - s1[i]); + } + + assert_eq((uint32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiSubtractInt8) { + GI_INT8_t src0, src1, ret; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + std::vector s1{ + 3, + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 4}; + s0.resize(SIMD_LEN_8); + s1.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + init((int8_t*)&src1, s1, SIMD_LEN_8); + + ret = GiSubtractInt8(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back(s0[i] - s1[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiMultiplyInt32) { + GI_INT32_t src0, src1, ret; + std::vector s0{127, 2, 202204, 99}; + std::vector s1{1, 2, -4, -9}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + init((int32_t*)&src1, s1); + + ret = GiMultiplyInt32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] * s1[i]); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiMultiplyInt8) { + GI_INT8_t src0, src1, ret; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + std::vector s1{ + 3, + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 4}; + s0.resize(SIMD_LEN_8); + s1.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + init((int8_t*)&src1, s1, SIMD_LEN_8); + + ret = GiMultiplyInt8(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back(s0[i] * s1[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiMultiplyAddInt32) { + GI_INT32_t src0, src1, src2, ret; + std::vector s0{127, 2, 67, 9999}; + std::vector s1{1, 2, 90, -9}; + std::vector s2{-1, 12, 4, -9}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + s2.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + init((int32_t*)&src1, s1); + init((int32_t*)&src2, s2); + + ret = GiMultiplyAddInt32(src0, src1, src2); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] + s1[i] * s2[i]); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiMultiplyAddInt8) { + GI_INT8_t src0, src1, src2, ret; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + std::vector s1{ + 3, + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 4}; + std::vector s2{ + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 5, + 8, + 4}; + s0.resize(SIMD_LEN_8); + s1.resize(SIMD_LEN_8); + s2.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + init((int8_t*)&src1, s1, SIMD_LEN_8); + init((int8_t*)&src2, s2, SIMD_LEN_8); + + ret = GiMultiplyAddInt8(src0, src1, src2); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back(s0[i] + s1[i] * s2[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiAndInt8) { + GI_INT8_t src0, src1, ret; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + std::vector s1{ + 3, + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 4}; + s0.resize(SIMD_LEN_8); + s1.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + init((int8_t*)&src1, s1, SIMD_LEN_8); + + ret = GiAndInt8(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back(s0[i] & s1[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiEOrUint32) { + GI_UINT32_t src0, src1, ret; + std::vector s0{127, 2, std::numeric_limits::max(), 9999}; + std::vector s1{1, 2, std::numeric_limits::max(), 9}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + init((uint32_t*)&src0, s0); + init((uint32_t*)&src1, s1); + + ret = GiEOrUint32(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] ^ s1[i]); + } + + assert_eq((uint32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiOrInt8) { + GI_INT8_t src0, src1, ret; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + std::vector s1{ + 3, + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 4}; + s0.resize(SIMD_LEN_8); + s1.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + init((int8_t*)&src1, s1, SIMD_LEN_8); + + ret = GiOrInt8(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back(s0[i] | s1[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiAndNotInt8) { + GI_INT8_t src0, src1, ret; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + std::vector s1{ + 3, + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 4}; + s0.resize(SIMD_LEN_8); + s1.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + init((int8_t*)&src1, s1, SIMD_LEN_8); + + ret = GiAndNotInt8(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back((~s0[i]) & s1[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiXorInt8) { + GI_INT8_t src0, src1, ret; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + std::vector s1{ + 3, + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 4}; + s0.resize(SIMD_LEN_8); + s1.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + init((int8_t*)&src1, s1, SIMD_LEN_8); + + ret = GiXorInt8(src0, src1); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back((s0[i]) ^ s1[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiShiftRight23Int32) { + GI_INT32_t src0, ret; + std::vector s0{1, 2, 3, -4}; + s0.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + + ret = GiShiftRight23Int32(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] >> 23); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiBlendInt32) { + GI_INT32_t src0, src1, src2, ret, na; + std::vector s0{1, 2, 3, -4}; + std::vector s1{12, 22, 32, -43}; + std::vector s2{-1, 21, 34, 4}; + s0.resize(SIMD_LEN); + s1.resize(SIMD_LEN); + s2.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + init((int32_t*)&src1, s1); + init((int32_t*)&src2, s2); + + ret = GiBlendInt32(src0, src1, src2); + + na = GiOrInt32(GiAndInt32(src1, src2), GiAndNotInt32(src2, src0)); + + std::vector naive; + auto p = (int32_t*)&na; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(p[i]); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiBlendInt8) { + GI_INT8_t src0, src1, src2, ret, na; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + std::vector s1{ + 3, + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 4}; + std::vector s2{ + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 5, + 8, + 4}; + s0.resize(SIMD_LEN_8); + s1.resize(SIMD_LEN_8); + s2.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + init((int8_t*)&src1, s1, SIMD_LEN_8); + init((int8_t*)&src2, s2, SIMD_LEN_8); + + ret = GiBlendInt8(src0, src1, src2); + na = GiOrInt8(GiAndInt8(src1, src2), GiAndNotInt8(src2, src0)); + + std::vector naive; + auto p = (int8_t*)&na; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back(p[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiAbsInt32) { + GI_INT32_t src0, ret; + std::vector s0{-1, 2, -3, 4}; + s0.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + + ret = GiAbsInt32(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(s0[i] > 0 ? s0[i] : -s0[i]); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiAbsInt16) { + GI_INT16_t src0, ret; + std::vector s0{-127, 2, std::numeric_limits::max(), 9999, 1, 2, + 3, 4}; + s0.resize(SIMD_LEN_16); + init((int16_t*)&src0, s0, SIMD_LEN_16); + + ret = GiAbsInt16(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_16; i++) { + naive.push_back(s0[i] > 0 ? s0[i] : -s0[i]); + } + + assert_eq((int16_t*)&ret, naive, SIMD_LEN_16); +} + +TEST_F(FALLBACK, GiAbsInt8) { + GI_INT8_t src0, ret; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + s0.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + + ret = GiAbsInt8(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back(s0[i] > 0 ? s0[i] : -s0[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiMaximumInt32) { + GI_INT32_t src0, src1, src2, ret, na; + std::vector s0{1, -2, 3, 4}; + s0.resize(SIMD_LEN); + std::vector s1{5, 6, 7, -8}; + s1.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + init((int32_t*)&src1, s1); + + std::vector s2; + for (size_t i = 0; i < SIMD_LEN; i++) { + s2.push_back(s0[i] > s1[i] ? 0xFFFFFFFF : 0); + } + s2.resize(SIMD_LEN); + init((int32_t*)&src2, s2); + + ret = GiMaximumInt32(src0, src1); + + na = GiBlendInt32(src1, src0, src2); + std::vector naive; + auto p = (int32_t*)&na; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(p[i]); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiMinimumInt32) { + GI_INT32_t src0, src1, src2, ret, na; + std::vector s0{1, -2, 3, 4}; + s0.resize(SIMD_LEN); + std::vector s1{5, 6, 7, -8}; + s1.resize(SIMD_LEN); + init((int32_t*)&src0, s0); + init((int32_t*)&src1, s1); + + std::vector s2; + for (size_t i = 0; i < SIMD_LEN; i++) { + s2.push_back(s1[i] > s0[i] ? 0xFFFFFFFF : 0); + } + s2.resize(SIMD_LEN); + init((int32_t*)&src2, s2); + + ret = GiMinimumInt32(src0, src1); + + na = GiBlendInt32(src1, src0, src2); + std::vector naive; + auto p = (int32_t*)&na; + for (size_t i = 0; i < SIMD_LEN; i++) { + naive.push_back(p[i]); + } + + assert_eq((int32_t*)&ret, naive); +} + +TEST_F(FALLBACK, GiBlendInt8x16) { + GI_INT8_t src0, src1, src2, ret, na; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + std::vector s1{ + 3, + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 4}; + std::vector s2{ + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 5, + 8, + 4}; + s0.resize(SIMD_LEN_8); + s1.resize(SIMD_LEN_8); + s2.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + init((int8_t*)&src1, s1, SIMD_LEN_8); + init((int8_t*)&src2, s2, SIMD_LEN_8); + + ret = GiBlendInt8x16(src0, src1, src2); + na = GiOrInt8(GiAndInt8(src1, src2), GiAndNotInt8(src2, src0)); + + std::vector naive; + auto p = (int8_t*)&na; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back(p[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiMaximumInt8) { + GI_INT8_t src0, src1, src2, ret, na; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + std::vector s1{ + 3, + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 4}; + s0.resize(SIMD_LEN_8); + s1.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + init((int8_t*)&src1, s1, SIMD_LEN_8); + + std::vector s2; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + s2.push_back(s1[i] < s0[i] ? 0xFF : 0); + } + s2.resize(SIMD_LEN); + init((int8_t*)&src2, s2, SIMD_LEN_8); + ret = GiMaximumInt8(src0, src1); + + na = GiBlendInt8(src1, src0, src2); + + std::vector naive; + auto p = (int8_t*)&na; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back(p[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiMinimumInt8) { + GI_INT8_t src0, src1, src2, ret, na; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 3, + 4}; + std::vector s1{ + 3, + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + 1, + 2, + 4}; + s0.resize(SIMD_LEN_8); + s1.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + init((int8_t*)&src1, s1, SIMD_LEN_8); + + std::vector s2; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + s2.push_back(s1[i] > s0[i] ? 0xFF : 0); + } + s2.resize(SIMD_LEN); + init((int8_t*)&src2, s2, SIMD_LEN_8); + ret = GiMinimumInt8(src0, src1); + + na = GiBlendInt8(src1, src0, src2); + + std::vector naive; + auto p = (int8_t*)&na; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive.push_back(p[i]); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiMoveHighLongInt8) { + GI_INT8_t src0; + GI_INT16_t ret; + + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + std::numeric_limits::max(), + std::numeric_limits::min(), + 3, + 4}; + + s0.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + + ret = GiMoveHighLongInt8(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_8 / 2; i++) { + naive.push_back(s0[i + SIMD_LEN_8 / 2]); + } + + assert_eq((int16_t*)&ret, naive, SIMD_LEN_16); +} + +TEST_F(FALLBACK, GiMoveLowLongInt8) { + GI_INT8_t src0; + GI_INT16_t ret; + + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 56, + -128, + 1, + 2, + 3, + 4, + 127, + 2, + 56, + -128, + std::numeric_limits::max(), + std::numeric_limits::min(), + 3, + 4}; + + s0.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + + ret = GiMoveLowLongInt8(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_8 / 2; i++) { + naive.push_back(s0[i]); + } + + assert_eq((int16_t*)&ret, naive, SIMD_LEN_16); +} + +TEST_F(FALLBACK, GiMoveHighLongInt16) { + GI_INT16_t src0; + GI_INT32_t ret; + std::vector s0{-127, 2, std::numeric_limits::max(), 9999, 1, 2, + 3, 4}; + s0.resize(SIMD_LEN_16); + init((int16_t*)&src0, s0, SIMD_LEN_16); + + ret = GiMoveHighLongInt16(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_16 / 2; i++) { + naive.push_back(s0[i + SIMD_LEN_16 / 2]); + } + + assert_eq((int32_t*)&ret, naive, SIMD_LEN); +} + +TEST_F(FALLBACK, GiMoveLowLongInt16) { + GI_INT16_t src0; + GI_INT32_t ret; + std::vector s0{-127, 2, std::numeric_limits::max(), 9999, 1, 2, + 3, 4}; + s0.resize(SIMD_LEN_16); + init((int16_t*)&src0, s0, SIMD_LEN_16); + + ret = GiMoveLowLongInt16(src0); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN_16 / 2; i++) { + naive.push_back(s0[i]); + } + + assert_eq((int32_t*)&ret, naive, SIMD_LEN); +} + +TEST_F(FALLBACK, GiReduceAddInt8) { + GI_INT8_t src0; + int32_t ret{0}; + std::vector s0{127, 2, 56, -128, 1, 2, 3, 4, 127, 2, 56, -128, 1, 2, 3, 4}; + s0.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + + ret = GiReduceAddInt8(src0); + + int32_t naive{0}; + for (auto i : s0) { + naive += i; + } + + ASSERT_EQ(ret, naive); +} + +TEST_F(FALLBACK, GiReduceMaxInt8) { + GI_INT8_t src0; + int8_t ret{0}; + std::vector s0{127, 2, 56, -128, 1, 2, 3, 4, 127, 2, 56, -128, 1, 2, 3, 4}; + s0.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + + ret = GiReduceMaxInt8(src0); + + int8_t naive{s0[0]}; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive = Max(naive, s0[i]); + } + + ASSERT_EQ(ret, naive); +} + +TEST_F(FALLBACK, GiReduceMinInt8) { + GI_INT8_t src0; + int8_t ret{0}; + std::vector s0{127, 2, 56, -128, 1, 2, 3, 4, 127, 2, 56, -128, 1, 2, 3, 4}; + s0.resize(SIMD_LEN_8); + init((int8_t*)&src0, s0, SIMD_LEN_8); + + ret = GiReduceMinInt8(src0); + + int8_t naive{s0[0]}; + for (size_t i = 0; i < SIMD_LEN_8; i++) { + naive = Min(naive, s0[i]); + } + + ASSERT_EQ(ret, naive); +} + +TEST_F(FALLBACK, GiCvtFromFloat32ToInt8) { + GI_INT8_t ret; + GI_FLOAT32_t src0; + std::vector s0{ + 1.0f, -2.2f, std::numeric_limits::max(), + std::numeric_limits::min()}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiCvtFromFloat32ToInt8(src0); + + std::vector naive; + naive.resize(SIMD_LEN_8); + + for (size_t i = 0; i < SIMD_LEN; i++) { + int8_t data = Saturate(round(s0[i]), -128, 127); + naive[i] = data; + naive[SIMD_LEN + i] = data; + naive[2 * SIMD_LEN + i] = data; + naive[3 * SIMD_LEN + i] = data; + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiCvtFromFloat32V2ToInt8) { + GI_INT8_t ret; + GI_FLOAT32_V2_t src0; + std::vector s0{ + 1.0f, + -2.2f, + std::numeric_limits::max(), + std::numeric_limits::min(), + 1.1f, + 2.2f, + -9.0f, + 899999.0f}; + s0.resize(SIMD_LEN * 2); + init((float*)&src0, s0, SIMD_LEN * 2); + + ret = GiCvtFromFloat32V2ToInt8(src0); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN * 2; i++) { + naive.push_back(Saturate(round(s0[i]), -128, 127)); + } + + for (size_t i = 0; i < SIMD_LEN * 2; i++) { + naive.push_back(Saturate(round(s0[i]), -128, 127)); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiCvtFromFloat32V4ToInt8) { + GI_INT8_t ret; + GI_FLOAT32_V4_t src0; + std::vector s0{ + std::numeric_limits::max(), + std::numeric_limits::min(), + 1.0f, + -2.2f, + 3.1f, + 4.2f, + -5.0f, + 6.0f, + 7.0f, + 8.0f, + -9.9f, + 10.9f, + -11.9f, + 12.9f, + 13.9f, + -14.9f}; + s0.resize(SIMD_LEN * 4); + init((float*)&src0, s0, SIMD_LEN * 4); + + ret = GiCvtFromFloat32V4ToInt8(src0); + + std::vector naive; + + for (size_t i = 0; i < SIMD_LEN * 4; i++) { + naive.push_back(Saturate(round(s0[i]), -128, 127)); + } + + assert_eq((int8_t*)&ret, naive, SIMD_LEN_8); +} + +TEST_F(FALLBACK, GiCombineFloat32) { + float32x2_t src0, src1; + GI_FLOAT32_t ret; + std::vector s0{1.1f, -3.1415f}; + std::vector s1{2.3f, 3.14777f}; + memcpy(&src0, s0.data(), sizeof(float32x2_t)); + memcpy(&src1, s1.data(), sizeof(float32x2_t)); + + ret = GiCombineFloat32(src0, src1); + + std::vector naive; + naive.push_back(s0[0]); + naive.push_back(s0[1]); + naive.push_back(s1[0]); + naive.push_back(s1[1]); + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiGetLowFloat32) { + float32x2_t ret; + GI_FLOAT32_t src0; + std::vector s0{1.0f, 2.2f, 3.4f, 4.5f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiGetLowFloat32(src0); + auto r = (float*)&ret; + + ASSERT_EQ(*r, s0[0]); + ASSERT_EQ(*(r + 1), s0[1]); +} + +TEST_F(FALLBACK, GiGetHighFloat32) { + float32x2_t ret; + GI_FLOAT32_t src0; + std::vector s0{1.0f, 2.2f, 3.4f, 4.5f}; + s0.resize(SIMD_LEN); + init((float*)&src0, s0); + + ret = GiGetHighFloat32(src0); + auto r = (float*)&ret; + + ASSERT_EQ(*r, s0[2]); + ASSERT_EQ(*(r + 1), s0[3]); +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen