Browse Source

feat(gi): make pooling apply gi class type

GitOrigin-RevId: e60c6a2e76
release-1.10
Megvii Engine Team 3 years ago
parent
commit
19d36fa03c
3 changed files with 91 additions and 59 deletions
  1. +5
    -5
      dnn/src/fallback/pooling/gi/do_max_pooling_3x3_s2x2_float.cpp
  2. +59
    -35
      dnn/src/fallback/pooling/gi/kern_fp32_pooling_nchw44.h
  3. +27
    -19
      dnn/src/fallback/pooling/gi/pooling_helper.h

+ 5
- 5
dnn/src/fallback/pooling/gi/do_max_pooling_3x3_s2x2_float.cpp View File

@@ -8,11 +8,11 @@
namespace megdnn { namespace megdnn {
namespace fallback { namespace fallback {


#define GI_UZP(s0, s1, d0, d1) \
do { \
auto tmp__ = GiUzpqFloat32(s0, s1); \
d0 = tmp__.val[0]; \
d1 = tmp__.val[1]; \
#define GI_UZP(s0, s1, d0, d1) \
do { \
auto tmp__ = GiUzpqFloat32(s0, s1); \
d0 = GiGetSubVectorFloat32V2(tmp__, 0); \
d1 = GiGetSubVectorFloat32V2(tmp__, 1); \
} while (0) } while (0)


void do_max_pooling_3x3_s2x2_float_gi( void do_max_pooling_3x3_s2x2_float_gi(


+ 59
- 35
dnn/src/fallback/pooling/gi/kern_fp32_pooling_nchw44.h View File

@@ -29,17 +29,33 @@ void calculate_xsx_nchw44(T1 result, T2 src) {
CalXsXNchw44<filter, stride, ow_step, mode, T1, T2>::impl(result, src); CalXsXNchw44<filter, stride, ow_step, mode, T1, T2>::impl(result, src);
}; };


#define CALCULATE_MAX_CB(step) \
result[0] = GiMaximumFloat32(result[0], src[0 * stride + step]); \
result[1] = GiMaximumFloat32(result[1], src[1 * stride + step]); \
result[2] = GiMaximumFloat32(result[2], src[2 * stride + step]); \
result[3] = GiMaximumFloat32(result[3], src[3 * stride + step]);
#define CALCULATE_MAX_CB(step) \
result[0] = GiFloat32Type2FixLenType(GiMaximumFloat32( \
GiFixLenType2GiFloat32Type(result[0]), \
GiFixLenType2GiFloat32Type(src[0 * stride + step]))); \
result[1] = GiFloat32Type2FixLenType(GiMaximumFloat32( \
GiFixLenType2GiFloat32Type(result[1]), \
GiFixLenType2GiFloat32Type(src[1 * stride + step]))); \
result[2] = GiFloat32Type2FixLenType(GiMaximumFloat32( \
GiFixLenType2GiFloat32Type(result[2]), \
GiFixLenType2GiFloat32Type(src[2 * stride + step]))); \
result[3] = GiFloat32Type2FixLenType(GiMaximumFloat32( \
GiFixLenType2GiFloat32Type(result[3]), \
GiFixLenType2GiFloat32Type(src[3 * stride + step])));


#define CALCULATE_AVG_CB(step) \
result[0] = GiAddFloat32(result[0], src[0 * stride + step]); \
result[1] = GiAddFloat32(result[1], src[1 * stride + step]); \
result[2] = GiAddFloat32(result[2], src[2 * stride + step]); \
result[3] = GiAddFloat32(result[3], src[3 * stride + step]);
#define CALCULATE_AVG_CB(step) \
result[0] = GiFloat32Type2FixLenType(GiAddFloat32( \
GiFixLenType2GiFloat32Type(result[0]), \
GiFixLenType2GiFloat32Type(src[0 * stride + step]))); \
result[1] = GiFloat32Type2FixLenType(GiAddFloat32( \
GiFixLenType2GiFloat32Type(result[1]), \
GiFixLenType2GiFloat32Type(src[1 * stride + step]))); \
result[2] = GiFloat32Type2FixLenType(GiAddFloat32( \
GiFixLenType2GiFloat32Type(result[2]), \
GiFixLenType2GiFloat32Type(src[2 * stride + step]))); \
result[3] = GiFloat32Type2FixLenType(GiAddFloat32( \
GiFixLenType2GiFloat32Type(result[3]), \
GiFixLenType2GiFloat32Type(src[3 * stride + step])));


#define INSTANCE_CAL(filter) \ #define INSTANCE_CAL(filter) \
template <int stride, typename T1, typename T2> \ template <int stride, typename T1, typename T2> \
@@ -78,13 +94,13 @@ struct KerPoolingFilterXStrideXNchw44<filter, stride, ow_step, PoolingBase::Mode
constexpr int packed_ic = 4; constexpr int packed_ic = 4;
constexpr int simd_len = 4; constexpr int simd_len = 4;
constexpr float default_float = std::numeric_limits<float>::lowest(); constexpr float default_float = std::numeric_limits<float>::lowest();
GI_FLOAT32_t result[ow_step];
GI_FLOAT32_t src[src_reg_size];
GI_FLOAT32_FIXLEN_t result[ow_step];
GI_FLOAT32_FIXLEN_t src[src_reg_size];


result[0] = GiBroadcastFloat32(default_float);
result[1] = GiBroadcastFloat32(default_float);
result[2] = GiBroadcastFloat32(default_float);
result[3] = GiBroadcastFloat32(default_float);
result[0] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[1] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[2] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[3] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));


for (int fh_idx = 0; fh_idx < filter; ++fh_idx) { for (int fh_idx = 0; fh_idx < filter; ++fh_idx) {
load_helper<src_reg_size, 0, simd_len, 0, GiD1Qf32>( load_helper<src_reg_size, 0, simd_len, 0, GiD1Qf32>(
@@ -93,10 +109,10 @@ struct KerPoolingFilterXStrideXNchw44<filter, stride, ow_step, PoolingBase::Mode
result, src); result, src);
} }


GiStoreFloat32(dst_ptr + 0 * packed_ic, result[0]);
GiStoreFloat32(dst_ptr + 1 * packed_ic, result[1]);
GiStoreFloat32(dst_ptr + 2 * packed_ic, result[2]);
GiStoreFloat32(dst_ptr + 3 * packed_ic, result[3]);
GiStoreFloat32(dst_ptr + 0 * packed_ic, GiFixLenType2GiFloat32Type(result[0]));
GiStoreFloat32(dst_ptr + 1 * packed_ic, GiFixLenType2GiFloat32Type(result[1]));
GiStoreFloat32(dst_ptr + 2 * packed_ic, GiFixLenType2GiFloat32Type(result[2]));
GiStoreFloat32(dst_ptr + 3 * packed_ic, GiFixLenType2GiFloat32Type(result[3]));
} }
}; };


@@ -110,28 +126,36 @@ struct KerPoolingFilterXStrideXNchw44<
constexpr float default_float = 0; constexpr float default_float = 0;
constexpr float div_filter_size = 1.f / (filter * filter); constexpr float div_filter_size = 1.f / (filter * filter);
const GI_FLOAT32_t div_filter_size_vec = GiBroadcastFloat32(div_filter_size); const GI_FLOAT32_t div_filter_size_vec = GiBroadcastFloat32(div_filter_size);
GI_FLOAT32_t result[ow_step];
GI_FLOAT32_t src[src_reg_size];
GI_FLOAT32_FIXLEN_t result[ow_step];
GI_FLOAT32_FIXLEN_t src[src_reg_size];


result[0] = GiBroadcastFloat32(default_float);
result[1] = GiBroadcastFloat32(default_float);
result[2] = GiBroadcastFloat32(default_float);
result[3] = GiBroadcastFloat32(default_float);
result[0] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[1] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[2] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[3] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));


for (int fh_idx = 0; fh_idx < filter; ++fh_idx) { for (int fh_idx = 0; fh_idx < filter; ++fh_idx) {
load_helper<src_reg_size, 0, simd_len, 0, GiD1Qf32>( load_helper<src_reg_size, 0, simd_len, 0, GiD1Qf32>(
src, src_ptr + fh_idx * iw * packed_ic, 0); src, src_ptr + fh_idx * iw * packed_ic, 0);
calculate_xsx_nchw44<filter, stride, ow_step, PoolingBase::Mode::AVERAGE>( calculate_xsx_nchw44<filter, stride, ow_step, PoolingBase::Mode::AVERAGE>(
result, src); result, src);
}
result[0] = GiMultiplyFloat32(result[0], div_filter_size_vec);
result[1] = GiMultiplyFloat32(result[1], div_filter_size_vec);
result[2] = GiMultiplyFloat32(result[2], div_filter_size_vec);
result[3] = GiMultiplyFloat32(result[3], div_filter_size_vec);
GiStoreFloat32(dst_ptr + 0 * packed_ic, result[0]);
GiStoreFloat32(dst_ptr + 1 * packed_ic, result[1]);
GiStoreFloat32(dst_ptr + 2 * packed_ic, result[2]);
GiStoreFloat32(dst_ptr + 3 * packed_ic, result[3]);
};
GiStoreFloat32(
dst_ptr + 0 * packed_ic,
GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(result[0]), div_filter_size_vec));
GiStoreFloat32(
dst_ptr + 1 * packed_ic,
GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(result[1]), div_filter_size_vec));
GiStoreFloat32(
dst_ptr + 2 * packed_ic,
GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(result[2]), div_filter_size_vec));
GiStoreFloat32(
dst_ptr + 3 * packed_ic,
GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(result[3]), div_filter_size_vec));
} }
}; };




