GitOrigin-RevId: f250aa7b2a
tags/v1.9.0
@@ -95,8 +95,8 @@ typedef __m128i GI_INT16; | |||
typedef __m128i GI_INT32; | |||
#else | |||
typedef float GI_FLOAT32 __attribute__((vector_size(16))); | |||
typedef uint16_t GI_UINT8 __attribute__((vector_size(16))); | |||
typedef int16_t GI_INT8 __attribute__((vector_size(16))); | |||
typedef uint8_t GI_UINT8 __attribute__((vector_size(16))); | |||
typedef int8_t GI_INT8 __attribute__((vector_size(16))); | |||
typedef int16_t GI_INT16 __attribute__((vector_size(16))); | |||
typedef int32_t GI_INT32 __attribute__((vector_size(16))); | |||
#endif | |||
@@ -119,6 +119,9 @@ typedef int32_t GI_INT32 __attribute__((vector_size(16))); | |||
#define GI_SIMD_LEN_BYTE 16 | |||
#endif | |||
#define Max(a, b) (a) > (b) ? (a) : (b) | |||
#define Min(a, b) (a) < (b) ? (a) : (b) | |||
typedef struct { | |||
GI_INT32 val[2]; | |||
} GI_INT32_V2; | |||
@@ -223,7 +223,7 @@ GiInterleaveLowFloat32(GI_FLOAT32 Vector1, GI_FLOAT32 Vector2) { | |||
#if defined(GI_NEON64_INTRINSICS) | |||
return vzip1q_f32(Vector1, Vector2); | |||
#elif defined(GI_NEON32_INTRINSICS) | |||
float32x2_t zipped = vzipq_f32(Vector1, Vector2); | |||
float32x4x2_t zipped = vzipq_f32(Vector1, Vector2); | |||
return zipped.val[0]; | |||
#elif defined(GI_SSE2_INTRINSICS) | |||
return _mm_unpacklo_ps(Vector1, Vector2); | |||
@@ -243,7 +243,7 @@ GiInterleaveHighFloat32(GI_FLOAT32 Vector1, GI_FLOAT32 Vector2) { | |||
#if defined(GI_NEON64_INTRINSICS) | |||
return vzip2q_f32(Vector1, Vector2); | |||
#elif defined(GI_NEON32_INTRINSICS) | |||
float32x2_t zipped = vzipq_f32(Vector1, Vector2); | |||
float32x4x2_t zipped = vzipq_f32(Vector1, Vector2); | |||
return zipped.val[1]; | |||
#elif defined(GI_SSE2_INTRINSICS) | |||
return _mm_unpackhi_ps(Vector1, Vector2); | |||
@@ -460,7 +460,14 @@ GiMaximumFloat32(GI_FLOAT32 Vector1, GI_FLOAT32 Vector2) { | |||
#if defined(GI_NEON_INTRINSICS) | |||
return vmaxq_f32(Vector1, Vector2); | |||
#elif defined(GI_SSE2_INTRINSICS) | |||
return _mm_max_ps(Vector1, Vector2); | |||
//! _mm_max_ps does not fellow the IEEE standard when input is NAN, so | |||
//! implement by C code | |||
#define MAX_NAN(a, b) (std::isnan(a) || (a) > (b)) ? (a) : (b); | |||
GI_FLOAT32 max; | |||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { | |||
max[i] = MAX_NAN(Vector1[i], Vector2[i]); | |||
} | |||
return max; | |||
#else | |||
return GiBlendFloat32(Vector2, Vector1, Vector1 > Vector2); | |||
#endif | |||
@@ -473,6 +480,14 @@ GiMinimumFloat32(GI_FLOAT32 Vector1, GI_FLOAT32 Vector2) { | |||
return vminq_f32(Vector1, Vector2); | |||
#elif defined(GI_SSE2_INTRINSICS) | |||
return _mm_min_ps(Vector1, Vector2); | |||
//! _mm_min_ps does not fellow the IEEE standard when input is NAN, so | |||
//! implement by C code | |||
#define MIN_NAN(a, b) (std::isnan(a) || (a) < (b)) ? (a) : (b); | |||
GI_FLOAT32 min; | |||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { | |||
min[i] = MIN_NAN(Vector1[i], Vector2[i]); | |||
} | |||
return min; | |||
#else | |||
return GiBlendFloat32(Vector2, Vector1, Vector2 > Vector1); | |||
#endif | |||
@@ -97,7 +97,7 @@ void GiStoreInt8(int8_t* Buffer, GI_INT8 Vector) { | |||
#elif defined(GI_SSE2_INTRINSICS) | |||
_mm_storeu_si128((__m128i*)Buffer, Vector); | |||
#else | |||
for (int i = 0; i < 16; i++) { | |||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | |||
Buffer[i] = Vector[i]; | |||
} | |||
#endif | |||
@@ -197,7 +197,8 @@ GiAndNotInt8(GI_INT8 VectorNot, GI_INT8 Vector) { | |||
#elif defined(GI_SSE2_INTRINSICS) | |||
return _mm_andnot_si128(VectorNot, Vector); | |||
#else | |||
return (~VectorNot) & Vector; | |||
GI_INT8 Not = ~VectorNot; | |||
return (Not & Vector); | |||
#endif | |||
} | |||
@@ -327,11 +328,13 @@ GiMoveHighLongInt8(GI_INT8 Vector) { | |||
for (int i = 0; i < 8; i++) { | |||
data[i] = o_data[8 + i]; | |||
} | |||
return _mm_loadu_si16(data); | |||
return _mm_loadu_si128((__m128i*)data); | |||
#else | |||
GI_INT16 ret; | |||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { | |||
ret[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i]; | |||
int8_t* data = (int8_t*)&Vector; | |||
size_t half_length = GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); | |||
for (size_t i = 0; i < half_length; i++) { | |||
ret[i] = data[i + half_length]; | |||
} | |||
return ret; | |||
#endif | |||
@@ -351,10 +354,11 @@ GiMoveLowLongInt8(GI_INT8 Vector) { | |||
for (int i = 0; i < 8; i++) { | |||
data[i] = o_data[i]; | |||
} | |||
return _mm_loadu_si16(data); | |||
return _mm_loadu_si128((__m128i*)data); | |||
#else | |||
GI_INT16 ret; | |||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { | |||
size_t half_length = GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); | |||
for (size_t i = 0; i < half_length; i++) { | |||
ret[i] = Vector[i]; | |||
} | |||
return ret; | |||
@@ -375,11 +379,12 @@ GiMoveHighLongInt16(GI_INT16 Vector) { | |||
for (int i = 0; i < 4; i++) { | |||
data[i] = o_data[4 + i]; | |||
} | |||
return _mm_loadu_si32(data); | |||
return _mm_loadu_si128((__m128i*)data); | |||
#else | |||
GI_INT32 ret; | |||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int16_t); i++) { | |||
ret[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i]; | |||
size_t half_length = GI_SIMD_LEN_BYTE / 2 / sizeof(int16_t); | |||
for (size_t i = 0; i < half_length; i++) { | |||
ret[i] = Vector[half_length + i]; | |||
} | |||
return ret; | |||
#endif | |||
@@ -399,10 +404,11 @@ GiMoveLowLongInt16(GI_INT16 Vector) { | |||
for (int i = 0; i < 4; i++) { | |||
data[i] = o_data[i]; | |||
} | |||
return _mm_loadu_si32(data); | |||
return _mm_loadu_si128((__m128i*)data); | |||
#else | |||
GI_INT32 ret; | |||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int16_t); i++) { | |||
size_t half_length = GI_SIMD_LEN_BYTE / 2 / sizeof(int16_t); | |||
for (size_t i = 0; i < half_length; i++) { | |||
ret[i] = Vector[i]; | |||
} | |||
return ret; | |||
@@ -414,7 +420,7 @@ int16_t GiReduceAddInt8(GI_INT8 Vector) { | |||
#if defined(GI_NEON64_INTRINSICS) | |||
return vaddlvq_s8(Vector); | |||
#elif defined(GI_NEON32_INTRINSICS) | |||
int32_t sum = vpaddlq_s16(vpaddlq_s8(Vector)); | |||
int32x4_t sum = vpaddlq_s16(vpaddlq_s8(Vector)); | |||
return (vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) + vgetq_lane_s32(sum, 2) + | |||
vgetq_lane_s32(sum, 3)); | |||
#elif defined(GI_SSE42_INTRINSICS) | |||
@@ -431,8 +437,8 @@ int16_t GiReduceAddInt8(GI_INT8 Vector) { | |||
return (int16_t)(ret); | |||
#elif defined(GI_SSE2_INTRINSICS) | |||
__m64 low = GiGetLowInt8x16(Vector); | |||
__m64 high = GiGetHighInt8x16(Vector); | |||
__m64 low = _mm_movepi64_pi64(Vector); | |||
__m64 high = _mm_movepi64_pi64(_mm_unpackhi_epi64(Vector, Vector)); | |||
__m128 v0 = _mm_cvtpi8_ps(low); | |||
__m128 v1 = _mm_cvtpi8_ps(_mm_unpackhi_pi32(low, low)); | |||
__m128 v2 = _mm_cvtpi8_ps(high); | |||
@@ -447,16 +453,13 @@ int16_t GiReduceAddInt8(GI_INT8 Vector) { | |||
return (int16_t)(ret0 + ret1 + ret2 + ret3); | |||
#else | |||
int32_t sum = 0; | |||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { | |||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | |||
sum += Vector[i]; | |||
} | |||
return sum; | |||
#endif | |||
} | |||
#define Max(a, b) (a) > (b) ? (a) : (b) | |||
#define Min(a, b) (a) < (b) ? (a) : (b) | |||
GI_FORCEINLINE | |||
int8_t GiReduceMaxInt8(GI_INT8 Vector) { | |||
#if defined(GI_NEON64_INTRINSICS) | |||
@@ -480,23 +483,23 @@ int8_t GiReduceMaxInt8(GI_INT8 Vector) { | |||
ret = Max(_mm_extract_epi32(sum, 3), ret); | |||
return (int8_t)ret; | |||
#elif defined(GI_SSE2_INTRINSICS) | |||
__m64 low = GiGetLowInt8x16(Vector); | |||
__m64 high = GiGetHighInt8x16(Vector); | |||
__m64 low = _mm_movepi64_pi64(Vector); | |||
__m64 high = _mm_movepi64_pi64(_mm_unpackhi_epi64(Vector, Vector)); | |||
__m128 v0 = _mm_cvtpi8_ps(low); | |||
__m128 v1 = _mm_cvtpi8_ps(_mm_unpackhi_pi32(low, low)); | |||
__m128 v2 = _mm_cvtpi8_ps(high); | |||
__m128 v3 = _mm_cvtpi8_ps(_mm_unpackhi_pi32(high, high)); | |||
__m128 sum0 = _mm_add_ps(v0, v1); | |||
__m128 sum1 = _mm_add_ps(v2, v3); | |||
__m128 sum = _mm_add_ps(sum0, sum1); | |||
float ret0 = _mm_cvtss_f32(sum); | |||
float ret1 = _mm_cvtss_f32(_mm_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 1, 1, 1))); | |||
float ret2 = _mm_cvtss_f32(_mm_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 2, 2, 2))); | |||
float ret3 = _mm_cvtss_f32(_mm_shuffle_ps(sum, sum, _MM_SHUFFLE(3, 3, 3, 3))); | |||
__m128 max0 = _mm_max_ps(v0, v1); | |||
__m128 max1 = _mm_max_ps(v2, v3); | |||
__m128 max = _mm_max_ps(max0, max1); | |||
float ret0 = _mm_cvtss_f32(max); | |||
float ret1 = _mm_cvtss_f32(_mm_shuffle_ps(max, max, _MM_SHUFFLE(1, 1, 1, 1))); | |||
float ret2 = _mm_cvtss_f32(_mm_shuffle_ps(max, max, _MM_SHUFFLE(2, 2, 2, 2))); | |||
float ret3 = _mm_cvtss_f32(_mm_shuffle_ps(max, max, _MM_SHUFFLE(3, 3, 3, 3))); | |||
return (int8_t)(Max(Max(ret0, ret1), Max(ret2, ret3))); | |||
#else | |||
int8_t max = Vector[0]; | |||
for (size_t i = 1; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { | |||
for (size_t i = 1; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | |||
max = Max(max, Vector[i]); | |||
} | |||
return max; | |||
@@ -526,23 +529,23 @@ int8_t GiReduceMinInt8(GI_INT8 Vector) { | |||
ret = Min(_mm_extract_epi32(sum, 3), ret); | |||
return (int8_t)ret; | |||
#elif defined(GI_SSE2_INTRINSICS) | |||
__m64 low = GiGetLowInt8x16(Vector); | |||
__m64 high = GiGetHighInt8x16(Vector); | |||
__m64 low = _mm_movepi64_pi64(Vector); | |||
__m64 high = _mm_movepi64_pi64(_mm_unpackhi_epi64(Vector, Vector)); | |||
__m128 v0 = _mm_cvtpi8_ps(low); | |||
__m128 v1 = _mm_cvtpi8_ps(_mm_unpackhi_pi32(low, low)); | |||
__m128 v2 = _mm_cvtpi8_ps(high); | |||
__m128 v3 = _mm_cvtpi8_ps(_mm_unpackhi_pi32(high, high)); | |||
__m128 sum0 = _mm_add_ps(v0, v1); | |||
__m128 sum1 = _mm_add_ps(v2, v3); | |||
__m128 sum = _mm_add_ps(sum0, sum1); | |||
float ret0 = _mm_cvtss_f32(sum); | |||
float ret1 = _mm_cvtss_f32(_mm_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 1, 1, 1))); | |||
float ret2 = _mm_cvtss_f32(_mm_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 2, 2, 2))); | |||
float ret3 = _mm_cvtss_f32(_mm_shuffle_ps(sum, sum, _MM_SHUFFLE(3, 3, 3, 3))); | |||
__m128 min0 = _mm_min_ps(v0, v1); | |||
__m128 min1 = _mm_min_ps(v2, v3); | |||
__m128 min = _mm_min_ps(min0, min1); | |||
float ret0 = _mm_cvtss_f32(min); | |||
float ret1 = _mm_cvtss_f32(_mm_shuffle_ps(min, min, _MM_SHUFFLE(1, 1, 1, 1))); | |||
float ret2 = _mm_cvtss_f32(_mm_shuffle_ps(min, min, _MM_SHUFFLE(2, 2, 2, 2))); | |||
float ret3 = _mm_cvtss_f32(_mm_shuffle_ps(min, min, _MM_SHUFFLE(3, 3, 3, 3))); | |||
return (int8_t)(Min(Min(ret0, ret1), Min(ret2, ret3))); | |||
#else | |||
int8_t min = Vector[0]; | |||
for (size_t i = 1; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { | |||
for (size_t i = 1; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | |||
min = Min(min, Vector[i]); | |||
} | |||
return min; | |||
@@ -561,8 +564,7 @@ GiCvtFromFloat32ToInt8(GI_FLOAT32 src) { | |||
#if __ARM_ARCH >= 8 | |||
int32x4_t vres0 = vcvtaq_s32_f32(src); | |||
int16x8_t mid_s16 = vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres0)); | |||
int8x8_t ret = vqmovn_s16(vcombine_s16(vqmovn_s32(mid_s16), vqmovn_s32(mid_s16))); | |||
return vcombine_s16(ret, ret); | |||
return vcombine_s8(vqmovn_s16(mid_s16), vqmovn_s16(mid_s16)); | |||
#else | |||
float32x4_t vzero = vdupq_n_f32(0.f); | |||
float32x4_t vfhalf = vdupq_n_f32(0.5f); | |||
@@ -570,8 +572,7 @@ GiCvtFromFloat32ToInt8(GI_FLOAT32 src) { | |||
float32x4_t vinc0 = vbslq_f32(vcgeq_f32(src, vzero), vfhalf, vfneg_half); | |||
int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(src, vinc0)); | |||
int16x8_t mid_s16 = vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres0)); | |||
int8x8_t ret = vqmovn_s16(vcombine_s16(vqmovn_s32(mid_s16), vqmovn_s32(mid_s16))); | |||
return vcombine_s16(ret, ret); | |||
return vcombine_s8(vqmovn_s16(mid_s16), vqmovn_s16(mid_s16)); | |||
#endif | |||
#elif defined(GI_SSE42_INTRINSICS) | |||
__m128 vfzero = _mm_set1_ps(0.f); | |||
@@ -0,0 +1,81 @@ | |||
/** | |||
* \file dnn/src/arm_common/quantized_converter.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/dtype.h" | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
#include "src/fallback/general_intrinsic/gi_int.h" | |||
namespace megdnn { | |||
namespace fallback { | |||
struct QConverterBase { | |||
inline static GI_INT32 vzero() { return GiBroadcastInt32(0); } | |||
inline static GI_FLOAT32 vfzero() { return GiBroadcastFloat32(0.f); } | |||
inline static GI_FLOAT32 vfhalf() { return GiBroadcastFloat32(0.5f); } | |||
inline static GI_FLOAT32 vfneg_half() { return GiBroadcastFloat32(-0.5f); } | |||
}; | |||
struct QConverter { | |||
template <typename dst_type, typename... src_type> | |||
static inline dst_type convert(const src_type&... src); | |||
template <typename dst_type, typename... src_type> | |||
static inline dst_type round(const src_type&... src); | |||
}; | |||
template <> | |||
inline dt_qint8 QConverter::convert(const float& src) { | |||
return dt_qint8(saturate<int8_t, float>(std::round(src), -128, 127)); | |||
} | |||
template <> | |||
inline dt_quint8 QConverter::convert(const float& src, const uint8_t& zp) { | |||
return dt_quint8(saturate<uint8_t, float>(std::round(src) + zp, 0, 255)); | |||
} | |||
template <> | |||
inline dt_qint32 QConverter::convert(const float& src) { | |||
return dt_qint32(saturate<int32_t, float>( | |||
std::round(src), static_cast<float>(std::numeric_limits<int32_t>::min()), | |||
static_cast<float>(std::numeric_limits<int32_t>::max()))); | |||
} | |||
template <> | |||
inline GI_FLOAT32_V2 QConverter::convert(const GI_INT16& vsrc) { | |||
GI_INT32 vhi = GiMoveHighLongInt16(vsrc); | |||
GI_INT32 vlo = GiMoveLowLongInt16(vsrc); | |||
return {{GiCastToFloat32(vlo), GiCastToFloat32(vhi)}}; | |||
} | |||
template <> | |||
inline GI_INT8 QConverter::convert(const GI_FLOAT32_V2& vsrc) { | |||
return GiCvtFromFloat32V2ToInt8(vsrc); | |||
} | |||
template <> | |||
inline GI_INT8 QConverter::convert(const GI_FLOAT32& src) { | |||
return GiCvtFromFloat32ToInt8(src); | |||
} | |||
template <> | |||
inline GI_INT32 QConverter::round(const GI_FLOAT32& vsrc) { | |||
return GiRoundAsInt32(vsrc); | |||
} | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -14,11 +14,13 @@ | |||
#include "src/naive/handle.h" | |||
#include "midout.h" | |||
#include "reducer.h" | |||
#include "src/common/reduce_helper.h" | |||
MIDOUT_DECL(megdnn_fb_reduce_op) | |||
MIDOUT_DECL(megdnn_fb_reduce_c) | |||
MIDOUT_DECL(megdnn_fb_reduce_dtype) | |||
MIDOUT_DECL(megdnn_fallback_reduce_optimized) | |||
namespace { | |||
@@ -77,11 +79,20 @@ namespace fallback { | |||
void ReduceImpl::exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
if (!exec_optimized(src, dst, workspace)) { | |||
return exec_fallback(src, dst, workspace); | |||
} | |||
} | |||
void ReduceImpl::exec_fallback( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
using namespace reduce; | |||
using Mode = Param::Mode; | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
size_t A, B, C; | |||
get_ABC(src.layout, A, B, C, param().axis); | |||
#define cb_by_op(src_type, dst_type, _wtype, mode_, Op_, kern_func) \ | |||
if (param().mode == mode_) { \ | |||
typedef DTypeTrait<src_type>::ctype src_ctype; \ | |||
@@ -176,6 +187,101 @@ void ReduceImpl::exec( | |||
naive::ReduceForwardImpl::exec(src, dst, workspace); | |||
} | |||
bool ReduceImpl::exec_optimized( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace) { | |||
size_t A, B, C; | |||
reduce::get_ABC(src.layout, A, B, C, param().axis); | |||
bool execed = false; | |||
using Mode = param::Reduce::Mode; | |||
#define DISPATCH_FUNC(Reducer, dtype, ctype, comp_type) \ | |||
if (C == 1) { \ | |||
using _Reducer = Reducer<dtype, ctype, comp_type, true>; \ | |||
std::function<void(const ctype*, ctype*, DType, size_t, size_t, size_t)> \ | |||
do_reduce = Exec<_Reducer, true>::do_reduce; \ | |||
MIDOUT_BEGIN( \ | |||
megdnn_fallback_reduce_optimized, ctype, dtype, comp_type, \ | |||
midout_iv(0)) { \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \ | |||
reinterpret_cast<ctype*>(src.raw_ptr()), \ | |||
reinterpret_cast<ctype*>(dst.raw_ptr()), src_type, A, B, C)); \ | |||
execed = true; \ | |||
} \ | |||
MIDOUT_END(); \ | |||
} else { \ | |||
using _Reducer = Reducer<dtype, ctype, comp_type, false>; \ | |||
std::function<void(const ctype*, ctype*, DType, size_t, size_t, size_t)> \ | |||
do_reduce = Exec<_Reducer, false>::do_reduce; \ | |||
MIDOUT_BEGIN( \ | |||
megdnn_fallback_reduce_optimized, ctype, dtype, comp_type, \ | |||
midout_iv(1)) { \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \ | |||
reinterpret_cast<ctype*>(src.raw_ptr()), \ | |||
reinterpret_cast<ctype*>(dst.raw_ptr()), src_type, A, B, C)); \ | |||
execed = true; \ | |||
} \ | |||
MIDOUT_END(); \ | |||
} | |||
#define DISPATCH_MODE_QUANTIZED(dtype, ctype, comp_type) \ | |||
switch (param().mode) { \ | |||
case Mode::MEAN: \ | |||
DISPATCH_FUNC(MeanReducer, dtype, ctype, comp_type); \ | |||
break; \ | |||
case Mode::MAX: \ | |||
DISPATCH_FUNC(maxReducer, dtype, ctype, ctype); \ | |||
break; \ | |||
case Mode::MIN: \ | |||
DISPATCH_FUNC(minReducer, dtype, ctype, ctype); \ | |||
break; \ | |||
default: \ | |||
break; \ | |||
} | |||
#define DISPATCH_MODE_FLOAT(dtype, ctype, comp_type) \ | |||
switch (param().mode) { \ | |||
case Mode::MEAN: \ | |||
DISPATCH_FUNC(MeanReducer, dtype, ctype, comp_type); \ | |||
break; \ | |||
case Mode::MAX: \ | |||
DISPATCH_FUNC(maxReducer, dtype, ctype, ctype); \ | |||
break; \ | |||
case Mode::MIN: \ | |||
DISPATCH_FUNC(minReducer, dtype, ctype, ctype); \ | |||
break; \ | |||
case Mode::SUM: \ | |||
DISPATCH_FUNC(SumReducer, dtype, ctype, ctype); \ | |||
break; \ | |||
case Mode::SUM_SQR: \ | |||
DISPATCH_FUNC(SumSqrReducer, dtype, ctype, ctype); \ | |||
break; \ | |||
case Mode::PRODUCT: \ | |||
DISPATCH_FUNC(ProductReducer, dtype, ctype, ctype); \ | |||
break; \ | |||
default: \ | |||
break; \ | |||
} | |||
if (src.layout.is_contiguous() && | |||
src.layout.dtype.category() == DTypeCategory::QUANTIZED && | |||
param().data_type == param::Reduce::DataType::DEFAULT) { | |||
DType src_type = src.layout.dtype; | |||
if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { | |||
DISPATCH_MODE_QUANTIZED(dt_qint8, int8_t, int32_t) | |||
} | |||
} else if ( | |||
src.layout.is_contiguous() && | |||
src.layout.dtype.category() == DTypeCategory::FLOAT && | |||
param().data_type == param::Reduce::DataType::DEFAULT) { | |||
DType src_type = src.layout.dtype; | |||
if (src.layout.dtype.enumv() == DTypeEnum::Float32) { | |||
DISPATCH_MODE_FLOAT(dt_float32, float, float) | |||
} | |||
} | |||
return execed; | |||
#undef DISPATCH_FUNC | |||
#undef DISPATCH_MODE_QUANTIZED | |||
#undef DISPATCH_MODE_FLOAT | |||
} | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -19,6 +19,10 @@ public: | |||
using ReduceForwardImpl::ReduceForwardImpl; | |||
void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace) override; | |||
bool exec_optimized( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace); | |||
void exec_fallback( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace); | |||
}; | |||
} // namespace fallback | |||
@@ -0,0 +1,417 @@ | |||
/** | |||
* \file dnn/src/fallback/reduce/reducer.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/common/utils.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
#include "src/fallback/general_intrinsic/gi_int.h" | |||
#include "src/fallback/quantized_converter.h" | |||
using namespace megdnn; | |||
using namespace fallback; | |||
namespace { | |||
/*****************************Mean Reducer***********************/ | |||
template <typename dtype, typename ctype, typename comp_type, bool C1> | |||
struct MeanReducer; | |||
template <> | |||
struct MeanReducer<dt_qint8, int8_t, int32_t, true> { | |||
using ctype = int8_t; | |||
static constexpr int SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
int32_t res; | |||
float coef; | |||
MeanReducer(DType, size_t cnt) : res(0), coef(1.0 / cnt) {} | |||
MeanReducer() = default; | |||
void feed(const int8_t* val) { res += GiReduceAddInt8(GiLoadInt8(val)); } | |||
void feed_remain(const int8_t* val) { res += *val; } | |||
void post(int8_t* dst) { | |||
float sum = res * coef; | |||
*dst = std::round(sum); | |||
} | |||
}; | |||
template <> | |||
struct MeanReducer<dt_qint8, int8_t, int32_t, false> { | |||
using ctype = int8_t; | |||
static constexpr int SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
GI_INT32 res[4]; | |||
int32_t remain; | |||
int32_t cnt; | |||
float coef; | |||
GI_FLOAT32 vcoef; | |||
MeanReducer(DType, size_t cnt) : remain(0), cnt(cnt), coef(1.0 / cnt) { | |||
memset(res, 0, sizeof(res)); | |||
vcoef = GiBroadcastFloat32(coef); | |||
} | |||
MeanReducer() = default; | |||
void feed(const int8_t* val) { | |||
const GI_INT8 vval = GiLoadInt8(val); | |||
const GI_INT16 vval_low = GiMoveLowLongInt8(vval); | |||
const GI_INT16 vval_high = GiMoveHighLongInt8(vval); | |||
const GI_INT32 vval_low_low = GiMoveLowLongInt16(vval_low); | |||
const GI_INT32 vval_low_high = GiMoveHighLongInt16(vval_low); | |||
const GI_INT32 vval_high_low = GiMoveLowLongInt16(vval_high); | |||
const GI_INT32 vval_high_high = GiMoveHighLongInt16(vval_high); | |||
res[0] = GiAddInt32(res[0], vval_low_low); | |||
res[1] = GiAddInt32(res[1], vval_low_high); | |||
res[2] = GiAddInt32(res[2], vval_high_low); | |||
res[3] = GiAddInt32(res[3], vval_high_high); | |||
} | |||
void feed_remain(const int8_t* val) { remain += *val; } | |||
void post(int8_t* dst) { | |||
for (int i = 0; i < 4; i += 2) { | |||
GI_FLOAT32 vitem0 = GiMultiplyFloat32(GiCastToFloat32(res[i]), vcoef); | |||
GI_FLOAT32 vitem1 = GiMultiplyFloat32(GiCastToFloat32(res[i + 1]), vcoef); | |||
GiStoreLowInt8( | |||
dst, | |||
(QConverter::convert<GI_INT8, GI_FLOAT32_V2>({{vitem0, vitem1}}))); | |||
dst += 8; | |||
} | |||
} | |||
void post_remain(int8_t* dst) { | |||
float sum = remain * coef; | |||
*dst = std::round(sum); | |||
} | |||
}; | |||
template <> | |||
struct MeanReducer<dt_float32, float, float, true> { | |||
using ctype = float; | |||
static constexpr int SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float); | |||
GI_FLOAT32 res; | |||
float result; | |||
float coef; | |||
MeanReducer(DType, size_t cnt) : result(0.0f), coef(1.0 / cnt) { | |||
res = GiBroadcastFloat32(0.0f); | |||
} | |||
MeanReducer() = default; | |||
void feed(const float* val) { res = GiAddFloat32(GiLoadFloat32(val), res); } | |||
void feed_remain(const float* val) { result += *val; } | |||
void post(float* dst) { | |||
result += GiReduceAddFloat32(res); | |||
*dst = result * coef; | |||
} | |||
}; | |||
template <> | |||
struct MeanReducer<dt_float32, float, float, false> { | |||
using ctype = float; | |||
static constexpr int SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float); | |||
GI_FLOAT32 res; | |||
float remain; | |||
float coef; | |||
MeanReducer(DType, size_t cnt) : remain(0.0f), coef(1.0 / cnt) { | |||
res = GiBroadcastFloat32(0.0f); | |||
} | |||
MeanReducer() = default; | |||
void feed(const float* val) { res = GiAddFloat32(GiLoadFloat32(val), res); } | |||
void feed_remain(const float* val) { remain += *val; } | |||
void post(float* dst) { | |||
res = GiMultiplyScalerFloat32(res, coef); | |||
GiStoreFloat32(dst, res); | |||
} | |||
void post_remain(float* dst) { *dst = remain * coef; } | |||
}; | |||
/******************************max min Reducer****************************/ | |||
template <typename dtype, typename ctype, typename comp_type, bool C1> | |||
struct maxReducer; | |||
template <typename dtype, typename ctype, typename comp_type, bool C1> | |||
struct minReducer; | |||
#define REDUCER_MAX_MIN_C1(_mode, _Mode, _init) \ | |||
template <> \ | |||
struct _mode##Reducer<dt_float32, float, float, true> { \ | |||
using ctype = float; \ | |||
static constexpr int SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float); \ | |||
GI_FLOAT32 res; \ | |||
_mode##Reducer(DType, size_t) { res = GiBroadcastFloat32(_init); } \ | |||
_mode##Reducer() = default; \ | |||
void feed(const float* val) { \ | |||
auto vval = GiLoadFloat32(val); \ | |||
res = Gi##_Mode##imumFloat32(vval, res); \ | |||
} \ | |||
void feed_remain(const float* val) { \ | |||
auto vval = GiBroadcastFloat32(*val); \ | |||
res = Gi##_Mode##imumFloat32(vval, res); \ | |||
} \ | |||
void post(float* dst) { *dst = GiReduce##_Mode##imumFloat32(res); } \ | |||
} | |||
REDUCER_MAX_MIN_C1(max, Max, std::numeric_limits<dt_float32>::lowest()); | |||
REDUCER_MAX_MIN_C1(min, Min, std::numeric_limits<dt_float32>::max()); | |||
#undef REDUCER_MAX_MIN_C1 | |||
#define REDUCER_MAX_MIN_C(_mode, _Mode, _init) \ | |||
template <> \ | |||
struct _mode##Reducer<dt_float32, float, float, false> { \ | |||
using ctype = float; \ | |||
static constexpr int SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float); \ | |||
GI_FLOAT32 res; \ | |||
float remain; \ | |||
_mode##Reducer(DType, size_t) { \ | |||
res = GiBroadcastFloat32(_init); \ | |||
remain = _init; \ | |||
} \ | |||
_mode##Reducer() = default; \ | |||
void feed(const float* val) { \ | |||
GI_FLOAT32 vval = GiLoadFloat32(val); \ | |||
res = Gi##_Mode##imumFloat32(vval, res); \ | |||
} \ | |||
void feed_remain(const float* val) { \ | |||
using namespace std; \ | |||
remain = _mode(*val, remain); \ | |||
} \ | |||
void post(float* dst) { GiStoreFloat32(dst, res); } \ | |||
void post_remain(float* dst) { *dst = remain; } \ | |||
} | |||
REDUCER_MAX_MIN_C(max, Max, std::numeric_limits<dt_float32>::lowest()); | |||
REDUCER_MAX_MIN_C(min, Min, std::numeric_limits<dt_float32>::max()); | |||
#undef REDUCER_MAX_MIN_C | |||
#define REDUCER_MAX_MIN_C1(_mode, _Mode, _init) \ | |||
template <> \ | |||
struct _mode##Reducer<dt_qint8, int8_t, int8_t, true> { \ | |||
using ctype = int8_t; \ | |||
static constexpr int SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); \ | |||
GI_INT8 res; \ | |||
_mode##Reducer(DType, size_t) { res = GiBroadcastInt8(_init); } \ | |||
_mode##Reducer() = default; \ | |||
void feed(const int8_t* val) { \ | |||
GI_INT8 vval = GiLoadInt8(val); \ | |||
res = Gi##_Mode##imumInt8(vval, res); \ | |||
} \ | |||
void feed_remain(const int8_t* val) { \ | |||
GI_INT8 vval = GiBroadcastInt8(*val); \ | |||
res = Gi##_Mode##imumInt8(vval, res); \ | |||
} \ | |||
void post(int8_t* dst) { *dst = GiReduce##_Mode##Int8(res); } \ | |||
} | |||
REDUCER_MAX_MIN_C1(max, Max, -128); | |||
REDUCER_MAX_MIN_C1(min, Min, 127); | |||
#undef REDUCER_MAX_MIN_C1 | |||
#define REDUCER_MAX_MIN_C(_mode, _Mode, _init) \ | |||
template <> \ | |||
struct _mode##Reducer<dt_qint8, int8_t, int8_t, false> { \ | |||
using ctype = int8_t; \ | |||
static constexpr int SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); \ | |||
GI_INT8 res; \ | |||
int8_t remain; \ | |||
_mode##Reducer(DType, size_t) { \ | |||
res = GiBroadcastInt8(_init); \ | |||
remain = _init; \ | |||
} \ | |||
_mode##Reducer() = default; \ | |||
void feed(const int8_t* val) { \ | |||
GI_INT8 vval = GiLoadInt8(val); \ | |||
res = Gi##_Mode##imumInt8(vval, res); \ | |||
} \ | |||
void feed_remain(const int8_t* val) { \ | |||
using namespace std; \ | |||
remain = _mode(*val, remain); \ | |||
} \ | |||
void post(int8_t* dst) { GiStoreInt8(dst, res); } \ | |||
void post_remain(int8_t* dst) { *dst = remain; } \ | |||
} | |||
REDUCER_MAX_MIN_C(max, Max, -128); | |||
REDUCER_MAX_MIN_C(min, Min, 127); | |||
#undef REDUCER_MAX_MIN_C | |||
/***************************Sum Product Reducer***************************/ | |||
template <typename dtype, typename ctype, typename comp_type, bool C1> | |||
struct SumReducer; | |||
template <typename dtype, typename ctype, typename comp_type, bool C1> | |||
struct ProductReducer; | |||
#define REDUCER_SUM_PRODUCT_C1(_mode, _Mode, _op, _init) \ | |||
template <> \ | |||
struct _mode##Reducer<dt_float32, float, float, true> { \ | |||
using ctype = float; \ | |||
static constexpr int SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float); \ | |||
GI_FLOAT32 res; \ | |||
float remain; \ | |||
_mode##Reducer(DType, size_t) { \ | |||
res = GiBroadcastFloat32(_init); \ | |||
remain = _init; \ | |||
} \ | |||
_mode##Reducer() = default; \ | |||
void feed(const float* val) { \ | |||
GI_FLOAT32 vval = GiLoadFloat32(val); \ | |||
res = Gi##_Mode##Float32(vval, res); \ | |||
} \ | |||
void feed_remain(const float* val) { \ | |||
using namespace std; \ | |||
auto op = _op<float>(); \ | |||
remain = op(remain, *val); \ | |||
} \ | |||
void post(float* dst) { \ | |||
using namespace std; \ | |||
auto op = _op<float>(); \ | |||
*dst = op(remain, GiReduce##_Mode##Float32(res)); \ | |||
} \ | |||
} | |||
REDUCER_SUM_PRODUCT_C1(Sum, Add, plus, 0.0f); | |||
REDUCER_SUM_PRODUCT_C1(Product, Multiply, multiplies, 1.0f); | |||
#undef REDUCER_SUM_PRODUCT_C1 | |||
#define REDUCER_SUM_PRODUCT_C(_mode, _Mode, _op, _init) \ | |||
template <> \ | |||
struct _mode##Reducer<dt_float32, float, float, false> { \ | |||
using ctype = float; \ | |||
static constexpr int SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float); \ | |||
GI_FLOAT32 res; \ | |||
float remain; \ | |||
_mode##Reducer(DType, size_t) { \ | |||
res = GiBroadcastFloat32(_init); \ | |||
remain = _init; \ | |||
} \ | |||
_mode##Reducer() = default; \ | |||
void feed(const float* val) { \ | |||
GI_FLOAT32 vval = GiLoadFloat32(val); \ | |||
res = Gi##_Mode##Float32(vval, res); \ | |||
} \ | |||
void feed_remain(const float* val) { \ | |||
using namespace std; \ | |||
auto op = _op<float>(); \ | |||
remain = op(remain, (*val)); \ | |||
} \ | |||
void post(float* dst) { GiStoreFloat32(dst, res); } \ | |||
void post_remain(float* dst) { *dst = remain; } \ | |||
} | |||
REDUCER_SUM_PRODUCT_C(Sum, Add, plus, 0.0f); | |||
REDUCER_SUM_PRODUCT_C(Product, Multiply, multiplies, 1.0f); | |||
#undef REDUCER_SUM_PRODUCT_C | |||
/***************************SumSqr Reducer***************************/ | |||
template <typename dtype, typename ctype, typename comp_type, bool C1> | |||
struct SumSqrReducer; | |||
template <> | |||
struct SumSqrReducer<dt_float32, float, float, true> { | |||
using ctype = float; | |||
static constexpr int SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float); | |||
GI_FLOAT32 res; | |||
float result; | |||
SumSqrReducer(DType, size_t cnt) : result(0.0f) { | |||
MEGDNN_MARK_USED_VAR(cnt); | |||
res = GiBroadcastFloat32(0.0f); | |||
} | |||
SumSqrReducer() = default; | |||
void feed(const float* val) { | |||
GI_FLOAT32 vval = GiLoadFloat32(val); | |||
res = GiAddFloat32(GiMultiplyFloat32(vval, vval), res); | |||
} | |||
void feed_remain(const float* val) { | |||
float vval = *val; | |||
result += vval * vval; | |||
} | |||
void post(float* dst) { | |||
result += GiReduceAddFloat32(res); | |||
*dst = result; | |||
} | |||
}; | |||
template <> | |||
struct SumSqrReducer<dt_float32, float, float, false> { | |||
using ctype = float; | |||
static constexpr int SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float); | |||
GI_FLOAT32 res; | |||
float remain; | |||
SumSqrReducer(DType, size_t cnt) : remain(0.0f) { | |||
MEGDNN_MARK_USED_VAR(cnt); | |||
res = GiBroadcastFloat32(0.0f); | |||
} | |||
SumSqrReducer() = default; | |||
void feed(const float* val) { | |||
GI_FLOAT32 vval = GiLoadFloat32(val); | |||
res = GiAddFloat32(GiMultiplyFloat32(vval, vval), res); | |||
} | |||
void feed_remain(const float* val) { remain += (*val) * (*val); } | |||
void post(float* dst) { GiStoreFloat32(dst, res); } | |||
void post_remain(float* dst) { *dst = remain; } | |||
}; | |||
/**************************************do reduce*************************/ | |||
template <typename Reducer, bool C1> | |||
struct Exec { | |||
static void do_reduce( | |||
const typename Reducer::ctype* src, typename Reducer::ctype* dst, | |||
DType src_dtype, size_t A, size_t B, size_t C); | |||
}; | |||
template <typename Reducer> | |||
struct Exec<Reducer, true> { | |||
static void do_reduce( | |||
const typename Reducer::ctype* src, typename Reducer::ctype* dst, | |||
DType src_dtype, size_t A, size_t B, size_t) { | |||
size_t a = 0; | |||
for (; a < A; a++) { | |||
Reducer reducer0(src_dtype, B); | |||
auto temp_src0 = src + a * B; | |||
size_t b = 0; | |||
for (; b + Reducer::SIMD_WIDTH <= B; b += Reducer::SIMD_WIDTH) { | |||
reducer0.feed(temp_src0); | |||
temp_src0 += Reducer::SIMD_WIDTH; | |||
} | |||
for (; b < B; b++) { | |||
reducer0.feed_remain(temp_src0); | |||
temp_src0++; | |||
} | |||
reducer0.post(dst); | |||
dst++; | |||
} | |||
} | |||
}; | |||
template <typename Reducer> | |||
struct Exec<Reducer, false> { | |||
static void do_reduce( | |||
const typename Reducer::ctype* src, typename Reducer::ctype* dst, | |||
DType src_dtype, size_t A, size_t B, size_t C) { | |||
for (size_t a = 0; a < A; a++) { | |||
size_t c = 0; | |||
for (; c + Reducer::SIMD_WIDTH <= C; c += Reducer::SIMD_WIDTH) { | |||
Reducer reducer(src_dtype, B); | |||
for (size_t b = 0; b < B; b++) | |||
reducer.feed(src + c + C * b); | |||
reducer.post(dst); | |||
dst += Reducer::SIMD_WIDTH; | |||
} | |||
for (; c < C; c++) { | |||
Reducer reducer(src_dtype, B); | |||
for (size_t b = 0; b < B; b++) | |||
reducer.feed_remain(src + c + C * b); | |||
reducer.post_remain(dst); | |||
dst++; | |||
} | |||
src += B * C; | |||
} | |||
} | |||
}; | |||
} // namespace | |||
// vim: syntax=cpp.doxygen |
@@ -181,7 +181,6 @@ TEST_F(ARM_COMMON, LSTM_FORWARD_RECORD) { | |||
TEST_F(ARM_COMMON, BENCHMARK_LSTM_FORWARD) { | |||
Benchmarker<LSTM> optimized_bench(handle()); | |||
constexpr size_t RUNS = 20; | |||
auto run = [&](size_t hidden_size, size_t input_size) { | |||
optimized_bench.set_times(20).set_display(true); | |||
size_t gate_hidden_size = 4 * hidden_size; | |||
@@ -18,6 +18,75 @@ | |||
using namespace megdnn; | |||
using namespace test; | |||
TEST_F(FALLBACK, REDUCE_FULL) { | |||
using Param = Reduce::Param; | |||
using Mode = Param::Mode; | |||
Checker<Reduce> checker(handle()); | |||
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | |||
checker.set_rng(0, &rng); | |||
struct Config { | |||
Param param; | |||
DType dtype; | |||
TensorShape shape; | |||
Config(Param param, DType dtype, TensorShape shape) | |||
: param(param), dtype(dtype), shape(shape) {} | |||
}; | |||
std::vector<Config> configs; | |||
for (auto mode : {Mode::MEAN, Mode::MAX, Mode::MIN}) | |||
for (auto dtype : std::vector<DType>{ | |||
dtype::Float32(), dtype::Float16(), dtype::QuantizedS8(1.3f), | |||
dtype::Quantized8Asymm(1.3f, static_cast<uint8_t>(3))}) | |||
for (int32_t axis : {0, 1, 2}) { | |||
for (size_t A : {1, 3, 5}) { | |||
for (size_t B : {4, 6, 9, 16, 33, 45}) { | |||
for (size_t C : {4, 6, 9, 16, 33, 45}) { | |||
TensorShape shape{A, B, C}; | |||
Param param(mode, axis); | |||
Config config(param, dtype, shape); | |||
configs.push_back(config); | |||
} | |||
} | |||
} | |||
} | |||
for (auto&& config : configs) { | |||
auto&& dtype = config.dtype; | |||
auto&& param = config.param; | |||
auto&& shape = config.shape; | |||
checker.set_dtype(0, dtype).set_param(param).execs({shape, {}}); | |||
} | |||
configs.clear(); | |||
for (auto mode : {Mode::SUM, Mode::PRODUCT, Mode::SUM_SQR}) | |||
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) | |||
for (int32_t axis : {0, 1, 2}) { | |||
for (size_t A : {1, 3, 5}) { | |||
for (size_t B : {4, 6, 9, 16, 33, 45}) { | |||
for (size_t C : {4, 6, 9, 16, 33, 45}) { | |||
TensorShape shape{A, B, C}; | |||
Param param(mode, axis); | |||
Config config(param, dtype, shape); | |||
configs.push_back(config); | |||
} | |||
} | |||
} | |||
} | |||
UniformFloatRNG rng_float(-2, 2); | |||
checker.set_rng(0, &rng_float); | |||
checker.set_epsilon(1e-1); | |||
for (auto&& config : configs) { | |||
auto&& dtype = config.dtype; | |||
auto&& param = config.param; | |||
auto&& shape = config.shape; | |||
if (dtype == dtype::Float16()) | |||
checker.set_epsilon(1e-1); | |||
else | |||
checker.set_epsilon(1e-3); | |||
checker.set_dtype(0, dtype).set_param(param).execs({shape, {}}); | |||
} | |||
} | |||
TEST_F(FALLBACK, REDUCE) { | |||
using Param = Reduce::Param; | |||
using Mode = Param::Mode; | |||