diff --git a/dnn/src/arm_common/pooling/algo.h b/dnn/src/arm_common/pooling/algo.h index df932d1d..5952af95 100644 --- a/dnn/src/arm_common/pooling/algo.h +++ b/dnn/src/arm_common/pooling/algo.h @@ -140,6 +140,8 @@ public: AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; const char* name() const override { return "FALLBACK_POOLING"; } bool usable(const PoolingKernSizeParam&) const override { return true; } + //! use to fallback to algo define at: + //! dnn/src/fallback/pooling/gi/algo.h void exec(const PoolingKernParam&) const override { megdnn_assert(false, "code issue happened!!"); } diff --git a/dnn/src/arm_common/resize/opr_impl.cpp b/dnn/src/arm_common/resize/opr_impl.cpp index c9225da9..df837690 100644 --- a/dnn/src/arm_common/resize/opr_impl.cpp +++ b/dnn/src/arm_common/resize/opr_impl.cpp @@ -30,18 +30,16 @@ void ResizeImpl::exec( bool is_contiguous = src.layout.is_contiguous() && dst.layout.is_contiguous(); bool is_dtype_same = src.layout.dtype == dst.layout.dtype; - bool is_dtype_fp32 = src.layout.dtype == dtype::Float32(); bool is_dtype_fp16 = DNN_FLOAT16_SELECT(src.layout.dtype == dtype::Float16(), false); - bool is_dtype_supported = is_dtype_same && (is_dtype_fp32 || is_dtype_fp16); + bool is_dtype_supported = is_dtype_same && is_dtype_fp16; - bool is_nchw = param().format == param::Resize::Format::NCHW && - (is_dtype_fp32 || is_dtype_fp16); - bool is_nchw44_fp32 = - param().format == param::Resize::Format::NCHW44 && is_dtype_fp32; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + bool is_nchw = param().format == param::Resize::Format::NCHW && is_dtype_fp16; bool is_nchw88_fp16 = param().format == param::Resize::Format::NCHW88 && is_dtype_fp16; + bool is_upsample2 = src.layout.shape[2] * 2 == dst.layout.shape[2] && + src.layout.shape[3] * 2 == dst.layout.shape[3]; #endif bool is_imode_nearest = @@ -50,8 +48,6 @@ void ResizeImpl::exec( param().imode == param::Resize::InterpolationMode::INTER_LINEAR; bool is_imode_supported = is_imode_nearest || is_imode_linear; - bool is_upsample2 = src.layout.shape[2] * 2 == dst.layout.shape[2] && - src.layout.shape[3] * 2 == dst.layout.shape[3]; bool usable = is_contiguous && is_dtype_supported && is_imode_supported; if (param().format == param::Resize::Format::NHWC && @@ -59,63 +55,6 @@ void ResizeImpl::exec( MEGDNN_DISPATCH_CPU_KERN_OPR(resize_cv_exec(src, dst, param().imode)); } else if (!usable) { fallback::ResizeImpl::exec(src, dst, workspace); - } else if (is_dtype_fp32) { - auto kern_param = KernParam::from_tensors( - param().format, param().imode, src, dst, workspace); - if (is_nchw44_fp32) { - if (is_upsample2) { - if (is_imode_nearest) { - MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(0)) { - MEGDNN_DISPATCH_CPU_KERN_OPR( - resize_nearest_upsample2_nchw44_fp32(kern_param)); - } - MIDOUT_END(); - } else { - megdnn_assert(is_imode_linear, "invalid imode"); - MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(1)) { - MEGDNN_DISPATCH_CPU_KERN_OPR( - resize_linear_upsample2_nchw44_fp32(kern_param)); - } - MIDOUT_END(); - } - } else { - if (is_imode_nearest) { - MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(2)) { - MEGDNN_DISPATCH_CPU_KERN_OPR( - resize_direct_nearest_nchw44_fp32(kern_param)); - } - MIDOUT_END(); - } else { - megdnn_assert(is_imode_linear, "invalid imode"); - MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(3)) { - MEGDNN_DISPATCH_CPU_KERN_OPR( - resize_direct_linear_nchw44_fp32(kern_param)); - } - MIDOUT_END(); - } - } - } else if (is_nchw) { - if (is_upsample2) { - if (is_imode_nearest) { - MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(4)) { - MEGDNN_DISPATCH_CPU_KERN_OPR( - resize_nearest_upsample2_nchw_fp32(kern_param)); - } - MIDOUT_END(); - } else { - megdnn_assert(is_imode_linear, "invalid imode"); - MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(5)) { - MEGDNN_DISPATCH_CPU_KERN_OPR( - resize_linear_upsample2_nchw_fp32(kern_param)); - } - MIDOUT_END(); - } - } else { - fallback::ResizeImpl::exec(src, dst, workspace); - } - } else { - fallback::ResizeImpl::exec(src, dst, workspace); - } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC } else if (is_dtype_fp16) { auto kern_param = KernParam::from_tensors( diff --git a/dnn/src/fallback/conv_bias/gi/utils.h b/dnn/src/fallback/conv_bias/gi/utils.h index 504b5697..b2f6503b 100644 --- a/dnn/src/fallback/conv_bias/gi/utils.h +++ b/dnn/src/fallback/conv_bias/gi/utils.h @@ -96,28 +96,10 @@ struct Vector { Vector(const GI_FLOAT32_V2_t& v) { value = v; } static Vector load(const float* addr) { Vector v; -#if defined(GI_TEST_NAIVE) - v.value.val[0] = GiLoadFloat32(addr); - v.value.val[1] = GiLoadFloat32(addr + 4); -#elif defined(__arm__) || defined(__aarch64__) - v.value = vld1q_f32_x2(addr); -#else - v.value.val[0] = GiLoadFloat32(addr); - v.value.val[1] = GiLoadFloat32(addr + 4); -#endif + v.value = GiLoadFloat32V2(addr); return v; } - static void save(float* addr, const Vector& v) { -#if defined(GI_TEST_NAIVE) - GiStoreFloat32(addr, v.value.val[0]); - GiStoreFloat32(addr + 4, v.value.val[1]); -#elif defined(__arm__) || defined(__aarch64__) - vst1q_f32_x2(addr, v.value); -#else - GiStoreFloat32(addr, v.value.val[0]); - GiStoreFloat32(addr + 4, v.value.val[1]); -#endif - } + static void save(float* addr, const Vector& v) { GiStoreFloat32V2(addr, v.value); } void save(float* addr) { save(addr, *this); } Vector operator+(const Vector& lr) { diff --git a/dnn/src/fallback/general_intrinsic/gi_common.h b/dnn/src/fallback/general_intrinsic/gi_common.h index c8b06dca..37496e85 100644 --- a/dnn/src/fallback/general_intrinsic/gi_common.h +++ b/dnn/src/fallback/general_intrinsic/gi_common.h @@ -143,6 +143,7 @@ typedef int16x8_t GI_INT16_t; typedef int32x4_t GI_INT32_t; typedef uint32x4_t GI_UINT32_t; typedef float32x4x2_t GI_FLOAT32_V2_t; +typedef float32x4x3_t GI_FLOAT32_V3_t; typedef float32x4x4_t GI_FLOAT32_V4_t; typedef int32x4x2_t GI_INT32_V2_t; typedef int32x4x4_t GI_INT32_V4_t; @@ -167,6 +168,7 @@ typedef __m128i GI_INT16_t; typedef __m128i GI_INT32_t; typedef __m128i GI_UINT32_t; typedef __m128i GI_INT64_t; +#define _SWAP_HI_LOW32 (2 | (3 << 2) | (0 << 4) | (1 << 6)) #define _INSERTPS_NDX(srcField, dstField) (((srcField) << 6) | ((dstField) << 4)) #define _M64(out, inp) _mm_storel_epi64((__m128i*)&(out), inp) #define _pM128i(a) _mm_loadl_epi64((__m128i*)&(a)) @@ -295,6 +297,10 @@ typedef struct { } GI_FLOAT32_V2_NAIVE_t; typedef struct { + GI_FLOAT32_NAIVE_t val[3]; +} GI_FLOAT32_V3_NAIVE_t; + +typedef struct { GI_FLOAT32_NAIVE_t val[4]; } GI_FLOAT32_V4_NAIVE_t; @@ -335,6 +341,10 @@ typedef struct { } GI_FLOAT32_V2_t; typedef struct { + GI_FLOAT32_t val[3]; +} GI_FLOAT32_V3_t; + +typedef struct { GI_FLOAT32_t val[4]; } GI_FLOAT32_V4_t; diff --git a/dnn/src/fallback/general_intrinsic/gi_float.h b/dnn/src/fallback/general_intrinsic/gi_float.h index 86f81882..e8c657f0 100644 --- a/dnn/src/fallback/general_intrinsic/gi_float.h +++ b/dnn/src/fallback/general_intrinsic/gi_float.h @@ -157,6 +157,19 @@ GI_FLOAT32_t GiLoadFloat32(const float* Buffer) { } GI_FORCEINLINE +GI_FLOAT32_V2_t GiLoadFloat32V2(const float* Buffer) { +#if defined(GI_NEON_INTRINSICS) + return vld1q_f32_x2(Buffer); +#else + GI_FLOAT32_V2_t v; + v.val[0] = GiLoadFloat32(Buffer); + v.val[1] = GiLoadFloat32(Buffer + GI_SIMD_LEN_BYTE / sizeof(float)); + + return v; +#endif +} + +GI_FORCEINLINE GI_FLOAT32_t GiLoadFloat32LowHalf(const float* Buffer) { #if defined(GI_NEON_INTRINSICS) return vcombine_f32(vld1_f32(Buffer), vdup_n_f32(0.f)); @@ -519,6 +532,16 @@ void GiStoreFloat32(float* Buffer, GI_FLOAT32_t Vector) { #endif } +GI_FORCEINLINE +void GiStoreFloat32V2(float* Buffer, GI_FLOAT32_V2_t Vector) { +#if defined(GI_NEON_INTRINSICS) + vst1q_f32_x2(Buffer, Vector); +#else + GiStoreFloat32(Buffer, Vector.val[0]); + GiStoreFloat32(Buffer + GI_SIMD_LEN_BYTE / sizeof(float), Vector.val[1]); +#endif +} + #if defined(GI_NEON_INTRINSICS) #define GISTORELANEFLOAT32(i) \ GI_FORCEINLINE void GiStoreLane##i##Float32(float* Buffer, GI_FLOAT32_t Vector) { \ @@ -593,6 +616,18 @@ GI_FLOAT32_V2_t GiZipqFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { } GI_FORCEINLINE +void GiStoreZipFloat32V2(float* Buffer, GI_FLOAT32_V2_t Vector) { +#if defined(GI_NEON_INTRINSICS) + vst2q_f32(Buffer, Vector); +#else + GI_FLOAT32_V2_t tmp; + tmp = GiZipqFloat32(Vector.val[0], Vector.val[1]); + GiStoreFloat32(Buffer, tmp.val[0]); + GiStoreFloat32(Buffer + GI_SIMD_LEN_BYTE / sizeof(float), tmp.val[1]); +#endif +} + +GI_FORCEINLINE GI_FLOAT32_t GiInterleaveLowFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { #if defined(GI_NEON64_INTRINSICS) return vzip1q_f32(Vector1, Vector2); @@ -1357,3 +1392,70 @@ GI_FORCEINLINE float32x2_t GiPmaxFloat32(float32x2_t a, float32x2_t b) { return res; #endif } + +GI_FORCEINLINE +GI_FLOAT32_V3_t GiLoadUzipFloat32V3(const float* ptr) { +#if defined(GI_NEON_INTRINSICS) + return vld3q_f32(ptr); +#elif defined(GI_SSE2_INTRINSICS) + GI_FLOAT32_V3_t v; + __m128 tmp0, tmp1, tmp2, tmp3; + v.val[0] = GiLoadFloat32(ptr); + v.val[1] = GiLoadFloat32((ptr + 4)); + v.val[2] = GiLoadFloat32((ptr + 8)); + + tmp0 = _mm_castsi128_ps(_mm_shuffle_epi32( + _mm_castps_si128(v.val[0]), 0 | (3 << 2) | (1 << 4) | (2 << 6))); + tmp1 = _mm_castsi128_ps( + _mm_shuffle_epi32(_mm_castps_si128(v.val[1]), _SWAP_HI_LOW32)); + tmp2 = _mm_castsi128_ps(_mm_shuffle_epi32( + _mm_castps_si128(v.val[2]), 1 | (2 << 2) | (0 << 4) | (3 << 6))); + tmp3 = _mm_unpacklo_ps(tmp1, tmp2); + + v.val[0] = _mm_movelh_ps(tmp0, tmp3); + tmp0 = _mm_unpackhi_ps(tmp0, tmp1); + v.val[1] = + _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(tmp0), _SWAP_HI_LOW32)); + v.val[1] = _mm_movehl_ps(tmp3, v.val[1]); + v.val[2] = _mm_movehl_ps(tmp2, tmp0); + return v; +#else + GI_FLOAT32_V3_t ret; + for (size_t i = 0; i < 3; i++) { + ret.val[i][0] = ptr[0 + i]; + ret.val[i][1] = ptr[3 + i]; + ret.val[i][2] = ptr[6 + i]; + ret.val[i][3] = ptr[9 + i]; + } + + return ret; +#endif +} + +GI_FORCEINLINE +void GiStoreZipFloat32V3(float* ptr, GI_FLOAT32_V3_t val) { +#if defined(GI_NEON_INTRINSICS) + vst3q_f32(ptr, val); +#elif defined(GI_SSE2_INTRINSICS) + GI_FLOAT32_V3_t v; + __m128 tmp0, tmp1, tmp2; + tmp0 = _mm_unpacklo_ps(val.val[0], val.val[1]); + tmp1 = _mm_unpackhi_ps(val.val[0], val.val[1]); + tmp2 = _mm_unpacklo_ps(val.val[1], val.val[2]); + v.val[1] = _mm_shuffle_ps(tmp2, tmp1, _MM_SHUFFLE(1, 0, 3, 2)); + v.val[2] = _mm_movehl_ps(val.val[2], tmp1); + v.val[2] = _mm_shuffle_ps(v.val[2], v.val[2], _MM_SHUFFLE(3, 1, 0, 2)); + tmp1 = _mm_unpacklo_ps(tmp2, val.val[0]); + v.val[0] = _mm_shuffle_ps(tmp0, tmp1, _MM_SHUFFLE(3, 2, 1, 0)); + + GiStoreFloat32(ptr, v.val[0]); + GiStoreFloat32((ptr + 4), v.val[1]); + GiStoreFloat32((ptr + 8), v.val[2]); +#else + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + *ptr++ = val.val[0][i]; + *ptr++ = val.val[1][i]; + *ptr++ = val.val[2][i]; + } +#endif +} diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp b/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp index 2034a0fc..098da14c 100644 --- a/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp +++ b/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp @@ -45,17 +45,7 @@ void sgemv_gi_naive_n_mk4( while (k < K) { GI_FLOAT32_t b = GiLoadFloat32(Bptr); GI_FLOAT32_V2_t a[2]; -#if defined(GI_TEST_NAIVE) -#define LOAD_A(step) \ - a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \ - a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4); -#elif defined(__arm__) || defined(__aarch64__) -#define LOAD_A(step) a[step] = vld1q_f32_x2(Aptr0 + step * 8); -#else -#define LOAD_A(step) \ - a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \ - a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4); -#endif +#define LOAD_A(step) a[step] = GiLoadFloat32V2(Aptr0 + step * 8); UNROLL_CALL_RAW(2, LOAD_A) #undef LOAD_A diff --git a/dnn/src/fallback/pooling/opr_impl.h b/dnn/src/fallback/pooling/opr_impl.h index de0d12ff..aa07b6c2 100644 --- a/dnn/src/fallback/pooling/opr_impl.h +++ b/dnn/src/fallback/pooling/opr_impl.h @@ -34,6 +34,8 @@ private: _megdnn_tensor_in src, _megdnn_tensor_out dst, const Param& param); void exec_w2x2_s2x2_int8(_megdnn_tensor_in src, _megdnn_tensor_out dst); void exec_w2x2_s2x2_avg_int8(_megdnn_tensor_in src, _megdnn_tensor_out dst); + void exec_fallback( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace); public: using naive::PoolingForwardImpl::PoolingForwardImpl; @@ -43,9 +45,6 @@ public: _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) override; - void exec_fallback( - _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace); - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override; static size_t constexpr MAX_SPATIAL_DIM = 2; diff --git a/dnn/src/fallback/resize/gi/direct_nchwxx.cpp b/dnn/src/fallback/resize/gi/direct_nchwxx.cpp new file mode 100644 index 00000000..efcb7551 --- /dev/null +++ b/dnn/src/fallback/resize/gi/direct_nchwxx.cpp @@ -0,0 +1,69 @@ +#include "src/fallback/resize/gi/direct_nchwxx.h" + +using namespace megdnn; +using namespace fallback; +using namespace resize; + +namespace { + +template +void resize_direct_nchwxx( + const ctype* sptr, ctype* dptr, size_t N, size_t IH, size_t IW, size_t OH, + size_t OW) { + using simd_helper = SIMDHelper; + constexpr size_t PC = simd_helper::simd_width; + using simd_type = typename simd_helper::simd_type; + + float scale_h = static_cast(OH) / IH; + float scale_w = static_cast(OW) / IW; + + for (size_t n = 0; n < N; ++n) { + for (size_t oh = 0; oh < OH; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + int ih0, ih1, iw0, iw1; + float ah0, ah1, aw0, aw1; + + std::tie(ah0, ih0, ah1, ih1) = + get_nearest_linear_coord(imode, scale_h, IH, oh); + std::tie(aw0, iw0, aw1, iw1) = + get_nearest_linear_coord(imode, scale_w, IW, ow); + + simd_type r0 = simd_helper::load(sptr + (ih0 * IW + iw0) * PC); + simd_type r1 = simd_helper::load(sptr + (ih0 * IW + iw1) * PC); + simd_type r2 = simd_helper::load(sptr + (ih1 * IW + iw0) * PC); + simd_type r3 = simd_helper::load(sptr + (ih1 * IW + iw1) * PC); + + // FIXME: weight fp16 may cause precision problem + ctype a0 = ah0 * aw0; + ctype a1 = ah0 * aw1; + ctype a2 = ah1 * aw0; + ctype a3 = ah1 * aw1; + + simd_type c = simd_helper::dup(0); + c = simd_helper::fma(c, r0, a0); + c = simd_helper::fma(c, r1, a1); + c = simd_helper::fma(c, r2, a2); + c = simd_helper::fma(c, r3, a3); + + simd_helper::store(dptr + (oh * OW + ow) * PC, c); + } + } + sptr += IH * IW * PC; + dptr += OH * OW * PC; + } +} +} // namespace + +void megdnn::fallback::resize_direct_nearest_nchw44_gi_fp32( + const ResizeImpl::KernParam& kern_param) { + resize_direct_nchwxx( + kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, + kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); +} + +void megdnn::fallback::resize_direct_linear_nchw44_gi_fp32( + const ResizeImpl::KernParam& kern_param) { + resize_direct_nchwxx( + kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, + kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); +} diff --git a/dnn/src/fallback/resize/gi/direct_nchwxx.h b/dnn/src/fallback/resize/gi/direct_nchwxx.h new file mode 100644 index 00000000..507b22ce --- /dev/null +++ b/dnn/src/fallback/resize/gi/direct_nchwxx.h @@ -0,0 +1,15 @@ +#pragma once +#include "src/fallback/resize/gi/helper.h" +#include "src/fallback/resize/opr_impl.h" + +namespace megdnn { +namespace fallback { + +void resize_direct_linear_nchw44_gi_fp32( + const ResizeImpl::KernParam& kern_param); + +void resize_direct_nearest_nchw44_gi_fp32( + const ResizeImpl::KernParam& kern_param); + +} // namespace fallback +} // namespace megdnn diff --git a/dnn/src/fallback/resize/gi/helper.h b/dnn/src/fallback/resize/gi/helper.h new file mode 100644 index 00000000..7ce5bc53 --- /dev/null +++ b/dnn/src/fallback/resize/gi/helper.h @@ -0,0 +1,76 @@ +#pragma once +#include "src/fallback/general_intrinsic/gi_float.h" +#include "src/fallback/resize/opr_impl.h" + +namespace megdnn { +namespace fallback { +namespace resize { + +using InterpolationMode = Resize::InterpolationMode; + +template +struct SIMDHelper {}; + +template <> +struct SIMDHelper { + using simd_type = GI_FLOAT32_t; + using simd_type_x2 = GI_FLOAT32_V2_t; + using ctype = float; + static constexpr size_t simd_width = 4; + + static GI_FORCEINLINE simd_type load(const ctype* src_ptr) { + return GiLoadFloat32(src_ptr); + } + static GI_FORCEINLINE void store(ctype* dst_ptr, const simd_type& rdst) { + GiStoreFloat32(dst_ptr, rdst); + } + static GI_FORCEINLINE void store2_interleave( + ctype* dst_ptr, const simd_type& rdst1, const simd_type& rdst2) { + simd_type_x2 rdst; + rdst.val[0] = rdst1; + rdst.val[1] = rdst2; + GiStoreZipFloat32V2(dst_ptr, rdst); + } + static GI_FORCEINLINE simd_type + fma(const simd_type& a, const simd_type& b, ctype n) { + return GiMultiplyAddScalarFloat32(a, b, n); + } + static GI_FORCEINLINE simd_type + fma(const simd_type& a, const simd_type& b, const simd_type& c) { + return GiMlaqFloat32(a, b, c); + } + static GI_FORCEINLINE simd_type dup(float val) { return GiBroadcastFloat32(val); } +}; + +static GI_FORCEINLINE int get_nearest_src(float scale, int size, int idx) { + return std::min(static_cast(idx / scale), size - 1); +} + +static GI_FORCEINLINE std::tuple get_nearest_linear_coord( + InterpolationMode imode, float scale, int size, int idx) { + if (size == 1) { + return std::make_tuple(1.0f, 0, 0.0f, 0); + } + + float alpha = (idx + 0.5f) / scale - 0.5f; + int origin_idx = static_cast(floor(alpha)); + alpha -= origin_idx; + + if (imode == InterpolationMode::INTER_NEAREST) { + origin_idx = get_nearest_src(scale, size, idx); + alpha = 0; + } + + if (origin_idx < 0) { + origin_idx = 0; + alpha = 0; + } else if (origin_idx + 1 >= size) { + origin_idx = size - 2; + alpha = 1; + } + + return std::make_tuple(1 - alpha, origin_idx, alpha, origin_idx + 1); +} +}; // namespace resize +}; // namespace fallback +}; // namespace megdnn diff --git a/dnn/src/fallback/resize/gi/resize_cv.cpp b/dnn/src/fallback/resize/gi/resize_cv.cpp new file mode 100644 index 00000000..17b2867f --- /dev/null +++ b/dnn/src/fallback/resize/gi/resize_cv.cpp @@ -0,0 +1,1434 @@ +/** + * By downloading, copying, installing or using the software you agree to this license. + * If you do not agree to this license, do not download, install, + * copy or use the software. + * + * + * License Agreement + * For Open Source Computer Vision Library + * (3-clause BSD License) + * + * Copyright (C) 2000-2020, Intel Corporation, all rights reserved. + * Copyright (C) 2009-2011, Willow Garage Inc., all rights reserved. + * Copyright (C) 2009-2016, NVIDIA Corporation, all rights reserved. + * Copyright (C) 2010-2013, Advanced Micro Devices, Inc., all rights reserved. + * Copyright (C) 2015-2016, OpenCV Foundation, all rights reserved. + * Copyright (C) 2015-2016, Itseez Inc., all rights reserved. + * Copyright (C) 2019-2020, Xperience AI, all rights reserved. + * Third party copyrights are property of their respective owners. + * + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the names of the copyright holders nor the names of the contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * This software is provided by the copyright holders and contributors "as is" and + * any express or implied warranties, including, but not limited to, the implied + * warranties of merchantability and fitness for a particular purpose are disclaimed. + * In no event shall copyright holders or contributors be liable for any direct, + * indirect, incidental, special, exemplary, or consequential damages + * (including, but not limited to, procurement of substitute goods or services; + * loss of use, data, or profits; or business interruption) however caused + * and on any theory of liability, whether in contract, strict liability, + * or tort (including negligence or otherwise) arising in any way out of + * the use of this software, even if advised of the possibility of such damage. + * + * --------------------------------------------------------------------------- + */ +#include "src/fallback/resize/gi/resize_cv.h" +#include +#include "src/common/cv/common.h" +#include "src/common/cv/helper.h" +#include "src/common/utils.h" +#include "src/fallback/handle.h" +#include "src/fallback/resize/opr_impl.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_fallback_resizecv_imode) +MIDOUT_DECL(megdnn_fallback_resizecv_dtype) + +using namespace megdnn; +using namespace fallback; +using namespace megcv; + +namespace { + +using InterpolationMode = param::Resize::InterpolationMode; +using IMode = InterpolationMode; + +void resize_nearest_32f(const Mat32f& src, Mat32f& dst) { + AlignedVector tabx(dst.rows()); + AlignedVector taby(dst.cols()); + const double fx = static_cast(dst.rows()) / src.rows(); + const double fy = static_cast(dst.cols()) / src.cols(); + const double ifx = 1.0f / fx; + const double ify = 1.0f / fy; + const size_t ch = src.channels(); + for (size_t dx = 0; dx < tabx.size(); ++dx) { + double rx = dx * ifx; + int sx = static_cast(floor(rx)); + sx = megcv::saturate(sx, 0, static_cast(src.rows())); + tabx[dx] = sx; + } + for (size_t dy = 0; dy < taby.size(); ++dy) { + double ry = dy * ify; + int sy = static_cast(floor(ry)); + sy = megcv::saturate(sy, 0, static_cast(src.cols())); + taby[dy] = sy; + } + // taby[taby.size() - 1] = src.cols() - 1; + size_t tabxsize = tabx.size(); + size_t tabysize = taby.size(); + if (ch == 1) { + for (size_t dx = 0; dx < tabxsize; ++dx) { + float* pdst = dst.ptr(dx); + const float* psrc = src.ptr(tabx[dx]); + size_t dy = 0; + for (; dy < tabysize; dy++) { + const float* pcsrc = psrc + taby[dy]; + pdst[dy] = pcsrc[0]; + } + } + } else if (ch == 3) { + for (size_t dx = 0; dx < tabxsize; ++dx) { + float* pdst = dst.ptr(dx); + const float* psrc = src.ptr(tabx[dx]); + size_t dy3 = 0; + for (size_t dy = 0; dy < tabysize; ++dy, dy3 += 3) { + float* pcdst = pdst + dy3; + const float* pcsrc = psrc + taby[dy] * 3; + pcdst[0] = pcsrc[0]; + pcdst[1] = pcsrc[1]; + pcdst[2] = pcsrc[2]; + } + } + } +} + +// linear 32f +void build_tabs_linear_32f( + const Mat32f& src, const Mat32f& dst, AlignedVector& tabsx, + AlignedVector& tabsy, AlignedVector& tabrx, + AlignedVector& tabry) { + megdnn_assert(src.rows() >= 2); + megdnn_assert(src.cols() >= 2); + megdnn_assert(dst.rows() >= 2); + megdnn_assert(dst.cols() >= 2); + const float fx = static_cast(dst.rows()) / src.rows(); + const float fy = static_cast(dst.cols()) / src.cols(); + const float ifx = 1.0f / fx; + const float ify = 1.0f / fy; + for (size_t dx = 0; dx < dst.rows(); ++dx) { + float rx = (dx + 0.5f) * ifx - 0.5f; + int sx = static_cast(floor(rx)); + rx -= sx; + if (sx < 0) { + sx = 0; + rx = 0; + } else if (sx + 1 >= static_cast(src.rows())) { + sx = src.rows() - 2; + rx = 1; + } + tabsx[dx] = sx; + tabrx[dx] = rx; + } + for (size_t dy = 0; dy < dst.cols(); ++dy) { + float ry = (dy + 0.5f) * ify - 0.5f; + int sy = static_cast(floor(ry)); + ry -= sy; + if (sy < 0) { + sy = 0; + ry = 0; + } else if (sy + 1 >= static_cast(src.cols())) { + sy = src.cols() - 2; + ry = 1; + } + tabsy[dy] = sy; + tabry[dy] = ry; + } +} + +void calc_cache_linear_32fc1_1( + const Mat32f& src, const Mat32f& dst, const AlignedVector& tabsx, + const AlignedVector& tabsy, const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, AlignedVector& cache0, + AlignedVector& cache1) { + (void)tabrx; + const float* psrc1 = src.ptr(tabsx[dx] + 1); + size_t dstcols = dst.cols(); + size_t dy = 0; + + // cache0 = cache1; + std::swap(cache0, cache1); + for (; dy < dstcols; ++dy) { + const float* pcsrc10 = psrc1 + (tabsy[dy] + 0); + const float* pcsrc11 = psrc1 + (tabsy[dy] + 1); + float ry = tabry[dy]; + float iry = 1.0f - ry; + cache1[dy] = pcsrc11[0] * ry + pcsrc10[0] * iry; + } +} + +void calc_cache_linear_32fc1_2( + const Mat32f& src, const Mat32f& dst, const AlignedVector& tabsx, + const AlignedVector& tabsy, const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, AlignedVector& cache0, + AlignedVector& cache1) { + (void)tabrx; + const float* psrc0 = src.ptr(tabsx[dx] + 0); + const float* psrc1 = src.ptr(tabsx[dx] + 1); + int dstcols = dst.cols(); + int dy = 0; + + // 4 pixels each time + float* cache0_ptr = cache0.data(); + float* cache1_ptr = cache1.data(); + const float* tabry_ptr = tabry.data(); + for (; dy + 4 <= dstcols; dy += 4) { +#define EXPAND(dy) \ + { \ + int t0 = tabsy[dy + 0]; \ + int t1 = tabsy[dy + 1]; \ + int t2 = tabsy[dy + 2]; \ + int t3 = tabsy[dy + 3]; \ + const float pcsrc00[4] = { \ + psrc0[t0 + 0], psrc0[t1 + 0], psrc0[t2 + 0], psrc0[t3 + 0]}; \ + const float pcsrc01[4] = { \ + psrc0[t0 + 1], \ + psrc0[t1 + 1], \ + psrc0[t2 + 1], \ + psrc0[t3 + 1], \ + }; \ + const float pcsrc10[4] = { \ + psrc1[t0 + 0], \ + psrc1[t1 + 0], \ + psrc1[t2 + 0], \ + psrc1[t3 + 0], \ + }; \ + const float pcsrc11[4] = { \ + psrc1[t0 + 1], \ + psrc1[t1 + 1], \ + psrc1[t2 + 1], \ + psrc1[t3 + 1], \ + }; \ + GI_FLOAT32_t v_pcsrc00 = GiLoadFloat32(pcsrc00); \ + GI_FLOAT32_t v_pcsrc01 = GiLoadFloat32(pcsrc01); \ + GI_FLOAT32_t v_pcsrc10 = GiLoadFloat32(pcsrc10); \ + GI_FLOAT32_t v_pcsrc11 = GiLoadFloat32(pcsrc11); \ + GI_FLOAT32_t v_ry = GiLoadFloat32(tabry_ptr + dy); \ + GI_FLOAT32_t v_iry = GiSubtractFloat32(GiBroadcastFloat32(1.0f), v_ry); \ + GiStoreFloat32( \ + cache0_ptr + dy, \ + GiMlaqFloat32(GiMultiplyFloat32(v_pcsrc01, v_ry), v_pcsrc00, v_iry)); \ + GiStoreFloat32( \ + cache1_ptr + dy, \ + GiMlaqFloat32(GiMultiplyFloat32(v_pcsrc11, v_ry), v_pcsrc10, v_iry)); \ + } \ + while (0) + + EXPAND(dy); +#undef EXPAND + } + for (; dy < dstcols; ++dy) { + const float* pcsrc00 = psrc0 + (tabsy[dy] + 0); + const float* pcsrc01 = psrc0 + (tabsy[dy] + 1); + const float* pcsrc10 = psrc1 + (tabsy[dy] + 0); + const float* pcsrc11 = psrc1 + (tabsy[dy] + 1); + float ry = tabry[dy]; + float iry = 1.0f - ry; + cache0[dy] = pcsrc01[0] * ry + pcsrc00[0] * iry; + cache1[dy] = pcsrc11[0] * ry + pcsrc10[0] * iry; + } +} + +void calc_cache_linear_32fc3_1( + const Mat32f& src, const Mat32f& dst, const AlignedVector& tabsx, + const AlignedVector& tabsy, const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, AlignedVector& cache0, + AlignedVector& cache1) { + (void)tabrx; + const float* psrc1 = src.ptr(tabsx[dx] + 1); + const size_t dstcols = dst.cols(); + size_t dy = 0, dy3 = 0; + + // cache0 = cache1; + std::swap(cache0, cache1); + for (; dy < dstcols; ++dy, dy3 += 3) { + const float* pcsrc10 = psrc1 + (tabsy[dy] + 0) * 3; + const float* pcsrc11 = psrc1 + (tabsy[dy] + 1) * 3; + float ry = tabry[dy]; + float iry = 1.0f - ry; + cache1[dy3 + 0] = pcsrc11[0] * ry + pcsrc10[0] * iry; + cache1[dy3 + 1] = pcsrc11[1] * ry + pcsrc10[1] * iry; + cache1[dy3 + 2] = pcsrc11[2] * ry + pcsrc10[2] * iry; + } +} + +void calc_cache_linear_32fc3_2( + const Mat32f& src, const Mat32f& dst, const AlignedVector& tabsx, + const AlignedVector& tabsy, const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, AlignedVector& cache0, + AlignedVector& cache1) { + (void)tabrx; + const float* psrc0 = src.ptr(tabsx[dx] + 0); + const float* psrc1 = src.ptr(tabsx[dx] + 1); + int dstcols = dst.cols(); + int dy = 0, dy3 = 0; + + for (; dy < dstcols; ++dy, dy3 += 3) { + const float* pcsrc00 = psrc0 + (tabsy[dy] + 0) * 3; + const float* pcsrc01 = psrc0 + (tabsy[dy] + 1) * 3; + const float* pcsrc10 = psrc1 + (tabsy[dy] + 0) * 3; + const float* pcsrc11 = psrc1 + (tabsy[dy] + 1) * 3; + float ry = tabry[dy]; + float iry = 1.0f - ry; + cache0[dy3 + 0] = pcsrc01[0] * ry + pcsrc00[0] * iry; + cache1[dy3 + 0] = pcsrc11[0] * ry + pcsrc10[0] * iry; + cache0[dy3 + 1] = pcsrc01[1] * ry + pcsrc00[1] * iry; + cache1[dy3 + 1] = pcsrc11[1] * ry + pcsrc10[1] * iry; + cache0[dy3 + 2] = pcsrc01[2] * ry + pcsrc00[2] * iry; + cache1[dy3 + 2] = pcsrc11[2] * ry + pcsrc10[2] * iry; + } +} +void resize_linear_32f_gi(const Mat32f& src, Mat32f& dst) { + AlignedVector tabsx(dst.rows()); + AlignedVector tabsy(dst.cols()); + AlignedVector tabrx(dst.rows()); + AlignedVector tabry(dst.cols()); + build_tabs_linear_32f(src, dst, tabsx, tabsy, tabrx, tabry); + + if (src.channels() == 1) { + AlignedVector cache0(dst.cols()), cache1(dst.cols()); + int dstrows = dst.rows(); + int dstcols = dst.cols(); + for (int dx = 0; dx < dstrows; ++dx) { + if (dx == 0 || tabsx[dx] != tabsx[dx - 1]) { + if (dx > 0 && tabsx[dx] == tabsx[dx - 1] + 1) { + calc_cache_linear_32fc1_1( + src, dst, tabsx, tabsy, tabrx, tabry, dx, cache0, cache1); + } else { + calc_cache_linear_32fc1_2( + src, dst, tabsx, tabsy, tabrx, tabry, dx, cache0, cache1); + } + } + const float* cache0_ptr = cache0.data(); + const float* cache1_ptr = cache1.data(); + float rx = tabrx[dx]; + float irx = 1.0f - rx; + float* pdst = dst.ptr(dx); + int dy = 0; +#define EXPAND(x) \ + v_cache0 = GiLoadFloat32(cache0_ptr + dy + x); \ + v_cache1 = GiLoadFloat32(cache1_ptr + dy + x); \ + GiStoreFloat32( \ + pdst + dy + x, \ + GiMlaqFloat32(GiMultiplyFloat32(v_rx, v_cache1), v_irx, v_cache0)); + GI_FLOAT32_t v_rx = GiBroadcastFloat32(rx); + GI_FLOAT32_t v_irx = GiBroadcastFloat32(irx); + for (; dy + 8 <= dstcols; dy += 8) { + GI_FLOAT32_t v_cache0; + GI_FLOAT32_t v_cache1; + EXPAND(0); + EXPAND(4); + } + if (dy + 4 <= dstcols) { + GI_FLOAT32_t v_cache0; + GI_FLOAT32_t v_cache1; + EXPAND(0); + dy += 4; + } +#undef EXPAND + for (; dy < dstcols; ++dy) { + float* pcdst = pdst + dy; + pcdst[0] = rx * cache1[dy] + irx * cache0[dy]; + } + } + } else if (src.channels() == 3) { + int dstrows = dst.rows(); + int dstcols = dst.cols() * 3; + AlignedVector cache0(dstcols), cache1(dstcols); + for (int dx = 0; dx < dstrows; ++dx) { + if (dx == 0 || tabsx[dx] != tabsx[dx - 1]) { + if (dx > 0 && tabsx[dx] == tabsx[dx - 1] + 1) { + calc_cache_linear_32fc3_1( + src, dst, tabsx, tabsy, tabrx, tabry, dx, cache0, cache1); + } else { + calc_cache_linear_32fc3_2( + src, dst, tabsx, tabsy, tabrx, tabry, dx, cache0, cache1); + } + } + const float* cache0_ptr = cache0.data(); + const float* cache1_ptr = cache1.data(); + float rx = tabrx[dx]; + float irx = 1.0f - rx; + float* pdst = dst.ptr(dx); + int dy = 0; + GI_FLOAT32_t v_rx = GiBroadcastFloat32(rx); + GI_FLOAT32_t v_irx = GiBroadcastFloat32(irx); +#define EXPAND(x) \ + v_cache0 = GiLoadUzipFloat32V3(cache0_ptr + dy + (x)*3); \ + v_cache1 = GiLoadUzipFloat32V3(cache1_ptr + dy + (x)*3); \ + v_dst.val[0] = GiMlaqFloat32( \ + GiMultiplyFloat32(v_rx, v_cache1.val[0]), v_irx, v_cache0.val[0]); \ + v_dst.val[1] = GiMlaqFloat32( \ + GiMultiplyFloat32(v_rx, v_cache1.val[1]), v_irx, v_cache0.val[1]); \ + v_dst.val[2] = GiMlaqFloat32( \ + GiMultiplyFloat32(v_rx, v_cache1.val[2]), v_irx, v_cache0.val[2]); \ + GiStoreZipFloat32V3(pdst + dy + (x)*3, v_dst); + + for (; dy + 8 * 3 <= dstcols; dy += 8 * 3) { + GI_FLOAT32_V3_t v_cache0; + GI_FLOAT32_V3_t v_cache1; + GI_FLOAT32_V3_t v_dst; + + EXPAND(0); + EXPAND(4); + } + + if (dy + 4 * 3 <= dstcols) { + GI_FLOAT32_V3_t v_cache0; + GI_FLOAT32_V3_t v_cache1; + GI_FLOAT32_V3_t v_dst; + + EXPAND(0); + + dy += 4 * 3; + } +#undef EXPAND + for (; dy < dstcols; dy += 3) { + float* pcdst = pdst + dy; + pcdst[0] = rx * cache1[dy + 0] + irx * cache0[dy + 0]; + pcdst[1] = rx * cache1[dy + 1] + irx * cache0[dy + 1]; + pcdst[2] = rx * cache1[dy + 2] + irx * cache0[dy + 2]; + } + } + } else { + megdnn_throw(("nr. of channels must be 1 or 3.")); + } +} + +void resize_linear_32f(const Mat32f& src, Mat32f& dst) { + return resize_linear_32f_gi(src, dst); +} + +const int INTER_RESIZE_COEF_BITS = 11; +const int INTER_RESIZE_COEF_SCALE = 1 << INTER_RESIZE_COEF_BITS; +const float MEGCV_PI = acos(-1); +struct HResizeNoVec { + int operator()( + const float**, float**, int, const int*, const float*, int, int, int, int, + int) const { + return 0; + } +}; +template +struct ResizeAreaFastNoVec { + ResizeAreaFastNoVec(int, int) {} + ResizeAreaFastNoVec(int, int, int, int) {} + int operator()(const T*, T*, int) const { return 0; } +}; + +struct VResizeCubicVec_32f { + int operator()(const float** src, float* dst, const float* beta, int width) const { + const float *S0 = src[0], *S1 = src[1], *S2 = src[2], *S3 = src[3]; + int x = 0; + GI_FLOAT32_t v_b0 = GiBroadcastFloat32(beta[0]), + v_b1 = GiBroadcastFloat32(beta[1]), + v_b2 = GiBroadcastFloat32(beta[2]), + v_b3 = GiBroadcastFloat32(beta[3]); + + for (; x <= width - 8; x += 8) { + GI_FLOAT32_t s0_t = GiLoadFloat32(S0 + x); + GI_FLOAT32_t s1_t = GiLoadFloat32(S1 + x); + GI_FLOAT32_t s2_t = GiLoadFloat32(S2 + x); + GI_FLOAT32_t s3_t = GiLoadFloat32(S3 + x); + GI_FLOAT32_t tmp = GiMlaqFloat32( + GiMlaqFloat32( + GiMlaqFloat32(GiMultiplyFloat32(v_b0, s0_t), v_b1, s1_t), + v_b2, s2_t), + v_b3, s3_t); + GiStoreFloat32(dst + x, tmp); + + s0_t = GiLoadFloat32(S0 + x + 4); + s1_t = GiLoadFloat32(S1 + x + 4); + s2_t = GiLoadFloat32(S2 + x + 4); + s3_t = GiLoadFloat32(S3 + x + 4); + tmp = GiMlaqFloat32( + GiMlaqFloat32( + GiMlaqFloat32(GiMultiplyFloat32(v_b0, s0_t), v_b1, s1_t), + v_b2, s2_t), + v_b3, s3_t); + GiStoreFloat32(dst + x + 4, tmp); + } + + return x; + } +}; + +struct VResizeLanczos4Vec_32f { + int operator()(const float** src, float* dst, const float* beta, int width) const { + const float *S0 = src[0], *S1 = src[1], *S2 = src[2], *S3 = src[3], + *S4 = src[4], *S5 = src[5], *S6 = src[6], *S7 = src[7]; + int x = 0; + GI_FLOAT32_t v_b0 = GiBroadcastFloat32(beta[0]), + v_b1 = GiBroadcastFloat32(beta[1]), + v_b2 = GiBroadcastFloat32(beta[2]), + v_b3 = GiBroadcastFloat32(beta[3]), + v_b4 = GiBroadcastFloat32(beta[4]), + v_b5 = GiBroadcastFloat32(beta[5]), + v_b6 = GiBroadcastFloat32(beta[6]), + v_b7 = GiBroadcastFloat32(beta[7]); + + for (; x <= width - 4; x += 4) { + GI_FLOAT32_t v_dst0 = GiMlaqFloat32( + GiMlaqFloat32( + GiMlaqFloat32( + GiMultiplyFloat32(v_b0, GiLoadFloat32(S0 + x)), + v_b1, GiLoadFloat32(S1 + x)), + v_b2, GiLoadFloat32(S2 + x)), + v_b3, GiLoadFloat32(S3 + x)); + GI_FLOAT32_t v_dst1 = GiMlaqFloat32( + GiMlaqFloat32( + GiMlaqFloat32( + GiMultiplyFloat32(v_b4, GiLoadFloat32(S4 + x)), + v_b5, GiLoadFloat32(S5 + x)), + v_b6, GiLoadFloat32(S6 + x)), + v_b7, GiLoadFloat32(S7 + x)); + GiStoreFloat32(dst + x, GiAddFloat32(v_dst0, v_dst1)); + } + + return x; + } +}; +struct VResizeLinearVec_32f { + int operator()(const float** src, float* dst, const float* beta, int width) const { + const float *S0 = src[0], *S1 = src[1]; + int x = 0; + + GI_FLOAT32_t v_b0 = GiBroadcastFloat32(beta[0]), + v_b1 = GiBroadcastFloat32(beta[1]); + + for (; x <= width - 8; x += 8) { + GI_FLOAT32_t v_src00 = GiLoadFloat32(S0 + x), + v_src01 = GiLoadFloat32(S0 + x + 4); + GI_FLOAT32_t v_src10 = GiLoadFloat32(S1 + x), + v_src11 = GiLoadFloat32(S1 + x + 4); + + GiStoreFloat32( + dst + x, + GiMlaqFloat32(GiMultiplyFloat32(v_src00, v_b0), v_src10, v_b1)); + GiStoreFloat32( + dst + x + 4, + GiMlaqFloat32(GiMultiplyFloat32(v_src01, v_b0), v_src11, v_b1)); + } + + return x; + } +}; + +typedef HResizeNoVec HResizeLinearVec_32f; + +struct ResizeAreaFastVec_SIMD_32f { + ResizeAreaFastVec_SIMD_32f(int _scale_x, int _scale_y, int _cn, int _step) + : scale_x(_scale_x), + scale_y(_scale_y), + cn(_cn), + step(_step * sizeof(float)) { + fast_mode = scale_x == 2 && scale_y == 2 && (cn == 1 || cn == 3 || cn == 4); + } + + int operator()(const float* S, float* D, int w) const { + if (!fast_mode) + return 0; + + const float *S0 = S, *S1 = (const float*)((const float*)(S0) + step); + int dx = 0; + + GI_FLOAT32_t v_025 = GiBroadcastFloat32(0.25f); + + if (cn == 1) { + for (; dx <= w - 4; dx += 4, S0 += 8, S1 += 8, D += 4) { + GI_FLOAT32_V2_t v_row0 = GiLd2qFloat32(S0), v_row1 = GiLd2qFloat32(S1); + + GI_FLOAT32_t v_dst0 = GiAddFloat32(v_row0.val[0], v_row0.val[1]); + GI_FLOAT32_t v_dst1 = GiAddFloat32(v_row1.val[0], v_row1.val[1]); + + GiStoreFloat32( + D, GiMultiplyFloat32(GiAddFloat32(v_dst0, v_dst1), v_025)); + } + } + + return dx; + } + +private: + int scale_x, scale_y; + int cn; + bool fast_mode; + int step; +}; + +struct DecimateAlpha { + int si, di; + float alpha; +}; +template +using ResizeFunc = void (*)( + const Mat& src, Mat& dst, const int* xofs, const void* alpha, + const int* yofs, const void* beta, int xmin, int xmax, int ksize); +template +using ResizeAreaFastFunc = void (*)( + const Mat& src, Mat& dst, const int* ofs, const int* xofs, int scale_x, + int scale_y); +template +using ResizeAreaFunc = void (*)( + const Mat& src, Mat& dst, const DecimateAlpha* xtab, int xtab_size, + const DecimateAlpha* ytab, int ytab_size, const int* yofs); + +static GI_FORCEINLINE void interpolate_cubic(float x, float* coeffs) { + const float A = -0.75f; + + coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A; + coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1; + coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1; + coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2]; +} +static GI_FORCEINLINE void interpolate_lanczos4(float x, float* coeffs) { + static const double s45 = 0.70710678118654752440084436210485; + static const double cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45}, + {-1, 0}, {s45, s45}, {0, -1}, {-s45, s45}}; + + if (x < FLT_EPSILON) { + for (int i = 0; i < 8; i++) + coeffs[i] = 0; + coeffs[3] = 1; + return; + } + + float sum = 0; + double y0 = -(x + 3) * MEGCV_PI * 0.25, s0 = sin(y0), c0 = cos(y0); + for (int i = 0; i < 8; i++) { + double y = -(x + 3 - i) * MEGCV_PI * 0.25; + coeffs[i] = (float)((cs[i][0] * s0 + cs[i][1] * c0) / (y * y)); + sum += coeffs[i]; + } + + sum = 1.f / sum; + for (int i = 0; i < 8; i++) + coeffs[i] *= sum; +} + +template +struct HResizeLanczos4 { + typedef T value_type; + typedef WT buf_type; + typedef AT alpha_type; + + void operator()( + const T** src, WT** dst, int count, const int* xofs, const AT* alpha, + int swidth, int dwidth, int cn, int xmin, int xmax) const { + for (int k = 0; k < count; k++) { + const T* S = src[k]; + WT* D = dst[k]; + int dx = 0, limit = xmin; + if (cn == 1) { + for (;;) { + for (; dx < limit; dx++, alpha += 8) { + int j, sx = xofs[dx] - 1 * 3; + WT v = 0; + for (j = 0; j < 8; j++) { + int sxj = sx + j * 1; + if ((unsigned)sxj >= (unsigned)swidth) { + while (sxj < 0) + sxj += 1; + while (sxj >= swidth) + sxj -= 1; + } + v += S[sxj] * alpha[j]; + } + D[dx] = v; + } + if (limit == dwidth) + break; + for (; dx < xmax; dx++, alpha += 8) { + int sx = xofs[dx]; + D[dx] = S[sx - 1 * 3] * alpha[0] + S[sx - 1 * 2] * alpha[1] + + S[sx - 1] * alpha[2] + S[sx] * alpha[3] + + S[sx + 1] * alpha[4] + S[sx + 1 * 2] * alpha[5] + + S[sx + 1 * 3] * alpha[6] + S[sx + 1 * 4] * alpha[7]; + } + limit = dwidth; + } + } else { + megdnn_assert(cn == 3); + for (;;) { + for (; dx < limit; dx++, alpha += 8) { + int j, sx = xofs[dx] - 3 * 3; + WT v = 0; + for (j = 0; j < 8; j++) { + int sxj = sx + j * 3; + if ((unsigned)sxj >= (unsigned)swidth) { + while (sxj < 0) + sxj += 3; + while (sxj >= swidth) + sxj -= 3; + } + v += S[sxj] * alpha[j]; + } + D[dx] = v; + } + if (limit == dwidth) + break; + for (; dx < xmax; dx++, alpha += 8) { + int sx = xofs[dx]; + D[dx] = S[sx - 3 * 3] * alpha[0] + S[sx - 3 * 2] * alpha[1] + + S[sx - 3] * alpha[2] + S[sx] * alpha[3] + + S[sx + 3] * alpha[4] + S[sx + 3 * 2] * alpha[5] + + S[sx + 3 * 3] * alpha[6] + S[sx + 3 * 4] * alpha[7]; + } + limit = dwidth; + } + } + alpha -= dwidth * 8; + } + } +}; +template +struct HResizeLinear { + typedef T value_type; + typedef WT buf_type; + typedef AT alpha_type; + + void operator()( + const T** src, WT** dst, int count, const int* xofs, const AT* alpha, + int swidth, int dwidth, int cn, int xmin, int xmax) const { + int dx, k; + VecOp vecOp; + + int dx0 = + vecOp((const float**)src, (float**)dst, count, xofs, + (const float*)alpha, swidth, dwidth, cn, xmin, xmax); + + for (k = 0; k <= count - 2; k++) { + const T *S0 = src[k], *S1 = src[k + 1]; + WT *D0 = dst[k], *D1 = dst[k + 1]; + for (dx = dx0; dx < xmax; dx++) { + int sx = xofs[dx]; + WT a0 = alpha[dx * 2], a1 = alpha[dx * 2 + 1]; + WT t0 = S0[sx] * a0 + S0[sx + cn] * a1; + WT t1 = S1[sx] * a0 + S1[sx + cn] * a1; + D0[dx] = t0; + D1[dx] = t1; + } + + for (; dx < dwidth; dx++) { + int sx = xofs[dx]; + D0[dx] = WT(S0[sx] * ONE); + D1[dx] = WT(S1[sx] * ONE); + } + } + + for (; k < count; k++) { + const T* S = src[k]; + WT* D = dst[k]; + for (dx = 0; dx < xmax; dx++) { + int sx = xofs[dx]; + D[dx] = S[sx] * alpha[dx * 2] + S[sx + cn] * alpha[dx * 2 + 1]; + } + + for (; dx < dwidth; dx++) + D[dx] = WT(S[xofs[dx]] * ONE); + } + } +}; +template +struct HResizeCubic { + typedef T value_type; + typedef WT buf_type; + typedef AT alpha_type; + + void operator()( + const T** src, WT** dst, int count, const int* xofs, const AT* alpha, + int swidth, int dwidth, int cn, int xmin, int xmax) const { + for (int k = 0; k < count; k++) { + const T* S = src[k]; + WT* D = dst[k]; + int dx = 0, limit = xmin; + if (cn == 1) { + for (;;) { + for (; dx < limit; dx++, alpha += 4) { + int j, sx = xofs[dx] - 1; + WT v = 0; + for (j = 0; j < 4; j++) { + int sxj = sx + j * 1; + if ((unsigned)sxj >= (unsigned)swidth) { + while (sxj < 0) + sxj += 1; + while (sxj >= swidth) + sxj -= 1; + } + v += S[sxj] * alpha[j]; + } + D[dx] = v; + } + if (limit == dwidth) + break; + for (; dx < xmax; dx++, alpha += 4) { + int sx = xofs[dx]; + D[dx] = S[sx - 1] * alpha[0] + S[sx] * alpha[1] + + S[sx + 1] * alpha[2] + S[sx + 1 * 2] * alpha[3]; + } + limit = dwidth; + } + } else { + megdnn_assert(cn == 3); + for (;;) { + for (; dx < limit; dx++, alpha += 4) { + int j, sx = xofs[dx] - 3; + WT v = 0; + for (j = 0; j < 4; j++) { + int sxj = sx + j * 3; + if ((unsigned)sxj >= (unsigned)swidth) { + while (sxj < 0) + sxj += 3; + while (sxj >= swidth) + sxj -= 3; + } + v += S[sxj] * alpha[j]; + } + D[dx] = v; + } + if (limit == dwidth) + break; + for (; dx < xmax; dx++, alpha += 4) { + int sx = xofs[dx]; + D[dx] = S[sx - 3] * alpha[0] + S[sx] * alpha[1] + + S[sx + 3] * alpha[2] + S[sx + 3 * 2] * alpha[3]; + } + limit = dwidth; + } + } + alpha -= dwidth * 4; + } + } +}; + +template +struct VResizeLanczos4 { + typedef T value_type; + typedef WT buf_type; + typedef AT alpha_type; + + void operator()(const WT** src, T* dst, const AT* beta, int width) const { + CastOp castOp; + VecOp vecOp; + int k, x = vecOp((const float**)src, (float*)dst, (const float*)beta, width); +#if MEGCV_ENABLE_UNROLLED + for (; x <= width - 4; x += 4) { + WT b = beta[0]; + const WT* S = src[0]; + WT s0 = S[x] * b, s1 = S[x + 1] * b, s2 = S[x + 2] * b, s3 = S[x + 3] * b; + + for (k = 1; k < 8; k++) { + b = beta[k]; + S = src[k]; + s0 += S[x] * b; + s1 += S[x + 1] * b; + s2 += S[x + 2] * b; + s3 += S[x + 3] * b; + } + + dst[x] = castOp(s0); + dst[x + 1] = castOp(s1); + dst[x + 2] = castOp(s2); + dst[x + 3] = castOp(s3); + } +#endif + + for (; x < width; x++) { + dst[x] = castOp( + src[0][x] * beta[0] + src[1][x] * beta[1] + src[2][x] * beta[2] + + src[3][x] * beta[3] + src[4][x] * beta[4] + src[5][x] * beta[5] + + src[6][x] * beta[6] + src[7][x] * beta[7]); + } + } +}; +template +struct VResizeLinear { + typedef T value_type; + typedef WT buf_type; + typedef AT alpha_type; + + void operator()(const WT** src, T* dst, const AT* beta, int width) const { + WT b0 = beta[0], b1 = beta[1]; + const WT *S0 = src[0], *S1 = src[1]; + CastOp castOp; + VecOp vecOp; + int x = vecOp((const float**)src, (float*)dst, (const float*)beta, width); +#if MEGCV_ENABLE_UNROLLED + for (; x <= width - 4; x += 4) { + WT t0, t1; + t0 = S0[x] * b0 + S1[x] * b1; + t1 = S0[x + 1] * b0 + S1[x + 1] * b1; + dst[x] = castOp(t0); + dst[x + 1] = castOp(t1); + t0 = S0[x + 2] * b0 + S1[x + 2] * b1; + t1 = S0[x + 3] * b0 + S1[x + 3] * b1; + dst[x + 2] = castOp(t0); + dst[x + 3] = castOp(t1); + } +#endif + for (; x < width; x++) + dst[x] = castOp(S0[x] * b0 + S1[x] * b1); + } +}; +template +struct VResizeCubic { + typedef T value_type; + typedef WT buf_type; + typedef AT alpha_type; + + void operator()(const WT** src, T* dst, const AT* beta, int width) const { + WT b0 = beta[0], b1 = beta[1], b2 = beta[2], b3 = beta[3]; + const WT *S0 = src[0], *S1 = src[1], *S2 = src[2], *S3 = src[3]; + CastOp castOp; + VecOp vecOp; + + int x = vecOp((const float**)src, (float*)dst, (const float*)beta, width); + for (; x < width; x++) + dst[x] = castOp(S0[x] * b0 + S1[x] * b1 + S2[x] * b2 + S3[x] * b3); + } +}; + +template +void resizeGeneric_( + const Mat& src, Mat& dst, const int* xofs, const void* _alpha, + const int* yofs, const void* _beta, int xmin, int xmax, int ksize) { + typedef typename HResize::value_type T; + typedef typename HResize::buf_type WT; + typedef typename HResize::alpha_type AT; + + const AT* beta = static_cast(_beta); + const AT* alpha = static_cast(_alpha); + int swidth = src.width(); + int sheight = src.height(); + int dwidth = dst.width(); + int dheight = dst.height(); + int cn = src.channels(); + swidth *= cn; + dwidth *= cn; + xmin *= cn; + xmax *= cn; + // image resize is a separable operation. In case of not too strong + // dsize.height + int dy; + HResize hresize; + VResize vresize; + + int bufstep = static_cast(align_size(dwidth, 16)); + AlignedVector _buffer(bufstep * ksize); + WT* buffer = _buffer.data(); + const T* srows[16] = {0}; + WT* rows[16] = {0}; + int prev_sy[16]; + + for (int k = 0; k < ksize; ++k) { + prev_sy[k] = -1; + rows[k] = buffer + bufstep * k; + } + + for (dy = 0; dy < dheight; ++dy, beta += ksize) { + int sy0 = yofs[dy], k0 = ksize, k1 = 0, ksize2 = ksize / 2; + + for (int k = 0; k < ksize; ++k) { + int sy = saturate(sy0 - ksize2 + 1 + k, 0, sheight); + for (k1 = std::max(k1, k); k1 < ksize; ++k1) { + if (sy == prev_sy[k1]) { + if (k1 > k) + memcpy(rows[k], rows[k1], bufstep * sizeof(rows[0][0])); + break; + } + } + if (k1 == ksize) + k0 = std::min(k0, k); + srows[k] = src.ptr(sy); + prev_sy[k] = sy; + } + if (k0 < ksize) + hresize(srows + k0, rows + k0, ksize - k0, xofs, alpha, swidth, dwidth, cn, + xmin, xmax); + vresize((const WT**)(rows), dst.ptr(dy), beta, dwidth); + } +} + +template +void setup_resize_env( + InterpolationMode /* ip */, int& /* ksize */, bool& /* fixedpt */, + ResizeFunc& /* func */) { + megdnn_throw(("unimplemented")); +} +template <> +void setup_resize_env( + InterpolationMode ip, int& ksize, bool& fixedpt, ResizeFunc& func) { + fixedpt = false; + switch (ip) { + case IMode::INTER_CUBIC: + ksize = 4; + func = resizeGeneric_< + HResizeCubic, + VResizeCubic< + float, float, float, Cast, + VResizeCubicVec_32f>, + float>; + break; + case IMode::INTER_LANCZOS4: + ksize = 8; + func = resizeGeneric_< + HResizeLanczos4, + VResizeLanczos4< + float, float, float, Cast, + VResizeLanczos4Vec_32f>, + float>; + break; + case IMode::INTER_LINEAR: + case IMode::INTER_AREA: + ksize = 2; + func = resizeGeneric_< + HResizeLinear, + VResizeLinear< + float, float, float, Cast, + VResizeLinearVec_32f>, + float>; + break; + default: + megdnn_throw(("unknown interpolation method")); + } +} + +int compute_resize_area_tab( + int ssize, int dsize, int cn, double scale, DecimateAlpha* tab) { + int k = 0; + for (int dx = 0; dx < dsize; dx++) { + double fsx1 = dx * scale; + double fsx2 = fsx1 + scale; + double cellWidth = std::min(scale, ssize - fsx1); + + int sx1 = ceil(fsx1), sx2 = floor(fsx2); + + sx2 = std::min(sx2, ssize - 1); + sx1 = std::min(sx1, sx2); + + if (sx1 - fsx1 > 1e-3) { + megdnn_assert(k < ssize * 2); + tab[k].di = dx * cn; + tab[k].si = (sx1 - 1) * cn; + tab[k++].alpha = (float)((sx1 - fsx1) / cellWidth); + } + + for (int sx = sx1; sx < sx2; sx++) { + megdnn_assert(k < ssize * 2); + tab[k].di = dx * cn; + tab[k].si = sx * cn; + tab[k++].alpha = float(1.0 / cellWidth); + } + + if (fsx2 - sx2 > 1e-3) { + megdnn_assert(k < ssize * 2); + tab[k].di = dx * cn; + tab[k].si = sx2 * cn; + tab[k++].alpha = + (float)(std::min(std::min(fsx2 - sx2, 1.), cellWidth) / cellWidth); + } + } + return k; +} + +// resize Area Fast +template +void resizeAreaFast_( + const Mat& src, Mat& dst, const int* ofs, const int* xofs, int scale_x, + int scale_y) { + // Range range(0, dst.rows); + int swidth = src.width(); + int sheight = src.height(); + int dwidth = dst.width(); + int dheight = dst.height(); + int cn = src.channels(); + int area = scale_x * scale_y; + float scale = 1.f / (area); + int dwidth1 = (swidth / scale_x) * cn; + dwidth *= cn; + swidth *= cn; + int dy, dx, k = 0; + + VecOp vop(scale_x, scale_y, src.channels(), (int)src.step()); + + for (dy = 0; dy < dheight; dy++) { + T* D = (T*)(dst.ptr(dy)); + int sy0 = dy * scale_y; + int w = sy0 + scale_y <= sheight ? dwidth1 : 0; + + if (sy0 >= sheight) { + for (dx = 0; dx < dwidth; dx++) + D[dx] = 0; + continue; + } + + dx = vop((const T*)(src.ptr(sy0)), D, w); + for (; dx < w; dx++) { + const T* S = (const T*)(src.ptr(sy0)) + xofs[dx]; + WT sum = 0; + k = 0; +#if MEGCV_ENABLE_UNROLLED + for (; k <= area - 4; k += 4) + sum += S[ofs[k]] + S[ofs[k + 1]] + S[ofs[k + 2]] + S[ofs[k + 3]]; +#endif + for (; k < area; k++) + sum += S[ofs[k]]; + + D[dx] = saturate_cast(sum * scale); + } + + for (; dx < dwidth; dx++) { + WT sum = 0; + int count = 0, sx0 = xofs[dx]; + if (sx0 >= swidth) + D[dx] = 0; + + for (int sy = 0; sy < scale_y; sy++) { + if (sy0 + sy >= sheight) + break; + const T* S = (const T*)(src.ptr(sy0 + sy)) + sx0; + for (int sx = 0; sx < scale_x * cn; sx += cn) { + if (sx0 + sx >= swidth) + break; + sum += S[sx]; + count++; + } + } + + D[dx] = saturate_cast((float)sum / count); + } + } +} + +template +ResizeAreaFastFunc get_resize_area_fast_func() { + megdnn_throw(("unknown type")); +} + +template <> +ResizeAreaFastFunc get_resize_area_fast_func() { + return resizeAreaFast_; +} + +// Resize Area +template +static void resizeArea_( + const Mat& src, Mat& dst, const DecimateAlpha* xtab, int xtab_size, + const DecimateAlpha* ytab, int ytab_size, const int* tabofs) { + // parallel_for_(Range(0, dst.rows), + // ResizeArea_Invoker(src, dst, xtab, xtab_size, ytab, ytab_size, + // tabofs), dst.total()/((double)(1 << 16))); + (void)ytab_size; + int dwidth = dst.width(), dheight = dst.height(); + int cn = dst.channels(); + dwidth *= cn; + AlignedVector _buffer(dwidth * 2); + WT *buf = _buffer.data(), *sum = buf + dwidth; + int j_start = tabofs[0], j_end = tabofs[dheight], j, k, dx, + prev_dy = ytab[j_start].di; + + for (dx = 0; dx < dwidth; dx++) + sum[dx] = (WT)0; + + for (j = j_start; j < j_end; j++) { + WT beta = ytab[j].alpha; + int dy = ytab[j].di; + int sy = ytab[j].si; + + { + const T* S = (const T*)(src.ptr(sy)); + for (dx = 0; dx < dwidth; dx++) + buf[dx] = (WT)0; + + if (cn == 1) + for (k = 0; k < xtab_size; k++) { + int dxn = xtab[k].di; + WT alpha = xtab[k].alpha; + buf[dxn] += S[xtab[k].si] * alpha; + } + else if (cn == 3) + for (k = 0; k < xtab_size; k++) { + int sxn = xtab[k].si; + int dxn = xtab[k].di; + WT alpha = xtab[k].alpha; + WT t0 = buf[dxn] + S[sxn] * alpha; + WT t1 = buf[dxn + 1] + S[sxn + 1] * alpha; + WT t2 = buf[dxn + 2] + S[sxn + 2] * alpha; + buf[dxn] = t0; + buf[dxn + 1] = t1; + buf[dxn + 2] = t2; + } + else { + megdnn_throw(("nr. of channels must be 1 or 3")); + } + } + + if (dy != prev_dy) { + T* D = dst.ptr(prev_dy); + + for (dx = 0; dx < dwidth; dx++) { + D[dx] = saturate_cast(sum[dx]); + sum[dx] = beta * buf[dx]; + } + prev_dy = dy; + } else { + for (dx = 0; dx < dwidth; dx++) + sum[dx] += beta * buf[dx]; + } + } + + { + T* D = dst.ptr(prev_dy); + for (dx = 0; dx < dwidth; dx++) + D[dx] = saturate_cast(sum[dx]); + } +} + +template +ResizeAreaFunc get_resize_area_func() { + megdnn_throw(("unknown type")); +} + +template <> +ResizeAreaFunc get_resize_area_func() { + return resizeArea_; +} + +template +void resize_opencv(const Mat& src, Mat& dst, InterpolationMode ip) { + // fake area mode missing here + int dwidth = dst.width(); + int dheight = dst.height(); + int swidth = src.width(); + int sheight = src.height(); + int xmin = 0, xmax = dwidth, width = dwidth * dst.channels(); + double inv_scale_x = static_cast(dwidth) / swidth; + double inv_scale_y = static_cast(dheight) / sheight; + double scale_x = 1.0 / inv_scale_x; + double scale_y = 1.0 / inv_scale_y; + int dx, sx, dy, sy, k; + float fx, fy; + int cn = src.channels(); + { + int iscale_x = saturate_cast(scale_x); + int iscale_y = saturate_cast(scale_y); + + bool is_area_fast = std::abs(scale_x - iscale_x) < DBL_EPSILON && + std::abs(scale_y - iscale_y) < DBL_EPSILON; + if (ip == IMode::INTER_LINEAR && is_area_fast && iscale_x == 2 && + iscale_y == 2) { + ip = IMode::INTER_AREA; + } + if (ip == IMode::INTER_AREA && scale_x >= 1 && scale_y >= 1) { + if (is_area_fast) { + int area = iscale_x * iscale_y; + size_t srcstep = src.step(); + AlignedVector _ofs(area + dwidth * cn); + int* ofs = _ofs.data(); + int* xofs = ofs + area; + ResizeAreaFastFunc func = + get_resize_area_fast_func(); /// need change + for (sy = 0, k = 0; sy < iscale_y; ++sy) + for (sx = 0; sx < iscale_x; ++sx) + ofs[k++] = static_cast(sy * srcstep + sx * cn); + for (dx = 0; dx < dwidth; ++dx) { + int j = dx * cn; + sx = iscale_x * j; + for (k = 0; k < cn; ++k) + xofs[j + k] = sx + k; + } + func(src, dst, ofs, xofs, iscale_x, iscale_y); + return; + } + ResizeAreaFunc func = get_resize_area_func(); + AlignedVector _xytab((swidth + sheight) * 2); + DecimateAlpha *xtab = _xytab.data(), *ytab = xtab + swidth * 2; + int xtab_size = compute_resize_area_tab(swidth, dwidth, cn, scale_x, xtab); + int ytab_size = compute_resize_area_tab(sheight, dheight, 1, scale_y, ytab); + AlignedVector _tabofs(dheight + 1); + int* tabofs = _tabofs.data(); + for (k = 0, dy = 0; k < ytab_size; ++k) { + if (k == 0 || ytab[k].di != ytab[k - 1].di) { + megdnn_assert(ytab[k].di == dy); + tabofs[dy++] = k; + } + } + tabofs[dy] = ytab_size; + func(src, dst, xtab, xtab_size, ytab, ytab_size, tabofs); + return; + } + } + bool area_mode = (ip == IMode::INTER_AREA); + int ksize, ksize2; + ResizeFunc func; + bool fixedpt; + setup_resize_env(ip, ksize, fixedpt, func); + ksize2 = ksize / 2; + AlignedVector _buffer( + (width + dst.height()) * (sizeof(int) + sizeof(float) * ksize)); + uchar* buffer = _buffer.data(); + int* xofs = static_cast(static_cast(buffer)); + int* yofs = xofs + width; + float* alpha = static_cast(static_cast(yofs + dst.height())); + short* ialpha = static_cast(static_cast(alpha)); + float* beta = alpha + width * ksize; + short* ibeta = static_cast(static_cast(beta)); + // float cbuf[16]; + float cbuf[16] = {0}; + for (dx = 0; dx < dwidth; ++dx) { + if (!area_mode) { + fx = (float)((dx + 0.5) * scale_x - 0.5); + sx = floor(fx); + fx -= sx; + } else { + sx = floor(dx * scale_x); + fx = (float)((dx + 1) - (sx + 1) * inv_scale_x); + fx = (fx <= 0 ? 0.0f : fx - floor(fx)); + } + + if (sx < ksize2 - 1) { + xmin = dx + 1; + if (sx < 0 && (ip != IMode::INTER_CUBIC && ip != IMode::INTER_LANCZOS4)) { + fx = 0; + sx = 0; + } + } + if (sx + ksize2 >= swidth) { + xmax = std::min(xmax, dx); + if (sx >= swidth - 1 && ip != IMode::INTER_CUBIC && + ip != IMode::INTER_LANCZOS4) { + fx = 0; + sx = swidth - 1; + } + } + int k; + for (k = 0, sx *= cn; k < cn; ++k) + xofs[dx * cn + k] = sx + k; + if (ip == IMode::INTER_CUBIC) { + interpolate_cubic(fx, cbuf); + } else if (ip == IMode::INTER_LANCZOS4) { + interpolate_lanczos4(fx, cbuf); + } else { + cbuf[0] = 1.0f - fx; + cbuf[1] = fx; + } + if (fixedpt) { + for (k = 0; k < ksize; ++k) { + ialpha[dx * cn * ksize + k] = + saturate_cast(cbuf[k] * INTER_RESIZE_COEF_SCALE); + } + for (; k < cn * ksize; ++k) { + ialpha[dx * cn * ksize + k] = ialpha[dx * cn * ksize + k - ksize]; + } + } else { + for (k = 0; k < ksize; ++k) { + alpha[dx * cn * ksize + k] = cbuf[k]; + } + for (; k < cn * ksize; ++k) { + alpha[dx * cn * ksize + k] = alpha[dx * cn * ksize + k - ksize]; + } + } + } + for (dy = 0; dy < dheight; ++dy) { + if (!area_mode) { + fy = static_cast((dy + 0.5) * scale_y - 0.5); + sy = floor(fy); + fy -= sy; + } else { + sy = floor(dy * scale_y); + fy = static_cast((dy + 1) - (sy + 1) * inv_scale_y); + fy = (fy <= 0 ? 0.0f : fy - floor(fy)); + } + yofs[dy] = sy; + if (ip == IMode::INTER_CUBIC) { + interpolate_cubic(fy, cbuf); + } else if (ip == IMode::INTER_LANCZOS4) { + interpolate_lanczos4(fy, cbuf); + } else { + cbuf[0] = 1.0f - fy; + cbuf[1] = fy; + } + if (fixedpt) { + for (int k = 0; k < ksize; ++k) { + ibeta[dy * ksize + k] = + saturate_cast(cbuf[k] * INTER_RESIZE_COEF_SCALE); + } + } else { + for (int k = 0; k < ksize; ++k) { + beta[dy * ksize + k] = cbuf[k]; + } + } + } + func(src, dst, xofs, + fixedpt ? static_cast(ialpha) : static_cast(alpha), yofs, + fixedpt ? static_cast(ibeta) : static_cast(beta), xmin, xmax, + ksize); +} + +} // anonymous namespace + +void megdnn::fallback::resize_cv_gi_exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + param::Resize::InterpolationMode imode) { + megdnn_assert(src.layout[3] == 1 || src.layout[3] == 3, "unsupported src channel"); + for (size_t i = 0; i < src.layout.shape[0]; ++i) { + if (dst.layout.dtype == dtype::Float32()) { + MIDOUT_BEGIN(megdnn_fallback_resizecv_dtype, midout_iv(0)) { + Mat src_mat = TensorND2Mat(src, i); + Mat dst_mat = TensorND2Mat(dst, i); + switch (imode) { + case IMode::INTER_NEAREST: + MIDOUT_BEGIN(megdnn_fallback_resizecv_imode, midout_iv(0)) { + resize_nearest_32f(src_mat, dst_mat); + } + MIDOUT_END(); + break; + case IMode::INTER_LINEAR: + MIDOUT_BEGIN(megdnn_fallback_resizecv_imode, midout_iv(1)) { + resize_linear_32f(src_mat, dst_mat); + } + MIDOUT_END(); + break; + case IMode::INTER_CUBIC: + case IMode::INTER_LANCZOS4: + case IMode::INTER_AREA: + MIDOUT_BEGIN(megdnn_fallback_resizecv_imode, midout_iv(2)) { + resize_opencv(src_mat, dst_mat, imode); + } + MIDOUT_END(); + break; + default: + megdnn_throw("unsupported interpolation mode"); + break; + } + } + MIDOUT_END(); + } else { + megdnn_throw("Unsupported datatype of resize optr."); + } + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/resize/gi/resize_cv.h b/dnn/src/fallback/resize/gi/resize_cv.h new file mode 100644 index 00000000..503379b2 --- /dev/null +++ b/dnn/src/fallback/resize/gi/resize_cv.h @@ -0,0 +1,20 @@ +#include + +#include "src/common/cv/helper.h" +#include "src/fallback/resize/gi/helper.h" + +namespace megdnn { +namespace fallback { + +/** + * \fn resize_cv_exec + * \brief Used if the format is NHWC, transfer from megcv + */ +void resize_cv_gi_exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + param::Resize::InterpolationMode imode); + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/resize/gi/upsample2_nchw.cpp b/dnn/src/fallback/resize/gi/upsample2_nchw.cpp new file mode 100644 index 00000000..126884b2 --- /dev/null +++ b/dnn/src/fallback/resize/gi/upsample2_nchw.cpp @@ -0,0 +1,189 @@ +#include "src/fallback/resize/gi/upsample2_nchw.h" + +using namespace megdnn; +using namespace fallback; +using namespace resize; + +namespace { + +template +static GI_FORCEINLINE ctype +compute_linear_element(const ctype src[4], const ctype alpha[2]) { + return src[0] * alpha[0 ^ fh] * alpha[0 ^ fw] + + src[1] * alpha[0 ^ fh] * alpha[1 ^ fw] + + src[2] * alpha[1 ^ fh] * alpha[0 ^ fw] + + src[3] * alpha[1 ^ fh] * alpha[1 ^ fw]; +} + +template +static GI_FORCEINLINE typename simd_helper::simd_type compute_linear_element_simd( + const typename simd_helper::simd_type src[4], + const typename simd_helper::simd_type alpha[2][2]) { + typename simd_helper::simd_type c = simd_helper::dup(0); + c = simd_helper::fma(c, src[0], alpha[0 ^ fh][0 ^ fw]); + c = simd_helper::fma(c, src[1], alpha[0 ^ fh][1 ^ fw]); + c = simd_helper::fma(c, src[2], alpha[1 ^ fh][0 ^ fw]); + c = simd_helper::fma(c, src[3], alpha[1 ^ fh][1 ^ fw]); + return c; +} + +template +static GI_FORCEINLINE void compute_linear_2x2_element( + const ctype* src, ctype* dst, size_t IW, size_t OW, const ctype alpha[2]) { + const ctype* src_ptr[4] = {src, src, src, src}; + + if (has_right) { + src_ptr[1] += 1; + src_ptr[3] += 1; + } + if (has_bottom) { + src_ptr[2] += IW; + src_ptr[3] += IW; + } + + ctype rsrc[4]; + rsrc[0] = *src_ptr[0]; + rsrc[1] = *src_ptr[1]; + rsrc[2] = *src_ptr[2]; + rsrc[3] = *src_ptr[3]; + + dst[0] = compute_linear_element(rsrc, alpha); + if (has_right) { + dst[1] = compute_linear_element(rsrc, alpha); + } + if (has_bottom) { + dst[OW] = compute_linear_element(rsrc, alpha); + } + if (has_right && has_bottom) { + dst[OW + 1] = compute_linear_element(rsrc, alpha); + } +} + +template +static GI_FORCEINLINE void compute_linear_2x2_element_simd( + const typename simd_helper::ctype* src, typename simd_helper::ctype* dst, + size_t IW, size_t OW, const typename simd_helper::simd_type alpha[2][2]) { + using simd_type = typename simd_helper::simd_type; + + simd_type rsrc[4]; + rsrc[0] = simd_helper::load(src); + rsrc[1] = simd_helper::load(src + 1); + rsrc[2] = simd_helper::load(src + IW); + rsrc[3] = simd_helper::load(src + IW + 1); + + simd_type rdst[4]; + rdst[0] = compute_linear_element_simd(rsrc, alpha); + rdst[1] = compute_linear_element_simd(rsrc, alpha); + rdst[2] = compute_linear_element_simd(rsrc, alpha); + rdst[3] = compute_linear_element_simd(rsrc, alpha); + + simd_helper::store2_interleave(dst, rdst[0], rdst[1]); + simd_helper::store2_interleave(dst + OW, rdst[2], rdst[3]); +} + +template +void linear_upsample2_nchw( + const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { + using simd_helper = SIMDHelper; + size_t OW = IW * 2; + constexpr size_t PC = simd_helper::simd_width; + + ctype alpha[2] = {0.75, 0.25}; + + typename simd_helper::simd_type simd_alpha[2][2]; + simd_alpha[0][0] = simd_helper::dup(0.75 * 0.75); + simd_alpha[0][1] = simd_helper::dup(0.75 * 0.25); + simd_alpha[1][0] = simd_helper::dup(0.25 * 0.75); + simd_alpha[1][1] = simd_helper::dup(0.25 * 0.25); + + for (size_t i = 0; i < N; ++i) { + compute_linear_2x2_element( + src_ptr, dst_ptr, IW, OW, alpha); + { + for (size_t iw = 0; iw + 1 < IW; ++iw) { + compute_linear_2x2_element( + src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, alpha); + } + } + compute_linear_2x2_element( + src_ptr + (IW - 1), dst_ptr + (OW - 1), IW, OW, alpha); + dst_ptr += OW; + + for (size_t ih = 0; ih + 1 < IH; ++ih) { + compute_linear_2x2_element( + src_ptr, dst_ptr, IW, OW, alpha); + size_t iw = 0; + for (; iw + PC < IW; iw += PC) { + compute_linear_2x2_element_simd( + src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, simd_alpha); + } + for (; iw + 1 < IW; ++iw) { + compute_linear_2x2_element( + src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, alpha); + } + compute_linear_2x2_element( + src_ptr + (IW - 1), dst_ptr + (OW - 1), IW, OW, alpha); + + src_ptr += IW; + dst_ptr += 2 * OW; + } + + compute_linear_2x2_element( + src_ptr, dst_ptr, IW, OW, alpha); + { + for (size_t iw = 0; iw + 1 < IW; ++iw) { + compute_linear_2x2_element( + src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, alpha); + } + } + compute_linear_2x2_element( + src_ptr + (IW - 1), dst_ptr + (OW - 1), IW, OW, alpha); + src_ptr += IW; + dst_ptr += OW; + } +} + +template +void nearest_upsample2_nchw( + const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { + using simd_helper = SIMDHelper; + size_t OW = IW * 2; + constexpr size_t PC = simd_helper::simd_width; + + for (size_t i = 0; i < N; ++i) { + for (size_t ih = 0; ih < IH; ++ih) { + size_t iw = 0; + for (; iw + PC - 1 < IW; iw += PC) { + typename simd_helper::simd_type r0 = simd_helper::load(src_ptr + iw); + + simd_helper::store2_interleave(dst_ptr + (iw * 2), r0, r0); + simd_helper::store2_interleave(dst_ptr + (OW + iw * 2), r0, r0); + } + for (; iw < IW; iw += 1) { + ctype v = src_ptr[iw]; + dst_ptr[iw * 2] = v; + dst_ptr[iw * 2 + 1] = v; + dst_ptr[OW + iw * 2] = v; + dst_ptr[OW + iw * 2 + 1] = v; + } + src_ptr += IW; + dst_ptr += 2 * OW; + } + } +} + +} // namespace + +void megdnn::fallback::resize_linear_upsample2_nchw_gi_fp32( + const ResizeImpl::KernParam& kern_param) { + linear_upsample2_nchw( + kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c, + kern_param.ih, kern_param.iw); +} + +void megdnn::fallback::resize_nearest_upsample2_nchw_gi_fp32( + const ResizeImpl::KernParam& kern_param) { + nearest_upsample2_nchw( + kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c, + kern_param.ih, kern_param.iw); +} diff --git a/dnn/src/fallback/resize/gi/upsample2_nchw.h b/dnn/src/fallback/resize/gi/upsample2_nchw.h new file mode 100644 index 00000000..ec0430d9 --- /dev/null +++ b/dnn/src/fallback/resize/gi/upsample2_nchw.h @@ -0,0 +1,14 @@ +#pragma once +#include "src/fallback/resize/gi/helper.h" + +namespace megdnn { +namespace fallback { + +void resize_linear_upsample2_nchw_gi_fp32( + const ResizeImpl::KernParam& kern_param); + +void resize_nearest_upsample2_nchw_gi_fp32( + const ResizeImpl::KernParam& kern_param); + +} // namespace fallback +} // namespace megdnn diff --git a/dnn/src/fallback/resize/gi/upsample2_nchwxx.cpp b/dnn/src/fallback/resize/gi/upsample2_nchwxx.cpp new file mode 100644 index 00000000..6648e6cd --- /dev/null +++ b/dnn/src/fallback/resize/gi/upsample2_nchwxx.cpp @@ -0,0 +1,155 @@ +#include "src/fallback/resize/gi/upsample2_nchwxx.h" + +using namespace megdnn; +using namespace fallback; +using namespace resize; + +namespace { + +template +static GI_FORCEINLINE typename simd_helper::simd_type compute_linear_element( + const typename simd_helper::simd_type src[4], + const typename simd_helper::simd_type alpha[2][2]) { + typename simd_helper::simd_type c = simd_helper::dup(0); + c = simd_helper::fma(c, src[0], alpha[0 ^ fh][0 ^ fw]); + c = simd_helper::fma(c, src[1], alpha[0 ^ fh][1 ^ fw]); + c = simd_helper::fma(c, src[2], alpha[1 ^ fh][0 ^ fw]); + c = simd_helper::fma(c, src[3], alpha[1 ^ fh][1 ^ fw]); + return c; +} + +template +static GI_FORCEINLINE void compute_linear_2x2_element( + const typename simd_helper::ctype* src, typename simd_helper::ctype* dst, + size_t IW, size_t OW, const typename simd_helper::simd_type alpha[2][2]) { + constexpr size_t PC = simd_helper::simd_width; + const typename simd_helper::ctype* src_ptr[4] = {src, src, src, src}; + + if (has_right) { + src_ptr[1] += PC; + src_ptr[3] += PC; + } + if (has_bottom) { + src_ptr[2] += IW * PC; + src_ptr[3] += IW * PC; + } + + typename simd_helper::simd_type rsrc[4]; + rsrc[0] = simd_helper::load(src_ptr[0]); + rsrc[1] = simd_helper::load(src_ptr[1]); + rsrc[2] = simd_helper::load(src_ptr[2]); + rsrc[3] = simd_helper::load(src_ptr[3]); + + typename simd_helper::simd_type rdst[4]; + rdst[0] = compute_linear_element(rsrc, alpha); + rdst[1] = compute_linear_element(rsrc, alpha); + rdst[2] = compute_linear_element(rsrc, alpha); + rdst[3] = compute_linear_element(rsrc, alpha); + + simd_helper::store(dst, rdst[0]); + if (has_right) { + simd_helper::store(dst + PC, rdst[1]); + } + if (has_bottom) { + simd_helper::store(dst + OW * PC, rdst[2]); + } + if (has_right && has_bottom) { + simd_helper::store(dst + (OW + 1) * PC, rdst[3]); + } +} + +template +void linear_upsample2_nchwxx( + const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { + using simd_helper = SIMDHelper; + size_t OW = IW * 2; + constexpr size_t PC = simd_helper::simd_width; + + typename simd_helper::simd_type alpha[2][2]; + alpha[0][0] = simd_helper::dup(0.75 * 0.75); + alpha[0][1] = simd_helper::dup(0.75 * 0.25); + alpha[1][0] = simd_helper::dup(0.25 * 0.75); + alpha[1][1] = simd_helper::dup(0.25 * 0.25); + + for (size_t i = 0; i < N; ++i) { + compute_linear_2x2_element( + src_ptr, dst_ptr, IW, OW, alpha); + + { + for (size_t iw = 0; iw + 1 < IW; ++iw) { + compute_linear_2x2_element( + src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha); + } + } + compute_linear_2x2_element( + src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha); + dst_ptr += OW * PC; + + for (size_t ih = 0; ih + 1 < IH; ++ih) { + compute_linear_2x2_element( + src_ptr, dst_ptr, IW, OW, alpha); + for (size_t iw = 0; iw + 1 < IW; ++iw) { + compute_linear_2x2_element( + src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha); + } + compute_linear_2x2_element( + src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha); + + src_ptr += IW * PC; + dst_ptr += 2 * OW * PC; + } + + compute_linear_2x2_element( + src_ptr, dst_ptr, IW, OW, alpha); + { + for (size_t iw = 0; iw + 1 < IW; ++iw) { + compute_linear_2x2_element( + src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha); + } + } + + compute_linear_2x2_element( + src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha); + src_ptr += IW * PC; + dst_ptr += OW * PC; + } +} + +template +void nearest_upsample2_nchwxx( + const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { + using simd_helper = SIMDHelper; + size_t OW = IW * 2; + constexpr size_t PC = simd_helper::simd_width; + + for (size_t i = 0; i < N; ++i) { + for (size_t ih = 0; ih < IH; ++ih) { + for (size_t iw = 0; iw < IW; ++iw) { + typename simd_helper::simd_type r0 = + simd_helper::load(src_ptr + iw * PC); + + simd_helper::store(dst_ptr + (iw * 2) * PC, r0); + simd_helper::store(dst_ptr + (iw * 2 + 1) * PC, r0); + simd_helper::store(dst_ptr + (OW + iw * 2) * PC, r0); + simd_helper::store(dst_ptr + (OW + iw * 2 + 1) * PC, r0); + } + src_ptr += IW * PC; + dst_ptr += 2 * OW * PC; + } + } +} +} // namespace + +void megdnn::fallback::resize_linear_upsample2_nchw44_gi_fp32( + const ResizeImpl::KernParam& kern_param) { + linear_upsample2_nchwxx( + kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, + kern_param.ih, kern_param.iw); +} + +void megdnn::fallback::resize_nearest_upsample2_nchw44_gi_fp32( + const ResizeImpl::KernParam& kern_param) { + nearest_upsample2_nchwxx( + kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, + kern_param.ih, kern_param.iw); +} diff --git a/dnn/src/fallback/resize/gi/upsample2_nchwxx.h b/dnn/src/fallback/resize/gi/upsample2_nchwxx.h new file mode 100644 index 00000000..a4c66540 --- /dev/null +++ b/dnn/src/fallback/resize/gi/upsample2_nchwxx.h @@ -0,0 +1,14 @@ +#pragma once +#include "src/fallback/resize/gi/helper.h" + +namespace megdnn { +namespace fallback { + +void resize_linear_upsample2_nchw44_gi_fp32( + const ResizeImpl::KernParam& kern_param); + +void resize_nearest_upsample2_nchw44_gi_fp32( + const ResizeImpl::KernParam& kern_param); + +} // namespace fallback +} // namespace megdnn diff --git a/dnn/src/fallback/resize/opr_impl.cpp b/dnn/src/fallback/resize/opr_impl.cpp index 8921008f..14bc0f3e 100644 --- a/dnn/src/fallback/resize/opr_impl.cpp +++ b/dnn/src/fallback/resize/opr_impl.cpp @@ -15,6 +15,14 @@ #include "src/common/rounding_converter.cuh" #include "src/fallback/handle.h" +#include "src/fallback/resize/gi/direct_nchwxx.h" +#include "src/fallback/resize/gi/resize_cv.h" +#include "src/fallback/resize/gi/upsample2_nchw.h" +#include "src/fallback/resize/gi/upsample2_nchwxx.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_fallback_resize) + using namespace megdnn; using namespace fallback; @@ -108,6 +116,11 @@ void ResizeImpl::kern_fallback_nhwc(const KernParam& kern_param) { void ResizeImpl::exec( _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); + exec_gi(src, dst, workspace); +} + +void ResizeImpl::exec_fallback( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { if (param().format == param::Resize::Format::NCHW4 || param().format == param::Resize::Format::NCHW44 || param().format == param::Resize::Format::NCHW88 || @@ -148,4 +161,92 @@ void ResizeImpl::exec( naive::ResizeImpl::exec(src, dst, workspace); } +void ResizeImpl::exec_gi( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { + bool is_contiguous = src.layout.is_contiguous() && dst.layout.is_contiguous(); + bool is_dtype_same = src.layout.dtype == dst.layout.dtype; + bool is_dtype_fp32 = src.layout.dtype == dtype::Float32(); + bool is_dtype_supported = is_dtype_same && is_dtype_fp32; + + bool is_nchw_fp32 = param().format == param::Resize::Format::NCHW && is_dtype_fp32; + bool is_nchw44_fp32 = + param().format == param::Resize::Format::NCHW44 && is_dtype_fp32; + bool is_imode_nearest = + param().imode == param::Resize::InterpolationMode::INTER_NEAREST; + bool is_imode_linear = + param().imode == param::Resize::InterpolationMode::INTER_LINEAR; + bool is_imode_supported = is_imode_nearest || is_imode_linear; + + bool is_upsample2 = src.layout.shape[2] * 2 == dst.layout.shape[2] && + src.layout.shape[3] * 2 == dst.layout.shape[3]; + bool usable = is_contiguous && is_dtype_supported && is_imode_supported; + + if (param().format == param::Resize::Format::NHWC && + (src.layout[3] == 1 || src.layout[3] == 3) && is_nhwc_contig_wc(src.layout) && + is_dtype_fp32) { + MEGDNN_DISPATCH_CPU_KERN_OPR(resize_cv_gi_exec(src, dst, param().imode)); + } else if (!usable) { + exec_fallback(src, dst, workspace); + } else if (is_dtype_fp32) { + auto kern_param = KernParam::from_tensors( + param().format, param().imode, src, dst, workspace); + if (is_nchw44_fp32) { + if (is_upsample2) { + if (is_imode_nearest) { + MIDOUT_BEGIN(megdnn_fallback_resize, midout_iv(0)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_nearest_upsample2_nchw44_gi_fp32(kern_param)); + } + MIDOUT_END(); + } else { + megdnn_assert(is_imode_linear, "invalid imode"); + MIDOUT_BEGIN(megdnn_fallback_resize, midout_iv(1)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_linear_upsample2_nchw44_gi_fp32(kern_param)); + } + MIDOUT_END(); + } + } else { + if (is_imode_nearest) { + MIDOUT_BEGIN(megdnn_fallback_resize, midout_iv(2)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_direct_nearest_nchw44_gi_fp32(kern_param)); + } + MIDOUT_END(); + } else { + megdnn_assert(is_imode_linear, "invalid imode"); + MIDOUT_BEGIN(megdnn_fallback_resize, midout_iv(3)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_direct_linear_nchw44_gi_fp32(kern_param)); + } + MIDOUT_END(); + } + } + } else if (is_nchw_fp32) { + if (is_upsample2) { + if (is_imode_nearest) { + MIDOUT_BEGIN(megdnn_fallback_resize, midout_iv(4)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_nearest_upsample2_nchw_gi_fp32(kern_param)); + } + MIDOUT_END(); + } else { + megdnn_assert(is_imode_linear, "invalid imode"); + MIDOUT_BEGIN(megdnn_fallback_resize, midout_iv(5)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_linear_upsample2_nchw_gi_fp32(kern_param)); + } + MIDOUT_END(); + } + } else { + exec_fallback(src, dst, workspace); + } + } else { + exec_fallback(src, dst, workspace); + } + } else { + exec_fallback(src, dst, workspace); + } +} +// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/resize/opr_impl.h b/dnn/src/fallback/resize/opr_impl.h index 2880b0db..82282012 100644 --- a/dnn/src/fallback/resize/opr_impl.h +++ b/dnn/src/fallback/resize/opr_impl.h @@ -35,6 +35,11 @@ private: template void kern_fallback_nhwc(const KernParam& kern_param); + void exec_fallback( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace); + + void exec_gi( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace); }; // class ResizeImpl } // namespace fallback diff --git a/dnn/test/fallback/gi.cpp b/dnn/test/fallback/gi.cpp index 4a431912..5dd97bf4 100644 --- a/dnn/test/fallback/gi.cpp +++ b/dnn/test/fallback/gi.cpp @@ -358,6 +358,21 @@ TEST_F(FALLBACK, GiLoadFloat32) { assert_eq((float*)&ret, naive); } +TEST_F(FALLBACK, GiLoadFloat32V2) { + GI_FLOAT32_V2_t ret; + std::vector s0{2.3f, 4.7f, -1.4f, 1223.6f, 1.1f, 4.0f, 99.7f, 1234.9f}; + s0.resize(SIMD_LEN * 2); + + ret = GiLoadFloat32V2(s0.data()); + + std::vector naive; + for (size_t i = 0; i < SIMD_LEN * 2; i++) { + naive.push_back(s0[i]); + } + + assert_eq((float*)&ret, naive, SIMD_LEN * 2); +} + TEST_F(FALLBACK, GiLoadFloat32LowHalf) { GI_FLOAT32_t ret; std::vector s0{2.3f, 4.7f, -1.4f, 1223.6f}; @@ -701,6 +716,18 @@ TEST_F(FALLBACK, GiStoreFloat32) { assert_eq(ret.data(), s0); } +TEST_F(FALLBACK, GiStoreFloat32V2) { + GI_FLOAT32_V2_t src0; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f, -1.1f, -2.2f, -3.5f, -4.9}; + s0.resize(SIMD_LEN * 2); + init((float*)&src0, s0, SIMD_LEN * 2); + std::vector ret{0}; + ret.resize(SIMD_LEN * 2); + + GiStoreFloat32V2(ret.data(), src0); + assert_eq(ret.data(), s0, SIMD_LEN * 2); +} + TEST_F(FALLBACK, GiStoreLaneXXFloat32) { GI_FLOAT32_t src0; std::vector s0{1.1f, 2.2f, 3.5f, 4.9f}; @@ -1226,7 +1253,7 @@ TEST_F(FALLBACK, GiBSLFloat32) { naive.resize(SIMD_LEN); memcpy(naive.data(), &na, sizeof(GI_FLOAT32_t)); - assert_eq((float*)&ret, naive); + assert_eq_and_nan((float*)&ret, naive); } } @@ -3199,6 +3226,65 @@ TEST_F(FALLBACK, GiPmaxFloat32) { ASSERT_LT(std::abs(naive[1] - r[1]), 1e-3); } +TEST_F(FALLBACK, GiStoreZipFloat32V2) { + GI_FLOAT32_V2_t src0; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f, 2312.1f, 345.244f, 3.59f, -12.8f}; + s0.resize(SIMD_LEN * 2); + init((float*)&src0, s0, SIMD_LEN * 2); + std::vector ret; + ret.resize(SIMD_LEN * 2); + std::vector ret_cmp; + ret_cmp.resize(SIMD_LEN * 2); + + GiStoreZipFloat32V2(ret.data(), src0); + + GI_FLOAT32_V2_t tmp; + tmp = GiZipqFloat32(src0.val[0], src0.val[1]); + GiStoreFloat32(ret_cmp.data(), tmp.val[0]); + GiStoreFloat32(ret_cmp.data() + SIMD_LEN, tmp.val[1]); + + assert_eq(ret.data(), ret_cmp, SIMD_LEN * 2); +} + +TEST_F(FALLBACK, GiLoadUzipFloat32V3) { + GI_FLOAT32_V3_t ret; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f, 2312.1f, 345.244f, + 3.59f, -12.8f, 2.2f, 6.0f, 90.0f, 89.3f}; + s0.resize(SIMD_LEN * 3); + + ret = GiLoadUzipFloat32V3(s0.data()); + std::vector naive; + for (size_t i = 0; i < 3; i++) { + naive.push_back(s0[0 + i]); + naive.push_back(s0[3 + i]); + naive.push_back(s0[6 + i]); + naive.push_back(s0[9 + i]); + } + + assert_eq((float*)&ret, naive); +} + +TEST_F(FALLBACK, GiStoreZipFloat32V3) { + GI_FLOAT32_V3_t src0; + std::vector s0{1.1f, 2.2f, 3.5f, 4.9f, 2312.1f, 345.244f, + 3.59f, -12.8f, 3.59f, -12.8f, 2.2f, 6.0}; + s0.resize(SIMD_LEN * 3); + init((float*)&src0, s0, SIMD_LEN * 3); + std::vector ret; + ret.resize(SIMD_LEN * 3); + + GiStoreZipFloat32V3(ret.data(), src0); + + std::vector ret_cmp; + for (size_t i = 0; i < SIMD_LEN; i++) { + ret_cmp.push_back(s0[0 + i]); + ret_cmp.push_back(s0[4 + i]); + ret_cmp.push_back(s0[8 + i]); + } + + assert_eq(ret.data(), ret_cmp, SIMD_LEN * 3); +} + } // namespace test } // namespace megdnn diff --git a/dnn/test/fallback/resize.cpp b/dnn/test/fallback/resize.cpp index f77d410f..d9b2b48e 100644 --- a/dnn/test/fallback/resize.cpp +++ b/dnn/test/fallback/resize.cpp @@ -167,6 +167,48 @@ TEST_F(FALLBACK, RESIZE_NCHW4_RECORD) { } } +namespace { +static void set_nchw_args(resize::IMode imode, std::vector& args) { + param::Resize param; + param.format = param::Resize::Format::NCHW; + param.imode = imode; + rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul) + args.emplace_back( + param, TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul}, + TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul}); + args.emplace_back(param, TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 20, 20}); + args.emplace_back(param, TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 7, 9}); + args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8}); + args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4}); +} +} // namespace + +TEST_F(FALLBACK, RESIZE_NCHW_FP32) { + std::vector args; + set_nchw_args(resize::IMode::INTER_LINEAR, args); + set_nchw_args(resize::IMode::INTER_NEAREST, args); + Checker checker(handle()); + + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .execs({arg.src, arg.dst}); + } +} + +TEST_F(FALLBACK, RESIZE_NCHW44_FP32) { + std::vector args = resize::get_nchw44_args(); + Checker checker(handle()); + + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .execs({arg.src, arg.dst}); + } +} + } // namespace test } // namespace megdnn // vim: syntax=cpp.doxygen