+ 27
- 19
dnn/src/fallback/pooling/gi/pooling_helper.h View File

@@ -56,18 +56,20 @@ struct GiMeanPooler<area, dt_float32, float, float> {
static constexpr int MIDOUT_CASE_NUM = 1; static constexpr int MIDOUT_CASE_NUM = 1;
static constexpr int SIMD_WIDTH = 4; static constexpr int SIMD_WIDTH = 4;


static const GI_FLOAT32_t coef;
GI_FLOAT32_t res;
GiMeanPooler(DType) : res(GiBroadcastFloat32(0.0f)) {}
void feed(const float* val) { res = GiAddFloat32(res, GiLoadFloat32(val)); }
GI_FLOAT32_FIXLEN_t res, coef;
GiMeanPooler(DType)
: res(GiFloat32Type2FixLenType(GiBroadcastFloat32(0.0f))),
coef(GiFloat32Type2FixLenType(GiBroadcastFloat32(1.0f / area))) {}
void feed(const float* val) {
res = GiFloat32Type2FixLenType(
GiAddFloat32(GiFixLenType2GiFloat32Type(res), GiLoadFloat32(val)));
}
void post(float* dst) { void post(float* dst) {
res = GiMultiplyFloat32(res, coef);
GiStoreFloat32(dst, res);
res = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(res), GiFixLenType2GiFloat32Type(coef)));
GiStoreFloat32(dst, GiFixLenType2GiFloat32Type(res));
} }
}; };
template <int area>
const GI_FLOAT32_t GiMeanPooler<area, dt_float32, float, float>::coef =
GiBroadcastFloat32(1.0f / area);


/* ======================= MaxPooler ======================== */ /* ======================= MaxPooler ======================== */


@@ -96,10 +98,15 @@ struct GiMaxPooler<area, dt_float32, float, float> {
static constexpr int MIDOUT_CASE_NUM = 11; static constexpr int MIDOUT_CASE_NUM = 11;
static constexpr int SIMD_WIDTH = 4; static constexpr int SIMD_WIDTH = 4;


GI_FLOAT32_t res;
GiMaxPooler(DType) : res(GiBroadcastFloat32(DTypeTrait<dt_float32>::min())) {}
void feed(const float* val) { res = GiMaximumFloat32(res, GiLoadFloat32(val)); }
void post(float* dst) { GiStoreFloat32(dst, res); }
GI_FLOAT32_FIXLEN_t res;
GiMaxPooler(DType)
: res(GiFloat32Type2FixLenType(
GiBroadcastFloat32(DTypeTrait<dt_float32>::min()))) {}
void feed(const float* val) {
res = GiFloat32Type2FixLenType(
GiMaximumFloat32(GiFixLenType2GiFloat32Type(res), GiLoadFloat32(val)));
}
void post(float* dst) { GiStoreFloat32(dst, GiFixLenType2GiFloat32Type(res)); }
}; };


template <typename Pooler, int window> template <typename Pooler, int window>
@@ -137,7 +144,8 @@ struct do_pxl_2x2_pack_proxy<
const int IW, const int OH, const int OW, const int PH, const int PW) { const int IW, const int OH, const int OW, const int PH, const int PW) {
MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IH);
MEGDNN_MARK_USED_VAR(OH); MEGDNN_MARK_USED_VAR(OH);
static const auto avg_coef = GiBroadcastFloat32(0.25f);
static const auto avg_coef =
GiFloat32Type2FixLenType(GiBroadcastFloat32(0.25f));
int ih = -PH + 2 * oh; int ih = -PH + 2 * oh;
int iw = -PW + 2 * ow; int iw = -PW + 2 * ow;
auto i00 = GiLoadFloat32(src + (ih + 0) * IW + (iw + 0)), auto i00 = GiLoadFloat32(src + (ih + 0) * IW + (iw + 0)),
@@ -148,7 +156,7 @@ struct do_pxl_2x2_pack_proxy<
auto vlow = GiPaddFloat32(GiGetLowFloat32(sum0), GiGetHighFloat32(sum0)); auto vlow = GiPaddFloat32(GiGetLowFloat32(sum0), GiGetHighFloat32(sum0));
auto vhigh = GiPaddFloat32(GiGetLowFloat32(sum1), GiGetHighFloat32(sum1)); auto vhigh = GiPaddFloat32(GiGetLowFloat32(sum1), GiGetHighFloat32(sum1));
auto comb = GiCombineFloat32(vlow, vhigh); auto comb = GiCombineFloat32(vlow, vhigh);
auto result = GiMultiplyFloat32(comb, avg_coef);
auto result = GiMultiplyFloat32(comb, GiFixLenType2GiFloat32Type(avg_coef));
GiStoreFloat32(dst + oh * OW + ow, result); GiStoreFloat32(dst + oh * OW + ow, result);
} }
}; };
@@ -327,8 +335,8 @@ void do_max_pooling_w5x5_s2x2_gi(
auto s0 = GiLoadFloat32(sptr + iw + 0); auto s0 = GiLoadFloat32(sptr + iw + 0);
auto s1 = GiLoadFloat32(sptr + iw + MEGDNN_SIMD_WIDTH); auto s1 = GiLoadFloat32(sptr + iw + MEGDNN_SIMD_WIDTH);
auto d = GiUzpqFloat32(s0, s1); auto d = GiUzpqFloat32(s0, s1);
GiStoreFloat32(even + even_offset, d.val[0]);
GiStoreFloat32(odd + odd_offset, d.val[1]);
GiStoreFloat32(even + even_offset, GiGetSubVectorFloat32V2(d, 0));
GiStoreFloat32(odd + odd_offset, GiGetSubVectorFloat32V2(d, 1));
even_offset += MEGDNN_SIMD_WIDTH; even_offset += MEGDNN_SIMD_WIDTH;
odd_offset += MEGDNN_SIMD_WIDTH; odd_offset += MEGDNN_SIMD_WIDTH;
} }
@@ -464,8 +472,8 @@ void do_average_pooling_3x3_s2x2_gi(


for (; iw + 2 * MEGDNN_SIMD_WIDTH <= IW; iw += 2 * MEGDNN_SIMD_WIDTH) { for (; iw + 2 * MEGDNN_SIMD_WIDTH <= IW; iw += 2 * MEGDNN_SIMD_WIDTH) {
auto s0 = GiLd2qFloat32(sptr + iw); auto s0 = GiLd2qFloat32(sptr + iw);
GiStoreFloat32(even + even_offset, s0.val[0]);
GiStoreFloat32(odd + odd_offset, s0.val[1]);
GiStoreFloat32(even + even_offset, GiGetSubVectorFloat32V2(s0, 0));
GiStoreFloat32(odd + odd_offset, GiGetSubVectorFloat32V2(s0, 1));
even_offset += MEGDNN_SIMD_WIDTH; even_offset += MEGDNN_SIMD_WIDTH;
odd_offset += MEGDNN_SIMD_WIDTH; odd_offset += MEGDNN_SIMD_WIDTH;
} }


Loading…
Cancel
Save