GitOrigin-RevId: 3370cdc57a
release-1.10
@@ -140,6 +140,8 @@ public: | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; | AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; | ||||
const char* name() const override { return "FALLBACK_POOLING"; } | const char* name() const override { return "FALLBACK_POOLING"; } | ||||
bool usable(const PoolingKernSizeParam&) const override { return true; } | bool usable(const PoolingKernSizeParam&) const override { return true; } | ||||
//! use to fallback to algo define at: | |||||
//! dnn/src/fallback/pooling/gi/algo.h | |||||
void exec(const PoolingKernParam&) const override { | void exec(const PoolingKernParam&) const override { | ||||
megdnn_assert(false, "code issue happened!!"); | megdnn_assert(false, "code issue happened!!"); | ||||
} | } | ||||
@@ -30,18 +30,16 @@ void ResizeImpl::exec( | |||||
bool is_contiguous = src.layout.is_contiguous() && dst.layout.is_contiguous(); | bool is_contiguous = src.layout.is_contiguous() && dst.layout.is_contiguous(); | ||||
bool is_dtype_same = src.layout.dtype == dst.layout.dtype; | bool is_dtype_same = src.layout.dtype == dst.layout.dtype; | ||||
bool is_dtype_fp32 = src.layout.dtype == dtype::Float32(); | |||||
bool is_dtype_fp16 = | bool is_dtype_fp16 = | ||||
DNN_FLOAT16_SELECT(src.layout.dtype == dtype::Float16(), false); | DNN_FLOAT16_SELECT(src.layout.dtype == dtype::Float16(), false); | ||||
bool is_dtype_supported = is_dtype_same && (is_dtype_fp32 || is_dtype_fp16); | |||||
bool is_dtype_supported = is_dtype_same && is_dtype_fp16; | |||||
bool is_nchw = param().format == param::Resize::Format::NCHW && | |||||
(is_dtype_fp32 || is_dtype_fp16); | |||||
bool is_nchw44_fp32 = | |||||
param().format == param::Resize::Format::NCHW44 && is_dtype_fp32; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
bool is_nchw = param().format == param::Resize::Format::NCHW && is_dtype_fp16; | |||||
bool is_nchw88_fp16 = | bool is_nchw88_fp16 = | ||||
param().format == param::Resize::Format::NCHW88 && is_dtype_fp16; | param().format == param::Resize::Format::NCHW88 && is_dtype_fp16; | ||||
bool is_upsample2 = src.layout.shape[2] * 2 == dst.layout.shape[2] && | |||||
src.layout.shape[3] * 2 == dst.layout.shape[3]; | |||||
#endif | #endif | ||||
bool is_imode_nearest = | bool is_imode_nearest = | ||||
@@ -50,8 +48,6 @@ void ResizeImpl::exec( | |||||
param().imode == param::Resize::InterpolationMode::INTER_LINEAR; | param().imode == param::Resize::InterpolationMode::INTER_LINEAR; | ||||
bool is_imode_supported = is_imode_nearest || is_imode_linear; | bool is_imode_supported = is_imode_nearest || is_imode_linear; | ||||
bool is_upsample2 = src.layout.shape[2] * 2 == dst.layout.shape[2] && | |||||
src.layout.shape[3] * 2 == dst.layout.shape[3]; | |||||
bool usable = is_contiguous && is_dtype_supported && is_imode_supported; | bool usable = is_contiguous && is_dtype_supported && is_imode_supported; | ||||
if (param().format == param::Resize::Format::NHWC && | if (param().format == param::Resize::Format::NHWC && | ||||
@@ -59,63 +55,6 @@ void ResizeImpl::exec( | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(resize_cv_exec(src, dst, param().imode)); | MEGDNN_DISPATCH_CPU_KERN_OPR(resize_cv_exec(src, dst, param().imode)); | ||||
} else if (!usable) { | } else if (!usable) { | ||||
fallback::ResizeImpl::exec(src, dst, workspace); | fallback::ResizeImpl::exec(src, dst, workspace); | ||||
} else if (is_dtype_fp32) { | |||||
auto kern_param = KernParam<float>::from_tensors( | |||||
param().format, param().imode, src, dst, workspace); | |||||
if (is_nchw44_fp32) { | |||||
if (is_upsample2) { | |||||
if (is_imode_nearest) { | |||||
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(0)) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
resize_nearest_upsample2_nchw44_fp32(kern_param)); | |||||
} | |||||
MIDOUT_END(); | |||||
} else { | |||||
megdnn_assert(is_imode_linear, "invalid imode"); | |||||
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(1)) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
resize_linear_upsample2_nchw44_fp32(kern_param)); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} else { | |||||
if (is_imode_nearest) { | |||||
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(2)) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
resize_direct_nearest_nchw44_fp32(kern_param)); | |||||
} | |||||
MIDOUT_END(); | |||||
} else { | |||||
megdnn_assert(is_imode_linear, "invalid imode"); | |||||
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(3)) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
resize_direct_linear_nchw44_fp32(kern_param)); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} | |||||
} else if (is_nchw) { | |||||
if (is_upsample2) { | |||||
if (is_imode_nearest) { | |||||
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(4)) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
resize_nearest_upsample2_nchw_fp32(kern_param)); | |||||
} | |||||
MIDOUT_END(); | |||||
} else { | |||||
megdnn_assert(is_imode_linear, "invalid imode"); | |||||
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(5)) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
resize_linear_upsample2_nchw_fp32(kern_param)); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} else { | |||||
fallback::ResizeImpl::exec(src, dst, workspace); | |||||
} | |||||
} else { | |||||
fallback::ResizeImpl::exec(src, dst, workspace); | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
} else if (is_dtype_fp16) { | } else if (is_dtype_fp16) { | ||||
auto kern_param = KernParam<dt_float16>::from_tensors( | auto kern_param = KernParam<dt_float16>::from_tensors( | ||||
@@ -96,28 +96,10 @@ struct Vector<float, 8> { | |||||
Vector(const GI_FLOAT32_V2_t& v) { value = v; } | Vector(const GI_FLOAT32_V2_t& v) { value = v; } | ||||
static Vector load(const float* addr) { | static Vector load(const float* addr) { | ||||
Vector v; | Vector v; | ||||
#if defined(GI_TEST_NAIVE) | |||||
v.value.val[0] = GiLoadFloat32(addr); | |||||
v.value.val[1] = GiLoadFloat32(addr + 4); | |||||
#elif defined(__arm__) || defined(__aarch64__) | |||||
v.value = vld1q_f32_x2(addr); | |||||
#else | |||||
v.value.val[0] = GiLoadFloat32(addr); | |||||
v.value.val[1] = GiLoadFloat32(addr + 4); | |||||
#endif | |||||
v.value = GiLoadFloat32V2(addr); | |||||
return v; | return v; | ||||
} | } | ||||
static void save(float* addr, const Vector& v) { | |||||
#if defined(GI_TEST_NAIVE) | |||||
GiStoreFloat32(addr, v.value.val[0]); | |||||
GiStoreFloat32(addr + 4, v.value.val[1]); | |||||
#elif defined(__arm__) || defined(__aarch64__) | |||||
vst1q_f32_x2(addr, v.value); | |||||
#else | |||||
GiStoreFloat32(addr, v.value.val[0]); | |||||
GiStoreFloat32(addr + 4, v.value.val[1]); | |||||
#endif | |||||
} | |||||
static void save(float* addr, const Vector& v) { GiStoreFloat32V2(addr, v.value); } | |||||
void save(float* addr) { save(addr, *this); } | void save(float* addr) { save(addr, *this); } | ||||
Vector operator+(const Vector& lr) { | Vector operator+(const Vector& lr) { | ||||
@@ -143,6 +143,7 @@ typedef int16x8_t GI_INT16_t; | |||||
typedef int32x4_t GI_INT32_t; | typedef int32x4_t GI_INT32_t; | ||||
typedef uint32x4_t GI_UINT32_t; | typedef uint32x4_t GI_UINT32_t; | ||||
typedef float32x4x2_t GI_FLOAT32_V2_t; | typedef float32x4x2_t GI_FLOAT32_V2_t; | ||||
typedef float32x4x3_t GI_FLOAT32_V3_t; | |||||
typedef float32x4x4_t GI_FLOAT32_V4_t; | typedef float32x4x4_t GI_FLOAT32_V4_t; | ||||
typedef int32x4x2_t GI_INT32_V2_t; | typedef int32x4x2_t GI_INT32_V2_t; | ||||
typedef int32x4x4_t GI_INT32_V4_t; | typedef int32x4x4_t GI_INT32_V4_t; | ||||
@@ -167,6 +168,7 @@ typedef __m128i GI_INT16_t; | |||||
typedef __m128i GI_INT32_t; | typedef __m128i GI_INT32_t; | ||||
typedef __m128i GI_UINT32_t; | typedef __m128i GI_UINT32_t; | ||||
typedef __m128i GI_INT64_t; | typedef __m128i GI_INT64_t; | ||||
#define _SWAP_HI_LOW32 (2 | (3 << 2) | (0 << 4) | (1 << 6)) | |||||
#define _INSERTPS_NDX(srcField, dstField) (((srcField) << 6) | ((dstField) << 4)) | #define _INSERTPS_NDX(srcField, dstField) (((srcField) << 6) | ((dstField) << 4)) | ||||
#define _M64(out, inp) _mm_storel_epi64((__m128i*)&(out), inp) | #define _M64(out, inp) _mm_storel_epi64((__m128i*)&(out), inp) | ||||
#define _pM128i(a) _mm_loadl_epi64((__m128i*)&(a)) | #define _pM128i(a) _mm_loadl_epi64((__m128i*)&(a)) | ||||
@@ -295,6 +297,10 @@ typedef struct { | |||||
} GI_FLOAT32_V2_NAIVE_t; | } GI_FLOAT32_V2_NAIVE_t; | ||||
typedef struct { | typedef struct { | ||||
GI_FLOAT32_NAIVE_t val[3]; | |||||
} GI_FLOAT32_V3_NAIVE_t; | |||||
typedef struct { | |||||
GI_FLOAT32_NAIVE_t val[4]; | GI_FLOAT32_NAIVE_t val[4]; | ||||
} GI_FLOAT32_V4_NAIVE_t; | } GI_FLOAT32_V4_NAIVE_t; | ||||
@@ -335,6 +341,10 @@ typedef struct { | |||||
} GI_FLOAT32_V2_t; | } GI_FLOAT32_V2_t; | ||||
typedef struct { | typedef struct { | ||||
GI_FLOAT32_t val[3]; | |||||
} GI_FLOAT32_V3_t; | |||||
typedef struct { | |||||
GI_FLOAT32_t val[4]; | GI_FLOAT32_t val[4]; | ||||
} GI_FLOAT32_V4_t; | } GI_FLOAT32_V4_t; | ||||
@@ -157,6 +157,19 @@ GI_FLOAT32_t GiLoadFloat32(const float* Buffer) { | |||||
} | } | ||||
GI_FORCEINLINE | GI_FORCEINLINE | ||||
GI_FLOAT32_V2_t GiLoadFloat32V2(const float* Buffer) { | |||||
#if defined(GI_NEON_INTRINSICS) | |||||
return vld1q_f32_x2(Buffer); | |||||
#else | |||||
GI_FLOAT32_V2_t v; | |||||
v.val[0] = GiLoadFloat32(Buffer); | |||||
v.val[1] = GiLoadFloat32(Buffer + GI_SIMD_LEN_BYTE / sizeof(float)); | |||||
return v; | |||||
#endif | |||||
} | |||||
GI_FORCEINLINE | |||||
GI_FLOAT32_t GiLoadFloat32LowHalf(const float* Buffer) { | GI_FLOAT32_t GiLoadFloat32LowHalf(const float* Buffer) { | ||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
return vcombine_f32(vld1_f32(Buffer), vdup_n_f32(0.f)); | return vcombine_f32(vld1_f32(Buffer), vdup_n_f32(0.f)); | ||||
@@ -519,6 +532,16 @@ void GiStoreFloat32(float* Buffer, GI_FLOAT32_t Vector) { | |||||
#endif | #endif | ||||
} | } | ||||
GI_FORCEINLINE | |||||
void GiStoreFloat32V2(float* Buffer, GI_FLOAT32_V2_t Vector) { | |||||
#if defined(GI_NEON_INTRINSICS) | |||||
vst1q_f32_x2(Buffer, Vector); | |||||
#else | |||||
GiStoreFloat32(Buffer, Vector.val[0]); | |||||
GiStoreFloat32(Buffer + GI_SIMD_LEN_BYTE / sizeof(float), Vector.val[1]); | |||||
#endif | |||||
} | |||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
#define GISTORELANEFLOAT32(i) \ | #define GISTORELANEFLOAT32(i) \ | ||||
GI_FORCEINLINE void GiStoreLane##i##Float32(float* Buffer, GI_FLOAT32_t Vector) { \ | GI_FORCEINLINE void GiStoreLane##i##Float32(float* Buffer, GI_FLOAT32_t Vector) { \ | ||||
@@ -593,6 +616,18 @@ GI_FLOAT32_V2_t GiZipqFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | |||||
} | } | ||||
GI_FORCEINLINE | GI_FORCEINLINE | ||||
void GiStoreZipFloat32V2(float* Buffer, GI_FLOAT32_V2_t Vector) { | |||||
#if defined(GI_NEON_INTRINSICS) | |||||
vst2q_f32(Buffer, Vector); | |||||
#else | |||||
GI_FLOAT32_V2_t tmp; | |||||
tmp = GiZipqFloat32(Vector.val[0], Vector.val[1]); | |||||
GiStoreFloat32(Buffer, tmp.val[0]); | |||||
GiStoreFloat32(Buffer + GI_SIMD_LEN_BYTE / sizeof(float), tmp.val[1]); | |||||
#endif | |||||
} | |||||
GI_FORCEINLINE | |||||
GI_FLOAT32_t GiInterleaveLowFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | GI_FLOAT32_t GiInterleaveLowFloat32(GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | ||||
#if defined(GI_NEON64_INTRINSICS) | #if defined(GI_NEON64_INTRINSICS) | ||||
return vzip1q_f32(Vector1, Vector2); | return vzip1q_f32(Vector1, Vector2); | ||||
@@ -1357,3 +1392,70 @@ GI_FORCEINLINE float32x2_t GiPmaxFloat32(float32x2_t a, float32x2_t b) { | |||||
return res; | return res; | ||||
#endif | #endif | ||||
} | } | ||||
GI_FORCEINLINE | |||||
GI_FLOAT32_V3_t GiLoadUzipFloat32V3(const float* ptr) { | |||||
#if defined(GI_NEON_INTRINSICS) | |||||
return vld3q_f32(ptr); | |||||
#elif defined(GI_SSE2_INTRINSICS) | |||||
GI_FLOAT32_V3_t v; | |||||
__m128 tmp0, tmp1, tmp2, tmp3; | |||||
v.val[0] = GiLoadFloat32(ptr); | |||||
v.val[1] = GiLoadFloat32((ptr + 4)); | |||||
v.val[2] = GiLoadFloat32((ptr + 8)); | |||||
tmp0 = _mm_castsi128_ps(_mm_shuffle_epi32( | |||||
_mm_castps_si128(v.val[0]), 0 | (3 << 2) | (1 << 4) | (2 << 6))); | |||||
tmp1 = _mm_castsi128_ps( | |||||
_mm_shuffle_epi32(_mm_castps_si128(v.val[1]), _SWAP_HI_LOW32)); | |||||
tmp2 = _mm_castsi128_ps(_mm_shuffle_epi32( | |||||
_mm_castps_si128(v.val[2]), 1 | (2 << 2) | (0 << 4) | (3 << 6))); | |||||
tmp3 = _mm_unpacklo_ps(tmp1, tmp2); | |||||
v.val[0] = _mm_movelh_ps(tmp0, tmp3); | |||||
tmp0 = _mm_unpackhi_ps(tmp0, tmp1); | |||||
v.val[1] = | |||||
_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(tmp0), _SWAP_HI_LOW32)); | |||||
v.val[1] = _mm_movehl_ps(tmp3, v.val[1]); | |||||
v.val[2] = _mm_movehl_ps(tmp2, tmp0); | |||||
return v; | |||||
#else | |||||
GI_FLOAT32_V3_t ret; | |||||
for (size_t i = 0; i < 3; i++) { | |||||
ret.val[i][0] = ptr[0 + i]; | |||||
ret.val[i][1] = ptr[3 + i]; | |||||
ret.val[i][2] = ptr[6 + i]; | |||||
ret.val[i][3] = ptr[9 + i]; | |||||
} | |||||
return ret; | |||||
#endif | |||||
} | |||||
GI_FORCEINLINE | |||||
void GiStoreZipFloat32V3(float* ptr, GI_FLOAT32_V3_t val) { | |||||
#if defined(GI_NEON_INTRINSICS) | |||||
vst3q_f32(ptr, val); | |||||
#elif defined(GI_SSE2_INTRINSICS) | |||||
GI_FLOAT32_V3_t v; | |||||
__m128 tmp0, tmp1, tmp2; | |||||
tmp0 = _mm_unpacklo_ps(val.val[0], val.val[1]); | |||||
tmp1 = _mm_unpackhi_ps(val.val[0], val.val[1]); | |||||
tmp2 = _mm_unpacklo_ps(val.val[1], val.val[2]); | |||||
v.val[1] = _mm_shuffle_ps(tmp2, tmp1, _MM_SHUFFLE(1, 0, 3, 2)); | |||||
v.val[2] = _mm_movehl_ps(val.val[2], tmp1); | |||||
v.val[2] = _mm_shuffle_ps(v.val[2], v.val[2], _MM_SHUFFLE(3, 1, 0, 2)); | |||||
tmp1 = _mm_unpacklo_ps(tmp2, val.val[0]); | |||||
v.val[0] = _mm_shuffle_ps(tmp0, tmp1, _MM_SHUFFLE(3, 2, 1, 0)); | |||||
GiStoreFloat32(ptr, v.val[0]); | |||||
GiStoreFloat32((ptr + 4), v.val[1]); | |||||
GiStoreFloat32((ptr + 8), v.val[2]); | |||||
#else | |||||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { | |||||
*ptr++ = val.val[0][i]; | |||||
*ptr++ = val.val[1][i]; | |||||
*ptr++ = val.val[2][i]; | |||||
} | |||||
#endif | |||||
} |
@@ -45,17 +45,7 @@ void sgemv_gi_naive_n_mk4( | |||||
while (k < K) { | while (k < K) { | ||||
GI_FLOAT32_t b = GiLoadFloat32(Bptr); | GI_FLOAT32_t b = GiLoadFloat32(Bptr); | ||||
GI_FLOAT32_V2_t a[2]; | GI_FLOAT32_V2_t a[2]; | ||||
#if defined(GI_TEST_NAIVE) | |||||
#define LOAD_A(step) \ | |||||
a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \ | |||||
a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4); | |||||
#elif defined(__arm__) || defined(__aarch64__) | |||||
#define LOAD_A(step) a[step] = vld1q_f32_x2(Aptr0 + step * 8); | |||||
#else | |||||
#define LOAD_A(step) \ | |||||
a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \ | |||||
a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4); | |||||
#endif | |||||
#define LOAD_A(step) a[step] = GiLoadFloat32V2(Aptr0 + step * 8); | |||||
UNROLL_CALL_RAW(2, LOAD_A) | UNROLL_CALL_RAW(2, LOAD_A) | ||||
#undef LOAD_A | #undef LOAD_A | ||||
@@ -34,6 +34,8 @@ private: | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, const Param& param); | _megdnn_tensor_in src, _megdnn_tensor_out dst, const Param& param); | ||||
void exec_w2x2_s2x2_int8(_megdnn_tensor_in src, _megdnn_tensor_out dst); | void exec_w2x2_s2x2_int8(_megdnn_tensor_in src, _megdnn_tensor_out dst); | ||||
void exec_w2x2_s2x2_avg_int8(_megdnn_tensor_in src, _megdnn_tensor_out dst); | void exec_w2x2_s2x2_avg_int8(_megdnn_tensor_in src, _megdnn_tensor_out dst); | ||||
void exec_fallback( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace); | |||||
public: | public: | ||||
using naive::PoolingForwardImpl::PoolingForwardImpl; | using naive::PoolingForwardImpl::PoolingForwardImpl; | ||||
@@ -43,9 +45,6 @@ public: | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | _megdnn_tensor_in src, _megdnn_tensor_out dst, | ||||
_megdnn_workspace workspace) override; | _megdnn_workspace workspace) override; | ||||
void exec_fallback( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace); | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override; | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override; | ||||
static size_t constexpr MAX_SPATIAL_DIM = 2; | static size_t constexpr MAX_SPATIAL_DIM = 2; | ||||
@@ -0,0 +1,69 @@ | |||||
#include "src/fallback/resize/gi/direct_nchwxx.h" | |||||
using namespace megdnn; | |||||
using namespace fallback; | |||||
using namespace resize; | |||||
namespace { | |||||
template <typename ctype, InterpolationMode imode> | |||||
void resize_direct_nchwxx( | |||||
const ctype* sptr, ctype* dptr, size_t N, size_t IH, size_t IW, size_t OH, | |||||
size_t OW) { | |||||
using simd_helper = SIMDHelper<ctype>; | |||||
constexpr size_t PC = simd_helper::simd_width; | |||||
using simd_type = typename simd_helper::simd_type; | |||||
float scale_h = static_cast<float>(OH) / IH; | |||||
float scale_w = static_cast<float>(OW) / IW; | |||||
for (size_t n = 0; n < N; ++n) { | |||||
for (size_t oh = 0; oh < OH; ++oh) { | |||||
for (size_t ow = 0; ow < OW; ++ow) { | |||||
int ih0, ih1, iw0, iw1; | |||||
float ah0, ah1, aw0, aw1; | |||||
std::tie(ah0, ih0, ah1, ih1) = | |||||
get_nearest_linear_coord(imode, scale_h, IH, oh); | |||||
std::tie(aw0, iw0, aw1, iw1) = | |||||
get_nearest_linear_coord(imode, scale_w, IW, ow); | |||||
simd_type r0 = simd_helper::load(sptr + (ih0 * IW + iw0) * PC); | |||||
simd_type r1 = simd_helper::load(sptr + (ih0 * IW + iw1) * PC); | |||||
simd_type r2 = simd_helper::load(sptr + (ih1 * IW + iw0) * PC); | |||||
simd_type r3 = simd_helper::load(sptr + (ih1 * IW + iw1) * PC); | |||||
// FIXME: weight fp16 may cause precision problem | |||||
ctype a0 = ah0 * aw0; | |||||
ctype a1 = ah0 * aw1; | |||||
ctype a2 = ah1 * aw0; | |||||
ctype a3 = ah1 * aw1; | |||||
simd_type c = simd_helper::dup(0); | |||||
c = simd_helper::fma(c, r0, a0); | |||||
c = simd_helper::fma(c, r1, a1); | |||||
c = simd_helper::fma(c, r2, a2); | |||||
c = simd_helper::fma(c, r3, a3); | |||||
simd_helper::store(dptr + (oh * OW + ow) * PC, c); | |||||
} | |||||
} | |||||
sptr += IH * IW * PC; | |||||
dptr += OH * OW * PC; | |||||
} | |||||
} | |||||
} // namespace | |||||
void megdnn::fallback::resize_direct_nearest_nchw44_gi_fp32( | |||||
const ResizeImpl::KernParam<float>& kern_param) { | |||||
resize_direct_nchwxx<float, InterpolationMode::INTER_NEAREST>( | |||||
kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, | |||||
kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); | |||||
} | |||||
void megdnn::fallback::resize_direct_linear_nchw44_gi_fp32( | |||||
const ResizeImpl::KernParam<float>& kern_param) { | |||||
resize_direct_nchwxx<float, InterpolationMode::INTER_LINEAR>( | |||||
kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, | |||||
kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); | |||||
} |
@@ -0,0 +1,15 @@ | |||||
#pragma once | |||||
#include "src/fallback/resize/gi/helper.h" | |||||
#include "src/fallback/resize/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace fallback { | |||||
void resize_direct_linear_nchw44_gi_fp32( | |||||
const ResizeImpl::KernParam<float>& kern_param); | |||||
void resize_direct_nearest_nchw44_gi_fp32( | |||||
const ResizeImpl::KernParam<float>& kern_param); | |||||
} // namespace fallback | |||||
} // namespace megdnn |
@@ -0,0 +1,76 @@ | |||||
#pragma once | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | |||||
#include "src/fallback/resize/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace fallback { | |||||
namespace resize { | |||||
using InterpolationMode = Resize::InterpolationMode; | |||||
template <typename ctype> | |||||
struct SIMDHelper {}; | |||||
template <> | |||||
struct SIMDHelper<float> { | |||||
using simd_type = GI_FLOAT32_t; | |||||
using simd_type_x2 = GI_FLOAT32_V2_t; | |||||
using ctype = float; | |||||
static constexpr size_t simd_width = 4; | |||||
static GI_FORCEINLINE simd_type load(const ctype* src_ptr) { | |||||
return GiLoadFloat32(src_ptr); | |||||
} | |||||
static GI_FORCEINLINE void store(ctype* dst_ptr, const simd_type& rdst) { | |||||
GiStoreFloat32(dst_ptr, rdst); | |||||
} | |||||
static GI_FORCEINLINE void store2_interleave( | |||||
ctype* dst_ptr, const simd_type& rdst1, const simd_type& rdst2) { | |||||
simd_type_x2 rdst; | |||||
rdst.val[0] = rdst1; | |||||
rdst.val[1] = rdst2; | |||||
GiStoreZipFloat32V2(dst_ptr, rdst); | |||||
} | |||||
static GI_FORCEINLINE simd_type | |||||
fma(const simd_type& a, const simd_type& b, ctype n) { | |||||
return GiMultiplyAddScalarFloat32(a, b, n); | |||||
} | |||||
static GI_FORCEINLINE simd_type | |||||
fma(const simd_type& a, const simd_type& b, const simd_type& c) { | |||||
return GiMlaqFloat32(a, b, c); | |||||
} | |||||
static GI_FORCEINLINE simd_type dup(float val) { return GiBroadcastFloat32(val); } | |||||
}; | |||||
static GI_FORCEINLINE int get_nearest_src(float scale, int size, int idx) { | |||||
return std::min(static_cast<int>(idx / scale), size - 1); | |||||
} | |||||
static GI_FORCEINLINE std::tuple<float, int, float, int> get_nearest_linear_coord( | |||||
InterpolationMode imode, float scale, int size, int idx) { | |||||
if (size == 1) { | |||||
return std::make_tuple(1.0f, 0, 0.0f, 0); | |||||
} | |||||
float alpha = (idx + 0.5f) / scale - 0.5f; | |||||
int origin_idx = static_cast<int>(floor(alpha)); | |||||
alpha -= origin_idx; | |||||
if (imode == InterpolationMode::INTER_NEAREST) { | |||||
origin_idx = get_nearest_src(scale, size, idx); | |||||
alpha = 0; | |||||
} | |||||
if (origin_idx < 0) { | |||||
origin_idx = 0; | |||||
alpha = 0; | |||||
} else if (origin_idx + 1 >= size) { | |||||
origin_idx = size - 2; | |||||
alpha = 1; | |||||
} | |||||
return std::make_tuple(1 - alpha, origin_idx, alpha, origin_idx + 1); | |||||
} | |||||
}; // namespace resize | |||||
}; // namespace fallback | |||||
}; // namespace megdnn |
@@ -0,0 +1,20 @@ | |||||
#include <megdnn/oprs.h> | |||||
#include "src/common/cv/helper.h" | |||||
#include "src/fallback/resize/gi/helper.h" | |||||
namespace megdnn { | |||||
namespace fallback { | |||||
/** | |||||
* \fn resize_cv_exec | |||||
* \brief Used if the format is NHWC, transfer from megcv | |||||
*/ | |||||
void resize_cv_gi_exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
param::Resize::InterpolationMode imode); | |||||
} // namespace fallback | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,189 @@ | |||||
#include "src/fallback/resize/gi/upsample2_nchw.h" | |||||
using namespace megdnn; | |||||
using namespace fallback; | |||||
using namespace resize; | |||||
namespace { | |||||
template <typename ctype, size_t fh, size_t fw> | |||||
static GI_FORCEINLINE ctype | |||||
compute_linear_element(const ctype src[4], const ctype alpha[2]) { | |||||
return src[0] * alpha[0 ^ fh] * alpha[0 ^ fw] + | |||||
src[1] * alpha[0 ^ fh] * alpha[1 ^ fw] + | |||||
src[2] * alpha[1 ^ fh] * alpha[0 ^ fw] + | |||||
src[3] * alpha[1 ^ fh] * alpha[1 ^ fw]; | |||||
} | |||||
template <typename simd_helper, size_t fh, size_t fw> | |||||
static GI_FORCEINLINE typename simd_helper::simd_type compute_linear_element_simd( | |||||
const typename simd_helper::simd_type src[4], | |||||
const typename simd_helper::simd_type alpha[2][2]) { | |||||
typename simd_helper::simd_type c = simd_helper::dup(0); | |||||
c = simd_helper::fma(c, src[0], alpha[0 ^ fh][0 ^ fw]); | |||||
c = simd_helper::fma(c, src[1], alpha[0 ^ fh][1 ^ fw]); | |||||
c = simd_helper::fma(c, src[2], alpha[1 ^ fh][0 ^ fw]); | |||||
c = simd_helper::fma(c, src[3], alpha[1 ^ fh][1 ^ fw]); | |||||
return c; | |||||
} | |||||
template <typename ctype, bool has_right, bool has_bottom> | |||||
static GI_FORCEINLINE void compute_linear_2x2_element( | |||||
const ctype* src, ctype* dst, size_t IW, size_t OW, const ctype alpha[2]) { | |||||
const ctype* src_ptr[4] = {src, src, src, src}; | |||||
if (has_right) { | |||||
src_ptr[1] += 1; | |||||
src_ptr[3] += 1; | |||||
} | |||||
if (has_bottom) { | |||||
src_ptr[2] += IW; | |||||
src_ptr[3] += IW; | |||||
} | |||||
ctype rsrc[4]; | |||||
rsrc[0] = *src_ptr[0]; | |||||
rsrc[1] = *src_ptr[1]; | |||||
rsrc[2] = *src_ptr[2]; | |||||
rsrc[3] = *src_ptr[3]; | |||||
dst[0] = compute_linear_element<ctype, 0, 0>(rsrc, alpha); | |||||
if (has_right) { | |||||
dst[1] = compute_linear_element<ctype, 0, 1>(rsrc, alpha); | |||||
} | |||||
if (has_bottom) { | |||||
dst[OW] = compute_linear_element<ctype, 1, 0>(rsrc, alpha); | |||||
} | |||||
if (has_right && has_bottom) { | |||||
dst[OW + 1] = compute_linear_element<ctype, 1, 1>(rsrc, alpha); | |||||
} | |||||
} | |||||
template <typename simd_helper> | |||||
static GI_FORCEINLINE void compute_linear_2x2_element_simd( | |||||
const typename simd_helper::ctype* src, typename simd_helper::ctype* dst, | |||||
size_t IW, size_t OW, const typename simd_helper::simd_type alpha[2][2]) { | |||||
using simd_type = typename simd_helper::simd_type; | |||||
simd_type rsrc[4]; | |||||
rsrc[0] = simd_helper::load(src); | |||||
rsrc[1] = simd_helper::load(src + 1); | |||||
rsrc[2] = simd_helper::load(src + IW); | |||||
rsrc[3] = simd_helper::load(src + IW + 1); | |||||
simd_type rdst[4]; | |||||
rdst[0] = compute_linear_element_simd<simd_helper, 0, 0>(rsrc, alpha); | |||||
rdst[1] = compute_linear_element_simd<simd_helper, 0, 1>(rsrc, alpha); | |||||
rdst[2] = compute_linear_element_simd<simd_helper, 1, 0>(rsrc, alpha); | |||||
rdst[3] = compute_linear_element_simd<simd_helper, 1, 1>(rsrc, alpha); | |||||
simd_helper::store2_interleave(dst, rdst[0], rdst[1]); | |||||
simd_helper::store2_interleave(dst + OW, rdst[2], rdst[3]); | |||||
} | |||||
template <typename ctype> | |||||
void linear_upsample2_nchw( | |||||
const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { | |||||
using simd_helper = SIMDHelper<ctype>; | |||||
size_t OW = IW * 2; | |||||
constexpr size_t PC = simd_helper::simd_width; | |||||
ctype alpha[2] = {0.75, 0.25}; | |||||
typename simd_helper::simd_type simd_alpha[2][2]; | |||||
simd_alpha[0][0] = simd_helper::dup(0.75 * 0.75); | |||||
simd_alpha[0][1] = simd_helper::dup(0.75 * 0.25); | |||||
simd_alpha[1][0] = simd_helper::dup(0.25 * 0.75); | |||||
simd_alpha[1][1] = simd_helper::dup(0.25 * 0.25); | |||||
for (size_t i = 0; i < N; ++i) { | |||||
compute_linear_2x2_element<ctype, false, false>( | |||||
src_ptr, dst_ptr, IW, OW, alpha); | |||||
{ | |||||
for (size_t iw = 0; iw + 1 < IW; ++iw) { | |||||
compute_linear_2x2_element<ctype, true, false>( | |||||
src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, alpha); | |||||
} | |||||
} | |||||
compute_linear_2x2_element<ctype, false, false>( | |||||
src_ptr + (IW - 1), dst_ptr + (OW - 1), IW, OW, alpha); | |||||
dst_ptr += OW; | |||||
for (size_t ih = 0; ih + 1 < IH; ++ih) { | |||||
compute_linear_2x2_element<ctype, false, true>( | |||||
src_ptr, dst_ptr, IW, OW, alpha); | |||||
size_t iw = 0; | |||||
for (; iw + PC < IW; iw += PC) { | |||||
compute_linear_2x2_element_simd<simd_helper>( | |||||
src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, simd_alpha); | |||||
} | |||||
for (; iw + 1 < IW; ++iw) { | |||||
compute_linear_2x2_element<ctype, true, true>( | |||||
src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, alpha); | |||||
} | |||||
compute_linear_2x2_element<ctype, false, true>( | |||||
src_ptr + (IW - 1), dst_ptr + (OW - 1), IW, OW, alpha); | |||||
src_ptr += IW; | |||||
dst_ptr += 2 * OW; | |||||
} | |||||
compute_linear_2x2_element<ctype, false, false>( | |||||
src_ptr, dst_ptr, IW, OW, alpha); | |||||
{ | |||||
for (size_t iw = 0; iw + 1 < IW; ++iw) { | |||||
compute_linear_2x2_element<ctype, true, false>( | |||||
src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, alpha); | |||||
} | |||||
} | |||||
compute_linear_2x2_element<ctype, false, false>( | |||||
src_ptr + (IW - 1), dst_ptr + (OW - 1), IW, OW, alpha); | |||||
src_ptr += IW; | |||||
dst_ptr += OW; | |||||
} | |||||
} | |||||
template <typename ctype> | |||||
void nearest_upsample2_nchw( | |||||
const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { | |||||
using simd_helper = SIMDHelper<ctype>; | |||||
size_t OW = IW * 2; | |||||
constexpr size_t PC = simd_helper::simd_width; | |||||
for (size_t i = 0; i < N; ++i) { | |||||
for (size_t ih = 0; ih < IH; ++ih) { | |||||
size_t iw = 0; | |||||
for (; iw + PC - 1 < IW; iw += PC) { | |||||
typename simd_helper::simd_type r0 = simd_helper::load(src_ptr + iw); | |||||
simd_helper::store2_interleave(dst_ptr + (iw * 2), r0, r0); | |||||
simd_helper::store2_interleave(dst_ptr + (OW + iw * 2), r0, r0); | |||||
} | |||||
for (; iw < IW; iw += 1) { | |||||
ctype v = src_ptr[iw]; | |||||
dst_ptr[iw * 2] = v; | |||||
dst_ptr[iw * 2 + 1] = v; | |||||
dst_ptr[OW + iw * 2] = v; | |||||
dst_ptr[OW + iw * 2 + 1] = v; | |||||
} | |||||
src_ptr += IW; | |||||
dst_ptr += 2 * OW; | |||||
} | |||||
} | |||||
} | |||||
} // namespace | |||||
void megdnn::fallback::resize_linear_upsample2_nchw_gi_fp32( | |||||
const ResizeImpl::KernParam<float>& kern_param) { | |||||
linear_upsample2_nchw( | |||||
kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c, | |||||
kern_param.ih, kern_param.iw); | |||||
} | |||||
void megdnn::fallback::resize_nearest_upsample2_nchw_gi_fp32( | |||||
const ResizeImpl::KernParam<float>& kern_param) { | |||||
nearest_upsample2_nchw( | |||||
kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c, | |||||
kern_param.ih, kern_param.iw); | |||||
} |
@@ -0,0 +1,14 @@ | |||||
#pragma once | |||||
#include "src/fallback/resize/gi/helper.h" | |||||
namespace megdnn { | |||||
namespace fallback { | |||||
void resize_linear_upsample2_nchw_gi_fp32( | |||||
const ResizeImpl::KernParam<float>& kern_param); | |||||
void resize_nearest_upsample2_nchw_gi_fp32( | |||||
const ResizeImpl::KernParam<float>& kern_param); | |||||
} // namespace fallback | |||||
} // namespace megdnn |
@@ -0,0 +1,155 @@ | |||||
#include "src/fallback/resize/gi/upsample2_nchwxx.h" | |||||
using namespace megdnn; | |||||
using namespace fallback; | |||||
using namespace resize; | |||||
namespace { | |||||
template <typename simd_helper, size_t fh, size_t fw> | |||||
static GI_FORCEINLINE typename simd_helper::simd_type compute_linear_element( | |||||
const typename simd_helper::simd_type src[4], | |||||
const typename simd_helper::simd_type alpha[2][2]) { | |||||
typename simd_helper::simd_type c = simd_helper::dup(0); | |||||
c = simd_helper::fma(c, src[0], alpha[0 ^ fh][0 ^ fw]); | |||||
c = simd_helper::fma(c, src[1], alpha[0 ^ fh][1 ^ fw]); | |||||
c = simd_helper::fma(c, src[2], alpha[1 ^ fh][0 ^ fw]); | |||||
c = simd_helper::fma(c, src[3], alpha[1 ^ fh][1 ^ fw]); | |||||
return c; | |||||
} | |||||
template <typename simd_helper, bool has_right, bool has_bottom> | |||||
static GI_FORCEINLINE void compute_linear_2x2_element( | |||||
const typename simd_helper::ctype* src, typename simd_helper::ctype* dst, | |||||
size_t IW, size_t OW, const typename simd_helper::simd_type alpha[2][2]) { | |||||
constexpr size_t PC = simd_helper::simd_width; | |||||
const typename simd_helper::ctype* src_ptr[4] = {src, src, src, src}; | |||||
if (has_right) { | |||||
src_ptr[1] += PC; | |||||
src_ptr[3] += PC; | |||||
} | |||||
if (has_bottom) { | |||||
src_ptr[2] += IW * PC; | |||||
src_ptr[3] += IW * PC; | |||||
} | |||||
typename simd_helper::simd_type rsrc[4]; | |||||
rsrc[0] = simd_helper::load(src_ptr[0]); | |||||
rsrc[1] = simd_helper::load(src_ptr[1]); | |||||
rsrc[2] = simd_helper::load(src_ptr[2]); | |||||
rsrc[3] = simd_helper::load(src_ptr[3]); | |||||
typename simd_helper::simd_type rdst[4]; | |||||
rdst[0] = compute_linear_element<simd_helper, 0, 0>(rsrc, alpha); | |||||
rdst[1] = compute_linear_element<simd_helper, 0, 1>(rsrc, alpha); | |||||
rdst[2] = compute_linear_element<simd_helper, 1, 0>(rsrc, alpha); | |||||
rdst[3] = compute_linear_element<simd_helper, 1, 1>(rsrc, alpha); | |||||
simd_helper::store(dst, rdst[0]); | |||||
if (has_right) { | |||||
simd_helper::store(dst + PC, rdst[1]); | |||||
} | |||||
if (has_bottom) { | |||||
simd_helper::store(dst + OW * PC, rdst[2]); | |||||
} | |||||
if (has_right && has_bottom) { | |||||
simd_helper::store(dst + (OW + 1) * PC, rdst[3]); | |||||
} | |||||
} | |||||
template <typename ctype> | |||||
void linear_upsample2_nchwxx( | |||||
const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { | |||||
using simd_helper = SIMDHelper<ctype>; | |||||
size_t OW = IW * 2; | |||||
constexpr size_t PC = simd_helper::simd_width; | |||||
typename simd_helper::simd_type alpha[2][2]; | |||||
alpha[0][0] = simd_helper::dup(0.75 * 0.75); | |||||
alpha[0][1] = simd_helper::dup(0.75 * 0.25); | |||||
alpha[1][0] = simd_helper::dup(0.25 * 0.75); | |||||
alpha[1][1] = simd_helper::dup(0.25 * 0.25); | |||||
for (size_t i = 0; i < N; ++i) { | |||||
compute_linear_2x2_element<simd_helper, false, false>( | |||||
src_ptr, dst_ptr, IW, OW, alpha); | |||||
{ | |||||
for (size_t iw = 0; iw + 1 < IW; ++iw) { | |||||
compute_linear_2x2_element<simd_helper, true, false>( | |||||
src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha); | |||||
} | |||||
} | |||||
compute_linear_2x2_element<simd_helper, false, false>( | |||||
src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha); | |||||
dst_ptr += OW * PC; | |||||
for (size_t ih = 0; ih + 1 < IH; ++ih) { | |||||
compute_linear_2x2_element<simd_helper, false, true>( | |||||
src_ptr, dst_ptr, IW, OW, alpha); | |||||
for (size_t iw = 0; iw + 1 < IW; ++iw) { | |||||
compute_linear_2x2_element<simd_helper, true, true>( | |||||
src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha); | |||||
} | |||||
compute_linear_2x2_element<simd_helper, false, true>( | |||||
src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha); | |||||
src_ptr += IW * PC; | |||||
dst_ptr += 2 * OW * PC; | |||||
} | |||||
compute_linear_2x2_element<simd_helper, false, false>( | |||||
src_ptr, dst_ptr, IW, OW, alpha); | |||||
{ | |||||
for (size_t iw = 0; iw + 1 < IW; ++iw) { | |||||
compute_linear_2x2_element<simd_helper, true, false>( | |||||
src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha); | |||||
} | |||||
} | |||||
compute_linear_2x2_element<simd_helper, false, false>( | |||||
src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha); | |||||
src_ptr += IW * PC; | |||||
dst_ptr += OW * PC; | |||||
} | |||||
} | |||||
template <typename ctype> | |||||
void nearest_upsample2_nchwxx( | |||||
const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { | |||||
using simd_helper = SIMDHelper<ctype>; | |||||
size_t OW = IW * 2; | |||||
constexpr size_t PC = simd_helper::simd_width; | |||||
for (size_t i = 0; i < N; ++i) { | |||||
for (size_t ih = 0; ih < IH; ++ih) { | |||||
for (size_t iw = 0; iw < IW; ++iw) { | |||||
typename simd_helper::simd_type r0 = | |||||
simd_helper::load(src_ptr + iw * PC); | |||||
simd_helper::store(dst_ptr + (iw * 2) * PC, r0); | |||||
simd_helper::store(dst_ptr + (iw * 2 + 1) * PC, r0); | |||||
simd_helper::store(dst_ptr + (OW + iw * 2) * PC, r0); | |||||
simd_helper::store(dst_ptr + (OW + iw * 2 + 1) * PC, r0); | |||||
} | |||||
src_ptr += IW * PC; | |||||
dst_ptr += 2 * OW * PC; | |||||
} | |||||
} | |||||
} | |||||
} // namespace | |||||
void megdnn::fallback::resize_linear_upsample2_nchw44_gi_fp32( | |||||
const ResizeImpl::KernParam<float>& kern_param) { | |||||
linear_upsample2_nchwxx( | |||||
kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, | |||||
kern_param.ih, kern_param.iw); | |||||
} | |||||
void megdnn::fallback::resize_nearest_upsample2_nchw44_gi_fp32( | |||||
const ResizeImpl::KernParam<float>& kern_param) { | |||||
nearest_upsample2_nchwxx( | |||||
kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, | |||||
kern_param.ih, kern_param.iw); | |||||
} |
@@ -0,0 +1,14 @@ | |||||
#pragma once | |||||
#include "src/fallback/resize/gi/helper.h" | |||||
namespace megdnn { | |||||
namespace fallback { | |||||
void resize_linear_upsample2_nchw44_gi_fp32( | |||||
const ResizeImpl::KernParam<float>& kern_param); | |||||
void resize_nearest_upsample2_nchw44_gi_fp32( | |||||
const ResizeImpl::KernParam<float>& kern_param); | |||||
} // namespace fallback | |||||
} // namespace megdnn |
@@ -15,6 +15,14 @@ | |||||
#include "src/common/rounding_converter.cuh" | #include "src/common/rounding_converter.cuh" | ||||
#include "src/fallback/handle.h" | #include "src/fallback/handle.h" | ||||
#include "src/fallback/resize/gi/direct_nchwxx.h" | |||||
#include "src/fallback/resize/gi/resize_cv.h" | |||||
#include "src/fallback/resize/gi/upsample2_nchw.h" | |||||
#include "src/fallback/resize/gi/upsample2_nchwxx.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_fallback_resize) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace fallback; | using namespace fallback; | ||||
@@ -108,6 +116,11 @@ void ResizeImpl::kern_fallback_nhwc(const KernParam<ctype>& kern_param) { | |||||
void ResizeImpl::exec( | void ResizeImpl::exec( | ||||
_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { | _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { | ||||
check_exec(src.layout, dst.layout, workspace.size); | check_exec(src.layout, dst.layout, workspace.size); | ||||
exec_gi(src, dst, workspace); | |||||
} | |||||
void ResizeImpl::exec_fallback( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { | |||||
if (param().format == param::Resize::Format::NCHW4 || | if (param().format == param::Resize::Format::NCHW4 || | ||||
param().format == param::Resize::Format::NCHW44 || | param().format == param::Resize::Format::NCHW44 || | ||||
param().format == param::Resize::Format::NCHW88 || | param().format == param::Resize::Format::NCHW88 || | ||||
@@ -148,4 +161,92 @@ void ResizeImpl::exec( | |||||
naive::ResizeImpl::exec(src, dst, workspace); | naive::ResizeImpl::exec(src, dst, workspace); | ||||
} | } | ||||
void ResizeImpl::exec_gi( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { | |||||
bool is_contiguous = src.layout.is_contiguous() && dst.layout.is_contiguous(); | |||||
bool is_dtype_same = src.layout.dtype == dst.layout.dtype; | |||||
bool is_dtype_fp32 = src.layout.dtype == dtype::Float32(); | |||||
bool is_dtype_supported = is_dtype_same && is_dtype_fp32; | |||||
bool is_nchw_fp32 = param().format == param::Resize::Format::NCHW && is_dtype_fp32; | |||||
bool is_nchw44_fp32 = | |||||
param().format == param::Resize::Format::NCHW44 && is_dtype_fp32; | |||||
bool is_imode_nearest = | |||||
param().imode == param::Resize::InterpolationMode::INTER_NEAREST; | |||||
bool is_imode_linear = | |||||
param().imode == param::Resize::InterpolationMode::INTER_LINEAR; | |||||
bool is_imode_supported = is_imode_nearest || is_imode_linear; | |||||
bool is_upsample2 = src.layout.shape[2] * 2 == dst.layout.shape[2] && | |||||
src.layout.shape[3] * 2 == dst.layout.shape[3]; | |||||
bool usable = is_contiguous && is_dtype_supported && is_imode_supported; | |||||
if (param().format == param::Resize::Format::NHWC && | |||||
(src.layout[3] == 1 || src.layout[3] == 3) && is_nhwc_contig_wc(src.layout) && | |||||
is_dtype_fp32) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(resize_cv_gi_exec(src, dst, param().imode)); | |||||
} else if (!usable) { | |||||
exec_fallback(src, dst, workspace); | |||||
} else if (is_dtype_fp32) { | |||||
auto kern_param = KernParam<float>::from_tensors( | |||||
param().format, param().imode, src, dst, workspace); | |||||
if (is_nchw44_fp32) { | |||||
if (is_upsample2) { | |||||
if (is_imode_nearest) { | |||||
MIDOUT_BEGIN(megdnn_fallback_resize, midout_iv(0)) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
resize_nearest_upsample2_nchw44_gi_fp32(kern_param)); | |||||
} | |||||
MIDOUT_END(); | |||||
} else { | |||||
megdnn_assert(is_imode_linear, "invalid imode"); | |||||
MIDOUT_BEGIN(megdnn_fallback_resize, midout_iv(1)) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
resize_linear_upsample2_nchw44_gi_fp32(kern_param)); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} else { | |||||
if (is_imode_nearest) { | |||||
MIDOUT_BEGIN(megdnn_fallback_resize, midout_iv(2)) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
resize_direct_nearest_nchw44_gi_fp32(kern_param)); | |||||
} | |||||
MIDOUT_END(); | |||||
} else { | |||||
megdnn_assert(is_imode_linear, "invalid imode"); | |||||
MIDOUT_BEGIN(megdnn_fallback_resize, midout_iv(3)) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
resize_direct_linear_nchw44_gi_fp32(kern_param)); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} | |||||
} else if (is_nchw_fp32) { | |||||
if (is_upsample2) { | |||||
if (is_imode_nearest) { | |||||
MIDOUT_BEGIN(megdnn_fallback_resize, midout_iv(4)) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
resize_nearest_upsample2_nchw_gi_fp32(kern_param)); | |||||
} | |||||
MIDOUT_END(); | |||||
} else { | |||||
megdnn_assert(is_imode_linear, "invalid imode"); | |||||
MIDOUT_BEGIN(megdnn_fallback_resize, midout_iv(5)) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
resize_linear_upsample2_nchw_gi_fp32(kern_param)); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} else { | |||||
exec_fallback(src, dst, workspace); | |||||
} | |||||
} else { | |||||
exec_fallback(src, dst, workspace); | |||||
} | |||||
} else { | |||||
exec_fallback(src, dst, workspace); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -35,6 +35,11 @@ private: | |||||
template <typename ctype> | template <typename ctype> | ||||
void kern_fallback_nhwc(const KernParam<ctype>& kern_param); | void kern_fallback_nhwc(const KernParam<ctype>& kern_param); | ||||
void exec_fallback( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace); | |||||
void exec_gi( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace); | |||||
}; // class ResizeImpl | }; // class ResizeImpl | ||||
} // namespace fallback | } // namespace fallback | ||||
@@ -358,6 +358,21 @@ TEST_F(FALLBACK, GiLoadFloat32) { | |||||
assert_eq((float*)&ret, naive); | assert_eq((float*)&ret, naive); | ||||
} | } | ||||
TEST_F(FALLBACK, GiLoadFloat32V2) { | |||||
GI_FLOAT32_V2_t ret; | |||||
std::vector<float> s0{2.3f, 4.7f, -1.4f, 1223.6f, 1.1f, 4.0f, 99.7f, 1234.9f}; | |||||
s0.resize(SIMD_LEN * 2); | |||||
ret = GiLoadFloat32V2(s0.data()); | |||||
std::vector<float> naive; | |||||
for (size_t i = 0; i < SIMD_LEN * 2; i++) { | |||||
naive.push_back(s0[i]); | |||||
} | |||||
assert_eq((float*)&ret, naive, SIMD_LEN * 2); | |||||
} | |||||
TEST_F(FALLBACK, GiLoadFloat32LowHalf) { | TEST_F(FALLBACK, GiLoadFloat32LowHalf) { | ||||
GI_FLOAT32_t ret; | GI_FLOAT32_t ret; | ||||
std::vector<float> s0{2.3f, 4.7f, -1.4f, 1223.6f}; | std::vector<float> s0{2.3f, 4.7f, -1.4f, 1223.6f}; | ||||
@@ -701,6 +716,18 @@ TEST_F(FALLBACK, GiStoreFloat32) { | |||||
assert_eq(ret.data(), s0); | assert_eq(ret.data(), s0); | ||||
} | } | ||||
TEST_F(FALLBACK, GiStoreFloat32V2) { | |||||
GI_FLOAT32_V2_t src0; | |||||
std::vector<float> s0{1.1f, 2.2f, 3.5f, 4.9f, -1.1f, -2.2f, -3.5f, -4.9}; | |||||
s0.resize(SIMD_LEN * 2); | |||||
init((float*)&src0, s0, SIMD_LEN * 2); | |||||
std::vector<float> ret{0}; | |||||
ret.resize(SIMD_LEN * 2); | |||||
GiStoreFloat32V2(ret.data(), src0); | |||||
assert_eq(ret.data(), s0, SIMD_LEN * 2); | |||||
} | |||||
TEST_F(FALLBACK, GiStoreLaneXXFloat32) { | TEST_F(FALLBACK, GiStoreLaneXXFloat32) { | ||||
GI_FLOAT32_t src0; | GI_FLOAT32_t src0; | ||||
std::vector<float> s0{1.1f, 2.2f, 3.5f, 4.9f}; | std::vector<float> s0{1.1f, 2.2f, 3.5f, 4.9f}; | ||||
@@ -1226,7 +1253,7 @@ TEST_F(FALLBACK, GiBSLFloat32) { | |||||
naive.resize(SIMD_LEN); | naive.resize(SIMD_LEN); | ||||
memcpy(naive.data(), &na, sizeof(GI_FLOAT32_t)); | memcpy(naive.data(), &na, sizeof(GI_FLOAT32_t)); | ||||
assert_eq((float*)&ret, naive); | |||||
assert_eq_and_nan((float*)&ret, naive); | |||||
} | } | ||||
} | } | ||||
@@ -3199,6 +3226,65 @@ TEST_F(FALLBACK, GiPmaxFloat32) { | |||||
ASSERT_LT(std::abs(naive[1] - r[1]), 1e-3); | ASSERT_LT(std::abs(naive[1] - r[1]), 1e-3); | ||||
} | } | ||||
TEST_F(FALLBACK, GiStoreZipFloat32V2) { | |||||
GI_FLOAT32_V2_t src0; | |||||
std::vector<float> s0{1.1f, 2.2f, 3.5f, 4.9f, 2312.1f, 345.244f, 3.59f, -12.8f}; | |||||
s0.resize(SIMD_LEN * 2); | |||||
init((float*)&src0, s0, SIMD_LEN * 2); | |||||
std::vector<float> ret; | |||||
ret.resize(SIMD_LEN * 2); | |||||
std::vector<float> ret_cmp; | |||||
ret_cmp.resize(SIMD_LEN * 2); | |||||
GiStoreZipFloat32V2(ret.data(), src0); | |||||
GI_FLOAT32_V2_t tmp; | |||||
tmp = GiZipqFloat32(src0.val[0], src0.val[1]); | |||||
GiStoreFloat32(ret_cmp.data(), tmp.val[0]); | |||||
GiStoreFloat32(ret_cmp.data() + SIMD_LEN, tmp.val[1]); | |||||
assert_eq(ret.data(), ret_cmp, SIMD_LEN * 2); | |||||
} | |||||
TEST_F(FALLBACK, GiLoadUzipFloat32V3) { | |||||
GI_FLOAT32_V3_t ret; | |||||
std::vector<float> s0{1.1f, 2.2f, 3.5f, 4.9f, 2312.1f, 345.244f, | |||||
3.59f, -12.8f, 2.2f, 6.0f, 90.0f, 89.3f}; | |||||
s0.resize(SIMD_LEN * 3); | |||||
ret = GiLoadUzipFloat32V3(s0.data()); | |||||
std::vector<float> naive; | |||||
for (size_t i = 0; i < 3; i++) { | |||||
naive.push_back(s0[0 + i]); | |||||
naive.push_back(s0[3 + i]); | |||||
naive.push_back(s0[6 + i]); | |||||
naive.push_back(s0[9 + i]); | |||||
} | |||||
assert_eq((float*)&ret, naive); | |||||
} | |||||
TEST_F(FALLBACK, GiStoreZipFloat32V3) { | |||||
GI_FLOAT32_V3_t src0; | |||||
std::vector<float> s0{1.1f, 2.2f, 3.5f, 4.9f, 2312.1f, 345.244f, | |||||
3.59f, -12.8f, 3.59f, -12.8f, 2.2f, 6.0}; | |||||
s0.resize(SIMD_LEN * 3); | |||||
init((float*)&src0, s0, SIMD_LEN * 3); | |||||
std::vector<float> ret; | |||||
ret.resize(SIMD_LEN * 3); | |||||
GiStoreZipFloat32V3(ret.data(), src0); | |||||
std::vector<float> ret_cmp; | |||||
for (size_t i = 0; i < SIMD_LEN; i++) { | |||||
ret_cmp.push_back(s0[0 + i]); | |||||
ret_cmp.push_back(s0[4 + i]); | |||||
ret_cmp.push_back(s0[8 + i]); | |||||
} | |||||
assert_eq(ret.data(), ret_cmp, SIMD_LEN * 3); | |||||
} | |||||
} // namespace test | } // namespace test | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -167,6 +167,48 @@ TEST_F(FALLBACK, RESIZE_NCHW4_RECORD) { | |||||
} | } | ||||
} | } | ||||
namespace { | |||||
static void set_nchw_args(resize::IMode imode, std::vector<resize::TestArg>& args) { | |||||
param::Resize param; | |||||
param.format = param::Resize::Format::NCHW; | |||||
param.imode = imode; | |||||
rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul) | |||||
args.emplace_back( | |||||
param, TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul}, | |||||
TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul}); | |||||
args.emplace_back(param, TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 20, 20}); | |||||
args.emplace_back(param, TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 7, 9}); | |||||
args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8}); | |||||
args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4}); | |||||
} | |||||
} // namespace | |||||
TEST_F(FALLBACK, RESIZE_NCHW_FP32) { | |||||
std::vector<resize::TestArg> args; | |||||
set_nchw_args(resize::IMode::INTER_LINEAR, args); | |||||
set_nchw_args(resize::IMode::INTER_NEAREST, args); | |||||
Checker<Resize> checker(handle()); | |||||
for (auto&& arg : args) { | |||||
checker.set_param(arg.param) | |||||
.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.execs({arg.src, arg.dst}); | |||||
} | |||||
} | |||||
TEST_F(FALLBACK, RESIZE_NCHW44_FP32) { | |||||
std::vector<resize::TestArg> args = resize::get_nchw44_args(); | |||||
Checker<Resize> checker(handle()); | |||||
for (auto&& arg : args) { | |||||
checker.set_param(arg.param) | |||||
.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.execs({arg.src, arg.dst}); | |||||
} | |||||
} | |||||
} // namespace test | } // namespace test | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |