diff --git a/CMakeLists.txt b/CMakeLists.txt index 0356ecc8..91b7df8d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,11 +19,14 @@ CHECK_CXX_COMPILER_FLAG(-Wclass-memaccess CXX_SUPPORT_WCLASS_MEMACCESS) set(MGE_ARCH AUTO CACHE STRING "Architecture on which MegEngine to be built.") set_property(CACHE MGE_ARCH PROPERTY STRINGS AUTO x86_64 i386 + armv7 aarch64 naive fallback ) option(MGE_WITH_JIT "Build MegEngine with JIT." ON) option(MGE_WITH_HALIDE "Build MegEngine with Halide JIT" ON) +option(MGE_ARMV8_2_FEATURE_FP16 "Enable armv8.2-a+fp16 support" OFF) +option(MGE_ARMV8_2_FEATURE_DOTPROD "enable armv8.2-a+dotprod support" OFF) option(MGE_DISABLE_FLOAT16 "Disable MegEngine float16 support." OFF) option(MGE_WITH_CUDA "Enable MegEngine CUDA support." ON) option(MGE_CUDA_USE_STATIC "Enable MegEngine CUDA static linking." ON) @@ -31,12 +34,52 @@ option(MGE_WITH_TRT "Build MegEngine with TensorRT." ON) option(MGE_USE_SYSTEM_LIB "Build MegEngine with system libraries." OFF) option(MGB_WITH_FLATBUFFERS "Build MegBrain with FlatBuffers serialization support." ON) +if(CMAKE_TOOLCHAIN_FILE) + message("We are cross compiling.") + message("config FLATBUFFERS_FLATC_EXECUTABLE to: ${PROJECT_SOURCE_DIR}/build_dir/host_flatc/install/bin/flatc") + set(FLATBUFFERS_FLATC_EXECUTABLE "${PROJECT_SOURCE_DIR}/build_dir/host_flatc/install/bin/flatc") + if(ANDROID_TOOLCHAIN_ROOT) + if(NOT "${ANDROID_ARCH_NAME}" STREQUAL "") + set(ANDROID_ARCH ${ANDROID_ARCH_NAME}) + endif() + if(${ANDROID_ARCH} STREQUAL "arm") + set(MGE_ARCH "armv7") + elseif(${ANDROID_ARCH} STREQUAL "arm64") + set(MGE_ARCH "aarch64") + else() + message(FATAL_ERROR "DO NOT SUPPORT ANDROID ARCH NOW") + endif() + elseif(IOS_TOOLCHAIN_ROOT) + if(${IOS_ARCH} STREQUAL "armv7") + set(MGE_ARCH "armv7") + elseif(${IOS_ARCH} STREQUAL "arm64") + set(MGE_ARCH "aarch64") + elseif(${IOS_ARCH} STREQUAL "armv7k") + set(MGE_ARCH "armv7") + elseif(${IOS_ARCH} STREQUAL "arm64e") + set(MGE_ARCH "aarch64") + elseif(${IOS_ARCH} STREQUAL "armv7s") + set(MGE_ARCH "armv7") + else() + message(FATAL_ERROR "Unsupported IOS_ARCH.") + endif() + elseif(NOT "${ARM_CROSS_BUILD_ARCH}" STREQUAL "") + set(MGE_ARCH ${ARM_CROSS_BUILD_ARCH}) + else() + message(FATAL_ERROR "Unknown cross-compiling settings.") + endif() + message("CONFIG MGE_ARCH TO ${MGE_ARCH}") +endif() if(${MGE_ARCH} STREQUAL "AUTO") if(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") set(MGE_ARCH "x86_64") elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i386" OR ${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i686") set(MGE_ARCH "i386") + elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "aarch64" OR ${CMAKE_SYSTEM_PROCESSOR} STREQUAL "arm64") + set(MGE_ARCH "aarch64") + elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "^arm") + set(MGE_ARCH "armv7") else() message(FATAL "Unknown machine architecture for MegEngine.") endif() @@ -399,6 +442,38 @@ if(MGE_ARCH STREQUAL "x86_64" OR MGE_ARCH STREQUAL "i386") endif() endif() +if(MGE_ARCH STREQUAL "armv7") + # -funsafe-math-optimizations to enable neon auto-vectorization (since neon is not fully IEEE 754 compatible, GCC does not turn on neon auto-vectorization by default. + if(ANDROID) + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfloat-abi=softfp -mfpu=neon") + endif() + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -funsafe-math-optimizations") + set (MARCH "-march=armv7-a") + set (MEGDNN_ARMV7 1) +endif() + +if(MGE_ARCH STREQUAL "aarch64") + set(MEGDNN_AARCH64 1) + set(MEGDNN_64_BIT 1) + set(MARCH "-march=armv8-a") + if(MGE_ARMV8_2_FEATURE_FP16) + message("Enable fp16 feature support in armv8.2") + if(NOT ${MGE_DISABLE_FLOAT16}) + set(MEGDNN_ENABLE_FP16_NEON 1) + endif() + set(MARCH "-march=armv8.2-a+fp16") + endif() + + if(MGE_ARMV8_2_FEATURE_DOTPROD) + message("Enable dotprod feature support in armv8.2") + if(MGE_ARMV8_2_FEATURE_FP16) + set(MARCH "-march=armv8.2-a+fp16+dotprod") + else() + set(MARCH "-march=armv8.2-a+dotprod") + endif() + endif() + +endif() set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}") diff --git a/dnn/include/megdnn/handle.h b/dnn/include/megdnn/handle.h index a84ac2f4..c9296ea6 100644 --- a/dnn/include/megdnn/handle.h +++ b/dnn/include/megdnn/handle.h @@ -29,6 +29,9 @@ class Handle { NAIVE = 0, FALLBACK = 1, X86 = 2, + ARM_COMMON = 3, + ARMV7 = 4, + AARCH64 = 5, CUDA = 6, }; diff --git a/dnn/src/CMakeLists.txt b/dnn/src/CMakeLists.txt index bbfc8a2f..84f83db7 100644 --- a/dnn/src/CMakeLists.txt +++ b/dnn/src/CMakeLists.txt @@ -17,6 +17,22 @@ if(NOT ${MGE_ARCH} STREQUAL "naive") set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C) list(APPEND SOURCES ${SOURCES_}) endif() + elseif(${MGE_ARCH} STREQUAL "armv7") + file(GLOB_RECURSE SOURCES_ armv7/*.cpp) + list(APPEND SOURCES ${SOURCES_}) + file(GLOB_RECURSE SOURCES_ arm_common/*.cpp) + list(APPEND SOURCES ${SOURCES_}) + file(GLOB_RECURSE SOURCES_ armv7/*.S) + set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C) + list(APPEND SOURCES ${SOURCES_}) + elseif(${MGE_ARCH} STREQUAL "aarch64") + file(GLOB_RECURSE SOURCES_ aarch64/*.cpp) + list(APPEND SOURCES ${SOURCES_}) + file(GLOB_RECURSE SOURCES_ arm_common/*.cpp) + list(APPEND SOURCES ${SOURCES_}) + file(GLOB_RECURSE SOURCES_ aarch64/*.S) + set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C) + list(APPEND SOURCES ${SOURCES_}) endif() endif() diff --git a/dnn/src/aarch64/conv_bias/fp16/algos.cpp b/dnn/src/aarch64/conv_bias/fp16/algos.cpp new file mode 100644 index 00000000..4bce406b --- /dev/null +++ b/dnn/src/aarch64/conv_bias/fp16/algos.cpp @@ -0,0 +1,139 @@ +/** + * \file dnn/src/aarch64/conv_bias/fp16/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/conv_bias/fp16/algos.h" +#include "src/aarch64/conv_bias/fp16/stride2_kern.h" +#include "src/arm_common/conv_bias/direct/multi_thread_common.h" +#include "src/arm_common/conv_bias/postprocess_helper.h" + +using namespace megdnn; +using namespace aarch64; +#include "midout.h" +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +/* ===================== stride-2 algo ===================== */ +MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) + +bool ConvBiasImpl::AlgoF16DirectStride2::usable( + FallbackConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + bool aviliable = + param.filter_meta.format == param::Convolution::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float16 && + param.filter_type.enumv() == DTypeEnum::Float16 && + param.dst_type.enumv() == DTypeEnum::Float16 && + !fm.should_flip && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); + if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + aviliable &= (large_group == m_large_group); + } + return aviliable; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( + FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 1) { + auto wbundle = arm_common::MultithreadDirectConvCommon< + dt_float16, __fp16>::get_bundle_stride(param, m_large_group); + return wbundle.total_size_in_bytes(); + } + MIDOUT_END(); + return false; +} + +SmallVector +ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( + FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { + return get_kimpls(param); + } + MIDOUT_END(); + return {}; +} + +SmallVector +ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( + const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + auto FH = fm.spatial[0]; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + using Func = std::function; + Func conv = nullptr; + if (FH == 2) { + conv = fp16::conv_stride2::do_conv_2x2_stride2; + } else if (FH == 3) { + conv = fp16::conv_stride2::do_conv_3x3_stride2; + } else if (FH == 5) { + conv = fp16::conv_stride2::do_conv_5x5_stride2; + } else if (FH == 7) { + conv = fp16::conv_stride2::do_conv_7x7_stride2; + } + + WorkspaceBundle wbundle = arm_common::MultithreadDirectConvCommon< + dt_float16, __fp16>::get_bundle_stride(param, m_large_group); + SmallVector ret_kerns; + + //! Dense conv and small group + if (m_large_group) { + //! Channel wise conv and big groups + auto exec_one_group = [wbundle, conv](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + arm_common::MultithreadDirectConvCommon:: + copy_padding_kern_stride(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + arm_common::MultithreadDirectConvCommon:: + do_conv_kern_stride(bundle, kern_param, ncb_index, conv, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + arm_common::MultithreadDirectConvCommon:: + copy_padding_kern_stride(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle, conv](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + arm_common::MultithreadDirectConvCommon:: + do_conv_kern_stride(bundle, kern_param, ncb_index, conv, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/fp16/algos.h b/dnn/src/aarch64/conv_bias/fp16/algos.h new file mode 100644 index 00000000..367006ee --- /dev/null +++ b/dnn/src/aarch64/conv_bias/fp16/algos.h @@ -0,0 +1,42 @@ +/** + * \file dnn/src/aarch64/conv_bias/fp16/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/aarch64/conv_bias/opr_impl.h" +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +namespace megdnn { +namespace aarch64 { +/* ===================== stride-2 algo ===================== */ +class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + bool m_large_group; +public: + AlgoF16DirectStride2(bool large_group) : m_large_group(large_group) {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "ARMV8F16STRD2_LARGE_GROUP" + : "ARMV8F16STRD2_SMALL_GROUP"; + } + + bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(FallbackConvBiasImpl*, + const NCBKernSizeParam& param) const override; + + SmallVector dispatch_kerns(FallbackConvBiasImpl*, + const NCBKernSizeParam&) const override; +}; +} // namespace aarch64 +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/fp16/stride2_kern.h b/dnn/src/aarch64/conv_bias/fp16/stride2_kern.h new file mode 100644 index 00000000..92d64907 --- /dev/null +++ b/dnn/src/aarch64/conv_bias/fp16/stride2_kern.h @@ -0,0 +1,1037 @@ +/** + * \file dnn/src/aarch64/conv_bias/fp16/stride2_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include +#include "src/arm_common/simd_macro/neon_helper_fp16.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace aarch64 { +namespace fp16 { +namespace conv_stride2 { + +static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, + __fp16* dst, size_t IH, size_t IW, size_t OH, + size_t OW, size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + size_t width = OW >> 3; + size_t mod4_left = width & 3; + + rep(ic, IC) { + const __fp16* src_ptr = src + IW * IH * ic; + __fp16* outptr = dst; + + const __fp16* r0 = src_ptr; + const __fp16* r1 = src_ptr + IW; + + const __fp16* k0 = filter; + + MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); + rep(h, OH) { + asm volatile( + "dup v28.8h, %5.h[0] \n" + "dup v29.8h, %5.h[1] \n" + "dup v30.8h, %5.h[2] \n" + "dup v31.8h, %5.h[3] \n" + "cmp %4, #2 \n" + "mov x1, xzr \n" + // mod4_left == 3 + "bgt 0f \n" + // mod4_left == 2 + "beq 1f \n" + "cmp %4, #1 \n" + // mod4_left == 1 + "beq 2f \n" + // mod4_left == 0 + "b 3f \n" + + // mod4_left == 3 + "0: \n" + "ld1 {v0.8h, v1.8h, v2.8h}, [%1] \n" + + "ld2 {v3.8h, v4.8h}, [%2], #32 \n" + "ld2 {v9.8h, v10.8h}, [%3], #32 \n" + "ld2 {v5.8h, v6.8h}, [%2], #32 \n" + "ld2 {v11.8h, v12.8h}, [%3], #32 \n" + "ld2 {v7.8h, v8.8h}, [%2], #32 \n" + "ld2 {v13.8h, v14.8h}, [%3], #32 \n" + "fmla v0.8h, v3.8h, v28.8h \n" + "fmla v1.8h, v5.8h, v28.8h \n" + "fmla v2.8h, v7.8h, v28.8h \n" + "fmla v0.8h, v4.8h, v29.8h \n" + "fmla v1.8h, v6.8h, v29.8h \n" + "fmla v2.8h, v8.8h, v29.8h \n" + + "fmla v0.8h, v9.8h, v30.8h \n" + "fmla v1.8h, v11.8h, v30.8h \n" + "fmla v2.8h, v13.8h, v30.8h \n" + "fmla v0.8h, v10.8h, v31.8h \n" + "fmla v1.8h, v12.8h, v31.8h \n" + "fmla v2.8h, v14.8h, v31.8h \n" + + "add x1, x1, #3 \n" + "st1 {v0.8h, v1.8h, v2.8h}, [%1], #48 \n" + "b 3f \n" + + // mod4_left == 2 + "1: \n" + "ld1 {v0.8h, v1.8h}, [%1] \n" + + "ld2 {v2.8h, v3.8h}, [%2], #32 \n" + "ld2 {v6.8h, v7.8h}, [%3], #32 \n" + "ld2 {v4.8h, v5.8h}, [%2], #32 \n" + "ld2 {v8.8h, v9.8h}, [%3], #32 \n" + "fmla v0.8h, v2.8h, v28.8h \n" + "fmla v1.8h, v4.8h, v28.8h \n" + "fmla v0.8h, v3.8h, v29.8h \n" + "fmla v1.8h, v5.8h, v29.8h \n" + + "fmla v0.8h, v6.8h, v30.8h \n" + "fmla v1.8h, v8.8h, v30.8h \n" + "fmla v0.8h, v7.8h, v31.8h \n" + "fmla v1.8h, v9.8h, v31.8h \n" + + "add x1, x1, #2 \n" + "st1 {v0.8h, v1.8h}, [%1], #32 \n" + "b 3f \n" + + // mod4_left == 1 + "2: \n" + "ld1 {v0.8h}, [%1] \n" + + "ld2 {v1.8h, v2.8h}, [%2], #32 \n" + "ld2 {v3.8h, v4.8h}, [%3], #32 \n" + "fmla v0.8h, v1.8h, v28.8h \n" + "fmla v0.8h, v2.8h, v29.8h \n" + + "fmla v0.8h, v3.8h, v30.8h \n" + "fmla v0.8h, v4.8h, v31.8h \n" + + "add x1, x1, #1 \n" + "st1 {v0.8h}, [%1], #16 \n" + "b 3f \n" + + // mod4_left == 0 + "3: \n" + "cmp %0, x1 \n" + "beq 5f \n" + "4: \n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1] \n" + + "ld2 {v4.8h, v5.8h}, [%2], #32 \n" + "ld2 {v12.8h, v13.8h}, [%3], #32 \n" + "ld2 {v6.8h, v7.8h}, [%2], #32 \n" + "ld2 {v14.8h, v15.8h}, [%3], #32 \n" + "ld2 {v8.8h, v9.8h}, [%2], #32 \n" + "ld2 {v16.8h, v17.8h}, [%3], #32 \n" + "ld2 {v10.8h, v11.8h}, [%2], #32 \n" + "ld2 {v18.8h, v19.8h}, [%3], #32 \n" + "fmla v0.8h, v4.8h, v28.8h \n" + "fmla v1.8h, v6.8h, v28.8h \n" + "fmla v2.8h, v8.8h, v28.8h \n" + "fmla v3.8h, v10.8h, v28.8h \n" + "fmla v0.8h, v5.8h, v29.8h \n" + "fmla v1.8h, v7.8h, v29.8h \n" + "fmla v2.8h, v9.8h, v29.8h \n" + "fmla v3.8h, v11.8h, v29.8h \n" + + "fmla v0.8h, v12.8h, v30.8h \n" + "fmla v1.8h, v14.8h, v30.8h \n" + "fmla v2.8h, v16.8h, v30.8h \n" + "fmla v3.8h, v18.8h, v30.8h \n" + "fmla v0.8h, v13.8h, v31.8h \n" + "fmla v1.8h, v15.8h, v31.8h \n" + "fmla v2.8h, v17.8h, v31.8h \n" + "fmla v3.8h, v19.8h, v31.8h \n" + + "add x1, x1, #4 \n" + "cmp %0, x1 \n" + "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" + "bne 4b \n" + + "5: \n" + : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) + : "r"(mod4_left), "w"(_k0123) + : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30", + "v31"); + + r0 += tail_step; + r1 += tail_step; + } + + filter += 4; + } +} + +static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, + __fp16* dst, size_t IH, size_t IW, size_t OH, + size_t OW, size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + size_t width = OW >> 3; + size_t mod3_left = width % 3; + + rep(ic, IC) { + const __fp16* src_ptr = src + IW * IH * ic; + __fp16* outptr = dst; + + const __fp16* r0 = src_ptr; + const __fp16* r1 = src_ptr + IW; + const __fp16* r2 = src_ptr + IW * 2; + + const __fp16* k0 = filter; + const __fp16* k1 = filter + 3; + const __fp16* k2 = filter + 5; + + MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); + MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); + MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); + rep(h, OH) { + asm volatile( + "dup v21.8h, %6.h[0] \n" + "dup v22.8h, %6.h[1] \n" + "dup v23.8h, %6.h[2] \n" + "dup v24.8h, %6.h[3] \n" + "dup v25.8h, %7.h[1] \n" + "dup v26.8h, %7.h[2] \n" + "dup v27.8h, %7.h[3] \n" + "dup v28.8h, %8.h[2] \n" + "dup v29.8h, %8.h[3] \n" + "cmp %5, #1 \n" + "mov x1, xzr \n" + "bgt 0f \n" // mod3_left == 2 + "beq 1f \n" // mod3_left == 1 + "blt 2f \n" // mod3_left == 0 + + "0: \n" + "ld1 {v0.8h, v1.8h}, [%1] \n" + + "ld2 {v2.8h, v3.8h}, [%2], #32 \n" + "ld2 {v9.8h, v10.8h}, [%3], #32 \n" + "ld2 {v4.8h, v5.8h}, [%2], #32 \n" + "ld2 {v11.8h, v12.8h}, [%3], #32 \n" + "fmla v0.8h, v2.8h, v21.8h \n" + "fmla v1.8h, v4.8h, v21.8h \n" + "fmla v0.8h, v3.8h, v22.8h \n" + "fmla v1.8h, v5.8h, v22.8h \n" + "ld1 {v6.8h}, [%2] \n" + "ld1 {v13.8h}, [%3] \n" + "ext v7.16b, v2.16b, v4.16b, #2 \n" + "ext v8.16b, v4.16b, v6.16b, #2 \n" + "fmla v0.8h, v7.8h, v23.8h \n" + "fmla v1.8h, v8.8h, v23.8h \n" + + "ld2 {v2.8h, v3.8h}, [%4], #32 \n" + "fmla v0.8h, v9.8h, v24.8h \n" + "fmla v1.8h, v11.8h, v24.8h \n" + "fmla v0.8h, v10.8h, v25.8h \n" + "fmla v1.8h, v12.8h, v25.8h \n" + "ld2 {v4.8h, v5.8h}, [%4], #32 \n" + "ext v14.16b, v9.16b, v11.16b, #2 \n" + "ext v15.16b, v11.16b, v13.16b, #2 \n" + "fmla v0.8h, v14.8h, v26.8h \n" + "fmla v1.8h, v15.8h, v26.8h \n" + + "ld1 {v6.8h}, [%4] \n" + "fmla v0.8h, v2.8h, v27.8h \n" + "fmla v1.8h, v4.8h, v27.8h \n" + "fmla v0.8h, v3.8h, v28.8h \n" + "fmla v1.8h, v5.8h, v28.8h \n" + "ext v7.16b, v2.16b, v4.16b, #2 \n" + "ext v8.16b, v4.16b, v6.16b, #2 \n" + "fmla v0.8h, v7.8h, v29.8h \n" + "fmla v1.8h, v8.8h, v29.8h \n" + + "add x1, x1, #2 \n" + "cmp %0, x1 \n" + + "st1 {v0.8h, v1.8h}, [%1], #32 \n" + "bne 2f \n" // if width != 2 jump to 2 + "b 3f \n" // jump end + + "1: \n" + "ld1 {v0.8h}, [%1] \n" + "ld2 {v1.8h, v2.8h}, [%2], #32 \n" + + "ld2 {v5.8h, v6.8h}, [%3], #32 \n" + "ld1 {v3.8h}, [%2] \n" + "fmla v0.8h, v1.8h, v21.8h \n" + "ext v7.16b, v1.16b, v3.16b, #2 \n" + "fmla v0.8h, v2.8h, v22.8h \n" + "ld1 {v1.8h}, [%3] \n" + "fmla v0.8h, v7.8h, v23.8h \n" + "ld2 {v3.8h, v4.8h}, [%4], #32 \n" + + "fmla v0.8h, v5.8h, v24.8h \n" + "ext v7.16b, v5.16b, v1.16b, #2 \n" + "fmla v0.8h, v6.8h, v25.8h \n" + "ld1 {v5.8h}, [%4] \n" + "fmla v0.8h, v7.8h, v26.8h \n" + + "fmla v0.8h, v3.8h, v27.8h \n" + "fmla v0.8h, v4.8h, v28.8h \n" + "ext v7.16b, v3.16b, v5.16b, #2 \n" + "fmla v0.8h, v7.8h, v29.8h \n" + + "st1 {v0.8h}, [%1], #16 \n" + + "add x1, x1, #1 \n" + "cmp %0, x1 \n" + "beq 3f \n" + + "2: \n" + "ld1 {v0.8h, v1.8h, v2.8h}, [%1] \n" + + "ld2 {v3.8h, v4.8h}, [%2], #32 \n" + "ld2 {v11.8h, v12.8h}, [%3], #32 \n" + "ld2 {v5.8h, v6.8h}, [%2], #32 \n" + "ld2 {v13.8h, v14.8h}, [%3], #32 \n" + "ld2 {v7.8h, v8.8h}, [%2], #32 \n" + "ld2 {v15.8h, v16.8h}, [%3], #32 \n" + "fmla v0.8h, v3.8h, v21.8h \n" + "fmla v1.8h, v5.8h, v21.8h \n" + "fmla v2.8h, v7.8h, v21.8h \n" + "ld1 {v9.8h}, [%2] \n" + "ld1 {v17.8h}, [%3] \n" + "fmla v0.8h, v4.8h, v22.8h \n" + "fmla v1.8h, v6.8h, v22.8h \n" + "fmla v2.8h, v8.8h, v22.8h \n" + "ext v10.16b, v3.16b, v5.16b, #2 \n" + "ext v4.16b, v5.16b, v7.16b, #2 \n" + "ext v6.16b, v7.16b, v9.16b, #2 \n" + "fmla v0.8h, v10.8h, v23.8h \n" + "fmla v1.8h, v4.8h, v23.8h \n" + "fmla v2.8h, v6.8h, v23.8h \n" + + "ld2 {v3.8h, v4.8h}, [%4], #32 \n" + "fmla v0.8h, v11.8h, v24.8h \n" + "fmla v1.8h, v13.8h, v24.8h \n" + "fmla v2.8h, v15.8h, v24.8h \n" + "ld2 {v5.8h, v6.8h}, [%4], #32 \n" + "fmla v0.8h, v12.8h, v25.8h \n" + "fmla v1.8h, v14.8h, v25.8h \n" + "fmla v2.8h, v16.8h, v25.8h \n" + "ld2 {v7.8h, v8.8h}, [%4], #32 \n" + "ext v18.16b, v11.16b, v13.16b, #2 \n" + "ext v12.16b, v13.16b, v15.16b, #2 \n" + "ext v14.16b, v15.16b, v17.16b, #2 \n" + "fmla v0.8h, v18.8h, v26.8h \n" + "fmla v1.8h, v12.8h, v26.8h \n" + "fmla v2.8h, v14.8h, v26.8h \n" + + "ld1 {v9.8h}, [%4] \n" + "fmla v0.8h, v3.8h, v27.8h \n" + "fmla v1.8h, v5.8h, v27.8h \n" + "fmla v2.8h, v7.8h, v27.8h \n" + "fmla v0.8h, v4.8h, v28.8h \n" + "fmla v1.8h, v6.8h, v28.8h \n" + "fmla v2.8h, v8.8h, v28.8h \n" + "ext v10.16b, v3.16b, v5.16b, #2 \n" + "ext v4.16b, v5.16b, v7.16b, #2 \n" + "ext v6.16b, v7.16b, v9.16b, #2 \n" + "fmla v0.8h, v10.8h, v29.8h \n" + "fmla v1.8h, v4.8h, v29.8h \n" + "fmla v2.8h, v6.8h, v29.8h \n" + + "add x1, x1, #3 \n" + "cmp %0, x1 \n" + + "st1 {v0.8h, v1.8h, v2.8h}, [%1], #48 \n" + "bne 2b \n" // if + "3: \n" + : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) + : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) + : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29"); + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } + + filter += 9; + } +} + +static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, + __fp16* dst, size_t IH, size_t IW, size_t OH, + size_t OW, size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + size_t width = OW >> 3; + size_t mod2_left = width & 1; + + rep(ic, IC) { + const __fp16* src_ptr = src + IW * IH * ic; + __fp16* outptr = dst; + + const __fp16* r0 = src_ptr; + const __fp16* r1 = src_ptr + IW; + const __fp16* r2 = src_ptr + IW * 2; + const __fp16* r3 = src_ptr + IW * 3; + const __fp16* r4 = src_ptr + IW * 4; + + register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); + register MEGDNN_SIMD_TYPE _k4567 asm("v1") = + MEGDNN_SIMD_LOADU(filter + 4); + register MEGDNN_SIMD_TYPE _k891011 asm("v2") = + MEGDNN_SIMD_LOADU(filter + 8); + register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = + MEGDNN_SIMD_LOADU(filter + 12); + register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = + MEGDNN_SIMD_LOADU(filter + 16); + register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = + MEGDNN_SIMD_LOADU(filter + 20); + register MEGDNN_SIMD_TYPE _k24242424 asm("v6") = + MEGDNN_SIMD_SET1(filter[24]); + + for (size_t i = 0; i < OH; i++) { + asm volatile( + "cmp %14, #0 \n" + "mov x1, xzr \n" + "beq 1f \n" + + // mod2_left == 1 + "0: \n" + "ld1 {v7.8h}, [%1] \n" + + // v8.8h: 0 2 4 6 v9.8h: 1 3 5 7 + "ld2 {v8.8h, v9.8h}, [%2], #32 \n" + "ld2 {v15.8h, v16.8h}, [%3], #32 \n" + "fmla v7.8h, v8.8h, %7.h[0] \n" + "fmla v7.8h, v9.8h, %7.h[1] \n" + "ld2 {v10.8h, v11.8h}, [%2] \n" + "ld2 {v17.8h, v18.8h}, [%3] \n" + // v12.8h: 2 4 6 8 + "ext v12.16b, v8.16b, v10.16b, #2 \n" + // v13.8h: 3 5 7 9 + "ext v13.16b, v9.16b, v11.16b, #2 \n" + "fmla v7.8h, v12.8h, %7.h[2] \n" + "fmla v7.8h, v13.8h, %7.h[3] \n" + // v14.8h: 4 6 8 10 + "ext v14.16b, v8.16b, v10.16b, #4 \n" + "fmla v7.8h, v14.8h, %8.h[0] \n" + + "ld2 {v8.8h, v9.8h}, [%4], #32 \n" + "fmla v7.8h, v15.8h, %8.h[1] \n" + "fmla v7.8h, v16.8h, %8.h[2] \n" + "ld2 {v10.8h, v11.8h}, [%4] \n" + "ext v19.16b, v15.16b, v17.16b, #2 \n" + "ext v20.16b, v16.16b, v18.16b, #2 \n" + "fmla v7.8h, v19.8h, %8.h[3] \n" + "fmla v7.8h, v20.8h, %9.h[0] \n" + "ext v21.16b, v15.16b, v17.16b, #4 \n" + "fmla v7.8h, v21.8h, %9.h[1] \n" + + "ld2 {v15.8h, v16.8h}, [%5], #32 \n" + "fmla v7.8h, v8.8h, %9.h[2] \n" + "fmla v7.8h, v9.8h, %9.h[3] \n" + "ld2 {v17.8h, v18.8h}, [%5] \n" + "ext v12.16b, v8.16b, v10.16b, #2 \n" + "ext v13.16b, v9.16b, v11.16b, #2 \n" + "fmla v7.8h, v12.8h, %10.h[0] \n" + "fmla v7.8h, v13.8h, %10.h[1] \n" + "ext v14.16b, v8.16b, v10.16b, #4 \n" + "fmla v7.8h, v14.8h, %10.h[2] \n" + + "ld2 {v8.8h, v9.8h}, [%6], #32 \n" + "fmla v7.8h, v15.8h, %10.h[3] \n" + "fmla v7.8h, v16.8h, %11.h[0] \n" + "ld2 {v10.8h, v11.8h}, [%6] \n" + "ext v19.16b, v15.16b, v17.16b, #2 \n" + "ext v20.16b, v16.16b, v18.16b, #2 \n" + "fmla v7.8h, v19.8h, %11.h[1] \n" + "fmla v7.8h, v20.8h, %11.h[2] \n" + "ext v21.16b, v15.16b, v17.16b, #4 \n" + "fmla v7.8h, v21.8h, %11.h[3] \n" + + "fmla v7.8h, v8.8h, %12.h[0] \n" + "fmla v7.8h, v9.8h, %12.h[1] \n" + "ext v12.16b, v8.16b, v10.16b, #2 \n" + "ext v13.16b, v9.16b, v11.16b, #2 \n" + "fmla v7.8h, v12.8h, %12.h[2] \n" + "fmla v7.8h, v13.8h, %12.h[3] \n" + "ext v14.16b, v8.16b, v10.16b, #4 \n" + "fmla v7.8h, v14.8h, %13.h[0] \n" + + "add x1, x1, #1 \n" + "st1 {v7.8h}, [%1], #16 \n" + + "1: \n" + "cmp %0, x1 \n" + "beq 3f \n" + + // mod2_left == 0 + "2: \n" + "ld1 {v7.8h, v8.8h}, [%1] \n" + + // v9.8h: 0 2 4 6 v10.8h: 1 3 5 7 + "ld2 {v9.8h, v10.8h}, [%2], #32 \n" + "ld2 {v21.8h, v22.8h}, [%3], #32 \n" + // v11.8h: 8 10 12 14 v12.8h: 9 11 13 15 + "ld2 {v11.8h, v12.8h}, [%2], #32 \n" + "ld2 {v23.8h, v24.8h}, [%3], #32 \n" + // v13.8h: 16 18 20 22 v14.8h: 17 19 21 23 + "ld2 {v13.8h, v14.8h}, [%2] \n" + "ld2 {v25.8h, v26.8h}, [%3] \n" + // v15.8h: 2 4 6 8 + "ext v15.16b, v9.16b, v11.16b, #2 \n" + // v16.8h: 3 5 7 9 + "ext v16.16b, v10.16b, v12.16b, #2 \n" + // v17.8h: 4 6 8 10 + "ext v17.16b, v9.16b, v11.16b, #4 \n" + // v18.8h: 10 12 14 16 + "ext v18.16b, v11.16b, v13.16b, #2 \n" + // v19.8h: 11 13 15 17 + "ext v19.16b, v12.16b, v14.16b, #2 \n" + // v20.8h: 12 14 16 18 + "ext v20.16b, v11.16b, v13.16b, #4 \n" + "fmla v7.8h, v9.8h, %7.h[0] \n" + "fmla v7.8h, v10.8h, %7.h[1] \n" + "fmla v7.8h, v15.8h, %7.h[2] \n" + "fmla v7.8h, v16.8h, %7.h[3] \n" + "fmla v7.8h, v17.8h, %8.h[0] \n" + "fmla v8.8h, v11.8h, %7.h[0] \n" + "fmla v8.8h, v12.8h, %7.h[1] \n" + "fmla v8.8h, v18.8h, %7.h[2] \n" + "fmla v8.8h, v19.8h, %7.h[3] \n" + "fmla v8.8h, v20.8h, %8.h[0] \n" + + "ld2 {v9.8h, v10.8h}, [%4], #32 \n" + "ext v27.16b, v21.16b, v23.16b, #2 \n" + "ext v28.16b, v22.16b, v24.16b, #2 \n" + "ext v29.16b, v21.16b, v23.16b, #4 \n" + "fmla v7.8h, v21.8h, %8.h[1] \n" + "fmla v7.8h, v22.8h, %8.h[2] \n" + "fmla v7.8h, v27.8h, %8.h[3] \n" + "fmla v7.8h, v28.8h, %9.h[0] \n" + "fmla v7.8h, v29.8h, %9.h[1] \n" + "ld2 {v11.8h, v12.8h}, [%4], #32 \n" + "ext v30.16b, v23.16b, v25.16b, #2 \n" + "ext v31.16b, v24.16b, v26.16b, #2 \n" + "ext v21.16b, v23.16b, v25.16b, #4 \n" + "ld2 {v13.8h, v14.8h}, [%4] \n" + "fmla v8.8h, v23.8h, %8.h[1] \n" + "fmla v8.8h, v24.8h, %8.h[2] \n" + "fmla v8.8h, v30.8h, %8.h[3] \n" + "fmla v8.8h, v31.8h, %9.h[0] \n" + "fmla v8.8h, v21.8h, %9.h[1] \n" + + "ld2 {v21.8h, v22.8h}, [%5], #32 \n" + "ext v15.16b, v9.16b, v11.16b, #2 \n" + "ext v16.16b, v10.16b, v12.16b, #2 \n" + "ext v17.16b, v9.16b, v11.16b, #4 \n" + "ext v18.16b, v11.16b, v13.16b, #2 \n" + "ext v19.16b, v12.16b, v14.16b, #2 \n" + "ext v20.16b, v11.16b, v13.16b, #4 \n" + "ld2 {v23.8h, v24.8h}, [%5], #32 \n" + "fmla v7.8h, v9.8h, %9.h[2] \n" + "fmla v7.8h, v10.8h, %9.h[3] \n" + "fmla v7.8h, v15.8h, %10.h[0] \n" + "fmla v7.8h, v16.8h, %10.h[1] \n" + "fmla v7.8h, v17.8h, %10.h[2] \n" + "ld2 {v25.8h, v26.8h}, [%5] \n" + "fmla v8.8h, v11.8h, %9.h[2] \n" + "fmla v8.8h, v12.8h, %9.h[3] \n" + "fmla v8.8h, v18.8h, %10.h[0] \n" + "fmla v8.8h, v19.8h, %10.h[1] \n" + "fmla v8.8h, v20.8h, %10.h[2] \n" + + "ld2 {v9.8h, v10.8h}, [%6], #32 \n" + "ext v27.16b, v21.16b, v23.16b, #2 \n" + "ext v28.16b, v22.16b, v24.16b, #2 \n" + "ext v29.16b, v21.16b, v23.16b, #4 \n" + "fmla v7.8h, v21.8h, %10.h[3] \n" + "fmla v7.8h, v22.8h, %11.h[0] \n" + "fmla v7.8h, v27.8h, %11.h[1] \n" + "fmla v7.8h, v28.8h, %11.h[2] \n" + "fmla v7.8h, v29.8h, %11.h[3] \n" + "ld2 {v11.8h, v12.8h}, [%6], #32 \n" + "ext v30.16b, v23.16b, v25.16b, #2 \n" + "ext v31.16b, v24.16b, v26.16b, #2 \n" + "ext v21.16b, v23.16b, v25.16b, #4 \n" + "ld2 {v13.8h, v14.8h}, [%6] \n" + "fmla v8.8h, v23.8h, %10.h[3] \n" + "fmla v8.8h, v24.8h, %11.h[0] \n" + "fmla v8.8h, v30.8h, %11.h[1] \n" + "fmla v8.8h, v31.8h, %11.h[2] \n" + "fmla v8.8h, v21.8h, %11.h[3] \n" + + "ext v15.16b, v9.16b, v11.16b, #2 \n" + "ext v16.16b, v10.16b, v12.16b, #2 \n" + "ext v17.16b, v9.16b, v11.16b, #4 \n" + "ext v18.16b, v11.16b, v13.16b, #2 \n" + "ext v19.16b, v12.16b, v14.16b, #2 \n" + "ext v20.16b, v11.16b, v13.16b, #4 \n" + "fmla v7.8h, v9.8h, %12.h[0] \n" + "fmla v7.8h, v10.8h, %12.h[1] \n" + "fmla v7.8h, v15.8h, %12.h[2] \n" + "fmla v7.8h, v16.8h, %12.h[3] \n" + "fmla v7.8h, v17.8h, %13.8h \n" + "fmla v8.8h, v11.8h, %12.h[0] \n" + "fmla v8.8h, v12.8h, %12.h[1] \n" + "fmla v8.8h, v18.8h, %12.h[2] \n" + "fmla v8.8h, v19.8h, %12.h[3] \n" + "fmla v8.8h, v20.8h, %13.8h \n" + + "add x1, x1, #2 \n" + "cmp %0, x1 \n" + "st1 {v7.8h, v8.8h}, [%1], #32 \n" + "bne 2b \n" + "3: \n" + + : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), + "+r"(r3), "+r"(r4) + : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), + "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), + "r"(mod2_left) + : "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31"); + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } + + filter += 25; + } +} + +static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, + __fp16* dst, size_t IH, size_t IW, size_t OH, + size_t OW, size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + size_t width = OW >> 3; + + rep(ic, IC) { + const __fp16* src_ptr = src + IW * IH * ic; + __fp16* outptr = dst; + + const __fp16* r0 = src_ptr; + const __fp16* r1 = src_ptr + IW; + const __fp16* r2 = src_ptr + IW * 2; + const __fp16* r3 = src_ptr + IW * 3; + const __fp16* r4 = src_ptr + IW * 4; + const __fp16* r5 = src_ptr + IW * 5; + const __fp16* r6 = src_ptr + IW * 6; + + register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); + register MEGDNN_SIMD_TYPE _k4567 asm("v1") = + MEGDNN_SIMD_LOADU(filter + 4); + register MEGDNN_SIMD_TYPE _k891011 asm("v2") = + MEGDNN_SIMD_LOADU(filter + 8); + register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = + MEGDNN_SIMD_LOADU(filter + 12); + register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = + MEGDNN_SIMD_LOADU(filter + 16); + register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = + MEGDNN_SIMD_LOADU(filter + 20); + register MEGDNN_SIMD_TYPE _k24252627 asm("v6") = + MEGDNN_SIMD_LOADU(filter + 24); + register MEGDNN_SIMD_TYPE _k28293031 asm("v7") = + MEGDNN_SIMD_LOADU(filter + 28); + register MEGDNN_SIMD_TYPE _k32333435 asm("v8") = + MEGDNN_SIMD_LOADU(filter + 32); + register MEGDNN_SIMD_TYPE _k36373839 asm("v9") = + MEGDNN_SIMD_LOADU(filter + 36); + register MEGDNN_SIMD_TYPE _k40414243 asm("v10") = + MEGDNN_SIMD_LOADU(filter + 40); + register MEGDNN_SIMD_TYPE _k44454647 asm("v11") = + MEGDNN_SIMD_LOADU(filter + 44); + register MEGDNN_SIMD_TYPE _k48484848 asm("v12") = + MEGDNN_SIMD_SET1(filter[48]); + + for (size_t i = 0; i < OH; i++) { + asm volatile( + "and x1, %8, #1 \n" + "cmp x1, #0 \n" + "mov x1, xzr \n" + "beq 1f \n" + + // mod2_left == 1 + "0: \n" + "ld1 {v13.8h}, [%0] \n" + + // v14.8h: 0 2 4 6 v15.8h: 1 3 5 7 + "ld2 {v14.8h, v15.8h}, [%1], #32 \n" + "ld2 {v23.8h, v24.8h}, [%2], #32 \n" + "ld2 {v16.8h, v17.8h}, [%1] \n" + "ld2 {v25.8h, v26.8h}, [%2] \n" + // v18.8h: 2 4 6 8 + "ext v18.16b, v14.16b, v16.16b, #2 \n" + // v19.8h: 3 5 7 9 + "ext v19.16b, v15.16b, v17.16b, #2 \n" + // v20.8h: 4 6 8 10 + "ext v20.16b, v14.16b, v16.16b, #4 \n" + // v21.8h: 5 7 9 11 + "ext v21.16b, v15.16b, v17.16b, #4 \n" + // v22.8h: 6 8 10 12 + "ext v22.16b, v14.16b, v16.16b, #6 \n" + "fmla v13.8h, v14.8h, %9.h[0] \n" + "fmla v13.8h, v15.8h, %9.h[1] \n" + "fmla v13.8h, v18.8h, %9.h[2] \n" + "fmla v13.8h, v19.8h, %9.h[3] \n" + "fmla v13.8h, v20.8h, %10.h[0] \n" + "fmla v13.8h, v21.8h, %10.h[1] \n" + "fmla v13.8h, v22.8h, %10.h[2] \n" + + "ld2 {v14.8h, v15.8h}, [%3], #32 \n" + "ext v27.16b, v23.16b, v25.16b, #2 \n" + "ext v28.16b, v24.16b, v26.16b, #2 \n" + "ext v29.16b, v23.16b, v25.16b, #4 \n" + "ext v30.16b, v24.16b, v26.16b, #4 \n" + "ext v31.16b, v23.16b, v25.16b, #6 \n" + "ld2 {v16.8h, v17.8h}, [%3] \n" + "fmla v13.8h, v23.8h, %10.h[3] \n" + "fmla v13.8h, v24.8h, %11.h[0] \n" + "fmla v13.8h, v27.8h, %11.h[1] \n" + "fmla v13.8h, v28.8h, %11.h[2] \n" + "fmla v13.8h, v29.8h, %11.h[3] \n" + "fmla v13.8h, v30.8h, %12.h[0] \n" + "fmla v13.8h, v31.8h, %12.h[1] \n" + + "ld2 {v23.8h, v24.8h}, [%4], #32 \n" + "ext v18.16b, v14.16b, v16.16b, #2 \n" + "ext v19.16b, v15.16b, v17.16b, #2 \n" + "ext v20.16b, v14.16b, v16.16b, #4 \n" + "ext v21.16b, v15.16b, v17.16b, #4 \n" + "ext v22.16b, v14.16b, v16.16b, #6 \n" + "ld2 {v25.8h, v26.8h}, [%4] \n" + "fmla v13.8h, v14.8h, %12.h[2] \n" + "fmla v13.8h, v15.8h, %12.h[3] \n" + "fmla v13.8h, v18.8h, %13.h[0] \n" + "fmla v13.8h, v19.8h, %13.h[1] \n" + "fmla v13.8h, v20.8h, %13.h[2] \n" + "fmla v13.8h, v21.8h, %13.h[3] \n" + "fmla v13.8h, v22.8h, %14.h[0] \n" + + "ld2 {v14.8h, v15.8h}, [%5], #32 \n" + "ext v27.16b, v23.16b, v25.16b, #2 \n" + "ext v28.16b, v24.16b, v26.16b, #2 \n" + "ext v29.16b, v23.16b, v25.16b, #4 \n" + "ext v30.16b, v24.16b, v26.16b, #4 \n" + "ext v31.16b, v23.16b, v25.16b, #6 \n" + "ld2 {v16.8h, v17.8h}, [%5] \n" + "fmla v13.8h, v23.8h, %14.h[1] \n" + "fmla v13.8h, v24.8h, %14.h[2] \n" + "fmla v13.8h, v27.8h, %14.h[3] \n" + "fmla v13.8h, v28.8h, %15.h[0] \n" + "fmla v13.8h, v29.8h, %15.h[1] \n" + "fmla v13.8h, v30.8h, %15.h[2] \n" + "fmla v13.8h, v31.8h, %15.h[3] \n" + + "ld2 {v23.8h, v24.8h}, [%6], #32 \n" + "ext v18.16b, v14.16b, v16.16b, #2 \n" + "ext v19.16b, v15.16b, v17.16b, #2 \n" + "ext v20.16b, v14.16b, v16.16b, #4 \n" + "ext v21.16b, v15.16b, v17.16b, #4 \n" + "ext v22.16b, v14.16b, v16.16b, #6 \n" + "ld2 {v25.8h, v26.8h}, [%6] \n" + "fmla v13.8h, v14.8h, %16.h[0] \n" + "fmla v13.8h, v15.8h, %16.h[1] \n" + "fmla v13.8h, v18.8h, %16.h[2] \n" + "fmla v13.8h, v19.8h, %16.h[3] \n" + "fmla v13.8h, v20.8h, %17.h[0] \n" + "fmla v13.8h, v21.8h, %17.h[1] \n" + "fmla v13.8h, v22.8h, %17.h[2] \n" + + "ld2 {v14.8h, v15.8h}, [%7], #32 \n" + "ext v27.16b, v23.16b, v25.16b, #2 \n" + "ext v28.16b, v24.16b, v26.16b, #2 \n" + "ext v29.16b, v23.16b, v25.16b, #4 \n" + "ext v30.16b, v24.16b, v26.16b, #4 \n" + "ext v31.16b, v23.16b, v25.16b, #6 \n" + "ld2 {v16.8h, v17.8h}, [%7] \n" + "fmla v13.8h, v23.8h, %17.h[3] \n" + "fmla v13.8h, v24.8h, %18.h[0] \n" + "fmla v13.8h, v27.8h, %18.h[1] \n" + "fmla v13.8h, v28.8h, %18.h[2] \n" + "fmla v13.8h, v29.8h, %18.h[3] \n" + "fmla v13.8h, v30.8h, %19.h[0] \n" + "fmla v13.8h, v31.8h, %19.h[1] \n" + + "ext v18.16b, v14.16b, v16.16b, #2 \n" + "ext v19.16b, v15.16b, v17.16b, #2 \n" + "ext v20.16b, v14.16b, v16.16b, #4 \n" + "ext v21.16b, v15.16b, v17.16b, #4 \n" + "ext v22.16b, v14.16b, v16.16b, #6 \n" + "fmla v13.8h, v14.8h, %19.h[2] \n" + "fmla v13.8h, v15.8h, %19.h[3] \n" + "fmla v13.8h, v18.8h, %20.h[0] \n" + "fmla v13.8h, v19.8h, %20.h[1] \n" + "fmla v13.8h, v20.8h, %20.h[2] \n" + "fmla v13.8h, v21.8h, %20.h[3] \n" + "fmla v13.8h, v22.8h, %21.8h \n" + + "add x1, x1, #1 \n" + "st1 {v13.8h}, [%0], #16 \n" + + "1: \n" + "cmp %8, x1 \n" + "beq 3f \n" + + // mod2_left == 0 + "2: \n" + "ld1 {v13.8h, v14.8h}, [%0] \n" + + // v15.8h: 0 2 4 6 v16.8h: 1 3 5 7 + "ld2 {v15.8h, v16.8h}, [%1], #32 \n" + // v17.8h: 8 10 12 14 v16.8h: 9 11 13 15 + "ld2 {v17.8h, v18.8h}, [%1], #32 \n" + // v19.8h: 16 18 20 22 v20.8h: 17 19 21 23 + "ld2 {v19.8h, v20.8h}, [%1] \n" + // v21.8h: 2 4 6 8 + "ext v21.16b, v15.16b, v17.16b, #2 \n" + // v22.8h: 3 5 7 9 + "ext v22.16b, v16.16b, v18.16b, #2 \n" + // v23.8h: 4 6 8 10 + "ext v23.16b, v15.16b, v17.16b, #4 \n" + // v24.8h: 5 7 9 11 + "ext v24.16b, v16.16b, v18.16b, #4 \n" + // v25.8h: 6 8 10 12 + "ext v25.16b, v15.16b, v17.16b, #6 \n" + "fmla v13.8h, v15.8h, %9.h[0] \n" + "fmla v13.8h, v16.8h, %9.h[1] \n" + "fmla v13.8h, v21.8h, %9.h[2] \n" + "fmla v13.8h, v22.8h, %9.h[3] \n" + "fmla v13.8h, v23.8h, %10.h[0] \n" + "fmla v13.8h, v24.8h, %10.h[1] \n" + "fmla v13.8h, v25.8h, %10.h[2] \n" + // v15.8h: 10 12 14 16 + "ext v15.16b, v17.16b, v19.16b, #2 \n" + // v16.8h: 11 13 15 17 + "ext v16.16b, v18.16b, v20.16b, #2 \n" + // v21.8h: 12 14 16 18 + "ext v21.16b, v17.16b, v19.16b, #4 \n" + // v22.8h: 13 15 17 19 + "ext v22.16b, v18.16b, v20.16b, #4 \n" + // v23.8h: 14 16 18 19 + "ext v23.16b, v17.16b, v19.16b, #6 \n" + "fmla v14.8h, v17.8h, %9.h[0] \n" + "fmla v14.8h, v18.8h, %9.h[1] \n" + "fmla v14.8h, v15.8h, %9.h[2] \n" + "fmla v14.8h, v16.8h, %9.h[3] \n" + "fmla v14.8h, v21.8h, %10.h[0] \n" + "fmla v14.8h, v22.8h, %10.h[1] \n" + "fmla v14.8h, v23.8h, %10.h[2] \n" + + "ld2 {v26.8h, v27.8h}, [%2], #32 \n" + "ld2 {v28.8h, v29.8h}, [%2], #32 \n" + "ld2 {v30.8h, v31.8h}, [%2] \n" + "ext v21.16b, v26.16b, v28.16b, #2 \n" + "ext v22.16b, v27.16b, v29.16b, #2 \n" + "ext v23.16b, v26.16b, v28.16b, #4 \n" + "ext v24.16b, v27.16b, v29.16b, #4 \n" + "ext v25.16b, v26.16b, v28.16b, #6 \n" + "fmla v13.8h, v26.8h, %10.h[3] \n" + "fmla v13.8h, v27.8h, %11.h[0] \n" + "fmla v13.8h, v21.8h, %11.h[1] \n" + "fmla v13.8h, v22.8h, %11.h[2] \n" + "fmla v13.8h, v23.8h, %11.h[3] \n" + "fmla v13.8h, v24.8h, %12.h[0] \n" + "fmla v13.8h, v25.8h, %12.h[1] \n" + "ext v26.16b, v28.16b, v30.16b, #2 \n" + "ext v27.16b, v29.16b, v31.16b, #2 \n" + "ext v21.16b, v28.16b, v30.16b, #4 \n" + "ext v22.16b, v29.16b, v31.16b, #4 \n" + "ext v23.16b, v28.16b, v30.16b, #6 \n" + "fmla v14.8h, v28.8h, %10.h[3] \n" + "fmla v14.8h, v29.8h, %11.h[0] \n" + "fmla v14.8h, v26.8h, %11.h[1] \n" + "fmla v14.8h, v27.8h, %11.h[2] \n" + "fmla v14.8h, v21.8h, %11.h[3] \n" + "fmla v14.8h, v22.8h, %12.h[0] \n" + "fmla v14.8h, v23.8h, %12.h[1] \n" + + "ld2 {v15.8h, v16.8h}, [%3], #32 \n" + "ld2 {v17.8h, v18.8h}, [%3], #32 \n" + "ld2 {v19.8h, v20.8h}, [%3] \n" + "ext v21.16b, v15.16b, v17.16b, #2 \n" + "ext v22.16b, v16.16b, v18.16b, #2 \n" + "ext v23.16b, v15.16b, v17.16b, #4 \n" + "ext v24.16b, v16.16b, v18.16b, #4 \n" + "ext v25.16b, v15.16b, v17.16b, #6 \n" + "fmla v13.8h, v15.8h, %12.h[2] \n" + "fmla v13.8h, v16.8h, %12.h[3] \n" + "fmla v13.8h, v21.8h, %13.h[0] \n" + "fmla v13.8h, v22.8h, %13.h[1] \n" + "fmla v13.8h, v23.8h, %13.h[2] \n" + "fmla v13.8h, v24.8h, %13.h[3] \n" + "fmla v13.8h, v25.8h, %14.h[0] \n" + "ext v15.16b, v17.16b, v19.16b, #2 \n" + "ext v16.16b, v18.16b, v20.16b, #2 \n" + "ext v21.16b, v17.16b, v19.16b, #4 \n" + "ext v22.16b, v18.16b, v20.16b, #4 \n" + "ext v23.16b, v17.16b, v19.16b, #6 \n" + "fmla v14.8h, v17.8h, %12.h[2] \n" + "fmla v14.8h, v18.8h, %12.h[3] \n" + "fmla v14.8h, v15.8h, %13.h[0] \n" + "fmla v14.8h, v16.8h, %13.h[1] \n" + "fmla v14.8h, v21.8h, %13.h[2] \n" + "fmla v14.8h, v22.8h, %13.h[3] \n" + "fmla v14.8h, v23.8h, %14.h[0] \n" + + "ld2 {v26.8h, v27.8h}, [%4], #32 \n" + "ld2 {v28.8h, v29.8h}, [%4], #32 \n" + "ld2 {v30.8h, v31.8h}, [%4] \n" + "ext v21.16b, v26.16b, v28.16b, #2 \n" + "ext v22.16b, v27.16b, v29.16b, #2 \n" + "ext v23.16b, v26.16b, v28.16b, #4 \n" + "ext v24.16b, v27.16b, v29.16b, #4 \n" + "ext v25.16b, v26.16b, v28.16b, #6 \n" + "fmla v13.8h, v26.8h, %14.h[1] \n" + "fmla v13.8h, v27.8h, %14.h[2] \n" + "fmla v13.8h, v21.8h, %14.h[3] \n" + "fmla v13.8h, v22.8h, %15.h[0] \n" + "fmla v13.8h, v23.8h, %15.h[1] \n" + "fmla v13.8h, v24.8h, %15.h[2] \n" + "fmla v13.8h, v25.8h, %15.h[3] \n" + "ext v26.16b, v28.16b, v30.16b, #2 \n" + "ext v27.16b, v29.16b, v31.16b, #2 \n" + "ext v21.16b, v28.16b, v30.16b, #4 \n" + "ext v22.16b, v29.16b, v31.16b, #4 \n" + "ext v23.16b, v28.16b, v30.16b, #6 \n" + "fmla v14.8h, v28.8h, %14.h[1] \n" + "fmla v14.8h, v29.8h, %14.h[2] \n" + "fmla v14.8h, v26.8h, %14.h[3] \n" + "fmla v14.8h, v27.8h, %15.h[0] \n" + "fmla v14.8h, v21.8h, %15.h[1] \n" + "fmla v14.8h, v22.8h, %15.h[2] \n" + "fmla v14.8h, v23.8h, %15.h[3] \n" + + "ld2 {v15.8h, v16.8h}, [%5], #32 \n" + "ld2 {v17.8h, v18.8h}, [%5], #32 \n" + "ld2 {v19.8h, v20.8h}, [%5] \n" + "ext v21.16b, v15.16b, v17.16b, #2 \n" + "ext v22.16b, v16.16b, v18.16b, #2 \n" + "ext v23.16b, v15.16b, v17.16b, #4 \n" + "ext v24.16b, v16.16b, v18.16b, #4 \n" + "ext v25.16b, v15.16b, v17.16b, #6 \n" + "fmla v13.8h, v15.8h, %16.h[0] \n" + "fmla v13.8h, v16.8h, %16.h[1] \n" + "fmla v13.8h, v21.8h, %16.h[2] \n" + "fmla v13.8h, v22.8h, %16.h[3] \n" + "fmla v13.8h, v23.8h, %17.h[0] \n" + "fmla v13.8h, v24.8h, %17.h[1] \n" + "fmla v13.8h, v25.8h, %17.h[2] \n" + "ext v15.16b, v17.16b, v19.16b, #2 \n" + "ext v16.16b, v18.16b, v20.16b, #2 \n" + "ext v21.16b, v17.16b, v19.16b, #4 \n" + "ext v22.16b, v18.16b, v20.16b, #4 \n" + "ext v23.16b, v17.16b, v19.16b, #6 \n" + "fmla v14.8h, v17.8h, %16.h[0] \n" + "fmla v14.8h, v18.8h, %16.h[1] \n" + "fmla v14.8h, v15.8h, %16.h[2] \n" + "fmla v14.8h, v16.8h, %16.h[3] \n" + "fmla v14.8h, v21.8h, %17.h[0] \n" + "fmla v14.8h, v22.8h, %17.h[1] \n" + "fmla v14.8h, v23.8h, %17.h[2] \n" + + "ld2 {v26.8h, v27.8h}, [%6], #32 \n" + "ld2 {v28.8h, v29.8h}, [%6], #32 \n" + "ld2 {v30.8h, v31.8h}, [%6] \n" + "ext v21.16b, v26.16b, v28.16b, #2 \n" + "ext v22.16b, v27.16b, v29.16b, #2 \n" + "ext v23.16b, v26.16b, v28.16b, #4 \n" + "ext v24.16b, v27.16b, v29.16b, #4 \n" + "ext v25.16b, v26.16b, v28.16b, #6 \n" + "fmla v13.8h, v26.8h, %17.h[3] \n" + "fmla v13.8h, v27.8h, %18.h[0] \n" + "fmla v13.8h, v21.8h, %18.h[1] \n" + "fmla v13.8h, v22.8h, %18.h[2] \n" + "fmla v13.8h, v23.8h, %18.h[3] \n" + "fmla v13.8h, v24.8h, %19.h[0] \n" + "fmla v13.8h, v25.8h, %19.h[1] \n" + "ext v26.16b, v28.16b, v30.16b, #2 \n" + "ext v27.16b, v29.16b, v31.16b, #2 \n" + "ext v21.16b, v28.16b, v30.16b, #4 \n" + "ext v22.16b, v29.16b, v31.16b, #4 \n" + "ext v23.16b, v28.16b, v30.16b, #6 \n" + "fmla v14.8h, v28.8h, %17.h[3] \n" + "fmla v14.8h, v29.8h, %18.h[0] \n" + "fmla v14.8h, v26.8h, %18.h[1] \n" + "fmla v14.8h, v27.8h, %18.h[2] \n" + "fmla v14.8h, v21.8h, %18.h[3] \n" + "fmla v14.8h, v22.8h, %19.h[0] \n" + "fmla v14.8h, v23.8h, %19.h[1] \n" + + "ld2 {v15.8h, v16.8h}, [%7], #32 \n" + "ld2 {v17.8h, v18.8h}, [%7], #32 \n" + "ld2 {v19.8h, v20.8h}, [%7] \n" + "ext v21.16b, v15.16b, v17.16b, #2 \n" + "ext v22.16b, v16.16b, v18.16b, #2 \n" + "ext v23.16b, v15.16b, v17.16b, #4 \n" + "ext v24.16b, v16.16b, v18.16b, #4 \n" + "ext v25.16b, v15.16b, v17.16b, #6 \n" + "fmla v13.8h, v15.8h, %19.h[2] \n" + "fmla v13.8h, v16.8h, %19.h[3] \n" + "fmla v13.8h, v21.8h, %20.h[0] \n" + "fmla v13.8h, v22.8h, %20.h[1] \n" + "fmla v13.8h, v23.8h, %20.h[2] \n" + "fmla v13.8h, v24.8h, %20.h[3] \n" + "fmla v13.8h, v25.8h, %21.8h \n" + "ext v15.16b, v17.16b, v19.16b, #2 \n" + "ext v16.16b, v18.16b, v20.16b, #2 \n" + "ext v21.16b, v17.16b, v19.16b, #4 \n" + "ext v22.16b, v18.16b, v20.16b, #4 \n" + "ext v23.16b, v17.16b, v19.16b, #6 \n" + "fmla v14.8h, v17.8h, %19.h[2] \n" + "fmla v14.8h, v18.8h, %19.h[3] \n" + "fmla v14.8h, v15.8h, %20.h[0] \n" + "fmla v14.8h, v16.8h, %20.h[1] \n" + "fmla v14.8h, v21.8h, %20.h[2] \n" + "fmla v14.8h, v22.8h, %20.h[3] \n" + "fmla v14.8h, v23.8h, %21.8h \n" + + "add x1, x1, #2 \n" + "st1 {v13.8h, v14.8h}, [%0], #32 \n" + "cmp %8, x1 \n" + "bne 2b \n" + "3: \n" + + : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), + "+r"(r4), "+r"(r5), "+r"(r6) + : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), + "w"(_k12131415), "w"(_k16171819), "w"(_k20212223), + "w"(_k24252627), "w"(_k28293031), "w"(_k32333435), + "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), + "w"(_k48484848) + : "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } + filter += 49; + } +} + +} // namespace conv_stride2 +} // namespace fp16 +} // namespace aarch64 +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/fp32/algos.cpp b/dnn/src/aarch64/conv_bias/fp32/algos.cpp new file mode 100644 index 00000000..300e0c2b --- /dev/null +++ b/dnn/src/aarch64/conv_bias/fp32/algos.cpp @@ -0,0 +1,137 @@ +/** + * \file dnn/src/aarch64/conv_bias/fp32/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/conv_bias/fp32/algos.h" +#include "src/aarch64/conv_bias/fp32/stride2_kern.h" +#include "src/arm_common/conv_bias/direct/multi_thread_common.h" +#include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/fallback/conv_bias/common.h" + +#include "midout.h" + +using namespace megdnn; +using namespace aarch64; + +MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) +bool ConvBiasImpl::AlgoF32DirectStride2::usable( + FallbackConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + bool aviliable = + param.filter_meta.format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + param.dst_type.enumv() == DTypeEnum::Float32 && + !fm.should_flip && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); + if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + aviliable &= (large_group == m_large_group); + } + return aviliable; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( + FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 1) { + auto wbundle = arm_common::MultithreadDirectConvCommon< + float, float>::get_bundle_stride(param, m_large_group); + return wbundle.total_size_in_bytes(); + } + MIDOUT_END(); + return 0; +} +SmallVector +ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( + FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { + return get_kimpls(param); + } + MIDOUT_END(); + return {}; +} + +SmallVector +ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( + const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + auto FH = fm.spatial[0]; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + using Func = std::function; + Func conv = nullptr; + if (FH == 2) { + conv = fp32::conv_stride2::do_conv_2x2_stride2; + } else if (FH == 3) { + conv = fp32::conv_stride2::do_conv_3x3_stride2; + } else if (FH == 5) { + conv = fp32::conv_stride2::do_conv_5x5_stride2; + } else if (FH == 7) { + conv = fp32::conv_stride2::do_conv_7x7_stride2; + } + + WorkspaceBundle wbundle = arm_common::MultithreadDirectConvCommon< + float, float>::get_bundle_stride(param, m_large_group); + SmallVector ret_kerns; + + //! Dense conv and small group + if (m_large_group) { + //! Channel wise conv and big groups + auto exec_one_group = [wbundle, conv](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + arm_common::MultithreadDirectConvCommon:: + copy_padding_kern_stride(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + arm_common::MultithreadDirectConvCommon< + float, float>::do_conv_kern_stride(bundle, kern_param, + ncb_index, conv, + {ncb_index.thread_id, + 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + arm_common::MultithreadDirectConvCommon:: + copy_padding_kern_stride(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle, conv](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + arm_common::MultithreadDirectConvCommon< + float, float>::do_conv_kern_stride(bundle, kern_param, + ncb_index, conv, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} diff --git a/dnn/src/aarch64/conv_bias/fp32/algos.h b/dnn/src/aarch64/conv_bias/fp32/algos.h new file mode 100644 index 00000000..24c79858 --- /dev/null +++ b/dnn/src/aarch64/conv_bias/fp32/algos.h @@ -0,0 +1,47 @@ +/** + * \file dnn/src/aarch64/conv_bias/fp32/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/aarch64/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/opr_impl.h" + +namespace megdnn { +namespace aarch64 { + +using FallbackConvBiasImpl = fallback::ConvBiasImpl; +/* ===================== stride-2 algo ===================== */ + +class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + bool m_large_group; +public: + AlgoF32DirectStride2(bool large_group) : m_large_group(large_group) {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "ARMV8F32STRD2_LARGE_GROUP" + : "ARMV8F32STRD2_SMALL_GROUP"; + } + + bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(FallbackConvBiasImpl*, + const NCBKernSizeParam& param) const override; + + SmallVector dispatch_kerns(FallbackConvBiasImpl*, + const NCBKernSizeParam&) const override; +}; + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/fp32/stride2_kern.h b/dnn/src/aarch64/conv_bias/fp32/stride2_kern.h new file mode 100644 index 00000000..a02ddfd8 --- /dev/null +++ b/dnn/src/aarch64/conv_bias/fp32/stride2_kern.h @@ -0,0 +1,1024 @@ +/** + * \file dnn/src/aarch64/conv_bias/fp32/stride2_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include +#include "src/arm_common/simd_macro/neon_helper.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace aarch64 { +namespace fp32{ +namespace conv_stride2 { + + +//! For the detail tune process, refer to `expr/conv_aarch64_stride2/main.cpp` + +// refer to function do_conv_2x2_stride2_asm_unroll4 +static void do_conv_2x2_stride2(const float* src, const float* filter, + float* dst, size_t IH, size_t IW, size_t OH, + size_t OW, size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + size_t width = OW >> 2; + size_t mod4_left = width & 3; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + + const float* k0 = filter; + + MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); + rep(h, OH) { + asm volatile( + "dup v28.4s, %5.s[0] \n" + "dup v29.4s, %5.s[1] \n" + "dup v30.4s, %5.s[2] \n" + "dup v31.4s, %5.s[3] \n" + "cmp %4, #2 \n" + "mov x1, xzr \n" + // mod4_left == 3 + "bgt 0f \n" + // mod4_left == 2 + "beq 1f \n" + "cmp %4, #1 \n" + // mod4_left == 1 + "beq 2f \n" + // mod4_left == 0 + "b 3f \n" + + // mod4_left == 3 + "0: \n" + "ld1 {v0.4s, v1.4s, v2.4s}, [%1] \n" + + "ld2 {v3.4s, v4.4s}, [%2], #32 \n" + "ld2 {v9.4s, v10.4s}, [%3], #32 \n" + "ld2 {v5.4s, v6.4s}, [%2], #32 \n" + "ld2 {v11.4s, v12.4s}, [%3], #32 \n" + "ld2 {v7.4s, v8.4s}, [%2], #32 \n" + "ld2 {v13.4s, v14.4s}, [%3], #32 \n" + "fmla v0.4s, v3.4s, v28.4s \n" + "fmla v1.4s, v5.4s, v28.4s \n" + "fmla v2.4s, v7.4s, v28.4s \n" + "fmla v0.4s, v4.4s, v29.4s \n" + "fmla v1.4s, v6.4s, v29.4s \n" + "fmla v2.4s, v8.4s, v29.4s \n" + + "fmla v0.4s, v9.4s, v30.4s \n" + "fmla v1.4s, v11.4s, v30.4s \n" + "fmla v2.4s, v13.4s, v30.4s \n" + "fmla v0.4s, v10.4s, v31.4s \n" + "fmla v1.4s, v12.4s, v31.4s \n" + "fmla v2.4s, v14.4s, v31.4s \n" + + "add x1, x1, #3 \n" + "st1 {v0.4s, v1.4s, v2.4s}, [%1], #48 \n" + "b 3f \n" + + // mod4_left == 2 + "1: \n" + "ld1 {v0.4s, v1.4s}, [%1] \n" + + "ld2 {v2.4s, v3.4s}, [%2], #32 \n" + "ld2 {v6.4s, v7.4s}, [%3], #32 \n" + "ld2 {v4.4s, v5.4s}, [%2], #32 \n" + "ld2 {v8.4s, v9.4s}, [%3], #32 \n" + "fmla v0.4s, v2.4s, v28.4s \n" + "fmla v1.4s, v4.4s, v28.4s \n" + "fmla v0.4s, v3.4s, v29.4s \n" + "fmla v1.4s, v5.4s, v29.4s \n" + + "fmla v0.4s, v6.4s, v30.4s \n" + "fmla v1.4s, v8.4s, v30.4s \n" + "fmla v0.4s, v7.4s, v31.4s \n" + "fmla v1.4s, v9.4s, v31.4s \n" + + "add x1, x1, #2 \n" + "st1 {v0.4s, v1.4s}, [%1], #32 \n" + "b 3f \n" + + // mod4_left == 1 + "2: \n" + "ld1 {v0.4s}, [%1] \n" + + "ld2 {v1.4s, v2.4s}, [%2], #32 \n" + "ld2 {v3.4s, v4.4s}, [%3], #32 \n" + "fmla v0.4s, v1.4s, v28.4s \n" + "fmla v0.4s, v2.4s, v29.4s \n" + + "fmla v0.4s, v3.4s, v30.4s \n" + "fmla v0.4s, v4.4s, v31.4s \n" + + "add x1, x1, #1 \n" + "st1 {v0.4s}, [%1], #16 \n" + "b 3f \n" + + // mod4_left == 0 + "3: \n" + "cmp %0, x1 \n" + "beq 5f \n" + "4: \n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%1] \n" + + "ld2 {v4.4s, v5.4s}, [%2], #32 \n" + "ld2 {v12.4s, v13.4s}, [%3], #32 \n" + "ld2 {v6.4s, v7.4s}, [%2], #32 \n" + "ld2 {v14.4s, v15.4s}, [%3], #32 \n" + "ld2 {v8.4s, v9.4s}, [%2], #32 \n" + "ld2 {v16.4s, v17.4s}, [%3], #32 \n" + "ld2 {v10.4s, v11.4s}, [%2], #32 \n" + "ld2 {v18.4s, v19.4s}, [%3], #32 \n" + "fmla v0.4s, v4.4s, v28.4s \n" + "fmla v1.4s, v6.4s, v28.4s \n" + "fmla v2.4s, v8.4s, v28.4s \n" + "fmla v3.4s, v10.4s, v28.4s \n" + "fmla v0.4s, v5.4s, v29.4s \n" + "fmla v1.4s, v7.4s, v29.4s \n" + "fmla v2.4s, v9.4s, v29.4s \n" + "fmla v3.4s, v11.4s, v29.4s \n" + + "fmla v0.4s, v12.4s, v30.4s \n" + "fmla v1.4s, v14.4s, v30.4s \n" + "fmla v2.4s, v16.4s, v30.4s \n" + "fmla v3.4s, v18.4s, v30.4s \n" + "fmla v0.4s, v13.4s, v31.4s \n" + "fmla v1.4s, v15.4s, v31.4s \n" + "fmla v2.4s, v17.4s, v31.4s \n" + "fmla v3.4s, v19.4s, v31.4s \n" + + "add x1, x1, #4 \n" + "cmp %0, x1 \n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%1], #64 \n" + "bne 4b \n" + + "5: \n" + : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) + : "r"(mod4_left), "w"(_k0123) + : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30", + "v31"); + + r0 += tail_step; + r1 += tail_step; + } + + filter += 4; + } +} + +// refer to function do_conv_3x3_stride2_asm_unroll3 +static void do_conv_3x3_stride2(const float* src, const float* filter, + float* dst, size_t IH, size_t IW, size_t OH, + size_t OW, size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + size_t width = OW >> 2; + size_t mod3_left = width % 3; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + + const float* k0 = filter; + const float* k1 = filter + 3; + const float* k2 = filter + 5; + + MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); + MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); + MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); + rep(h, OH) { + asm volatile( + "dup v21.4s, %6.s[0] \n" + "dup v22.4s, %6.s[1] \n" + "dup v23.4s, %6.s[2] \n" + "dup v24.4s, %6.s[3] \n" + "dup v25.4s, %7.s[1] \n" + "dup v26.4s, %7.s[2] \n" + "dup v27.4s, %7.s[3] \n" + "dup v28.4s, %8.s[2] \n" + "dup v29.4s, %8.s[3] \n" + "cmp %5, #1 \n" + "mov x1, xzr \n" + "bgt 0f \n" // mod3_left == 2 + "beq 1f \n" // mod3_left == 1 + "blt 2f \n" // mod3_left == 0 + + "0: \n" + "ld1 {v0.4s, v1.4s}, [%1] \n" + + "ld2 {v2.4s, v3.4s}, [%2], #32 \n" + "ld2 {v9.4s, v10.4s}, [%3], #32 \n" + "ld2 {v4.4s, v5.4s}, [%2], #32 \n" + "ld2 {v11.4s, v12.4s}, [%3], #32 \n" + "fmla v0.4s, v2.4s, v21.4s \n" + "fmla v1.4s, v4.4s, v21.4s \n" + "fmla v0.4s, v3.4s, v22.4s \n" + "fmla v1.4s, v5.4s, v22.4s \n" + "ld1 {v6.4s}, [%2] \n" + "ld1 {v13.4s}, [%3] \n" + "ext v7.16b, v2.16b, v4.16b, #4 \n" + "ext v8.16b, v4.16b, v6.16b, #4 \n" + "fmla v0.4s, v7.4s, v23.4s \n" + "fmla v1.4s, v8.4s, v23.4s \n" + + "ld2 {v2.4s, v3.4s}, [%4], #32 \n" + "fmla v0.4s, v9.4s, v24.4s \n" + "fmla v1.4s, v11.4s, v24.4s \n" + "fmla v0.4s, v10.4s, v25.4s \n" + "fmla v1.4s, v12.4s, v25.4s \n" + "ld2 {v4.4s, v5.4s}, [%4], #32 \n" + "ext v14.16b, v9.16b, v11.16b, #4 \n" + "ext v15.16b, v11.16b, v13.16b, #4 \n" + "fmla v0.4s, v14.4s, v26.4s \n" + "fmla v1.4s, v15.4s, v26.4s \n" + + "ld1 {v6.4s}, [%4] \n" + "fmla v0.4s, v2.4s, v27.4s \n" + "fmla v1.4s, v4.4s, v27.4s \n" + "fmla v0.4s, v3.4s, v28.4s \n" + "fmla v1.4s, v5.4s, v28.4s \n" + "ext v7.16b, v2.16b, v4.16b, #4 \n" + "ext v8.16b, v4.16b, v6.16b, #4 \n" + "fmla v0.4s, v7.4s, v29.4s \n" + "fmla v1.4s, v8.4s, v29.4s \n" + + "add x1, x1, #2 \n" + "cmp %0, x1 \n" + + "st1 {v0.4s, v1.4s}, [%1], #32 \n" + "bne 2f \n" // if width != 2 jump to 2 + "b 3f \n" // jump end + + "1: \n" + "ld1 {v0.4s}, [%1] \n" // load dst 0, 1, 2, 3 + "ld2 {v1.4s, v2.4s}, [%2], #32 \n" // 0, 2, 4, 6 + + "ld2 {v5.4s, v6.4s}, [%3], #32 \n" + "ld1 {v3.4s}, [%2] \n" // load src 8 12 ... + "fmla v0.4s, v1.4s, v21.4s \n" // src[i] * k[i] + "ext v7.16b, v1.16b, v3.16b, #4 \n" // 2, 4, 6, 8 + "fmla v0.4s, v2.4s, v22.4s \n" + "ld1 {v1.4s}, [%3] \n" // load src 8 12 ... + "fmla v0.4s, v7.4s, v23.4s \n" + "ld2 {v3.4s, v4.4s}, [%4], #32 \n" + + "fmla v0.4s, v5.4s, v24.4s \n" + "ext v7.16b, v5.16b, v1.16b, #4 \n" // 2, 4, 6, 8 + "fmla v0.4s, v6.4s, v25.4s \n" + "ld1 {v5.4s}, [%4] \n" // load src 8 12 ... + "fmla v0.4s, v7.4s, v26.4s \n" + + "fmla v0.4s, v3.4s, v27.4s \n" + "fmla v0.4s, v4.4s, v28.4s \n" + "ext v7.16b, v3.16b, v5.16b, #4 \n" // 2, 4, 6, 8 + "fmla v0.4s, v7.4s, v29.4s \n" + + "st1 {v0.4s}, [%1], #16 \n" + + "add x1, x1, #1 \n" + "cmp %0, x1 \n" + "beq 3f \n" + + "2: \n" + "ld1 {v0.4s, v1.4s, v2.4s}, [%1] \n" + + "ld2 {v3.4s, v4.4s}, [%2], #32 \n" + "ld2 {v11.4s, v12.4s}, [%3], #32 \n" + "ld2 {v5.4s, v6.4s}, [%2], #32 \n" + "ld2 {v13.4s, v14.4s}, [%3], #32 \n" + "ld2 {v7.4s, v8.4s}, [%2], #32 \n" + "ld2 {v15.4s, v16.4s}, [%3], #32 \n" + "fmla v0.4s, v3.4s, v21.4s \n" + "fmla v1.4s, v5.4s, v21.4s \n" + "fmla v2.4s, v7.4s, v21.4s \n" + "ld1 {v9.4s}, [%2] \n" + "ld1 {v17.4s}, [%3] \n" + "fmla v0.4s, v4.4s, v22.4s \n" + "fmla v1.4s, v6.4s, v22.4s \n" + "fmla v2.4s, v8.4s, v22.4s \n" + "ext v10.16b, v3.16b, v5.16b, #4 \n" + "ext v4.16b, v5.16b, v7.16b, #4 \n" + "ext v6.16b, v7.16b, v9.16b, #4 \n" + "fmla v0.4s, v10.4s, v23.4s \n" + "fmla v1.4s, v4.4s, v23.4s \n" + "fmla v2.4s, v6.4s, v23.4s \n" + + "ld2 {v3.4s, v4.4s}, [%4], #32 \n" + "fmla v0.4s, v11.4s, v24.4s \n" + "fmla v1.4s, v13.4s, v24.4s \n" + "fmla v2.4s, v15.4s, v24.4s \n" + "ld2 {v5.4s, v6.4s}, [%4], #32 \n" + "fmla v0.4s, v12.4s, v25.4s \n" + "fmla v1.4s, v14.4s, v25.4s \n" + "fmla v2.4s, v16.4s, v25.4s \n" + "ld2 {v7.4s, v8.4s}, [%4], #32 \n" + "ext v18.16b, v11.16b, v13.16b, #4 \n" + "ext v12.16b, v13.16b, v15.16b, #4 \n" + "ext v14.16b, v15.16b, v17.16b, #4 \n" + "fmla v0.4s, v18.4s, v26.4s \n" + "fmla v1.4s, v12.4s, v26.4s \n" + "fmla v2.4s, v14.4s, v26.4s \n" + + "ld1 {v9.4s}, [%4] \n" + "fmla v0.4s, v3.4s, v27.4s \n" + "fmla v1.4s, v5.4s, v27.4s \n" + "fmla v2.4s, v7.4s, v27.4s \n" + "fmla v0.4s, v4.4s, v28.4s \n" + "fmla v1.4s, v6.4s, v28.4s \n" + "fmla v2.4s, v8.4s, v28.4s \n" + "ext v10.16b, v3.16b, v5.16b, #4 \n" + "ext v4.16b, v5.16b, v7.16b, #4 \n" + "ext v6.16b, v7.16b, v9.16b, #4 \n" + "fmla v0.4s, v10.4s, v29.4s \n" + "fmla v1.4s, v4.4s, v29.4s \n" + "fmla v2.4s, v6.4s, v29.4s \n" + + "add x1, x1, #3 \n" + "cmp %0, x1 \n" + + "st1 {v0.4s, v1.4s, v2.4s}, [%1], #48 \n" + "bne 2b \n" // if + "3: \n" + : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) + : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) + : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29"); + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } + + filter += 9; + } +} + +// refer to function do_conv_5x5_stride2_asm_unroll2 +static void do_conv_5x5_stride2(const float* src, const float* filter, + float* dst, size_t IH, size_t IW, size_t OH, + size_t OW, size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + size_t width = OW >> 2; + size_t mod2_left = width & 1; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + const float* r3 = src_ptr + IW * 3; + const float* r4 = src_ptr + IW * 4; + + MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); + MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); + MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); + MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); + MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); + MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); + MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]); + + for (size_t i = 0; i < OH; i++) { + asm volatile( + "cmp %14, #0 \n" + "mov x1, xzr \n" + "beq 1f \n" + + // mod2_left == 1 + "0: \n" + "ld1 {v0.4s}, [%1] \n" + + // v1.4s: 0 2 4 6 v2.4s: 1 3 5 7 + "ld2 {v1.4s, v2.4s}, [%2], #32 \n" + "ld2 {v8.4s, v9.4s}, [%3], #32 \n" + "fmla v0.4s, v1.4s, %7.s[0] \n" + "fmla v0.4s, v2.4s, %7.s[1] \n" + "ld2 {v3.4s, v4.4s}, [%2] \n" + "ld2 {v10.4s, v11.4s}, [%3] \n" + // v5.4s: 2 4 6 8 + "ext v5.16b, v1.16b, v3.16b, #4 \n" + // v6.4s: 3 5 7 9 + "ext v6.16b, v2.16b, v4.16b, #4 \n" + "fmla v0.4s, v5.4s, %7.s[2] \n" + "fmla v0.4s, v6.4s, %7.s[3] \n" + // v7.4s: 4 6 8 10 + "ext v7.16b, v1.16b, v3.16b, #8 \n" + "fmla v0.4s, v7.4s, %8.s[0] \n" + + "ld2 {v1.4s, v2.4s}, [%4], #32 \n" + "fmla v0.4s, v8.4s, %8.s[1] \n" + "fmla v0.4s, v9.4s, %8.s[2] \n" + "ld2 {v3.4s, v4.4s}, [%4] \n" + "ext v12.16b, v8.16b, v10.16b, #4 \n" + "ext v13.16b, v9.16b, v11.16b, #4 \n" + "fmla v0.4s, v12.4s, %8.s[3] \n" + "fmla v0.4s, v13.4s, %9.s[0] \n" + "ext v14.16b, v8.16b, v10.16b, #8 \n" + "fmla v0.4s, v14.4s, %9.s[1] \n" + + "ld2 {v8.4s, v9.4s}, [%5], #32 \n" + "fmla v0.4s, v1.4s, %9.s[2] \n" + "fmla v0.4s, v2.4s, %9.s[3] \n" + "ld2 {v10.4s, v11.4s}, [%5] \n" + "ext v5.16b, v1.16b, v3.16b, #4 \n" + "ext v6.16b, v2.16b, v4.16b, #4 \n" + "fmla v0.4s, v5.4s, %10.s[0] \n" + "fmla v0.4s, v6.4s, %10.s[1] \n" + "ext v7.16b, v1.16b, v3.16b, #8 \n" + "fmla v0.4s, v7.4s, %10.s[2] \n" + + "ld2 {v1.4s, v2.4s}, [%6], #32 \n" + "fmla v0.4s, v8.4s, %10.s[3] \n" + "fmla v0.4s, v9.4s, %11.s[0] \n" + "ld2 {v3.4s, v4.4s}, [%6] \n" + "ext v12.16b, v8.16b, v10.16b, #4 \n" + "ext v13.16b, v9.16b, v11.16b, #4 \n" + "fmla v0.4s, v12.4s, %11.s[1] \n" + "fmla v0.4s, v13.4s, %11.s[2] \n" + "ext v14.16b, v8.16b, v10.16b, #8 \n" + "fmla v0.4s, v14.4s, %11.s[3] \n" + + "fmla v0.4s, v1.4s, %12.s[0] \n" + "fmla v0.4s, v2.4s, %12.s[1] \n" + "ext v5.16b, v1.16b, v3.16b, #4 \n" + "ext v6.16b, v2.16b, v4.16b, #4 \n" + "fmla v0.4s, v5.4s, %12.s[2] \n" + "fmla v0.4s, v6.4s, %12.s[3] \n" + "ext v7.16b, v1.16b, v3.16b, #8 \n" + "fmla v0.4s, v7.4s, %13.s[0] \n" + + "add x1, x1, #1 \n" + "st1 {v0.4s}, [%1], #16 \n" + + "1: \n" + "cmp %0, x1 \n" + "beq 3f \n" + + // mod2_left == 0 + "2: \n" + "ld1 {v0.4s, v1.4s}, [%1] \n" + + // v2.4s: 0 2 4 6 v3.4s: 1 3 5 7 + "ld2 {v2.4s, v3.4s}, [%2], #32 \n" + "ld2 {v14.4s, v15.4s}, [%3], #32 \n" + // v4.4s: 8 10 12 14 v5.4s: 9 11 13 15 + "ld2 {v4.4s, v5.4s}, [%2], #32 \n" + "ld2 {v16.4s, v17.4s}, [%3], #32 \n" + // v6.4s: 16 18 20 22 v7.4s: 17 19 21 23 + "ld2 {v6.4s, v7.4s}, [%2] \n" + "ld2 {v18.4s, v19.4s}, [%3] \n" + // v8.4s: 2 4 6 8 + "ext v8.16b, v2.16b, v4.16b, #4 \n" + // v9.4s: 3 5 7 9 + "ext v9.16b, v3.16b, v5.16b, #4 \n" + // v10.4s: 4 6 8 10 + "ext v10.16b, v2.16b, v4.16b, #8 \n" + // v11.4s: 10 12 14 16 + "ext v11.16b, v4.16b, v6.16b, #4 \n" + // v12.4s: 11 13 15 17 + "ext v12.16b, v5.16b, v7.16b, #4 \n" + // v13.4s: 12 14 16 18 + "ext v13.16b, v4.16b, v6.16b, #8 \n" + "fmla v0.4s, v2.4s, %7.s[0] \n" + "fmla v0.4s, v3.4s, %7.s[1] \n" + "fmla v0.4s, v8.4s, %7.s[2] \n" + "fmla v0.4s, v9.4s, %7.s[3] \n" + "fmla v0.4s, v10.4s, %8.s[0] \n" + "fmla v1.4s, v4.4s, %7.s[0] \n" + "fmla v1.4s, v5.4s, %7.s[1] \n" + "fmla v1.4s, v11.4s, %7.s[2] \n" + "fmla v1.4s, v12.4s, %7.s[3] \n" + "fmla v1.4s, v13.4s, %8.s[0] \n" + + "ld2 {v2.4s, v3.4s}, [%4], #32 \n" + "ext v20.16b, v14.16b, v16.16b, #4 \n" + "ext v21.16b, v15.16b, v17.16b, #4 \n" + "ext v22.16b, v14.16b, v16.16b, #8 \n" + "fmla v0.4s, v14.4s, %8.s[1] \n" + "fmla v0.4s, v15.4s, %8.s[2] \n" + "fmla v0.4s, v20.4s, %8.s[3] \n" + "fmla v0.4s, v21.4s, %9.s[0] \n" + "fmla v0.4s, v22.4s, %9.s[1] \n" + "ld2 {v4.4s, v5.4s}, [%4], #32 \n" + "ext v23.16b, v16.16b, v18.16b, #4 \n" + "ext v24.16b, v17.16b, v19.16b, #4 \n" + "ext v14.16b, v16.16b, v18.16b, #8 \n" + "ld2 {v6.4s, v7.4s}, [%4] \n" + "fmla v1.4s, v16.4s, %8.s[1] \n" + "fmla v1.4s, v17.4s, %8.s[2] \n" + "fmla v1.4s, v23.4s, %8.s[3] \n" + "fmla v1.4s, v24.4s, %9.s[0] \n" + "fmla v1.4s, v14.4s, %9.s[1] \n" + + "ld2 {v14.4s, v15.4s}, [%5], #32 \n" + "ext v8.16b, v2.16b, v4.16b, #4 \n" + "ext v9.16b, v3.16b, v5.16b, #4 \n" + "ext v10.16b, v2.16b, v4.16b, #8 \n" + "ext v11.16b, v4.16b, v6.16b, #4 \n" + "ext v12.16b, v5.16b, v7.16b, #4 \n" + "ext v13.16b, v4.16b, v6.16b, #8 \n" + "ld2 {v16.4s, v17.4s}, [%5], #32 \n" + "fmla v0.4s, v2.4s, %9.s[2] \n" + "fmla v0.4s, v3.4s, %9.s[3] \n" + "fmla v0.4s, v8.4s, %10.s[0] \n" + "fmla v0.4s, v9.4s, %10.s[1] \n" + "fmla v0.4s, v10.4s, %10.s[2] \n" + "ld2 {v18.4s, v19.4s}, [%5] \n" + "fmla v1.4s, v4.4s, %9.s[2] \n" + "fmla v1.4s, v5.4s, %9.s[3] \n" + "fmla v1.4s, v11.4s, %10.s[0] \n" + "fmla v1.4s, v12.4s, %10.s[1] \n" + "fmla v1.4s, v13.4s, %10.s[2] \n" + + "ld2 {v2.4s, v3.4s}, [%6], #32 \n" + "ext v20.16b, v14.16b, v16.16b, #4 \n" + "ext v21.16b, v15.16b, v17.16b, #4 \n" + "ext v22.16b, v14.16b, v16.16b, #8 \n" + "fmla v0.4s, v14.4s, %10.s[3] \n" + "fmla v0.4s, v15.4s, %11.s[0] \n" + "fmla v0.4s, v20.4s, %11.s[1] \n" + "fmla v0.4s, v21.4s, %11.s[2] \n" + "fmla v0.4s, v22.4s, %11.s[3] \n" + "ld2 {v4.4s, v5.4s}, [%6], #32 \n" + "ext v23.16b, v16.16b, v18.16b, #4 \n" + "ext v24.16b, v17.16b, v19.16b, #4 \n" + "ext v14.16b, v16.16b, v18.16b, #8 \n" + "ld2 {v6.4s, v7.4s}, [%6] \n" + "fmla v1.4s, v16.4s, %10.s[3] \n" + "fmla v1.4s, v17.4s, %11.s[0] \n" + "fmla v1.4s, v23.4s, %11.s[1] \n" + "fmla v1.4s, v24.4s, %11.s[2] \n" + "fmla v1.4s, v14.4s, %11.s[3] \n" + + "ext v8.16b, v2.16b, v4.16b, #4 \n" + "ext v9.16b, v3.16b, v5.16b, #4 \n" + "ext v10.16b, v2.16b, v4.16b, #8 \n" + "ext v11.16b, v4.16b, v6.16b, #4 \n" + "ext v12.16b, v5.16b, v7.16b, #4 \n" + "ext v13.16b, v4.16b, v6.16b, #8 \n" + "fmla v0.4s, v2.4s, %12.s[0] \n" + "fmla v0.4s, v3.4s, %12.s[1] \n" + "fmla v0.4s, v8.4s, %12.s[2] \n" + "fmla v0.4s, v9.4s, %12.s[3] \n" + "fmla v0.4s, v10.4s, %13.4s \n" + "fmla v1.4s, v4.4s, %12.s[0] \n" + "fmla v1.4s, v5.4s, %12.s[1] \n" + "fmla v1.4s, v11.4s, %12.s[2] \n" + "fmla v1.4s, v12.4s, %12.s[3] \n" + "fmla v1.4s, v13.4s, %13.4s \n" + + "add x1, x1, #2 \n" + "cmp %0, x1 \n" + "st1 {v0.4s, v1.4s}, [%1], #32 \n" + "bne 2b \n" + "3: \n" + + : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), + "+r"(r3), "+r"(r4) + : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), + "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), + "r"(mod2_left) + : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24"); + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } + + filter += 25; + } +} + +// refer to function do_conv_7x7_stride2_asm_unroll2 +static void do_conv_7x7_stride2(const float* src, const float* filter, + float* dst, size_t IH, size_t IW, size_t OH, + size_t OW, size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + size_t width = OW >> 2; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + const float* r3 = src_ptr + IW * 3; + const float* r4 = src_ptr + IW * 4; + const float* r5 = src_ptr + IW * 5; + const float* r6 = src_ptr + IW * 6; + + MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); + MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); + MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); + MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); + MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); + MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); + MEGDNN_SIMD_TYPE _k24252627 = MEGDNN_SIMD_LOADU(filter + 24); + MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(filter + 28); + MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(filter + 32); + MEGDNN_SIMD_TYPE _k36373839 = MEGDNN_SIMD_LOADU(filter + 36); + MEGDNN_SIMD_TYPE _k40414243 = MEGDNN_SIMD_LOADU(filter + 40); + MEGDNN_SIMD_TYPE _k44454647 = MEGDNN_SIMD_LOADU(filter + 44); + MEGDNN_SIMD_TYPE _k48484848 = MEGDNN_SIMD_SET1(filter[48]); + + for (size_t i = 0; i < OH; i++) { + asm volatile( + "and x1, %8, #1 \n" + "cmp x1, #0 \n" + "mov x1, xzr \n" + "beq 1f \n" + + // mod2_left == 1 + "0: \n" + "ld1 {v0.4s}, [%0] \n" + + // v1.4s: 0 2 4 6 v2.4s: 1 3 5 7 + "ld2 {v1.4s, v2.4s}, [%1], #32 \n" + "ld2 {v10.4s, v11.4s}, [%2], #32 \n" + "ld2 {v3.4s, v4.4s}, [%1] \n" + "ld2 {v12.4s, v13.4s}, [%2] \n" + // v5.4s: 2 4 6 8 + "ext v5.16b, v1.16b, v3.16b, #4 \n" + // v6.4s: 3 5 7 9 + "ext v6.16b, v2.16b, v4.16b, #4 \n" + // v7.4s: 4 6 8 10 + "ext v7.16b, v1.16b, v3.16b, #8 \n" + // v8.4s: 5 7 9 11 + "ext v8.16b, v2.16b, v4.16b, #8 \n" + // v9.4s: 6 8 10 12 + "ext v9.16b, v1.16b, v3.16b, #12 \n" + "fmla v0.4s, v1.4s, %9.s[0] \n" + "fmla v0.4s, v2.4s, %9.s[1] \n" + "fmla v0.4s, v5.4s, %9.s[2] \n" + "fmla v0.4s, v6.4s, %9.s[3] \n" + "fmla v0.4s, v7.4s, %10.s[0] \n" + "fmla v0.4s, v8.4s, %10.s[1] \n" + "fmla v0.4s, v9.4s, %10.s[2] \n" + + "ld2 {v1.4s, v2.4s}, [%3], #32 \n" + "ext v14.16b, v10.16b, v12.16b, #4 \n" + "ext v15.16b, v11.16b, v13.16b, #4 \n" + "ext v16.16b, v10.16b, v12.16b, #8 \n" + "ext v17.16b, v11.16b, v13.16b, #8 \n" + "ext v18.16b, v10.16b, v12.16b, #12 \n" + "ld2 {v3.4s, v4.4s}, [%3] \n" + "fmla v0.4s, v10.4s, %10.s[3] \n" + "fmla v0.4s, v11.4s, %11.s[0] \n" + "fmla v0.4s, v14.4s, %11.s[1] \n" + "fmla v0.4s, v15.4s, %11.s[2] \n" + "fmla v0.4s, v16.4s, %11.s[3] \n" + "fmla v0.4s, v17.4s, %12.s[0] \n" + "fmla v0.4s, v18.4s, %12.s[1] \n" + + "ld2 {v10.4s, v11.4s}, [%4], #32 \n" + "ext v5.16b, v1.16b, v3.16b, #4 \n" + "ext v6.16b, v2.16b, v4.16b, #4 \n" + "ext v7.16b, v1.16b, v3.16b, #8 \n" + "ext v8.16b, v2.16b, v4.16b, #8 \n" + "ext v9.16b, v1.16b, v3.16b, #12 \n" + "ld2 {v12.4s, v13.4s}, [%4] \n" + "fmla v0.4s, v1.4s, %12.s[2] \n" + "fmla v0.4s, v2.4s, %12.s[3] \n" + "fmla v0.4s, v5.4s, %13.s[0] \n" + "fmla v0.4s, v6.4s, %13.s[1] \n" + "fmla v0.4s, v7.4s, %13.s[2] \n" + "fmla v0.4s, v8.4s, %13.s[3] \n" + "fmla v0.4s, v9.4s, %14.s[0] \n" + + "ld2 {v1.4s, v2.4s}, [%5], #32 \n" + "ext v14.16b, v10.16b, v12.16b, #4 \n" + "ext v15.16b, v11.16b, v13.16b, #4 \n" + "ext v16.16b, v10.16b, v12.16b, #8 \n" + "ext v17.16b, v11.16b, v13.16b, #8 \n" + "ext v18.16b, v10.16b, v12.16b, #12 \n" + "ld2 {v3.4s, v4.4s}, [%5] \n" + "fmla v0.4s, v10.4s, %14.s[1] \n" + "fmla v0.4s, v11.4s, %14.s[2] \n" + "fmla v0.4s, v14.4s, %14.s[3] \n" + "fmla v0.4s, v15.4s, %15.s[0] \n" + "fmla v0.4s, v16.4s, %15.s[1] \n" + "fmla v0.4s, v17.4s, %15.s[2] \n" + "fmla v0.4s, v18.4s, %15.s[3] \n" + + "ld2 {v10.4s, v11.4s}, [%6], #32 \n" + "ext v5.16b, v1.16b, v3.16b, #4 \n" + "ext v6.16b, v2.16b, v4.16b, #4 \n" + "ext v7.16b, v1.16b, v3.16b, #8 \n" + "ext v8.16b, v2.16b, v4.16b, #8 \n" + "ext v9.16b, v1.16b, v3.16b, #12 \n" + "ld2 {v12.4s, v13.4s}, [%6] \n" + "fmla v0.4s, v1.4s, %16.s[0] \n" + "fmla v0.4s, v2.4s, %16.s[1] \n" + "fmla v0.4s, v5.4s, %16.s[2] \n" + "fmla v0.4s, v6.4s, %16.s[3] \n" + "fmla v0.4s, v7.4s, %17.s[0] \n" + "fmla v0.4s, v8.4s, %17.s[1] \n" + "fmla v0.4s, v9.4s, %17.s[2] \n" + + "ld2 {v1.4s, v2.4s}, [%7], #32 \n" + "ext v14.16b, v10.16b, v12.16b, #4 \n" + "ext v15.16b, v11.16b, v13.16b, #4 \n" + "ext v16.16b, v10.16b, v12.16b, #8 \n" + "ext v17.16b, v11.16b, v13.16b, #8 \n" + "ext v18.16b, v10.16b, v12.16b, #12 \n" + "ld2 {v3.4s, v4.4s}, [%7] \n" + "fmla v0.4s, v10.4s, %17.s[3] \n" + "fmla v0.4s, v11.4s, %18.s[0] \n" + "fmla v0.4s, v14.4s, %18.s[1] \n" + "fmla v0.4s, v15.4s, %18.s[2] \n" + "fmla v0.4s, v16.4s, %18.s[3] \n" + "fmla v0.4s, v17.4s, %19.s[0] \n" + "fmla v0.4s, v18.4s, %19.s[1] \n" + + "ext v5.16b, v1.16b, v3.16b, #4 \n" + "ext v6.16b, v2.16b, v4.16b, #4 \n" + "ext v7.16b, v1.16b, v3.16b, #8 \n" + "ext v8.16b, v2.16b, v4.16b, #8 \n" + "ext v9.16b, v1.16b, v3.16b, #12 \n" + "fmla v0.4s, v1.4s, %19.s[2] \n" + "fmla v0.4s, v2.4s, %19.s[3] \n" + "fmla v0.4s, v5.4s, %20.s[0] \n" + "fmla v0.4s, v6.4s, %20.s[1] \n" + "fmla v0.4s, v7.4s, %20.s[2] \n" + "fmla v0.4s, v8.4s, %20.s[3] \n" + "fmla v0.4s, v9.4s, %21.4s \n" + + "add x1, x1, #1 \n" + "st1 {v0.4s}, [%0], #16 \n" + + "1: \n" + "cmp %8, x1 \n" + "beq 3f \n" + + // mod2_left == 0 + "2: \n" + "ld1 {v0.4s, v1.4s}, [%0] \n" + + // v2.4s: 0 2 4 6 v3.4s: 1 3 5 7 + "ld2 {v2.4s, v3.4s}, [%1], #32 \n" + // v4.4s: 8 10 12 14 v3.4s: 9 11 13 15 + "ld2 {v4.4s, v5.4s}, [%1], #32 \n" + // v6.4s: 16 18 20 22 v7.4s: 17 19 21 23 + "ld2 {v6.4s, v7.4s}, [%1] \n" + // v8.4s: 2 4 6 8 + "ext v8.16b, v2.16b, v4.16b, #4 \n" + // v9.4s: 3 5 7 9 + "ext v9.16b, v3.16b, v5.16b, #4 \n" + // v10.4s: 4 6 8 10 + "ext v10.16b, v2.16b, v4.16b, #8 \n" + // v11.4s: 5 7 9 11 + "ext v11.16b, v3.16b, v5.16b, #8 \n" + // v12.4s: 6 8 10 12 + "ext v12.16b, v2.16b, v4.16b, #12 \n" + "fmla v0.4s, v2.4s, %9.s[0] \n" + "fmla v0.4s, v3.4s, %9.s[1] \n" + "fmla v0.4s, v8.4s, %9.s[2] \n" + "fmla v0.4s, v9.4s, %9.s[3] \n" + "fmla v0.4s, v10.4s, %10.s[0] \n" + "fmla v0.4s, v11.4s, %10.s[1] \n" + "fmla v0.4s, v12.4s, %10.s[2] \n" + // v2.4s: 10 12 14 16 + "ext v2.16b, v4.16b, v6.16b, #4 \n" + // v3.4s: 11 13 15 17 + "ext v3.16b, v5.16b, v7.16b, #4 \n" + // v8.4s: 12 14 16 18 + "ext v8.16b, v4.16b, v6.16b, #8 \n" + // v9.4s: 13 15 17 19 + "ext v9.16b, v5.16b, v7.16b, #8 \n" + // v10.4s: 14 16 18 19 + "ext v10.16b, v4.16b, v6.16b, #12 \n" + "fmla v1.4s, v4.4s, %9.s[0] \n" + "fmla v1.4s, v5.4s, %9.s[1] \n" + "fmla v1.4s, v2.4s, %9.s[2] \n" + "fmla v1.4s, v3.4s, %9.s[3] \n" + "fmla v1.4s, v8.4s, %10.s[0] \n" + "fmla v1.4s, v9.4s, %10.s[1] \n" + "fmla v1.4s, v10.4s, %10.s[2] \n" + + "ld2 {v13.4s, v14.4s}, [%2], #32 \n" + "ld2 {v15.4s, v16.4s}, [%2], #32 \n" + "ld2 {v17.4s, v18.4s}, [%2] \n" + "ext v8.16b, v13.16b, v15.16b, #4 \n" + "ext v9.16b, v14.16b, v16.16b, #4 \n" + "ext v10.16b, v13.16b, v15.16b, #8 \n" + "ext v11.16b, v14.16b, v16.16b, #8 \n" + "ext v12.16b, v13.16b, v15.16b, #12 \n" + "fmla v0.4s, v13.4s, %10.s[3] \n" + "fmla v0.4s, v14.4s, %11.s[0] \n" + "fmla v0.4s, v8.4s, %11.s[1] \n" + "fmla v0.4s, v9.4s, %11.s[2] \n" + "fmla v0.4s, v10.4s, %11.s[3] \n" + "fmla v0.4s, v11.4s, %12.s[0] \n" + "fmla v0.4s, v12.4s, %12.s[1] \n" + "ext v13.16b, v15.16b, v17.16b, #4 \n" + "ext v14.16b, v16.16b, v18.16b, #4 \n" + "ext v8.16b, v15.16b, v17.16b, #8 \n" + "ext v9.16b, v16.16b, v18.16b, #8 \n" + "ext v10.16b, v15.16b, v17.16b, #12 \n" + "fmla v1.4s, v15.4s, %10.s[3] \n" + "fmla v1.4s, v16.4s, %11.s[0] \n" + "fmla v1.4s, v13.4s, %11.s[1] \n" + "fmla v1.4s, v14.4s, %11.s[2] \n" + "fmla v1.4s, v8.4s, %11.s[3] \n" + "fmla v1.4s, v9.4s, %12.s[0] \n" + "fmla v1.4s, v10.4s, %12.s[1] \n" + + "ld2 {v2.4s, v3.4s}, [%3], #32 \n" + "ld2 {v4.4s, v5.4s}, [%3], #32 \n" + "ld2 {v6.4s, v7.4s}, [%3] \n" + "ext v8.16b, v2.16b, v4.16b, #4 \n" + "ext v9.16b, v3.16b, v5.16b, #4 \n" + "ext v10.16b, v2.16b, v4.16b, #8 \n" + "ext v11.16b, v3.16b, v5.16b, #8 \n" + "ext v12.16b, v2.16b, v4.16b, #12 \n" + "fmla v0.4s, v2.4s, %12.s[2] \n" + "fmla v0.4s, v3.4s, %12.s[3] \n" + "fmla v0.4s, v8.4s, %13.s[0] \n" + "fmla v0.4s, v9.4s, %13.s[1] \n" + "fmla v0.4s, v10.4s, %13.s[2] \n" + "fmla v0.4s, v11.4s, %13.s[3] \n" + "fmla v0.4s, v12.4s, %14.s[0] \n" + "ext v2.16b, v4.16b, v6.16b, #4 \n" + "ext v3.16b, v5.16b, v7.16b, #4 \n" + "ext v8.16b, v4.16b, v6.16b, #8 \n" + "ext v9.16b, v5.16b, v7.16b, #8 \n" + "ext v10.16b, v4.16b, v6.16b, #12 \n" + "fmla v1.4s, v4.4s, %12.s[2] \n" + "fmla v1.4s, v5.4s, %12.s[3] \n" + "fmla v1.4s, v2.4s, %13.s[0] \n" + "fmla v1.4s, v3.4s, %13.s[1] \n" + "fmla v1.4s, v8.4s, %13.s[2] \n" + "fmla v1.4s, v9.4s, %13.s[3] \n" + "fmla v1.4s, v10.4s, %14.s[0] \n" + + "ld2 {v13.4s, v14.4s}, [%4], #32 \n" + "ld2 {v15.4s, v16.4s}, [%4], #32 \n" + "ld2 {v17.4s, v18.4s}, [%4] \n" + "ext v8.16b, v13.16b, v15.16b, #4 \n" + "ext v9.16b, v14.16b, v16.16b, #4 \n" + "ext v10.16b, v13.16b, v15.16b, #8 \n" + "ext v11.16b, v14.16b, v16.16b, #8 \n" + "ext v12.16b, v13.16b, v15.16b, #12 \n" + "fmla v0.4s, v13.4s, %14.s[1] \n" + "fmla v0.4s, v14.4s, %14.s[2] \n" + "fmla v0.4s, v8.4s, %14.s[3] \n" + "fmla v0.4s, v9.4s, %15.s[0] \n" + "fmla v0.4s, v10.4s, %15.s[1] \n" + "fmla v0.4s, v11.4s, %15.s[2] \n" + "fmla v0.4s, v12.4s, %15.s[3] \n" + "ext v13.16b, v15.16b, v17.16b, #4 \n" + "ext v14.16b, v16.16b, v18.16b, #4 \n" + "ext v8.16b, v15.16b, v17.16b, #8 \n" + "ext v9.16b, v16.16b, v18.16b, #8 \n" + "ext v10.16b, v15.16b, v17.16b, #12 \n" + "fmla v1.4s, v15.4s, %14.s[1] \n" + "fmla v1.4s, v16.4s, %14.s[2] \n" + "fmla v1.4s, v13.4s, %14.s[3] \n" + "fmla v1.4s, v14.4s, %15.s[0] \n" + "fmla v1.4s, v8.4s, %15.s[1] \n" + "fmla v1.4s, v9.4s, %15.s[2] \n" + "fmla v1.4s, v10.4s, %15.s[3] \n" + + "ld2 {v2.4s, v3.4s}, [%5], #32 \n" + "ld2 {v4.4s, v5.4s}, [%5], #32 \n" + "ld2 {v6.4s, v7.4s}, [%5] \n" + "ext v8.16b, v2.16b, v4.16b, #4 \n" + "ext v9.16b, v3.16b, v5.16b, #4 \n" + "ext v10.16b, v2.16b, v4.16b, #8 \n" + "ext v11.16b, v3.16b, v5.16b, #8 \n" + "ext v12.16b, v2.16b, v4.16b, #12 \n" + "fmla v0.4s, v2.4s, %16.s[0] \n" + "fmla v0.4s, v3.4s, %16.s[1] \n" + "fmla v0.4s, v8.4s, %16.s[2] \n" + "fmla v0.4s, v9.4s, %16.s[3] \n" + "fmla v0.4s, v10.4s, %17.s[0] \n" + "fmla v0.4s, v11.4s, %17.s[1] \n" + "fmla v0.4s, v12.4s, %17.s[2] \n" + "ext v2.16b, v4.16b, v6.16b, #4 \n" + "ext v3.16b, v5.16b, v7.16b, #4 \n" + "ext v8.16b, v4.16b, v6.16b, #8 \n" + "ext v9.16b, v5.16b, v7.16b, #8 \n" + "ext v10.16b, v4.16b, v6.16b, #12 \n" + "fmla v1.4s, v4.4s, %16.s[0] \n" + "fmla v1.4s, v5.4s, %16.s[1] \n" + "fmla v1.4s, v2.4s, %16.s[2] \n" + "fmla v1.4s, v3.4s, %16.s[3] \n" + "fmla v1.4s, v8.4s, %17.s[0] \n" + "fmla v1.4s, v9.4s, %17.s[1] \n" + "fmla v1.4s, v10.4s, %17.s[2] \n" + + "ld2 {v13.4s, v14.4s}, [%6], #32 \n" + "ld2 {v15.4s, v16.4s}, [%6], #32 \n" + "ld2 {v17.4s, v18.4s}, [%6] \n" + "ext v8.16b, v13.16b, v15.16b, #4 \n" + "ext v9.16b, v14.16b, v16.16b, #4 \n" + "ext v10.16b, v13.16b, v15.16b, #8 \n" + "ext v11.16b, v14.16b, v16.16b, #8 \n" + "ext v12.16b, v13.16b, v15.16b, #12 \n" + "fmla v0.4s, v13.4s, %17.s[3] \n" + "fmla v0.4s, v14.4s, %18.s[0] \n" + "fmla v0.4s, v8.4s, %18.s[1] \n" + "fmla v0.4s, v9.4s, %18.s[2] \n" + "fmla v0.4s, v10.4s, %18.s[3] \n" + "fmla v0.4s, v11.4s, %19.s[0] \n" + "fmla v0.4s, v12.4s, %19.s[1] \n" + "ext v13.16b, v15.16b, v17.16b, #4 \n" + "ext v14.16b, v16.16b, v18.16b, #4 \n" + "ext v8.16b, v15.16b, v17.16b, #8 \n" + "ext v9.16b, v16.16b, v18.16b, #8 \n" + "ext v10.16b, v15.16b, v17.16b, #12 \n" + "fmla v1.4s, v15.4s, %17.s[3] \n" + "fmla v1.4s, v16.4s, %18.s[0] \n" + "fmla v1.4s, v13.4s, %18.s[1] \n" + "fmla v1.4s, v14.4s, %18.s[2] \n" + "fmla v1.4s, v8.4s, %18.s[3] \n" + "fmla v1.4s, v9.4s, %19.s[0] \n" + "fmla v1.4s, v10.4s, %19.s[1] \n" + + "ld2 {v2.4s, v3.4s}, [%7], #32 \n" + "ld2 {v4.4s, v5.4s}, [%7], #32 \n" + "ld2 {v6.4s, v7.4s}, [%7] \n" + "ext v8.16b, v2.16b, v4.16b, #4 \n" + "ext v9.16b, v3.16b, v5.16b, #4 \n" + "ext v10.16b, v2.16b, v4.16b, #8 \n" + "ext v11.16b, v3.16b, v5.16b, #8 \n" + "ext v12.16b, v2.16b, v4.16b, #12 \n" + "fmla v0.4s, v2.4s, %19.s[2] \n" + "fmla v0.4s, v3.4s, %19.s[3] \n" + "fmla v0.4s, v8.4s, %20.s[0] \n" + "fmla v0.4s, v9.4s, %20.s[1] \n" + "fmla v0.4s, v10.4s, %20.s[2] \n" + "fmla v0.4s, v11.4s, %20.s[3] \n" + "fmla v0.4s, v12.4s, %21.4s \n" + "ext v2.16b, v4.16b, v6.16b, #4 \n" + "ext v3.16b, v5.16b, v7.16b, #4 \n" + "ext v8.16b, v4.16b, v6.16b, #8 \n" + "ext v9.16b, v5.16b, v7.16b, #8 \n" + "ext v10.16b, v4.16b, v6.16b, #12 \n" + "fmla v1.4s, v4.4s, %19.s[2] \n" + "fmla v1.4s, v5.4s, %19.s[3] \n" + "fmla v1.4s, v2.4s, %20.s[0] \n" + "fmla v1.4s, v3.4s, %20.s[1] \n" + "fmla v1.4s, v8.4s, %20.s[2] \n" + "fmla v1.4s, v9.4s, %20.s[3] \n" + "fmla v1.4s, v10.4s, %21.4s \n" + + "add x1, x1, #2 \n" + "st1 {v0.4s, v1.4s}, [%0], #32 \n" + "cmp %8, x1 \n" + "bne 2b \n" + "3: \n" + + : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), + "+r"(r4), "+r"(r5), "+r"(r6) + : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), + "w"(_k12131415), "w"(_k16171819), "w"(_k20212223), + "w"(_k24252627), "w"(_k28293031), "w"(_k32333435), + "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), + "w"(_k48484848) + : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18"); + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } + filter += 49; + } +} + +} // namespace conv_stride2 +} // namespace fp32 +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/int8/algos.cpp b/dnn/src/aarch64/conv_bias/int8/algos.cpp new file mode 100644 index 00000000..8b17cfa6 --- /dev/null +++ b/dnn/src/aarch64/conv_bias/int8/algos.cpp @@ -0,0 +1,187 @@ +/** + * \file dnn/src/aarch64/conv_bias/int8/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/conv_bias/int8/algos.h" +#include "src/aarch64/conv_bias/int8/strategy.h" +#include "src/arm_common/convolution/img2col_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" +#include "src/fallback/conv_bias/common.h" +#include "src/fallback/matrix_mul/gemm_impl.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_aarch64_conv_bias_int8_gemm) + +using namespace megdnn; +using namespace aarch64; +using megdnn::arm_common::HSwishOp; +using megdnn::arm_common::ReluOp; +using megdnn::arm_common::TypeCvtOp; + +/* ===================== matrix mul algo ===================== */ + +bool ConvBiasImpl::AlgoS8MatrixMul::usable( + FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(opr); + auto&& fm = param.filter_meta; + return param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.dst_type.enumv() == DTypeEnum::QuantizedS8 && + fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + //! As postprocess, the bias is not contigous read, make the + //! performance bad, so we do not process it in fused kernel + param.bias_mode != BiasMode::BIAS && + //! This algo is only support single thread + param.nr_threads == 1_z; +} + +WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( + const NCBKernSizeParam& param) { + UNPACK_CONV_NCB_KERN_SIZES(param); + MEGDNN_MARK_USED_VAR(N); + auto IW2 = IH + 2 * PH; + auto IH2 = IW + 2 * PW; + bool can_matrix_mul_direct = + (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0); + // temp space to store padding-free src (with 16 extra int8) + // temp space to store unrolled matrix (with 16 extra int8) + // workspace for matrix mul opr + size_t part0, part1, part2; + if (can_matrix_mul_direct) { + part0 = part1 = 0; + } else { + part0 = (IC * IH2 * IW2 + 16) * sizeof(int8_t); + part1 = (IC * FH * FW * OH * OW + 16) * sizeof(int8_t); + } + { + size_t M = OC; + size_t K = IC * FH * FW; + size_t N = OH * OW; + +#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ + _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \ + _bias_midout_enum, _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + part2 = megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ + M, N, K, false, false, strategy) \ + .get_workspace_size(); \ + } \ + MIDOUT_END() + +#if !(__ARM_FEATURE_DOTPROD) + DISPATCH_GEMM_BIAS(s8_4x4, 0) +#else + DISPATCH_GEMM_BIAS(s8_8x12, 1) +#endif +#undef DISPATCH_GEMM_STRATEGY + } + return {nullptr, {part0, part1, part2}}; +} + +void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, + const NCBKernIndex& ncb_index) { + auto is_xcorr = !param.filter_meta.should_flip; + UNPACK_CONV_NCB_KERN_SIZES(param); + auto bundle = get_bundle(param); + bundle.set(param.workspace_ptr); + auto IH2 = IH + 2 * PH; + auto IW2 = IW + 2 * PW; + size_t group_id = ncb_index.ndrange_id[0]; + // workspace = tmp..src2 + for (size_t n = 0; n < N; ++n) { + dt_int8* src = const_cast(param.src(n, group_id)); + dt_int8* filter = const_cast(param.filter(group_id)); + dt_int8* dst = static_cast(param.dst(n, group_id)); + dt_int32* bias = const_cast(param.bias(n, group_id)); + + dt_int8 *B, *src2; + if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { + // special case: 1x1 + B = const_cast(src); + } else { + src2 = static_cast(bundle.get(0)); + // copy src to src2; + dt_int8* src2_ptr = src2; + const dt_int8* src_ptr = src; + rep(ic, IC) { + if (PH != 0) { + std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2); + src2_ptr += PH * IW2; + } + rep(ih, IH) { + if (PW != 0) + rep(pw, PW) { *(src2_ptr++) = 0.0f; } + std::memcpy(src2_ptr, src_ptr, sizeof(dt_int8) * IW); + src2_ptr += IW; + src_ptr += IW; + if (PW != 0) + rep(pw, PW) { *(src2_ptr++) = 0.0f; } + } + if (PH != 0) { + std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2); + src2_ptr += PH * IW2; + } + } + + B = static_cast(bundle.get(1)); + if (SH == 1 && SW == 1) { + if (is_xcorr) + img2col(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); + else + img2col(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); + } else { + if (is_xcorr) + img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, + FW, SH, SW); + else + img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, + FW, SH, SW); + } + } + { + Workspace workspace(static_cast(bundle.get(2)), + bundle.get_size(2)); + size_t M = OC; + size_t K = IC * FH * FW; + size_t N = OH * OW; + +#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ + _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \ + _bias_midout_enum, _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline> \ + gemm_interleaved(M, N, K, false, false, strategy); \ + gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ + bias); \ + } \ + MIDOUT_END() + +#if !(__ARM_FEATURE_DOTPROD) + DISPATCH_GEMM_BIAS(s8_4x4, 0) +#else + DISPATCH_GEMM_BIAS(s8_8x12, 1) +#endif +#undef DISPATCH_GEMM_STRATEGY + } + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/int8/algos.h b/dnn/src/aarch64/conv_bias/int8/algos.h new file mode 100644 index 00000000..403d755d --- /dev/null +++ b/dnn/src/aarch64/conv_bias/int8/algos.h @@ -0,0 +1,52 @@ +/** + * \file dnn/src/aarch64/conv_bias/int8/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/aarch64/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/opr_impl.h" + +namespace megdnn { +namespace aarch64 { + +using FallbackConvBiasImpl = fallback::ConvBiasImpl; + +class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase { + static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); + static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); + +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "S8MATMUL"; } + + bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(FallbackConvBiasImpl*, + const NCBKernSizeParam& param) const override { + return get_bundle(param).total_size_in_bytes(); + } + SmallVector dispatch_kerns( + FallbackConvBiasImpl*, const NCBKernSizeParam& param) const override { + size_t group = param.filter_meta.group; + return {{kimpl, {group, 1_z, 1_z}}}; + } + //! select matmul to the highest preference + bool is_preferred(FallbackConvBiasImpl* opr, + const NCBKernSizeParam& param) const override { + return static_cast(opr) + ->is_matmul_quantized_prefer(param); + } +}; + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/int8/strategy.cpp b/dnn/src/aarch64/conv_bias/int8/strategy.cpp new file mode 100644 index 00000000..cc8a6761 --- /dev/null +++ b/dnn/src/aarch64/conv_bias/int8/strategy.cpp @@ -0,0 +1,309 @@ +/** + * \file dnn/src/aarch64/conv_bias/int8/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/conv_bias/int8/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +#include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h" +#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" +#include "src/arm_common/conv_bias/matmul_postprocess.h" + +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +namespace impl { +template +struct KernCaller; + +#if __ARM_FEATURE_DOTPROD +template +struct KernCaller { + static void run(const dt_int8* packA, const dt_int8* packB, size_t M, + size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, + Op op, const dt_int32* bias, dt_int32* workspace) { + megdnn_assert(is_first_k); + + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t B_INTERLEAVE = 12; + //! K is packed to times of 4 + K = round_up(K, 4); + const int K8 = (K << 3); + const int K12 = K * 12; + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int8_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x12x4::kern_8x12(packA, cur_packB, K, workspace, 12, + is_first_k); + + arm_common::ConvBiasMatmul::postprocess(bias, workspace, + output, LDC, op); + output += B_INTERLEAVE; + cur_packB += K12; + } + + for (; n < N; n += 4) { + matmul_8x12x4::kern_8x4(packA, cur_packB, K, workspace, 4, + is_first_k, std::min(N - n, 4)); + +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_N(cb, 8, std::min(N - n, 4)); +#undef cb + output += 4; + cur_packB += K4; + } + packA += K8; + + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias += A_INTERLEAVE; + } + } + + for (; m < M; m += 4) { + int8_t* output = C + (m * LDC); + const dt_int8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x12x4::kern_4x12(packA, cur_packB, K, workspace, 12, + is_first_k, + std::min(M - m, 4)); +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_M_N(cb, std::min(M - m, 4), 12); +#undef cb + + output += B_INTERLEAVE; + cur_packB += K12; + } + + for (; n < N; n += 4) { + matmul_8x12x4::kern_4x4(packA, cur_packB, K, workspace, 4, + is_first_k, std::min(M - m, 4), + std::min(N - n, 4)); +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_M(cb, std::min(M - m, 4), + std::min(N - n, 4)); +#undef cb + + output += 4; + cur_packB += K4; + } + packA += K4; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias += 4; + } + } + } +}; + +#else + +template +struct KernCaller { + static void run(const dt_int8* packA, const dt_int8* packB, size_t M, + size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, + Op op, const dt_int32* bias, dt_int32* workspace) { + megdnn_assert(is_first_k); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 4; + //! K is packed to times of 4 + K = round_up(K, 16); + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int8_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, + is_first_k); + arm_common::ConvBiasMatmul::postprocess(bias, workspace, + output, LDC, op); + + output += B_INTERLEAVE; + cur_packB += K4; + } + + for (; n < N; n += B_INTERLEAVE) { + matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, workspace, + 4, is_first_k, 4, + std::min(N - n, 4)); +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_N(cb, 4, std::min(N - n, 4)); +#undef cb + output += B_INTERLEAVE; + cur_packB += K4; + } + + packA += K4; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias += A_INTERLEAVE; + } + } + + for (; m < M; m += A_INTERLEAVE) { + int8_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n < N; n += B_INTERLEAVE) { + matmul_4x4x16::kern_4x4_remain( + packA, cur_packB, K, workspace, 4, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); + +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_M(cb, std::min(M - m, 4), + std::min(N - n, 4)); +#undef cb + output += B_INTERLEAVE; + cur_packB += K4; + } + packA += K4; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias += A_INTERLEAVE; + } + } + } +}; + +#endif + +} // namespace impl +#if !(__ARM_FEATURE_DOTPROD) +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) + +void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { + if (transpose) { + matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, + kmax); + } else { + matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, + int ldin, int x0, int xmax, int k0, + int kmax, bool transpose) const { + if (transpose) { + matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); + } else { + matmul_4x4x16::gemm_s8_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); + } +} + +size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const { + return 4 * 4 * sizeof(dt_int32); +} +#else +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) + +void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { + MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t); + MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t); + if (transpose) { + matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, + kmax); + } else { + matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_s8_8x12_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, + int ldin, int x0, int xmax, int k0, + int kmax, bool transpose) const { + if (transpose) { + matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); + } else { + matmul_8x12x4::gemm_s8_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); + } +} + +size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const { + return 8 * 12 * sizeof(dt_int32); +} + +#endif + +#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ + void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \ + size_t K, dt_int8* C, size_t LDC, bool is_first_k, \ + const dt_int32* bias, dt_int32* workspace) const { \ + float scale_A = A_dtype.param().scale; \ + float scale_B = B_dtype.param().scale; \ + float scale_C = C_dtype.param().scale; \ + DEFINE_OP(_OP); \ + impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ + packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ + workspace); \ + } + +#define DEFINE_OP(_Op) \ + arm_common::_Op op(scale_A* scale_B, scale_C); + +#if !(__ARM_FEATURE_DOTPROD) +KERN(4, 4, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) +KERN(4, 4, nobias, BiasMode::NO_BIAS, relu, ReluOp) +KERN(4, 4, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) +#else +KERN(8, 12, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) +KERN(8, 12, nobias, BiasMode::NO_BIAS, relu, ReluOp) +KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) +#endif +#undef DEFINE_OP + +#define DEFINE_OP(_Op) \ + arm_common::_Op op(scale_A* scale_B, \ + scale_A* scale_B, scale_C); +#if !(__ARM_FEATURE_DOTPROD) +KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) +KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) +KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, + FuseAddHSwishOp) +#else +KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) +KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) +KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, + FuseAddHSwishOp) +#endif +#undef DEFINE_OP + +#undef KERN + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/int8/strategy.h b/dnn/src/aarch64/conv_bias/int8/strategy.h new file mode 100644 index 00000000..092e9301 --- /dev/null +++ b/dnn/src/aarch64/conv_bias/int8/strategy.h @@ -0,0 +1,69 @@ +/** + * \file dnn/src/aarch64/conv_bias/int8/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul { + +#if !(__ARM_FEATURE_DOTPROD) +/** + * \brief base strategy of gemm. + * + * \name gemm___biasmode_nolinemode + */ +MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 4, 16, + false, true, + gemm_s8_4x4_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_relu, + gemm_s8_4x4_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_hswish, + gemm_s8_4x4_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_identity, + gemm_s8_4x4_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu, + gemm_s8_4x4_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish, + gemm_s8_4x4_nobias_identity); + +#else +MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4, + false, true, + gemm_s8_8x12_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_relu, + gemm_s8_8x12_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_hswish, + gemm_s8_8x12_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_identity, + gemm_s8_8x12_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu, + gemm_s8_8x12_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish, + gemm_s8_8x12_nobias_identity); + +#endif + +} // namespace matmul +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/opr_impl.cpp b/dnn/src/aarch64/conv_bias/opr_impl.cpp new file mode 100644 index 00000000..f9905dd0 --- /dev/null +++ b/dnn/src/aarch64/conv_bias/opr_impl.cpp @@ -0,0 +1,69 @@ +/** + * \file dnn/src/aarch64/conv_bias/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/conv_bias/opr_impl.h" +#include "src/aarch64/conv_bias/int8/algos.h" +#include "src/aarch64/conv_bias/quint8/algos.h" + +#include "src/naive/handle.h" +#include "src/common/utils.h" +#include "src/common/metahelper.h" + +#include "src/fallback/convolution/opr_impl.h" +#include "src/aarch64/conv_bias/fp32/algos.h" +#include "src/aarch64/conv_bias/fp16/algos.h" + +using namespace megdnn; +using namespace aarch64; + +class ConvBiasImpl::AlgoPack : NonCopyableObj { + AlgoF32DirectStride2 f32_direct_stride2_large_group{true}; + AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; + AlgoS8MatrixMul s8_matrix_mul; + AlgoQU8MatrixMul qu8_matrix_mul; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + AlgoF16DirectStride2 f16_direct_stride2_large_group{true}; + AlgoF16DirectStride2 f16_direct_stride2_small_group{false}; +#endif + +public: + AlgoPack() { + matmul_algos.emplace_back(&qu8_matrix_mul); + matmul_algos.emplace_back(&s8_matrix_mul); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + direct_algos.emplace_back(&f16_direct_stride2_large_group); + direct_algos.emplace_back(&f16_direct_stride2_small_group); +#endif + direct_algos.emplace_back(&f32_direct_stride2_large_group); + direct_algos.emplace_back(&f32_direct_stride2_small_group); + } + SmallVector direct_algos; + SmallVector matmul_algos; +}; + +SmallVector ConvBiasImpl::algo_pack() { + static AlgoPack sl_algo_pack; + auto&& algos = arm_common::ConvBiasImpl::algo_pack(); + algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), + sl_algo_pack.direct_algos.end()); + //! We put matmul algos at the end. Because matmul will get privilege when + //! prefer return true. See + //! fallback::ConvolutionImpl::ncb_1g_get_all_algorithms for more details. + algos.insert(algos.end(), sl_algo_pack.matmul_algos.begin(), + sl_algo_pack.matmul_algos.end()); + return std::move(algos); +} + +const char* ConvBiasImpl::get_algorithm_set_name() const { + return "AARCH64"; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/opr_impl.h b/dnn/src/aarch64/conv_bias/opr_impl.h new file mode 100644 index 00000000..16c11672 --- /dev/null +++ b/dnn/src/aarch64/conv_bias/opr_impl.h @@ -0,0 +1,41 @@ +/** + * \file dnn/src/aarch64/conv_bias/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/common/utils.h" +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace aarch64 { + +class ConvBiasImpl : public arm_common::ConvBiasImpl { +public: + using arm_common::ConvBiasImpl::ConvBiasImpl; + + SmallVector algo_pack() override; + +protected: + + const char* get_algorithm_set_name() const override; + +private: + class AlgoF32DirectStride2; + class AlgoS8MatrixMul; + class AlgoQU8MatrixMul; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + class AlgoF16DirectStride2; +#endif + class AlgoPack; +}; + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/quint8/algos.cpp b/dnn/src/aarch64/conv_bias/quint8/algos.cpp new file mode 100644 index 00000000..dc8a1233 --- /dev/null +++ b/dnn/src/aarch64/conv_bias/quint8/algos.cpp @@ -0,0 +1,181 @@ +/** + * \file dnn/src/aarch64/conv_bias/quint8/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/conv_bias/quint8/algos.h" +#include "src/aarch64/conv_bias/quint8/strategy.h" +#include "src/aarch64/matrix_mul/quint8_dot/gemv.h" +#include "src/aarch64/matrix_mul/quint8_dot/strategy.h" +#include "src/arm_common/convolution/img2col_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" +#include "src/fallback/conv_bias/common.h" +#include "src/fallback/matrix_mul/gemm_impl.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_aarch64_conv_bias_quint8_gemm) + +using namespace megdnn; +using namespace aarch64; +using megdnn::arm_common::HSwishOp; +using megdnn::arm_common::ReluOp; +using megdnn::arm_common::TypeCvtOp; + +/* ===================== matrix mul algo ===================== */ + +bool ConvBiasImpl::AlgoQU8MatrixMul::usable( + FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(opr); + auto&& fm = param.filter_meta; + return param.src_type.enumv() == DTypeEnum::Quantized8Asymm && + param.dst_type.enumv() == DTypeEnum::Quantized8Asymm && + fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + //! As postprocess, the bias is not contigous read, make the + //! performance bad, so we do not process it in fused kernel + param.bias_mode != BiasMode::BIAS && + //! This algo is only support single thread + param.nr_threads == 1_z; +} + +WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( + const NCBKernSizeParam& param) { + UNPACK_CONV_NCB_KERN_SIZES(param); + MEGDNN_MARK_USED_VAR(N); + auto IW2 = IH + 2 * PH; + auto IH2 = IW + 2 * PW; + bool can_matrix_mul_direct = + (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0); + // temp space to store padding-free src (with 16 extra int8) + // temp space to store unrolled matrix (with 16 extra int8) + // workspace for matrix mul opr + size_t part0, part1, part2; + if (can_matrix_mul_direct) { + part0 = part1 = 0; + } else { + part0 = (IC * IH2 * IW2 + 16) * sizeof(uint8_t); + part1 = (IC * FH * FW * OH * OW + 16) * sizeof(uint8_t); + } + { + size_t M = OC; + size_t K = IC * FH * FW; + size_t N = OH * OW; + +#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ + _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \ + _bias_midout_enum, _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + part2 = megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ + M, N, K, false, false, strategy) \ + .get_workspace_size(); \ + } \ + MIDOUT_END() + + DISPATCH_GEMM_BIAS(u8_8x8, 0) +#undef DISPATCH_GEMM_STRATEGY + } + return {nullptr, {part0, part1, part2}}; +} + +void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, + const NCBKernIndex& ncb_index) { + auto is_xcorr = !param.filter_meta.should_flip; + UNPACK_CONV_NCB_KERN_SIZES(param); + auto bundle = get_bundle(param); + bundle.set(param.workspace_ptr); + auto IH2 = IH + 2 * PH; + auto IW2 = IW + 2 * PW; + size_t group_id = ncb_index.ndrange_id[0]; + uint8_t src_zp = param.src_type.param().zero_point; + // workspace = tmp..src2 + for (size_t n = 0; n < N; ++n) { + uint8_t* src = const_cast(param.src(n, group_id)); + uint8_t* filter = const_cast(param.filter(group_id)); + uint8_t* dst = static_cast(param.dst(n, group_id)); + int32_t* bias = const_cast(param.bias(n, group_id)); + + uint8_t *B, *src2; + if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { + // special case: 1x1 + B = const_cast(src); + } else { + src2 = static_cast(bundle.get(0)); + // copy src to src2; + uint8_t* src2_ptr = src2; + const uint8_t* src_ptr = src; + rep(ic, IC) { + if (PH != 0) { + std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2); + src2_ptr += PH * IW2; + } + rep(ih, IH) { + if (PW != 0) + rep(pw, PW) { *(src2_ptr++) = src_zp; } + std::memcpy(src2_ptr, src_ptr, sizeof(uint8_t) * IW); + src2_ptr += IW; + src_ptr += IW; + if (PW != 0) + rep(pw, PW) { *(src2_ptr++) = src_zp; } + } + if (PH != 0) { + std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2); + src2_ptr += PH * IW2; + } + } + + B = static_cast(bundle.get(1)); + if (SH == 1 && SW == 1) { + if (is_xcorr) + img2col(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); + else + img2col(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); + } else { + if (is_xcorr) + img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, + FW, SH, SW); + else + img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, + FW, SH, SW); + } + } + { + Workspace workspace(static_cast(bundle.get(2)), + bundle.get_size(2)); + size_t M = OC; + size_t K = IC * FH * FW; + size_t N = OH * OW; + +#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ + _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \ + _bias_midout_enum, _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline> \ + gemm_interleaved(M, N, K, false, false, strategy); \ + gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ + bias); \ + } \ + MIDOUT_END() + + DISPATCH_GEMM_BIAS(u8_8x8, 0) +#undef DISPATCH_GEMM_STRATEGY + } + } +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/quint8/algos.h b/dnn/src/aarch64/conv_bias/quint8/algos.h new file mode 100644 index 00000000..afa58b9d --- /dev/null +++ b/dnn/src/aarch64/conv_bias/quint8/algos.h @@ -0,0 +1,52 @@ +/** + * \file dnn/src/aarch64/conv_bias/quint8/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/aarch64/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/opr_impl.h" + +namespace megdnn { +namespace aarch64 { + +using FallbackConvBiasImpl = fallback::ConvBiasImpl; + +class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase { + static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); + static void kimpl(const NCBKernParam& param, const NCBKernIndex&); + +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "QU8MATMUL"; } + + bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(FallbackConvBiasImpl*, + const NCBKernSizeParam& param) const override { + return get_bundle(param).total_size_in_bytes(); + } + SmallVector dispatch_kerns( + FallbackConvBiasImpl*, + const NCBKernSizeParam& param) const override { + size_t group = param.filter_meta.group; + return {{kimpl, {group, 1_z, 1_z}}}; + } + //! select matmul to the highest preference + bool is_preferred(FallbackConvBiasImpl* opr, + const NCBKernSizeParam& param) const override { + return static_cast(opr) + ->is_matmul_quantized_prefer(param); + } +}; +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/quint8/strategy.cpp b/dnn/src/aarch64/conv_bias/quint8/strategy.cpp new file mode 100644 index 00000000..02fbff23 --- /dev/null +++ b/dnn/src/aarch64/conv_bias/quint8/strategy.cpp @@ -0,0 +1,319 @@ +/** + * \file dnn/src/aarch64/conv_bias/quint8/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/conv_bias/quint8/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" +#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" +#include "src/arm_common/conv_bias/matmul_postprocess.h" + +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +namespace impl { +template +struct KernCaller; + +#if __ARM_FEATURE_DOTPROD +template +struct KernCaller { + static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, + size_t N, size_t K, dt_uint8* C, size_t LDC, + bool is_first_k, Op op, const dt_int32* bias, + dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { + megdnn_assert(is_first_k); + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t B_INTERLEAVE = 8; + const uint32_t zAB = + static_cast(zp_A) * static_cast(zp_B) * K; + //! K is packed to times of 4 + K = round_up(K, 4); + const int K8 = (K << 3); + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + uint8_t* output = C + (m * LDC); + + size_t n = 0; + const dt_uint8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x8x4::kern_8x8(packA, cur_packB, K, workspace, 8, + is_first_k, zp_A, zp_B, zAB); + + arm_common::ConvBiasMatmul::postprocess(bias, workspace, + output, LDC, op); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_8x8x4::kern_8x4(packA, cur_packB, K, workspace, 4, + is_first_k, std::min(N - n, 4), + zp_A, zp_B, zAB); +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_N(cb, 8, std::min(N - n, 4)); +#undef cb + + output += 4; + cur_packB += K4; + } + packA += K8; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias += A_INTERLEAVE; + } + } + + for (; m < M; m += 4) { + uint8_t* output = C + (m * LDC); + const dt_uint8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x8x4::kern_4x8(packA, cur_packB, K, workspace, 8, + is_first_k, std::min(M - m, 4), + zp_A, zp_B, zAB); +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_M_N(cb, std::min(M - m, 4), 8); +#undef cb + + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_8x8x4::kern_4x4(packA, cur_packB, K, workspace, 4, + is_first_k, std::min(M - m, 4), + std::min(N - n, 4), zp_A, zp_B, + zAB); +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_M(cb, std::min(M - m, 4), + std::min(N - n, 4)); +#undef cb + + output += 4; + cur_packB += K4; + } + packA += K4; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias += 4; + } + } + } +}; + +#else + +template +struct KernCaller { + static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, + size_t N, size_t K, dt_uint8* C, size_t LDC, + bool is_first_k, Op op, const dt_int32* bias, + dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { + megdnn_assert(is_first_k); + + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t B_INTERLEAVE = 8; + //! K is packed to times of 8 + K = round_up(K, 8); + const int K8 = K * 8; + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + uint8_t* output = C + (m * LDC); + + size_t n = 0; + const dt_uint8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x8x8::kern_8x8(packA, cur_packB, K, workspace, 8, + is_first_k, zp_A, zp_B); + + arm_common::ConvBiasMatmul::postprocess(bias, workspace, + output, LDC, op); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_8x8x8::kern_8x4(packA, cur_packB, K, workspace, 4, + is_first_k, std::min(N - n, 4), + zp_A, zp_B); +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_N(cb, 8, std::min(N - n, 4)); +#undef cb + + + output += 4; + cur_packB += K4; + } + packA += K8; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias += A_INTERLEAVE; + } + } + + for (; m < M; m += 4) { + uint8_t* output = C + (m * LDC); + const dt_uint8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x8x8::kern_4x8(packA, cur_packB, K, workspace, 8, + is_first_k, std::min(M - m, 4), + zp_A, zp_B); +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_M_N(cb, std::min(M - m, 4), 8); +#undef cb + + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_8x8x8::kern_4x4(packA, cur_packB, K, workspace, 4, + is_first_k, std::min(M - m, 4), + std::min(N - n, 4), zp_A, zp_B); +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_M(cb, std::min(M - m, 4), + std::min(N - n, 4)); +#undef cb + + + output += 4; + cur_packB += K4; + } + packA += K4; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias += 4; + } + } + } +}; + +#endif + +} // namespace impl +#if __ARM_FEATURE_DOTPROD +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity) + +void gemm_u8_8x8_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, + int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { + if (transpose) { + matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0, + ymax, k0, kmax); + } else { + matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin, + y0, ymax, k0, kmax); + } +} + +void gemm_u8_8x8_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, + int ldin, int x0, int xmax, int k0, + int kmax, bool transpose) const { + if (transpose) { + matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, + xmax, k0, kmax); + } else { + matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax, + k0, kmax); + } +} + +#else + +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity) +void gemm_u8_8x8_nobias_identity::pack_A(dt_uint8* outptr, + const dt_uint8* inptr, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + uint8_t zA = A_dtype.param().zero_point; + if (transpose) { + matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, + ymax, k0, kmax, zA); + } else { + matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, + kmax, zA); + } +} + +void gemm_u8_8x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, + int ldin, int x0, int xmax, int k0, + int kmax, bool transpose) const { + uint8_t zB = B_dtype.param().zero_point; + if (transpose) { + matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, + k0, kmax, zB); + } else { + matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, + zB); + } +} + +#endif +size_t gemm_u8_8x8_nobias_identity::get_workspace_size() const { + return 8 * 8 * sizeof(dt_int32); +} + +#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ + void gemm_u8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ + const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \ + size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ + const dt_int32* bias, dt_int32* workspace) const { \ + float scale_A = A_dtype.param().scale; \ + uint8_t zp_A = A_dtype.param().zero_point; \ + float scale_B = B_dtype.param().scale; \ + uint8_t zp_B = B_dtype.param().zero_point; \ + float scale_C = C_dtype.param().scale; \ + uint8_t zp_C = C_dtype.param().zero_point; \ + DEFINE_OP(_OP); \ + impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ + packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ + workspace, zp_A, zp_B); \ + } + +#define DEFINE_OP(_Op) \ + arm_common::_Op op(scale_A* scale_B, scale_C, zp_C); + +KERN(8, 8, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) +KERN(8, 8, nobias, BiasMode::NO_BIAS, relu, ReluOp) +KERN(8, 8, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) +#undef DEFINE_OP + +#define DEFINE_OP(_Op) \ + arm_common::_Op op(scale_A* scale_B, \ + scale_A* scale_B, scale_C, zp_C); +KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) +KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) +KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, + FuseAddHSwishOp) +#undef DEFINE_OP + +#undef KERN + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/conv_bias/quint8/strategy.h b/dnn/src/aarch64/conv_bias/quint8/strategy.h new file mode 100644 index 00000000..43c53020 --- /dev/null +++ b/dnn/src/aarch64/conv_bias/quint8/strategy.h @@ -0,0 +1,48 @@ +/** + * \file dnn/src/aarch64/conv_bias/quint8/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul { + +#if __ARM_FEATURE_DOTPROD +MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4, + false, true, + gemm_u8_8x8_nobias_identity); +#else +MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8, + false, true, + gemm_u8_8x8_nobias_identity); +#endif + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_relu, + gemm_u8_8x8_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_hswish, + gemm_u8_8x8_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_identity, + gemm_u8_8x8_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_relu, + gemm_u8_8x8_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_hswish, + gemm_u8_8x8_nobias_identity); + + +} // namespace matmul +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/handle.cpp b/dnn/src/aarch64/handle.cpp new file mode 100644 index 00000000..9d894b9b --- /dev/null +++ b/dnn/src/aarch64/handle.cpp @@ -0,0 +1,44 @@ +/** + * \file dnn/src/aarch64/handle.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/common/handle_impl.h" + +#include "src/aarch64/handle.h" +#include "src/aarch64/matrix_mul/opr_impl.h" +#include "src/aarch64/rotate/opr_impl.h" +#include "src/aarch64/relayout/opr_impl.h" +#include "src/aarch64/conv_bias/opr_impl.h" +#include "src/aarch64/warp_perspective/opr_impl.h" + +namespace megdnn { +namespace aarch64 { + +template +std::unique_ptr HandleImpl::create_operator() { + return arm_common::HandleImpl::create_operator(); +} + +MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMul) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(Rotate) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutForward) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspective) + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpragmas" +#pragma GCC diagnostic ignored "-Winstantiation-after-specialization" +MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) +#pragma GCC diagnostic pop + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/handle.h b/dnn/src/aarch64/handle.h new file mode 100644 index 00000000..70b5fea9 --- /dev/null +++ b/dnn/src/aarch64/handle.h @@ -0,0 +1,33 @@ +/** + * \file dnn/src/aarch64/handle.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/arm_common/handle.h" + +namespace megdnn { +namespace aarch64 { + +class HandleImpl: public arm_common::HandleImpl { + public: + HandleImpl(megcoreComputingHandle_t computing_handle, + HandleType type = HandleType::AARCH64): + arm_common::HandleImpl::HandleImpl(computing_handle, type) + {} + + template + std::unique_ptr create_operator(); +}; + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen + + diff --git a/dnn/src/aarch64/matrix_mul/algos.cpp b/dnn/src/aarch64/matrix_mul/algos.cpp new file mode 100644 index 00000000..b98045c9 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/algos.cpp @@ -0,0 +1,1038 @@ +/** + * \file dnn/src/aarch64/matrix_mul/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/algos.h" +#include "src/aarch64/matrix_mul/fp16/strategy.h" +#include "src/aarch64/matrix_mul/fp32/strategy.h" +#include "src/aarch64/matrix_mul/int16/strategy.h" +#include "src/aarch64/matrix_mul/int8/strategy.h" +#include "src/aarch64/matrix_mul/int8_dot/gemv.h" +#include "src/aarch64/matrix_mul/int8_dot/strategy.h" +#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" +#include "src/aarch64/matrix_mul/quint8/strategy.h" +#include "src/aarch64/matrix_mul/quint8_dot/gemv.h" +#include "src/aarch64/matrix_mul/quint8_dot/strategy.h" +#include "src/common/utils.h" +#include "src/fallback/matrix_mul/gemm_impl.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_aarch64_matmul_kern) + +using namespace megdnn; +using namespace aarch64; + +/* ===================== F32K8X12X1 algo ===================== */ +bool MatrixMulImpl::AlgoF32K8x12x1::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32() && + kern_size_param.format == param::MatrixMul::Format::DEFAULT; +} + +size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32K8x12x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::sgemm_8x12 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( + const KernSizeParam&) const { + auto f32_kern_8x12 = [](const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32K8x12x1::get_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, + LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + aarch64::matmul::sgemm_8x12 strategy(M, N, K, A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); + }; + return f32_kern_8x12; +} +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, + "AlgoF32K8x12x1Impl"_hash, + aarch64::matmul::sgemm_8x12, float, float); + +/* ===================== F32K4X16X1 algo ===================== */ + +bool MatrixMulImpl::AlgoF32K4x16x1::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32() && + kern_size_param.format == param::MatrixMul::Format::DEFAULT; +} + +size_t MatrixMulImpl::AlgoF32K4x16x1::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32K4x16x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::sgemm_4x16 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern( + const KernSizeParam&) const { + auto f32_kern_4x16 = [](const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32K4x16x1::get_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, + LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::sgemm_4x16 strategy(M, N, K, A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); + }; + return f32_kern_4x16; +} +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K4x16x1, megdnn_aarch64_matmul_kern, + "AlgoF32K4x16x1Impl"_hash, + aarch64::matmul::sgemm_4x16, float, float); + +/* ===================== F32MK4_4x16 algo ===================== */ + +bool MatrixMulImpl::AlgoF32MK4_4x16::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.C_type == dtype::Float32() && + kern_size_param.B_type == dtype::Float32() && + kern_size_param.A_type == dtype::Float32() && + kern_size_param.format == param::MatrixMul::Format::MK4 && + !kern_size_param.trA && !kern_size_param.trB && + kern_size_param.N % 4 == 0; +} + +size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32MK4_4x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::sgemm_nopack_4x16 strategy(A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved< + aarch64::matmul::sgemm_nopack_4x16, false>(M, N, K, trA, + trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x16::get_kern( + const KernSizeParam&) const { + auto f32_kern_mk4_4x16 = [](const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32MK4_4x16::get_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, + LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::sgemm_nopack_4x16 strategy(A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved(M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); + }; + return f32_kern_mk4_4x16; +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +/* ===================== F16 K8x24x1 algo ===================== */ +namespace { +void f16_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("f16_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::hgemm_8x24 strategy(M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoF16K8x24x1::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float16(); +} + +size_t MatrixMulImpl::AlgoF16K8x24x1::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoF16K8x24x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::hgemm_8x24 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern( + const KernSizeParam&) const { + return f16_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K8x24x1, megdnn_aarch64_matmul_kern, + "AlogF16K8x24x1Impl"_hash, + aarch64::matmul::hgemm_8x24, dt_float16, + dt_float16); +/* ===================== F16_MK8_8x8 algo ===================== */ + +bool MatrixMulImpl::AlgoF16MK8_8x8::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float16() && + kern_size_param.format == param::MatrixMul::Format::MK8 && + !kern_size_param.trA && !kern_size_param.trB && + kern_size_param.N % 4 == 0; +} + +size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoF16MK8_8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::gemm_nopack_f16_8x8 strategy(A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved< + aarch64::matmul::gemm_nopack_f16_8x8, false>( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern( + const KernSizeParam&) const { + auto kern_mk8_8x8 = [](const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoF16MK8_8x8::get_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, + LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_nopack_f16_8x8 strategy(A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved< + aarch64::matmul::gemm_nopack_f16_8x8, false>(M, N, K, trA, + trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); + }; + return kern_mk8_8x8; +} + +#endif + +#if __ARM_FEATURE_DOTPROD +/* ==================== Int8x8x32 K8x12x4 Dotprod algo ==================== */ +namespace { +void int8x8x32_k8x12x4_dotprod_kern( + const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("int8x8x32_k8x12x4_dotprod_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_s8_8x12 strategy(M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::usable( + const KernSizeParam& kern_size_param) const { + return can_be_treated_as_int8x8x32(kern_size_param); +} + +size_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x32K8x12x4DotProd::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + + aarch64::matmul::gemm_s8_8x12 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_kern( + const KernSizeParam&) const { + return int8x8x32_k8x12x4_dotprod_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd, + megdnn_aarch64_matmul_kern, + "AlgoInt8x8x32K8x12x4DotProdImpl"_hash, + aarch64::matmul::gemm_s8_8x12, int8_t, + int32_t); +/* ===================== Int8x8x32 Gemv DotProd algo ===================== */ +namespace { +void int8x8x32_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + aarch64::matmul::gemv_like_int8(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::usable( + const KernSizeParam& kern_size_param) const { + return can_be_treated_as_int8x8x32(kern_size_param) && + !kern_size_param.trA && !kern_size_param.trB && + kern_size_param.N == 1 && kern_size_param.LDB == 1; +} + +bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::preferred( + const KernSizeParam& kern_size_param) const { + auto N = kern_size_param.N, LDB = kern_size_param.LDB; + return (N == 1 && LDB == 1); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvDotProd::get_kern( + const KernSizeParam&) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x32GemvDotProd::get_kern"_hash)) { + return int8x8x32_gemv_dotprod_kern; + } + MIDOUT_END(); + return nullptr; +} +#else + +/* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */ +namespace { +void int8x8x32_mk4_4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("int8x8x32_mk4_4x4x16_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + aarch64::matmul::gemm_mk4_s8_4x4 strategy(M, N, K, A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::usable( + const KernSizeParam& param) const { + return param.A_type.enumv() == param.B_type.enumv() && + (param.A_type.enumv() == DTypeEnum::Int8 || + param.A_type.enumv() == DTypeEnum::QuantizedS8) && + (param.C_type.enumv() == DTypeEnum::Int32 || + param.C_type.enumv() == DTypeEnum::QuantizedS32) && + param.compute_mode == Param::ComputeMode::DEFAULT && + param.format == param::MatrixMul::Format::MK4 && !param.trA && + !param.trB && param.M % 4 == 0 && param.K % 4 == 0; +} + +bool MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::preferred( + const KernSizeParam& kern_size_param) const { + return kern_size_param.K > 16; +} + +size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x32MK4_4x4x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::gemm_mk4_s8_4x4 strategy(M, N, K, A_type, B_type, + C_type); + return megdnn::matmul::GemmInterleaved< + aarch64::matmul::gemm_mk4_s8_4x4>(M, N, K, trA, trB, + strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_kern( + const KernSizeParam&) const { + return int8x8x32_mk4_4x4x16_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x4x16, + megdnn_aarch64_matmul_kern, + "AlgoInt8x8x32MK4_4x4x16Impl"_hash, + aarch64::matmul::gemm_mk4_s8_4x4, int8_t, + int32_t); + +/* ===================== Int8x8x32 K4x4x16 algo ===================== */ +namespace { +void int8x8x32_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("int8x8x32_k4x4x16_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_s8_4x4 strategy(M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32K4x4x16::usable( + const KernSizeParam& kern_size_param) const { + return can_be_treated_as_int8x8x32(kern_size_param); +} + +bool MatrixMulImpl::AlgoInt8x8x32K4x4x16::preferred( + const KernSizeParam& kern_size_param) const { + return kern_size_param.K > 16; +} + +size_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x32K4x4x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::gemm_s8_4x4 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_kern( + const KernSizeParam&) const { + return int8x8x32_k4x4x16_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x4x16, + megdnn_aarch64_matmul_kern, + "AlgoInt8x8x32K4x4x16Impl"_hash, + aarch64::matmul::gemm_s8_4x4, int8_t, + int32_t); +/* ===================== Int8x8x32 K8x8x8 algo ===================== */ +namespace { +void int8x8x32_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("int8x8x32_k8x8x8_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_s8_8x8 strategy(M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32K8x8x8::usable( + const KernSizeParam& kern_size_param) const { + return can_be_treated_as_int8x8x32(kern_size_param); +} + +bool MatrixMulImpl::AlgoInt8x8x32K8x8x8::preferred( + const KernSizeParam& kern_size_param) const { + return kern_size_param.K <= 16; +} + +size_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x32K8x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::gemm_s8_8x8 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_kern( + const KernSizeParam&) const { + return int8x8x32_k8x8x8_kern; +} +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8, + megdnn_aarch64_matmul_kern, + "AlgoInt8x8x32K8x8x8Impl"_hash, + aarch64::matmul::gemm_s8_8x8, int8_t, + int32_t); +#endif + +/* ===================== Int8x8x16 K8x8x8 algo ===================== */ +namespace { +void int8x8x16_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("int8x8x16_k8x8x8_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::usable( + const KernSizeParam& kern_size_param) const { + return can_be_treated_as_int8x8x16(kern_size_param); +} + +bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::preferred( + const KernSizeParam& kern_size_param) const { + return kern_size_param.K <= 16; +} + +size_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x16K8x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type, + C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_kern( + const KernSizeParam&) const { + return int8x8x16_k8x8x8_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x8, + megdnn_aarch64_matmul_kern, + "AlgoInt8x8x16K8x8x8Impl"_hash, + aarch64::matmul::gemm_s8x8x16_8x8, int8_t, + int16_t); +/* ===================== Int8x8x16 K4x4x16 algo ===================== */ +namespace { +void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("int8x8x16_k4x4x16_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_s8x8x16_4x4 strategy(M, N, K, A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::usable( + const KernSizeParam& kern_size_param) const { + return can_be_treated_as_int8x8x16(kern_size_param); +} + +bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::preferred( + const KernSizeParam& kern_size_param) const { + MEGDNN_MARK_USED_VAR(kern_size_param); + return true; +} + +size_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x16K4x4x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::gemm_s8x8x16_4x4 strategy(M, N, K, A_type, B_type, + C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_kern( + const KernSizeParam&) const { + return int8x8x16_k4x4x16_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16, + megdnn_aarch64_matmul_kern, + "AlgoInt8x8x16K4x4x16Impl"_hash, + aarch64::matmul::gemm_s8x8x16_4x4, int8_t, + int16_t); + +/* ===================== Int16x16x32 K12x8x1 algo ===================== */ +namespace { +void int16x16x32_k12x8x1_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("int16x16x32_k12x8x1_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_s16_12x8x1 strategy(M, N, K, A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt16x16x32K12x8x1::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.compute_mode == + param::MatrixMul::ComputeMode::DEFAULT && + kern_size_param.A_type.enumv() == DTypeEnum::Int16 && + kern_size_param.C_type.enumv() == DTypeEnum::Int32; +} + +bool MatrixMulImpl::AlgoInt16x16x32K12x8x1::preferred( + const KernSizeParam& kern_size_param) const { + MEGDNN_MARK_USED_VAR(kern_size_param); + return true; +} + +size_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt16x16x32K12x8x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::gemm_s16_12x8x1 strategy(M, N, K, A_type, B_type, + C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_kern( + const KernSizeParam&) const { + return int16x16x32_k12x8x1_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x8x1, + megdnn_aarch64_matmul_kern, + "AlgoInt16x16x32K12x8x1Impl"_hash, + aarch64::matmul::gemm_s16_12x8x1, int16_t, + int32_t); + +/* ===================== Int16x16x32MK8_8x8 algo ===================== */ + +bool MatrixMulImpl::AlgoInt16x16x32MK8_8x8::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.C_type == dtype::Int32() && + kern_size_param.B_type == dtype::Int16() && + kern_size_param.A_type == dtype::Int16() && + kern_size_param.format == param::MatrixMul::Format::MK8 && + !kern_size_param.trA && !kern_size_param.trB && + kern_size_param.N % 4 == 0; +} + +size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt16x16x32MK8_8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::gemm_nopack_s16_8x8 strategy(A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved< + aarch64::matmul::gemm_nopack_s16_8x8, false>( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern( + const KernSizeParam&) const { + auto kern_mk8_8x8 = [](const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt16x16x32MK8_8x8::get_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, + LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_nopack_s16_8x8 strategy(A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved< + aarch64::matmul::gemm_nopack_s16_8x8, false>(M, N, K, trA, + trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); + }; + return kern_mk8_8x8; +} + +#if __ARM_FEATURE_DOTPROD +/* ==================== Quint8 K8x8x4 Dotprod algo ==================== */ +namespace { +void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("quint8_k8x8x4_dotprod_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoQuint8K8x8x4DotProd::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && + kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; +} + +size_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoQuint8K8x8x4DotProd::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + + aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern( + const KernSizeParam&) const { + return quint8_k8x8x4_dotprod_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd, + megdnn_aarch64_matmul_kern, + "AlgoQuint8K8x8x4DotProdImpl"_hash, + aarch64::matmul::gemm_u8_8x8, uint8_t, + int32_t); +/* ===================== Quint8 Gemv DotProd algo ===================== */ +namespace { +void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("quint8_gemv_dotprod_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + auto A_type = kern_param.A_type, B_type = kern_param.B_type; + + aarch64::matmul::gemv_like_quint8( + Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC, + A_type.param().zero_point, + B_type.param().zero_point); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoQuint8GemvDotProd::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && + kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + !kern_size_param.trA && !kern_size_param.trB && + kern_size_param.N == 1 && kern_size_param.LDB == 1; +} + +bool MatrixMulImpl::AlgoQuint8GemvDotProd::preferred( + const KernSizeParam& kern_size_param) const { + auto N = kern_size_param.N, LDB = kern_size_param.LDB; + return (N == 1 && LDB == 1); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8GemvDotProd::get_kern( + const KernSizeParam&) const { + return quint8_gemv_dotprod_kern; +} +#else + +/* ===================== Quint8 K8x8x8 algo ===================== */ +namespace { +void quint8_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("quint8_gemv_dotprod_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoQuint8K8x8x8::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && + kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; +} + +size_t MatrixMulImpl::AlgoQuint8K8x8x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoQuint8K8x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + + aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x8::get_kern( + const KernSizeParam&) const { + return quint8_k8x8x8_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8, + megdnn_aarch64_matmul_kern, + "AlgoQuint8K8x8x8Impl"_hash, + aarch64::matmul::gemm_u8_8x8, uint8_t, + int32_t); +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h new file mode 100644 index 00000000..266f1e47 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -0,0 +1,250 @@ +/** + * \file dnn/src/aarch64/matrix_mul/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/aarch64/matrix_mul/opr_impl.h" +#include "src/arm_common/matrix_mul/algos.h" +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace aarch64 { + +class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_F32K8X12X1"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_F32K4X16X1"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_F32_MK4_4x16"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + PackMode packmode() const override { return PackMode::NO_PACK; } +}; + +class MatrixMulImpl::AlgoF32Gemv final + : public arm_common::MatrixMulImpl::AlgoF32Gemv {}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_F16_K8X24X1"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_F16_MK8_8X8"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + PackMode packmode() const override { return PackMode::NO_PACK; } +}; + +#endif + +#if __ARM_FEATURE_DOTPROD +class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { + return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; + } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoInt8x8x32GemvDotProd final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { + return "AARCH64_INT8X8X32_GEMV_DOTPROD"; + } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override { return 0; } + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } + PackMode packmode() const override { return PackMode::NO_PACK; } +}; +#else + +class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { + +public: + bool is_reproducible() const override { return true; } + const char* name() const override { + return "AARCH64_INT8X8X32_MK4_4X4X16"; + } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + PackMode packmode() const override { return PackMode::DEFAULT; } + + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoInt8x8x32Gemv final + : public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {}; + +#endif + +class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + PackMode packmode() const override { return PackMode::NO_PACK; } +}; + +#if __ARM_FEATURE_DOTPROD +class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { + return "AARCH64_QUINT8_K8X8X4_DOTPROD"; + } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override { return 0; } + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } + PackMode packmode() const override { return PackMode::NO_PACK; } +}; +#else + +class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; +#endif + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/asm/common.h b/dnn/src/aarch64/matrix_mul/asm/common.h new file mode 100644 index 00000000..fc9fa537 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/asm/common.h @@ -0,0 +1,1888 @@ +/** + * \file dnn/src/aarch64/matrix_mul/asm/common.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include +#include +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace aarch64 { + +/* ======================== Prefetch ======================== */ +#define ASM_PREFETCH(address) "PRFM PLDL1KEEP, " address "\n" +#define ASM_PREFETCHL2(address) "PRFM PLDL2KEEP, " address "\n" +#define ASM_PREFETCHW(address) "PRFM PSTL1KEEP, " address "\n" +#define ASM_PREFETCHWL2(address) "PRFM PSTL2KEEP, " address "\n" + +static inline void prefetch_6x(const void* pfp) { + // clang-format off + asm volatile(ASM_PREFETCH("[%[pfp]]") + ASM_PREFETCH("[%[pfp], #64]") + ASM_PREFETCH("[%[pfp], #128]") + ASM_PREFETCH("[%[pfp], #192]") + ASM_PREFETCH("[%[pfp], #256]") + ASM_PREFETCH("[%[pfp], #320]") + : + : [pfp] "r"(pfp) + : "memory"); + // clang-format on +} + +static inline void prefetch_5x(const void* pfp) { + // clang-format off + asm volatile(ASM_PREFETCH("[%[pfp]]") + ASM_PREFETCH("[%[pfp], #64]") + ASM_PREFETCH("[%[pfp], #128]") + ASM_PREFETCH("[%[pfp], #192]") + ASM_PREFETCH("[%[pfp], #256]") + : + : [pfp] "r"(pfp) + : "memory"); + // clang-format on +} + +static inline void prefetch_4x(const void* pfp) { + // clang-format off + asm volatile(ASM_PREFETCH("[%[pfp]]") + ASM_PREFETCH("[%[pfp], #64]") + ASM_PREFETCH("[%[pfp], #128]") + ASM_PREFETCH("[%[pfp], #192]") + : + : [pfp] "r"(pfp) + : "memory"); + // clang-format on +} + +static inline void prefetch_3x(const void* pfp) { + // clang-format off + asm volatile(ASM_PREFETCH("[%[pfp]]") + ASM_PREFETCH("[%[pfp], #64]") + ASM_PREFETCH("[%[pfp], #128]") + : + : [pfp] "r"(pfp) + : "memory"); + // clang-format on +} + +static inline void prefetch_2x(const void* pfp) { + // clang-format off + asm volatile(ASM_PREFETCH("[%[pfp]]") + ASM_PREFETCH("[%[pfp], #64]") + : + : [pfp] "r"(pfp) + : "memory"); + // clang-format on +} + +static inline void prefetch_1x(const void* pfp) { + // clang-format off + asm volatile(ASM_PREFETCH("[%[pfp]]") : : [pfp] "r"(pfp) : "memory"); + // clang-format on +} + +/* ======================== interleave pack A ======================== */ + +/** + * interleave_INTERLEAVE_UNROLLK_BATCH_type + * + * BATCH means process BATCH * UNROLL_K cols once, BATCH * sizeof(TYPE) * + * UNROLL_K = 16bytes(128bits, a vector size). + * + * the elements traverse order: + * rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[j, i] + */ + +template +static inline void interleave_24x1_8_h_helper( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr, int skippf = 0) { + static_assert(sizeof(T) == 2, "only support size == 2"); + asm volatile( + // Load up 8 elements (1 vector) from each of 8 sources. + "cbnz %w[skippf], 1f\n" + ASM_PREFETCH("[%[inptr0], #128]") + ASM_PREFETCH("[%[inptr1], #128]") + ASM_PREFETCH("[%[inptr2], #128]") + ASM_PREFETCH("[%[inptr3], #128]") + "1:\n" + + "ldr q0, [%[inptr0]], #16\n" // q0=A0A1A2A3A4A5A6A7 + "ldr q4, [%[inptr4]], #16\n" // q8=E0E1E2E3E4E5E6E7 + "ldr q2, [%[inptr2]], #16\n" // q4=C0C1C2C3... + "ldr q6, [%[inptr6]], #16\n" + "zip1 v8.8h, v0.8h, v4.8h\n" // q8=A0E0A1E1A2E2A3E3 + "zip2 v16.8h, v0.8h, v4.8h\n" // q16=A4E4A5E5A6E6A7E7 + "zip1 v9.8h, v2.8h, v6.8h\n" // q9=C0G0C1G1C2G2C3G3 + "zip2 v17.8h, v2.8h, v6.8h\n" // q17=C4G4C5G5C6G6C7G7 + "ldr q1, [%[inptr1]], #16\n" // q1=B0B1B2B3B4B5B6B7 + "ldr q5, [%[inptr5]], #16\n" + "ldr q3, [%[inptr3]], #16\n" // q3=D0D1D2D3.... + "ldr q7, [%[inptr7]], #16\n" + "zip1 v10.8h, v1.8h, v5.8h\n" // q18=B0F0B1F1B2F2B3F3 + "zip2 v18.8h, v1.8h, v5.8h\n" // q18=B4F4B5F5B6F6B7F7 + "zip1 v11.8h, v3.8h, v7.8h\n" // q19=D0H0D1H1D2H2D3H3 + "zip2 v19.8h, v3.8h, v7.8h\n" // q19=D4H4D5H5D6H6D7H7 + + "zip1 v12.8h, v8.8h, v9.8h\n" // q20=A0C0E0G0A1C1E1G1 + "zip2 v20.8h, v8.8h, v9.8h\n" + "zip1 v13.8h, v10.8h, v11.8h\n" // q21=B0D0F0H0B1I1F1H1 + "zip2 v21.8h, v10.8h, v11.8h\n" + + "cbnz %w[skippf], 2f\n" + ASM_PREFETCH("[%[inptr4], #112]") + ASM_PREFETCH("[%[inptr5], #112]") + ASM_PREFETCH("[%[inptr6], #112]") + ASM_PREFETCH("[%[inptr7], #112]") + "2:\n" + + "zip1 v22.8h, v16.8h, v17.8h\n" + "zip2 v30.8h, v16.8h, v17.8h\n" + "zip1 v23.8h, v18.8h, v19.8h\n" + "zip2 v31.8h, v18.8h, v19.8h\n" + + "zip1 v14.8h, v12.8h, v13.8h\n" // q22=A0B0C0D0E0F0G0H0 + "zip2 v15.8h, v12.8h, v13.8h\n" // q23=A1B1C1D1E1F1G1H1 + "str q14, [%[outptr]], #48\n" + "str q15, [%[outptr]], #48\n" + + "zip1 v0.8h, v20.8h, v21.8h\n" + "zip2 v1.8h, v20.8h, v21.8h\n" + "str q0, [%[outptr]], #48\n" + "str q1, [%[outptr]], #48\n" + + "zip1 v2.8h, v22.8h, v23.8h\n" + "zip2 v3.8h, v22.8h, v23.8h\n" + "str q2, [%[outptr]], #48\n" + "str q3, [%[outptr]], #48\n" + + "zip1 v4.8h, v30.8h, v31.8h\n" + "zip2 v5.8h, v30.8h, v31.8h\n" + "str q4, [%[outptr]], #48\n" + "str q5, [%[outptr]], #48\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), + [outptr] "+r"(outptr) + : [skippf] "r"(skippf) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "cc", "memory"); +} + +template +static inline void interleave_16x1_8_h_helper( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr, int skippf = 0) { + static_assert(sizeof(T) == 2, "only support size == 2"); + asm volatile( + // Load up 8 elements (1 vector) from each of 8 sources. + "cbnz %w[skippf], 1f\n" + ASM_PREFETCH("[%[inptr0], #128]") + ASM_PREFETCH("[%[inptr1], #128]") + ASM_PREFETCH("[%[inptr2], #128]") + ASM_PREFETCH("[%[inptr3], #128]") + "1:\n" + + "ldr q0, [%[inptr0]], #16\n" // q0=A0A1A2A3A4A5A6A7 + "ldr q4, [%[inptr4]], #16\n" // q8=E0E1E2E3E4E5E6E7 + "ldr q2, [%[inptr2]], #16\n" // q4=C0C1C2C3... + "ldr q6, [%[inptr6]], #16\n" + "zip1 v8.8h, v0.8h, v4.8h\n" // q8=A0E0A1E1A2E2A3E3 + "zip2 v16.8h, v0.8h, v4.8h\n" // q16=A4E4A5E5A6E6A7E7 + "zip1 v9.8h, v2.8h, v6.8h\n" // q9=C0G0C1G1C2G2C3G3 + "zip2 v17.8h, v2.8h, v6.8h\n" // q17=C4G4C5G5C6G6C7G7 + "ldr q1, [%[inptr1]], #16\n" // q1=B0B1B2B3B4B5B6B7 + "ldr q5, [%[inptr5]], #16\n" + "ldr q3, [%[inptr3]], #16\n" // q3=D0D1D2D3.... + "ldr q7, [%[inptr7]], #16\n" + "zip1 v10.8h, v1.8h, v5.8h\n" // q18=B0F0B1F1B2F2B3F3 + "zip2 v18.8h, v1.8h, v5.8h\n" // q18=B4F4B5F5B6F6B7F7 + "zip1 v11.8h, v3.8h, v7.8h\n" // q19=D0H0D1H1D2H2D3H3 + "zip2 v19.8h, v3.8h, v7.8h\n" // q19=D4H4D5H5D6H6D7H7 + + "zip1 v12.8h, v8.8h, v9.8h\n" // q20=A0C0E0G0A1C1E1G1 + "zip2 v20.8h, v8.8h, v9.8h\n" + "zip1 v13.8h, v10.8h, v11.8h\n" // q21=B0D0F0H0B1I1F1H1 + "zip2 v21.8h, v10.8h, v11.8h\n" + + "cbnz %w[skippf], 2f\n" + ASM_PREFETCH("[%[inptr4], #112]") + ASM_PREFETCH("[%[inptr5], #112]") + ASM_PREFETCH("[%[inptr6], #112]") + ASM_PREFETCH("[%[inptr7], #112]") + "2:\n" + + "zip1 v22.8h, v16.8h, v17.8h\n" + "zip2 v30.8h, v16.8h, v17.8h\n" + "zip1 v23.8h, v18.8h, v19.8h\n" + "zip2 v31.8h, v18.8h, v19.8h\n" + + "zip1 v14.8h, v12.8h, v13.8h\n" // q22=A0B0C0D0E0F0G0H0 + "zip2 v15.8h, v12.8h, v13.8h\n" // q23=A1B1C1D1E1F1G1H1 + "str q14, [%[outptr]], #32\n" + "str q15, [%[outptr]], #32\n" + + "zip1 v0.8h, v20.8h, v21.8h\n" + "zip2 v1.8h, v20.8h, v21.8h\n" + "str q0, [%[outptr]], #32\n" + "str q1, [%[outptr]], #32\n" + + "zip1 v2.8h, v22.8h, v23.8h\n" + "zip2 v3.8h, v22.8h, v23.8h\n" + "str q2, [%[outptr]], #32\n" + "str q3, [%[outptr]], #32\n" + + "zip1 v4.8h, v30.8h, v31.8h\n" + "zip2 v5.8h, v30.8h, v31.8h\n" + "str q4, [%[outptr]], #32\n" + "str q5, [%[outptr]], #32\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), + [outptr] "+r"(outptr) + : [skippf] "r"(skippf) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "cc", "memory"); +} + +template +static inline void interleave_8x1_8_h(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T*& outptr, int skippf = 0) { + static_assert(sizeof(T) == 2, "only support size == 2"); + asm volatile( + // Load up 8 elements (1 vector) from each of 8 sources. + "cbnz %w[skippf], 1f\n" + ASM_PREFETCH("[%[inptr0], #128]") + ASM_PREFETCH("[%[inptr1], #128]") + ASM_PREFETCH("[%[inptr2], #128]") + ASM_PREFETCH("[%[inptr3], #128]") + "1:\n" + + + "ldr q0, [%[inptr0]], #16\n" // q0=A0A1A2A3A4A5A6A7 + "ldr q4, [%[inptr4]], #16\n" // q8=E0E1E2E3E4E5E6E7 + "ldr q2, [%[inptr2]], #16\n" // q4=C0C1C2C3... + "ldr q6, [%[inptr6]], #16\n" + "zip1 v8.8h, v0.8h, v4.8h\n" // q8=A0E0A1E1A2E2A3E3 + "zip2 v16.8h, v0.8h, v4.8h\n" // q16=A4E4A5E5A6E6A7E7 + "zip1 v9.8h, v2.8h, v6.8h\n" // q9=C0G0C1G1C2G2C3G3 + "zip2 v17.8h, v2.8h, v6.8h\n" // q17=C4G4C5G5C6G6C7G7 + "ldr q1, [%[inptr1]], #16\n" // q1=B0B1B2B3B4B5B6B7 + "ldr q5, [%[inptr5]], #16\n" + "ldr q3, [%[inptr3]], #16\n" // q3=D0D1D2D3.... + "ldr q7, [%[inptr7]], #16\n" + "zip1 v10.8h, v1.8h, v5.8h\n" // q18=B0F0B1F1B2F2B3F3 + "zip2 v18.8h, v1.8h, v5.8h\n" // q18=B4F4B5F5B6F6B7F7 + "zip1 v11.8h, v3.8h, v7.8h\n" // q19=D0H0D1H1D2H2D3H3 + "zip2 v19.8h, v3.8h, v7.8h\n" // q19=D4H4D5H5D6H6D7H7 + + "zip1 v12.8h, v8.8h, v9.8h\n" // q20=A0C0E0G0A1C1E1G1 + "zip2 v20.8h, v8.8h, v9.8h\n" + "zip1 v13.8h, v10.8h, v11.8h\n" // q21=B0D0F0H0B1I1F1H1 + "zip2 v21.8h, v10.8h, v11.8h\n" + + "cbnz %w[skippf], 2f\n" + ASM_PREFETCH("[%[inptr4], #112]") + ASM_PREFETCH("[%[inptr5], #112]") + ASM_PREFETCH("[%[inptr6], #112]") + ASM_PREFETCH("[%[inptr7], #112]") + "2:\n" + + "zip1 v22.8h, v16.8h, v17.8h\n" + "zip2 v30.8h, v16.8h, v17.8h\n" + "zip1 v23.8h, v18.8h, v19.8h\n" + "zip2 v31.8h, v18.8h, v19.8h\n" + + "zip1 v14.8h, v12.8h, v13.8h\n" // q22=A0B0C0D0E0F0G0H0 + "zip2 v15.8h, v12.8h, v13.8h\n" // q23=A1B1C1D1E1F1G1H1 + "stp q14, q15, [%[outptr]], #32\n" // Write back first two elements + + "zip1 v0.8h, v20.8h, v21.8h\n" + "zip2 v1.8h, v20.8h, v21.8h\n" + "stp q0, q1, [%[outptr]], #32\n" // Write back next two elements + + "zip1 v2.8h, v22.8h, v23.8h\n" + "zip2 v3.8h, v22.8h, v23.8h\n" + "stp q2, q3, [%[outptr]], #32\n" // Write back next two elements + + "zip1 v4.8h, v30.8h, v31.8h\n" + "zip2 v5.8h, v30.8h, v31.8h\n" + "stp q4, q5, [%[outptr]], #32\n" // Write back last two elements + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), + [outptr] "+r"(outptr) + + : [skippf] "r"(skippf) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "cc", "memory"); +} + +template +static inline void interleave_4x1_4_h(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 2, "only support size == 2"); + asm volatile( + "ldr d0, [%[inptr0]], #8\n" // d0 = A0A1A2A3 + "ldr d1, [%[inptr1]], #8\n" // d1 = B0B1B2B3 + "ldr d2, [%[inptr2]], #8\n" // d2 = C0C1C2C3 + "ldr d3, [%[inptr3]], #8\n" // d3 = D0D1D2D3 + "zip1 v4.4h, v0.4h, v2.4h\n" // d4 = A0C0A1C1 + "zip2 v8.4h, v0.4h, v2.4h\n" // d8 = A2C2A3C3 + "zip1 v5.4h, v1.4h, v3.4h\n" // d5 = B0D0B1D1 + "zip2 v9.4h, v1.4h, v3.4h\n" // d9 = B2D2B3D3 + + "zip1 v6.4h, v4.4h, v5.4h\n" // d6 = A0B0C0D0 + "zip2 v7.4h, v4.4h, v5.4h\n" // d7 = A1B1C1D1 + "stp d6, d7, [%[outptr]], #16\n" + + "zip1 v10.4h, v8.4h, v9.4h\n" // d10 = A2B2C2D2 + "zip2 v11.4h, v8.4h, v9.4h\n" // d11 = A3B3C3D3 + "stp d10, d11, [%[outptr]], #16\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "memory"); +} + +static inline void interleave_4x1_2_d(const int64_t*& inptr0, + const int64_t*& inptr1, + const int64_t*& inptr2, + const int64_t*& inptr3, + int64_t*& outptr) { + asm volatile( + "ld1 {v0.2d}, [%[inptr0]], #16\n" // d0 = A0A1 + "ld1 {v1.2d}, [%[inptr1]], #16\n" // d1 = B0B1 + "ld1 {v2.2d}, [%[inptr2]], #16\n" // d2 = C0C1 + "ld1 {v3.2d}, [%[inptr3]], #16\n" // d3 = D0D1 + + "zip1 v4.2d, v0.2d, v1.2d\n" // d8 = A0B0 + "zip2 v5.2d, v0.2d, v1.2d\n" // d9 = A1B1 + "zip1 v6.2d, v2.2d, v3.2d\n" // d10 = C0D0 + "zip2 v7.2d, v2.2d, v3.2d\n" // d11 = C1D1 + + "st1 {v4.2d}, [%[outptr]], #16\n" + "st1 {v6.2d}, [%[outptr]], #16\n" + "st1 {v5.2d}, [%[outptr]], #16\n" + "st1 {v7.2d}, [%[outptr]], #16\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); +} + +static inline void interleave_4x2_2_d(const int64_t*& inptr0, + const int64_t*& inptr1, + const int64_t*& inptr2, + const int64_t*& inptr3, + int64_t*& outptr) { + asm volatile( + "ld1 {v0.2d}, [%[inptr0]], #16\n" // d0 = A0 + "ld1 {v1.2d}, [%[inptr0]], #16\n" // d1 = A1 + "ld1 {v2.2d}, [%[inptr1]], #16\n" // d2 = B0 + "ld1 {v3.2d}, [%[inptr1]], #16\n" // d3 = B1 + "ld1 {v4.2d}, [%[inptr2]], #16\n" // d4 = C0 + "ld1 {v5.2d}, [%[inptr2]], #16\n" // d5 = C1 + "ld1 {v6.2d}, [%[inptr3]], #16\n" // d6 = D0 + "ld1 {v7.2d}, [%[inptr3]], #16\n" // d7 = D1 + + "st1 {v0.2d}, [%[outptr]], #16\n" + "st1 {v2.2d}, [%[outptr]], #16\n" + "st1 {v4.2d}, [%[outptr]], #16\n" + "st1 {v6.2d}, [%[outptr]], #16\n" + "st1 {v1.2d}, [%[outptr]], #16\n" + "st1 {v3.2d}, [%[outptr]], #16\n" + "st1 {v5.2d}, [%[outptr]], #16\n" + "st1 {v7.2d}, [%[outptr]], #16\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); +} + +static inline void interleave_12x1_4_s( + const int32_t*& inptr0, const int32_t*& inptr1, const int32_t*& inptr2, + const int32_t*& inptr3, const int32_t*& inptr4, const int32_t*& inptr5, + const int32_t*& inptr6, const int32_t*& inptr7, const int32_t*& inptr8, + const int32_t*& inptr9, const int32_t*& inptr10, + const int32_t*& inptr11, int32_t*& outptr) { + asm volatile( + "ld1 {v0.4s}, [%[inptr0]], #16\n" // d0 = A0A1A2A3 + "ld1 {v1.4s}, [%[inptr1]], #16\n" // d1 = B0B1B2B3 + "ld1 {v2.4s}, [%[inptr2]], #16\n" // d2 = C0C1C2C3 + "ld1 {v3.4s}, [%[inptr3]], #16\n" // d3 = D0D1D2D3 + "zip1 v12.4s, v0.4s, v2.4s\n" // d12 = A0C0A1C1 + "zip2 v13.4s, v0.4s, v2.4s\n" // d13 = A2C2A3C3 + "zip1 v14.4s, v1.4s, v3.4s\n" // d14 = B0D0B1D1 + "zip2 v15.4s, v1.4s, v3.4s\n" // d15 = B2D2B3D3 + "zip1 v0.4s, v12.4s, v14.4s\n" // d0 = A0B0C0D0 + "zip2 v1.4s, v12.4s, v14.4s\n" // d1 = A1B1C1D1 + "zip1 v2.4s, v13.4s, v15.4s\n" // d2 = A2B2C2D2 + "zip2 v3.4s, v13.4s, v15.4s\n" // d3 = A3B3C3D3 + + "ld1 {v4.4s}, [%[inptr4]], #16\n" // d4 = E0E1E2E3 + "ld1 {v5.4s}, [%[inptr5]], #16\n" // d5 = F0F1F2F3 + "ld1 {v6.4s}, [%[inptr6]], #16\n" // d6 = G0G1G2G3 + "ld1 {v7.4s}, [%[inptr7]], #16\n" // d7 = H0H1H2H3 + "zip1 v16.4s, v4.4s, v6.4s\n" // d16 = E0G0E1G1 + "zip2 v17.4s, v4.4s, v6.4s\n" // d17 = E2G2E3G3 + "zip1 v18.4s, v5.4s, v7.4s\n" // d18 = F0H0F1H1 + "zip2 v19.4s, v5.4s, v7.4s\n" // d19 = F2H2F3H3 + "zip1 v4.4s, v16.4s, v18.4s\n" // d4 = E0F0G0H0 + "zip2 v5.4s, v16.4s, v18.4s\n" // d5 = E1F1G1H1 + "zip1 v6.4s, v17.4s, v19.4s\n" // d6 = E2F2G2H2 + "zip2 v7.4s, v17.4s, v19.4s\n" // d7 = E3F3G3H3 + + "ld1 {v8.4s}, [%[inptr8]], #16\n" // d8 = I0I1I2I3 + "ld1 {v9.4s}, [%[inptr9]], #16\n" // d9 = J0J1J2J3 + "ld1 {v10.4s}, [%[inptr10]], #16\n" // d10 = K0K1K2K3 + "ld1 {v11.4s}, [%[inptr11]], #16\n" // d11 = L0L1L2L3 + "zip1 v20.4s, v8.4s, v10.4s\n" // d20 = I0K0I1K1 + "zip2 v21.4s, v8.4s, v10.4s\n" // d21 = I2K2I3K3 + "zip1 v22.4s, v9.4s, v11.4s\n" // d22 = J0L0J1L1 + "zip2 v23.4s, v9.4s, v11.4s\n" // d23 = J2L2J3L3 + "zip1 v8.4s, v20.4s, v22.4s\n" // d8 = I0J0K0L0 + "zip2 v9.4s, v20.4s, v22.4s\n" // d9 = I1J1K1L1 + "zip1 v10.4s, v21.4s, v23.4s\n" // d10 = I2J2K2L2 + "zip2 v11.4s, v21.4s, v23.4s\n" // d11 = I3J3K3L3 + + "st1 {v0.4s}, [%[outptr]], #16\n" + "st1 {v4.4s}, [%[outptr]], #16\n" + "st1 {v8.4s}, [%[outptr]], #16\n" + "st1 {v1.4s}, [%[outptr]], #16\n" + "st1 {v5.4s}, [%[outptr]], #16\n" + "st1 {v9.4s}, [%[outptr]], #16\n" + "st1 {v2.4s}, [%[outptr]], #16\n" + "st1 {v6.4s}, [%[outptr]], #16\n" + "st1 {v10.4s}, [%[outptr]], #16\n" + "st1 {v3.4s}, [%[outptr]], #16\n" + "st1 {v7.4s}, [%[outptr]], #16\n" + "st1 {v11.4s}, [%[outptr]], #16\n" + + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), + [inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10), + [inptr11] "+r"(inptr11), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "cc", "memory"); +} + +template +static inline void interleave_12x1_4_h( + const T*& in0, const T*& in1, const T*& in2, const T*& in3, + const T*& in4, const T*& in5, const T*& in6, const T*& in7, + const T*& in8, const T*& in9, const T*& in10, const T*& in11, T*& out) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_12x1_4_h only support uint16_t and int16_t"); + const int16_t*& inptr0 = reinterpret_cast(in0); + const int16_t*& inptr1 = reinterpret_cast(in1); + const int16_t*& inptr2 = reinterpret_cast(in2); + const int16_t*& inptr3 = reinterpret_cast(in3); + const int16_t*& inptr4 = reinterpret_cast(in4); + const int16_t*& inptr5 = reinterpret_cast(in5); + const int16_t*& inptr6 = reinterpret_cast(in6); + const int16_t*& inptr7 = reinterpret_cast(in7); + const int16_t*& inptr8 = reinterpret_cast(in8); + const int16_t*& inptr9 = reinterpret_cast(in9); + const int16_t*& inptr10 = reinterpret_cast(in10); + const int16_t*& inptr11 = reinterpret_cast(in11); + int16_t*& outptr = reinterpret_cast(out); + asm volatile( + "ld1 {v0.4h}, [%[inptr0]], #8\n" // d0 = A0A1A2A3 + "ld1 {v1.4h}, [%[inptr1]], #8\n" // d1 = B0B1B2B3 + "ld1 {v2.4h}, [%[inptr2]], #8\n" // d2 = C0C1C2C3 + "ld1 {v3.4h}, [%[inptr3]], #8\n" // d3 = D0D1D2D3 + "zip1 v12.4h, v0.4h, v2.4h\n" // d12 = A0C0A1C1 + "zip2 v13.4h, v0.4h, v2.4h\n" // d13 = A2C2A3C3 + "zip1 v14.4h, v1.4h, v3.4h\n" // d14 = B0D0B1D1 + "zip2 v15.4h, v1.4h, v3.4h\n" // d15 = B2D2B3D3 + "zip1 v0.4h, v12.4h, v14.4h\n" // d0 = A0B0C0D0 + "zip2 v1.4h, v12.4h, v14.4h\n" // d1 = A1B1C1D1 + "zip1 v2.4h, v13.4h, v15.4h\n" // d2 = A2B2C2D2 + "zip2 v3.4h, v13.4h, v15.4h\n" // d3 = A3B3C3D3 + + "ld1 {v4.4h}, [%[inptr4]], #8\n" // d4 = E0E1E2E3 + "ld1 {v5.4h}, [%[inptr5]], #8\n" // d5 = F0F1F2F3 + "ld1 {v6.4h}, [%[inptr6]], #8\n" // d6 = G0G1G2G3 + "ld1 {v7.4h}, [%[inptr7]], #8\n" // d7 = H0H1H2H3 + "zip1 v16.4h, v4.4h, v6.4h\n" // d16 = E0G0E1G1 + "zip2 v17.4h, v4.4h, v6.4h\n" // d17 = E2G2E3G3 + "zip1 v18.4h, v5.4h, v7.4h\n" // d18 = F0H0F1H1 + "zip2 v19.4h, v5.4h, v7.4h\n" // d19 = F2H2F3H3 + "zip1 v4.4h, v16.4h, v18.4h\n" // d4 = E0F0G0H0 + "zip2 v5.4h, v16.4h, v18.4h\n" // d5 = E1F1G1H1 + "zip1 v6.4h, v17.4h, v19.4h\n" // d6 = E2F2G2H2 + "zip2 v7.4h, v17.4h, v19.4h\n" // d7 = E3F3G3H3 + + "ld1 {v8.4h}, [%[inptr8]], #8\n" // d8 = I0I1I2I3 + "ld1 {v9.4h}, [%[inptr9]], #8\n" // d9 = J0J1J2J3 + "ld1 {v10.4h}, [%[inptr10]], #8\n" // d10 = K0K1K2K3 + "ld1 {v11.4h}, [%[inptr11]], #8\n" // d11 = L0L1L2L3 + "zip1 v20.4h, v8.4h, v10.4h\n" // d20 = I0K0I1K1 + "zip2 v21.4h, v8.4h, v10.4h\n" // d21 = I2K2I3K3 + "zip1 v22.4h, v9.4h, v11.4h\n" // d22 = J0L0J1L1 + "zip2 v23.4h, v9.4h, v11.4h\n" // d23 = J2L2J3L3 + "zip1 v8.4h, v20.4h, v22.4h\n" // d8 = I0J0K0L0 + "zip2 v9.4h, v20.4h, v22.4h\n" // d9 = I1J1K1L1 + "zip1 v10.4h, v21.4h, v23.4h\n" // d10 = I2J2K2L2 + "zip2 v11.4h, v21.4h, v23.4h\n" // d11 = I3J3K3L3 + + "st1 {v0.4h}, [%[outptr]], #8\n" // d0 = A0B0C0D0 + "st1 {v4.4h}, [%[outptr]], #8\n" // d4 = E0F0G0H0 + "st1 {v8.4h}, [%[outptr]], #8\n" // d8 = I0J0K0L0 + "st1 {v1.4h}, [%[outptr]], #8\n" // d1 = A1B1C1D1 + "st1 {v5.4h}, [%[outptr]], #8\n" // d5 = E1F1G1H1 + "st1 {v9.4h}, [%[outptr]], #8\n" // d9 = I1J1K1L1 + "st1 {v2.4h}, [%[outptr]], #8\n" // d2 = A2B2C2D2 + "st1 {v6.4h}, [%[outptr]], #8\n" // d6 = E2F2G2H2 + "st1 {v10.4h}, [%[outptr]], #8\n" // d10 = I2J2K2L2 + "st1 {v3.4h}, [%[outptr]], #8\n" // d3 = A3B3C3D3 + "st1 {v7.4h}, [%[outptr]], #8\n" // d7 = E3F3G3H3 + "st1 {v11.4h}, [%[outptr]], #8\n" // d11 = I3J3K3L3 + + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), + [inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10), + [inptr11] "+r"(inptr11), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "cc", "memory"); +} + +template +static inline void interleave_12x4_4_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + const T*& inptr8, const T*& inptr9, + const T*& inptr10, const T*& inptr11, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_12x4_4_b only support uint8_t and int8_t"); + interleave_12x1_4_s(reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(inptr4), + reinterpret_cast(inptr5), + reinterpret_cast(inptr6), + reinterpret_cast(inptr7), + reinterpret_cast(inptr8), + reinterpret_cast(inptr9), + reinterpret_cast(inptr10), + reinterpret_cast(inptr11), + reinterpret_cast(outptr)); +} + +static inline void interleave_8x1_4_s( + const int32_t*& inptr0, const int32_t*& inptr1, const int32_t*& inptr2, + const int32_t*& inptr3, const int32_t*& inptr4, const int32_t*& inptr5, + const int32_t*& inptr6, const int32_t*& inptr7, int32_t*& outptr) { + asm volatile( + "ld1 {v0.4s}, [%[inptr0]], #16\n" // d0 = A0A1A2A3 + "ld1 {v1.4s}, [%[inptr1]], #16\n" // d1 = B0B1B2B3 + "ld1 {v2.4s}, [%[inptr2]], #16\n" // d2 = C0C1C2C3 + "ld1 {v3.4s}, [%[inptr3]], #16\n" // d3 = D0D1D2D3 + "zip1 v8.4s, v0.4s, v2.4s\n" // d8 = A0C0A1C1 + "zip2 v9.4s, v0.4s, v2.4s\n" // d9 = A2C2A3C3 + "zip1 v10.4s, v1.4s, v3.4s\n" // d10 = B0D0B1D1 + "zip2 v11.4s, v1.4s, v3.4s\n" // d11 = B2D2B3D3 + "zip1 v12.4s, v8.4s, v10.4s\n" // d12 = A0B0C0D0 + "zip2 v13.4s, v8.4s, v10.4s\n" // d13 = A1B1C1D1 + "zip1 v14.4s, v9.4s, v11.4s\n" // d14 = A2B2C2D2 + "zip2 v15.4s, v9.4s, v11.4s\n" // d15 = A3B3C3D3 + + "ld1 {v4.4s}, [%[inptr4]], #16\n" // d4 = E0E1E2E3 + "ld1 {v5.4s}, [%[inptr5]], #16\n" // d5 = F0F1F2F3 + "ld1 {v6.4s}, [%[inptr6]], #16\n" // d6 = G0G1G2G3 + "ld1 {v7.4s}, [%[inptr7]], #16\n" // d7 = H0H1H2H3 + "zip1 v16.4s, v4.4s, v6.4s\n" // d16 = E0G0E1G1 + "zip2 v17.4s, v4.4s, v6.4s\n" // d17 = E2G2E3G3 + "zip1 v18.4s, v5.4s, v7.4s\n" // d18 = F0H0F1H1 + "zip2 v19.4s, v5.4s, v7.4s\n" // d19 = F2H2F3H3 + "zip1 v20.4s, v16.4s, v18.4s\n" // d20 = E0F0G0H0 + "zip2 v21.4s, v16.4s, v18.4s\n" // d21 = E1F1G1H1 + "zip1 v22.4s, v17.4s, v19.4s\n" // d22 = E2F2G2H2 + "zip2 v23.4s, v17.4s, v19.4s\n" // d23 = E3F3G3H3 + + "st1 {v12.4s}, [%[outptr]], #16\n" + "st1 {v20.4s}, [%[outptr]], #16\n" + "st1 {v13.4s}, [%[outptr]], #16\n" + "st1 {v21.4s}, [%[outptr]], #16\n" + "st1 {v14.4s}, [%[outptr]], #16\n" + "st1 {v22.4s}, [%[outptr]], #16\n" + "st1 {v15.4s}, [%[outptr]], #16\n" + "st1 {v23.4s}, [%[outptr]], #16\n" + + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "cc", "memory"); +} + +static inline void interleave_8x1_2_d( + const int64_t*& inptr0, const int64_t*& inptr1, const int64_t*& inptr2, + const int64_t*& inptr3, const int64_t*& inptr4, const int64_t*& inptr5, + const int64_t*& inptr6, const int64_t*& inptr7, int64_t*& outptr) { + asm volatile( + "ld1 {v0.2d}, [%[inptr0]], #16\n" // d0 = A0A1 + "ld1 {v1.2d}, [%[inptr1]], #16\n" // d1 = B0B1 + "ld1 {v2.2d}, [%[inptr2]], #16\n" // d2 = C0C1 + "ld1 {v3.2d}, [%[inptr3]], #16\n" // d3 = D0D1 + "ld1 {v4.2d}, [%[inptr4]], #16\n" // d4 = E0E1 + "ld1 {v5.2d}, [%[inptr5]], #16\n" // d5 = F0F1 + "ld1 {v6.2d}, [%[inptr6]], #16\n" // d6 = G0G1 + "ld1 {v7.2d}, [%[inptr7]], #16\n" // d7 = H0H1 + + "zip1 v8.2d, v0.2d, v1.2d\n" // d8 = A0B0 + "zip2 v9.2d, v0.2d, v1.2d\n" // d9 = A1B1 + "zip1 v10.2d, v2.2d, v3.2d\n" // d10 = C0D0 + "zip2 v11.2d, v2.2d, v3.2d\n" // d11 = C1D1 + "zip1 v12.2d, v4.2d, v5.2d\n" // d12 = E0F0 + "zip2 v13.2d, v4.2d, v5.2d\n" // d13 = E1F1 + "zip1 v14.2d, v6.2d, v7.2d\n" // d14 = G0H0 + "zip2 v15.2d, v6.2d, v7.2d\n" // d15 = G1H1 + + "st1 {v8.2d}, [%[outptr]], #16\n" + "st1 {v10.2d}, [%[outptr]], #16\n" + "st1 {v12.2d}, [%[outptr]], #16\n" + "st1 {v14.2d}, [%[outptr]], #16\n" + "st1 {v9.2d}, [%[outptr]], #16\n" + "st1 {v11.2d}, [%[outptr]], #16\n" + "st1 {v13.2d}, [%[outptr]], #16\n" + "st1 {v15.2d}, [%[outptr]], #16\n" + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "cc", "memory"); +} + +static inline void interleave_8x2_2_d( + const int64_t*& inptr0, const int64_t*& inptr1, const int64_t*& inptr2, + const int64_t*& inptr3, const int64_t*& inptr4, const int64_t*& inptr5, + const int64_t*& inptr6, const int64_t*& inptr7, int64_t*& outptr) { + asm volatile( + "ld1 {v0.2d}, [%[inptr0]], #16\n" // d0 = A0 + "ld1 {v1.2d}, [%[inptr0]], #16\n" // d1 = A1 + "ld1 {v2.2d}, [%[inptr1]], #16\n" // d2 = B0 + "ld1 {v3.2d}, [%[inptr1]], #16\n" // d3 = B1 + "ld1 {v4.2d}, [%[inptr2]], #16\n" // d4 = C0 + "ld1 {v5.2d}, [%[inptr2]], #16\n" // d5 = C1 + "ld1 {v6.2d}, [%[inptr3]], #16\n" // d6 = D0 + "ld1 {v7.2d}, [%[inptr3]], #16\n" // d7 = D1 + "ld1 {v8.2d}, [%[inptr4]], #16\n" // d8 = E0 + "ld1 {v9.2d}, [%[inptr4]], #16\n" // d9 = E1 + "ld1 {v10.2d}, [%[inptr5]], #16\n" // d10 = F0 + "ld1 {v11.2d}, [%[inptr5]], #16\n" // d11 = F1 + "ld1 {v12.2d}, [%[inptr6]], #16\n" // d12 = G0 + "ld1 {v13.2d}, [%[inptr6]], #16\n" // d13 = G1 + "ld1 {v14.2d}, [%[inptr7]], #16\n" // d14 = H0 + "ld1 {v15.2d}, [%[inptr7]], #16\n" // d15 = H1 + + "st1 {v0.2d}, [%[outptr]], #16\n" + "st1 {v2.2d}, [%[outptr]], #16\n" + "st1 {v4.2d}, [%[outptr]], #16\n" + "st1 {v6.2d}, [%[outptr]], #16\n" + "st1 {v8.2d}, [%[outptr]], #16\n" + "st1 {v10.2d}, [%[outptr]], #16\n" + "st1 {v12.2d}, [%[outptr]], #16\n" + "st1 {v14.2d}, [%[outptr]], #16\n" + "st1 {v1.2d}, [%[outptr]], #16\n" + "st1 {v3.2d}, [%[outptr]], #16\n" + "st1 {v5.2d}, [%[outptr]], #16\n" + "st1 {v7.2d}, [%[outptr]], #16\n" + "st1 {v9.2d}, [%[outptr]], #16\n" + "st1 {v11.2d}, [%[outptr]], #16\n" + "st1 {v13.2d}, [%[outptr]], #16\n" + "st1 {v15.2d}, [%[outptr]], #16\n" + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "cc", "memory"); +} + +template +static inline void interleave_8x4_4_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_8x4_4_b only support uint8_t and int8_t"); + interleave_8x1_4_s(reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(inptr4), + reinterpret_cast(inptr5), + reinterpret_cast(inptr6), + reinterpret_cast(inptr7), + reinterpret_cast(outptr)); +} + +template +static inline void interleave_8x4_1_h(const T*& in0, const T*& in1, + const T*& in2, const T*& in3, T* out) { + static_assert(sizeof(T) == 2, "only support size == 2"); + asm volatile( + "ldr q0, [%[in0]], #16\n" // A1A2A3A4A5A6A7A8 + "ldr q1, [%[in1]], #16\n" // B1B2B3B4B5B6B7B8 + "ldr q2, [%[in2]], #16\n" // C1C2C3C4C5C6C7C8 + "ldr q3, [%[in3]], #16\n" // D1D2D3D4D5D6D7D8 + + "trn1 v4.8h, v0.8h, v1.8h\n" // A1B1A3B3A5B5A7B7 + "trn2 v5.8h, v0.8h, v1.8h\n" // A2B2A4B4A6B6A8B8 + "trn1 v6.8h, v2.8h, v3.8h\n" // C1D1C3D3C5D5C7D7 + "trn2 v7.8h, v2.8h, v3.8h\n" // C2D2C4D4C6D6C8D8 + + "zip1 v8.4s, v4.4s, v6.4s\n" // A1B1C1D1A3B3C3D3 + "zip2 v9.4s, v4.4s, v6.4s\n" // A5B5C5D5A7B7C7D7 + "zip1 v10.4s, v5.4s, v7.4s\n" // A2B2C2D2A4B4C4D4 + "zip2 v11.4s, v5.4s, v7.4s\n" // A6B6C6D6A8B8C8D8 + + "zip1 v12.2d, v8.2d, v10.2d\n" // A1B1C1D1A2B2C2D2 + "zip2 v13.2d, v8.2d, v10.2d\n" // A3B3C3D3A4B4C4D4 + "zip1 v14.2d, v9.2d, v11.2d\n" // A5B5C5D5A6B6C6D6 + "zip2 v15.2d, v9.2d, v11.2d\n" // A7B7C7D7A8B8C8D8 + + "st1 {v12.2d}, [%[out]], #16\n" + "st1 {v13.2d}, [%[out]], #16\n" + "st1 {v14.2d}, [%[out]], #16\n" + "st1 {v15.2d}, [%[out]], #16\n" + : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), + [in3] "+r"(in3), [out] "+r"(out) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "memory"); +} + +template +static inline void interleave_8x8_2_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_8x8_2_b only support uint8_t and int8_t"); + interleave_8x1_2_d(reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(inptr4), + reinterpret_cast(inptr5), + reinterpret_cast(inptr6), + reinterpret_cast(inptr7), + reinterpret_cast(outptr)); +} + +template +static inline void interleave_8x8_2_h(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_8x8_2_h only support uint16_t and int16_t"); + interleave_8x2_2_d(reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(inptr4), + reinterpret_cast(inptr5), + reinterpret_cast(inptr6), + reinterpret_cast(inptr7), + reinterpret_cast(outptr)); +} + +template +static inline void interleave_8x2_8_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_8x2_8_b only support uint8_t and int8_t"); + interleave_8x1_8_h(reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(inptr4), + reinterpret_cast(inptr5), + reinterpret_cast(inptr6), + reinterpret_cast(inptr7), + reinterpret_cast(outptr)); +} + +template +static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_8x8_1_b only support uint8_t and int8_t"); + asm volatile( + "ld1 {v0.d}[0], [%[inptr0]], 8\n" // A1A2A3A4A5A6A7A8 + "ld1 {v0.d}[1], [%[inptr1]], 8\n" // B1B2B3B4B5B6B7B8 + "ld1 {v1.d}[0], [%[inptr2]], 8\n" // C1C2C3C4C5C6C7C8 + "ld1 {v1.d}[1], [%[inptr3]], 8\n" // D1D2D3D4D5D6D7D8 + "ld1 {v2.d}[0], [%[inptr4]], 8\n" // E1E2E3E4E5E6E7E8 + "ld1 {v2.d}[1], [%[inptr5]], 8\n" // F1F2F3F4F5F6F7F8 + "ld1 {v3.d}[0], [%[inptr6]], 8\n" // G1G2G3G4G5G6G7G8 + "ld1 {v3.d}[1], [%[inptr7]], 8\n" // H1H2H3H4H5H6H7H8 + + "st1 {v0.2d}, [%[outptr]], 16\n" // A1A2A3A4A5A6A7A8B1B2B3B4B5B6B7B8 + "st1 {v1.2d}, [%[outptr]], 16\n" // C1C2C3C4C5C6C7C8D1D2D3D4D5D6D7D8 + "st1 {v2.2d}, [%[outptr]], 16\n" // E1E2E3E4E5E6E7E8F1F2F3F4F5F6F7F8 + "st1 {v3.2d}, [%[outptr]], 16\n" // G1G2G3G4G5G6G7G8H1H2H3H4H5H6H7H8 + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "memory"); +} + +template +static inline void interleave_8x8_1_h(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_8x8_1_h only support uint16_t and int16_t"); + asm volatile( + "ld1 {v0.8h}, [%[inptr0]], #16\n" // A1A2A3A4A5A6A7A8 + "ld1 {v1.8h}, [%[inptr1]], #16\n" // B1B2B3B4B5B6B7B8 + "ld1 {v2.8h}, [%[inptr2]], #16\n" // C1C2C3C4C5C6C7C8 + "ld1 {v3.8h}, [%[inptr3]], #16\n" // D1D2D3D4D5D6D7D8 + "ld1 {v4.8h}, [%[inptr4]], #16\n" // E1E2E3E4E5E6E7E8 + "ld1 {v5.8h}, [%[inptr5]], #16\n" // F1F2F3F4F5F6F7F8 + "ld1 {v6.8h}, [%[inptr6]], #16\n" // G1G2G3G4G5G6G7G8 + "ld1 {v7.8h}, [%[inptr7]], #16\n" // H1H2H3H4H5H6H7H8 + + "st1 {v0.8h}, [%[outptr]], #16\n" // A1A2A3A4A5A6A7A8 + "st1 {v1.8h}, [%[outptr]], #16\n" // B1B2B3B4B5B6B7B8 + "st1 {v2.8h}, [%[outptr]], #16\n" // C1C2C3C4C5C6C7C8 + "st1 {v3.8h}, [%[outptr]], #16\n" // D1D2D3D4D5D6D7D8 + "st1 {v4.8h}, [%[outptr]], #16\n" // E1E2E3E4E5E6E7E8 + "st1 {v5.8h}, [%[outptr]], #16\n" // F1F2F3F4F5F6F7F8 + "st1 {v6.8h}, [%[outptr]], #16\n" // G1G2G3G4G5G6G7G8 + "st1 {v7.8h}, [%[outptr]], #16\n" // H1H2H3H4H5H6H7H8 + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); +} + +static inline void interleave_4x1_4_s(const int32_t*& inptr0, + const int32_t*& inptr1, + const int32_t*& inptr2, + const int32_t*& inptr3, + int32_t*& outptr) { + asm volatile( + "ld1 {v0.4s}, [%[inptr0]], #16\n" // d0 = A0A1A2A3 + "ld1 {v1.4s}, [%[inptr1]], #16\n" // d1 = B0B1B2B3 + "ld1 {v2.4s}, [%[inptr2]], #16\n" // d2 = C0C1C2C3 + "ld1 {v3.4s}, [%[inptr3]], #16\n" // d3 = D0D1D2D3 + "zip1 v8.4s, v0.4s, v2.4s\n" // d8 = A0C0A1C1 + "zip2 v9.4s, v0.4s, v2.4s\n" // d9 = A2C2A3C3 + "zip1 v10.4s, v1.4s, v3.4s\n" // d10 = B0D0B1D1 + "zip2 v11.4s, v1.4s, v3.4s\n" // d11 = B2D2B3D3 + "zip1 v12.4s, v8.4s, v10.4s\n" // d12 = A0B0C0D0 + "zip2 v13.4s, v8.4s, v10.4s\n" // d13 = A1B1C1D1 + "zip1 v14.4s, v9.4s, v11.4s\n" // d14 = A2B2C2D2 + "zip2 v15.4s, v9.4s, v11.4s\n" // d15 = A3B3C3D3 + + "st1 {v12.4s}, [%[outptr]], #16\n" + "st1 {v13.4s}, [%[outptr]], #16\n" + "st1 {v14.4s}, [%[outptr]], #16\n" + "st1 {v15.4s}, [%[outptr]], #16\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "cc", "memory"); +} + +template +static inline void interleave_4x8_1_s(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 4, "only support size == 4"); + asm volatile( + "ld1 {v0.4s, v1.4s}, [%[inptr0]], #32\n" + "ld1 {v2.4s, v3.4s}, [%[inptr1]], #32\n" + "ld1 {v4.4s, v5.4s}, [%[inptr2]], #32\n" + "ld1 {v6.4s, v7.4s}, [%[inptr3]], #32\n" + "st1 {v0.4s, v1.4s}, [%[outptr]], #32\n" + "st1 {v2.4s, v3.4s}, [%[outptr]], #32\n" + "st1 {v4.4s, v5.4s}, [%[outptr]], #32\n" + "st1 {v6.4s, v7.4s}, [%[outptr]], #32\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); +} + +template +static inline void interleave_4x12_1_s(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 4, "only support size == 4"); + asm volatile( + "ld1 {v0.4s, v1.4s, v2.4s}, [%[inptr0]], #48\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [%[inptr1]], #48\n" + "ld1 {v8.4s, v9.4s, v10.4s}, [%[inptr2]], #48\n" + "ld1 {v12.4s, v13.4s, v14.4s}, [%[inptr3]], #48\n" + "st1 {v0.4s, v1.4s, v2.4s}, [%[outptr]], #48\n" + "st1 {v4.4s, v5.4s, v6.4s}, [%[outptr]], #48\n" + "st1 {v8.4s, v9.4s, v10.4s}, [%[outptr]], #48\n" + "st1 {v12.4s, v13.4s, v14.4s}, [%[outptr]], #48\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v4", "v5", "v6", "v8", "v9", "v10", "v12", + "v13", "v14", "cc", "memory"); +} + +template +static inline void interleave_4x16_1_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 1, "only support size == 1"); + asm volatile( + "ld1 {v0.4s}, [%[inptr0]], #16\n" // d0 = A0A1A2A3 + "ld1 {v1.4s}, [%[inptr1]], #16\n" // d1 = B0B1B2B3 + "ld1 {v2.4s}, [%[inptr2]], #16\n" // d2 = C0C1C2C3 + "ld1 {v3.4s}, [%[inptr3]], #16\n" // d3 = D0D1D2D3 + "st1 {v0.4s}, [%[outptr]], #16\n" + "st1 {v1.4s}, [%[outptr]], #16\n" + "st1 {v2.4s}, [%[outptr]], #16\n" + "st1 {v3.4s}, [%[outptr]], #16\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "cc", "memory"); +} + +template +static inline void interleave_4x16_1_s(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 4, "only support size == 4"); + asm volatile( + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[inptr1]], #64\n" + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[inptr2]], #64\n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[inptr3]], #64\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]], #64\n" + "st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[outptr]], #64\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[outptr]], #64\n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[outptr]], #64\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "cc", "memory"); +} + +template +static inline void interleave_4x2_4_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_4x2_4_b only support uint8_t and int8_t"); + interleave_4x1_4_h(reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(outptr)); +} + +template +static inline void interleave_4x4_4_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_4x4_4_b only support uint8_t and int8_t"); + interleave_4x1_4_s(reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(outptr)); +} + +template +static inline void interleave_4x4_1_s(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 4, "interleave_4x4_1_s only support size == 4"); + asm volatile( + "ld1 {v0.4s}, [%[inptr0]], #16\n" + "ld1 {v1.4s}, [%[inptr1]], #16\n" + "ld1 {v2.4s}, [%[inptr2]], #16\n" + "ld1 {v3.4s}, [%[inptr3]], #16\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]], #64\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "cc", "memory"); +} + +template +static inline void interleave_4x8_2_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_4x8_2_b only support uint8_t and int8_t"); + interleave_4x1_2_d(reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(outptr)); +} + +template +static inline void interleave_4x8_2_h(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_4x8_2_h only support uint16_t and int16_t"); + interleave_4x2_2_d(reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(outptr)); +} + +template +static inline void interleave_1x16_1_s(const T*& inptr0, T*& outptr) { + static_assert(sizeof(T) == 4, "interleave_1x16_1_s only support size == 4"); + asm volatile( + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]], #64\n" + + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "cc", "memory"); +} +template +static inline void interleave_1x12_1_s(const T*& inptr0, T*& outptr) { + static_assert(sizeof(T) == 4, "interleave_1x12_1_s only support size == 4"); + asm volatile( + "ld1 {v0.4s, v1.4s, v2.4s}, [%[inptr0]], #48\n" + "st1 {v0.4s, v1.4s, v2.4s}, [%[outptr]], #48\n" + + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "cc", "memory"); +} + +template +static inline void interleave_1x8_1_s(const T*& inptr0, T*& outptr) { + static_assert(sizeof(T) == 4, "interleave_1x8_1_s only support size == 4"); + asm volatile( + "ld1 {v0.4s, v1.4s}, [%[inptr0]], #32\n" + "st1 {v0.4s, v1.4s}, [%[outptr]], #32\n" + + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) + : + : "v0", "v1", "cc", "memory"); +} + +template +static inline void interleave_1x4_1_s(const T*& inptr0, T*& outptr) { + static_assert(sizeof(T) == 4, "interleave_1x4_1_s only support size == 4"); + asm volatile( + "ld1 {v0.4s}, [%[inptr0]], #16\n" + "st1 {v0.4s}, [%[outptr]], #16\n" + + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) + : + : "v0", "cc", "memory"); +} + +template +static inline void interleave_helper(const T*& inptr, T*& outptr, int unroll_k, + int ksize, T val = 0) { + int k = 0; + for (; k < ksize; k++) { + *outptr++ = *inptr++; + } + for (; k < unroll_k; k++) { + *outptr++ = val; + } +} + +template +static inline void interleave_1(const T*& inptr0, T*& outptr, int unroll_k, + int ksize, T val = 0) { + for (int k = 0; k < ksize; k += unroll_k) { + int size = std::min(unroll_k, ksize - k); + interleave_helper(inptr0, outptr, unroll_k, size, val); + } +} + +template +static inline void interleave_4(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, T*& outptr, + int unroll_k, int ksize, T val = 0) { + for (int k = 0; k < ksize; k += unroll_k) { + int size = std::min(unroll_k, ksize - k); + interleave_helper(inptr0, outptr, unroll_k, size, val); + interleave_helper(inptr1, outptr, unroll_k, size, val); + interleave_helper(inptr2, outptr, unroll_k, size, val); + interleave_helper(inptr3, outptr, unroll_k, size, val); + } +} + +template +static inline void interleave_8(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, T*& outptr, + int unroll_k, int ksize, T val = 0) { + for (int k = 0; k < ksize; k += unroll_k) { + int size = std::min(unroll_k, ksize - k); + interleave_helper(inptr0, outptr, unroll_k, size, val); + interleave_helper(inptr1, outptr, unroll_k, size, val); + interleave_helper(inptr2, outptr, unroll_k, size, val); + interleave_helper(inptr3, outptr, unroll_k, size, val); + interleave_helper(inptr4, outptr, unroll_k, size, val); + interleave_helper(inptr5, outptr, unroll_k, size, val); + interleave_helper(inptr6, outptr, unroll_k, size, val); + interleave_helper(inptr7, outptr, unroll_k, size, val); + } +} + +template +static inline void interleave_12(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + const T*& inptr8, const T*& inptr9, + const T*& inptr10, const T*& inptr11, + T*& outptr, int unroll_k, int ksize) { + for (int k = 0; k < ksize; k += unroll_k) { + int size = std::min(unroll_k, ksize - k); + interleave_helper(inptr0, outptr, unroll_k, size); + interleave_helper(inptr1, outptr, unroll_k, size); + interleave_helper(inptr2, outptr, unroll_k, size); + interleave_helper(inptr3, outptr, unroll_k, size); + interleave_helper(inptr4, outptr, unroll_k, size); + interleave_helper(inptr5, outptr, unroll_k, size); + interleave_helper(inptr6, outptr, unroll_k, size); + interleave_helper(inptr7, outptr, unroll_k, size); + interleave_helper(inptr8, outptr, unroll_k, size); + interleave_helper(inptr9, outptr, unroll_k, size); + interleave_helper(inptr10, outptr, unroll_k, size); + interleave_helper(inptr11, outptr, unroll_k, size); + } +} +/* ======================== transpose pack B ======================== */ +/** + * transpose_INTERLEAVE_UNROLLK_BATCH_type + * + * BATCH means process BATCH * INTERLEAVE cols once, BATCH * sizeof(TYPE) * + * INTERLEAVE = 16bytes(128bits, a vector size). + * + * the elements traverse order: + * rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[i, j] + */ +template +static inline void transpose_24x4_1_h(const T*& in0, const T*& in1, + const T*& in2, const T*& in3, T* out) { + static_assert(sizeof(T) == 2, "only support size == 2"); + asm volatile( + "ldp q0, q1, [%[in0]], #32\n" + "stp q0, q1, [%[out]]\n" + "ldr q2, [%[in0]], #16\n" + ASM_PREFETCH("[%[in0], #192]") + "ldp q3, q4, [%[in1]], #32\n" + "stp q2, q3, [%[out], #32]\n" + "ldr q5, [%[in1]], #16\n" + ASM_PREFETCH("[%[in1], #192]") + "stp q4, q5, [%[out], #64]\n" + "ldp q6, q7, [%[in2]], #32\n" + "stp q6, q7, [%[out], #96]\n" + "ldr q8, [%[in2]], #16\n" + ASM_PREFETCH("[%[in2], #192]") + "ldp q9, q10, [%[in3]], #32\n" + "stp q8, q9, [%[out], #128]\n" + "ldr q11, [%[in3]], #16\n" + "stp q10, q11, [%[out], #160]\n" + ASM_PREFETCH("[%[in3], #192]") + + : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), + [in3] "+r"(in3), [out] "+r"(out) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "memory"); +} + +template +static inline void transpose_16x4_1_h(const T*& in0, const T*& in1, + const T*& in2, const T*& in3, T* out) { + static_assert(sizeof(T) == 2, "only support size == 2"); + asm volatile( + "ldp q0, q1, [%[in0]], #32\n" + "stp q0, q1, [%[out]]\n" + "ldp q2, q3, [%[in1]], #32\n" + "stp q2, q3, [%[out], #32]\n" + "ldp q4, q5, [%[in2]], #32\n" + "stp q4, q5, [%[out], #64]\n" + "ldp q6, q7, [%[in3]], #32\n" + "stp q6, q7, [%[out], #96]\n" + : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), + [in3] "+r"(in3), [out] "+r"(out) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); +} + +template +static inline void transpose_8x4_1_h(const T*& in0, const T*& in1, + const T*& in2, const T*& in3, T* out) { + static_assert(sizeof(T) == 2, "only support size == 2"); + asm volatile( + "ldr q0, [%[in0]], #16\n" + "str q0, [%[out]]\n" + "ldr q1, [%[in1]], #16\n" + "str q1, [%[out], #16]\n" + "ldr q2, [%[in2]], #16\n" + "str q2, [%[out], #32]\n" + "ldr q3, [%[in3]], #16\n" + "str q3, [%[out], #48]\n" + : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), + [in3] "+r"(in3), [out] "+r"(out) + : + : "v0", "v1", "v2", "v3", "memory"); +} + +template +static inline void transpose_24x2_1_h(const T*& in0, const T*& in1, T* out) { + static_assert(sizeof(T) == 2, "only support size == 2"); + asm volatile( + "ldp q0, q1, [%[in0]], #32\n" + "stp q0, q1, [%[out]]\n" + "ldr q2, [%[in0]], #16\n" + ASM_PREFETCH("[%[in0], #192]") + "ldp q3, q4, [%[in1]], #32\n" + "stp q2, q3, [%[out], #32]\n" + "ldr q5, [%[in1]], #16\n" + ASM_PREFETCH("[%[in1], #192]") + "stp q4, q5, [%[out], #64]\n" + : [in0] "+r"(in0), [in1] "+r"(in1), [out] "+r"(out) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "memory"); +} + +template +static inline void transpose_16x2_1_h(const T*& in0, const T*& in1, T* out) { + static_assert(sizeof(T) == 2, "only support size == 2"); + asm volatile( + "ldp q0, q1, [%[in0]], #32\n" + "stp q0, q1, [%[out]]\n" + "ldp q2, q3, [%[in1]], #32\n" + "stp q2, q3, [%[out], #32]\n" + : [in0] "+r"(in0), [in1] "+r"(in1), [out] "+r"(out) + : + : "v0", "v1", "v2", "v3", "memory"); +} + +template +static inline void transpose_8x2_1_h(const T*& in0, const T*& in1, T* out) { + static_assert(sizeof(T) == 2, "only support size == 2"); + asm volatile( + "ldr q0, [%[in0]], #16\n" + "str q0, [%[out]]\n" + "ldr q1, [%[in1]], #16\n" + "str q1, [%[out], #16]\n" + : [in0] "+r"(in0), [in1] "+r"(in1), [out] "+r"(out) + : + : "v0", "v1", "memory"); +} + +template +static inline void transpose_24x1_1_h(const T*& in0, T* out) { + static_assert(sizeof(T) == 2, "only support size == 2"); + // clang-format off + asm volatile( + "ldp q0, q1, [%[in0]], #32\n" + "stp q0, q1, [%[out]] \n" + "ldr q2, [%[in0]], #16 \n" + ASM_PREFETCH("[%[in0], #192]") + "str q2, [%[out], #32] \n" + : [in0] "+r"(in0), [out] "+r"(out) + : + : "v0", "v1", "v2", "memory"); + // clang-format on +} + +template +static inline void transpose_16x1_1_h(const T*& in0, T* out) { + static_assert(sizeof(T) == 2, "only support size == 2"); + asm volatile( + "ldp q0, q1, [%[in0]], #32\n" + "stp q0, q1, [%[out]]\n" + : [in0] "+r"(in0), [out] "+r"(out) + : + : "v0", "v1", "memory"); +} + +template +static inline void transpose_12x1_1_h(const T*& in0, T* out) { + static_assert(sizeof(T) == 2, "only support size == 2"); + // clang-format off + asm volatile( + "ld1 {v0.8h}, [%[in0]], #16\n" + "ld1 {v1.4h}, [%[in0]], #8\n" + "st1 {v0.8h}, [%[out]], #16\n" + "st1 {v1.4h}, [%[out]], #8\n" + : [in0] "+r"(in0), [out] "+r"(out) + : + : "v0", "v1", "memory"); + // clang-format on +} + +template +static inline void transpose_8x1_1_h(const T*& in0, T* out) { + static_assert(sizeof(T) == 2, "only support size == 2"); + asm volatile( + "ldr q0, [%[in0]], #16\n" + "str q0, [%[out]]\n" + : [in0] "+r"(in0), [out] "+r"(out) + : + : "v0", "memory"); +} + +template +static inline void transpose_4x1_1_h(const T*& in0, T* out) { + static_assert(sizeof(T) == 2, "only support size == 2"); + // clang-format off + asm volatile( + "ld1 {v0.4h}, [%[in0]], #8\n" + "st1 {v0.4h}, [%[out]], #8\n" + : [in0] "+r"(in0), [out] "+r"(out) + : + : "v0", "memory"); + // clang-format on +} + +template +static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T* outptr, int stride = 16) { + static_assert(sizeof(T) == 4, + "transpose_4x4_1_s only support sizeof(T) == 4"); + + asm volatile( + "ld1 {v0.4s}, [%[inptr0]], 16\n" // A0A1A2A3 + "ld1 {v1.4s}, [%[inptr1]], 16\n" // B0B1B2B3 + "ld1 {v2.4s}, [%[inptr2]], 16\n" // C0C1C2C3 + "ld1 {v3.4s}, [%[inptr3]], 16\n" // D0D1D2D3 + + "zip1 v4.4s, v0.4s, v1.4s\n" + "zip1 v5.4s, v2.4s, v3.4s\n" + "zip2 v6.4s, v0.4s, v1.4s\n" + "zip2 v7.4s, v2.4s, v3.4s\n" + + "zip1 v8.2d, v4.2d, v5.2d\n" + "zip1 v9.2d, v6.2d, v7.2d\n" + "zip2 v10.2d, v4.2d, v5.2d\n" + "zip2 v11.2d, v6.2d, v7.2d\n" + + "st1 {v8.4s}, [%[outptr]], %x[stride]\n" + "st1 {v10.4s}, [%[outptr]], %x[stride]\n" + "st1 {v9.4s}, [%[outptr]], %x[stride]\n" + "st1 {v11.4s}, [%[outptr]], %x[stride]\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr), [stride] "+r"(stride) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "memory"); +} + +template +static inline void transpose_8x4_1_s(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T* outptr) { + static_assert(sizeof(T) == 4, + "transpose_8x4_1_s only support sizeof(T) == 4"); + + asm volatile( + "ld1 {v0.4s}, [%[inptr0]], 16\n" // A0A1A2A3 + "ld1 {v1.4s}, [%[inptr1]], 16\n" // B0B1B2B3 + "ld1 {v2.4s}, [%[inptr2]], 16\n" // C0C1C2C3 + "ld1 {v3.4s}, [%[inptr3]], 16\n" // D0D1D2D3 + "ld1 {v4.4s}, [%[inptr4]], 16\n" // E0E1E2E3 + "ld1 {v5.4s}, [%[inptr5]], 16\n" // F0F1F2F3 + "ld1 {v6.4s}, [%[inptr6]], 16\n" // G0G1G2G3 + "ld1 {v7.4s}, [%[inptr7]], 16\n" // H0H1H2H3 + + "zip1 v8.4s, v0.4s, v1.4s\n" // A0B0A1B1 + "zip2 v9.4s, v0.4s, v1.4s\n" // A2B2A3B3 + "zip1 v10.4s, v2.4s, v3.4s\n" // C0D0C1D1 + "zip2 v11.4s, v2.4s, v3.4s\n" // C2D2C3D3 + "zip1 v12.4s, v4.4s, v5.4s\n" // E0F0E1F1 + "zip2 v13.4s, v4.4s, v5.4s\n" // E2F2E3F3 + "zip1 v14.4s, v6.4s, v7.4s\n" // G0H0G1H1 + "zip2 v15.4s, v6.4s, v7.4s\n" // G2H2G3H3 + + "zip1 v0.2d, v8.2d, v10.2d\n" // A0B0C0D0 + "zip2 v2.2d, v8.2d, v10.2d\n" // A1B1C1D1 + + "zip1 v4.2d, v9.2d, v11.2d\n" // A2B2C2D2 + "zip2 v6.2d, v9.2d, v11.2d\n" // A3B3C3D3 + + "zip1 v1.2d, v12.2d, v14.2d\n" // E0F0G0H0 + "zip2 v3.2d, v12.2d, v14.2d\n" // E1F1G1H1 + + "zip1 v5.2d, v13.2d, v15.2d\n" // E2F2G2H2 + "zip2 v7.2d, v13.2d, v15.2d\n" // E3F3G3H3 + + "st1 {v0.4s,v1.4s,v2.4s,v3.4s}, [%[outptr]], #64\n" + "st1 {v4.4s,v5.4s,v6.4s,v7.4s}, [%[outptr]], #64\n" + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "memory"); +} + +template +static inline void transpose_12x4_1_s(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + const T*& inptr8, const T*& inptr9, + const T*& inptr10, const T*& inptr11, + T* outptr) { + static_assert(sizeof(T) == 4, + "transpose_12x4_1_s only support sizeof(T) == 4"); + asm volatile( + "ld1 {v0.4s}, [%[inptr0]], 16\n" // A0A1A2A3 + "ld1 {v1.4s}, [%[inptr1]], 16\n" // B0B1B2B3 + "ld1 {v2.4s}, [%[inptr2]], 16\n" // C0C1C2C3 + "ld1 {v3.4s}, [%[inptr3]], 16\n" // D0D1D2D3 + "ld1 {v4.4s}, [%[inptr4]], 16\n" // E0E1E2E3 + "ld1 {v5.4s}, [%[inptr5]], 16\n" // F0F1F2F3 + "ld1 {v6.4s}, [%[inptr6]], 16\n" // G0G1G2G3 + "ld1 {v7.4s}, [%[inptr7]], 16\n" // H0H1H2H3 + "ld1 {v16.4s}, [%[inptr8]], 16\n" // I0I1I2I3 + "ld1 {v17.4s}, [%[inptr9]], 16\n" // J0J1J2J3 + "ld1 {v18.4s}, [%[inptr10]], 16\n" // K0K1K2K3 + "ld1 {v19.4s}, [%[inptr11]], 16\n" // L0L1L2L3 + + "zip1 v8.4s, v0.4s, v1.4s\n" // A0B0A1B1 + "zip2 v9.4s, v0.4s, v1.4s\n" // A2B2A3B3 + "zip1 v10.4s, v2.4s, v3.4s\n" // C0D0C1D1 + "zip2 v11.4s, v2.4s, v3.4s\n" // C2D2C3D3 + + "zip1 v12.4s, v4.4s, v5.4s\n" // E0F0E1F1 + "zip2 v13.4s, v4.4s, v5.4s\n" // E2F2E3F3 + "zip1 v14.4s, v6.4s, v7.4s\n" // G0H0G1H1 + "zip2 v15.4s, v6.4s, v7.4s\n" // G2H2G3H3 + + "zip1 v20.4s, v16.4s, v17.4s\n" // I0J0I1J1 + "zip2 v21.4s, v16.4s, v17.4s\n" // I2J2I3J3 + "zip1 v22.4s, v18.4s, v19.4s\n" // K0L0K1L1 + "zip2 v23.4s, v18.4s, v19.4s\n" // K2L2K3L3 + + "zip1 v0.2d, v8.2d, v10.2d\n" // A0B0C0D0 + "zip2 v3.2d, v8.2d, v10.2d\n" // A1B1C1D1 + + "zip1 v6.2d, v9.2d, v11.2d\n" // A2B2C2D2 + "zip2 v24.2d, v9.2d, v11.2d\n" // A3B3C3D3 + + "zip1 v1.2d, v12.2d, v14.2d\n" // E0F0G0H0 + "zip2 v4.2d, v12.2d, v14.2d\n" // E1F1G1H1 + + "zip1 v7.2d, v13.2d, v15.2d\n" // E2F2G2H2 + "zip2 v25.2d, v13.2d, v15.2d\n" // E3F3G3H3 + + "zip1 v2.2d, v20.2d, v22.2d\n" // I0J0K0L0 + "zip2 v5.2d, v20.2d, v22.2d\n" // I1J1K1L1 + + "zip1 v8.2d, v21.2d, v23.2d\n" // I2J2K2L2 + "zip2 v26.2d, v21.2d, v23.2d\n" // I3J3K3L3 + + "st1 {v0.4s,v1.4s,v2.4s}, [%[outptr]], #48\n" + "st1 {v3.4s,v4.4s,v5.4s}, [%[outptr]], #48\n" + "st1 {v6.4s,v7.4s,v8.4s}, [%[outptr]], #48\n" + "st1 {v24.4s,v25.4s,v26.4s}, [%[outptr]], #48\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), + [inptr8] "+r"(inptr8), [inptr9] "+r"(inptr9), + [inptr10] "+r"(inptr10), [inptr11] "+r"(inptr11), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "memory"); +} + +template +static inline void transpose_12x4_1_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T* outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "transpose_12x4_1_b only support uint8_t and int8_t"); + asm volatile( + "ldr q0, [%[inptr0]], #12\n" // A1A2A3A4A5A6A7A8A9A10A11A12A13A14A15A16 + "ldr q1, [%[inptr1]], #12\n" // B1B2B3B4B5B6B7B8B9B10B11B12B13B14B15B16 + "ldr q2, [%[inptr2]], #12\n" // C1C2C3C4C5C6C7C8C9C10C11C12C13C14C15C16 + //! \warning the last inptr3 may less than 16bytes, so we should + //! split read it + "ldr d3, [%[inptr3]], #8\n" // D1D2D3D4D5D6D7D8D9D10D11D12D13D14D15D16 + "ldr w1, [%[inptr3]], #4\n" + "ins v3.s[2], w1\n" + + "trn1 v4.16b, v0.16b, v1.16b\n" // v4: A1B1A3B3.... + "trn2 v5.16b, v0.16b, v1.16b\n" // v5: A2B2A4B4.... + "trn1 v6.16b, v2.16b, v3.16b\n" // v6: C1D1C3D3.... + "trn2 v7.16b, v2.16b, v3.16b\n" // v7: C2D2C4D4.... + + "trn1 v8.8h, v4.8h, v6.8h\n" // v8: A1B1C1D1A5B5C5D5... + "trn2 v9.8h, v4.8h, v6.8h\n" // v9: A3B3C3D3A7B7C7D7... + "trn1 v10.8h, v5.8h, v7.8h\n" // v10: A2B2C2D2A6B6C6D6... + "trn2 v11.8h, v5.8h, v7.8h\n" // v11: A4B4C4D4A8B8C8D8... + + //! ABCD=E then + //! v8: E1E5E9E13 v10: E2E6E10E14 v9: E3E7E11E15 v11: + //! E4E8E12E16 + "zip1 v12.4s, v8.4s, v10.4s\n" // v12: E1E2E5E6 + "zip2 v13.4s, v8.4s, v10.4s\n" // v13: E9E10E13E14 + "zip1 v14.4s, v9.4s, v11.4s\n" // v14: E3E4E7E8 + "zip2 v15.4s, v9.4s, v11.4s\n" // v15: E11E12E15E16 + "zip1 v17.2d, v12.2d, v14.2d\n" // v17: E1E2E3E4 + "zip2 v18.2d, v12.2d, v14.2d\n" // v18: E5E6E7E8 + "zip1 v19.2d, v13.2d, v15.2d\n" // v19: E8E10E11E12 + "zip2 v20.2d, v13.2d, v15.2d\n" // v19: E13E14E15E16 + + "stp q17, q18, [%[outptr]], #32\n" + "str q19, [%[outptr]], #16\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "w1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "memory"); +} + +template +static inline void transpose_8x4_1_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T* outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "transpose_8x4_1_b only support uint8_t and int8_t"); + asm volatile( + "ld1 {v0.d}[0], [%[inptr0]], #8\n" // A1A2A3A4A5A6A7A8 + "ld1 {v1.d}[0], [%[inptr1]], #8\n" // B1B2B3B4B5B6B7B8 + "ld1 {v0.d}[1], [%[inptr2]], #8\n" // C1C2C3C4C5C6C7C8 + "ld1 {v1.d}[1], [%[inptr3]], #8\n" // D1D2D3D4D5D6D7D8 + + "zip1 v2.16b, v0.16b, v1.16b\n" // A1B1A2B2A3B3A4B4A5B5A6B6A7B7A8B8 + "zip2 v3.16b, v0.16b, v1.16b\n" // C1D1C2D2C3D3C4D4C5D5C6D6C7D7C8D8 + + "zip1 v4.8h, v2.8h, v3.8h\n" // A1B1C1D1A2B2C2D2A3B3C3D3A4B4C4D4 + "zip2 v5.8h, v2.8h, v3.8h\n" // A5B5C5D5A6B6C6D6A7B7C7D7A8B8C8D8 + + "st1 {v4.2d}, [%[outptr]], #16\n" + "st1 {v5.2d}, [%[outptr]], #16\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "memory"); +} + +template +static inline void transpose_8x8_1_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T* outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "transpose_8x8_1_b only support uint8_t and int8_t"); + asm volatile( + "ld1 {v0.8b}, [%[inptr0]], #8\n" // A1A2A3A4A5A6A7A8 + "ld1 {v1.8b}, [%[inptr1]], #8\n" // B1B2B3B4B5B6B7B8 + "ld1 {v2.8b}, [%[inptr2]], #8\n" // C1C2C3C4C5C6C7C8 + "ld1 {v3.8b}, [%[inptr3]], #8\n" // D1D2D3D4D5D6D7D8 + "ld1 {v4.8b}, [%[inptr4]], #8\n" // E1E2E3E4E5E6E7E8 + "ld1 {v5.8b}, [%[inptr5]], #8\n" // F1F2F3F4F5F6F7F8 + "ld1 {v6.8b}, [%[inptr6]], #8\n" // G1G2G3G4G5G6G7G8 + "ld1 {v7.8b}, [%[inptr7]], #8\n" // H1H2H3H4H5H6H7H8 + + "zip1 v8.16b, v0.16b, v1.16b\n" // A1B1A2B2A3B3A4B4 + // A5B5A6B6A7B7A8B8 + "zip1 v9.16b, v2.16b, v3.16b\n" // C1D1C2D2C3D3C4D4 + // C5D5C6D6C7D7C8D8 + "zip1 v10.16b, v4.16b, v5.16b\n" // E1F1E2F2E3F3E4F4 + // E5F5E6F6E7F7E8F8 + "zip1 v11.16b, v6.16b, v7.16b\n" // G1H1G2H2G3H3G4H4 + // G5H5G6H6G7H7G8H8 + + "zip1 v12.8h, v8.8h, v9.8h\n" // A1B1C1D1A2B2C2D2 + // A3B3C3D3A4B4C4D4 + "zip1 v13.8h, v10.8h, v11.8h\n" // E1F1G1H1E2F2G2H2 + // E3F3G3H3E4F4G4H4 + "zip2 v14.8h, v8.8h, v9.8h\n" // A5B5C5D5A6B6C6D6 + // A7B7C7D7A8B8C8D8 + "zip2 v15.8h, v10.8h, v11.8h\n" // E5F5G5H5E6F6G6H6 + // E7F7G7H7E8F8G8H8 + + "zip1 v16.4s, v12.4s, v13.4s\n" // A1B1C1D1E1F1G1H1 + // A2B2C2D2E2F2G2H2 + "zip1 v18.4s, v14.4s, v15.4s\n" // A5B5C5D5E5F5G5H5 + // A6B6C6D6E6F6G6H6 + "zip2 v17.4s, v12.4s, v13.4s\n" // A3B3C3D3E3F3G3H3 + // A4B4C4D4E4F4G4H4 + "zip2 v19.4s, v14.4s, v15.4s\n" // A7B7C7D7E7F7G7H7 + // A8B8C8D8E8F8G8H8 + + "st1 {v16.16b}, [%[outptr]], #16\n" // A1B1C1D1E1F1G1H1 + // A2B2C2D2E2F2G2H2 + "st1 {v17.16b}, [%[outptr]], #16\n" // A3B3C3D3E3F3G3H3 + // A4B4C4D4E4F4G4H4 + "st1 {v18.16b}, [%[outptr]], #16\n" // A5B5C5D5E5F5G5H5 + // A6B6C6D6E6F6G6H6 + "st1 {v19.16b}, [%[outptr]], #16\n" // A7B7C7D7E7F7G7H7 + // A8B8C8D8E8F8G8H8 + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "memory"); +} + +template +static inline void transpose_4x16_1_b_helper(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T* outptr) { + static_assert(sizeof(T) == 1, "only support size == 1"); + static int8x16_t shuffle_idx = {0, 4, 8, 12, 1, 5, 9, 13, + 2, 6, 10, 14, 3, 7, 11, 15}; + asm volatile( + "ld1 {v0.s}[0], [%[inptr0]], #4\n" + "ld1 {v0.s}[1], [%[inptr1]], #4\n" + "ld1 {v0.s}[2], [%[inptr2]], #4\n" + "ld1 {v0.s}[3], [%[inptr3]], #4\n" + "ld1 {v1.s}[0], [%[inptr4]], #4\n" + "ld1 {v1.s}[1], [%[inptr5]], #4\n" + "ld1 {v1.s}[2], [%[inptr6]], #4\n" + "ld1 {v1.s}[3], [%[inptr7]], #4\n" + + "tbl v2.16b, {v0.16b}, %[shuffle_idx].16b\n" + "tbl v3.16b, {v1.16b}, %[shuffle_idx].16b\n" + + "zip1 v4.4s, v2.4s, v3.4s\n" + "zip2 v5.4s, v2.4s, v3.4s\n" + + "dup v6.2d, v4.d[1]\n" + "dup v7.2d, v5.d[1]\n" + + "str d4, [%[outptr]], #16\n" + "str d6, [%[outptr]], #16\n" + "str d5, [%[outptr]], #16\n" + "str d7, [%[outptr]], #16\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), + [outptr] "+r"(outptr), [shuffle_idx] "+w"(shuffle_idx) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); +} + +template +static inline void transpose_4(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, T* outptr, + int interleave, int size, T val = 0) { + megdnn_assert(size <= interleave); + int i = 0; + for (; i < size; i++) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + } + for (; i < interleave; i++) { + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + } +} + +template +static inline void transpose_8(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, T* outptr, + int interleave, int size, T val = 0) { + megdnn_assert(size <= interleave); + int i = 0; + for (; i < size; i++) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + for (; i < interleave; i++) { + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + } +} +/***************************** Transpose then interleave ********************/ + +//! pack form {1, 4(icb), 4(ic), 4(oc)} to {1, 1, 4(oc), 16(ic)} +template +static inline void transpose_interleave_4x4_4_b(const T*& inptr0, + const T*& inptr1, + const T*& inptr2, + const T*& inptr3, T* outptr, + int stride = 64) { + static_assert(sizeof(T) == 1, + "transpose_interleave_4x4_4_b only support sizeof(T) == 1"); + + asm volatile( + "ld4 {v0.16b, v1.16b, v2.16b, v3.16b},[%[inptr0]], 64\n" + "ld4 {v4.16b, v5.16b, v6.16b, v7.16b},[%[inptr1]], 64\n" + "ld4 {v8.16b, v9.16b, v10.16b, v11.16b},[%[inptr2]], 64\n" + "ld4 {v12.16b, v13.16b, v14.16b, v15.16b},[%[inptr3]], 64\n" + + "st1 {v0.16b, v1.16b, v2.16b, v3.16b},[%[outptr]], %x[stride]\n" + "st1 {v4.16b, v5.16b, v6.16b, v7.16b},[%[outptr]], %x[stride]\n" + "st1 {v8.16b, v9.16b, v10.16b, v11.16b},[%[outptr]], %x[stride]\n" + "st1 {v12.16b, v13.16b, v14.16b, v15.16b},[%[outptr]], %x[stride]\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr), [stride] "+r"(stride) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v14", "v15", "memory"); +} + +template +static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr, + int stride = 64) { + static_assert(sizeof(T) == 1, + "transpose_interleave_1x4_4_b only support sizeof(T) == 1"); + + asm volatile( + "ld4 {v0.16b, v1.16b, v2.16b, v3.16b},[%[inptr0]], 64\n" + "st1 {v0.16b, v1.16b, v2.16b, v3.16b},[%[outptr]], %x[stride]\n" + : + [inptr0] "+r"(inptr0), [outptr] "+r"(outptr), [stride] "+r"(stride) + : + : "v0", "v1", "v2", "v3", "v4", "memory"); +} + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp16/strategy.cpp b/dnn/src/aarch64/matrix_mul/fp16/strategy.cpp new file mode 100644 index 00000000..76be2eda --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp16/strategy.cpp @@ -0,0 +1,2589 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp16/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/fp16/strategy.h" +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +namespace { + +void interleave_8x1(__fp16* out, const __fp16* in, int ldin, int y0, int ymax, + int k0, int kmax) { + __fp16* outptr = out; + const __fp16* inptr = in; + __fp16 zerobuff[24]; + std::memset(zerobuff, 0, sizeof(__fp16) * 24); + + int y = y0; + for (; y + 8 <= ymax; y += 8) { + const __fp16* inptr0 = inptr + y * ldin + k0; + const __fp16* inptr1 = inptr0 + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + const __fp16* inptr4 = inptr3 + ldin; + const __fp16* inptr5 = inptr4 + ldin; + const __fp16* inptr6 = inptr5 + ldin; + const __fp16* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = (kmax - k0); + for (; x > 7; x -= 8) { + int skippf = (x & 31); + interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr, skippf); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + } + + for (; y < ymax; y += 4) { + const __fp16* inptr0 = inptr + y * ldin + k0; + const __fp16* inptr1 = inptr0 + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + /* Cope with ragged cases by copying from a buffer of zeroes instead + */ + int x = (kmax - k0); + for (; x > 3; x -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x1_4_h(inptr0, inptr1, inptr2, inptr3, outptr); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + } + } +} + +void interleave_24x1(__fp16* out, const __fp16* in, const int ldin, const int y0, + const int ymax, const int k0, const int kmax) { + __fp16* outptr = out; + const __fp16* inptr = in; + __fp16 zerobuff[24]; + std::memset(zerobuff, 0, sizeof(__fp16) * 24); + int K16 = 16 * (kmax - k0); + int K24 = 24 * (kmax - k0); + + int y = y0; + for (; y + 24 <= ymax; y += 24) { + int yi = y; + for (; yi < y + 24; yi += 8) { + const __fp16* inptr0 = inptr + yi * ldin + k0; + const __fp16* inptr1 = inptr0 + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + const __fp16* inptr4 = inptr3 + ldin; + const __fp16* inptr5 = inptr4 + ldin; + const __fp16* inptr6 = inptr5 + ldin; + const __fp16* inptr7 = inptr6 + ldin; + __fp16* outptr_inner = outptr + yi - y; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = (kmax - k0); + for (; x > 7; x -= 8) { + int skippf = (x & 31); + interleave_24x1_8_h_helper(inptr0, inptr1, inptr2, inptr3, + inptr4, inptr5, inptr6, inptr7, + outptr_inner, skippf); + } + for (; x > 0; x--) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + *outptr_inner++ = *inptr4++; + *outptr_inner++ = *inptr5++; + *outptr_inner++ = *inptr6++; + *outptr_inner++ = *inptr7++; + outptr_inner += 16; + } + } + outptr += K24; + } + + for (; y + 16 <= ymax; y += 16) { + int yi = y; + for (; yi < y + 16; yi += 8) { + const __fp16* inptr0 = inptr + yi * ldin + k0; + const __fp16* inptr1 = inptr0 + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + const __fp16* inptr4 = inptr3 + ldin; + const __fp16* inptr5 = inptr4 + ldin; + const __fp16* inptr6 = inptr5 + ldin; + const __fp16* inptr7 = inptr6 + ldin; + __fp16* outptr_inner = outptr + yi - y; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = (kmax - k0); + for (; x > 7; x -= 8) { + int skippf = (x & 31); + interleave_16x1_8_h_helper(inptr0, inptr1, inptr2, inptr3, + inptr4, inptr5, inptr6, inptr7, + outptr_inner, skippf); + } + for (; x > 0; x--) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + *outptr_inner++ = *inptr4++; + *outptr_inner++ = *inptr5++; + *outptr_inner++ = *inptr6++; + *outptr_inner++ = *inptr7++; + outptr_inner += 8; + } + } + outptr += K16; + } + + for (; y + 8 <= ymax; y += 8) { + const __fp16* inptr0 = inptr + y * ldin + k0; + const __fp16* inptr1 = inptr0 + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + const __fp16* inptr4 = inptr3 + ldin; + const __fp16* inptr5 = inptr4 + ldin; + const __fp16* inptr6 = inptr5 + ldin; + const __fp16* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = (kmax - k0); + for (; x > 7; x -= 8) { + int skippf = (x & 31); + interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr, skippf); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + } + + for (; y < ymax; y += 4) { + const __fp16* inptr0 = inptr + y * ldin + k0; + const __fp16* inptr1 = inptr0 + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + /* Cope with ragged cases by copying from a buffer of zeroes instead + */ + int x = (kmax - k0); + for (; x > 3; x -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x1_4_h(inptr0, inptr1, inptr2, inptr3, outptr); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + } + } +} + +void transpose_1x8(__fp16* out, const __fp16* in, int ldin, int x0, int xmax, + int k0, int kmax) { + int ksize = kmax - k0; + int ksize8 = (ksize << 3); + int ksize4 = (ksize << 2); + int k = ksize; + __fp16* outptr_base8 = out; + __fp16* outptr_base4 = out; + + const __fp16* inptr_base = in + x0 + k0 * ldin; + + for (; k > 3; k -= 4) { + __fp16* outptr = outptr_base8; + + const __fp16* inptr = inptr_base; + const __fp16* inptr1 = inptr + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + prefetch_3x(inptr3); + + int x = x0; + for (; x + 8 <= xmax; x += 8) { + transpose_8x4_1_h(inptr, inptr1, inptr2, inptr3, outptr); + outptr += ksize8; + } + outptr += outptr_base4 - outptr_base8; + for (; x < xmax; x += 4) { + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr++) : (__fp16)(0); + *outptr++ = val; + } + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr1++) : (__fp16)(0); + *outptr++ = val; + } + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr2++) : (__fp16)(0); + *outptr++ = val; + } + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr3++) : (__fp16)(0); + *outptr++ = val; + } + outptr -= 16; + outptr += ksize4; + } + + inptr_base += ldin * 4; + outptr_base8 += 8 * 4; + outptr_base4 += 4 * 4; + } + + if (k) { + __fp16* outptr = outptr_base8; + const __fp16* inptr = inptr_base; + const __fp16* inptr1 = inptr + ldin; + const __fp16* inptr2 = inptr1 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + + int x = x0; + for (; x + 8 <= xmax; x += 8) { + switch (k) { + case 3: + transpose_8x2_1_h(inptr, inptr1, outptr); + transpose_8x1_1_h(inptr2, outptr + 8 * 2); + break; + + case 2: + transpose_8x2_1_h(inptr, inptr1, outptr); + break; + + case 1: + transpose_8x1_1_h(inptr, outptr); + break; + + default: + megdnn_assert(0); + } + outptr += ksize8; + } + + outptr += outptr_base4 - outptr_base8; + for (; x < xmax; x += 4) { + switch (k) { + case 3: + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr++) : (__fp16)(0); + *outptr++ = val; + } + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr1++) : (__fp16)(0); + *outptr++ = val; + } + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr2++) : (__fp16)(0); + *outptr++ = val; + } + outptr -= 12; + break; + case 2: + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr++) : (__fp16)(0); + *outptr++ = val; + } + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr1++) : (__fp16)(0); + *outptr++ = val; + } + outptr -= 8; + break; + + case 1: + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr++) : (__fp16)(0); + *outptr++ = val; + } + outptr -= 4; + break; + + default: + megdnn_assert(0); + } + outptr += ksize4; + } + } +} + +void transpose_1x24(__fp16* out, const __fp16* in, const int ldin, const int x0, + const int xmax, const int k0, const int kmax) { + int ksize = kmax - k0; + int ksize24 = ksize * 24; + int ksize16 = (ksize << 4); + int ksize8 = (ksize << 3); + int ksize4 = (ksize << 2); + int k = ksize; + __fp16* outptr_base = out; + __fp16* outptr_base16 = out; + __fp16* outptr_base8 = out; + __fp16* outptr_base4 = out; + + const __fp16* inptr_base = in + x0 + k0 * ldin; + + for (; k > 3; k -= 4) { + __fp16* outptr = outptr_base; + + const __fp16* inptr = inptr_base; + const __fp16* inptr1 = inptr + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + prefetch_3x(inptr3); + + int x = x0; + for (; x + 24 <= xmax; x += 24) { + transpose_24x4_1_h(inptr, inptr1, inptr2, inptr3, outptr); + outptr += ksize24; + } + outptr += outptr_base16 - outptr_base; + for (; x + 16 <= xmax; x += 16) { + transpose_16x4_1_h(inptr, inptr1, inptr2, inptr3, outptr); + outptr += ksize16; + } + outptr += outptr_base8 - outptr_base16; + for (; x + 8 <= xmax; x += 8) { + transpose_8x4_1_h(inptr, inptr1, inptr2, inptr3, outptr); + outptr += ksize8; + } + outptr += outptr_base4 - outptr_base8; + for (; x < xmax; x += 4) { + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr++) : (__fp16)(0); + *outptr++ = val; + } + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr1++) : (__fp16)(0); + *outptr++ = val; + } + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr2++) : (__fp16)(0); + *outptr++ = val; + } + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr3++) : (__fp16)(0); + *outptr++ = val; + } + outptr -= 16; + outptr += ksize4; + } + + inptr_base += ldin * 4; + outptr_base += 24 * 4; + outptr_base16 += 16 * 4; + outptr_base8 += 8 * 4; + outptr_base4 += 4 * 4; + } + + if (k) { + __fp16* outptr = outptr_base; + const __fp16* inptr = inptr_base; + const __fp16* inptr1 = inptr + ldin; + const __fp16* inptr2 = inptr1 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + + int x = x0; + for (; x + 24 <= xmax; x += 24) { + switch (k) { + case 3: + transpose_24x2_1_h(inptr, inptr1, outptr); + transpose_24x1_1_h(inptr2, outptr + 24 * 2); + break; + + case 2: + transpose_24x2_1_h(inptr, inptr1, outptr); + break; + + case 1: + transpose_24x1_1_h(inptr, outptr); + break; + + default: + megdnn_assert(0); + } + outptr += ksize24; + } + + outptr += outptr_base16 - outptr_base; + for (; x + 16 <= xmax; x += 16) { + switch (k) { + case 3: + transpose_16x2_1_h(inptr, inptr1, outptr); + transpose_16x1_1_h(inptr2, outptr + 16 * 2); + break; + + case 2: + transpose_16x2_1_h(inptr, inptr1, outptr); + break; + + case 1: + transpose_16x1_1_h(inptr, outptr); + break; + + default: + megdnn_assert(0); + } + outptr += ksize16; + } + + outptr += outptr_base8 - outptr_base16; + for (; x + 8 <= xmax; x += 8) { + switch (k) { + case 3: + transpose_8x2_1_h(inptr, inptr1, outptr); + transpose_8x1_1_h(inptr2, outptr + 8 * 2); + break; + + case 2: + transpose_8x2_1_h(inptr, inptr1, outptr); + break; + + case 1: + transpose_8x1_1_h(inptr, outptr); + break; + + default: + megdnn_assert(0); + } + outptr += ksize8; + } + + outptr += outptr_base4 - outptr_base8; + for (; x < xmax; x += 4) { + switch (k) { + case 3: + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr++) : (__fp16)(0); + *outptr++ = val; + } + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr1++) : (__fp16)(0); + *outptr++ = val; + } + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr2++) : (__fp16)(0); + *outptr++ = val; + } + outptr -= 12; + break; + case 2: + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr++) : (__fp16)(0); + *outptr++ = val; + } + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr1++) : (__fp16)(0); + *outptr++ = val; + } + outptr -= 8; + break; + + case 1: + for (int i = 0; i < 4; i++) { + __fp16 val = (x + i < xmax) ? (*inptr++) : (__fp16)(0); + *outptr++ = val; + } + outptr -= 4; + break; + + default: + megdnn_assert(0); + } + outptr += ksize4; + } + } +} + +// Overview of register layout: +// +// A 2x24 cell of Rhs is stored in 16bit in q2-q7. +// A 8x2 cell of Lhs is stored in 16bit in q0-q1 +// A 8x24 block of accumulators is stored in 16bit in q8--q31. +// +// +--------+--------+--------+ +// | v2[0-7]| v3[0-7]| v4[0-7]| +// Rhs +--------+--------+--------+ +// | v5[0-7]| v6[0-7]| v7[0-7]| +// +--------+--------+--------+ +// +// | | | | +// +// Lhs | | | | +// +// +--+--+ - - - - +--------+--------+--------+ +// |v0|v1| | v8[0-7]|v16[0-7]|v24[0-7]| +// |v0|v1| | v9[0-7]|v17[0-7]|v25[0-7]| +// |v0|v1| |v10[0-7]|v18[0-7]|v26[0-7]| +// |v0|v1| |v11[0-7]|v19[0-7]|v27[0-7]| +// |v0|v1| |v12[0-7]|v20[0-7]|v28[0-7]| +// |v0|v1| |v13[0-7]|v21[0-7]|v29[0-7]| +// |v0|v1| |v14[0-7]|v22[0-7]|v30[0-7]| +// |v0|v1| |v15[0-7]|v23[0-7]|v31[0-7]| +// +--+--+ - - - - +--------+--------+--------+ +// +// Accumulator + +void aarch64_hgemm_assembly_kernel_24x8(const __fp16* a_ptr, + const __fp16*& b_ptr, int K, + __fp16* outptr0, int ldout, int type) { + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + register float16x8_t a0 asm("v0"); + register float16x8_t a0a asm("v1"); + register float16x8_t b0 asm("v2"); + register float16x8_t b1 asm("v3"); + register float16x8_t b2 asm("v4"); + register float16x8_t b0a asm("v5"); + register float16x8_t b1a asm("v6"); + register float16x8_t b2a asm("v7"); + + __fp16* outptr1 = outptr0 + ldout; + __fp16* outptr2 = outptr1 + ldout; + __fp16* outptr3 = outptr2 + ldout; + __fp16* outptr4 = outptr3 + ldout; + __fp16* outptr5 = outptr4 + ldout; + __fp16* outptr6 = outptr5 + ldout; + __fp16* outptr7 = outptr6 + ldout; + + asm volatile( + ".arch armv8.2-a+fp16\n" + + // load accumulator C + "cmp %w[type], #0\n" + "beq 5f\n" + "ldp q8, q16, [%[outptr0]]\n" + "ldr q24, [%[outptr0], #32]\n" + "ldp q9, q17, [%[outptr1]]\n" + "ldr q25, [%[outptr1], #32]\n" + "ldp q10, q18, [%[outptr2]]\n" + "ldr q26, [%[outptr2], #32]\n" + "ldp q11, q19, [%[outptr3]]\n" + "ldr q27, [%[outptr3], #32]\n" + "ldp q12, q20, [%[outptr4]]\n" + "ldr q28, [%[outptr4], #32]\n" + "ldp q13, q21, [%[outptr5]]\n" + "ldr q29, [%[outptr5], #32]\n" + "ldp q14, q22, [%[outptr6]]\n" + "ldr q30, [%[outptr6], #32]\n" + "ldp q15, q23, [%[outptr7]]\n" + "ldr q31, [%[outptr7], #32]\n" + "b 6f\n" + + "5:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + + "6:\n" + "ldr %q[a0], [%[a_ptr]]\n" + "ldr %q[b0], [%[b_ptr]]\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "ldr %q[b0a], [%[b_ptr], #48]\n" + "ldr %q[b1a], [%[b_ptr], #64]\n" + + ASM_PREFETCH("[%[b_ptr], #64]") + ASM_PREFETCH("[%[b_ptr], #128]") + ASM_PREFETCH("[%[b_ptr], #192]") + ASM_PREFETCH("[%[b_ptr], #256]") + ASM_PREFETCH("[%[b_ptr], #320]") + + "cbz %w[k], 4f\n" + + "1:\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %q[a0a], [%[a_ptr], #16]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %q[b2a], [%[b_ptr], #80]\n" + "fmla v12.8h, %[b0].8h, %[a0].h[4]\n" + "fmla v13.8h, %[b0].8h, %[a0].h[5]\n" + "fmla v14.8h, %[b0].8h, %[a0].h[6]\n" + "fmla v15.8h, %[b0].8h, %[a0].h[7]\n" + "ldr %q[b0], [%[b_ptr], #96]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + ASM_PREFETCH("[%[a_ptr], #128]") + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v20.8h, %[b1].8h, %[a0].h[4]\n" + "fmla v21.8h, %[b1].8h, %[a0].h[5]\n" + "fmla v22.8h, %[b1].8h, %[a0].h[6]\n" + "fmla v23.8h, %[b1].8h, %[a0].h[7]\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" + ASM_PREFETCH("[%[b_ptr], #288]") + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" + "fmla v28.8h, %[b2].8h, %[a0].h[4]\n" + "fmla v29.8h, %[b2].8h, %[a0].h[5]\n" + "fmla v30.8h, %[b2].8h, %[a0].h[6]\n" + "fmla v31.8h, %[b2].8h, %[a0].h[7]\n" + "ldr %q[a0], [%[a_ptr], #32]\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "fmla v12.8h, %[b0a].8h, %[a0a].h[4]\n" + "fmla v13.8h, %[b0a].8h, %[a0a].h[5]\n" + "fmla v14.8h, %[b0a].8h, %[a0a].h[6]\n" + "fmla v15.8h, %[b0a].8h, %[a0a].h[7]\n" + "ldr %q[b0a], [%[b_ptr], #48]\n" + + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" + ASM_PREFETCH("[%[b_ptr], #352]") + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + "fmla v20.8h, %[b1a].8h, %[a0a].h[4]\n" + "fmla v21.8h, %[b1a].8h, %[a0a].h[5]\n" + "fmla v22.8h, %[b1a].8h, %[a0a].h[6]\n" + "fmla v23.8h, %[b1a].8h, %[a0a].h[7]\n" + "ldr %q[b1a], [%[b_ptr], #64]\n" + + "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n" + "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n" + "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n" + "fmla v28.8h, %[b2a].8h, %[a0a].h[4]\n" + "fmla v29.8h, %[b2a].8h, %[a0a].h[5]\n" + "subs %w[k], %w[k], #1\n" + "fmla v30.8h, %[b2a].8h, %[a0a].h[6]\n" + "fmla v31.8h, %[b2a].8h, %[a0a].h[7]\n" + + "bne 1b\n" + "4:\n" + // Jump to odd tail if necessary. + "cbnz %w[oddk], 2f\n" + + // Even tail + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %q[a0a], [%[a_ptr], #16]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %q[b2a], [%[b_ptr], #80]\n" + "fmla v12.8h, %[b0].8h, %[a0].h[4]\n" + "fmla v13.8h, %[b0].8h, %[a0].h[5]\n" + "fmla v14.8h, %[b0].8h, %[a0].h[6]\n" + "fmla v15.8h, %[b0].8h, %[a0].h[7]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "fmla v20.8h, %[b1].8h, %[a0].h[4]\n" + "fmla v21.8h, %[b1].8h, %[a0].h[5]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "fmla v22.8h, %[b1].8h, %[a0].h[6]\n" + "fmla v23.8h, %[b1].8h, %[a0].h[7]\n" + + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" + "fmla v28.8h, %[b2].8h, %[a0].h[4]\n" + "fmla v29.8h, %[b2].8h, %[a0].h[5]\n" + "fmla v30.8h, %[b2].8h, %[a0].h[6]\n" + "fmla v31.8h, %[b2].8h, %[a0].h[7]\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "str q8, [%[outptr0]]\n" + "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n" + "str q16, [%[outptr0], #16]\n" + + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + "str q24, [%[outptr0], #32]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" + "str q9, [%[outptr1]]\n" + "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n" + "str q17, [%[outptr1], #16]\n" + + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "str q25, [%[outptr1], #32]\n" + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "str q10, [%[outptr2]]\n" + "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n" + "str q18, [%[outptr2], #16]\n" + + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "str q26, [%[outptr2], #32]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + "str q11, [%[outptr3]]\n" + "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n" + "str q19, [%[outptr3], #16]\n" + + "fmla v12.8h, %[b0a].8h, %[a0a].h[4]\n" + "str q27, [%[outptr3], #32]\n" + "fmla v20.8h, %[b1a].8h, %[a0a].h[4]\n" + "str q12, [%[outptr4]]\n" + "fmla v28.8h, %[b2a].8h, %[a0a].h[4]\n" + "str q20, [%[outptr4], #16]\n" + + "fmla v13.8h, %[b0a].8h, %[a0a].h[5]\n" + "str q28, [%[outptr4], #32]\n" + "fmla v21.8h, %[b1a].8h, %[a0a].h[5]\n" + "str q13, [%[outptr5]]\n" + "fmla v29.8h, %[b2a].8h, %[a0a].h[5]\n" + "str q21, [%[outptr5], #16]\n" + + "fmla v14.8h, %[b0a].8h, %[a0a].h[6]\n" + "str q29, [%[outptr5], #32]\n" + "fmla v22.8h, %[b1a].8h, %[a0a].h[6]\n" + "str q14, [%[outptr6]]\n" + "fmla v30.8h, %[b2a].8h, %[a0a].h[6]\n" + "str q22, [%[outptr6], #16]\n" + + "fmla v15.8h, %[b0a].8h, %[a0a].h[7]\n" + "str q30, [%[outptr6], #32]\n" + "fmla v23.8h, %[b1a].8h, %[a0a].h[7]\n" + "str q15, [%[outptr7]]\n" + "fmla v31.8h, %[b2a].8h, %[a0a].h[7]\n" + "b 3f\n" + + // Odd tail + "2:\n" + "add %[a_ptr], %[a_ptr], #16\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "str q8, [%[outptr0]]\n" + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "str q16, [%[outptr0], #16]\n" + + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "str q24, [%[outptr0], #32]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "str q9, [%[outptr1]]\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" + "str q17, [%[outptr1], #16]\n" + + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "str q25, [%[outptr1], #32]\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "str q10, [%[outptr2]]\n" + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "str q18, [%[outptr2], #16]\n" + + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "str q26, [%[outptr2], #32]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "str q11, [%[outptr3]]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" + "str q19, [%[outptr3], #16]\n" + + "fmla v12.8h, %[b0].8h, %[a0].h[4]\n" + "str q27, [%[outptr3], #32]\n" + "fmla v20.8h, %[b1].8h, %[a0].h[4]\n" + "str q12, [%[outptr4]]\n" + "fmla v28.8h, %[b2].8h, %[a0].h[4]\n" + "str q20, [%[outptr4], #16]\n" + + "fmla v13.8h, %[b0].8h, %[a0].h[5]\n" + "str q28, [%[outptr4], #32]\n" + "fmla v21.8h, %[b1].8h, %[a0].h[5]\n" + "str q13, [%[outptr5]]\n" + "fmla v29.8h, %[b2].8h, %[a0].h[5]\n" + "str q21, [%[outptr5], #16]\n" + + "fmla v14.8h, %[b0].8h, %[a0].h[6]\n" + "str q29, [%[outptr5], #32]\n" + "fmla v22.8h, %[b1].8h, %[a0].h[6]\n" + "str q14, [%[outptr6]]\n" + "fmla v30.8h, %[b2].8h, %[a0].h[6]\n" + "str q22, [%[outptr6], #16]\n" + + "fmla v15.8h, %[b0].8h, %[a0].h[7]\n" + "str q30, [%[outptr6], #32]\n" + "fmla v23.8h, %[b1].8h, %[a0].h[7]\n" + "str q15, [%[outptr7]]\n" + "fmla v31.8h, %[b2].8h, %[a0].h[7]\n" + + "3:\n" + "str q23, [%[outptr7], #16]\n" + "str q31, [%[outptr7], #32]\n" + : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), [b1] "+w"(b1), + [b2] "+w"(b2), [k] "+r"(k), [b0a] "+w"(b0a), [b1a] "+w"(b1a), + [b2a] "+w"(b2a), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), + [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3), + [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), + [outptr6] "+r"(outptr6), [outptr7] "+r"(outptr7) + : [oddk] "r"(oddk), [type] "r"(type) + : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); +} + +// Overview of register layout: +// +// A 2x16 cell of Rhs is stored in 16bit in q2,q3,q5,q6. +// A 8x2 cell of Lhs is stored in 16bit in q0-q1 +// A 8x16 block of accumulators is stored in 16bit in q8-q15, q16-q23. +// +// +--------+--------+ +// | v2[0-7]| v3[0-7]| +// Rhs +--------+--------+ +// | v5[0-7]| v6[0-7]| +// +--------+--------+ +// +// | | | +// +// Lhs | | | +// +// +--+--+ - - - - +--------+--------+ +// |v0|v1| | v8[0-7]|v16[0-7]| +// |v0|v1| | v9[0-7]|v17[0-7]| +// |v0|v1| |v10[0-7]|v18[0-7]| +// |v0|v1| |v11[0-7]|v19[0-7]| +// |v0|v1| |v12[0-7]|v20[0-7]| +// |v0|v1| |v13[0-7]|v21[0-7]| +// |v0|v1| |v14[0-7]|v22[0-7]| +// |v0|v1| |v15[0-7]|v23[0-7]| +// +--+--+ - - - - +--------+--------+ +// +// Accumulator +void aarch64_hgemm_assembly_kernel_16x8(const __fp16* a_ptr, + const __fp16*& b_ptr, int K, + __fp16* outptr0, int ldout, int type) { + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + register float16x8_t a0 asm("v0"); + register float16x8_t a0a asm("v1"); + register float16x8_t b0 asm("v2"); + register float16x8_t b1 asm("v3"); + register float16x8_t b0a asm("v5"); + register float16x8_t b1a asm("v6"); + + __fp16* outptr1 = outptr0 + ldout; + __fp16* outptr2 = outptr1 + ldout; + __fp16* outptr3 = outptr2 + ldout; + __fp16* outptr4 = outptr3 + ldout; + __fp16* outptr5 = outptr4 + ldout; + __fp16* outptr6 = outptr5 + ldout; + __fp16* outptr7 = outptr6 + ldout; + + asm volatile( + ".arch armv8.2-a+fp16\n" + + // load accumulator C + "cmp %w[type], #0\n" + "beq 5f\n" + "ldp q8, q16, [%[outptr0]]\n" + "ldp q9, q17, [%[outptr1]]\n" + "ldp q10, q18, [%[outptr2]]\n" + "ldp q11, q19, [%[outptr3]]\n" + "ldp q12, q20, [%[outptr4]]\n" + "ldp q13, q21, [%[outptr5]]\n" + "ldp q14, q22, [%[outptr6]]\n" + "ldp q15, q23, [%[outptr7]]\n" + "b 6f\n" + + "5:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + + "6:\n" + "ldr %q[a0], [%[a_ptr]]\n" + "ldr %q[b0], [%[b_ptr]]\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "ldr %q[b0a], [%[b_ptr], #32]\n" + "ldr %q[b1a], [%[b_ptr], #48]\n" + + "cbz %w[k], 4f\n" + + "1:\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %q[a0a], [%[a_ptr], #16]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "fmla v12.8h, %[b0].8h, %[a0].h[4]\n" + "fmla v13.8h, %[b0].8h, %[a0].h[5]\n" + "fmla v14.8h, %[b0].8h, %[a0].h[6]\n" + "fmla v15.8h, %[b0].8h, %[a0].h[7]\n" + "ldr %q[b0], [%[b_ptr], #64]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "add %[b_ptr], %[b_ptr], #64\n" + "fmla v20.8h, %[b1].8h, %[a0].h[4]\n" + "fmla v21.8h, %[b1].8h, %[a0].h[5]\n" + "fmla v22.8h, %[b1].8h, %[a0].h[6]\n" + "fmla v23.8h, %[b1].8h, %[a0].h[7]\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + + "ldr %q[a0], [%[a_ptr], #32]\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "fmla v12.8h, %[b0a].8h, %[a0a].h[4]\n" + "fmla v13.8h, %[b0a].8h, %[a0a].h[5]\n" + "fmla v14.8h, %[b0a].8h, %[a0a].h[6]\n" + "fmla v15.8h, %[b0a].8h, %[a0a].h[7]\n" + "ldr %q[b0a], [%[b_ptr], #32]\n" + + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + "fmla v20.8h, %[b1a].8h, %[a0a].h[4]\n" + "fmla v21.8h, %[b1a].8h, %[a0a].h[5]\n" + "fmla v22.8h, %[b1a].8h, %[a0a].h[6]\n" + "fmla v23.8h, %[b1a].8h, %[a0a].h[7]\n" + "ldr %q[b1a], [%[b_ptr], #48]\n" + + "add %[a_ptr], %[a_ptr], #32\n" + "subs %w[k], %w[k], #1\n" + + "bne 1b\n" + "4:\n" + // Jump to odd tail if necessary. + "cbnz %w[oddk], 2f\n" + + // Even tail + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %q[a0a], [%[a_ptr], #16]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "fmla v12.8h, %[b0].8h, %[a0].h[4]\n" + "fmla v13.8h, %[b0].8h, %[a0].h[5]\n" + "fmla v14.8h, %[b0].8h, %[a0].h[6]\n" + "fmla v15.8h, %[b0].8h, %[a0].h[7]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "add %[b_ptr], %[b_ptr], #64\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "fmla v20.8h, %[b1].8h, %[a0].h[4]\n" + "fmla v21.8h, %[b1].8h, %[a0].h[5]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "fmla v22.8h, %[b1].8h, %[a0].h[6]\n" + "fmla v23.8h, %[b1].8h, %[a0].h[7]\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "str q8, [%[outptr0]]\n" + "str q16, [%[outptr0], #16]\n" + + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" + "str q9, [%[outptr1]]\n" + "str q17, [%[outptr1], #16]\n" + + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "str q10, [%[outptr2]]\n" + "str q18, [%[outptr2], #16]\n" + + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + "str q11, [%[outptr3]]\n" + "str q19, [%[outptr3], #16]\n" + + "fmla v12.8h, %[b0a].8h, %[a0a].h[4]\n" + "fmla v20.8h, %[b1a].8h, %[a0a].h[4]\n" + "str q12, [%[outptr4]]\n" + "str q20, [%[outptr4], #16]\n" + + "fmla v13.8h, %[b0a].8h, %[a0a].h[5]\n" + "fmla v21.8h, %[b1a].8h, %[a0a].h[5]\n" + "str q13, [%[outptr5]]\n" + "str q21, [%[outptr5], #16]\n" + + "fmla v14.8h, %[b0a].8h, %[a0a].h[6]\n" + "fmla v22.8h, %[b1a].8h, %[a0a].h[6]\n" + "str q14, [%[outptr6]]\n" + "str q22, [%[outptr6], #16]\n" + + "fmla v15.8h, %[b0a].8h, %[a0a].h[7]\n" + "fmla v23.8h, %[b1a].8h, %[a0a].h[7]\n" + "str q15, [%[outptr7]]\n" + "b 3f\n" + + // Odd tail + "2:\n" + "add %[a_ptr], %[a_ptr], #16\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "add %[b_ptr], %[b_ptr], #32\n" + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "str q8, [%[outptr0]]\n" + "str q16, [%[outptr0], #16]\n" + + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "str q9, [%[outptr1]]\n" + "str q17, [%[outptr1], #16]\n" + + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "str q10, [%[outptr2]]\n" + "str q18, [%[outptr2], #16]\n" + + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "str q11, [%[outptr3]]\n" + "str q19, [%[outptr3], #16]\n" + + "fmla v12.8h, %[b0].8h, %[a0].h[4]\n" + "fmla v20.8h, %[b1].8h, %[a0].h[4]\n" + "str q12, [%[outptr4]]\n" + "str q20, [%[outptr4], #16]\n" + + "fmla v13.8h, %[b0].8h, %[a0].h[5]\n" + "fmla v21.8h, %[b1].8h, %[a0].h[5]\n" + "str q13, [%[outptr5]]\n" + "str q21, [%[outptr5], #16]\n" + + "fmla v14.8h, %[b0].8h, %[a0].h[6]\n" + "fmla v22.8h, %[b1].8h, %[a0].h[6]\n" + "str q14, [%[outptr6]]\n" + "str q22, [%[outptr6], #16]\n" + + "fmla v15.8h, %[b0].8h, %[a0].h[7]\n" + "fmla v23.8h, %[b1].8h, %[a0].h[7]\n" + "str q15, [%[outptr7]]\n" + + "3:\n" + "str q23, [%[outptr7], #16]\n" + : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), [b1] "+w"(b1), + [k] "+r"(k), [b0a] "+w"(b0a), [b1a] "+w"(b1a), + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), + [outptr3] "+r"(outptr3), [outptr4] "+r"(outptr4), + [outptr5] "+r"(outptr5), [outptr6] "+r"(outptr6), + [outptr7] "+r"(outptr7) + : [oddk] "r"(oddk), [type] "r"(type) + : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "cc", "memory"); +} + +// Overview of register layout: +// +// A 2x8 cell of Rhs is stored in 16bit in q2,q5 +// A 8x2 cell of Lhs is stored in 16bit in q0-q1 +// A 8x8 block of accumulators is stored in 16bit in q8-q15. +// +// +--------+ +// | v2[0-7]| +// Rhs +--------+ +// | v5[0-7]| +// +--------+ +// +// | | +// +// Lhs | | +// +// +--+--+ - - - - +--------+ +// |v0|v1| | v8[0-7]| +// |v0|v1| | v9[0-7]| +// |v0|v1| |v10[0-7]| +// |v0|v1| |v11[0-7]| +// |v0|v1| |v12[0-7]| +// |v0|v1| |v13[0-7]| +// |v0|v1| |v14[0-7]| +// |v0|v1| |v15[0-7]| +// +--+--+ - - - - +--------+ +// +// Accumulator +void aarch64_hgemm_assembly_kernel_8x8(const __fp16* a_ptr, + const __fp16*& b_ptr, int K, + __fp16* outptr0, int ldout, int type) { + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + register float16x8_t a0 asm("v0"); + register float16x8_t a0a asm("v1"); + register float16x8_t b0 asm("v2"); + register float16x8_t b0a asm("v5"); + + __fp16* outptr1 = outptr0 + ldout; + __fp16* outptr2 = outptr1 + ldout; + __fp16* outptr3 = outptr2 + ldout; + __fp16* outptr4 = outptr3 + ldout; + __fp16* outptr5 = outptr4 + ldout; + __fp16* outptr6 = outptr5 + ldout; + __fp16* outptr7 = outptr6 + ldout; + + asm volatile( + ".arch armv8.2-a+fp16\n" + + // load accumulator C + "cmp %w[type], #0\n" + "beq 5f\n" + "ldr q8, [%[outptr0]]\n" + "ldr q9, [%[outptr1]]\n" + "ldr q10, [%[outptr2]]\n" + "ldr q11, [%[outptr3]]\n" + "ldr q12, [%[outptr4]]\n" + "ldr q13, [%[outptr5]]\n" + "ldr q14, [%[outptr6]]\n" + "ldr q15, [%[outptr7]]\n" + "b 6f\n" + + "5:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + + "6:\n" + "ldr %q[a0], [%[a_ptr]]\n" + "ldr %q[b0], [%[b_ptr]]\n" + "ldr %q[b0a], [%[b_ptr], #16]\n" + + "cbz %w[k], 4f\n" + + "1:\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %q[a0a], [%[a_ptr], #16]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "fmla v12.8h, %[b0].8h, %[a0].h[4]\n" + "fmla v13.8h, %[b0].8h, %[a0].h[5]\n" + "fmla v14.8h, %[b0].8h, %[a0].h[6]\n" + "fmla v15.8h, %[b0].8h, %[a0].h[7]\n" + "ldr %q[b0], [%[b_ptr], #32]\n" + + "add %[b_ptr], %[b_ptr], #32\n" + "ldr %q[a0], [%[a_ptr], #32]\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "fmla v12.8h, %[b0a].8h, %[a0a].h[4]\n" + "fmla v13.8h, %[b0a].8h, %[a0a].h[5]\n" + "fmla v14.8h, %[b0a].8h, %[a0a].h[6]\n" + "fmla v15.8h, %[b0a].8h, %[a0a].h[7]\n" + "ldr %q[b0a], [%[b_ptr], #16]\n" + + "add %[a_ptr], %[a_ptr], #32\n" + "subs %w[k], %w[k], #1\n" + + "bne 1b\n" + "4:\n" + // Jump to odd tail if necessary. + "cbnz %w[oddk], 2f\n" + + // Even tail + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %q[a0a], [%[a_ptr], #16]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "fmla v12.8h, %[b0].8h, %[a0].h[4]\n" + "fmla v13.8h, %[b0].8h, %[a0].h[5]\n" + "fmla v14.8h, %[b0].8h, %[a0].h[6]\n" + "fmla v15.8h, %[b0].8h, %[a0].h[7]\n" + + "add %[b_ptr], %[b_ptr], #32\n" + "add %[a_ptr], %[a_ptr], #32\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "str q8, [%[outptr0]]\n" + + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + "str q9, [%[outptr1]]\n" + + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "str q10, [%[outptr2]]\n" + + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "str q11, [%[outptr3]]\n" + + "fmla v12.8h, %[b0a].8h, %[a0a].h[4]\n" + "str q12, [%[outptr4]]\n" + + "fmla v13.8h, %[b0a].8h, %[a0a].h[5]\n" + "str q13, [%[outptr5]]\n" + + "fmla v14.8h, %[b0a].8h, %[a0a].h[6]\n" + "str q14, [%[outptr6]]\n" + + "fmla v15.8h, %[b0a].8h, %[a0a].h[7]\n" + "str q15, [%[outptr7]]\n" + "b 3f\n" + + // Odd tail + "2:\n" + "add %[a_ptr], %[a_ptr], #16\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "add %[b_ptr], %[b_ptr], #16\n" + "str q8, [%[outptr0]]\n" + + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "str q9, [%[outptr1]]\n" + + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "str q10, [%[outptr2]]\n" + + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "str q11, [%[outptr3]]\n" + + "fmla v12.8h, %[b0].8h, %[a0].h[4]\n" + "str q12, [%[outptr4]]\n" + + "fmla v13.8h, %[b0].8h, %[a0].h[5]\n" + "str q13, [%[outptr5]]\n" + + "fmla v14.8h, %[b0].8h, %[a0].h[6]\n" + "str q14, [%[outptr6]]\n" + + "fmla v15.8h, %[b0].8h, %[a0].h[7]\n" + "str q15, [%[outptr7]]\n" + + "3:\n" + : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), [k] "+r"(k), + [b0a] "+w"(b0a), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), + [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3), + [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), + [outptr6] "+r"(outptr6), [outptr7] "+r"(outptr7) + : [oddk] "r"(oddk), [type] "r"(type) + : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "cc", + "memory"); +} + +// Overview of register layout: +// +// A 2x8 cell of Rhs is stored in 16bit in d2, d5 +// A 8x2 cell of Lhs is stored in 16bit in q0 - q1 +// A 8x8 block of accumulators is stored in 16bit in d8 - d15. +// +// +--------+ +// | d2[0-3]| +// Rhs +--------+ +// | d5[0-3]| +// +--------+ +// +// | | +// +// Lhs | | +// +// +--+--+ - - - - +--------+ +// |v0|v1| | d8[0-3]| +// |v0|v1| | d9[0-3]| +// |v0|v1| |d10[0-3]| +// |v0|v1| |d11[0-3]| +// |v0|v1| |d12[0-3]| +// |v0|v1| |d13[0-3]| +// |v0|v1| |d14[0-3]| +// |v0|v1| |d15[0-3]| +// +--+--+ - - - - +--------+ +// +// Accumulator +void aarch64_hgemm_assembly_kernel_4x8(const __fp16* a_ptr, + const __fp16*& b_ptr, int K, + __fp16* outptr0, int ldout, int x_remain, + int type) { + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + register float16x8_t a0 asm("v0"); + register float16x8_t a0a asm("v1"); + register float16x8_t b0 asm("v2"); + register float16x8_t b0a asm("v5"); + + __fp16* outptr1 = outptr0 + ldout; + __fp16* outptr2 = outptr1 + ldout; + __fp16* outptr3 = outptr2 + ldout; + __fp16* outptr4 = outptr3 + ldout; + __fp16* outptr5 = outptr4 + ldout; + __fp16* outptr6 = outptr5 + ldout; + __fp16* outptr7 = outptr6 + ldout; + +#define LOAD_LINE(reg_index, n) \ + "mov x0, %[outptr" n \ + "]\n" \ + "cmp %w[x_remain], #4\n" \ + "b.lt REMAIN_LOAD_LINE_LESS_THAN_4_" n \ + "\n" \ + "ldr d" reg_index \ + ", [x0]\n" \ + "b LOAD_LINE_END_" n \ + "\n" \ + \ + "REMAIN_LOAD_LINE_LESS_THAN_4_" n \ + ":\n" \ + "cmp %w[x_remain], #0\n" \ + "beq LOAD_LINE_END_" n \ + "\n" \ + "ld1 {v" reg_index \ + ".h}[0], [x0], #2\n" \ + "cmp %w[x_remain], #1\n" \ + "beq LOAD_LINE_END_" n \ + "\n" \ + "ld1 {v" reg_index \ + ".h}[1], [x0], #2\n" \ + "cmp %w[x_remain], #2\n" \ + "beq LOAD_LINE_END_" n \ + "\n" \ + "ld1 {v" reg_index \ + ".h}[2], [x0], #2\n" \ + "LOAD_LINE_END_" n ":\n" + +#define LOAD_C \ + LOAD_LINE("8", "0") \ + LOAD_LINE("9", "1") \ + LOAD_LINE("10", "2") \ + LOAD_LINE("11", "3") \ + LOAD_LINE("12", "4") \ + LOAD_LINE("13", "5") \ + LOAD_LINE("14", "6") \ + LOAD_LINE("15", "7") + +#define STORE_LINE(reg_index, n) \ + "mov x0, %[outptr" n \ + "]\n" \ + "cmp %w[x_remain], #4\n" \ + "b.lt REMAIN_STORE_LINE_LESS_THAN_4_" n \ + "\n" \ + "str d" reg_index \ + ", [x0]\n" \ + "b STORE_LINE_END_" n \ + "\n" \ + \ + "REMAIN_STORE_LINE_LESS_THAN_4_" n \ + ":\n" \ + "cmp %w[x_remain], #0\n" \ + "beq STORE_LINE_END_" n \ + "\n" \ + "st1 {v" reg_index \ + ".h}[0], [x0], #2\n" \ + "cmp %w[x_remain], #1\n" \ + "beq STORE_LINE_END_" n \ + "\n" \ + "st1 {v" reg_index \ + ".h}[1], [x0], #2\n" \ + "cmp %w[x_remain], #2\n" \ + "beq STORE_LINE_END_" n \ + "\n" \ + "st1 {v" reg_index \ + ".h}[2], [x0], #2\n" \ + "STORE_LINE_END_" n ":\n" + +#define STORE_C \ + STORE_LINE("8", "0") \ + STORE_LINE("9", "1") \ + STORE_LINE("10", "2") \ + STORE_LINE("11", "3") \ + STORE_LINE("12", "4") \ + STORE_LINE("13", "5") \ + STORE_LINE("14", "6") \ + STORE_LINE("15", "7") + + asm volatile( + ".arch armv8.2-a+fp16\n" + + // load accumulator C + "cmp %w[type], #0\n" + "beq 5f\n" LOAD_C + "b 6f\n" + + "5:\n" + "eor v8.8b, v8.8b, v8.8b\n" + "eor v9.8b, v9.8b, v9.8b\n" + "eor v10.8b, v10.8b, v10.8b\n" + "eor v11.8b, v11.8b, v11.8b\n" + "eor v12.8b, v12.8b, v12.8b\n" + "eor v13.8b, v13.8b, v13.8b\n" + "eor v14.8b, v14.8b, v14.8b\n" + "eor v15.8b, v15.8b, v15.8b\n" + + "6:\n" + "ldr %q[a0], [%[a_ptr]]\n" + + "cbz %w[k], 4f\n" + + "1:\n" + "ldp %d[b0], %d[b0a], [%[b_ptr]]\n" + "fmla v8.4h , %[b0].4h, %[a0].h[0]\n" + "fmla v9.4h , %[b0].4h, %[a0].h[1]\n" + "ldr %q[a0a], [%[a_ptr], #16]\n" + "fmla v10.4h, %[b0].4h, %[a0].h[2]\n" + "fmla v11.4h, %[b0].4h, %[a0].h[3]\n" + "fmla v12.4h, %[b0].4h, %[a0].h[4]\n" + "fmla v13.4h, %[b0].4h, %[a0].h[5]\n" + "fmla v14.4h, %[b0].4h, %[a0].h[6]\n" + "fmla v15.4h, %[b0].4h, %[a0].h[7]\n" + + "add %[b_ptr], %[b_ptr], #16\n" + "ldr %q[a0], [%[a_ptr], #32]\n" + + "fmla v8.4h , %[b0a].4h, %[a0a].h[0]\n" + "fmla v9.4h , %[b0a].4h, %[a0a].h[1]\n" + "fmla v10.4h, %[b0a].4h, %[a0a].h[2]\n" + "fmla v11.4h, %[b0a].4h, %[a0a].h[3]\n" + "fmla v12.4h, %[b0a].4h, %[a0a].h[4]\n" + "fmla v13.4h, %[b0a].4h, %[a0a].h[5]\n" + "fmla v14.4h, %[b0a].4h, %[a0a].h[6]\n" + "fmla v15.4h, %[b0a].4h, %[a0a].h[7]\n" + + "add %[a_ptr], %[a_ptr], #32\n" + "subs %w[k], %w[k], #1\n" + + "bne 1b\n" + "4:\n" + // Jump to odd tail if necessary. + "cbnz %w[oddk], 2f\n" + + // Even tail + "ldp %d[b0], %d[b0a], [%[b_ptr]]\n" + "fmla v8.4h , %[b0].4h, %[a0].h[0]\n" + "fmla v9.4h , %[b0].4h, %[a0].h[1]\n" + "ldr %q[a0a], [%[a_ptr], #16]\n" + "fmla v10.4h, %[b0].4h, %[a0].h[2]\n" + "fmla v11.4h, %[b0].4h, %[a0].h[3]\n" + "fmla v12.4h, %[b0].4h, %[a0].h[4]\n" + "fmla v13.4h, %[b0].4h, %[a0].h[5]\n" + "fmla v14.4h, %[b0].4h, %[a0].h[6]\n" + "fmla v15.4h, %[b0].4h, %[a0].h[7]\n" + + "add %[b_ptr], %[b_ptr], #16\n" + "add %[a_ptr], %[a_ptr], #32\n" + + "fmla v8.4h , %[b0a].4h, %[a0a].h[0]\n" + "fmla v9.4h , %[b0a].4h, %[a0a].h[1]\n" + "fmla v10.4h, %[b0a].4h, %[a0a].h[2]\n" + "fmla v11.4h, %[b0a].4h, %[a0a].h[3]\n" + "fmla v12.4h, %[b0a].4h, %[a0a].h[4]\n" + "fmla v13.4h, %[b0a].4h, %[a0a].h[5]\n" + "fmla v14.4h, %[b0a].4h, %[a0a].h[6]\n" + "fmla v15.4h, %[b0a].4h, %[a0a].h[7]\n" + "b 3f\n" + + // Odd tail + "2:\n" + "ldr %d[b0], [%[b_ptr]]\n" + "add %[a_ptr], %[a_ptr], #16\n" + "fmla v8.4h , %[b0].4h, %[a0].h[0]\n" + "add %[b_ptr], %[b_ptr], #8\n" + "fmla v9.4h , %[b0].4h, %[a0].h[1]\n" + "fmla v10.4h, %[b0].4h, %[a0].h[2]\n" + "fmla v11.4h, %[b0].4h, %[a0].h[3]\n" + "fmla v12.4h, %[b0].4h, %[a0].h[4]\n" + "fmla v13.4h, %[b0].4h, %[a0].h[5]\n" + "fmla v14.4h, %[b0].4h, %[a0].h[6]\n" + "fmla v15.4h, %[b0].4h, %[a0].h[7]\n" + + "3:\n" STORE_C + : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), [k] "+r"(k), + [b0a] "+w"(b0a), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), + [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3), + [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), + [outptr6] "+r"(outptr6), [outptr7] "+r"(outptr7) + : [oddk] "r"(oddk), [x_remain] "r"(x_remain), [type] "r"(type) + : "x0", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "cc", + "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +// Overview of register layout: +// +// A 2x24 cell of Rhs is stored in 16bit in q2 - q7 +// A 4x2 cell of Lhs is stored in 16bit in d0, d1 +// A 4x24 block of accumulators is stored in 16bit in q8-q11, q16-q19, q24-q27. +// +// +--------+--------+--------+ +// | v2[0-7]| v3[0-7]| v4[0-7]| +// Rhs +--------+--------+--------+ +// | v5[0-7]| v6[0-7]| v7[0-7]| +// +--------+--------+--------+ +// +// | | | | +// +// Lhs | | | | +// +// +--+--+ - - - - +--------+--------+--------+ +// |v0|v1| | v8[0-7]|v16[0-7]|v24[0-7]| +// |v0|v1| | v9[0-7]|v17[0-7]|v25[0-7]| +// |v0|v1| |v10[0-7]|v18[0-7]|v26[0-7]| +// |v0|v1| |v11[0-7]|v19[0-7]|v27[0-7]| +// +--+--+ - - - - +--------+--------+--------+ +// +// Accumulator +//! cannot load %[a0] and %[a0a] at same time! +void aarch64_hgemm_assembly_kernel_24x4(const __fp16* a_ptr, + const __fp16*& b_ptr, int K, + __fp16* outptr0, int ldout, + int y_remain, int type) { + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + register float16x8_t a0 asm("v0"); + register float16x8_t a0a asm("v1"); + register float16x8_t b0 asm("v2"); + register float16x8_t b1 asm("v3"); + register float16x8_t b2 asm("v4"); + register float16x8_t b0a asm("v5"); + register float16x8_t b1a asm("v6"); + register float16x8_t b2a asm("v7"); + + __fp16* outptr1 = outptr0 + ldout; + __fp16* outptr2 = outptr1 + ldout; + __fp16* outptr3 = outptr2 + ldout; + +// clang-format off +#define LOAD_LINE(v1, v2, v3, n) \ + "cbz w0, LOAD_24x4_C_END\n" \ + "ldp q" v1 ", q" v2 ", [%[outptr" n \ + "]]\n" \ + "ldr q" v3 ", [%[outptr" n \ + "], #32]\n" \ + "subs w0, w0, #1\n" + +#define LOAD_C \ + "mov w0, %w[y_remain]\n" \ + LOAD_LINE("8", "16", "24", "0") \ + LOAD_LINE("9", "17", "25", "1") \ + LOAD_LINE("10", "18", "26", "2") \ + LOAD_LINE("11", "19", "27", "3") \ + "LOAD_24x4_C_END:\n" + +#define STORE_LINE(v1, v2, v3, n) \ + "cbz w0, STORE_24x4_C_END\n" \ + "stp q" v1 ", q" v2 ", [%[outptr" n \ + "]]\n" \ + "str q" v3 ", [%[outptr" n \ + "], #32]\n" \ + "subs w0, w0, #1\n" + +#define STORE_C "mov w0, %w[y_remain]\n" \ + STORE_LINE("8", "16", "24", "0") \ + STORE_LINE("9", "17", "25", "1") \ + STORE_LINE("10", "18", "26", "2") \ + STORE_LINE("11", "19", "27", "3") \ + "STORE_24x4_C_END:\n" +// clang-format on + + asm volatile( + ".arch armv8.2-a+fp16\n" + + // load accumulator C + "cmp %w[type], #0\n" + "beq 5f\n" + LOAD_C + "b 6f\n" + "5:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + + "6:\n" + "ldr %d[a0], [%[a_ptr]]\n" + "ldr %q[b0], [%[b_ptr]]\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "ldr %q[b0a], [%[b_ptr], #48]\n" + "ldr %q[b1a], [%[b_ptr], #64]\n" + + "cbz %w[k], 4f\n" + + "1:\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %q[b2a], [%[b_ptr], #80]\n" + "ldr %q[b0], [%[b_ptr], #96]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" + "ldr %d[a0], [%[a_ptr], #16]\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "ldr %q[b0a], [%[b_ptr], #48]\n" + + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + "ldr %q[b1a], [%[b_ptr], #64]\n" + + "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n" + "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n" + "add %[a_ptr], %[a_ptr], #16\n" + "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n" + "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n" + "subs %w[k], %w[k], #1\n" + + "bne 1b\n" + "4:\n" + // Jump to odd tail if necessary. + "cbnz %w[oddk], 2f\n" + + // Even tail + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %q[b2a], [%[b_ptr], #80]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" + + "fmla v8.8h, %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h, %[b0a].8h, %[a0a].h[1]\n" + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + + "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n" + "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n" + "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n" + "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n" + "b 3f\n" + + // Odd tail + "2:\n" + "add %[a_ptr], %[a_ptr], #8\n" + "add %[b_ptr], %[b_ptr], #48\n" + + "fmla v8.8h, %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h, %[b0].8h, %[a0].h[1]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" + + "3:\n" STORE_C + : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), + [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k), + [b0a] "+w"(b0a), [b1a] "+w"(b1a), [b2a] "+w"(b2a), + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), + [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : + [oddk] "r"(oddk), [y_remain] "r"(y_remain), [type] "r"(type) + : "w0", "v8", "v9", "v10", "v11", "v16", "v17", "v18", + "v19", "v24", "v25", "v26", "v27", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +// Overview of register layout: +// +// A 2x16 cell of Rhs is stored in 16bit in q2, q3, q5, q6 +// A 4x2 cell of Lhs is stored in 16bit in d0, d1 +// A 4x16 block of accumulators is stored in 16bit in q8-q11, q16-q19. +// +// +--------+--------+ +// | v2[0-7]| v3[0-7]| +// Rhs +--------+--------+ +// | v5[0-7]| v6[0-7]| +// +--------+--------+ +// +// | | | +// +// Lhs | | | +// +// +--+--+ - - - - +--------+--------+ +// |v0|v1| | v8[0-7]|v16[0-7]| +// |v0|v1| | v9[0-7]|v17[0-7]| +// |v0|v1| |v10[0-7]|v18[0-7]| +// |v0|v1| |v11[0-7]|v19[0-7]| +// +--+--+ - - - - +--------+--------+ +// +// Accumulator +void aarch64_hgemm_assembly_kernel_16x4(const __fp16* a_ptr, + const __fp16*& b_ptr, int K, + __fp16* outptr0, int ldout, + int y_remain, int type) { + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + register float16x8_t a0 asm("v0"); + register float16x8_t a0a asm("v1"); + register float16x8_t b0 asm("v2"); + register float16x8_t b1 asm("v3"); + register float16x8_t b0a asm("v5"); + register float16x8_t b1a asm("v6"); + + __fp16* outptr1 = outptr0 + ldout; + __fp16* outptr2 = outptr1 + ldout; + __fp16* outptr3 = outptr2 + ldout; + +// clang-format off + +#define LOAD_LINE(v1, v2, n) \ + "cbz w0, LOAD_16x4_C_END\n" \ + "ldp q" v1 ", q" v2 ", [%[outptr" n \ + "]]\n" \ + "subs w0, w0, #1\n" + +#define LOAD_C "mov w0, %w[y_remain]\n" \ + LOAD_LINE("8", "16", "0") \ + LOAD_LINE("9", "17", "1") \ + LOAD_LINE("10", "18", "2") \ + LOAD_LINE("11", "19", "3") \ + "LOAD_16x4_C_END:\n" + +#define STORE_LINE(v1, v2, n) \ + "cbz w0, STORE_16x4_C_END\n" \ + "stp q" v1 ", q" v2 ", [%[outptr" n \ + "]]\n" \ + "subs w0, w0, #1\n" + +#define STORE_C "mov w0, %w[y_remain]\n" \ + STORE_LINE("8", "16", "0") \ + STORE_LINE("9", "17", "1") \ + STORE_LINE("10", "18", "2") \ + STORE_LINE("11", "19", "3") \ + "STORE_16x4_C_END:\n" + +// clang-format on + + asm volatile( + ".arch armv8.2-a+fp16\n" + + // load accumulator C + "cmp %w[type], #0\n" + "beq 5f\n" LOAD_C + "b 6f\n" + + "5:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "6:\n" + "ldr %d[a0], [%[a_ptr]]\n" + "ldr %q[b0], [%[b_ptr]]\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "ldr %q[b0a], [%[b_ptr], #32]\n" + "ldr %q[b1a], [%[b_ptr], #48]\n" + + "cbz %w[k], 4f\n" + + "1:\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %q[b0], [%[b_ptr], #64]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "add %[b_ptr], %[b_ptr], #64\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + "ldr %d[a0], [%[a_ptr], #16]\n" + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "ldr %q[b0a], [%[b_ptr], #32]\n" + + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + "ldr %q[b1a], [%[b_ptr], #48]\n" + + "add %[a_ptr], %[a_ptr], #16\n" + "subs %w[k], %w[k], #1\n" + + "bne 1b\n" + "4:\n" + // Jump to odd tail if necessary. + "cbnz %w[oddk], 2f\n" + + // Even tail + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "add %[b_ptr], %[b_ptr], #64\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "fmla v8.8h, %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h, %[b0a].8h, %[a0a].h[1]\n" + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + + "b 3f\n" + + // Odd tail + "2:\n" + "add %[a_ptr], %[a_ptr], #8\n" + "add %[b_ptr], %[b_ptr], #32\n" + + "fmla v8.8h, %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h, %[b0].8h, %[a0].h[1]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + + "3:\n" STORE_C + : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), + [b1] "+w"(b1), [k] "+r"(k), [b0a] "+w"(b0a), + [b1a] "+w"(b1a), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), + [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : + [oddk] "r"(oddk), [y_remain] "r"(y_remain), [type] "r"(type) + : "w0", "v8", "v9", "v10", "v11", "v16", "v17", "v18", + "v19", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +// Overview of register layout: +// +// A 2x8 cell of Rhs is stored in 16bit in q2, q5 +// A 4x2 cell of Lhs is stored in 16bit in d0, d1 +// A 4x8 block of accumulators is stored in 16bit in q8-q11. +// +// +--------+ +// | v2[0-7]| +// Rhs +--------+ +// | v5[0-7]| +// +--------+ +// +// | | +// +// Lhs | | +// +// +--+--+ - - - - +--------+ +// |v0|v1| | v8[0-7]| +// |v0|v1| | v9[0-7]| +// |v0|v1| |v10[0-7]| +// |v0|v1| |v11[0-7]| +// +--+--+ - - - - +--------+ +// +// Accumulator +void aarch64_hgemm_assembly_kernel_8x4(const __fp16* a_ptr, + const __fp16*& b_ptr, int K, + __fp16* outptr0, int ldout, int y_remain, + int type) { + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + register float16x8_t a0 asm("v0"); + register float16x8_t a0a asm("v1"); + register float16x8_t b0 asm("v2"); + register float16x8_t b0a asm("v5"); + + __fp16* outptr1 = outptr0 + ldout; + __fp16* outptr2 = outptr1 + ldout; + __fp16* outptr3 = outptr2 + ldout; + +// clang-format off +#define LOAD_LINE(v1, n) \ + "cbz w0, LOAD_8x4_C_END\n" \ + "ldr q" v1 ", [%[outptr" n \ + "]]\n" \ + "subs w0, w0, #1\n" + +#define LOAD_C \ + "mov w0, %w[y_remain]\n" \ + LOAD_LINE("8", "0") \ + LOAD_LINE("9", "1") \ + LOAD_LINE("10", "2") \ + LOAD_LINE("11", "3") \ + "LOAD_8x4_C_END:\n" + +#define STORE_LINE(v1, n) \ + "cbz w0, STORE_8x4_C_END\n" \ + "str q" v1 ", [%[outptr" n \ + "]]\n" \ + "subs w0, w0, #1\n" + +#define STORE_C \ + "mov w0, %w[y_remain]\n" \ + STORE_LINE("8", "0") \ + STORE_LINE("9", "1") \ + STORE_LINE("10", "2") \ + STORE_LINE("11", "3") \ + "STORE_8x4_C_END:\n" +// clang-format on + + asm volatile( + ".arch armv8.2-a+fp16\n" + + // load accumulator C + "cmp %w[type], #0\n" + "beq 5f\n" LOAD_C + "b 6f\n" + "5:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + + "6:\n" + "ldr %d[a0], [%[a_ptr]]\n" + "ldr %q[b0], [%[b_ptr]]\n" + "ldr %q[b0a], [%[b_ptr], #16]\n" + + "cbz %w[k], 4f\n" + + "1:\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %q[b0], [%[b_ptr], #32]\n" + + "add %[b_ptr], %[b_ptr], #32\n" + "ldr %d[a0], [%[a_ptr], #16]\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "ldr %q[b0a], [%[b_ptr], #16]\n" + + "add %[a_ptr], %[a_ptr], #16\n" + "subs %w[k], %w[k], #1\n" + + "bne 1b\n" + "4:\n" + // Jump to odd tail if necessary. + "cbnz %w[oddk], 2f\n" + + // Even tail + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + + "add %[b_ptr], %[b_ptr], #32\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "fmla v8.8h, %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h, %[b0a].8h, %[a0a].h[1]\n" + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + + "b 3f\n" + + // Odd tail + "2:\n" + "add %[a_ptr], %[a_ptr], #8\n" + "add %[b_ptr], %[b_ptr], #16\n" + + "fmla v8.8h, %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h, %[b0].8h, %[a0].h[1]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + + "3:\n" STORE_C + : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), + [k] "+r"(k), [b0a] "+w"(b0a), [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), + [outptr3] "+r"(outptr3) + : + [oddk] "r"(oddk), [y_remain] "r"(y_remain), [type] "r"(type) + : "w0", "v8", "v9", "v10", "v11", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +// Overview of register layout: +// +// A 2x8 cell of Rhs is stored in 16bit in d2, d5 +// A 4x2 cell of Lhs is stored in 16bit in d0, d1 +// A 4x8 block of accumulators is stored in 16bit in d8-d11. +// +// +--------+ +// | d2[0-3]| +// Rhs +--------+ +// | d5[0-3]| +// +--------+ +// +// | | +// +// Lhs | | +// +// +--+--+ - - - - +--------+ +// |d0|d1| | d8[0-3]| +// |d0|d1| | d9[0-3]| +// |d0|d1| |d10[0-3]| +// |d0|d1| |d11[0-3]| +// +--+--+ - - - - +--------+ +// +// Accumulator +void aarch64_hgemm_assembly_kernel_4x4(const __fp16* a_ptr, + const __fp16*& b_ptr, int K, + __fp16* outptr0, int ldout, int x_remain, + int y_remain, int type) { + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + register float16x8_t a0 asm("v0"); + register float16x8_t a0a asm("v1"); + register float16x8_t b0 asm("v2"); + register float16x8_t b0a asm("v5"); + + __fp16* outptr1 = outptr0 + ldout; + __fp16* outptr2 = outptr1 + ldout; + __fp16* outptr3 = outptr2 + ldout; + +#define LOAD_LINE(reg_index, n) \ + "cbz w1, LOAD_4x4_C_END\n" \ + "mov x0, %[outptr" n \ + "]\n" \ + "cmp %w[x_remain], #4\n" \ + "b.lt REMAIN_LOAD_4x4_LINE_LESS_THAN_4_" n \ + "\n" \ + "ldr d" reg_index \ + ", [x0]\n" \ + "b LOAD_4x4_LINE_END_" n \ + "\n" \ + \ + "REMAIN_LOAD_4x4_LINE_LESS_THAN_4_" n \ + ":\n" \ + "cmp %w[x_remain], #0\n" \ + "beq LOAD_4x4_LINE_END_" n \ + "\n" \ + "ld1 {v" reg_index \ + ".h}[0], [x0], #2\n" \ + "cmp %w[x_remain], #1\n" \ + "beq LOAD_4x4_LINE_END_" n \ + "\n" \ + "ld1 {v" reg_index \ + ".h}[1], [x0], #2\n" \ + "cmp %w[x_remain], #2\n" \ + "beq LOAD_4x4_LINE_END_" n \ + "\n" \ + "ld1 {v" reg_index \ + ".h}[2], [x0], #2\n" \ + "LOAD_4x4_LINE_END_" n \ + ":\n" \ + "subs w1, w1, #1\n" + +#define LOAD_C \ + "mov w1, %w[y_remain]\n" \ + LOAD_LINE("8", "0") \ + LOAD_LINE("9", "1") \ + LOAD_LINE("10", "2") \ + LOAD_LINE("11", "3") \ + "LOAD_4x4_C_END:\n" + +#define STORE_LINE(reg_index, n) \ + "cbz w1, STORE_4x4_C_END\n" \ + "mov x0, %[outptr" n \ + "]\n" \ + "cmp %w[x_remain], #4\n" \ + "b.lt REMAIN_STORE_4x4_LINE_LESS_THAN_4_" n \ + "\n" \ + "str d" reg_index \ + ", [x0]\n" \ + "b STORE_4x4_LINE_END_" n \ + "\n" \ + \ + "REMAIN_STORE_4x4_LINE_LESS_THAN_4_" n \ + ":\n" \ + "cmp %w[x_remain], #0\n" \ + "beq STORE_4x4_LINE_END_" n \ + "\n" \ + "st1 {v" reg_index \ + ".h}[0], [x0], #2\n" \ + "cmp %w[x_remain], #1\n" \ + "beq STORE_4x4_LINE_END_" n \ + "\n" \ + "st1 {v" reg_index \ + ".h}[1], [x0], #2\n" \ + "cmp %w[x_remain], #2\n" \ + "beq STORE_4x4_LINE_END_" n \ + "\n" \ + "st1 {v" reg_index \ + ".h}[2], [x0], #2\n" \ + "STORE_4x4_LINE_END_" n \ + ":\n" \ + "subs w1, w1, #1\n" + +#define STORE_C "mov w1, %w[y_remain]\n" \ + STORE_LINE("8", "0") \ + STORE_LINE("9", "1") \ + STORE_LINE("10", "2") \ + STORE_LINE("11", "3") \ + "STORE_4x4_C_END:\n" + + asm volatile( + ".arch armv8.2-a+fp16\n" + + // load accumulator C + "cmp %w[type], #0\n" + "beq 5f\n" LOAD_C + "b 6f\n" + + "5:\n" + "eor v8.8b, v8.8b, v8.8b\n" + "eor v9.8b, v9.8b, v9.8b\n" + "eor v10.8b, v10.8b, v10.8b\n" + "eor v11.8b, v11.8b, v11.8b\n" + + "6:\n" + "ldr %d[a0], [%[a_ptr]]\n" + + "cbz %w[k], 4f\n" + + "1:\n" + "ldp %d[b0], %d[b0a], [%[b_ptr]]\n" + "fmla v8.4h , %[b0].4h, %[a0].h[0]\n" + "fmla v9.4h , %[b0].4h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.4h, %[b0].4h, %[a0].h[2]\n" + "fmla v11.4h, %[b0].4h, %[a0].h[3]\n" + + "add %[b_ptr], %[b_ptr], #16\n" + "ldr %d[a0], [%[a_ptr], #16]\n" + + "fmla v8.4h , %[b0a].4h, %[a0a].h[0]\n" + "fmla v9.4h , %[b0a].4h, %[a0a].h[1]\n" + "fmla v10.4h, %[b0a].4h, %[a0a].h[2]\n" + "fmla v11.4h, %[b0a].4h, %[a0a].h[3]\n" + + "add %[a_ptr], %[a_ptr], #16\n" + "subs %w[k], %w[k], #1\n" + + "bne 1b\n" + "4:\n" + // Jump to odd tail if necessary. + "cbnz %w[oddk], 2f\n" + + // Even tail + "ldp %d[b0], %d[b0a], [%[b_ptr]]\n" + "fmla v8.4h , %[b0].4h, %[a0].h[0]\n" + "fmla v9.4h , %[b0].4h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.4h, %[b0].4h, %[a0].h[2]\n" + "fmla v11.4h, %[b0].4h, %[a0].h[3]\n" + + "add %[b_ptr], %[b_ptr], #16\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "fmla v8.4h, %[b0a].4h, %[a0a].h[0]\n" + "fmla v9.4h, %[b0a].4h, %[a0a].h[1]\n" + "fmla v10.4h, %[b0a].4h, %[a0a].h[2]\n" + "fmla v11.4h, %[b0a].4h, %[a0a].h[3]\n" + "b 3f\n" + + // Odd tail + "2:\n" + "ldr %d[b0], [%[b_ptr]]\n" + "add %[a_ptr], %[a_ptr], #8\n" + "add %[b_ptr], %[b_ptr], #8\n" + + "fmla v8.4h, %[b0].4h, %[a0].h[0]\n" + "fmla v9.4h, %[b0].4h, %[a0].h[1]\n" + "fmla v10.4h, %[b0].4h, %[a0].h[2]\n" + "fmla v11.4h, %[b0].4h, %[a0].h[3]\n" + + "3:\n" STORE_C + : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), + [k] "+r"(k), [b0a] "+w"(b0a), [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), + [outptr3] "+r"(outptr3) + : [oddk] "r"(oddk), [x_remain] "r"(x_remain), + [y_remain] "r"(y_remain), [type] "r"(type) + : "x0", "w1", "v8", "v9", "v10", "v11", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +void aarch64_hgemm_asimd_8x24(const __fp16* Apanel, const __fp16* Bpanel, + __fp16* out, int ldout, int x0, int xmax, int y0, + int ymax, int K, bool is_first_k) { + const __fp16* a_ptr = Apanel; + const int A_interleave = 8; + const int B_transpose1xW = 24; + const int K8 = (K << 3); + const int K4 = (K << 2); + int type = is_first_k ? 0 : 1; + + int y = y0; + for (; y + A_interleave <= ymax; y += A_interleave) { + const __fp16* a_ptr0 = a_ptr; + const __fp16* b_ptr = Bpanel; + + __fp16* outptr0 = out + (y * ldout) + x0; + + int x = x0; + + for (; x + B_transpose1xW <= xmax; x += B_transpose1xW) { + a_ptr = a_ptr0; + aarch64_hgemm_assembly_kernel_24x8(a_ptr, b_ptr, K, outptr0, ldout, + type); + outptr0 += B_transpose1xW; + } + + for (; x + 16 <= xmax; x += 16) { + a_ptr = a_ptr0; + aarch64_hgemm_assembly_kernel_16x8(a_ptr, b_ptr, K, outptr0, ldout, + type); + outptr0 += 16; + } + for (; x + 8 <= xmax; x += 8) { + a_ptr = a_ptr0; + aarch64_hgemm_assembly_kernel_8x8(a_ptr, b_ptr, K, outptr0, ldout, + type); + outptr0 += 8; + } + for (; x < xmax; x += 4) { + int x_remain = xmax - x; + a_ptr = a_ptr0; + aarch64_hgemm_assembly_kernel_4x8(a_ptr, b_ptr, K, outptr0, ldout, + x_remain, type); + outptr0 += 4; + } + a_ptr = a_ptr0 + K8; + } + + for (; y < ymax; y += 4) { + const __fp16* a_ptr0 = a_ptr; + const __fp16* b_ptr = Bpanel; + + __fp16* outptr0 = out + (y * ldout) + x0; + + int x = x0; + for (; x + B_transpose1xW <= xmax; x += B_transpose1xW) { + a_ptr = a_ptr0; + aarch64_hgemm_assembly_kernel_24x4(a_ptr, b_ptr, K, outptr0, ldout, + ymax - y, type); + outptr0 += B_transpose1xW; + } + + for (; x + 16 <= xmax; x += 16) { + a_ptr = a_ptr0; + aarch64_hgemm_assembly_kernel_16x4(a_ptr, b_ptr, K, outptr0, ldout, + ymax - y, type); + outptr0 += 16; + } + for (; x + 8 <= xmax; x += 8) { + a_ptr = a_ptr0; + aarch64_hgemm_assembly_kernel_8x4(a_ptr, b_ptr, K, outptr0, ldout, + ymax - y, type); + outptr0 += 8; + } + for (; x < xmax; x += 4) { + a_ptr = a_ptr0; + aarch64_hgemm_assembly_kernel_4x4(a_ptr, b_ptr, K, outptr0, ldout, + xmax - x, ymax - y, type); + outptr0 += 4; + } + a_ptr = a_ptr0 + K4; + } +} +} // namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL(hgemm_8x24); + +void hgemm_8x24::pack_A(dt_float16* out, const dt_float16* in, int ldin, int y0, + int ymax, int k0, int kmax, bool transpose_A) const { + if (transpose_A) { + transpose_1x8(reinterpret_cast<__fp16*>(out), + reinterpret_cast(in), ldin, y0, ymax, k0, + kmax); + } else { + interleave_8x1(reinterpret_cast<__fp16*>(out), + reinterpret_cast(in), ldin, y0, ymax, k0, + kmax); + } +} + +void hgemm_8x24::pack_B(dt_float16* out, const dt_float16* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose_B) const { + if (transpose_B) { + interleave_24x1(reinterpret_cast<__fp16*>(out), + reinterpret_cast(in), ldin, x0, xmax, k0, + kmax); + } else { + transpose_1x24(reinterpret_cast<__fp16*>(out), + reinterpret_cast(in), ldin, x0, xmax, k0, + kmax); + } +} + +void hgemm_8x24::kern(const dt_float16* packA, const dt_float16* packB, + size_t M, size_t N, size_t K, dt_float16* C, size_t LDC, + bool is_first_k, const dt_float16*, dt_float16*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float16); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + aarch64_hgemm_asimd_8x24(reinterpret_cast(packA), + reinterpret_cast(packB), + reinterpret_cast<__fp16*>(C), LDC, 0, N, 0, M, K, + is_first_k); +} +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp16/strategy.h b/dnn/src/aarch64/matrix_mul/fp16/strategy.h new file mode 100644 index 00000000..f3358b69 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp16/strategy.h @@ -0,0 +1,29 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp16/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +namespace megdnn { +namespace aarch64 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 8, 24, 1, false, + true, hgemm_8x24); + +MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 8, 8, 1, + false, true, gemm_nopack_f16_8x8); + +} // namespace matmul +} // namespace aarch64 +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp b/dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp new file mode 100644 index 00000000..fd3d7e30 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp @@ -0,0 +1,439 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/fp16/strategy.h" +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +namespace { + +// Overview of register layout: +// +// A 8x1 cell of Rhs is stored in 16bit in v0-v3 +// A 8x1 cell of Lhs is stored in 16bit in v16-v23 +// A 8x1 block of accumulators is stored in 16bit in v24-v27. +// +// Rhs +-------+ +// |v0[0-7]| +// |v1[0-7]| +// |v2[0-7]| +// |v3[0-7]| +// +-------+ +// Lhs +// +--------+ +// |v16[0-7]| +// |v17[0-7]| +// |v18[0-7]| +// |v19[0-7]| +--------+ +// |v20[0-7]| |v24[0-7]| +// |v21[0-7]| |v25[0-7]| +// |v22[0-7]| |v26[0-7]| +// |v23[0-7]| |v27[0-7]| +// +--------+ +--------+ +// Accumulator +void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, + dt_float16* output) { + //! LDB means number of elements in one block in B. we will read 24 numbers + //! first. so minus 24 * 2 bytes here. + LDB = (LDB - 24) * sizeof(dt_float16); + + asm volatile( + ".arch armv8.2-a+fp16\n" + + "ld1 {v16.4s, v17.4s}, [%[a_ptr]], 32\n" + + "subs %w[K], %w[K], #8\n" + "ld1 {v0.4s}, [%[b_ptr]], 16\n" + + "ld1 {v1.4s}, [%[b_ptr]], 16\n" + "fmul v24.8h, v16.8h, v0.h[0]\n" + + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "fmul v25.8h, v16.8h, v1.h[0]\n" + + "ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" + "fmul v26.8h, v16.8h, v2.h[0]\n" + + "ld1 {v18.4s}, [%[a_ptr]], 16\n" + "fmul v27.8h, v16.8h, v3.h[0]\n" + + "fmla v24.8h, v17.8h, v0.h[1]\n" + "fmla v25.8h, v17.8h, v1.h[1]\n" + "fmla v26.8h, v17.8h, v2.h[1]\n" + "fmla v27.8h, v17.8h, v3.h[1]\n" + + "ld1 {v19.4s}, [%[a_ptr]], 16\n" + + "fmla v24.8h, v18.8h, v0.h[2]\n" + "fmla v25.8h, v18.8h, v1.h[2]\n" + "fmla v26.8h, v18.8h, v2.h[2]\n" + "fmla v27.8h, v18.8h, v3.h[2]\n" + + "ld1 {v20.4s}, [%[a_ptr]], 16\n" + + "fmla v24.8h, v19.8h, v0.h[3]\n" + "fmla v25.8h, v19.8h, v1.h[3]\n" + "fmla v26.8h, v19.8h, v2.h[3]\n" + "fmla v27.8h, v19.8h, v3.h[3]\n" + + "ld1 {v21.4s}, [%[a_ptr]], 16\n" + + "fmla v24.8h, v20.8h, v0.h[4]\n" + "fmla v25.8h, v20.8h, v1.h[4]\n" + "fmla v26.8h, v20.8h, v2.h[4]\n" + "fmla v27.8h, v20.8h, v3.h[4]\n" + + "ld1 {v22.4s}, [%[a_ptr]], 16\n" + + "fmla v24.8h, v21.8h, v0.h[5]\n" + "fmla v25.8h, v21.8h, v1.h[5]\n" + "fmla v26.8h, v21.8h, v2.h[5]\n" + "fmla v27.8h, v21.8h, v3.h[5]\n" + + "ld1 {v23.4s}, [%[a_ptr]], 16\n" + + "fmla v24.8h, v22.8h, v0.h[6]\n" + "fmla v25.8h, v22.8h, v1.h[6]\n" + "fmla v26.8h, v22.8h, v2.h[6]\n" + "fmla v27.8h, v22.8h, v3.h[6]\n" + + "beq 2f\n" + + "1:\n" + + "ld1 {v16.4s}, [%[a_ptr]], 16\n" + + "fmla v24.8h, v23.8h, v0.h[7]\n" + "ld1 {v0.4s}, [%[b_ptr]], 16\n" + + "fmla v25.8h, v23.8h, v1.h[7]\n" + "ld1 {v1.4s}, [%[b_ptr]], 16\n" + + "fmla v26.8h, v23.8h, v2.h[7]\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + + "fmla v27.8h, v23.8h, v3.h[7]\n" + "ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" + + "ld1 {v17.4s}, [%[a_ptr]], 16\n" + + "fmla v24.8h, v16.8h, v0.h[0]\n" + "fmla v25.8h, v16.8h, v1.h[0]\n" + "fmla v26.8h, v16.8h, v2.h[0]\n" + "fmla v27.8h, v16.8h, v3.h[0]\n" + + "ld1 {v18.4s}, [%[a_ptr]], 16\n" + + "fmla v24.8h, v17.8h, v0.h[1]\n" + "fmla v25.8h, v17.8h, v1.h[1]\n" + "fmla v26.8h, v17.8h, v2.h[1]\n" + "fmla v27.8h, v17.8h, v3.h[1]\n" + + "ld1 {v19.4s}, [%[a_ptr]], 16\n" + + "fmla v24.8h, v18.8h, v0.h[2]\n" + "fmla v25.8h, v18.8h, v1.h[2]\n" + "fmla v26.8h, v18.8h, v2.h[2]\n" + "fmla v27.8h, v18.8h, v3.h[2]\n" + + "ld1 {v20.4s}, [%[a_ptr]], 16\n" + + "fmla v24.8h, v19.8h, v0.h[3]\n" + "fmla v25.8h, v19.8h, v1.h[3]\n" + "fmla v26.8h, v19.8h, v2.h[3]\n" + "fmla v27.8h, v19.8h, v3.h[3]\n" + + "ld1 {v21.4s}, [%[a_ptr]], 16\n" + + "fmla v24.8h, v20.8h, v0.h[4]\n" + "fmla v25.8h, v20.8h, v1.h[4]\n" + "fmla v26.8h, v20.8h, v2.h[4]\n" + "fmla v27.8h, v20.8h, v3.h[4]\n" + + "ld1 {v22.4s}, [%[a_ptr]], 16\n" + + "fmla v24.8h, v21.8h, v0.h[5]\n" + "fmla v25.8h, v21.8h, v1.h[5]\n" + "fmla v26.8h, v21.8h, v2.h[5]\n" + "fmla v27.8h, v21.8h, v3.h[5]\n" + + "ld1 {v23.4s}, [%[a_ptr]], 16\n" + + "fmla v24.8h, v22.8h, v0.h[6]\n" + "fmla v25.8h, v22.8h, v1.h[6]\n" + "fmla v26.8h, v22.8h, v2.h[6]\n" + "fmla v27.8h, v22.8h, v3.h[6]\n" + + "subs %w[K], %w[K], #8\n" + "bne 1b\n" + + "2:\n" + + "fmla v24.8h, v23.8h, v0.h[7]\n" + "fmla v25.8h, v23.8h, v1.h[7]\n" + "fmla v26.8h, v23.8h, v2.h[7]\n" + "fmla v27.8h, v23.8h, v3.h[7]\n" + + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory"); +} + +// Overview of register layout: +// +// A 8x1 cell of Rhs is stored in 16bit in v8-v15 +// A 8x1 cell of Lhs is stored in 16bit in v0-v7 +// A 8x1 block of accumulators is stored in 16bit in v24-v31. +// +// Rhs +--------+ +// | v8[0-7]| +// | v9[0-7]| +// |v10[0-7]| +// |v11[0-7]| +// |v12[0-7]| +// |v13[0-7]| +// |v14[0-7]| +// |v15[0-7]| +// +--------+ +// Lhs +// +--------+ - - - - -+--------+ +// | v0[0-7]| |v24[0-7]| +// | v1[0-7]| |v25[0-7]| +// | v2[0-7]| |v26[0-7]| +// | v3[0-7]| |v27[0-7]| +// | v4[0-7]| |v28[0-7]| +// | v5[0-7]| |v29[0-7]| +// | v6[0-7]| |v30[0-7]| +// | v7[0-7]| |v31[0-7]| +// +--------+ +--------+ +// Accumulator +void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, + dt_float16* output) { + //! As each load 128 number from B, but the pos add 112 * 2, so we minus 112 + //! here. + LDB = (LDB - 32) * sizeof(dt_float16); + + asm volatile( + ".arch armv8.2-a+fp16\n" + + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n" + "subs %w[K], %w[K], #8\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n" + + "fmul v24.8h, v8.8h, v0.h[0]\n" + "fmul v25.8h, v8.8h, v1.h[0]\n" + "fmul v26.8h, v8.8h, v2.h[0]\n" + "fmul v27.8h, v8.8h, v3.h[0]\n" + "fmul v28.8h, v8.8h, v4.h[0]\n" + "fmul v29.8h, v8.8h, v5.h[0]\n" + "fmul v30.8h, v8.8h, v6.h[0]\n" + "fmul v31.8h, v8.8h, v7.h[0]\n" + + "fmla v24.8h, v9.8h, v0.h[1]\n" + "fmla v25.8h, v9.8h, v1.h[1]\n" + "fmla v26.8h, v9.8h, v2.h[1]\n" + "fmla v27.8h, v9.8h, v3.h[1]\n" + "fmla v28.8h, v9.8h, v4.h[1]\n" + "fmla v29.8h, v9.8h, v5.h[1]\n" + "fmla v30.8h, v9.8h, v6.h[1]\n" + "fmla v31.8h, v9.8h, v7.h[1]\n" + + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n" + "fmla v24.8h, v10.8h, v0.h[2]\n" + "fmla v25.8h, v10.8h, v1.h[2]\n" + "fmla v26.8h, v10.8h, v2.h[2]\n" + "fmla v27.8h, v10.8h, v3.h[2]\n" + "fmla v28.8h, v10.8h, v4.h[2]\n" + "fmla v29.8h, v10.8h, v5.h[2]\n" + "fmla v30.8h, v10.8h, v6.h[2]\n" + "fmla v31.8h, v10.8h, v7.h[2]\n" + + "fmla v24.8h, v11.8h, v0.h[3]\n" + "fmla v25.8h, v11.8h, v1.h[3]\n" + "fmla v26.8h, v11.8h, v2.h[3]\n" + "fmla v27.8h, v11.8h, v3.h[3]\n" + "fmla v28.8h, v11.8h, v4.h[3]\n" + "fmla v29.8h, v11.8h, v5.h[3]\n" + "fmla v30.8h, v11.8h, v6.h[3]\n" + "fmla v31.8h, v11.8h, v7.h[3]\n" + + "fmla v24.8h, v12.8h, v0.h[4]\n" + "fmla v25.8h, v12.8h, v1.h[4]\n" + "fmla v26.8h, v12.8h, v2.h[4]\n" + "fmla v27.8h, v12.8h, v3.h[4]\n" + "fmla v24.8h, v13.8h, v0.h[5]\n" + "fmla v25.8h, v13.8h, v1.h[5]\n" + "fmla v26.8h, v13.8h, v2.h[5]\n" + "fmla v27.8h, v13.8h, v3.h[5]\n" + + "beq 2f\n" + + "1:\n" + + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n" + "fmla v24.8h, v15.8h, v0.h[7]\n" + "fmla v25.8h, v15.8h, v1.h[7]\n" + "fmla v26.8h, v15.8h, v2.h[7]\n" + "fmla v27.8h, v15.8h, v3.h[7]\n" + "fmla v24.8h, v14.8h, v0.h[6]\n" + "fmla v25.8h, v14.8h, v1.h[6]\n" + "fmla v26.8h, v14.8h, v2.h[6]\n" + "fmla v27.8h, v14.8h, v3.h[6]\n" + "fmla v28.8h, v12.8h, v4.h[4]\n" + "fmla v29.8h, v12.8h, v5.h[4]\n" + "fmla v30.8h, v12.8h, v6.h[4]\n" + "fmla v31.8h, v12.8h, v7.h[4]\n" + + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" + "fmla v28.8h, v13.8h, v4.h[5]\n" + "fmla v29.8h, v13.8h, v5.h[5]\n" + "fmla v30.8h, v13.8h, v6.h[5]\n" + "fmla v31.8h, v13.8h, v7.h[5]\n" + "fmla v28.8h, v14.8h, v4.h[6]\n" + "fmla v29.8h, v14.8h, v5.h[6]\n" + "fmla v30.8h, v14.8h, v6.h[6]\n" + "fmla v31.8h, v14.8h, v7.h[6]\n" + "fmla v28.8h, v15.8h, v4.h[7]\n" + "fmla v29.8h, v15.8h, v5.h[7]\n" + "fmla v30.8h, v15.8h, v6.h[7]\n" + "fmla v31.8h, v15.8h, v7.h[7]\n" + "fmla v24.8h, v8.8h, v0.h[0]\n" + "fmla v25.8h, v8.8h, v1.h[0]\n" + "fmla v26.8h, v8.8h, v2.h[0]\n" + "fmla v27.8h, v8.8h, v3.h[0]\n" + + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n" + "fmla v24.8h, v9.8h, v0.h[1]\n" + "fmla v25.8h, v9.8h, v1.h[1]\n" + "fmla v26.8h, v9.8h, v2.h[1]\n" + "fmla v27.8h, v9.8h, v3.h[1]\n" + "fmla v24.8h, v10.8h, v0.h[2]\n" + "fmla v25.8h, v10.8h, v1.h[2]\n" + "fmla v26.8h, v10.8h, v2.h[2]\n" + "fmla v27.8h, v10.8h, v3.h[2]\n" + "fmla v24.8h, v11.8h, v0.h[3]\n" + "fmla v25.8h, v11.8h, v1.h[3]\n" + "fmla v26.8h, v11.8h, v2.h[3]\n" + "fmla v27.8h, v11.8h, v3.h[3]\n" + + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n" + "fmla v28.8h, v10.8h, v4.h[2]\n" + "fmla v29.8h, v10.8h, v5.h[2]\n" + "fmla v30.8h, v10.8h, v6.h[2]\n" + "fmla v31.8h, v10.8h, v7.h[2]\n" + "fmla v28.8h, v8.8h, v4.h[0]\n" + "fmla v29.8h, v8.8h, v5.h[0]\n" + "fmla v30.8h, v8.8h, v6.h[0]\n" + "fmla v31.8h, v8.8h, v7.h[0]\n" + "fmla v28.8h, v9.8h, v4.h[1]\n" + "fmla v29.8h, v9.8h, v5.h[1]\n" + "fmla v30.8h, v9.8h, v6.h[1]\n" + "fmla v31.8h, v9.8h, v7.h[1]\n" + + "fmla v28.8h, v11.8h, v4.h[3]\n" + "fmla v29.8h, v11.8h, v5.h[3]\n" + "fmla v30.8h, v11.8h, v6.h[3]\n" + "fmla v31.8h, v11.8h, v7.h[3]\n" + + "fmla v24.8h, v12.8h, v0.h[4]\n" + "fmla v25.8h, v12.8h, v1.h[4]\n" + "fmla v26.8h, v12.8h, v2.h[4]\n" + "fmla v27.8h, v12.8h, v3.h[4]\n" + "fmla v24.8h, v13.8h, v0.h[5]\n" + "fmla v25.8h, v13.8h, v1.h[5]\n" + "fmla v26.8h, v13.8h, v2.h[5]\n" + "fmla v27.8h, v13.8h, v3.h[5]\n" + + "subs %w[K], %w[K], #8\n" + "bne 1b\n" + + "2:\n" + "fmla v24.8h, v14.8h, v0.h[6]\n" + "fmla v25.8h, v14.8h, v1.h[6]\n" + "fmla v26.8h, v14.8h, v2.h[6]\n" + "fmla v27.8h, v14.8h, v3.h[6]\n" + "fmla v24.8h, v15.8h, v0.h[7]\n" + "fmla v25.8h, v15.8h, v1.h[7]\n" + "fmla v26.8h, v15.8h, v2.h[7]\n" + "fmla v27.8h, v15.8h, v3.h[7]\n" + "fmla v28.8h, v12.8h, v4.h[4]\n" + "fmla v29.8h, v12.8h, v5.h[4]\n" + "fmla v28.8h, v13.8h, v4.h[5]\n" + "fmla v29.8h, v13.8h, v5.h[5]\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n" + "fmla v28.8h, v14.8h, v4.h[6]\n" + "fmla v29.8h, v14.8h, v5.h[6]\n" + "fmla v28.8h, v15.8h, v4.h[7]\n" + "fmla v29.8h, v15.8h, v5.h[7]\n" + "fmla v30.8h, v12.8h, v6.h[4]\n" + "fmla v31.8h, v12.8h, v7.h[4]\n" + "fmla v30.8h, v13.8h, v6.h[5]\n" + "fmla v31.8h, v13.8h, v7.h[5]\n" + "fmla v30.8h, v14.8h, v6.h[6]\n" + "fmla v31.8h, v14.8h, v7.h[6]\n" + "fmla v30.8h, v15.8h, v6.h[7]\n" + "fmla v31.8h, v15.8h, v7.h[7]\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "cc", "memory"); +} + +} // anonymous namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8); + +void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, + const dt_float16* B, size_t LDB, dt_float16* C, + size_t LDC, size_t M, size_t K, size_t N, + const dt_float16*, void*, bool trA, + bool trB) const { + constexpr static size_t MB = 8; + constexpr static size_t KB = 8; + constexpr static size_t NB = 8; + constexpr static size_t CALCBLK = 4; + + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); + + //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) + for (size_t m = 0; m < M; m += MB) { + dt_float16* output = C + (m / MB) * LDC; + const dt_float16* cur_B = B; + size_t n = 0; + for (; n + NB - 1 < N; n += NB) { + kern_8x8(A, cur_B, LDB, K, output); + cur_B += KB * NB; + output += MB * NB; + } + if (n < N) { + kern_8x4(A, cur_B, LDB, K, output); + } + A += LDA; + } +} + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/common.h b/dnn/src/aarch64/matrix_mul/fp32/common.h new file mode 100644 index 00000000..50120684 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp32/common.h @@ -0,0 +1,39 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp32/common.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include +#include "megdnn/arch.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace aarch64 { + +MEGDNN_NOINLINE void sgemm_packA_n(const float* A, float* Apacked, size_t M, + size_t K, size_t LDA, const float* alpha); + +MEGDNN_NOINLINE void sgemm_packA_t(const float* A, float* Apacked, size_t M, + size_t K, size_t LDA, const float* alpha); + +MEGDNN_NOINLINE void sgemm_packB_n(const float* B, float* Bpacked, size_t K, + size_t N, size_t LDB); + +MEGDNN_NOINLINE void sgemm_packB_t(const float* B, float* Bpacked, size_t K, + size_t N, size_t LDB); + +MEGDNN_NOINLINE void sgemm_kernel12x8(const float* A, const float* B, float* C, + size_t LDC, size_t M, size_t N, size_t K, + int type, const float* beta); + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h new file mode 100644 index 00000000..fcec5fb9 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h @@ -0,0 +1,718 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + + +namespace megdnn { +namespace aarch64 { +namespace matmul_general_4x16 { + +// Overview of register layout: +// +// A 1x16 cell of Rhs is stored in 32bit in v1-v4 +// A 4x1 cell of Lhs is stored in 32bit in v0 +// A 4x16 block of accumulators is stored in 32bit in v10-v25. +// +// +--------+--------+--------+--------+ +// | v2[0-3]| v3[0-3]| v4[0-3]| v5[0-3]| +// Rhs +--------+--------+--------+--------+ +// +// | | | | | +// +// Lhs | | | | | +// +// +--+ - - - - +--------+--------+--------+--------+ +// |v0| |v10[0-3]|v11[0-3]|v12[0-3]|v13[0-3]| +// |v0| |v14[0-3]|v15[0-3]|v16[0-3]|v17[0-3]| +// |v0| |v18[0-3]|v19[0-3]|v20[0-3]|v21[0-3]| +// |v0| |v22[0-3]|v23[0-3]|v24[0-3]|v25[0-3]| +// +--+ - - - - +--------+--------+--------+--------+ +// +// Accumulator +void kern_4x16(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, int m_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = reinterpret_cast(output); + +// clang-format off +#define LOAD_LINE(v0, v1, v2, v3, n) \ + "cmp x10, #0\n" \ + "beq 100f\n" \ + "mov x9, x" n "\n" \ + "ld1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s, v" v3 ".4s}, [x9], 64\n" \ + "subs x10, x10, #1\n" + +#define LOAD_C \ + "mov x10, %x[m_remain]\n" \ + LOAD_LINE("10", "11", "12", "13", "0") \ + LOAD_LINE("14", "15", "16", "17", "1") \ + LOAD_LINE("18", "19", "20", "21", "2") \ + LOAD_LINE("22", "23", "24", "25", "3") \ + "100:\n" + +#define STORE_LINE(v0, v1, v2, v3, n) \ + "cmp x10, #0\n" \ + "beq 101f\n" \ + "mov x9, x" n "\n" \ + "st1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s, v" v3 ".4s}, [x9], 64\n" \ + "subs x10, x10, #1\n" + +#define STORE_C \ + "mov x10, %x[m_remain]\n" \ + STORE_LINE("10", "11", "12", "13", "0") \ + STORE_LINE("14", "15", "16", "17", "1") \ + STORE_LINE("18", "19", "20", "21", "2") \ + STORE_LINE("22", "23", "24", "25", "3") \ + "101:\n" + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + + "2: \n" + "ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], 64\n" + + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v3.4s, v0.s[0]\n" + "fmla v12.4s, v4.4s, v0.s[0]\n" + "fmla v13.4s, v5.4s, v0.s[0]\n" + "ld1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[b_ptr]], 64\n" + "fmla v14.4s, v2.4s, v0.s[1]\n" + "fmla v15.4s, v3.4s, v0.s[1]\n" + "fmla v16.4s, v4.4s, v0.s[1]\n" + "fmla v17.4s, v5.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v18.4s, v2.4s, v0.s[2]\n" + "fmla v19.4s, v3.4s, v0.s[2]\n" + "fmla v20.4s, v4.4s, v0.s[2]\n" + "fmla v21.4s, v5.4s, v0.s[2]\n" + "fmla v22.4s, v2.4s, v0.s[3]\n" + "fmla v23.4s, v3.4s, v0.s[3]\n" + "fmla v24.4s, v4.4s, v0.s[3]\n" + "fmla v25.4s, v5.4s, v0.s[3]\n" + + "ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], 64\n" + "fmla v10.4s, v6.4s, v1.s[0]\n" + "fmla v11.4s, v7.4s, v1.s[0]\n" + "fmla v12.4s, v8.4s, v1.s[0]\n" + "fmla v13.4s, v9.4s, v1.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[1]\n" + "fmla v15.4s, v7.4s, v1.s[1]\n" + "fmla v16.4s, v8.4s, v1.s[1]\n" + "fmla v17.4s, v9.4s, v1.s[1]\n" + "fmla v18.4s, v6.4s, v1.s[2]\n" + "fmla v19.4s, v7.4s, v1.s[2]\n" + "fmla v20.4s, v8.4s, v1.s[2]\n" + "fmla v21.4s, v9.4s, v1.s[2]\n" + "fmla v22.4s, v6.4s, v1.s[3]\n" + "fmla v23.4s, v7.4s, v1.s[3]\n" + "fmla v24.4s, v8.4s, v1.s[3]\n" + "fmla v25.4s, v9.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v3.4s, v0.s[0]\n" + "fmla v12.4s, v4.4s, v0.s[0]\n" + "fmla v13.4s, v5.4s, v0.s[0]\n" + "ld1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[b_ptr]], 64\n" + "fmla v14.4s, v2.4s, v0.s[1]\n" + "fmla v15.4s, v3.4s, v0.s[1]\n" + "fmla v16.4s, v4.4s, v0.s[1]\n" + "fmla v17.4s, v5.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v18.4s, v2.4s, v0.s[2]\n" + "fmla v19.4s, v3.4s, v0.s[2]\n" + "fmla v20.4s, v4.4s, v0.s[2]\n" + "fmla v21.4s, v5.4s, v0.s[2]\n" + "fmla v22.4s, v2.4s, v0.s[3]\n" + "fmla v23.4s, v3.4s, v0.s[3]\n" + "fmla v24.4s, v4.4s, v0.s[3]\n" + "fmla v25.4s, v5.4s, v0.s[3]\n" + + "fmla v10.4s, v6.4s, v1.s[0]\n" + "fmla v11.4s, v7.4s, v1.s[0]\n" + "fmla v12.4s, v8.4s, v1.s[0]\n" + "fmla v13.4s, v9.4s, v1.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[1]\n" + "fmla v15.4s, v7.4s, v1.s[1]\n" + "fmla v16.4s, v8.4s, v1.s[1]\n" + "fmla v17.4s, v9.4s, v1.s[1]\n" + "fmla v18.4s, v6.4s, v1.s[2]\n" + "fmla v19.4s, v7.4s, v1.s[2]\n" + "fmla v20.4s, v8.4s, v1.s[2]\n" + "fmla v21.4s, v9.4s, v1.s[2]\n" + "fmla v22.4s, v6.4s, v1.s[3]\n" + "fmla v23.4s, v7.4s, v1.s[3]\n" + "fmla v24.4s, v8.4s, v1.s[3]\n" + "fmla v25.4s, v9.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v3.4s, v0.s[0]\n" + "fmla v12.4s, v4.4s, v0.s[0]\n" + "fmla v13.4s, v5.4s, v0.s[0]\n" + "fmla v14.4s, v2.4s, v0.s[1]\n" + "fmla v15.4s, v3.4s, v0.s[1]\n" + "fmla v16.4s, v4.4s, v0.s[1]\n" + "fmla v17.4s, v5.4s, v0.s[1]\n" + "fmla v18.4s, v2.4s, v0.s[2]\n" + "fmla v19.4s, v3.4s, v0.s[2]\n" + "fmla v20.4s, v4.4s, v0.s[2]\n" + "fmla v21.4s, v5.4s, v0.s[2]\n" + "fmla v22.4s, v2.4s, v0.s[3]\n" + "fmla v23.4s, v3.4s, v0.s[3]\n" + "fmla v24.4s, v4.4s, v0.s[3]\n" + "fmla v25.4s, v5.4s, v0.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [m_remain] "+r"(m_remain), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", + "x10", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +// Overview of register layout: +// +// A 2x4 cell of Rhs is stored in 32bit in v2 - v3 +// A 4x2 cell of Lhs is stored in 32bit in v0 - v1 +// A 4x4 block of accumulators is stored in 32bit in v4-v6 +// +// +--------+ +// | v2[0-3]| +// Rhs +--------+ +// | v3[0-3]| +// +--------+ +// +// | | +// +// Lhs | | +// +// +--+--+ - - - - +--------+ +// |v0|v1| | v4[0-3]| +// |v0|v1| | v5[0-3]| +// |v0|v1| | v6[0-3]| +// |v0|v1| | v7[0-3]| +// +--+--+ - - - - +--------+ +// +// Accumulator +void kern_4x4(const float* packA, const float* packB, int K, float* output, + int LDC, bool is_first_k, int m_remain, int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = output; + +// clang-format off +#define LOAD_LINE(v0, n) \ + "cmp x10, #0\n" \ + "beq 102f\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ld1 {v" v0 ".4s}, [x" n "], 16\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[0], [x" n "], 4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[1], [x" n "], 4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[2], [x" n "], 4\n" \ + "101" n ":\n" \ + "subs x10, x10, #1\n" + +#define LOAD_C \ + "mov x10, %x[m_remain]\n" \ + LOAD_LINE("4", "0") \ + LOAD_LINE("5", "1") \ + LOAD_LINE("6", "2") \ + LOAD_LINE("7", "3") \ + "102:\n" + +#define STORE_LINE(v0, n) \ + "cmp x10, #0 \n" \ + "beq 105f\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "st1 {v" v0 ".4s}, [x" n " ], 16\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[0], [x" n "], 4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[1], [x" n "], 4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[2], [x" n "], 4\n" \ + "104" n ":\n" \ + "subs x10, x10, #1\n" + + +#define STORE_C \ + "mov x10, %x[m_remain]\n" \ + STORE_LINE("4", "0") \ + STORE_LINE("5", "1") \ + STORE_LINE("6", "2") \ + STORE_LINE("7", "3") \ + "105:\n" +// clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + + "2: \n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "ld1 {v3.4s}, [%[b_ptr]], 16\n" + "fmla v4.4s, v2.4s, v0.s[0]\n" + "fmla v5.4s, v2.4s, v0.s[1]\n" + "fmla v6.4s, v2.4s, v0.s[2]\n" + "fmla v7.4s, v2.4s, v0.s[3]\n" + + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "fmla v4.4s, v3.4s, v1.s[0]\n" + "fmla v5.4s, v3.4s, v1.s[1]\n" + "fmla v6.4s, v3.4s, v1.s[2]\n" + "fmla v7.4s, v3.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "ld1 {v3.4s}, [%[b_ptr]], 16\n" + "fmla v4.4s, v2.4s, v0.s[0]\n" + "fmla v5.4s, v2.4s, v0.s[1]\n" + "fmla v6.4s, v2.4s, v0.s[2]\n" + "fmla v7.4s, v2.4s, v0.s[3]\n" + + "fmla v4.4s, v3.4s, v1.s[0]\n" + "fmla v5.4s, v3.4s, v1.s[1]\n" + "fmla v6.4s, v3.4s, v1.s[2]\n" + "fmla v7.4s, v3.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v4.4s, v2.4s, v0.s[0]\n" + "fmla v5.4s, v2.4s, v0.s[1]\n" + "fmla v6.4s, v2.4s, v0.s[2]\n" + "fmla v7.4s, v2.4s, v0.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [m_remain] "+r"(m_remain), + [n_remain] "+r"(n_remain), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", + "x2", "x3", "x10", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, + int ymax, int k0, int kmax) { + float zerobuff[4]; + std::memset(zerobuff, 0, sizeof(float) * 4); + constexpr int PACK_SIZE = 4*4; + + int y = y0; + for (; y + 3 < ymax; y += 4) { + // printf("main loop pack_a_n %p \n",outptr); + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = (kmax - k0); + for (; K > 3; K -= 4) { + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += PACK_SIZE; + } + + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); + } + + for (; y < ymax; y += 4) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = (kmax - k0); + for (; K > 3; K -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += PACK_SIZE; + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); + } + } +} + +void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, + int xmax, int k0, int kmax) { + int ksize = kmax - k0; + int ksize4 = (ksize << 2); + float* outptr_base = out; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k * ldin + x0; + const float* inptr1 = inptr + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + prefetch_3x(inptr3); + + int x = x0; + auto outptr = outptr_base; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + + outptr_base += 4 * 4; + } + + for (; k < kmax; k++) { + const float* inptr = in + k * ldin + x0; + prefetch_3x(inptr); + int x = x0; + auto outptr = outptr_base; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_1x4_1_s(inptr, outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_1(inptr, outptr, 4, xmax - x); + } + + outptr_base += 4; + } +} + +void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int ksize = kmax - k0; + int ksize16 = ksize * 16; + int ksize4 = (ksize << 2); + float* outptr_base = out; + float* outptr_base4 = outptr_base + (xmax - x0) / 16 * ksize16; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k * ldin + x0; + const float* inptr1 = inptr + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + prefetch_3x(inptr3); + + int x = x0; + auto outptr = outptr_base; + for (; x + 16 <= xmax; x += 16) { + auto outptr_interleave = outptr; + interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize16; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + + outptr_base += 16 * 4; + outptr_base4 += 4 * 4; + } + + for (; k < kmax; k++) { + const float* inptr = in + k * ldin + x0; + prefetch_3x(inptr); + int x = x0; + auto outptr = outptr_base; + for (; x + 16 <= xmax; x += 16) { + auto outptr_interleave = outptr; + interleave_1x16_1_s(inptr, outptr_interleave); + outptr += ksize16; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_1x4_1_s(inptr, outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_1(inptr, outptr, 4, xmax - x); + } + + outptr_base += 16; + outptr_base4 += 4; + } +} + +void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, + int y0, int ymax, int k0, int kmax) { + float* outptr = out; + const float* inptr = in; + float zerobuff[4]; + std::memset(zerobuff, 0, sizeof(float) * 4); + int K16 = 16 * (kmax - k0); + + int y = y0; + + for (; y + 16 <= ymax; y += 16) { + int yi = y; + for (; yi < y + 16; yi += 4) { + const float* inptr0 = inptr + yi * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + float* outptr_inner = outptr + yi - y; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int x = (kmax - k0); + for (; x > 3; x -= 4) { + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, + 64); + outptr_inner += 64; + } + for (; x > 0; x--) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + outptr_inner += 12; + } + } + outptr += K16; + } + + for (; y < ymax; y += 4) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + /* Cope with ragged cases by copying from a buffer of zeroes instead + */ + int x = (kmax - k0); + for (; x > 3; x -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += 16; + } + + if (x > 0) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, x); + } + } +} + +} // matmul_general_4x16 +} // aarch64 +} // megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h new file mode 100644 index 00000000..076c63b9 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h @@ -0,0 +1,1242 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + + +namespace megdnn { +namespace aarch64 { +namespace matmul_general_8x12 { + +// Overview of register layout: +// +// A 1x12 cell of Rhs is stored in 32bit in v2-v7 +// A 8x1 cell of Lhs is stored in 32bit in (v0-v1) +// A 8x12 block of accumulators is stored in 32bit in v8-v31. +// +// +--------+--------+--------+ +// | v2[0-3]| v3[0-3]| v4[0-3]| +// | v5[0-3]| v6[0-3]| v7[0-3]| +// Rhs +--------+--------+--------+ +// +// | | | | +// +// Lhs | | | | +// +// +--+ --- - +--------+--------+--------+ +// |v0| | v8[0-3]| v9[0-3]|v10[0-3]| +// |v0| |v11[0-3]|v12[0-3]|v13[0-3]| +// |v0| |v14[0-3]|v15[0-3]|v16[0-3]| +// |v0| |v17[0-3]|v18[0-3]|v19[0-3]| +// |v1| |v20[0-3]|v21[0-3]|v22[0-3]| +// |v1| |v23[0-3]|v24[0-3]|v25[0-3]| +// |v1| |v26[0-3]|v27[0-3]|v28[0-3]| +// |v1| |v29[0-3]|v30[0-3]|v31[0-3]| +// +--+ --- - +--------+--------+--------+ +// +// Accumulator +void kern_8x12(const float* packA, const float* packB, int K, float* output, + int LDC, bool is_first_k) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = reinterpret_cast(output); + +// clang-format off +#define LOAD_LINE(v0, v1, v2, n) \ + "ld1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s}, [x" n "]\n" \ + +#define LOAD_C \ + LOAD_LINE("8", "9", "10", "0") \ + LOAD_LINE("11", "12", "13", "1") \ + LOAD_LINE("14", "15", "16", "2") \ + LOAD_LINE("17", "18", "19", "3") \ + LOAD_LINE("20", "21", "22", "4") \ + LOAD_LINE("23", "24", "25", "5") \ + LOAD_LINE("26", "27", "28", "6") \ + LOAD_LINE("29", "30", "31", "7") \ + +#define STORE_LINE(v0, v1, v2, n) \ + "st1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s}, [x" n "]\n" \ + +#define STORE_C \ + STORE_LINE("8", "9", "10", "0") \ + STORE_LINE("11", "12", "13", "1") \ + STORE_LINE("14", "15", "16", "2") \ + STORE_LINE("17", "18", "19", "3") \ + STORE_LINE("20", "21", "22", "4") \ + STORE_LINE("23", "24", "25", "5") \ + STORE_LINE("26", "27", "28", "6") \ + STORE_LINE("29", "30", "31", "7") \ + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "add x4, x3, %x[LDC]\n" + "add x5, x4, %x[LDC]\n" + "add x6, x5, %x[LDC]\n" + "add x7, x6, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [x0]\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "prfm pstl1keep, [x1]\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "prfm pstl1keep, [x2]\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "prfm pstl1keep, [x3]\n" + "eor v23.16b, v23.16b, v23.16b\n" + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "prfm pstl1keep, [x4]\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "prfm pstl1keep, [x5]\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + "prfm pstl1keep, [x6]\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "prfm pldl1keep, [%[a_ptr], #64]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v21.4s, v3.4s, v1.s[0]\n" + "fmla v22.4s, v4.4s, v1.s[0]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "fmla v24.4s, v3.4s, v1.s[1]\n" + "fmla v25.4s, v4.4s, v1.s[1]\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], 48\n" + "fmla v27.4s, v3.4s, v1.s[2]\n" + "fmla v28.4s, v4.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + "prfm pldl1keep, [%[b_ptr], #64]\n" + "fmla v30.4s, v3.4s, v1.s[3]\n" + "fmla v31.4s, v4.4s, v1.s[3]\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "fmla v9.4s, v6.4s, v0.s[0]\n" + "fmla v10.4s, v7.4s, v0.s[0]\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v5.4s, v0.s[2]\n" + "fmla v15.4s, v6.4s, v0.s[2]\n" + "fmla v16.4s, v7.4s, v0.s[2]\n" + "fmla v17.4s, v5.4s, v0.s[3]\n" + "fmla v18.4s, v6.4s, v0.s[3]\n" + "fmla v19.4s, v7.4s, v0.s[3]\n" + "fmla v20.4s, v5.4s, v1.s[0]\n" + "fmla v21.4s, v6.4s, v1.s[0]\n" + "fmla v22.4s, v7.4s, v1.s[0]\n" + "fmla v23.4s, v5.4s, v1.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v24.4s, v6.4s, v1.s[1]\n" + "fmla v25.4s, v7.4s, v1.s[1]\n" + "fmla v26.4s, v5.4s, v1.s[2]\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "fmla v27.4s, v6.4s, v1.s[2]\n" + "fmla v28.4s, v7.4s, v1.s[2]\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + "fmla v30.4s, v6.4s, v1.s[3]\n" + "prfm pldl1keep, [%[b_ptr], #64]\n" + "fmla v31.4s, v7.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v21.4s, v3.4s, v1.s[0]\n" + "fmla v22.4s, v4.4s, v1.s[0]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "fmla v24.4s, v3.4s, v1.s[1]\n" + "fmla v25.4s, v4.4s, v1.s[1]\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], 48\n" + "fmla v27.4s, v3.4s, v1.s[2]\n" + "fmla v28.4s, v4.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + "fmla v30.4s, v3.4s, v1.s[3]\n" + "fmla v31.4s, v4.4s, v1.s[3]\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "fmla v9.4s, v6.4s, v0.s[0]\n" + "fmla v10.4s, v7.4s, v0.s[0]\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v5.4s, v0.s[2]\n" + "fmla v15.4s, v6.4s, v0.s[2]\n" + "fmla v16.4s, v7.4s, v0.s[2]\n" + "fmla v17.4s, v5.4s, v0.s[3]\n" + "st1 {v8.4s, v9.4s, v10.4s}, [x0]\n" + "fmla v18.4s, v6.4s, v0.s[3]\n" + "fmla v19.4s, v7.4s, v0.s[3]\n" + "fmla v20.4s, v5.4s, v1.s[0]\n" + "fmla v21.4s, v6.4s, v1.s[0]\n" + "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" + "fmla v22.4s, v7.4s, v1.s[0]\n" + "fmla v23.4s, v5.4s, v1.s[1]\n" + "fmla v24.4s, v6.4s, v1.s[1]\n" + "fmla v25.4s, v7.4s, v1.s[1]\n" + "st1 {v14.4s, v15.4s, v16.4s}, [x2]\n" + "fmla v26.4s, v5.4s, v1.s[2]\n" + "fmla v27.4s, v6.4s, v1.s[2]\n" + "fmla v28.4s, v7.4s, v1.s[2]\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + "fmla v30.4s, v6.4s, v1.s[3]\n" + "fmla v31.4s, v7.4s, v1.s[3]\n" + "st1 {v17.4s, v18.4s, v19.4s}, [x3]\n" + "st1 {v20.4s, v21.4s, v22.4s}, [x4]\n" + "st1 {v23.4s, v24.4s, v25.4s}, [x5]\n" + "st1 {v26.4s, v27.4s, v28.4s}, [x6]\n" + "st1 {v29.4s, v30.4s, v31.4s}, [x7]\n" + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "st1 {v8.4s, v9.4s, v10.4s}, [x0]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" + "fmla v21.4s, v3.4s, v1.s[0]\n" + "fmla v22.4s, v4.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "fmla v24.4s, v3.4s, v1.s[1]\n" + "st1 {v14.4s, v15.4s, v16.4s}, [x2]\n" + "fmla v25.4s, v4.4s, v1.s[1]\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v27.4s, v3.4s, v1.s[2]\n" + "fmla v28.4s, v4.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + "st1 {v17.4s, v18.4s, v19.4s}, [x3]\n" + "fmla v30.4s, v3.4s, v1.s[3]\n" + "fmla v31.4s, v4.4s, v1.s[3]\n" + "st1 {v20.4s, v21.4s, v22.4s}, [x4]\n" + "st1 {v23.4s, v24.4s, v25.4s}, [x5]\n" + "st1 {v26.4s, v27.4s, v28.4s}, [x6]\n" + "st1 {v29.4s, v30.4s, v31.4s}, [x7]\n" + + "6:\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +// Overview of register layout: +// +// A 1x12 cell of Rhs is stored in 32bit in v2-v7 +// A 8x1 cell of Lhs is stored in 32bit in (v0-v1) +// A 8x12 block of accumulators is stored in 32bit in v8-v31. +// +// +--------+ +// | v2[0-3]| +// | v5[0-3]| +// Rhs +--------+ +// +// | | +// +// Lhs | | +// +// +--+ --- - +--------+ +// |v0| | v8[0-3]| +// |v0| |v11[0-3]| +// |v0| |v14[0-3]| +// |v0| |v17[0-3]| +// |v1| |v20[0-3]| +// |v1| |v23[0-3]| +// |v1| |v26[0-3]| +// |v1| |v29[0-3]| +// +--+ --- - +--------+ +// +// Accumulator +void kern_8x4(const float* packA, const float* packB, int K, float* output, + int LDC, bool is_first_k, int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = reinterpret_cast(output); + +// clang-format off +#define LOAD_LINE(v0, n) \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ld1 {v" v0 ".4s}, [x" n "],#16\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[0], [x" n "],#4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[1], [x" n "],#4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[2], [x" n "],#4\n" \ + "101" n ":\n" \ + +#define LOAD_C \ + LOAD_LINE("8", "0") \ + LOAD_LINE("11", "1") \ + LOAD_LINE("14", "2") \ + LOAD_LINE("17", "3") \ + LOAD_LINE("20", "4") \ + LOAD_LINE("23", "5") \ + LOAD_LINE("26", "6") \ + LOAD_LINE("29", "7") \ + + +#define STORE_LINE(v0, n) \ + "cmp %w[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "st1 {v" v0 ".4s}, [x" n " ],#16\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[0], [x" n "],#4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[1], [x" n "],#4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[2], [x" n "],#4\n" \ + "104" n ":\n" \ + + +#define STORE_C \ + STORE_LINE("8", "0") \ + STORE_LINE("11", "1") \ + STORE_LINE("14", "2") \ + STORE_LINE("17", "3") \ + STORE_LINE("20", "4") \ + STORE_LINE("23", "5") \ + STORE_LINE("26", "6") \ + STORE_LINE("29", "7") \ + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "add x4, x3, %x[LDC]\n" + "add x5, x4, %x[LDC]\n" + "add x6, x5, %x[LDC]\n" + "add x7, x6, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + + "2: \n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "ld1 {v5.4s}, [%[b_ptr]], 16\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v5.4s, v0.s[2]\n" + "fmla v17.4s, v5.4s, v0.s[3]\n" + "fmla v20.4s, v5.4s, v1.s[0]\n" + "fmla v23.4s, v5.4s, v1.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v26.4s, v5.4s, v1.s[2]\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "ld1 {v5.4s}, [%[b_ptr]], 16\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "fmla v14.4s, v5.4s, v0.s[2]\n" + "fmla v17.4s, v5.4s, v0.s[3]\n" + "fmla v20.4s, v5.4s, v1.s[0]\n" + "fmla v23.4s, v5.4s, v1.s[1]\n" + "fmla v26.4s, v5.4s, v1.s[2]\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr), [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", + "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", + "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + + +// Overview of register layout: +// +// A 1x12 cell of Rhs is stored in 32bit in v2-v7 +// A 8x1 cell of Lhs is stored in 32bit in (v0-v1) +// A 8x12 block of accumulators is stored in 32bit in v8-v31. +// +// +--------+--------+--------+ +// | v2[0-3]| v3[0-3]| v4[0-3]| +// | v5[0-3]| v6[0-3]| v7[0-3]| +// Rhs +--------+--------+--------+ +// +// | | | | +// +// Lhs | | | | +// +// +--+ --- - +--------+--------+--------+ +// |v0| | v8[0-3]| v9[0-3]|v10[0-3]| +// |v0| |v11[0-3]|v12[0-3]|v13[0-3]| +// |v0| |v14[0-3]|v15[0-3]|v16[0-3]| +// |v0| |v17[0-3]|v18[0-3]|v19[0-3]| +// +--+ --- - +--------+--------+--------+ +// +// Accumulator +void kern_4x12(const float* packA, const float* packB, int K, float* output, + int LDC, bool is_first_k, int m_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = output; + +// clang-format off +#define LOAD_LINE(v0, v1, v2, n) \ + "cmp x10, #0\n" \ + "beq 102f\n" \ + "ld1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s}, [x" n "]\n" \ + "subs x10, x10, #1\n" + +#define LOAD_C \ + "mov x10, %x[m_remain]\n" \ + LOAD_LINE("8","9","10", "0") \ + LOAD_LINE("11","12","13", "1") \ + LOAD_LINE("14","15","16", "2") \ + LOAD_LINE("17","18","19", "3") \ + "102:\n" + +#define STORE_LINE(v0, v1, v2, n) \ + "cmp x10, #0 \n" \ + "beq 105f\n" \ + "st1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s}, [x" n "]\n" \ + "subs x10, x10, #1\n" + + +#define STORE_C \ + "mov x10, %x[m_remain]\n" \ + STORE_LINE("8","9","10", "0") \ + STORE_LINE("11","12","13", "1") \ + STORE_LINE("14","15","16", "2") \ + STORE_LINE("17","18","19", "3") \ + "105:\n" + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], 48\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v9.4s, v6.4s, v1.s[0]\n" + "fmla v10.4s, v7.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v15.4s, v6.4s, v1.s[2]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v16.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v1.s[3]\n" + "fmla v19.4s, v7.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], 48\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v9.4s, v6.4s, v1.s[0]\n" + "fmla v10.4s, v7.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v15.4s, v6.4s, v1.s[2]\n" + "fmla v16.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v1.s[3]\n" + "fmla v19.4s, v7.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr), [m_remain] "+r"(m_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "x1", "x2", "x3", "x10", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + + +// Overview of register layout: +// +// A 2x4 cell of Rhs is stored in 32bit in v2 - v3 +// A 4x2 cell of Lhs is stored in 32bit in v0 - v1 +// A 4x4 block of accumulators is stored in 32bit in v4-v6 +// +// +--------+ +// | v2[0-3]| +// | v5[0-3]| +// Rhs +--------+ +// +// | | +// +// Lhs | | +// +// +--+ --- - +--------+ +// |v0| | v8[0-3]| +// |v0| |v11[0-3]| +// |v0| |v14[0-3]| +// |v0| |v17[0-3]| +// +--+ --- - +--------+ +// +// Accumulator +void kern_4x4(const float* packA, const float* packB, int K, float* output, + int LDC, bool is_first_k, int m_remain, int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = output; + +// clang-format off +#define LOAD_LINE(v0, n) \ + "cmp x10, #0\n" \ + "beq 102f\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ld1 {v" v0 ".4s}, [x" n "], 16\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[0], [x" n "], 4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[1], [x" n "], 4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[2], [x" n "], 4\n" \ + "101" n ":\n" \ + "subs x10, x10, #1\n" + +#define LOAD_C \ + "mov x10, %x[m_remain]\n" \ + LOAD_LINE("8", "0") \ + LOAD_LINE("11", "1") \ + LOAD_LINE("14", "2") \ + LOAD_LINE("17", "3") \ + "102:\n" + +#define STORE_LINE(v0, n) \ + "cmp x10, #0 \n" \ + "beq 105f\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "st1 {v" v0 ".4s}, [x" n " ], 16\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[0], [x" n "], 4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[1], [x" n "], 4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[2], [x" n "], 4\n" \ + "104" n ":\n" \ + "subs x10, x10, #1\n" + + +#define STORE_C \ + "mov x10, %x[m_remain]\n" \ + STORE_LINE("8", "0") \ + STORE_LINE("11", "1") \ + STORE_LINE("14", "2") \ + STORE_LINE("17", "3") \ + "105:\n" + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + + "2: \n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "ld1 {v5.4s}, [%[b_ptr]], 16\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "ld1 {v5.4s}, [%[b_ptr]], 16\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr), [n_remain] "+r"(n_remain), + [m_remain] "+r"(m_remain) + : + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", + "x3", "x10", "cc", "memory"); +} + +void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin, int y0, + int ymax, int k0, int kmax) { + float zerobuff[8]; + std::memset(zerobuff, 0, sizeof(float) * 8); + constexpr int PACK_SIZE_32 = 4*8; + constexpr int PACK_SIZE_16 = 4*4; + int y = y0; + for (; y + 7 < ymax; y += 8) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + const float* inptr4 = inptr3 + ldin; + const float* inptr5 = inptr4 + ldin; + const float* inptr6 = inptr5 + ldin; + const float* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + int x = (kmax - k0); + for (; x > 3; x -= 4) { + transpose_8x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += PACK_SIZE_32; + } + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + } + + for (; y < ymax; y += 4) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = (kmax - k0); + for (; K > 3; K -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += PACK_SIZE_16; + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); + } + } +} + +void sgemm_8x12_pack_A_t(float* out, const float* in, int ldin, int x0, + int xmax, int k0, int kmax) { + int ksize = kmax - k0; + int ksize8 = (ksize << 3); + int ksize4 = (ksize << 2); + float* outptr_base = out; + float* outptr_base4 = outptr_base + (xmax - x0) / 8 * ksize8; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k * ldin + x0; + const float* inptr1 = inptr + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + prefetch_3x(inptr3); + + int x = x0; + auto outptr = outptr_base; + for (; x + 8 <= xmax; x += 8) { + auto outptr_interleave = outptr; + interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize8; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize4; + } + if (x < xmax) { + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + outptr_base += 4 * 8; + outptr_base4 += 4 * 4; + } + + for (; k < kmax; k++) { + const float* inptr = in + k * ldin + x0; + prefetch_3x(inptr); + int x = x0; + auto outptr = outptr_base; + for (; x + 8 <= xmax; x += 8) { + auto outptr_interleave = outptr; + interleave_1x8_1_s(inptr, outptr_interleave); + outptr += ksize8; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_1x4_1_s(inptr, outptr_interleave); + outptr += ksize4; + } + if (x < xmax) { + interleave_1(inptr, outptr, 4, xmax - x); + } + outptr_base += 8; + outptr_base4 += 4; + } +} + +void sgemm_8x12_pack_B_n(float* out, const float* in, int ldin, int x0, + int xmax, int k0, int kmax) { + int ksize = kmax - k0; + int ksize12 = ksize * 12; + int ksize4 = (ksize << 2); + float* outptr_base = out; + float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k * ldin + x0; + const float* inptr1 = inptr + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + prefetch_3x(inptr3); + + int x = x0; + auto outptr = outptr_base; + for (; x + 12 <= xmax; x += 12) { + auto outptr_interleave = outptr; + interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize12; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize4; + } + if (x < xmax) { + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + outptr_base += 12 * 4; + outptr_base4 += 4 * 4; + } + + for (; k < kmax; k++) { + const float* inptr = in + k * ldin + x0; + prefetch_3x(inptr); + int x = x0; + auto outptr = outptr_base; + for (; x + 12 <= xmax; x += 12) { + auto outptr_interleave = outptr; + interleave_1x12_1_s(inptr, outptr_interleave); + outptr += ksize12; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_1x4_1_s(inptr, outptr_interleave); + outptr += ksize4; + } + if (x < xmax) { + interleave_1(inptr, outptr, 4, xmax - x); + } + outptr_base += 12; + outptr_base4 += 4; + } +} + +void sgemm_8x12_pack_B_t(float* out, const float* in, int ldin, + int y0, int ymax, int k0, int kmax) { + float* outptr = out; + const float* inptr = in; + float zerobuff[12]; + std::memset(zerobuff, 0, sizeof(float) * 12); + int y = y0; + for (; y + 12 <= ymax; y += 12) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + const float* inptr4 = inptr3 + ldin; + const float* inptr5 = inptr4 + ldin; + const float* inptr6 = inptr5 + ldin; + const float* inptr7 = inptr6 + ldin; + const float* inptr8 = inptr7 + ldin; + const float* inptr9 = inptr8 + ldin; + const float* inptr10 = inptr9 + ldin; + const float* inptr11 = inptr10 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + prefetch_2x(inptr8); + prefetch_2x(inptr9); + prefetch_2x(inptr10); + prefetch_2x(inptr11); + int x = (kmax - k0); + for (; x > 3; x -= 4) { + transpose_12x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, + outptr); + outptr += 48; + } + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + *outptr++ = *inptr8++; + *outptr++ = *inptr9++; + *outptr++ = *inptr10++; + *outptr++ = *inptr11++; + } + } + + for (; y < ymax; y += 4) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + /* Cope with ragged cases by copying from a buffer of zeroes instead + */ + int x = (kmax - k0); + for (; x > 3; x -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += 16; + } + + if (x > 0) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, x); + } + } +} + +} // matmul_general_4x16 +} // aarch64 +} // megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp b/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp new file mode 100644 index 00000000..b180d8c7 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp @@ -0,0 +1,166 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp32/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/fp32/strategy.h" +#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" +#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" +#include "src/common/utils.h" + +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); + +void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, + int ymax, int k0, int kmax, bool transpose_A) const { + if (transpose_A) { + matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); + } else { + matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); + } +} + +void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax, + int k0, int kmax, bool transpose_B) const { + if (transpose_B) { + matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); + } else { + matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); + } +} + +void sgemm_4x16::kern(const float* packA, const float* packB, + size_t M, size_t N, size_t K, float* C, size_t LDC, + bool is_first_k, const float*, float*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 16; + const int K16 = K * 16; + const int K4 = K * 4; + + size_t m = 0; + for (; m < M; m += A_INTERLEAVE) { + float* output = C + (m * LDC); + + size_t n = 0; + const float* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); + output += B_INTERLEAVE; + cur_packB += K16; + } + + for (; n < N; n += 4) { + matmul_general_4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + + packA += K4; + } +} + +MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); + +void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, + int ymax, int k0, int kmax, bool transpose_A) const { + if (transpose_A) { + matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, + kmax); + } else { + matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, + kmax); + } +} + +void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, + int k0, int kmax, bool transpose_B) const { + if (transpose_B) { + matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, + kmax); + } else { + matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, + kmax); + } +} + +void sgemm_8x12::kern(const float* packA, const float* packB, + size_t M, size_t N, size_t K, float* C, size_t LDC, + bool is_first_k, const float*, float*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t A_INTERLEAVE4 = 4; + constexpr size_t B_INTERLEAVE = 12; + const int K12 = K * 12; + const int K8 = K * 8; + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) { + float* output = C + (m * LDC); + + size_t n = 0; + const float* cur_packB = packB; + for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) { + matmul_general_8x12::kern_8x12(packA, cur_packB, K, output, LDC, + is_first_k); + output += B_INTERLEAVE; + cur_packB += K12; + } + + for (; n < N; n += 4) { + matmul_general_8x12::kern_8x4(packA, cur_packB, K, output, LDC, + is_first_k, + std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + packA += K8; + } + for (; m < M; m += A_INTERLEAVE4) { + float* output = C + (m * LDC); + size_t n = 0; + const float* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_general_8x12::kern_4x12(packA, cur_packB, K, output, LDC, + is_first_k, + std::min(M - m, 4)); + output += B_INTERLEAVE; + cur_packB += K12; + } + + for (; n < N; n += 4) { + matmul_general_8x12::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + packA += K4; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/strategy.h b/dnn/src/aarch64/matrix_mul/fp32/strategy.h new file mode 100644 index 00000000..8cd877e2 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp32/strategy.h @@ -0,0 +1,30 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp32/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul { +MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, + sgemm_8x12); + +MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, + sgemm_4x16); + +MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true, + sgemm_nopack_4x16); + +} // namespace matmul +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp b/dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp new file mode 100644 index 00000000..bfb51f6e --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp @@ -0,0 +1,570 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/aarch64/matrix_mul/fp32/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +namespace { + +// Overview of register layout: +// +// A 4x4 block of A is stored in register v4-v7 +// A 4x4 block of B is stored in register v0-v3 +// A 8x4 block of accumulators store in v16-v19 +// +// A +--------+ +// | v4[0-3]| +// | v5[0-3]| +// | v6[0-3]| +// | v7[0-3]| +// +--------+ +// B +// +--------+ - - - - -+--------+ +// | v0[0-3]| |v16[0-3]| +// | v1[0-3]| |v17[0-3]| +// | v2[0-3]| |v18[0-3]| +// | v3[0-3]| |v19[0-3]| +// +--------+ - - - - -+--------+ +// Accumulator + +void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, + float* output) { + //! As each load 16 number from B, but the pos add 12 * 4, so we minus 12 + //! here. + LDB = (LDB - 12) * sizeof(float); + asm volatile( + "subs %w[K], %w[K], #4\n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n" + + "ld1 {v0.4s}, [%[b_ptr]], 16\n" + "ld1 {v1.4s}, [%[b_ptr]], 16\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" + + "fmul v16.4s, v4.4s, v0.s[0]\n" + "fmul v17.4s, v4.4s, v1.s[0]\n" + "fmul v18.4s, v4.4s, v2.s[0]\n" + "fmul v19.4s, v4.4s, v3.s[0]\n" + + "fmla v16.4s, v5.4s, v0.s[1]\n" + "fmla v17.4s, v5.4s, v1.s[1]\n" + "fmla v18.4s, v5.4s, v2.s[1]\n" + "fmla v19.4s, v5.4s, v3.s[1]\n" + + "beq 2f\n" + + "1:\n" + + "ld1 {v4.4s, v5.4s}, [%[a_ptr]], 32\n" + + "fmla v16.4s, v6.4s, v0.s[2]\n" + "fmla v17.4s, v6.4s, v1.s[2]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v19.4s, v6.4s, v3.s[2]\n" + + "fmla v16.4s, v7.4s, v0.s[3]\n" + "fmla v17.4s, v7.4s, v1.s[3]\n" + "ld1 {v0.4s}, [%[b_ptr]], 16\n" + "fmla v18.4s, v7.4s, v2.s[3]\n" + "ld1 {v1.4s}, [%[b_ptr]], 16\n" + "fmla v19.4s, v7.4s, v3.s[3]\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + + "fmla v16.4s, v4.4s, v0.s[0]\n" + "ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" + "fmla v17.4s, v4.4s, v1.s[0]\n" + "fmla v18.4s, v4.4s, v2.s[0]\n" + "fmla v19.4s, v4.4s, v3.s[0]\n" + + "ld1 {v6.4s, v7.4s}, [%[a_ptr]], 32\n" + + "fmla v16.4s, v5.4s, v0.s[1]\n" + "fmla v17.4s, v5.4s, v1.s[1]\n" + "fmla v18.4s, v5.4s, v2.s[1]\n" + "fmla v19.4s, v5.4s, v3.s[1]\n" + + "subs %w[K], %w[K], #4\n" + "bne 1b\n" + + "2:\n" + + "fmla v16.4s, v6.4s, v0.s[2]\n" + "fmla v17.4s, v6.4s, v1.s[2]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v19.4s, v6.4s, v3.s[2]\n" + + "fmla v16.4s, v7.4s, v0.s[3]\n" + "fmla v17.4s, v7.4s, v1.s[3]\n" + "fmla v18.4s, v7.4s, v2.s[3]\n" + "fmla v19.4s, v7.4s, v3.s[3]\n" + + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); +} + +// Overview of register layout: +// +// A 4x4 block of A is stored in register v4-v7 +// A 4x4 block of B is stored in register v0-v3, slipping until 8x4 +// A 8x4 block of accumulators store in v16-v23. +// +// A +--------+ +// | v4[0-3]| +// | v5[0-3]| +// | v6[0-3]| +// | v7[0-3]| +// +--------+ +// B +// +--------+ - - - - -+--------+ +// | v0[0-3]| |v16[0-3]| +// | v1[0-3]| |v17[0-3]| +// | v2[0-3]| |v18[0-3]| +// | v3[0-3]| |v19[0-3]| +// +--------+ - - - - -+--------+ +// | v0[0-3]| |v20[0-3]| +// | v1[0-3]| |v21[0-3]| +// | v2[0-3]| |v22[0-3]| +// | v3[0-3]| |v23[0-3]| +// +--------+ - - - - -+--------+ +// Accumulator + +void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, + float* output) { + //! As each load 32 number from B, but the pos add 24 * 4, so we minus 24 + //! here. + LDB = (LDB - 24) * sizeof(float); + asm volatile( + "subs %w[K], %w[K], #4\n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n" + + "ld1 {v0.4s}, [%[b_ptr]], 16\n" + "fmul v16.4s, v4.4s, v0.s[0]\n" + + "ld1 {v1.4s}, [%[b_ptr]], 16\n" + "fmla v16.4s, v5.4s, v0.s[1]\n" + "fmul v17.4s, v4.4s, v1.s[0]\n" + + "ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n" + "fmla v17.4s, v5.4s, v1.s[1]\n" + "fmla v16.4s, v6.4s, v0.s[2]\n" + "fmla v17.4s, v6.4s, v1.s[2]\n" + "fmul v18.4s, v4.4s, v2.s[0]\n" + "fmla v16.4s, v7.4s, v0.s[3]\n" + "fmla v18.4s, v5.4s, v2.s[1]\n" + "fmla v17.4s, v7.4s, v1.s[3]\n" + "fmul v19.4s, v4.4s, v3.s[0]\n" + + "ld1 {v24.4s, v25.4s}, [%[b_ptr]], 32\n" + "fmla v18.4s, v7.4s, v2.s[3]\n" + "fmla v19.4s, v5.4s, v3.s[1]\n" + "fmul v20.4s, v4.4s, v24.s[0]\n" + "fmla v19.4s, v6.4s, v3.s[2]\n" + + "ld1 {v26.4s, v27.4s}, [%[b_ptr]], %x[LDB]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v19.4s, v7.4s, v3.s[3]\n" + "fmul v21.4s, v4.4s, v25.s[0]\n" + "fmla v20.4s, v5.4s, v24.s[1]\n" + "fmla v21.4s, v5.4s, v25.s[1]\n" + + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" + "fmla v20.4s, v6.4s, v24.s[2]\n" + "fmul v22.4s, v4.4s, v26.s[0]\n" + "fmla v21.4s, v6.4s, v25.s[2]\n" + "fmla v22.4s, v5.4s, v26.s[1]\n" + + "fmla v21.4s, v7.4s, v25.s[3]\n" + "fmul v23.4s, v4.4s, v27.s[0]\n" + "fmla v20.4s, v7.4s, v24.s[3]\n" + "fmla v22.4s, v6.4s, v26.s[2]\n" + "fmla v23.4s, v5.4s, v27.s[1]\n" + + "beq 2f\n" + + "1:\n" + "ld1 {v4.4s, v5.4s}, [%[a_ptr]], 32\n" + "fmla v22.4s, v7.4s, v26.s[3]\n" + "fmla v23.4s, v6.4s, v27.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[0]\n" + "fmla v17.4s, v4.4s, v1.s[0]\n" + "fmla v23.4s, v7.4s, v27.s[3]\n" + + "ld1 {v6.4s, v7.4s}, [%[a_ptr]], 32\n" + "fmla v16.4s, v5.4s, v0.s[1]\n" + "fmla v17.4s, v5.4s, v1.s[1]\n" + "fmla v16.4s, v6.4s, v0.s[2]\n" + "fmla v17.4s, v6.4s, v1.s[2]\n" + "fmla v18.4s, v4.4s, v2.s[0]\n" + "fmla v16.4s, v7.4s, v0.s[3]\n" + "fmla v18.4s, v5.4s, v2.s[1]\n" + "fmla v17.4s, v7.4s, v1.s[3]\n" + "fmla v19.4s, v4.4s, v3.s[0]\n" + + "ld1 {v24.4s, v25.4s}, [%[b_ptr]], 32\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v19.4s, v5.4s, v3.s[1]\n" + "fmla v20.4s, v4.4s, v24.s[0]\n" + "fmla v19.4s, v6.4s, v3.s[2]\n" + + "ld1 {v26.4s, v27.4s}, [%[b_ptr]], %x[LDB]\n" + "fmla v18.4s, v7.4s, v2.s[3]\n" + "fmla v19.4s, v7.4s, v3.s[3]\n" + "fmla v21.4s, v4.4s, v25.s[0]\n" + "fmla v20.4s, v5.4s, v24.s[1]\n" + "fmla v21.4s, v5.4s, v25.s[1]\n" + + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" + "fmla v20.4s, v6.4s, v24.s[2]\n" + "fmla v22.4s, v4.4s, v26.s[0]\n" + "fmla v20.4s, v7.4s, v24.s[3]\n" + "fmla v23.4s, v4.4s, v27.s[0]\n" + "fmla v21.4s, v6.4s, v25.s[2]\n" + "fmla v22.4s, v5.4s, v26.s[1]\n" + "fmla v21.4s, v7.4s, v25.s[3]\n" + "fmla v23.4s, v5.4s, v27.s[1]\n" + "fmla v22.4s, v6.4s, v26.s[2]\n" + + "subs %w[K], %w[K], #4\n" + "bne 1b\n" + + "2:\n" + "fmla v22.4s, v7.4s, v26.s[3]\n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" + "fmla v23.4s, v6.4s, v27.s[2]\n" + "fmla v23.4s, v7.4s, v27.s[3]\n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "cc", "memory"); +} + +// Overview of register layout: +// +// A 4x1 cell of Rhs is stored in 32bit in v4-v7 (v8-v11 for ping pong) +// A 4x1 cell of Lhs is stored in 32bit in v0-v3 +// A 16x1 block of accumulators is stored in 32bit in v16-v31. +// +// Rhs +--------+ +// | v4[0-3]| +// | v5[0-3]| +// | v6[0-3]| +// | v7[0-3]| +// +--------+ +// Lhs +// +--------+ - - - - -+--------+ +// | v0[0-3] | |v16[0-3]| +// | v1[0-3] | |v17[0-3]| +// | v2[0-3] | |v18[0-3]| +// | v3[0-3] | |v19[0-3]| +// | v8[0-3] | |v20[0-3]| +// | v9[0-3] | |v21[0-3]| +// | v10[0-3]| |v22[0-3]| +// | v11[0-3]| |v23[0-3]| +// +--------+ |v24[0-3]| +// |v25[0-3]| +// |v26[0-3]| +// |v27[0-3]| +// |v28[0-3]| +// |v29[0-3]| +// |v30[0-3]| +// |v31[0-3]| +// +--------+ +// Accumulator + +void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, + float* output) { + //! As each load 64 number from B, but the pos add 56 * 4, so we minus 56 + //! here. + LDB = (LDB - 56) * sizeof(float); + + asm volatile( + "stp d8, d9, [sp, #-16]!\n" + "stp d10, d11, [sp, #-16]!\n" + + "subs %w[K], %w[K], #4\n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" + + "fmul v16.4s, v4.4s, v0.s[0]\n" + "fmul v17.4s, v4.4s, v1.s[0]\n" + "fmul v18.4s, v4.4s, v2.s[0]\n" + "fmul v19.4s, v4.4s, v3.s[0]\n" + + "fmla v16.4s, v5.4s, v0.s[1]\n" + "fmla v17.4s, v5.4s, v1.s[1]\n" + "fmla v18.4s, v5.4s, v2.s[1]\n" + "fmla v19.4s, v5.4s, v3.s[1]\n" + + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[b_ptr]], 64\n" + + "fmla v16.4s, v6.4s, v0.s[2]\n" + "fmla v17.4s, v6.4s, v1.s[2]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v19.4s, v6.4s, v3.s[2]\n" + + "fmla v16.4s, v7.4s, v0.s[3]\n" + "fmla v17.4s, v7.4s, v1.s[3]\n" + "fmla v18.4s, v7.4s, v2.s[3]\n" + "fmla v19.4s, v7.4s, v3.s[3]\n" + + "fmul v20.4s, v4.4s, v8.s[0]\n" + "fmul v21.4s, v4.4s, v9.s[0]\n" + "fmul v22.4s, v4.4s, v10.s[0]\n" + "fmul v23.4s, v4.4s, v11.s[0]\n" + + "fmla v20.4s, v5.4s, v8.s[1]\n" + "fmla v21.4s, v5.4s, v9.s[1]\n" + "fmla v22.4s, v5.4s, v10.s[1]\n" + "fmla v23.4s, v5.4s, v11.s[1]\n" + + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" + + "fmla v20.4s, v6.4s, v8.s[2]\n" + "fmla v21.4s, v6.4s, v9.s[2]\n" + "fmla v22.4s, v6.4s, v10.s[2]\n" + "fmla v23.4s, v6.4s, v11.s[2]\n" + + "fmla v20.4s, v7.4s, v8.s[3]\n" + "fmla v21.4s, v7.4s, v9.s[3]\n" + "fmla v22.4s, v7.4s, v10.s[3]\n" + "fmla v23.4s, v7.4s, v11.s[3]\n" + + "fmul v24.4s, v4.4s, v0.s[0]\n" + "fmul v25.4s, v4.4s, v1.s[0]\n" + "fmul v26.4s, v4.4s, v2.s[0]\n" + "fmul v27.4s, v4.4s, v3.s[0]\n" + + "fmla v24.4s, v5.4s, v0.s[1]\n" + "fmla v25.4s, v5.4s, v1.s[1]\n" + "fmla v26.4s, v5.4s, v2.s[1]\n" + "fmla v27.4s, v5.4s, v3.s[1]\n" + + "ld1 {v8.4s, v9.4s}, [%[b_ptr]], 32\n" + + "fmla v24.4s, v6.4s, v0.s[2]\n" + "fmla v25.4s, v6.4s, v1.s[2]\n" + "fmla v26.4s, v6.4s, v2.s[2]\n" + "fmla v27.4s, v6.4s, v3.s[2]\n" + + "ld1 {v10.4s, v11.4s}, [%[b_ptr]], %x[LDB]\n" + + "fmla v24.4s, v7.4s, v0.s[3]\n" + "fmla v25.4s, v7.4s, v1.s[3]\n" + "fmla v26.4s, v7.4s, v2.s[3]\n" + "fmla v27.4s, v7.4s, v3.s[3]\n" + + "fmul v28.4s, v4.4s, v8.s[0]\n" + "fmul v29.4s, v4.4s, v9.s[0]\n" + "fmul v30.4s, v4.4s, v10.s[0]\n" + "fmul v31.4s, v4.4s, v11.s[0]\n" + + "beq 2f\n" + + "1:\n" + + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" + + "fmla v28.4s, v5.4s, v8.s[1]\n" + "fmla v29.4s, v5.4s, v9.s[1]\n" + "fmla v30.4s, v5.4s, v10.s[1]\n" + "fmla v31.4s, v5.4s, v11.s[1]\n" + + "ld1 {v4.4s}, [%[a_ptr]], 16\n" + + "fmla v28.4s, v6.4s, v8.s[2]\n" + "fmla v29.4s, v6.4s, v9.s[2]\n" + "fmla v30.4s, v6.4s, v10.s[2]\n" + "fmla v31.4s, v6.4s, v11.s[2]\n" + + "ld1 {v5.4s}, [%[a_ptr]], 16\n" + + "fmla v28.4s, v7.4s, v8.s[3]\n" + "fmla v29.4s, v7.4s, v9.s[3]\n" + "fmla v30.4s, v7.4s, v10.s[3]\n" + "fmla v31.4s, v7.4s, v11.s[3]\n" + + "ld1 {v6.4s}, [%[a_ptr]], 16\n" + + "fmla v16.4s, v4.4s, v0.s[0]\n" + "fmla v17.4s, v4.4s, v1.s[0]\n" + "fmla v18.4s, v4.4s, v2.s[0]\n" + "fmla v19.4s, v4.4s, v3.s[0]\n" + + "ld1 {v7.4s}, [%[a_ptr]], 16\n" + + "fmla v16.4s, v5.4s, v0.s[1]\n" + "fmla v17.4s, v5.4s, v1.s[1]\n" + "fmla v18.4s, v5.4s, v2.s[1]\n" + "fmla v19.4s, v5.4s, v3.s[1]\n" + + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[b_ptr]], 64\n" + + "fmla v16.4s, v6.4s, v0.s[2]\n" + "fmla v17.4s, v6.4s, v1.s[2]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v19.4s, v6.4s, v3.s[2]\n" + + "fmla v16.4s, v7.4s, v0.s[3]\n" + "fmla v17.4s, v7.4s, v1.s[3]\n" + "fmla v18.4s, v7.4s, v2.s[3]\n" + "fmla v19.4s, v7.4s, v3.s[3]\n" + + "fmla v20.4s, v4.4s, v8.s[0]\n" + "fmla v21.4s, v4.4s, v9.s[0]\n" + "fmla v22.4s, v4.4s, v10.s[0]\n" + "fmla v23.4s, v4.4s, v11.s[0]\n" + + "fmla v20.4s, v5.4s, v8.s[1]\n" + "fmla v21.4s, v5.4s, v9.s[1]\n" + "fmla v22.4s, v5.4s, v10.s[1]\n" + "fmla v23.4s, v5.4s, v11.s[1]\n" + + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" + + "fmla v20.4s, v6.4s, v8.s[2]\n" + "fmla v21.4s, v6.4s, v9.s[2]\n" + "fmla v22.4s, v6.4s, v10.s[2]\n" + "fmla v23.4s, v6.4s, v11.s[2]\n" + + "fmla v20.4s, v7.4s, v8.s[3]\n" + "fmla v21.4s, v7.4s, v9.s[3]\n" + "fmla v22.4s, v7.4s, v10.s[3]\n" + "fmla v23.4s, v7.4s, v11.s[3]\n" + + "fmla v24.4s, v4.4s, v0.s[0]\n" + "fmla v25.4s, v4.4s, v1.s[0]\n" + "fmla v26.4s, v4.4s, v2.s[0]\n" + "fmla v27.4s, v4.4s, v3.s[0]\n" + + "fmla v24.4s, v5.4s, v0.s[1]\n" + "fmla v25.4s, v5.4s, v1.s[1]\n" + "fmla v26.4s, v5.4s, v2.s[1]\n" + "fmla v27.4s, v5.4s, v3.s[1]\n" + + "ld1 {v8.4s, v9.4s}, [%[b_ptr]], 32\n" + + "fmla v24.4s, v6.4s, v0.s[2]\n" + "fmla v25.4s, v6.4s, v1.s[2]\n" + "fmla v26.4s, v6.4s, v2.s[2]\n" + "fmla v27.4s, v6.4s, v3.s[2]\n" + + "ld1 {v10.4s, v11.4s}, [%[b_ptr]], %x[LDB]\n" + + "fmla v24.4s, v7.4s, v0.s[3]\n" + "fmla v25.4s, v7.4s, v1.s[3]\n" + "fmla v26.4s, v7.4s, v2.s[3]\n" + "fmla v27.4s, v7.4s, v3.s[3]\n" + + "fmla v28.4s, v4.4s, v8.s[0]\n" + "fmla v29.4s, v4.4s, v9.s[0]\n" + "fmla v30.4s, v4.4s, v10.s[0]\n" + "fmla v31.4s, v4.4s, v11.s[0]\n" + + "subs %w[K], %w[K], #4\n" + "bne 1b\n" + + "2:\n" + + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" + + "fmla v28.4s, v5.4s, v8.s[1]\n" + "fmla v29.4s, v5.4s, v9.s[1]\n" + "fmla v30.4s, v5.4s, v10.s[1]\n" + "fmla v31.4s, v5.4s, v11.s[1]\n" + + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n" + + "fmla v28.4s, v6.4s, v8.s[2]\n" + "fmla v29.4s, v6.4s, v9.s[2]\n" + "fmla v30.4s, v6.4s, v10.s[2]\n" + "fmla v31.4s, v6.4s, v11.s[2]\n" + + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n" + + "fmla v28.4s, v7.4s, v8.s[3]\n" + "fmla v29.4s, v7.4s, v9.s[3]\n" + "fmla v30.4s, v7.4s, v10.s[3]\n" + "fmla v31.4s, v7.4s, v11.s[3]\n" + + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n" + + "ldp d10, d11, [sp], #16\n" + "ldp d8, d9, [sp], #16\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", + "memory"); +} + +} // namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16); + +void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, + size_t LDB, float* C, size_t LDC, size_t M, + size_t K, size_t N, const float*, void*, bool trA, + bool trB) const { + constexpr static size_t MB = 4; + constexpr static size_t KB = 4; + constexpr static size_t NB = 16; + constexpr static size_t CALCBLK = 4; + + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); + + //! (m/4, k/4, 4, 4) * (k/4, n, 4) = (m/4, n, 4) + for (size_t m = 0; m < M; m += MB) { + float* output = C + (m / MB) * LDC; + const float* cur_B = B; + size_t n = 0; + for (; n + NB - 1 < N; n += NB) { + kern_4x16(A, cur_B, LDB, K, output); + cur_B += KB * NB; + output += MB * NB; + } + switch (N - n) { + case 4: + kern_4x4(A, cur_B, LDB, K, output); + break; + case 8: + kern_4x8(A, cur_B, LDB, K, output); + break; + case 12: + kern_4x8(A, cur_B, LDB, K, output); + cur_B += KB * CALCBLK * 2; + output += MB * CALCBLK * 2; + kern_4x4(A, cur_B, LDB, K, output); + break; + default: + break; + } + A += LDA; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h b/dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h new file mode 100644 index 00000000..370af5ec --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h @@ -0,0 +1,1175 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_12x8x1 { + +/** + * Overview of register layout: + * + * A 1x8 cell of Rhs is stored in 16bit in q2 + * A 12x1 cell of Lhs is stored in 16bit in q0-q1 + * A 12x8 block of accumulators is stored in 32bit in q7-q30 + * + * +--------+--------+ + * | v2[0-3]| v2[4-7]| + * Rhs +--------+--------+ + * Lhs | | | + * + * +--------+ - - - - +-----------------+ + * |v0[0]| | v7[0-3]| v8[0-3]| + * |v0[1]| | v9[0-3]|v10[0-3]| + * |v0[2]| |v11[0-3]|v12[0-3]| + * |v0[3]| |v13[0-3]|v14[0-3]| + * |v0[4]| |v15[0-3]|v16[0-3]| + * |v0[5]| |v17[0-3]|v18[0-3]| + * |v0[6]| |v19[0-3]|v20[0-3]| + * |v0[7]| |v21[0-3]|v22[0-3]| + * |v1[0]| |v23[0-3]|v24[0-3]| + * |v1[1]| |v25[0-3]|v26[0-3]| + * |v1[2]| |v27[0-3]|v28[0-3]| + * |v1[3]| |v29[0-3]|v30[0-3]| + * +--------+ - - - - +-----------------+ + * + * Accumulator + */ + +static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, + int32_t* output, int LDC, bool is_first_k) { + const int16_t* a_ptr = packA; + const int16_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + + asm volatile( + // load accumulator C + "add x1, %[output], %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "add x4, x3, %x[LDC]\n" + "add x5, x4, %x[LDC]\n" + "add x6, x5, %x[LDC]\n" + "add x7, x6, %x[LDC]\n" + "add x8, x7, %x[LDC]\n" + "add x9, x8, %x[LDC]\n" + "add x10, x9, %x[LDC]\n" + "add x11, x10, %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + + "ldp q7, q8, [%[output]]\n" + "ldp q9, q10, [x1]\n" + "ldp q11, q12, [x2]\n" + "ldp q13, q14, [x3]\n" + "ldp q15, q16, [x4]\n" + "ldp q17, q18, [x5]\n" + "ldp q19, q20, [x6]\n" + "ldp q21, q22, [x7]\n" + "ldp q23, q24, [x7]\n" + "ldp q25, q26, [x7]\n" + "ldp q27, q28, [x7]\n" + "ldp q29, q30, [x7]\n" + "b 2f\n" + + "1:\n" + "eor v7.16b, v7.16b, v7.16b\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + + "2: \n" + "ld1 {v2.8h}, [%[b_ptr]], 16\n" + "ld1 {v0.8h}, [%[a_ptr]], 16\n" + "ld1 {v1.4h}, [%[a_ptr]], 8\n" + + "smlal v7.4s, v2.4h, v0.h[0]\n" + "smlal v9.4s, v2.4h, v0.h[1]\n" + "smlal v11.4s, v2.4h, v0.h[2]\n" + "smlal v13.4s, v2.4h, v0.h[3]\n" + "smlal v15.4s, v2.4h, v0.h[4]\n" + "smlal v17.4s, v2.4h, v0.h[5]\n" + "smlal v19.4s, v2.4h, v0.h[6]\n" + "smlal v21.4s, v2.4h, v0.h[7]\n" + "smlal v23.4s, v2.4h, v1.h[0]\n" + "smlal v25.4s, v2.4h, v1.h[1]\n" + "smlal v27.4s, v2.4h, v1.h[2]\n" + "smlal v29.4s, v2.4h, v1.h[3]\n" + "smlal2 v8.4s, v2.8h, v0.h[0]\n" + "smlal2 v10.4s, v2.8h, v0.h[1]\n" + "smlal2 v12.4s, v2.8h, v0.h[2]\n" + "smlal2 v14.4s, v2.8h, v0.h[3]\n" + "smlal2 v16.4s, v2.8h, v0.h[4]\n" + "smlal2 v18.4s, v2.8h, v0.h[5]\n" + "smlal2 v20.4s, v2.8h, v0.h[6]\n" + "smlal2 v22.4s, v2.8h, v0.h[7]\n" + "smlal2 v24.4s, v2.8h, v1.h[0]\n" + "smlal2 v26.4s, v2.8h, v1.h[1]\n" + "smlal2 v28.4s, v2.8h, v1.h[2]\n" + "smlal2 v30.4s, v2.8h, v1.h[3]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" + "stp q7, q8, [%[output]]\n" + "stp q9, q10, [x1]\n" + "stp q11, q12, [x2]\n" + "stp q13, q14, [x3]\n" + "stp q15, q16, [x4]\n" + "stp q17, q18, [x5]\n" + "stp q19, q20, [x6]\n" + "stp q21, q22, [x7]\n" + "stp q23, q24, [x8]\n" + "stp q25, q26, [x9]\n" + "stp q27, q28, [x10]\n" + "stp q29, q30, [x11]\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [output] "+r"(output) + : + : "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "x1", + "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", + "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +/** + * Overview of register layout: + * + * A 1x8 cell of Rhs is stored in 16bit in q2 + * A 8x1 cell of Lhs is stored in 16bit in q0 + * A 8x8 block of accumulators is stored in 32bit in q7-q22 + * + * +--------+--------+ + * | v2[0-3]| v2[4-7]| + * Rhs +--------+--------+ + * Lhs | | | + * + * +--------+ - - - - +-----------------+ + * |v0[0]| | v7[0-3]| v8[0-3]| + * |v0[1]| | v9[0-3]|v10[0-3]| + * |v0[2]| |v11[0-3]|v12[0-3]| + * |v0[3]| |v13[0-3]|v14[0-3]| + * |v0[4]| |v15[0-3]|v16[0-3]| + * |v0[5]| |v17[0-3]|v18[0-3]| + * |v0[6]| |v19[0-3]|v20[0-3]| + * |v0[7]| |v21[0-3]|v22[0-3]| + * +--------+ - - - - +-----------------+ + * + * Accumulator + */ + +static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, + int32_t* output, int LDC, bool is_first_k) { + const int16_t* a_ptr = packA; + const int16_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + + asm volatile( + // load accumulator C + "add x1, %[output], %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "add x4, x3, %x[LDC]\n" + "add x5, x4, %x[LDC]\n" + "add x6, x5, %x[LDC]\n" + "add x7, x6, %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + + "ldp q7, q8, [%[output]]\n" + "ldp q9, q10, [x1]\n" + "ldp q11, q12, [x2]\n" + "ldp q13, q14, [x3]\n" + "ldp q15, q16, [x4]\n" + "ldp q17, q18, [x5]\n" + "ldp q19, q20, [x6]\n" + "ldp q21, q22, [x7]\n" + "b 2f\n" + + "1:\n" + "eor v7.16b, v7.16b, v7.16b\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + + "2: \n" + "ld1 {v2.8h}, [%[b_ptr]], 16\n" + "ld1 {v0.8h}, [%[a_ptr]], 16\n" + + "smlal v7.4s, v2.4h, v0.h[0]\n" + "smlal v9.4s, v2.4h, v0.h[1]\n" + "smlal v11.4s, v2.4h, v0.h[2]\n" + "smlal v13.4s, v2.4h, v0.h[3]\n" + "smlal v15.4s, v2.4h, v0.h[4]\n" + "smlal v17.4s, v2.4h, v0.h[5]\n" + "smlal v19.4s, v2.4h, v0.h[6]\n" + "smlal v21.4s, v2.4h, v0.h[7]\n" + "smlal2 v8.4s, v2.8h, v0.h[0]\n" + "smlal2 v10.4s, v2.8h, v0.h[1]\n" + "smlal2 v12.4s, v2.8h, v0.h[2]\n" + "smlal2 v14.4s, v2.8h, v0.h[3]\n" + "smlal2 v16.4s, v2.8h, v0.h[4]\n" + "smlal2 v18.4s, v2.8h, v0.h[5]\n" + "smlal2 v20.4s, v2.8h, v0.h[6]\n" + "smlal2 v22.4s, v2.8h, v0.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" + "stp q7, q8, [%[output]]\n" + "stp q9, q10, [x1]\n" + "stp q11, q12, [x2]\n" + "stp q13, q14, [x3]\n" + "stp q15, q16, [x4]\n" + "stp q17, q18, [x5]\n" + "stp q19, q20, [x6]\n" + "stp q21, q22, [x7]\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [output] "+r"(output) + : + : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1", + "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +/** + * Overview of register layout: + * + * A 1x8 cell of Rhs is stored in 16bit in q2 + * A 4x1 cell of Lhs is stored in 16bit in q0 + * A 4x8 block of accumulators is stored in 32bit in q7-q14 + * + * +--------+--------+ + * | v2[0-3]| v2[4-7]| + * Rhs +--------+--------+ + * Lhs | | | + * + * +--------+ - - - - +-----------------+ + * |v0[0]| | v7[0-3]| v8[0-3]| + * |v0[1]| | v9[0-3]|v10[0-3]| + * |v0[2]| |v11[0-3]|v12[0-3]| + * |v0[3]| |v13[0-3]|v14[0-3]| + * +--------+ - - - - +-----------------+ + * + * Accumulator + */ + +static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, + size_t m_remain) { + const int16_t* a_ptr = packA; + const int16_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(v1, v2, m) \ + "cbz %[x0], 100f\n" \ + "ldp " v1 "," v2 ", [%[outptr" m "]]\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %x[m_remain]\n" \ + LOAD_LINE("q7", "q8", "0") \ + LOAD_LINE("q9", "q10", "1") \ + LOAD_LINE("q11", "q12", "2") \ + LOAD_LINE("q13", "q14", "3") \ + "100:\n" + +#define STORE_LINE(v1, v2, m) \ + "cbz %[x0], 101f\n" \ + "stp " v1 "," v2", [%[outptr" m "]]\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %x[m_remain]\n" \ + STORE_LINE("q7", "q8", "0") \ + STORE_LINE("q9", "q10", "1") \ + STORE_LINE("q11", "q12", "2") \ + STORE_LINE("q13", "q14", "3") \ + "101:\n" + // clang-format on + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v7.16b, v7.16b, v7.16b\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + + "2: \n" + "ld1 {v2.8h}, [%[b_ptr]], 16\n" + "ld1 {v0.4h}, [%[a_ptr]], 8\n" + + "smlal v7.4s, v2.4h, v0.h[0]\n" + "smlal v9.4s, v2.4h, v0.h[1]\n" + "smlal v11.4s, v2.4h, v0.h[2]\n" + "smlal v13.4s, v2.4h, v0.h[3]\n" + "smlal2 v8.4s, v2.8h, v0.h[0]\n" + "smlal2 v10.4s, v2.8h, v0.h[1]\n" + "smlal2 v12.4s, v2.8h, v0.h[2]\n" + "smlal2 v14.4s, v2.8h, v0.h[3]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), + [m_remain] "+r"(m_remain) + : + : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +/** + * Overview of register layout: + * + * A 1x4 cell of Rhs is stored in 16bit in q2 + * A 12x1 cell of Lhs is stored in 16bit in q0-q1 + * A 12x4 block of accumulators is stored in 32bit in q7-q30 + * + * +--------+ + * | v2[0-3]| + * Rhs +--------+ + * Lhs | | + * + * +--------+ - - - - +--------- + * |v0[0]| | v8[0-3]| + * |v0[1]| | v9[0-3]| + * |v0[2]| |v10[0-3]| + * |v0[3]| |v11[0-3]| + * |v0[4]| |v12[0-3]| + * |v0[5]| |v13[0-3]| + * |v0[6]| |v14[0-3]| + * |v0[7]| |v15[0-3]| + * |v1[0]| |v16[0-3]| + * |v1[1]| |v17[0-3]| + * |v1[2]| |v18[0-3]| + * |v1[3]| |v19[0-3]| + * +--------+ - - - - +--------- + * + * Accumulator + */ + +static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, + size_t n_remain) { + const int16_t* a_ptr = packA; + const int16_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + int32_t* outptr4; + int32_t* outptr5; + int32_t* outptr6; + int32_t* outptr7; + int32_t* outptr8; + int32_t* outptr9; + int32_t* outptr10; + int32_t* outptr11; + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" reg_index ", [%[x0]] \n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "101" n ":\n" + +#define LOAD_C \ + LOAD_LINE("8", "0") \ + LOAD_LINE("9", "1") \ + LOAD_LINE("10", "2") \ + LOAD_LINE("11", "3") \ + LOAD_LINE("12", "4") \ + LOAD_LINE("13", "5") \ + LOAD_LINE("14", "6") \ + LOAD_LINE("15", "7") \ + LOAD_LINE("16", "8") \ + LOAD_LINE("17", "9") \ + LOAD_LINE("18", "10") \ + LOAD_LINE("19", "11") + +#define STORE_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 102" n "f\n" \ + "str q" reg_index ", [%[x0]]\n" \ + "b 103" n "f\n" \ + "102" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "103" n ":\n" + +#define STORE_C \ + STORE_LINE("8", "0") \ + STORE_LINE("9", "1") \ + STORE_LINE("10", "2") \ + STORE_LINE("11", "3") \ + STORE_LINE("12", "4") \ + STORE_LINE("13", "5") \ + STORE_LINE("14", "6") \ + STORE_LINE("15", "7") \ + STORE_LINE("16", "8") \ + STORE_LINE("17", "9") \ + STORE_LINE("18", "10") \ + STORE_LINE("19", "11") + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "add %[outptr4], %[outptr3], %x[LDC]\n" + "add %[outptr5], %[outptr4], %x[LDC]\n" + "add %[outptr6], %[outptr5], %x[LDC]\n" + "add %[outptr7], %[outptr6], %x[LDC]\n" + "add %[outptr8], %[outptr7], %x[LDC]\n" + "add %[outptr9], %[outptr8], %x[LDC]\n" + "add %[outptr10], %[outptr9], %x[LDC]\n" + "add %[outptr11], %[outptr10], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + "ld1 {v2.4h}, [%[b_ptr]], 8\n" + "ld1 {v0.8h}, [%[a_ptr]], 16\n" + "ld1 {v1.4h}, [%[a_ptr]], 8\n" + + "smlal v8.4s, v2.4h, v0.h[0]\n" + "smlal v9.4s, v2.4h, v0.h[1]\n" + "smlal v10.4s, v2.4h, v0.h[2]\n" + "smlal v11.4s, v2.4h, v0.h[3]\n" + "smlal v12.4s, v2.4h, v0.h[4]\n" + "smlal v13.4s, v2.4h, v0.h[5]\n" + "smlal v14.4s, v2.4h, v0.h[6]\n" + "smlal v15.4s, v2.4h, v0.h[7]\n" + "smlal v16.4s, v2.4h, v1.h[0]\n" + "smlal v17.4s, v2.4h, v1.h[1]\n" + "smlal v18.4s, v2.4h, v1.h[2]\n" + "smlal v19.4s, v2.4h, v1.h[3]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), + [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), + [outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9), + [outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11), + [x0] "+r"(x0), [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +/** + * Overview of register layout: + * + * A 1x4 cell of Rhs is stored in 16bit in q2 + * A 12x1 cell of Lhs is stored in 16bit in q0-q1 + * A 12x4 block of accumulators is stored in 32bit in q7-q30 + * + * +--------+ + * | v2[0-3]| + * Rhs +--------+ + * Lhs | | + * + * +--------+ - - - - +--------- + * |v0[0]| | v8[0-3]| + * |v0[1]| | v9[0-3]| + * |v0[2]| |v10[0-3]| + * |v0[3]| |v11[0-3]| + * |v0[4]| |v12[0-3]| + * |v0[5]| |v13[0-3]| + * |v0[6]| |v14[0-3]| + * |v0[7]| |v15[0-3]| + * +--------+ - - - - +--------- + * + * Accumulator + */ + +static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, + size_t n_remain) { + const int16_t* a_ptr = packA; + const int16_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + int32_t* outptr4; + int32_t* outptr5; + int32_t* outptr6; + int32_t* outptr7; + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" reg_index ", [%[x0]] \n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "101" n ":\n" + +#define LOAD_C \ + LOAD_LINE("8", "0") \ + LOAD_LINE("9", "1") \ + LOAD_LINE("10", "2") \ + LOAD_LINE("11", "3") \ + LOAD_LINE("12", "4") \ + LOAD_LINE("13", "5") \ + LOAD_LINE("14", "6") \ + LOAD_LINE("15", "7") + +#define STORE_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 102" n "f\n" \ + "str q" reg_index ", [%[x0]]\n" \ + "b 103" n "f\n" \ + "102" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "103" n ":\n" + +#define STORE_C \ + STORE_LINE("8", "0") \ + STORE_LINE("9", "1") \ + STORE_LINE("10", "2") \ + STORE_LINE("11", "3") \ + STORE_LINE("12", "4") \ + STORE_LINE("13", "5") \ + STORE_LINE("14", "6") \ + STORE_LINE("15", "7") + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "add %[outptr4], %[outptr3], %x[LDC]\n" + "add %[outptr5], %[outptr4], %x[LDC]\n" + "add %[outptr6], %[outptr5], %x[LDC]\n" + "add %[outptr7], %[outptr6], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + + "2: \n" + "ld1 {v2.4h}, [%[b_ptr]], 8\n" + "ld1 {v0.8h}, [%[a_ptr]], 16\n" + + "smlal v8.4s, v2.4h, v0.h[0]\n" + "smlal v9.4s, v2.4h, v0.h[1]\n" + "smlal v10.4s, v2.4h, v0.h[2]\n" + "smlal v11.4s, v2.4h, v0.h[3]\n" + "smlal v12.4s, v2.4h, v0.h[4]\n" + "smlal v13.4s, v2.4h, v0.h[5]\n" + "smlal v14.4s, v2.4h, v0.h[6]\n" + "smlal v15.4s, v2.4h, v0.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), + [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), + [n_remain] "+r"(n_remain) + : + : "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +/** + * Overview of register layout: + * + * A 1x4 cell of Rhs is stored in 16bit in q2 + * A 12x1 cell of Lhs is stored in 16bit in q0-q1 + * A 12x4 block of accumulators is stored in 32bit in q7-q30 + * + * +--------+ + * | v2[0-3]| + * Rhs +--------+ + * Lhs | | + * + * +--------+ - - - - +--------- + * |v0[0]| | v8[0-3]| + * |v0[1]| | v9[0-3]| + * |v0[2]| |v10[0-3]| + * |v0[3]| |v11[0-3]| + * +--------+ - - - - +--------- + * + * Accumulator + */ + +static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, size_t m_remain, + size_t n_remain) { + const int16_t* a_ptr = packA; + const int16_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + size_t x0 = 0; + size_t x1 = 0; + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "cbz %[x1], 102f\n" \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" reg_index ", [%[x0]]\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "101" n ":\n" \ + "subs %[x1], %[x1], #1\n" + +#define LOAD_C \ + "mov %[x1], %x[m_remain]\n" \ + LOAD_LINE("8", "0") \ + LOAD_LINE("9", "1") \ + LOAD_LINE("10", "2") \ + LOAD_LINE("11", "3") \ + "102:\n" + +#define STORE_LINE(reg_index, n) \ + "cbz %[x1], 105f\n" \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "str q" reg_index ", [%[x0]]\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "104" n ":\n" \ + "subs %[x1], %[x1], #1\n" + +#define STORE_C \ + "mov %[x1], %x[m_remain]\n" \ + STORE_LINE("8", "0") \ + STORE_LINE("9", "1") \ + STORE_LINE("10", "2") \ + STORE_LINE("11", "3") \ + "105:\n" + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + + "2: \n" + "ld1 {v2.4h}, [%[b_ptr]], 8\n" + "ld1 {v0.4h}, [%[a_ptr]], 8\n" + + "smlal v8.4s, v2.4h, v0.h[0]\n" + "smlal v9.4s, v2.4h, v0.h[1]\n" + "smlal v10.4s, v2.4h, v0.h[2]\n" + "smlal v11.4s, v2.4h, v0.h[3]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), + [m_remain] "+r"(m_remain), [x1] "+r"(x1), + [n_remain] "+r"(n_remain) + : + : "v0", "v2", "v8", "v9", "v10", "v11", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + int16_t zerobuff[4]; + std::memset(zerobuff, 0, sizeof(int16_t) * 4); + + int y = y0; + for (; y + 11 < ymax; y += 12) { + const int16_t* inptr0 = inptr + y * ldin + k0; + const int16_t* inptr1 = inptr0 + ldin; + const int16_t* inptr2 = inptr1 + ldin; + const int16_t* inptr3 = inptr2 + ldin; + const int16_t* inptr4 = inptr3 + ldin; + const int16_t* inptr5 = inptr4 + ldin; + const int16_t* inptr6 = inptr5 + ldin; + const int16_t* inptr7 = inptr6 + ldin; + const int16_t* inptr8 = inptr7 + ldin; + const int16_t* inptr9 = inptr8 + ldin; + const int16_t* inptr10 = inptr9 + ldin; + const int16_t* inptr11 = inptr10 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + prefetch_2x(inptr8); + prefetch_2x(inptr9); + prefetch_2x(inptr10); + prefetch_2x(inptr11); + + int K = kmax - k0; + for (; K > 3; K -= 4) { + interleave_12x1_4_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, inptr8, inptr9, inptr10, + inptr11, outptr); + } + + if (K > 0) { + interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, + outptr, 1, K); + } + } + + for (; y + 7 < ymax; y += 8) { + const int16_t* inptr0 = inptr + y * ldin + k0; + const int16_t* inptr1 = inptr0 + ldin; + const int16_t* inptr2 = inptr1 + ldin; + const int16_t* inptr3 = inptr2 + ldin; + const int16_t* inptr4 = inptr3 + ldin; + const int16_t* inptr5 = inptr4 + ldin; + const int16_t* inptr6 = inptr5 + ldin; + const int16_t* inptr7 = inptr6 + ldin; + + int K = kmax - k0; + for (; K > 7; K -= 8) { + interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + } + + if (K > 0) { + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 1, K); + } + } + + for (; y < ymax; y += 4) { + const int16_t* inptr0 = inptr + y * ldin + k0; + const int16_t* inptr1 = inptr0 + ldin; + const int16_t* inptr2 = inptr1 + ldin; + const int16_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 3; K -= 4) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4x1_4_h(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); + } + } +} + +static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in, + int ldin, int x0, int xmax, + int k0, int kmax) { + const int ksize = kmax - k0; + const int ksize4 = ksize * 4; + const int ksize8 = ksize4 * 2; + const int ksize12 = ksize * 12; + int16_t* outptr = out; + int16_t* outptr_base = out; + //! 1x8 block output start pos + int16_t* outptr_base8 = out + ((xmax - x0) / 12) * ksize12; + //! 1x4 block output start pos + int16_t* outptr_base4 = + outptr_base8 + (xmax - (x0 + (xmax - x0) / 12 * 12)) / 8 * ksize8; + + int k = k0; + for (; k < kmax; k++) { + const int16_t* inptr = in + k * ldin + x0; + prefetch_2x(inptr); + int x = x0; + outptr = outptr_base; + for (; x + 11 < xmax; x += 12) { + transpose_12x1_1_h(inptr, outptr); + outptr += ksize12; + } + outptr = outptr_base8; + for (; x + 7 < xmax; x += 8) { + transpose_8x1_1_h(inptr, outptr); + outptr += ksize8; + } + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + transpose_4x1_1_h(inptr, outptr); + outptr += ksize4; + } + int X = (4 - (xmax - x)) % 4; + for (; x < xmax; x++) { + *outptr++ = *inptr++; + } + memset(outptr, 0, sizeof(int16_t) * X); + outptr += ksize4; + outptr_base += 12; + outptr_base8 += 8; + outptr_base4 += 4; + } +} + +static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin, + int x0, int xmax, int k0, int kmax) { + const int ksize = kmax - k0; + const int ksize4 = ksize * 4; + const int ksize8 = ksize4 * 2; + int16_t* outptr = out; + int16_t* outptr_base = out; + //! 1x4 block output start pos + int16_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k++) { + const int16_t* inptr = in + k * ldin + x0; + prefetch_2x(inptr); + int x = x0; + outptr = outptr_base; + for (; x + 7 < xmax; x += 8) { + transpose_8x1_1_h(inptr, outptr); + outptr += ksize8; + } + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + transpose_4x1_1_h(inptr, outptr); + outptr += ksize4; + } + int X = (4 - (xmax - x)) % 4; + for (; x < xmax; x++) { + *outptr++ = *inptr++; + } + memset(outptr, 0, sizeof(int16_t) * X); + outptr += ksize4; + outptr_base += 8; + outptr_base4 += 4; + } +} + +static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, + const int16_t* inptr, int ldin, + int y0, int ymax, int k0, + int kmax) { + int16_t zerobuff[4]; + std::memset(zerobuff, 0, sizeof(int16_t) * 4); + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const int16_t* inptr0 = inptr + y * ldin + k0; + const int16_t* inptr1 = inptr0 + ldin; + const int16_t* inptr2 = inptr1 + ldin; + const int16_t* inptr3 = inptr2 + ldin; + const int16_t* inptr4 = inptr3 + ldin; + const int16_t* inptr5 = inptr4 + ldin; + const int16_t* inptr6 = inptr5 + ldin; + const int16_t* inptr7 = inptr6 + ldin; + + int K = kmax - k0; + for (; K > 7; K -= 8) { + interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + } + + if (K > 0) { + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 1, K); + } + } + + for (; y < ymax; y += 4) { + const int16_t* inptr0 = inptr + y * ldin + k0; + const int16_t* inptr1 = inptr0 + ldin; + const int16_t* inptr2 = inptr1 + ldin; + const int16_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 3; K -= 4) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4x1_4_h(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); + } + } +} + +} // namespace matmul_12x8x1 +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int16/strategy.cpp b/dnn/src/aarch64/matrix_mul/int16/strategy.cpp new file mode 100644 index 00000000..2b798cf0 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int16/strategy.cpp @@ -0,0 +1,132 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int16/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/int16/strategy.h" +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/aarch64/matrix_mul/int16/kernel_12x8x1.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +///////////////////////// gemm_s16_12x8x1 //////////////////////////////////// +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1); + +void gemm_s16_12x8x1::pack_A(dt_int16* outptr, const dt_int16* inptr, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n(outptr, inptr, ldin, + y0, ymax, k0, kmax); + } else { + matmul_12x8x1::gemm_s16_12x8x1_pack_A_n(outptr, inptr, ldin, y0, ymax, + k0, kmax); + } +} + +void gemm_s16_12x8x1::pack_B(dt_int16* out, const dt_int16* in, int ldin, + int x0, int xmax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n(out, in, ldin, x0, + xmax, k0, kmax); + } else { + matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, + kmax); + } +} + +void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, + size_t M, size_t N, size_t K, dt_int32* C, + size_t LDC, bool is_first_k, const dt_int32*, + dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + (A_dtype.enumv() == DTypeEnum::Int16 && + C_dtype.enumv() == DTypeEnum::Int32), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 12; + constexpr size_t B_INTERLEAVE = 8; + const int K12 = K * 12; + const int K8 = K * 8; + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int16* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, + is_first_k); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_12x8x1::kern_12x4(packA, cur_packB, K, output, LDC, + is_first_k, std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + packA += K12; + } + + for (; m + 7 < M; m += 8) { + int32_t* output = C + (m * LDC); + const dt_int16* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, + is_first_k); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_12x8x1::kern_8x4(packA, cur_packB, K, output, LDC, + is_first_k, std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + packA += K8; + } + + for (; m < M; m += 4) { + int32_t* output = C + (m * LDC); + const dt_int16* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_12x8x1::kern_4x8(packA, cur_packB, K, output, LDC, + is_first_k, std::min(M - m, 4)); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_12x8x1::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k, std::min(M - m, 4), + std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + packA += K4; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int16/strategy.h b/dnn/src/aarch64/matrix_mul/int16/strategy.h new file mode 100644 index 00000000..c67c25b8 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int16/strategy.h @@ -0,0 +1,29 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int16/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, + gemm_s16_12x8x1); + +MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 8, 8, 1, false, + true, gemm_nopack_s16_8x8); + +} // namespace matmul +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp b/dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp new file mode 100644 index 00000000..7a4d7d90 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp @@ -0,0 +1,658 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/int16/strategy.h" +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +namespace { + +// Overview of register layout: +// +// A 8x1 cell of Lhs is stored in 16bit in v24-v27 +// A 8x1 cell of Rhs is stored in 16bit in v0-v7 +// A 8x2 block of accumulators is stored in 32bit in v16-v23. +// +// Lhs +--------+ +// |v0[0-7]| +// |v1[0-7]| +// |v2[0-7]| +// |v3[0-7]| +// +--------+ +// Rhs +// +---------+ - - - - -+--------+ +// | v24[0-7]| |v16[0-3]| +// | v25[0-7]| |v17[0-3]| +// | v26[0-7]| |v18[0-3]| +// | v27[0-7]| |v19[0-3]| +// | v28[0-7]| |v20[0-3]| +// | v29[0-7]| |v21[0-3]| +// | v30[0-7]| |v22[0-3]| +// | v31[0-7]| |v23[0-3]| +// +---------+ +--------+ +// Accumulator +void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, + dt_int32* output) { + //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 + //! here. + LDB = (LDB - 24) * sizeof(dt_int16); + + asm volatile( + "subs %w[K], %w[K], #8\n" + + "ld1 {v24.4s}, [%[a_ptr]], 16\n" + "ld1 {v0.4s}, [%[b_ptr]], 16\n" + "ld1 {v1.4s}, [%[b_ptr]], 16\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" + + "smull v16.4s, v24.4h, v0.h[0]\n" + "smull2 v17.4s, v24.8h, v0.h[0]\n" + "smull v18.4s, v24.4h, v1.h[0]\n" + "smull2 v19.4s, v24.8h, v1.h[0]\n" + + "ld1 {v25.4s}, [%[a_ptr]], 16\n" + + "smull v20.4s, v24.4h, v2.h[0]\n" + "smull2 v21.4s, v24.8h, v2.h[0]\n" + "smull v22.4s, v24.4h, v3.h[0]\n" + "smull2 v23.4s, v24.8h, v3.h[0]\n" + + "smlal v16.4s, v25.4h, v0.h[1]\n" + "smlal2 v17.4s, v25.8h, v0.h[1]\n" + "smlal v18.4s, v25.4h, v1.h[1]\n" + "smlal2 v19.4s, v25.8h, v1.h[1]\n" + + "ld1 {v26.4s}, [%[a_ptr]], 16\n" + + "smlal v20.4s, v25.4h, v2.h[1]\n" + "smlal2 v21.4s, v25.8h, v2.h[1]\n" + "smlal v22.4s, v25.4h, v3.h[1]\n" + "smlal2 v23.4s, v25.8h, v3.h[1]\n" + + "smlal v16.4s, v26.4h, v0.h[2]\n" + "smlal2 v17.4s, v26.8h, v0.h[2]\n" + "smlal v18.4s, v26.4h, v1.h[2]\n" + "smlal2 v19.4s, v26.8h, v1.h[2]\n" + + "ld1 {v27.4s}, [%[a_ptr]], 16\n" + + "smlal v20.4s, v26.4h, v2.h[2]\n" + "smlal2 v21.4s, v26.8h, v2.h[2]\n" + "smlal v22.4s, v26.4h, v3.h[2]\n" + "smlal2 v23.4s, v26.8h, v3.h[2]\n" + + "smlal v16.4s, v27.4h, v0.h[3]\n" + "smlal2 v17.4s, v27.8h, v0.h[3]\n" + "smlal v18.4s, v27.4h, v1.h[3]\n" + "smlal2 v19.4s, v27.8h, v1.h[3]\n" + + "ld1 {v28.4s}, [%[a_ptr]], 16\n" + + "smlal v20.4s, v27.4h, v2.h[3]\n" + "smlal2 v21.4s, v27.8h, v2.h[3]\n" + "smlal v22.4s, v27.4h, v3.h[3]\n" + "smlal2 v23.4s, v27.8h, v3.h[3]\n" + + "smlal v16.4s, v28.4h, v0.h[4]\n" + "smlal2 v17.4s, v28.8h, v0.h[4]\n" + "smlal v18.4s, v28.4h, v1.h[4]\n" + "smlal2 v19.4s, v28.8h, v1.h[4]\n" + + "ld1 {v29.4s}, [%[a_ptr]], 16\n" + + "smlal v20.4s, v28.4h, v2.h[4]\n" + "smlal2 v21.4s, v28.8h, v2.h[4]\n" + "smlal v22.4s, v28.4h, v3.h[4]\n" + "smlal2 v23.4s, v28.8h, v3.h[4]\n" + + "smlal v16.4s, v29.4h, v0.h[5]\n" + "smlal2 v17.4s, v29.8h, v0.h[5]\n" + "smlal v18.4s, v29.4h, v1.h[5]\n" + "smlal2 v19.4s, v29.8h, v1.h[5]\n" + + "ld1 {v30.4s}, [%[a_ptr]], 16\n" + + "smlal v20.4s, v29.4h, v2.h[5]\n" + "smlal2 v21.4s, v29.8h, v2.h[5]\n" + "smlal v22.4s, v29.4h, v3.h[5]\n" + "smlal2 v23.4s, v29.8h, v3.h[5]\n" + + "smlal v16.4s, v30.4h, v0.h[6]\n" + "smlal2 v17.4s, v30.8h, v0.h[6]\n" + "smlal v18.4s, v30.4h, v1.h[6]\n" + "smlal2 v19.4s, v30.8h, v1.h[6]\n" + + "ld1 {v31.4s}, [%[a_ptr]], 16\n" + + "smlal v20.4s, v30.4h, v2.h[6]\n" + "smlal2 v21.4s, v30.8h, v2.h[6]\n" + "smlal v22.4s, v30.4h, v3.h[6]\n" + "smlal2 v23.4s, v30.8h, v3.h[6]\n" + + "beq 2f\n" + + "1:\n" + + "ld1 {v24.4s}, [%[a_ptr]], 16\n" + + "smlal v16.4s, v31.4h, v0.h[7]\n" + "smlal2 v17.4s, v31.8h, v0.h[7]\n" + + "ld1 {v0.4s}, [%[b_ptr]], 16\n" + + "smlal v18.4s, v31.4h, v1.h[7]\n" + "smlal2 v19.4s, v31.8h, v1.h[7]\n" + + "ld1 {v1.4s}, [%[b_ptr]], 16\n" + + "smlal v20.4s, v31.4h, v2.h[7]\n" + "smlal2 v21.4s, v31.8h, v2.h[7]\n" + + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + + "smlal v22.4s, v31.4h, v3.h[7]\n" + "smlal2 v23.4s, v31.8h, v3.h[7]\n" + + "ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" + + "smlal v16.4s, v24.4h, v0.h[0]\n" + "smlal2 v17.4s, v24.8h, v0.h[0]\n" + "smlal v18.4s, v24.4h, v1.h[0]\n" + "smlal2 v19.4s, v24.8h, v1.h[0]\n" + + "ld1 {v25.4s}, [%[a_ptr]], 16\n" + + "smlal v20.4s, v24.4h, v2.h[0]\n" + "smlal2 v21.4s, v24.8h, v2.h[0]\n" + "smlal v22.4s, v24.4h, v3.h[0]\n" + "smlal2 v23.4s, v24.8h, v3.h[0]\n" + + "smlal v16.4s, v25.4h, v0.h[1]\n" + "smlal2 v17.4s, v25.8h, v0.h[1]\n" + "smlal v18.4s, v25.4h, v1.h[1]\n" + "smlal2 v19.4s, v25.8h, v1.h[1]\n" + + "ld1 {v26.4s}, [%[a_ptr]], 16\n" + + "smlal v20.4s, v25.4h, v2.h[1]\n" + "smlal2 v21.4s, v25.8h, v2.h[1]\n" + "smlal v22.4s, v25.4h, v3.h[1]\n" + "smlal2 v23.4s, v25.8h, v3.h[1]\n" + + "smlal v16.4s, v26.4h, v0.h[2]\n" + "smlal2 v17.4s, v26.8h, v0.h[2]\n" + "smlal v18.4s, v26.4h, v1.h[2]\n" + "smlal2 v19.4s, v26.8h, v1.h[2]\n" + + "ld1 {v27.4s}, [%[a_ptr]], 16\n" + + "smlal v20.4s, v26.4h, v2.h[2]\n" + "smlal2 v21.4s, v26.8h, v2.h[2]\n" + "smlal v22.4s, v26.4h, v3.h[2]\n" + "smlal2 v23.4s, v26.8h, v3.h[2]\n" + + "smlal v16.4s, v27.4h, v0.h[3]\n" + "smlal2 v17.4s, v27.8h, v0.h[3]\n" + "smlal v18.4s, v27.4h, v1.h[3]\n" + "smlal2 v19.4s, v27.8h, v1.h[3]\n" + + "ld1 {v28.4s}, [%[a_ptr]], 16\n" + + "smlal v20.4s, v27.4h, v2.h[3]\n" + "smlal2 v21.4s, v27.8h, v2.h[3]\n" + "smlal v22.4s, v27.4h, v3.h[3]\n" + "smlal2 v23.4s, v27.8h, v3.h[3]\n" + + "smlal v16.4s, v28.4h, v0.h[4]\n" + "smlal2 v17.4s, v28.8h, v0.h[4]\n" + "smlal v18.4s, v28.4h, v1.h[4]\n" + "smlal2 v19.4s, v28.8h, v1.h[4]\n" + + "ld1 {v29.4s}, [%[a_ptr]], 16\n" + + "smlal v20.4s, v28.4h, v2.h[4]\n" + "smlal2 v21.4s, v28.8h, v2.h[4]\n" + "smlal v22.4s, v28.4h, v3.h[4]\n" + "smlal2 v23.4s, v28.8h, v3.h[4]\n" + + "smlal v16.4s, v29.4h, v0.h[5]\n" + "smlal2 v17.4s, v29.8h, v0.h[5]\n" + "smlal v18.4s, v29.4h, v1.h[5]\n" + "smlal2 v19.4s, v29.8h, v1.h[5]\n" + + "ld1 {v30.4s}, [%[a_ptr]], 16\n" + + "smlal v20.4s, v29.4h, v2.h[5]\n" + "smlal2 v21.4s, v29.8h, v2.h[5]\n" + "smlal v22.4s, v29.4h, v3.h[5]\n" + "smlal2 v23.4s, v29.8h, v3.h[5]\n" + + "smlal v16.4s, v30.4h, v0.h[6]\n" + "smlal2 v17.4s, v30.8h, v0.h[6]\n" + "smlal v18.4s, v30.4h, v1.h[6]\n" + "smlal2 v19.4s, v30.8h, v1.h[6]\n" + + "ld1 {v31.4s}, [%[a_ptr]], 16\n" + + "smlal v20.4s, v30.4h, v2.h[6]\n" + "smlal2 v21.4s, v30.8h, v2.h[6]\n" + "smlal v22.4s, v30.4h, v3.h[6]\n" + "smlal2 v23.4s, v30.8h, v3.h[6]\n" + + "subs %w[K], %w[K], #8\n" + "bne 1b\n" + + "2:\n" + + "smlal v16.4s, v31.4h, v0.h[7]\n" + "smlal2 v17.4s, v31.8h, v0.h[7]\n" + "smlal v18.4s, v31.4h, v1.h[7]\n" + "smlal2 v19.4s, v31.8h, v1.h[7]\n" + "smlal v20.4s, v31.4h, v2.h[7]\n" + "smlal2 v21.4s, v31.8h, v2.h[7]\n" + "smlal v22.4s, v31.4h, v3.h[7]\n" + "smlal2 v23.4s, v31.8h, v3.h[7]\n" + + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "cc", "memory"); +} + +// Overview of register layout: +// +// A 8x1 cell of Rhs is stored in 16bit in v8-v15 +// A 8x1 cell of Lhs is stored in 16bit in v0-v7 +// A 8x2 block of accumulators is stored in 32bit in v16-v31. +// +// Rhs +--------+ +// | v8[0-7]| +// | v9[0-7]| +// |v10[0-7]| +// |v11[0-7]| +// |v12[0-7]| +// |v13[0-7]| +// |v14[0-7]| +// |v15[0-7]| +// +--------+ +// Lhs +// +--------+ - - - - -+--------+--------+ +// | v0[0-7]| |v16[0-3]|v17[0-3]| +// | v1[0-7]| |v18[0-3]|v19[0-3]| +// | v2[0-7]| |v20[0-3]|v21[0-3]| +// | v3[0-7]| |v22[0-3]|v23[0-3]| +// | v4[0-7]| |v24[0-3]|v25[0-3]| +// | v5[0-7]| |v26[0-3]|v27[0-3]| +// | v6[0-7]| |v28[0-3]|v29[0-3]| +// | v7[0-7]| |v30[0-3]|v31[0-3]| +// +--------+ +--------+--------+ +// Accumulator +void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, + dt_int32* output) { + //! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 + //! here. + LDB = (LDB - 48) * sizeof(dt_int16); + + asm volatile( + "subs %w[K], %w[K], #8\n" + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n" + + "ld1 {v0.4s}, [%[b_ptr]], 16\n" + "smull v16.4s, v8.4h, v0.h[0]\n" + "ld1 {v1.4s}, [%[b_ptr]], 16\n" + "smlal v16.4s, v9.4h, v0.h[1]\n" + "smull v18.4s, v8.4h, v1.h[0]\n" + "smull2 v17.4s, v8.8h, v0.h[0]\n" + "smull2 v19.4s, v8.8h, v1.h[0]\n" + "smlal v16.4s, v10.4h, v0.h[2]\n" + "smlal v18.4s, v9.4h, v1.h[1]\n" + "smlal2 v17.4s, v9.8h, v0.h[1]\n" + "smlal2 v19.4s, v9.8h, v1.h[1]\n" + "smlal v16.4s, v11.4h, v0.h[3]\n" + "smlal v18.4s, v10.4h, v1.h[2]\n" + "smlal2 v17.4s, v10.8h, v0.h[2]\n" + "smlal2 v19.4s, v10.8h, v1.h[2]\n" + "smlal v16.4s, v12.4h, v0.h[4]\n" + "smlal v18.4s, v11.4h, v1.h[3]\n" + "smlal2 v17.4s, v11.8h, v0.h[3]\n" + "smlal2 v19.4s, v11.8h, v1.h[3]\n" + "smlal v16.4s, v13.4h, v0.h[5]\n" + "smlal v18.4s, v12.4h, v1.h[4]\n" + "smlal2 v17.4s, v12.8h, v0.h[4]\n" + "smlal2 v19.4s, v12.8h, v1.h[4]\n" + "smlal2 v17.4s, v13.8h, v0.h[5]\n" + + "ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n" + "smlal v16.4s, v14.4h, v0.h[6]\n" + "smlal v18.4s, v13.4h, v1.h[5]\n" + "smlal2 v17.4s, v14.8h, v0.h[6]\n" + "smlal2 v19.4s, v13.8h, v1.h[5]\n" + "smull v20.4s, v8.4h, v2.h[0]\n" + "smull v22.4s, v8.4h, v3.h[0]\n" + "smull2 v21.4s, v8.8h, v2.h[0]\n" + "smull2 v23.4s, v8.8h, v3.h[0]\n" + "smlal v16.4s, v15.4h, v0.h[7]\n" + "smlal v18.4s, v14.4h, v1.h[6]\n" + "smlal2 v17.4s, v15.8h, v0.h[7]\n" + "smlal2 v19.4s, v14.8h, v1.h[6]\n" + "smlal v20.4s, v9.4h, v2.h[1]\n" + "smlal v22.4s, v9.4h, v3.h[1]\n" + "smlal2 v21.4s, v9.8h, v2.h[1]\n" + "smlal2 v23.4s, v9.8h, v3.h[1]\n" + "smlal v18.4s, v15.4h, v1.h[7]\n" + "smlal v20.4s, v10.4h, v2.h[2]\n" + "smlal v22.4s, v10.4h, v3.h[2]\n" + "smlal2 v21.4s, v10.8h, v2.h[2]\n" + "smlal2 v23.4s, v10.8h, v3.h[2]\n" + "smlal2 v19.4s, v15.8h, v1.h[7]\n" + "smlal v20.4s, v11.4h, v2.h[3]\n" + "smlal v22.4s, v11.4h, v3.h[3]\n" + "smlal2 v21.4s, v11.8h, v2.h[3]\n" + "smlal2 v23.4s, v11.8h, v3.h[3]\n" + "smlal v20.4s, v12.4h, v2.h[4]\n" + "smlal v22.4s, v12.4h, v3.h[4]\n" + "smlal2 v21.4s, v12.8h, v2.h[4]\n" + "smlal2 v23.4s, v12.8h, v3.h[4]\n" + "smlal v20.4s, v13.4h, v2.h[5]\n" + "smlal v22.4s, v13.4h, v3.h[5]\n" + "smlal2 v21.4s, v13.8h, v2.h[5]\n" + "smlal2 v23.4s, v13.8h, v3.h[5]\n" + + "ld1 {v4.4s, v5.4s}, [%[b_ptr]], 32\n" + "smlal v20.4s, v14.4h, v2.h[6]\n" + "smlal v22.4s, v14.4h, v3.h[6]\n" + "smlal2 v21.4s, v14.8h, v2.h[6]\n" + "smlal2 v23.4s, v14.8h, v3.h[6]\n" + "smull v24.4s, v8.4h, v4.h[0]\n" + "smull v26.4s, v8.4h, v5.h[0]\n" + "smull2 v25.4s, v8.8h, v4.h[0]\n" + "smull2 v27.4s, v8.8h, v5.h[0]\n" + "smlal v20.4s, v15.4h, v2.h[7]\n" + "smlal v22.4s, v15.4h, v3.h[7]\n" + "smlal2 v21.4s, v15.8h, v2.h[7]\n" + "smlal2 v23.4s, v15.8h, v3.h[7]\n" + "smlal v24.4s, v9.4h, v4.h[1]\n" + "smlal v26.4s, v9.4h, v5.h[1]\n" + "smlal2 v25.4s, v9.8h, v4.h[1]\n" + "smlal2 v27.4s, v9.8h, v5.h[1]\n" + "smlal v24.4s, v10.4h, v4.h[2]\n" + "smlal v26.4s, v10.4h, v5.h[2]\n" + "smlal2 v25.4s, v10.8h, v4.h[2]\n" + "smlal2 v27.4s, v10.8h, v5.h[2]\n" + "smlal v24.4s, v11.4h, v4.h[3]\n" + "smlal v26.4s, v11.4h, v5.h[3]\n" + "smlal2 v25.4s, v11.8h, v4.h[3]\n" + "smlal2 v27.4s, v11.8h, v5.h[3]\n" + "smlal v24.4s, v12.4h, v4.h[4]\n" + "smlal v26.4s, v12.4h, v5.h[4]\n" + "smlal2 v25.4s, v12.8h, v4.h[4]\n" + "smlal2 v27.4s, v12.8h, v5.h[4]\n" + "smlal v24.4s, v13.4h, v4.h[5]\n" + "smlal v26.4s, v13.4h, v5.h[5]\n" + "smlal2 v25.4s, v13.8h, v4.h[5]\n" + "smlal2 v27.4s, v13.8h, v5.h[5]\n" + + "ld1 {v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n" + "smlal v24.4s, v14.4h, v4.h[6]\n" + "smlal v26.4s, v14.4h, v5.h[6]\n" + "smlal2 v25.4s, v14.8h, v4.h[6]\n" + "smlal2 v27.4s, v14.8h, v5.h[6]\n" + "smull v28.4s, v8.4h, v6.h[0]\n" + "smull v30.4s, v8.4h, v7.h[0]\n" + "smull2 v29.4s, v8.8h, v6.h[0]\n" + "smull2 v31.4s, v8.8h, v7.h[0]\n" + "smlal v28.4s, v9.4h, v6.h[1]\n" + "smlal v30.4s, v9.4h, v7.h[1]\n" + "smlal2 v29.4s, v9.8h, v6.h[1]\n" + "smlal2 v31.4s, v9.8h, v7.h[1]\n" + "smlal v28.4s, v10.4h, v6.h[2]\n" + "smlal v30.4s, v10.4h, v7.h[2]\n" + "smlal2 v29.4s, v10.8h, v6.h[2]\n" + "smlal2 v31.4s, v10.8h, v7.h[2]\n" + "smlal v28.4s, v11.4h, v6.h[3]\n" + "smlal v30.4s, v11.4h, v7.h[3]\n" + "smlal2 v29.4s, v11.8h, v6.h[3]\n" + "smlal2 v31.4s, v11.8h, v7.h[3]\n" + "smlal v28.4s, v12.4h, v6.h[4]\n" + "smlal v30.4s, v12.4h, v7.h[4]\n" + "smlal2 v29.4s, v12.8h, v6.h[4]\n" + "smlal2 v31.4s, v12.8h, v7.h[4]\n" + "smlal v28.4s, v13.4h, v6.h[5]\n" + "smlal v30.4s, v13.4h, v7.h[5]\n" + "smlal2 v29.4s, v13.8h, v6.h[5]\n" + "smlal2 v31.4s, v13.8h, v7.h[5]\n" + + "beq 2f\n" + + "1:\n" + + "smlal v24.4s, v15.4h, v4.h[7]\n" + "smlal v26.4s, v15.4h, v5.h[7]\n" + "smlal2 v25.4s, v15.8h, v4.h[7]\n" + + "ld1 {v8.4s, v9.4s}, [%[a_ptr]], 32\n" + "smlal2 v27.4s, v15.8h, v5.h[7]\n" + "smlal v28.4s, v14.4h, v6.h[6]\n" + "smlal v30.4s, v14.4h, v7.h[6]\n" + + "ld1 {v10.4s, v11.4s}, [%[a_ptr]], 32\n" + "smlal2 v29.4s, v15.8h, v6.h[7]\n" + "smlal2 v31.4s, v14.8h, v7.h[6]\n" + "smlal v28.4s, v15.4h, v6.h[7]\n" + + "ld1 {v12.4s, v13.4s}, [%[a_ptr]], 32\n" + "smlal v30.4s, v15.4h, v7.h[7]\n" + "smlal2 v29.4s, v14.8h, v6.h[6]\n" + + "ld1 {v0.4s}, [%[b_ptr]], 16\n" + "smlal2 v31.4s, v15.8h, v7.h[7]\n" + "smlal v16.4s, v8.4h, v0.h[0]\n" + + "ld1 {v1.4s}, [%[b_ptr]], 16\n" + "smlal v16.4s, v9.4h, v0.h[1]\n" + "smlal2 v17.4s, v8.8h, v0.h[0]\n" + "smlal v16.4s, v10.4h, v0.h[2]\n" + "smlal v18.4s, v8.4h, v1.h[0]\n" + "smlal2 v17.4s, v9.8h, v0.h[1]\n" + "smlal2 v19.4s, v8.8h, v1.h[0]\n" + + "ld1 {v14.4s, v15.4s}, [%[a_ptr]], 32\n" + "smlal v16.4s, v11.4h, v0.h[3]\n" + "smlal v18.4s, v9.4h, v1.h[1]\n" + "smlal2 v17.4s, v10.8h, v0.h[2]\n" + "smlal2 v19.4s, v9.8h, v1.h[1]\n" + "smlal v16.4s, v12.4h, v0.h[4]\n" + "smlal v18.4s, v10.4h, v1.h[2]\n" + "smlal2 v17.4s, v11.8h, v0.h[3]\n" + "smlal2 v19.4s, v10.8h, v1.h[2]\n" + "smlal v16.4s, v13.4h, v0.h[5]\n" + "smlal v18.4s, v11.4h, v1.h[3]\n" + "smlal2 v17.4s, v12.8h, v0.h[4]\n" + "smlal2 v19.4s, v11.8h, v1.h[3]\n" + "smlal v16.4s, v14.4h, v0.h[6]\n" + "smlal v18.4s, v12.4h, v1.h[4]\n" + "smlal2 v17.4s, v13.8h, v0.h[5]\n" + "smlal2 v19.4s, v12.8h, v1.h[4]\n" + "smlal v16.4s, v15.4h, v0.h[7]\n" + "smlal v18.4s, v13.4h, v1.h[5]\n" + "smlal2 v17.4s, v14.8h, v0.h[6]\n" + "smlal2 v19.4s, v13.8h, v1.h[5]\n" + + "ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n" + "smlal v18.4s, v14.4h, v1.h[6]\n" + "smlal2 v17.4s, v15.8h, v0.h[7]\n" + "smlal2 v19.4s, v14.8h, v1.h[6]\n" + "smlal v20.4s, v8.4h, v2.h[0]\n" + "smlal v22.4s, v8.4h, v3.h[0]\n" + "smlal2 v21.4s, v8.8h, v2.h[0]\n" + "smlal2 v23.4s, v8.8h, v3.h[0]\n" + "smlal v18.4s, v15.4h, v1.h[7]\n" + "smlal v20.4s, v9.4h, v2.h[1]\n" + "smlal v22.4s, v9.4h, v3.h[1]\n" + "smlal2 v21.4s, v9.8h, v2.h[1]\n" + "smlal2 v23.4s, v9.8h, v3.h[1]\n" + "smlal2 v19.4s, v15.8h, v1.h[7]\n" + "smlal v20.4s, v10.4h, v2.h[2]\n" + "smlal v22.4s, v10.4h, v3.h[2]\n" + "smlal2 v21.4s, v10.8h, v2.h[2]\n" + "smlal2 v23.4s, v10.8h, v3.h[2]\n" + "smlal v20.4s, v11.4h, v2.h[3]\n" + "smlal v22.4s, v11.4h, v3.h[3]\n" + "smlal2 v21.4s, v11.8h, v2.h[3]\n" + "smlal2 v23.4s, v11.8h, v3.h[3]\n" + "smlal v20.4s, v12.4h, v2.h[4]\n" + "smlal v22.4s, v12.4h, v3.h[4]\n" + "smlal2 v21.4s, v12.8h, v2.h[4]\n" + "smlal2 v23.4s, v12.8h, v3.h[4]\n" + "smlal v20.4s, v13.4h, v2.h[5]\n" + "smlal v22.4s, v13.4h, v3.h[5]\n" + "smlal2 v21.4s, v13.8h, v2.h[5]\n" + "smlal2 v23.4s, v13.8h, v3.h[5]\n" + + "ld1 {v4.4s, v5.4s}, [%[b_ptr]], 32\n" + "smlal v20.4s, v14.4h, v2.h[6]\n" + "smlal v22.4s, v14.4h, v3.h[6]\n" + "smlal2 v21.4s, v14.8h, v2.h[6]\n" + "smlal2 v23.4s, v14.8h, v3.h[6]\n" + "smlal v24.4s, v8.4h, v4.h[0]\n" + "smlal v26.4s, v8.4h, v5.h[0]\n" + "smlal2 v25.4s, v8.8h, v4.h[0]\n" + "smlal2 v27.4s, v8.8h, v5.h[0]\n" + "smlal v20.4s, v15.4h, v2.h[7]\n" + "smlal2 v21.4s, v15.8h, v2.h[7]\n" + "smlal v22.4s, v15.4h, v3.h[7]\n" + "smlal2 v23.4s, v15.8h, v3.h[7]\n" + "smlal v24.4s, v9.4h, v4.h[1]\n" + "smlal v26.4s, v9.4h, v5.h[1]\n" + "smlal2 v25.4s, v9.8h, v4.h[1]\n" + "smlal2 v27.4s, v9.8h, v5.h[1]\n" + "smlal v24.4s, v10.4h, v4.h[2]\n" + "smlal v26.4s, v10.4h, v5.h[2]\n" + "smlal2 v25.4s, v10.8h, v4.h[2]\n" + "smlal2 v27.4s, v10.8h, v5.h[2]\n" + "smlal v24.4s, v11.4h, v4.h[3]\n" + "smlal v26.4s, v11.4h, v5.h[3]\n" + "smlal2 v25.4s, v11.8h, v4.h[3]\n" + "smlal2 v27.4s, v11.8h, v5.h[3]\n" + "smlal v24.4s, v12.4h, v4.h[4]\n" + "smlal v26.4s, v12.4h, v5.h[4]\n" + "smlal2 v25.4s, v12.8h, v4.h[4]\n" + "smlal2 v27.4s, v12.8h, v5.h[4]\n" + "smlal v24.4s, v13.4h, v4.h[5]\n" + "smlal v26.4s, v13.4h, v5.h[5]\n" + "smlal2 v25.4s, v13.8h, v4.h[5]\n" + "smlal2 v27.4s, v13.8h, v5.h[5]\n" + + "ld1 {v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n" + "smlal v24.4s, v14.4h, v4.h[6]\n" + "smlal v26.4s, v14.4h, v5.h[6]\n" + "smlal2 v25.4s, v14.8h, v4.h[6]\n" + "smlal2 v27.4s, v14.8h, v5.h[6]\n" + "smlal v28.4s, v8.4h, v6.h[0]\n" + "smlal v30.4s, v8.4h, v7.h[0]\n" + "smlal2 v29.4s, v8.8h, v6.h[0]\n" + "smlal2 v31.4s, v8.8h, v7.h[0]\n" + "smlal v28.4s, v9.4h, v6.h[1]\n" + "smlal v30.4s, v9.4h, v7.h[1]\n" + "smlal2 v29.4s, v9.8h, v6.h[1]\n" + "smlal2 v31.4s, v9.8h, v7.h[1]\n" + "smlal v28.4s, v10.4h, v6.h[2]\n" + "smlal v30.4s, v10.4h, v7.h[2]\n" + "smlal2 v29.4s, v10.8h, v6.h[2]\n" + "smlal2 v31.4s, v10.8h, v7.h[2]\n" + "smlal v28.4s, v11.4h, v6.h[3]\n" + "smlal v30.4s, v11.4h, v7.h[3]\n" + "smlal2 v29.4s, v11.8h, v6.h[3]\n" + "smlal2 v31.4s, v11.8h, v7.h[3]\n" + "smlal v28.4s, v12.4h, v6.h[4]\n" + "smlal v30.4s, v12.4h, v7.h[4]\n" + "smlal2 v29.4s, v12.8h, v6.h[4]\n" + "smlal2 v31.4s, v12.8h, v7.h[4]\n" + "smlal v28.4s, v13.4h, v6.h[5]\n" + "smlal v30.4s, v13.4h, v7.h[5]\n" + "smlal2 v29.4s, v13.8h, v6.h[5]\n" + "smlal2 v31.4s, v13.8h, v7.h[5]\n" + + "subs %w[K], %w[K], #8\n" + "bne 1b\n" + + "2:\n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" + "smlal v24.4s, v15.4h, v4.h[7]\n" + "smlal v28.4s, v14.4h, v6.h[6]\n" + "smlal v30.4s, v14.4h, v7.h[6]\n" + "smlal v26.4s, v15.4h, v5.h[7]\n" + "smlal2 v25.4s, v15.8h, v4.h[7]\n" + "smlal2 v27.4s, v15.8h, v5.h[7]\n" + "smlal2 v29.4s, v14.8h, v6.h[6]\n" + "smlal2 v31.4s, v14.8h, v7.h[6]\n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n" + "smlal v28.4s, v15.4h, v6.h[7]\n" + "smlal v30.4s, v15.4h, v7.h[7]\n" + "smlal2 v29.4s, v15.8h, v6.h[7]\n" + "smlal2 v31.4s, v15.8h, v7.h[7]\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "cc", "memory"); +} + +} // anonymous namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8); + +void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, + size_t LDB, dt_int32* C, size_t LDC, size_t M, + size_t K, size_t N, const dt_int32*, void*, + bool trA, bool trB) const { + constexpr static size_t MB = 8; + constexpr static size_t KB = 8; + constexpr static size_t NB = 8; + constexpr static size_t CALCBLK = 4; + + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); + + //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) + for (size_t m = 0; m < M; m += MB) { + dt_int32* output = C + (m / MB) * LDC; + const dt_int16* cur_B = B; + size_t n = 0; + for (; n + NB - 1 < N; n += NB) { + kern_8x8(A, cur_B, LDB, K, output); + cur_B += KB * NB; + output += MB * NB; + } + if (n < N) { + kern_8x4(A, cur_B, LDB, K, output); + } + A += LDA; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h b/dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h new file mode 100644 index 00000000..d118a9b2 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h @@ -0,0 +1,856 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if !(__ARM_FEATURE_DOTPROD) +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_4x4x16 { + +/** + * Overview of register layout: + * + * A 16x2 cell of Rhs is stored in 8bit in q2-q4. + * A 8x2x2 cell of Lhs is stored in 8bit in q0-q1 + * A 8x16 block of accumulators is stored in 8bit in q8--q31. + * + * \warning Fast kernel operating on int8 operands. + * It is assumed that one of the two int8 operands only takes values + * in [-127, 127], while the other may freely range in [-128, 127]. + * The issue with both operands taking the value -128 is that: + * -128*-128 + -128*-128 == -32768 overflows int16. + * Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 + * range. That is the basic idea of this kernel. + * + * + * +--------+--------+---------+---------+ + * |v4[0-16]|v5[0-16]| v6[0-16]| v7[0-16]| + * Rhs +--------+--------+---------+---------+ + * |v8[0-16]|v9[0-16]|v10[0-16]|v11[0-16]| + * +--------+--------+---------+---------+ + * | | | | | + * + * Lhs | | | | | + * + * +--------+ - - - - +-------------------------------------+ + * |v0[0-16]| |v16[0-4]|v17[0-4]| v18[0-4]| v19[0-4]| + * |v1[0-16]| |v20[0-4]|v21[0-4]| v22[0-4]| v23[0-4]| + * |v2[0-16]| |v24[0-4]|v25[0-4]| v26[0-4]| v27[0-4]| + * |v3[0-16]| |v28[0-4]|v29[0-4]| v30[0-4]| v31[0-4]| + * +--------+ - - - - +-------------------------------------+ + * + * Accumulator + */ + +static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k) { + K /= 16; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(int32_t); + + asm volatile ( + // load accumulator C + "add x1, %[output], %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + + "ldr q16, [%[output]]\n" + "ldr q17, [x1]\n" + "ldr q18, [x2]\n" + "ldr q19, [x3]\n" + "b 2f\n" + + "1:\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + "ldr q0, [%[a_ptr]]\n" + "ldr q4, [%[b_ptr]]\n" + "ldr q5, [%[b_ptr], #16]\n" + "ldr q6, [%[b_ptr], #32]\n" + "movi v20.4s, #0x0\n" + "ldr q7, [%[b_ptr], #48]\n" + "movi v21.4s, #0x0\n" + "ldr q1, [%[a_ptr], #16]\n" + "movi v22.4s, #0x0\n" + "ldr q2, [%[a_ptr], #32]\n" + "movi v23.4s, #0x0\n" + "ldr q3, [%[a_ptr], #48]\n" + "movi v24.4s, #0x0\n" + ASM_PREFETCH("[%[b_ptr], #64]") + "movi v25.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #64]") + "movi v26.4s, #0x0\n" + ASM_PREFETCH("[%[b_ptr], #128]") + "movi v27.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #128]") + "movi v28.4s, #0x0\n" + ASM_PREFETCH("[%[b_ptr], #192]") + "movi v29.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #192]") + "movi v30.4s, #0x0\n" + ASM_PREFETCH("[%[b_ptr], #256]") + "movi v31.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #256]") + + // Start of unroll 0 (first iteration) + "smull v12.8h, v0.8b, v4.8b\n" + "smull v13.8h, v0.8b, v5.8b\n" + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + // Unroll 0 continuation (branch target) + "3:\n" + "smull v14.8h, v0.8b, v6.8b\n" + "subs %w[k], %w[k], #1\n" + "smull v15.8h, v0.8b, v7.8b\n" + "ldr q8, [%[b_ptr], #64]\n" + "smlal2 v12.8h, v0.16b, v4.16b\n" + "smlal2 v13.8h, v0.16b, v5.16b\n" + "ldr q9, [%[b_ptr], #80]\n" + "smlal2 v14.8h, v0.16b, v6.16b\n" + "smlal2 v15.8h, v0.16b, v7.16b\n" + "ldr q0, [%[a_ptr], #64]\n" + + "sadalp v16.4s, v12.8h\n" + "smull v12.8h, v1.8b, v4.8b\n" + "sadalp v17.4s, v13.8h\n" + "sadalp v18.4s, v14.8h\n" + "smull v13.8h, v1.8b, v5.8b\n" + "sadalp v19.4s, v15.8h\n" + "smull v14.8h, v1.8b, v6.8b\n" + "ldr q10, [%[b_ptr], #96]\n" + "smull v15.8h, v1.8b, v7.8b\n" + "smlal2 v12.8h, v1.16b, v4.16b\n" + "ldr q11, [%[b_ptr], #112]\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "add %[b_ptr], %[b_ptr], #128\n" + "smlal2 v14.8h, v1.16b, v6.16b\n" + "smlal2 v15.8h, v1.16b, v7.16b\n" + "ldr q1, [%[a_ptr], #80]\n" + + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v2.8b, v4.8b\n" + "sadalp v21.4s, v13.8h\n" + "sadalp v22.4s, v14.8h\n" + "smull v13.8h, v2.8b, v5.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v14.8h, v2.8b, v6.8b\n" + "smull v15.8h, v2.8b, v7.8b\n" + "smlal2 v12.8h, v2.16b, v4.16b\n" + ASM_PREFETCH("[%[b_ptr], #192]") + "smlal2 v13.8h, v2.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v6.16b\n" + ASM_PREFETCH("[%[a_ptr], #320]") + "smlal2 v15.8h, v2.16b, v7.16b\n" + "ldr q2, [%[a_ptr], #96]\n" + + "sadalp v24.4s, v12.8h\n" + "smull v12.8h, v3.8b, v4.8b\n" + "sadalp v25.4s, v13.8h\n" + "sadalp v26.4s, v14.8h\n" + "smull v13.8h, v3.8b, v5.8b\n" + "sadalp v27.4s, v15.8h\n" + "smull v14.8h, v3.8b, v6.8b\n" + "smull v15.8h, v3.8b, v7.8b\n" + "smlal2 v12.8h, v3.16b, v4.16b\n" + "ldr q4, [%[b_ptr], #0]\n" + "smlal2 v13.8h, v3.16b, v5.16b\n" + "smlal2 v14.8h, v3.16b, v6.16b\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "ldr q3, [%[a_ptr], #112]\n" + + // Unroll 1 + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, v8.8b\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "smull v13.8h, v0.8b, v9.8b\n" + "sadalp v31.4s, v15.8h\n" + "smull v14.8h, v0.8b, v10.8b\n" + "smull v15.8h, v0.8b, v11.8b\n" + "ldr q5, [%[b_ptr], #16]\n" + "smlal2 v12.8h, v0.16b, v8.16b\n" + "smlal2 v13.8h, v0.16b, v9.16b\n" + "ldr q6, [%[b_ptr], #32]\n" + "smlal2 v14.8h, v0.16b, v10.16b\n" + "smlal2 v15.8h, v0.16b, v11.16b\n" + "ldr q0, [%[a_ptr], #128]\n" + + "sadalp v16.4s, v12.8h\n" + "smull v12.8h, v1.8b, v8.8b\n" + "sadalp v17.4s, v13.8h\n" + "sadalp v18.4s, v14.8h\n" + "smull v13.8h, v1.8b, v9.8b\n" + "sadalp v19.4s, v15.8h\n" + "add %[a_ptr], %[a_ptr], #128\n" + "smull v14.8h, v1.8b, v10.8b\n" + "smull v15.8h, v1.8b, v11.8b\n" + "ldr q7, [%[b_ptr], #48]\n" + "smlal2 v12.8h, v1.16b, v8.16b\n" + "smlal2 v13.8h, v1.16b, v9.16b\n" + "smlal2 v14.8h, v1.16b, v10.16b\n" + "smlal2 v15.8h, v1.16b, v11.16b\n" + "ldr q1, [%[a_ptr], #16]\n" + + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v2.8b, v8.8b\n" + "sadalp v21.4s, v13.8h\n" + "sadalp v22.4s, v14.8h\n" + "smull v13.8h, v2.8b, v9.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v14.8h, v2.8b, v10.8b\n" + "smull v15.8h, v2.8b, v11.8b\n" + "smlal2 v12.8h, v2.16b, v8.16b\n" + ASM_PREFETCH("[%[b_ptr], #256]") + "smlal2 v13.8h, v2.16b, v9.16b\n" + "smlal2 v14.8h, v2.16b, v10.16b\n" + ASM_PREFETCH("[%[a_ptr], #256]") + "smlal2 v15.8h, v2.16b, v11.16b\n" + "ldr q2, [%[a_ptr], #32]\n" + + "sadalp v24.4s, v12.8h\n" + "smull v12.8h, v3.8b, v8.8b\n" + "sadalp v25.4s, v13.8h\n" + "sadalp v26.4s, v14.8h\n" + "smull v13.8h, v3.8b, v9.8b\n" + "sadalp v27.4s, v15.8h\n" + "smull v14.8h, v3.8b, v10.8b\n" + "smull v15.8h, v3.8b, v11.8b\n" + "smlal2 v12.8h, v3.16b, v8.16b\n" + "smlal2 v13.8h, v3.16b, v9.16b\n" + "smlal2 v14.8h, v3.16b, v10.16b\n" + "smlal2 v15.8h, v3.16b, v11.16b\n" + "ldr q3, [%[a_ptr], #48]\n" + + // Start of unroll 0 for next iteration. + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, v4.8b\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "smull v13.8h, v0.8b, v5.8b\n" + "sadalp v31.4s, v15.8h\n" + "bne 3b\n" + + // Target to use when K=1 or 2 (i.e. zero iterations of main loop) + "4:\n" + + // Branch to alternative tail for odd K + "cbnz %w[oddk], 5f\n" + + // Detached final iteration (even K) + "smull v14.8h, v0.8b, v6.8b\n" + "smull v15.8h, v0.8b, v7.8b\n" + "ldr q8, [%[b_ptr], #64]\n" + "smlal2 v12.8h, v0.16b, v4.16b\n" + "smlal2 v13.8h, v0.16b, v5.16b\n" + "ldr q9, [%[b_ptr], #80]\n" + "smlal2 v14.8h, v0.16b, v6.16b\n" + "smlal2 v15.8h, v0.16b, v7.16b\n" + "ldr q0, [%[a_ptr], #64]\n" + + "sadalp v16.4s, v12.8h\n" + "smull v12.8h, v1.8b, v4.8b\n" + "sadalp v17.4s, v13.8h\n" + "sadalp v18.4s, v14.8h\n" + "smull v13.8h, v1.8b, v5.8b\n" + "sadalp v19.4s, v15.8h\n" + "smull v14.8h, v1.8b, v6.8b\n" + "ldr q10, [%[b_ptr], #96]\n" + "smull v15.8h, v1.8b, v7.8b\n" + "smlal2 v12.8h, v1.16b, v4.16b\n" + "ldr q11, [%[b_ptr], #112]\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "add %[b_ptr], %[b_ptr], #128\n" + "smlal2 v14.8h, v1.16b, v6.16b\n" + "smlal2 v15.8h, v1.16b, v7.16b\n" + "ldr q1, [%[a_ptr], #80]\n" + + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v2.8b, v4.8b\n" + "sadalp v21.4s, v13.8h\n" + "sadalp v22.4s, v14.8h\n" + "smull v13.8h, v2.8b, v5.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v14.8h, v2.8b, v6.8b\n" + "smull v15.8h, v2.8b, v7.8b\n" + "smlal2 v12.8h, v2.16b, v4.16b\n" + "smlal2 v13.8h, v2.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v6.16b\n" + "smlal2 v15.8h, v2.16b, v7.16b\n" + "ldr q2, [%[a_ptr], #96]\n" + + "sadalp v24.4s, v12.8h\n" + "smull v12.8h, v3.8b, v4.8b\n" + "sadalp v25.4s, v13.8h\n" + "sadalp v26.4s, v14.8h\n" + "smull v13.8h, v3.8b, v5.8b\n" + "sadalp v27.4s, v15.8h\n" + "smull v14.8h, v3.8b, v6.8b\n" + "smull v15.8h, v3.8b, v7.8b\n" + "smlal2 v12.8h, v3.16b, v4.16b\n" + "smlal2 v13.8h, v3.16b, v5.16b\n" + "smlal2 v14.8h, v3.16b, v6.16b\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "ldr q3, [%[a_ptr], #112]\n" + + // Unroll 1 + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, v8.8b\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "smull v13.8h, v0.8b, v9.8b\n" + "sadalp v31.4s, v15.8h\n" + "smull v14.8h, v0.8b, v10.8b\n" + "add %[a_ptr], %[a_ptr], #128\n" + "smull v15.8h, v0.8b, v11.8b\n" + "smlal2 v12.8h, v0.16b, v8.16b\n" + "smlal2 v13.8h, v0.16b, v9.16b\n" + "smlal2 v14.8h, v0.16b, v10.16b\n" + "smlal2 v15.8h, v0.16b, v11.16b\n" + + "sadalp v16.4s, v12.8h\n" + "smull v12.8h, v1.8b, v8.8b\n" + "sadalp v17.4s, v13.8h\n" + "sadalp v18.4s, v14.8h\n" + "smull v13.8h, v1.8b, v9.8b\n" + "sadalp v19.4s, v15.8h\n" + "smull v14.8h, v1.8b, v10.8b\n" + "smull v15.8h, v1.8b, v11.8b\n" + "smlal2 v12.8h, v1.16b, v8.16b\n" + "addp v16.4s, v16.4s, v17.4s\n" + "smlal2 v13.8h, v1.16b, v9.16b\n" + "addp v17.4s, v18.4s, v19.4s\n" + "smlal2 v14.8h, v1.16b, v10.16b\n" + "smlal2 v15.8h, v1.16b, v11.16b\n" + + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v2.8b, v8.8b\n" + "sadalp v21.4s, v13.8h\n" + "sadalp v22.4s, v14.8h\n" + "smull v13.8h, v2.8b, v9.8b\n" + "sadalp v23.4s, v15.8h\n" + "addp v16.4s, v16.4s, v17.4s\n" + "smull v14.8h, v2.8b, v10.8b\n" + "addp v18.4s, v20.4s, v21.4s\n" + "addp v19.4s, v22.4s, v23.4s\n" + "smull v15.8h, v2.8b, v11.8b\n" + "smlal2 v12.8h, v2.16b, v8.16b\n" + "str q16, [%[output]]\n" + "smlal2 v13.8h, v2.16b, v9.16b\n" + "smlal2 v14.8h, v2.16b, v10.16b\n" + "smlal2 v15.8h, v2.16b, v11.16b\n" + + "sadalp v24.4s, v12.8h\n" + "smull v12.8h, v3.8b, v8.8b\n" + "sadalp v25.4s, v13.8h\n" + "sadalp v26.4s, v14.8h\n" + "smull v13.8h, v3.8b, v9.8b\n" + "sadalp v27.4s, v15.8h\n" + "addp v17.4s, v18.4s, v19.4s\n" + "smull v14.8h, v3.8b, v10.8b\n" + "addp v20.4s, v24.4s, v25.4s\n" + "addp v21.4s, v26.4s, v27.4s\n" + "smull v15.8h, v3.8b, v11.8b\n" + "smlal2 v12.8h, v3.16b, v8.16b\n" + "str q17, [x1]\n" + "smlal2 v13.8h, v3.16b, v9.16b\n" + "smlal2 v14.8h, v3.16b, v10.16b\n" + "addp v18.4s, v20.4s, v21.4s\n" + "smlal2 v15.8h, v3.16b, v11.16b\n" + "b 6f\n" + + // Detached final iteration (odd K) + "5:\n" + "smull v14.8h, v0.8b, v6.8b\n" + "add %[a_ptr], %[a_ptr], #64\n" + "smull v15.8h, v0.8b, v7.8b\n" + "add %[b_ptr], %[b_ptr], #64\n" + "smlal2 v12.8h, v0.16b, v4.16b\n" + "smlal2 v13.8h, v0.16b, v5.16b\n" + "smlal2 v14.8h, v0.16b, v6.16b\n" + "smlal2 v15.8h, v0.16b, v7.16b\n" + + "sadalp v16.4s, v12.8h\n" + "smull v12.8h, v1.8b, v4.8b\n" + "sadalp v17.4s, v13.8h\n" + "sadalp v18.4s, v14.8h\n" + "smull v13.8h, v1.8b, v5.8b\n" + "sadalp v19.4s, v15.8h\n" + "smull v14.8h, v1.8b, v6.8b\n" + "smull v15.8h, v1.8b, v7.8b\n" + "smlal2 v12.8h, v1.16b, v4.16b\n" + "addp v16.4s, v16.4s, v17.4s\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "addp v17.4s, v18.4s, v19.4s\n" + "smlal2 v14.8h, v1.16b, v6.16b\n" + "smlal2 v15.8h, v1.16b, v7.16b\n" + + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v2.8b, v4.8b\n" + "sadalp v21.4s, v13.8h\n" + "sadalp v22.4s, v14.8h\n" + "smull v13.8h, v2.8b, v5.8b\n" + "sadalp v23.4s, v15.8h\n" + "addp v16.4s, v16.4s, v17.4s\n" + "smull v14.8h, v2.8b, v6.8b\n" + "addp v18.4s, v20.4s, v21.4s\n" + "addp v19.4s, v22.4s, v23.4s\n" + "smull v15.8h, v2.8b, v7.8b\n" + "smlal2 v12.8h, v2.16b, v4.16b\n" + "str q16, [%[output]]\n" + "smlal2 v13.8h, v2.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v6.16b\n" + "smlal2 v15.8h, v2.16b, v7.16b\n" + + "sadalp v24.4s, v12.8h\n" + "smull v12.8h, v3.8b, v4.8b\n" + "sadalp v25.4s, v13.8h\n" + "sadalp v26.4s, v14.8h\n" + "smull v13.8h, v3.8b, v5.8b\n" + "sadalp v27.4s, v15.8h\n" + "addp v17.4s, v18.4s, v19.4s\n" + "smull v14.8h, v3.8b, v6.8b\n" + "addp v20.4s, v24.4s, v25.4s\n" + "addp v21.4s, v26.4s, v27.4s\n" + "smull v15.8h, v3.8b, v7.8b\n" + "smlal2 v12.8h, v3.16b, v4.16b\n" + "str q17, [x1]\n" + "smlal2 v13.8h, v3.16b, v5.16b\n" + "smlal2 v14.8h, v3.16b, v6.16b\n" + "addp v18.4s, v20.4s, v21.4s\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + + "6:\n" + + // Final additions + "sadalp v28.4s, v12.8h\n" + "str q18, [x2]\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + // Horizontal reduction, phase 1 + "addp v22.4s, v28.4s, v29.4s\n" + "addp v23.4s, v30.4s, v31.4s\n" + + // Horizontal reduction, phase 2 + "addp v19.4s, v22.4s, v23.4s\n" + "str q19, [x3]\n" + + : + [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [oddk] "+r" (oddk), + [is_first_k] "+r" (is_first_k), [k] "+r" (k), [LDC] "+r" (LDC), + [output] "+r"(output) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20","v21","v22","v23","v24","v25","v26", + "v27","v28","v29","v30","v31", "x1", "x2", "x3", + "cc", "memory" + ); +} + +static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, + int m_remain, int n_remain) { + megdnn_assert(K > 0); + K /= 16; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "cbz x4, 102f\n" \ + "mov x5, x" n "\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" reg_index ", [x5]\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[0], [x5], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[1], [x5], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[2], [x5], #4\n" \ + "101" n ":\n" \ + "subs x4, x4, #1\n" + +#define LOAD_C \ + "mov x4, %x[m_remain]\n" \ + LOAD_LINE("16", "0") \ + LOAD_LINE("17", "1") \ + LOAD_LINE("18", "2") \ + LOAD_LINE("19", "3") \ + "102:\n" + +#define STORE_LINE(reg_index, n) \ + "cbz x4, 105f\n" \ + "mov x5, x" n "\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "str q" reg_index ", [x5]\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[0], [x5], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[1], [x5], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[2], [x5], #4\n" \ + "104" n ":\n" \ + "subs x4, x4, #1\n" + +#define STORE_C \ + "mov x4, %x[m_remain]\n" \ + STORE_LINE("16", "0") \ + STORE_LINE("17", "1") \ + STORE_LINE("18", "2") \ + STORE_LINE("19", "3") \ + "105:\n" + + // clang-format on + + asm volatile( + // load accumulator C + "mov x0, %[output]\n" + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + + LOAD_C // + "b 2f\n" + + "1:\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + + "2: \n" + "ldr q4, [%[b_ptr]]\n" + "ldr q5, [%[b_ptr], #16]\n" + "ldr q6, [%[b_ptr], #32]\n" + "ldr q7, [%[b_ptr], #48]\n" + "ldr q0, [%[a_ptr]]\n" + "ldr q1, [%[a_ptr], #16]\n" + "ldr q2, [%[a_ptr], #32]\n" + "ldr q3, [%[a_ptr], #48]\n" + + "smull v12.8h, v0.8b, v4.8b\n" + "smull v13.8h, v0.8b, v5.8b\n" + "smull v14.8h, v0.8b, v6.8b\n" + "smull v15.8h, v0.8b, v7.8b\n" + "smlal2 v12.8h, v0.16b, v4.16b\n" + "smlal2 v13.8h, v0.16b, v5.16b\n" + "smlal2 v14.8h, v0.16b, v6.16b\n" + "smlal2 v15.8h, v0.16b, v7.16b\n" + "sadalp v16.4s, v12.8h\n" + "sadalp v17.4s, v13.8h\n" + "sadalp v18.4s, v14.8h\n" + "sadalp v19.4s, v15.8h\n" + + "smull v12.8h, v1.8b, v4.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v1.8b, v6.8b\n" + "smull v15.8h, v1.8b, v7.8b\n" + "smlal2 v12.8h, v1.16b, v4.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v1.16b, v6.16b\n" + "smlal2 v15.8h, v1.16b, v7.16b\n" + "sadalp v20.4s, v12.8h\n" + "sadalp v21.4s, v13.8h\n" + "sadalp v22.4s, v14.8h\n" + "sadalp v23.4s, v15.8h\n" + + "smull v12.8h, v2.8b, v4.8b\n" + "smull v13.8h, v2.8b, v5.8b\n" + "smull v14.8h, v2.8b, v6.8b\n" + "smull v15.8h, v2.8b, v7.8b\n" + "smlal2 v12.8h, v2.16b, v4.16b\n" + "smlal2 v13.8h, v2.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v6.16b\n" + "smlal2 v15.8h, v2.16b, v7.16b\n" + "sadalp v24.4s, v12.8h\n" + "sadalp v25.4s, v13.8h\n" + "sadalp v26.4s, v14.8h\n" + "sadalp v27.4s, v15.8h\n" + + "smull v12.8h, v3.8b, v4.8b\n" + "smull v13.8h, v3.8b, v5.8b\n" + "smull v14.8h, v3.8b, v6.8b\n" + "smull v15.8h, v3.8b, v7.8b\n" + "smlal2 v12.8h, v3.16b, v4.16b\n" + "smlal2 v13.8h, v3.16b, v5.16b\n" + "smlal2 v14.8h, v3.16b, v6.16b\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "sadalp v28.4s, v12.8h\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + "add %[a_ptr], %[a_ptr], #64\n" + "add %[b_ptr], %[b_ptr], #64\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" + // reduction + "addp v16.4s, v16.4s, v17.4s\n" + "addp v17.4s, v18.4s, v19.4s\n" + "addp v16.4s, v16.4s, v17.4s\n" + "addp v18.4s, v20.4s, v21.4s\n" + "addp v19.4s, v22.4s, v23.4s\n" + "addp v17.4s, v18.4s, v19.4s\n" + "addp v20.4s, v24.4s, v25.4s\n" + "addp v21.4s, v26.4s, v27.4s\n" + "addp v18.4s, v20.4s, v21.4s\n" + "addp v22.4s, v28.4s, v29.4s\n" + "addp v23.4s, v30.4s, v31.4s\n" + "addp v19.4s, v22.4s, v23.4s\n" + + STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [output] "+r"(output), [m_remain] "+r"(m_remain), + [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc", + "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y + 3 < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + //! read 16 * 4 in each row + for (; K > 15; K -= 16) { + interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); + } + } + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + //! read 4 * 4 in each row + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); + } + } +} + +static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 16) * 4; + int8_t* outptr = out; + + int k = k0; + for (; k < kmax; k += 16) { + int ki = k; + for (int cnt = 0; cnt < 2; ki += 8, cnt++) { + const int8_t* inptr0 = in + ki * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + int8_t* outptr_inner = outptr + ki - k; + + int remain = std::min(ki + 7 - kmax, 7); + int x = x0; + for (; x + 3 < xmax; x += 4) { + if (remain >= 0) { + switch (remain) { + case 7: + inptr0 = zerobuff; + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, + inptr4, inptr5, inptr6, inptr7, + outptr_inner); + outptr_inner += ksize4; + } + + if (x < xmax) { + if (remain >= 0) { + switch (remain) { + case 7: + inptr0 = zerobuff; + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + for (; x < xmax; x++) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + *outptr_inner++ = *inptr4++; + *outptr_inner++ = *inptr5++; + *outptr_inner++ = *inptr6++; + *outptr_inner++ = *inptr7++; + outptr_inner += 8; + } + } + } + + outptr += 16 * 4; + } +} + +} // namespace matmul_4x4x16 +} // namespace aarch64 +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h b/dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h new file mode 100644 index 00000000..6330b115 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h @@ -0,0 +1,1375 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if !(__ARM_FEATURE_DOTPROD) +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_8x8x8 { + +/** + * Overview of register layout: + * + * A 8x8x8 cell of Rhs is stored in 8bit in q26-q27 + * A 8x8x8 cell of Lhs is stored in 8bit in q0-q7 + * A 8x8 block of accumulators is stored in 32bit in q8-q23 + * + * +--------+--------+ + * |v26[0-8]|v27[0-8]| + * Rhs +--------+--------+ + * Lhs | | | + * + * +--------+ - - - - +-----------------+ + * |v0[0-8]| | v8[0-4]| v9[0-4]| + * |v1[0-8]| |v10[0-4]|v11[0-4]| + * |v2[0-8]| |v12[0-4]|v13[0-4]| + * |v3[0-8]| |v14[0-4]|v15[0-4]| + * |v4[0-8]| |v16[0-4]|v17[0-4]| + * |v5[0-8]| |v18[0-4]|v19[0-4]| + * |v6[0-8]| |v20[0-4]|v21[0-4]| + * |v7[0-8]| |v22[0-4]|v23[0-4]| + * +--------+ - - - - +-----------------+ + * + * Accumulator + */ + +static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k) { + K /= 8; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + + asm volatile( + // load accumulator C + "add x1, %[output], %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "add x4, x3, %x[LDC]\n" + "add x5, x4, %x[LDC]\n" + "add x6, x5, %x[LDC]\n" + "add x7, x6, %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + + "ldp q8, q9, [%[output]]\n" + "ldp q10, q11, [x1]\n" + "ldp q12, q13, [x2]\n" + "ldp q14, q15, [x3]\n" + "ldp q16, q17, [x4]\n" + "ldp q18, q19, [x5]\n" + "ldp q20, q21, [x6]\n" + "ldp q22, q23, [x7]\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + + "2: \n" + "ld1 {v26.8b}, [%[b_ptr]], 8\n" + "ld1 {v0.8b}, [%[a_ptr]], 8\n" + "ld1 {v1.8b}, [%[a_ptr]], 8\n" + "ld1 {v2.8b}, [%[a_ptr]], 8\n" + "ld1 {v3.8b}, [%[a_ptr]], 8\n" + "ld1 {v4.8b}, [%[a_ptr]], 8\n" + "ld1 {v5.8b}, [%[a_ptr]], 8\n" + "ld1 {v6.8b}, [%[a_ptr]], 8\n" + "ld1 {v7.8b}, [%[a_ptr]], 8\n" + "sshll v26.8h, v26.8b, #0\n" + "sshll v0.8h, v0.8b, #0\n" + "sshll v1.8h, v1.8b, #0\n" + "sshll v2.8h, v2.8b, #0\n" + "sshll v3.8h, v3.8b, #0\n" + "sshll v4.8h, v4.8b, #0\n" + "sshll v5.8h, v5.8b, #0\n" + "sshll v6.8h, v6.8b, #0\n" + "sshll v7.8h, v7.8b, #0\n" + + "ld1 {v27.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v26.4h, v0.h[0]\n" + "smlal v10.4s, v26.4h, v1.h[0]\n" + "smlal v12.4s, v26.4h, v2.h[0]\n" + "smlal v14.4s, v26.4h, v3.h[0]\n" + "smlal v16.4s, v26.4h, v4.h[0]\n" + "smlal v18.4s, v26.4h, v5.h[0]\n" + "smlal v20.4s, v26.4h, v6.h[0]\n" + "smlal v22.4s, v26.4h, v7.h[0]\n" + "sshll v27.8h, v27.8b, #0\n" + "smlal2 v9.4s, v26.8h, v0.h[0]\n" + "smlal2 v11.4s, v26.8h, v1.h[0]\n" + "smlal2 v13.4s, v26.8h, v2.h[0]\n" + "smlal2 v15.4s, v26.8h, v3.h[0]\n" + "smlal2 v17.4s, v26.8h, v4.h[0]\n" + "smlal2 v19.4s, v26.8h, v5.h[0]\n" + "smlal2 v21.4s, v26.8h, v6.h[0]\n" + "smlal2 v23.4s, v26.8h, v7.h[0]\n" + + "ld1 {v26.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v27.4h, v0.h[1]\n" + "smlal v10.4s, v27.4h, v1.h[1]\n" + "smlal v12.4s, v27.4h, v2.h[1]\n" + "smlal v14.4s, v27.4h, v3.h[1]\n" + "smlal v16.4s, v27.4h, v4.h[1]\n" + "smlal v18.4s, v27.4h, v5.h[1]\n" + "smlal v20.4s, v27.4h, v6.h[1]\n" + "smlal v22.4s, v27.4h, v7.h[1]\n" + "sshll v26.8h, v26.8b, #0\n" + "smlal2 v9.4s, v27.8h, v0.h[1]\n" + "smlal2 v11.4s, v27.8h, v1.h[1]\n" + "smlal2 v13.4s, v27.8h, v2.h[1]\n" + "smlal2 v15.4s, v27.8h, v3.h[1]\n" + "smlal2 v17.4s, v27.8h, v4.h[1]\n" + "smlal2 v19.4s, v27.8h, v5.h[1]\n" + "smlal2 v21.4s, v27.8h, v6.h[1]\n" + "smlal2 v23.4s, v27.8h, v7.h[1]\n" + + "ld1 {v27.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v26.4h, v0.h[2]\n" + "smlal v10.4s, v26.4h, v1.h[2]\n" + "smlal v12.4s, v26.4h, v2.h[2]\n" + "smlal v14.4s, v26.4h, v3.h[2]\n" + "smlal v16.4s, v26.4h, v4.h[2]\n" + "smlal v18.4s, v26.4h, v5.h[2]\n" + "smlal v20.4s, v26.4h, v6.h[2]\n" + "smlal v22.4s, v26.4h, v7.h[2]\n" + "sshll v27.8h, v27.8b, #0\n" + "smlal2 v9.4s, v26.8h, v0.h[2]\n" + "smlal2 v11.4s, v26.8h, v1.h[2]\n" + "smlal2 v13.4s, v26.8h, v2.h[2]\n" + "smlal2 v15.4s, v26.8h, v3.h[2]\n" + "smlal2 v17.4s, v26.8h, v4.h[2]\n" + "smlal2 v19.4s, v26.8h, v5.h[2]\n" + "smlal2 v21.4s, v26.8h, v6.h[2]\n" + "smlal2 v23.4s, v26.8h, v7.h[2]\n" + + "ld1 {v26.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v27.4h, v0.h[3]\n" + "smlal v10.4s, v27.4h, v1.h[3]\n" + "smlal v12.4s, v27.4h, v2.h[3]\n" + "smlal v14.4s, v27.4h, v3.h[3]\n" + "smlal v16.4s, v27.4h, v4.h[3]\n" + "smlal v18.4s, v27.4h, v5.h[3]\n" + "smlal v20.4s, v27.4h, v6.h[3]\n" + "smlal v22.4s, v27.4h, v7.h[3]\n" + "sshll v26.8h, v26.8b, #0\n" + "smlal2 v9.4s, v27.8h, v0.h[3]\n" + "smlal2 v11.4s, v27.8h, v1.h[3]\n" + "smlal2 v13.4s, v27.8h, v2.h[3]\n" + "smlal2 v15.4s, v27.8h, v3.h[3]\n" + "smlal2 v17.4s, v27.8h, v4.h[3]\n" + "smlal2 v19.4s, v27.8h, v5.h[3]\n" + "smlal2 v21.4s, v27.8h, v6.h[3]\n" + "smlal2 v23.4s, v27.8h, v7.h[3]\n" + + "ld1 {v27.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v26.4h, v0.h[4]\n" + "smlal v10.4s, v26.4h, v1.h[4]\n" + "smlal v12.4s, v26.4h, v2.h[4]\n" + "smlal v14.4s, v26.4h, v3.h[4]\n" + "smlal v16.4s, v26.4h, v4.h[4]\n" + "smlal v18.4s, v26.4h, v5.h[4]\n" + "smlal v20.4s, v26.4h, v6.h[4]\n" + "smlal v22.4s, v26.4h, v7.h[4]\n" + "sshll v27.8h, v27.8b, #0\n" + "smlal2 v9.4s, v26.8h, v0.h[4]\n" + "smlal2 v11.4s, v26.8h, v1.h[4]\n" + "smlal2 v13.4s, v26.8h, v2.h[4]\n" + "smlal2 v15.4s, v26.8h, v3.h[4]\n" + "smlal2 v17.4s, v26.8h, v4.h[4]\n" + "smlal2 v19.4s, v26.8h, v5.h[4]\n" + "smlal2 v21.4s, v26.8h, v6.h[4]\n" + "smlal2 v23.4s, v26.8h, v7.h[4]\n" + + "ld1 {v26.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v27.4h, v0.h[5]\n" + "smlal v10.4s, v27.4h, v1.h[5]\n" + "smlal v12.4s, v27.4h, v2.h[5]\n" + "smlal v14.4s, v27.4h, v3.h[5]\n" + "smlal v16.4s, v27.4h, v4.h[5]\n" + "smlal v18.4s, v27.4h, v5.h[5]\n" + "smlal v20.4s, v27.4h, v6.h[5]\n" + "smlal v22.4s, v27.4h, v7.h[5]\n" + "sshll v26.8h, v26.8b, #0\n" + "smlal2 v9.4s, v27.8h, v0.h[5]\n" + "smlal2 v11.4s, v27.8h, v1.h[5]\n" + "smlal2 v13.4s, v27.8h, v2.h[5]\n" + "smlal2 v15.4s, v27.8h, v3.h[5]\n" + "smlal2 v17.4s, v27.8h, v4.h[5]\n" + "smlal2 v19.4s, v27.8h, v5.h[5]\n" + "smlal2 v21.4s, v27.8h, v6.h[5]\n" + "smlal2 v23.4s, v27.8h, v7.h[5]\n" + + "ld1 {v27.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v26.4h, v0.h[6]\n" + "smlal v10.4s, v26.4h, v1.h[6]\n" + "smlal v12.4s, v26.4h, v2.h[6]\n" + "smlal v14.4s, v26.4h, v3.h[6]\n" + "smlal v16.4s, v26.4h, v4.h[6]\n" + "smlal v18.4s, v26.4h, v5.h[6]\n" + "smlal v20.4s, v26.4h, v6.h[6]\n" + "smlal v22.4s, v26.4h, v7.h[6]\n" + "sshll v27.8h, v27.8b, #0\n" + "smlal2 v9.4s, v26.8h, v0.h[6]\n" + "smlal2 v11.4s, v26.8h, v1.h[6]\n" + "smlal2 v13.4s, v26.8h, v2.h[6]\n" + "smlal2 v15.4s, v26.8h, v3.h[6]\n" + "smlal2 v17.4s, v26.8h, v4.h[6]\n" + "smlal2 v19.4s, v26.8h, v5.h[6]\n" + "smlal2 v21.4s, v26.8h, v6.h[6]\n" + "smlal2 v23.4s, v26.8h, v7.h[6]\n" + + "smlal v8.4s, v27.4h, v0.h[7]\n" + "smlal v10.4s, v27.4h, v1.h[7]\n" + "smlal v12.4s, v27.4h, v2.h[7]\n" + "smlal v14.4s, v27.4h, v3.h[7]\n" + "smlal v16.4s, v27.4h, v4.h[7]\n" + "smlal v18.4s, v27.4h, v5.h[7]\n" + "smlal v20.4s, v27.4h, v6.h[7]\n" + "smlal v22.4s, v27.4h, v7.h[7]\n" + "smlal2 v9.4s, v27.8h, v0.h[7]\n" + "smlal2 v11.4s, v27.8h, v1.h[7]\n" + "smlal2 v13.4s, v27.8h, v2.h[7]\n" + "smlal2 v15.4s, v27.8h, v3.h[7]\n" + "smlal2 v17.4s, v27.8h, v4.h[7]\n" + "smlal2 v19.4s, v27.8h, v5.h[7]\n" + "smlal2 v21.4s, v27.8h, v6.h[7]\n" + "smlal2 v23.4s, v27.8h, v7.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" + "stp q8, q9, [%[output]]\n" + "stp q10, q11, [x1]\n" + "stp q12, q13, [x2]\n" + "stp q14, q15, [x3]\n" + "stp q16, q17, [x4]\n" + "stp q18, q19, [x5]\n" + "stp q20, q21, [x6]\n" + "stp q22, q23, [x7]\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [output] "+r"(output) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v26", "v27", "x1", + "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); +} + +/** + * Overview of register layout: + * + * A 8x4x8 cell of Rhs is stored in 8bit in q16-q17 + * A 8x8x8 cell of Lhs is stored in 8bit in q0-q7 + * A 8x4 block of accumulators is stored in 32bit in q8-q15 + * + * +--------+ + * |v16[0-4]| + * Rhs +--------+ + * |v17[0-4]| + * Lhs +--------+ + * + * +--------+ - - - - +--------+ + * |v0[0-8]| | v8[0-4]| + * |v1[0-8]| | v9[0-4]| + * |v2[0-8]| |v10[0-4]| + * |v3[0-8]| |v11[0-4]| + * |v4[0-8]| |v12[0-4]| + * |v5[0-8]| |v13[0-4]| + * |v6[0-8]| |v14[0-4]| + * |v7[0-8]| |v15[0-4]| + * +--------+ - - - - +--------+ + * + * Accumulator + */ + +static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, + size_t n_remain) { + K /= 8; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + int32_t* outptr4; + int32_t* outptr5; + int32_t* outptr6; + int32_t* outptr7; + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" reg_index ", [%[x0]] \n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "101" n ":\n" + +#define LOAD_C \ + LOAD_LINE("8", "0") \ + LOAD_LINE("9", "1") \ + LOAD_LINE("10", "2") \ + LOAD_LINE("11", "3") \ + LOAD_LINE("12", "4") \ + LOAD_LINE("13", "5") \ + LOAD_LINE("14", "6") \ + LOAD_LINE("15", "7") + +#define STORE_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 102" n "f\n" \ + "str q" reg_index ", [%[x0]]\n" \ + "b 103" n "f\n" \ + "102" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "103" n ":\n" + +#define STORE_C \ + STORE_LINE("8", "0") \ + STORE_LINE("9", "1") \ + STORE_LINE("10", "2") \ + STORE_LINE("11", "3") \ + STORE_LINE("12", "4") \ + STORE_LINE("13", "5") \ + STORE_LINE("14", "6") \ + STORE_LINE("15", "7") + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "add %[outptr4], %[outptr3], %x[LDC]\n" + "add %[outptr5], %[outptr4], %x[LDC]\n" + "add %[outptr6], %[outptr5], %x[LDC]\n" + "add %[outptr7], %[outptr6], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + + "2: \n" + "ld1 {v16.s}[0], [%[b_ptr]], 4\n" + "ld1 {v0.8b}, [%[a_ptr]], 8\n" + "ld1 {v1.8b}, [%[a_ptr]], 8\n" + "ld1 {v2.8b}, [%[a_ptr]], 8\n" + "ld1 {v3.8b}, [%[a_ptr]], 8\n" + "ld1 {v4.8b}, [%[a_ptr]], 8\n" + "ld1 {v5.8b}, [%[a_ptr]], 8\n" + "ld1 {v6.8b}, [%[a_ptr]], 8\n" + "ld1 {v7.8b}, [%[a_ptr]], 8\n" + "sshll v16.8h, v16.8b, #0\n" + "sshll v0.8h, v0.8b, #0\n" + "sshll v1.8h, v1.8b, #0\n" + "sshll v2.8h, v2.8b, #0\n" + "sshll v3.8h, v3.8b, #0\n" + "sshll v4.8h, v4.8b, #0\n" + "sshll v5.8h, v5.8b, #0\n" + "sshll v6.8h, v6.8b, #0\n" + "sshll v7.8h, v7.8b, #0\n" + + "ld1 {v17.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v16.4h, v0.h[0]\n" + "smlal v9.4s, v16.4h, v1.h[0]\n" + "smlal v10.4s, v16.4h, v2.h[0]\n" + "smlal v11.4s, v16.4h, v3.h[0]\n" + "sshll v17.8h, v17.8b, #0\n" + "smlal v12.4s, v16.4h, v4.h[0]\n" + "smlal v13.4s, v16.4h, v5.h[0]\n" + "smlal v14.4s, v16.4h, v6.h[0]\n" + "smlal v15.4s, v16.4h, v7.h[0]\n" + + "ld1 {v16.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v17.4h, v0.h[1]\n" + "smlal v9.4s, v17.4h, v1.h[1]\n" + "smlal v10.4s, v17.4h, v2.h[1]\n" + "smlal v11.4s, v17.4h, v3.h[1]\n" + "sshll v16.8h, v16.8b, #0\n" + "smlal v12.4s, v17.4h, v4.h[1]\n" + "smlal v13.4s, v17.4h, v5.h[1]\n" + "smlal v14.4s, v17.4h, v6.h[1]\n" + "smlal v15.4s, v17.4h, v7.h[1]\n" + + "ld1 {v17.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v16.4h, v0.h[2]\n" + "smlal v9.4s, v16.4h, v1.h[2]\n" + "smlal v10.4s, v16.4h, v2.h[2]\n" + "smlal v11.4s, v16.4h, v3.h[2]\n" + "sshll v17.8h, v17.8b, #0\n" + "smlal v12.4s, v16.4h, v4.h[2]\n" + "smlal v13.4s, v16.4h, v5.h[2]\n" + "smlal v14.4s, v16.4h, v6.h[2]\n" + "smlal v15.4s, v16.4h, v7.h[2]\n" + + "ld1 {v16.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v17.4h, v0.h[3]\n" + "smlal v9.4s, v17.4h, v1.h[3]\n" + "smlal v10.4s, v17.4h, v2.h[3]\n" + "smlal v11.4s, v17.4h, v3.h[3]\n" + "sshll v16.8h, v16.8b, #0\n" + "smlal v12.4s, v17.4h, v4.h[3]\n" + "smlal v13.4s, v17.4h, v5.h[3]\n" + "smlal v14.4s, v17.4h, v6.h[3]\n" + "smlal v15.4s, v17.4h, v7.h[3]\n" + + "ld1 {v17.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v16.4h, v0.h[4]\n" + "smlal v9.4s, v16.4h, v1.h[4]\n" + "smlal v10.4s, v16.4h, v2.h[4]\n" + "smlal v11.4s, v16.4h, v3.h[4]\n" + "sshll v17.8h, v17.8b, #0\n" + "smlal v12.4s, v16.4h, v4.h[4]\n" + "smlal v13.4s, v16.4h, v5.h[4]\n" + "smlal v14.4s, v16.4h, v6.h[4]\n" + "smlal v15.4s, v16.4h, v7.h[4]\n" + + "ld1 {v16.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v17.4h, v0.h[5]\n" + "smlal v9.4s, v17.4h, v1.h[5]\n" + "smlal v10.4s, v17.4h, v2.h[5]\n" + "smlal v11.4s, v17.4h, v3.h[5]\n" + "sshll v16.8h, v16.8b, #0\n" + "smlal v12.4s, v17.4h, v4.h[5]\n" + "smlal v13.4s, v17.4h, v5.h[5]\n" + "smlal v14.4s, v17.4h, v6.h[5]\n" + "smlal v15.4s, v17.4h, v7.h[5]\n" + + "ld1 {v17.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v16.4h, v0.h[6]\n" + "smlal v9.4s, v16.4h, v1.h[6]\n" + "smlal v10.4s, v16.4h, v2.h[6]\n" + "smlal v11.4s, v16.4h, v3.h[6]\n" + "sshll v17.8h, v17.8b, #0\n" + "smlal v12.4s, v16.4h, v4.h[6]\n" + "smlal v13.4s, v16.4h, v5.h[6]\n" + "smlal v14.4s, v16.4h, v6.h[6]\n" + "smlal v15.4s, v16.4h, v7.h[6]\n" + + "smlal v8.4s, v17.4h, v0.h[7]\n" + "smlal v9.4s, v17.4h, v1.h[7]\n" + "smlal v10.4s, v17.4h, v2.h[7]\n" + "smlal v11.4s, v17.4h, v3.h[7]\n" + "smlal v12.4s, v17.4h, v4.h[7]\n" + "smlal v13.4s, v17.4h, v5.h[7]\n" + "smlal v14.4s, v17.4h, v6.h[7]\n" + "smlal v15.4s, v17.4h, v7.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), + [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), + [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +/** + * Overview of register layout: + * + * A 8x8x8 cell of Rhs is stored in 8bit in q12-q13 + * A 8x8x4 cell of Lhs is stored in 8bit in q0-q3 + * A 4x8 block of accumulators is stored in 32bit in q4-q11 + * + * +--------+--------+ + * |v12[0-8]|v13[0-8]| + * Rhs +--------+--------+ + * Lhs | | | + * + * +--------+ - - - - +-----------------+ + * |v0[0-8]| | v4[0-4]| v5[0-4]| + * |v1[0-8]| | v6[0-4]| v7[0-4]| + * |v2[0-8]| | v8[0-4]| v9[0-4]| + * |v3[0-8]| |v10[0-4]|v11[0-4]| + * +--------+ - - - - +-----------------+ + * + * Accumulator + */ + +static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, + size_t m_remain) { + K /= 8; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(v1, v2, m) \ + "cbz %[x0], 100f\n" \ + "ldp " v1 "," v2 ", [%[outptr" m "]]\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %x[m_remain]\n" \ + LOAD_LINE("q4", "q5", "0") \ + LOAD_LINE("q6", "q7", "1") \ + LOAD_LINE("q8", "q9", "2") \ + LOAD_LINE("q10", "q11", "3") \ + "100:\n" + +#define STORE_LINE(v1, v2, m) \ + "cbz %[x0], 101f\n" \ + "stp " v1 "," v2", [%[outptr" m "]]\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %x[m_remain]\n" \ + STORE_LINE("q4", "q5", "0") \ + STORE_LINE("q6", "q7", "1") \ + STORE_LINE("q8", "q9", "2") \ + STORE_LINE("q10", "q11", "3") \ + "101:\n" + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + + "2: \n" + "ld1 {v12.8b}, [%[b_ptr]], 8\n" + "ld1 {v0.8b}, [%[a_ptr]], 8\n" + "ld1 {v1.8b}, [%[a_ptr]], 8\n" + "ld1 {v2.8b}, [%[a_ptr]], 8\n" + "ld1 {v3.8b}, [%[a_ptr]], 8\n" + "sshll v12.8h, v12.8b, #0\n" + "sshll v0.8h, v0.8b, #0\n" + "sshll v1.8h, v1.8b, #0\n" + "sshll v2.8h, v2.8b, #0\n" + "sshll v3.8h, v3.8b, #0\n" + + "ld1 {v13.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v12.4h, v0.h[0]\n" + "smlal v6.4s, v12.4h, v1.h[0]\n" + "smlal v8.4s, v12.4h, v2.h[0]\n" + "smlal v10.4s, v12.4h, v3.h[0]\n" + "sshll v13.8h, v13.8b, #0\n" + "smlal2 v5.4s, v12.8h, v0.h[0]\n" + "smlal2 v7.4s, v12.8h, v1.h[0]\n" + "smlal2 v9.4s, v12.8h, v2.h[0]\n" + "smlal2 v11.4s, v12.8h, v3.h[0]\n" + + "ld1 {v12.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v13.4h, v0.h[1]\n" + "smlal v6.4s, v13.4h, v1.h[1]\n" + "smlal v8.4s, v13.4h, v2.h[1]\n" + "smlal v10.4s, v13.4h, v3.h[1]\n" + "sshll v12.8h, v12.8b, #0\n" + "smlal2 v5.4s, v13.8h, v0.h[1]\n" + "smlal2 v7.4s, v13.8h, v1.h[1]\n" + "smlal2 v9.4s, v13.8h, v2.h[1]\n" + "smlal2 v11.4s, v13.8h, v3.h[1]\n" + + "ld1 {v13.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v12.4h, v0.h[2]\n" + "smlal v6.4s, v12.4h, v1.h[2]\n" + "smlal v8.4s, v12.4h, v2.h[2]\n" + "smlal v10.4s, v12.4h, v3.h[2]\n" + "sshll v13.8h, v13.8b, #0\n" + "smlal2 v5.4s, v12.8h, v0.h[2]\n" + "smlal2 v7.4s, v12.8h, v1.h[2]\n" + "smlal2 v9.4s, v12.8h, v2.h[2]\n" + "smlal2 v11.4s, v12.8h, v3.h[2]\n" + + "ld1 {v12.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v13.4h, v0.h[3]\n" + "smlal v6.4s, v13.4h, v1.h[3]\n" + "smlal v8.4s, v13.4h, v2.h[3]\n" + "smlal v10.4s, v13.4h, v3.h[3]\n" + "sshll v12.8h, v12.8b, #0\n" + "smlal2 v5.4s, v13.8h, v0.h[3]\n" + "smlal2 v7.4s, v13.8h, v1.h[3]\n" + "smlal2 v9.4s, v13.8h, v2.h[3]\n" + "smlal2 v11.4s, v13.8h, v3.h[3]\n" + + "ld1 {v13.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v12.4h, v0.h[4]\n" + "smlal v6.4s, v12.4h, v1.h[4]\n" + "smlal v8.4s, v12.4h, v2.h[4]\n" + "smlal v10.4s, v12.4h, v3.h[4]\n" + "sshll v13.8h, v13.8b, #0\n" + "smlal2 v5.4s, v12.8h, v0.h[4]\n" + "smlal2 v7.4s, v12.8h, v1.h[4]\n" + "smlal2 v9.4s, v12.8h, v2.h[4]\n" + "smlal2 v11.4s, v12.8h, v3.h[4]\n" + + "ld1 {v12.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v13.4h, v0.h[5]\n" + "smlal v6.4s, v13.4h, v1.h[5]\n" + "smlal v8.4s, v13.4h, v2.h[5]\n" + "smlal v10.4s, v13.4h, v3.h[5]\n" + "sshll v12.8h, v12.8b, #0\n" + "smlal2 v5.4s, v13.8h, v0.h[5]\n" + "smlal2 v7.4s, v13.8h, v1.h[5]\n" + "smlal2 v9.4s, v13.8h, v2.h[5]\n" + "smlal2 v11.4s, v13.8h, v3.h[5]\n" + + "ld1 {v13.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v12.4h, v0.h[6]\n" + "smlal v6.4s, v12.4h, v1.h[6]\n" + "smlal v8.4s, v12.4h, v2.h[6]\n" + "smlal v10.4s, v12.4h, v3.h[6]\n" + "sshll v13.8h, v13.8b, #0\n" + "smlal2 v5.4s, v12.8h, v0.h[6]\n" + "smlal2 v7.4s, v12.8h, v1.h[6]\n" + "smlal2 v9.4s, v12.8h, v2.h[6]\n" + "smlal2 v11.4s, v12.8h, v3.h[6]\n" + + "smlal v4.4s, v13.4h, v0.h[7]\n" + "smlal v6.4s, v13.4h, v1.h[7]\n" + "smlal v8.4s, v13.4h, v2.h[7]\n" + "smlal v10.4s, v13.4h, v3.h[7]\n" + "smlal2 v5.4s, v13.8h, v0.h[7]\n" + "smlal2 v7.4s, v13.8h, v1.h[7]\n" + "smlal2 v9.4s, v13.8h, v2.h[7]\n" + "smlal2 v11.4s, v13.8h, v3.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + [outptr3] "=r"(outptr3), [x0] "+r"(x0), [m_remain] "+r"(m_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +/** + * Overview of register layout: + * + * A 8x4x8 cell of Rhs is stored in 8bit in q8-q9 + * A 8x8x4 cell of Lhs is stored in 8bit in q0-q3 + * A 4x4 block of accumulators is stored in 32bit in q4-q7 + * + * +--------+ + * | v8[0-4]| + * Rhs +--------+ + * | v9[0-4]| + * Lhs +--------+ + * + * +--------+ - - - - +--------+ + * |v0[0-8]| | v4[0-4]| + * |v1[0-8]| | v5[0-4]| + * |v2[0-8]| | v6[0-4]| + * |v3[0-8]| | v7[0-4]| + * +--------+ - - - - +--------+ + * + * Accumulator + */ + +static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, size_t m_remain, + size_t n_remain) { + K /= 8; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + size_t x0 = 0; + size_t x1 = 0; + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "cbz %[x1], 102f\n" \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" reg_index ", [%[x0]]\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "101" n ":\n" \ + "subs %[x1], %[x1], #1\n" + +#define LOAD_C \ + "mov %[x1], %x[m_remain]\n" \ + LOAD_LINE("4", "0") \ + LOAD_LINE("5", "1") \ + LOAD_LINE("6", "2") \ + LOAD_LINE("7", "3") \ + "102:\n" + +#define STORE_LINE(reg_index, n) \ + "cbz %[x1], 105f\n" \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "str q" reg_index ", [%[x0]]\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "104" n ":\n" \ + "subs %[x1], %[x1], #1\n" + +#define STORE_C \ + "mov %[x1], %x[m_remain]\n" \ + STORE_LINE("4", "0") \ + STORE_LINE("5", "1") \ + STORE_LINE("6", "2") \ + STORE_LINE("7", "3") \ + "105:\n" + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + + "2: \n" + "ld1 {v8.s}[0], [%[b_ptr]], 4\n" + "ld1 {v0.8b}, [%[a_ptr]], 8\n" + "ld1 {v1.8b}, [%[a_ptr]], 8\n" + "ld1 {v2.8b}, [%[a_ptr]], 8\n" + "ld1 {v3.8b}, [%[a_ptr]], 8\n" + "sshll v8.8h, v8.8b, #0\n" + "sshll v0.8h, v0.8b, #0\n" + "sshll v1.8h, v1.8b, #0\n" + "sshll v2.8h, v2.8b, #0\n" + "sshll v3.8h, v3.8b, #0\n" + + "ld1 {v9.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v8.4h, v0.h[0]\n" + "smlal v5.4s, v8.4h, v1.h[0]\n" + "sshll v9.8h, v9.8b, #0\n" + "smlal v6.4s, v8.4h, v2.h[0]\n" + "smlal v7.4s, v8.4h, v3.h[0]\n" + + "ld1 {v8.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v9.4h, v0.h[1]\n" + "smlal v5.4s, v9.4h, v1.h[1]\n" + "sshll v8.8h, v8.8b, #0\n" + "smlal v6.4s, v9.4h, v2.h[1]\n" + "smlal v7.4s, v9.4h, v3.h[1]\n" + + "ld1 {v9.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v8.4h, v0.h[2]\n" + "smlal v5.4s, v8.4h, v1.h[2]\n" + "sshll v9.8h, v9.8b, #0\n" + "smlal v6.4s, v8.4h, v2.h[2]\n" + "smlal v7.4s, v8.4h, v3.h[2]\n" + + "ld1 {v8.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v9.4h, v0.h[3]\n" + "smlal v5.4s, v9.4h, v1.h[3]\n" + "sshll v8.8h, v8.8b, #0\n" + "smlal v6.4s, v9.4h, v2.h[3]\n" + "smlal v7.4s, v9.4h, v3.h[3]\n" + + "ld1 {v9.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v8.4h, v0.h[4]\n" + "smlal v5.4s, v8.4h, v1.h[4]\n" + "sshll v9.8h, v9.8b, #0\n" + "smlal v6.4s, v8.4h, v2.h[4]\n" + "smlal v7.4s, v8.4h, v3.h[4]\n" + + "ld1 {v8.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v9.4h, v0.h[5]\n" + "smlal v5.4s, v9.4h, v1.h[5]\n" + "sshll v8.8h, v8.8b, #0\n" + "smlal v6.4s, v9.4h, v2.h[5]\n" + "smlal v7.4s, v9.4h, v3.h[5]\n" + + "ld1 {v9.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v8.4h, v0.h[6]\n" + "smlal v5.4s, v8.4h, v1.h[6]\n" + "sshll v9.8h, v9.8b, #0\n" + "smlal v6.4s, v8.4h, v2.h[6]\n" + "smlal v7.4s, v8.4h, v3.h[6]\n" + + "smlal v4.4s, v9.4h, v0.h[7]\n" + "smlal v5.4s, v9.4h, v1.h[7]\n" + "smlal v6.4s, v9.4h, v2.h[7]\n" + "smlal v7.4s, v9.4h, v3.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), + [x1] "+r"(x1), [m_remain] "+r"(m_remain), + [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v11", "cc", + "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, + int y0, int ymax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + for (; K > 15; K -= 16) { + interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + } + + if (K > 0) { + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 8, K); + } + } + + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x8_2_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K); + } + } +} + +static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, + int ldin, int x0, int xmax, int k0, + int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 8) * 4; + const int ksize8 = ksize4 * 2; + int8_t* outptr = out; + int8_t* outptr_base = out; + //! 4x4 block output start pos + int8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k += 8) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = x0; + outptr = outptr_base; + + for (; x + 7 < xmax; x += 8) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += ksize8; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, 4); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, xmax - x); + } + + outptr_base += 8 * 8; + outptr_base4 += 4 * 8; + } +} + +static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 8) * 4; + const int ksize8 = ksize4 * 2; + int8_t* outptr = out; + int8_t* outptr_base = out; + int8_t* outptr_interleave = nullptr; + //! 4x4 block output start pos + int8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k += 8) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = x0; + outptr = outptr_base; + + for (; x + 7 < xmax; x += 8) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + outptr_interleave = outptr; + interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr_interleave); + outptr += ksize8; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + outptr_interleave = outptr; + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr_interleave, 4, 4); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + outptr_interleave = outptr; + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr_interleave, 4, xmax - x); + } + + outptr_base += 8 * 8; + outptr_base4 += 4 * 8; + } +} + +static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + constexpr int interleave4 = 32; + constexpr int interleave8 = 64; + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + for (; K > 7; K -= 8) { + transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += interleave8; + } + + if (K > 0) { + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 8, K); + outptr += interleave8; + } + } + + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 7; K -= 8) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += interleave4; + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K); + outptr += interleave4; + } + } +} + +} // namespace matmul_8x8x8 +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen +#endif diff --git a/dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h b/dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h new file mode 100644 index 00000000..5fd94935 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h @@ -0,0 +1,892 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include +#if !(__ARM_FEATURE_DOTPROD) +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_mk4_4x4x16 { + +/** + * Overview of register layout: + * + * A 16x4 cell of Rhs is stored in 8bit in v0-q3. + * B 16x4 cell of Lhs is stored in 8bit in q4-q7 + * C 8x16 block of accumulators is stored in 8bit in q8--q31. + * + * \warning Fast kernel operating on int8 operands. + * It is assumed that one of the two int8 operands only takes values + * in [-127, 127], while the other may freely range in [-128, 127]. + * The issue with both operands taking the value -128 is that: + * -128*-128 + -128*-128 == -32768 overflows int16. + * Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 + * range. That is the basic idea of this kernel. + * + * + * +--------+--------+---------+---------+ + * |v4[0-16]|v5[0-16]| v6[0-16]| v7[0-16]| + * Rhs +--------+--------+---------+---------+ + * | | | | | + * + * Lhs | | | | | + * + * +--------+ - - - - +-------------------------------------+ + * |v0[0-16]| |v16[0-4]|v17[0-4]| v18[0-4]| v19[0-4]| + * |v1[0-16]| |v20[0-4]|v21[0-4]| v22[0-4]| v23[0-4]| + * |v2[0-16]| |v24[0-4]|v25[0-4]| v26[0-4]| v27[0-4]| + * |v3[0-16]| |v28[0-4]|v29[0-4]| v30[0-4]| v31[0-4]| + * +--------+ - - - - +-------------------------------------+ + * + * Accumulator + */ + +static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, bool is_first_k) { + K = div_ceil(K, 16); + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + asm volatile( + // load accumulator C + "ld1 {v0.16b}, [%[a_ptr]], #16\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "ld1 {v1.16b}, [%[a_ptr]], #16\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v19.16b, v19.16b\n" + "eor v21.16b, v19.16b, v19.16b\n" + "ld1 {v4.16b, v5.16b}, [%[b_ptr]], #32\n" + "eor v22.16b, v19.16b, v19.16b\n" + "PRFM PLDL1KEEP, [%[a_ptr], #32]\n" + "eor v23.16b, v19.16b, v19.16b\n" + "eor v24.16b, v19.16b, v19.16b\n" + "PRFM PLDL1KEEP, [%[b_ptr], #32]\n" + "eor v25.16b, v19.16b, v19.16b\n" + "eor v26.16b, v19.16b, v19.16b\n" + "PRFM PLDL1KEEP, [%[b_ptr], #64]\n" + "eor v27.16b, v19.16b, v19.16b\n" + "eor v28.16b, v19.16b, v19.16b\n" + "PRFM PLDL1KEEP, [%[a_ptr], #64]\n" + "eor v29.16b, v19.16b, v19.16b\n" + "eor v30.16b, v19.16b, v19.16b\n" + "PRFM PLDL1KEEP, [%[b_ptr], #128]\n" + "eor v31.16b, v19.16b, v19.16b\n" + + //! if K==1 jump to compute last K + "cmp %w[k], #2\n" + "beq 2f\n" + "blt 3f\n" + + //! K>2 + "1:\n" + //! First k + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v0.8b, v5.8b\n" + "ld1 {v6.16b}, [%[b_ptr]], #16\n" + "smull v12.8h, v1.8b, v4.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "ld1 {v7.16b}, [%[b_ptr]], #16\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v0.16b, v5.16b\n" + "smlal2 v12.8h, v1.16b, v4.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + + "smull v10.8h, v0.8b, v6.8b\n" + "ld1 {v2.16b}, [%[a_ptr]], #16\n" + "smull v11.8h, v0.8b, v7.8b\n" + "smull v14.8h, v1.8b, v6.8b\n" + "ld1 {v3.16b}, [%[a_ptr]], #16\n" + "smull v15.8h, v1.8b, v7.8b\n" + "sadalp v16.4s, v8.8h\n" + "smlal2 v10.8h, v0.16b, v6.16b\n" + "sadalp v17.4s, v9.8h\n" + "smlal2 v11.8h, v0.16b, v7.16b\n" + "sadalp v20.4s, v12.8h\n" + "smlal2 v14.8h, v1.16b, v6.16b\n" + "sadalp v21.4s, v13.8h\n" + "smlal2 v15.8h, v1.16b, v7.16b\n" + + "smull v8.8h, v2.8b, v4.8b\n" + "smull v9.8h, v2.8b, v5.8b\n" + "ld1 {v0.16b}, [%[a_ptr]], #16\n" + "smull v12.8h, v3.8b, v4.8b\n" + "ld1 {v1.16b}, [%[a_ptr]], #16\n" + "smull v13.8h, v3.8b, v5.8b\n" + "sadalp v18.4s, v10.8h\n" + "smlal2 v8.8h, v2.16b, v4.16b\n" + "sadalp v19.4s, v11.8h\n" + "smlal2 v9.8h, v2.16b, v5.16b\n" + "sadalp v22.4s, v14.8h\n" + "smlal2 v12.8h, v3.16b, v4.16b\n" + "sadalp v23.4s, v15.8h\n" + "smlal2 v13.8h, v3.16b, v5.16b\n" + + "smull v10.8h, v2.8b, v6.8b\n" + "smull v11.8h, v2.8b, v7.8b\n" + "ld1 {v4.16b}, [%[b_ptr]], #16\n" + "smull v14.8h, v3.8b, v6.8b\n" + "ld1 {v5.16b}, [%[b_ptr]], #16\n" + "smull v15.8h, v3.8b, v7.8b\n" + "sadalp v24.4s, v8.8h\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "sadalp v25.4s, v9.8h\n" + "smlal2 v11.8h, v2.16b, v7.16b\n" + "sadalp v28.4s, v12.8h\n" + "smlal2 v14.8h, v3.16b, v6.16b\n" + "sadalp v29.4s, v13.8h\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + + //! Second k + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v0.8b, v5.8b\n" + "ld1 {v6.16b}, [%[b_ptr]], #16\n" + "smull v12.8h, v1.8b, v4.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "ld1 {v7.16b}, [%[b_ptr]], #16\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "sadalp v26.4s, v10.8h\n" + "smlal2 v9.8h, v0.16b, v5.16b\n" + "sadalp v27.4s, v11.8h\n" + "smlal2 v12.8h, v1.16b, v4.16b\n" + "sadalp v30.4s, v14.8h\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "sadalp v31.4s, v15.8h\n" + + "smull v10.8h, v0.8b, v6.8b\n" + "ld1 {v2.16b}, [%[a_ptr]], #16\n" + "smull v11.8h, v0.8b, v7.8b\n" + "smull v14.8h, v1.8b, v6.8b\n" + "ld1 {v3.16b}, [%[a_ptr]], #16\n" + "smull v15.8h, v1.8b, v7.8b\n" + "sadalp v16.4s, v8.8h\n" + "smlal2 v10.8h, v0.16b, v6.16b\n" + "sadalp v17.4s, v9.8h\n" + "smlal2 v11.8h, v0.16b, v7.16b\n" + "sadalp v20.4s, v12.8h\n" + "smlal2 v14.8h, v1.16b, v6.16b\n" + "sadalp v21.4s, v13.8h\n" + "smlal2 v15.8h, v1.16b, v7.16b\n" + + "smull v8.8h, v2.8b, v4.8b\n" + "smull v9.8h, v2.8b, v5.8b\n" + "ld1 {v0.16b}, [%[a_ptr]], #16\n" + "smull v12.8h, v3.8b, v4.8b\n" + "ld1 {v1.16b}, [%[a_ptr]], #16\n" + "smull v13.8h, v3.8b, v5.8b\n" + "sadalp v18.4s, v10.8h\n" + "smlal2 v8.8h, v2.16b, v4.16b\n" + "sadalp v19.4s, v11.8h\n" + "smlal2 v9.8h, v2.16b, v5.16b\n" + "sadalp v22.4s, v14.8h\n" + "smlal2 v12.8h, v3.16b, v4.16b\n" + "sadalp v23.4s, v15.8h\n" + "smlal2 v13.8h, v3.16b, v5.16b\n" + + "sub %w[k], %w[k], #2\n" + "cmp %w[k], #2\n" + + "smull v10.8h, v2.8b, v6.8b\n" + "ld1 {v4.16b}, [%[b_ptr]], #16\n" + "smull v11.8h, v2.8b, v7.8b\n" + "ld1 {v5.16b}, [%[b_ptr]], #16\n" + "smull v14.8h, v3.8b, v6.8b\n" + "sadalp v24.4s, v8.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + "sadalp v25.4s, v9.8h\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "sadalp v28.4s, v12.8h\n" + "smlal2 v11.8h, v2.16b, v7.16b\n" + "sadalp v29.4s, v13.8h\n" + "smlal2 v14.8h, v3.16b, v6.16b\n" + "sadalp v26.4s, v10.8h\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + + "sadalp v27.4s, v11.8h\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + "bgt 1b\n" + "blt 3f\n" + + //! K==2 + "2:\n" + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v0.8b, v5.8b\n" + "ld1 {v6.16b}, [%[b_ptr]], #16\n" + "smull v12.8h, v1.8b, v4.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "ld1 {v7.16b}, [%[b_ptr]], #16\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v0.16b, v5.16b\n" + "smlal2 v12.8h, v1.16b, v4.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + + "smull v10.8h, v0.8b, v6.8b\n" + "ld1 {v2.16b}, [%[a_ptr]], #16\n" + "smull v11.8h, v0.8b, v7.8b\n" + "smull v14.8h, v1.8b, v6.8b\n" + "ld1 {v3.16b}, [%[a_ptr]], #16\n" + "smull v15.8h, v1.8b, v7.8b\n" + "sadalp v16.4s, v8.8h\n" + "smlal2 v10.8h, v0.16b, v6.16b\n" + "sadalp v17.4s, v9.8h\n" + "smlal2 v11.8h, v0.16b, v7.16b\n" + "sadalp v20.4s, v12.8h\n" + "smlal2 v14.8h, v1.16b, v6.16b\n" + "sadalp v21.4s, v13.8h\n" + "smlal2 v15.8h, v1.16b, v7.16b\n" + + "smull v8.8h, v2.8b, v4.8b\n" + "smull v9.8h, v2.8b, v5.8b\n" + "ld1 {v0.16b}, [%[a_ptr]], #16\n" + "smull v12.8h, v3.8b, v4.8b\n" + "ld1 {v1.16b}, [%[a_ptr]], #16\n" + "smull v13.8h, v3.8b, v5.8b\n" + "sadalp v18.4s, v10.8h\n" + "smlal2 v8.8h, v2.16b, v4.16b\n" + "sadalp v19.4s, v11.8h\n" + "smlal2 v9.8h, v2.16b, v5.16b\n" + "sadalp v22.4s, v14.8h\n" + "smlal2 v12.8h, v3.16b, v4.16b\n" + "sadalp v23.4s, v15.8h\n" + "smlal2 v13.8h, v3.16b, v5.16b\n" + + "smull v10.8h, v2.8b, v6.8b\n" + "smull v11.8h, v2.8b, v7.8b\n" + "ld1 {v4.16b}, [%[b_ptr]], #16\n" + "smull v14.8h, v3.8b, v6.8b\n" + "ld1 {v5.16b}, [%[b_ptr]], #16\n" + "smull v15.8h, v3.8b, v7.8b\n" + "sadalp v24.4s, v8.8h\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "sadalp v25.4s, v9.8h\n" + "smlal2 v11.8h, v2.16b, v7.16b\n" + "sadalp v28.4s, v12.8h\n" + "smlal2 v14.8h, v3.16b, v6.16b\n" + "sadalp v29.4s, v13.8h\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "sadalp v26.4s, v10.8h\n" + "sadalp v27.4s, v11.8h\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + //! K==1 + "3:\n" + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v0.8b, v5.8b\n" + "ld1 {v6.16b}, [%[b_ptr]], #16\n" + "smull v12.8h, v1.8b, v4.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "ld1 {v7.16b}, [%[b_ptr]], #16\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v0.16b, v5.16b\n" + "smlal2 v12.8h, v1.16b, v4.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + + "smull v10.8h, v0.8b, v6.8b\n" + "ld1 {v2.16b}, [%[a_ptr]], #16\n" + "smull v11.8h, v0.8b, v7.8b\n" + "smull v14.8h, v1.8b, v6.8b\n" + "ld1 {v3.16b}, [%[a_ptr]], #16\n" + "smull v15.8h, v1.8b, v7.8b\n" + "sadalp v16.4s, v8.8h\n" + "smlal2 v10.8h, v0.16b, v6.16b\n" + "sadalp v17.4s, v9.8h\n" + "smlal2 v11.8h, v0.16b, v7.16b\n" + "sadalp v20.4s, v12.8h\n" + "smlal2 v14.8h, v1.16b, v6.16b\n" + "sadalp v21.4s, v13.8h\n" + "smlal2 v15.8h, v1.16b, v7.16b\n" + + "smull v8.8h, v2.8b, v4.8b\n" + "smull v9.8h, v2.8b, v5.8b\n" + "smull v12.8h, v3.8b, v4.8b\n" + "smull v13.8h, v3.8b, v5.8b\n" + "sadalp v18.4s, v10.8h\n" + "smlal2 v8.8h, v2.16b, v4.16b\n" + "sadalp v19.4s, v11.8h\n" + "smlal2 v9.8h, v2.16b, v5.16b\n" + "sadalp v22.4s, v14.8h\n" + "smlal2 v12.8h, v3.16b, v4.16b\n" + "sadalp v23.4s, v15.8h\n" + "smlal2 v13.8h, v3.16b, v5.16b\n" + + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v24.4s, v8.8h\n" + "smull v11.8h, v2.8b, v7.8b\n" + "sadalp v25.4s, v9.8h\n" + "smull v14.8h, v3.8b, v6.8b\n" + "sadalp v28.4s, v12.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + "sadalp v29.4s, v13.8h\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v2.16b, v7.16b\n" + "sadalp v26.4s, v10.8h\n" + "smlal2 v14.8h, v3.16b, v6.16b\n" + "sadalp v27.4s, v11.8h\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + "addp v4.4s, v16.4s, v20.4s\n" + "addp v5.4s, v24.4s, v28.4s\n" + "addp v6.4s, v17.4s, v21.4s\n" + "addp v7.4s, v25.4s, v29.4s\n" + "addp v8.4s, v18.4s, v22.4s\n" + "addp v9.4s, v26.4s, v30.4s\n" + "addp v10.4s, v19.4s, v23.4s\n" + "addp v11.4s, v27.4s, v31.4s\n" + + "cmp %w[is_first_k], #1\n" + + "addp v0.4s, v4.4s, v5.4s\n" + "addp v1.4s, v6.4s, v7.4s\n" + "addp v2.4s, v8.4s, v9.4s\n" + "addp v3.4s, v10.4s, v11.4s\n" + + "beq 6f\n" + + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output]]\n" + "add v0.4s, v0.4s, v8.4s\n" + "add v1.4s, v1.4s, v9.4s\n" + "add v2.4s, v2.4s, v10.4s\n" + "add v3.4s, v3.4s, v11.4s\n" + + "6:\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "cc", "memory"); +} + +static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, bool is_first_k, size_t remain_n) { + K = div_ceil(K, 16); + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + asm volatile( + // load accumulator C + "ld1 {v0.16b}, [%[a_ptr]], #16\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "ld1 {v1.16b}, [%[a_ptr]], #16\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v19.16b, v19.16b\n" + "eor v21.16b, v19.16b, v19.16b\n" + "ld1 {v4.16b, v5.16b}, [%[b_ptr]], #32\n" + "eor v22.16b, v19.16b, v19.16b\n" + "PRFM PLDL1KEEP, [%[a_ptr], #32]\n" + "eor v23.16b, v19.16b, v19.16b\n" + "eor v24.16b, v19.16b, v19.16b\n" + "PRFM PLDL1KEEP, [%[b_ptr], #32]\n" + "eor v25.16b, v19.16b, v19.16b\n" + "eor v26.16b, v19.16b, v19.16b\n" + "PRFM PLDL1KEEP, [%[b_ptr], #64]\n" + "eor v27.16b, v19.16b, v19.16b\n" + "eor v28.16b, v19.16b, v19.16b\n" + "PRFM PLDL1KEEP, [%[a_ptr], #64]\n" + "eor v29.16b, v19.16b, v19.16b\n" + "eor v30.16b, v19.16b, v19.16b\n" + "PRFM PLDL1KEEP, [%[b_ptr], #128]\n" + "eor v31.16b, v19.16b, v19.16b\n" + + //! if K==1 jump to compute last K + "cmp %w[k], #2\n" + "beq 2f\n" + "blt 3f\n" + + //! K>2 + "1:\n" + //! First k + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v0.8b, v5.8b\n" + "ld1 {v6.16b}, [%[b_ptr]], #16\n" + "smull v12.8h, v1.8b, v4.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "ld1 {v7.16b}, [%[b_ptr]], #16\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v0.16b, v5.16b\n" + "smlal2 v12.8h, v1.16b, v4.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + + "smull v10.8h, v0.8b, v6.8b\n" + "ld1 {v2.16b}, [%[a_ptr]], #16\n" + "smull v11.8h, v0.8b, v7.8b\n" + "smull v14.8h, v1.8b, v6.8b\n" + "ld1 {v3.16b}, [%[a_ptr]], #16\n" + "smull v15.8h, v1.8b, v7.8b\n" + "sadalp v16.4s, v8.8h\n" + "smlal2 v10.8h, v0.16b, v6.16b\n" + "sadalp v17.4s, v9.8h\n" + "smlal2 v11.8h, v0.16b, v7.16b\n" + "sadalp v20.4s, v12.8h\n" + "smlal2 v14.8h, v1.16b, v6.16b\n" + "sadalp v21.4s, v13.8h\n" + "smlal2 v15.8h, v1.16b, v7.16b\n" + + "smull v8.8h, v2.8b, v4.8b\n" + "smull v9.8h, v2.8b, v5.8b\n" + "ld1 {v0.16b}, [%[a_ptr]], #16\n" + "smull v12.8h, v3.8b, v4.8b\n" + "ld1 {v1.16b}, [%[a_ptr]], #16\n" + "smull v13.8h, v3.8b, v5.8b\n" + "sadalp v18.4s, v10.8h\n" + "smlal2 v8.8h, v2.16b, v4.16b\n" + "sadalp v19.4s, v11.8h\n" + "smlal2 v9.8h, v2.16b, v5.16b\n" + "sadalp v22.4s, v14.8h\n" + "smlal2 v12.8h, v3.16b, v4.16b\n" + "sadalp v23.4s, v15.8h\n" + "smlal2 v13.8h, v3.16b, v5.16b\n" + + "smull v10.8h, v2.8b, v6.8b\n" + "smull v11.8h, v2.8b, v7.8b\n" + "ld1 {v4.16b}, [%[b_ptr]], #16\n" + "smull v14.8h, v3.8b, v6.8b\n" + "ld1 {v5.16b}, [%[b_ptr]], #16\n" + "smull v15.8h, v3.8b, v7.8b\n" + "sadalp v24.4s, v8.8h\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "sadalp v25.4s, v9.8h\n" + "smlal2 v11.8h, v2.16b, v7.16b\n" + "sadalp v28.4s, v12.8h\n" + "smlal2 v14.8h, v3.16b, v6.16b\n" + "sadalp v29.4s, v13.8h\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + + //! Second k + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v0.8b, v5.8b\n" + "ld1 {v6.16b}, [%[b_ptr]], #16\n" + "smull v12.8h, v1.8b, v4.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "ld1 {v7.16b}, [%[b_ptr]], #16\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "sadalp v26.4s, v10.8h\n" + "smlal2 v9.8h, v0.16b, v5.16b\n" + "sadalp v27.4s, v11.8h\n" + "smlal2 v12.8h, v1.16b, v4.16b\n" + "sadalp v30.4s, v14.8h\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "sadalp v31.4s, v15.8h\n" + + "smull v10.8h, v0.8b, v6.8b\n" + "ld1 {v2.16b}, [%[a_ptr]], #16\n" + "smull v11.8h, v0.8b, v7.8b\n" + "smull v14.8h, v1.8b, v6.8b\n" + "ld1 {v3.16b}, [%[a_ptr]], #16\n" + "smull v15.8h, v1.8b, v7.8b\n" + "sadalp v16.4s, v8.8h\n" + "smlal2 v10.8h, v0.16b, v6.16b\n" + "sadalp v17.4s, v9.8h\n" + "smlal2 v11.8h, v0.16b, v7.16b\n" + "sadalp v20.4s, v12.8h\n" + "smlal2 v14.8h, v1.16b, v6.16b\n" + "sadalp v21.4s, v13.8h\n" + "smlal2 v15.8h, v1.16b, v7.16b\n" + + "smull v8.8h, v2.8b, v4.8b\n" + "smull v9.8h, v2.8b, v5.8b\n" + "ld1 {v0.16b}, [%[a_ptr]], #16\n" + "smull v12.8h, v3.8b, v4.8b\n" + "ld1 {v1.16b}, [%[a_ptr]], #16\n" + "smull v13.8h, v3.8b, v5.8b\n" + "sadalp v18.4s, v10.8h\n" + "smlal2 v8.8h, v2.16b, v4.16b\n" + "sadalp v19.4s, v11.8h\n" + "smlal2 v9.8h, v2.16b, v5.16b\n" + "sadalp v22.4s, v14.8h\n" + "smlal2 v12.8h, v3.16b, v4.16b\n" + "sadalp v23.4s, v15.8h\n" + "smlal2 v13.8h, v3.16b, v5.16b\n" + + "sub %w[k], %w[k], #2\n" + "cmp %w[k], #2\n" + + "smull v10.8h, v2.8b, v6.8b\n" + "ld1 {v4.16b}, [%[b_ptr]], #16\n" + "smull v11.8h, v2.8b, v7.8b\n" + "ld1 {v5.16b}, [%[b_ptr]], #16\n" + "smull v14.8h, v3.8b, v6.8b\n" + "sadalp v24.4s, v8.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + "sadalp v25.4s, v9.8h\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "sadalp v28.4s, v12.8h\n" + "smlal2 v11.8h, v2.16b, v7.16b\n" + "sadalp v29.4s, v13.8h\n" + "smlal2 v14.8h, v3.16b, v6.16b\n" + "sadalp v26.4s, v10.8h\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + + "sadalp v27.4s, v11.8h\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + "bgt 1b\n" + "blt 3f\n" + + //! K==2 + "2:\n" + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v0.8b, v5.8b\n" + "ld1 {v6.16b}, [%[b_ptr]], #16\n" + "smull v12.8h, v1.8b, v4.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "ld1 {v7.16b}, [%[b_ptr]], #16\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v0.16b, v5.16b\n" + "smlal2 v12.8h, v1.16b, v4.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + + "smull v10.8h, v0.8b, v6.8b\n" + "ld1 {v2.16b}, [%[a_ptr]], #16\n" + "smull v11.8h, v0.8b, v7.8b\n" + "smull v14.8h, v1.8b, v6.8b\n" + "ld1 {v3.16b}, [%[a_ptr]], #16\n" + "smull v15.8h, v1.8b, v7.8b\n" + "sadalp v16.4s, v8.8h\n" + "smlal2 v10.8h, v0.16b, v6.16b\n" + "sadalp v17.4s, v9.8h\n" + "smlal2 v11.8h, v0.16b, v7.16b\n" + "sadalp v20.4s, v12.8h\n" + "smlal2 v14.8h, v1.16b, v6.16b\n" + "sadalp v21.4s, v13.8h\n" + "smlal2 v15.8h, v1.16b, v7.16b\n" + + "smull v8.8h, v2.8b, v4.8b\n" + "smull v9.8h, v2.8b, v5.8b\n" + "ld1 {v0.16b}, [%[a_ptr]], #16\n" + "smull v12.8h, v3.8b, v4.8b\n" + "ld1 {v1.16b}, [%[a_ptr]], #16\n" + "smull v13.8h, v3.8b, v5.8b\n" + "sadalp v18.4s, v10.8h\n" + "smlal2 v8.8h, v2.16b, v4.16b\n" + "sadalp v19.4s, v11.8h\n" + "smlal2 v9.8h, v2.16b, v5.16b\n" + "sadalp v22.4s, v14.8h\n" + "smlal2 v12.8h, v3.16b, v4.16b\n" + "sadalp v23.4s, v15.8h\n" + "smlal2 v13.8h, v3.16b, v5.16b\n" + + "smull v10.8h, v2.8b, v6.8b\n" + "smull v11.8h, v2.8b, v7.8b\n" + "ld1 {v4.16b}, [%[b_ptr]], #16\n" + "smull v14.8h, v3.8b, v6.8b\n" + "ld1 {v5.16b}, [%[b_ptr]], #16\n" + "smull v15.8h, v3.8b, v7.8b\n" + "sadalp v24.4s, v8.8h\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "sadalp v25.4s, v9.8h\n" + "smlal2 v11.8h, v2.16b, v7.16b\n" + "sadalp v28.4s, v12.8h\n" + "smlal2 v14.8h, v3.16b, v6.16b\n" + "sadalp v29.4s, v13.8h\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "sadalp v26.4s, v10.8h\n" + "sadalp v27.4s, v11.8h\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + //! K==1 + "3:\n" + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v0.8b, v5.8b\n" + "ld1 {v6.16b}, [%[b_ptr]], #16\n" + "smull v12.8h, v1.8b, v4.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "ld1 {v7.16b}, [%[b_ptr]], #16\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v0.16b, v5.16b\n" + "smlal2 v12.8h, v1.16b, v4.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + + "smull v10.8h, v0.8b, v6.8b\n" + "ld1 {v2.16b}, [%[a_ptr]], #16\n" + "smull v11.8h, v0.8b, v7.8b\n" + "smull v14.8h, v1.8b, v6.8b\n" + "ld1 {v3.16b}, [%[a_ptr]], #16\n" + "smull v15.8h, v1.8b, v7.8b\n" + "sadalp v16.4s, v8.8h\n" + "smlal2 v10.8h, v0.16b, v6.16b\n" + "sadalp v17.4s, v9.8h\n" + "smlal2 v11.8h, v0.16b, v7.16b\n" + "sadalp v20.4s, v12.8h\n" + "smlal2 v14.8h, v1.16b, v6.16b\n" + "sadalp v21.4s, v13.8h\n" + "smlal2 v15.8h, v1.16b, v7.16b\n" + + "smull v8.8h, v2.8b, v4.8b\n" + "smull v9.8h, v2.8b, v5.8b\n" + "smull v12.8h, v3.8b, v4.8b\n" + "smull v13.8h, v3.8b, v5.8b\n" + "sadalp v18.4s, v10.8h\n" + "smlal2 v8.8h, v2.16b, v4.16b\n" + "sadalp v19.4s, v11.8h\n" + "smlal2 v9.8h, v2.16b, v5.16b\n" + "sadalp v22.4s, v14.8h\n" + "smlal2 v12.8h, v3.16b, v4.16b\n" + "sadalp v23.4s, v15.8h\n" + "smlal2 v13.8h, v3.16b, v5.16b\n" + + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v24.4s, v8.8h\n" + "smull v11.8h, v2.8b, v7.8b\n" + "sadalp v25.4s, v9.8h\n" + "smull v14.8h, v3.8b, v6.8b\n" + "sadalp v28.4s, v12.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + "sadalp v29.4s, v13.8h\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v2.16b, v7.16b\n" + "sadalp v26.4s, v10.8h\n" + "smlal2 v14.8h, v3.16b, v6.16b\n" + "sadalp v27.4s, v11.8h\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + "addp v4.4s, v16.4s, v20.4s\n" + "addp v5.4s, v24.4s, v28.4s\n" + "addp v6.4s, v17.4s, v21.4s\n" + "addp v7.4s, v25.4s, v29.4s\n" + "addp v8.4s, v18.4s, v22.4s\n" + "addp v9.4s, v26.4s, v30.4s\n" + "addp v10.4s, v19.4s, v23.4s\n" + "addp v11.4s, v27.4s, v31.4s\n" + + "addp v0.4s, v4.4s, v5.4s\n" + "addp v1.4s, v6.4s, v7.4s\n" + "addp v2.4s, v8.4s, v9.4s\n" + "addp v3.4s, v10.4s, v11.4s\n" + + "cmp %w[is_first_k], #1\n" + "beq 6f\n" + + "cmp %w[remain_n], #3\n" + "beq 1003f\n" + "cmp %w[remain_n], #2\n" + "beq 1002f\n" + "cmp %w[remain_n], #1\n" + "beq 1001f\n" + "1003:\n" + "ld1 {v8.4s, v9.4s, v10.4s}, [%[output]]\n" + "add v0.4s, v0.4s, v8.4s\n" + "add v1.4s, v1.4s, v9.4s\n" + "add v2.4s, v2.4s, v10.4s\n" + "b 6f\n" + "1002:\n" + "ld1 {v8.4s, v9.4s}, [%[output]]\n" + "add v0.4s, v0.4s, v8.4s\n" + "add v1.4s, v1.4s, v9.4s\n" + "b 6f\n" + "1001:\n" + "ld1 {v8.4s}, [%[output]]\n" + "add v0.4s, v0.4s, v8.4s\n" + + "6:\n" + "cmp %w[remain_n], #3\n" + "beq 10003f\n" + "cmp %w[remain_n], #2\n" + "beq 10002f\n" + "cmp %w[remain_n], #1\n" + "beq 10001f\n" + "10003:\n" + "str q2, [%[output], #32]\n" + "10002:\n" + "str q1, [%[output], #16]\n" + "10001:\n" + "str q0, [%[output]]\n" + + "7:\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [remain_n] "+r"(remain_n), [is_first_k] "+r"(is_first_k), + [k] "+r"(K), [output] "+r"(output) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "cc", "memory"); +} + +static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + //! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} + int8_t zerobuff[4][64]; + std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); + megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, + "mk4 matmul with m is not times of 4"); + megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, + "mk4 matmul with k is not times of 4"); + size_t roundk = round_up(kmax - k0, 16); + size_t out_offset = roundk * 4; + int y = y0; + int start_y = y0 / 4; + for (; y + 15 < ymax; y += 16, start_y += 4) { + const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + int8_t* output = outptr + start_y * out_offset; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + int K = kmax - k0; + for (; K > 15; K -= 16) { + transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, + out_offset); + output += 64; + } + if (K > 0) { + std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4); + std::memcpy(zerobuff[1], inptr1, sizeof(int8_t) * K * 4); + std::memcpy(zerobuff[2], inptr2, sizeof(int8_t) * K * 4); + std::memcpy(zerobuff[3], inptr3, sizeof(int8_t) * K * 4); + inptr0 = zerobuff[0]; + inptr1 = zerobuff[1]; + inptr2 = zerobuff[2]; + inptr3 = zerobuff[3]; + transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, + out_offset); + output += 64; + } + } + for (; y + 3 < ymax; y += 4, start_y++) { + const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; + int8_t* output = outptr + start_y * out_offset; + prefetch_2x(inptr0); + int K = kmax - k0; + for (; K > 15; K -= 16) { + transpose_interleave_1x4_4_b(inptr0, output); + output += 64; + } + if (K > 0) { + std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4); + inptr0 = zerobuff[0]; + transpose_interleave_1x4_4_b(inptr0, output); + output += 64; + } + } +} + +static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int32_t zerobuff[4]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ICB = (ksize) / 4; + const int ksize4 = round_up(ICB, 4) * 4; + int32_t* outptr = reinterpret_cast(out); + megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, + "mk4 matmul with k is not times of 4"); + + int k = k0 / 4; + for (; k + 3 < ICB; k += 4) { + const int32_t* inptr0 = + reinterpret_cast(in + k * ldin + x0); + const int32_t* inptr1 = + reinterpret_cast(in + (k + 1) * ldin + x0); + const int32_t* inptr2 = + reinterpret_cast(in + (k + 2) * ldin + x0); + const int32_t* inptr3 = + reinterpret_cast(in + (k + 3) * ldin + x0); + int32_t* outptr_inner = outptr; + + int x = x0; + for (; x + 3 < xmax; x += 4) { + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner); + outptr_inner += ksize4; + } + if (x < xmax) { + for (; x < xmax; x++) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + } + } + outptr += 4 * 4; + } + if (k < ICB) { + const int32_t* inptr0 = + reinterpret_cast(in + k * ldin + x0); + const int32_t* inptr1 = + reinterpret_cast(in + (k + 1) * ldin + x0); + const int32_t* inptr2 = + reinterpret_cast(in + (k + 2) * ldin + x0); + const int32_t* inptr3 = + reinterpret_cast(in + (k + 3) * ldin + x0); + int32_t* outptr_inner = outptr; + + int x = x0; + for (; x + 3 < xmax; x += 4) { + if (k + 3 >= ICB) { + switch (k + 3 - ICB) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner); + outptr_inner += ksize4; + } + if (x < xmax) { + if (k + 3 >= ICB) { + switch (k + 3 - ICB) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + for (; x < xmax; x++) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + } + } + outptr += 4 * 4; + } +} + +} // namespace matmul_4x4x16 +} // namespace aarch64 +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8/strategy.cpp b/dnn/src/aarch64/matrix_mul/int8/strategy.cpp new file mode 100644 index 00000000..882e8da0 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8/strategy.cpp @@ -0,0 +1,263 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if !(__ARM_FEATURE_DOTPROD) +#include "src/aarch64/matrix_mul/int8/strategy.h" +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h" +#include "src/aarch64/matrix_mul/int8/kernel_8x8x8.h" +#include "src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +///////////////////////// gemm_s8_4x4 //////////////////////////////////// +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4); + +void gemm_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, + kmax); + } else { + matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) const { + if (transpose) { + matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); + } else { + matmul_4x4x16::gemm_s8_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); + } +} + +void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, + size_t N, size_t K, dt_int32* C, size_t LDC, + bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 4; + //! K is packed to times of 4 + K = round_up(K, 16); + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k); + output += B_INTERLEAVE; + cur_packB += K4; + } + + for (; n < N; n += B_INTERLEAVE) { + matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, LDC, + is_first_k, 4, + std::min(N - n, 4)); + output += B_INTERLEAVE; + cur_packB += K4; + } + + packA += K4; + } + + for (; m < M; m += 4) { + int32_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n < N; n += B_INTERLEAVE) { + matmul_4x4x16::kern_4x4_remain( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); + output += B_INTERLEAVE; + cur_packB += K4; + } + packA += K4; + } +} + +///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); + +void gemm_mk4_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + megdnn_assert(!transpose, + "the gemm_mk4_s8_4x4 strategy is not support transpose A"); + matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, + kmax); +} + +void gemm_mk4_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) const { + megdnn_assert(!transpose, + "the gemm_mk4_s8_4x4 strategy is not support transpose B"); + matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, + kmax); +} + +void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, + size_t N, size_t K, dt_int32* C, size_t LDC, + bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 4; + //! K is packed to times of 4 + megdnn_assert(K % 4 == 0, "K is not time of 4"); + const size_t K4 = round_up(K, 16) * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + (m / 4 * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, + is_first_k); + output += B_INTERLEAVE * 4; + cur_packB += K4; + } + + if (n < N) { + matmul_mk4_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, + is_first_k, N - n); + } + + packA += K4; + } +} + + +///////////////////////// gemm_s8_8x8 //////////////////////////////////// +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8); + +void gemm_s8_8x8::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, + ymax, k0, kmax); + } else { + matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_s8_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) const { + if (transpose) { + matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, + k0, kmax); + } else { + matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); + } +} + +void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, + size_t N, size_t K, dt_int32* C, size_t LDC, + bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t B_INTERLEAVE = 8; + //! K is packed to times of 4 + K = round_up(K, 8); + const int K8 = K * 8; + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, + is_first_k); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + packA += K8; + } + + for (; m < M; m += 4) { + int32_t* output = C + (m * LDC); + const dt_int8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), + std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + packA += K4; + } +} +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8/strategy.h b/dnn/src/aarch64/matrix_mul/int8/strategy.h new file mode 100644 index 00000000..933b67e1 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8/strategy.h @@ -0,0 +1,34 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#if !(__ARM_FEATURE_DOTPROD) +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, + gemm_s8_4x4); + +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, + gemm_mk4_s8_4x4); + +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, + gemm_s8_8x8); + +} // namespace matmul +} // namespace aarch64 +} // namespace megdnn + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp b/dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp new file mode 100644 index 00000000..cb1ac71b --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp @@ -0,0 +1,116 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/int8_dot/gemv.h" +#include +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/common/unroll_macro.h" + +#if __ARM_FEATURE_DOTPROD + +namespace { + +void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride) { + megdnn_assert(N == 1 && Bstride == 1); + size_t m = 0; + for (; m + 2 <= M; m += 2) { + int32_t acc[4]; + int32x4_t acc_neon = vdupq_n_s32(0); + size_t k = 0; + for (; k + 16 <= K; k += 16) { + int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k)); + int64x2_t a1 = + vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k)); + //! the first 8 elements is m, the last 8 elements is m + 1 + int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1)); + int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1)); + + int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k)); + int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0)); + int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0)); + + acc_neon = vdotq_s32(acc_neon, a2, b2); + acc_neon = vdotq_s32(acc_neon, a3, b3); + } + vst1q_s32(acc, acc_neon); + + for (; k + 8 <= K; k += 8) { + int8x8_t a0 = vld1_s8(A + m * Astride + k); + int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k); + int8x8_t b0 = vld1_s8(B + k); + uint32x2_t zero = vdup_n_s32(0); + acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); + zero = vdup_n_s32(0); + acc[3] += vaddv_s32(vdot_s32(zero, a1, b0)); + } + + for (; k < K; ++k) { + acc[0] += static_cast(A[m * Astride + k]) * B[k]; + acc[3] += static_cast(A[(m + 1) * Astride + k]) * B[k]; + } + C[m * Cstride] = acc[0] + acc[1]; + C[(m + 1) * Cstride] = acc[2] + acc[3]; + } + + for (; m < M; ++m) { + int32_t acc[4]; + int32x4_t acc_neon = vdupq_n_s32(0); + size_t k = 0; + for (; k + 16 <= K; k += 16) { + int8x16_t a0 = vld1q_s8(A + m * Astride + k); + int8x16_t b0 = vld1q_s8(B + k); + acc_neon = vdotq_s32(acc_neon, a0, b0); + } + vst1q_s32(acc, acc_neon); + + for (; k + 8 <= K; k += 8) { + int8x8_t a0 = vld1_s8(A + m * Astride + k); + int8x8_t b0 = vld1_s8(B + k); + uint32x2_t zero = vdup_n_s32(0); + acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); + } + + for (; k < K; ++k) { + acc[0] += static_cast(A[m * Astride + k]) * B[k]; + } + C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; + } +} + +} // namespace + +bool megdnn::aarch64::matmul::is_gemv_like_preferred_int8( + bool transposeA, bool transposeB, size_t M, size_t N, size_t K, + size_t /* LDA */, size_t LDB, size_t /* LDC */) { + if (transposeA) + return false; + if (transposeB) + return false; + MEGDNN_MARK_USED_VAR(K); + MEGDNN_MARK_USED_VAR(M); + return (N == 1 && LDB == 1); +} + +void megdnn::aarch64::matmul::gemv_like_int8(const int8_t* __restrict A, + const int8_t* __restrict B, + int32_t* __restrict C, size_t M, + size_t N, size_t K, size_t Astride, + size_t Bstride, size_t Cstride) { + megdnn_assert(N == 1); + return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); +} + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/gemv.h b/dnn/src/aarch64/matrix_mul/int8_dot/gemv.h new file mode 100644 index 00000000..61041ab1 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8_dot/gemv.h @@ -0,0 +1,34 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include +#include + +#if __ARM_FEATURE_DOTPROD +namespace megdnn { +namespace aarch64 { +namespace matmul { + +bool is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M, + size_t N, size_t K, size_t LDA, size_t LDB, + size_t LDC); + +void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride); + +} // namespace matmul +} // namespace aarch64 +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h b/dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h new file mode 100644 index 00000000..1f08d31b --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h @@ -0,0 +1,1552 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_DOTPROD + +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_8x12x4 { + +// Overview of register layout: +// +// A 12x4 cell of Rhs is stored in 8bit in q2-q4. +// A 8x4x2 cell of Lhs is stored in 8bit in q0-q1,q5-q6 +// A 8x12 block of accumulators is stored in 8bit in q8--q31. +// +// +--------+--------+--------+ +// |v2[0-16]|v3[0-16]|v4[0-16]| +// Rhs +--------+--------+--------+ +// +// | | | | +// +// Lhs | | | | +// +// +-------+-------+ - - - - +--------+--------+--------+ +// |v0[0-4]|v5[0-4]| | v8[0-4]|v16[0-4]|v24[0-4]| +// |v0[0-4]|v5[0-4]| | v9[0-4]|v17[0-4]|v25[0-4]| +// |v0[0-4]|v5[0-4]| |v10[0-4]|v18[0-4]|v26[0-4]| +// |v0[0-4]|v5[0-4]| |v11[0-4]|v19[0-4]|v27[0-4]| +// |v1[0-4]|v6[0-4]| |v12[0-4]|v20[0-4]|v28[0-4]| +// |v1[0-4]|v6[0-4]| |v13[0-4]|v21[0-4]|v29[0-4]| +// |v1[0-4]|v6[0-4]| |v14[0-4]|v22[0-4]|v30[0-4]| +// |v1[0-4]|v6[0-4]| |v15[0-4]|v23[0-4]|v31[0-4]| +// +-------+-------+ - - - - +--------+--------+--------+ +// +// Accumulator + +/** + * \note The performance of reorder instruction and use prefetch is almost the + * same, I test in kirin980 with small and big core, here i just keep both the + * implementation. + */ +#if 1 +static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + int32x4_t a0; + int32x4_t a1; + int32x4_t b0; + int32x4_t b1; + int32x4_t b2; + int32x4_t a0a; + int32x4_t a1a; + LDC = LDC * sizeof(int32_t); + + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + int32_t* outptr4; + int32_t* outptr5; + int32_t* outptr6; + int32_t* outptr7; + + asm volatile ( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "add %[outptr4], %[outptr3], %x[LDC]\n" + "add %[outptr5], %[outptr4], %x[LDC]\n" + "add %[outptr6], %[outptr5], %x[LDC]\n" + "add %[outptr7], %[outptr6], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 5f\n" + // we can not use ld1, as it can not encode {v8, v16, v24} + "ldp q8, q16, [%[outptr0]]\n" + "ldr q24, [%[outptr0], #32]\n" + "ldp q9, q17, [%[outptr1]]\n" + "ldr q25, [%[outptr1], #32]\n" + "ldp q10, q18, [%[outptr2]]\n" + "ldr q26, [%[outptr2], #32]\n" + "ldp q11, q19, [%[outptr3]]\n" + "ldr q27, [%[outptr3], #32]\n" + "ldp q12, q20, [%[outptr4]]\n" + "ldr q28, [%[outptr4], #32]\n" + "ldp q13, q21, [%[outptr5]]\n" + "ldr q29, [%[outptr5], #32]\n" + "ldp q14, q22, [%[outptr6]]\n" + "ldr q30, [%[outptr6], #32]\n" + "ldp q15, q23, [%[outptr7]]\n" + "ldr q31, [%[outptr7], #32]\n" + "b 6f\n" + + "5:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + + "6: \n" + // Initialize result registers, load initial operands, prime prefetches. + "ldr %q[a0], [%[a_ptr]]\n" + "ldr %q[b0], [%[b_ptr]]\n" + "ldr %q[a1], [%[a_ptr], #16]\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + ASM_PREFETCH("[%[b_ptr], #64]") + ASM_PREFETCH("[%[a_ptr], #64]") + ASM_PREFETCH("[%[b_ptr], #128]") + ASM_PREFETCH("[%[a_ptr], #128]") + ASM_PREFETCH("[%[b_ptr], #192]") + ASM_PREFETCH("[%[b_ptr], #256]") + ASM_PREFETCH("[%[a_ptr], #192]") + ASM_PREFETCH("[%[b_ptr], #320]") + ASM_PREFETCH("[%[a_ptr], #256]") + ASM_PREFETCH("[%[b_ptr], #384]") + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + // Loop proper + "1:\n" + "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n" + "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n" + + "ldr %q[b2], [%[b_ptr], #32]\n" + "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "ldr %q[a0a], [%[a_ptr], #32]\n" + "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "ldr %q[a1a], [%[a_ptr], #48]\n" + "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "ldr %q[b0], [%[b_ptr], #48]\n" + + "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n" + ASM_PREFETCH("[%[a_ptr], #320]") + "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n" + "ldr %q[b1], [%[b_ptr], #64]\n" + + "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" + ASM_PREFETCH("[%[b_ptr], #448]") + "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n" + "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n" + "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n" + "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n" + "ldr %q[b2], [%[b_ptr], #80]\n" + + "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n" + "sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n" + "ldr %q[a0], [%[a_ptr], #64]\n" + "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n" + "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n" + "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n" + "ldr %q[a1], [%[a_ptr], #80]\n" + "sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n" + "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n" + "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n" + "ldr %q[b0], [%[b_ptr], #96]\n" + + "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n" + "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n" + ASM_PREFETCH("[%[b_ptr], #512]") + "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n" + "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n" + "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n" + "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n" + "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n" + "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n" + "ldr %q[b1], [%[b_ptr], #112]\n" + + "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n" + "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n" + "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n" + "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n" + "subs %w[k], %w[k], #1\n" + "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n" + "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n" + "bne 1b\n" + + // Target to use when K is 1 or 2 (i.e. zero iterations of main loop) + "4:\n" + + // Branch to alternative tail for odd K + "cbnz %w[oddk], 2f\n" + + // Detached final iteration (even K) + "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n" + "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "ldr %q[a0a], [%[a_ptr], #32]\n" + "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "ldr %q[a1a], [%[a_ptr], #48]\n" + "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "ldr %q[b0], [%[b_ptr], #48]\n" + + "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n" + "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n" + "ldr %q[b1], [%[b_ptr], #64]\n" + + "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n" + "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n" + "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n" + "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n" + "ldr %q[b2], [%[b_ptr], #80]\n" + + "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n" + + "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n" + "str q8, [%[outptr0], #0]\n" + "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n" + "str q16, [%[outptr0], #16]\n" + "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n" + "str q24, [%[outptr0], #32]\n" + + "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n" + "str q9, [%[outptr1], #0]\n" + "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n" + "str q17, [%[outptr1], #16]\n" + "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n" + "str q25, [%[outptr1], #32]\n" + "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n" + "str q10, [%[outptr2], #0]\n" + + "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n" + "str q18, [%[outptr2], #16]\n" + "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n" + "str q26, [%[outptr2], #32]\n" + "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n" + "str q11, [%[outptr3], #0]\n" + + "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n" + "str q19, [%[outptr3], #16]\n" + "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n" + "str q27, [%[outptr3], #32]\n" + "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n" + "str q12, [%[outptr4], #0]\n" + + "sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n" + "str q20, [%[outptr4], #16]\n" + "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n" + "str q28, [%[outptr4], #32]\n" + "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n" + "str q13, [%[outptr5], #0]\n" + + "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n" + "str q21, [%[outptr5], #16]\n" + "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n" + "str q29, [%[outptr5], #32]\n" + "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n" + "str q14, [%[outptr6], #0]\n" + + "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n" + "str q22, [%[outptr6], #16]\n" + "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n" + "str q30, [%[outptr6], #32]\n" + "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n" + "str q15, [%[outptr7], #0]\n" + + "b 3f\n" + + // Detached final iteration (odd K) + "2:\n" + "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n" + "str q8, [%[outptr0], #0]\n" + "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n" + "str q16, [%[outptr0], #16]\n" + "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "add %[a_ptr], %[a_ptr], #32\n" + "str q24, [%[outptr0], #32]\n" + "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" + "str q9, [%[outptr1], #0]\n" + + "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "str q17, [%[outptr1], #16]\n" + "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "str q25, [%[outptr1], #32]\n" + "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "str q10, [%[outptr2], #0]\n" + + "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "str q18, [%[outptr2], #16]\n" + "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "str q26, [%[outptr2], #32]\n" + "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n" + "str q11, [%[outptr3], #0]\n" + + "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "str q19, [%[outptr3], #16]\n" + "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "str q27, [%[outptr3], #32]\n" + "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "str q12, [%[outptr4], #0]\n" + + "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "str q20, [%[outptr4], #16]\n" + "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "str q28, [%[outptr4], #32]\n" + "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n" + "str q13, [%[outptr5], #0]\n" + + "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "str q21, [%[outptr5], #16]\n" + "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "str q29, [%[outptr5], #32]\n" + "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n" + "str q14, [%[outptr6], #0]\n" + + "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "str q22, [%[outptr6], #16]\n" + "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n" + "str q30, [%[outptr6], #32]\n" + "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n" + "str q15, [%[outptr7], #0]\n" + + + // Common tail + "3:\n" + "str q23, [%[outptr7], #16]\n" + "str q31, [%[outptr7], #32]\n" + : + [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr),[oddk] "+r" (oddk), + [is_first_k] "+r" (is_first_k), [k] "+r" (k), [LDC] "+r" (LDC), + [a0] "=w" (a0), [a1] "=w" (a1), [a0a] "=w" (a0a), [a1a] "=w" (a1a), + [b0] "=w" (b0), [b1] "=w" (b1), [b2] "=w" (b2), + [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), + [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7) + : + : "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", + "memory" + ); +} +#else +static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + + int32x4_t a0; + int32x4_t a1; + int32x4_t b0; + int32x4_t b1; + int32x4_t b2; + int32x4_t a0a; + int32x4_t a1a; + LDC = LDC * sizeof(int32_t); + + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + int32_t* outptr4; + int32_t* outptr5; + int32_t* outptr6; + int32_t* outptr7; + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "add %[outptr4], %[outptr3], %x[LDC]\n" + "add %[outptr5], %[outptr4], %x[LDC]\n" + "add %[outptr6], %[outptr5], %x[LDC]\n" + "add %[outptr7], %[outptr6], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + // we can not use ld1, as it can not encode {v8, v16, v24} + "ldp q8, q16, [%[outptr0]]\n" + "ldr q24, [%[outptr0], #32]\n" + "ldp q9, q17, [%[outptr1]]\n" + "ldr q25, [%[outptr1], #32]\n" + "ldp q10, q18, [%[outptr2]]\n" + "ldr q26, [%[outptr2], #32]\n" + "ldp q11, q19, [%[outptr3]]\n" + "ldr q27, [%[outptr3], #32]\n" + "ldp q12, q20, [%[outptr4]]\n" + "ldr q28, [%[outptr4], #32]\n" + "ldp q13, q21, [%[outptr5]]\n" + "ldr q29, [%[outptr5], #32]\n" + "ldp q14, q22, [%[outptr6]]\n" + "ldr q30, [%[outptr6], #32]\n" + "ldp q15, q23, [%[outptr7]]\n" + "ldr q31, [%[outptr7], #32]\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + + "2: \n" + "cbz %w[oddk], 3f\n" + // parse the oddk + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[a1], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[b1], [%[b_ptr]], #16\n" + "ldr %q[b2], [%[b_ptr]], #16\n" + "sdot v8.4s, %[b0].16b, %[a0].4b[0]\n" + "sdot v9.4s, %[b0].16b, %[a0].4b[1]\n" + "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n" + "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n" + "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" + "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n" + "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n" + "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n" + "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n" + + "cbz %w[k], 4f\n" + // Loop proper + "3:\n" + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[a1], [%[a_ptr]], #16\n" + "ldr %q[a0a], [%[a_ptr]], #16\n" + "ldr %q[a1a], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[b1], [%[b_ptr]], #16\n" + "ldr %q[b2], [%[b_ptr]], #16\n" + "sdot v8.4s, %[b0].16b, %[a0].4b[0]\n" + "sdot v9.4s, %[b0].16b, %[a0].4b[1]\n" + "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n" + "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n" + "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" + "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n" + "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n" + "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n" + "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[b1], [%[b_ptr]], #16\n" + "ldr %q[b2], [%[b_ptr]], #16\n" + "sdot v8.4s, %[b0].16b, %[a0a].4b[0]\n" + "sdot v9.4s, %[b0].16b, %[a0a].4b[1]\n" + "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n" + "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n" + "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n" + "sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n" + "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n" + "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n" + "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n" + "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n" + "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n" + "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n" + "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n" + "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n" + "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n" + "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n" + "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n" + "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n" + "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n" + "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n" + "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n" + "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n" + "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n" + "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n" + + "subs %w[k], %w[k], #1\n" + "bne 3b\n" + + "4:\n" + "stp q8, q16, [%[outptr0]]\n" + "str q24, [%[outptr0], #32]\n" + "stp q9, q17, [%[outptr1]]\n" + "str q25, [%[outptr1], #32]\n" + "stp q10, q18, [%[outptr2]]\n" + "str q26, [%[outptr2], #32]\n" + "stp q11, q19, [%[outptr3]]\n" + "str q27, [%[outptr3], #32]\n" + "stp q12, q20, [%[outptr4]]\n" + "str q28, [%[outptr4], #32]\n" + "stp q13, q21, [%[outptr5]]\n" + "str q29, [%[outptr5], #32]\n" + "stp q14, q22, [%[outptr6]]\n" + "str q30, [%[outptr6], #32]\n" + "stp q15, q23, [%[outptr7]]\n" + "str q31, [%[outptr7], #32]\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), + [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), + [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), + [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), + [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), + [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7) + : + : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); +} + +#endif + +// Overview of register layout: +// +// A 12x4 cell of Rhs is stored in 8bit in q2-q4. +// A 8x4x2 cell of Lhs is stored in 8bit in q0-q1,q5-q6 +// A 8x12 block of accumulators is stored in 8bit in q8--q31. +// +// +--------+--------+--------+ +// |v1[0-16]|v2[0-16]|v3[0-16]| +// Rhs +--------+--------+--------+ +// |v5[0-16]|v6[0-16]|v7[0-16]| +// +--------+--------+--------+ +// +// | | | | +// +// Lhs | | | | +// +// +-------+-------+ - - - - +--------+--------+--------+ +// |v0[0-4]|v4[0-4]| | v8[0-4]|v12[0-4]|v16[0-4]| +// |v0[0-4]|v4[0-4]| | v9[0-4]|v13[0-4]|v17[0-4]| +// |v0[0-4]|v4[0-4]| |v10[0-4]|v14[0-4]|v18[0-4]| +// |v0[0-4]|v4[0-4]| |v11[0-4]|v15[0-4]|v19[0-4]| +// +-------+-------+ - - - - +--------+--------+--------+ +// +// Accumulator + +static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, int m_remain) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + int32x4_t a0; + int32x4_t b0; + int32x4_t b1; + int32x4_t b2; + int32x4_t a0a; + int32x4_t b0a; + int32x4_t b1a; + int32x4_t b2a; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + size_t x0; + +// clang-format off +#define LOAD_LINE(v1, v2, v3, m) \ + "cbz %[x0], 100f\n" \ + "ldp " v1 "," v2 ", [%[outptr" m "]]\n" \ + "ldr " v3 ", [%[outptr" m "], #32]\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %x[m_remain]\n" \ + LOAD_LINE("q8", "q12", "q16", "0") \ + LOAD_LINE("q9", "q13", "q17", "1") \ + LOAD_LINE("q10", "q14", "q18", "2") \ + LOAD_LINE("q11", "q15", "q19", "3") \ + "100:\n" + +#define STORE_LINE(v1, v2, v3, m) \ + "cbz %[x0], 101f\n" \ + "stp " v1 "," v2", [%[outptr" m "]]\n" \ + "str " v3 ", [%[outptr" m "], #32]\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %x[m_remain]\n" \ + STORE_LINE("q8", "q12", "q16", "0") \ + STORE_LINE("q9", "q13", "q17", "1") \ + STORE_LINE("q10", "q14", "q18", "2") \ + STORE_LINE("q11", "q15", "q19", "3") \ + "101:\n" + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + "cbz %w[oddk], 3f\n" + + // parse the oddk + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[b1], [%[b_ptr]], #16\n" + "ldr %q[b2], [%[b_ptr]], #16\n" + "sdot v8.4s, %[b0].16b, %[a0].4b[0]\n" + "sdot v9.4s, %[b0].16b, %[a0].4b[1]\n" + "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "sdot v12.4s, %[b1].16b, %[a0].4b[0]\n" + "sdot v13.4s, %[b1].16b, %[a0].4b[1]\n" + "sdot v14.4s, %[b1].16b, %[a0].4b[2]\n" + "sdot v15.4s, %[b1].16b, %[a0].4b[3]\n" + "sdot v16.4s, %[b2].16b, %[a0].4b[0]\n" + "sdot v17.4s, %[b2].16b, %[a0].4b[1]\n" + "sdot v18.4s, %[b2].16b, %[a0].4b[2]\n" + "sdot v19.4s, %[b2].16b, %[a0].4b[3]\n" + + "cbz %w[k], 4f\n" + // Loop proper + "3:\n" + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[b1], [%[b_ptr]], #16\n" + "ldr %q[b2], [%[b_ptr]], #16\n" + "ldr %q[a0a], [%[a_ptr]], #16\n" + "ldr %q[b0a], [%[b_ptr]], #16\n" + "ldr %q[b1a], [%[b_ptr]], #16\n" + "ldr %q[b2a], [%[b_ptr]], #16\n" + + "sdot v8.4s, %[b0].16b, %[a0].4b[0]\n" + "sdot v9.4s, %[b0].16b, %[a0].4b[1]\n" + "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "sdot v12.4s, %[b1].16b, %[a0].4b[0]\n" + "sdot v13.4s, %[b1].16b, %[a0].4b[1]\n" + "sdot v14.4s, %[b1].16b, %[a0].4b[2]\n" + "sdot v15.4s, %[b1].16b, %[a0].4b[3]\n" + "sdot v16.4s, %[b2].16b, %[a0].4b[0]\n" + "sdot v17.4s, %[b2].16b, %[a0].4b[1]\n" + "sdot v18.4s, %[b2].16b, %[a0].4b[2]\n" + "sdot v19.4s, %[b2].16b, %[a0].4b[3]\n" + "sdot v8.4s , %[b0a].16b, %[a0a].4b[0]\n" + "sdot v9.4s , %[b0a].16b, %[a0a].4b[1]\n" + "sdot v10.4s, %[b0a].16b, %[a0a].4b[2]\n" + "sdot v11.4s, %[b0a].16b, %[a0a].4b[3]\n" + "sdot v12.4s, %[b1a].16b, %[a0a].4b[0]\n" + "sdot v13.4s, %[b1a].16b, %[a0a].4b[1]\n" + "sdot v14.4s, %[b1a].16b, %[a0a].4b[2]\n" + "sdot v15.4s, %[b1a].16b, %[a0a].4b[3]\n" + "sdot v16.4s, %[b2a].16b, %[a0a].4b[0]\n" + "sdot v17.4s, %[b2a].16b, %[a0a].4b[1]\n" + "sdot v18.4s, %[b2a].16b, %[a0a].4b[2]\n" + "sdot v19.4s, %[b2a].16b, %[a0a].4b[3]\n" + + "subs %w[k], %w[k], #1\n" + "bne 3b\n" + + "4:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), + [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), + [is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain), + [LDC] "+r"(LDC), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), + [b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a), + [b2a] "=w"(b2a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + [outptr3] "=r"(outptr3), [x0] "=r"(x0) + : + : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "memory", "cc"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +// Overview of register layout: +// +// A (4x4)x2 cell of Rhs is stored in 8bit in q2-q3. +// A 4x4x2 cell of Lhs is stored in 8bit in q0-q1, q4-a5 +// A 8x4 block of accumulators is stored in 8bit in q4--q7. +// +// +--------+ +// |v2[0-16]| +// Rhs +--------+ +// |v3[0-16]| +// +--------+ +// | | +// +// Lhs | | +// +// +-------+-------+ - - - - +--------+ +// |v0[0-4]|v4[0-4]| | v6[0-4]| +// |v0[0-4]|v4[0-4]| | v7[0-4]| +// |v0[0-4]|v4[0-4]| | v8[0-4]| +// |v0[0-4]|v4[0-4]| | v9[0-4]| +// |v1[0-4]|v5[0-4]| |v10[0-4]| +// |v1[0-4]|v5[0-4]| |v11[0-4]| +// |v1[0-4]|v5[0-4]| |v12[0-4]| +// |v1[0-4]|v5[0-4]| |v13[0-4]| +// +-------+-------+ - - - - +---------+ +// +// Accumulator + +static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, int n_remain) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + int32x4_t a0; + int32x4_t a1; + int32x4_t b0; + int32x4_t b0a; + int32x4_t a0a; + int32x4_t a1a; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + int32_t* outptr4; + int32_t* outptr5; + int32_t* outptr6; + int32_t* outptr7; + + size_t x0; + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" reg_index ", [%[x0]] \n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "101" n ":\n" + + +#define LOAD_C \ + LOAD_LINE("6", "0") \ + LOAD_LINE("7", "1") \ + LOAD_LINE("8", "2") \ + LOAD_LINE("9", "3") \ + LOAD_LINE("10", "4") \ + LOAD_LINE("11", "5") \ + LOAD_LINE("12", "6") \ + LOAD_LINE("13", "7") + +#define STORE_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 102" n "f\n" \ + "str q" reg_index ", [%[x0]]\n" \ + "b 103" n "f\n" \ + "102" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "103" n ":\n" + +#define STORE_C \ + STORE_LINE("6", "0") \ + STORE_LINE("7", "1") \ + STORE_LINE("8", "2") \ + STORE_LINE("9", "3") \ + STORE_LINE("10", "4") \ + STORE_LINE("11", "5") \ + STORE_LINE("12", "6") \ + STORE_LINE("13", "7") + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "add %[outptr4], %[outptr3], %x[LDC]\n" + "add %[outptr5], %[outptr4], %x[LDC]\n" + "add %[outptr6], %[outptr5], %x[LDC]\n" + "add %[outptr7], %[outptr6], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + + "2: \n" + "cbz %w[oddk], 3f\n" + + // parse the oddk + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[a1], [%[a_ptr]], #16\n" + "sdot v6.4s , %[b0].16b, %[a0].4b[0]\n" + "sdot v7.4s , %[b0].16b, %[a0].4b[1]\n" + "sdot v8.4s, %[b0].16b, %[a0].4b[2]\n" + "sdot v9.4s, %[b0].16b, %[a0].4b[3]\n" + "sdot v10.4s, %[b0].16b, %[a1].4b[0]\n" + "sdot v11.4s, %[b0].16b, %[a1].4b[1]\n" + "sdot v12.4s, %[b0].16b, %[a1].4b[2]\n" + "sdot v13.4s, %[b0].16b, %[a1].4b[3]\n" + + "cbz %w[k], 4f\n" + // Loop proper + "3:\n" + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[a1], [%[a_ptr]], #16\n" + "ldr %q[a0a], [%[a_ptr]], #16\n" + "ldr %q[a1a], [%[a_ptr]], #16\n" + "ldr %q[b0a], [%[b_ptr]], #16\n" + "sdot v6.4s , %[b0].16b, %[a0].4b[0]\n" + "sdot v7.4s , %[b0].16b, %[a0].4b[1]\n" + "sdot v8.4s, %[b0].16b, %[a0].4b[2]\n" + "sdot v9.4s, %[b0].16b, %[a0].4b[3]\n" + "sdot v10.4s, %[b0].16b, %[a1].4b[0]\n" + "sdot v11.4s, %[b0].16b, %[a1].4b[1]\n" + "sdot v12.4s, %[b0].16b, %[a1].4b[2]\n" + "sdot v13.4s, %[b0].16b, %[a1].4b[3]\n" + + "sdot v6.4s , %[b0a].16b, %[a0a].4b[0]\n" + "sdot v7.4s , %[b0a].16b, %[a0a].4b[1]\n" + "sdot v8.4s, %[b0a].16b, %[a0a].4b[2]\n" + "sdot v9.4s, %[b0a].16b, %[a0a].4b[3]\n" + "sdot v10.4s, %[b0a].16b, %[a1a].4b[0]\n" + "sdot v11.4s, %[b0a].16b, %[a1a].4b[1]\n" + "sdot v12.4s, %[b0a].16b, %[a1a].4b[2]\n" + "sdot v13.4s, %[b0a].16b, %[a1a].4b[3]\n" + + "subs %w[k], %w[k], #1\n" + "bne 3b\n" + + "4:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC), + [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), + [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), + [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), + [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), + [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "=r"(x0) + : + : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", + "cc"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +// Overview of register layout: +// +// A 4x4x2 cell of Rhs is stored in 8bit in q2-q3. +// A 4x4x2 cell of Lhs is stored in 8bit in q0-q1 +// A 4x4x2 block of accumulators is stored in 8bit in q4--q7. +// +// +--------+ +// | v2[0-7]| +// Rhs +--------+ +// | v3[0-7]| +// +--------+ +// | | +// +// Lhs | | +// +// +-------+-------+ - - - - +--------+ +// |v0[0-4]|v1[0-4]| | v4[0-7]| +// |v0[0-4]|v1[0-4]| | v5[0-7]| +// |v0[0-4]|v1[0-4]| | v6[0-7]| +// |v0[0-4]|v1[0-4]| | v7[0-7]| +// +-------+-------+ - - - - +--------+ +// +// Accumulator + +static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + K /= 4; + const int32_t* a_ptr = reinterpret_cast(packA); + const int32_t* b_ptr = reinterpret_cast(packB); + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + int32x4_t a0; + int32x4_t a0a; + int32x4_t b0; + int32x4_t b0a; + LDC = LDC * sizeof(int32_t); + + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + size_t x0, x1; + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "cbz %[x1], 102f\n" \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" reg_index ", [%[x0]]\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "101" n ":\n" \ + "subs %[x1], %[x1], #1\n" + +#define LOAD_C \ + "mov %[x1], %x[m_remain]\n" \ + LOAD_LINE("4", "0") \ + LOAD_LINE("5", "1") \ + LOAD_LINE("6", "2") \ + LOAD_LINE("7", "3") \ + "102:\n" + +#define STORE_LINE(reg_index, n) \ + "cbz %[x1], 105f\n" \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "str q" reg_index ", [%[x0]]\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "104" n ":\n" \ + "subs %[x1], %[x1], #1\n" + +#define STORE_C \ + "mov %[x1], %x[m_remain]\n" \ + STORE_LINE("4", "0") \ + STORE_LINE("5", "1") \ + STORE_LINE("6", "2") \ + STORE_LINE("7", "3") \ + "105:\n" + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" // + LOAD_C // + + "b 2f\n" + + "1:\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + + "2: \n" + "cbz %w[oddk], 3f\n" + + // parse the oddk + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "sdot v4.4s , %[b0].16b, %[a0].4b[0]\n" + "sdot v5.4s , %[b0].16b, %[a0].4b[1]\n" + "sdot v6.4s, %[b0].16b, %[a0].4b[2]\n" + "sdot v7.4s, %[b0].16b, %[a0].4b[3]\n" + + "cbz %w[k], 4f\n" + // Loop proper + "3:\n" + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[a0a], [%[a_ptr]], #16\n" + "ldr %q[b0a], [%[b_ptr]], #16\n" + "sdot v4.4s , %[b0].16b, %[a0].4b[0]\n" + "sdot v5.4s , %[b0].16b, %[a0].4b[1]\n" + "sdot v6.4s, %[b0].16b, %[a0].4b[2]\n" + "sdot v7.4s, %[b0].16b, %[a0].4b[3]\n" + "sdot v4.4s , %[b0a].16b, %[a0a].4b[0]\n" + "sdot v5.4s , %[b0a].16b, %[a0a].4b[1]\n" + "sdot v6.4s, %[b0a].16b, %[a0a].4b[2]\n" + "sdot v7.4s, %[b0a].16b, %[a0a].4b[3]\n" + + "subs %w[k], %w[k], #1\n" + "bne 3b\n" + + "4:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), + [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), + [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), + [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + [outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1) + : + : "v4", "v5", "v6", "v7", "memory", "cc"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + //! read 8 * 4 in each row + for (; K > 15; K -= 16) { + interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + } + + if (K > 0) { + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, K); + } + } + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + //! read 4 * 4 in each row + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, K); + } + } +} + +static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize8 = round_up(ksize, 4) * 8; + const int ksize4 = round_up(ksize, 4) * 4; + int8_t* outptr = out; + int8_t* outptr_base = out; + //! 4x4 block output start pos + int8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k += 4) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int x = x0; + outptr = outptr_base; + for (; x + 7 < xmax; x += 8) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += ksize8; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, 4); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + + outptr_base += 8 * 4; + outptr_base4 += 4 * 4; + } +} + +static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize12 = round_up(ksize, 4) * 12; + const int ksize4 = round_up(ksize, 4) * 4; + int8_t* outptr = out; + int8_t* outptr_base = out; + //! 4x4 block output start pos + int8_t* outptr_base4 = out + ((xmax - x0) / 12) * ksize12; + + int k = k0; + for (; k < kmax; k += 4) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int x = x0; + outptr = outptr_base; + for (; x + 11 < xmax; x += 12) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_12x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += ksize12; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, 4); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + + outptr_base += 12 * 4; + outptr_base4 += 4 * 4; + } +} + +static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y + 11 < ymax; y += 12) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + const int8_t* inptr8 = inptr7 + ldin; + const int8_t* inptr9 = inptr8 + ldin; + const int8_t* inptr10 = inptr9 + ldin; + const int8_t* inptr11 = inptr10 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + prefetch_2x(inptr8); + prefetch_2x(inptr9); + prefetch_2x(inptr10); + prefetch_2x(inptr11); + + int K = kmax - k0; + //! read 12 * 4 in each row + for (; K > 15; K -= 16) { + interleave_12x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, inptr8, inptr9, inptr10, + inptr11, outptr); + } + + if (K > 0) { + interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, + outptr, 4, K); + } + } + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + //! read 4 * 4 in each row + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, K); + } + } +} + +} // namespace matmul_8x12x4 +} // namespace aarch64 +} // namespace megdnn + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp new file mode 100644 index 00000000..ca3d0061 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp @@ -0,0 +1,113 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/int8_dot/strategy.h" +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" + +#if __ARM_FEATURE_DOTPROD +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12); + +void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, + kmax); + } else { + matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) const { + if (transpose) { + matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); + } else { + matmul_8x12x4::gemm_s8_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); + } +} + +void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, + size_t N, size_t K, dt_int32* C, size_t LDC, + bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t B_INTERLEAVE = 12; + //! K is packed to times of 4 + K = round_up(K, 4); + const int K8 = (K << 3); + const int K12 = K * 12; + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, + is_first_k); + output += B_INTERLEAVE; + cur_packB += K12; + } + + for (; n < N; n += 4) { + matmul_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, + is_first_k, std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + packA += K8; + } + + for (; m < M; m += 4) { + int32_t* output = C + (m * LDC); + const dt_int8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, + is_first_k, std::min(M - m, 4)); + output += B_INTERLEAVE; + cur_packB += K12; + } + + for (; n < N; n += 4) { + matmul_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k, std::min(M - m, 4), + std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + packA += K4; + } +} +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h new file mode 100644 index 00000000..c633d9dc --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h @@ -0,0 +1,26 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8_dot/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +#if __ARM_FEATURE_DOTPROD +namespace megdnn { +namespace aarch64 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, + gemm_s8_8x12); + +} // namespace aarch64 +} // namespace matmul +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h new file mode 100644 index 00000000..c91daa93 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h @@ -0,0 +1,439 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_4x4x16 { + +/** + * Overview of register layout: + * + * +---------+---------+---------+---------+ + * |v20[0-15]|v21[0-15]|v22[0-15]|v23[0-15]| + * Rhs +---------+---------+---------+---------+ + * Lhs | | | + * + * +--------+ - - - - +---------+---------+---------+---------+ + * |v0[0-15]| | v4[0-8] | v8[0-8]| v12[0-8]| v16[0-8]| + * |v1[0-15]| | v5[0-8] | v9[0-8]| v13[0-8]| v17[0-8]| + * |v2[0-15]| | v6[0-8] | v10[0-8]| v14[0-8]| v18[0-8]| + * |v3[0-15]| | v7[0-8] | v11[0-8]| v15[0-8]| v19[0-8]| + * +--------+ - - - - +---------+---------+---------+---------+ + * + * Accumulator + */ +static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + K /= 16; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int16_t); +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "cmp x5, #0 \n" \ + "beq 105f\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ld1 {v" reg_index ".4h}, [x" n "], #8\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "blt 101" n "f\n" \ + "ld1 {v" reg_index ".h}[0], [x" n "], #2\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[1], [x" n "], #2\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[2], [x" n "], #2\n" \ + "101" n ":\n" \ + "sub x5, x5, #1\n" + +#define LOAD_C \ + "mov x5, %x[m_remain]\n" \ + LOAD_LINE("24", "0") \ + LOAD_LINE("25", "1") \ + LOAD_LINE("26", "2") \ + LOAD_LINE("27", "3") \ + "105:\n" + + +#define STORE_LINE(reg_index, n) \ + "cmp x5, #0 \n" \ + "beq 105f\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 102" n "f\n" \ + "st1 {v" reg_index ".4h}, [x" n "], #8\n" \ + "b 103" n "f\n" \ + "102" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".h}[0], [x" n "], #2\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".h}[1], [x" n "], #2\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".h}[2], [x" n "], #2\n" \ + "103" n ":\n" \ + "sub x5, x5, #1\n" + +#define STORE_C \ + "mov x5, %x[m_remain]\n" \ + STORE_LINE("24", "0") \ + STORE_LINE("25", "1") \ + STORE_LINE("26", "2") \ + STORE_LINE("27", "3") \ + "105:\n" + // clang-format on + + register int16_t* outptr asm("x0") = output; + asm volatile( + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + + // Clear accumulators + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + // General loop. + "1:\n" + "ld1 {v20.16b}, [%[b_ptr]], 16\n" + "ld1 {v0.16b}, [%[a_ptr]], 16\n" + "ld1 {v1.16b}, [%[a_ptr]], 16\n" + "ld1 {v2.16b}, [%[a_ptr]], 16\n" + "ld1 {v3.16b}, [%[a_ptr]], 16\n" + + "ld1 {v21.16b}, [%[b_ptr]], 16\n" + "smlal v4.8h, v0.8b, v20.8b\n" + "smlal v5.8h, v1.8b, v20.8b\n" + "smlal v6.8h, v2.8b, v20.8b\n" + "smlal v7.8h, v3.8b, v20.8b\n" + "smlal2 v4.8h, v0.16b, v20.16b\n" + "smlal2 v5.8h, v1.16b, v20.16b\n" + "smlal2 v6.8h, v2.16b, v20.16b\n" + "smlal2 v7.8h, v3.16b, v20.16b\n" + + "ld1 {v22.16b}, [%[b_ptr]], 16\n" + "smlal v8.8h, v0.8b, v21.8b\n" + "smlal v9.8h, v1.8b, v21.8b\n" + "smlal v10.8h, v2.8b, v21.8b\n" + "smlal v11.8h, v3.8b, v21.8b\n" + "smlal2 v8.8h, v0.16b, v21.16b\n" + "smlal2 v9.8h, v1.16b, v21.16b\n" + "smlal2 v10.8h, v2.16b, v21.16b\n" + "smlal2 v11.8h, v3.16b, v21.16b\n" + + "ld1 {v23.16b}, [%[b_ptr]], 16\n" + "smlal v12.8h, v0.8b, v22.8b\n" + "smlal v13.8h, v1.8b, v22.8b\n" + "smlal v14.8h, v2.8b, v22.8b\n" + "smlal v15.8h, v3.8b, v22.8b\n" + "smlal2 v12.8h, v0.16b, v22.16b\n" + "smlal2 v13.8h, v1.16b, v22.16b\n" + "smlal2 v14.8h, v2.16b, v22.16b\n" + "smlal2 v15.8h, v3.16b, v22.16b\n" + + "smlal v16.8h, v0.8b, v23.8b\n" + "smlal v17.8h, v1.8b, v23.8b\n" + "smlal v18.8h, v2.8b, v23.8b\n" + "smlal v19.8h, v3.8b, v23.8b\n" + "smlal2 v16.8h, v0.16b, v23.16b\n" + "smlal2 v17.8h, v1.16b, v23.16b\n" + "smlal2 v18.8h, v2.16b, v23.16b\n" + "smlal2 v19.8h, v3.16b, v23.16b\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 1b\n" + + "cmp %w[is_first_k], #1\n" + "beq 2f\n" LOAD_C + "b 3f\n" + + "2:\n" // Clear the C regs. + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + + "3:\n" + // Reduce v4-v19 to v0-v3 + "addv h20, v4.8h\n" + "addv h21, v8.8h\n" + "addv h22, v12.8h\n" + "addv h23, v16.8h\n" + "ins v0.h[0], v20.h[0]\n" + "ins v0.h[1], v21.h[0]\n" + "ins v0.h[2], v22.h[0]\n" + "ins v0.h[3], v23.h[0]\n" + "add v24.4h, v24.4h, v0.4h\n" + + "addv h28, v5.8h\n" + "addv h29, v9.8h\n" + "addv h30, v13.8h\n" + "addv h31, v17.8h\n" + "ins v1.h[0], v28.h[0]\n" + "ins v1.h[1], v29.h[0]\n" + "ins v1.h[2], v30.h[0]\n" + "ins v1.h[3], v31.h[0]\n" + "add v25.4h, v25.4h, v1.4h\n" + + "addv h20, v6.8h\n" + "addv h21, v10.8h\n" + "addv h22, v14.8h\n" + "addv h23, v18.8h\n" + "ins v2.h[0], v20.h[0]\n" + "ins v2.h[1], v21.h[0]\n" + "ins v2.h[2], v22.h[0]\n" + "ins v2.h[3], v23.h[0]\n" + "add v26.4h, v26.4h, v2.4h\n" + + "addv h28, v7.8h\n" + "addv h29, v11.8h\n" + "addv h30, v15.8h\n" + "addv h31, v19.8h\n" + "ins v3.h[0], v28.h[0]\n" + "ins v3.h[1], v29.h[0]\n" + "ins v3.h[2], v30.h[0]\n" + "ins v3.h[3], v31.h[0]\n" + "add v27.4h, v27.4h, v3.4h\n" + + // Store back into memory + STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr] "+r"(outptr), [m_remain] "+r"(m_remain), + [n_remain] "+r"(n_remain) + : + : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", + "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y + 3 < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + //! read 4 * 16 in each row + for (; K > 15; K -= 16) { + interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); + } + } + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + //! read 4 * 16 in each row + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); + } + } +} + +static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 16) * 4; + int8_t* outptr = out; + + int k = k0; + for (; k < kmax; k += 16) { + int ki = k; + for (int cnt = 0; cnt < 2; ki += 8, cnt++) { + const int8_t* inptr0 = in + ki * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int8_t* outptr_inner = outptr + ki - k; + + int remain = std::min(ki + 7 - kmax, 7); + int x = x0; + for (; x + 3 < xmax; x += 4) { + if (remain >= 0) { + switch (remain) { + case 7: + inptr0 = zerobuff; + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, + inptr4, inptr5, inptr6, inptr7, + outptr_inner); + outptr_inner += ksize4; + } + + if (x < xmax) { + if (remain >= 0) { + switch (remain) { + case 7: + inptr0 = zerobuff; + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + for (; x < xmax; x++) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + *outptr_inner++ = *inptr4++; + *outptr_inner++ = *inptr5++; + *outptr_inner++ = *inptr6++; + *outptr_inner++ = *inptr7++; + outptr_inner += 8; + } + } + } + + outptr += 16 * 4; + } +} + +} // namespace matmul_4x4x16 +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h new file mode 100644 index 00000000..98f925c1 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h @@ -0,0 +1,1300 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_8x8x8 { + +/** + * Overview of register layout: + * + * A 8x8x8 cell of Rhs is stored in 8bit in v16 + * A 8x8x8 cell of Lhs is stored in 8bit in v0-v7 + * A 8x8 block of accumulators is stored in 16bit in v8-v15 + * + * +---------+ + * |v16[0-8] | + * Rhs +---------+ + * Lhs | | + * + * +--------+ - - - - +---------+ + * |v0[0-8]| | v8[0-8] | + * |v1[0-8]| | v9[0-8] | + * |v2[0-8]| | v10[0-8]| + * |v3[0-8]| | v11[0-8]| + * |v4[0-8]| | v12[0-8]| + * |v5[0-8]| | v13[0-8]| + * |v6[0-8]| | v14[0-8]| + * |v7[0-8]| | v15[0-8]| + * +--------+ - - - - +---------+ + * + * Accumulator + */ +static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k) { + K /= 8; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int16_t); + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "ld1 {v" reg_index ".8h}, [x" n "]\n" +#define LOAD_C \ + LOAD_LINE("8", "0") \ + LOAD_LINE("9", "1") \ + LOAD_LINE("10", "2") \ + LOAD_LINE("11", "3") \ + LOAD_LINE("12", "4") \ + LOAD_LINE("13", "5") \ + LOAD_LINE("14", "6") \ + LOAD_LINE("15", "7") + +#define STORE_LINE(reg_index, num) \ + "st1 {v" reg_index ".8h}, [x" num "]\n" + +#define STORE_C \ + STORE_LINE("8", "0") \ + STORE_LINE("9", "1") \ + STORE_LINE("10", "2") \ + STORE_LINE("11", "3") \ + STORE_LINE("12", "4") \ + STORE_LINE("13", "5") \ + STORE_LINE("14", "6") \ + STORE_LINE("15", "7") + +#define CLEAR_8_INT16(reg) \ + "eor v" reg ".16b, v" reg ".16b, v" reg ".16b\n" +#define CLEAR_8_REGS \ + CLEAR_8_INT16("8") \ + CLEAR_8_INT16("9") \ + CLEAR_8_INT16("10") \ + CLEAR_8_INT16("11") \ + CLEAR_8_INT16("12") \ + CLEAR_8_INT16("13") \ + CLEAR_8_INT16("14") \ + CLEAR_8_INT16("15") + + // clang-format on + + register int16_t* outptr asm("x0") = output; + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "add x4, x3, %x[LDC]\n" + "add x5, x4, %x[LDC]\n" + "add x6, x5, %x[LDC]\n" + "add x7, x6, %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + "b 2f\n" + + "1:\n" CLEAR_8_REGS + + "2: \n" + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "ld1 {v0.8b}, [%[a_ptr]], 8\n" + "ld1 {v1.8b}, [%[a_ptr]], 8\n" + "ld1 {v2.8b}, [%[a_ptr]], 8\n" + "ld1 {v3.8b}, [%[a_ptr]], 8\n" + "ld1 {v4.8b}, [%[a_ptr]], 8\n" + "ld1 {v5.8b}, [%[a_ptr]], 8\n" + "ld1 {v6.8b}, [%[a_ptr]], 8\n" + "ld1 {v7.8b}, [%[a_ptr]], 8\n" + + "sshll v16.8h, v16.8b, #0\n" + "sshll v0.8h, v0.8b, #0\n" + "sshll v1.8h, v1.8b, #0\n" + "sshll v2.8h, v2.8b, #0\n" + "sshll v3.8h, v3.8b, #0\n" + "sshll v4.8h, v4.8b, #0\n" + "sshll v5.8h, v5.8b, #0\n" + "sshll v6.8h, v6.8b, #0\n" + "sshll v7.8h, v7.8b, #0\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "mla v8.8h, v16.8h, v0.h[0]\n" + "mla v9.8h, v16.8h, v1.h[0]\n" + "mla v10.8h, v16.8h, v2.h[0]\n" + "mla v11.8h, v16.8h, v3.h[0]\n" + "mla v12.8h, v16.8h, v4.h[0]\n" + "mla v13.8h, v16.8h, v5.h[0]\n" + "mla v14.8h, v16.8h, v6.h[0]\n" + "mla v15.8h, v16.8h, v7.h[0]\n" + "sshll v17.8h, v17.8b, #0\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "mla v8.8h, v17.8h, v0.h[1]\n" + "mla v9.8h, v17.8h, v1.h[1]\n" + "mla v10.8h, v17.8h, v2.h[1]\n" + "mla v11.8h, v17.8h, v3.h[1]\n" + "mla v12.8h, v17.8h, v4.h[1]\n" + "mla v13.8h, v17.8h, v5.h[1]\n" + "mla v14.8h, v17.8h, v6.h[1]\n" + "mla v15.8h, v17.8h, v7.h[1]\n" + "sshll v16.8h, v16.8b, #0\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "mla v8.8h, v16.8h, v0.h[2]\n" + "mla v9.8h, v16.8h, v1.h[2]\n" + "mla v10.8h, v16.8h, v2.h[2]\n" + "mla v11.8h, v16.8h, v3.h[2]\n" + "mla v12.8h, v16.8h, v4.h[2]\n" + "mla v13.8h, v16.8h, v5.h[2]\n" + "mla v14.8h, v16.8h, v6.h[2]\n" + "mla v15.8h, v16.8h, v7.h[2]\n" + "sshll v17.8h, v17.8b, #0\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "mla v8.8h, v17.8h, v0.h[3]\n" + "mla v9.8h, v17.8h, v1.h[3]\n" + "mla v10.8h, v17.8h, v2.h[3]\n" + "mla v11.8h, v17.8h, v3.h[3]\n" + "mla v12.8h, v17.8h, v4.h[3]\n" + "mla v13.8h, v17.8h, v5.h[3]\n" + "mla v14.8h, v17.8h, v6.h[3]\n" + "mla v15.8h, v17.8h, v7.h[3]\n" + "sshll v16.8h, v16.8b, #0\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "mla v8.8h, v16.8h, v0.h[4]\n" + "mla v9.8h, v16.8h, v1.h[4]\n" + "mla v10.8h, v16.8h, v2.h[4]\n" + "mla v11.8h, v16.8h, v3.h[4]\n" + "mla v12.8h, v16.8h, v4.h[4]\n" + "mla v13.8h, v16.8h, v5.h[4]\n" + "mla v14.8h, v16.8h, v6.h[4]\n" + "mla v15.8h, v16.8h, v7.h[4]\n" + "sshll v17.8h, v17.8b, #0\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "mla v8.8h, v17.8h, v0.h[5]\n" + "mla v9.8h, v17.8h, v1.h[5]\n" + "mla v10.8h, v17.8h, v2.h[5]\n" + "mla v11.8h, v17.8h, v3.h[5]\n" + "mla v12.8h, v17.8h, v4.h[5]\n" + "mla v13.8h, v17.8h, v5.h[5]\n" + "mla v14.8h, v17.8h, v6.h[5]\n" + "mla v15.8h, v17.8h, v7.h[5]\n" + "sshll v16.8h, v16.8b, #0\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "mla v8.8h, v16.8h, v0.h[6]\n" + "mla v9.8h, v16.8h, v1.h[6]\n" + "mla v10.8h, v16.8h, v2.h[6]\n" + "mla v11.8h, v16.8h, v3.h[6]\n" + "mla v12.8h, v16.8h, v4.h[6]\n" + "mla v13.8h, v16.8h, v5.h[6]\n" + "mla v14.8h, v16.8h, v6.h[6]\n" + "mla v15.8h, v16.8h, v7.h[6]\n" + "sshll v17.8h, v17.8b, #0\n" + + "mla v8.8h, v17.8h, v0.h[7]\n" + "mla v9.8h, v17.8h, v1.h[7]\n" + "mla v10.8h, v17.8h, v2.h[7]\n" + "mla v11.8h, v17.8h, v3.h[7]\n" + "mla v12.8h, v17.8h, v4.h[7]\n" + "mla v13.8h, v17.8h, v5.h[7]\n" + "mla v14.8h, v17.8h, v6.h[7]\n" + "mla v15.8h, v17.8h, v7.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "bne 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3", + "x4", "x5", "x6", "x7", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +#undef CLEAR_8_INT16 +#undef CLEAR_8_REGS +} + +/** + * Overview of register layout: + * + * A 8x4x8 cell of Rhs is stored in 8bit in v16-v17 + * A 8x8x8 cell of Lhs is stored in 8bit in v0-v7 + * A 8x4 block of accumulators is stored in 16bit in v8-v15 + * + * +--------+ + * |v16[0-4]| + * Rhs +--------+ + * | | + * Lhs + * +--------+ - - - - +--------+ + * |v0[0-8]| | v8[0-4]| + * |v1[0-8]| | v9[0-4]| + * |v2[0-8]| |v10[0-4]| + * |v3[0-8]| |v11[0-4]| + * |v4[0-8]| |v12[0-4]| + * |v5[0-8]| |v13[0-4]| + * |v6[0-8]| |v14[0-4]| + * |v7[0-8]| |v15[0-4]| + * +--------+ - - - - +--------+ + * + * Accumulator + */ + +static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, + size_t n_remain) { + K /= 8; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int16_t); + int16_t* outptr0 = output; + int16_t* outptr1; + int16_t* outptr2; + int16_t* outptr3; + int16_t* outptr4; + int16_t* outptr5; + int16_t* outptr6; + int16_t* outptr7; + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ld1 {v" reg_index ".4h}, [%[x0]]\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[0], [%[x0]], #2\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[1], [%[x0]], #2\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[2], [%[x0]], #2\n" \ + "101" n ":\n" + +#define LOAD_C \ + LOAD_LINE("8", "0") \ + LOAD_LINE("9", "1") \ + LOAD_LINE("10", "2") \ + LOAD_LINE("11", "3") \ + LOAD_LINE("12", "4") \ + LOAD_LINE("13", "5") \ + LOAD_LINE("14", "6") \ + LOAD_LINE("15", "7") + +#define STORE_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 102" n "f\n" \ + "st1 {v" reg_index ".4h}, [%[x0]]\n" \ + "b 103" n "f\n" \ + "102" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".h}[0], [%[x0]], #2\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".h}[1], [%[x0]], #2\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".h}[2], [%[x0]], #2\n" \ + "103" n ":\n" + +#define STORE_C \ + STORE_LINE("8", "0") \ + STORE_LINE("9", "1") \ + STORE_LINE("10", "2") \ + STORE_LINE("11", "3") \ + STORE_LINE("12", "4") \ + STORE_LINE("13", "5") \ + STORE_LINE("14", "6") \ + STORE_LINE("15", "7") + +#define CLEAR_8_INT16(reg) \ + "eor v" reg ".16b, v" reg ".16b, v" reg ".16b\n" +#define CLEAR_8_REGS \ + CLEAR_8_INT16("8") \ + CLEAR_8_INT16("9") \ + CLEAR_8_INT16("10") \ + CLEAR_8_INT16("11") \ + CLEAR_8_INT16("12") \ + CLEAR_8_INT16("13") \ + CLEAR_8_INT16("14") \ + CLEAR_8_INT16("15") + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "add %[outptr4], %[outptr3], %x[LDC]\n" + "add %[outptr5], %[outptr4], %x[LDC]\n" + "add %[outptr6], %[outptr5], %x[LDC]\n" + "add %[outptr7], %[outptr6], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + "b 2f\n" + + "1:\n" CLEAR_8_REGS + + "2: \n" + "ld1 {v16.s}[0], [%[b_ptr]], 4\n" + "ld1 {v0.8b}, [%[a_ptr]], 8\n" + "ld1 {v1.8b}, [%[a_ptr]], 8\n" + "ld1 {v2.8b}, [%[a_ptr]], 8\n" + "ld1 {v3.8b}, [%[a_ptr]], 8\n" + "ld1 {v4.8b}, [%[a_ptr]], 8\n" + "ld1 {v5.8b}, [%[a_ptr]], 8\n" + "ld1 {v6.8b}, [%[a_ptr]], 8\n" + "ld1 {v7.8b}, [%[a_ptr]], 8\n" + "sshll v16.8h, v16.8b, #0\n" + "sshll v0.8h, v0.8b, #0\n" + "sshll v1.8h, v1.8b, #0\n" + "sshll v2.8h, v2.8b, #0\n" + "sshll v3.8h, v3.8b, #0\n" + "sshll v4.8h, v4.8b, #0\n" + "sshll v5.8h, v5.8b, #0\n" + "sshll v6.8h, v6.8b, #0\n" + "sshll v7.8h, v7.8b, #0\n" + + "ld1 {v17.s}[0], [%[b_ptr]], 4\n" + "mla v8.4h, v16.4h, v0.h[0]\n" + "mla v9.4h, v16.4h, v1.h[0]\n" + "mla v10.4h, v16.4h, v2.h[0]\n" + "mla v11.4h, v16.4h, v3.h[0]\n" + "sshll v17.8h, v17.8b, #0\n" + "mla v12.4h, v16.4h, v4.h[0]\n" + "mla v13.4h, v16.4h, v5.h[0]\n" + "mla v14.4h, v16.4h, v6.h[0]\n" + "mla v15.4h, v16.4h, v7.h[0]\n" + + "ld1 {v16.s}[0], [%[b_ptr]], 4\n" + "mla v8.4h, v17.4h, v0.h[1]\n" + "mla v9.4h, v17.4h, v1.h[1]\n" + "mla v10.4h, v17.4h, v2.h[1]\n" + "mla v11.4h, v17.4h, v3.h[1]\n" + "sshll v16.8h, v16.8b, #0\n" + "mla v12.4h, v17.4h, v4.h[1]\n" + "mla v13.4h, v17.4h, v5.h[1]\n" + "mla v14.4h, v17.4h, v6.h[1]\n" + "mla v15.4h, v17.4h, v7.h[1]\n" + + "ld1 {v17.s}[0], [%[b_ptr]], 4\n" + "mla v8.4h, v16.4h, v0.h[2]\n" + "mla v9.4h, v16.4h, v1.h[2]\n" + "mla v10.4h, v16.4h, v2.h[2]\n" + "mla v11.4h, v16.4h, v3.h[2]\n" + "sshll v17.8h, v17.8b, #0\n" + "mla v12.4h, v16.4h, v4.h[2]\n" + "mla v13.4h, v16.4h, v5.h[2]\n" + "mla v14.4h, v16.4h, v6.h[2]\n" + "mla v15.4h, v16.4h, v7.h[2]\n" + + "ld1 {v16.s}[0], [%[b_ptr]], 4\n" + "mla v8.4h, v17.4h, v0.h[3]\n" + "mla v9.4h, v17.4h, v1.h[3]\n" + "mla v10.4h, v17.4h, v2.h[3]\n" + "mla v11.4h, v17.4h, v3.h[3]\n" + "sshll v16.8h, v16.8b, #0\n" + "mla v12.4h, v17.4h, v4.h[3]\n" + "mla v13.4h, v17.4h, v5.h[3]\n" + "mla v14.4h, v17.4h, v6.h[3]\n" + "mla v15.4h, v17.4h, v7.h[3]\n" + + "ld1 {v17.s}[0], [%[b_ptr]], 4\n" + "mla v8.4h, v16.4h, v0.h[4]\n" + "mla v9.4h, v16.4h, v1.h[4]\n" + "mla v10.4h, v16.4h, v2.h[4]\n" + "mla v11.4h, v16.4h, v3.h[4]\n" + "sshll v17.8h, v17.8b, #0\n" + "mla v12.4h, v16.4h, v4.h[4]\n" + "mla v13.4h, v16.4h, v5.h[4]\n" + "mla v14.4h, v16.4h, v6.h[4]\n" + "mla v15.4h, v16.4h, v7.h[4]\n" + + "ld1 {v16.s}[0], [%[b_ptr]], 4\n" + "mla v8.4h, v17.4h, v0.h[5]\n" + "mla v9.4h, v17.4h, v1.h[5]\n" + "mla v10.4h, v17.4h, v2.h[5]\n" + "mla v11.4h, v17.4h, v3.h[5]\n" + "sshll v16.8h, v16.8b, #0\n" + "mla v12.4h, v17.4h, v4.h[5]\n" + "mla v13.4h, v17.4h, v5.h[5]\n" + "mla v14.4h, v17.4h, v6.h[5]\n" + "mla v15.4h, v17.4h, v7.h[5]\n" + + "ld1 {v17.s}[0], [%[b_ptr]], 4\n" + "mla v8.4h, v16.4h, v0.h[6]\n" + "mla v9.4h, v16.4h, v1.h[6]\n" + "mla v10.4h, v16.4h, v2.h[6]\n" + "mla v11.4h, v16.4h, v3.h[6]\n" + "sshll v17.8h, v17.8b, #0\n" + "mla v12.4h, v16.4h, v4.h[6]\n" + "mla v13.4h, v16.4h, v5.h[6]\n" + "mla v14.4h, v16.4h, v6.h[6]\n" + "mla v15.4h, v16.4h, v7.h[6]\n" + + "mla v8.4h, v17.4h, v0.h[7]\n" + "mla v9.4h, v17.4h, v1.h[7]\n" + "mla v10.4h, v17.4h, v2.h[7]\n" + "mla v11.4h, v17.4h, v3.h[7]\n" + "mla v12.4h, v17.4h, v4.h[7]\n" + "mla v13.4h, v17.4h, v5.h[7]\n" + "mla v14.4h, v17.4h, v6.h[7]\n" + "mla v15.4h, v17.4h, v7.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), + [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), + [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +#undef CLEAR_8_INT16 +#undef CLEAR_8_REGS +} + +/** + * Overview of register layout: + * + * A 8x8x8 cell of Rhs is stored in 8bit in v8-v9 + * A 8x8x4 cell of Lhs is stored in 8bit in v0-v3 + * A 4x8 block of accumulators is stored in 16bit in v4-v7 + * + * +--------+ + * | v8[0-8]| + * +--------+ + * | v9[0-8]| + * Rhs +--------+ + * Lhs | | + * + * +-------+ - - - - -+--------+ + * |v0[0-8]| | v4[0-8]| + * |v1[0-8]| | v5[0-8]| + * |v2[0-8]| | v6[0-8]| + * |v3[0-8]| | v7[0-8]| + * +-------+ - - - - -+--------+ + * + * Accumulator + */ + +static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, + size_t m_remain) { + K /= 8; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int16_t); + int16_t* outptr0 = output; + int16_t* outptr1; + int16_t* outptr2; + int16_t* outptr3; + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(reg_index, m) \ + "cbz %[x0], 100f\n" \ + "ld1 {v" reg_index ".8h}, [%[outptr" m "]], #16\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %x[m_remain]\n" \ + LOAD_LINE("4", "0") \ + LOAD_LINE("5", "1") \ + LOAD_LINE("6", "2") \ + LOAD_LINE("7", "3") \ + "100:\n" + +#define STORE_LINE(reg_index, m) \ + "cbz %[x0], 101f\n" \ + "st1 {v" reg_index ".8h}, [%[outptr" m "]]\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %x[m_remain]\n" \ + STORE_LINE("4", "0") \ + STORE_LINE("5", "1") \ + STORE_LINE("6", "2") \ + STORE_LINE("7", "3") \ + "101:\n" + +#define CLEAR_8_INT16(reg_index) \ + "eor v" reg_index ".16b, v" reg_index ".16b, v" reg_index ".16b\n" +#define CLEAR_4_REGS \ + CLEAR_8_INT16("4") \ + CLEAR_8_INT16("5") \ + CLEAR_8_INT16("6") \ + CLEAR_8_INT16("7") + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + "b 2f\n" + + "1:\n" CLEAR_4_REGS + + "2: \n" + "ld1 {v8.8b}, [%[b_ptr]], 8\n" + "ld1 {v0.8b}, [%[a_ptr]], 8\n" + "ld1 {v1.8b}, [%[a_ptr]], 8\n" + "ld1 {v2.8b}, [%[a_ptr]], 8\n" + "ld1 {v3.8b}, [%[a_ptr]], 8\n" + "sshll v8.8h, v8.8b, #0\n" + "sshll v0.8h, v0.8b, #0\n" + "sshll v1.8h, v1.8b, #0\n" + "sshll v2.8h, v2.8b, #0\n" + "sshll v3.8h, v3.8b, #0\n" + + "ld1 {v9.8b}, [%[b_ptr]], 8\n" + "mla v4.8h, v8.8h, v0.h[0]\n" + "mla v5.8h, v8.8h, v1.h[0]\n" + "mla v6.8h, v8.8h, v2.h[0]\n" + "mla v7.8h, v8.8h, v3.h[0]\n" + "sshll v9.8h, v9.8b, #0\n" + + "ld1 {v8.8b}, [%[b_ptr]], 8\n" + "mla v4.8h, v9.8h, v0.h[1]\n" + "mla v5.8h, v9.8h, v1.h[1]\n" + "mla v6.8h, v9.8h, v2.h[1]\n" + "mla v7.8h, v9.8h, v3.h[1]\n" + "sshll v8.8h, v8.8b, #0\n" + + "ld1 {v9.8b}, [%[b_ptr]], 8\n" + "mla v4.8h, v8.8h, v0.h[2]\n" + "mla v5.8h, v8.8h, v1.h[2]\n" + "mla v6.8h, v8.8h, v2.h[2]\n" + "mla v7.8h, v8.8h, v3.h[2]\n" + "sshll v9.8h, v9.8b, #0\n" + + "ld1 {v8.8b}, [%[b_ptr]], 8\n" + "mla v4.8h, v9.8h, v0.h[3]\n" + "mla v5.8h, v9.8h, v1.h[3]\n" + "mla v6.8h, v9.8h, v2.h[3]\n" + "mla v7.8h, v9.8h, v3.h[3]\n" + "sshll v8.8h, v8.8b, #0\n" + + "ld1 {v9.8b}, [%[b_ptr]], 8\n" + "mla v4.8h, v8.8h, v0.h[4]\n" + "mla v5.8h, v8.8h, v1.h[4]\n" + "mla v6.8h, v8.8h, v2.h[4]\n" + "mla v7.8h, v8.8h, v3.h[4]\n" + "sshll v9.8h, v9.8b, #0\n" + + "ld1 {v8.8b}, [%[b_ptr]], 8\n" + "mla v4.8h, v9.8h, v0.h[5]\n" + "mla v5.8h, v9.8h, v1.h[5]\n" + "mla v6.8h, v9.8h, v2.h[5]\n" + "mla v7.8h, v9.8h, v3.h[5]\n" + "sshll v8.8h, v8.8b, #0\n" + + "ld1 {v9.8b}, [%[b_ptr]], 8\n" + "mla v4.8h, v8.8h, v0.h[6]\n" + "mla v5.8h, v8.8h, v1.h[6]\n" + "mla v6.8h, v8.8h, v2.h[6]\n" + "mla v7.8h, v8.8h, v3.h[6]\n" + "sshll v9.8h, v9.8b, #0\n" + + "mla v4.8h, v9.8h, v0.h[7]\n" + "mla v5.8h, v9.8h, v1.h[7]\n" + "mla v6.8h, v9.8h, v2.h[7]\n" + "mla v7.8h, v9.8h, v3.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), + [m_remain] "+r"(m_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", + "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +#undef CLEAR_8_INT16 +#undef CLEAR_4_REGS +} + +/** + * Overview of register layout: + * + * A 8x4x8 cell of Rhs is stored in 8bit in v8-v9 + * A 4x8x8 cell of Lhs is stored in 8bit in q0-q3 + * A 8x8 block of accumulators is stored in 16bit in q4-q7 + * + * +--------+ + * | q8[0-4]| + * +--------+ + * | q9[0-4]| + * Rhs +--------+ + * Lhs | | + * + * +--------+ - - - - +--------- + * |q0[0-8]| | q4[0-4]| + * |q1[0-8]| | q5[0-4]| + * |q2[0-8]| | q6[0-4]| + * |q3[0-8]| | q7[0-4]| + * +--------+ - - - - +--------- + * + * Accumulator + */ +static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, size_t m_remain, + size_t n_remain) { + K /= 8; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int16_t); + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "cmp %[x0], #0 \n" \ + "beq 102f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ld1 {v" reg_index ".4h}, [x" n " ], #8\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[0], [x" n "], #2\n" \ + "cmp %[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[1], [x" n "], #2\n" \ + "cmp %[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[2], [x" n "], #2\n" \ + "101" n ":\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %[m_remain]\n" \ + "mov x1, x0\n" \ + LOAD_LINE("4", "1") \ + "add x1, x0, %x[LDC]\n" \ + "add x0, x0, %x[LDC]\n" \ + LOAD_LINE("5", "1") \ + "add x1, x0, %x[LDC]\n" \ + "add x0, x0, %x[LDC]\n" \ + LOAD_LINE("6", "1") \ + "add x1, x0, %x[LDC]\n" \ + LOAD_LINE("7", "1") \ + "102:\n" + +#define STORE_LINE(reg_index, n) \ + "cmp %[x0], #0 \n" \ + "beq 105f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "st1 {v" reg_index ".4h}, [x" n "]\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".h}[0], [x" n "], #2\n" \ + "cmp %[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".h}[1], [x" n "], #2\n" \ + "cmp %[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".h}[2], [x" n "], #2\n" \ + "104" n ":\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %[m_remain]\n" \ + "mov x1, x0\n" \ + STORE_LINE("4", "1") \ + "add x1, x0, %x[LDC]\n" \ + "add x0, x0, %x[LDC]\n" \ + STORE_LINE("5", "1") \ + "add x1, x0, %x[LDC]\n" \ + "add x0, x0, %x[LDC]\n" \ + STORE_LINE("6", "1") \ + "add x1, x0, %x[LDC]\n" \ + STORE_LINE("7", "1") \ + "105:\n" + + // clang-format on + + register int16_t* outptr asm("x0") = output; + asm volatile( + // load accumulator C + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + + "2: \n" + "ld1 {v8.s}[0], [%[b_ptr]], 4\n" + "ld1 {v0.8b}, [%[a_ptr]], 8\n" + "ld1 {v1.8b}, [%[a_ptr]], 8\n" + "ld1 {v2.8b}, [%[a_ptr]], 8\n" + "ld1 {v3.8b}, [%[a_ptr]], 8\n" + "sshll v8.8h, v8.8b, #0\n" + "sshll v0.8h, v0.8b, #0\n" + "sshll v1.8h, v1.8b, #0\n" + "sshll v2.8h, v2.8b, #0\n" + "sshll v3.8h, v3.8b, #0\n" + + "ld1 {v9.s}[0], [%[b_ptr]], 4\n" + "mla v4.4h, v8.4h, v0.h[0]\n" + "mla v5.4h, v8.4h, v1.h[0]\n" + "mla v6.4h, v8.4h, v2.h[0]\n" + "mla v7.4h, v8.4h, v3.h[0]\n" + "sshll v9.8h, v9.8b, #0\n" + + "ld1 {v8.s}[0], [%[b_ptr]], 4\n" + "mla v4.4h, v9.4h, v0.h[1]\n" + "mla v5.4h, v9.4h, v1.h[1]\n" + "mla v6.4h, v9.4h, v2.h[1]\n" + "mla v7.4h, v9.4h, v3.h[1]\n" + "sshll v8.8h, v8.8b, #0\n" + + "ld1 {v9.s}[0], [%[b_ptr]], 4\n" + "mla v4.4h, v8.4h, v0.h[2]\n" + "mla v5.4h, v8.4h, v1.h[2]\n" + "mla v6.4h, v8.4h, v2.h[2]\n" + "mla v7.4h, v8.4h, v3.h[2]\n" + "sshll v9.8h, v9.8b, #0\n" + + "ld1 {v8.s}[0], [%[b_ptr]], 4\n" + "mla v4.4h, v9.4h, v0.h[3]\n" + "mla v5.4h, v9.4h, v1.h[3]\n" + "mla v6.4h, v9.4h, v2.h[3]\n" + "mla v7.4h, v9.4h, v3.h[3]\n" + "sshll v8.8h, v8.8b, #0\n" + + "ld1 {v9.s}[0], [%[b_ptr]], 4\n" + "mla v4.4h, v8.4h, v0.h[4]\n" + "mla v5.4h, v8.4h, v1.h[4]\n" + "mla v6.4h, v8.4h, v2.h[4]\n" + "mla v7.4h, v8.4h, v3.h[4]\n" + "sshll v9.8h, v9.8b, #0\n" + + "ld1 {v8.s}[0], [%[b_ptr]], 4\n" + "mla v4.4h, v9.4h, v0.h[5]\n" + "mla v5.4h, v9.4h, v1.h[5]\n" + "mla v6.4h, v9.4h, v2.h[5]\n" + "mla v7.4h, v9.4h, v3.h[5]\n" + "sshll v8.8h, v8.8b, #0\n" + + "ld1 {v9.s}[0], [%[b_ptr]], 4\n" + "mla v4.4h, v8.4h, v0.h[6]\n" + "mla v5.4h, v8.4h, v1.h[6]\n" + "mla v6.4h, v8.4h, v2.h[6]\n" + "mla v7.4h, v8.4h, v3.h[6]\n" + "sshll v9.8h, v9.8b, #0\n" + + "mla v4.4h, v9.4h, v0.h[7]\n" + "mla v5.4h, v9.4h, v1.h[7]\n" + "mla v6.4h, v9.4h, v2.h[7]\n" + "mla v7.4h, v9.4h, v3.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "bne 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), + [K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), + [x0] "+r"(x0), [m_remain] "+r"(m_remain), + [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1", + "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + for (; K > 15; K -= 16) { + interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + } + + if (K > 0) { + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 8, K); + } + } + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x8_2_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K); + } + } +} + +static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, + int ldin, int x0, int xmax, + int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 8) * 4; + const int ksize8 = ksize4 * 2; + int8_t* outptr = out; + int8_t* outptr_base = out; + //! 4x4 block output start pos + int8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k += 8) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = x0; + outptr = outptr_base; + + for (; x + 7 < xmax; x += 8) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += ksize8; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, 4); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, xmax - x); + } + + outptr_base += 8 * 8; + outptr_base4 += 4 * 8; + } +} + +static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 8) * 4; + const int ksize8 = ksize4 * 2; + int8_t* outptr = out; + int8_t* outptr_base = out; + int8_t* outptr_interleave = nullptr; + int8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k += 8) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = x0; + outptr = outptr_base; + + for (; x + 7 < xmax; x += 8) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + outptr_interleave = outptr; + interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr_interleave); + outptr += ksize8; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + outptr_interleave = outptr; + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr_interleave, 4, 4); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + outptr_interleave = outptr; + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr_interleave, 4, xmax - x); + } + + outptr_base += 8 * 8; + outptr_base4 += 4 * 8; + } +} + +static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, + const dt_int8* inptr, int ldin, + int y0, int ymax, int k0, + int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + constexpr int interleave4 = 32; + constexpr int interleave8 = 64; + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + for (; K > 7; K -= 8) { + transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += interleave8; + } + + if (K > 0) { + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 8, K); + outptr += interleave8; + } + } + + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 7; K -= 8) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += interleave4; + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K); + outptr += interleave4; + } + } +} +} // namespace matmul_8x8x8 +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp new file mode 100644 index 00000000..ff495ab6 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp @@ -0,0 +1,200 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" +#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" +#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/matrix_mul/gemm_common.h" + +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +// ===========================gemm_s8x8x16_4x4================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8); + +void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, + int ymax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n(out, in, ldin, y0, + ymax, k0, kmax); + } else { + matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n(out, in, ldin, x0, + xmax, k0, kmax); + } else { + matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, + kmax); + } +} + +void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, + size_t M, size_t N, size_t K, dt_int16* C, + size_t LDC, bool is_first_k, const dt_int16*, + dt_int16*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + (A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int16), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t B_INTERLEAVE = 8; + //! K is packed to times of 4 + K = round_up(K, 8); + const int K8 = K * 8; + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int16_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, + is_first_k); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + packA += K8; + } + + for (; m < M; m += 4) { + int16_t* output = C + (m * LDC); + const dt_int8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), + std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + packA += K4; + } +} + +// ===========================gemm_s8x8x16_4x4================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4); + +void gemm_s8x8x16_4x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, + int ymax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, + kmax); + } else { + matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_s8x8x16_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, + kmax); + } else { + matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, + kmax); + } +} + +void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, + size_t M, size_t N, size_t K, dt_int16* C, + size_t LDC, bool is_first_k, const dt_int16*, + dt_int16*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + (A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int16), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 4; + //! K is packed to times of 4 + K = round_up(K, 16); + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int16_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k, A_INTERLEAVE, B_INTERLEAVE); + output += B_INTERLEAVE; + cur_packB += K4; + } + + for (; n < N; n += B_INTERLEAVE) { + matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k, A_INTERLEAVE, + std::min(N - n, B_INTERLEAVE)); + output += B_INTERLEAVE; + cur_packB += K4; + } + + packA += K4; + } + + for (; m < M; m += A_INTERLEAVE) { + int16_t* output = C + (m * LDC); + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n < N; n += B_INTERLEAVE) { + matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k, + std::min(M - m, A_INTERLEAVE), + std::min(N - n, B_INTERLEAVE)); + output += B_INTERLEAVE; + cur_packB += K4; + } + packA += K4; + } +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h new file mode 100644 index 00000000..2b06ceef --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h @@ -0,0 +1,27 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, + gemm_s8x8x16_8x8); +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 16, false, true, + gemm_s8x8x16_4x4); + +} // namespace matmul +} // namespace aarch64 +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.cpp b/dnn/src/aarch64/matrix_mul/opr_impl.cpp new file mode 100644 index 00000000..384b2e12 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/opr_impl.cpp @@ -0,0 +1,93 @@ +/** + * \file dnn/src/aarch64/matrix_mul/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/aarch64/matrix_mul/opr_impl.h" +#include "src/aarch64/matrix_mul/algos.h" +#include "src/common/metahelper.h" +#include "src/common/utils.h" + +using namespace megdnn; +using namespace aarch64; + +class MatrixMulImpl::AlgoPack : NonCopyableObj { + AlgoF32K8x12x1 f32K8x12x1; + AlgoF32K4x16x1 f32k4x16x1; + AlgoF32MK4_4x16 f32mk4_4x16; + AlgoF32Gemv f32_gemv; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + AlgoF16K8x24x1 f16_k8x24x1; + AlgoF16MK8_8x8 f16_mk8_8x8; +#endif +#if __ARM_FEATURE_DOTPROD + AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; + AlgoInt8x8x32GemvDotProd int8x8x32_gemv_dotprod; +#else + AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; + AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; + AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8; + AlgoInt8x8x32Gemv int8x8x32_gemv; +#endif + AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; + AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; + + AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; + AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; + +#if __ARM_FEATURE_DOTPROD + AlgoQuint8K8x8x4DotProd quint8_k8x8x4_dotprod; + AlgoQuint8GemvDotProd quint8_gemv_dotprod; +#else + AlgoQuint8K8x8x8 quint8_k8x8x8; +#endif + +public: + SmallVector all_algos; + + AlgoPack() { + all_algos.emplace_back(&f32_gemv); + all_algos.emplace_back(&f32K8x12x1); + all_algos.emplace_back(&f32k4x16x1); + all_algos.emplace_back(&f32mk4_4x16); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + all_algos.emplace_back(&f16_k8x24x1); + all_algos.emplace_back(&f16_mk8_8x8); +#endif +#if __ARM_FEATURE_DOTPROD + all_algos.emplace_back(&int8x8x32_gemv_dotprod); + all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); +#else + all_algos.emplace_back(&int8x8x32_gemv); + all_algos.emplace_back(&int8x8x32_k8x8x8); + all_algos.emplace_back(&int8x8x32_k4x4x16); + all_algos.emplace_back(&int8x8x32_mk4_4x4x16); +#endif + all_algos.emplace_back(&int8x8x16_k4x4x16); + all_algos.emplace_back(&int8x8x16_k8x8x8); + + all_algos.emplace_back(&int16x16x32_k12x8x1); + all_algos.emplace_back(&int16x16x32_mk8_8x8); +#if __ARM_FEATURE_DOTPROD + all_algos.emplace_back(&quint8_gemv_dotprod); + all_algos.emplace_back(&quint8_k8x8x4_dotprod); +#else + all_algos.emplace_back(&quint8_k8x8x8); +#endif + } +}; + +SmallVector MatrixMulImpl::algo_pack() { + static AlgoPack s_algo_pack; + auto&& algos = arm_common::MatrixMulImpl::algo_pack(); + algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), + s_algo_pack.all_algos.end()); + return std::move(algos); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.h b/dnn/src/aarch64/matrix_mul/opr_impl.h new file mode 100644 index 00000000..01ec762f --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/opr_impl.h @@ -0,0 +1,63 @@ +/** + * \file dnn/src/aarch64/matrix_mul/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/arm_common/matrix_mul/opr_impl.h" + +namespace megdnn { +namespace aarch64 { + +class MatrixMulImpl : public arm_common::MatrixMulImpl { +public: + using arm_common::MatrixMulImpl::MatrixMulImpl; + + SmallVector algo_pack() override; + +private: + class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 + class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1 + class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4 + class AlgoF32Gemv; // Aarch64 F32 Gemv +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + class AlgoF16K8x24x1; // Aarch64 F16 Kernel 8x24x1 + class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8 +#endif + +#if __ARM_FEATURE_DOTPROD + class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel + // 8x12x4 DotProduct + class AlgoInt8x8x32GemvDotProd; // Aarch64 Int8x8x32 Gemv DotProduct +#else + class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 + class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 + class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 + class AlgoInt8x8x32Gemv; // Aarch64 Int8x8x32 Gemv +#endif + class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 + class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 + + class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1 + class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8 + +#if __ARM_FEATURE_DOTPROD + class AlgoQuint8K8x8x4DotProd; // Aarch64 Quint8 Kernel + // 8x8x4 DotProduct + class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct +#else + class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 +#endif + + class AlgoPack; +}; + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h b/dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h new file mode 100644 index 00000000..d1d12a55 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h @@ -0,0 +1,1398 @@ +/** + * \file dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if !(__ARM_FEATURE_DOTPROD) +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_8x8x8 { + +/** + * Overview of register layout: + * + * A 8x8x8 cell of Rhs is stored in 8bit in q26-q27 + * A 8x8x8 cell of Lhs is stored in 8bit in q0-q7 + * A 8x8 block of accumulators is stored in 32bit in q8-q23 + * zero_point_A is stored in 8bit in q24 + * zero_point_B is stored in 8bit in q25. + * + * +--------+--------+ + * |v26[0-8]|v27[0-8]| + * Rhs +--------+--------+ + * Lhs | | | + * + * +--------+ - - - - +-----------------+ + * |v0[0-8]| | v8[0-4]| v9[0-4]| + * |v1[0-8]| |v10[0-4]|v11[0-4]| + * |v2[0-8]| |v12[0-4]|v13[0-4]| + * |v3[0-8]| |v14[0-4]|v15[0-4]| + * |v4[0-8]| |v16[0-4]|v17[0-4]| + * |v5[0-8]| |v18[0-4]|v19[0-4]| + * |v6[0-8]| |v20[0-4]|v21[0-4]| + * |v7[0-8]| |v22[0-4]|v23[0-4]| + * +--------+ - - - - +-----------------+ + * + * Accumulator + */ + +static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, uint8_t za, + uint8_t zb) { + K /= 8; + const uint8_t* a_ptr = packA; + const uint8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + + asm volatile( + // load accumulator C + "add x1, %[output], %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "add x4, x3, %x[LDC]\n" + "add x5, x4, %x[LDC]\n" + "add x6, x5, %x[LDC]\n" + "add x7, x6, %x[LDC]\n" + "dup v24.8b, %w[za]\n" + "dup v25.8b, %w[zb]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + + "ldp q8, q9, [%[output]]\n" + "ldp q10, q11, [x1]\n" + "ldp q12, q13, [x2]\n" + "ldp q14, q15, [x3]\n" + "ldp q16, q17, [x4]\n" + "ldp q18, q19, [x5]\n" + "ldp q20, q21, [x6]\n" + "ldp q22, q23, [x7]\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + + "2: \n" + "ld1 {v26.8b}, [%[b_ptr]], 8\n" + "ld1 {v0.8b}, [%[a_ptr]], 8\n" + "ld1 {v1.8b}, [%[a_ptr]], 8\n" + "ld1 {v2.8b}, [%[a_ptr]], 8\n" + "ld1 {v3.8b}, [%[a_ptr]], 8\n" + "ld1 {v4.8b}, [%[a_ptr]], 8\n" + "ld1 {v5.8b}, [%[a_ptr]], 8\n" + "ld1 {v6.8b}, [%[a_ptr]], 8\n" + "ld1 {v7.8b}, [%[a_ptr]], 8\n" + "usubl v26.8h, v26.8b, v25.8b\n" + "usubl v0.8h, v0.8b, v24.8b\n" + "usubl v1.8h, v1.8b, v24.8b\n" + "usubl v2.8h, v2.8b, v24.8b\n" + "usubl v3.8h, v3.8b, v24.8b\n" + "usubl v4.8h, v4.8b, v24.8b\n" + "usubl v5.8h, v5.8b, v24.8b\n" + "usubl v6.8h, v6.8b, v24.8b\n" + "usubl v7.8h, v7.8b, v24.8b\n" + + "ld1 {v27.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v26.4h, v0.h[0]\n" + "smlal v10.4s, v26.4h, v1.h[0]\n" + "smlal v12.4s, v26.4h, v2.h[0]\n" + "smlal v14.4s, v26.4h, v3.h[0]\n" + "smlal v16.4s, v26.4h, v4.h[0]\n" + "smlal v18.4s, v26.4h, v5.h[0]\n" + "smlal v20.4s, v26.4h, v6.h[0]\n" + "smlal v22.4s, v26.4h, v7.h[0]\n" + "usubl v27.8h, v27.8b, v25.8b\n" + "smlal2 v9.4s, v26.8h, v0.h[0]\n" + "smlal2 v11.4s, v26.8h, v1.h[0]\n" + "smlal2 v13.4s, v26.8h, v2.h[0]\n" + "smlal2 v15.4s, v26.8h, v3.h[0]\n" + "smlal2 v17.4s, v26.8h, v4.h[0]\n" + "smlal2 v19.4s, v26.8h, v5.h[0]\n" + "smlal2 v21.4s, v26.8h, v6.h[0]\n" + "smlal2 v23.4s, v26.8h, v7.h[0]\n" + + "ld1 {v26.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v27.4h, v0.h[1]\n" + "smlal v10.4s, v27.4h, v1.h[1]\n" + "smlal v12.4s, v27.4h, v2.h[1]\n" + "smlal v14.4s, v27.4h, v3.h[1]\n" + "smlal v16.4s, v27.4h, v4.h[1]\n" + "smlal v18.4s, v27.4h, v5.h[1]\n" + "smlal v20.4s, v27.4h, v6.h[1]\n" + "smlal v22.4s, v27.4h, v7.h[1]\n" + "usubl v26.8h, v26.8b, v25.8b\n" + "smlal2 v9.4s, v27.8h, v0.h[1]\n" + "smlal2 v11.4s, v27.8h, v1.h[1]\n" + "smlal2 v13.4s, v27.8h, v2.h[1]\n" + "smlal2 v15.4s, v27.8h, v3.h[1]\n" + "smlal2 v17.4s, v27.8h, v4.h[1]\n" + "smlal2 v19.4s, v27.8h, v5.h[1]\n" + "smlal2 v21.4s, v27.8h, v6.h[1]\n" + "smlal2 v23.4s, v27.8h, v7.h[1]\n" + + "ld1 {v27.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v26.4h, v0.h[2]\n" + "smlal v10.4s, v26.4h, v1.h[2]\n" + "smlal v12.4s, v26.4h, v2.h[2]\n" + "smlal v14.4s, v26.4h, v3.h[2]\n" + "smlal v16.4s, v26.4h, v4.h[2]\n" + "smlal v18.4s, v26.4h, v5.h[2]\n" + "smlal v20.4s, v26.4h, v6.h[2]\n" + "smlal v22.4s, v26.4h, v7.h[2]\n" + "usubl v27.8h, v27.8b, v25.8b\n" + "smlal2 v9.4s, v26.8h, v0.h[2]\n" + "smlal2 v11.4s, v26.8h, v1.h[2]\n" + "smlal2 v13.4s, v26.8h, v2.h[2]\n" + "smlal2 v15.4s, v26.8h, v3.h[2]\n" + "smlal2 v17.4s, v26.8h, v4.h[2]\n" + "smlal2 v19.4s, v26.8h, v5.h[2]\n" + "smlal2 v21.4s, v26.8h, v6.h[2]\n" + "smlal2 v23.4s, v26.8h, v7.h[2]\n" + + "ld1 {v26.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v27.4h, v0.h[3]\n" + "smlal v10.4s, v27.4h, v1.h[3]\n" + "smlal v12.4s, v27.4h, v2.h[3]\n" + "smlal v14.4s, v27.4h, v3.h[3]\n" + "smlal v16.4s, v27.4h, v4.h[3]\n" + "smlal v18.4s, v27.4h, v5.h[3]\n" + "smlal v20.4s, v27.4h, v6.h[3]\n" + "smlal v22.4s, v27.4h, v7.h[3]\n" + "usubl v26.8h, v26.8b, v25.8b\n" + "smlal2 v9.4s, v27.8h, v0.h[3]\n" + "smlal2 v11.4s, v27.8h, v1.h[3]\n" + "smlal2 v13.4s, v27.8h, v2.h[3]\n" + "smlal2 v15.4s, v27.8h, v3.h[3]\n" + "smlal2 v17.4s, v27.8h, v4.h[3]\n" + "smlal2 v19.4s, v27.8h, v5.h[3]\n" + "smlal2 v21.4s, v27.8h, v6.h[3]\n" + "smlal2 v23.4s, v27.8h, v7.h[3]\n" + + "ld1 {v27.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v26.4h, v0.h[4]\n" + "smlal v10.4s, v26.4h, v1.h[4]\n" + "smlal v12.4s, v26.4h, v2.h[4]\n" + "smlal v14.4s, v26.4h, v3.h[4]\n" + "smlal v16.4s, v26.4h, v4.h[4]\n" + "smlal v18.4s, v26.4h, v5.h[4]\n" + "smlal v20.4s, v26.4h, v6.h[4]\n" + "smlal v22.4s, v26.4h, v7.h[4]\n" + "usubl v27.8h, v27.8b, v25.8b\n" + "smlal2 v9.4s, v26.8h, v0.h[4]\n" + "smlal2 v11.4s, v26.8h, v1.h[4]\n" + "smlal2 v13.4s, v26.8h, v2.h[4]\n" + "smlal2 v15.4s, v26.8h, v3.h[4]\n" + "smlal2 v17.4s, v26.8h, v4.h[4]\n" + "smlal2 v19.4s, v26.8h, v5.h[4]\n" + "smlal2 v21.4s, v26.8h, v6.h[4]\n" + "smlal2 v23.4s, v26.8h, v7.h[4]\n" + + "ld1 {v26.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v27.4h, v0.h[5]\n" + "smlal v10.4s, v27.4h, v1.h[5]\n" + "smlal v12.4s, v27.4h, v2.h[5]\n" + "smlal v14.4s, v27.4h, v3.h[5]\n" + "smlal v16.4s, v27.4h, v4.h[5]\n" + "smlal v18.4s, v27.4h, v5.h[5]\n" + "smlal v20.4s, v27.4h, v6.h[5]\n" + "smlal v22.4s, v27.4h, v7.h[5]\n" + "usubl v26.8h, v26.8b, v25.8b\n" + "smlal2 v9.4s, v27.8h, v0.h[5]\n" + "smlal2 v11.4s, v27.8h, v1.h[5]\n" + "smlal2 v13.4s, v27.8h, v2.h[5]\n" + "smlal2 v15.4s, v27.8h, v3.h[5]\n" + "smlal2 v17.4s, v27.8h, v4.h[5]\n" + "smlal2 v19.4s, v27.8h, v5.h[5]\n" + "smlal2 v21.4s, v27.8h, v6.h[5]\n" + "smlal2 v23.4s, v27.8h, v7.h[5]\n" + + "ld1 {v27.8b}, [%[b_ptr]], 8\n" + "smlal v8.4s, v26.4h, v0.h[6]\n" + "smlal v10.4s, v26.4h, v1.h[6]\n" + "smlal v12.4s, v26.4h, v2.h[6]\n" + "smlal v14.4s, v26.4h, v3.h[6]\n" + "smlal v16.4s, v26.4h, v4.h[6]\n" + "smlal v18.4s, v26.4h, v5.h[6]\n" + "smlal v20.4s, v26.4h, v6.h[6]\n" + "smlal v22.4s, v26.4h, v7.h[6]\n" + "usubl v27.8h, v27.8b, v25.8b\n" + "smlal2 v9.4s, v26.8h, v0.h[6]\n" + "smlal2 v11.4s, v26.8h, v1.h[6]\n" + "smlal2 v13.4s, v26.8h, v2.h[6]\n" + "smlal2 v15.4s, v26.8h, v3.h[6]\n" + "smlal2 v17.4s, v26.8h, v4.h[6]\n" + "smlal2 v19.4s, v26.8h, v5.h[6]\n" + "smlal2 v21.4s, v26.8h, v6.h[6]\n" + "smlal2 v23.4s, v26.8h, v7.h[6]\n" + + "smlal v8.4s, v27.4h, v0.h[7]\n" + "smlal v10.4s, v27.4h, v1.h[7]\n" + "smlal v12.4s, v27.4h, v2.h[7]\n" + "smlal v14.4s, v27.4h, v3.h[7]\n" + "smlal v16.4s, v27.4h, v4.h[7]\n" + "smlal v18.4s, v27.4h, v5.h[7]\n" + "smlal v20.4s, v27.4h, v6.h[7]\n" + "smlal v22.4s, v27.4h, v7.h[7]\n" + "smlal2 v9.4s, v27.8h, v0.h[7]\n" + "smlal2 v11.4s, v27.8h, v1.h[7]\n" + "smlal2 v13.4s, v27.8h, v2.h[7]\n" + "smlal2 v15.4s, v27.8h, v3.h[7]\n" + "smlal2 v17.4s, v27.8h, v4.h[7]\n" + "smlal2 v19.4s, v27.8h, v5.h[7]\n" + "smlal2 v21.4s, v27.8h, v6.h[7]\n" + "smlal2 v23.4s, v27.8h, v7.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" + "stp q8, q9, [%[output]]\n" + "stp q10, q11, [x1]\n" + "stp q12, q13, [x2]\n" + "stp q14, q15, [x3]\n" + "stp q16, q17, [x4]\n" + "stp q18, q19, [x5]\n" + "stp q20, q21, [x6]\n" + "stp q22, q23, [x7]\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [output] "+r"(output), [za] "+r"(za), [zb] "+r"(zb) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "x1", + "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); +} + +/** + * Overview of register layout: + * + * A 8x4x8 cell of Rhs is stored in 8bit in q16-q17 + * A 8x8x8 cell of Lhs is stored in 8bit in q0-q7 + * A 8x4 block of accumulators is stored in 32bit in q8-q15 + * zero_point_A is stored in 8bit in q18 + * zero_point_B is stored in 8bit in q19. + * + * +--------+ + * |v16[0-4]| + * Rhs +--------+ + * |v17[0-4]| + * Lhs +--------+ + * + * +--------+ - - - - +--------+ + * |v0[0-8]| | v8[0-4]| + * |v1[0-8]| | v9[0-4]| + * |v2[0-8]| |v10[0-4]| + * |v3[0-8]| |v11[0-4]| + * |v4[0-8]| |v12[0-4]| + * |v5[0-8]| |v13[0-4]| + * |v6[0-8]| |v14[0-4]| + * |v7[0-8]| |v15[0-4]| + * +--------+ - - - - +--------+ + * + * Accumulator + */ + +static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, size_t n_remain, + uint8_t za, uint8_t zb) { + K /= 8; + const uint8_t* a_ptr = packA; + const uint8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + int32_t* outptr4; + int32_t* outptr5; + int32_t* outptr6; + int32_t* outptr7; + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" reg_index ", [%[x0]] \n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "101" n ":\n" + +#define LOAD_C \ + LOAD_LINE("8", "0") \ + LOAD_LINE("9", "1") \ + LOAD_LINE("10", "2") \ + LOAD_LINE("11", "3") \ + LOAD_LINE("12", "4") \ + LOAD_LINE("13", "5") \ + LOAD_LINE("14", "6") \ + LOAD_LINE("15", "7") + +#define STORE_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 102" n "f\n" \ + "str q" reg_index ", [%[x0]]\n" \ + "b 103" n "f\n" \ + "102" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "103" n ":\n" + +#define STORE_C \ + STORE_LINE("8", "0") \ + STORE_LINE("9", "1") \ + STORE_LINE("10", "2") \ + STORE_LINE("11", "3") \ + STORE_LINE("12", "4") \ + STORE_LINE("13", "5") \ + STORE_LINE("14", "6") \ + STORE_LINE("15", "7") + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "add %[outptr4], %[outptr3], %x[LDC]\n" + "add %[outptr5], %[outptr4], %x[LDC]\n" + "add %[outptr6], %[outptr5], %x[LDC]\n" + "add %[outptr7], %[outptr6], %x[LDC]\n" + "dup v18.8b, %w[za]\n" + "dup v19.8b, %w[zb]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + + "2: \n" + "ld1 {v16.s}[0], [%[b_ptr]], 4\n" + "ld1 {v0.8b}, [%[a_ptr]], 8\n" + "ld1 {v1.8b}, [%[a_ptr]], 8\n" + "ld1 {v2.8b}, [%[a_ptr]], 8\n" + "ld1 {v3.8b}, [%[a_ptr]], 8\n" + "ld1 {v4.8b}, [%[a_ptr]], 8\n" + "ld1 {v5.8b}, [%[a_ptr]], 8\n" + "ld1 {v6.8b}, [%[a_ptr]], 8\n" + "ld1 {v7.8b}, [%[a_ptr]], 8\n" + "usubl v16.8h, v16.8b, v19.8b\n" + "usubl v0.8h, v0.8b, v18.8b\n" + "usubl v1.8h, v1.8b, v18.8b\n" + "usubl v2.8h, v2.8b, v18.8b\n" + "usubl v3.8h, v3.8b, v18.8b\n" + "usubl v4.8h, v4.8b, v18.8b\n" + "usubl v5.8h, v5.8b, v18.8b\n" + "usubl v6.8h, v6.8b, v18.8b\n" + "usubl v7.8h, v7.8b, v18.8b\n" + + "ld1 {v17.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v16.4h, v0.h[0]\n" + "smlal v9.4s, v16.4h, v1.h[0]\n" + "smlal v10.4s, v16.4h, v2.h[0]\n" + "smlal v11.4s, v16.4h, v3.h[0]\n" + "usubl v17.8h, v17.8b, v19.8b\n" + "smlal v12.4s, v16.4h, v4.h[0]\n" + "smlal v13.4s, v16.4h, v5.h[0]\n" + "smlal v14.4s, v16.4h, v6.h[0]\n" + "smlal v15.4s, v16.4h, v7.h[0]\n" + + "ld1 {v16.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v17.4h, v0.h[1]\n" + "smlal v9.4s, v17.4h, v1.h[1]\n" + "smlal v10.4s, v17.4h, v2.h[1]\n" + "smlal v11.4s, v17.4h, v3.h[1]\n" + "usubl v16.8h, v16.8b, v19.8b\n" + "smlal v12.4s, v17.4h, v4.h[1]\n" + "smlal v13.4s, v17.4h, v5.h[1]\n" + "smlal v14.4s, v17.4h, v6.h[1]\n" + "smlal v15.4s, v17.4h, v7.h[1]\n" + + "ld1 {v17.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v16.4h, v0.h[2]\n" + "smlal v9.4s, v16.4h, v1.h[2]\n" + "smlal v10.4s, v16.4h, v2.h[2]\n" + "smlal v11.4s, v16.4h, v3.h[2]\n" + "usubl v17.8h, v17.8b, v19.8b\n" + "smlal v12.4s, v16.4h, v4.h[2]\n" + "smlal v13.4s, v16.4h, v5.h[2]\n" + "smlal v14.4s, v16.4h, v6.h[2]\n" + "smlal v15.4s, v16.4h, v7.h[2]\n" + + "ld1 {v16.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v17.4h, v0.h[3]\n" + "smlal v9.4s, v17.4h, v1.h[3]\n" + "smlal v10.4s, v17.4h, v2.h[3]\n" + "smlal v11.4s, v17.4h, v3.h[3]\n" + "usubl v16.8h, v16.8b, v19.8b\n" + "smlal v12.4s, v17.4h, v4.h[3]\n" + "smlal v13.4s, v17.4h, v5.h[3]\n" + "smlal v14.4s, v17.4h, v6.h[3]\n" + "smlal v15.4s, v17.4h, v7.h[3]\n" + + "ld1 {v17.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v16.4h, v0.h[4]\n" + "smlal v9.4s, v16.4h, v1.h[4]\n" + "smlal v10.4s, v16.4h, v2.h[4]\n" + "smlal v11.4s, v16.4h, v3.h[4]\n" + "usubl v17.8h, v17.8b, v19.8b\n" + "smlal v12.4s, v16.4h, v4.h[4]\n" + "smlal v13.4s, v16.4h, v5.h[4]\n" + "smlal v14.4s, v16.4h, v6.h[4]\n" + "smlal v15.4s, v16.4h, v7.h[4]\n" + + "ld1 {v16.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v17.4h, v0.h[5]\n" + "smlal v9.4s, v17.4h, v1.h[5]\n" + "smlal v10.4s, v17.4h, v2.h[5]\n" + "smlal v11.4s, v17.4h, v3.h[5]\n" + "usubl v16.8h, v16.8b, v19.8b\n" + "smlal v12.4s, v17.4h, v4.h[5]\n" + "smlal v13.4s, v17.4h, v5.h[5]\n" + "smlal v14.4s, v17.4h, v6.h[5]\n" + "smlal v15.4s, v17.4h, v7.h[5]\n" + + "ld1 {v17.s}[0], [%[b_ptr]], 4\n" + "smlal v8.4s, v16.4h, v0.h[6]\n" + "smlal v9.4s, v16.4h, v1.h[6]\n" + "smlal v10.4s, v16.4h, v2.h[6]\n" + "smlal v11.4s, v16.4h, v3.h[6]\n" + "usubl v17.8h, v17.8b, v19.8b\n" + "smlal v12.4s, v16.4h, v4.h[6]\n" + "smlal v13.4s, v16.4h, v5.h[6]\n" + "smlal v14.4s, v16.4h, v6.h[6]\n" + "smlal v15.4s, v16.4h, v7.h[6]\n" + + "smlal v8.4s, v17.4h, v0.h[7]\n" + "smlal v9.4s, v17.4h, v1.h[7]\n" + "smlal v10.4s, v17.4h, v2.h[7]\n" + "smlal v11.4s, v17.4h, v3.h[7]\n" + "smlal v12.4s, v17.4h, v4.h[7]\n" + "smlal v13.4s, v17.4h, v5.h[7]\n" + "smlal v14.4s, v17.4h, v6.h[7]\n" + "smlal v15.4s, v17.4h, v7.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [za] "+r"(za), [zb] "+r"(zb), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), + [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), + [outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +/** + * Overview of register layout: + * + * A 8x8x8 cell of Rhs is stored in 8bit in q12-q13 + * A 8x8x4 cell of Lhs is stored in 8bit in q0-q3 + * A 4x8 block of accumulators is stored in 32bit in q4-q11 + * zero_point_A is stored in 8bit in q14 + * zero_point_B is stored in 8bit in q15. + * + * +--------+--------+ + * |v12[0-8]|v13[0-8]| + * Rhs +--------+--------+ + * Lhs | | | + * + * +--------+ - - - - +-----------------+ + * |v0[0-8]| | v4[0-4]| v5[0-4]| + * |v1[0-8]| | v6[0-4]| v7[0-4]| + * |v2[0-8]| | v8[0-4]| v9[0-4]| + * |v3[0-8]| |v10[0-4]|v11[0-4]| + * +--------+ - - - - +-----------------+ + * + * Accumulator + */ + +static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, size_t m_remain, + uint8_t za, uint8_t zb) { + K /= 8; + const uint8_t* a_ptr = packA; + const uint8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(v1, v2, m) \ + "cbz %[x0], 100f\n" \ + "ldp " v1 "," v2 ", [%[outptr" m "]]\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %x[m_remain]\n" \ + LOAD_LINE("q4", "q5", "0") \ + LOAD_LINE("q6", "q7", "1") \ + LOAD_LINE("q8", "q9", "2") \ + LOAD_LINE("q10", "q11", "3") \ + "100:\n" + +#define STORE_LINE(v1, v2, m) \ + "cbz %[x0], 101f\n" \ + "stp " v1 "," v2", [%[outptr" m "]]\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %x[m_remain]\n" \ + STORE_LINE("q4", "q5", "0") \ + STORE_LINE("q6", "q7", "1") \ + STORE_LINE("q8", "q9", "2") \ + STORE_LINE("q10", "q11", "3") \ + "101:\n" + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "dup v14.8b, %w[za]\n" + "dup v15.8b, %w[zb]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + + "2: \n" + "ld1 {v12.8b}, [%[b_ptr]], 8\n" + "ld1 {v0.8b}, [%[a_ptr]], 8\n" + "ld1 {v1.8b}, [%[a_ptr]], 8\n" + "ld1 {v2.8b}, [%[a_ptr]], 8\n" + "ld1 {v3.8b}, [%[a_ptr]], 8\n" + "usubl v12.8h, v12.8b, v15.8b\n" + "usubl v0.8h, v0.8b, v14.8b\n" + "usubl v1.8h, v1.8b, v14.8b\n" + "usubl v2.8h, v2.8b, v14.8b\n" + "usubl v3.8h, v3.8b, v14.8b\n" + + "ld1 {v13.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v12.4h, v0.h[0]\n" + "smlal v6.4s, v12.4h, v1.h[0]\n" + "smlal v8.4s, v12.4h, v2.h[0]\n" + "smlal v10.4s, v12.4h, v3.h[0]\n" + "usubl v13.8h, v13.8b, v15.8b\n" + "smlal2 v5.4s, v12.8h, v0.h[0]\n" + "smlal2 v7.4s, v12.8h, v1.h[0]\n" + "smlal2 v9.4s, v12.8h, v2.h[0]\n" + "smlal2 v11.4s, v12.8h, v3.h[0]\n" + + "ld1 {v12.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v13.4h, v0.h[1]\n" + "smlal v6.4s, v13.4h, v1.h[1]\n" + "smlal v8.4s, v13.4h, v2.h[1]\n" + "smlal v10.4s, v13.4h, v3.h[1]\n" + "usubl v12.8h, v12.8b, v15.8b\n" + "smlal2 v5.4s, v13.8h, v0.h[1]\n" + "smlal2 v7.4s, v13.8h, v1.h[1]\n" + "smlal2 v9.4s, v13.8h, v2.h[1]\n" + "smlal2 v11.4s, v13.8h, v3.h[1]\n" + + "ld1 {v13.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v12.4h, v0.h[2]\n" + "smlal v6.4s, v12.4h, v1.h[2]\n" + "smlal v8.4s, v12.4h, v2.h[2]\n" + "smlal v10.4s, v12.4h, v3.h[2]\n" + "usubl v13.8h, v13.8b, v15.8b\n" + "smlal2 v5.4s, v12.8h, v0.h[2]\n" + "smlal2 v7.4s, v12.8h, v1.h[2]\n" + "smlal2 v9.4s, v12.8h, v2.h[2]\n" + "smlal2 v11.4s, v12.8h, v3.h[2]\n" + + "ld1 {v12.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v13.4h, v0.h[3]\n" + "smlal v6.4s, v13.4h, v1.h[3]\n" + "smlal v8.4s, v13.4h, v2.h[3]\n" + "smlal v10.4s, v13.4h, v3.h[3]\n" + "usubl v12.8h, v12.8b, v15.8b\n" + "smlal2 v5.4s, v13.8h, v0.h[3]\n" + "smlal2 v7.4s, v13.8h, v1.h[3]\n" + "smlal2 v9.4s, v13.8h, v2.h[3]\n" + "smlal2 v11.4s, v13.8h, v3.h[3]\n" + + "ld1 {v13.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v12.4h, v0.h[4]\n" + "smlal v6.4s, v12.4h, v1.h[4]\n" + "smlal v8.4s, v12.4h, v2.h[4]\n" + "smlal v10.4s, v12.4h, v3.h[4]\n" + "usubl v13.8h, v13.8b, v15.8b\n" + "smlal2 v5.4s, v12.8h, v0.h[4]\n" + "smlal2 v7.4s, v12.8h, v1.h[4]\n" + "smlal2 v9.4s, v12.8h, v2.h[4]\n" + "smlal2 v11.4s, v12.8h, v3.h[4]\n" + + "ld1 {v12.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v13.4h, v0.h[5]\n" + "smlal v6.4s, v13.4h, v1.h[5]\n" + "smlal v8.4s, v13.4h, v2.h[5]\n" + "smlal v10.4s, v13.4h, v3.h[5]\n" + "usubl v12.8h, v12.8b, v15.8b\n" + "smlal2 v5.4s, v13.8h, v0.h[5]\n" + "smlal2 v7.4s, v13.8h, v1.h[5]\n" + "smlal2 v9.4s, v13.8h, v2.h[5]\n" + "smlal2 v11.4s, v13.8h, v3.h[5]\n" + + "ld1 {v13.8b}, [%[b_ptr]], 8\n" + "smlal v4.4s, v12.4h, v0.h[6]\n" + "smlal v6.4s, v12.4h, v1.h[6]\n" + "smlal v8.4s, v12.4h, v2.h[6]\n" + "smlal v10.4s, v12.4h, v3.h[6]\n" + "usubl v13.8h, v13.8b, v15.8b\n" + "smlal2 v5.4s, v12.8h, v0.h[6]\n" + "smlal2 v7.4s, v12.8h, v1.h[6]\n" + "smlal2 v9.4s, v12.8h, v2.h[6]\n" + "smlal2 v11.4s, v12.8h, v3.h[6]\n" + + "smlal v4.4s, v13.4h, v0.h[7]\n" + "smlal v6.4s, v13.4h, v1.h[7]\n" + "smlal v8.4s, v13.4h, v2.h[7]\n" + "smlal v10.4s, v13.4h, v3.h[7]\n" + "smlal2 v5.4s, v13.8h, v0.h[7]\n" + "smlal2 v7.4s, v13.8h, v1.h[7]\n" + "smlal2 v9.4s, v13.8h, v2.h[7]\n" + "smlal2 v11.4s, v13.8h, v3.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [za] "+r"(za), [zb] "+r"(zb), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + [outptr3] "=r"(outptr3), [x0] "+r"(x0), [m_remain] "+r"(m_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +/** + * Overview of register layout: + * + * A 8x4x8 cell of Rhs is stored in 8bit in q8-q9 + * A 8x8x4 cell of Lhs is stored in 8bit in q0-q3 + * A 4x4 block of accumulators is stored in 32bit in q4-q7 + * zero_point_A is stored in 8bit in q10 + * zero_point_B is stored in 8bit in q11. + * + * +--------+ + * | v8[0-4]| + * Rhs +--------+ + * | v9[0-4]| + * Lhs +--------+ + * + * +--------+ - - - - +--------+ + * |v0[0-8]| | v4[0-4]| + * |v1[0-8]| | v5[0-4]| + * |v2[0-8]| | v6[0-4]| + * |v3[0-8]| | v7[0-4]| + * +--------+ - - - - +--------+ + * + * Accumulator + */ + +static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, size_t m_remain, + size_t n_remain, uint8_t za, uint8_t zb) { + K /= 8; + const uint8_t* a_ptr = packA; + const uint8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + size_t x0 = 0; + size_t x1 = 0; + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "cbz %[x1], 102f\n" \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" reg_index ", [%[x0]]\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "101" n ":\n" \ + "subs %[x1], %[x1], #1\n" + +#define LOAD_C \ + "mov %[x1], %x[m_remain]\n" \ + LOAD_LINE("4", "0") \ + LOAD_LINE("5", "1") \ + LOAD_LINE("6", "2") \ + LOAD_LINE("7", "3") \ + "102:\n" + +#define STORE_LINE(reg_index, n) \ + "cbz %[x1], 105f\n" \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "str q" reg_index ", [%[x0]]\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "104" n ":\n" \ + "subs %[x1], %[x1], #1\n" + +#define STORE_C \ + "mov %[x1], %x[m_remain]\n" \ + STORE_LINE("4", "0") \ + STORE_LINE("5", "1") \ + STORE_LINE("6", "2") \ + STORE_LINE("7", "3") \ + "105:\n" + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "dup v10.8b, %w[za]\n" + "dup v11.8b, %w[zb]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + + "2: \n" + "ld1 {v8.s}[0], [%[b_ptr]], 4\n" + "ld1 {v0.8b}, [%[a_ptr]], 8\n" + "ld1 {v1.8b}, [%[a_ptr]], 8\n" + "ld1 {v2.8b}, [%[a_ptr]], 8\n" + "ld1 {v3.8b}, [%[a_ptr]], 8\n" + "usubl v8.8h, v8.8b, v11.8b\n" + "usubl v0.8h, v0.8b, v10.8b\n" + "usubl v1.8h, v1.8b, v10.8b\n" + "usubl v2.8h, v2.8b, v10.8b\n" + "usubl v3.8h, v3.8b, v10.8b\n" + + "ld1 {v9.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v8.4h, v0.h[0]\n" + "smlal v5.4s, v8.4h, v1.h[0]\n" + "usubl v9.8h, v9.8b, v11.8b\n" + "smlal v6.4s, v8.4h, v2.h[0]\n" + "smlal v7.4s, v8.4h, v3.h[0]\n" + + "ld1 {v8.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v9.4h, v0.h[1]\n" + "smlal v5.4s, v9.4h, v1.h[1]\n" + "usubl v8.8h, v8.8b, v11.8b\n" + "smlal v6.4s, v9.4h, v2.h[1]\n" + "smlal v7.4s, v9.4h, v3.h[1]\n" + + "ld1 {v9.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v8.4h, v0.h[2]\n" + "smlal v5.4s, v8.4h, v1.h[2]\n" + "usubl v9.8h, v9.8b, v11.8b\n" + "smlal v6.4s, v8.4h, v2.h[2]\n" + "smlal v7.4s, v8.4h, v3.h[2]\n" + + "ld1 {v8.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v9.4h, v0.h[3]\n" + "smlal v5.4s, v9.4h, v1.h[3]\n" + "usubl v8.8h, v8.8b, v11.8b\n" + "smlal v6.4s, v9.4h, v2.h[3]\n" + "smlal v7.4s, v9.4h, v3.h[3]\n" + + "ld1 {v9.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v8.4h, v0.h[4]\n" + "smlal v5.4s, v8.4h, v1.h[4]\n" + "usubl v9.8h, v9.8b, v11.8b\n" + "smlal v6.4s, v8.4h, v2.h[4]\n" + "smlal v7.4s, v8.4h, v3.h[4]\n" + + "ld1 {v8.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v9.4h, v0.h[5]\n" + "smlal v5.4s, v9.4h, v1.h[5]\n" + "usubl v8.8h, v8.8b, v11.8b\n" + "smlal v6.4s, v9.4h, v2.h[5]\n" + "smlal v7.4s, v9.4h, v3.h[5]\n" + + "ld1 {v9.s}[0], [%[b_ptr]], 4\n" + "smlal v4.4s, v8.4h, v0.h[6]\n" + "smlal v5.4s, v8.4h, v1.h[6]\n" + "usubl v9.8h, v9.8b, v11.8b\n" + "smlal v6.4s, v8.4h, v2.h[6]\n" + "smlal v7.4s, v8.4h, v3.h[6]\n" + + "smlal v4.4s, v9.4h, v0.h[7]\n" + "smlal v5.4s, v9.4h, v1.h[7]\n" + "smlal v6.4s, v9.4h, v2.h[7]\n" + "smlal v7.4s, v9.4h, v3.h[7]\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [za] "+r"(za), [zb] "+r"(zb), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + [outptr3] "=r"(outptr3), [x0] "+r"(x0), [x1] "+r"(x1), + [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_u8_8x8_pack_A_n(dt_uint8* outptr, const dt_uint8* inptr, + int ldin, int y0, int ymax, int k0, int kmax, + uint8_t zero_point) { + uint8_t zerobuff[16]; + std::fill(zerobuff, zerobuff + 16, zero_point); + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const uint8_t* inptr0 = inptr + y * ldin + k0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + const uint8_t* inptr4 = inptr3 + ldin; + const uint8_t* inptr5 = inptr4 + ldin; + const uint8_t* inptr6 = inptr5 + ldin; + const uint8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + for (; K > 15; K -= 16) { + interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + } + + if (K > 0) { + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 8, K, zero_point); + } + } + + for (; y < ymax; y += 4) { + const uint8_t* inptr0 = inptr + y * ldin + k0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x8_2_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K, + zero_point); + } + } +} + +static void gemm_u8_8x8_transpose_pack_A_n(dt_uint8* out, const dt_uint8* in, + int ldin, int x0, int xmax, int k0, + int kmax, uint8_t zero_point) { + uint8_t zerobuff[16]; + std::fill(zerobuff, zerobuff + 16, zero_point); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 8) * 4; + const int ksize8 = ksize4 * 2; + uint8_t* outptr = out; + uint8_t* outptr_base = out; + //! 4x4 block output start pos + uint8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k += 8) { + const uint8_t* inptr0 = in + k * ldin + x0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + const uint8_t* inptr4 = inptr3 + ldin; + const uint8_t* inptr5 = inptr4 + ldin; + const uint8_t* inptr6 = inptr5 + ldin; + const uint8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = x0; + outptr = outptr_base; + + for (; x + 7 < xmax; x += 8) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += ksize8; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, 4, zero_point); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, xmax - x, zero_point); + } + + outptr_base += 8 * 8; + outptr_base4 += 4 * 8; + } +} + +static void gemm_u8_8x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, + int x0, int xmax, int k0, int kmax, + uint8_t zero_point) { + uint8_t zerobuff[16]; + std::fill(zerobuff, zerobuff + 16, zero_point); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 8) * 4; + const int ksize8 = ksize4 * 2; + uint8_t* outptr = out; + uint8_t* outptr_base = out; + uint8_t* outptr_interleave = nullptr; + //! 4x4 block output start pos + uint8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k += 8) { + const uint8_t* inptr0 = in + k * ldin + x0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + const uint8_t* inptr4 = inptr3 + ldin; + const uint8_t* inptr5 = inptr4 + ldin; + const uint8_t* inptr6 = inptr5 + ldin; + const uint8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = x0; + outptr = outptr_base; + + for (; x + 7 < xmax; x += 8) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + outptr_interleave = outptr; + interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr_interleave); + outptr += ksize8; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + outptr_interleave = outptr; + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr_interleave, 4, 4, zero_point); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + outptr_interleave = outptr; + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr_interleave, 4, xmax - x, zero_point); + } + + outptr_base += 8 * 8; + outptr_base4 += 4 * 8; + } +} + +static void gemm_u8_8x8_transpose_pack_B_n(dt_uint8* outptr, + const dt_uint8* inptr, int ldin, + int y0, int ymax, int k0, int kmax, + uint8_t zero_point) { + uint8_t zerobuff[16]; + std::fill(zerobuff, zerobuff + 16, zero_point); + constexpr int interleave4 = 32; + constexpr int interleave8 = 64; + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const uint8_t* inptr0 = inptr + y * ldin + k0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + const uint8_t* inptr4 = inptr3 + ldin; + const uint8_t* inptr5 = inptr4 + ldin; + const uint8_t* inptr6 = inptr5 + ldin; + const uint8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + for (; K > 7; K -= 8) { + transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += interleave8; + } + + if (K > 0) { + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 8, K, zero_point); + outptr += interleave8; + } + } + + for (; y < ymax; y += 4) { + const uint8_t* inptr0 = inptr + y * ldin + k0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 7; K -= 8) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += interleave4; + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K, + zero_point); + outptr += interleave4; + } + } +} + +} // namespace matmul_8x8x8 +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen +#endif diff --git a/dnn/src/aarch64/matrix_mul/quint8/strategy.cpp b/dnn/src/aarch64/matrix_mul/quint8/strategy.cpp new file mode 100644 index 00000000..96079e6d --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/quint8/strategy.cpp @@ -0,0 +1,113 @@ +/** + * \file dnn/src/aarch64/matrix_mul/quint8/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if !(__ARM_FEATURE_DOTPROD) +#include "src/aarch64/matrix_mul/quint8/strategy.h" +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8); + +void gemm_u8_8x8::pack_A(dt_uint8* outptr, const dt_uint8* inptr, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + uint8_t zA = A_dtype.param().zero_point; + if (transpose) { + matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, + ymax, k0, kmax, zA); + } else { + matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, + kmax, zA); + } +} + +void gemm_u8_8x8::pack_B(dt_uint8* out, const dt_uint8* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) const { + uint8_t zB = B_dtype.param().zero_point; + if (transpose) { + matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, + k0, kmax, zB); + } else { + matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, + zB); + } +} + +void gemm_u8_8x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, + size_t N, size_t K, dt_int32* C, size_t LDC, + bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Quantized8Asymm && + C_dtype.enumv() == DTypeEnum::QuantizedS32, + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + uint8_t zA = A_dtype.param().zero_point; + uint8_t zB = B_dtype.param().zero_point; + + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t B_INTERLEAVE = 8; + //! K is packed to times of 8 + K = round_up(K, 8); + const int K8 = K * 8; + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + (m * LDC); + + size_t n = 0; + const dt_uint8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k, + zA, zB); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4), zA, zB); + output += 4; + cur_packB += K4; + } + packA += K8; + } + + for (; m < M; m += 4) { + int32_t* output = C + (m * LDC); + const dt_uint8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), zA, zB); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), + std::min(N - n, 4), zA, zB); + output += 4; + cur_packB += K4; + } + packA += K4; + } +} +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/quint8/strategy.h b/dnn/src/aarch64/matrix_mul/quint8/strategy.h new file mode 100644 index 00000000..e76830dd --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/quint8/strategy.h @@ -0,0 +1,28 @@ +/** + * \file dnn/src/aarch64/matrix_mul/quint8/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#if !(__ARM_FEATURE_DOTPROD) +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 8, 8, 8, false, true, + gemm_u8_8x8); + +} // namespace matmul +} // namespace aarch64 +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp b/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp new file mode 100644 index 00000000..0497f9fb --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp @@ -0,0 +1,177 @@ +/** + * \file dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/quint8_dot/gemv.h" +#include +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/common/unroll_macro.h" + +#if __ARM_FEATURE_DOTPROD + +namespace { + +void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride, + uint8_t zero_point_A, uint8_t zero_point_B) { + int32_t zAB = static_cast(zero_point_A) * + static_cast(zero_point_B) * K; + uint8x16_t zAq = vdupq_n_u8(zero_point_A); + uint8x16_t zBq = vdupq_n_u8(zero_point_B); + uint8x8_t zA = vdup_n_u8(zero_point_A); + uint8x8_t zB = vdup_n_u8(zero_point_B); + megdnn_assert(N == 1 && Bstride == 1); + size_t m = 0; + for (; m + 2 <= M; m += 2) { + int32_t acc_zA, acc_zB, acc_zB2; + int32_t acc[4]; + size_t k = 0; + uint32x4_t acc_neon = vdupq_n_u32(0); + { + uint32x4_t acc_zA_neon = vdupq_n_u32(0); + uint32x4_t acc_zB_neon = vdupq_n_u32(0); + uint32x4_t acc_zB2_neon = vdupq_n_u32(0); + for (; k + 16 <= K; k += 16) { + uint8x16_t elem = vld1q_u8(A + m * Astride + k); + acc_zB_neon = vdotq_u32(acc_zB_neon, zBq, elem); + uint64x2_t a0 = vreinterpretq_u64_u8(elem); + elem = vld1q_u8(A + (m + 1) * Astride + k); + acc_zB2_neon = vdotq_u32(acc_zB2_neon, zBq, elem); + uint64x2_t a1 = vreinterpretq_u64_u8(elem); + //! the first 8 elements is m, the last 8 elements is m + 1 + uint8x16_t a2 = vreinterpretq_u8_u64(vzip1q_u64(a0, a1)); + uint8x16_t a3 = vreinterpretq_u8_u64(vzip2q_u64(a0, a1)); + + elem = vld1q_u8(B + k); + acc_zA_neon = vdotq_u32(acc_zA_neon, zAq, elem); + uint64x2_t b0 = vreinterpretq_u64_u8(elem); + uint8x16_t b2 = vreinterpretq_u8_u64(vzip1q_u64(b0, b0)); + uint8x16_t b3 = vreinterpretq_u8_u64(vzip2q_u64(b0, b0)); + + acc_neon = vdotq_u32(acc_neon, a2, b2); + acc_neon = vdotq_u32(acc_neon, a3, b3); + } + vst1q_s32(acc, vreinterpretq_s32_u32(acc_neon)); + acc_zA = vaddvq_u32(acc_zA_neon); + acc_zB = vaddvq_u32(acc_zB_neon); + acc_zB2 = vaddvq_u32(acc_zB2_neon); + } + + { + uint32x2_t acc_zA_neon = vdup_n_u32(0); + uint32x2_t acc_zB_neon = vdup_n_u32(0); + uint32x2_t acc_zB2_neon = vdup_n_u32(0); + for (; k + 8 <= K; k += 8) { + uint8x8_t a0 = vld1_u8(A + m * Astride + k); + uint8x8_t a1 = vld1_u8(A + (m + 1) * Astride + k); + uint8x8_t b0 = vld1_u8(B + k); + uint32x2_t zero = vdup_n_u32(0); + acc[0] += vaddv_u32(vdot_u32(zero, a0, b0)); + zero = vdup_n_u32(0); + acc[3] += vaddv_u32(vdot_u32(zero, a1, b0)); + + acc_zB_neon = vdot_u32(acc_zB_neon, a0, zB); + acc_zB2_neon = vdot_u32(acc_zB2_neon, a1, zB); + acc_zA_neon = vdot_u32(acc_zA_neon, b0, zA); + } + + acc_zA += vaddv_u32(acc_zA_neon); + acc_zB += vaddv_u32(acc_zB_neon); + acc_zB2 += vaddv_u32(acc_zB2_neon); + } + + for (; k < K; ++k) { + acc[0] += static_cast(A[m * Astride + k]) * B[k]; + acc[3] += static_cast(A[(m + 1) * Astride + k]) * B[k]; + acc_zA += static_cast(B[k]) * zero_point_A; + acc_zB += static_cast(A[m * Astride + k]) * zero_point_B; + acc_zB2 += static_cast(A[(m + 1) * Astride + k]) * + zero_point_B; + } + C[m * Cstride] = acc[0] + acc[1] + zAB - acc_zA - acc_zB; + C[(m + 1) * Cstride] = acc[2] + acc[3] + zAB - acc_zA - acc_zB2; + } + + for (; m < M; ++m) { + int32_t acc[4]; + int32_t acc_zA, acc_zB; + uint32x4_t acc_neon = vdupq_n_u32(0); + size_t k = 0; + { + uint32x4_t acc_zA_neon = vdupq_n_u32(0); + uint32x4_t acc_zB_neon = vdupq_n_u32(0); + for (; k + 16 <= K; k += 16) { + uint8x16_t a0 = vld1q_u8(A + m * Astride + k); + uint8x16_t b0 = vld1q_u8(B + k); + acc_neon = vdotq_u32(acc_neon, a0, b0); + acc_zB_neon = vdotq_u32(acc_zB_neon, zBq, a0); + acc_zA_neon = vdotq_u32(acc_zA_neon, zAq, b0); + } + vst1q_s32(acc, vreinterpretq_s32_u32(acc_neon)); + acc_zA = vaddvq_u32(acc_zA_neon); + acc_zB = vaddvq_u32(acc_zB_neon); + } + + { + uint32x2_t acc_zA_neon = vdup_n_u32(0); + uint32x2_t acc_zB_neon = vdup_n_u32(0); + for (; k + 8 <= K; k += 8) { + uint8x8_t a0 = vld1_u8(A + m * Astride + k); + uint8x8_t b0 = vld1_u8(B + k); + uint32x2_t zero = vdup_n_u32(0); + acc[0] += vaddv_u32(vdot_u32(zero, a0, b0)); + + acc_zB_neon = vdot_u32(acc_zB_neon, a0, zB); + acc_zA_neon = vdot_u32(acc_zA_neon, b0, zA); + } + acc_zA += vaddv_u32(acc_zA_neon); + acc_zB += vaddv_u32(acc_zB_neon); + } + + for (; k < K; ++k) { + acc[0] += static_cast(A[m * Astride + k]) * B[k]; + acc_zA += static_cast(B[k]) * zero_point_A; + acc_zB += static_cast(A[m * Astride + k]) * zero_point_B; + } + C[m * Cstride] = + acc[0] + acc[1] + acc[2] + acc[3] + zAB - acc_zA - acc_zB; + } +} + +} // namespace + +bool megdnn::aarch64::matmul::is_gemv_like_preferred_quint8( + bool transposeA, bool transposeB, size_t M, size_t N, size_t K, + size_t /* LDA */, size_t LDB, size_t /* LDC */) { + if (transposeA) + return false; + if (transposeB) + return false; + MEGDNN_MARK_USED_VAR(K); + MEGDNN_MARK_USED_VAR(M); + //! rebenchmark gemv in sdm855 + return (N == 1 && LDB == 1); +} + +void megdnn::aarch64::matmul::gemv_like_quint8( + const uint8_t* __restrict A, const uint8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, size_t K, size_t Astride, + size_t Bstride, size_t Cstride, uint8_t zero_point_A, + uint8_t zero_point_B) { + megdnn_assert(N == 1); + return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride, + zero_point_A, zero_point_B); +} + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h b/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h new file mode 100644 index 00000000..e7d8e85b --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h @@ -0,0 +1,35 @@ +/** + * \file dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include +#include + +#if __ARM_FEATURE_DOTPROD +namespace megdnn { +namespace aarch64 { +namespace matmul { + +bool is_gemv_like_preferred_quint8(bool transposeA, bool transposeB, size_t M, + size_t N, size_t K, size_t LDA, size_t LDB, + size_t LDC); + +void gemv_like_quint8(const uint8_t* __restrict A, const uint8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride, + uint8_t zero_point_A, uint8_t zero_point_B); + +} // namespace matmul +} // namespace aarch64 +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h b/dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h new file mode 100644 index 00000000..3cb1c1b8 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h @@ -0,0 +1,1092 @@ +/** + * \file dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_DOTPROD + +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_8x8x4 { + +//! calc v0 = v0 - v1[lane1] - v2 +#define SUB_LANE(v0, v1, lane1, v2, vtmp) \ + "dup v" #vtmp ".4s, v" #v1 ".s[" #lane1 \ + "]\n" \ + "sub v" #v0 ".4s, v" #v0 ".4s, v" #vtmp \ + ".4s\n" \ + "sub v" #v0 ".4s, v" #v0 ".4s, v" #v2 ".4s\n" + +// Overview of register layout: +// +// A 8x4 cell of Rhs is stored in 8bit in q2-q3. +// A 8x4x2 cell of Lhs is stored in 8bit in q0-q1,q4-q5 +// A 8x12 block of accumulators is stored in 32bit in q6--q21. +// +// +--------+--------+ +// |v2[0-16]|v3[0-16]| +// Rhs +--------+--------+ +// +// | | | +// +// Lhs | | | +// +// +-------+-------+ - - - - +--------+--------+ +// |v0[0-4]|v4[0-4]| | v6[0-4]|v14[0-4]| +// |v0[0-4]|v4[0-4]| | v7[0-4]|v15[0-4]| +// |v0[0-4]|v4[0-4]| | v8[0-4]|v16[0-4]| +// |v0[0-4]|v4[0-4]| | v9[0-4]|v17[0-4]| +// |v1[0-4]|v5[0-4]| |v10[0-4]|v18[0-4]| +// |v1[0-4]|v5[0-4]| |v11[0-4]|v19[0-4]| +// |v1[0-4]|v5[0-4]| |v12[0-4]|v20[0-4]| +// |v1[0-4]|v5[0-4]| |v13[0-4]|v21[0-4]| +// +-------+-------+ - - - - +--------+--------+ +// +// Accumulator +// +// C = sum((A - zA) * (B - zB)) = sum(A * B) - sum(A) * zB - sum(B) * zA + zA * +// zB * k +// A -> v27, v28 | B -> v29, v30 | zA * zB * k -> v26 + +static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, + uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { + K /= 4; + const uint8_t* a_ptr = packA; + const uint8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + + int32x4_t a0; + int32x4_t a1; + int32x4_t b0; + int32x4_t b1; + int32x4_t a0a; + int32x4_t a1a; + LDC = LDC * sizeof(int32_t); + + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + int32_t* outptr4; + int32_t* outptr5; + int32_t* outptr6; + int32_t* outptr7; + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "add %[outptr4], %[outptr3], %x[LDC]\n" + "add %[outptr5], %[outptr4], %x[LDC]\n" + "add %[outptr6], %[outptr5], %x[LDC]\n" + "add %[outptr7], %[outptr6], %x[LDC]\n" + "dup v24.16b, %w[zero_point_B] \n" + "dup v25.16b, %w[zero_point_A] \n" + "dup v26.4s, %w[zAB] \n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + + "ldp q6, q14, [%[outptr0]]\n" + "ldp q7, q15, [%[outptr1]]\n" + "ldp q8, q16, [%[outptr2]]\n" + "ldp q9, q17, [%[outptr3]]\n" + "ldp q10, q18, [%[outptr4]]\n" + "ldp q11, q19, [%[outptr5]]\n" + "ldp q12, q20, [%[outptr6]]\n" + "ldp q13, q21, [%[outptr7]]\n" + "b 2f\n" + + "1:\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + + "2: \n" + "cbz %w[oddk], 3f\n" + // parse the oddk + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[a1], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[b1], [%[b_ptr]], #16\n" + "udot v27.4s, %[a0].16b, v24.16b\n" + "udot v28.4s, %[a1].16b, v24.16b\n" + "udot v29.4s, %[b0].16b, v25.16b\n" + "udot v30.4s, %[b1].16b, v25.16b\n" + "udot v6.4s, %[b0].16b, %[a0].4b[0]\n" + "udot v7.4s, %[b0].16b, %[a0].4b[1]\n" + "udot v8.4s, %[b0].16b, %[a0].4b[2]\n" + "udot v9.4s, %[b0].16b, %[a0].4b[3]\n" + "udot v10.4s, %[b0].16b, %[a1].4b[0]\n" + "udot v11.4s, %[b0].16b, %[a1].4b[1]\n" + "udot v12.4s, %[b0].16b, %[a1].4b[2]\n" + "udot v13.4s, %[b0].16b, %[a1].4b[3]\n" + "udot v14.4s, %[b1].16b, %[a0].4b[0]\n" + "udot v15.4s, %[b1].16b, %[a0].4b[1]\n" + "udot v16.4s, %[b1].16b, %[a0].4b[2]\n" + "udot v17.4s, %[b1].16b, %[a0].4b[3]\n" + "udot v18.4s, %[b1].16b, %[a1].4b[0]\n" + "udot v19.4s, %[b1].16b, %[a1].4b[1]\n" + "udot v20.4s, %[b1].16b, %[a1].4b[2]\n" + "udot v21.4s, %[b1].16b, %[a1].4b[3]\n" + + "cbz %w[k], 4f\n" + // Loop proper + "3:\n" + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[a1], [%[a_ptr]], #16\n" + "ldr %q[a0a], [%[a_ptr]], #16\n" + "ldr %q[a1a], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[b1], [%[b_ptr]], #16\n" + "udot v27.4s, %[a0].16b, v24.16b\n" + "udot v28.4s, %[a1].16b, v24.16b\n" + "udot v27.4s, %[a0a].16b, v24.16b\n" + "udot v28.4s, %[a1a].16b, v24.16b\n" + "udot v29.4s, %[b0].16b, v25.16b\n" + "udot v30.4s, %[b1].16b, v25.16b\n" + "udot v6.4s, %[b0].16b, %[a0].4b[0]\n" + "udot v7.4s, %[b0].16b, %[a0].4b[1]\n" + "udot v8.4s, %[b0].16b, %[a0].4b[2]\n" + "udot v9.4s, %[b0].16b, %[a0].4b[3]\n" + "udot v10.4s, %[b0].16b, %[a1].4b[0]\n" + "udot v11.4s, %[b0].16b, %[a1].4b[1]\n" + "udot v12.4s, %[b0].16b, %[a1].4b[2]\n" + "udot v13.4s, %[b0].16b, %[a1].4b[3]\n" + "udot v14.4s, %[b1].16b, %[a0].4b[0]\n" + "udot v15.4s, %[b1].16b, %[a0].4b[1]\n" + "udot v16.4s, %[b1].16b, %[a0].4b[2]\n" + "udot v17.4s, %[b1].16b, %[a0].4b[3]\n" + "udot v18.4s, %[b1].16b, %[a1].4b[0]\n" + "udot v19.4s, %[b1].16b, %[a1].4b[1]\n" + "udot v20.4s, %[b1].16b, %[a1].4b[2]\n" + "udot v21.4s, %[b1].16b, %[a1].4b[3]\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[b1], [%[b_ptr]], #16\n" + "udot v29.4s, %[b0].16b, v25.16b\n" + "udot v30.4s, %[b1].16b, v25.16b\n" + "udot v6.4s, %[b0].16b, %[a0a].4b[0]\n" + "udot v7.4s, %[b0].16b, %[a0a].4b[1]\n" + "udot v8.4s, %[b0].16b, %[a0a].4b[2]\n" + "udot v9.4s, %[b0].16b, %[a0a].4b[3]\n" + "udot v10.4s, %[b0].16b, %[a1a].4b[0]\n" + "udot v11.4s, %[b0].16b, %[a1a].4b[1]\n" + "udot v12.4s, %[b0].16b, %[a1a].4b[2]\n" + "udot v13.4s, %[b0].16b, %[a1a].4b[3]\n" + "udot v14.4s, %[b1].16b, %[a0a].4b[0]\n" + "udot v15.4s, %[b1].16b, %[a0a].4b[1]\n" + "udot v16.4s, %[b1].16b, %[a0a].4b[2]\n" + "udot v17.4s, %[b1].16b, %[a0a].4b[3]\n" + "udot v18.4s, %[b1].16b, %[a1a].4b[0]\n" + "udot v19.4s, %[b1].16b, %[a1a].4b[1]\n" + "udot v20.4s, %[b1].16b, %[a1a].4b[2]\n" + "udot v21.4s, %[b1].16b, %[a1a].4b[3]\n" + + "subs %w[k], %w[k], #1\n" + "bne 3b\n" + + "4:\n" + //! minus zAB + "sub v27.4s, v27.4s, v26.4s\n" + "sub v28.4s, v28.4s, v26.4s\n" + + // clang-format off + SUB_LANE(6, 27, 0, 29, 23) + SUB_LANE(14, 27, 0, 30, 23) + SUB_LANE(7, 27, 1, 29, 23) + SUB_LANE(15, 27, 1, 30, 23) + SUB_LANE(8, 27, 2, 29, 23) + SUB_LANE(16, 27, 2, 30, 23) + SUB_LANE(9, 27, 3, 29, 23) + SUB_LANE(17, 27, 3, 30, 23) + SUB_LANE(10, 28, 0, 29, 23) + SUB_LANE(18, 28, 0, 30, 23) + SUB_LANE(11, 28, 1, 29, 23) + SUB_LANE(19, 28, 1, 30, 23) + SUB_LANE(12, 28, 2, 29, 23) + SUB_LANE(20, 28, 2, 30, 23) + SUB_LANE(13, 28, 3, 29, 23) + SUB_LANE(21, 28, 3, 30, 23) + // clang-format on + + "stp q6, q14, [%[outptr0]]\n" + "stp q7, q15, [%[outptr1]]\n" + "stp q8, q16, [%[outptr2]]\n" + "stp q9, q17, [%[outptr3]]\n" + "stp q10, q18, [%[outptr4]]\n" + "stp q11, q19, [%[outptr5]]\n" + "stp q12, q20, [%[outptr6]]\n" + "stp q13, q21, [%[outptr7]]\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), + [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), + [b1] "+w"(b1), [k] "+r"(k), [LDC] "+r"(LDC), [oddk] "+r"(oddk), + [is_first_k] "+r"(is_first_k), [outptr0] "+r"(outptr0), + [zero_point_A] "+r"(zero_point_A), + [zero_point_B] "+r"(zero_point_B), [zAB] "+r"(zAB), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), + [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), + [outptr7] "=r"(outptr7) + : + : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "cc", "memory"); +} + +// Overview of register layout: +// +// A 8x4 cell of Rhs is stored in 8bit in q1-q2, q4-q5. +// A 8x4x2 cell of Lhs is stored in 8bit in q0,q3 +// A 8x12 block of accumulators is stored in 8bit in q8--q31. +// +// +--------+--------+ +// |v1[0-16]|v2[0-16]| +// Rhs +--------+--------+ +// |v4[0-16]|v5[0-16]| +// +--------+--------+ +// +// | | | +// +// Lhs | | | +// +// +-------+-------+ - - - - +--------+--------+ +// |v0[0-4]|v3[0-4]| | v6[0-4]|v10[0-4]| +// |v0[0-4]|v3[0-4]| | v7[0-4]|v11[0-4]| +// |v0[0-4]|v3[0-4]| | v8[0-4]|v12[0-4]| +// |v0[0-4]|v3[0-4]| | v9[0-4]|v13[0-4]| +// +-------+-------+ - - - - +--------+--------+ +// +// Accumulator +// +// C = sum((A - zA) * (B - zB)) = sum(A * B) - sum(A) * zB - sum(B) * zA + zA * +// zB * k +// A -> v28 | B -> v29, v30 | zA * zB * k -> v26 + +static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, int m_remain, + uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { + K /= 4; + const uint8_t* a_ptr = packA; + const uint8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + int32x4_t a0; + int32x4_t b0; + int32x4_t b1; + int32x4_t a0a; + int32x4_t b0a; + int32x4_t b1a; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + size_t x0; + +// clang-format off +#define LOAD_LINE(v1, v2, m) \ + "cbz %[x0], 100f\n" \ + "ldp " v1 "," v2 ", [%[outptr" m "]]\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %x[m_remain]\n" \ + LOAD_LINE("q6", "q10", "0") \ + LOAD_LINE("q7", "q11", "1") \ + LOAD_LINE("q8", "q12", "2") \ + LOAD_LINE("q9", "q13", "3") \ + "100:\n" + +#define STORE_LINE(v1, v2, m) \ + "cbz %[x0], 101f\n" \ + "stp " v1 "," v2", [%[outptr" m "]]\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %x[m_remain]\n" \ + STORE_LINE("q6", "q10", "0") \ + STORE_LINE("q7", "q11", "1") \ + STORE_LINE("q8", "q12", "2") \ + STORE_LINE("q9", "q13", "3") \ + "101:\n" + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "dup v24.16b, %w[zero_point_B] \n" + "dup v25.16b, %w[zero_point_A] \n" + "dup v26.4s, %w[zAB] \n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + + "2: \n" + "cbz %w[oddk], 3f\n" + + // parse the oddk + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[b1], [%[b_ptr]], #16\n" + "udot v28.4s, %[a0].16b, v24.16b\n" + "udot v29.4s, %[b0].16b, v25.16b\n" + "udot v30.4s, %[b1].16b, v25.16b\n" + "udot v6.4s, %[b0].16b, %[a0].4b[0]\n" + "udot v7.4s, %[b0].16b, %[a0].4b[1]\n" + "udot v8.4s, %[b0].16b, %[a0].4b[2]\n" + "udot v9.4s, %[b0].16b, %[a0].4b[3]\n" + "udot v10.4s, %[b1].16b, %[a0].4b[0]\n" + "udot v11.4s, %[b1].16b, %[a0].4b[1]\n" + "udot v12.4s, %[b1].16b, %[a0].4b[2]\n" + "udot v13.4s, %[b1].16b, %[a0].4b[3]\n" + + "cbz %w[k], 4f\n" + // Loop proper + "3:\n" + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[b1], [%[b_ptr]], #16\n" + "ldr %q[a0a], [%[a_ptr]], #16\n" + "ldr %q[b0a], [%[b_ptr]], #16\n" + "ldr %q[b1a], [%[b_ptr]], #16\n" + "udot v28.4s, %[a0].16b, v24.16b\n" + "udot v28.4s, %[a0a].16b, v24.16b\n" + "udot v29.4s, %[b0].16b, v25.16b\n" + "udot v30.4s, %[b1].16b, v25.16b\n" + "udot v29.4s, %[b0a].16b, v25.16b\n" + "udot v30.4s, %[b1a].16b, v25.16b\n" + + "udot v6.4s, %[b0].16b, %[a0].4b[0]\n" + "udot v7.4s, %[b0].16b, %[a0].4b[1]\n" + "udot v8.4s, %[b0].16b, %[a0].4b[2]\n" + "udot v9.4s, %[b0].16b, %[a0].4b[3]\n" + "udot v10.4s, %[b1].16b, %[a0].4b[0]\n" + "udot v11.4s, %[b1].16b, %[a0].4b[1]\n" + "udot v12.4s, %[b1].16b, %[a0].4b[2]\n" + "udot v13.4s, %[b1].16b, %[a0].4b[3]\n" + "udot v6.4s , %[b0a].16b, %[a0a].4b[0]\n" + "udot v7.4s , %[b0a].16b, %[a0a].4b[1]\n" + "udot v8.4s, %[b0a].16b, %[a0a].4b[2]\n" + "udot v9.4s, %[b0a].16b, %[a0a].4b[3]\n" + "udot v10.4s, %[b1a].16b, %[a0a].4b[0]\n" + "udot v11.4s, %[b1a].16b, %[a0a].4b[1]\n" + "udot v12.4s, %[b1a].16b, %[a0a].4b[2]\n" + "udot v13.4s, %[b1a].16b, %[a0a].4b[3]\n" + + "subs %w[k], %w[k], #1\n" + "bne 3b\n" + + "4:\n" + //! minus zAB + "sub v29.4s, v29.4s, v26.4s\n" + "sub v30.4s, v30.4s, v26.4s\n" + + // clang-format off + SUB_LANE(6, 28, 0, 29, 23) + SUB_LANE(10, 28, 0, 30, 23) + SUB_LANE(7, 28, 1, 29, 23) + SUB_LANE(11, 28, 1, 30, 23) + SUB_LANE(8, 28, 2, 29, 23) + SUB_LANE(12, 28, 2, 30, 23) + SUB_LANE(9, 28, 3, 29, 23) + SUB_LANE(13, 28, 3, 30, 23) + // clang-format on + + STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), + [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), + [is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain), + [zero_point_A] "+r"(zero_point_A), + [zero_point_B] "+r"(zero_point_B), [zAB] "+r"(zAB), + [LDC] "+r"(LDC), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), + [b1] "=w"(b1), [b0a] "=w"(b0a), [b1a] "=w"(b1a), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + [outptr3] "=r"(outptr3), [x0] "=r"(x0) + : + : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v23", "v24", + "v25", "v26", "v28", "v29", "v30", "memory", "cc"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +// Overview of register layout: +// +// A (4x4)x2 cell of Rhs is stored in 8bit in q2-q3. +// A 4x4x2 cell of Lhs is stored in 8bit in q0-q1, q4-a5 +// A 8x4 block of accumulators is stored in 8bit in q4--q7. +// +// +--------+ +// |v2[0-16]| +// Rhs +--------+ +// |v3[0-16]| +// +--------+ +// | | +// +// Lhs | | +// +// +-------+-------+ - - - - +--------+ +// |v0[0-4]|v4[0-4]| | v6[0-4]| +// |v0[0-4]|v4[0-4]| | v7[0-4]| +// |v0[0-4]|v4[0-4]| | v8[0-4]| +// |v0[0-4]|v4[0-4]| | v9[0-4]| +// |v1[0-4]|v5[0-4]| |v10[0-4]| +// |v1[0-4]|v5[0-4]| |v11[0-4]| +// |v1[0-4]|v5[0-4]| |v12[0-4]| +// |v1[0-4]|v5[0-4]| |v13[0-4]| +// +-------+-------+ - - - - +---------+ +// +// Accumulator +// +// C = sum((A - zA) * (B - zB)) = sum(A * B) - sum(A) * zB - sum(B) * zA + zA * +// zB * k +// A -> v27, v28 | B -> v29 | zA * zB * k -> v26 + +static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, int n_remain, + uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { + K /= 4; + const uint8_t* a_ptr = packA; + const uint8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + int32x4_t a0; + int32x4_t a1; + int32x4_t b0; + int32x4_t b0a; + int32x4_t a0a; + int32x4_t a1a; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + int32_t* outptr4; + int32_t* outptr5; + int32_t* outptr6; + int32_t* outptr7; + + size_t x0; + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" reg_index ", [%[x0]] \n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "101" n ":\n" + + +#define LOAD_C \ + LOAD_LINE("6", "0") \ + LOAD_LINE("7", "1") \ + LOAD_LINE("8", "2") \ + LOAD_LINE("9", "3") \ + LOAD_LINE("10", "4") \ + LOAD_LINE("11", "5") \ + LOAD_LINE("12", "6") \ + LOAD_LINE("13", "7") + +#define STORE_LINE(reg_index, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 102" n "f\n" \ + "str q" reg_index ", [%[x0]]\n" \ + "b 103" n "f\n" \ + "102" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "103" n ":\n" + +#define STORE_C \ + STORE_LINE("6", "0") \ + STORE_LINE("7", "1") \ + STORE_LINE("8", "2") \ + STORE_LINE("9", "3") \ + STORE_LINE("10", "4") \ + STORE_LINE("11", "5") \ + STORE_LINE("12", "6") \ + STORE_LINE("13", "7") + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "add %[outptr4], %[outptr3], %x[LDC]\n" + "add %[outptr5], %[outptr4], %x[LDC]\n" + "add %[outptr6], %[outptr5], %x[LDC]\n" + "add %[outptr7], %[outptr6], %x[LDC]\n" + "dup v24.16b, %w[zero_point_B] \n" + "dup v25.16b, %w[zero_point_A] \n" + "dup v26.4s, %w[zAB] \n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + + "2: \n" + "cbz %w[oddk], 3f\n" + + // parse the oddk + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[a1], [%[a_ptr]], #16\n" + "udot v27.4s, %[a0].16b, v24.16b\n" + "udot v28.4s, %[a1].16b, v24.16b\n" + "udot v29.4s, %[b0].16b, v25.16b\n" + "udot v6.4s , %[b0].16b, %[a0].4b[0]\n" + "udot v7.4s , %[b0].16b, %[a0].4b[1]\n" + "udot v8.4s, %[b0].16b, %[a0].4b[2]\n" + "udot v9.4s, %[b0].16b, %[a0].4b[3]\n" + "udot v10.4s, %[b0].16b, %[a1].4b[0]\n" + "udot v11.4s, %[b0].16b, %[a1].4b[1]\n" + "udot v12.4s, %[b0].16b, %[a1].4b[2]\n" + "udot v13.4s, %[b0].16b, %[a1].4b[3]\n" + + "cbz %w[k], 4f\n" + // Loop proper + "3:\n" + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[a1], [%[a_ptr]], #16\n" + "ldr %q[a0a], [%[a_ptr]], #16\n" + "ldr %q[a1a], [%[a_ptr]], #16\n" + "ldr %q[b0a], [%[b_ptr]], #16\n" + "udot v27.4s, %[a0].16b, v24.16b\n" + "udot v28.4s, %[a1].16b, v24.16b\n" + "udot v27.4s, %[a0a].16b, v24.16b\n" + "udot v28.4s, %[a1a].16b, v24.16b\n" + "udot v29.4s, %[b0].16b, v25.16b\n" + "udot v29.4s, %[b0a].16b, v25.16b\n" + "udot v6.4s , %[b0].16b, %[a0].4b[0]\n" + "udot v7.4s , %[b0].16b, %[a0].4b[1]\n" + "udot v8.4s, %[b0].16b, %[a0].4b[2]\n" + "udot v9.4s, %[b0].16b, %[a0].4b[3]\n" + "udot v10.4s, %[b0].16b, %[a1].4b[0]\n" + "udot v11.4s, %[b0].16b, %[a1].4b[1]\n" + "udot v12.4s, %[b0].16b, %[a1].4b[2]\n" + "udot v13.4s, %[b0].16b, %[a1].4b[3]\n" + "udot v6.4s , %[b0a].16b, %[a0a].4b[0]\n" + "udot v7.4s , %[b0a].16b, %[a0a].4b[1]\n" + "udot v8.4s, %[b0a].16b, %[a0a].4b[2]\n" + "udot v9.4s, %[b0a].16b, %[a0a].4b[3]\n" + "udot v10.4s, %[b0a].16b, %[a1a].4b[0]\n" + "udot v11.4s, %[b0a].16b, %[a1a].4b[1]\n" + "udot v12.4s, %[b0a].16b, %[a1a].4b[2]\n" + "udot v13.4s, %[b0a].16b, %[a1a].4b[3]\n" + + "subs %w[k], %w[k], #1\n" + "bne 3b\n" + + "4:\n" + //! minus zAB + "sub v27.4s, v27.4s, v26.4s\n" + "sub v28.4s, v28.4s, v26.4s\n" + + // clang-format off + SUB_LANE(6, 27, 0, 29, 23) + SUB_LANE(7, 27, 1, 29, 23) + SUB_LANE(8, 27, 2, 29, 23) + SUB_LANE(9, 27, 3, 29, 23) + SUB_LANE(10, 28, 0, 29, 23) + SUB_LANE(11, 28, 1, 29, 23) + SUB_LANE(12, 28, 2, 29, 23) + SUB_LANE(13, 28, 3, 29, 23) + // clang-format on + + STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC), + [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), + [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), + [zero_point_A] "+r"(zero_point_A), + [zero_point_B] "+r"(zero_point_B), [zAB] "+r"(zAB), [a0] "=w"(a0), + [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), [b0] "=w"(b0), + [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), + [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), + [outptr7] "=r"(outptr7), [x0] "=r"(x0) + : + : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "memory", "cc"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +// Overview of register layout: +// +// A 4x4x2 cell of Rhs is stored in 8bit in q2-q3. +// A 4x4x2 cell of Lhs is stored in 8bit in q0-q1 +// A 4x4x2 block of accumulators is stored in 8bit in q4--q7. +// +// +--------+ +// | v2[0-7]| +// Rhs +--------+ +// | v3[0-7]| +// +--------+ +// | | +// +// Lhs | | +// +// +-------+-------+ - - - - +--------+ +// |v0[0-4]|v1[0-4]| | v4[0-7]| +// |v0[0-4]|v1[0-4]| | v5[0-7]| +// |v0[0-4]|v1[0-4]| | v6[0-7]| +// |v0[0-4]|v1[0-4]| | v7[0-7]| +// +-------+-------+ - - - - +--------+ +// +// Accumulator +// +// C = sum((A - zA) * (B - zB)) = sum(A * B) - sum(A) * zB - sum(B) * zA + zA * +// zB * k +// A -> v28 | B -> v29 | zA * zB * k -> v26 + +static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, int m_remain, + int n_remain, uint8_t zero_point_A, uint8_t zero_point_B, + uint32_t zAB) { + K /= 4; + const int32_t* a_ptr = reinterpret_cast(packA); + const int32_t* b_ptr = reinterpret_cast(packB); + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + int32x4_t a0; + int32x4_t a0a; + int32x4_t b0; + int32x4_t b0a; + LDC = LDC * sizeof(int32_t); + + int32_t* outptr0 = output; + int32_t* outptr1; + int32_t* outptr2; + int32_t* outptr3; + size_t x0, x1; + +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "cbz %[x1], 102f\n" \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" reg_index ", [%[x0]]\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "101" n ":\n" \ + "subs %[x1], %[x1], #1\n" + +#define LOAD_C \ + "mov %[x1], %x[m_remain]\n" \ + LOAD_LINE("4", "0") \ + LOAD_LINE("5", "1") \ + LOAD_LINE("6", "2") \ + LOAD_LINE("7", "3") \ + "102:\n" + +#define STORE_LINE(reg_index, n) \ + "cbz %[x1], 105f\n" \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "str q" reg_index ", [%[x0]]\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \ + "104" n ":\n" \ + "subs %[x1], %[x1], #1\n" + +#define STORE_C \ + "mov %[x1], %x[m_remain]\n" \ + STORE_LINE("4", "0") \ + STORE_LINE("5", "1") \ + STORE_LINE("6", "2") \ + STORE_LINE("7", "3") \ + "105:\n" + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "add %[outptr2], %[outptr1], %x[LDC]\n" + "add %[outptr3], %[outptr2], %x[LDC]\n" + "dup v24.16b, %w[zero_point_B] \n" + "dup v25.16b, %w[zero_point_A] \n" + "dup v26.4s, %w[zAB] \n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" // + LOAD_C // + + "b 2f\n" + + "1:\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + + "2: \n" + "cbz %w[oddk], 3f\n" + + // parse the oddk + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "udot v28.4s, %[a0].16b, v24.16b\n" + "udot v29.4s, %[b0].16b, v25.16b\n" + "udot v4.4s , %[b0].16b, %[a0].4b[0]\n" + "udot v5.4s , %[b0].16b, %[a0].4b[1]\n" + "udot v6.4s, %[b0].16b, %[a0].4b[2]\n" + "udot v7.4s, %[b0].16b, %[a0].4b[3]\n" + + "cbz %w[k], 4f\n" + // Loop proper + "3:\n" + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[a0a], [%[a_ptr]], #16\n" + "ldr %q[b0a], [%[b_ptr]], #16\n" + "udot v28.4s, %[a0].16b, v24.16b\n" + "udot v28.4s, %[a0a].16b, v24.16b\n" + "udot v29.4s, %[b0].16b, v25.16b\n" + "udot v29.4s, %[b0a].16b, v25.16b\n" + "udot v4.4s , %[b0].16b, %[a0].4b[0]\n" + "udot v5.4s , %[b0].16b, %[a0].4b[1]\n" + "udot v6.4s, %[b0].16b, %[a0].4b[2]\n" + "udot v7.4s, %[b0].16b, %[a0].4b[3]\n" + "udot v4.4s , %[b0a].16b, %[a0a].4b[0]\n" + "udot v5.4s , %[b0a].16b, %[a0a].4b[1]\n" + "udot v6.4s, %[b0a].16b, %[a0a].4b[2]\n" + "udot v7.4s, %[b0a].16b, %[a0a].4b[3]\n" + + "subs %w[k], %w[k], #1\n" + "bne 3b\n" + + "4:\n" + //! minus zAB + "sub v28.4s, v28.4s, v26.4s\n" + + // clang-format off + SUB_LANE(4, 28, 0, 29, 23) + SUB_LANE(5, 28, 1, 29, 23) + SUB_LANE(6, 28, 2, 29, 23) + SUB_LANE(7, 28, 3, 29, 23) + // clang-format on + + STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), + [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), + [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), + [zero_point_A] "+r"(zero_point_A), + [zero_point_B] "+r"(zero_point_B), [zAB] "+r"(zAB), + [outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), + [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + [outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1) + : + : "v4", "v5", "v6", "v7", "v23", "v24", "v25", "v26", "v28", "v29", + "memory", "cc"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +#undef SUB_LANE + +static void gemm_u8_8x8_transpose_pack_helper(uint8_t* out, const uint8_t* in, + int ldin, int x0, int xmax, + int k0, int kmax) { + uint8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(uint8_t) * 16); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 4) * 4; + const int ksize8 = ksize4 * 2; + uint8_t* outptr = out; + uint8_t* outptr_base = out; + //! 4x4 block output start pos + uint8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k += 4) { + const uint8_t* inptr0 = in + k * ldin + x0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int x = x0; + outptr = outptr_base; + + for (; x + 7 < xmax; x += 8) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += ksize8; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, 4); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + + outptr_base += 8 * 4; + outptr_base4 += 4 * 4; + } +} + +static void gemm_u8_8x8_interleave_pack_helper(uint8_t* outptr, + const uint8_t* inptr, int ldin, + int y0, int ymax, int k0, + int kmax) { + uint8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(uint8_t) * 16); + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const uint8_t* inptr0 = inptr + y * ldin + k0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + const uint8_t* inptr4 = inptr3 + ldin; + const uint8_t* inptr5 = inptr4 + ldin; + const uint8_t* inptr6 = inptr5 + ldin; + const uint8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + //! read 8 * 4 in each row + for (; K > 15; K -= 16) { + interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + } + + if (K > 0) { + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, K); + } + } + for (; y < ymax; y += 4) { + const uint8_t* inptr0 = inptr + y * ldin + k0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + //! read 4 * 4 in each row + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, K); + } + } +} + +} // namespace matmul_8x8x4 +} // namespace aarch64 +} // namespace megdnn + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp b/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp new file mode 100644 index 00000000..305996b3 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp @@ -0,0 +1,114 @@ +/** + * \file dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/aarch64/matrix_mul/quint8_dot/strategy.h" +#include "megdnn/dtype.h" +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +#if __ARM_FEATURE_DOTPROD +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8); + +void gemm_u8_8x8::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0, + ymax, k0, kmax); + } else { + matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin, + y0, ymax, k0, kmax); + } +} + +void gemm_u8_8x8::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) const { + if (transpose) { + matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, + xmax, k0, kmax); + } else { + matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax, + k0, kmax); + } +} + +void gemm_u8_8x8::kern(const uint8_t* packA, const uint8_t* packB, size_t M, + size_t N, size_t K, dt_int32* C, size_t LDC, + bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Quantized8Asymm && + C_dtype.enumv() == DTypeEnum::QuantizedS32); + MEGDNN_MARK_USED_VAR(C_dtype); + size_t zero_point_A = A_dtype.param().zero_point; + size_t zero_point_B = B_dtype.param().zero_point; + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t B_INTERLEAVE = 8; + const uint32_t zAB = static_cast(zero_point_A) * + static_cast(zero_point_B) * K; + //! K is packed to times of 4 + K = round_up(K, 4); + const int K8 = (K << 3); + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + (m * LDC); + + size_t n = 0; + const dt_uint8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k, + zero_point_A, zero_point_B, zAB); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_8x8x4::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4), zero_point_A, + zero_point_B, zAB); + output += 4; + cur_packB += K4; + } + packA += K8; + } + + for (; m < M; m += 4) { + int32_t* output = C + (m * LDC); + const dt_uint8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), zero_point_A, + zero_point_B, zAB); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_8x8x4::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), + std::min(N - n, 4), zero_point_A, + zero_point_B, zAB); + output += 4; + cur_packB += K4; + } + packA += K4; + } +} +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h b/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h new file mode 100644 index 00000000..c84eee07 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h @@ -0,0 +1,26 @@ +/** + * \file dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +#if __ARM_FEATURE_DOTPROD +namespace megdnn { +namespace aarch64 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(uint8_t, int32_t, int32_t, 8, 8, 4, false, true, + gemm_u8_8x8); + +} // namespace aarch64 +} // namespace matmul +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/relayout/opr_impl.cpp b/dnn/src/aarch64/relayout/opr_impl.cpp new file mode 100644 index 00000000..df34c812 --- /dev/null +++ b/dnn/src/aarch64/relayout/opr_impl.cpp @@ -0,0 +1,183 @@ +/** + * \file dnn/src/aarch64/relayout/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/common/utils.h" +#include "src/common/relayout_helper.h" + +#include "src/aarch64/handle.h" +#include "src/aarch64/relayout/opr_impl.h" + +using namespace megdnn; +using namespace relayout; + +namespace { + +struct TransposeByte { + uint8_t v; +}; + +void trans_16x16_u8(const void* src, void* dst, const size_t src_step, + const size_t dst_step) { + asm volatile( + "\n" + "ld1 {v0.16b}, [%[src]], %[src_step] \n" + "ld1 {v1.16b}, [%[src]], %[src_step] \n" + "ld1 {v2.16b}, [%[src]], %[src_step] \n" + "ld1 {v3.16b}, [%[src]], %[src_step] \n" + "ld1 {v4.16b}, [%[src]], %[src_step] \n" + "ld1 {v5.16b}, [%[src]], %[src_step] \n" + "ld1 {v6.16b}, [%[src]], %[src_step] \n" + "ld1 {v7.16b}, [%[src]], %[src_step] \n" + "ld1 {v8.16b}, [%[src]], %[src_step] \n" + "ld1 {v9.16b}, [%[src]], %[src_step] \n" + "ld1 {v10.16b}, [%[src]], %[src_step] \n" + "ld1 {v11.16b}, [%[src]], %[src_step] \n" + "ld1 {v12.16b}, [%[src]], %[src_step] \n" + "ld1 {v13.16b}, [%[src]], %[src_step] \n" + "ld1 {v14.16b}, [%[src]], %[src_step] \n" + "ld1 {v15.16b}, [%[src]], %[src_step] \n" + "trn1 v16.16b, v0.16b, v1.16b \n" + "trn2 v17.16b, v0.16b, v1.16b \n" + "trn1 v18.16b, v2.16b, v3.16b \n" + "trn2 v19.16b, v2.16b, v3.16b \n" + "trn1 v20.16b, v4.16b, v5.16b \n" + "trn2 v21.16b, v4.16b, v5.16b \n" + "trn1 v22.16b, v6.16b, v7.16b \n" + "trn2 v23.16b, v6.16b, v7.16b \n" + "trn1 v24.16b, v8.16b, v9.16b \n" + "trn2 v25.16b, v8.16b, v9.16b \n" + "trn1 v26.16b, v10.16b, v11.16b \n" + "trn2 v27.16b, v10.16b, v11.16b \n" + "trn1 v28.16b, v12.16b, v13.16b \n" + "trn2 v29.16b, v12.16b, v13.16b \n" + "trn1 v30.16b, v14.16b, v15.16b \n" + "trn2 v31.16b, v14.16b, v15.16b \n" + "trn1 v0.8h, v16.8h, v18.8h \n" + "trn2 v2.8h, v16.8h, v18.8h \n" + "trn1 v4.8h, v20.8h, v22.8h \n" + "trn2 v6.8h, v20.8h, v22.8h \n" + "trn1 v8.8h, v24.8h, v26.8h \n" + "trn2 v10.8h, v24.8h, v26.8h \n" + "trn1 v12.8h, v28.8h, v30.8h \n" + "trn2 v14.8h, v28.8h, v30.8h \n" + "trn1 v1.8h, v17.8h, v19.8h \n" + "trn2 v3.8h, v17.8h, v19.8h \n" + "trn1 v5.8h, v21.8h, v23.8h \n" + "trn2 v7.8h, v21.8h, v23.8h \n" + "trn1 v9.8h, v25.8h, v27.8h \n" + "trn2 v11.8h, v25.8h, v27.8h \n" + "trn1 v13.8h, v29.8h, v31.8h \n" + "trn2 v15.8h, v29.8h, v31.8h \n" + "trn1 v16.4s, v0.4s, v4.4s \n" + "trn2 v20.4s, v0.4s, v4.4s \n" + "trn1 v24.4s, v8.4s, v12.4s \n" + "trn2 v28.4s, v8.4s, v12.4s \n" + "trn1 v17.4s, v1.4s, v5.4s \n" + "trn2 v21.4s, v1.4s, v5.4s \n" + "trn1 v25.4s, v9.4s, v13.4s \n" + "trn2 v29.4s, v9.4s, v13.4s \n" + "trn1 v18.4s, v2.4s, v6.4s \n" + "trn2 v22.4s, v2.4s, v6.4s \n" + "trn1 v26.4s, v10.4s, v14.4s \n" + "trn2 v30.4s, v10.4s, v14.4s \n" + "trn1 v19.4s, v3.4s, v7.4s \n" + "trn2 v23.4s, v3.4s, v7.4s \n" + "trn1 v27.4s, v11.4s, v15.4s \n" + "trn2 v31.4s, v11.4s, v15.4s \n" + "trn1 v0.2d, v16.2d, v24.2d \n" + "trn2 v8.2d, v16.2d, v24.2d \n" + "trn1 v1.2d, v17.2d, v25.2d \n" + "trn2 v9.2d, v17.2d, v25.2d \n" + "trn1 v2.2d, v18.2d, v26.2d \n" + "trn2 v10.2d, v18.2d, v26.2d \n" + "trn1 v3.2d, v19.2d, v27.2d \n" + "trn2 v11.2d, v19.2d, v27.2d \n" + "trn1 v4.2d, v20.2d, v28.2d \n" + "trn2 v12.2d, v20.2d, v28.2d \n" + "trn1 v5.2d, v21.2d, v29.2d \n" + "trn2 v13.2d, v21.2d, v29.2d \n" + "trn1 v6.2d, v22.2d, v30.2d \n" + "trn2 v14.2d, v22.2d, v30.2d \n" + "trn1 v7.2d, v23.2d, v31.2d \n" + "trn2 v15.2d, v23.2d, v31.2d \n" + "st1 {v0.16b}, [%[dst]], %[dst_step] \n" + "st1 {v1.16b}, [%[dst]], %[dst_step] \n" + "st1 {v2.16b}, [%[dst]], %[dst_step] \n" + "st1 {v3.16b}, [%[dst]], %[dst_step] \n" + "st1 {v4.16b}, [%[dst]], %[dst_step] \n" + "st1 {v5.16b}, [%[dst]], %[dst_step] \n" + "st1 {v6.16b}, [%[dst]], %[dst_step] \n" + "st1 {v7.16b}, [%[dst]], %[dst_step] \n" + "st1 {v8.16b}, [%[dst]], %[dst_step] \n" + "st1 {v9.16b}, [%[dst]], %[dst_step] \n" + "st1 {v10.16b}, [%[dst]], %[dst_step] \n" + "st1 {v11.16b}, [%[dst]], %[dst_step] \n" + "st1 {v12.16b}, [%[dst]], %[dst_step] \n" + "st1 {v13.16b}, [%[dst]], %[dst_step] \n" + "st1 {v14.16b}, [%[dst]], %[dst_step] \n" + "st1 {v15.16b}, [%[dst]], %[dst_step] \n" + : + [src] "+r" (src), + [dst] "+r" (dst) + : + [src_step] "r" (src_step), + [dst_step] "r" (dst_step) + : + "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", + "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", + "d31"); + +} + +} // anonymous namespace + +namespace megdnn { +namespace relayout { +namespace transpose_fallback { +template <> +struct transpose_traits { + static constexpr size_t block_size = 16; +}; + +template <> +void transpose_block(const TransposeByte* src, + TransposeByte* dst, const size_t src_stride, + const size_t dst_stride) { + trans_16x16_u8(src, dst, src_stride, dst_stride); +} + +} // namespace transpose_fallback +} // namespace relayout +} // namespace megdnn + +void aarch64::RelayoutForwardImpl::exec(_megdnn_tensor_in src0, + _megdnn_tensor_out dst0, + Handle* src_handle) { + check_cpu_handle(src_handle); + TensorND src = src0, dst = dst0; + check_layout_and_canonize(src.layout, dst.layout); + + relayout::TransposeParam trans_param; + bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); + if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { + auto sptr = static_cast(src.raw_ptr), + dptr = static_cast(dst.raw_ptr); + MEGDNN_DISPATCH_CPU_KERN_OPR( + transpose_fallback::transpose( + trans_param.batch, trans_param.m, trans_param.n, sptr, + dptr)); + return; + } + exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/relayout/opr_impl.h b/dnn/src/aarch64/relayout/opr_impl.h new file mode 100644 index 00000000..96195cee --- /dev/null +++ b/dnn/src/aarch64/relayout/opr_impl.h @@ -0,0 +1,31 @@ +/** + * \file dnn/src/aarch64/relayout/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "src/fallback/relayout/opr_impl.h" + +namespace megdnn { +namespace aarch64 { + +class RelayoutForwardImpl final : public fallback::RelayoutForwardImpl { + public: + using fallback::RelayoutForwardImpl::RelayoutForwardImpl; + + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + Handle *src_handle) override; + + bool is_thread_safe() const override { return true; } +}; + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/rotate/opr_impl.cpp b/dnn/src/aarch64/rotate/opr_impl.cpp new file mode 100644 index 00000000..f8a0cd51 --- /dev/null +++ b/dnn/src/aarch64/rotate/opr_impl.cpp @@ -0,0 +1,393 @@ +/** + * \file dnn/src/aarch64/rotate/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include + +#include "src/aarch64/rotate/opr_impl.h" +#include "src/aarch64/handle.h" +#include "src/common/cv/common.h" +#include "src/common/cv/helper.h" +#include "src/common/utils.h" + + +namespace megdnn { +namespace megcv { + +void rotate_8uc1_clockwise_16x16(const uchar *src, + uchar *dst, + size_t src_step, size_t dst_step) +{ + asm volatile ("\n" + "ld1 {v0.16b}, [%[src]], %[src_step] \n" + "ld1 {v1.16b}, [%[src]], %[src_step] \n" + "ld1 {v2.16b}, [%[src]], %[src_step] \n" + "ld1 {v3.16b}, [%[src]], %[src_step] \n" + "ld1 {v4.16b}, [%[src]], %[src_step] \n" + "ld1 {v5.16b}, [%[src]], %[src_step] \n" + "ld1 {v6.16b}, [%[src]], %[src_step] \n" + "ld1 {v7.16b}, [%[src]], %[src_step] \n" + "ld1 {v8.16b}, [%[src]], %[src_step] \n" + "ld1 {v9.16b}, [%[src]], %[src_step] \n" + "ld1 {v10.16b}, [%[src]], %[src_step] \n" + "ld1 {v11.16b}, [%[src]], %[src_step] \n" + "ld1 {v12.16b}, [%[src]], %[src_step] \n" + "ld1 {v13.16b}, [%[src]], %[src_step] \n" + "ld1 {v14.16b}, [%[src]], %[src_step] \n" + "ld1 {v15.16b}, [%[src]], %[src_step] \n" + + "trn1 v16.16b, v0.16b, v1.16b \n" + "trn2 v17.16b, v0.16b, v1.16b \n" + "trn1 v18.16b, v2.16b, v3.16b \n" + "trn2 v19.16b, v2.16b, v3.16b \n" + "trn1 v20.16b, v4.16b, v5.16b \n" + "trn2 v21.16b, v4.16b, v5.16b \n" + "trn1 v22.16b, v6.16b, v7.16b \n" + "trn2 v23.16b, v6.16b, v7.16b \n" + "trn1 v24.16b, v8.16b, v9.16b \n" + "trn2 v25.16b, v8.16b, v9.16b \n" + "trn1 v26.16b, v10.16b, v11.16b \n" + "trn2 v27.16b, v10.16b, v11.16b \n" + "trn1 v28.16b, v12.16b, v13.16b \n" + "trn2 v29.16b, v12.16b, v13.16b \n" + "trn1 v30.16b, v14.16b, v15.16b \n" + "trn2 v31.16b, v14.16b, v15.16b \n" + + "trn1 v0.8h, v16.8h, v18.8h \n" + "trn2 v2.8h, v16.8h, v18.8h \n" + "trn1 v4.8h, v20.8h, v22.8h \n" + "trn2 v6.8h, v20.8h, v22.8h \n" + "trn1 v8.8h, v24.8h, v26.8h \n" + "trn2 v10.8h, v24.8h, v26.8h \n" + "trn1 v12.8h, v28.8h, v30.8h \n" + "trn2 v14.8h, v28.8h, v30.8h \n" + "trn1 v1.8h, v17.8h, v19.8h \n" + "trn2 v3.8h, v17.8h, v19.8h \n" + "trn1 v5.8h, v21.8h, v23.8h \n" + "trn2 v7.8h, v21.8h, v23.8h \n" + "trn1 v9.8h, v25.8h, v27.8h \n" + "trn2 v11.8h, v25.8h, v27.8h \n" + "trn1 v13.8h, v29.8h, v31.8h \n" + "trn2 v15.8h, v29.8h, v31.8h \n" + + "trn1 v16.4s, v0.4s, v4.4s \n" + "trn2 v20.4s, v0.4s, v4.4s \n" + "trn1 v24.4s, v8.4s, v12.4s \n" + "trn2 v28.4s, v8.4s, v12.4s \n" + "trn1 v17.4s, v1.4s, v5.4s \n" + "trn2 v21.4s, v1.4s, v5.4s \n" + "trn1 v25.4s, v9.4s, v13.4s \n" + "trn2 v29.4s, v9.4s, v13.4s \n" + "trn1 v18.4s, v2.4s, v6.4s \n" + "trn2 v22.4s, v2.4s, v6.4s \n" + "trn1 v26.4s, v10.4s, v14.4s \n" + "trn2 v30.4s, v10.4s, v14.4s \n" + "trn1 v19.4s, v3.4s, v7.4s \n" + "trn2 v23.4s, v3.4s, v7.4s \n" + "trn1 v27.4s, v11.4s, v15.4s \n" + "trn2 v31.4s, v11.4s, v15.4s \n" + + "trn1 v0.2d, v16.2d, v24.2d \n" + "trn2 v8.2d, v16.2d, v24.2d \n" + "trn1 v1.2d, v17.2d, v25.2d \n" + "trn2 v9.2d, v17.2d, v25.2d \n" + "trn1 v2.2d, v18.2d, v26.2d \n" + "trn2 v10.2d, v18.2d, v26.2d \n" + "trn1 v3.2d, v19.2d, v27.2d \n" + "trn2 v11.2d, v19.2d, v27.2d \n" + "trn1 v4.2d, v20.2d, v28.2d \n" + "trn2 v12.2d, v20.2d, v28.2d \n" + "trn1 v5.2d, v21.2d, v29.2d \n" + "trn2 v13.2d, v21.2d, v29.2d \n" + "trn1 v6.2d, v22.2d, v30.2d \n" + "trn2 v14.2d, v22.2d, v30.2d \n" + "trn1 v7.2d, v23.2d, v31.2d \n" + "trn2 v15.2d, v23.2d, v31.2d \n" +// There is no rev128 instruction, so we use rev64 and ext to simulate it. + "rev64 v0.16b, v0.16b \n" + "rev64 v1.16b, v1.16b \n" + "rev64 v2.16b, v2.16b \n" + "rev64 v3.16b, v3.16b \n" + "rev64 v4.16b, v4.16b \n" + "rev64 v5.16b, v5.16b \n" + "rev64 v6.16b, v6.16b \n" + "rev64 v7.16b, v7.16b \n" + "rev64 v8.16b, v8.16b \n" + "rev64 v9.16b, v9.16b \n" + "rev64 v10.16b, v10.16b \n" + "rev64 v11.16b, v11.16b \n" + "rev64 v12.16b, v12.16b \n" + "rev64 v13.16b, v13.16b \n" + "rev64 v14.16b, v14.16b \n" + "rev64 v15.16b, v15.16b \n" + "ext v0.16b, v0.16b, v0.16b, #8 \n" + "ext v1.16b, v1.16b, v1.16b, #8 \n" + "ext v2.16b, v2.16b, v2.16b, #8 \n" + "ext v3.16b, v3.16b, v3.16b, #8 \n" + "ext v4.16b, v4.16b, v4.16b, #8 \n" + "ext v5.16b, v5.16b, v5.16b, #8 \n" + "ext v6.16b, v6.16b, v6.16b, #8 \n" + "ext v7.16b, v7.16b, v7.16b, #8 \n" + "ext v8.16b, v8.16b, v8.16b, #8 \n" + "ext v9.16b, v9.16b, v9.16b, #8 \n" + "ext v10.16b, v10.16b, v10.16b, #8 \n" + "ext v11.16b, v11.16b, v11.16b, #8 \n" + "ext v12.16b, v12.16b, v12.16b, #8 \n" + "ext v13.16b, v13.16b, v13.16b, #8 \n" + "ext v14.16b, v14.16b, v14.16b, #8 \n" + "ext v15.16b, v15.16b, v15.16b, #8 \n" + + "st1 {v0.16b}, [%[dst]], %[dst_step] \n" + "st1 {v1.16b}, [%[dst]], %[dst_step] \n" + "st1 {v2.16b}, [%[dst]], %[dst_step] \n" + "st1 {v3.16b}, [%[dst]], %[dst_step] \n" + "st1 {v4.16b}, [%[dst]], %[dst_step] \n" + "st1 {v5.16b}, [%[dst]], %[dst_step] \n" + "st1 {v6.16b}, [%[dst]], %[dst_step] \n" + "st1 {v7.16b}, [%[dst]], %[dst_step] \n" + "st1 {v8.16b}, [%[dst]], %[dst_step] \n" + "st1 {v9.16b}, [%[dst]], %[dst_step] \n" + "st1 {v10.16b}, [%[dst]], %[dst_step] \n" + "st1 {v11.16b}, [%[dst]], %[dst_step] \n" + "st1 {v12.16b}, [%[dst]], %[dst_step] \n" + "st1 {v13.16b}, [%[dst]], %[dst_step] \n" + "st1 {v14.16b}, [%[dst]], %[dst_step] \n" + "st1 {v15.16b}, [%[dst]], %[dst_step] \n" + : + [src] "+r" (src), + [dst] "+r" (dst) + : + [src_step] "r" (src_step), + [dst_step] "r" (dst_step) + : + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + ); +} + +void rotate_8uc1_counterclockwise_16x16(const uchar *src, + uchar *dst, + size_t src_step, size_t dst_step) +{ + asm volatile ("\n" + "ld1 {v0.16b}, [%[src]], %[src_step] \n" + "ld1 {v1.16b}, [%[src]], %[src_step] \n" + "ld1 {v2.16b}, [%[src]], %[src_step] \n" + "ld1 {v3.16b}, [%[src]], %[src_step] \n" + "ld1 {v4.16b}, [%[src]], %[src_step] \n" + "ld1 {v5.16b}, [%[src]], %[src_step] \n" + "ld1 {v6.16b}, [%[src]], %[src_step] \n" + "ld1 {v7.16b}, [%[src]], %[src_step] \n" + "ld1 {v8.16b}, [%[src]], %[src_step] \n" + "ld1 {v9.16b}, [%[src]], %[src_step] \n" + "ld1 {v10.16b}, [%[src]], %[src_step] \n" + "ld1 {v11.16b}, [%[src]], %[src_step] \n" + "ld1 {v12.16b}, [%[src]], %[src_step] \n" + "ld1 {v13.16b}, [%[src]], %[src_step] \n" + "ld1 {v14.16b}, [%[src]], %[src_step] \n" + "ld1 {v15.16b}, [%[src]], %[src_step] \n" + + "trn1 v16.16b, v0.16b, v1.16b \n" + "trn2 v17.16b, v0.16b, v1.16b \n" + "trn1 v18.16b, v2.16b, v3.16b \n" + "trn2 v19.16b, v2.16b, v3.16b \n" + "trn1 v20.16b, v4.16b, v5.16b \n" + "trn2 v21.16b, v4.16b, v5.16b \n" + "trn1 v22.16b, v6.16b, v7.16b \n" + "trn2 v23.16b, v6.16b, v7.16b \n" + "trn1 v24.16b, v8.16b, v9.16b \n" + "trn2 v25.16b, v8.16b, v9.16b \n" + "trn1 v26.16b, v10.16b, v11.16b \n" + "trn2 v27.16b, v10.16b, v11.16b \n" + "trn1 v28.16b, v12.16b, v13.16b \n" + "trn2 v29.16b, v12.16b, v13.16b \n" + "trn1 v30.16b, v14.16b, v15.16b \n" + "trn2 v31.16b, v14.16b, v15.16b \n" + + "trn1 v0.8h, v16.8h, v18.8h \n" + "trn2 v2.8h, v16.8h, v18.8h \n" + "trn1 v4.8h, v20.8h, v22.8h \n" + "trn2 v6.8h, v20.8h, v22.8h \n" + "trn1 v8.8h, v24.8h, v26.8h \n" + "trn2 v10.8h, v24.8h, v26.8h \n" + "trn1 v12.8h, v28.8h, v30.8h \n" + "trn2 v14.8h, v28.8h, v30.8h \n" + "trn1 v1.8h, v17.8h, v19.8h \n" + "trn2 v3.8h, v17.8h, v19.8h \n" + "trn1 v5.8h, v21.8h, v23.8h \n" + "trn2 v7.8h, v21.8h, v23.8h \n" + "trn1 v9.8h, v25.8h, v27.8h \n" + "trn2 v11.8h, v25.8h, v27.8h \n" + "trn1 v13.8h, v29.8h, v31.8h \n" + "trn2 v15.8h, v29.8h, v31.8h \n" + + "trn1 v16.4s, v0.4s, v4.4s \n" + "trn2 v20.4s, v0.4s, v4.4s \n" + "trn1 v24.4s, v8.4s, v12.4s \n" + "trn2 v28.4s, v8.4s, v12.4s \n" + "trn1 v17.4s, v1.4s, v5.4s \n" + "trn2 v21.4s, v1.4s, v5.4s \n" + "trn1 v25.4s, v9.4s, v13.4s \n" + "trn2 v29.4s, v9.4s, v13.4s \n" + "trn1 v18.4s, v2.4s, v6.4s \n" + "trn2 v22.4s, v2.4s, v6.4s \n" + "trn1 v26.4s, v10.4s, v14.4s \n" + "trn2 v30.4s, v10.4s, v14.4s \n" + "trn1 v19.4s, v3.4s, v7.4s \n" + "trn2 v23.4s, v3.4s, v7.4s \n" + "trn1 v27.4s, v11.4s, v15.4s \n" + "trn2 v31.4s, v11.4s, v15.4s \n" + + "trn1 v0.2d, v16.2d, v24.2d \n" + "trn2 v8.2d, v16.2d, v24.2d \n" + "trn1 v1.2d, v17.2d, v25.2d \n" + "trn2 v9.2d, v17.2d, v25.2d \n" + "trn1 v2.2d, v18.2d, v26.2d \n" + "trn2 v10.2d, v18.2d, v26.2d \n" + "trn1 v3.2d, v19.2d, v27.2d \n" + "trn2 v11.2d, v19.2d, v27.2d \n" + "trn1 v4.2d, v20.2d, v28.2d \n" + "trn2 v12.2d, v20.2d, v28.2d \n" + "trn1 v5.2d, v21.2d, v29.2d \n" + "trn2 v13.2d, v21.2d, v29.2d \n" + "trn1 v6.2d, v22.2d, v30.2d \n" + "trn2 v14.2d, v22.2d, v30.2d \n" + "trn1 v7.2d, v23.2d, v31.2d \n" + "trn2 v15.2d, v23.2d, v31.2d \n" + + "st1 {v15.16b}, [%[dst]], %[dst_step] \n" + "st1 {v14.16b}, [%[dst]], %[dst_step] \n" + "st1 {v13.16b}, [%[dst]], %[dst_step] \n" + "st1 {v12.16b}, [%[dst]], %[dst_step] \n" + "st1 {v11.16b}, [%[dst]], %[dst_step] \n" + "st1 {v10.16b}, [%[dst]], %[dst_step] \n" + "st1 {v9.16b}, [%[dst]], %[dst_step] \n" + "st1 {v8.16b}, [%[dst]], %[dst_step] \n" + "st1 {v7.16b}, [%[dst]], %[dst_step] \n" + "st1 {v6.16b}, [%[dst]], %[dst_step] \n" + "st1 {v5.16b}, [%[dst]], %[dst_step] \n" + "st1 {v4.16b}, [%[dst]], %[dst_step] \n" + "st1 {v3.16b}, [%[dst]], %[dst_step] \n" + "st1 {v2.16b}, [%[dst]], %[dst_step] \n" + "st1 {v1.16b}, [%[dst]], %[dst_step] \n" + "st1 {v0.16b}, [%[dst]], %[dst_step] \n" + : + [src] "+r" (src), + [dst] "+r" (dst) + : + [src_step] "r" (src_step), + [dst_step] "r" (dst_step) + : + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + ); +} + +void rotate_8uc1_clockwise(const uchar* src, uchar* dst, const size_t rows, + const size_t cols, const size_t src_step, + const size_t dst_step) { + const size_t block = 16; + (void)block; + size_t i = 0; + + for (; i + block <= rows; i += block) { + size_t j = 0; + for (; j + block <= cols; j += block) { + rotate_8uc1_clockwise_16x16( + src + i * src_step + j, + dst + j * dst_step + (rows - (i + block)), src_step, + dst_step); + } + for (; j < cols; ++j) { + for (size_t k = 0; k < block; ++k) { + dst[j * dst_step + (rows - 1 - (i + k))] = + src[(i + k) * src_step + j]; + } + } + } + + for (; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + dst[j * dst_step + (rows - 1 - i)] = src[i * src_step + j]; + } + } +} + +void rotate_8uc1_counterclockwise(const uchar* src, uchar* dst, + const size_t rows, const size_t cols, + const size_t src_step, + const size_t dst_step) { + const size_t block = 16; + (void)block; + size_t i = 0; + + for (; i + block <= rows; i += block) { + size_t j = 0; + for (; j + block <= cols; j += block) { + rotate_8uc1_counterclockwise_16x16( + src + i * src_step + j, + dst + (cols - (j + block)) * dst_step + i, src_step, + dst_step); + } + for (; j < cols; ++j) { + for (size_t k = 0; k < block; ++k) { + dst[(cols - 1 - j) * dst_step + (i + k)] = + src[(i + k) * src_step + j]; + } + } + } + + for (; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + dst[(cols - 1 - j) * dst_step + i] = src[i * src_step + j]; + } + } +} + +void rotate(const Mat& src, Mat& dst, bool clockwise) { + megdnn_assert(src.rows() == dst.cols()); + megdnn_assert(src.cols() == dst.rows()); + megdnn_assert(src.channels() == dst.channels()); + megdnn_assert(src.channels() == 1_z); + if (clockwise) { + rotate_8uc1_clockwise(src.ptr(), dst.ptr(), src.rows(), src.cols(), + src.step(), dst.step()); + } else { + rotate_8uc1_counterclockwise(src.ptr(), dst.ptr(), src.rows(), + src.cols(), src.step(), dst.step()); + } +} + +} // namespace megcv + +namespace aarch64 { + +void RotateImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, + _megdnn_workspace workspace) { + using namespace megcv; + check_exec(src.layout, dst.layout, workspace.size); + + //! rotate only support data type is uchar and the channel size is 1 + if (dst.layout.dtype != dtype::Uint8() || src.layout.shape[3] != 1) { + return fallback::RotateImpl::exec(src, dst, workspace); + } + + MEGDNN_DISPATCH_CPU_KERN_OPR({ + for (size_t i = 0; i < src.layout.shape[0]; ++i) { + Mat src_mat = TensorND2Mat(src, i); + Mat dst_mat = TensorND2Mat(dst, i); + rotate(src_mat, dst_mat, param().clockwise); + } + }); +} + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/rotate/opr_impl.h b/dnn/src/aarch64/rotate/opr_impl.h new file mode 100644 index 00000000..099c0035 --- /dev/null +++ b/dnn/src/aarch64/rotate/opr_impl.h @@ -0,0 +1,35 @@ +/** + * \file dnn/src/aarch64/rotate/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "megdnn/oprs.h" +#include "src/fallback/rotate/opr_impl.h" + +namespace megdnn { +namespace aarch64 { + +class RotateImpl : public fallback::RotateImpl { + public: + using fallback::RotateImpl::RotateImpl; + + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/warp_perspective/opr_impl.cpp b/dnn/src/aarch64/warp_perspective/opr_impl.cpp new file mode 100644 index 00000000..5616462a --- /dev/null +++ b/dnn/src/aarch64/warp_perspective/opr_impl.cpp @@ -0,0 +1,43 @@ +/** + * \file dnn/src/aarch64/warp_perspective/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/aarch64/warp_perspective/opr_impl.h" + +#include "src/aarch64/warp_perspective/warp_perspective_cv.h" + +#include "src/common/utils.h" +#include "src/common/warp_common.h" +#include "src/naive/handle.h" + +namespace megdnn { +namespace aarch64 { + +void WarpPerspectiveImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_in mat, + _megdnn_tensor_in mat_idx, + _megdnn_tensor_in dst, + _megdnn_workspace workspace) +{ + check_exec(src.layout, mat.layout, mat_idx.layout, dst.layout, + workspace.size); + if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, + param().format) && !mat_idx.layout.ndim) { + warp_perspective_cv_exec(src, mat, dst, param().border_val, + param().bmode, param().imode, handle()); + } else { + //! Use arm_common implementation + arm_common::WarpPerspectiveImpl::exec(src, mat, mat_idx, dst, workspace); + } +} + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/warp_perspective/opr_impl.h b/dnn/src/aarch64/warp_perspective/opr_impl.h new file mode 100644 index 00000000..c5fd11ce --- /dev/null +++ b/dnn/src/aarch64/warp_perspective/opr_impl.h @@ -0,0 +1,30 @@ +/** + * \file dnn/src/aarch64/warp_perspective/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "src/arm_common/warp_perspective/opr_impl.h" + +namespace megdnn { +namespace aarch64 { + +class WarpPerspectiveImpl : public arm_common::WarpPerspectiveImpl { +public: + using arm_common::WarpPerspectiveImpl::WarpPerspectiveImpl; + + void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, + _megdnn_tensor_in mat_idx, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; +}; + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp b/dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp new file mode 100644 index 00000000..c20f3054 --- /dev/null +++ b/dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp @@ -0,0 +1,257 @@ +/** + * \file dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/aarch64/handle.h" +#include "src/aarch64/warp_perspective/warp_perspective_cv.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/cv/common.h" +#include "src/common/cv/helper.h" +#include "src/common/cv/interp_helper.h" +#include "src/common/utils.h" +#include "src/common/warp_common.h" + +using namespace megdnn; +using namespace aarch64; +using namespace megcv; +using namespace warp; +namespace { + +constexpr size_t BLOCK_SZ = 32u; +template +void warp_perspective_cv(const Mat& src, Mat& dst, const float* trans, + const float border_value, size_t task_id) { + // no extra padding + double M[9]; + rep(i, 9) M[i] = trans[i]; + T bvalue[3] = {(T)border_value, (T)border_value, (T)border_value}; + + size_t x1, y1, width = dst.cols(), height = dst.rows(); + size_t BLOCK_SZ_H = std::min(BLOCK_SZ / 2, height); + size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width); + BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height); + + size_t width_block_size = div_ceil(width, BLOCK_SZ_W); + size_t y = (task_id / width_block_size) * BLOCK_SZ_H; + size_t x = (task_id % width_block_size) * BLOCK_SZ_W; + // start invoke + short XY[BLOCK_SZ * BLOCK_SZ * 2], A[BLOCK_SZ * BLOCK_SZ]; + + float64x2_t vM6 = vdupq_n_f64(M[6]); + float64x2_t vM0 = vdupq_n_f64(M[0]); + float64x2_t vM3 = vdupq_n_f64(M[3]); + float64x2_t v2M6 = vdupq_n_f64(M[6] * 2); + float64x2_t v2M0 = vdupq_n_f64(M[0] * 2); + float64x2_t v2M3 = vdupq_n_f64(M[3] * 2); + float64x2_t v4f = vdupq_n_f64(4); + float64x2_t v1f = vdupq_n_f64(1); + float64x2_t v0f = vdupq_n_f64(0); + float64x2_t vTABLE_SIZE = vdupq_n_f64(INTER_TAB_SIZE); + float64x2_t vmin = vdupq_n_f64((double)INT_MIN); + float64x2_t vmax = vdupq_n_f64((double)INT_MAX); + int32x4_t vtabmask = vdupq_n_s32(INTER_TAB_SIZE - 1); + + size_t bw = std::min(BLOCK_SZ_W, width - x); + size_t bh = std::min(BLOCK_SZ_H, height - y); // height + Mat _XY(bh, bw, 2, XY); + Mat dpart(dst, y, bh, x, bw); + + for (y1 = 0; y1 < bh; y1++) { + short* xy = XY + y1 * bw * 2; + double X0 = M[0] * x + M[1] * (y + y1) + M[2]; + double Y0 = M[3] * x + M[4] * (y + y1) + M[5]; + double W0 = M[6] * x + M[7] * (y + y1) + M[8]; + float64x2_t vW0 = vdupq_n_f64(W0); + float64x2_t vidx = {0.f, 1.f}; + float64x2_t vX0 = vdupq_n_f64(X0); + float64x2_t vY0 = vdupq_n_f64(Y0); + if (imode == IMode::NEAREST) { + for (x1 = 0; x1 + 4 <= bw; x1 += 4) { + float64x2_t vw0 = vaddq_f64(vW0, vmulq_f64(vM6, vidx)); + float64x2_t vw1 = vaddq_f64(vw0, v2M6); + + vw0 = vbitq_f64(vdivq_f64(v1f, vw0), v0f, vceqq_f64(vw0, v0f)); + vw1 = vbitq_f64(vdivq_f64(v1f, vw1), v0f, vceqq_f64(vw1, v0f)); + + float64x2_t vtmp0 = vmlaq_f64(vX0, vM0, vidx); + float64x2_t vtmp1 = vaddq_f64(vtmp0, v2M0); + float64x2_t vfx0 = vmulq_f64(vtmp0, vw0); + float64x2_t vfx1 = vmulq_f64(vtmp1, vw1); + vfx0 = vmaxq_f64(vminq_f64(vfx0, vmax), vmin); + vfx1 = vmaxq_f64(vminq_f64(vfx1, vmax), vmin); + + vtmp0 = vmlaq_f64(vY0, vM3, vidx); + vtmp1 = vaddq_f64(vtmp0, v2M3); + float64x2_t vfy0 = vmulq_f64(vtmp0, vw0); + float64x2_t vfy1 = vmulq_f64(vtmp1, vw1); + vfy0 = vmaxq_f64(vminq_f64(vfy0, vmax), vmin); + vfy1 = vmaxq_f64(vminq_f64(vfy1, vmax), vmin); + + int32x2_t vx0 = vqmovn_s64(vcvtaq_s64_f64(vfx0)); + int32x2_t vx1 = vqmovn_s64(vcvtaq_s64_f64(vfx1)); + int32x2_t vy0 = vqmovn_s64(vcvtaq_s64_f64(vfy0)); + int32x2_t vy1 = vqmovn_s64(vcvtaq_s64_f64(vfy1)); + + int32x4_t vx = vcombine_s32(vx0, vx1); + int32x4_t vy = vcombine_s32(vy0, vy1); + + int16x4x2_t ret = {{vqmovn_s32(vx), vqmovn_s32(vy)}}; + vst2_s16(xy + x1 * 2, ret); + + vidx = vaddq_f64(vidx, v4f); + } + + for (; x1 < bw; x1++) { + double W = W0 + M[6] * x1; + W = W ? 1. / W : 0; + double fX = std::max( + (double)INT_MIN, + std::min((double)INT_MAX, (X0 + M[0] * x1) * W)); + double fY = std::max( + (double)INT_MIN, + std::min((double)INT_MAX, (Y0 + M[3] * x1) * W)); + int X = saturate_cast(fX); + int Y = saturate_cast(fY); + xy[x1 * 2] = saturate_cast(X); + xy[x1 * 2 + 1] = saturate_cast(Y); + } + } else { + short* alpha = A + y1 * bw; + for (x1 = 0; x1 + 4 <= bw; x1 += 4) { + float64x2_t vw0 = vaddq_f64(vW0, vmulq_f64(vM6, vidx)); + float64x2_t vw1 = vaddq_f64(vw0, v2M6); + + vw0 = vbitq_f64(vdivq_f64(vTABLE_SIZE, vw0), v0f, + vceqq_f64(vw0, v0f)); + vw1 = vbitq_f64(vdivq_f64(vTABLE_SIZE, vw1), v0f, + vceqq_f64(vw1, v0f)); + + float64x2_t vtmp0 = vmlaq_f64(vX0, vM0, vidx); + float64x2_t vtmp1 = vaddq_f64(vtmp0, v2M0); + float64x2_t vfx0 = vmulq_f64(vtmp0, vw0); + float64x2_t vfx1 = vmulq_f64(vtmp1, vw1); + vfx0 = vmaxq_f64(vminq_f64(vfx0, vmax), vmin); + vfx1 = vmaxq_f64(vminq_f64(vfx1, vmax), vmin); + + vtmp0 = vmlaq_f64(vY0, vM3, vidx); + vtmp1 = vaddq_f64(vtmp0, v2M3); + float64x2_t vfy0 = vmulq_f64(vtmp0, vw0); + float64x2_t vfy1 = vmulq_f64(vtmp1, vw1); + vfy0 = vmaxq_f64(vminq_f64(vfy0, vmax), vmin); + vfy1 = vmaxq_f64(vminq_f64(vfy1, vmax), vmin); + + int32x2_t vx0 = vqmovn_s64(vcvtaq_s64_f64(vfx0)); + int32x2_t vx1 = vqmovn_s64(vcvtaq_s64_f64(vfx1)); + int32x2_t vy0 = vqmovn_s64(vcvtaq_s64_f64(vfy0)); + int32x2_t vy1 = vqmovn_s64(vcvtaq_s64_f64(vfy1)); + + int32x4_t vx = vcombine_s32(vx0, vx1); + int32x4_t vy = vcombine_s32(vy0, vy1); + + int16x4x2_t ret = {{vqshrn_n_s32(vx, INTER_BITS), + vqshrn_n_s32(vy, INTER_BITS)}}; + vst2_s16(xy + x1 * 2, ret); + + vidx = vaddq_f64(vidx, v4f); + + vx = vandq_s32(vx, vtabmask); + vy = vandq_s32(vy, vtabmask); + + vst1_s16(&alpha[x1], + vqmovn_s32(vmlaq_n_s32(vx, vy, INTER_TAB_SIZE))); + } + for (; x1 < bw; x1++) { + double W = W0 + M[6] * x1; + W = W ? INTER_TAB_SIZE / W : 0; + double fX = std::max( + (double)INT_MIN, + std::min((double)INT_MAX, (X0 + M[0] * x1) * W)); + double fY = std::max( + (double)INT_MIN, + std::min((double)INT_MAX, (Y0 + M[3] * x1) * W)); + int X = saturate_cast(fX); + int Y = saturate_cast(fY); + xy[x1 * 2] = saturate_cast(X >> INTER_BITS); + xy[x1 * 2 + 1] = saturate_cast(Y >> INTER_BITS); + alpha[x1] = + (short)((Y & (INTER_TAB_SIZE - 1)) * INTER_TAB_SIZE + + (X & (INTER_TAB_SIZE - 1))); + } + } + } + Mat _matA(bh, bw, 1, (ushort*)(A)); + remap>(src, dpart, _XY, _matA, bvalue); +} +} // anonymous namespace +void megdnn::aarch64::warp_perspective_cv_exec( + _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst, + float border_value, BorderMode bmode, InterpolationMode imode, + Handle* handle) { + size_t ch = dst.layout[3]; + size_t width = dst.layout[2]; + size_t height = dst.layout[1]; + const size_t batch = dst.layout.shape[0]; + + size_t BLOCK_SZ_H = std::min(BLOCK_SZ / 2, height); + size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width); + BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height); + + size_t parallelism_batch = div_ceil(height, BLOCK_SZ_H) * + div_ceil(width, BLOCK_SZ_W); + megdnn_assert(ch == 1 || ch == 3 || ch == 2, + "unsupported src channel: %zu, avaiable channel size: 1/2/3", + ch); + const float* trans_ptr = trans.ptr(); + if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { +#define cb(_imode, _bmode, _ch) \ + auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ + size_t index, size_t) { \ + size_t batch_id = index / parallelism_batch; \ + size_t task_id = index % parallelism_batch; \ + Mat src_mat = TensorND2Mat(src, batch_id); \ + Mat dst_mat = TensorND2Mat(dst, batch_id); \ + const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ + warp_perspective_cv( \ + src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ + MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ + task_id); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(handle), batch* parallelism_batch, \ + task); + DISPATCH_IMODE(imode, bmode, ch, cb) +#undef cb + } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { +#define cb(_imode, _bmode, _ch) \ + auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ + size_t index, size_t) { \ + size_t batch_id = index / parallelism_batch; \ + size_t task_id = index % parallelism_batch; \ + Mat src_mat = TensorND2Mat(src, batch_id); \ + Mat dst_mat = TensorND2Mat(dst, batch_id); \ + const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ + warp_perspective_cv( \ + src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ + MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ + task_id); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(handle), batch* parallelism_batch, \ + task); + DISPATCH_IMODE(imode, bmode, ch, cb) +#undef cb + } else { + megdnn_throw( + megdnn_mangle("Unsupported datatype of WarpAffine optr.")); + } + +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/warp_perspective/warp_perspective_cv.h b/dnn/src/aarch64/warp_perspective/warp_perspective_cv.h new file mode 100644 index 00000000..62fe31c4 --- /dev/null +++ b/dnn/src/aarch64/warp_perspective/warp_perspective_cv.h @@ -0,0 +1,32 @@ +/** + * \file dnn/src/aarch64/warp_perspective/warp_perspective_cv.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include + +#include "src/common/cv/helper.h" + +namespace megdnn { +namespace aarch64 { + +/** + * \fn warp_perspective_cv + * \brief Used if the format is NHWC, transfer from megcv + */ +void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, + _megdnn_tensor_in dst, float border_value, + param::WarpPerspective::BorderMode border_mode, + param::WarpPerspective::InterpolationMode imode, + Handle* handle); + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp b/dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp new file mode 100644 index 00000000..e833b7fb --- /dev/null +++ b/dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp @@ -0,0 +1,380 @@ +/** + * \file dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/direct/multi_thread_common.h" +#include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/fallback/matrix_mul/opr_impl.h" + +using namespace megdnn; +using namespace arm_common; + +namespace { +bool need_dst_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + auto align = param.src_type.enumv() == DTypeEnum::Float32 ? 4 : 8; + return param.osz[1] % align; +} +bool need_src_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { + return true; + } + return need_dst_copy(param); +} +void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, size_t FW, + size_t PH, size_t PW, size_t& IH2, size_t& IW2, size_t& OW2) { + MEGDNN_MARK_USED_VAR(PW); + MEGDNN_MARK_USED_VAR(PH); + auto&& fm = param.filter_meta; + auto SW = fm.stride[1]; + + auto Align = param.src_type.enumv() == DTypeEnum::Float32 ? 3 : 7; + OW2 = (OW + Align) & ~Align; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; + // Because stride is 2, sometimes IW == IW2+1. Do a max update to + // handle this case. + IH2 = std::max(IH2, IH); + IW2 = std::max(IW2, IW); +} + +} // namespace + +template +WorkspaceBundle MultithreadDirectConvCommon::get_bundle( + const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { + auto&& fm = param.filter_meta; + size_t nr_threads = param.nr_threads; + size_t group = fm.group, batch = param.n; + size_t IH2 = param.isz[0] + 2 * fm.padding[0]; + size_t IW2 = param.isz[1] + 2 * fm.padding[1]; + // part0: copied src + // part1: copied filter + size_t part0, part1; + if (fm.padding[0] == 0 && fm.padding[1] == 0) { + //! only the last plane need to be copied, add 16 Byte extra space in + //! case of invalid read and write + part0 = (param.isz[0] * param.isz[1]) * sizeof(io_ctype) + 16; + } else if (m_large_group) { + //! Serial in group, each thread process one group, parallel by group + part0 = (IH2 * IW2 * fm.icpg * nr_threads) * sizeof(io_ctype) + 16; + } else { + //! Parallel in group, Then should copy every inputs to workspace + part0 = (IH2 * IW2 * fm.icpg * group * batch) * sizeof(io_ctype) + 16; + } + if (param.filter_meta.should_flip) { + if (m_large_group) { + //! Serial in group, each thread has own workspace and then reuse + part1 = fm.spatial[0] * fm.spatial[1] * fm.ocpg * fm.icpg * + nr_threads * sizeof(io_ctype); + } else { + part1 = fm.spatial[0] * fm.spatial[1] * fm.ocpg * fm.icpg * group * + sizeof(io_ctype); + } + } else { + part1 = 0; + } + return {nullptr, {part0, part1}}; +} +template +WorkspaceBundle +MultithreadDirectConvCommon::get_bundle_stride( + const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { + UNPACK_CONV_F32_NCB_KERN_SIZES(param); + MEGDNN_MARK_USED_VAR(N); + MEGDNN_MARK_USED_VAR(OC); + MEGDNN_MARK_USED_VAR(SH); + MEGDNN_MARK_USED_VAR(SW); + auto&& fm = param.filter_meta; + size_t nr_threads = param.nr_threads; + size_t group = fm.group, batch = param.n; + size_t IH2, IW2, OW2; + get_rectified_size(param, IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); + + size_t src_size = 0, dst_size = 0; + // src_size: copied src + // dst_size: copied dst + if (need_src_copy(param)) { + src_size = m_large_group + ? IC * IH2 * IW2 * sizeof(io_ctype) * nr_threads + : IC * IH2 * IW2 * sizeof(io_ctype) * group * batch; + }; + if (need_dst_copy(param)) { + //! add 16 Byte extra space in case of invalid read and write + dst_size = OH * OW2 * sizeof(io_ctype) * nr_threads + 16; + } + return {nullptr, {src_size, dst_size}}; +} + +//! Process one output channel weight flip +template +void MultithreadDirectConvCommon::weight_flip_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t OC = kern_param.filter_meta.ocpg; + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], channel_id = workspace_ids[2], + group_id = ncb_index.ndrange_id[0]; + const io_ctype* filter = + kern_param.filter(group_id) + channel_id * FH * FW * IC; + bundle.set(kern_param.workspace_ptr); + io_ctype* filter_flip = + static_cast(bundle.get(1)) + + (workspace_group_id * IC * OC + channel_id * IC) * FH * FW; + rep(ic, IC) { + const io_ctype* filter_plane = filter + ic * FH * FW; + io_ctype* filter_flip_plane = filter_flip + ic * FH * FW; + rep(fh, FH) rep(fw, FW) { + filter_flip_plane[fh * FW + fw] = + filter_plane[(FH - fh - 1) * FW + (FW - fw - 1)]; + } + } +} + +//! Process one input channel copy padding +template +void MultithreadDirectConvCommon::copy_padding_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t IH2 = IH + 2 * PH; + size_t IW2 = IW + 2 * PW; + size_t padding_group_size = IH2 * IW2 * IC; + size_t N = kern_param.n; + size_t GROUP = kern_param.filter_meta.group; + bundle.set(kern_param.workspace_ptr); + + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2]; + size_t batch_id = ncb_index.ndrange_id[1], + group_id = ncb_index.ndrange_id[0]; + const io_ctype* sptr = static_cast( + kern_param.src(batch_id, group_id, channel_id)); + if (PH > 0 || PW > 0) { + //! copy to sptr_base to eliminate padding effect + io_ctype* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size + + channel_id * IH2 * IW2; + std::memset(sptr_base, 0, sizeof(io_ctype) * IH2 * IW2); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(io_ctype) * IW); + } + } else if (batch_id + 1 == N && channel_id + 1 == IC && + group_id + 1 == GROUP) { + //! copy last plane + io_ctype* sptr_last_c = static_cast(bundle.get(0)); + std::memcpy(sptr_last_c, sptr, sizeof(io_ctype) * IH2 * IW2); + } +}; +//! Process one input channel copy padding +template +void MultithreadDirectConvCommon:: + copy_padding_kern_stride(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t OW = kern_param.osz[1]; + size_t OH = kern_param.osz[0]; + size_t IH2, IW2, OW2; + size_t GROUP = kern_param.filter_meta.group; + get_rectified_size(kern_param, IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1]; + size_t channel_id = workspace_ids[2], batch_id = ncb_index.ndrange_id[1], + group_id = ncb_index.ndrange_id[0]; + + const io_ctype* sptr = static_cast( + kern_param.src(batch_id, group_id, channel_id)); + if (need_src_copy(kern_param)) { + //! copy to sptr_base to eliminate padding effect + io_ctype* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size + + channel_id * IH2 * IW2; + std::memset(sptr_base, 0, sizeof(io_ctype) * IH2 * IW2); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(io_ctype) * IW); + } + } +}; + +//! compute one output channel +template +void MultithreadDirectConvCommon::do_conv_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const kern_direct_conv_f32& fun, const CpuNDRange& workspace_ids) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t OC = kern_param.filter_meta.ocpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t IH2 = kern_param.isz[0] + 2 * PH; + size_t IW2 = kern_param.isz[1] + 2 * PW; + size_t padding_group_size = IH2 * IW2 * IC; + size_t N = kern_param.n; + size_t GROUP = kern_param.filter_meta.group; + bundle.set(kern_param.workspace_ptr); + + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + size_t channel_id = workspace_ids[2]; + + const io_ctype* sptr = kern_param.src(batch_id, group_id); + const io_ctype* filter = kern_param.filter(group_id); + const io_ctype* bias_ptr = + kern_param.bias(batch_id, group_id, channel_id); + io_ctype* dptr = kern_param.dst(batch_id, group_id, channel_id); + + //! Used for get the workspace offset + size_t workspace_batch_id = workspace_ids[1]; + size_t workspace_group_id = workspace_ids[0]; + + io_ctype* sptr_base; + io_ctype* sptr_last_c; + auto fptr = + kern_param.filter_meta.should_flip + ? static_cast(bundle.get(1)) + + (workspace_group_id * OC * IC + channel_id * IC) * + FH * FW + : filter + channel_id * FH * FW * IC; + if (PH > 0 || PW > 0) { + sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size; + sptr_last_c = sptr_base + (IC - 1) * IH2 * IW2; + //! Last batch, last group + } else if (batch_id + 1 == N && group_id + 1 == GROUP) { + sptr_base = const_cast(sptr); + sptr_last_c = static_cast(bundle.get(0)); + } else { + sptr_base = const_cast(sptr); + sptr_last_c = sptr_base + (IC - 1) * IH2 * IW2; + } + std::memset(dptr, 0, sizeof(io_ctype) * (OH * OW)); + rep(ic, IC) { + io_ctype* sptr_cur = + (ic + 1 == IC ? sptr_last_c : sptr_base + ic * IH2 * IW2); + fun(reinterpret_cast(sptr_cur), + reinterpret_cast(fptr + ic * FH * FW), + reinterpret_cast(dptr), IH2, IW2, OH, OW, FH, FW); + } + PostProcess::run(dptr, const_cast(bias_ptr), dptr, + kern_param.bias_mode, kern_param.nonlineMode, + kern_param.bias_type, kern_param.dst_type, 1_z, + 1_z, OH, OW); +}; + +//! compute one output channel +template +void MultithreadDirectConvCommon::do_conv_kern_stride( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const kern_direct_conv_f32_stride& fun, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t IH2, IW2, OW2; + get_rectified_size(kern_param, IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); + + size_t padding_group_size = IH2 * IW2 * IC; + size_t GROUP = kern_param.filter_meta.group; + bundle.set(kern_param.workspace_ptr); + + //! Used for get the workspace offset + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + size_t channel_id = workspace_ids[2]; + + const io_ctype* sptr = kern_param.src(batch_id, group_id); + const io_ctype* fptr = + kern_param.filter(group_id) + channel_id * FH * FW * IC; + const io_ctype* bias_ptr = + kern_param.bias(batch_id, group_id, channel_id); + io_ctype* dptr = kern_param.dst(batch_id, group_id, channel_id); + + size_t workspace_batch_id = workspace_ids[1]; + size_t workspace_group_id = workspace_ids[0]; + + io_ctype* sptr_base; + io_ctype* dptr_base; + if (need_src_copy(kern_param)) { + sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size; + } else { + sptr_base = const_cast(sptr); + } + if (need_dst_copy(kern_param)) { + dptr_base = static_cast(bundle.get(1)) + + ncb_index.thread_id * OH * OW2; + } else { + dptr_base = dptr; + } + if (need_dst_copy(kern_param)) { + std::memset(dptr_base, 0, sizeof(io_ctype) * (OH * OW2)); + fun(reinterpret_cast(sptr_base), + reinterpret_cast(fptr), + reinterpret_cast(dptr_base), IH2, IW2, OH, OW2, IC); + copy_plane_in_bytes(dptr, dptr_base, OH, OW * sizeof(io_ctype), + OW * sizeof(io_ctype), OW2 * sizeof(io_ctype)); + } else { + std::memset(dptr_base, 0, sizeof(io_ctype) * (OH * OW)); + fun(reinterpret_cast(sptr_base), + reinterpret_cast(fptr), + reinterpret_cast(dptr_base), IH2, IW2, OH, OW, IC); + } + PostProcess::run(dptr, const_cast(bias_ptr), dptr, + kern_param.bias_mode, kern_param.nonlineMode, + kern_param.bias_type, kern_param.dst_type, 1_z, + 1_z, OH, OW); +}; +template class megdnn::arm_common::MultithreadDirectConvCommon; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template class megdnn::arm_common::MultithreadDirectConvCommon; +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/direct/multi_thread_common.h b/dnn/src/arm_common/conv_bias/direct/multi_thread_common.h new file mode 100644 index 00000000..55c1ea91 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/direct/multi_thread_common.h @@ -0,0 +1,65 @@ +/** + * \file dnn/src/arm_common/conv_bias/direct/multi_thread_common.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/matrix_mul/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +template +class MultithreadDirectConvCommon { +public: + using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; + using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; + using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + + using kern_direct_conv_f32 = + std::function; + using kern_direct_conv_f32_stride = std::function; + + static WorkspaceBundle get_bundle(const NCBKernSizeParam& param, + bool m_large_group); + static WorkspaceBundle get_bundle_stride(const NCBKernSizeParam& param, + bool m_large_group); + static void weight_flip_kern(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + static void copy_padding_kern(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + static void copy_padding_kern_stride(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + static void do_conv_kern(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const kern_direct_conv_f32& fun, + const CpuNDRange& workspace_ids); + static void do_conv_kern_stride(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const kern_direct_conv_f32_stride& fun, + const CpuNDRange& workspace_ids); +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/algos.cpp b/dnn/src/arm_common/conv_bias/f16/algos.cpp new file mode 100644 index 00000000..c24af258 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/algos.cpp @@ -0,0 +1,561 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/f16/algos.h" +#include "src/arm_common/conv_bias/direct/multi_thread_common.h" +#include "src/arm_common/conv_bias/f16/direct.h" +#include "src/arm_common/conv_bias/f16/do_conv_stride1.h" +#include "src/arm_common/conv_bias/f16/strategy.h" +#include "src/arm_common/conv_bias/img2col_helper.h" +#include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/common/opr_delegate.h" +#include "src/fallback/conv_bias/common.h" +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_fp16) +using namespace megdnn; +using namespace arm_common; + +/* ======================= AlgoFP16WinogradF23 ======================== */ + +bool ConvBiasImpl::AlgoFP16WinogradF23::usable( + fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(opr); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 0, 0) { + using Strategy = winograd::winograd_2x3_4x4_f16; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + (opr->param().format == param::ConvBias::Format::NCHW || + (opr->param().format == + param::ConvBias::Format::NCHW_WINOGRAD && + opr->param().output_block_size == 2 && + param.winograd_matmul_format == + param::MatrixMul::Format::DEFAULT)) && + opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == + param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float16 && + param.filter_meta.icpg % 4 == 0 && + param.filter_meta.ocpg % 4 == 0; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoFP16WinogradF23::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 0, 1) { + winograd::winograd_2x3_4x4_f16 strategy( + param.src_type, param.filter_type, param.dst_type); + return megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_workspace_size(param, m_matmul_algo); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoFP16WinogradF23::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 0, 2) { + winograd::winograd_2x3_4x4_f16 strategy( + param.src_type, param.filter_type, param.dst_type); + + auto winograd_impl = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg); + return winograd_impl.get_kerns(param, m_matmul_algo); + } + MIDOUT_END(); + return {}; +} + +/* ======================= AlgoFP16WinogradF45 ======================== */ + +bool ConvBiasImpl::AlgoFP16WinogradF45::usable( + fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(opr); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 1, 0) { + using Strategy = winograd::winograd_4x5_1x1_f16; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + (opr->param().format == param::ConvBias::Format::NCHW || + (opr->param().format == + param::ConvBias::Format::NCHW_WINOGRAD && + opr->param().output_block_size == 4 && + param.winograd_matmul_format == + param::MatrixMul::Format::DEFAULT)) && + opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 5) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == + param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float16; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoFP16WinogradF45::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + winograd::winograd_4x5_1x1_f16 strategy(param.src_type, param.filter_type, + param.dst_type); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 1, 1) { + return megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_workspace_size(param, m_matmul_algo); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoFP16WinogradF45::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 1, 2) { + winograd::winograd_4x5_1x1_f16 strategy( + param.src_type, param.filter_type, param.dst_type); + auto winograd_impl = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg); + return winograd_impl.get_kerns(param, m_matmul_algo); + } + MIDOUT_END(); + return {}; +} +/* ======================= AlgoFP16WinogradF63 ======================== */ + +bool ConvBiasImpl::AlgoFP16WinogradF63::usable( + fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(opr); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 2, 0) { + using Strategy = winograd::winograd_6x3_1x1_f16; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + (opr->param().format == param::ConvBias::Format::NCHW || + (opr->param().format == + param::ConvBias::Format::NCHW_WINOGRAD && + opr->param().output_block_size == 6 && + param.winograd_matmul_format == + param::MatrixMul::Format::DEFAULT)) && + opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == + param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float16; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoFP16WinogradF63::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + winograd::winograd_6x3_1x1_f16 strategy(param.src_type, param.filter_type, + param.dst_type); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 2, 1) { + return megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_workspace_size(param, m_matmul_algo); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoFP16WinogradF63::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 2, 2) { + winograd::winograd_6x3_1x1_f16 strategy( + param.src_type, param.filter_type, param.dst_type); + auto winograd_impl = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg); + return winograd_impl.get_kerns(param, m_matmul_algo); + } + MIDOUT_END(); + return {}; +} + +/* ======================= AlgoFP16WinogradF23_8x8 ======================== */ + +bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable( + fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(opr); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 3, 0) { + if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) + return false; + using Strategy = winograd::winograd_2x3_8x8_f16; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + (opr->param().format == param::ConvBias::Format::NCHW || + (opr->param().format == + param::ConvBias::Format::NCHW_WINOGRAD && + opr->param().output_block_size == 2 && + param.winograd_matmul_format == + param::MatrixMul::Format::MK8)) && + opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == + param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float16; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoFP16WinogradF23_8x8::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 3, 1) { + winograd::winograd_2x3_8x8_f16 strategy( + param.src_type, param.filter_type, param.dst_type); + return megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_workspace_size(param, m_matmul_algo); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoFP16WinogradF23_8x8::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 2) { + winograd::winograd_2x3_8x8_f16 strategy( + param.src_type, param.filter_type, param.dst_type); + auto winograd_impl = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg); + return winograd_impl.get_kerns(param, m_matmul_algo); + } + MIDOUT_END(); + return {}; +} + +/*========================from Convolution=============================*/ + +MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_kimpl) + +bool ConvBiasImpl::AlgoF16Direct::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 0) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto SH = fm.stride[0], SW = fm.stride[1]; + // the condition ``param.isz[0]*param.isz[1] >= 8'' and + // ``param.osz[0]*param.osz[1] >= 8'' comes from the fact that the + // kernel may have access to up to 8 fp16 after the end of the memory + // chunk. + bool aviliable = fm.format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float16 && + param.filter_type.enumv() == DTypeEnum::Float16 && + param.dst_type.enumv() == DTypeEnum::Float16 && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && + param.isz[0] * param.isz[1] >= 8 && + param.osz[0] * param.osz[1] >= 8 && FH <= 7 && + SH == 1 && SW == 1; + if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + aviliable &= (large_group == m_large_group); + } + return aviliable; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoF16Direct::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) { + auto wbundle = + MultithreadDirectConvCommon::get_bundle( + param, m_large_group); + return wbundle.total_size_in_bytes(); + } + MIDOUT_END(); + return 0; +} + +SmallVector ConvBiasImpl::AlgoF16Direct::get_kimpls( + const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = + MultithreadDirectConvCommon::get_bundle( + param, m_large_group); + SmallVector ret_kerns; + //! When group >= nr_threads, treat it as large_group, each thread process + //! one group for better performance + if (m_large_group) { + //! Channel wise conv and big groups + auto exec_one_group = [wbundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + if (fm.should_flip) { + for (size_t oc = 0; oc < OC; oc++) { + MultithreadDirectConvCommon:: + weight_flip_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); + } + } + for (size_t ic = 0; ic < IC; ic++) { + MultithreadDirectConvCommon:: + copy_padding_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + MultithreadDirectConvCommon::do_conv_kern( + bundle, kern_param, ncb_index, + fp16::conv_bias::kern_direct_f16, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + if (fm.should_flip) { + auto weight_flip = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + MultithreadDirectConvCommon:: + weight_flip_kern(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({weight_flip, {group, 1_z, OC}}); + } + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + MultithreadDirectConvCommon::copy_padding_kern( + bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + MultithreadDirectConvCommon::do_conv_kern( + bundle, kern_param, ncb_index, + fp16::conv_bias::kern_direct_f16, ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} + +SmallVector ConvBiasImpl::AlgoF16Direct::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) { + return get_kimpls(param); + } + MIDOUT_END(); + return {}; +} + +/* ===================== stride-1 algo ===================== */ + +bool ConvBiasImpl::AlgoF16DirectStride1::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 0) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + bool aviliable = + param.filter_meta.format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float16 && + param.filter_type.enumv() == DTypeEnum::Float16 && + param.dst_type.enumv() == DTypeEnum::Float16 && + !fm.should_flip && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5); + if (algo_selection_strategy == + ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + aviliable &= (large_group == m_large_group); + } + return aviliable; + } + MIDOUT_END(); + return false; +} + +SmallVector +ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( + const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + auto FH = fm.spatial[0]; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + using Func = std::function; + Func conv_kern_function = nullptr; + +#define SWITCH_KERN() \ + switch (FH) { \ + case 2: \ + conv_kern_function = fp16::conv_stride1::do_conv_2x2_stride1; \ + break; \ + case 3: \ + conv_kern_function = fp16::conv_stride1::do_conv_3x3_stride1; \ + break; \ + case 5: \ + conv_kern_function = fp16::conv_stride1::do_conv_5x5_stride1; \ + break; \ + } + SWITCH_KERN(); + + WorkspaceBundle wbundle = + MultithreadDirectConvCommon::get_bundle_stride( + param, m_large_group); + SmallVector ret_kerns; + //! When group >= nr_threads, treat it as large_group, each thread process + //! one group for better performance + if (m_large_group) { + //! Channel wise conv and big groups + auto exec_one_group = [wbundle, conv_kern_function]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + MultithreadDirectConvCommon:: + copy_padding_kern_stride(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + MultithreadDirectConvCommon:: + do_conv_kern_stride(bundle, kern_param, ncb_index, + conv_kern_function, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + MultithreadDirectConvCommon:: + copy_padding_kern_stride(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle, conv_kern_function]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + MultithreadDirectConvCommon:: + do_conv_kern_stride(bundle, kern_param, ncb_index, + conv_kern_function, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} + +size_t ConvBiasImpl::AlgoF16DirectStride1::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 1) { + auto bundle = MultithreadDirectConvCommon< + dt_float16, __fp16>::get_bundle_stride(param, m_large_group); + return bundle.total_size_in_bytes(); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoF16DirectStride1::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 2) { + return get_kimpls(param); + } + MIDOUT_END(); + return {}; +} + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/algos.h b/dnn/src/arm_common/conv_bias/f16/algos.h new file mode 100644 index 00000000..2dea43c9 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/algos.h @@ -0,0 +1,185 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/matrix_mul/opr_impl.h" +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +namespace megdnn { +namespace arm_common { + +class ConvBiasImpl::AlgoFP16WinogradF23 final : public AlgoBase { +public: + AlgoFP16WinogradF23(fallback::MatrixMulImpl::AlgoBase* matmul_algo, + uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {1, 2, m_tile_size}); + } + return m_name.c_str(); + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + + static std::vector + get_avaiable_matmul_algos(const NCBKernSizeParam& param); + +private: + fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; + mutable std::string m_name; + + uint32_t m_tile_size; +}; + +class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { +public: + AlgoFP16WinogradF45(fallback::MatrixMulImpl::AlgoBase* matmul_algo, + uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {1, 4, m_tile_size}); + } + return m_name.c_str(); + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + + static std::vector + get_avaiable_matmul_algos(const NCBKernSizeParam& param); + +private: + fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; + mutable std::string m_name; + + uint32_t m_tile_size; +}; +class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { +public: + AlgoFP16WinogradF63(fallback::MatrixMulImpl::AlgoBase* matmul_algo, + uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {1, 6, m_tile_size}); + } + return m_name.c_str(); + } + + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + + static std::vector + get_avaiable_matmul_algos(const NCBKernSizeParam& param); + +private: + fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; + mutable std::string m_name; + + uint32_t m_tile_size; +}; +class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { +public: + AlgoFP16WinogradF23_8x8(fallback::MatrixMulImpl::AlgoBase* matmul_algo, + uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {8, 2, m_tile_size}); + } + return m_name.c_str(); + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + +private: + fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; + mutable std::string m_name; + uint32_t m_tile_size; +}; + +class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + bool m_large_group; + +public: + AlgoF16Direct(bool is_large_group) : m_large_group{is_large_group} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "F16DIRECT_LARGE_GROUP" + : "F16DIRECT_SMALL_GROUP"; + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + bool m_large_group; + +public: + AlgoF16DirectStride1(bool is_large_group) : m_large_group{is_large_group} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "F16STRD1_LARGE_GROUP" : "F16STRD1_SMALL_GROUP"; + } + bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/direct.cpp b/dnn/src/arm_common/conv_bias/f16/direct.cpp new file mode 100644 index 00000000..b291e04c --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/direct.cpp @@ -0,0 +1,799 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/direct.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./direct.h" +#include "include/megdnn/oprs.h" +#include "midout.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/simd_macro/neon_helper.h" + +MIDOUT_DECL(megdnn_arm_conv_f16) + +using namespace megdnn; +using namespace arm_common; +using namespace fp16; +using namespace conv_bias; +namespace { + +#define BLOCK_H 8 + +#define LOAD_RESULT_VAL \ + if (width < 8) { \ + auto load_less_8 = [](__fp16* dst, float16x8_t& data) { \ + if (width == 1u) { \ + data = vld1q_lane_f16(dst, data, 0); \ + } else if (width == 2u) { \ + data = vld1q_lane_f16(dst + 0, data, 0); \ + data = vld1q_lane_f16(dst + 1, data, 1); \ + } else if (width == 3u) { \ + data = vld1q_lane_f16(dst + 0, data, 0); \ + data = vld1q_lane_f16(dst + 1, data, 1); \ + data = vld1q_lane_f16(dst + 2, data, 2); \ + } else if (width == 4u) { \ + data = vld1q_lane_u64(dst, data, 0); \ + } else if (width == 5u) { \ + data = vld1q_lane_u64(dst, data, 0); \ + data = vld1q_lane_f16(dst + 4, data, 4); \ + } else if (width == 6u) { \ + data = vld1q_lane_u64(dst, data, 0); \ + data = vld1q_lane_f16(dst + 4, data, 4); \ + data = vld1q_lane_f16(dst + 5, data, 5); \ + } else if (width == 7u) { \ + data = vld1q_lane_u64(dst, data, 0); \ + data = vld1q_lane_f16(dst + 4, data, 4); \ + data = vld1q_lane_f16(dst + 5, data, 5); \ + data = vld1q_lane_f16(dst + 6, data, 6); \ + } \ + }; \ + if (height >= 1) \ + load_less_8(dst + 0 * OW, out0); \ + if (height >= 2) \ + load_less_8(dst + 1 * OW, out1); \ + if (height >= 3) \ + load_less_8(dst + 2 * OW, out2); \ + if (height >= 4) \ + load_less_8(dst + 3 * OW, out3); \ + if (height >= 5) \ + load_less_8(dst + 4 * OW, out4); \ + if (height >= 6) \ + load_less_8(dst + 5 * OW, out5); \ + if (height >= 7) \ + load_less_8(dst + 6 * OW, out6); \ + if (height >= 8) \ + load_less_8(dst + 7 * OW, out7); \ + } else { \ + if (height >= 1) \ + out0 = vld1q_f16(dst + 0 * OW); \ + if (height >= 2) \ + out1 = vld1q_f16(dst + 1 * OW); \ + if (height >= 3) \ + out2 = vld1q_f16(dst + 2 * OW); \ + if (height >= 4) \ + out3 = vld1q_f16(dst + 3 * OW); \ + if (height >= 5) \ + out4 = vld1q_f16(dst + 4 * OW); \ + if (height >= 6) \ + out5 = vld1q_f16(dst + 5 * OW); \ + if (height >= 7) \ + out6 = vld1q_f16(dst + 6 * OW); \ + if (height >= 8) \ + out7 = vld1q_f16(dst + 7 * OW); \ + } + +#define STORE_RESULT_VAL \ + if (width < 8) { \ + auto store_less_8 = [](__fp16* dst, float16x8_t& data) { \ + if (width == 1u) { \ + vst1q_lane_f16(dst, data, 0); \ + } else if (width == 2u) { \ + vst1q_lane_f16(dst + 0, data, 0); \ + vst1q_lane_f16(dst + 1, data, 1); \ + } else if (width == 3u) { \ + vst1q_lane_f16(dst + 0, data, 0); \ + vst1q_lane_f16(dst + 1, data, 1); \ + vst1q_lane_f16(dst + 2, data, 2); \ + } else if (width == 4u) { \ + vst1_f16(dst, vget_low_f16(data)); \ + } else if (width == 5u) { \ + vst1_f16(dst, vget_low_f16(data)); \ + vst1q_lane_f16(dst + 4, data, 4); \ + } else if (width == 6u) { \ + vst1_f16(dst, vget_low_f16(data)); \ + vst1q_lane_f16(dst + 4, data, 4); \ + vst1q_lane_f16(dst + 5, data, 5); \ + } else if (width == 7u) { \ + vst1_f16(dst, vget_low_f16(data)); \ + vst1q_lane_f16(dst + 4, data, 4); \ + vst1q_lane_f16(dst + 5, data, 5); \ + vst1q_lane_f16(dst + 6, data, 6); \ + } \ + }; \ + if (height >= 1) \ + store_less_8(dst + 0 * OW, out0); \ + if (height >= 2) \ + store_less_8(dst + 1 * OW, out1); \ + if (height >= 3) \ + store_less_8(dst + 2 * OW, out2); \ + if (height >= 4) \ + store_less_8(dst + 3 * OW, out3); \ + if (height >= 5) \ + store_less_8(dst + 4 * OW, out4); \ + if (height >= 6) \ + store_less_8(dst + 5 * OW, out5); \ + if (height >= 7) \ + store_less_8(dst + 6 * OW, out6); \ + if (height >= 8) \ + store_less_8(dst + 7 * OW, out7); \ + } else { \ + if (height >= 1) \ + vst1q_f16(dst + 0 * OW, out0); \ + if (height >= 2) \ + vst1q_f16(dst + 1 * OW, out1); \ + if (height >= 3) \ + vst1q_f16(dst + 2 * OW, out2); \ + if (height >= 4) \ + vst1q_f16(dst + 3 * OW, out3); \ + if (height >= 5) \ + vst1q_f16(dst + 4 * OW, out4); \ + if (height >= 6) \ + vst1q_f16(dst + 5 * OW, out5); \ + if (height >= 7) \ + vst1q_f16(dst + 6 * OW, out6); \ + if (height >= 8) \ + vst1q_f16(dst + 7 * OW, out7); \ + } + +template +struct do_pixel_proxy { + static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow); +}; + +template +struct do_pixel_proxy<1, height, width> { + static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + const int ih = oh, iw = ow; +#define cb(i) float16x8_t out##i{0}; + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); +#undef cb + float16x8_t kr0, inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_RESULT_VAL; + for (int fw = 0; fw < FW; ++fw) { + const __fp16* src_dd = src + fw; + kr0 = vdupq_n_f16(filter[0 * FW + fw]); + +#define cb(i) \ + if (height > i) { \ + inp = vld1q_f16(src_dd + i * IW); \ + out##i = vmlaq_f16(out##i, inp, kr0); \ + } + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); + +#undef cb + } + STORE_RESULT_VAL; + } +}; + +template +struct do_pixel_proxy<2, height, width> { + static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + const int ih = oh, iw = ow; +#define cb(i) float16x8_t out##i{0}; + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); +#undef cb + float16x8_t kr0, kr1, inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_RESULT_VAL; + for (int fw = 0; fw < FW; ++fw) { + const __fp16* src_dd = src + fw; + kr0 = vdupq_n_f16(filter[0 * FW + fw]); + kr1 = vdupq_n_f16(filter[1 * FW + fw]); + +#define cb(i) \ + if (height > i) { \ + inp = vld1q_f16(src_dd + i * IW); \ + out##i = vmlaq_f16(out##i, inp, kr0); \ + inp = vld1q_f16(src_dd + (i + 1) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr1); \ + } + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); +#undef cb + } + STORE_RESULT_VAL; + } +}; + +template +struct do_pixel_proxy<3, height, width> { + static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + const int ih = oh, iw = ow; +#define cb(i) float16x8_t out##i{0}; + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); +#undef cb + float16x8_t kr0, kr1, kr2, inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_RESULT_VAL; + for (int fw = 0; fw < FW; ++fw) { + const __fp16* src_dd = src + fw; + kr0 = vdupq_n_f16(filter[0 * FW + fw]); + kr1 = vdupq_n_f16(filter[1 * FW + fw]); + kr2 = vdupq_n_f16(filter[2 * FW + fw]); +#define cb(i) \ + if (height > i) { \ + inp = vld1q_f16(src_dd + i * IW); \ + out##i = vmlaq_f16(out##i, inp, kr0); \ + inp = vld1q_f16(src_dd + (i + 1) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr1); \ + inp = vld1q_f16(src_dd + (i + 2) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr2); \ + } + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); + +#undef cb + } + STORE_RESULT_VAL; + } +}; + +template +struct do_pixel_proxy<4, height, width> { + static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + const int ih = oh, iw = ow; +#define cb(i) float16x8_t out##i{0}; + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); +#undef cb + float16x8_t kr0, kr1, kr2, kr3, inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_RESULT_VAL; + for (int fw = 0; fw < FW; ++fw) { + const __fp16* src_dd = src + fw; + kr0 = vdupq_n_f16(filter[0 * FW + fw]); + kr1 = vdupq_n_f16(filter[1 * FW + fw]); + kr2 = vdupq_n_f16(filter[2 * FW + fw]); + kr3 = vdupq_n_f16(filter[3 * FW + fw]); +#define cb(i) \ + if (height > i) { \ + inp = vld1q_f16(src_dd + i * IW); \ + out##i = vmlaq_f16(out##i, inp, kr0); \ + inp = vld1q_f16(src_dd + (i + 1) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr1); \ + inp = vld1q_f16(src_dd + (i + 2) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr2); \ + inp = vld1q_f16(src_dd + (i + 3) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr3); \ + } + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); + +#undef cb + } + STORE_RESULT_VAL; + } +}; + +template +struct do_pixel_proxy<5, height, width> { + static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + const int ih = oh, iw = ow; +#define cb(i) float16x8_t out##i{0}; + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); +#undef cb + float16x8_t kr0, kr1, kr2, kr3, kr4, inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_RESULT_VAL; + for (int fw = 0; fw < FW; ++fw) { + const __fp16* src_dd = src + fw; + kr0 = vdupq_n_f16(filter[0 * FW + fw]); + kr1 = vdupq_n_f16(filter[1 * FW + fw]); + kr2 = vdupq_n_f16(filter[2 * FW + fw]); + kr3 = vdupq_n_f16(filter[3 * FW + fw]); + kr4 = vdupq_n_f16(filter[4 * FW + fw]); +#define cb(i) \ + if (height > i) { \ + inp = vld1q_f16(src_dd + i * IW); \ + out##i = vmlaq_f16(out##i, inp, kr0); \ + inp = vld1q_f16(src_dd + (i + 1) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr1); \ + inp = vld1q_f16(src_dd + (i + 2) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr2); \ + inp = vld1q_f16(src_dd + (i + 3) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr3); \ + inp = vld1q_f16(src_dd + (i + 4) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr4); \ + } + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); +#undef cb + } + STORE_RESULT_VAL; + } +}; + +template +struct do_pixel_proxy<6, height, width> { + static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + const int ih = oh, iw = ow; +#define cb(i) float16x8_t out##i{0}; + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); +#undef cb + float16x8_t kr0, kr1, kr2, kr3, kr4, kr5, inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_RESULT_VAL; + for (int fw = 0; fw < FW; ++fw) { + const __fp16* src_dd = src + fw; + kr0 = vdupq_n_f16(filter[0 * FW + fw]); + kr1 = vdupq_n_f16(filter[1 * FW + fw]); + kr2 = vdupq_n_f16(filter[2 * FW + fw]); + kr3 = vdupq_n_f16(filter[3 * FW + fw]); + kr4 = vdupq_n_f16(filter[4 * FW + fw]); + kr5 = vdupq_n_f16(filter[5 * FW + fw]); +#define cb(i) \ + if (height > i) { \ + inp = vld1q_f16(src_dd + i * IW); \ + out##i = vmlaq_f16(out##i, inp, kr0); \ + inp = vld1q_f16(src_dd + (i + 1) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr1); \ + inp = vld1q_f16(src_dd + (i + 2) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr2); \ + inp = vld1q_f16(src_dd + (i + 3) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr3); \ + inp = vld1q_f16(src_dd + (i + 4) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr4); \ + inp = vld1q_f16(src_dd + (i + 5) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr5); \ + } + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); +#undef cb + } + STORE_RESULT_VAL; + } +}; + +template +struct do_pixel_proxy<7, height, width> { + static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + const int ih = oh, iw = ow; +#define cb(i) float16x8_t out##i{0}; + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); +#undef cb + float16x8_t kr0, kr1, kr2, kr3, kr4, kr5, kr6, inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_RESULT_VAL; + for (int fw = 0; fw < FW; ++fw) { + const __fp16* src_dd = src + fw; + kr0 = vdupq_n_f16(filter[0 * FW + fw]); + kr1 = vdupq_n_f16(filter[1 * FW + fw]); + kr2 = vdupq_n_f16(filter[2 * FW + fw]); + kr3 = vdupq_n_f16(filter[3 * FW + fw]); + kr4 = vdupq_n_f16(filter[4 * FW + fw]); + kr5 = vdupq_n_f16(filter[5 * FW + fw]); + kr6 = vdupq_n_f16(filter[6 * FW + fw]); +#define cb(i) \ + if (height > i) { \ + inp = vld1q_f16(src_dd + i * IW); \ + out##i = vmlaq_f16(out##i, inp, kr0); \ + inp = vld1q_f16(src_dd + (i + 1) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr1); \ + inp = vld1q_f16(src_dd + (i + 2) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr2); \ + inp = vld1q_f16(src_dd + (i + 3) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr3); \ + inp = vld1q_f16(src_dd + (i + 4) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr4); \ + inp = vld1q_f16(src_dd + (i + 5) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr5); \ + inp = vld1q_f16(src_dd + (i + 6) * IW); \ + out##i = vmlaq_f16(out##i, inp, kr6); \ + } + UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); +#undef cb + } + STORE_RESULT_VAL; + } +}; + +#undef STORE_RESULT_VAL +#undef LOAD_RESULT_VAL + +template +void do_pixel(const __fp16* src, const __fp16* filter, __fp16* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + do_pixel_proxy::exec(src, filter, dst, IH, IW, OH, OW, + FW, oh, ow); +} + +template +void do_conv_tpl_enable_prefetch(const __fp16* src, + const __fp16* filter, __fp16* dst, + const int IH, const int IW, const int OH, + const int OW, const int FW) { + const int hbeg = 0, hend = OH; + const int wbeg = 0, wend = OW; + int i, j; + for (i = hbeg; i + BLOCK_H <= hend; i += BLOCK_H) { + for (j = wbeg; j + 8 <= wend; j += 8) { + // do prefetch + const int prefetch_index_input = + (j + 16) < wend + ? i * IW + j + 16 + : (i + 8) * IW + (((j + 16 - wend) >> 2) << 2); + const int prefetch_index_output = + (j + 16) < wend + ? i * OW + j + 16 + : (i + 8) * OW + (((j + 16 - wend) >> 2) << 2); + const __fp16* src_prefetch = src + prefetch_index_input; + const __fp16* dst_prefetch = dst + prefetch_index_output; + for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { + __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); + } +#define unroll_prefetch_cb(i) __builtin_prefetch(dst_prefetch + i * OW, 1, 3); + UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, + j); + } +#define DISPATCH(width) \ + do { \ + const int prefetch_index_input = (i + 8) * IW + 12; \ + const int prefetch_index_output = (i + 8) * OW + 12; \ + const __fp16* src_prefetch = src + prefetch_index_input; \ + const __fp16* dst_prefetch = dst + prefetch_index_output; \ + for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ + __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ + } \ + UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ + j); \ + } while (0) + switch (wend - j) { + case 1: + DISPATCH(1); + break; + case 2: + DISPATCH(2); + break; + case 3: + DISPATCH(3); + break; + case 4: + DISPATCH(4); + break; + case 5: + DISPATCH(5); + break; + case 6: + DISPATCH(6); + break; + case 7: + DISPATCH(7); + break; + } +#undef DISPATCH + } + +#define DISPATCH2(height, width) \ + do { \ + const int prefetch_index_input = IH * IW + 12; \ + const __fp16* src_prefetch = src + prefetch_index_input; \ + for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ + __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ + } \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ + j); \ + } while (0) + +#define DISPATCH1(height) \ + do { \ + for (j = wbeg; j + 8 <= wend; j += 8) { \ + const int prefetch_index_input = \ + (j + 16) < wend \ + ? i * IW + j + 16 \ + : (i + 8) * IW + (((j + 16 - wend) >> 2) << 2); \ + const int prefetch_index_output = \ + (j + 16) < wend \ + ? i * OW + j + 16 \ + : (i + 8) * OW + (((j + 16 - wend) >> 2) << 2); \ + const __fp16* src_prefetch = src + prefetch_index_input; \ + const __fp16* dst_prefetch = dst + prefetch_index_output; \ + for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ + __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ + } \ + UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ + j); \ + } \ + switch (wend - j) { \ + case 1: \ + DISPATCH2(height, 1); \ + break; \ + case 2: \ + DISPATCH2(height, 2); \ + break; \ + case 3: \ + DISPATCH2(height, 3); \ + break; \ + case 4: \ + DISPATCH2(height, 4); \ + break; \ + case 5: \ + DISPATCH2(height, 5); \ + break; \ + case 6: \ + DISPATCH2(height, 6); \ + break; \ + case 7: \ + DISPATCH2(height, 7); \ + break; \ + } \ + } while (0) + switch (hend - i) { + case 1: + DISPATCH1(1); + break; + case 2: + DISPATCH1(2); + break; + case 3: + DISPATCH1(3); + break; +#if BLOCK_H == 8 + case 4: + DISPATCH1(4); + break; + case 5: + DISPATCH1(5); + break; + case 6: + DISPATCH1(6); + break; + case 7: + DISPATCH1(7); + break; +#endif + } +#undef DISPATCH1 +#undef DISPATCH2 +#undef unroll_prefetch_cb +} +template +void do_conv_tpl_disable_prefetch(const __fp16* src, + const __fp16* filter, __fp16* dst, + const int IH, const int IW, const int OH, + const int OW, const int FW) { + const int hbeg = 0, hend = OH; + const int wbeg = 0, wend = OW; + int i, j; + for (i = hbeg; i + BLOCK_H <= hend; i += BLOCK_H) { + for (j = wbeg; j + 8 <= wend; j += 8) { + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, + j); + } +#define DISPATCH(width) \ + do { \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ + j); \ + } while (0) + switch (wend - j) { + case 1: + DISPATCH(1); + break; + case 2: + DISPATCH(2); + break; + case 3: + DISPATCH(3); + break; + case 4: + DISPATCH(4); + break; + case 5: + DISPATCH(5); + break; + case 6: + DISPATCH(6); + break; + case 7: + DISPATCH(7); + break; + } +#undef DISPATCH + } +#define DISPATCH2(height, width) \ + do { \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ + j); \ + } while (0) +#define DISPATCH1(height) \ + do { \ + for (j = wbeg; j + 8 <= wend; j += 8) { \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ + j); \ + } \ + switch (wend - j) { \ + case 1: \ + DISPATCH2(height, 1); \ + break; \ + case 2: \ + DISPATCH2(height, 2); \ + break; \ + case 3: \ + DISPATCH2(height, 3); \ + break; \ + case 4: \ + DISPATCH2(height, 4); \ + break; \ + case 5: \ + DISPATCH2(height, 5); \ + break; \ + case 6: \ + DISPATCH2(height, 6); \ + break; \ + case 7: \ + DISPATCH2(height, 7); \ + break; \ + } \ + } while (0) + switch (hend - i) { + case 1: + DISPATCH1(1); + break; + case 2: + DISPATCH1(2); + break; + case 3: + DISPATCH1(3); + break; +#if BLOCK_H == 8 + case 4: + DISPATCH1(4); + break; + case 5: + DISPATCH1(5); + break; + case 6: + DISPATCH1(6); + break; + case 7: + DISPATCH1(7); + break; +#endif + } +#undef DISPATCH1 +#undef DISPATCH2 +} +} // anonymous namespace + +void conv_bias::kern_direct_f16(const __fp16* src, + const __fp16* filter, __fp16* dst, + const int IH, const int IW, const int OH, + const int OW, const int FH, const int FW) { + megdnn_assert_internal(FH <= 7); + if (IH > 100 && IW > 100) { +#define GAO(FH) \ + do { \ + return do_conv_tpl_enable_prefetch(src, filter, dst, IH, IW, OH, \ + OW, FW); \ + } while (0) + switch (FH) { + case 1: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(0)) { GAO(1); } + MIDOUT_END(); + break; + case 2: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(1)) { GAO(2); } + MIDOUT_END(); + break; + case 3: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(2)) { GAO(3); } + MIDOUT_END(); + break; + case 4: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(3)) { GAO(4); } + MIDOUT_END(); + break; + case 5: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(4)) { GAO(5); } + MIDOUT_END(); + break; + case 6: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(5)) { GAO(6); } + MIDOUT_END(); + break; + case 7: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(6)) { GAO(7); } + MIDOUT_END(); + break; + } +#undef GAO + } else { +#define GAO(FH) \ + do { \ + return do_conv_tpl_disable_prefetch(src, filter, dst, IH, IW, OH, \ + OW, FW); \ + } while (0) + switch (FH) { + case 1: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(0)) { GAO(1); } + MIDOUT_END(); + break; + case 2: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(1)) { GAO(2); } + MIDOUT_END(); + break; + case 3: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(2)) { GAO(3); } + MIDOUT_END(); + break; + case 4: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(3)) { GAO(4); } + MIDOUT_END(); + break; + case 5: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(4)) { GAO(5); } + MIDOUT_END(); + break; + case 6: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(5)) { GAO(6); } + MIDOUT_END(); + break; + case 7: + MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(6)) { GAO(7); } + MIDOUT_END(); + break; + } +#undef GAO + } + megdnn_assert_internal(0); +} +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/direct.h b/dnn/src/arm_common/conv_bias/f16/direct.h new file mode 100644 index 00000000..7a556eec --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/direct.h @@ -0,0 +1,32 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/direct.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include +#include "megdnn/dtype.h" +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +namespace megdnn { +namespace arm_common { +namespace fp16{ +namespace conv_bias { + +void kern_direct_f16(const __fp16* src, const __fp16* filter, + __fp16* dst, const int IH, const int IW, const int OH, + const int OW, const int FH, const int FW); + +} // namespace convolution +} // namespace fp16 +} // namespace arm_common +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.cpp b/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.cpp new file mode 100644 index 00000000..4ddbebf7 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.cpp @@ -0,0 +1,522 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/do_conv_stride1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "./do_conv_stride1.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/conv_bias/postprocess_helper.h" + +using namespace megdnn; +using namespace arm_common; +using namespace fp16; +using namespace conv_stride1; + +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; + + +void conv_stride1::do_conv_2x2_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, + size_t IH, size_t IW, size_t OH, size_t OW, + size_t IC) { + const size_t tail_step = IW - OW; + //! unroll of 2 + size_t ic = 0; + for (; ic + 1 < IC; ic += 2) { + const __fp16* src_ptr = src + IW * IH * ic; + const __fp16* src_ptr1 = src_ptr + IW * IH; + __fp16* outptr = dst; + + const __fp16* r00 = src_ptr; + const __fp16* r01 = src_ptr + IW; + const __fp16* r10 = src_ptr1; + const __fp16* r11 = src_ptr1 + IW; + + const __fp16* k0 = filter + ic * 4; + + float16x8_t _k0 = vld1q_f16(k0); + rep(h, OH) { + int width = OW >> 3; + + rep(i, width) { + float16x8_t _r000 = vld1q_f16(r00); + float16x8_t _r010 = vld1q_f16(r01); + float16x8_t _r001 = vld1q_f16(r00 + 1); + float16x8_t _r011 = vld1q_f16(r01 + 1); + + float16x8_t _r100 = vld1q_f16(r10); + float16x8_t _r110 = vld1q_f16(r11); + float16x8_t _r101 = vld1q_f16(r10 + 1); + float16x8_t _r111 = vld1q_f16(r11 + 1); + + float16x8_t _sum = vld1q_f16(outptr); + + _sum = vmlaq_lane_f16(_sum, _r000, vget_low_f16(_k0), 0); + _sum = vmlaq_lane_f16(_sum, _r001, vget_low_f16(_k0), 1); + _sum = vmlaq_lane_f16(_sum, _r010, vget_low_f16(_k0), 2); + _sum = vmlaq_lane_f16(_sum, _r011, vget_low_f16(_k0), 3); + + _sum = vmlaq_lane_f16(_sum, _r100, vget_high_f16(_k0), 0); + _sum = vmlaq_lane_f16(_sum, _r101, vget_high_f16(_k0), 1); + _sum = vmlaq_lane_f16(_sum, _r110, vget_high_f16(_k0), 2); + _sum = vmlaq_lane_f16(_sum, _r111, vget_high_f16(_k0), 3); + + vst1q_f16(outptr, _sum); + + r00 += 8; + r01 += 8; + r10 += 8; + r11 += 8; + outptr += 8; + } + + r00 += tail_step; + r01 += tail_step; + r10 += tail_step; + r11 += tail_step; + } + } + for (; ic < IC; ic++) { + const __fp16* src_ptr = src + IW * IH * ic; + __fp16* outptr = dst; + + const __fp16* r0 = src_ptr; + const __fp16* r1 = src_ptr + IW; + + const __fp16* k0 = filter + ic * 4; + + float16x8_t _k0 = vdupq_n_f16(k0[0]); + float16x8_t _k1 = vdupq_n_f16(k0[1]); + float16x8_t _k2 = vdupq_n_f16(k0[2]); + float16x8_t _k3 = vdupq_n_f16(k0[3]); + rep(h, OH) { + int width = OW >> 3; + + rep(i, width) { + float16x8_t _r00 = vld1q_f16(r0); + float16x8_t _r10 = vld1q_f16(r1); + float16x8_t _r01 = vld1q_f16(r0 + 1); + float16x8_t _r11 = vld1q_f16(r1 + 1); + + float16x8_t _sum = vld1q_f16(outptr); + float16x8_t _sum2; + + _sum = vmlaq_f16(_sum, _r00, _k0); + _sum2 = vmulq_f16(_r01, _k1); + _sum = vmlaq_f16(_sum, _r10, _k2); + _sum2 = vmlaq_f16(_sum2, _r11, _k3); + + _sum = vaddq_f16(_sum, _sum2); + + vst1q_f16(outptr, _sum); + + r0 += 8; + r1 += 8; + outptr += 8; + } + + r0 += tail_step; + r1 += tail_step; + } + } +} + +void conv_stride1::do_conv_3x3_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, + size_t IH, size_t IW, size_t OH, size_t OW, + size_t IC) { + const size_t tail_step = IW - OW; + + rep(ic, IC) { + const __fp16* src_ptr = src + IW * IH * ic; + __fp16* outptr = dst; + __fp16* outptr2 = outptr + OW; + + const __fp16* r0 = src_ptr; + const __fp16* r1 = src_ptr + IW; + const __fp16* r2 = src_ptr + IW * 2; + const __fp16* r3 = src_ptr + IW * 3; + + float16x8_t _k01234567 = vld1q_f16(filter); + float16x8_t _k12345678 = vld1q_f16(filter + 1); + + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 3; + + rep(i, width) { + float16x8_t _sum1 = vld1q_f16(outptr); + float16x8_t _sum2 = vdupq_n_f16(0.f); + float16x8_t _sum3 = vld1q_f16(outptr2); + float16x8_t _sum4 = vdupq_n_f16(0.f); + + float16x8_t _r00 = vld1q_f16(r0); + float16x8_t _r00n = vld1q_f16(r0 + 8); + float16x8_t _r01 = vextq_f16(_r00, _r00n, 1); + float16x8_t _r02 = vextq_f16(_r00, _r00n, 2); + + float16x8_t _r10 = vld1q_f16(r1); + float16x8_t _r10n = vld1q_f16(r1 + 8); + float16x8_t _r11 = vextq_f16(_r10, _r10n, 1); + float16x8_t _r12 = vextq_f16(_r10, _r10n, 2); + + float16x8_t _r20 = vld1q_f16(r2); + float16x8_t _r20n = vld1q_f16(r2 + 8); + float16x8_t _r21 = vextq_f16(_r20, _r20n, 1); + float16x8_t _r22 = vextq_f16(_r20, _r20n, 2); + + float16x8_t _r30 = vld1q_f16(r3); + float16x8_t _r30n = vld1q_f16(r3 + 8); + float16x8_t _r31 = vextq_f16(_r30, _r30n, 1); + float16x8_t _r32 = vextq_f16(_r30, _r30n, 2); + + _sum1 = vmlaq_low_lane_f16(_sum1, _r00, _k01234567, 0); + _sum2 = vmlaq_low_lane_f16(_sum2, _r01, _k01234567, 1); + _sum1 = vmlaq_low_lane_f16(_sum1, _r02, _k01234567, 2); + _sum2 = vmlaq_low_lane_f16(_sum2, _r10, _k01234567, 3); + _sum1 = vmlaq_high_lane_f16(_sum1, _r11, _k01234567, 4); + _sum2 = vmlaq_high_lane_f16(_sum2, _r12, _k01234567, 5); + _sum1 = vmlaq_high_lane_f16(_sum1, _r20, _k01234567, 6); + _sum2 = vmlaq_high_lane_f16(_sum2, _r21, _k01234567, 7); + _sum1 = vmlaq_high_lane_f16(_sum1, _r22, _k12345678, 7); + + _sum3 = vmlaq_low_lane_f16(_sum3, _r10, _k01234567, 0); + _sum4 = vmlaq_low_lane_f16(_sum4, _r11, _k01234567, 1); + _sum3 = vmlaq_low_lane_f16(_sum3, _r12, _k01234567, 2); + _sum4 = vmlaq_low_lane_f16(_sum4, _r20, _k01234567, 3); + _sum3 = vmlaq_high_lane_f16(_sum3, _r21, _k01234567, 4); + _sum4 = vmlaq_high_lane_f16(_sum4, _r22, _k01234567, 5); + _sum3 = vmlaq_high_lane_f16(_sum3, _r30, _k01234567, 6); + _sum4 = vmlaq_high_lane_f16(_sum4, _r31, _k01234567, 7); + _sum3 = vmlaq_high_lane_f16(_sum3, _r32, _k12345678, 7); + + _sum1 = vaddq_f16(_sum1, _sum2); + _sum3 = vaddq_f16(_sum3, _sum4); + + vst1q_f16(outptr, _sum1); + vst1q_f16(outptr2, _sum3); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + outptr += 8; + outptr2 += 8; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 3; + + rep(i, width) { + float16x8_t _sum1 = vld1q_f16(outptr); + float16x8_t _sum2 = vdupq_n_f16(0.f); + + float16x8_t _r00 = vld1q_f16(r0); + float16x8_t _r00n = vld1q_f16(r0 + 8); + float16x8_t _r01 = vextq_f16(_r00, _r00n, 1); + float16x8_t _r02 = vextq_f16(_r00, _r00n, 2); + + float16x8_t _r10 = vld1q_f16(r1); + float16x8_t _r10n = vld1q_f16(r1 + 8); + float16x8_t _r11 = vextq_f16(_r10, _r10n, 1); + float16x8_t _r12 = vextq_f16(_r10, _r10n, 2); + + float16x8_t _r20 = vld1q_f16(r2); + float16x8_t _r20n = vld1q_f16(r2 + 8); + float16x8_t _r21 = vextq_f16(_r20, _r20n, 1); + float16x8_t _r22 = vextq_f16(_r20, _r20n, 2); + + _sum1 = vmlaq_low_lane_f16(_sum1, _r00, _k01234567, 0); + _sum2 = vmlaq_low_lane_f16(_sum2, _r01, _k01234567, 1); + _sum1 = vmlaq_low_lane_f16(_sum1, _r02, _k01234567, 2); + _sum2 = vmlaq_low_lane_f16(_sum2, _r10, _k01234567, 3); + _sum1 = vmlaq_high_lane_f16(_sum1, _r11, _k01234567, 4); + _sum2 = vmlaq_high_lane_f16(_sum2, _r12, _k01234567, 5); + _sum1 = vmlaq_high_lane_f16(_sum1, _r20, _k01234567, 6); + _sum2 = vmlaq_high_lane_f16(_sum2, _r21, _k01234567, 7); + _sum1 = vmlaq_high_lane_f16(_sum1, _r22, _k12345678, 7); + + _sum1 = vaddq_f16(_sum1, _sum2); + + vst1q_f16(outptr, _sum1); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 8; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } + + filter += 9; + } +} + +void conv_stride1::do_conv_5x5_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, + size_t IH, size_t IW, size_t OH, size_t OW, + size_t IC) { + const size_t tail_step = IW - OW; + + rep(ic, IC) { + const __fp16* src_ptr = src + IW * IH * ic; + __fp16* outptr = dst; + __fp16* outptr2 = outptr + OW; + + const __fp16* r0 = src_ptr; + const __fp16* r1 = src_ptr + IW; + const __fp16* r2 = src_ptr + IW * 2; + const __fp16* r3 = src_ptr + IW * 3; + const __fp16* r4 = src_ptr + IW * 4; + const __fp16* r5 = src_ptr + IW * 5; + + float16x8_t _k0 = vld1q_f16(filter); + float16x8_t _k1 = vld1q_f16(filter + 8); + float16x8_t _k2 = vld1q_f16(filter + 16); + float16x8_t _k3 = vld1q_f16(filter + 17); + + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 3; + + rep(i, width) { + float16x8_t _sum = vld1q_f16(outptr); + float16x8_t _sum2 = vld1q_f16(outptr2); + + float16x8_t _r00 = vld1q_f16(r0); + float16x8_t _r05 = vld1q_f16(r0 + 8); + float16x8_t _r01 = vextq_f16(_r00, _r05, 1); + float16x8_t _r02 = vextq_f16(_r00, _r05, 2); + float16x8_t _r03 = vextq_f16(_r00, _r05, 3); + float16x8_t _r04 = vextq_f16(_r00, _r05, 4); + + float16x8_t _r10 = vld1q_f16(r1); + float16x8_t _r15 = vld1q_f16(r1 + 8); + float16x8_t _r11 = vextq_f16(_r10, _r15, 1); + float16x8_t _r12 = vextq_f16(_r10, _r15, 2); + float16x8_t _r13 = vextq_f16(_r10, _r15, 3); + float16x8_t _r14 = vextq_f16(_r10, _r15, 4); + + float16x8_t _r20 = vld1q_f16(r2); + float16x8_t _r25 = vld1q_f16(r2 + 8); + float16x8_t _r21 = vextq_f16(_r20, _r25, 1); + float16x8_t _r22 = vextq_f16(_r20, _r25, 2); + float16x8_t _r23 = vextq_f16(_r20, _r25, 3); + float16x8_t _r24 = vextq_f16(_r20, _r25, 4); + + float16x8_t _r30 = vld1q_f16(r3); + float16x8_t _r35 = vld1q_f16(r3 + 8); + float16x8_t _r31 = vextq_f16(_r30, _r35, 1); + float16x8_t _r32 = vextq_f16(_r30, _r35, 2); + float16x8_t _r33 = vextq_f16(_r30, _r35, 3); + float16x8_t _r34 = vextq_f16(_r30, _r35, 4); + + float16x8_t _r40 = vld1q_f16(r4); + float16x8_t _r45 = vld1q_f16(r4 + 8); + float16x8_t _r41 = vextq_f16(_r40, _r45, 1); + float16x8_t _r42 = vextq_f16(_r40, _r45, 2); + float16x8_t _r43 = vextq_f16(_r40, _r45, 3); + float16x8_t _r44 = vextq_f16(_r40, _r45, 4); + + float16x8_t _r50 = vld1q_f16(r5); + float16x8_t _r55 = vld1q_f16(r5 + 8); + float16x8_t _r51 = vextq_f16(_r50, _r55, 1); + float16x8_t _r52 = vextq_f16(_r50, _r55, 2); + float16x8_t _r53 = vextq_f16(_r50, _r55, 3); + float16x8_t _r54 = vextq_f16(_r50, _r55, 4); + + _sum = vmlaq_low_lane_f16(_sum, _r00, _k0, 0); + _sum = vmlaq_low_lane_f16(_sum, _r01, _k0, 1); + _sum = vmlaq_low_lane_f16(_sum, _r02, _k0, 2); + _sum = vmlaq_low_lane_f16(_sum, _r03, _k0, 3); + _sum = vmlaq_high_lane_f16(_sum, _r04, _k0, 4); + + _sum = vmlaq_high_lane_f16(_sum, _r10, _k0, 5); + _sum = vmlaq_high_lane_f16(_sum, _r11, _k0, 6); + _sum = vmlaq_high_lane_f16(_sum, _r12, _k0, 7); + _sum = vmlaq_low_lane_f16(_sum, _r13, _k1, 0); + _sum = vmlaq_low_lane_f16(_sum, _r14, _k1, 1); + + _sum = vmlaq_low_lane_f16(_sum, _r20, _k1, 2); + _sum = vmlaq_low_lane_f16(_sum, _r21, _k1, 3); + _sum = vmlaq_high_lane_f16(_sum, _r22, _k1, 4); + _sum = vmlaq_high_lane_f16(_sum, _r23, _k1, 5); + _sum = vmlaq_high_lane_f16(_sum, _r24, _k1, 6); + + _sum = vmlaq_high_lane_f16(_sum, _r30, _k1, 7); + _sum = vmlaq_low_lane_f16(_sum, _r31, _k2, 0); + _sum = vmlaq_low_lane_f16(_sum, _r32, _k2, 1); + _sum = vmlaq_low_lane_f16(_sum, _r33, _k2, 2); + _sum = vmlaq_low_lane_f16(_sum, _r34, _k2, 3); + + _sum = vmlaq_high_lane_f16(_sum, _r40, _k2, 4); + _sum = vmlaq_high_lane_f16(_sum, _r41, _k2, 5); + _sum = vmlaq_high_lane_f16(_sum, _r42, _k2, 6); + _sum = vmlaq_high_lane_f16(_sum, _r43, _k2, 7); + _sum = vmlaq_high_lane_f16(_sum, _r44, _k3, 7); + + _sum2 = vmlaq_low_lane_f16(_sum2, _r10, _k0, 0); + _sum2 = vmlaq_low_lane_f16(_sum2, _r11, _k0, 1); + _sum2 = vmlaq_low_lane_f16(_sum2, _r12, _k0, 2); + _sum2 = vmlaq_low_lane_f16(_sum2, _r13, _k0, 3); + _sum2 = vmlaq_high_lane_f16(_sum2, _r14, _k0, 4); + + _sum2 = vmlaq_high_lane_f16(_sum2, _r20, _k0, 5); + _sum2 = vmlaq_high_lane_f16(_sum2, _r21, _k0, 6); + _sum2 = vmlaq_high_lane_f16(_sum2, _r22, _k0, 7); + _sum2 = vmlaq_low_lane_f16(_sum2, _r23, _k1, 0); + _sum2 = vmlaq_low_lane_f16(_sum2, _r24, _k1, 1); + + _sum2 = vmlaq_low_lane_f16(_sum2, _r30, _k1, 2); + _sum2 = vmlaq_low_lane_f16(_sum2, _r31, _k1, 3); + _sum2 = vmlaq_high_lane_f16(_sum2, _r32, _k1, 4); + _sum2 = vmlaq_high_lane_f16(_sum2, _r33, _k1, 5); + _sum2 = vmlaq_high_lane_f16(_sum2, _r34, _k1, 6); + + _sum2 = vmlaq_high_lane_f16(_sum2, _r40, _k1, 7); + _sum2 = vmlaq_low_lane_f16(_sum2, _r41, _k2, 0); + _sum2 = vmlaq_low_lane_f16(_sum2, _r42, _k2, 1); + _sum2 = vmlaq_low_lane_f16(_sum2, _r43, _k2, 2); + _sum2 = vmlaq_low_lane_f16(_sum2, _r44, _k2, 3); + + _sum2 = vmlaq_high_lane_f16(_sum2, _r50, _k2, 4); + _sum2 = vmlaq_high_lane_f16(_sum2, _r51, _k2, 5); + _sum2 = vmlaq_high_lane_f16(_sum2, _r52, _k2, 6); + _sum2 = vmlaq_high_lane_f16(_sum2, _r53, _k2, 7); + _sum2 = vmlaq_high_lane_f16(_sum2, _r54, _k3, 7); + + vst1q_f16(outptr, _sum); + vst1q_f16(outptr2, _sum2); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + outptr += 8; + outptr2 += 8; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 3; + + rep(i, width) { + float16x8_t _sum = vld1q_f16(outptr); + + float16x8_t _r00 = vld1q_f16(r0); + float16x8_t _r05 = vld1q_f16(r0 + 8); + float16x8_t _r01 = vextq_f16(_r00, _r05, 1); + float16x8_t _r02 = vextq_f16(_r00, _r05, 2); + float16x8_t _r03 = vextq_f16(_r00, _r05, 3); + float16x8_t _r04 = vextq_f16(_r00, _r05, 4); + + float16x8_t _r10 = vld1q_f16(r1); + float16x8_t _r15 = vld1q_f16(r1 + 8); + float16x8_t _r11 = vextq_f16(_r10, _r15, 1); + float16x8_t _r12 = vextq_f16(_r10, _r15, 2); + float16x8_t _r13 = vextq_f16(_r10, _r15, 3); + float16x8_t _r14 = vextq_f16(_r10, _r15, 4); + + float16x8_t _r20 = vld1q_f16(r2); + float16x8_t _r25 = vld1q_f16(r2 + 8); + float16x8_t _r21 = vextq_f16(_r20, _r25, 1); + float16x8_t _r22 = vextq_f16(_r20, _r25, 2); + float16x8_t _r23 = vextq_f16(_r20, _r25, 3); + float16x8_t _r24 = vextq_f16(_r20, _r25, 4); + + float16x8_t _r30 = vld1q_f16(r3); + float16x8_t _r35 = vld1q_f16(r3 + 8); + float16x8_t _r31 = vextq_f16(_r30, _r35, 1); + float16x8_t _r32 = vextq_f16(_r30, _r35, 2); + float16x8_t _r33 = vextq_f16(_r30, _r35, 3); + float16x8_t _r34 = vextq_f16(_r30, _r35, 4); + + float16x8_t _r40 = vld1q_f16(r4); + float16x8_t _r45 = vld1q_f16(r4 + 8); + float16x8_t _r41 = vextq_f16(_r40, _r45, 1); + float16x8_t _r42 = vextq_f16(_r40, _r45, 2); + float16x8_t _r43 = vextq_f16(_r40, _r45, 3); + float16x8_t _r44 = vextq_f16(_r40, _r45, 4); + + _sum = vmlaq_low_lane_f16(_sum, _r00, _k0, 0); + _sum = vmlaq_low_lane_f16(_sum, _r01, _k0, 1); + _sum = vmlaq_low_lane_f16(_sum, _r02, _k0, 2); + _sum = vmlaq_low_lane_f16(_sum, _r03, _k0, 3); + _sum = vmlaq_high_lane_f16(_sum, _r04, _k0, 4); + + _sum = vmlaq_high_lane_f16(_sum, _r10, _k0, 5); + _sum = vmlaq_high_lane_f16(_sum, _r11, _k0, 6); + _sum = vmlaq_high_lane_f16(_sum, _r12, _k0, 7); + _sum = vmlaq_low_lane_f16(_sum, _r13, _k1, 0); + _sum = vmlaq_low_lane_f16(_sum, _r14, _k1, 1); + + _sum = vmlaq_low_lane_f16(_sum, _r20, _k1, 2); + _sum = vmlaq_low_lane_f16(_sum, _r21, _k1, 3); + _sum = vmlaq_high_lane_f16(_sum, _r22, _k1, 4); + _sum = vmlaq_high_lane_f16(_sum, _r23, _k1, 5); + _sum = vmlaq_high_lane_f16(_sum, _r24, _k1, 6); + + _sum = vmlaq_high_lane_f16(_sum, _r30, _k1, 7); + _sum = vmlaq_low_lane_f16(_sum, _r31, _k2, 0); + _sum = vmlaq_low_lane_f16(_sum, _r32, _k2, 1); + _sum = vmlaq_low_lane_f16(_sum, _r33, _k2, 2); + _sum = vmlaq_low_lane_f16(_sum, _r34, _k2, 3); + + _sum = vmlaq_high_lane_f16(_sum, _r40, _k2, 4); + _sum = vmlaq_high_lane_f16(_sum, _r41, _k2, 5); + _sum = vmlaq_high_lane_f16(_sum, _r42, _k2, 6); + _sum = vmlaq_high_lane_f16(_sum, _r43, _k2, 7); + _sum = vmlaq_high_lane_f16(_sum, _r44, _k3, 7); + + vst1q_f16(outptr, _sum); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + outptr += 8; + } + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } + + filter += 25; + } +} +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.h b/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.h new file mode 100644 index 00000000..95159691 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.h @@ -0,0 +1,32 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/do_conv_stride1.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/fallback/conv_bias/opr_impl.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +namespace megdnn { +namespace arm_common { +namespace fp16 { +namespace conv_stride1 { +void do_conv_2x2_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +void do_conv_3x3_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +void do_conv_5x5_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +} // namespace conv_stride1 +} // namespace fp16 +} // namespace arm_common +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/helper.h b/dnn/src/arm_common/conv_bias/f16/helper.h new file mode 100644 index 00000000..8e88e3e2 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/helper.h @@ -0,0 +1,341 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once +#include "src/common/unroll_macro.h" +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#define MATRIX_MUL4x4_fp16(sum, a, b) \ + sum##0 = vmul_lane_f16(b##0, a##0, 0); \ + sum##1 = vmul_lane_f16(b##0, a##1, 0); \ + sum##2 = vmul_lane_f16(b##0, a##2, 0); \ + sum##3 = vmul_lane_f16(b##0, a##3, 0); \ + sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##1, a##0, 1)); \ + sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##1, a##1, 1)); \ + sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##1, a##2, 1)); \ + sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##1, a##3, 1)); \ + sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##2, a##0, 2)); \ + sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##2, a##1, 2)); \ + sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##2, a##2, 2)); \ + sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##2, a##3, 2)); \ + sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##3, a##0, 3)); \ + sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##3, a##1, 3)); \ + sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##3, a##2, 3)); \ + sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##3, a##3, 3)); + +#define CONCAT(a, id) a##id + +#if MEGDNN_AARCH64 + +#define TRANSPOSE_4x4(a, ret) \ + do { \ + auto b00 = vzip1_f16(CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a1b1a2b2*/ \ + auto b01 = vzip2_f16(CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a3b3a4b4*/ \ + auto b10 = vzip1_f16(CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c1d1c2d2*/ \ + auto b11 = vzip2_f16(CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c3d3c4d4*/ \ + auto s32b00 = vreinterpret_s32_f16(b00); \ + auto s32b01 = vreinterpret_s32_f16(b01); \ + auto s32b10 = vreinterpret_s32_f16(b10); \ + auto s32b11 = vreinterpret_s32_f16(b11); \ + CONCAT(ret, 0).value = \ + vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)); \ + CONCAT(ret, 1).value = \ + vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)); \ + CONCAT(ret, 2).value = \ + vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)); \ + CONCAT(ret, 3).value = \ + vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)); \ + } while (0); + +#define TRANSPOSE_4x8(a, ret) \ + do { \ + auto b00 = vzip1q_f16(CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a1b1a2b2a3b3a4b4*/ \ + auto b01 = vzip2q_f16(CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a5b5a6b6a7b7a8b8*/ \ + auto b10 = vzip1q_f16(CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c1d1c2d2c3d3c4d4*/ \ + auto b11 = vzip2q_f16(CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c5d5c6d6c7d7c8d8*/ \ + auto s32b00 = vreinterpretq_s32_f16(b00); \ + auto s32b01 = vreinterpretq_s32_f16(b01); \ + auto s32b10 = vreinterpretq_s32_f16(b10); \ + auto s32b11 = vreinterpretq_s32_f16(b11); \ + auto f16b00 = vreinterpretq_f16_s32( \ + vzip1q_s32(s32b00, s32b10)); /*a1b1c1d1a2b2c2d2*/ \ + auto f16b01 = vreinterpretq_f16_s32( \ + vzip2q_s32(s32b00, s32b10)); /*a3b3c3d3a4b4a4d4*/ \ + auto f16b10 = vreinterpretq_f16_s32( \ + vzip1q_s32(s32b01, s32b11)); /*a5b5c5d5a6b6c6d6*/ \ + auto f16b11 = vreinterpretq_f16_s32( \ + vzip2q_s32(s32b01, s32b11)); /*a7b7c7d7a8b8c8d8*/ \ + CONCAT(ret, 0).value = vget_low_f16(f16b00); \ + CONCAT(ret, 1).value = vget_high_f16(f16b00); \ + CONCAT(ret, 2).value = vget_low_f16(f16b01); \ + CONCAT(ret, 3).value = vget_high_f16(f16b01); \ + CONCAT(ret, 4).value = vget_low_f16(f16b10); \ + CONCAT(ret, 5).value = vget_high_f16(f16b10); \ + CONCAT(ret, 6).value = vget_low_f16(f16b11); \ + CONCAT(ret, 7).value = vget_high_f16(f16b11); \ + } while (0); + +#define TRANSPOSE_8x4(a, ret) \ + do { \ + auto b00 = vzip1_f16(CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a1b1a2b2*/ \ + auto b01 = vzip2_f16(CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a3b3a4b4*/ \ + auto b10 = vzip1_f16(CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c1d1c2d2*/ \ + auto b11 = vzip2_f16(CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c3d3c4d4*/ \ + auto b20 = vzip1_f16(CONCAT(a, 4).value, \ + CONCAT(a, 5).value); /*e1f1e2f2*/ \ + auto b21 = vzip2_f16(CONCAT(a, 4).value, \ + CONCAT(a, 5).value); /*e3f3e4f4*/ \ + auto b30 = vzip1_f16(CONCAT(a, 6).value, \ + CONCAT(a, 7).value); /*g1h1g2h2*/ \ + auto b31 = vzip2_f16(CONCAT(a, 6).value, \ + CONCAT(a, 7).value); /*g3h3g4h4*/ \ + auto s32b00 = vreinterpret_s32_f16(b00); \ + auto s32b01 = vreinterpret_s32_f16(b01); \ + auto s32b10 = vreinterpret_s32_f16(b10); \ + auto s32b11 = vreinterpret_s32_f16(b11); \ + auto s32b20 = vreinterpret_s32_f16(b20); \ + auto s32b21 = vreinterpret_s32_f16(b21); \ + auto s32b30 = vreinterpret_s32_f16(b30); \ + auto s32b31 = vreinterpret_s32_f16(b31); \ + CONCAT(ret, 0).value = \ + vcombine_f16(vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)), \ + vreinterpret_f16_s32(vzip1_s32(s32b20, s32b30))); \ + CONCAT(ret, 1).value = \ + vcombine_f16(vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)), \ + vreinterpret_f16_s32(vzip2_s32(s32b20, s32b30))); \ + CONCAT(ret, 2).value = \ + vcombine_f16(vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)), \ + vreinterpret_f16_s32(vzip1_s32(s32b21, s32b31))); \ + CONCAT(ret, 3).value = \ + vcombine_f16(vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)), \ + vreinterpret_f16_s32(vzip2_s32(s32b21, s32b31))); \ + } while (0); + +#define TRANSPOSE_8x8(a, ret) \ + do { \ + auto b00 = vzip1q_f16(CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ + auto b01 = vzip2q_f16(CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a5b5a6b6 a7b7a8b8*/ \ + auto b10 = vzip1q_f16(CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ + auto b11 = vzip2q_f16(CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c5d5c6d6 c7d7c8d8*/ \ + auto b20 = vzip1q_f16(CONCAT(a, 4).value, \ + CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ + auto b21 = vzip2q_f16(CONCAT(a, 4).value, \ + CONCAT(a, 5).value); /*e5f5e6f6 e7f7e8f8*/ \ + auto b30 = vzip1q_f16(CONCAT(a, 6).value, \ + CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ + auto b31 = vzip2q_f16(CONCAT(a, 6).value, \ + CONCAT(a, 7).value); /*g5h5g6h6 g7h7g8h8*/ \ + auto s32b00 = vreinterpretq_s32_f16(b00); \ + auto s32b01 = vreinterpretq_s32_f16(b01); \ + auto s32b10 = vreinterpretq_s32_f16(b10); \ + auto s32b11 = vreinterpretq_s32_f16(b11); \ + auto s32b20 = vreinterpretq_s32_f16(b20); \ + auto s32b21 = vreinterpretq_s32_f16(b21); \ + auto s32b30 = vreinterpretq_s32_f16(b30); \ + auto s32b31 = vreinterpretq_s32_f16(b31); \ + auto s64b00 = vreinterpretq_s64_s32( \ + vzip1q_s32(s32b00, s32b10)); /*a1b1c1d1 a2b2c2d2*/ \ + auto s64b01 = vreinterpretq_s64_s32( \ + vzip2q_s32(s32b00, s32b10)); /*a3b3c3d3 a4b4c4d4*/ \ + auto s64b10 = vreinterpretq_s64_s32( \ + vzip1q_s32(s32b01, s32b11)); /*a5b5c5d5 a6b6c6d6*/ \ + auto s64b11 = vreinterpretq_s64_s32( \ + vzip2q_s32(s32b01, s32b11)); /*a7b7c7d7 a8b8c8d8*/ \ + auto s64b20 = vreinterpretq_s64_s32( \ + vzip1q_s32(s32b20, s32b30)); /*e1f1g1h1 e2f2g2h2*/ \ + auto s64b21 = vreinterpretq_s64_s32( \ + vzip2q_s32(s32b20, s32b30)); /*e3f3g3h3 e4f4g4h4*/ \ + auto s64b30 = vreinterpretq_s64_s32( \ + vzip1q_s32(s32b21, s32b31)); /*e5f5g5h5 e6f6g6h6*/ \ + auto s64b31 = vreinterpretq_s64_s32( \ + vzip2q_s32(s32b21, s32b31)); /*e7f7g7h7 e8f8g8h8*/ \ + CONCAT(ret, 0).value = \ + vreinterpretq_f16_s64(vzip1q_s64(s64b00, s64b20)); \ + CONCAT(ret, 1).value = \ + vreinterpretq_f16_s64(vzip2q_s64(s64b00, s64b20)); \ + CONCAT(ret, 2).value = \ + vreinterpretq_f16_s64(vzip1q_s64(s64b01, s64b21)); \ + CONCAT(ret, 3).value = \ + vreinterpretq_f16_s64(vzip2q_s64(s64b01, s64b21)); \ + CONCAT(ret, 4).value = \ + vreinterpretq_f16_s64(vzip1q_s64(s64b10, s64b30)); \ + CONCAT(ret, 5).value = \ + vreinterpretq_f16_s64(vzip2q_s64(s64b10, s64b30)); \ + CONCAT(ret, 6).value = \ + vreinterpretq_f16_s64(vzip1q_s64(s64b11, s64b31)); \ + CONCAT(ret, 7).value = \ + vreinterpretq_f16_s64(vzip2q_s64(s64b11, s64b31)); \ + } while (0); + +#else + +#define TRANSPOSE_4x4(a, ret) \ + do { \ + auto b0_01 = vzip_f16(CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ + auto b1_01 = vzip_f16(CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ + auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \ + auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \ + auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \ + auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \ + auto s32b00b10 = vzip_s32(s32b00, s32b10); /*a1b1c1d1 a2b2c2d2*/ \ + auto s32b01b11 = vzip_s32(s32b01, s32b11); /*a3b3c3d3 a4b4c4d4*/ \ + CONCAT(ret, 0).value = vreinterpret_f16_s32(s32b00b10.val[0]); \ + CONCAT(ret, 1).value = vreinterpret_f16_s32(s32b00b10.val[1]); \ + CONCAT(ret, 2).value = vreinterpret_f16_s32(s32b01b11.val[0]); \ + CONCAT(ret, 3).value = vreinterpret_f16_s32(s32b01b11.val[1]); \ + } while (0); + +#define TRANSPOSE_4x8(a, ret) \ + do { \ + auto b0_01 = vzipq_f16( \ + CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a1b1a2b2a3b3a4b4 a5b5a6b6a7b7a8b8*/ \ + auto b1_01 = vzipq_f16( \ + CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c1d1c2d2c3d3c4d4 c5d6c6d6c7d7c8d8*/ \ + auto s32b00 = vreinterpretq_s32_f16(b0_01.val[0]); \ + auto s32b01 = vreinterpretq_s32_f16(b0_01.val[1]); \ + auto s32b10 = vreinterpretq_s32_f16(b1_01.val[0]); \ + auto s32b11 = vreinterpretq_s32_f16(b1_01.val[1]); \ + auto s32b00b10 = vzipq_s32( \ + s32b00, s32b10); /*a1b1c1d1a2b2c2d2 a3b3c3d3a4b4c4d4*/ \ + auto s32b01b11 = vzipq_s32( \ + s32b01, s32b11); /*a5b5c5d5a6b6c6d6 a7b7c7d7a8b8c8d8*/ \ + CONCAT(ret, 0).value = \ + vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[0])); \ + CONCAT(ret, 1).value = \ + vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[0])); \ + CONCAT(ret, 2).value = \ + vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[1])); \ + CONCAT(ret, 3).value = \ + vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[1])); \ + CONCAT(ret, 4).value = \ + vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[0])); \ + CONCAT(ret, 5).value = \ + vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[0])); \ + CONCAT(ret, 6).value = \ + vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[1])); \ + CONCAT(ret, 7).value = \ + vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[1])); \ + } while (0); + +#define TRANSPOSE_8x4(a, ret) \ + do { \ + auto b0_01 = vzip_f16(CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ + auto b1_01 = vzip_f16(CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ + auto b2_01 = vzip_f16(CONCAT(a, 4).value, \ + CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ + auto b3_01 = vzip_f16(CONCAT(a, 6).value, \ + CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ + auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \ + auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \ + auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \ + auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \ + auto s32b20 = vreinterpret_s32_f16(b2_01.val[0]); \ + auto s32b21 = vreinterpret_s32_f16(b2_01.val[1]); \ + auto s32b30 = vreinterpret_s32_f16(b3_01.val[0]); \ + auto s32b31 = vreinterpret_s32_f16(b3_01.val[1]); \ + auto s32b00b10 = vzip_s32(s32b00, s32b10); \ + auto s32b01b11 = vzip_s32(s32b01, s32b11); \ + auto s32b20b30 = vzip_s32(s32b20, s32b30); \ + auto s32b21b31 = vzip_s32(s32b21, s32b31); \ + CONCAT(ret, 0).value = \ + vcombine_f16(vreinterpret_f16_s32(s32b00b10.val[0]), \ + vreinterpret_f16_s32(s32b20b30.val[0])); \ + CONCAT(ret, 1).value = \ + vcombine_f16(vreinterpret_f16_s32(s32b00b10.val[1]), \ + vreinterpret_f16_s32(s32b20b30.val[1])); \ + CONCAT(ret, 2).value = \ + vcombine_f16(vreinterpret_f16_s32(s32b01b11.val[0]), \ + vreinterpret_f16_s32(s32b21b31.val[0])); \ + CONCAT(ret, 3).value = \ + vcombine_f16(vreinterpret_f16_s32(s32b01b11.val[1]), \ + vreinterpret_f16_s32(s32b21b31.val[1])); \ + } while (0); + +#define TRANSPOSE_8x8(a, ret) \ + do { \ + auto b00 = vzipq_f16(CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ + auto b01 = vzipq_f16(CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a5b5a6b6 a7b7a8b8*/ \ + auto b10 = vzipq_f16(CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ + auto b11 = vzipq_f16(CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c5d5c6d6 c7d7c8d8*/ \ + auto b20 = vzipq_f16(CONCAT(a, 4).value, \ + CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ + auto b21 = vzipq_f16(CONCAT(a, 4).value, \ + CONCAT(a, 5).value); /*e5f5e6f6 e7f7e8f8*/ \ + auto b30 = vzipq_f16(CONCAT(a, 6).value, \ + CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ + auto b31 = vzipq_f16(CONCAT(a, 6).value, \ + CONCAT(a, 7).value); /*g5h5g6h6 g7h7g8h8*/ \ + auto s32b00 = vreinterpretq_s32_f16(b00.val[0]); \ + auto s32b01 = vreinterpretq_s32_f16(b01.val[1]); \ + auto s32b10 = vreinterpretq_s32_f16(b10.val[0]); \ + auto s32b11 = vreinterpretq_s32_f16(b11.val[1]); \ + auto s32b20 = vreinterpretq_s32_f16(b20.val[0]); \ + auto s32b21 = vreinterpretq_s32_f16(b21.val[1]); \ + auto s32b30 = vreinterpretq_s32_f16(b30.val[0]); \ + auto s32b31 = vreinterpretq_s32_f16(b31.val[1]); \ + auto s32b00b10 = vzipq_s32(s32b00, s32b10); \ + auto s32b01b11 = vzipq_s32(s32b01, s32b11); \ + auto s32b20b30 = vzipq_s32(s32b20, s32b30); \ + auto s32b21b31 = vzipq_s32(s32b21, s32b31); \ + CONCAT(ret, 0).value = vreinterpretq_f16_s32( \ + vcombine_s32(vget_low_s32(s32b00b10.val[0]), \ + vget_low_s32(s32b20b30.val[0]))); \ + CONCAT(ret, 1).value = vreinterpretq_f16_s32( \ + vcombine_s32(vget_high_s32(s32b00b10.val[0]), \ + vget_high_s32(s32b20b30.val[0]))); \ + CONCAT(ret, 2).value = vreinterpretq_f16_s32( \ + vcombine_s32(vget_low_s32(s32b00b10.val[1]), \ + vget_low_s32(s32b20b30.val[1]))); \ + CONCAT(ret, 3).value = vreinterpretq_f16_s32( \ + vcombine_s32(vget_high_s32(s32b00b10.val[1]), \ + vget_high_s32(s32b20b30.val[1]))); \ + CONCAT(ret, 4).value = vreinterpretq_f16_s32( \ + vcombine_s32(vget_low_s32(s32b01b11.val[0]), \ + vget_low_s32(s32b21b31.val[0]))); \ + CONCAT(ret, 5).value = vreinterpretq_f16_s32( \ + vcombine_s32(vget_high_s32(s32b01b11.val[0]), \ + vget_high_s32(s32b21b31.val[0]))); \ + CONCAT(ret, 6).value = vreinterpretq_f16_s32( \ + vcombine_s32(vget_low_s32(s32b01b11.val[1]), \ + vget_low_s32(s32b21b31.val[1]))); \ + CONCAT(ret, 7).value = vreinterpretq_f16_s32( \ + vcombine_s32(vget_high_s32(s32b01b11.val[1]), \ + vget_high_s32(s32b21b31.val[1]))); \ + } while (0); + +#endif +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/strategy.h b/dnn/src/arm_common/conv_bias/f16/strategy.h new file mode 100644 index 00000000..3e47097f --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/strategy.h @@ -0,0 +1,33 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/fallback/conv_bias/winograd/winograd.h" +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 2, + 3, 4, 4, winograd_2x3_4x4_f16) +MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 4, + 5, 1, 1, winograd_4x5_1x1_f16) +MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 6, + 3, 1, 1, winograd_6x3_1x1_f16) +MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 2, + 3, 8, 8, winograd_2x3_8x8_f16) +} // namespace winograd +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/strategy_2x3.cpp b/dnn/src/arm_common/conv_bias/f16/strategy_2x3.cpp new file mode 100644 index 00000000..a084ba45 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/strategy_2x3.cpp @@ -0,0 +1,373 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/strategy_2x3.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/f16/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/winograd/winograd.h" + +#include "src/arm_common/conv_bias/f16/helper.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F23) + +using namespace megdnn; +using namespace arm_common; + +namespace { +void transpose_4x4(const __fp16* src, __fp16* dst, int lda, int ldb) { + float16x4x2_t a0, a1; + a0.val[0] = vld1_f16(src + 0 * lda); // a0a1a2a3 + a0.val[1] = vld1_f16(src + 1 * lda); // b0b1b2b3 + a1.val[0] = vld1_f16(src + 2 * lda); // c0c1c2c3 + a1.val[1] = vld1_f16(src + 3 * lda); // d0d1d2d3 + float16x4x2_t b0 = vzip_f16(a0.val[0], a1.val[0]); // a0c0a1c1a2c2a3c3 + float16x4x2_t b1 = vzip_f16(a0.val[1], a1.val[1]); // b0d0b1d1b2d2b3d3 + float16x4x2_t c0 = vzip_f16(b0.val[0], b1.val[0]); // a0b0c0d0a1b1c1d1 + float16x4x2_t c1 = vzip_f16(b0.val[1], b1.val[1]); // a2b2c2d2a3b3c3d3 + vst1_f16(dst + 0 * ldb, c0.val[0]); + vst1_f16(dst + 1 * ldb, c0.val[1]); + vst1_f16(dst + 2 * ldb, c1.val[0]); + vst1_f16(dst + 3 * ldb, c1.val[1]); +} + +struct InputTransform2X3 { + template + static void prepare(const __fp16* input, __fp16* patch, __fp16* patchT, + int ih_start, int iw_start, size_t IH, size_t IW, + size_t ic, size_t IC) { + constexpr size_t alpha = 2 + 3 - 1; + constexpr size_t alpha4 = alpha * 4; + if (!(inner && ic + 4 < IC)) { + memset(patch, 0, sizeof(__fp16) * 4 * alpha * alpha); + } + if (inner) { + const __fp16* input_ptr = + input + ic * IH * IW + ih_start * IW + iw_start; + for (size_t ico = 0; ico < 4; ++ico) { + if (ic + ico < IC) { + auto v0 = vld1_f16(input_ptr); + auto v1 = vld1_f16(input_ptr + IW); + auto v2 = vld1_f16(input_ptr + IW * 2); + auto v3 = vld1_f16(input_ptr + IW * 3); + + vst1_f16(patch + ico * alpha4 + 0 * 4, v0); + vst1_f16(patch + ico * alpha4 + 1 * 4, v1); + vst1_f16(patch + ico * alpha4 + 2 * 4, v2); + vst1_f16(patch + ico * alpha4 + 3 * 4, v3); + input_ptr += IH * IW; + } + } + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + // partial copy + for (size_t ico = 0; ico < 4; ++ico) { + if (ic + ico < IC) { + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + patch[ico * alpha4 + iho * 4 + iwo] = + input[(ic + ico) * IH * IW + ih * IW + iw]; + } + } + } + } + } + + transpose_4x4(patch + 0 * 1, patchT + 0 * 4, 16, 4); + transpose_4x4(patch + 4 * 1, patchT + 4 * 4, 16, 4); + transpose_4x4(patch + 8 * 1, patchT + 8 * 4, 16, 4); + transpose_4x4(patch + 12 * 1, patchT + 12 * 4, 16, 4); + } + + static void transform(const __fp16* patchT, __fp16* input_transform_buf, + size_t unit_idx, size_t nr_units_in_tile, size_t ic, + size_t IC) { + constexpr size_t alpha = 2 + 3 - 1; + // BT * d * B +#define cb(m, n) \ + Vector<__fp16, 4> d##m##n = \ + Vector<__fp16, 4>::load(patchT + m * 4 * 4 + n * 4); + + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + + //! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 + //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 + //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 + //! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 +#define cb(m) \ + auto t0##m = d0##m - d2##m; \ + auto t1##m = d1##m + d2##m; \ + auto t2##m = d2##m - d1##m; \ + auto t3##m = d3##m - d1##m; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(m) \ + d##m##0 = t##m##0 - t##m##2; \ + d##m##1 = t##m##1 + t##m##2; \ + d##m##2 = t##m##2 - t##m##1; \ + d##m##3 = t##m##3 - t##m##1; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(m, n) \ + d##m##n.save(input_transform_buf + \ + (m * alpha + n) * nr_units_in_tile * IC + unit_idx * IC + \ + ic); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) +#undef cb + } +}; + +template +struct OutputTransform2X3 { + static void transform(const dt_float16* output_transform_buf, + const dt_float16* bias, dt_float16* output, + dt_float16* transform_mid_buf, size_t oh_start, + size_t ow_start, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, size_t oc_index, + size_t unit_idx, size_t nr_units_in_tile, + const DType& src_dtype, const DType& dst_dtype) { + Op op(src_dtype, dst_dtype); + const __fp16* output_transform_ptr = + reinterpret_cast(output_transform_buf); + const __fp16* bias_ptr = reinterpret_cast(bias); + __fp16* output_ptr = reinterpret_cast<__fp16*>(output); + __fp16* transform_mid_ptr = + reinterpret_cast<__fp16*>(transform_mid_buf); + + //! AT * m * A + constexpr size_t alpha = 2 + 3 - 1; + size_t OC = oc_end - oc_start; + size_t oc = oc_start + oc_index; + +#define cb(m, n) \ + auto v##m##n = Vector<__fp16, 4>::load( \ + output_transform_ptr + (m * alpha + n) * nr_units_in_tile * OC + \ + unit_idx * OC + oc_index); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + //! 1 1 1 0 v00 v01 v02 v03 1 0 + //! 0 1 -1 1 v10 v11 v12 v13 1 1 + //! v20 v21 v22 v23 1 -1 + //! v30 v31 v32 v33 0 1 +#define cb(m) \ + auto t0##m = v0##m + v1##m + v2##m; \ + auto t1##m = v1##m - v2##m + v3##m; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + v00 = t00 + t01 + t02; + v10 = t10 + t11 + t12; + v01 = t01 - t02 + t03; + v11 = t11 - t12 + t13; + + Vector<__fp16, 4> vbias; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias = Vector<__fp16, 4>::load(bias_ptr + oc); + + v00 += vbias; + v10 += vbias; + v01 += vbias; + v11 += vbias; + } + float16x8_t result01, result23; + result01 = vcombine_f16(v00.value, v01.value); + result23 = vcombine_f16(v10.value, v11.value); + if (bmode != BiasMode::BIAS) { + result01 = op(result01); + result23 = op(result23); + } + vst1q_f16(transform_mid_ptr, result01); + vst1q_f16(transform_mid_ptr + 8, result23); + + for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { + for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { + for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) { + size_t oh = oh_start + oho; + size_t ow = ow_start + owo; + __fp16 res = transform_mid_ptr[oho * 2 * 4 + owo * 4 + oco]; + if (bmode == BiasMode::BIAS) { + res += bias_ptr[(oc + oco) * OH * OW + oh * OW + ow]; + res = op(res); + } + output_ptr[(oc + oco) * OH * OW + oh * OW + ow] = res; + } + } + } + } +}; +} // namespace + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f16) + +void winograd_2x3_4x4_f16::filter(const dt_float16* filter, + dt_float16* filter_transform_buf, + dt_float16* transform_mid_buf, size_t OC, + size_t IC, size_t oc_start, size_t oc_end) { + constexpr int alpha = 2 + 3 - 1; + //! G * g * GT + __fp16* filter_transbuf_ptr = + reinterpret_cast<__fp16*>(filter_transform_buf); + __fp16* filter_transmid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf); + + for (size_t oc = oc_start; oc < oc_end; oc++) { + rep(ic, IC) { + const __fp16* filter_ptr = reinterpret_cast(filter) + + (oc * IC + ic) * 3 * 3; + /** + * origin: (4x3) * (3 x 3) * (3 x 4) + * pack to G and g to times of 4 + * now: (4x4) * (4 x 4) * (4 x 4) + */ + //! 1 0 0 0 v00 v01 v02 0 1 0.5 0.5 0 + //! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 + //! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 + //! 0 0 1 0 0 0 0 0 0 0 0 0 + float16x4_t v0 = vld1_f16(filter_ptr); // 0 1 2 3 + float16x4_t v1 = vld1_f16(filter_ptr + 3); // 3 4 5 6 + float16x4_t v2 = vld1_f16(filter_ptr + 5); // 5678 + float16x4_t v3 = vdup_n_f16(0); + v2 = vext_f16(v2, v3, 1); + v0 = vset_lane_f16(0, v0, 3); + v1 = vset_lane_f16(0, v1, 3); +#define cb(i) float16x4_t vsum##i; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + vsum0 = v0; + float16x4_t v0addv2 = vadd_f16(v0, v2); + float16x4_t v02addv1 = vadd_f16(v0addv2, v1); + float16x4_t v02subv1 = vsub_f16(v0addv2, v1); + vsum1 = vmul_n_f16(v02addv1, 0.5); + vsum2 = vmul_n_f16(v02subv1, 0.5); + vsum3 = v2; + +#define cb(i) \ + do { \ + mid_buf1[0] = vget_lane_f16(vsum##i, 0); \ + __fp16 a0a2 = vget_lane_f16(vsum##i, 0) + vget_lane_f16(vsum##i, 2); \ + __fp16 a0a2adda1 = a0a2 + vget_lane_f16(vsum##i, 1); \ + __fp16 a0a2suba1 = a0a2 - vget_lane_f16(vsum##i, 1); \ + mid_buf1[1] = a0a2adda1 * 0.5; \ + mid_buf1[2] = a0a2suba1 * 0.5; \ + mid_buf1[3] = vget_lane_f16(vsum##i, 2); \ + mid_buf1 += 4; \ + } while (0); + + __fp16* mid_buf1 = filter_transmid_ptr; + UNROLL_CALL_NOWRAPPER(4, cb); + mid_buf1 = filter_transmid_ptr; +#undef cb + rep(i, alpha) rep(j, alpha) { + filter_transbuf_ptr[(i * alpha + j) * OC * IC + ic * OC + oc] = + filter_transmid_ptr[i * alpha + j]; + } + } + } +} + +void winograd_2x3_4x4_f16::input(const dt_float16* input, + dt_float16* input_transform_buf, + dt_float16* transform_mid_buf, size_t IH, + size_t IW, size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, + size_t nr_units_in_tile) { + megdnn_assert(IC % 4 == 0); + constexpr int alpha = 3 + 2 - 1; + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + dt_float16* patch = transform_mid_buf; + dt_float16* patchT = transform_mid_buf + 4 * alpha * alpha; + + for (size_t ic = 0; ic < IC; ic += 4) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransform2X3::prepare( + reinterpret_cast(input), + reinterpret_cast<__fp16*>(patch), + reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, + IH, IW, ic, IC); + InputTransform2X3::transform( + reinterpret_cast(patchT), + reinterpret_cast<__fp16*>(input_transform_buf), + unit_idx, nr_units_in_tile, ic, IC); + } else { + InputTransform2X3::prepare( + reinterpret_cast(input), + reinterpret_cast<__fp16*>(patch), + reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, + IH, IW, ic, IC); + InputTransform2X3::transform( + reinterpret_cast(patchT), + reinterpret_cast<__fp16*>(input_transform_buf), + unit_idx, nr_units_in_tile, ic, IC); + } + } + } +} + +void winograd_2x3_4x4_f16::output(const dt_float16* output_transform_buf, + const dt_float16* bias, dt_float16* output, + dt_float16* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, + size_t unit_start_idx, size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform2X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + + for (size_t oc = oc_start; oc < oc_end; oc += 4) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_fp16_F23, cb, __fp16, __fp16, bmode, + nonline_mode, output_transform_buf, bias, output, transform_mid_buf, + oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, + nr_units_in_tile, src_dtype, dst_dtype); + } + } +#undef cb +} + +} // namespace winograd +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp b/dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp new file mode 100644 index 00000000..e25e8f11 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp @@ -0,0 +1,407 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "src/fallback/conv_bias/winograd/winograd.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" + +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/arm_common/conv_bias/f16/strategy.h" +#include "src/arm_common/conv_bias/f16/helper.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" + +#include "src/common/winograd/winograd_generator.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "midout.h" + +#include "src/common/winograd/winograd_helper.h" + +MIDOUT_DECL(megdnn_arm_common_winograd_f16_F23_8x8) + +using namespace megdnn; +using namespace arm_common; +namespace { +void transpose_8x4(const __fp16* src, __fp16* dst, int lda, int ldb) { + float16x4x2_t a0, a1, a2, a3; + a0.val[0] = vld1_f16(src + 0 * lda); + a0.val[1] = vld1_f16(src + 1 * lda); + a1.val[0] = vld1_f16(src + 2 * lda); + a1.val[1] = vld1_f16(src + 3 * lda); + a2.val[0] = vld1_f16(src + 4 * lda); + a2.val[1] = vld1_f16(src + 5 * lda); + a3.val[0] = vld1_f16(src + 6 * lda); + a3.val[1] = vld1_f16(src + 7 * lda); + float16x4x2_t b0 = vzip_f16(a0.val[0], a1.val[0]); + float16x4x2_t b1 = vzip_f16(a0.val[1], a1.val[1]); + float16x4x2_t b2 = vzip_f16(a2.val[0], a3.val[0]); + float16x4x2_t b3 = vzip_f16(a2.val[1], a3.val[1]); + + float16x4x2_t c0 = vzip_f16(b0.val[0], b1.val[0]); + float16x4x2_t c1 = vzip_f16(b0.val[1], b1.val[1]); + float16x4x2_t c2 = vzip_f16(b2.val[0], b3.val[0]); + float16x4x2_t c3 = vzip_f16(b2.val[1], b3.val[1]); + + vst1_f16(dst + 0 * ldb, c0.val[0]); + vst1_f16(dst + 1 * ldb, c2.val[0]); + vst1_f16(dst + 2 * ldb, c0.val[1]); + vst1_f16(dst + 3 * ldb, c2.val[1]); + vst1_f16(dst + 4 * ldb, c1.val[0]); + vst1_f16(dst + 5 * ldb, c3.val[0]); + vst1_f16(dst + 6 * ldb, c1.val[1]); + vst1_f16(dst + 7 * ldb, c3.val[1]); +} + +struct InputTransform2X3_8x8 { + template + static void prepare(const __fp16* input, __fp16* patch, __fp16* patchT, + int ih_start, int iw_start, size_t IH, size_t IW, + size_t ic, size_t IC) { + constexpr size_t alpha = 2 + 3 - 1; + if (!(inner && ic + 8 < IC)) { + memset(patch, 0, sizeof(__fp16) * 8 * alpha * alpha); + } + if (inner) { + const __fp16* input_ptr = + input + ic * IH * IW + ih_start * IW + iw_start; + for (size_t ico = 0; ico < 8; ++ico) { + if (ic + ico < IC) { + auto v0 = vld1_f16(input_ptr); + auto v1 = vld1_f16(input_ptr + IW); + auto v2 = vld1_f16(input_ptr + IW * 2); + auto v3 = vld1_f16(input_ptr + IW * 3); + + vst1_f16(patch + ico * alpha * 4 + 0 * 4, v0); + vst1_f16(patch + ico * alpha * 4 + 1 * 4, v1); + vst1_f16(patch + ico * alpha * 4 + 2 * 4, v2); + vst1_f16(patch + ico * alpha * 4 + 3 * 4, v3); + input_ptr += IH * IW; + } + } + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + // partial copy + for (size_t ico = 0; ico < 8; ++ico) { + if (ic + ico < IC) { + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + patch[ico * alpha * alpha + iho * alpha + iwo] = + + input[(ic + ico) * IH * IW + ih * IW + iw]; + } + } + } + } + } + transpose_8x4(patch + 4 * 0, patchT + 32 * 0, 16, 4); + transpose_8x4(patch + 4 * 1, patchT + 32 * 1, 16, 4); + transpose_8x4(patch + 4 * 2, patchT + 32 * 2, 16, 4); + transpose_8x4(patch + 4 * 3, patchT + 32 * 3, 16, 4); + } + + static void transform(const __fp16* patchT, __fp16* input_transform_buf, + size_t unit_idx, size_t nr_units_in_tile, size_t ic, + size_t IC) { + constexpr size_t alpha = 2 + 3 - 1; + // BT * d * B +#define cb(m, n) \ + Vector<__fp16, 8> d##m##n = \ + Vector<__fp16, 8>::load(patchT + 8 * (m * 4 + n)); + + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + + //! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 + //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 + //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 + //! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 +#define cb(m) \ + auto t0##m = d0##m - d2##m; \ + auto t1##m = d1##m + d2##m; \ + auto t2##m = d2##m - d1##m; \ + auto t3##m = d3##m - d1##m; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(m) \ + d##m##0 = t##m##0 - t##m##2; \ + d##m##1 = t##m##1 + t##m##2; \ + d##m##2 = t##m##2 - t##m##1; \ + d##m##3 = t##m##3 - t##m##1; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + size_t ICB = IC / 8; + size_t icb = ic / 8; +#define cb(m, n) \ + d##m##n.save(input_transform_buf + \ + (m * alpha + n) * nr_units_in_tile * ICB * 8 + \ + icb * nr_units_in_tile * 8 + unit_idx * 8); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) +#undef cb + } +}; + +template +struct OutputTransform2X3_8x8 { + static void transform(const dt_float16* output_transform_buf, + const dt_float16* bias, dt_float16* output, + dt_float16* transform_mid_buf, size_t oh_start, + size_t ow_start, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, size_t oc_index, + size_t unit_idx, size_t nr_units_in_tile, + const DType& src_dtype, const DType& dst_dtype) { + Op op(src_dtype, dst_dtype); + const __fp16* output_transform_ptr = + reinterpret_cast(output_transform_buf); + const __fp16* bias_ptr = reinterpret_cast(bias); + __fp16* output_ptr = reinterpret_cast<__fp16*>(output); + __fp16* transform_mid_ptr = + reinterpret_cast<__fp16*>(transform_mid_buf); + + //! AT * m * A + constexpr size_t alpha = 2 + 3 - 1; + + size_t oc = oc_start + oc_index; + size_t OCB = (oc_end - oc_start) / 8; + size_t ocb = oc_index / 8; + +#define cb(m, n) \ + auto v##m##n = Vector<__fp16, 8>::load( \ + output_transform_ptr + \ + (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ + ocb * nr_units_in_tile * 8 + unit_idx * 8); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + + //! 1 1 1 0 v00 v01 v02 v03 1 0 + //! 0 1 -1 1 v10 v11 v12 v13 1 1 + //! v20 v21 v22 v23 1 -1 + //! v30 v31 v32 v33 0 1 +#define cb(m) \ + auto t0##m = v0##m + v1##m + v2##m; \ + auto t1##m = v1##m - v2##m + v3##m; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + v00 = t00 + t01 + t02; + v10 = t10 + t11 + t12; + v01 = t01 - t02 + t03; + v11 = t11 - t12 + t13; + + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + Vector<__fp16, 8> vbias; + vbias = Vector<__fp16, 8>::load(bias_ptr + oc); + + v00 += vbias; + v10 += vbias; + v01 += vbias; + v11 += vbias; + } + if (bmode != BiasMode::BIAS) { + v00.value = op(v00.value); + v01.value = op(v01.value); + v10.value = op(v10.value); + v11.value = op(v11.value); + } + + v00.save(transform_mid_ptr + (0 * 2 + 0) * 8); + v10.save(transform_mid_ptr + (1 * 2 + 0) * 8); + v01.save(transform_mid_ptr + (0 * 2 + 1) * 8); + v11.save(transform_mid_ptr + (1 * 2 + 1) * 8); + + for (size_t oco = 0; oco < 8 && oc + oco < oc_end; ++oco) { + for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { + for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) { + size_t oh = oh_start + oho; + size_t ow = ow_start + owo; + __fp16 res = transform_mid_ptr[oho * 2 * 8 + owo * 8 + oco]; + if (bmode == BiasMode::BIAS) { + res += bias_ptr[(oc + oco) * OH * OW + oh * OW + ow]; + res = op(res); + } + output_ptr[(oc + oco) * OH * OW + oh * OW + ow] = res; + } + } + } + } +}; +} // namespace + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_f16) + +void winograd_2x3_8x8_f16::filter(const dt_float16* filter, + dt_float16* filter_transform_buf, + dt_float16* transform_mid_buf, size_t OC, + size_t IC, size_t oc_start, size_t oc_end) { + constexpr int alpha = 2 + 3 - 1; + //! G * g * GT + __fp16* filter_transbuf_ptr = + reinterpret_cast<__fp16*>(filter_transform_buf); + __fp16* filter_transmid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf); + + size_t OCB = OC / 8; + size_t ICB = IC / 8; + for (size_t oc = oc_start; oc < oc_end; oc++) { + rep(ic, IC) { + const __fp16* filter_ptr = reinterpret_cast(filter) + + (oc * IC + ic) * 3 * 3; + /** + * origin: (4x3) * (3 x 3) * (3 x 4) + * pack to G and g to times of 4 + * now: (4x4) * (4 x 4) * (4 x 4) + */ + //! 1 0 0 0 v00 v01 v02 0 1 0.5 0.5 0 + //! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 + //! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 + //! 0 0 1 0 0 0 0 0 0 0 0 0 + float16x4_t v0 = vld1_f16(filter_ptr); // 0 1 2 3 + float16x4_t v1 = vld1_f16(filter_ptr + 3); // 3 4 5 6 + float16x4_t v2 = vld1_f16(filter_ptr + 5); // 5678 + float16x4_t v3 = vdup_n_f16(0); + v2 = vext_f16(v2, v3, 1); + v0 = vset_lane_f16(0, v0, 3); + v1 = vset_lane_f16(0, v1, 3); +#define cb(i) float16x4_t vsum##i; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + vsum0 = v0; + float16x4_t v0addv2 = vadd_f16(v0, v2); + float16x4_t v02addv1 = vadd_f16(v0addv2, v1); + float16x4_t v02subv1 = vsub_f16(v0addv2, v1); + vsum1 = vmul_n_f16(v02addv1, 0.5); + vsum2 = vmul_n_f16(v02subv1, 0.5); + vsum3 = v2; + +#define cb(i) \ + do { \ + mid_buf1[0] = vget_lane_f16(vsum##i, 0); \ + __fp16 a0a2 = vget_lane_f16(vsum##i, 0) + vget_lane_f16(vsum##i, 2); \ + __fp16 a0a2adda1 = a0a2 + vget_lane_f16(vsum##i, 1); \ + __fp16 a0a2suba1 = a0a2 - vget_lane_f16(vsum##i, 1); \ + mid_buf1[1] = a0a2adda1 * 0.5; \ + mid_buf1[2] = a0a2suba1 * 0.5; \ + mid_buf1[3] = vget_lane_f16(vsum##i, 2); \ + mid_buf1 += 4; \ + } while (0); + + __fp16* mid_buf1 = filter_transmid_ptr; + UNROLL_CALL_NOWRAPPER(4, cb); + mid_buf1 = filter_transmid_ptr; +#undef cb + size_t ocb = (oc) / 8; + size_t oc8 = (oc) % 8; + size_t icb = (ic) / 8; + size_t ic8 = (ic) % 8; + rep(i, alpha) rep(j, alpha) { + filter_transbuf_ptr[(i * alpha + j) * OCB * ICB * 8 * 8 + + ocb * ICB * 8 * 8 + icb * 8 * 8 + ic8 * 8 + + oc8] = filter_transmid_ptr[i * alpha + j]; + } + } + } +} + +void winograd_2x3_8x8_f16::input(const dt_float16* input, + dt_float16* input_transform_buf, + dt_float16* transform_mid_buf, size_t IH, + size_t IW, size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, size_t nr_units_in_tile) { + megdnn_assert(IC % 8 == 0); + constexpr int alpha = 3 + 2 - 1; + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + dt_float16* patch = transform_mid_buf; + dt_float16* patchT = transform_mid_buf + 8 * alpha * alpha; + + for (size_t ic = 0; ic < IC; ic += 8) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransform2X3_8x8::prepare( + reinterpret_cast(input), + reinterpret_cast<__fp16*>(patch), + reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, + IH, IW, ic, IC); + InputTransform2X3_8x8::transform( + reinterpret_cast(patchT), + reinterpret_cast<__fp16*>(input_transform_buf), + unit_idx, nr_units_in_tile, ic, IC); + + } else { + InputTransform2X3_8x8::prepare( + reinterpret_cast(input), + reinterpret_cast<__fp16*>(patch), + reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, + IH, IW, ic, IC); + InputTransform2X3_8x8::transform( + reinterpret_cast(patchT), + reinterpret_cast<__fp16*>(input_transform_buf), + unit_idx, nr_units_in_tile, ic, IC); + } + } + } +} + +void winograd_2x3_8x8_f16::output(const dt_float16* output_transform_buf, + const dt_float16* bias, dt_float16* output, + dt_float16* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, + size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform2X3_8x8<_bmode MEGDNN_COMMA _nonline_op>::transform( \ + __VA_ARGS__); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + + for (size_t oc = oc_start; oc < oc_end; oc += 8) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_fp16_F23_8x8, cb, __fp16, __fp16, + bmode, nonline_mode, output_transform_buf, bias, output, + transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, + oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); + } + } +#undef cb +} + +} // namespace winograd +} // namespace arm_common +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp b/dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp new file mode 100644 index 00000000..d36d2184 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp @@ -0,0 +1,488 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/f16/helper.h" +#include "src/arm_common/conv_bias/f16/strategy.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/winograd/winograd.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F45) +using namespace megdnn; +using namespace arm_common; +namespace { + +struct FilterTransform4X5 { +#define FILTER_TRANSFORM(d, wd) \ + do { \ + wd##0 = d##0; \ + wd##r0 = d##r0; \ + wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.222168; \ + wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.222168; \ + wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.222168; \ + wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.222168; \ + auto tmpd0 = d##0 * 0.710938; \ + auto tmpd1 = d##1 * 0.355469; \ + auto tmpd2 = d##2 * 0.177734; \ + auto tmpd3 = d##3 * 0.088867; \ + auto tmpd4 = d##4 * 0.044434; \ + auto tmpdr0 = d##r0 * 0.710938; \ + auto tmpdr1 = d##r1 * 0.355469; \ + auto tmpdr2 = d##r2 * 0.177734; \ + auto tmpdr3 = d##r3 * 0.088867; \ + auto tmpdr4 = d##r4 * 0.044434; \ + wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ + wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ + wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ + wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ + tmpd0 = d##0 * 0.011108; \ + tmpd1 = d##1 * 0.022217; \ + tmpd2 = d##2 * 0.044434; \ + tmpd3 = d##3 * 0.088867; \ + tmpd4 = d##4 * 0.177734; \ + tmpdr0 = d##r0 * 0.011108; \ + tmpdr1 = d##r1 * 0.022217; \ + tmpdr2 = d##r2 * 0.044434; \ + tmpdr3 = d##r3 * 0.088867; \ + tmpdr4 = d##r4 * 0.177734; \ + wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ + ; \ + wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ + ; \ + wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ + ; \ + wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ + ; \ + wd##7 = d##4; \ + wd##r7 = d##r4; \ + } while (0); + +#define FILTER_TRANSFORM_FINAL(d, wd) \ + do { \ + wd##0 = d##0; \ + wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.222168; \ + wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.222168; \ + auto tmp0 = d##0 * 0.710938 + d##2 * 0.177734 + d##4 * 0.044434; \ + auto tmp1 = d##1 * 0.355469 + d##3 * 0.088867; \ + wd##3 = tmp0 + tmp1; \ + wd##4 = tmp0 - tmp1; \ + tmp0 = d##0 * 0.011108 + d##2 * 0.044434 + d##4 * 0.177734; \ + tmp1 = d##1 * 0.022217 + d##3 * 0.088867; \ + wd##5 = tmp0 + tmp1; \ + wd##6 = tmp0 - tmp1; \ + wd##7 = d##4; \ + } while (0); + static void transform(const __fp16* filter, __fp16* filter_transform_buf, + __fp16* transform_mid_buf, size_t OC, size_t IC, + size_t oc_start, size_t oc_end) { + // Gg * GT + // G + //[[ 1. 0. 0. 0. 0. ] + // [-0.2222222 -0.2222222 -0.2222222 -0.2222222 -0.2222222] + // [-0.2222222 0.2222222 -0.2222222 0.2222222 -0.2222222] + // [ 0.7111111 0.3555556 0.1777778 0.0888889 0.0444444] + // [ 0.7111111 -0.3555556 0.1777778 -0.0888889 0.0444444] + // [ 0.0111111 0.0222222 0.0444444 0.0888889 0.1777778] + // [ 0.0111111 -0.0222222 0.0444444 -0.0888889 0.1777778] + // [ 0. 0. 0. 0. 1. ]] + constexpr size_t alpha = 4 + 5 - 1; + for (size_t oc = oc_start; oc < oc_end; oc++) { + rep(ic, IC) { + const __fp16* fptr = filter + (oc * IC + ic) * 5 * 5; + +#define cb(i) Vector<__fp16, 4> g##i = Vector<__fp16, 4>::load(fptr + 5 * i); + UNROLL_CALL_NOWRAPPER(5, cb); +#undef cb + +#define cb(i) __fp16 gr##i = *(fptr + 5 * i + 4); + UNROLL_CALL_NOWRAPPER(5, cb); + +#undef cb +#define cb(i) Vector<__fp16, 4> Gg##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + +#define cb(i) __fp16 Ggr##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + +#define cb(i) Vector<__fp16, 8> Ggt##i; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(i) Vector<__fp16, 8> result##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + FILTER_TRANSFORM(g, Gg) +#if MEGDNN_AARCH64 + float16x8_t vgr = {Ggr0, Ggr1, Ggr2, Ggr3, + Ggr4, Ggr5, Ggr6, Ggr7}; + Vector<__fp16, 8> Ggt4(vgr); + TRANSPOSE_8x4(Gg, Ggt); + FILTER_TRANSFORM_FINAL(Ggt, result); +#define cb(i) result##i.save(transform_mid_buf + i * alpha); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + rep(i, alpha) rep(j, alpha) { + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + + oc] = transform_mid_buf[j * alpha + i]; + } +#else + +#define GET_VECTOR_FP16D_ELEM(s, i, idx) vget_lane_f16(CONCAT(s, i).value, idx) + +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_FP16D_ELEM(Gg, i, 0); \ + auto tmp024 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) + \ + GET_VECTOR_FP16D_ELEM(Gg, i, 2) + Ggr##i; \ + auto tmp13 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) + \ + GET_VECTOR_FP16D_ELEM(Gg, i, 3); \ + mid_buf1[1] = (tmp024 + tmp13) * -0.2222222; \ + mid_buf1[2] = (tmp024 - tmp13) * -0.2222222; \ + auto tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.7111111; \ + auto tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.3555556; \ + auto tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.1777778; \ + auto tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \ + auto tmp4 = Ggr##i * 0.0444444; \ + tmp024 = tmp0 + tmp2 + tmp4; \ + tmp13 = tmp1 + tmp3; \ + mid_buf1[3] = tmp024 + tmp13; \ + mid_buf1[4] = tmp024 - tmp13; \ + tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.0111111; \ + tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.0222222; \ + tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.0444444; \ + tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \ + tmp4 = Ggr##i * 0.1777778; \ + tmp024 = tmp0 + tmp2 + tmp4; \ + tmp13 = tmp1 + tmp3; \ + mid_buf1[5] = tmp024 + tmp13; \ + mid_buf1[6] = tmp024 - tmp13; \ + mid_buf1[7] = Ggr##i; \ + mid_buf1 += 8; \ + } while (0); + __fp16* mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(8, cb); + mid_buf1 = transform_mid_buf; +#undef cb +#undef GET_VECTOR_FP16D_ELEM + rep(i, alpha) rep(j, alpha) { + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + + oc] = transform_mid_buf[i * alpha + j]; + } +#endif + } + } + } +}; +#undef FILTER_TRANSFORM +#undef FILTER_TRANSFORM_FINAL + +struct InputTransform4X5 { +#define INPUT_TRANSFORM(d, wd) \ + do { \ + wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ + auto tmp0 = d##2 - d##4 * 4.25f + d##6; \ + auto tmp1 = d##1 - d##3 * 4.25f + d##5; \ + wd##1 = tmp0 + tmp1; \ + wd##2 = tmp0 - tmp1; \ + tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ + tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ + wd##3 = tmp0 + tmp1; \ + wd##4 = tmp0 - tmp1; \ + tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ + tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ + wd##5 = tmp0 + tmp1; \ + wd##6 = tmp0 - tmp1; \ + wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ + } while (0) + +#define GET_VECTOR_FP16Q_ELEM(s, i, idx) vgetq_lane_f16(CONCAT(s, i).value, idx) + + template + static void transform(const __fp16* input, __fp16* input_transform_buf, + __fp16* transform_mid_buf, int ih_start, int iw_start, + size_t ic, size_t IH, size_t IW, size_t IC, + size_t unit_idx, size_t nr_units_in_tile) { + // BTd * B + //([[ 1. , 0. , -5.25, 0. , 5.25, 0. , -1. , 0. ], + // [ 0. , 1. , 1. , -4.25, -4.25, 1. , 1. , 0. ], + // [ 0. , -1. , 1. , 4.25, -4.25, -1. , 1. , 0. ], + // [ 0. , 2. , 4. , -2.5 , -5. , 0.5 , 1. , 0. ], + // [ 0. , -2. , 4. , 2.5 , -5. , -0.5 , 1. , 0. ], + // [ 0. , 0.5 , 0.25, -2.5 , -1.25, 2. , 1. , 0. ], + // [ 0. , -0.5 , 0.25, 2.5 , -1.25, -2. , 1. , 0. ], + // [ 0. , -1. , 0. , 5.25, 0. , -5.25, 0. , 1. ]])) + + constexpr size_t alpha = 4 + 5 - 1; + if (!inner) { + memset(transform_mid_buf, 0, sizeof(__fp16) * alpha * alpha); + } + +#define cb(i) Vector<__fp16, 8> d##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + if (inner) { + const __fp16* input_ptr = + input + ic * IH * IW + ih_start * IW + iw_start; +#define cb(i) d##i = Vector<__fp16, 8>::load(input_ptr + IW * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + transform_mid_buf[iho * alpha + iwo] = + input[ic * IH * IW + ih * IW + iw]; + } + } +#define cb(i) d##i = Vector<__fp16, 8>::load(transform_mid_buf + alpha * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + } + +#define cb(i) Vector<__fp16, 8> wd##i, ret##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + INPUT_TRANSFORM(d, wd); + TRANSPOSE_8x8(wd, d); + INPUT_TRANSFORM(d, ret); + +#define cb(i) ret##i.save(transform_mid_buf + i * alpha); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + rep(i, alpha) rep(j, alpha) { + input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + + unit_idx * IC + ic] = + transform_mid_buf[j * alpha + i]; + } + } +}; +#undef INPUT_TRANSFORM + +#define OUTPUT_TRANSFORM(m, s) \ + do { \ + s##0 = m##0 + m##1 + m##2 + m##3 + m##4 + m##5 + m##6; \ + s##1 = m##1 - m##2 + m##3 * 0.5 - m##4 * 0.5 + m##5 * 2.0 - \ + m##6 * 2.0; \ + s##2 = m##1 + m##2 + m##3 * 0.25 + m##4 * 0.25 + m##5 * 4.0 + \ + m##6 * 4.0; \ + s##3 = m##1 - m##2 + m##3 * 0.125 - m##4 * 0.125 + m##5 * 8.0 - \ + m##6 * 8.0 + m##7; \ + } while (0) +template +struct OutputTransform4X5 { + static void transform(const dt_float16* output_transform_buf, + const dt_float16* bias, dt_float16* output, + dt_float16* transform_mid_buf, size_t oh_start, + size_t ow_start, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, size_t oc_index, + size_t unit_idx, size_t nr_units_in_tile, + const DType& src_dtype, const DType& dst_dtype) { + Op op(src_dtype, dst_dtype); + //! AT * m * A + // AT f45 + // 1.0 1.0 1.0 1.000 1.000 1.0 1.0 0.0 + // 0.0 1.0 -1.0 0.500 -0.500 2.0 -2.0 0.0 + // 0.0 1.0 1.0 0.250 0.250 4.0 4.0 0.0 + // 0.0 1.0 -1.0 0.125 -0.125 8.0 -8.0 1.0 + constexpr size_t alpha = 5 + 4 - 1; + const __fp16* fp16_output_transform_buf = + reinterpret_cast(output_transform_buf); + const __fp16* fp16_bias = reinterpret_cast(bias); + __fp16* fp16_output = reinterpret_cast<__fp16*>(output); + __fp16* fp16_transform_mid_buf = + reinterpret_cast<__fp16*>(transform_mid_buf); + + __fp16* mid_buf1 = fp16_transform_mid_buf; + + size_t OC = oc_end - oc_start; + size_t oc = oc_start + oc_index; + +#define cb(m, n) \ + fp16_transform_mid_buf[m * alpha + n] = \ + fp16_output_transform_buf[(m * alpha + n) * nr_units_in_tile * \ + OC + \ + unit_idx * OC + oc_index]; + UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); +#undef cb + +#define cb(i) \ + auto m##i = Vector<__fp16, 8>::load(fp16_transform_mid_buf + alpha * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb +#define cb(i) Vector<__fp16, 8> s##i; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb +#define cb(i) Vector<__fp16, 4> st##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb +#define cb(i) Vector<__fp16, 4> result##i; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + OUTPUT_TRANSFORM(m, s); + TRANSPOSE_4x8(s, st); + OUTPUT_TRANSFORM(st, result); + TRANSPOSE_4x4(result, result); + + if (oh_start + 4 <= OH && ow_start + 4 <= OW) { + int index = (oc * OH + oh_start) * OW + ow_start; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + float16x4_t bias0 = vdup_n_f16(fp16_bias[oc]); + result0.value = vadd_f16(result0.value, bias0); + result1.value = vadd_f16(result1.value, bias0); + result2.value = vadd_f16(result2.value, bias0); + result3.value = vadd_f16(result3.value, bias0); + } else if (bmode == BiasMode::BIAS) { + float16x4_t bmbias0 = vld1_f16(fp16_bias + index); + float16x4_t bmbias1 = vld1_f16(fp16_bias + index + OW); + float16x4_t bmbias2 = vld1_f16(fp16_bias + index + OW * 2); + float16x4_t bmbias3 = vld1_f16(fp16_bias + index + OW * 3); + result0.value = vadd_f16(result0.value, bmbias0); + result1.value = vadd_f16(result1.value, bmbias1); + result2.value = vadd_f16(result2.value, bmbias2); + result3.value = vadd_f16(result3.value, bmbias3); + } + + float16x8_t item01 = op(vcombine_f16(result0.value, result1.value)); + float16x8_t item23 = op(vcombine_f16(result2.value, result3.value)); + + vst1_f16(fp16_output + index, vget_low_f16(item01)); + vst1_f16(fp16_output + index + OW, vget_high_f16(item01)); + vst1_f16(fp16_output + index + OW * 2, vget_low_f16(item23)); + vst1_f16(fp16_output + index + OW * 3, vget_high_f16(item23)); + } else { +#define cb(i) result##i.save(mid_buf1 + i * 4); + mid_buf1 = fp16_transform_mid_buf; + UNROLL_CALL_NOWRAPPER(4, cb); + mid_buf1 = fp16_transform_mid_buf; +#undef cb + for (size_t oho = 0; oho < 4 && oh_start + oho < OH; ++oho) { + for (size_t owo = 0; owo < 4 && ow_start + owo < OW; ++owo) { + size_t oh = oh_start + oho; + size_t ow = ow_start + owo; + __fp16 res = mid_buf1[oho * 4 + owo]; + if (bmode == BiasMode::BIAS) { + res += fp16_bias[oc * OH * OW + oh * OW + ow]; + } else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + res += fp16_bias[oc]; + } + res = op(res); + fp16_output[oc * OH * OW + oh * OW + ow] = res; + } + } + } + } +}; +#undef OUTPUT_TRANSFORM +#undef GET_VECTOR_FP16Q_ELEM +} // namespace + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f16) + +void winograd_4x5_1x1_f16::filter(const dt_float16* filter, + dt_float16* filter_transform_buf, + dt_float16* transform_mid_buf, size_t OC, + size_t IC, size_t oc_start, size_t oc_end) { + FilterTransform4X5::transform( + reinterpret_cast(filter), + reinterpret_cast<__fp16*>(filter_transform_buf), + reinterpret_cast<__fp16*>(transform_mid_buf), OC, IC, oc_start, + oc_end); +} + +void winograd_4x5_1x1_f16::input(const dt_float16* input, + dt_float16* input_transform_buf, + dt_float16* transform_mid_buf, size_t IH, + size_t IW, size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, + size_t nr_units_in_tile) { + constexpr int alpha = 4 + 5 - 1; + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + rep(ic, IC) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransform4X5::transform( + reinterpret_cast(input), + reinterpret_cast<__fp16*>(input_transform_buf), + reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, + iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); + + } else { + InputTransform4X5::transform( + reinterpret_cast(input), + reinterpret_cast<__fp16*>(input_transform_buf), + reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, + iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); + } + } + } +} + +void winograd_4x5_1x1_f16::output(const dt_float16* output_transform_buf, + const dt_float16* bias, dt_float16* output, + dt_float16* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, + size_t unit_start_idx, size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform4X5<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + + for (size_t oc = oc_start; oc < oc_end; oc++) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_fp16_F45, cb, __fp16, __fp16, bmode, + nonline_mode, output_transform_buf, bias, output, transform_mid_buf, + oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, + nr_units_in_tile, src_dtype, dst_dtype); + } + } +#undef cb +} + +} // namespace winograd +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/strategy_6x3.cpp b/dnn/src/arm_common/conv_bias/f16/strategy_6x3.cpp new file mode 100644 index 00000000..adbb9b7e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/strategy_6x3.cpp @@ -0,0 +1,608 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/strategy_6x3.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/f16/helper.h" +#include "src/arm_common/conv_bias/f16/strategy.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/winograd/winograd.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F63) +using namespace megdnn; +using namespace arm_common; +namespace { +struct FilterTransform6X3 { + // 1.0000000 0.0000000 0.0000000 + //-0.2222222 -0.2222222 -0.2222222 + //-0.2222222 0.2222222 -0.2222222 + // 0.0111111 0.0222222 0.0444444 + // 0.0111111 -0.0222222 0.0444444 + // 0.7111111 0.3555556 0.1777778 + // 0.7111111 -0.3555556 0.1777778 + // 0.0000000 0.0000000 1.0000000 +#define FILTER_TRANSFORM(d, wd) \ + do { \ + wd##0 = d##0; \ + wd##1 = (d##0 + d##1 + d##2) * -0.222168; \ + wd##2 = (d##0 - d##1 + d##2) * -0.222168; \ + auto tmpd0 = d##0 * 0.011108; \ + auto tmpd1 = d##1 * 0.022217; \ + auto tmpd2 = d##2 * 0.044434; \ + wd##3 = tmpd0 + tmpd1 + tmpd2; \ + wd##4 = tmpd0 - tmpd1 + tmpd2; \ + tmpd0 = d##0 * 0.710938; \ + tmpd1 = d##1 * 0.355469; \ + tmpd2 = d##2 * 0.177734; \ + wd##5 = tmpd0 + tmpd1 + tmpd2; \ + wd##6 = tmpd0 - tmpd1 + tmpd2; \ + wd##7 = d##2; \ + } while (0); + +#define FILTER_TRANSFORM_FINAL(d, wd) \ + do { \ + wd##0 = d##0; \ + wd##1 = (d##0 + d##1 + d##2) * -0.222168; \ + wd##2 = (d##0 - d##1 + d##2) * -0.222168; \ + auto tmp0 = d##0 * 0.011108 + d##2 * 0.044434; \ + auto tmp1 = d##1 * 0.022217; \ + wd##3 = tmp0 + tmp1; \ + wd##4 = tmp0 - tmp1; \ + tmp0 = d##0 * 0.710938 + d##2 * 0.177734; \ + tmp1 = d##1 * 0.355469; \ + wd##5 = tmp0 + tmp1; \ + wd##6 = tmp0 - tmp1; \ + wd##7 = d##2; \ + } while (0); + static void transform(const __fp16* filter, __fp16* filter_transform_buf, + __fp16* transform_mid_buf, size_t OC, size_t IC, + size_t oc_start, size_t oc_end) { + // Gg * GT + // G + // 1.0000000 0.0000000 0.0000000 + //-0.2222222 -0.2222222 -0.2222222 + //-0.2222222 0.2222222 -0.2222222 + // 0.0111111 0.0222222 0.0444444 + // 0.0111111 -0.0222222 0.0444444 + // 0.7111111 0.3555556 0.1777778 + // 0.7111111 -0.3555556 0.1777778 + // 0.0000000 0.0000000 1.0000000 + constexpr size_t alpha = 6 + 3 - 1; + for (size_t oc = oc_start; oc < oc_end; oc++) { + rep(ic, IC) { + const __fp16* fptr = filter + (oc * IC + ic) * 3 * 3; + + float16x4_t v0, v1, v2, v3; + v0 = vld1_f16(fptr); // 0 1 2 3 + v1 = vld1_f16(fptr + 3); // 3 4 5 6 + v2 = vld1_f16(fptr + 5); // 5 6 7 8 + v3 = vdup_n_f16(0); + v2 = vext_f16(v2, v3, 1); + v0 = vset_lane_f16(0, v0, 3); + v1 = vset_lane_f16(0, v1, 3); +#define cb(i) Vector<__fp16, 4> g##i(v##i); + UNROLL_CALL_NOWRAPPER(3, cb); +#undef cb + +#define cb(i) Vector<__fp16, 4> Gg##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + +#define cb(i) Vector<__fp16, 8> Ggt##i; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(i) Vector<__fp16, 8> result##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + FILTER_TRANSFORM(g, Gg) +#if MEGDNN_AARCH64 + TRANSPOSE_8x4(Gg, Ggt); + FILTER_TRANSFORM_FINAL(Ggt, result); +#define cb(i) result##i.save(transform_mid_buf + i * alpha); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + rep(i, alpha) rep(j, alpha) { + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + + oc] = transform_mid_buf[j * alpha + i]; + } +#else + /* 1.0000000 -0.2222222 -0.2222222 0.0111111 0.0111111 + 0.7111111 0.7111111 0.0000000 0.0000000 -0.2222222 + 0.2222222 0.0222222 -0.0222222 0.3555556 -0.3555556 + 0.0000000 0.0000000 -0.2222222 -0.2222222 0.0444444 + 0.0444444 0.1777778 0.1777778 1.0000000*/ + +#define GET_VECTOR_FP16D_ELEM(s, i, idx) vget_lane_f16(CONCAT(s, i).value, idx) +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_FP16D_ELEM(Gg, i, 0); \ + auto tmp02 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) + \ + GET_VECTOR_FP16D_ELEM(Gg, i, 2); \ + mid_buf1[1] = (tmp02 + GET_VECTOR_FP16D_ELEM(Gg, i, 1)) * -0.2222222; \ + mid_buf1[2] = (tmp02 - GET_VECTOR_FP16D_ELEM(Gg, i, 1)) * -0.2222222; \ + auto tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.0111111; \ + auto tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.0222222; \ + auto tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.0444444; \ + tmp02 = tmp0 + tmp2; \ + mid_buf1[3] = tmp02 + tmp1; \ + mid_buf1[4] = tmp02 - tmp1; \ + tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.7111111; \ + tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.3555556; \ + tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.1777778; \ + tmp02 = tmp0 + tmp2; \ + mid_buf1[5] = tmp02 + tmp1; \ + mid_buf1[6] = tmp02 - tmp1; \ + mid_buf1[7] = GET_VECTOR_FP16D_ELEM(Gg, i, 2); \ + mid_buf1 += 8; \ + } while (0); + __fp16* mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(8, cb); + mid_buf1 = transform_mid_buf; +#undef cb + rep(i, alpha) rep(j, alpha) { + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + + oc] = transform_mid_buf[i * alpha + j]; + } +#undef GET_VECTOR_FP16D_ELEM +#endif + } + } + } +}; +#undef FILTER_TRANSFORM +#undef FILTER_TRANSFORM_FINAL +/** + * input transform + * + * wd0 = (d0 - d6) + 5.25 * (d4 - d2) + * wd1 = (d6 + d2 - 4.25 * d4) + (d1 + d5 - 4.25 * d3) + * wd2 = (d6 + d2 - 4.25 * d4) - (d1 + d5 - 4.25 * d3) + * wd3 = (d6 + 0.25 * d2 - 1.25 * d4) + 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) + * wd4 = (d6 + 0.25 * d2 - 1.25 * d4) - 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) + * wd5 = (d6 - 5.0 * d4 + 4.0 * d2) + 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) + * wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) + * wd7 = (d7 - d1) + 5.25 * (d3 - d5) + */ +#define INPUT_TRANSFORM(d, wd) \ + do { \ + wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25; \ + auto tmp0 = d##6 + d##2 - d##4 * 4.25; \ + auto tmp1 = d##1 + d##5 - d##3 * 4.25; \ + wd##1 = tmp0 + tmp1; \ + wd##2 = tmp0 - tmp1; \ + tmp0 = d##6 + d##2 * 0.25 - d##4 * 1.25; \ + tmp1 = (d##5 + d##1 * 0.25 - d##3 * 1.25) * 2.0; \ + wd##3 = tmp0 + tmp1; \ + wd##4 = tmp0 - tmp1; \ + tmp0 = d6 - d4 * 5.0 + d2 * 4.0; \ + tmp1 = (d1 + d5 * 0.25 - d3 * 1.25) * 2.0; \ + wd##5 = tmp0 + tmp1; \ + wd##6 = tmp0 - tmp1; \ + wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25; \ + } while (0); + +#define GET_VECTOR_FP16Q_ELEM(s, i, idx) vgetq_lane_f16(CONCAT(s, i).value, idx) +struct InputTransform6x3 { + template + static void transform(const __fp16* input, __fp16* input_transform_buf, + __fp16* transform_mid_buf, int ih_start, int iw_start, + size_t ic, size_t IH, size_t IW, size_t IC, + size_t unit_idx, size_t nr_units_in_tile) { + // BTd * B + // 1.000 0.000 -5.25 0.000 5.250 0.000 -1.0 0.00 + // -0.00 1.000 1.000 -4.25 -4.25 1.000 1.00 -0.0 + // -0.00 -1.00 1.000 4.250 -4.25 -1.00 1.00 -0.0 + // 0.000 0.500 0.250 -2.50 -1.25 2.000 1.00 0.00 + // 0.000 -0.50 0.250 2.500 -1.25 -2.00 1.00 0.00 + // 0.000 2.000 4.000 -2.50 -5.00 0.500 1.00 0.00 + // 0.000 -2.00 4.000 2.500 -5.00 -0.50 1.00 0.00 + // 0.000 -1.00 0.000 5.250 0.000 -5.25 0.00 1.00 + constexpr size_t alpha = 6 + 3 - 1; + if (!inner) { + memset(transform_mid_buf, 0, sizeof(__fp16) * alpha * alpha); + } + +#define cb(i) Vector<__fp16, 8> d##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + if (inner) { + const __fp16* input_ptr = + input + ic * IH * IW + ih_start * IW + iw_start; +#define cb(i) d##i = Vector<__fp16, 8>::load(input_ptr + IW * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + transform_mid_buf[iho * alpha + iwo] = + input[ic * IH * IW + ih * IW + iw]; + } + } +#define cb(i) d##i = Vector<__fp16, 8>::load(transform_mid_buf + alpha * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + } + +#define cb(i) Vector<__fp16, 8> wd##i, ret##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + INPUT_TRANSFORM(d, wd); + +#if MEGDNN_AARCH64 + TRANSPOSE_8x8(wd, d); + INPUT_TRANSFORM(d, ret); + +#define cb(i) ret##i.save(transform_mid_buf + i * alpha); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + rep(i, alpha) rep(j, alpha) { + input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + + unit_idx * IC + ic] = + transform_mid_buf[j * alpha + i]; + } +#else + //! 1 0 0 0 0 0 0 0 + //! 0 1 -1 0.5 -0.5 2 -2 -1 + //! -5.25 1 1 0.25 0.25 4 4 0 + //! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25 + //! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0 + //! 0 1 -1 2 -2 0.5 -0.5 -5.25 + //! -1 1 1 1 1 1 1 0 + //! 0 0 0 0 0 0 0 1 +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_FP16Q_ELEM(wd, i, 0) - \ + GET_VECTOR_FP16Q_ELEM(wd, i, 6) + \ + 5.25 * (GET_VECTOR_FP16Q_ELEM(wd, i, 4) - \ + GET_VECTOR_FP16Q_ELEM(wd, i, 2)); \ + mid_buf1[7] = GET_VECTOR_FP16Q_ELEM(wd, i, 7) - \ + GET_VECTOR_FP16Q_ELEM(wd, i, 1) + \ + 5.25 * (GET_VECTOR_FP16Q_ELEM(wd, i, 3) - \ + GET_VECTOR_FP16Q_ELEM(wd, i, 5)); \ + auto tmp0 = GET_VECTOR_FP16Q_ELEM(wd, i, 2) + \ + GET_VECTOR_FP16Q_ELEM(wd, i, 6) - \ + 4.25 * GET_VECTOR_FP16Q_ELEM(wd, i, 4); \ + auto tmp1 = GET_VECTOR_FP16Q_ELEM(wd, i, 1) - \ + GET_VECTOR_FP16Q_ELEM(wd, i, 3) * 4.25 + \ + GET_VECTOR_FP16Q_ELEM(wd, i, 5); \ + mid_buf1[1] = tmp0 + tmp1; \ + mid_buf1[2] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_FP16Q_ELEM(wd, i, 2) * 0.25 + \ + GET_VECTOR_FP16Q_ELEM(wd, i, 6) - \ + GET_VECTOR_FP16Q_ELEM(wd, i, 4) * 1.25; \ + tmp1 = GET_VECTOR_FP16Q_ELEM(wd, i, 1) * 0.5 - \ + GET_VECTOR_FP16Q_ELEM(wd, i, 3) * 2.5 + \ + GET_VECTOR_FP16Q_ELEM(wd, i, 5) * 2; \ + mid_buf1[3] = tmp0 + tmp1; \ + mid_buf1[4] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_FP16Q_ELEM(wd, i, 6) + \ + GET_VECTOR_FP16Q_ELEM(wd, i, 2) * 4.0 - \ + GET_VECTOR_FP16Q_ELEM(wd, i, 4) * 5.0; \ + tmp1 = GET_VECTOR_FP16Q_ELEM(wd, i, 1) * 2 - \ + GET_VECTOR_FP16Q_ELEM(wd, i, 3) * 2.5 + \ + GET_VECTOR_FP16Q_ELEM(wd, i, 5) * 0.5; \ + mid_buf1[5] = tmp0 + tmp1; \ + mid_buf1[6] = tmp0 - tmp1; \ + mid_buf1 += 8; \ + } while (0); + + __fp16* mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(8, cb); + mid_buf1 = transform_mid_buf; + +#undef cb + rep(i, alpha) rep(j, alpha) { + input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + + unit_idx * IC + ic] = + transform_mid_buf[i * alpha + j]; + } +#endif + } +}; + +#undef INPUT_TRANSFORM + +#define OUTPUT_TRANSFORM(m, r) \ + do { \ + auto m1addm2 = m##1 + m##2; \ + auto m1subm2 = m##1 - m##2; \ + auto m3addm4 = m##3 + m##4; \ + auto m3subm4 = m##3 - m##4; \ + auto m5addm6 = m##5 + m##6; \ + auto m5subm6 = m##5 - m##6; \ + r##0 = m##0 + m1addm2 + m3addm4 + m5addm6; \ + r##1 = m1subm2 + m3subm4 * 2.0 + m5subm6 * 0.5; \ + r##2 = m1addm2 + m3addm4 * 4.0 + m5addm6 * 0.25; \ + r##3 = m1subm2 + m3subm4 * 8.0 + m5subm6 * 0.125; \ + r##4 = m1addm2 + m3addm4 * 16.0 + m5addm6 * 0.0625; \ + r##5 = m1subm2 + m3subm4 * 32.0 + m5subm6 * 0.03125 + m##7; \ + } while (0) +template +struct OutputTransform6X3 { + static void transform(const dt_float16* output_transform_buf, + const dt_float16* bias, dt_float16* output, + dt_float16* transform_mid_buf, size_t oh_start, + size_t ow_start, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, size_t oc_index, + size_t unit_idx, size_t nr_units_in_tile, + const DType& src_dtype, const DType& dst_dtype) { + Op op(src_dtype, dst_dtype); + //! AT * m * A + // AT f45 + // 1.0 1.0 1.0 1.0 1.0 1.00000 1.00000 0.0 + // 0.0 1.0 -1.0 2.0 -2.0 0.50000 -0.50000 0.0 + // 0.0 1.0 1.0 4.0 4.0 0.25000 0.25000 0.0 + // 0.0 1.0 -1.0 8.0 -8.0 0.12500 -0.12500 0.0 + // 0.0 1.0 1.0 16.0 16.0 0.06250 0.06250 0.0 + // 0.0 1.0 -1.0 32.0 -32.0 0.03125 -0.03125 1.0 + constexpr size_t alpha = 3 + 6 - 1; + const __fp16* fp16_output_transform_buf = + reinterpret_cast(output_transform_buf); + const __fp16* fp16_bias = reinterpret_cast(bias); + __fp16* fp16_output = reinterpret_cast<__fp16*>(output); + __fp16* fp16_transform_mid_buf = + reinterpret_cast<__fp16*>(transform_mid_buf); + + __fp16* mid_buf1 = fp16_transform_mid_buf; + + size_t OC = oc_end - oc_start; + size_t oc = oc_start + oc_index; + +#define cb(m, n) \ + fp16_transform_mid_buf[m * alpha + n] = \ + fp16_output_transform_buf[(m * alpha + n) * nr_units_in_tile * \ + OC + \ + unit_idx * OC + oc_index]; + UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); +#undef cb + +#define cb(i) \ + auto m##i = Vector<__fp16, 8>::load(fp16_transform_mid_buf + alpha * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb +#define cb(i) Vector<__fp16, 8> s##i; + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + /* 1.0 0.0 0.00 0.000 0.0000 0.00000 + 1.0 1.0 1.00 1.000 1.0000 1.00000 + 1.0 -1.0 1.00 -1.000 1.0000 -1.00000 + 1.0 2.0 4.00 8.000 16.000 32.00000 + 1.0 -2.0 4.00 -8.000 16.000 -32.00000 + 1.0 0.5 0.25 0.125 0.0625 0.03125 + 1.0 -0.5 0.25 -0.125 0.0625 -0.03125 + 0.0 0.0 0.00 0.000 0.0000 1.00000*/ + + OUTPUT_TRANSFORM(m, s); + mid_buf1 = fp16_transform_mid_buf; + +#define cb(i) \ + do { \ + auto m1addm2 = GET_VECTOR_FP16Q_ELEM(s, i, 1) + \ + GET_VECTOR_FP16Q_ELEM(s, i, 2); \ + auto m1subm2 = GET_VECTOR_FP16Q_ELEM(s, i, 1) - \ + GET_VECTOR_FP16Q_ELEM(s, i, 2); \ + auto m3addm4 = GET_VECTOR_FP16Q_ELEM(s, i, 3) + \ + GET_VECTOR_FP16Q_ELEM(s, i, 4); \ + auto m3subm4 = GET_VECTOR_FP16Q_ELEM(s, i, 3) - \ + GET_VECTOR_FP16Q_ELEM(s, i, 4); \ + auto m5addm6 = GET_VECTOR_FP16Q_ELEM(s, i, 5) + \ + GET_VECTOR_FP16Q_ELEM(s, i, 6); \ + auto m5subm6 = GET_VECTOR_FP16Q_ELEM(s, i, 5) - \ + GET_VECTOR_FP16Q_ELEM(s, i, 6); \ + mid_buf1[0] = \ + GET_VECTOR_FP16Q_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ + mid_buf1[1] = m1subm2 + m3subm4 * 2 + m5subm6 * 0.5; \ + mid_buf1[2] = m1addm2 + m3addm4 * 4 + m5addm6 * 0.25; \ + mid_buf1[3] = m1subm2 + m3subm4 * 8 + m5subm6 * 0.125; \ + mid_buf1[4] = m1addm2 + m3addm4 * 16 + m5addm6 * 0.0625; \ + mid_buf1[5] = m1subm2 + m3subm4 * 32 + m5subm6 * 0.03125 + \ + GET_VECTOR_FP16Q_ELEM(s, i, 7); \ + mid_buf1 += 6; \ + } while (0); + mid_buf1 = fp16_transform_mid_buf; + UNROLL_CALL_NOWRAPPER(6, cb); + mid_buf1 = fp16_transform_mid_buf; + +#undef cb + + if (oh_start + 6 <= OH && ow_start + 6 <= OW) { + int index = (oc * OH + oh_start) * OW + ow_start; + +#define cb(i) float16x4_t vr##i = vld1_f16(mid_buf1 + i * 6); + + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + float16x8_t vr0123_45 = {mid_buf1[4], mid_buf1[5], mid_buf1[10], + mid_buf1[11], mid_buf1[16], mid_buf1[17], + mid_buf1[22], mid_buf1[23]}; + float16x4_t vr45_45 = {mid_buf1[28], mid_buf1[29], mid_buf1[34], + mid_buf1[35]}; + + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + float16x4_t bias0 = vdup_n_f16(fp16_bias[oc]); +#define cb(i) vr##i = vadd_f16(vr##i, bias0); + + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + vr45_45 = vadd_f16(vr45_45, bias0); + vr0123_45 = vaddq_f16(vr0123_45, vcombine_f16(bias0, bias0)); + + } else if (bmode == BiasMode::BIAS) { +#define cb(i) float16x4_t bmbias##i = vld1_f16(fp16_bias + index + OW * i); + + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + +#define cb(i) vr##i = vadd_f16(vr##i, bmbias##i); + + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + float16x8_t vb0123_45 = {fp16_bias[index + 0 * OW + 4], + fp16_bias[index + 0 * OW + 5], + fp16_bias[index + 1 * OW + 4], + fp16_bias[index + 1 * OW + 5], + fp16_bias[index + 2 * OW + 4], + fp16_bias[index + 2 * OW + 5], + fp16_bias[index + 3 * OW + 4], + fp16_bias[index + 3 * OW + 5]}; + float16x4_t vb45_45 = {fp16_bias[index + 4 * OW + 4], + fp16_bias[index + 4 * OW + 5], + fp16_bias[index + 5 * OW + 4], + fp16_bias[index + 5 * OW + 5]}; + vr45_45 = vadd_f16(vr45_45, vb45_45); + vr0123_45 = vaddq_f16(vr0123_45, vb0123_45); + } + + float16x8_t item01 = op(vcombine_f16(vr0, vr1)); + float16x8_t item23 = op(vcombine_f16(vr2, vr3)); + float16x8_t item45 = op(vcombine_f16(vr4, vr5)); + + vst1_f16(fp16_output + index, vget_low_f16(item01)); + vst1_f16(fp16_output + index + OW, vget_high_f16(item01)); + vst1_f16(fp16_output + index + OW * 2, vget_low_f16(item23)); + vst1_f16(fp16_output + index + OW * 3, vget_high_f16(item23)); + vst1_f16(fp16_output + index + OW * 4, vget_low_f16(item45)); + vst1_f16(fp16_output + index + OW * 5, vget_high_f16(item45)); + vr0123_45 = op(vr0123_45); + float16x8_t vr45 = op(vcombine_f16(vr45_45, vr45_45)); + + fp16_output[index + OW * 0 + 4] = vgetq_lane_f16(vr0123_45, 0); + fp16_output[index + OW * 0 + 5] = vgetq_lane_f16(vr0123_45, 1); + fp16_output[index + OW * 1 + 4] = vgetq_lane_f16(vr0123_45, 2); + fp16_output[index + OW * 1 + 5] = vgetq_lane_f16(vr0123_45, 3); + fp16_output[index + OW * 2 + 4] = vgetq_lane_f16(vr0123_45, 4); + fp16_output[index + OW * 2 + 5] = vgetq_lane_f16(vr0123_45, 5); + fp16_output[index + OW * 3 + 4] = vgetq_lane_f16(vr0123_45, 6); + fp16_output[index + OW * 3 + 5] = vgetq_lane_f16(vr0123_45, 7); + fp16_output[index + OW * 4 + 4] = vgetq_lane_f16(vr45, 0); + fp16_output[index + OW * 4 + 5] = vgetq_lane_f16(vr45, 1); + fp16_output[index + OW * 5 + 4] = vgetq_lane_f16(vr45, 2); + fp16_output[index + OW * 5 + 5] = vgetq_lane_f16(vr45, 3); + } else { + for (size_t oho = 0; oho < 6 && oh_start + oho < OH; ++oho) { + for (size_t owo = 0; owo < 6 && ow_start + owo < OW; ++owo) { + size_t oh = oh_start + oho; + size_t ow = ow_start + owo; + __fp16 res = mid_buf1[oho * 6 + owo]; + if (bmode == BiasMode::BIAS) { + res += fp16_bias[oc * OH * OW + oh * OW + ow]; + } else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + res += fp16_bias[oc]; + } + res = op(res); + fp16_output[oc * OH * OW + oh * OW + ow] = res; + } + } + } + } +}; +#undef GET_VECTOR_FP16Q_ELEM +#undef OUTPUT_TRANSFORM +} // namespace + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f16) + +void winograd_6x3_1x1_f16::filter(const dt_float16* filter, + dt_float16* filter_transform_buf, + dt_float16* transform_mid_buf, size_t OC, + size_t IC, size_t oc_start, size_t oc_end) { + FilterTransform6X3::transform( + reinterpret_cast(filter), + reinterpret_cast<__fp16*>(filter_transform_buf), + reinterpret_cast<__fp16*>(transform_mid_buf), OC, IC, oc_start, + oc_end); +} + +void winograd_6x3_1x1_f16::input(const dt_float16* input, + dt_float16* input_transform_buf, + dt_float16* transform_mid_buf, size_t IH, + size_t IW, size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, + size_t nr_units_in_tile) { + constexpr int alpha = 6 + 3 - 1; + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + rep(ic, IC) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransform6x3::transform( + reinterpret_cast(input), + reinterpret_cast<__fp16*>(input_transform_buf), + reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, + iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); + + } else { + InputTransform6x3::transform( + reinterpret_cast(input), + reinterpret_cast<__fp16*>(input_transform_buf), + reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, + iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); + } + } + } +} + +void winograd_6x3_1x1_f16::output(const dt_float16* output_transform_buf, + const dt_float16* bias, dt_float16* output, + dt_float16* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, + size_t unit_start_idx, size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + + for (size_t oc = oc_start; oc < oc_end; oc++) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_fp16_F63, cb, __fp16, __fp16, bmode, + nonline_mode, output_transform_buf, bias, output, transform_mid_buf, + oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, + nr_units_in_tile, src_dtype, dst_dtype); + } + } +#undef cb +} +} // namespace winograd +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.cpp b/dnn/src/arm_common/conv_bias/fp32/algos.cpp new file mode 100644 index 00000000..28e08032 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/algos.cpp @@ -0,0 +1,754 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/fp32/algos.h" +#include "src/arm_common/conv_bias/direct/multi_thread_common.h" +#include "src/arm_common/conv_bias/fp32/direct.h" +#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" +#include "src/arm_common/conv_bias/fp32/do_conv_stride2.h" +#include "src/arm_common/conv_bias/fp32/strategy.h" +#include "src/arm_common/conv_bias/img2col_helper.h" +#include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/common/opr_delegate.h" +#include "src/fallback/conv_bias/common.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_common_winograd_fp32) + +using namespace megdnn; +using namespace arm_common; + +/* ======================= AlgoFP32WinogradF23_4x4 ======================== */ + +bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( + fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(opr); + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 0) { + if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) + return false; + using Strategy = winograd::winograd_2x3_4x4_f; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + (opr->param().format == param::ConvBias::Format::NCHW || + (opr->param().format == + param::ConvBias::Format::NCHW_WINOGRAD && + opr->param().output_block_size == 2 && + param.winograd_matmul_format == + param::MatrixMul::Format::MK4)) && + opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == + param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float32; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoFP32WinogradF23_4x4::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 1) { + winograd::winograd_2x3_4x4_f strategy(param.src_type, param.filter_type, + param.dst_type); + return megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_workspace_size(param, m_matmul_algo); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoFP32WinogradF23_4x4::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 2) { + winograd::winograd_2x3_4x4_f strategy(param.src_type, param.filter_type, + param.dst_type); + auto winograd_impl = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg); + return winograd_impl.get_kerns(param, m_matmul_algo); + } + MIDOUT_END(); + return {}; +} + +/* ======================= AlgoFP32WinogradF63 ======================== */ + +bool ConvBiasImpl::AlgoFP32WinogradF63::usable( + fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(opr); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 0) { + using Strategy = winograd::winograd_6x3_1x1_f; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + (opr->param().format == param::ConvBias::Format::NCHW || + (opr->param().format == + param::ConvBias::Format::NCHW_WINOGRAD && + opr->param().output_block_size == 6 && + param.winograd_matmul_format == + param::MatrixMul::Format::DEFAULT)) && + opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == + param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float32; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoFP32WinogradF63::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 1) { + winograd::winograd_6x3_1x1_f strategy(param.src_type, param.filter_type, + param.dst_type); + return megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_workspace_size(param, m_matmul_algo); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoFP32WinogradF63::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 2) { + winograd::winograd_6x3_1x1_f strategy(param.src_type, param.filter_type, + param.dst_type); + auto winograd_impl = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg); + return winograd_impl.get_kerns(param, m_matmul_algo); + } + MIDOUT_END(); + return {}; +} + +/* ======================= AlgoFP32WinogradF54 ======================== */ + +bool ConvBiasImpl::AlgoFP32WinogradF54::usable( + fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(opr); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 0) { + using Strategy = winograd::winograd_5x4_1x1_f; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + (opr->param().format == param::ConvBias::Format::NCHW || + (opr->param().format == + param::ConvBias::Format::NCHW_WINOGRAD && + opr->param().output_block_size == 5 && + param.winograd_matmul_format == + param::MatrixMul::Format::DEFAULT)) && + opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 4) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == + param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float32; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoFP32WinogradF54::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 1) { + winograd::winograd_5x4_1x1_f strategy(param.src_type, param.filter_type, + param.dst_type); + return megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_workspace_size(param, m_matmul_algo); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoFP32WinogradF54::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 2) { + winograd::winograd_5x4_1x1_f strategy(param.src_type, param.filter_type, + param.dst_type); + auto winograd_impl = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg); + return winograd_impl.get_kerns(param, m_matmul_algo); + } + MIDOUT_END(); + return {}; +} + +/* ======================= AlgoFP32WinogradF45 ======================== */ + +bool ConvBiasImpl::AlgoFP32WinogradF45::usable( + fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(opr); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 0) { + using Strategy = winograd::winograd_4x5_1x1_f; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + (opr->param().format == param::ConvBias::Format::NCHW || + (opr->param().format == + param::ConvBias::Format::NCHW_WINOGRAD && + opr->param().output_block_size == 4 && + param.winograd_matmul_format == + param::MatrixMul::Format::DEFAULT)) && + opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 5) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == + param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float32; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoFP32WinogradF45::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 1) { + winograd::winograd_4x5_1x1_f strategy(param.src_type, param.filter_type, + param.dst_type); + return megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_workspace_size(param, m_matmul_algo); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoFP32WinogradF45::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 2) { + winograd::winograd_4x5_1x1_f strategy(param.src_type, param.filter_type, + param.dst_type); + auto winograd_impl = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg); + return winograd_impl.get_kerns(param, m_matmul_algo); + } + MIDOUT_END(); + return {}; +} + +/* ======================= AlgoFP32WinogradF63_4x4 ======================== */ + +bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( + fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(opr); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 0) { + if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) + return false; + using Strategy = winograd::winograd_6x3_4x4_f; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + (opr->param().format == param::ConvBias::Format::NCHW || + (opr->param().format == + param::ConvBias::Format::NCHW_WINOGRAD && + opr->param().output_block_size == 6 && + param.winograd_matmul_format == + param::MatrixMul::Format::MK4)) && + opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == + param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_meta.icpg % 4 == 0 && + param.filter_meta.ocpg % 4 == 0; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoFP32WinogradF63_4x4::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 1) { + winograd::winograd_6x3_4x4_f strategy(param.src_type, param.filter_type, + param.dst_type); + return megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_workspace_size(param, m_matmul_algo); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoFP32WinogradF63_4x4::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 2) { + winograd::winograd_6x3_4x4_f strategy(param.src_type, param.filter_type, + param.dst_type); + auto winograd_impl = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg); + return winograd_impl.get_kerns(param, m_matmul_algo); + } + MIDOUT_END(); + return {}; +} + +/* ===================== direct algo ===================== */ +MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl); + +bool ConvBiasImpl::AlgoF32Direct::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto SH = fm.stride[0], SW = fm.stride[1]; + // the condition ``param.isz[0]*param.isz[1] >= 4'' and + // ``param.osz[0]*param.osz[1] >= 4'' comes from the fact that the + // kernel may have access to up to 4 floats after the end of the memory + // chunk. + bool aviliable = fm.format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + param.dst_type.enumv() == DTypeEnum::Float32 && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && + param.isz[0] * param.isz[1] >= 4 && + param.osz[0] * param.osz[1] >= 4 && FH <= 7 && + SH == 1 && SW == 1; + if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + aviliable &= (large_group == m_large_group); + } + return aviliable; + } + MIDOUT_END(); + return false; +} +size_t ConvBiasImpl::AlgoF32Direct::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { + auto wbundle = MultithreadDirectConvCommon::get_bundle( + param, m_large_group); + return wbundle.total_size_in_bytes(); + } + MIDOUT_END(); + return 0; +} +SmallVector ConvBiasImpl::AlgoF32Direct::get_kimpls( + const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = + MultithreadDirectConvCommon::get_bundle( + param, m_large_group); + SmallVector ret_kerns; + //! When group >= nr_threads, treat it as large_group, each thread process + //! one group for better performance + if (m_large_group) { + //! Channel wise conv and big groups + auto exec_one_group = [wbundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + if (fm.should_flip) { + for (size_t oc = 0; oc < OC; oc++) { + MultithreadDirectConvCommon::weight_flip_kern( + bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); + } + } + for (size_t ic = 0; ic < IC; ic++) { + MultithreadDirectConvCommon::copy_padding_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + MultithreadDirectConvCommon::do_conv_kern( + bundle, kern_param, ncb_index, + fp32::conv_bias::kern_direct, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + if (fm.should_flip) { + auto weight_flip = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + MultithreadDirectConvCommon::weight_flip_kern( + bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({weight_flip, {group, 1_z, OC}}); + } + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + MultithreadDirectConvCommon::copy_padding_kern( + bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + MultithreadDirectConvCommon::do_conv_kern( + bundle, kern_param, ncb_index, fp32::conv_bias::kern_direct, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} + +SmallVector ConvBiasImpl::AlgoF32Direct::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { + return get_kimpls(param); + } + MIDOUT_END(); + return {}; +} +/* ===================== stride-1 algo ===================== */ +bool ConvBiasImpl::AlgoF32DirectStride1::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + bool aviliable = + param.filter_meta.format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + param.dst_type.enumv() == DTypeEnum::Float32 && + !fm.should_flip && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); + if (algo_selection_strategy == + ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + aviliable &= (large_group == m_large_group); + } + return aviliable; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { + auto bundle = + MultithreadDirectConvCommon::get_bundle_stride( + param, m_large_group); + return bundle.total_size_in_bytes(); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( + const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + auto FH = fm.spatial[0]; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + using Func = std::function; + Func conv_kern_function = nullptr; + +#define SWITCH_KERN_STR1() \ + switch (FH) { \ + case 2: \ + conv_kern_function = fp32::conv_stride1::do_conv_2x2_stride1; \ + break; \ + case 3: \ + conv_kern_function = fp32::conv_stride1::do_conv_3x3_stride1; \ + break; \ + case 5: \ + conv_kern_function = fp32::conv_stride1::do_conv_5x5_stride1; \ + break; \ + case 7: \ + conv_kern_function = fp32::conv_stride1::do_conv_7x7_stride1; \ + break; \ + } + SWITCH_KERN_STR1(); + + WorkspaceBundle wbundle = + MultithreadDirectConvCommon::get_bundle_stride( + param, m_large_group); + SmallVector ret_kerns; + //! When group >= nr_threads, treat it as large_group, each thread process + //! one group for better performance + if (m_large_group) { + //! Channel wise conv and big groups + auto exec_one_group = [wbundle, conv_kern_function]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + MultithreadDirectConvCommon:: + copy_padding_kern_stride(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + MultithreadDirectConvCommon::do_conv_kern_stride( + bundle, kern_param, ncb_index, conv_kern_function, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + MultithreadDirectConvCommon::copy_padding_kern_stride( + bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle, conv_kern_function]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + MultithreadDirectConvCommon::do_conv_kern_stride( + bundle, kern_param, ncb_index, conv_kern_function, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} + +SmallVector +ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 2) { + return get_kimpls(param); + } + MIDOUT_END(); + return {}; +} + +/* ===================== stride-2 algo ===================== */ + +bool ConvBiasImpl::AlgoF32DirectStride2::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + bool aviliable = + param.filter_meta.format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + param.dst_type.enumv() == DTypeEnum::Float32 && + !fm.should_flip && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); + if (algo_selection_strategy == + ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + aviliable &= (large_group == m_large_group); + } + return aviliable; + } + MIDOUT_END(); + return false; +} +size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) { + auto bundle = + MultithreadDirectConvCommon::get_bundle_stride( + param, m_large_group); + return bundle.total_size_in_bytes(); + } + MIDOUT_END(); + return 0; +} +SmallVector +ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( + const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + auto FH = fm.spatial[0]; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + using Func = std::function; + Func conv_kern_function = nullptr; + +#define SWITCH_KERN_STR2() \ + switch (FH) { \ + case 2: \ + conv_kern_function = fp32::conv_stride2::do_conv_2x2_stride2; \ + break; \ + case 3: \ + conv_kern_function = fp32::conv_stride2::do_conv_3x3_stride2; \ + break; \ + case 5: \ + conv_kern_function = fp32::conv_stride2::do_conv_5x5_stride2; \ + break; \ + case 7: \ + conv_kern_function = fp32::conv_stride2::do_conv_7x7_stride2; \ + break; \ + } + SWITCH_KERN_STR2(); + + WorkspaceBundle wbundle = + MultithreadDirectConvCommon::get_bundle_stride( + param, m_large_group); + SmallVector ret_kerns; + //! When group >= nr_threads, treat it as large_group, each thread process + //! one group for better performance + if (m_large_group) { + //! Channel wise conv and big groups + auto exec_one_group = [wbundle, conv_kern_function]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + MultithreadDirectConvCommon:: + copy_padding_kern_stride(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + MultithreadDirectConvCommon::do_conv_kern_stride( + bundle, kern_param, ncb_index, conv_kern_function, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + MultithreadDirectConvCommon::copy_padding_kern_stride( + bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle, conv_kern_function]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + MultithreadDirectConvCommon::do_conv_kern_stride( + bundle, kern_param, ncb_index, conv_kern_function, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} + +SmallVector +ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 2) { + return get_kimpls(param); + } + MIDOUT_END(); + return {}; +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.h b/dnn/src/arm_common/conv_bias/fp32/algos.h new file mode 100644 index 00000000..c97e7fc3 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/algos.h @@ -0,0 +1,223 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/matrix_mul/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class ConvBiasImpl::AlgoFP32WinogradF23_4x4 final : public AlgoBase { +public: + AlgoFP32WinogradF23_4x4(fallback::MatrixMulImpl::AlgoBase* matmul_algo, + uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {4, 2, m_tile_size}); + } + return m_name.c_str(); + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + +private: + fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; + mutable std::string m_name; + uint32_t m_tile_size; +}; + +class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { +public: + AlgoFP32WinogradF63(fallback::MatrixMulImpl::AlgoBase* matmul_algo, + uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {1, 6, m_tile_size}); + } + return m_name.c_str(); + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + +private: + fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; + mutable std::string m_name; + + uint32_t m_tile_size; +}; + +class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { +public: + AlgoFP32WinogradF63_4x4(fallback::MatrixMulImpl::AlgoBase* matmul_algo, + uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {4, 6, m_tile_size}); + } + return m_name.c_str(); + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + +private: + fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; + mutable std::string m_name; + + uint32_t m_tile_size; +}; + +class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { +public: + AlgoFP32WinogradF54(fallback::MatrixMulImpl::AlgoBase* matmul_algo, + uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {1, 5, m_tile_size}); + } + return m_name.c_str(); + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + +private: + fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; + mutable std::string m_name; + + uint32_t m_tile_size; +}; + +class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { +public: + AlgoFP32WinogradF45(fallback::MatrixMulImpl::AlgoBase* matmul_algo, + uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {1, 4, m_tile_size}); + } + return m_name.c_str(); + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + +private: + fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; + mutable std::string m_name; + + uint32_t m_tile_size; +}; + + +class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + bool m_large_group; + +public: + AlgoF32Direct(bool is_large_group) : m_large_group{is_large_group} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "F32DIRECT_LARGE_GROUP" + : "F32DIRECT_SMALL_GROUP"; + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + bool m_large_group; + +public: + AlgoF32DirectStride1(bool is_large_group) : m_large_group{is_large_group} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "F32STRD1_LARGE_GROUP" : "F32STRD1_SMALL_GROUP"; + } + bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + bool m_large_group; + +public: + AlgoF32DirectStride2(bool is_large_group) : m_large_group{is_large_group} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "F32STRD2_LARGE_GROUP" : "F32STRD2_SMALL_GROUP"; + } + bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct.cpp b/dnn/src/arm_common/conv_bias/fp32/direct.cpp new file mode 100644 index 00000000..b55d21f8 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct.cpp @@ -0,0 +1,911 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/direct.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include "include/megdnn/oprs.h" +#include "midout.h" +#include "src/arm_common/conv_bias/fp32/direct.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/common/unroll_macro.h" +MIDOUT_DECL(megdnn_arm_conv_f32) + +using namespace megdnn; +using namespace arm_common; +using namespace fp32; +using namespace conv_bias; + +namespace { + +template +struct do_pixel_proxy { + static void exec(const float* src, const float* filter, float* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow); +}; + +#define cb_load(i) data = vld1q_lane_f32(dst + i, data, i); +#define LOAD_OUT \ + if (width < 4) { \ + auto load_less_4 = [](float* dst, float32x4_t& data) { \ + if (width == 1u) { \ + UNROLL_CALL_NOWRAPPER(1, cb_load); \ + } else if (width == 2u) { \ + UNROLL_CALL_NOWRAPPER(2, cb_load); \ + } else if (width == 3u) { \ + UNROLL_CALL_NOWRAPPER(3, cb_load); \ + } \ + }; \ + if (height >= 1) \ + load_less_4(dst + 0 * OW, out0); \ + if (height >= 2) \ + load_less_4(dst + 1 * OW, out1); \ + if (height >= 3) \ + load_less_4(dst + 2 * OW, out2); \ + if (height >= 4) \ + load_less_4(dst + 3 * OW, out3); \ + } else { \ + if (height > 0) \ + out0 = vld1q_f32(dst + 0 * OW); \ + if (height > 1) \ + out1 = vld1q_f32(dst + 1 * OW); \ + if (height > 2) \ + out2 = vld1q_f32(dst + 2 * OW); \ + if (height > 3) \ + out3 = vld1q_f32(dst + 3 * OW); \ + } +#define cb_store(i) vst1q_lane_f32(dst + i, data, i); +#define STORE_OUT \ + if (width < 4) { \ + auto store_less_4 = [](float* dst, float32x4_t& data) { \ + if (width == 1u) { \ + UNROLL_CALL_NOWRAPPER(1, cb_store); \ + } else if (width == 2u) { \ + UNROLL_CALL_NOWRAPPER(2, cb_store); \ + } else if (width == 3u) { \ + UNROLL_CALL_NOWRAPPER(3, cb_store); \ + } \ + }; \ + if (height >= 1) \ + store_less_4(dst + 0 * OW, out0); \ + if (height >= 2) \ + store_less_4(dst + 1 * OW, out1); \ + if (height >= 3) \ + store_less_4(dst + 2 * OW, out2); \ + if (height >= 4) \ + store_less_4(dst + 3 * OW, out3); \ + } else { \ + if (height >= 1) \ + vst1q_f32(dst + 0 * OW, out0); \ + if (height >= 2) \ + vst1q_f32(dst + 1 * OW, out1); \ + if (height >= 3) \ + vst1q_f32(dst + 2 * OW, out2); \ + if (height >= 4) \ + vst1q_f32(dst + 3 * OW, out3); \ + } + +template +struct do_pixel_proxy<1, height, width> { + static void exec(const float* src, const float* filter, float* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + (void)IH; + (void)OH; + const int ih = oh, iw = ow; + float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_OUT; + for (int fw = 0; fw < FW; ++fw) { + const float* src_dd = src + fw; + kr0 = vdupq_n_f32(filter[0 * FW + fw]); + + if (height > 0) + inp = vld1q_f32(src_dd + 0 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr0); + + if (height > 1) + inp = vld1q_f32(src_dd + 1 * IW); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr0); + + if (height > 2) + inp = vld1q_f32(src_dd + 2 * IW); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr0); + + if (height > 3) + inp = vld1q_f32(src_dd + 3 * IW); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr0); + } + STORE_OUT; + } +}; + +template +struct do_pixel_proxy<2, height, width> { + static void exec(const float* src, const float* filter, float* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + (void)IH; + (void)OH; + const int ih = oh, iw = ow; + float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_OUT; + for (int fw = 0; fw < FW; ++fw) { + const float* src_dd = src + fw; + kr0 = vdupq_n_f32(filter[0 * FW + fw]); + kr1 = vdupq_n_f32(filter[1 * FW + fw]); + + if (height > 0) + inp = vld1q_f32(src_dd + 0 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 1 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr1); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr0); + + if (height > 1) + inp = vld1q_f32(src_dd + 2 * IW); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr1); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr0); + + if (height > 2) + inp = vld1q_f32(src_dd + 3 * IW); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr1); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr0); + + if (height > 3) + inp = vld1q_f32(src_dd + 4 * IW); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr1); + } + STORE_OUT; + } +}; + +template +struct do_pixel_proxy<3, height, width> { + static void exec(const float* src, const float* filter, float* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + (void)IH; + (void)OH; + const int ih = oh, iw = ow; + float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_OUT; + for (int fw = 0; fw < FW; ++fw) { + const float* src_dd = src + fw; + kr0 = vdupq_n_f32(filter[0 * FW + fw]); + kr1 = vdupq_n_f32(filter[1 * FW + fw]); + kr2 = vdupq_n_f32(filter[2 * FW + fw]); + + if (height > 0) + inp = vld1q_f32(src_dd + 0 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 1 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr1); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 2 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr2); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr1); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr0); + + if (height > 1) + inp = vld1q_f32(src_dd + 3 * IW); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr2); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr1); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr0); + + if (height > 2) + inp = vld1q_f32(src_dd + 4 * IW); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr2); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr1); + + if (height > 3) + inp = vld1q_f32(src_dd + 5 * IW); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr2); + } + STORE_OUT; + } +}; + +template +struct do_pixel_proxy<4, height, width> { + static void exec(const float* src, const float* filter, float* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + (void)IH; + (void)OH; + const int ih = oh, iw = ow; + float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_OUT; + for (int fw = 0; fw < FW; ++fw) { + const float* src_dd = src + fw; + kr0 = vdupq_n_f32(filter[0 * FW + fw]); + kr1 = vdupq_n_f32(filter[1 * FW + fw]); + kr2 = vdupq_n_f32(filter[2 * FW + fw]); + kr3 = vdupq_n_f32(filter[3 * FW + fw]); + + if (height > 0) + inp = vld1q_f32(src_dd + 0 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 1 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr1); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 2 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr2); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr1); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 3 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr3); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr2); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr1); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr0); + + if (height > 1) + inp = vld1q_f32(src_dd + 4 * IW); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr3); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr2); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr1); + + if (height > 2) + inp = vld1q_f32(src_dd + 5 * IW); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr3); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr2); + + if (height > 3) + inp = vld1q_f32(src_dd + 6 * IW); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr3); + } + STORE_OUT; + } +}; + +template +struct do_pixel_proxy<5, height, width> { + static void exec(const float* src, const float* filter, float* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + (void)IH; + (void)OH; + const int ih = oh, iw = ow; + float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, + inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_OUT; + for (int fw = 0; fw < FW; ++fw) { + const float* src_dd = src + fw; + kr0 = vdupq_n_f32(filter[0 * FW + fw]); + kr1 = vdupq_n_f32(filter[1 * FW + fw]); + kr2 = vdupq_n_f32(filter[2 * FW + fw]); + kr3 = vdupq_n_f32(filter[3 * FW + fw]); + kr4 = vdupq_n_f32(filter[4 * FW + fw]); + + if (height > 0) + inp = vld1q_f32(src_dd + 0 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 1 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr1); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 2 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr2); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr1); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 3 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr3); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr2); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr1); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 4 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr4); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr3); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr2); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr1); + + if (height > 1) + inp = vld1q_f32(src_dd + 5 * IW); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr4); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr3); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr2); + + if (height > 2) + inp = vld1q_f32(src_dd + 6 * IW); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr4); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr3); + + if (height > 3) + inp = vld1q_f32(src_dd + 7 * IW); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr4); + } + STORE_OUT; + } +}; + +template +struct do_pixel_proxy<6, height, width> { + static void exec(const float* src, const float* filter, float* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + (void)IH; + (void)OH; + const int ih = oh, iw = ow; + float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, + kr5, inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_OUT; + for (int fw = 0; fw < FW; ++fw) { + const float* src_dd = src + fw; + kr0 = vdupq_n_f32(filter[0 * FW + fw]); + kr1 = vdupq_n_f32(filter[1 * FW + fw]); + kr2 = vdupq_n_f32(filter[2 * FW + fw]); + kr3 = vdupq_n_f32(filter[3 * FW + fw]); + kr4 = vdupq_n_f32(filter[4 * FW + fw]); + kr5 = vdupq_n_f32(filter[5 * FW + fw]); + + if (height > 0) + inp = vld1q_f32(src_dd + 0 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 1 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr1); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 2 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr2); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr1); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 3 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr3); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr2); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr1); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 4 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr4); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr3); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr2); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr1); + + if (height > 0) + inp = vld1q_f32(src_dd + 5 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr5); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr4); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr3); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr2); + + if (height > 1) + inp = vld1q_f32(src_dd + 6 * IW); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr5); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr4); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr3); + + if (height > 2) + inp = vld1q_f32(src_dd + 7 * IW); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr5); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr4); + + if (height > 3) + inp = vld1q_f32(src_dd + 8 * IW); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr5); + } + STORE_OUT; + } +}; + +template +struct do_pixel_proxy<7, height, width> { + static void exec(const float* src, const float* filter, float* dst, + const int IH, const int IW, const int OH, const int OW, + const int FW, const int oh, const int ow) { + (void)IH; + (void)OH; + const int ih = oh, iw = ow; + float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, + kr5, kr6, inp; + src += ih * IW + iw; + dst += oh * OW + ow; + LOAD_OUT; + for (int fw = 0; fw < FW; ++fw) { + const float* src_dd = src + fw; + kr0 = vdupq_n_f32(filter[0 * FW + fw]); + kr1 = vdupq_n_f32(filter[1 * FW + fw]); + kr2 = vdupq_n_f32(filter[2 * FW + fw]); + kr3 = vdupq_n_f32(filter[3 * FW + fw]); + kr4 = vdupq_n_f32(filter[4 * FW + fw]); + kr5 = vdupq_n_f32(filter[5 * FW + fw]); + kr6 = vdupq_n_f32(filter[6 * FW + fw]); + + if (height > 0) + inp = vld1q_f32(src_dd + 0 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 1 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr1); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 2 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr2); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr1); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 3 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr3); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr2); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr1); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr0); + + if (height > 0) + inp = vld1q_f32(src_dd + 4 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr4); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr3); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr2); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr1); + + if (height > 0) + inp = vld1q_f32(src_dd + 5 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr5); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr4); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr3); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr2); + + if (height > 0) + inp = vld1q_f32(src_dd + 6 * IW); + if (height > 0) + out0 = vmlaq_f32(out0, inp, kr6); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr5); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr4); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr3); + + if (height > 1) + inp = vld1q_f32(src_dd + 7 * IW); + if (height > 1) + out1 = vmlaq_f32(out1, inp, kr6); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr5); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr4); + + if (height > 2) + inp = vld1q_f32(src_dd + 8 * IW); + if (height > 2) + out2 = vmlaq_f32(out2, inp, kr6); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr5); + + if (height > 3) + inp = vld1q_f32(src_dd + 9 * IW); + if (height > 3) + out3 = vmlaq_f32(out3, inp, kr6); + } + STORE_OUT; + } +}; +#undef cb_load +#undef cb_load +#undef LOAD_OUT +#undef STORE_OUT + +template +void do_pixel(const float* src, const float* filter, float* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, + const int oh, const int ow) { + do_pixel_proxy::exec(src, filter, dst, IH, IW, OH, OW, + FW, oh, ow); +} + +template +void do_conv_tpl_enable_prefetch(const float* src, const float* filter, + float* dst, const int IH, const int IW, + const int OH, const int OW, const int FW) { + const int hbeg = 0, hend = OH; + const int wbeg = 0, wend = OW; + int i, j; + for (i = hbeg; i + 4 <= hend; i += 4) { + for (j = wbeg; j + 4 <= wend; j += 4) { + // do prefetch + const int prefetch_index_input = + (j + 16) < wend + ? i * IW + j + 16 + : (i + 4) * IW + (((j + 16 - wend) >> 2) << 2); + const int prefetch_index_output = + (j + 16) < wend + ? i * OW + j + 16 + : (i + 4) * OW + (((j + 16 - wend) >> 2) << 2); + const float* src_prefetch = src + prefetch_index_input; + const float* dst_prefetch = dst + prefetch_index_output; + for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { + __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); + } + __builtin_prefetch(dst_prefetch + 0 * OW, 1, 3); + __builtin_prefetch(dst_prefetch + 1 * OW, 1, 3); + __builtin_prefetch(dst_prefetch + 2 * OW, 1, 3); + __builtin_prefetch(dst_prefetch + 3 * OW, 1, 3); + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); + } +#define DISPATCH(width) \ + do { \ + const int prefetch_index_input = (i + 4) * IW + 12; \ + const int prefetch_index_output = (i + 4) * OW + 12; \ + const float* src_prefetch = src + prefetch_index_input; \ + const float* dst_prefetch = dst + prefetch_index_output; \ + for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ + __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ + } \ + __builtin_prefetch(dst_prefetch + 0 * OW, 1, 3); \ + __builtin_prefetch(dst_prefetch + 1 * OW, 1, 3); \ + __builtin_prefetch(dst_prefetch + 2 * OW, 1, 3); \ + __builtin_prefetch(dst_prefetch + 3 * OW, 1, 3); \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ + } while (0) + switch (wend - j) { + case 1: + DISPATCH(1); + break; + case 2: + DISPATCH(2); + break; + case 3: + DISPATCH(3); + break; + } +#undef DISPATCH + } + +#define DISPATCH2(height, width) \ + do { \ + const int prefetch_index_input = IH * IW + 12; \ + const float* src_prefetch = src + prefetch_index_input; \ + for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ + __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ + } \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ + j); \ + } while (0) + +#define DISPATCH1(height) \ + do { \ + for (j = wbeg; j + 4 <= wend; j += 4) { \ + const int prefetch_index_input = \ + (j + 16) < wend \ + ? i * IW + j + 16 \ + : (i + 4) * IW + (((j + 16 - wend) >> 2) << 2); \ + const int prefetch_index_output = \ + (j + 16) < wend \ + ? i * OW + j + 16 \ + : (i + 4) * OW + (((j + 16 - wend) >> 2) << 2); \ + const float* src_prefetch = src + prefetch_index_input; \ + const float* dst_prefetch = dst + prefetch_index_output; \ + for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ + __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ + } \ + __builtin_prefetch(dst_prefetch + 0 * OW, 1, 3); \ + __builtin_prefetch(dst_prefetch + 1 * OW, 1, 3); \ + __builtin_prefetch(dst_prefetch + 2 * OW, 1, 3); \ + __builtin_prefetch(dst_prefetch + 3 * OW, 1, 3); \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ + j); \ + } \ + switch (wend - j) { \ + case 1: \ + DISPATCH2(height, 1); \ + break; \ + case 2: \ + DISPATCH2(height, 2); \ + break; \ + case 3: \ + DISPATCH2(height, 3); \ + break; \ + } \ + } while (0) + switch (hend - i) { + case 1: + DISPATCH1(1); + break; + case 2: + DISPATCH1(2); + break; + case 3: + DISPATCH1(3); + break; + } +#undef DISPATCH1 +#undef DISPATCH2 +} +template +void do_conv_tpl_disable_prefetch(const float* src, const float* filter, + float* dst, const int IH, const int IW, + const int OH, const int OW, const int FW) { + const int hbeg = 0, hend = OH; + const int wbeg = 0, wend = OW; + int i, j; + for (i = hbeg; i + 4 <= hend; i += 4) { + for (j = wbeg; j + 4 <= wend; j += 4) { + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); + } +#define DISPATCH(width) \ + do { \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ + } while (0) + switch (wend - j) { + case 1: + DISPATCH(1); + break; + case 2: + DISPATCH(2); + break; + case 3: + DISPATCH(3); + break; + } +#undef DISPATCH + } +#define DISPATCH2(height, width) \ + do { \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ + j); \ + } while (0) +#define DISPATCH1(height) \ + do { \ + for (j = wbeg; j + 4 <= wend; j += 4) { \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ + j); \ + } \ + switch (wend - j) { \ + case 1: \ + DISPATCH2(height, 1); \ + break; \ + case 2: \ + DISPATCH2(height, 2); \ + break; \ + case 3: \ + DISPATCH2(height, 3); \ + break; \ + } \ + } while (0) + switch (hend - i) { + case 1: + DISPATCH1(1); + break; + case 2: + DISPATCH1(2); + break; + case 3: + DISPATCH1(3); + break; + } +#undef DISPATCH1 +#undef DISPATCH2 +} +} // anonymous namespace + +void conv_bias::kern_direct(const float* src, const float* filter, float* dst, + const int IH, const int IW, const int OH, + const int OW, const int FH, const int FW) { + megdnn_assert_internal(FH <= 7); + if (IH > 100 && IW > 100) { +#define GAO(FH) \ + do { \ + return do_conv_tpl_enable_prefetch(src, filter, dst, IH, IW, OH, \ + OW, FW); \ + } while (0) + switch (FH) { + case 1: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); } + MIDOUT_END(); + break; + case 2: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); } + MIDOUT_END(); + break; + case 3: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); } + MIDOUT_END(); + break; + case 4: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); } + MIDOUT_END(); + break; + case 5: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); } + MIDOUT_END(); + break; + case 6: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); } + MIDOUT_END(); + break; + case 7: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); } + MIDOUT_END(); + break; + } +#undef GAO + } else { +#define GAO(FH) \ + do { \ + return do_conv_tpl_disable_prefetch(src, filter, dst, IH, IW, OH, \ + OW, FW); \ + } while (0) + switch (FH) { + case 1: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); } + MIDOUT_END(); + break; + case 2: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); } + MIDOUT_END(); + break; + case 3: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); } + MIDOUT_END(); + break; + case 4: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); } + MIDOUT_END(); + break; + case 5: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); } + MIDOUT_END(); + break; + case 6: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); } + MIDOUT_END(); + break; + case 7: + MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); } + MIDOUT_END(); + break; + } +#undef GAO + } + megdnn_assert_internal(0); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct.h b/dnn/src/arm_common/conv_bias/fp32/direct.h new file mode 100644 index 00000000..861cd0fe --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct.h @@ -0,0 +1,29 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/direct.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include + +namespace megdnn { +namespace arm_common { +namespace fp32{ +namespace conv_bias { + +void kern_direct(const float *src, const float *filter, float *dst, + const int IH, const int IW, const int OH, const int OW, + const int FH, const int FW); + +} // namespace convolution +} // namespace fp32 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp new file mode 100644 index 00000000..75e73f91 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp @@ -0,0 +1,735 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include + +#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" +#include "src/arm_common/simd_macro/neon_helper.h" +#include "src/arm_common/conv_bias/postprocess_helper.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs1) + +using namespace megdnn; +using namespace arm_common; +using namespace fp32; +using namespace conv_stride1; + +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; + + +void conv_stride1::do_conv_2x2_stride1(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, + size_t IC) { + const size_t tail_step = IW - OW; + //! unroll of 2 + size_t ic = 0; + for (; ic + 1 < IC; ic += 2) { + const float* src_ptr = src + IW * IH * ic; + const float* src_ptr1 = src_ptr + IW * IH; + float* outptr = dst; + + const float* r00 = src_ptr; + const float* r01 = src_ptr + IW; + const float* r10 = src_ptr1; + const float* r11 = src_ptr1 + IW; + + const float* k0 = filter + ic * 4; + const float* k1 = k0 + 4; + + MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_LOADU(k0); + MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_LOADU(k1); + rep(h, OH) { + int width = OW >> 2; + + rep(i, width) { + MEGDNN_SIMD_TYPE _r000 = MEGDNN_SIMD_LOADU(r00); + MEGDNN_SIMD_TYPE _r010 = MEGDNN_SIMD_LOADU(r01); + MEGDNN_SIMD_TYPE _r001 = MEGDNN_SIMD_LOADU(r00 + 1); + MEGDNN_SIMD_TYPE _r011 = MEGDNN_SIMD_LOADU(r01 + 1); + + MEGDNN_SIMD_TYPE _r100 = MEGDNN_SIMD_LOADU(r10); + MEGDNN_SIMD_TYPE _r110 = MEGDNN_SIMD_LOADU(r11); + MEGDNN_SIMD_TYPE _r101 = MEGDNN_SIMD_LOADU(r10 + 1); + MEGDNN_SIMD_TYPE _r111 = MEGDNN_SIMD_LOADU(r11 + 1); + + MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); + + _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r000, + MEGDNN_SIMD_GET_LOW(_k0), 0); + _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r001, + MEGDNN_SIMD_GET_LOW(_k0), 1); + _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r010, + MEGDNN_SIMD_GET_HIGH(_k0), 0); + _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r011, + MEGDNN_SIMD_GET_HIGH(_k0), 1); + + _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r100, + MEGDNN_SIMD_GET_LOW(_k1), 0); + _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r101, + MEGDNN_SIMD_GET_LOW(_k1), 1); + _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r110, + MEGDNN_SIMD_GET_HIGH(_k1), 0); + _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r111, + MEGDNN_SIMD_GET_HIGH(_k1), 1); + + MEGDNN_SIMD_STOREU(outptr, _sum); + + r00 += 4; + r01 += 4; + r10 += 4; + r11 += 4; + outptr += 4; + } + + r00 += tail_step; + r01 += tail_step; + r10 += tail_step; + r11 += tail_step; + } + } + for (; ic < IC; ic++) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + + const float* k0 = filter + ic * 4; + + MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_SET1(k0[0]); + MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_SET1(k0[1]); + MEGDNN_SIMD_TYPE _k2 = MEGDNN_SIMD_SET1(k0[2]); + MEGDNN_SIMD_TYPE _k3 = MEGDNN_SIMD_SET1(k0[3]); + rep(h, OH) { + int width = OW >> 2; + + rep(i, width) { + MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); + MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); + MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_LOADU(r0 + 1); + MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_LOADU(r1 + 1); + + MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); + MEGDNN_SIMD_TYPE _sum2; + + _sum = MEGDNN_SIMD_FMADD(_r00, _k0, _sum); + _sum2 = MEGDNN_SIMD_MUL(_r01, _k1); + _sum = MEGDNN_SIMD_FMADD(_r10, _k2, _sum); + _sum2 = MEGDNN_SIMD_FMADD(_r11, _k3, _sum2); + + _sum = MEGDNN_SIMD_ADD(_sum, _sum2); + + MEGDNN_SIMD_STOREU(outptr, _sum); + + r0 += 4; + r1 += 4; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + } + } +} + +void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, + size_t IC) { + const size_t tail_step = IW - OW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + float* outptr2 = outptr + OW; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + const float* r3 = src_ptr + IW * 3; + + const float* k0 = filter; + const float* k1 = filter + 3; + const float* k2 = filter + 5; + + MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); + MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); + MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); + MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1); + + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 2; + + rep(i, width) { + MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr); + MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f); + MEGDNN_SIMD_TYPE _sum3 = MEGDNN_SIMD_LOADU(outptr2); + MEGDNN_SIMD_TYPE _sum4 = MEGDNN_SIMD_SET1(0.f); + + MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); + MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4); + MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1); + MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2); + + MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); + MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4); + MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1); + MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2); + + MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); + MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4); + MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1); + MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); + + MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); + MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 4); + MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r30n, 1); + MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r30n, 2); + + _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1); + _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0); + _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2); + _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1); + _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2); + + _sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r10, _k0123, 0); + _sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r11, _k0123, 1); + _sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r12, _k0123, 2); + _sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r20, _k3456, 0); + _sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r21, _k3456, 1); + _sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r22, _k3456, 2); + _sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r30, _k6789, 0); + _sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r31, _k6789, 1); + _sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r32, _k6789, 2); + + _sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2); + _sum3 = MEGDNN_SIMD_ADD(_sum3, _sum4); + + MEGDNN_SIMD_STOREU(outptr, _sum1); + MEGDNN_SIMD_STOREU(outptr2, _sum3); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 2; + + rep(i, width) { + MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr); + MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f); + + MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); + MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4); + MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1); + MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2); + + MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); + MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4); + MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1); + MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2); + + MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); + MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4); + MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1); + MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); + + _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1); + _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0); + _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2); + _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1); + _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2); + + _sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2); + + MEGDNN_SIMD_STOREU(outptr, _sum1); + + r0 += 4; + r1 += 4; + r2 += 4; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } + + filter += 9; + } +} + +void conv_stride1::do_conv_5x5_stride1(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, + size_t IC) { + const size_t tail_step = IW - OW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + float* outptr2 = outptr + OW; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + const float* r3 = src_ptr + IW * 3; + const float* r4 = src_ptr + IW * 4; + const float* r5 = src_ptr + IW * 5; + + MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); + MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); + MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); + MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); + MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); + MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); + MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]); + + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 2; + + rep(i, width) { + MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); + MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_LOADU(outptr2); + + MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); + MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); + MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); + MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); + MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); + + MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); + MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); + MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); + MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); + MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); + + MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); + MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); + MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); + MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); + MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); + + MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); + MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); + MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); + MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); + MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); + + MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); + MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); + MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); + MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); + MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); + + MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5); + MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4); + MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1); + MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2); + MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); + + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k0123, 0); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r11, _k0123, 1); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k0123, 2); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r13, _k0123, 3); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r14, _k4567, 0); + + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r20, _k4567, 1); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k4567, 2); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r22, _k4567, 3); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r23, _k891011, 0); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r24, _k891011, 1); + + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r30, _k891011, 2); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r31, _k891011, 3); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r32, _k12131415, 0); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r33, _k12131415, 1); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r34, _k12131415, 2); + + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r40, _k12131415, 3); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r41, _k16171819, 0); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r42, _k16171819, 1); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r43, _k16171819, 2); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r44, _k16171819, 3); + + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r50, _k20212223, 0); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r51, _k20212223, 1); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r52, _k20212223, 2); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r53, _k20212223, 3); + _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r54, _k24242424, 0); + + MEGDNN_SIMD_STOREU(outptr, _sum); + MEGDNN_SIMD_STOREU(outptr2, _sum2); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 2; + + rep(i, width) { + MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); + + MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); + MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); + MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); + MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); + MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); + + MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); + MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); + MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); + MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); + MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); + + MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); + MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); + MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); + MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); + MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); + + MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); + MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); + MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); + MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); + MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); + + MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); + MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); + MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); + MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); + MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); + + MEGDNN_SIMD_STOREU(outptr, _sum); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } + + filter += 25; + } +} + +void conv_stride1::do_conv_7x7_stride1(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, + size_t IC) { + const size_t tail_step = IW - OW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + const float* r3 = src_ptr + IW * 3; + const float* r4 = src_ptr + IW * 4; + const float* r5 = src_ptr + IW * 5; + const float* r6 = src_ptr + IW * 6; + + const float* k0 = filter; + const float* k1 = filter + 7; + const float* k2 = filter + 14; + const float* k3 = filter + 21; + const float* k4 = filter + 28; + const float* k5 = filter + 35; + const float* k6 = filter + 42; + + for (size_t i = 0; i < OH; i++) { + int width = OW >> 2; + + rep(i, width) { + MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); + + MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); + MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4); + + MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); // 0 1 2 3 + MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); // 4 5 6 7 + MEGDNN_SIMD_TYPE _r00n = + MEGDNN_SIMD_LOADU(r0 + 8); // 8 9 10 11 + MEGDNN_SIMD_TYPE _r01 = + MEGDNN_SIMD_EXT(_r00, _r04, 1); // 1 2 3 4 + MEGDNN_SIMD_TYPE _r02 = + MEGDNN_SIMD_EXT(_r00, _r04, 2); // 2 3 4 5 + MEGDNN_SIMD_TYPE _r03 = + MEGDNN_SIMD_EXT(_r00, _r04, 3); // 3 4 5 6 + MEGDNN_SIMD_TYPE _r05 = + MEGDNN_SIMD_EXT(_r04, _r00n, 1); // 5 6 7 8 + MEGDNN_SIMD_TYPE _r06 = + MEGDNN_SIMD_EXT(_r04, _r00n, 2); // 6 7 8 9 + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2); + + MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1); + MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4); + + MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); + MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); + MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 8); + MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); + MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); + MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); + MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r14, _r10n, 1); + MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r14, _r10n, 2); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2); + + MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2); + MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4); + + MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); + MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); + MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 8); + MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); + MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); + MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); + MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r24, _r20n, 1); + MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r24, _r20n, 2); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2); + + MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3); + MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4); + + MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); + MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); + MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 8); + MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); + MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); + MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); + MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r34, _r30n, 1); + MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r34, _r30n, 2); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2); + + MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4); + MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4); + + MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); + MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); + MEGDNN_SIMD_TYPE _r40n = MEGDNN_SIMD_LOADU(r4 + 8); + MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); + MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); + MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); + MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r44, _r40n, 1); + MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r44, _r40n, 2); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2); + + MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5); + MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4); + + MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5); + MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4); + MEGDNN_SIMD_TYPE _r50n = MEGDNN_SIMD_LOADU(r5 + 8); + MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1); + MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2); + MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3); + MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r54, _r50n, 1); + MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r54, _r50n, 2); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); + + MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); + MEGDNN_SIMD_TYPE _k46474849 = MEGDNN_SIMD_LOADU(k6 + 4); + + MEGDNN_SIMD_TYPE _r60 = MEGDNN_SIMD_LOADU(r6); + MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_LOADU(r6 + 4); + MEGDNN_SIMD_TYPE _r60n = MEGDNN_SIMD_LOADU(r6 + 8); + MEGDNN_SIMD_TYPE _r61 = MEGDNN_SIMD_EXT(_r60, _r64, 1); + MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r64, 2); + MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r60, _r64, 3); + MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r64, _r60n, 1); + MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r64, _r60n, 2); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k46474849, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k46474849, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k46474849, 2); + + MEGDNN_SIMD_STOREU(outptr, _sum); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } + filter += 49; + } +} + +#include "src/common/simd_macro/epilogue.h" +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h new file mode 100644 index 00000000..f7990399 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h @@ -0,0 +1,34 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include + +namespace megdnn { +namespace arm_common { +namespace fp32 { +namespace conv_stride1 { + +void do_conv_2x2_stride1(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +void do_conv_3x3_stride1(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +void do_conv_5x5_stride1(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +void do_conv_7x7_stride1(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +} // namespace conv_stride1 +} // namespace fp32 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp new file mode 100644 index 00000000..e6c92653 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp @@ -0,0 +1,513 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include + +#include "./do_conv_stride2.h" +#include "midout.h" +#include "src/arm_common/simd_macro/neon_helper.h" +#include "src/arm_common/conv_bias/postprocess_helper.h" + +MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs2) + +using namespace megdnn; +using namespace arm_common; +using namespace fp32; +using namespace conv_stride2; + +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; + + +void conv_stride2::do_conv_2x2_stride2(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, + size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + + const float* k0 = filter; + + MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); + rep(h, OH) { + int nn = OW >> 2; + + rep(i, nn) { + MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr); + + MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0); + + MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6 + MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7 + + _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0); + _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1); + + MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1); + + MEGDNN_SIMD_TYPE _r10 = _r1.val[0]; + MEGDNN_SIMD_TYPE _r11 = _r1.val[1]; + + _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k0123, 2); + _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k0123, 3); + + MEGDNN_SIMD_STOREU(outptr, _outp); + + r0 += 8; + r1 += 8; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + } + + filter += 4; + } +} + +void conv_stride2::do_conv_3x3_stride2(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, + size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + + const float* k0 = filter; + const float* k1 = filter + 3; + const float* k2 = filter + 5; + + MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); + MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); + MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); + MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1); + rep(h, OH) { + int nn = OW >> 2; + + rep(i, nn) { + MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr); + + MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0); + MEGDNN_SIMD_TYPE2 _r0n = MEGDNN_SIMD_LOAD2(r0 + 8); + + MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6 + MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7 + MEGDNN_SIMD_TYPE _r02 = + MEGDNN_SIMD_EXT(_r00, _r0n.val[0], 1); // 2 4 6 8 + + _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0); + _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1); + _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r02, _k0123, 2); + + MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1); + MEGDNN_SIMD_TYPE2 _r1n = MEGDNN_SIMD_LOAD2(r1 + 8); + + MEGDNN_SIMD_TYPE _r10 = _r1.val[0]; + MEGDNN_SIMD_TYPE _r11 = _r1.val[1]; + MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1n.val[0], 1); + + _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k3456, 0); + _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k3456, 1); + _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r12, _k3456, 2); + + MEGDNN_SIMD_TYPE2 _r2 = MEGDNN_SIMD_LOAD2(r2); + MEGDNN_SIMD_TYPE2 _r2n = MEGDNN_SIMD_LOAD2(r2 + 8); + + MEGDNN_SIMD_TYPE _r20 = _r2.val[0]; + MEGDNN_SIMD_TYPE _r21 = _r2.val[1]; + MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2n.val[0], 1); + + _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r20, _k6789, 0); + _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r21, _k6789, 1); + _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r22, _k6789, 2); + + MEGDNN_SIMD_STOREU(outptr, _outp); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } + + filter += 9; + } +} + +void conv_stride2::do_conv_5x5_stride2(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, + size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + const float* r3 = src_ptr + IW * 3; + const float* r4 = src_ptr + IW * 4; + + MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); + MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); + MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); + MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); + MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); + MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); + MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]); + + for (size_t i = 0; i < OH; i++) { + int nn = OW >> 2; + + rep(i, nn) { + MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); + + MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0); + MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8); + MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 + MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 + MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6 + MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7 + MEGDNN_SIMD_TYPE _r02 = + MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8 + MEGDNN_SIMD_TYPE _r03 = + MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9 + MEGDNN_SIMD_TYPE _r04 = + MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10 + + MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1); + MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8); + MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0]; + MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1]; + MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0]; + MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1]; + MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1); + MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1); + MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2); + + MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2); + MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8); + MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0]; + MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1]; + MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0]; + MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1]; + MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1); + MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1); + MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2); + + MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3); + MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8); + MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0]; + MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1]; + MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0]; + MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1]; + MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1); + MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1); + MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2); + + MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4); + MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8); + MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0]; + MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1]; + MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0]; + MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1]; + MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1); + MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1); + MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); + + MEGDNN_SIMD_STOREU(outptr, _sum); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } + + filter += 25; + } +} + +void conv_stride2::do_conv_7x7_stride2(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, + size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + const float* r3 = src_ptr + IW * 3; + const float* r4 = src_ptr + IW * 4; + const float* r5 = src_ptr + IW * 5; + const float* r6 = src_ptr + IW * 6; + + const float* k0 = filter; + const float* k1 = filter + 7; + const float* k2 = filter + 14; + const float* k3 = filter + 21; + const float* k4 = filter + 28; + const float* k5 = filter + 35; + const float* k6 = filter + 42; + + for (size_t i = 0; i < OH; i++) { + int nn = OW >> 2; + + rep(i, nn) { + MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); + + MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); + MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4); + + MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0); + MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8); + MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 + MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 + MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6 + MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7 + MEGDNN_SIMD_TYPE _r02 = + MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8 + MEGDNN_SIMD_TYPE _r03 = + MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9 + MEGDNN_SIMD_TYPE _r04 = + MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10 + MEGDNN_SIMD_TYPE _r05 = + MEGDNN_SIMD_EXT(_r01, _r0_9111315, 2); // 5 7 9 11 + MEGDNN_SIMD_TYPE _r06 = + MEGDNN_SIMD_EXT(_r00, _r0_8101214, 3); // 6 8 10 12 + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2); + + MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1); + MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4); + + MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1); + MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8); + MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0]; + MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1]; + MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0]; + MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1]; + MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1); + MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1); + MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2); + MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 2); + MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 3); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2); + + MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2); + MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4); + + MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2); + MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8); + MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0]; + MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1]; + MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0]; + MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1]; + MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1); + MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1); + MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2); + MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 2); + MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 3); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2); + + MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3); + MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4); + + MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3); + MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8); + MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0]; + MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1]; + MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0]; + MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1]; + MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1); + MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1); + MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2); + MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 2); + MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 3); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2); + + MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4); + MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4); + + MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4); + MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8); + MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0]; + MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1]; + MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0]; + MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1]; + MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1); + MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1); + MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2); + MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 2); + MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 3); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2); + + MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5); + MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4); + + MEGDNN_SIMD_TYPE2 _r50_02461357 = MEGDNN_SIMD_LOAD2(r5); + MEGDNN_SIMD_TYPE2 _r50nx2 = MEGDNN_SIMD_LOAD2(r5 + 8); + MEGDNN_SIMD_TYPE _r5_8101214 = _r50nx2.val[0]; + MEGDNN_SIMD_TYPE _r5_9111315 = _r50nx2.val[1]; + MEGDNN_SIMD_TYPE _r50 = _r50_02461357.val[0]; + MEGDNN_SIMD_TYPE _r51 = _r50_02461357.val[1]; + MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 1); + MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 1); + MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 2); + MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 2); + MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 3); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); + + MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); + MEGDNN_SIMD_TYPE _k45464748 = MEGDNN_SIMD_LOADU(k6 + 3); + + MEGDNN_SIMD_TYPE2 _r60_02461357 = MEGDNN_SIMD_LOAD2(r6); + MEGDNN_SIMD_TYPE2 _r60nx2 = MEGDNN_SIMD_LOAD2(r6 + 8); + MEGDNN_SIMD_TYPE _r6_8101214 = _r60nx2.val[0]; + MEGDNN_SIMD_TYPE _r6_9111315 = _r60nx2.val[1]; + MEGDNN_SIMD_TYPE _r60 = _r60_02461357.val[0]; + MEGDNN_SIMD_TYPE _r61 = _r60_02461357.val[1]; + MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 1); + MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 1); + MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 2); + MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 2); + MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 3); + + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k45464748, 1); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k45464748, 2); + _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k45464748, 3); + + MEGDNN_SIMD_STOREU(outptr, _sum); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } + filter += 49; + } +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h new file mode 100644 index 00000000..e676c652 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h @@ -0,0 +1,32 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/fallback/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { +namespace fp32 { +namespace conv_stride2 { +void do_conv_2x2_stride2(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +void do_conv_3x3_stride2(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +void do_conv_5x5_stride2(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +void do_conv_7x7_stride2(const float* src, const float* filter, float* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +} // namespace conv_stride2 +} // namespace fp32 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/filter_transform.h b/dnn/src/arm_common/conv_bias/fp32/filter_transform.h new file mode 100644 index 00000000..9ca264ca --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/filter_transform.h @@ -0,0 +1,164 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/filter_transform.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/arm_common/conv_bias/fp32/helper.h" +#include "megdnn/opr_param_defs.h" + +namespace megdnn { +namespace arm_common { + +template +struct FilterTransform6X3 { +#define FILTER_TRANSFORM(d, wd) \ + do { \ + wd##0 = d##0; \ + auto tmp0 = (d##0 + d##2) * -0.2222222f; \ + auto tmp1 = d##1 * -0.2222222f; \ + wd##1 = tmp0 + tmp1; \ + wd##2 = tmp0 - tmp1; \ + tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \ + tmp1 = d##1 * 0.0222222f; \ + wd##3 = tmp0 + tmp1; \ + wd##4 = tmp0 - tmp1; \ + tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \ + tmp1 = d##1 * 0.3555556f; \ + wd##5 = tmp0 + tmp1; \ + wd##6 = tmp0 - tmp1; \ + wd##7 = d##2; \ + } while (0); + + static void transform(const float* filter, float* filter_transform_buf, + float* transform_mid_buf, size_t OC, size_t IC, + size_t oc_start, size_t oc_end) { + // Gg * GT + // G + // 1.0000000 0.0000000 0.0000000 + // -0.2222222 -0.2222222 -0.2222222 + // -0.2222222 0.2222222 -0.2222222 + // 0.0111111 0.0222222 0.0444444 + // 0.0111111 -0.0222222 0.0444444 + // 0.7111111 0.3555556 0.1777778 + // 0.7111111 -0.3555556 0.1777778 + // 0.0000000 0.0000000 1.0000000 + + constexpr size_t alpha = 6 + 3 - 1; + size_t OCB = OC / 4; + size_t ICB = IC / 4; + for (size_t oc = oc_start; oc < oc_end; oc++) { + rep(ic, IC) { + const float* fptr = filter + (oc * IC + ic) * 3 * 3; + + Vector g0 = Vector::load(fptr); + Vector g1 = Vector::load(fptr + 3); + + Vector g2 = Vector::load(fptr + 6 - 1); + float32x4_t zeros = vdupq_n_f32(0.0f); + g2.value = vextq_f32(g2.value, zeros, 1); + +#define cb(i) Vector wd##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + +#define cb(i) Vector wdt##i; + UNROLL_CALL_NOWRAPPER(3, cb); +#undef cb + +#define cb(i) Vector ret##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + FILTER_TRANSFORM(g, wd); + + size_t ocb = oc / 4; + size_t oc4 = oc % 4; + size_t icb = ic / 4; + size_t ic4 = ic % 4; +#if MEGDNN_AARCH64 + TRANSPOSE_8x3(wd, wdt); + FILTER_TRANSFORM(wdt, ret); + +#define cb(i) ret##i.save(transform_mid_buf + i * alpha); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + rep(i, alpha) rep(j, alpha) { + if (format == param::MatrixMul::Format::DEFAULT) { + filter_transform_buf[(i * alpha + j) * OC * IC + + ic * OC + oc] = + transform_mid_buf[j * alpha + i]; + } else { + filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 * + 4 + + ocb * ICB * 4 * 4 + icb * 4 * 4 + + ic4 * 4 + oc4] = + transform_mid_buf[j * alpha + i]; + } + } + +#else + +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ + auto tmp0 = (GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * \ + -0.2222222f; \ + auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * -0.2222222f; \ + mid_buf1[1] = tmp0 + tmp1; \ + mid_buf1[2] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0111111f + \ + GET_VECTOR_ELEM(wd, i, 2) * 0.0444444f; \ + tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0222222f; \ + mid_buf1[3] = tmp0 + tmp1; \ + mid_buf1[4] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.7111111f + \ + GET_VECTOR_ELEM(wd, i, 2) * 0.1777778f; \ + tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.3555556f; \ + mid_buf1[5] = tmp0 + tmp1; \ + mid_buf1[6] = tmp0 - tmp1; \ + mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ + mid_buf1 += 8; \ + } while (0); +#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) + + float* mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(8, cb); + mid_buf1 = transform_mid_buf; +#undef cb + + rep(i, alpha) rep(j, alpha) { + if (format == param::MatrixMul::Format::DEFAULT) { + filter_transform_buf[(i * alpha + j) * OC * IC + + ic * OC + oc] = + transform_mid_buf[i * alpha + j]; + } else { + filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 * + 4 + + ocb * ICB * 4 * 4 + icb * 4 * 4 + + ic4 * 4 + oc4] = + transform_mid_buf[i * alpha + j]; + } + } +#endif + } + } + } +}; +#undef FILTER_TRANSFORM +#undef GET_VECTOR_ELEM + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/helper.h b/dnn/src/arm_common/conv_bias/fp32/helper.h new file mode 100644 index 00000000..40df3c4a --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/helper.h @@ -0,0 +1,204 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once +#include "src/common/unroll_macro.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace arm_common { +inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { + float32x4x2_t a0, a1; + a0.val[0] = vld1q_f32(src + 0 * lda); + a0.val[1] = vld1q_f32(src + 1 * lda); + a1.val[0] = vld1q_f32(src + 2 * lda); + a1.val[1] = vld1q_f32(src + 3 * lda); + float32x4x2_t b0 = vzipq_f32(a0.val[0], a1.val[0]); + float32x4x2_t b1 = vzipq_f32(a0.val[1], a1.val[1]); + float32x4x2_t c0 = vzipq_f32(b0.val[0], b1.val[0]); + float32x4x2_t c1 = vzipq_f32(b0.val[1], b1.val[1]); + vst1q_f32(dst + 0 * ldb, c0.val[0]); + vst1q_f32(dst + 1 * ldb, c0.val[1]); + vst1q_f32(dst + 2 * ldb, c1.val[0]); + vst1q_f32(dst + 3 * ldb, c1.val[1]); +} +} // namespace arm_common +} // namespace megdnn + +#define MATRIX_MUL4x4(sum, a, b) \ + sum##0 = vmlaq_low_lane_f32(sum##0, b##0, a##0, 0); \ + sum##0 = vmlaq_low_lane_f32(sum##0, b##1, a##0, 1); \ + sum##0 = vmlaq_high_lane_f32(sum##0, b##2, a##0, 2); \ + sum##0 = vmlaq_high_lane_f32(sum##0, b##3, a##0, 3); \ + sum##1 = vmlaq_low_lane_f32(sum##1, b##0, a##1, 0); \ + sum##1 = vmlaq_low_lane_f32(sum##1, b##1, a##1, 1); \ + sum##1 = vmlaq_high_lane_f32(sum##1, b##2, a##1, 2); \ + sum##1 = vmlaq_high_lane_f32(sum##1, b##3, a##1, 3); \ + sum##2 = vmlaq_low_lane_f32(sum##2, b##0, a##2, 0); \ + sum##2 = vmlaq_low_lane_f32(sum##2, b##1, a##2, 1); \ + sum##2 = vmlaq_high_lane_f32(sum##2, b##2, a##2, 2); \ + sum##2 = vmlaq_high_lane_f32(sum##2, b##3, a##2, 3); \ + sum##3 = vmlaq_low_lane_f32(sum##3, b##0, a##3, 0); \ + sum##3 = vmlaq_low_lane_f32(sum##3, b##1, a##3, 1); \ + sum##3 = vmlaq_high_lane_f32(sum##3, b##2, a##3, 2); \ + sum##3 = vmlaq_high_lane_f32(sum##3, b##3, a##3, 3); + +#define CONCAT(a, idx) a##idx + +#if MEGDNN_AARCH64 +//! ret and a are type Vector +#define TRANSPOSE_8x8(a, ret) \ + do { \ + auto b0 = vzipq_f32(CONCAT(a, 0).value.val[0], \ + CONCAT(a, 1).value.val[0]); \ + auto b1 = vzipq_f32(CONCAT(a, 0).value.val[1], \ + CONCAT(a, 1).value.val[1]); \ + auto b2 = vzipq_f32(CONCAT(a, 2).value.val[0], \ + CONCAT(a, 3).value.val[0]); \ + auto b3 = vzipq_f32(CONCAT(a, 2).value.val[1], \ + CONCAT(a, 3).value.val[1]); \ + auto b4 = vzipq_f32(CONCAT(a, 4).value.val[0], \ + CONCAT(a, 5).value.val[0]); \ + auto b5 = vzipq_f32(CONCAT(a, 4).value.val[1], \ + CONCAT(a, 5).value.val[1]); \ + auto b6 = vzipq_f32(CONCAT(a, 6).value.val[0], \ + CONCAT(a, 7).value.val[0]); \ + auto b7 = vzipq_f32(CONCAT(a, 6).value.val[1], \ + CONCAT(a, 7).value.val[1]); \ + CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \ + vreinterpretq_s64_f32(b2.val[0]))); \ + CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b4.val[0]), \ + vreinterpretq_s64_f32(b6.val[0]))); \ + CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \ + vreinterpretq_s64_f32(b2.val[0]))); \ + CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b4.val[0]), \ + vreinterpretq_s64_f32(b6.val[0]))); \ + CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \ + vreinterpretq_s64_f32(b2.val[1]))); \ + CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b4.val[1]), \ + vreinterpretq_s64_f32(b6.val[1]))); \ + CONCAT(ret, 3).value.val[0] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b0.val[1]), \ + vreinterpretq_s64_f32(b2.val[1]))); \ + CONCAT(ret, 3).value.val[1] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b4.val[1]), \ + vreinterpretq_s64_f32(b6.val[1]))); \ + CONCAT(ret, 4).value.val[0] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b1.val[0]), \ + vreinterpretq_s64_f32(b3.val[0]))); \ + CONCAT(ret, 4).value.val[1] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b5.val[0]), \ + vreinterpretq_s64_f32(b7.val[0]))); \ + CONCAT(ret, 5).value.val[0] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b1.val[0]), \ + vreinterpretq_s64_f32(b3.val[0]))); \ + CONCAT(ret, 5).value.val[1] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b5.val[0]), \ + vreinterpretq_s64_f32(b7.val[0]))); \ + CONCAT(ret, 6).value.val[0] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b1.val[1]), \ + vreinterpretq_s64_f32(b3.val[1]))); \ + CONCAT(ret, 6).value.val[1] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b5.val[1]), \ + vreinterpretq_s64_f32(b7.val[1]))); \ + CONCAT(ret, 7).value.val[0] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b1.val[1]), \ + vreinterpretq_s64_f32(b3.val[1]))); \ + CONCAT(ret, 7).value.val[1] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b5.val[1]), \ + vreinterpretq_s64_f32(b7.val[1]))); \ + } while (0); + +#define TRANSPOSE_8x3(a, ret) \ + auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ + auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ + auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ + auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ + CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \ + vreinterpretq_s64_f32(b1.val[0]))); \ + CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b2.val[0]), \ + vreinterpretq_s64_f32(b3.val[0]))); \ + CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \ + vreinterpretq_s64_f32(b1.val[0]))); \ + CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b2.val[0]), \ + vreinterpretq_s64_f32(b3.val[0]))); \ + CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \ + vreinterpretq_s64_f32(b1.val[1]))); \ + CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b2.val[1]), \ + vreinterpretq_s64_f32(b3.val[1]))); + +#define TRANSPOSE_8x4(a, ret) \ + auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ + auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ + auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ + auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ + CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \ + vreinterpretq_s64_f32(b1.val[0]))); \ + CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b2.val[0]), \ + vreinterpretq_s64_f32(b3.val[0]))); \ + CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \ + vreinterpretq_s64_f32(b1.val[0]))); \ + CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b2.val[0]), \ + vreinterpretq_s64_f32(b3.val[0]))); \ + CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \ + vreinterpretq_s64_f32(b1.val[1]))); \ + CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \ + vzip1q_s64(vreinterpretq_s64_f32(b2.val[1]), \ + vreinterpretq_s64_f32(b3.val[1]))); \ + CONCAT(ret, 3).value.val[0] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b0.val[1]), \ + vreinterpretq_s64_f32(b1.val[1]))); \ + CONCAT(ret, 3).value.val[1] = vreinterpretq_f32_s64( \ + vzip2q_s64(vreinterpretq_s64_f32(b2.val[1]), \ + vreinterpretq_s64_f32(b3.val[1]))); + +#elif MEGDNN_ARMV7 +#define TRANSPOSE_8x4(a, ret) \ + auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ + auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ + auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ + auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ + CONCAT(ret, 0).value.val[0] = \ + vcombine_f32(vget_low_f32(b0.val[0]), vget_low_f32(b1.val[0])); \ + CONCAT(ret, 1).value.val[0] = \ + vcombine_f32(vget_high_f32(b0.val[0]), vget_high_f32(b1.val[0])); \ + CONCAT(ret, 2).value.val[0] = \ + vcombine_f32(vget_low_f32(b0.val[1]), vget_low_f32(b1.val[1])); \ + CONCAT(ret, 3).value.val[0] = \ + vcombine_f32(vget_high_f32(b0.val[1]), vget_high_f32(b1.val[1])); \ + CONCAT(ret, 0).value.val[1] = \ + vcombine_f32(vget_low_f32(b2.val[0]), vget_low_f32(b3.val[0])); \ + CONCAT(ret, 1).value.val[1] = \ + vcombine_f32(vget_high_f32(b2.val[0]), vget_high_f32(b3.val[0])); \ + CONCAT(ret, 2).value.val[1] = \ + vcombine_f32(vget_low_f32(b2.val[1]), vget_low_f32(b3.val[1])); \ + CONCAT(ret, 3).value.val[1] = \ + vcombine_f32(vget_high_f32(b2.val[1]), vget_high_f32(b3.val[1])); + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy.h b/dnn/src/arm_common/conv_bias/fp32/strategy.h new file mode 100644 index 00000000..43b109e9 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/strategy.h @@ -0,0 +1,39 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/fallback/conv_bias/winograd/winograd.h" + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4, + winograd_2x3_4x4_f) + +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, + winograd_6x3_1x1_f) + +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4, + winograd_6x3_4x4_f) + +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 5, 4, 1, 1, + winograd_5x4_1x1_f) + +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 4, 5, 1, 1, + winograd_4x5_1x1_f) +} // namespace winograd +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp new file mode 100644 index 00000000..5d55628e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp @@ -0,0 +1,346 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/fp32/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/winograd/winograd.h" + +#include "src/naive/matrix_mul/matrix_mul_helper.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/arm_common/conv_bias/fp32/helper.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F23) + +using namespace megdnn; +using namespace arm_common; +namespace { + +struct InputTransform2X3 { + template + static void prepare(const float* input, float* patch, float* patchT, + int ih_start, int iw_start, size_t IH, size_t IW, + size_t ic, size_t IC) { + constexpr size_t alpha = 2 + 3 - 1; + if (!(inner && ic + 4 < IC)) { + memset(patch, 0, sizeof(float) * 4 * alpha * alpha); + } + if (inner) { + const float* input_ptr = + input + ic * IH * IW + ih_start * IW + iw_start; + for (size_t ico = 0; ico < 4; ++ico) { + if (ic + ico < IC) { + auto v0 = vld1q_f32(input_ptr); + auto v1 = vld1q_f32(input_ptr + IW); + auto v2 = vld1q_f32(input_ptr + IW * 2); + auto v3 = vld1q_f32(input_ptr + IW * 3); + + vst1q_f32(patch + ico * 4 * alpha + 0 * 4, v0); + vst1q_f32(patch + ico * 4 * alpha + 1 * 4, v1); + vst1q_f32(patch + ico * 4 * alpha + 2 * 4, v2); + vst1q_f32(patch + ico * 4 * alpha + 3 * 4, v3); + input_ptr += IH * IW; + } + } + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + // partial copy + for (size_t ico = 0; ico < 4; ++ico) { + if (ic + ico < IC) { + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + patch[ico * alpha * 4 + iho * 4 + iwo] = + input[(ic + ico) * IH * IW + ih * IW + iw]; + } + } + } + } + } + + transpose_4x4(patch + 0 * 1, patchT + 0 * 4, 16, 4); + transpose_4x4(patch + 4 * 1, patchT + 4 * 4, 16, 4); + transpose_4x4(patch + 8 * 1, patchT + 8 * 4, 16, 4); + transpose_4x4(patch + 12 * 1, patchT + 12 * 4, 16, 4); + } + + static void transform(const float* patchT, float* input_transform_buf, + size_t unit_idx, size_t nr_units_in_tile, size_t ic, + size_t IC) { + constexpr size_t alpha = 2 + 3 - 1; + // BT * d * B +#define cb(m, n) \ + Vector d##m##n = \ + Vector::load(patchT + m * 4 * 4 + n * 4); + + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + + //! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 + //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 + //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 + //! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 +#define cb(m) \ + auto t0##m = d0##m - d2##m; \ + auto t1##m = d1##m + d2##m; \ + auto t2##m = d2##m - d1##m; \ + auto t3##m = d3##m - d1##m; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(m) \ + d##m##0 = t##m##0 - t##m##2; \ + d##m##1 = t##m##1 + t##m##2; \ + d##m##2 = t##m##2 - t##m##1; \ + d##m##3 = t##m##3 - t##m##1; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + size_t ICB = IC / 4; + size_t icb = ic / 4; +#define cb(m, n) \ + d##m##n.save(input_transform_buf + \ + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ + icb * nr_units_in_tile * 4 + unit_idx * 4); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) +#undef cb + } +}; + +template +struct OutputTransform2X3 { + static void transform(const float* output_transform_buf, const float* bias, + float* output, float* transform_mid_buf, + size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, + size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, + const DType& dst_dtype) { + Op op(src_dtype, dst_dtype); + //! AT * m * A + constexpr size_t alpha = 2 + 3 - 1; + + size_t oc = oc_start + oc_index; + size_t OCB = (oc_end - oc_start) / 4; + size_t ocb = oc_index / 4; + +#define cb(m, n) \ + auto v##m##n = Vector::load( \ + output_transform_buf + \ + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ + ocb * nr_units_in_tile * 4 + unit_idx * 4); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + //! 1 1 1 0 v00 v01 v02 v03 1 0 + //! 0 1 -1 1 v10 v11 v12 v13 1 1 + //! v20 v21 v22 v23 1 -1 + //! v30 v31 v32 v33 0 1 +#define cb(m) \ + auto t0##m = v0##m + v1##m + v2##m; \ + auto t1##m = v1##m - v2##m + v3##m; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + v00 = t00 + t01 + t02; + v10 = t10 + t11 + t12; + v01 = t01 - t02 + t03; + v11 = t11 - t12 + t13; + + Vector vbias; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias = Vector::load(bias + oc); + + v00 += vbias; + v10 += vbias; + v01 += vbias; + v11 += vbias; + } + if (bmode != BiasMode::BIAS) { + v00 = op(v00.value); + v01 = op(v01.value); + v10 = op(v10.value); + v11 = op(v11.value); + } + + v00.save(transform_mid_buf + (0 * 2 + 0) * 4); + v10.save(transform_mid_buf + (1 * 2 + 0) * 4); + v01.save(transform_mid_buf + (0 * 2 + 1) * 4); + v11.save(transform_mid_buf + (1 * 2 + 1) * 4); + + for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { + for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { + for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) { + size_t oh = oh_start + oho; + size_t ow = ow_start + owo; + float res = transform_mid_buf[oho * 2 * 4 + owo * 4 + oco]; + if (bmode == BiasMode::BIAS) { + res += bias[(oc + oco) * OH * OW + oh * OW + ow]; + res = op(res); + } + output[(oc + oco) * OH * OW + oh * OW + ow] = res; + } + } + } + } +}; +} // namespace + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) +void winograd_2x3_4x4_f::filter(const float* filter, + float* filter_transform_buf, + float* transform_mid_buf, size_t OC, size_t IC, + size_t oc_start, size_t oc_end) { + constexpr int alpha = 2 + 3 - 1; + //! G * g * GT + float32x4_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0}, + g3{0, 0, 1, 0}; + float32x4_t gt0{1, 0.5, 0.5, 0}, gt1{0, 0.5, -0.5, 0}, gt2{0, 0.5, 0.5, 1}, + gt3{0, 0, 0, 0}; + size_t OCB = OC / 4; + size_t ICB = IC / 4; + + for (size_t oc = oc_start; oc < oc_end; oc++) + rep(ic, IC) { + const float* filter_ptr = filter + (oc * IC + ic) * 3 * 3; + /** + * origin: (4x3) * (3 x 3) * (3 x 4) + * pack to G and g to times of 4 + * now: (4x4) * (4 x 4) * (4 x 4) + */ + //! 1 0 0 0 v00 v01 v02 0 1 0.5 0.5 0 + //! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 + //! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 + //! 0 0 1 0 0 0 0 0 0 0 0 0 + float32x4_t vf0 = vld1q_f32(filter_ptr); + float32x4_t vf1 = vld1q_f32(filter_ptr + 4); + float32x4_t vf2 = vdupq_n_f32(filter_ptr[8]); + + float32x4_t v3(vdupq_n_f32(0)); + auto vtmp = vextq_f32(vf1, vf2, 2); + vtmp = vsetq_lane_f32(0, vtmp, 3); + float32x4_t v2(vtmp); + vtmp = vextq_f32(vf0, vf1, 3); + vtmp = vsetq_lane_f32(0, vtmp, 3); + float32x4_t v1(vtmp); + vtmp = vsetq_lane_f32(0, vf0, 3); + float32x4_t v0(vtmp); + + float32x4_t vsum0 = vdupq_n_f32(0), vsum1 = vdupq_n_f32(0), + vsum2 = vdupq_n_f32(0), vsum3 = vdupq_n_f32(0); + + MATRIX_MUL4x4(vsum, g, v); + + float32x4_t vres0 = vdupq_n_f32(0), vres1 = vdupq_n_f32(0), + vres2 = vdupq_n_f32(0), vres3 = vdupq_n_f32(0); + MATRIX_MUL4x4(vres, vsum, gt); + + vst1q_f32(transform_mid_buf, vres0); + vst1q_f32(transform_mid_buf + 4, vres1); + vst1q_f32(transform_mid_buf + 8, vres2); + vst1q_f32(transform_mid_buf + 12, vres3); + + size_t ocb = oc / 4; + size_t oc4 = oc % 4; + size_t icb = ic / 4; + size_t ic4 = ic % 4; + rep(i, alpha) rep(j, alpha) { + filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 * 4 + + ocb * ICB * 4 * 4 + icb * 4 * 4 + ic4 * 4 + + oc4] = transform_mid_buf[i * alpha + j]; + } + } +} + +void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf, + float* transform_mid_buf, size_t IH, size_t IW, + size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, size_t nr_units_in_tile) { + megdnn_assert(IC % 4 == 0); + constexpr int alpha = 3 + 2 - 1; + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + float* patch = transform_mid_buf; + float* patchT = transform_mid_buf + 4 * alpha * alpha; + + for (size_t ic = 0; ic < IC; ic += 4) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransform2X3::prepare(input, patch, patchT, ih_start, + iw_start, IH, IW, ic, IC); + InputTransform2X3::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, ic, + IC); + + } else { + InputTransform2X3::prepare(input, patch, patchT, + ih_start, iw_start, IH, IW, + ic, IC); + InputTransform2X3::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, ic, + IC); + } + } + } +} + +void winograd_2x3_4x4_f::output(const float* output_transform_buf, + const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform2X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + + for (size_t oc = oc_start; oc < oc_end; oc += 4) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_fp32_F23, cb, float, float, bmode, + nonline_mode, output_transform_buf, bias, output, transform_mid_buf, + oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, + nr_units_in_tile, src_dtype, dst_dtype); + } + } +#undef cb +} + +} // namespace winograd +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp new file mode 100644 index 00000000..bf634c3c --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp @@ -0,0 +1,483 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/fp32/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/winograd/winograd.h" + +#include "src/arm_common/conv_bias/fp32/helper.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F45) + +using namespace megdnn; +using namespace arm_common; +namespace { + +struct FilterTransform4X5 { +#define FILTER_TRANSFORM(d, wd) \ + do { \ + wd##0 = d##0; \ + wd##r0 = d##r0; \ + wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \ + wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; \ + wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \ + wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; \ + auto tmpd0 = d##0 * 0.7111111; \ + auto tmpd1 = d##1 * 0.3555556; \ + auto tmpd2 = d##2 * 0.1777778; \ + auto tmpd3 = d##3 * 0.0888889; \ + auto tmpd4 = d##4 * 0.0444444; \ + auto tmpdr0 = d##r0 * 0.7111111; \ + auto tmpdr1 = d##r1 * 0.3555556; \ + auto tmpdr2 = d##r2 * 0.1777778; \ + auto tmpdr3 = d##r3 * 0.0888889; \ + auto tmpdr4 = d##r4 * 0.0444444; \ + wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ + wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ + wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ + wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ + tmpd0 = d##0 * 0.0111111; \ + tmpd1 = d##1 * 0.0222222; \ + tmpd2 = d##2 * 0.0444444; \ + tmpd3 = d##3 * 0.0888889; \ + tmpd4 = d##4 * 0.1777778; \ + tmpdr0 = d##r0 * 0.0111111; \ + tmpdr1 = d##r1 * 0.0222222; \ + tmpdr2 = d##r2 * 0.0444444; \ + tmpdr3 = d##r3 * 0.0888889; \ + tmpdr4 = d##r4 * 0.1777778; \ + wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ + wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ + wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ + wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ + wd##7 = d##4; \ + wd##r7 = d##r4; \ + } while (0); + +#define FILTER_TRANSFORM_FINAL(d, wd) \ + do { \ + wd##0 = d##0; \ + wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \ + wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \ + auto tmp0 = d##0 * 0.7111111 + d##2 * 0.1777778 + d##4 * 0.0444444; \ + auto tmp1 = d##1 * 0.3555556 + d##3 * 0.0888889; \ + wd##3 = tmp0 + tmp1; \ + wd##4 = tmp0 - tmp1; \ + tmp0 = d##0 * 0.0111111 + d##2 * 0.0444444 + d##4 * 0.1777778; \ + tmp1 = d##1 * 0.0222222 + d##3 * 0.0888889; \ + wd##5 = tmp0 + tmp1; \ + wd##6 = tmp0 - tmp1; \ + wd##7 = d##4; \ + } while (0); + static void transform(const float* filter, float* filter_transform_buf, + float* transform_mid_buf, size_t OC, size_t IC, + size_t oc_start, size_t oc_end) { + // Gg * GT + // G + //[[ 1. 0. 0. 0. 0. ] + // [-0.2222222 -0.2222222 -0.2222222 -0.2222222 -0.2222222] + // [-0.2222222 0.2222222 -0.2222222 0.2222222 -0.2222222] + // [ 0.7111111 0.3555556 0.1777778 0.0888889 0.0444444] + // [ 0.7111111 -0.3555556 0.1777778 -0.0888889 0.0444444] + // [ 0.0111111 0.0222222 0.0444444 0.0888889 0.1777778] + // [ 0.0111111 -0.0222222 0.0444444 -0.0888889 0.1777778] + // [ 0. 0. 0. 0. 1. ]] + constexpr size_t alpha = 4 + 5 - 1; + for (size_t oc = oc_start; oc < oc_end; oc++) + rep(ic, IC) { + const float* fptr = filter + (oc * IC + ic) * 5 * 5; + +#define cb(i) Vector g##i = Vector::load(fptr + 5 * i); + UNROLL_CALL_NOWRAPPER(5, cb); +#undef cb + +#define cb(i) float gr##i = *(fptr + 5 * i + 4); + UNROLL_CALL_NOWRAPPER(5, cb); + +#undef cb +#define cb(i) Vector Gg##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + +#define cb(i) float Ggr##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + +#define cb(i) Vector Ggt##i; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(i) Vector result##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + FILTER_TRANSFORM(g, Gg) + float32x4x2_t vgr; + float32x4_t vgr0 = {Ggr0, Ggr1, Ggr2, Ggr3}; + float32x4_t vgr1 = {Ggr4, Ggr5, Ggr6, Ggr7}; + vgr.val[0] = vgr0; //{Ggr0, Ggr1, Ggr2, Ggr3}; + vgr.val[1] = vgr1; //{Ggr4, Ggr5, Ggr6, Ggr7}; + Vector Ggt4(vgr); + TRANSPOSE_8x4(Gg, Ggt); + FILTER_TRANSFORM_FINAL(Ggt, result); + +#define cb(i) result##i.save(transform_mid_buf + i * alpha); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + rep(i, alpha) rep(j, alpha) { + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + + oc] = transform_mid_buf[j * alpha + i]; + } + } + } +}; +#undef FILTER_TRANSFORM +#undef FILTER_TRANSFORM_FINAL + +struct InputTransform4X5 { +#define INPUT_TRANSFORM(d, wd) \ + do { \ + wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ + auto tmp0 = d##2 - d##4 * 4.25f + d##6; \ + auto tmp1 = d##1 - d##3 * 4.25f + d##5; \ + wd##1 = tmp0 + tmp1; \ + wd##2 = tmp0 - tmp1; \ + tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ + tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ + wd##3 = tmp0 + tmp1; \ + wd##4 = tmp0 - tmp1; \ + tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ + tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ + wd##5 = tmp0 + tmp1; \ + wd##6 = tmp0 - tmp1; \ + wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ + } while (0) + +#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ + vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) +#define GET_VECTOR_LOW_ELEM(s, i, idx) \ + vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) + + template + static void transform(const float* input, float* input_transform_buf, + float* transform_mid_buf, int ih_start, int iw_start, + size_t ic, size_t IH, size_t IW, size_t IC, + size_t unit_idx, size_t nr_units_in_tile) { + // BTd * B + //([[ 1. , 0. , -5.25, 0. , 5.25, 0. , -1. , 0. ], + // [ 0. , 1. , 1. , -4.25, -4.25, 1. , 1. , 0. ], + // [ 0. , -1. , 1. , 4.25, -4.25, -1. , 1. , 0. ], + // [ 0. , 2. , 4. , -2.5 , -5. , 0.5 , 1. , 0. ], + // [ 0. , -2. , 4. , 2.5 , -5. , -0.5 , 1. , 0. ], + // [ 0. , 0.5 , 0.25, -2.5 , -1.25, 2. , 1. , 0. ], + // [ 0. , -0.5 , 0.25, 2.5 , -1.25, -2. , 1. , 0. ], + // [ 0. , -1. , 0. , 5.25, 0. , -5.25, 0. , 1. ]])) + + constexpr size_t alpha = 4 + 5 - 1; + if (!inner) { + memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); + } + +#define cb(i) Vector d##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + if (inner) { + const float* input_ptr = + input + ic * IH * IW + ih_start * IW + iw_start; +#define cb(i) d##i = Vector::load(input_ptr + IW * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + transform_mid_buf[iho * alpha + iwo] = + input[ic * IH * IW + ih * IW + iw]; + } + } +#define cb(i) d##i = Vector::load(transform_mid_buf + alpha * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + } + +#define cb(i) Vector wd##i, ret##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + INPUT_TRANSFORM(d, wd); +#if MEGDNN_AARCH64 + TRANSPOSE_8x8(wd, d); + INPUT_TRANSFORM(d, ret); + +#define cb(i) ret##i.save(transform_mid_buf + i * alpha); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + rep(i, alpha) rep(j, alpha) { + input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + + unit_idx * IC + ic] = + transform_mid_buf[j * alpha + i]; + } +#else +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \ + GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ + 5.25 * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ + GET_VECTOR_LOW_ELEM(wd, i, 2)); \ + mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \ + GET_VECTOR_LOW_ELEM(wd, i, 1) + \ + 5.25 * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ + GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ + auto tmp0 = 4 * GET_VECTOR_LOW_ELEM(wd, i, 2) + \ + -5 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ + GET_VECTOR_HIGH_ELEM(wd, i, 2); \ + auto tmp1 = 2 * GET_VECTOR_LOW_ELEM(wd, i, 1) + \ + -2.5 * GET_VECTOR_LOW_ELEM(wd, i, 3) + \ + 0.5 * GET_VECTOR_HIGH_ELEM(wd, i, 1); \ + mid_buf1[3] = tmp0 + tmp1; \ + mid_buf1[4] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ + -4.25 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ + GET_VECTOR_HIGH_ELEM(wd, i, 2); \ + tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \ + GET_VECTOR_LOW_ELEM(wd, i, 3) * -4.25 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 1); \ + mid_buf1[1] = tmp0 + tmp1; \ + mid_buf1[2] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) * 0.25 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 0) * -1.25 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 2); \ + tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5 + \ + GET_VECTOR_LOW_ELEM(wd, i, 3) * -2.5 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2; \ + mid_buf1[5] = tmp0 + tmp1; \ + mid_buf1[6] = tmp0 - tmp1; \ + mid_buf1 += 8; \ + } while (0); + + float* mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(8, cb); + mid_buf1 = transform_mid_buf; + +#undef cb + rep(i, alpha) rep(j, alpha) { + input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + + unit_idx * IC + ic] = + transform_mid_buf[i * alpha + j]; + } +#endif + } +}; +#undef INPUT_TRANSFORM + +#define OUTPUT_TRANSFORM(m, s) \ + do { \ + s0 = m0 + m1 + m2 + m3 + m4 + m5 + m6; \ + s1 = m1 - m2 + m3 * 0.5 - m4 * 0.5 + m5 * 2.0 - m6 * 2.0; \ + s2 = m1 + m2 + m3 * 0.25 + m4 * 0.25 + m5 * 4.0 + m6 * 4.0; \ + s3 = m1 - m2 + m3 * 0.125 - m4 * 0.125 + m5 * 8.0 - m6 * 8.0 + m7; \ + } while (0) +template +struct OutputTransform4X5 { + static void transform(const float* output_transform_buf, const float* bias, + float* output, float* transform_mid_buf, + size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, + size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, + const DType& dst_dtype) { + Op op(src_dtype, dst_dtype); + //! AT * m * A + // AT f45 + // 1.0 1.0 1.0 1.000 1.000 1.0 1.0 0.0 + // 0.0 1.0 -1.0 0.500 -0.500 2.0 -2.0 0.0 + // 0.0 1.0 1.0 0.250 0.250 4.0 4.0 0.0 + // 0.0 1.0 -1.0 0.125 -0.125 8.0 -8.0 1.0 + constexpr size_t alpha = 5 + 4 - 1; + float* mid_buf1 = transform_mid_buf; + + size_t OC = oc_end - oc_start; + size_t oc = oc_start + oc_index; + +#define cb(m, n) \ + transform_mid_buf[m * alpha + n] = \ + output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \ + unit_idx * OC + oc_index]; + UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); +#undef cb + +#define cb(i) auto m##i = Vector::load(transform_mid_buf + alpha * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb +#define cb(i) Vector s##i; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + OUTPUT_TRANSFORM(m, s); +#define cb(i) \ + do { \ + auto add12 = \ + GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto add34 = \ + GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ + auto add56 = \ + GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ + auto sub12 = \ + GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto sub34 = \ + GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ + auto sub56 = \ + GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ + mid_buf1[0] = GET_VECTOR_LOW_ELEM(s, i, 0) + add12 + add34 + add56; \ + mid_buf1[1] = sub12 + sub34 * 0.5 + sub56 * 2.0; \ + mid_buf1[2] = add12 + add34 * 0.25 + add56 * 4.0; \ + mid_buf1[3] = sub12 + sub34 * 0.125 + sub56 * 8.0 + \ + GET_VECTOR_HIGH_ELEM(s, i, 3); \ + mid_buf1 += 4; \ + } while (0); + + mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(4, cb); + mid_buf1 = transform_mid_buf; +#undef cb + + if (oh_start + 4 <= OH && ow_start + 4 <= OW) { + float32x4_t bias0; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias0 = vdupq_n_f32(bias[oc]); + } + rep(i, 4) { + size_t oh = oh_start + i; + float32x4_t item0 = vld1q_f32(mid_buf1); + + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + item0 = vaddq_f32(item0, bias0); + } else if (bmode == BiasMode::BIAS) { + bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); + item0 = vaddq_f32(item0, bias0); + } + item0 = op(item0); + vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); + mid_buf1 += 4; + } + } else { + for (size_t oho = 0; oho < 4 && oh_start + oho < OH; ++oho) { + for (size_t owo = 0; owo < 4 && ow_start + owo < OW; ++owo) { + size_t oh = oh_start + oho; + size_t ow = ow_start + owo; + float res = mid_buf1[oho * 4 + owo]; + if (bmode == BiasMode::BIAS) { + res += bias[oc * OH * OW + oh * OW + ow]; + } else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + res += bias[oc]; + } + res = op(res); + output[oc * OH * OW + oh * OW + ow] = res; + } + } + } + } +}; +#undef OUTPUT_TRANSFORM +#undef GET_VECTOR_HIGH_ELEM +#undef GET_VECTOR_LOW_ELEM + +} // namespace + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f) + +void winograd_4x5_1x1_f::filter(const float* filter, + float* filter_transform_buf, + float* transform_mid_buf, size_t OC, size_t IC, + size_t oc_start, size_t oc_end) { + FilterTransform4X5::transform(filter, filter_transform_buf, + transform_mid_buf, OC, IC, oc_start, oc_end); +} + +void winograd_4x5_1x1_f::input(const float* input, float* input_transform_buf, + float* transform_mid_buf, size_t IH, size_t IW, + size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, size_t nr_units_in_tile) { + constexpr int alpha = 4 + 5 - 1; + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + rep(ic, IC) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransform4X5::transform( + input, input_transform_buf, transform_mid_buf, ih_start, + iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); + + } else { + InputTransform4X5::transform( + input, input_transform_buf, transform_mid_buf, ih_start, + iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); + } + } + } +} + +void winograd_4x5_1x1_f::output(const float* output_transform_buf, + const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform4X5<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + + for (size_t oc = oc_start; oc < oc_end; oc++) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_fp32_F45, cb, float, float, bmode, + nonline_mode, output_transform_buf, bias, output, transform_mid_buf, + oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, + nr_units_in_tile, src_dtype, dst_dtype); + } + } +#undef cb +} + +} // namespace winograd +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp new file mode 100644 index 00000000..54d35771 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp @@ -0,0 +1,500 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/fp32/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/winograd/winograd.h" + +#include "src/arm_common/conv_bias/fp32/helper.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F54) + +using namespace megdnn; +using namespace arm_common; +namespace { + +struct FilterTransform5X4 { +#define FILTER_TRANSFORM(d, wd) \ + do { \ + wd##0 = d##0; \ + auto tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \ + auto tmp1 = d##1 * 0.3555556f + d##3 * 0.0888889f; \ + wd##1 = tmp0 + tmp1; \ + wd##2 = tmp0 - tmp1; \ + tmp0 = (d##0 + d##2) * -0.2222222f; \ + tmp1 = (d##1 + d##3) * -0.2222222f; \ + wd##3 = tmp0 + tmp1; \ + wd##4 = tmp0 - tmp1; \ + tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \ + tmp1 = d##1 * 0.0222222f + d##3 * 0.0888889f; \ + wd##5 = tmp0 + tmp1; \ + wd##6 = tmp0 - tmp1; \ + wd##7 = d##3; \ + } while (0) + + static void transform(const float* filter, float* filter_transform_buf, + float* transform_mid_buf, size_t OC, size_t IC, + size_t oc_start, size_t oc_end) { + // Gg * GT + // G + // 1 0 0 0 + // 0.7111111 0.3555556 0.1777778 0.0888889 + // 0.7111111 -0.3555556 0.1777778 -0.0888889 + // -0.2222222 -0.2222222 -0.2222222 -0.2222222 + // -0.2222222 0.2222222 -0.2222222 0.2222222 + // 0.0111111 0.0222222 0.0444444 0.0888889 + // 0.0111111 -0.0222222 0.0444444 -0.0888889 + // 0 0 0 1 + + constexpr size_t alpha = 4 + 5 - 1; + for (size_t oc = oc_start; oc < oc_end; oc++) { + rep(ic, IC) { + const float* fptr = filter + (oc * IC + ic) * 4 * 4; + +#define cb(i) Vector g##i = Vector::load(fptr + 4 * i); + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(i) Vector wd##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + +#define cb(i) Vector wdt##i; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(i) Vector ret##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + FILTER_TRANSFORM(g, wd); +#if MEGDNN_AARCH64 + TRANSPOSE_8x4(wd, wdt); + FILTER_TRANSFORM(wdt, ret); + +#define cb(i) ret##i.save(transform_mid_buf + i * alpha); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + rep(i, alpha) rep(j, alpha) { + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + + oc] = transform_mid_buf[j * alpha + i]; + } +#else + +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ + auto tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.7111111f + \ + GET_VECTOR_ELEM(wd, i, 2) * 0.1777778f; \ + auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.3555556f + \ + GET_VECTOR_ELEM(wd, i, 3) * 0.0888889f; \ + mid_buf1[1] = tmp0 + tmp1; \ + mid_buf1[2] = tmp0 - tmp1; \ + tmp0 = (GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * \ + -0.2222222f; \ + tmp1 = (GET_VECTOR_ELEM(wd, i, 1) + GET_VECTOR_ELEM(wd, i, 3)) * \ + -0.2222222f; \ + mid_buf1[3] = tmp0 + tmp1; \ + mid_buf1[4] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0111111f + \ + GET_VECTOR_ELEM(wd, i, 2) * 0.0444444f; \ + tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0222222f + \ + GET_VECTOR_ELEM(wd, i, 3) * 0.0888889f; \ + mid_buf1[5] = tmp0 + tmp1; \ + mid_buf1[6] = tmp0 - tmp1; \ + mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ + mid_buf1 += 8; \ + } while (0); +#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) + + float* mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(8, cb); + mid_buf1 = transform_mid_buf; +#undef cb + rep(i, alpha) rep(j, alpha) { + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + + oc] = transform_mid_buf[i * alpha + j]; + } +#endif + } + } + } +}; +#undef FILTER_TRANSFORM +#undef GET_VECTOR_ELEM + +struct InputTransform5X4 { +#define INPUT_TRANSFORM(d, wd) \ + do { \ + wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ + auto tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ + auto tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ + wd##1 = tmp0 + tmp1; \ + wd##2 = tmp0 - tmp1; \ + tmp0 = d##2 - d##4 * 4.25f + d##6; \ + tmp1 = d##1 - d##3 * 4.25f + d##5; \ + wd##3 = tmp0 + tmp1; \ + wd##4 = tmp0 - tmp1; \ + tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ + tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ + wd##5 = tmp0 + tmp1; \ + wd##6 = tmp0 - tmp1; \ + wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ + } while (0) + +#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ + vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) +#define GET_VECTOR_LOW_ELEM(s, i, idx) \ + vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) + + template + static void transform(const float* input, float* input_transform_buf, + float* transform_mid_buf, int ih_start, int iw_start, + size_t ic, size_t IH, size_t IW, size_t IC, + size_t unit_idx, size_t nr_units_in_tile) { + // BTd * B + // BT + // 1 0 -5.25 0 5.25 0 -1 0 + // 0 2 4 -2.5 -5 0.5 1 0 + // 0 -2 4 2.5 -5 -0.5 1 0 + // 0 1 1 -4.25 -4.25 1 1 0 + // 0 -1 1 4.25 -4.25 -1 1 0 + // 0 0.5 0.25 -2.5 -1.25 2 1 0 + // 0 -0.5 0.25 2.5 -1.25 -2 1 0 + // 0 -1 0 5.25 0 -5.25 0 1 + + constexpr size_t alpha = 4 + 5 - 1; + if (!inner) { + memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); + } + +#define cb(i) Vector d##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + if (inner) { + const float* input_ptr = + input + ic * IH * IW + ih_start * IW + iw_start; +#define cb(i) d##i = Vector::load(input_ptr + IW * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + transform_mid_buf[iho * alpha + iwo] = + input[ic * IH * IW + ih * IW + iw]; + } + } +#define cb(i) d##i = Vector::load(transform_mid_buf + alpha * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + } + +#define cb(i) Vector wd##i, ret##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + INPUT_TRANSFORM(d, wd); +#if MEGDNN_AARCH64 + TRANSPOSE_8x8(wd, d); + INPUT_TRANSFORM(d, ret); + +#define cb(i) ret##i.save(transform_mid_buf + i * alpha); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + rep(i, alpha) rep(j, alpha) { + input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + + unit_idx * IC + ic] = + transform_mid_buf[j * alpha + i]; + } +#else +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \ + GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ + 5.25 * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ + GET_VECTOR_LOW_ELEM(wd, i, 2)); \ + mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \ + GET_VECTOR_LOW_ELEM(wd, i, 1) + \ + 5.25 * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ + GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ + auto tmp0 = 4 * GET_VECTOR_LOW_ELEM(wd, i, 2) + \ + -5 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ + GET_VECTOR_HIGH_ELEM(wd, i, 2); \ + auto tmp1 = 2 * GET_VECTOR_LOW_ELEM(wd, i, 1) + \ + -2.5 * GET_VECTOR_LOW_ELEM(wd, i, 3) + \ + 0.5 * GET_VECTOR_HIGH_ELEM(wd, i, 1); \ + mid_buf1[1] = tmp0 + tmp1; \ + mid_buf1[2] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ + -4.25 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ + GET_VECTOR_HIGH_ELEM(wd, i, 2); \ + tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \ + GET_VECTOR_LOW_ELEM(wd, i, 3) * -4.25 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 1); \ + mid_buf1[3] = tmp0 + tmp1; \ + mid_buf1[4] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) * 0.25 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 0) * -1.25 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 2); \ + tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5 + \ + GET_VECTOR_LOW_ELEM(wd, i, 3) * -2.5 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2; \ + mid_buf1[5] = tmp0 + tmp1; \ + mid_buf1[6] = tmp0 - tmp1; \ + mid_buf1 += 8; \ + } while (0); + + float* mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(8, cb); + mid_buf1 = transform_mid_buf; + +#undef cb + rep(i, alpha) rep(j, alpha) { + input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + + unit_idx * IC + ic] = + transform_mid_buf[i * alpha + j]; + } +#endif + } +}; +#undef INPUT_TRANSFORM + +#define OUTPUT_TRANSFORM(m, s) \ + do { \ + auto m1addm2 = m##1 + m##2; \ + auto m1subm2 = m##1 - m##2; \ + auto m3addm4 = m##3 + m##4; \ + auto m3subm4 = m##3 - m##4; \ + auto m5addm6 = (m##5 + m##6); \ + auto m5subm6 = (m##5 - m##6); \ + s##0 = m##0; \ + CONCAT(s, 0).add(m1addm2).add(m3addm4).add(m5addm6); \ + CONCAT(s, 1) = m3subm4; \ + CONCAT(s, 1).mla(m1subm2, 0.5f).mla(m5subm6, 2.0f); \ + CONCAT(s, 2) = m3addm4; \ + CONCAT(s, 2).mla(m1addm2, 0.25f).mla(m5addm6, 4.0f); \ + CONCAT(s, 3) = m3subm4; \ + CONCAT(s, 3).mla(m1subm2, 0.125f).mla(m5subm6, 8.0f); \ + CONCAT(s, 4) = m##7; \ + CONCAT(s, 4).mla(m1addm2, 0.0625f).add(m3addm4).mla(m5addm6, 16.0f); \ + } while (0) + +template +struct OutputTransform5X4 { + static void transform(const float* output_transform_buf, const float* bias, + float* output, float* transform_mid_buf, + size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, + size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, + const DType& dst_dtype) { + Op op(src_dtype, dst_dtype); + //! AT * m * A + // AT + // 1 1 1 1 1 1 1 0 + // 0 0.5 -0.5 1 -1 2 -2 0 + // 0 0.25 0.25 1 1 4 4 0 + // 0 0.125 -0.125 1 -1 8 -8 0 + // 0 0.0625 0.0625 1 1 16 16 1 + constexpr size_t alpha = 5 + 4 - 1; + float* mid_buf1 = transform_mid_buf; + + size_t OC = oc_end - oc_start; + size_t oc = oc_start + oc_index; + +#define cb(m, n) \ + transform_mid_buf[m * alpha + n] = \ + output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \ + unit_idx * OC + oc_index]; + UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); +#undef cb + +#define cb(i) auto m##i = Vector::load(transform_mid_buf + alpha * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb +#define cb(i) Vector s##i, ret##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + OUTPUT_TRANSFORM(m, s); +#define cb(i) \ + do { \ + auto m1addm2 = \ + GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto m1subm2 = \ + GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto m3addm4 = \ + GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ + auto m3subm4 = \ + GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ + auto m5addm6 = \ + GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ + auto m5subm6 = \ + GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ + mid_buf1[0] = \ + GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ + mid_buf1[1] = 0.5f * m1subm2 + m3subm4 + 2.0f * m5subm6; \ + mid_buf1[2] = 0.25f * m1addm2 + m3addm4 + 4.0f * m5addm6; \ + mid_buf1[3] = 0.125f * m1subm2 + m3subm4 + 8.0f * m5subm6; \ + mid_buf1[4] = 0.0625f * m1addm2 + m3addm4 + 16.0f * m5addm6 + \ + GET_VECTOR_HIGH_ELEM(s, i, 3); \ + mid_buf1 += 5; \ + } while (0); + + mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(5, cb); + mid_buf1 = transform_mid_buf; +#undef cb + + if (oh_start + 5 <= OH && ow_start + 5 <= OW) { + float32x4_t bias0; + float32_t bias1; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias0 = vdupq_n_f32(bias[oc]); + bias1 = bias[oc]; + } + rep(i, 5) { + size_t oh = oh_start + i; + float32x4_t item0 = vld1q_f32(mid_buf1); + float32_t item1 = mid_buf1[4]; + + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + item0 = vaddq_f32(item0, bias0); + item1 = item1 + bias1; + } else if (bmode == BiasMode::BIAS) { + bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); + bias1 = bias[oc * OH * OW + oh * OW + ow_start + 4]; + item0 = vaddq_f32(item0, bias0); + item1 = item1 + bias1; + } + item0 = op(item0); + item1 = op(item1); + vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); + output[oc * OH * OW + oh * OW + ow_start + 4] = item1; + + mid_buf1 += 5; + } + } else { + for (size_t oho = 0; oho < 5 && oh_start + oho < OH; ++oho) { + for (size_t owo = 0; owo < 5 && ow_start + owo < OW; ++owo) { + size_t oh = oh_start + oho; + size_t ow = ow_start + owo; + float res = mid_buf1[oho * 5 + owo]; + if (bmode == BiasMode::BIAS) { + res += bias[oc * OH * OW + oh * OW + ow]; + } else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + res += bias[oc]; + } + res = op(res); + output[oc * OH * OW + oh * OW + ow] = res; + } + } + } + } +}; +#undef OUTPUT_TRANSFORM +#undef GET_VECTOR_HIGH_ELEM +#undef GET_VECTOR_LOW_ELEM + +} // namespace + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_5x4_1x1_f) + +void winograd_5x4_1x1_f::filter(const float* filter, + float* filter_transform_buf, + float* transform_mid_buf, size_t OC, size_t IC, + size_t oc_start, size_t oc_end) { + FilterTransform5X4::transform(filter, filter_transform_buf, + transform_mid_buf, OC, IC, oc_start, oc_end); +} + +void winograd_5x4_1x1_f::input(const float* input, float* input_transform_buf, + float* transform_mid_buf, size_t IH, size_t IW, + size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, size_t nr_units_in_tile) { + constexpr int alpha = 5 + 4 - 1; + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + + rep(ic, IC) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransform5X4::transform( + input, input_transform_buf, transform_mid_buf, ih_start, + iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); + + } else { + InputTransform5X4::transform( + input, input_transform_buf, transform_mid_buf, ih_start, + iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); + } + } + } +} + +void winograd_5x4_1x1_f::output(const float* output_transform_buf, + const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, + size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform5X4<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + + for (size_t oc = oc_start; oc < oc_end; oc++) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_fp32_F54, cb, float, float, bmode, + nonline_mode, output_transform_buf, bias, output, transform_mid_buf, + oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, + nr_units_in_tile, src_dtype, dst_dtype); + } + } +#undef cb +} + +} // namespace winograd +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp new file mode 100644 index 00000000..7cde4a7c --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp @@ -0,0 +1,424 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/fp32/filter_transform.h" +#include "src/arm_common/conv_bias/fp32/helper.h" +#include "src/arm_common/conv_bias/fp32/strategy.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/winograd/winograd.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63) + +using namespace megdnn; +using namespace arm_common; +namespace { + +/** + * input transform + * + * wd0 = (d0 - d6) + 5.25 * (d4 - d2) + * wd1 = (d6 + d2 - 4.25 * d4) + (d1 + d5 - 4.25 * d3) + * wd2 = (d6 + d2 - 4.25 * d4) - (d1 + d5 - 4.25 * d3) + * wd3 = (d6 + 0.25 * d2 - 1.25 * d4) + 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) + * wd4 = (d6 + 0.25 * d2 - 1.25 * d4) - 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) + * wd5 = (d6 - 5.0 * d4 + 4.0 * d2) + 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) + * wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) + * wd7 = (d7 - d1) + 5.25 * (d3 - d5) + */ +#define INPUT_TRANSFORM(d, wd) \ + do { \ + wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ + auto tmp0 = d##6 + d##2 - d##4 * 4.25f; \ + auto tmp1 = d##1 + d##5 - d##3 * 4.25f; \ + wd##1 = tmp0 + tmp1; \ + wd##2 = tmp0 - tmp1; \ + tmp0 = d##6 + d##2 * 0.25f - d##4 * 1.25f; \ + tmp1 = (d##5 + d##1 * 0.25f - d##3 * 1.25f) * 2.0f; \ + wd##3 = tmp0 + tmp1; \ + wd##4 = tmp0 - tmp1; \ + tmp0 = d6 - d4 * 5.0f + d2 * 4.0f; \ + tmp1 = (d1 + d5 * 0.25f - d3 * 1.25f) * 2.0f; \ + wd##5 = tmp0 + tmp1; \ + wd##6 = tmp0 - tmp1; \ + wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ + } while (0); + +#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ + vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) +#define GET_VECTOR_LOW_ELEM(s, i, idx) \ + vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) +struct InputTransform6X3 { + template + static void transform(const float* input, float* input_transform_buf, + float* transform_mid_buf, int ih_start, int iw_start, + size_t ic, size_t IH, size_t IW, size_t IC, + size_t unit_idx, size_t nr_units_in_tile) { + constexpr size_t alpha = 6 + 3 - 1; + if (!inner) { + memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); + } + +#define cb(i) Vector d##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + if (inner) { + const float* input_ptr = + input + ic * IH * IW + ih_start * IW + iw_start; +#define cb(i) d##i = Vector::load(input_ptr + IW * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + transform_mid_buf[iho * alpha + iwo] = + input[ic * IH * IW + ih * IW + iw]; + } + } +#define cb(i) d##i = Vector::load(transform_mid_buf + alpha * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + } + +#define cb(i) Vector wd##i, ret##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + INPUT_TRANSFORM(d, wd); + +#if MEGDNN_AARCH64 + TRANSPOSE_8x8(wd, d); + INPUT_TRANSFORM(d, ret); + +#define cb(i) ret##i.save(transform_mid_buf + i * alpha); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + rep(i, alpha) rep(j, alpha) { + input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + + unit_idx * IC + ic] = + transform_mid_buf[j * alpha + i]; + } +#else + //! 1 0 0 0 0 0 0 0 + //! 0 1 -1 0.5 -0.5 2 -2 -1 + //! -5.25 1 1 0.25 0.25 4 4 0 + //! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25 + //! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0 + //! 0 1 -1 2 -2 0.5 -0.5 -5.25 + //! -1 1 1 1 1 1 1 0 + //! 0 0 0 0 0 0 0 1 +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \ + GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ + 5.25f * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ + GET_VECTOR_LOW_ELEM(wd, i, 2)); \ + mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \ + GET_VECTOR_LOW_ELEM(wd, i, 1) + \ + 5.25f * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ + GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ + auto tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ + GET_VECTOR_HIGH_ELEM(wd, i, 2) - \ + 4.25f * GET_VECTOR_HIGH_ELEM(wd, i, 0); \ + auto tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \ + GET_VECTOR_HIGH_ELEM(wd, i, 1) - \ + 4.25f * GET_VECTOR_LOW_ELEM(wd, i, 3); \ + mid_buf1[1] = tmp0 + tmp1; \ + mid_buf1[2] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ + 0.25f * GET_VECTOR_LOW_ELEM(wd, i, 2) - \ + GET_VECTOR_HIGH_ELEM(wd, i, 0) * 1.25f; \ + tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5f - \ + GET_VECTOR_LOW_ELEM(wd, i, 3) * 2.5f + \ + GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2.f; \ + mid_buf1[3] = tmp0 + tmp1; \ + mid_buf1[4] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ + (GET_VECTOR_LOW_ELEM(wd, i, 2) - \ + GET_VECTOR_HIGH_ELEM(wd, i, 0) * 1.25f) * \ + 4; \ + tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 2.f - \ + GET_VECTOR_LOW_ELEM(wd, i, 3) * 2.5f + \ + GET_VECTOR_HIGH_ELEM(wd, i, 1) * 0.5f; \ + mid_buf1[5] = tmp0 + tmp1; \ + mid_buf1[6] = tmp0 - tmp1; \ + mid_buf1 += 8; \ + } while (0); + + float* mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(8, cb); + mid_buf1 = transform_mid_buf; + +#undef cb + rep(i, alpha) rep(j, alpha) { + input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + + unit_idx * IC + ic] = + transform_mid_buf[i * alpha + j]; + } +#endif + } +}; + +#undef INPUT_TRANSFORM + +/** + * Output Transform: use fma + * + * s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6) / 32 + * s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6) / 32 + * s2 = (m1 + m2) + 4 * (m3 + m4) + 8 * (m5 + m6) / 32 + * s3 = (m1 - m2) + 8 * (m3 - m4) + 4 * (m5 - m6) / 32 + * s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6) / 32 + * s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) / 32 + m7 + */ +#define OUTPUT_TRANSFORM(m, s) \ + do { \ + auto m1addm2 = m##1 + m##2; \ + auto m1subm2 = m##1 - m##2; \ + auto m3addm4 = m##3 + m##4; \ + auto m3subm4 = m##3 - m##4; \ + auto m5addm6 = (m##5 + m##6) * 0.03125f; \ + auto m5subm6 = (m##5 - m##6) * 0.03125f; \ + s##0 = m##0; \ + CONCAT(s, 0).mla(m5addm6, 32.f).add(m3addm4).add(m1addm2); \ + CONCAT(s, 1) = m1subm2; \ + CONCAT(s, 1).mla(m3subm4, 2.f).mla(m5subm6, 16.f); \ + CONCAT(s, 2) = m1addm2; \ + CONCAT(s, 2).mla(m3addm4, 4.f).mla(m5addm6, 8.f); \ + CONCAT(s, 3) = m1subm2; \ + CONCAT(s, 3).mla(m3subm4, 8.f).mla(m5subm6, 4.f); \ + CONCAT(s, 4) = m1addm2; \ + CONCAT(s, 4).mla(m3addm4, 16.f).mla(m5addm6, 2.f); \ + CONCAT(s, 5) = m1subm2; \ + CONCAT(s, 5).mla(m3subm4, 32.f).add(m5subm6).add(m##7); \ + } while (0); + +template +struct OutputTransform6X3 { + static void transform(const float* output_transform_buf, const float* bias, + float* output, float* transform_mid_buf, + size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, + size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, + const DType& dst_dtype) { + constexpr size_t alpha = 6 + 3 - 1; + Op op(src_dtype, dst_dtype); + float* mid_buf1 = transform_mid_buf; + + //! AT * m * A + size_t OC = oc_end - oc_start; + size_t oc = oc_start + oc_index; + +#define cb(m, n) \ + transform_mid_buf[m * alpha + n] = \ + output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \ + unit_idx * OC + oc_index]; + UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); +#undef cb + +#define cb(i) auto m##i = Vector::load(transform_mid_buf + alpha * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb +#define cb(i) Vector s##i, ret##i; + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + OUTPUT_TRANSFORM(m, s); + /** + * Output transform: m * A + * + * 1 0 0 0 0 0 + * 1 1 1 1 1 1 + * 1 -1 1 -1 1 -1 + * 1 2 4 8 16 32 + * 1 -2 4 -8 16 -32 + * 1 0.5 0.25 0.125 0.0625 0.03125 + * 1 -0.5 0.25 -0.125 0.0625 -0.03125 + * 0 0.0 0 0 0 1 + */ +#define cb(i) \ + do { \ + auto m1addm2 = \ + GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto m1subm2 = \ + GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto m3addm4 = \ + GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ + auto m3subm4 = \ + GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ + auto m5addm6 = \ + GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ + auto m5subm6 = \ + GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ + mid_buf1[0] = \ + GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ + mid_buf1[1] = m1subm2 + 2.f * m3subm4 + 0.5f * m5subm6; \ + mid_buf1[2] = m1addm2 + 4.f * m3addm4 + 0.25f * m5addm6; \ + mid_buf1[3] = m1subm2 + 8.f * m3subm4 + 0.125f * m5subm6; \ + mid_buf1[4] = m1addm2 + 16.f * m3addm4 + 0.0625f * m5addm6; \ + mid_buf1[5] = m1subm2 + 32.f * m3subm4 + 0.03125f * m5subm6 + \ + GET_VECTOR_HIGH_ELEM(s, i, 3); \ + mid_buf1 += 6; \ + } while (0); + + mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(6, cb); + mid_buf1 = transform_mid_buf; +#undef cb + + if (oh_start + 6 <= OH && ow_start + 6 <= OW) { + float32x4_t bias0; + float32x2_t bias1; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias0 = vdupq_n_f32(bias[oc]); + bias1 = vdup_n_f32(bias[oc]); + } + rep(i, 6) { + size_t oh = oh_start + i; + float32x4_t item0 = vld1q_f32(mid_buf1); + float32x2_t item1 = vld1_f32(mid_buf1 + 4); + + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + item0 = vaddq_f32(item0, bias0); + item1 = vadd_f32(item1, bias1); + } else if (bmode == BiasMode::BIAS) { + bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); + bias1 = vld1_f32(bias + oc * OH * OW + oh * OW + ow_start + + 4); + item0 = vaddq_f32(item0, bias0); + item1 = vadd_f32(item1, bias1); + } + item0 = op(item0); + item1 = vset_lane_f32(op(vget_lane_f32(item1, 0)), item1, 0); + item1 = vset_lane_f32(op(vget_lane_f32(item1, 1)), item1, 1); + vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); + vst1_f32(output + oc * OH * OW + oh * OW + ow_start + 4, item1); + + mid_buf1 += 6; + } + } else { + for (size_t oho = 0; oho < 6 && oh_start + oho < OH; ++oho) { + for (size_t owo = 0; owo < 6 && ow_start + owo < OW; ++owo) { + size_t oh = oh_start + oho; + size_t ow = ow_start + owo; + float res = mid_buf1[oho * 6 + owo]; + if (bmode == BiasMode::BIAS) { + res += bias[oc * OH * OW + oh * OW + ow]; + } else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + res += bias[oc]; + } + res = op(res); + output[oc * OH * OW + oh * OW + ow] = res; + } + } + } + } +}; + +#undef GET_VECTOR_HIGH_ELEM +#undef GET_VECTOR_LOW_ELEM +#undef OUTPUT_TRANSFORM + +} // namespace + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f) + +void winograd_6x3_1x1_f::filter(const float* filter, + float* filter_transform_buf, + float* transform_mid_buf, size_t OC, size_t IC, + size_t oc_start, size_t oc_end) { + FilterTransform6X3::transform( + filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, + oc_end); +} + +void winograd_6x3_1x1_f::input(const float* input, float* input_transform_buf, + float* transform_mid_buf, size_t IH, size_t IW, + size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, size_t nr_units_in_tile) { + constexpr int alpha = 3 + 6 - 1; + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + rep(ic, IC) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransform6X3::transform( + input, input_transform_buf, transform_mid_buf, ih_start, + iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); + + } else { + InputTransform6X3::transform( + input, input_transform_buf, transform_mid_buf, ih_start, + iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); + } + } + } +} + +void winograd_6x3_1x1_f::output(const float* output_transform_buf, + const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, + size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + + for (size_t oc = oc_start; oc < oc_end; oc++) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_fp32_F63, cb, float, float, bmode, + nonline_mode, output_transform_buf, bias, output, transform_mid_buf, + oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, + nr_units_in_tile, src_dtype, dst_dtype); + } + } +#undef cb +} + +} // namespace winograd +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp new file mode 100644 index 00000000..ca42913e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp @@ -0,0 +1,351 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/fp32/filter_transform.h" +#include "src/arm_common/conv_bias/fp32/helper.h" +#include "src/arm_common/conv_bias/fp32/strategy.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/common/winograd/winograd_helper.h" +#include "src/fallback/conv_bias/winograd/winograd.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63_4x4) + +using namespace megdnn; +using namespace arm_common; + +namespace { + +struct InputTransform6X3 { + template + static void prepare(const float* input, float* patch, float* patchT, + int ih_start, int iw_start, size_t IH, size_t IW, + size_t ic, size_t IC) { + constexpr size_t alpha = 6 + 3 - 1; + if (!(inner && ic + 4 < IC)) { + memset(patch, 0, sizeof(float) * 4 * alpha * alpha); + } + if (inner) { + const float* input_ptr = + input + ic * IH * IW + ih_start * IW + iw_start; + for (size_t ico = 0; ico < 4; ++ico) { + if (ic + ico < IC) { +#define cb(i) \ + auto v##i##0 = vld1q_f32(input_ptr + IW * i); \ + auto v##i##1 = vld1q_f32(input_ptr + IW * i + 4); + + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + +#define cb(i) \ + vst1q_f32(patch + ico * 8 * alpha + i * 8, v##i##0); \ + vst1q_f32(patch + ico * 8 * alpha + i * 8 + 4, v##i##1); + + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + input_ptr += IH * IW; + } + } + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + // partial copy + for (size_t ico = 0; ico < 4; ++ico) { + if (ic + ico < IC) { + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + patch[ico * alpha * 8 + iho * 8 + iwo] = + input[(ic + ico) * IH * IW + ih * IW + iw]; + } + } + } + } + } + +#define cb(i) \ + transpose_4x4(patch + 8 * i + 0, patchT + 8 * i * 4, 64, 4); \ + transpose_4x4(patch + 8 * i + 4, patchT + 8 * i * 4 + 4 * 4, 64, 4); + UNROLL_CALL_NOWRAPPER(8, cb) +#undef cb + } + + static void transform(const float* patchT, float* input_transform_buf, + size_t unit_idx, size_t nr_units_in_tile, size_t ic, + size_t IC) { + constexpr size_t alpha = 6 + 3 - 1; + // BT * d * B +#define cb(m, n) \ + Vector d##m##n = \ + Vector::load(patchT + m * 8 * 4 + n * 4); + + UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); +#undef cb + + //! B + //! 1 0 0 0 0 0 0 0 + //! 0 1 -1 0.5 -0.5 2 -2 -1 + //! -5.25 1 1 0.25 0.25 4 4 0 + //! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25 + //! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0 + //! 0 1 -1 2 -2 0.5 -0.5 -5.25 + //! -1 1 1 1 1 1 1 0 + //! 0 0 0 0 0 0 0 1 +#define cb(m) \ + auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \ + auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \ + auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \ + auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \ + d5##m * 2.f + d6##m; \ + auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - \ + d4##m * 1.25f - d5##m * 2.f + d6##m; \ + auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \ + d5##m * 0.5f + d6##m; \ + auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \ + d5##m * 0.5f + d6##m; \ + auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; + + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + +#define cb(m) \ + d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \ + d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - \ + (t##m##3 + t##m##4) * 4.25f; \ + d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + \ + (t##m##3 - t##m##4) * 4.25f; \ + d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - \ + t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; \ + d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - \ + t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; \ + d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \ + t##m##5 * 0.5f + t##m##6; \ + d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - \ + t##m##4 * 5.f - t##m##5 * 0.5f + t##m##6; \ + d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; + + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + size_t ICB = IC / 4; + size_t icb = ic / 4; +#define cb(m, n) \ + d##m##n.save(input_transform_buf + \ + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ + icb * nr_units_in_tile * 4 + unit_idx * 4); + UNROLL_CALL_NOWRAPPER_D2(8, 8, cb) +#undef cb + } +}; + +template +struct OutputTransform6X3 { + static void transform(const float* output_transform_buf, const float* bias, + float* output, float* transform_mid_buf, + size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, + size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, + const DType& dst_dtype) { + Op op(src_dtype, dst_dtype); + //! AT * m * A + constexpr size_t alpha = 6 + 3 - 1; + + size_t oc = oc_start + oc_index; + size_t OCB = (oc_end - oc_start) / 4; + size_t ocb = oc_index / 4; + +#define cb(m, n) \ + auto v##m##n = Vector::load( \ + output_transform_buf + \ + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ + ocb * nr_units_in_tile * 4 + unit_idx * 4); + UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); +#undef cb + + /** + * A + * + * 1 0 0 0 0 0 + * 1 1 1 1 1 1 + * 1 -1 1 -1 1 -1 + * 1 2 4 8 16 32 + * 1 -2 4 -8 16 -32 + * 1 0.5 0.25 0.125 0.0625 0.03125 + * 1 -0.5 0.25 -0.125 0.0625 -0.03125 + * 0 0.0 0 0 0 1 + */ + + Vector v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; +#define cb(m) \ + v1addv2 = v1##m + v2##m; \ + v1subv2 = v1##m - v2##m; \ + v3addv4 = v3##m + v4##m; \ + v3subv4 = v3##m - v4##m; \ + v5addv6 = v5##m + v6##m; \ + v5subv6 = v5##m - v6##m; \ + auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \ + auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ + auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ + auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ + auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ + auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; + + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + +#define cb(m) \ + v1addv2 = t##m##1 + t##m##2; \ + v1subv2 = t##m##1 - t##m##2; \ + v3addv4 = t##m##3 + t##m##4; \ + v3subv4 = t##m##3 - t##m##4; \ + v5addv6 = t##m##5 + t##m##6; \ + v5subv6 = t##m##5 - t##m##6; \ + v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \ + v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ + v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ + v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ + v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ + v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; + + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + Vector vbias; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias = Vector::load(bias + oc); + +#define cb(m, n) v##m##n += vbias; + UNROLL_CALL_RAW_D2(6, 6, cb); +#undef cb + } + if (bmode != BiasMode::BIAS) { +#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); + UNROLL_CALL_RAW_D2(6, 6, cb); +#undef cb + } + +#define cb(m, n) CONCAT(v##m, n).save(transform_mid_buf + (m * 6 + n) * 4); + UNROLL_CALL_RAW_D2(6, 6, cb); +#undef cb + + for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { + for (size_t oho = 0; oho < 6 && oh_start + oho < OH; ++oho) { + for (size_t owo = 0; owo < 6 && ow_start + owo < OW; ++owo) { + size_t oh = oh_start + oho; + size_t ow = ow_start + owo; + float res = transform_mid_buf[oho * 6 * 4 + owo * 4 + oco]; + if (bmode == BiasMode::BIAS) { + res += bias[(oc + oco) * OH * OW + oh * OW + ow]; + res = op(res); + } + output[(oc + oco) * OH * OW + oh * OW + ow] = res; + } + } + } + } +}; +} // namespace + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_4x4_f) + +void winograd_6x3_4x4_f::filter(const float* filter, + float* filter_transform_buf, + float* transform_mid_buf, size_t OC, size_t IC, + size_t oc_start, size_t oc_end) { + FilterTransform6X3::transform( + filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, + oc_end); +} + +void winograd_6x3_4x4_f::input(const float* input, float* input_transform_buf, + float* transform_mid_buf, size_t IH, size_t IW, + size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, size_t nr_units_in_tile) { + megdnn_assert(IC % 4 == 0); + constexpr int alpha = 3 + 6 - 1; + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + float* patch = transform_mid_buf; + float* patchT = transform_mid_buf + 4 * alpha * alpha; + + for (size_t ic = 0; ic < IC; ic += 4) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransform6X3::prepare(input, patch, patchT, ih_start, + iw_start, IH, IW, ic, IC); + InputTransform6X3::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, ic, + IC); + + } else { + InputTransform6X3::prepare(input, patch, patchT, + ih_start, iw_start, IH, IW, + ic, IC); + InputTransform6X3::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, ic, + IC); + } + } + } +} + +void winograd_6x3_4x4_f::output(const float* output_transform_buf, + const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + + for (size_t oc = oc_start; oc < oc_end; oc += 4) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_fp32_F63_4x4, cb, float, float, bmode, + nonline_mode, output_transform_buf, bias, output, transform_mid_buf, + oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, + nr_units_in_tile, src_dtype, dst_dtype); + } + } +#undef cb +} + +} // namespace winograd +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/img2col_helper.h b/dnn/src/arm_common/conv_bias/img2col_helper.h new file mode 100644 index 00000000..84850fd2 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/img2col_helper.h @@ -0,0 +1,82 @@ +/** + * \file dnn/src/arm_common/conv_bias/img2col_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include +#include "src/common/utils.h" + +namespace { + +template +void img2col_stride(const dtype* __restrict src, + dtype* __restrict dst, const int OC, const int OH, + const int OW, const int IC, const int IH, const int IW, + const int FH, const int FW, const int SH, const int SW) { + (void)OC; + size_t i = 0; + rep(ic, IC) { + rep(fh, FH) { + rep(fw, FW) { + rep(oh, OH) { + rep(ow, OW) { + int fh2, fw2; + if (is_xcorr) { + fh2 = fh; + fw2 = fw; + } else { + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; + } + dst[i++] = src[ic * IH * IW + (oh * SH + fh2) * IW + + (ow * SW + fw2)]; + } + } + } + } + } +} + +template +void img2col(const dtype* src, dtype* dst, size_t /* OC */, size_t OH, + size_t OW, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW) { + size_t offset = (4 - OW % 4) % 4; + size_t i = 0; + rep(ic, IC) { + rep(fh, FH) { + rep(fw, FW) { + rep(oh, OH) { + size_t ow = 0; + for (; ow < OW; ow += 4) { + size_t fh2, fw2; + if (is_xcorr) { + fh2 = fh; + fw2 = fw; + } else { + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; + } + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + + (ow + fw2) + 0]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + + (ow + fw2) + 1]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + + (ow + fw2) + 2]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + + (ow + fw2) + 3]; + } + i -= offset; + } + } + } + } +} + +} // anonymous namespace + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/algos.cpp b/dnn/src/arm_common/conv_bias/int8/algos.cpp new file mode 100644 index 00000000..625ef9c6 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/algos.cpp @@ -0,0 +1,275 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8/algos.h" +#include "src/arm_common/conv_bias/int8/channel_wise_nchw44.h" +#include "src/arm_common/conv_bias/int8/strategy.h" +#include "src/arm_common/conv_bias/int8/stride1.h" +#include "src/arm_common/conv_bias/int8/stride1_dotprod.h" +#include "src/arm_common/conv_bias/int8/stride2.h" +#include "src/arm_common/conv_bias/int8/stride2_dotprod.h" +#include "src/arm_common/elemwise_op.h" +#include "src/fallback/conv_bias/common.h" + +#include "midout.h" + +using namespace megdnn; +using namespace arm_common; + +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8) +/* ===================== stride1 algo ===================== */ +bool ConvBiasImpl::AlgoS8DirectStride1::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + bool avaible = direct_int8_stride1::can_conv_direct_stride1_int8(param); + auto fm = param.filter_meta; + if (algo_selection_strategy == + ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { + bool large_group = fm.group >= param.nr_threads; + avaible &= (large_group == m_large_group); + } + return avaible; +} +bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred( + megdnn::fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto OC = fm.ocpg; + auto IC = fm.icpg; + bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) || + ((FH == 3 || FH == 5 || FH == 7) && + (OC <= 16 || (IC <= 4 && OC <= 32)))) && + param.bias_mode != BiasMode::BIAS; + return preferred; +} + +size_t ConvBiasImpl::AlgoS8DirectStride1::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto bundle = direct_int8_stride1::get_bundle(param, m_large_group); + return bundle.total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 0) { + return direct_int8_stride1::get_kimpls(param, m_large_group); + } + MIDOUT_END(); + return {}; +} + +/* ===================== stride1 algo ===================== */ +bool ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + return channel_wise_nchw44::stride1::is_available(param); +} + +size_t ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto bundle = channel_wise_nchw44::stride1::get_bundle(param); + return bundle.total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoS8ChanWiseStride1NCHW44"_hash)) { + return channel_wise_nchw44::stride1::get_kimpls(param); + } + MIDOUT_END(); + return {}; +} + +/* ===================== stride2 algo ===================== */ +bool ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + return channel_wise_nchw44::stride2::is_available(param); +} + +size_t ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto bundle = channel_wise_nchw44::stride2::get_bundle(param); + return bundle.total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoS8ChanWiseStride2NCHW44"_hash)) { + return channel_wise_nchw44::stride2::get_kimpls(param); + } + MIDOUT_END(); + return {}; +} + +/* ===================== stride2 algo ===================== */ +bool ConvBiasImpl::AlgoS8DirectStride2::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + bool avaible = direct_int8_stride2::can_conv_direct_stride2_int8(param); + if (algo_selection_strategy == + ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + avaible &= (large_group == m_large_group); + } + return avaible; +} + +size_t ConvBiasImpl::AlgoS8DirectStride2::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto bundle = direct_int8_stride2::get_bundle(param, m_large_group); + return bundle.total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 1) { + return direct_int8_stride2::get_kimpls(param, m_large_group); + } + MIDOUT_END(); + return {}; +} + +#if __ARM_FEATURE_DOTPROD +/* ===================== dot stride1 algo ======================== */ +bool ConvBiasImpl::AlgoDotS8DirectStride1::usable( + FallbackConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + bool avaible = + direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param); + + if (algo_selection_strategy == + ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + avaible &= (large_group == m_large_group); + } + + return avaible; +} + +size_t ConvBiasImpl::AlgoDotS8DirectStride1::get_workspace( + FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { + auto bundle = direct_dotprod_int8_stride1::get_bundle(param, m_large_group); + return bundle.total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns( + FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 1) { + return direct_dotprod_int8_stride1::get_kimpls(param, m_large_group); + } + MIDOUT_END(); + return {}; +} + +/* ===================== dot stride2 algo ======================== */ +bool ConvBiasImpl::AlgoDotS8DirectStride2::usable( + FallbackConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + bool avaible = + direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param); + if (algo_selection_strategy == + ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + avaible &= (large_group == m_large_group); + } + return avaible; +} + +size_t ConvBiasImpl::AlgoDotS8DirectStride2::get_workspace( + FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { + auto bundle = direct_dotprod_int8_stride2::get_bundle(param, m_large_group); + return bundle.total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns( + FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 2) { + return direct_dotprod_int8_stride2::get_kimpls(param, m_large_group); + } + MIDOUT_END(); + return {}; +} +#endif + +/* ======================= AlgoS8WinogradF23_8x8 ======================== */ + +bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable( + fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) + return false; + using Strategy = winograd::winograd_2x3_8x8_s8; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + ((opr->param().format == param::ConvBias::Format::NCHW && + param.filter_type.enumv() == DTypeEnum::QuantizedS8) || + (opr->param().format == param::ConvBias::Format::NCHW_WINOGRAD && + opr->param().output_block_size == 2 && + param.winograd_matmul_format == param::MatrixMul::Format::MK8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS16)) && + opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.bias_type.enumv() == DTypeEnum::QuantizedS32 && + param.dst_type.enumv() == DTypeEnum::QuantizedS8; +} + +size_t ConvBiasImpl::AlgoS8WinogradF23_8x8::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + winograd::winograd_2x3_8x8_s8 strategy(param.src_type, param.filter_type, + param.dst_type); + return megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_workspace_size(param, m_matmul_algo); +} + +SmallVector +ConvBiasImpl::AlgoS8WinogradF23_8x8::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 0, 2) { + winograd::winograd_2x3_8x8_s8 strategy( + param.src_type, param.filter_type, param.dst_type); + auto winograd_impl = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg); + return winograd_impl.get_kerns(param, m_matmul_algo); + } + MIDOUT_END(); + return {}; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/algos.h b/dnn/src/arm_common/conv_bias/int8/algos.h new file mode 100644 index 00000000..b03551f9 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/algos.h @@ -0,0 +1,211 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class ConvBiasImpl::AlgoS8DirectStride1 final : public AlgoBase { + bool m_large_group; + +public: + AlgoS8DirectStride1(bool large_group) : m_large_group(large_group) {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "S8STRD1_LARGE_GROUP" : "S8STRD1_SMALL_GROUP"; + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + + bool is_preferred(megdnn::fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoS8DirectStride1NCHW44 final : public AlgoBase { +public: + AlgoS8DirectStride1NCHW44() {} + bool is_reproducible() const override { return true; } + const char* name() const override { return "S8_NCHW44_DIRECT_STRD1"; } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + + bool is_preferred(megdnn::fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { + bool m_large_group; + +public: + AlgoS8DirectStride2(bool large_group) : m_large_group(large_group) {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "S8STRD2_LARGE_GROUP" : "S8STRD2_SMALL_GROUP"; + } + bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoS8DirectStride2NCHW44 final : public AlgoBase { +public: + AlgoS8DirectStride2NCHW44() {} + bool is_reproducible() const override { return true; } + const char* name() const override { return "S8_NCHW44_DIRECT_STRD2"; } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + bool is_preferred(megdnn::fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44 final : public AlgoBase { +public: + AlgoS8DirectStride2NCHWNCHW44() {} + bool is_reproducible() const override { return true; } + const char* name() const override { return "S8_CONV_NCHW_NCHW44"; } + bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + bool is_preferred(megdnn::fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "S8_CHAN_WISE_STRD1_NCHW44"; } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "S8_CHAN_WISE_STRD2_NCHW44"; } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + +#if __ARM_FEATURE_DOTPROD +class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { + bool m_large_group; + +public: + AlgoDotS8DirectStride1(bool large_group) : m_large_group(large_group) {} + + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "ARMDOTS8STRD1_LARGE_GROUP" + : "ARMDOTS8STRD1_SMALL_GROUP"; + } + bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam&, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(FallbackConvBiasImpl*, + const NCBKernSizeParam&) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { + bool m_large_group; + +public: + AlgoDotS8DirectStride2(bool large_group) : m_large_group(large_group) {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "ARMDOTS8STRD2_LARGE_GROUP" + : "ARMDOTS8STRD2_SMALL_GROUP"; + } + + bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam&, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(FallbackConvBiasImpl*, + const NCBKernSizeParam&) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; +#endif + +class ConvBiasImpl::AlgoS8WinogradF23_8x8 final : public AlgoBase { +public: + AlgoS8WinogradF23_8x8(fallback::MatrixMulImpl::AlgoBase* matmul_algo, + uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {8, 2, m_tile_size}); + } + return m_name.c_str(); + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + static std::vector + get_avaiable_matmul_algos(const NCBKernSizeParam& param); + +private: + fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; + mutable std::string m_name; + uint32_t m_tile_size; +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp b/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp new file mode 100644 index 00000000..b3f37cb3 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp @@ -0,0 +1,1643 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; + +static inline void accumulate_2_q_vector(int8x16_t& src0, int8x16_t& kern0, + int8x16_t& src1, int8x16_t& kern1, + int32x4_t* sum) { + int16x8_t tmp_sum0 = vmull_s8(vget_low_s8(src0), vget_low_s8(kern0)); + int16x8_t tmp_sum1 = vmull_high_s8(src0, kern0); + tmp_sum0 = vmlal_s8(tmp_sum0, vget_low_s8(src1), vget_low_s8(kern1)); + tmp_sum1 = vmlal_high_s8(tmp_sum1, src1, kern1); + sum[0] = vaddw_s16(sum[0], vget_low_s16(tmp_sum0)); + sum[1] = vaddw_s16(sum[1], vget_high_s16(tmp_sum0)); + sum[2] = vaddw_s16(sum[2], vget_low_s16(tmp_sum1)); + sum[3] = vaddw_s16(sum[3], vget_high_s16(tmp_sum1)); +} + +static inline void accumulate_1_q_vector(int8x16_t& src0, int8x16_t& kern0, + int32x4_t* sum) { + int16x8_t tmp_sum0 = vmull_s8(vget_low_s8(src0), vget_low_s8(kern0)); + int16x8_t tmp_sum1 = vmull_high_s8(src0, kern0); + sum[0] = vaddw_s16(sum[0], vget_low_s16(tmp_sum0)); + sum[1] = vaddw_s16(sum[1], vget_high_s16(tmp_sum0)); + sum[2] = vaddw_s16(sum[2], vget_low_s16(tmp_sum1)); + sum[3] = vaddw_s16(sum[3], vget_high_s16(tmp_sum1)); +} + +static inline void accumulate_2_d_vector(int8x16_t& src0, int8x8_t& kern0, + int8x16_t& src1, int8x8_t& kern1, + int32x4_t& sum0, int32x4_t& sum1) { + int16x8_t tmp_sum0 = vmull_s8(vget_low_s8(src0), kern0); + int16x8_t tmp_sum1 = vmull_s8(vget_high_s8(src0), kern0); + tmp_sum0 = vmlal_s8(tmp_sum0, vget_low_s8(src1), kern1); + tmp_sum1 = vmlal_s8(tmp_sum1, vget_high_s8(src1), kern1); + sum0 = vaddw_s16(sum0, vget_low_s16(tmp_sum0)); + sum1 = vaddw_s16(sum1, vget_low_s16(tmp_sum1)); + sum0 = vaddw_s16(sum0, vget_high_s16(tmp_sum0)); + sum1 = vaddw_s16(sum1, vget_high_s16(tmp_sum1)); +} + +static inline void accumulate_1_line_horizon(const int8x8_t& src0, + const int8x8_t& kern0, + const int8x8_t& src1, + const int8x8_t& kern1, + int32x4_t& sum) { + int16x8_t tmp_sum = vmull_s8(src0, kern0); + tmp_sum = vmlal_s8(tmp_sum, src1, kern1); + sum = vaddw_s16(sum, vget_low_s16(tmp_sum)); + sum = vaddw_s16(sum, vget_high_s16(tmp_sum)); +} + +static inline void accumulate_1_d_vector(const int8x8_t& src0, + const int8x8_t& kern0, + int32x4_t& sum) { + int16x8_t tmp_sum = vmull_s8(src0, kern0); + sum = vaddw_s16(sum, vget_low_s16(tmp_sum)); + sum = vaddw_s16(sum, vget_high_s16(tmp_sum)); +} + +#define ACC_S16_S32(sum, tmp_sum) \ + sum = vaddw_s16(sum, vget_low_s16(tmp_sum)); \ + sum = vaddw_s16(sum, vget_high_s16(tmp_sum)); + +#define STORE_1_LINE(dst, oh, ow, OW, sum) \ + if (quantized) { \ + dt_qint8* dptr = \ + reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ + op({{sum[0], sum[1]}}, dptr); \ + op({{sum[2], sum[3]}}, dptr + 8); \ + } else { \ + dt_int32* dptr = \ + reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ + vst1q_s32(dptr, sum[0]); \ + vst1q_s32(dptr + 4, sum[1]); \ + vst1q_s32(dptr + 8, sum[2]); \ + vst1q_s32(dptr + 12, sum[3]); \ + } + +#define STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum, remain) \ + if (quantized) { \ + dt_qint8* dptr = \ + reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ + if (remain == 1) { \ + op(sum[0], dptr); \ + } else if (remain == 2) { \ + op({{sum[0], sum[1]}}, dptr); \ + } else if (remain == 3) { \ + op({{sum[0], sum[1]}}, dptr); \ + op(sum[2], dptr + 8); \ + } \ + } else { \ + dt_int32* dptr = \ + reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ + if (remain == 1) { \ + vst1q_s32(dptr, sum[0]); \ + } else if (remain == 2) { \ + vst1q_s32(dptr, sum[0]); \ + vst1q_s32(dptr + 4, sum[1]); \ + } else if (remain == 3) { \ + vst1q_s32(dptr, sum[0]); \ + vst1q_s32(dptr + 4, sum[1]); \ + vst1q_s32(dptr + 8, sum[2]); \ + } \ + } + +template +void channel_wise_nchw44::direct_stride1_2x2_int8( + const int8_t* src, const int8_t* filter, const int32_t* bias, void* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + int8x8_t kern01 = vld1_s8(filter); + int8x8_t kern23 = vld1_s8(filter + 8); + size_t oh = 0_z; + for (; oh + 2 <= OH; oh += 2) { + size_t ih = oh; + size_t ow = 0_z; + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int32_t* __restrict bptr = bias; + int32x4_t sum00; + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vld1q_s32(bptr); + } else { + sum00 = vdupq_n_s32(0); + } + int32x4_t sum01 = sum00, sum02 = sum00, sum03 = sum00, + sum10 = sum00, sum11 = sum00, sum12 = sum00, + sum13 = sum00; + int8x16_t src0 = vld1q_s8(sptr0); + int8x8_t src03 = vld1_s8(sptr0 + 3 * 4), src00 = vget_low_s8(src0), + src02 = vget_high_s8(src0); + int8x8_t src01 = vext_s8(src00, src02, 4); + + int8x16_t src1 = vld1q_s8(sptr1); + int8x8_t src13 = vld1_s8(sptr1 + 3 * 4), src10 = vget_low_s8(src1), + src12 = vget_high_s8(src1); + int8x8_t src11 = vext_s8(src10, src12, 4); + + int8x16_t src2 = vld1q_s8(sptr2); + int8x8_t src23 = vld1_s8(sptr2 + 3 * 4), src20 = vget_low_s8(src2), + src22 = vget_high_s8(src2); + int8x8_t src21 = vext_s8(src20, src22, 4); + //! first line + int16x8_t tmp_sum00 = vmull_s8(src00, kern01); + tmp_sum00 = vmlal_s8(tmp_sum00, src10, kern23); + ACC_S16_S32(sum00, tmp_sum00); + + int16x8_t tmp_sum01 = vmull_s8(src01, kern01); + tmp_sum01 = vmlal_s8(tmp_sum01, src11, kern23); + ACC_S16_S32(sum01, tmp_sum01); + + int16x8_t tmp_sum02 = vmull_s8(src02, kern01); + tmp_sum02 = vmlal_s8(tmp_sum02, src12, kern23); + ACC_S16_S32(sum02, tmp_sum02); + + int16x8_t tmp_sum03 = vmull_s8(src03, kern01); + tmp_sum03 = vmlal_s8(tmp_sum03, src13, kern23); + ACC_S16_S32(sum03, tmp_sum03); + //! second line + int16x8_t tmp_sum10 = vmull_s8(src10, kern01); + tmp_sum10 = vmlal_s8(tmp_sum10, src20, kern23); + ACC_S16_S32(sum10, tmp_sum10); + + int16x8_t tmp_sum11 = vmull_s8(src11, kern01); + tmp_sum11 = vmlal_s8(tmp_sum11, src21, kern23); + ACC_S16_S32(sum11, tmp_sum11); + + int16x8_t tmp_sum12 = vmull_s8(src12, kern01); + tmp_sum12 = vmlal_s8(tmp_sum12, src22, kern23); + ACC_S16_S32(sum12, tmp_sum12); + + int16x8_t tmp_sum13 = vmull_s8(src13, kern01); + tmp_sum13 = vmlal_s8(tmp_sum13, src23, kern23); + ACC_S16_S32(sum13, tmp_sum13); + if (quantized) { + dt_qint8* dptr = + reinterpret_cast(dst) + oh * OW * 4 + ow * 4; + op({{sum00, sum01}}, dptr); + op({{sum02, sum03}}, dptr + 8); + op({{sum10, sum11}}, dptr + OW * 4); + op({{sum12, sum13}}, dptr + OW * 4 + 8); + } else { + dt_int32* dptr = + reinterpret_cast(dst) + oh * OW * 4 + ow * 4; + vst1q_s32(dptr, sum00); + vst1q_s32(dptr + 4, sum01); + vst1q_s32(dptr + 8, sum02); + vst1q_s32(dptr + 12, sum03); + vst1q_s32(dptr + OW * 4, sum10); + vst1q_s32(dptr + OW * 4 + 4, sum11); + vst1q_s32(dptr + OW * 4 + 8, sum12); + vst1q_s32(dptr + OW * 4 + 12, sum13); + } + } + for (; ow < OW; ow++) { + size_t iw = ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int32_t* __restrict bptr = bias; + int32x4_t sum00; + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vld1q_s32(bptr); + } else { + sum00 = vdupq_n_s32(0); + } + int32x4_t sum10 = sum00; + int8x8_t src00 = vld1_s8(sptr0); + int8x8_t src10 = vld1_s8(sptr1); + int8x8_t src20 = vld1_s8(sptr2); + + int16x8_t tmp_sum00 = vmull_s8(src00, kern01); + tmp_sum00 = vmlal_s8(tmp_sum00, src10, kern23); + ACC_S16_S32(sum00, tmp_sum00); + + int16x8_t tmp_sum10 = vmull_s8(src10, kern01); + tmp_sum10 = vmlal_s8(tmp_sum10, src20, kern23); + ACC_S16_S32(sum10, tmp_sum10); + + if (quantized) { + dt_qint8* dptr = + reinterpret_cast(dst) + oh * OW * 4 + ow * 4; + op(sum00, dptr); + op(sum10, dptr + OW * 4); + } else { + dt_int32* dptr = + reinterpret_cast(dst) + oh * OW * 4 + ow * 4; + vst1q_s32(dptr, sum00); + vst1q_s32(dptr + OW * 4, sum10); + } + } + } + for (; oh < OH; oh++) { + size_t ih = oh; + size_t ow = 0_z; + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int32_t* __restrict bptr = bias; + int32x4_t sum00; + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vld1q_s32(bptr); + } else { + sum00 = vdupq_n_s32(0); + } + int32x4_t sum01 = sum00, sum02 = sum00, sum03 = sum00; + + int8x16_t src0 = vld1q_s8(sptr0); + int8x8_t src03 = vld1_s8(sptr0 + 3 * 4), src00 = vget_low_s8(src0), + src02 = vget_high_s8(src0); + int8x8_t src01 = vext_s8(src00, src02, 4); + + int8x16_t src1 = vld1q_s8(sptr1); + int8x8_t src13 = vld1_s8(sptr1 + 3 * 4), src10 = vget_low_s8(src1), + src12 = vget_high_s8(src1); + int8x8_t src11 = vext_s8(src10, src12, 4); + + int16x8_t tmp_sum00 = vmull_s8(src00, kern01); + tmp_sum00 = vmlal_s8(tmp_sum00, src10, kern23); + ACC_S16_S32(sum00, tmp_sum00); + + int16x8_t tmp_sum01 = vmull_s8(src01, kern01); + tmp_sum01 = vmlal_s8(tmp_sum01, src11, kern23); + ACC_S16_S32(sum01, tmp_sum01); + + int16x8_t tmp_sum02 = vmull_s8(src02, kern01); + tmp_sum02 = vmlal_s8(tmp_sum02, src12, kern23); + ACC_S16_S32(sum02, tmp_sum02); + + int16x8_t tmp_sum03 = vmull_s8(src03, kern01); + tmp_sum03 = vmlal_s8(tmp_sum03, src13, kern23); + ACC_S16_S32(sum03, tmp_sum03); + + if (quantized) { + dt_qint8* dptr = + reinterpret_cast(dst) + oh * OW * 4 + ow * 4; + op({{sum00, sum01}}, dptr); + op({{sum02, sum03}}, dptr + 8); + } else { + dt_int32* dptr = + reinterpret_cast(dst) + oh * OW * 4 + ow * 4; + vst1q_s32(dptr, sum00); + vst1q_s32(dptr + 4, sum01); + vst1q_s32(dptr + 8, sum02); + vst1q_s32(dptr + 12, sum03); + } + } + for (; ow < OW; ow++) { + size_t iw = ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int32_t* __restrict bptr = bias; + int32x4_t sum00; + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vld1q_s32(bptr); + } else { + sum00 = vdupq_n_s32(0); + } + int8x8_t src00 = vld1_s8(sptr0); + int8x8_t src10 = vld1_s8(sptr1); + int16x8_t tmp_sum00 = vmull_s8(src00, kern01); + tmp_sum00 = vmlal_s8(tmp_sum00, src10, kern23); + ACC_S16_S32(sum00, tmp_sum00); + if (quantized) { + dt_qint8* dptr = + reinterpret_cast(dst) + oh * OW * 4 + ow * 4; + op(sum00, dptr); + } else { + dt_int32* dptr = + reinterpret_cast(dst) + oh * OW * 4 + ow * 4; + vst1q_s32(dptr, sum00); + } + } + } +} +#undef ACC_S16_S32 + +template +void channel_wise_nchw44::direct_stride1_3x3_int8( + const int8_t* sptr, const int8_t* fptr, const int32_t* bias, void* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + const int32_t* __restrict bptr = bias; + int32x4_t init_v; + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init_v = vld1q_s32(bptr); + } else { + init_v = vdupq_n_s32(0); + } + const int* filter = reinterpret_cast(fptr); + int8x16_t kern[9]; +#define cb(i) kern[i] = (int8x16_t)vld1q_dup_s32(filter + i); + UNROLL_CALL_NOWRAPPER(9, cb); +#undef cb + +#define LOAD_2_LINE_SRC(sptr0, sptr1) \ + src[0][0] = vld1q_s8(sptr0); \ + src[0][2] = vld1q_s8(sptr0 + 16); \ + src[1][0] = vld1q_s8(sptr1); \ + src[1][2] = vld1q_s8(sptr1 + 16); \ + src[0][1] = vextq_s8(src[0][0], src[0][2], 4); \ + src[1][1] = vextq_s8(src[1][0], src[1][2], 4); \ + src[0][2] = vextq_s8(src[0][0], src[0][2], 8); \ + src[1][2] = vextq_s8(src[1][0], src[1][2], 8); + +#define LOAD_1_LINE_SRC(sptr0, src) \ + src[0] = vld1q_s8(sptr0); \ + src[2] = vld1q_s8(sptr0 + 16); \ + src[1] = vextq_s8(src[0], src[2], 4); \ + src[2] = vextq_s8(src[0], src[2], 8); + +#define ACC_1_LINE(src, kern0, kern1, kern2, sum) \ + accumulate_2_q_vector(src[0], kern0, src[1], kern1, sum); \ + accumulate_1_q_vector(src[2], kern2, sum); + + size_t oh = 0_z; + for (; oh + 3 <= OH; oh += 3) { + size_t ih = oh; + size_t ow = 0_z; + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW * 4 + iw * 4; + int32x4_t sum0[4], sum1[4], sum2[4]; +#define cb(j) \ + sum0[j] = init_v; \ + sum1[j] = init_v; \ + sum2[j] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb +//! gcc will report error of "more than 30 operands in 'asm'" +#if MEGDNN_AARCH64 && defined(__clang__) + asm volatile( + //! load src 0,1 + "ldr q21, [%[sptr0]]\n" + "ldr q24, [%[sptr1]]\n" + + //! sum0 line<0,1> + "smull v27.8h, v21.8b, %[k0].8b\n" + "ldr q23, [%[sptr0], #16]\n" + "smull2 v28.8h, v21.16b, %[k0].16b\n" + "ldr q26, [%[sptr1], #16]\n" + "smlal v27.8h, v24.8b, %[k3].8b\n" + "ext v22.16b, v21.16b, v23.16b, #4\n" + "smlal2 v28.8h, v24.16b, %[k3].16b\n" + "ext v23.16b, v21.16b, v23.16b, #8\n" + "saddw %[sum00].4s, %[sum00].4s, v27.4h\n" + "ext v25.16b, v24.16b, v26.16b, #4\n" + "saddw2 %[sum01].4s, %[sum01].4s, v27.8h\n" + "saddw %[sum02].4s, %[sum02].4s, v28.4h\n" + "ext v26.16b, v24.16b, v26.16b, #8\n" + "saddw2 %[sum03].4s, %[sum03].4s, v28.8h\n" + + "ldr q21, [%[sptr2]]\n" + "smull v29.8h, v22.8b, %[k1].8b\n" + "smull2 v30.8h, v22.16b, %[k1].16b\n" + "ldr q31, [%[sptr2], #16]\n" + "smull v27.8h, v23.8b, %[k2].8b\n" + "ext v22.16b, v21.16b, v31.16b, #4\n" + "smull2 v28.8h, v23.16b, %[k2].16b\n" + "ext v23.16b, v21.16b, v31.16b, #8\n" + "smlal v29.8h, v25.8b, %[k4].8b\n" + "smlal2 v30.8h, v25.16b, %[k4].16b\n" + "saddw %[sum00].4s, %[sum00].4s, v29.4h\n" + "smlal v27.8h, v26.8b, %[k5].8b\n" + "saddw2 %[sum01].4s, %[sum01].4s, v29.8h\n" + "smlal2 v28.8h, v26.16b, %[k5].16b\n" + "saddw %[sum02].4s, %[sum02].4s, v30.4h\n" + "saddw2 %[sum03].4s, %[sum03].4s, v30.8h\n" + //! load src 2 + + //! sum0 line<2> + "smull v29.8h, v21.8b, %[k6].8b\n" + "saddw %[sum00].4s, %[sum00].4s, v27.4h\n" + "smull2 v30.8h, v21.16b, %[k6].16b\n" + "saddw2 %[sum01].4s, %[sum01].4s, v27.8h\n" + "smull v27.8h, v23.8b, %[k8].8b\n" + "saddw %[sum02].4s, %[sum02].4s, v28.4h\n" + "smlal v29.8h, v22.8b, %[k7].8b\n" + "saddw2 %[sum03].4s, %[sum03].4s, v28.8h\n" + "smlal2 v30.8h, v22.16b, %[k7].16b\n" + "saddw %[sum00].4s, %[sum00].4s, v29.4h\n" + "smull2 v28.8h, v23.16b, %[k8].16b\n" + "saddw2 %[sum01].4s, %[sum01].4s, v29.8h\n" + "saddw %[sum02].4s, %[sum02].4s, v30.4h\n" + "saddw2 %[sum03].4s, %[sum03].4s, v30.8h\n" + + //! sum1 line<0,1> + "saddw2 %[sum03].4s, %[sum03].4s, v28.8h\n" + "smull v29.8h, v24.8b, %[k0].8b\n" + "saddw %[sum00].4s, %[sum00].4s, v27.4h\n" + "smull2 v30.8h, v24.16b, %[k0].16b\n" + "saddw2 %[sum01].4s, %[sum01].4s, v27.8h\n" + "smull v27.8h, v25.8b, %[k1].8b\n" + "saddw %[sum02].4s, %[sum02].4s, v28.4h\n" + "smull2 v28.8h, v25.16b, %[k1].16b\n" + "smlal v29.8h, v21.8b, %[k3].8b\n" + "smlal2 v30.8h, v21.16b, %[k3].16b\n" + "saddw %[sum10].4s, %[sum10].4s, v29.4h\n" + "smlal v27.8h, v22.8b, %[k4].8b\n" + "saddw2 %[sum11].4s, %[sum11].4s, v29.8h\n" + "smlal2 v28.8h, v22.16b, %[k4].16b\n" + "saddw %[sum12].4s, %[sum12].4s, v30.4h\n" + "saddw2 %[sum13].4s, %[sum13].4s, v30.8h\n" + + "ldr q24, [%[sptr3]]\n" + "smull v29.8h, v26.8b, %[k2].8b\n" + "saddw %[sum10].4s, %[sum10].4s, v27.4h\n" + "smull2 v30.8h, v26.16b, %[k2].16b\n" + "saddw2 %[sum11].4s, %[sum11].4s, v27.8h\n" + "smlal v29.8h, v23.8b, %[k5].8b\n" + "saddw %[sum12].4s, %[sum12].4s, v28.4h\n" + "smlal2 v30.8h, v23.16b, %[k5].16b\n" + "saddw2 %[sum13].4s, %[sum13].4s, v28.8h\n" + "ldr q26, [%[sptr3], #16]\n" + "saddw %[sum10].4s, %[sum10].4s, v29.4h\n" + "ext v25.16b, v24.16b, v26.16b, #4\n" + "saddw2 %[sum11].4s, %[sum11].4s, v29.8h\n" + "ext v26.16b, v24.16b, v26.16b, #8\n" + "saddw %[sum12].4s, %[sum12].4s, v30.4h\n" + //! src line 3 + + //! sum1 line<2> + "smull v27.8h, v24.8b, %[k6].8b\n" + "saddw2 %[sum13].4s, %[sum13].4s, v30.8h\n" + "smull2 v28.8h, v24.16b, %[k6].16b\n" + "smlal v27.8h, v25.8b, %[k7].8b\n" + "smlal2 v28.8h, v25.16b, %[k7].16b\n" + "saddw %[sum10].4s, %[sum10].4s, v27.4h\n" + "saddw2 %[sum11].4s, %[sum11].4s, v27.8h\n" + + "smull v29.8h, v26.8b, %[k8].8b\n" + "saddw %[sum12].4s, %[sum12].4s, v28.4h\n" + "smull2 v30.8h, v26.16b, %[k8].16b\n" + "saddw2 %[sum13].4s, %[sum13].4s, v28.8h\n" + + //! sum2 line<0,1> + "smull v27.8h, v21.8b, %[k0].8b\n" + "saddw %[sum10].4s, %[sum10].4s, v29.4h\n" + "smull2 v28.8h, v21.16b, %[k0].16b\n" + "saddw2 %[sum11].4s, %[sum11].4s, v29.8h\n" + "smull v29.8h, v22.8b, %[k1].8b\n" + "saddw %[sum12].4s, %[sum12].4s, v30.4h\n" + "smlal v27.8h, v24.8b, %[k3].8b\n" + "saddw2 %[sum13].4s, %[sum13].4s, v30.8h\n" + "smull2 v30.8h, v22.16b, %[k1].16b\n" + "ldr q21, [%[sptr4]]\n" + "saddw %[sum20].4s, %[sum20].4s, v27.4h\n" + "smlal2 v28.8h, v24.16b, %[k3].16b\n" + "saddw2 %[sum21].4s, %[sum21].4s, v27.8h\n" + "smlal v29.8h, v25.8b, %[k4].8b\n" + "saddw %[sum22].4s, %[sum22].4s, v28.4h\n" + "smlal2 v30.8h, v25.16b, %[k4].16b\n" + "saddw2 %[sum23].4s, %[sum23].4s, v28.8h\n" + + "smull v27.8h, v23.8b, %[k2].8b\n" + "saddw %[sum20].4s, %[sum20].4s, v29.4h\n" + "smull2 v28.8h, v23.16b, %[k2].16b\n" + "saddw2 %[sum21].4s, %[sum21].4s, v29.8h\n" + "ldr q23, [%[sptr4], #16]\n" + "smlal v27.8h, v26.8b, %[k5].8b\n" + "saddw %[sum22].4s, %[sum22].4s, v30.4h\n" + "smlal2 v28.8h, v26.16b, %[k5].16b\n" + "saddw2 %[sum23].4s, %[sum23].4s, v30.8h\n" + "ext v22.16b, v21.16b, v23.16b, #4\n" + "saddw %[sum20].4s, %[sum20].4s, v27.4h\n" + "ext v23.16b, v21.16b, v23.16b, #8\n" + "saddw2 %[sum21].4s, %[sum21].4s, v27.8h\n" + //! src line 3 + + //! sum2 line<2> + "smull v29.8h, v21.8b, %[k6].8b\n" + "saddw %[sum22].4s, %[sum22].4s, v28.4h\n" + "smull2 v30.8h, v21.16b, %[k6].16b\n" + "saddw2 %[sum23].4s, %[sum23].4s, v28.8h\n" + "smull v27.8h, v23.8b, %[k8].8b\n" + "smull2 v28.8h, v23.16b, %[k8].16b\n" + "smlal v29.8h, v22.8b, %[k7].8b\n" + "smlal2 v30.8h, v22.16b, %[k7].16b\n" + "saddw %[sum20].4s, %[sum20].4s, v29.4h\n" + "saddw2 %[sum21].4s, %[sum21].4s, v29.8h\n" + "saddw %[sum22].4s, %[sum22].4s, v30.4h\n" + "saddw2 %[sum23].4s, %[sum23].4s, v30.8h\n" + "saddw %[sum20].4s, %[sum20].4s, v27.4h\n" + "saddw2 %[sum21].4s, %[sum21].4s, v27.8h\n" + "saddw %[sum22].4s, %[sum22].4s, v28.4h\n" + "saddw2 %[sum23].4s, %[sum23].4s, v28.8h\n" + : [k0] "+w"(kern[0]), [k1] "+w"(kern[1]), + [k2] "+w"(kern[2]), [k3] "+w"(kern[3]), + [k4] "+w"(kern[4]), [k5] "+w"(kern[5]), + [k6] "+w"(kern[6]), [k7] "+w"(kern[7]), + [k8] "+w"(kern[8]), [sum00] "+w"(sum0[0]), + [sum01] "+w"(sum0[1]), [sum02] "+w"(sum0[2]), + [sum03] "+w"(sum0[3]), [sum10] "+w"(sum1[0]), + [sum11] "+w"(sum1[1]), [sum12] "+w"(sum1[2]), + [sum13] "+w"(sum1[3]), [sum20] "+w"(sum2[0]), + [sum21] "+w"(sum2[1]), [sum22] "+w"(sum2[2]), + [sum23] "+w"(sum2[3]), [sptr0] "+r"(sptr0), + [sptr1] "+r"(sptr1), [sptr2] "+r"(sptr2), + [sptr3] "+r"(sptr3), [sptr4] "+r"(sptr4) + : + : "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "cc", "memory"); + + STORE_1_LINE(dst, (oh), ow, OW, sum0); + STORE_1_LINE(dst, (oh + 1), ow, OW, sum1); + STORE_1_LINE(dst, (oh + 2), ow, OW, sum2); +#else + int8x16_t src[2][3]; + LOAD_2_LINE_SRC(sptr0, sptr1); + + accumulate_2_q_vector(src[0][0], kern[0], src[1][0], kern[3], sum0); + accumulate_2_q_vector(src[0][1], kern[1], src[1][1], kern[4], sum0); + accumulate_2_q_vector(src[0][2], kern[2], src[1][2], kern[5], sum0); + + LOAD_1_LINE_SRC(sptr2, src[0]); + + ACC_1_LINE(src[0], kern[6], kern[7], kern[8], sum0); + + accumulate_2_q_vector(src[1][0], kern[0], src[0][0], kern[3], sum1); + accumulate_2_q_vector(src[1][1], kern[1], src[0][1], kern[4], sum1); + accumulate_2_q_vector(src[1][2], kern[2], src[0][2], kern[5], sum1); + + STORE_1_LINE(dst, oh, ow, OW, sum0); + + LOAD_1_LINE_SRC(sptr3, src[1]); + ACC_1_LINE(src[1], kern[6], kern[7], kern[8], sum1); + + accumulate_2_q_vector(src[0][0], kern[0], src[1][0], kern[3], sum2); + accumulate_2_q_vector(src[0][1], kern[1], src[1][1], kern[4], sum2); + accumulate_2_q_vector(src[0][2], kern[2], src[1][2], kern[5], sum2); + + STORE_1_LINE(dst, (oh + 1), ow, OW, sum1); + LOAD_1_LINE_SRC(sptr4, src[0]); + ACC_1_LINE(src[0], kern[6], kern[7], kern[8], sum2); + + STORE_1_LINE(dst, (oh + 2), ow, OW, sum2); +#endif + } + if (ow < OW) { + size_t iw = ow; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW * 4 + iw * 4; + int32x4_t sum0[4], sum1[4], sum2[4]; + int8x16_t src[2][3]; +#define cb(j) \ + sum0[j] = init_v; \ + sum1[j] = init_v; \ + sum2[j] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + LOAD_2_LINE_SRC(sptr0, sptr1); + + accumulate_2_q_vector(src[0][0], kern[0], src[1][0], kern[3], sum0); + accumulate_2_q_vector(src[0][1], kern[1], src[1][1], kern[4], sum0); + accumulate_2_q_vector(src[0][2], kern[2], src[1][2], kern[5], sum0); + + LOAD_1_LINE_SRC(sptr2, src[0]); + ACC_1_LINE(src[0], kern[6], kern[7], kern[8], sum0); + + accumulate_2_q_vector(src[1][0], kern[0], src[0][0], kern[3], sum1); + accumulate_2_q_vector(src[1][1], kern[1], src[0][1], kern[4], sum1); + accumulate_2_q_vector(src[1][2], kern[2], src[0][2], kern[5], sum1); + + STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum0, remain); + + LOAD_1_LINE_SRC(sptr3, src[1]); + ACC_1_LINE(src[1], kern[6], kern[7], kern[8], sum1); + + accumulate_2_q_vector(src[0][0], kern[0], src[1][0], kern[3], sum2); + accumulate_2_q_vector(src[0][1], kern[1], src[1][1], kern[4], sum2); + accumulate_2_q_vector(src[0][2], kern[2], src[1][2], kern[5], sum2); + + STORE_1_LINE_REMAIN(dst, (oh + 1), ow, OW, sum1, remain); + LOAD_1_LINE_SRC(sptr4, src[0]); + ACC_1_LINE(src[0], kern[6], kern[7], kern[8], sum2); + + STORE_1_LINE_REMAIN(dst, (oh + 2), ow, OW, sum2, remain); + } + } + for (; oh < OH; oh++) { + size_t ih = oh; + size_t ow = 0_z; + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4; + int32x4_t sum0[4]; + int8x16_t src[2][3]; +#define cb(i) sum0[i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + LOAD_2_LINE_SRC(sptr0, sptr1); + accumulate_2_q_vector(src[0][0], kern[0], src[1][0], kern[3], sum0); + accumulate_2_q_vector(src[0][1], kern[1], src[1][1], kern[4], sum0); + accumulate_2_q_vector(src[0][2], kern[2], src[1][2], kern[5], sum0); + LOAD_1_LINE_SRC(sptr2, src[0]); + ACC_1_LINE(src[0], kern[6], kern[7], kern[8], sum0); + STORE_1_LINE(dst, oh, ow, OW, sum0); + } + if (ow < OW) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4; + int32x4_t sum0[4]; + int8x16_t src[2][3]; +#define cb(i) sum0[i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + LOAD_2_LINE_SRC(sptr0, sptr1); + accumulate_2_q_vector(src[0][0], kern[0], src[1][0], kern[3], sum0); + accumulate_2_q_vector(src[0][1], kern[1], src[1][1], kern[4], sum0); + accumulate_2_q_vector(src[0][2], kern[2], src[1][2], kern[5], sum0); + LOAD_1_LINE_SRC(sptr2, src[0]); + ACC_1_LINE(src[0], kern[6], kern[7], kern[8], sum0); + STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum0, (OW - ow)); + } + } +#undef LOAD_1_LINE_SRC +#undef LOAD_2_LINE_SRC +#undef ACC_1_LINE +} + +template +void channel_wise_nchw44::direct_stride1_5x5_int8( + const int8_t* sptr, const int8_t* fptr, const int32_t* bias, void* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + const int32_t* __restrict bptr = bias; + int32x4_t init_v; + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init_v = vld1q_s32(bptr); + } else { + init_v = vdupq_n_s32(0); + } + const int* filter = reinterpret_cast(fptr); + +#define LOAD_1_LINE_SRC(sptr, src) \ + src[0] = vld1q_s8(sptr); \ + src[4] = vld1q_s8(sptr + 16); \ + src[1] = vextq_s8(src[0], src[4], 4); \ + src[2] = vextq_s8(src[0], src[4], 8); \ + src[3] = vextq_s8(src[0], src[4], 12); + +#define ACC_1_LINE(src, kern, sum) \ + accumulate_2_q_vector(src[0], kern[0], src[1], kern[1], sum); \ + accumulate_2_q_vector(src[2], kern[2], src[3], kern[3], sum); \ + accumulate_1_q_vector(src[4], kern[4], sum); + + size_t oh = 0_z; + for (; oh + 2 <= OH; oh += 2) { + size_t ih = oh; + size_t ow = 0_z; + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr0 + IW * 4; + const int8_t* __restrict sptr2 = sptr1 + IW * 4; + const int8_t* __restrict sptr3 = sptr2 + IW * 4; + const int8_t* __restrict sptr4 = sptr3 + IW * 4; + const int8_t* __restrict sptr5 = sptr4 + IW * 4; + int32x4_t sum0[4], sum1[4]; + int8x16_t src[2][5]; + int8x16_t kern[2][5]; +#define cb(j) \ + sum0[j] = init_v; \ + sum1[j] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + //! first two line in filter +#define cb(i, kern, filter) kern[i] = (int8x16_t)vld1q_dup_s32((filter) + i); + UNROLL_CALL(5, cb, kern[0], filter); + UNROLL_CALL(5, cb, kern[1], (filter + 5)); +#undef cb + LOAD_1_LINE_SRC(sptr0, src[0]); + LOAD_1_LINE_SRC(sptr1, src[1]); +#define cb(i, sum) \ + accumulate_2_q_vector(src[0][i], kern[0][i], src[1][i], kern[1][i], sum); + UNROLL_CALL(5, cb, sum0); +#undef cb + + LOAD_1_LINE_SRC(sptr2, src[0]); + +#define cb(i, sum) \ + accumulate_2_q_vector(src[1][i], kern[0][i], src[0][i], kern[1][i], sum); + UNROLL_CALL(5, cb, sum1); +#undef cb + //! second two line in filter + LOAD_1_LINE_SRC(sptr3, src[1]); + +#define cb(i, kern, filter) kern[i] = (int8x16_t)vld1q_dup_s32((filter) + i); + UNROLL_CALL(5, cb, kern[0], filter + 10); + UNROLL_CALL(5, cb, kern[1], (filter + 15)); +#undef cb +#define cb(i, sum) \ + accumulate_2_q_vector(src[0][i], kern[0][i], src[1][i], kern[1][i], sum); + UNROLL_CALL(5, cb, sum0); +#undef cb + LOAD_1_LINE_SRC(sptr4, src[0]); + +#define cb(i, sum) \ + accumulate_2_q_vector(src[1][i], kern[0][i], src[0][i], kern[1][i], sum); + UNROLL_CALL(5, cb, sum1); +#undef cb + //! last line in filter +#define cb(i, kern, filter) kern[i] = (int8x16_t)vld1q_dup_s32((filter) + i); + UNROLL_CALL(5, cb, kern[0], filter + 20); +#undef cb + + ACC_1_LINE(src[0], kern[0], sum0); + + LOAD_1_LINE_SRC(sptr5, src[1]); + + ACC_1_LINE(src[1], kern[0], sum1); + + STORE_1_LINE(dst, oh, ow, OW, sum0); + STORE_1_LINE(dst, (oh + 1), ow, OW, sum1); + } + if (ow < OW) { + size_t remain = OW - ow; + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr0 + IW * 4; + const int8_t* __restrict sptr2 = sptr1 + IW * 4; + const int8_t* __restrict sptr3 = sptr2 + IW * 4; + const int8_t* __restrict sptr4 = sptr3 + IW * 4; + const int8_t* __restrict sptr5 = sptr4 + IW * 4; + int32x4_t sum0[4], sum1[4]; + int8x16_t src[2][5]; + int8x16_t kern[2][5]; +#define cb(j) \ + sum0[j] = init_v; \ + sum1[j] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + //! first two line in filter +#define cb(i, kern, filter) kern[i] = (int8x16_t)vld1q_dup_s32((filter) + i); + UNROLL_CALL(5, cb, kern[0], filter); + UNROLL_CALL(5, cb, kern[1], (filter + 5)); +#undef cb + LOAD_1_LINE_SRC(sptr0, src[0]); + LOAD_1_LINE_SRC(sptr1, src[1]); +#define cb(i, sum) \ + accumulate_2_q_vector(src[0][i], kern[0][i], src[1][i], kern[1][i], sum); + UNROLL_CALL(5, cb, sum0); +#undef cb + + LOAD_1_LINE_SRC(sptr2, src[0]); +#define cb(i, sum) \ + accumulate_2_q_vector(src[1][i], kern[0][i], src[0][i], kern[1][i], sum); + UNROLL_CALL(5, cb, sum1); +#undef cb + //! second two line in filter + LOAD_1_LINE_SRC(sptr3, src[1]); + +#define cb(i, kern, filter) kern[i] = (int8x16_t)vld1q_dup_s32((filter) + i); + UNROLL_CALL(5, cb, kern[0], filter + 10); + UNROLL_CALL(5, cb, kern[1], (filter + 15)); +#undef cb +#define cb(i, sum) \ + accumulate_2_q_vector(src[0][i], kern[0][i], src[1][i], kern[1][i], sum); + UNROLL_CALL(5, cb, sum0); +#undef cb + LOAD_1_LINE_SRC(sptr4, src[0]); + +#define cb(i, sum) \ + accumulate_2_q_vector(src[1][i], kern[0][i], src[0][i], kern[1][i], sum); + UNROLL_CALL(5, cb, sum1); +#undef cb + //! last line in filter +#define cb(i, kern, filter) kern[i] = (int8x16_t)vld1q_dup_s32((filter) + i); + UNROLL_CALL(5, cb, kern[0], filter + 20); +#undef cb + + ACC_1_LINE(src[0], kern[0], sum0); + + LOAD_1_LINE_SRC(sptr5, src[1]); + + ACC_1_LINE(src[1], kern[0], sum1); + + STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum0, remain); + STORE_1_LINE_REMAIN(dst, (oh + 1), ow, OW, sum1, remain); + } + } + for (; oh < OH; oh++) { + size_t ih = oh; + size_t ow = 0_z; + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr0 + IW * 4; + const int8_t* __restrict sptr2 = sptr1 + IW * 4; + const int8_t* __restrict sptr3 = sptr2 + IW * 4; + const int8_t* __restrict sptr4 = sptr3 + IW * 4; + int32x4_t sum0[4]; + int8x16_t src[2][5]; + int8x16_t kern[2][5]; +#define cb(j) sum0[j] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + //! first two line in filter +#define cb(i, kern, filter) kern[i] = (int8x16_t)vld1q_dup_s32((filter) + i); + UNROLL_CALL(5, cb, kern[0], filter); + UNROLL_CALL(5, cb, kern[1], (filter + 5)); +#undef cb + LOAD_1_LINE_SRC(sptr0, src[0]); + LOAD_1_LINE_SRC(sptr1, src[1]); +#define cb(i, sum) \ + accumulate_2_q_vector(src[0][i], kern[0][i], src[1][i], kern[1][i], sum); + UNROLL_CALL(5, cb, sum0); +#undef cb + //! second two line in filter + LOAD_1_LINE_SRC(sptr2, src[0]); + LOAD_1_LINE_SRC(sptr3, src[1]); +#define cb(i, kern, filter) kern[i] = (int8x16_t)vld1q_dup_s32((filter) + i); + UNROLL_CALL(5, cb, kern[0], filter + 10); + UNROLL_CALL(5, cb, kern[1], (filter + 15)); +#undef cb +#define cb(i, sum) \ + accumulate_2_q_vector(src[0][i], kern[0][i], src[1][i], kern[1][i], sum); + UNROLL_CALL(5, cb, sum0); +#undef cb + //! last line in filter + LOAD_1_LINE_SRC(sptr4, src[0]); +#define cb(i, kern, filter) kern[i] = (int8x16_t)vld1q_dup_s32((filter) + i); + UNROLL_CALL(5, cb, kern[0], filter + 20); +#undef cb + ACC_1_LINE(src[0], kern[0], sum0); + STORE_1_LINE(dst, oh, ow, OW, sum0); + } + if (ow < OW) { + size_t remain = OW - ow; + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr0 + IW * 4; + const int8_t* __restrict sptr2 = sptr1 + IW * 4; + const int8_t* __restrict sptr3 = sptr2 + IW * 4; + const int8_t* __restrict sptr4 = sptr3 + IW * 4; + int32x4_t sum0[4]; + int8x16_t src[2][5]; + int8x16_t kern[2][5]; +#define cb(j) sum0[j] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + //! first two line in filter +#define cb(i, kern, filter) kern[i] = (int8x16_t)vld1q_dup_s32((filter) + i); + UNROLL_CALL(5, cb, kern[0], filter); + UNROLL_CALL(5, cb, kern[1], (filter + 5)); +#undef cb + LOAD_1_LINE_SRC(sptr0, src[0]); + LOAD_1_LINE_SRC(sptr1, src[1]); +#define cb(i, sum) \ + accumulate_2_q_vector(src[0][i], kern[0][i], src[1][i], kern[1][i], sum); + UNROLL_CALL(5, cb, sum0); +#undef cb + //! second two line in filter + LOAD_1_LINE_SRC(sptr2, src[0]); + LOAD_1_LINE_SRC(sptr3, src[1]); +#define cb(i, kern, filter) kern[i] = (int8x16_t)vld1q_dup_s32((filter) + i); + UNROLL_CALL(5, cb, kern[0], filter + 10); + UNROLL_CALL(5, cb, kern[1], (filter + 15)); +#undef cb +#define cb(i, sum) \ + accumulate_2_q_vector(src[0][i], kern[0][i], src[1][i], kern[1][i], sum); + UNROLL_CALL(5, cb, sum0); +#undef cb + //! last line in filter + LOAD_1_LINE_SRC(sptr4, src[0]); +#define cb(i, kern, filter) kern[i] = (int8x16_t)vld1q_dup_s32((filter) + i); + UNROLL_CALL(5, cb, kern[0], filter + 20); +#undef cb + ACC_1_LINE(src[0], kern[0], sum0); + STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum0, remain); + } + } +#undef LOAD_1_LINE_SRC +#undef LOAD_2_LINE_SRC +#undef ACC_1_LINE +} + +template +void channel_wise_nchw44::direct_stride2_2x2_int8( + const int8_t* src, const int8_t* filter, const int32_t* bias, void* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + int32x4_t init_v; + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init_v = vld1q_s32(bias); + } else { + init_v = vdupq_n_s32(0); + } + int8x8_t kern01 = vld1_s8(filter); + int8x8_t kern23 = vld1_s8(filter + 8); + size_t oh = 0_z; + for (; oh + 2 <= OH; oh += 2) { + size_t ih = oh * 2; + size_t ow = 0_z; + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + + int32x4_t sum[2][4]; +#define cb(i) \ + sum[0][i] = init_v; \ + sum[1][i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + int8x16_t src00 = vld1q_s8(sptr0); + int8x16_t src01 = vld1q_s8(sptr0 + 16); + + int8x16_t src10 = vld1q_s8(sptr1); + int8x16_t src11 = vld1q_s8(sptr1 + 16); + + accumulate_2_d_vector(src00, kern01, src10, kern23, sum[0][0], + sum[0][1]); + accumulate_2_d_vector(src01, kern01, src11, kern23, sum[0][2], + sum[0][3]); + + int8x16_t src20 = vld1q_s8(sptr2); + int8x16_t src21 = vld1q_s8(sptr2 + 16); + + int8x16_t src30 = vld1q_s8(sptr3); + int8x16_t src31 = vld1q_s8(sptr3 + 16); + + accumulate_2_d_vector(src20, kern01, src30, kern23, sum[1][0], + sum[1][1]); + accumulate_2_d_vector(src21, kern01, src31, kern23, sum[1][2], + sum[1][3]); + + STORE_1_LINE(dst, oh, ow, OW, sum[0]); + STORE_1_LINE(dst, (oh + 1), ow, OW, sum[1]); + } + if (ow < OW) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + + int32x4_t sum[2][4]; +#define cb(i) \ + sum[0][i] = init_v; \ + sum[1][i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + int8x16_t src00 = vld1q_s8(sptr0); + int8x16_t src01 = vld1q_s8(sptr0 + 16); + + int8x16_t src10 = vld1q_s8(sptr1); + int8x16_t src11 = vld1q_s8(sptr1 + 16); + + accumulate_2_d_vector(src00, kern01, src10, kern23, sum[0][0], + sum[0][1]); + accumulate_2_d_vector(src01, kern01, src11, kern23, sum[0][2], + sum[0][3]); + + int8x16_t src20 = vld1q_s8(sptr2); + int8x16_t src21 = vld1q_s8(sptr2 + 16); + + int8x16_t src30 = vld1q_s8(sptr3); + int8x16_t src31 = vld1q_s8(sptr3 + 16); + + accumulate_2_d_vector(src20, kern01, src30, kern23, sum[1][0], + sum[1][1]); + accumulate_2_d_vector(src21, kern01, src31, kern23, sum[1][2], + sum[1][3]); + + STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum[0], remain); + STORE_1_LINE_REMAIN(dst, (oh + 1), ow, OW, sum[1], remain); + } + } + for (; oh < OH; oh++) { + size_t ih = oh * 2; + size_t ow = 0_z; + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + int32x4_t sum0[4]; +#define cb(i) sum0[i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + //! first two line + int8x16_t src00 = vld1q_s8(sptr0); + int8x16_t src01 = vld1q_s8(sptr0 + 16); + + int8x16_t src10 = vld1q_s8(sptr1); + int8x16_t src11 = vld1q_s8(sptr1 + 16); + + accumulate_2_d_vector(src00, kern01, src10, kern23, sum0[0], + sum0[1]); + accumulate_2_d_vector(src01, kern01, src11, kern23, sum0[2], + sum0[3]); + + STORE_1_LINE(dst, oh, ow, OW, sum0); + } + if (OW > ow) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + int32x4_t sum0[4]; +#define cb(i) sum0[i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + //! first two line + int8x16_t src00 = vld1q_s8(sptr0); + int8x16_t src01 = vld1q_s8(sptr0 + 16); + + int8x16_t src10 = vld1q_s8(sptr1); + int8x16_t src11 = vld1q_s8(sptr1 + 16); + + accumulate_2_d_vector(src00, kern01, src10, kern23, sum0[0], + sum0[1]); + accumulate_2_d_vector(src01, kern01, src11, kern23, sum0[2], + sum0[3]); + + STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum0, remain); + } + } +} + +template +void channel_wise_nchw44::direct_stride2_3x3_int8( + const int8_t* src, const int8_t* filter, const int32_t* bias, void* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + int32x4_t init_v; + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init_v = vld1q_s32(bias); + } else { + init_v = vdupq_n_s32(0); + } + int32x2_t zero = vdup_n_s32(0); + int8x8_t kern01 = vld1_s8(filter); + int8x8_t kern20 = vreinterpret_s8_s32( + vzip_s32(vreinterpret_s32_s8(vld1_s8(filter + 8)), zero).val[0]); + int8x8_t kern34 = vld1_s8(filter + 12); + int8x8_t kern50 = vreinterpret_s8_s32( + vzip_s32(vreinterpret_s32_s8(vld1_s8(filter + 20)), zero).val[0]); + int8x8_t kern67 = vld1_s8(filter + 24); + //! in case of illegal read + int8x8_t kern80 = vreinterpret_s8_s32( + vzip_s32(vreinterpret_s32_s8(vld1_s8(filter + 28)), zero).val[1]); + +#define COMPUTE_ONE_LINE(src00, src01, src02, kern01, kern20, sum) \ + accumulate_1_line_horizon(vget_low_s8(src00), kern01, vget_high_s8(src00), \ + kern20, sum[0]); \ + accumulate_1_line_horizon(vget_high_s8(src00), kern01, vget_low_s8(src01), \ + kern20, sum[1]); \ + accumulate_1_line_horizon(vget_low_s8(src01), kern01, vget_high_s8(src01), \ + kern20, sum[2]); \ + accumulate_1_line_horizon(vget_high_s8(src01), kern01, src02, kern20, \ + sum[3]); + + size_t oh = 0_z; + for (; oh + 2 <= OH; oh += 2) { + size_t ih = oh * 2; + size_t ow = 0_z; + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + int32x4_t sum[2][4]; +#define cb(i) \ + sum[0][i] = init_v; \ + sum[1][i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + //! line 0 + int8x16_t src00 = vld1q_s8(sptr0); + int8x16_t src01 = vld1q_s8(sptr0 + 16); + int8x8_t src02 = vld1_s8(sptr0 + 32); + COMPUTE_ONE_LINE(src00, src01, src02, kern01, kern20, sum[0]); + + //! line 1 + int8x16_t src10 = vld1q_s8(sptr1); + int8x16_t src11 = vld1q_s8(sptr1 + 16); + int8x8_t src12 = vld1_s8(sptr1 + 32); + COMPUTE_ONE_LINE(src10, src11, src12, kern34, kern50, sum[0]); + + //! line 2 + int8x16_t src20 = vld1q_s8(sptr2); + int8x16_t src21 = vld1q_s8(sptr2 + 16); + int8x8_t src22 = vld1_s8(sptr2 + 32); + COMPUTE_ONE_LINE(src20, src21, src22, kern67, kern80, sum[0]); + //! sum1 + COMPUTE_ONE_LINE(src20, src21, src22, kern01, kern20, sum[1]); + + //! line 3 + int8x16_t src30 = vld1q_s8(sptr3); + int8x16_t src31 = vld1q_s8(sptr3 + 16); + int8x8_t src32 = vld1_s8(sptr3 + 32); + COMPUTE_ONE_LINE(src30, src31, src32, kern34, kern50, sum[1]); + + //! line 4 + int8x16_t src40 = vld1q_s8(sptr4); + int8x16_t src41 = vld1q_s8(sptr4 + 16); + int8x8_t src42 = vld1_s8(sptr4 + 32); + COMPUTE_ONE_LINE(src40, src41, src42, kern67, kern80, sum[1]); + + STORE_1_LINE(dst, oh, ow, OW, sum[0]); + STORE_1_LINE(dst, (oh + 1), ow, OW, sum[1]); + } + if (ow < OW) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + + int32x4_t sum[2][4]; +#define cb(i) \ + sum[0][i] = init_v; \ + sum[1][i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + //! line 0 + int8x16_t src00 = vld1q_s8(sptr0); + int8x16_t src01 = vld1q_s8(sptr0 + 16); + int8x8_t src02 = vld1_s8(sptr0 + 32); + COMPUTE_ONE_LINE(src00, src01, src02, kern01, kern20, sum[0]); + + //! line 1 + int8x16_t src10 = vld1q_s8(sptr1); + int8x16_t src11 = vld1q_s8(sptr1 + 16); + int8x8_t src12 = vld1_s8(sptr1 + 32); + COMPUTE_ONE_LINE(src10, src11, src12, kern34, kern50, sum[0]); + + //! line 2 + int8x16_t src20 = vld1q_s8(sptr2); + int8x16_t src21 = vld1q_s8(sptr2 + 16); + int8x8_t src22 = vld1_s8(sptr2 + 32); + COMPUTE_ONE_LINE(src20, src21, src22, kern67, kern80, sum[0]); + //! sum1 + COMPUTE_ONE_LINE(src20, src21, src22, kern01, kern20, sum[1]); + + //! line 3 + int8x16_t src30 = vld1q_s8(sptr3); + int8x16_t src31 = vld1q_s8(sptr3 + 16); + int8x8_t src32 = vld1_s8(sptr3 + 32); + COMPUTE_ONE_LINE(src30, src31, src32, kern34, kern50, sum[1]); + + //! line 4 + int8x16_t src40 = vld1q_s8(sptr4); + int8x16_t src41 = vld1q_s8(sptr4 + 16); + int8x8_t src42 = vld1_s8(sptr4 + 32); + COMPUTE_ONE_LINE(src40, src41, src42, kern67, kern80, sum[1]); + + STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum[0], remain); + STORE_1_LINE_REMAIN(dst, (oh + 1), ow, OW, sum[1], remain); + } + } + for (; oh < OH; oh++) { + size_t ih = oh * 2; + size_t ow = 0_z; + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + int32x4_t sum[4]; +#define cb(i) sum[i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + //! line 0 + int8x16_t src00 = vld1q_s8(sptr0); + int8x16_t src01 = vld1q_s8(sptr0 + 16); + int8x8_t src02 = vld1_s8(sptr0 + 32); + COMPUTE_ONE_LINE(src00, src01, src02, kern01, kern20, sum); + + //! line 1 + int8x16_t src10 = vld1q_s8(sptr1); + int8x16_t src11 = vld1q_s8(sptr1 + 16); + int8x8_t src12 = vld1_s8(sptr1 + 32); + COMPUTE_ONE_LINE(src10, src11, src12, kern34, kern50, sum); + + //! line 2 + int8x16_t src20 = vld1q_s8(sptr2); + int8x16_t src21 = vld1q_s8(sptr2 + 16); + int8x8_t src22 = vld1_s8(sptr2 + 32); + COMPUTE_ONE_LINE(src20, src21, src22, kern67, kern80, sum); + + STORE_1_LINE(dst, oh, ow, OW, sum); + } + if (OW > ow) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + int32x4_t sum[4]; +#define cb(i) sum[i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + //! line 0 + int8x16_t src00 = vld1q_s8(sptr0); + int8x16_t src01 = vld1q_s8(sptr0 + 16); + int8x8_t src02 = vld1_s8(sptr0 + 32); + COMPUTE_ONE_LINE(src00, src01, src02, kern01, kern20, sum); + + //! line 1 + int8x16_t src10 = vld1q_s8(sptr1); + int8x16_t src11 = vld1q_s8(sptr1 + 16); + int8x8_t src12 = vld1_s8(sptr1 + 32); + COMPUTE_ONE_LINE(src10, src11, src12, kern34, kern50, sum); + + //! line 2 + int8x16_t src20 = vld1q_s8(sptr2); + int8x16_t src21 = vld1q_s8(sptr2 + 16); + int8x8_t src22 = vld1_s8(sptr2 + 32); + COMPUTE_ONE_LINE(src20, src21, src22, kern67, kern80, sum); + + STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum, remain); + } + } +#undef COMPUTE_ONE_LINE +} + +template +void channel_wise_nchw44::direct_stride2_5x5_int8( + const int8_t* src, const int8_t* filter, const int32_t* bias, void* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + int32x4_t init_v; + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init_v = vld1q_s32(bias); + } else { + init_v = vdupq_n_s32(0); + } + int8x8_t kern0[3], kern1[3], kern2[3], kern3[3], kern4[3]; + int32x2_t zero = vdup_n_s32(0); + kern0[0] = vld1_s8(filter); + kern0[1] = vld1_s8(filter + 8); + kern0[2] = vreinterpret_s8_s32( + vzip_s32(vreinterpret_s32_s8(vld1_s8(filter + 16)), zero).val[0]); + kern1[0] = vld1_s8(filter + 20); + kern1[1] = vld1_s8(filter + 28); + kern1[2] = vreinterpret_s8_s32( + vzip_s32(vreinterpret_s32_s8(vld1_s8(filter + 36)), zero).val[0]); + kern2[0] = vld1_s8(filter + 40); + kern2[1] = vld1_s8(filter + 48); + kern2[2] = vreinterpret_s8_s32( + vzip_s32(vreinterpret_s32_s8(vld1_s8(filter + 56)), zero).val[0]); + kern3[0] = vld1_s8(filter + 60); + kern3[1] = vld1_s8(filter + 68); + kern3[2] = vreinterpret_s8_s32( + vzip_s32(vreinterpret_s32_s8(vld1_s8(filter + 76)), zero).val[0]); + kern4[0] = vld1_s8(filter + 80); + kern4[1] = vld1_s8(filter + 88); + //! in case of illegal read + kern4[2] = vreinterpret_s8_s32( + vzip_s32(vreinterpret_s32_s8(vld1_s8(filter + 92)), zero).val[1]); + +#define COMPUTE_ONE_VECTOR(src00, src01, src02, src10, src11, src12, kern0, \ + kern1, sum) \ + accumulate_1_line_horizon(src00, kern0[0], src10, kern1[0], sum); \ + accumulate_1_line_horizon(src01, kern0[1], src11, kern1[1], sum); \ + accumulate_1_line_horizon(src02, kern0[2], src12, kern1[2], sum); + +#define COMPUTE_TWO_LINE(src0, src1, kern0, kern1, sum) \ + COMPUTE_ONE_VECTOR(vget_low_s8(src0[0]), vget_high_s8(src0[0]), \ + vget_low_s8(src0[1]), vget_low_s8(src1[0]), \ + vget_high_s8(src1[0]), vget_low_s8(src1[1]), kern0, \ + kern1, sum[0]) \ + COMPUTE_ONE_VECTOR(vget_high_s8(src0[0]), vget_low_s8(src0[1]), \ + vget_high_s8(src0[1]), vget_high_s8(src1[0]), \ + vget_low_s8(src1[1]), vget_high_s8(src1[1]), kern0, \ + kern1, sum[1]) \ + COMPUTE_ONE_VECTOR(vget_low_s8(src0[1]), vget_high_s8(src0[1]), \ + vget_low_s8(src0[2]), vget_low_s8(src1[1]), \ + vget_high_s8(src1[1]), vget_low_s8(src1[2]), kern0, \ + kern1, sum[2]) \ + COMPUTE_ONE_VECTOR(vget_high_s8(src0[1]), vget_low_s8(src0[2]), \ + vget_high_s8(src0[2]), vget_high_s8(src1[1]), \ + vget_low_s8(src1[2]), vget_high_s8(src1[2]), kern0, \ + kern1, sum[3]) + +#define COMPUTE_ONE_LINE(src, kern, sum) \ + accumulate_1_line_horizon(vget_low_s8(src[0]), kern[0], \ + vget_high_s8(src[0]), kern[1], sum[0]); \ + accumulate_1_line_horizon(vget_high_s8(src[0]), kern[0], \ + vget_low_s8(src[1]), kern[1], sum[1]); \ + accumulate_1_line_horizon(vget_low_s8(src[1]), kern[0], \ + vget_high_s8(src[1]), kern[1], sum[2]); \ + accumulate_1_line_horizon(vget_high_s8(src[1]), kern[0], \ + vget_low_s8(src[2]), kern[1], sum[3]); \ + accumulate_1_d_vector(vget_low_s8(src[1]), kern[2], sum[0]); \ + accumulate_1_d_vector(vget_high_s8(src[1]), kern[2], sum[1]); \ + accumulate_1_d_vector(vget_low_s8(src[2]), kern[2], sum[2]); \ + accumulate_1_d_vector(vget_high_s8(src[2]), kern[2], sum[3]) + + size_t oh = 0_z; + for (; oh + 2 <= OH; oh += 2) { + size_t ih = oh * 2; + size_t ow = 0_z; + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + const int8_t* __restrict sptr5 = src + (ih + 5) * IW * 4 + iw * 4; + const int8_t* __restrict sptr6 = src + (ih + 6) * IW * 4 + iw * 4; + int32x4_t sum[2][4]; +#define cb(i) \ + sum[0][i] = init_v; \ + sum[1][i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + int8x16_t src0[3], src1[3]; + //! line 0, 1 + src0[0] = vld1q_s8(sptr0); + src0[1] = vld1q_s8(sptr0 + 16); + src0[2] = vld1q_s8(sptr0 + 32); + + src1[0] = vld1q_s8(sptr1); + src1[1] = vld1q_s8(sptr1 + 16); + src1[2] = vld1q_s8(sptr1 + 32); + + COMPUTE_TWO_LINE(src0, src1, kern0, kern1, sum[0]); + + //! line 2,3 + src0[0] = vld1q_s8(sptr2); + src0[1] = vld1q_s8(sptr2 + 16); + src0[2] = vld1q_s8(sptr2 + 32); + + src1[0] = vld1q_s8(sptr3); + src1[1] = vld1q_s8(sptr3 + 16); + src1[2] = vld1q_s8(sptr3 + 32); + + COMPUTE_TWO_LINE(src0, src1, kern2, kern3, sum[0]); + COMPUTE_TWO_LINE(src0, src1, kern0, kern1, sum[1]); + + //! line 4,5 + src0[0] = vld1q_s8(sptr4); + src0[1] = vld1q_s8(sptr4 + 16); + src0[2] = vld1q_s8(sptr4 + 32); + + src1[0] = vld1q_s8(sptr5); + src1[1] = vld1q_s8(sptr5 + 16); + src1[2] = vld1q_s8(sptr5 + 32); + COMPUTE_ONE_LINE(src0, kern4, sum[0]); + COMPUTE_TWO_LINE(src0, src1, kern2, kern3, sum[1]); + + //! line 6 + src0[0] = vld1q_s8(sptr6); + src0[1] = vld1q_s8(sptr6 + 16); + src0[2] = vld1q_s8(sptr6 + 32); + + COMPUTE_ONE_LINE(src0, kern4, sum[1]); + + STORE_1_LINE(dst, oh, ow, OW, sum[0]); + STORE_1_LINE(dst, (oh + 1), ow, OW, sum[1]); + } + if (ow < OW) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + const int8_t* __restrict sptr5 = src + (ih + 5) * IW * 4 + iw * 4; + const int8_t* __restrict sptr6 = src + (ih + 6) * IW * 4 + iw * 4; + int32x4_t sum[2][4]; +#define cb(i) \ + sum[0][i] = init_v; \ + sum[1][i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + int8x16_t src0[3], src1[3]; + //! line 0, 1 + src0[0] = vld1q_s8(sptr0); + src0[1] = vld1q_s8(sptr0 + 16); + src0[2] = vld1q_s8(sptr0 + 32); + + src1[0] = vld1q_s8(sptr1); + src1[1] = vld1q_s8(sptr1 + 16); + src1[2] = vld1q_s8(sptr1 + 32); + + COMPUTE_TWO_LINE(src0, src1, kern0, kern1, sum[0]); + + //! line 2,3 + src0[0] = vld1q_s8(sptr2); + src0[1] = vld1q_s8(sptr2 + 16); + src0[2] = vld1q_s8(sptr2 + 32); + + src1[0] = vld1q_s8(sptr3); + src1[1] = vld1q_s8(sptr3 + 16); + src1[2] = vld1q_s8(sptr3 + 32); + + COMPUTE_TWO_LINE(src0, src1, kern2, kern3, sum[0]); + COMPUTE_TWO_LINE(src0, src1, kern0, kern1, sum[1]); + + //! line 4,5 + src0[0] = vld1q_s8(sptr4); + src0[1] = vld1q_s8(sptr4 + 16); + src0[2] = vld1q_s8(sptr4 + 32); + + src1[0] = vld1q_s8(sptr5); + src1[1] = vld1q_s8(sptr5 + 16); + src1[2] = vld1q_s8(sptr5 + 32); + COMPUTE_ONE_LINE(src0, kern4, sum[0]); + COMPUTE_TWO_LINE(src0, src1, kern2, kern3, sum[1]); + + //! line 6 + src0[0] = vld1q_s8(sptr6); + src0[1] = vld1q_s8(sptr6 + 16); + src0[2] = vld1q_s8(sptr6 + 32); + + COMPUTE_ONE_LINE(src0, kern4, sum[1]); + + STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum[0], remain); + STORE_1_LINE_REMAIN(dst, (oh + 1), ow, OW, sum[1], remain); + } + } + for (; oh < OH; oh++) { + size_t ih = oh * 2; + size_t ow = 0_z; + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + int32x4_t sum[4]; +#define cb(i) sum[i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + int8x16_t src0[3], src1[3]; + //! line 0, 1 + src0[0] = vld1q_s8(sptr0); + src0[1] = vld1q_s8(sptr0 + 16); + src0[2] = vld1q_s8(sptr0 + 32); + + src1[0] = vld1q_s8(sptr1); + src1[1] = vld1q_s8(sptr1 + 16); + src1[2] = vld1q_s8(sptr1 + 32); + + COMPUTE_TWO_LINE(src0, src1, kern0, kern1, sum); + + //! line 2,3 + src0[0] = vld1q_s8(sptr2); + src0[1] = vld1q_s8(sptr2 + 16); + src0[2] = vld1q_s8(sptr2 + 32); + + src1[0] = vld1q_s8(sptr3); + src1[1] = vld1q_s8(sptr3 + 16); + src1[2] = vld1q_s8(sptr3 + 32); + + COMPUTE_TWO_LINE(src0, src1, kern2, kern3, sum); + + //! line 4,5 + src0[0] = vld1q_s8(sptr4); + src0[1] = vld1q_s8(sptr4 + 16); + src0[2] = vld1q_s8(sptr4 + 32); + + COMPUTE_ONE_LINE(src0, kern4, sum); + + STORE_1_LINE(dst, oh, ow, OW, sum); + } + if (OW > ow) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + int32x4_t sum[4]; +#define cb(i) sum[i] = init_v; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + int8x16_t src0[3], src1[3]; + //! line 0, 1 + src0[0] = vld1q_s8(sptr0); + src0[1] = vld1q_s8(sptr0 + 16); + src0[2] = vld1q_s8(sptr0 + 32); + + src1[0] = vld1q_s8(sptr1); + src1[1] = vld1q_s8(sptr1 + 16); + src1[2] = vld1q_s8(sptr1 + 32); + + COMPUTE_TWO_LINE(src0, src1, kern0, kern1, sum); + + //! line 2,3 + src0[0] = vld1q_s8(sptr2); + src0[1] = vld1q_s8(sptr2 + 16); + src0[2] = vld1q_s8(sptr2 + 32); + + src1[0] = vld1q_s8(sptr3); + src1[1] = vld1q_s8(sptr3 + 16); + src1[2] = vld1q_s8(sptr3 + 32); + + COMPUTE_TWO_LINE(src0, src1, kern2, kern3, sum); + + //! line 4,5 + src0[0] = vld1q_s8(sptr4); + src0[1] = vld1q_s8(sptr4 + 16); + src0[2] = vld1q_s8(sptr4 + 32); + + COMPUTE_ONE_LINE(src0, kern4, sum); + + STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum, remain); + } + } +#undef COMPUTE_ONE_VECTOR +#undef COMPUTE_ONE_LINE +#undef COMPUTE_TWO_LINE +} + +#undef STORE_1_LINE +#undef STORE_1_LINE_REMAIN + +#define INSTANTIATION(quantized, stride, i, bias, Op) \ + template void channel_wise_nchw44::direct_##stride##_##i##x##i##_int8< \ + quantized, bias, Op>(const int8_t*, const int8_t*, const int32_t*, \ + void*, const size_t, const size_t, \ + const size_t, const size_t, const Op&); + +#define FOR_OP(stride, i, bias) \ + INSTANTIATION(true, stride, i, bias, \ + TypeCvtOp) \ + INSTANTIATION(true, stride, i, bias, \ + ReluOp) \ + INSTANTIATION(true, stride, i, bias, \ + HSwishOp) \ + INSTANTIATION(false, stride, i, bias, \ + NoneOp) + +#define FOR_BIAS(stride, i) \ + FOR_OP(stride, i, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_FILTER(stride) \ + FOR_BIAS(stride, 2) \ + FOR_BIAS(stride, 3) \ + FOR_BIAS(stride, 5) + +#define FOR_STRIDE \ + FOR_FILTER(stride1) \ + FOR_FILTER(stride2) + +FOR_STRIDE + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_BIAS +#undef FOR_OP +#undef INSTANTIATION +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.h b/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.h new file mode 100644 index 00000000..f0a02791 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.h @@ -0,0 +1,40 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace channel_wise_nchw44 { + +#define KERN(stride, i) \ + template \ + void direct_##stride##_##i##x##i##_int8( \ + const int8_t* src, const int8_t* filter, const int32_t* bias, \ + void* dst, const size_t IH, const size_t IW, const size_t OH, \ + const size_t OW, const Op& op); + +KERN(stride1, 2) +KERN(stride1, 3) +KERN(stride1, 5) + +KERN(stride2, 2) +KERN(stride2, 3) +KERN(stride2, 5) + +#undef KERN + +} // namesapce conv_bias +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp b/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp new file mode 100644 index 00000000..dbe44780 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp @@ -0,0 +1,338 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/int8/channel_wise_nchw44.h" +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +#include "midout.h" + +using namespace megdnn; +using namespace arm_common; +using namespace channel_wise_nchw44; + +namespace { +void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t& IH2, size_t& IW2) { + auto&& fm = param.filter_meta; + auto SW = fm.stride[1]; + auto OH = param.osz[0]; + auto OW = param.osz[1]; + auto FH = fm.spatial[0]; + auto FW = fm.spatial[1]; + + size_t OW2 = (OW + 3) & ~3; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; +} +} // namespace + +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride1) +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride2) + +bool stride1::is_available(const NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + bool avaible = + //! src and filter are qint8, dst is qint8 or qint32 + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32)) || + //! src and filter are int8, dst is int32 + (param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int32)) && + fm.format == param::Convolution::Format::NCHW44 && + !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5) && + fm.icpg == 1 && fm.ocpg == 1 && fm.group % 4 == 0; + return avaible; +} + +WorkspaceBundle stride1::get_bundle( + const ConvBiasImpl::NCBKernSizeParam& param) { + size_t nr_threads = param.nr_threads; + size_t IH2, IW2; + get_rectified_size(param, IH2, IW2); + constexpr size_t pack_ic_size = 4_z; + //! The extra 16B is used to void ivalid read in kernel compute + size_t src_size = IH2 * IW2 * pack_ic_size * sizeof(int8_t) + 16; + SmallVector sizes(nr_threads, src_size); + return {nullptr, sizes}; +} + +//! compute one output channel +template +void stride1::do_conv_kern(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IH2, IW2; + get_rectified_size(kern_param, IH2, IW2); + Op op = Op(1.0f, 1.0f); + if (quantized) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + op = Op(scale_bias, scale_dst); + } + + constexpr size_t pack_group_size = 4_z; + constexpr size_t pack_ic_size = 4_z; + + size_t thread_id = ncb_index.thread_id, batch_id = ncb_index.ndrange_id[0]; + size_t group_id = ncb_index.ndrange_id[1]; + bundle.set(kern_param.workspace_ptr); + int8_t* padding_src = static_cast(bundle.get(thread_id)); + const int8_t* sptr = + kern_param.src(batch_id, group_id, 0, pack_group_size); + const int8_t* fptr = kern_param.filter(group_id, pack_group_size); + void* dst = kern_param.dst(batch_id, group_id, 0, pack_group_size); + const int32_t* bptr = + kern_param.bias(batch_id, group_id, 0, pack_group_size); + //! copy in case of illegal read src when padding is zero + std::memset(padding_src, 0, sizeof(int8_t) * IH2 * IW2 * pack_ic_size); + rep(ih, IH) { + std::memcpy(padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size, + sptr + ih * IW * pack_ic_size, + sizeof(int8_t) * IW * pack_ic_size); + } + sptr = padding_src; + +#define KERN(_size) \ + direct_stride1_##_size##x##_size##_int8( \ + sptr, fptr, bptr, dst, IH2, IW2, OH, OW, op); + DISPATCH_FILTER_CHANNEL_WISE(filter, KERN); +#undef KERN +} + +SmallVector stride1::get_kimpls( + const NCBKernSizeParam& param) { + auto fm = param.filter_meta; + size_t N = param.n; + size_t group = fm.group / 4; + megdnn_assert(fm.group % 4 == 0, + "nchw44 channel wise conv with group is not times of 4"); + WorkspaceBundle wbundle = get_bundle(param); + bool quantized = param.dst_type.enumv() == DTypeEnum::QuantizedS8; + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(quantized, filter, bias_mode, op) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride1, \ + midout_iv(#quantized #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + if (quantized) { \ + DO_CONV_KERN_FUN(true, i, bias_mode, \ + TypeCvtOp) \ + } else { \ + DO_CONV_KERN_FUN(false, i, bias_mode, \ + NoneOp) \ + } \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + if (quantized) { \ + DO_CONV_KERN_FUN(true, i, bias_mode, \ + ReluOp) \ + } else { \ + DO_CONV_KERN_FUN(false, i, bias_mode, \ + NoneOp) \ + } \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + if (quantized) { \ + DO_CONV_KERN_FUN(true, i, bias_mode, \ + HSwishOp) \ + } else { \ + DO_CONV_KERN_FUN(false, i, bias_mode, \ + NoneOp) \ + } \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(i) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(i, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + auto exec_one_group = [wbundle, do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(wbundle, kern_param, ncb_index); + }; + ret_kerns.push_back({exec_one_group, {N, group}}); + return ret_kerns; +#undef DO_CONV_KERN_FUN +} + +bool stride2::is_available(const NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + bool avaible = + //! src and filter are qint8, dst is qint8 or qint32 + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32)) || + //! src and filter are int8, dst is int32 + (param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int32)) && + fm.format == param::Convolution::Format::NCHW44 && + !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5) && + fm.icpg == 1 && fm.ocpg == 1 && fm.group % 4 == 0; + return avaible; +} + +WorkspaceBundle stride2::get_bundle( + const ConvBiasImpl::NCBKernSizeParam& param) { + size_t nr_threads = param.nr_threads; + size_t IH2, IW2; + get_rectified_size(param, IH2, IW2); + constexpr size_t pack_ic_size = 4_z; + //! The extra 16B is used to void ivalid read in kernel compute + size_t src_size = IH2 * IW2 * pack_ic_size * sizeof(int8_t) + 16; + SmallVector sizes(nr_threads, src_size); + return {nullptr, sizes}; +} + +//! compute one output channel +template +void stride2::do_conv_kern(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IH2, IW2; + get_rectified_size(kern_param, IH2, IW2); + Op op = Op(1.0f, 1.0f); + if (quantized) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + op = Op(scale_bias, scale_dst); + } + + constexpr size_t pack_group_size = 4_z; + constexpr size_t pack_ic_size = 4_z; + + size_t thread_id = ncb_index.thread_id, batch_id = ncb_index.ndrange_id[0]; + size_t group_id = ncb_index.ndrange_id[1]; + bundle.set(kern_param.workspace_ptr); + int8_t* padding_src = static_cast(bundle.get(thread_id)); + const int8_t* sptr = + kern_param.src(batch_id, group_id, 0, pack_group_size); + const int8_t* fptr = kern_param.filter(group_id, pack_group_size); + void* dst = kern_param.dst(batch_id, group_id, 0, pack_group_size); + const int32_t* bptr = + kern_param.bias(batch_id, group_id, 0, pack_group_size); + //! copy in case of illegal read src when padding is zero + std::memset(padding_src, 0, sizeof(int8_t) * IH2 * IW2 * pack_ic_size); + rep(ih, IH) { + std::memcpy(padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size, + sptr + ih * IW * pack_ic_size, + sizeof(int8_t) * IW * pack_ic_size); + } + sptr = padding_src; + +#define KERN(_size) \ + direct_stride2_##_size##x##_size##_int8( \ + sptr, fptr, bptr, dst, IH2, IW2, OH, OW, op); + DISPATCH_FILTER_CHANNEL_WISE(filter, KERN); +#undef KERN +} + +SmallVector stride2::get_kimpls( + const NCBKernSizeParam& param) { + auto fm = param.filter_meta; + size_t N = param.n; + size_t group = fm.group / 4; + megdnn_assert(fm.group % 4 == 0, + "nchw44 channel wise conv with group is not times of 4"); + WorkspaceBundle wbundle = get_bundle(param); + bool quantized = param.dst_type.enumv() == DTypeEnum::QuantizedS8; + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(quantized, filter, bias_mode, op) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride2, \ + midout_iv(#quantized #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + + DISPATCH_CONV_KERN(); + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + auto exec_one_group = [wbundle, do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(wbundle, kern_param, ncb_index); + }; + ret_kerns.push_back({exec_one_group, {N, group}}); + return ret_kerns; +#undef DISPATCH_CONV_KERN +#undef GET_BIAS_MODE_PARAM +#undef GET_OP_PARAM +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.h b/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.h new file mode 100644 index 00000000..46efaaa4 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.h @@ -0,0 +1,57 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { +namespace channel_wise_nchw44 { + +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +using conv_fun = std::function; + +namespace stride1 { + +bool is_available(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param); + +template +void do_conv_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index); + +SmallVector get_kimpls(const NCBKernSizeParam& param); +} // namespace stride1 + +namespace stride2 { +bool is_available(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param); + +template +void do_conv_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index); + +SmallVector get_kimpls(const NCBKernSizeParam& param); + +} // namespace stride2 +} // namespace direct_int8_stride1 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct.cpp b/dnn/src/arm_common/conv_bias/int8/direct.cpp new file mode 100644 index 00000000..debb4ab4 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct.cpp @@ -0,0 +1,2181 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/int8/direct.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; + +#define ACC_S16_S32(dst0, dst1, src) \ + dst0 = vaddw_s16(dst0, vget_low_s16(src)); \ + dst1 = vaddw_s16(dst1, vget_high_s16(src)); + +#define POSTPROCESS(dst0, dst1, tptr, dptr) \ + if (last_ic) { \ + op({{dst0, dst1}}, reinterpret_cast(dptr)); \ + } else { \ + vst1q_s32(tptr, dst0); \ + vst1q_s32(tptr + 4, dst1); \ + } + +template +void conv_bias::conv_direct_stride1_2x2_int8_nchw(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + int8x8_t k00 = vdup_n_s8(filter[0]); + int8x8_t k01 = vdup_n_s8(filter[1]); + int8x8_t k10 = vdup_n_s8(filter[2]); + int8x8_t k11 = vdup_n_s8(filter[3]); + + // 4x8 block + size_t oh = 0; + for (; oh + 4 <= OH; oh += 4) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11, sum20, sum21, sum30, sum31; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + sum20 = vld1q_s32(tptr + 2 * OW); + sum21 = vld1q_s32(tptr + 2 * OW + 4); + sum30 = vld1q_s32(tptr + 3 * OW); + sum31 = vld1q_s32(tptr + 3 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + sum20 = sum00; + sum21 = sum00; + sum30 = sum00; + sum31 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + sum20 = vdupq_n_s32(0); + sum21 = vdupq_n_s32(0); + sum30 = vdupq_n_s32(0); + sum31 = vdupq_n_s32(0); + } + } + + int8x8_t s = vld1_s8(sptr + 0 * IW); + int16x8_t d0 = vmull_s8(k00, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k10, s); + ACC_S16_S32(sum00, sum01, d0); + int16x8_t d1 = vmull_s8(k00, s); + + s = vld1_s8(sptr + 2 * IW); + d1 = vmlal_s8(d1, k10, s); + ACC_S16_S32(sum10, sum11, d1); + int16x8_t d2 = vmull_s8(k00, s); + + s = vld1_s8(sptr + 3 * IW); + d2 = vmlal_s8(d2, k10, s); + ACC_S16_S32(sum20, sum21, d2); + int16x8_t d3 = vmull_s8(k00, s); + + s = vld1_s8(sptr + 4 * IW); + d3 = vmlal_s8(d3, k10, s); + ACC_S16_S32(sum30, sum31, d3); + + ++sptr; + + s = vld1_s8(sptr + 0 * IW); + d0 = vmull_s8(k01, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k11, s); + ACC_S16_S32(sum00, sum01, d0); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + d1 = vmull_s8(k01, s); + + s = vld1_s8(sptr + 2 * IW); + d1 = vmlal_s8(d1, k11, s); + ACC_S16_S32(sum10, sum11, d1); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + d2 = vmull_s8(k01, s); + + s = vld1_s8(sptr + 3 * IW); + d2 = vmlal_s8(d2, k11, s); + ACC_S16_S32(sum20, sum21, d2); + POSTPROCESS(sum20, sum21, tptr + 2 * OW, dptr + 2 * OW); + d3 = vmull_s8(k01, s); + + s = vld1_s8(sptr + 4 * IW); + d3 = vmlal_s8(d3, k11, s); + + ACC_S16_S32(sum30, sum31, d3); + POSTPROCESS(sum30, sum31, tptr + 3 * OW, dptr + 3 * OW); + } + } + if (oh + 3 == OH) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11, sum20, sum21; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + sum20 = vld1q_s32(tptr + 2 * OW); + sum21 = vld1q_s32(tptr + 2 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + sum20 = sum00; + sum21 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + sum20 = vdupq_n_s32(0); + sum21 = vdupq_n_s32(0); + } + } + + int8x8_t s = vld1_s8(sptr + 0 * IW); + int16x8_t d0 = vmull_s8(k00, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k10, s); + ACC_S16_S32(sum00, sum01, d0); + int16x8_t d1 = vmull_s8(k00, s); + + s = vld1_s8(sptr + 2 * IW); + d1 = vmlal_s8(d1, k10, s); + ACC_S16_S32(sum10, sum11, d1); + int16x8_t d2 = vmull_s8(k00, s); + + s = vld1_s8(sptr + 3 * IW); + d2 = vmlal_s8(d2, k10, s); + ACC_S16_S32(sum20, sum21, d2); + + ++sptr; + + s = vld1_s8(sptr + 0 * IW); + d0 = vmull_s8(k01, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k11, s); + ACC_S16_S32(sum00, sum01, d0); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + ; + d1 = vmull_s8(k01, s); + + s = vld1_s8(sptr + 2 * IW); + d1 = vmlal_s8(d1, k11, s); + ACC_S16_S32(sum10, sum11, d1); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + d2 = vmull_s8(k01, s); + + s = vld1_s8(sptr + 3 * IW); + d2 = vmlal_s8(d2, k11, s); + ACC_S16_S32(sum20, sum21, d2); + POSTPROCESS(sum20, sum21, tptr + 2 * OW, dptr + 2 * OW); + } + } else if (oh + 2 == OH) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + } + + int8x8_t s = vld1_s8(sptr + 0 * IW); + int16x8_t d0 = vmull_s8(k00, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k10, s); + int16x8_t d1 = vmull_s8(k00, s); + + s = vld1_s8(sptr + 2 * IW); + d1 = vmlal_s8(d1, k10, s); + + ACC_S16_S32(sum00, sum01, d0); + ACC_S16_S32(sum10, sum11, d1); + + ++sptr; + + s = vld1_s8(sptr + 0 * IW); + d0 = vmull_s8(k01, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k11, s); + ACC_S16_S32(sum00, sum01, d0); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + d1 = vmull_s8(k01, s); + + s = vld1_s8(sptr + 2 * IW); + d1 = vmlal_s8(d1, k11, s); + ACC_S16_S32(sum10, sum11, d1); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } else if (oh + 1 == OH) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + } + + int8x8_t s = vld1_s8(sptr + 0 * IW); + int16x8_t d0 = vmull_s8(k00, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k10, s); + ACC_S16_S32(sum00, sum01, d0); + + ++sptr; + + s = vld1_s8(sptr + 0 * IW); + d0 = vmull_s8(k01, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k11, s); + ACC_S16_S32(sum00, sum01, d0); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } +} + +template +void conv_bias::conv_direct_stride1_3x3_int8_nchw(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + int8x8_t k00 = vdup_n_s8(filter[0]); + int8x8_t k01 = vdup_n_s8(filter[1]); + int8x8_t k02 = vdup_n_s8(filter[2]); + int8x8_t k10 = vdup_n_s8(filter[3]); + int8x8_t k11 = vdup_n_s8(filter[4]); + int8x8_t k12 = vdup_n_s8(filter[5]); + int8x8_t k20 = vdup_n_s8(filter[6]); + int8x8_t k21 = vdup_n_s8(filter[7]); + int8x8_t k22 = vdup_n_s8(filter[8]); + + // block 2x8 + size_t oh = 0; + for (; oh + 1 < OH; oh += 2) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + } + + int8x8_t _r00 = vld1_s8(sptr + 0 * IW); + int8x8_t _r0n = vld1_s8(sptr + 0 * IW + 8); + int8x8_t _r01 = vext_s8(_r00, _r0n, 1); + int8x8_t _r02 = vext_s8(_r00, _r0n, 2); + + int16x8_t d0 = vmull_s8(_r00, k00); + d0 = vmlal_s8(d0, _r01, k01); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k02); + + int8x8_t _r10 = vld1_s8(sptr + 1 * IW); + int8x8_t _r1n = vld1_s8(sptr + 1 * IW + 8); + int8x8_t _r11 = vext_s8(_r10, _r1n, 1); + int8x8_t _r12 = vext_s8(_r10, _r1n, 2); + d0 = vmlal_s8(d0, _r10, k10); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r11, k11); + d0 = vmlal_s8(d0, _r12, k12); + ACC_S16_S32(sum00, sum01, d0); + int16x8_t d1 = vmull_s8(_r10, k00); + d1 = vmlal_s8(d1, _r11, k01); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r12, k02); + + int8x8_t _r20 = vld1_s8(sptr + 2 * IW); + int8x8_t _r2n = vld1_s8(sptr + 2 * IW + 8); + int8x8_t _r21 = vext_s8(_r20, _r2n, 1); + int8x8_t _r22 = vext_s8(_r20, _r2n, 2); + d0 = vmull_s8(_r20, k20); + d0 = vmlal_s8(d0, _r21, k21); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r22, k22); + ACC_S16_S32(sum00, sum01, d0); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + d1 = vmlal_s8(d1, _r20, k10); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r21, k11); + d1 = vmlal_s8(d1, _r22, k12); + ACC_S16_S32(sum10, sum11, d1); + + int8x8_t _r30 = vld1_s8(sptr + 3 * IW); + int8x8_t _r3n = vld1_s8(sptr + 3 * IW + 8); + int8x8_t _r31 = vext_s8(_r30, _r3n, 1); + int8x8_t _r32 = vext_s8(_r30, _r3n, 2); + d1 = vmull_s8(_r30, k20); + d1 = vmlal_s8(d1, _r31, k21); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r32, k22); + ACC_S16_S32(sum10, sum11, d1); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } + + if (oh < OH) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + } + int8x8_t _r00 = vld1_s8(sptr + 0 * IW); + int8x8_t _r0n = vld1_s8(sptr + 0 * IW + 8); + int8x8_t _r01 = vext_s8(_r00, _r0n, 1); + int8x8_t _r02 = vext_s8(_r00, _r0n, 2); + + int16x8_t d0 = vmull_s8(_r00, k00); + d0 = vmlal_s8(d0, _r01, k01); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k02); + + int8x8_t _r10 = vld1_s8(sptr + 1 * IW); + int8x8_t _r1n = vld1_s8(sptr + 1 * IW + 8); + int8x8_t _r11 = vext_s8(_r10, _r1n, 1); + int8x8_t _r12 = vext_s8(_r10, _r1n, 2); + d0 = vmlal_s8(d0, _r10, k10); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r11, k11); + d0 = vmlal_s8(d0, _r12, k12); + ACC_S16_S32(sum00, sum01, d0); + + int8x8_t _r20 = vld1_s8(sptr + 2 * IW); + int8x8_t _r2n = vld1_s8(sptr + 2 * IW + 8); + int8x8_t _r21 = vext_s8(_r20, _r2n, 1); + int8x8_t _r22 = vext_s8(_r20, _r2n, 2); + d0 = vmull_s8(_r20, k20); + d0 = vmlal_s8(d0, _r21, k21); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r22, k22); + ACC_S16_S32(sum00, sum01, d0); + + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } +} + +template +void conv_bias::conv_direct_stride1_5x5_int8_nchw(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + int8x8_t k00 = vdup_n_s8(filter[0]); + int8x8_t k01 = vdup_n_s8(filter[1]); + int8x8_t k02 = vdup_n_s8(filter[2]); + int8x8_t k03 = vdup_n_s8(filter[3]); + int8x8_t k04 = vdup_n_s8(filter[4]); + int8x8_t k10 = vdup_n_s8(filter[5]); + int8x8_t k11 = vdup_n_s8(filter[6]); + int8x8_t k12 = vdup_n_s8(filter[7]); + int8x8_t k13 = vdup_n_s8(filter[8]); + int8x8_t k14 = vdup_n_s8(filter[9]); + int8x8_t k20 = vdup_n_s8(filter[10]); + int8x8_t k21 = vdup_n_s8(filter[11]); + int8x8_t k22 = vdup_n_s8(filter[12]); + int8x8_t k23 = vdup_n_s8(filter[13]); + int8x8_t k24 = vdup_n_s8(filter[14]); + int8x8_t k30 = vdup_n_s8(filter[15]); + int8x8_t k31 = vdup_n_s8(filter[16]); + int8x8_t k32 = vdup_n_s8(filter[17]); + int8x8_t k33 = vdup_n_s8(filter[18]); + int8x8_t k34 = vdup_n_s8(filter[19]); + int8x8_t k40 = vdup_n_s8(filter[20]); + int8x8_t k41 = vdup_n_s8(filter[21]); + int8x8_t k42 = vdup_n_s8(filter[22]); + int8x8_t k43 = vdup_n_s8(filter[23]); + int8x8_t k44 = vdup_n_s8(filter[24]); + + // block 2x8 + size_t oh = 0; + for (; oh + 1 < OH; oh += 2) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + } + + int8x8_t _r00 = vld1_s8(sptr + 0 * IW); + int8x8_t _r0n = vld1_s8(sptr + 0 * IW + 8); + int8x8_t _r01 = vext_s8(_r00, _r0n, 1); + int8x8_t _r02 = vext_s8(_r00, _r0n, 2); + int8x8_t _r03 = vext_s8(_r00, _r0n, 3); + int8x8_t _r04 = vext_s8(_r00, _r0n, 4); + int16x8_t d0 = vmull_s8(_r00, k00); + d0 = vmlal_s8(d0, _r01, k01); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k02); + d0 = vmlal_s8(d0, _r03, k03); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k04); + + int8x8_t _r10 = vld1_s8(sptr + 1 * IW); + int8x8_t _r1n = vld1_s8(sptr + 1 * IW + 8); + int8x8_t _r11 = vext_s8(_r10, _r1n, 1); + int8x8_t _r12 = vext_s8(_r10, _r1n, 2); + int8x8_t _r13 = vext_s8(_r10, _r1n, 3); + int8x8_t _r14 = vext_s8(_r10, _r1n, 4); + d0 = vmlal_s8(d0, _r10, k10); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r11, k11); + d0 = vmlal_s8(d0, _r12, k12); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r13, k13); + d0 = vmlal_s8(d0, _r14, k14); + ACC_S16_S32(sum00, sum01, d0); + int16x8_t d1 = vmull_s8(_r10, k00); + d1 = vmlal_s8(d1, _r11, k01); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r12, k02); + d1 = vmlal_s8(d1, _r13, k03); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r14, k04); + + int8x8_t _r20 = vld1_s8(sptr + 2 * IW); + int8x8_t _r2n = vld1_s8(sptr + 2 * IW + 8); + int8x8_t _r21 = vext_s8(_r20, _r2n, 1); + int8x8_t _r22 = vext_s8(_r20, _r2n, 2); + int8x8_t _r23 = vext_s8(_r20, _r2n, 3); + int8x8_t _r24 = vext_s8(_r20, _r2n, 4); + d0 = vmull_s8(_r20, k20); + d0 = vmlal_s8(d0, _r21, k21); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r22, k22); + d0 = vmlal_s8(d0, _r23, k23); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r24, k24); + d1 = vmlal_s8(d1, _r20, k10); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r21, k11); + d1 = vmlal_s8(d1, _r22, k12); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r23, k13); + d1 = vmlal_s8(d1, _r24, k14); + ACC_S16_S32(sum10, sum11, d1); + + int8x8_t _r30 = vld1_s8(sptr + 3 * IW); + int8x8_t _r3n = vld1_s8(sptr + 3 * IW + 8); + int8x8_t _r31 = vext_s8(_r30, _r3n, 1); + int8x8_t _r32 = vext_s8(_r30, _r3n, 2); + int8x8_t _r33 = vext_s8(_r30, _r3n, 3); + int8x8_t _r34 = vext_s8(_r30, _r3n, 4); + d0 = vmlal_s8(d0, _r30, k30); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r31, k31); + d0 = vmlal_s8(d0, _r32, k32); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r33, k33); + d0 = vmlal_s8(d0, _r34, k34); + ACC_S16_S32(sum00, sum01, d0); + d1 = vmull_s8(_r30, k20); + d1 = vmlal_s8(d1, _r31, k21); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r32, k22); + d1 = vmlal_s8(d1, _r33, k23); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r34, k24); + + int8x8_t _r40 = vld1_s8(sptr + 4 * IW); + int8x8_t _r4n = vld1_s8(sptr + 4 * IW + 8); + int8x8_t _r41 = vext_s8(_r40, _r4n, 1); + int8x8_t _r42 = vext_s8(_r40, _r4n, 2); + int8x8_t _r43 = vext_s8(_r40, _r4n, 3); + int8x8_t _r44 = vext_s8(_r40, _r4n, 4); + d0 = vmull_s8(_r40, k40); + d0 = vmlal_s8(d0, _r41, k41); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r42, k42); + d0 = vmlal_s8(d0, _r43, k43); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r44, k44); + ACC_S16_S32(sum00, sum01, d0); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + d1 = vmlal_s8(d1, _r40, k30); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r41, k31); + d1 = vmlal_s8(d1, _r42, k32); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r43, k33); + d1 = vmlal_s8(d1, _r44, k34); + ACC_S16_S32(sum10, sum11, d1); + + int8x8_t _r50 = vld1_s8(sptr + 5 * IW); + int8x8_t _r5n = vld1_s8(sptr + 5 * IW + 8); + int8x8_t _r51 = vext_s8(_r50, _r5n, 1); + int8x8_t _r52 = vext_s8(_r50, _r5n, 2); + int8x8_t _r53 = vext_s8(_r50, _r5n, 3); + int8x8_t _r54 = vext_s8(_r50, _r5n, 4); + d1 = vmull_s8(_r50, k40); + d1 = vmlal_s8(d1, _r51, k41); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r52, k42); + d1 = vmlal_s8(d1, _r53, k43); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r54, k44); + ACC_S16_S32(sum10, sum11, d1); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } + + if (oh < OH) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + } + + int8x8_t _r00 = vld1_s8(sptr + 0 * IW); + int8x8_t _r0n = vld1_s8(sptr + 0 * IW + 8); + int8x8_t _r01 = vext_s8(_r00, _r0n, 1); + int8x8_t _r02 = vext_s8(_r00, _r0n, 2); + int8x8_t _r03 = vext_s8(_r00, _r0n, 3); + int8x8_t _r04 = vext_s8(_r00, _r0n, 4); + int16x8_t d0 = vmull_s8(_r00, k00); + d0 = vmlal_s8(d0, _r01, k01); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k02); + d0 = vmlal_s8(d0, _r03, k03); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k04); + + int8x8_t _r10 = vld1_s8(sptr + 1 * IW); + int8x8_t _r1n = vld1_s8(sptr + 1 * IW + 8); + int8x8_t _r11 = vext_s8(_r10, _r1n, 1); + int8x8_t _r12 = vext_s8(_r10, _r1n, 2); + int8x8_t _r13 = vext_s8(_r10, _r1n, 3); + int8x8_t _r14 = vext_s8(_r10, _r1n, 4); + d0 = vmlal_s8(d0, _r10, k10); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r11, k11); + d0 = vmlal_s8(d0, _r12, k12); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r13, k13); + d0 = vmlal_s8(d0, _r14, k14); + ACC_S16_S32(sum00, sum01, d0); + + int8x8_t _r20 = vld1_s8(sptr + 2 * IW); + int8x8_t _r2n = vld1_s8(sptr + 2 * IW + 8); + int8x8_t _r21 = vext_s8(_r20, _r2n, 1); + int8x8_t _r22 = vext_s8(_r20, _r2n, 2); + int8x8_t _r23 = vext_s8(_r20, _r2n, 3); + int8x8_t _r24 = vext_s8(_r20, _r2n, 4); + d0 = vmull_s8(_r20, k20); + d0 = vmlal_s8(d0, _r21, k21); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r22, k22); + d0 = vmlal_s8(d0, _r23, k23); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r24, k24); + + int8x8_t _r30 = vld1_s8(sptr + 3 * IW); + int8x8_t _r3n = vld1_s8(sptr + 3 * IW + 8); + int8x8_t _r31 = vext_s8(_r30, _r3n, 1); + int8x8_t _r32 = vext_s8(_r30, _r3n, 2); + int8x8_t _r33 = vext_s8(_r30, _r3n, 3); + int8x8_t _r34 = vext_s8(_r30, _r3n, 4); + d0 = vmlal_s8(d0, _r30, k30); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r31, k31); + d0 = vmlal_s8(d0, _r32, k32); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r33, k33); + d0 = vmlal_s8(d0, _r34, k34); + ACC_S16_S32(sum00, sum01, d0); + + int8x8_t _r40 = vld1_s8(sptr + 4 * IW); + int8x8_t _r4n = vld1_s8(sptr + 4 * IW + 8); + int8x8_t _r41 = vext_s8(_r40, _r4n, 1); + int8x8_t _r42 = vext_s8(_r40, _r4n, 2); + int8x8_t _r43 = vext_s8(_r40, _r4n, 3); + int8x8_t _r44 = vext_s8(_r40, _r4n, 4); + d0 = vmull_s8(_r40, k40); + d0 = vmlal_s8(d0, _r41, k41); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r42, k42); + d0 = vmlal_s8(d0, _r43, k43); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r44, k44); + ACC_S16_S32(sum00, sum01, d0); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } +} + +template +void conv_bias::conv_direct_stride1_7x7_int8_nchw(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + int8x8_t k00 = vdup_n_s8(filter[0]); + int8x8_t k01 = vdup_n_s8(filter[1]); + int8x8_t k02 = vdup_n_s8(filter[2]); + int8x8_t k03 = vdup_n_s8(filter[3]); + int8x8_t k04 = vdup_n_s8(filter[4]); + int8x8_t k05 = vdup_n_s8(filter[5]); + int8x8_t k06 = vdup_n_s8(filter[6]); + + int8x8_t k10 = vdup_n_s8(filter[7]); + int8x8_t k11 = vdup_n_s8(filter[8]); + int8x8_t k12 = vdup_n_s8(filter[9]); + int8x8_t k13 = vdup_n_s8(filter[10]); + int8x8_t k14 = vdup_n_s8(filter[11]); + int8x8_t k15 = vdup_n_s8(filter[12]); + int8x8_t k16 = vdup_n_s8(filter[13]); + + int8x8_t k20 = vdup_n_s8(filter[14]); + int8x8_t k21 = vdup_n_s8(filter[15]); + int8x8_t k22 = vdup_n_s8(filter[16]); + int8x8_t k23 = vdup_n_s8(filter[17]); + int8x8_t k24 = vdup_n_s8(filter[18]); + int8x8_t k25 = vdup_n_s8(filter[19]); + int8x8_t k26 = vdup_n_s8(filter[20]); + + int8x8_t k30 = vdup_n_s8(filter[21]); + int8x8_t k31 = vdup_n_s8(filter[22]); + int8x8_t k32 = vdup_n_s8(filter[23]); + int8x8_t k33 = vdup_n_s8(filter[24]); + int8x8_t k34 = vdup_n_s8(filter[25]); + int8x8_t k35 = vdup_n_s8(filter[26]); + int8x8_t k36 = vdup_n_s8(filter[27]); + + int8x8_t k40 = vdup_n_s8(filter[28]); + int8x8_t k41 = vdup_n_s8(filter[29]); + int8x8_t k42 = vdup_n_s8(filter[30]); + int8x8_t k43 = vdup_n_s8(filter[31]); + int8x8_t k44 = vdup_n_s8(filter[32]); + int8x8_t k45 = vdup_n_s8(filter[33]); + int8x8_t k46 = vdup_n_s8(filter[34]); + + int8x8_t k50 = vdup_n_s8(filter[35]); + int8x8_t k51 = vdup_n_s8(filter[36]); + int8x8_t k52 = vdup_n_s8(filter[37]); + int8x8_t k53 = vdup_n_s8(filter[38]); + int8x8_t k54 = vdup_n_s8(filter[39]); + int8x8_t k55 = vdup_n_s8(filter[40]); + int8x8_t k56 = vdup_n_s8(filter[41]); + + int8x8_t k60 = vdup_n_s8(filter[42]); + int8x8_t k61 = vdup_n_s8(filter[43]); + int8x8_t k62 = vdup_n_s8(filter[44]); + int8x8_t k63 = vdup_n_s8(filter[45]); + int8x8_t k64 = vdup_n_s8(filter[46]); + int8x8_t k65 = vdup_n_s8(filter[47]); + int8x8_t k66 = vdup_n_s8(filter[48]); + + // block 2x8 + size_t oh = 0; + for (; oh + 1 < OH; oh += 2) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + } + + int8x8_t _r00 = vld1_s8(sptr + 0 * IW); + int8x8_t _r0n = vld1_s8(sptr + 0 * IW + 8); + int8x8_t _r01 = vext_s8(_r00, _r0n, 1); + int8x8_t _r02 = vext_s8(_r00, _r0n, 2); + int8x8_t _r03 = vext_s8(_r00, _r0n, 3); + int8x8_t _r04 = vext_s8(_r00, _r0n, 4); + int8x8_t _r05 = vext_s8(_r00, _r0n, 5); + int8x8_t _r06 = vext_s8(_r00, _r0n, 6); + int16x8_t d0 = vmull_s8(_r00, k00); + d0 = vmlal_s8(d0, _r01, k01); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k02); + d0 = vmlal_s8(d0, _r03, k03); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k04); + d0 = vmlal_s8(d0, _r05, k05); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r06, k06); + + int8x8_t _r10 = vld1_s8(sptr + 1 * IW); + int8x8_t _r1n = vld1_s8(sptr + 1 * IW + 8); + int8x8_t _r11 = vext_s8(_r10, _r1n, 1); + int8x8_t _r12 = vext_s8(_r10, _r1n, 2); + int8x8_t _r13 = vext_s8(_r10, _r1n, 3); + int8x8_t _r14 = vext_s8(_r10, _r1n, 4); + int8x8_t _r15 = vext_s8(_r10, _r1n, 5); + int8x8_t _r16 = vext_s8(_r10, _r1n, 6); + d0 = vmlal_s8(d0, _r10, k10); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r11, k11); + d0 = vmlal_s8(d0, _r12, k12); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r13, k13); + d0 = vmlal_s8(d0, _r14, k14); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r15, k15); + d0 = vmlal_s8(d0, _r16, k16); + ACC_S16_S32(sum00, sum01, d0); + int16x8_t d1 = vmull_s8(_r10, k00); + d1 = vmlal_s8(d1, _r11, k01); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r12, k02); + d1 = vmlal_s8(d1, _r13, k03); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r14, k04); + d1 = vmlal_s8(d1, _r15, k05); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r16, k06); + + int8x8_t _r20 = vld1_s8(sptr + 2 * IW); + int8x8_t _r2n = vld1_s8(sptr + 2 * IW + 8); + int8x8_t _r21 = vext_s8(_r20, _r2n, 1); + int8x8_t _r22 = vext_s8(_r20, _r2n, 2); + int8x8_t _r23 = vext_s8(_r20, _r2n, 3); + int8x8_t _r24 = vext_s8(_r20, _r2n, 4); + int8x8_t _r25 = vext_s8(_r20, _r2n, 5); + int8x8_t _r26 = vext_s8(_r20, _r2n, 6); + d0 = vmull_s8(_r20, k20); + d0 = vmlal_s8(d0, _r21, k21); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r22, k22); + d0 = vmlal_s8(d0, _r23, k23); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r24, k24); + d0 = vmlal_s8(d0, _r25, k25); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r26, k26); + d1 = vmlal_s8(d1, _r20, k10); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r21, k11); + d1 = vmlal_s8(d1, _r22, k12); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r23, k13); + d1 = vmlal_s8(d1, _r24, k14); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r25, k15); + d1 = vmlal_s8(d1, _r26, k16); + ACC_S16_S32(sum10, sum11, d1); + + int8x8_t _r30 = vld1_s8(sptr + 3 * IW); + int8x8_t _r3n = vld1_s8(sptr + 3 * IW + 8); + int8x8_t _r31 = vext_s8(_r30, _r3n, 1); + int8x8_t _r32 = vext_s8(_r30, _r3n, 2); + int8x8_t _r33 = vext_s8(_r30, _r3n, 3); + int8x8_t _r34 = vext_s8(_r30, _r3n, 4); + int8x8_t _r35 = vext_s8(_r30, _r3n, 5); + int8x8_t _r36 = vext_s8(_r30, _r3n, 6); + d0 = vmlal_s8(d0, _r30, k30); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r31, k31); + d0 = vmlal_s8(d0, _r32, k32); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r33, k33); + d0 = vmlal_s8(d0, _r34, k34); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r35, k35); + d0 = vmlal_s8(d0, _r36, k36); + ACC_S16_S32(sum00, sum01, d0); + d1 = vmull_s8(_r30, k20); + d1 = vmlal_s8(d1, _r31, k21); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r32, k22); + d1 = vmlal_s8(d1, _r33, k23); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r34, k24); + d1 = vmlal_s8(d1, _r35, k25); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r36, k26); + + int8x8_t _r40 = vld1_s8(sptr + 4 * IW); + int8x8_t _r4n = vld1_s8(sptr + 4 * IW + 8); + int8x8_t _r41 = vext_s8(_r40, _r4n, 1); + int8x8_t _r42 = vext_s8(_r40, _r4n, 2); + int8x8_t _r43 = vext_s8(_r40, _r4n, 3); + int8x8_t _r44 = vext_s8(_r40, _r4n, 4); + int8x8_t _r45 = vext_s8(_r40, _r4n, 5); + int8x8_t _r46 = vext_s8(_r40, _r4n, 6); + d0 = vmull_s8(_r40, k40); + d0 = vmlal_s8(d0, _r41, k41); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r42, k42); + d0 = vmlal_s8(d0, _r43, k43); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r44, k44); + d0 = vmlal_s8(d0, _r45, k45); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r46, k46); + d1 = vmlal_s8(d1, _r40, k30); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r41, k31); + d1 = vmlal_s8(d1, _r42, k32); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r43, k33); + d1 = vmlal_s8(d1, _r44, k34); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r45, k35); + d1 = vmlal_s8(d1, _r46, k36); + ACC_S16_S32(sum10, sum11, d1); + + int8x8_t _r50 = vld1_s8(sptr + 5 * IW); + int8x8_t _r5n = vld1_s8(sptr + 5 * IW + 8); + int8x8_t _r51 = vext_s8(_r50, _r5n, 1); + int8x8_t _r52 = vext_s8(_r50, _r5n, 2); + int8x8_t _r53 = vext_s8(_r50, _r5n, 3); + int8x8_t _r54 = vext_s8(_r50, _r5n, 4); + int8x8_t _r55 = vext_s8(_r50, _r5n, 5); + int8x8_t _r56 = vext_s8(_r50, _r5n, 6); + d0 = vmlal_s8(d0, _r50, k50); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r51, k51); + d0 = vmlal_s8(d0, _r52, k52); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r53, k53); + d0 = vmlal_s8(d0, _r54, k54); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r55, k55); + d0 = vmlal_s8(d0, _r56, k56); + ACC_S16_S32(sum00, sum01, d0); + d1 = vmull_s8(_r50, k40); + d1 = vmlal_s8(d1, _r51, k41); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r52, k42); + d1 = vmlal_s8(d1, _r53, k43); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r54, k44); + d1 = vmlal_s8(d1, _r55, k45); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r56, k46); + + int8x8_t _r60 = vld1_s8(sptr + 6 * IW); + int8x8_t _r6n = vld1_s8(sptr + 6 * IW + 8); + int8x8_t _r61 = vext_s8(_r60, _r6n, 1); + int8x8_t _r62 = vext_s8(_r60, _r6n, 2); + int8x8_t _r63 = vext_s8(_r60, _r6n, 3); + int8x8_t _r64 = vext_s8(_r60, _r6n, 4); + int8x8_t _r65 = vext_s8(_r60, _r6n, 5); + int8x8_t _r66 = vext_s8(_r60, _r6n, 6); + d0 = vmull_s8(_r60, k60); + d0 = vmlal_s8(d0, _r61, k61); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r62, k62); + d0 = vmlal_s8(d0, _r63, k63); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r64, k64); + d0 = vmlal_s8(d0, _r65, k65); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r66, k66); + ACC_S16_S32(sum00, sum01, d0); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + d1 = vmlal_s8(d1, _r60, k50); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r61, k51); + d1 = vmlal_s8(d1, _r62, k52); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r63, k53); + d1 = vmlal_s8(d1, _r64, k54); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r65, k55); + d1 = vmlal_s8(d1, _r66, k56); + ACC_S16_S32(sum10, sum11, d1); + + int8x8_t _r70 = vld1_s8(sptr + 7 * IW); + int8x8_t _r7n = vld1_s8(sptr + 7 * IW + 8); + int8x8_t _r71 = vext_s8(_r70, _r7n, 1); + int8x8_t _r72 = vext_s8(_r70, _r7n, 2); + int8x8_t _r73 = vext_s8(_r70, _r7n, 3); + int8x8_t _r74 = vext_s8(_r70, _r7n, 4); + int8x8_t _r75 = vext_s8(_r70, _r7n, 5); + int8x8_t _r76 = vext_s8(_r70, _r7n, 6); + d1 = vmull_s8(_r70, k60); + d1 = vmlal_s8(d1, _r71, k61); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r72, k62); + d1 = vmlal_s8(d1, _r73, k63); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r74, k64); + d1 = vmlal_s8(d1, _r75, k65); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r76, k66); + ACC_S16_S32(sum10, sum11, d1); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } + + if (oh < OH) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + } + + int8x8_t _r00 = vld1_s8(sptr + 0 * IW); + int8x8_t _r0n = vld1_s8(sptr + 0 * IW + 8); + int8x8_t _r01 = vext_s8(_r00, _r0n, 1); + int8x8_t _r02 = vext_s8(_r00, _r0n, 2); + int8x8_t _r03 = vext_s8(_r00, _r0n, 3); + int8x8_t _r04 = vext_s8(_r00, _r0n, 4); + int8x8_t _r05 = vext_s8(_r00, _r0n, 5); + int8x8_t _r06 = vext_s8(_r00, _r0n, 6); + int16x8_t d0 = vmull_s8(_r00, k00); + d0 = vmlal_s8(d0, _r01, k01); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k02); + d0 = vmlal_s8(d0, _r03, k03); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k04); + d0 = vmlal_s8(d0, _r05, k05); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r06, k06); + + int8x8_t _r10 = vld1_s8(sptr + 1 * IW); + int8x8_t _r1n = vld1_s8(sptr + 1 * IW + 8); + int8x8_t _r11 = vext_s8(_r10, _r1n, 1); + int8x8_t _r12 = vext_s8(_r10, _r1n, 2); + int8x8_t _r13 = vext_s8(_r10, _r1n, 3); + int8x8_t _r14 = vext_s8(_r10, _r1n, 4); + int8x8_t _r15 = vext_s8(_r10, _r1n, 5); + int8x8_t _r16 = vext_s8(_r10, _r1n, 6); + d0 = vmlal_s8(d0, _r10, k10); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r11, k11); + d0 = vmlal_s8(d0, _r12, k12); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r13, k13); + d0 = vmlal_s8(d0, _r14, k14); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r15, k15); + d0 = vmlal_s8(d0, _r16, k16); + ACC_S16_S32(sum00, sum01, d0); + + int8x8_t _r20 = vld1_s8(sptr + 2 * IW); + int8x8_t _r2n = vld1_s8(sptr + 2 * IW + 8); + int8x8_t _r21 = vext_s8(_r20, _r2n, 1); + int8x8_t _r22 = vext_s8(_r20, _r2n, 2); + int8x8_t _r23 = vext_s8(_r20, _r2n, 3); + int8x8_t _r24 = vext_s8(_r20, _r2n, 4); + int8x8_t _r25 = vext_s8(_r20, _r2n, 5); + int8x8_t _r26 = vext_s8(_r20, _r2n, 6); + d0 = vmull_s8(_r20, k20); + d0 = vmlal_s8(d0, _r21, k21); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r22, k22); + d0 = vmlal_s8(d0, _r23, k23); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r24, k24); + d0 = vmlal_s8(d0, _r25, k25); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r26, k26); + + int8x8_t _r30 = vld1_s8(sptr + 3 * IW); + int8x8_t _r3n = vld1_s8(sptr + 3 * IW + 8); + int8x8_t _r31 = vext_s8(_r30, _r3n, 1); + int8x8_t _r32 = vext_s8(_r30, _r3n, 2); + int8x8_t _r33 = vext_s8(_r30, _r3n, 3); + int8x8_t _r34 = vext_s8(_r30, _r3n, 4); + int8x8_t _r35 = vext_s8(_r30, _r3n, 5); + int8x8_t _r36 = vext_s8(_r30, _r3n, 6); + d0 = vmlal_s8(d0, _r30, k30); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r31, k31); + d0 = vmlal_s8(d0, _r32, k32); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r33, k33); + d0 = vmlal_s8(d0, _r34, k34); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r35, k35); + d0 = vmlal_s8(d0, _r36, k36); + ACC_S16_S32(sum00, sum01, d0); + + int8x8_t _r40 = vld1_s8(sptr + 4 * IW); + int8x8_t _r4n = vld1_s8(sptr + 4 * IW + 8); + int8x8_t _r41 = vext_s8(_r40, _r4n, 1); + int8x8_t _r42 = vext_s8(_r40, _r4n, 2); + int8x8_t _r43 = vext_s8(_r40, _r4n, 3); + int8x8_t _r44 = vext_s8(_r40, _r4n, 4); + int8x8_t _r45 = vext_s8(_r40, _r4n, 5); + int8x8_t _r46 = vext_s8(_r40, _r4n, 6); + d0 = vmull_s8(_r40, k40); + d0 = vmlal_s8(d0, _r41, k41); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r42, k42); + d0 = vmlal_s8(d0, _r43, k43); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r44, k44); + d0 = vmlal_s8(d0, _r45, k45); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r46, k46); + + int8x8_t _r50 = vld1_s8(sptr + 5 * IW); + int8x8_t _r5n = vld1_s8(sptr + 5 * IW + 8); + int8x8_t _r51 = vext_s8(_r50, _r5n, 1); + int8x8_t _r52 = vext_s8(_r50, _r5n, 2); + int8x8_t _r53 = vext_s8(_r50, _r5n, 3); + int8x8_t _r54 = vext_s8(_r50, _r5n, 4); + int8x8_t _r55 = vext_s8(_r50, _r5n, 5); + int8x8_t _r56 = vext_s8(_r50, _r5n, 6); + d0 = vmlal_s8(d0, _r50, k50); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r51, k51); + d0 = vmlal_s8(d0, _r52, k52); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r53, k53); + d0 = vmlal_s8(d0, _r54, k54); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r55, k55); + d0 = vmlal_s8(d0, _r56, k56); + ACC_S16_S32(sum00, sum01, d0); + + int8x8_t _r60 = vld1_s8(sptr + 6 * IW); + int8x8_t _r6n = vld1_s8(sptr + 6 * IW + 8); + int8x8_t _r61 = vext_s8(_r60, _r6n, 1); + int8x8_t _r62 = vext_s8(_r60, _r6n, 2); + int8x8_t _r63 = vext_s8(_r60, _r6n, 3); + int8x8_t _r64 = vext_s8(_r60, _r6n, 4); + int8x8_t _r65 = vext_s8(_r60, _r6n, 5); + int8x8_t _r66 = vext_s8(_r60, _r6n, 6); + d0 = vmull_s8(_r60, k60); + d0 = vmlal_s8(d0, _r61, k61); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r62, k62); + d0 = vmlal_s8(d0, _r63, k63); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r64, k64); + d0 = vmlal_s8(d0, _r65, k65); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r66, k66); + ACC_S16_S32(sum00, sum01, d0); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } +} + +template +void conv_bias::conv_direct_stride2_2x2_int8_nchw(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); +#define GET_R2(sptr) \ + _r00 = vld1_s8(sptr); \ + _r00 = vtbl1_s8(_r00, _idx); \ + _r01 = vld1_s8(sptr + 8); \ + _r01 = vtbl1_s8(_r01, _idx); \ + _rn = vzip_s32(vreinterpret_s32_s8(_r00), vreinterpret_s32_s8(_r01)); \ + _r00 = vreinterpret_s8_s32(_rn.val[0]); \ + _r01 = vreinterpret_s8_s32(_rn.val[1]); + + int8x8_t k00 = vdup_n_s8(filter[0]); + int8x8_t k01 = vdup_n_s8(filter[1]); + int8x8_t k10 = vdup_n_s8(filter[2]); + int8x8_t k11 = vdup_n_s8(filter[3]); + + int8x8_t _idx = {0, 2, 4, 6, 1, 3, 5, 7}; + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh * 2; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + int16x8_t d0; + int32x2x2_t _rn; + int8x8_t _r00, _r01; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + } + + GET_R2(sptr); + d0 = vmull_s8(_r00, k00); + d0 = vmlal_s8(d0, _r01, k01); + ACC_S16_S32(sum00, sum01, d0); + + GET_R2(sptr + IW); + d0 = vmull_s8(_r00, k10); + d0 = vmlal_s8(d0, _r01, k11); + ACC_S16_S32(sum00, sum01, d0); + + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } +#undef GET_R2 +} + +template +void conv_bias::conv_direct_stride2_3x3_int8_nchw(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); +#define GET_R3(sptr) \ + _r00 = vld1_s8(sptr); \ + _r00 = vtbl1_s8(_r00, _idx); \ + _r01 = vld1_s8(sptr + 8); \ + _r01 = vtbl1_s8(_r01, _idx); \ + _rn = vzip_s32(vreinterpret_s32_s8(_r00), vreinterpret_s32_s8(_r01)); \ + _r00 = vreinterpret_s8_s32(_rn.val[0]); \ + _r01 = vreinterpret_s8_s32(_rn.val[1]); \ + _r02 = vld1_s8(sptr + 16); \ + _r02 = vext_s8(_r00, _r02, 1); + + int8x8_t k00 = vdup_n_s8(filter[0]); + int8x8_t k01 = vdup_n_s8(filter[1]); + int8x8_t k02 = vdup_n_s8(filter[2]); + int8x8_t k10 = vdup_n_s8(filter[3]); + int8x8_t k11 = vdup_n_s8(filter[4]); + int8x8_t k12 = vdup_n_s8(filter[5]); + int8x8_t k20 = vdup_n_s8(filter[6]); + int8x8_t k21 = vdup_n_s8(filter[7]); + int8x8_t k22 = vdup_n_s8(filter[8]); + + int8x8_t _idx = {0, 2, 4, 6, 1, 3, 5, 7}; + + // block 2x8 + size_t oh = 0; + for (; oh + 1 < OH; oh += 2) { + size_t ih = oh * 2; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + int16x8_t d0, d1; + int32x2x2_t _rn; + int8x8_t _r00, _r01, _r02; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + } + GET_R3(sptr); + d0 = vmull_s8(_r00, k00); + d0 = vmlal_s8(d0, _r01, k01); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k02); + + GET_R3(sptr + IW); + d0 = vmlal_s8(d0, _r00, k10); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r01, k11); + d0 = vmlal_s8(d0, _r02, k12); + ACC_S16_S32(sum00, sum01, d0); + + GET_R3(sptr + 2 * IW); + d0 = vmull_s8(_r00, k20); + d0 = vmlal_s8(d0, _r01, k21); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k22); + ACC_S16_S32(sum00, sum01, d0); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + d1 = vmull_s8(_r00, k00); + d1 = vmlal_s8(d1, _r01, k01); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r02, k02); + + GET_R3(sptr + 3 * IW); + d1 = vmlal_s8(d1, _r00, k10); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r01, k11); + d1 = vmlal_s8(d1, _r02, k12); + ACC_S16_S32(sum10, sum11, d1); + + GET_R3(sptr + 4 * IW); + d1 = vmull_s8(_r00, k20); + d1 = vmlal_s8(d1, _r01, k21); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r02, k22); + ACC_S16_S32(sum10, sum11, d1); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } + + if (oh < OH) { + size_t ih = oh * 2; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + int16x8_t d0; + int32x2x2_t _rn; + int8x8_t _r00, _r01, _r02; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + } + GET_R3(sptr); + d0 = vmull_s8(_r00, k00); + d0 = vmlal_s8(d0, _r01, k01); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k02); + + GET_R3(sptr + IW); + d0 = vmlal_s8(d0, _r00, k10); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r01, k11); + d0 = vmlal_s8(d0, _r02, k12); + ACC_S16_S32(sum00, sum01, d0); + + GET_R3(sptr + 2 * IW); + d0 = vmull_s8(_r00, k20); + d0 = vmlal_s8(d0, _r01, k21); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k22); + ACC_S16_S32(sum00, sum01, d0); + + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } +#undef GET_R3 +} + +template +void conv_bias::conv_direct_stride2_5x5_int8_nchw(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); +#define GET_R5(sptr) \ + _r00 = vld1_s8(sptr); \ + _r00 = vtbl1_s8(_r00, _idx); \ + _r01 = vld1_s8(sptr + 8); \ + _r01 = vtbl1_s8(_r01, _idx); \ + _rn = vzip_s32(vreinterpret_s32_s8(_r00), vreinterpret_s32_s8(_r01)); \ + _r00 = vreinterpret_s8_s32(_rn.val[0]); \ + _r01 = vreinterpret_s8_s32(_rn.val[1]); \ + _r03 = vld1_s8(sptr + 16); \ + _r03 = vtbl1_s8(_r03, _idx); \ + _r02 = vext_s8(_r00, _r03, 1); \ + _r04 = vext_s8(_r00, _r03, 2); \ + _r03 = vtbl1_s8(_r03, _idxn); \ + _r03 = vext_s8(_r01, _r03, 1); + + int8x8_t k00 = vdup_n_s8(filter[0]); + int8x8_t k01 = vdup_n_s8(filter[1]); + int8x8_t k02 = vdup_n_s8(filter[2]); + int8x8_t k03 = vdup_n_s8(filter[3]); + int8x8_t k04 = vdup_n_s8(filter[4]); + int8x8_t k10 = vdup_n_s8(filter[5]); + int8x8_t k11 = vdup_n_s8(filter[6]); + int8x8_t k12 = vdup_n_s8(filter[7]); + int8x8_t k13 = vdup_n_s8(filter[8]); + int8x8_t k14 = vdup_n_s8(filter[9]); + int8x8_t k20 = vdup_n_s8(filter[10]); + int8x8_t k21 = vdup_n_s8(filter[11]); + int8x8_t k22 = vdup_n_s8(filter[12]); + int8x8_t k23 = vdup_n_s8(filter[13]); + int8x8_t k24 = vdup_n_s8(filter[14]); + int8x8_t k30 = vdup_n_s8(filter[15]); + int8x8_t k31 = vdup_n_s8(filter[16]); + int8x8_t k32 = vdup_n_s8(filter[17]); + int8x8_t k33 = vdup_n_s8(filter[18]); + int8x8_t k34 = vdup_n_s8(filter[19]); + int8x8_t k40 = vdup_n_s8(filter[20]); + int8x8_t k41 = vdup_n_s8(filter[21]); + int8x8_t k42 = vdup_n_s8(filter[22]); + int8x8_t k43 = vdup_n_s8(filter[23]); + int8x8_t k44 = vdup_n_s8(filter[24]); + + int8x8_t _idx = {0, 2, 4, 6, 1, 3, 5, 7}; + int8x8_t _idxn = {4, 5, 6, 7, 0, 1, 2, 3}; + + // block 2x8 + size_t oh = 0; + for (; oh + 1 < OH; oh += 2) { + size_t ih = oh * 2; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + int16x8_t d0, d1; + int32x2x2_t _rn; + int8x8_t _r00, _r01, _r02, _r03, _r04; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + } + GET_R5(sptr); + d0 = vmull_s8(_r00, k00); + d0 = vmlal_s8(d0, _r01, k01); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k02); + d0 = vmlal_s8(d0, _r03, k03); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k04); + + GET_R5(sptr + IW); + d0 = vmlal_s8(d0, _r00, k10); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r01, k11); + d0 = vmlal_s8(d0, _r02, k12); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r03, k13); + d0 = vmlal_s8(d0, _r04, k14); + ACC_S16_S32(sum00, sum01, d0); + + GET_R5(sptr + 2 * IW); + d0 = vmull_s8(_r00, k20); + d0 = vmlal_s8(d0, _r01, k21); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k22); + d0 = vmlal_s8(d0, _r03, k23); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k24); + d1 = vmull_s8(_r00, k00); + d1 = vmlal_s8(d1, _r01, k01); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r02, k02); + d1 = vmlal_s8(d1, _r03, k03); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r04, k04); + + GET_R5(sptr + 3 * IW); + d0 = vmlal_s8(d0, _r00, k30); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r01, k31); + d0 = vmlal_s8(d0, _r02, k32); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r03, k33); + d0 = vmlal_s8(d0, _r04, k34); + ACC_S16_S32(sum00, sum01, d0); + d1 = vmlal_s8(d1, _r00, k10); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r01, k11); + d1 = vmlal_s8(d1, _r02, k12); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r03, k13); + d1 = vmlal_s8(d1, _r04, k14); + ACC_S16_S32(sum10, sum11, d1); + + GET_R5(sptr + 4 * IW); + d0 = vmull_s8(_r00, k40); + d0 = vmlal_s8(d0, _r01, k41); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k42); + d0 = vmlal_s8(d0, _r03, k43); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k44); + ACC_S16_S32(sum00, sum01, d0); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + d1 = vmull_s8(_r00, k20); + d1 = vmlal_s8(d1, _r01, k21); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r02, k22); + d1 = vmlal_s8(d1, _r03, k23); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r04, k24); + + GET_R5(sptr + 5 * IW); + d1 = vmlal_s8(d1, _r00, k30); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r01, k31); + d1 = vmlal_s8(d1, _r02, k32); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r03, k33); + d1 = vmlal_s8(d1, _r04, k34); + ACC_S16_S32(sum10, sum11, d1); + + GET_R5(sptr + 6 * IW); + d1 = vmull_s8(_r00, k40); + d1 = vmlal_s8(d1, _r01, k41); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r02, k42); + d1 = vmlal_s8(d1, _r03, k43); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r04, k44); + ACC_S16_S32(sum10, sum11, d1); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } + + if (oh < OH) { + size_t ih = oh * 2; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + int16x8_t d0; + int32x2x2_t _rn; + int8x8_t _r00, _r01, _r02, _r03, _r04; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + } + GET_R5(sptr); + d0 = vmull_s8(_r00, k00); + d0 = vmlal_s8(d0, _r01, k01); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k02); + d0 = vmlal_s8(d0, _r03, k03); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k04); + + GET_R5(sptr + IW); + d0 = vmlal_s8(d0, _r00, k10); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r01, k11); + d0 = vmlal_s8(d0, _r02, k12); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r03, k13); + d0 = vmlal_s8(d0, _r04, k14); + ACC_S16_S32(sum00, sum01, d0); + + GET_R5(sptr + 2 * IW); + d0 = vmull_s8(_r00, k20); + d0 = vmlal_s8(d0, _r01, k21); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k22); + d0 = vmlal_s8(d0, _r03, k23); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k24); + + GET_R5(sptr + 3 * IW); + d0 = vmlal_s8(d0, _r00, k30); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r01, k31); + d0 = vmlal_s8(d0, _r02, k32); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r03, k33); + d0 = vmlal_s8(d0, _r04, k34); + ACC_S16_S32(sum00, sum01, d0); + + GET_R5(sptr + 4 * IW); + d0 = vmull_s8(_r00, k40); + d0 = vmlal_s8(d0, _r01, k41); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k42); + d0 = vmlal_s8(d0, _r03, k43); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k44); + ACC_S16_S32(sum00, sum01, d0); + + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } +#undef GET_R5 +} + +template +void conv_bias::conv_direct_stride2_7x7_int8_nchw(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); +#define GET_R7(sptr) \ + _r00 = vld1_s8(sptr); \ + _r00 = vtbl1_s8(_r00, _idx); \ + _r01 = vld1_s8(sptr + 8); \ + _r01 = vtbl1_s8(_r01, _idx); \ + _rn = vzip_s32(vreinterpret_s32_s8(_r00), vreinterpret_s32_s8(_r01)); \ + _r00 = vreinterpret_s8_s32(_rn.val[0]); \ + _r01 = vreinterpret_s8_s32(_rn.val[1]); \ + _r05 = vld1_s8(sptr + 16); \ + _r05 = vtbl1_s8(_r05, _idx); \ + _r02 = vext_s8(_r00, _r05, 1); \ + _r04 = vext_s8(_r00, _r05, 2); \ + _r06 = vext_s8(_r00, _r05, 3); \ + _r05 = vtbl1_s8(_r05, _idxn); \ + _r03 = vext_s8(_r01, _r05, 1); \ + _r05 = vext_s8(_r01, _r05, 2); + + int8x8_t k00 = vdup_n_s8(filter[0]); + int8x8_t k01 = vdup_n_s8(filter[1]); + int8x8_t k02 = vdup_n_s8(filter[2]); + int8x8_t k03 = vdup_n_s8(filter[3]); + int8x8_t k04 = vdup_n_s8(filter[4]); + int8x8_t k05 = vdup_n_s8(filter[5]); + int8x8_t k06 = vdup_n_s8(filter[6]); + + int8x8_t k10 = vdup_n_s8(filter[7]); + int8x8_t k11 = vdup_n_s8(filter[8]); + int8x8_t k12 = vdup_n_s8(filter[9]); + int8x8_t k13 = vdup_n_s8(filter[10]); + int8x8_t k14 = vdup_n_s8(filter[11]); + int8x8_t k15 = vdup_n_s8(filter[12]); + int8x8_t k16 = vdup_n_s8(filter[13]); + + int8x8_t k20 = vdup_n_s8(filter[14]); + int8x8_t k21 = vdup_n_s8(filter[15]); + int8x8_t k22 = vdup_n_s8(filter[16]); + int8x8_t k23 = vdup_n_s8(filter[17]); + int8x8_t k24 = vdup_n_s8(filter[18]); + int8x8_t k25 = vdup_n_s8(filter[19]); + int8x8_t k26 = vdup_n_s8(filter[20]); + + int8x8_t k30 = vdup_n_s8(filter[21]); + int8x8_t k31 = vdup_n_s8(filter[22]); + int8x8_t k32 = vdup_n_s8(filter[23]); + int8x8_t k33 = vdup_n_s8(filter[24]); + int8x8_t k34 = vdup_n_s8(filter[25]); + int8x8_t k35 = vdup_n_s8(filter[26]); + int8x8_t k36 = vdup_n_s8(filter[27]); + + int8x8_t k40 = vdup_n_s8(filter[28]); + int8x8_t k41 = vdup_n_s8(filter[29]); + int8x8_t k42 = vdup_n_s8(filter[30]); + int8x8_t k43 = vdup_n_s8(filter[31]); + int8x8_t k44 = vdup_n_s8(filter[32]); + int8x8_t k45 = vdup_n_s8(filter[33]); + int8x8_t k46 = vdup_n_s8(filter[34]); + + int8x8_t k50 = vdup_n_s8(filter[35]); + int8x8_t k51 = vdup_n_s8(filter[36]); + int8x8_t k52 = vdup_n_s8(filter[37]); + int8x8_t k53 = vdup_n_s8(filter[38]); + int8x8_t k54 = vdup_n_s8(filter[39]); + int8x8_t k55 = vdup_n_s8(filter[40]); + int8x8_t k56 = vdup_n_s8(filter[41]); + + int8x8_t k60 = vdup_n_s8(filter[42]); + int8x8_t k61 = vdup_n_s8(filter[43]); + int8x8_t k62 = vdup_n_s8(filter[44]); + int8x8_t k63 = vdup_n_s8(filter[45]); + int8x8_t k64 = vdup_n_s8(filter[46]); + int8x8_t k65 = vdup_n_s8(filter[47]); + int8x8_t k66 = vdup_n_s8(filter[48]); + + int8x8_t _idx = {0, 2, 4, 6, 1, 3, 5, 7}; + int8x8_t _idxn = {4, 5, 6, 7, 0, 1, 2, 3}; + + // block 2x8 + size_t oh = 0; + for (; oh + 1 < OH; oh += 2) { + size_t ih = oh * 2; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + int16x8_t d0, d1; + int32x2x2_t _rn; + int8x8_t _r00, _r01, _r02, _r03, _r04, _r05, _r06; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + } + GET_R7(sptr); + d0 = vmull_s8(_r00, k00); + d0 = vmlal_s8(d0, _r01, k01); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k02); + d0 = vmlal_s8(d0, _r03, k03); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k04); + d0 = vmlal_s8(d0, _r05, k05); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r06, k06); + + GET_R7(sptr + IW); + d0 = vmlal_s8(d0, _r00, k10); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r01, k11); + d0 = vmlal_s8(d0, _r02, k12); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r03, k13); + d0 = vmlal_s8(d0, _r04, k14); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r05, k15); + d0 = vmlal_s8(d0, _r06, k16); + ACC_S16_S32(sum00, sum01, d0); + + GET_R7(sptr + 2 * IW); + d0 = vmull_s8(_r00, k20); + d0 = vmlal_s8(d0, _r01, k21); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k22); + d0 = vmlal_s8(d0, _r03, k23); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k24); + d0 = vmlal_s8(d0, _r05, k25); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r06, k26); + d1 = vmull_s8(_r00, k00); + d1 = vmlal_s8(d1, _r01, k01); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r02, k02); + d1 = vmlal_s8(d1, _r03, k03); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r04, k04); + d1 = vmlal_s8(d1, _r05, k05); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r06, k06); + + GET_R7(sptr + 3 * IW); + d0 = vmlal_s8(d0, _r00, k30); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r01, k31); + d0 = vmlal_s8(d0, _r02, k32); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r03, k33); + d0 = vmlal_s8(d0, _r04, k34); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r05, k35); + d0 = vmlal_s8(d0, _r06, k36); + ACC_S16_S32(sum00, sum01, d0); + d1 = vmlal_s8(d1, _r00, k10); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r01, k11); + d1 = vmlal_s8(d1, _r02, k12); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r03, k13); + d1 = vmlal_s8(d1, _r04, k14); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r05, k15); + d1 = vmlal_s8(d1, _r06, k16); + ACC_S16_S32(sum10, sum11, d1); + + GET_R7(sptr + 4 * IW); + d0 = vmull_s8(_r00, k40); + d0 = vmlal_s8(d0, _r01, k41); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k42); + d0 = vmlal_s8(d0, _r03, k43); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k44); + d0 = vmlal_s8(d0, _r05, k45); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r06, k46); + d1 = vmull_s8(_r00, k20); + d1 = vmlal_s8(d1, _r01, k21); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r02, k22); + d1 = vmlal_s8(d1, _r03, k23); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r04, k24); + d1 = vmlal_s8(d1, _r05, k25); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r06, k26); + + GET_R7(sptr + 5 * IW); + d0 = vmlal_s8(d0, _r00, k50); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r01, k51); + d0 = vmlal_s8(d0, _r02, k52); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r03, k53); + d0 = vmlal_s8(d0, _r04, k54); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r05, k55); + d0 = vmlal_s8(d0, _r06, k56); + ACC_S16_S32(sum00, sum01, d0); + d1 = vmlal_s8(d1, _r00, k30); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r01, k31); + d1 = vmlal_s8(d1, _r02, k32); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r03, k33); + d1 = vmlal_s8(d1, _r04, k34); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r05, k35); + d1 = vmlal_s8(d1, _r06, k36); + ACC_S16_S32(sum10, sum11, d1); + + GET_R7(sptr + 6 * IW); + d0 = vmull_s8(_r00, k60); + d0 = vmlal_s8(d0, _r01, k61); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k62); + d0 = vmlal_s8(d0, _r03, k63); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k64); + d0 = vmlal_s8(d0, _r05, k65); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r06, k66); + ACC_S16_S32(sum00, sum01, d0); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + d1 = vmull_s8(_r00, k40); + d1 = vmlal_s8(d1, _r01, k41); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r02, k42); + d1 = vmlal_s8(d1, _r03, k43); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r04, k44); + d1 = vmlal_s8(d1, _r05, k45); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r06, k46); + + GET_R7(sptr + 7 * IW); + d1 = vmlal_s8(d1, _r00, k50); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r01, k51); + d1 = vmlal_s8(d1, _r02, k52); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r03, k53); + d1 = vmlal_s8(d1, _r04, k54); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r05, k55); + d1 = vmlal_s8(d1, _r06, k56); + ACC_S16_S32(sum10, sum11, d1); + + GET_R7(sptr + 8 * IW); + d1 = vmull_s8(_r00, k60); + d1 = vmlal_s8(d1, _r01, k61); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r02, k62); + d1 = vmlal_s8(d1, _r03, k63); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r04, k64); + d1 = vmlal_s8(d1, _r05, k65); + ACC_S16_S32(sum10, sum11, d1); + d1 = vmull_s8(_r06, k66); + ACC_S16_S32(sum10, sum11, d1); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } + + if (oh < OH) { + size_t ih = oh * 2; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + int8_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + int16x8_t d0; + int32x2x2_t _rn; + int8x8_t _r00, _r01, _r02, _r03, _r04, _r05, _r06; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + } + + GET_R7(sptr); + d0 = vmull_s8(_r00, k00); + d0 = vmlal_s8(d0, _r01, k01); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k02); + d0 = vmlal_s8(d0, _r03, k03); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k04); + d0 = vmlal_s8(d0, _r05, k05); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r06, k06); + + GET_R7(sptr + IW); + d0 = vmlal_s8(d0, _r00, k10); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r01, k11); + d0 = vmlal_s8(d0, _r02, k12); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r03, k13); + d0 = vmlal_s8(d0, _r04, k14); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r05, k15); + d0 = vmlal_s8(d0, _r06, k16); + ACC_S16_S32(sum00, sum01, d0); + + GET_R7(sptr + 2 * IW); + d0 = vmull_s8(_r00, k20); + d0 = vmlal_s8(d0, _r01, k21); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k22); + d0 = vmlal_s8(d0, _r03, k23); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k24); + d0 = vmlal_s8(d0, _r05, k25); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r06, k26); + + GET_R7(sptr + 3 * IW); + d0 = vmlal_s8(d0, _r00, k30); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r01, k31); + d0 = vmlal_s8(d0, _r02, k32); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r03, k33); + d0 = vmlal_s8(d0, _r04, k34); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r05, k35); + d0 = vmlal_s8(d0, _r06, k36); + ACC_S16_S32(sum00, sum01, d0); + + GET_R7(sptr + 4 * IW); + d0 = vmull_s8(_r00, k40); + d0 = vmlal_s8(d0, _r01, k41); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k42); + d0 = vmlal_s8(d0, _r03, k43); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k44); + d0 = vmlal_s8(d0, _r05, k45); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r06, k46); + + GET_R7(sptr + 5 * IW); + d0 = vmlal_s8(d0, _r00, k50); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r01, k51); + d0 = vmlal_s8(d0, _r02, k52); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r03, k53); + d0 = vmlal_s8(d0, _r04, k54); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r05, k55); + d0 = vmlal_s8(d0, _r06, k56); + ACC_S16_S32(sum00, sum01, d0); + + GET_R7(sptr + 6 * IW); + d0 = vmull_s8(_r00, k60); + d0 = vmlal_s8(d0, _r01, k61); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r02, k62); + d0 = vmlal_s8(d0, _r03, k63); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r04, k64); + d0 = vmlal_s8(d0, _r05, k65); + ACC_S16_S32(sum00, sum01, d0); + d0 = vmull_s8(_r06, k66); + ACC_S16_S32(sum00, sum01, d0); + + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } +#undef GET_R7 +} + +#undef POSTPROCESS +#undef ACC_S16_S32 + +#define INSTANTIATION(stride, i, first_ic, last_ic, bias, Op) \ + template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw< \ + first_ic, last_ic, bias, Op>( \ + const int8_t*, const int8_t*, const int32_t*, int32_t*, int8_t*, \ + const size_t, const size_t, const size_t, const size_t, \ + const Op&); + +#define FOR_OP(stride, i, first_ic, last_ic, bias) \ + INSTANTIATION(stride, i, first_ic, last_ic, bias, \ + TypeCvtOp) \ + INSTANTIATION(stride, i, first_ic, last_ic, bias, \ + ReluOp) \ + INSTANTIATION(stride, i, first_ic, last_ic, bias, \ + HSwishOp) + +#define FOR_BIAS(stride, i, first_ic, last_ic) \ + FOR_OP(stride, i, first_ic, last_ic, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, first_ic, last_ic, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_IC(stride, i) \ + FOR_BIAS(stride, i, true, true) \ + FOR_BIAS(stride, i, true, false) \ + FOR_BIAS(stride, i, false, false) \ + FOR_BIAS(stride, i, false, true) + +#define FOR_FILTER(stride) \ + FOR_IC(stride, 2) \ + FOR_IC(stride, 3) \ + FOR_IC(stride, 5) \ + FOR_IC(stride, 7) + +#define FOR_STRIDE \ + FOR_FILTER(stride1) \ + FOR_FILTER(stride2) + +FOR_STRIDE + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_IC +#undef FOR_BIAS +#undef FOR_NONLINEAR +#undef INSTANTIATION + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct.h b/dnn/src/arm_common/conv_bias/int8/direct.h new file mode 100644 index 00000000..1b0589e7 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct.h @@ -0,0 +1,63 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace conv_bias { + +#define KERN(stride, i, layout) \ + template \ + void conv_direct_##stride##_##i##x##i##_int8_##layout( \ + const int8_t* src, const int8_t* filter, const int32_t* bias, \ + int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, \ + const size_t OH, const size_t OW, const Op& op); + +KERN(stride1, 2, nchw) +KERN(stride1, 3, nchw) +KERN(stride1, 5, nchw) +KERN(stride1, 7, nchw) + +KERN(stride2, 2, nchw) +KERN(stride2, 3, nchw) +KERN(stride2, 5, nchw) +KERN(stride2, 7, nchw) + +#undef KERN + +#define KERN(stride, i, layout) \ + template \ + void conv_direct_##stride##_##i##x##i##_int8_##layout( \ + const int8_t* src, const int8_t* filter, const int32_t* bias, \ + int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \ + const size_t IH, const size_t IW, const size_t OH, \ + const size_t OW, const Op& op); +KERN(stride1, 2, nchw44) +KERN(stride1, 3, nchw44) +KERN(stride1, 5, nchw44) +KERN(stride1, 7, nchw44) + +KERN(stride2, 2, nchw44) +KERN(stride2, 3, nchw44) +KERN(stride2, 5, nchw44) +KERN(stride2, 7, nchw44) +#undef KERN +void nchw44_pack_filter(const int8_t* src, int8_t* dst, int filter); +void nchw44_pack_src(const int8_t* src, int8_t* dst, int length); + +} // namespace conv_bias +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp new file mode 100644 index 00000000..fedc920c --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp @@ -0,0 +1,2171 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/conv_bias/int8/direct_dotprod.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; +using megdnn::arm_common::ReluOp; +using megdnn::arm_common::TypeCvtOp; + +inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index) { + int8x8x2_t src; + src.val[0] = vget_low_s8(a); + src.val[1] = vget_high_s8(a); + uint8x8_t index_low = vget_low_u8(index); + uint8x8_t index_high = vget_high_u8(index); + int8x8_t r00 = vtbl2_s8(src, vreinterpret_s8_u8(index_low)); + int8x8_t r01 = vtbl2_s8(src, vreinterpret_s8_u8(index_high)); + int8x16_t r = vcombine_s8(r00, r01); + return r; +} + +#define ST1_S32X4(dst0, tptr) vst1q_s32(tptr, dst0); + +#define ST2_S32X4X2(dst0, tptr) vst2q_s32(tptr, dst0); + +#define POSTPROCESS_1X8(dst0, dst1, tptr, dptr) \ + if (last_ic) { \ + op({{dst0, dst1}}, reinterpret_cast(dptr)); \ + } else { \ + ST1_S32X4(dst0, tptr); \ + ST1_S32X4(dst1, tptr + 4); \ + } + +#define POSTPROCESS2_1X8(dst0, tptr, dptr) \ + if (last_ic) { \ + int32x4x2_t temp; \ + int32x4_t temp00, temp11; \ + temp = vzipq_s32(dst0.val[0], dst0.val[1]); \ + temp00 = temp.val[0]; \ + temp11 = temp.val[1]; \ + op({{temp00, temp11}}, reinterpret_cast(dptr)); \ + } else { \ + ST2_S32X4X2(dst0, tptr); \ + } + +#define POSTPROCESS_2X4(dst0, dst1, tptr1, tptr2, dptr1, dptr2) \ + if (last_ic) { \ + int32x2_t res = reinterpret_cast(op({{dst0, dst1}})); \ + vst1_lane_s32(reinterpret_cast(dptr1), res, 0); \ + vst1_lane_s32(reinterpret_cast(dptr2), res, 1); \ + } else { \ + ST1_S32X4(dst0, tptr1); \ + ST1_S32X4(dst1, tptr2); \ + } + +#define POSTPROCESS_1X4(dst0, tptr, dptr) \ + if (last_ic) { \ + int32x4_t dst1 = vdupq_n_s32(0); \ + int32x2_t res = reinterpret_cast(op({{dst0, dst1}})); \ + vst1_lane_s32(reinterpret_cast(dptr), res, 0); \ + } else { \ + ST1_S32X4(dst0, tptr); \ + } + +#define CALC_0(_k_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k_idx, _elem); + +#define CALC_1(_k_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k_idx, _elem); + +#define CALC_2(_k1_idx, _k2_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k1_idx, _elem); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); + +template +void conv_bias::conv_direct_stride1_2x2_int8_dot(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + const size_t tail_step = IW - OW; + const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, + 2, 3, 16, 16, 3, 4, 16, 16}; + const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, + 6, 7, 16, 16, 7, 8, 16, 16}; + int32_t* outptr = temp; + int32_t* outptr2 = outptr + OW; + int8_t* dstptr = dst; + int8_t* dstptr2 = dstptr + OW; + const int32_t* __restrict bptr = bias; + + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + 2 * IW; + + const int8_t* k0 = filter; + + int8x16_t _k = vreinterpretq_s8_s32( + vdupq_n_s32(*reinterpret_cast(k0))); + uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; + int8x16_t _k1 = vqtbl1q_s8_v7(_k, _idx); + _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; + int8x16_t _k23 = vqtbl1q_s8_v7(_k, _idx); + + int8x16_t _tmp, _elem; + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 2 < width; w += 2) { + int32x4_t _sum00, _sum01, _sum10, _sum11; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum01 = vld1q_s32(outptr + 4); + _sum10 = vld1q_s32(outptr2); + _sum11 = vld1q_s32(outptr2 + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum01 = _sum00; + _sum10 = _sum00; + _sum11 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum01 = vdupq_n_s32(0); + _sum10 = vdupq_n_s32(0); + _sum11 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(1, 0); + CALC_0(1, 1); + + _tmp = vld1q_s8(r1); + CALC_2(23, 1, 0); + CALC_2(23, 1, 1); + + _tmp = vld1q_s8(r2); + CALC_1(23, 0); + CALC_1(23, 1); + + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X8(_sum10, _sum11, outptr2, dstptr2); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 8; + outptr2 += 8; + dstptr += 8; + dstptr2 += 8; + } + + for (; w < width; w++) { + int32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum10 = vld1q_s32(outptr2); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum10 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum10 = vdupq_n_s32(0); + } + } + + _tmp = vtranslq_s8(vld1_s8(r0)); + CALC_0(1, 0); + + _tmp = vtranslq_s8(vld1_s8(r1)); + CALC_2(23, 1, 0); + + _tmp = vtranslq_s8(vld1_s8(r2)); + CALC_1(23, 0); + + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 4; + r1 += 4; + r2 += 4; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 4 < width; w += 4) { + int32x4x2_t _sum0, _sum1; + if (!first_ic) { + _sum0 = vld2q_s32(outptr); + _sum1 = vld2q_s32(outptr + 8); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum0.val[0] = vdupq_n_s32(bptr[0]); + _sum0.val[1] = vdupq_n_s32(bptr[0]); + _sum1 = _sum0; + } else { + _sum0.val[0] = vdupq_n_s32(0); + _sum1.val[0] = vdupq_n_s32(0); + _sum0.val[1] = vdupq_n_s32(0); + _sum1.val[1] = vdupq_n_s32(0); + } + } + + int8x16_t _r00 = vld1q_s8(r0); + //! here will not not read out of bound + int8x16_t _r01_ = vdupq_n_s8(r0[16]); + int8x16_t _r10 = vld1q_s8(r1); + int8x16_t _r11_ = vdupq_n_s8(r1[16]); + int8x16_t _r01 = vextq_s8(_r00, _r01_, 1); + int8x16_t _r11 = vextq_s8(_r10, _r11_, 1); + + int16x8x2_t r_20 = vzipq_s16(vreinterpretq_s16_s8(_r00), + vreinterpretq_s16_s8(_r10)); + int8x16_t _r0 = r_20.val[0]; + int8x16_t _r2 = r_20.val[1]; + + int16x8x2_t r1_21 = vzipq_s16(vreinterpretq_s16_s8(_r01), + vreinterpretq_s16_s8(_r11)); + int8x16_t _r1 = r1_21.val[0]; + int8x16_t _r3 = r1_21.val[1]; + + _sum0.val[0] = vdotq_s32(_sum0.val[0], _k, _r0); + _sum0.val[1] = vdotq_s32(_sum0.val[1], _k, _r1); + _sum1.val[0] = vdotq_s32(_sum1.val[0], _k, _r2); + _sum1.val[1] = vdotq_s32(_sum1.val[1], _k, _r3); + + POSTPROCESS2_1X8(_sum0, outptr, dstptr); + POSTPROCESS2_1X8(_sum1, outptr + 8, dstptr + 8); + + r0 += 16; + r1 += 16; + outptr += 16; + dstptr += 16; + } + for (; w + 2 < width; w += 2) { + int32x4_t _sum00, _sum01; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum01 = vld1q_s32(outptr + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum01 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum01 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(1, 0); + CALC_0(1, 1); + + _tmp = vld1q_s8(r1); + CALC_0(23, 0); + CALC_0(23, 1); + + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + + r0 += 8; + r1 += 8; + outptr += 8; + dstptr += 8; + } + + for (; w < width; w++) { + int32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + } else { + _sum00 = vdupq_n_s32(0); + } + } + + _tmp = vtranslq_s8(vld1_s8(r0)); + CALC_0(1, 0); + + _tmp = vtranslq_s8(vld1_s8(r1)); + CALC_0(23, 0); + + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 4; + r1 += 4; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + } +} + +template +void conv_bias::conv_direct_stride1_3x3_int8_dot(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + const size_t tail_step = IW - OW; + + const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, + 2, 3, 4, 16, 3, 4, 5, 16}; + const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, + 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx2 = {8, 9, 10, 16, 9, 10, 11, 16, + 10, 11, 12, 16, 11, 12, 13, 16}; + int32_t* outptr = temp; + int32_t* outptr2 = outptr + OW; + int8_t* dstptr = dst; + int8_t* dstptr2 = dstptr + OW; + const int32_t* __restrict bptr = bias; + + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + IW * 2; + const int8_t* r3 = src + IW * 3; + + const int8_t* k0 = filter; + + int8x16_t _k_tmp = vcombine_s8(vld1_s8(k0), vdup_n_s8(k0[8])); + uint8x16_t _idx = {0, 1, 2, 16, 0, 1, 2, 16, 0, 1, 2, 16, 0, 1, 2, 16}; + int8x16_t _k12 = vqtbl1q_s8_v7(_k_tmp, _idx); + _idx = {3, 4, 5, 16, 3, 4, 5, 16, 3, 4, 5, 16, 3, 4, 5, 16}; + int8x16_t _k345 = vqtbl1q_s8_v7(_k_tmp, _idx); + _idx = {6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16}; + int8x16_t _k678 = vqtbl1q_s8_v7(_k_tmp, _idx); + + int8x16_t _tmp, _elem; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + int32x4_t _sum00, _sum01, _sum02, _sum10, _sum11, _sum12; + + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum01 = vld1q_s32(outptr + 4); + _sum02 = vld1q_s32(outptr + 8); + _sum10 = vld1q_s32(outptr2); + _sum11 = vld1q_s32(outptr2 + 4); + _sum12 = vld1q_s32(outptr2 + 8); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum01 = _sum00; + _sum02 = _sum00; + _sum10 = _sum00; + _sum11 = _sum00; + _sum12 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum01 = vdupq_n_s32(0); + _sum02 = vdupq_n_s32(0); + _sum10 = vdupq_n_s32(0); + _sum11 = vdupq_n_s32(0); + _sum12 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(12, 0); + CALC_0(12, 1); + CALC_0(12, 2); + + _tmp = vld1q_s8(r1); + CALC_2(345, 12, 0); + CALC_2(345, 12, 1); + CALC_2(345, 12, 2); + + _tmp = vld1q_s8(r2); + CALC_2(678, 345, 0); + CALC_2(678, 345, 1); + CALC_2(678, 345, 2); + + _tmp = vld1q_s8(r3); + CALC_1(678, 0); + CALC_1(678, 1); + CALC_1(678, 2); + + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X4(_sum02, outptr + 8, dstptr + 8); + POSTPROCESS_1X8(_sum10, _sum11, outptr2, dstptr2); + POSTPROCESS_1X4(_sum12, outptr2 + 8, dstptr2 + 8); + + r0 += 12; + r1 += 12; + r2 += 12; + r3 += 12; + outptr += 12; + outptr2 += 12; + dstptr += 12; + dstptr2 += 12; + } + for (; w < width; w++) { + int32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum10 = vld1q_s32(outptr2); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum10 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum10 = vdupq_n_s32(0); + } + } + + _tmp = vtranslq_s8(vld1_s8(r0)); + CALC_0(12, 0); + + _tmp = vtranslq_s8(vld1_s8(r1)); + CALC_2(345, 12, 0); + + _tmp = vtranslq_s8(vld1_s8(r2)); + CALC_2(678, 345, 0); + + _tmp = vtranslq_s8(vld1_s8(r3)); + CALC_1(678, 0); + + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + int32x4_t _sum00, _sum01, _sum02; + + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum01 = vld1q_s32(outptr + 4); + _sum02 = vld1q_s32(outptr + 8); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum01 = _sum00; + _sum02 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum01 = vdupq_n_s32(0); + _sum02 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(12, 0); + CALC_0(12, 1); + CALC_0(12, 2); + + _tmp = vld1q_s8(r1); + CALC_0(345, 0); + CALC_0(345, 1); + CALC_0(345, 2); + + _tmp = vld1q_s8(r2); + CALC_0(678, 0); + CALC_0(678, 1); + CALC_0(678, 2); + + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X4(_sum02, outptr + 8, dstptr + 8); + + r0 += 12; + r1 += 12; + r2 += 12; + outptr += 12; + dstptr += 12; + } + for (; w < width; w++) { + int32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + } else { + _sum00 = vdupq_n_s32(0); + } + } + + _tmp = vtranslq_s8(vld1_s8(r0)); + CALC_0(12, 0); + + _tmp = vtranslq_s8(vld1_s8(r1)); + CALC_0(345, 0); + + _tmp = vtranslq_s8(vld1_s8(r2)); + CALC_0(678, 0); + + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 4; + r1 += 4; + r2 += 4; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } +} + +template +void conv_bias::conv_direct_stride2_2x2_int8_dot(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + const size_t tail_step = IW - 2 * OW + IW; + + const uint8x16_t _idx0 = {0, 1, 16, 16, 2, 3, 16, 16, + 4, 5, 16, 16, 6, 7, 16, 16}; + int32_t* outptr = temp; + int8_t* dstptr = dst; + + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* k0 = filter; + const int32_t* __restrict bptr = bias; + + int8x16_t _k = vreinterpretq_s8_s32( + vdupq_n_s32(*reinterpret_cast(k0))); + uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; + int8x16_t _k1 = vqtbl1q_s8_v7(_k, _idx); + _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; + int8x16_t _k23 = vqtbl1q_s8_v7(_k, _idx); + + int8x16_t _tmp, _elem; + const int width = OW >> 2; + size_t h = 0; + for (; h < OH; h++) { + int w = 0; + for (; w + 2 < width; w += 2) { + int32x4_t _sum0, _sum1; + if (!first_ic) { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum0 = vdupq_n_s32(bptr[0]); + _sum1 = _sum0; + } else { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + } + + int8x16_t _r00 = vld1q_s8(r0); + //! here will not not read out of bound + int8x16_t _r10 = vld1q_s8(r1); + + int16x8x2_t r_00 = vzipq_s16(vreinterpretq_s16_s8(_r00), + vreinterpretq_s16_s8(_r10)); + int8x16_t _r0 = r_00.val[0]; + int8x16_t _r1 = r_00.val[1]; + + _sum0 = vdotq_s32(_sum0, _k, _r0); + _sum1 = vdotq_s32(_sum1, _k, _r1); + + POSTPROCESS_1X8(_sum0, _sum1, outptr, dstptr); + + r0 += 16; + r1 += 16; + outptr += 8; + dstptr += 8; + } + + for (; w < width; w++) { + int32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + } else { + _sum00 = vdupq_n_s32(0); + } + } + + _tmp = vtranslq_s8(vld1_s8(r0)); + CALC_0(1, 0); + + _tmp = vtranslq_s8(vld1_s8(r1)); + CALC_0(23, 0); + + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 8; + r1 += 8; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + } +} + +template +void conv_bias::conv_direct_stride2_3x3_int8_dot(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + const size_t tail_step = IW - 2 * OW + IW; + + const uint8x16_t _idx0 = {0, 1, 2, 16, 2, 3, 4, 16, + 4, 5, 6, 16, 6, 7, 8, 16}; + int32_t* outptr = temp; + int32_t* outptr2 = outptr + OW; + int8_t* dstptr = dst; + int8_t* dstptr2 = dstptr + OW; + const int32_t* __restrict bptr = bias; + + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + IW * 2; + const int8_t* r3 = src + IW * 3; + const int8_t* r4 = src + IW * 4; + + const int8_t* k0 = filter; + + int8x16_t _k_tmp = vcombine_s8(vld1_s8(k0), vdup_n_s8(k0[8])); + uint8x16_t _idx = {0, 1, 2, 16, 0, 1, 2, 16, 0, 1, 2, 16, 0, 1, 2, 16}; + int8x16_t _k12 = vqtbl1q_s8_v7(_k_tmp, _idx); + _idx = {3, 4, 5, 16, 3, 4, 5, 16, 3, 4, 5, 16, 3, 4, 5, 16}; + int8x16_t _k345 = vqtbl1q_s8_v7(_k_tmp, _idx); + _idx = {6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16}; + int8x16_t _k678 = vqtbl1q_s8_v7(_k_tmp, _idx); + + int8x16_t _tmp, _elem; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 2; + + int w = 0; + for (; w < width; w++) { + int32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum10 = vld1q_s32(outptr2); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum10 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum10 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(12, 0); + + _tmp = vld1q_s8(r1); + CALC_0(345, 0); + + _tmp = vld1q_s8(r2); + CALC_2(678, 12, 0); + + _tmp = vld1q_s8(r3); + CALC_1(345, 0); + + _tmp = vld1q_s8(r4); + CALC_1(678, 0); + + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + + r0 += tail_step + IW * 2; + r1 += tail_step + IW * 2; + r2 += tail_step + IW * 2; + r3 += tail_step + IW * 2; + r4 += tail_step + IW * 2; + + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 2; + + int w = 0; + for (; w < width; w++) { + int32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + } else { + _sum00 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(12, 0); + + _tmp = vld1q_s8(r1); + CALC_0(345, 0); + + _tmp = vld1q_s8(r2); + CALC_0(678, 0); + + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } +} +#undef CALC_0 +#undef CALC_1 +#undef CALC_2 + +#define CALC_0(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k01_idx, _elem); + +#define CALC_1(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##0); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k00_idx, _elem); \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##1); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k01_idx, _elem); + +#define CALC_2(_k00_idx, _k01_idx, _k10_idx, _k11_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k10_idx, _elem); \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); + +template +void conv_bias::conv_direct_stride2_5x5_int8_dot(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + const size_t tail_step = IW - 2 * OW + IW; + + const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; + const uint8x16_t _idx01 = {4, 16, 16, 16, 6, 16, 16, 16, + 8, 16, 16, 16, 10, 16, 16, 16}; + //! start from 8 + const uint8x16_t& _idx10 = _idx00; + const uint8x16_t& _idx11 = _idx01; + + int8x16_t _tmp, _elem; + int32_t* outptr = temp; + int32_t* outptr2 = outptr + OW; + int8_t* dstptr = dst; + int8_t* dstptr2 = dstptr + OW; + const int32_t* __restrict bptr = bias; + + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + IW * 2; + const int8_t* r3 = src + IW * 3; + const int8_t* r4 = src + IW * 4; + const int8_t* r5 = src + IW * 5; + const int8_t* r6 = src + IW * 6; + + const int8_t* k0 = filter; + + int8x16_t _k = vld1q_s8(k0); + //! filter row 1 + uint8x16_t _idx = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + int8x16_t _k123 = vqtbl1q_s8_v7(_k, _idx); + _idx = {4, 16, 16, 16, 4, 16, 16, 16, 4, 16, 16, 16, 4, 16, 16, 16}; + int8x16_t _k4 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 2 + _idx = {5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8}; + int8x16_t _k5678 = vqtbl1q_s8_v7(_k, _idx); + _idx = {9, 16, 16, 16, 9, 16, 16, 16, 9, 16, 16, 16, 9, 16, 16, 16}; + int8x16_t _k9 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 3 + _idx = {10, 11, 12, 13, 10, 11, 12, 13, 10, 11, 12, 13, 10, 11, 12, 13}; + int8x16_t _k10111213 = vqtbl1q_s8_v7(_k, _idx); + _idx = {14, 16, 16, 16, 14, 16, 16, 16, 14, 16, 16, 16, 14, 16, 16, 16}; + int8x16_t _k14 = vqtbl1q_s8_v7(_k, _idx); + //! 9 10 11 12 -> 13 14 15 16 -> 17 18 19 20 -> 21 22 23 24 + _k = vld1q_s8(k0 + 9); + //! filter row 4 + _idx = {6, 7, 8, 9, 6, 7, 8, 9, 6, 7, 8, 9, 6, 7, 8, 9}; + int8x16_t _k15161718 = vqtbl1q_s8_v7(_k, _idx); + _idx = {10, 16, 16, 16, 10, 16, 16, 16, 10, 16, 16, 16, 10, 16, 16, 16}; + int8x16_t _k19 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 5 + _idx = {11, 12, 13, 14, 11, 12, 13, 14, 11, 12, 13, 14, 11, 12, 13, 14}; + int8x16_t _k20212223 = vqtbl1q_s8_v7(_k, _idx); + _idx = {15, 16, 16, 16, 15, 16, 16, 16, 15, 16, 16, 16, 15, 16, 16, 16}; + int8x16_t _k24 = vqtbl1q_s8_v7(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 2 < width; w += 2) { + int32x4_t _sum00, _sum01, _sum10, _sum11; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum01 = vld1q_s32(outptr + 4); + _sum10 = vld1q_s32(outptr2); + _sum11 = vld1q_s32(outptr2 + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum01 = _sum00; + _sum10 = _sum00; + _sum11 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum01 = vdupq_n_s32(0); + _sum10 = vdupq_n_s32(0); + _sum11 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(123, 4, 0); + _tmp = vld1q_s8(r0 + 8); + CALC_0(123, 4, 1); + + _tmp = vld1q_s8(r1); + CALC_0(5678, 9, 0); + _tmp = vld1q_s8(r1 + 8); + CALC_0(5678, 9, 1); + + _tmp = vld1q_s8(r2); + CALC_2(10111213, 14, 123, 4, 0); + _tmp = vld1q_s8(r2 + 8); + CALC_2(10111213, 14, 123, 4, 1); + + _tmp = vld1q_s8(r3); + CALC_2(15161718, 19, 5678, 9, 0); + _tmp = vld1q_s8(r3 + 8); + CALC_2(15161718, 19, 5678, 9, 1); + + _tmp = vld1q_s8(r4); + CALC_2(20212223, 24, 10111213, 14, 0); + _tmp = vld1q_s8(r4 + 8); + CALC_2(20212223, 24, 10111213, 14, 1); + + _tmp = vld1q_s8(r5); + CALC_1(15161718, 19, 0); + _tmp = vld1q_s8(r5 + 8); + CALC_1(15161718, 19, 1); + + _tmp = vld1q_s8(r6); + CALC_1(20212223, 24, 0); + _tmp = vld1q_s8(r6 + 8); + CALC_1(20212223, 24, 1); + + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X8(_sum10, _sum11, outptr2, dstptr2); + + r0 += 16; + r1 += 16; + r2 += 16; + r3 += 16; + r4 += 16; + r5 += 16; + r6 += 16; + outptr += 8; + outptr2 += 8; + dstptr += 8; + dstptr2 += 8; + } + for (; w < width; w++) { + int32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum10 = vld1q_s32(outptr2); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum10 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum10 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(123, 4, 0); + + _tmp = vld1q_s8(r1); + CALC_0(5678, 9, 0); + + _tmp = vld1q_s8(r2); + CALC_2(10111213, 14, 123, 4, 0); + + _tmp = vld1q_s8(r3); + CALC_2(15161718, 19, 5678, 9, 0); + + _tmp = vld1q_s8(r4); + CALC_2(20212223, 24, 10111213, 14, 0); + + _tmp = vld1q_s8(r5); + CALC_1(15161718, 19, 0); + + _tmp = vld1q_s8(r6); + CALC_1(20212223, 24, 0); + + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + + r0 += tail_step + IW * 2; + r1 += tail_step + IW * 2; + r2 += tail_step + IW * 2; + r3 += tail_step + IW * 2; + r4 += tail_step + IW * 2; + r5 += tail_step + IW * 2; + r6 += tail_step + IW * 2; + + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 2 < width; w += 2) { + int32x4_t _sum00, _sum01; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum01 = vld1q_s32(outptr + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum01 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum01 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(123, 4, 0); + _tmp = vld1q_s8(r0 + 8); + CALC_0(123, 4, 1); + + _tmp = vld1q_s8(r1); + CALC_0(5678, 9, 0); + _tmp = vld1q_s8(r1 + 8); + CALC_0(5678, 9, 1); + + _tmp = vld1q_s8(r2); + CALC_0(10111213, 14, 0); + _tmp = vld1q_s8(r2 + 8); + CALC_0(10111213, 14, 1); + + _tmp = vld1q_s8(r3); + CALC_0(15161718, 19, 0); + _tmp = vld1q_s8(r3 + 8); + CALC_0(15161718, 19, 1); + + _tmp = vld1q_s8(r4); + CALC_0(20212223, 24, 0); + _tmp = vld1q_s8(r4 + 8); + CALC_0(20212223, 24, 1); + + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + + r0 += 16; + r1 += 16; + r2 += 16; + r3 += 16; + r4 += 16; + outptr += 8; + dstptr += 8; + } + for (; w < width; w++) { + int32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + } else { + _sum00 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(123, 4, 0); + + _tmp = vld1q_s8(r1); + CALC_0(5678, 9, 0); + + _tmp = vld1q_s8(r2); + CALC_0(10111213, 14, 0); + + _tmp = vld1q_s8(r3); + CALC_0(15161718, 19, 0); + + _tmp = vld1q_s8(r4); + CALC_0(20212223, 24, 0); + + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } +} + +template +void conv_bias::conv_direct_stride2_7x7_int8_dot(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + const size_t tail_step = IW - 2 * OW + IW; + + const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 6, 7, 8, 16, + 8, 9, 10, 16, 10, 11, 12, 16}; + //! start from 8 + const uint8x16_t& _idx10 = _idx00; + const uint8x16_t& _idx11 = _idx01; + + int8x16_t _tmp, _elem; + int32_t* outptr = temp; + int32_t* outptr2 = outptr + OW; + int8_t* dstptr = dst; + int8_t* dstptr2 = dstptr + OW; + const int32_t* __restrict bptr = bias; + + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + IW * 2; + const int8_t* r3 = src + IW * 3; + const int8_t* r4 = src + IW * 4; + const int8_t* r5 = src + IW * 5; + const int8_t* r6 = src + IW * 6; + const int8_t* r7 = src + IW * 7; + const int8_t* r8 = src + IW * 8; + + const int8_t* k0 = filter; + + int8x16_t _k = vld1q_s8(k0); + //! filter row 1 + uint8x16_t _idx = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + int8x16_t _k123 = vqtbl1q_s8_v7(_k, _idx); + _idx = {4, 5, 6, 16, 4, 5, 6, 16, 4, 5, 6, 16, 4, 5, 6, 16}; + int8x16_t _k456 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 2 + _idx = {7, 8, 9, 10, 7, 8, 9, 10, 7, 8, 9, 10, 7, 8, 9, 10}; + int8x16_t _k78910 = vqtbl1q_s8_v7(_k, _idx); + _idx = {11, 12, 13, 16, 11, 12, 13, 16, 11, 12, 13, 16, 11, 12, 13, 16}; + int8x16_t _k111213 = vqtbl1q_s8_v7(_k, _idx); + + //! 12 13 14 15 -> 16 17 18 19 -> 20 21 22 23 -> 24 25 26 27 + _k = vld1q_s8(k0 + 12); + //! filter row 3 + _idx = {2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5}; + int8x16_t _k14151617 = vqtbl1q_s8_v7(_k, _idx); + _idx = {6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16}; + int8x16_t _k181920 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 4 + _idx = {9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12}; + int8x16_t _k21222324 = vqtbl1q_s8_v7(_k, _idx); + _idx = {13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16}; + int8x16_t _k252627 = vqtbl1q_s8_v7(_k, _idx); + + //! 24 25 26 27->28 29 30 31 -> 32 33 34 35 -> 36 37 38 39 + _k = vld1q_s8(k0 + 24); + //! filter row 5 + _idx = {4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7}; + int8x16_t _k28293031 = vqtbl1q_s8_v7(_k, _idx); + _idx = {8, 9, 10, 16, 8, 9, 10, 16, 8, 9, 10, 16, 8, 9, 10, 16}; + int8x16_t _k323334 = vqtbl1q_s8_v7(_k, _idx); + + //! 33 34 35 36 -> 37 38 39 40 -> 41 42 43 44 -> 45 46 47 48 + _k = vld1q_s8(k0 + 33); + //! filter row 6 + _idx = {2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5}; + int8x16_t _k35363738 = vqtbl1q_s8_v7(_k, _idx); + _idx = {6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16}; + int8x16_t _k394041 = vqtbl1q_s8_v7(_k, _idx); + + //! filter row 7 + _idx = {9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12}; + int8x16_t _k42434445 = vqtbl1q_s8_v7(_k, _idx); + _idx = {13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16}; + int8x16_t _k464748 = vqtbl1q_s8_v7(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 2 < width; w += 2) { + int32x4_t _sum00, _sum01, _sum10, _sum11; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum01 = vld1q_s32(outptr + 4); + _sum10 = vld1q_s32(outptr2); + _sum11 = vld1q_s32(outptr2 + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum01 = _sum00; + _sum10 = _sum00; + _sum11 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum01 = vdupq_n_s32(0); + _sum10 = vdupq_n_s32(0); + _sum11 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(123, 456, 0); + _tmp = vld1q_s8(r0 + 8); + CALC_0(123, 456, 1); + + _tmp = vld1q_s8(r1); + CALC_0(78910, 111213, 0); + _tmp = vld1q_s8(r1 + 8); + CALC_0(78910, 111213, 1); + + _tmp = vld1q_s8(r2); + CALC_2(14151617, 181920, 123, 456, 0); + _tmp = vld1q_s8(r2 + 8); + CALC_2(14151617, 181920, 123, 456, 1); + + _tmp = vld1q_s8(r3); + CALC_2(21222324, 252627, 78910, 111213, 0); + _tmp = vld1q_s8(r3 + 8); + CALC_2(21222324, 252627, 78910, 111213, 1); + + _tmp = vld1q_s8(r4); + CALC_2(28293031, 323334, 14151617, 181920, 0); + _tmp = vld1q_s8(r4 + 8); + CALC_2(28293031, 323334, 14151617, 181920, 1); + + _tmp = vld1q_s8(r5); + CALC_2(35363738, 394041, 21222324, 252627, 0); + _tmp = vld1q_s8(r5 + 8); + CALC_2(35363738, 394041, 21222324, 252627, 1); + + _tmp = vld1q_s8(r6); + CALC_2(42434445, 464748, 28293031, 323334, 0); + _tmp = vld1q_s8(r6 + 8); + CALC_2(42434445, 464748, 28293031, 323334, 1); + + _tmp = vld1q_s8(r7); + CALC_1(35363738, 394041, 0); + _tmp = vld1q_s8(r7 + 8); + CALC_1(35363738, 394041, 1); + + _tmp = vld1q_s8(r8); + CALC_1(42434445, 464748, 0); + _tmp = vld1q_s8(r8 + 8); + CALC_1(42434445, 464748, 1); + + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X8(_sum10, _sum11, outptr2, dstptr2); + + r0 += 16; + r1 += 16; + r2 += 16; + r3 += 16; + r4 += 16; + r5 += 16; + r6 += 16; + r7 += 16; + r8 += 16; + outptr += 8; + outptr2 += 8; + dstptr += 8; + dstptr2 += 8; + } + for (; w < width; w++) { + int32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum10 = vld1q_s32(outptr2); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum10 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum10 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(123, 456, 0); + + _tmp = vld1q_s8(r1); + CALC_0(78910, 111213, 0); + + _tmp = vld1q_s8(r2); + CALC_2(14151617, 181920, 123, 456, 0); + + _tmp = vld1q_s8(r3); + CALC_2(21222324, 252627, 78910, 111213, 0); + + _tmp = vld1q_s8(r4); + CALC_2(28293031, 323334, 14151617, 181920, 0); + + _tmp = vld1q_s8(r5); + CALC_2(35363738, 394041, 21222324, 252627, 0); + + _tmp = vld1q_s8(r6); + CALC_2(42434445, 464748, 28293031, 323334, 0); + + _tmp = vld1q_s8(r7); + CALC_1(35363738, 394041, 0); + + _tmp = vld1q_s8(r8); + CALC_1(42434445, 464748, 0); + + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + r7 += 8; + r8 += 8; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + + r0 += tail_step + IW * 2; + r1 += tail_step + IW * 2; + r2 += tail_step + IW * 2; + r3 += tail_step + IW * 2; + r4 += tail_step + IW * 2; + r5 += tail_step + IW * 2; + r6 += tail_step + IW * 2; + r7 += tail_step + IW * 2; + r8 += tail_step + IW * 2; + + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 2 < width; w += 2) { + int32x4_t _sum00, _sum01; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum01 = vld1q_s32(outptr + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum01 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum01 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(123, 456, 0); + _tmp = vld1q_s8(r0 + 8); + CALC_0(123, 456, 1); + + _tmp = vld1q_s8(r1); + CALC_0(78910, 111213, 0); + _tmp = vld1q_s8(r1 + 8); + CALC_0(78910, 111213, 1); + + _tmp = vld1q_s8(r2); + CALC_0(14151617, 181920, 0); + _tmp = vld1q_s8(r2 + 8); + CALC_0(14151617, 181920, 1); + + _tmp = vld1q_s8(r3); + CALC_0(21222324, 252627, 0); + _tmp = vld1q_s8(r3 + 8); + CALC_0(21222324, 252627, 1); + + _tmp = vld1q_s8(r4); + CALC_0(28293031, 323334, 0); + _tmp = vld1q_s8(r4 + 8); + CALC_0(28293031, 323334, 1); + + _tmp = vld1q_s8(r5); + CALC_0(35363738, 394041, 0); + _tmp = vld1q_s8(r5 + 8); + CALC_0(35363738, 394041, 1); + + _tmp = vld1q_s8(r6); + CALC_0(42434445, 464748, 0); + _tmp = vld1q_s8(r6 + 8); + CALC_0(42434445, 464748, 1); + + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + + r0 += 16; + r1 += 16; + r2 += 16; + r3 += 16; + r4 += 16; + r5 += 16; + r6 += 16; + outptr += 8; + dstptr += 8; + } + for (; w < width; w++) { + int32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + } else { + _sum00 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(123, 456, 0); + + _tmp = vld1q_s8(r1); + CALC_0(78910, 111213, 0); + + _tmp = vld1q_s8(r2); + CALC_0(14151617, 181920, 0); + + _tmp = vld1q_s8(r3); + CALC_0(21222324, 252627, 0); + + _tmp = vld1q_s8(r4); + CALC_0(28293031, 323334, 0); + + _tmp = vld1q_s8(r5); + CALC_0(35363738, 394041, 0); + + _tmp = vld1q_s8(r6); + CALC_0(42434445, 464748, 0); + + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } +} + +template +void conv_bias::conv_direct_stride1_5x5_int8_dot(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + const size_t tail_step = IW - OW; + + const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; + const uint8x16_t _idx01 = {4, 16, 16, 16, 5, 16, 16, 16, + 6, 16, 16, 16, 7, 16, 16, 16}; + const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; + const uint8x16_t _idx11 = {8, 16, 16, 16, 9, 16, 16, 16, + 10, 16, 16, 16, 11, 16, 16, 16}; + const uint8x16_t _idx20 = {8, 9, 10, 11, 9, 10, 11, 12, + 10, 11, 12, 13, 11, 12, 13, 14}; + const uint8x16_t _idx21 = {12, 16, 16, 16, 13, 16, 16, 16, + 14, 16, 16, 16, 15, 16, 16, 16}; + int8x16_t _tmp, _elem; + int32_t* outptr = temp; + int32_t* outptr2 = outptr + OW; + int8_t* dstptr = dst; + int8_t* dstptr2 = dstptr + OW; + const int32_t* __restrict bptr = bias; + + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + IW * 2; + const int8_t* r3 = src + IW * 3; + const int8_t* r4 = src + IW * 4; + const int8_t* r5 = src + IW * 5; + + const int8_t* k0 = filter; + + int8x16_t _k = vld1q_s8(k0); + //! filter row 1 + uint8x16_t _idx = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + int8x16_t _k123 = vqtbl1q_s8_v7(_k, _idx); + _idx = {4, 16, 16, 16, 4, 16, 16, 16, 4, 16, 16, 16, 4, 16, 16, 16}; + int8x16_t _k4 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 2 + _idx = {5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8}; + int8x16_t _k5678 = vqtbl1q_s8_v7(_k, _idx); + _idx = {9, 16, 16, 16, 9, 16, 16, 16, 9, 16, 16, 16, 9, 16, 16, 16}; + int8x16_t _k9 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 3 + _idx = {10, 11, 12, 13, 10, 11, 12, 13, 10, 11, 12, 13, 10, 11, 12, 13}; + int8x16_t _k10111213 = vqtbl1q_s8_v7(_k, _idx); + _idx = {14, 16, 16, 16, 14, 16, 16, 16, 14, 16, 16, 16, 14, 16, 16, 16}; + int8x16_t _k14 = vqtbl1q_s8_v7(_k, _idx); + //! 9 10 11 12 -> 13 14 15 16 -> 17 18 19 20 -> 21 22 23 24 + _k = vld1q_s8(k0 + 9); + //! filter row 4 + _idx = {6, 7, 8, 9, 6, 7, 8, 9, 6, 7, 8, 9, 6, 7, 8, 9}; + int8x16_t _k15161718 = vqtbl1q_s8_v7(_k, _idx); + _idx = {10, 16, 16, 16, 10, 16, 16, 16, 10, 16, 16, 16, 10, 16, 16, 16}; + int8x16_t _k19 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 5 + _idx = {11, 12, 13, 14, 11, 12, 13, 14, 11, 12, 13, 14, 11, 12, 13, 14}; + int8x16_t _k20212223 = vqtbl1q_s8_v7(_k, _idx); + _idx = {15, 16, 16, 16, 15, 16, 16, 16, 15, 16, 16, 16, 15, 16, 16, 16}; + int8x16_t _k24 = vqtbl1q_s8_v7(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 3 < width; w += 3) { + int32x4_t _sum00, _sum01, _sum02, _sum10, _sum11, _sum12; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum01 = vld1q_s32(outptr + 4); + _sum02 = vld1q_s32(outptr + 8); + _sum10 = vld1q_s32(outptr2); + _sum11 = vld1q_s32(outptr2 + 4); + _sum12 = vld1q_s32(outptr2 + 8); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum01 = _sum00; + _sum02 = _sum00; + _sum10 = _sum00; + _sum11 = _sum00; + _sum12 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum01 = vdupq_n_s32(0); + _sum02 = vdupq_n_s32(0); + _sum10 = vdupq_n_s32(0); + _sum11 = vdupq_n_s32(0); + _sum12 = vdupq_n_s32(0); + } + } + _tmp = vld1q_s8(r0); + CALC_0(123, 4, 0); + CALC_0(123, 4, 1); + CALC_0(123, 4, 2); + + _tmp = vld1q_s8(r1); + CALC_2(5678, 9, 123, 4, 0); + CALC_2(5678, 9, 123, 4, 1); + CALC_2(5678, 9, 123, 4, 2); + + _tmp = vld1q_s8(r2); + CALC_2(10111213, 14, 5678, 9, 0); + CALC_2(10111213, 14, 5678, 9, 1); + CALC_2(10111213, 14, 5678, 9, 2); + + _tmp = vld1q_s8(r3); + CALC_2(15161718, 19, 10111213, 14, 0); + CALC_2(15161718, 19, 10111213, 14, 1); + CALC_2(15161718, 19, 10111213, 14, 2); + + _tmp = vld1q_s8(r4); + CALC_2(20212223, 24, 15161718, 19, 0); + CALC_2(20212223, 24, 15161718, 19, 1); + CALC_2(20212223, 24, 15161718, 19, 2); + + _tmp = vld1q_s8(r5); + CALC_1(20212223, 24, 0); + CALC_1(20212223, 24, 1); + CALC_1(20212223, 24, 2); + + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X4(_sum02, outptr + 8, dstptr + 8); + POSTPROCESS_1X8(_sum10, _sum11, outptr2, dstptr2); + POSTPROCESS_1X4(_sum12, outptr2 + 8, dstptr2 + 8); + + r0 += 12; + r1 += 12; + r2 += 12; + r3 += 12; + r4 += 12; + r5 += 12; + outptr += 12; + outptr2 += 12; + dstptr += 12; + dstptr2 += 12; + } + for (; w < width; w++) { + int32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum10 = vld1q_s32(outptr2); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum10 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum10 = vdupq_n_s32(0); + } + } + + _tmp = vtranslq_s8(vld1_s8(r0)); + CALC_0(123, 4, 0); + + _tmp = vtranslq_s8(vld1_s8(r1)); + CALC_2(5678, 9, 123, 4, 0); + + _tmp = vtranslq_s8(vld1_s8(r2)); + CALC_2(10111213, 14, 5678, 9, 0); + + _tmp = vtranslq_s8(vld1_s8(r3)); + CALC_2(15161718, 19, 10111213, 14, 0); + + _tmp = vtranslq_s8(vld1_s8(r4)); + CALC_2(20212223, 24, 15161718, 19, 0); + + _tmp = vtranslq_s8(vld1_s8(r5)); + CALC_1(20212223, 24, 0); + + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 3 < width; w += 3) { + int32x4_t _sum00, _sum01, _sum02; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum01 = vld1q_s32(outptr + 4); + _sum02 = vld1q_s32(outptr + 8); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum01 = _sum00; + _sum02 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum01 = vdupq_n_s32(0); + _sum02 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(123, 4, 0); + CALC_0(123, 4, 1); + CALC_0(123, 4, 2); + + _tmp = vld1q_s8(r1); + CALC_0(5678, 9, 0); + CALC_0(5678, 9, 1); + CALC_0(5678, 9, 2); + + _tmp = vld1q_s8(r2); + CALC_0(10111213, 14, 0); + CALC_0(10111213, 14, 1); + CALC_0(10111213, 14, 2); + + _tmp = vld1q_s8(r3); + CALC_0(15161718, 19, 0); + CALC_0(15161718, 19, 1); + CALC_0(15161718, 19, 2); + + _tmp = vld1q_s8(r4); + CALC_0(20212223, 24, 0); + CALC_0(20212223, 24, 1); + CALC_0(20212223, 24, 2); + + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X4(_sum02, outptr + 8, dstptr + 8); + + r0 += 12; + r1 += 12; + r2 += 12; + r3 += 12; + r4 += 12; + outptr += 12; + dstptr += 12; + } + for (; w < width; w++) { + int32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + } else { + _sum00 = vdupq_n_s32(0); + } + } + + _tmp = vtranslq_s8(vld1_s8(r0)); + CALC_0(123, 4, 0); + + _tmp = vtranslq_s8(vld1_s8(r1)); + CALC_0(5678, 9, 0); + + _tmp = vtranslq_s8(vld1_s8(r2)); + CALC_0(10111213, 14, 0); + + _tmp = vtranslq_s8(vld1_s8(r3)); + CALC_0(15161718, 19, 0); + + _tmp = vtranslq_s8(vld1_s8(r4)); + CALC_0(20212223, 24, 0); + + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } +} + +template +void conv_bias::conv_direct_stride1_7x7_int8_dot(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + const size_t tail_step = IW - OW; + + const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, + 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; + const uint8x16_t _idx11 = {8, 9, 10, 16, 9, 10, 11, 16, + 10, 11, 12, 16, 11, 12, 13, 16}; + + int8x16_t _tmp, _elem; + int32_t* outptr = temp; + int32_t* outptr2 = outptr + OW; + int8_t* dstptr = dst; + int8_t* dstptr2 = dstptr + OW; + const int32_t* __restrict bptr = bias; + + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + IW * 2; + const int8_t* r3 = src + IW * 3; + const int8_t* r4 = src + IW * 4; + const int8_t* r5 = src + IW * 5; + const int8_t* r6 = src + IW * 6; + const int8_t* r7 = src + IW * 7; + + const int8_t* k0 = filter; + + int8x16_t _k = vld1q_s8(k0); + //! filter row 1 + uint8x16_t _idx = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + int8x16_t _k123 = vqtbl1q_s8_v7(_k, _idx); + _idx = {4, 5, 6, 16, 4, 5, 6, 16, 4, 5, 6, 16, 4, 5, 6, 16}; + int8x16_t _k456 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 2 + _idx = {7, 8, 9, 10, 7, 8, 9, 10, 7, 8, 9, 10, 7, 8, 9, 10}; + int8x16_t _k78910 = vqtbl1q_s8_v7(_k, _idx); + _idx = {11, 12, 13, 16, 11, 12, 13, 16, 11, 12, 13, 16, 11, 12, 13, 16}; + int8x16_t _k111213 = vqtbl1q_s8_v7(_k, _idx); + + //! 12 13 14 15 -> 16 17 18 19 -> 20 21 22 23 -> 24 25 26 27 + _k = vld1q_s8(k0 + 12); + //! filter row 3 + _idx = {2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5}; + int8x16_t _k14151617 = vqtbl1q_s8_v7(_k, _idx); + _idx = {6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16}; + int8x16_t _k181920 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 4 + _idx = {9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12}; + int8x16_t _k21222324 = vqtbl1q_s8_v7(_k, _idx); + _idx = {13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16}; + int8x16_t _k252627 = vqtbl1q_s8_v7(_k, _idx); + + //! 24 25 26 27->28 29 30 31 -> 32 33 34 35 -> 36 37 38 39 + _k = vld1q_s8(k0 + 24); + //! filter row 5 + _idx = {4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7}; + int8x16_t _k28293031 = vqtbl1q_s8_v7(_k, _idx); + _idx = {8, 9, 10, 16, 8, 9, 10, 16, 8, 9, 10, 16, 8, 9, 10, 16}; + int8x16_t _k323334 = vqtbl1q_s8_v7(_k, _idx); + + //! 33 34 35 36 -> 37 38 39 40 -> 41 42 43 44 -> 45 46 47 48 + _k = vld1q_s8(k0 + 33); + //! filter row 6 + _idx = {2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5}; + int8x16_t _k35363738 = vqtbl1q_s8_v7(_k, _idx); + _idx = {6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16}; + int8x16_t _k394041 = vqtbl1q_s8_v7(_k, _idx); + + //! filter row 7 + _idx = {9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12}; + int8x16_t _k42434445 = vqtbl1q_s8_v7(_k, _idx); + _idx = {13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16}; + int8x16_t _k464748 = vqtbl1q_s8_v7(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 2 < width; w += 2) { + int32x4_t _sum00, _sum01, _sum10, _sum11; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum01 = vld1q_s32(outptr + 4); + _sum10 = vld1q_s32(outptr2); + _sum11 = vld1q_s32(outptr2 + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum01 = _sum00; + _sum10 = _sum00; + _sum11 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum01 = vdupq_n_s32(0); + _sum10 = vdupq_n_s32(0); + _sum11 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(123, 456, 0); + CALC_0(123, 456, 1); + + _tmp = vld1q_s8(r1); + CALC_2(78910, 111213, 123, 456, 0); + CALC_2(78910, 111213, 123, 456, 1); + + _tmp = vld1q_s8(r2); + CALC_2(14151617, 181920, 78910, 111213, 0); + CALC_2(14151617, 181920, 78910, 111213, 1); + + _tmp = vld1q_s8(r3); + CALC_2(21222324, 252627, 14151617, 181920, 0); + CALC_2(21222324, 252627, 14151617, 181920, 1); + + _tmp = vld1q_s8(r4); + CALC_2(28293031, 323334, 21222324, 252627, 0); + CALC_2(28293031, 323334, 21222324, 252627, 1); + + _tmp = vld1q_s8(r5); + CALC_2(35363738, 394041, 28293031, 323334, 0); + CALC_2(35363738, 394041, 28293031, 323334, 1); + + _tmp = vld1q_s8(r6); + CALC_2(42434445, 464748, 35363738, 394041, 0); + CALC_2(42434445, 464748, 35363738, 394041, 1); + + _tmp = vld1q_s8(r7); + CALC_1(42434445, 464748, 0); + CALC_1(42434445, 464748, 1); + + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X8(_sum10, _sum11, outptr2, dstptr2); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + r7 += 8; + outptr += 8; + outptr2 += 8; + dstptr += 8; + dstptr2 += 8; + } + for (; w < width; w++) { + int32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum10 = vld1q_s32(outptr2); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum10 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum10 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(123, 456, 0); + + _tmp = vld1q_s8(r1); + CALC_2(78910, 111213, 123, 456, 0); + + _tmp = vld1q_s8(r2); + CALC_2(14151617, 181920, 78910, 111213, 0); + + _tmp = vld1q_s8(r3); + CALC_2(21222324, 252627, 14151617, 181920, 0); + + _tmp = vld1q_s8(r4); + CALC_2(28293031, 323334, 21222324, 252627, 0); + + _tmp = vld1q_s8(r5); + CALC_2(35363738, 394041, 28293031, 323334, 0); + + _tmp = vld1q_s8(r6); + CALC_2(42434445, 464748, 35363738, 394041, 0); + + _tmp = vld1q_s8(r7); + CALC_1(42434445, 464748, 0); + + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + r7 += 4; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + r6 += tail_step + IW; + r7 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 2 < width; w += 2) { + int32x4_t _sum00, _sum01; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + _sum01 = vld1q_s32(outptr + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + _sum01 = _sum00; + } else { + _sum00 = vdupq_n_s32(0); + _sum01 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(123, 456, 0); + CALC_0(123, 456, 1); + + _tmp = vld1q_s8(r1); + CALC_0(78910, 111213, 0); + CALC_0(78910, 111213, 1); + + _tmp = vld1q_s8(r2); + CALC_0(14151617, 181920, 0); + CALC_0(14151617, 181920, 1); + + _tmp = vld1q_s8(r3); + CALC_0(21222324, 252627, 0); + CALC_0(21222324, 252627, 1); + + _tmp = vld1q_s8(r4); + CALC_0(28293031, 323334, 0); + CALC_0(28293031, 323334, 1); + + _tmp = vld1q_s8(r5); + CALC_0(35363738, 394041, 0); + CALC_0(35363738, 394041, 1); + + _tmp = vld1q_s8(r6); + CALC_0(42434445, 464748, 0); + CALC_0(42434445, 464748, 1); + + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + outptr += 8; + dstptr += 8; + } + for (; w < width; w++) { + int32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_s32(outptr); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + _sum00 = vdupq_n_s32(bptr[0]); + } else { + _sum00 = vdupq_n_s32(0); + } + } + + _tmp = vld1q_s8(r0); + CALC_0(123, 456, 0); + + _tmp = vld1q_s8(r1); + CALC_0(78910, 111213, 0); + + _tmp = vld1q_s8(r2); + CALC_0(14151617, 181920, 0); + + _tmp = vld1q_s8(r3); + CALC_0(21222324, 252627, 0); + + _tmp = vld1q_s8(r4); + CALC_0(28293031, 323334, 0); + + _tmp = vld1q_s8(r5); + CALC_0(35363738, 394041, 0); + + _tmp = vld1q_s8(r6); + CALC_0(42434445, 464748, 0); + + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } +} + +#undef CALC_0 +#undef CALC_1 +#undef CALC_2 + +#undef POSTPROCESS_1X8 +#undef POSTPROCESS2_1X8 +#undef POSTPROCESS_2X4 +#undef POSTPROCESS_1X4 +#undef ST1_S32X4 +#undef ST2_S32X4X2 + +#define INSTANTIATION(stride, i, first_ic, last_ic, bias, Op) \ + template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_dot< \ + first_ic, last_ic, bias, Op>( \ + const int8_t*, const int8_t*, const int32_t*, int32_t*, int8_t*, \ + const size_t, const size_t, const size_t, const size_t, \ + const Op&); + +#define FOR_OP(stride, i, first_ic, last_ic, bias) \ + INSTANTIATION(stride, i, first_ic, last_ic, bias, \ + TypeCvtOp) \ + INSTANTIATION(stride, i, first_ic, last_ic, bias, \ + ReluOp) \ + INSTANTIATION(stride, i, first_ic, last_ic, bias, \ + HSwishOp) + +#define FOR_BIAS(stride, i, first_ic, last_ic) \ + FOR_OP(stride, i, first_ic, last_ic, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, first_ic, last_ic, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_IC(stride, i) \ + FOR_BIAS(stride, i, true, true) \ + FOR_BIAS(stride, i, true, false) \ + FOR_BIAS(stride, i, false, false) \ + FOR_BIAS(stride, i, false, true) + +#define FOR_FILTER(stride) \ + FOR_IC(stride, 2) \ + FOR_IC(stride, 3) \ + FOR_IC(stride, 5) \ + FOR_IC(stride, 7) + +#define FOR_STRIDE \ + FOR_FILTER(stride1) \ + FOR_FILTER(stride2) + +FOR_STRIDE + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_IC +#undef FOR_BIAS +#undef FOR_NONLINEAR +#undef INSTANTIATION + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod.h b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.h new file mode 100644 index 00000000..f5752809 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.h @@ -0,0 +1,42 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_dotprod.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace conv_bias { + +#define KERN(stride, i) \ + template \ + void conv_direct_##stride##_##i##x##i##_int8_dot( \ + const int8_t* src, const int8_t* filter, const int32_t* bias, \ + int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, \ + const size_t OH, const size_t OW, const Op& op); + +KERN(stride1, 2) +KERN(stride1, 3) +KERN(stride1, 5) +KERN(stride1, 7) + +KERN(stride2, 2) +KERN(stride2, 3) +KERN(stride2, 5) +KERN(stride2, 7) + +#undef KERN + +} // namesapce conv_bias +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp new file mode 100644 index 00000000..6551776d --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp @@ -0,0 +1,393 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/int8/algos.h" +#include "src/arm_common/conv_bias/int8/direct.h" +#include "src/arm_common/conv_bias/int8/strategy.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +#include "midout.h" + +using namespace megdnn; +using namespace arm_common; +using conv_fun = std::function; +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride1) + +static void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + auto&& fm = param.filter_meta; + auto SW = fm.stride[1]; + auto OH = param.osz[0]; + auto OW = param.osz[1]; + auto FH = fm.spatial[0]; + auto FW = fm.spatial[1]; + + OH2 = OH; + OW2 = (OW + 7) & ~7; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; +} + +static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { + constexpr size_t src_expand = 4; + auto&& fm = param.filter_meta; + size_t group = fm.group; + size_t batch = param.n; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + size_t FH = fm.spatial[0]; + size_t FW = fm.spatial[1]; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(param, IH2, IW2, OH2, OW2); + if (group == 1) { + size_t src_size = + batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; + size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); + return {nullptr, {src_size, weight_size}}; + } else { + size_t src_size = + param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; + size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); + return {nullptr, {src_size, weight_size}}; + } +}; + +static void copy_padding_kern(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t GROUP = kern_param.filter_meta.group; + + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + constexpr int pack_ic = 4; + constexpr int expend_element = 4; + // TODO: block dim is better to get from arg + size_t workspace_ic_block = 4; + size_t workspace_batch_id = workspace_ids[0]; + size_t workspace_group_id = workspace_ids[1]; + size_t workspace_ic_id = workspace_ids[2]; + size_t workspace_ic = workspace_ic_id * workspace_ic_block; + size_t batch_id = ncb_index.ndrange_id[0]; + size_t group_id = ncb_index.ndrange_id[1]; + size_t group_pack_size = 1; + + int nr_pad_h = PH * IW2 * pack_ic * expend_element; + int nr_pad_w = PW * pack_ic * expend_element; + int over_pad = std::max(0_z, IW2 - IW - 2 * PW) * pack_ic * expend_element; + //! copy to sptr_base to eliminate padding effect + const int8_t* sptr = static_cast(kern_param.src( + batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); + int8_t* sptr_base = static_cast(bundle.get(0)) + + (workspace_batch_id * GROUP * padding_group_size + + workspace_group_id * padding_group_size + + workspace_ic * IH2 * IW2) * + expend_element; + size_t nr_ic = workspace_ic_block; + if (GROUP > 1) { + nr_ic = IC; + } + rep_step(ic_idx, nr_ic, pack_ic) { + std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); + sptr_base += nr_pad_h; + rep(ih_idx, IH) { + std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); + sptr_base += nr_pad_w; + conv_bias::nchw44_pack_src(sptr, sptr_base, IW); + sptr_base += IW * pack_ic * expend_element; + sptr += IW * pack_ic; + std::memset(sptr_base, 0, (nr_pad_w + over_pad) * sizeof(int8_t)); + sptr_base += nr_pad_w + over_pad; + } + std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); + sptr_base += nr_pad_h; + } +} + +template +static void do_conv_kern(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t OC = kern_param.filter_meta.ocpg; + size_t GROUP = kern_param.filter_meta.group; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_post_process = + kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) + Op op = Op(1.0f, 4.0f); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + op = Op(scale_bias, scale_dst); + } + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + + constexpr size_t pack_c = 4; + constexpr size_t src_expand_size = 4; + const size_t workspace_batch_id = workspace_ids[0]; + const size_t workspace_group_id = workspace_ids[1]; + const size_t batch_id = ncb_index.ndrange_id[0]; + const size_t group_id = ncb_index.ndrange_id[1]; + const size_t oc_id = ncb_index.ndrange_id[2]; + const size_t oc_block_num = ncb_range[2]; + size_t nr_pack_per_step = div_ceil(div_ceil(OC, pack_c), oc_block_num); + size_t oc_block = nr_pack_per_step * pack_c; + const size_t oc_idx = oc_id * oc_block; + if (oc_id == (oc_block_num - 1)) { + oc_block = OC - oc_id * nr_pack_per_step * pack_c; + } + megdnn_assert(oc_block % pack_c == 0, + "oc must be devisible by 4, but oc = %zu", oc_block); + const int8_t* sptr = + static_cast(bundle.get(0)) + + workspace_batch_id * GROUP * padding_group_size * src_expand_size + + workspace_group_id * padding_group_size * src_expand_size; + + const int8_t* fptr = + kern_param.filter(group_id) + oc_idx * FH * FW * IC; + void* dst = reinterpret_cast( + reinterpret_cast( + kern_param.dst(batch_id, group_id)) + + oc_idx * OH * OW); + const int32_t* bptr = + kern_param.bias(batch_id, group_id) + oc_idx; + auto packed_weight = reinterpret_cast(bundle.get(1)) + + group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; + conv_bias::nchw44_pack_filter(fptr, packed_weight, + oc_block / 4 * IC / 4 * FH * FW); + +#define KERN1_NCHW44_CONV(filter) \ + conv_bias::conv_direct_stride1_##filter##x##filter##_int8_nchw44< \ + bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \ + static_cast(dst), oc_block, IC, \ + IH2, IW2, OH, OW, op) + DISPATCH_FILTER(filter, KERN1_NCHW44_CONV) +#undef KERN1_NCHW44_CONV +} + +/* ===================== stride1 algo ===================== */ +bool ConvBiasImpl::AlgoS8DirectStride1NCHW44::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + MEGDNN_MARK_USED_VAR(algo_selection_strategy); + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto OC = fm.ocpg; + auto IC = fm.icpg; + bool avaible = //! src and filter are qint8, dst is qint8 or qint32 + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32))) && + (fm.format == param::Convolution::Format::NCHW44) && + (OC % 4 == 0 && IC % 4 == 0 && OC >= 4) && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && + param.bias_mode != BiasMode::BIAS; + return avaible; +} + +bool ConvBiasImpl::AlgoS8DirectStride1NCHW44::is_preferred( + megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, + const NCBKernSizeParam& param) const { + // TODO: benchmark and fix + MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr); + MEGDNN_MARK_USED_VAR(param); + return false; +} + +size_t ConvBiasImpl::AlgoS8DirectStride1NCHW44::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + return get_bundle(param).total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoS8DirectStride1NCHW44::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + size_t OW = param.osz[1]; + size_t group = fm.group; + size_t fh = fm.spatial[0]; + size_t fw = fm.spatial[1]; + WorkspaceBundle wbundle = get_bundle(param); + conv_fun do_conv_fun = nullptr; + int ow_remain = OW % 8; +// NOTE: remain_w is not used to gen hash of midout for compatible with changing +// shape runtime +#define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride1, \ + midout_iv(#filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + +#define GET_OP_PARAM(filter, bias_mode, remain_w) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_REMAIN_W_PARAM(filter, bias_mode) \ + switch (ow_remain) { \ + case 0: \ + GET_OP_PARAM(filter, bias_mode, 0); \ + break; \ + case 1: \ + GET_OP_PARAM(filter, bias_mode, 1); \ + break; \ + case 2: \ + GET_OP_PARAM(filter, bias_mode, 2); \ + break; \ + case 3: \ + GET_OP_PARAM(filter, bias_mode, 3); \ + break; \ + case 4: \ + GET_OP_PARAM(filter, bias_mode, 4); \ + break; \ + case 5: \ + GET_OP_PARAM(filter, bias_mode, 5); \ + break; \ + case 6: \ + GET_OP_PARAM(filter, bias_mode, 6); \ + break; \ + case 7: \ + GET_OP_PARAM(filter, bias_mode, 7); \ + break; \ + default: \ + megdnn_assert(0); \ + } + +#define GET_BIAS_MODE_PARAM(filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + +#undef DO_CONV_KERN_FUN +#undef GET_REMAIN_W_PARAM +#undef GET_OP_PARAM +#undef GET_BIAS_MODE_PARAM +#undef DISPATCH_CONV_KERN + + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + WorkspaceBundle bundle = wbundle; + + constexpr size_t pack_oc = 4; + size_t oc_step = pack_oc; + if (fh == 2 && fw == 2 && OC >= 8) { + oc_step = 8; + } + + if (group == 1) { + CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + constexpr size_t pack_ic = 4; + ret_kerns.push_back({copy_padding, {N, group, div_ceil(IC, pack_ic)}}); + auto do_conv = [bundle, do_conv_fun, ncb_range]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, + ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + } else { + CpuNDRange ncb_range = {N, group, 1}; + auto do_conv = [bundle, do_conv_fun, ncb_range]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + {0, ncb_index.thread_id, 0}); + do_conv_fun(bundle, kern_param, ncb_index, + {0, ncb_index.thread_id, 0}, ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + } + + return ret_kerns; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp new file mode 100644 index 00000000..ca5c45a6 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp @@ -0,0 +1,791 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8/direct.h" +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; +namespace { + +/** +dot like impl. dot 4 ic to 1 oc, accumale to c +example: (format like weight) +packed weight +low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> +--------------------------------------------------------------------- +high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> +dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> +**/ +// TODO: can try oh = 2 impl, oc = 8 impl +template +static void ker_neon_dirctconv_3x3s1_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, + int iw, const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[2 * 4]; + int8x16_t weight[3]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[2]; + init_oc4_ow8(c, bias_ptr); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + + c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); + + c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); + + c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); + + c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + store_oc4_ow8_remain_static(c, op, dst_ptr); +} + +template +static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, + const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc4 = oc_step * fh * fw * ic; + + int32x4_t c[2][8]; + int8x16_t weight[2][2]; + int8x16_t src[8 + 1]; + int16x8_t temp_c[4]; + + init_oc8_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0][0] = vld1q_s8(read_weight_ptr); + weight[0][1] = vld1q_s8(read_weight_ptr + 16); + weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); + weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); + + c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); + c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); + c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[2]); + c[1][1] = vdotq_s32_h(weight[1][0], src[1], c[1][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); + c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); + c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[2]); + c[1][1] = vdotq_s32_h(weight[1][1], src[2], c[1][1], temp_c[3]); + + c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]); + c[1][2] = vdotq_s32_h(weight[1][0], src[2], c[1][2], temp_c[1]); + c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[2]); + c[1][3] = vdotq_s32_h(weight[1][0], src[3], c[1][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]); + c[1][2] = vdotq_s32_h(weight[1][1], src[3], c[1][2], temp_c[1]); + c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[2]); + c[1][3] = vdotq_s32_h(weight[1][1], src[4], c[1][3], temp_c[3]); + + c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]); + c[1][4] = vdotq_s32_h(weight[1][0], src[4], c[1][4], temp_c[1]); + c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[2]); + c[1][5] = vdotq_s32_h(weight[1][0], src[5], c[1][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]); + c[1][4] = vdotq_s32_h(weight[1][1], src[5], c[1][4], temp_c[1]); + c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[2]); + c[1][5] = vdotq_s32_h(weight[1][1], src[6], c[1][5], temp_c[3]); + + c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]); + c[1][6] = vdotq_s32_h(weight[1][0], src[6], c[1][6], temp_c[1]); + c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[2]); + c[1][7] = vdotq_s32_h(weight[1][0], src[7], c[1][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]); + c[1][6] = vdotq_s32_h(weight[1][1], src[7], c[1][6], temp_c[1]); + c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[2]); + c[1][7] = vdotq_s32_h(weight[1][1], src[8], c[1][7], temp_c[3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_oc8_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); +} + +template +static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, + int iw, const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[2 * 4]; + int8x16_t weight[2]; + int8x16_t src[8 + 1]; + int16x8_t temp_c[2]; + init_oc4_ow8(c, bias_ptr); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + + c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); + + c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); + + c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); + + c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + store_oc4_ow8_remain_static(c, op, dst_ptr); +} + +template +static void ker_neon_dirctconv_5x5s1_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, + int iw, const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[2 * 4]; + int8x16_t weight[5]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[2]; + init_oc4_ow8(c, bias_ptr); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + + c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[3], src[4], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[4], src[5], c[1], temp_c[1]); + + c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[3], src[5], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[3], src[6], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[4], src[6], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[4], src[7], c[3], temp_c[1]); + + c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[3], src[7], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[3], src[8], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[4], src[8], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[4], src[9], c[5], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); + + c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[3], src[9], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[3], src[0], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[4], src[0], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[4], src[1], c[7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + store_oc4_ow8_remain_static(c, op, dst_ptr); +} + +template +static void ker_neon_dirctconv_7x7s1_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, + int iw, const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[2 * 4]; + int8x16_t weight[7]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[2]; + init_oc4_ow8(c, bias_ptr); + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); + weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); + + c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[3], src[4], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[4], src[5], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[5], src[5], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[5], src[6], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[6], src[6], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[6], src[7], c[1], temp_c[1]); + + c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[3], src[5], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[3], src[6], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[4], src[6], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[4], src[7], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[5], src[7], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[5], src[8], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[6], src[8], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[6], src[9], c[3], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); + + c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[3], src[7], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[3], src[8], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[4], src[8], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[4], src[9], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[5], src[9], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[5], src[0], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[6], src[0], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[6], src[1], c[5], temp_c[1]); + + src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); + src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); + + c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[3], src[9], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[3], src[0], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[4], src[0], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[4], src[1], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[5], src[1], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[5], src[2], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[6], src[2], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[6], src[3], c[7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + store_oc4_ow8_remain_static(c, op, dst_ptr); +} + +} // namespace + +/** +origin weight shape +packed weight shape +example: (format like weight) +origin +<0, 0> <1, 0> <2, 0> <3, 0> +<0, 1> <1, 1> <2, 1> <3, 1> +<0, 2> <1, 2> <2, 2> <3, 2> +<0, 3> <1, 3> <2, 3> <3, 3> +packed +low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> +--------------------------------------------------------------------- +high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> +**/ +void conv_bias::nchw44_pack_filter(const int8_t* src, int8_t* dst, int length) { + static const uint8_t weight_idx_buffer[16] = {0, 4, 9, 13, 2, 6, 11, 15, + 12, 8, 5, 1, 14, 10, 7, 3}; + constexpr int simd_len = 16; + uint8x16_t weight_idx = vld1q_u8(weight_idx_buffer); + for (int i = 0; i < length; i++) { + int8x16_t result = vldq_tbl_s8(src + i * simd_len, weight_idx); + vst1q_s8(dst + i * simd_len, result); + } +} +/** +origin src shape +packed src shape +example: (format like ) +origin +<0> <0> <0> <0> +packed +low 64 bit <0> <1> <2> <3> | <0> <1> <2> <3> +--------------------------------------------------------------------- +high 64 bit <3> <2> <1> <0> | <3> <2> <1> <0> +**/ +void conv_bias::nchw44_pack_src(const int8_t* src, int8_t* dst, int length) { + static const uint8_t src_idx_buffer[16] = {0, 1, 2, 3, 0, 1, 2, 3, + 3, 2, 1, 0, 3, 2, 1, 0}; + constexpr int pack_ic = 4; + constexpr int simd_len = 16; + uint8x16_t src_idx = vld1q_u8(src_idx_buffer); + for (int i = 0; i < length; i++) { + int8x16_t result = vld_dup_tbl_s32(src + i * pack_ic, src_idx); + vst1q_s8(dst + i * simd_len, result); + } +} + +template +void conv_bias::conv_direct_stride1_2x2_int8_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, + int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, + const size_t ih, const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t filter_size = 2; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t big_oc_step = 8; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_oc = oh * ow * ic_step; + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s1_oc8_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_2x2s1_oc8_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc, op); + } + } + } + if (oc_remain > 0) { + const size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s1_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_2x2s1_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + } + } +} +template +void conv_bias::conv_direct_stride1_3x3_int8_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, + int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, + const size_t ih, const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t filter_size = 3; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_3x3s1_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_3x3s1_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + } + } +} +template +void conv_bias::conv_direct_stride1_5x5_int8_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, + int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, + const size_t ih, const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t filter_size = 5; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_5x5s1_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_5x5s1_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + } + } +} + +template +void conv_bias::conv_direct_stride1_7x7_int8_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, + int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, + const size_t ih, const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t filter_size = 7; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_7x7s1_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_7x7s1_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + } + } +} + +#define INSTANTIATION(stride, i, bias, remain_w, Op) \ + template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \ + bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \ + int32_t*, int8_t*, const size_t, const size_t, \ + const size_t, const size_t, const size_t, \ + const size_t, const Op&); + +#define FOR_OP(stride, i, bias, remain_w) \ + INSTANTIATION(stride, i, bias, remain_w, \ + TypeCvtOp) \ + INSTANTIATION(stride, i, bias, remain_w, \ + ReluOp) \ + INSTANTIATION(stride, i, bias, remain_w, \ + HSwishOp) + +#define FOR_REMAIN(stride, i, bias) \ + FOR_OP(stride, i, bias, 0) \ + FOR_OP(stride, i, bias, 1) \ + FOR_OP(stride, i, bias, 2) \ + FOR_OP(stride, i, bias, 3) \ + FOR_OP(stride, i, bias, 4) \ + FOR_OP(stride, i, bias, 5) \ + FOR_OP(stride, i, bias, 6) \ + FOR_OP(stride, i, bias, 7) + +#define FOR_BIAS(stride, i) \ + FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \ + FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_FILTER(stride) \ + FOR_BIAS(stride, 2) \ + FOR_BIAS(stride, 3) \ + FOR_BIAS(stride, 5) \ + FOR_BIAS(stride, 7) + +FOR_FILTER(stride1) + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_IC +#undef FOR_BIAS +#undef FOR_NONLINEAR +#undef FOR_REMAIN +#undef INSTANTIATION diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_algo.cpp new file mode 100644 index 00000000..9b47eefa --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_algo.cpp @@ -0,0 +1,404 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/int8/algos.h" +#include "src/arm_common/conv_bias/int8/direct.h" +#include "src/arm_common/conv_bias/int8/strategy.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +#include "midout.h" + +using namespace megdnn; +using namespace arm_common; +using conv_fun = std::function; +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride2) + +static void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + auto&& fm = param.filter_meta; + size_t SW = fm.stride[1]; + size_t IH = param.isz[0]; + size_t IW = param.isz[1]; + size_t OH = param.osz[0]; + size_t OW = param.osz[1]; + size_t FH = fm.spatial[0]; + size_t FW = fm.spatial[1]; + + OH2 = OH; + OW2 = (OW + 7) & ~7; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; + // Because stride is 2, sometimes IW == IW2+1. Do a max update to + // handle this case. + IH2 = std::max(IH2, IH); + IW2 = std::max(IW2, IW); +} +static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { + constexpr size_t src_expand = 4; + auto&& fm = param.filter_meta; + size_t group = fm.group; + size_t batch = param.n; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + size_t FH = fm.spatial[0]; + size_t FW = fm.spatial[1]; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(param, IH2, IW2, OH2, OW2); + if (group == 1) { + size_t src_size = + batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; + size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); + return {nullptr, {src_size, weight_size}}; + } else { + size_t src_size = + param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; + size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); + return {nullptr, {src_size, weight_size}}; + } +}; + +static void copy_padding_kern(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t GROUP = kern_param.filter_meta.group; + + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + constexpr int pack_ic = 4; + constexpr int expend_element = 4; + // TODO: block dim is better to get from arg + size_t workspace_ic_block = 4; + size_t workspace_batch_id = workspace_ids[0]; + size_t workspace_group_id = workspace_ids[1]; + size_t workspace_ic_id = workspace_ids[2]; + size_t workspace_ic = workspace_ic_id * workspace_ic_block; + size_t batch_id = ncb_index.ndrange_id[0]; + size_t group_id = ncb_index.ndrange_id[1]; + size_t group_pack_size = 1; + + int nr_pad_h = PH * IW2 * pack_ic * expend_element; + int nr_pad_w = PW * pack_ic * expend_element; + int over_pad = std::max(0_z, IW2 - IW - 2 * PW) * pack_ic * expend_element; + int row_last_pad = ((int)IW2 - (int)IW - 2 * (int)PW) >= 0 + ? nr_pad_w + over_pad + : (IW2 - IW - PW) * pack_ic * expend_element; + int col_last_pad = + ((int)IH2 - (int)IH - 2 * (int)PH) >= 0 + ? nr_pad_h + : (IH2 - IH - PH) * IW2 * pack_ic * expend_element; + const int8_t* sptr = static_cast(kern_param.src( + batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); + + //! copy to sptr_base to eliminate padding effect + int8_t* sptr_base = static_cast(bundle.get(0)) + + (workspace_batch_id * GROUP * padding_group_size + + workspace_group_id * padding_group_size + + workspace_ic * IH2 * IW2) * + expend_element; + size_t nr_ic = workspace_ic_block; + if (GROUP > 1) { + nr_ic = IC; + } + rep_step(ic_idx, nr_ic, pack_ic) { + std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); + sptr_base += nr_pad_h; + rep(ih_idx, IH) { + std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); + sptr_base += nr_pad_w; + conv_bias::nchw44_pack_src(sptr, sptr_base, IW); + sptr_base += IW * pack_ic * expend_element; + sptr += IW * pack_ic; + std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); + sptr_base += row_last_pad; + } + std::memset(sptr_base, 0, col_last_pad * sizeof(int8_t)); + sptr_base += col_last_pad; + } +} + +template +static void do_conv_kern(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t OC = kern_param.filter_meta.ocpg; + size_t GROUP = kern_param.filter_meta.group; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_post_process = + kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) + Op op = Op(1.0f, 4.0f); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + op = Op(scale_bias, scale_dst); + } + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + + constexpr size_t pack_c = 4; + constexpr size_t src_expand_size = 4; + const size_t workspace_batch_id = workspace_ids[0]; + const size_t workspace_group_id = workspace_ids[1]; + const size_t batch_id = ncb_index.ndrange_id[0]; + const size_t group_id = ncb_index.ndrange_id[1]; + const size_t oc_id = ncb_index.ndrange_id[2]; + const size_t oc_block_num = ncb_range[2]; + size_t nr_pack_per_step = div_ceil(div_ceil(OC, pack_c), oc_block_num); + size_t oc_block = nr_pack_per_step * pack_c; + const size_t oc_idx = oc_id * oc_block; + if (oc_id == (oc_block_num - 1)) { + oc_block = OC - oc_id * nr_pack_per_step * pack_c; + } + megdnn_assert(oc_block % pack_c == 0, + "oc must be devisible by 4, but oc = %zu", oc_block); + const int8_t* sptr = + static_cast(bundle.get(0)) + + workspace_batch_id * GROUP * padding_group_size * src_expand_size + + workspace_group_id * padding_group_size * src_expand_size; + + const int8_t* fptr = + kern_param.filter(group_id) + oc_idx * FH * FW * IC; + void* dst = reinterpret_cast( + reinterpret_cast( + kern_param.dst(batch_id, group_id)) + + oc_idx * OH * OW); + const int32_t* bptr = + kern_param.bias(batch_id, group_id) + oc_idx; + auto packed_weight = reinterpret_cast(bundle.get(1)) + + group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; + conv_bias::nchw44_pack_filter(fptr, packed_weight, + oc_block / 4 * IC / 4 * FH * FW); +#define KERN1_NCHW44_CONV(filter) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw44< \ + bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \ + static_cast(dst), oc_block, IC, \ + IH2, IW2, OH, OW, op) + DISPATCH_FILTER(filter, KERN1_NCHW44_CONV) +#undef KERN1_NCHW44_CONV +} + +/* ===================== stride2 algo ===================== */ +bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + MEGDNN_MARK_USED_VAR(algo_selection_strategy); + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto OC = fm.ocpg; + auto IC = fm.icpg; + bool avaible = //! src and filter are qint8, dst is qint8 or qint32 + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32))) && + (fm.format == param::Convolution::Format::NCHW44) && + (OC % 4 == 0 && IC % 4 == 0 && OC >= 4) && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && + param.bias_mode != BiasMode::BIAS; + return avaible; +} + +bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::is_preferred( + megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, + const NCBKernSizeParam& param) const { + // TODO: benchmark and fix + MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr); + MEGDNN_MARK_USED_VAR(param); + return false; +} + +size_t ConvBiasImpl::AlgoS8DirectStride2NCHW44::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + return get_bundle(param).total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoS8DirectStride2NCHW44::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + size_t OW = param.osz[1]; + size_t group = fm.group; + size_t fh = fm.spatial[0]; + size_t fw = fm.spatial[1]; + WorkspaceBundle wbundle = get_bundle(param); + conv_fun do_conv_fun = nullptr; + int ow_remain = OW % 8; +// NOTE: remain_w is not used to gen hash of midout for compatible with changing +// shape runtime +#define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride2, \ + midout_iv(#filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + +#define GET_OP_PARAM(filter, bias_mode, remain_w) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_REMAIN_W_PARAM(filter, bias_mode) \ + switch (ow_remain) { \ + case 0: \ + GET_OP_PARAM(filter, bias_mode, 0); \ + break; \ + case 1: \ + GET_OP_PARAM(filter, bias_mode, 1); \ + break; \ + case 2: \ + GET_OP_PARAM(filter, bias_mode, 2); \ + break; \ + case 3: \ + GET_OP_PARAM(filter, bias_mode, 3); \ + break; \ + case 4: \ + GET_OP_PARAM(filter, bias_mode, 4); \ + break; \ + case 5: \ + GET_OP_PARAM(filter, bias_mode, 5); \ + break; \ + case 6: \ + GET_OP_PARAM(filter, bias_mode, 6); \ + break; \ + case 7: \ + GET_OP_PARAM(filter, bias_mode, 7); \ + break; \ + default: \ + megdnn_assert(0); \ + } + +#define GET_BIAS_MODE_PARAM(filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + +#undef DO_CONV_KERN_FUN +#undef GET_REMAIN_W_PARAM +#undef GET_OP_PARAM +#undef GET_BIAS_MODE_PARAM +#undef DISPATCH_CONV_KERN + + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + WorkspaceBundle bundle = wbundle; + + constexpr size_t pack_oc = 4; + size_t oc_step = pack_oc; + if (fh == 2 && fw == 2 && OC >= 8) { + oc_step = 8; + } + if (group == 1) { + CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + constexpr size_t pack_ic = 4; + ret_kerns.push_back({copy_padding, {N, group, div_ceil(IC, pack_ic)}}); + auto do_conv = [bundle, do_conv_fun, ncb_range]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, + ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + } else { + CpuNDRange ncb_range = {N, group, 1}; + auto do_conv = [bundle, do_conv_fun, ncb_range]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + {0, ncb_index.thread_id, 0}); + do_conv_fun(bundle, kern_param, ncb_index, + {0, ncb_index.thread_id, 0}, ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + } + + return ret_kerns; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp new file mode 100644 index 00000000..06d047c3 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp @@ -0,0 +1,793 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8/direct.h" +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; +namespace { + +/** +dot like impl. dot 4 ic to 1 oc, accumale to c +example: (format like weight) +packed weight +low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> +--------------------------------------------------------------------- +high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> +dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> +**/ +// TODO: can try oh = 2 impl, oc = 8 impl +template +static void ker_neon_dirctconv_3x3s2_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, + int iw, const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[2 * 4]; + int8x16_t weight[3]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[2]; + init_oc4_ow8(c, bias_ptr); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + + c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); + + c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); + src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); + c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); + + src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); + c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_oc4_ow8_remain_static(c, op, dst_ptr); +} + +template +static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, + const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc4 = oc_step * fh * fw * ic; + + int32x4_t c[2][8]; + int8x16_t weight[2][2]; + int8x16_t src[8 + 1]; + int16x8_t temp_c[4]; + + init_oc8_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8(src_ic_0_3 + 16); + src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); + src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); + src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); + src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0][0] = vld1q_s8(read_weight_ptr); + weight[0][1] = vld1q_s8(read_weight_ptr + 16); + weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); + weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); + + c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); + c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); + c[0][1] = vdotq_s32_h(weight[0][0], src[2], c[0][1], temp_c[2]); + c[1][1] = vdotq_s32_h(weight[1][0], src[2], c[1][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); + c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); + c[0][1] = vdotq_s32_h(weight[0][1], src[3], c[0][1], temp_c[2]); + c[1][1] = vdotq_s32_h(weight[1][1], src[3], c[1][1], temp_c[3]); + + c[0][2] = vdotq_s32_h(weight[0][0], src[4], c[0][2], temp_c[0]); + c[1][2] = vdotq_s32_h(weight[1][0], src[4], c[1][2], temp_c[1]); + c[0][3] = vdotq_s32_h(weight[0][0], src[6], c[0][3], temp_c[2]); + c[1][3] = vdotq_s32_h(weight[1][0], src[6], c[1][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[0][1], src[5], c[0][2], temp_c[0]); + c[1][2] = vdotq_s32_h(weight[1][1], src[5], c[1][2], temp_c[1]); + c[0][3] = vdotq_s32_h(weight[0][1], src[7], c[0][3], temp_c[2]); + c[1][3] = vdotq_s32_h(weight[1][1], src[7], c[1][3], temp_c[3]); + + src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); + c[0][4] = vdotq_s32_h(weight[0][0], src[8], c[0][4], temp_c[0]); + c[1][4] = vdotq_s32_h(weight[1][0], src[8], c[1][4], temp_c[1]); + c[0][5] = vdotq_s32_h(weight[0][0], src[1], c[0][5], temp_c[2]); + c[1][5] = vdotq_s32_h(weight[1][0], src[1], c[1][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[0][1], src[0], c[0][4], temp_c[0]); + c[1][4] = vdotq_s32_h(weight[1][1], src[0], c[1][4], temp_c[1]); + c[0][5] = vdotq_s32_h(weight[0][1], src[2], c[0][5], temp_c[2]); + c[1][5] = vdotq_s32_h(weight[1][1], src[2], c[1][5], temp_c[3]); + + src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); + c[0][6] = vdotq_s32_h(weight[0][0], src[3], c[0][6], temp_c[0]); + c[1][6] = vdotq_s32_h(weight[1][0], src[3], c[1][6], temp_c[1]); + c[0][7] = vdotq_s32_h(weight[0][0], src[5], c[0][7], temp_c[2]); + c[1][7] = vdotq_s32_h(weight[1][0], src[5], c[1][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[0][1], src[4], c[0][6], temp_c[0]); + c[1][6] = vdotq_s32_h(weight[1][1], src[4], c[1][6], temp_c[1]); + c[0][7] = vdotq_s32_h(weight[0][1], src[6], c[0][7], temp_c[2]); + c[1][7] = vdotq_s32_h(weight[1][1], src[6], c[1][7], temp_c[3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_oc8_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); +} + +template +static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, + int iw, const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[2 * 4]; + int8x16_t weight[2]; + int8x16_t src[8 + 1]; + int16x8_t temp_c[2]; + init_oc4_ow8(c, bias_ptr); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + + c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); + + c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); + c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[0], src[1], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[1], src[0], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[1], src[2], c[5], temp_c[1]); + + src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); + c[6] = vdotq_s32_h(weight[0], src[3], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[0], src[5], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[1], src[4], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[1], src[6], c[7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + store_oc4_ow8_remain_static(c, op, dst_ptr); +} + +template +static void ker_neon_dirctconv_5x5s2_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, + int iw, const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[2 * 4]; + int8x16_t weight[5]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[2]; + init_oc4_ow8(c, bias_ptr); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + + c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[3], src[5], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[4], src[6], c[1], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[3], src[7], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[3], src[9], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[4], src[8], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[4], src[0], c[3], temp_c[1]); + + src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); + src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); + c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[3], src[1], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[3], src[3], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[4], src[2], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[4], src[4], c[5], temp_c[1]); + + src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 17 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 18 * 16)); + c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[3], src[5], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[3], src[7], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[4], src[6], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[4], src[8], c[7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + store_oc4_ow8_remain_static(c, op, dst_ptr); +} + +template +static void ker_neon_dirctconv_7x7s2_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, + int iw, const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[2 * 4]; + int8x16_t weight[7]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[2]; + init_oc4_ow8(c, bias_ptr); + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); + src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); + src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); + src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); + src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); + weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); + + c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[3], src[5], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[4], src[6], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[5], src[5], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[5], src[7], c[1], temp_c[1]); + c[0] = vdotq_s32_h(weight[6], src[6], c[0], temp_c[0]); + c[1] = vdotq_s32_h(weight[6], src[8], c[1], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); + c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[3], src[7], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[3], src[9], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[4], src[8], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[4], src[0], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[5], src[9], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[5], src[1], c[3], temp_c[1]); + c[2] = vdotq_s32_h(weight[6], src[0], c[2], temp_c[0]); + c[3] = vdotq_s32_h(weight[6], src[2], c[3], temp_c[1]); + + src[3] = vld1q_s8(src_ic_0_3 + 13 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 14 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 15 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 16 * 16); + c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[3], src[1], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[3], src[3], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[4], src[2], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[4], src[4], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[5], src[3], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[5], src[5], c[5], temp_c[1]); + c[4] = vdotq_s32_h(weight[6], src[4], c[4], temp_c[0]); + c[5] = vdotq_s32_h(weight[6], src[6], c[5], temp_c[1]); + + src[7] = vld1q_s8(src_ic_0_3 + 17 * 16); + src[8] = vld1q_s8(src_ic_0_3 + 18 * 16); + src[9] = vld1q_s8(src_ic_0_3 + 19 * 16); + src[0] = vld1q_s8(src_ic_0_3 + 20 * 16); + c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[3], src[5], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[3], src[7], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[4], src[6], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[4], src[8], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[5], src[7], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[5], src[9], c[7], temp_c[1]); + c[6] = vdotq_s32_h(weight[6], src[8], c[6], temp_c[0]); + c[7] = vdotq_s32_h(weight[6], src[0], c[7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + store_oc4_ow8_remain_static(c, op, dst_ptr); +} + +} // namespace + +template +void conv_bias::conv_direct_stride2_2x2_int8_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, + int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, + const size_t ih, const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t filter_size = 2; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t big_oc_step = 8; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t stride_h = 2; + constexpr size_t stride_w = 2; + constexpr int pack_iw_len = 4; + + const size_t out_img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_oc = oh * ow * ic_step; + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s2_oc8_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_2x2s2_oc8_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc, op); + } + } + } + if (oc_remain > 0) { + const size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s2_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_2x2s2_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + } + } +} +template +void conv_bias::conv_direct_stride2_3x3_int8_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, + int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, + const size_t ih, const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t filter_size = 3; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t stride_h = 2; + constexpr size_t stride_w = 2; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_3x3s2_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_3x3s2_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + } + } +} +template +void conv_bias::conv_direct_stride2_5x5_int8_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, + int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, + const size_t ih, const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t filter_size = 5; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t stride_h = 2; + constexpr size_t stride_w = 2; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_5x5s2_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_5x5s2_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + } + } +} + +template +void conv_bias::conv_direct_stride2_7x7_int8_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, + int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, + const size_t ih, const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t filter_size = 7; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t stride_h = 2; + constexpr size_t stride_w = 2; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_7x7s2_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_7x7s2_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op); + } + } + } +} + +#define INSTANTIATION(stride, i, bias, remain_w, Op) \ + template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \ + bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \ + int32_t*, int8_t*, const size_t, const size_t, \ + const size_t, const size_t, const size_t, \ + const size_t, const Op&); + +#define FOR_OP(stride, i, bias, remain_w) \ + INSTANTIATION(stride, i, bias, remain_w, \ + TypeCvtOp) \ + INSTANTIATION(stride, i, bias, remain_w, \ + ReluOp) \ + INSTANTIATION(stride, i, bias, remain_w, \ + HSwishOp) + +#define FOR_REMAIN(stride, i, bias) \ + FOR_OP(stride, i, bias, 0) \ + FOR_OP(stride, i, bias, 1) \ + FOR_OP(stride, i, bias, 2) \ + FOR_OP(stride, i, bias, 3) \ + FOR_OP(stride, i, bias, 4) \ + FOR_OP(stride, i, bias, 5) \ + FOR_OP(stride, i, bias, 6) \ + FOR_OP(stride, i, bias, 7) + +#define FOR_BIAS(stride, i) \ + FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \ + FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_FILTER(stride) \ + FOR_BIAS(stride, 2) \ + FOR_BIAS(stride, 3) \ + FOR_BIAS(stride, 5) \ + FOR_BIAS(stride, 7) + +FOR_FILTER(stride2) + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_IC +#undef FOR_BIAS +#undef FOR_NONLINEAR +#undef FOR_REMAIN +#undef INSTANTIATION diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp new file mode 100644 index 00000000..377df9f5 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp @@ -0,0 +1,302 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/int8/algos.h" +#include "src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h" +#include "src/arm_common/conv_bias/int8/strategy.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +#include "midout.h" + +using namespace megdnn; +using namespace arm_common; +using conv_fun = std::function; +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw_nchw44_stride2) + +static void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + auto&& fm = param.filter_meta; + size_t IH = param.isz[0]; + size_t IW = param.isz[1]; + size_t OH = param.osz[0]; + size_t OW = param.osz[1]; + + OH2 = OH; + OW2 = OW; + IH2 = round_up(IH + 2 * fm.padding[0], static_cast(2)); + IW2 = IW + 2 * fm.padding[1]; +} +static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { + constexpr size_t src_expand = 4; + auto&& fm = param.filter_meta; + size_t group = fm.group; + size_t batch = param.n; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + size_t FH = fm.spatial[0]; + size_t FW = fm.spatial[1]; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(param, IH2, IW2, OH2, OW2); + megdnn_assert(group == 1, "only support group == 1 now"); + size_t src_size = + batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; + size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); + return {nullptr, {src_size, weight_size}}; +}; + +static void copy_padding_kern(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t GROUP = kern_param.filter_meta.group; + + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + constexpr int expend_element = 4; + // TODO: block dim is better to get from arg + size_t workspace_ic_block = 1; + size_t workspace_batch_id = workspace_ids[0]; + size_t workspace_group_id = workspace_ids[1]; + size_t workspace_ic_id = workspace_ids[2]; + size_t workspace_ic = workspace_ic_id * workspace_ic_block; + size_t batch_id = ncb_index.ndrange_id[0]; + size_t group_id = ncb_index.ndrange_id[1]; + + const int8_t* sptr = static_cast( + kern_param.src(batch_id, group_id, workspace_ic_id, 1, 1)); + //! copy to sptr_base to eliminate padding effect + int8_t* sptr_base = static_cast(bundle.get(0)) + + (workspace_batch_id * GROUP * padding_group_size + + workspace_group_id * padding_group_size + + workspace_ic * IH2 * IW2) * + expend_element; + conv_bias::pack_nchw_src_for_nchw44_conv(sptr, sptr_base, 1, PH, PH, PW, PW, + IH, IW); +} + +template +static void do_conv_kern(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t OC = kern_param.filter_meta.ocpg; + size_t GROUP = kern_param.filter_meta.group; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_post_process = + kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) + Op op = Op(1.0f, 4.0f); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + op = Op(scale_bias, scale_dst); + } + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + + constexpr size_t pack_c = 4; + constexpr size_t src_expand_size = 4; + const size_t workspace_batch_id = workspace_ids[0]; + const size_t workspace_group_id = workspace_ids[1]; + const size_t batch_id = ncb_index.ndrange_id[0]; + const size_t group_id = ncb_index.ndrange_id[1]; + const size_t oc_id = ncb_index.ndrange_id[2]; + const size_t oc_block_num = ncb_range[2]; + size_t nr_pack_per_step = div_ceil(div_ceil(OC, pack_c), oc_block_num); + size_t oc_block = nr_pack_per_step * pack_c; + const size_t oc_idx = oc_id * oc_block; + if (oc_id == (oc_block_num - 1)) { + oc_block = OC - oc_id * nr_pack_per_step * pack_c; + } + megdnn_assert(oc_block % pack_c == 0, + "oc must be devisible by 4, but oc = %zu", oc_block); + const int8_t* sptr = + static_cast(bundle.get(0)) + + workspace_batch_id * GROUP * padding_group_size * src_expand_size + + workspace_group_id * padding_group_size * src_expand_size; + + const int8_t* fptr = + kern_param.filter(group_id) + oc_idx * FH * FW * IC; + void* dst = reinterpret_cast( + reinterpret_cast( + kern_param.dst(batch_id, group_id)) + + oc_idx * OH * OW); + const int32_t* bptr = + kern_param.bias(batch_id, group_id) + oc_idx; + auto packed_weight = reinterpret_cast(bundle.get(1)) + + group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; + + conv_bias::pack_nchw44_weight_for_nchw_conv(fptr, packed_weight, IC, FH, FW, + oc_block); +#define KERN1_NCHW44_CONV(filter) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw_nchw44< \ + bias_mode, Op>(sptr, packed_weight, bptr, nullptr, \ + static_cast(dst), oc_block, IC, IH2, IW2, \ + OH, OW, op) + DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); +#undef KERN1_NCHW44_CONV +} + +/* ===================== stride2 algo ===================== */ +bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto OC = fm.ocpg; + bool avaible = //! src and filter are qint8, dst is qint8 + fm.icpg < 4 && // must be nchw input + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && + (fm.format == param::Convolution::Format::NCHW44) && + (OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == fm.spatial[1] && (FH == 3 || FH == 5 || FH == 7) && + fm.group == 1 && param.bias_mode != BiasMode::BIAS; + return avaible; +} + +bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::is_preferred( + megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, + const NCBKernSizeParam& param) const { + // TODO: benchmark and fix + return false; +} + +size_t ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + return get_bundle(param).total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + size_t N = param.n; + size_t OC = fm.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = get_bundle(param); + conv_fun do_conv_fun = nullptr; +// NOTE: remain_w is not used to gen hash of midout for compatible with changing +// shape runtime +#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44_stride2, \ + midout_iv(#filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + +#define GET_OP_PARAM(filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(filter, bias_mode, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(filter, bias_mode, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(filter, bias_mode, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } +#define GET_BIAS_MODE_PARAM(filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + +#undef DO_CONV_KERN_FUN +#undef GET_REMAIN_W_PARAM +#undef GET_OP_PARAM +#undef GET_BIAS_MODE_PARAM +#undef DISPATCH_CONV_KERN + + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + WorkspaceBundle bundle = wbundle; + + constexpr size_t pack_oc = 8; + size_t oc_step = pack_oc; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {N, group, fm.icpg}}); + + CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; + auto do_conv = [bundle, do_conv_fun, ncb_range]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, + ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + + return ret_kerns; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.cpp new file mode 100644 index 00000000..af6a759e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.cpp @@ -0,0 +1,776 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern_nchw.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h" +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; +namespace { + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight, T4& temp); + static void impl(T& c, T2& src, T3& weight); +}; +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight, T4& temp) { + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], + temp[0]); + c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0], + temp[1]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], + temp[2]); + c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1], + temp[3]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], + temp[0]); + c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2], + temp[1]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], + temp[2]); + c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3], + temp[3]); + } + static void impl(T& c, T2& src, T3& weight) { + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); + c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); + c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); + c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); + c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3]); + } +}; +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight, T4& temp) { + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], + temp[0]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], + temp[2]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], + temp[0]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], + temp[2]); + } + static void impl(T& c, T2& src, T3& weight) { + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); + } +}; + +template +inline void cal_helper(T& c, T2& src, T3& weight, T4& temp) { + ShiftCalHelper::impl( + c, src, weight, temp); +} +template +inline void cal_helper(T& c, T2& src, T3& weight) { + ShiftCalHelper::impl( + c, src, weight); +}; + +template +struct OCHelper { +public: + static const int val = 0; +}; +template <> +struct OCHelper<4> { +public: + static const int val = 1; +}; +template <> +struct OCHelper<8> { +public: + static const int val = 2; +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op); +}; +/** + * filter shape = (oc/4, ic, 7, 7, 4), first 4 oc is f0 = filter[0, 0, :, :, :] + * calculate sequence \ + * f0[0:1, 0:1, 4] dot4, \ + * f0[0:1, 2:3, 4] dot4, \ + * f0[0:1, 4:5, 4] dot4, \ + * f0[0:1, 6, 4] dot2, \ + * ... + * f0[6, 0:1, 4] dot2, \ + * f0[6, 2:3, 4] dot2, \ + * f0[6, 4:5, 4] dot2, \ + * f0[6, 6, 4] dot1, \ + * look like: + * |---|---|---|-| + * |x x|x x|x x|x| + * |x x|x x|x x|x| + * |---|---|---|-| + * |x x|x x|x x|x| + * |x x|x x|x x|x| + * |---|---|---|-| + * |x x|x x|x x|x| + * |x x|x x|x x|x| + * |---|---|---|-| + * |x x|x x|x x|x| + * |---|---|---|-| + **/ +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + constexpr int filter_size = 7; + constexpr int ic_step = 1; + constexpr int oc_step = 4; + constexpr int pack_iw_len = 4; + constexpr int fh_step = 2; + constexpr int fh_end = filter_size / fh_step * fh_step; + constexpr int c_dim = OCHelper::val; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; + + int32x4_t c[c_dim][4]; + + init_ocx_ow4(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { + const int8_t* nchw_src_ptr = + src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + int8x16_t src[6]; + int8x16_t dot4_weight[c_dim][3]; + int16x8_t temp_c[4]; + load_helper<3, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, + ld_dot4_weight_oc); + load_helper<6, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight, + temp_c); + cal_helper<1, 1, c_dim, Vdotq_s32_h>(c, src, dot4_weight, + temp_c); + cal_helper<2, 2, c_dim, Vdotq_s32_h>(c, src, dot4_weight, + temp_c); + + int8x8_t src_dot2[4]; + int8x8_t dot2_weight[c_dim][1]; + load_helper<1, 3 * 16, 8, c_dim, Vld1_s8>( + dot2_weight, weight_ptr, ld_dot4_weight_oc); + load_helper<4, 3 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, + 0); + cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, + temp_c); + weight_ptr += filter_size * pack_iw_len * fh_step; + } + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + + 6 * iw * ic_step * pack_iw_len; + + int8x8_t dot2_weight[c_dim][3]; + int16x8_t temp_c[4]; + int8x8_t src_dot2[6]; + uint8x16_t tbl = vld1q_u8(src_idx_buffer); + load_helper<3, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, + ld_dot4_weight_oc); + load_helper_x<6, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, + 0, tbl); + cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, + temp_c); + cal_helper<1, 1, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, + temp_c); + cal_helper<2, 2, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, + temp_c); + + int16x8_t dot1_weight[c_dim][1]; + int16x8_t src_dot1[4]; + load_helper<1, 3 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( + dot1_weight, weight_ptr, ld_dot4_weight_oc); + load_helper<4, 3 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, + nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight); + weight_ptr += filter_size * pack_iw_len; + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int filter_size = 5; + static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + constexpr int ih_step = 2; + constexpr int ic_step = 1; + constexpr int oc_step = 4; + constexpr int pack_iw_len = 4; + constexpr int fh_end = filter_size / ih_step * ih_step; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][4]; + + init_ocx_ow4(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + for (int fh_idx = 0; fh_idx < fh_end; fh_idx += ih_step) { + const int8_t* nchw_src_ptr = + src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + int8x16_t src[5]; + int8x16_t dot4_weight[c_dim][2]; + int16x8_t temp_c[4]; + load_helper<2, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, + ld_dot4_weight_oc); + load_helper<5, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight, + temp_c); + cal_helper<1, 1, c_dim, Vdotq_s32_h>(c, src, dot4_weight, + temp_c); + + int8x8_t src_dot2[4]; + int8x8_t dot2_weight[c_dim][1]; + load_helper<1, 2 * 16, 8, c_dim, Vld1_s8>( + dot2_weight, weight_ptr, ld_dot4_weight_oc); + load_helper<4, 2 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, + 0); + cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, + temp_c); + weight_ptr += filter_size * pack_iw_len * ih_step; + } + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + + fh_end * iw * ic_step * pack_iw_len; + + int8x8_t dot2_weight[c_dim][2]; + int16x8_t temp_c[4]; + int8x8_t src_dot2[5]; + uint8x16_t tbl = vld1q_u8(src_idx_buffer); + load_helper<2, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, + ld_dot4_weight_oc); + load_helper_x<5, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, + 0, tbl); + + cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, + temp_c); + cal_helper<1, 1, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, + temp_c); + + int16x8_t dot1_weight[c_dim][1]; + int16x8_t src_dot1[4]; + load_helper<1, 2 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( + dot1_weight, weight_ptr, ld_dot4_weight_oc); + load_helper<4, 2 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, + nchw_src_ptr, 0); + + cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight); + weight_ptr += filter_size * pack_iw_len; + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; +/** + * filter shape = (oc/4, ic, 3, 3, 4), first 4 oc is f0 = filter[0, 0, :, :, :] + * calculate sequence \ + * f0[0:1, 0:1, 4] dot4, \ + * f0[0:1, 2, 4] dot2, \ + * f0[2, 0:1, 4] dot2, \ + * f0[2, 2, 4] dot1 \ + * look like: + * |---|-| + * |x x|x| + * |x x|x| + * |-----| + * |x x|x| + * |-----| + **/ +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int filter_size = 3; + static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int loop_ic_step = 1; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][4]; + init_ocx_ow4(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + // first 2 line + { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[4]; + int8x16_t dot4_weight[c_dim][1]; + int16x8_t temp_c[4]; + load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, + ld_weight_oc); + load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight, + temp_c); + + int8x8_t src_dot2[4]; + int8x8_t dot2_weight[c_dim][1]; + load_helper<1, 1 * 16, 8, c_dim, Vld1_s8>( + dot2_weight, weight_ptr, ld_weight_oc); + load_helper<4, 1 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, + 0); + cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, + temp_c); + } + // last line + { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + + 2 * iw * ic_step * pack_iw_len; + int16x8_t temp_c[4]; + int8x8_t src_dot2[4]; + int8x8_t dot2_weight[c_dim][1]; + uint8x16_t tbl = vld1q_u8(src_idx_buffer); + load_helper<1, 24, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, + ld_weight_oc); + load_helper_x<4, 0, 16, 0, Vldq_tbl_low_s8>( + src_dot2, nchw_src_ptr, 0, tbl); + cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, + temp_c); + int16x8_t dot1_weight[c_dim][1]; + int16x8_t src_dot1[4]; + load_helper<1, 32, 8, c_dim, Vldq_dup_4s8_8s16>( + dot1_weight, weight_ptr, ld_weight_oc); + load_helper<4, 1 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, + nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight); + weight_ptr += filter_size * filter_size * pack_iw_len; + } + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; + +} // namespace +enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; +template +inline void pack_src_one_line(const int8_t* inptr, int8_t* outptr, int left_pad, + int right_pad, const int iw) { + const int8_t* src_row_0 = inptr; + const int8_t* src_row_1 = inptr + iw; + constexpr int combine_row = 2; + constexpr int iw_step = 16; + constexpr int src_expand = 4; + constexpr int out_gap = iw_step * src_expand; + const int iw_end = iw / iw_step * iw_step; + + memset(outptr, 0, combine_row * left_pad * src_expand * sizeof(int8_t)); + outptr += combine_row * left_pad * src_expand; + + for (int iw_idx = 0; iw_idx < iw_end; iw_idx += iw_step) { + int8x16_t row0 = vld1q_s8(src_row_0 + iw_idx); + int8x16_t row1 = vdupq_n_s8(0); + if (mode == PACK_MODE::NO_PAD) { + row1 = vld1q_s8(src_row_1 + iw_idx); + } else if (mode == PACK_MODE::FIRST_PAD) { + row1 = row0; + row0 = vdupq_n_s8(0); + } + int8x16x2_t pack_rows = vzipq_s8(row0, row1); +#define STORE_8S8(step) \ + vst1_s8(outptr + step * 8, \ + vreinterpret_s8_s16(vdup_laneq_s16( \ + vreinterpretq_s16_s8(pack_rows.val[0]), step))); + + UNROLL_CALL_RAW(8, STORE_8S8); +#undef STORE_8S8 +#define STORE_8S8(step) \ + vst1_s8(outptr + out_gap + step * 8, \ + vreinterpret_s8_s16(vdup_laneq_s16( \ + vreinterpretq_s16_s8(pack_rows.val[1]), step))); + + UNROLL_CALL_RAW(8, STORE_8S8); +#undef STORE_8S8 + outptr += out_gap * combine_row; + } + for (int iw_idx = iw_end; iw_idx < iw; iw_idx++) { + int8x8_t row0 = vld1_dup_s8(src_row_0 + iw_idx); + int8x8_t row1 = vdup_n_s8(0); + if (mode == PACK_MODE::NO_PAD) { + row1 = vld1_dup_s8(src_row_1 + iw_idx); + } else if (mode == PACK_MODE::FIRST_PAD) { + row1 = row0; + row0 = vdup_n_s8(0); + } + int8x8x2_t pack_rows = vzip_s8(row0, row1); + vst1_s8(outptr, pack_rows.val[0]); + outptr += src_expand * combine_row; + } + memset(outptr, 0, combine_row * right_pad * src_expand * sizeof(int8_t)); + outptr += combine_row * right_pad * src_expand; +} +/** + * pack (ic, h, w) to (ic, h / 2, 2 * w) + * pack interleave two adjacent row in src and repeat 4 times, store to one row + * */ +void conv_bias::pack_nchw_src_for_nchw44_conv( + const int8_t* inptr, int8_t* outptr, const int ic, const int top_pad, + const int bottom_pad, const int left_pad, const int right_pad, + const int ih, const int iw) { + constexpr int src_expand = 4; + constexpr int oh_step = 2; + const int oh = ih + top_pad + bottom_pad; + const int oh_end = div_floor(ih + top_pad, oh_step) * oh_step; + const int ow = (iw + left_pad + right_pad) * src_expand; + + for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { + int oh_idx = 0; + for (; oh_idx < top_pad; oh_idx += oh_step) { + if (top_pad - oh_idx >= oh_step) { + memset(outptr, 0, oh_step * ow * sizeof(int8_t)); + } else { + pack_src_one_line(inptr, outptr, left_pad, + right_pad, iw); + inptr += iw; + } + outptr += oh_step * ow; + } + + for (; oh_idx < oh_end; oh_idx += oh_step) { + pack_src_one_line(inptr, outptr, left_pad, + right_pad, iw); + inptr += oh_step * iw; + outptr += oh_step * ow; + } + + for (; oh_idx < oh; oh_idx += oh_step) { + const int last_pad = oh_idx - ih - top_pad; + if (last_pad >= 0) { + memset(outptr, 0, oh_step * ow * sizeof(int8_t)); + } else { + pack_src_one_line(inptr, outptr, left_pad, + right_pad, iw); + inptr += iw; + } + outptr += oh_step * ow; + } + } +} + +/** + * pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh * fw, 4(oc)} + * pack interleave two adjacent row in filter to one row + * */ +void conv_bias::pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, + int8_t* outptr, const int ic, + const int fh, const int fw, + const int oc) { + constexpr int oc_step = 4; + constexpr int ic_step = 2; + constexpr int fh_step = 2; + constexpr int fw_step = 2; + const int ic_end = ic / ic_step * ic_step; + const int ic_remain = ic - ic_end; + const int fh_end = fh / fh_step * fh_step; + const int fh_remain = fh - fh_end; + const int fw_end = fw / fw_step * fw_step; + const int fw_remain = fw - fw_end; + const int filter_stride = ic * oc_step; + static const uint8_t ic2_idx_h_buffer[16] = {0, 8, 1, 9, 2, 10, 3, 11, + 4, 12, 5, 13, 6, 14, 7, 15}; + uint8x16_t ic2_idx_h = vld1q_u8(ic2_idx_h_buffer); + for (int oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + for (int ic_idx = 0; ic_idx < ic_end; ic_idx += ic_step) { + const int ic_offset = ic_idx * oc_step; + int8_t* output_ic0 = outptr + ic_idx * fh * fw * oc_step; + int8_t* output_ic1 = output_ic0 + fh * fw * oc_step; + for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { + const int fh_offset = fh_idx * fw * filter_stride; + for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_idx * filter_stride + + ic_offset; + int8x8_t row_0 = vld1_s8(filter_ptr); + int8x8_t row_1 = vld1_s8(filter_ptr + fw * filter_stride); + int8x16_t combine_row = vcombine_s8(row_0, row_1); + combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); + vst1_s8(output_ic0, vget_low_s8(combine_row)); + vst1_s8(output_ic1, vget_high_s8(combine_row)); + output_ic0 += 8; + output_ic1 += 8; + } + } + if (fh_remain > 0) { + const int fh_offset = fh_end * fw * filter_stride; + for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_idx * filter_stride + + ic_offset; + int8x8_t row_0 = vld1_s8(filter_ptr); + int8x8_t row_1 = vld1_s8(filter_ptr + filter_stride); + int8x16_t combine_row = vcombine_s8(row_0, row_1); + combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); + vst1_s8(output_ic0, vget_low_s8(combine_row)); + vst1_s8(output_ic1, vget_high_s8(combine_row)); + output_ic0 += 8; + output_ic1 += 8; + } + if (fw_remain > 0) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_end * filter_stride + + ic_offset; + int8x8_t row_0 = vld1_s8(filter_ptr); + vst1_lane_s32((int32_t*)output_ic0, + vreinterpret_s32_s8(row_0), 0); + vst1_lane_s32((int32_t*)output_ic1, + vreinterpret_s32_s8(row_0), 1); + output_ic0 += 4; + output_ic1 += 4; + } + } + } + if (ic_remain > 0) { + const int ic_offset = ic_end * oc_step; + int8_t* output_ic0 = outptr + ic_end * fh * fw * oc_step; + for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { + const int fh_offset = fh_idx * fw * filter_stride; + for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_idx * filter_stride + + ic_offset; + int8x8_t row_0 = vreinterpret_s8_s32( + vld1_dup_s32((const int32_t*)(filter_ptr))); + int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( + (const int32_t*)(filter_ptr + fw * filter_stride))); + int8x16_t combine_row = vcombine_s8(row_0, row_1); + combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); + vst1_s8(output_ic0, vget_low_s8(combine_row)); + output_ic0 += 8; + } + } + if (fh_remain > 0) { + const int fh_offset = fh_end * fw * filter_stride; + for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_idx * filter_stride + + ic_offset; + int8x8_t row_0 = vreinterpret_s8_s32( + vld1_dup_s32((const int32_t*)(filter_ptr))); + int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( + (const int32_t*)(filter_ptr + filter_stride))); + int8x16_t combine_row = vcombine_s8(row_0, row_1); + combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); + vst1_s8(output_ic0, vget_low_s8(combine_row)); + output_ic0 += 8; + } + if (fw_remain > 0) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_end * filter_stride + + ic_offset; + *(int32_t*)(output_ic0) = *(const int32_t*)(filter_ptr); + output_ic0 += 4; + } + } + } + inptr += oc_step * fh * fw * ic; + outptr += oc_step * fh * fw * ic; + } +} + +template +static void conv_direct_stride2_int8_nchw_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, + int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, + const size_t ih, const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 1; + constexpr size_t big_oc_step = 8; + constexpr size_t oc_step = 4; + constexpr size_t ih_step = 2; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 4; + constexpr size_t stride_h = 2; + constexpr size_t stride_w = 2; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = + std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonXXs2NchwNchw44::impl; \ + kern_small_oc_remain = \ + KerNeonXXs2NchwNchw44::impl; \ + break; + + UNROLL_CALL_RAW(4, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } + + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } + } + } + if (oc_remain > 0) { + size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + } + } +} +#define CONSTRUCT_FUNC(filter_size) \ + template \ + void conv_bias:: \ + conv_direct_stride2_##filter_size##x##filter_size##_int8_nchw_nchw44( \ + const int8_t* src, const int8_t* filter, \ + const int32_t* bias, int32_t* temp, int8_t* dst, \ + const size_t oc, const size_t ic, const size_t ih, \ + const size_t iw, const size_t oh, const size_t ow, \ + const Op& op) { \ + conv_direct_stride2_int8_nchw_nchw44( \ + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); \ + } + +CONSTRUCT_FUNC(3); +CONSTRUCT_FUNC(5); +CONSTRUCT_FUNC(7); +#undef CONSTRUCT_FUNC + +template +void conv_bias::conv_direct_stride2_2x2_int8_nchw_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, + int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, + const size_t ih, const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + megdnn_assert(0, "not imple nchw_nchw44 2x2s2 conv"); +} + +#define INSTANTIATION(stride, i, bias, Op) \ + template void conv_bias:: \ + conv_direct_##stride##_##i##x##i##_int8_nchw_nchw44( \ + const int8_t*, const int8_t*, const int32_t*, int32_t*, \ + int8_t*, const size_t, const size_t, const size_t, \ + const size_t, const size_t, const size_t, const Op&); + +#define FOR_OP(stride, i, bias) \ + INSTANTIATION(stride, i, bias, TypeCvtOp) \ + INSTANTIATION(stride, i, bias, ReluOp) \ + INSTANTIATION(stride, i, bias, HSwishOp) + +#define FOR_BIAS(stride, i) \ + FOR_OP(stride, i, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_FILTER(stride) \ + FOR_BIAS(stride, 2) \ + FOR_BIAS(stride, 3) \ + FOR_BIAS(stride, 5) \ + FOR_BIAS(stride, 7) + +FOR_FILTER(stride2) + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_IC +#undef FOR_BIAS +#undef FOR_NONLINEAR +#undef FOR_REMAIN +#undef INSTANTIATION diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h new file mode 100644 index 00000000..a0f65a65 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h @@ -0,0 +1,44 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace conv_bias { +#define KERN(stride, i, layout) \ + template \ + void conv_direct_##stride##_##i##x##i##_int8_nchw_##layout( \ + const int8_t* src, const int8_t* filter, const int32_t* bias, \ + int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \ + const size_t IH, const size_t IW, const size_t OH, \ + const size_t OW, const Op& op); + +KERN(stride2, 2, nchw44) +KERN(stride2, 3, nchw44) +KERN(stride2, 5, nchw44) +KERN(stride2, 7, nchw44) +#undef KERN + +void pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, int8_t* outptr, + const int ic, const int fh, const int fw, + const int oc); + +void pack_nchw_src_for_nchw44_conv(const int8_t* inptr, int8_t* outptr, + const int ic, const int top_pad, + const int bottom_pad, const int left_pad, + const int right_pad, const int ih, + const int iw); +} // namespace conv_bias +} // namespace arm_common +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/helper.h b/dnn/src/arm_common/conv_bias/int8/helper.h new file mode 100644 index 00000000..965b36d1 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/helper.h @@ -0,0 +1,33 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once +#include "src/common/unroll_macro.h" + +#define MATRIX_MUL4x4(sum, a, b) \ + sum##0 = vmla_lane_s16(sum##0, b##0, a##0, 0); \ + sum##0 = vmla_lane_s16(sum##0, b##1, a##0, 1); \ + sum##0 = vmla_lane_s16(sum##0, b##2, a##0, 2); \ + sum##0 = vmla_lane_s16(sum##0, b##3, a##0, 3); \ + sum##1 = vmla_lane_s16(sum##1, b##0, a##1, 0); \ + sum##1 = vmla_lane_s16(sum##1, b##1, a##1, 1); \ + sum##1 = vmla_lane_s16(sum##1, b##2, a##1, 2); \ + sum##1 = vmla_lane_s16(sum##1, b##3, a##1, 3); \ + sum##2 = vmla_lane_s16(sum##2, b##0, a##2, 0); \ + sum##2 = vmla_lane_s16(sum##2, b##1, a##2, 1); \ + sum##2 = vmla_lane_s16(sum##2, b##2, a##2, 2); \ + sum##2 = vmla_lane_s16(sum##2, b##3, a##2, 3); \ + sum##3 = vmla_lane_s16(sum##3, b##0, a##3, 0); \ + sum##3 = vmla_lane_s16(sum##3, b##1, a##3, 1); \ + sum##3 = vmla_lane_s16(sum##3, b##2, a##3, 2); \ + sum##3 = vmla_lane_s16(sum##3, b##3, a##3, 3); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/strategy.h b/dnn/src/arm_common/conv_bias/int8/strategy.h new file mode 100644 index 00000000..717f99da --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/strategy.h @@ -0,0 +1,27 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/fallback/conv_bias/winograd/winograd.h" + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY(int8_t, int8_t, int16_t, int, 2, 3, 8, 8, + winograd_2x3_8x8_s8) +} +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/strategy_2x3_8x8.cpp b/dnn/src/arm_common/conv_bias/int8/strategy_2x3_8x8.cpp new file mode 100644 index 00000000..d8efdc03 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/strategy_2x3_8x8.cpp @@ -0,0 +1,425 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/strategy_2x3_8x8.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/fallback/conv_bias/winograd/winograd.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" + +#include "src/arm_common/conv_bias/winograd_common/winograd_common.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/arm_common/conv_bias/int8/strategy.h" +#include "src/arm_common/conv_bias/int8/helper.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" + +#include "src/common/winograd/winograd_generator.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_s8_F23_8x8) + +using namespace megdnn; +using namespace arm_common; + +namespace { +void transpose_8x4(const int16_t* src, int16_t* dst, int lda, int ldb) { + int16x4x2_t a0, a1, a2, a3; + a0.val[0] = vld1_s16(src + 0 * lda); + a0.val[1] = vld1_s16(src + 1 * lda); + a1.val[0] = vld1_s16(src + 2 * lda); + a1.val[1] = vld1_s16(src + 3 * lda); + a2.val[0] = vld1_s16(src + 4 * lda); + a2.val[1] = vld1_s16(src + 5 * lda); + a3.val[0] = vld1_s16(src + 6 * lda); + a3.val[1] = vld1_s16(src + 7 * lda); + int16x4x2_t b0 = vzip_s16(a0.val[0], a1.val[0]); + int16x4x2_t b1 = vzip_s16(a0.val[1], a1.val[1]); + int16x4x2_t b2 = vzip_s16(a2.val[0], a3.val[0]); + int16x4x2_t b3 = vzip_s16(a2.val[1], a3.val[1]); + + int16x4x2_t c0 = vzip_s16(b0.val[0], b1.val[0]); + int16x4x2_t c1 = vzip_s16(b0.val[1], b1.val[1]); + int16x4x2_t c2 = vzip_s16(b2.val[0], b3.val[0]); + int16x4x2_t c3 = vzip_s16(b2.val[1], b3.val[1]); + + vst1_s16(dst + 0 * ldb, c0.val[0]); + vst1_s16(dst + 1 * ldb, c2.val[0]); + vst1_s16(dst + 2 * ldb, c0.val[1]); + vst1_s16(dst + 3 * ldb, c2.val[1]); + vst1_s16(dst + 4 * ldb, c1.val[0]); + vst1_s16(dst + 5 * ldb, c3.val[0]); + vst1_s16(dst + 6 * ldb, c1.val[1]); + vst1_s16(dst + 7 * ldb, c3.val[1]); +} + +struct FilterTransform2X3_qs8 { + static void transform(const int8_t* filter_ptr, int16_t* filter_transform_buf, + int16_t* transform_mid_buf, size_t OC, size_t IC, + size_t oc_start, size_t oc_end) { + constexpr int alpha = 2 + 3 - 1; + //! G * g * GT + int16x4_t g0{2, 0, 0, 0}, g1{1, 1, 1, 0}, g2{1, -1, 1, 0}, + g3{0, 0, 2, 0}; + int16x4_t gt0{2, 1, 1, 0}, gt1{0, 1, -1, 0}, gt2{0, 1, 1, 2}, + gt3{0, 0, 0, 0}; + + size_t OCB = OC / 8; + size_t ICB = IC / 8; + +#define get_v_general \ + InputGetter getter; \ + int16x4_t v0 = getter(filter); \ + int16x4_t v1 = getter(filter + 3); \ + int16x4_t v2 = getter(filter + 6); \ + int16x4_t v3 = vdup_n_s16(0); \ + /*To avoid the unaligned opcode error on tx1.*/ \ + vset_lane_s16_fix_tx1(0, v0, 3); \ + vset_lane_s16_fix_tx1(0, v1, 3); \ + vset_lane_s16_fix_tx1(0, v2, 3); + +#define get_v_searal \ + /* To avoid the bus error on armv7(mi9).*/ \ + int8x8_t s0 = {filter[0], filter[1], filter[2], 0, 0, 0, 0, 0}; \ + int8x8_t s1 = {filter[3], filter[4], filter[5], 0, 0, 0, 0, 0}; \ + int8x8_t s2 = {filter[6], filter[7], filter[8], 0, 0, 0, 0, 0}; \ + int16x4_t v0 = vget_low_s16(vmovl_s8(s0)); \ + int16x4_t v1 = vget_low_s16(vmovl_s8(s1)); \ + int16x4_t v2 = vget_low_s16(vmovl_s8(s2)); \ + int16x4_t v3 = vdup_n_s16(0); + +#define cb(oc, ic, get_v) \ + get_v int16x4_t vsum0 = vdup_n_s16(0), vsum1 = vdup_n_s16(0), \ + vsum2 = vdup_n_s16(0), vsum3 = vdup_n_s16(0); \ + MATRIX_MUL4x4(vsum, g, v); \ + int16x4_t vres0 = vdup_n_s16(0), vres1 = vdup_n_s16(0), \ + vres2 = vdup_n_s16(0), vres3 = vdup_n_s16(0); \ + MATRIX_MUL4x4(vres, vsum, gt); \ + vst1_s16(transform_mid_buf, vres0); \ + vst1_s16(transform_mid_buf + 4, vres1); \ + vst1_s16(transform_mid_buf + 8, vres2); \ + vst1_s16(transform_mid_buf + 12, vres3); \ + size_t ocb = (oc) / 8; \ + size_t oc8 = (oc) % 8; \ + size_t icb = (ic) / 8; \ + size_t ic8 = (ic) % 8; \ + rep(i, alpha) rep(j, alpha) { \ + filter_transform_buf[(i * alpha + j) * OCB * ICB * 8 * 8 + \ + ocb * ICB * 8 * 8 + icb * 8 * 8 + ic8 * 8 + \ + oc8] = transform_mid_buf[i * alpha + j]; \ + } \ + filter += 9; + + for (size_t oc = oc_start; oc < oc_end; oc++) { + const int8_t* filter = filter_ptr + oc * IC * 3 * 3; + if (oc != OC - 1) { + rep(ic, IC) { + /** + * origin: (4x3) * (3 x 3) * (3 x 4) + * pack to G and g to times of 4 + * now: (4x4) * (4 x 4) * (4 x 4) + */ + //! 1 0 0 0 v00 v01 v02 0 1 0.5 0.5 0 + //! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 + //! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 + //! 0 0 1 0 0 0 0 0 0 0 0 0 + cb(oc, ic, get_v_general); + } + } else { + rep(ic, IC - 1) { + cb(OC - 1, ic, get_v_general); + } + cb(OC - 1, IC - 1, get_v_searal); + } + } +#undef cb +#undef get_v_general +#undef get_v_searal + } +}; + +struct InputTransform2X3_qs8 { + template + static void prepare(const int8_t* input, int16_t* patch, int16_t* patchT, + int ih_start, int iw_start, size_t IH, size_t IW, + size_t ic, size_t IC) { + constexpr size_t alpha = 2 + 3 - 1; + if (!(inner && ic + 8 < IC)) { + memset(patch, 0, sizeof(int16_t) * 8 * alpha * alpha); + } + if (inner) { + const int8_t* input_ptr = + input + ic * IH * IW + ih_start * IW + iw_start; + InputGetter getter; + for (size_t ico = 0; ico < 8; ++ico) { + if (ic + ico < IC) { + int16x4_t v0 = getter(input_ptr); + int16x4_t v1 = getter(input_ptr + IW); + int16x4_t v2 = getter(input_ptr + IW * 2); + int16x4_t v3 = getter(input_ptr + IW * 3); + + vst1_s16(patch + (ico * 4 * alpha + 0 * 4), v0); + vst1_s16(patch + (ico * 4 * alpha + 1 * 4), v1); + vst1_s16(patch + (ico * 4 * alpha + 2 * 4), v2); + vst1_s16(patch + (ico * 4 * alpha + 3 * 4), v3); + input_ptr += IH * IW; + } + } + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + // partial copy + for (size_t ico = 0; ico < 8; ++ico) { + if (ic + ico < IC) { + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + patch[ico * alpha * alpha + iho * alpha + iwo] = + static_cast( + input[(ic + ico) * IH * IW + + ih * IW + iw]); + } + } + } + } + } + transpose_8x4(patch + 4 * 0, patchT + 32 * 0, 16, 4); + transpose_8x4(patch + 4 * 1, patchT + 32 * 1, 16, 4); + transpose_8x4(patch + 4 * 2, patchT + 32 * 2, 16, 4); + transpose_8x4(patch + 4 * 3, patchT + 32 * 3, 16, 4); + } + + static void transform(const int16_t* patchT, int16_t* input_transform_buf, + size_t unit_idx, size_t nr_units_in_tile, size_t ic, + size_t IC) { + constexpr size_t alpha = 2 + 3 - 1; + // BT * d * B +#define cb(m, n) \ + Vector d##m##n = \ + Vector::load(patchT + 8 * (m * 4 + n)); + + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + + //! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 + //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 + //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 + //! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 +#define cb(m) \ + auto t0##m = d0##m - d2##m; \ + auto t1##m = d1##m + d2##m; \ + auto t2##m = d2##m - d1##m; \ + auto t3##m = d3##m - d1##m; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(m) \ + d##m##0 = t##m##0 - t##m##2; \ + d##m##1 = t##m##1 + t##m##2; \ + d##m##2 = t##m##2 - t##m##1; \ + d##m##3 = t##m##3 - t##m##1; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + size_t ICB = IC / 8; + size_t icb = ic / 8; +#define cb(m, n) \ + d##m##n.save(input_transform_buf + \ + (m * alpha + n) * nr_units_in_tile * ICB * 8 + \ + icb * nr_units_in_tile * 8 + unit_idx * 8); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) +#undef cb + } +}; + +template +struct OutputTransform2X3_qs8 { + static void transform(const int32_t* output_transform_buf, + const int32_t* bias, int8_t* output, + int32_t* transform_mid_buf, size_t oh_start, + size_t ow_start, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, size_t oc_index, + size_t unit_idx, size_t nr_units_in_tile, + const DType& src_dtype, const DType& filter_dtype, + const DType& dst_dtype) { + float scale_filter = 0.f; + if (filter_dtype.enumv() == DTypeEnum::QuantizedS8) { + scale_filter = filter_dtype.param().scale; + } else { + megdnn_assert(filter_dtype.enumv() == DTypeEnum::QuantizedS16); + scale_filter = filter_dtype.param().scale; + } + float input_filter_scale = + src_dtype.param().scale * scale_filter; + DType buffer_dtype = dtype::QuantizedS32(input_filter_scale * 0.5f * + 0.5f * 1.0f * 1.0f); + Op op(buffer_dtype, dst_dtype); + //! AT * m * A + constexpr size_t alpha = 2 + 3 - 1; + + size_t oc = oc_start + oc_index; + size_t OCB = (oc_end - oc_start) / 8; + size_t ocb = oc_index / 8; + +#define cb(m, n) \ + auto v##m##n = Vector::load( \ + output_transform_buf + \ + (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ + ocb * nr_units_in_tile * 8 + unit_idx * 8); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + //! 1 1 1 0 v00 v01 v02 v03 1 0 + //! 0 1 -1 1 v10 v11 v12 v13 1 1 + //! v20 v21 v22 v23 1 -1 + //! v30 v31 v32 v33 0 1 +#define cb(m) \ + auto t0##m = v0##m + v1##m + v2##m; \ + auto t1##m = v1##m - v2##m + v3##m; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + v00 = t00 + t01 + t02; + v10 = t10 + t11 + t12; + v01 = t01 - t02 + t03; + v11 = t11 - t12 + t13; + + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + Vector vbias; + vbias = Vector::load(bias + oc) * (2 * 2); + + v00 += vbias; + v10 += vbias; + v01 += vbias; + v11 += vbias; + } + + v00.save(transform_mid_buf + (0 * 2 + 0) * 8); + v10.save(transform_mid_buf + (1 * 2 + 0) * 8); + v01.save(transform_mid_buf + (0 * 2 + 1) * 8); + v11.save(transform_mid_buf + (1 * 2 + 1) * 8); + + for (size_t oco = 0; oco < 8 && oc + oco < oc_end; ++oco) { + for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { + for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) { + dt_qint8 res_int8 = dt_qint8(0); + size_t oh = oh_start + oho; + size_t ow = ow_start + owo; + int32_t res = + transform_mid_buf[oho * 2 * 8 + owo * 8 + oco]; + if (bmode == BiasMode::BIAS) { + res += bias[(oc + oco) * OH * OW + oh * OW + ow] * 2 * + 2; + } + res_int8 = op(dt_qint32(res)); + output[(oc + oco) * OH * OW + oh * OW + ow] = + res_int8.as_int8(); + } + } + } + } +}; +} // namespace + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_s8) + +void winograd_2x3_8x8_s8::filter(const int8_t* filter, + int16_t* filter_transform_buf, + int16_t* transform_mid_buf, size_t OC, + size_t IC, size_t oc_start, size_t oc_end) { + FilterTransform2X3_qs8::transform(filter, filter_transform_buf, + transform_mid_buf, OC, IC, oc_start, + oc_end); +} + +void winograd_2x3_8x8_s8::input(const int8_t* input, + int16_t* input_transform_buf, + int16_t* transform_mid_buf, size_t IH, + size_t IW, size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, + size_t nr_units_in_tile) { + megdnn_assert(IC % 8 == 0); + constexpr int alpha = 3 + 2 - 1; + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + int16_t* patch = transform_mid_buf; + int16_t* patchT = transform_mid_buf + 8 * alpha * alpha; + + for (size_t ic = 0; ic < IC; ic += 8) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransform2X3_qs8::prepare(input, patch, patchT, + ih_start, iw_start, IH, IW, + ic, IC); + InputTransform2X3_qs8::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, ic, + IC); + + } else { + InputTransform2X3_qs8::prepare(input, patch, patchT, + ih_start, iw_start, IH, + IW, ic, IC); + InputTransform2X3_qs8::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, ic, + IC); + } + } + } +} + +void winograd_2x3_8x8_s8::output(const int* output_transform_buf, + const int* bias, int8_t* output, + int* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, + size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform2X3_qs8<_bmode MEGDNN_COMMA _nonline_op>::transform( \ + __VA_ARGS__); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + + for (size_t oc = oc_start; oc < oc_end; oc += 8) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + DISPATCH_CONV_WINOGRAD_BIAS_QUANTIZED( + megdnn_arm_common_winograd_s8_F23_8x8, cb, dt_qint32, dt_qint8, + bmode, nonline_mode, output_transform_buf, bias, output, + transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, + unit_idx, nr_units_in_tile, src_dtype, filter_dtype, dst_dtype); + } + } +#undef cb +} + +} // namespace winograd +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/stride1.cpp b/dnn/src/arm_common/conv_bias/int8/stride1.cpp new file mode 100644 index 00000000..43ab1ff4 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/stride1.cpp @@ -0,0 +1,360 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/stride1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8/stride1.h" +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/int8/direct.h" +#include "src/arm_common/conv_bias/int8/strategy.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +using namespace megdnn; +using namespace arm_common; +using namespace direct_int8_stride1; + +namespace { +bool need_dst_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + return param.osz[1] % 8; +} +bool need_src_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { + return true; + } + return need_dst_copy(param); +} +void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + auto&& fm = param.filter_meta; + auto SW = fm.stride[1]; + auto OH = param.osz[0]; + auto OW = param.osz[1]; + auto FH = fm.spatial[0]; + auto FW = fm.spatial[1]; + + OH2 = OH; + OW2 = (OW + 7) & ~7; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; +} +} // namespace +bool direct_int8_stride1::can_conv_direct_stride1_int8( + const NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto OC = fm.ocpg; + auto IC = fm.icpg; + bool avaible = + //! src and filter are qint8, dst is qint8 or qint32 + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32)) || + //! src and filter are int8, dst is int32 + (param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int32)) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + if (param.bias_type.valid()) { + avaible &= ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.bias_type.enumv() == DTypeEnum::QuantizedS32) || + (param.bias_type.enumv() == param.dst_type.enumv())); + } + bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) || + ((FH == 3 || FH == 5 || FH == 7) && + (OC <= 16 || (IC <= 4 && OC <= 32)))) && + param.bias_mode != BiasMode::BIAS; + return avaible && preferred; +} + +WorkspaceBundle direct_int8_stride1::get_bundle( + const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { + auto&& fm = param.filter_meta; + size_t nr_threads = param.nr_threads; + size_t group = fm.group, batch = param.n; + size_t IC = fm.icpg; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(param, IH2, IW2, OH2, OW2); + size_t src_size = 0, dst_size = 0; + if (need_src_copy(param)) { + src_size = m_large_group + ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; + }; + if (need_dst_copy(param)) { + dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; + } + if (IC > 1) { + size_t temp_size = OH2 * OW2 * sizeof(int32_t) * nr_threads; + return {nullptr, {src_size, dst_size, temp_size}}; + } else { + return {nullptr, {src_size, dst_size}}; + }; +} +//! Process one input channel copy padding +void direct_int8_stride1::copy_padding_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t GROUP = kern_param.filter_meta.group; + + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2], + group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + + const int8_t* sptr = static_cast( + kern_param.src(batch_id, group_id, channel_id)); + if (need_src_copy_var) { + //! copy to sptr_base to eliminate padding effect + int8_t* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size + + channel_id * IH2 * IW2; + std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); + } + } +}; +//! compute one output channel +template +void direct_int8_stride1::do_conv_kern(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t GROUP = kern_param.filter_meta.group; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + bool need_dst_copy_var = need_dst_copy(kern_param); + bool need_post_process = + kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) + Op op = Op(1.0f, 4.0f); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + op = Op(scale_bias, scale_dst); + } + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + //! If large group, each thread has its own worspace, set group_id with + //! thread_id + const int8_t* sptr = kern_param.src(batch_id, group_id); + const int8_t* fptr = + kern_param.filter(group_id) + oc * FH * FW * IC; + void* dst = reinterpret_cast(reinterpret_cast( + kern_param.dst(batch_id, group_id, oc))); + const int32_t* bptr = kern_param.bias(batch_id, group_id, oc); + if (need_src_copy_var) { + sptr = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size; + } + void* dptr = nullptr; + int32_t* tptr = nullptr; + if (need_dst_copy_var) { + dptr = reinterpret_cast( + reinterpret_cast(bundle.get(1)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size()); + } else { + dptr = dst; + } + +#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride1_##filter##x##filter##_int8_nchw< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ + static_cast(dptr), IH2, IW2, OH2, OW2, op) + +#define KERN0_NO_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride1_##filter##x##filter##_int8_nchw< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, \ + static_cast(dptr), nullptr, IH2, IW2, OH2, OW2, op) + +#define KERN1_NEED_POST_PROCESS(filter) \ + KERN0_NEED_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC - 1; ++ic) { \ + KERN0_NEED_POST_PROCESS(filter, false, false); \ + } \ + KERN0_NEED_POST_PROCESS(filter, false, true); + +#define KERN1_NO_POST_PROCESS(filter) \ + KERN0_NO_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC; ++ic) { \ + KERN0_NO_POST_PROCESS(filter, false, false); \ + } + if (need_post_process) { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NEED_POST_PROCESS, true, true) + } else { + tptr = static_cast(bundle.get(2)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); + DISPATCH_FILTER(filter, KERN1_NEED_POST_PROCESS) + } + } else { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NO_POST_PROCESS, true, false) + } else { + DISPATCH_FILTER(filter, KERN1_NO_POST_PROCESS) + } + } +#undef KERN0 +#undef KERN1_NEED_POST_PROCESS +#undef KERN1_NO_POST_PROCESS + if (need_dst_copy_var) { + rep(oh, OH) { + std::memcpy(reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); + } + } +} + +SmallVector direct_int8_stride1::get_kimpls( + const NCBKernSizeParam& param, bool m_large_group) { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = get_bundle(param, m_large_group); + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ + do_conv_fun = do_conv_kern; + +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(i) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(i, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + if (m_large_group) { + auto exec_one_group = [wbundle, do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + copy_padding_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + do_conv_fun(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle, do_conv_fun](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/stride1.h b/dnn/src/arm_common/conv_bias/int8/stride1.h new file mode 100644 index 00000000..a56db1ed --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/stride1.h @@ -0,0 +1,45 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/stride1.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { +namespace direct_int8_stride1 { +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +using conv_fun = std::function; + +bool can_conv_direct_stride1_int8(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); + +void copy_padding_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +template +void do_conv_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +SmallVector get_kimpls(const NCBKernSizeParam& param, + bool); +} // namespace direct_int8_stride1 +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp new file mode 100644 index 00000000..051cd95a --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp @@ -0,0 +1,363 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#if __ARM_FEATURE_DOTPROD + +#include "src/arm_common/conv_bias/int8/stride1_dotprod.h" +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/int8/direct_dotprod.h" +#include "src/arm_common/conv_bias/int8/strategy.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +using namespace megdnn; +using namespace arm_common; +using namespace direct_dotprod_int8_stride1; + +namespace { +bool need_dst_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + return param.osz[1] % 8; +} +bool need_src_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { + return true; + } + return need_dst_copy(param); +} +void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + auto&& fm = param.filter_meta; + auto SW = fm.stride[1]; + auto OH = param.osz[0]; + auto OW = param.osz[1]; + auto FH = fm.spatial[0]; + auto FW = fm.spatial[1]; + + OH2 = OH; + OW2 = (OW + 7) & ~7; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; +} +} // namespace + +bool direct_dotprod_int8_stride1::can_conv_direct_stride1_int8( + const NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto OC = fm.ocpg; + auto IC = fm.icpg; + bool avaible = + //! src and filter are qint8, dst is qint8 or qint32 + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32)) || + //! src and filter are int8, dst is int32 + (param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int32)) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + + bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) || + ((FH == 3 || FH == 5 || FH == 7) && + (OC <= 16 || (IC <= 4 && OC <= 32)))) && + param.bias_mode != BiasMode::BIAS; + if (param.bias_type.valid()) { + avaible &= ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.bias_type.enumv() == DTypeEnum::QuantizedS32) || + (param.bias_type.enumv() == param.dst_type.enumv())); + } + return avaible && preferred; +} + +WorkspaceBundle direct_dotprod_int8_stride1::get_bundle( + const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { + auto&& fm = param.filter_meta; + size_t nr_threads = param.nr_threads; + size_t group = fm.group, batch = param.n; + size_t IC = fm.icpg; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(param, IH2, IW2, OH2, OW2); + size_t src_size = 0, dst_size = 0; + if (need_src_copy(param)) { + src_size = m_large_group + ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; + }; + if (need_dst_copy(param)) { + dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; + } + if (IC > 1) { + size_t temp_size = OH2 * OW2 * sizeof(int32_t) * nr_threads; + return {nullptr, {src_size, dst_size, temp_size}}; + } else { + return {nullptr, {src_size, dst_size}}; + }; +} +//! Process one input channel copy padding +void direct_dotprod_int8_stride1::copy_padding_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t GROUP = kern_param.filter_meta.group; + + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + + //! Used for get the workspace offset + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1], + channel_id = workspace_ids[2]; + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1]; + const int8_t* sptr = kern_param.src(batch_id, group_id, channel_id); + if (need_src_copy_var) { + //! copy to sptr_base to eliminate padding effect + int8_t* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size + + channel_id * IH2 * IW2; + std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); + } + } +}; +//! compute one output channel +template +void direct_dotprod_int8_stride1::do_conv_kern( + WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t GROUP = kern_param.filter_meta.group; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + bool need_dst_copy_var = need_dst_copy(kern_param); + bool need_post_process = + kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) + Op op = Op(1.0f, 4.0f); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + op = Op(scale_bias, scale_dst); + } + size_t padding_group_size = IH2 * IW2 * IC; + + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + //! If large group, each thread has its own worspace, set group_id + //! with thread_id + const int8_t* sptr = kern_param.src(batch_id, group_id); + const int8_t* fptr = + kern_param.filter(group_id) + oc * FH * FW * IC; + void* dst = kern_param.dst(batch_id, group_id, oc); + const int32_t* bptr = kern_param.bias(batch_id, group_id, oc); + if (need_src_copy_var) { + sptr = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size; + } + void* dptr = nullptr; + int32_t* tptr = nullptr; + if (need_dst_copy_var) { + dptr = reinterpret_cast( + reinterpret_cast(bundle.get(1)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size()); + } else { + dptr = dst; + } + +#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride1_##filter##x##filter##_int8_dot< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ + static_cast(dptr), IH2, IW2, OH2, OW2, op) + +#define KERN0_NO_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride1_##filter##x##filter##_int8_dot< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, \ + static_cast(dptr), nullptr, IH2, IW2, OH2, OW2, op) + +#define KERN1_NEED_POST_PROCESS(filter) \ + KERN0_NEED_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC - 1; ++ic) { \ + KERN0_NEED_POST_PROCESS(filter, false, false); \ + } \ + KERN0_NEED_POST_PROCESS(filter, false, true); + +#define KERN1_NO_POST_PROCESS(filter) \ + KERN0_NO_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC; ++ic) { \ + KERN0_NO_POST_PROCESS(filter, false, false); \ + } + if (need_post_process) { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NEED_POST_PROCESS, true, true) + } else { + tptr = static_cast(bundle.get(2)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); + DISPATCH_FILTER(filter, KERN1_NEED_POST_PROCESS) + } + } else { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NO_POST_PROCESS, true, false) + } else { + DISPATCH_FILTER(filter, KERN1_NO_POST_PROCESS) + } + } +#undef KERN0 +#undef KERN1_NEED_POST_PROCESS +#undef KERN1_NO_POST_PROCESS + if (need_dst_copy_var) { + rep(oh, OH) { + std::memcpy(reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); + } + } +} + +SmallVector direct_dotprod_int8_stride1::get_kimpls( + const NCBKernSizeParam& param, bool m_large_group) { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = get_bundle(param, m_large_group); + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ + do_conv_fun = do_conv_kern; + +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(i) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(i, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + if (m_large_group) { + auto exec_one_group = [wbundle, do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + copy_padding_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + do_conv_fun(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle, do_conv_fun](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + + return ret_kerns; +} +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h new file mode 100644 index 00000000..7c32328e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h @@ -0,0 +1,45 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#if __ARM_FEATURE_DOTPROD +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" +namespace megdnn { +namespace arm_common { +namespace direct_dotprod_int8_stride1 { +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +using conv_fun = std::function; + +bool can_conv_direct_stride1_int8(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); + +void copy_padding_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +template +void do_conv_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +SmallVector get_kimpls(const NCBKernSizeParam& param, + bool); +} // namespace direct_dotprod_int8_stride1 +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/stride2.cpp b/dnn/src/arm_common/conv_bias/int8/stride2.cpp new file mode 100644 index 00000000..22db808d --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/stride2.cpp @@ -0,0 +1,367 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/stride2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8/stride2.h" +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/int8/direct.h" +#include "src/arm_common/conv_bias/int8/strategy.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +using namespace megdnn; +using namespace arm_common; +using namespace direct_int8_stride2; + +namespace { +bool need_dst_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + return param.osz[1] % 8; +} +bool need_src_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { + return true; + } + return need_dst_copy(param); +} +void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + auto&& fm = param.filter_meta; + size_t SW = fm.stride[1]; + size_t IH = param.isz[0]; + size_t IW = param.isz[1]; + size_t OH = param.osz[0]; + size_t OW = param.osz[1]; + size_t FH = fm.spatial[0]; + size_t FW = fm.spatial[1]; + + OH2 = OH; + OW2 = (OW + 7) & ~7; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; + // Because stride is 2, sometimes IW == IW2+1. Do a max update to + // handle this case. + IH2 = std::max(IH2, IH); + IW2 = std::max(IW2, IW); +} +} // namespace +bool direct_int8_stride2::can_conv_direct_stride2_int8( + const NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto OC = fm.ocpg; + auto IC = fm.icpg; + bool avaible = + //! src and filter are qint8, dst is qint8 or qint32 + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32)) || + //! src and filter are int8, dst is int32 + (param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int32)) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + if (param.bias_type.valid()) { + avaible &= ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.bias_type.enumv() == DTypeEnum::QuantizedS32) || + (param.bias_type.enumv() == param.dst_type.enumv())); + } + + bool preferred = (((FH == 2 || FH == 3) && + (IC == 1 || (IC <= 8 && OC <= 12) || OC <= 8)) || + (FH == 5 && ((IC == 1 && OC <= 16) || OC <= 12)) || + (FH == 7 && OC <= 16)) && + (param.bias_mode != BiasMode::BIAS); + return avaible && preferred; +} + +WorkspaceBundle direct_int8_stride2::get_bundle( + const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { + auto&& fm = param.filter_meta; + size_t nr_threads = param.nr_threads; + size_t group = fm.group, batch = param.n; + size_t IC = fm.icpg; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(param, IH2, IW2, OH2, OW2); + size_t src_size = 0, dst_size = 0; + if (need_src_copy(param)) { + src_size = m_large_group + ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; + }; + if (need_dst_copy(param)) { + dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; + } + if (IC > 1) { + size_t temp_size = OH2 * OW2 * sizeof(int32_t) * nr_threads; + return {nullptr, {src_size, dst_size, temp_size}}; + } else { + return {nullptr, {src_size, dst_size}}; + }; +} +//! Process one input channel copy padding +void direct_int8_stride2::copy_padding_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t GROUP = kern_param.filter_meta.group; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + + //! Used for get the workspace offset + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2]; + const int8_t* sptr = static_cast( + kern_param.src(batch_id, group_id, channel_id)); + if (need_src_copy_var) { + //! copy to sptr_base to eliminate padding effect + int8_t* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size + + channel_id * IH2 * IW2; + std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); + } + } +}; +//! compute one output channel +template +void direct_int8_stride2::do_conv_kern(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t GROUP = kern_param.filter_meta.group; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + bool need_dst_copy_var = need_dst_copy(kern_param); + bool need_post_process = + kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) + Op op = Op(1.0f, 4.0f); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + op = Op(scale_bias, scale_dst); + } + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + //! If large group, each thread has its own worspace, set group_id with + //! thread_id + const int8_t* sptr = kern_param.src(batch_id, group_id); + const int8_t* fptr = + kern_param.filter(group_id) + oc * FH * FW * IC; + void* dst = kern_param.dst(batch_id, group_id, oc); + const int32_t* bptr = kern_param.bias(batch_id, group_id, oc); + if (need_src_copy_var) { + sptr = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size; + } + void* dptr = nullptr; + int32_t* tptr = nullptr; + if (need_dst_copy_var) { + dptr = reinterpret_cast( + reinterpret_cast(bundle.get(1)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size()); + } else { + dptr = dst; + } + +#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ + static_cast(dptr), IH2, IW2, OH2, OW2, op) + +#define KERN0_NO_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, \ + static_cast(dptr), nullptr, IH2, IW2, OH2, OW2, op) + +#define KERN1_NEED_POST_PROCESS(filter) \ + KERN0_NEED_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC - 1; ++ic) { \ + KERN0_NEED_POST_PROCESS(filter, false, false); \ + } \ + KERN0_NEED_POST_PROCESS(filter, false, true); + +#define KERN1_NO_POST_PROCESS(filter) \ + KERN0_NO_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC; ++ic) { \ + KERN0_NO_POST_PROCESS(filter, false, false); \ + } + if (need_post_process) { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NEED_POST_PROCESS, true, true) + } else { + tptr = static_cast(bundle.get(2)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); + DISPATCH_FILTER(filter, KERN1_NEED_POST_PROCESS) + } + } else { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NO_POST_PROCESS, true, false) + } else { + DISPATCH_FILTER(filter, KERN1_NO_POST_PROCESS) + } + } +#undef KERN0 +#undef KERN1_NEED_POST_PROCESS +#undef KERN1_NO_POST_PROCESS + if (need_dst_copy_var) { + rep(oh, OH) { + std::memcpy(reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); + } + } +} + +SmallVector direct_int8_stride2::get_kimpls( + const NCBKernSizeParam& param, bool m_large_group) { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = get_bundle(param, m_large_group); + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ + do_conv_fun = do_conv_kern; + +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(i) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(i, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + if (m_large_group) { + auto exec_one_group = [wbundle, do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + copy_padding_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + do_conv_fun(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle, do_conv_fun](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/stride2.h b/dnn/src/arm_common/conv_bias/int8/stride2.h new file mode 100644 index 00000000..7509b425 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/stride2.h @@ -0,0 +1,44 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/stride2.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { +namespace direct_int8_stride2 { +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +using conv_fun = std::function; +bool can_conv_direct_stride2_int8(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); + +void copy_padding_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +template +void do_conv_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +SmallVector get_kimpls(const NCBKernSizeParam& param, + bool); +} // namespace direct_int8_stride2 +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp new file mode 100644 index 00000000..90344fc9 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp @@ -0,0 +1,368 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/conv_bias/int8/stride2_dotprod.h" +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/int8/direct_dotprod.h" +#include "src/arm_common/conv_bias/int8/strategy.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +using namespace megdnn; +using namespace arm_common; +using namespace direct_dotprod_int8_stride2; + +namespace { +bool need_dst_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + return param.osz[1] % 8; +} +bool need_src_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { + return true; + } + return need_dst_copy(param); +} +void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + auto&& fm = param.filter_meta; + size_t SW = fm.stride[1]; + size_t IH = param.isz[0]; + size_t IW = param.isz[1]; + size_t OH = param.osz[0]; + size_t OW = param.osz[1]; + size_t FH = fm.spatial[0]; + size_t FW = fm.spatial[1]; + + OH2 = OH; + OW2 = (OW + 7) & ~7; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; + // Because stride is 2, sometimes IW == IW2+1. Do a max update to + // handle this case. + IH2 = std::max(IH2, IH); + IW2 = std::max(IW2, IW); +} +} // namespace + +bool direct_dotprod_int8_stride2::can_conv_direct_stride2_int8( + const NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto OC = fm.ocpg; + auto IC = fm.icpg; + bool avaible = + //! src and filter are qint8, dst is qint8 or qint32 + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32)) || + //! src and filter are int8, dst is int32 + (param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int32)) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + + bool preferred = (((FH == 2 || FH == 3) && + (IC == 1 || (IC <= 8 && OC <= 12) || OC <= 8)) || + (FH == 5 && ((IC == 1 && OC <= 16) || OC <= 12)) || + (FH == 7 && OC <= 16)) && + (param.bias_mode != BiasMode::BIAS); + if (param.bias_type.valid()) { + avaible &= ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.bias_type.enumv() == DTypeEnum::QuantizedS32) || + (param.bias_type.enumv() == param.dst_type.enumv())); + } + return avaible && preferred; +} + +WorkspaceBundle direct_dotprod_int8_stride2::get_bundle( + const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { + auto&& fm = param.filter_meta; + size_t nr_threads = param.nr_threads; + size_t group = fm.group, batch = param.n; + size_t IC = fm.icpg; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(param, IH2, IW2, OH2, OW2); + size_t src_size = 0, dst_size = 0; + if (need_src_copy(param)) { + src_size = m_large_group + ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; + }; + if (need_dst_copy(param)) { + dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; + } + if (IC > 1) { + size_t temp_size = OH2 * OW2 * sizeof(int32_t) * nr_threads; + return {nullptr, {src_size, dst_size, temp_size}}; + } else { + return {nullptr, {src_size, dst_size}}; + }; +} +//! Process one input channel copy padding +void direct_dotprod_int8_stride2::copy_padding_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t GROUP = kern_param.filter_meta.group; + + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + const int8_t* sptr = kern_param.src(batch_id, group_id, channel_id); + if (need_src_copy_var) { + //! copy to sptr_base to eliminate padding effect + int8_t* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size + + channel_id * IH2 * IW2; + std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); + } + } +}; +//! compute one output channel +template +void direct_dotprod_int8_stride2::do_conv_kern( + WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t GROUP = kern_param.filter_meta.group; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + bool need_dst_copy_var = need_dst_copy(kern_param); + bool need_post_process = + kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) + Op op = Op(1.0f, 4.0f); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + op = Op(scale_bias, scale_dst); + } + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + //! If large group, each thread has its own worspace, set group_id + //! with thread_id + const int8_t* sptr = kern_param.src(batch_id, group_id); + const int8_t* fptr = + kern_param.filter(group_id) + oc * FH * FW * IC; + void* dst = kern_param.dst(batch_id, group_id, oc); + const int32_t* bptr = kern_param.bias(batch_id, group_id, oc); + + if (need_src_copy_var) { + sptr = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size; + } + void* dptr = nullptr; + int32_t* tptr = nullptr; + if (need_dst_copy_var) { + dptr = reinterpret_cast( + reinterpret_cast(bundle.get(1)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size()); + } else { + dptr = dst; + } + +#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_int8_dot< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ + static_cast(dptr), IH2, IW2, OH2, OW2, op) + +#define KERN0_NO_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_int8_dot< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, \ + static_cast(dptr), nullptr, IH2, IW2, OH2, OW2, op) + +#define KERN1_NEED_POST_PROCESS(filter) \ + KERN0_NEED_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC - 1; ++ic) { \ + KERN0_NEED_POST_PROCESS(filter, false, false); \ + } \ + KERN0_NEED_POST_PROCESS(filter, false, true); + +#define KERN1_NO_POST_PROCESS(filter) \ + KERN0_NO_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC; ++ic) { \ + KERN0_NO_POST_PROCESS(filter, false, false); \ + } + if (need_post_process) { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NEED_POST_PROCESS, true, true) + } else { + tptr = static_cast(bundle.get(2)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); + DISPATCH_FILTER(filter, KERN1_NEED_POST_PROCESS) + } + } else { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NO_POST_PROCESS, true, false) + } else { + DISPATCH_FILTER(filter, KERN1_NO_POST_PROCESS) + } + } +#undef KERN0 +#undef KERN1_NEED_POST_PROCESS +#undef KERN1_NO_POST_PROCESS + if (need_dst_copy_var) { + rep(oh, OH) { + std::memcpy(reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); + } + } +} + +SmallVector direct_dotprod_int8_stride2::get_kimpls( + const NCBKernSizeParam& param, bool m_large_group) { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = get_bundle(param, m_large_group); + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ + do_conv_fun = do_conv_kern; + +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(i) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(i, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + if (m_large_group) { + auto exec_one_group = [wbundle, do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + copy_padding_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + do_conv_fun(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle, do_conv_fun](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h new file mode 100644 index 00000000..639cb224 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h @@ -0,0 +1,46 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_DOTPROD +#pragma once +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { +namespace direct_dotprod_int8_stride2 { +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +using conv_fun = std::function; + +bool can_conv_direct_stride2_int8(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); + +void copy_padding_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +template +void do_conv_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +SmallVector get_kimpls(const NCBKernSizeParam& param, + bool); +} // namespace direct_dotprod_int8_stride2 +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp new file mode 100644 index 00000000..6b99902f --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp @@ -0,0 +1,565 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/int8x8x16/algos.h" +#include "src/arm_common/conv_bias/int8x8x16/conv_direct.h" +#include "src/arm_common/conv_bias/int8x8x16/conv_stride2.h" + +#include "midout.h" +#include "src/common/opr_delegate.h" +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8816_kimpl) + +#include +#include +#include + +using namespace megdnn; +using namespace arm_common; + +namespace { +bool need_dst_copy_str1( + const megdnn::fallback::ConvolutionImpl::NCBKernSizeParam& param) { + if (param.osz[0] % 1 != 0 || param.osz[1] % 8 != 0) + return true; + return false; +} +bool need_src_copy_str1( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + + if (fm.padding[0] != 0 || fm.padding[1] != 0) + return true; + + return need_dst_copy_str1(param); +} +void get_rectified_size_str1(size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, size_t& IH2, size_t& IW2, + size_t& OH2, size_t& OW2) { + OH2 = OH; + OW2 = (OW + 7) & ~7; + IH2 = OH2 + (IH - OH) + 2 * PH; + IW2 = OW2 + (IW - OW) + 2 * PW; +} +bool need_dst_copy_str2( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + // If the size of output is not multiples of 8, we need to copy it. + if (param.osz[0] % 8 != 0 || param.osz[1] % 8 != 0) + return true; + return false; +} +bool need_src_copy_str2( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + // If padding is not zero, we need to copy to eliminate padding effect. + if (fm.padding[0] != 0 || fm.padding[1] != 0) + return true; + + return need_dst_copy_str2(param); +} +void get_rectified_size_str2(size_t IH, size_t IW, size_t OH, size_t OW, + size_t FH, size_t FW, size_t PH, size_t PW, + size_t& IH2, size_t& IW2, size_t& OH2, + size_t& OW2) { + MEGDNN_MARK_USED_VAR(PH); + MEGDNN_MARK_USED_VAR(PW); + OH2 = (OH + 7) & ~7; + OW2 = (OW + 7) & ~7; + IH2 = 2 * OH2 + FH - 2; + IW2 = 2 * OW2 + FW - 2; + // Because stride is 2, sometimes IH/W == IH/W2 + 1 + // Do a max update to handle this case. + IH2 = std::max(IH2, IH); + IW2 = std::max(IW2, IW); +} +} // namespace + +/* ===================== direct algo ===================== */ +bool ConvBiasImpl::AlgoI8x8x16Direct::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 0) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + bool aviliable = + param.bias_mode == BiasMode::NO_BIAS && + param.nonlineMode == NonlineMode::IDENTITY && + fm.format == param::ConvBias::Format::NCHW && !fm.should_flip && + param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int16 && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); + if (algo_selection_strategy == + ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + aviliable &= (large_group == m_large_group); + } + return aviliable; + } + MIDOUT_END(); + return false; +} +WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Direct::get_bundle( + const NCBKernSizeParam& param) const { + auto&& fm = param.filter_meta; + size_t nr_threads = param.nr_threads; + size_t group = fm.group, batch = param.n; + auto IC = fm.icpg, IH = param.isz[0], IW = param.isz[1]; + auto OH = param.osz[0], OW = param.osz[1]; + auto PH = fm.padding[0], PW = fm.padding[1]; + size_t OH2, OW2, IH2, IW2; + get_rectified_size_str1(IH, IW, OH, OW, PH, PW, IH2, IW2, OH2, OW2); + size_t part0 = 0u, part1 = 0u; + if (need_src_copy_str1(param)) { + part0 = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; + } + if (need_dst_copy_str1(param)) { + part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16; + } + return {nullptr, {part0, part1}}; +} +size_t ConvBiasImpl::AlgoI8x8x16Direct::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 1) { + auto bundle = get_bundle(param); + return bundle.total_size_in_bytes(); + } + MIDOUT_END(); + return 0; +} +//! Process one input channel copy padding +void ConvBiasImpl::AlgoI8x8x16Direct::copy_padding_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t GROUP = kern_param.filter_meta.group; + size_t OH2, OW2, IH2, IW2; + get_rectified_size_str1(IH, IW, OH, OW, PH, PW, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy_str1(kern_param); + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], + channel_id = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + const int8_t* sptr = kern_param.src(batch_id, group_id, channel_id); + if (need_src_copy_var) { + //! copy to sptr_base to eliminate padding effect + int8_t* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size + + channel_id * IH2 * IW2; + std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); + } + } +}; +//! compute one output channel +void ConvBiasImpl::AlgoI8x8x16Direct::do_conv_kern( + WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t GROUP = kern_param.filter_meta.group; + size_t OH2, OW2, IH2, IW2; + get_rectified_size_str1(IH, IW, OH, OW, PH, PW, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy_str1(kern_param); + bool need_dst_copy_var = need_dst_copy_str1(kern_param); + size_t padding_group_size = IH2 * IW2 * IC; + //! Choose the compute kernel + using Func = + std::function; + Func fun_not_add_to_dst = nullptr, fun_add_to_dst = nullptr; + if (FH == 2) { + fun_not_add_to_dst = + conv_bias::conv_direct_2x2_sc_int8_int8_int16; + fun_add_to_dst = conv_bias::conv_direct_2x2_sc_int8_int8_int16; + } else if (FH == 3) { + fun_not_add_to_dst = + conv_bias::conv_direct_3x3_sc_int8_int8_int16; + fun_add_to_dst = conv_bias::conv_direct_3x3_sc_int8_int8_int16; + } else if (FH == 5) { + fun_not_add_to_dst = + conv_bias::conv_direct_5x5_sc_int8_int8_int16; + fun_add_to_dst = conv_bias::conv_direct_5x5_sc_int8_int8_int16; + } + + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; + + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + + const int8_t* sptr = kern_param.src(batch_id, group_id); + const int8_t* filter = + kern_param.filter(group_id) + oc * FH * FW * IC; + int16_t* dst = kern_param.dst(batch_id, group_id, oc); + if (need_src_copy_var) { + sptr = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size; + } + int16_t* dptr = nullptr; + if (need_dst_copy_var) { + dptr = static_cast(bundle.get(1)) + + ncb_index.thread_id * OH2 * OW2; + } else { + dptr = dst; + } + fun_not_add_to_dst(sptr, filter, dptr, IH2, IW2, OH2, OW2, 0, 0); + for (size_t ic = 1; ic < IC; ++ic) { + fun_add_to_dst(sptr + ic * IH2 * IW2, filter + ic * FH * FW, dptr, IH2, + IW2, OH2, OW2, 0, 0); + } + if (need_dst_copy_var) { + rep(oh, OH) { + std::memcpy(dst + oh * OW, dptr + oh * OW2, sizeof(int16_t) * OW); + } + } +} +SmallVector ConvBiasImpl::AlgoI8x8x16Direct::get_kimpls( + const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = get_bundle(param); + SmallVector ret_kerns; + if (m_large_group) { + auto exec_one_group = [wbundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + copy_padding_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + do_conv_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} +SmallVector +ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 2) { + return get_kimpls(param); + } + MIDOUT_END(); + return {}; +} + +/* ===================== stride-2 algo ===================== */ +bool ConvBiasImpl::AlgoI8x8x16Stride2::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 0) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + bool aviliable = param.bias_mode == BiasMode::NO_BIAS && + param.nonlineMode == NonlineMode::IDENTITY && + fm.format == param::ConvBias::Format::NCHW && + !fm.should_flip && + param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int16 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); + if (algo_selection_strategy == + ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + aviliable &= (large_group == m_large_group); + } + return aviliable; + } + MIDOUT_END(); + return false; +} +WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Stride2::get_bundle( + const NCBKernSizeParam& param) const { + auto&& fm = param.filter_meta; + size_t nr_threads = param.nr_threads; + size_t group = fm.group, batch = param.n; + auto IC = fm.icpg, IH = param.isz[0], IW = param.isz[1]; + auto OH = param.osz[0], OW = param.osz[1]; + auto PH = fm.padding[0], PW = fm.padding[1]; + auto FH = fm.spatial[0], FW = fm.spatial[1]; + size_t OH2, OW2, IH2, IW2; + get_rectified_size_str2(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); + size_t part0 = 0u, part1 = 0u; + if (need_src_copy_str2(param)) { + part0 = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; + } + if (need_dst_copy_str2(param)) { + part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16; + } + return {nullptr, {part0, part1}}; +} +size_t ConvBiasImpl::AlgoI8x8x16Stride2::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 1) { + auto bundle = get_bundle(param); + return bundle.total_size_in_bytes(); + } + MIDOUT_END(); + return 0; +} +//! Process one input channel copy padding +void ConvBiasImpl::AlgoI8x8x16Stride2::copy_padding_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + auto FH = kern_param.filter_meta.spatial[0], + FW = kern_param.filter_meta.spatial[1]; + size_t GROUP = kern_param.filter_meta.group; + size_t IH2, IW2, OH2, OW2; + get_rectified_size_str2(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy_str2(kern_param); + size_t padding_group_size = IH2 * IW2 * IC; + + bundle.set(kern_param.workspace_ptr); + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], + channel_id = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + const int8_t* sptr = kern_param.src(batch_id, group_id, channel_id); + if (need_src_copy_var) { + //! copy to sptr_base to eliminate padding effect + int8_t* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size + + channel_id * IH2 * IW2; + std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); + } + } +}; +//! compute one output channel +void ConvBiasImpl::AlgoI8x8x16Stride2::do_conv_kern( + WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t GROUP = kern_param.filter_meta.group; + size_t IH2, IW2, OH2, OW2; + get_rectified_size_str2(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy_str2(kern_param); + bool need_dst_copy_var = need_dst_copy_str2(kern_param); + size_t padding_group_size = IH2 * IW2 * IC; + //! Choose the compute kernel + using Func = + std::function; + Func fun_not_add_to_dst = nullptr, fun_add_to_dst = nullptr; + if (FH == 2) { + fun_not_add_to_dst = + conv_bias::conv_stride2_2x2_sc_int8_int8_int16; + fun_add_to_dst = conv_bias::conv_stride2_2x2_sc_int8_int8_int16; + } else if (FH == 3) { + fun_not_add_to_dst = + conv_bias::conv_stride2_3x3_sc_int8_int8_int16; + fun_add_to_dst = conv_bias::conv_stride2_3x3_sc_int8_int8_int16; + } else if (FH == 5) { + fun_not_add_to_dst = + conv_bias::conv_stride2_5x5_sc_int8_int8_int16; + fun_add_to_dst = conv_bias::conv_stride2_5x5_sc_int8_int8_int16; + } + + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + const int8_t* sptr = kern_param.src(batch_id, group_id); + const int8_t* filter = + kern_param.filter(group_id) + oc * FH * FW * IC; + int16_t* dst = kern_param.dst(batch_id, group_id, oc); + if (need_src_copy_var) { + sptr = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size; + } + int16_t* dptr = nullptr; + if (need_dst_copy_var) { + dptr = static_cast(bundle.get(1)) + + ncb_index.thread_id * OH2 * OW2; + } else { + dptr = dst; + } + fun_not_add_to_dst(sptr, filter, dptr, IH2, IW2, OH2, OW2, 0, 0); + for (size_t ic = 1; ic < IC; ++ic) { + fun_add_to_dst(sptr + ic * IH2 * IW2, filter + ic * FH * FW, dptr, IH2, + IW2, OH2, OW2, 0, 0); + } + if (need_dst_copy_var) { + rep(oh, OH) { + std::memcpy(dst + oh * OW, dptr + oh * OW2, sizeof(int16_t) * OW); + } + } +} +SmallVector ConvBiasImpl::AlgoI8x8x16Stride2::get_kimpls( + const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = get_bundle(param); + SmallVector ret_kerns; + if (m_large_group) { + auto exec_one_group = [wbundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + copy_padding_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + do_conv_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} +SmallVector +ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 2) { + return get_kimpls(param); + } + MIDOUT_END(); + return {}; +} +bool ConvBiasImpl::AlgoI8x8x16Stride2Filter2::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 0) { + return param.bias_mode == BiasMode::NO_BIAS && + param.nonlineMode == NonlineMode::IDENTITY && + param.nr_threads == 1_z && + conv_bias::can_conv_int8x8x16_stride2_flt2(param); + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoI8x8x16Stride2Filter2::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 1) { + return conv_bias::get_workspace_in_bytes_conv_int8x8x16_stride2_flt2( + param); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoI8x8x16Stride2Filter2::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + // return {conv_bias::conv_int8x8x16_stride2_flt2,true}; + auto kern = [](const NCBKernParam& param, const NCBKernIndex& ncb_index) { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 2) { + auto ncb_param = param; + ncb_param.src_ptr = param.src(0, ncb_index.ndrange_id[0]); + ncb_param.dst_ptr = param.dst(0, ncb_index.ndrange_id[0]); + ncb_param.filter_ptr = param.filter(ncb_index.ndrange_id[0]); + ncb_param.bias_ptr = param.bias(0, ncb_index.ndrange_id[0]); + conv_bias::conv_int8x8x16_stride2_flt2(ncb_param); + } + MIDOUT_END(); + }; + size_t group = param.filter_meta.group; + return {{kern, {group, 1_z, 1_z}}}; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h new file mode 100644 index 00000000..c89e0a6a --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h @@ -0,0 +1,95 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "../opr_impl.h" + +namespace megdnn { +namespace arm_common { +class ConvBiasImpl::AlgoI8x8x16Direct final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; + static void copy_padding_kern(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + static void do_conv_kern(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + bool m_large_group; + +public: + AlgoI8x8x16Direct(bool large_group) : m_large_group(large_group) {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "I8816DIRECT_LARGE_GROUP" + : "I8816DIRECT_SMALL_GROUP"; + } + bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; + static void copy_padding_kern(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + static void do_conv_kern(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + bool m_large_group; + +public: + AlgoI8x8x16Stride2(bool large_group) : m_large_group(large_group) {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "I8816STRD2_LARGE_GROUP" + : "I8816STRD2_SMALL_GROUP"; + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "I8816STRD2F2"; } + + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.cpp new file mode 100644 index 00000000..f933116b --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.cpp @@ -0,0 +1,598 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/int8x8x16/conv_direct.h" +#include "src/common/utils.h" + +#include +#include "midout.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8816_filter) + +using namespace megdnn; +using namespace arm_common; +using namespace conv_bias; + +template +void conv_bias::conv_direct_2x2_sc_int8_int8_int16(const int8_t* src, const int8_t* filter, + int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, + size_t PW) { + size_t OH_start = PH, OH_stop = OH - PH; + size_t OW_start = PW, OW_stop = OW - PW; + auto run_single = [&](size_t oh, size_t ow) { + if (!add_to_dst) { + dst[oh * OW + ow] = 0; + } + for (size_t fh = 0; fh < 2; ++fh) + for (size_t fw = 0; fw < 2; ++fw) { + size_t ih = oh + fh - PH; + size_t iw = ow + fw - PW; + if (ih < IH && iw < IW) { + dst[oh * OW + ow] += + (int16_t)src[ih * IW + iw] * filter[fh * 2 + fw]; + } + } + }; + for (size_t oh = 0; oh < OH_start; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + run_single(oh, ow); + } + } + for (size_t oh = OH_start; oh < OH_stop; ++oh) { + for (size_t ow = 0; ow < OW_start; ++ow) + run_single(oh, ow); + for (size_t ow = OW_stop; ow < OW; ++ow) + run_single(oh, ow); + } + for (size_t oh = OH_stop; oh < OH; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + run_single(oh, ow); + } + } + // 4x8 block + size_t oh = OH_start; + for (; oh + 4 <= OH_stop; oh += 4) { + size_t ih = oh - PH; + size_t ow = OW_start; + for (; ow < OW_stop; ow += 8) { + size_t iw = ow - PW; + int16_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int8_t* __restrict fptr = filter; + int16x8_t d0, d1, d2, d3; + int8x8_t k0, k1, s; + if (add_to_dst) { + d0 = vld1q_s16(dptr + 0 * OW); + d1 = vld1q_s16(dptr + 1 * OW); + d2 = vld1q_s16(dptr + 2 * OW); + d3 = vld1q_s16(dptr + 3 * OW); + } else { + d0 = vdupq_n_s16(0); + d1 = vdupq_n_s16(0); + d2 = vdupq_n_s16(0); + d3 = vdupq_n_s16(0); + } + + for (size_t fw = 0; fw < 2; ++fw) { + k0 = vdup_n_s8(fptr[0 * 2 + fw]); + k1 = vdup_n_s8(fptr[1 * 2 + fw]); + + s = vld1_s8(sptr + 0 * IW); + d0 = vmlal_s8(d0, k0, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k1, s); + d1 = vmlal_s8(d1, k0, s); + + s = vld1_s8(sptr + 2 * IW); + d1 = vmlal_s8(d1, k1, s); + d2 = vmlal_s8(d2, k0, s); + + s = vld1_s8(sptr + 3 * IW); + d2 = vmlal_s8(d2, k1, s); + d3 = vmlal_s8(d3, k0, s); + + s = vld1_s8(sptr + 4 * IW); + d3 = vmlal_s8(d3, k1, s); + + ++sptr; + } + vst1q_s16(dptr + 0 * OW, d0); + vst1q_s16(dptr + 1 * OW, d1); + vst1q_s16(dptr + 2 * OW, d2); + vst1q_s16(dptr + 3 * OW, d3); + } + } + if (oh + 3 == OH_stop) { + size_t ih = oh - PH; + size_t ow = OW_start; + for (; ow < OW_stop; ow += 8) { + size_t iw = ow - PW; + int16_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int8_t* __restrict fptr = filter; + int16x8_t d0, d1, d2; + int8x8_t k0, k1, s; + if (add_to_dst) { + d0 = vld1q_s16(dptr + 0 * OW); + d1 = vld1q_s16(dptr + 1 * OW); + d2 = vld1q_s16(dptr + 2 * OW); + } else { + d0 = vdupq_n_s16(0); + d1 = vdupq_n_s16(0); + d2 = vdupq_n_s16(0); + } + for (size_t fw = 0; fw < 2; ++fw) { + k0 = vdup_n_s8(fptr[0 * 2 + fw]); + k1 = vdup_n_s8(fptr[1 * 2 + fw]); + + s = vld1_s8(sptr + 0 * IW); + d0 = vmlal_s8(d0, k0, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k1, s); + d1 = vmlal_s8(d1, k0, s); + + s = vld1_s8(sptr + 2 * IW); + d1 = vmlal_s8(d1, k1, s); + d2 = vmlal_s8(d2, k0, s); + + s = vld1_s8(sptr + 3 * IW); + d2 = vmlal_s8(d2, k1, s); + + ++sptr; + } + vst1q_s16(dptr + 0 * OW, d0); + vst1q_s16(dptr + 1 * OW, d1); + vst1q_s16(dptr + 2 * OW, d2); + } + } else if (oh + 2 == OH_stop) { + size_t ih = oh - PH; + size_t ow = OW_start; + for (; ow < OW_stop; ow += 8) { + size_t iw = ow - PW; + int16_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int8_t* __restrict fptr = filter; + int16x8_t d0, d1; + int8x8_t k0, k1, s; + if (add_to_dst) { + d0 = vld1q_s16(dptr + 0 * OW); + d1 = vld1q_s16(dptr + 1 * OW); + } else { + d0 = vdupq_n_s16(0); + d1 = vdupq_n_s16(0); + } + for (size_t fw = 0; fw < 2; ++fw) { + k0 = vdup_n_s8(fptr[0 * 2 + fw]); + k1 = vdup_n_s8(fptr[1 * 2 + fw]); + + s = vld1_s8(sptr + 0 * IW); + d0 = vmlal_s8(d0, k0, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k1, s); + d1 = vmlal_s8(d1, k0, s); + + s = vld1_s8(sptr + 2 * IW); + d1 = vmlal_s8(d1, k1, s); + + ++sptr; + } + vst1q_s16(dptr + 0 * OW, d0); + vst1q_s16(dptr + 1 * OW, d1); + } + } else if (oh + 1 == OH_stop) { + size_t ih = oh - PH; + size_t ow = OW_start; + for (; ow < OW_stop; ow += 8) { + size_t iw = ow - PW; + int16_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int8_t* __restrict fptr = filter; + int16x8_t d0; + int8x8_t k0, k1, s; + if (add_to_dst) { + d0 = vld1q_s16(dptr + 0 * OW); + } else { + d0 = vdupq_n_s16(0); + } + for (size_t fw = 0; fw < 2; ++fw) { + k0 = vdup_n_s8(fptr[0 * 2 + fw]); + k1 = vdup_n_s8(fptr[1 * 2 + fw]); + + s = vld1_s8(sptr + 0 * IW); + d0 = vmlal_s8(d0, k0, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k1, s); + + ++sptr; + } + vst1q_s16(dptr + 0 * OW, d0); + } + } +} + +template +void conv_bias::conv_direct_3x3_sc_int8_int8_int16(const int8_t* src, const int8_t* filter, + int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, + size_t PW) { + size_t OH_start = PH, OH_stop = OH - PH; + size_t OW_start = PW, OW_stop = OW - PW; + + auto run_single = [&](size_t oh, size_t ow) { + if (!add_to_dst) { + dst[oh * OW + ow] = 0; + } + for (size_t fh = 0; fh < 3; ++fh) + for (size_t fw = 0; fw < 3; ++fw) { + size_t ih = oh + fh - PH; + size_t iw = ow + fw - PW; + if (ih < IH && iw < IW) { + dst[oh * OW + ow] += + (int16_t)src[ih * IW + iw] * filter[fh * 3 + fw]; + } + } + }; + + for (size_t oh = 0; oh < OH_start; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + run_single(oh, ow); + } + } + for (size_t oh = OH_start; oh < OH_stop; ++oh) { + for (size_t ow = 0; ow < OW_start; ++ow) + run_single(oh, ow); + for (size_t ow = OW_stop; ow < OW; ++ow) + run_single(oh, ow); + } + for (size_t oh = OH_stop; oh < OH; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + run_single(oh, ow); + } + } + + // 4x8 block + size_t oh = OH_start; + for (; oh + 4 <= OH_stop; oh += 4) { + size_t ih = oh - PH; + size_t ow = OW_start; + for (; ow < OW_stop; ow += 8) { + size_t iw = ow - PW; + int16_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int8_t* __restrict fptr = filter; + int16x8_t d0, d1, d2, d3; + int8x8_t k0, k1, k2, s; + if (add_to_dst) { + d0 = vld1q_s16(dptr + 0 * OW); + d1 = vld1q_s16(dptr + 1 * OW); + d2 = vld1q_s16(dptr + 2 * OW); + d3 = vld1q_s16(dptr + 3 * OW); + } else { + d0 = vdupq_n_s16(0); + d1 = vdupq_n_s16(0); + d2 = vdupq_n_s16(0); + d3 = vdupq_n_s16(0); + } + for (size_t fw = 0; fw < 3; ++fw) { + k0 = vdup_n_s8(fptr[0 * 3 + fw]); + k1 = vdup_n_s8(fptr[1 * 3 + fw]); + k2 = vdup_n_s8(fptr[2 * 3 + fw]); + + s = vld1_s8(sptr + 0 * IW); + d0 = vmlal_s8(d0, k0, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k1, s); + d1 = vmlal_s8(d1, k0, s); + + s = vld1_s8(sptr + 2 * IW); + d0 = vmlal_s8(d0, k2, s); + d1 = vmlal_s8(d1, k1, s); + d2 = vmlal_s8(d2, k0, s); + + s = vld1_s8(sptr + 3 * IW); + d1 = vmlal_s8(d1, k2, s); + d2 = vmlal_s8(d2, k1, s); + d3 = vmlal_s8(d3, k0, s); + + s = vld1_s8(sptr + 4 * IW); + d2 = vmlal_s8(d2, k2, s); + d3 = vmlal_s8(d3, k1, s); + + s = vld1_s8(sptr + 5 * IW); + d3 = vmlal_s8(d3, k2, s); + + ++sptr; + } + vst1q_s16(dptr + 0 * OW, d0); + vst1q_s16(dptr + 1 * OW, d1); + vst1q_s16(dptr + 2 * OW, d2); + vst1q_s16(dptr + 3 * OW, d3); + } + } + + if (oh + 3 == OH_stop) { + size_t ih = oh - PH; + size_t ow = OW_start; + for (; ow < OW_stop; ow += 8) { + size_t iw = ow - PW; + int16_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int8_t* __restrict fptr = filter; + int16x8_t d0, d1, d2; + int8x8_t k0, k1, k2, s; + if (add_to_dst) { + d0 = vld1q_s16(dptr + 0 * OW); + d1 = vld1q_s16(dptr + 1 * OW); + d2 = vld1q_s16(dptr + 2 * OW); + } else { + d0 = vdupq_n_s16(0); + d1 = vdupq_n_s16(0); + d2 = vdupq_n_s16(0); + } + for (size_t fw = 0; fw < 3; ++fw) { + k0 = vdup_n_s8(fptr[0 * 3 + fw]); + k1 = vdup_n_s8(fptr[1 * 3 + fw]); + k2 = vdup_n_s8(fptr[2 * 3 + fw]); + + s = vld1_s8(sptr + 0 * IW); + d0 = vmlal_s8(d0, k0, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k1, s); + d1 = vmlal_s8(d1, k0, s); + + s = vld1_s8(sptr + 2 * IW); + d0 = vmlal_s8(d0, k2, s); + d1 = vmlal_s8(d1, k1, s); + d2 = vmlal_s8(d2, k0, s); + + s = vld1_s8(sptr + 3 * IW); + d1 = vmlal_s8(d1, k2, s); + d2 = vmlal_s8(d2, k1, s); + + s = vld1_s8(sptr + 4 * IW); + d2 = vmlal_s8(d2, k2, s); + ++sptr; + } + vst1q_s16(dptr + 0 * OW, d0); + vst1q_s16(dptr + 1 * OW, d1); + vst1q_s16(dptr + 2 * OW, d2); + } + } else if (oh + 2 == OH_stop) { + size_t ih = oh - PH; + size_t ow = OW_start; + for (; ow < OW_stop; ow += 8) { + size_t iw = ow - PW; + int16_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int8_t* __restrict fptr = filter; + int16x8_t d0, d1; + int8x8_t k0, k1, k2, s; + if (add_to_dst) { + d0 = vld1q_s16(dptr + 0 * OW); + d1 = vld1q_s16(dptr + 1 * OW); + } else { + d0 = vdupq_n_s16(0); + d1 = vdupq_n_s16(0); + } + for (size_t fw = 0; fw < 3; ++fw) { + k0 = vdup_n_s8(fptr[0 * 3 + fw]); + k1 = vdup_n_s8(fptr[1 * 3 + fw]); + k2 = vdup_n_s8(fptr[2 * 3 + fw]); + + s = vld1_s8(sptr + 0 * IW); + d0 = vmlal_s8(d0, k0, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k1, s); + d1 = vmlal_s8(d1, k0, s); + + s = vld1_s8(sptr + 2 * IW); + d0 = vmlal_s8(d0, k2, s); + d1 = vmlal_s8(d1, k1, s); + + s = vld1_s8(sptr + 3 * IW); + d1 = vmlal_s8(d1, k2, s); + + ++sptr; + } + vst1q_s16(dptr + 0 * OW, d0); + vst1q_s16(dptr + 1 * OW, d1); + } + } else if (oh + 1 == OH_stop) { + size_t ih = oh - PH; + size_t ow = OW_start; + + for (; ow < OW_stop; ow += 8) { + size_t iw = ow - PW; + int16_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int8_t* __restrict fptr = filter; + int16x8_t d0; + int8x8_t k0, k1, k2, s; + + if (add_to_dst) { + d0 = vld1q_s16(dptr + 0 * OW); + } else { + d0 = vdupq_n_s16(0); + } + for (size_t fw = 0; fw < 3; ++fw) { + k0 = vdup_n_s8(fptr[0 * 3 + fw]); + k1 = vdup_n_s8(fptr[1 * 3 + fw]); + k2 = vdup_n_s8(fptr[2 * 3 + fw]); + + s = vld1_s8(sptr + 0 * IW); + d0 = vmlal_s8(d0, k0, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k1, s); + + s = vld1_s8(sptr + 2 * IW); + d0 = vmlal_s8(d0, k2, s); + + ++sptr; + } + vst1q_s16(dptr + 0 * OW, d0); + } + } +} + +template +void conv_bias::conv_direct_5x5_sc_int8_int8_int16(const int8_t* src, const int8_t* filter, + int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, + size_t PW) { + size_t OH_start = PH, OH_stop = OH - PH; + size_t OW_start = PW, OW_stop = OW - PW; + auto run_single = [&](size_t oh, size_t ow) { + if (!add_to_dst) { + dst[oh * OW + ow] = 0; + } + for (size_t fh = 0; fh < 5; ++fh) + for (size_t fw = 0; fw < 5; ++fw) { + size_t ih = oh + fh - PH; + size_t iw = ow + fw - PW; + if (ih < IH && iw < IW) { + dst[oh * OW + ow] += + (int16_t)src[ih * IW + iw] * filter[fh * 5 + fw]; + } + } + }; + for (size_t oh = 0; oh < OH_start; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + run_single(oh, ow); + } + } + for (size_t oh = OH_start; oh < OH_stop; ++oh) { + for (size_t ow = 0; ow < OW_start; ++ow) + run_single(oh, ow); + for (size_t ow = OW_stop; ow < OW; ++ow) + run_single(oh, ow); + } + for (size_t oh = OH_stop; oh < OH; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + run_single(oh, ow); + } + } + // 4x8 block + size_t oh = OH_start; + for (; oh + 4 <= OH_stop; oh += 4) { + size_t ih = oh - PH; + size_t ow = OW_start; + for (; ow + 8 <= OW_stop; ow += 8) { + size_t iw = ow - PW; + int16_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int8_t* __restrict fptr = filter; + int16x8_t d0, d1, d2, d3; + int8x8_t k0, k1, k2, k3, k4, s; + if (add_to_dst) { + d0 = vld1q_s16(dptr + 0 * OW); + d1 = vld1q_s16(dptr + 1 * OW); + d2 = vld1q_s16(dptr + 2 * OW); + d3 = vld1q_s16(dptr + 3 * OW); + } else { + d0 = vdupq_n_s16(0); + d1 = vdupq_n_s16(0); + d2 = vdupq_n_s16(0); + d3 = vdupq_n_s16(0); + } + for (size_t fw = 0; fw < 5; ++fw) { + k0 = vdup_n_s8(fptr[0 * 5 + fw]); + k1 = vdup_n_s8(fptr[1 * 5 + fw]); + k2 = vdup_n_s8(fptr[2 * 5 + fw]); + k3 = vdup_n_s8(fptr[3 * 5 + fw]); + k4 = vdup_n_s8(fptr[4 * 5 + fw]); + + s = vld1_s8(sptr + 0 * IW); + d0 = vmlal_s8(d0, k0, s); + + s = vld1_s8(sptr + 1 * IW); + d0 = vmlal_s8(d0, k1, s); + d1 = vmlal_s8(d1, k0, s); + + s = vld1_s8(sptr + 2 * IW); + d0 = vmlal_s8(d0, k2, s); + d1 = vmlal_s8(d1, k1, s); + d2 = vmlal_s8(d2, k0, s); + + s = vld1_s8(sptr + 3 * IW); + d0 = vmlal_s8(d0, k3, s); + d1 = vmlal_s8(d1, k2, s); + d2 = vmlal_s8(d2, k1, s); + d3 = vmlal_s8(d3, k0, s); + + s = vld1_s8(sptr + 4 * IW); + d0 = vmlal_s8(d0, k4, s); + d1 = vmlal_s8(d1, k3, s); + d2 = vmlal_s8(d2, k2, s); + d3 = vmlal_s8(d3, k1, s); + + s = vld1_s8(sptr + 5 * IW); + d1 = vmlal_s8(d1, k4, s); + d2 = vmlal_s8(d2, k3, s); + d3 = vmlal_s8(d3, k2, s); + + s = vld1_s8(sptr + 6 * IW); + d2 = vmlal_s8(d2, k4, s); + d3 = vmlal_s8(d3, k3, s); + + s = vld1_s8(sptr + 7 * IW); + d3 = vmlal_s8(d3, k4, s); + + ++sptr; + } + vst1q_s16(dptr + 0 * OW, d0); + vst1q_s16(dptr + 1 * OW, d1); + vst1q_s16(dptr + 2 * OW, d2); + vst1q_s16(dptr + 3 * OW, d3); + } + for (; ow < OW_stop; ++ow) { + run_single(oh + 0, ow); + run_single(oh + 1, ow); + run_single(oh + 2, ow); + run_single(oh + 3, ow); + } + } + for (; oh < OH_stop; ++oh) { + for (size_t ow = OW_start; ow < OW_stop; ++ow) { + run_single(oh, ow); + } + } +} + +template void conv_bias::conv_direct_2x2_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); +template void conv_bias::conv_direct_2x2_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); +template void conv_bias::conv_direct_3x3_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); +template void conv_bias::conv_direct_3x3_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); +template void conv_bias::conv_direct_5x5_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); +template void conv_bias::conv_direct_5x5_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.h b/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.h new file mode 100644 index 00000000..8a07b25b --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.h @@ -0,0 +1,41 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" + +#include +#include + +namespace megdnn { +namespace arm_common { +namespace conv_bias { + +template +void conv_direct_2x2_sc_int8_int8_int16(const int8_t* src, const int8_t* filter, + int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, + size_t PW); +template +void conv_direct_3x3_sc_int8_int8_int16(const int8_t* src, const int8_t* filter, + int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, + size_t PW); +template +void conv_direct_5x5_sc_int8_int8_int16(const int8_t* src, const int8_t* filter, + int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, + size_t PW); +} // namespace conv_bias +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.cpp new file mode 100644 index 00000000..55ebf206 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.cpp @@ -0,0 +1,540 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/conv_bias/int8x8x16/conv_stride2.h" +#include "src/common/utils.h" + +#include +#include "midout.h" +#include "src/arm_common/simd_macro/marm_neon.h" +MIDOUT_DECL(megdnn_arm_common_conv_bias_s2_filter) + +#pragma GCC diagnostic ignored "-Wunused-parameter" + +using namespace megdnn; +using namespace arm_common; +using namespace conv_bias; + +template +void conv_bias::conv_stride2_2x2_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW) { + size_t OH_start = div_ceil(PH, 2), + OH_stop = div_floor(IH + PH - 2, 2) + 1, + OW_start = div_ceil(PW, 2), + OW_stop = div_floor(IW + PW - 2, 2) + 1; + OH_start = std::min(OH, OH_start); + OH_stop = std::min(OH, OH_stop); + OW_start = std::min(OW, OW_start); + OW_stop = std::min(OW, OW_stop); + auto run_single = [&](size_t oh, size_t ow) { + if (!add_to_dst) { + dst[oh * OW + ow] = 0; + } + for (size_t fh = 0; fh < 2; ++fh) + for (size_t fw = 0; fw < 2; ++fw) { + size_t ih = oh * 2 + fh - PH; + size_t iw = ow * 2 + fw - PW; + if (ih < IH && iw < IW) { + dst[oh * OW + ow] += + (int16_t)src[ih * IW + iw] * filter[fh * 2 + fw]; + } + } + }; + for (size_t oh = 0; oh < OH_start; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + run_single(oh, ow); + } + } + for (size_t oh = OH_start; oh < OH_stop; ++oh) { + for (size_t ow = 0; ow < OW_start; ++ow) + run_single(oh, ow); + for (size_t ow = OW_stop; ow < OW; ++ow) + run_single(oh, ow); + } + for (size_t oh = OH_stop; oh < OH; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + run_single(oh, ow); + } + } + // 1x4 block + int8_t workspace[16]; + for (size_t i = 0; i < 8; ++i) + workspace[i] = filter[i & 1]; + for (size_t i = 8; i < 16; ++i) + workspace[i] = filter[(i & 1) + 2]; + int8x8_t f0 = vld1_s8(workspace + 0), f1 = vld1_s8(workspace + 8); + for (size_t oh = OH_start; oh < OH_stop; ++oh) { + size_t ih = oh * 2 - PH; + const int8_t* sptr = src + ih * IW + (OW_start * 2 - PW); + int16_t* dptr = dst + oh * OW + OW_start; + size_t ow = OW_start; + for (; ow + 4 <= OW_stop; ow += 4) { + int8x8_t s0 = vld1_s8(sptr + 0 * IW), s1 = vld1_s8(sptr + 1 * IW); + int16x8_t r0 = vmull_s8(s0, f0), r1 = vmull_s8(s1, f1); + int16x8_t tmp0 = vaddq_s16(r0, r1); + int32x4_t tmp1 = vpaddlq_s16(tmp0); + int16x4_t d = vmovn_s32(tmp1); + if (add_to_dst) { + d = vadd_s16(d, vld1_s16(dptr)); + } + vst1_s16(dptr, d); + sptr += 8; + dptr += 4; + } + for (; ow < OW_stop; ++ow) { + int16_t s0 = sptr[0], s1 = sptr[1], s2 = sptr[IW + 0], + s3 = sptr[IW + 1]; + int16_t f0 = filter[0], f1 = filter[1], f2 = filter[2], + f3 = filter[3]; + int16_t d = s0 * f0 + s1 * f1 + s2 * f2 + s3 * f3; + if (add_to_dst) { + *dptr += d; + } else { + *dptr = d; + } + sptr += 2; + dptr += 1; + } + } +} + +template +void conv_bias::conv_stride2_3x3_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW) { + size_t OH_start = div_ceil(PH, 2), + OH_stop = div_floor(IH + PH - 3, 2) + 1, + OW_start = div_ceil(PW, 2), + OW_stop = div_floor(IW + PW - 3, 2) + 1; + OH_start = std::min(OH, OH_start); + OH_stop = std::min(OH, OH_stop); + OW_start = std::min(OW, OW_start); + OW_stop = std::min(OW, OW_stop); + auto run_single = [&](size_t oh, size_t ow) { + if (!add_to_dst) { + dst[oh * OW + ow] = 0; + } + for (size_t fh = 0; fh < 3; ++fh) + for (size_t fw = 0; fw < 3; ++fw) { + size_t ih = oh * 2 + fh - PH; + size_t iw = ow * 2 + fw - PW; + if (ih < IH && iw < IW) { + dst[oh * OW + ow] += + (int16_t)src[ih * IW + iw] * filter[fh * 3 + fw]; + } + } + }; + for (size_t oh = 0; oh < OH_start; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + run_single(oh, ow); + } + } + for (size_t oh = OH_start; oh < OH_stop; ++oh) { + for (size_t ow = 0; ow < OW_start; ++ow) + run_single(oh, ow); + for (size_t ow = OW_stop; ow < OW; ++ow) + run_single(oh, ow); + } + for (size_t oh = OH_stop; oh < OH; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + run_single(oh, ow); + } + } + // 4x8 block + size_t oh = OH_start; + int8_t cache_even[9 * 16]; + int8_t cache_odd[9 * 16]; + const int8_t*(sptrs[3]) = {cache_even + 0, cache_odd + 0, cache_even + 1}; + for (; oh + 4 <= OH_stop; oh += 4) { + size_t ih = oh * 2 - PH; + size_t ow = OW_start; + for (; ow + 8 <= OW_stop; ow += 8) { + size_t iw = ow * 2 - PW; + int16_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int8_t* __restrict fptr = filter; + int16x8_t d0, d1, d2, d3; + int8x8_t k0, k1, k2, s; + { + // do transpose + for (size_t i = 0; i < 9; ++i) { + int8x16_t s_full = vld1q_s8(sptr + i * IW); + int8x8_t s_low = vget_low_s8(s_full); + int8x8_t s_high = vget_high_s8(s_full); + int8x8x2_t s_result = vuzp_s8(s_low, s_high); + vst1_s8(cache_even + i * 16, s_result.val[0]); + vst1_s8(cache_odd + i * 16, s_result.val[1]); + // the 8-th elem + cache_even[i * 16 + 8] = sptr[i * IW + 16]; + } + } + if (add_to_dst) { + d0 = vld1q_s16(dptr + 0 * OW); + d1 = vld1q_s16(dptr + 1 * OW); + d2 = vld1q_s16(dptr + 2 * OW); + d3 = vld1q_s16(dptr + 3 * OW); + } else { + d0 = vdupq_n_s16(0); + d1 = vdupq_n_s16(0); + d2 = vdupq_n_s16(0); + d3 = vdupq_n_s16(0); + } + for (size_t fw = 0; fw < 3; ++fw) { + k0 = vdup_n_s8(fptr[0 * 3 + fw]); + k1 = vdup_n_s8(fptr[1 * 3 + fw]); + k2 = vdup_n_s8(fptr[2 * 3 + fw]); + + // line 0 + s = vld1_s8(sptrs[fw] + 0 * 16); + d0 = vmlal_s8(d0, k0, s); + + // line 1 + s = vld1_s8(sptrs[fw] + 1 * 16); + d0 = vmlal_s8(d0, k1, s); + + // line 2 + s = vld1_s8(sptrs[fw] + 2 * 16); + d0 = vmlal_s8(d0, k2, s); + d1 = vmlal_s8(d1, k0, s); + + // line 3 + s = vld1_s8(sptrs[fw] + 3 * 16); + d1 = vmlal_s8(d1, k1, s); + + // line 4 + s = vld1_s8(sptrs[fw] + 4 * 16); + d1 = vmlal_s8(d1, k2, s); + d2 = vmlal_s8(d2, k0, s); + + // line 5 + s = vld1_s8(sptrs[fw] + 5 * 16); + d2 = vmlal_s8(d2, k1, s); + + // line 6 + s = vld1_s8(sptrs[fw] + 6 * 16); + d2 = vmlal_s8(d2, k2, s); + d3 = vmlal_s8(d3, k0, s); + + // line 7 + s = vld1_s8(sptrs[fw] + 7 * 16); + d3 = vmlal_s8(d3, k1, s); + + // line 8 + s = vld1_s8(sptrs[fw] + 8 * 16); + d3 = vmlal_s8(d3, k2, s); + } + vst1q_s16(dptr + 0 * OW, d0); + vst1q_s16(dptr + 1 * OW, d1); + vst1q_s16(dptr + 2 * OW, d2); + vst1q_s16(dptr + 3 * OW, d3); + } + for (; ow < OW_stop; ++ow) { + run_single(oh + 0, ow); + run_single(oh + 1, ow); + run_single(oh + 2, ow); + run_single(oh + 3, ow); + } + } + for (; oh < OH_stop; ++oh) { + for (size_t ow = OW_start; ow < OW_stop; ++ow) { + run_single(oh, ow); + } + } +} + +template +void conv_bias::conv_stride2_5x5_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW) { + size_t OH_start = div_ceil(PH, 2), + OH_stop = div_floor(IH + PH - 5, 2) + 1, + OW_start = div_ceil(PW, 2), + OW_stop = div_floor(IW + PW - 5, 2) + 1; + OH_start = std::min(OH, OH_start); + OH_stop = std::min(OH, OH_stop); + OW_start = std::min(OW, OW_start); + OW_stop = std::min(OW, OW_stop); + auto run_single = [&](size_t oh, size_t ow) { + if (!add_to_dst) { + dst[oh * OW + ow] = 0; + } + for (size_t fh = 0; fh < 5; ++fh) + for (size_t fw = 0; fw < 5; ++fw) { + size_t ih = oh * 2 + fh - PH; + size_t iw = ow * 2 + fw - PW; + if (ih < IH && iw < IW) { + dst[oh * OW + ow] += + (int16_t)src[ih * IW + iw] * filter[fh * 5 + fw]; + } + } + }; + for (size_t oh = 0; oh < OH_start; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + run_single(oh, ow); + } + } + for (size_t oh = OH_start; oh < OH_stop; ++oh) { + for (size_t ow = 0; ow < OW_start; ++ow) + run_single(oh, ow); + for (size_t ow = OW_stop; ow < OW; ++ow) + run_single(oh, ow); + } + for (size_t oh = OH_stop; oh < OH; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + run_single(oh, ow); + } + } + // 4x8 block + size_t oh = OH_start; + int8_t cache_even[11 * 16]; + int8_t cache_odd[11 * 16]; + const int8_t*(sptrs[5]) = { + cache_even + 0, cache_odd + 0, cache_even + 1, + cache_odd + 1, cache_even + 2, + }; + for (; oh + 4 <= OH_stop; oh += 4) { + size_t ih = oh * 2 - PH; + size_t ow = OW_start; + for (; ow + 8 <= OW_stop; ow += 8) { + size_t iw = ow * 2 - PW; + int16_t* __restrict dptr = dst + oh * OW + ow; + const int8_t* __restrict sptr = src + ih * IW + iw; + const int8_t* __restrict fptr = filter; + int16x8_t d0, d1, d2, d3; + int8x8_t k0, k1, k2, k3, k4, s; + { + // do transpose + for (size_t i = 0; i < 11; ++i) { + int8x16_t s_full = vld1q_s8(sptr + i * IW); + int8x8_t s_low = vget_low_s8(s_full); + int8x8_t s_high = vget_high_s8(s_full); + int8x8x2_t s_result = vuzp_s8(s_low, s_high); + vst1_s8(cache_even + i * 16, s_result.val[0]); + vst1_s8(cache_odd + i * 16, s_result.val[1]); + // last elements + cache_even[i * 16 + 8] = sptr[i * IW + 16]; + cache_odd[i * 16 + 8] = sptr[i * IW + 17]; + cache_even[i * 16 + 9] = sptr[i * IW + 18]; + } + } + if (add_to_dst) { + d0 = vld1q_s16(dptr + 0 * OW); + d1 = vld1q_s16(dptr + 1 * OW); + d2 = vld1q_s16(dptr + 2 * OW); + d3 = vld1q_s16(dptr + 3 * OW); + } else { + d0 = vdupq_n_s16(0); + d1 = vdupq_n_s16(0); + d2 = vdupq_n_s16(0); + d3 = vdupq_n_s16(0); + } + for (size_t fw = 0; fw < 5; ++fw) { + k0 = vdup_n_s8(fptr[0 * 5 + fw]); + k1 = vdup_n_s8(fptr[1 * 5 + fw]); + k2 = vdup_n_s8(fptr[2 * 5 + fw]); + k3 = vdup_n_s8(fptr[3 * 5 + fw]); + k4 = vdup_n_s8(fptr[4 * 5 + fw]); + + // line 0 + s = vld1_s8(sptrs[fw] + 0 * 16); + d0 = vmlal_s8(d0, k0, s); + + // line 1 + s = vld1_s8(sptrs[fw] + 1 * 16); + d0 = vmlal_s8(d0, k1, s); + + // line 2 + s = vld1_s8(sptrs[fw] + 2 * 16); + d0 = vmlal_s8(d0, k2, s); + d1 = vmlal_s8(d1, k0, s); + + // line 3 + s = vld1_s8(sptrs[fw] + 3 * 16); + d0 = vmlal_s8(d0, k3, s); + d1 = vmlal_s8(d1, k1, s); + + // line 4 + s = vld1_s8(sptrs[fw] + 4 * 16); + d0 = vmlal_s8(d0, k4, s); + d1 = vmlal_s8(d1, k2, s); + d2 = vmlal_s8(d2, k0, s); + + // line 5 + s = vld1_s8(sptrs[fw] + 5 * 16); + d1 = vmlal_s8(d1, k3, s); + d2 = vmlal_s8(d2, k1, s); + + // line 6 + s = vld1_s8(sptrs[fw] + 6 * 16); + d1 = vmlal_s8(d1, k4, s); + d2 = vmlal_s8(d2, k2, s); + d3 = vmlal_s8(d3, k0, s); + + // line 7 + s = vld1_s8(sptrs[fw] + 7 * 16); + d2 = vmlal_s8(d2, k3, s); + d3 = vmlal_s8(d3, k1, s); + + // line 8 + s = vld1_s8(sptrs[fw] + 8 * 16); + d2 = vmlal_s8(d2, k4, s); + d3 = vmlal_s8(d3, k2, s); + + // line 9 + s = vld1_s8(sptrs[fw] + 9 * 16); + d3 = vmlal_s8(d3, k3, s); + + // line 9 + s = vld1_s8(sptrs[fw] + 10 * 16); + d3 = vmlal_s8(d3, k4, s); + } + vst1q_s16(dptr + 0 * OW, d0); + vst1q_s16(dptr + 1 * OW, d1); + vst1q_s16(dptr + 2 * OW, d2); + vst1q_s16(dptr + 3 * OW, d3); + } + for (; ow < OW_stop; ++ow) { + run_single(oh + 0, ow); + run_single(oh + 1, ow); + run_single(oh + 2, ow); + run_single(oh + 3, ow); + } + } + for (; oh < OH_stop; ++oh) { + for (size_t ow = OW_start; ow < OW_stop; ++ow) { + run_single(oh, ow); + } + } +} +template void conv_bias::conv_stride2_2x2_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); +template void conv_bias::conv_stride2_2x2_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); +template void conv_bias::conv_stride2_3x3_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); +template void conv_bias::conv_stride2_3x3_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); +template void conv_bias::conv_stride2_5x5_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); +template void conv_bias::conv_stride2_5x5_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, + size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + +namespace { +void conv_2x2_optimize_single_channel(const int8_t* src, const uint32_t IH, + const uint32_t IW, const int8_t* filter, + int16_t* dst, const uint32_t OH, + const uint32_t OW) { + int8_t workspace[16]; + workspace[0] = filter[0]; + workspace[1] = filter[1]; + workspace[2] = filter[0]; + workspace[3] = filter[1]; + workspace[4] = filter[0]; + workspace[5] = filter[1]; + workspace[6] = filter[0]; + workspace[7] = filter[1]; + workspace[8] = filter[2]; + workspace[9] = filter[3]; + workspace[10] = filter[2]; + workspace[11] = filter[3]; + workspace[12] = filter[2]; + workspace[13] = filter[3]; + workspace[14] = filter[2]; + workspace[15] = filter[3]; + int8x8_t f0 = vld1_s8(workspace), f1 = vld1_s8(workspace + 8); + + int8x8_t v0, v1; + int16x8_t r0, r1, s0; + int16x4_t s, s16; + for (uint32_t i = 0, j; i < IH; i += 2) { + for (j = 0; j + 8 <= IW; j += 8) { + v0 = vld1_s8(src), v1 = vld1_s8(src + IW); + r0 = vmull_s8(v0, f0), r1 = vmull_s8(v1, f1); + s0 = vaddq_s16(r0, r1); + s16 = vmovn_s32(vpaddlq_s16(s0)); + s = vadd_s16(vld1_s16(dst), s16); + vst1_s16(dst, s); + src += 8; + dst += 4; + } + for (; j < IW; j += 2) { + (*dst++) += static_cast(src[0]) * + static_cast(filter[0]) + + static_cast(src[1]) * + static_cast(filter[1]) + + static_cast(src[IW]) * + static_cast(filter[2]) + + static_cast(src[IW + 1]) * + static_cast(filter[3]); + src += 2; + } + src += IW; + } +} + +} // anonymous namespace + +size_t conv_bias::get_workspace_in_bytes_conv_int8x8x16_stride2_flt2( + const ConvBiasImpl::NCBKernSizeParam& param) { + return 0; +} + +bool conv_bias::can_conv_int8x8x16_stride2_flt2( + const ConvBiasImpl::NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + return fm.format == param::ConvBias::Format::NCHW && !fm.should_flip && + param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int16 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5) && + param.isz[0] % 2 == 0 && param.isz[1] % 2 == 0 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.spatial[0] == 2 && + fm.spatial[1] == 2 && fm.padding[0] == 0 && fm.padding[1] == 0; +} + +void conv_bias::conv_int8x8x16_stride2_flt2( + const ConvBiasImpl::NCBKernParam& param) { + UNPACK_CONV_F32_NCB_KERN_SIZES(param); + megdnn_ignore(FH); + megdnn_ignore(FW); + megdnn_ignore(SH); + megdnn_ignore(SW); + megdnn_ignore(PH); + megdnn_ignore(PW); + auto src = param.src(); + auto dst_init = param.dst(); + auto filter_init = param.filter(); + const uint32_t shape = IH * IW; + for (uint32_t n = 0; n < N; ++n) { + auto fptr = filter_init; + auto dst = dst_init + n * param.out_bs; + memset(dst, 0, sizeof(dst[0]) * OC * OH * OW); + for (uint32_t j = 0; j < OC; ++j) { + for (uint32_t k = 0; k < IC; ++k) { + conv_2x2_optimize_single_channel(src + k * shape, IH, IW, fptr, + dst, OH, OW); + fptr += 4; + } + dst += OH * OW; + } + src += param.inp_bs; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.h b/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.h new file mode 100644 index 00000000..7ec4af90 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.h @@ -0,0 +1,49 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/arm_common/conv_bias/opr_impl.h" + +#include +#include + +namespace megdnn { +namespace arm_common { +namespace conv_bias { + +template +void conv_stride2_2x2_sc_int8_int8_int16(const int8_t* src, + const int8_t* filter, int16_t* dst, + size_t IH, size_t IW, size_t OH, + size_t OW, size_t PH, size_t PW); +template +void conv_stride2_3x3_sc_int8_int8_int16(const int8_t* src, + const int8_t* filter, int16_t* dst, + size_t IH, size_t IW, size_t OH, + size_t OW, size_t PH, size_t PW); +template +void conv_stride2_5x5_sc_int8_int8_int16(const int8_t* src, + const int8_t* filter, int16_t* dst, + size_t IH, size_t IW, size_t OH, + size_t OW, size_t PH, size_t PW); + +bool can_conv_int8x8x16_stride2_flt2( + const ConvBiasImpl::NCBKernSizeParam& param); + +void conv_int8x8x16_stride2_flt2(const ConvBiasImpl::NCBKernParam& param); + +size_t get_workspace_in_bytes_conv_int8x8x16_stride2_flt2( + const ConvBiasImpl::NCBKernSizeParam& param); + +} // namespace conv_bias +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/intrinsic_helper.h b/dnn/src/arm_common/conv_bias/intrinsic_helper.h new file mode 100644 index 00000000..2cd9c3a3 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/intrinsic_helper.h @@ -0,0 +1,500 @@ +#pragma once +/** + * \file dnn/src/arm_common/conv_bias/intrinsic_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/neon_struct.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/fallback/conv_bias/common.h" +namespace megdnn { +namespace { + +////////////////////Store_OC4_OW8_Remain///////////////////////// +template +struct Store_OC4_OW8_Remain { + static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr); +}; + +template +struct Store_OC4_OW8_Remain<0, Op> { + static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) { + op({{c[0], c[1]}}, reinterpret_cast(dst_ptr)); + op({{c[2], c[3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[4], c[5]}}, reinterpret_cast(dst_ptr + 16)); + op({{c[6], c[7]}}, reinterpret_cast(dst_ptr + 24)); + } +}; + +template +struct Store_OC4_OW8_Remain<7, Op> { + static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) { + op({{c[0], c[1]}}, reinterpret_cast(dst_ptr)); + op({{c[2], c[3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[4], c[5]}}, reinterpret_cast(dst_ptr + 16)); + op(c[6], reinterpret_cast(dst_ptr + 24)); + } +}; +template +struct Store_OC4_OW8_Remain<6, Op> { + static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) { + op({{c[0], c[1]}}, reinterpret_cast(dst_ptr)); + op({{c[2], c[3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[4], c[5]}}, reinterpret_cast(dst_ptr + 16)); + } +}; +template +struct Store_OC4_OW8_Remain<5, Op> { + static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) { + op({{c[0], c[1]}}, reinterpret_cast(dst_ptr)); + op({{c[2], c[3]}}, reinterpret_cast(dst_ptr + 8)); + op(c[4], reinterpret_cast(dst_ptr + 16)); + } +}; +template +struct Store_OC4_OW8_Remain<4, Op> { + static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) { + op({{c[0], c[1]}}, reinterpret_cast(dst_ptr)); + op({{c[2], c[3]}}, reinterpret_cast(dst_ptr + 8)); + } +}; +template +struct Store_OC4_OW8_Remain<3, Op> { + static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) { + op({{c[0], c[1]}}, reinterpret_cast(dst_ptr)); + op(c[2], reinterpret_cast(dst_ptr + 8)); + } +}; +template +struct Store_OC4_OW8_Remain<2, Op> { + static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) { + op({{c[0], c[1]}}, reinterpret_cast(dst_ptr)); + } +}; +template +struct Store_OC4_OW8_Remain<1, Op> { + static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) { + op(c[0], reinterpret_cast(dst_ptr)); + } +}; + +template +inline void store_oc4_ow8_remain_static(int32x4_t c[8], const Op& op, + int8_t* dst_ptr) { + Store_OC4_OW8_Remain::impl(c, op, dst_ptr); +} + +template +struct StoreOcxOw4Remain { + static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc); +}; + +template +struct StoreOcxOw4Remain<2, 0, Op, T> { + static void impl(int32x4_t c[2][4], const Op& op, int8_t* dst_ptr, + int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + + op({{c[1][0], c[1][1]}}, + reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + } +}; + +template +struct StoreOcxOw4Remain<2, 3, Op, T> { + static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op(c[0][2], reinterpret_cast(dst_ptr + 8)); + + op({{c[1][0], c[1][1]}}, + reinterpret_cast(dst_ptr + ld_dst_oc)); + op(c[1][2], reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + } +}; +template +struct StoreOcxOw4Remain<2, 2, Op, T> { + static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[1][0], c[1][1]}}, + reinterpret_cast(dst_ptr + ld_dst_oc)); + } +}; +template +struct StoreOcxOw4Remain<2, 1, Op, T> { + static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) { + op(c[0][0], reinterpret_cast(dst_ptr)); + op(c[1][0], reinterpret_cast(dst_ptr + ld_dst_oc)); + } +}; + +template +struct StoreOcxOw4Remain<1, 0, Op, T> { + static void impl(int32x4_t c[2][4], const Op& op, int8_t* dst_ptr, + int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + } +}; + +template +struct StoreOcxOw4Remain<1, 3, Op, T> { + static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op(c[0][2], reinterpret_cast(dst_ptr + 8)); + } +}; +template +struct StoreOcxOw4Remain<1, 2, Op, T> { + static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + } +}; +template +struct StoreOcxOw4Remain<1, 1, Op, T> { + static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) { + op(c[0][0], reinterpret_cast(dst_ptr)); + } +}; +template +inline void store_ocx_ow4_remain_static(T& c, const Op& op, int8_t* dst_ptr, + int ld_dst_oc) { + StoreOcxOw4Remain::impl(c, op, dst_ptr, ld_dst_oc); +} + +////////////////////Store_OC8_OW8_Remain///////////////////////// + +template +struct Store_OC8_OW8_Remain { + static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, + int ld_dst_oc); +}; + +template +struct Store_OC8_OW8_Remain<0, Op> { + static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, + int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); + + op({{c[1][0], c[1][1]}}, + reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][6], c[1][7]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 24)); + } +}; + +template +struct Store_OC8_OW8_Remain<7, Op> { + static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, + int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op(c[0][6], reinterpret_cast(dst_ptr + 24)); + + op({{c[1][0], c[1][1]}}, + reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op(c[1][6], reinterpret_cast(dst_ptr + ld_dst_oc + 24)); + } +}; + +template +struct Store_OC8_OW8_Remain<6, Op> { + static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, + int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + + op({{c[1][0], c[1][1]}}, + reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + } +}; + +template +struct Store_OC8_OW8_Remain<5, Op> { + static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, + int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op(c[0][4], reinterpret_cast(dst_ptr + 16)); + + op({{c[1][0], c[1][1]}}, + reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op(c[1][4], reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + } +}; + +template +struct Store_OC8_OW8_Remain<4, Op> { + static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, + int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + + op({{c[1][0], c[1][1]}}, + reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + } +}; + +template +struct Store_OC8_OW8_Remain<3, Op> { + static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, + int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op(c[0][2], reinterpret_cast(dst_ptr + 8)); + + op({{c[1][0], c[1][1]}}, + reinterpret_cast(dst_ptr + ld_dst_oc)); + op(c[1][2], reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + } +}; +template +struct Store_OC8_OW8_Remain<2, Op> { + static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, + int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[1][0], c[1][1]}}, + reinterpret_cast(dst_ptr + ld_dst_oc)); + } +}; +template +struct Store_OC8_OW8_Remain<1, Op> { + static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, + int ld_dst_oc) { + op(c[0][0], reinterpret_cast(dst_ptr)); + op(c[1][0], reinterpret_cast(dst_ptr + ld_dst_oc)); + } +}; + +template +inline void store_oc8_ow8_remain_static(int32x4_t c[2][8], const Op& op, + int8_t* dst_ptr, int ld_dst_oc) { + Store_OC8_OW8_Remain::impl(c, op, dst_ptr, ld_dst_oc); +} + +/////////////////////////////////////////////////////// + +template +inline void init_oc4_ow8(int32x4_t c[8], const int32_t* bias_ptr) { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { +#define BAIS_INIT(step) c[step] = vld1q_s32(bias_ptr); + UNROLL_CALL_RAW(8, BAIS_INIT); +#undef BAIS_INIT + } else { +#define BAIS_INIT(step) c[step] = vdupq_n_s32(0); + UNROLL_CALL_RAW(8, BAIS_INIT); +#undef BAIS_INIT + } +} + +template +inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr, + int oc_step) { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { +#define BAIS_INIT(step) \ + c[0][step] = vld1q_s32(bias_ptr); \ + c[1][step] = vld1q_s32(bias_ptr + oc_step); + UNROLL_CALL_RAW(8, BAIS_INIT); +#undef BAIS_INIT + } else { +#define BAIS_INIT(step) \ + c[0][step] = vdupq_n_s32(0); \ + c[1][step] = vdupq_n_s32(0); + UNROLL_CALL_RAW(8, BAIS_INIT); +#undef BAIS_INIT + } +} +template +struct InitOcxOw4 { + static void impl(T& c, const int32_t* bias_ptr, int oc_step); +}; + +template +struct InitOcxOw4<2, bias_mode, T> { + static void impl(T& c, const int32_t* bias_ptr, int oc_step) { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { +#define BAIS_INIT(step) \ + c[0][step] = vld1q_s32(bias_ptr); \ + c[1][step] = vld1q_s32(bias_ptr + oc_step); + UNROLL_CALL_RAW(4, BAIS_INIT); +#undef BAIS_INIT + } else { +#define BAIS_INIT(step) \ + c[0][step] = vdupq_n_s32(0); \ + c[1][step] = vdupq_n_s32(0); + UNROLL_CALL_RAW(4, BAIS_INIT); +#undef BAIS_INIT + } + } +}; + +template +struct InitOcxOw4<1, bias_mode, T> { + static void impl(T& c, const int32_t* bias_ptr, int oc_step) { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { +#define BAIS_INIT(step) c[0][step] = vld1q_s32(bias_ptr); + UNROLL_CALL_RAW(4, BAIS_INIT); +#undef BAIS_INIT + } else { +#define BAIS_INIT(step) c[0][step] = vdupq_n_s32(0); + UNROLL_CALL_RAW(4, BAIS_INIT); +#undef BAIS_INIT + } + } +}; + +template +inline void init_ocx_ow4(T& c, const int32_t* bias_ptr, int oc_step) { + InitOcxOw4::impl(c, bias_ptr, oc_step); +} +/////////////////////////////////////// +template +struct LoadHelper { + static void impl(T& weight, const int8_t* ptr, int oc_offset, XT... args); +}; + +#define WEIGHT_CB(step) \ + src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...); + +template +struct LoadHelper<1, base_offset, ptr_step, 0, Func, T, XT...> { + static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { + UNROLL_CALL_RAW(1, WEIGHT_CB); + } +}; +template +struct LoadHelper<2, base_offset, ptr_step, 0, Func, T, XT...> { + static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { + UNROLL_CALL_RAW(2, WEIGHT_CB); + } +}; + +template +struct LoadHelper<3, base_offset, ptr_step, 0, Func, T, XT...> { + static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { + UNROLL_CALL_RAW(3, WEIGHT_CB); + } +}; +template +struct LoadHelper<4, base_offset, ptr_step, 0, Func, T, XT...> { + static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { + UNROLL_CALL_RAW(4, WEIGHT_CB); + } +}; +template +struct LoadHelper<5, base_offset, ptr_step, 0, Func, T, XT...> { + static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { + UNROLL_CALL_RAW(5, WEIGHT_CB); + } +}; +template +struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, XT...> { + static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { + UNROLL_CALL_RAW(6, WEIGHT_CB); + } +}; +#undef WEIGHT_CB + +#define WEIGHT_CB(step) \ + src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); +template +struct LoadHelper<1, base_offset, ptr_step, 1, Func, T> { + static void impl(T& src, const int8_t* ptr, int oc_offset) { + UNROLL_CALL_RAW(1, WEIGHT_CB); + } +}; +template +struct LoadHelper<2, base_offset, ptr_step, 1, Func, T> { + static void impl(T& src, const int8_t* ptr, int oc_offset) { + UNROLL_CALL_RAW(2, WEIGHT_CB); + } +}; + +template +struct LoadHelper<3, base_offset, ptr_step, 1, Func, T> { + static void impl(T& src, const int8_t* ptr, int oc_offset) { + UNROLL_CALL_RAW(3, WEIGHT_CB); + } +}; + +#undef WEIGHT_CB + +#define WEIGHT_CB(step) \ + src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \ + src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset); + +template +struct LoadHelper<1, base_offset, ptr_step, 2, Func, T> { + static void impl(T& src, const int8_t* ptr, int oc_offset) { + UNROLL_CALL_RAW(1, WEIGHT_CB); + } +}; +template +struct LoadHelper<2, base_offset, ptr_step, 2, Func, T> { + static void impl(T& src, const int8_t* ptr, int oc_offset) { + UNROLL_CALL_RAW(2, WEIGHT_CB); + } +}; + +template +struct LoadHelper<3, base_offset, ptr_step, 2, Func, T> { + static void impl(T& src, const int8_t* ptr, int oc_offset) { + UNROLL_CALL_RAW(3, WEIGHT_CB); + } +}; + +#undef WEIGHT_CB + +template +inline void load_helper(T& weight, const int8_t* ptr, int oc_offset) { + LoadHelper::impl( + weight, ptr, oc_offset); +} + +template +inline void load_helper_x(T& weight, const int8_t* ptr, int oc_offset, + XT... args) { + LoadHelper::impl(weight, ptr, oc_offset, args...); +} + +} // namespace +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/matmul_postprocess.h b/dnn/src/arm_common/conv_bias/matmul_postprocess.h new file mode 100644 index 00000000..a4f0e657 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/matmul_postprocess.h @@ -0,0 +1,388 @@ +/** + * \file dnn/src/arm_common/conv_bias/matmul_postprocess.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { + +#define SAVE(C, vres, n, idx) \ + switch (n) { \ + case 4: \ + vst1_lane_s32(reinterpret_cast(C), \ + vreinterpret_s32_s8(vres), idx / 4); \ + break; \ + case 3: \ + vst1_lane_s8(C + 2, vres, idx + 2); \ + case 2: \ + vst1_lane_s8(C + 1, vres, idx + 1); \ + case 1: \ + vst1_lane_s8(C + 0, vres, idx + 0); \ + break; \ + default: \ + megdnn_assert(0); \ + } + +#define SAVEU(C, vres, n, idx) \ + switch (n) { \ + case 4: \ + vst1_lane_s32(reinterpret_cast(C), \ + vreinterpret_s32_u8(vres), idx / 4); \ + break; \ + case 3: \ + vst1_lane_u8(C + 2, vres, idx + 2); \ + case 2: \ + vst1_lane_u8(C + 1, vres, idx + 1); \ + case 1: \ + vst1_lane_u8(C + 0, vres, idx + 0); \ + break; \ + default: \ + megdnn_assert(0); \ + } + +template +struct Process; + +template +struct Process, Op>::value>> { + static dst_neon_type run(const int32x4x2_t& wp, const int32x4x2_t, + const Op& op) { + return op(wp); + } +}; + +template +struct Process, Op>::value>> { + static dst_neon_type run(const int32x4x2_t& wp, const int32x4x2_t bias, + const Op& op) { + return op(wp, bias); + } +}; + +template +struct ConvBiasMatmul { + static void postprocess(const dt_int32* bias, const dt_int32* workspace, + dst_ctype* C, size_t LDC, Op op); +}; + +template +struct ConvBiasMatmul { + static void postprocess(const dt_int32* bias, const dt_int32* workspace, + dt_int8* C, size_t LDC, const Op& op) { + static_assert(m > 0 && m <= block_m, "invalid m or n"); + int32x4_t vbias0, vwp0, vwp1, vwp2; + if (bmode != BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias0 = QConverterBase::vzero(); + } + for (int i = 0; i < m; i++) { + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias0 = vdupq_n_s32(*bias); + } + vwp0 = vld1q_s32(workspace); + vwp1 = vld1q_s32(workspace + 4); + vwp2 = vld1q_s32(workspace + 8); + + int8x8_t vres; + vres = Process::run({{vwp0, vwp1}}, + {{vbias0, vbias0}}, op); + vst1_s8(C, vres); + + vres = Process::run({{vwp1, vwp2}}, + {{vbias0, vbias0}}, op); + //! save the high half + vst1_lane_s32(reinterpret_cast(C + 8), + vreinterpret_s32_s8(vres), 1); + + bias++; + C += LDC; + workspace += 12; + } + } +}; + + +template +struct ConvBiasMatmul { + static void postprocess(const dt_int32* bias, const dt_int32* workspace, + dt_int8* C, size_t LDC, const Op& op) { + static_assert(m > 0 && m <= block_m && n > 0 && n <= 4, + "invalid m or n"); + int i = 0; + int32x4_t vbias0, vbias1, vwp0, vwp1; + if (bmode != BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias0 = QConverterBase::vzero(); + vbias1 = QConverterBase::vzero(); + } + for (; i + 1 < m; i += 2) { + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias0 = vdupq_n_s32(*bias); + } + vwp0 = vld1q_s32(workspace); + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias++; + vbias1 = vdupq_n_s32(*bias); + } + workspace += 4; + vwp1 = vld1q_s32(workspace); + + int8x8_t vres; + vres = Process::run({{vwp0, vwp1}}, + {{vbias0, vbias1}}, op); + SAVE(C, vres, n, 0); + C += LDC; + SAVE(C, vres, n, 4); + + bias++; + C += LDC; + workspace += 4; + } + + if (i < m) { + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias0 = vdupq_n_s32(*bias); + } + vwp0 = vld1q_s32(workspace); + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias1 = QConverterBase::vzero(); + } + vwp1 = QConverterBase::vzero(); + + int8x8_t vres; + vres = Process::run({{vwp0, vwp1}}, + {{vbias0, vbias1}}, op); + SAVE(C, vres, n, 0); + C += LDC; + } + } +}; + +template +struct ConvBiasMatmul { + static void postprocess(const dt_int32* bias, const dt_int32* workspace, + dt_int8* C, size_t LDC, const Op& op) { + static_assert(m > 0 && m <= block_m, "invalid m or n"); + int i = 0; + int32x4_t vbias0, vbias1, vwp0, vwp1; + if (bmode != BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias0 = QConverterBase::vzero(); + vbias1 = QConverterBase::vzero(); + } + for (; i + 1 < m; i += 2) { + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias0 = vdupq_n_s32(*bias); + } + vwp0 = vcombine_s32(vld1_s32(workspace), vdup_n_s32(0)); + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias++; + vbias1 = vdupq_n_s32(*bias); + } + workspace += 2; + vwp1 = vcombine_s32(vld1_s32(workspace), vdup_n_s32(0)); + + int8x8_t vres; + vres = Process::run({{vwp0, vwp1}}, + {{vbias0, vbias1}}, op); + SAVE(C, vres, n, 0); + C += LDC; + SAVE(C, vres, n, 4); + + bias++; + C += LDC; + workspace += 2; + } + + if (i < m) { + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias0 = vdupq_n_s32(*bias); + } + vwp0 = vcombine_s32(vld1_s32(workspace), vdup_n_s32(0)); + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias1 = QConverterBase::vzero(); + } + vwp1 = QConverterBase::vzero(); + + int8x8_t vres; + vres = Process::run({{vwp0, vwp1}}, + {{vbias0, vbias1}}, op); + SAVE(C, vres, n, 0); + C += LDC; + } + } +}; + +template +struct ConvBiasMatmul { + static void postprocess(const dt_int32* bias, const dt_int32* workspace, + dt_uint8* C, size_t LDC, const Op& op) { + static_assert(m > 0 && m <= block_m, "invalid m or n"); + int32x4_t vbias0, vwp0, vwp1; + if (bmode != BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias0 = QConverterBase::vzero(); + } + for (int i = 0; i < m; i++) { + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias0 = vdupq_n_s32(*bias); + } + vwp0 = vld1q_s32(workspace); + vwp1 = vld1q_s32(workspace + 4); + + uint8x8_t vres; + vres = Process::run( + {{vwp0, vwp1}}, {{vbias0, vbias0}}, op); + vst1_u8(C, vres); + + bias++; + C += LDC; + workspace += 8; + } + } +}; + + +template +struct ConvBiasMatmul { + static void postprocess(const dt_int32* bias, const dt_int32* workspace, + dt_uint8* C, size_t LDC, const Op& op) { + static_assert(m > 0 && m <= block_m && n > 0 && n <= 4, + "invalid m or n"); + int i = 0; + int32x4_t vbias0, vbias1, vwp0, vwp1; + if (bmode != BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias0 = QConverterBase::vzero(); + vbias1 = QConverterBase::vzero(); + } + for (; i + 1 < m; i += 2) { + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias0 = vdupq_n_s32(*bias); + } + vwp0 = vld1q_s32(workspace); + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias++; + vbias1 = vdupq_n_s32(*bias); + } + workspace += 4; + vwp1 = vld1q_s32(workspace); + + uint8x8_t vres; + vres = Process::run({{vwp0, vwp1}}, + {{vbias0, vbias1}}, op); + SAVEU(C, vres, n, 0); + C += LDC; + SAVEU(C, vres, n, 4); + + bias++; + C += LDC; + workspace += 4; + } + + if (i < m) { + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias0 = vdupq_n_s32(*bias); + } + vwp0 = vld1q_s32(workspace); + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias1 = QConverterBase::vzero(); + } + vwp1 = QConverterBase::vzero(); + + uint8x8_t vres; + vres = Process::run({{vwp0, vwp1}}, + {{vbias0, vbias1}}, op); + SAVEU(C, vres, n, 0); + C += LDC; + } + } +}; + + +#define DISPATCH_M(cb, _m, _n, ...) \ + switch (_m) { \ + case 4: { \ + DISPATCH_N(cb, 4, _n, ##__VA_ARGS__); \ + break; \ + } \ + case 3: { \ + DISPATCH_N(cb, 3, _n, ##__VA_ARGS__); \ + break; \ + } \ + case 2: { \ + DISPATCH_N(cb, 2, _n, ##__VA_ARGS__); \ + break; \ + } \ + case 1: { \ + DISPATCH_N(cb, 1, _n, ##__VA_ARGS__); \ + break; \ + } \ + default: \ + megdnn_assert(0); \ + } + +#define DISPATCH_N(cb, _m, _n, ...) \ + switch (_n) { \ + case 4: { \ + cb(_m, 4, ##__VA_ARGS__); \ + break; \ + } \ + case 3: { \ + cb(_m, 3, ##__VA_ARGS__); \ + break; \ + } \ + case 2: { \ + cb(_m, 2, ##__VA_ARGS__); \ + break; \ + } \ + case 1: { \ + cb(_m, 1, ##__VA_ARGS__); \ + break; \ + } \ + default: \ + megdnn_assert(0); \ + } + +//! _n should be a compiler time constant +#define DISPATCH_M_N(cb, _m, _n, ...) \ + switch (_m) { \ + case 4: { \ + cb(4, _n, ##__VA_ARGS__); \ + break; \ + } \ + case 3: { \ + cb(3, _n, ##__VA_ARGS__); \ + break; \ + } \ + case 2: { \ + cb(2, _n, ##__VA_ARGS__); \ + break; \ + } \ + case 1: { \ + cb(1, _n, ##__VA_ARGS__); \ + break; \ + } \ + default: \ + megdnn_assert(0); \ + } + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/neon_struct.h b/dnn/src/arm_common/conv_bias/neon_struct.h new file mode 100644 index 00000000..535674ec --- /dev/null +++ b/dnn/src/arm_common/conv_bias/neon_struct.h @@ -0,0 +1,54 @@ +#pragma once +/** + * \file dnn/src/arm_common/conv_bias/neon_struct.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/simd_macro/marm_neon.h" +namespace megdnn { +namespace { +struct Vdotq_s32_h { + static int32x4_t impl(int8x16_t& a, int8x16_t& b, int32x4_t& c, + int16x8_t& temp) { + return vdotq_s32_h(a, b, c, temp); + } +}; +struct Vdot2_s32_h { + static int32x4_t impl(int8x8_t a, int8x8_t b, int32x4_t c, int16x8_t temp) { + return vdot2_s32_h(a, b, c, temp); + } +}; + +struct Vmlal_s16 { + static int32x4_t impl(int16x8_t a, int16x8_t b, int32x4_t c) { + return vmlal_s16(c, vget_low_s16(a), vget_low_s16(b)); + } +}; + +struct Vld1q_s8 { + static int8x16_t impl(const int8_t* ptr) { return vld1q_s8(ptr); } +}; +struct Vld1_s8 { + static int8x8_t impl(const int8_t* ptr) { return vld1_s8(ptr); } +}; +struct Vldq_dup_4s8_8s16 { + static int16x8_t impl(const int8_t* ptr) { return vldq_dup_4s8_8s16(ptr); } +}; + +struct Vldq_tbl_low_s8 { + static int8x8_t impl(const int8_t* ptr, uint8x16_t idx) { + return vldq_tbl_low_s8(ptr, idx); + } +}; + +struct Vld1_dup_s8_s16 { + static int16x8_t impl(const int8_t* ptr) { return vld1_dup_s8_s16(ptr); } +}; +} // namespace +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp new file mode 100644 index 00000000..b1f7808e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -0,0 +1,236 @@ +/** + * \file dnn/src/arm_common/conv_bias/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8/algos.h" +#include "src/arm_common/conv_bias/int8x8x16/algos.h" +#include "src/arm_common/conv_bias/quint8/algos.h" + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/common/metahelper.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" + +#include "src/arm_common/convolution/opr_impl.h" +#include "src/arm_common/matrix_mul/opr_impl.h" +#include "src/common/opr_delegate.h" + +#include "include/megdnn/oprs/nn.h" +#include "src/arm_common/conv_bias/f16/algos.h" +#include "src/arm_common/conv_bias/fp32/algos.h" +#include "src/arm_common/conv_bias/int8/stride1.h" +#include "src/arm_common/conv_bias/int8/stride2.h" +#include "src/arm_common/conv_bias/quint8/stride1.h" +#include "src/arm_common/conv_bias/quint8/stride2.h" +#include "src/arm_common/convolution/opr_impl.h" + +using namespace megdnn; +using namespace arm_common; + +namespace { +uint8_t arm_common_algo_type_storage; +} // anonymous namespace + +class ConvBiasImpl::AlgoPack : NonCopyableObj { + AlgoQU8DirectStride2 qu8_direct_stride2_large_group{true}; + AlgoQU8DirectStride2 qu8_direct_stride2_small_group{false}; + AlgoQU8DirectStride1 qu8_direct_stride1_large_group{true}; + AlgoQU8DirectStride1 qu8_direct_stride1_small_group{false}; + AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; + AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; + AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44; + AlgoS8DirectStride2NCHWNCHW44 s8_direct_stride2_nchw_nchw44; + AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; + AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; + AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44; + AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; + AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; + +#if __ARM_FEATURE_DOTPROD + AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true}; + AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false}; + AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true}; + AlgoDotS8DirectStride2 ds8_direct_stride2_small_group{false}; + AlgoDotU8DirectStride1 du8_direct_stride1_large_group{true}; + AlgoDotU8DirectStride1 du8_direct_stride1_small_group{false}; + AlgoDotU8DirectStride2 du8_direct_stride2_large_group{true}; + AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false}; +#endif + + AlgoF32Direct f32_direct_large_group{true}; + AlgoF32Direct f32_direct_small_group{false}; + AlgoF32DirectStride2 f32_direct_stride2_large_group{true}; + AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; + AlgoF32DirectStride1 f32_direct_stride1_large_group{true}; + AlgoF32DirectStride1 f32_direct_stride1_small_group{false}; + AlgoI8x8x16Direct i8x8x16_direct_large_group{true}; + AlgoI8x8x16Direct i8x8x16_direct_small_group{false}; + AlgoI8x8x16Stride2 i8x8x16_stride2_large_group{true}; + AlgoI8x8x16Stride2 i8x8x16_stride2_small_group{false}; + AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + AlgoF16Direct f16_direct_large_group{true}; + AlgoF16Direct f16_direct_small_group{false}; + AlgoF16DirectStride1 f16_direct_stride1_large_group{true}; + AlgoF16DirectStride1 f16_direct_stride1_small_group{false}; +#endif + + SmallVector> refhold; + +public: + AlgoPack() { +#if __ARM_FEATURE_DOTPROD + direct_algos.emplace_back(&ds8_direct_stride1_large_group); + direct_algos.emplace_back(&ds8_direct_stride1_small_group); + direct_algos.emplace_back(&ds8_direct_stride2_large_group); + direct_algos.emplace_back(&ds8_direct_stride2_small_group); + direct_algos.emplace_back(&du8_direct_stride1_large_group); + direct_algos.emplace_back(&du8_direct_stride1_small_group); + direct_algos.emplace_back(&du8_direct_stride2_large_group); + direct_algos.emplace_back(&du8_direct_stride2_small_group); +#endif + direct_algos.emplace_back(&qu8_direct_stride2_large_group); + direct_algos.emplace_back(&qu8_direct_stride2_small_group); + direct_algos.emplace_back(&qu8_direct_stride1_large_group); + direct_algos.emplace_back(&qu8_direct_stride1_small_group); + direct_algos.emplace_back(&s8_direct_stride2_large_group); + direct_algos.emplace_back(&s8_direct_stride2_small_group); + direct_algos.emplace_back(&s8_direct_stride2_nchw44); + direct_algos.emplace_back(&s8_direct_stride2_nchw_nchw44); + direct_algos.emplace_back(&s8_direct_stride1_large_group); + direct_algos.emplace_back(&s8_direct_stride1_small_group); + direct_algos.emplace_back(&s8_direct_stride1_nchw44); + + direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); + direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + direct_algos.emplace_back(&f16_direct_stride1_large_group); + direct_algos.emplace_back(&f16_direct_stride1_small_group); + direct_algos.emplace_back(&f16_direct_large_group); + direct_algos.emplace_back(&f16_direct_small_group); +#endif + direct_algos.emplace_back(&i8x8x16_direct_large_group); + direct_algos.emplace_back(&i8x8x16_direct_small_group); + direct_algos.emplace_back(&i8x8x16_stride2_filter2); + direct_algos.emplace_back(&i8x8x16_stride2_large_group); + direct_algos.emplace_back(&i8x8x16_stride2_small_group); + direct_algos.emplace_back(&f32_direct_stride1_large_group); + direct_algos.emplace_back(&f32_direct_stride1_small_group); + direct_algos.emplace_back(&f32_direct_stride2_large_group); + direct_algos.emplace_back(&f32_direct_stride2_small_group); + direct_algos.emplace_back(&f32_direct_large_group); + direct_algos.emplace_back(&f32_direct_small_group); + + static CpuOprDelegationStorage<2> storage; + auto matmul_opr = storage.get(); + auto&& matmul_algos = + static_cast(matmul_opr) + ->algo_pack(); + for (auto&& algo : matmul_algos) { + if (algo->type() == nullptr) + continue; + for (uint32_t tile_size : {8, 16, 24, 32, 40, 48, 64, 80}) { + refhold.emplace_back(new AlgoFP32WinogradF23_4x4( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF63( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF63_4x4( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF54( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF45( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + refhold.emplace_back(new AlgoFP16WinogradF23( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP16WinogradF45( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP16WinogradF63( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP16WinogradF23_8x8( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); +#endif + refhold.emplace_back(new AlgoS8WinogradF23_8x8( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); + } + } + } + SmallVector direct_algos; + SmallVector winograd_algos; +}; + +SmallVector ConvBiasImpl::algo_pack() { + static AlgoPack sl_algo_pack; + auto&& algos = fallback::ConvBiasImpl::algo_pack(); + algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), + sl_algo_pack.direct_algos.end()); + algos.insert(algos.end(), sl_algo_pack.winograd_algos.begin(), + sl_algo_pack.winograd_algos.end()); + return std::move(algos); +} + +void* const ConvBiasImpl::sm_arm_common_algo_type = + &arm_common_algo_type_storage; + +bool ConvBiasImpl::is_matmul_quantized_prefer( + const ConvBiasImpl::NCBKernSizeParam& param) { + // fallback::ConvBiasImpl::NCBKernParam conv_ncb_param; + fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param( + param, 0, param::MatrixMul::Format::DEFAULT, {}, 0, + BiasMode::NO_BIAS, param::ConvBias::NonlineMode::IDENTITY); + conv_ncb_param.dst_type = param.bias_type; + conv_ncb_param.filter_meta.group = 1; + + bool conv_direct_unusable = false; + if (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32) { + conv_direct_unusable = + !arm_common::direct_int8_stride1::can_conv_direct_stride1_int8( + conv_ncb_param) && + !arm_common::direct_int8_stride2::can_conv_direct_stride2_int8( + conv_ncb_param); + } else if (param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) { + conv_direct_unusable = + !arm_common::direct_quint8_stride1:: + can_conv_direct_stride1_quint8(conv_ncb_param) && + !arm_common::direct_quint8_stride2:: + can_conv_direct_stride2_quint8(conv_ncb_param); + } + return conv_direct_unusable; +} + +const char* ConvBiasImpl::get_algorithm_set_name() const { + // arm common version 0 + return "AC0"; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h new file mode 100644 index 00000000..f21dba87 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -0,0 +1,83 @@ +/** + * \file dnn/src/arm_common/conv_bias/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/common/utils.h" +#include "src/fallback/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class ConvBiasImpl : public fallback::ConvBiasImpl { +public: + using fallback::ConvBiasImpl::ConvBiasImpl; + using FallbackConvBiasImpl = fallback::ConvBiasImpl; + using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + + bool is_thread_safe() const override { return true; } + + SmallVector algo_pack() override; + + bool is_matmul_quantized_prefer( + const ConvBiasImpl::NCBKernSizeParam& ncb_param) override; + class AlgoPack; + +protected: + static void* const sm_arm_common_algo_type; + + const char* get_algorithm_set_name() const override; + +private: + class AlgoS8DirectStride1; + class AlgoS8DirectStride1NCHW44; + class AlgoS8DirectStride2; + class AlgoS8DirectStride2NCHW44; + class AlgoS8DirectStride2NCHWNCHW44; + class AlgoQU8DirectStride1; + class AlgoQU8DirectStride2; + class AlgoFP32WinogradF23_4x4; + class AlgoFP32WinogradF63; + class AlgoFP32WinogradF63_4x4; + class AlgoFP32WinogradF54; + class AlgoFP32WinogradF45; + + class AlgoS8ChanWiseStride1NCHW44; + class AlgoS8ChanWiseStride2NCHW44; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + class AlgoFP16WinogradF23; + class AlgoFP16WinogradF45; + class AlgoFP16WinogradF63; + class AlgoFP16WinogradF23_8x8; +#endif +#if __ARM_FEATURE_DOTPROD + class AlgoDotS8DirectStride1; + class AlgoDotS8DirectStride2; + class AlgoDotU8DirectStride1; + class AlgoDotU8DirectStride2; +#endif + class AlgoF32Direct; + class AlgoF32DirectStride1; + class AlgoF32DirectStride2; + class AlgoI8x8x16Direct; + class AlgoI8x8x16Stride2; + class AlgoI8x8x16Stride2Filter2; + class AlgoS8WinogradF23_8x8; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + class AlgoF16Direct; + class AlgoF16DirectStride1; +#endif +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/postprocess_helper.h b/dnn/src/arm_common/conv_bias/postprocess_helper.h new file mode 100644 index 00000000..a9e93694 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/postprocess_helper.h @@ -0,0 +1,340 @@ +/** + * \file dnn/src/arm_common/conv_bias/postprocess_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megdnn/basic_types.h" +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" +#include "src/arm_common/elemwise_op.h" +#include "src/fallback/conv_bias/opr_impl.h" +namespace { + + +#define CONCAT_OP(_name) megdnn::arm_common::_name +#define CONCAT_NL(_name) megdnn::NonlineMode::_name + +#define CB(_caller, _op, _mode) \ + case _mode: \ + _caller(_op); \ + break; + +#define DEFAULT \ + default: \ + megdnn_throw("unsupported nolinemode"); \ + break; + +#define HANDLE_IDENTITY() \ + case megdnn::NonlineMode::IDENTITY: \ + break; + +#define FOR_NONLINEAR_UNARY(_op) \ + megdnn::arm_common::OpCallerUnary<_op, megdnn::arm_common::VEC>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(dst_ptr), bias_type, dst_type, \ + N* OC* OH* OW); + +#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ + megdnn::arm_common:: \ + OpCallerBinary<_op, megdnn::arm_common::VEC_BCAST101>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, \ + dst_type, N, OC, OH* OW); + +#define FOR_NONLINEAR_BINARY(_op) \ + megdnn::arm_common:: \ + OpCallerBinary<_op, megdnn::arm_common::VEC_VEC>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, \ + dst_type, N* OC* OH* OW); + +#define FOR_BIAS(_mode) \ + switch (_mode) { \ + case megdnn::BiasMode::NO_BIAS: \ + FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY) \ + break; \ + case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST) \ + break; \ + case megdnn::BiasMode::BIAS: \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY) \ + break; \ + default: \ + megdnn_throw("no quantized unsupported biasmode"); \ + break; \ + } + +#define FOR_NONLINEAR(_caller) \ + switch (nonlineMode) { \ + CB(_caller, CONCAT_OP(AddOp), CONCAT_NL(IDENTITY)) \ + CB(_caller, CONCAT_OP(FuseAddReluOp), CONCAT_NL(RELU)) \ + CB(_caller, CONCAT_OP(FuseAddSigmoidOp), CONCAT_NL(SIGMOID)) \ + CB(_caller, CONCAT_OP(FuseAddHSwishOp), CONCAT_NL(H_SWISH)) \ + DEFAULT \ + } + +#define FOR_NONLINEAR_NOBIAS(_caller) \ + switch (nonlineMode) { \ + HANDLE_IDENTITY() \ + CB(_caller, CONCAT_OP(ReluOp), CONCAT_NL(RELU)) \ + CB(_caller, CONCAT_OP(SigmoidOp), CONCAT_NL(SIGMOID)) \ + CB(_caller, CONCAT_OP(HSwishOp), CONCAT_NL(H_SWISH)) \ + DEFAULT \ + } + +template +struct PostProcess { + static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, + size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + MEGDNN_MARK_USED_VAR(pack_oc_size); + FOR_BIAS(bias_mode) + } +}; + +template +struct PostProcess { + static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, + size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + MEGDNN_MARK_USED_VAR(conv_dst_ptr); + MEGDNN_MARK_USED_VAR(bias_ptr); + MEGDNN_MARK_USED_VAR(dst_ptr); + MEGDNN_MARK_USED_VAR(bias_mode); + MEGDNN_MARK_USED_VAR(nonlineMode); + MEGDNN_MARK_USED_VAR(bias_type); + MEGDNN_MARK_USED_VAR(dst_type); + MEGDNN_MARK_USED_VAR(N); + MEGDNN_MARK_USED_VAR(OC); + MEGDNN_MARK_USED_VAR(OH); + MEGDNN_MARK_USED_VAR(OW); + MEGDNN_MARK_USED_VAR(pack_oc_size); + megdnn_assert(bias_mode == megdnn::BiasMode::NO_BIAS && + nonlineMode == megdnn::NonlineMode::IDENTITY); + } +}; + +#undef FOR_NONLINEAR_UNARY +#undef FOR_NONLINEAR_BINARY_BROADCAST +#undef FOR_NONLINEAR_BINARY +#undef FOR_NONLINEAR_NOBIAS +#undef FOR_NONLINEAR +#undef FOR_BIAS +#undef HANDLE_IDENTITY + +#define FOR_NONLINEAR_UNARY(_op) \ + megdnn::arm_common::OpCallerUnary< \ + _op, \ + megdnn::arm_common::VEC>::run(static_cast(conv_dst_ptr), \ + reinterpret_cast(dst_ptr), \ + bias_type, dst_type, N* OC* OH* OW); + +#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ + megdnn::arm_common::OpCallerBinary<_op, \ + megdnn::arm_common::VEC_BCAST101>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, \ + dst_type, N, OC, OH* OW); + +#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \ + megdnn::arm_common::OpCallerBinary<_op, \ + megdnn::arm_common::VEC_BCAST101x4>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, \ + dst_type, N, OC, OH* OW, pack_oc_size); + +#define HANDLE_IDENTITY(_caller, _op) \ + case megdnn::NonlineMode::IDENTITY: \ + _caller(_op) break; + +#define FOR_NONLINEAR(_caller) \ + switch (nonlineMode) { \ + HANDLE_IDENTITY(_caller, CONCAT_OP(AddOp)) \ + CB(_caller, CONCAT_OP(FuseAddReluOp), CONCAT_NL(RELU)) \ + CB(_caller, CONCAT_OP(FuseAddHSwishOp), CONCAT_NL(H_SWISH)) \ + DEFAULT \ + } + +#define FOR_NONLINEAR_NOBIAS(_caller) \ + switch (nonlineMode) { \ + HANDLE_IDENTITY(_caller, CONCAT_OP(TypeCvtOp)) \ + CB(_caller, CONCAT_OP(ReluOp), CONCAT_NL(RELU)) \ + CB(_caller, CONCAT_OP(HSwishOp), CONCAT_NL(H_SWISH)) \ + DEFAULT \ + } + +#define FOR_BIAS(_bias_mode) \ + switch (_bias_mode) { \ + case megdnn::BiasMode::NO_BIAS: \ + FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY); \ + break; \ + case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ + if (pack_oc_size == 1) { \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ + } else { \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ + } \ + break; \ + default: \ + megdnn_throw("quantized unsupported biasmode"); \ + break; \ + } + +template +struct PostProcess { + static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, + size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + FOR_BIAS(bias_mode); + } +}; + +#undef FOR_NONLINEAR_UNARY +#undef FOR_NONLINEAR_BINARY_BROADCAST +#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44 +#undef FOR_NONLINEAR_BINARY +#undef FOR_NONLINEAR_NOBIAS +#undef FOR_NONLINEAR +#undef FOR_BIAS +#undef CB +#undef CONCAT_OP +#undef CONCAT_NL +#undef DEFAULT +#undef HANDLE_IDENTITY + +#define DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, _bias_id, _src_type, \ + _dst_type, _bmode, _nonline_mode, ...) \ + switch (_nonline_mode) { \ + case param::ConvBias::NonlineMode::IDENTITY: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \ + cb(_bmode, NoneOp<_src_type MEGDNN_COMMA _dst_type>, \ + __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + case param::ConvBias::NonlineMode::RELU: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \ + cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, \ + __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + case param::ConvBias::NonlineMode::SIGMOID: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 2) { \ + cb(_bmode, SigmoidOp<_src_type MEGDNN_COMMA _dst_type>, \ + __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + case param::ConvBias::NonlineMode::H_SWISH: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 3) { \ + cb(_bmode, HSwishOp<_src_type MEGDNN_COMMA _dst_type>, \ + __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED(_midout_tag, cb, _bias_id, \ + _src_type, _dst_type, _bmode, \ + _nonline_mode, ...) \ + switch (_nonline_mode) { \ + case param::ConvBias::NonlineMode::IDENTITY: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \ + cb(_bmode, TypeCvtOp<_src_type MEGDNN_COMMA _dst_type>, \ + __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + case param::ConvBias::NonlineMode::RELU: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \ + cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, \ + __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define DISPATCH_CONV_WINOGRAD_BIAS(_midout_tag, cb, _src_type, _dst_type, \ + _bmode, _nonline_mode, ...) \ + switch (_bmode) { \ + case BiasMode::BIAS: { \ + DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, 0, _src_type, \ + _dst_type, BiasMode::BIAS, \ + _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + case BiasMode::NO_BIAS: { \ + DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, 1, _src_type, \ + _dst_type, BiasMode::NO_BIAS, \ + _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + case BiasMode::BROADCAST_CHANNEL_BIAS: { \ + DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, 2, _src_type, \ + _dst_type, \ + BiasMode::BROADCAST_CHANNEL_BIAS, \ + _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define DISPATCH_CONV_WINOGRAD_BIAS_QUANTIZED( \ + _midout_tag, cb, _src_type, _dst_type, _bmode, _nonline_mode, ...) \ + switch (_bmode) { \ + case BiasMode::BIAS: { \ + DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \ + _midout_tag, cb, 0, _src_type, _dst_type, BiasMode::BIAS, \ + _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + case BiasMode::NO_BIAS: { \ + DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \ + _midout_tag, cb, 1, _src_type, _dst_type, \ + BiasMode::NO_BIAS, _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + case BiasMode::BROADCAST_CHANNEL_BIAS: { \ + DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \ + _midout_tag, cb, 2, _src_type, _dst_type, \ + BiasMode::BROADCAST_CHANNEL_BIAS, _nonline_mode, \ + __VA_ARGS__) \ + break; \ + } \ + default: \ + megdnn_assert(0); \ + break; \ + } + +} // namespace diff --git a/dnn/src/arm_common/conv_bias/quint8/algos.cpp b/dnn/src/arm_common/conv_bias/quint8/algos.cpp new file mode 100644 index 00000000..8c743da9 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/algos.cpp @@ -0,0 +1,149 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/quint8/algos.h" +#include "src/arm_common/conv_bias/quint8/stride1.h" +#include "src/arm_common/conv_bias/quint8/stride2.h" +#include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" +#include "src/arm_common/conv_bias/quint8/stride2_dotprod.h" +#include "src/arm_common/elemwise_op.h" +#include "src/fallback/conv_bias/common.h" +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_common_conv_bias_quint8) + +using namespace megdnn; +using namespace arm_common; + +/* ===================== stride1 algo ===================== */ +bool ConvBiasImpl::AlgoQU8DirectStride1::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + bool avaible = direct_quint8_stride1::can_conv_direct_stride1_quint8(param); + if (algo_selection_strategy == + ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + avaible &= (large_group == m_large_group); + } + return avaible; +} + +size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto bundle = direct_quint8_stride1::get_bundle(param, m_large_group); + return bundle.total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 0) { + return direct_quint8_stride1::get_kimpls(param, m_large_group); + } + MIDOUT_END(); + return {}; +} + +/* ===================== stride2 algo ===================== */ +bool ConvBiasImpl::AlgoQU8DirectStride2::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + bool avaible = direct_quint8_stride2::can_conv_direct_stride2_quint8(param); + if (algo_selection_strategy == + ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + avaible &= (large_group == m_large_group); + } + return avaible; +} + +size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto bundle = direct_quint8_stride2::get_bundle(param, m_large_group); + return bundle.total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 1) { + return direct_quint8_stride2::get_kimpls(param, m_large_group); + } + MIDOUT_END(); + return {}; +} +#if __ARM_FEATURE_DOTPROD +/* ===================== stride1 algo ===================== */ +bool ConvBiasImpl::AlgoDotU8DirectStride1::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + bool avaible = + direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8( + param); + if (algo_selection_strategy == + ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + avaible &= (large_group == m_large_group); + } + return avaible; +} + +size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto bundle = + direct_dotprod_quint8_stride1::get_bundle(param, m_large_group); + return bundle.total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 0) { + return direct_dotprod_quint8_stride1::get_kimpls(param, m_large_group); + } + MIDOUT_END(); + return {}; +} + +/* ===================== stride2 algo ===================== */ +bool ConvBiasImpl::AlgoDotU8DirectStride2::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + bool avaible = + direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8( + param); + if (algo_selection_strategy == + ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { + bool large_group = param.filter_meta.group >= param.nr_threads; + avaible &= (large_group == m_large_group); + } + return avaible; +} + +size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto bundle = + direct_dotprod_quint8_stride2::get_bundle(param, m_large_group); + return bundle.total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 1) { + return direct_dotprod_quint8_stride2::get_kimpls(param, m_large_group); + } + MIDOUT_END(); + return {}; +} + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/quint8/algos.h b/dnn/src/arm_common/conv_bias/quint8/algos.h new file mode 100644 index 00000000..33a457e6 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/algos.h @@ -0,0 +1,101 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class ConvBiasImpl::AlgoQU8DirectStride1 final : public AlgoBase { + bool m_large_group; + +public: + AlgoQU8DirectStride1(bool large_group) : m_large_group(large_group) {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "QU8STRD1_LARGE_GROUP" : "QU8STRD1_SMALL_GROUP"; + } + + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { + bool m_large_group; + +public: + AlgoQU8DirectStride2(bool large_group) : m_large_group(large_group) {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "QU8STRD2_LARGE_GROUP" : "QU8STRD2_SMALL_GROUP"; + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; +#if __ARM_FEATURE_DOTPROD +class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { + bool m_large_group; + +public: + AlgoDotU8DirectStride1(bool large_group) : m_large_group(large_group) {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "ARMDOTU8STRD1_LARGE_GROUP" + : "ARMDOTU8STRD1_SMALL_GROUP"; + } + + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + +class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { + bool m_large_group; + +public: + AlgoDotU8DirectStride2(bool large_group) : m_large_group(large_group) {} + bool is_reproducible() const override { return true; } + const char* name() const override { + return m_large_group ? "ARMDOTU8STRD2_LARGE_GROUP" + : "ARMDOTU8STRD2_SMALL_GROUP"; + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; +#endif +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/quint8/direct.cpp b/dnn/src/arm_common/conv_bias/quint8/direct.cpp new file mode 100644 index 00000000..a374817c --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/direct.cpp @@ -0,0 +1,2130 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/direct.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/quint8/direct.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; + +#include "midout.h" +MIDOUT_DECL(conv_direct_stride) + +#define ACC_S16_S32(dst0, dst1, src) \ + dst0 = vaddw_s16(dst0, vget_low_s16(src)); \ + dst1 = vaddw_s16(dst1, vget_high_s16(src)); + +#define SUB128(n) static_cast(static_cast(n) - 128) + +#define SUB128VECTOR(src) \ + vqmovn_s16(vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(src)), v128)) + +#define MLSFZP(s, f) vmlsl_s8(vmull_s8(s, f), s, fzp) + +#define POSTPROCESS(dst0, dst1, tptr, dptr) \ + if (last_ic) { \ + op({{dst0, dst1}}, reinterpret_cast(dptr)); \ + } else { \ + vst1q_s32(tptr, dst0); \ + vst1q_s32(tptr + 4, dst1); \ + } + +template +void conv_bias::conv_direct_stride1_2x2_quint8( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const int8_t src_zp, + const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + MIDOUT_BEGIN(conv_direct_stride, 0, 0) { + int16x8_t v128 = vdupq_n_s16(128); + int32x4_t vsrc_filter_zp = vdupq_n_s32(src_filter_zp); + int8x8_t f00 = vdup_n_s8(SUB128(filter[0])); + int8x8_t f01 = vdup_n_s8(SUB128(filter[1])); + int8x8_t f10 = vdup_n_s8(SUB128(filter[2])); + int8x8_t f11 = vdup_n_s8(SUB128(filter[3])); + + int8x8_t fzp = vdup_n_s8(filter_zp); + + // get filter * src_zp for one IC + int32_t fxszp = 0; + for (size_t i = 0; i < 4; ++i) + fxszp += static_cast(filter[i]) - 128; + int32x4_t vfxszp = vdupq_n_s32(fxszp * static_cast(src_zp)); + + // 4x8 block + size_t oh = 0; + for (; oh + 4 <= OH; oh += 4) { + size_t ih = oh; + size_t ow = 0; + for (; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11, sum20, sum21, sum30, sum31; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + sum20 = vld1q_s32(tptr + 2 * OW); + sum21 = vld1q_s32(tptr + 2 * OW + 4); + sum30 = vld1q_s32(tptr + 3 * OW); + sum31 = vld1q_s32(tptr + 3 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + sum20 = sum00; + sum21 = sum00; + sum30 = sum00; + sum31 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + sum20 = vdupq_n_s32(0); + sum21 = vdupq_n_s32(0); + sum30 = vdupq_n_s32(0); + sum31 = vdupq_n_s32(0); + } + + // src_zp * filter_zp for one OC + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + sum10 += vsrc_filter_zp; + sum11 += vsrc_filter_zp; + sum20 += vsrc_filter_zp; + sum21 += vsrc_filter_zp; + sum30 += vsrc_filter_zp; + sum31 += vsrc_filter_zp; + } + + int8x8_t s = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f00)); + + s = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f10)); + ACC_S16_S32(sum10, sum11, MLSFZP(s, f00)); + + s = SUB128VECTOR(vld1_u8(sptr + 2 * IW)); + ACC_S16_S32(sum10, sum11, MLSFZP(s, f10)); + ACC_S16_S32(sum20, sum21, MLSFZP(s, f00)); + + s = SUB128VECTOR(vld1_u8(sptr + 3 * IW)); + ACC_S16_S32(sum20, sum21, MLSFZP(s, f10)); + ACC_S16_S32(sum30, sum31, MLSFZP(s, f00)); + + s = SUB128VECTOR(vld1_u8(sptr + 4 * IW)); + ACC_S16_S32(sum30, sum31, MLSFZP(s, f10)); + + ++sptr; + + s = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f01)); + + s = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f11)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + ACC_S16_S32(sum10, sum11, MLSFZP(s, f01)); + + s = SUB128VECTOR(vld1_u8(sptr + 2 * IW)); + ACC_S16_S32(sum10, sum11, MLSFZP(s, f11)); + sum10 = vsubq_s32(sum10, vfxszp); + sum11 = vsubq_s32(sum11, vfxszp); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + ACC_S16_S32(sum20, sum21, MLSFZP(s, f01)); + + s = SUB128VECTOR(vld1_u8(sptr + 3 * IW)); + ACC_S16_S32(sum20, sum21, MLSFZP(s, f11)); + sum20 = vsubq_s32(sum20, vfxszp); + sum21 = vsubq_s32(sum21, vfxszp); + POSTPROCESS(sum20, sum21, tptr + 2 * OW, dptr + 2 * OW); + ACC_S16_S32(sum30, sum31, MLSFZP(s, f01)); + + s = SUB128VECTOR(vld1_u8(sptr + 4 * IW)); + ACC_S16_S32(sum30, sum31, MLSFZP(s, f11)); + sum30 = vsubq_s32(sum30, vfxszp); + sum31 = vsubq_s32(sum31, vfxszp); + + POSTPROCESS(sum30, sum31, tptr + 3 * OW, dptr + 3 * OW); + } + } + + if (oh + 3 == OH) { + size_t ih = oh; + size_t ow = 0; + for (; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11, sum20, sum21; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + sum20 = vld1q_s32(tptr + 2 * OW); + sum21 = vld1q_s32(tptr + 2 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + sum20 = sum00; + sum21 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + sum20 = vdupq_n_s32(0); + sum21 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + sum10 += vsrc_filter_zp; + sum11 += vsrc_filter_zp; + sum20 += vsrc_filter_zp; + sum21 += vsrc_filter_zp; + } + + int8x8_t s = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f00)); + + s = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f10)); + ACC_S16_S32(sum10, sum11, MLSFZP(s, f00)); + + s = SUB128VECTOR(vld1_u8(sptr + 2 * IW)); + ACC_S16_S32(sum10, sum11, MLSFZP(s, f10)); + ACC_S16_S32(sum20, sum21, MLSFZP(s, f00)); + + s = SUB128VECTOR(vld1_u8(sptr + 3 * IW)); + ACC_S16_S32(sum20, sum21, MLSFZP(s, f10)); + + ++sptr; + + s = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f01)); + + s = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f11)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + ACC_S16_S32(sum10, sum11, MLSFZP(s, f01)); + + s = SUB128VECTOR(vld1_u8(sptr + 2 * IW)); + ACC_S16_S32(sum10, sum11, MLSFZP(s, f11)); + sum10 = vsubq_s32(sum10, vfxszp); + sum11 = vsubq_s32(sum11, vfxszp); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + ACC_S16_S32(sum20, sum21, MLSFZP(s, f01)); + + s = SUB128VECTOR(vld1_u8(sptr + 3 * IW)); + ACC_S16_S32(sum20, sum21, MLSFZP(s, f11)); + sum20 = vsubq_s32(sum20, vfxszp); + sum21 = vsubq_s32(sum21, vfxszp); + POSTPROCESS(sum20, sum21, tptr + 2 * OW, dptr + 2 * OW); + } + } else if (oh + 2 == OH) { + size_t ih = oh; + size_t ow = 0; + for (; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + sum10 += vsrc_filter_zp; + sum11 += vsrc_filter_zp; + } + + int8x8_t s = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f00)); + + s = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f10)); + ACC_S16_S32(sum10, sum11, MLSFZP(s, f00)); + + s = SUB128VECTOR(vld1_u8(sptr + 2 * IW)); + ACC_S16_S32(sum10, sum11, MLSFZP(s, f10)); + + ++sptr; + + s = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f01)); + + s = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f11)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + ACC_S16_S32(sum10, sum11, MLSFZP(s, f01)); + + s = SUB128VECTOR(vld1_u8(sptr + 2 * IW)); + ACC_S16_S32(sum10, sum11, MLSFZP(s, f11)); + + sum10 = vsubq_s32(sum10, vfxszp); + sum11 = vsubq_s32(sum11, vfxszp); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } else if (oh + 1 == OH) { + size_t ih = oh; + size_t ow = 0; + for (; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + } + + int8x8_t s = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f00)); + + s = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f10)); + + ++sptr; + + s = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f01)); + + s = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + ACC_S16_S32(sum00, sum01, MLSFZP(s, f11)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } + } + MIDOUT_END(); +} + +template +void conv_bias::conv_direct_stride1_3x3_quint8( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const int8_t src_zp, + const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + MIDOUT_BEGIN(conv_direct_stride, 0, 1) { + int16x8_t v128 = vdupq_n_s16(128); + int32x4_t vsrc_filter_zp = vdupq_n_s32(src_filter_zp); + int8x8_t fzp = vdup_n_s8(filter_zp); + + int8x8_t f00 = vdup_n_s8(SUB128(filter[0])); + int8x8_t f01 = vdup_n_s8(SUB128(filter[1])); + int8x8_t f02 = vdup_n_s8(SUB128(filter[2])); + int8x8_t f10 = vdup_n_s8(SUB128(filter[3])); + int8x8_t f11 = vdup_n_s8(SUB128(filter[4])); + int8x8_t f12 = vdup_n_s8(SUB128(filter[5])); + int8x8_t f20 = vdup_n_s8(SUB128(filter[6])); + int8x8_t f21 = vdup_n_s8(SUB128(filter[7])); + int8x8_t f22 = vdup_n_s8(SUB128(filter[8])); + + int32_t fxszp = 0; + for (size_t i = 0; i < 9; ++i) + fxszp += static_cast(filter[i]) - 128; + int32x4_t vfxszp = vdupq_n_s32(fxszp * static_cast(src_zp)); + + // block 2x8 + size_t oh = 0; + for (; oh + 1 < OH; oh += 2) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + sum10 += vsrc_filter_zp; + sum11 += vsrc_filter_zp; + } + + int8x8_t _r00 = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + int8x8_t _r0n = SUB128VECTOR(vld1_u8(sptr + 0 * IW + 8)); + int8x8_t _r01 = vext_s8(_r00, _r0n, 1); + int8x8_t _r02 = vext_s8(_r00, _r0n, 2); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f00)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f01)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f02)); + + int8x8_t _r10 = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + int8x8_t _r1n = SUB128VECTOR(vld1_u8(sptr + 1 * IW + 8)); + int8x8_t _r11 = vext_s8(_r10, _r1n, 1); + int8x8_t _r12 = vext_s8(_r10, _r1n, 2); + ACC_S16_S32(sum00, sum01, MLSFZP(_r10, f10)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r11, f11)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r12, f12)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r10, f00)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r11, f01)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r12, f02)); + + int8x8_t _r20 = SUB128VECTOR(vld1_u8(sptr + 2 * IW)); + int8x8_t _r2n = SUB128VECTOR(vld1_u8(sptr + 2 * IW + 8)); + int8x8_t _r21 = vext_s8(_r20, _r2n, 1); + int8x8_t _r22 = vext_s8(_r20, _r2n, 2); + ACC_S16_S32(sum00, sum01, MLSFZP(_r20, f20)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r21, f21)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r22, f22)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + ACC_S16_S32(sum10, sum11, MLSFZP(_r20, f10)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r21, f11)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r22, f12)); + + int8x8_t _r30 = SUB128VECTOR(vld1_u8(sptr + 3 * IW)); + int8x8_t _r3n = SUB128VECTOR(vld1_u8(sptr + 3 * IW + 8)); + int8x8_t _r31 = vext_s8(_r30, _r3n, 1); + int8x8_t _r32 = vext_s8(_r30, _r3n, 2); + ACC_S16_S32(sum10, sum11, MLSFZP(_r30, f20)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r31, f21)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r32, f22)); + sum10 = vsubq_s32(sum10, vfxszp); + sum11 = vsubq_s32(sum11, vfxszp); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } + + if (oh < OH) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + } + + int8x8_t _r00 = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + int8x8_t _r0n = SUB128VECTOR(vld1_u8(sptr + 0 * IW + 8)); + int8x8_t _r01 = vext_s8(_r00, _r0n, 1); + int8x8_t _r02 = vext_s8(_r00, _r0n, 2); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f00)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f01)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f02)); + + int8x8_t _r10 = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + int8x8_t _r1n = SUB128VECTOR(vld1_u8(sptr + 1 * IW + 8)); + int8x8_t _r11 = vext_s8(_r10, _r1n, 1); + int8x8_t _r12 = vext_s8(_r10, _r1n, 2); + ACC_S16_S32(sum00, sum01, MLSFZP(_r10, f10)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r11, f11)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r12, f12)); + + int8x8_t _r20 = SUB128VECTOR(vld1_u8(sptr + 2 * IW)); + int8x8_t _r2n = SUB128VECTOR(vld1_u8(sptr + 2 * IW + 8)); + int8x8_t _r21 = vext_s8(_r20, _r2n, 1); + int8x8_t _r22 = vext_s8(_r20, _r2n, 2); + ACC_S16_S32(sum00, sum01, MLSFZP(_r20, f20)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r21, f21)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r22, f22)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } + } + MIDOUT_END(); +} + +template +void conv_bias::conv_direct_stride1_5x5_quint8( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const int8_t src_zp, + const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + MIDOUT_BEGIN(conv_direct_stride, 0, 2) { + int16x8_t v128 = vdupq_n_s16(128); + int32x4_t vsrc_filter_zp = vdupq_n_s32(src_filter_zp); + int8x8_t fzp = vdup_n_s8(filter_zp); + + int8x8_t f00 = vdup_n_s8(SUB128(filter[0])); + int8x8_t f01 = vdup_n_s8(SUB128(filter[1])); + int8x8_t f02 = vdup_n_s8(SUB128(filter[2])); + int8x8_t f03 = vdup_n_s8(SUB128(filter[3])); + int8x8_t f04 = vdup_n_s8(SUB128(filter[4])); + int8x8_t f10 = vdup_n_s8(SUB128(filter[5])); + int8x8_t f11 = vdup_n_s8(SUB128(filter[6])); + int8x8_t f12 = vdup_n_s8(SUB128(filter[7])); + int8x8_t f13 = vdup_n_s8(SUB128(filter[8])); + int8x8_t f14 = vdup_n_s8(SUB128(filter[9])); + int8x8_t f20 = vdup_n_s8(SUB128(filter[10])); + int8x8_t f21 = vdup_n_s8(SUB128(filter[11])); + int8x8_t f22 = vdup_n_s8(SUB128(filter[12])); + int8x8_t f23 = vdup_n_s8(SUB128(filter[13])); + int8x8_t f24 = vdup_n_s8(SUB128(filter[14])); + int8x8_t f30 = vdup_n_s8(SUB128(filter[15])); + int8x8_t f31 = vdup_n_s8(SUB128(filter[16])); + int8x8_t f32 = vdup_n_s8(SUB128(filter[17])); + int8x8_t f33 = vdup_n_s8(SUB128(filter[18])); + int8x8_t f34 = vdup_n_s8(SUB128(filter[19])); + int8x8_t f40 = vdup_n_s8(SUB128(filter[20])); + int8x8_t f41 = vdup_n_s8(SUB128(filter[21])); + int8x8_t f42 = vdup_n_s8(SUB128(filter[22])); + int8x8_t f43 = vdup_n_s8(SUB128(filter[23])); + int8x8_t f44 = vdup_n_s8(SUB128(filter[24])); + + // get filter * src_zp for one IC + int32_t fxszp = 0; + for (size_t i = 0; i < 25; ++i) + fxszp += static_cast(filter[i]) - 128; + int32x4_t vfxszp = vdupq_n_s32(fxszp * static_cast(src_zp)); + + // block 2x8 + size_t oh = 0; + for (; oh + 1 < OH; oh += 2) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + sum10 += vsrc_filter_zp; + sum11 += vsrc_filter_zp; + } + + int8x8_t _r00 = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + int8x8_t _r0n = SUB128VECTOR(vld1_u8(sptr + 0 * IW + 8)); + int8x8_t _r01 = vext_s8(_r00, _r0n, 1); + int8x8_t _r02 = vext_s8(_r00, _r0n, 2); + int8x8_t _r03 = vext_s8(_r00, _r0n, 3); + int8x8_t _r04 = vext_s8(_r00, _r0n, 4); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f00)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f01)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f02)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f03)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f04)); + + int8x8_t _r10 = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + int8x8_t _r1n = SUB128VECTOR(vld1_u8(sptr + 1 * IW + 8)); + int8x8_t _r11 = vext_s8(_r10, _r1n, 1); + int8x8_t _r12 = vext_s8(_r10, _r1n, 2); + int8x8_t _r13 = vext_s8(_r10, _r1n, 3); + int8x8_t _r14 = vext_s8(_r10, _r1n, 4); + ACC_S16_S32(sum00, sum01, MLSFZP(_r10, f10)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r11, f11)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r12, f12)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r13, f13)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r14, f14)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r10, f00)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r11, f01)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r12, f02)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r13, f03)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r14, f04)) + + int8x8_t _r20 = SUB128VECTOR(vld1_u8(sptr + 2 * IW)); + int8x8_t _r2n = SUB128VECTOR(vld1_u8(sptr + 2 * IW + 8)); + int8x8_t _r21 = vext_s8(_r20, _r2n, 1); + int8x8_t _r22 = vext_s8(_r20, _r2n, 2); + int8x8_t _r23 = vext_s8(_r20, _r2n, 3); + int8x8_t _r24 = vext_s8(_r20, _r2n, 4); + ACC_S16_S32(sum00, sum01, MLSFZP(_r20, f20)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r21, f21)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r22, f22)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r23, f23)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r24, f24)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r20, f10)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r21, f11)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r22, f12)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r23, f13)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r24, f14)) + + int8x8_t _r30 = SUB128VECTOR(vld1_u8(sptr + 3 * IW)); + int8x8_t _r3n = SUB128VECTOR(vld1_u8(sptr + 3 * IW + 8)); + int8x8_t _r31 = vext_s8(_r30, _r3n, 1); + int8x8_t _r32 = vext_s8(_r30, _r3n, 2); + int8x8_t _r33 = vext_s8(_r30, _r3n, 3); + int8x8_t _r34 = vext_s8(_r30, _r3n, 4); + ACC_S16_S32(sum00, sum01, MLSFZP(_r30, f30)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r31, f31)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r32, f32)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r33, f33)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r34, f34)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r30, f20)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r31, f21)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r32, f22)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r33, f23)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r34, f24)); + + int8x8_t _r40 = SUB128VECTOR(vld1_u8(sptr + 4 * IW)); + int8x8_t _r4n = SUB128VECTOR(vld1_u8(sptr + 4 * IW + 8)); + int8x8_t _r41 = vext_s8(_r40, _r4n, 1); + int8x8_t _r42 = vext_s8(_r40, _r4n, 2); + int8x8_t _r43 = vext_s8(_r40, _r4n, 3); + int8x8_t _r44 = vext_s8(_r40, _r4n, 4); + ACC_S16_S32(sum00, sum01, MLSFZP(_r40, f40)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r41, f41)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r42, f42)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r43, f43)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r44, f44)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + ACC_S16_S32(sum10, sum11, MLSFZP(_r40, f30)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r41, f31)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r42, f32)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r43, f33)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r44, f34)); + + int8x8_t _r50 = SUB128VECTOR(vld1_u8(sptr + 5 * IW)); + int8x8_t _r5n = SUB128VECTOR(vld1_u8(sptr + 5 * IW + 8)); + int8x8_t _r51 = vext_s8(_r50, _r5n, 1); + int8x8_t _r52 = vext_s8(_r50, _r5n, 2); + int8x8_t _r53 = vext_s8(_r50, _r5n, 3); + int8x8_t _r54 = vext_s8(_r50, _r5n, 4); + ACC_S16_S32(sum10, sum11, MLSFZP(_r50, f40)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r51, f41)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r52, f42)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r53, f43)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r54, f44)); + sum10 = vsubq_s32(sum10, vfxszp); + sum11 = vsubq_s32(sum11, vfxszp); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } + + if (oh < OH) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + } + + int8x8_t _r00 = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + int8x8_t _r0n = SUB128VECTOR(vld1_u8(sptr + 0 * IW + 8)); + int8x8_t _r01 = vext_s8(_r00, _r0n, 1); + int8x8_t _r02 = vext_s8(_r00, _r0n, 2); + int8x8_t _r03 = vext_s8(_r00, _r0n, 3); + int8x8_t _r04 = vext_s8(_r00, _r0n, 4); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f00)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f01)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f02)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f03)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f04)); + + int8x8_t _r10 = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + int8x8_t _r1n = SUB128VECTOR(vld1_u8(sptr + 1 * IW + 8)); + int8x8_t _r11 = vext_s8(_r10, _r1n, 1); + int8x8_t _r12 = vext_s8(_r10, _r1n, 2); + int8x8_t _r13 = vext_s8(_r10, _r1n, 3); + int8x8_t _r14 = vext_s8(_r10, _r1n, 4); + ACC_S16_S32(sum00, sum01, MLSFZP(_r10, f10)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r11, f11)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r12, f12)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r13, f13)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r14, f14)); + + int8x8_t _r20 = SUB128VECTOR(vld1_u8(sptr + 2 * IW)); + int8x8_t _r2n = SUB128VECTOR(vld1_u8(sptr + 2 * IW + 8)); + int8x8_t _r21 = vext_s8(_r20, _r2n, 1); + int8x8_t _r22 = vext_s8(_r20, _r2n, 2); + int8x8_t _r23 = vext_s8(_r20, _r2n, 3); + int8x8_t _r24 = vext_s8(_r20, _r2n, 4); + ACC_S16_S32(sum00, sum01, MLSFZP(_r20, f20)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r21, f21)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r22, f22)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r23, f23)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r24, f24)); + + int8x8_t _r30 = SUB128VECTOR(vld1_u8(sptr + 3 * IW)); + int8x8_t _r3n = SUB128VECTOR(vld1_u8(sptr + 3 * IW + 8)); + int8x8_t _r31 = vext_s8(_r30, _r3n, 1); + int8x8_t _r32 = vext_s8(_r30, _r3n, 2); + int8x8_t _r33 = vext_s8(_r30, _r3n, 3); + int8x8_t _r34 = vext_s8(_r30, _r3n, 4); + ACC_S16_S32(sum00, sum01, MLSFZP(_r30, f30)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r31, f31)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r32, f32)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r33, f33)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r34, f34)); + + int8x8_t _r40 = SUB128VECTOR(vld1_u8(sptr + 4 * IW)); + int8x8_t _r4n = SUB128VECTOR(vld1_u8(sptr + 4 * IW + 8)); + int8x8_t _r41 = vext_s8(_r40, _r4n, 1); + int8x8_t _r42 = vext_s8(_r40, _r4n, 2); + int8x8_t _r43 = vext_s8(_r40, _r4n, 3); + int8x8_t _r44 = vext_s8(_r40, _r4n, 4); + ACC_S16_S32(sum00, sum01, MLSFZP(_r40, f40)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r41, f41)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r42, f42)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r43, f43)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r44, f44)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } + } + MIDOUT_END(); +} + +template +void conv_bias::conv_direct_stride1_7x7_quint8( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const int8_t src_zp, + const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + MIDOUT_BEGIN(conv_direct_stride, 0, 3) { + int16x8_t v128 = vdupq_n_s16(128); + int32x4_t vsrc_filter_zp = vdupq_n_s32(src_filter_zp); + int8x8_t fzp = vdup_n_s8(filter_zp); + + int8x8_t f00 = vdup_n_s8(SUB128(filter[0])); + int8x8_t f01 = vdup_n_s8(SUB128(filter[1])); + int8x8_t f02 = vdup_n_s8(SUB128(filter[2])); + int8x8_t f03 = vdup_n_s8(SUB128(filter[3])); + int8x8_t f04 = vdup_n_s8(SUB128(filter[4])); + int8x8_t f05 = vdup_n_s8(SUB128(filter[5])); + int8x8_t f06 = vdup_n_s8(SUB128(filter[6])); + + int8x8_t f10 = vdup_n_s8(SUB128(filter[7])); + int8x8_t f11 = vdup_n_s8(SUB128(filter[8])); + int8x8_t f12 = vdup_n_s8(SUB128(filter[9])); + int8x8_t f13 = vdup_n_s8(SUB128(filter[10])); + int8x8_t f14 = vdup_n_s8(SUB128(filter[11])); + int8x8_t f15 = vdup_n_s8(SUB128(filter[12])); + int8x8_t f16 = vdup_n_s8(SUB128(filter[13])); + + int8x8_t f20 = vdup_n_s8(SUB128(filter[14])); + int8x8_t f21 = vdup_n_s8(SUB128(filter[15])); + int8x8_t f22 = vdup_n_s8(SUB128(filter[16])); + int8x8_t f23 = vdup_n_s8(SUB128(filter[17])); + int8x8_t f24 = vdup_n_s8(SUB128(filter[18])); + int8x8_t f25 = vdup_n_s8(SUB128(filter[19])); + int8x8_t f26 = vdup_n_s8(SUB128(filter[20])); + + int8x8_t f30 = vdup_n_s8(SUB128(filter[21])); + int8x8_t f31 = vdup_n_s8(SUB128(filter[22])); + int8x8_t f32 = vdup_n_s8(SUB128(filter[23])); + int8x8_t f33 = vdup_n_s8(SUB128(filter[24])); + int8x8_t f34 = vdup_n_s8(SUB128(filter[25])); + int8x8_t f35 = vdup_n_s8(SUB128(filter[26])); + int8x8_t f36 = vdup_n_s8(SUB128(filter[27])); + + int8x8_t f40 = vdup_n_s8(SUB128(filter[28])); + int8x8_t f41 = vdup_n_s8(SUB128(filter[29])); + int8x8_t f42 = vdup_n_s8(SUB128(filter[30])); + int8x8_t f43 = vdup_n_s8(SUB128(filter[31])); + int8x8_t f44 = vdup_n_s8(SUB128(filter[32])); + int8x8_t f45 = vdup_n_s8(SUB128(filter[33])); + int8x8_t f46 = vdup_n_s8(SUB128(filter[34])); + + int8x8_t f50 = vdup_n_s8(SUB128(filter[35])); + int8x8_t f51 = vdup_n_s8(SUB128(filter[36])); + int8x8_t f52 = vdup_n_s8(SUB128(filter[37])); + int8x8_t f53 = vdup_n_s8(SUB128(filter[38])); + int8x8_t f54 = vdup_n_s8(SUB128(filter[39])); + int8x8_t f55 = vdup_n_s8(SUB128(filter[40])); + int8x8_t f56 = vdup_n_s8(SUB128(filter[41])); + + int8x8_t f60 = vdup_n_s8(SUB128(filter[42])); + int8x8_t f61 = vdup_n_s8(SUB128(filter[43])); + int8x8_t f62 = vdup_n_s8(SUB128(filter[44])); + int8x8_t f63 = vdup_n_s8(SUB128(filter[45])); + int8x8_t f64 = vdup_n_s8(SUB128(filter[46])); + int8x8_t f65 = vdup_n_s8(SUB128(filter[47])); + int8x8_t f66 = vdup_n_s8(SUB128(filter[48])); + + // get filter * src_zp for one IC + int32_t fxszp = 0; + for (size_t i = 0; i < 49; ++i) + fxszp += static_cast(filter[i]) - 128; + int32x4_t vfxszp = vdupq_n_s32(fxszp * static_cast(src_zp)); + + // block 2x8 + size_t oh = 0; + for (; oh + 1 < OH; oh += 2) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + sum10 += vsrc_filter_zp; + sum11 += vsrc_filter_zp; + } + + int8x8_t _r00 = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + int8x8_t _r0n = SUB128VECTOR(vld1_u8(sptr + 0 * IW + 8)); + int8x8_t _r01 = vext_s8(_r00, _r0n, 1); + int8x8_t _r02 = vext_s8(_r00, _r0n, 2); + int8x8_t _r03 = vext_s8(_r00, _r0n, 3); + int8x8_t _r04 = vext_s8(_r00, _r0n, 4); + int8x8_t _r05 = vext_s8(_r00, _r0n, 5); + int8x8_t _r06 = vext_s8(_r00, _r0n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f00)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f01)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f02)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f03)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f04)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f05)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f06)); + + int8x8_t _r10 = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + int8x8_t _r1n = SUB128VECTOR(vld1_u8(sptr + 1 * IW + 8)); + int8x8_t _r11 = vext_s8(_r10, _r1n, 1); + int8x8_t _r12 = vext_s8(_r10, _r1n, 2); + int8x8_t _r13 = vext_s8(_r10, _r1n, 3); + int8x8_t _r14 = vext_s8(_r10, _r1n, 4); + int8x8_t _r15 = vext_s8(_r10, _r1n, 5); + int8x8_t _r16 = vext_s8(_r10, _r1n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r10, f10)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r11, f11)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r12, f12)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r13, f13)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r14, f14)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r15, f15)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r16, f16)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r10, f00)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r11, f01)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r12, f02)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r13, f03)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r14, f04)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r15, f05)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r16, f06)); + + int8x8_t _r20 = SUB128VECTOR(vld1_u8(sptr + 2 * IW)); + int8x8_t _r2n = SUB128VECTOR(vld1_u8(sptr + 2 * IW + 8)); + int8x8_t _r21 = vext_s8(_r20, _r2n, 1); + int8x8_t _r22 = vext_s8(_r20, _r2n, 2); + int8x8_t _r23 = vext_s8(_r20, _r2n, 3); + int8x8_t _r24 = vext_s8(_r20, _r2n, 4); + int8x8_t _r25 = vext_s8(_r20, _r2n, 5); + int8x8_t _r26 = vext_s8(_r20, _r2n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r20, f20)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r21, f21)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r22, f22)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r23, f23)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r24, f24)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r25, f25)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r26, f26)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r20, f10)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r21, f11)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r22, f12)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r23, f13)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r24, f14)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r25, f15)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r26, f16)); + + int8x8_t _r30 = SUB128VECTOR(vld1_u8(sptr + 3 * IW)); + int8x8_t _r3n = SUB128VECTOR(vld1_u8(sptr + 3 * IW + 8)); + int8x8_t _r31 = vext_s8(_r30, _r3n, 1); + int8x8_t _r32 = vext_s8(_r30, _r3n, 2); + int8x8_t _r33 = vext_s8(_r30, _r3n, 3); + int8x8_t _r34 = vext_s8(_r30, _r3n, 4); + int8x8_t _r35 = vext_s8(_r30, _r3n, 5); + int8x8_t _r36 = vext_s8(_r30, _r3n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r30, f30)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r31, f31)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r32, f32)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r33, f33)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r34, f34)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r35, f35)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r36, f36)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r30, f20)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r31, f21)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r32, f22)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r33, f23)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r34, f24)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r35, f25)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r36, f26)); + + int8x8_t _r40 = SUB128VECTOR(vld1_u8(sptr + 4 * IW)); + int8x8_t _r4n = SUB128VECTOR(vld1_u8(sptr + 4 * IW + 8)); + int8x8_t _r41 = vext_s8(_r40, _r4n, 1); + int8x8_t _r42 = vext_s8(_r40, _r4n, 2); + int8x8_t _r43 = vext_s8(_r40, _r4n, 3); + int8x8_t _r44 = vext_s8(_r40, _r4n, 4); + int8x8_t _r45 = vext_s8(_r40, _r4n, 5); + int8x8_t _r46 = vext_s8(_r40, _r4n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r40, f40)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r41, f41)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r42, f42)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r43, f43)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r44, f44)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r45, f45)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r46, f46)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r40, f30)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r41, f31)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r42, f32)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r43, f33)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r44, f34)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r45, f35)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r46, f36)); + + int8x8_t _r50 = SUB128VECTOR(vld1_u8(sptr + 5 * IW)); + int8x8_t _r5n = SUB128VECTOR(vld1_u8(sptr + 5 * IW + 8)); + int8x8_t _r51 = vext_s8(_r50, _r5n, 1); + int8x8_t _r52 = vext_s8(_r50, _r5n, 2); + int8x8_t _r53 = vext_s8(_r50, _r5n, 3); + int8x8_t _r54 = vext_s8(_r50, _r5n, 4); + int8x8_t _r55 = vext_s8(_r50, _r5n, 5); + int8x8_t _r56 = vext_s8(_r50, _r5n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r50, f50)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r51, f51)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r52, f52)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r53, f53)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r54, f54)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r55, f55)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r56, f56)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r50, f40)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r51, f41)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r52, f42)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r53, f43)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r54, f44)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r55, f45)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r56, f46)); + + int8x8_t _r60 = SUB128VECTOR(vld1_u8(sptr + 6 * IW)); + int8x8_t _r6n = SUB128VECTOR(vld1_u8(sptr + 6 * IW + 8)); + int8x8_t _r61 = vext_s8(_r60, _r6n, 1); + int8x8_t _r62 = vext_s8(_r60, _r6n, 2); + int8x8_t _r63 = vext_s8(_r60, _r6n, 3); + int8x8_t _r64 = vext_s8(_r60, _r6n, 4); + int8x8_t _r65 = vext_s8(_r60, _r6n, 5); + int8x8_t _r66 = vext_s8(_r60, _r6n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r60, f60)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r61, f61)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r62, f62)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r63, f63)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r64, f64)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r65, f65)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r66, f66)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + ACC_S16_S32(sum10, sum11, MLSFZP(_r60, f50)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r61, f51)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r62, f52)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r63, f53)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r64, f54)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r65, f55)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r66, f56)); + + int8x8_t _r70 = SUB128VECTOR(vld1_u8(sptr + 7 * IW)); + int8x8_t _r7n = SUB128VECTOR(vld1_u8(sptr + 7 * IW + 8)); + int8x8_t _r71 = vext_s8(_r70, _r7n, 1); + int8x8_t _r72 = vext_s8(_r70, _r7n, 2); + int8x8_t _r73 = vext_s8(_r70, _r7n, 3); + int8x8_t _r74 = vext_s8(_r70, _r7n, 4); + int8x8_t _r75 = vext_s8(_r70, _r7n, 5); + int8x8_t _r76 = vext_s8(_r70, _r7n, 6); + ACC_S16_S32(sum10, sum11, MLSFZP(_r70, f60)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r71, f61)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r72, f62)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r73, f63)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r74, f64)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r75, f65)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r76, f66)); + sum10 = vsubq_s32(sum10, vfxszp); + sum11 = vsubq_s32(sum11, vfxszp); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } + + if (oh < OH) { + size_t ih = oh; + for (size_t ow = 0; ow < OW; ow += 8) { + size_t iw = ow; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + } + + int8x8_t _r00 = SUB128VECTOR(vld1_u8(sptr + 0 * IW)); + int8x8_t _r0n = SUB128VECTOR(vld1_u8(sptr + 0 * IW + 8)); + int8x8_t _r01 = vext_s8(_r00, _r0n, 1); + int8x8_t _r02 = vext_s8(_r00, _r0n, 2); + int8x8_t _r03 = vext_s8(_r00, _r0n, 3); + int8x8_t _r04 = vext_s8(_r00, _r0n, 4); + int8x8_t _r05 = vext_s8(_r00, _r0n, 5); + int8x8_t _r06 = vext_s8(_r00, _r0n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f00)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f01)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f02)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f03)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f04)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f05)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f06)); + + int8x8_t _r10 = SUB128VECTOR(vld1_u8(sptr + 1 * IW)); + int8x8_t _r1n = SUB128VECTOR(vld1_u8(sptr + 1 * IW + 8)); + int8x8_t _r11 = vext_s8(_r10, _r1n, 1); + int8x8_t _r12 = vext_s8(_r10, _r1n, 2); + int8x8_t _r13 = vext_s8(_r10, _r1n, 3); + int8x8_t _r14 = vext_s8(_r10, _r1n, 4); + int8x8_t _r15 = vext_s8(_r10, _r1n, 5); + int8x8_t _r16 = vext_s8(_r10, _r1n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r10, f10)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r11, f11)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r12, f12)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r13, f13)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r14, f14)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r15, f15)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r16, f16)); + + int8x8_t _r20 = SUB128VECTOR(vld1_u8(sptr + 2 * IW)); + int8x8_t _r2n = SUB128VECTOR(vld1_u8(sptr + 2 * IW + 8)); + int8x8_t _r21 = vext_s8(_r20, _r2n, 1); + int8x8_t _r22 = vext_s8(_r20, _r2n, 2); + int8x8_t _r23 = vext_s8(_r20, _r2n, 3); + int8x8_t _r24 = vext_s8(_r20, _r2n, 4); + int8x8_t _r25 = vext_s8(_r20, _r2n, 5); + int8x8_t _r26 = vext_s8(_r20, _r2n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r20, f20)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r21, f21)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r22, f22)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r23, f23)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r24, f24)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r25, f25)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r26, f26)); + + int8x8_t _r30 = SUB128VECTOR(vld1_u8(sptr + 3 * IW)); + int8x8_t _r3n = SUB128VECTOR(vld1_u8(sptr + 3 * IW + 8)); + int8x8_t _r31 = vext_s8(_r30, _r3n, 1); + int8x8_t _r32 = vext_s8(_r30, _r3n, 2); + int8x8_t _r33 = vext_s8(_r30, _r3n, 3); + int8x8_t _r34 = vext_s8(_r30, _r3n, 4); + int8x8_t _r35 = vext_s8(_r30, _r3n, 5); + int8x8_t _r36 = vext_s8(_r30, _r3n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r30, f30)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r31, f31)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r32, f32)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r33, f33)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r34, f34)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r35, f35)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r36, f36)); + + int8x8_t _r40 = SUB128VECTOR(vld1_u8(sptr + 4 * IW)); + int8x8_t _r4n = SUB128VECTOR(vld1_u8(sptr + 4 * IW + 8)); + int8x8_t _r41 = vext_s8(_r40, _r4n, 1); + int8x8_t _r42 = vext_s8(_r40, _r4n, 2); + int8x8_t _r43 = vext_s8(_r40, _r4n, 3); + int8x8_t _r44 = vext_s8(_r40, _r4n, 4); + int8x8_t _r45 = vext_s8(_r40, _r4n, 5); + int8x8_t _r46 = vext_s8(_r40, _r4n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r40, f40)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r41, f41)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r42, f42)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r43, f43)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r44, f44)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r45, f45)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r46, f46)); + + int8x8_t _r50 = SUB128VECTOR(vld1_u8(sptr + 5 * IW)); + int8x8_t _r5n = SUB128VECTOR(vld1_u8(sptr + 5 * IW + 8)); + int8x8_t _r51 = vext_s8(_r50, _r5n, 1); + int8x8_t _r52 = vext_s8(_r50, _r5n, 2); + int8x8_t _r53 = vext_s8(_r50, _r5n, 3); + int8x8_t _r54 = vext_s8(_r50, _r5n, 4); + int8x8_t _r55 = vext_s8(_r50, _r5n, 5); + int8x8_t _r56 = vext_s8(_r50, _r5n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r50, f50)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r51, f51)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r52, f52)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r53, f53)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r54, f54)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r55, f55)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r56, f56)); + + int8x8_t _r60 = SUB128VECTOR(vld1_u8(sptr + 6 * IW)); + int8x8_t _r6n = SUB128VECTOR(vld1_u8(sptr + 6 * IW + 8)); + int8x8_t _r61 = vext_s8(_r60, _r6n, 1); + int8x8_t _r62 = vext_s8(_r60, _r6n, 2); + int8x8_t _r63 = vext_s8(_r60, _r6n, 3); + int8x8_t _r64 = vext_s8(_r60, _r6n, 4); + int8x8_t _r65 = vext_s8(_r60, _r6n, 5); + int8x8_t _r66 = vext_s8(_r60, _r6n, 6); + ACC_S16_S32(sum00, sum01, MLSFZP(_r60, f60)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r61, f61)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r62, f62)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r63, f63)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r64, f64)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r65, f65)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r66, f66)); + + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } + } + MIDOUT_END(); +} + +template +void conv_bias::conv_direct_stride2_2x2_quint8( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const int8_t src_zp, + const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); +#define GET_R2(sptr) \ + _r00 = SUB128VECTOR(vld1_u8(sptr)); \ + _r00 = vtbl1_s8(_r00, _idx); \ + _r01 = SUB128VECTOR(vld1_u8(sptr + 8)); \ + _r01 = vtbl1_s8(_r01, _idx); \ + _rn = vzip_s32(vreinterpret_s32_s8(_r00), vreinterpret_s32_s8(_r01)); \ + _r00 = vreinterpret_s8_s32(_rn.val[0]); \ + _r01 = vreinterpret_s8_s32(_rn.val[1]); + + MIDOUT_BEGIN(conv_direct_stride, 0, 4) { + int16x8_t v128 = vdupq_n_s16(128); + int32x4_t vsrc_filter_zp = vdupq_n_s32(src_filter_zp); + int8x8_t fzp = vdup_n_s8(filter_zp); + + int8x8_t f00 = vdup_n_s8(SUB128(filter[0])); + int8x8_t f01 = vdup_n_s8(SUB128(filter[1])); + int8x8_t f10 = vdup_n_s8(SUB128(filter[2])); + int8x8_t f11 = vdup_n_s8(SUB128(filter[3])); + + // get filter * src_zp for one IC + int32_t fxszp = 0; + for (size_t i = 0; i < 4; ++i) + fxszp += static_cast(filter[i]) - 128; + int32x4_t vfxszp = vdupq_n_s32(fxszp * static_cast(src_zp)); + + int8x8_t _idx = {0, 2, 4, 6, 1, 3, 5, 7}; + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh * 2; + size_t ow = 0; + for (; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + int32x2x2_t _rn; + int8x8_t _r00, _r01; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + } + + GET_R2(sptr); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f00)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f01)); + + GET_R2(sptr + IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f10)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f11)); + + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } + } + MIDOUT_END(); + +#undef GET_R2 +} + +template +void conv_bias::conv_direct_stride2_3x3_quint8( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const int8_t src_zp, + const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); +#define GET_R3(sptr) \ + _r00 = SUB128VECTOR(vld1_u8(sptr)); \ + _r00 = vtbl1_s8(_r00, _idx); \ + _r01 = SUB128VECTOR(vld1_u8(sptr + 8)); \ + _r01 = vtbl1_s8(_r01, _idx); \ + _rn = vzip_s32(vreinterpret_s32_s8(_r00), vreinterpret_s32_s8(_r01)); \ + _r00 = vreinterpret_s8_s32(_rn.val[0]); \ + _r01 = vreinterpret_s8_s32(_rn.val[1]); \ + _r02 = SUB128VECTOR(vld1_u8(sptr + 16)); \ + _r02 = vext_s8(_r00, _r02, 1); + + MIDOUT_BEGIN(conv_direct_stride, 0, 5) { + int16x8_t v128 = vdupq_n_s16(128); + int32x4_t vsrc_filter_zp = vdupq_n_s32(src_filter_zp); + int8x8_t fzp = vdup_n_s8(filter_zp); + + int8x8_t f00 = vdup_n_s8(SUB128(filter[0])); + int8x8_t f01 = vdup_n_s8(SUB128(filter[1])); + int8x8_t f02 = vdup_n_s8(SUB128(filter[2])); + int8x8_t f10 = vdup_n_s8(SUB128(filter[3])); + int8x8_t f11 = vdup_n_s8(SUB128(filter[4])); + int8x8_t f12 = vdup_n_s8(SUB128(filter[5])); + int8x8_t f20 = vdup_n_s8(SUB128(filter[6])); + int8x8_t f21 = vdup_n_s8(SUB128(filter[7])); + int8x8_t f22 = vdup_n_s8(SUB128(filter[8])); + + // get filter * src_zp for one IC + int32_t fxszp = 0; + for (size_t i = 0; i < 9; ++i) + fxszp += static_cast(filter[i]) - 128; + int32x4_t vfxszp = vdupq_n_s32(fxszp * static_cast(src_zp)); + + int8x8_t _idx = {0, 2, 4, 6, 1, 3, 5, 7}; + + // 2x8 block + size_t oh = 0; + for (; oh + 1 < OH; oh += 2) { + size_t ih = oh * 2; + size_t ow = 0; + for (; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + int32x2x2_t _rn; + int8x8_t _r00, _r01, _r02; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + sum10 += vsrc_filter_zp; + sum11 += vsrc_filter_zp; + } + GET_R3(sptr); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f00)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f01)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f02)); + + GET_R3(sptr + IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f10)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f11)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f12)); + + GET_R3(sptr + 2 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f20)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f21)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f22)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f00)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f01)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f02)); + + GET_R3(sptr + 3 * IW); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f10)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f11)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f12)); + + GET_R3(sptr + 4 * IW); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f20)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f21)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f22)); + sum10 = vsubq_s32(sum10, vfxszp); + sum11 = vsubq_s32(sum11, vfxszp); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } + if (oh < OH) { + size_t ih = oh * 2; + size_t ow = 0; + for (; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + int32x2x2_t _rn; + int8x8_t _r00, _r01, _r02; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + } + + GET_R3(sptr); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f00)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f01)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f02)); + + GET_R3(sptr + IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f10)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f11)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f12)); + + GET_R3(sptr + 2 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f20)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f21)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f22)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } + } + MIDOUT_END(); +#undef GET_R3 +} + +template +void conv_bias::conv_direct_stride2_5x5_quint8( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const int8_t src_zp, + const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); +#define GET_R5(sptr) \ + _r00 = SUB128VECTOR(vld1_u8(sptr)); \ + _r00 = vtbl1_s8(_r00, _idx); \ + _r01 = SUB128VECTOR(vld1_u8(sptr + 8)); \ + _r01 = vtbl1_s8(_r01, _idx); \ + _rn = vzip_s32(vreinterpret_s32_s8(_r00), vreinterpret_s32_s8(_r01)); \ + _r00 = vreinterpret_s8_s32(_rn.val[0]); \ + _r01 = vreinterpret_s8_s32(_rn.val[1]); \ + _r03 = SUB128VECTOR(vld1_u8(sptr + 16)); \ + _r03 = vtbl1_s8(_r03, _idx); \ + _r02 = vext_s8(_r00, _r03, 1); \ + _r04 = vext_s8(_r00, _r03, 2); \ + _r03 = vtbl1_s8(_r03, _idxn); \ + _r03 = vext_s8(_r01, _r03, 1); + + MIDOUT_BEGIN(conv_direct_stride, 0, 6) { + int16x8_t v128 = vdupq_n_s16(128); + int32x4_t vsrc_filter_zp = vdupq_n_s32(src_filter_zp); + int8x8_t fzp = vdup_n_s8(filter_zp); + + int8x8_t f00 = vdup_n_s8(SUB128(filter[0])); + int8x8_t f01 = vdup_n_s8(SUB128(filter[1])); + int8x8_t f02 = vdup_n_s8(SUB128(filter[2])); + int8x8_t f03 = vdup_n_s8(SUB128(filter[3])); + int8x8_t f04 = vdup_n_s8(SUB128(filter[4])); + int8x8_t f10 = vdup_n_s8(SUB128(filter[5])); + int8x8_t f11 = vdup_n_s8(SUB128(filter[6])); + int8x8_t f12 = vdup_n_s8(SUB128(filter[7])); + int8x8_t f13 = vdup_n_s8(SUB128(filter[8])); + int8x8_t f14 = vdup_n_s8(SUB128(filter[9])); + int8x8_t f20 = vdup_n_s8(SUB128(filter[10])); + int8x8_t f21 = vdup_n_s8(SUB128(filter[11])); + int8x8_t f22 = vdup_n_s8(SUB128(filter[12])); + int8x8_t f23 = vdup_n_s8(SUB128(filter[13])); + int8x8_t f24 = vdup_n_s8(SUB128(filter[14])); + int8x8_t f30 = vdup_n_s8(SUB128(filter[15])); + int8x8_t f31 = vdup_n_s8(SUB128(filter[16])); + int8x8_t f32 = vdup_n_s8(SUB128(filter[17])); + int8x8_t f33 = vdup_n_s8(SUB128(filter[18])); + int8x8_t f34 = vdup_n_s8(SUB128(filter[19])); + int8x8_t f40 = vdup_n_s8(SUB128(filter[20])); + int8x8_t f41 = vdup_n_s8(SUB128(filter[21])); + int8x8_t f42 = vdup_n_s8(SUB128(filter[22])); + int8x8_t f43 = vdup_n_s8(SUB128(filter[23])); + int8x8_t f44 = vdup_n_s8(SUB128(filter[24])); + + // get filter * src_zp for one IC + int32_t fxszp = 0; + for (size_t i = 0; i < 25; ++i) + fxszp += static_cast(filter[i]) - 128; + int32x4_t vfxszp = vdupq_n_s32(fxszp * static_cast(src_zp)); + + int8x8_t _idx = {0, 2, 4, 6, 1, 3, 5, 7}; + int8x8_t _idxn = {4, 5, 6, 7, 0, 1, 2, 3}; + // 2x8 block + size_t oh = 0; + for (; oh + 1 < OH; oh += 2) { + size_t ih = oh * 2; + size_t ow = 0; + for (; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + int32x2x2_t _rn; + int8x8_t _r00, _r01, _r02, _r03, _r04; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + sum10 += vsrc_filter_zp; + sum11 += vsrc_filter_zp; + } + + GET_R5(sptr); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f00)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f01)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f02)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f03)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f04)); + + GET_R5(sptr + IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f10)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f11)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f12)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f13)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f14)); + + GET_R5(sptr + 2 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f20)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f21)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f22)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f23)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f24)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f00)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f01)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f02)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r03, f03)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r04, f04)); + + GET_R5(sptr + 3 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f30)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f31)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f32)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f33)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f34)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f10)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f11)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f12)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r03, f13)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r04, f14)); + + GET_R5(sptr + 4 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f40)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f41)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f42)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f43)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f44)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f20)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f21)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f22)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r03, f23)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r04, f24)); + + GET_R5(sptr + 5 * IW); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f30)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f31)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f32)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r03, f33)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r04, f34)); + + GET_R5(sptr + 6 * IW); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f40)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f41)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f42)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r03, f43)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r04, f44)); + sum10 = vsubq_s32(sum10, vfxszp); + sum11 = vsubq_s32(sum11, vfxszp); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } + if (oh < OH) { + size_t ih = oh * 2; + size_t ow = 0; + for (; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + int32x2x2_t _rn; + int8x8_t _r00, _r01, _r02, _r03, _r04; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + } + + GET_R5(sptr); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f00)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f01)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f02)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f03)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f04)); + + GET_R5(sptr + IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f10)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f11)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f12)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f13)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f14)); + + GET_R5(sptr + 2 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f20)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f21)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f22)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f23)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f24)); + + GET_R5(sptr + 3 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f30)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f31)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f32)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f33)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f34)); + + GET_R5(sptr + 4 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f40)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f41)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f42)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f43)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f44)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } + } + MIDOUT_END(); +#undef GET_R5 +} + +template +void conv_bias::conv_direct_stride2_7x7_quint8( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const int8_t src_zp, + const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); +#define GET_R7(sptr) \ + _r00 = SUB128VECTOR(vld1_u8(sptr)); \ + _r00 = vtbl1_s8(_r00, _idx); \ + _r01 = SUB128VECTOR(vld1_u8(sptr + 8)); \ + _r01 = vtbl1_s8(_r01, _idx); \ + _rn = vzip_s32(vreinterpret_s32_s8(_r00), vreinterpret_s32_s8(_r01)); \ + _r00 = vreinterpret_s8_s32(_rn.val[0]); \ + _r01 = vreinterpret_s8_s32(_rn.val[1]); \ + _r05 = SUB128VECTOR(vld1_u8(sptr + 16)); \ + _r05 = vtbl1_s8(_r05, _idx); \ + _r02 = vext_s8(_r00, _r05, 1); \ + _r04 = vext_s8(_r00, _r05, 2); \ + _r06 = vext_s8(_r00, _r05, 3); \ + _r05 = vtbl1_s8(_r05, _idxn); \ + _r03 = vext_s8(_r01, _r05, 1); \ + _r05 = vext_s8(_r01, _r05, 2); + + MIDOUT_BEGIN(conv_direct_stride, 0, 7) { + int16x8_t v128 = vdupq_n_s16(128); + int32x4_t vsrc_filter_zp = vdupq_n_s32(src_filter_zp); + int8x8_t fzp = vdup_n_s8(filter_zp); + + int8x8_t f00 = vdup_n_s8(SUB128(filter[0])); + int8x8_t f01 = vdup_n_s8(SUB128(filter[1])); + int8x8_t f02 = vdup_n_s8(SUB128(filter[2])); + int8x8_t f03 = vdup_n_s8(SUB128(filter[3])); + int8x8_t f04 = vdup_n_s8(SUB128(filter[4])); + int8x8_t f05 = vdup_n_s8(SUB128(filter[5])); + int8x8_t f06 = vdup_n_s8(SUB128(filter[6])); + + int8x8_t f10 = vdup_n_s8(SUB128(filter[7])); + int8x8_t f11 = vdup_n_s8(SUB128(filter[8])); + int8x8_t f12 = vdup_n_s8(SUB128(filter[9])); + int8x8_t f13 = vdup_n_s8(SUB128(filter[10])); + int8x8_t f14 = vdup_n_s8(SUB128(filter[11])); + int8x8_t f15 = vdup_n_s8(SUB128(filter[12])); + int8x8_t f16 = vdup_n_s8(SUB128(filter[13])); + + int8x8_t f20 = vdup_n_s8(SUB128(filter[14])); + int8x8_t f21 = vdup_n_s8(SUB128(filter[15])); + int8x8_t f22 = vdup_n_s8(SUB128(filter[16])); + int8x8_t f23 = vdup_n_s8(SUB128(filter[17])); + int8x8_t f24 = vdup_n_s8(SUB128(filter[18])); + int8x8_t f25 = vdup_n_s8(SUB128(filter[19])); + int8x8_t f26 = vdup_n_s8(SUB128(filter[20])); + + int8x8_t f30 = vdup_n_s8(SUB128(filter[21])); + int8x8_t f31 = vdup_n_s8(SUB128(filter[22])); + int8x8_t f32 = vdup_n_s8(SUB128(filter[23])); + int8x8_t f33 = vdup_n_s8(SUB128(filter[24])); + int8x8_t f34 = vdup_n_s8(SUB128(filter[25])); + int8x8_t f35 = vdup_n_s8(SUB128(filter[26])); + int8x8_t f36 = vdup_n_s8(SUB128(filter[27])); + + int8x8_t f40 = vdup_n_s8(SUB128(filter[28])); + int8x8_t f41 = vdup_n_s8(SUB128(filter[29])); + int8x8_t f42 = vdup_n_s8(SUB128(filter[30])); + int8x8_t f43 = vdup_n_s8(SUB128(filter[31])); + int8x8_t f44 = vdup_n_s8(SUB128(filter[32])); + int8x8_t f45 = vdup_n_s8(SUB128(filter[33])); + int8x8_t f46 = vdup_n_s8(SUB128(filter[34])); + + int8x8_t f50 = vdup_n_s8(SUB128(filter[35])); + int8x8_t f51 = vdup_n_s8(SUB128(filter[36])); + int8x8_t f52 = vdup_n_s8(SUB128(filter[37])); + int8x8_t f53 = vdup_n_s8(SUB128(filter[38])); + int8x8_t f54 = vdup_n_s8(SUB128(filter[39])); + int8x8_t f55 = vdup_n_s8(SUB128(filter[40])); + int8x8_t f56 = vdup_n_s8(SUB128(filter[41])); + + int8x8_t f60 = vdup_n_s8(SUB128(filter[42])); + int8x8_t f61 = vdup_n_s8(SUB128(filter[43])); + int8x8_t f62 = vdup_n_s8(SUB128(filter[44])); + int8x8_t f63 = vdup_n_s8(SUB128(filter[45])); + int8x8_t f64 = vdup_n_s8(SUB128(filter[46])); + int8x8_t f65 = vdup_n_s8(SUB128(filter[47])); + int8x8_t f66 = vdup_n_s8(SUB128(filter[48])); + + // get filter * src_zp for one IC + int32_t fxszp = 0; + for (size_t i = 0; i < 49; ++i) + fxszp += static_cast(filter[i]) - 128; + int32x4_t vfxszp = vdupq_n_s32(fxszp * static_cast(src_zp)); + + int8x8_t _idx = {0, 2, 4, 6, 1, 3, 5, 7}; + int8x8_t _idxn = {4, 5, 6, 7, 0, 1, 2, 3}; + + // 2x8 block + size_t oh = 0; + for (; oh + 1 < OH; oh += 2) { + size_t ih = oh * 2; + size_t ow = 0; + for (; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01, sum10, sum11; + int32x2x2_t _rn; + int8x8_t _r00, _r01, _r02, _r03, _r04, _r05, _r06; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + sum10 = vld1q_s32(tptr + 1 * OW); + sum11 = vld1q_s32(tptr + 1 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + sum10 = sum00; + sum11 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + sum10 = vdupq_n_s32(0); + sum11 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + sum10 += vsrc_filter_zp; + sum11 += vsrc_filter_zp; + } + + GET_R7(sptr); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f00)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f01)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f02)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f03)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f04)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f05)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f06)); + + GET_R7(sptr + IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f10)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f11)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f12)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f13)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f14)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f15)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f16)); + + GET_R7(sptr + 2 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f20)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f21)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f22)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f23)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f24)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f25)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f26)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f00)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f01)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f02)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r03, f03)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r04, f04)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r05, f05)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r06, f06)); + + GET_R7(sptr + 3 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f30)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f31)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f32)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f33)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f34)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f35)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f36)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f10)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f11)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f12)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r03, f13)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r04, f14)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r05, f15)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r06, f16)); + + GET_R7(sptr + 4 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f40)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f41)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f42)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f43)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f44)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f45)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f46)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f20)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f21)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f22)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r03, f23)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r04, f24)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r05, f25)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r06, f26)); + + GET_R7(sptr + 5 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f50)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f51)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f52)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f53)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f54)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f55)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f56)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f30)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f31)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f32)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r03, f33)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r04, f34)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r05, f35)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r06, f36)); + + GET_R7(sptr + 6 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f60)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f61)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f62)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f63)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f64)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f65)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f66)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f40)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f41)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f42)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r03, f43)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r04, f44)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r05, f45)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r06, f46)); + + GET_R7(sptr + 7 * IW); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f50)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f51)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f52)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r03, f53)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r04, f54)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r05, f55)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r06, f56)); + + GET_R7(sptr + 8 * IW); + ACC_S16_S32(sum10, sum11, MLSFZP(_r00, f60)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r01, f61)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r02, f62)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r03, f63)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r04, f64)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r05, f65)); + ACC_S16_S32(sum10, sum11, MLSFZP(_r06, f66)); + sum10 = vsubq_s32(sum10, vfxszp); + sum11 = vsubq_s32(sum11, vfxszp); + POSTPROCESS(sum10, sum11, tptr + 1 * OW, dptr + 1 * OW); + } + } + if (oh < OH) { + size_t ih = oh * 2; + size_t ow = 0; + for (; ow < OW; ow += 8) { + size_t iw = ow * 2; + int32_t* __restrict tptr = temp + oh * OW + ow; + uint8_t* __restrict dptr = dst + oh * OW + ow; + const uint8_t* __restrict sptr = src + ih * IW + iw; + const int32_t* __restrict bptr = bias; + int32x4_t sum00, sum01; + int32x2x2_t _rn; + int8x8_t _r00, _r01, _r02, _r03, _r04, _r05, _r06; + + if (!first_ic) { + sum00 = vld1q_s32(tptr + 0 * OW); + sum01 = vld1q_s32(tptr + 0 * OW + 4); + } else { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + sum00 = vdupq_n_s32(bptr[0]); + sum01 = sum00; + } else { + sum00 = vdupq_n_s32(0); + sum01 = vdupq_n_s32(0); + } + sum00 += vsrc_filter_zp; + sum01 += vsrc_filter_zp; + } + + GET_R7(sptr); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f00)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f01)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f02)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f03)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f04)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f05)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f06)); + + GET_R7(sptr + IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f10)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f11)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f12)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f13)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f14)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f15)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f16)); + + GET_R7(sptr + 2 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f20)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f21)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f22)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f23)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f24)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f25)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f26)); + + GET_R7(sptr + 3 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f30)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f31)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f32)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f33)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f34)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f35)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f36)); + + GET_R7(sptr + 4 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f40)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f41)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f42)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f43)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f44)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f45)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f46)); + + GET_R7(sptr + 5 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f50)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f51)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f52)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f53)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f54)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f55)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f56)); + + GET_R7(sptr + 6 * IW); + ACC_S16_S32(sum00, sum01, MLSFZP(_r00, f60)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r01, f61)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r02, f62)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r03, f63)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r04, f64)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r05, f65)); + ACC_S16_S32(sum00, sum01, MLSFZP(_r06, f66)); + sum00 = vsubq_s32(sum00, vfxszp); + sum01 = vsubq_s32(sum01, vfxszp); + POSTPROCESS(sum00, sum01, tptr + 0 * OW, dptr + 0 * OW); + } + } + } + MIDOUT_END(); +#undef GET_R7 +} + +#undef MLSFZP +#undef SUB128 +#undef SUB128VECTOR +#undef POSTPROCESS +#undef ACC_S16_S32 + +#define INSTANTIATION(stride, i, first_ic, last_ic, bias, Op) \ + template void conv_bias::conv_direct_##stride##_##i##x##i##_quint8< \ + first_ic, last_ic, bias, Op>( \ + const uint8_t*, const uint8_t*, const int32_t*, int32_t*, \ + uint8_t*, const size_t, const size_t, const size_t, const size_t, \ + const int8_t, const int8_t, const int32_t, const Op&); + +#define FOR_NONLINEAR(stride, i, first_ic, last_ic, bias) \ + INSTANTIATION(stride, i, first_ic, last_ic, bias, \ + TypeCvtOp) \ + INSTANTIATION(stride, i, first_ic, last_ic, bias, \ + ReluOp) \ + INSTANTIATION(stride, i, first_ic, last_ic, bias, \ + HSwishOp) + +#define FOR_BIAS(stride, i, first_ic, last_ic) \ + FOR_NONLINEAR(stride, i, first_ic, last_ic, BiasMode::NO_BIAS) \ + FOR_NONLINEAR(stride, i, first_ic, last_ic, \ + BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_IC(stride, i) \ + FOR_BIAS(stride, i, true, true) \ + FOR_BIAS(stride, i, true, false) \ + FOR_BIAS(stride, i, false, false) \ + FOR_BIAS(stride, i, false, true) + +#define FOR_FILTER(stride) \ + FOR_IC(stride, 2) \ + FOR_IC(stride, 3) \ + FOR_IC(stride, 5) \ + FOR_IC(stride, 7) + +#define FOR_STRIDE \ + FOR_FILTER(stride1) \ + FOR_FILTER(stride2) + +FOR_STRIDE + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_IC +#undef FOR_BIAS +#undef FOR_NONLINEAR +#undef INSTANTIATION + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/quint8/direct.h b/dnn/src/arm_common/conv_bias/quint8/direct.h new file mode 100644 index 00000000..9475bb2b --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/direct.h @@ -0,0 +1,44 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/direct.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace conv_bias { + +#define KERN(stride, i) \ + template \ + void conv_direct_##stride##_##i##x##i##_quint8( \ + const uint8_t* src, const uint8_t* filter, const int32_t* bias, \ + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, \ + const size_t OH, const size_t OW, const int8_t src_zp, \ + const int8_t filter_zp, const int32_t src_filter_zp, \ + const Op& op); + +KERN(stride1, 2) +KERN(stride1, 3) +KERN(stride1, 5) +KERN(stride1, 7) + +KERN(stride2, 2) +KERN(stride2, 3) +KERN(stride2, 5) +KERN(stride2, 7) + +#undef KERN + +} // namesapce conv_bias +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp new file mode 100644 index 00000000..8005bbd4 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp @@ -0,0 +1,2522 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/conv_bias/quint8/direct_dotprod.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; +using megdnn::arm_common::ReluOp; +using megdnn::arm_common::TypeCvtOp; + +constexpr int32_t SHIFT = (1 << 30); + +inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index){ + int8x8x2_t src; + src.val[0] = vget_low_s8(a); + src.val[1] = vget_high_s8(a); + uint8x8_t index_low = vget_low_u8(index); + uint8x8_t index_high = vget_high_u8(index); + int8x8_t r00 = vtbl2_s8(src,vreinterpret_s8_u8(index_low)) ; + int8x8_t r01 = vtbl2_s8(src,vreinterpret_s8_u8(index_high)); + int8x16_t r = vcombine_s8(r00,r01); + return r; +} + +#define ST1_S32X4(dst0, tptr) vst1q_u32(tptr, dst0); + +#define ST2_S32X4X2(dst0, tptr) vst2q_u32(tptr, dst0); + +#define POSTPROCESS_1X8(dst0, dst1, tptr, dptr) \ + if (last_ic && fused_kern) { \ + op({{vreinterpretq_u32_s32(dst0), vreinterpretq_u32_s32(dst1)}}, \ + reinterpret_cast(dptr)); \ + } else { \ + ST1_S32X4(dst0, tptr); \ + ST1_S32X4(dst1, tptr + 4); \ + } + +#define POSTPROCESS2_1X8(dst0, tptr, dptr) \ + if (last_ic && fused_kern) { \ + uint32x4x2_t temp; \ + uint32x4_t temp00, temp11; \ + temp = vzipq_u32(dst0.val[0], dst0.val[1]); \ + temp00 = temp.val[0]; \ + temp11 = temp.val[1]; \ + op({{temp00,temp11}},reinterpret_cast(dptr)); \ + } else { \ + ST2_S32X4X2(dst0, tptr); \ + } + +#define POSTPROCESS_2X4(dst0, dst1, tptr1, tptr2, dptr1, dptr2) \ + if (last_ic && fused_kern) { \ + uint32x2_t res = reinterpret_cast( \ + op({{vreinterpretq_u32_s32(dst0), \ + vreinterpretq_u32_s32(dst1)}})); \ + vst1_lane_u32(reinterpret_cast(dptr1), res, 0); \ + vst1_lane_u32(reinterpret_cast(dptr2), res, 1); \ + } else { \ + ST1_S32X4(dst0, tptr1); \ + ST1_S32X4(dst1, tptr2); \ + } + +#define POSTPROCESS_1X4(dst0, tptr, dptr) \ + if (last_ic && fused_kern) { \ + int32x4_t dst1 = vdupq_n_s32(0); \ + uint32x2_t res = reinterpret_cast( \ + op({{vreinterpretq_u32_s32(dst0), dst1}})); \ + vst1_lane_u32(reinterpret_cast(dptr), res, 0); \ + } else { \ + ST1_S32X4(dst0, tptr); \ + } + +#define POSTPROCESS_1X1(dst0, tptr, dptr) \ + if (last_ic && fused_kern) { \ + int32x4_t dst1 = vdupq_n_s32(0); \ + uint8x8_t res = op({{vreinterpretq_u32_s32(dst0), dst1}}); \ + dptr = vget_lane_u8(res, 0); \ + } else { \ + tptr = vgetq_lane_u32(dst0, 0); \ + } + +#define CALC_DST(_sum) \ + _sum = vreinterpretq_u32_s32( \ + vaddq_s32(vreinterpretq_s32_u32(_sum), _shift_zp)) + +#define CALC_0(_k_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +#define CALC_1(_k_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +#define CALC_2(_k1_idx, _k2_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k1_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k1_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k2_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k2_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +template +void conv_bias::conv_direct_stride1_2x2_quint8_dot( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const uint8_t src_zp, + const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - OW; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp; + if (bias_mode != BiasMode::NO_BIAS) { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT + bias[0]); + } else { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + } + + const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, + 2, 3, 16, 16, 3, 4, 16, 16}; + const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, + 6, 7, 16, 16, 7, 8, 16, 16}; + //! here we use uint32_t for calc + uint32_t* outptr = reinterpret_cast(temp); + uint32_t* outptr2 = outptr + OW; + uint8_t* dstptr = dst; + uint8_t* dstptr2 = dstptr + OW; + + const uint8_t* r0 = src; + const uint8_t* r1 = src + IW; + const uint8_t* r2 = src + 2 * IW; + + const uint8_t* k0 = filter; + + uint8x16_t _k = vreinterpretq_u8_u32( + vdupq_n_u32(*reinterpret_cast(k0))); + uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; + uint8x16_t _k1 = vqtbl1q_s8_v7(_k, _idx); + _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; + uint8x16_t _k23 = vqtbl1q_s8_v7(_k, _idx); + +#define SUB_ZP(_sum, _r) \ + _sum = vdotq_u32(_sum, _k, _r); \ + _sum = vsubq_u32(_sum, vdotq2_u32(_src_zp, _k)); \ + _sum = vsubq_u32(_sum, vdotq2_u32(_filter_zp, _r)); + + uint8x16_t _tmp, _elem; + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 4 < width; w += 4) { + uint32x4x2_t _sum00, _sum01, _sum10, _sum11; + if (!first_ic) { + _sum00 = vld2q_u32(outptr); + _sum01 = vld2q_u32(outptr + 8); + _sum10 = vld2q_u32(outptr2); + _sum11 = vld2q_u32(outptr2 + 8); + } else { + _sum00.val[0] = vdupq_n_u32(SHIFT); + _sum01.val[0] = vdupq_n_u32(SHIFT); + _sum10.val[0] = vdupq_n_u32(SHIFT); + _sum11.val[0] = vdupq_n_u32(SHIFT); + _sum00.val[1] = vdupq_n_u32(SHIFT); + _sum01.val[1] = vdupq_n_u32(SHIFT); + _sum10.val[1] = vdupq_n_u32(SHIFT); + _sum11.val[1] = vdupq_n_u32(SHIFT); + } + + uint8x16_t _r00 = vld1q_u8(r0); + //! here will not not read out of bound + uint8x16_t _r01_ = vdupq_n_u8(r0[16]); + uint8x16_t _r10 = vld1q_u8(r1); + uint8x16_t _r11_ = vdupq_n_u8(r1[16]); + uint8x16_t _r20 = vld1q_u8(r2); + uint8x16_t _r21_ = vdupq_n_u8(r2[16]); + uint8x16_t _r01 = vextq_u8(_r00, _r01_, 1); + uint8x16_t _r11 = vextq_u8(_r10, _r11_, 1); + uint8x16_t _r21 = vextq_u8(_r20, _r21_, 1); + + int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), + vreinterpretq_s16_u8(_r10)); + uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); + uint8x16_t _r2 = vreinterpretq_u8_s8(r_0.val[1]); + + int16x8x2_t r_1 = vzipq_s16(vreinterpretq_s16_u8(_r01), + vreinterpretq_s16_u8(_r11)); + int8x16_t _r1 = vreinterpretq_u8_s8(r_1.val[0]); + int8x16_t _r3 = vreinterpretq_u8_s8(r_1.val[1]); + + SUB_ZP(_sum00.val[0], _r0); + SUB_ZP(_sum00.val[1], _r1); + SUB_ZP(_sum01.val[0], _r2); + SUB_ZP(_sum01.val[1], _r3); + + r_0 = vzipq_s16(vreinterpretq_s16_u8(_r10), + vreinterpretq_s16_u8(_r20)); + _r0 = vreinterpretq_u8_s8(r_0.val[0]); + _r2 = vreinterpretq_u8_s8(r_0.val[1]); + + r_1 = vzipq_s16(vreinterpretq_s16_u8(_r11), + vreinterpretq_s16_u8(_r21)); + _r1 = vreinterpretq_u8_s8(r_1.val[0]); + _r3 = vreinterpretq_u8_s8(r_1.val[1]); + + SUB_ZP(_sum10.val[0], _r0); + SUB_ZP(_sum10.val[1], _r1); + SUB_ZP(_sum11.val[0], _r2); + SUB_ZP(_sum11.val[1], _r3); + + if (last_ic) { + CALC_DST(_sum00.val[0]); + CALC_DST(_sum00.val[1]); + CALC_DST(_sum01.val[0]); + CALC_DST(_sum01.val[1]); + CALC_DST(_sum10.val[0]); + CALC_DST(_sum10.val[1]); + CALC_DST(_sum11.val[0]); + CALC_DST(_sum11.val[1]); + } + + POSTPROCESS2_1X8(_sum00, outptr, dstptr); + POSTPROCESS2_1X8(_sum01, outptr + 8, dstptr + 8); + POSTPROCESS2_1X8(_sum10, outptr2, dstptr2); + POSTPROCESS2_1X8(_sum11, outptr2 + 8, dstptr2 + 8); + + r0 += 16; + r1 += 16; + r2 += 16; + outptr += 16; + outptr2 += 16; + dstptr += 16; + dstptr2 += 16; + } + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00, _sum01, _sum10, _sum11; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + _sum10 = vld1q_u32(outptr2); + _sum11 = vld1q_u32(outptr2 + 4); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + _sum11 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(1, 0); + CALC_0(1, 1); + + _tmp = vld1q_u8(r1); + CALC_2(23, 1, 0); + CALC_2(23, 1, 1); + + _tmp = vld1q_u8(r2); + CALC_1(23, 0); + CALC_1(23, 1); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum10); + CALC_DST(_sum11); + } + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X8(_sum10, _sum11, outptr2, dstptr2); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 8; + outptr2 += 8; + dstptr += 8; + dstptr2 += 8; + } + + for (; w < width; w++) { + uint32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum10 = vld1q_u32(outptr2); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + } + + _tmp = vtranslq_u8(vld1_u8(r0)); + CALC_0(1, 0); + + _tmp = vtranslq_u8(vld1_u8(r1)); + CALC_2(23, 1, 0); + + _tmp = vtranslq_u8(vld1_u8(r2)); + CALC_1(23, 0); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 4; + r1 += 4; + r2 += 4; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 4 < width; w += 4) { + uint32x4x2_t _sum0, _sum1; + if (!first_ic) { + _sum0 = vld2q_u32(outptr); + _sum1 = vld2q_u32(outptr + 8); + } else { + _sum0.val[0] = vdupq_n_u32(SHIFT); + _sum1.val[0] = vdupq_n_u32(SHIFT); + _sum0.val[1] = vdupq_n_u32(SHIFT); + _sum1.val[1] = vdupq_n_u32(SHIFT); + } + + uint8x16_t _r00 = vld1q_u8(r0); + //! here will not not read out of bound + uint8x16_t _r01_ = vdupq_n_u8(r0[16]); + uint8x16_t _r10 = vld1q_u8(r1); + uint8x16_t _r11_ = vdupq_n_u8(r1[16]); + uint8x16_t _r01 = vextq_u8(_r00, _r01_, 1); + uint8x16_t _r11 = vextq_u8(_r10, _r11_, 1); + + int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), + vreinterpretq_s16_u8(_r10)); + uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); + uint8x16_t _r2 = vreinterpretq_u8_s8(r_0.val[1]); + + int16x8x2_t r_1 = vzipq_s16(vreinterpretq_s16_u8(_r01), + vreinterpretq_s16_u8(_r11)); + int8x16_t _r1 = vreinterpretq_u8_s8(r_1.val[0]); + int8x16_t _r3 = vreinterpretq_u8_s8(r_1.val[1]); + + SUB_ZP(_sum0.val[0], _r0); + SUB_ZP(_sum0.val[1], _r1); + SUB_ZP(_sum1.val[0], _r2); + SUB_ZP(_sum1.val[1], _r3); + + if (last_ic) { + CALC_DST(_sum0.val[0]); + CALC_DST(_sum0.val[1]); + CALC_DST(_sum1.val[0]); + CALC_DST(_sum1.val[1]); + } + POSTPROCESS2_1X8(_sum0, outptr, dstptr); + POSTPROCESS2_1X8(_sum1, outptr + 8, dstptr + 8); + + r0 += 16; + r1 += 16; + outptr += 16; + dstptr += 16; + } + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00, _sum01; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(1, 0); + CALC_0(1, 1); + + _tmp = vld1q_u8(r1); + CALC_0(23, 0); + CALC_0(23, 1); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + } + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + + r0 += 8; + r1 += 8; + outptr += 8; + dstptr += 8; + } + + for (; w < width; w++) { + uint32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + } else { + _sum00 = vdupq_n_u32(SHIFT); + } + + _tmp = vtranslq_u8(vld1_u8(r0)); + CALC_0(1, 0); + + _tmp = vtranslq_u8(vld1_u8(r1)); + CALC_0(23, 0); + + if (last_ic) { + CALC_DST(_sum00); + } + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 4; + r1 += 4; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + } +#undef SUB_ZP +} + +template +void conv_bias::conv_direct_stride1_3x3_quint8_dot( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const uint8_t src_zp, + const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - OW; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp; + if (bias_mode != BiasMode::NO_BIAS) { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT + bias[0]); + } else { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + } + + const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, + 2, 3, 4, 16, 3, 4, 5, 16}; + const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, + 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx2 = {8, 9, 10, 16, 9, 10, 11, 16, + 10, 11, 12, 16, 11, 12, 13, 16}; + uint32_t* outptr = reinterpret_cast(temp); + uint32_t* outptr2 = outptr + OW; + uint8_t* dstptr = dst; + uint8_t* dstptr2 = dstptr + OW; + + const uint8_t* r0 = src; + const uint8_t* r1 = src + IW; + const uint8_t* r2 = src + IW * 2; + const uint8_t* r3 = src + IW * 3; + + const uint8_t* k0 = filter; + + uint8x16_t _k_tmp = vcombine_u8(vld1_u8(k0), vdup_n_u8(k0[8])); + uint8x16_t _idx = {0, 1, 2, 16, 0, 1, 2, 16, 0, 1, 2, 16, 0, 1, 2, 16}; + uint8x16_t _k12 = vqtbl1q_s8_v7(_k_tmp, _idx); + _idx = {3, 4, 5, 16, 3, 4, 5, 16, 3, 4, 5, 16, 3, 4, 5, 16}; + uint8x16_t _k345 = vqtbl1q_s8_v7(_k_tmp, _idx); + _idx = {6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16}; + uint8x16_t _k678 = vqtbl1q_s8_v7(_k_tmp, _idx); + + uint8x16_t _tmp, _elem; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + uint32x4_t _sum00, _sum01, _sum02, _sum10, _sum11, _sum12; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + _sum02 = vld1q_u32(outptr + 8); + _sum10 = vld1q_u32(outptr2); + _sum11 = vld1q_u32(outptr2 + 4); + _sum12 = vld1q_u32(outptr2 + 8); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + _sum02 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + _sum11 = vdupq_n_u32(SHIFT); + _sum12 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(12, 0); + CALC_0(12, 1); + CALC_0(12, 2); + + _tmp = vld1q_u8(r1); + CALC_2(345, 12, 0); + CALC_2(345, 12, 1); + CALC_2(345, 12, 2); + + _tmp = vld1q_u8(r2); + CALC_2(678, 345, 0); + CALC_2(678, 345, 1); + CALC_2(678, 345, 2); + + _tmp = vld1q_u8(r3); + CALC_1(678, 0); + CALC_1(678, 1); + CALC_1(678, 2); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + CALC_DST(_sum10); + CALC_DST(_sum11); + CALC_DST(_sum12); + } + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X4(_sum02, outptr + 8, dstptr + 8); + POSTPROCESS_1X8(_sum10, _sum11, outptr2, dstptr2); + POSTPROCESS_1X4(_sum12, outptr2 + 8, dstptr2 + 8); + + r0 += 12; + r1 += 12; + r2 += 12; + r3 += 12; + outptr += 12; + outptr2 += 12; + dstptr += 12; + dstptr2 += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum10 = vld1q_u32(outptr2); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + } + + _tmp = vtranslq_u8(vld1_u8(r0)); + CALC_0(12, 0); + + _tmp = vtranslq_u8(vld1_u8(r1)); + CALC_2(345, 12, 0); + + _tmp = vtranslq_u8(vld1_u8(r2)); + CALC_2(678, 345, 0); + + _tmp = vtranslq_u8(vld1_u8(r3)); + CALC_1(678, 0); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + uint32x4_t _sum00, _sum01, _sum02; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + _sum02 = vld1q_u32(outptr + 8); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + _sum02 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(12, 0); + CALC_0(12, 1); + CALC_0(12, 2); + + _tmp = vld1q_u8(r1); + CALC_0(345, 0); + CALC_0(345, 1); + CALC_0(345, 2); + + _tmp = vld1q_u8(r2); + CALC_0(678, 0); + CALC_0(678, 1); + CALC_0(678, 2); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + } + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X4(_sum02, outptr + 8, dstptr + 8); + + r0 += 12; + r1 += 12; + r2 += 12; + outptr += 12; + dstptr += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + } else { + _sum00 = vdupq_n_u32(SHIFT); + } + + _tmp = vtranslq_u8(vld1_u8(r0)); + CALC_0(12, 0); + + _tmp = vtranslq_u8(vld1_u8(r1)); + CALC_0(345, 0); + + _tmp = vtranslq_u8(vld1_u8(r2)); + CALC_0(678, 0); + + if (last_ic) { + CALC_DST(_sum00); + } + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 4; + r1 += 4; + r2 += 4; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } +} + +template +void conv_bias::conv_direct_stride2_2x2_quint8_dot( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const uint8_t src_zp, + const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - 2 * OW + IW; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp; + if (bias_mode != BiasMode::NO_BIAS) { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT + bias[0]); + } else { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + } + + const uint8x16_t _idx0 = {0, 1, 16, 16, 2, 3, 16, 16, + 4, 5, 16, 16, 6, 7, 16, 16}; + uint32_t* outptr = reinterpret_cast(temp); + uint8_t* dstptr = dst; + + const uint8_t* r0 = src; + const uint8_t* r1 = src + IW; + + const uint8_t* k0 = filter; + + uint8x16_t _k = vreinterpretq_u8_u32( + vdupq_n_u32(*reinterpret_cast(k0))); + uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; + uint8x16_t _k1 = vqtbl1q_s8_v7(_k, _idx); + _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; + uint8x16_t _k23 = vqtbl1q_s8_v7(_k, _idx); + +#define SUB_ZP(_sum, _r) \ + _sum = vdotq_u32(_sum, _k, _r); \ + _sum = vsubq_u32(_sum, vdotq2_u32(_src_zp, _k)); \ + _sum = vsubq_u32(_sum, vdotq2_u32(_filter_zp, _r)); + + uint8x16_t _tmp, _elem; + const int width = OW >> 2; + size_t h = 0; + for (; h < OH; h++) { + int w = 0; + for (; w + 2 < width; w += 2) { + uint32x4_t _sum0, _sum1; + if (!first_ic) { + _sum0 = vld1q_u32(outptr); + _sum1 = vld1q_u32(outptr + 4); + } else { + _sum0 = vdupq_n_u32(SHIFT); + _sum1 = vdupq_n_u32(SHIFT); + } + + uint8x16_t _r00 = vld1q_u8(r0); + //! here will not not read out of bound + uint8x16_t _r10 = vld1q_u8(r1); + + int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), + vreinterpretq_s16_u8(_r10)); + uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); + uint8x16_t _r1 = vreinterpretq_u8_s8(r_0.val[1]); + SUB_ZP(_sum0, _r0); + SUB_ZP(_sum1, _r1); + + if (last_ic) { + CALC_DST(_sum0); + CALC_DST(_sum1); + } + + POSTPROCESS_1X8(_sum0, _sum1, outptr, dstptr); + + r0 += 16; + r1 += 16; + outptr += 8; + dstptr += 8; + } + + for (; w < width; w++) { + uint32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + } else { + _sum00 = vdupq_n_u32(SHIFT); + } + + _tmp = vtranslq_u8(vld1_u8(r0)); + CALC_0(1, 0); + + _tmp = vtranslq_u8(vld1_u8(r1)); + CALC_0(23, 0); + + if (last_ic) { + CALC_DST(_sum00); + } + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 8; + r1 += 8; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + } +#undef SUB_ZP +} + +template +void conv_bias::conv_direct_stride2_3x3_quint8_dot( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const uint8_t src_zp, + const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - 2 * OW + IW; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp; + if (bias_mode != BiasMode::NO_BIAS) { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT + bias[0]); + } else { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + } + + const uint8x16_t _idx0 = {0, 1, 2, 16, 2, 3, 4, 16, + 4, 5, 6, 16, 6, 7, 8, 16}; + const uint8x16_t _idx1 = {8, 9, 10, 16, 10, 11, 12, 16, + 12, 13, 14, 16, 16, 16, 16, 16}; + //! start from 12 13 14 15 + const uint8x16_t _idx2 = {2, 3, 4, 16, 4, 5, 6, 16, + 6, 7, 8, 16, 8, 9, 10, 16}; + const uint8x16_t _idx3 = {10, 11, 12, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16}; + uint32_t* outptr = reinterpret_cast(temp); + uint32_t* outptr2 = outptr + OW; + uint8_t* dstptr = dst; + uint8_t* dstptr2 = dstptr + OW; + + const uint8_t* r0 = src; + const uint8_t* r1 = src + IW; + const uint8_t* r2 = src + IW * 2; + const uint8_t* r3 = src + IW * 3; + const uint8_t* r4 = src + IW * 4; + + const uint8_t* k0 = filter; + + uint8x16_t _k_tmp = vcombine_u8(vld1_u8(k0), vdup_n_u8(k0[8])); + uint8x16_t _idx = {0, 1, 2, 16, 0, 1, 2, 16, 0, 1, 2, 16, 0, 1, 2, 16}; + uint8x16_t _k12 = vqtbl1q_s8_v7(_k_tmp, _idx); + _idx = {3, 4, 5, 16, 3, 4, 5, 16, 3, 4, 5, 16, 3, 4, 5, 16}; + uint8x16_t _k345 = vqtbl1q_s8_v7(_k_tmp, _idx); + _idx = {6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16}; + uint8x16_t _k678 = vqtbl1q_s8_v7(_k_tmp, _idx); + + uint8x16_t _tmp, _elem; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + uint32x4_t _sum00, _sum01, _sum02, _sum03, _sum10, _sum11, _sum12, + _sum13; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + _sum02 = vld1q_u32(outptr + 7); + _sum03 = vld1q_u32(outptr + 11); + _sum10 = vld1q_u32(outptr2); + _sum11 = vld1q_u32(outptr2 + 4); + _sum12 = vld1q_u32(outptr2 + 7); + _sum13 = vld1q_u32(outptr2 + 11); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + _sum02 = vdupq_n_u32(SHIFT); + _sum03 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + _sum11 = vdupq_n_u32(SHIFT); + _sum12 = vdupq_n_u32(SHIFT); + _sum13 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(12, 0); + CALC_0(12, 1); + _tmp = vld1q_u8(r0 + 12); + CALC_0(12, 2); + CALC_0(12, 3); + + _tmp = vld1q_u8(r1); + CALC_0(345, 0); + CALC_0(345, 1); + _tmp = vld1q_u8(r1 + 12); + CALC_0(345, 2); + CALC_0(345, 3); + + _tmp = vld1q_u8(r2); + CALC_2(678, 12, 0); + CALC_2(678, 12, 1); + _tmp = vld1q_u8(r2 + 12); + CALC_2(678, 12, 2); + CALC_2(678, 12, 3); + + _tmp = vld1q_u8(r3); + CALC_1(345, 0); + CALC_1(345, 1); + _tmp = vld1q_u8(r3 + 12); + CALC_1(345, 2); + CALC_1(345, 3); + + _tmp = vld1q_u8(r4); + CALC_1(678, 0); + CALC_1(678, 1); + _tmp = vld1q_u8(r4 + 12); + CALC_1(678, 2); + CALC_1(678, 3); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + CALC_DST(_sum03); + CALC_DST(_sum10); + CALC_DST(_sum11); + CALC_DST(_sum12); + CALC_DST(_sum13); + } + + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X4(_sum02, outptr + 7, dstptr + 7); + POSTPROCESS_1X1(_sum03, outptr[11], dstptr[11]); + POSTPROCESS_1X8(_sum10, _sum11, outptr2, dstptr2); + POSTPROCESS_1X4(_sum12, outptr2 + 7, dstptr2 + 7); + POSTPROCESS_1X1(_sum13, outptr2[11], dstptr2[11]); + + r0 += 24; + r1 += 24; + r2 += 24; + r3 += 24; + r4 += 24; + outptr += 12; + outptr2 += 12; + dstptr += 12; + dstptr2 += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum10 = vld1q_u32(outptr2); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(12, 0); + + _tmp = vld1q_u8(r1); + CALC_0(345, 0); + + _tmp = vld1q_u8(r2); + CALC_2(678, 12, 0); + + _tmp = vld1q_u8(r3); + CALC_1(345, 0); + + _tmp = vld1q_u8(r4); + CALC_1(678, 0); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + + r0 += tail_step + IW * 2; + r1 += tail_step + IW * 2; + r2 += tail_step + IW * 2; + r3 += tail_step + IW * 2; + r4 += tail_step + IW * 2; + + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + uint32x4_t _sum00, _sum01, _sum02, _sum03; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + _sum02 = vld1q_u32(outptr + 7); + _sum03 = vld1q_u32(outptr + 11); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + _sum02 = vdupq_n_u32(SHIFT); + _sum03 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(12, 0); + CALC_0(12, 1); + _tmp = vld1q_u8(r0 + 12); + CALC_0(12, 2); + CALC_0(12, 3); + + _tmp = vld1q_u8(r1); + CALC_0(345, 0); + CALC_0(345, 1); + _tmp = vld1q_u8(r1 + 12); + CALC_0(345, 2); + CALC_0(345, 3); + + _tmp = vld1q_u8(r2); + CALC_0(678, 0); + CALC_0(678, 1); + _tmp = vld1q_u8(r2 + 12); + CALC_0(678, 2); + CALC_0(678, 3); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + CALC_DST(_sum03); + } + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X4(_sum02, outptr + 7, dstptr + 7); + POSTPROCESS_1X1(_sum03, outptr[11], dstptr[11]); + + r0 += 24; + r1 += 24; + r2 += 24; + outptr += 12; + dstptr += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + } else { + _sum00 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(12, 0); + + _tmp = vld1q_u8(r1); + CALC_0(345, 0); + + _tmp = vld1q_u8(r2); + CALC_0(678, 0); + + if (last_ic) { + CALC_DST(_sum00); + } + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } +} + +#undef CALC_0 +#undef CALC_1 +#undef CALC_2 + +#define CALC_0(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +#define CALC_1(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##0); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k00_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##1); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k01_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +#define CALC_2(_k00_idx, _k01_idx, _k10_idx, _k11_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _elem2 = vdotq2_u32(_filter_zp, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k10_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k10_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ + _elem2 = vdotq2_u32(_filter_zp, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k11_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k11_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); + +template +void conv_bias::conv_direct_stride1_5x5_quint8_dot( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const uint8_t src_zp, + const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - OW; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp; + if (bias_mode != BiasMode::NO_BIAS) { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT + bias[0]); + } else { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + } + + const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; + const uint8x16_t _idx01 = {4, 16, 16, 16, 5, 16, 16, 16, + 6, 16, 16, 16, 7, 16, 16, 16}; + const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; + const uint8x16_t _idx11 = {8, 16, 16, 16, 9, 16, 16, 16, + 10, 16, 16, 16, 11, 16, 16, 16}; + const uint8x16_t _idx20 = {8, 9, 10, 11, 9, 10, 11, 12, + 10, 11, 12, 13, 11, 12, 13, 14}; + const uint8x16_t _idx21 = {12, 16, 16, 16, 13, 16, 16, 16, + 14, 16, 16, 16, 15, 16, 16, 16}; + uint8x16_t _tmp, _elem; + uint32x4_t _elem2; + uint32_t* outptr = reinterpret_cast(temp); + uint32_t* outptr2 = outptr + OW; + uint8_t* dstptr = dst; + uint8_t* dstptr2 = dstptr + OW; + + const uint8_t* r0 = src; + const uint8_t* r1 = src + IW; + const uint8_t* r2 = src + IW * 2; + const uint8_t* r3 = src + IW * 3; + const uint8_t* r4 = src + IW * 4; + const uint8_t* r5 = src + IW * 5; + + const uint8_t* k0 = filter; + + uint8x16_t _k = vld1q_u8(k0); + //! filter row 1 + uint8x16_t _idx = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + uint8x16_t _k123 = vqtbl1q_s8_v7(_k, _idx); + _idx = {4, 16, 16, 16, 4, 16, 16, 16, 4, 16, 16, 16, 4, 16, 16, 16}; + uint8x16_t _k4 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 2 + _idx = {5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8}; + uint8x16_t _k5678 = vqtbl1q_s8_v7(_k, _idx); + _idx = {9, 16, 16, 16, 9, 16, 16, 16, 9, 16, 16, 16, 9, 16, 16, 16}; + uint8x16_t _k9 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 3 + _idx = {10, 11, 12, 13, 10, 11, 12, 13, 10, 11, 12, 13, 10, 11, 12, 13}; + uint8x16_t _k10111213 = vqtbl1q_s8_v7(_k, _idx); + _idx = {14, 16, 16, 16, 14, 16, 16, 16, 14, 16, 16, 16, 14, 16, 16, 16}; + uint8x16_t _k14 = vqtbl1q_s8_v7(_k, _idx); + //! 9 10 11 12 -> 13 14 15 16 -> 17 18 19 20 -> 21 22 23 24 + _k = vld1q_u8(k0 + 9); + //! filter row 4 + _idx = {6, 7, 8, 9, 6, 7, 8, 9, 6, 7, 8, 9, 6, 7, 8, 9}; + uint8x16_t _k15161718 = vqtbl1q_s8_v7(_k, _idx); + _idx = {10, 16, 16, 16, 10, 16, 16, 16, 10, 16, 16, 16, 10, 16, 16, 16}; + uint8x16_t _k19 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 5 + _idx = {11, 12, 13, 14, 11, 12, 13, 14, 11, 12, 13, 14, 11, 12, 13, 14}; + uint8x16_t _k20212223 = vqtbl1q_s8_v7(_k, _idx); + _idx = {15, 16, 16, 16, 15, 16, 16, 16, 15, 16, 16, 16, 15, 16, 16, 16}; + uint8x16_t _k24 = vqtbl1q_s8_v7(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 3 < width; w += 3) { + uint32x4_t _sum00, _sum01, _sum02, _sum10, _sum11, _sum12; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + _sum02 = vld1q_u32(outptr + 8); + _sum10 = vld1q_u32(outptr2); + _sum11 = vld1q_u32(outptr2 + 4); + _sum12 = vld1q_u32(outptr2 + 8); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + _sum02 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + _sum11 = vdupq_n_u32(SHIFT); + _sum12 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 4, 0); + CALC_0(123, 4, 1); + CALC_0(123, 4, 2); + + _tmp = vld1q_u8(r1); + CALC_2(5678, 9, 123, 4, 0); + CALC_2(5678, 9, 123, 4, 1); + CALC_2(5678, 9, 123, 4, 2); + + _tmp = vld1q_u8(r2); + CALC_2(10111213, 14, 5678, 9, 0); + CALC_2(10111213, 14, 5678, 9, 1); + CALC_2(10111213, 14, 5678, 9, 2); + + _tmp = vld1q_u8(r3); + CALC_2(15161718, 19, 10111213, 14, 0); + CALC_2(15161718, 19, 10111213, 14, 1); + CALC_2(15161718, 19, 10111213, 14, 2); + + _tmp = vld1q_u8(r4); + CALC_2(20212223, 24, 15161718, 19, 0); + CALC_2(20212223, 24, 15161718, 19, 1); + CALC_2(20212223, 24, 15161718, 19, 2); + + _tmp = vld1q_u8(r5); + CALC_1(20212223, 24, 0); + CALC_1(20212223, 24, 1); + CALC_1(20212223, 24, 2); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + CALC_DST(_sum10); + CALC_DST(_sum11); + CALC_DST(_sum12); + } + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X4(_sum02, outptr + 8, dstptr + 8); + POSTPROCESS_1X8(_sum10, _sum11, outptr2, dstptr2); + POSTPROCESS_1X4(_sum12, outptr2 + 8, dstptr2 + 8); + + r0 += 12; + r1 += 12; + r2 += 12; + r3 += 12; + r4 += 12; + r5 += 12; + outptr += 12; + outptr2 += 12; + dstptr += 12; + dstptr2 += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum10 = vld1q_u32(outptr2); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + } + + _tmp = vtranslq_u8(vld1_u8(r0)); + CALC_0(123, 4, 0); + + _tmp = vtranslq_u8(vld1_u8(r1)); + CALC_2(5678, 9, 123, 4, 0); + + _tmp = vtranslq_u8(vld1_u8(r2)); + CALC_2(10111213, 14, 5678, 9, 0); + + _tmp = vtranslq_u8(vld1_u8(r3)); + CALC_2(15161718, 19, 10111213, 14, 0); + + _tmp = vtranslq_u8(vld1_u8(r4)); + CALC_2(20212223, 24, 15161718, 19, 0); + + _tmp = vtranslq_u8(vld1_u8(r5)); + CALC_1(20212223, 24, 0); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 3 < width; w += 3) { + uint32x4_t _sum00, _sum01, _sum02; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + _sum02 = vld1q_u32(outptr + 8); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + _sum02 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 4, 0); + CALC_0(123, 4, 1); + CALC_0(123, 4, 2); + + _tmp = vld1q_u8(r1); + CALC_0(5678, 9, 0); + CALC_0(5678, 9, 1); + CALC_0(5678, 9, 2); + + _tmp = vld1q_u8(r2); + CALC_0(10111213, 14, 0); + CALC_0(10111213, 14, 1); + CALC_0(10111213, 14, 2); + + _tmp = vld1q_u8(r3); + CALC_0(15161718, 19, 0); + CALC_0(15161718, 19, 1); + CALC_0(15161718, 19, 2); + + _tmp = vld1q_u8(r4); + CALC_0(20212223, 24, 0); + CALC_0(20212223, 24, 1); + CALC_0(20212223, 24, 2); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + } + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X4(_sum02, outptr + 8, dstptr + 8); + + r0 += 12; + r1 += 12; + r2 += 12; + r3 += 12; + r4 += 12; + outptr += 12; + dstptr += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + } else { + _sum00 = vdupq_n_u32(SHIFT); + } + + _tmp = vtranslq_u8(vld1_u8(r0)); + CALC_0(123, 4, 0); + + _tmp = vtranslq_u8(vld1_u8(r1)); + CALC_0(5678, 9, 0); + + _tmp = vtranslq_u8(vld1_u8(r2)); + CALC_0(10111213, 14, 0); + + _tmp = vtranslq_u8(vld1_u8(r3)); + CALC_0(15161718, 19, 0); + + _tmp = vtranslq_u8(vld1_u8(r4)); + CALC_0(20212223, 24, 0); + + if (last_ic) { + CALC_DST(_sum00); + } + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } +} + +template +void conv_bias::conv_direct_stride1_7x7_quint8_dot( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const uint8_t src_zp, + const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - OW; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp; + if (bias_mode != BiasMode::NO_BIAS) { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT + bias[0]); + } else { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + } + + const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, + 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; + const uint8x16_t _idx11 = {8, 9, 10, 16, 9, 10, 11, 16, + 10, 11, 12, 16, 11, 12, 13, 16}; + + uint8x16_t _tmp, _elem; + uint32x4_t _elem2; + uint32_t* outptr = reinterpret_cast(temp); + uint32_t* outptr2 = outptr + OW; + uint8_t* dstptr = dst; + uint8_t* dstptr2 = dstptr + OW; + + const uint8_t* r0 = src; + const uint8_t* r1 = src + IW; + const uint8_t* r2 = src + IW * 2; + const uint8_t* r3 = src + IW * 3; + const uint8_t* r4 = src + IW * 4; + const uint8_t* r5 = src + IW * 5; + const uint8_t* r6 = src + IW * 6; + const uint8_t* r7 = src + IW * 7; + + const uint8_t* k0 = filter; + + uint8x16_t _k = vld1q_u8(k0); + //! filter row 1 + uint8x16_t _idx = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + uint8x16_t _k123 = vqtbl1q_s8_v7(_k, _idx); + _idx = {4, 5, 6, 16, 4, 5, 6, 16, 4, 5, 6, 16, 4, 5, 6, 16}; + uint8x16_t _k456 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 2 + _idx = {7, 8, 9, 10, 7, 8, 9, 10, 7, 8, 9, 10, 7, 8, 9, 10}; + uint8x16_t _k78910 = vqtbl1q_s8_v7(_k, _idx); + _idx = {11, 12, 13, 16, 11, 12, 13, 16, 11, 12, 13, 16, 11, 12, 13, 16}; + uint8x16_t _k111213 = vqtbl1q_s8_v7(_k, _idx); + + //! 12 13 14 15 -> 16 17 18 19 -> 20 21 22 23 -> 24 25 26 27 + _k = vld1q_u8(k0 + 12); + //! filter row 3 + _idx = {2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5}; + uint8x16_t _k14151617 = vqtbl1q_s8_v7(_k, _idx); + _idx = {6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16}; + uint8x16_t _k181920 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 4 + _idx = {9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12}; + uint8x16_t _k21222324 = vqtbl1q_s8_v7(_k, _idx); + _idx = {13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16}; + uint8x16_t _k252627 = vqtbl1q_s8_v7(_k, _idx); + + //! 24 25 26 27->28 29 30 31 -> 32 33 34 35 -> 36 37 38 39 + _k = vld1q_u8(k0 + 24); + //! filter row 5 + _idx = {4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7}; + uint8x16_t _k28293031 = vqtbl1q_s8_v7(_k, _idx); + _idx = {8, 9, 10, 16, 8, 9, 10, 16, 8, 9, 10, 16, 8, 9, 10, 16}; + uint8x16_t _k323334 = vqtbl1q_s8_v7(_k, _idx); + + //! 33 34 35 36 -> 37 38 39 40 -> 41 42 43 44 -> 45 46 47 48 + _k = vld1q_u8(k0 + 33); + //! filter row 6 + _idx = {2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5}; + uint8x16_t _k35363738 = vqtbl1q_s8_v7(_k, _idx); + _idx = {6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16}; + uint8x16_t _k394041 = vqtbl1q_s8_v7(_k, _idx); + + //! filter row 7 + _idx = {9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12}; + uint8x16_t _k42434445 = vqtbl1q_s8_v7(_k, _idx); + _idx = {13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16}; + uint8x16_t _k464748 = vqtbl1q_s8_v7(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00, _sum01, _sum10, _sum11; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + _sum10 = vld1q_u32(outptr2); + _sum11 = vld1q_u32(outptr2 + 4); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + _sum11 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 456, 0); + CALC_0(123, 456, 1); + + _tmp = vld1q_u8(r1); + CALC_2(78910, 111213, 123, 456, 0); + CALC_2(78910, 111213, 123, 456, 1); + + _tmp = vld1q_u8(r2); + CALC_2(14151617, 181920, 78910, 111213, 0); + CALC_2(14151617, 181920, 78910, 111213, 1); + + _tmp = vld1q_u8(r3); + CALC_2(21222324, 252627, 14151617, 181920, 0); + CALC_2(21222324, 252627, 14151617, 181920, 1); + + _tmp = vld1q_u8(r4); + CALC_2(28293031, 323334, 21222324, 252627, 0); + CALC_2(28293031, 323334, 21222324, 252627, 1); + + _tmp = vld1q_u8(r5); + CALC_2(35363738, 394041, 28293031, 323334, 0); + CALC_2(35363738, 394041, 28293031, 323334, 1); + + _tmp = vld1q_u8(r6); + CALC_2(42434445, 464748, 35363738, 394041, 0); + CALC_2(42434445, 464748, 35363738, 394041, 1); + + _tmp = vld1q_u8(r7); + CALC_1(42434445, 464748, 0); + CALC_1(42434445, 464748, 1); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum10); + CALC_DST(_sum11); + } + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X8(_sum10, _sum11, outptr2, dstptr2); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + r7 += 8; + outptr += 8; + outptr2 += 8; + dstptr += 8; + dstptr2 += 8; + } + for (; w < width; w++) { + uint32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum10 = vld1q_u32(outptr2); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 456, 0); + + _tmp = vld1q_u8(r1); + CALC_2(78910, 111213, 123, 456, 0); + + _tmp = vld1q_u8(r2); + CALC_2(14151617, 181920, 78910, 111213, 0); + + _tmp = vld1q_u8(r3); + CALC_2(21222324, 252627, 14151617, 181920, 0); + + _tmp = vld1q_u8(r4); + CALC_2(28293031, 323334, 21222324, 252627, 0); + + _tmp = vld1q_u8(r5); + CALC_2(35363738, 394041, 28293031, 323334, 0); + + _tmp = vld1q_u8(r6); + CALC_2(42434445, 464748, 35363738, 394041, 0); + + _tmp = vld1q_u8(r7); + CALC_1(42434445, 464748, 0); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + r7 += 4; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + r6 += tail_step + IW; + r7 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00, _sum01; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 456, 0); + CALC_0(123, 456, 1); + + _tmp = vld1q_u8(r1); + CALC_0(78910, 111213, 0); + CALC_0(78910, 111213, 1); + + _tmp = vld1q_u8(r2); + CALC_0(14151617, 181920, 0); + CALC_0(14151617, 181920, 1); + + _tmp = vld1q_u8(r3); + CALC_0(21222324, 252627, 0); + CALC_0(21222324, 252627, 1); + + _tmp = vld1q_u8(r4); + CALC_0(28293031, 323334, 0); + CALC_0(28293031, 323334, 1); + + _tmp = vld1q_u8(r5); + CALC_0(35363738, 394041, 0); + CALC_0(35363738, 394041, 1); + + _tmp = vld1q_u8(r6); + CALC_0(42434445, 464748, 0); + CALC_0(42434445, 464748, 1); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + } + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + outptr += 8; + dstptr += 8; + } + for (; w < width; w++) { + uint32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + } else { + _sum00 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 456, 0); + + _tmp = vld1q_u8(r1); + CALC_0(78910, 111213, 0); + + _tmp = vld1q_u8(r2); + CALC_0(14151617, 181920, 0); + + _tmp = vld1q_u8(r3); + CALC_0(21222324, 252627, 0); + + _tmp = vld1q_u8(r4); + CALC_0(28293031, 323334, 0); + + _tmp = vld1q_u8(r5); + CALC_0(35363738, 394041, 0); + + _tmp = vld1q_u8(r6); + CALC_0(42434445, 464748, 0); + + if (last_ic) { + CALC_DST(_sum00); + } + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } +} + +template +void conv_bias::conv_direct_stride2_5x5_quint8_dot( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const uint8_t src_zp, + const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - 2 * OW + IW; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp; + if (bias_mode != BiasMode::NO_BIAS) { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT + bias[0]); + } else { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + } + + const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; + const uint8x16_t _idx01 = {4, 16, 16, 16, 6, 16, 16, 16, + 8, 16, 16, 16, 10, 16, 16, 16}; + //! start from 8 + const uint8x16_t& _idx10 = _idx00; + const uint8x16_t& _idx11 = _idx01; + + uint8x16_t _tmp, _elem; + uint32x4_t _elem2; + uint32_t* outptr = reinterpret_cast(temp); + uint32_t* outptr2 = outptr + OW; + uint8_t* dstptr = dst; + uint8_t* dstptr2 = dstptr + OW; + + const uint8_t* r0 = src; + const uint8_t* r1 = src + IW; + const uint8_t* r2 = src + IW * 2; + const uint8_t* r3 = src + IW * 3; + const uint8_t* r4 = src + IW * 4; + const uint8_t* r5 = src + IW * 5; + const uint8_t* r6 = src + IW * 6; + + const uint8_t* k0 = filter; + + uint8x16_t _k = vld1q_u8(k0); + //! filter row 1 + uint8x16_t _idx = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + uint8x16_t _k123 = vqtbl1q_s8_v7(_k, _idx); + _idx = {4, 16, 16, 16, 4, 16, 16, 16, 4, 16, 16, 16, 4, 16, 16, 16}; + uint8x16_t _k4 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 2 + _idx = {5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8}; + uint8x16_t _k5678 = vqtbl1q_s8_v7(_k, _idx); + _idx = {9, 16, 16, 16, 9, 16, 16, 16, 9, 16, 16, 16, 9, 16, 16, 16}; + uint8x16_t _k9 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 3 + _idx = {10, 11, 12, 13, 10, 11, 12, 13, 10, 11, 12, 13, 10, 11, 12, 13}; + uint8x16_t _k10111213 = vqtbl1q_s8_v7(_k, _idx); + _idx = {14, 16, 16, 16, 14, 16, 16, 16, 14, 16, 16, 16, 14, 16, 16, 16}; + uint8x16_t _k14 = vqtbl1q_s8_v7(_k, _idx); + //! 9 10 11 12 -> 13 14 15 16 -> 17 18 19 20 -> 21 22 23 24 + _k = vld1q_u8(k0 + 9); + //! filter row 4 + _idx = {6, 7, 8, 9, 6, 7, 8, 9, 6, 7, 8, 9, 6, 7, 8, 9}; + uint8x16_t _k15161718 = vqtbl1q_s8_v7(_k, _idx); + _idx = {10, 16, 16, 16, 10, 16, 16, 16, 10, 16, 16, 16, 10, 16, 16, 16}; + uint8x16_t _k19 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 5 + _idx = {11, 12, 13, 14, 11, 12, 13, 14, 11, 12, 13, 14, 11, 12, 13, 14}; + uint8x16_t _k20212223 = vqtbl1q_s8_v7(_k, _idx); + _idx = {15, 16, 16, 16, 15, 16, 16, 16, 15, 16, 16, 16, 15, 16, 16, 16}; + uint8x16_t _k24 = vqtbl1q_s8_v7(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00, _sum01, _sum10, _sum11; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + _sum10 = vld1q_u32(outptr2); + _sum11 = vld1q_u32(outptr2 + 4); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + _sum11 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 4, 0); + _tmp = vld1q_u8(r0 + 8); + CALC_0(123, 4, 1); + + _tmp = vld1q_u8(r1); + CALC_0(5678, 9, 0); + _tmp = vld1q_u8(r1 + 8); + CALC_0(5678, 9, 1); + + _tmp = vld1q_u8(r2); + CALC_2(10111213, 14, 123, 4, 0); + _tmp = vld1q_u8(r2 + 8); + CALC_2(10111213, 14, 123, 4, 1); + + _tmp = vld1q_u8(r3); + CALC_2(15161718, 19, 5678, 9, 0); + _tmp = vld1q_u8(r3 + 8); + CALC_2(15161718, 19, 5678, 9, 1); + + _tmp = vld1q_u8(r4); + CALC_2(20212223, 24, 10111213, 14, 0); + _tmp = vld1q_u8(r4 + 8); + CALC_2(20212223, 24, 10111213, 14, 1); + + _tmp = vld1q_u8(r5); + CALC_1(15161718, 19, 0); + _tmp = vld1q_u8(r5 + 8); + CALC_1(15161718, 19, 1); + + _tmp = vld1q_u8(r6); + CALC_1(20212223, 24, 0); + _tmp = vld1q_u8(r6 + 8); + CALC_1(20212223, 24, 1); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum10); + CALC_DST(_sum11); + } + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X8(_sum10, _sum11, outptr2, dstptr2); + + r0 += 16; + r1 += 16; + r2 += 16; + r3 += 16; + r4 += 16; + r5 += 16; + r6 += 16; + outptr += 8; + outptr2 += 8; + dstptr += 8; + dstptr2 += 8; + } + for (; w < width; w++) { + uint32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum10 = vld1q_u32(outptr2); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 4, 0); + + _tmp = vld1q_u8(r1); + CALC_0(5678, 9, 0); + + _tmp = vld1q_u8(r2); + CALC_2(10111213, 14, 123, 4, 0); + + _tmp = vld1q_u8(r3); + CALC_2(15161718, 19, 5678, 9, 0); + + _tmp = vld1q_u8(r4); + CALC_2(20212223, 24, 10111213, 14, 0); + + _tmp = vld1q_u8(r5); + CALC_1(15161718, 19, 0); + + _tmp = vld1q_u8(r6); + CALC_1(20212223, 24, 0); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + + r0 += tail_step + IW * 2; + r1 += tail_step + IW * 2; + r2 += tail_step + IW * 2; + r3 += tail_step + IW * 2; + r4 += tail_step + IW * 2; + r5 += tail_step + IW * 2; + r6 += tail_step + IW * 2; + + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00, _sum01; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 4, 0); + _tmp = vld1q_u8(r0 + 8); + CALC_0(123, 4, 1); + + _tmp = vld1q_u8(r1); + CALC_0(5678, 9, 0); + _tmp = vld1q_u8(r1 + 8); + CALC_0(5678, 9, 1); + + _tmp = vld1q_u8(r2); + CALC_0(10111213, 14, 0); + _tmp = vld1q_u8(r2 + 8); + CALC_0(10111213, 14, 1); + + _tmp = vld1q_u8(r3); + CALC_0(15161718, 19, 0); + _tmp = vld1q_u8(r3 + 8); + CALC_0(15161718, 19, 1); + + _tmp = vld1q_u8(r4); + CALC_0(20212223, 24, 0); + _tmp = vld1q_u8(r4 + 8); + CALC_0(20212223, 24, 1); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + } + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + + r0 += 16; + r1 += 16; + r2 += 16; + r3 += 16; + r4 += 16; + outptr += 8; + dstptr += 8; + } + for (; w < width; w++) { + uint32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + } else { + _sum00 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 4, 0); + + _tmp = vld1q_u8(r1); + CALC_0(5678, 9, 0); + + _tmp = vld1q_u8(r2); + CALC_0(10111213, 14, 0); + + _tmp = vld1q_u8(r3); + CALC_0(15161718, 19, 0); + + _tmp = vld1q_u8(r4); + CALC_0(20212223, 24, 0); + + if (last_ic) { + CALC_DST(_sum00); + } + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } +} + +template +void conv_bias::conv_direct_stride2_7x7_quint8_dot( + const uint8_t* src, const uint8_t* filter, const int32_t* bias, + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, + const size_t OH, const size_t OW, const uint8_t src_zp, + const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - 2 * OW + IW; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp; + if (bias_mode != BiasMode::NO_BIAS) { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT + bias[0]); + } else { + _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + } + + const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 6, 7, 8, 16, + 8, 9, 10, 16, 10, 11, 12, 16}; + //! start from 8 + const uint8x16_t& _idx10 = _idx00; + const uint8x16_t& _idx11 = _idx01; + + uint8x16_t _tmp, _elem; + uint32x4_t _elem2; + uint32_t* outptr = reinterpret_cast(temp); + uint32_t* outptr2 = outptr + OW; + uint8_t* dstptr = dst; + uint8_t* dstptr2 = dstptr + OW; + + const uint8_t* r0 = src; + const uint8_t* r1 = src + IW; + const uint8_t* r2 = src + IW * 2; + const uint8_t* r3 = src + IW * 3; + const uint8_t* r4 = src + IW * 4; + const uint8_t* r5 = src + IW * 5; + const uint8_t* r6 = src + IW * 6; + const uint8_t* r7 = src + IW * 7; + const uint8_t* r8 = src + IW * 8; + + const uint8_t* k0 = filter; + + uint8x16_t _k = vld1q_u8(k0); + //! filter row 1 + uint8x16_t _idx = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + uint8x16_t _k123 = vqtbl1q_s8_v7(_k, _idx); + _idx = {4, 5, 6, 16, 4, 5, 6, 16, 4, 5, 6, 16, 4, 5, 6, 16}; + uint8x16_t _k456 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 2 + _idx = {7, 8, 9, 10, 7, 8, 9, 10, 7, 8, 9, 10, 7, 8, 9, 10}; + uint8x16_t _k78910 = vqtbl1q_s8_v7(_k, _idx); + _idx = {11, 12, 13, 16, 11, 12, 13, 16, 11, 12, 13, 16, 11, 12, 13, 16}; + uint8x16_t _k111213 = vqtbl1q_s8_v7(_k, _idx); + + //! 12 13 14 15 -> 16 17 18 19 -> 20 21 22 23 -> 24 25 26 27 + _k = vld1q_u8(k0 + 12); + //! filter row 3 + _idx = {2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5}; + uint8x16_t _k14151617 = vqtbl1q_s8_v7(_k, _idx); + _idx = {6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16}; + uint8x16_t _k181920 = vqtbl1q_s8_v7(_k, _idx); + //! filter row 4 + _idx = {9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12}; + uint8x16_t _k21222324 = vqtbl1q_s8_v7(_k, _idx); + _idx = {13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16}; + uint8x16_t _k252627 = vqtbl1q_s8_v7(_k, _idx); + + //! 24 25 26 27->28 29 30 31 -> 32 33 34 35 -> 36 37 38 39 + _k = vld1q_u8(k0 + 24); + //! filter row 5 + _idx = {4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7}; + uint8x16_t _k28293031 = vqtbl1q_s8_v7(_k, _idx); + _idx = {8, 9, 10, 16, 8, 9, 10, 16, 8, 9, 10, 16, 8, 9, 10, 16}; + uint8x16_t _k323334 = vqtbl1q_s8_v7(_k, _idx); + + //! 33 34 35 36 -> 37 38 39 40 -> 41 42 43 44 -> 45 46 47 48 + _k = vld1q_u8(k0 + 33); + //! filter row 6 + _idx = {2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5}; + uint8x16_t _k35363738 = vqtbl1q_s8_v7(_k, _idx); + _idx = {6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16, 6, 7, 8, 16}; + uint8x16_t _k394041 = vqtbl1q_s8_v7(_k, _idx); + + //! filter row 7 + _idx = {9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12}; + uint8x16_t _k42434445 = vqtbl1q_s8_v7(_k, _idx); + _idx = {13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16, 13, 14, 15, 16}; + uint8x16_t _k464748 = vqtbl1q_s8_v7(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00, _sum01, _sum10, _sum11; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + _sum10 = vld1q_u32(outptr2); + _sum11 = vld1q_u32(outptr2 + 4); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + _sum11 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 456, 0); + _tmp = vld1q_u8(r0 + 8); + CALC_0(123, 456, 1); + + _tmp = vld1q_u8(r1); + CALC_0(78910, 111213, 0); + _tmp = vld1q_u8(r1 + 8); + CALC_0(78910, 111213, 1); + + _tmp = vld1q_u8(r2); + CALC_2(14151617, 181920, 123, 456, 0); + _tmp = vld1q_u8(r2 + 8); + CALC_2(14151617, 181920, 123, 456, 1); + + _tmp = vld1q_u8(r3); + CALC_2(21222324, 252627, 78910, 111213, 0); + _tmp = vld1q_u8(r3 + 8); + CALC_2(21222324, 252627, 78910, 111213, 1); + + _tmp = vld1q_u8(r4); + CALC_2(28293031, 323334, 14151617, 181920, 0); + _tmp = vld1q_u8(r4 + 8); + CALC_2(28293031, 323334, 14151617, 181920, 1); + + _tmp = vld1q_u8(r5); + CALC_2(35363738, 394041, 21222324, 252627, 0); + _tmp = vld1q_u8(r5 + 8); + CALC_2(35363738, 394041, 21222324, 252627, 1); + + _tmp = vld1q_u8(r6); + CALC_2(42434445, 464748, 28293031, 323334, 0); + _tmp = vld1q_u8(r6 + 8); + CALC_2(42434445, 464748, 28293031, 323334, 1); + + _tmp = vld1q_u8(r7); + CALC_1(35363738, 394041, 0); + _tmp = vld1q_u8(r7 + 8); + CALC_1(35363738, 394041, 1); + + _tmp = vld1q_u8(r8); + CALC_1(42434445, 464748, 0); + _tmp = vld1q_u8(r8 + 8); + CALC_1(42434445, 464748, 1); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum10); + CALC_DST(_sum11); + } + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + POSTPROCESS_1X8(_sum10, _sum11, outptr2, dstptr2); + + r0 += 16; + r1 += 16; + r2 += 16; + r3 += 16; + r4 += 16; + r5 += 16; + r6 += 16; + r7 += 16; + r8 += 16; + outptr += 8; + outptr2 += 8; + dstptr += 8; + dstptr2 += 8; + } + for (; w < width; w++) { + uint32x4_t _sum00, _sum10; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum10 = vld1q_u32(outptr2); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum10 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 456, 0); + + _tmp = vld1q_u8(r1); + CALC_0(78910, 111213, 0); + + _tmp = vld1q_u8(r2); + CALC_2(14151617, 181920, 123, 456, 0); + + _tmp = vld1q_u8(r3); + CALC_2(21222324, 252627, 78910, 111213, 0); + + _tmp = vld1q_u8(r4); + CALC_2(28293031, 323334, 14151617, 181920, 0); + + _tmp = vld1q_u8(r5); + CALC_2(35363738, 394041, 21222324, 252627, 0); + + _tmp = vld1q_u8(r6); + CALC_2(42434445, 464748, 28293031, 323334, 0); + + _tmp = vld1q_u8(r7); + CALC_1(35363738, 394041, 0); + + _tmp = vld1q_u8(r8); + CALC_1(42434445, 464748, 0); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + POSTPROCESS_2X4(_sum00, _sum10, outptr, outptr2, dstptr, dstptr2); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + r7 += 8; + r8 += 8; + outptr += 4; + outptr2 += 4; + dstptr += 4; + dstptr2 += 4; + } + + r0 += tail_step + IW * 2; + r1 += tail_step + IW * 2; + r2 += tail_step + IW * 2; + r3 += tail_step + IW * 2; + r4 += tail_step + IW * 2; + r5 += tail_step + IW * 2; + r6 += tail_step + IW * 2; + r7 += tail_step + IW * 2; + r8 += tail_step + IW * 2; + + outptr += OW; + outptr2 += OW; + dstptr += OW; + dstptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00, _sum01; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + _sum01 = vld1q_u32(outptr + 4); + } else { + _sum00 = vdupq_n_u32(SHIFT); + _sum01 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 456, 0); + _tmp = vld1q_u8(r0 + 8); + CALC_0(123, 456, 1); + + _tmp = vld1q_u8(r1); + CALC_0(78910, 111213, 0); + _tmp = vld1q_u8(r1 + 8); + CALC_0(78910, 111213, 1); + + _tmp = vld1q_u8(r2); + CALC_0(14151617, 181920, 0); + _tmp = vld1q_u8(r2 + 8); + CALC_0(14151617, 181920, 1); + + _tmp = vld1q_u8(r3); + CALC_0(21222324, 252627, 0); + _tmp = vld1q_u8(r3 + 8); + CALC_0(21222324, 252627, 1); + + _tmp = vld1q_u8(r4); + CALC_0(28293031, 323334, 0); + _tmp = vld1q_u8(r4 + 8); + CALC_0(28293031, 323334, 1); + + _tmp = vld1q_u8(r5); + CALC_0(35363738, 394041, 0); + _tmp = vld1q_u8(r5 + 8); + CALC_0(35363738, 394041, 1); + + _tmp = vld1q_u8(r6); + CALC_0(42434445, 464748, 0); + _tmp = vld1q_u8(r6 + 8); + CALC_0(42434445, 464748, 1); + + if (last_ic) { + CALC_DST(_sum00); + CALC_DST(_sum01); + } + POSTPROCESS_1X8(_sum00, _sum01, outptr, dstptr); + + r0 += 16; + r1 += 16; + r2 += 16; + r3 += 16; + r4 += 16; + r5 += 16; + r6 += 16; + outptr += 8; + dstptr += 8; + } + for (; w < width; w++) { + uint32x4_t _sum00; + if (!first_ic) { + _sum00 = vld1q_u32(outptr); + } else { + _sum00 = vdupq_n_u32(SHIFT); + } + + _tmp = vld1q_u8(r0); + CALC_0(123, 456, 0); + + _tmp = vld1q_u8(r1); + CALC_0(78910, 111213, 0); + + _tmp = vld1q_u8(r2); + CALC_0(14151617, 181920, 0); + + _tmp = vld1q_u8(r3); + CALC_0(21222324, 252627, 0); + + _tmp = vld1q_u8(r4); + CALC_0(28293031, 323334, 0); + + _tmp = vld1q_u8(r5); + CALC_0(35363738, 394041, 0); + + _tmp = vld1q_u8(r6); + CALC_0(42434445, 464748, 0); + + if (last_ic) { + CALC_DST(_sum00); + } + POSTPROCESS_1X4(_sum00, outptr, dstptr); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + outptr += 4; + dstptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } +} + +#undef CALC_0 +#undef CALC_1 +#undef CALC_2 + +#undef POSTPROCESS_1X8 +#undef POSTPROCESS2_1X8 +#undef POSTPROCESS_2X4 +#undef POSTPROCESS_1X4 +#undef POSTPROCESS_1X1 +#undef ST1_S32X4 +#undef ST2_S32X4X2 + +#define INSTANTIATION(stride, i, first_ic, last_ic, fused_kern, bias, Op) \ + template void conv_bias::conv_direct_##stride##_##i##x##i##_quint8_dot< \ + first_ic, last_ic, fused_kern, bias, Op>( \ + const uint8_t*, const uint8_t*, const int32_t*, int32_t*, \ + uint8_t*, const size_t, const size_t, const size_t, const size_t, \ + const uint8_t, const uint8_t, const int32_t, const Op&); + +#define FOR_NONLINEAR(stride, i, first_ic, last_ic, fused_kern, bias) \ + INSTANTIATION(stride, i, first_ic, last_ic, fused_kern, bias, \ + TypeCvtOp) \ + INSTANTIATION(stride, i, first_ic, last_ic, fused_kern, bias, \ + ReluOp) \ + INSTANTIATION(stride, i, first_ic, last_ic, fused_kern, bias, \ + HSwishOp) + +#define FOR_BIAS(stride, i, first_ic, last_ic, fused_kern) \ + FOR_NONLINEAR(stride, i, first_ic, last_ic, fused_kern, BiasMode::NO_BIAS) \ + FOR_NONLINEAR(stride, i, first_ic, last_ic, fused_kern, \ + BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_KERN(stride, i, first_ic, last_ic) \ + FOR_BIAS(stride, i, first_ic, last_ic, true) \ + FOR_BIAS(stride, i, first_ic, last_ic, false) + +#define FOR_IC(stride, i) \ + FOR_KERN(stride, i, true, true) \ + FOR_KERN(stride, i, true, false) \ + FOR_KERN(stride, i, false, false) \ + FOR_KERN(stride, i, false, true) + +#define FOR_FILTER(stride) \ + FOR_IC(stride, 2) \ + FOR_IC(stride, 3) \ + FOR_IC(stride, 5) \ + FOR_IC(stride, 7) + +#define FOR_STRIDE \ + FOR_FILTER(stride1) \ + FOR_FILTER(stride2) + +FOR_STRIDE + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_IC +#undef FOR_BIAS +#undef FOR_NONLINEAR +#undef INSTANTIATION + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h new file mode 100644 index 00000000..4ba49edb --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h @@ -0,0 +1,46 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#if __ARM_FEATURE_DOTPROD + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace conv_bias { + +#define KERN(stride, i) \ + template \ + void conv_direct_##stride##_##i##x##i##_quint8_dot( \ + const uint8_t* src, const uint8_t* filter, const int32_t* bias, \ + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, \ + const size_t OH, const size_t OW, const uint8_t src_zp, \ + const uint8_t filter_zp, const int32_t src_filter_zp, \ + const Op& op); + +KERN(stride1, 2) +KERN(stride1, 3) +KERN(stride1, 5) +KERN(stride1, 7) + +KERN(stride2, 2) +KERN(stride2, 3) +KERN(stride2, 5) +KERN(stride2, 7) + +#undef KERN + +} // namesapce conv_bias +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/quint8/stride1.cpp b/dnn/src/arm_common/conv_bias/quint8/stride1.cpp new file mode 100644 index 00000000..71477490 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/stride1.cpp @@ -0,0 +1,369 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/stride1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/quint8/stride1.h" +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/quint8/direct.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +using namespace megdnn; +using namespace arm_common; +using namespace direct_quint8_stride1; + +namespace { +bool need_dst_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + return param.osz[1] % 8; +} +bool need_src_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { + return true; + } + return need_dst_copy(param); +} +void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + auto&& fm = param.filter_meta; + auto SW = fm.stride[1]; + auto OH = param.osz[0]; + auto OW = param.osz[1]; + auto FH = fm.spatial[0]; + auto FW = fm.spatial[1]; + + OH2 = OH; + OW2 = (OW + 7) & ~7; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; +} +} // namespace + +bool direct_quint8_stride1::can_conv_direct_stride1_quint8( + const NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto OC = fm.ocpg; + auto IC = fm.icpg; + bool avaible = + param.src_type.enumv() == DTypeEnum::Quantized8Asymm && + param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && + (param.dst_type.enumv() == DTypeEnum::QuantizedS32 || + param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + if (param.bias_type.valid()) { + avaible &= param.bias_type.enumv() == DTypeEnum::QuantizedS32; + } + bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) || + ((FH == 3 || FH == 5 || FH == 7) && + (OC <= 16 || (IC <= 4 && OC <= 32)))) && + param.bias_mode != BiasMode::BIAS; + return avaible && preferred; +} + +WorkspaceBundle direct_quint8_stride1::get_bundle( + const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { + auto&& fm = param.filter_meta; + size_t nr_threads = param.nr_threads; + size_t group = fm.group, batch = param.n; + size_t IC = fm.icpg; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(param, IH2, IW2, OH2, OW2); + size_t src_size = 0, dst_size = 0; + if (need_src_copy(param)) { + src_size = m_large_group + ? IC * IH2 * IW2 * sizeof(uint8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(uint8_t) * group * batch; + }; + if (need_dst_copy(param)) { + dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; + } + if (IC > 1) { + size_t temp_size = OH2 * OW2 * sizeof(int32_t) * nr_threads; + return {nullptr, {src_size, dst_size, temp_size}}; + } else { + return {nullptr, {src_size, dst_size}}; + }; +} +//! Process one input channel copy padding +void direct_quint8_stride1::copy_padding_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t GROUP = kern_param.filter_meta.group; + + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + size_t padding_group_size = IH2 * IW2 * IC; + + bundle.set(kern_param.workspace_ptr); + + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], + channel_id = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + + const uint8_t* sptr = + kern_param.src(batch_id, group_id, channel_id); + if (need_src_copy_var) { + //! copy to sptr_base to eliminate padding effect + uint8_t* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size + + channel_id * IH2 * IW2; + uint8_t _src_zp = + kern_param.src_type.param().zero_point; + std::memset(sptr_base, _src_zp, sizeof(uint8_t) * IH2 * IW2); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(uint8_t) * IW); + } + } +}; +//! compute one output channel +template +void direct_quint8_stride1::do_conv_kern(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t GROUP = kern_param.filter_meta.group; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + bool need_dst_copy_var = need_dst_copy(kern_param); + bool need_post_process = + (kern_param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); + +#define SUB128(n) static_cast(static_cast(n) - 128) + uint8_t _src_zp = + kern_param.src_type.param().zero_point; + int8_t src_zp = SUB128(_src_zp); + int8_t filter_zp = SUB128( + kern_param.filter_type.param().zero_point); + int32_t src_filter_zp = static_cast(filter_zp) * + static_cast(src_zp) * IC * FH * FW; +#undef SUB128 + Op op = Op(1.0f, 1.0f, 0); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = + kern_param.dst_type.param().scale; + uint8_t dst_zp = + kern_param.dst_type.param().zero_point; + op = Op(scale_bias, scale_dst, dst_zp); + } + size_t padding_group_size = IH2 * IW2 * IC; + + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], oc = workspace_ids[2], + group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + + const uint8_t* sptr = kern_param.src(batch_id, group_id); + const uint8_t* fptr = kern_param.filter(group_id) + oc * FH * FW * IC; + void* dst = kern_param.dst(batch_id, group_id, oc); + const int32_t* bptr = kern_param.bias(batch_id, group_id, oc); + if (need_src_copy_var) { + sptr = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size; + } + void* dptr = nullptr; + int32_t* tptr = nullptr; + if (need_dst_copy_var) { + dptr = reinterpret_cast( + reinterpret_cast(bundle.get(1)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size()); + } else { + dptr = dst; + } + +#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride1_##filter##x##filter##_quint8< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ + static_cast(dptr), IH2, IW2, OH2, OW2, src_zp, \ + filter_zp, src_filter_zp, op) + +#define KERN0_NO_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride1_##filter##x##filter##_quint8< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, \ + static_cast(dptr), nullptr, IH2, IW2, OH2, OW2, src_zp, \ + filter_zp, src_filter_zp, op) + +#define KERN1_NEED_POST_PROCESS(filter) \ + KERN0_NEED_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC - 1; ++ic) { \ + KERN0_NEED_POST_PROCESS(filter, false, false); \ + } \ + KERN0_NEED_POST_PROCESS(filter, false, true); + +#define KERN1_NO_POST_PROCESS(filter) \ + KERN0_NO_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC; ++ic) { \ + KERN0_NO_POST_PROCESS(filter, false, false); \ + } + if (need_post_process) { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NEED_POST_PROCESS, true, true) + } else { + tptr = static_cast(bundle.get(2)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); + DISPATCH_FILTER(filter, KERN1_NEED_POST_PROCESS) + } + } else { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NO_POST_PROCESS, true, false) + } else { + DISPATCH_FILTER(filter, KERN1_NO_POST_PROCESS) + } + } +#undef KERN0 +#undef KERN1_NEED_POST_PROCESS +#undef KERN1_NO_POST_PROCESS + if (need_dst_copy_var) { + rep(oh, OH) { + std::memcpy(reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); + } + } +} + +SmallVector direct_quint8_stride1::get_kimpls( + const NCBKernSizeParam& param, bool m_large_group) { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = get_bundle(param, m_large_group); + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ + do_conv_fun = do_conv_kern; + +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(i) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(i, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + if (m_large_group) { + auto exec_one_group = [wbundle, do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + copy_padding_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + do_conv_fun(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + } else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle, do_conv_fun](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/quint8/stride1.h b/dnn/src/arm_common/conv_bias/quint8/stride1.h new file mode 100644 index 00000000..b0de5c91 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/stride1.h @@ -0,0 +1,45 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/stride1.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { +namespace direct_quint8_stride1 { +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +using conv_fun = std::function; + +bool can_conv_direct_stride1_quint8(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); + +void copy_padding_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +template +void do_conv_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +SmallVector get_kimpls(const NCBKernSizeParam& param, + bool); +} // namespace direct_quint8_stride1 +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp new file mode 100644 index 00000000..d72805ad --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp @@ -0,0 +1,369 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/quint8/direct_dotprod.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +using namespace megdnn; +using namespace arm_common; +using namespace direct_dotprod_quint8_stride1; + +namespace { +bool need_dst_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + return param.osz[1] % 8; +} +bool need_src_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { + return true; + } + return need_dst_copy(param); +} +void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + auto&& fm = param.filter_meta; + auto SW = fm.stride[1]; + auto OH = param.osz[0]; + auto OW = param.osz[1]; + auto FH = fm.spatial[0]; + auto FW = fm.spatial[1]; + + OH2 = OH; + OW2 = (OW + 7) & ~7; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; +} +} // namespace + +bool direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8( + const NCBKernSizeParam& param) { + // Semantically it means avaiable, + // but we use it as preferred actually. + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto OC = fm.ocpg; + auto IC = fm.icpg; + bool avaible = + param.src_type.enumv() == DTypeEnum::Quantized8Asymm && + param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && + (param.dst_type.enumv() == DTypeEnum::QuantizedS32 || + param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + if (param.bias_type.valid()) { + avaible &= param.bias_type.enumv() == DTypeEnum::QuantizedS32; + } + bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) || + ((FH == 3 || FH == 5 || FH == 7) && + (OC <= 16 || (IC <= 4 && OC <= 32)))) && + param.bias_mode != BiasMode::BIAS; + return avaible && preferred; +} + +WorkspaceBundle direct_dotprod_quint8_stride1::get_bundle( + const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { + auto&& fm = param.filter_meta; + size_t nr_threads = param.nr_threads; + size_t group = fm.group, batch = param.n; + size_t IC = fm.icpg; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(param, IH2, IW2, OH2, OW2); + size_t src_size = 0, dst_size = 0; + if (need_src_copy(param)) { + src_size = m_large_group + ? IC * IH2 * IW2 * sizeof(uint8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(uint8_t) * group * batch; + }; + if (need_dst_copy(param)) { + dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; + } + if (IC > 1) { + size_t temp_size = OH2 * OW2 * sizeof(int32_t) * nr_threads; + return {nullptr, {src_size, dst_size, temp_size}}; + } else { + return {nullptr, {src_size, dst_size}}; + }; +} +//! Process one input channel copy padding +void direct_dotprod_quint8_stride1::copy_padding_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t GROUP = kern_param.filter_meta.group; + + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], + channel_id = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + const uint8_t* sptr = + kern_param.src(batch_id, group_id, channel_id); + + if (need_src_copy_var) { + //! copy to sptr_base to eliminate padding effect + uint8_t* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size + + channel_id * IH2 * IW2; + uint8_t _src_zp = + kern_param.src_type.param().zero_point; + std::memset(sptr_base, _src_zp, sizeof(uint8_t) * IH2 * IW2); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(uint8_t) * IW); + } + } +}; +//! compute one output channel +template +void direct_dotprod_quint8_stride1::do_conv_kern( + WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t GROUP = kern_param.filter_meta.group; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + bool need_dst_copy_var = need_dst_copy(kern_param); + bool need_post_process = + (kern_param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); + + uint8_t src_zp = + kern_param.src_type.param().zero_point; + uint8_t filter_zp = + kern_param.filter_type.param().zero_point; + int32_t src_filter_zp = static_cast(filter_zp) * + static_cast(src_zp) * IC * FH * FW; + Op op(1.0f, 1.0f, static_cast(0)); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = + kern_param.dst_type.param().scale; + uint8_t dst_zp = + kern_param.dst_type.param().zero_point; + op = Op(scale_bias, scale_dst, dst_zp); + } + size_t padding_group_size = IH2 * IW2 * IC; + + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], oc = workspace_ids[2], + group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + + const uint8_t* sptr = kern_param.src(batch_id, group_id); + const uint8_t* fptr = kern_param.filter(group_id) + oc * FH * FW * IC; + void* dst = kern_param.dst(batch_id, group_id, oc); + const int32_t* bptr = kern_param.bias(batch_id, group_id, oc); + + if (need_src_copy_var) { + sptr = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size; + } + void* dptr = nullptr; + int32_t* tptr = nullptr; + if (need_dst_copy_var) { + dptr = reinterpret_cast( + reinterpret_cast(bundle.get(1)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size()); + } else { + dptr = dst; + } + +#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride1_##filter##x##filter##_quint8_dot< \ + first_ic, last_ic, true, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ + static_cast(dptr), IH2, IW2, OH2, OW2, src_zp, \ + filter_zp, src_filter_zp, op) + +#define KERN0_NO_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride1_##filter##x##filter##_quint8_dot< \ + first_ic, last_ic, false, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, \ + static_cast(dptr), nullptr, IH2, IW2, OH2, OW2, src_zp, \ + filter_zp, src_filter_zp, op) + +#define KERN1_NEED_POST_PROCESS(filter) \ + KERN0_NEED_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC - 1; ++ic) { \ + KERN0_NEED_POST_PROCESS(filter, false, false); \ + } \ + KERN0_NEED_POST_PROCESS(filter, false, true); + +#define KERN1_NO_POST_PROCESS(filter) \ + KERN0_NO_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC - 1; ++ic) { \ + KERN0_NO_POST_PROCESS(filter, false, false); \ + } \ + KERN0_NO_POST_PROCESS(filter, false, true) + if (need_post_process) { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NEED_POST_PROCESS, true, true) + } else { + tptr = static_cast(bundle.get(2)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); + DISPATCH_FILTER(filter, KERN1_NEED_POST_PROCESS) + } + } else { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NO_POST_PROCESS, true, true) + } else { + DISPATCH_FILTER(filter, KERN1_NO_POST_PROCESS) + } + } +#undef KERN0 +#undef KERN1_NEED_POST_PROCESS +#undef KERN1_NO_POST_PROCESS + if (need_dst_copy_var) { + rep(oh, OH) { + std::memcpy(reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); + } + } +} + +SmallVector direct_dotprod_quint8_stride1::get_kimpls( + const NCBKernSizeParam& param, bool m_large_group) { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = get_bundle(param, m_large_group); + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ + do_conv_fun = do_conv_kern; + +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(i) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(i, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + if (m_large_group) { + auto exec_one_group = [wbundle, do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + copy_padding_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + do_conv_fun(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + }else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle, do_conv_fun](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h new file mode 100644 index 00000000..d79a8095 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h @@ -0,0 +1,46 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#if __ARM_FEATURE_DOTPROD +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { +namespace direct_dotprod_quint8_stride1 { +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +using conv_fun = std::function; + +bool can_conv_direct_stride1_quint8(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); + +void copy_padding_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +template +void do_conv_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +SmallVector get_kimpls(const NCBKernSizeParam& param, + bool); +} // namespace direct_dotprod_quint8_stride1 +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/quint8/stride2.cpp b/dnn/src/arm_common/conv_bias/quint8/stride2.cpp new file mode 100644 index 00000000..ffc366c4 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/stride2.cpp @@ -0,0 +1,376 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/stride2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/quint8/stride2.h" +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/quint8/direct.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +using namespace megdnn; +using namespace arm_common; +using namespace direct_quint8_stride2; + +namespace { +bool need_dst_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + return param.osz[1] % 8; +} +bool need_src_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { + return true; + } + return need_dst_copy(param); +} +void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + auto&& fm = param.filter_meta; + size_t SW = fm.stride[1]; + size_t IH = param.isz[0]; + size_t IW = param.isz[1]; + size_t OH = param.osz[0]; + size_t OW = param.osz[1]; + size_t FH = fm.spatial[0]; + size_t FW = fm.spatial[1]; + + OH2 = OH; + OW2 = (OW + 7) & ~7; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; + // Because stride is 2, sometimes IW == IW2+1. Do a max update to + // handle this case. + IH2 = std::max(IH2, IH); + IW2 = std::max(IW2, IW); +} +} // namespace + +bool direct_quint8_stride2::can_conv_direct_stride2_quint8( + const NCBKernSizeParam& param) { + // Semantically it means avaiable, + // but we use it as preferred actually. + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto OC = fm.ocpg; + auto IC = fm.icpg; + bool avaible = + param.src_type.enumv() == DTypeEnum::Quantized8Asymm && + param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && + (param.dst_type.enumv() == DTypeEnum::QuantizedS32 || + param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + if (param.bias_type.valid()) { + avaible &= param.bias_type.enumv() == DTypeEnum::QuantizedS32; + } + bool preferred = (((FH == 2 || FH == 3) && + (IC == 1 || (IC <= 8 && OC <= 12) || OC <= 8)) || + (FH == 5 && ((IC == 1 && OC <= 16) || OC <= 12)) || + (FH == 7 && OC <= 16)) && + (param.bias_mode != BiasMode::BIAS); + return avaible && preferred; +} + +WorkspaceBundle direct_quint8_stride2::get_bundle( + const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { + auto&& fm = param.filter_meta; + size_t nr_threads = param.nr_threads; + size_t group = fm.group, batch = param.n; + size_t IC = fm.icpg; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(param, IH2, IW2, OH2, OW2); + size_t src_size = 0, dst_size = 0; + if (need_src_copy(param)) { + src_size = m_large_group + ? IC * IH2 * IW2 * sizeof(uint8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(uint8_t) * group * batch; + }; + if (need_dst_copy(param)) { + dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; + } + if (IC > 1) { + size_t temp_size = OH2 * OW2 * sizeof(int32_t) * nr_threads; + return {nullptr, {src_size, dst_size, temp_size}}; + } else { + return {nullptr, {src_size, dst_size}}; + }; +} +//! Process one input channel copy padding +void direct_quint8_stride2::copy_padding_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t GROUP = kern_param.filter_meta.group; + + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2], + group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + + const uint8_t* sptr = + kern_param.src(batch_id, group_id, channel_id); + if (need_src_copy_var) { + //! copy to sptr_base to eliminate padding effect + uint8_t* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size + + channel_id * IH2 * IW2; + uint8_t _src_zp = + kern_param.src_type.param().zero_point; + std::memset(sptr_base, _src_zp, sizeof(uint8_t) * IH2 * IW2); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(uint8_t) * IW); + } + } +}; +//! compute one output channel +template +void direct_quint8_stride2::do_conv_kern(WorkspaceBundle bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t GROUP = kern_param.filter_meta.group; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + bool need_dst_copy_var = need_dst_copy(kern_param); + bool need_post_process = + (kern_param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); + +#define SUB128(n) static_cast(static_cast(n) - 128) + uint8_t _src_zp = + kern_param.src_type.param().zero_point; + int8_t src_zp = SUB128(_src_zp); + int8_t filter_zp = SUB128( + kern_param.filter_type.param().zero_point); + int32_t src_filter_zp = static_cast(filter_zp) * + static_cast(src_zp) * IC * FH * FW; +#undef SUB128 + Op op = Op(1.0f, 1.0f, 0); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = + kern_param.dst_type.param().scale; + uint8_t dst_zp = + kern_param.dst_type.param().zero_point; + op = Op(scale_bias, scale_dst, dst_zp); + } + size_t padding_group_size = IH2 * IW2 * IC; + + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], oc = workspace_ids[2], + group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + + const uint8_t* sptr = kern_param.src(batch_id, group_id); + const uint8_t* fptr = kern_param.filter(group_id) + oc * FH * FW * IC; + void* dst = kern_param.dst(batch_id, group_id, oc); + const int32_t* bptr = kern_param.bias(batch_id, group_id, oc); + if (need_src_copy_var) { + sptr = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size; + } + void* dptr = nullptr; + int32_t* tptr = nullptr; + if (need_dst_copy_var) { + dptr = reinterpret_cast( + reinterpret_cast(bundle.get(1)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size()); + } else { + dptr = dst; + } + +#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_quint8< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ + static_cast(dptr), IH2, IW2, OH2, OW2, src_zp, \ + filter_zp, src_filter_zp, op) + +#define KERN0_NO_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_quint8< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, \ + static_cast(dptr), nullptr, IH2, IW2, OH2, OW2, src_zp, \ + filter_zp, src_filter_zp, op) + +#define KERN1_NEED_POST_PROCESS(filter) \ + KERN0_NEED_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC - 1; ++ic) { \ + KERN0_NEED_POST_PROCESS(filter, false, false); \ + } \ + KERN0_NEED_POST_PROCESS(filter, false, true); + +#define KERN1_NO_POST_PROCESS(filter) \ + KERN0_NO_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC; ++ic) { \ + KERN0_NO_POST_PROCESS(filter, false, false); \ + } + if (need_post_process) { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NEED_POST_PROCESS, true, true) + } else { + tptr = static_cast(bundle.get(2)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); + DISPATCH_FILTER(filter, KERN1_NEED_POST_PROCESS) + } + } else { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NO_POST_PROCESS, true, false) + } else { + DISPATCH_FILTER(filter, KERN1_NO_POST_PROCESS) + } + } +#undef KERN0 +#undef KERN1_NEED_POST_PROCESS +#undef KERN1_NO_POST_PROCESS + if (need_dst_copy_var) { + rep(oh, OH) { + std::memcpy(reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); + } + } +} + +SmallVector direct_quint8_stride2::get_kimpls( + const NCBKernSizeParam& param, bool m_large_group) { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = get_bundle(param, m_large_group); + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ + do_conv_fun = do_conv_kern; + +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(i) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(i, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + if (m_large_group) { + auto exec_one_group = [wbundle, do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + copy_padding_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + do_conv_fun(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + }else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle, do_conv_fun](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/quint8/stride2.h b/dnn/src/arm_common/conv_bias/quint8/stride2.h new file mode 100644 index 00000000..b73d02e2 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/stride2.h @@ -0,0 +1,45 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/stride2.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { +namespace direct_quint8_stride2 { +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +using conv_fun = std::function; + +bool can_conv_direct_stride2_quint8(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); + +void copy_padding_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +template +void do_conv_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +SmallVector get_kimpls(const NCBKernSizeParam& param, + bool); +} // namespace direct_quint8_stride2 +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp new file mode 100644 index 00000000..0ce54962 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp @@ -0,0 +1,374 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/conv_bias/quint8/stride2_dotprod.h" +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/quint8/direct_dotprod.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +using namespace megdnn; +using namespace arm_common; +using namespace direct_dotprod_quint8_stride2; + +namespace { +bool need_dst_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + return param.osz[1] % 8; +} +bool need_src_copy( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { + if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { + return true; + } + return need_dst_copy(param); +} +void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + auto&& fm = param.filter_meta; + size_t SW = fm.stride[1]; + size_t IH = param.isz[0]; + size_t IW = param.isz[1]; + size_t OH = param.osz[0]; + size_t OW = param.osz[1]; + size_t FH = fm.spatial[0]; + size_t FW = fm.spatial[1]; + + OH2 = OH; + OW2 = (OW + 7) & ~7; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; + // Because stride is 2, sometimes IW == IW2+1. Do a max update to + // handle this case. + IH2 = std::max(IH2, IH); + IW2 = std::max(IW2, IW); +} +} // namespace + +bool direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8( + const NCBKernSizeParam& param) { + // Semantically it means avaiable, + // but we use it as preferred actually. + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto OC = fm.ocpg; + auto IC = fm.icpg; + bool avaible = + param.src_type.enumv() == DTypeEnum::Quantized8Asymm && + param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && + (param.dst_type.enumv() == DTypeEnum::QuantizedS32 || + param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + if (param.bias_type.valid()) { + avaible &= param.bias_type.enumv() == DTypeEnum::QuantizedS32; + } + bool preferred = (((FH == 2 || FH == 3) && + (IC == 1 || (IC <= 8 && OC <= 12) || OC <= 8)) || + (FH == 5 && ((IC == 1 && OC <= 16) || OC <= 12)) || + (FH == 7 && OC <= 16)) && + (param.bias_mode != BiasMode::BIAS); + return avaible && preferred; +} + +WorkspaceBundle direct_dotprod_quint8_stride2::get_bundle( + const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { + auto&& fm = param.filter_meta; + size_t nr_threads = param.nr_threads; + size_t group = fm.group, batch = param.n; + size_t IC = fm.icpg; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(param, IH2, IW2, OH2, OW2); + size_t src_size = 0, dst_size = 0; + if (need_src_copy(param)) { + src_size = m_large_group + ? IC * IH2 * IW2 * sizeof(uint8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(uint8_t) * group * batch; + }; + if (need_dst_copy(param)) { + dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; + } + if (IC > 1) { + size_t temp_size = OH2 * OW2 * sizeof(int32_t) * nr_threads; + return {nullptr, {src_size, dst_size, temp_size}}; + } else { + return {nullptr, {src_size, dst_size}}; + }; +} +//! Process one input channel copy padding +void direct_dotprod_quint8_stride2::copy_padding_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t GROUP = kern_param.filter_meta.group; + + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + size_t padding_group_size = IH2 * IW2 * IC; + bundle.set(kern_param.workspace_ptr); + + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], + channel_id = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + const uint8_t* sptr = + kern_param.src(batch_id, group_id, channel_id); + if (need_src_copy_var) { + //! copy to sptr_base to eliminate padding effect + uint8_t* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size + + channel_id * IH2 * IW2; + uint8_t _src_zp = + kern_param.src_type.param().zero_point; + std::memset(sptr_base, _src_zp, sizeof(uint8_t) * IH2 * IW2); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(uint8_t) * IW); + } + } +}; +//! compute one output channel +template +void direct_dotprod_quint8_stride2::do_conv_kern( + WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t GROUP = kern_param.filter_meta.group; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + bool need_dst_copy_var = need_dst_copy(kern_param); + bool need_post_process = + (kern_param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); + + uint8_t src_zp = + kern_param.src_type.param().zero_point; + uint8_t filter_zp = + kern_param.filter_type.param().zero_point; + int32_t src_filter_zp = static_cast(filter_zp) * + static_cast(src_zp) * IC * FH * FW; + Op op(1.0f, 1.0f, static_cast(0)); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = + kern_param.dst_type.param().scale; + uint8_t dst_zp = + kern_param.dst_type.param().zero_point; + op = Op(scale_bias, scale_dst, dst_zp); + } + size_t padding_group_size = IH2 * IW2 * IC; + + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], oc = workspace_ids[2], + group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + + const uint8_t* sptr = kern_param.src(batch_id, group_id); + const uint8_t* fptr = kern_param.filter(group_id) + oc * FH * FW * IC; + void* dst = kern_param.dst(batch_id, group_id, oc); + const int32_t* bptr = kern_param.bias(batch_id, group_id, oc); + + if (need_src_copy_var) { + sptr = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size + + workspace_batch_id * GROUP * padding_group_size; + } + void* dptr = nullptr; + int32_t* tptr = nullptr; + if (need_dst_copy_var) { + dptr = reinterpret_cast( + reinterpret_cast(bundle.get(1)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size()); + } else { + dptr = dst; + } + +#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_quint8_dot< \ + first_ic, last_ic, true, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ + static_cast(dptr), IH2, IW2, OH2, OW2, src_zp, \ + filter_zp, src_filter_zp, op) + +#define KERN0_NO_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_quint8_dot< \ + first_ic, last_ic, false, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, \ + static_cast(dptr), nullptr, IH2, IW2, OH2, OW2, src_zp, \ + filter_zp, src_filter_zp, op) + +#define KERN1_NEED_POST_PROCESS(filter) \ + KERN0_NEED_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC - 1; ++ic) { \ + KERN0_NEED_POST_PROCESS(filter, false, false); \ + } \ + KERN0_NEED_POST_PROCESS(filter, false, true); + +#define KERN1_NO_POST_PROCESS(filter) \ + KERN0_NO_POST_PROCESS(filter, true, false); \ + for (ic = 1; ic < IC - 1; ++ic) { \ + KERN0_NO_POST_PROCESS(filter, false, false); \ + } \ + KERN0_NO_POST_PROCESS(filter, false, true); + if (need_post_process) { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NEED_POST_PROCESS, true, true) + } else { + tptr = static_cast(bundle.get(2)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); + DISPATCH_FILTER(filter, KERN1_NEED_POST_PROCESS) + } + } else { + size_t ic = 0; + if (IC == 1) { + DISPATCH_FILTER(filter, KERN0_NO_POST_PROCESS, true, true) + } else { + DISPATCH_FILTER(filter, KERN1_NO_POST_PROCESS) + } + } +#undef KERN0 +#undef KERN1_NEED_POST_PROCESS +#undef KERN1_NO_POST_PROCESS + if (need_dst_copy_var) { + rep(oh, OH) { + std::memcpy(reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); + } + } +} + +SmallVector direct_dotprod_quint8_stride2::get_kimpls( + const NCBKernSizeParam& param, bool m_large_group) { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = param.filter_meta.icpg; + size_t OC = param.filter_meta.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = get_bundle(param, m_large_group); + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ + do_conv_fun = do_conv_kern; + +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(i) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(i, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + if (m_large_group) { + auto exec_one_group = [wbundle, do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + WorkspaceBundle bundle = wbundle; + for (size_t ic = 0; ic < IC; ic++) { + copy_padding_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + do_conv_fun(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); + } + }; + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); + }else { + WorkspaceBundle bundle = wbundle; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {group, N, IC}}); + auto do_conv = [bundle, do_conv_fun](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_conv, {group, N, OC}}); + } + return ret_kerns; +} +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h new file mode 100644 index 00000000..0c8049d9 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h @@ -0,0 +1,46 @@ +/** + * \file dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#if __ARM_FEATURE_DOTPROD +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { +namespace direct_dotprod_quint8_stride2 { +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +using conv_fun = std::function; + +bool can_conv_direct_stride2_quint8(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); + +void copy_padding_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +template +void do_conv_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); + +SmallVector get_kimpls(const NCBKernSizeParam& param, + bool); +} // namespace direct_dotprod_quint8_stride2 +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/winograd_common/winograd_common.h b/dnn/src/arm_common/conv_bias/winograd_common/winograd_common.h new file mode 100644 index 00000000..d1e43466 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/winograd_common/winograd_common.h @@ -0,0 +1,44 @@ +/** + * \file dnn/src/arm_common/conv_bias/winograd_common/winograd_common.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace arm_common { +namespace { + +template +struct InputGetter; + +template <> +struct InputGetter { + int16x4_t operator()(const int8_t* ptr) { + return vget_low_s16(vmovl_s8(vld1_s8(ptr))); + } +}; + +template <> +struct InputGetter { + uint16x4_t zp; + InputGetter(uint8_t zero_point) { + zp = vdup_n_u16(static_cast(zero_point)); + } + uint16x4_t operator()(const uint8_t* ptr) { + return vget_low_u16(vmovl_u8(vld1_u8(ptr))) - zp; + } +}; +} // namespace +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/img2col_helper.h b/dnn/src/arm_common/convolution/img2col_helper.h new file mode 100644 index 00000000..f6e8e441 --- /dev/null +++ b/dnn/src/arm_common/convolution/img2col_helper.h @@ -0,0 +1,82 @@ +/** + * \file dnn/src/arm_common/convolution/img2col_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include +#include "src/common/utils.h" + +namespace { + +template +void img2col_stride(const dtype* __restrict src, + dtype* __restrict dst, const int OC, const int OH, + const int OW, const int IC, const int IH, const int IW, + const int FH, const int FW, const int SH, const int SW) { + (void)OC; + size_t i = 0; + rep(ic, IC) { + rep(fh, FH) { + rep(fw, FW) { + rep(oh, OH) { + rep(ow, OW) { + int fh2, fw2; + if (is_xcorr) { + fh2 = fh; + fw2 = fw; + } else { + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; + } + dst[i++] = src[ic * IH * IW + (oh * SH + fh2) * IW + + (ow * SW + fw2)]; + } + } + } + } + } +} + +template +void img2col(const dtype* src, dtype* dst, size_t /* OC */, size_t OH, + size_t OW, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW) { + size_t offset = (4 - OW % 4) % 4; + size_t i = 0; + rep(ic, IC) { + rep(fh, FH) { + rep(fw, FW) { + rep(oh, OH) { + size_t ow = 0; + for (; ow < OW; ow += 4) { + size_t fh2, fw2; + if (is_xcorr) { + fh2 = fh; + fw2 = fw; + } else { + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; + } + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + + (ow + fw2) + 0]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + + (ow + fw2) + 1]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + + (ow + fw2) + 2]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + + (ow + fw2) + 3]; + } + i -= offset; + } + } + } + } +} + +} // anonymous namespace + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/int8x8x32/algos.cpp b/dnn/src/arm_common/convolution/int8x8x32/algos.cpp new file mode 100644 index 00000000..771b2021 --- /dev/null +++ b/dnn/src/arm_common/convolution/int8x8x32/algos.cpp @@ -0,0 +1,63 @@ +/** + * \file dnn/src/arm_common/convolution/int8x8x32/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/convolution/img2col_helper.h" +#include "src/arm_common/convolution/int8x8x32/algos.h" +#include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h" +#include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h" + +#include "midout.h" +#include "src/common/opr_delegate.h" + +MIDOUT_DECL(megdnn_arm_conv_int8832_kimpl) + +using namespace megdnn; +using namespace arm_common; + +#if __ARM_FEATURE_DOTPROD +/* ===================== ConvolutionBackwardData ===================== */ +/* ===================== direct stride 1 algo ===================== */ +bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable( + ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + return deconv::can_stride1_int8x8x32_dot(param); +} + +size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace( + ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param); +} + +ConvolutionBackwardDataImpl::ncb_kern_t +ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( + ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { + return deconv::stride1_int8x8x32_dot; +} + +/* ===================== direct stride 2 algo ===================== */ +bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable( + ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + return deconv::can_stride2_int8x8x32_dot(param); +} + +size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace( + ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param); +} + +ConvolutionBackwardDataImpl::ncb_kern_t +ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::dispatch_kern( + ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { + return deconv::stride2_int8x8x32_dot; +} + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/int8x8x32/algos.h b/dnn/src/arm_common/convolution/int8x8x32/algos.h new file mode 100644 index 00000000..154cc8a5 --- /dev/null +++ b/dnn/src/arm_common/convolution/int8x8x32/algos.h @@ -0,0 +1,61 @@ +/** + * \file dnn/src/arm_common/convolution/int8x8x32/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/convolution/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +#if __ARM_FEATURE_DOTPROD +/* ===================== ConvolutionBackwardData ===================== */ + +class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH32_I8x8x32_DECONV_STRIDE1"; } + + bool usable(ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const override; + + size_t get_workspace(ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const override; + + ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, + const NCBKernSizeParam&) const override; + + void* type() const override { return sm_arm_common_algo_type; } +}; + +class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH32_I8x8x32_DECONV_STRIDE2"; } + + bool usable(ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const override; + + size_t get_workspace(ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const override; + + ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, + const NCBKernSizeParam&) const override; + + void* type() const override { return sm_arm_common_algo_type; } +}; + +#endif + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp new file mode 100644 index 00000000..5ffed44d --- /dev/null +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp @@ -0,0 +1,1170 @@ +/** + * \file dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h" +#include "src/common/utils.h" + +#include +#include "src/arm_common/simd_macro/marm_neon.h" + +using namespace megdnn; +using namespace arm_common; +using namespace deconv; + +namespace { + +bool need_dst_copy(const NCBKernSizeParam& param) { + if (param.osz[1] % 4 != 0) { + // If the size of output is not multiples of 4, we need to copy it. + return true; + } + return false; +} + +bool need_src_copy(const NCBKernSizeParam& param) { + auto FH = param.filter_meta.spatial[0], FW = param.filter_meta.spatial[1], + PH = param.filter_meta.padding[0], PW = param.filter_meta.padding[1]; + return FH > PH + 1 || FW > PW + 1 || need_dst_copy(param); +} + +void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, + size_t FW, size_t PH, size_t PW, size_t& IH2, + size_t& IW2, size_t& OW2) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + MEGDNN_MARK_USED_VAR(IW); + MEGDNN_MARK_USED_VAR(PW); + //! OW should be a multiple of 4 + OW2 = (OW + 3) & ~3; + IH2 = IH + 2 * (FH - PH - 1); + //! OW2 - FW + 1 + 2 * PW + 2 * (FW - PW - 1); + IW2 = OW2 + FW - 1; +} + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param) { + UNPACK_CONV_F32_NCB_KERN_SIZES(param); + MEGDNN_MARK_USED_VAR(N); + MEGDNN_MARK_USED_VAR(OC); + MEGDNN_MARK_USED_VAR(SH); + MEGDNN_MARK_USED_VAR(SW); + size_t src_size = 0, dst_size = 0; + size_t IH2, IW2, OW2; + get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); + if (need_src_copy(param)) { + //! We only need one src plane + src_size = sizeof(int8_t) * IH2 * IW2; + } + if (need_dst_copy(param)) { + dst_size = sizeof(int32_t) * IC * OH * OW2; + } + return WorkspaceBundle(nullptr, {src_size, dst_size}); +} + +inline int8x16_t vqtbl1q_s8_common(int8x16_t a, uint8x16_t index) { + int8x8x2_t src; + src.val[0] = vget_low_s8(a); + src.val[1] = vget_high_s8(a); + uint8x8_t index_low = vget_low_u8(index); + uint8x8_t index_high = vget_high_u8(index); + int8x8_t r00 = vtbl2_s8(src, vreinterpret_s8_u8(index_low)); + int8x8_t r01 = vtbl2_s8(src, vreinterpret_s8_u8(index_high)); + int8x16_t r = vcombine_s8(r00, r01); + return r; +} + +#define CALC_0(_k_idx, _c_idx) \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k_idx, _elem); + +#define CALC_1(_k_idx, _c_idx) \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k_idx, _elem); + +#define CALC_2(_k1_idx, _k2_idx, _c_idx) \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k1_idx, _elem); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); + +void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - OW; + + const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, + 2, 3, 16, 16, 3, 4, 16, 16}; + const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, + 6, 7, 16, 16, 7, 8, 16, 16}; + rep(ic, IC) { + const int8_t* src_ptr = src; + int32_t* dst_ptr = dst + OW * OH * ic; + int32_t* outptr = dst_ptr; + int32_t* outptr2 = dst_ptr + OW; + + const int8_t* r0 = src_ptr; + const int8_t* r1 = src_ptr + IW; + const int8_t* r2 = src_ptr + 2 * IW; + + const int8_t* k0 = filter; + + int8x16_t _k0 = vreinterpretq_s8_s32( + vdupq_n_s32(*reinterpret_cast(k0))); + uint8x16_t _idx_k = {3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0}; + int8x16_t _k = vqtbl1q_s8_common(_k0, _idx_k); + uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, + 0, 1, 16, 16, 0, 1, 16, 16}; + int8x16_t _k1 = vqtbl1q_s8_common(_k, _idx); + _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; + int8x16_t _k23 = vqtbl1q_s8_common(_k, _idx); + + int8x16_t _tmp, _elem; + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 4 < width; w += 4) { + int32x4x2_t _sum00, _sum01, _sum10, _sum11; + _sum00 = vld2q_s32(outptr); + _sum01 = vld2q_s32(outptr + 8); + _sum10 = vld2q_s32(outptr2); + _sum11 = vld2q_s32(outptr2 + 8); + + int8x16_t _r00 = vld1q_s8(r0); + //! here will not not read out of bound + int8x16_t _r01_ = vdupq_n_s8(r0[16]); + int8x16_t _r10 = vld1q_s8(r1); + int8x16_t _r11_ = vdupq_n_s8(r1[16]); + int8x16_t _r20 = vld1q_s8(r2); + int8x16_t _r21_ = vdupq_n_s8(r2[16]); + int8x16_t _r01 = vextq_s8(_r00, _r01_, 1); + int8x16_t _r11 = vextq_s8(_r10, _r11_, 1); + int8x16_t _r21 = vextq_s8(_r20, _r21_, 1); + + int16x8x2_t r_00 = vzipq_s16(vreinterpretq_s16_s8(_r00), + vreinterpretq_s16_s8(_r10)); + int8x16_t _r0 = r_00.val[0]; + int8x16_t _r2 = r_00.val[1]; + + int16x8x2_t r_11 = vzipq_s16(vreinterpretq_s16_s8(_r01), + vreinterpretq_s16_s8(_r11)); + int8x16_t _r1 = r_11.val[0]; + int8x16_t _r3 = r_11.val[1]; + + _sum00.val[0] = vdotq_s32(_sum00.val[0], _k, _r0); + _sum00.val[1] = vdotq_s32(_sum00.val[1], _k, _r1); + _sum01.val[0] = vdotq_s32(_sum01.val[0], _k, _r2); + _sum01.val[1] = vdotq_s32(_sum01.val[1], _k, _r3); + + r_00 = vzipq_s16(vreinterpretq_s16_s8(_r10), + vreinterpretq_s16_s8(_r20)); + _r0 = r_00.val[0]; + _r2 = r_00.val[1]; + + r_11 = vzipq_s16(vreinterpretq_s16_s8(_r11), + vreinterpretq_s16_s8(_r21)); + _r1 = r_11.val[0]; + _r3 = r_11.val[1]; + + _sum10.val[0] = vdotq_s32(_sum10.val[0], _k, _r0); + _sum10.val[1] = vdotq_s32(_sum10.val[1], _k, _r1); + _sum11.val[0] = vdotq_s32(_sum11.val[0], _k, _r2); + _sum11.val[1] = vdotq_s32(_sum11.val[1], _k, _r3); + + vst2q_s32(outptr, _sum00); + vst2q_s32(outptr + 8, _sum01); + vst2q_s32(outptr2, _sum10); + vst2q_s32(outptr2 + 8, _sum11); + + r0 += 16; + r1 += 16; + r2 += 16; + outptr += 16; + outptr2 += 16; + } + for (; w + 2 < width; w += 2) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + int32x4_t _sum10 = vld1q_s32(outptr2); + int32x4_t _sum11 = vld1q_s32(outptr2 + 4); + + _tmp = vld1q_s8(r0); + CALC_0(1, 0); + CALC_0(1, 1); + + _tmp = vld1q_s8(r1); + CALC_2(23, 1, 0); + CALC_2(23, 1, 1); + + _tmp = vld1q_s8(r2); + CALC_1(23, 0); + CALC_1(23, 1); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + vst1q_s32(outptr2, _sum10); + vst1q_s32(outptr2 + 4, _sum11); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 8; + outptr2 += 8; + } + + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum10 = vld1q_s32(outptr2); + + _tmp = vtranslq_s8(vld1_s8(r0)); + CALC_0(1, 0); + + _tmp = vtranslq_s8(vld1_s8(r1)); + CALC_2(23, 1, 0); + + _tmp = vtranslq_s8(vld1_s8(r2)); + CALC_1(23, 0); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr2, _sum10); + + r0 += 4; + r1 += 4; + r2 += 4; + outptr += 4; + outptr2 += 4; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 4 < width; w += 4) { + int32x4x2_t _sum0, _sum1; + _sum0 = vld2q_s32(outptr); + _sum1 = vld2q_s32(outptr + 8); + + int8x16_t _r00 = vld1q_s8(r0); + //! here will not not read out of bound + int8x16_t _r01_ = vdupq_n_s8(r0[16]); + int8x16_t _r10 = vld1q_s8(r1); + int8x16_t _r11_ = vdupq_n_s8(r1[16]); + int8x16_t _r01 = vextq_s8(_r00, _r01_, 1); + int8x16_t _r11 = vextq_s8(_r10, _r11_, 1); + + int16x8x2_t r_00 = vzipq_s16(vreinterpretq_s16_s8(_r00), + vreinterpretq_s16_s8(_r10)); + int8x16_t _r0 = r_00.val[0]; + int8x16_t _r2 = r_00.val[1]; + + int16x8x2_t r_11 = vzipq_s16(vreinterpretq_s16_s8(_r01), + vreinterpretq_s16_s8(_r11)); + int8x16_t _r1 = r_11.val[0]; + int8x16_t _r3 = r_11.val[1]; + + _sum0.val[0] = vdotq_s32(_sum0.val[0], _k, _r0); + _sum0.val[1] = vdotq_s32(_sum0.val[1], _k, _r1); + _sum1.val[0] = vdotq_s32(_sum1.val[0], _k, _r2); + _sum1.val[1] = vdotq_s32(_sum1.val[1], _k, _r3); + + vst2q_s32(outptr, _sum0); + vst2q_s32(outptr + 8, _sum1); + + r0 += 16; + r1 += 16; + outptr += 16; + } + for (; w + 2 < width; w += 2) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + + _tmp = vld1q_s8(r0); + CALC_0(1, 0); + CALC_0(1, 1); + + _tmp = vld1q_s8(r1); + CALC_0(23, 0); + CALC_0(23, 1); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + + r0 += 8; + r1 += 8; + outptr += 8; + } + + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + + _tmp = vtranslq_s8(vld1_s8(r0)); + CALC_0(1, 0); + + _tmp = vtranslq_s8(vld1_s8(r1)); + CALC_0(23, 0); + + vst1q_s32(outptr, _sum00); + + r0 += 4; + r1 += 4; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + } + + filter += 4; + } +} + +void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - OW; + + const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, + 2, 3, 4, 16, 3, 4, 5, 16}; + const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, + 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx2 = {8, 9, 10, 16, 9, 10, 11, 16, + 10, 11, 12, 16, 11, 12, 13, 16}; + rep(ic, IC) { + const int8_t* src_ptr = src; + int32_t* dst_ptr = dst + OW * OH * ic; + int32_t* outptr = dst_ptr; + int32_t* outptr2 = outptr + OW; + + const int8_t* r0 = src_ptr; + const int8_t* r1 = src_ptr + IW; + const int8_t* r2 = src_ptr + IW * 2; + const int8_t* r3 = src_ptr + IW * 3; + + const int8_t* k0 = filter; + + int8x16_t _k_tmp = vcombine_s8(vld1_s8(k0), vdup_n_s8(k0[8])); + uint8x16_t _idx = {8, 7, 6, 16, 8, 7, 6, 16, 8, 7, 6, 16, 8, 7, 6, 16}; + int8x16_t _k12 = vqtbl1q_s8_common(_k_tmp, _idx); + _idx = {5, 4, 3, 16, 5, 4, 3, 16, 5, 4, 3, 16, 5, 4, 3, 16}; + int8x16_t _k345 = vqtbl1q_s8_common(_k_tmp, _idx); + _idx = {2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16}; + int8x16_t _k678 = vqtbl1q_s8_common(_k_tmp, _idx); + + int8x16_t _tmp, _elem; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + //! As the inner kernel read 16 elements, and IW is times of 16 + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + int32x4_t _sum02 = vld1q_s32(outptr + 8); + int32x4_t _sum10 = vld1q_s32(outptr2); + int32x4_t _sum11 = vld1q_s32(outptr2 + 4); + int32x4_t _sum12 = vld1q_s32(outptr2 + 8); + + _tmp = vld1q_s8(r0); + CALC_0(12, 0); + CALC_0(12, 1); + CALC_0(12, 2); + + _tmp = vld1q_s8(r1); + CALC_2(345, 12, 0); + CALC_2(345, 12, 1); + CALC_2(345, 12, 2); + + _tmp = vld1q_s8(r2); + CALC_2(678, 345, 0); + CALC_2(678, 345, 1); + CALC_2(678, 345, 2); + + _tmp = vld1q_s8(r3); + CALC_1(678, 0); + CALC_1(678, 1); + CALC_1(678, 2); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + vst1q_s32(outptr + 8, _sum02); + vst1q_s32(outptr2, _sum10); + vst1q_s32(outptr2 + 4, _sum11); + vst1q_s32(outptr2 + 8, _sum12); + + r0 += 12; + r1 += 12; + r2 += 12; + r3 += 12; + outptr += 12; + outptr2 += 12; + } + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum10 = vld1q_s32(outptr2); + + _tmp = vtranslq_s8(vld1_s8(r0)); + CALC_0(12, 0); + + _tmp = vtranslq_s8(vld1_s8(r1)); + CALC_2(345, 12, 0); + + _tmp = vtranslq_s8(vld1_s8(r2)); + CALC_2(678, 345, 0); + + _tmp = vtranslq_s8(vld1_s8(r3)); + CALC_1(678, 0); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr2, _sum10); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + int32x4_t _sum02 = vld1q_s32(outptr + 8); + + _tmp = vld1q_s8(r0); + CALC_0(12, 0); + CALC_0(12, 1); + CALC_0(12, 2); + + _tmp = vld1q_s8(r1); + CALC_0(345, 0); + CALC_0(345, 1); + CALC_0(345, 2); + + _tmp = vld1q_s8(r2); + CALC_0(678, 0); + CALC_0(678, 1); + CALC_0(678, 2); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + vst1q_s32(outptr + 8, _sum02); + + r0 += 12; + r1 += 12; + r2 += 12; + outptr += 12; + } + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + + _tmp = vtranslq_s8(vld1_s8(r0)); + CALC_0(12, 0); + + _tmp = vtranslq_s8(vld1_s8(r1)); + CALC_0(345, 0); + + _tmp = vtranslq_s8(vld1_s8(r2)); + CALC_0(678, 0); + + vst1q_s32(outptr, _sum00); + + r0 += 4; + r1 += 4; + r2 += 4; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } + + filter += 9; + } +} + +#undef CALC_0 +#undef CALC_1 +#undef CALC_2 + +#define CALC_0(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k01_idx, _elem); + +#define CALC_1(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx##0); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k00_idx, _elem); \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx##1); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k01_idx, _elem); + +#define CALC_2(_k00_idx, _k01_idx, _k10_idx, _k11_idx, _c_idx) \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k10_idx, _elem); \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); + +void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - OW; + + const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; + const uint8x16_t _idx01 = {4, 16, 16, 16, 5, 16, 16, 16, + 6, 16, 16, 16, 7, 16, 16, 16}; + const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; + const uint8x16_t _idx11 = {8, 16, 16, 16, 9, 16, 16, 16, + 10, 16, 16, 16, 11, 16, 16, 16}; + const uint8x16_t _idx20 = {8, 9, 10, 11, 9, 10, 11, 12, + 10, 11, 12, 13, 11, 12, 13, 14}; + const uint8x16_t _idx21 = {12, 16, 16, 16, 13, 16, 16, 16, + 14, 16, 16, 16, 15, 16, 16, 16}; + int8x16_t _tmp, _elem; + rep(ic, IC) { + const int8_t* src_ptr = src; + int32_t* dst_ptr = dst + OW * OH * ic; + int32_t* outptr = dst_ptr; + int32_t* outptr2 = outptr + OW; + + const int8_t* r0 = src_ptr; + const int8_t* r1 = src_ptr + IW; + const int8_t* r2 = src_ptr + IW * 2; + const int8_t* r3 = src_ptr + IW * 3; + const int8_t* r4 = src_ptr + IW * 4; + const int8_t* r5 = src_ptr + IW * 5; + + const int8_t* k0 = filter; + + int8x16_t _k = vld1q_s8(k0 + 9); + //! filter row 1 + uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, + 15, 14, 13, 12, 15, 14, 13, 12}; + int8x16_t _k123 = vqtbl1q_s8_common(_k, _idx); + _idx = {11, 16, 16, 16, 11, 16, 16, 16, 11, 16, 16, 16, 11, 16, 16, 16}; + int8x16_t _k4 = vqtbl1q_s8_common(_k, _idx); + //! filter row 2 + _idx = {10, 9, 8, 7, 10, 9, 8, 7, 10, 9, 8, 7, 10, 9, 8, 7}; + int8x16_t _k5678 = vqtbl1q_s8_common(_k, _idx); + _idx = {6, 16, 16, 16, 6, 16, 16, 16, 6, 16, 16, 16, 6, 16, 16, 16}; + int8x16_t _k9 = vqtbl1q_s8_common(_k, _idx); + //! filter row 3 + _idx = {5, 4, 3, 2, 5, 4, 3, 2, 5, 4, 3, 2, 5, 4, 3, 2}; + int8x16_t _k10111213 = vqtbl1q_s8_common(_k, _idx); + _idx = {1, 16, 16, 16, 1, 16, 16, 16, 1, 16, 16, 16, 1, 16, 16, 16}; + int8x16_t _k14 = vqtbl1q_s8_common(_k, _idx); + //! 9 10 11 12 -> 13 14 15 16 -> 17 18 19 20 -> 21 22 23 24 + _k = vld1q_s8(k0); + //! filter row 4 + _idx = {9, 8, 7, 6, 9, 8, 7, 6, 9, 8, 7, 6, 9, 8, 7, 6}; + int8x16_t _k15161718 = vqtbl1q_s8_common(_k, _idx); + _idx = {5, 16, 16, 16, 5, 16, 16, 16, 5, 16, 16, 16, 5, 16, 16, 16}; + int8x16_t _k19 = vqtbl1q_s8_common(_k, _idx); + //! filter row 5 + _idx = {4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1}; + int8x16_t _k20212223 = vqtbl1q_s8_common(_k, _idx); + _idx = {0, 16, 16, 16, 0, 16, 16, 16, 0, 16, 16, 16, 0, 16, 16, 16}; + int8x16_t _k24 = vqtbl1q_s8_common(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 3 < width; w += 3) { + //! As the inner kernel read 16 elements, and IW is times of 16 + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + int32x4_t _sum02 = vld1q_s32(outptr + 8); + int32x4_t _sum10 = vld1q_s32(outptr2); + int32x4_t _sum11 = vld1q_s32(outptr2 + 4); + int32x4_t _sum12 = vld1q_s32(outptr2 + 8); + + _tmp = vld1q_s8(r0); + CALC_0(123, 4, 0); + CALC_0(123, 4, 1); + CALC_0(123, 4, 2); + + _tmp = vld1q_s8(r1); + CALC_2(5678, 9, 123, 4, 0); + CALC_2(5678, 9, 123, 4, 1); + CALC_2(5678, 9, 123, 4, 2); + + _tmp = vld1q_s8(r2); + CALC_2(10111213, 14, 5678, 9, 0); + CALC_2(10111213, 14, 5678, 9, 1); + CALC_2(10111213, 14, 5678, 9, 2); + + _tmp = vld1q_s8(r3); + CALC_2(15161718, 19, 10111213, 14, 0); + CALC_2(15161718, 19, 10111213, 14, 1); + CALC_2(15161718, 19, 10111213, 14, 2); + + _tmp = vld1q_s8(r4); + CALC_2(20212223, 24, 15161718, 19, 0); + CALC_2(20212223, 24, 15161718, 19, 1); + CALC_2(20212223, 24, 15161718, 19, 2); + + _tmp = vld1q_s8(r5); + CALC_1(20212223, 24, 0); + CALC_1(20212223, 24, 1); + CALC_1(20212223, 24, 2); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + vst1q_s32(outptr + 8, _sum02); + vst1q_s32(outptr2, _sum10); + vst1q_s32(outptr2 + 4, _sum11); + vst1q_s32(outptr2 + 8, _sum12); + + r0 += 12; + r1 += 12; + r2 += 12; + r3 += 12; + r4 += 12; + r5 += 12; + outptr += 12; + outptr2 += 12; + } + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum10 = vld1q_s32(outptr2); + + _tmp = vtranslq_s8(vld1_s8(r0)); + CALC_0(123, 4, 0); + + _tmp = vtranslq_s8(vld1_s8(r1)); + CALC_2(5678, 9, 123, 4, 0); + + _tmp = vtranslq_s8(vld1_s8(r2)); + CALC_2(10111213, 14, 5678, 9, 0); + + _tmp = vtranslq_s8(vld1_s8(r3)); + CALC_2(15161718, 19, 10111213, 14, 0); + + _tmp = vtranslq_s8(vld1_s8(r4)); + CALC_2(20212223, 24, 15161718, 19, 0); + + _tmp = vtranslq_s8(vld1_s8(r5)); + CALC_1(20212223, 24, 0); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr2, _sum10); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 3 < width; w += 3) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + int32x4_t _sum02 = vld1q_s32(outptr + 8); + + _tmp = vld1q_s8(r0); + CALC_0(123, 4, 0); + CALC_0(123, 4, 1); + CALC_0(123, 4, 2); + + _tmp = vld1q_s8(r1); + CALC_0(5678, 9, 0); + CALC_0(5678, 9, 1); + CALC_0(5678, 9, 2); + + _tmp = vld1q_s8(r2); + CALC_0(10111213, 14, 0); + CALC_0(10111213, 14, 1); + CALC_0(10111213, 14, 2); + + _tmp = vld1q_s8(r3); + CALC_0(15161718, 19, 0); + CALC_0(15161718, 19, 1); + CALC_0(15161718, 19, 2); + + _tmp = vld1q_s8(r4); + CALC_0(20212223, 24, 0); + CALC_0(20212223, 24, 1); + CALC_0(20212223, 24, 2); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + vst1q_s32(outptr + 8, _sum02); + + r0 += 12; + r1 += 12; + r2 += 12; + r3 += 12; + r4 += 12; + outptr += 12; + } + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + + _tmp = vtranslq_s8(vld1_s8(r0)); + CALC_0(123, 4, 0); + + _tmp = vtranslq_s8(vld1_s8(r1)); + CALC_0(5678, 9, 0); + + _tmp = vtranslq_s8(vld1_s8(r2)); + CALC_0(10111213, 14, 0); + + _tmp = vtranslq_s8(vld1_s8(r3)); + CALC_0(15161718, 19, 0); + + _tmp = vtranslq_s8(vld1_s8(r4)); + CALC_0(20212223, 24, 0); + + vst1q_s32(outptr, _sum00); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } + + filter += 25; + } +} + +void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - OW; + + const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, + 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; + const uint8x16_t _idx11 = {8, 9, 10, 16, 9, 10, 11, 16, + 10, 11, 12, 16, 11, 12, 13, 16}; + + int8x16_t _tmp, _elem; + rep(ic, IC) { + const int8_t* src_ptr = src; + int32_t* dst_ptr = dst + OW * OH * ic; + int32_t* outptr = dst_ptr; + int32_t* outptr2 = outptr + OW; + + const int8_t* r0 = src_ptr; + const int8_t* r1 = src_ptr + IW; + const int8_t* r2 = src_ptr + IW * 2; + const int8_t* r3 = src_ptr + IW * 3; + const int8_t* r4 = src_ptr + IW * 4; + const int8_t* r5 = src_ptr + IW * 5; + const int8_t* r6 = src_ptr + IW * 6; + const int8_t* r7 = src_ptr + IW * 7; + + const int8_t* k0 = filter; + + int8x16_t _k = vld1q_s8(k0 + 33); + //! filter row 1 + uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, + 15, 14, 13, 12, 15, 14, 13, 12}; + int8x16_t _k123 = vqtbl1q_s8_common(_k, _idx); + _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; + int8x16_t _k456 = vqtbl1q_s8_common(_k, _idx); + //! filter row 2 + _idx = {8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5}; + int8x16_t _k78910 = vqtbl1q_s8_common(_k, _idx); + _idx = {4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16}; + int8x16_t _k111213 = vqtbl1q_s8_common(_k, _idx); + + //! 12 13 14 15 -> 16 17 18 19 -> 20 21 22 23 -> 24 25 26 27 + _k = vld1q_s8(k0 + 19); + //! filter row 3 + _idx = {15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12}; + int8x16_t _k14151617 = vqtbl1q_s8_common(_k, _idx); + _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; + int8x16_t _k181920 = vqtbl1q_s8_common(_k, _idx); + //! filter row 4 + _idx = {8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5}; + int8x16_t _k21222324 = vqtbl1q_s8_common(_k, _idx); + _idx = {4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16}; + int8x16_t _k252627 = vqtbl1q_s8_common(_k, _idx); + + //! 24 25 26 27->28 29 30 31 -> 32 33 34 35 -> 36 37 38 39 + _k = vld1q_s8(k0 + 5); + //! filter row 5 + _idx = {15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12}; + int8x16_t _k28293031 = vqtbl1q_s8_common(_k, _idx); + _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; + int8x16_t _k323334 = vqtbl1q_s8_common(_k, _idx); + + //! 33 34 35 36 -> 37 38 39 40 -> 41 42 43 44 -> 45 46 47 48 + _k = vld1q_s8(k0); + //! filter row 6 + _idx = {13, 12, 11, 10, 13, 12, 11, 10, 13, 12, 11, 10, 13, 12, 11, 10}; + int8x16_t _k35363738 = vqtbl1q_s8_common(_k, _idx); + _idx = {9, 8, 7, 16, 9, 8, 7, 16, 9, 8, 7, 16, 9, 8, 7, 16}; + int8x16_t _k394041 = vqtbl1q_s8_common(_k, _idx); + + //! filter row 7 + _idx = {6, 5, 4, 3, 6, 5, 4, 3, 6, 5, 4, 3, 6, 5, 4, 3}; + int8x16_t _k42434445 = vqtbl1q_s8_common(_k, _idx); + _idx = {2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16}; + int8x16_t _k464748 = vqtbl1q_s8_common(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 2 < width; w += 2) { + //! As the inner kernel read 16 elements, and IW is times of 16 + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + int32x4_t _sum10 = vld1q_s32(outptr2); + int32x4_t _sum11 = vld1q_s32(outptr2 + 4); + + _tmp = vld1q_s8(r0); + CALC_0(123, 456, 0); + CALC_0(123, 456, 1); + + _tmp = vld1q_s8(r1); + CALC_2(78910, 111213, 123, 456, 0); + CALC_2(78910, 111213, 123, 456, 1); + + _tmp = vld1q_s8(r2); + CALC_2(14151617, 181920, 78910, 111213, 0); + CALC_2(14151617, 181920, 78910, 111213, 1); + + _tmp = vld1q_s8(r3); + CALC_2(21222324, 252627, 14151617, 181920, 0); + CALC_2(21222324, 252627, 14151617, 181920, 1); + + _tmp = vld1q_s8(r4); + CALC_2(28293031, 323334, 21222324, 252627, 0); + CALC_2(28293031, 323334, 21222324, 252627, 1); + + _tmp = vld1q_s8(r5); + CALC_2(35363738, 394041, 28293031, 323334, 0); + CALC_2(35363738, 394041, 28293031, 323334, 1); + + _tmp = vld1q_s8(r6); + CALC_2(42434445, 464748, 35363738, 394041, 0); + CALC_2(42434445, 464748, 35363738, 394041, 1); + + _tmp = vld1q_s8(r7); + CALC_1(42434445, 464748, 0); + CALC_1(42434445, 464748, 1); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + vst1q_s32(outptr2, _sum10); + vst1q_s32(outptr2 + 4, _sum11); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + r7 += 8; + outptr += 8; + outptr2 += 8; + } + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum10 = vld1q_s32(outptr2); + + _tmp = vld1q_s8(r0); + CALC_0(123, 456, 0); + + _tmp = vld1q_s8(r1); + CALC_2(78910, 111213, 123, 456, 0); + + _tmp = vld1q_s8(r2); + CALC_2(14151617, 181920, 78910, 111213, 0); + + _tmp = vld1q_s8(r3); + CALC_2(21222324, 252627, 14151617, 181920, 0); + + _tmp = vld1q_s8(r4); + CALC_2(28293031, 323334, 21222324, 252627, 0); + + _tmp = vld1q_s8(r5); + CALC_2(35363738, 394041, 28293031, 323334, 0); + + _tmp = vld1q_s8(r6); + CALC_2(42434445, 464748, 35363738, 394041, 0); + + _tmp = vld1q_s8(r7); + CALC_1(42434445, 464748, 0); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr2, _sum10); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + r7 += 4; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + r6 += tail_step + IW; + r7 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 2 < width; w += 2) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + + _tmp = vld1q_s8(r0); + CALC_0(123, 456, 0); + CALC_0(123, 456, 1); + + _tmp = vld1q_s8(r1); + CALC_0(78910, 111213, 0); + CALC_0(78910, 111213, 1); + + _tmp = vld1q_s8(r2); + CALC_0(14151617, 181920, 0); + CALC_0(14151617, 181920, 1); + + _tmp = vld1q_s8(r3); + CALC_0(21222324, 252627, 0); + CALC_0(21222324, 252627, 1); + + _tmp = vld1q_s8(r4); + CALC_0(28293031, 323334, 0); + CALC_0(28293031, 323334, 1); + + _tmp = vld1q_s8(r5); + CALC_0(35363738, 394041, 0); + CALC_0(35363738, 394041, 1); + + _tmp = vld1q_s8(r6); + CALC_0(42434445, 464748, 0); + CALC_0(42434445, 464748, 1); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + outptr += 8; + } + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + + _tmp = vld1q_s8(r0); + CALC_0(123, 456, 0); + + _tmp = vld1q_s8(r1); + CALC_0(78910, 111213, 0); + + _tmp = vld1q_s8(r2); + CALC_0(14151617, 181920, 0); + + _tmp = vld1q_s8(r3); + CALC_0(21222324, 252627, 0); + + _tmp = vld1q_s8(r4); + CALC_0(28293031, 323334, 0); + + _tmp = vld1q_s8(r5); + CALC_0(35363738, 394041, 0); + + _tmp = vld1q_s8(r6); + CALC_0(42434445, 464748, 0); + + vst1q_s32(outptr, _sum00); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } + + filter += 49; + } +} + +#undef CALC_0 +#undef CALC_1 +#undef CALC_2 + +} // anonymous namespace + +size_t deconv::get_workspace_in_bytes_stride1_int8x8x32_dot( + const NCBKernSizeParam& param) { + return get_bundle(param).total_size_in_bytes(); +} + +bool deconv::can_stride1_int8x8x32_dot(const NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0], FW = fm.spatial[1], OC = fm.ocpg, IC = fm.icpg, + PH = fm.padding[0], PW = fm.padding[1]; + bool avaiable = fm.format == param::Convolution::Format::NCHW && + !fm.should_flip && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 1 && fm.stride[1] == 1 && FH == FW && + (FH == 2 || FH == 3 || FH == 5 || FH == 7) && + FH >= PH + 1 && FW >= PW + 1; + + return avaiable && + ((FH == 2 && OC <= 8) || + ((FH == 3 || FH == 5 || FH == 7) && (IC < 32 && OC <= 16))); +} + +void deconv::stride1_int8x8x32_dot(const NCBKernParam& param) { + auto bundle = get_bundle(param); + bundle.set(param.workspace_ptr); + UNPACK_CONV_F32_NCB_KERN_SIZES(param); + MEGDNN_MARK_USED_VAR(SH); + MEGDNN_MARK_USED_VAR(SW); + size_t IH2, IW2, OW2; + int padding_h = FH - PH - 1, padding_w = FW - PW - 1; + + get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); + + using Func = std::function; + Func conv = nullptr; + if (FH == 2) { + conv = deconv_direct_2x2; + } else if (FH == 3) { + conv = deconv_direct_3x3; + } else if (FH == 5) { + conv = deconv_direct_5x5; + } else if (FH == 7) { + conv = deconv_direct_7x7; + } else { + megdnn_assert(0); + } + + bool need_src_copy_var = need_src_copy(param); + bool need_dst_copy_var = need_dst_copy(param); + int8_t* base_src_ptr = const_cast(param.diff()); + int32_t* base_dst_ptr = param.grad(); + const int8_t* fptr = param.filter(); + + for (size_t n = 0; n < N; ++n) { + int32_t* dptr_copied = static_cast(bundle.get(1)); + int32_t* dptr_ori = base_dst_ptr + n * param.out_bs; + int32_t* dptr = nullptr; + size_t OW_real = OW; + if (need_dst_copy_var) { + dptr = dptr_copied; + OW_real = OW2; + } else { + dptr = dptr_ori; + } + std::memset(dptr, 0, sizeof(int32_t) * IC * OH * OW_real); + + int8_t* sptr_ori = base_src_ptr + n * param.inp_bs; + int8_t* sptr_copied = static_cast(bundle.get(0)); + int8_t* sptr = nullptr; + + rep(oc, OC) { + if (need_src_copy_var) { + // copy sptr_ori to sptr_copied + std::memset(sptr_copied, 0, sizeof(int8_t) * IH2 * IW2); + copy_plane_in_bytes(sptr_copied + padding_h * IW2 + padding_w, + sptr_ori + oc * IH * IW, IH, + IW * sizeof(int8_t), IW2 * sizeof(int8_t), + IW * sizeof(int8_t)); + sptr = sptr_copied; + } else { + sptr = sptr_ori + oc * IH * IW; + } + conv(sptr, fptr + oc * IC * FH * FW, dptr, IH2, IW2, OH, OW_real, + IC); + } + if (need_dst_copy_var) { + for (size_t ic = 0; ic < IC; ++ic) { + copy_plane_in_bytes(dptr_ori + ic * OH * OW, + dptr + ic * OH * OW2, OH, + OW * sizeof(int32_t), OW * sizeof(int32_t), + OW2 * sizeof(int32_t)); + } + } + } +} + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h new file mode 100644 index 00000000..eae4b531 --- /dev/null +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h @@ -0,0 +1,38 @@ +/** + * \file dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/convolution/opr_impl.h" + +#include +#include + +namespace megdnn { +namespace arm_common { +namespace deconv { + +using NCBKernSizeParam = ConvolutionBackwardDataImpl::NCBKernSizeParam; +using NCBKernParam = ConvolutionBackwardDataImpl::NCBKernParam; + +bool can_stride1_int8x8x32_dot(const NCBKernSizeParam& param); + +void stride1_int8x8x32_dot(const NCBKernParam& param); + +size_t get_workspace_in_bytes_stride1_int8x8x32_dot( + const NCBKernSizeParam& param); + +} // namespace deconv +} // namespace arm_common +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp new file mode 100644 index 00000000..9a8e9774 --- /dev/null +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp @@ -0,0 +1,1269 @@ +/** + * \file dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h" +#include "src/common/utils.h" + +#include +#include "src/arm_common/simd_macro/marm_neon.h" + +using namespace megdnn; +using namespace arm_common; +using namespace deconv; + +namespace { + +bool need_dst_copy(const NCBKernSizeParam& param) { + if (param.osz[1] % 4 != 0) { + // If the size of output is not multiples of 4, we need to copy it. + return true; + } + return false; +} + +void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, + size_t FW, size_t PH, size_t PW, size_t& IH2, + size_t& IW2, size_t& OW2) { + MEGDNN_MARK_USED_VAR(OH); + MEGDNN_MARK_USED_VAR(IW); + //! OW should be a multiple of 4 + OW2 = (OW + 3) & ~3; + IH2 = 2 * IH - 1 + 2 * (FH - PH - 1); + IW2 = (OW2 - FW + 2 * PW) / 2 + 1 + (FW - PW - 1) + 16; +} + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param) { + UNPACK_CONV_F32_NCB_KERN_SIZES(param); + MEGDNN_MARK_USED_VAR(N); + MEGDNN_MARK_USED_VAR(OC); + MEGDNN_MARK_USED_VAR(SH); + MEGDNN_MARK_USED_VAR(SW); + size_t src_size = 0, dst_size = 0; + size_t IH2, IW2, OW2; + get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); + src_size = sizeof(int8_t) * IH2 * IW2; + if (need_dst_copy(param)) { + dst_size = sizeof(int32_t) * IC * OH * OW2; + } + return WorkspaceBundle(nullptr, {src_size, dst_size}); +} + +inline int8x16_t vqtbl1q_s8_common(int8x16_t a, uint8x16_t index) { + int8x8x2_t src; + src.val[0] = vget_low_s8(a); + src.val[1] = vget_high_s8(a); + uint8x8_t index_low = vget_low_u8(index); + uint8x8_t index_high = vget_high_u8(index); + int8x8_t r00 = vtbl2_s8(src, vreinterpret_s8_u8(index_low)); + int8x8_t r01 = vtbl2_s8(src, vreinterpret_s8_u8(index_high)); + int8x16_t r = vcombine_s8(r00, r01); + return r; +} + +#define CALC_0(_k_idx, _c_idx) \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k_idx, _elem); + +#define CALC_1(_k_idx, _c_idx) \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k_idx, _elem); + +#define CALC_2(_k1_idx, _k2_idx, _c_idx) \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k1_idx, _elem); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); + +template +void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - OW / 2; + + const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, + 2, 3, 16, 16, 3, 4, 16, 16}; + const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, + 6, 7, 16, 16, 7, 8, 16, 16}; + uint8x16_t _idx_r_0, _idx_r_1; + if (even) { + _idx_r_0 = {0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16}; + _idx_r_1 = {16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16, 8}; + } else { + _idx_r_0 = {16, 0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7}; + _idx_r_1 = {0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16}; + } + rep(ic, IC) { + const int8_t* src_ptr = src; + int32_t* dst_ptr = dst + OW * OH * ic; + int32_t* outptr = dst_ptr; + int32_t* outptr2 = dst_ptr + OW; + + const int8_t* r0 = src_ptr; + const int8_t* r1 = src_ptr + IW; + const int8_t* r2 = src_ptr + 2 * IW; + + const int8_t* k0 = filter; + + int8x16_t _k0 = vreinterpretq_s8_s32( + vdupq_n_s32(*reinterpret_cast(k0))); + uint8x16_t _idx_k = {3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0}; + int8x16_t _k = vqtbl1q_s8_common(_k0, _idx_k); + uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, + 0, 1, 16, 16, 0, 1, 16, 16}; + int8x16_t _k1 = vqtbl1q_s8_common(_k, _idx); + _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; + int8x16_t _k23 = vqtbl1q_s8_common(_k, _idx); + + int8x16_t _tmp, _elem; + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 4 < width; w += 4) { + int32x4x2_t _sum00, _sum01, _sum10, _sum11; + _sum00 = vld2q_s32(outptr); + _sum01 = vld2q_s32(outptr + 8); + _sum10 = vld2q_s32(outptr2); + _sum11 = vld2q_s32(outptr2 + 8); + + int8x16_t _r0_ori = vld1q_s8(r0); + int8x16_t _r00 = vqtbl1q_s8_common(_r0_ori, _idx_r_0); + int8x16_t _r01 = vqtbl1q_s8_common(_r0_ori, _idx_r_1); + int8x16_t _r1_ori = vld1q_s8(r1); + int8x16_t _r10 = vqtbl1q_s8_common(_r1_ori, _idx_r_0); + int8x16_t _r11 = vqtbl1q_s8_common(_r1_ori, _idx_r_1); + int8x16_t _r2_ori = vld1q_s8(r2); + int8x16_t _r20 = vqtbl1q_s8_common(_r2_ori, _idx_r_0); + int8x16_t _r21 = vqtbl1q_s8_common(_r2_ori, _idx_r_1); + + int16x8x2_t r_00 = vzipq_s16(vreinterpretq_s16_s8(_r00), + vreinterpretq_s16_s8(_r10)); + int8x16_t _r0 = r_00.val[0]; + int8x16_t _r2 = r_00.val[1]; + + int16x8x2_t r_11 = vzipq_s16(vreinterpretq_s16_s8(_r01), + vreinterpretq_s16_s8(_r11)); + int8x16_t _r1 = r_11.val[0]; + int8x16_t _r3 = r_11.val[1]; + + _sum00.val[0] = vdotq_s32(_sum00.val[0], _k, _r0); + _sum00.val[1] = vdotq_s32(_sum00.val[1], _k, _r1); + _sum01.val[0] = vdotq_s32(_sum01.val[0], _k, _r2); + _sum01.val[1] = vdotq_s32(_sum01.val[1], _k, _r3); + + r_00 = vzipq_s16(vreinterpretq_s16_s8(_r10), + vreinterpretq_s16_s8(_r20)); + _r0 = r_00.val[0]; + _r2 = r_00.val[1]; + + r_11 = vzipq_s16(vreinterpretq_s16_s8(_r11), + vreinterpretq_s16_s8(_r21)); + _r1 = r_11.val[0]; + _r3 = r_11.val[1]; + + _sum10.val[0] = vdotq_s32(_sum10.val[0], _k, _r0); + _sum10.val[1] = vdotq_s32(_sum10.val[1], _k, _r1); + _sum11.val[0] = vdotq_s32(_sum11.val[0], _k, _r2); + _sum11.val[1] = vdotq_s32(_sum11.val[1], _k, _r3); + + vst2q_s32(outptr, _sum00); + vst2q_s32(outptr + 8, _sum01); + vst2q_s32(outptr2, _sum10); + vst2q_s32(outptr2 + 8, _sum11); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 16; + outptr2 += 16; + } + for (; w + 2 < width; w += 2) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + int32x4_t _sum10 = vld1q_s32(outptr2); + int32x4_t _sum11 = vld1q_s32(outptr2 + 4); + + int8x16_t _r_ori = vld1q_s8(r0); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(1, 0); + CALC_0(1, 1); + + _r_ori = vld1q_s8(r1); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(23, 1, 0); + CALC_2(23, 1, 1); + + _r_ori = vld1q_s8(r2); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_1(23, 0); + CALC_1(23, 1); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + vst1q_s32(outptr2, _sum10); + vst1q_s32(outptr2 + 4, _sum11); + + r0 += 4; + r1 += 4; + r2 += 4; + outptr += 8; + outptr2 += 8; + } + + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum10 = vld1q_s32(outptr2); + + int8x16_t _r_ori = vtranslq_s8(vld1_s8(r0)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(1, 0); + + _r_ori = vtranslq_s8(vld1_s8(r1)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(23, 1, 0); + + _r_ori = vtranslq_s8(vld1_s8(r2)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_1(23, 0); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr2, _sum10); + + r0 += 2; + r1 += 2; + r2 += 2; + outptr += 4; + outptr2 += 4; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 4 < width; w += 4) { + int32x4x2_t _sum0, _sum1; + _sum0 = vld2q_s32(outptr); + _sum1 = vld2q_s32(outptr + 8); + + int8x16_t _r0_ori = vld1q_s8(r0); + int8x16_t _r00 = vqtbl1q_s8_common(_r0_ori, _idx_r_0); + int8x16_t _r01 = vqtbl1q_s8_common(_r0_ori, _idx_r_1); + int8x16_t _r1_ori = vld1q_s8(r1); + int8x16_t _r10 = vqtbl1q_s8_common(_r1_ori, _idx_r_0); + int8x16_t _r11 = vqtbl1q_s8_common(_r1_ori, _idx_r_1); + + int16x8x2_t r_00 = vzipq_s16(vreinterpretq_s16_s8(_r00), + vreinterpretq_s16_s8(_r10)); + int8x16_t _r0 = r_00.val[0]; + int8x16_t _r2 = r_00.val[1]; + + int16x8x2_t r_11 = vzipq_s16(vreinterpretq_s16_s8(_r01), + vreinterpretq_s16_s8(_r11)); + int8x16_t _r1 = r_11.val[0]; + int8x16_t _r3 = r_11.val[1]; + + _sum0.val[0] = vdotq_s32(_sum0.val[0], _k, _r0); + _sum0.val[1] = vdotq_s32(_sum0.val[1], _k, _r1); + _sum1.val[0] = vdotq_s32(_sum1.val[0], _k, _r2); + _sum1.val[1] = vdotq_s32(_sum1.val[1], _k, _r3); + + vst2q_s32(outptr, _sum0); + vst2q_s32(outptr + 8, _sum1); + + r0 += 8; + r1 += 8; + outptr += 16; + } + for (; w + 2 < width; w += 2) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + + int8x16_t _r_ori = vld1q_s8(r0); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(1, 0); + CALC_0(1, 1); + + _r_ori = vld1q_s8(r1); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(23, 0); + CALC_0(23, 1); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + + r0 += 4; + r1 += 4; + outptr += 8; + } + + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + + int8x16_t _r_ori = vtranslq_s8(vld1_s8(r0)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(1, 0); + + _r_ori = vtranslq_s8(vld1_s8(r1)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(23, 0); + + vst1q_s32(outptr, _sum00); + + r0 += 2; + r1 += 2; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + } + + filter += 4; + } +} + +template +void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - OW / 2; + + const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, + 2, 3, 4, 16, 3, 4, 5, 16}; + const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, + 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx2 = {8, 9, 10, 16, 9, 10, 11, 16, + 10, 11, 12, 16, 11, 12, 13, 16}; + + uint8x16_t _idx_r_0; + if (even) { + _idx_r_0 = {0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16}; + } else { + _idx_r_0 = {16, 0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7}; + } + rep(ic, IC) { + const int8_t* src_ptr = src; + int32_t* dst_ptr = dst + OW * OH * ic; + int32_t* outptr = dst_ptr; + int32_t* outptr2 = outptr + OW; + + const int8_t* r0 = src_ptr; + const int8_t* r1 = src_ptr + IW; + const int8_t* r2 = src_ptr + IW * 2; + const int8_t* r3 = src_ptr + IW * 3; + + const int8_t* k0 = filter; + + int8x16_t _k_tmp = vcombine_s8(vld1_s8(k0), vdup_n_s8(k0[8])); + uint8x16_t _idx = {8, 7, 6, 16, 8, 7, 6, 16, 8, 7, 6, 16, 8, 7, 6, 16}; + int8x16_t _k12 = vqtbl1q_s8_common(_k_tmp, _idx); + _idx = {5, 4, 3, 16, 5, 4, 3, 16, 5, 4, 3, 16, 5, 4, 3, 16}; + int8x16_t _k345 = vqtbl1q_s8_common(_k_tmp, _idx); + _idx = {2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16}; + int8x16_t _k678 = vqtbl1q_s8_common(_k_tmp, _idx); + + int8x16_t _tmp, _elem; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + //! As the inner kernel read 16 elements, and IW is times of 16 + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + int32x4_t _sum02 = vld1q_s32(outptr + 8); + int32x4_t _sum10 = vld1q_s32(outptr2); + int32x4_t _sum11 = vld1q_s32(outptr2 + 4); + int32x4_t _sum12 = vld1q_s32(outptr2 + 8); + + int8x16_t _r_ori = vld1q_s8(r0); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(12, 0); + CALC_0(12, 1); + CALC_0(12, 2); + + _r_ori = vld1q_s8(r1); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(345, 12, 0); + CALC_2(345, 12, 1); + CALC_2(345, 12, 2); + + _r_ori = vld1q_s8(r2); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(678, 345, 0); + CALC_2(678, 345, 1); + CALC_2(678, 345, 2); + + _r_ori = vld1q_s8(r3); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_1(678, 0); + CALC_1(678, 1); + CALC_1(678, 2); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + vst1q_s32(outptr + 8, _sum02); + vst1q_s32(outptr2, _sum10); + vst1q_s32(outptr2 + 4, _sum11); + vst1q_s32(outptr2 + 8, _sum12); + + r0 += 6; + r1 += 6; + r2 += 6; + r3 += 6; + outptr += 12; + outptr2 += 12; + } + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum10 = vld1q_s32(outptr2); + + int8x16_t _r_ori = vtranslq_s8(vld1_s8(r0)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(12, 0); + + _r_ori = vtranslq_s8(vld1_s8(r1)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(345, 12, 0); + + _r_ori = vtranslq_s8(vld1_s8(r2)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(678, 345, 0); + + _r_ori = vtranslq_s8(vld1_s8(r3)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_1(678, 0); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr2, _sum10); + + r0 += 2; + r1 += 2; + r2 += 2; + r3 += 2; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + int32x4_t _sum02 = vld1q_s32(outptr + 8); + + int8x16_t _r_ori = vld1q_s8(r0); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(12, 0); + CALC_0(12, 1); + CALC_0(12, 2); + + _r_ori = vld1q_s8(r1); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(345, 0); + CALC_0(345, 1); + CALC_0(345, 2); + + _r_ori = vld1q_s8(r2); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(678, 0); + CALC_0(678, 1); + CALC_0(678, 2); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + vst1q_s32(outptr + 8, _sum02); + + r0 += 6; + r1 += 6; + r2 += 6; + outptr += 12; + } + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + + int8x16_t _r_ori = vtranslq_s8(vld1_s8(r0)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(12, 0); + + _r_ori = vtranslq_s8(vld1_s8(r1)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(345, 0); + + _r_ori = vtranslq_s8(vld1_s8(r2)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(678, 0); + + vst1q_s32(outptr, _sum00); + + r0 += 2; + r1 += 2; + r2 += 2; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } + + filter += 9; + } +} + +#undef CALC_0 +#undef CALC_1 +#undef CALC_2 + +#define CALC_0(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k01_idx, _elem); + +#define CALC_1(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx##0); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k00_idx, _elem); \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx##1); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k01_idx, _elem); + +#define CALC_2(_k00_idx, _k01_idx, _k10_idx, _k11_idx, _c_idx) \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k10_idx, _elem); \ + _elem = vqtbl1q_s8_common(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); + +template +void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - OW / 2; + + const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; + const uint8x16_t _idx01 = {4, 16, 16, 16, 5, 16, 16, 16, + 6, 16, 16, 16, 7, 16, 16, 16}; + const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; + const uint8x16_t _idx11 = {8, 16, 16, 16, 9, 16, 16, 16, + 10, 16, 16, 16, 11, 16, 16, 16}; + const uint8x16_t _idx20 = {8, 9, 10, 11, 9, 10, 11, 12, + 10, 11, 12, 13, 11, 12, 13, 14}; + const uint8x16_t _idx21 = {12, 16, 16, 16, 13, 16, 16, 16, + 14, 16, 16, 16, 15, 16, 16, 16}; + + uint8x16_t _idx_r_0; + if (even) { + _idx_r_0 = {0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16}; + } else { + _idx_r_0 = {16, 0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7}; + } + int8x16_t _tmp, _elem; + rep(ic, IC) { + const int8_t* src_ptr = src; + int32_t* dst_ptr = dst + OW * OH * ic; + int32_t* outptr = dst_ptr; + int32_t* outptr2 = outptr + OW; + + const int8_t* r0 = src_ptr; + const int8_t* r1 = src_ptr + IW; + const int8_t* r2 = src_ptr + IW * 2; + const int8_t* r3 = src_ptr + IW * 3; + const int8_t* r4 = src_ptr + IW * 4; + const int8_t* r5 = src_ptr + IW * 5; + + const int8_t* k0 = filter; + + int8x16_t _k = vld1q_s8(k0 + 9); + //! filter row 1 + uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, + 15, 14, 13, 12, 15, 14, 13, 12}; + int8x16_t _k123 = vqtbl1q_s8_common(_k, _idx); + _idx = {11, 16, 16, 16, 11, 16, 16, 16, 11, 16, 16, 16, 11, 16, 16, 16}; + int8x16_t _k4 = vqtbl1q_s8_common(_k, _idx); + //! filter row 2 + _idx = {10, 9, 8, 7, 10, 9, 8, 7, 10, 9, 8, 7, 10, 9, 8, 7}; + int8x16_t _k5678 = vqtbl1q_s8_common(_k, _idx); + _idx = {6, 16, 16, 16, 6, 16, 16, 16, 6, 16, 16, 16, 6, 16, 16, 16}; + int8x16_t _k9 = vqtbl1q_s8_common(_k, _idx); + //! filter row 3 + _idx = {5, 4, 3, 2, 5, 4, 3, 2, 5, 4, 3, 2, 5, 4, 3, 2}; + int8x16_t _k10111213 = vqtbl1q_s8_common(_k, _idx); + _idx = {1, 16, 16, 16, 1, 16, 16, 16, 1, 16, 16, 16, 1, 16, 16, 16}; + int8x16_t _k14 = vqtbl1q_s8_common(_k, _idx); + //! 9 10 11 12 -> 13 14 15 16 -> 17 18 19 20 -> 21 22 23 24 + _k = vld1q_s8(k0); + //! filter row 4 + _idx = {9, 8, 7, 6, 9, 8, 7, 6, 9, 8, 7, 6, 9, 8, 7, 6}; + int8x16_t _k15161718 = vqtbl1q_s8_common(_k, _idx); + _idx = {5, 16, 16, 16, 5, 16, 16, 16, 5, 16, 16, 16, 5, 16, 16, 16}; + int8x16_t _k19 = vqtbl1q_s8_common(_k, _idx); + //! filter row 5 + _idx = {4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1}; + int8x16_t _k20212223 = vqtbl1q_s8_common(_k, _idx); + _idx = {0, 16, 16, 16, 0, 16, 16, 16, 0, 16, 16, 16, 0, 16, 16, 16}; + int8x16_t _k24 = vqtbl1q_s8_common(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 3 < width; w += 3) { + //! As the inner kernel read 16 elements, and IW is times of 16 + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + int32x4_t _sum02 = vld1q_s32(outptr + 8); + int32x4_t _sum10 = vld1q_s32(outptr2); + int32x4_t _sum11 = vld1q_s32(outptr2 + 4); + int32x4_t _sum12 = vld1q_s32(outptr2 + 8); + + int8x16_t _r_ori = vld1q_s8(r0); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(123, 4, 0); + CALC_0(123, 4, 1); + CALC_0(123, 4, 2); + + _r_ori = vld1q_s8(r1); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(5678, 9, 123, 4, 0); + CALC_2(5678, 9, 123, 4, 1); + CALC_2(5678, 9, 123, 4, 2); + + _r_ori = vld1q_s8(r2); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(10111213, 14, 5678, 9, 0); + CALC_2(10111213, 14, 5678, 9, 1); + CALC_2(10111213, 14, 5678, 9, 2); + + _r_ori = vld1q_s8(r3); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(15161718, 19, 10111213, 14, 0); + CALC_2(15161718, 19, 10111213, 14, 1); + CALC_2(15161718, 19, 10111213, 14, 2); + + _r_ori = vld1q_s8(r4); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(20212223, 24, 15161718, 19, 0); + CALC_2(20212223, 24, 15161718, 19, 1); + CALC_2(20212223, 24, 15161718, 19, 2); + + _r_ori = vld1q_s8(r5); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_1(20212223, 24, 0); + CALC_1(20212223, 24, 1); + CALC_1(20212223, 24, 2); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + vst1q_s32(outptr + 8, _sum02); + vst1q_s32(outptr2, _sum10); + vst1q_s32(outptr2 + 4, _sum11); + vst1q_s32(outptr2 + 8, _sum12); + + r0 += 6; + r1 += 6; + r2 += 6; + r3 += 6; + r4 += 6; + r5 += 6; + outptr += 12; + outptr2 += 12; + } + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum10 = vld1q_s32(outptr2); + + int8x16_t _r_ori = vtranslq_s8(vld1_s8(r0)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(123, 4, 0); + + _r_ori = vtranslq_s8(vld1_s8(r1)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(5678, 9, 123, 4, 0); + + _r_ori = vtranslq_s8(vld1_s8(r2)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(10111213, 14, 5678, 9, 0); + + _r_ori = vtranslq_s8(vld1_s8(r3)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(15161718, 19, 10111213, 14, 0); + + _r_ori = vtranslq_s8(vld1_s8(r4)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(20212223, 24, 15161718, 19, 0); + + _r_ori = vtranslq_s8(vld1_s8(r5)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_1(20212223, 24, 0); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr2, _sum10); + + r0 += 2; + r1 += 2; + r2 += 2; + r3 += 2; + r4 += 2; + r5 += 2; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 3 < width; w += 3) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + int32x4_t _sum02 = vld1q_s32(outptr + 8); + + int8x16_t _r_ori = vld1q_s8(r0); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(123, 4, 0); + CALC_0(123, 4, 1); + CALC_0(123, 4, 2); + + _r_ori = vld1q_s8(r1); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(5678, 9, 0); + CALC_0(5678, 9, 1); + CALC_0(5678, 9, 2); + + _r_ori = vld1q_s8(r2); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(10111213, 14, 0); + CALC_0(10111213, 14, 1); + CALC_0(10111213, 14, 2); + + _r_ori = vld1q_s8(r3); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(15161718, 19, 0); + CALC_0(15161718, 19, 1); + CALC_0(15161718, 19, 2); + + _r_ori = vld1q_s8(r4); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(20212223, 24, 0); + CALC_0(20212223, 24, 1); + CALC_0(20212223, 24, 2); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + vst1q_s32(outptr + 8, _sum02); + + r0 += 6; + r1 += 6; + r2 += 6; + r3 += 6; + r4 += 6; + outptr += 12; + } + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + + int8x16_t _r_ori = vtranslq_s8(vld1_s8(r0)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(123, 4, 0); + + _r_ori = vtranslq_s8(vld1_s8(r1)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(5678, 9, 0); + + _r_ori = vtranslq_s8(vld1_s8(r2)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(10111213, 14, 0); + + _r_ori = vtranslq_s8(vld1_s8(r3)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(15161718, 19, 0); + + _r_ori = vtranslq_s8(vld1_s8(r4)); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(20212223, 24, 0); + + vst1q_s32(outptr, _sum00); + + r0 += 2; + r1 += 2; + r2 += 2; + r3 += 2; + r4 += 2; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } + + filter += 25; + } +} + +template +void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { + MEGDNN_MARK_USED_VAR(IH); + const size_t tail_step = IW - OW / 2; + + const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, + 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; + const uint8x16_t _idx11 = {8, 9, 10, 16, 9, 10, 11, 16, + 10, 11, 12, 16, 11, 12, 13, 16}; + + uint8x16_t _idx_r_0; + if (even) { + _idx_r_0 = {0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16}; + } else { + _idx_r_0 = {16, 0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7}; + } + int8x16_t _tmp, _elem; + rep(ic, IC) { + const int8_t* src_ptr = src; + int32_t* dst_ptr = dst + OW * OH * ic; + int32_t* outptr = dst_ptr; + int32_t* outptr2 = outptr + OW; + + const int8_t* r0 = src_ptr; + const int8_t* r1 = src_ptr + IW; + const int8_t* r2 = src_ptr + IW * 2; + const int8_t* r3 = src_ptr + IW * 3; + const int8_t* r4 = src_ptr + IW * 4; + const int8_t* r5 = src_ptr + IW * 5; + const int8_t* r6 = src_ptr + IW * 6; + const int8_t* r7 = src_ptr + IW * 7; + + const int8_t* k0 = filter; + + int8x16_t _k = vld1q_s8(k0 + 33); + //! filter row 1 + uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, + 15, 14, 13, 12, 15, 14, 13, 12}; + int8x16_t _k123 = vqtbl1q_s8_common(_k, _idx); + _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; + int8x16_t _k456 = vqtbl1q_s8_common(_k, _idx); + //! filter row 2 + _idx = {8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5}; + int8x16_t _k78910 = vqtbl1q_s8_common(_k, _idx); + _idx = {4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16}; + int8x16_t _k111213 = vqtbl1q_s8_common(_k, _idx); + + //! 12 13 14 15 -> 16 17 18 19 -> 20 21 22 23 -> 24 25 26 27 + _k = vld1q_s8(k0 + 19); + //! filter row 3 + _idx = {15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12}; + int8x16_t _k14151617 = vqtbl1q_s8_common(_k, _idx); + _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; + int8x16_t _k181920 = vqtbl1q_s8_common(_k, _idx); + //! filter row 4 + _idx = {8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5}; + int8x16_t _k21222324 = vqtbl1q_s8_common(_k, _idx); + _idx = {4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16}; + int8x16_t _k252627 = vqtbl1q_s8_common(_k, _idx); + + //! 24 25 26 27->28 29 30 31 -> 32 33 34 35 -> 36 37 38 39 + _k = vld1q_s8(k0 + 5); + //! filter row 5 + _idx = {15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12}; + int8x16_t _k28293031 = vqtbl1q_s8_common(_k, _idx); + _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; + int8x16_t _k323334 = vqtbl1q_s8_common(_k, _idx); + + //! 33 34 35 36 -> 37 38 39 40 -> 41 42 43 44 -> 45 46 47 48 + _k = vld1q_s8(k0); + //! filter row 6 + _idx = {13, 12, 11, 10, 13, 12, 11, 10, 13, 12, 11, 10, 13, 12, 11, 10}; + int8x16_t _k35363738 = vqtbl1q_s8_common(_k, _idx); + _idx = {9, 8, 7, 16, 9, 8, 7, 16, 9, 8, 7, 16, 9, 8, 7, 16}; + int8x16_t _k394041 = vqtbl1q_s8_common(_k, _idx); + + //! filter row 7 + _idx = {6, 5, 4, 3, 6, 5, 4, 3, 6, 5, 4, 3, 6, 5, 4, 3}; + int8x16_t _k42434445 = vqtbl1q_s8_common(_k, _idx); + _idx = {2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16}; + int8x16_t _k464748 = vqtbl1q_s8_common(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 2 < width; w += 2) { + //! As the inner kernel read 16 elements, and IW is times of 16 + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + int32x4_t _sum10 = vld1q_s32(outptr2); + int32x4_t _sum11 = vld1q_s32(outptr2 + 4); + + int8x16_t _r_ori = vld1q_s8(r0); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(123, 456, 0); + CALC_0(123, 456, 1); + + _r_ori = vld1q_s8(r1); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(78910, 111213, 123, 456, 0); + CALC_2(78910, 111213, 123, 456, 1); + + _r_ori = vld1q_s8(r2); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(14151617, 181920, 78910, 111213, 0); + CALC_2(14151617, 181920, 78910, 111213, 1); + + _r_ori = vld1q_s8(r3); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(21222324, 252627, 14151617, 181920, 0); + CALC_2(21222324, 252627, 14151617, 181920, 1); + + _r_ori = vld1q_s8(r4); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(28293031, 323334, 21222324, 252627, 0); + CALC_2(28293031, 323334, 21222324, 252627, 1); + + _r_ori = vld1q_s8(r5); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(35363738, 394041, 28293031, 323334, 0); + CALC_2(35363738, 394041, 28293031, 323334, 1); + + _r_ori = vld1q_s8(r6); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(42434445, 464748, 35363738, 394041, 0); + CALC_2(42434445, 464748, 35363738, 394041, 1); + + _r_ori = vld1q_s8(r7); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_1(42434445, 464748, 0); + CALC_1(42434445, 464748, 1); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + vst1q_s32(outptr2, _sum10); + vst1q_s32(outptr2 + 4, _sum11); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + r7 += 4; + outptr += 8; + outptr2 += 8; + } + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum10 = vld1q_s32(outptr2); + + int8x16_t _r_ori = vld1q_s8(r0); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(123, 456, 0); + + _r_ori = vld1q_s8(r1); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(78910, 111213, 123, 456, 0); + + _r_ori = vld1q_s8(r2); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(14151617, 181920, 78910, 111213, 0); + + _r_ori = vld1q_s8(r3); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(21222324, 252627, 14151617, 181920, 0); + + _r_ori = vld1q_s8(r4); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(28293031, 323334, 21222324, 252627, 0); + + _r_ori = vld1q_s8(r5); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(35363738, 394041, 28293031, 323334, 0); + + _r_ori = vld1q_s8(r6); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_2(42434445, 464748, 35363738, 394041, 0); + + _r_ori = vld1q_s8(r7); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_1(42434445, 464748, 0); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr2, _sum10); + + r0 += 2; + r1 += 2; + r2 += 2; + r3 += 2; + r4 += 2; + r5 += 2; + r6 += 2; + r7 += 2; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + r6 += tail_step + IW; + r7 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 2 < width; w += 2) { + int32x4_t _sum00 = vld1q_s32(outptr); + int32x4_t _sum01 = vld1q_s32(outptr + 4); + + int8x16_t _r_ori = vld1q_s8(r0); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(123, 456, 0); + CALC_0(123, 456, 1); + + _r_ori = vld1q_s8(r1); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(78910, 111213, 0); + CALC_0(78910, 111213, 1); + + _r_ori = vld1q_s8(r2); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(14151617, 181920, 0); + CALC_0(14151617, 181920, 1); + + _r_ori = vld1q_s8(r3); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(21222324, 252627, 0); + CALC_0(21222324, 252627, 1); + + _r_ori = vld1q_s8(r4); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(28293031, 323334, 0); + CALC_0(28293031, 323334, 1); + + _r_ori = vld1q_s8(r5); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(35363738, 394041, 0); + CALC_0(35363738, 394041, 1); + + _r_ori = vld1q_s8(r6); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(42434445, 464748, 0); + CALC_0(42434445, 464748, 1); + + vst1q_s32(outptr, _sum00); + vst1q_s32(outptr + 4, _sum01); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + outptr += 8; + } + for (; w < width; w++) { + int32x4_t _sum00 = vld1q_s32(outptr); + + int8x16_t _r_ori = vld1q_s8(r0); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(123, 456, 0); + + _r_ori = vld1q_s8(r1); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(78910, 111213, 0); + + _r_ori = vld1q_s8(r2); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(14151617, 181920, 0); + + _r_ori = vld1q_s8(r3); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(21222324, 252627, 0); + + _r_ori = vld1q_s8(r4); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(28293031, 323334, 0); + + _r_ori = vld1q_s8(r5); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(35363738, 394041, 0); + + _r_ori = vld1q_s8(r6); + _tmp = vqtbl1q_s8_common(_r_ori, _idx_r_0); + CALC_0(42434445, 464748, 0); + + vst1q_s32(outptr, _sum00); + + r0 += 2; + r1 += 2; + r2 += 2; + r3 += 2; + r4 += 2; + r5 += 2; + r6 += 2; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } + + filter += 49; + } +} + +#undef CALC_0 +#undef CALC_1 +#undef CALC_2 + +} // anonymous namespace + +size_t deconv::get_workspace_in_bytes_stride2_int8x8x32_dot( + const NCBKernSizeParam& param) { + return get_bundle(param).total_size_in_bytes(); +} + +bool deconv::can_stride2_int8x8x32_dot(const NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0], FW = fm.spatial[1], OC = fm.ocpg, + PH = fm.padding[0], PW = fm.padding[1]; + bool avaiable = fm.format == param::Convolution::Format::NCHW && + !fm.should_flip && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == FW && + (FH == 2 || FH == 3 || FH == 5 || FH == 7) && + FH >= PH + 1 && FW >= PW + 1; + + return avaiable && ((FH == 2 && OC <= 4) || (FH == 3 && OC <= 8) || + (FH == 5 && OC <= 16) || (FH == 7 && OC < 32)); +} + +void deconv::stride2_int8x8x32_dot(const NCBKernParam& param) { + auto bundle = get_bundle(param); + bundle.set(param.workspace_ptr); + UNPACK_CONV_F32_NCB_KERN_SIZES(param); + MEGDNN_MARK_USED_VAR(SH); + MEGDNN_MARK_USED_VAR(SW); + size_t IH2, IW2, OW2; + int padding_h = FH - PH - 1, padding_w = FW - PW - 1; + + get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); + + using Func = std::function; + Func conv = nullptr; + if (FH == 2) { + if ((padding_w & 1) == 0) + conv = deconv_direct_2x2; + else + conv = deconv_direct_2x2; + } else if (FH == 3) { + if ((padding_w & 1) == 0) + conv = deconv_direct_3x3; + else + conv = deconv_direct_3x3; + } else if (FH == 5) { + if ((padding_w & 1) == 0) + conv = deconv_direct_5x5; + else + conv = deconv_direct_5x5; + } else if (FH == 7) { + if ((padding_w & 1) == 0) + conv = deconv_direct_7x7; + else + conv = deconv_direct_7x7; + } else { + megdnn_assert(0); + } + + bool need_dst_copy_var = need_dst_copy(param); + int8_t* base_src_ptr = const_cast(param.diff()); + int32_t* base_dst_ptr = param.grad(); + const int8_t* fptr = param.filter(); + + for (size_t n = 0; n < N; ++n) { + int32_t* dptr_copied = static_cast(bundle.get(1)); + int32_t* dptr_ori = base_dst_ptr + n * param.out_bs; + int32_t* dptr = nullptr; + size_t OW_real = OW; + if (need_dst_copy_var) { + dptr = dptr_copied; + OW_real = OW2; + } else { + dptr = dptr_ori; + } + std::memset(dptr, 0, sizeof(int32_t) * IC * OH * OW_real); + + int8_t* sptr_ori = base_src_ptr + n * param.inp_bs; + int8_t* sptr_copied = static_cast(bundle.get(0)); + int8_t* sptr = nullptr; + rep(oc, OC) { + std::memset(sptr_copied, 0, sizeof(int8_t) * IH2 * IW2); + copy_plane_in_bytes(sptr_copied + padding_h * IW2 + padding_w / 2, + sptr_ori + oc * IH * IW, IH, + IW * sizeof(int8_t), 2 * IW2 * sizeof(int8_t), + IW * sizeof(int8_t)); + sptr = sptr_copied; + + conv(sptr, fptr + oc * IC * FH * FW, dptr, IH2, IW2, OH, OW_real, + IC); + } + if (need_dst_copy_var) { + for (size_t ic = 0; ic < IC; ++ic) { + copy_plane_in_bytes(dptr_ori + ic * OH * OW, + dptr + ic * OH * OW2, OH, + OW * sizeof(int32_t), OW * sizeof(int32_t), + OW2 * sizeof(int32_t)); + } + } + } +} + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h new file mode 100644 index 00000000..d6c1b6de --- /dev/null +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h @@ -0,0 +1,37 @@ +/** + * \file dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/convolution/opr_impl.h" + +#include +#include + +namespace megdnn { +namespace arm_common { +namespace deconv { + +using NCBKernSizeParam = ConvolutionBackwardDataImpl::NCBKernSizeParam; +using NCBKernParam = ConvolutionBackwardDataImpl::NCBKernParam; + +bool can_stride2_int8x8x32_dot(const NCBKernSizeParam& param); + +void stride2_int8x8x32_dot(const NCBKernParam& param); + +size_t get_workspace_in_bytes_stride2_int8x8x32_dot(const NCBKernSizeParam& param); + +} // namespace convolution +} // namespace arm_common +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/opr_impl.cpp b/dnn/src/arm_common/convolution/opr_impl.cpp new file mode 100644 index 00000000..d9a36e94 --- /dev/null +++ b/dnn/src/arm_common/convolution/opr_impl.cpp @@ -0,0 +1,94 @@ +/** + * \file dnn/src/arm_common/convolution/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./opr_impl.h" +#include "./int8x8x32/algos.h" +#include "./quint8/algos.h" + +#include "src/common/metahelper.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" +#include "src/common/opr_delegate.h" + +using namespace megdnn; +using namespace arm_common; + +namespace { +uint8_t arm_common_algo_type_storage; +} // anonymous namespace + +/* ===================== ConvolutionBackwardData ===================== */ +struct ConvolutionBackwardDataImpl::AlgoPack { +#if __ARM_FEATURE_DOTPROD + AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot; + AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot; + AlgoUdot8DirectStride1 quint8_direct_stride1_udot; + AlgoUdot8DirectStride2 quint8_direct_stride2_udot; +#endif +}; +ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; + +void* const ConvolutionBackwardDataImpl::sm_arm_common_algo_type = + &arm_common_algo_type_storage; + +ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( + Algorithm* algo, const NCBKernSizeParam& param) { + if (algo->type() == sm_arm_common_algo_type) { + return static_cast(algo)->dispatch_kern(this, param); + } + return fallback::ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(algo, param); +} + +size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(Algorithm* algo, + const NCBKernSizeParam& param) { + if (algo->type() == sm_arm_common_algo_type) { + return static_cast(algo)->get_workspace(this, param); + } + return fallback::ConvolutionBackwardDataImpl::ncb_1g_get_workspace(algo, param); +} + +std::vector +ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(const NCBKernSizeParam& param) { + + auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(param); + +#if __ARM_FEATURE_DOTPROD + if((param.filter_type.enumv() == DTypeEnum::QuantizedS8 || + param.filter_type.enumv() == DTypeEnum::Int8) && + (param.grad_type.enumv() == DTypeEnum::QuantizedS32 || + param.grad_type.enumv() == DTypeEnum::Int32)) { + + if (sm_algo_pack.i8x8x32_direct_stride1_sdot.usable(this, param)) { + ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride1_sdot); + } + if (sm_algo_pack.i8x8x32_direct_stride2_sdot.usable(this, param)) { + ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride2_sdot); + } + } + else if(param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && + param.grad_type.enumv() == DTypeEnum::QuantizedS32) { + + if (sm_algo_pack.quint8_direct_stride1_udot.usable(this, param)) { + ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride1_udot); + } + if (sm_algo_pack.quint8_direct_stride2_udot.usable(this, param)) { + ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride2_udot); + } + } +#endif + return ret; +} +const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const { + // arm common version 0 + return "DeconvAC0"; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/opr_impl.h b/dnn/src/arm_common/convolution/opr_impl.h new file mode 100644 index 00000000..17e8ec0b --- /dev/null +++ b/dnn/src/arm_common/convolution/opr_impl.h @@ -0,0 +1,65 @@ +/** + * \file dnn/src/arm_common/convolution/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/common/utils.h" +#include "src/fallback/convolution/opr_impl.h" +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class ConvBiasImpl; + +class ConvolutionBackwardDataImpl : public fallback::ConvolutionBackwardDataImpl { +public: + using fallback::ConvolutionBackwardDataImpl::ConvolutionBackwardDataImpl; + +protected: + static void* const sm_arm_common_algo_type; + + class AlgoBase : public Algorithm { + protected: + ~AlgoBase() = default; + + public: + virtual bool usable(ConvolutionBackwardDataImpl* opr, + const NCBKernSizeParam& param) const = 0; + virtual size_t get_workspace(ConvolutionBackwardDataImpl* opr, + const NCBKernSizeParam& param) const = 0; + virtual ncb_kern_t dispatch_kern( + ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param) const = 0; + }; + + ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo, + const NCBKernSizeParam& param) override; + + size_t ncb_1g_get_workspace(Algorithm* algo, + const NCBKernSizeParam& param) override; + + std::vector ncb_1g_get_all_algorithms( + const NCBKernSizeParam& param) override; + + const char* get_algorithm_set_name() const override; + + private: +#if __ARM_FEATURE_DOTPROD + class AlgoSdot8DirectStride1; + class AlgoSdot8DirectStride2; + class AlgoUdot8DirectStride1; + class AlgoUdot8DirectStride2; +#endif + struct AlgoPack; + static AlgoPack sm_algo_pack; +}; + +} // namespace arm_common +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/quint8/algos.cpp b/dnn/src/arm_common/convolution/quint8/algos.cpp new file mode 100644 index 00000000..951af79f --- /dev/null +++ b/dnn/src/arm_common/convolution/quint8/algos.cpp @@ -0,0 +1,62 @@ +/** + * \file dnn/src/arm_common/convolution/quint8/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/convolution/img2col_helper.h" +#include "src/arm_common/convolution/quint8/algos.h" +#include "src/arm_common/convolution/quint8/conv_backdata_stride1.h" +#include "src/arm_common/convolution/quint8/conv_backdata_stride2.h" +#include "src/common/opr_delegate.h" +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_conv_quint8_kimpl) + +using namespace megdnn; +using namespace arm_common; + +#if __ARM_FEATURE_DOTPROD + +/* ===================== ConvolutionBackwardData ===================== */ + +/* ===================== direct stride 1 algo ===================== */ +bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable( + ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + return deconv::can_stride1_quint8_dot(param); +} + +size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace( + ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + return deconv::get_workspace_in_bytes_stride1_quint8_dot(param); +} + +ConvolutionBackwardDataImpl::ncb_kern_t +ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( + ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { + return deconv::stride1_quint8_dot; +} + +/* ===================== direct stride 2 algo ===================== */ +bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable( + ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + return deconv::can_stride2_quint8_dot(param); +} + +size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace( + ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + return deconv::get_workspace_in_bytes_stride2_quint8_dot(param); +} + +ConvolutionBackwardDataImpl::ncb_kern_t +ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::dispatch_kern( + ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { + return deconv::stride2_quint8_dot; +} +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/quint8/algos.h b/dnn/src/arm_common/convolution/quint8/algos.h new file mode 100644 index 00000000..5cba3485 --- /dev/null +++ b/dnn/src/arm_common/convolution/quint8/algos.h @@ -0,0 +1,58 @@ +/** + * \file dnn/src/arm_common/convolution/quint8/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/convolution/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +#if __ARM_FEATURE_DOTPROD +/* ===================== ConvolutionBackwardData ===================== */ +class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE1"; } + + bool usable(ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const override; + + size_t get_workspace(ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const override; + + ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, + const NCBKernSizeParam&) const override; + + void* type() const override { return sm_arm_common_algo_type; } +}; + +class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE2"; } + + bool usable(ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const override; + + size_t get_workspace(ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const override; + + ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, + const NCBKernSizeParam&) const override; + + void* type() const override { return sm_arm_common_algo_type; } +}; +#endif +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp new file mode 100644 index 00000000..bd2f2176 --- /dev/null +++ b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp @@ -0,0 +1,1352 @@ +/** + * \file dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/convolution/quint8/conv_backdata_stride1.h" +#include "src/common/utils.h" + +#include +#include "src/arm_common/simd_macro/marm_neon.h" + +using namespace megdnn; +using namespace arm_common; +using namespace deconv; + +namespace { + +#define SHIFT_BITS 30 +#define SHIFT (1 << SHIFT_BITS) + +bool need_dst_copy(const NCBKernSizeParam& param) { + if (param.osz[1] % 4 != 0) { + // If the size of output is not multiples of 4, we need to copy it. + return true; + } + return false; +} + +bool need_src_copy(const NCBKernSizeParam& param) { + auto FH = param.filter_meta.spatial[0], FW = param.filter_meta.spatial[1], + PH = param.filter_meta.padding[0], PW = param.filter_meta.padding[1]; + return FH > PH + 1 || FW > PW + 1 || need_dst_copy(param); +} + +void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, + size_t FW, size_t PH, size_t PW, size_t& IH2, + size_t& IW2, size_t& OW2) { + MEGDNN_MARK_USED_VAR(OH); + MEGDNN_MARK_USED_VAR(IW); + MEGDNN_MARK_USED_VAR(PW); + //! OW should be a multiple of 4 + OW2 = (OW + 3) & ~3; + IH2 = IH + 2 * (FH - PH - 1); + //! OW2 - FW + 1 + 2 * PW + 2 * (FW - PW - 1); + IW2 = OW2 + FW - 1; +} + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param) { + UNPACK_CONV_F32_NCB_KERN_SIZES(param); + MEGDNN_MARK_USED_VAR(N); + MEGDNN_MARK_USED_VAR(OC); + MEGDNN_MARK_USED_VAR(SH); + MEGDNN_MARK_USED_VAR(SW); + size_t src_size = 0, dst_size = 0; + size_t IH2, IW2, OW2; + get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); + if (need_src_copy(param)) { + //! We only need one src plane + src_size = sizeof(uint8_t) * IH2 * IW2; + } + if (need_dst_copy(param)) { + dst_size = sizeof(int32_t) * IC * OH * OW2; + } + return WorkspaceBundle(nullptr, {src_size, dst_size}); +} + +inline uint8x16_t vqtbl1q_u8_common(uint8x16_t a, uint8x16_t index) { + uint8x8x2_t src; + src.val[0] = vget_low_u8(a); + src.val[1] = vget_high_u8(a); + uint8x8_t r00 = vtbl2_u8(src, vget_low_u8(index)); + uint8x8_t r01 = vtbl2_u8(src, vget_high_u8(index)); + uint8x16_t r = vcombine_u8(r00, r01); + return r; +} + +#define CALC_DST(_sum) \ + _sum = vreinterpretq_u32_s32( \ + vaddq_s32(vreinterpretq_s32_u32(_sum), _shift_zp)); + +#define CALC_0(_k_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +#define CALC_1(_k_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +#define CALC_2(_k1_idx, _k2_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k1_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k1_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k2_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k2_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +template +void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, + uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(IC); + const size_t tail_step = IW - OW; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + + const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, + 2, 3, 16, 16, 3, 4, 16, 16}; + const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, + 6, 7, 16, 16, 7, 8, 16, 16}; + const uint8_t* src_ptr = src; + //! here we use uint32_t for calc + uint32_t* outptr = reinterpret_cast(dst); + uint32_t* outptr2 = outptr + OW; + + const uint8_t* r0 = src_ptr; + const uint8_t* r1 = src_ptr + IW; + const uint8_t* r2 = src_ptr + 2 * IW; + + const uint8_t* k0 = filter; + + uint8x16_t _k0 = vreinterpretq_u8_u32( + vdupq_n_u32(*reinterpret_cast(k0))); + uint8x16_t _idx_k = {3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0}; + uint8x16_t _k = vqtbl1q_u8_common(_k0, _idx_k); + uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; + uint8x16_t _k1 = vqtbl1q_u8_common(_k, _idx); + _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; + uint8x16_t _k23 = vqtbl1q_u8_common(_k, _idx); + +#define SUB_ZP(_sum, _r) \ + _sum = vdotq_u32(_sum, _k, _r); \ + _sum = vsubq_u32(_sum, vdotq2_u32(_src_zp, _k)); \ + _sum = vsubq_u32(_sum, vdotq2_u32(_filter_zp, _r)); + + uint8x16_t _tmp, _elem; + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 4 < width; w += 4) { + uint32x4x2_t _sum00, _sum01, _sum10, _sum11; + _sum00 = vld2q_u32(outptr); + _sum01 = vld2q_u32(outptr + 8); + _sum10 = vld2q_u32(outptr2); + _sum11 = vld2q_u32(outptr2 + 8); + + uint8x16_t _r00 = vld1q_u8(r0); + //! here will not not read out of bound + uint8x16_t _r01_ = vdupq_n_u8(r0[16]); + uint8x16_t _r10 = vld1q_u8(r1); + uint8x16_t _r11_ = vdupq_n_u8(r1[16]); + uint8x16_t _r20 = vld1q_u8(r2); + uint8x16_t _r21_ = vdupq_n_u8(r2[16]); + uint8x16_t _r01 = vextq_u8(_r00, _r01_, 1); + uint8x16_t _r11 = vextq_u8(_r10, _r11_, 1); + uint8x16_t _r21 = vextq_u8(_r20, _r21_, 1); + + int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), + vreinterpretq_s16_u8(_r10)); + uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); + uint8x16_t _r2 = vreinterpretq_u8_s8(r_0.val[1]); + + int16x8x2_t r_1 = vzipq_s16(vreinterpretq_s16_u8(_r01), + vreinterpretq_s16_u8(_r11)); + uint8x16_t _r1 = vreinterpretq_u8_s8(r_1.val[0]); + uint8x16_t _r3 = vreinterpretq_u8_s8(r_1.val[1]); + + SUB_ZP(_sum00.val[0], _r0); + SUB_ZP(_sum00.val[1], _r1); + SUB_ZP(_sum01.val[0], _r2); + SUB_ZP(_sum01.val[1], _r3); + + r_0 = vzipq_s16(vreinterpretq_s16_u8(_r10), + vreinterpretq_s16_u8(_r20)); + _r0 = vreinterpretq_u8_s8(r_0.val[0]); + _r2 = vreinterpretq_u8_s8(r_0.val[1]); + + r_1 = vzipq_s16(vreinterpretq_s16_u8(_r11), + vreinterpretq_s16_u8(_r21)); + _r1 = vreinterpretq_u8_s8(r_1.val[0]); + _r3 = vreinterpretq_u8_s8(r_1.val[1]); + + SUB_ZP(_sum10.val[0], _r0); + SUB_ZP(_sum10.val[1], _r1); + SUB_ZP(_sum11.val[0], _r2); + SUB_ZP(_sum11.val[1], _r3); + + if (last_oc) { + CALC_DST(_sum00.val[0]); + CALC_DST(_sum00.val[1]); + CALC_DST(_sum01.val[0]); + CALC_DST(_sum01.val[1]); + CALC_DST(_sum10.val[0]); + CALC_DST(_sum10.val[1]); + CALC_DST(_sum11.val[0]); + CALC_DST(_sum11.val[1]); + } + vst2q_u32(outptr, _sum00); + vst2q_u32(outptr + 8, _sum01); + vst2q_u32(outptr2, _sum10); + vst2q_u32(outptr2 + 8, _sum11); + + r0 += 16; + r1 += 16; + r2 += 16; + outptr += 16; + outptr2 += 16; + } + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + uint32x4_t _sum10 = vld1q_u32(outptr2); + uint32x4_t _sum11 = vld1q_u32(outptr2 + 4); + + _tmp = vld1q_u8(r0); + CALC_0(1, 0); + CALC_0(1, 1); + + _tmp = vld1q_u8(r1); + CALC_2(23, 1, 0); + CALC_2(23, 1, 1); + + _tmp = vld1q_u8(r2); + CALC_1(23, 0); + CALC_1(23, 1); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum10); + CALC_DST(_sum11); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + vst1q_u32(outptr2, _sum10); + vst1q_u32(outptr2 + 4, _sum11); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 8; + outptr2 += 8; + } + + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum10 = vld1q_u32(outptr2); + + _tmp = vtranslq_u8(vld1_u8(r0)); + CALC_0(1, 0); + + _tmp = vtranslq_u8(vld1_u8(r1)); + CALC_2(23, 1, 0); + + _tmp = vtranslq_u8(vld1_u8(r2)); + CALC_1(23, 0); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr2, _sum10); + + r0 += 4; + r1 += 4; + r2 += 4; + outptr += 4; + outptr2 += 4; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 4 < width; w += 4) { + uint32x4x2_t _sum0, _sum1; + _sum0 = vld2q_u32(outptr); + _sum1 = vld2q_u32(outptr + 8); + + uint8x16_t _r00 = vld1q_u8(r0); + //! here will not not read out of bound + uint8x16_t _r01_ = vdupq_n_u8(r0[16]); + uint8x16_t _r10 = vld1q_u8(r1); + uint8x16_t _r11_ = vdupq_n_u8(r1[16]); + uint8x16_t _r01 = vextq_u8(_r00, _r01_, 1); + uint8x16_t _r11 = vextq_u8(_r10, _r11_, 1); + + int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), + vreinterpretq_s16_u8(_r10)); + uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); + uint8x16_t _r2 = vreinterpretq_u8_s8(r_0.val[1]); + + int16x8x2_t r_1 = vzipq_s16(vreinterpretq_s16_u8(_r01), + vreinterpretq_s16_u8(_r11)); + uint8x16_t _r1 = vreinterpretq_u8_s8(r_1.val[0]); + uint8x16_t _r3 = vreinterpretq_u8_s8(r_1.val[1]); + + SUB_ZP(_sum0.val[0], _r0); + SUB_ZP(_sum0.val[1], _r1); + SUB_ZP(_sum1.val[0], _r2); + SUB_ZP(_sum1.val[1], _r3); + + if (last_oc) { + CALC_DST(_sum0.val[0]); + CALC_DST(_sum0.val[1]); + CALC_DST(_sum1.val[0]); + CALC_DST(_sum1.val[1]); + } + vst2q_u32(outptr, _sum0); + vst2q_u32(outptr + 8, _sum1); + + r0 += 16; + r1 += 16; + outptr += 16; + } + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + + _tmp = vld1q_u8(r0); + CALC_0(1, 0); + CALC_0(1, 1); + + _tmp = vld1q_u8(r1); + CALC_0(23, 0); + CALC_0(23, 1); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + + r0 += 8; + r1 += 8; + outptr += 8; + } + + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + + _tmp = vtranslq_u8(vld1_u8(r0)); + CALC_0(1, 0); + + _tmp = vtranslq_u8(vld1_u8(r1)); + CALC_0(23, 0); + + if (last_oc) { + CALC_DST(_sum00); + } + vst1q_u32(outptr, _sum00); + + r0 += 4; + r1 += 4; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + } +#undef SUB_ZP +} + +template +void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, + uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(IC); + const size_t tail_step = IW - OW; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + + const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, + 2, 3, 4, 16, 3, 4, 5, 16}; + const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, + 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx2 = {8, 9, 10, 16, 9, 10, 11, 16, + 10, 11, 12, 16, 11, 12, 13, 16}; + const uint8_t* src_ptr = src; + uint32_t* outptr = reinterpret_cast(dst); + uint32_t* outptr2 = outptr + OW; + + const uint8_t* r0 = src_ptr; + const uint8_t* r1 = src_ptr + IW; + const uint8_t* r2 = src_ptr + IW * 2; + const uint8_t* r3 = src_ptr + IW * 3; + + const uint8_t* k0 = filter; + + uint8x16_t _k_tmp = vcombine_u8(vld1_u8(k0), vdup_n_u8(k0[8])); + uint8x16_t _idx = {8, 7, 6, 16, 8, 7, 6, 16, 8, 7, 6, 16, 8, 7, 6, 16}; + uint8x16_t _k12 = vqtbl1q_u8_common(_k_tmp, _idx); + _idx = {5, 4, 3, 16, 5, 4, 3, 16, 5, 4, 3, 16, 5, 4, 3, 16}; + uint8x16_t _k345 = vqtbl1q_u8_common(_k_tmp, _idx); + _idx = {2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16}; + uint8x16_t _k678 = vqtbl1q_u8_common(_k_tmp, _idx); + + uint8x16_t _tmp, _elem; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + //! As the inner kernel read 16 elements, and IW is times of 16 + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + uint32x4_t _sum02 = vld1q_u32(outptr + 8); + uint32x4_t _sum10 = vld1q_u32(outptr2); + uint32x4_t _sum11 = vld1q_u32(outptr2 + 4); + uint32x4_t _sum12 = vld1q_u32(outptr2 + 8); + + _tmp = vld1q_u8(r0); + CALC_0(12, 0); + CALC_0(12, 1); + CALC_0(12, 2); + + _tmp = vld1q_u8(r1); + CALC_2(345, 12, 0); + CALC_2(345, 12, 1); + CALC_2(345, 12, 2); + + _tmp = vld1q_u8(r2); + CALC_2(678, 345, 0); + CALC_2(678, 345, 1); + CALC_2(678, 345, 2); + + _tmp = vld1q_u8(r3); + CALC_1(678, 0); + CALC_1(678, 1); + CALC_1(678, 2); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + CALC_DST(_sum10); + CALC_DST(_sum11); + CALC_DST(_sum12); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + vst1q_u32(outptr + 8, _sum02); + vst1q_u32(outptr2, _sum10); + vst1q_u32(outptr2 + 4, _sum11); + vst1q_u32(outptr2 + 8, _sum12); + + r0 += 12; + r1 += 12; + r2 += 12; + r3 += 12; + outptr += 12; + outptr2 += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum10 = vld1q_u32(outptr2); + + _tmp = vtranslq_u8(vld1_u8(r0)); + CALC_0(12, 0); + + _tmp = vtranslq_u8(vld1_u8(r1)); + CALC_2(345, 12, 0); + + _tmp = vtranslq_u8(vld1_u8(r2)); + CALC_2(678, 345, 0); + + _tmp = vtranslq_u8(vld1_u8(r3)); + CALC_1(678, 0); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr2, _sum10); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + uint32x4_t _sum02 = vld1q_u32(outptr + 8); + + _tmp = vld1q_u8(r0); + CALC_0(12, 0); + CALC_0(12, 1); + CALC_0(12, 2); + + _tmp = vld1q_u8(r1); + CALC_0(345, 0); + CALC_0(345, 1); + CALC_0(345, 2); + + _tmp = vld1q_u8(r2); + CALC_0(678, 0); + CALC_0(678, 1); + CALC_0(678, 2); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + vst1q_u32(outptr + 8, _sum02); + + r0 += 12; + r1 += 12; + r2 += 12; + outptr += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + + _tmp = vtranslq_u8(vld1_u8(r0)); + CALC_0(12, 0); + + _tmp = vtranslq_u8(vld1_u8(r1)); + CALC_0(345, 0); + + _tmp = vtranslq_u8(vld1_u8(r2)); + CALC_0(678, 0); + + if (last_oc) { + CALC_DST(_sum00); + } + vst1q_u32(outptr, _sum00); + + r0 += 4; + r1 += 4; + r2 += 4; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } +} + +#undef CALC_0 +#undef CALC_1 +#undef CALC_2 + +#define CALC_0(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +#define CALC_1(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k00_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k01_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +#define CALC_2(_k00_idx, _k01_idx, _k10_idx, _k11_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _elem2 = vdotq2_u32(_filter_zp, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k10_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k10_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ + _elem2 = vdotq2_u32(_filter_zp, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k11_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k11_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); + +template +void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, + uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(IC); + const size_t tail_step = IW - OW; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + + const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; + const uint8x16_t _idx01 = {4, 16, 16, 16, 5, 16, 16, 16, + 6, 16, 16, 16, 7, 16, 16, 16}; + const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; + const uint8x16_t _idx11 = {8, 16, 16, 16, 9, 16, 16, 16, + 10, 16, 16, 16, 11, 16, 16, 16}; + const uint8x16_t _idx20 = {8, 9, 10, 11, 9, 10, 11, 12, + 10, 11, 12, 13, 11, 12, 13, 14}; + const uint8x16_t _idx21 = {12, 16, 16, 16, 13, 16, 16, 16, + 14, 16, 16, 16, 15, 16, 16, 16}; + uint8x16_t _tmp, _elem; + uint32x4_t _elem2; + const uint8_t* src_ptr = src; + uint32_t* outptr = reinterpret_cast(dst); + uint32_t* outptr2 = outptr + OW; + + const uint8_t* r0 = src_ptr; + const uint8_t* r1 = src_ptr + IW; + const uint8_t* r2 = src_ptr + IW * 2; + const uint8_t* r3 = src_ptr + IW * 3; + const uint8_t* r4 = src_ptr + IW * 4; + const uint8_t* r5 = src_ptr + IW * 5; + + const uint8_t* k0 = filter; + + uint8x16_t _k = vld1q_u8(k0 + 9); + //! filter row 1 + uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, + 15, 14, 13, 12, 15, 14, 13, 12}; + uint8x16_t _k123 = vqtbl1q_u8_common(_k, _idx); + _idx = {11, 16, 16, 16, 11, 16, 16, 16, 11, 16, 16, 16, 11, 16, 16, 16}; + uint8x16_t _k4 = vqtbl1q_u8_common(_k, _idx); + //! filter row 2 + _idx = {10, 9, 8, 7, 10, 9, 8, 7, 10, 9, 8, 7, 10, 9, 8, 7}; + uint8x16_t _k5678 = vqtbl1q_u8_common(_k, _idx); + _idx = {6, 16, 16, 16, 6, 16, 16, 16, 6, 16, 16, 16, 6, 16, 16, 16}; + uint8x16_t _k9 = vqtbl1q_u8_common(_k, _idx); + //! filter row 3 + _idx = {5, 4, 3, 2, 5, 4, 3, 2, 5, 4, 3, 2, 5, 4, 3, 2}; + uint8x16_t _k10111213 = vqtbl1q_u8_common(_k, _idx); + _idx = {1, 16, 16, 16, 1, 16, 16, 16, 1, 16, 16, 16, 1, 16, 16, 16}; + uint8x16_t _k14 = vqtbl1q_u8_common(_k, _idx); + //! 9 10 11 12 -> 13 14 15 16 -> 17 18 19 20 -> 21 22 23 24 + _k = vld1q_u8(k0); + //! filter row 4 + _idx = {9, 8, 7, 6, 9, 8, 7, 6, 9, 8, 7, 6, 9, 8, 7, 6}; + uint8x16_t _k15161718 = vqtbl1q_u8_common(_k, _idx); + _idx = {5, 16, 16, 16, 5, 16, 16, 16, 5, 16, 16, 16, 5, 16, 16, 16}; + uint8x16_t _k19 = vqtbl1q_u8_common(_k, _idx); + //! filter row 5 + _idx = {4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1}; + uint8x16_t _k20212223 = vqtbl1q_u8_common(_k, _idx); + _idx = {0, 16, 16, 16, 0, 16, 16, 16, 0, 16, 16, 16, 0, 16, 16, 16}; + uint8x16_t _k24 = vqtbl1q_u8_common(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 3 < width; w += 3) { + //! As the inner kernel read 16 elements, and IW is times of 16 + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + uint32x4_t _sum02 = vld1q_u32(outptr + 8); + uint32x4_t _sum10 = vld1q_u32(outptr2); + uint32x4_t _sum11 = vld1q_u32(outptr2 + 4); + uint32x4_t _sum12 = vld1q_u32(outptr2 + 8); + + _tmp = vld1q_u8(r0); + CALC_0(123, 4, 0); + CALC_0(123, 4, 1); + CALC_0(123, 4, 2); + + _tmp = vld1q_u8(r1); + CALC_2(5678, 9, 123, 4, 0); + CALC_2(5678, 9, 123, 4, 1); + CALC_2(5678, 9, 123, 4, 2); + + _tmp = vld1q_u8(r2); + CALC_2(10111213, 14, 5678, 9, 0); + CALC_2(10111213, 14, 5678, 9, 1); + CALC_2(10111213, 14, 5678, 9, 2); + + _tmp = vld1q_u8(r3); + CALC_2(15161718, 19, 10111213, 14, 0); + CALC_2(15161718, 19, 10111213, 14, 1); + CALC_2(15161718, 19, 10111213, 14, 2); + + _tmp = vld1q_u8(r4); + CALC_2(20212223, 24, 15161718, 19, 0); + CALC_2(20212223, 24, 15161718, 19, 1); + CALC_2(20212223, 24, 15161718, 19, 2); + + _tmp = vld1q_u8(r5); + CALC_1(20212223, 24, 0); + CALC_1(20212223, 24, 1); + CALC_1(20212223, 24, 2); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + CALC_DST(_sum10); + CALC_DST(_sum11); + CALC_DST(_sum12); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + vst1q_u32(outptr + 8, _sum02); + vst1q_u32(outptr2, _sum10); + vst1q_u32(outptr2 + 4, _sum11); + vst1q_u32(outptr2 + 8, _sum12); + + r0 += 12; + r1 += 12; + r2 += 12; + r3 += 12; + r4 += 12; + r5 += 12; + outptr += 12; + outptr2 += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum10 = vld1q_u32(outptr2); + + _tmp = vtranslq_u8(vld1_u8(r0)); + CALC_0(123, 4, 0); + + _tmp = vtranslq_u8(vld1_u8(r1)); + CALC_2(5678, 9, 123, 4, 0); + + _tmp = vtranslq_u8(vld1_u8(r2)); + CALC_2(10111213, 14, 5678, 9, 0); + + _tmp = vtranslq_u8(vld1_u8(r3)); + CALC_2(15161718, 19, 10111213, 14, 0); + + _tmp = vtranslq_u8(vld1_u8(r4)); + CALC_2(20212223, 24, 15161718, 19, 0); + + _tmp = vtranslq_u8(vld1_u8(r5)); + CALC_1(20212223, 24, 0); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr2, _sum10); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 3 < width; w += 3) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + uint32x4_t _sum02 = vld1q_u32(outptr + 8); + + _tmp = vld1q_u8(r0); + CALC_0(123, 4, 0); + CALC_0(123, 4, 1); + CALC_0(123, 4, 2); + + _tmp = vld1q_u8(r1); + CALC_0(5678, 9, 0); + CALC_0(5678, 9, 1); + CALC_0(5678, 9, 2); + + _tmp = vld1q_u8(r2); + CALC_0(10111213, 14, 0); + CALC_0(10111213, 14, 1); + CALC_0(10111213, 14, 2); + + _tmp = vld1q_u8(r3); + CALC_0(15161718, 19, 0); + CALC_0(15161718, 19, 1); + CALC_0(15161718, 19, 2); + + _tmp = vld1q_u8(r4); + CALC_0(20212223, 24, 0); + CALC_0(20212223, 24, 1); + CALC_0(20212223, 24, 2); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + vst1q_u32(outptr + 8, _sum02); + + r0 += 12; + r1 += 12; + r2 += 12; + r3 += 12; + r4 += 12; + outptr += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + + _tmp = vtranslq_u8(vld1_u8(r0)); + CALC_0(123, 4, 0); + + _tmp = vtranslq_u8(vld1_u8(r1)); + CALC_0(5678, 9, 0); + + _tmp = vtranslq_u8(vld1_u8(r2)); + CALC_0(10111213, 14, 0); + + _tmp = vtranslq_u8(vld1_u8(r3)); + CALC_0(15161718, 19, 0); + + _tmp = vtranslq_u8(vld1_u8(r4)); + CALC_0(20212223, 24, 0); + + if (last_oc) { + CALC_DST(_sum00); + } + vst1q_u32(outptr, _sum00); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } +} + +template +void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, + uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(IC); + const size_t tail_step = IW - OW; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + + const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, + 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; + const uint8x16_t _idx11 = {8, 9, 10, 16, 9, 10, 11, 16, + 10, 11, 12, 16, 11, 12, 13, 16}; + + uint8x16_t _tmp, _elem; + uint32x4_t _elem2; + const uint8_t* src_ptr = src; + uint32_t* outptr = reinterpret_cast(dst); + uint32_t* outptr2 = outptr + OW; + + const uint8_t* r0 = src_ptr; + const uint8_t* r1 = src_ptr + IW; + const uint8_t* r2 = src_ptr + IW * 2; + const uint8_t* r3 = src_ptr + IW * 3; + const uint8_t* r4 = src_ptr + IW * 4; + const uint8_t* r5 = src_ptr + IW * 5; + const uint8_t* r6 = src_ptr + IW * 6; + const uint8_t* r7 = src_ptr + IW * 7; + + const uint8_t* k0 = filter; + + uint8x16_t _k = vld1q_u8(k0 + 33); + //! filter row 1 + uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, + 15, 14, 13, 12, 15, 14, 13, 12}; + uint8x16_t _k123 = vqtbl1q_u8_common(_k, _idx); + _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; + uint8x16_t _k456 = vqtbl1q_u8_common(_k, _idx); + //! filter row 2 + _idx = {8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5}; + uint8x16_t _k78910 = vqtbl1q_u8_common(_k, _idx); + _idx = {4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16}; + uint8x16_t _k111213 = vqtbl1q_u8_common(_k, _idx); + + //! 12 13 14 15 -> 16 17 18 19 -> 20 21 22 23 -> 24 25 26 27 + _k = vld1q_u8(k0 + 19); + //! filter row 3 + _idx = {15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12}; + uint8x16_t _k14151617 = vqtbl1q_u8_common(_k, _idx); + _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; + uint8x16_t _k181920 = vqtbl1q_u8_common(_k, _idx); + //! filter row 4 + _idx = {8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5}; + uint8x16_t _k21222324 = vqtbl1q_u8_common(_k, _idx); + _idx = {4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16}; + uint8x16_t _k252627 = vqtbl1q_u8_common(_k, _idx); + + //! 24 25 26 27->28 29 30 31 -> 32 33 34 35 -> 36 37 38 39 + _k = vld1q_u8(k0 + 5); + //! filter row 5 + _idx = {15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12}; + uint8x16_t _k28293031 = vqtbl1q_u8_common(_k, _idx); + _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; + uint8x16_t _k323334 = vqtbl1q_u8_common(_k, _idx); + + //! 33 34 35 36 -> 37 38 39 40 -> 41 42 43 44 -> 45 46 47 48 + _k = vld1q_u8(k0); + //! filter row 6 + _idx = {13, 12, 11, 10, 13, 12, 11, 10, 13, 12, 11, 10, 13, 12, 11, 10}; + uint8x16_t _k35363738 = vqtbl1q_u8_common(_k, _idx); + _idx = {9, 8, 7, 16, 9, 8, 7, 16, 9, 8, 7, 16, 9, 8, 7, 16}; + uint8x16_t _k394041 = vqtbl1q_u8_common(_k, _idx); + + //! filter row 7 + _idx = {6, 5, 4, 3, 6, 5, 4, 3, 6, 5, 4, 3, 6, 5, 4, 3}; + uint8x16_t _k42434445 = vqtbl1q_u8_common(_k, _idx); + _idx = {2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16}; + uint8x16_t _k464748 = vqtbl1q_u8_common(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 2 < width; w += 2) { + //! As the inner kernel read 16 elements, and IW is times of 16 + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + uint32x4_t _sum10 = vld1q_u32(outptr2); + uint32x4_t _sum11 = vld1q_u32(outptr2 + 4); + + _tmp = vld1q_u8(r0); + CALC_0(123, 456, 0); + CALC_0(123, 456, 1); + + _tmp = vld1q_u8(r1); + CALC_2(78910, 111213, 123, 456, 0); + CALC_2(78910, 111213, 123, 456, 1); + + _tmp = vld1q_u8(r2); + CALC_2(14151617, 181920, 78910, 111213, 0); + CALC_2(14151617, 181920, 78910, 111213, 1); + + _tmp = vld1q_u8(r3); + CALC_2(21222324, 252627, 14151617, 181920, 0); + CALC_2(21222324, 252627, 14151617, 181920, 1); + + _tmp = vld1q_u8(r4); + CALC_2(28293031, 323334, 21222324, 252627, 0); + CALC_2(28293031, 323334, 21222324, 252627, 1); + + _tmp = vld1q_u8(r5); + CALC_2(35363738, 394041, 28293031, 323334, 0); + CALC_2(35363738, 394041, 28293031, 323334, 1); + + _tmp = vld1q_u8(r6); + CALC_2(42434445, 464748, 35363738, 394041, 0); + CALC_2(42434445, 464748, 35363738, 394041, 1); + + _tmp = vld1q_u8(r7); + CALC_1(42434445, 464748, 0); + CALC_1(42434445, 464748, 1); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum10); + CALC_DST(_sum11); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + vst1q_u32(outptr2, _sum10); + vst1q_u32(outptr2 + 4, _sum11); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + r7 += 8; + outptr += 8; + outptr2 += 8; + } + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum10 = vld1q_u32(outptr2); + + _tmp = vld1q_u8(r0); + CALC_0(123, 456, 0); + + _tmp = vld1q_u8(r1); + CALC_2(78910, 111213, 123, 456, 0); + + _tmp = vld1q_u8(r2); + CALC_2(14151617, 181920, 78910, 111213, 0); + + _tmp = vld1q_u8(r3); + CALC_2(21222324, 252627, 14151617, 181920, 0); + + _tmp = vld1q_u8(r4); + CALC_2(28293031, 323334, 21222324, 252627, 0); + + _tmp = vld1q_u8(r5); + CALC_2(35363738, 394041, 28293031, 323334, 0); + + _tmp = vld1q_u8(r6); + CALC_2(42434445, 464748, 35363738, 394041, 0); + + _tmp = vld1q_u8(r7); + CALC_1(42434445, 464748, 0); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr2, _sum10); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + r7 += 4; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + r6 += tail_step + IW; + r7 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + + _tmp = vld1q_u8(r0); + CALC_0(123, 456, 0); + CALC_0(123, 456, 1); + + _tmp = vld1q_u8(r1); + CALC_0(78910, 111213, 0); + CALC_0(78910, 111213, 1); + + _tmp = vld1q_u8(r2); + CALC_0(14151617, 181920, 0); + CALC_0(14151617, 181920, 1); + + _tmp = vld1q_u8(r3); + CALC_0(21222324, 252627, 0); + CALC_0(21222324, 252627, 1); + + _tmp = vld1q_u8(r4); + CALC_0(28293031, 323334, 0); + CALC_0(28293031, 323334, 1); + + _tmp = vld1q_u8(r5); + CALC_0(35363738, 394041, 0); + CALC_0(35363738, 394041, 1); + + _tmp = vld1q_u8(r6); + CALC_0(42434445, 464748, 0); + CALC_0(42434445, 464748, 1); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + outptr += 8; + } + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + + _tmp = vld1q_u8(r0); + CALC_0(123, 456, 0); + + _tmp = vld1q_u8(r1); + CALC_0(78910, 111213, 0); + + _tmp = vld1q_u8(r2); + CALC_0(14151617, 181920, 0); + + _tmp = vld1q_u8(r3); + CALC_0(21222324, 252627, 0); + + _tmp = vld1q_u8(r4); + CALC_0(28293031, 323334, 0); + + _tmp = vld1q_u8(r5); + CALC_0(35363738, 394041, 0); + + _tmp = vld1q_u8(r6); + CALC_0(42434445, 464748, 0); + + if (last_oc) { + CALC_DST(_sum00); + } + vst1q_u32(outptr, _sum00); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } +} + +#undef CALC_0 +#undef CALC_1 +#undef CALC_2 + +} // anonymous namespace + +size_t deconv::get_workspace_in_bytes_stride1_quint8_dot( + const NCBKernSizeParam& param) { + return get_bundle(param).total_size_in_bytes(); +} + +bool deconv::can_stride1_quint8_dot(const NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0], FW = fm.spatial[1], OC = fm.ocpg, + PH = fm.padding[0], PW = fm.padding[1]; + bool avaiable = fm.format == param::Convolution::Format::NCHW && + !fm.should_flip && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 1 && fm.stride[1] == 1 && FH == FW && + (FH == 2 || FH == 3 || FH == 5 || FH == 7) && + FH >= PH + 1 && FW >= PW + 1; + + /** + * \note In the kernel, we use int32_t to calc the value, in order + * not generate negative number, we first initialize SHIFT and sub + * it to get the actual value in the last oc calc. + * + * \warning the sum of dst value should not greater than SHIFT, + * otherwise it maybe error, but currently in mobile, it would not + * be possible(7*7*OC*2^8*2^8 > SHIFT => OC > 334). + */ + avaiable &= (7 * 7 * OC < (1 << (SHIFT_BITS - 8 - 8))); + return avaiable && OC < 8; +} + +void deconv::stride1_quint8_dot(const NCBKernParam& param) { + auto bundle = get_bundle(param); + bundle.set(param.workspace_ptr); + UNPACK_CONV_F32_NCB_KERN_SIZES(param); + MEGDNN_MARK_USED_VAR(SH); + MEGDNN_MARK_USED_VAR(SW); + size_t IH2, IW2, OW2; + int padding_h = FH - PH - 1, padding_w = FW - PW - 1; + get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); + + uint8_t filter_zp = + param.filter_type.param().zero_point; + uint8_t src_zp = param.diff_type.param().zero_point; + int32_t src_filter_zp = static_cast(filter_zp) * + static_cast(src_zp) * OC * FH * FH; + + using Func = std::function; + Func deconv = nullptr, deconv_last_oc = nullptr; + if (FH == 2) { + deconv = deconv_direct_2x2; + deconv_last_oc = deconv_direct_2x2; + } else if (FH == 3) { + deconv = deconv_direct_3x3; + deconv_last_oc = deconv_direct_3x3; + } else if (FH == 5) { + deconv = deconv_direct_5x5; + deconv_last_oc = deconv_direct_5x5; + } else if (FH == 7) { + deconv = deconv_direct_7x7; + deconv_last_oc = deconv_direct_7x7; + } else { + megdnn_assert(0); + } + + bool need_src_copy_var = need_src_copy(param); + bool need_dst_copy_var = need_dst_copy(param); + uint8_t* base_src_ptr = reinterpret_cast( + const_cast(param.diff())); + int32_t* base_dst_ptr = reinterpret_cast(param.grad()); + const uint8_t* fptr = + reinterpret_cast(param.filter()); + + for (size_t n = 0; n < N; ++n) { + int32_t* dptr_copied = static_cast(bundle.get(1)); + int32_t* dptr_ori = base_dst_ptr + n * param.out_bs; + int32_t* dptr = nullptr; + size_t OW_real = OW; + if (need_dst_copy_var) { + dptr = dptr_copied; + OW_real = OW2; + } else { + dptr = dptr_ori; + } + std::fill_n(dptr, IC * OH * OW_real, SHIFT); + + uint8_t* sptr_ori = base_src_ptr + n * param.inp_bs; + uint8_t* sptr_copied = static_cast(bundle.get(0)); + uint8_t* sptr = nullptr; + + rep(oc, OC) { + if (need_src_copy_var) { + // copy sptr_ori to sptr_copied + std::memset(sptr_copied, src_zp, sizeof(uint8_t) * IH2 * IW2); + copy_plane_in_bytes(sptr_copied + padding_h * IW2 + padding_w, + sptr_ori + oc * IH * IW, IH, + IW * sizeof(uint8_t), IW2 * sizeof(uint8_t), + IW * sizeof(uint8_t)); + sptr = sptr_copied; + } else { + sptr = sptr_ori + oc * IH * IW; + } + + int32_t* dst_ptr = dptr; + const uint8_t* filter = fptr + oc * IC * FH * FW; + for (size_t ic = 0; ic < IC; ic++) { + if (oc != OC - 1) { + deconv(sptr, filter, dst_ptr, IH2, IW2, OH, OW_real, IC, + src_zp, filter_zp, src_filter_zp); + } else { + deconv_last_oc(sptr, filter, dst_ptr, IH2, IW2, OH, OW_real, + IC, src_zp, filter_zp, src_filter_zp); + } + dst_ptr += OH * OW_real; + filter += FH * FH; + } + } + if (need_dst_copy_var) { + for (size_t ic = 0; ic < IC; ++ic) { + copy_plane_in_bytes(dptr_ori + ic * OH * OW, + dptr + ic * OH * OW2, OH, + OW * sizeof(int32_t), OW * sizeof(int32_t), + OW2 * sizeof(int32_t)); + } + } + } +} + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.h b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.h new file mode 100644 index 00000000..588205fd --- /dev/null +++ b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.h @@ -0,0 +1,37 @@ +/** + * \file dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/convolution/opr_impl.h" + +#include +#include + +namespace megdnn { +namespace arm_common { +namespace deconv { + +using NCBKernSizeParam = ConvolutionBackwardDataImpl::NCBKernSizeParam; +using NCBKernParam = ConvolutionBackwardDataImpl::NCBKernParam; + +bool can_stride1_quint8_dot(const NCBKernSizeParam& param); + +void stride1_quint8_dot(const NCBKernParam& param); + +size_t get_workspace_in_bytes_stride1_quint8_dot(const NCBKernSizeParam& param); + +} // namespace deconv +} // namespace arm_common +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp new file mode 100644 index 00000000..5db7ea06 --- /dev/null +++ b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp @@ -0,0 +1,1456 @@ +/** + * \file dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/convolution/quint8/conv_backdata_stride2.h" +#include "src/common/utils.h" + +#include +#include "src/arm_common/simd_macro/marm_neon.h" + +using namespace megdnn; +using namespace arm_common; +using namespace deconv; + +namespace { + +#define SHIFT_BITS 30 +#define SHIFT (1 << SHIFT_BITS) + +bool need_dst_copy(const NCBKernSizeParam& param) { + if (param.osz[1] % 4 != 0) { + // If the size of output is not multiples of 4, we need to copy it. + return true; + } + return false; +} + +void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, + size_t FW, size_t PH, size_t PW, size_t& IH2, + size_t& IW2, size_t& OW2) { + MEGDNN_MARK_USED_VAR(OH); + MEGDNN_MARK_USED_VAR(IW); + //! OW should be a multiple of 4 + OW2 = (OW + 3) & ~3; + IH2 = 2 * IH - 1 + 2 * (FH - PH - 1); + IW2 = (OW2 - FW + 2 * PW) / 2 + 1 + (FW - PW - 1) + 16; +} + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param) { + UNPACK_CONV_F32_NCB_KERN_SIZES(param); + MEGDNN_MARK_USED_VAR(N); + MEGDNN_MARK_USED_VAR(OC); + MEGDNN_MARK_USED_VAR(SH); + MEGDNN_MARK_USED_VAR(SW); + size_t src_size = 0, dst_size = 0; + size_t IH2, IW2, OW2; + get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); + //! We only need one src plane + src_size = sizeof(uint8_t) * IH2 * IW2; + if (need_dst_copy(param)) { + dst_size = sizeof(int32_t) * IC * OH * OW2; + } + return WorkspaceBundle(nullptr, {src_size, dst_size}); +} + +inline uint8x16_t vqtbl1q_u8_common(uint8x16_t a, uint8x16_t index) { + uint8x8x2_t src; + src.val[0] = vget_low_u8(a); + src.val[1] = vget_high_u8(a); + uint8x8_t r00 = vtbl2_u8(src, vget_low_u8(index)); + uint8x8_t r01 = vtbl2_u8(src, vget_high_u8(index)); + uint8x16_t r = vcombine_u8(r00, r01); + return r; +} + +inline uint8x16_t vqtbx1q_u8_common(uint8x16_t a, uint8x16_t t, + uint8x16_t idx) { + uint8x8x2_t _temp; + _temp.val[0] = vget_low_u8(t); + _temp.val[1] = vget_high_u8(t); + uint8x8_t a_low = vtbx2_u8(vget_low_u8(a), _temp, vget_low_u8(idx)); + uint8x8_t a_high = vtbx2_u8(vget_high_u8(a), _temp, vget_high_u8(idx)); + uint8x16_t r = vcombine_u8(a_low, a_high); + return r; +} + +#define CALC_DST(_sum) \ + _sum = vreinterpretq_u32_s32( \ + vaddq_s32(vreinterpretq_s32_u32(_sum), _shift_zp)); + +#define CALC_0(_k_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +#define CALC_1(_k_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +#define CALC_2(_k1_idx, _k2_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k1_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k1_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k2_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k2_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +template +void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, + uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(IC); + const size_t tail_step = IW - OW / 2; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + + const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, + 2, 3, 16, 16, 3, 4, 16, 16}; + const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, + 6, 7, 16, 16, 7, 8, 16, 16}; + uint8x16_t _idx_r_0, _idx_r_1; + if (even) { + _idx_r_0 = {0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16}; + _idx_r_1 = {16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16, 8}; + } else { + _idx_r_0 = {16, 0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7}; + _idx_r_1 = {0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16}; + } + const uint8_t* src_ptr = src; + //! here we use uint32_t for calc + uint32_t* outptr = reinterpret_cast(dst); + uint32_t* outptr2 = outptr + OW; + + const uint8_t* r0 = src_ptr; + const uint8_t* r1 = src_ptr + IW; + const uint8_t* r2 = src_ptr + 2 * IW; + + const uint8_t* k0 = filter; + + uint8x16_t _k0 = vreinterpretq_u8_u32( + vdupq_n_u32(*reinterpret_cast(k0))); + uint8x16_t _idx_k = {3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0}; + uint8x16_t _k = vqtbl1q_u8_common(_k0, _idx_k); + uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; + uint8x16_t _k1 = vqtbl1q_u8_common(_k, _idx); + _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; + uint8x16_t _k23 = vqtbl1q_u8_common(_k, _idx); + +#define SUB_ZP(_sum, _r) \ + _sum = vdotq_u32(_sum, _k, _r); \ + _sum = vsubq_u32(_sum, vdotq2_u32(_src_zp, _k)); \ + _sum = vsubq_u32(_sum, vdotq2_u32(_filter_zp, _r)); + + uint8x16_t _tmp, _elem; + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 4 < width; w += 4) { + uint32x4x2_t _sum00, _sum01, _sum10, _sum11; + _sum00 = vld2q_u32(outptr); + _sum01 = vld2q_u32(outptr + 8); + _sum10 = vld2q_u32(outptr2); + _sum11 = vld2q_u32(outptr2 + 8); + + uint8x16_t _r0_ori = vld1q_u8(r0); + uint8x16_t _r00 = vqtbx1q_u8_common(_src_zp, _r0_ori, _idx_r_0); + uint8x16_t _r01 = vqtbx1q_u8_common(_src_zp, _r0_ori, _idx_r_1); + uint8x16_t _r1_ori = vld1q_u8(r1); + uint8x16_t _r10 = vqtbx1q_u8_common(_src_zp, _r1_ori, _idx_r_0); + uint8x16_t _r11 = vqtbx1q_u8_common(_src_zp, _r1_ori, _idx_r_1); + uint8x16_t _r2_ori = vld1q_u8(r2); + uint8x16_t _r20 = vqtbx1q_u8_common(_src_zp, _r2_ori, _idx_r_0); + uint8x16_t _r21 = vqtbx1q_u8_common(_src_zp, _r2_ori, _idx_r_1); + + int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), + vreinterpretq_s16_u8(_r10)); + uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); + uint8x16_t _r2 = vreinterpretq_u8_s8(r_0.val[1]); + + int16x8x2_t r_1 = vzipq_s16(vreinterpretq_s16_u8(_r01), + vreinterpretq_s16_u8(_r11)); + uint8x16_t _r1 = vreinterpretq_u8_s8(r_1.val[0]); + uint8x16_t _r3 = vreinterpretq_u8_s8(r_1.val[1]); + + SUB_ZP(_sum00.val[0], _r0); + SUB_ZP(_sum00.val[1], _r1); + SUB_ZP(_sum01.val[0], _r2); + SUB_ZP(_sum01.val[1], _r3); + + r_0 = vzipq_s16(vreinterpretq_s16_u8(_r10), + vreinterpretq_s16_u8(_r20)); + _r0 = vreinterpretq_u8_s8(r_0.val[0]); + _r2 = vreinterpretq_u8_s8(r_0.val[1]); + + r_1 = vzipq_s16(vreinterpretq_s16_u8(_r11), + vreinterpretq_s16_u8(_r21)); + _r1 = vreinterpretq_u8_s8(r_1.val[0]); + _r3 = vreinterpretq_u8_s8(r_1.val[1]); + + SUB_ZP(_sum10.val[0], _r0); + SUB_ZP(_sum10.val[1], _r1); + SUB_ZP(_sum11.val[0], _r2); + SUB_ZP(_sum11.val[1], _r3); + + if (last_oc) { + CALC_DST(_sum00.val[0]); + CALC_DST(_sum00.val[1]); + CALC_DST(_sum01.val[0]); + CALC_DST(_sum01.val[1]); + CALC_DST(_sum10.val[0]); + CALC_DST(_sum10.val[1]); + CALC_DST(_sum11.val[0]); + CALC_DST(_sum11.val[1]); + } + vst2q_u32(outptr, _sum00); + vst2q_u32(outptr + 8, _sum01); + vst2q_u32(outptr2, _sum10); + vst2q_u32(outptr2 + 8, _sum11); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 16; + outptr2 += 16; + } + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + uint32x4_t _sum10 = vld1q_u32(outptr2); + uint32x4_t _sum11 = vld1q_u32(outptr2 + 4); + + uint8x16_t _r_ori = vld1q_u8(r0); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(1, 0); + CALC_0(1, 1); + + _r_ori = vld1q_u8(r1); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(23, 1, 0); + CALC_2(23, 1, 1); + + _r_ori = vld1q_u8(r2); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_1(23, 0); + CALC_1(23, 1); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum10); + CALC_DST(_sum11); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + vst1q_u32(outptr2, _sum10); + vst1q_u32(outptr2 + 4, _sum11); + + r0 += 4; + r1 += 4; + r2 += 4; + outptr += 8; + outptr2 += 8; + } + + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum10 = vld1q_u32(outptr2); + + uint8x16_t _r_ori = vtranslq_u8(vld1_u8(r0)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(1, 0); + + _r_ori = vtranslq_u8(vld1_u8(r1)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(23, 1, 0); + + _r_ori = vtranslq_u8(vld1_u8(r2)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_1(23, 0); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr2, _sum10); + + r0 += 2; + r1 += 2; + r2 += 2; + outptr += 4; + outptr2 += 4; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 4 < width; w += 4) { + uint32x4x2_t _sum0, _sum1; + _sum0 = vld2q_u32(outptr); + _sum1 = vld2q_u32(outptr + 8); + + uint8x16_t _r0_ori = vld1q_u8(r0); + uint8x16_t _r00 = vqtbx1q_u8_common(_src_zp, _r0_ori, _idx_r_0); + uint8x16_t _r01 = vqtbx1q_u8_common(_src_zp, _r0_ori, _idx_r_1); + uint8x16_t _r1_ori = vld1q_u8(r1); + uint8x16_t _r10 = vqtbx1q_u8_common(_src_zp, _r1_ori, _idx_r_0); + uint8x16_t _r11 = vqtbx1q_u8_common(_src_zp, _r1_ori, _idx_r_1); + + int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), + vreinterpretq_s16_u8(_r10)); + uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); + uint8x16_t _r2 = vreinterpretq_u8_s8(r_0.val[1]); + + int16x8x2_t r_1 = vzipq_s16(vreinterpretq_s16_u8(_r01), + vreinterpretq_s16_u8(_r11)); + uint8x16_t _r1 = vreinterpretq_u8_s8(r_1.val[0]); + uint8x16_t _r3 = vreinterpretq_u8_s8(r_1.val[1]); + + SUB_ZP(_sum0.val[0], _r0); + SUB_ZP(_sum0.val[1], _r1); + SUB_ZP(_sum1.val[0], _r2); + SUB_ZP(_sum1.val[1], _r3); + + if (last_oc) { + CALC_DST(_sum0.val[0]); + CALC_DST(_sum0.val[1]); + CALC_DST(_sum1.val[0]); + CALC_DST(_sum1.val[1]); + } + vst2q_u32(outptr, _sum0); + vst2q_u32(outptr + 8, _sum1); + + r0 += 8; + r1 += 8; + outptr += 16; + } + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + + uint8x16_t _r_ori = vld1q_u8(r0); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(1, 0); + CALC_0(1, 1); + + _r_ori = vld1q_u8(r1); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(23, 0); + CALC_0(23, 1); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + + r0 += 4; + r1 += 4; + outptr += 8; + } + + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + + uint8x16_t _r_ori = vtranslq_u8(vld1_u8(r0)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(1, 0); + + _r_ori = vtranslq_u8(vld1_u8(r1)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(23, 0); + + if (last_oc) { + CALC_DST(_sum00); + } + vst1q_u32(outptr, _sum00); + + r0 += 2; + r1 += 2; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + } +#undef SUB_ZP +} + +template +void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, + uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(IC); + const size_t tail_step = IW - OW / 2; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + + const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, + 2, 3, 4, 16, 3, 4, 5, 16}; + const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, + 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx2 = {8, 9, 10, 16, 9, 10, 11, 16, + 10, 11, 12, 16, 11, 12, 13, 16}; + uint8x16_t _idx_r_0; + if (even) { + _idx_r_0 = {0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16}; + } else { + _idx_r_0 = {16, 0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7}; + } + const uint8_t* src_ptr = src; + uint32_t* outptr = reinterpret_cast(dst); + uint32_t* outptr2 = outptr + OW; + + const uint8_t* r0 = src_ptr; + const uint8_t* r1 = src_ptr + IW; + const uint8_t* r2 = src_ptr + IW * 2; + const uint8_t* r3 = src_ptr + IW * 3; + + const uint8_t* k0 = filter; + + uint8x16_t _k_tmp = vcombine_u8(vld1_u8(k0), vdup_n_u8(k0[8])); + uint8x16_t _idx = {8, 7, 6, 16, 8, 7, 6, 16, 8, 7, 6, 16, 8, 7, 6, 16}; + uint8x16_t _k12 = vqtbl1q_u8_common(_k_tmp, _idx); + _idx = {5, 4, 3, 16, 5, 4, 3, 16, 5, 4, 3, 16, 5, 4, 3, 16}; + uint8x16_t _k345 = vqtbl1q_u8_common(_k_tmp, _idx); + _idx = {2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16}; + uint8x16_t _k678 = vqtbl1q_u8_common(_k_tmp, _idx); + + uint8x16_t _tmp, _elem; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + //! As the inner kernel read 16 elements, and IW is times of 16 + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + uint32x4_t _sum02 = vld1q_u32(outptr + 8); + uint32x4_t _sum10 = vld1q_u32(outptr2); + uint32x4_t _sum11 = vld1q_u32(outptr2 + 4); + uint32x4_t _sum12 = vld1q_u32(outptr2 + 8); + + uint8x16_t _r_ori = vld1q_u8(r0); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(12, 0); + CALC_0(12, 1); + CALC_0(12, 2); + + _r_ori = vld1q_u8(r1); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(345, 12, 0); + CALC_2(345, 12, 1); + CALC_2(345, 12, 2); + + _r_ori = vld1q_u8(r2); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(678, 345, 0); + CALC_2(678, 345, 1); + CALC_2(678, 345, 2); + + _r_ori = vld1q_u8(r3); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_1(678, 0); + CALC_1(678, 1); + CALC_1(678, 2); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + CALC_DST(_sum10); + CALC_DST(_sum11); + CALC_DST(_sum12); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + vst1q_u32(outptr + 8, _sum02); + vst1q_u32(outptr2, _sum10); + vst1q_u32(outptr2 + 4, _sum11); + vst1q_u32(outptr2 + 8, _sum12); + + r0 += 6; + r1 += 6; + r2 += 6; + r3 += 6; + outptr += 12; + outptr2 += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum10 = vld1q_u32(outptr2); + + uint8x16_t _r_ori = vtranslq_u8(vld1_u8(r0)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(12, 0); + + _r_ori = vtranslq_u8(vld1_u8(r1)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(345, 12, 0); + + _r_ori = vtranslq_u8(vld1_u8(r2)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(678, 345, 0); + + _r_ori = vtranslq_u8(vld1_u8(r3)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_1(678, 0); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr2, _sum10); + + r0 += 2; + r1 += 2; + r2 += 2; + r3 += 2; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 2; + + int w = 0; + for (; w + 3 < width; w += 3) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + uint32x4_t _sum02 = vld1q_u32(outptr + 8); + + uint8x16_t _r_ori = vld1q_u8(r0); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(12, 0); + CALC_0(12, 1); + CALC_0(12, 2); + + _r_ori = vld1q_u8(r1); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(345, 0); + CALC_0(345, 1); + CALC_0(345, 2); + + _r_ori = vld1q_u8(r2); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(678, 0); + CALC_0(678, 1); + CALC_0(678, 2); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + vst1q_u32(outptr + 8, _sum02); + + r0 += 6; + r1 += 6; + r2 += 6; + outptr += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + + uint8x16_t _r_ori = vtranslq_u8(vld1_u8(r0)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(12, 0); + + _r_ori = vtranslq_u8(vld1_u8(r1)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(345, 0); + + _r_ori = vtranslq_u8(vld1_u8(r2)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(678, 0); + + if (last_oc) { + CALC_DST(_sum00); + } + vst1q_u32(outptr, _sum00); + + r0 += 2; + r1 += 2; + r2 += 2; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } +} + +#undef CALC_0 +#undef CALC_1 +#undef CALC_2 + +#define CALC_0(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +#define CALC_1(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k00_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k01_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); + +#define CALC_2(_k00_idx, _k01_idx, _k10_idx, _k11_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _elem2 = vdotq2_u32(_filter_zp, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k10_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k10_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum0##_c_idx = \ + vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ + _elem2 = vdotq2_u32(_filter_zp, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k11_idx, _elem); \ + _sum1##_c_idx = \ + vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k11_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); + +template +void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, + uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(IC); + const size_t tail_step = IW - OW / 2; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + + const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; + const uint8x16_t _idx01 = {4, 16, 16, 16, 5, 16, 16, 16, + 6, 16, 16, 16, 7, 16, 16, 16}; + const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; + const uint8x16_t _idx11 = {8, 16, 16, 16, 9, 16, 16, 16, + 10, 16, 16, 16, 11, 16, 16, 16}; + const uint8x16_t _idx20 = {8, 9, 10, 11, 9, 10, 11, 12, + 10, 11, 12, 13, 11, 12, 13, 14}; + const uint8x16_t _idx21 = {12, 16, 16, 16, 13, 16, 16, 16, + 14, 16, 16, 16, 15, 16, 16, 16}; + uint8x16_t _idx_r_0; + if (even) { + _idx_r_0 = {0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16}; + } else { + _idx_r_0 = {16, 0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7}; + } + uint8x16_t _tmp, _elem; + uint32x4_t _elem2; + const uint8_t* src_ptr = src; + uint32_t* outptr = reinterpret_cast(dst); + uint32_t* outptr2 = outptr + OW; + + const uint8_t* r0 = src_ptr; + const uint8_t* r1 = src_ptr + IW; + const uint8_t* r2 = src_ptr + IW * 2; + const uint8_t* r3 = src_ptr + IW * 3; + const uint8_t* r4 = src_ptr + IW * 4; + const uint8_t* r5 = src_ptr + IW * 5; + + const uint8_t* k0 = filter; + + uint8x16_t _k = vld1q_u8(k0 + 9); + //! filter row 1 + uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, + 15, 14, 13, 12, 15, 14, 13, 12}; + uint8x16_t _k123 = vqtbl1q_u8_common(_k, _idx); + _idx = {11, 16, 16, 16, 11, 16, 16, 16, 11, 16, 16, 16, 11, 16, 16, 16}; + uint8x16_t _k4 = vqtbl1q_u8_common(_k, _idx); + //! filter row 2 + _idx = {10, 9, 8, 7, 10, 9, 8, 7, 10, 9, 8, 7, 10, 9, 8, 7}; + uint8x16_t _k5678 = vqtbl1q_u8_common(_k, _idx); + _idx = {6, 16, 16, 16, 6, 16, 16, 16, 6, 16, 16, 16, 6, 16, 16, 16}; + uint8x16_t _k9 = vqtbl1q_u8_common(_k, _idx); + //! filter row 3 + _idx = {5, 4, 3, 2, 5, 4, 3, 2, 5, 4, 3, 2, 5, 4, 3, 2}; + uint8x16_t _k10111213 = vqtbl1q_u8_common(_k, _idx); + _idx = {1, 16, 16, 16, 1, 16, 16, 16, 1, 16, 16, 16, 1, 16, 16, 16}; + uint8x16_t _k14 = vqtbl1q_u8_common(_k, _idx); + //! 9 10 11 12 -> 13 14 15 16 -> 17 18 19 20 -> 21 22 23 24 + _k = vld1q_u8(k0); + //! filter row 4 + _idx = {9, 8, 7, 6, 9, 8, 7, 6, 9, 8, 7, 6, 9, 8, 7, 6}; + uint8x16_t _k15161718 = vqtbl1q_u8_common(_k, _idx); + _idx = {5, 16, 16, 16, 5, 16, 16, 16, 5, 16, 16, 16, 5, 16, 16, 16}; + uint8x16_t _k19 = vqtbl1q_u8_common(_k, _idx); + //! filter row 5 + _idx = {4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1}; + uint8x16_t _k20212223 = vqtbl1q_u8_common(_k, _idx); + _idx = {0, 16, 16, 16, 0, 16, 16, 16, 0, 16, 16, 16, 0, 16, 16, 16}; + uint8x16_t _k24 = vqtbl1q_u8_common(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 3 < width; w += 3) { + //! As the inner kernel read 16 elements, and IW is times of 16 + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + uint32x4_t _sum02 = vld1q_u32(outptr + 8); + uint32x4_t _sum10 = vld1q_u32(outptr2); + uint32x4_t _sum11 = vld1q_u32(outptr2 + 4); + uint32x4_t _sum12 = vld1q_u32(outptr2 + 8); + + uint8x16_t _r_ori = vld1q_u8(r0); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(123, 4, 0); + CALC_0(123, 4, 1); + CALC_0(123, 4, 2); + + _r_ori = vld1q_u8(r1); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(5678, 9, 123, 4, 0); + CALC_2(5678, 9, 123, 4, 1); + CALC_2(5678, 9, 123, 4, 2); + + _r_ori = vld1q_u8(r2); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(10111213, 14, 5678, 9, 0); + CALC_2(10111213, 14, 5678, 9, 1); + CALC_2(10111213, 14, 5678, 9, 2); + + _r_ori = vld1q_u8(r3); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(15161718, 19, 10111213, 14, 0); + CALC_2(15161718, 19, 10111213, 14, 1); + CALC_2(15161718, 19, 10111213, 14, 2); + + _r_ori = vld1q_u8(r4); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(20212223, 24, 15161718, 19, 0); + CALC_2(20212223, 24, 15161718, 19, 1); + CALC_2(20212223, 24, 15161718, 19, 2); + + _r_ori = vld1q_u8(r5); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_1(20212223, 24, 0); + CALC_1(20212223, 24, 1); + CALC_1(20212223, 24, 2); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + CALC_DST(_sum10); + CALC_DST(_sum11); + CALC_DST(_sum12); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + vst1q_u32(outptr + 8, _sum02); + vst1q_u32(outptr2, _sum10); + vst1q_u32(outptr2 + 4, _sum11); + vst1q_u32(outptr2 + 8, _sum12); + + r0 += 6; + r1 += 6; + r2 += 6; + r3 += 6; + r4 += 6; + r5 += 6; + outptr += 12; + outptr2 += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum10 = vld1q_u32(outptr2); + + uint8x16_t _r_ori = vtranslq_u8(vld1_u8(r0)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(123, 4, 0); + + _r_ori = vtranslq_u8(vld1_u8(r1)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(5678, 9, 123, 4, 0); + + _r_ori = vtranslq_u8(vld1_u8(r2)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(10111213, 14, 5678, 9, 0); + + _r_ori = vtranslq_u8(vld1_u8(r3)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(15161718, 19, 10111213, 14, 0); + + _r_ori = vtranslq_u8(vld1_u8(r4)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(20212223, 24, 15161718, 19, 0); + + _r_ori = vtranslq_u8(vld1_u8(r5)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_1(20212223, 24, 0); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr2, _sum10); + + r0 += 2; + r1 += 2; + r2 += 2; + r3 += 2; + r4 += 2; + r5 += 2; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 3 < width; w += 3) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + uint32x4_t _sum02 = vld1q_u32(outptr + 8); + + uint8x16_t _r_ori = vld1q_u8(r0); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(123, 4, 0); + CALC_0(123, 4, 1); + CALC_0(123, 4, 2); + + _r_ori = vld1q_u8(r1); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(5678, 9, 0); + CALC_0(5678, 9, 1); + CALC_0(5678, 9, 2); + + _r_ori = vld1q_u8(r2); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(10111213, 14, 0); + CALC_0(10111213, 14, 1); + CALC_0(10111213, 14, 2); + + _r_ori = vld1q_u8(r3); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(15161718, 19, 0); + CALC_0(15161718, 19, 1); + CALC_0(15161718, 19, 2); + + _r_ori = vld1q_u8(r4); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(20212223, 24, 0); + CALC_0(20212223, 24, 1); + CALC_0(20212223, 24, 2); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum02); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + vst1q_u32(outptr + 8, _sum02); + + r0 += 6; + r1 += 6; + r2 += 6; + r3 += 6; + r4 += 6; + outptr += 12; + } + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + + uint8x16_t _r_ori = vtranslq_u8(vld1_u8(r0)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(123, 4, 0); + + _r_ori = vtranslq_u8(vld1_u8(r1)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(5678, 9, 0); + + _r_ori = vtranslq_u8(vld1_u8(r2)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(10111213, 14, 0); + + _r_ori = vtranslq_u8(vld1_u8(r3)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(15161718, 19, 0); + + _r_ori = vtranslq_u8(vld1_u8(r4)); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(20212223, 24, 0); + + if (last_oc) { + CALC_DST(_sum00); + } + vst1q_u32(outptr, _sum00); + + r0 += 2; + r1 += 2; + r2 += 2; + r3 += 2; + r4 += 2; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } +} + +template +void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, + size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, + uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(IC); + const size_t tail_step = IW - OW / 2; + + uint8x16_t _src_zp = vdupq_n_u8(src_zp); + uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); + int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); + + const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, + 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; + const uint8x16_t _idx11 = {8, 9, 10, 16, 9, 10, 11, 16, + 10, 11, 12, 16, 11, 12, 13, 16}; + uint8x16_t _idx_r_0; + if (even) { + _idx_r_0 = {0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16}; + } else { + _idx_r_0 = {16, 0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7}; + } + + uint8x16_t _tmp, _elem; + uint32x4_t _elem2; + const uint8_t* src_ptr = src; + uint32_t* outptr = reinterpret_cast(dst); + uint32_t* outptr2 = outptr + OW; + + const uint8_t* r0 = src_ptr; + const uint8_t* r1 = src_ptr + IW; + const uint8_t* r2 = src_ptr + IW * 2; + const uint8_t* r3 = src_ptr + IW * 3; + const uint8_t* r4 = src_ptr + IW * 4; + const uint8_t* r5 = src_ptr + IW * 5; + const uint8_t* r6 = src_ptr + IW * 6; + const uint8_t* r7 = src_ptr + IW * 7; + + const uint8_t* k0 = filter; + + uint8x16_t _k = vld1q_u8(k0 + 33); + //! filter row 1 + uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, + 15, 14, 13, 12, 15, 14, 13, 12}; + uint8x16_t _k123 = vqtbl1q_u8_common(_k, _idx); + _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; + uint8x16_t _k456 = vqtbl1q_u8_common(_k, _idx); + //! filter row 2 + _idx = {8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5}; + uint8x16_t _k78910 = vqtbl1q_u8_common(_k, _idx); + _idx = {4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16}; + uint8x16_t _k111213 = vqtbl1q_u8_common(_k, _idx); + + //! 12 13 14 15 -> 16 17 18 19 -> 20 21 22 23 -> 24 25 26 27 + _k = vld1q_u8(k0 + 19); + //! filter row 3 + _idx = {15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12}; + uint8x16_t _k14151617 = vqtbl1q_u8_common(_k, _idx); + _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; + uint8x16_t _k181920 = vqtbl1q_u8_common(_k, _idx); + //! filter row 4 + _idx = {8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5, 8, 7, 6, 5}; + uint8x16_t _k21222324 = vqtbl1q_u8_common(_k, _idx); + _idx = {4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16, 4, 3, 2, 16}; + uint8x16_t _k252627 = vqtbl1q_u8_common(_k, _idx); + + //! 24 25 26 27->28 29 30 31 -> 32 33 34 35 -> 36 37 38 39 + _k = vld1q_u8(k0 + 5); + //! filter row 5 + _idx = {15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12}; + uint8x16_t _k28293031 = vqtbl1q_u8_common(_k, _idx); + _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; + uint8x16_t _k323334 = vqtbl1q_u8_common(_k, _idx); + + //! 33 34 35 36 -> 37 38 39 40 -> 41 42 43 44 -> 45 46 47 48 + _k = vld1q_u8(k0); + //! filter row 6 + _idx = {13, 12, 11, 10, 13, 12, 11, 10, 13, 12, 11, 10, 13, 12, 11, 10}; + uint8x16_t _k35363738 = vqtbl1q_u8_common(_k, _idx); + _idx = {9, 8, 7, 16, 9, 8, 7, 16, 9, 8, 7, 16, 9, 8, 7, 16}; + uint8x16_t _k394041 = vqtbl1q_u8_common(_k, _idx); + + //! filter row 7 + _idx = {6, 5, 4, 3, 6, 5, 4, 3, 6, 5, 4, 3, 6, 5, 4, 3}; + uint8x16_t _k42434445 = vqtbl1q_u8_common(_k, _idx); + _idx = {2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16, 2, 1, 0, 16}; + uint8x16_t _k464748 = vqtbl1q_u8_common(_k, _idx); + + const int width = OW >> 2; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int w = 0; + for (; w + 2 < width; w += 2) { + //! As the inner kernel read 16 elements, and IW is times of 16 + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + uint32x4_t _sum10 = vld1q_u32(outptr2); + uint32x4_t _sum11 = vld1q_u32(outptr2 + 4); + + uint8x16_t _r_ori = vld1q_u8(r0); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(123, 456, 0); + CALC_0(123, 456, 1); + + _r_ori = vld1q_u8(r1); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(78910, 111213, 123, 456, 0); + CALC_2(78910, 111213, 123, 456, 1); + + _r_ori = vld1q_u8(r2); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(14151617, 181920, 78910, 111213, 0); + CALC_2(14151617, 181920, 78910, 111213, 1); + + _r_ori = vld1q_u8(r3); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(21222324, 252627, 14151617, 181920, 0); + CALC_2(21222324, 252627, 14151617, 181920, 1); + + _r_ori = vld1q_u8(r4); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(28293031, 323334, 21222324, 252627, 0); + CALC_2(28293031, 323334, 21222324, 252627, 1); + + _r_ori = vld1q_u8(r5); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(35363738, 394041, 28293031, 323334, 0); + CALC_2(35363738, 394041, 28293031, 323334, 1); + + _r_ori = vld1q_u8(r6); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(42434445, 464748, 35363738, 394041, 0); + CALC_2(42434445, 464748, 35363738, 394041, 1); + + _r_ori = vld1q_u8(r7); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_1(42434445, 464748, 0); + CALC_1(42434445, 464748, 1); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + CALC_DST(_sum10); + CALC_DST(_sum11); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + vst1q_u32(outptr2, _sum10); + vst1q_u32(outptr2 + 4, _sum11); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + r7 += 4; + outptr += 8; + outptr2 += 8; + } + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum10 = vld1q_u32(outptr2); + + uint8x16_t _r_ori = vld1q_u8(r0); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(123, 456, 0); + + _r_ori = vld1q_u8(r1); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(78910, 111213, 123, 456, 0); + + _r_ori = vld1q_u8(r2); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(14151617, 181920, 78910, 111213, 0); + + _r_ori = vld1q_u8(r3); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(21222324, 252627, 14151617, 181920, 0); + + _r_ori = vld1q_u8(r4); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(28293031, 323334, 21222324, 252627, 0); + + _r_ori = vld1q_u8(r5); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(35363738, 394041, 28293031, 323334, 0); + + _r_ori = vld1q_u8(r6); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_2(42434445, 464748, 35363738, 394041, 0); + + _r_ori = vld1q_u8(r7); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_1(42434445, 464748, 0); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum10); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr2, _sum10); + + r0 += 2; + r1 += 2; + r2 += 2; + r3 += 2; + r4 += 2; + r5 += 2; + r6 += 2; + r7 += 2; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + r6 += tail_step + IW; + r7 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int w = 0; + for (; w + 2 < width; w += 2) { + uint32x4_t _sum00 = vld1q_u32(outptr); + uint32x4_t _sum01 = vld1q_u32(outptr + 4); + + uint8x16_t _r_ori = vld1q_u8(r0); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(123, 456, 0); + CALC_0(123, 456, 1); + + _r_ori = vld1q_u8(r1); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(78910, 111213, 0); + CALC_0(78910, 111213, 1); + + _r_ori = vld1q_u8(r2); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(14151617, 181920, 0); + CALC_0(14151617, 181920, 1); + + _r_ori = vld1q_u8(r3); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(21222324, 252627, 0); + CALC_0(21222324, 252627, 1); + + _r_ori = vld1q_u8(r4); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(28293031, 323334, 0); + CALC_0(28293031, 323334, 1); + + _r_ori = vld1q_u8(r5); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(35363738, 394041, 0); + CALC_0(35363738, 394041, 1); + + _r_ori = vld1q_u8(r6); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(42434445, 464748, 0); + CALC_0(42434445, 464748, 1); + + if (last_oc) { + CALC_DST(_sum00); + CALC_DST(_sum01); + } + vst1q_u32(outptr, _sum00); + vst1q_u32(outptr + 4, _sum01); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + outptr += 8; + } + for (; w < width; w++) { + uint32x4_t _sum00 = vld1q_u32(outptr); + + uint8x16_t _r_ori = vld1q_u8(r0); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(123, 456, 0); + + _r_ori = vld1q_u8(r1); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(78910, 111213, 0); + + _r_ori = vld1q_u8(r2); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(14151617, 181920, 0); + + _r_ori = vld1q_u8(r3); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(21222324, 252627, 0); + + _r_ori = vld1q_u8(r4); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(28293031, 323334, 0); + + _r_ori = vld1q_u8(r5); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(35363738, 394041, 0); + + _r_ori = vld1q_u8(r6); + _tmp = vqtbx1q_u8_common(_src_zp, _r_ori, _idx_r_0); + CALC_0(42434445, 464748, 0); + + if (last_oc) { + CALC_DST(_sum00); + } + vst1q_u32(outptr, _sum00); + + r0 += 2; + r1 += 2; + r2 += 2; + r3 += 2; + r4 += 2; + r5 += 2; + r6 += 2; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } +} + +#undef CALC_0 +#undef CALC_1 +#undef CALC_2 + +} // anonymous namespace + +size_t deconv::get_workspace_in_bytes_stride2_quint8_dot( + const NCBKernSizeParam& param) { + return get_bundle(param).total_size_in_bytes(); +} + +bool deconv::can_stride2_quint8_dot(const NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0], FW = fm.spatial[1], OC = fm.ocpg, + PH = fm.padding[0], PW = fm.padding[1]; + bool avaiable = fm.format == param::Convolution::Format::NCHW && + !fm.should_flip && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == FW && + (FH == 2 || FH == 3 || FH == 5 || FH == 7) && + FH >= PH + 1 && FW >= PW + 1; + + /** + * \note In the kernel, we use uint32_t to calc the value, in order + * not generate negative number, we first initialize SHIFT and sub + * it to get the actual value in the last oc calc. + * + * \warning the sum of dst value should not greater than SHIFT, + * otherwise it maybe error, but currently in mobile, it would not + * be possible(7*7*OC*2^8*2^8 > SHIFT => OC > 334). + */ + avaiable &= (7 * 7 * OC < (1 << (SHIFT_BITS - 8 - 8))); + return avaiable && ((FH == 2 && OC <= 4) || + ((FH == 3 || FH == 5 || FH == 7) && OC <= 8)); +} + +void deconv::stride2_quint8_dot(const NCBKernParam& param) { + auto bundle = get_bundle(param); + bundle.set(param.workspace_ptr); + UNPACK_CONV_F32_NCB_KERN_SIZES(param); + MEGDNN_MARK_USED_VAR(SH); + MEGDNN_MARK_USED_VAR(SW); + size_t IH2, IW2, OW2; + int padding_h = FH - PH - 1, padding_w = FW - PW - 1; + get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); + + uint8_t filter_zp = + param.filter_type.param().zero_point; + uint8_t src_zp = param.diff_type.param().zero_point; + int32_t src_filter_zp = static_cast(filter_zp) * + static_cast(src_zp) * OC * FH * FH; + + using Func = std::function; + Func deconv = nullptr, deconv_last_oc = nullptr; + + switch (FH) { +#define cb(n) \ + case n: \ + do { \ + if ((padding_w & 1) == 0) { \ + deconv = deconv_direct_##n##x##n; \ + deconv_last_oc = deconv_direct_##n##x##n; \ + } else { \ + deconv = deconv_direct_##n##x##n; \ + deconv_last_oc = deconv_direct_##n##x##n; \ + } \ + } while (0); \ + break; + cb(2); + cb(3); + cb(5); + cb(7); +#undef cb + default: + megdnn_assert(0); + } + + bool need_dst_copy_var = need_dst_copy(param); + uint8_t* base_src_ptr = reinterpret_cast( + const_cast(param.diff())); + int32_t* base_dst_ptr = reinterpret_cast(param.grad()); + const uint8_t* fptr = + reinterpret_cast(param.filter()); + + for (size_t n = 0; n < N; ++n) { + int32_t* dptr_copied = static_cast(bundle.get(1)); + int32_t* dptr_ori = base_dst_ptr + n * param.out_bs; + int32_t* dptr = nullptr; + size_t OW_real = OW; + if (need_dst_copy_var) { + dptr = dptr_copied; + OW_real = OW2; + } else { + dptr = dptr_ori; + } + std::fill_n(dptr, IC * OH * OW_real, SHIFT); + + uint8_t* sptr_ori = base_src_ptr + n * param.inp_bs; + uint8_t* sptr_copied = static_cast(bundle.get(0)); + uint8_t* sptr = nullptr; + + rep(oc, OC) { + // copy sptr_ori to sptr_copied + std::memset(sptr_copied, src_zp, sizeof(uint8_t) * IH2 * IW2); + copy_plane_in_bytes(sptr_copied + padding_h * IW2 + padding_w / 2, + sptr_ori + oc * IH * IW, IH, + IW * sizeof(uint8_t), 2 * IW2 * sizeof(uint8_t), + IW * sizeof(uint8_t)); + sptr = sptr_copied; + + int32_t* dst_ptr = dptr; + const uint8_t* filter = fptr + oc * IC * FH * FW; + for (size_t ic = 0; ic < IC; ic++) { + if (oc != OC - 1) { + deconv(sptr, filter, dst_ptr, IH2, IW2, OH, OW_real, IC, + src_zp, filter_zp, src_filter_zp); + } else { + deconv_last_oc(sptr, filter, dst_ptr, IH2, IW2, OH, OW_real, + IC, src_zp, filter_zp, src_filter_zp); + } + dst_ptr += OH * OW_real; + filter += FH * FH; + } + } + if (need_dst_copy_var) { + for (size_t ic = 0; ic < IC; ++ic) { + copy_plane_in_bytes(dptr_ori + ic * OH * OW, + dptr + ic * OH * OW2, OH, + OW * sizeof(int32_t), OW * sizeof(int32_t), + OW2 * sizeof(int32_t)); + } + } + } +} + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h new file mode 100644 index 00000000..83732038 --- /dev/null +++ b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h @@ -0,0 +1,37 @@ +/** + * \file dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/convolution/opr_impl.h" + +#include +#include + +namespace megdnn { +namespace arm_common { +namespace deconv { + +using NCBKernSizeParam = ConvolutionBackwardDataImpl::NCBKernSizeParam; +using NCBKernParam = ConvolutionBackwardDataImpl::NCBKernParam; + +bool can_stride2_quint8_dot(const NCBKernSizeParam& param); + +void stride2_quint8_dot(const NCBKernParam& param); + +size_t get_workspace_in_bytes_stride2_quint8_dot(const NCBKernSizeParam& param); + +} // namespace convolution +} // namespace arm_common +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/cvt_color/opr_impl.cpp b/dnn/src/arm_common/cvt_color/opr_impl.cpp new file mode 100644 index 00000000..fee255a1 --- /dev/null +++ b/dnn/src/arm_common/cvt_color/opr_impl.cpp @@ -0,0 +1,1666 @@ +/** + * By downloading, copying, installing or using the software you agree to this license. + * If you do not agree to this license, do not download, install, + * copy or use the software. + * + * + * License Agreement + * For Open Source Computer Vision Library + * (3-clause BSD License) + * + * Copyright (C) 2000-2020, Intel Corporation, all rights reserved. + * Copyright (C) 2009-2011, Willow Garage Inc., all rights reserved. + * Copyright (C) 2009-2016, NVIDIA Corporation, all rights reserved. + * Copyright (C) 2010-2013, Advanced Micro Devices, Inc., all rights reserved. + * Copyright (C) 2015-2016, OpenCV Foundation, all rights reserved. + * Copyright (C) 2015-2016, Itseez Inc., all rights reserved. + * Copyright (C) 2019-2020, Xperience AI, all rights reserved. + * Third party copyrights are property of their respective owners. + * + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the names of the copyright holders nor the names of the contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * This software is provided by the copyright holders and contributors "as is" and + * any express or implied warranties, including, but not limited to, the implied + * warranties of merchantability and fitness for a particular purpose are disclaimed. + * In no event shall copyright holders or contributors be liable for any direct, + * indirect, incidental, special, exemplary, or consequential damages + * (including, but not limited to, procurement of substitute goods or services; + * loss of use, data, or profits; or business interruption) however caused + * and on any theory of liability, whether in contract, strict liability, + * or tort (including negligence or otherwise) arising in any way out of + * the use of this software, even if advised of the possibility of such damage. + * + * --------------------------------------------------------------------------- + * \file dnn/src/arm_common/cvt_color/opr_impl.cpp + * + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * + * This file has been modified by Megvii ("Megvii Modifications"). + * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. + * + * --------------------------------------------------------------------------- + */ +#include +#include "src/arm_common/cvt_color/opr_impl.h" +#include "src/arm_common/handle.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/cv/common.h" +#include "src/common/cv/cvt_color.h" +#include "src/common/cv/helper.h" +#include "src/common/utils.h" +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_cvtcolor) +MIDOUT_DECL(megdnn_arm_cvtcolor_cases) +MIDOUT_DECL(megdnn_arm_cvt_bt601_yuv) + +namespace megdnn { +namespace arm_common { + +GENERATE_CVT_OPR_DECL_FOREACH(GENERATE_CVT_OPR_DECL) +GENERATE_UNSUPPORT_CVT_OPR_FOR_FLOAT(GENERATE_UNSUPPORT_CVT_OPR) + +using namespace megcv; + +namespace { +/** + * \brief yuv to rgb or bgr. + * + * \tparam rgb, is convert to rgb or bgr + * \tparam is_planar, if true, the layout is YYYYUUVV or YYYYVVUU, otherwise + * YYYYYUVUV or YYYYYVUVU + * \tparam is_uv, if true, U is before V, otherwise V is before U + */ +template +void cvt_yuv_transform(const Mat8u& src, Mat8u& dst) { + uint8x16_t v_y; + int32x4_t v_y_s32_0, v_y_s32_1, v_y_s32_2, v_y_s32_3; + uint8x8x2_t v_vu; + int32x4_t v_RV0, v_RV1, v_RV2, v_RV3; + int32x4_t v_GVU0, v_GVU1, v_GVU2, v_GVU3; + int32x4_t v_BU0, v_BU1, v_BU2, v_BU3; + + int32x4x4_t v_R; + int32x4x4_t v_G; + int32x4x4_t v_B; + uint8x16x3_t v_RGB, v_BGR; + + int16x8_t v_128; + v_128 = vdupq_n_s16(128); + + int16x4_t v_359, v_88, v_183, v_454; + v_359 = vdup_n_s16(359); + v_88 = vdup_n_s16(88); + v_183 = vdup_n_s16(183); + v_454 = vdup_n_s16(454); + size_t height = dst.rows(); + size_t width = dst.cols(); + int src_step = src.step(); + const unsigned char* pY = src.ptr(); + const unsigned char* pU; + const unsigned char* pV; + if (is_uv) { + pU = src.ptr(height); + //! only used if is_planar is false + pV = src.ptr(height + height / 4); + } else { + pV = src.ptr(height); + //! only used if is_planar is false + pU = src.ptr(height + height / 4); + } + +#define SET_COLOR(out, index) \ + if (rgb) { \ + out[index++] = R; \ + out[index++] = G; \ + out[index++] = B; \ + } else { \ + out[index++] = B; \ + out[index++] = G; \ + out[index++] = R; \ + } + + for (size_t r = 0; r < height; r += 2, pY += (src_step << 1)) { + unsigned char* dst0 = dst.ptr(r); + unsigned char* dst1 = dst.ptr(r + 1); + size_t index0 = 0; + size_t index1 = 0; + int c = 0; + for (; c <= (int)(width - 16); c += 16, index0 += 48, index1 += 48) { + int16x8x2_t v_vu_s16; + if (is_planar) { + v_vu_s16.val[0] = + vreinterpretq_s16_u16(vmovl_u8(vld1_u8(pV + c / 2))); + v_vu_s16.val[1] = + vreinterpretq_s16_u16(vmovl_u8(vld1_u8(pU + c / 2))); + } else { + if (is_uv) { + v_vu = vld2_u8(pU + c); + v_vu_s16.val[0] = + vreinterpretq_s16_u16(vmovl_u8(v_vu.val[1])); + v_vu_s16.val[1] = + vreinterpretq_s16_u16(vmovl_u8(v_vu.val[0])); + } else { + v_vu = vld2_u8(pV + c); + v_vu_s16.val[0] = + vreinterpretq_s16_u16(vmovl_u8(v_vu.val[0])); + v_vu_s16.val[1] = + vreinterpretq_s16_u16(vmovl_u8(v_vu.val[1])); + } + } + + v_vu_s16.val[0] = vsubq_s16(v_vu_s16.val[0], v_128); + v_vu_s16.val[1] = vsubq_s16(v_vu_s16.val[1], v_128); + + int16x4_t v_v0, v_u0; + int16x4_t v_v1, v_u1; + v_v0 = vget_low_s16(v_vu_s16.val[0]); + v_v1 = vget_high_s16(v_vu_s16.val[0]); + v_u0 = vget_low_s16(v_vu_s16.val[1]); + v_u1 = vget_high_s16(v_vu_s16.val[1]); + + v_RV1 = vshrq_n_s32(vmull_s16(v_v0, v_359), 8); + v_RV3 = vshrq_n_s32(vmull_s16(v_v1, v_359), 8); + v_GVU1 = vshrq_n_s32( + vaddq_s32(vmull_s16(v_u0, v_88), vmull_s16(v_v0, v_183)), + 8); + v_GVU3 = vshrq_n_s32( + vaddq_s32(vmull_s16(v_u1, v_88), vmull_s16(v_v1, v_183)), + 8); + v_BU1 = vshrq_n_s32(vmull_s16(v_u0, v_454), 8); + v_BU3 = vshrq_n_s32(vmull_s16(v_u1, v_454), 8); + + int32x4x2_t temp; + temp = vzipq_s32(v_RV1, v_RV1); + v_RV0 = temp.val[0]; + v_RV1 = temp.val[1]; + temp = vzipq_s32(v_RV3, v_RV3); + v_RV2 = temp.val[0]; + v_RV3 = temp.val[1]; + + temp = vzipq_s32(v_GVU1, v_GVU1); + v_GVU0 = temp.val[0]; + v_GVU1 = temp.val[1]; + temp = vzipq_s32(v_GVU3, v_GVU3); + v_GVU2 = temp.val[0]; + v_GVU3 = temp.val[1]; + + temp = vzipq_s32(v_BU1, v_BU1); + v_BU0 = temp.val[0]; + v_BU1 = temp.val[1]; + temp = vzipq_s32(v_BU3, v_BU3); + v_BU2 = temp.val[0]; + v_BU3 = temp.val[1]; + + v_y = vld1q_u8(pY + c); + uint8x8_t v_y_half; + v_y_half = vget_low_u8(v_y); + int16x8_t v_y_2quarter = vreinterpretq_s16_u16(vmovl_u8(v_y_half)); + v_y_s32_0 = vmovl_s16(vget_low_s16(v_y_2quarter)); + v_y_s32_1 = vmovl_s16(vget_high_s16(v_y_2quarter)); + + v_y_half = vget_high_u8(v_y); + v_y_2quarter = vreinterpretq_s16_u16(vmovl_u8(v_y_half)); + v_y_s32_2 = vmovl_s16(vget_low_s16(v_y_2quarter)); + v_y_s32_3 = vmovl_s16(vget_high_s16(v_y_2quarter)); + + v_R.val[0] = vaddq_s32(v_y_s32_0, v_RV0); + v_R.val[1] = vaddq_s32(v_y_s32_1, v_RV1); + v_R.val[2] = vaddq_s32(v_y_s32_2, v_RV2); + v_R.val[3] = vaddq_s32(v_y_s32_3, v_RV3); + + v_G.val[0] = vsubq_s32(v_y_s32_0, v_GVU0); + v_G.val[1] = vsubq_s32(v_y_s32_1, v_GVU1); + v_G.val[2] = vsubq_s32(v_y_s32_2, v_GVU2); + v_G.val[3] = vsubq_s32(v_y_s32_3, v_GVU3); + + v_B.val[0] = vaddq_s32(v_y_s32_0, v_BU0); + v_B.val[1] = vaddq_s32(v_y_s32_1, v_BU1); + v_B.val[2] = vaddq_s32(v_y_s32_2, v_BU2); + v_B.val[3] = vaddq_s32(v_y_s32_3, v_BU3); + + if (rgb) { + v_RGB.val[0] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), + vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), + vmovn_s32(v_R.val[3])))); + v_RGB.val[1] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), + vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), + vmovn_s32(v_G.val[3])))); + v_RGB.val[2] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), + vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), + vmovn_s32(v_B.val[3])))); + + vst3q_u8((dst0 + c * 3), v_RGB); + } else { + v_BGR.val[0] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), + vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), + vmovn_s32(v_B.val[3])))); + v_BGR.val[1] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), + vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), + vmovn_s32(v_G.val[3])))); + v_BGR.val[2] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), + vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), + vmovn_s32(v_R.val[3])))); + vst3q_u8((dst0 + c * 3), v_BGR); + } + + v_y = vld1q_u8(pY + src_step + c); + v_y_half = vget_low_u8(v_y); + v_y_2quarter = vreinterpretq_s16_u16(vmovl_u8(v_y_half)); + v_y_s32_0 = vmovl_s16(vget_low_s16(v_y_2quarter)); + v_y_s32_1 = vmovl_s16(vget_high_s16(v_y_2quarter)); + + v_y_half = vget_high_u8(v_y); + v_y_2quarter = vreinterpretq_s16_u16(vmovl_u8(v_y_half)); + v_y_s32_2 = vmovl_s16(vget_low_s16(v_y_2quarter)); + v_y_s32_3 = vmovl_s16(vget_high_s16(v_y_2quarter)); + + v_R.val[0] = vaddq_s32(v_y_s32_0, v_RV0); + v_R.val[1] = vaddq_s32(v_y_s32_1, v_RV1); + v_R.val[2] = vaddq_s32(v_y_s32_2, v_RV2); + v_R.val[3] = vaddq_s32(v_y_s32_3, v_RV3); + + v_G.val[0] = vsubq_s32(v_y_s32_0, v_GVU0); + v_G.val[1] = vsubq_s32(v_y_s32_1, v_GVU1); + v_G.val[2] = vsubq_s32(v_y_s32_2, v_GVU2); + v_G.val[3] = vsubq_s32(v_y_s32_3, v_GVU3); + + v_B.val[0] = vaddq_s32(v_y_s32_0, v_BU0); + v_B.val[1] = vaddq_s32(v_y_s32_1, v_BU1); + v_B.val[2] = vaddq_s32(v_y_s32_2, v_BU2); + v_B.val[3] = vaddq_s32(v_y_s32_3, v_BU3); + + if (rgb) { + v_RGB.val[0] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), + vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), + vmovn_s32(v_R.val[3])))); + v_RGB.val[1] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), + vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), + vmovn_s32(v_G.val[3])))); + v_RGB.val[2] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), + vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), + vmovn_s32(v_B.val[3])))); + + vst3q_u8((dst1 + c * 3), v_RGB); + } else { + v_BGR.val[0] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), + vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), + vmovn_s32(v_B.val[3])))); + v_BGR.val[1] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), + vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), + vmovn_s32(v_G.val[3])))); + v_BGR.val[2] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), + vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), + vmovn_s32(v_R.val[3])))); + vst3q_u8((dst1 + c * 3), v_BGR); + } + } + + for (; c < (int)width; c += 2) { + int Y00, Y01, Y10, Y11, U, V; + int R, G, B; + Y00 = *((pY) + c); + Y01 = *((pY) + c + 1); + Y10 = *((pY) + src_step + c); + Y11 = *((pY) + src_step + c + 1); + if (is_planar) { + V = *(pV + c / 2); + U = *(pU + c / 2); + } else { + if (is_uv) { + U = *(pU + c); + V = *(pU + c + 1); + } else { + V = *(pV + c); + U = *(pV + c + 1); + } + } + + int ruv, guv, buv; + ruv = ((359 * (V - 128)) >> 8); + guv = -1 * ((88 * (U - 128) + 183 * (V - 128)) >> 8); + buv = ((454 * (U - 128)) >> 8); + + R = Y00 + ruv; + G = Y00 + guv; + B = Y00 + buv; + R = (R > 255) ? 255 : ((R < 0) ? 0 : R); + G = (G > 255) ? 255 : ((G < 0) ? 0 : G); + B = (B > 255) ? 255 : ((B < 0) ? 0 : B); + + SET_COLOR(dst0, index0); + + R = Y01 + ruv; + G = Y01 + guv; + B = Y01 + buv; + R = (R > 255) ? 255 : ((R < 0) ? 0 : R); + G = (G > 255) ? 255 : ((G < 0) ? 0 : G); + B = (B > 255) ? 255 : ((B < 0) ? 0 : B); + + SET_COLOR(dst0, index0); + + ruv = ((359 * (V - 128)) >> 8); + guv = -1 * ((88 * (U - 128) + 183 * (V - 128)) >> 8); + buv = ((454 * (U - 128)) >> 8); + R = Y10 + ruv; + G = Y10 + guv; + B = Y10 + buv; + R = (R > 255) ? 255 : ((R < 0) ? 0 : R); + G = (G > 255) ? 255 : ((G < 0) ? 0 : G); + B = (B > 255) ? 255 : ((B < 0) ? 0 : B); + + SET_COLOR(dst1, index1); + + R = Y11 + ruv; + G = Y11 + guv; + B = Y11 + buv; + R = (R > 255) ? 255 : ((R < 0) ? 0 : R); + G = (G > 255) ? 255 : ((G < 0) ? 0 : G); + B = (B > 255) ? 255 : ((B < 0) ? 0 : B); + + SET_COLOR(dst1, index1); + } + if (is_planar) { + pV += src_step / 2; + pU += src_step / 2; + } else { + if (is_uv) { + pU += src_step; + } else { + pV += src_step; + } + } + } +#undef SET_COLOR +} + +/** + * \brief real yuv to rgb or bgr. + * + * \tparam rgb, is convert to rgb or bgr + * \tparam is_planar, if true, the layout is YYYYUUVV or YYYYVVUU, otherwise + * YYYYYUVUV or YYYYYVUVU + * \tparam is_uv, if true, U is before V, otherwise V is before U + * + * \note it is BT.601 YUV to RGB reference, it refer to + * https://github.com/opencv/opencv/blob/1b53a4fccc1a61541b71340af9a04b59484ec2cf/modules/imgproc/src/opencl/color_yuv.cl#L253 + * R = (Y - 16) * 1.164 - (V - 128) * -1.596 + * G = (Y - 16) * 1.164 - (U - 128) * 0.391 - (V - 128) * 0.813 + * B = (Y - 16) * 1.164 - (U - 128) * -2.018 + * The Numerical approximations refers to libyuv + * implementation(https://github.com/lemenkov/libyuv/blob/7e936044d154b9fe159a67f9562e10b1ef1cb590/source/row_common.cc#L1002), + */ +template +void cvt_BT601_yuv_transform(const Mat8u& src, Mat8u& dst) { + typedef unsigned char uint8; + const uint8* pY; + const uint8* pU; + const uint8* pV; + +#define SET_COLOR(out, index) \ + if (rgb) { \ + out[index++] = R; \ + out[index++] = G; \ + out[index++] = B; \ + } else { \ + out[index++] = B; \ + out[index++] = G; \ + out[index++] = R; \ + } + +#define YG 18997 /* round(1.164 * 64 * 256 * 256 / 257) */ +#define YGB -1160 /* 1.164 * 64 * -16 + 64 / 2 */ + +// U and V contributions to R,G,B. +#define UB -128 /* max(-128, round(-2.018 * 64)) */ +#define UG 25 /* round(0.391 * 64) */ +#define VG 52 /* round(0.813 * 64) */ +#define VR -102 /* round(-1.596 * 64) */ + +// Bias values to subtract 16 from Y and 128 from U and V. +#define BB (UB * 128 + YGB) +#define BG (UG * 128 + VG * 128 + YGB) +#define BR (VR * 128 + YGB) + + int32x4_t v_UB = vdupq_n_s32(UB); + int32x4_t v_YG = vdupq_n_s32(YG); + int32x4_t v_UG = vdupq_n_s32(UG); + int32x4_t v_VG = vdupq_n_s32(VG); + int32x4_t v_VR = vdupq_n_s32(VR); + int32x4_t v_BB = vdupq_n_s32(UB * 128 + YGB); + int32x4_t v_BG = vdupq_n_s32(UG * 128 + VG * 128 + YGB); + int32x4_t v_BR = vdupq_n_s32(VR * 128 + YGB); + int32x4_t v_0101 = vdupq_n_s32(0x0101); + + uint8x8x2_t v_vu; + int32x4x4_t v_R; + int32x4x4_t v_G; + int32x4x4_t v_B; + uint8x16x3_t v_RGB, v_BGR; + int32x4_t v_Y1; + + int width = dst.cols(); + int height = dst.rows(); + int src_step = src.step(); + pY = src.ptr(); + if (is_uv) { + pU = src.ptr(height); + pV = src.ptr(height + height / 4); + } else { + pV = src.ptr(height); + pU = src.ptr(height + height / 4); + } + for (int i = 0; i < height; i += 2, pY += src_step * 2) { + size_t index = 0; + size_t index1 = 0; + uint8* out = dst.ptr(i); + uint8* out1 = dst.ptr(i + 1); + int j = 0; + int jV = 0; + + for (; j <= (int)(width - 16); j += 16, index += 48, index1 += 48) { + int16x8x2_t v_vu_s16; + if (is_planar) { + v_vu_s16.val[0] = + vreinterpretq_s16_u16(vmovl_u8(vld1_u8(pV + jV))); + v_vu_s16.val[1] = + vreinterpretq_s16_u16(vmovl_u8(vld1_u8(pU + jV))); + jV += 8; + } else { + if (is_uv) { + v_vu = vld2_u8(pU + j); + v_vu_s16.val[0] = + vreinterpretq_s16_u16(vmovl_u8(v_vu.val[1])); + v_vu_s16.val[1] = + vreinterpretq_s16_u16(vmovl_u8(v_vu.val[0])); + } else { + v_vu = vld2_u8(pV + j); + v_vu_s16.val[0] = + vreinterpretq_s16_u16(vmovl_u8(v_vu.val[0])); + v_vu_s16.val[1] = + vreinterpretq_s16_u16(vmovl_u8(v_vu.val[1])); + } + } + + int32x4_t v_v0, v_u0; + int32x4_t v_v1, v_u1; + int32x4_t v_v2, v_u2; + int32x4_t v_v3, v_u3; + v_v0 = vmovl_s16(vget_low_s16(v_vu_s16.val[0])); + v_v2 = vmovl_s16(vget_high_s16(v_vu_s16.val[0])); + v_u0 = vmovl_s16(vget_low_s16(v_vu_s16.val[1])); + v_u2 = vmovl_s16(vget_high_s16(v_vu_s16.val[1])); + + //! zip the v0 to 0011/2233, as two y value share the shape u/v + int32x4x2_t temp; + temp = vzipq_s32(v_v0, v_v0); + v_v0 = temp.val[0]; + v_v1 = temp.val[1]; + temp = vzipq_s32(v_v2, v_v2); + v_v2 = temp.val[0]; + v_v3 = temp.val[1]; + + temp = vzipq_s32(v_u0, v_u0); + v_u0 = temp.val[0]; + v_u1 = temp.val[1]; + temp = vzipq_s32(v_u2, v_u2); + v_u2 = temp.val[0]; + v_u3 = temp.val[1]; + + uint8x16_t v_y = vld1q_u8(pY + j); + uint8x8_t v_y_half = vget_low_u8(v_y); + int16x8_t v_y_2quarter = vreinterpretq_s16_u16(vmovl_u8(v_y_half)); + int32x4_t v_y0 = vmovl_s16(vget_low_s16(v_y_2quarter)); + int32x4_t v_y1 = vmovl_s16(vget_high_s16(v_y_2quarter)); + v_y_half = vget_high_u8(v_y); + v_y_2quarter = vreinterpretq_s16_u16(vmovl_u8(v_y_half)); + int32x4_t v_y2 = vmovl_s16(vget_low_s16(v_y_2quarter)); + int32x4_t v_y3 = vmovl_s16(vget_high_s16(v_y_2quarter)); + + //! calc +#define CALC(_idx) \ + v_Y1 = vshrq_n_s32(vmulq_s32(vmulq_s32(v_y##_idx, v_0101), v_YG), 16); \ + v_B.val[_idx] = vshrq_n_s32( \ + vsubq_s32(vaddq_s32(v_Y1, v_BB), vmulq_s32(v_u##_idx, v_UB)), 6); \ + v_G.val[_idx] = \ + vshrq_n_s32(vsubq_s32(vaddq_s32(v_Y1, v_BG), \ + vaddq_s32(vmulq_s32(v_u##_idx, v_UG), \ + vmulq_s32(v_v##_idx, v_VG))), \ + 6); \ + v_R.val[_idx] = vshrq_n_s32( \ + vsubq_s32(vaddq_s32(v_Y1, v_BR), vmulq_s32(v_v##_idx, v_VR)), 6); + + CALC(0); + CALC(1); + CALC(2); + CALC(3); + + if (rgb) { + v_RGB.val[0] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), + vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), + vmovn_s32(v_R.val[3])))); + v_RGB.val[1] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), + vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), + vmovn_s32(v_G.val[3])))); + v_RGB.val[2] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), + vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), + vmovn_s32(v_B.val[3])))); + vst3q_u8((out + index), v_RGB); + } else { + v_BGR.val[0] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), + vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), + vmovn_s32(v_B.val[3])))); + v_BGR.val[1] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), + vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), + vmovn_s32(v_G.val[3])))); + v_BGR.val[2] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), + vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), + vmovn_s32(v_R.val[3])))); + vst3q_u8((out + index), v_BGR); + } + + v_y = vld1q_u8(pY + src_step + j); + v_y_half = vget_low_u8(v_y); + v_y_2quarter = vreinterpretq_s16_u16(vmovl_u8(v_y_half)); + v_y0 = vmovl_s16(vget_low_s16(v_y_2quarter)); + v_y1 = vmovl_s16(vget_high_s16(v_y_2quarter)); + v_y_half = vget_high_u8(v_y); + v_y_2quarter = vreinterpretq_s16_u16(vmovl_u8(v_y_half)); + v_y2 = vmovl_s16(vget_low_s16(v_y_2quarter)); + v_y3 = vmovl_s16(vget_high_s16(v_y_2quarter)); + + CALC(0); + CALC(1); + CALC(2); + CALC(3); + + if (rgb) { + v_RGB.val[0] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), + vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), + vmovn_s32(v_R.val[3])))); + v_RGB.val[1] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), + vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), + vmovn_s32(v_G.val[3])))); + v_RGB.val[2] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), + vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), + vmovn_s32(v_B.val[3])))); + vst3q_u8((out1 + index1), v_RGB); + } else { + v_BGR.val[0] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), + vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), + vmovn_s32(v_B.val[3])))); + v_BGR.val[1] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), + vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), + vmovn_s32(v_G.val[3])))); + v_BGR.val[2] = vcombine_u8( + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), + vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), + vmovn_s32(v_R.val[3])))); + vst3q_u8((out1 + index1), v_BGR); + } +#undef CALC + } + + for (; j < width; j += 2) { + int U = 0, V = 0, Y0 = 0; + if (is_planar) { + V = *(pV + jV); + U = *(pU + jV); + jV++; + } else { + if (is_uv) { + U = *(pU + j); + V = *(pU + j + 1); + } else { + V = *(pV + j); + U = *(pV + j + 1); + } + } + + Y0 = *((pY) + j); + uint32_t Y1 = static_cast(Y0 * 0x0101 * YG) >> 16; + uint8_t B = saturate_cast( + static_cast(-(U * UB) + Y1 + BB) >> 6); + uint8_t G = saturate_cast( + static_cast(-(U * UG + V * VG) + Y1 + BG) >> 6); + uint8_t R = saturate_cast( + static_cast(-(V * VR) + Y1 + BR) >> 6); + SET_COLOR(out, index) + + Y0 = *((pY) + j + 1); + Y1 = static_cast(Y0 * 0x0101 * YG) >> 16; + B = saturate_cast( + static_cast(-(U * UB) + Y1 + BB) >> 6); + G = saturate_cast( + static_cast(-(U * UG + V * VG) + Y1 + BG) >> 6); + R = saturate_cast( + static_cast(-(V * VR) + Y1 + BR) >> 6); + SET_COLOR(out, index) + + Y0 = *((pY) + src_step + j); + Y1 = static_cast(Y0 * 0x0101 * YG) >> 16; + B = saturate_cast( + static_cast(-(U * UB) + Y1 + BB) >> 6); + G = saturate_cast( + static_cast(-(U * UG + V * VG) + Y1 + BG) >> 6); + R = saturate_cast( + static_cast(-(V * VR) + Y1 + BR) >> 6); + SET_COLOR(out1, index1) + + Y0 = *((pY) + src_step + j + 1); + Y1 = static_cast(Y0 * 0x0101 * YG) >> 16; + B = saturate_cast( + static_cast(-(U * UB) + Y1 + BB) >> 6); + G = saturate_cast( + static_cast(-(U * UG + V * VG) + Y1 + BG) >> 6); + R = saturate_cast( + static_cast(-(V * VR) + Y1 + BR) >> 6); + SET_COLOR(out1, index1) + } + + if (is_planar) { + pV += src_step / 2; + pU += src_step / 2; + } else { + if (is_uv) { + pU += src_step; + } else { + pV += src_step; + } + } + } +#undef SET_COLOR +#undef BB +#undef BG +#undef BR +#undef YGB +#undef UB +#undef UG +#undef VG +#undef VR +#undef YG +} + +} // namespace + +void cvt_rgb2gray_32f_neon(const Mat32f& src, Mat32f& dst) { + static const float coef[] = {0.299f, 0.587f, 0.114f}; + // load coef into neon types + const float32x4_t v_cr(vdupq_n_f32(coef[0])), v_cg(vdupq_n_f32(coef[1])), + v_cb(vdupq_n_f32(coef[2])); + +#define EXPAND(offset) \ + v_src = vld3q_f32(psrc + offset * 3); \ + vst1q_f32(pdst + offset, \ + vmlaq_f32(vmlaq_f32(vmulq_f32(v_src.val[0], v_cr), v_src.val[1], \ + v_cg), \ + v_src.val[2], v_cb)); + for (size_t r = 0; r < src.rows(); ++r) { + const float* psrc = src.ptr(r); + float* pdst = dst.ptr(r); + + const float* pend = psrc + src.cols() * 3; + // pack 48 float at a time (16 pixels) + + for (; psrc <= pend - 16 * 3; psrc += 16 * 3, pdst += 16) { + float32x4x3_t v_src; + + EXPAND(0); + EXPAND(4); + EXPAND(8); + EXPAND(12); + } + // if more than 8 pixels left, do an extra pack + if (psrc <= pend - 8 * 3) { + float32x4x3_t v_src; + + EXPAND(0); + EXPAND(4); + + psrc += 8 * 3; + pdst += 8; + } + // if more than 4 pixels left, do an extra pack + if (psrc <= pend - 4 * 3) { + float32x4x3_t v_src; + + EXPAND(0); + + psrc += 4 * 3; + pdst += 4; + } + // loop over left pixels + for (; psrc < pend; psrc += 3, pdst += 1) { + *pdst = psrc[0] * coef[0] + psrc[1] * coef[1] + psrc[2] * coef[2]; + } + } +#undef EXPAND +} + +void cvt_rgb2yuv_8u_neon(const Mat8u& src, Mat8u& dst) { + const int yuv_shift = 14; + const int coeffs[] = {1868, 9617, 4899, 8061, 14369}; + const int delta = 128 << yuv_shift; + + const int C0 = coeffs[0], C1 = coeffs[1], C2 = coeffs[2], C3 = coeffs[3], + C4 = coeffs[4]; + + int16x4_t v_c0, v_c1, v_c2; + int32x4_t v_c3, v_c4, v_delta, v_delta2; + v_c0 = vdup_n_s16(coeffs[0]); + v_c1 = vdup_n_s16(coeffs[1]); + v_c2 = vdup_n_s16(coeffs[2]); + v_c3 = vdupq_n_s32(coeffs[3]); + v_c4 = vdupq_n_s32(coeffs[4]); + v_delta = vdupq_n_s32(128 << yuv_shift); + v_delta2 = vdupq_n_s32(1 << (yuv_shift - 1)); + + for (size_t r = 0; r < src.rows(); ++r) { + const uchar* psrc = src.ptr(r); + uchar* pdst = dst.ptr(r); + const uchar* const pend = psrc + src.cols() * 3; + + // pack 8 pixels (24 uchar) + for (; psrc <= pend - 8 * 3; psrc += 8 * 3, pdst += 8 * 3) { + uint8x8x3_t v_dst; + int16x8x3_t v_src16; + + uint8x8x3_t v_src = vld3_u8(psrc); + v_src16.val[0] = vreinterpretq_s16_u16(vmovl_u8(v_src.val[0])); + v_src16.val[1] = vreinterpretq_s16_u16(vmovl_u8(v_src.val[1])); + v_src16.val[2] = vreinterpretq_s16_u16(vmovl_u8(v_src.val[2])); + + int16x4x3_t v_src0; + v_src0.val[0] = vget_low_s16(v_src16.val[0]); + v_src0.val[1] = vget_low_s16(v_src16.val[1]); + v_src0.val[2] = vget_low_s16(v_src16.val[2]); + + int32x4_t v_Y0 = vmlal_s16(vmlal_s16(vmull_s16(v_src0.val[0], v_c0), + v_src0.val[1], v_c1), + v_src0.val[2], v_c2); + v_Y0 = vshrq_n_s32(vaddq_s32(v_Y0, v_delta2), yuv_shift); + int32x4_t v_Cr0 = vmlaq_s32( + v_delta, vsubq_s32(vmovl_s16(v_src0.val[0]), v_Y0), v_c3); + v_Cr0 = vshrq_n_s32(vaddq_s32(v_Cr0, v_delta2), yuv_shift); + int32x4_t v_Cb0 = vmlaq_s32( + v_delta, vsubq_s32(vmovl_s16(v_src0.val[2]), v_Y0), v_c4); + v_Cb0 = vshrq_n_s32(vaddq_s32(v_Cb0, v_delta2), yuv_shift); + + v_src0.val[0] = vget_high_s16(v_src16.val[0]); + v_src0.val[1] = vget_high_s16(v_src16.val[1]); + v_src0.val[2] = vget_high_s16(v_src16.val[2]); + + int32x4_t v_Y1 = vmlal_s16(vmlal_s16(vmull_s16(v_src0.val[0], v_c0), + v_src0.val[1], v_c1), + v_src0.val[2], v_c2); + v_Y1 = vshrq_n_s32(vaddq_s32(v_Y1, v_delta2), yuv_shift); + int32x4_t v_Cr1 = vmlaq_s32( + v_delta, vsubq_s32(vmovl_s16(v_src0.val[0]), v_Y1), v_c3); + v_Cr1 = vshrq_n_s32(vaddq_s32(v_Cr1, v_delta2), yuv_shift); + int32x4_t v_Cb1 = vmlaq_s32( + v_delta, vsubq_s32(vmovl_s16(v_src0.val[2]), v_Y1), v_c4); + v_Cb1 = vshrq_n_s32(vaddq_s32(v_Cb1, v_delta2), yuv_shift); + + v_dst.val[0] = vqmovun_s16( + vcombine_s16(vqmovn_s32(v_Y0), vqmovn_s32(v_Y1))); + v_dst.val[1] = vqmovun_s16( + vcombine_s16(vqmovn_s32(v_Cr0), vqmovn_s32(v_Cr1))); + v_dst.val[2] = vqmovun_s16( + vcombine_s16(vqmovn_s32(v_Cb0), vqmovn_s32(v_Cb1))); + + vst3_u8(pdst, v_dst); + } + for (; psrc < pend; psrc += 3, pdst += 3) { + int Y = descale(psrc[0] * C0 + psrc[1] * C1 + psrc[2] * C2, + yuv_shift); + int Cr = descale((psrc[0] - Y) * C3 + delta, yuv_shift); + int Cb = descale((psrc[2] - Y) * C4 + delta, yuv_shift); + pdst[0] = saturate_cast(Y); + pdst[1] = saturate_cast(Cr); + pdst[2] = saturate_cast(Cb); + } + } +} + +void cvt_rgb2yuv_32f_neon(const Mat32f& src, Mat32f& dst) { + const float coeffs[] = {0.114f, 0.587f, 0.299f, 0.492f, 0.877f}; + float32x4_t v_c0, v_c1, v_c2, v_c3, v_c4, v_delta; + const float C0 = coeffs[0], C1 = coeffs[1], C2 = coeffs[2], C3 = coeffs[3], + C4 = coeffs[4]; + const float delta = 0.5f; + v_c0 = vdupq_n_f32(coeffs[0]); + v_c1 = vdupq_n_f32(coeffs[1]); + v_c2 = vdupq_n_f32(coeffs[2]); + v_c3 = vdupq_n_f32(coeffs[3]); + v_c4 = vdupq_n_f32(coeffs[4]); + v_delta = vdupq_n_f32(0.5f); + + for (size_t r = 0; r < src.rows(); ++r) { + const float* psrc = src.ptr(r); + float* pdst = dst.ptr(r); + const float* const pend = psrc + src.cols() * 3; + + for (; psrc <= pend - 4 * 3; psrc += 4 * 3, pdst += 4 * 3) { + float32x4x3_t v_src = vld3q_f32(psrc), v_dst; + v_dst.val[0] = vmlaq_f32(vmlaq_f32(vmulq_f32(v_src.val[0], v_c0), + v_src.val[1], v_c1), + v_src.val[2], v_c2); + v_dst.val[1] = vmlaq_f32( + v_delta, vsubq_f32(v_src.val[0], v_dst.val[0]), v_c3); + v_dst.val[2] = vmlaq_f32( + v_delta, vsubq_f32(v_src.val[2], v_dst.val[0]), v_c4); + + vst3q_f32(pdst, v_dst); + } + for (; psrc < pend; psrc += 3, pdst += 3) { + float Y = psrc[0] * C0 + psrc[1] * C1 + psrc[2] * C2; + float Cr = (psrc[0] - Y) * C3 + delta; + float Cb = (psrc[2] - Y) * C4 + delta; + pdst[0] = Y; + pdst[1] = Cr; + pdst[2] = Cb; + } + } +} + +void cvt_yuv2rgb_8u_neon(const Mat8u& src, Mat8u& dst) { + static const int coeffs[] = {33292, -6472, -9519, 18678}; + const int C0 = coeffs[0], C1 = coeffs[1], C2 = coeffs[2], C3 = coeffs[3]; + const int yuv_shift = 14; + const int delta = 128; + + int32x4_t v_c0, v_c1, v_c2, v_c3, v_delta2; + int16x4_t v_delta; + v_c0 = vdupq_n_s32(coeffs[0]); + v_c1 = vdupq_n_s32(coeffs[1]); + v_c2 = vdupq_n_s32(coeffs[2]); + v_c3 = vdupq_n_s32(coeffs[3]); + v_delta = vdup_n_s16(128); + v_delta2 = vdupq_n_s32(1 << (yuv_shift - 1)); + for (size_t r = 0; r < src.rows(); ++r) { + const uchar* psrc = src.ptr(r); + uchar* pdst = dst.ptr(r); + const uchar* const pend = psrc + src.cols() * 3; + for (; psrc <= pend - 8 * 3; psrc += 8 * 3, pdst += 8 * 3) { + uint8x8x3_t v_src = vld3_u8(psrc); + int16x8x3_t v_src16; + v_src16.val[0] = vreinterpretq_s16_u16(vmovl_u8(v_src.val[0])); + v_src16.val[1] = vreinterpretq_s16_u16(vmovl_u8(v_src.val[1])); + v_src16.val[2] = vreinterpretq_s16_u16(vmovl_u8(v_src.val[2])); + + int16x4_t v_Y = vget_low_s16(v_src16.val[0]), + v_Cr = vget_low_s16(v_src16.val[1]), + v_Cb = vget_low_s16(v_src16.val[2]); + + int32x4_t v_b0 = vmulq_s32(v_c3, vsubl_s16(v_Cb, v_delta)); + v_b0 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_b0, v_delta2), yuv_shift), + v_Y); + int32x4_t v_g0 = + vmlaq_s32(vmulq_s32(vsubl_s16(v_Cr, v_delta), v_c1), + vsubl_s16(v_Cb, v_delta), v_c2); + v_g0 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_g0, v_delta2), yuv_shift), + v_Y); + int32x4_t v_r0 = vmulq_s32(v_c0, vsubl_s16(v_Cr, v_delta)); + v_r0 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_r0, v_delta2), yuv_shift), + v_Y); + + v_Y = vget_high_s16(v_src16.val[0]); + v_Cr = vget_high_s16(v_src16.val[1]); + v_Cb = vget_high_s16(v_src16.val[2]); + + int32x4_t v_b1 = vmulq_s32(v_c3, vsubl_s16(v_Cb, v_delta)); + v_b1 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_b1, v_delta2), yuv_shift), + v_Y); + int32x4_t v_g1 = + vmlaq_s32(vmulq_s32(vsubl_s16(v_Cr, v_delta), v_c1), + vsubl_s16(v_Cb, v_delta), v_c2); + v_g1 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_g1, v_delta2), yuv_shift), + v_Y); + int32x4_t v_r1 = vmulq_s32(v_c0, vsubl_s16(v_Cr, v_delta)); + v_r1 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_r1, v_delta2), yuv_shift), + v_Y); + + uint8x8_t v_b = + vqmovun_s16(vcombine_s16(vmovn_s32(v_b0), vmovn_s32(v_b1))); + uint8x8_t v_g = + vqmovun_s16(vcombine_s16(vmovn_s32(v_g0), vmovn_s32(v_g1))); + uint8x8_t v_r = + vqmovun_s16(vcombine_s16(vmovn_s32(v_r0), vmovn_s32(v_r1))); + + uint8x8x3_t v_dst; + v_dst.val[0] = v_r; + v_dst.val[1] = v_g; + v_dst.val[2] = v_b; + vst3_u8(pdst, v_dst); + } + for (; psrc < pend; psrc += 3, pdst += 3) { + uchar Y = psrc[0]; + uchar Cr = psrc[1]; + uchar Cb = psrc[2]; + + int b = Y + descale((Cb - delta) * C3, yuv_shift); + int g = Y + + descale((Cb - delta) * C2 + (Cr - delta) * C1, yuv_shift); + int r = Y + descale((Cr - delta) * C0, yuv_shift); + + pdst[0] = saturate_cast(r); + pdst[1] = saturate_cast(g); + pdst[2] = saturate_cast(b); + } + } +} + +void cvt_yuv2rgb_32f_neon(const Mat32f& src, Mat32f& dst) { + static const float coeffs[] = {2.032f, -0.395f, -0.581f, 1.140f}; + const float delta = 0.5f; + const float C0 = coeffs[0], C1 = coeffs[1], C2 = coeffs[2], C3 = coeffs[3]; + + float32x4_t v_c0, v_c1, v_c2, v_c3, v_delta; + v_c0 = vdupq_n_f32(coeffs[0]); + v_c1 = vdupq_n_f32(coeffs[1]); + v_c2 = vdupq_n_f32(coeffs[2]); + v_c3 = vdupq_n_f32(coeffs[3]); + v_delta = vdupq_n_f32(0.5f); + + for (size_t r = 0; r < src.rows(); ++r) { + const float* psrc = src.ptr(r); + float* pdst = dst.ptr(r); + const float* const pend = psrc + src.cols() * 3; + for (; psrc <= pend - 4 * 3; psrc += 4 * 3, pdst += 4 * 3) { + float32x4x3_t v_src = vld3q_f32(psrc), v_dst; + float32x4_t v_Y = v_src.val[0], v_Cr = v_src.val[1], + v_Cb = v_src.val[2]; + + v_dst.val[0] = vmlaq_f32(v_Y, vsubq_f32(v_Cr, v_delta), v_c0); + v_dst.val[1] = vaddq_f32( + vmlaq_f32(vmulq_f32(vsubq_f32(v_Cb, v_delta), v_c2), + vsubq_f32(v_Cr, v_delta), v_c1), + v_Y); + v_dst.val[2] = vmlaq_f32(v_Y, vsubq_f32(v_Cb, v_delta), v_c3); + + vst3q_f32(pdst, v_dst); + } + + for (; psrc < pend; psrc += 3, pdst += 3) { + float Y = psrc[0], Cr = psrc[1], Cb = psrc[2]; + + float b = Y + (Cb - delta) * C3; + float g = Y + (Cb - delta) * C2 + (Cr - delta) * C1; + float r = Y + (Cr - delta) * C0; + + pdst[0] = r; + pdst[1] = g; + pdst[2] = b; + } + } +} + +void cvt_rgba2rgb_8u_neon(const Mat8u& src, Mat8u& dst) { + for (size_t r = 0; r < src.rows(); ++r) { + const uchar* psrc = src.ptr(r); + uchar* pdst = dst.ptr(r); + const uchar* const pend = psrc + src.cols() * 4; + + for (; psrc <= pend - 64; pdst += 48, psrc += 64) { + uint8x16x4_t v_src = vld4q_u8(psrc); + uint8x16x3_t v_dst; + v_dst.val[0] = v_src.val[0]; + v_dst.val[1] = v_src.val[1]; + v_dst.val[2] = v_src.val[2]; + vst3q_u8(pdst, v_dst); + } + for (; psrc <= pend - 32; pdst += 24, psrc += 32) { + uint8x8x4_t v_src = vld4_u8(psrc); + uint8x8x3_t v_dst; + v_dst.val[0] = v_src.val[0]; + v_dst.val[1] = v_src.val[1]; + v_dst.val[2] = v_src.val[2]; + vst3_u8(pdst, v_dst); + } + for (; psrc < pend; pdst += 3, psrc += 4) { + uchar t0 = psrc[0], t1 = psrc[1], t2 = psrc[2]; + pdst[0] = t0; + pdst[1] = t1; + pdst[2] = t2; + } + } +} + +void cvt_rgba2bgr_8u_neon(const Mat8u& src, Mat8u& dst) { + for (size_t r = 0; r < src.rows(); ++r) { + const uchar* psrc = src.ptr(r); + uchar* pdst = dst.ptr(r); + const uchar* const pend = psrc + src.cols() * 4; + + for (; psrc <= pend - 64; pdst += 48, psrc += 64) { + uint8x16x4_t v_src = vld4q_u8(psrc); + uint8x16x3_t v_dst; + v_dst.val[0] = v_src.val[2]; + v_dst.val[1] = v_src.val[1]; + v_dst.val[2] = v_src.val[0]; + vst3q_u8(pdst, v_dst); + } + for (; psrc <= pend - 32; pdst += 24, psrc += 32) { + uint8x8x4_t v_src = vld4_u8(psrc); + uint8x8x3_t v_dst; + v_dst.val[0] = v_src.val[2]; + v_dst.val[1] = v_src.val[1]; + v_dst.val[2] = v_src.val[0]; + vst3_u8(pdst, v_dst); + } + for (; psrc < pend; pdst += 3, psrc += 4) { + uchar t0 = psrc[0], t1 = psrc[1], t2 = psrc[2]; + pdst[0] = t2; + pdst[1] = t1; + pdst[2] = t0; + } + } +} + +void cvt_rgb2bgr_8u_neon(const Mat8u& src, Mat8u& dst) { + for (size_t r = 0; r < src.rows(); ++r) { + const uchar* psrc = src.ptr(r); + uchar* pdst = dst.ptr(r); + const uchar* const pend = psrc + src.cols() * 3; + + for (; psrc <= pend - 48; pdst += 48, psrc += 48) { + uint8x16x3_t v_src = vld3q_u8(psrc), v_dst; + v_dst.val[0] = v_src.val[2]; + v_dst.val[1] = v_src.val[1]; + v_dst.val[2] = v_src.val[0]; + vst3q_u8(pdst, v_dst); + } + for (; psrc <= pend - 24; pdst += 24, psrc += 24) { + uint8x8x3_t v_src = vld3_u8(psrc), v_dst; + v_dst.val[0] = v_src.val[2]; + v_dst.val[1] = v_src.val[1]; + v_dst.val[2] = v_src.val[0]; + vst3_u8(pdst, v_dst); + } + for (; psrc < pend; pdst += 3, psrc += 3) { + uchar t0 = psrc[0], t1 = psrc[1], t2 = psrc[2]; + pdst[0] = t2; + pdst[1] = t1; + pdst[2] = t0; + } + } +} + +template <> +void cvt_rgb2gray(const Mat8u& src, Mat8u& dst) { + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + megdnn_assert(src.channels() == 3); + megdnn_assert(dst.channels() == 1); + + const int yuv_shift = 14, R2Y = 4899, G2Y = 9617, B2Y = 1868; + + const uchar* _src = src.ptr(); + uchar* _dst = dst.ptr(); + size_t rows = src.rows(); + size_t cols = src.cols(); + size_t src_step = src.step(); + size_t dst_step = dst.step(); + for (size_t r = 0; r < rows; ++r, _src += src_step, _dst += dst_step) { + const uchar* temp_src = _src; + uchar* temp_dst = _dst; + for (size_t c = 0; c < cols; ++c, temp_src += 3, temp_dst += 1) { + uchar x0 = temp_src[0]; + uchar x1 = temp_src[1]; + uchar x2 = temp_src[2]; + temp_dst[0] = + (x0 * R2Y + x1 * G2Y + x2 * B2Y + (1 << (yuv_shift - 1))) >> + yuv_shift; + } + } +} + +template <> +void cvt_rgb2gray(const Mat32f& src, Mat32f& dst) { + megdnn_assert(src.channels() == 3); + megdnn_assert(dst.channels() == 1); + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + + return cvt_rgb2gray_32f_neon(src, dst); +} + +// gray2rgb +template <> +void cvt_gray2rgb(const Mat8u& src, Mat8u& dst) { + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + megdnn_assert(src.channels() == 1); + megdnn_assert(dst.channels() == 3); + + for (size_t r = 0; r < src.rows(); ++r) { + const uchar* psrc = src.ptr(r); + uchar* pdst = dst.ptr(r); + const uchar* const pend = psrc + src.cols() * 1; + for (; psrc < pend; psrc += 1, pdst += 3) { + pdst[0] = pdst[1] = pdst[2] = psrc[0]; + } + } +} +template <> +void cvt_gray2rgb(const Mat32f& src, Mat32f& dst) { + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + megdnn_assert(src.channels() == 1); + megdnn_assert(dst.channels() == 3); + + for (size_t r = 0; r < src.rows(); ++r) { + const float* psrc = src.ptr(r); + float* pdst = dst.ptr(r); + const float* const pend = psrc + src.cols() * 1; + for (; psrc < pend; psrc += 1, pdst += 3) { + pdst[0] = pdst[1] = pdst[2] = psrc[0]; + } + } +} + +// rgb2yuv +template <> +void cvt_rgb2yuv(const Mat8u& src, Mat8u& dst) { + megdnn_assert(src.channels() == 3); + megdnn_assert(dst.channels() == 3); + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + + return cvt_rgb2yuv_8u_neon(src, dst); +} +template <> +void cvt_rgb2yuv(const Mat32f& src, Mat32f& dst) { + megdnn_assert(src.channels() == 3); + megdnn_assert(dst.channels() == 3); + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + + return cvt_rgb2yuv_32f_neon(src, dst); +} + +// yuv2rgb +template <> +void cvt_yuv2rgb(const Mat32f& src, Mat32f& dst) { + megdnn_assert(src.channels() == 3); + megdnn_assert(dst.channels() == 3); + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + + // turn on neon optimization wont improve + // return cvt_yuv2rgb_32f_neon(src, dst); + + const float coef[] = {2.032f, -0.395f, -0.581f, 1.140f}; + const float delta = 0.5f; + for (size_t r = 0; r < src.rows(); ++r) { + for (size_t c = 0; c < src.cols(); ++c) { + const float* v = &src.at(r, c, 0); + float Y = v[0]; + float Cr = v[1]; + float Cb = v[2]; + + float R = Y + (Cr - delta) * coef[0]; + float G = Y + (Cb - delta) * coef[2] + (Cr - delta) * coef[1]; + float B = Y + (Cb - delta) * coef[3]; + + float* target = &dst.at(r, c, 0); + target[0] = R; + target[1] = G; + target[2] = B; + } + } +} + +template <> +void cvt_yuv2rgb(const Mat8u& src, Mat8u& dst) { + megdnn_assert(src.channels() == 3); + megdnn_assert(dst.channels() == 3); + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + + return cvt_yuv2rgb_8u_neon(src, dst); +} + +template <> +void cvt_rgba2rgb(const Mat8u& src, Mat8u& dst) { + megdnn_assert(src.channels() == 4); + megdnn_assert(dst.channels() == 3); + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + + return cvt_rgba2rgb_8u_neon(src, dst); +} + +template <> +void cvt_rgba2bgr(const Mat8u& src, Mat8u& dst) { + megdnn_assert(src.channels() == 4); + megdnn_assert(dst.channels() == 3); + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + + return cvt_rgba2bgr_8u_neon(src, dst); +} + +template <> +void cvt_rgba2gray(const Mat8u& src, Mat8u& dst) { + megdnn_assert(src.channels() == 4); + megdnn_assert(dst.channels() == 1); + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + + const int yuv_shift = 14, R2Y = 4899, G2Y = 9617, B2Y = 1868; + + const uchar* _src = src.ptr(); + uchar* _dst = dst.ptr(); + size_t rows = src.rows(); + size_t cols = src.cols(); + size_t src_step = src.step(); + size_t dst_step = dst.step(); + for (size_t r = 0; r < rows; ++r, _src += src_step, _dst += dst_step) { + const uchar* temp_src = _src; + uchar* temp_dst = _dst; + for (size_t c = 0; c < cols; ++c, temp_src += 4, temp_dst += 1) { + uchar x0 = temp_src[0]; + uchar x1 = temp_src[1]; + uchar x2 = temp_src[2]; + temp_dst[0] = + (x0 * R2Y + x1 * G2Y + x2 * B2Y + (1 << (yuv_shift - 1))) >> + yuv_shift; + } + } +} + +template <> +void cvt_rgb2bgr(const Mat8u& src, Mat8u& dst) { + megdnn_assert(src.channels() == 3); + megdnn_assert(dst.channels() == 3); + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + + return cvt_rgb2bgr_8u_neon(src, dst); +} + +template <> +void cvt_bgr2gray(const Mat8u& src, Mat8u& dst) { + megdnn_assert(src.channels() == 3); + megdnn_assert(dst.channels() == 1); + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + + const int yuv_shift = 14, R2Y = 4899, G2Y = 9617, B2Y = 1868; + int tab[256 * 3]; + + int b = 0, g = 0, r = (1 << (yuv_shift - 1)); + for (int i = 0; i < 256; ++i, r += R2Y, g += G2Y, b += B2Y) { + tab[i] = r; + tab[i + 256] = g; + tab[i + 512] = b; + } + + const uchar* _src = src.ptr(); + uchar* _dst = dst.ptr(); + size_t rows = src.rows(); + size_t cols = src.cols(); + size_t src_step = src.step(); + size_t dst_step = dst.step(); + for (size_t r = 0; r < rows; ++r, _src += src_step, _dst += dst_step) { + const uchar* temp_src = _src; + uchar* temp_dst = _dst; + for (size_t c = 0; c < cols; ++c, temp_src += 3, temp_dst += 1) { + uchar x0 = temp_src[0]; + uchar x1 = temp_src[1]; + uchar x2 = temp_src[2]; + temp_dst[0] = + (tab[x2] + tab[x1 + 256] + tab[x0 + 512]) >> yuv_shift; + } + } +} + +template <> +void cvt_bgr2rgb(const Mat8u& src, Mat8u& dst) { + return cvt_rgb2bgr(src, dst); +} + +template <> +void cvt_yuv2gray_nv21(const Mat8u& src, Mat8u& dst) { + const uchar* _src = src.ptr(); + uchar* _dst = dst.ptr(); + size_t rows = dst.rows(); + size_t cols = dst.cols(); + size_t src_step = src.step(); + size_t dst_step = dst.step(); + for (size_t r = 0; r < rows; ++r, _src += src_step, _dst += dst_step) { + const uchar* temp_src = _src; + uchar* temp_dst = _dst; + for (size_t c = 0; c < cols; ++c, temp_src += 1, temp_dst += 1) { + temp_dst[0] = temp_src[0]; + } + } +} + +template <> +void cvt_yuv2rgb_nv21(const Mat8u& src, Mat8u& dst) { + return cvt_yuv_transform(src, dst); +} + +template <> +void cvt_yuv2bgr_nv21(const Mat8u& src, Mat8u& dst) { + return cvt_yuv_transform(src, dst); +} + +template <> +void cvt_yuv2rgb_nv12(const Mat8u& src, Mat8u& dst) { + return cvt_yuv_transform(src, dst); +} + +template <> +void cvt_yuv2bgr_nv12(const Mat8u& src, Mat8u& dst) { + return cvt_yuv_transform(src, dst); +} + +template <> +void cvt_yuv2rgb_yv12(const Mat8u& src, Mat8u& dst) { + return cvt_yuv_transform(src, dst); +} + +template <> +void cvt_yuv2bgr_yv12(const Mat8u& src, Mat8u& dst) { + return cvt_yuv_transform(src, dst); +} + +template <> +void cvt_yuv2rgb_yu12(const Mat8u& src, Mat8u& dst) { + return cvt_yuv_transform(src, dst); +} + +template <> +void cvt_yuv2bgr_yu12(const Mat8u& src, Mat8u& dst) { + return cvt_yuv_transform(src, dst); +} + +template +void cvt_bt601_yuv(const megcv::Mat& src, megcv::Mat& dst, + param::CvtColor::Mode mode) { + MEGDNN_MARK_USED_VAR(src); + MEGDNN_MARK_USED_VAR(dst); + MEGDNN_MARK_USED_VAR(mode); + megdnn_throw("Unsupport dtype for real yuv"); +} + +template <> +void cvt_bt601_yuv(const megcv::Mat& src, megcv::Mat& dst, + param::CvtColor::Mode mode) { + using Mode = param::CvtColor::Mode; + switch (mode) { + case Mode::BT601_YUV2RGB_NV21: + MIDOUT_BEGIN(megdnn_arm_cvt_bt601_yuv, midout_iv(0)) { + return cvt_BT601_yuv_transform(src, dst); + } + MIDOUT_END(); + case Mode::BT601_YUV2BGR_NV21: + MIDOUT_BEGIN(megdnn_arm_cvt_bt601_yuv, midout_iv(1)) { + return cvt_BT601_yuv_transform(src, dst); + } + MIDOUT_END(); + case Mode::BT601_YUV2RGB_NV12: + MIDOUT_BEGIN(megdnn_arm_cvt_bt601_yuv, midout_iv(2)) { + return cvt_BT601_yuv_transform(src, dst); + } + MIDOUT_END(); + case Mode::BT601_YUV2BGR_NV12: + MIDOUT_BEGIN(megdnn_arm_cvt_bt601_yuv, midout_iv(3)) { + return cvt_BT601_yuv_transform(src, dst); + } + MIDOUT_END(); + case Mode::BT601_YUV2RGB_YV12: + MIDOUT_BEGIN(megdnn_arm_cvt_bt601_yuv, midout_iv(4)) { + return cvt_BT601_yuv_transform(src, dst); + } + MIDOUT_END(); + case Mode::BT601_YUV2BGR_YV12: + MIDOUT_BEGIN(megdnn_arm_cvt_bt601_yuv, midout_iv(5)) { + return cvt_BT601_yuv_transform(src, dst); + } + MIDOUT_END(); + case Mode::BT601_YUV2RGB_YU12: + MIDOUT_BEGIN(megdnn_arm_cvt_bt601_yuv, midout_iv(6)) { + return cvt_BT601_yuv_transform(src, dst); + } + MIDOUT_END(); + case Mode::BT601_YUV2BGR_YU12: + MIDOUT_BEGIN(megdnn_arm_cvt_bt601_yuv, midout_iv(7)) { + return cvt_BT601_yuv_transform(src, dst); + } + MIDOUT_END(); + default: + megdnn_throw("unknown mode for real yuv."); + } +} + +template +void CvtColorImpl::cvt_color_exec(const TensorND& src_tensor, + const TensorND& dst_tensor) { + auto mode = param().mode; + for (size_t i = 0; i < src_tensor.layout.shape[0]; ++i) { + Mat src = TensorND2Mat(src_tensor, i); + Mat dst = TensorND2Mat(dst_tensor, i); + switch (mode) { + case Param::Mode::RGB2GRAY: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(0)) { + cvt_rgb2gray(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::RGB2YUV: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(1)) { + cvt_rgb2yuv(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::YUV2RGB: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(2)) { + cvt_yuv2rgb(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::GRAY2RGB: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(3)) { + cvt_gray2rgb(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::RGBA2RGB: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(4)) { + cvt_rgba2rgb(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::RGBA2BGR: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(5)) { + cvt_rgba2bgr(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::RGBA2GRAY: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(6)) { + cvt_rgba2gray(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::RGB2BGR: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(7)) { + cvt_rgb2bgr(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::BGR2GRAY: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(8)) { + cvt_bgr2gray(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::BGR2RGB: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(9)) { + cvt_bgr2rgb(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::YUV2GRAY_NV21: + case Param::Mode::YUV2GRAY_NV12: + case Param::Mode::YUV2GRAY_YV12: + case Param::Mode::YUV2GRAY_YU12: + cvt_yuv2gray_nv21(src, dst); + break; + case Param::Mode::YUV2RGB_NV21: + case Param::Mode::YCrCb2RGB: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(10)) { + cvt_yuv2rgb_nv21(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::YUV2BGR_NV21: + case Param::Mode::YCrCb2BGR: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(11)) { + cvt_yuv2bgr_nv21(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::YUV2RGB_NV12: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(12)) { + cvt_yuv2rgb_nv12(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::YUV2BGR_NV12: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(13)) { + cvt_yuv2bgr_nv12(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::YUV2RGB_YV12: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(14)) { + cvt_yuv2rgb_yv12(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::YUV2BGR_YV12: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(15)) { + cvt_yuv2bgr_yv12(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::YUV2RGB_YU12: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(16)) { + cvt_yuv2rgb_yu12(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::YUV2BGR_YU12: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(17)) { + cvt_yuv2bgr_yu12(src, dst); + } + MIDOUT_END(); + break; + case Param::Mode::BT601_YUV2BGR_NV12: + case Param::Mode::BT601_YUV2RGB_NV12: + case Param::Mode::BT601_YUV2BGR_NV21: + case Param::Mode::BT601_YUV2RGB_NV21: + case Param::Mode::BT601_YUV2RGB_YU12: + case Param::Mode::BT601_YUV2BGR_YU12: + case Param::Mode::BT601_YUV2RGB_YV12: + case Param::Mode::BT601_YUV2BGR_YV12: + MIDOUT_BEGIN(megdnn_arm_cvtcolor_cases, midout_iv(18)) { + cvt_bt601_yuv(src, dst, mode); + } + MIDOUT_END(); + break; + + default: + megdnn_throw("Can not find property cvt_color operator."); + } + } +} +void CvtColorImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, + _megdnn_workspace workspace) { + using namespace megcv; + check_exec(src.layout, dst.layout, workspace.size); + if (dst.layout.dtype == dtype::Float32()) { + MIDOUT_BEGIN(megdnn_arm_cvtcolor MEGDNN_COMMA midout_iv(0)) { + MEGDNN_DISPATCH_CPU_KERN_OPR(cvt_color_exec(src, dst)); + } MIDOUT_END(); + } else if (dst.layout.dtype == dtype::Uint8()) { + MIDOUT_BEGIN(megdnn_arm_cvtcolor MEGDNN_COMMA midout_iv(1)) { + MEGDNN_DISPATCH_CPU_KERN_OPR(cvt_color_exec(src, dst)); + } MIDOUT_END(); + } else { megdnn_throw("Unsupported datatype of CvtColor optr."); }; +} + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/cvt_color/opr_impl.h b/dnn/src/arm_common/cvt_color/opr_impl.h new file mode 100644 index 00000000..c0938452 --- /dev/null +++ b/dnn/src/arm_common/cvt_color/opr_impl.h @@ -0,0 +1,40 @@ +/** + * \file dnn/src/arm_common/cvt_color/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace arm_common { + +class CvtColorImpl: public CvtColor { + private: + template + void cvt_color_exec(_megdnn_tensor_in src, + _megdnn_tensor_out dst); + + public: + using CvtColor::CvtColor; + + void exec(_megdnn_tensor_in src, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout &, + const TensorLayout &) override + { + return 0; + } +}; + +} // namespace x86 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise/binary/algo.cpp b/dnn/src/arm_common/elemwise/binary/algo.cpp new file mode 100644 index 00000000..c3c0ddbb --- /dev/null +++ b/dnn/src/arm_common/elemwise/binary/algo.cpp @@ -0,0 +1,406 @@ +/** + * \file dnn/src/arm_common/elemwise/binary/algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/elemwise/binary/algo.h" +#include "src/arm_common/elemwise_op.h" + +#include "src/common/utils.h" +#include "src/naive/handle.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_common_elemwise_binary) + +using namespace megdnn; +using namespace arm_common; + +namespace { +static inline bool is_available_common(Elemwise::Mode mode) { + /** + * Fused sigmoid & tanh may be slower than the naive algo, because the + * time used by neon function `exp_ps_f32` is decided by the input. + */ + if (mode == Elemwise::Mode::FUSE_ADD_SIGMOID || + mode == Elemwise::Mode::FUSE_ADD_TANH) { + return false; + } + + return true; +} +} // anonymous namespace + +#if MEGDNN_AARCH64 +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + auto mode = kern_param.mode; \ + if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \ + mode == Mode::SUB || mode == Mode::MUL || mode == Mode::POW || \ + mode == Mode::TRUE_DIV || mode == Mode::FUSE_ADD_RELU || \ + mode == Mode::FUSE_ADD_H_SWISH) \ + return true; +#else +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + auto mode = kern_param.mode; \ + if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \ + mode == Mode::SUB || mode == Mode::MUL || mode == Mode::POW || \ + mode == Mode::FUSE_ADD_RELU || mode == Mode::FUSE_ADD_H_SWISH) \ + return true; +#endif + +#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ + auto mode = kern_param.mode; \ + if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \ + mode == Mode::SUB || mode == Mode::MUL || mode == Mode::RMULH || \ + mode == Mode::FUSE_ADD_RELU) \ + return true; + +bool ElemwiseImpl::AlgoBinaryVecVec::is_available( + const KernParam& kern_param) const { + if (!is_available_common(kern_param.mode) || + (BcastType::VEC_VEC != kern_param.broad_cast_type)) + return false; + + auto& elparam = kern_param.binary_elparam; + auto& src0 = elparam[0]; + + //! exactly match [x, y] + [x, y] + DISPATCH_TYPE("AlgoBinaryVecVec::is_available"_hash); + + return false; +} + +bool ElemwiseImpl::AlgoBinaryVecScalar::is_available( + const KernParam& kern_param) const { + if (!is_available_common(kern_param.mode) || + ((BcastType::VEC_SCALAR != kern_param.broad_cast_type) && + (BcastType::SCALAR_VEC != kern_param.broad_cast_type))) + return false; + + auto& elparam = kern_param.binary_elparam; + auto& src0 = elparam[0]; + + DISPATCH_TYPE("AlgoBinaryVecScalar::is_available"_hash); + return false; +} + +bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available( + const KernParam& kern_param) const { + if (!is_available_common(kern_param.mode) || + ((BcastType::VEC_BCAST101 != kern_param.broad_cast_type) && + (BcastType::BCAST101_VEC != kern_param.broad_cast_type))) + return false; + + auto& elparam = kern_param.binary_elparam; + auto& src0 = elparam[0]; + + DISPATCH_TYPE("AlgoBinaryVecBcast101::is_available"_hash); + + return false; +} + +bool ElemwiseImpl::AlgoBinaryVecBcast101x4::is_available( + const KernParam& kern_param) const { + if (!is_available_common(kern_param.mode) || + ((BcastType::VEC_BCAST101x4 != kern_param.broad_cast_type) && + (BcastType::BCAST101x4_VEC != kern_param.broad_cast_type))) + return false; + + auto& elparam = kern_param.binary_elparam; + auto& src0 = elparam[0]; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (MEGDNN_FLOAT16_SELECT(src0.layout.dtype == dtype::Float16{}, false)) { + return false; + } +#endif + + DISPATCH_TYPE("AlgoBinaryVecBcast101x::is_available"_hash); + + return false; +} +#undef DISPATCH_MODE_FLOAT +#undef DISPATCH_MODE_INT + +#if MEGDNN_AARCH64 +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ + DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ + DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ + DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ + DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ + DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ + DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \ + DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, \ + FuseAddReluOp); \ + DISPATCH_BINARY(FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, \ + FuseAddHSwishOp); \ + default: \ + megdnn_throw(ssprintf("No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ + } +#else +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ + DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ + DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ + DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ + DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ + DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ + DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, \ + FuseAddReluOp); \ + DISPATCH_BINARY(FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, \ + FuseAddHSwishOp); \ + default: \ + megdnn_throw(ssprintf("No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ + } + +#endif + +#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ + DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ + DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ + DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ + DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ + DISPATCH_BINARY(RMULH, _case, _type, _type_midout_id, RmulhOp); \ + DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, \ + FuseAddReluOp); \ + default: \ + megdnn_throw(ssprintf("No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ + } + +void ElemwiseImpl::AlgoBinaryVecVec::exec(const KernParam& kern_param) const { + auto& elparam = kern_param.binary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1]; + + //! exactly match [x, y] + [x, y] +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary<_op<_type, _type>, \ + BcastType::VEC_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE("AlgoBinaryVecVec::exec"_hash); + +#undef DISPATCH_BINARY + + return; +} + +void ElemwiseImpl::AlgoBinaryVecScalar::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.binary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1]; + auto&& dst = *(kern_param.m_dst); + + // Case 2: vector + scalar +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary<_op<_type, _type>, \ + BcastType::VEC_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr)[0], \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + if (BcastType::VEC_SCALAR == kern_param.broad_cast_type) { + DISPATCH_TYPE("AlgoBinaryVecScalar::exec_vec_sca"_hash); + } +#undef DISPATCH_BINARY + + // scalar + vector +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary<_op<_type, _type>, \ + BcastType::SCALAR_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr)[0], \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, \ + src1.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + if (BcastType::SCALAR_VEC == kern_param.broad_cast_type) { + DISPATCH_TYPE("AlgoBinaryVecScalar::exec_sca_vec"_hash); + } +#undef DISPATCH_BINARY + + return; +} + +void ElemwiseImpl::AlgoBinaryVecBcast101::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.binary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1]; + auto&& dst = *(kern_param.m_dst); + BroadcastChannelInfo binfo; + + // Case 3: BcastType::VEC + BCAST_101 + if (BcastType::VEC_BCAST101 == kern_param.broad_cast_type && + is_broadcasted_channel_like(src1.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary<_op<_type, _type>, \ + BcastType::VEC_BCAST101>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + DISPATCH_TYPE("AlgoBinaryVecBcast101::exec_vec_b"_hash); + +#undef DISPATCH_BINARY + } + + // BCAST_101 + BcastType::VEC + if (BcastType::BCAST101_VEC == kern_param.broad_cast_type && + is_broadcasted_channel_like(src0.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary<_op<_type, _type>, \ + BcastType::BCAST101_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + DISPATCH_TYPE("AlgoBinaryVecBcast101::exec_b_vec"_hash); + +#undef DISPATCH_BINARY + } + return; +} + +void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.binary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1]; + auto&& dst = *(kern_param.m_dst); + BroadcastChannelInfo binfo; + + // BcastType::VEC + BCAST_101x + if (BcastType::VEC_BCAST101x4 == kern_param.broad_cast_type && + is_broadcastedx_channel_like<4>(src1.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary<_op<_type, _type>, \ + BcastType::VEC_BCAST101x4>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, batch_size, \ + binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + size_t batch_size = + src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + DISPATCH_TYPE("AlgoBinaryVecBcast101x::exec_vec_b"_hash); + +#undef DISPATCH_BINARY + } + + // BCAST_101x + BcastType::VEC + if (BcastType::BCAST101x4_VEC == kern_param.broad_cast_type && + is_broadcastedx_channel_like<4>(src0.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary<_op<_type, _type>, \ + BcastType::BCAST101x4_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, batch_size, \ + binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ + return + size_t batch_size = + src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + + DISPATCH_TYPE("AlgoBinaryVecBcast101x::exec_b_vec"_hash); + +#undef DISPATCH_BINARY + } + return; +} +#undef DISPATCH_MODE_FLOAT +#undef DISPATCH_MODE_INT + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise/binary/algo.h b/dnn/src/arm_common/elemwise/binary/algo.h new file mode 100644 index 00000000..20fb8f8d --- /dev/null +++ b/dnn/src/arm_common/elemwise/binary/algo.h @@ -0,0 +1,40 @@ +/** + * \file dnn/src/arm_common/elemwise/binary/algo.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/arm_common/elemwise/opr_impl.h" +namespace megdnn { +namespace arm_common { + +#define DECL_CB(case) \ + class ElemwiseImpl::AlgoBinary##case final \ + : public ElemwiseImpl::AlgoBase { \ + mutable std::string m_name; \ + bool is_reproducible() const override { return true; } \ + const char* name() const override { \ + if (m_name.empty()) { \ + m_name = megdnn_mangle( \ + ssprintf("Elemwise::AlgoBinaryCase" #case)); \ + } \ + return m_name.c_str(); \ + } \ + bool is_available(const KernParam&) const override; \ + void exec(const KernParam&) const override; \ + }; + +DECL_CB(VecVec); +DECL_CB(VecScalar); +DECL_CB(VecBcast101); +DECL_CB(VecBcast101x4); +#undef DECL_CB +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise/neon_mathfun.cpp b/dnn/src/arm_common/elemwise/neon_mathfun.cpp new file mode 100644 index 00000000..99c9fbb4 --- /dev/null +++ b/dnn/src/arm_common/elemwise/neon_mathfun.cpp @@ -0,0 +1,378 @@ +/** + * \file dnn/src/arm_common/elemwise/neon_mathfun.cpp + * + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + * + * This file has been modified by Megvii ("Megvii Modifications"). + * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights + * reserved. + * + */ + +/* NEON implementation of sin, cos, exp and log + + Inspired by Intel Approximate Math library, and based on the + corresponding algorithms of the cephes math library +*/ + +/* Copyright (C) 2011 Julien Pommier + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + (this is the zlib license) +*/ + +#include "./neon_mathfun.h" + +namespace megdnn { +namespace arm_common { + +#define c_inv_mant_mask ~0x7f800000u +#define c_cephes_SQRTHF 0.707106781186547524 +#define c_cephes_log_p0 7.0376836292E-2 +#define c_cephes_log_p1 -1.1514610310E-1 +#define c_cephes_log_p2 1.1676998740E-1 +#define c_cephes_log_p3 -1.2420140846E-1 +#define c_cephes_log_p4 +1.4249322787E-1 +#define c_cephes_log_p5 -1.6668057665E-1 +#define c_cephes_log_p6 +2.0000714765E-1 +#define c_cephes_log_p7 -2.4999993993E-1 +#define c_cephes_log_p8 +3.3333331174E-1 +#define c_cephes_log_q1 -2.12194440e-4 +#define c_cephes_log_q2 0.693359375 + +/** + * natural logarithm computed for 4 simultaneous float return NaN for x <= 0 + */ +v4sf log_ps_f32(v4sf x) { + v4sf one = vdupq_n_f32(1); + + x = vmaxq_f32(x, + vdupq_n_f32(0)); /* force flush to zero on denormal values */ + v4su invalid_mask = vcleq_f32(x, vdupq_n_f32(0)); + + v4si ux = vreinterpretq_s32_f32(x); + + v4si emm0 = vshrq_n_s32(ux, 23); + + /* keep only the fractional part */ + ux = vandq_s32(ux, vdupq_n_s32(c_inv_mant_mask)); + ux = vorrq_s32(ux, vreinterpretq_s32_f32(vdupq_n_f32(0.5f))); + x = vreinterpretq_f32_s32(ux); + + emm0 = vsubq_s32(emm0, vdupq_n_s32(0x7f)); + v4sf e = vcvtq_f32_s32(emm0); + + e = vaddq_f32(e, one); + + /* part2: + if( x < SQRTHF ) { + e -= 1; + x = x + x - 1.0; + } else { x = x - 1.0; } + */ + v4su mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF)); + v4sf tmp = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); + x = vsubq_f32(x, one); + e = vsubq_f32(e, vreinterpretq_f32_u32( + vandq_u32(vreinterpretq_u32_f32(one), mask))); + x = vaddq_f32(x, tmp); + + v4sf z = vmulq_f32(x, x); + + v4sf y = vdupq_n_f32(c_cephes_log_p0); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8)); + y = vmulq_f32(y, x); + + y = vmulq_f32(y, z); + + tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1)); + y = vaddq_f32(y, tmp); + + tmp = vmulq_f32(z, vdupq_n_f32(0.5f)); + y = vsubq_f32(y, tmp); + + tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2)); + x = vaddq_f32(x, y); + x = vaddq_f32(x, tmp); + x = vreinterpretq_f32_u32( + vorrq_u32(vreinterpretq_u32_f32(x), + invalid_mask)); // negative arg will be NAN + return x; +} + +#define c_exp_hi 88.3762626647949f +#define c_exp_lo -88.3762626647949f + +#define c_cephes_LOG2EF 1.44269504088896341 +#define c_cephes_exp_C1 0.693359375 +#define c_cephes_exp_C2 -2.12194440e-4 + +#define c_cephes_exp_p0 1.9875691500E-4 +#define c_cephes_exp_p1 1.3981999507E-3 +#define c_cephes_exp_p2 8.3334519073E-3 +#define c_cephes_exp_p3 4.1665795894E-2 +#define c_cephes_exp_p4 1.6666665459E-1 +#define c_cephes_exp_p5 5.0000001201E-1 + +/* exp() computed for 4 float at once */ +v4sf exp_ps_f32(v4sf x) { + v4sf tmp, fx; + + v4sf one = vdupq_n_f32(1); + x = vminq_f32(x, vdupq_n_f32(c_exp_hi)); + x = vmaxq_f32(x, vdupq_n_f32(c_exp_lo)); + + /* express exp(x) as exp(g + n*log(2)) */ + fx = vmlaq_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF)); + + /* perform a floorf */ + tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx)); + + /* if greater, subtract 1 */ + v4su mask = vcgtq_f32(tmp, fx); + mask = vandq_u32(mask, vreinterpretq_u32_f32(one)); + + fx = vsubq_f32(tmp, vreinterpretq_f32_u32(mask)); + + tmp = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C1)); + v4sf z = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C2)); + x = vsubq_f32(x, tmp); + x = vsubq_f32(x, z); + + static const float cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1, + c_cephes_exp_p2, c_cephes_exp_p3, + c_cephes_exp_p4, c_cephes_exp_p5}; + v4sf y = vld1q_dup_f32(cephes_exp_p + 0); + v4sf c1 = vld1q_dup_f32(cephes_exp_p + 1); + v4sf c2 = vld1q_dup_f32(cephes_exp_p + 2); + v4sf c3 = vld1q_dup_f32(cephes_exp_p + 3); + v4sf c4 = vld1q_dup_f32(cephes_exp_p + 4); + v4sf c5 = vld1q_dup_f32(cephes_exp_p + 5); + + y = vmulq_f32(y, x); + z = vmulq_f32(x, x); + y = vaddq_f32(y, c1); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c2); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c3); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c4); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c5); + + y = vmulq_f32(y, z); + y = vaddq_f32(y, x); + y = vaddq_f32(y, one); + + /* build 2^n */ + int32x4_t mm; + mm = vcvtq_s32_f32(fx); + mm = vaddq_s32(mm, vdupq_n_s32(0x7f)); + mm = vshlq_n_s32(mm, 23); + v4sf pow2n = vreinterpretq_f32_s32(mm); + + y = vmulq_f32(y, pow2n); + return y; +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +float16x8_t exp_ps_f16(float16x8_t x) { + float32x4_t low = vcvt_f32_f16(vget_low_f16(x)); + float32x4_t high = vcvt_f32_f16(vget_high_f16(x)); + low = exp_ps_f32(low); + high = exp_ps_f32(high); + + return vcombine_f16(vcvt_f16_f32(low), vcvt_f16_f32(high)); +} +#endif + +#define c_minus_cephes_DP1 -0.78515625 +#define c_minus_cephes_DP2 -2.4187564849853515625e-4 +#define c_minus_cephes_DP3 -3.77489497744594108e-8 +#define c_sincof_p0 -1.9515295891E-4 +#define c_sincof_p1 8.3321608736E-3 +#define c_sincof_p2 -1.6666654611E-1 +#define c_coscof_p0 2.443315711809948E-005 +#define c_coscof_p1 -1.388731625493765E-003 +#define c_coscof_p2 4.166664568298827E-002 +#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI + +/* evaluation of 4 sines & cosines at once. + + The code is the exact rewriting of the cephes sinf function. + Precision is excellent as long as x < 8192 (I did not bother to + take into account the special handling they have for greater values + -- it does not return garbage for arguments over 8192, though, but + the extra precision is missing). + + Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the + surprising but correct result. + + Note also that when you compute sin(x), cos(x) is available at + almost no extra price so both sin_ps_f32 and cos_ps_f32 make use of + sincos_ps_f32.. + */ +void sincos_ps_f32(v4sf x, v4sf* ysin, v4sf* ycos) { // any x + v4sf xmm1, xmm2, xmm3, y; + + v4su emm2; + + v4su sign_mask_sin, sign_mask_cos; + sign_mask_sin = vcltq_f32(x, vdupq_n_f32(0)); + x = vabsq_f32(x); + + /* scale by 4/Pi */ + y = vmulq_f32(x, vdupq_n_f32(c_cephes_FOPI)); + + /* store the integer part of y in mm0 */ + emm2 = vcvtq_u32_f32(y); + /* j=(j+1) & (~1) (see the cephes sources) */ + emm2 = vaddq_u32(emm2, vdupq_n_u32(1)); + emm2 = vandq_u32(emm2, vdupq_n_u32(~1)); + y = vcvtq_f32_u32(emm2); + + /* get the polynom selection mask + there is one polynom for 0 <= x <= Pi/4 + and another one for Pi/4 all_algos; +}; + +void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { + m_src = &srcs; + m_dst = &dst; + + if (!dst.layout.is_contiguous()) { + return fallback::ElemwiseImpl::exec(srcs, dst); + } + + if (m_dst->layout.dtype == dtype::Float32() || + MEGDNN_FLOAT16_SELECT(m_dst->layout.dtype == dtype::Float16(), false) || + m_dst->layout.dtype == dtype::Int32() || + m_dst->layout.dtype == dtype::Int16() || + m_dst->layout.dtype == dtype::Int8()) { + auto kern_param = make_kern_param(this); + kern_param.m_dst = &dst; + static AlgoPack m_algo_pack; + for (auto& m_algo : m_algo_pack.all_algos) { + if (m_algo->is_available(kern_param)) { + m_algo->exec(kern_param); + return; + } + } + } + fallback::ElemwiseImpl::exec(srcs, dst); +} + +ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { + KernParam kern_param; + kern_param.broad_cast_type = BcastType::UNKNOWN_BCAST_TYPE; + kern_param.mode = opr->param().mode; + kern_param.handle = opr->handle(); + + if ((opr->m_src->size() == 3) && + (opr->param().mode == Mode::FUSE_MUL_ADD3)) { + kern_param.ternary_elparam = opr->make_elemwise_op_param<3>(); + bool c_is_scalar; + opr->prepare_fma3(kern_param.ternary_elparam, c_is_scalar); + auto &src0 = kern_param.ternary_elparam[0], + &src1 = kern_param.ternary_elparam[1], + &src2 = kern_param.ternary_elparam[2]; + BroadcastChannelInfo binfo; + + if (is_vector(src0.layout) && is_vector(src1.layout) && + is_vector(src2.layout)) { + kern_param.broad_cast_type = BcastType::VEC_VEC_VEC; + return kern_param; + } + + if (is_vector(src0.layout) && is_vector(src1.layout) && c_is_scalar) { + kern_param.broad_cast_type = BcastType::VEC_VEC_SCALAR; + return kern_param; + } + + if (is_vector(src1.layout) && + is_broadcasted_channel_like(src0.layout, binfo) && + src0.layout.eq_layout(src2.layout)) { + kern_param.broad_cast_type = BcastType::BCAST101_VEC_BCAST101; + return kern_param; + } + + if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && + is_broadcasted_channel_like(src1.layout, binfo)) { + kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC; + return kern_param; + } + + if (is_vector(src0.layout) && is_vector(src2.layout) && + is_broadcasted_scalar(src1.layout)) { + kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC; + return kern_param; + } + + if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) && + is_broadcasted_scalar(src2.layout)) { + kern_param.broad_cast_type = BcastType::VEC_SCALAR_SCALAR; + return kern_param; + } + } else if (opr->m_src->size() == 2) { + kern_param.binary_elparam = opr->make_elemwise_op_param<2>(); + auto &src0 = kern_param.binary_elparam[0], + &src1 = kern_param.binary_elparam[1]; + BroadcastChannelInfo binfo; + if (is_vector(src0.layout) && is_vector(src1.layout)) { + kern_param.broad_cast_type = BcastType::VEC_VEC; + return kern_param; + } + + if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout)) { + kern_param.broad_cast_type = BcastType::VEC_SCALAR; + return kern_param; + } + + if (is_vector(src1.layout) && is_broadcasted_scalar(src0.layout)) { + kern_param.broad_cast_type = BcastType::SCALAR_VEC; + return kern_param; + } + + if (is_vector(src0.layout) && + is_broadcasted_channel_like(src1.layout, binfo)) { + kern_param.broad_cast_type = BcastType::VEC_BCAST101; + return kern_param; + } + + if (is_vector(src1.layout) && + is_broadcasted_channel_like(src0.layout, binfo)) { + kern_param.broad_cast_type = BcastType::BCAST101_VEC; + return kern_param; + } + + if (is_vector(src0.layout) && + is_broadcastedx_channel_like<4>(src1.layout, binfo)) { + kern_param.broad_cast_type = BcastType::VEC_BCAST101x4; + return kern_param; + } + + if (is_vector(src1.layout) && + is_broadcastedx_channel_like<4>(src0.layout, binfo)) { + kern_param.broad_cast_type = BcastType::BCAST101x4_VEC; + return kern_param; + } + + } else if (opr->m_src->size() == 1) { + kern_param.broad_cast_type = BcastType::VEC; + kern_param.unary_elparam = opr->make_elemwise_op_param<1>(); + return kern_param; + } + + return kern_param; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise/opr_impl.h b/dnn/src/arm_common/elemwise/opr_impl.h new file mode 100644 index 00000000..a0d7743f --- /dev/null +++ b/dnn/src/arm_common/elemwise/opr_impl.h @@ -0,0 +1,91 @@ +/** + * \file dnn/src/arm_common/elemwise/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/fallback/elemwise/opr_impl.h" + +#include "src/arm_common/elemwise_op.h" +namespace megdnn { +namespace arm_common { +class ElemwiseImpl final : public fallback::ElemwiseImpl { +public: + using fallback::ElemwiseImpl::ElemwiseImpl; + void exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) override; + const char* get_algorithm_set_name() const { return "ARM COMMON ELEMWISE"; } + +private: + struct KernParam { + BcastType broad_cast_type; + Mode mode; + const TensorND* m_dst; + Handle* handle; + ElemwiseOpParamN<3> ternary_elparam; + ElemwiseOpParamN<2> binary_elparam; + ElemwiseOpParamN<1> unary_elparam; + }; + KernParam make_kern_param(ElemwiseImpl* opr); + class AlgoBase; + class AlgoUnary; + class AlgoBinaryVecVec; + class AlgoBinaryVecScalar; + class AlgoBinaryVecBcast101; + class AlgoBinaryVecBcast101x4; + class AlgoTernaryFma3VecVecVec; + class AlgoTernaryFma3VecVecScalar; + class AlgoTernaryFma3Bcast101VecBcast101; + class AlgoTernaryFma3VecBcast101Vec; + class AlgoTernaryFma3VecScalarVec; + class AlgoTernaryFma3VecScalarScalar; + class AlgoPack; +}; + +/*! + * + * \brief base class for Elemwise algo + * + */ +class ElemwiseImpl::AlgoBase : public detail::Algorithm { +public: + virtual bool is_available(const KernParam&) const = 0; + virtual void exec(const KernParam&) const = 0; + virtual ~AlgoBase() = default; +}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#define DISPATCH_TYPE(_case) \ + if (src0.layout.dtype == dtype::Float32{}) { \ + DISPATCH_MODE_FLOAT(_case, float, 0); \ + } else if (MEGDNN_FLOAT16_SELECT(src0.layout.dtype == dtype::Float16{}, \ + false)) { \ + DISPATCH_MODE_FLOAT(_case, __fp16, 1); \ + } else if (src0.layout.dtype == dtype::Int32{}) { \ + DISPATCH_MODE_INT(_case, int, 2); \ + } else if (src0.layout.dtype == dtype::Int16{}) { \ + DISPATCH_MODE_INT(_case, dt_int16, 3); \ + } else if (src0.layout.dtype == dtype::Int8{}) { \ + DISPATCH_MODE_INT(_case, dt_int8, 4); \ + } +#else +#define DISPATCH_TYPE(_case) \ + if (src0.layout.dtype == dtype::Float32{}) { \ + DISPATCH_MODE_FLOAT(_case, float, 0); \ + } else if (src0.layout.dtype == dtype::Int32{}) { \ + DISPATCH_MODE_INT(_case, int, 2); \ + } else if (src0.layout.dtype == dtype::Int16{}) { \ + DISPATCH_MODE_INT(_case, dt_int16, 3); \ + } else if (src0.layout.dtype == dtype::Int8{}) { \ + DISPATCH_MODE_INT(_case, dt_int8, 4); \ + } +#endif + +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise/ternary/algo.cpp b/dnn/src/arm_common/elemwise/ternary/algo.cpp new file mode 100644 index 00000000..624099a1 --- /dev/null +++ b/dnn/src/arm_common/elemwise/ternary/algo.cpp @@ -0,0 +1,263 @@ +/** + * \file dnn/src/arm_common/elemwise/ternary/algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/elemwise/ternary/algo.h" +#include "src/arm_common/elemwise_op.h" + +#include "src/common/utils.h" +#include "src/naive/handle.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_common_elemwise_ternary) + +using namespace megdnn; +using namespace arm_common; + +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + auto mode = kern_param.mode; \ + if (mode == Mode::FUSE_MUL_ADD3) \ + return true; +#define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT + +#define DECL_AVAILABLE(case, type) \ + bool ElemwiseImpl::AlgoTernaryFma3##case ::is_available( \ + const KernParam& kern_param) const { \ + if (type == kern_param.broad_cast_type) { \ + auto& elparam = kern_param.ternary_elparam; \ + auto& src0 = elparam[0]; \ + DISPATCH_TYPE("AlgoTernaryFma3::is_available" #case##_hash); \ + } \ + return false; \ + } + +DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC); +DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR); +DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101); +DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC); +DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC); +DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR); +#undef DECL_CB +#undef DISPATCH_MODE_FLOAT +#undef DISPATCH_MODE_INT + +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_TERNARY(FUSE_MUL_ADD3, _case, _type, _type_midout_id, \ + FuseMulAdd3Op); \ + default: \ + megdnn_throw(ssprintf("No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ + } +#define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT +void ElemwiseImpl::AlgoTernaryFma3VecVecVec::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 1: shape of (src0, src2) and src1 are exactly match +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary<_op<_type, _type>, \ + BcastType::VEC_VEC_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast(src2.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE("AlgoTernaryFma3VecVecVec::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} +void ElemwiseImpl::AlgoTernaryFma3VecVecScalar::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 2: (src2 is a scalar) && (src0 and src1 has the same shape) +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary<_op<_type, _type>, \ + BcastType::VEC_VEC_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast(src2.raw_ptr)[0], \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE("AlgoTernaryFma3VecVecScalar::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} +void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 3: shape of src0 and src2 is {1, C, 1, 1} + BroadcastChannelInfo binfo; + is_broadcasted_channel_like(src0.layout, binfo); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, \ + BcastType::BCAST101_VEC_BCAST101>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast(src2.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} +void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 4: shape of src1 is {1, C, 1, 1}, and src0 and src2 are contig + BroadcastChannelInfo binfo; + is_broadcasted_channel_like(src1.layout, binfo); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary<_op<_type, _type>, \ + BcastType::VEC_BCAST101_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast(src2.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE("AlgoTernaryFma3VecBcast101Vec::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} +void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 5: (src1 is a scalar) && (src0 and src2 has the same shape) +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary<_op<_type, _type>, \ + BcastType::VEC_SCALAR_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr)[0], \ + static_cast(src2.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE("AlgoTernaryFma3VecScalarVec::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} +void ElemwiseImpl::AlgoTernaryFma3VecScalarScalar::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 6: (src1 and src2 is scalar) && (src0 is vector) +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary<_op<_type, _type>, \ + BcastType::VEC_SCALAR_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr)[0], \ + static_cast(src2.raw_ptr)[0], \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE("AlgoTernaryFma3VecScalarScalar::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} +#undef DISPATCH_MODE_FLOAT +#undef DISPATCH_MODE_INT + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise/ternary/algo.h b/dnn/src/arm_common/elemwise/ternary/algo.h new file mode 100644 index 00000000..975af62a --- /dev/null +++ b/dnn/src/arm_common/elemwise/ternary/algo.h @@ -0,0 +1,42 @@ +/** + * \file dnn/src/arm_common/elemwise/ternary/algo.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/arm_common/elemwise/opr_impl.h" +namespace megdnn { +namespace arm_common { + +#define DECL_CB(case) \ + class ElemwiseImpl::AlgoTernaryFma3##case final \ + : public ElemwiseImpl::AlgoBase { \ + mutable std::string m_name; \ + bool is_reproducible() const override { return true; } \ + const char* name() const override { \ + if (m_name.empty()) { \ + m_name = megdnn_mangle( \ + ssprintf("Elemwise::AlgoTernaryFma3" #case)); \ + } \ + return m_name.c_str(); \ + } \ + bool is_available(const KernParam&) const override; \ + void exec(const KernParam&) const override; \ + }; + +DECL_CB(VecVecVec); +DECL_CB(VecVecScalar); +DECL_CB(Bcast101VecBcast101); +DECL_CB(VecBcast101Vec); +DECL_CB(VecScalarVec); +DECL_CB(VecScalarScalar); +#undef DECL_CB +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise/unary/algo.cpp b/dnn/src/arm_common/elemwise/unary/algo.cpp new file mode 100644 index 00000000..e8faaabe --- /dev/null +++ b/dnn/src/arm_common/elemwise/unary/algo.cpp @@ -0,0 +1,120 @@ +/** + * \file dnn/src/arm_common/elemwise/unary/algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/elemwise/unary/algo.h" +#include "src/arm_common/elemwise_op.h" + +#include "src/common/utils.h" +#include "src/naive/handle.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_common_elemwise_unary) + +using namespace megdnn; +using namespace arm_common; + +bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { + if (BcastType::VEC != kern_param.broad_cast_type) + return false; + + if (kern_param.m_dst->layout.dtype.category() != DTypeCategory::FLOAT && + (kern_param.mode == Mode::EXP || kern_param.mode == Mode::SIGMOID || + kern_param.mode == Mode::TANH || kern_param.mode == Mode::FAST_TANH || + kern_param.mode == Mode::H_SWISH)) { + return false; + } + //! As `NEGATE` mode is so simple, that the code generate by compiler is + //! vectorized optimized, while other mode such as `ABS` has branch, the + //! compiler may not generate code as good as user intrinsic. + if (kern_param.mode == Mode::NEGATE) { + return false; + } + + auto& elparam = kern_param.unary_elparam; + if (!elparam[0].layout.is_contiguous()) + return false; + megdnn_assert(elparam[0].layout.ndim == 1); + auto& src0 = elparam[0]; + +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + auto mode = kern_param.mode; \ + if (mode == Mode::RELU || mode == Mode::ABS || mode == Mode::SIGMOID || \ + mode == Mode::EXP || mode == Mode::TANH || mode == Mode::FAST_TANH || \ + mode == Mode::H_SWISH) \ + return true; + +#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ + auto mode = kern_param.mode; \ + if (mode == Mode::RELU || mode == Mode::ABS) \ + return true; + + DISPATCH_TYPE("AlgoUnary::is_available"_hash); + return false; +#undef DISPATCH_MODE_FLOAT +#undef DISPATCH_MODE_INT +} + +void ElemwiseImpl::AlgoUnary::exec(const KernParam& kern_param) const { +#define DISPATCH_UNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_unary, midout_iv(_case), \ + midout_iv(Mode::_mode), midout_iv(_type_midout_id)) { \ + thin_function \ + run = OpCallerUnary<_op<_type, _type>, \ + BcastType::VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast<_type*>(dst_tensor.raw_ptr), \ + src0.layout.dtype, dst_tensor.layout.dtype, \ + nr_elems)); \ + } \ + MIDOUT_END(); \ + return + + auto& elparam = kern_param.unary_elparam; + megdnn_assert(elparam[0].layout.ndim == 1); + auto& src0 = elparam[0]; + auto& dst_tensor = *(kern_param.m_dst); + + size_t nr_elems = src0.layout.total_nr_elems(); + +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_UNARY(RELU, _case, _type, _type_midout_id, ReluOp); \ + DISPATCH_UNARY(ABS, _case, _type, _type_midout_id, AbsOp); \ + DISPATCH_UNARY(SIGMOID, _case, _type, _type_midout_id, SigmoidOp); \ + DISPATCH_UNARY(EXP, _case, _type, _type_midout_id, ExpOp); \ + DISPATCH_UNARY(TANH, _case, _type, _type_midout_id, TanhOp); \ + DISPATCH_UNARY(FAST_TANH, _case, _type, _type_midout_id, FastTanhOp); \ + DISPATCH_UNARY(H_SWISH, _case, _type, _type_midout_id, HSwishOp); \ + default: \ + megdnn_throw(ssprintf("No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ + } + +#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_UNARY(RELU, _case, _type, _type_midout_id, ReluOp); \ + DISPATCH_UNARY(ABS, _case, _type, _type_midout_id, AbsOp); \ + default: \ + megdnn_throw(ssprintf("No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ + } + + DISPATCH_TYPE("AlgoUnary::exec"_hash); +#undef DISPATCH_MODE_FLOAT +#undef DISPATCH_MODE_INT +#undef DISPATCH_UNARY +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise/unary/algo.h b/dnn/src/arm_common/elemwise/unary/algo.h new file mode 100644 index 00000000..881162f6 --- /dev/null +++ b/dnn/src/arm_common/elemwise/unary/algo.h @@ -0,0 +1,33 @@ +/** + * \file dnn/src/arm_common/elemwise/unary/algo.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/arm_common/elemwise/opr_impl.h" +namespace megdnn { +namespace arm_common { +class ElemwiseImpl::AlgoUnary final : public ElemwiseImpl::AlgoBase { + mutable std::string m_name; + + bool is_reproducible() const override { return true; } + const char* name() const override { + if (m_name.empty()) { + m_name = megdnn_mangle(ssprintf("Elemwise::AlgoUnary")); + } + return m_name.c_str(); + } + + bool is_available(const KernParam&) const override; + void exec(const KernParam&) const override; +}; + +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/abs.h b/dnn/src/arm_common/elemwise_helper/kimpl/abs.h new file mode 100644 index 00000000..14e56714 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/abs.h @@ -0,0 +1,124 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/abs.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct AbsOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { + return src > 0 ? src : (-src); + } +}; + +template +struct AbsOp; + +#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ + template <> \ + struct AbsOp<_ctype> : AbsOpBase<_ctype> { \ + using AbsOpBase::AbsOpBase; \ + using AbsOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type operator()(const _neon_type& src) const { \ + auto vitem0 = vabsq_##_func_suffix(src.val[0]); \ + auto vitem1 = vabsq_##_func_suffix(src.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + }; +OP(dt_float32, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8x2_t, f16, 8) +#endif +OP(dt_int32, int32x4x2_t, s32, 4) +OP(dt_int16, int16x8x2_t, s16, 8) +OP(dt_int8, int8x16x2_t, s8, 16) +#undef OP + +template <> +struct AbsOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint8& src, dt_qint8* dst) const { + *dst = operator()(src); + } + dt_qint8 operator()(const dt_qint8& src) const { + float fsrc = src.as_int8() * this->scale; + fsrc = fsrc > 0 ? fsrc : -fsrc; + return QConverter::convert(fsrc); + } +}; + +template <> +struct AbsOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_quint8& src, dt_quint8* dst) const { + *dst = operator()(src); + } + dt_quint8 operator()(const dt_quint8& src) const { + float fsrc = src.as_uint8() * this->scale - this->szp; + fsrc = fsrc > 0 ? fsrc : -fsrc; + return QConverter::convert(fsrc, this->dzp); + } +}; + +template <> +struct AbsOp : AbsOpBase { + using AbsOpBase::AbsOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using AbsOpBase::operator(); + void operator()(const int8x16x2_t& vsrc, dt_qint8* dst) const { + OPERATOR_UNARY_QINT8; + } + int8x8_t operator()(const int32x4x2_t& vsrc) const { + auto vitem0 = vmulq_f32(vcvtq_f32_s32(vsrc.val[0]), this->vscale); + auto vitem1 = vmulq_f32(vcvtq_f32_s32(vsrc.val[1]), this->vscale); + vitem0 = vabsq_f32(vitem0); + vitem1 = vabsq_f32(vitem1); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +template <> +struct AbsOp : AbsOpBase { + using AbsOpBase::AbsOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using AbsOpBase::operator(); + void operator()(const uint8x16x2_t& vsrc, dt_quint8* dst) const { + OPERATOR_UNARY_QUINT8; + } + uint8x8_t operator()(const uint32x4x2_t& vsrc) const { + auto vitem0 = vmulq_f32(vcvtq_f32_u32(vsrc.val[0]), this->vscale); + auto vitem1 = vmulq_f32(vcvtq_f32_u32(vsrc.val[1]), this->vscale); + vitem0 = vsubq_f32(vitem0, this->vszp); + vitem1 = vsubq_f32(vitem1, this->vszp); + vitem0 = vabsq_f32(vitem0); + vitem1 = vabsq_f32(vitem1); + return QConverter::convert( + {{vitem0, vitem1}}, this->vdzp); + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/add.h b/dnn/src/arm_common/elemwise_helper/kimpl/add.h new file mode 100644 index 00000000..a6462100 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/add.h @@ -0,0 +1,314 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/add.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct AddOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return src0 + src1; + } +}; + +template +struct AddOp; + +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct AddOp<_ctype> : AddOpBase<_ctype> { \ + using AddOpBase::AddOpBase; \ + using AddOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()(const _neon_type2& src0, \ + const _neon_type2& src1) const { \ + auto vitem0 = vaddq_##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = vaddq_##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()(const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, \ + const _neon_type& src1) const { \ + return vaddq_##_func_suffix(src0, src1); \ + } \ + }; +OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) +#endif +OP(dt_int32, int32x4_t, int32x4x2_t, s32, 4) +OP(dt_int16, int16x8_t, int16x8x2_t, s16, 8) +OP(dt_int8, int8x16_t, int8x16x2_t, s8, 16) +#undef OP + +template <> +struct AddOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint8& src0, const dt_qint8& src1, + dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { + return QConverter::convert( + src0.as_int8() * this->scale0 + src1.as_int8() * this->scale1); + } +}; + +template <> +struct AddOpBase : BinaryOpBase { + float szp; + float32x4_t vszp; + + AddOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) + : BinaryOpBase(src0_dtype, src1_dtype, dst_dtype) { + szp = this->szp0 + this->szp1; + vszp = vdupq_n_f32(szp); + } + void operator()(const dt_quint8& src0, const dt_quint8& src1, + dt_quint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_quint8 operator()(const dt_quint8& src0, const dt_quint8& src1) const { + return QConverter::convert( + src0.as_uint8() * this->scale0 + + src1.as_uint8() * this->scale1 - this->szp, + this->dzp); + } +}; + +template <> +struct AddOp : AddOpBase { + using AddOpBase::AddOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using AddOpBase::operator(); + + void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, + dt_qint8* dst) const { + OPERATOR_BINARY_QINT8; + } + + int8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + auto vitem0 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); + + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +template <> +struct AddOp : AddOpBase { + using AddOpBase::AddOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using AddOpBase::operator(); + + void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { + OPERATOR_BINARY_QUINT8; + } + + uint8x8_t operator()(const uint32x4x2_t& vsrc0, + const uint32x4x2_t& vsrc1) const { + auto vitem0 = vsubq_f32( + vaddq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale0), + vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale1)), + this->vszp); + auto vitem1 = vsubq_f32( + vaddq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale0), + vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale1)), + this->vszp); + + return QConverter::convert( + {{vitem0, vitem1}}, this->vdzp); + } +}; + +template <> +struct AddOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint32& src0, const dt_qint32& src1, + dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { + return QConverter::convert( + src0.as_int32() * this->scale0 + + src1.as_int32() * this->scale1); + } +}; + +template <> +struct AddOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint32& src0, const dt_qint32& src1, + dt_quint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_quint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { + return QConverter::convert( + src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1, + zp); + } +}; + +#if MEGDNN_AARCH64 +template +struct AddOp + : AddOpBase { + using AddOpBase::AddOpBase; + using AddOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + + void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, + dt_qint8* dst) const { + vst1_s8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); + } + + int8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + if (enable_opt_or_fixup) { + auto vitem0 = vmulq_f32( + vcvtq_f32_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0])), + this->vscale0); + auto vitem1 = vmulq_f32( + vcvtq_f32_s32(vaddq_s32(vsrc0.val[1], vsrc1.val[1])), + this->vscale0); + return QConverter::convert( + {{vitem0, vitem1}}); + } else { + auto vitem0 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); + return QConverter::convert( + {{vitem0, vitem1}}); + } + } +}; +#else +template +struct AddOp + : AddOpBase, FixupBase { + using AddOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + + AddOp(DType src0_dtype, DType src1_dtype, DType dst_dtype) + : AddOpBase(src0_dtype, src1_dtype, dst_dtype), FixupBase(scale0) {} + + AddOp(float src0_scale, float src1_scale, float dst_scale) + : AddOpBase(src0_scale, src1_scale, dst_scale), FixupBase(scale0) {} + + void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, + dt_qint8* dst) const { + vst1_s8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); + } + + int8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + if (enable_opt_or_fixup) { + auto vitem0 = vqrdmulhq_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0]), + vmultiplier); + auto vitem1 = vqrdmulhq_s32(vaddq_s32(vsrc0.val[1], vsrc1.val[1]), + vmultiplier); + // FIXME Theoretically, we should check shift != 0 here. + auto fixup0 = vshrq_n_s32(vitem0, 31); + auto fixup1 = vshrq_n_s32(vitem1, 31); + vitem0 = vqaddq_s32(vitem0, fixup0); + vitem1 = vqaddq_s32(vitem1, fixup1); + return vqmovn_s16( + vcombine_s16(vqmovn_s32(vrshlq_s32(vitem0, vshift)), + vqmovn_s32(vrshlq_s32(vitem1, vshift)))); + } else { + auto vitem0 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); + return QConverter::convert( + {{vitem0, vitem1}}); + } + } +}; +#endif + +template +struct AddOp + : AddOpBase { + using AddOpBase::AddOpBase; + constexpr static size_t SIMD_WIDTH = 4; + using AddOpBase::operator(); + + void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, + dt_quint8* dst) const { + vst1_u8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); + } + + uint8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + if (enable_opt_or_fixup) { + auto vitem0 = vmulq_f32( + vcvtq_f32_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0])), + this->vscale0); + auto vitem1 = vmulq_f32( + vcvtq_f32_s32(vaddq_s32(vsrc0.val[1], vsrc1.val[1])), + this->vscale0); + return QConverter::convert( + {{vitem0, vitem1}}, this->vzp); + } else { + auto vitem0 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); + return QConverter::convert( + {{vitem0, vitem1}}, this->vzp); + } + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/exp.h b/dnn/src/arm_common/elemwise_helper/kimpl/exp.h new file mode 100644 index 00000000..6dfea966 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/exp.h @@ -0,0 +1,59 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/exp.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct ExpOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { + float tmp = src; + return exp(tmp); + } +}; + +template +struct ExpOp; + +#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ + template <> \ + struct ExpOp<_ctype> : ExpOpBase<_ctype> { \ + using ExpOpBase::ExpOpBase; \ + using ExpOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type operator()(const _neon_type& src) const { \ + auto vitem0 = exp_ps_##_func_suffix(src.val[0]); \ + auto vitem1 = exp_ps_##_func_suffix(src.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + }; +OP(dt_float32, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8x2_t, f16, 8) +#endif +#undef OP + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/fast_tanh.h b/dnn/src/arm_common/elemwise_helper/kimpl/fast_tanh.h new file mode 100644 index 00000000..fcd5ebbf --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/fast_tanh.h @@ -0,0 +1,80 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/fast_tanh.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +//! tanh = x * (27 + x^2) / (27 + 9 * x^2) +template +struct FastTanhOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { + float x = src; + return x * (27.f + x * x) / (27.f + 9.f * x * x); + } +}; + +template +struct FastTanhOp; + +#define OP(_ctype, _neon_type, _func_suffix, _fix_func_suffix, _simd_width) \ + template <> \ + struct FastTanhOp<_ctype> : FastTanhOpBase<_ctype> { \ + using FastTanhOpBase::FastTanhOpBase; \ + using FastTanhOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type operator()(const _neon_type& src) const { \ + auto val_27 = vdupq_n_##_func_suffix(27.f); \ + auto val_9 = vdupq_n_##_func_suffix(9.f); \ + auto valx = src.val[0]; \ + auto valx1 = src.val[1]; \ + auto valxp2 = vmulq_##_fix_func_suffix(valx, valx); \ + auto valx1p2 = vmulq_##_fix_func_suffix(valx1, valx1); \ + auto denominator = vaddq_##_fix_func_suffix(valxp2, val_27); \ + auto denominator1 = vaddq_##_fix_func_suffix(valx1p2, val_27); \ + valx = vmulq_##_fix_func_suffix(valx, denominator); \ + valx1 = vmulq_##_fix_func_suffix(valx1, denominator1); \ + denominator = vmlaq_##_fix_func_suffix(val_27, valxp2, val_9); \ + denominator1 = vmlaq_##_fix_func_suffix(val_27, valx1p2, val_9); \ + auto r_denominator = vrecpeq_##_func_suffix(denominator); \ + auto r_denominator1 = vrecpeq_##_func_suffix(denominator1); \ + r_denominator = vmulq_##_fix_func_suffix( \ + vrecpsq_##_func_suffix(denominator, r_denominator), \ + r_denominator); \ + r_denominator1 = vmulq_##_fix_func_suffix( \ + vrecpsq_##_func_suffix(denominator1, r_denominator1), \ + r_denominator1); \ + valx = vmulq_##_fix_func_suffix(valx, r_denominator); \ + valx1 = vmulq_##_fix_func_suffix(valx1, r_denominator1); \ + return {{valx, valx1}}; \ + } \ + }; +OP(dt_float32, float32x4x2_t, f32, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8x2_t, f16, fix_f16, 8) +#endif +#undef OP + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_h_swish.h b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_h_swish.h new file mode 100644 index 00000000..a1f606a1 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_h_swish.h @@ -0,0 +1,197 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_h_swish.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" +#include "src/arm_common/elemwise_helper/kimpl/kern_macro_prologue.h" + +namespace megdnn { +namespace arm_common { + +template +struct FuseAddHSwishOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + float tmp = src0 + src1; + tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; + return tmp; + } +}; + +template +struct FuseAddHSwishOp; + +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct FuseAddHSwishOp<_ctype> : FuseAddHSwishOpBase<_ctype> { \ + using FuseAddHSwishOpBase::FuseAddHSwishOpBase; \ + using FuseAddHSwishOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()(const _neon_type2& src0, \ + const _neon_type2& src1) const { \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + val1 = vaddq_##_func_suffix(val1, val3); \ + val2 = vaddq_##_func_suffix(val2, val4); \ + H_SWISH_KERN(_func_suffix, val1, val2); \ + return {{val1, val2}}; \ + } \ + void operator()(const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, \ + const _neon_type& src1) const { \ + auto val1 = src0; \ + auto val2 = src1; \ + val1 = vaddq_##_func_suffix(val1, val2); \ + H_SWISH_KERN_N1(_func_suffix, val1); \ + return val1; \ + } \ + }; +OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) +#endif +#undef OP + +template <> +struct FuseAddHSwishOpBase + : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint32& src0, const dt_qint32& src1, + dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { + float tmp = src0.as_int32() * this->scale_src0 + + src1.as_int32() * this->scale_src1; + tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; + tmp *= this->scale_dst; + return QConverter::convert(tmp); + } +}; + +template <> +struct FuseAddHSwishOpBase + : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint32& src0, const dt_qint32& src1, + dt_quint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_quint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { + float tmp = src0.as_int32() * this->scale_src0 + + src1.as_int32() * this->scale_src1; + tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; + tmp *= this->scale_dst; + return QConverter::convert(tmp, zp); + } +}; + +template +struct FuseAddHSwishOp + : FuseAddHSwishOpBase { + using FuseAddHSwishOpBase::FuseAddHSwishOpBase; + using FuseAddHSwishOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, + dt_qint8* dst) const { + vst1_s8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); + } + + int8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + float32x4_t vitem0, vitem1; + if (enable_opt_or_fixup) { + vitem0 = vmulq_f32( + vcvtq_f32_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0])), + this->vscale_src0); + vitem1 = vmulq_f32( + vcvtq_f32_s32(vaddq_s32(vsrc0.val[1], vsrc1.val[1])), + this->vscale_src0); + + } else { + vitem0 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale_src0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale_src1)); + vitem1 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale_src0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale_src1)); + } + H_SWISH_KERN(f32, vitem0, vitem1); + vitem0 = vmulq_f32(vitem0, this->vscale_dst); + vitem1 = vmulq_f32(vitem1, this->vscale_dst); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +template +struct FuseAddHSwishOp + : FuseAddHSwishOpBase { + using FuseAddHSwishOpBase::FuseAddHSwishOpBase; + using FuseAddHSwishOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, + dt_quint8* dst) const { + vst1_u8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); + } + + uint8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + float32x4_t vitem0, vitem1; + if (enable_opt_or_fixup) { + vitem0 = vmulq_f32( + vcvtq_f32_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0])), + this->vscale_src0); + vitem1 = vmulq_f32( + vcvtq_f32_s32(vaddq_s32(vsrc0.val[1], vsrc1.val[1])), + this->vscale_src0); + + } else { + vitem0 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale_src0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale_src1)); + vitem1 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale_src0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale_src1)); + } + + H_SWISH_KERN(f32, vitem0, vitem1); + vitem0 = vmulq_f32(vitem0, this->vscale_dst); + vitem1 = vmulq_f32(vitem1, this->vscale_dst); + return QConverter::convert({{vitem0, vitem1}}, + this->vzp); + } +}; + +#include "src/arm_common/elemwise_helper/kimpl/kern_macro_epilogue.h" + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_relu.h b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_relu.h new file mode 100644 index 00000000..b5325ffe --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_relu.h @@ -0,0 +1,363 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_relu.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" +#include "src/arm_common/elemwise/neon_util_impl_helper.h" + +namespace megdnn { +namespace arm_common { + +template +struct FuseAddReluOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + auto tmp = src0 + src1; + return tmp > 0 ? tmp : 0; + } +}; + +template +struct FuseAddReluOp; + +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct FuseAddReluOp<_ctype> : FuseAddReluOpBase<_ctype> { \ + using FuseAddReluOpBase::FuseAddReluOpBase; \ + using FuseAddReluOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()(const _neon_type2& src0, \ + const _neon_type2& src1) const { \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + FUSE_ADD_RELU_NEON_PACK2(val1, val2, val3, val4, _func_suffix); \ + return {{val1, val2}}; \ + } \ + void operator()(const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, \ + const _neon_type& src1) const { \ + auto val1 = src0; \ + auto val2 = src1; \ + FUSE_ADD_RELU_NEON_PACK(val1, val2, _func_suffix); \ + return val1; \ + } \ + }; +OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) +#endif +OP(dt_int32, int32x4_t, int32x4x2_t, s32, 4) +OP(dt_int16, int16x8_t, int16x8x2_t, s16, 8) +OP(dt_int8, int8x16_t, int8x16x2_t, s8, 16) +#undef OP + +template +struct FuseAddReluOpCommon; + +template <> +struct FuseAddReluOpCommon { + inline static float32x4_t vzero() { return vdupq_n_f32(0); } +}; + +template <> +struct FuseAddReluOpCommon { + inline static int32x4_t vzero() { return vdupq_n_s32(0); } +}; + +template <> +struct FuseAddReluOpBase + : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint8& src0, const dt_qint8& src1, + dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { + return QConverter::convert(std::max( + src0.as_int8() * this->scale0 + src1.as_int8() * this->scale1, + 0.f)); + } +}; + +template <> +struct FuseAddReluOpBase + : BinaryOpBase { + float szp; + float32x4_t vszp; + + FuseAddReluOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) + : BinaryOpBase(src0_dtype, src1_dtype, dst_dtype) { + szp = this->szp0 + this->szp1; + vszp = vdupq_n_f32(szp); + } + void operator()(const dt_quint8& src0, const dt_quint8& src1, + dt_quint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_quint8 operator()(const dt_quint8& src0, const dt_quint8& src1) const { + return QConverter::convert( + std::max(src0.as_uint8() * this->scale0 + + src1.as_uint8() * this->scale1 - + this->szp, + 0.f), + this->dzp); + } +}; + +template <> +struct FuseAddReluOp + : FuseAddReluOpBase, FuseAddReluOpCommon { + using FuseAddReluOpBase::FuseAddReluOpBase; + using FuseAddReluOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 16; + + void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, + dt_qint8* dst) const { + OPERATOR_BINARY_QINT8; + } + + int8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + auto vitem0 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); + + vitem0 = vmaxq_f32(vitem0, this->vzero()); + vitem1 = vmaxq_f32(vitem1, this->vzero()); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +template <> +struct FuseAddReluOp + : FuseAddReluOpBase, FuseAddReluOpCommon { + using FuseAddReluOpBase::FuseAddReluOpBase; + using FuseAddReluOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 16; + + void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { + OPERATOR_BINARY_QUINT8; + } + + uint8x8_t operator()(const uint32x4x2_t& vsrc0, + const uint32x4x2_t& vsrc1) const { + auto vitem0 = vsubq_f32( + vaddq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale0), + vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale1)), + this->vszp); + auto vitem1 = vsubq_f32( + vaddq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale0), + vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale1)), + this->vszp); + + vitem0 = vmaxq_f32(vitem0, this->vzero()); + vitem1 = vmaxq_f32(vitem1, this->vzero()); + return QConverter::convert({{vitem0, vitem1}}, + this->vdzp); + } +}; + +template <> +struct FuseAddReluOpBase + : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint32& src0, const dt_qint32& src1, + dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { + return QConverter::convert(std::max( + src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1, + 0.f)); + } +}; + +template <> +struct FuseAddReluOpBase + : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint32& src0, const dt_qint32& src1, + dt_quint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_quint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { + return QConverter::convert( + std::max(src0.as_int32() * this->scale0 + + src1.as_int32() * this->scale1, + 0.f), + zp); + } +}; + +#if MEGDNN_AARCH64 +template +struct FuseAddReluOp + : FuseAddReluOpBase, FuseAddReluOpCommon { + using FuseAddReluOpBase::FuseAddReluOpBase; + using FuseAddReluOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, + dt_qint8* dst) const { + vst1_s8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); + } + + int8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + if (enable_opt_or_fixup) { + auto vitem0 = vmulq_f32( + vcvtq_f32_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0])), + this->vscale0); + auto vitem1 = vmulq_f32( + vcvtq_f32_s32(vaddq_s32(vsrc0.val[1], vsrc1.val[1])), + this->vscale0); + vitem0 = vmaxq_f32(vitem0, this->vzero()); + vitem1 = vmaxq_f32(vitem1, this->vzero()); + return QConverter::convert( + {{vitem0, vitem1}}); + + } else { + auto vitem0 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); + + vitem0 = vmaxq_f32(vitem0, this->vzero()); + vitem1 = vmaxq_f32(vitem1, this->vzero()); + return QConverter::convert( + {{vitem0, vitem1}}); + } + } +}; +#else +template +struct FuseAddReluOp + : FuseAddReluOpBase, + FuseAddReluOpCommon, + FixupBase { + using FuseAddReluOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + FuseAddReluOp(DType src0_dtype, DType src1_dtype, DType dst_dtype) + : FuseAddReluOpBase(src0_dtype, src1_dtype, dst_dtype), + FixupBase(scale0) {} + + FuseAddReluOp(float src0_scale, float src1_scale, float dst_scale) + : FuseAddReluOpBase(src0_scale, src1_scale, dst_scale), + FixupBase(scale0) {} + + void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, + dt_qint8* dst) const { + vst1_s8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); + } + + int8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + if (enable_opt_or_fixup) { + auto vitem0 = vqrdmulhq_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0]), + vmultiplier); + auto vitem1 = vqrdmulhq_s32(vaddq_s32(vsrc0.val[1], vsrc1.val[1]), + vmultiplier); + vitem0 = vmaxq_s32(vitem0, FuseAddReluOpCommon::vzero()); + vitem1 = vmaxq_s32(vitem1, FuseAddReluOpCommon::vzero()); + return vqmovn_s16( + vcombine_s16(vqmovn_s32(vrshlq_s32(vitem0, vshift)), + vqmovn_s32(vrshlq_s32(vitem1, vshift)))); + } else { + auto vitem0 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); + + vitem0 = vmaxq_f32(vitem0, this->vzero()); + vitem1 = vmaxq_f32(vitem1, this->vzero()); + return QConverter::convert( + {{vitem0, vitem1}}); + } + } +}; +#endif + +template +struct FuseAddReluOp + : FuseAddReluOpBase, FuseAddReluOpCommon { + using FuseAddReluOpBase::FuseAddReluOpBase; + using FuseAddReluOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, + dt_quint8* dst) const { + vst1_u8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); + } + + uint8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + if (enable_opt_or_fixup) { + auto vitem0 = vmulq_f32( + vcvtq_f32_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0])), + this->vscale0); + auto vitem1 = vmulq_f32( + vcvtq_f32_s32(vaddq_s32(vsrc0.val[1], vsrc1.val[1])), + this->vscale0); + vitem0 = vmaxq_f32(vitem0, this->vzero()); + vitem1 = vmaxq_f32(vitem1, this->vzero()); + return QConverter::convert( + {{vitem0, vitem1}}, this->vzp); + + } else { + auto vitem0 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = vaddq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); + + vitem0 = vmaxq_f32(vitem0, this->vzero()); + vitem1 = vmaxq_f32(vitem1, this->vzero()); + + return QConverter::convert( + {{vitem0, vitem1}}, this->vzp); + } + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h new file mode 100644 index 00000000..50a94e4d --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h @@ -0,0 +1,82 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct FuseAddSigmoidOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + float tmpf = src0 + src1; + tmpf = exp(-tmpf); + tmpf = 1.f / (1.f + tmpf); + return tmpf; + } +}; + +template +struct FuseAddSigmoidOp; + +#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ + template <> \ + struct FuseAddSigmoidOp<_ctype> : FuseAddSigmoidOpBase<_ctype> { \ + using FuseAddSigmoidOpBase::FuseAddSigmoidOpBase; \ + using FuseAddSigmoidOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type operator()(const _neon_type& src0, \ + const _neon_type& src1) const { \ + auto zero_val = vdupq_n_##_func_suffix(0.f); \ + auto one_val = vdupq_n_##_func_suffix(1.f); \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + val1 = vaddq_##_func_suffix(val1, val3); \ + val2 = vaddq_##_func_suffix(val2, val4); \ + val1 = vsubq_##_func_suffix(zero_val, val1); \ + val2 = vsubq_##_func_suffix(zero_val, val2); \ + val1 = exp_ps_##_func_suffix(val1); \ + val2 = exp_ps_##_func_suffix(val2); \ + auto recipe1 = vaddq_##_func_suffix(one_val, val1); \ + auto recipe2 = vaddq_##_func_suffix(one_val, val2); \ + val1 = vrecpeq_##_func_suffix(recipe1); \ + val2 = vrecpeq_##_func_suffix(recipe2); \ + val1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe1, val1), \ + val1); \ + val2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe2, val2), \ + val2); \ + return {{val1, val2}}; \ + } \ + }; +OP(dt_float32, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8x2_t, f16, 8) +#endif +#undef OP + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_tanh.h b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_tanh.h new file mode 100644 index 00000000..c18d8b47 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_tanh.h @@ -0,0 +1,87 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_tanh.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct FuseAddTanhOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + float tmpf = exp(src0 + (src1)); + float tmpf2 = 1 / tmpf; + return (tmpf - tmpf2) / (tmpf + tmpf2); + } +}; + +template +struct FuseAddTanhOp; + +#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ + template <> \ + struct FuseAddTanhOp<_ctype> : FuseAddTanhOpBase<_ctype> { \ + using FuseAddTanhOpBase::FuseAddTanhOpBase; \ + using FuseAddTanhOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type operator()(const _neon_type& src0, \ + const _neon_type& src1) const { \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + val1 = vaddq_##_func_suffix(val1, val3); \ + val2 = vaddq_##_func_suffix(val2, val4); \ + auto exp1 = exp_ps_##_func_suffix(val1); \ + auto exp2 = exp_ps_##_func_suffix(val2); \ + auto rexp1 = vrecpeq_##_func_suffix(exp1); \ + auto rexp2 = vrecpeq_##_func_suffix(exp2); \ + rexp1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(exp1, rexp1), \ + rexp1); \ + rexp2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(exp2, rexp2), \ + rexp2); \ + val1 = vsubq_##_func_suffix(exp1, rexp1); \ + val2 = vsubq_##_func_suffix(exp2, rexp2); \ + exp1 = vaddq_##_func_suffix(exp1, rexp1); \ + exp2 = vaddq_##_func_suffix(exp2, rexp2); \ + rexp1 = vrecpeq_##_func_suffix(exp1); \ + rexp2 = vrecpeq_##_func_suffix(exp2); \ + rexp1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(exp1, rexp1), \ + rexp1); \ + rexp2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(exp2, rexp2), \ + rexp2); \ + val1 = vmulq_##_func_suffix(val1, rexp1); \ + val2 = vmulq_##_func_suffix(val2, rexp2); \ + return {{val1, val2}}; \ + } \ + }; +OP(dt_float32, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8x2_t, f16, 8) +#endif +#undef OP + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_mul_add3.h b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_mul_add3.h new file mode 100644 index 00000000..871ca5f4 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_mul_add3.h @@ -0,0 +1,68 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/fuse_mul_add3.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct FuseMulAdd3OpBase : TernaryOpBase { + using TernaryOpBase::TernaryOpBase; + void operator()(const src_ctype& src0, const src_ctype& src1, + const src_ctype src2, dst_ctype* dst) const { + *dst = operator()(src0, src1, src2); + } + + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1, + const src_ctype& src2) const { + return (src0 * src1) + src2; + } +}; + +template +struct FuseMulAdd3Op; + +#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ + template <> \ + struct FuseMulAdd3Op<_ctype> : FuseMulAdd3OpBase<_ctype> { \ + using FuseMulAdd3OpBase::FuseMulAdd3OpBase; \ + using FuseMulAdd3OpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type& src0, const _neon_type& src1, \ + const _neon_type& src2, dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1, src2); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type operator()(const _neon_type& src0, const _neon_type& src1, \ + const _neon_type& src2) const { \ + auto vitem0 = vmlaq_##_func_suffix(src2.val[0], src0.val[0], \ + src1.val[0]); \ + auto vitem1 = vmlaq_##_func_suffix(src2.val[1], src0.val[1], \ + src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + }; +OP(dt_float32, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8x2_t, f16, 8) +#endif +OP(dt_int32, int32x4x2_t, s32, 4) +OP(dt_int16, int16x8x2_t, s16, 8) +OP(dt_int8, int8x16x2_t, s8, 16) +#undef OP + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/hswish.h b/dnn/src/arm_common/elemwise_helper/kimpl/hswish.h new file mode 100644 index 00000000..dccd9486 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/hswish.h @@ -0,0 +1,163 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/hswish.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/kern_macro_prologue.h" +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct HSwishOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { + float tmp = src; + tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; + return (tmp); + } +}; + +//! h_swish(x) = x * clip(x + 3, 0, 6) / 6 +template +struct HSwishOp; + +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct HSwishOp<_ctype> : HSwishOpBase<_ctype> { \ + using HSwishOpBase::HSwishOpBase; \ + using HSwishOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()(const _neon_type2& src) const { \ + auto val1 = src.val[0]; \ + auto val2 = src.val[1]; \ + H_SWISH_KERN(_func_suffix, val1, val2); \ + return {{val1, val2}}; \ + } \ + _neon_type operator()(const _neon_type& src) { \ + auto val_zero = vdupq_n_##_func_suffix(0.f); \ + auto val_six = vdupq_n_##_func_suffix(6.f); \ + auto val_three = vdupq_n_##_func_suffix(3.f); \ + auto val_rec_six = vdupq_n_##_func_suffix(1.f / 6.f); \ + auto clip1 = vmaxq_##_func_suffix( \ + vminq_##_func_suffix(vaddq_##_func_suffix(src, val_three), \ + val_six), \ + val_zero); \ + return vmulq_##_func_suffix(vmulq_##_func_suffix(src, clip1), \ + val_rec_six); \ + } \ + }; +OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) +#endif +#undef OP + +template <> +struct HSwishOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint32& src, dt_qint8* dst) const { + *dst = operator()(src); + } + + dt_qint8 operator()(const dt_qint32& src) const { + float tmp = src.as_int32() * this->scale_src; + tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; + tmp *= this->scale_dst; + return QConverter::convert(tmp); + } +}; + +template <> +struct HSwishOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint32& src, dt_quint8* dst) const { + *dst = operator()(src); + } + + dt_quint8 operator()(const dt_qint32& src) const { + float tmp = src.as_int32() * this->scale_src; + tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; + tmp *= this->scale_dst; + return QConverter::convert(tmp, zp); + } +}; + +template <> +struct HSwishOp : HSwishOpBase { + using HSwishOpBase::HSwishOpBase; + using HSwishOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + + void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { + vst1_s8(reinterpret_cast(dst), operator()(vsrc)); + } + void operator()(const int32x4_t& vsrc, dt_qint8* dst) const { + vst1_lane_s32(reinterpret_cast(dst), + (int32x2_t)(operator()(vsrc)), 0); + } + + int8x8_t operator()(const int32x4x2_t& vsrc) const { + auto vitem0 = vmulq_f32(vcvtq_f32_s32(vsrc.val[0]), this->vscale_src); + auto vitem1 = vmulq_f32(vcvtq_f32_s32(vsrc.val[1]), this->vscale_src); + + H_SWISH_KERN(f32, vitem0, vitem1); + vitem0 = vmulq_f32(vitem0, this->vscale_dst); + vitem1 = vmulq_f32(vitem1, this->vscale_dst); + + return QConverter::convert({{vitem0, vitem1}}); + } + int8x8_t operator()(const int32x4_t& src) const { + auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale_src); + + H_SWISH_KERN_N1(f32, vitem0); + vitem0 = vmulq_f32(vitem0, this->vscale_dst); + + return QConverter::convert(vitem0); + } +}; + +template <> +struct HSwishOp : HSwishOpBase { + using HSwishOpBase::HSwishOpBase; + using HSwishOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + + void operator()(const int32x4x2_t& vsrc, dt_quint8* dst) const { + vst1_u8(reinterpret_cast(dst), operator()(vsrc)); + } + + uint8x8_t operator()(const int32x4x2_t& vsrc) const { + auto vitem0 = vmulq_f32(vcvtq_f32_s32(vsrc.val[0]), this->vscale_src); + auto vitem1 = vmulq_f32(vcvtq_f32_s32(vsrc.val[1]), this->vscale_src); + H_SWISH_KERN(f32, vitem0, vitem1); + vitem0 = vmulq_f32(vitem0, this->vscale_dst); + vitem1 = vmulq_f32(vitem1, this->vscale_dst); + + return QConverter::convert({{vitem0, vitem1}}, + this->vzp); + } +}; + +} // namespace arm_common +} // namespace megdnn + +#include "src/arm_common/elemwise_helper/kimpl/kern_macro_epilogue.h" +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/kern_macro_epilogue.h b/dnn/src/arm_common/elemwise_helper/kimpl/kern_macro_epilogue.h new file mode 100644 index 00000000..61f3472e --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/kern_macro_epilogue.h @@ -0,0 +1,14 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/kern_macro_epilogue.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#undef H_SWISH_KERN + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/kern_macro_prologue.h b/dnn/src/arm_common/elemwise_helper/kimpl/kern_macro_prologue.h new file mode 100644 index 00000000..edbf2253 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/kern_macro_prologue.h @@ -0,0 +1,47 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/kern_macro_prologue.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#define H_SWISH_KERN(_func_suffix, _val1, _val2) \ + do { \ + auto val_zero = vdupq_n_##_func_suffix(0.f); \ + auto val_six = vdupq_n_##_func_suffix(6.f); \ + auto val_three = vdupq_n_##_func_suffix(3.f); \ + auto val_rec_six = vdupq_n_##_func_suffix(1.f / 6.f); \ + auto clip1 = vmaxq_##_func_suffix( \ + vminq_##_func_suffix(vaddq_##_func_suffix(_val1, val_three), \ + val_six), \ + val_zero); \ + auto clip2 = vmaxq_##_func_suffix( \ + vminq_##_func_suffix(vaddq_##_func_suffix(_val2, val_three), \ + val_six), \ + val_zero); \ + _val1 = vmulq_##_func_suffix(vmulq_##_func_suffix(_val1, clip1), \ + val_rec_six); \ + _val2 = vmulq_##_func_suffix(vmulq_##_func_suffix(_val2, clip2), \ + val_rec_six); \ + } while (0); + +#define H_SWISH_KERN_N1(_func_suffix, _val1) \ + do { \ + auto val_zero = vdupq_n_##_func_suffix(0.f); \ + auto val_six = vdupq_n_##_func_suffix(6.f); \ + auto val_three = vdupq_n_##_func_suffix(3.f); \ + auto val_rec_six = vdupq_n_##_func_suffix(1.f / 6.f); \ + auto clip1 = vmaxq_##_func_suffix( \ + vminq_##_func_suffix(vaddq_##_func_suffix(_val1, val_three), \ + val_six), \ + val_zero); \ + _val1 = vmulq_##_func_suffix(vmulq_##_func_suffix(_val1, clip1), \ + val_rec_six); \ + } while (0); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/max.h b/dnn/src/arm_common/elemwise_helper/kimpl/max.h new file mode 100644 index 00000000..17e4bb63 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/max.h @@ -0,0 +1,165 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/max.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { +template +struct MaxOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return src0 > src1 ? src0 : src1; + } +}; + +template +struct MaxOp; + +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct MaxOp<_ctype> : MaxOpBase<_ctype> { \ + using MaxOpBase::MaxOpBase; \ + using MaxOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()(const _neon_type2& src0, \ + const _neon_type2& src1) const { \ + auto vitem0 = vmaxq_##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = vmaxq_##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()(const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, \ + const _neon_type& src1) const { \ + return vmaxq_##_func_suffix(src0, src1); \ + } \ + }; +OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) +#endif +OP(dt_int32, int32x4_t, int32x4x2_t, s32, 4) +OP(dt_int16, int16x8_t, int16x8x2_t, s16, 8) +OP(dt_int8, int8x16_t, int8x16x2_t, s8, 16) +#undef OP + +template <> +struct MaxOpBase : BinaryOpBase { + using src_ctype = dt_qint8; + using dst_ctype = dt_qint8; + using BinaryOpBase::BinaryOpBase; + + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + float fsrc0 = src0.as_int8() * this->scale0; + float fsrc1 = src1.as_int8() * this->scale1; + return QConverter::convert(fsrc0 > fsrc1 ? fsrc0 + : fsrc1); + } +}; + +template <> +struct MaxOpBase : BinaryOpBase { + using src_ctype = dt_quint8; + using dst_ctype = dt_quint8; + using BinaryOpBase::BinaryOpBase; + + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + float fsrc0 = src0.as_uint8() * this->scale0 - this->szp0; + float fsrc1 = src1.as_uint8() * this->scale1 - this->szp1; + return QConverter::convert( + fsrc0 > fsrc1 ? fsrc0 : fsrc1, this->dzp); + } +}; + +template <> +struct MaxOp : MaxOpBase { + using MaxOpBase::MaxOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using MaxOpBase::operator(); + + void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, + dt_qint8* dst) const { + OPERATOR_BINARY_QINT8; + } + + int8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + auto vitem0 = vmaxq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = vmaxq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +template <> +struct MaxOp : MaxOpBase { + using MaxOpBase::MaxOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using MaxOpBase::operator(); + + void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { + OPERATOR_BINARY_QUINT8; + } + + uint8x8_t operator()(const uint32x4x2_t& vsrc0, + const uint32x4x2_t vsrc1) const { + auto vsrct0 = + vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale0), + this->vszp0); + auto vsrct1 = + vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale1), + this->vszp1); + auto vitem0 = vmaxq_f32(vsrct0, vsrct1); + vsrct0 = + vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale0), + this->vszp0); + vsrct1 = + vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale1), + this->vszp1); + auto vitem1 = vmaxq_f32(vsrct0, vsrct1); + return QConverter::convert( + {{vitem0, vitem1}}, this->vdzp); + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/min.h b/dnn/src/arm_common/elemwise_helper/kimpl/min.h new file mode 100644 index 00000000..59f623be --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/min.h @@ -0,0 +1,159 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/min.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct MinOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return src0 < src1 ? src0 : src1; + } +}; + +template +struct MinOp; + +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct MinOp<_ctype> : MinOpBase<_ctype> { \ + using MinOpBase::MinOpBase; \ + using MinOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()(const _neon_type2& src0, \ + const _neon_type2& src1) const { \ + auto vitem0 = vminq_##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = vminq_##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()(const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, \ + const _neon_type& src1) const { \ + return vminq_##_func_suffix(src0, src1); \ + } \ + }; +OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) +#endif +OP(dt_int32, int32x4_t, int32x4x2_t, s32, 4) +OP(dt_int16, int16x8_t, int16x8x2_t, s16, 8) +OP(dt_int8, int8x16_t, int8x16x2_t, s8, 16) +#undef OP + +template <> +struct MinOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint8& src0, const dt_qint8& src1, + dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { + float fsrc0 = src0.as_int8() * this->scale0; + float fsrc1 = src1.as_int8() * this->scale1; + return QConverter::convert(fsrc0 < fsrc1 ? fsrc0 + : fsrc1); + } +}; + +template <> +struct MinOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_quint8& src0, const dt_quint8& src1, + dt_quint8* dst) const { + *dst = operator()(src0, src1); + } + dt_quint8 operator()(const dt_quint8& src0, const dt_quint8& src1) const { + float fsrc0 = src0.as_uint8() * this->scale0 - this->szp0; + float fsrc1 = src1.as_uint8() * this->scale1 - this->szp1; + return QConverter::convert( + fsrc0 < fsrc1 ? fsrc0 : fsrc1, this->dzp); + } +}; + +template <> +struct MinOp : MinOpBase { + using MinOpBase::MinOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using MinOpBase::operator(); + + void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, + dt_qint8* dst) const { + OPERATOR_BINARY_QINT8; + } + + int8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + auto vitem0 = vminq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = vminq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +template <> +struct MinOp : MinOpBase { + using MinOpBase::MinOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using MinOpBase::operator(); + + void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { + OPERATOR_BINARY_QUINT8; + } + + uint8x8_t operator()(const uint32x4x2_t& vsrc0, + const uint32x4x2_t& vsrc1) const { + auto vsrct0 = + vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale0), + this->vszp0); + auto vsrct1 = + vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale1), + this->vszp1); + auto vitem0 = vminq_f32(vsrct0, vsrct1); + vsrct0 = + vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale0), + this->vszp0); + vsrct1 = + vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale1), + this->vszp1); + auto vitem1 = vminq_f32(vsrct0, vsrct1); + return QConverter::convert( + {{vitem0, vitem1}}, this->vdzp); + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/mul.h b/dnn/src/arm_common/elemwise_helper/kimpl/mul.h new file mode 100644 index 00000000..e20c082d --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/mul.h @@ -0,0 +1,157 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/mul.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct MulOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return src0 * src1; + } +}; + +template +struct MulOp; + +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct MulOp<_ctype> : MulOpBase<_ctype> { \ + using MulOpBase::MulOpBase; \ + using MulOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()(const _neon_type2& src0, \ + const _neon_type2& src1) const { \ + auto vitem0 = vmulq_##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = vmulq_##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()(const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, \ + const _neon_type& src1) const { \ + return vmulq_##_func_suffix(src0, src1); \ + } \ + }; +OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) +#endif +OP(dt_int32, int32x4_t, int32x4x2_t, s32, 4) +OP(dt_int16, int16x8_t, int16x8x2_t, s16, 8) +OP(dt_int8, int8x16_t, int8x16x2_t, s8, 16) +#undef OP + +template <> +struct MulOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint8& src0, const dt_qint8& src1, + dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { + return QConverter::convert( + src0.as_int8() * scale_src0 * src1.as_int8() * scale1); + } +}; + +template <> +struct MulOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_quint8& src0, const dt_quint8& src1, + dt_quint8* dst) const { + *dst = operator()(src0, src1); + } + dt_quint8 operator()(const dt_quint8& src0, const dt_quint8& src1) const { + float fsrc0 = src0.as_uint8() * scale_src0 - this->scale_zp0; + float fsrc1 = src1.as_uint8() * scale_src1 - this->scale_zp1; + return QConverter::convert( + fsrc0 * fsrc1 * scale_dst, this->dzp); + } +}; + +template <> +struct MulOp : MulOpBase { + using MulOpBase::MulOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using MulOpBase::operator(); + + void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, + dt_qint8* dst) const { + OPERATOR_BINARY_QINT8; + } + + int8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + auto vitem0 = vmulq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale_src0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = vmulq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale_src0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); + + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +template <> +struct MulOp : MulOpBase { + using MulOpBase::MulOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using MulOpBase::operator(); + + void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { + OPERATOR_BINARY_QUINT8; + } + uint8x8_t operator()(const uint32x4x2_t& vsrc0, + const uint32x4x2_t vsrc1) const { + auto vfsrc0 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale_src0), + this->vscale_zp0); + auto vfsrc1 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale_src1), + this->vscale_zp1); + auto vitem0 = vmulq_f32(vmulq_f32(vfsrc0, vfsrc1), this->vscale_dst); + vfsrc0 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale_src0), + this->vscale_zp0); + vfsrc1 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale_src1), + this->vscale_zp1); + auto vitem1 = vmulq_f32(vmulq_f32(vfsrc0, vfsrc1), this->vscale_dst); + return QConverter::convert( + {{vitem0, vitem1}}, this->vdzp); + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/none.h b/dnn/src/arm_common/elemwise_helper/kimpl/none.h new file mode 100644 index 00000000..6cf5bd00 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/none.h @@ -0,0 +1,83 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/none.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct NoneOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + dst_ctype operator()(const src_ctype& src) const { return src; } +}; + +template +struct NoneOp; +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct NoneOp<_ctype> : NoneOpBase<_ctype> { \ + using NoneOpBase::NoneOpBase; \ + using NoneOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + _neon_type2 operator()(const _neon_type2& src) const { return src; } \ + _neon_type operator()(const _neon_type& src) const { return src; } \ + }; + +OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) +#endif +OP(dt_int32, int32x4_t, int32x4x2_t, s32, 4) +OP(dt_int16, int16x8_t, int16x8x2_t, s16, 8) +OP(dt_int8, int8x16_t, int8x16x2_t, s8, 16) +#undef OP + +template <> +struct NoneOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint8& src, dt_qint8* dst) const { *dst = src; } +}; + +template <> +struct NoneOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_quint8& src, dt_quint8* dst) const { *dst = src; } +}; + +template <> +struct NoneOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint32& src, dt_qint8* dst) const { + *(reinterpret_cast(dst)) = src; + } +}; + +template <> +struct NoneOp : NoneOpBase { + using NoneOpBase::NoneOpBase; + using NoneOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + + void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { + vst1q_s32(reinterpret_cast(dst), vsrc.val[0]); + vst1q_s32(reinterpret_cast(dst + 16), vsrc.val[1]); + } + void operator()(const int32x4_t& src, dt_qint8* dst) const { + vst1q_s32(reinterpret_cast(dst), src); + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/op_base.h b/dnn/src/arm_common/elemwise_helper/kimpl/op_base.h new file mode 100644 index 00000000..e1cdbbd9 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/op_base.h @@ -0,0 +1,916 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/op_base.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "src/arm_common/elemwise/neon_mathfun.h" +#include "src/arm_common/quantized_converter.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace arm_common { + +////////////////////////// unary ////////////////////////// +template +struct OpBase { + using src_ctype = _src_ctype; + using dst_ctype = _dst_ctype; + OpBase() = default; +}; + +template +struct UnaryOpBase : OpBase { + using OpBase::OpBase; + UnaryOpBase() = default; + UnaryOpBase(DType /*src_dtype*/, DType /*dst_dtype*/) {} +}; + +#define OPERATOR_UNARY_QINT8 \ + int16x8_t vsrct = vmovl_low_s8(vsrc.val[0]); \ + vst1_s8(reinterpret_cast(dst), \ + operator()({{vmovl_low_s16(vsrct), vmovl_high_s16(vsrct)}})); \ + \ + vsrct = vmovl_high_s8(vsrc.val[0]); \ + vst1_s8(reinterpret_cast(dst + 8), \ + operator()({{vmovl_low_s16(vsrct), vmovl_high_s16(vsrct)}})); \ + \ + vsrct = vmovl_low_s8(vsrc.val[1]); \ + vst1_s8(reinterpret_cast(dst + 16), \ + operator()({{vmovl_low_s16(vsrct), vmovl_high_s16(vsrct)}})); \ + \ + vsrct = vmovl_high_s8(vsrc.val[1]); \ + vst1_s8(reinterpret_cast(dst + 24), \ + operator()({{vmovl_low_s16(vsrct), vmovl_high_s16(vsrct)}})); + +#define OPERATOR_UNARY_QUINT8 \ + uint16x8_t vsrct = vmovl_low_u8(vsrc.val[0]); \ + vst1_u8(reinterpret_cast(dst), \ + operator()({{vmovl_low_u16(vsrct), vmovl_high_u16(vsrct)}})); \ + \ + vsrct = vmovl_high_u8(vsrc.val[0]); \ + vst1_u8(reinterpret_cast(dst + 8), \ + operator()({{vmovl_low_u16(vsrct), vmovl_high_u16(vsrct)}})); \ + \ + vsrct = vmovl_low_u8(vsrc.val[1]); \ + vst1_u8(reinterpret_cast(dst + 16), \ + operator()({{vmovl_low_u16(vsrct), vmovl_high_u16(vsrct)}})); \ + \ + vsrct = vmovl_high_u8(vsrc.val[1]); \ + vst1_u8(reinterpret_cast(dst + 24), \ + operator()({{vmovl_low_u16(vsrct), vmovl_high_u16(vsrct)}})); + +//! scale_src = src.scale; scale_dst = 1.f / dst.scale (div -> mul) +//! scale = src.scale / dst.scale +template <> +struct UnaryOpBase : OpBase { + using OpBase::OpBase; + float scale_src, scale_dst; + float32x4_t vscale_src, vscale_dst; + float scale; + float32x4_t vscale; + + void init(float src_scale, float dst_scale) { + scale_src = src_scale; + vscale_src = vdupq_n_f32(scale_src); + scale_dst = 1.f / dst_scale; + vscale_dst = vdupq_n_f32(scale_dst); + scale = src_scale / dst_scale; + vscale = vdupq_n_f32(scale); + } + + UnaryOpBase(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + init(src_scale, dst_scale); + } + UnaryOpBase(float src_scale, float dst_scale) { + init(src_scale, dst_scale); + } +}; + + +//! scale_src = src.scale; scale_dst = 1.f / dst.scale +//! scale_zp = src.zp * src.scale; dzp = dst.zp +//! scale = src.scale / dst.scale; szp = src.zp * scale +template <> +struct UnaryOpBase : OpBase { + using OpBase::OpBase; + float scale_src, scale_dst; + float32x4_t vscale_src, vscale_dst; + float scale_zp; + float32x4_t vscale_zp; + uint8_t dzp; + int32x4_t vdzp; + float scale, szp; + float32x4_t vscale, vszp; + + void init(float src_scale, float dst_scale, uint8_t src_zp, + uint8_t dst_zp) { + scale_src = src_scale; + scale_dst = 1.f / dst_scale; + vscale_src = vdupq_n_f32(scale_src); + vscale_dst = vdupq_n_f32(scale_dst); + scale_zp = src_zp * src_scale; + vscale_zp = vdupq_n_f32(scale_zp); + dzp = dst_zp; + vdzp = vdupq_n_s32(static_cast(dzp)); + scale = src_scale / dst_scale; + vscale = vdupq_n_f32(scale); + szp = src_zp * scale; + vszp = vdupq_n_f32(szp); + } + UnaryOpBase(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + uint8_t src_zp = src_dtype.param().zero_point; + uint8_t dst_zp = dst_dtype.param().zero_point; + init(src_scale, dst_scale, src_zp, dst_zp); + } + UnaryOpBase(float src_scale, float dst_scale, uint8_t src_zp, + uint8_t dst_zp) { + init(src_scale, dst_scale, src_zp, dst_zp); + } + float32x4x2_t cvt_to_float(const uint32x4x2_t& vsrc) { + auto vitem0 = vmulq_f32(vcvtq_f32_u32(vsrc.val[0]), this->vscale_src); + vitem0 = vsubq_f32(vitem0, this->vscale_zp); + auto vitem1 = vmulq_f32(vcvtq_f32_u32(vsrc.val[1]), this->vscale_src); + vitem1 = vsubq_f32(vitem1, this->vscale_zp); + return {{vitem0, vitem1}}; + } + uint8x8_t cvt_float_to_dst(float32x4x2_t& vsrc) { + auto vitem0 = vmulq_f32(vsrc.val[0], this->vscale_dst); + auto vitem1 = vmulq_f32(vsrc.val[1], this->vscale_dst); + return QConverter::convert( + {{vitem0, vitem1}}, this->vdzp); + } + float32x4x2_t cvt_to_fdst(const uint32x4x2_t& vsrc) { + auto vitem0 = vmulq_f32(vcvtq_f32_u32(vsrc.val[0]), this->vscale); + vitem0 = vsubq_f32(vitem0, this->vszp); + auto vitem1 = vmulq_f32(vcvtq_f32_u32(vsrc.val[1]), this->vscale); + vitem1 = vsubq_f32(vitem1, this->vszp); + return {{vitem0, vitem1}}; + } + uint8x8_t cvt_fdst_to_dst(float32x4x2_t& vsrc) { + return QConverter::convert( + vsrc, this->vdzp); + } +}; + +template <> +struct UnaryOpBase : OpBase { + using OpBase::OpBase; + using src_ctype = dt_qint32; + using dst_ctype = dt_qint8; + float scale; + float32x4_t vscale; + float scale_src, scale_dst; + float32x4_t vscale_src, vscale_dst; + + void init(float src_scale, float dst_scale) { + scale_src = src_scale; + vscale_src = vdupq_n_f32(src_scale); + scale_dst = 1 / dst_scale; + vscale_dst = vdupq_n_f32(scale_dst); + scale = src_scale / dst_scale; + vscale = vdupq_n_f32(scale); + } + + UnaryOpBase(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + init(src_scale, dst_scale); + } + + UnaryOpBase(float src_scale, float dst_scale) { + init(src_scale, dst_scale); + } +}; + +template <> +struct UnaryOpBase : OpBase { + using OpBase::OpBase; + using src_ctype = dt_qint32; + using dst_ctype = dt_quint8; + float scale; + float32x4_t vscale; + float scale_src, scale_dst; + float32x4_t vscale_src, vscale_dst; + uint8_t zp; + int32x4_t vzp; + + void init(float src_scale, float dst_scale, uint8_t zero_point) { + scale_src = src_scale; + vscale_src = vdupq_n_f32(src_scale); + scale_dst = 1 / dst_scale; + vscale_dst = vdupq_n_f32(scale_dst); + zp = zero_point; + vzp = vdupq_n_s32(static_cast(zp)); + scale = src_scale / dst_scale; + vscale = vdupq_n_f32(scale); + } + + UnaryOpBase(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + uint8_t zp = dst_dtype.param().zero_point; + init(src_scale, dst_scale, zp); + } + + UnaryOpBase(float src_scale, float dst_scale, uint8_t zero_point) { + init(src_scale, dst_scale, zero_point); + } +}; + +////////////////////////// binary ////////////////////////// +template +struct BinaryOpBase : OpBase { + using OpBase::OpBase; + BinaryOpBase() = default; + BinaryOpBase(DType /*src0_dtype*/, DType /*src1_dtype*/, + DType /*dst_dtype*/) {} +}; + +#define OPERATOR_BINARY_QINT8 \ + int16x8_t vsrct0 = vmovl_low_s8(vsrc0.val[0]); \ + int16x8_t vsrct1 = vmovl_low_s8(vsrc1.val[0]); \ + vst1_s8(reinterpret_cast(dst), \ + operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}})); \ + \ + vsrct0 = vmovl_high_s8(vsrc0.val[0]); \ + vsrct1 = vmovl_high_s8(vsrc1.val[0]); \ + vst1_s8(reinterpret_cast(dst + 8), \ + operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}})); \ + \ + vsrct0 = vmovl_low_s8(vsrc0.val[1]); \ + vsrct1 = vmovl_low_s8(vsrc1.val[1]); \ + vst1_s8(reinterpret_cast(dst + 16), \ + operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}})); \ + \ + vsrct0 = vmovl_high_s8(vsrc0.val[1]); \ + vsrct1 = vmovl_high_s8(vsrc1.val[1]); \ + vst1_s8(reinterpret_cast(dst + 24), \ + operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}})) + +#define OPERATOR_BINARY_QUINT8 \ + uint16x8_t vsrct0 = vmovl_low_u8(vsrc0.val[0]); \ + uint16x8_t vsrct1 = vmovl_low_u8(vsrc1.val[0]); \ + vst1_u8(reinterpret_cast(dst), \ + operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}})); \ + \ + vsrct0 = vmovl_high_u8(vsrc0.val[0]); \ + vsrct1 = vmovl_high_u8(vsrc1.val[0]); \ + vst1_u8(reinterpret_cast(dst + 8), \ + operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}})); \ + \ + vsrct0 = vmovl_low_u8(vsrc0.val[1]); \ + vsrct1 = vmovl_low_u8(vsrc1.val[1]); \ + vst1_u8(reinterpret_cast(dst + 16), \ + operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}})); \ + \ + vsrct0 = vmovl_high_u8(vsrc0.val[1]); \ + vsrct1 = vmovl_high_u8(vsrc1.val[1]); \ + vst1_u8(reinterpret_cast(dst + 24), \ + operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}})) + +/* ================= binary op for quantized types ================== */ + +//! scale_src0 = src0.scale; scale_src1 = src1.scale; scale_dst = 1.f / +//! dst.scale scale0 = src0.scale / dst.scale; scale1 = src1.scale / dst.scale +template <> +struct BinaryOpBase : OpBase { + using OpBase::OpBase; + using src_ctype = dt_qint8; + using dst_ctype = dt_qint8; + float scale_src0, scale_src1, scale_dst; + float32x4_t vscale_src0, vscale_src1, vscale_dst; + float scale0, scale1; + float32x4_t vscale0, vscale1; + + void init(float src0_scale, float src1_scale, float dst_scale) { + scale_src0 = src0_scale; + vscale_src0 = vdupq_n_f32(scale_src0); + scale_src1 = src1_scale; + vscale_src1 = vdupq_n_f32(scale_src1); + scale_dst = 1.f / dst_scale; + vscale_dst = vdupq_n_f32(scale_dst); + scale0 = src0_scale / dst_scale; + vscale0 = vdupq_n_f32(scale0); + scale1 = src1_scale / dst_scale; + vscale1 = vdupq_n_f32(scale1); + } + + BinaryOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) { + float src0_scale = src0_dtype.param().scale; + float src1_scale = src1_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + init(src0_scale, src1_scale, dst_scale); + } + + BinaryOpBase(float src0_scale, float src1_scale, float dst_scale) { + init(src0_scale, src1_scale, dst_scale); + } +}; + +//! scale_src0 = src0.scale; scale_src1 = src1.scale; scale_dst = 1.f / +//! dst.scale scale_zp0 = src0.zp * src0.scale; scale_zp1 = src1.zp * src1.scale +//! scale0 = src0.scale / dst.scale; scale1 = src1.scale / dst.scale +//! szp0 = src0.zp * scale0; szp1 = src1.zp * scale1 +//! dzp = dst.zp +template <> +struct BinaryOpBase : OpBase { + using OpBase::OpBase; + using src_ctype = dt_quint8; + using dst_ctype = dt_quint8; + float scale_src0, scale_src1, scale_dst; + float32x4_t vscale_src0, vscale_src1, vscale_dst; + float scale_zp0, scale_zp1; + float32x4_t vscale_zp0, vscale_zp1; + float scale0, scale1, szp0, szp1; + float32x4_t vscale0, vscale1, vszp0, vszp1; + uint8_t dzp; + int32x4_t vdzp; + + void init(float src0_scale, float src1_scale, float dst_scale, + uint8_t src0_zp, uint8_t src1_zp, uint8_t dst_zp) { + scale_src0 = src0_scale; + vscale_src0 = vdupq_n_f32(scale_src0); + scale_src1 = src1_scale; + vscale_src1 = vdupq_n_f32(scale_src1); + scale_dst = 1.f / dst_scale; + vscale_dst = vdupq_n_f32(scale_dst); + scale_zp0 = src0_zp * src0_scale; + vscale_zp0 = vdupq_n_f32(scale_zp0); + scale_zp1 = src1_zp * src1_scale; + vscale_zp1 = vdupq_n_f32(scale_zp1); + scale0 = src0_scale / dst_scale; + vscale0 = vdupq_n_f32(scale0); + scale1 = src1_scale / dst_scale; + vscale1 = vdupq_n_f32(scale1); + dzp = dst_zp; + vdzp = vdupq_n_s32(static_cast(dzp)); + szp0 = src0_zp * scale0; + szp1 = src1_zp * scale1; + vszp0 = vdupq_n_f32(szp0); + vszp1 = vdupq_n_f32(szp1); + } + + BinaryOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) { + float src0_scale = src0_dtype.param().scale; + float src1_scale = src1_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + uint8_t src0_zp = src0_dtype.param().zero_point; + uint8_t src1_zp = src1_dtype.param().zero_point; + uint8_t dst_zp = dst_dtype.param().zero_point; + init(src0_scale, src1_scale, dst_scale, src0_zp, src1_zp, dst_zp); + } + + BinaryOpBase(float src0_scale, float src1_scale, float dst_scale, + uint8_t src0_zp, uint8_t src1_zp, uint8_t dst_zp) { + init(src0_scale, src1_scale, dst_scale, src0_zp, src1_zp, dst_zp); + } +}; + +template <> +struct BinaryOpBase : OpBase { + using OpBase::OpBase; + using src_ctype = dt_qint32; + using dst_ctype = dt_qint8; + float scale0, scale1; + float32x4_t vscale0, vscale1; + float scale_src0, scale_src1, scale_dst; + float32x4_t vscale_src0, vscale_src1, vscale_dst; + + void init(float src0_scale, float src1_scale, float dst_scale) { + scale_src0 = src0_scale; + vscale_src0 = vdupq_n_f32(src0_scale); + scale_src1 = src1_scale; + vscale_src1 = vdupq_n_f32(src1_scale); + scale_dst = 1 / dst_scale; + vscale_dst = vdupq_n_f32(scale_dst); + scale0 = src0_scale / dst_scale; + vscale0 = vdupq_n_f32(scale0); + scale1 = src1_scale / dst_scale; + vscale1 = vdupq_n_f32(scale1); + } + + BinaryOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) { + float src0_scale = src0_dtype.param().scale; + float src1_scale = src1_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + init(src0_scale, src1_scale, dst_scale); + } + + BinaryOpBase(float src0_scale, float src1_scale, float dst_scale) { + init(src0_scale, src1_scale, dst_scale); + } +}; + +template <> +struct BinaryOpBase : OpBase { + using OpBase::OpBase; + using src_ctype = dt_qint32; + using dst_ctype = dt_quint8; + float scale0, scale1; + float32x4_t vscale0, vscale1; + uint8_t zp; + int32x4_t vzp; + float scale_src0, scale_src1, scale_dst; + float32x4_t vscale_src0, vscale_src1, vscale_dst; + + void init(float src0_scale, float src1_scale, float dst_scale, + uint8_t zero_point) { + scale_src0 = src0_scale; + vscale_src0 = vdupq_n_f32(src0_scale); + scale_src1 = src1_scale; + vscale_src1 = vdupq_n_f32(src1_scale); + scale_dst = 1 / dst_scale; + vscale_dst = vdupq_n_f32(scale_dst); + zp = zero_point; + vzp = vdupq_n_s32(static_cast(zp)); + scale0 = src0_scale / dst_scale; + vscale0 = vdupq_n_f32(scale0); + scale1 = src1_scale / dst_scale; + vscale1 = vdupq_n_f32(scale1); + } + + BinaryOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) { + float src0_scale = src0_dtype.param().scale; + float src1_scale = src1_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + uint8_t zp = dst_dtype.param().zero_point; + init(src0_scale, src1_scale, dst_scale, zp); + } + + BinaryOpBase(float src0_scale, float src1_scale, float dst_scale, + uint8_t zero_point) { + init(src0_scale, src1_scale, dst_scale, zero_point); + } +}; + +////////////////////////// ternary ////////////////////////// +template +struct TernaryOpBase : OpBase { + using OpBase::OpBase; + TernaryOpBase() = default; + TernaryOpBase(DType /*src0_dtype*/, DType /*src1_dtype*/, + DType /*src2_dtype*/, DType /*dst_dtype*/) {} +}; + +#define OPERATOR_TERNARY_QINT8 \ + int16x8_t vsrct0 = vmovl_low_s8(vsrc0.val[0]); \ + int16x8_t vsrct1 = vmovl_low_s8(vsrc1.val[0]); \ + int16x8_t vsrct2 = vmovl_low_s8(vsrc2.val[0]); \ + vst1_s8(reinterpret_cast(dst), \ + operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}}, \ + {{vmovl_low_s16(vsrct2), vmovl_high_s16(vsrct2)}})); \ + \ + vsrct0 = vmovl_high_s8(vsrc0.val[0]); \ + vsrct1 = vmovl_high_s8(vsrc1.val[0]); \ + vsrct2 = vmovl_high_s8(vsrc2.val[0]); \ + vst1_s8(reinterpret_cast(dst + 8), \ + operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}}, \ + {{vmovl_low_s16(vsrct2), vmovl_high_s16(vsrct2)}})); \ + \ + vsrct0 = vmovl_low_s8(vsrc0.val[1]); \ + vsrct1 = vmovl_low_s8(vsrc1.val[1]); \ + vsrct2 = vmovl_low_s8(vsrc2.val[1]); \ + vst1_s8(reinterpret_cast(dst + 16), \ + operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}}, \ + {{vmovl_low_s16(vsrct2), vmovl_high_s16(vsrct2)}})); \ + \ + vsrct0 = vmovl_high_s8(vsrc0.val[1]); \ + vsrct1 = vmovl_high_s8(vsrc1.val[1]); \ + vsrct2 = vmovl_high_s8(vsrc2.val[1]); \ + vst1_s8(reinterpret_cast(dst + 24), \ + operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}}, \ + {{vmovl_low_s16(vsrct2), vmovl_high_s16(vsrct2)}})) + +#define OPERATOR_TERNARY_QUINT8 \ + uint16x8_t vsrct0 = vmovl_low_u8(vsrc0.val[0]); \ + uint16x8_t vsrct1 = vmovl_low_u8(vsrc1.val[0]); \ + uint16x8_t vsrct2 = vmovl_low_u8(vsrc2.val[0]); \ + vst1_u8(reinterpret_cast(dst), \ + operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}}, \ + {{vmovl_low_u16(vsrct2), vmovl_high_u16(vsrct2)}})); \ + \ + vsrct0 = vmovl_high_u8(vsrc0.val[0]); \ + vsrct1 = vmovl_high_u8(vsrc1.val[0]); \ + vsrct2 = vmovl_high_u8(vsrc2.val[0]); \ + vst1_u8(reinterpret_cast(dst + 8), \ + operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}}, \ + {{vmovl_low_u16(vsrct2), vmovl_high_u16(vsrct2)}})); \ + \ + vsrct0 = vmovl_low_u8(vsrc0.val[1]); \ + vsrct1 = vmovl_low_u8(vsrc1.val[1]); \ + vsrct2 = vmovl_low_u8(vsrc2.val[1]); \ + vst1_u8(reinterpret_cast(dst + 16), \ + operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}}, \ + {{vmovl_low_u16(vsrct2), vmovl_high_u16(vsrct2)}})); \ + \ + vsrct0 = vmovl_high_u8(vsrc0.val[1]); \ + vsrct1 = vmovl_high_u8(vsrc1.val[1]); \ + vsrct2 = vmovl_high_u8(vsrc2.val[1]); \ + vst1_u8(reinterpret_cast(dst + 24), \ + operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}}, \ + {{vmovl_low_u16(vsrct2), vmovl_high_u16(vsrct2)}})) + +/*========================= ternaty op for quanzited ====================*/ +template <> +struct TernaryOpBase : OpBase { + using OpBase::OpBase; + using src_ctype = dt_qint8; + using dst_ctype = dt_qint8; + float scale_src0, scale_src1, scale_src2, scale_dst; + float32x4_t vscale_src0, vscale_src1, vscale_src2, vscale_dst; + float scale0, scale1, scale2; + float32x4_t vscale0, vscale1, vscale2; + void init(float src0_scale, float src1_scale, float src2_scale, + float dst_scale) { + scale_src0 = src0_scale; + scale_src1 = src1_scale; + scale_src2 = src2_scale; + scale_dst = 1.f / dst_scale; + vscale_src0 = vdupq_n_f32(scale_src0); + vscale_src1 = vdupq_n_f32(scale_src1); + vscale_src2 = vdupq_n_f32(scale_src2); + vscale_dst = vdupq_n_f32(scale_dst); + scale0 = src0_scale / dst_scale; + scale1 = src1_scale / dst_scale; + scale2 = src2_scale / dst_scale; + vscale0 = vdupq_n_f32(scale0); + vscale1 = vdupq_n_f32(scale1); + vscale2 = vdupq_n_f32(scale2); + } + TernaryOpBase(DType src0_dtype, DType src1_dtype, DType src2_dtype, + DType dst_dtype) { + float src0_scale = src0_dtype.param().scale; + float src1_scale = src1_dtype.param().scale; + float src2_scale = src2_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + init(src0_scale, src1_scale, src2_scale, dst_scale); + } + TernaryOpBase(float src0_scale, float src1_scale, float src2_scale, + float dst_scale) { + init(src0_scale, src1_scale, src2_scale, dst_scale); + } +}; + +template <> +struct TernaryOpBase : OpBase { + using OpBase::OpBase; + using src_ctype = dt_quint8; + using dst_ctype = dt_quint8; + float scale_src0, scale_src1, scale_src2, scale_dst; + float32x4_t vscale_src0, vscale_src1, vscale_src2, vscale_dst; + float scale_zp0, scale_zp1, scale_zp2; + float32x4_t vscale_zp0, vscale_zp1, vscale_zp2; + float scale0, scale1, scale2; + float32x4_t vscale0, vscale1, vscale2; + uint8_t dzp; + int32x4_t vdzp; + void init(float src0_scale, float src1_scale, float src2_scale, + float dst_scale, uint8_t src0_zp, uint8_t src1_zp, + uint8_t src2_zp, uint8_t dst_zp) { + scale_src0 = src0_scale; + scale_src1 = src1_scale; + scale_src2 = src2_scale; + scale_dst = 1.f / dst_scale; + vscale_src0 = vdupq_n_f32(scale_src0); + vscale_src1 = vdupq_n_f32(scale_src1); + vscale_src2 = vdupq_n_f32(scale_src2); + vscale_dst = vdupq_n_f32(scale_dst); + scale_zp0 = src0_zp * scale_src0; + scale_zp1 = src1_zp * scale_src1; + scale_zp2 = src2_zp * scale_src2; + vscale_zp0 = vdupq_n_f32(scale_zp0); + vscale_zp1 = vdupq_n_f32(scale_zp1); + vscale_zp2 = vdupq_n_f32(scale_zp2); + scale0 = src0_scale / dst_scale; + scale1 = src1_scale / dst_scale; + scale2 = src2_scale / dst_scale; + vscale0 = vdupq_n_f32(scale0); + vscale1 = vdupq_n_f32(scale1); + vscale2 = vdupq_n_f32(scale2); + dzp = dst_zp; + vdzp = vdupq_n_s32(static_cast(dzp)); + } + TernaryOpBase(DType src0_dtype, DType src1_dtype, DType src2_dtype, + DType dst_dtype) { + float src0_scale = src0_dtype.param().scale; + float src1_scale = src1_dtype.param().scale; + float src2_scale = src2_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + uint8_t src0_zp = src0_dtype.param().zero_point; + uint8_t src1_zp = src1_dtype.param().zero_point; + uint8_t src2_zp = src2_dtype.param().zero_point; + uint8_t dst_zp = dst_dtype.param().zero_point; + init(src0_scale, src1_scale, src2_scale, dst_scale, src0_zp, src1_zp, + src2_zp, dst_zp); + } + TernaryOpBase(float src0_scale, float src1_scale, float src2_scale, + float dst_scale, uint8_t src0_zp, uint8_t src1_zp, + uint8_t src2_zp, uint8_t dst_zp) { + init(src0_scale, src1_scale, src2_scale, dst_scale, src0_zp, src1_zp, + src2_zp, dst_zp); + } +}; + +////////////////////////// fixup ////////////////////////// +struct FixupBase { + int32x4_t vmultiplier, vshift; + FixupBase(float scale) { + //! ignore Fixup if scale >= 0.5, using typecvt instead of shift & + //! multiplier, as it may introduce errors. + if (scale >= 0.5) + return; + + int shift = static_cast(::ceilf(::log2f(0.5 / scale))); + scale *= ::powf(2, shift); + //! Using double can get full precision here, but it can be ignored. + vmultiplier = vdupq_n_s32( + std::round(static_cast(scale) * ((2LL) << 30))); + vshift = vdupq_n_s32(-shift); + } +}; + +//////////////////////// quantization common //////////////////// +template +struct UnaryQuantizationOp; + +template +struct UnaryQuantizationOp + : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + constexpr static size_t SIMD_WIDTH = 16; + Op op; + + void operator()(const dt_qint8& src, dt_qint8* dst) const { + *dst = operator()(src); + } + + dt_qint8 operator()(const dt_qint8& src) const { + float fsrc = src.as_int8() * this->scale_src; + fsrc = op(fsrc); + fsrc = fsrc * this->scale_dst; + return QConverter::convert(fsrc); + } + + void operator()(const int8x16x2_t& vsrc, dt_qint8* dst) const { + OPERATOR_UNARY_QINT8; + } + + int8x8_t operator()(const int32x4x2_t& vsrc) const { + auto vitem0 = vmulq_f32(vcvtq_f32_s32(vsrc.val[0]), this->vscale_src); + auto vitem1 = vmulq_f32(vcvtq_f32_s32(vsrc.val[1]), this->vscale_src); + auto val = this->op({{vitem0, vitem1}}); + val.val[0] = vmulq_f32(val.val[0], this->vscale_dst); + val.val[1] = vmulq_f32(val.val[1], this->vscale_dst); + return QConverter::convert(val); + } +}; + +template +struct UnaryQuantizationOp + : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + constexpr static size_t SIMD_WIDTH = 16; + Op op; + + void operator()(const dt_quint8& src, dt_quint8* dst) const { + *dst = operator()(src); + } + + dt_quint8 operator()(const dt_quint8& src) const { + float fsrc = src.as_uint8() * this->scale_src - this->scale_zp; + fsrc = op(fsrc); + fsrc = fsrc * this->scale_dst; + return QConverter::convert(fsrc, this->dzp); + } + + void operator()(const uint8x16x2_t& vsrc, dt_quint8* dst) const { + OPERATOR_UNARY_QUINT8; + } + + uint8x8_t operator()(const uint32x4x2_t& vsrc) const { + auto vitem0 = vmulq_f32(vcvtq_f32_u32(vsrc.val[0]), this->vscale_src); + vitem0 = vsubq_f32(vitem0, this->vscale_zp); + auto vitem1 = vmulq_f32(vcvtq_f32_u32(vsrc.val[1]), this->vscale_src); + vitem1 = vsubq_f32(vitem1, this->vscale_zp); + auto val = this->op({{vitem0, vitem1}}); + val.val[0] = vmulq_f32(val.val[0], this->vscale_dst); + val.val[1] = vmulq_f32(val.val[1], this->vscale_dst); + return QConverter::convert( + val, this->vdzp); + } +}; + +template +struct BinaryQuantizationOp; + +template +struct BinaryQuantizationOp + : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + constexpr static size_t SIMD_WIDTH = 16; + Op op; + + void operator()(const dt_qint8& src0, const dt_qint8& src1, + dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { + float fsrc0 = src0.as_int8() * this->scale_src0; + float fsrc1 = src1.as_int8() * this->scale_src1; + float fdst = op(fsrc0, fsrc1); + fdst = fdst * this->scale_dst; + return QConverter::convert(fdst); + } + + void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, + dt_qint8* dst) const { + OPERATOR_BINARY_QINT8; + } + + int8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + auto val0 = vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale_src0); + auto val1 = vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale_src0); + auto val2 = vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale_src1); + auto val3 = vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale_src1); + auto val = op({{val0, val1}}, {{val2, val3}}); + val.val[0] = vmulq_f32(val.val[0], this->vscale_dst); + val.val[1] = vmulq_f32(val.val[1], this->vscale_dst); + return QConverter::convert(val); + } +}; + +template +struct BinaryQuantizationOp + : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + constexpr static size_t SIMD_WIDTH = 16; + Op op; + + void operator()(const dt_quint8& src0, const dt_quint8& src1, + dt_quint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_quint8 operator()(const dt_quint8& src0, const dt_quint8& src1) const { + float fsrc0 = src0.as_uint8() * this->scale_src0 - this->scale_zp0; + float fsrc1 = src1.as_uint8() * this->scale_src1 - this->scale_zp1; + float fdst = op(fsrc0, fsrc1); + fdst = fdst * this->scale_dst; + return QConverter::convert(fdst, this->dzp); + } + + void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { + OPERATOR_BINARY_QUINT8; + } + + uint8x8_t operator()(const uint32x4x2_t& vsrc0, + const uint32x4x2_t& vsrc1) const { + auto val0 = vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale_src0); + val0 = vsubq_f32(val0, this->vscale_zp0); + auto val1 = vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale_src0); + val1 = vsubq_f32(val1, this->vscale_zp0); + auto val2 = vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale_src1); + val2 = vsubq_f32(val2, this->vscale_zp1); + auto val3 = vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale_src1); + val3 = vsubq_f32(val3, this->vscale_zp1); + auto val = op({{val0, val1}}, {{val2, val3}}); + val.val[0] = vmulq_f32(val.val[0], this->vscale_dst); + val.val[1] = vmulq_f32(val.val[1], this->vscale_dst); + return QConverter::convert( + val, this->vdzp); + } +}; + +template +struct TernaryQuantizationOp; + +template +struct TernaryQuantizationOp + : TernaryOpBase { + using TernaryOpBase::TernaryOpBase; + constexpr static size_t SIMD_WIDTH = 16; + Op op; + + void operator()(const dt_qint8& src0, const dt_qint8& src1, + const dt_qint8& src2, dt_qint8* dst) const { + *dst = operator()(src0, src1, src2); + } + + dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1, + const dt_qint8& src2) const { + float fsrc0 = src0.as_int8() * this->scale_src0; + float fsrc1 = src1.as_int8() * this->scale_src1; + float fsrc2 = src2.as_int8() * this->scale_src2; + float fdst = op(fsrc0, fsrc1, fsrc2); + fdst = fdst * this->scale_dst; + return QConverter::convert(fdst); + } + + void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, + const int8x16x2_t& vsrc2, dt_qint8* dst) const { + OPERATOR_TERNARY_QINT8; + } + + int8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1, + const int32x4x2_t& vsrc2) const { + auto val0 = vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale_src0); + auto val1 = vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale_src0); + auto val2 = vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale_src1); + auto val3 = vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale_src1); + auto val4 = vmulq_f32(vcvtq_f32_s32(vsrc2.val[0]), this->vscale_src2); + auto val5 = vmulq_f32(vcvtq_f32_s32(vsrc2.val[1]), this->vscale_src2); + auto val = op({{val0, val1}}, {{val2, val3}}, {{val4, val5}}); + val.val[0] = vmulq_f32(val.val[0], this->vscale_dst); + val.val[1] = vmulq_f32(val.val[1], this->vscale_dst); + return QConverter::convert(val); + } +}; + +template +struct TernaryQuantizationOp + : TernaryOpBase { + using TernaryOpBase::TernaryOpBase; + constexpr static size_t SIMD_WIDTH = 16; + Op op; + + void operator()(const dt_quint8& src0, const dt_quint8& src1, + const dt_quint8& src2, dt_quint8* dst) const { + *dst = operator()(src0, src1, src2); + } + + dt_quint8 operator()(const dt_quint8& src0, const dt_quint8& src1, + const dt_quint8& src2) const { + float fsrc0 = src0.as_uint8() * this->scale_src0 - this->scale_zp0; + float fsrc1 = src1.as_uint8() * this->scale_src1 - this->scale_zp1; + float fsrc2 = src2.as_uint8() * this->scale_src2 - this->scale_zp2; + float fdst = op(fsrc0, fsrc1, fsrc2); + fdst = fdst * this->scale_dst; + return QConverter::convert(fdst, this->dzp); + } + + void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + const uint8x16x2_t& vsrc2, dt_quint8* dst) const { + OPERATOR_TERNARY_QUINT8; + } + + uint8x8_t operator()(const uint32x4x2_t& vsrc0, const uint32x4x2_t& vsrc1, + const uint32x4x2_t& vsrc2) const { + auto val0 = vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale_src0); + val0 = vsubq_f32(val0, this->vscale_zp0); + auto val1 = vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale_src0); + val1 = vsubq_f32(val1, this->vscale_zp0); + auto val2 = vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale_src1); + val2 = vsubq_f32(val2, this->vscale_zp1); + auto val3 = vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale_src1); + val3 = vsubq_f32(val3, this->vscale_zp1); + auto val4 = vmulq_f32(vcvtq_f32_u32(vsrc2.val[0]), this->vscale_src2); + val4 = vsubq_f32(val4, this->vscale_zp2); + auto val5 = vmulq_f32(vcvtq_f32_u32(vsrc2.val[1]), this->vscale_src2); + val5 = vsubq_f32(val5, this->vscale_zp2); + auto val = op({{val0, val1}}, {{val2, val3}}, {{val4, val5}}); + val.val[0] = vmulq_f32(val.val[0], this->vscale_dst); + val.val[1] = vmulq_f32(val.val[1], this->vscale_dst); + return QConverter::convert( + val, this->vdzp); + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/pow.h b/dnn/src/arm_common/elemwise_helper/kimpl/pow.h new file mode 100644 index 00000000..d8b9b982 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/pow.h @@ -0,0 +1,36 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/pow.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +// when __fp16 is avaliable POW is very slow, so add there +/////////////////////// POW float only //////////////////////////// +template +struct PowOp : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + constexpr static size_t SIMD_WIDTH = 1; + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return powf(src0, src1); + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/relu.h b/dnn/src/arm_common/elemwise_helper/kimpl/relu.h new file mode 100644 index 00000000..5335070c --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/relu.h @@ -0,0 +1,246 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/relu.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct ReluOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { + return src > 0 ? src : 0; + } +}; + +template +struct ReluOp; + +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct ReluOp<_ctype> : ReluOpBase<_ctype> { \ + using ReluOpBase::ReluOpBase; \ + using ReluOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()(const _neon_type2& src) const { \ + auto vzero = vdupq_n_##_func_suffix(0); \ + auto vitem0 = vmaxq_##_func_suffix(src.val[0], vzero); \ + auto vitem1 = vmaxq_##_func_suffix(src.val[1], vzero); \ + return {{vitem0, vitem1}}; \ + } \ + _neon_type operator()(const _neon_type& src) const { \ + auto vzero = vdupq_n_##_func_suffix(0); \ + return vmaxq_##_func_suffix(src, vzero); \ + } \ + }; +OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) +#endif +OP(dt_int32, int32x4_t, int32x4x2_t, s32, 4) +OP(dt_int16, int16x8_t, int16x8x2_t, s16, 8) +OP(dt_int8, int8x16_t, int8x16x2_t, s8, 16) +#undef OP + +template <> +struct ReluOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint8& src, dt_qint8* dst) const { + *dst = operator()(src); + } + dt_qint8 operator()(const dt_qint8& src) const { + float fsrc = src.as_int8() * this->scale; + fsrc = std::max(fsrc, 0.f); + return QConverter::convert(fsrc); + } +}; + +template <> +struct ReluOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_quint8& src, dt_quint8* dst) const { + *dst = operator()(src); + } + dt_quint8 operator()(const dt_quint8& src) const { + float fsrc = src.as_uint8() * this->scale - szp; + fsrc = std::max(fsrc, 0.f); + return QConverter::convert(fsrc, this->dzp); + } +}; + +template <> +struct ReluOp : ReluOpBase { + using ReluOpBase::ReluOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using ReluOpBase::operator(); + + void operator()(const int8x16x2_t& vsrc, dt_qint8* dst) const { + OPERATOR_UNARY_QINT8; + } + int8x8_t operator()(const int32x4x2_t& vsrc) const { + auto vzero = vdupq_n_f32(0.f); + auto vitem0 = vmulq_f32(vcvtq_f32_s32(vsrc.val[0]), this->vscale); + auto vitem1 = vmulq_f32(vcvtq_f32_s32(vsrc.val[1]), this->vscale); + vitem0 = vmaxq_f32(vitem0, vzero); + vitem1 = vmaxq_f32(vitem1, vzero); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +template <> +struct ReluOp : ReluOpBase { + using ReluOpBase::ReluOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using ReluOpBase::operator(); + + void operator()(const uint8x16x2_t& vsrc, dt_quint8* dst) const { + OPERATOR_UNARY_QUINT8; + } + uint8x8_t operator()(const uint32x4x2_t& vsrc) const { + auto vzero = vdupq_n_f32(0.f); + auto vitem0 = vmulq_f32(vcvtq_f32_u32(vsrc.val[0]), this->vscale); + auto vitem1 = vmulq_f32(vcvtq_f32_u32(vsrc.val[1]), this->vscale); + vitem0 = vsubq_f32(vitem0, this->vszp); + vitem1 = vsubq_f32(vitem1, this->vszp); + vitem0 = vmaxq_f32(vitem0, vzero); + vitem1 = vmaxq_f32(vitem1, vzero); + return QConverter::convert( + {{vitem0, vitem1}}, this->vdzp); + } +}; + +template <> +struct ReluOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint32& src, dt_qint8* dst) const { + *dst = operator()(src); + } + + dt_qint8 operator()(const dt_qint32& src) const { + float fsrc = src.as_int32() * this->scale; + fsrc = std::max(fsrc, 0.f); + return QConverter::convert(fsrc); + } +}; + +template <> +struct ReluOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint32& src, dt_quint8* dst) const { + *dst = operator()(src); + } + + dt_quint8 operator()(const dt_qint32& src) const { + float fsrc = src.as_int32() * this->scale; + fsrc = std::max(fsrc, 0.f); + return QConverter::convert(fsrc, this->zp); + } +}; + +#if __ARM_ARCH >= 8 +template <> +struct ReluOp : ReluOpBase { + using ReluOpBase::ReluOpBase; + using ReluOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + + void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { + vst1_s8(reinterpret_cast(dst), operator()(vsrc)); + } + void operator()(const int32x4_t& src, dt_qint8* dst) const { + vst1_lane_s32(reinterpret_cast(dst), + (int32x2_t)(operator()(src)), 0); + } + + int8x8_t operator()(const int32x4x2_t& vsrc) const { + auto vitem0 = vmulq_f32(vcvtq_f32_s32(vsrc.val[0]), this->vscale); + auto vitem1 = vmulq_f32(vcvtq_f32_s32(vsrc.val[1]), this->vscale); + vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); + vitem1 = vmaxq_f32(vitem1, QConverterBase::vfzero()); + + return QConverter::convert({{vitem0, vitem1}}); + } + int8x8_t operator()(const int32x4_t& src) const { + auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); + vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); + return QConverter::convert(vitem0); + } +}; +#else +template <> +struct ReluOp : ReluOpBase, + FixupBase { + using ReluOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + + ReluOp(DType src_dtype, DType dst_dtype) + : ReluOpBase(src_dtype, dst_dtype), FixupBase(scale) {} + + ReluOp(float src_scale, float dst_scale) + : ReluOpBase(src_scale, dst_scale), FixupBase(scale) {} + + void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { + vst1_s8(reinterpret_cast(dst), operator()(vsrc)); + } + + int8x8_t operator()(const int32x4x2_t& vsrc) const { + int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier); + int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); + vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); + vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero()); + return vqmovn_s16(vcombine_s16(vqmovn_s32(vrshlq_s32(vitem0, vshift)), + vqmovn_s32(vrshlq_s32(vitem1, vshift)))); + } + void operator()(const int32x4_t& src, dt_qint8* dst) const { + auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); + vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); + auto result = QConverter::convert(vitem0); + vst1_lane_s32(reinterpret_cast(dst), (int32x2_t)result, 0); + } +}; +#endif + +template <> +struct ReluOp : ReluOpBase { + using ReluOpBase::ReluOpBase; + using ReluOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + + void operator()(const int32x4x2_t& vsrc, dt_quint8* dst) const { + vst1_u8(reinterpret_cast(dst), operator()(vsrc)); + } + + uint8x8_t operator()(const int32x4x2_t& vsrc) const { + auto vitem0 = vmulq_f32(vcvtq_f32_s32(vsrc.val[0]), this->vscale); + auto vitem1 = vmulq_f32(vcvtq_f32_s32(vsrc.val[1]), this->vscale); + vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); + vitem1 = vmaxq_f32(vitem1, QConverterBase::vfzero()); + + return QConverter::convert({{vitem0, vitem1}}, + this->vzp); + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/rmulh.h b/dnn/src/arm_common/elemwise_helper/kimpl/rmulh.h new file mode 100644 index 00000000..89e387f3 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/rmulh.h @@ -0,0 +1,144 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/rmulh.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct RmulhOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return round_mulh_saturate(src0, src1); + } +}; + +template +struct RmulhOp; + +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct RmulhOp<_ctype> : RmulhOpBase<_ctype> { \ + using RmulhOpBase::RmulhOpBase; \ + using RmulhOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()(const _neon_type2& src0, \ + const _neon_type2& src1) const { \ + auto vitem0 = vqrdmulhq_##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = vqrdmulhq_##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()(const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, \ + const _neon_type& src1) const { \ + return vqrdmulhq_##_func_suffix(src0, src1); \ + } \ + }; +OP(dt_int32, int32x4_t, int32x4x2_t, s32, 4) +OP(dt_int16, int16x8_t, int16x8x2_t, s16, 8) +#undef OP +/** + * As There is no vqrdmulh.s8, we have to emulate it manually as this is + * requested by the researchers + */ +template <> +struct RmulhOp : RmulhOpBase { + using RmulhOpBase::RmulhOpBase; + using RmulhOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 16; + void operator()(const int8x16x2_t& src0, const int8x16x2_t& src1, + int8_t* dst) const { + auto vitem = operator()(src0, src1); + vst1q_s8(dst, vitem.val[0]); + vst1q_s8(dst + SIMD_WIDTH, vitem.val[1]); + } + int8x16x2_t operator()(const int8x16x2_t& src0, + const int8x16x2_t& src1) const { + int8x16_t val, var; + int8x8_t lol, hil, lor, hir; + int16x8_t mu0, mu1; + + val = src0.val[0]; + var = src1.val[0]; + lol = vget_low_s8(val); + hil = vget_high_s8(val); + lor = vget_low_s8(var); + hir = vget_high_s8(var); + + mu0 = vmull_s8(lol, lor); + lol = vqrshrn_n_s16(mu0, 7); + mu1 = vmull_s8(hil, hir); + hil = vqrshrn_n_s16(mu1, 7); + + int8x16_t val1, var1; + int8x8_t lol1, hil1, lor1, hir1; + int16x8_t mu01, mu11; + + val1 = src0.val[1]; + var1 = src1.val[1]; + lol1 = vget_low_s8(val1); + hil1 = vget_high_s8(val1); + lor1 = vget_low_s8(var1); + hir1 = vget_high_s8(var1); + + mu01 = vmull_s8(lol1, lor1); + lol1 = vqrshrn_n_s16(mu01, 7); + mu11 = vmull_s8(hil1, hir1); + hil1 = vqrshrn_n_s16(mu11, 7); + + return {{vcombine_s8(lol, hil), vcombine_s8(lol1, hil1)}}; + } + void operator()(const int8x16_t& src0, const int8x16_t& src1, + int8_t* dst) const { + auto vitem = operator()(src0, src1); + vst1q_s8(dst, vitem); + } + int8x16_t operator()(const int8x16_t& src0, const int8x16_t& src1) const { + int8x16_t val, var; + int8x8_t lol, hil, lor, hir; + int16x8_t mu0, mu1; + + val = src0; + var = src1; + lol = vget_low_s8(val); + hil = vget_high_s8(val); + lor = vget_low_s8(var); + hir = vget_high_s8(var); + + mu0 = vmull_s8(lol, lor); + lol = vqrshrn_n_s16(mu0, 7); + mu1 = vmull_s8(hil, hir); + hil = vqrshrn_n_s16(mu1, 7); + + return vcombine_s8(lol, hil); + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h b/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h new file mode 100644 index 00000000..60be900b --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h @@ -0,0 +1,70 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct SigmoidOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { + float tmpf = src; + tmpf = exp(-tmpf); + tmpf = 1.f / (1.f + tmpf); + return tmpf; + } +}; + +template +struct SigmoidOp; + +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct SigmoidOp<_ctype> : SigmoidOpBase<_ctype> { \ + using SigmoidOpBase::SigmoidOpBase; \ + using SigmoidOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()(const _neon_type2& src) const { \ + return {{operator()(src.val[0]), operator()(src.val[1])}}; \ + } \ + _neon_type operator()(const _neon_type& src) const { \ + auto zero_val = vdupq_n_##_func_suffix(0.f); \ + auto one_val = vdupq_n_##_func_suffix(1.f); \ + auto val1 = vsubq_##_func_suffix(zero_val, src); \ + val1 = exp_ps_##_func_suffix(val1); \ + auto recipe1 = vaddq_##_func_suffix(one_val, val1); \ + val1 = vrecpeq_##_func_suffix(recipe1); \ + val1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe1, val1), \ + val1); \ + return val1; \ + } \ + }; +OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) +#endif +#undef OP + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/sub.h b/dnn/src/arm_common/elemwise_helper/kimpl/sub.h new file mode 100644 index 00000000..73a13227 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/sub.h @@ -0,0 +1,156 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/sub.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct SubOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return src0 - src1; + } +}; + +template +struct SubOp; + +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct SubOp<_ctype> : SubOpBase<_ctype> { \ + using SubOpBase::SubOpBase; \ + using SubOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()(const _neon_type2& src0, \ + const _neon_type2& src1) const { \ + auto vitem0 = vsubq_##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = vsubq_##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()(const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, \ + const _neon_type& src1) const { \ + return vsubq_##_func_suffix(src0, src1); \ + } \ + }; +OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) +#endif +OP(dt_int32, int32x4_t, int32x4x2_t, s32, 4) +OP(dt_int16, int16x8_t, int16x8x2_t, s16, 8) +OP(dt_int8, int8x16_t, int8x16x2_t, s8, 16) +#undef OP + +template <> +struct SubOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + + void operator()(const dt_qint8& src0, const dt_qint8& src1, + dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { + return QConverter::convert(src0.as_int8() * scale0 - + src1.as_int8() * scale1); + } +}; + +template <> +struct SubOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + + void operator()(const dt_quint8& src0, const dt_quint8& src1, + dt_quint8* dst) const { + *dst = operator()(src0, src1); + } + dt_quint8 operator()(const dt_quint8& src0, const dt_quint8& src1) const { + float fsrc0 = src0.as_uint8() * scale0 - this->szp0; + float fsrc1 = src1.as_uint8() * scale1 - this->szp1; + return QConverter::convert(fsrc0 - fsrc1, + this->dzp); + } +}; + +template <> +struct SubOp : SubOpBase { + using SubOpBase::SubOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using SubOpBase::operator(); + + void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, + dt_qint8* dst) const { + OPERATOR_BINARY_QINT8; + } + int8x8_t operator()(const int32x4x2_t& vsrc0, + const int32x4x2_t& vsrc1) const { + auto vitem0 = vsubq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = vsubq_f32( + vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale0), + vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +template <> +struct SubOp : SubOpBase { + using SubOpBase::SubOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using SubOpBase::operator(); + + void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { + OPERATOR_BINARY_QUINT8; + } + uint8x8_t operator()(const uint32x4x2_t& vsrc0, + const uint32x4x2_t& vsrc1) const { + auto vfsrc0 = + vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale0), + this->vszp0); + auto vfsrc1 = + vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale1), + this->vszp1); + auto vitem0 = vsubq_f32(vfsrc0, vfsrc1); + vfsrc0 = + vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale0), + this->vszp0); + vfsrc1 = + vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale1), + this->vszp1); + auto vitem1 = vsubq_f32(vfsrc0, vfsrc1); + return QConverter::convert( + {{vitem0, vitem1}}, this->vdzp); + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/tanh.h b/dnn/src/arm_common/elemwise_helper/kimpl/tanh.h new file mode 100644 index 00000000..dc8b73c2 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/tanh.h @@ -0,0 +1,77 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/tanh.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct TanhOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { + float tmp = src; + return tanh(tmp); + } +}; + +template +struct TanhOp; + +#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ + template <> \ + struct TanhOp<_ctype> : TanhOpBase<_ctype> { \ + using TanhOpBase::TanhOpBase; \ + using TanhOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type operator()(const _neon_type& src) const { \ + auto one_val = vdupq_n_##_func_suffix(1.f); \ + auto two_val = vdupq_n_##_func_suffix(2.f); \ + auto val1 = src.val[0]; \ + auto val2 = src.val[1]; \ + val1 = vmulq_##_func_suffix(two_val, val1); \ + val2 = vmulq_##_func_suffix(two_val, val2); \ + val1 = exp_ps_##_func_suffix(val1); \ + val2 = exp_ps_##_func_suffix(val2); \ + val1 = vaddq_##_func_suffix(one_val, val1); \ + val2 = vaddq_##_func_suffix(one_val, val2); \ + auto rval1 = vrecpeq_##_func_suffix(val1); \ + auto rval2 = vrecpeq_##_func_suffix(val2); \ + rval1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(val1, rval1), \ + rval1); \ + rval2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(val2, rval2), \ + rval2); \ + val1 = vmulq_##_func_suffix(two_val, rval1); \ + val2 = vmulq_##_func_suffix(two_val, rval2); \ + val1 = vsubq_##_func_suffix(one_val, val1); \ + val2 = vsubq_##_func_suffix(one_val, val2); \ + return {{val1, val2}}; \ + } \ + }; +OP(dt_float32, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8x2_t, f16, 8) +#endif +#undef OP + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/true_div.h b/dnn/src/arm_common/elemwise_helper/kimpl/true_div.h new file mode 100644 index 00000000..d9341482 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/true_div.h @@ -0,0 +1,88 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/true_div.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +//! use a couple Newton-Raphson steps to refine the estimate. +//! A / B => 1. rB = vrecpeq_f32(B) 2. rB= vmulq_f32(vrecpsq_f32(B, rB), rB) +//! 3. A * rB +template +struct TrueDivOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const src_ctype& src0, const src_ctype& src1, + dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return src0 / src1; + } +}; + +#if MEGDNN_AARCH64 +template +struct TrueDivOp; + +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct TrueDivOp<_ctype> : TrueDivOpBase<_ctype> { \ + using TrueDivOpBase::TrueDivOpBase; \ + using TrueDivOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()(const _neon_type2& src0, \ + const _neon_type2& src1) const { \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + val1 = vdivq_##_func_suffix(val1, val3); \ + val2 = vdivq_##_func_suffix(val2, val4); \ + return {{val1, val2}}; \ + } \ + void operator()(const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, \ + const _neon_type& src1) const { \ + return vdivq_##_func_suffix(src0, src1); \ + } \ + }; +OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) +#endif +#undef OP + +#else + +template +struct TrueDivOp : TrueDivOpBase { + using TrueDivOpBase::TrueDivOpBase; + using TrueDivOpBase::operator(); +}; + +#endif + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/typecvt.h b/dnn/src/arm_common/elemwise_helper/kimpl/typecvt.h new file mode 100644 index 00000000..3104b0ee --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/kimpl/typecvt.h @@ -0,0 +1,130 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/kimpl/typecvt.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace arm_common { + +template +struct TypeCvtOp; + +#if __ARM_ARCH >= 8 +template <> +struct TypeCvtOp : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + constexpr static size_t SIMD_WIDTH = 4; + + void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { + vst1_s8(reinterpret_cast(dst), operator()(vsrc)); + } + void operator()(const int32x4_t& vsrc, dt_qint8* dst) const { + vst1_lane_s32(reinterpret_cast(dst), + (int32x2_t)(operator()(vsrc)), 0); + } + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dt_qint8 operator()(const dt_qint32& src) const { + float fsrc = src.as_int32() * this->scale; + return QConverter::convert(fsrc); + } + + int8x8_t operator()(const int32x4x2_t& vsrc) const { + auto vitem0 = vmulq_f32(vcvtq_f32_s32(vsrc.val[0]), this->vscale); + auto vitem1 = vmulq_f32(vcvtq_f32_s32(vsrc.val[1]), this->vscale); + + return QConverter::convert({{vitem0, vitem1}}); + } + int8x8_t operator()(const int32x4_t& src) const { + auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); + return QConverter::convert(vitem0); + } +}; +#else +template <> +struct TypeCvtOp : UnaryOpBase, + FixupBase { + constexpr static size_t SIMD_WIDTH = 4; + + TypeCvtOp(DType src_dtype, DType dst_dtype) + : UnaryOpBase(src_dtype, dst_dtype), FixupBase(scale) {} + + TypeCvtOp(float src_scale, float dst_scale) + : UnaryOpBase(src_scale, dst_scale), FixupBase(scale) {} + + void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { + vst1_s8(reinterpret_cast(dst), operator()(vsrc)); + } + void operator()(const int32x4_t& vsrc, dt_qint8* dst) const { + vst1_lane_s32(reinterpret_cast(dst), + (int32x2_t)(operator()(vsrc)), 0); + } + dt_qint8 operator()(const dt_qint32& src) const { + float fsrc = src.as_int32() * this->scale; + return QConverter::convert(fsrc); + } + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + int8x8_t operator()(const int32x4x2_t& vsrc) const { + int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier); + int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); + auto fixup0 = vshrq_n_s32(vitem0, 31); + auto fixup1 = vshrq_n_s32(vitem1, 31); + // FIXME Theoretically, we should check shift != 0 here. + vitem0 = vqaddq_s32(vitem0, fixup0); + vitem1 = vqaddq_s32(vitem1, fixup1); + return vqmovn_s16(vcombine_s16(vqmovn_s32(vrshlq_s32(vitem0, vshift)), + vqmovn_s32(vrshlq_s32(vitem1, vshift)))); + } + int8x8_t operator()(const int32x4_t& src) const { + int32x4_t vitem0 = vqrdmulhq_s32(src, vmultiplier); + auto fixup0 = vshrq_n_s32(vitem0, 31); + vitem0 = vqaddq_s32(vitem0, fixup0); + int16x4_t vres0_int16 = vqmovn_s32(vrshlq_s32(vitem0, vshift)); + return vqmovn_s16(vcombine_s16(vres0_int16, vres0_int16)); + } +}; +#endif + +template <> +struct TypeCvtOp : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + constexpr static size_t SIMD_WIDTH = 4; + + void operator()(const int32x4x2_t& vsrc, dt_quint8* dst) const { + vst1_u8(reinterpret_cast(dst), operator()(vsrc)); + } + + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + + dt_quint8 operator()(const src_ctype& src) const { + return QConverter::convert( + src.as_int32() * this->scale, this->zp); + } + uint8x8_t operator()(const int32x4x2_t& vsrc) const { + auto vitem0 = vmulq_f32(vcvtq_f32_s32(vsrc.val[0]), this->vscale); + auto vitem1 = vmulq_f32(vcvtq_f32_s32(vsrc.val[1]), this->vscale); + + return QConverter::convert({{vitem0, vitem1}}, + this->vzp); + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/op_binary.h b/dnn/src/arm_common/elemwise_helper/op_binary.h new file mode 100644 index 00000000..2d1ada02 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/op_binary.h @@ -0,0 +1,52 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/op_binary.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/add.h" +#include "src/arm_common/elemwise_helper/kimpl/mul.h" +#include "src/arm_common/elemwise_helper/kimpl/rmulh.h" +#include "src/arm_common/elemwise_helper/kimpl/fuse_add_relu.h" +#include "src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h" +#include "src/arm_common/elemwise_helper/kimpl/fuse_add_tanh.h" +#include "src/arm_common/elemwise_helper/kimpl/fuse_add_h_swish.h" +#include "src/arm_common/elemwise_helper/kimpl/max.h" +#include "src/arm_common/elemwise_helper/kimpl/min.h" +#include "src/arm_common/elemwise_helper/kimpl/pow.h" +#include "src/arm_common/elemwise_helper/kimpl/sub.h" +#include "src/arm_common/elemwise_helper/kimpl/true_div.h" + +//////////////////// quantization ////////////////////////////// +namespace megdnn { +namespace arm_common { +#define cb(op) \ + template <> \ + struct op \ + : BinaryQuantizationOp > { \ + using BinaryQuantizationOp >::BinaryQuantizationOp; \ + }; \ + template <> \ + struct op \ + : BinaryQuantizationOp > { \ + using BinaryQuantizationOp >::BinaryQuantizationOp; \ + }; + +cb(TrueDivOp); +cb(FuseAddSigmoidOp); +cb(FuseAddTanhOp); +cb(FuseAddHSwishOp); + +#undef cb +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/op_ternary.h b/dnn/src/arm_common/elemwise_helper/op_ternary.h new file mode 100644 index 00000000..26ea4fcc --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/op_ternary.h @@ -0,0 +1,37 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/op_ternary.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/fuse_mul_add3.h" + +//////////////////// quantization ////////////////////////////// +namespace megdnn { +namespace arm_common { +#define cb(op) \ + template <> \ + struct op \ + : TernaryQuantizationOp > { \ + using TernaryQuantizationOp >::TernaryQuantizationOp; \ + }; \ + template <> \ + struct op \ + : TernaryQuantizationOp > { \ + using TernaryQuantizationOp >::TernaryQuantizationOp; \ + }; + +cb(FuseMulAdd3Op); +#undef cb +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/op_unary.h b/dnn/src/arm_common/elemwise_helper/op_unary.h new file mode 100644 index 00000000..ce901224 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/op_unary.h @@ -0,0 +1,50 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/op_unary.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/elemwise_helper/kimpl/none.h" +#include "src/arm_common/elemwise_helper/kimpl/abs.h" +#include "src/arm_common/elemwise_helper/kimpl/exp.h" +#include "src/arm_common/elemwise_helper/kimpl/fast_tanh.h" +#include "src/arm_common/elemwise_helper/kimpl/hswish.h" +#include "src/arm_common/elemwise_helper/kimpl/relu.h" +#include "src/arm_common/elemwise_helper/kimpl/sigmoid.h" +#include "src/arm_common/elemwise_helper/kimpl/tanh.h" +#include "src/arm_common/elemwise_helper/kimpl/hswish.h" +#include "src/arm_common/elemwise_helper/kimpl/typecvt.h" + +//////////////////// quantization ////////////////////////////// +namespace megdnn { +namespace arm_common { +#define cb(op) \ + template <> \ + struct op \ + : UnaryQuantizationOp > { \ + using UnaryQuantizationOp >::UnaryQuantizationOp; \ + }; \ + template <> \ + struct op \ + : UnaryQuantizationOp > { \ + using UnaryQuantizationOp >::UnaryQuantizationOp; \ + }; + +cb(SigmoidOp); +cb(ExpOp); +cb(TanhOp); +cb(FastTanhOp); +cb(HSwishOp); +#undef cb +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp new file mode 100644 index 00000000..bcebe699 --- /dev/null +++ b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp @@ -0,0 +1,825 @@ +/** + * \file dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "./opr_impl.h" +#include "src/common/elemwise_multi_type/kern_defs.cuh" +#include "src/naive/handle.h" + +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace { + +using namespace megdnn; + +template +void neon_round_shr_saturate_int16_static_k(const int16_t* a_ptr, size_t size, + int8_t* dst_ptr) { + static_assert(k >= 1 && k <= 8, "Shift offset out of range"); + size_t i = 0; + int16x8_t x0, x1, f0, f1; + for (; i + 15 < size; i += 16, a_ptr += 16, dst_ptr += 16) { + x0 = vld1q_s16(a_ptr); + x1 = vld1q_s16(a_ptr + 8); + f0 = vshrq_n_s16(x0, 15); + f1 = vshrq_n_s16(x1, 15); + x0 = vqaddq_s16(x0, f0); + x1 = vqaddq_s16(x1, f1); + vst1_s8(dst_ptr, vqrshrn_n_s16(x0, k)); + vst1_s8(dst_ptr + 8, vqrshrn_n_s16(x1, k)); + } + for (; i < size; i++, a_ptr++, dst_ptr++) { + *dst_ptr = megdnn::elemwise_multi_type::round_shr_saturate( + *a_ptr, k); + } +} + +} // namespace + +namespace megdnn { +namespace arm_common { + +template +void ElemwiseMultiTypeImpl::neon_round_shr_saturate_bcast_scalar( + const stype* a_ptr, int8_t k, size_t size, dt_int8* dst_ptr) { + MEGDNN_MARK_USED_VAR(a_ptr); + MEGDNN_MARK_USED_VAR(k); + MEGDNN_MARK_USED_VAR(size); + MEGDNN_MARK_USED_VAR(dst_ptr); + megdnn_throw( + "ElemwiseMultiType (mode=ROUND_SHR_SATURATE) only supports int8, " + "int16 and int32 on ARM"); +} + +template <> +void ElemwiseMultiTypeImpl::neon_round_shr_saturate_bcast_scalar( + const int8_t* a_ptr, int8_t k, size_t size, dt_int8* dst_ptr) { + size_t i = 0; + const int8x16_t shift_vec = vdupq_n_s8(-k); + int8x16_t x0, x1, f0, f1; + for (; i + 31 < size; i += 32, a_ptr += 32, dst_ptr += 32) { + x0 = vld1q_s8(a_ptr); + x1 = vld1q_s8(a_ptr + 16); + f0 = vshrq_n_s8(x0, 7); + f1 = vshrq_n_s8(x1, 7); + x0 = vqaddq_s8(x0, f0); + x1 = vqaddq_s8(x1, f1); + vst1q_s8(dst_ptr, vrshlq_s8(x0, shift_vec)); + vst1q_s8(dst_ptr + 16, vrshlq_s8(x1, shift_vec)); + } + for (; i < size; i++, a_ptr++, dst_ptr++) { + *dst_ptr = elemwise_multi_type::round_shr_saturate( + *a_ptr, k); + } +} + +template <> +void ElemwiseMultiTypeImpl::neon_round_shr_saturate_bcast_scalar( + const int16_t* a_ptr, int8_t k, size_t size, dt_int8* dst_ptr) { + // vqrshrn_n_s16 is significantly faster than vrshlq_s16 + vqmovn_s16, but + // it requires that shift offset is known at compile time. + switch (k) { +#define DISPATCH(i) \ + case i: \ + neon_round_shr_saturate_int16_static_k(a_ptr, size, dst_ptr); \ + return; + DISPATCH(1) + DISPATCH(2) + DISPATCH(3) + DISPATCH(4) + DISPATCH(5) + DISPATCH(6) + DISPATCH(7) + DISPATCH(8) +#undef DISPATCH + default: + break; + } + + size_t i = 0; + const int16x8_t shift_vec = vdupq_n_s16(-k); + int16x8_t x0, x1, f0, f1; + for (; i + 15 < size; i += 16, a_ptr += 16, dst_ptr += 16) { + x0 = vld1q_s16(a_ptr); + x1 = vld1q_s16(a_ptr + 8); + f0 = vshrq_n_s16(x0, 15); + f1 = vshrq_n_s16(x1, 15); + x0 = vqaddq_s16(x0, f0); + x1 = vqaddq_s16(x1, f1); + vst1_s8(dst_ptr, vqmovn_s16(vrshlq_s16(x0, shift_vec))); + vst1_s8(dst_ptr + 8, vqmovn_s16(vrshlq_s16(x1, shift_vec))); + } + for (; i < size; i++, a_ptr++, dst_ptr++) { + *dst_ptr = elemwise_multi_type::round_shr_saturate( + *a_ptr, k); + } +} + +template <> +void ElemwiseMultiTypeImpl::neon_round_shr_saturate_bcast_scalar( + const int32_t* a_ptr, int8_t k, size_t size, dt_int8* dst_ptr) { + size_t i = 0; + const int32x4_t shift_vec = vdupq_n_s32(-k); + int32x4_t x0, x1, f0, f1; + int8x8_t o0; + for (; i + 7 < size; i += 8, a_ptr += 8, dst_ptr += 8) { + x0 = vld1q_s32(a_ptr); + x1 = vld1q_s32(a_ptr + 4); + f0 = vshrq_n_s32(x0, 31); + f1 = vshrq_n_s32(x1, 31); + x0 = vqaddq_s32(x0, f0); + x1 = vqaddq_s32(x1, f1); + o0 = vqmovn_s16(vcombine_s16(vqmovn_s32(vrshlq_s32(x0, shift_vec)), + vqmovn_s32(vrshlq_s32(x1, shift_vec)))); + vst1_s8(dst_ptr, o0); + } + for (; i < size; i++, a_ptr++, dst_ptr++) { + *dst_ptr = elemwise_multi_type::round_shr_saturate( + *a_ptr, k); + } +} + +template +void ElemwiseMultiTypeImpl::dispatch_round_shr_saturate_iXxi8xi8_bcast_scalar( + const ElemwiseOpParamN<2>& param, megdnn::dt_int8* dst) { + auto a_ptr = param[0].ptr(); + auto k = param[1].ptr()[0]; + size_t size = param.size; + + MEGDNN_DISPATCH_CPU_KERN_OPR( + neon_round_shr_saturate_bcast_scalar(a_ptr, k, size, dst)); +} + +void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( + const ElemwiseOpParamN<2>& param, megdnn::dt_int8* dst) { + if (is_vector(param[0].layout) && is_broadcasted_scalar(param[1].layout)) { + switch (param[0].layout.dtype.enumv()) { +#define cb(t) \ + case DTypeTrait::enumv: \ + return dispatch_round_shr_saturate_iXxi8xi8_bcast_scalar< \ + DTypeTrait::ctype>(param, dst); + MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) +#undef cb + default: + megdnn_throw( + "ElemwiseMultiType (mode=ROUND_SHR_SATURATE) only " + "supports int8, int16 and int32 on ARM"); + } + } + + fallback::ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8(param, dst); +} + +void neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_int16( + size_t batch_size, size_t channel_size, size_t channel_stride, + const int16_t* x_ptr, const int16_t* b_ptr, const int16_t M, + const int offset, const int8_t minv, const int8_t maxv, size_t size, + int8_t* dst_ptr) { + MEGDNN_MARK_USED_VAR(size); + const int16x8_t shift_vec = vdupq_n_s16(-offset); + const int16x8_t M_vec = vdupq_n_s16(M); + const int8x16_t minv_vec = vdupq_n_s8(minv); + const int8x16_t maxv_vec = vdupq_n_s8(maxv); + + size_t i = 0, b_pos = 0, channel_offset = 0; + for (size_t batch = 0; batch < batch_size; ++batch) { + b_pos = 0; + for (size_t chan = 0; chan < channel_size; ++chan, ++b_pos) { + auto b_vec = vdupq_n_s16(b_ptr[b_pos]); + channel_offset += channel_stride; + for (; i + 15 < channel_offset; + i += 16, x_ptr += 16, dst_ptr += 16) { + auto x0 = vld1q_s16(x_ptr); + auto x1 = vld1q_s16(x_ptr + 8); + x0 = vaddq_s16(x0, b_vec); + x1 = vaddq_s16(x1, b_vec); + x0 = vqrdmulhq_s16(x0, M_vec); + x1 = vqrdmulhq_s16(x1, M_vec); + // FIXME Theoretically, we should check shift != 0 here, + auto fixup0 = vshrq_n_s16(x0, 15); + auto fixup1 = vshrq_n_s16(x1, 15); + x0 = vqaddq_s16(x0, fixup0); + x1 = vqaddq_s16(x1, fixup1); + auto o0 = vcombine_s8(vqmovn_s16(vrshlq_s16(x0, shift_vec)), + vqmovn_s16(vrshlq_s16(x1, shift_vec))); + o0 = vminq_s8(o0, maxv_vec); + o0 = vmaxq_s8(o0, minv_vec); + vst1q_s8(dst_ptr, o0); + } + for (; i + 7 < channel_offset; i += 8, x_ptr += 8, dst_ptr += 8) { + auto x0 = vld1q_s16(x_ptr); + x0 = vaddq_s16(x0, b_vec); + x0 = vqrdmulhq_s16(x0, M_vec); + // FIXME Theoretically, we should check shift != 0 here, + auto fixup0 = vshrq_n_s16(x0, 15); + x0 = vqaddq_s16(x0, fixup0); + auto o0 = vqmovn_s16(vrshlq_s16(x0, shift_vec)); + o0 = vmin_s8(o0, vget_low_s8(maxv_vec)); + o0 = vmax_s8(o0, vget_low_s8(minv_vec)); + vst1_s8(dst_ptr, o0); + } + dt_int16 bias = b_ptr[b_pos]; + for (; i < channel_offset; ++i, ++x_ptr, ++dst_ptr) { + dt_int16 result = rounding_shift_right_away_from_zero( + round_mulh_saturate(*x_ptr + bias, M), + offset); + *dst_ptr = static_cast(std::max( + std::min(result, maxv), minv)); + } + } + } +} + +void neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_int32( + size_t batch_size, size_t channel_size, size_t channel_stride, + const int32_t* x_ptr, const int32_t* b_ptr, const int32_t M, + const int offset, const int8_t minv, const int8_t maxv, size_t size, + int8_t* dst_ptr) { + MEGDNN_MARK_USED_VAR(size); + const int32x4_t shift_vec = vdupq_n_s32(-offset); + const int32x4_t M_vec = vdupq_n_s32(M); + const int8x8_t minv_vec = vdup_n_s8(minv); + const int8x8_t maxv_vec = vdup_n_s8(maxv); + + size_t i = 0, b_pos = 0, channel_offset = 0; + for (size_t batch = 0; batch < batch_size; ++batch) { + b_pos = 0; + for (size_t chan = 0; chan < channel_size; ++chan, ++b_pos) { + int32x4_t b_vec = vdupq_n_s32(b_ptr[b_pos]); + channel_offset += channel_stride; + for (; i + 7 < channel_offset; i += 8, x_ptr += 8, dst_ptr += 8) { + auto x0 = vld1q_s32(x_ptr); + auto x1 = vld1q_s32(x_ptr + 4); + x0 = vaddq_s32(x0, b_vec); + x1 = vaddq_s32(x1, b_vec); + x0 = vqrdmulhq_s32(x0, M_vec); + x1 = vqrdmulhq_s32(x1, M_vec); + // FIXME Theoretically, we should check shift != 0 here, + auto fixup0 = vshrq_n_s32(x0, 31); + auto fixup1 = vshrq_n_s32(x1, 31); + x0 = vqaddq_s32(x0, fixup0); + x1 = vqaddq_s32(x1, fixup1); + auto o0 = vqmovn_s32(vrshlq_s32(x0, shift_vec)); + auto o1 = vqmovn_s32(vrshlq_s32(x1, shift_vec)); + auto of = vqmovn_s16(vcombine_s16(o0, o1)); + of = vmin_s8(of, maxv_vec); + of = vmax_s8(of, minv_vec); + vst1_s8(dst_ptr, of); + } + dt_int32 bias = b_ptr[b_pos]; + for (; i < channel_offset; ++i, ++x_ptr, ++dst_ptr) { + dt_int32 result = rounding_shift_right_away_from_zero( + round_mulh_saturate(*x_ptr + bias, M), + offset); + *dst_ptr = static_cast(std::max( + std::min(result, maxv), minv)); + } + } + } +} + +bool ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_rshr( + const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst) { + BroadcastChannelInfo binfo; + if (is_vector(param[0].layout) && + is_broadcasted_channel_like(param[1].layout, binfo) && + is_broadcasted_scalar(param[2].layout) && + is_broadcasted_scalar(param[3].layout) && + is_broadcasted_scalar(param[4].layout) && + is_broadcasted_scalar(param[5].layout)) { + auto offset = param[3].ptr()[0]; + auto minv = param[4].ptr()[0]; + auto maxv = param[5].ptr()[0]; + switch (param[0].layout.dtype.enumv()) { +#define DISPATCH(stype, suffix) \ + case DTypeTrait::enumv: { \ + auto x_ptr = param[0].ptr::ctype>(); \ + auto b_ptr = param[1].ptr::ctype>(); \ + auto M = param[2].ptr::ctype>()[0]; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_##suffix( \ + binfo.x, binfo.y, binfo.z, x_ptr, b_ptr, M, offset, \ + minv, maxv, param.size, dst)); \ + break; \ + } + DISPATCH(dtype::Int16, int16) + DISPATCH(dtype::Int32, int32) + default: + megdnn_throw("unreachable"); + } + return true; + } + return false; +#undef DISPATCH +} + +void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( + const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst) { + if (dispatch_fuse_add_rmulh_rshr(param, dst)) + return; + fallback::ElemwiseMultiTypeImpl:: + on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8(param, dst); +} + +void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( + const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst) { + if (dispatch_fuse_add_rmulh_rshr(param, dst)) + return; + fallback::ElemwiseMultiTypeImpl:: + on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8(param, dst); +} + +void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<1>& param, + const TensorND& dst, + Elemwise::Mode mode) { + megdnn_assert(param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); + megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); + +#define DISPATCH_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::RELU, ReluOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::H_SWISH, \ + HSwishOp) \ + default: \ + break; \ + } + +#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::RELU, ReluOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ABS, AbsOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SIGMOID, \ + SigmoidOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::EXP, ExpOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::TANH, TanhOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FAST_TANH, \ + FastTanhOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::H_SWISH, \ + HSwishOp) \ + default: \ + break; \ + } + +#define DISPATCH() \ + if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ + } else if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ + dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ + DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, \ + dtype::Quantized8Asymm) \ + } else if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ + } else if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ + dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ + DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \ + } + + TensorND src = param[0]; + + size_t nr_elems = src.layout.total_nr_elems(); +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerUnary<_op, VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src.ptr(), dst.ptr(), \ + src.layout.dtype, dst.layout.dtype, nr_elems)); \ + return; \ + } + + DISPATCH() + + fallback::ElemwiseMultiTypeImpl::on_quantized_mode(param, dst, mode); + +#undef DISPATCH_SINGLE_MODE +#undef DISPATCH +#undef DISPATCH_QUANTIZED_MODE +#undef DISPATCH_MODE +} + +void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, + const TensorND& dst, + Elemwise::Mode mode) { + megdnn_assert(param[0].layout.dtype.enumv() == + param[1].layout.dtype.enumv() && + param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); + megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); + +#define DISPATCH_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, \ + FuseAddReluOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, \ + Elemwise::Mode::FUSE_ADD_H_SWISH, \ + FuseAddHSwishOp) \ + default: \ + break; \ + } + +#if MEGDNN_AARCH64 +#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MIN, MinOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MAX, MaxOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SUB, SubOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MUL, MulOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::TRUE_DIV, \ + TrueDivOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, \ + FuseAddReluOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, \ + Elemwise::Mode::FUSE_ADD_SIGMOID, \ + FuseAddSigmoidOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_TANH, \ + FuseAddTanhOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, \ + Elemwise::Mode::FUSE_ADD_H_SWISH, \ + FuseAddHSwishOp) \ + default: \ + break; \ + } +#else +#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MIN, MinOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MAX, MaxOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SUB, SubOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MUL, MulOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, \ + FuseAddReluOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, \ + Elemwise::Mode::FUSE_ADD_SIGMOID, \ + FuseAddSigmoidOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_TANH, \ + FuseAddTanhOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, \ + Elemwise::Mode::FUSE_ADD_H_SWISH, \ + FuseAddHSwishOp) \ + default: \ + break; \ + } +#endif + +#define DISPATCH() \ + if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ + } else if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ + dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ + DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \ + } else if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ + } else if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ + dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ + DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, \ + dtype::Quantized8Asymm) \ + } + + TensorND src0 = param[0]; + TensorND src1 = param[1]; + + //! VEC + VEC + if (is_vector(src0.layout) && is_vector(src1.layout)) { + size_t nr_elems = src0.layout.total_nr_elems(); +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, VEC_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + dst.ptr(), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, nr_elems)); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! VEC + SCALAR + { + bool normal_case = + is_vector(src0.layout) && is_broadcasted_scalar(src1.layout); + bool swap_case = false; + bool commutable = false; + if (mode != Elemwise::Mode::SUB && mode != Elemwise::Mode::TRUE_DIV) + commutable = true; + if (!normal_case && commutable) { + swap_case = is_vector(src1.layout) && + is_broadcasted_scalar(src0.layout); + } + if (normal_case || swap_case) { + auto &lhs = src0, &rhs = src1; + if (swap_case) + std::swap(lhs, rhs); +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, \ + VEC_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ + src0.ptr(), src1.ptr()[0], \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, src0.layout.total_nr_elems())); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! SCALAR + VEC + if (!commutable && is_vector(src1.layout) && + is_broadcasted_scalar(src0.layout)) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, \ + SCALAR_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ + src0.ptr()[0], src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, src1.layout.total_nr_elems())); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + } + + //! VEC + BCAST101 + { + BroadcastChannelInfo binfo; + bool normal_case = is_vector(src0.layout) && + is_broadcasted_channel_like(src1.layout, binfo); + bool swap_case = false; + bool commutable = false; + if (mode != Elemwise::Mode::SUB && mode != Elemwise::Mode::TRUE_DIV) + commutable = true; + if (!normal_case && commutable) { + swap_case = is_vector(src1.layout) && + is_broadcasted_channel_like(src0.layout, binfo); + } + if (normal_case || swap_case) { + auto &lhs = src0, &rhs = src1; + if (swap_case) + std::swap(lhs, rhs); +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, \ + VEC_BCAST101>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ + src0.ptr(), src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! BCAST101 + VEC : only for SUB or TRUE_DIV + if (!commutable && is_vector(src1.layout) && + is_broadcasted_channel_like(src0.layout, binfo)) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, \ + BCAST101_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ + src0.ptr(), src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + } + + //! VEC + BCAST101x4 + { + BroadcastChannelInfo binfo; + if (is_vector(src0.layout) && + is_broadcastedx_channel_like<4>(src1.layout, binfo)) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, \ + VEC_BCAST101x4>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ + src0.ptr(), src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + + size_t batch_size = + src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! BCAST101x + VEC + if (is_vector(src1.layout) && + is_broadcastedx_channel_like<4>(src0.layout, binfo)) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, \ + BCAST101x4_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ + src0.ptr(), src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + + size_t batch_size = + src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + } + + fallback::ElemwiseMultiTypeImpl::on_quantized_mode(param, dst, mode); + +#undef DISPATCH_MODE +#undef DISPATCH_QUANTIZED_MODE +#undef DISPATCH +} + +void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, + const TensorND& dst, + Elemwise::Mode mode) { + megdnn_assert( + param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv() && + param[0].layout.dtype.enumv() == param[2].layout.dtype.enumv() && + param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); + megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); + +#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FUSE_MUL_ADD3, \ + FuseMulAdd3Op) \ + default: \ + break; \ + } + +#define DISPATCH() \ + if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ + } else if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ + dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ + DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, \ + dtype::Quantized8Asymm) \ + } + + TensorND src0 = param[0]; + TensorND src1 = param[1]; + TensorND src2 = param[2]; + + //! VEC + VEC + VEC + if (is_vector(src0.layout) && is_vector(src1.layout) && + is_vector(src2.layout)) { + size_t nr_elems = src0.layout.total_nr_elems(); +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary<_op, \ + VEC_VEC_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + src2.ptr(), dst.ptr(), \ + src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, nr_elems)); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! VEC + VEC + SCALAR + if (is_vector(src0.layout) && is_vector(src1.layout) && + is_broadcasted_scalar(src2.layout)) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary<_op, \ + VEC_VEC_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + src2.ptr()[0], dst.ptr(), \ + src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, src0.layout.total_nr_elems())); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! BCAST101 + VEC + BCAST101 + { + BroadcastChannelInfo binfo; + bool normal_case = is_vector(src1.layout) && + is_broadcasted_channel_like(src0.layout, binfo) && + src0.layout.eq_shape(src2.layout); + if (normal_case) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary<_op, \ + BCAST101_VEC_BCAST101>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + src2.ptr(), dst.ptr(), \ + src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + } + + fallback::ElemwiseMultiTypeImpl::on_quantized_mode(param, dst, mode); +#undef DISPATCH +#undef DISPATCH_QUANTIZED_MODE +} + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_multi_type/opr_impl.h b/dnn/src/arm_common/elemwise_multi_type/opr_impl.h new file mode 100644 index 00000000..5cfdd25d --- /dev/null +++ b/dnn/src/arm_common/elemwise_multi_type/opr_impl.h @@ -0,0 +1,55 @@ +/** + * \file dnn/src/arm_common/elemwise_multi_type/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/fallback/elemwise_multi_type/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class ElemwiseMultiTypeImpl : public fallback::ElemwiseMultiTypeImpl { + template + void neon_round_shr_saturate_bcast_scalar(const stype* a_ptr, int8_t k, + size_t size, dt_int8* dst_ptr); + + template + void dispatch_round_shr_saturate_iXxi8xi8_bcast_scalar( + const ElemwiseOpParamN<2>& param, megdnn::dt_int8* dst); + + bool dispatch_fuse_add_rmulh_rshr(const ElemwiseOpParamN<6>& param, + megdnn::dt_int8* dst); + +protected: + void on_round_shr_saturate_iXxi8xi8(const ElemwiseOpParamN<2>& param, + dt_int8* dst) override; + void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( + const ElemwiseOpParamN<6>& param, dt_int8* dst) override; + void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( + const ElemwiseOpParamN<6>& param, dt_int8* dst) override; + + void on_quantized_mode(const ElemwiseOpParamN<1>& param, + const TensorND& dst, Elemwise::Mode mode) override; + + void on_quantized_mode(const ElemwiseOpParamN<2>& param, + const TensorND& dst, Elemwise::Mode mode) override; + + void on_quantized_mode(const ElemwiseOpParamN<3>& param, + const TensorND& dst, Elemwise::Mode mode) override; + +public: + using fallback::ElemwiseMultiTypeImpl::ElemwiseMultiTypeImpl; +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_op.h b/dnn/src/arm_common/elemwise_op.h new file mode 100644 index 00000000..5cb7312e --- /dev/null +++ b/dnn/src/arm_common/elemwise_op.h @@ -0,0 +1,801 @@ +/** + * \file dnn/src/arm_common/elemwise_op.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "src/arm_common/elemwise_helper/op_binary.h" +#include "src/arm_common/elemwise_helper/op_ternary.h" +#include "src/arm_common/elemwise_helper/op_unary.h" + +namespace megdnn { +namespace arm_common { + +///////////////////////////////// ParamElemVistor /////////////////////////// +template +struct ParamElemVisitor; + +//! visitor single elemwise, and dup to vector +template +struct ParamElemVisitorDup; + +#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitor<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vld1q_##_fun_suffix( \ + reinterpret_cast(src)); \ + } \ + }; \ + template <> \ + struct ParamElemVisitorDup<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vdupq_n_##_fun_suffix( \ + *reinterpret_cast(src)); \ + } \ + } +cb(dt_qint32, int32_t, int32x4_t, s32); +cb(dt_qint8, int8_t, int8x16_t, s8); +cb(dt_quint8, uint8_t, uint8x16_t, u8); + +cb(dt_float32, float32_t, float32x4_t, f32); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +cb(__fp16, __fp16, float16x8_t, f16); +#endif +cb(dt_int32, int32_t, int32x4_t, s32); +cb(dt_int16, int16_t, int16x8_t, s16); +cb(dt_int8, int8_t, int8x16_t, s8); +#undef cb + +template +struct ParamElemVisitorBcast101x4; +#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x4<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vreinterpretq_##_fun_suffix##_##rel_suffix( \ + vld1q_dup_##rel_suffix( \ + reinterpret_cast(src))); \ + } \ + } + +cb(dt_qint8, int32_t, int8x16_t, s8, s32); +cb(dt_quint8, uint32_t, uint8x16_t, u8, u32); +cb(dt_int8, int32_t, int8x16_t, s8, s32); +cb(dt_int16, int64_t, int16x8_t, s16, s64); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +cb(__fp16, uint64_t, float16x8_t, f16, u64); +#endif +#undef cb +#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x4<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vld1q_##_fun_suffix( \ + reinterpret_cast(src)); \ + } \ + } + +cb(dt_qint32, int32_t, int32x4_t, s32); +cb(dt_float32, float32_t, float32x4_t, f32); +cb(dt_int32, int32_t, int32x4_t, s32); +#undef cb + +/*! + * \brief broadcast type + * BCAST_x[0]x[1]...: x[i] == !stride[i] + */ +enum BcastType { + VEC, + VEC_VEC, + VEC_BCAST101, + VEC_BCAST101x4, + VEC_SCALAR, + SCALAR_VEC, + BCAST101_VEC, + BCAST101x4_VEC, + VEC_VEC_VEC, + VEC_VEC_SCALAR, + BCAST101_VEC_BCAST101, + VEC_BCAST101_VEC, + VEC_SCALAR_VEC, + VEC_SCALAR_SCALAR, + UNKNOWN_BCAST_TYPE +}; + +///////////////////////////////// OpCaller ///////////////////////////// +template +struct OpCallerUnary; + +template +struct OpCallerUnary { + static void run(const typename Op::src_ctype* src, + typename Op::dst_ctype* dst, DType src_dtype, + DType dst_dtype, size_t nr_elems) { + Op op(src_dtype, dst_dtype); + ParamElemVisitor vis; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis(src), vis(src + Op::SIMD_WIDTH)}}, dst); + src += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src, dst); + src++; + dst++; + } + } +}; + +template +struct OpCallerBinary; + +///////////////////////// Pow //////////////////////////////// +template +struct OpCallerBinary, VEC_VEC> { + using Op = PowOp; + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, *src1, dst); + src0++; + src1++; + dst++; + } + } +}; + +template +struct OpCallerBinary, VEC_SCALAR> { + using Op = PowOp; + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, src1, dst); + src0++; + dst++; + } + } +}; + +template +struct OpCallerBinary, VEC_BCAST101> { + using Op = PowOp; + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t batch, + size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr = src1; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + dst++; + } + src1_ptr++; + } + } + } +}; + +template +struct OpCallerBinary, SCALAR_VEC> { + using Op = PowOp; + static void run(const typename Op::src_ctype src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(src0, *src1, dst); + src1++; + dst++; + } + } +}; + +template +struct OpCallerBinary, BCAST101_VEC> { + using Op = PowOp; + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t batch, + size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src1++; + dst++; + } + src0_ptr++; + } + } + } +}; + +template +struct OpCallerBinary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitor vis1; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, *src1, dst); + src0++; + src1++; + dst++; + } + } +}; + +template +struct OpCallerBinary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t batch, + size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr = src1; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + auto src1_neon = vis1(src1_ptr); + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{src1_neon, src1_neon}}, dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + dst++; + } + src1_ptr++; + } + } + } +}; + +template +struct OpCallerBinary, BCAST101x4_VEC> { + using Op = PowOp; + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t batch, + size_t nr_channel_blocks, size_t channel_stride, + size_t channel_block_dim) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + for (size_t i = 0; i < channel_stride; i++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; + c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, dst); + src1++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t batch, + size_t nr_channel_blocks, size_t channel_stride, + size_t channel_block_dim) { + megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitorBcast101x4 vis0; + ParamElemVisitor vis1; + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + auto channel_block_vec = vis0(src0_block_ptr); + size_t img_index = 0; + auto src1_offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * src1_offset <= channel_stride; + img_index += 2 * src1_offset) { + op({{channel_block_vec, channel_block_vec}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + // TODO:all elemwise_multi_type op imp one simd mode + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; + c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, dst); + src1++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinary, VEC_BCAST101x4> { + using Op = PowOp; + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t batch, + size_t nr_channel_blocks, size_t channel_stride, + size_t channel_block_dim) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + for (size_t i = 0; i < channel_stride; i++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; + c_iter++) { + op(*(src0), *(src1_block_ptr + c_iter), dst); + src0++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t batch, + size_t nr_channel_blocks, size_t channel_stride, + size_t channel_block_dim) { + megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorBcast101x4 vis1; + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + auto channel_block_vec = vis1(src1_block_ptr); + size_t img_index = 0; + auto src0_offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * src0_offset <= channel_stride; + img_index += 2 * src0_offset) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{channel_block_vec, channel_block_vec}}, dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + // TODO:all elemwise_multi_type op imp one simd mode + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; + c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), dst); + src0++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + auto vis1_neon = vis1(&src1); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1_neon, vis1_neon}}, dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, src1, dst); + src0++; + dst++; + } + } +}; + +//! this only for nonswap op, like SUB and DIV +template +struct OpCallerBinary { + static void run(const typename Op::src_ctype src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitorDup vis0; + ParamElemVisitor vis1; + auto vis0_neon = vis0(&src0); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0_neon, vis0_neon}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(src0, *src1, dst); + src1++; + dst++; + } + } +}; + +template +struct OpCallerBinary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t batch, + size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitorDup vis0; + ParamElemVisitor vis1; + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t c = 0; c < channel; c++) { + auto vis0_neon = vis0(src0_ptr); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{vis0_neon, vis0_neon}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src1++; + dst++; + } + src0_ptr++; + } + } + } +}; + +template +struct OpCallerTernary; + +template +struct OpCallerTernary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitor vis1; + ParamElemVisitor vis2; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, *src1, *src2, dst); + src0++; + src1++; + src2++; + dst++; + } + } +}; + +//! src0: vector, src1: vector, src2: scalar +template +struct OpCallerTernary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + const typename Op::src_ctype src2, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitor vis1; + ParamElemVisitorDup vis2; + auto vis2_neon = vis2(&src2); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, + {{vis2_neon, vis2_neon}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, *src1, src2, dst); + src0++; + src1++; + dst++; + } + } +}; + +//! src0: 1C11, src1: vector, src2: 1C11 +template +struct OpCallerTernary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch_size, size_t channel_size, + size_t channel_stride) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis1; + ParamElemVisitorDup vis0; + ParamElemVisitorDup vis2; + for (size_t batch = 0; batch < batch_size; batch++) { + auto src0_ptr = src0; + auto src2_ptr = src2; + for (size_t channel = 0; channel < channel_size; channel++) { + size_t i = 0; + auto src0_neon = vis0(src0_ptr); + auto src2_neon = vis2(src2_ptr); + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{src0_neon, src0_neon}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, + {{src2_neon, src2_neon}}, dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, *src2_ptr, dst); + src1++; + dst++; + } + src0_ptr++; + src2_ptr++; + } + } + } +}; + +//! src1: 1C11, src0 and src2 are contig +template +struct OpCallerTernary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch_size, size_t channel_size, + size_t channel_stride) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + ParamElemVisitor vis2; + for (size_t batch = 0; batch < batch_size; batch++) { + auto src1_ptr = src1; + for (size_t channel = 0; channel < channel_size; channel++) { + size_t i = 0; + auto src1_neon = vis1(src1_ptr); + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{src1_neon, src1_neon}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, *src2, dst); + src0++; + src2++; + dst++; + } + src1_ptr++; + } + } + } +}; + +//! src1: scalar, src0 and src2 has the same shape +template +struct OpCallerTernary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype src1, + const typename Op::src_ctype* src2, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + ParamElemVisitor vis2; + auto vis1_neon = vis1(&src1); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1_neon, vis1_neon}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, src1, *src2, dst); + src0++; + src2++; + dst++; + } + } +}; + +//! src1, src2: scalar, src0 is vector +template +struct OpCallerTernary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype src1, + const typename Op::src_ctype src2, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + ParamElemVisitorDup vis2; + auto vis1_neon = vis1(&src1); + auto vis2_neon = vis2(&src2); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1_neon, vis1_neon}}, {{vis2_neon, vis2_neon}}, dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, src1, src2, dst); + src0++; + dst++; + } + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/handle.cpp b/dnn/src/arm_common/handle.cpp new file mode 100644 index 00000000..b26c33cb --- /dev/null +++ b/dnn/src/arm_common/handle.cpp @@ -0,0 +1,65 @@ +/** + * \file dnn/src/arm_common/handle.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/common/handle_impl.h" + +#include "src/arm_common/handle.h" + +#include "src/arm_common/convolution/opr_impl.h" +#include "src/arm_common/pooling/opr_impl.h" +#include "src/arm_common/local/opr_impl.h" +#include "src/arm_common/separable_conv/opr_impl.h" +#include "src/arm_common/separable_filter/opr_impl.h" +#include "src/arm_common/elemwise/opr_impl.h" +#include "src/arm_common/elemwise_multi_type/opr_impl.h" +#include "src/arm_common/cvt_color/opr_impl.h" +#include "src/arm_common/warp_affine/opr_impl.h" +#include "src/arm_common/resize/opr_impl.h" +#include "src/arm_common/warp_perspective/opr_impl.h" +#include "src/arm_common/type_cvt/opr_impl.h" +#include "src/arm_common/reduce/opr_impl.h" +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/arm_common/winograd_filter_preprocess/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +template +std::unique_ptr HandleImpl::create_operator() { + return fallback::HandleImpl::create_operator(); +} + +MEGDNN_SPECIALIZE_CREATE_OPERATOR(Pooling) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(Local) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(SeparableConv) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(SeparableFilter) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(Elemwise) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseMultiType) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(CvtColor) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpAffine) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(Resize) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspective) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(Reduce) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(WinogradFilterPreprocess) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData) + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpragmas" +#pragma GCC diagnostic ignored "-Winstantiation-after-specialization" +MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) +#pragma GCC diagnostic pop + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/handle.h b/dnn/src/arm_common/handle.h new file mode 100644 index 00000000..ffc63712 --- /dev/null +++ b/dnn/src/arm_common/handle.h @@ -0,0 +1,32 @@ +/** + * \file dnn/src/arm_common/handle.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/handle.h" + +namespace megdnn { +namespace arm_common { + +class HandleImpl: public fallback::HandleImpl { + public: + HandleImpl(megcoreComputingHandle_t computing_handle, + HandleType type = HandleType::ARM_COMMON): + fallback::HandleImpl::HandleImpl(computing_handle, type) + { + } + + template + std::unique_ptr create_operator(); +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/local/opr_impl.cpp b/dnn/src/arm_common/local/opr_impl.cpp new file mode 100644 index 00000000..578e4da1 --- /dev/null +++ b/dnn/src/arm_common/local/opr_impl.cpp @@ -0,0 +1,132 @@ +/** + * \file dnn/src/arm_common/local/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/local/opr_impl.h" + +#include +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" + +using namespace megdnn; +using namespace arm_common; + +namespace { + +void do_one_pixel(float* dst, const float* filter, float sval, int OC) { + const int width = 4u; + int oc = 0; + float32x4_t vs = vdupq_n_f32(sval); + for (; oc + width <= OC; oc += width, filter += width, dst += width) { + float32x4_t vf = vld1q_f32(filter); + float32x4_t vd = vld1q_f32(dst); + vd = vmlaq_f32(vd, vs, vf); + vst1q_f32(dst, vd); + } + for (; oc < OC; oc++, dst++, filter++) { + *dst += sval * (*filter); + } +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void do_one_pixel(dt_float16* dst, const dt_float16* filter, dt_float16 sval, + int OC) { + const __fp16* filter_ptr = reinterpret_cast(filter); + __fp16* dst_ptr = reinterpret_cast<__fp16*>(dst); + const int width = 8u; + int oc = 0; + float16x8_t vs = vdupq_n_f16(sval); + for (; oc + width <= OC; + oc += width, filter_ptr += width, dst_ptr += width) { + float16x8_t vf = vld1q_f16(filter_ptr); + float16x8_t vd = vld1q_f16(dst_ptr); + vd = vmlaq_f16(vd, vs, vf); + vst1q_f16(dst_ptr, vd); + } +#if MEGDNN_FIX_AARCH32_BUG + // FIXME: as llvm may cause cannot select error if enable vectorize + #pragma clang loop vectorize(disable) +#endif + for (; oc < OC; oc++, dst_ptr++, filter_ptr++) { + *dst_ptr += sval * (*filter_ptr); + } +} +#endif + +template +void exec_internal(const LocalImpl::FloatNoncontigBatchKernParam& kparam) { + UNPACK_LOCAL_FLOAT_NONCONTIG_BATCH_KERN_PARAM(kparam, dtype); + auto dst2 = workspace; + // dst2 is (H, W, N, C) + std::memset(dst2, 0, sizeof(dtype) * OH * OW * N * OC); + dtype* dst2_hwnc = dst2; + rep(oh, OH) rep(ow, OW) { + const dtype* src_bak = src; + rep(ic, IC) { + rep(fh, FH) for (int fw = 0; fw < FW; ++fw, filter += OC) { + int ih = -PH + oh * SH + (is_xcorr ? fh : (FH - fh - 1)); + int iw = -PW + ow * SW + (is_xcorr ? fw : (FW - fw - 1)); + if (ih < 0 || ih >= IH || iw < 0 || iw >= IW) + continue; + dtype* dst2_bak = dst2; + rep(n, N) { + dtype s = src[n * INP_BS + ih * IW + iw]; + do_one_pixel(dst2, filter, s, OC); + dst2 += OC; + } + dst2 = dst2_bak; + } + src += IH * IW; + } + src = src_bak; + dst2 += N * OC; + } + transpose_knc2nsck(dst2_hwnc, dst, OH * OW, N, OC, OUT_BS); +} + +} // anonymous namespace + +size_t LocalImpl::get_workspace_in_bytes(const TensorLayout& /* src */, + const TensorLayout& /* filter */, + const TensorLayout& dst) { + return dst.span().dist_byte(); +} + +LocalImpl::float_noncontig_batch_kern LocalImpl::dispatch_float_noncontig_batch( + const TensorLayout& src, const TensorLayout&, const TensorLayout&) { + megdnn_assert(src.stride[0] > 0 && + static_cast(src.stride[0]) >= + src.total_nr_elems() / src.shape[0]); + if (src.dtype == dtype::Float32()) { + if (param().mode == Mode::CROSS_CORRELATION) { + return exec_internal; + } else { + return exec_internal; + } +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } else { + megdnn_assert(src.dtype == dtype::Float16()); + if (param().mode == Mode::CROSS_CORRELATION) { + return exec_internal; + } else { + return exec_internal; + } +#endif + } + megdnn_assert_internal(false); + return nullptr; +} + +void LocalImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + return exec_use_float_noncontig_batch(src, filter, dst, workspace); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/local/opr_impl.h b/dnn/src/arm_common/local/opr_impl.h new file mode 100644 index 00000000..f19d06c2 --- /dev/null +++ b/dnn/src/arm_common/local/opr_impl.h @@ -0,0 +1,40 @@ +/** + * \file dnn/src/arm_common/local/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/naive/local/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class LocalImpl final: public naive::LocalForwardImpl { + public: + using naive::LocalForwardImpl::LocalForwardImpl; + + float_noncontig_batch_kern dispatch_float_noncontig_batch( + const TensorLayout &src, + const TensorLayout &filter, + const TensorLayout &dst) override; + + void exec(_megdnn_tensor_in src, + _megdnn_tensor_in filter, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout &src, + const TensorLayout &filter, + const TensorLayout &dst) override; +}; + +} // namespace arm_common +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/algos.cpp b/dnn/src/arm_common/matrix_mul/algos.cpp new file mode 100644 index 00000000..be5b663c --- /dev/null +++ b/dnn/src/arm_common/matrix_mul/algos.cpp @@ -0,0 +1,188 @@ +/** + * \file dnn/src/arm_common/matrix_mul/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/matrix_mul/algos.h" +#include "src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.h" +#include "src/arm_common/matrix_mul/fp16/hgemv.h" +#include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" +#include "src/arm_common/matrix_mul/int8/gemv.h" +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_hgemv) +MIDOUT_DECL(megdnn_arm_exec_int8816) + +using namespace megdnn; +using namespace arm_common; + +/* ===================== Int8x8x16 algo ===================== */ + +namespace { +WorkspaceBundle get_workspace_bundle_int_8x8x16( + const MatrixMulImpl::KernSizeParam& kern_size_param) { + auto M = kern_size_param.M, K = kern_size_param.K, N = kern_size_param.N; + // Use 8x8 tile + return WorkspaceBundle(nullptr, {(M + 8) * K * sizeof(int8_t), + K * (N + 8) * sizeof(int8_t)}); +} + +void exec_int_8x8x16(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_arm_exec_int8816, void) { + auto bundle = get_workspace_bundle_int_8x8x16(kern_param); + bundle.set(kern_param.workspace_ptr); + auto w0 = static_cast(bundle.get(0)); + auto w1 = static_cast(bundle.get(1)); + size_t M = kern_param.M; + size_t N = kern_param.N; + size_t K = kern_param.K; + size_t LDB = kern_param.LDB; + exec_gemm_int8_int8_int16(kern_param.A(), + kern_param.B(), + kern_param.C(), M, K, N, LDB, w0, w1); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x16::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type == dtype::Int8() && + kern_size_param.B_type == dtype::Int8() && + kern_size_param.C_type == dtype::Int16() && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + !kern_size_param.trA && !kern_size_param.trB; +} + +size_t MatrixMulImpl::AlgoInt8x8x16::get_workspace( + const KernSizeParam& kern_size_param) const { + auto wbundle = get_workspace_bundle_int_8x8x16(kern_size_param); + return wbundle.total_size_in_bytes(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern( + const KernSizeParam&) const { + return exec_int_8x8x16; +} + +#if !__ARM_FEATURE_DOTPROD +/* ===================== Int8x8x32 Gemv algo ===================== */ +namespace { +void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + arm_common::matmul::gemv_like_int8(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, + LDC); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32Gemv::usable( + const KernSizeParam& kern_size_param) const { + auto N = kern_size_param.N, LDB = kern_size_param.LDB; + return can_be_treated_as_int8x8x32(kern_size_param) && + !kern_size_param.trA && !kern_size_param.trB && (N == 1 && LDB == 1); +} + +bool MatrixMulImpl::AlgoInt8x8x32Gemv::preferred( + const KernSizeParam& kern_size_param) const { + auto N = kern_size_param.N, LDB = kern_size_param.LDB; + return N == 1 && LDB == 1; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern( + const KernSizeParam&) const { + return int8x8x32_gemv_kern; +} +#endif + +/* ===================== F32 Gemv algo ===================== */ +namespace { +void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + arm_common::sgemm_sgemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoF32Gemv::usable( + const KernSizeParam& kern_size_param) const { + // enumerate the M, N, K, only usable when preferred + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && + !kern_size_param.trB && preferred(kern_size_param); +} + +bool MatrixMulImpl::AlgoF32Gemv::preferred( + const KernSizeParam& kern_size_param) const { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K, + LDB = kern_size_param.LDB; + + return M < 8 || (M == 8 && K <= 2) || (N == 1 && LDB == 1); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern( + const KernSizeParam&) const { + return f32_gemv_kern; +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +/* ===================== F16 Gemv algo ===================== */ +namespace { +void f16_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + MIDOUT_BEGIN(megdnn_arm_hgemv, void) { + arm_common::hgemv_exec(reinterpret_cast(Aptr), + reinterpret_cast(Bptr), + reinterpret_cast<__fp16*>(Cptr), M, N, K, LDA, + LDB, LDC); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoF16Gemv::usable( + const KernSizeParam& kern_size_param) const { + // enumerate the M, N, K, only usable when preferred + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float16() && !kern_size_param.trA && + !kern_size_param.trB && preferred(kern_size_param); +} + +bool MatrixMulImpl::AlgoF16Gemv::preferred( + const KernSizeParam& kern_size_param) const { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K, + LDB = kern_size_param.LDB; + + return M <= 4 || (M == 8 && K <= 2) || (N == 1 && LDB == 1); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16Gemv::get_kern( + const KernSizeParam&) const { + return f16_gemv_kern; +} +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/algos.h b/dnn/src/arm_common/matrix_mul/algos.h new file mode 100644 index 00000000..8db0b316 --- /dev/null +++ b/dnn/src/arm_common/matrix_mul/algos.h @@ -0,0 +1,81 @@ +/** + * \file dnn/src/arm_common/matrix_mul/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/matrix_mul/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class MatrixMulImpl::AlgoInt8x8x16 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_COMMON_INT8X8X16"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + PackMode packmode() const override { return PackMode::NO_PACK; } +}; + +#if !__ARM_FEATURE_DOTPROD +class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { +protected: + ~AlgoInt8x8x32Gemv() = default; + +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override { return 0; } + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } + PackMode packmode() const override { return PackMode::NO_PACK; } +}; +#endif + +class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { +protected: + ~AlgoF32Gemv() = default; + +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_COMMON_F32_GEMV"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override { return 0; } + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } + PackMode packmode() const override { return PackMode::NO_PACK; } +}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_COMMON_F16_GEMV"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override { return 0; } + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } + PackMode packmode() const override { return PackMode::NO_PACK; } +}; +#endif +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.cpp b/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.cpp new file mode 100644 index 00000000..bf5a77d2 --- /dev/null +++ b/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.cpp @@ -0,0 +1,454 @@ +/** + * \file dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.h" + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +namespace { + +inline int8x8_t vreinterpret_s8_s8(int8x8_t x) { return x; } + +void packA(const int8_t *src, int8_t *dst, size_t M, size_t K) +{ +#if 0 + // naive impl + megdnn_assert(M % 8 == 0); + for (size_t m = 0; m+8 <= M; m += 8) { + for (size_t k = 0; k < K; ++k) { + for (size_t m2 = m; m2 < m+8; ++m2) *(dst++) = src[m2*K + k]; + } + } +#else + // 8x8 block at a time + size_t m = 0; + int8_t * __restrict dptr = dst; + for (; m+8 <= M; m += 8) { + size_t k = 0; + for (; k+8 <= K; k += 8) { + const int8_t * __restrict sptr = src + (m*K + k); + int8x8_t l0 = vld1_s8(sptr + 0*K), + l1 = vld1_s8(sptr + 1*K), + l2 = vld1_s8(sptr + 2*K), + l3 = vld1_s8(sptr + 3*K), + l4 = vld1_s8(sptr + 4*K), + l5 = vld1_s8(sptr + 5*K), + l6 = vld1_s8(sptr + 6*K), + l7 = vld1_s8(sptr + 7*K); + // do transpose +#define TRANS(lhs, rhs, bit) { \ + auto tmp = vtrn_s ## bit(vreinterpret_s ## bit ## _s8(lhs), \ + vreinterpret_s ## bit ## _s8(rhs)); \ + lhs = vreinterpret_s8_s ## bit(tmp.val[0]); \ + rhs = vreinterpret_s8_s ## bit(tmp.val[1]); \ +} + TRANS(l0, l4, 32); + TRANS(l1, l5, 32); + TRANS(l2, l6, 32); + TRANS(l3, l7, 32); + TRANS(l0, l2, 16); + TRANS(l1, l3, 16); + TRANS(l4, l6, 16); + TRANS(l5, l7, 16); + TRANS(l0, l1, 8); + TRANS(l2, l3, 8); + TRANS(l4, l5, 8); + TRANS(l6, l7, 8); +#undef TRANS + vst1_s8(dptr, l0); dptr += 8; + vst1_s8(dptr, l1); dptr += 8; + vst1_s8(dptr, l2); dptr += 8; + vst1_s8(dptr, l3); dptr += 8; + vst1_s8(dptr, l4); dptr += 8; + vst1_s8(dptr, l5); dptr += 8; + vst1_s8(dptr, l6); dptr += 8; + vst1_s8(dptr, l7); dptr += 8; + } + for (; k < K; ++k) { + const int8_t * __restrict sptr = src + (m*K + k); + for (size_t i = 0; i < 8; ++i) *(dptr++) = *(sptr + i*K); + } + } + if (m < M) { + for (size_t k = 0; k < K; ++k) { + const int8_t * __restrict sptr = src + (m*K + k); + for (size_t i = 0; i < 8; ++i) { + *(dptr++) = (m+i < M ? *(sptr + i*K) : 0); + } + } + } +#endif +} + +#define LOAD(i) \ + int8x8_t l ## i = vld1_s8(sptr); \ + int8x8_t s ## i = vld1_s8(sptr + 8); \ + sptr += LDB; + +#define STORE(i) \ + vst1_s8(dptr, l ## i); \ + dptr += 8; \ + vst1_s8(dptr, s ## i); \ + dptr += 8; + +#define TRANS(i) \ + int8x8_t l ## i = vld1_s8(sptr); \ + int8x8_t s ## i = vld1_s8(sptr + 8); \ + sptr += N; \ + vst1_s8(dptr, l ## i); \ + dptr += 8; \ + vst1_s8(dptr, s ## i); \ + dptr += 8; + +void packB(const int8_t *src, int8_t *dst, size_t K, size_t N, size_t LDB) +{ +#if 0 + megdnn_assert(N % 8 == 0); + for (size_t n = 0; n+8 <= N; n += 8) + for (size_t k = 0; k < K; ++k) + { + for (size_t n2 = n; n2 < n+8; ++n2) *(dst++) = src[k*N + n2]; + } +#else + int8_t * __restrict dptr = dst; + size_t n = 0; + for(; n+16 <=N; n += 16) { + size_t k = 0; + for (; k+8 <= K; k += 8) { + const int8_t * __restrict sptr = src + k * LDB + n; + + LOAD(0); + LOAD(1); + LOAD(2); + LOAD(3); + LOAD(4); + LOAD(5); + LOAD(6); + LOAD(7); +#undef LOAD + STORE(0); + STORE(1); + STORE(2); + STORE(3); + STORE(4); + STORE(5); + STORE(6); + STORE(7); +#undef STORE +#undef TRANS + + } + for (; k < K; ++k) { + const int8_t * __restrict sptr = src + k * LDB + n; + int8x8_t l = vld1_s8(sptr); + int8x8_t s = vld1_s8(sptr + 8); + vst1_s8(dptr, l); dptr += 8; + vst1_s8(dptr, s); dptr += 8; + } + } + for (; n+8 <= N; n += 8) { + size_t k = 0; + for (; k+8 <= K; k += 8) { + const int8_t * __restrict sptr = src + k * LDB + n; + int8x8_t l0 = vld1_s8(sptr + 0*N), + l1 = vld1_s8(sptr + 1*N), + l2 = vld1_s8(sptr + 2*N), + l3 = vld1_s8(sptr + 3*N), + l4 = vld1_s8(sptr + 4*N), + l5 = vld1_s8(sptr + 5*N), + l6 = vld1_s8(sptr + 6*N), + l7 = vld1_s8(sptr + 7*N); + vst1_s8(dptr, l0); dptr += 8; + vst1_s8(dptr, l1); dptr += 8; + vst1_s8(dptr, l2); dptr += 8; + vst1_s8(dptr, l3); dptr += 8; + vst1_s8(dptr, l4); dptr += 8; + vst1_s8(dptr, l5); dptr += 8; + vst1_s8(dptr, l6); dptr += 8; + vst1_s8(dptr, l7); dptr += 8; + } + for (; k < K; ++k) { + const int8_t * __restrict sptr = src + k * LDB + n; + int8x8_t l = vld1_s8(sptr); + vst1_s8(dptr, l); dptr += 8; + } + } + if (n < N) { + for (size_t k = 0; k < K; ++k) { + const int8_t * __restrict sptr = src + k * LDB + n; + int8_t l[8] = {0}; + for (size_t i = 0; n+i < N; ++i) l[i] = sptr[i]; + for (size_t i = 0; i < 8; ++i) *(dptr++) = l[i]; + } + } +#endif +} + +} // anonymous namespace + +//#include + +namespace megdnn { +namespace arm_common { + +#define GAO(i) { \ + tmp = vdup_lane_s8(a, i); \ + l ## i = vmlal_s8(l ## i, tmp, b); \ +} + +#define STORE_REMAIN_N(i, p) \ + if(plen > p) \ + Cptr[p] = vgetq_lane_s16(l##i, p); \ + else \ + break; + +#define STORE_PARTRIAL_N(i) { \ + while(1) { \ + STORE_REMAIN_N(i, 0) \ + STORE_REMAIN_N(i, 1) \ + STORE_REMAIN_N(i, 2) \ + STORE_REMAIN_N(i, 3) \ + STORE_REMAIN_N(i, 4) \ + STORE_REMAIN_N(i, 5) \ + STORE_REMAIN_N(i, 6) \ + break; \ + } \ + Cptr += N; \ +} + +#define STORE_PARTRIAL_M(i) { \ + if(plen > i) { \ + vst1q_s16(Cptr, l##i); \ + Cptr += N; \ + } \ + else \ + break; \ +} + +#define GAO_16(i) { \ + tmp = vdup_lane_s8(a, i); \ + l ## i = vmlal_s8(l ## i, tmp, b0); \ + s ## i = vmlal_s8(s ## i, tmp, b1); \ +} + +#define STORE_16(i) { \ + vst1q_s16(Cptr, l##i); \ + vst1q_s16(Cptr + 8, s##i); \ + Cptr += N; \ +} + +#define STORE_REMAIN_N_16(i, p) \ + if(plen > p) \ + Cptr[8+p] = vgetq_lane_s16(s##i, p); \ + else \ + break; + +#define STORE_PARTRIAL_N_16(i) { \ + while(1) { \ + vst1q_s16(Cptr, l##i); \ + STORE_REMAIN_N_16(i, 0) \ + STORE_REMAIN_N_16(i, 1) \ + STORE_REMAIN_N_16(i, 2) \ + STORE_REMAIN_N_16(i, 3) \ + STORE_REMAIN_N_16(i, 4) \ + STORE_REMAIN_N_16(i, 5) \ + STORE_REMAIN_N_16(i, 6) \ + break; \ + } \ + Cptr += N; \ +} + +#define STORE_PARTRIAL_M_16(i) { \ + if(plen > i) \ + STORE_16(i) \ + else \ + break; \ +} + +void exec_gemm_int8_int8_int16(const int8_t *A_, const int8_t *B_, int16_t *C, + size_t M, size_t K, size_t N,size_t LDB, + int8_t *w0, int8_t *w1) +{ + // for test + //printf("matrix_mul M %ld, K %ld, N %ld \n", M, K, N); + packA(A_, w0, M, K); + packB(B_, w1, K, N, LDB); + + const int8_t * A = w0; + const int8_t * B = w1; + for (size_t m = 0; m < M; m += 8) { + size_t n = 0; + for (; n + 16 <= N; n += 16) { + //for (; n + 7 < N; n += 16) { + int16x8_t l0 = vdupq_n_s16(0), + l1 = vdupq_n_s16(0), + l2 = vdupq_n_s16(0), + l3 = vdupq_n_s16(0), + l4 = vdupq_n_s16(0), + l5 = vdupq_n_s16(0), + l6 = vdupq_n_s16(0), + l7 = vdupq_n_s16(0), + s0 = vdupq_n_s16(0), + s1 = vdupq_n_s16(0), + s2 = vdupq_n_s16(0), + s3 = vdupq_n_s16(0), + s4 = vdupq_n_s16(0), + s5 = vdupq_n_s16(0), + s6 = vdupq_n_s16(0), + s7 = vdupq_n_s16(0); + + const int8_t * __restrict Aptr = A + m*K; + const int8_t * __restrict Bptr = B + n*K; + + for (size_t k = 0; k < K; ++k) { + int8x8_t tmp; + int8x8_t a = vld1_s8(Aptr), + b0 = vld1_s8(Bptr), + b1 = vld1_s8(Bptr + 8); + Aptr += 8; + Bptr += 16; + //__builtin_prefetch(Aptr, 0, 0); + __builtin_prefetch(Bptr, 0, 0); + + GAO_16(0); + GAO_16(1); + GAO_16(2); + GAO_16(3); + GAO_16(4); + GAO_16(5); + GAO_16(6); + GAO_16(7); + + + } + + int16_t * __restrict Cptr = C + m*N + n; + + if (m+8 <= M) { // sub-case 1: m+8 <= M && n+16 <= N + STORE_16(0) + STORE_16(1) + STORE_16(2) + STORE_16(3) + STORE_16(4) + STORE_16(5) + STORE_16(6) + STORE_16(7) + } else { + size_t plen = M - m; + while(1) { + STORE_PARTRIAL_M_16(0) + STORE_PARTRIAL_M_16(1) + STORE_PARTRIAL_M_16(2) + STORE_PARTRIAL_M_16(3) + STORE_PARTRIAL_M_16(4) + STORE_PARTRIAL_M_16(5) + STORE_PARTRIAL_M_16(6) + break; + } + } + } + + for (; n < N; n += 8) { + int16x8_t l0 = vdupq_n_s16(0), + l1 = vdupq_n_s16(0), + l2 = vdupq_n_s16(0), + l3 = vdupq_n_s16(0), + l4 = vdupq_n_s16(0), + l5 = vdupq_n_s16(0), + l6 = vdupq_n_s16(0), + l7 = vdupq_n_s16(0); + const int8_t * __restrict Aptr = A + m*K; + const int8_t * __restrict Bptr = B + n*K; + for (size_t k = 0; k < K; ++k) { + int8x8_t a = vld1_s8(Aptr), + b = vld1_s8(Bptr); + int8x8_t tmp; + GAO(0); + GAO(1); + GAO(2); + GAO(3); + GAO(4); + GAO(5); + GAO(6); + GAO(7); + Aptr += 8; + Bptr += 8; + } + int16_t * __restrict Cptr = C + m*N + n; + + if (m+8 <= M && n+8 <= N) { + vst1q_s16(Cptr + 0*N, l0); + vst1q_s16(Cptr + 1*N, l1); + vst1q_s16(Cptr + 2*N, l2); + vst1q_s16(Cptr + 3*N, l3); + vst1q_s16(Cptr + 4*N, l4); + vst1q_s16(Cptr + 5*N, l5); + vst1q_s16(Cptr + 6*N, l6); + vst1q_s16(Cptr + 7*N, l7); + } else if (m+8 <=M && n+8 > N) { // m+8<=M && n+8<=N && n+8>N + size_t plen = N - n; + STORE_PARTRIAL_N(0) + STORE_PARTRIAL_N(1) + STORE_PARTRIAL_N(2) + STORE_PARTRIAL_N(3) + STORE_PARTRIAL_N(4) + STORE_PARTRIAL_N(5) + STORE_PARTRIAL_N(6) + STORE_PARTRIAL_N(7) + } else if(n+8 <= N) { // m+8>M && n+8<=N + size_t plen = M - m; + while(1) { + STORE_PARTRIAL_M(0) + STORE_PARTRIAL_M(1) + STORE_PARTRIAL_M(2) + STORE_PARTRIAL_M(3) + STORE_PARTRIAL_M(4) + STORE_PARTRIAL_M(5) + STORE_PARTRIAL_M(6) + break; + } + } else { + int16_t cache[8*8]; + vst1q_s16(cache + 0*8, l0); + vst1q_s16(cache + 1*8, l1); + vst1q_s16(cache + 2*8, l2); + vst1q_s16(cache + 3*8, l3); + vst1q_s16(cache + 4*8, l4); + vst1q_s16(cache + 5*8, l5); + vst1q_s16(cache + 6*8, l6); + vst1q_s16(cache + 7*8, l7); + + for (size_t i = 0; m+i < M && i < 8; ++i) + for (size_t j = 0; n+j < N && j < 8; ++j) + { + Cptr[i*N + j] = cache[i*8 + j]; + } + + } + } + } +} +#undef GAO +#undef STORE_REMAIN_N +#undef STORE_PARTRIAL_N +#undef STORE_PARTRIAL_M + +#undef GAO_16 +#undef STORE_16 +#undef STORE_REMAIN_N_16 +#undef STORE_PARTRIAL_N_16 +#undef STORE_PARTRIAL_M_16 + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.h b/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.h new file mode 100644 index 00000000..50fb8186 --- /dev/null +++ b/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.h @@ -0,0 +1,26 @@ +/** + * \file dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include + +namespace megdnn { +namespace arm_common { + +///! Row-major gemm +void exec_gemm_int8_int8_int16(const int8_t* A, const int8_t* B, int16_t* C, + size_t M, size_t K, size_t N, size_t LDB, + int8_t* w0, int8_t* w1); + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp b/dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp new file mode 100644 index 00000000..38b58165 --- /dev/null +++ b/dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp @@ -0,0 +1,879 @@ +/** + * \file dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#include "src/arm_common/matrix_mul/fp16/hgemv.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace { + +#define UNROLL_OUT(cb, step) UNROLL_CALL_RAW(step, cb) + +void hgemv_naive_n(const __fp16* __restrict A, const __fp16* __restrict B, + __fp16* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride) { + megdnn_assert(N == 1 && Bstride == 1); +#define vaddvq_f16(v) \ + ((v)[0] + (v)[1] + (v)[2] + (v)[3] + (v)[4] + (v)[5] + (v)[6] + (v)[7]) + size_t m = 0; + for (; m + 4 <= M; m += 4) { + float16x8_t a0, a1, a2, a3, b0; + float16x8_t sum0, sum1, sum2, sum3; + sum0 = vdupq_n_f16(0.f); + sum1 = vdupq_n_f16(0.f); + sum2 = vdupq_n_f16(0.f); + sum3 = vdupq_n_f16(0.f); + size_t k = 0; + for (; k + 8 <= K; k += 8) { + a0 = vld1q_f16(A + (m + 0) * Astride + k); + a1 = vld1q_f16(A + (m + 1) * Astride + k); + a2 = vld1q_f16(A + (m + 2) * Astride + k); + a3 = vld1q_f16(A + (m + 3) * Astride + k); + b0 = vld1q_f16(B + k); + sum0 = vmlaq_f16(sum0, a0, b0); + sum1 = vmlaq_f16(sum1, a1, b0); + sum2 = vmlaq_f16(sum2, a2, b0); + sum3 = vmlaq_f16(sum3, a3, b0); + } + for (; k < K; ++k) { + sum0[0] += A[(m + 0) * Astride + k] * B[k]; + sum1[0] += A[(m + 1) * Astride + k] * B[k]; + sum2[0] += A[(m + 2) * Astride + k] * B[k]; + sum3[0] += A[(m + 3) * Astride + k] * B[k]; + } + C[(m + 0) * Cstride] = vaddvq_f16(sum0); + C[(m + 1) * Cstride] = vaddvq_f16(sum1); + C[(m + 2) * Cstride] = vaddvq_f16(sum2); + C[(m + 3) * Cstride] = vaddvq_f16(sum3); + } + for (; m + 2 <= M; m += 2) { + float16x8_t a0, a1, b0; + float16x8_t sum0, sum1; + sum0 = vdupq_n_f16(0.f); + sum1 = vdupq_n_f16(0.f); + size_t k = 0; + for (; k + 8 <= K; k += 8) { + a0 = vld1q_f16(A + (m + 0) * Astride + k); + a1 = vld1q_f16(A + (m + 1) * Astride + k); + b0 = vld1q_f16(B + k); + sum0 = vmlaq_f16(sum0, a0, b0); + sum1 = vmlaq_f16(sum1, a1, b0); + } + for (; k < K; ++k) { + sum0[0] += A[(m + 0) * Astride + k] * B[k]; + sum1[0] += A[(m + 1) * Astride + k] * B[k]; + } + C[(m + 0) * Cstride] = vaddvq_f16(sum0); + C[(m + 1) * Cstride] = vaddvq_f16(sum1); + } + for (; m < M; m += 1) { + float16x8_t a0, b0; + float16x8_t sum0; + sum0 = vdupq_n_f16(0.f); + size_t k = 0; + for (; k + 8 <= K; k += 8) { + a0 = vld1q_f16(A + (m + 0) * Astride + k); + b0 = vld1q_f16(B + k); + sum0 = vfmaq_f16(sum0, a0, b0); + } + for (; k < K; k += 1) { + sum0[0] += A[(m + 0) * Astride + k] * B[k]; + } + C[(m + 0) * Cstride] = vaddvq_f16(sum0); + } +#undef vaddvq_f16 +} +} // namespace + +void megdnn::arm_common::hgemv_exec(const __fp16* __restrict A, + const __fp16* __restrict B, + __fp16* __restrict C, size_t M, size_t N, + size_t K, size_t Astride, size_t Bstride, + size_t Cstride) { + megdnn_assert((M <= 4) || (M == 8 && K <= 2) || (N == 1 && Bstride == 1)); + if (N == 1) { + return hgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); + } + size_t m = 0; + for (; m + 4 <= M; m += 4) { + size_t k = 0; + memset(C + m * Cstride, 0, 4 * sizeof(__fp16) * N); + for (; k + 4 <= K; k += 4) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + float16x8_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, + a22, a23, a30, a31, a32, a33; + float16x8_t b0, b1, b2, b3; + float16x8_t c0, c1, c2, c3; +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); +#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]); +#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]); + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadB, 4) + UNROLL_OUT(loadA0, 4) + UNROLL_OUT(loadA1, 4) + UNROLL_OUT(loadA2, 4) + UNROLL_OUT(loadA3, 4) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#undef loadA2 +#undef loadA3 +#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i); +#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i); +#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i); + UNROLL_OUT(calculate_row0, 4) + UNROLL_OUT(calculate_row1, 4) + UNROLL_OUT(calculate_row2, 4) + UNROLL_OUT(calculate_row3, 4) +#undef calculate_row0 +#undef calculate_row1 +#undef calculate_row2 +#undef calculate_row3 +#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 4) +#undef vstore + } + for (; n + 4 <= N; n += 4) { + float16x4_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, + a22, a23, a30, a31, a32, a33; + float16x4_t b0, b1, b2, b3; + float16x4_t c0, c1, c2, c3; +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); +#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]); +#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]); + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadB, 4) + UNROLL_OUT(loadA0, 4) + UNROLL_OUT(loadA1, 4) + UNROLL_OUT(loadA2, 4) + UNROLL_OUT(loadA3, 4) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#undef loadA2 +#undef loadA3 +#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i); +#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i); +#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i); + UNROLL_OUT(calculate_row0, 4) + UNROLL_OUT(calculate_row1, 4) + UNROLL_OUT(calculate_row2, 4) + UNROLL_OUT(calculate_row3, 4) +#undef calculate_row0 +#undef calculate_row1 +#undef calculate_row2 +#undef calculate_row3 +#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 4) +#undef vstore + } + for (; n < N; n += 1) { + __fp16 a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, + a23, a30, a31, a32, a33; + __fp16 b0, b1, b2, b3; + __fp16 c0, c1, c2, c3; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadB, 4) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[m * Astride + k + i]; +#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; +#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i]; +#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i]; + UNROLL_OUT(loadA0, 4) + UNROLL_OUT(loadA1, 4) + UNROLL_OUT(loadA2, 4) + UNROLL_OUT(loadA3, 4) +#undef loadA0 +#undef loadA1 +#undef loadA2 +#undef loadA3 + c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3; + c1 += a10 * b0 + a11 * b1 + a12 * b2 + a13 * b3; + c2 += a20 * b0 + a21 * b1 + a22 * b2 + a23 * b3; + c3 += a30 * b0 + a31 * b1 + a32 * b2 + a33 * b3; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 4) +#undef vstore + } + } + for (; k + 2 <= K; k += 2) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + float16x8_t a00, a01, a10, a11, a20, a21, a30, a31; + float16x8_t b0, b1; + float16x8_t c0, c1, c2, c3; +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); +#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]); +#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]); + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadB, 2) + UNROLL_OUT(loadA0, 2) + UNROLL_OUT(loadA1, 2) + UNROLL_OUT(loadA2, 2) + UNROLL_OUT(loadA3, 2) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#undef loadA2 +#undef loadA3 +#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i); +#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i); +#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i); + UNROLL_OUT(calculate_row0, 2) + UNROLL_OUT(calculate_row1, 2) + UNROLL_OUT(calculate_row2, 2) + UNROLL_OUT(calculate_row3, 2) +#undef calculate_row0 +#undef calculate_row1 +#undef calculate_row2 +#undef calculate_row3 +#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 4) +#undef vstore + } + for (; n + 4 <= N; n += 4) { + float16x4_t a00, a01, a10, a11, a20, a21, a30, a31; + float16x4_t b0, b1; + float16x4_t c0, c1, c2, c3; +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); +#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]); +#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]); + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadB, 2) + UNROLL_OUT(loadA0, 2) + UNROLL_OUT(loadA1, 2) + UNROLL_OUT(loadA2, 2) + UNROLL_OUT(loadA3, 2) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#undef loadA2 +#undef loadA3 +#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i); +#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i); +#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i); + UNROLL_OUT(calculate_row0, 2) + UNROLL_OUT(calculate_row1, 2) + UNROLL_OUT(calculate_row2, 2) + UNROLL_OUT(calculate_row3, 2) +#undef calculate_row0 +#undef calculate_row1 +#undef calculate_row2 +#undef calculate_row3 +#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 4) +#undef vstore + } + for (; n < N; n += 1) { + __fp16 a00, a01, a10, a11, a20, a21, a30, a31; + __fp16 b0, b1; + __fp16 c0, c1, c2, c3; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadB, 2) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; +#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; +#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i]; +#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i]; + UNROLL_OUT(loadA0, 2) + UNROLL_OUT(loadA1, 2) + UNROLL_OUT(loadA2, 2) + UNROLL_OUT(loadA3, 2) +#undef loadA0 +#undef loadA1 +#undef loadA2 +#undef loadA3 + c0 += a00 * b0 + a01 * b1; + c1 += a10 * b0 + a11 * b1; + c2 += a20 * b0 + a21 * b1; + c3 += a30 * b0 + a31 * b1; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 4) +#undef vstore + } + } + for (; k < K; k += 1) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + float16x8_t a00, a10, a20, a30; + float16x8_t b0; + float16x8_t c0, c1, c2, c3; +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); +#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]); +#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]); + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadB, 1) + UNROLL_OUT(loadA0, 1) + UNROLL_OUT(loadA1, 1) + UNROLL_OUT(loadA2, 1) + UNROLL_OUT(loadA3, 1) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#undef loadA2 +#undef loadA3 +#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i); +#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i); +#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i); + UNROLL_OUT(calculate_row0, 1) + UNROLL_OUT(calculate_row1, 1) + UNROLL_OUT(calculate_row2, 1) + UNROLL_OUT(calculate_row3, 1) +#undef calculate_row0 +#undef calculate_row1 +#undef calculate_row2 +#undef calculate_row3 +#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 4) +#undef vstore + } + for (; n + 4 <= N; n += 4) { + float16x4_t a00, a10, a20, a30; + float16x4_t b0; + float16x4_t c0, c1, c2, c3; +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); +#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]); +#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]); + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadB, 1) + UNROLL_OUT(loadA0, 1) + UNROLL_OUT(loadA1, 1) + UNROLL_OUT(loadA2, 1) + UNROLL_OUT(loadA3, 1) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#undef loadA2 +#undef loadA3 +#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i); +#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i); +#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i); + UNROLL_OUT(calculate_row0, 1) + UNROLL_OUT(calculate_row1, 1) + UNROLL_OUT(calculate_row2, 1) + UNROLL_OUT(calculate_row3, 1) +#undef calculate_row0 +#undef calculate_row1 +#undef calculate_row2 +#undef calculate_row3 +#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 4) +#undef vstore + } + for (; n < N; n += 1) { + __fp16 a00, a10, a20, a30; + __fp16 b0; + __fp16 c0, c1, c2, c3; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadB, 1) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; +#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; +#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i]; +#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i]; + UNROLL_OUT(loadA0, 1) + UNROLL_OUT(loadA1, 1) + UNROLL_OUT(loadA2, 1) + UNROLL_OUT(loadA3, 1) +#undef loadA0 +#undef loadA1 +#undef loadA2 +#undef loadA3 + c0 = c0 + a00 * b0; + c1 = c1 + a10 * b0; + c2 = c2 + a20 * b0; + c3 = c3 + a30 * b0; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 4) +#undef vstore + } + } + } + for (; m + 2 <= M; m += 2) { + size_t k = 0; + memset(C + m * Cstride, 0, 2 * sizeof(__fp16) * N); + for (; k + 4 <= K; k += 4) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + float16x8_t a00, a01, a02, a03, a10, a11, a12, a13; + float16x8_t b0, b1, b2, b3; + float16x8_t c0, c1; +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadB, 4) + UNROLL_OUT(loadA0, 4) + UNROLL_OUT(loadA1, 4) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i); + UNROLL_OUT(calculate_row0, 4) + UNROLL_OUT(calculate_row1, 4) +#undef calculate_row0 +#undef calculate_row1 +#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 2) +#undef vstore + } + for (; n + 4 <= N; n += 4) { + float16x4_t a00, a01, a02, a03, a10, a11, a12, a13; + float16x4_t b0, b1, b2, b3; + float16x4_t c0, c1; +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadB, 4) + UNROLL_OUT(loadA0, 4) + UNROLL_OUT(loadA1, 4) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i); + UNROLL_OUT(calculate_row0, 4) + UNROLL_OUT(calculate_row1, 4) +#undef calculate_row0 +#undef calculate_row1 +#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 2) +#undef vstore + } + for (; n < N; n += 1) { + __fp16 a00, a01, a02, a03, a10, a11, a12, a13; + __fp16 b0, b1, b2, b3; + __fp16 c0, c1; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadB, 4) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[m * Astride + k + i]; +#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; + UNROLL_OUT(loadA0, 4) + UNROLL_OUT(loadA1, 4) +#undef loadA0 +#undef loadA1 + c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3; + c1 += a10 * b0 + a11 * b1 + a12 * b2 + a13 * b3; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 2) +#undef vstore + } + } + for (; k + 2 <= K; k += 2) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + float16x8_t a00, a01, a10, a11; + float16x8_t b0, b1; + float16x8_t c0, c1; +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadB, 2) + UNROLL_OUT(loadA0, 2) + UNROLL_OUT(loadA1, 2) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i); + UNROLL_OUT(calculate_row0, 2) + UNROLL_OUT(calculate_row1, 2) +#undef calculate_row0 +#undef calculate_row1 +#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 2) +#undef vstore + } + for (; n + 4 <= N; n += 4) { + float16x4_t a00, a01, a10, a11; + float16x4_t b0, b1; + float16x4_t c0, c1; +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadB, 2) + UNROLL_OUT(loadA0, 2) + UNROLL_OUT(loadA1, 2) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i); + UNROLL_OUT(calculate_row0, 2) + UNROLL_OUT(calculate_row1, 2) +#undef calculate_row0 +#undef calculate_row1 +#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 2) +#undef vstore + } + for (; n < N; n += 1) { + __fp16 a00, a01, a10, a11; + __fp16 b0, b1; + __fp16 c0, c1; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadB, 2) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; +#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; + UNROLL_OUT(loadA0, 2) + UNROLL_OUT(loadA1, 2) +#undef loadA0 +#undef loadA1 + c0 += a00 * b0 + a01 * b1; + c1 += a10 * b0 + a11 * b1; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 2) +#undef vstore + } + } + for (; k < K; k += 1) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + float16x8_t a00, a10; + float16x8_t b0; + float16x8_t c0, c1; +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadB, 1) + UNROLL_OUT(loadA0, 1) + UNROLL_OUT(loadA1, 1) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i); + UNROLL_OUT(calculate_row0, 1) + UNROLL_OUT(calculate_row1, 1) +#undef calculate_row0 +#undef calculate_row1 +#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 2) +#undef vstore + } + for (; n + 4 <= N; n += 4) { + float16x4_t a00, a10; + float16x4_t b0; + float16x4_t c0, c1; +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadB, 1) + UNROLL_OUT(loadA0, 1) + UNROLL_OUT(loadA1, 1) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i); + UNROLL_OUT(calculate_row0, 1) + UNROLL_OUT(calculate_row1, 1) +#undef calculate_row0 +#undef calculate_row1 +#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 2) +#undef vstore + } + for (; n < N; n += 1) { + __fp16 a00, a10; + __fp16 b0; + __fp16 c0, c1; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadB, 1) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; +#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; + UNROLL_OUT(loadA0, 1) + UNROLL_OUT(loadA1, 1) +#undef loadA0 +#undef loadA1 + c0 = c0 + a00 * b0; + c1 = c1 + a10 * b0; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 2) +#undef vstore + } + } + } + for (; m < M; m += 1) { + size_t k = 0; + memset(C + m * Cstride, 0, sizeof(__fp16) * N); + for (; k + 4 <= K; k += 4) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + float16x8_t a00, a01, a02, a03; + float16x8_t b0, b1, b2, b3; + float16x8_t c0; +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 4) + UNROLL_OUT(loadA0, 4) +#undef loadB +#undef loadC +#undef loadA0 +#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); + UNROLL_OUT(calculate_row0, 4) +#undef calculate_row0 +#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 1) +#undef vstore + } + for (; n + 4 <= N; n += 4) { + float16x4_t a00, a01, a02, a03; + float16x4_t b0, b1, b2, b3; + float16x4_t c0; +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 4) + UNROLL_OUT(loadA0, 4) +#undef loadB +#undef loadC +#undef loadA0 +#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); + UNROLL_OUT(calculate_row0, 4) +#undef calculate_row0 +#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 1) +#undef vstore + } + for (; n < N; n += 1) { + __fp16 a00, a01, a02, a03; + __fp16 b0, b1, b2, b3; + __fp16 c0; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 4) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[m * Astride + k + i]; + UNROLL_OUT(loadA0, 4) +#undef loadA0 + c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 1) +#undef vstore + } + } + for (; k + 2 <= K; k += 2) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + float16x8_t a00, a01; + float16x8_t b0, b1; + float16x8_t c0; +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 2) + UNROLL_OUT(loadA0, 2) +#undef loadB +#undef loadC +#undef loadA0 +#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); + UNROLL_OUT(calculate_row0, 2) +#undef calculate_row0 +#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 1) +#undef vstore + } + for (; n + 4 <= N; n += 4) { + float16x4_t a00, a01; + float16x4_t b0, b1; + float16x4_t c0; +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 2) + UNROLL_OUT(loadA0, 2) +#undef loadB +#undef loadC +#undef loadA0 +#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); + UNROLL_OUT(calculate_row0, 2) +#undef calculate_row0 +#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 1) +#undef vstore + } + for (; n < N; n += 1) { + __fp16 a00, a01; + __fp16 b0, b1; + __fp16 c0; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 2) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; + UNROLL_OUT(loadA0, 2) +#undef loadA0 + c0 += a00 * b0 + a01 * b1; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 1) +#undef vstore + } + } + for (; k < K; k += 1) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + float16x8_t a00; + float16x8_t b0; + float16x8_t c0; +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 1) + UNROLL_OUT(loadA0, 1) +#undef loadB +#undef loadC +#undef loadA0 +#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); + UNROLL_OUT(calculate_row0, 1) +#undef calculate_row0 +#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 1) +#undef vstore + } + for (; n + 4 <= N; n += 4) { + float16x4_t a00; + float16x4_t b0; + float16x4_t c0; +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 1) + UNROLL_OUT(loadA0, 1) +#undef loadB +#undef loadC +#undef loadA0 +#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); + UNROLL_OUT(calculate_row0, 1) +#undef calculate_row0 +#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 1) +#undef vstore + } + for (; n < N; n += 1) { + __fp16 a00; + __fp16 b0; + __fp16 c0; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 1) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; + UNROLL_OUT(loadA0, 1) +#undef loadA0 + c0 = c0 + a00 * b0; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 1) +#undef vstore + } + } + } +} +bool megdnn::arm_common::is_hgemv_preferred(bool transposeA, bool transposeB, + size_t M, size_t N, size_t K, + size_t /*LDA*/, size_t LDB, + size_t /*LDC*/) { + if (transposeA) + return false; + if (transposeB) + return false; + + return M <= 4 || (M <= 8 && K <= 2) || (N == 1 && LDB == 1); +} + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/fp16/hgemv.h b/dnn/src/arm_common/matrix_mul/fp16/hgemv.h new file mode 100644 index 00000000..8a2fa952 --- /dev/null +++ b/dnn/src/arm_common/matrix_mul/fp16/hgemv.h @@ -0,0 +1,30 @@ +/** + * \file dnn/src/arm_common/matrix_mul/fp16/hgemv.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +namespace megdnn { +namespace arm_common { + +void hgemv_exec(const __fp16* __restrict A, const __fp16* __restrict B, + __fp16* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride); + +bool is_hgemv_preferred(bool transposeA, bool transposeB, size_t M, size_t N, + size_t K, size_t /*LDA*/, size_t LDB, size_t /*LDC*/); + +} // namespace aarch64 +} // namespace megdnn + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.cpp b/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.cpp new file mode 100644 index 00000000..31e0f535 --- /dev/null +++ b/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.cpp @@ -0,0 +1,802 @@ +/** + * \file dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" +#include +#include "include/megdnn/oprs.h" +#include "midout.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" + +MIDOUT_DECL(megdnn_fp32_sgemv) + +using namespace megdnn; +using namespace arm_common; + +namespace { + +#define UNROLL_OUT(cb, step) UNROLL_CALL_RAW(step, cb) + +#if !defined(__aarch64__) +#define vaddvq_f32(v) (v)[0] + (v)[1] + (v)[2] + (v)[3] +#endif +void sgemv_naive_n(const float* __restrict A, const float* __restrict B, + float* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride) { + megdnn_assert(N == 1 && Bstride == 1); +#define reset_acc(i) acc##i = 0; +#define acc_calu(i) acc##i += A[(m + i) * Astride + k] * B[k]; +#define vdupq_sum(i) sum##i = vdupq_n_f32(0.f); +#define loadA(i) a##i = vld1q_f32(A + (m + i) * Astride + k); +#define loadB(i) b##i = vld1q_f32(B + k); +#define calculate(i) sum##i = vmlaq_f32(sum##i, a##i, b0); +#define vstore(i) C[(m + i) * Cstride] = vaddvq_f32(sum##i) + acc##i; + size_t m = 0; + for (; m + 4 <= M; m += 4) { + float acc0, acc1, acc2, acc3; + float32x4_t a0, a1, a2, a3, b0; + float32x4_t sum0, sum1, sum2, sum3; + UNROLL_OUT(vdupq_sum, 4) + size_t k = 0; + for (; k + 4 <= K; k += 4) { + UNROLL_OUT(loadA, 4) + UNROLL_OUT(loadB, 1) + UNROLL_OUT(calculate, 4) + } + UNROLL_OUT(reset_acc, 4) + for (; k < K; ++k) { + UNROLL_OUT(acc_calu, 4) + } + UNROLL_OUT(vstore, 4) + } + for (; m + 2 <= M; m += 2) { + float acc0, acc1; + float32x4_t a0, a1, b0; + float32x4_t sum0, sum1; + UNROLL_OUT(vdupq_sum, 2) + size_t k = 0; + for (; k + 4 <= K; k += 4) { + UNROLL_OUT(loadA, 2) + UNROLL_OUT(loadB, 1) + UNROLL_OUT(calculate, 2) + } + UNROLL_OUT(reset_acc, 2) + for (; k < K; ++k) { + UNROLL_OUT(acc_calu, 2) + } + UNROLL_OUT(vstore, 2) + } + for (; m < M; m += 1) { + float acc0; + float32x4_t a0, b0; + float32x4_t sum0; + UNROLL_OUT(vdupq_sum, 1) + size_t k = 0; + for (; k + 4 <= K; k += 4) { + UNROLL_OUT(loadA, 1) + UNROLL_OUT(loadB, 1) + UNROLL_OUT(calculate, 1) + } + UNROLL_OUT(reset_acc, 1) + for (; k < K; ++k) { + UNROLL_OUT(acc_calu, 1) + } + UNROLL_OUT(vstore, 1) + } +#undef vdupq_sum +#undef loadA +#undef loadB +#undef calculate +#undef vstore +} +#if !defined(__aarch64__) +#undef vaddvq_f32 +#endif +} // namespace + +namespace megdnn { +namespace arm_common { + +void sgemm_sgemv_like(const float* __restrict A, const float* __restrict B, + float* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride) { + megdnn_assert(M < 8 || (M == 8 && K <= 2) || (N == 1 && Bstride == 1)); + if (N == 1) { + return sgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); + } + size_t m = 0; + for (; m + 4 <= M; m += 4) { + size_t k = 0; + memset(C + m * Cstride, 0, 4 * sizeof(float) * N); + for (; k + 4 <= K; k += 4) { + size_t n = 0; + for (; n + 4 <= N; n += 4) { + float32x4_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, + a22, a23, a30, a31, a32, a33; + float32x4_t b0, b1, b2, b3; + float32x4_t c0, c1, c2, c3; +#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); +#define loadA0(i) a0##i = vdupq_n_f32(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdupq_n_f32(A[(m + 1) * Astride + k + i]); +#define loadA2(i) a2##i = vdupq_n_f32(A[(m + 2) * Astride + k + i]); +#define loadA3(i) a3##i = vdupq_n_f32(A[(m + 3) * Astride + k + i]); + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadB, 4) + UNROLL_OUT(loadA0, 4) + UNROLL_OUT(loadA1, 4) + UNROLL_OUT(loadA2, 4) + UNROLL_OUT(loadA3, 4) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#undef loadA2 +#undef loadA3 +#define calculate_row0(i) c0 = vmlaq_f32(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vmlaq_f32(c1, b##i, a1##i); +#define calculate_row2(i) c2 = vmlaq_f32(c2, b##i, a2##i); +#define calculate_row3(i) c3 = vmlaq_f32(c3, b##i, a3##i); + UNROLL_OUT(calculate_row0, 4) + UNROLL_OUT(calculate_row1, 4) + UNROLL_OUT(calculate_row2, 4) + UNROLL_OUT(calculate_row3, 4) +#undef calculate_row0 +#undef calculate_row1 +#undef calculate_row2 +#undef calculate_row3 +#define vstore(i) vst1q_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 4) +#undef vstore + } + for (; n + 2 <= N; n += 2) { + float32x4_t a0, a1, a2, a3; + float32x2_t b0, b1, b2, b3; + float32x2_t c0, c1, c2, c3; +#define loadA(i) a##i = vld1q_f32(A + (m + i) * Astride + k); +#define loadB(i) b##i = vld1_f32(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f32(C + (m + i) * Cstride + n); + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadA, 4) + UNROLL_OUT(loadB, 4) +#undef loadA +#undef loadB +#undef loadC +#define calculateB0(i) c##i = vmla_lane_f32(c##i, b0, vget_low_f32(a##i), 0); +#define calculateB1(i) c##i = vmla_lane_f32(c##i, b1, vget_low_f32(a##i), 1); +#define calculateB2(i) c##i = vmla_lane_f32(c##i, b2, vget_high_f32(a##i), 0); +#define calculateB3(i) c##i = vmla_lane_f32(c##i, b3, vget_high_f32(a##i), 1); + UNROLL_OUT(calculateB0, 4) + UNROLL_OUT(calculateB1, 4) + UNROLL_OUT(calculateB2, 4) + UNROLL_OUT(calculateB3, 4) +#undef calculateB0 +#undef calculateB1 +#undef calculateB2 +#undef calculateB3 +#define vstore(i) vst1_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 4) +#undef vstore + } + for (; n < N; n += 1) { + float a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, + a23, a30, a31, a32, a33; + float b0, b1, b2, b3; + float c0, c1, c2, c3; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadB, 4) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[m * Astride + k + i]; +#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; +#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i]; +#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i]; + UNROLL_OUT(loadA0, 4) + UNROLL_OUT(loadA1, 4) + UNROLL_OUT(loadA2, 4) + UNROLL_OUT(loadA3, 4) +#undef loadA0 +#undef loadA1 +#undef loadA2 +#undef loadA3 + c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3; + c1 += a10 * b0 + a11 * b1 + a12 * b2 + a13 * b3; + c2 += a20 * b0 + a21 * b1 + a22 * b2 + a23 * b3; + c3 += a30 * b0 + a31 * b1 + a32 * b2 + a33 * b3; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 4) +#undef vstore + } + } + for (; k + 2 <= K; k += 2) { + size_t n = 0; + for (; n + 4 <= N; n += 4) { + float32x2_t a0, a1, a2, a3; + float32x4_t b0, b1; + float32x4_t c0, c1, c2, c3; +#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); +#define loadA(i) a##i = vld1_f32(A + (m + i) * Astride + k); +#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadA, 4) + UNROLL_OUT(loadB, 2) +#undef loadA +#undef loadC +#undef loadB +#define calculateB0(i) c##i = vmlaq_lane_f32(c##i, b0, a##i, 0); +#define calculateB1(i) c##i = vmlaq_lane_f32(c##i, b1, a##i, 1); + UNROLL_OUT(calculateB0, 4) + UNROLL_OUT(calculateB1, 4) +#undef calculateB0 +#undef calculateB1 +#define vstore(i) vst1q_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 4) +#undef vstore + } + for (; n + 2 <= N; n += 2) { + float32x2_t a0, a1, a2, a3; + float32x2_t b0, b1; + float32x2_t c0, c1, c2, c3; +#define loadC(i) c##i = vld1_f32(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f32(B + (k + i) * Bstride + n); +#define loadA(i) a##i = vld1_f32(A + (m + i) * Astride + k); + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadA, 4) + UNROLL_OUT(loadB, 2) +#undef loadA +#undef loadB +#undef loadC +#define calculateB0(i) c##i = vmla_lane_f32(c##i, b0, a##i, 0); +#define calculateB1(i) c##i = vmla_lane_f32(c##i, b1, a##i, 1); + UNROLL_OUT(calculateB0, 4) + UNROLL_OUT(calculateB1, 4) +#undef calculateB0 +#undef calculateB1 +#define vstore(i) vst1_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 4) +#undef vstore + } + for (; n < N; n += 1) { + float a00, a01, a10, a11, a20, a21, a30, a31; + float b0, b1; + float c0, c1, c2, c3; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadB, 2) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; +#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; +#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i]; +#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i]; + UNROLL_OUT(loadA0, 2) + UNROLL_OUT(loadA1, 2) + UNROLL_OUT(loadA2, 2) + UNROLL_OUT(loadA3, 2) +#undef loadA0 +#undef loadA1 +#undef loadA2 +#undef loadA3 + c0 += a00 * b0 + a01 * b1; + c1 += a10 * b0 + a11 * b1; + c2 += a20 * b0 + a21 * b1; + c3 += a30 * b0 + a31 * b1; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 4) +#undef vstore + } + } + for (; k < K; k += 1) { + size_t n = 0; + for (; n + 4 <= N; n += 4) { + float32x4_t a0, a1, a2, a3; + float32x4_t b0; + float32x4_t c0, c1, c2, c3; +#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); +#define loadA(i) a##i = vdupq_n_f32(A[(m + i) * Astride + k]); + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadA, 4) + UNROLL_OUT(loadB, 1) +#undef loadA +#undef loadB +#undef loadC +#define calculateB0(i) c##i = vmlaq_f32(c##i, a##i, b0); + UNROLL_OUT(calculateB0, 4) +#undef calculateB0 +#define vstore(i) vst1q_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 4) +#undef vstore + } + for (; n + 2 <= N; n += 2) { + float32x2_t a0, a1, a2, a3; + float32x2_t b0; + float32x2_t c0, c1, c2, c3; +#define loadC(i) c##i = vld1_f32(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f32(B + (k + i) * Bstride + n); +#define loadA(i) a##i = vdup_n_f32(A[(m + i) * Astride + k]); + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadA, 4) + UNROLL_OUT(loadB, 1) +#undef loadA +#undef loadB +#undef loadC +#define calculateB0(i) c##i = vmla_f32(c##i, a##i, b0); + UNROLL_OUT(calculateB0, 4) +#undef calculateB0 +#define vstore(i) vst1_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 4) +#undef vstore + } + for (; n < N; n += 1) { + float a00, a10, a20, a30; + float b0; + float c0, c1, c2, c3; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 4) + UNROLL_OUT(loadB, 1) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; +#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; +#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i]; +#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i]; + UNROLL_OUT(loadA0, 1) + UNROLL_OUT(loadA1, 1) + UNROLL_OUT(loadA2, 1) + UNROLL_OUT(loadA3, 1) +#undef loadA0 +#undef loadA1 +#undef loadA2 +#undef loadA3 + c0 = c0 + a00 * b0; + c1 = c1 + a10 * b0; + c2 = c2 + a20 * b0; + c3 = c3 + a30 * b0; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 4) +#undef vstore + } + } + } + for (; m + 2 <= M; m += 2) { + size_t k = 0; + memset(C + m * Cstride, 0, 2 * sizeof(float) * N); + for (; k + 4 <= K; k += 4) { + size_t n = 0; + for (; n + 4 <= N; n += 4) { + float32x4_t a00, a01, a02, a03, a10, a11, a12, a13; + float32x4_t b0, b1, b2, b3; + float32x4_t c0, c1; +#define loadA0(i) a0##i = vdupq_n_f32(A[(m + 0) * Astride + k + i]); +#define loadA1(i) a1##i = vdupq_n_f32(A[(m + 1) * Astride + k + i]); +#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadB, 4) + UNROLL_OUT(loadA0, 4) + UNROLL_OUT(loadA1, 4) +#undef loadB +#undef loadC +#undef loadA0 +#undef loadA1 +#define calculate_row0(i) c0 = vmlaq_f32(c0, b##i, a0##i); +#define calculate_row1(i) c1 = vmlaq_f32(c1, b##i, a1##i); + UNROLL_OUT(calculate_row0, 4) + UNROLL_OUT(calculate_row1, 4) +#undef calculate_row0 +#undef calculate_row1 +#define vstore(i) vst1q_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 2) +#undef vstore + } + for (; n + 2 <= N; n += 2) { + float32x4_t a0, a1; + float32x2_t b0, b1, b2, b3; + float32x2_t c0, c1; +#define loadA(i) a##i = vld1q_f32(A + (m + i) * Astride + k); +#define loadB(i) b##i = vld1_f32(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f32(C + (m + i) * Cstride + n); + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadA, 2) + UNROLL_OUT(loadB, 4) +#undef loadA +#undef loadB +#undef loadC +#define calculateB0(i) c##i = vmla_lane_f32(c##i, b0, vget_low_f32(a##i), 0); +#define calculateB1(i) c##i = vmla_lane_f32(c##i, b1, vget_low_f32(a##i), 1); +#define calculateB2(i) c##i = vmla_lane_f32(c##i, b2, vget_high_f32(a##i), 0); +#define calculateB3(i) c##i = vmla_lane_f32(c##i, b3, vget_high_f32(a##i), 1); + UNROLL_OUT(calculateB0, 2) + UNROLL_OUT(calculateB1, 2) + UNROLL_OUT(calculateB2, 2) + UNROLL_OUT(calculateB3, 2) +#undef calculateB0 +#undef calculateB1 +#undef calculateB2 +#undef calculateB3 +#define vstore(i) vst1_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 2) +#undef vstore + } + for (; n < N; n += 1) { + float a00, a01, a02, a03, a10, a11, a12, a13; + float b0, b1, b2, b3; + float c0, c1; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadB, 4) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[m * Astride + k + i]; +#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; + UNROLL_OUT(loadA0, 4) + UNROLL_OUT(loadA1, 4) +#undef loadA0 +#undef loadA1 + c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3; + c1 += a10 * b0 + a11 * b1 + a12 * b2 + a13 * b3; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 2) +#undef vstore + } + } + for (; k + 2 <= K; k += 2) { + size_t n = 0; + for (; n + 4 <= N; n += 4) { + float32x2_t a0, a1; + float32x4_t b0, b1; + float32x4_t c0, c1; +#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); +#define loadA(i) a##i = vld1_f32(A + (m + i) * Astride + k); +#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadA, 2) + UNROLL_OUT(loadB, 2) +#undef loadA +#undef loadC +#undef loadB +#define calculateB0(i) c##i = vmlaq_lane_f32(c##i, b0, a##i, 0); +#define calculateB1(i) c##i = vmlaq_lane_f32(c##i, b1, a##i, 1); + UNROLL_OUT(calculateB0, 2) + UNROLL_OUT(calculateB1, 2) +#undef calculateB0 +#undef calculateB1 +#define vstore(i) vst1q_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 2) +#undef vstore + } + for (; n + 2 <= N; n += 2) { + float32x2_t a0, a1; + float32x2_t b0, b1; + float32x2_t c0, c1; +#define loadC(i) c##i = vld1_f32(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f32(B + (k + i) * Bstride + n); +#define loadA(i) a##i = vld1_f32(A + (m + i) * Astride + k); + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadA, 2) + UNROLL_OUT(loadB, 2) +#undef loadA +#undef loadB +#undef loadC +#define calculateB0(i) c##i = vmla_lane_f32(c##i, b0, a##i, 0); +#define calculateB1(i) c##i = vmla_lane_f32(c##i, b1, a##i, 1); + UNROLL_OUT(calculateB0, 2) + UNROLL_OUT(calculateB1, 2) +#undef calculateB0 +#undef calculateB1 +#define vstore(i) vst1_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 2) +#undef vstore + } + for (; n < N; n += 1) { + float a00, a01, a10, a11; + float b0, b1; + float c0, c1; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadB, 2) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; +#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; + UNROLL_OUT(loadA0, 2) + UNROLL_OUT(loadA1, 2) +#undef loadA0 +#undef loadA1 + c0 += a00 * b0 + a01 * b1; + c1 += a10 * b0 + a11 * b1; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 2) +#undef vstore + } + } + for (; k < K; k += 1) { + size_t n = 0; + for (; n + 4 <= N; n += 4) { + float32x4_t a0, a1; + float32x4_t b0; + float32x4_t c0, c1; +#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); +#define loadA(i) a##i = vdupq_n_f32(A[(m + i) * Astride + k]); + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadA, 2) + UNROLL_OUT(loadB, 1) +#undef loadA +#undef loadB +#undef loadC +#define calculateB0(i) c##i = vmlaq_f32(c##i, a##i, b0); + UNROLL_OUT(calculateB0, 2) +#undef calculateB0 +#define vstore(i) vst1q_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 2) +#undef vstore + } + for (; n + 2 <= N; n += 2) { + float32x2_t a0, a1; + float32x2_t b0; + float32x2_t c0, c1; +#define loadC(i) c##i = vld1_f32(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f32(B + (k + i) * Bstride + n); +#define loadA(i) a##i = vdup_n_f32(A[(m + i) * Astride + k]); + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadA, 2) + UNROLL_OUT(loadB, 1) +#undef loadA +#undef loadB +#undef loadC +#define calculateB0(i) c##i = vmla_f32(c##i, a##i, b0); + UNROLL_OUT(calculateB0, 2) +#undef calculateB0 +#define vstore(i) vst1_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 2) +#undef vstore + } + for (; n < N; n += 1) { + float a00, a10; + float b0; + float c0, c1; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 2) + UNROLL_OUT(loadB, 1) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; +#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; + UNROLL_OUT(loadA0, 1) + UNROLL_OUT(loadA1, 1) +#undef loadA0 +#undef loadA1 + c0 = c0 + a00 * b0; + c1 = c1 + a10 * b0; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 2) +#undef vstore + } + } + } + for (; m < M; m += 1) { + size_t k = 0; + memset(C + m * Cstride, 0, sizeof(float) * N); + for (; k + 4 <= K; k += 4) { + size_t n = 0; + for (; n + 4 <= N; n += 4) { + float32x4_t a00, a01, a02, a03; + float32x4_t b0, b1, b2, b3; + float32x4_t c0; +#define loadA0(i) a0##i = vdupq_n_f32(A[m * Astride + k + i]); +#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 4) + UNROLL_OUT(loadA0, 4) +#undef loadB +#undef loadC +#undef loadA0 +#define calculate_row0(i) c0 = vmlaq_f32(c0, b##i, a0##i); + UNROLL_OUT(calculate_row0, 4) +#undef calculate_row0 +#define vstore(i) vst1q_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 1) +#undef vstore + } + for (; n + 2 <= N; n += 2) { + float32x4_t a0; + float32x2_t b0, b1, b2, b3; + float32x2_t c0; +#define loadA(i) a##i = vld1q_f32(A + (m + i) * Astride + k); +#define loadB(i) b##i = vld1_f32(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f32(C + (m + i) * Cstride + n); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadA, 1) + UNROLL_OUT(loadB, 4) +#undef loadA +#undef loadB +#undef loadC +#define calculateB0(i) c##i = vmla_lane_f32(c##i, b0, vget_low_f32(a##i), 0); +#define calculateB1(i) c##i = vmla_lane_f32(c##i, b1, vget_low_f32(a##i), 1); +#define calculateB2(i) c##i = vmla_lane_f32(c##i, b2, vget_high_f32(a##i), 0); +#define calculateB3(i) c##i = vmla_lane_f32(c##i, b3, vget_high_f32(a##i), 1); + UNROLL_OUT(calculateB0, 1) + UNROLL_OUT(calculateB1, 1) + UNROLL_OUT(calculateB2, 1) + UNROLL_OUT(calculateB3, 1) +#undef calculateB0 +#undef calculateB1 +#undef calculateB2 +#undef calculateB3 +#define vstore(i) vst1_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 1) +#undef vstore + } + for (; n < N; n += 1) { + float a00, a01, a02, a03; + float b0, b1, b2, b3; + float c0; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 4) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[m * Astride + k + i]; + UNROLL_OUT(loadA0, 4) +#undef loadA0 + c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 1) +#undef vstore + } + } + for (; k + 2 <= K; k += 2) { + size_t n = 0; + for (; n + 4 <= N; n += 4) { + float32x2_t a0; + float32x4_t b0, b1; + float32x4_t c0; +#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); +#define loadA(i) a##i = vld1_f32(A + (m + i) * Astride + k); +#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadA, 1) + UNROLL_OUT(loadB, 2) +#undef loadA +#undef loadC +#undef loadB +#define calculateB0(i) c##i = vmlaq_lane_f32(c##i, b0, a##i, 0); +#define calculateB1(i) c##i = vmlaq_lane_f32(c##i, b1, a##i, 1); + UNROLL_OUT(calculateB0, 1) + UNROLL_OUT(calculateB1, 1) +#undef calculateB0 +#undef calculateB1 +#define vstore(i) vst1q_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 1) +#undef vstore + } + for (; n + 2 <= N; n += 2) { + float32x2_t a0; + float32x2_t b0, b1; + float32x2_t c0; +#define loadC(i) c##i = vld1_f32(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f32(B + (k + i) * Bstride + n); +#define loadA(i) a##i = vld1_f32(A + (m + i) * Astride + k); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadA, 1) + UNROLL_OUT(loadB, 2) +#undef loadA +#undef loadB +#undef loadC +#define calculateB0(i) c##i = vmla_lane_f32(c##i, b0, a##i, 0); +#define calculateB1(i) c##i = vmla_lane_f32(c##i, b1, a##i, 1); + UNROLL_OUT(calculateB0, 1) + UNROLL_OUT(calculateB1, 1) +#undef calculateB0 +#undef calculateB1 +#define vstore(i) vst1_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 1) +#undef vstore + } + for (; n < N; n += 1) { + float a00, a01; + float b0, b1; + float c0; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 2) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; + UNROLL_OUT(loadA0, 2) +#undef loadA0 + c0 += a00 * b0 + a01 * b1; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 1) +#undef vstore + } + } + for (; k < K; k += 1) { + size_t n = 0; + for (; n + 4 <= N; n += 4) { + float32x4_t a0; + float32x4_t b0; + float32x4_t c0; +#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); +#define loadA(i) a##i = vdupq_n_f32(A[(m + i) * Astride + k]); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadA, 1) + UNROLL_OUT(loadB, 1) +#undef loadA +#undef loadB +#undef loadC +#define calculateB0(i) c##i = vmlaq_f32(c##i, a##i, b0); + UNROLL_OUT(calculateB0, 1) +#undef calculateB0 +#define vstore(i) vst1q_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 1) +#undef vstore + } + for (; n + 2 <= N; n += 2) { + float32x2_t a0; + float32x2_t b0; + float32x2_t c0; +#define loadC(i) c##i = vld1_f32(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f32(B + (k + i) * Bstride + n); +#define loadA(i) a##i = vdup_n_f32(A[(m + i) * Astride + k]); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadA, 1) + UNROLL_OUT(loadB, 1) +#undef loadA +#undef loadB +#undef loadC +#define calculateB0(i) c##i = vmla_f32(c##i, a##i, b0); + UNROLL_OUT(calculateB0, 1) +#undef calculateB0 +#define vstore(i) vst1_f32(C + (m + i) * Cstride + n, c##i); + UNROLL_OUT(vstore, 1) +#undef vstore + } + for (; n < N; n += 1) { + float a00; + float b0; + float c0; +#define loadC(i) c##i = C[(m + i) * Cstride + n]; +#define loadB(i) b##i = B[(k + i) * Bstride + n]; + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 1) +#undef loadB +#undef loadC +#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; + UNROLL_OUT(loadA0, 1) +#undef loadA0 + c0 = c0 + a00 * b0; +#define vstore(i) C[(m + i) * Cstride + n] = c##i; + UNROLL_OUT(vstore, 1) +#undef vstore + } + } + } +} +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.h b/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.h new file mode 100644 index 00000000..d1c352be --- /dev/null +++ b/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.h @@ -0,0 +1,30 @@ +/** + * \file dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include + +namespace megdnn { +namespace arm_common { + +bool is_sgemv_like_preferred(bool row_major, bool transposeA, bool transposeB, + size_t M, size_t N, size_t K, float alpha, + size_t /* LDA */, size_t LDB, float beta, + size_t /* LDC */); + +void sgemm_sgemv_like(const float* __restrict A, const float* __restrict B, + float* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride); + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/int8/gemv.cpp b/dnn/src/arm_common/matrix_mul/int8/gemv.cpp new file mode 100644 index 00000000..676d0ba8 --- /dev/null +++ b/dnn/src/arm_common/matrix_mul/int8/gemv.cpp @@ -0,0 +1,129 @@ +/** + * \file dnn/src/arm_common/matrix_mul/int8/gemv.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if !__ARM_FEATURE_DOTPROD + +#include +#include "src/arm_common/matrix_mul/int8/gemv.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "megdnn/oprs.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_int8_gemv) + +using namespace megdnn; +using namespace arm_common; + +namespace { + +void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride) { + megdnn_assert(N == 1 && Bstride == 1); + size_t m = 0; + for (; m + 2 <= M; m += 2) { + int32_t acc0 = 0, acc1 = 0; + size_t k = 0; + for (; k + 16 <= K; k += 16) { + int8x16_t a0 = vld1q_s8(A + m * Astride + k); + int8x16_t a1 = vld1q_s8(A + (m + 1) * Astride + k); + + int8x16_t b0 = vld1q_s8(B + k); + + int16x8_t c0 = vmull_s8(vget_low_s8(a0), vget_low_s8(b0)); + c0 = vmlal_high_s8(c0, a0, b0); + + int16x8_t c1 = vmull_s8(vget_low_s8(a1), vget_low_s8(b0)); + c1 = vmlal_high_s8(c1, a1, b0); + acc0 += vaddlvq_s16(c0); + acc1 += vaddlvq_s16(c1); + } + + for (; k + 8 <= K; k += 8) { + int8x8_t a0 = vld1_s8(A + m * Astride + k); + int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k); + int8x8_t b0 = vld1_s8(B + k); + + int16x8_t c0 = vmull_s8(a0, b0); + + int16x8_t c1 = vmull_s8(a1, b0); + acc0 += vaddlvq_s16(c0); + acc1 += vaddlvq_s16(c1); + } + + for (; k < K; ++k) { + acc0 += static_cast(A[m * Astride + k]) * B[k]; + acc1 += static_cast(A[(m + 1) * Astride + k]) * B[k]; + } + C[m * Cstride] = acc0; + C[(m + 1) * Cstride] = acc1; + } + + for (; m < M; ++m) { + int32_t acc0 = 0; + size_t k = 0; + for (; k + 16 <= K; k += 16) { + int8x16_t a0 = vld1q_s8(A + m * Astride + k); + int8x16_t b0 = vld1q_s8(B + k); + + int16x8_t c0 = vmull_s8(vget_low_s8(a0), vget_low_s8(b0)); + c0 = vmlal_high_s8(c0, a0, b0); + + acc0 += vaddlvq_s16(c0); + } + + for (; k + 8 <= K; k += 8) { + int8x8_t a0 = vld1_s8(A + m * Astride + k); + int8x8_t b0 = vld1_s8(B + k); + + int16x8_t c0 = vmull_s8(a0, b0); + acc0 += vaddlvq_s16(c0); + } + + for (; k < K; ++k) { + acc0 += static_cast(A[m * Astride + k]) * B[k]; + } + C[m * Cstride] = acc0; + } +} + +} // namespace + +bool matmul::is_gemv_like_preferred_int8(bool transposeA, bool transposeB, + size_t M, size_t N, size_t K, + size_t LDA, size_t LDB, size_t LDC) { + MEGDNN_MARK_USED_VAR(LDA); + MEGDNN_MARK_USED_VAR(LDB); + MEGDNN_MARK_USED_VAR(LDC); + MEGDNN_MARK_USED_VAR(M); + MEGDNN_MARK_USED_VAR(K); + if (transposeA) + return false; + if (transposeB) + return false; + + return N == 1 && LDB == 1; +} + +void matmul::gemv_like_int8(const int8_t* __restrict A, + const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, + size_t Bstride, size_t Cstride) { + megdnn_assert(N == 1); + MIDOUT_BEGIN(megdnn_arm_common_int8_gemv) { + return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); + } MIDOUT_END(); +} + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/int8/gemv.h b/dnn/src/arm_common/matrix_mul/int8/gemv.h new file mode 100644 index 00000000..dc66d84f --- /dev/null +++ b/dnn/src/arm_common/matrix_mul/int8/gemv.h @@ -0,0 +1,33 @@ +/** + * \file dnn/src/arm_common/matrix_mul/int8/gemv.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include +#include + +#if !__ARM_FEATURE_DOTPROD +namespace megdnn { +namespace arm_common { +namespace matmul { + +bool is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M, + size_t N, size_t K, size_t LDA, size_t LDB, + size_t LDC); + +void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride); +} // namespace matmul +} // namespace arm_common +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.cpp b/dnn/src/arm_common/matrix_mul/opr_impl.cpp new file mode 100644 index 00000000..d06311d1 --- /dev/null +++ b/dnn/src/arm_common/matrix_mul/opr_impl.cpp @@ -0,0 +1,49 @@ +/** + * \file dnn/src/arm_common/matrix_mul/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/matrix_mul/opr_impl.h" +#include "src/arm_common/matrix_mul/algos.h" +#include "src/common/metahelper.h" + +using namespace megdnn; +using namespace arm_common; + +namespace { +uint8_t arm_common_algo_type_storage; +} // anonymous namespace + +void* const MatrixMulImpl::sm_arm_common_algo_type = + &arm_common_algo_type_storage; + +class MatrixMulImpl::AlgoPack : NonCopyableObj { + AlgoInt8x8x16 int8x8x16; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + AlgoF16Gemv f16gemv; +#endif + +public: + AlgoPack() { + all_algos.emplace_back(&int8x8x16); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + all_algos.emplace_back(&f16gemv); +#endif + } + SmallVector all_algos; +}; + +SmallVector MatrixMulImpl::algo_pack() { + static AlgoPack s_algo_pack; + auto&& algos = fallback::MatrixMulImpl::algo_pack(); + algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), + s_algo_pack.all_algos.end()); + return std::move(algos); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.h b/dnn/src/arm_common/matrix_mul/opr_impl.h new file mode 100644 index 00000000..78a1bc2a --- /dev/null +++ b/dnn/src/arm_common/matrix_mul/opr_impl.h @@ -0,0 +1,42 @@ +/** + * \file dnn/src/arm_common/matrix_mul/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/common/utils.h" +#include "src/fallback/matrix_mul/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class MatrixMulImpl : public fallback::MatrixMulImpl { +public: + using fallback::MatrixMulImpl::MatrixMulImpl; + + bool is_thread_safe() const override { return true; } + + SmallVector algo_pack() override; + +protected: + static void* const sm_arm_common_algo_type; +#if !__ARM_FEATURE_DOTPROD + class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv +#endif + class AlgoF32Gemv; // Arm_common F32 Gemv +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + class AlgoF16Gemv; +#endif + class AlgoInt8x8x16; // Arm_common Int 8x8x16 + class AlgoPack; +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/algo.cpp b/dnn/src/arm_common/pooling/algo.cpp new file mode 100644 index 00000000..4cc52252 --- /dev/null +++ b/dnn/src/arm_common/pooling/algo.cpp @@ -0,0 +1,562 @@ +/** + * \file dnn/src/arm_common/pooling/algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/pooling/algo.h" +#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h" +#include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h" +#include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_common_pooling) + +namespace megdnn { +namespace arm_common { + +WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) { + megdnn_assert((param.src_type.category() == DTypeCategory::FLOAT || + param.src_type.enumv() == DTypeEnum::QuantizedS8 || + param.src_type.enumv() == DTypeEnum::Quantized8Asymm || + param.src_type == dtype::Int8{}) && + param.format == param::Pooling::Format::NCHW && + (param.mode == param::Pooling::Mode::MAX || + (param.mode == param::Pooling::Mode::AVERAGE && + param.filter[0] == 3)) && + param.filter[0] == param.filter[1] && + (param.filter[0] == 3 || param.filter[1] == 5) && + param.stride[0] == 2 && param.stride[1] == 2 && + param.isz[0] >= 2 && param.isz[1] >= 2); + //! max pooling nxn stride 2 + auto IW = param.isz[1]; + auto OW = param.osz[1]; + + // In order to process odd size filter, + // Firstly, Store a row of the input separately by odd and even numbers + // Then process them, get a row of the outputs + // We need to store n rows of results + SmallVector needed_mem; + for (size_t i = 0; i < param.filter[0]; ++i) + needed_mem.push_back(OW * param.src_type.size()); + needed_mem.push_back((IW + 1) / 2 * param.src_type.size()); + needed_mem.push_back((IW + 1) / 2 * param.src_type.size()); + WorkspaceBundle ws(nullptr, needed_mem, 16); + return ws; +} + +bool PoolingImpl::AlgoFilterxModexStride1::usable( + const PoolingKernSizeParam& param) const { + auto SH = param.stride[0]; + auto SW = param.stride[1]; + auto FH = param.filter[0]; + auto FW = param.filter[1]; + + bool avaible = (param.src_type.category() == DTypeCategory::FLOAT || + param.src_type.category() == DTypeCategory::QUANTIZED) && + param.format == Param::Format::NCHW && SH == 1 && SW == 1 && + FH == FW && (FH == 2 || FH == 3); + return avaible; +} + +void PoolingImpl::AlgoFilterxModexStride1::exec( + const PoolingKernParam& param) const { + auto IH = param.isz[0], IW = param.isz[1]; + auto OH = param.osz[0], OW = param.osz[1]; + auto N = param.n, C = param.ic; + auto PH = param.padding[0]; + auto PW = param.padding[1]; + auto FH = param.filter[0]; + + void* src_ptr = param.src_ptr; + void* dst_ptr = param.dst_ptr; + +#define DISPATCH_FUNC(Pooler, NeonPooler, window, midout_type_id) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(0), \ + midout_iv(midout_type_id), Pooler::MIDOUT_CASE_NUM, \ + NeonPooler::MIDOUT_CASE_NUM, window) { \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ + src_dtype = param.src_type](size_t index, size_t) { \ + size_t n = index / C; \ + size_t c = index % C; \ + do_pooling_compact< \ + Pooler MEGDNN_COMMA NeonPooler MEGDNN_COMMA window>( \ + static_cast(src_ptr) + \ + n * C * IH * IW + c * IH * IW, \ + static_cast(dst_ptr) + \ + n * C * OH * OW + c * OH * OW, \ + src_dtype, IH, IW, OH, OW, PH, PW); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ + run); \ + } \ + MIDOUT_END() + +#define DISPATCH_WINDOW(Pooler, NeonPooler, dtype, ctype, comp_type, \ + midout_type_id) \ + switch (FH) { \ + case 2: { \ + using _Pooler = Pooler<4, dtype, ctype, comp_type>; \ + using _NeonPooler = NeonPooler<4, dtype, ctype, comp_type>; \ + DISPATCH_FUNC(_Pooler, _NeonPooler, 2, midout_type_id); \ + break; \ + } \ + case 3: { \ + using _Pooler = Pooler<9, dtype, ctype, comp_type>; \ + using _NeonPooler = NeonPooler<9, dtype, ctype, comp_type>; \ + DISPATCH_FUNC(_Pooler, _NeonPooler, 3, midout_type_id); \ + break; \ + } \ + default: \ + megdnn_assert(0, "unsupport pooling filter size"); \ + break; \ + } + +#define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id) \ + switch (param.mode) { \ + case Mode::MAX: \ + DISPATCH_WINDOW(MaxPooler, NeonMaxPooler, dtype, ctype, comp_type, \ + midout_type_id); \ + break; \ + case Mode::AVERAGE: \ + DISPATCH_WINDOW(MeanInPooler, NeonMeanPooler, dtype, ctype, \ + comp_type, midout_type_id); \ + break; \ + default: \ + megdnn_assert(0, "unsupport pooling mode"); \ + break; \ + } + + if (param.src_type == dtype::Float32{}) { + DISPATCH_MODE(dt_float32, float, float, 0); + } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) { + DISPATCH_MODE(dt_qint8, int8_t, float, 1); + } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { + DISPATCH_MODE(dt_quint8, uint8_t, float, 2); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } else if (param.src_type == dtype::Float16{}) { + DISPATCH_MODE(dt_float16, __fp16, __fp16, 3); +#endif + } +#undef DISPATCH_FUNC +#undef DISPATCH_WINDOW +#undef DISPATCH_MODE +} +bool PoolingImpl::AlgoFilter2ModexStride2::usable( + const PoolingKernSizeParam& param) const { + auto SH = param.stride[0]; + auto SW = param.stride[1]; + auto FH = param.filter[0]; + auto FW = param.filter[1]; + + bool avaible = (param.src_type.category() == DTypeCategory::FLOAT || + param.src_type.category() == DTypeCategory::QUANTIZED) && + param.format == Param::Format::NCHW && FH == FW && + SH == SW && FH == 2 && SH == 2; + return avaible; +} + +void PoolingImpl::AlgoFilter2ModexStride2::exec( + const PoolingKernParam& param) const { + auto IH = param.isz[0], IW = param.isz[1]; + auto OH = param.osz[0], OW = param.osz[1]; + auto N = param.n, C = param.ic; + auto PH = param.padding[0]; + auto PW = param.padding[1]; + + void* src_ptr = param.src_ptr; + void* dst_ptr = param.dst_ptr; +#define DISPATCH_FUNC(Pooler, mode, midout_type_id) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(1), \ + midout_iv(midout_type_id), Pooler::MIDOUT_CASE_NUM) { \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ + src_dtype = param.src_type](size_t index, size_t) { \ + size_t n = index / C; \ + size_t c = index % C; \ + do_pooling_2x2( \ + static_cast(src_ptr) + \ + n * C * IH * IW + c * IH * IW, \ + static_cast(dst_ptr) + \ + n * C * OH * OW + c * OH * OW, \ + src_dtype, IH, IW, OH, OW, PH, PW); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ + run); \ + } \ + MIDOUT_END() + +#define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id) \ + switch (param.mode) { \ + case Mode::MAX: { \ + using _Pooler = MaxPooler<4, dtype, ctype, comp_type>; \ + DISPATCH_FUNC(_Pooler, Mode::MAX, midout_type_id); \ + break; \ + } \ + case Mode::AVERAGE: { \ + using _Pooler = MeanInPooler<4, dtype, ctype, comp_type>; \ + DISPATCH_FUNC(_Pooler, Mode::AVERAGE, midout_type_id); \ + break; \ + } \ + default: \ + megdnn_assert(0, "unsupport pooling mode"); \ + break; \ + } + + if (param.src_type == dtype::Float32{}) { + DISPATCH_MODE(dt_float32, float, float, 0); + } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) { + DISPATCH_MODE(dt_qint8, int8_t, float, 1); + } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { + DISPATCH_MODE(dt_quint8, uint8_t, float, 2); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } else if (param.src_type == dtype::Float16{}) { + DISPATCH_MODE(dt_float16, __fp16, __fp16, 3); +#endif + } +#undef DISPATCH_FUNC +#undef DISPATCH_PAD +#undef DISPATCH_MODE +} + +bool PoolingImpl::AlgoFilter3MaxStride2::usable( + const PoolingKernSizeParam& param) const { + bool avaible = (param.src_type.category() == DTypeCategory::FLOAT || + param.src_type.category() == DTypeCategory::QUANTIZED) && + param.format == Param::Format::NCHW && + param.mode == Mode::MAX && param.filter[0] == 3 && + param.filter[1] == 3 && param.stride[0] == 2 && + param.stride[1] == 2 && param.isz[0] >= 2 && + param.isz[1] >= 2; + return avaible; +} + +void PoolingImpl::AlgoFilter3MaxStride2::exec( + const PoolingKernParam& param) const { + auto IH = param.isz[0], IW = param.isz[1]; + auto OH = param.osz[0], OW = param.osz[1]; + auto N = param.n, C = param.ic; + auto PH = param.padding[0]; + auto PW = param.padding[1]; + + void* src_ptr = param.src_ptr; + void* dst_ptr = param.dst_ptr; + +#define DISPATCH_FUNC(type, func, midout_type_id) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ + midout_iv(midout_type_id)) { \ + WorkspaceBundle wbundle = get_bundle(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ + wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ + size_t index, size_t thread_id) { \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ + size_t n = index / C; \ + size_t c = index % C; \ + do_max_pooling_3x3_s2x2_##func##_NEON( \ + static_cast(src_ptr) + n * C * IH * IW + \ + c * IH * IW, \ + static_cast(dst_ptr) + n * C * OH * OW + \ + c * OH * OW, \ + IH, IW, OH, OW, PH, PW, ws); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ + run); \ + } \ + MIDOUT_END(); + + if (param.src_type == dtype::Float32{}) { + DISPATCH_FUNC(float, float, 0); + } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) { + DISPATCH_FUNC(int8_t, int8, 1); + } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { + DISPATCH_FUNC(uint8_t, uint8, 2); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } else if (param.src_type == dtype::Float16{}) { + DISPATCH_FUNC(__fp16, float16, 3); +#endif + } +#undef DISPATCH_FUNC +} +bool PoolingImpl::AlgoFilter3AverageStride2::usable( + const PoolingKernSizeParam& param) const { + bool avaible = (param.src_type.category() == DTypeCategory::FLOAT) && + param.format == Param::Format::NCHW && + param.mode == Mode::AVERAGE && param.filter[0] == 3 && + param.filter[1] == 3 && param.stride[0] == 2 && + param.stride[1] == 2 && param.isz[0] >= 2 && + param.isz[1] >= 2; + return avaible; +} + +void PoolingImpl::AlgoFilter3AverageStride2::exec( + const PoolingKernParam& param) const { + auto IH = param.isz[0], IW = param.isz[1]; + auto OH = param.osz[0], OW = param.osz[1]; + auto N = param.n, C = param.ic; + auto PH = param.padding[0]; + auto PW = param.padding[1]; + + void* src_ptr = param.src_ptr; + void* dst_ptr = param.dst_ptr; + +#define DISPATCH_FUNC(type, MEGDNN_SIMD_WIDTH, midout_type_id) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(3), \ + midout_iv(midout_type_id)) { \ + WorkspaceBundle wbundle = get_bundle(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ + wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ + size_t index, size_t thread_id) { \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ + size_t n = index / C; \ + size_t c = index % C; \ + do_average_pooling_3x3_s2x2_NEON( \ + static_cast(src_ptr) + n * C * IH * IW + \ + c * IH * IW, \ + static_cast(dst_ptr) + n * C * OH * OW + \ + c * OH * OW, \ + IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ + run); \ + } \ + MIDOUT_END(); + if (param.src_type == dtype::Float32{}) { + DISPATCH_FUNC(dt_float32, 4, 0); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } else if (param.src_type == dtype::Float16{}) { + DISPATCH_FUNC(__fp16, 8, 1); +#endif + } +#undef DISPATCH_FUNC +} +bool PoolingImpl::AlgoFilter4MaxStride2::usable( + const PoolingKernSizeParam& param) const { + auto SH = param.stride[0]; + auto SW = param.stride[1]; + auto FH = param.filter[0]; + auto FW = param.filter[1]; + auto OH = param.osz[0], OW = param.osz[1]; + + bool avaible = (param.src_type.category() == DTypeCategory::FLOAT || + param.src_type.category() == DTypeCategory::QUANTIZED) && + param.format == Param::Format::NCHW && + param.mode == Mode::MAX && FH == 4 && FW == 4 && SH == 2 && + SW == 2 && OH >= 2 && OW >= 2; + return avaible; +} + +void PoolingImpl::AlgoFilter4MaxStride2::exec( + const PoolingKernParam& param) const { + auto IH = param.isz[0], IW = param.isz[1]; + auto OH = param.osz[0], OW = param.osz[1]; + auto N = param.n, C = param.ic; + auto PH = param.padding[0]; + auto PW = param.padding[1]; + + void* src_ptr = param.src_ptr; + void* dst_ptr = param.dst_ptr; + +#define DISPATCH_FUNC(type, func, midout_type_id) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(4), \ + midout_iv(midout_type_id)) { \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ + src_dtype = param.src_type](size_t index, size_t) { \ + size_t n = index / C; \ + size_t c = index % C; \ + do_max_pooling_w4x4_s2x2_##func##_NEON( \ + static_cast(src_ptr) + n * C * IH * IW + \ + c * IH * IW, \ + static_cast(dst_ptr) + n * C * OH * OW + \ + c * OH * OW, \ + src_dtype, IH, IW, OH, OW, PH, PW); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ + run); \ + } \ + MIDOUT_END(); + + if (param.src_type == dtype::Float32{}) { + DISPATCH_FUNC(float, float, 0); + } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) { + DISPATCH_FUNC(int8_t, int8, 1); + } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { + DISPATCH_FUNC(uint8_t, uint8, 2); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } else if (param.src_type == dtype::Float16{}) { + DISPATCH_FUNC(__fp16, float16, 3); +#endif + } +#undef DISPATCH_FUNC +} +bool PoolingImpl::AlgoFilter5MaxStride2::usable( + const PoolingKernSizeParam& param) const { + auto SH = param.stride[0]; + auto SW = param.stride[1]; + auto FH = param.filter[0]; + auto FW = param.filter[1]; + auto OH = param.osz[0], OW = param.osz[1]; + + bool avaible = (param.src_type.category() == DTypeCategory::FLOAT || + param.src_type.category() == DTypeCategory::QUANTIZED) && + param.format == Param::Format::NCHW && + param.mode == Mode::MAX && FH == 5 && FW == 5 && SH == 2 && + SW == 2 && OH >= 2 && OW >= 2; + return avaible; +} + +void PoolingImpl::AlgoFilter5MaxStride2::exec( + const PoolingKernParam& param) const { + auto IH = param.isz[0], IW = param.isz[1]; + auto OH = param.osz[0], OW = param.osz[1]; + auto N = param.n, C = param.ic; + auto PH = param.padding[0]; + auto PW = param.padding[1]; + + void* src_ptr = param.src_ptr; + void* dst_ptr = param.dst_ptr; + +#define DISPATCH_FUNC(dtype, type, midout_type_id, MEGDNN_SIMD_WIDTH) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(5), \ + midout_iv(midout_type_id)) { \ + WorkspaceBundle wbundle = get_bundle(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ + wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ + size_t index, size_t thread_id) { \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ + size_t n = index / C; \ + size_t c = index % C; \ + do_max_pooling_w5x5_s2x2_NEON( \ + static_cast(src_ptr) + n * C * IH * IW + \ + c * IH * IW, \ + static_cast(dst_ptr) + n * C * OH * OW + \ + c * OH * OW, \ + IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ + run); \ + } \ + MIDOUT_END(); + + if (param.src_type == dtype::Float32{}) { + DISPATCH_FUNC(dt_float32, float, 0, 4); + } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) { + DISPATCH_FUNC(dt_int8, int8_t, 1, 16); + } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { + DISPATCH_FUNC(dt_uint8, uint8_t, 2, 16); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } else if (param.src_type == dtype::Float16{}) { + DISPATCH_FUNC(dt_float16, __fp16, 3, 8); +#endif + } +#undef DISPATCH_FUNC +} + +bool PoolingImpl::AlgoInt8Filter2MaxStride2::usable( + const PoolingKernSizeParam& param) const { + auto SH = param.stride[0]; + auto SW = param.stride[1]; + auto FH = param.filter[0]; + auto FW = param.filter[1]; + auto PH = param.padding[0]; + auto PW = param.padding[1]; + + bool avaible = param.src_type == dtype::Int8() && + param.format == Param::Format::NCHW && + param.mode == Mode::MAX && SH == 2 && SW == 2 && PH == 0 && + PW == 0 && FH == 2 && FW == 2; + return avaible; +} + +void PoolingImpl::AlgoInt8Filter2MaxStride2::exec( + const PoolingKernParam& param) const { + auto IH = param.isz[0], IW = param.isz[1]; + auto OH = param.osz[0], OW = param.osz[1]; + auto N = param.n, C = param.ic; + + auto src_ptr = param.src(); + auto dst_ptr = param.dst(); + + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(6)) { + auto run = [C, IH, IW, OH, OW, src_ptr, dst_ptr](size_t index, size_t) { + size_t n = index / C; + size_t c = index % C; + pooling_max_w2x2_s2x2(src_ptr + n * C * IH * IW + c * IH * IW, + dst_ptr + n * C * OH * OW + c * OH * OW, 1, 1, + IH, IW, OH, OW); + }; + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N * C, + run); + } + MIDOUT_END(); +} + +bool PoolingImpl::AlgoInt8Filter3MaxStride2::usable( + const PoolingKernSizeParam& param) const { + auto SH = param.stride[0]; + auto SW = param.stride[1]; + auto FH = param.filter[0]; + auto FW = param.filter[1]; + auto IH = param.isz[0]; + auto IW = param.isz[1]; + + bool avaible = param.src_type == dtype::Int8() && + param.format == Param::Format::NCHW && + param.mode == Mode::MAX && FH == 3 && FW == 3 && SH == 2 && + SW == 2 && IH >= 2 && IW >= 2; + return avaible; +} + +void PoolingImpl::AlgoInt8Filter3MaxStride2::exec( + const PoolingKernParam& param) const { + auto IH = param.isz[0], IW = param.isz[1]; + auto OH = param.osz[0], OW = param.osz[1]; + auto N = param.n, C = param.ic; + auto PH = param.padding[0]; + auto PW = param.padding[1]; + + auto src_ptr = param.src(); + auto dst_ptr = param.dst(); + + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(7)) { + WorkspaceBundle wbundle = get_bundle(param); + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, + wbundle = wbundle, + workspace_ptr = param.workspace()]( + size_t index, size_t thread_id) { + auto ws = wbundle; + ws.set(workspace_ptr + thread_id * ws.total_size_in_bytes()); + size_t n = index / C; + size_t c = index % C; + do_max_pooling_3x3_s2x2_int8_NEON( + src_ptr + n * C * IH * IW + c * IH * IW, + dst_ptr + n * C * OH * OW + c * OH * OW, IH, IW, OH, OW, PH, + PW, ws); + }; + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N * C, + run); + } + MIDOUT_END(); +} +} // namespace arm_common +} // namespace megdnn +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/arm_common/pooling/algo.h b/dnn/src/arm_common/pooling/algo.h new file mode 100644 index 00000000..a9969590 --- /dev/null +++ b/dnn/src/arm_common/pooling/algo.h @@ -0,0 +1,91 @@ +/** + * \file dnn/src/arm_common/pooling/algo.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/arm_common/pooling/opr_impl.h" +#include "src/arm_common/pooling/pooling_helper.h" +#include "src/common//utils.h" +#include "src/naive/handle.h" + +namespace megdnn { +namespace arm_common { + +using AlgoBase = PoolingImpl::AlgoBase; + +class PoolingImpl::AlgoFilterxModexStride1 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_POOLING_STRIDE1"; } + bool usable(const PoolingKernSizeParam& param) const override; + void exec(const PoolingKernParam& param) const override; +}; + +class PoolingImpl::AlgoFilter2ModexStride2 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_POOLING_STRIDE2"; } + bool usable(const PoolingKernSizeParam& param) const override; + void exec(const PoolingKernParam& param) const override; +}; +class PoolingImpl::AlgoFilter3MaxStride2 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_POOLING_FILTER3_MAX"; } + bool usable(const PoolingKernSizeParam& param) const override; + void exec(const PoolingKernParam& param) const override; +}; + +class PoolingImpl::AlgoFilter3AverageStride2 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_POOLING_FILTER3_AVERAGE"; } + bool usable(const PoolingKernSizeParam& param) const override; + void exec(const PoolingKernParam& param) const override; +}; + +class PoolingImpl::AlgoFilter4MaxStride2 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_POOLING_FILTER4_MAX"; } + bool usable(const PoolingKernSizeParam& param) const override; + void exec(const PoolingKernParam& param) const override; +}; + +class PoolingImpl::AlgoFilter5MaxStride2 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_POOLING_FILTER5_MAX"; } + bool usable(const PoolingKernSizeParam& param) const override; + void exec(const PoolingKernParam& param) const override; +}; + +class PoolingImpl::AlgoInt8Filter2MaxStride2 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_POOLING_INT8_FILTER2X2"; } + bool usable(const PoolingKernSizeParam& param) const override; + void exec(const PoolingKernParam& param) const override; +}; + +class PoolingImpl::AlgoInt8Filter3MaxStride2 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_POOLING_INT8_FILTER3X3"; } + bool usable(const PoolingKernSizeParam& param) const override; + void exec(const PoolingKernParam& param) const override; +}; +WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float.cpp b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float.cpp new file mode 100644 index 00000000..53d6ad37 --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float.cpp @@ -0,0 +1,15 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/simd_macro/neon_helper.h" + +#include "src/common/pooling/do_max_pooling_3x3_s2x2_float_def.inl" + +#include "src/arm_common/simd_macro/neon_helper_epilogue.h" diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float.h b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float.h new file mode 100644 index 00000000..b009613c --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float.h @@ -0,0 +1,17 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/arm_common/simd_macro/neon_helper.h" + +#include "src/common/pooling/do_max_pooling_3x3_s2x2_float_decl.inl" + +#include "src/arm_common/simd_macro/neon_helper_epilogue.h" diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.cpp b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.cpp new file mode 100644 index 00000000..f422f8c2 --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.cpp @@ -0,0 +1,160 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include +#include +#include +#include "src/arm_common/simd_macro/marm_neon.h" +#include + +namespace megdnn { +namespace arm_common { + +#define MEGDNN_SIMD_WIDTH 8 +void do_max_pooling_3x3_s2x2_float16_NEON(const __fp16* src, __fp16* dst, + size_t IH_, size_t IW_, size_t OH_, + size_t OW_, size_t PH_, size_t PW_, + const WorkspaceBundle& ws) { + int IH = IH_, IW = IW_, OH = OH_, OW = OW_, PH = PH_, PW = PW_; + // cache[i] stores the answer of the i-th line after + // pooling along the W dimension. + __fp16* cache[3] = {static_cast<__fp16*>(ws.get(0)), + static_cast<__fp16*>(ws.get(1)), + static_cast<__fp16*>(ws.get(2))}; + __fp16* odd = static_cast<__fp16*>(ws.get(3)); + __fp16* even = static_cast<__fp16*>(ws.get(4)); + int ih_next = 0; + // "good" area means we can use SIMD to accelerate. + auto get_good_area = [](int I, int /* O */, int P, int& O_from, int& O_to) { + // x*2 - P >= 0; 2x >= P; x >= P/2 + O_from = (P + 1) / 2; + // x*2 - P + 3 <= I; x*2 <= I+P-3; x <= (I+P-3)/2 + O_to = (I + P - 3) / 2 + 1; + // we must have I >= 2 to ensure O_from <= O_to + }; + int OW_from, OW_to; + get_good_area(IW, OW, PW, OW_from, OW_to); + auto process_cache = [&](int ih) { + const __fp16* __restrict sptr = src + ih * IW; + auto tmp = cache[2]; + cache[2] = cache[1]; + cache[1] = cache[0]; + cache[0] = tmp; + // cache 0 is used to store the current answer. + auto run_single = [&](int ow) { + int iw = ow * 2 - PW; + __fp16 res = std::numeric_limits::lowest(); + if (iw + 0 >= 0 && iw + 0 < IW) { + res = std::max(res, sptr[iw + 0]); + } + if (iw + 1 >= 0 && iw + 1 < IW) { + res = std::max(res, sptr[iw + 1]); + } + if (iw + 2 >= 0 && iw + 2 < IW) { + res = std::max(res, sptr[iw + 2]); + } + cache[0][ow] = res; + }; + // build odd/even + int iw = 0; + int odd_offset = 0, even_offset = 0; + + for (; iw + 2 * MEGDNN_SIMD_WIDTH <= IW; iw += 2 * MEGDNN_SIMD_WIDTH) { + float16x8_t s0, s1; + s0 = vld1q_f16(sptr + iw + 0); + s1 = vld1q_f16(sptr + iw + MEGDNN_SIMD_WIDTH); + float16x8x2_t d = vuzpq_f16(s0, s1); + vst1q_f16(even + even_offset, d.val[0]); + vst1q_f16(odd + odd_offset, d.val[1]); + even_offset += MEGDNN_SIMD_WIDTH; + odd_offset += MEGDNN_SIMD_WIDTH; + } + for (; iw < IW; ++iw) { + if (iw & 1) + odd[odd_offset++] = sptr[iw]; + else + even[even_offset++] = sptr[iw]; + } + int ow = 0; + for (; ow < OW_from; ++ow) + run_single(ow); + if (PW & 1) { + for (; ow + MEGDNN_SIMD_WIDTH <= OW_to; ow += MEGDNN_SIMD_WIDTH) { + float16x8_t d, s0, s1, s2; + s0 = vld1q_f16(odd + ow - (PW >> 1) - 1); + s1 = vld1q_f16(even + ow - (PW >> 1)); + s2 = vld1q_f16(odd + ow - (PW >> 1)); + d = vmaxq_f16(vmaxq_f16(s0, s1), s2); + vst1q_f16(cache[0] + ow, d); + } + } else { + for (; ow + MEGDNN_SIMD_WIDTH <= OW_to; + ow += MEGDNN_SIMD_WIDTH) { + float16x8_t d, s0, s1, s2; + s0 = vld1q_f16(even + ow - (PW >> 1)); + s1 = vld1q_f16(odd + ow - (PW >> 1)); + s2 = vld1q_f16(even + ow - (PW >> 1) + 1); + d = vmaxq_f16(vmaxq_f16(s0, s1), s2); + vst1q_f16(cache[0] + ow, d); + } + } + for (; ow < OW; ++ow) + run_single(ow); + }; + for (int oh = 0; oh < OH; ++oh) { + __fp16* __restrict dptr = dst + oh * OW; + int ih_from = std::min(IH, std::max(0, oh * 2 - PH)); + int ih_to = std::min(IH, std::max(0, oh * 2 - PH + 3)); + while (ih_next < ih_to) { + process_cache(ih_next++); + } + if (ih_to - ih_from == 3) { + int ow = 0; + for (; ow + MEGDNN_SIMD_WIDTH <= OW; ow += MEGDNN_SIMD_WIDTH) { + float16x8_t d, s0, s1, s2; + s0 = vld1q_f16(cache[0] + ow); + s1 = vld1q_f16(cache[1] + ow); + s2 = vld1q_f16(cache[2] + ow); + d = vmaxq_f16(vmaxq_f16(s0, s1), s2); + vst1q_f16(dptr + ow, d); + } + for (; ow < OW; ++ow) { + dptr[ow] = std::max(std::max(cache[0][ow], cache[1][ow]), + cache[2][ow]); + } + } else { + std::memcpy(dptr, cache[0], sizeof(__fp16) * OW); + for (int i = 1; i < ih_to - ih_from; ++i) { + int ow = 0; + for (; ow + MEGDNN_SIMD_WIDTH <= OW; + ow += MEGDNN_SIMD_WIDTH) { + float16x8_t d, s; + s = vld1q_f16(cache[i] + ow); + d = vld1q_f16(dptr + ow); + d = vmaxq_f16(d, s); + vst1q_f16(dptr + ow, d); + } + for (; ow < OW; ++ow) { + dptr[ow] = std::max(dptr[ow], cache[i][ow]); + } + } + } + } +} + +} // namespace arm_common +} // namespace megdnn + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.h b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.h new file mode 100644 index 00000000..a87e5773 --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.h @@ -0,0 +1,30 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include +#include "src/arm_common/simd_macro/marm_neon.h" +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "src/common/utils.h" + +namespace megdnn { +namespace arm_common { + +void do_max_pooling_3x3_s2x2_float16_NEON(const __fp16* src, __fp16* dst, + size_t IH, size_t IW, size_t OH, + size_t OW, size_t PH, size_t PW, + const WorkspaceBundle& ws); + +} // namespace arm_common +} // namespace megdnn + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.cpp b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.cpp new file mode 100644 index 00000000..5d93d43e --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.cpp @@ -0,0 +1,284 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h" + +#include +#include +#include +#include "src/arm_common/simd_macro/marm_neon.h" +#include + +namespace megdnn { +namespace arm_common { + +void do_max_pooling_3x3_s2x2_int8_NEON(const int8_t* src, int8_t* dst, + size_t IH_, size_t IW_, size_t OH_, + size_t OW_, size_t PH_, size_t PW_, + const WorkspaceBundle& ws) { + int IH = IH_, IW = IW_, OH = OH_, OW = OW_, PH = PH_, PW = PW_; + // cache[i] stores the answer of the i-th line after + // pooling along the W dimension. + int8_t* cache[3] = {static_cast(ws.get(0)), + static_cast(ws.get(1)), + static_cast(ws.get(2))}; + int8_t* odd = static_cast(ws.get(3)); + int8_t* even = static_cast(ws.get(4)); + + int ih_next = 0; + // "good" area means we can use SIMD to accelerate. + auto get_good_area = [](int I, int /* O */, int P, int& O_from, int& O_to) { + // x*2 - P >= 0; 2x >= P; x >= P/2 + O_from = (P + 1) / 2; + // x*2 - P + 3 <= I; x*2 <= I+P-3; x <= (I+P-3)/2 + O_to = (I + P - 3) / 2 + 1; + // we must have I >= 2 to ensure O_from <= O_to + }; + int OW_from, OW_to; + get_good_area(IW, OW, PW, OW_from, OW_to); + auto process_cache = [&](int ih) { + const int8_t* __restrict sptr = src + ih * IW; + auto tmp = cache[2]; + cache[2] = cache[1]; + cache[1] = cache[0]; + cache[0] = tmp; + // cache 0 is used to store the current answer. + auto run_single = [&](int ow) { + int iw = ow * 2 - PW; + int8_t res = std::numeric_limits::lowest(); + if (iw + 0 >= 0 && iw + 0 < IW) { + res = std::max(res, sptr[iw + 0]); + } + if (iw + 1 >= 0 && iw + 1 < IW) { + res = std::max(res, sptr[iw + 1]); + } + if (iw + 2 >= 0 && iw + 2 < IW) { + res = std::max(res, sptr[iw + 2]); + } + cache[0][ow] = res; + }; + // build odd/even + int iw = 0; + int odd_offset = 0, even_offset = 0; + + for (; iw + 32 <= IW; iw += 32) { + int8x16_t s0, s1; + s0 = vld1q_s8(sptr + iw + 0); + s1 = vld1q_s8(sptr + iw + 16); + int8x16x2_t d = vuzpq_s8(s0, s1); + vst1q_s8(even + even_offset, d.val[0]); + vst1q_s8(odd + odd_offset, d.val[1]); + even_offset += 16; + odd_offset += 16; + } + for (; iw < IW; ++iw) { + if (iw & 1) + odd[odd_offset++] = sptr[iw]; + else + even[even_offset++] = sptr[iw]; + } + int ow = 0; + for (; ow < OW_from; ++ow) + run_single(ow); + if (PW & 1) { + for (; ow + 16 <= OW_to; ow += 16) { + int8x16_t d, s0, s1, s2; + s0 = vld1q_s8(odd + ow - (PW >> 1) - 1); + s1 = vld1q_s8(even + ow - (PW >> 1)); + s2 = vld1q_s8(odd + ow - (PW >> 1)); + d = vmaxq_s8(vmaxq_s8(s0, s1), s2); + vst1q_s8(cache[0] + ow, d); + } + } else { + for (; ow + 16 <= OW_to; ow += 16) { + int8x16_t d, s0, s1, s2; + s0 = vld1q_s8(even + ow - (PW >> 1)); + s1 = vld1q_s8(odd + ow - (PW >> 1)); + s2 = vld1q_s8(even + ow - (PW >> 1) + 1); + d = vmaxq_s8(vmaxq_s8(s0, s1), s2); + vst1q_s8(cache[0] + ow, d); + } + } + for (; ow < OW; ++ow) + run_single(ow); + }; + for (int oh = 0; oh < OH; ++oh) { + int8_t* __restrict dptr = dst + oh * OW; + int ih_from = std::min(IH, std::max(0, oh * 2 - PH)); + int ih_to = std::min(IH, std::max(0, oh * 2 - PH + 3)); + while (ih_next < ih_to) { + process_cache(ih_next++); + } + if (ih_to - ih_from == 3) { + int ow = 0; + for (; ow + 16 <= OW; ow += 16) { + int8x16_t d, s0, s1, s2; + s0 = vld1q_s8(cache[0] + ow); + s1 = vld1q_s8(cache[1] + ow); + s2 = vld1q_s8(cache[2] + ow); + d = vmaxq_s8(vmaxq_s8(s0, s1), s2); + vst1q_s8(dptr + ow, d); + } + for (; ow < OW; ++ow) { + dptr[ow] = std::max(std::max(cache[0][ow], cache[1][ow]), + cache[2][ow]); + } + } else { + std::memcpy(dptr, cache[0], sizeof(int8_t) * OW); + for (int i = 1; i < ih_to - ih_from; ++i) { + int ow = 0; + for (; ow + 16 <= OW; ow += 16) { + int8x16_t d, s; + s = vld1q_s8(cache[i] + ow); + d = vld1q_s8(dptr + ow); + d = vmaxq_s8(d, s); + vst1q_s8(dptr + ow, d); + } + for (; ow < OW; ++ow) { + dptr[ow] = std::max(dptr[ow], cache[i][ow]); + } + } + } + } +} + +void do_max_pooling_3x3_s2x2_uint8_NEON(const uint8_t* src, uint8_t* dst, + size_t IH_, size_t IW_, size_t OH_, + size_t OW_, size_t PH_, size_t PW_, + const WorkspaceBundle& ws) { + int IH = IH_, IW = IW_, OH = OH_, OW = OW_, PH = PH_, PW = PW_; + // cache[i] stores the answer of the i-th line after + // pooling along the W dimension. + uint8_t* cache[3] = {static_cast(ws.get(0)), + static_cast(ws.get(1)), + static_cast(ws.get(2))}; + uint8_t* odd = static_cast(ws.get(3)); + uint8_t* even = static_cast(ws.get(4)); + + int ih_next = 0; + // "good" area means we can use SIMD to accelerate. + auto get_good_area = [](int I, int /* O */, int P, int& O_from, int& O_to) { + // x*2 - P >= 0; 2x >= P; x >= P/2 + O_from = (P + 1) / 2; + // x*2 - P + 3 <= I; x*2 <= I+P-3; x <= (I+P-3)/2 + O_to = (I + P - 3) / 2 + 1; + // we must have I >= 2 to ensure O_from <= O_to + }; + int OW_from, OW_to; + get_good_area(IW, OW, PW, OW_from, OW_to); + auto process_cache = [&](int ih) { + const uint8_t* __restrict sptr = src + ih * IW; + auto tmp = cache[2]; + cache[2] = cache[1]; + cache[1] = cache[0]; + cache[0] = tmp; + // cache 0 is used to store the current answer. + auto run_single = [&](int ow) { + int iw = ow * 2 - PW; + uint8_t res = std::numeric_limits::lowest(); + if (iw + 0 >= 0 && iw + 0 < IW) { + res = std::max(res, sptr[iw + 0]); + } + if (iw + 1 >= 0 && iw + 1 < IW) { + res = std::max(res, sptr[iw + 1]); + } + if (iw + 2 >= 0 && iw + 2 < IW) { + res = std::max(res, sptr[iw + 2]); + } + cache[0][ow] = res; + }; + // build odd/even + int iw = 0; + int odd_offset = 0, even_offset = 0; + + for (; iw + 32 <= IW; iw += 32) { + uint8x16_t s0, s1; + s0 = vld1q_u8(sptr + iw + 0); + s1 = vld1q_u8(sptr + iw + 16); + uint8x16x2_t d = vuzpq_u8(s0, s1); + vst1q_u8(even + even_offset, d.val[0]); + vst1q_u8(odd + odd_offset, d.val[1]); + even_offset += 16; + odd_offset += 16; + } + for (; iw < IW; ++iw) { + if (iw & 1) + odd[odd_offset++] = sptr[iw]; + else + even[even_offset++] = sptr[iw]; + } + int ow = 0; + for (; ow < OW_from; ++ow) + run_single(ow); + if (PW & 1) { + for (; ow + 16 <= OW_to; ow += 16) { + uint8x16_t d, s0, s1, s2; + s0 = vld1q_u8(odd + ow - (PW >> 1) - 1); + s1 = vld1q_u8(even + ow - (PW >> 1)); + s2 = vld1q_u8(odd + ow - (PW >> 1)); + d = vmaxq_u8(vmaxq_u8(s0, s1), s2); + vst1q_u8(cache[0] + ow, d); + } + } else { + for (; ow + 16 <= OW_to; ow += 16) { + uint8x16_t d, s0, s1, s2; + s0 = vld1q_u8(even + ow - (PW >> 1)); + s1 = vld1q_u8(odd + ow - (PW >> 1)); + s2 = vld1q_u8(even + ow - (PW >> 1) + 1); + d = vmaxq_u8(vmaxq_u8(s0, s1), s2); + vst1q_u8(cache[0] + ow, d); + } + } + for (; ow < OW; ++ow) + run_single(ow); + }; + for (int oh = 0; oh < OH; ++oh) { + uint8_t* __restrict dptr = dst + oh * OW; + int ih_from = std::min(IH, std::max(0, oh * 2 - PH)); + int ih_to = std::min(IH, std::max(0, oh * 2 - PH + 3)); + while (ih_next < ih_to) { + process_cache(ih_next++); + } + if (ih_to - ih_from == 3) { + int ow = 0; + for (; ow + 16 <= OW; ow += 16) { + uint8x16_t d, s0, s1, s2; + s0 = vld1q_u8(cache[0] + ow); + s1 = vld1q_u8(cache[1] + ow); + s2 = vld1q_u8(cache[2] + ow); + d = vmaxq_u8(vmaxq_u8(s0, s1), s2); + vst1q_u8(dptr + ow, d); + } + for (; ow < OW; ++ow) { + dptr[ow] = std::max(std::max(cache[0][ow], cache[1][ow]), + cache[2][ow]); + } + } else { + std::memcpy(dptr, cache[0], sizeof(uint8_t) * OW); + for (int i = 1; i < ih_to - ih_from; ++i) { + int ow = 0; + for (; ow + 16 <= OW; ow += 16) { + uint8x16_t d, s; + s = vld1q_u8(cache[i] + ow); + d = vld1q_u8(dptr + ow); + d = vmaxq_u8(d, s); + vst1q_u8(dptr + ow, d); + } + for (; ow < OW; ++ow) { + dptr[ow] = std::max(dptr[ow], cache[i][ow]); + } + } + } + } +} + +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h new file mode 100644 index 00000000..68c884f0 --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h @@ -0,0 +1,32 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include +#include "src/common/utils.h" + +namespace megdnn { +namespace arm_common { + +void do_max_pooling_3x3_s2x2_int8_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, size_t OH, + size_t OW, size_t PH, size_t PW, + const WorkspaceBundle& boudle); + +void do_max_pooling_3x3_s2x2_uint8_NEON(const uint8_t* src, uint8_t* dst, + size_t IH, size_t IW, size_t OH, + size_t OW, size_t PH, size_t PW, + const WorkspaceBundle& boudle); + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.cpp b/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.cpp new file mode 100644 index 00000000..6005bac5 --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.cpp @@ -0,0 +1,52 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h" +#include "src/arm_common/pooling/pooling_helper.h" + +namespace megdnn { +namespace arm_common { + +void pooling_max_w2x2_s2x2(const int8_t* src, int8_t* dst, size_t N, size_t C, + size_t IH, size_t IW, size_t OH, size_t OW) { + for (size_t nc = 0; nc < N * C; ++nc) { + for (size_t oh = 0; oh < OH; ++oh) { + size_t ih = oh << 1; + const int8_t* __restrict sptr0 = src + (ih + 0) * IW; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW; + int8_t* __restrict dptr = dst + oh * OW; + size_t ow = 0; + for (; ow + 8 <= OW; ow += 8) { + // process 2x16 to produce 1x8 elements at a time. + int8x16_t vs0 = vld1q_s8(sptr0), vs1 = vld1q_s8(sptr1); + int8x16_t vi = vmaxq_s8(vs0, vs1); + int8x8_t vd = vpmax_s8(vget_low_s8(vi), vget_high_s8(vi)); + vst1_s8(dptr, vd); + sptr0 += 16; + sptr1 += 16; + dptr += 8; + } + for (; ow < OW; ++ow) { + dptr[0] = std::max(std::max(sptr0[0], sptr0[1]), + std::max(sptr1[0], sptr1[1])); + sptr0 += 2; + sptr1 += 2; + dptr += 1; + } + } + src += IH * IW; + dst += OH * OW; + } +} +} // namespace arm_common +} // namespace megdnn +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h b/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h new file mode 100644 index 00000000..df730504 --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h @@ -0,0 +1,23 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/fallback/pooling/opr_impl.h" + +namespace megdnn { +namespace arm_common { +void pooling_max_w2x2_s2x2(const int8_t* src, int8_t* dst, size_t N, size_t C, + size_t IH, size_t IW, size_t OH, size_t OW); +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.cpp b/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.cpp new file mode 100644 index 00000000..f5cb8093 --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.cpp @@ -0,0 +1,293 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h" +#include "src/arm_common/pooling/pooling_helper.h" + +namespace megdnn { +namespace arm_common { + +void do_max_pooling_w4x4_s2x2_float_NEON(const dt_float32* src, dt_float32* dst, + DType src_dtype, const int IH, + const int IW, const int OH, + const int OW, const int PH, + const int PW) { + const int window = 4; + const int stride = 2; + using Pooler = MaxPooler<16, dt_float32, float, float>; + int oh = 0; + for (; oh < OH && -PH + stride * oh < 0; ++oh) { + int ow = 0; + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } + for (; oh < OH && -PH + stride * oh + window <= IH; ++oh) { + int ow = 0; + for (; ow < OW && -PW + stride * ow < 0; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + dt_float32 last_hf_res = -std::numeric_limits::infinity(); + int ih = -PH + stride * oh, iw = -PW + stride * ow; + if (-PW + stride * ow + window <= IW) { + float32x4_t i0 = vld1q_f32(src + (ih + 0) * IW + iw), + i1 = vld1q_f32(src + (ih + 1) * IW + iw), + i2 = vld1q_f32(src + (ih + 2) * IW + iw), + i3 = vld1q_f32(src + (ih + 3) * IW + iw); + float32x4_t sum0 = vmaxq_f32(vmaxq_f32(i0, i1), vmaxq_f32(i2, i3)); + float32x2_t t = vpmax_f32(vget_low_f32(sum0), vget_high_f32(sum0)); + dst[oh * OW + ow] = + std::max(vget_lane_f32(t, 0), vget_lane_f32(t, 1)); + last_hf_res = vget_lane_f32(t, 1); + ow += 1; + } + for (; ow + 1 < OW && -PW + stride * (ow + 1) + window <= IW; ow += 2) { + iw = -PW + stride * (ow + 1); + float32x4_t i0 = vld1q_f32(src + (ih + 0) * IW + iw), + i1 = vld1q_f32(src + (ih + 1) * IW + iw), + i2 = vld1q_f32(src + (ih + 2) * IW + iw), + i3 = vld1q_f32(src + (ih + 3) * IW + iw); + float32x4_t sum0 = vmaxq_f32(vmaxq_f32(i0, i1), vmaxq_f32(i2, i3)); + float32x2_t t = vpmax_f32(vget_low_f32(sum0), vget_high_f32(sum0)); + dst[oh * OW + ow + 0] = std::max(vget_lane_f32(t, 0), last_hf_res); + dst[oh * OW + ow + 1] = + std::max(vget_lane_f32(t, 0), vget_lane_f32(t, 1)); + last_hf_res = vget_lane_f32(t, 1); + } + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } + for (; oh < OH; ++oh) { + int ow = 0; + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } +} + +void do_max_pooling_w4x4_s2x2_int8_NEON(const int8_t* src, int8_t* dst, + DType src_dtype, const int IH, + const int IW, const int OH, + const int OW, const int PH, + const int PW) { + const int window = 4; + const int stride = 2; + using Pooler = MaxPooler<16, dt_qint8, int8_t, float>; + int oh = 0; + for (; oh < OH && -PH + stride * oh < 0; ++oh) { + int ow = 0; + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } + for (; oh < OH && -PH + stride * oh + window <= IH; ++oh) { + int ow = 0; + for (; ow < OW && -PW + stride * ow < 0; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + int8_t last_res = std::numeric_limits::lowest(); + int ih = -PH + stride * oh, iw = -PW + stride * ow; + if (-PW + stride * (ow + 6) + window <= IW) { + int8x16_t i0 = vld1q_s8(src + (ih + 0) * IW + iw), + i1 = vld1q_s8(src + (ih + 1) * IW + iw), + i2 = vld1q_s8(src + (ih + 2) * IW + iw), + i3 = vld1q_s8(src + (ih + 3) * IW + iw); + int8x16_t sum0 = vmaxq_s8(vmaxq_s8(i0, i1), vmaxq_s8(i2, i3)); + int8x8_t t = vpmax_s8(vget_low_s8(sum0), vget_high_s8(sum0)); +#define cb(i) \ + dst[oh * OW + ow + i] = \ + std::max(vget_lane_s8(t, i), vget_lane_s8(t, i + 1)); + UNROLL_CALL_NOWRAPPER(7, cb) +#undef cb + last_res = vget_lane_s8(t, 7); + ow += 7; + } + for (; ow + 7 < OW && -PW + stride * (ow + 7) + window <= IW; ow += 8) { + iw = -PW + stride * (ow + 1); + int8x16_t i0 = vld1q_s8(src + (ih + 0) * IW + iw), + i1 = vld1q_s8(src + (ih + 1) * IW + iw), + i2 = vld1q_s8(src + (ih + 2) * IW + iw), + i3 = vld1q_s8(src + (ih + 3) * IW + iw); + int8x16_t sum0 = vmaxq_s8(vmaxq_s8(i0, i1), vmaxq_s8(i2, i3)); + int8x8_t t = vpmax_s8(vget_low_s8(sum0), vget_high_s8(sum0)); + dst[oh * OW + ow + 0] = std::max(vget_lane_s8(t, 0), last_res); +#define cb(i) \ + dst[oh * OW + ow + i + 1] = \ + std::max(vget_lane_s8(t, i), vget_lane_s8(t, i + 1)); + UNROLL_CALL_NOWRAPPER(7, cb) +#undef cb + last_res = vget_lane_s8(t, 7); + } + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } + for (; oh < OH; ++oh) { + int ow = 0; + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } +} + +void do_max_pooling_w4x4_s2x2_uint8_NEON(const uint8_t* src, uint8_t* dst, + DType src_dtype, const int IH, + const int IW, const int OH, + const int OW, const int PH, + const int PW) { + const int window = 4; + const int stride = 2; + using Pooler = MaxPooler<16, dt_quint8, uint8_t, float>; + int oh = 0; + for (; oh < OH && -PH + stride * oh < 0; ++oh) { + int ow = 0; + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } + for (; oh < OH && -PH + stride * oh + window <= IH; ++oh) { + int ow = 0; + for (; ow < OW && -PW + stride * ow < 0; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + uint8_t last_res = std::numeric_limits::lowest(); + int ih = -PH + stride * oh, iw = -PW + stride * ow; + if (-PW + stride * (ow + 6) + window <= IW) { + uint8x16_t i0 = vld1q_u8(src + (ih + 0) * IW + iw), + i1 = vld1q_u8(src + (ih + 1) * IW + iw), + i2 = vld1q_u8(src + (ih + 2) * IW + iw), + i3 = vld1q_u8(src + (ih + 3) * IW + iw); + uint8x16_t sum0 = vmaxq_u8(vmaxq_u8(i0, i1), vmaxq_u8(i2, i3)); + uint8x8_t t = vpmax_u8(vget_low_u8(sum0), vget_high_u8(sum0)); +#define cb(i) \ + dst[oh * OW + ow + i] = \ + std::max(vget_lane_u8(t, i), vget_lane_u8(t, i + 1)); + UNROLL_CALL_NOWRAPPER(7, cb) +#undef cb + last_res = vget_lane_u8(t, 7); + ow += 7; + } + for (; ow + 7 < OW && -PW + stride * (ow + 7) + window <= IW; ow += 8) { + iw = -PW + stride * (ow + 1); + uint8x16_t i0 = vld1q_u8(src + (ih + 0) * IW + iw), + i1 = vld1q_u8(src + (ih + 1) * IW + iw), + i2 = vld1q_u8(src + (ih + 2) * IW + iw), + i3 = vld1q_u8(src + (ih + 3) * IW + iw); + uint8x16_t sum0 = vmaxq_u8(vmaxq_u8(i0, i1), vmaxq_u8(i2, i3)); + uint8x8_t t = vpmax_u8(vget_low_u8(sum0), vget_high_u8(sum0)); + dst[oh * OW + ow + 0] = std::max(vget_lane_u8(t, 0), last_res); +#define cb(i) \ + dst[oh * OW + ow + i + 1] = \ + std::max(vget_lane_u8(t, i), vget_lane_u8(t, i + 1)); + UNROLL_CALL_NOWRAPPER(7, cb) +#undef cb + last_res = vget_lane_u8(t, 7); + } + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } + for (; oh < OH; ++oh) { + int ow = 0; + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } +} +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void do_max_pooling_w4x4_s2x2_float16_NEON(const __fp16* src, __fp16* dst, + DType src_dtype, const int IH, + const int IW, const int OH, + const int OW, const int PH, + const int PW) { + const int window = 4; + const int stride = 2; + using Pooler = MaxPooler<16, dt_float16, __fp16, __fp16>; + int oh = 0; + for (; oh < OH && -PH + stride * oh < 0; ++oh) { + int ow = 0; + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } + for (; oh < OH && -PH + stride * oh + window <= IH; ++oh) { + int ow = 0; + for (; ow < OW && -PW + stride * ow < 0; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + __fp16 last_hf_res = -std::numeric_limits::infinity(); + int ih = -PH + stride * oh, iw = -PW + stride * ow; + if (-PW + stride * (ow + 2) + window <= IW) { + float16x8_t i0 = vld1q_f16(src + (ih + 0) * IW + iw), + i1 = vld1q_f16(src + (ih + 1) * IW + iw), + i2 = vld1q_f16(src + (ih + 2) * IW + iw), + i3 = vld1q_f16(src + (ih + 3) * IW + iw); + float16x8_t sum0 = vmaxq_f16(vmaxq_f16(i0, i1), vmaxq_f16(i2, i3)); + float16x4_t t = vpmax_f16(vget_low_f16(sum0), vget_high_f16(sum0)); + dst[oh * OW + ow] = + std::max(vget_lane_f16(t, 0), vget_lane_f16(t, 1)); + dst[oh * OW + ow + 1] = + std::max(vget_lane_f16(t, 1), vget_lane_f16(t, 2)); + dst[oh * OW + ow + 2] = + std::max(vget_lane_f16(t, 2), vget_lane_f16(t, 3)); + last_hf_res = vget_lane_f16(t, 3); + ow += 3; + } + for (; ow + 3 < OW && -PW + stride * (ow + 3) + window <= IW; ow += 4) { + iw = -PW + stride * (ow + 1); + float16x8_t i0 = vld1q_f16(src + (ih + 0) * IW + iw), + i1 = vld1q_f16(src + (ih + 1) * IW + iw), + i2 = vld1q_f16(src + (ih + 2) * IW + iw), + i3 = vld1q_f16(src + (ih + 3) * IW + iw); + float16x8_t sum0 = vmaxq_f16(vmaxq_f16(i0, i1), vmaxq_f16(i2, i3)); + float16x4_t t = vpmax_f16(vget_low_f16(sum0), vget_high_f16(sum0)); + dst[oh * OW + ow + 0] = std::max(vget_lane_f16(t, 0), last_hf_res); + dst[oh * OW + ow + 1] = + std::max(vget_lane_f16(t, 0), vget_lane_f16(t, 1)); + dst[oh * OW + ow + 2] = + std::max(vget_lane_f16(t, 1), vget_lane_f16(t, 2)); + dst[oh * OW + ow + 3] = + std::max(vget_lane_f16(t, 2), vget_lane_f16(t, 3)); + last_hf_res = vget_lane_f16(t, 3); + } + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } + for (; oh < OH; ++oh) { + int ow = 0; + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } +} +#endif +} // namespace arm_common +} // namespace megdnn +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h b/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h new file mode 100644 index 00000000..cf7839b3 --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h @@ -0,0 +1,44 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/fallback/pooling/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +void do_max_pooling_w4x4_s2x2_float_NEON(const dt_float32* src, dt_float32* dst, + DType src_dtype, const int IH, + const int IW, const int OH, + const int OW, const int PH, + const int PW); +void do_max_pooling_w4x4_s2x2_int8_NEON(const int8_t* src, int8_t* dst, + DType src_dtype, const int IH, + const int IW, const int OH, + const int OW, const int PH, + const int PW); +void do_max_pooling_w4x4_s2x2_uint8_NEON(const uint8_t* src, uint8_t* dst, + DType src_dtype, const int IH, + const int IW, const int OH, + const int OW, const int PH, + const int PW); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void do_max_pooling_w4x4_s2x2_float16_NEON(const __fp16* src, __fp16* dst, + DType src_dtype, const int IH, + const int IW, const int OH, + const int OW, const int PH, + const int PW); +#endif +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/arm_common/pooling/opr_impl.cpp b/dnn/src/arm_common/pooling/opr_impl.cpp new file mode 100644 index 00000000..ce3cb723 --- /dev/null +++ b/dnn/src/arm_common/pooling/opr_impl.cpp @@ -0,0 +1,135 @@ +/** + * \file dnn/src/arm_common/pooling/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/pooling/opr_impl.h" +#include "src/arm_common/pooling/algo.h" +#include "src/common/metahelper.h" + +using namespace megdnn; +using namespace arm_common; + +class PoolingImpl::AlgoPack : NonCopyableObj { + AlgoFilterxModexStride1 algo_filterx_modex_stride1; + AlgoFilter2ModexStride2 algo_filter2_modex_stride2; + AlgoFilter3MaxStride2 algo_filter3_max_stride2; + AlgoFilter3AverageStride2 algo_filter3_average_stride2; + AlgoFilter4MaxStride2 algo_filter4_max_stride2; + AlgoFilter5MaxStride2 algo_filter5_max_stride2; + AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2; + AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2; + +public: + AlgoPack() { + all_algos.emplace_back(&algo_filterx_modex_stride1); + all_algos.emplace_back(&algo_filter2_modex_stride2); + all_algos.emplace_back(&algo_filter3_max_stride2); + all_algos.emplace_back(&algo_filter3_average_stride2); + all_algos.emplace_back(&algo_filter4_max_stride2); + all_algos.emplace_back(&algo_filter5_max_stride2); + all_algos.emplace_back(&algo_int8_filter2_max_stride2); + all_algos.emplace_back(&algo_int8_filter3_max_stride2); + } + SmallVector all_algos; +}; + +PoolingImpl::PoolingKernSizeParam PoolingImpl::make_pooling_kern_szie_param( + fallback::PoolingImpl* opr, const TensorLayout& src, + const TensorLayout& dst) { + auto safe_u32 = [](size_t v) -> uint32_t { + megdnn_assert(v <= std::numeric_limits::max(), + "value too large: %zu", v); + return v; + }; + return {safe_u32(src.shape[0]), + safe_u32(src.shape[1]), + {{safe_u32(src.shape[2]), safe_u32(src.shape[3])}}, + {{safe_u32(dst.shape[2]), safe_u32(dst.shape[3])}}, + {{safe_u32(opr->param().pad_h), safe_u32(opr->param().pad_w)}}, + {{safe_u32(opr->param().window_h), + safe_u32(opr->param().window_w)}}, + {{safe_u32(opr->param().stride_h), + safe_u32(opr->param().stride_w)}}, + src.dtype, + dst.dtype, + opr->handle(), + opr->param().format, + opr->param().mode}; +}; + +PoolingImpl::PoolingKernParam PoolingImpl::make_pooling_kern_param( + fallback::PoolingImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + PoolingKernParam ret; + static_cast(ret) = + make_pooling_kern_szie_param(opr, src.layout, dst.layout); + ret.src_ptr = src.raw_ptr; + ret.dst_ptr = dst.raw_ptr; + ret.workspace_ptr = workspace.raw_ptr; + ret.workspace_size = workspace.size; + return ret; +}; + +size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst) { + bool find_algo = false; + static AlgoPack m_algo_pack; + auto param = make_pooling_kern_szie_param(this, src, dst); + for (auto& m_algo : m_algo_pack.all_algos) { + if (m_algo->usable(param)) { + find_algo = true; + break; + } + } + size_t arm_common_workspace = 0; + + //! When multi-thread, every thread has its own workspace + size_t nr_threads = static_cast(handle()) + ->megcore_dispatcher() + ->nr_threads(); + if ((param.src_type.category() == DTypeCategory::FLOAT || + param.src_type == dtype::Int8{} || + param.src_type.enumv() == DTypeEnum::QuantizedS8 || + param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && + param.filter[0] == param.filter[1] && + (param.filter[0] == 3 || param.filter[0] == 5) && + param.format == Param::Format::NCHW && + (param.mode == Mode::MAX || + (param.mode == Mode::AVERAGE && param.filter[0] == 3)) && + param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 && + param.isz[1] >= 2) { + WorkspaceBundle ws = get_bundle(param); + arm_common_workspace = ws.total_size_in_bytes() * nr_threads; + } + + if (find_algo) { + return arm_common_workspace; + } else { + auto fallback_worksapce = + fallback::PoolingImpl::get_workspace_in_bytes(src, dst); + return fallback_worksapce; + } +} + +void PoolingImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + check_exec(src.layout, dst.layout, workspace.size); + auto param = make_pooling_kern_param(this, src, dst, workspace); + static AlgoPack m_algo_pack; + for (auto& m_algo : m_algo_pack.all_algos) { + if (m_algo->usable(param)) { + m_algo->exec(param); + return; + } + } + fallback::PoolingImpl::exec(src, dst, workspace); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/opr_impl.h b/dnn/src/arm_common/pooling/opr_impl.h new file mode 100644 index 00000000..9b715991 --- /dev/null +++ b/dnn/src/arm_common/pooling/opr_impl.h @@ -0,0 +1,92 @@ +/** + * \file dnn/src/arm_common/pooling/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/fallback/pooling/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class PoolingImpl final : public fallback::PoolingImpl { +public: + using fallback::PoolingImpl::PoolingImpl; + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, + const TensorLayout&) override; + + static size_t constexpr MAX_SPATIAL_DIM = 2; + + struct PoolingKernSizeParam { + uint32_t n, ic; + std::array isz, osz; + std::array padding, filter, stride; + DType src_type, dst_type; + Handle* handle; + Param::Format format; + Mode mode; + }; + + struct PoolingKernParam : public PoolingKernSizeParam { + void* src_ptr; + void* dst_ptr; + void* workspace_ptr; + size_t workspace_size; + + template + const T* src() const { + src_type.assert_is_compatible_ctype(); + return static_cast(src_ptr); + } + + template + T* dst() const { + dst_type.assert_is_compatible_ctype(); + return static_cast(dst_ptr); + } + + template + T* workspace() const { + return static_cast(workspace_ptr); + } + }; + + PoolingKernSizeParam make_pooling_kern_szie_param( + fallback::PoolingImpl* opr, const TensorLayout& src, + const TensorLayout& dst); + + PoolingKernParam make_pooling_kern_param(fallback::PoolingImpl* opr, + _megdnn_tensor_in src, + _megdnn_tensor_out dst, + _megdnn_workspace workspace); + class AlgoBase : public detail::Algorithm { + public: + virtual ~AlgoBase() = default; + virtual bool usable(const PoolingKernSizeParam& param) const = 0; + virtual void exec(const PoolingKernParam& param) const = 0; + }; + +private: + class AlgoFilterxModexStride1; + class AlgoFilter2ModexStride2; + class AlgoFilter3MaxStride2; + class AlgoFilter3AverageStride2; + class AlgoFilter4MaxStride2; + class AlgoFilter5MaxStride2; + class AlgoInt8Filter2MaxStride2; + class AlgoInt8Filter3MaxStride2; + class AlgoPack; +}; +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/arm_common/pooling/pooling_helper.h b/dnn/src/arm_common/pooling/pooling_helper.h new file mode 100644 index 00000000..e14128a3 --- /dev/null +++ b/dnn/src/arm_common/pooling/pooling_helper.h @@ -0,0 +1,1040 @@ +/** + * \file dnn/src/arm_common/pooling/pooling_hleper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megdnn/dtype.h" +#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_float.h" +#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" + +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace { + +/* ======================= MeanPooler ======================== */ +using namespace megdnn; +/** + * \brief Mean mode for pooling + * \tparam area the pooling area size, FH * FW + * \tparam dtype the input type + * \tparam ctype the inner raw type + * \tparam comp_type compute type + */ +template +struct MeanPoolerCommon { + //! the neon register size is 16 bytes(128 bits) + static constexpr int SIMD_WIDTH = 16 / sizeof(ctype); + static constexpr comp_type coef = static_cast(1.0f) / area; + comp_type res; + MeanPoolerCommon() : res(0) {} + void feed(const ctype* val) { res += *val; } +}; +template +constexpr comp_type MeanPoolerCommon::coef; + +template +struct MeanInPooler : MeanPoolerCommon { + using ctype = _ctype; + //! `MIDOUT_CASE_NUM` is a unique int id + static constexpr int MIDOUT_CASE_NUM = 1; + MeanInPooler(DType) : MeanPoolerCommon() {} + void post(ctype* dst) { + this->res *= this->coef; + *dst = this->res; + } +}; + +template +struct MeanInRoundPooler : MeanPoolerCommon { + using ctype = _ctype; + void post(ctype* dst) { + this->res *= this->coef; + *dst = std::round(this->res); + } +}; + +template +struct MeanInPooler + : MeanInRoundPooler { + static constexpr int MIDOUT_CASE_NUM = 2; + MeanInPooler(DType) : MeanInRoundPooler() {} +}; + +template +struct MeanInPooler + : MeanInRoundPooler { + static constexpr int MIDOUT_CASE_NUM = 3; + uint8_t zero_point; + uint8_t feed_cnt; + MeanInPooler(DType dtype) + : MeanInRoundPooler(), + zero_point(dtype.param().zero_point), + feed_cnt(0) {} + void feed(const uint8_t* val) { + this->res += *val; + feed_cnt += 1; + } + void post(uint8_t* dst) { + this->res = + this->res + static_cast(area - feed_cnt) * zero_point; + this->res *= this->coef; + *dst = std::round(this->res); + } +}; + +template +struct NeonMeanPooler; + +template +struct NeonMeanPooler { + using ctype = float; + static constexpr int MIDOUT_CASE_NUM = 1; + static constexpr int SIMD_WIDTH = 4; + + static const float32x4_t coef; + float32x4_t res; + NeonMeanPooler(DType) : res(vdupq_n_f32(0.0f)) {} + void feed(const float* val) { res = vaddq_f32(res, vld1q_f32(val)); } + void post(float* dst) { + res = vmulq_f32(res, coef); + vst1q_f32(dst, res); + } +}; +template +const float32x4_t NeonMeanPooler::coef = + vdupq_n_f32(1.0f / area); + +template +struct NeonMeanPooler { + using ctype = int8_t; + static constexpr int MIDOUT_CASE_NUM = 2; + static constexpr int SIMD_WIDTH = 16; + + static const float32x4_t coef; +#if MEGDNN_ARMV7 + static const float32x4_t fzero; + static const float32x4_t half; + static const float32x4_t neg_half; +#endif + float32x4_t sum0; + float32x4_t sum1; + float32x4_t sum2; + float32x4_t sum3; + NeonMeanPooler(DType) + : sum0(vdupq_n_f32(0.0f)), + sum1(vdupq_n_f32(0.0f)), + sum2(vdupq_n_f32(0.0f)), + sum3(vdupq_n_f32(0.0f)) {} + void feed(const int8_t* val) { + int8x16_t item = vld1q_s8(val); + float32x4_t tmp; +#define cb(i) \ + tmp = (float32x4_t){static_cast(vgetq_lane_s8(item, 4 * i + 0)), \ + static_cast(vgetq_lane_s8(item, 4 * i + 1)), \ + static_cast(vgetq_lane_s8(item, 4 * i + 2)), \ + static_cast(vgetq_lane_s8(item, 4 * i + 3))}; \ + sum##i = vaddq_f32(sum##i, tmp); + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + } + void post(int8_t* dst) { +#define cb(i) sum##i = vmulq_f32(sum##i, coef); + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#if MEGDNN_AARCH64 +#define cb(i) auto res##i = vcvtaq_s32_f32(sum##i); +#elif MEGDNN_ARMV7 +#define cb(i) \ + auto inc##i = vbslq_f32(vcgeq_f32(sum##i, fzero), half, neg_half); \ + sum##i = vaddq_f32(sum##i, inc##i); \ + auto res##i = vcvtq_s32_f32(sum##i); +#else +#error "unsupport android arch" +#endif + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + int8x8_t merge_res1 = + vqmovn_s16(vcombine_s16(vqmovn_s32(res0), vqmovn_s32(res1))); + int8x8_t merge_res2 = + vqmovn_s16(vcombine_s16(vqmovn_s32(res2), vqmovn_s32(res3))); + + vst1q_s8(dst, vcombine_s8(merge_res1, merge_res2)); + } +}; +template +const float32x4_t NeonMeanPooler::coef = + vdupq_n_f32(1.0f / area); +#if MEGDNN_ARMV7 +template +const float32x4_t NeonMeanPooler::fzero = + vdupq_n_f32(0.f); +template +const float32x4_t NeonMeanPooler::half = + vdupq_n_f32(0.5f); +template +const float32x4_t NeonMeanPooler::neg_half = + vdupq_n_f32(-0.5f); +#endif + +template +struct NeonMeanPooler { + using ctype = uint8_t; + static constexpr int MIDOUT_CASE_NUM = 3; + static constexpr int SIMD_WIDTH = 16; + + static const float32x4_t coef; +#if MEGDNN_ARMV7 + static const float32x4_t fzero; + static const float32x4_t half; + static const float32x4_t neg_half; +#endif + float32x4_t sum0; + float32x4_t sum1; + float32x4_t sum2; + float32x4_t sum3; + NeonMeanPooler(DType) + : sum0(vdupq_n_f32(0.0f)), + sum1(vdupq_n_f32(0.0f)), + sum2(vdupq_n_f32(0.0f)), + sum3(vdupq_n_f32(0.0f)) {} + void feed(const uint8_t* val) { + uint8x16_t item = vld1q_u8(val); + float32x4_t tmp; +#define cb(i) \ + tmp = (float32x4_t){static_cast(vgetq_lane_u8(item, 4 * i + 0)), \ + static_cast(vgetq_lane_u8(item, 4 * i + 1)), \ + static_cast(vgetq_lane_u8(item, 4 * i + 2)), \ + static_cast(vgetq_lane_u8(item, 4 * i + 3))}; \ + sum##i = vaddq_f32(sum##i, tmp); + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + } + void post(uint8_t* dst) { +#define cb(i) sum##i = vmulq_f32(sum##i, coef); + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#if MEGDNN_AARCH64 +#define cb(i) auto res##i = vcvtaq_s32_f32(sum##i); +#elif MEGDNN_ARMV7 +#define cb(i) \ + auto inc##i = vbslq_f32(vcgeq_f32(sum##i, fzero), half, neg_half); \ + sum##i = vaddq_f32(sum##i, inc##i); \ + auto res##i = vcvtq_s32_f32(sum##i); +#else +#error "unsupport android arch" +#endif + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + uint8x8_t merge_res1 = vqmovn_u16(vreinterpretq_u16_s16( + vcombine_s16(vqmovn_s32(res0), vqmovn_s32(res1)))); + uint8x8_t merge_res2 = vqmovn_u16(vreinterpretq_u16_s16( + vcombine_s16(vqmovn_s32(res2), vqmovn_s32(res3)))); + + vst1q_u8(dst, vcombine_u8(merge_res1, merge_res2)); + } +}; +template +const float32x4_t NeonMeanPooler::coef = + vdupq_n_f32(1.0f / area); +#if MEGDNN_ARMV7 +template +const float32x4_t NeonMeanPooler::fzero = + vdupq_n_f32(0.f); +template +const float32x4_t NeonMeanPooler::half = + vdupq_n_f32(0.5f); +template +const float32x4_t NeonMeanPooler::neg_half = + vdupq_n_f32(-0.5f); +#endif + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template +struct NeonMeanPooler { + using ctype = __fp16; + static constexpr int MIDOUT_CASE_NUM = 4; + static constexpr int SIMD_WIDTH = 8; + + static const float16x8_t coef; + float16x8_t res; + NeonMeanPooler(DType) : res(vdupq_n_f16(0.0f)) {} + void feed(const __fp16* val) { res = vaddq_f16(res, vld1q_f16(val)); } + void post(__fp16* dst) { + res = vmulq_fix_f16(res, coef); + vst1q_f16(dst, res); + } +}; +template +const float16x8_t NeonMeanPooler::coef = + vdupq_n_f16(1.0f / area); +#endif + +/* ======================= MaxPooler ======================== */ + +template +struct MaxPooler { + using ctype = _ctype; + static constexpr int MIDOUT_CASE_NUM = 11; + static constexpr int SIMD_WIDTH = 16 / sizeof(ctype); + + static const ctype outsider; + ctype res; + MaxPooler(DType) : res(DTypeTrait::min()) {} + void feed(const ctype* val) { res = std::max(res, *val); } + void post(ctype* dst) { *dst = res; } +}; +template +const ctype MaxPooler::outsider = + DTypeTrait::min(); + +template +struct NeonMaxPooler; + +template +struct NeonMaxPooler { + using ctype = float; + static constexpr int MIDOUT_CASE_NUM = 11; + static constexpr int SIMD_WIDTH = 4; + + float32x4_t res; + NeonMaxPooler(DType) : res(vdupq_n_f32(DTypeTrait::min())) {} + void feed(const float* val) { res = vmaxq_f32(res, vld1q_f32(val)); } + void post(float* dst) { vst1q_f32(dst, res); } +}; + +template +struct NeonMaxPooler { + using ctype = int8_t; + static constexpr int MIDOUT_CASE_NUM = 12; + static constexpr int SIMD_WIDTH = 16; + + int8x16_t res; + NeonMaxPooler(DType) + : res(vdupq_n_s8(std::numeric_limits::lowest())) {} + void feed(const int8_t* val) { res = vmaxq_s8(res, vld1q_s8(val)); } + void post(int8_t* dst) { vst1q_s8(dst, res); } +}; + +template +struct NeonMaxPooler { + using ctype = uint8_t; + static constexpr int MIDOUT_CASE_NUM = 13; + static constexpr int SIMD_WIDTH = 16; + + uint8x16_t res; + NeonMaxPooler(DType) + : res(vdupq_n_u8(std::numeric_limits::lowest())) {} + void feed(const uint8_t* val) { res = vmaxq_u8(res, vld1q_u8(val)); } + void post(uint8_t* dst) { vst1q_u8(dst, res); } +}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template +struct NeonMaxPooler { + using ctype = __fp16; + static constexpr int MIDOUT_CASE_NUM = 14; + static constexpr int SIMD_WIDTH = 8; + + float16x8_t res; + NeonMaxPooler(DType) : res(vdupq_n_f16(DTypeTrait::min())) {} + void feed(const __fp16* val) { res = vmaxq_f16(res, vld1q_f16(val)); } + void post(__fp16* dst) { vst1q_f16(dst, res); } +}; +#endif + +template +void do_pxl_naive(int oh, int ow, const typename Pooler::ctype* src, + typename Pooler::ctype* dst, DType src_dtype, const int IH, + const int IW, const int OH, const int OW, const int PH, + const int PW, const int SH, const int SW) { + MEGDNN_MARK_USED_VAR(OH); + Pooler pooler(src_dtype); + rep(wh, window) rep(ww, window) { + int ih = -PH + oh * SH + wh; + int iw = -PW + ow * SW + ww; + if (ih >= 0 && iw >= 0 && ih < IH && iw < IW) { + pooler.feed(src + ih * IW + iw); + } + } + pooler.post(dst + oh * OW + ow); +} + +namespace detail { + +template +struct do_pxl_2x2_pack_proxy { + static void gao(int oh, int ow, const typename Pooler::ctype* src, + typename Pooler::ctype* dst, DType, const int IH, + const int IW, const int OH, const int OW, const int PH, + const int PW); +}; + +template <> +struct do_pxl_2x2_pack_proxy, + Pooling::Mode::AVERAGE> { + static void gao(int oh, int ow, const dt_float32* src, dt_float32* dst, + DType, const int IH, const int IW, const int OH, + const int OW, const int PH, const int PW) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + static const auto avg_coef = vdupq_n_f32(0.25f); + int ih = -PH + 2 * oh; + int iw = -PW + 2 * ow; + auto i00 = vld1q_f32(src + (ih + 0) * IW + (iw + 0)), + i01 = vld1q_f32(src + (ih + 0) * IW + (iw + 4)), + i10 = vld1q_f32(src + (ih + 1) * IW + (iw + 0)), + i11 = vld1q_f32(src + (ih + 1) * IW + (iw + 4)); + auto sum0 = vaddq_f32(i00, i10), sum1 = vaddq_f32(i01, i11); + auto vlow = vpadd_f32(vget_low_f32(sum0), vget_high_f32(sum0)); + auto vhigh = vpadd_f32(vget_low_f32(sum1), vget_high_f32(sum1)); + auto comb = vcombine_f32(vlow, vhigh); + auto result = vmulq_f32(comb, avg_coef); + vst1q_f32(dst + oh * OW + ow, result); + } +}; + +template <> +struct do_pxl_2x2_pack_proxy, + Pooling::Mode::AVERAGE> { + static void gao(int oh, int ow, const int8_t* src, int8_t* dst, DType, + const int IH, const int IW, const int OH, const int OW, + const int PH, const int PW) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + int ih = -PH + 2 * oh; + int iw = -PW + 2 * ow; + auto zero = vdupq_n_s16(0); + auto one = vdupq_n_s16(1); + auto i00 = vld1q_s8(src + (ih + 0) * IW + (iw + 0)), + i01 = vld1q_s8(src + (ih + 0) * IW + (iw + 16)), + i10 = vld1q_s8(src + (ih + 1) * IW + (iw + 0)), + i11 = vld1q_s8(src + (ih + 1) * IW + (iw + 16)); + int16x8_t sum0 = vaddl_s8(vget_low_s8(i00), vget_low_s8(i10)), + sum1 = vaddl_s8(vget_high_s8(i00), vget_high_s8(i10)), + sum2 = vaddl_s8(vget_low_s8(i01), vget_low_s8(i11)), + sum3 = vaddl_s8(vget_high_s8(i01), vget_high_s8(i11)); + + auto vlow0 = vpadd_s16(vget_low_s16(sum0), vget_high_s16(sum0)); + auto vhigh0 = vpadd_s16(vget_low_s16(sum1), vget_high_s16(sum1)); + auto vlow1 = vpadd_s16(vget_low_s16(sum2), vget_high_s16(sum2)); + auto vhigh1 = vpadd_s16(vget_low_s16(sum3), vget_high_s16(sum3)); + auto comb0 = vcombine_s16(vlow0, vhigh0); + auto comb1 = vcombine_s16(vlow1, vhigh1); + + auto fixup0 = vcltq_s16(comb0, zero); + comb0 = vsubq_s16(comb0, vbslq_s16(fixup0, one, zero)); + //! as vqrshrn_n_s16 is round to positive infinity + auto result0 = vqrshrn_n_s16(comb0, 2); + auto fixup1 = vcltq_s16(comb1, zero); + comb1 = vsubq_s16(comb1, vbslq_s16(fixup1, one, zero)); + auto result1 = vqrshrn_n_s16(comb1, 2); + vst1q_s8(dst + oh * OW + ow, vcombine_s8(result0, result1)); + } +}; + +template <> +struct do_pxl_2x2_pack_proxy, + Pooling::Mode::AVERAGE> { + static void gao(int oh, int ow, const uint8_t* src, uint8_t* dst, DType, + const int IH, const int IW, const int OH, const int OW, + const int PH, const int PW) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + int ih = -PH + 2 * oh; + int iw = -PW + 2 * ow; + auto i00 = vld1q_u8(src + (ih + 0) * IW + (iw + 0)), + i01 = vld1q_u8(src + (ih + 0) * IW + (iw + 16)), + i10 = vld1q_u8(src + (ih + 1) * IW + (iw + 0)), + i11 = vld1q_u8(src + (ih + 1) * IW + (iw + 16)); + uint16x8_t sum0 = vaddl_u8(vget_low_u8(i00), vget_low_u8(i10)), + sum1 = vaddl_u8(vget_high_u8(i00), vget_high_u8(i10)), + sum2 = vaddl_u8(vget_low_u8(i01), vget_low_u8(i11)), + sum3 = vaddl_u8(vget_high_u8(i01), vget_high_u8(i11)); + + auto vlow0 = vpadd_u16(vget_low_u16(sum0), vget_high_u16(sum0)); + auto vhigh0 = vpadd_u16(vget_low_u16(sum1), vget_high_u16(sum1)); + auto vlow1 = vpadd_u16(vget_low_u16(sum2), vget_high_u16(sum2)); + auto vhigh1 = vpadd_u16(vget_low_u16(sum3), vget_high_u16(sum3)); + auto comb0 = vcombine_u16(vlow0, vhigh0); + auto comb1 = vcombine_u16(vlow1, vhigh1); + + auto result0 = vqrshrn_n_u16(comb0, 2); + auto result1 = vqrshrn_n_u16(comb1, 2); + vst1q_u8(dst + oh * OW + ow, vcombine_u8(result0, result1)); + } +}; + +template <> +struct do_pxl_2x2_pack_proxy, + Pooling::Mode::MAX> { + static void gao(int oh, int ow, const dt_float32* src, dt_float32* dst, + DType, const int IH, const int IW, const int OH, + const int OW, const int PH, const int PW) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + int ih = -PH + 2 * oh; + int iw = -PW + 2 * ow; + auto i00 = vld1q_f32(src + (ih + 0) * IW + (iw + 0)), + i01 = vld1q_f32(src + (ih + 0) * IW + (iw + 4)), + i10 = vld1q_f32(src + (ih + 1) * IW + (iw + 0)), + i11 = vld1q_f32(src + (ih + 1) * IW + (iw + 4)); + auto sum0 = vmaxq_f32(i00, i10), sum1 = vmaxq_f32(i01, i11); + auto vlow = vpmax_f32(vget_low_f32(sum0), vget_high_f32(sum0)); + auto vhigh = vpmax_f32(vget_low_f32(sum1), vget_high_f32(sum1)); + auto comb = vcombine_f32(vlow, vhigh); + vst1q_f32(dst + oh * OW + ow, comb); + } +}; + +template <> +struct do_pxl_2x2_pack_proxy, + Pooling::Mode::MAX> { + static void gao(int oh, int ow, const int8_t* src, int8_t* dst, DType, + const int IH, const int IW, const int OH, const int OW, + const int PH, const int PW) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + int ih = -PH + 2 * oh; + int iw = -PW + 2 * ow; + auto i00 = vld1q_s8(src + (ih + 0) * IW + (iw + 0)), + i01 = vld1q_s8(src + (ih + 0) * IW + (iw + 16)), + i10 = vld1q_s8(src + (ih + 1) * IW + (iw + 0)), + i11 = vld1q_s8(src + (ih + 1) * IW + (iw + 16)); + auto sum0 = vmaxq_s8(i00, i10), sum1 = vmaxq_s8(i01, i11); + auto vlow = vpmax_s8(vget_low_s8(sum0), vget_high_s8(sum0)); + auto vhigh = vpmax_s8(vget_low_s8(sum1), vget_high_s8(sum1)); + auto comb = vcombine_s8(vlow, vhigh); + vst1q_s8(dst + oh * OW + ow, comb); + } +}; + +template <> +struct do_pxl_2x2_pack_proxy, + Pooling::Mode::MAX> { + static void gao(int oh, int ow, const uint8_t* src, uint8_t* dst, DType, + const int IH, const int IW, const int OH, const int OW, + const int PH, const int PW) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + int ih = -PH + 2 * oh; + int iw = -PW + 2 * ow; + auto i00 = vld1q_u8(src + (ih + 0) * IW + (iw + 0)), + i01 = vld1q_u8(src + (ih + 0) * IW + (iw + 16)), + i10 = vld1q_u8(src + (ih + 1) * IW + (iw + 0)), + i11 = vld1q_u8(src + (ih + 1) * IW + (iw + 16)); + auto sum0 = vmaxq_u8(i00, i10), sum1 = vmaxq_u8(i01, i11); + auto vlow = vpmax_u8(vget_low_u8(sum0), vget_high_u8(sum0)); + auto vhigh = vpmax_u8(vget_low_u8(sum1), vget_high_u8(sum1)); + auto comb = vcombine_u8(vlow, vhigh); + vst1q_u8(dst + oh * OW + ow, comb); + } +}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +struct do_pxl_2x2_pack_proxy, + Pooling::Mode::AVERAGE> { + static void gao(int oh, int ow, const __fp16* src, __fp16* dst, DType, + const int IH, const int IW, const int OH, const int OW, + const int PH, const int PW) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + static const auto avg_coef = vdupq_n_f16(0.25f); + int ih = -PH + 2 * oh; + int iw = -PW + 2 * ow; + auto i00 = vld1q_f16(src + (ih + 0) * IW + (iw + 0)), + i01 = vld1q_f16(src + (ih + 0) * IW + (iw + 8)), + i10 = vld1q_f16(src + (ih + 1) * IW + (iw + 0)), + i11 = vld1q_f16(src + (ih + 1) * IW + (iw + 8)); + auto sum0 = vaddq_f16(i00, i10), sum1 = vaddq_f16(i01, i11); + auto vlow = vpadd_f16(vget_low_f16(sum0), vget_high_f16(sum0)); + auto vhigh = vpadd_f16(vget_low_f16(sum1), vget_high_f16(sum1)); + auto comb = vcombine_f16(vlow, vhigh); + auto result = vmulq_f16(comb, avg_coef); + vst1q_f16(dst + oh * OW + ow, result); + } +}; + +template <> +struct do_pxl_2x2_pack_proxy, + Pooling::Mode::MAX> { + static void gao(int oh, int ow, const __fp16* src, __fp16* dst, DType, + const int IH, const int IW, const int OH, const int OW, + const int PH, const int PW) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + int ih = -PH + 2 * oh; + int iw = -PW + 2 * ow; + auto i00 = vld1q_f16(src + (ih + 0) * IW + (iw + 0)), + i01 = vld1q_f16(src + (ih + 0) * IW + (iw + 8)), + i10 = vld1q_f16(src + (ih + 1) * IW + (iw + 0)), + i11 = vld1q_f16(src + (ih + 1) * IW + (iw + 8)); + auto sum0 = vmaxq_f16(i00, i10), sum1 = vmaxq_f16(i01, i11); + auto vlow = vpmax_f16(vget_low_f16(sum0), vget_high_f16(sum0)); + auto vhigh = vpmax_f16(vget_low_f16(sum1), vget_high_f16(sum1)); + auto comb = vcombine_f16(vlow, vhigh); + vst1q_f16(dst + oh * OW + ow, comb); + } +}; +#endif + +} // namespace detail + +template +void do_pxl_2x2_pack(int oh, int ow, const typename Pooler::ctype* src, + typename Pooler::ctype* dst, DType src_dtype, const int IH, + const int IW, const int OH, const int OW, const int PH, + const int PW) { + detail::do_pxl_2x2_pack_proxy::gao( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW); +} + +template +void do_pxl_compact_packed(int oh, int ow, + const typename NeonPooler::ctype* src, + typename NeonPooler::ctype* dst, DType src_dtype, + const int IH, const int IW, const int OH, + const int OW, const int PH, const int PW) { + MEGDNN_MARK_USED_VAR(IH); + MEGDNN_MARK_USED_VAR(OH); + NeonPooler pooler(src_dtype); + rep(wh, window) rep(ww, window) { + int ih = -PH + oh + wh; + int iw = -PW + ow + ww; + pooler.feed(src + ih * IW + iw); + } + pooler.post(dst + oh * OW + ow); +} + +template +void do_pooling_compact(const typename Pooler::ctype* src, + typename Pooler::ctype* dst, DType src_dtype, + const int IH, const int IW, const int OH, const int OW, + const int PH, const int PW) { + static_assert(std::is_same::value, + "ctype of Pooler and NeonPooler is not the same"); + const int stride = 1; + int oh = 0; + for (; oh < OH && oh - PH < 0; ++oh) { + int ow = 0; + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } + for (; oh < OH && oh - PH + window <= IH; ++oh) { + int ow = 0; + for (; ow < OW && ow - PW < 0; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + for (; ow + NeonPooler::SIMD_WIDTH <= OW && + ow + NeonPooler::SIMD_WIDTH - 1 - PW + window <= IW; + ow += NeonPooler::SIMD_WIDTH) { + do_pxl_compact_packed( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW); + } + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } + for (; oh < OH; ++oh) { + int ow = 0; + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } +} + +template +void do_pooling_2x2(const typename Pooler::ctype* src, + typename Pooler::ctype* dst, DType src_dtype, const int IH, + const int IW, const int OH, const int OW, const int PH, + const int PW) { + const int window = 2; + const int stride = 2; + int oh = 0; + for (; oh < OH && -PH + stride * oh < 0; ++oh) { + int ow = 0; + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } + for (; oh < OH && -PH + stride * oh + window <= IH; ++oh) { + int ow = 0; + for (; ow < OW && -PW + stride * ow < 0; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + for (; ow + Pooler::SIMD_WIDTH <= OW && + -PW + stride * (ow + Pooler::SIMD_WIDTH - 1) + window <= IW; + ow += Pooler::SIMD_WIDTH) { + do_pxl_2x2_pack(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW); + } + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } + for (; oh < OH; ++oh) { + int ow = 0; + for (; ow < OW; ++ow) { + do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, + OH, OW, PH, PW, stride, stride); + } + } +} +inline float32x4_t vload(const float* src) { + return vld1q_f32(src); +} +inline float32x4x2_t vload2(const float* src) { + return vld2q_f32(src); +} +inline float32x4_t vdupq(float a) { + return vdupq_n_f32(a); +} +inline float32x4_t vaddq(float32x4_t a, float32x4_t b) { + return vaddq_f32(a, b); +} +inline float32x4_t vmulq(float32x4_t a, float32x4_t b) { + return vmulq_f32(a, b); +} +inline float32x4_t vmax(float32x4_t a, float32x4_t b) { + return vmaxq_f32(a, b); +} +inline void vset(float* src, float32x4_t dst) { + vst1q_f32(src, dst); +} +inline float32x4x2_t vunzip(float32x4_t a, float32x4_t b) { + return vuzpq_f32(a, b); +} + +inline int8x16_t vload(const int8_t* src) { + return vld1q_s8(src); +} +inline int8x16_t vmax(int8x16_t a, int8x16_t b) { + return vmaxq_s8(a, b); +} +inline void vset(int8_t* src, int8x16_t dst) { + vst1q_s8(src, dst); +} +inline int8x16x2_t vunzip(int8x16_t a, int8x16_t b) { + return vuzpq_s8(a, b); +} + +inline uint8x16_t vload(const uint8_t* src) { + return vld1q_u8(src); +} +inline uint8x16_t vmax(uint8x16_t a, uint8x16_t b) { + return vmaxq_u8(a, b); +} +inline void vset(uint8_t* src, uint8x16_t dst) { + vst1q_u8(src, dst); +} +inline uint8x16x2_t vunzip(uint8x16_t a, uint8x16_t b) { + return vuzpq_u8(a, b); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +inline float16x8_t vload(const __fp16* src) { + return vld1q_f16(src); +} +inline float16x8x2_t vload2(const __fp16* src) { + return vld2q_f16(src); +} +inline float16x8_t vaddq(float16x8_t a, float16x8_t b) { + return vaddq_f16(a, b); +} +inline float16x8_t vmulq(float16x8_t a, float16x8_t b) { + return vmulq_fix_f16(a, b); +} +inline float16x8_t vdupq(__fp16 a) { + return vdupq_n_f16(a); +} +inline float16x8_t vmax(float16x8_t a, float16x8_t b) { + return vmaxq_f16(a, b); +} +inline void vset(__fp16* src, float16x8_t dst) { + vst1q_f16(src, dst); +} +inline float16x8x2_t vunzip(float16x8_t a, float16x8_t b) { + return vuzpq_f16(a, b); +} +#endif + +// because the __fp16 can't get the lowest value, so add dtype +template +void do_max_pooling_w5x5_s2x2_NEON(const ctype* src, ctype* dst, const int IH, + const int IW, const int OH, const int OW, + const int PH, const int PW, + const WorkspaceBundle& ws, + const int MEGDNN_SIMD_WIDTH) { + ctype* cache[5] = { + static_cast(ws.get(0)), static_cast(ws.get(1)), + static_cast(ws.get(2)), static_cast(ws.get(3)), + static_cast(ws.get(4))}; + ctype* odd = static_cast(ws.get(5)); + ctype* even = static_cast(ws.get(6)); + int ih_next = 0; + int OW_from = (PW + 1) / 2, OW_to = (IW + PW - 5) / 2 + 1; + auto process_cache = [&](int ih) { + const ctype* __restrict sptr = src + ih * IW; + auto tmp = cache[4]; + for (auto i = 4; i >= 1; --i) + cache[i] = cache[i - 1]; + cache[0] = tmp; + auto run_single = [&](int ow) { + int iw = ow * 2 - PW; + ctype res = std::numeric_limits::lowest(); + for (auto i = 0; i < 5; ++i) + if (iw + i >= 0 && iw + i < IW) + res = std::max(res, sptr[iw + i]); + cache[0][ow] = res; + }; + int iw = 0; + int odd_offset = 0, even_offset = 0; + for (; iw + 2 * MEGDNN_SIMD_WIDTH <= IW; iw += 2 * MEGDNN_SIMD_WIDTH) { + auto s0 = vload(sptr + iw + 0); + auto s1 = vload(sptr + iw + MEGDNN_SIMD_WIDTH); + auto d = vunzip(s0, s1); + vset(even + even_offset, d.val[0]); + vset(odd + odd_offset, d.val[1]); + even_offset += MEGDNN_SIMD_WIDTH; + odd_offset += MEGDNN_SIMD_WIDTH; + } + for (; iw < IW; ++iw) { + if (iw & 1) + odd[odd_offset++] = sptr[iw]; + else + even[even_offset++] = sptr[iw]; + } + int ow = 0; + for (; ow < OW_from; ++ow) + run_single(ow); + if (PW & 1) { + for (; ow + MEGDNN_SIMD_WIDTH <= OW_to; ow += MEGDNN_SIMD_WIDTH) { + auto s0 = vload(odd + ow - (PW >> 1) - 1); + auto s1 = vload(even + ow - (PW >> 1)); + auto s2 = vload(odd + ow - (PW >> 1)); + auto s3 = vload(even + ow - (PW >> 1) + 1); + auto s4 = vload(odd + ow - (PW >> 1) + 1); + auto d = vmax(s0, vmax(vmax(s1, s2), vmax(s3, s4))); + vset(cache[0] + ow, d); + } + } else { + for (; ow + MEGDNN_SIMD_WIDTH <= OW_to; ow += MEGDNN_SIMD_WIDTH) { + auto s0 = vload(even + ow - (PW >> 1)); + auto s1 = vload(odd + ow - (PW >> 1)); + auto s2 = vload(even + ow - (PW >> 1) + 1); + auto s3 = vload(odd + ow - (PW >> 1) + 1); + auto s4 = vload(even + ow - (PW >> 1) + 2); + auto d = vmax(s0, vmax(vmax(s1, s2), vmax(s3, s4))); + vset(cache[0] + ow, d); + } + } + for (; ow < OW; ++ow) + run_single(ow); + }; + + for (int oh = 0; oh < OH; ++oh) { + ctype* __restrict dptr = dst + oh * OW; + int ih_from = std::min(IH, std::max(0, oh * 2 - PH)); + int ih_to = std::min(IH, std::max(0, oh * 2 - PH + 5)); + while (ih_next < ih_to) + process_cache(ih_next++); + if (ih_to - ih_from == 5) { + int ow = 0; + for (; ow + MEGDNN_SIMD_WIDTH <= OW; ow += MEGDNN_SIMD_WIDTH) { + auto s0 = vload(cache[0] + ow); + auto s1 = vload(cache[1] + ow); + auto s2 = vload(cache[2] + ow); + auto s3 = vload(cache[3] + ow); + auto s4 = vload(cache[4] + ow); + auto d = vmax(s0, vmax(vmax(s1, s2), vmax(s3, s4))); + vset(dptr + ow, d); + } + for (; ow < OW; ++ow) + dptr[ow] = std::max({cache[0][ow], cache[1][ow], cache[2][ow], + cache[3][ow], cache[4][ow]}); + } else { + std::memcpy(dptr, cache[0], sizeof(ctype) * OW); + for (int i = 1; i < ih_to - ih_from; ++i) { + int ow = 0; + for (; ow + MEGDNN_SIMD_WIDTH <= OW; ow += MEGDNN_SIMD_WIDTH) { + auto s = vload(cache[i] + ow); + auto d = vload(dptr + ow); + d = vmax(d, s); + vset(dptr + ow, d); + } + for (; ow < OW; ++ow) + dptr[ow] = std::max(dptr[ow], cache[i][ow]); + } + } + } +} + +template +void do_average_pooling_3x3_s2x2_NEON(const ctype* src, ctype* dst, size_t IH_, + size_t IW_, size_t OH_, size_t OW_, + size_t PH_, size_t PW_, + const WorkspaceBundle& ws, + const int MEGDNN_SIMD_WIDTH) { + int IH = IH_, IW = IW_, OH = OH_, OW = OW_, PH = PH_, PW = PW_; + // cache[i] stores the answer of the i-th line after + // pooling along the W dimension. + ctype* cache[3] = {static_cast(ws.get(0)), + static_cast(ws.get(1)), + static_cast(ws.get(2))}; + ctype* odd = static_cast(ws.get(3)); + ctype* even = static_cast(ws.get(4)); + int ih_next = 0; + // "good" area means we can use SIMD to accelerate. + auto get_good_area = [](int I, int /* O */, int P, int& O_from, int& O_to) { + // x*2 - P >= 0; 2x >= P; x >= P/2 + O_from = (P + 1) / 2; + // x*2 - P + 3 <= I; x*2 <= I+P-3; x <= (I+P-3)/2 + O_to = (I + P - 3) / 2 + 1; + // we must have I >= 2 to ensure O_from <= O_to + }; + int OW_from, OW_to; + get_good_area(IW, OW, PW, OW_from, OW_to); + auto process_cache = [&](int ih) { + const ctype* __restrict sptr = src + ih * IW; + auto tmp = cache[2]; + cache[2] = cache[1]; + cache[1] = cache[0]; + cache[0] = tmp; + // cache 0 is used to store the current answer. + auto run_single = [&](int ow) { + int iw = ow * 2 - PW; + ctype res = 0; + if (iw + 0 >= 0 && iw + 0 < IW) { + res += sptr[iw + 0]; + } + if (iw + 1 >= 0 && iw + 1 < IW) { + res += sptr[iw + 1]; + } + if (iw + 2 >= 0 && iw + 2 < IW) { + res += sptr[iw + 2]; + } + cache[0][ow] = res; + }; + // build odd/even + int iw = 0; + int odd_offset = 0, even_offset = 0; + + for (; iw + 2 * MEGDNN_SIMD_WIDTH <= IW; iw += 2 * MEGDNN_SIMD_WIDTH) { + auto s0 = vload2(sptr + iw); + vset(even + even_offset, s0.val[0]); + vset(odd + odd_offset, s0.val[1]); + even_offset += MEGDNN_SIMD_WIDTH; + odd_offset += MEGDNN_SIMD_WIDTH; + } + for (; iw < IW; ++iw) { + if (iw & 1) + odd[odd_offset++] = sptr[iw]; + else + even[even_offset++] = sptr[iw]; + } + int ow = 0; + for (; ow < OW_from; ++ow) + run_single(ow); + if (PW & 1) { + for (; ow + MEGDNN_SIMD_WIDTH <= OW_to; ow += MEGDNN_SIMD_WIDTH) { + auto s0 = vload(odd + ow - (PW >> 1) - 1); + auto s1 = vload(even + ow - (PW >> 1)); + auto s2 = vload(odd + ow - (PW >> 1)); + auto d = vaddq(vaddq(s0, s1), s2); + vset(cache[0] + ow, d); + } + } else { + for (; ow + MEGDNN_SIMD_WIDTH <= OW_to; ow += MEGDNN_SIMD_WIDTH) { + auto s0 = vload(even + ow - (PW >> 1)); + auto s1 = vload(odd + ow - (PW >> 1)); + auto s2 = vload(even + ow - (PW >> 1) + 1); + auto d = vaddq(vaddq(s0, s1), s2); + vset(cache[0] + ow, d); + } + } + for (; ow < OW; ++ow) + run_single(ow); + }; + for (int oh = 0; oh < OH; ++oh) { + ctype* __restrict dptr = dst + oh * OW; + int ih_from = std::min(IH, std::max(0, oh * 2 - PH)); + int ih_to = std::min(IH, std::max(0, oh * 2 - PH + 3)); + while (ih_next < ih_to) { + process_cache(ih_next++); + } + ctype factor = (1.0f / 9); + auto coef = vdupq(factor); + if (ih_to - ih_from == 3) { + int ow = 0; + for (; ow + MEGDNN_SIMD_WIDTH <= OW; ow += MEGDNN_SIMD_WIDTH) { + auto s0 = vload(cache[0] + ow); + auto s1 = vload(cache[1] + ow); + auto s2 = vload(cache[2] + ow); + auto d = vaddq(vaddq(s0, s1), s2); + d = vmulq(d, coef); + vset(dptr + ow, d); + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; ow < OW; ++ow) { + dptr[ow] = + (cache[0][ow] + cache[1][ow] + cache[2][ow]) * factor; + } + } else { + std::memcpy(dptr, cache[0], sizeof(ctype) * OW); + int i = 1; + for (; i < ih_to - ih_from; ++i) { + int ow = 0; + for (; ow + MEGDNN_SIMD_WIDTH <= OW; ow += MEGDNN_SIMD_WIDTH) { + auto s = vload(cache[i] + ow); + auto d = vload(dptr + ow); + d = vaddq(d, s); + vset(dptr + ow, d); + } + for (; ow < OW; ++ow) { + dptr[ow] = (dptr[ow] + cache[i][ow]); + } + } + int ow = 0; + for (; ow + MEGDNN_SIMD_WIDTH <= OW; ow += MEGDNN_SIMD_WIDTH) { + auto d = vload(dptr + ow); + d = vmulq(d, coef); + vset(dptr + ow, d); + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; ow < OW; ++ow) { + dptr[ow] *= factor; + } + } + } +} +} // anonymous namespace + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/quantized_converter.h b/dnn/src/arm_common/quantized_converter.h new file mode 100644 index 00000000..1586bca2 --- /dev/null +++ b/dnn/src/arm_common/quantized_converter.h @@ -0,0 +1,147 @@ +/** + * \file dnn/src/arm_common/quantized_converter.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once + +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace arm_common { + +struct QConverterBase { + inline static int32x4_t vzero() { return vdupq_n_s32(0); } + + inline static float32x4_t vfzero() { return vdupq_n_f32(0.f); } + + inline static float32x4_t vfhalf() { return vdupq_n_f32(0.5f); } + + inline static float32x4_t vfneg_half() { return vdupq_n_f32(-0.5f); } +}; + +struct QConverter { + template + static inline dst_type convert(const src_type&... src); +}; + +template <> +inline dt_qint8 QConverter::convert(const float& src) { + return dt_qint8(saturate(std::round(src), -128, 127)); +} + +template <> +inline dt_quint8 QConverter::convert(const float& src, const uint8_t& zp) { + return dt_quint8(saturate(std::round(src) + zp, 0, 255)); +} + +template <> +inline dt_qint32 QConverter::convert(const float& src) { + return dt_qint32( + saturate(std::round(src), -2147483648, 2147483647)); +} + +#if __ARM_ARCH >= 8 +template <> +inline int8x8_t QConverter::convert(const float32x4x2_t& vsrc) { + int32x4_t vres0 = vcvtaq_s32_f32(vsrc.val[0]); + int32x4_t vres1 = vcvtaq_s32_f32(vsrc.val[1]); + return vqmovn_s16(vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1))); +} +template <> +inline int8x8_t QConverter::convert(const float32x4_t& src) { + int32x4_t res0 = vcvtaq_s32_f32(src); + int16x4_t res0_int16 = vqmovn_s32(res0); + return vqmovn_s16(vcombine_s16(res0_int16, res0_int16)); +} + +template <> +inline uint8x8_t QConverter::convert(const float32x4x2_t& vsrc, + const int32x4_t& vzp) { + int32x4_t vres0 = vcvtaq_s32_f32(vsrc.val[0]); + int32x4_t vres1 = vcvtaq_s32_f32(vsrc.val[1]); + vres0 = vqaddq_s32(vres0, vzp); + vres1 = vqaddq_s32(vres1, vzp); + vres0 = vmaxq_s32(vres0, QConverterBase::vzero()); + vres1 = vmaxq_s32(vres1, QConverterBase::vzero()); + + return vqmovn_u16(vreinterpretq_u16_s16( + vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1)))); +} + +template <> +inline int32x4_t QConverter::convert(const float32x4_t& vsrc) { + return vcvtaq_s32_f32(vsrc); +} + +#else +template <> +inline int8x8_t QConverter::convert(const float32x4x2_t& vsrc) { + float32x4_t vinc0 = + vbslq_f32(vcgeq_f32(vsrc.val[0], QConverterBase::vfzero()), + QConverterBase::vfhalf(), QConverterBase::vfneg_half()); + float32x4_t vinc1 = + vbslq_f32(vcgeq_f32(vsrc.val[1], QConverterBase::vfzero()), + QConverterBase::vfhalf(), QConverterBase::vfneg_half()); + + int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(vsrc.val[0], vinc0)); + int32x4_t vres1 = vcvtq_s32_f32(vaddq_f32(vsrc.val[1], vinc1)); + + return vqmovn_s16(vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1))); +} + +template <> +inline int8x8_t QConverter::convert(const float32x4_t& src) { + float32x4_t vinc0 = + vbslq_f32(vcgeq_f32(src, QConverterBase::vfzero()), + QConverterBase::vfhalf(), QConverterBase::vfneg_half()); + + int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(src, vinc0)); + int16x4_t vres0_int16 = vqmovn_s32(vres0); + return vqmovn_s16(vcombine_s16(vres0_int16, vres0_int16)); +} + +template <> +inline uint8x8_t QConverter::convert(const float32x4x2_t& vsrc, + const int32x4_t& vzp) { + float32x4_t vinc0 = + vbslq_f32(vcgeq_f32(vsrc.val[0], QConverterBase::vfzero()), + QConverterBase::vfhalf(), QConverterBase::vfneg_half()); + float32x4_t vinc1 = + vbslq_f32(vcgeq_f32(vsrc.val[1], QConverterBase::vfzero()), + QConverterBase::vfhalf(), QConverterBase::vfneg_half()); + + int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(vsrc.val[0], vinc0)); + int32x4_t vres1 = vcvtq_s32_f32(vaddq_f32(vsrc.val[1], vinc1)); + vres0 = vqaddq_s32(vres0, vzp); + vres1 = vqaddq_s32(vres1, vzp); + vres0 = vmaxq_s32(vres0, QConverterBase::vzero()); + vres1 = vmaxq_s32(vres1, QConverterBase::vzero()); + + return vqmovn_u16(vreinterpretq_u16_s16( + vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1)))); +} + +template <> +inline int32x4_t QConverter::convert(const float32x4_t& vsrc) { + float32x4_t vinc = + vbslq_f32(vcgeq_f32(vsrc, QConverterBase::vfzero()), + QConverterBase::vfhalf(), QConverterBase::vfneg_half()); + return vcvtq_s32_f32(vaddq_f32(vsrc, vinc)); +} + +#endif + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/reduce/opr_impl.cpp b/dnn/src/arm_common/reduce/opr_impl.cpp new file mode 100644 index 00000000..70457f42 --- /dev/null +++ b/dnn/src/arm_common/reduce/opr_impl.cpp @@ -0,0 +1,932 @@ +/** + * \file dnn/src/arm_common/reduce/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/reduce/opr_impl.h" + +#include +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/quantized_converter.h" +#include "src/common/reduce_helper.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" + +using namespace megdnn; +using namespace arm_common; + +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_reduce) + +namespace { + +//!FIXME: we should check this when update the compiler +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if MEGDNN_ARMV7 +typedef float fp16_fix_t; +#else +typedef __fp16 fp16_fix_t; +#endif +#endif + +/*****************************Mean Reducer***********************/ +template +struct MeanReducer; + +template <> +struct MeanReducer { + using ctype = int8_t; + static constexpr int SIMD_WIDTH = 16; + + int32_t res; + float coef; + MeanReducer(DType, size_t cnt) : res(0), coef(1.0 / cnt) {} + MeanReducer() = default; + void feed(const int8_t* val) { +#if MEGDNN_AARCH64 + res += vaddlvq_s8(vld1q_s8(val)); +#elif MEGDNN_ARMV7 + auto sum = vpaddlq_s16(vpaddlq_s8(vld1q_s8(val))); + res += (vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) + + vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3)); +#else +#error "unsupport android arch" +#endif + } + void feed_remain(const int8_t* val) { res += *val; } + void post(int8_t* dst) { + float sum = res * coef; + *dst = std::round(sum); + } +}; + +template <> +struct MeanReducer { + using ctype = uint8_t; + static constexpr int SIMD_WIDTH = 16; + + int32_t res; + int32_t zp; + int32_t cnt; + float coef; + MeanReducer(DType src_dtype, size_t cnt) + : res(0), cnt(cnt), coef(1.0 / cnt) { + zp = src_dtype.param().zero_point; + } + MeanReducer() = default; + void feed(const uint8_t* val) { +#if MEGDNN_AARCH64 + res += vaddlvq_u8(vld1q_u8(val)); +#elif MEGDNN_ARMV7 + auto sum = + vreinterpretq_s32_u32(vpaddlq_u16(vpaddlq_u8(vld1q_u8(val)))); + res += (vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) + + vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3)); +#else +#error "unsupport android arch" +#endif + } + void feed_remain(const uint8_t* val) { res += *val; } + void post(uint8_t* dst) { + float sum = (res - zp * cnt) * coef; + *dst = std::round(sum) + zp; + } +}; + +template <> +struct MeanReducer { + using ctype = float; + static constexpr int SIMD_WIDTH = 4; + + float32x4_t res; + float result; + float coef; + MeanReducer(DType, size_t cnt) : result(0.0f), coef(1.0 / cnt) { + res = vdupq_n_f32(0.0f); + } + MeanReducer() = default; + void feed(const float* val) { res = vaddq_f32(vld1q_f32(val), res); } + void feed_remain(const float* val) { result += *val; } + void post(float* dst) { +#if MEGDNN_AARCH64 + result += vaddvq_f32(res); +#elif MEGDNN_ARMV7 + auto sum_temp = vpadd_f32(vget_low_f32(res), vget_high_f32(res)); + result += (vget_lane_f32(sum_temp, 0) + vget_lane_f32(sum_temp, 1)); +#else +#error "unsupport android arch" +#endif + *dst = result * coef; + } +}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +struct MeanReducer<__fp16, __fp16, __fp16, true> { + using ctype = __fp16; + static constexpr int SIMD_WIDTH = 8; + + float16x8_t res; + fp16_fix_t result; + fp16_fix_t coef; + MeanReducer(DType, size_t cnt) : result(0.0f), coef(1.0 / cnt) { + res = vdupq_n_f16(0.0f); + } + MeanReducer() = default; + void feed(const ctype* val) { res = vaddq_f16(vld1q_f16(val), res); } + void feed_remain(const ctype* val) { result += *val; } + void post(ctype* dst) { + auto sum_tmp = vadd_f16(vget_low_f16(res), vget_high_f16(res)); + result += (vget_lane_f16(sum_tmp, 0) + vget_lane_f16(sum_tmp, 1) + + vget_lane_f16(sum_tmp, 2) + vget_lane_f16(sum_tmp, 3)); + *dst = result * coef; + } +}; + +template <> +struct MeanReducer<__fp16, __fp16, __fp16, false> { + using ctype = __fp16; + static constexpr int SIMD_WIDTH = 8; + + float16x8_t res; + fp16_fix_t remain; + fp16_fix_t coef; + MeanReducer(DType, size_t cnt) : remain(0.0f), coef(1.0 / cnt) { + res = vdupq_n_f16(0.0f); + } + MeanReducer() = default; + void feed(const ctype* val) { res = vaddq_f16(vld1q_f16(val), res); } + void feed_remain(const ctype* val) { remain += *val; } + void post(ctype* dst) { + res = vmulq_n_f16(res, coef); + vst1q_f16(dst, res); + } + void post_remain(ctype* dst){ + *dst = remain * coef; + } +}; +#endif + +template <> +struct MeanReducer { + using ctype = float; + static constexpr int SIMD_WIDTH = 4; + + float32x4_t res; + float remain; + float coef; + MeanReducer(DType, size_t cnt) : remain(0.0f), coef(1.0 / cnt) { + res = vdupq_n_f32(0.0f); + } + MeanReducer() = default; + void feed(const float* val) { res = vaddq_f32(vld1q_f32(val), res); } + void feed_remain(const float* val) { remain += *val; } + void post(float* dst) { + res = vmulq_n_f32(res, coef); + vst1q_f32(dst, res); + } + void post_remain(float* dst){ + *dst = remain * coef; + } +}; + +template <> +struct MeanReducer { + using ctype = int8_t; + static constexpr int SIMD_WIDTH = 16; + + int32x4_t res[4]; + int32_t remain; + int32_t cnt; + float coef; + float32x4_t vcoef; + MeanReducer(DType, size_t cnt) + : remain(0), cnt(cnt), coef(1.0 / cnt) { + memset(res, 0, sizeof (res)); + vcoef = vdupq_n_f32(coef); + } + MeanReducer() = default; + void feed(const int8_t* val) { + const int8x16_t vval = vld1q_s8(val); + const int16x8_t vval_low = vmovl_s8(vget_low_s8(vval)); + const int16x8_t vval_high = vmovl_s8(vget_high_s8(vval)); + + const int32x4_t vval_low_low = vmovl_s16(vget_low_s16(vval_low)); + const int32x4_t vval_low_high = vmovl_s16(vget_high_s16(vval_low)); + const int32x4_t vval_high_low = vmovl_s16(vget_low_s16(vval_high)); + const int32x4_t vval_high_high = vmovl_s16(vget_high_s16(vval_high)); + + res[0] = vaddq_s32(res[0], vval_low_low); + res[1] = vaddq_s32(res[1], vval_low_high); + res[2] = vaddq_s32(res[2], vval_high_low); + res[3] = vaddq_s32(res[3], vval_high_high); + } + void feed_remain(const int8_t* val) { remain += *val; } + void post(int8_t* dst) { + for (int i = 0; i < 4; i += 2) { + float32x4_t vitem0 = vmulq_f32(vcvtq_f32_s32(res[i]), vcoef); + float32x4_t vitem1 = vmulq_f32(vcvtq_f32_s32(res[i + 1]), vcoef); + vst1_s8(dst, (QConverter::convert({{vitem0, vitem1}}))); + dst += 8; + } + } + void post_remain(int8_t* dst) { + float sum = remain * coef; + *dst = std::round(sum); + } +}; + +template <> +struct MeanReducer { + using ctype = uint8_t; + static constexpr int SIMD_WIDTH = 16; + + int32x4_t res[4]; + int32_t remain; + int32_t zp; + int32x4_t vzp; + int32_t cnt; + int32x4_t vcnt; + float coef; + float32x4_t vcoef; + MeanReducer(DType src_dtype, size_t cnt) + : remain(0), cnt(cnt), coef(1.0 / cnt) { + zp = src_dtype.param().zero_point; + vzp = vdupq_n_s32(zp); + memset(res, 0, sizeof (res)); + vcnt = vdupq_n_s32(cnt); + vcoef = vdupq_n_f32(coef); + } + MeanReducer() = default; + void feed(const uint8_t* val) { + const uint8x16_t vval = vld1q_u8(val); + const uint16x8_t vval_low = vmovl_u8(vget_low_u8(vval)); + const uint16x8_t vval_high = vmovl_u8(vget_high_u8(vval)); + + const uint32x4_t vval_low_low = vmovl_u16(vget_low_u16(vval_low)); + const uint32x4_t vval_low_high = vmovl_u16(vget_high_u16(vval_low)); + const uint32x4_t vval_high_low = vmovl_u16(vget_low_u16(vval_high)); + const uint32x4_t vval_high_high = vmovl_u16(vget_high_u16(vval_high)); + + res[0] = vaddq_s32(res[0], vreinterpretq_s32_u32(vval_low_low)); + res[1] = vaddq_s32(res[1], vreinterpretq_s32_u32(vval_low_high)); + res[2] = vaddq_s32(res[2], vreinterpretq_s32_u32(vval_high_low)); + res[3] = vaddq_s32(res[3], vreinterpretq_s32_u32(vval_high_high)); + } + void feed_remain(const uint8_t* val) { remain += *val; } + void post(uint8_t* dst) { + for (int i = 0; i < 4; i += 2) { + int32x4_t tmp = vmulq_s32(vzp, vcnt); + int32x4_t tmp0 = vsubq_s32(res[i], tmp); + int32x4_t tmp1 = vsubq_s32(res[i + 1], tmp); + float32x4_t vitem0 = vmulq_f32(vcvtq_f32_s32(tmp0), vcoef); + float32x4_t vitem1 = vmulq_f32(vcvtq_f32_s32(tmp1), vcoef); + + vst1_u8(dst, (QConverter::convert({{vitem0, vitem1}}, vzp))); + dst += 8; + } + } + void post_remain(uint8_t* dst) { + float sum = (remain - zp * cnt) * coef; + *dst = std::round(sum) + zp; + } +}; + +/******************************max min Reducer****************************/ +template +struct maxReducer; +template +struct minReducer; + +#define REDUCER_MAX_MIN_C1(_mode, _dtype, _ctype, _comp_type, _stype, __stype, _init) \ + template<> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = 16; \ + __stype##8x16_t res; \ + _mode##Reducer(DType, size_t) { res = vdupq_n_##_stype##8(_init); } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype##8x16_t vval = vld1q_##_stype##8(val); \ + res = v##_mode##q_##_stype##8(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + __stype##8x16_t vval = vdupq_n_##_stype##8(*val); \ + res = v##_mode##q_##_stype##8(vval, res); \ + } \ + void post(ctype* dst) { \ + __stype##16x8_t vval_low = \ + vmovl_##_stype##8(vget_low_##_stype##8(res)); \ + __stype##16x8_t vval_high = \ + vmovl_##_stype##8(vget_high_##_stype##8(res)); \ + __stype##16x8_t vval_m = \ + v##_mode##q_##_stype##16(vval_low, vval_high); \ + \ + __stype##32x4_t vval_m_low = \ + vmovl_##_stype##16(vget_low_##_stype##16(vval_m)); \ + __stype##32x4_t vval_m_high = \ + vmovl_##_stype##16(vget_high_##_stype##16(vval_m)); \ + __stype##32x4_t vval_m_m = \ + v##_mode##q_##_stype##32(vval_m_low, vval_m_high); \ + using namespace std; \ + *dst = _mode({vgetq_lane_##_stype##32(vval_m_m, 0), \ + vgetq_lane_##_stype##32(vval_m_m, 1), \ + vgetq_lane_##_stype##32(vval_m_m, 2), \ + vgetq_lane_##_stype##32(vval_m_m, 3)}); \ + } \ + } + +REDUCER_MAX_MIN_C1(max, dt_qint8, int8_t, int8_t, s, int, -128); +REDUCER_MAX_MIN_C1(min, dt_qint8, int8_t, int8_t, s, int, 127); +REDUCER_MAX_MIN_C1(max, dt_quint8, uint8_t, uint8_t, u, uint, 0); +REDUCER_MAX_MIN_C1(min, dt_quint8, uint8_t, uint8_t, u, uint, 255); +#undef REDUCER_MAX_MIN_C1 + +#define REDUCER_MAX_MIN_C(_mode, _dtype, _ctype, _comp_type, _stype, __stype, _init) \ + template<> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, false> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = 16; \ + __stype##8x16_t res, remain; \ + _mode##Reducer(DType, size_t) { \ + res = vdupq_n_##_stype(_init); \ + remain = vdupq_n_##_stype(_init); \ + } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype##8x16_t vval = vld1q_##_stype(val); \ + res = v##_mode##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + __stype##8x16_t vval = vdupq_n_##_stype(*val); \ + remain = v##_mode##q_##_stype(vval, remain); \ + } \ + void post(ctype* dst) { \ + vst1q_##_stype(dst, res); \ + } \ + void post_remain(ctype* dst) { \ + vst1q_lane_##_stype(dst, remain, 0); \ + } \ + } + +REDUCER_MAX_MIN_C(max, dt_qint8, int8_t, int8_t, s8, int, -128); +REDUCER_MAX_MIN_C(min, dt_qint8, int8_t, int8_t, s8, int, 127); +REDUCER_MAX_MIN_C(max, dt_quint8, uint8_t, uint8_t, u8, uint, 0); +REDUCER_MAX_MIN_C(min, dt_quint8, uint8_t, uint8_t, u8, uint, 255); +#undef REDUCER_MAX_MIN_C + +#define REDUCER_MAX_MIN_C1(_mode, _dtype, _ctype, _comp_type, _stype, __stype, \ + _num, _init) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = _num; \ + __stype res; \ + _mode##Reducer(DType, size_t) { res = vdupq_n_##_stype(_init); } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype vval = vld1q_##_stype(val); \ + res = v##_mode##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + __stype vval = vdupq_n_##_stype(*val); \ + res = v##_mode##q_##_stype(vval, res); \ + } \ + void post(ctype* dst) { \ + auto val = v##_mode##_##_stype(vget_low_##_stype(res), \ + vget_high_##_stype(res)); \ + using namespace std; \ + *dst = _mode( \ + {vget_lane_##_stype(val, 0), vget_lane_##_stype(val, 1)}); \ + } \ + } + +REDUCER_MAX_MIN_C1(max, dt_float32, float, float, f32, float32x4_t, 4, + std::numeric_limits::lowest()); +REDUCER_MAX_MIN_C1(min, dt_float32, float, float, f32, float32x4_t, 4, + std::numeric_limits::max()); +#undef REDUCER_MAX_MIN_C1 + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#define REDUCER_MAX_MIN_C1(_mode, _dtype, _ctype, _comp_type, _stype, __stype, \ + _num, _init) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = _num; \ + __stype res; \ + _mode##Reducer(DType, size_t) { res = vdupq_n_##_stype(_init); } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype vval = vld1q_##_stype(val); \ + res = v##_mode##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + __stype vval = vdupq_n_##_stype(*val); \ + res = v##_mode##q_##_stype(vval, res); \ + } \ + void post(ctype* dst) { \ + auto val = v##_mode##_##_stype(vget_low_##_stype(res), \ + vget_high_##_stype(res)); \ + using namespace std; \ + *dst = _mode( \ + {vget_lane_##_stype(val, 0), vget_lane_##_stype(val, 1), \ + vget_lane_##_stype(val, 2), vget_lane_##_stype(val, 3)}); \ + } \ + } + +REDUCER_MAX_MIN_C1(max, __fp16, __fp16, __fp16, f16, float16x8_t, 8, + std::numeric_limits::lowest()); +REDUCER_MAX_MIN_C1(min, __fp16, __fp16, __fp16, f16, float16x8_t, 8, + std::numeric_limits::max()); +#undef REDUCER_MAX_MIN_C1 +#endif + +#define REDUCER_MAX_MIN_C(_mode, _dtype, _ctype, _comp_type, _stype, __stype, \ + _num, _init) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, false> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = _num; \ + __stype res; \ + ctype remain; \ + _mode##Reducer(DType, size_t) { \ + res = vdupq_n_##_stype(_init); \ + remain = _init; \ + } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype vval = vld1q_##_stype(val); \ + res = v##_mode##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + using namespace std; \ + remain = _mode(*val, remain); \ + } \ + void post(ctype* dst) { vst1q_##_stype(dst, res); } \ + void post_remain(ctype* dst) { *dst = remain; } \ + } + +REDUCER_MAX_MIN_C(max, dt_float32, float, float, f32, float32x4_t, 4, + std::numeric_limits::lowest()); +REDUCER_MAX_MIN_C(min, dt_float32, float, float, f32, float32x4_t, 4, + std::numeric_limits::max()); +#undef REDUCER_MAX_MIN_C +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#define REDUCER_MAX_MIN_C(_mode, _dtype, _ctype, _comp_type, _stype, __stype, \ + _num, _init) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, false> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = _num; \ + __stype res; \ + fp16_fix_t remain; \ + _mode##Reducer(DType, size_t) { \ + res = vdupq_n_##_stype(_init); \ + remain = _init; \ + } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype vval = vld1q_##_stype(val); \ + res = v##_mode##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + using namespace std; \ + remain = _mode(*val, static_cast<__fp16>(remain)); \ + } \ + void post(ctype* dst) { vst1q_##_stype(dst, res); } \ + void post_remain(ctype* dst) { *dst = static_cast<__fp16>(remain); } \ + } + +REDUCER_MAX_MIN_C(max, __fp16, __fp16, __fp16, f16, float16x8_t, 8, + std::numeric_limits::lowest()); +REDUCER_MAX_MIN_C(min, __fp16, __fp16, __fp16, f16, float16x8_t, 8, + std::numeric_limits::max()); +#undef REDUCER_MAX_MIN_C +#endif + +/***************************Sum Product Reducer***************************/ +template +struct SumReducer; +template +struct ProductReducer; + +#define REDUCER_SUM_PRODUCT_C1(_mode, _dtype, _ctype, _comp_type, _stype, \ + __stype, _num, _init, _act, _op) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = _num; \ + __stype res; \ + ctype remain; \ + _mode##Reducer(DType, size_t) { \ + res = vdupq_n_##_stype(_init); \ + remain = _init; \ + } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype vval = vld1q_##_stype(val); \ + res = v##_act##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + using namespace std; \ + auto op = _op(); \ + remain = op(remain, *val); \ + } \ + void post(ctype* dst) { \ + using namespace std; \ + auto val = v##_act##_##_stype(vget_low_##_stype(res), \ + vget_high_##_stype(res)); \ + auto op = _op(); \ + *dst = op(remain, op(vget_lane_##_stype(val, 0), \ + vget_lane_##_stype(val, 1))); \ + } \ + } + +REDUCER_SUM_PRODUCT_C1(Sum, dt_float32, float, float, f32, float32x4_t, 4, 0, + add, plus); +REDUCER_SUM_PRODUCT_C1(Product, dt_float32, float, float, f32, float32x4_t, 4, + 1.0f, mul, multiplies); +#undef REDUCER_SUM_PRODUCT_C1 + +#define REDUCER_SUM_PRODUCT_C(_mode, _dtype, _ctype, _comp_type, _stype, \ + __stype, _num, _init, _act, _op) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, false> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = _num; \ + __stype res; \ + ctype remain; \ + _mode##Reducer(DType, size_t) { \ + res = vdupq_n_##_stype(_init); \ + remain = _init; \ + } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype vval = vld1q_##_stype(val); \ + res = v##_act##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + using namespace std; \ + auto op = _op(); \ + remain = op(remain, (*val)); \ + } \ + void post(ctype* dst) { vst1q_##_stype(dst, res); } \ + void post_remain(ctype* dst) { *dst = remain; } \ + } + +REDUCER_SUM_PRODUCT_C(Sum, dt_float32, float, float, f32, float32x4_t, 4, 0, + add, plus); +REDUCER_SUM_PRODUCT_C(Product, dt_float32, float, float, f32, float32x4_t, 4, 1, + mul, multiplies); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +REDUCER_SUM_PRODUCT_C(Sum, __fp16, __fp16, __fp16, f16, float16x8_t, 8, 0, add, + plus); +REDUCER_SUM_PRODUCT_C(Product, __fp16, __fp16, __fp16, f16, float16x8_t, 8, 1, + mul, multiplies); +#endif +#undef REDUCER_SUM_PRODUCT_C + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#define REDUCER_SUM_PRODUCT_C1(_mode, _dtype, _ctype, _comp_type, _stype, \ + __stype, _num, _init, _act, _op) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = _num; \ + __stype res; \ + fp16_fix_t remain; \ + _mode##Reducer(DType, size_t) { \ + res = vdupq_n_##_stype(_init); \ + remain = _init; \ + } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype vval = vld1q_##_stype(val); \ + res = v##_act##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + using namespace std; \ + auto op = _op(); \ + remain = op(remain, *val); \ + } \ + void post(ctype* dst) { \ + using namespace std; \ + auto val = v##_act##_##_stype(vget_low_##_stype(res), \ + vget_high_##_stype(res)); \ + auto op = _op(); \ + *dst = op(remain, op(op(vget_lane_##_stype(val, 0), \ + vget_lane_##_stype(val, 1)), \ + op(vget_lane_##_stype(val, 2), \ + vget_lane_##_stype(val, 3)))); \ + } \ + } + +REDUCER_SUM_PRODUCT_C1(Sum, __fp16, __fp16, __fp16, f16, float16x8_t, 8, 0, add, + plus); +REDUCER_SUM_PRODUCT_C1(Product, __fp16, __fp16, __fp16, f16, float16x8_t, 8, 1, + mul, multiplies); +#undef REDUCER_SUM_PRODUCT_C1 +#endif + +/***************************SumSqr Reducer***************************/ +template +struct SumSqrReducer; + +template <> +struct SumSqrReducer { + using ctype = float; + static constexpr int SIMD_WIDTH = 4; + + float32x4_t res; + float result; + SumSqrReducer(DType, size_t cnt) : result(0.0f) { + MEGDNN_MARK_USED_VAR(cnt); + res = vdupq_n_f32(0.0f); + } + SumSqrReducer() = default; + void feed(const float* val) { + float32x4_t vval = vld1q_f32(val); + res = vaddq_f32(vmulq_f32(vval, vval), res); + } + void feed_remain(const float* val) { + float vval = *val; + result += vval * vval; + } + void post(float* dst) { +#if MEGDNN_AARCH64 + result += vaddvq_f32(res); +#elif MEGDNN_ARMV7 + auto sum_temp = vpadd_f32(vget_low_f32(res), vget_high_f32(res)); + result += (vget_lane_f32(sum_temp, 0) + vget_lane_f32(sum_temp, 1)); +#else +#error "unsupport android arch" +#endif + *dst = result; + } +}; +template <> +struct SumSqrReducer { + using ctype = float; + static constexpr int SIMD_WIDTH = 4; + + float32x4_t res; + float remain; + SumSqrReducer(DType, size_t cnt) : remain(0.0f){ + MEGDNN_MARK_USED_VAR(cnt); + res = vdupq_n_f32(0.0f); + } + SumSqrReducer() = default; + void feed(const float* val) { + float32x4_t vval = vld1q_f32(val); + res = vaddq_f32(vmulq_f32(vval, vval), res); + } + void feed_remain(const float* val) { remain += (*val) * (*val); } + void post(float* dst) { + vst1q_f32(dst, res); + } + void post_remain(float* dst){ + *dst = remain; + } +}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +struct SumSqrReducer<__fp16, __fp16, __fp16, true> { + using ctype = __fp16; + static constexpr int SIMD_WIDTH = 8; + + float16x8_t res; + //! We set the dtype of result to float instead of __fp16, As in compile + //! armv7+fp16, it may trigger result error. + //! ldr instrucation need alignment of 4bytes, while __fp16 result placed in + //! text segments is not satisfied. + //!FIXME: we should check it if we upgrade compiler. + fp16_fix_t result; + SumSqrReducer(DType, size_t cnt) : result(0.0f) { res = vdupq_n_f16(0.0f); } + SumSqrReducer() = default; + void feed(const __fp16* val) { + float16x8_t vval = vld1q_f16(val); + res = vaddq_f16(vmulq_f16(vval, vval), res); + } + void feed_remain(const __fp16* val) { + __fp16 vval = *val; + result += vval * vval; + } + void post(__fp16* dst) { + auto sum_temp = vpadd_f16(vget_low_f16(res), vget_high_f16(res)); + result += (vget_lane_f16(sum_temp, 0) + vget_lane_f16(sum_temp, 1)) + + (vget_lane_f16(sum_temp, 2) + vget_lane_f16(sum_temp, 3)); + *dst = result; + } +}; +template <> +struct SumSqrReducer<__fp16, __fp16, __fp16, false> { + using ctype = __fp16; + static constexpr int SIMD_WIDTH = 8; + + float16x8_t res; + //! We set the dtype of result to float instead of __fp16, As in compile + //! armv7+fp16, it may trigger result error. + //! ldr instrucation need alignment of 4bytes, while __fp16 result placed in + //! text segments is not satisfied. + //!FIXME: we should check it if we upgrade compiler. + fp16_fix_t remain; + SumSqrReducer(DType, size_t cnt) : remain(0.0f){ + res = vdupq_n_f16(0.0f); + } + SumSqrReducer() = default; + void feed(const __fp16* val) { + float16x8_t vval = vld1q_f16(val); + res = vaddq_f16(vmulq_f16(vval, vval), res); + } + void feed_remain(const __fp16* val) { remain += (*val) * (*val); } + void post(__fp16* dst) { + vst1q_f16(dst, res); + } + void post_remain(__fp16* dst){ + *dst = remain; + } +}; +#endif + +/**************************************do reduce*************************/ + +template +struct Exec { + static void do_reduce(const typename Reducer::ctype* src, + const typename Reducer::ctype* dst, DType src_dtype, + size_t A, size_t B, size_t C); +}; + +template +struct Exec { + static void do_reduce(const typename Reducer::ctype* src, + typename Reducer::ctype* dst, DType src_dtype, + size_t A, size_t B, size_t) { + size_t a = 0; + for (; a < A; a++) { + Reducer reducer0(src_dtype, B); + auto temp_src0 = src + a * B; + size_t b = 0; + for (; b + Reducer::SIMD_WIDTH <= B; b += Reducer::SIMD_WIDTH) { + reducer0.feed(temp_src0); + temp_src0 += Reducer::SIMD_WIDTH; + } + for (; b < B; b++) { + reducer0.feed_remain(temp_src0); + temp_src0++; + } + reducer0.post(dst); + dst++; + } + } +}; + +template +struct Exec { + static void do_reduce(const typename Reducer::ctype* src, + typename Reducer::ctype* dst, DType src_dtype, + size_t A, size_t B, size_t C) { + for (size_t a = 0; a < A; a++) { + size_t c = 0; + for (; c + Reducer::SIMD_WIDTH <= C; c += Reducer::SIMD_WIDTH) { + Reducer reducer(src_dtype, B); + for (size_t b = 0; b < B; b++) + reducer.feed(src + c + C * b); + reducer.post(dst); + dst += Reducer::SIMD_WIDTH; + } + for (; c < C; c++) { + Reducer reducer(src_dtype, B); + for (size_t b = 0; b < B; b++) + reducer.feed_remain(src + c + C * b); + reducer.post_remain(dst); + dst++; + } + src += B * C; + } + } +}; + +} // anonymous namespace + +void ReduceImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + check_exec(src.layout, dst.layout, workspace.size); + size_t A, B, C; + reduce::get_ABC(src.layout, A, B, C, param().axis); + bool execed = false; + using Mode = param::Reduce::Mode; +#define DISPATCH_FUNC(Reducer, dtype, ctype, comp_type) \ + if (C == 1) { \ + using _Reducer = Reducer; \ + std::function \ + do_reduce = Exec<_Reducer, true>::do_reduce; \ + MIDOUT_BEGIN(megdnn_arm_common_reduce, ctype, dtype, comp_type, \ + midout_iv(1)) { \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + do_reduce(reinterpret_cast(src.raw_ptr), \ + reinterpret_cast(dst.raw_ptr), src_type, \ + A, B, C)); \ + execed = true; \ + } \ + MIDOUT_END(); \ + } else { \ + using _Reducer = Reducer; \ + std::function \ + do_reduce = Exec<_Reducer, false>::do_reduce; \ + MIDOUT_BEGIN(megdnn_arm_common_reduce, ctype, dtype, comp_type, \ + midout_iv(1)) { \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + do_reduce(reinterpret_cast(src.raw_ptr), \ + reinterpret_cast(dst.raw_ptr), src_type, \ + A, B, C)); \ + execed = true; \ + } \ + MIDOUT_END(); \ + } + +#define DISPATCH_MODE_QUANTIZED(dtype, ctype, comp_type) \ + switch (param().mode) { \ + case Mode::MEAN: \ + DISPATCH_FUNC(MeanReducer, dtype, ctype, comp_type); \ + break; \ + case Mode::MAX: \ + DISPATCH_FUNC(maxReducer, dtype, ctype, ctype); \ + break; \ + case Mode::MIN: \ + DISPATCH_FUNC(minReducer, dtype, ctype, ctype); \ + break; \ + default: \ + break; \ + } + +#define DISPATCH_MODE_FLOAT(dtype, ctype, comp_type) \ + switch (param().mode) { \ + case Mode::MEAN: \ + DISPATCH_FUNC(MeanReducer, dtype, ctype, comp_type); \ + break; \ + case Mode::MAX: \ + DISPATCH_FUNC(maxReducer, dtype, ctype, ctype); \ + break; \ + case Mode::MIN: \ + DISPATCH_FUNC(minReducer, dtype, ctype, ctype); \ + break; \ + case Mode::SUM: \ + DISPATCH_FUNC(SumReducer, dtype, ctype, ctype); \ + break; \ + case Mode::SUM_SQR: \ + DISPATCH_FUNC(SumSqrReducer, dtype, ctype, ctype); \ + break; \ + case Mode::PRODUCT: \ + DISPATCH_FUNC(ProductReducer, dtype, ctype, ctype); \ + break; \ + default: \ + break; \ + } + if (src.layout.is_contiguous() && + src.layout.dtype.category() == DTypeCategory::QUANTIZED && + param().data_type == param::Reduce::DataType::DEFAULT) { + DType src_type = src.layout.dtype; + + if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { + DISPATCH_MODE_QUANTIZED(dt_qint8, int8_t, int32_t) + } + if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { + DISPATCH_MODE_QUANTIZED(dt_quint8, uint8_t, int32_t) + } + } else if (src.layout.is_contiguous() && + src.layout.dtype.category() == DTypeCategory::FLOAT && + param().data_type == param::Reduce::DataType::DEFAULT) { + + DType src_type = src.layout.dtype; + if (src.layout.dtype.enumv() == DTypeEnum::Float32) { + DISPATCH_MODE_FLOAT(dt_float32, float, float) + } +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (src.layout.dtype.enumv() == DTypeEnum::Float16) { + MEGDNN_INC_FLOAT16(DISPATCH_MODE_FLOAT(__fp16, __fp16, __fp16)); + } +#endif + } +#undef DISPATCH_FUNC +#undef DISPATCH_MODE_QUANTIZED +#undef DISPATCH_MODE_FLOAT + + if (!execed) { + return fallback::ReduceImpl::exec(src, dst, workspace); + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/reduce/opr_impl.h b/dnn/src/arm_common/reduce/opr_impl.h new file mode 100644 index 00000000..2374ea9e --- /dev/null +++ b/dnn/src/arm_common/reduce/opr_impl.h @@ -0,0 +1,29 @@ +/** + * \file dnn/src/arm_common/reduce/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/fallback/reduce/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class ReduceImpl : public fallback::ReduceImpl { +public: + using fallback::ReduceImpl::ReduceImpl; + + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/resize/opr_impl.cpp b/dnn/src/arm_common/resize/opr_impl.cpp new file mode 100644 index 00000000..8c46f5f6 --- /dev/null +++ b/dnn/src/arm_common/resize/opr_impl.cpp @@ -0,0 +1,33 @@ +/** + * \file dnn/src/arm_common/resize/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/resize/opr_impl.h" +#include "src/arm_common/handle.h" +#include "src/arm_common/resize/resize_cv.h" + +using namespace megdnn; +using namespace arm_common; + +void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, + _megdnn_workspace workspace) { + check_exec(src.layout, dst.layout, workspace.size); + if (param().format == param::Resize::Format::NCHW || + (src.layout[3] != 1 && src.layout[3] != 3) || + !is_nhwc_contig_wc(src.layout)) { + fallback::ResizeImpl::exec(src, dst, workspace); + } else { + megdnn_assert(param().format == param::Resize::Format::NHWC, + "invalid resize format"); + MEGDNN_DISPATCH_CPU_KERN_OPR(resize_cv_exec(src, dst, param().imode)); + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/resize/opr_impl.h b/dnn/src/arm_common/resize/opr_impl.h new file mode 100644 index 00000000..5e4b6f31 --- /dev/null +++ b/dnn/src/arm_common/resize/opr_impl.h @@ -0,0 +1,33 @@ +/** + * \file dnn/src/arm_common/resize/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "src/fallback/resize/opr_impl.h" + +namespace megdnn { +namespace arm_common { +class ResizeImpl : public fallback::ResizeImpl { +public: + using fallback::ResizeImpl::ResizeImpl; + + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/resize/resize_cv.cpp b/dnn/src/arm_common/resize/resize_cv.cpp new file mode 100644 index 00000000..ab85695d --- /dev/null +++ b/dnn/src/arm_common/resize/resize_cv.cpp @@ -0,0 +1,2052 @@ +/** + * By downloading, copying, installing or using the software you agree to this license. + * If you do not agree to this license, do not download, install, + * copy or use the software. + * + * + * License Agreement + * For Open Source Computer Vision Library + * (3-clause BSD License) + * + * Copyright (C) 2000-2020, Intel Corporation, all rights reserved. + * Copyright (C) 2009-2011, Willow Garage Inc., all rights reserved. + * Copyright (C) 2009-2016, NVIDIA Corporation, all rights reserved. + * Copyright (C) 2010-2013, Advanced Micro Devices, Inc., all rights reserved. + * Copyright (C) 2015-2016, OpenCV Foundation, all rights reserved. + * Copyright (C) 2015-2016, Itseez Inc., all rights reserved. + * Copyright (C) 2019-2020, Xperience AI, all rights reserved. + * Third party copyrights are property of their respective owners. + * + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the names of the copyright holders nor the names of the contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * This software is provided by the copyright holders and contributors "as is" and + * any express or implied warranties, including, but not limited to, the implied + * warranties of merchantability and fitness for a particular purpose are disclaimed. + * In no event shall copyright holders or contributors be liable for any direct, + * indirect, incidental, special, exemplary, or consequential damages + * (including, but not limited to, procurement of substitute goods or services; + * loss of use, data, or profits; or business interruption) however caused + * and on any theory of liability, whether in contract, strict liability, + * or tort (including negligence or otherwise) arising in any way out of + * the use of this software, even if advised of the possibility of such damage. + * + * --------------------------------------------------------------------------- + * \file dnn/src/arm_common/resize/resize_cv.cpp + * + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * + * This file has been modified by Megvii ("Megvii Modifications"). + * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. + * + * --------------------------------------------------------------------------- + */ +#include +#include "src/arm_common/handle.h" +#include "src/arm_common/resize/opr_impl.h" +#include "src/arm_common/resize/resize_cv.h" +#include "src/common/cv/common.h" +#include "src/common/cv/helper.h" +#include "src/common/utils.h" + +#include "midout.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +MIDOUT_DECL(megdnn_arm_resizecv_imode) +MIDOUT_DECL(megdnn_arm_resizecv_dtype) + +using namespace megdnn; +using namespace arm_common; +using namespace megcv; + +namespace { + +const int SCALE = 11; + +using InterpolationMode = param::Resize::InterpolationMode; +using IMode = InterpolationMode; + +// nearest neighbor + +void resize_nearest_8u(const Mat8u& src, Mat8u& dst) { + AlignedVector tabx(dst.rows()); + AlignedVector taby(dst.cols()); + const double fx = static_cast(dst.rows()) / src.rows(); + const double fy = static_cast(dst.cols()) / src.cols(); + const double ifx = 1.0f / fx; + const double ify = 1.0f / fy; + const size_t ch = src.channels(); + for (size_t dx = 0; dx < tabx.size(); ++dx) { + double rx = dx * ifx; + int sx = static_cast(floor(rx)); + sx = megcv::saturate(sx, 0, static_cast(src.rows())); + tabx[dx] = sx; + } + for (size_t dy = 0; dy < taby.size(); ++dy) { + double ry = dy * ify; + int sy = static_cast(floor(ry)); + sy = megcv::saturate(sy, 0, static_cast(src.cols())); + taby[dy] = sy; + } + + int tabxsize = tabx.size(); + int tabysize = taby.size(); + if (ch == 1) { + for (int dx = 0; dx < tabxsize; ++dx) { + uchar* pdst = dst.ptr(dx); + const uchar* psrc = src.ptr(tabx[dx]); + for (int dy = 0; dy < tabysize; ++dy) { + uchar* pcdst = pdst + dy; + const uchar* pcsrc = psrc + taby[dy]; + pcdst[0] = pcsrc[0]; + } + } + } else if (ch == 3) { + for (int dx = 0; dx < tabxsize; ++dx) { + uchar* pdst = dst.ptr(dx); + const uchar* psrc = src.ptr(tabx[dx]); + int dy3 = 0; + for (int dy = 0; dy < tabysize; ++dy, dy3 += 3) { + uchar* pcdst = pdst + dy3; + const uchar* pcsrc = psrc + taby[dy] * 3; + pcdst[0] = pcsrc[0]; + pcdst[1] = pcsrc[1]; + pcdst[2] = pcsrc[2]; + } + } + } +} + +void resize_nearest_32f(const Mat32f& src, Mat32f& dst) { + AlignedVector tabx(dst.rows()); + AlignedVector taby(dst.cols()); + const double fx = static_cast(dst.rows()) / src.rows(); + const double fy = static_cast(dst.cols()) / src.cols(); + const double ifx = 1.0f / fx; + const double ify = 1.0f / fy; + const size_t ch = src.channels(); + for (size_t dx = 0; dx < tabx.size(); ++dx) { + double rx = dx * ifx; + int sx = static_cast(floor(rx)); + sx = megcv::saturate(sx, 0, static_cast(src.rows())); + tabx[dx] = sx; + } + for (size_t dy = 0; dy < taby.size(); ++dy) { + double ry = dy * ify; + int sy = static_cast(floor(ry)); + sy = megcv::saturate(sy, 0, static_cast(src.cols())); + taby[dy] = sy; + } + // taby[taby.size() - 1] = src.cols() - 1; + size_t tabxsize = tabx.size(); + size_t tabysize = taby.size(); + if (ch == 1) { + for (size_t dx = 0; dx < tabxsize; ++dx) { + float* pdst = dst.ptr(dx); + const float* psrc = src.ptr(tabx[dx]); + size_t dy = 0; + for (; dy < tabysize; dy++) { + const float* pcsrc = psrc + taby[dy]; + pdst[dy] = pcsrc[0]; + } + } + } else if (ch == 3) { + for (size_t dx = 0; dx < tabxsize; ++dx) { + float* pdst = dst.ptr(dx); + const float* psrc = src.ptr(tabx[dx]); + size_t dy3 = 0; + for (size_t dy = 0; dy < tabysize; ++dy, dy3 += 3) { + float* pcdst = pdst + dy3; + const float* pcsrc = psrc + taby[dy] * 3; + pcdst[0] = pcsrc[0]; + pcdst[1] = pcsrc[1]; + pcdst[2] = pcsrc[2]; + } + } + } +} + +// linear 32f +void build_tabs_linear_32f(const Mat32f& src, const Mat32f& dst, + AlignedVector& tabsx, AlignedVector& tabsy, + AlignedVector& tabrx, + AlignedVector& tabry) { + megdnn_assert(src.rows() >= 2); + megdnn_assert(src.cols() >= 2); + megdnn_assert(dst.rows() >= 2); + megdnn_assert(dst.cols() >= 2); + const float fx = static_cast(dst.rows()) / src.rows(); + const float fy = static_cast(dst.cols()) / src.cols(); + const float ifx = 1.0f / fx; + const float ify = 1.0f / fy; + for (size_t dx = 0; dx < dst.rows(); ++dx) { + float rx = (dx + 0.5f) * ifx - 0.5f; + int sx = static_cast(floor(rx)); + rx -= sx; + if (sx < 0) { + sx = 0; + rx = 0; + } else if (sx + 1 >= static_cast(src.rows())) { + sx = src.rows() - 2; + rx = 1; + } + tabsx[dx] = sx; + tabrx[dx] = rx; + } + for (size_t dy = 0; dy < dst.cols(); ++dy) { + float ry = (dy + 0.5f) * ify - 0.5f; + int sy = static_cast(floor(ry)); + ry -= sy; + if (sy < 0) { + sy = 0; + ry = 0; + } else if (sy + 1 >= static_cast(src.cols())) { + sy = src.cols() - 2; + ry = 1; + } + tabsy[dy] = sy; + tabry[dy] = ry; + } +} + +void calc_cache_linear_32fc1_1(const Mat32f& src, const Mat32f& dst, + const AlignedVector& tabsx, + const AlignedVector& tabsy, + const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, + AlignedVector& cache0, + AlignedVector& cache1) { + (void)tabrx; + const float* psrc1 = src.ptr(tabsx[dx] + 1); + size_t dstcols = dst.cols(); + size_t dy = 0; + + // cache0 = cache1; + std::swap(cache0, cache1); + for (; dy < dstcols; ++dy) { + const float* pcsrc10 = psrc1 + (tabsy[dy] + 0); + const float* pcsrc11 = psrc1 + (tabsy[dy] + 1); + float ry = tabry[dy]; + float iry = 1.0f - ry; + cache1[dy] = pcsrc11[0] * ry + pcsrc10[0] * iry; + } +} + +void calc_cache_linear_32fc1_2(const Mat32f& src, const Mat32f& dst, + const AlignedVector& tabsx, + const AlignedVector& tabsy, + const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, + AlignedVector& cache0, + AlignedVector& cache1) { + (void)tabrx; + const float* psrc0 = src.ptr(tabsx[dx] + 0); + const float* psrc1 = src.ptr(tabsx[dx] + 1); + int dstcols = dst.cols(); + int dy = 0; + + // 4 pixels each time + float* cache0_ptr = cache0.data(); + float* cache1_ptr = cache1.data(); + const float* tabry_ptr = tabry.data(); + for (; dy + 4 <= dstcols; dy += 4) { +#define EXPAND(dy) \ + { \ + int t0 = tabsy[dy + 0]; \ + int t1 = tabsy[dy + 1]; \ + int t2 = tabsy[dy + 2]; \ + int t3 = tabsy[dy + 3]; \ + const float pcsrc00[4] = {psrc0[t0 + 0], psrc0[t1 + 0], psrc0[t2 + 0], \ + psrc0[t3 + 0]}; \ + const float pcsrc01[4] = { \ + psrc0[t0 + 1], \ + psrc0[t1 + 1], \ + psrc0[t2 + 1], \ + psrc0[t3 + 1], \ + }; \ + const float pcsrc10[4] = { \ + psrc1[t0 + 0], \ + psrc1[t1 + 0], \ + psrc1[t2 + 0], \ + psrc1[t3 + 0], \ + }; \ + const float pcsrc11[4] = { \ + psrc1[t0 + 1], \ + psrc1[t1 + 1], \ + psrc1[t2 + 1], \ + psrc1[t3 + 1], \ + }; \ + float32x4_t v_pcsrc00 = vld1q_f32(pcsrc00); \ + float32x4_t v_pcsrc01 = vld1q_f32(pcsrc01); \ + float32x4_t v_pcsrc10 = vld1q_f32(pcsrc10); \ + float32x4_t v_pcsrc11 = vld1q_f32(pcsrc11); \ + float32x4_t v_ry = vld1q_f32(tabry_ptr + dy); \ + float32x4_t v_iry = vsubq_f32(vdupq_n_f32(1.0f), v_ry); \ + vst1q_f32(cache0_ptr + dy, \ + vmlaq_f32(vmulq_f32(v_pcsrc01, v_ry), v_pcsrc00, v_iry)); \ + vst1q_f32(cache1_ptr + dy, \ + vmlaq_f32(vmulq_f32(v_pcsrc11, v_ry), v_pcsrc10, v_iry)); \ + } \ + while (0) + + EXPAND(dy); +#undef EXPAND + } + for (; dy < dstcols; ++dy) { + const float* pcsrc00 = psrc0 + (tabsy[dy] + 0); + const float* pcsrc01 = psrc0 + (tabsy[dy] + 1); + const float* pcsrc10 = psrc1 + (tabsy[dy] + 0); + const float* pcsrc11 = psrc1 + (tabsy[dy] + 1); + float ry = tabry[dy]; + float iry = 1.0f - ry; + cache0[dy] = pcsrc01[0] * ry + pcsrc00[0] * iry; + cache1[dy] = pcsrc11[0] * ry + pcsrc10[0] * iry; + } +} + +void calc_cache_linear_32fc3_1(const Mat32f& src, const Mat32f& dst, + const AlignedVector& tabsx, + const AlignedVector& tabsy, + const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, + AlignedVector& cache0, + AlignedVector& cache1) { + (void)tabrx; + const float* psrc1 = src.ptr(tabsx[dx] + 1); + const size_t dstcols = dst.cols(); + size_t dy = 0, dy3 = 0; + + // cache0 = cache1; + std::swap(cache0, cache1); + for (; dy < dstcols; ++dy, dy3 += 3) { + const float* pcsrc10 = psrc1 + (tabsy[dy] + 0) * 3; + const float* pcsrc11 = psrc1 + (tabsy[dy] + 1) * 3; + float ry = tabry[dy]; + float iry = 1.0f - ry; + cache1[dy3 + 0] = pcsrc11[0] * ry + pcsrc10[0] * iry; + cache1[dy3 + 1] = pcsrc11[1] * ry + pcsrc10[1] * iry; + cache1[dy3 + 2] = pcsrc11[2] * ry + pcsrc10[2] * iry; + } +} + +void calc_cache_linear_32fc3_2(const Mat32f& src, const Mat32f& dst, + const AlignedVector& tabsx, + const AlignedVector& tabsy, + const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, + AlignedVector& cache0, + AlignedVector& cache1) { + (void)tabrx; + const float* psrc0 = src.ptr(tabsx[dx] + 0); + const float* psrc1 = src.ptr(tabsx[dx] + 1); + int dstcols = dst.cols(); + int dy = 0, dy3 = 0; + + for (; dy < dstcols; ++dy, dy3 += 3) { + const float* pcsrc00 = psrc0 + (tabsy[dy] + 0) * 3; + const float* pcsrc01 = psrc0 + (tabsy[dy] + 1) * 3; + const float* pcsrc10 = psrc1 + (tabsy[dy] + 0) * 3; + const float* pcsrc11 = psrc1 + (tabsy[dy] + 1) * 3; + float ry = tabry[dy]; + float iry = 1.0f - ry; + cache0[dy3 + 0] = pcsrc01[0] * ry + pcsrc00[0] * iry; + cache1[dy3 + 0] = pcsrc11[0] * ry + pcsrc10[0] * iry; + cache0[dy3 + 1] = pcsrc01[1] * ry + pcsrc00[1] * iry; + cache1[dy3 + 1] = pcsrc11[1] * ry + pcsrc10[1] * iry; + cache0[dy3 + 2] = pcsrc01[2] * ry + pcsrc00[2] * iry; + cache1[dy3 + 2] = pcsrc11[2] * ry + pcsrc10[2] * iry; + } +} +void resize_linear_32f_neon(const Mat32f& src, Mat32f& dst) { + AlignedVector tabsx(dst.rows()); + AlignedVector tabsy(dst.cols()); + AlignedVector tabrx(dst.rows()); + AlignedVector tabry(dst.cols()); + build_tabs_linear_32f(src, dst, tabsx, tabsy, tabrx, tabry); + + if (src.channels() == 1) { + AlignedVector cache0(dst.cols()), cache1(dst.cols()); + int dstrows = dst.rows(); + int dstcols = dst.cols(); + for (int dx = 0; dx < dstrows; ++dx) { + if (dx == 0 || tabsx[dx] != tabsx[dx - 1]) { + if (dx > 0 && tabsx[dx] == tabsx[dx - 1] + 1) { + calc_cache_linear_32fc1_1(src, dst, tabsx, tabsy, tabrx, + tabry, dx, cache0, cache1); + } else { + calc_cache_linear_32fc1_2(src, dst, tabsx, tabsy, tabrx, + tabry, dx, cache0, cache1); + } + } + const float* cache0_ptr = cache0.data(); + const float* cache1_ptr = cache1.data(); + float rx = tabrx[dx]; + float irx = 1.0f - rx; + float* pdst = dst.ptr(dx); + int dy = 0; +#define EXPAND(x) \ + v_cache0 = vld1q_f32(cache0_ptr + dy + x); \ + v_cache1 = vld1q_f32(cache1_ptr + dy + x); \ + vst1q_f32(pdst + dy + x, \ + vmlaq_f32(vmulq_f32(v_rx, v_cache1), v_irx, v_cache0)); + float32x4_t v_rx = vdupq_n_f32(rx); + float32x4_t v_irx = vdupq_n_f32(irx); + for (; dy + 8 <= dstcols; dy += 8) { + float32x4_t v_cache0; + float32x4_t v_cache1; + EXPAND(0); + EXPAND(4); + } + if (dy + 4 <= dstcols) { + float32x4_t v_cache0; + float32x4_t v_cache1; + EXPAND(0); + dy += 4; + } +#undef EXPAND + for (; dy < dstcols; ++dy) { + float* pcdst = pdst + dy; + pcdst[0] = rx * cache1[dy] + irx * cache0[dy]; + } + } + } else if (src.channels() == 3) { + int dstrows = dst.rows(); + int dstcols = dst.cols() * 3; + AlignedVector cache0(dstcols), cache1(dstcols); + for (int dx = 0; dx < dstrows; ++dx) { + if (dx == 0 || tabsx[dx] != tabsx[dx - 1]) { + if (dx > 0 && tabsx[dx] == tabsx[dx - 1] + 1) { + calc_cache_linear_32fc3_1(src, dst, tabsx, tabsy, tabrx, + tabry, dx, cache0, cache1); + } else { + calc_cache_linear_32fc3_2(src, dst, tabsx, tabsy, tabrx, + tabry, dx, cache0, cache1); + } + } + const float* cache0_ptr = cache0.data(); + const float* cache1_ptr = cache1.data(); + float rx = tabrx[dx]; + float irx = 1.0f - rx; + float* pdst = dst.ptr(dx); + int dy = 0; + float32x4_t v_rx = vdupq_n_f32(rx); + float32x4_t v_irx = vdupq_n_f32(irx); +#define EXPAND(x) \ + v_cache0 = vld3q_f32(cache0_ptr + dy + (x)*3); \ + v_cache1 = vld3q_f32(cache1_ptr + dy + (x)*3); \ + v_dst.val[0] = vmlaq_f32(vmulq_f32(v_rx, v_cache1.val[0]), v_irx, \ + v_cache0.val[0]); \ + v_dst.val[1] = vmlaq_f32(vmulq_f32(v_rx, v_cache1.val[1]), v_irx, \ + v_cache0.val[1]); \ + v_dst.val[2] = vmlaq_f32(vmulq_f32(v_rx, v_cache1.val[2]), v_irx, \ + v_cache0.val[2]); \ + vst3q_f32(pdst + dy + (x)*3, v_dst); + for (; dy + 8 * 3 <= dstcols; dy += 8 * 3) { + float32x4x3_t v_cache0; + float32x4x3_t v_cache1; + float32x4x3_t v_dst; + + EXPAND(0); + EXPAND(4); + } + + if (dy + 4 * 3 <= dstcols) { + float32x4x3_t v_cache0; + float32x4x3_t v_cache1; + float32x4x3_t v_dst; + + EXPAND(0); + + dy += 4 * 3; + } +#undef EXPAND + for (; dy < dstcols; dy += 3) { + float* pcdst = pdst + dy; + pcdst[0] = rx * cache1[dy + 0] + irx * cache0[dy + 0]; + pcdst[1] = rx * cache1[dy + 1] + irx * cache0[dy + 1]; + pcdst[2] = rx * cache1[dy + 2] + irx * cache0[dy + 2]; + } + } + } else { + megdnn_throw(("nr. of channels must be 1 or 3.")); + } +} + +void resize_linear_32f(const Mat32f& src, Mat32f& dst) { + return resize_linear_32f_neon(src, dst); +} + +// linear 8u +void build_tabs_linear_8u(const Mat8u& src, const Mat8u& dst, + AlignedVector& tabsx, AlignedVector& tabsy, + AlignedVector& tabrx, + AlignedVector& tabry) { + megdnn_assert(src.rows() >= 2); + megdnn_assert(src.cols() >= 2); + megdnn_assert(dst.rows() >= 2); + megdnn_assert(dst.cols() >= 2); + const float fx = static_cast(dst.rows()) / src.rows(); + const float fy = static_cast(dst.cols()) / src.cols(); + const float ifx = 1.0f / fx; + const float ify = 1.0f / fy; + for (size_t dx = 0; dx < dst.rows(); ++dx) { + float rx = (dx + 0.5f) * ifx - 0.5f; + int sx = static_cast(floor(rx)); + rx -= sx; + if (sx < 0) { + sx = 0; + rx = 0; + } else if (sx + 1 >= static_cast(src.rows())) { + sx = src.rows() - 2; + rx = 1; + } + tabsx[dx] = sx; + tabrx[dx] = static_cast(rx * (1 << SCALE)); + } + for (size_t dy = 0; dy < dst.cols(); ++dy) { + float ry = (dy + 0.5f) * ify - 0.5f; + int sy = static_cast(floor(ry)); + ry -= sy; + if (sy < 0) { + sy = 0; + ry = 0; + } else if (sy + 1 >= static_cast(src.cols())) { + sy = src.cols() - 2; + ry = 1; + } + tabsy[dy] = sy; + tabry[dy] = static_cast(ry * (1 << SCALE)); + } +} + +void calc_cache_8uc1_1(const Mat8u& src, const Mat8u& dst, + const AlignedVector& tabsx, + const AlignedVector& tabsy, + const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, + AlignedVector& cache0, AlignedVector& cache1) { + (void)tabrx; + const uchar* psrc1 = src.ptr(tabsx[dx] + 1); + size_t dstcols = dst.cols(); + size_t dy = 0; + + // cache0 = cache1; + std::swap(cache0, cache1); + for (; dy < dstcols; ++dy) { + const uchar* pcsrc10 = psrc1 + (tabsy[dy] + 0); + const uchar* pcsrc11 = psrc1 + (tabsy[dy] + 1); + int ry = tabry[dy]; + int iry = (1 << SCALE) - ry; + cache1[dy] = pcsrc11[0] * ry + pcsrc10[0] * iry; + } +} + +void calc_cache_8uc1_2(const Mat8u& src, const Mat8u& dst, + const AlignedVector& tabsx, + const AlignedVector& tabsy, + const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, + AlignedVector& cache0, AlignedVector& cache1) { + (void)tabrx; + const uchar* psrc0 = src.ptr(tabsx[dx] + 0); + const uchar* psrc1 = src.ptr(tabsx[dx] + 1); + int dstcols = dst.cols(); + int dy = 0; + + // 4 pixels each time + for (; dy < dstcols; ++dy) { + const uchar* pcsrc00 = psrc0 + (tabsy[dy] + 0); + const uchar* pcsrc01 = psrc0 + (tabsy[dy] + 1); + const uchar* pcsrc10 = psrc1 + (tabsy[dy] + 0); + const uchar* pcsrc11 = psrc1 + (tabsy[dy] + 1); + int ry = tabry[dy]; + int iry = (1 << SCALE) - ry; + cache0[dy] = pcsrc01[0] * ry + pcsrc00[0] * iry; + cache1[dy] = pcsrc11[0] * ry + pcsrc10[0] * iry; + } +} + +void calc_cache_8uc3_1(const Mat8u& src, const Mat8u& dst, + const AlignedVector& tabsx, + const AlignedVector& tabsy, + const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, + AlignedVector& cache0, AlignedVector& cache1) { + (void)tabrx; + const uchar* psrc1 = src.ptr(tabsx[dx] + 1); + size_t dstcols = dst.cols(); + size_t dy = 0, dy3 = 0; + + // cache0 = cache1; + std::swap(cache0, cache1); + for (; dy < dstcols; ++dy, dy3 += 3) { + const uchar* pcsrc10 = psrc1 + (tabsy[dy] + 0) * 3; + const uchar* pcsrc11 = psrc1 + (tabsy[dy] + 1) * 3; + int ry = tabry[dy]; + int iry = (1 << SCALE) - ry; + cache1[dy3 + 0] = pcsrc11[0] * ry + pcsrc10[0] * iry; + cache1[dy3 + 1] = pcsrc11[1] * ry + pcsrc10[1] * iry; + cache1[dy3 + 2] = pcsrc11[2] * ry + pcsrc10[2] * iry; + } +} + +void calc_cache_8uc3_2(const Mat8u& src, const Mat8u& dst, + const AlignedVector& tabsx, + const AlignedVector& tabsy, + const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, + AlignedVector& cache0, AlignedVector& cache1) { + (void)tabrx; + const uchar* psrc0 = src.ptr(tabsx[dx] + 0); + const uchar* psrc1 = src.ptr(tabsx[dx] + 1); + int dstcols = dst.cols(); + int dy = 0, dy3 = 0; + + // 4 pixels each time + for (; dy < dstcols; ++dy, dy3 += 3) { + const uchar* pcsrc00 = psrc0 + (tabsy[dy] + 0) * 3; + const uchar* pcsrc01 = psrc0 + (tabsy[dy] + 1) * 3; + const uchar* pcsrc10 = psrc1 + (tabsy[dy] + 0) * 3; + const uchar* pcsrc11 = psrc1 + (tabsy[dy] + 1) * 3; + int ry = tabry[dy]; + int iry = (1 << SCALE) - ry; + cache0[dy3 + 0] = pcsrc01[0] * ry + pcsrc00[0] * iry; + cache1[dy3 + 0] = pcsrc11[0] * ry + pcsrc10[0] * iry; + cache0[dy3 + 1] = pcsrc01[1] * ry + pcsrc00[1] * iry; + cache1[dy3 + 1] = pcsrc11[1] * ry + pcsrc10[1] * iry; + cache0[dy3 + 2] = pcsrc01[2] * ry + pcsrc00[2] * iry; + cache1[dy3 + 2] = pcsrc11[2] * ry + pcsrc10[2] * iry; + } +} + +void resize_linear_8u_neon(const Mat8u& src, Mat8u& dst) { + AlignedVector tabsx(dst.rows()); + AlignedVector tabsy(dst.cols()); + AlignedVector tabrx(dst.rows()); + AlignedVector tabry(dst.cols()); + build_tabs_linear_8u(src, dst, tabsx, tabsy, tabrx, tabry); + + if (src.channels() == 1) { + AlignedVector cache0(dst.cols()), cache1(dst.cols()); + int dstrows = dst.rows(); + int dstcols = dst.cols(); + for (int dx = 0; dx < dstrows; ++dx) { + if (dx == 0 || tabsx[dx] != tabsx[dx - 1]) { + if (dx > 0 && tabsx[dx] == tabsx[dx - 1] + 1) { + calc_cache_8uc1_1(src, dst, tabsx, tabsy, tabrx, tabry, dx, + cache0, cache1); + } else { + calc_cache_8uc1_2(src, dst, tabsx, tabsy, tabrx, tabry, dx, + cache0, cache1); + } + } + int rx = tabrx[dx]; + int irx = (1 << SCALE) - rx; + uchar* pdst = dst.ptr(dx); + int dy = 0; + + const int* cache0_ptr = cache0.data(); + const int* cache1_ptr = cache1.data(); + int32x4_t v_rx = vdupq_n_s32(rx); + int32x4_t v_irx = vdupq_n_s32(irx); + const int RSCALE = SCALE + SCALE - 16; + for (; dy + 16 <= dstcols; dy += 16) { + int32x4_t v_cache0_0; + int32x4_t v_cache1_0; + int32x4_t v_cache0_4; + int32x4_t v_cache1_4; + int32x4_t v_cache0_8; + int32x4_t v_cache1_8; + int32x4_t v_cache0_c; + int32x4_t v_cache1_c; + + v_cache0_0 = vld1q_s32(cache0_ptr + dy + 0x0); + v_cache1_0 = vld1q_s32(cache1_ptr + dy + 0x0); + v_cache0_4 = vld1q_s32(cache0_ptr + dy + 0x4); + v_cache1_4 = vld1q_s32(cache1_ptr + dy + 0x4); + v_cache0_8 = vld1q_s32(cache0_ptr + dy + 0x8); + v_cache1_8 = vld1q_s32(cache1_ptr + dy + 0x8); + v_cache0_c = vld1q_s32(cache0_ptr + dy + 0xc); + v_cache1_c = vld1q_s32(cache1_ptr + dy + 0xc); + + int16x4_t v_ans0, v_ans4, v_ans8, v_ansc; + v_ans0 = vqshrn_n_s32(vmlaq_s32(vmulq_s32(v_rx, v_cache1_0), + v_irx, v_cache0_0), + 16); + v_ans4 = vqshrn_n_s32(vmlaq_s32(vmulq_s32(v_rx, v_cache1_4), + v_irx, v_cache0_4), + 16); + v_ans8 = vqshrn_n_s32(vmlaq_s32(vmulq_s32(v_rx, v_cache1_8), + v_irx, v_cache0_8), + 16); + v_ansc = vqshrn_n_s32(vmlaq_s32(vmulq_s32(v_rx, v_cache1_c), + v_irx, v_cache0_c), + 16); + + int16x8_t v_half16_0, v_half16_1; + v_half16_0 = vcombine_s16(v_ans0, v_ans4); // x0 + v_half16_1 = vcombine_s16(v_ans8, v_ansc); // y0 + + uint8x8_t v_half8_0, v_half8_1; + v_half8_0 = vqshrun_n_s16(v_half16_0, RSCALE); + v_half8_1 = vqshrun_n_s16(v_half16_1, RSCALE); + + vst1q_u8(pdst + dy, vcombine_u8(v_half8_0, v_half8_1)); + } + + for (; dy < dstcols; ++dy) { + uchar* pcdst = pdst + dy; + pcdst[0] = + (rx * cache1[dy] + irx * cache0[dy]) >> (SCALE + SCALE); + } + } + } else if (src.channels() == 3) { + int dstrows = dst.rows(); + int dstcols = dst.cols() * 3; + AlignedVector cache0(dstcols), cache1(dstcols); + for (int dx = 0; dx < dstrows; ++dx) { + if (dx == 0 || tabsx[dx] != tabsx[dx - 1]) { + if (dx > 0 && tabsx[dx] == tabsx[dx - 1] + 1) { + calc_cache_8uc3_1(src, dst, tabsx, tabsy, tabrx, tabry, dx, + cache0, cache1); + } else { + calc_cache_8uc3_2(src, dst, tabsx, tabsy, tabrx, tabry, dx, + cache0, cache1); + } + } + int rx = tabrx[dx]; + int irx = (1 << SCALE) - rx; + uchar* pdst = dst.ptr(dx); + int dy = 0; + + for (; dy < dstcols; dy += 3) { + uchar* pcdst = pdst + dy; + pcdst[0] = (rx * cache1[dy + 0] + irx * cache0[dy + 0]) >> + (SCALE + SCALE); + pcdst[1] = (rx * cache1[dy + 1] + irx * cache0[dy + 1]) >> + (SCALE + SCALE); + pcdst[2] = (rx * cache1[dy + 2] + irx * cache0[dy + 2]) >> + (SCALE + SCALE); + } + } + } else { + megdnn_throw(("nr. of channels must be 1 or 3.")); + } +} + +void resize_linear_8u(const Mat8u& src, Mat8u& dst) { + return resize_linear_8u_neon(src, dst); +} + +const int INTER_RESIZE_COEF_BITS = 11; +const int INTER_RESIZE_COEF_SCALE = 1 << INTER_RESIZE_COEF_BITS; +const float MEGCV_PI = acos(-1); +struct HResizeNoVec { + int operator()(const uchar**, uchar**, int, const int*, const uchar*, int, + int, int, int, int) const { + return 0; + } +}; +struct VResizeNoVec { + int operator()(const uchar**, uchar*, const uchar*, int) const { return 0; } +}; +template +struct ResizeAreaFastNoVec { + ResizeAreaFastNoVec(int, int) {} + ResizeAreaFastNoVec(int, int, int, int) {} + int operator()(const T*, T*, int) const { return 0; } +}; + +struct VResizeCubicVec_32f { + int operator()(const uchar** _src, uchar* _dst, const uchar* _beta, + int width) const { + const float** src = (const float**)_src; + const float* beta = (const float*)_beta; + const float *S0 = src[0], *S1 = src[1], *S2 = src[2], *S3 = src[3]; + float* dst = (float*)_dst; + int x = 0; + float32x4_t v_b0 = vdupq_n_f32(beta[0]), v_b1 = vdupq_n_f32(beta[1]), + v_b2 = vdupq_n_f32(beta[2]), v_b3 = vdupq_n_f32(beta[3]); + + for (; x <= width - 8; x += 8) { + vst1q_f32( + dst + x, + vmlaq_f32(vmlaq_f32(vmlaq_f32(vmulq_f32(v_b0, + vld1q_f32(S0 + x)), + v_b1, vld1q_f32(S1 + x)), + v_b2, vld1q_f32(S2 + x)), + v_b3, vld1q_f32(S3 + x))); + vst1q_f32( + dst + x + 4, + vmlaq_f32(vmlaq_f32(vmlaq_f32(vmulq_f32(v_b0, + vld1q_f32(S0 + x + + 4)), + v_b1, vld1q_f32(S1 + x + 4)), + v_b2, vld1q_f32(S2 + x + 4)), + v_b3, vld1q_f32(S3 + x + 4))); + } + + return x; + } +}; + +struct VResizeLanczos4Vec_32f { + int operator()(const uchar** _src, uchar* _dst, const uchar* _beta, + int width) const { + const float** src = (const float**)_src; + const float* beta = (const float*)_beta; + const float *S0 = src[0], *S1 = src[1], *S2 = src[2], *S3 = src[3], + *S4 = src[4], *S5 = src[5], *S6 = src[6], *S7 = src[7]; + float* dst = (float*)_dst; + int x = 0; + float32x4_t v_b0 = vdupq_n_f32(beta[0]), v_b1 = vdupq_n_f32(beta[1]), + v_b2 = vdupq_n_f32(beta[2]), v_b3 = vdupq_n_f32(beta[3]), + v_b4 = vdupq_n_f32(beta[4]), v_b5 = vdupq_n_f32(beta[5]), + v_b6 = vdupq_n_f32(beta[6]), v_b7 = vdupq_n_f32(beta[7]); + + for (; x <= width - 4; x += 4) { + float32x4_t v_dst0 = vmlaq_f32( + vmlaq_f32(vmlaq_f32(vmulq_f32(v_b0, vld1q_f32(S0 + x)), + v_b1, vld1q_f32(S1 + x)), + v_b2, vld1q_f32(S2 + x)), + v_b3, vld1q_f32(S3 + x)); + float32x4_t v_dst1 = vmlaq_f32( + vmlaq_f32(vmlaq_f32(vmulq_f32(v_b4, vld1q_f32(S4 + x)), + v_b5, vld1q_f32(S5 + x)), + v_b6, vld1q_f32(S6 + x)), + v_b7, vld1q_f32(S7 + x)); + vst1q_f32(dst + x, vaddq_f32(v_dst0, v_dst1)); + } + + return x; + } +}; +struct VResizeLinearVec_32f { + int operator()(const uchar** _src, uchar* _dst, const uchar* _beta, + int width) const { + const float** src = (const float**)_src; + const float* beta = (const float*)_beta; + const float *S0 = src[0], *S1 = src[1]; + float* dst = (float*)_dst; + int x = 0; + + float32x4_t v_b0 = vdupq_n_f32(beta[0]), v_b1 = vdupq_n_f32(beta[1]); + + for (; x <= width - 8; x += 8) { + float32x4_t v_src00 = vld1q_f32(S0 + x), + v_src01 = vld1q_f32(S0 + x + 4); + float32x4_t v_src10 = vld1q_f32(S1 + x), + v_src11 = vld1q_f32(S1 + x + 4); + + vst1q_f32(dst + x, + vmlaq_f32(vmulq_f32(v_src00, v_b0), v_src10, v_b1)); + vst1q_f32(dst + x + 4, + vmlaq_f32(vmulq_f32(v_src01, v_b0), v_src11, v_b1)); + } + + return x; + } +}; +struct VResizeLinearVec_32s8u { + int operator()(const uchar** _src, uchar* dst, const uchar* _beta, + int width) const { + const int **src = (const int**)_src, *S0 = src[0], *S1 = src[1]; + const short* beta = (const short*)_beta; + int x = 0; + int16x8_t v_b0 = vdupq_n_s16(beta[0]), v_b1 = vdupq_n_s16(beta[1]), + v_delta = vdupq_n_s16(2); + + for (; x <= width - 16; x += 16) { + int32x4_t v_src00 = vshrq_n_s32(vld1q_s32(S0 + x), 4), + v_src10 = vshrq_n_s32(vld1q_s32(S1 + x), 4); + int32x4_t v_src01 = vshrq_n_s32(vld1q_s32(S0 + x + 4), 4), + v_src11 = vshrq_n_s32(vld1q_s32(S1 + x + 4), 4); + + int16x8_t v_src0 = + vcombine_s16(vmovn_s32(v_src00), vmovn_s32(v_src01)); + int16x8_t v_src1 = + vcombine_s16(vmovn_s32(v_src10), vmovn_s32(v_src11)); + + int16x8_t v_dst0 = + vaddq_s16(vshrq_n_s16(vqdmulhq_s16(v_src0, v_b0), 1), + vshrq_n_s16(vqdmulhq_s16(v_src1, v_b1), 1)); + v_dst0 = vshrq_n_s16(vaddq_s16(v_dst0, v_delta), 2); + + v_src00 = vshrq_n_s32(vld1q_s32(S0 + x + 8), 4); + v_src10 = vshrq_n_s32(vld1q_s32(S1 + x + 8), 4); + v_src01 = vshrq_n_s32(vld1q_s32(S0 + x + 12), 4); + v_src11 = vshrq_n_s32(vld1q_s32(S1 + x + 12), 4); + + v_src0 = vcombine_s16(vmovn_s32(v_src00), vmovn_s32(v_src01)); + v_src1 = vcombine_s16(vmovn_s32(v_src10), vmovn_s32(v_src11)); + + int16x8_t v_dst1 = + vaddq_s16(vshrq_n_s16(vqdmulhq_s16(v_src0, v_b0), 1), + vshrq_n_s16(vqdmulhq_s16(v_src1, v_b1), 1)); + v_dst1 = vshrq_n_s16(vaddq_s16(v_dst1, v_delta), 2); + + vst1q_u8(dst + x, + vcombine_u8(vqmovun_s16(v_dst0), vqmovun_s16(v_dst1))); + } + + return x; + } +}; + +typedef HResizeNoVec HResizeLinearVec_32f; +typedef HResizeNoVec HResizeLinearVec_8u32s; +typedef VResizeNoVec VResizeLanczos4Vec_32s8u; +typedef VResizeNoVec VResizeCubicVec_32s8u; + +class ResizeAreaFastVec_SIMD_8u { +public: + ResizeAreaFastVec_SIMD_8u(int _cn, int _step) : cn(_cn), step(_step) {} + + int operator()(const uchar* S, uchar* D, int w) const { + int dx = 0; + const uchar *S0 = S, *S1 = S0 + step; + + uint16x8_t v_2 = vdupq_n_u16(2); + + if (cn == 1) { + for (; dx <= w - 16; dx += 16, S0 += 32, S1 += 32, D += 16) { + uint8x16x2_t v_row0 = vld2q_u8(S0), v_row1 = vld2q_u8(S1); + + uint16x8_t v_dst0 = vaddl_u8(vget_low_u8(v_row0.val[0]), + vget_low_u8(v_row0.val[1])); + v_dst0 = + vaddq_u16(v_dst0, vaddl_u8(vget_low_u8(v_row1.val[0]), + vget_low_u8(v_row1.val[1]))); + v_dst0 = vshrq_n_u16(vaddq_u16(v_dst0, v_2), 2); + + uint16x8_t v_dst1 = vaddl_u8(vget_high_u8(v_row0.val[0]), + vget_high_u8(v_row0.val[1])); + v_dst1 = vaddq_u16(v_dst1, + vaddl_u8(vget_high_u8(v_row1.val[0]), + vget_high_u8(v_row1.val[1]))); + v_dst1 = vshrq_n_u16(vaddq_u16(v_dst1, v_2), 2); + + vst1q_u8(D, vcombine_u8(vmovn_u16(v_dst0), vmovn_u16(v_dst1))); + } + } else { + megdnn_assert(cn == 3); + } + + return dx; + } + +private: + int cn, step; +}; +struct ResizeAreaFastVec_SIMD_32f { + ResizeAreaFastVec_SIMD_32f(int _scale_x, int _scale_y, int _cn, int _step) + : scale_x(_scale_x), + scale_y(_scale_y), + cn(_cn), + step(_step * sizeof(float)) { + fast_mode = + scale_x == 2 && scale_y == 2 && (cn == 1 || cn == 3 || cn == 4); + } + + int operator()(const float* S, float* D, int w) const { + if (!fast_mode) + return 0; + + const float *S0 = S, *S1 = (const float*)((const uchar*)(S0) + step); + int dx = 0; + + float32x4_t v_025 = vdupq_n_f32(0.25f); + + if (cn == 1) { + for (; dx <= w - 4; dx += 4, S0 += 8, S1 += 8, D += 4) { + float32x4x2_t v_row0 = vld2q_f32(S0), v_row1 = vld2q_f32(S1); + + float32x4_t v_dst0 = vaddq_f32(v_row0.val[0], v_row0.val[1]); + float32x4_t v_dst1 = vaddq_f32(v_row1.val[0], v_row1.val[1]); + + vst1q_f32(D, vmulq_f32(vaddq_f32(v_dst0, v_dst1), v_025)); + } + } + + return dx; + } + +private: + int scale_x, scale_y; + int cn; + bool fast_mode; + int step; +}; + +struct DecimateAlpha { + int si, di; + float alpha; +}; +template +using ResizeFunc = void (*)(const Mat& src, Mat& dst, const int* xofs, + const void* alpha, const int* yofs, + const void* beta, int xmin, int xmax, int ksize); +template +using ResizeAreaFastFunc = void (*)(const Mat& src, Mat& dst, + const int* ofs, const int* xofs, + int scale_x, int scale_y); +template +using ResizeAreaFunc = void (*)(const Mat& src, Mat& dst, + const DecimateAlpha* xtab, int xtab_size, + const DecimateAlpha* ytab, int ytab_size, + const int* yofs); + +static inline void interpolate_cubic(float x, float* coeffs) { + const float A = -0.75f; + + coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A; + coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1; + coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1; + coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2]; +} +static inline void interpolate_lanczos4(float x, float* coeffs) { + static const double s45 = 0.70710678118654752440084436210485; + static const double cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45}, + {-1, 0}, {s45, s45}, {0, -1}, {-s45, s45}}; + + if (x < FLT_EPSILON) { + for (int i = 0; i < 8; i++) + coeffs[i] = 0; + coeffs[3] = 1; + return; + } + + float sum = 0; + double y0 = -(x + 3) * MEGCV_PI * 0.25, s0 = sin(y0), c0 = cos(y0); + for (int i = 0; i < 8; i++) { + double y = -(x + 3 - i) * MEGCV_PI * 0.25; + coeffs[i] = (float)((cs[i][0] * s0 + cs[i][1] * c0) / (y * y)); + sum += coeffs[i]; + } + + sum = 1.f / sum; + for (int i = 0; i < 8; i++) + coeffs[i] *= sum; +} + +template +struct HResizeLanczos4 { + typedef T value_type; + typedef WT buf_type; + typedef AT alpha_type; + + void operator()(const T** src, WT** dst, int count, const int* xofs, + const AT* alpha, int swidth, int dwidth, int cn, int xmin, + int xmax) const { + for (int k = 0; k < count; k++) { + const T* S = src[k]; + WT* D = dst[k]; + int dx = 0, limit = xmin; + if (cn == 1) { + for (;;) { + for (; dx < limit; dx++, alpha += 8) { + int j, sx = xofs[dx] - 1 * 3; + WT v = 0; + for (j = 0; j < 8; j++) { + int sxj = sx + j * 1; + if ((unsigned)sxj >= (unsigned)swidth) { + while (sxj < 0) + sxj += 1; + while (sxj >= swidth) + sxj -= 1; + } + v += S[sxj] * alpha[j]; + } + D[dx] = v; + } + if (limit == dwidth) + break; + for (; dx < xmax; dx++, alpha += 8) { + int sx = xofs[dx]; + D[dx] = S[sx - 1 * 3] * alpha[0] + + S[sx - 1 * 2] * alpha[1] + + S[sx - 1] * alpha[2] + S[sx] * alpha[3] + + S[sx + 1] * alpha[4] + + S[sx + 1 * 2] * alpha[5] + + S[sx + 1 * 3] * alpha[6] + + S[sx + 1 * 4] * alpha[7]; + } + limit = dwidth; + } + } else { + megdnn_assert(cn == 3); + for (;;) { + for (; dx < limit; dx++, alpha += 8) { + int j, sx = xofs[dx] - 3 * 3; + WT v = 0; + for (j = 0; j < 8; j++) { + int sxj = sx + j * 3; + if ((unsigned)sxj >= (unsigned)swidth) { + while (sxj < 0) + sxj += 3; + while (sxj >= swidth) + sxj -= 3; + } + v += S[sxj] * alpha[j]; + } + D[dx] = v; + } + if (limit == dwidth) + break; + for (; dx < xmax; dx++, alpha += 8) { + int sx = xofs[dx]; + D[dx] = S[sx - 3 * 3] * alpha[0] + + S[sx - 3 * 2] * alpha[1] + + S[sx - 3] * alpha[2] + S[sx] * alpha[3] + + S[sx + 3] * alpha[4] + + S[sx + 3 * 2] * alpha[5] + + S[sx + 3 * 3] * alpha[6] + + S[sx + 3 * 4] * alpha[7]; + } + limit = dwidth; + } + } + alpha -= dwidth * 8; + } + } +}; +template +struct HResizeLinear { + typedef T value_type; + typedef WT buf_type; + typedef AT alpha_type; + + void operator()(const T** src, WT** dst, int count, const int* xofs, + const AT* alpha, int swidth, int dwidth, int cn, int xmin, + int xmax) const { + int dx, k; + VecOp vecOp; + + int dx0 = vecOp((const uchar**)src, (uchar**)dst, count, xofs, + (const uchar*)alpha, swidth, dwidth, cn, xmin, xmax); + + for (k = 0; k <= count - 2; k++) { + const T *S0 = src[k], *S1 = src[k + 1]; + WT *D0 = dst[k], *D1 = dst[k + 1]; + for (dx = dx0; dx < xmax; dx++) { + int sx = xofs[dx]; + WT a0 = alpha[dx * 2], a1 = alpha[dx * 2 + 1]; + WT t0 = S0[sx] * a0 + S0[sx + cn] * a1; + WT t1 = S1[sx] * a0 + S1[sx + cn] * a1; + D0[dx] = t0; + D1[dx] = t1; + } + + for (; dx < dwidth; dx++) { + int sx = xofs[dx]; + D0[dx] = WT(S0[sx] * ONE); + D1[dx] = WT(S1[sx] * ONE); + } + } + + for (; k < count; k++) { + const T* S = src[k]; + WT* D = dst[k]; + for (dx = 0; dx < xmax; dx++) { + int sx = xofs[dx]; + D[dx] = S[sx] * alpha[dx * 2] + S[sx + cn] * alpha[dx * 2 + 1]; + } + + for (; dx < dwidth; dx++) + D[dx] = WT(S[xofs[dx]] * ONE); + } + } +}; +template +struct HResizeCubic { + typedef T value_type; + typedef WT buf_type; + typedef AT alpha_type; + + void operator()(const T** src, WT** dst, int count, const int* xofs, + const AT* alpha, int swidth, int dwidth, int cn, int xmin, + int xmax) const { + for (int k = 0; k < count; k++) { + const T* S = src[k]; + WT* D = dst[k]; + int dx = 0, limit = xmin; + if (cn == 1) { + for (;;) { + for (; dx < limit; dx++, alpha += 4) { + int j, sx = xofs[dx] - 1; + WT v = 0; + for (j = 0; j < 4; j++) { + int sxj = sx + j * 1; + if ((unsigned)sxj >= (unsigned)swidth) { + while (sxj < 0) + sxj += 1; + while (sxj >= swidth) + sxj -= 1; + } + v += S[sxj] * alpha[j]; + } + D[dx] = v; + } + if (limit == dwidth) + break; + for (; dx < xmax; dx++, alpha += 4) { + int sx = xofs[dx]; + D[dx] = S[sx - 1] * alpha[0] + S[sx] * alpha[1] + + S[sx + 1] * alpha[2] + S[sx + 1 * 2] * alpha[3]; + } + limit = dwidth; + } + } else { + megdnn_assert(cn == 3); + for (;;) { + for (; dx < limit; dx++, alpha += 4) { + int j, sx = xofs[dx] - 3; + WT v = 0; + for (j = 0; j < 4; j++) { + int sxj = sx + j * 3; + if ((unsigned)sxj >= (unsigned)swidth) { + while (sxj < 0) + sxj += 3; + while (sxj >= swidth) + sxj -= 3; + } + v += S[sxj] * alpha[j]; + } + D[dx] = v; + } + if (limit == dwidth) + break; + for (; dx < xmax; dx++, alpha += 4) { + int sx = xofs[dx]; + D[dx] = S[sx - 3] * alpha[0] + S[sx] * alpha[1] + + S[sx + 3] * alpha[2] + S[sx + 3 * 2] * alpha[3]; + } + limit = dwidth; + } + } + alpha -= dwidth * 4; + } + } +}; + +template +struct VResizeLanczos4 { + typedef T value_type; + typedef WT buf_type; + typedef AT alpha_type; + + void operator()(const WT** src, T* dst, const AT* beta, int width) const { + CastOp castOp; + VecOp vecOp; + int k, x = vecOp((const uchar**)src, (uchar*)dst, (const uchar*)beta, + width); +#if MEGCV_ENABLE_UNROLLED + for (; x <= width - 4; x += 4) { + WT b = beta[0]; + const WT* S = src[0]; + WT s0 = S[x] * b, s1 = S[x + 1] * b, s2 = S[x + 2] * b, + s3 = S[x + 3] * b; + + for (k = 1; k < 8; k++) { + b = beta[k]; + S = src[k]; + s0 += S[x] * b; + s1 += S[x + 1] * b; + s2 += S[x + 2] * b; + s3 += S[x + 3] * b; + } + + dst[x] = castOp(s0); + dst[x + 1] = castOp(s1); + dst[x + 2] = castOp(s2); + dst[x + 3] = castOp(s3); + } +#endif + + for (; x < width; x++) { + dst[x] = castOp(src[0][x] * beta[0] + src[1][x] * beta[1] + + src[2][x] * beta[2] + src[3][x] * beta[3] + + src[4][x] * beta[4] + src[5][x] * beta[5] + + src[6][x] * beta[6] + src[7][x] * beta[7]); + } + } +}; +template +struct VResizeLinear { + typedef T value_type; + typedef WT buf_type; + typedef AT alpha_type; + + void operator()(const WT** src, T* dst, const AT* beta, int width) const { + WT b0 = beta[0], b1 = beta[1]; + const WT *S0 = src[0], *S1 = src[1]; + CastOp castOp; + VecOp vecOp; + int x = vecOp((const uchar**)src, (uchar*)dst, (const uchar*)beta, + width); +#if MEGCV_ENABLE_UNROLLED + for (; x <= width - 4; x += 4) { + WT t0, t1; + t0 = S0[x] * b0 + S1[x] * b1; + t1 = S0[x + 1] * b0 + S1[x + 1] * b1; + dst[x] = castOp(t0); + dst[x + 1] = castOp(t1); + t0 = S0[x + 2] * b0 + S1[x + 2] * b1; + t1 = S0[x + 3] * b0 + S1[x + 3] * b1; + dst[x + 2] = castOp(t0); + dst[x + 3] = castOp(t1); + } +#endif + for (; x < width; x++) + dst[x] = castOp(S0[x] * b0 + S1[x] * b1); + } +}; +template +struct VResizeCubic { + typedef T value_type; + typedef WT buf_type; + typedef AT alpha_type; + + void operator()(const WT** src, T* dst, const AT* beta, int width) const { + WT b0 = beta[0], b1 = beta[1], b2 = beta[2], b3 = beta[3]; + const WT *S0 = src[0], *S1 = src[1], *S2 = src[2], *S3 = src[3]; + CastOp castOp; + VecOp vecOp; + + int x = vecOp((const uchar**)src, (uchar*)dst, (const uchar*)beta, + width); + for (; x < width; x++) + dst[x] = castOp(S0[x] * b0 + S1[x] * b1 + S2[x] * b2 + S3[x] * b3); + } +}; +template <> +struct VResizeLinear, + VResizeLinearVec_32s8u> { + typedef uchar value_type; + typedef int buf_type; + typedef short alpha_type; + + void operator()(const buf_type** src, value_type* dst, + const alpha_type* beta, int width) const { + alpha_type b0 = beta[0], b1 = beta[1]; + const buf_type *S0 = src[0], *S1 = src[1]; + VResizeLinearVec_32s8u vecOp; + + int x = vecOp((const uchar**)src, (uchar*)dst, (const uchar*)beta, + width); +#if MEGCV_ENABLE_UNROLLED + for (; x <= width - 4; x += 4) { + dst[x + 0] = uchar((((b0 * (S0[x + 0] >> 4)) >> 16) + + ((b1 * (S1[x + 0] >> 4)) >> 16) + 2) >> + 2); + dst[x + 1] = uchar((((b0 * (S0[x + 1] >> 4)) >> 16) + + ((b1 * (S1[x + 1] >> 4)) >> 16) + 2) >> + 2); + dst[x + 2] = uchar((((b0 * (S0[x + 2] >> 4)) >> 16) + + ((b1 * (S1[x + 2] >> 4)) >> 16) + 2) >> + 2); + dst[x + 3] = uchar((((b0 * (S0[x + 3] >> 4)) >> 16) + + ((b1 * (S1[x + 3] >> 4)) >> 16) + 2) >> + 2); + } +#endif + for (; x < width; x++) + dst[x] = uchar((((b0 * (S0[x] >> 4)) >> 16) + + ((b1 * (S1[x] >> 4)) >> 16) + 2) >> + 2); + } +}; + +template +void resizeGeneric_(const Mat& src, Mat& dst, const int* xofs, + const void* _alpha, const int* yofs, const void* _beta, + int xmin, int xmax, int ksize) { + typedef typename HResize::value_type T; + typedef typename HResize::buf_type WT; + typedef typename HResize::alpha_type AT; + + const AT* beta = static_cast(_beta); + const AT* alpha = static_cast(_alpha); + int swidth = src.width(); + int sheight = src.height(); + int dwidth = dst.width(); + int dheight = dst.height(); + int cn = src.channels(); + swidth *= cn; + dwidth *= cn; + xmin *= cn; + xmax *= cn; + // image resize is a separable operation. In case of not too strong + // dsize.height + int dy; + HResize hresize; + VResize vresize; + + int bufstep = static_cast(align_size(dwidth, 16)); + AlignedVector _buffer(bufstep * ksize); + WT* buffer = _buffer.data(); + const T* srows[16] = {0}; + WT* rows[16] = {0}; + int prev_sy[16]; + + for (int k = 0; k < ksize; ++k) { + prev_sy[k] = -1; + rows[k] = buffer + bufstep * k; + } + + for (dy = 0; dy < dheight; ++dy, beta += ksize) { + int sy0 = yofs[dy], k0 = ksize, k1 = 0, ksize2 = ksize / 2; + + for (int k = 0; k < ksize; ++k) { + int sy = saturate(sy0 - ksize2 + 1 + k, 0, sheight); + for (k1 = std::max(k1, k); k1 < ksize; ++k1) { + if (sy == prev_sy[k1]) { + if (k1 > k) + memcpy(rows[k], rows[k1], bufstep * sizeof(rows[0][0])); + break; + } + } + if (k1 == ksize) + k0 = std::min(k0, k); + srows[k] = src.ptr(sy); + prev_sy[k] = sy; + } + if (k0 < ksize) + hresize(srows + k0, rows + k0, ksize - k0, xofs, alpha, swidth, + dwidth, cn, xmin, xmax); + vresize((const WT**)(rows), dst.ptr(dy), beta, dwidth); + } +} + +template +void setup_resize_env(InterpolationMode /* ip */, int& /* ksize */, + bool& /* fixedpt */, ResizeFunc& /* func */) { + megdnn_throw(("unimplemented")); +} +template <> +void setup_resize_env(InterpolationMode ip, int& ksize, bool& fixedpt, + ResizeFunc& func) { + fixedpt = false; + switch (ip) { + case IMode::INTER_CUBIC: + ksize = 4; + func = resizeGeneric_< + HResizeCubic, + VResizeCubic, + VResizeCubicVec_32f>, + float>; + break; + case IMode::INTER_LANCZOS4: + ksize = 8; + func = resizeGeneric_< + HResizeLanczos4, + VResizeLanczos4, + VResizeLanczos4Vec_32f>, + float>; + break; + case IMode::INTER_LINEAR: + case IMode::INTER_AREA: + ksize = 2; + func = resizeGeneric_< + HResizeLinear, + VResizeLinear, + VResizeLinearVec_32f>, + float>; + break; + default: + megdnn_throw(("unknown interpolation method")); + } +} +template <> +void setup_resize_env(InterpolationMode ip, int& ksize, bool& fixedpt, + ResizeFunc& func) { + fixedpt = true; + switch (ip) { + case IMode::INTER_CUBIC: + ksize = 4; + func = resizeGeneric_< + HResizeCubic, + VResizeCubic< + uchar, int, short, + FixedPtCast, + VResizeCubicVec_32s8u>, + uchar>; + break; + case IMode::INTER_LANCZOS4: + ksize = 8; + func = resizeGeneric_< + HResizeLanczos4, + VResizeLanczos4< + uchar, int, short, + FixedPtCast, + VResizeLanczos4Vec_32s8u>, + uchar>; + break; + case IMode::INTER_LINEAR: + case IMode::INTER_AREA: + ksize = 2; + func = resizeGeneric_< + HResizeLinear, + VResizeLinear< + uchar, int, short, + FixedPtCast, + VResizeLinearVec_32s8u>, + uchar>; + break; + default: + megdnn_throw(("unknown interpolation method")); + } +} + +int compute_resize_area_tab(int ssize, int dsize, int cn, double scale, + DecimateAlpha* tab) { + int k = 0; + for (int dx = 0; dx < dsize; dx++) { + double fsx1 = dx * scale; + double fsx2 = fsx1 + scale; + double cellWidth = std::min(scale, ssize - fsx1); + + int sx1 = ceil(fsx1), sx2 = floor(fsx2); + + sx2 = std::min(sx2, ssize - 1); + sx1 = std::min(sx1, sx2); + + if (sx1 - fsx1 > 1e-3) { + megdnn_assert(k < ssize * 2); + tab[k].di = dx * cn; + tab[k].si = (sx1 - 1) * cn; + tab[k++].alpha = (float)((sx1 - fsx1) / cellWidth); + } + + for (int sx = sx1; sx < sx2; sx++) { + megdnn_assert(k < ssize * 2); + tab[k].di = dx * cn; + tab[k].si = sx * cn; + tab[k++].alpha = float(1.0 / cellWidth); + } + + if (fsx2 - sx2 > 1e-3) { + megdnn_assert(k < ssize * 2); + tab[k].di = dx * cn; + tab[k].si = sx2 * cn; + tab[k++].alpha = + (float)(std::min(std::min(fsx2 - sx2, 1.), cellWidth) / + cellWidth); + } + } + return k; +} + +// resize Area Fast +template +void resizeAreaFast_(const Mat& src, Mat& dst, const int* ofs, + const int* xofs, int scale_x, int scale_y) { + // Range range(0, dst.rows); + int swidth = src.width(); + int sheight = src.height(); + int dwidth = dst.width(); + int dheight = dst.height(); + int cn = src.channels(); + int area = scale_x * scale_y; + float scale = 1.f / (area); + int dwidth1 = (swidth / scale_x) * cn; + dwidth *= cn; + swidth *= cn; + int dy, dx, k = 0; + + VecOp vop(scale_x, scale_y, src.channels(), (int)src.step()); + + for (dy = 0; dy < dheight; dy++) { + T* D = (T*)(dst.ptr(dy)); + int sy0 = dy * scale_y; + int w = sy0 + scale_y <= sheight ? dwidth1 : 0; + + if (sy0 >= sheight) { + for (dx = 0; dx < dwidth; dx++) + D[dx] = 0; + continue; + } + + dx = vop((const T*)(src.ptr(sy0)), D, w); + for (; dx < w; dx++) { + const T* S = (const T*)(src.ptr(sy0)) + xofs[dx]; + WT sum = 0; + k = 0; +#if MEGCV_ENABLE_UNROLLED + for (; k <= area - 4; k += 4) + sum += S[ofs[k]] + S[ofs[k + 1]] + S[ofs[k + 2]] + + S[ofs[k + 3]]; +#endif + for (; k < area; k++) + sum += S[ofs[k]]; + + D[dx] = saturate_cast(sum * scale); + } + + for (; dx < dwidth; dx++) { + WT sum = 0; + int count = 0, sx0 = xofs[dx]; + if (sx0 >= swidth) + D[dx] = 0; + + for (int sy = 0; sy < scale_y; sy++) { + if (sy0 + sy >= sheight) + break; + const T* S = (const T*)(src.ptr(sy0 + sy)) + sx0; + for (int sx = 0; sx < scale_x * cn; sx += cn) { + if (sx0 + sx >= swidth) + break; + sum += S[sx]; + count++; + } + } + + D[dx] = saturate_cast((float)sum / count); + } + } +} + +template +struct ResizeAreaFastVec { + ResizeAreaFastVec(int _scale_x, int _scale_y, int _cn, int _step) + : scale_x(_scale_x), + scale_y(_scale_y), + cn(_cn), + step(_step), + vecOp(_cn, _step) { + fast_mode = + scale_x == 2 && scale_y == 2 && (cn == 1 || cn == 3 || cn == 4); + } + + int operator()(const T* S, T* D, int w) const { + if (!fast_mode) + return 0; + + const T* nextS = (const T*)((const uchar*)S + step); + int dx = vecOp(S, D, w); + + if (cn == 1) + for (; dx < w; ++dx) { + int index = dx * 2; + D[dx] = (T)((S[index] + S[index + 1] + nextS[index] + + nextS[index + 1] + 2) >> + 2); + } + else if (cn == 3) + for (; dx < w; dx += 3) { + int index = dx * 2; + D[dx] = (T)((S[index] + S[index + 3] + nextS[index] + + nextS[index + 3] + 2) >> + 2); + D[dx + 1] = (T)((S[index + 1] + S[index + 4] + + nextS[index + 1] + nextS[index + 4] + 2) >> + 2); + D[dx + 2] = (T)((S[index + 2] + S[index + 5] + + nextS[index + 2] + nextS[index + 5] + 2) >> + 2); + } + else { + megdnn_assert(cn == 4); + for (; dx < w; dx += 4) { + int index = dx * 2; + D[dx] = (T)((S[index] + S[index + 4] + nextS[index] + + nextS[index + 4] + 2) >> + 2); + D[dx + 1] = (T)((S[index + 1] + S[index + 5] + + nextS[index + 1] + nextS[index + 5] + 2) >> + 2); + D[dx + 2] = (T)((S[index + 2] + S[index + 6] + + nextS[index + 2] + nextS[index + 6] + 2) >> + 2); + D[dx + 3] = (T)((S[index + 3] + S[index + 7] + + nextS[index + 3] + nextS[index + 7] + 2) >> + 2); + } + } + + return dx; + } + +private: + int scale_x, scale_y; + int cn; + bool fast_mode; + int step; + SIMDVecOp vecOp; +}; + +template +ResizeAreaFastFunc get_resize_area_fast_func() { + megdnn_throw(("unknown type")); +} + +template <> +ResizeAreaFastFunc get_resize_area_fast_func() { + return resizeAreaFast_; +} + +template <> +ResizeAreaFastFunc get_resize_area_fast_func() { + return resizeAreaFast_>; +} + +// Resize Area +template +static void resizeArea_(const Mat& src, Mat& dst, + const DecimateAlpha* xtab, int xtab_size, + const DecimateAlpha* ytab, int ytab_size, + const int* tabofs) { + // parallel_for_(Range(0, dst.rows), + // ResizeArea_Invoker(src, dst, xtab, xtab_size, ytab, ytab_size, + // tabofs), dst.total()/((double)(1 << 16))); + (void)ytab_size; + int dwidth = dst.width(), dheight = dst.height(); + int cn = dst.channels(); + dwidth *= cn; + AlignedVector _buffer(dwidth * 2); + WT *buf = _buffer.data(), *sum = buf + dwidth; + int j_start = tabofs[0], j_end = tabofs[dheight], j, k, dx, + prev_dy = ytab[j_start].di; + + for (dx = 0; dx < dwidth; dx++) + sum[dx] = (WT)0; + + for (j = j_start; j < j_end; j++) { + WT beta = ytab[j].alpha; + int dy = ytab[j].di; + int sy = ytab[j].si; + + { + const T* S = (const T*)(src.ptr(sy)); + for (dx = 0; dx < dwidth; dx++) + buf[dx] = (WT)0; + + if (cn == 1) + for (k = 0; k < xtab_size; k++) { + int dxn = xtab[k].di; + WT alpha = xtab[k].alpha; + buf[dxn] += S[xtab[k].si] * alpha; + } + else if (cn == 3) + for (k = 0; k < xtab_size; k++) { + int sxn = xtab[k].si; + int dxn = xtab[k].di; + WT alpha = xtab[k].alpha; + WT t0 = buf[dxn] + S[sxn] * alpha; + WT t1 = buf[dxn + 1] + S[sxn + 1] * alpha; + WT t2 = buf[dxn + 2] + S[sxn + 2] * alpha; + buf[dxn] = t0; + buf[dxn + 1] = t1; + buf[dxn + 2] = t2; + } + else { + megdnn_throw(("nr. of channels must be 1 or 3")); + } + } + + if (dy != prev_dy) { + T* D = dst.ptr(prev_dy); + + for (dx = 0; dx < dwidth; dx++) { + D[dx] = saturate_cast(sum[dx]); + sum[dx] = beta * buf[dx]; + } + prev_dy = dy; + } else { + for (dx = 0; dx < dwidth; dx++) + sum[dx] += beta * buf[dx]; + } + } + + { + T* D = dst.ptr(prev_dy); + for (dx = 0; dx < dwidth; dx++) + D[dx] = saturate_cast(sum[dx]); + } +} + +template +ResizeAreaFunc get_resize_area_func() { + megdnn_throw(("unknown type")); +} +template <> +ResizeAreaFunc get_resize_area_func() { + return resizeArea_; +} +template <> +ResizeAreaFunc get_resize_area_func() { + return resizeArea_; +} + +template +void resize_opencv(const Mat& src, Mat& dst, InterpolationMode ip) { + // fake area mode missing here + int dwidth = dst.width(); + int dheight = dst.height(); + int swidth = src.width(); + int sheight = src.height(); + int xmin = 0, xmax = dwidth, width = dwidth * dst.channels(); + double inv_scale_x = static_cast(dwidth) / swidth; + double inv_scale_y = static_cast(dheight) / sheight; + double scale_x = 1.0 / inv_scale_x; + double scale_y = 1.0 / inv_scale_y; + int dx, sx, dy, sy, k; + float fx, fy; + int cn = src.channels(); + { + int iscale_x = saturate_cast(scale_x); + int iscale_y = saturate_cast(scale_y); + + bool is_area_fast = std::abs(scale_x - iscale_x) < DBL_EPSILON && + std::abs(scale_y - iscale_y) < DBL_EPSILON; + if (ip == IMode::INTER_LINEAR && is_area_fast && iscale_x == 2 && + iscale_y == 2) { + ip = IMode::INTER_AREA; + } + if (ip == IMode::INTER_AREA && scale_x >= 1 && scale_y >= 1) { + if (is_area_fast) { + int area = iscale_x * iscale_y; + size_t srcstep = src.step(); + AlignedVector _ofs(area + dwidth * cn); + int* ofs = _ofs.data(); + int* xofs = ofs + area; + ResizeAreaFastFunc func = + get_resize_area_fast_func(); /// need change + for (sy = 0, k = 0; sy < iscale_y; ++sy) + for (sx = 0; sx < iscale_x; ++sx) + ofs[k++] = static_cast(sy * srcstep + sx * cn); + for (dx = 0; dx < dwidth; ++dx) { + int j = dx * cn; + sx = iscale_x * j; + for (k = 0; k < cn; ++k) + xofs[j + k] = sx + k; + } + func(src, dst, ofs, xofs, iscale_x, iscale_y); + return; + } + ResizeAreaFunc func = get_resize_area_func(); + AlignedVector _xytab((swidth + sheight) * 2); + DecimateAlpha *xtab = _xytab.data(), *ytab = xtab + swidth * 2; + int xtab_size = + compute_resize_area_tab(swidth, dwidth, cn, scale_x, xtab); + int ytab_size = + compute_resize_area_tab(sheight, dheight, 1, scale_y, ytab); + AlignedVector _tabofs(dheight + 1); + int* tabofs = _tabofs.data(); + for (k = 0, dy = 0; k < ytab_size; ++k) { + if (k == 0 || ytab[k].di != ytab[k - 1].di) { + megdnn_assert(ytab[k].di == dy); + tabofs[dy++] = k; + } + } + tabofs[dy] = ytab_size; + func(src, dst, xtab, xtab_size, ytab, ytab_size, tabofs); + return; + } + } + bool area_mode = (ip == IMode::INTER_AREA); + int ksize, ksize2; + ResizeFunc func; + bool fixedpt; + setup_resize_env(ip, ksize, fixedpt, func); + ksize2 = ksize / 2; + AlignedVector _buffer((width + dst.height()) * + (sizeof(int) + sizeof(float) * ksize)); + uchar* buffer = _buffer.data(); + int* xofs = static_cast(static_cast(buffer)); + int* yofs = xofs + width; + float* alpha = static_cast(static_cast(yofs + dst.height())); + short* ialpha = static_cast(static_cast(alpha)); + float* beta = alpha + width * ksize; + short* ibeta = static_cast(static_cast(beta)); + // float cbuf[16]; + float cbuf[16] = {0}; + for (dx = 0; dx < dwidth; ++dx) { + if (!area_mode) { + fx = (float)((dx + 0.5) * scale_x - 0.5); + sx = floor(fx); + fx -= sx; + } else { + sx = floor(dx * scale_x); + fx = (float)((dx + 1) - (sx + 1) * inv_scale_x); + fx = (fx <= 0 ? 0.0f : fx - floor(fx)); + } + + if (sx < ksize2 - 1) { + xmin = dx + 1; + if (sx < 0 && + (ip != IMode::INTER_CUBIC && ip != IMode::INTER_LANCZOS4)) { + fx = 0; + sx = 0; + } + } + if (sx + ksize2 >= swidth) { + xmax = std::min(xmax, dx); + if (sx >= swidth - 1 && ip != IMode::INTER_CUBIC && + ip != IMode::INTER_LANCZOS4) { + fx = 0; + sx = swidth - 1; + } + } + int k; + for (k = 0, sx *= cn; k < cn; ++k) + xofs[dx * cn + k] = sx + k; + if (ip == IMode::INTER_CUBIC) { + interpolate_cubic(fx, cbuf); + } else if (ip == IMode::INTER_LANCZOS4) { + interpolate_lanczos4(fx, cbuf); + } else { + cbuf[0] = 1.0f - fx; + cbuf[1] = fx; + } + if (fixedpt) { + for (k = 0; k < ksize; ++k) { + ialpha[dx * cn * ksize + k] = + saturate_cast(cbuf[k] * INTER_RESIZE_COEF_SCALE); + } + for (; k < cn * ksize; ++k) { + ialpha[dx * cn * ksize + k] = + ialpha[dx * cn * ksize + k - ksize]; + } + } else { + for (k = 0; k < ksize; ++k) { + alpha[dx * cn * ksize + k] = cbuf[k]; + } + for (; k < cn * ksize; ++k) { + alpha[dx * cn * ksize + k] = alpha[dx * cn * ksize + k - ksize]; + } + } + } + for (dy = 0; dy < dheight; ++dy) { + if (!area_mode) { + fy = static_cast((dy + 0.5) * scale_y - 0.5); + sy = floor(fy); + fy -= sy; + } else { + sy = floor(dy * scale_y); + fy = static_cast((dy + 1) - (sy + 1) * inv_scale_y); + fy = (fy <= 0 ? 0.0f : fy - floor(fy)); + } + yofs[dy] = sy; + if (ip == IMode::INTER_CUBIC) { + interpolate_cubic(fy, cbuf); + } else if (ip == IMode::INTER_LANCZOS4) { + interpolate_lanczos4(fy, cbuf); + } else { + cbuf[0] = 1.0f - fy; + cbuf[1] = fy; + } + if (fixedpt) { + for (int k = 0; k < ksize; ++k) { + ibeta[dy * ksize + k] = + saturate_cast(cbuf[k] * INTER_RESIZE_COEF_SCALE); + } + } else { + for (int k = 0; k < ksize; ++k) { + beta[dy * ksize + k] = cbuf[k]; + } + } + } + func(src, dst, xofs, + fixedpt ? static_cast(ialpha) : static_cast(alpha), yofs, + fixedpt ? static_cast(ibeta) : static_cast(beta), xmin, + xmax, ksize); +} + +} // anonymous namespace + +void megdnn::arm_common::resize_cv_exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + param::Resize::InterpolationMode imode) { + megdnn_assert(src.layout[3] == 1 || src.layout[3] == 3, + "unsupported src channel"); + for (size_t i = 0; i < src.layout.shape[0]; ++i) { + if (dst.layout.dtype == dtype::Float32()) { + MIDOUT_BEGIN(megdnn_arm_resizecv_dtype, midout_iv(0)) { + Mat src_mat = TensorND2Mat(src, i); + Mat dst_mat = TensorND2Mat(dst, i); + switch (imode) { + case IMode::INTER_NEAREST: + MIDOUT_BEGIN(megdnn_arm_resizecv_imode, midout_iv(0)) { + resize_nearest_32f(src_mat, dst_mat); + } + MIDOUT_END(); + break; + case IMode::INTER_LINEAR: + MIDOUT_BEGIN(megdnn_arm_resizecv_imode, midout_iv(1)) { + resize_linear_32f(src_mat, dst_mat); + } + MIDOUT_END(); + break; + case IMode::INTER_CUBIC: + case IMode::INTER_LANCZOS4: + case IMode::INTER_AREA: + MIDOUT_BEGIN(megdnn_arm_resizecv_imode, midout_iv(2)) { + resize_opencv(src_mat, dst_mat, imode); + } + MIDOUT_END(); + break; + default: + megdnn_throw("unsupported interpolation mode"); + break; + } + } + MIDOUT_END(); + } else if (dst.layout.dtype == dtype::Uint8()) { + MIDOUT_BEGIN(megdnn_arm_resizecv_dtype, midout_iv(1)) { + Mat src_mat = TensorND2Mat(src, i); + Mat dst_mat = TensorND2Mat(dst, i); + switch (imode) { + case IMode::INTER_NEAREST: + MIDOUT_BEGIN(megdnn_arm_resizecv_imode, midout_iv(0)) { + resize_nearest_8u(src_mat, dst_mat); + } + MIDOUT_END(); + break; + case IMode::INTER_LINEAR: + MIDOUT_BEGIN(megdnn_arm_resizecv_imode, midout_iv(1)) { + resize_linear_8u(src_mat, dst_mat); + } + MIDOUT_END(); + break; + case IMode::INTER_CUBIC: + case IMode::INTER_LANCZOS4: + case IMode::INTER_AREA: + MIDOUT_BEGIN(megdnn_arm_resizecv_imode, midout_iv(2)) { + resize_opencv(src_mat, dst_mat, imode); + } + MIDOUT_END(); + break; + default: + megdnn_throw("unsupported interpolation mode"); + break; + } + } + MIDOUT_END(); + } else { + megdnn_throw(megdnn_mangle("Unsupported datatype of resize optr.")); + } + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/resize/resize_cv.h b/dnn/src/arm_common/resize/resize_cv.h new file mode 100644 index 00000000..7b2093b6 --- /dev/null +++ b/dnn/src/arm_common/resize/resize_cv.h @@ -0,0 +1,29 @@ +/** + * \file dnn/src/arm_common/resize/resize_cv.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include + +#include "src/common/cv/helper.h" + +namespace megdnn { +namespace arm_common { + +/** + * \fn resize_cv_exec + * \brief Used if the format is NHWC, transfer from megcv + */ +void resize_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + param::Resize::InterpolationMode imode); + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/separable_conv/opr_impl.cpp b/dnn/src/arm_common/separable_conv/opr_impl.cpp new file mode 100644 index 00000000..f4ddb1fa --- /dev/null +++ b/dnn/src/arm_common/separable_conv/opr_impl.cpp @@ -0,0 +1,50 @@ +/** + * \file dnn/src/arm_common/separable_conv/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/separable_conv/opr_impl.h" +#include "./sep_conv_filter.h" +#include "src/common/utils.h" +//#include "src/arm_common/profile.h" +#include "src/arm_common/handle.h" +#include + +namespace megdnn { +namespace arm_common { +using namespace sep_conv; + +void SeparableConvImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) +{ + check_exec(src.layout, filter_x.layout, filter_y.layout, dst.layout, workspace.size); + int ih = src.layout.shape[2]; + int iw = src.layout.shape[3]; + int oh = dst.layout.shape[2]; + int ow = dst.layout.shape[3]; + + filter_engine_ = new FilterEngine(ih, iw, oh, ow, + param().ksize_h, param().ksize_w, + param().anchor_h, param().anchor_w, + param().borderMode, param().is_symm_kernel); + + MEGDNN_DISPATCH_CPU_KERN_OPR( + filter_engine_->exec(src, filter_x, filter_y, dst); + ); + + delete(filter_engine_); + +} + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/separable_conv/opr_impl.h b/dnn/src/arm_common/separable_conv/opr_impl.h new file mode 100644 index 00000000..352c8a7e --- /dev/null +++ b/dnn/src/arm_common/separable_conv/opr_impl.h @@ -0,0 +1,40 @@ +/** + * \file dnn/src/arm_common/separable_conv/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "./sep_conv_filter.h" +namespace megdnn { +namespace arm_common { +using namespace sep_conv; +class SeparableConvImpl: public SeparableConvForward { + public: + //SeparableConvForwardImpl(Handle *handle): SeparableConvForward(handle) {} + using SeparableConvForward::SeparableConvForward; + void exec(_megdnn_tensor_in src, + _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout &, + const TensorLayout &, + const TensorLayout &, + const TensorLayout &) override + { + // TODO: deduce the size of ring buffer. + return 0; + } + FilterEngine* filter_engine_; +}; + +} // namespace arm_common +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/separable_conv/sep_conv_common.h b/dnn/src/arm_common/separable_conv/sep_conv_common.h new file mode 100644 index 00000000..e8a93759 --- /dev/null +++ b/dnn/src/arm_common/separable_conv/sep_conv_common.h @@ -0,0 +1,158 @@ +/** + * \file dnn/src/arm_common/separable_conv/sep_conv_common.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/common/utils.h" +#include "megdnn/oprs.h" + +namespace megdnn { +namespace arm_common { +namespace sep_conv { + +#define VEC_ALIGN 16 + +using BorderMode = SeparableConv::Param::BorderMode; +using uchar = unsigned char; +using ushort = unsigned short; + +/////////// helper /////////// + +static inline size_t align_size(size_t sz, int n) +{ + megdnn_assert((n & (n - 1)) == 0); + return (sz + n-1) & -n; +} + +static inline int clip(int x, int a, int b) +{ + return x >= a ? (x < b ? x : b-1) : a; +} + +template static inline _Tp* align_ptr(_Tp* ptr, int n=(int)sizeof(_Tp)) +{ + return (_Tp*)(((size_t)ptr + n-1) & -n); +} + +template +T saturate_cast(T x) +{ return x; } + +template +T saturate_cast(int x) +{ + return static_cast(x); +} +template +T saturate_cast(float x) +{ + return static_cast(x); +} +template +T saturate_cast(double x) +{ + return static_cast(x); +} + +// int -> uchar +template<> unsigned char saturate_cast(int x); +// int -> short +template<> short saturate_cast(int x); +// float -> int +template<> int saturate_cast(float x); +// float -> short +template<> short saturate_cast(float x); +// double -> int +template<> int saturate_cast(double x); + + +template struct FixedPtCast +{ + typedef ST type1; + typedef DT rtype; + enum { SHIFT = bits, DELTA = 1 << (bits-1) }; + + DT operator()(ST val) const + { return saturate_cast
((val + DELTA)>>SHIFT); } +}; + +template struct FixedPtCastEx +{ + typedef ST type1; + typedef DT rtype; + + FixedPtCastEx() : SHIFT(0), DELTA(0) {} + FixedPtCastEx(int bits) : SHIFT(bits), DELTA(bits ? 1 << (bits-1) : 0) {} + DT operator()(ST val) const { return saturate_cast
(val + DELTA); } + int SHIFT, DELTA; +}; + +template<> struct FixedPtCastEx +{ + typedef int type1; + typedef uchar rtype; + + FixedPtCastEx() : SHIFT(0), DELTA(0) {} + FixedPtCastEx(int bits) : SHIFT(bits), DELTA(bits ? 1 << (bits-1) : 0) {} + uchar operator()(int val) const { return saturate_cast((val + DELTA)>>SHIFT); } + int SHIFT, DELTA; +}; + + +template struct Cast +{ + typedef ST type1; + typedef DT rtype; + + DT operator()(ST val) const { return saturate_cast
(val); } +}; + +static inline int border_interpolate(int p, int len, BorderMode bmode) +{ + if( (unsigned)p < (unsigned)len ) + ; + else if( bmode == BorderMode::BORDER_REPLICATE ) + p = p < 0 ? 0 : len - 1; + else if( bmode == BorderMode::BORDER_REFLECT || bmode == BorderMode::BORDER_REFLECT_101 ) + { + int delta = (bmode == BorderMode::BORDER_REFLECT_101); + if( len == 1 ) + return 0; + do + { + if( p < 0 ) + p = -p - 1 + delta; + else + p = len - 1 - (p - len) - delta; + } + while( (unsigned)p >= (unsigned)len ); + } + else if( bmode == BorderMode::BORDER_WRAP ) + { + megdnn_assert(len > 0); + if( p < 0 ) + p -= ((p-len+1)/len)*len; + while (p >= len) { + p -= len; + } + } + else if( bmode == BorderMode::BORDER_CONSTANT ) + p = -1; + else + megdnn_throw("Unknown/unsupported border type"); + return p; +} +/////////// helper /////////// + +} // namespace sep_conv +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/separable_conv/sep_conv_filter.h b/dnn/src/arm_common/separable_conv/sep_conv_filter.h new file mode 100644 index 00000000..d25a44f6 --- /dev/null +++ b/dnn/src/arm_common/separable_conv/sep_conv_filter.h @@ -0,0 +1,116 @@ +/** + * \file dnn/src/arm_common/separable_conv/sep_conv_filter.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "./sep_conv_common.h" +#include "src/common/utils.h" +#pragma once +namespace megdnn { +namespace arm_common { +namespace sep_conv { +//#define BorderMode param::SeparableConv::BorderMode +//#define BorderMode SeparableConv::Param::BorderMode +using BorderMode = SeparableConv::Param::BorderMode; +//using uchar = unsigned char; +//using ushort = unsigned short; + +class BaseRowFilter { +public: + //! the default constructor + BaseRowFilter(); + //! the destructor + virtual ~BaseRowFilter(); + //! the filtering operator. Must be overridden in the derived classes. The horizontal border interpolation is done outside of the class. + virtual void operator()(const uchar* src, uchar* dst, uchar* kernel, int width, int cn) = 0; + + int ksize; + int anchor; +}; + + +class BaseColumnFilter { +public: + //! the default constructor + BaseColumnFilter(); + //! the destructor + virtual ~BaseColumnFilter(); + //! the filtering operator. Must be overridden in the derived classes. The vertical border interpolation is done outside of the class. + virtual void operator()(const uchar** src, uchar* dst, uchar* kernel, int dststep, int dstcount, int width) = 0; + //! resets the internal buffers, if any + virtual void reset(); + + int ksize; + int anchor; +}; + +class FilterEngine { +public: + //FilterEngine(); + + FilterEngine(const int &ih, const int &iw, + const int &oh, const int &ow, + const int &kh, const int &kw, + const int &anchor_h, const int &anchor_w, + BorderMode borderType = BorderMode::BORDER_CONSTANT, + bool is_symm_kernel = true); + + virtual ~FilterEngine(); + + void init( const int &ih, const int &iw, + const int &oh, const int &ow, + const int &kh, const int &kw, + const int &anchor_h, const int &anchor_w, + BorderMode borderType, + bool is_symm_kernel); + + void exec( const TensorND & src, + const TensorND & kernel_x, + const TensorND & kernel_y, + const TensorND & dst); + + BaseRowFilter* getSepRowFilter(); + BaseColumnFilter* getSepColFilter(); + + inline int getBorderRowIdx1(int idx); + + +private: + // kernel + int ksize_x_, ksize_y_; + int anchor_x_, anchor_y_; // anchors is useless in this version. + int is_symm_kernel_; // are the kernels symmtric. + + //filter + BaseRowFilter *rowFilter_; + BaseColumnFilter *colFilter_; + + //buffer + std::vector srcRow_; // a buffer of a single appended input row + std::vector ringBuf_; // a buffer of middle results. size = maxBufferRow * (maxWidth + kernel_w - 1) + std::vector row_ptr_; + int rowBuffStride_; // aligned stride of a row in the buffer. + int rowBufferOutputRow_; // each time the buffer is full, we can calculate 'rowBufferOutputRow' out rows at one time. + // In this version rowBufferOutputRow_ = 1. + int maxBufferRow_; // max_size_of buffer row. maxBufferRow_ = ksize_y + (rowBufferOutputRow_ - 1) + // In this version maxBufferRow_ = ksize_y. + + //border + BorderMode borderType_; + int dx1_, dx2_, dy1_, dy2_; + std::vector borderTab_; // src idx of border elements + std::vector constBorderValue_; // template of append value (out of mat edge) + std::vector constBorderRow_; // a row of srcRow full of border value ---rowFilter---> constBorderRow +}; + + +} // namespace sep_conv +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/separable_conv/sep_conv_filter_engine.cpp b/dnn/src/arm_common/separable_conv/sep_conv_filter_engine.cpp new file mode 100644 index 00000000..67ffe737 --- /dev/null +++ b/dnn/src/arm_common/separable_conv/sep_conv_filter_engine.cpp @@ -0,0 +1,895 @@ +/** + * \file dnn/src/arm_common/separable_conv/sep_conv_filter_engine.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "./sep_conv_filter.h" + +#include +#include +#include "src/arm_common/simd_macro/marm_neon.h" +#include + +namespace megdnn { +namespace arm_common { +namespace sep_conv { +using BorderMode = SeparableConv::Param::BorderMode; +using uchar = unsigned char; +using ushort = unsigned short; + +////////////////////////////////////////////// +//vecOp +///////////////////////////////////////////// + +struct RowVec_32f +{ + RowVec_32f() + {} + + RowVec_32f(int _len) + { + ksize = _len; + } + + int operator()(const uchar* _src, uchar* _dst, uchar * kernel, int width, int cn) const + { + + int _ksize = ksize; + const float* src0 = (const float*)_src; + float* dst = (float*)_dst; + const float* _kx = (float*)kernel; + + int i = 0, k; + width *= cn; + + for( ; i <= width - 8; i += 8 ) + { + const float* src = src0 + i; + float32x4_t f, s0 = vdupq_n_f32(0), s1 = s0, x0, x1; + for( k = 0; k < _ksize; k++, src += cn ) + { + f = vdupq_n_f32(_kx[k]); + + x0 = vld1q_f32(src); + x1 = vld1q_f32(src + 4); + s0 = vmlaq_f32(s0, x0, f); + s1 = vmlaq_f32(s1, x1, f); + } + vst1q_f32(dst + i, s0); + vst1q_f32(dst + i + 4, s1); + } + for( ; i <= width - 4; i += 4 ) + { + const float* src = src0 + i; + float32x4_t f, s0 = vdupq_n_f32(0), x0; + for( k = 0; k < _ksize; k++, src += cn ) + { + f = vdupq_n_f32(_kx[k]); + + x0 = vld1q_f32(src); + s0 = vmlaq_f32(s0, x0, f); + } + vst1q_f32(dst + i, s0); + } + return i; + } + int ksize; +}; + +struct SymmRowSmallVec_32f +{ + SymmRowSmallVec_32f() {} + SymmRowSmallVec_32f(int _len) + { + ksize = _len; + } + + int operator()(const uchar* _src, uchar* _dst, uchar * kernel, int width, int cn) const + { + int i = 0, _ksize = ksize; + float* dst = (float*)_dst; + const float* src = (const float*)_src + (_ksize/2)*cn; + const float* kx = (float*)kernel + _ksize/2; + width *= cn; + + { + if( _ksize == 1 ) + return 0; + if( _ksize == 3 ) + { + float32x4_t k0 = vdupq_n_f32(kx[0]), k1 = vdupq_n_f32(kx[1]); + for( ; i <= width - 8; i += 8, src += 8 ) + { + float32x4_t x0, x1, x2, y0, y1, y2; + x0 = vld1q_f32(src - cn); + x1 = vld1q_f32(src); + x2 = vld1q_f32(src + cn); + y0 = vld1q_f32(src - cn + 4); + y1 = vld1q_f32(src + 4); + y2 = vld1q_f32(src + cn + 4); + + x0 = vmulq_f32(vaddq_f32(x0, x2), k1); + y0 = vmulq_f32(vaddq_f32(y0, y2), k1); + x0 = vmlaq_f32(x0, x1, k0); + y0 = vmlaq_f32(y0, y1, k0); + vst1q_f32(dst + i, x0); + vst1q_f32(dst + i + 4, y0); + } + } + else if( _ksize == 5 ) + { + float32x4_t k0 = vdupq_n_f32(kx[0]), k1 = vdupq_n_f32(kx[1]), k2 = vdupq_n_f32(kx[2]); + for( ; i <= width - 8; i += 8, src += 8 ) + { + float32x4_t x0, x1, x2, y0, y1, y2; + x0 = vld1q_f32(src - cn); + x1 = vld1q_f32(src); + x2 = vld1q_f32(src + cn); + y0 = vld1q_f32(src - cn + 4); + y1 = vld1q_f32(src + 4); + y2 = vld1q_f32(src + cn + 4); + + x0 = vmulq_f32(vaddq_f32(x0, x2), k1); + y0 = vmulq_f32(vaddq_f32(y0, y2), k1); + x0 = vmlaq_f32(x0, x1, k0); + y0 = vmlaq_f32(y0, y1, k0); + + x2 = vaddq_f32(vld1q_f32(src + cn*2), vld1q_f32(src - cn*2)); + y2 = vaddq_f32(vld1q_f32(src + cn*2 + 4), vld1q_f32(src - cn*2 + 4)); + x0 = vmlaq_f32(x0, x2, k2); + y0 = vmlaq_f32(y0, y2, k2); + + vst1q_f32(dst + i, x0); + vst1q_f32(dst + i + 4, y0); + } + } + + } + return i; + } + int ksize; +}; + +struct ColumnVec_32f +{ + ColumnVec_32f() {} + ColumnVec_32f(int _len, int) + { + ksize = _len; + } + + int operator()(const uchar** _src, uchar* _dst, uchar * kernel, int &, int width) const + { + const float* ky = (const float*)kernel; + int i = 0, k; + const float** src = (const float**)_src; + const float *S; + float* dst = (float*)_dst; + + { + for( ; i <= width - 16; i += 16 ) + { + float32x4_t f = vdupq_n_f32(ky[0]); + + float32x4_t s0, s1, s2, s3; + float32x4_t x0, x1; + S = src[0] + i; + s0 = vld1q_f32(S); + s1 = vld1q_f32(S+4); + s0 = vmulq_f32(s0, f); + s1 = vmulq_f32(s1, f); + s2 = vld1q_f32(S+8); + s3 = vld1q_f32(S+12); + s2 = vmulq_f32(s2, f); + s3 = vmulq_f32(s3, f); + + for( k = 1; k < ksize; k++ ) + { + S = src[k] + i; + float32x4_t f = vdupq_n_f32(ky[k]); + x0 = vld1q_f32(S); + x1 = vld1q_f32(S+4); + s0 = vmlaq_f32(s0, f, x0); + s1 = vmlaq_f32(s1, f, x1); + + x0 = vld1q_f32(S+8); + x1 = vld1q_f32(S+12); + s2 = vmlaq_f32(s2, f, x0); + s3 = vmlaq_f32(s3, f, x1); + } + s0 = vaddq_f32(s0, vld1q_f32(dst+i)); + s1 = vaddq_f32(s1, vld1q_f32(dst+i+4)); + s2 = vaddq_f32(s2, vld1q_f32(dst+i+8)); + s3 = vaddq_f32(s3, vld1q_f32(dst+i+12)); + + vst1q_f32(dst + i, s0); + vst1q_f32(dst + i + 4, s1); + vst1q_f32(dst + i + 8, s2); + vst1q_f32(dst + i + 12, s3); + } + + for( ; i <= width - 4; i += 4 ) + { + float32x4_t f = vdupq_n_f32(ky[0]); + + float32x4_t x0, s0 = vld1q_f32(src[0] + i); + s0 = vmulq_f32(s0, f); + + for( k = 1; k < ksize; k++ ) + { + float32x4_t f = vdupq_n_f32(ky[k]); + S = src[k] + i; + x0 = vld1q_f32(S); + s0 = vmlaq_f32(s0, f, x0); + } + s0 = vaddq_f32(s0, vld1q_f32(dst + i)); + vst1q_f32(dst + i, s0); + } + } + + return i; + } + int ksize; +}; + +struct SymmColumnVec_32f +{ + SymmColumnVec_32f() {} + SymmColumnVec_32f(int _len, int) + { + ksize = _len; + } + + int operator()(const uchar** _src, uchar* _dst, uchar * kernel, int &, int width) const + { + int ksize2 = (ksize)/2; + const float* ky = (const float*)kernel + ksize2; + int i = 0, k; + const float** src = (const float**)_src; + const float *S, *S2; + float* dst = (float*)_dst; + + { + for( ; i <= width - 16; i += 16 ) + { + float32x4_t f = vdupq_n_f32(ky[0]); + + float32x4_t s0, s1, s2, s3; + float32x4_t x0, x1; + S = src[0] + i; + s0 = vld1q_f32(S); + s1 = vld1q_f32(S+4); + s0 = vmulq_f32(s0, f); + s1 = vmulq_f32(s1, f); + s2 = vld1q_f32(S+8); + s3 = vld1q_f32(S+12); + s2 = vmulq_f32(s2, f); + s3 = vmulq_f32(s3, f); + + for( k = 1; k <= ksize2; k++ ) + { + S = src[k] + i; + S2 = src[-k] + i; + float32x4_t f = vdupq_n_f32(ky[k]); + + x0 = vaddq_f32(vld1q_f32(S), vld1q_f32(S2)); + x1 = vaddq_f32(vld1q_f32(S+4), vld1q_f32(S2+4)); + s0 = vmlaq_f32(s0, x0, f); + s1 = vmlaq_f32(s1, x1, f); + x0 = vaddq_f32(vld1q_f32(S+8), vld1q_f32(S2+8)); + x1 = vaddq_f32(vld1q_f32(S+12), vld1q_f32(S2+12)); + s2 = vmlaq_f32(s2, x0, f); + s3 = vmlaq_f32(s3, x1, f); + + } + s0 = vaddq_f32(s0, vld1q_f32(dst+i)); + s1 = vaddq_f32(s1, vld1q_f32(dst+i+4)); + s2 = vaddq_f32(s2, vld1q_f32(dst+i+8)); + s3 = vaddq_f32(s3, vld1q_f32(dst+i+12)); + + vst1q_f32(dst + i, s0); + vst1q_f32(dst + i + 4, s1); + vst1q_f32(dst + i + 8, s2); + vst1q_f32(dst + i + 12, s3); + } + + for( ; i <= width - 4; i += 4 ) + { + float32x4_t f = vdupq_n_f32(ky[0]); + float32x4_t x0, s0 = vld1q_f32(src[0] + i); + s0 = vmulq_f32(s0, f); + + for( k = 1; k <= ksize2; k++ ) + { + float32x4_t f = vdupq_n_f32(ky[k]); + S = src[k] + i; + S2 = src[-k] + i; + x0 = vaddq_f32(vld1q_f32(S), vld1q_f32(S2)); + s0 = vmlaq_f32(s0, x0, f); + } + s0 = vaddq_f32(s0, vld1q_f32(dst + i)); + vst1q_f32(dst + i, s0); + } + } + + return i; + + } + int ksize; +}; + + +struct SymmColumnSmallVec_32f +{ + SymmColumnSmallVec_32f() { } + SymmColumnSmallVec_32f(int _len, int) + { + ksize = _len; + } + + int operator()(const uchar** _src, uchar* _dst, uchar * kernel, int & count, int width) const + { + (void)count; + + int ksize2 = (ksize)/2; + const float* ky = (float*)kernel + ksize2; + int i = 0; + const float** src = (const float**)_src; + const float *S0 = src[-1], *S1 = src[0], *S2 = src[1]; + float* dst = (float*)_dst; + { + float32x4_t k0 = vdupq_n_f32(ky[0]), k1 = vdupq_n_f32(ky[1]); + for( ; i <= width - 8; i += 8 ) + { + float32x4_t s0, s1, x0, x1; + s0 = vld1q_f32(S1 + i); + s1 = vld1q_f32(S1 + i + 4); + s0 = vmulq_f32(s0, k0); + s1 = vmulq_f32(s1, k0); + x0 = vaddq_f32(vld1q_f32(S0 + i), vld1q_f32(S2 + i)); + x1 = vaddq_f32(vld1q_f32(S0 + i + 4), vld1q_f32(S2 + i + 4)); + s0 = vmlaq_f32(s0, x0, k1); + s1 = vmlaq_f32(s1, x1, k1); + s0 = vaddq_f32(s0, vld1q_f32(dst + i)); + s1 = vaddq_f32(s1, vld1q_f32(dst + i + 4)); + vst1q_f32(dst + i, s0); + vst1q_f32(dst + i + 4, s1); + } + } + + return i; + } + int ksize; +}; + +////////////////////////////////////////////////////////////////////////////////////// +//%RowFilter% +////////////////////////////////////////////////////////////////////////////////////// + +BaseRowFilter::BaseRowFilter() { ksize = anchor = -1; } +BaseRowFilter::~BaseRowFilter() {} + +template struct RowFilter : public BaseRowFilter +{ + RowFilter(int _ksize, int _anchor, const VecOp& _vecOp=VecOp() ) + { + anchor = _anchor; + ksize = _ksize; + vecOp = _vecOp; + } + + void operator()(const uchar* src, uchar* dst, uchar* kernel, int width, int cn) + { + int _ksize = ksize; + const DT* kx = (DT* )kernel; + const ST* S; + DT* D = (DT*)dst; + int i, k; + + i = vecOp(src, dst, kernel, width, cn); + width *= cn; +#if MEGCV_ENABLE_UNROLLED + for( ; i <= width - 4; i += 4 ) + { + S = (const ST*)src + i; + DT f = kx[0]; + DT s0 = f*S[0], s1 = f*S[1], s2 = f*S[2], s3 = f*S[3]; + + for( k = 1; k < _ksize; k++ ) + { + S += cn; + f = kx[k]; + s0 += f*S[0]; s1 += f*S[1]; + s2 += f*S[2]; s3 += f*S[3]; + } + + D[i] = s0; D[i+1] = s1; + D[i+2] = s2; D[i+3] = s3; + } +#endif + for( ; i < width; i++ ) + { + S = (const ST*)src + i; + DT s0 = kx[0]*S[0]; + for( k = 1; k < _ksize; k++ ) + { + S += cn; + s0 += kx[k]*S[0]; + } + D[i] = s0; + } + } + VecOp vecOp; +}; + + +template struct SymmRowSmallFilter : + public RowFilter +{ + SymmRowSmallFilter(int _ksize, int _anchor, + const VecOp& _vecOp = VecOp() ) + : RowFilter( _ksize, _anchor, _vecOp ) + {} + + void operator()(const uchar* src, uchar* dst, uchar* kernel, int width, int cn) + { + int ksize2 = this->ksize/2, ksize2n = ksize2*cn; + const DT* kx = (DT*)kernel + ksize2; + DT* D = (DT*)dst; + int i = this->vecOp(src, dst, kernel, width, cn), j, k; + const ST* S = (const ST*)src + i + ksize2n; + width *= cn; + + { + if( this->ksize == 1 && kx[0] == 1 ) + { + for( ; i <= width - 2; i += 2 ) + { + DT s0 = S[i], s1 = S[i+1]; + D[i] = s0; D[i+1] = s1; + } + S += i; + } + else if( this->ksize == 3 ) + { + DT k0 = kx[0], k1 = kx[1]; + for( ; i <= width - 2; i += 2, S += 2 ) + { + DT s0 = S[0]*k0 + (S[-cn] + S[cn])*k1, s1 = S[1]*k0 + (S[1-cn] + S[1+cn])*k1; + D[i] = s0; D[i+1] = s1; + } + } + else if( this->ksize == 5 ) + { + DT k0 = kx[0], k1 = kx[1], k2 = kx[2]; + for( ; i <= width - 2; i += 2, S += 2 ) + { + DT s0 = S[0]*k0 + (S[-cn] + S[cn])*k1 + (S[-cn*2] + S[cn*2])*k2; + DT s1 = S[1]*k0 + (S[1-cn] + S[1+cn])*k1 + (S[1-cn*2] + S[1+cn*2])*k2; + D[i] = s0; D[i+1] = s1; + } + } + + for( ; i < width; i++, S++ ) + { + DT s0 = kx[0]*S[0]; + for( k = 1, j = cn; k <= ksize2; k++, j += cn ) + s0 += kx[k]*(S[j] + S[-j]); + D[i] = s0; + } + } + } +}; + +template + BaseRowFilter * getLinearRowFilter(int ksize, bool is_symm_kernel) + { + // TODO: calculate anchor + int anchor = ksize/2; + if(is_symm_kernel) { + if( ksize <= 5 ) + { + //if( typeid(T) == typeid(float) && typeid(T1) == typeid(float)) + return new SymmRowSmallFilter + (ksize, anchor, SymmRowSmallVec_32f(ksize)); + } + + //if( typeid(T) == typeid(float) && typeid(T1) == typeid(float)) + return new RowFilter + (ksize, anchor, RowVec_32f(ksize)); + } else { + //if( typeid(T) == typeid(float) && typeid(T1) == typeid(float)) + return new RowFilter + (ksize, anchor, RowVec_32f(ksize)); + } + + //printf("Unsupported combination of source format (=%s), and buffer format (=%s)", + // typeid(T).name(), typeid(T1).name()); + //exit(1); + } +////////////////////////////////////////////////////////////////////////////////////// + + +////////////////////////////////////////////////////////////////////////////////////// +//%BaseColFilter% +////////////////////////////////////////////////////////////////////////////////////// + +BaseColumnFilter::BaseColumnFilter() { ksize = anchor = -1; } +BaseColumnFilter::~BaseColumnFilter() {} +void BaseColumnFilter::reset() {} + +template struct ColumnFilter : public BaseColumnFilter +{ + typedef typename CastOp::type1 ST; + typedef typename CastOp::rtype DT; + + ColumnFilter(int _ksize, int _anchor, + const CastOp& _castOp=CastOp(), + const VecOp& _vecOp=VecOp()) + { + this->anchor = _anchor; + this->ksize = _ksize; + this->castOp0 = _castOp; + this->vecOp = _vecOp; + } + + void operator()(const uchar** src, uchar* dst, uchar* kernel, int dststep, int count, int width) + { + const ST* ky = (ST*)kernel; + int i = 0, k; + CastOp castOp = this->castOp0; + + { + for( ; count > 0; count--, dst += dststep, src++ ) + { + DT* D = (DT*)dst; + i = (this->vecOp)(src, dst, kernel, count, width); +#if MEGCV_ENABLE_UNROLLED + for( ; i <= width - 4; i += 4 ) + { + ST f = ky[0]; + const ST* S = (const ST*)src[0] + i; + ST s0 = f*S[0], s1 = f*S[1], + s2 = f*S[2], s3 = f*S[3]; + + for( k = 1; k < ksize; k++ ) + { + S = (const ST*)src[k] + i; + f = ky[k]; + s0 += f*S[0]; + s1 += f*S[1]; + s2 += f*S[2]; + s3 += f*S[3]; + } + + D[i] += castOp(s0); D[i+1] += castOp(s1); + D[i+2] += castOp(s2); D[i+3] += castOp(s3); + } +#endif + for( ; i < width; i++ ) + { + ST s0 = D[i]; + //ST s0 = ky[0]*((const ST*)src[0])[i]; + for( k = 0; k < ksize; k++ ) { + s0 += ky[k]* ((const ST*)src[k])[i]; + } + D[i] = castOp(s0); + //D[i] += castOp(s0); + } + } + } + } + CastOp castOp0; + VecOp vecOp; +}; + +template struct SymmColumnFilter : public BaseColumnFilter +{ + typedef typename CastOp::type1 ST; + typedef typename CastOp::rtype DT; + + SymmColumnFilter(int _ksize, int _anchor, + const CastOp& _castOp=CastOp(), + const VecOp& _vecOp=VecOp()) + { + this->anchor = _anchor; + this->ksize = _ksize; + this->castOp0 = _castOp; + this->vecOp = _vecOp; + } + + void operator()(const uchar** src, uchar* dst, uchar* kernel, int dststep, int count, int width) + { + int ksize2 = this->ksize/2; + const ST* ky = (ST*)kernel + ksize2; + int i, k; + CastOp castOp = this->castOp0; + src += ksize2; + + { + for( ; count > 0; count--, dst += dststep, src++ ) + { + DT* D = (DT*)dst; + i = (this->vecOp)(src, dst, kernel, count, width); +#if MEGCV_ENABLE_UNROLLED + for( ; i <= width - 4; i += 4 ) + { + ST f = ky[0]; + const ST* S = (const ST*)src[0] + i, *S2; + ST s0 = f*S[0], s1 = f*S[1], + s2 = f*S[2], s3 = f*S[3]; + + for( k = 1; k <= ksize2; k++ ) + { + S = (const ST*)src[k] + i; + S2 = (const ST*)src[-k] + i; + f = ky[k]; + s0 += f*(S[0] + S2[0]); + s1 += f*(S[1] + S2[1]); + s2 += f*(S[2] + S2[2]); + s3 += f*(S[3] + S2[3]); + } + + D[i] += castOp(s0); D[i+1] += castOp(s1); + D[i+2] += castOp(s2); D[i+3] += castOp(s3); + } +#endif + for( ; i < width; i++ ) + { + ST s0 = ky[0]*((const ST*)src[0])[i]; + for( k = 1; k <= ksize2; k++ ) { + s0 += ky[k]*(((const ST*)src[k])[i] + ((const ST*)src[-k])[i]); + //s0 += ky[k]*((const ST*)src[k])[i]; + //s0 += ky[k]*((const ST*)src[-k])[i]; + } + D[i] += castOp(s0); + } + } + } + } + CastOp castOp0; + VecOp vecOp; +}; + + +template + struct SymmColumnSmallFilter : public SymmColumnFilter +{ + typedef typename CastOp::type1 ST; + typedef typename CastOp::rtype DT; + + SymmColumnSmallFilter( int _ksize, int _anchor, + const CastOp & _castOp=CastOp(), + const VecOp & _vecOp=VecOp()) + : SymmColumnFilter(_ksize, _anchor, _castOp, _vecOp ) + { + megdnn_assert(this->ksize == 3 ); + } + + void operator()(const uchar** src, uchar* dst, uchar* kernel, int dststep, int count, int width) + { + int ksize2 = this->ksize/2; + const ST* ky = (ST*)kernel + ksize2; + int i = 0; + ST f0 = ky[0], f1 = ky[1]; + CastOp castOp = this->castOp0; + src += ksize2; + + /* + if((typeid(ST) == typeid(int) && typeid(DT) == typeid(uchar))) + { + (this->vecOp)(src, dst, kernel, count, width); + } + */ + for( ; count > 0; count--, dst += dststep, src++ ) + { + DT* D = (DT*)dst; + + i = (this->vecOp)(src, dst, kernel, count, width); + if(count == 0) + break; + const ST* S0 = (const ST*)src[-1]; + const ST* S1 = (const ST*)src[0]; + const ST* S2 = (const ST*)src[1]; + { +#if MEGCV_ENABLE_UNROLLED + for( ; i <= width - 4; i += 4 ) + { + ST s0 = (S0[i] + S2[i])*f1 + S1[i]*f0; + ST s1 = (S0[i+1] + S2[i+1])*f1 + S1[i+1]*f0; + D[i] += castOp(s0); + D[i+1] += castOp(s1); + + s0 = (S0[i+2] + S2[i+2])*f1 + S1[i+2]*f0; + s1 = (S0[i+3] + S2[i+3])*f1 + S1[i+3]*f0; + D[i+2] += castOp(s0); + D[i+3] += castOp(s1); + } +#endif + for( ; i < width; i ++ ) + { + ST s0 = (S0[i] + S2[i])*f1 + S1[i]*f0; + D[i] += castOp(s0); + } + } + } + + } +}; + + +template + BaseColumnFilter * getLinearColumnFilter(int ksize, int /*bits*/, bool is_symm_kernel) + { + int anchor = ksize/2; + { + if(is_symm_kernel) { + if( ksize == 3 ) + { + + //if( typeid(T1) == typeid(float) && typeid(T) == typeid(float) ) + return new SymmColumnSmallFilter,SymmColumnSmallVec_32f> + (ksize, anchor, FixedPtCastEx(0), + SymmColumnSmallVec_32f(ksize, 0)); + } + //if( typeid(T1) == typeid(float) && typeid(T) == typeid(float) ) + return new SymmColumnFilter, SymmColumnVec_32f> + (ksize, anchor, FixedPtCastEx(), + SymmColumnVec_32f(ksize, 0)); + } else { + //if( typeid(T1) == typeid(float) && typeid(T) == typeid(float) ) + return new ColumnFilter, ColumnVec_32f> + (ksize, anchor, FixedPtCastEx(), + ColumnVec_32f(ksize, 0)); + } + } + //printf("Unsupported combination of buffer format (=%s), and destination format (=%s)", + // typeid(T1).name(), typeid(T).name()); + //exit(1); + } + +////////////////////////////////////////////////////////////////////////////////////// +////%FilterEngine% +////////////////////////////////////////////////////////////////////////////////////// + + FilterEngine::FilterEngine(const int &ih, const int &iw, + const int &oh, const int &ow, + const int &kh, const int &kw, + const int &anchor_h, const int &anchor_w, + BorderMode borderType, + bool is_symm_kernel) { + init(ih, iw, oh, ow, kh, kw, anchor_h, anchor_w, borderType, is_symm_kernel); + } + + + FilterEngine::~FilterEngine() + { + delete rowFilter_; + delete colFilter_; + } + + void FilterEngine::init(const int &ih, const int &iw, + const int &oh, const int &ow, + const int &kh, const int &kw, + const int &anchor_h, const int &anchor_w, + BorderMode borderType, + bool is_symm_kernel) { + // reduce warning + int wrn = ih + iw + oh; ++wrn; + + ksize_x_ = kw; + ksize_y_ = kh; + anchor_x_ = anchor_w; + anchor_y_ = anchor_h; + borderType_ = borderType; + is_symm_kernel_ = is_symm_kernel; + + rowFilter_ = getLinearRowFilter(kw, is_symm_kernel_); + colFilter_ = getLinearColumnFilter(kh, 0, is_symm_kernel_); + + rowBufferOutputRow_ = 1; + maxBufferRow_ = ksize_y_ + rowBufferOutputRow_ - 1; + //int rowBuffStride_ = sizeof(float)*(int)align_size(maxWidth + (ksize_y_ - 1),VEC_ALIGN); + rowBuffStride_ = sizeof(float) * (int)align_size(ow, VEC_ALIGN); + row_ptr_.resize(maxBufferRow_); + ringBuf_.resize(rowBuffStride_ * maxBufferRow_ + VEC_ALIGN); + + // There is no need to use constBorder when padding == 0. + //if (borderType_ = BORDER_CONSTANT) { + // constBorderRow.resize(sizeof(int) * (maxWidth + ksize.cols() - 1) + VEC_ALIGN); + //} + + + } + + void FilterEngine::exec( const TensorND & src, + const TensorND & kernel_x, + const TensorND & kernel_y, + const TensorND & dst) { + + //int stride_src = src.layout.stride[1]; + //int stride_dst = dst.layout.stride[1]; + //float *src0 = src.ptr(); + //float *dst0 = dst.ptr(); + float * src_cur_row = src.ptr(); + float * src_cur_step = src.ptr(); + float * dst_cur_chan = dst.ptr(); + int width_src = (int)src.layout.shape[3]; + int width_dst = (int)dst.layout.shape[3]; + int height_src = (int)src.layout.shape[2]; + //int height_dst = dst.layout.shape[2]; + int kernel_chan_stride = (int)kernel_x.layout.stride[1]; + memset(dst.ptr(), 0, sizeof(float) * dst.layout.total_nr_elems()); + + for(int step = 0; step < (int)src.layout.shape[0]; ++step) { + for(int chan_out = 0; chan_out < (int)dst.layout.shape[1]; + ++ chan_out, dst_cur_chan += dst.layout.stride[1]) { + float* kx = kernel_x.ptr(); + float* ky = kernel_y.ptr(); + src_cur_row = src_cur_step; + // handle a channel of input + for(int chan_in = 0; chan_in < (int)src.layout.shape[1]; ++ chan_in) { + // 1. init row buffer borden + // No need to init row border when padding == 0. + + // 2. fill ring buffer & calculate + int row_count = 0; + int row_ptr_pos = 0; + int dststep = dst.layout.stride[2]; + int bufRows = (int)row_ptr_.size(); + int bi = 0; + float* dst_cur_row = dst_cur_chan; + for(row_count = 0; row_count < height_src; + ++row_count, src_cur_row += width_src) { + + //2.1 Get tab row. No need to do this when padding == 0. + + //2.2 Calculate a row. + bi = row_count % bufRows; + uchar* brow = align_ptr(&ringBuf_[0], VEC_ALIGN) + bi * rowBuffStride_; + if(row_count < bufRows - 1) { + row_ptr_[bi] = (float*)brow; + } else { + row_ptr_[bufRows - 1] = (float*)brow; + } + + // Get a row & make border + //uchar* row = &srcRow[0]; + //memcpy( row + _dx1*esz, src, (width1 - _dx2 - _dx1)*esz ); + uchar* row = (uchar*)src_cur_row; + (*rowFilter_)(row, brow, (uchar*)kx, width_dst, 1); + // operator()(const uchar* src, uchar* dst, uchar* kernel, int width, int cn) + + // Keeping fill the ring_buff until its length is ky + if(row_count < bufRows - 1) { + ++ row_ptr_pos; + continue; + } + + // 2.3 Calculate column + // operator()(const uchar** src, uchar* dst, ST* kernel, int dststep, int count, int width) + (*colFilter_)((const uchar**)(&row_ptr_[0]), (uchar*)dst_cur_row, + (uchar*)ky, dststep, rowBufferOutputRow_, width_dst); + + // Update row_ptr + for(int i = 0; i< bufRows - 1; ++i) { + row_ptr_[i] = row_ptr_[i+1]; + } + dst_cur_row += width_dst; //dst.layout.stride[2]; + } + kx += kernel_chan_stride; + ky += kernel_chan_stride; + } // chan_in + } // chan_out + src_cur_step += src.layout.shape[0]; + } //step_in + } + +} // namespace sep_conv +} // namespace arm_common +} // namespace megdnn diff --git a/dnn/src/arm_common/separable_filter/filter.h b/dnn/src/arm_common/separable_filter/filter.h new file mode 100644 index 00000000..b4386656 --- /dev/null +++ b/dnn/src/arm_common/separable_filter/filter.h @@ -0,0 +1,662 @@ +/** + * By downloading, copying, installing or using the software you agree to this license. + * If you do not agree to this license, do not download, install, + * copy or use the software. + * + * + * License Agreement + * For Open Source Computer Vision Library + * (3-clause BSD License) + * + * Copyright (C) 2000-2020, Intel Corporation, all rights reserved. + * Copyright (C) 2009-2011, Willow Garage Inc., all rights reserved. + * Copyright (C) 2009-2016, NVIDIA Corporation, all rights reserved. + * Copyright (C) 2010-2013, Advanced Micro Devices, Inc., all rights reserved. + * Copyright (C) 2015-2016, OpenCV Foundation, all rights reserved. + * Copyright (C) 2015-2016, Itseez Inc., all rights reserved. + * Copyright (C) 2019-2020, Xperience AI, all rights reserved. + * Third party copyrights are property of their respective owners. + * + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the names of the copyright holders nor the names of the contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * This software is provided by the copyright holders and contributors "as is" and + * any express or implied warranties, including, but not limited to, the implied + * warranties of merchantability and fitness for a particular purpose are disclaimed. + * In no event shall copyright holders or contributors be liable for any direct, + * indirect, incidental, special, exemplary, or consequential damages + * (including, but not limited to, procurement of substitute goods or services; + * loss of use, data, or profits; or business interruption) however caused + * and on any theory of liability, whether in contract, strict liability, + * or tort (including negligence or otherwise) arising in any way out of + * the use of this software, even if advised of the possibility of such damage. + * + * --------------------------------------------------------------------------- + * \file dnn/src/arm_common/separable_filter/filter.h + * + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * + * This file has been modified by Megvii ("Megvii Modifications"). + * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. + * + * --------------------------------------------------------------------------- + */ + +#pragma once +#include "src/common/cv/filter.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include +#include + +namespace megdnn { +namespace megcv { +namespace sep_filter { + +using namespace filter_common; + +struct SymmRowSmallVec_8u32s { + SymmRowSmallVec_8u32s() {} + SymmRowSmallVec_8u32s(const uchar* _kernel, int _len) { + kernel = (int*)_kernel; + ksize = _len; + } + + int operator()(const uchar* src, uchar* _dst, int width, int cn) const { + int i = 0, _ksize = ksize; + int* dst = (int*)_dst; + const int* kx = kernel + _ksize / 2; + + src += (_ksize / 2) * cn; + width *= cn; + + if (_ksize == 1) + return 0; + if (_ksize == 3) { + if (kx[0] == 2 && kx[1] == 1) { + uint16x8_t zq = vdupq_n_u16(0); + + for (; i <= width - 8; i += 8, src += 8) { + uint8x8_t x0, x1, x2; + x0 = vld1_u8((uint8_t*)(src - cn)); + x1 = vld1_u8((uint8_t*)(src)); + x2 = vld1_u8((uint8_t*)(src + cn)); + + uint16x8_t y0, y1, y2; + y0 = vaddl_u8(x0, x2); + y1 = vshll_n_u8(x1, 1); + y2 = vaddq_u16(y0, y1); + + uint16x8x2_t str; + str.val[0] = y2; + str.val[1] = zq; + vst2q_u16((uint16_t*)(dst + i), str); + } + } else if (kx[0] == -2 && kx[1] == 1) + return 0; + else { + int32x4_t k32 = vdupq_n_s32(0); + k32 = vld1q_lane_s32(kx, k32, 0); + k32 = vld1q_lane_s32(kx + 1, k32, 1); + + int16x4_t k = vqmovn_s32(k32); + + uint8x8_t z = vdup_n_u8(0); + + for (; i <= width - 8; i += 8, src += 8) { + uint8x8_t x0, x1, x2; + x0 = vld1_u8((uint8_t*)(src - cn)); + x1 = vld1_u8((uint8_t*)(src)); + x2 = vld1_u8((uint8_t*)(src + cn)); + + int16x8_t y0, y1; + int32x4_t y2, y3; + y0 = vreinterpretq_s16_u16(vaddl_u8(x1, z)); + y1 = vreinterpretq_s16_u16(vaddl_u8(x0, x2)); + y2 = vmull_lane_s16(vget_low_s16(y0), k, 0); + y2 = vmlal_lane_s16(y2, vget_low_s16(y1), k, 1); + y3 = vmull_lane_s16(vget_high_s16(y0), k, 0); + y3 = vmlal_lane_s16(y3, vget_high_s16(y1), k, 1); + + vst1q_s32((int32_t*)(dst + i), y2); + vst1q_s32((int32_t*)(dst + i + 4), y3); + } + } + } else if (_ksize == 5) { + if (kx[0] == -2 && kx[1] == 0 && kx[2] == 1) + return 0; + else { + int32x4_t k32 = vdupq_n_s32(0); + k32 = vld1q_lane_s32(kx, k32, 0); + k32 = vld1q_lane_s32(kx + 1, k32, 1); + k32 = vld1q_lane_s32(kx + 2, k32, 2); + + int16x4_t k = vqmovn_s32(k32); + + uint8x8_t z = vdup_n_u8(0); + + for (; i <= width - 8; i += 8, src += 8) { + uint8x8_t x0, x1, x2, x3, x4; + x0 = vld1_u8((uint8_t*)(src - cn)); + x1 = vld1_u8((uint8_t*)(src)); + x2 = vld1_u8((uint8_t*)(src + cn)); + + int16x8_t y0, y1; + int32x4_t accl, acch; + y0 = vreinterpretq_s16_u16(vaddl_u8(x1, z)); + y1 = vreinterpretq_s16_u16(vaddl_u8(x0, x2)); + accl = vmull_lane_s16(vget_low_s16(y0), k, 0); + accl = vmlal_lane_s16(accl, vget_low_s16(y1), k, 1); + acch = vmull_lane_s16(vget_high_s16(y0), k, 0); + acch = vmlal_lane_s16(acch, vget_high_s16(y1), k, 1); + + int16x8_t y2; + x3 = vld1_u8((uint8_t*)(src - cn * 2)); + x4 = vld1_u8((uint8_t*)(src + cn * 2)); + y2 = vreinterpretq_s16_u16(vaddl_u8(x3, x4)); + accl = vmlal_lane_s16(accl, vget_low_s16(y2), k, 2); + acch = vmlal_lane_s16(acch, vget_high_s16(y2), k, 2); + + vst1q_s32((int32_t*)(dst + i), accl); + vst1q_s32((int32_t*)(dst + i + 4), acch); + } + } + } + + return i; + } + + int* kernel; + size_t ksize; +}; + +struct SymmColumnVec_32s8u { + SymmColumnVec_32s8u() {} + SymmColumnVec_32s8u(const uchar* _kernel, int _len, int _bits) { + ksize = _len; + kernel = (float*)malloc(sizeof(float) * ksize); + + for (size_t i = 0; i < ksize; i++) + kernel[i] = (float)(((int*)_kernel)[i]) * (1. / (1 << _bits)); + } + + ~SymmColumnVec_32s8u() { free(kernel); } + + int operator()(const uchar** _src, uchar* dst, int& count, + int width) const { + MEGDNN_MARK_USED_VAR(count); + int _ksize = ksize; + int ksize2 = _ksize / 2; + const float* ky = kernel + ksize2; + const int** src = (const int**)_src; + const int *S, *S2; + int i = 0, k; + + float32x4_t d4 = vdupq_n_f32(0); + + if (_ksize == 1) + return 0; + + float32x2_t k32; + k32 = vdup_n_f32(0); + k32 = vld1_lane_f32(ky, k32, 0); + k32 = vld1_lane_f32(ky + 1, k32, 1); + + for (; i <= width - 8; i += 8) { + float32x4_t accl, acch; + float32x4_t f0l, f0h, f1l, f1h, f2l, f2h; + + S = src[0] + i; + + f0l = vcvtq_f32_s32(vld1q_s32(S)); + f0h = vcvtq_f32_s32(vld1q_s32(S + 4)); + + S = src[1] + i; + S2 = src[-1] + i; + + f1l = vcvtq_f32_s32(vld1q_s32(S)); + f1h = vcvtq_f32_s32(vld1q_s32(S + 4)); + f2l = vcvtq_f32_s32(vld1q_s32(S2)); + f2h = vcvtq_f32_s32(vld1q_s32(S2 + 4)); + + accl = acch = d4; + accl = vmlaq_lane_f32(accl, f0l, k32, 0); + acch = vmlaq_lane_f32(acch, f0h, k32, 0); + accl = vmlaq_lane_f32(accl, vaddq_f32(f1l, f2l), k32, 1); + acch = vmlaq_lane_f32(acch, vaddq_f32(f1h, f2h), k32, 1); + + for (k = 2; k <= ksize2; k++) { + S = src[k] + i; + S2 = src[-k] + i; + + float32x4_t f3l, f3h, f4l, f4h; + f3l = vcvtq_f32_s32(vld1q_s32(S)); + f3h = vcvtq_f32_s32(vld1q_s32(S + 4)); + f4l = vcvtq_f32_s32(vld1q_s32(S2)); + f4h = vcvtq_f32_s32(vld1q_s32(S2 + 4)); + + accl = vmlaq_n_f32(accl, vaddq_f32(f3l, f4l), ky[k]); + acch = vmlaq_n_f32(acch, vaddq_f32(f3h, f4h), ky[k]); + } + + int32x4_t s32l, s32h; + s32l = vcvtq_s32_f32(accl); + s32h = vcvtq_s32_f32(acch); + + int16x4_t s16l, s16h; + s16l = vqmovn_s32(s32l); + s16h = vqmovn_s32(s32h); + + uint8x8_t u8; + u8 = vqmovun_s16(vcombine_s16(s16l, s16h)); + + vst1_u8((uint8_t*)(dst + i), u8); + } + + return i; + } + + float* kernel; + size_t ksize; +}; + +//! 32f + +struct RowVec_32f { + RowVec_32f() {} + + RowVec_32f(const uchar* _kernel, int _len) { + ksize = _len; + kernel = (float*)_kernel; + } + + int operator()(const uchar* _src, uchar* _dst, int width, int cn) const { + int _ksize = ksize; + const float* src0 = (const float*)_src; + float* dst = (float*)_dst; + const float* _kx = (float*)kernel; + + int i = 0, k; + width *= cn; + + for (; i <= width - 8; i += 8) { + const float* src = src0 + i; + float32x4_t f, s0 = vdupq_n_f32(0), s1 = s0, x0, x1; + for (k = 0; k < _ksize; k++, src += cn) { + f = vdupq_n_f32(_kx[k]); + x0 = vld1q_f32(src); + x1 = vld1q_f32(src + 4); + s0 = vmlaq_f32(s0, x0, f); + s1 = vmlaq_f32(s1, x1, f); + } + vst1q_f32(dst + i, s0); + vst1q_f32(dst + i + 4, s1); + } + for (; i <= width - 4; i += 4) { + const float* src = src0 + i; + float32x4_t f, s0 = vdupq_n_f32(0), x0; + for (k = 0; k < _ksize; k++, src += cn) { + f = vdupq_n_f32(_kx[k]); + + x0 = vld1q_f32(src); + s0 = vmlaq_f32(s0, x0, f); + } + vst1q_f32(dst + i, s0); + } + return i; + } + + float* kernel; + int ksize; +}; + +struct SymmRowSmallVec_32f { + SymmRowSmallVec_32f() {} + SymmRowSmallVec_32f(const uchar* _kernel, int _len) { + ksize = _len; + kernel = (float*)_kernel; + } + + int operator()(const uchar* _src, uchar* _dst, int width, int cn) const { + int i = 0, _ksize = ksize; + float* dst = (float*)_dst; + const float* src = (const float*)_src + (_ksize / 2) * cn; + const float* kx = (float*)kernel + _ksize / 2; + width *= cn; + + { + if (_ksize == 1) + return 0; + if (_ksize == 3) { + float32x4_t k0 = vdupq_n_f32(kx[0]), k1 = vdupq_n_f32(kx[1]); + for (; i <= width - 8; i += 8, src += 8) { + float32x4_t x0, x1, x2, y0, y1, y2; + x0 = vld1q_f32(src - cn); + x1 = vld1q_f32(src); + x2 = vld1q_f32(src + cn); + y0 = vld1q_f32(src - cn + 4); + y1 = vld1q_f32(src + 4); + y2 = vld1q_f32(src + cn + 4); + + x0 = vmulq_f32(vaddq_f32(x0, x2), k1); + y0 = vmulq_f32(vaddq_f32(y0, y2), k1); + x0 = vmlaq_f32(x0, x1, k0); + y0 = vmlaq_f32(y0, y1, k0); + vst1q_f32(dst + i, x0); + vst1q_f32(dst + i + 4, y0); + } + } else if (_ksize == 5) { + float32x4_t k0 = vdupq_n_f32(kx[0]), k1 = vdupq_n_f32(kx[1]), + k2 = vdupq_n_f32(kx[2]); + for (; i <= width - 8; i += 8, src += 8) { + float32x4_t x0, x1, x2, y0, y1, y2; + x0 = vld1q_f32(src - cn); + x1 = vld1q_f32(src); + x2 = vld1q_f32(src + cn); + y0 = vld1q_f32(src - cn + 4); + y1 = vld1q_f32(src + 4); + y2 = vld1q_f32(src + cn + 4); + + x0 = vmulq_f32(vaddq_f32(x0, x2), k1); + y0 = vmulq_f32(vaddq_f32(y0, y2), k1); + x0 = vmlaq_f32(x0, x1, k0); + y0 = vmlaq_f32(y0, y1, k0); + + x2 = vaddq_f32(vld1q_f32(src + cn * 2), + vld1q_f32(src - cn * 2)); + y2 = vaddq_f32(vld1q_f32(src + cn * 2 + 4), + vld1q_f32(src - cn * 2 + 4)); + x0 = vmlaq_f32(x0, x2, k2); + y0 = vmlaq_f32(y0, y2, k2); + + vst1q_f32(dst + i, x0); + vst1q_f32(dst + i + 4, y0); + } + } + } + return i; + } + + float* kernel; + int ksize; +}; + +struct ColumnVec_32f { + ColumnVec_32f() {} + ColumnVec_32f(const uchar* _kernel, int _len, int) { + ksize = _len; + kernel = (float*)_kernel; + } + + int operator()(const uchar** _src, uchar* _dst, int&, int width) const { + const float* ky = (const float*)kernel; + int i = 0, k; + const float** src = (const float**)_src; + const float* S; + float* dst = (float*)_dst; + + { + for (; i <= width - 16; i += 16) { + float32x4_t f = vdupq_n_f32(ky[0]); + + float32x4_t s0, s1, s2, s3; + float32x4_t x0, x1; + S = src[0] + i; + s0 = vld1q_f32(S); + s1 = vld1q_f32(S + 4); + s0 = vmulq_f32(s0, f); + s1 = vmulq_f32(s1, f); + s2 = vld1q_f32(S + 8); + s3 = vld1q_f32(S + 12); + s2 = vmulq_f32(s2, f); + s3 = vmulq_f32(s3, f); + + for (k = 1; k < ksize; k++) { + S = src[k] + i; + float32x4_t f = vdupq_n_f32(ky[k]); + x0 = vld1q_f32(S); + x1 = vld1q_f32(S + 4); + s0 = vmlaq_f32(s0, f, x0); + s1 = vmlaq_f32(s1, f, x1); + + x0 = vld1q_f32(S + 8); + x1 = vld1q_f32(S + 12); + s2 = vmlaq_f32(s2, f, x0); + s3 = vmlaq_f32(s3, f, x1); + } + vst1q_f32(dst + i, s0); + vst1q_f32(dst + i + 4, s1); + vst1q_f32(dst + i + 8, s2); + vst1q_f32(dst + i + 12, s3); + } + + for (; i <= width - 4; i += 4) { + float32x4_t f = vdupq_n_f32(ky[0]); + + float32x4_t x0, s0 = vld1q_f32(src[0] + i); + s0 = vmulq_f32(s0, f); + + for (k = 1; k < ksize; k++) { + float32x4_t f = vdupq_n_f32(ky[k]); + S = src[k] + i; + x0 = vld1q_f32(S); + s0 = vmlaq_f32(s0, f, x0); + } + vst1q_f32(dst + i, s0); + } + } + + return i; + } + + float* kernel; + int ksize; +}; + +struct SymmColumnVec_32f { + SymmColumnVec_32f() {} + SymmColumnVec_32f(const uchar* _kernel, int _len, int) { + ksize = _len; + kernel = (float*)_kernel; + } + + int operator()(const uchar** _src, uchar* _dst, int&, int width) const { + int ksize2 = (ksize) / 2; + const float* ky = (const float*)kernel + ksize2; + int i = 0, k; + const float** src = (const float**)_src; + const float *S, *S2; + float* dst = (float*)_dst; + + { + for (; i <= width - 16; i += 16) { + float32x4_t f = vdupq_n_f32(ky[0]); + + float32x4_t s0, s1, s2, s3; + float32x4_t x0, x1; + S = src[0] + i; + s0 = vld1q_f32(S); + s1 = vld1q_f32(S + 4); + s0 = vmulq_f32(s0, f); + s1 = vmulq_f32(s1, f); + s2 = vld1q_f32(S + 8); + s3 = vld1q_f32(S + 12); + s2 = vmulq_f32(s2, f); + s3 = vmulq_f32(s3, f); + + for (k = 1; k <= ksize2; k++) { + S = src[k] + i; + S2 = src[-k] + i; + float32x4_t f = vdupq_n_f32(ky[k]); + + x0 = vaddq_f32(vld1q_f32(S), vld1q_f32(S2)); + x1 = vaddq_f32(vld1q_f32(S + 4), vld1q_f32(S2 + 4)); + s0 = vmlaq_f32(s0, x0, f); + s1 = vmlaq_f32(s1, x1, f); + x0 = vaddq_f32(vld1q_f32(S + 8), vld1q_f32(S2 + 8)); + x1 = vaddq_f32(vld1q_f32(S + 12), vld1q_f32(S2 + 12)); + s2 = vmlaq_f32(s2, x0, f); + s3 = vmlaq_f32(s3, x1, f); + } + + vst1q_f32(dst + i, s0); + vst1q_f32(dst + i + 4, s1); + vst1q_f32(dst + i + 8, s2); + vst1q_f32(dst + i + 12, s3); + } + + for (; i <= width - 4; i += 4) { + float32x4_t f = vdupq_n_f32(ky[0]); + float32x4_t x0, s0 = vld1q_f32(src[0] + i); + s0 = vmulq_f32(s0, f); + + for (k = 1; k <= ksize2; k++) { + float32x4_t f = vdupq_n_f32(ky[k]); + S = src[k] + i; + S2 = src[-k] + i; + x0 = vaddq_f32(vld1q_f32(S), vld1q_f32(S2)); + s0 = vmlaq_f32(s0, x0, f); + } + vst1q_f32(dst + i, s0); + } + } + + return i; + } + + float* kernel; + int ksize; +}; + +struct SymmColumnSmallVec_32f { + SymmColumnSmallVec_32f() {} + SymmColumnSmallVec_32f(const uchar* _kernel, int _len, int) { + ksize = _len; + kernel = (float*)_kernel; + } + + int operator()(const uchar** _src, uchar* _dst, int& count, + int width) const { + MEGDNN_MARK_USED_VAR(count); + int ksize2 = (ksize) / 2; + const float* ky = (float*)kernel + ksize2; + int i = 0; + const float** src = (const float**)_src; + const float *S0 = src[-1], *S1 = src[0], *S2 = src[1]; + float* dst = (float*)_dst; + { + float32x4_t k0 = vdupq_n_f32(ky[0]), k1 = vdupq_n_f32(ky[1]); + for (; i <= width - 8; i += 8) { + float32x4_t s0, s1, x0, x1; + s0 = vld1q_f32(S1 + i); + s1 = vld1q_f32(S1 + i + 4); + s0 = vmulq_f32(s0, k0); + s1 = vmulq_f32(s1, k0); + x0 = vaddq_f32(vld1q_f32(S0 + i), vld1q_f32(S2 + i)); + x1 = vaddq_f32(vld1q_f32(S0 + i + 4), vld1q_f32(S2 + i + 4)); + s0 = vmlaq_f32(s0, x0, k1); + s1 = vmlaq_f32(s1, x1, k1); + vst1q_f32(dst + i, s0); + vst1q_f32(dst + i + 4, s1); + } + } + + return i; + } + + float* kernel; + int ksize; +}; + +/*! + * \brief get the column filter + * \tparam FT The inner buffer type, used to store the product of src and filter + * \tparam DT The dst image type. + */ +template +static BaseColumnFilter* getLinearColumnFilter(Mat& kernel, int bits) { + MEGDNN_MARK_USED_VAR(bits); + int ksize = kernel.cols(); + int anchor = ksize / 2; + uchar* kernel_str = static_cast(kernel.raw_ptr()); + if (SYMM && ksize == 3) { + if (std::is_same::value && std::is_same::value) + return new SymmColumnSmallFilter, + SymmColumnVec_32s8u>( + kernel, anchor, FixedPtCastEx(bits), + SymmColumnVec_32s8u(kernel_str, ksize, bits)); + if (std::is_same::value && std::is_same::value) + return new SymmColumnSmallFilter, + SymmColumnSmallVec_32f>( + kernel, anchor, FixedPtCastEx(0), + SymmColumnSmallVec_32f(kernel_str, ksize, 0)); + } + + if (std::is_same::value && std::is_same::value) + return new ColumnFilter, ColumnNoVec>( + kernel, anchor, FixedPtCastEx(bits), + ColumnNoVec(kernel_str, ksize, bits)); + + if (std::is_same::value && std::is_same::value) + return new ColumnFilter, ColumnVec_32f>( + kernel, anchor, FixedPtCastEx(), + ColumnVec_32f(kernel_str, ksize, 0)); + + MegCVException( + "Unsupported combination of source format and buffer format\n"); +} + +/*! + * \brief get the row filter + * \tparam ST The src image type + * \tparam FT The inner buffer type, used to store the product of src and filter + */ +template +static BaseRowFilter* getLinearRowFilter(Mat& kernel) { + int ksize = kernel.cols(); + int anchor = ksize / 2; + + uchar* kernel_str = static_cast(kernel.raw_ptr()); + + if (SYMM && (ksize == 1 || ksize == 3 || ksize == 5)) { + if (std::is_same::value && std::is_same::value) + return new SymmRowSmallFilter( + kernel, anchor, SymmRowSmallVec_8u32s(kernel_str, ksize)); + if (std::is_same::value && std::is_same::value) + return new SymmRowSmallFilter( + kernel, anchor, SymmRowSmallVec_32f(kernel_str, ksize)); + } + + if (std::is_same::value && std::is_same::value) + return new RowFilter(kernel, anchor, + RowNoVec(kernel_str, ksize)); + + if (std::is_same::value && std::is_same::value) + return new RowFilter(kernel, anchor, + RowVec_32f(kernel_str, ksize)); + + MegCVException( + "Unsupported combination of source format and buffer format\n"); +} + +} // namespace sep_filter +} // namespace x86 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/separable_filter/opr_impl.cpp b/dnn/src/arm_common/separable_filter/opr_impl.cpp new file mode 100644 index 00000000..dbf1fc49 --- /dev/null +++ b/dnn/src/arm_common/separable_filter/opr_impl.cpp @@ -0,0 +1,184 @@ +/** + * By downloading, copying, installing or using the software you agree to this license. + * If you do not agree to this license, do not download, install, + * copy or use the software. + * + * + * License Agreement + * For Open Source Computer Vision Library + * (3-clause BSD License) + * + * Copyright (C) 2000-2020, Intel Corporation, all rights reserved. + * Copyright (C) 2009-2011, Willow Garage Inc., all rights reserved. + * Copyright (C) 2009-2016, NVIDIA Corporation, all rights reserved. + * Copyright (C) 2010-2013, Advanced Micro Devices, Inc., all rights reserved. + * Copyright (C) 2015-2016, OpenCV Foundation, all rights reserved. + * Copyright (C) 2015-2016, Itseez Inc., all rights reserved. + * Copyright (C) 2019-2020, Xperience AI, all rights reserved. + * Third party copyrights are property of their respective owners. + * + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the names of the copyright holders nor the names of the contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * This software is provided by the copyright holders and contributors "as is" and + * any express or implied warranties, including, but not limited to, the implied + * warranties of merchantability and fitness for a particular purpose are disclaimed. + * In no event shall copyright holders or contributors be liable for any direct, + * indirect, incidental, special, exemplary, or consequential damages + * (including, but not limited to, procurement of substitute goods or services; + * loss of use, data, or profits; or business interruption) however caused + * and on any theory of liability, whether in contract, strict liability, + * or tort (including negligence or otherwise) arising in any way out of + * the use of this software, even if advised of the possibility of such damage. + * + * --------------------------------------------------------------------------- + * \file dnn/src/arm_common/separable_filter/opr_impl.cpp + * + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * + * This file has been modified by Megvii ("Megvii Modifications"). + * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. + * + * --------------------------------------------------------------------------- + */ +#include "src/arm_common/separable_filter/opr_impl.h" +#include "src/arm_common/separable_filter/filter.h" +#include "src/arm_common/handle.h" +#include "src/common/cv/common.h" +#include "src/common/cv/helper.h" +#include "src/common/utils.h" +#include + +namespace megdnn { +namespace arm_common { +using namespace megcv; +using namespace sep_filter; +using BorderMode = param::SeparableFilter::BorderMode; + +void SeparableFilterImpl::separable_filter_exec_8u(_megdnn_tensor_in src, + _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, + _megdnn_tensor_out dst) { + megdnn_assert(src.layout.dtype == dtype::Uint8()); + + Mat kernel_column(1, filter_y.layout.shape[3], 1, + static_cast(filter_y.raw_ptr)); + Mat kernel_row(1, filter_x.layout.shape[3], 1, + static_cast(filter_x.raw_ptr)); + + size_t src_channels = src.layout.shape[3]; + + constexpr uint8_t bits = 8; + //! Shift, make the elements of the kernel int + Mat kernel_column_int(1, kernel_column.cols(), 1); + Mat kernel_row_int(1, kernel_row.cols(), 1); + for (size_t i = 0; i < kernel_row.cols(); i++) { + kernel_row_int.at(0, i, 0) = + static_cast(kernel_row.at(0, i, 0) * (1 << bits)); + } + for (size_t i = 0; i < kernel_column.cols(); i++) { + kernel_column_int.at(0, i, 0) = + static_cast(kernel_column.at(0, i, 0) * (1 << bits)); + } + + uchar border_value[4] = {0, 0, 0, 0}; + + BaseRowFilter* rowFilter = nullptr; + BaseColumnFilter* columnFilter = nullptr; + if (param().is_symm_kernel) { + rowFilter = getLinearRowFilter(kernel_row_int); + columnFilter = getLinearColumnFilter( + kernel_column_int, bits * 2); + } else { + rowFilter = getLinearRowFilter(kernel_row_int); + columnFilter = getLinearColumnFilter( + kernel_column_int, bits * 2); + } + + FilterEngine filter(rowFilter, columnFilter, src_channels, + border_value, param().borderMode); + + megdnn_assert(param().borderMode != BorderMode::BORDER_ISOLATED); + for (size_t i = 0; i < src.layout.shape[0]; ++i) { + Mat src_mat = TensorND2Mat(src, i); + Mat dst_mat = TensorND2Mat(dst, i); + + filter.apply(src_mat, dst_mat); + } +} + +template +void SeparableFilterImpl::separable_filter_exec(_megdnn_tensor_in src, + _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, + _megdnn_tensor_out dst) { + Mat kernel_column(1, filter_y.layout.shape[3], 1, + static_cast(filter_y.raw_ptr)); + Mat kernel_row(1, filter_x.layout.shape[3], 1, + static_cast(filter_x.raw_ptr)); + size_t src_channels = src.layout.shape[3]; + + T border_value[4] = {0, 0, 0, 0}; + + BaseRowFilter* row_filter = nullptr; + BaseColumnFilter* column_filter = nullptr; + if (param().is_symm_kernel) { + row_filter = getLinearRowFilter(kernel_row); + column_filter = + getLinearColumnFilter(kernel_column, (int)0); + } else { + row_filter = getLinearRowFilter(kernel_row); + column_filter = + getLinearColumnFilter(kernel_column, (int)0); + } + + FilterEngine filter(row_filter, column_filter, src_channels, + border_value, param().borderMode); + + megdnn_assert(param().borderMode != BorderMode::BORDER_ISOLATED); + for (size_t i = 0; i < src.layout.shape[0]; ++i) { + Mat src_mat = TensorND2Mat(src, i); + Mat dst_mat = TensorND2Mat(dst, i); + filter.apply(src_mat, dst_mat); + } +} + +void SeparableFilterImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + check_exec(src.layout, filter_x.layout, filter_y.layout, dst.layout, + workspace.size); + if (dst.layout.dtype == dtype::Float32()) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + separable_filter_exec(src, filter_x, filter_y, dst)); + } else if (dst.layout.dtype == dtype::Uint8()) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + separable_filter_exec_8u(src, filter_x, filter_y, dst)); + } else { + megdnn_throw("Unsupported datatype of SeparableFilter opr."); + }; +} + +} // namespace x86 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/separable_filter/opr_impl.h b/dnn/src/arm_common/separable_filter/opr_impl.h new file mode 100644 index 00000000..b7ef816b --- /dev/null +++ b/dnn/src/arm_common/separable_filter/opr_impl.h @@ -0,0 +1,42 @@ +/** + * \file dnn/src/arm_common/separable_filter/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" +namespace megdnn { +namespace arm_common { +class SeparableFilterImpl : public SeparableFilterForward { +public: + using SeparableFilterForward::SeparableFilterForward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&, + const TensorLayout&) override { + return 0; + } + +private: + template + void separable_filter_exec(_megdnn_tensor_in src, + _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, + _megdnn_tensor_out dst); + void separable_filter_exec_8u(_megdnn_tensor_in src, + _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, + _megdnn_tensor_out dst); +}; + +} // namespace arm_common +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/simd_macro/marm_neon.cpp b/dnn/src/arm_common/simd_macro/marm_neon.cpp new file mode 100644 index 00000000..97218b0f --- /dev/null +++ b/dnn/src/arm_common/simd_macro/marm_neon.cpp @@ -0,0 +1,16 @@ +/** + * \file dnn/src/arm_common/simd_macro/marm_neon.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/simd_macro/marm_neon.h" + +#pragma message \ + "remove these functions defined in march_neon.h when these functions defined in the future compiler(arm_neon.h)" + diff --git a/dnn/src/arm_common/simd_macro/marm_neon.h b/dnn/src/arm_common/simd_macro/marm_neon.h new file mode 100644 index 00000000..f0dc8458 --- /dev/null +++ b/dnn/src/arm_common/simd_macro/marm_neon.h @@ -0,0 +1,527 @@ +/** + * \file dnn/src/arm_common/simd_macro/marm_neon.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once + +#include +#include "megdnn/arch.h" +#include "src/common/unroll_macro.h" + +// GCC does not support __nodebug__, it reports: +// '__nodebug__' attribute directive ignored +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpragmas" +#pragma GCC diagnostic ignored "-Wattributes" +#define __ai static inline __attribute__((__always_inline__, __nodebug__)) + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC && !MEGDNN_DISABLE_FLOAT16 +#define MEGDNN_INC_ARM_FP16(_x) _x +#else +#define MEGDNN_INC_ARM_FP16(_x) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +//! copy from arm_neon, as in clang7.0 these function not exists +#ifdef __LITTLE_ENDIAN__ +__ai float16x8_t vmlaq_f16(float16x8_t __p0, float16x8_t __p1, + float16x8_t __p2) { + float16x8_t __ret; + __ret = __p0 + __p1 * __p2; + return __ret; +} +#else +__ai float16x8_t vmlaq_f16(float16x8_t __p0, float16x8_t __p1, + float16x8_t __p2) { + float16x8_t __rev0; + __rev0 = __builtin_shufflevector(__p0, __p0, 7, 6, 5, 4, 3, 2, 1, 0); + float16x8_t __rev1; + __rev1 = __builtin_shufflevector(__p1, __p1, 7, 6, 5, 4, 3, 2, 1, 0); + float16x8_t __rev2; + __rev2 = __builtin_shufflevector(__p2, __p2, 7, 6, 5, 4, 3, 2, 1, 0); + float16x8_t __ret; + __ret = __rev0 + __rev1 * __rev2; + __ret = __builtin_shufflevector(__ret, __ret, 7, 6, 5, 4, 3, 2, 1, 0); + return __ret; +} +#endif + +#ifdef __LITTLE_ENDIAN__ +#define vmlaq_lane_f16(__p0, __p1, __p2, __p3) \ + __extension__({ \ + float16x8_t __s0 = __p0; \ + float16x8_t __s1 = __p1; \ + float16x4_t __s2 = __p2; \ + float16x8_t __ret; \ + __ret = __s0 + __s1 * __builtin_shufflevector(__s2, __s2, __p3, __p3, \ + __p3, __p3, __p3, __p3, \ + __p3, __p3); \ + __ret; \ + }) +#else +#define vmlaq_lane_f16(__p0, __p1, __p2, __p3) \ + __extension__({ \ + float16x8_t __s0 = __p0; \ + float16x8_t __s1 = __p1; \ + float16x4_t __s2 = __p2; \ + float16x8_t __rev0; \ + __rev0 = __builtin_shufflevector(__s0, __s0, 7, 6, 5, 4, 3, 2, 1, 0); \ + float16x8_t __rev1; \ + __rev1 = __builtin_shufflevector(__s1, __s1, 7, 6, 5, 4, 3, 2, 1, 0); \ + float16x4_t __rev2; \ + __rev2 = __builtin_shufflevector(__s2, __s2, 3, 2, 1, 0); \ + float16x8_t __ret; \ + __ret = __rev0 + __rev1 * __builtin_shufflevector( \ + __rev2, __rev2, __p3, __p3, __p3, \ + __p3, __p3, __p3, __p3, __p3); \ + __ret = __builtin_shufflevector(__ret, __ret, 7, 6, 5, 4, 3, 2, 1, 0); \ + __ret; \ + }) +#endif + +#if 0 +//! As in arm_neon.h, `vdupq_n_f16` is macro, may be different with +//! `vdupq_n_f32`, So here just undefine the macro, and declare a function to +//! implement just as `vdupq_n_f32`. +#undef vdupq_n_f16 +#ifdef __LITTLE_ENDIAN__ +__ai float16x8_t vdupq_n_f16(float16_t __p0) { + float16x8_t __ret; + __ret = (float16x8_t){__p0, __p0, __p0, __p0, __p0, __p0, __p0, __p0}; + return __ret; +} + +#else +__ai float16x8_t vdupq_n_f16(float16_t __p0) { + float16x8_t __ret; + __ret = (float16x8_t){__p0, __p0, __p0, __p0, __p0, __p0, __p0, __p0}; + __ret = __builtin_shufflevector(__ret, __ret, 7, 6, 5, 4, 3, 2, 1, 0); + return __ret; +} +#endif +#endif + +#ifdef __LITTLE_ENDIAN__ +#define vmlaq_laneq_f16(__p0, __p1, __p2, __p3) \ + __extension__({ \ + float16x8_t __s0 = __p0; \ + float16x8_t __s1 = __p1; \ + float16x8_t __s2 = __p2; \ + float16x8_t __ret; \ + __ret = __s0 + __s1 * __builtin_shufflevector(__s2, __s2, __p3, __p3, \ + __p3, __p3, __p3, __p3, \ + __p3, __p3); \ + __ret; \ + }) +#else +#define vmlaq_laneq_f16(__p0, __p1, __p2, __p3) \ + __extension__({ \ + float16x8_t __s0 = __p0; \ + float16x8_t __s1 = __p1; \ + float16x8_t __s2 = __p2; \ + float16x8_t __rev0; \ + __rev0 = __builtin_shufflevector(__s0, __s0, 7, 6, 5, 4, 3, 2, 1, 0); \ + float16x8_t __rev1; \ + __rev1 = __builtin_shufflevector(__s1, __s1, 7, 6, 5, 4, 3, 2, 1, 0); \ + float16x8_t __rev2; \ + __rev2 = __builtin_shufflevector(__s2, __s2, 7, 6, 5, 4, 3, 2, 1, 0); \ + float16x8_t __ret; \ + __ret = __rev0 + __rev1 * __builtin_shufflevector( \ + __rev2, __rev2, __p3, __p3, __p3, \ + __p3, __p3, __p3, __p3, __p3); \ + __ret = __builtin_shufflevector(__ret, __ret, 7, 6, 5, 4, 3, 2, 1, 0); \ + __ret; \ + }) +#endif + +#if MEGDNN_ARMV7 +#define vmlaq_low_lane_f16(__a, __b, __v, __lane) \ + __extension__({ \ + auto c = vget_low_f16(__v); \ + auto __ret = vmlaq_lane_f16(__a, __b, c, __lane); \ + __ret; \ + }) + +#define vmlaq_high_lane_f16(__a, __b, __v, __lane) \ + __extension__({ \ + auto c = vget_high_f16(__v); \ + auto __ret = vmlaq_lane_f16(__a, __b, c, (__lane - 4)); \ + __ret; \ + }) + +//! FIXME: remove these funtion once llvm fix such bugs +//! As origin implentation in \c arm_neon.h may cause +//! \attention {error in backend: Do not know how to split this operator's +//! operand!} +/////////////////////////////////////////////////////////////////////// +__ai float16x8_t vmulq_fix_f16(float16x8_t a, float16x8_t b) { + float16x8_t ret; + asm volatile("vmul.f16 %0, %1, %2\n" : "+w"(ret) : "w"(a), "w"(b)); + return ret; +} + +__ai float16x8_t vmulq_n_fix_f16(float16x8_t a, __fp16 b) { + float16x8_t ret; + asm volatile( + "vdup.16 q0, %2 \n" + "vmul.f16 %0, %1, q0\n" + : "+w"(ret) + : "w"(a), "r"(b) + : "q0"); + return ret; +} + +__ai float16x4_t vmul_n_fix_f16(float16x4_t a, __fp16 b) { + float16x4_t ret; + asm volatile( + "vdup.16 d0,%2\n" + "vmul.f16 %0, %1, d0[0]\n" + : "+w"(ret) + : "w"(a), "r"(b) + : "d0"); + return ret; +} +__ai float16x8_t vmlaq_fix_f16(float16x8_t a, float16x8_t b, float16x8_t c) { + asm volatile("vmla.f16 %0, %1, %2\n" : "+w"(a) : "w"(b), "w"(c)); + return a; +} + +__ai float16x8_t vaddq_fix_f16(float16x8_t a, float16x8_t b) { + float16x8_t ret; + asm volatile("vadd.f16 %0, %1, %2\n" : "+w"(ret) : "w"(a), "w"(b)); + return ret; +} + +#undef vdupq_n_f16 +__ai float16x8_t vdupq_n_f16(__fp16 a) { + float16x8_t ret; + asm volatile("vdup.16 %0, %1\n" : "+w"(ret) : "r"(a) :); + return ret; +} + +/////////////////////////////////////////////////////////////////////// + +#elif MEGDNN_AARCH64 +#define vmlaq_low_lane_f16(__a, __b, __v, __lane) \ + vmlaq_laneq_f16(__a, __b, __v, __lane) + +#define vmlaq_high_lane_f16(__a, __b, __v, __lane) \ + vmlaq_laneq_f16(__a, __b, __v, __lane) + +//! FIXME: remove these funtion once llvm fix such bugs +//! As origin implentation in \c arm_neon.h may cause +//! \attention {error in backend: Do not know how to split this operator's +//! operand!} +/////////////////////////////////////////////////////////////////////// + +__ai float16x8_t vmulq_fix_f16(float16x8_t a, float16x8_t b) { + return vmulq_f16(a, b); +} + +__ai float16x8_t vmlaq_fix_f16(float16x8_t a, float16x8_t b, float16x8_t c) { + return vmlaq_f16(a, b, c); +} + +__ai float16x8_t vaddq_fix_f16(float16x8_t a, float16x8_t b) { + return vaddq_f16(a, b); +} + +#undef vdupq_n_f16 +__ai float16x8_t vdupq_n_f16(__fp16 a) { + float16x8_t ret; + asm volatile("dup %0.8h, %w1\n" : "+w"(ret) : "r"(a) :); + return ret; +} + +/////////////////////////////////////////////////////////////////////// + +#endif + +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if __ARM_FEATURE_DOTPROD + +__ai int32x4_t vdotq2_s32(int8x16_t a, int8x16_t b) { + int32x4_t c = vdupq_n_s32(0); + return vdotq_s32(c, a, b); +} + +__ai uint32x4_t vdotq2_u32(uint8x16_t a, uint8x16_t b) { + uint32x4_t c = vdupq_n_u32(0); + return vdotq_u32(c, a, b); +} + +#define vdotq2_lane_s32(a, b, lane) \ + __extension__({ \ + int32x4_t c = vdupq_n_s32(0); \ + c = vdotq_lane_s32(c, a, b, lane); \ + c; \ + }) + +#define vdotq2_lane_u32(a, b, lane) \ + __extension__({ \ + uint32x4_t c = vdupq_n_u32(0); \ + c = vdotq_lane_u32(c, a, b, lane); \ + c; \ + }) + +__ai int32x2_t vdot2_s32(int8x8_t a, int8x8_t b) { + int32x2_t c = vdup_n_s32(0); + return vdot_s32(c, a, b); +} + +__ai uint32x2_t vdot2_u8(uint8x8_t a, uint8x8_t b) { + uint32x2_t c = vdup_n_u32(0); + return vdot_u32(c, a, b); +} + +#define vdot2_lane_s32(a, b, lane) \ + __extension__({ \ + int32x2_t c = vdup_n_s32(0); \ + c = vdot_lane_s32(c, a, b, lane); \ + c; \ + }) + +#define vdot2_lane_u8(a, b, lane) \ + __extension__({ \ + uint32x2_t c = vdup_n_u32(0); \ + c = vdot_lane_u32(c, a, b, lane); \ + c; \ + }) + +#endif // __ARM_FEATURE_DOTPROD + +#undef vld1q_f32_x2 +__ai float32x4x2_t vld1q_f32_x2(const float* p) { + return {{vld1q_f32(p), vld1q_f32(p + 4)}}; +} + +#undef vst1q_f32_x2 +__ai void vst1q_f32_x2(const float* p, float32x4x2_t v) { + vst1q_f32(const_cast(p), v.val[0]); + vst1q_f32(const_cast(p) + 4, v.val[1]); +} + +__ai int8x16_t vtranslq_s8(int8x8_t a) { + int8x16_t ret; +#if MEGDNN_AARCH64 + asm volatile("ins %0.d[0], %1.d[0]\n" : "+w"(ret) : "w"(a) :); +#else + asm volatile("vmov %e0, %P1\n" : "+w"(ret) : "w"(a) :); +#endif + return ret; +} + +__ai uint8x16_t vtranslq_u8(uint8x8_t a) { + uint8x16_t ret; +#if MEGDNN_AARCH64 + asm volatile("ins %0.d[0], %1.d[0]\n" : "+w"(ret) : "w"(a) :); +#else + asm volatile("vmov %e0, %P1\n" : "+w"(ret) : "w"(a) :); +#endif + return ret; +} + +#ifdef MEGDNN_TEGRA_X1 +#define vset_lane_s16_fix_tx1(__elem, __vec, __index) \ + { \ + asm volatile("ins %0.h[" #__index "], %w1\n" \ + : "+w"(__vec) \ + : "r"(__elem) \ + :); \ + } +#else +#define vset_lane_s16_fix_tx1(__elem, __vec, __index) \ + __vec = vset_lane_s16(__elem, __vec, __index) +#endif + +#if MEGDNN_ARMV7 +__ai int32_t vaddlvq_s16(int16x8_t __p0) { + int32_t __ret = 0; + auto sum = vpaddlq_s16(__p0); + __ret += (vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) + + vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3)); + return __ret; +} + +__ai int16x8_t vmlal_high_s8(int16x8_t __p0, int8x16_t __p1, int8x16_t __p2) { + int16x8_t __ret; + __ret = vmlal_s8(__p0, vget_high_s8(__p1), vget_high_s8(__p2)); + return __ret; +} + +__ai int16x8_t vmull_high_s8(int8x16_t __p0, int8x16_t __p1) { + int16x8_t __ret; + __ret = vmull_s8(vget_high_s8(__p0), vget_high_s8(__p1)); + return __ret; +} + +//! armv7 : vmovl_xx(vget_high_xx()), armv8 : vmovl_high_xx() +__ai int16x8_t vmovl_high_s8(int8x16_t __p0) { + return vmovl_s8(vget_high_s8(__p0)); +} + +__ai uint16x8_t vmovl_high_u8(uint8x16_t __p0) { + return vmovl_u8(vget_high_u8(__p0)); +} + +__ai int32x4_t vmovl_high_s16(int16x8_t __p0) { + return vmovl_s16(vget_high_s16(__p0)); +} + +__ai uint32x4_t vmovl_high_u16(uint16x8_t __p0) { + return vmovl_u16(vget_high_u16(__p0)); +} + +__ai int64x2_t vmovl_high_s32(int32x4_t __p0) { + return vmovl_s32(vget_high_s32(__p0)); +} + +__ai uint64x2_t vmovl_high_u32(uint32x4_t __p0) { + return vmovl_u32(vget_high_u32(__p0)); +} +#endif // MEGDNN_ARMV7 + +//! pack vmovl_low_xx() on armv7 and armv8 +__ai int16x8_t vmovl_low_s8(int8x16_t __p0) { + return vmovl_s8(vget_low_s8(__p0)); +} + +__ai uint16x8_t vmovl_low_u8(uint8x16_t __p0) { + return vmovl_u8(vget_low_u8(__p0)); +} + +__ai int32x4_t vmovl_low_s16(int16x8_t __p0) { + return vmovl_s16(vget_low_s16(__p0)); +} + +__ai uint32x4_t vmovl_low_u16(uint16x8_t __p0) { + return vmovl_u16(vget_low_u16(__p0)); +} + +__ai int64x2_t vmovl_low_s32(int32x4_t __p0) { + return vmovl_s32(vget_low_s32(__p0)); +} + +__ai uint64x2_t vmovl_low_u32(uint32x4_t __p0) { + return vmovl_u32(vget_low_u32(__p0)); +} + +#if MEGDNN_ARMV7 +#define vmlaq_low_lane_f32(__a, __b, __v, __lane) \ + __extension__({ \ + auto c = vget_low_f32(__v); \ + auto __ret = vmlaq_lane_f32(__a, __b, c, __lane); \ + __ret; \ + }) + +#define vmlaq_high_lane_f32(__a, __b, __v, __lane) \ + __extension__({ \ + auto c = vget_high_f32(__v); \ + auto __ret = vmlaq_lane_f32(__a, __b, c, (__lane - 2)); \ + __ret; \ + }) + +#elif MEGDNN_AARCH64 +__ai float64x2_t vbitq_f64(float64x2_t dst, float64x2_t v1, uint64x2_t mask) { + asm volatile("bit %0.16b, %1.16b, %2.16b\n" + : "+w"(dst) + : "w"(v1), "w"(mask) + :); + return dst; +} + +#define vmlaq_low_lane_f32(__a, __b, __v, __lane) \ + vmlaq_laneq_f32(__a, __b, __v, __lane) + +#define vmlaq_high_lane_f32(__a, __b, __v, __lane) \ + vmlaq_laneq_f32(__a, __b, __v, __lane) + +#endif + +#if MEGDNN_ARMV7 +__ai int8x16_t vqtbl1q_s8(int8x16_t& a, uint8x16_t& idx) { + int8x8_t src_low = vget_low_s8(a); + int8x8_t src_high = vget_high_s8(a); + return vcombine_s8(vtbl2_s8({src_low, src_high}, + vget_low_s8(vreinterpretq_s8_u8(idx))), + vtbl2_s8({src_low, src_high}, + vget_high_s8(vreinterpretq_s8_u8(idx)))); +} +namespace { +template +struct Vdup_laneq_s16_armv7 { + static int16x4_t impl(int16x8_t vec); +}; +#define cb(step) \ + template <> \ + struct Vdup_laneq_s16_armv7 { \ + static int16x4_t impl(int16x8_t vec) { \ + return vdup_lane_s16(vget_high_s16(vec), step); \ + } \ + }; \ + template <> \ + struct Vdup_laneq_s16_armv7 { \ + static int16x4_t impl(int16x8_t vec) { \ + return vdup_lane_s16(vget_low_s16(vec), step); \ + } \ + }; + +UNROLL_CALL_RAW(4, cb); +#undef cb +} // namespace +#define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7::impl(vec) + +#endif + +__ai int8x16_t vld_dup_tbl_s32(const int8_t* ptr, uint8x16_t& idx) { + int8x16_t result = vreinterpretq_s8_s32(vld1q_dup_s32((const int32_t*)ptr)); + result = vqtbl1q_s8(result, idx); + return result; +} +__ai int8x16_t vldq_tbl_s8(const int8_t* ptr, uint8x16_t& idx) { + int8x16_t result = vld1q_s8(ptr); + result = vqtbl1q_s8(result, idx); + return result; +} +__ai int32x4_t vdotq_s32_h(int8x16_t& a, int8x16_t& b, int32x4_t& c, + int16x8_t& temp) { + temp = vmull_s8(vget_low_s8(a), vget_low_s8(b)); + temp = vmlal_high_s8(temp, a, b); + c = vpadalq_s16(c, temp); + return c; +} +__ai int32x4_t vdot2_s32_h(int8x8_t& a, int8x8_t& b, int32x4_t& c, + int16x8_t& temp) { + temp = vmull_s8(a, b); + c = vpadalq_s16(c, temp); + return c; +} + +__ai int32x4_t vmlal_s16(int32x4_t& a, int16x8_t& b, int16x8_t& c) { + return vmlal_s16(a, vget_low_s16(b), vget_low_s16(c)); +} + +__ai int16x8_t vldq_dup_4s8_8s16(const int8_t* ptr) { + return vmovl_s8(vreinterpret_s8_s32( + vld1_dup_s32(reinterpret_cast(ptr)))); +} +__ai int8x8_t vldq_tbl_low_s8(const int8_t* ptr, uint8x16_t idx) { + return vget_low_s8(vldq_tbl_s8(ptr, idx)); +} +__ai int16x8_t vld1_dup_s8_s16(const int8_t* ptr) { + return vmovl_s8(vld1_dup_s8(ptr)); +} + +#undef __ai +#pragma GCC diagnostic pop + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/simd_macro/neon_helper.h b/dnn/src/arm_common/simd_macro/neon_helper.h new file mode 100644 index 00000000..3f0cb92b --- /dev/null +++ b/dnn/src/arm_common/simd_macro/neon_helper.h @@ -0,0 +1,55 @@ +/** + * \file dnn/src/arm_common/simd_macro/neon_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/simd_macro/marm_neon.h" + +#define MEGDNN_SIMD_NAME NEON +#define MEGDNN_SIMD_TARGET neon +#define MEGDNN_SIMD_ATTRIBUTE_TARGET +#define MEGDNN_SIMD_LAMBDA_ATTRIBUTE_TARGET +#define MEGDNN_SIMD_WIDTH 4 +#define MEGDNN_SIMD_TYPE float32x4_t +#define MEGDNN_SIMD_TYPE2 float32x4x2_t +#define MEGDNN_SIMD_LOADU(addr) vld1q_f32(addr) +#define MEGDNN_SIMD_STOREU(addr, reg) vst1q_f32(addr, reg) +#define MEGDNN_SIMD_SETZERO() vdupq_n_f32(0.0f) +#define MEGDNN_SIMD_SET1(num) vdupq_n_f32(num) +// XXX The order of a, b, c +#define MEGDNN_SIMD_FMADD(a, b, c) vmlaq_f32(c, a, b) +#define MEGDNN_SIMD_MAX(a, b) vmaxq_f32(a, b) +#define MEGDNN_SIMD_UZP(s0, s1, d0, d1) do { \ + auto tmp__ = vuzpq_f32(s0, s1); \ + d0 = tmp__.val[0]; \ + d1 = tmp__.val[1]; \ +} while (0) +#define MEGDNN_SIMD_LOAD2(addr) vld2q_f32(addr) +#define MEGDNN_SIMD_EXT(a, b, c) vextq_f32(a, b, c) +#define MEGDNN_SIMD_MUL(a, b) vmulq_f32(a, b) +#define MEGDNN_SIMD_ADD(a, b) vaddq_f32(a, b) +#define MEGDNN_SIMD_SET_LANE(a, b, c) vsetq_lane_f32(a, b, c) +#define MEGDNN_SIMD_GET_LOW(a) vget_low_f32(a) +#define MEGDNN_SIMD_GET_HIGH(a) vget_high_f32(a) +#define MEGDNN_SIMD_VMLAQ_LANE(a, b, c, d) vmlaq_lane_f32(a, b, c, d) +#if MEGDNN_ARMV7 +#define MEGDNN_SIMD_FMA_LANE(a, b, c, d) ({ \ + auto ret__ = vdupq_n_f32(vgetq_lane_f32(c, d)); \ + ret__ = vmlaq_f32(a, b, ret__); \ + ret__;}) +#define MEGDNN_SIMD_ADD_VEC(a) ({ \ + auto tmp__ = vadd_f32(vget_low_f32(a), vget_high_f32(a)); \ + tmp__ = vpadd_f32(tmp__, tmp__); \ + auto ret__ = vget_lane_f32(tmp__, 0); \ + ret__;}) +#else +// MEGDNN_AARCH64 +#define MEGDNN_SIMD_FMA_LANE(a, b, c, d) vfmaq_laneq_f32(a, b, c, d) +#define MEGDNN_SIMD_ADD_VEC(a) vaddvq_f32(a) +#endif + diff --git a/dnn/src/arm_common/simd_macro/neon_helper_epilogue.h b/dnn/src/arm_common/simd_macro/neon_helper_epilogue.h new file mode 100644 index 00000000..7618ce0c --- /dev/null +++ b/dnn/src/arm_common/simd_macro/neon_helper_epilogue.h @@ -0,0 +1,11 @@ +/** + * \file dnn/src/arm_common/simd_macro/neon_helper_epilogue.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/common/simd_macro/epilogue.h" diff --git a/dnn/src/arm_common/simd_macro/neon_helper_fp16.h b/dnn/src/arm_common/simd_macro/neon_helper_fp16.h new file mode 100644 index 00000000..e581b74b --- /dev/null +++ b/dnn/src/arm_common/simd_macro/neon_helper_fp16.h @@ -0,0 +1,25 @@ +/** + * \file dnn/src/arm_common/simd_macro/neon_helper_fp16.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/simd_macro/marm_neon.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#define MEGDNN_SIMD_NAME NEON +#define MEGDNN_SIMD_TARGET neon +#define MEGDNN_SIMD_ATTRIBUTE_TARGET +#define MEGDNN_SIMD_WIDTH 4 +#define MEGDNN_SIMD_TYPE float16x8_t +#define MEGDNN_SIMD_TYPE2 float16x8x2_t +#define MEGDNN_SIMD_LOADU(addr) vld1q_f16(addr) +#define MEGDNN_SIMD_STOREU(addr, reg) vst1q_f16(addr, reg) +#define MEGDNN_SIMD_SETZERO() vdupq_n_f16(0.0f) +#define MEGDNN_SIMD_SET1(num) vdupq_n_f16(num) + +#endif diff --git a/dnn/src/arm_common/type_cvt/opr_impl.cpp b/dnn/src/arm_common/type_cvt/opr_impl.cpp new file mode 100644 index 00000000..13545e3b --- /dev/null +++ b/dnn/src/arm_common/type_cvt/opr_impl.cpp @@ -0,0 +1,413 @@ +/** + * \file dnn/src/arm_common/type_cvt/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/type_cvt/opr_impl.h" + +#include +#include "midout.h" +#include "src/arm_common/quantized_converter.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" + +MIDOUT_DECL(megdnn_arm_typecvt_quantized) +MIDOUT_DECL(megdnn_arm_typecvt_float) + +using namespace megdnn; +using namespace arm_common; + +namespace { + +template +struct QuantizedTypeCvter; + +template <> +struct QuantizedTypeCvter { + using stype = int32_t; + using dst_type = int8_t; + static constexpr size_t SIMD_WIDTH = 8; + float scale; + float32x4_t vscale; + + QuantizedTypeCvter(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + scale = src_scale / dst_scale; + vscale = vdupq_n_f32(scale); + } + + void cvt(const int32_t* src, int8_t* dst) { + float32x4_t vitem0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(src)), vscale); + float32x4_t vitem1 = + vmulq_f32(vcvtq_f32_s32(vld1q_s32(src + 4)), vscale); + + auto vres = QConverter::convert( + {{vitem0, vitem1}}); + vst1_s8(dst, vres); + } + + void cvt_remain(const int32_t* src, int8_t* dst) { + *dst = saturate(std::round(*src * scale), -128, 127); + } +}; + +template <> +struct QuantizedTypeCvter { + using stype = int8_t; + using dst_type = int32_t; + static constexpr size_t SIMD_WIDTH = 8; + float scale; + float32x4_t vscale; + + QuantizedTypeCvter(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + scale = src_scale / dst_scale; + vscale = vdupq_n_f32(scale); + } + + void cvt(const int8_t* src, int32_t* dst) { + int16x8_t vitem = vmovl_s8(vld1_s8(src)); + auto vret0 = QConverter::convert(vmulq_f32( + vcvtq_f32_s32(vmovl_s16(vget_low_s16(vitem))), vscale)); + auto vret1 = QConverter::convert(vmulq_f32( + vcvtq_f32_s32(vmovl_s16(vget_high_s16(vitem))), vscale)); + vst1q_s32(dst, vret0); + vst1q_s32(dst + 4, vret1); + } + + void cvt_remain(const int8_t* src, int32_t* dst) { + *dst = saturate(std::round(*src * scale), -2147483648, + 2147483647); + } +}; + +template <> +struct QuantizedTypeCvter { + using stype = float; + using dst_type = int8_t; + static constexpr size_t SIMD_WIDTH = 8; + float scale; + float32x4_t vscale; + + QuantizedTypeCvter(DType src_dtype, DType dst_dtype) { + MEGDNN_MARK_USED_VAR(src_dtype); + float src_scale = 1; + float dst_scale = dst_dtype.param().scale; + scale = src_scale / dst_scale; + vscale = vdupq_n_f32(scale); + } + + void cvt(const float* src, int8_t* dst) { + float32x4_t vitem0 = vmulq_f32(vld1q_f32(src), vscale); + float32x4_t vitem1 = vmulq_f32(vld1q_f32(src + 4), vscale); + + auto vres = QConverter::convert( + {{vitem0, vitem1}}); + vst1_s8(dst, vres); + } + + void cvt_remain(const float* src, int8_t* dst) { + *dst = saturate(std::round(*src * scale), -128, 127); + } +}; + +template <> +struct QuantizedTypeCvter { + using stype = int32_t; + using dst_type = int32_t; + static constexpr size_t SIMD_WIDTH = 4; + float scale; + float32x4_t vscale; + + QuantizedTypeCvter(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + scale = src_scale / dst_scale; + vscale = vdupq_n_f32(scale); + } + + void cvt(const int32_t* src, int32_t* dst) { + float32x4_t vitem = vmulq_f32(vcvtq_f32_s32(vld1q_s32(src)), vscale); + + auto vres = QConverter::convert(vitem); + vst1q_s32(dst, vres); + } + + void cvt_remain(const int32_t* src, int32_t* dst) { + *dst = saturate(std::round(*src * scale), -2147483648, + 2147483647); + } +}; + +template <> +struct QuantizedTypeCvter { + using stype = int8_t; + using dst_type = int8_t; + static constexpr size_t SIMD_WIDTH = 8; + float scale; + float32x4_t vscale; + + QuantizedTypeCvter(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + scale = src_scale / dst_scale; + vscale = vdupq_n_f32(scale); + } + + void cvt(const int8_t* src, int8_t* dst) { + int16x8_t vdata = vmovl_s8(vld1_s8(src)); + float32x4_t vitem0 = vmulq_f32( + vcvtq_f32_s32(vmovl_s16(vget_low_s16(vdata))), vscale); + float32x4_t vitem1 = vmulq_f32( + vcvtq_f32_s32(vmovl_s16(vget_high_s16(vdata))), vscale); + + auto vres = QConverter::convert( + {{vitem0, vitem1}}); + vst1_s8(dst, vres); + } + + void cvt_remain(const int8_t* src, int8_t* dst) { + *dst = saturate(std::round(*src * scale), -128, 127); + } +}; + +template <> +struct QuantizedTypeCvter { + using stype = float; + using dst_type = uint8_t; + static constexpr size_t SIMD_WIDTH = 8; + float scale; + uint8_t zp; + int32x4_t vzp, vzero; + float32x4_t vscale; + + QuantizedTypeCvter(DType src_dtype, DType dst_dtype) { + MEGDNN_MARK_USED_VAR(src_dtype); + float src_scale = 1; + float dst_scale = dst_dtype.param().scale; + scale = src_scale / dst_scale; + zp = dst_dtype.param().zero_point; + vzp = vdupq_n_s32(static_cast(zp)); + vzero = vdupq_n_s32(0); + vscale = vdupq_n_f32(scale); + } + + void cvt(const float* src, uint8_t* dst) { + float32x4_t vitem0 = vmulq_f32(vld1q_f32(src), vscale); + float32x4_t vitem1 = vmulq_f32(vld1q_f32(src + 4), vscale); + + auto vres = QConverter::convert( + {{vitem0, vitem1}}, this->vzp); + vst1_u8(dst, vres); + } + + void cvt_remain(const float* src, uint8_t* dst) { + *dst = saturate(std::round(*src * scale) + zp, 0, 255); + } +}; + +template <> +struct QuantizedTypeCvter { + using stype = int32_t; + using dst_type = uint8_t; + static constexpr size_t SIMD_WIDTH = 8; + float scale; + uint8_t zp; + int32x4_t vzp, vzero; + float32x4_t vscale; + + QuantizedTypeCvter(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + scale = src_scale / dst_scale; + zp = dst_dtype.param().zero_point; + vzp = vdupq_n_s32(static_cast(zp)); + vzero = vdupq_n_s32(0); + vscale = vdupq_n_f32(scale); + } + + void cvt(const int32_t* src, uint8_t* dst) { + float32x4_t vitem0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(src)), vscale); + float32x4_t vitem1 = + vmulq_f32(vcvtq_f32_s32(vld1q_s32(src + 4)), vscale); + auto vres = QConverter::convert( + {{vitem0, vitem1}}, this->vzp); + vst1_u8(dst, vres); + } + + void cvt_remain(const int32_t* src, uint8_t* dst) { + *dst = saturate(std::round(*src * scale) + zp, 0, 255); + } +}; + +template <> +struct QuantizedTypeCvter { + using stype = uint8_t; + using dst_type = uint8_t; + static constexpr size_t SIMD_WIDTH = 8; + float scale; + uint8_t zp_dst, zp_src; + int32x4_t vzp_dst, vzero; + int16x8_t vzp_src; + float32x4_t vscale; + + QuantizedTypeCvter(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + scale = src_scale / dst_scale; + zp_dst = dst_dtype.param().zero_point; + zp_src = src_dtype.param().zero_point; + vzp_dst = vdupq_n_s32(static_cast(zp_dst)); + vzp_src = vdupq_n_s16(static_cast(zp_src)); + vzero = vdupq_n_s32(0); + vscale = vdupq_n_f32(scale); + } + + void cvt(const uint8_t* src, uint8_t* dst) { + int16x8_t vdata = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(src))); + vdata = vsubq_s16(vdata, vzp_src); + float32x4_t vitem0 = vmulq_f32( + vcvtq_f32_s32(vmovl_s16(vget_low_s16(vdata))), vscale); + float32x4_t vitem1 = vmulq_f32( + vcvtq_f32_s32(vmovl_s16(vget_high_s16(vdata))), vscale); + + auto vres = QConverter::convert( + {{vitem0, vitem1}}, this->vzp_dst); + vst1_u8(dst, vres); + } + + void cvt_remain(const uint8_t* src, uint8_t* dst) { + *dst = saturate( + std::round((*src - zp_src) * scale) + zp_dst, 0, 255); + } +}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template +struct FloatTypeCvter; + +template <> +struct FloatTypeCvter<__fp16, float> { + using stype = __fp16; + using dst_type = float; + static constexpr size_t SIMD_WIDTH = 8; + FloatTypeCvter(DType src_dtype, DType dst_dtype) {} + + void cvt(const __fp16* src, float* dst) { + float16x8_t vdata = vld1q_f16(src); + float32x4_t vdata_low = vcvt_f32_f16(vget_low_f16(vdata)); + float32x4_t vdata_high = vcvt_f32_f16(vget_high_f16(vdata)); + vst1q_f32(dst, vdata_low); + vst1q_f32(dst + 4, vdata_high); + } + + void cvt_remain(const __fp16* src, float* dst) { *dst = *src; } +}; + +template <> +struct FloatTypeCvter { + using stype = float; + using dst_type = __fp16; + static constexpr size_t SIMD_WIDTH = 8; + FloatTypeCvter(DType src_dtype, DType dst_dtype) {} + + void cvt(const float* src, __fp16* dst) { + float32x4_t vdata0 = vld1q_f32(src); + float32x4_t vdata1 = vld1q_f32(src + 4); + float16x8_t vdata = + vcombine_f16(vcvt_f16_f32(vdata0), vcvt_f16_f32(vdata1)); + vst1q_f16(dst, vdata); + } + + void cvt_remain(const float* src, __fp16* dst) { *dst = *src; } +}; +#endif + +template +void do_typecvt(const typename TypeCvter::stype* src, + typename TypeCvter::dst_type* dst, DType src_dtype, + DType dst_dtype, size_t nr_elems) { + TypeCvter typecvt(src_dtype, dst_dtype); + size_t i = 0; + for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { + typecvt.cvt(src, dst); + src += TypeCvter::SIMD_WIDTH; + dst += TypeCvter::SIMD_WIDTH; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + typecvt.cvt_remain(src, dst); + src++; + dst++; + } +} + +} // anonymous namespace + +void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { + DType src_dtype = src.layout.dtype; + DType dst_dtype = dst.layout.dtype; + size_t nr_elems = src.layout.total_nr_elems(); + bool execed = false; + if (src.layout.is_contiguous()) { + using namespace dtype; +#define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, \ + _midout_iv) \ + if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ + dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ + MIDOUT_BEGIN(megdnn_arm_typecvt_quantized, midout_iv(_midout_iv)) { \ + using _TypeCvter = QuantizedTypeCvter<_stype, _dtype>; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + do_typecvt<_TypeCvter>(src.compatible_ptr<_stype>(), \ + dst.compatible_ptr<_dtype>(), \ + src_dtype, dst_dtype, nr_elems)); \ + execed = true; \ + } \ + MIDOUT_END(); \ + } + + DISPATCH_QUANTIZED(QuantizedS32, int32_t, Quantized8Asymm, uint8_t, 0); + DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS8, int8_t, 1); + DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS32, int32_t, 2); + DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS8, int8_t, 3); + DISPATCH_QUANTIZED(Quantized8Asymm, uint8_t, Quantized8Asymm, uint8_t, + 4); + DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS32, int32_t, 5); + DISPATCH_QUANTIZED(float, float, QuantizedS8, int8_t, 6); + DISPATCH_QUANTIZED(float, float, Quantized8Asymm, uint8_t, 7); + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ + if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ + dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ + MIDOUT_BEGIN(megdnn_arm_typecvt_float, midout_iv(_midout_iv)) { \ + using _TypeCvter = FloatTypeCvter<_stype, _dtype>; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ + reinterpret_cast<_stype*>(src.raw_ptr), \ + reinterpret_cast<_dtype*>(dst.raw_ptr), src_dtype, \ + dst_dtype, nr_elems)); \ + execed = true; \ + } \ + MIDOUT_END(); \ + } + DISPATCH_FLOAT(dt_float16, __fp16, float, float, 0); + DISPATCH_FLOAT(float, float, dt_float16, __fp16, 1); +#endif + } + if (!execed) { + fallback::TypeCvtImpl::exec(src, dst); + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/type_cvt/opr_impl.h b/dnn/src/arm_common/type_cvt/opr_impl.h new file mode 100644 index 00000000..ab835b61 --- /dev/null +++ b/dnn/src/arm_common/type_cvt/opr_impl.h @@ -0,0 +1,28 @@ +/** + * \file dnn/src/arm_common/type_cvt/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/fallback/type_cvt/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class TypeCvtImpl : public fallback::TypeCvtImpl { +public: + using fallback::TypeCvtImpl::TypeCvtImpl; + + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) override; +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/utils.cpp b/dnn/src/arm_common/utils.cpp new file mode 100644 index 00000000..48e3b3fd --- /dev/null +++ b/dnn/src/arm_common/utils.cpp @@ -0,0 +1,184 @@ +/** + * \file dnn/src/arm_common/utils.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/common/utils.h" +#include +#include "src/arm_common/simd_macro/marm_neon.h" + +using namespace megdnn; + +namespace { + +template +void transpose_naive(const dtype *src, dtype *dst, + int lda, int ldb, int n, int m) +{ + rep(i, n) rep(j, m) { + dst[i*ldb + j] = src[j*lda + i]; + } +} + +void transpose_4x4_neon(const float *src, float *dst, int lda, int ldb) +{ + float32x4x2_t a0, a1; + a0.val[0] = vld1q_f32(src + 0*lda); + a0.val[1] = vld1q_f32(src + 1*lda); + a1.val[0] = vld1q_f32(src + 2*lda); + a1.val[1] = vld1q_f32(src + 3*lda); + float32x4x2_t b0 = vzipq_f32(a0.val[0], a1.val[0]); + float32x4x2_t b1 = vzipq_f32(a0.val[1], a1.val[1]); + float32x4x2_t c0 = vzipq_f32(b0.val[0], b1.val[0]); + float32x4x2_t c1 = vzipq_f32(b0.val[1], b1.val[1]); + vst1q_f32(dst + 0*ldb, c0.val[0]); + vst1q_f32(dst + 1*ldb, c0.val[1]); + vst1q_f32(dst + 2*ldb, c1.val[0]); + vst1q_f32(dst + 3*ldb, c1.val[1]); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void transpose_8x8_neon(const dt_float16 *src, dt_float16 *dst, int lda, int ldb) +{ + const __fp16* src_ptr = reinterpret_cast(src); + __fp16* dst_ptr = reinterpret_cast<__fp16*>(dst); + float16x8x4_t a0, a1; + a0.val[0] = vld1q_f16(src_ptr + 0*lda); // A0A1A2A3A4A5A6A7 + a0.val[1] = vld1q_f16(src_ptr + 1*lda); // B0B1B2B3B4B5B6B7 + a0.val[2] = vld1q_f16(src_ptr + 2*lda); // C0C1C2C3C4C5C6C7 + a0.val[3] = vld1q_f16(src_ptr + 3*lda); // D0D1D2D3D4D5D6D7 + a1.val[0] = vld1q_f16(src_ptr + 4*lda); // E0E1E2E3E4E5E6E7 + a1.val[1] = vld1q_f16(src_ptr + 5*lda); // F0F1F2F3F4F5F6F7 + a1.val[2] = vld1q_f16(src_ptr + 6*lda); // G0G1G2G3G4G5G6G7 + a1.val[3] = vld1q_f16(src_ptr + 7*lda); // H0H1H2H3H4H5H6H7 + + float16x8x2_t b0 = vzipq_f16(a0.val[0], a1.val[0]); // A0E0A1E1A2E2A3E3 A4E4A5E5A6E6A7E7 + float16x8x2_t b1 = vzipq_f16(a0.val[2], a1.val[2]); // C0G0C1G1C2G2C3G3 C4G4C5G5C6G6C7G7 + float16x8x2_t c0 = vzipq_f16(a0.val[1], a1.val[1]); // B0F0B1F1B2F2B3F3 B4F4B5F5B6F6B7F7 + float16x8x2_t c1 = vzipq_f16(a0.val[3], a1.val[3]); // D0H0D1H1D2H2D3H3 D4H4D5H5D6H6D7H7 + + float16x8x2_t d0 = vzipq_f16(b0.val[0], b1.val[0]); // A0C0E0G0A1C1E1G1 A2C2E2G2A3C3E3G3 + float16x8x2_t d1 = vzipq_f16(c0.val[0], c1.val[0]); // B0D0F0H0B1D1F1H1 B2D2F2H2B3D3F3H3 + float16x8x2_t e0 = vzipq_f16(d0.val[0], d1.val[0]); // A0B0C0D0E0F0G0H0 A1B1C1D1E1F1G1H1 + float16x8x2_t e1 = vzipq_f16(d0.val[1], d1.val[1]); // A2B2C2D2E2F2G2H2 A3B3C3D3E3F3G3H3 + + float16x8x2_t f0 = vzipq_f16(b0.val[1], b1.val[1]); // A4C4E4G4A5C5E5G5 A6C6E6G6A7C7E7G7 + float16x8x2_t f1 = vzipq_f16(c0.val[1], c1.val[1]); // B4D4F4H4B5D5F5H5 B6D6E6G6B7D7E7H7 + float16x8x2_t g0 = vzipq_f16(f0.val[0], f1.val[0]); // A4B4C4D4E4F4G4H4 A5B5C5D5E5F5G5H5 + float16x8x2_t g1 = vzipq_f16(f0.val[1], f1.val[1]); // A6B6C6D6E6F6G6H6 A7B7C7D7E7F7G7H7 + + vst1q_f16(dst_ptr + 0*ldb, e0.val[0]); + vst1q_f16(dst_ptr + 1*ldb, e0.val[1]); + vst1q_f16(dst_ptr + 2*ldb, e1.val[0]); + vst1q_f16(dst_ptr + 3*ldb, e1.val[1]); + vst1q_f16(dst_ptr + 4*ldb, g0.val[0]); + vst1q_f16(dst_ptr + 5*ldb, g0.val[1]); + vst1q_f16(dst_ptr + 6*ldb, g1.val[0]); + vst1q_f16(dst_ptr + 7*ldb, g1.val[1]); +} +#endif + +} // anonymous namespace + +namespace megdnn { + +template <> +void transpose(const float* src, float* dst, size_t m, size_t n, ptrdiff_t lds, + ptrdiff_t ldd) { + if (lds == -1) { + lds = n; + } + if (ldd == -1) { + ldd = m; + } + + for (size_t is = 0; is < n; is += 16) { + for (size_t js = 0; js < m; js += 16) { + auto ie = std::min(is + 16, n), je = std::min(js + 16, m), i = is; + for (; i + 4 <= ie; i += 4) { + auto j = js; + for (; j + 4 <= je; j += 4) { + transpose_4x4_neon(src + j * lds + i, dst + i * ldd + j, + lds, ldd); + } + if (j < je) { + transpose_naive(src + j * lds + i, dst + i * ldd + j, lds, + ldd, 4, je - j); + } + } + if (i < ie) { + transpose_naive(src + js * lds + i, dst + i * ldd + js, lds, + ldd, ie - i, je - js); + } + } + } +} + +template +void transpose_knc2nsck_helper(const dtype *src, dtype *dst, + size_t k, size_t n, size_t c, size_t n_stride) { + if (n_stride == k * c) { + // dst is contiguous + transpose(src, dst, k, n * c); + } else { + for (size_t i = 0; i < n; ++ i) { + transpose(src + i * c, dst + i * n_stride, + k, c, n * c); + } + } +} + +template <> +void transpose_knc2nsck(const float *src, float *dst, + size_t k, size_t n, size_t c, size_t n_stride) { + transpose_knc2nsck_helper(src, dst, k, n, c, n_stride); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +void transpose(const dt_float16* src, dt_float16* dst, size_t m, size_t n, + ptrdiff_t lds, ptrdiff_t ldd) { + if (lds == -1) { + lds = n; + } + if (ldd == -1) { + ldd = m; + } + + for (size_t is = 0; is < n; is += 16) { + for (size_t js = 0; js < m; js += 16) { + auto ie = std::min(is + 16, n), je = std::min(js + 16, m), i = is; + for (; i + 8 <= ie; i += 8) { + auto j = js; + for (; j + 8 <= je; j += 8) { + transpose_8x8_neon(src + j * lds + i, dst + i * ldd + j, + lds, ldd); + } + if (j < je) { + transpose_naive(src + j * lds + i, dst + i * ldd + j, lds, + ldd, 8, je - j); + } + } + if (i < ie) { + transpose_naive(src + js * lds + i, dst + i * ldd + js, lds, + ldd, ie - i, je - js); + } + } + } +} + +template <> +void transpose_knc2nsck(const dt_float16* src, dt_float16* dst, size_t k, + size_t n, size_t c, size_t n_stride) { + transpose_knc2nsck_helper(src, dst, k, n, c, n_stride); +} +#endif + +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/utils.h b/dnn/src/arm_common/utils.h new file mode 100644 index 00000000..3b790fbe --- /dev/null +++ b/dnn/src/arm_common/utils.h @@ -0,0 +1,464 @@ +/** + * \file dnn/src/arm_common/utils.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace arm_common { + +template +struct Vector; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +struct Vector<__fp16, 4> { + float16x4_t value; + Vector() {} + Vector(const __fp16 v) { value = vdup_n_f16(v); } + Vector(const Vector& lr) { value = lr.value; } + Vector(const Vector&& lr) { value = std::move(lr.value); } + Vector(const float16x4_t& v) { value = v; } + static Vector load(const __fp16* addr) { + Vector v; + v.value = vld1_f16(addr); + return v; + } + static void save(__fp16* addr, const Vector& v) { vst1_f16(addr, v.value); } + void save(__fp16* addr) { save(addr, *this); } + Vector operator+(const Vector& lr) { + Vector dst; + dst.value = value + lr.value; + return dst; + } + Vector& operator+=(const Vector& lr) { + value = value + lr.value; + return *this; + } + Vector operator-(const Vector& lr) { + Vector dst; + dst.value = value - lr.value; + return dst; + } + Vector& operator-=(const Vector& lr) { + value = value - lr.value; + return *this; + } + Vector operator*(__fp16 lr) { + Vector dst; +#if MEGDNN_AARCH64 + dst.value = vmul_n_f16(value, lr); +#else + dst.value = vmul_n_fix_f16(value, lr); +#endif + return dst; + } + Vector operator*(const Vector& lr) { + Vector dst; + dst.value = vmul_f16(value, lr.value); // value * lr.value; + return dst; + } + Vector& operator*=(const Vector& lr) { + value = vmul_f16(value, lr.value); + return *this; + } + Vector& operator=(const Vector& lr) { + value = lr.value; + return *this; + } + Vector& operator=(const Vector&& lr) { + value = std::move(lr.value); + return *this; + } + Vector operator-() { + Vector dst; + dst.value = -value; + return dst; + } +}; +template <> +struct Vector<__fp16, 8> { + float16x8_t value; + Vector() {} + Vector(const __fp16 v) { value = vdupq_n_f16(v); } + Vector(const Vector& lr) { value = lr.value; } + Vector(const Vector&& lr) { value = std::move(lr.value); } + Vector(const float16x8_t& v) { value = v; } + static Vector load(const __fp16* addr) { + Vector v; + v.value = vld1q_f16(addr); + return v; + } + static void save(__fp16* addr, const Vector& v) { + vst1q_f16(addr, v.value); + } + void save(__fp16* addr) { save(addr, *this); } + Vector operator+(const Vector& lr) { + Vector dst; + dst.value = value + lr.value; + return dst; + } + Vector& operator+=(const Vector& lr) { + value = value + lr.value; + return *this; + } + Vector operator-(const Vector& lr) { + Vector dst; + dst.value = value - lr.value; + return dst; + } + Vector& operator-=(const Vector& lr) { + value = value - lr.value; + return *this; + } + Vector operator*(__fp16 lr) { + Vector dst; +#if MEGDNN_AARCH64 + dst.value = vmulq_n_f16(value, lr); +#else + dst.value = vmulq_n_fix_f16(value, lr); +#endif + return dst; + } + Vector operator*(const Vector& lr) { + Vector dst; + dst.value = value * lr.value; + return dst; + } + Vector& operator*=(const Vector& lr) { + value = value * lr.value; + return *this; + } + Vector& operator=(const Vector& lr) { + value = lr.value; + return *this; + } + Vector& operator=(const Vector&& lr) { + value = std::move(lr.value); + return *this; + } + Vector operator-() { + Vector dst; + dst.value = -value; + return dst; + } +}; +#endif + +template <> +struct Vector { + float32x4_t value; + Vector() {} + Vector(const float v) { value = vdupq_n_f32(v); } + Vector(const Vector& lr) { value = lr.value; } + Vector(const Vector&& lr) { value = std::move(lr.value); } + Vector(const float32x4_t& v) { value = v; } + static Vector load(const float* addr) { + Vector v; + v.value = vld1q_f32(addr); + return v; + } + static void save(float* addr, const Vector& v) { vst1q_f32(addr, v.value); } + void save(float* addr) { save(addr, *this); } + Vector operator+(const Vector& lr) { + Vector dst; + dst.value = vaddq_f32(value, lr.value); + return dst; + } + Vector& operator+=(const Vector& lr) { + value = vaddq_f32(value, lr.value); + return *this; + } + Vector operator-(const Vector& lr) { + Vector dst; + dst.value = vsubq_f32(value, lr.value); + return dst; + } + Vector& operator-=(const Vector& lr) { + value = vsubq_f32(value, lr.value); + return *this; + } + Vector operator*(float lr) { + Vector dst; + dst.value = vmulq_n_f32(value, lr); + return dst; + } + Vector operator*(const Vector& lr) { + Vector dst; + dst.value = vmulq_f32(value, lr.value); + return dst; + } + Vector& operator*=(const Vector& lr) { + value = vmulq_f32(value, lr.value); + return *this; + } + Vector& operator=(const Vector& lr) { + value = lr.value; + return *this; + } + Vector& operator=(const Vector&& lr) { + value = std::move(lr.value); + return *this; + } + Vector operator-() { + Vector dst; + dst.value = -value; + return dst; + } +}; + +template <> +struct Vector { + float32x4x2_t value; + Vector() {} + Vector(const float v) { + value.val[0] = vdupq_n_f32(v); + value.val[1] = vdupq_n_f32(v); + } + Vector(const Vector& lr) { value = lr.value; } + Vector(const Vector&& lr) { value = std::move(lr.value); } + Vector(const float32x4x2_t& v) { value = v; } + static Vector load(const float* addr) { + Vector v; + v.value = vld1q_f32_x2(addr); + return v; + } + static void save(float* addr, const Vector& v) { + vst1q_f32_x2(addr, v.value); + } + + void save(float* addr) { save(addr, *this); } + Vector operator+(const Vector& lr) { + Vector dst; + dst.value.val[0] = vaddq_f32(value.val[0], lr.value.val[0]); + dst.value.val[1] = vaddq_f32(value.val[1], lr.value.val[1]); + return dst; + } + Vector& operator+=(const Vector& lr) { + value.val[0] = vaddq_f32(value.val[0], lr.value.val[0]); + value.val[1] = vaddq_f32(value.val[1], lr.value.val[1]); + return *this; + } + Vector& add(const Vector& lr) { + value.val[0] = vaddq_f32(value.val[0], lr.value.val[0]); + value.val[1] = vaddq_f32(value.val[1], lr.value.val[1]); + return *this; + } + Vector operator-(const Vector& lr) { + Vector dst; + dst.value.val[0] = vsubq_f32(value.val[0], lr.value.val[0]); + dst.value.val[1] = vsubq_f32(value.val[1], lr.value.val[1]); + return dst; + } + Vector& operator-=(const Vector& lr) { + value.val[0] = vsubq_f32(value.val[0], lr.value.val[0]); + value.val[1] = vsubq_f32(value.val[1], lr.value.val[1]); + return *this; + } + Vector operator*(float lr) { + Vector dst; + dst.value.val[0] = vmulq_n_f32(value.val[0], lr); + dst.value.val[1] = vmulq_n_f32(value.val[1], lr); + return dst; + } + //! val + lr * n + Vector& mla(const Vector& lr, float n) { + value.val[0] = vmlaq_n_f32(value.val[0], lr.value.val[0], n); + value.val[1] = vmlaq_n_f32(value.val[1], lr.value.val[1], n); + return *this; + } + + Vector operator*(const Vector& lr) { + Vector dst; + dst.value.val[0] = vmulq_f32(value.val[0], lr.value.val[0]); + dst.value.val[1] = vmulq_f32(value.val[1], lr.value.val[1]); + return dst; + } + Vector& operator*=(const Vector& lr) { + value.val[0] = vmulq_f32(value.val[0], lr.value.val[0]); + value.val[1] = vmulq_f32(value.val[1], lr.value.val[1]); + return *this; + } + Vector& operator=(const Vector& lr) { + value = lr.value; + return *this; + } + Vector& operator=(const Vector&& lr) { + value = std::move(lr.value); + return *this; + } + Vector operator-() { + Vector dst; + dst.value.val[0] = -value.val[0]; + dst.value.val[1] = -value.val[1]; + return dst; + } +}; + +template <> +struct Vector { + int16x8_t value; + Vector() {} + Vector(const int16_t v) { value = vdupq_n_s16(v); } + Vector(const Vector& lr) { value = lr.value; } + Vector(const Vector&& lr) { value = std::move(lr.value); } + Vector(const int16x8_t& v) { value = v; } + static Vector load(const int16_t* addr) { + Vector v; + v.value = vld1q_s16(addr); + return v; + } + static void save(int16_t* addr, const Vector& v) { + vst1q_s16(addr, v.value); + } + void save(int16_t* addr) { save(addr, *this); } + Vector operator+(const Vector& lr) { + Vector dst; + dst.value = vaddq_s16(value, lr.value); + return dst; + } + Vector& operator+=(const Vector& lr) { + value = vaddq_s16(value, lr.value); + return *this; + } + Vector operator-(const Vector& lr) { + Vector dst; + dst.value = vsubq_s16(value, lr.value); + return dst; + } + Vector& operator-=(const Vector& lr) { + value = vsubq_s16(value, lr.value); + return *this; + } + Vector operator*(int16_t lr) { + Vector dst; + dst.value = vmulq_n_s16(value, lr); + return dst; + } + Vector operator*(const Vector& lr) { + Vector dst; + dst.value = vmulq_s16(value, lr.value); + return dst; + } + Vector& operator*=(const Vector& lr) { + value = vmulq_s16(value, lr.value); + return *this; + } + Vector& operator=(const Vector& lr) { + value = lr.value; + return *this; + } + Vector& operator=(const Vector&& lr) { + value = std::move(lr.value); + return *this; + } + Vector operator-() { + Vector dst; + dst.value = -value; + return dst; + } +}; + +template <> +struct Vector { + int32x4x2_t value; + Vector() {} + Vector(const int32_t v) { + value.val[0] = vdupq_n_s32(v); + value.val[1] = vdupq_n_s32(v); + } + Vector(const Vector& lr) { value = lr.value; } + Vector(const Vector&& lr) { value = std::move(lr.value); } + Vector(const int32x4x2_t& v) { value = v; } + static Vector load(const int32_t* addr) { + Vector v; + v.value.val[0] = vld1q_s32(addr); + v.value.val[1] = vld1q_s32(addr + 4); + return v; + } + static void save(int32_t* addr, const Vector& v) { + vst1q_s32(addr, v.value.val[0]); + vst1q_s32(addr + 4, v.value.val[1]); + } + + void save(int32_t* addr) { save(addr, *this); } + Vector operator+(const Vector& lr) { + Vector dst; + dst.value.val[0] = vaddq_s32(value.val[0], lr.value.val[0]); + dst.value.val[1] = vaddq_s32(value.val[1], lr.value.val[1]); + return dst; + } + Vector& operator+=(const Vector& lr) { + value.val[0] = vaddq_s32(value.val[0], lr.value.val[0]); + value.val[1] = vaddq_s32(value.val[1], lr.value.val[1]); + return *this; + } + Vector& add(const Vector& lr) { + value.val[0] = vaddq_s32(value.val[0], lr.value.val[0]); + value.val[1] = vaddq_s32(value.val[1], lr.value.val[1]); + return *this; + } + Vector operator-(const Vector& lr) { + Vector dst; + dst.value.val[0] = vsubq_s32(value.val[0], lr.value.val[0]); + dst.value.val[1] = vsubq_s32(value.val[1], lr.value.val[1]); + return dst; + } + Vector& operator-=(const Vector& lr) { + value.val[0] = vsubq_s32(value.val[0], lr.value.val[0]); + value.val[1] = vsubq_s32(value.val[1], lr.value.val[1]); + return *this; + } + Vector operator*(int32_t lr) { + Vector dst; + dst.value.val[0] = vmulq_n_s32(value.val[0], lr); + dst.value.val[1] = vmulq_n_s32(value.val[1], lr); + return dst; + } + //! val + lr * n + Vector& mla(const Vector& lr, int32_t n) { + value.val[0] = vmlaq_n_s32(value.val[0], lr.value.val[0], n); + value.val[1] = vmlaq_n_s32(value.val[1], lr.value.val[1], n); + return *this; + } + Vector operator*(const Vector& lr) { + Vector dst; + dst.value.val[0] = vmulq_s32(value.val[0], lr.value.val[0]); + dst.value.val[1] = vmulq_s32(value.val[1], lr.value.val[1]); + return dst; + } + Vector& operator*=(const Vector& lr) { + value.val[0] = vmulq_s32(value.val[0], lr.value.val[0]); + value.val[1] = vmulq_s32(value.val[1], lr.value.val[1]); + return *this; + } + Vector& operator=(const Vector& lr) { + value = lr.value; + return *this; + } + Vector& operator=(const Vector&& lr) { + value = std::move(lr.value); + return *this; + } + Vector operator-() { + Vector dst; + dst.value.val[0] = -value.val[0]; + dst.value.val[1] = -value.val[1]; + return dst; + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/warp_affine/opr_impl.cpp b/dnn/src/arm_common/warp_affine/opr_impl.cpp new file mode 100644 index 00000000..0ef52511 --- /dev/null +++ b/dnn/src/arm_common/warp_affine/opr_impl.cpp @@ -0,0 +1,40 @@ +/** + * \file dnn/src/arm_common/warp_affine/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/handle.h" +#include "src/arm_common/warp_affine/opr_impl.h" +#include "src/arm_common/warp_affine/warp_affine_cv.h" +#include "src/common/warp_common.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_warpaffine) + +using namespace megdnn; +using namespace arm_common; +using namespace megcv; + +void WarpAffineImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(src.layout, mat.layout, dst.layout, workspace.size); + if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, + param().format)) { + MIDOUT_BEGIN(megdnn_arm_warpaffine, void) { + warp_affine_cv_exec(src, mat, dst, param().border_val, + param().border_mode, param().imode, handle()); + } + MIDOUT_END(); + } else { + //! Use fallback implementation + naive::WarpAffineImpl::exec(src, mat, dst, workspace); + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/warp_affine/opr_impl.h b/dnn/src/arm_common/warp_affine/opr_impl.h new file mode 100644 index 00000000..ad7cf3e2 --- /dev/null +++ b/dnn/src/arm_common/warp_affine/opr_impl.h @@ -0,0 +1,32 @@ +/** + * \file dnn/src/arm_common/warp_affine/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/naive/warp_affine/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class WarpAffineImpl : public naive::WarpAffineImpl { +public: + using naive::WarpAffineImpl::WarpAffineImpl; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, + _megdnn_tensor_in dst, _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/warp_affine/warp_affine_cv.cpp b/dnn/src/arm_common/warp_affine/warp_affine_cv.cpp new file mode 100644 index 00000000..58da0162 --- /dev/null +++ b/dnn/src/arm_common/warp_affine/warp_affine_cv.cpp @@ -0,0 +1,292 @@ +/** + * By downloading, copying, installing or using the software you agree to this license. + * If you do not agree to this license, do not download, install, + * copy or use the software. + * + * + * License Agreement + * For Open Source Computer Vision Library + * (3-clause BSD License) + * + * Copyright (C) 2000-2020, Intel Corporation, all rights reserved. + * Copyright (C) 2009-2011, Willow Garage Inc., all rights reserved. + * Copyright (C) 2009-2016, NVIDIA Corporation, all rights reserved. + * Copyright (C) 2010-2013, Advanced Micro Devices, Inc., all rights reserved. + * Copyright (C) 2015-2016, OpenCV Foundation, all rights reserved. + * Copyright (C) 2015-2016, Itseez Inc., all rights reserved. + * Copyright (C) 2019-2020, Xperience AI, all rights reserved. + * Third party copyrights are property of their respective owners. + * + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the names of the copyright holders nor the names of the contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * This software is provided by the copyright holders and contributors "as is" and + * any express or implied warranties, including, but not limited to, the implied + * warranties of merchantability and fitness for a particular purpose are disclaimed. + * In no event shall copyright holders or contributors be liable for any direct, + * indirect, incidental, special, exemplary, or consequential damages + * (including, but not limited to, procurement of substitute goods or services; + * loss of use, data, or profits; or business interruption) however caused + * and on any theory of liability, whether in contract, strict liability, + * or tort (including negligence or otherwise) arising in any way out of + * the use of this software, even if advised of the possibility of such damage. + * + * --------------------------------------------------------------------------- + * \file dnn/src/arm_common/warp_affine/warp_affine_cv.cpp + * + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * + * This file has been modified by Megvii ("Megvii Modifications"). + * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. + * + * --------------------------------------------------------------------------- + */ + +#include "src/arm_common/warp_affine/warp_affine_cv.h" +#include "src/arm_common/handle.h" + +#include "src/common/cv/common.h" +#include "src/common/cv/helper.h" +#include "src/common/cv/interp_helper.h" +#include "src/common/utils.h" +#include "src/common/warp_common.h" + +#include +#include + +#include "src/arm_common/simd_macro/marm_neon.h" + +using namespace megdnn; +using namespace arm_common; +using namespace megcv; +using namespace warp; + +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_warp_affine_cv) + +namespace { + +constexpr size_t BLOCK_SZ = 64_z; +template +void warp_affine_cv(const Mat& src, Mat& dst, const float* trans, + const float border_value, size_t task_id) { + // no extra padding + double M[6]; + rep(i, 6) M[i] = trans[i]; + T bvalue[3] = {(T)border_value, (T)border_value, (T)border_value}; + + std::vector _adelta(dst.cols() * 2); + int *adelta = _adelta.data(), *bdelta = adelta + dst.cols(); + // clang 3.6 can not deduce that `std::max(10, (int)INTER_BITS)' is a + // constant, which will cause compilation error in subsequent vshrq_n_s32. + constexpr int AB_BITS = 10 > INTER_BITS ? 10 : INTER_BITS; + constexpr int AB_SCALE = 1 << AB_BITS; + size_t dstcols = dst.cols(); + for (size_t x = 0; x < dstcols; ++x) { + adelta[x] = saturate_cast(M[0] * x * AB_SCALE); + bdelta[x] = saturate_cast(M[3] * x * AB_SCALE); + } + size_t x1, y1, dstrows = dst.rows(); + size_t BLOCK_SZ_H = std::min(BLOCK_SZ / 2, dstrows); + size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, dstcols); + BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, dstrows); + + size_t width_block_size = div_ceil(dstcols, BLOCK_SZ_W); + size_t y = (task_id / width_block_size) * BLOCK_SZ_H; + size_t x = (task_id % width_block_size) * BLOCK_SZ_W; + + short XY[BLOCK_SZ * BLOCK_SZ * 2], A[BLOCK_SZ * BLOCK_SZ]; + int round_delta = + (imode == IMode::INTER_NEAREST ? AB_SCALE / 2 + : AB_SCALE / INTER_TAB_SIZE / 2); + size_t bw = std::min(BLOCK_SZ_W, dstcols - x); + size_t bh = std::min(BLOCK_SZ_H, dstrows - y); + Mat _XY(bh, bw, 2, XY); + Mat dpart(dst, y, bh, x, bw); + + for (y1 = 0; y1 < bh; ++y1) { + short* xy = XY + y1 * bw * 2; + int X0 = saturate_cast((M[1] * (y + y1) + M[2]) * AB_SCALE) + + round_delta; + int Y0 = saturate_cast((M[4] * (y + y1) + M[5]) * AB_SCALE) + + round_delta; + + if (imode == IMode::INTER_NEAREST) { + x1 = 0; + + int32x4_t v_X0 = vdupq_n_s32(X0), v_Y0 = vdupq_n_s32(Y0); + for (; x1 + 8 <= bw; x1 += 8) { + int16x8x2_t v_dst; + v_dst.val[0] = vcombine_s16( + vqmovn_s32(vshrq_n_s32( + vaddq_s32(v_X0, vld1q_s32(adelta + x + x1)), + AB_BITS)), + vqmovn_s32(vshrq_n_s32( + vaddq_s32(v_X0, vld1q_s32(adelta + x + x1 + 4)), + AB_BITS))); + v_dst.val[1] = vcombine_s16( + vqmovn_s32(vshrq_n_s32( + vaddq_s32(v_Y0, vld1q_s32(bdelta + x + x1)), + AB_BITS)), + vqmovn_s32(vshrq_n_s32( + vaddq_s32(v_Y0, vld1q_s32(bdelta + x + x1 + 4)), + AB_BITS))); + + vst2q_s16(xy + (x1 << 1), v_dst); + } + + for (; x1 < bw; x1++) { + int X = (X0 + adelta[x + x1]) >> AB_BITS; + int Y = (Y0 + bdelta[x + x1]) >> AB_BITS; + xy[x1 * 2] = saturate_cast(X); + xy[x1 * 2 + 1] = saturate_cast(Y); + } + } else { + // if imode is not INTER_NEAREST + short* alpha = A + y1 * bw; + x1 = 0; + + int32x4_t v__X0 = vdupq_n_s32(X0), v__Y0 = vdupq_n_s32(Y0), + v_mask = vdupq_n_s32(INTER_TAB_SIZE - 1); + for (; x1 + 8 <= bw; x1 += 8) { + int32x4_t v_X0 = vshrq_n_s32( + vaddq_s32(v__X0, vld1q_s32(adelta + x + x1)), + AB_BITS - INTER_BITS); + int32x4_t v_Y0 = vshrq_n_s32( + vaddq_s32(v__Y0, vld1q_s32(bdelta + x + x1)), + AB_BITS - INTER_BITS); + int32x4_t v_X1 = vshrq_n_s32( + vaddq_s32(v__X0, vld1q_s32(adelta + x + x1 + 4)), + AB_BITS - INTER_BITS); + int32x4_t v_Y1 = vshrq_n_s32( + vaddq_s32(v__Y0, vld1q_s32(bdelta + x + x1 + 4)), + AB_BITS - INTER_BITS); + + int16x8x2_t v_xy; + v_xy.val[0] = + vcombine_s16(vqmovn_s32(vshrq_n_s32(v_X0, INTER_BITS)), + vqmovn_s32(vshrq_n_s32(v_X1, INTER_BITS))); + v_xy.val[1] = + vcombine_s16(vqmovn_s32(vshrq_n_s32(v_Y0, INTER_BITS)), + vqmovn_s32(vshrq_n_s32(v_Y1, INTER_BITS))); + + vst2q_s16(xy + (x1 << 1), v_xy); + + int16x4_t v_alpha0 = vmovn_s32(vaddq_s32( + vshlq_n_s32(vandq_s32(v_Y0, v_mask), INTER_BITS), + vandq_s32(v_X0, v_mask))); + int16x4_t v_alpha1 = vmovn_s32(vaddq_s32( + vshlq_n_s32(vandq_s32(v_Y1, v_mask), INTER_BITS), + vandq_s32(v_X1, v_mask))); + vst1q_s16(alpha + x1, vcombine_s16(v_alpha0, v_alpha1)); + } + + for (; x1 < bw; x1++) { + int X = (X0 + adelta[x + x1]) >> (AB_BITS - INTER_BITS); + int Y = (Y0 + bdelta[x + x1]) >> (AB_BITS - INTER_BITS); + xy[x1 * 2] = saturate_cast(X >> INTER_BITS); + xy[x1 * 2 + 1] = saturate_cast(Y >> INTER_BITS); + alpha[x1] = + (short)((Y & (INTER_TAB_SIZE - 1)) * INTER_TAB_SIZE + + (X & (INTER_TAB_SIZE - 1))); + } + } + } + Mat _matA(bh, bw, 1, (ushort*)(A)); + remap>(src, dpart, _XY, _matA, bvalue); +} +} // anonymous namespace + +void megdnn::arm_common::warp_affine_cv_exec( + _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst, + float border_value, BorderMode bmode, InterpolationMode imode, + Handle* handle) { + size_t ch = dst.layout[3]; + size_t width = dst.layout[2]; + size_t height = dst.layout[1]; + const size_t batch = dst.layout.shape[0]; + + size_t BLOCK_SZ_H = std::min(BLOCK_SZ / 2, height); + size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width); + BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height); + + size_t parallelism_batch = div_ceil(height, BLOCK_SZ_H) * + div_ceil(width, BLOCK_SZ_W); + + megdnn_assert(ch == 1 || ch == 3 || ch == 2, + "unsupported src channel: %zu, avaiable channel size: 1/2/3", + ch); + const float* trans_ptr = trans.ptr(); + if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { +#define cb(_imode, _bmode, _ch) \ + MIDOUT_BEGIN(megdnn_arm_common_warp_affine_cv, midout_iv(_imode), \ + midout_iv(_bmode), midout_iv(_ch), float) { \ + auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ + size_t index, size_t) { \ + size_t batch_id = index / parallelism_batch; \ + size_t task_id = index % parallelism_batch; \ + Mat src_mat = TensorND2Mat(src, batch_id); \ + Mat dst_mat = TensorND2Mat(dst, batch_id); \ + const float* task_trans_ptr = trans_ptr + batch_id * 2 * 3; \ + warp_affine_cv( \ + src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ + MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA \ + border_value, \ + task_id); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(handle), \ + batch* parallelism_batch, task); \ + } \ + MIDOUT_END(); + DISPATCH_IMODE(imode, bmode, ch, cb) + } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { +#undef cb +#define cb(_imode, _bmode, _ch) \ + MIDOUT_BEGIN(megdnn_arm_common_warp_affine_cv, midout_iv(_imode), \ + midout_iv(_bmode), midout_iv(_ch), uchar) { \ + auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ + size_t index, size_t) { \ + size_t batch_id = index / parallelism_batch; \ + size_t task_id = index % parallelism_batch; \ + Mat src_mat = TensorND2Mat(src, batch_id); \ + Mat dst_mat = TensorND2Mat(dst, batch_id); \ + const float* task_trans_ptr = trans_ptr + batch_id * 2 * 3; \ + warp_affine_cv( \ + src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ + MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA \ + border_value, \ + task_id); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(handle), \ + batch* parallelism_batch, task); \ + } \ + MIDOUT_END(); + DISPATCH_IMODE(imode, bmode, ch, cb) +#undef cb + } else { + megdnn_throw(megdnn_mangle("Unsupported datatype of WarpAffine optr.")); + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/warp_affine/warp_affine_cv.h b/dnn/src/arm_common/warp_affine/warp_affine_cv.h new file mode 100644 index 00000000..4005343b --- /dev/null +++ b/dnn/src/arm_common/warp_affine/warp_affine_cv.h @@ -0,0 +1,32 @@ +/** + * \file dnn/src/arm_common/warp_affine/warp_affine_cv.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include + +#include "src/common/cv/helper.h" + +namespace megdnn { +namespace arm_common { + +/** + * \fn warp_affine_cv + * \brief Used if the format is NHWC, transfer from megcv + */ +void warp_affine_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, + _megdnn_tensor_in dst, float border_value, + param::WarpAffine::BorderMode border_mode, + param::WarpAffine::InterpolationMode imode, + Handle* handle); + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/warp_perspective/opr_impl.cpp b/dnn/src/arm_common/warp_perspective/opr_impl.cpp new file mode 100644 index 00000000..1a98203f --- /dev/null +++ b/dnn/src/arm_common/warp_perspective/opr_impl.cpp @@ -0,0 +1,47 @@ +/** + * \file dnn/src/arm_common/warp_perspective/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/arm_common/warp_perspective/opr_impl.h" + +#include "src/arm_common/warp_perspective/warp_perspective_cv.h" + +#include "midout.h" +#include "src/common/utils.h" +#include "src/common/warp_common.h" +#include "src/naive/handle.h" + +MIDOUT_DECL(megdnn_arm_warpperspective) + +namespace megdnn { +namespace arm_common { + +void WarpPerspectiveImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, + _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, + _megdnn_workspace workspace) { + check_exec_allow_nhwc_mat_idx(src.layout, mat.layout, mat_idx.layout, + dst.layout, workspace.size); + if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, + param().format) && + !mat_idx.layout.ndim) { + MIDOUT_BEGIN(megdnn_arm_warpperspective, void) { + warp_perspective_cv_exec(src, mat, dst, param().border_val, + param().bmode, param().imode, handle()); + } + MIDOUT_END(); + } else { + //! Use fallback implementation + fallback::WarpPerspectiveImpl::exec(src, mat, mat_idx, dst, workspace); + } +} + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/warp_perspective/opr_impl.h b/dnn/src/arm_common/warp_perspective/opr_impl.h new file mode 100644 index 00000000..e533618e --- /dev/null +++ b/dnn/src/arm_common/warp_perspective/opr_impl.h @@ -0,0 +1,30 @@ +/** + * \file dnn/src/arm_common/warp_perspective/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "src/fallback/warp_perspective/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class WarpPerspectiveImpl : public fallback::WarpPerspectiveImpl { +public: + using fallback::WarpPerspectiveImpl::WarpPerspectiveImpl; + + void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, + _megdnn_tensor_in mat_idx, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/warp_perspective/warp_perspective_cv.cpp b/dnn/src/arm_common/warp_perspective/warp_perspective_cv.cpp new file mode 100644 index 00000000..7d67ee3d --- /dev/null +++ b/dnn/src/arm_common/warp_perspective/warp_perspective_cv.cpp @@ -0,0 +1,215 @@ +/** + * By downloading, copying, installing or using the software you agree to this license. + * If you do not agree to this license, do not download, install, + * copy or use the software. + * + * + * License Agreement + * For Open Source Computer Vision Library + * (3-clause BSD License) + * + * Copyright (C) 2000-2020, Intel Corporation, all rights reserved. + * Copyright (C) 2009-2011, Willow Garage Inc., all rights reserved. + * Copyright (C) 2009-2016, NVIDIA Corporation, all rights reserved. + * Copyright (C) 2010-2013, Advanced Micro Devices, Inc., all rights reserved. + * Copyright (C) 2015-2016, OpenCV Foundation, all rights reserved. + * Copyright (C) 2015-2016, Itseez Inc., all rights reserved. + * Copyright (C) 2019-2020, Xperience AI, all rights reserved. + * Third party copyrights are property of their respective owners. + * + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the names of the copyright holders nor the names of the contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * This software is provided by the copyright holders and contributors "as is" and + * any express or implied warranties, including, but not limited to, the implied + * warranties of merchantability and fitness for a particular purpose are disclaimed. + * In no event shall copyright holders or contributors be liable for any direct, + * indirect, incidental, special, exemplary, or consequential damages + * (including, but not limited to, procurement of substitute goods or services; + * loss of use, data, or profits; or business interruption) however caused + * and on any theory of liability, whether in contract, strict liability, + * or tort (including negligence or otherwise) arising in any way out of + * the use of this software, even if advised of the possibility of such damage. + * + * --------------------------------------------------------------------------- + * \file dnn/src/arm_common/warp_perspective/warp_perspective_cv.cpp + * + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * + * This file has been modified by Megvii ("Megvii Modifications"). + * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. + * + * --------------------------------------------------------------------------- + */ +#include "src/arm_common/warp_perspective/warp_perspective_cv.h" +#include "src/arm_common/handle.h" +#include "src/common/cv/common.h" +#include "src/common/cv/helper.h" +#include "src/common/cv/interp_helper.h" +#include "src/common/utils.h" +#include "src/common/warp_common.h" +#include +#include + +#include "src/arm_common/simd_macro/marm_neon.h" + +using namespace megdnn; +using namespace arm_common; +using namespace megcv; +using namespace warp; + +namespace { + +constexpr size_t BLOCK_SZ = 32u; +template +void warp_perspective_cv(const Mat& src, Mat& dst, const float* trans, + const float border_value, size_t task_id) { + // no extra padding + double M[9]; + rep(i, 9) M[i] = trans[i]; + T bvalue[3] = {(T)border_value, (T)border_value, (T)border_value}; + + size_t x1, y1, width = dst.cols(), height = dst.rows(); + size_t BLOCK_SZ_H = std::min(BLOCK_SZ / 2, height); + size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width); + BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height); + + size_t width_block_size = div_ceil(width, BLOCK_SZ_W); + size_t y = (task_id / width_block_size) * BLOCK_SZ_H; + size_t x = (task_id % width_block_size) * BLOCK_SZ_W; + + // start invoke + short XY[BLOCK_SZ * BLOCK_SZ * 2], A[BLOCK_SZ * BLOCK_SZ]; + size_t bw = std::min(BLOCK_SZ_W, width - x); + size_t bh = std::min(BLOCK_SZ_H, height - y); // height + Mat _XY(bh, bw, 2, XY); + Mat dpart(dst, y, bh, x, bw); + for (y1 = 0; y1 < bh; y1++) { + short* xy = XY + y1 * bw * 2; + double X0 = M[0] * x + M[1] * (y + y1) + M[2]; + double Y0 = M[3] * x + M[4] * (y + y1) + M[5]; + double W0 = M[6] * x + M[7] * (y + y1) + M[8]; + if (imode == IMode::NEAREST) + for (x1 = 0; x1 < bw; x1++) { + double W = W0 + M[6] * x1; + W = W ? 1. / W : 0; + double fX = std::max( + (double)INT_MIN, + std::min((double)INT_MAX, (X0 + M[0] * x1) * W)); + double fY = std::max( + (double)INT_MIN, + std::min((double)INT_MAX, (Y0 + M[3] * x1) * W)); + int X = saturate_cast(fX); + int Y = saturate_cast(fY); + xy[x1 * 2] = saturate_cast(X); + xy[x1 * 2 + 1] = saturate_cast(Y); + } + else { + short* alpha = A + y1 * bw; + for (x1 = 0; x1 < bw; x1++) { + double W = W0 + M[6] * x1; + W = W ? INTER_TAB_SIZE / W : 0; + double fX = std::max( + (double)INT_MIN, + std::min((double)INT_MAX, (X0 + M[0] * x1) * W)); + double fY = std::max( + (double)INT_MIN, + std::min((double)INT_MAX, (Y0 + M[3] * x1) * W)); + int X = saturate_cast(fX); + int Y = saturate_cast(fY); + xy[x1 * 2] = saturate_cast(X >> INTER_BITS); + xy[x1 * 2 + 1] = saturate_cast(Y >> INTER_BITS); + alpha[x1] = + (short)((Y & (INTER_TAB_SIZE - 1)) * INTER_TAB_SIZE + + (X & (INTER_TAB_SIZE - 1))); + } + } + } + Mat _matA(bh, bw, 1, (ushort*)(A)); + remap>(src, dpart, _XY, _matA, bvalue); +} + +} // anonymous namespace + +void megdnn::arm_common::warp_perspective_cv_exec( + _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst, + float border_value, BorderMode bmode, InterpolationMode imode, + Handle* handle) { + size_t ch = dst.layout[3]; + size_t width = dst.layout[2]; + size_t height = dst.layout[1]; + const size_t batch = dst.layout.shape[0]; + + size_t BLOCK_SZ_H = std::min(BLOCK_SZ / 2, height); + size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width); + BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height); + + size_t parallelism_batch = div_ceil(height, BLOCK_SZ_H) * + div_ceil(width, BLOCK_SZ_W); + megdnn_assert(ch == 1 || ch == 3 || ch == 2, + "unsupported src channel: %zu, avaiable channel size: 1/2/3", + ch); + const float* trans_ptr = trans.ptr(); + if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { +#define cb(_imode, _bmode, _ch) \ + auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ + size_t index, size_t) { \ + size_t batch_id = index / parallelism_batch; \ + size_t task_id = index % parallelism_batch; \ + Mat src_mat = TensorND2Mat(src, batch_id); \ + Mat dst_mat = TensorND2Mat(dst, batch_id); \ + const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ + warp_perspective_cv( \ + src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ + MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ + task_id); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(handle), batch* parallelism_batch, \ + task); + DISPATCH_IMODE(imode, bmode, ch, cb) +#undef cb + } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { +#define cb(_imode, _bmode, _ch) \ + auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ + size_t index, size_t) { \ + size_t batch_id = index / parallelism_batch; \ + size_t task_id = index % parallelism_batch; \ + Mat src_mat = TensorND2Mat(src, batch_id); \ + Mat dst_mat = TensorND2Mat(dst, batch_id); \ + const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ + warp_perspective_cv( \ + src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ + MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ + task_id); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(handle), batch* parallelism_batch, \ + task); + DISPATCH_IMODE(imode, bmode, ch, cb) +#undef cb + } else { + megdnn_throw(megdnn_mangle("Unsupported datatype of WarpAffine optr.")); + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/warp_perspective/warp_perspective_cv.h b/dnn/src/arm_common/warp_perspective/warp_perspective_cv.h new file mode 100644 index 00000000..75f097a7 --- /dev/null +++ b/dnn/src/arm_common/warp_perspective/warp_perspective_cv.h @@ -0,0 +1,32 @@ +/** + * \file dnn/src/arm_common/warp_perspective/warp_perspective_cv.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include + +#include "src/common/cv/helper.h" + +namespace megdnn { +namespace arm_common { + +/** + * \fn warp_perspective_cv + * \brief Used if the format is NHWC, transfer from megcv + */ +void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, + _megdnn_tensor_in dst, float border_value, + param::WarpPerspective::BorderMode border_mode, + param::WarpPerspective::InterpolationMode imode, + Handle* handle); + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp b/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp new file mode 100644 index 00000000..ce6e4a0c --- /dev/null +++ b/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp @@ -0,0 +1,132 @@ +/** + * \file dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/winograd_filter_preprocess/opr_impl.h" +#include "src/arm_common/handle.h" +#include "src/common/utils.h" +#include "src/arm_common/conv_bias/fp32/strategy.h" +#include "src/arm_common/conv_bias/int8/strategy.h" +#include "src/arm_common/conv_bias/f16/strategy.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_filter_preprocess) + +using namespace megdnn; +using namespace arm_common; + +void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + using namespace winograd; + check_exec(src.layout, dst.layout, workspace.size); + + size_t flt_start = 0; + size_t group = 1; + if (src.layout.ndim == 5) { + flt_start = 1; + group = src.layout[0]; + } + size_t OC = src.layout[flt_start], IC = src.layout[flt_start + 1], + FW = src.layout[flt_start + 3]; + size_t m = param().output_block_size; + + bool execed = false; + +#define DISPATCH(_strategy, _format, ...) \ + MIDOUT_BEGIN(megdnn_arm_common_winograd_filter_preprocess, \ + ##__VA_ARGS__) { \ + if (param().format == _format) { \ + for (size_t g = 0; g < group; g++) { \ + auto run = [=]() { \ + _strategy strategy(src.layout.dtype, src.layout.dtype, \ + src.layout.dtype); \ + megdnn::winograd::ConvBias<_strategy, _format>( \ + strategy, 1, 1, 1, 1, 1) \ + .filter_process(src_ptr, dst_ptr, workspace_ptr, \ + OC, IC); \ + }; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(run()); \ + src_ptr += src.layout.stride[0]; \ + dst_ptr += dst.layout.stride[0]; \ + } \ + execed = true; \ + } \ + } \ + MIDOUT_END(); + + if (src.layout.dtype.enumv() == DTypeEnum::Float32) { + const float* src_ptr = src.ptr(); + float* dst_ptr = dst.ptr(); + float* workspace_ptr = workspace.ptr(); + if (FW == 3) { + if (m == 2) { + DISPATCH(winograd_2x3_4x4_f, param::Winograd::Format::MK4, 0, + 0); + } else if (m == 6) { + DISPATCH(winograd_6x3_1x1_f, param::Winograd::Format::DEFAULT, + 0, 1); + DISPATCH(winograd_6x3_4x4_f, param::Winograd::Format::MK4, 0, + 2); + } + } else if (FW == 4) { + if (m == 5) { + DISPATCH(winograd_5x4_1x1_f, param::Winograd::Format::DEFAULT, + 0, 3); + } + } else if (FW == 5) { + if (m == 4) { + DISPATCH(winograd_4x5_1x1_f, param::Winograd::Format::DEFAULT, + 0, 4); + } + } + } + if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { + const dt_int8* src_ptr = src.compatible_ptr(); + dt_int16* dst_ptr = dst.compatible_ptr(); + dt_int16* workspace_ptr = workspace.ptr(); + if (FW == 3) { + if (m == 2) { + DISPATCH(winograd_2x3_8x8_s8, param::Winograd::Format::MK8, 1, + 0); + } + } + } +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (src.layout.dtype.enumv() == DTypeEnum::Float16) { + const dt_float16* src_ptr = src.ptr(); + dt_float16* dst_ptr = dst.ptr(); + dt_float16* workspace_ptr = workspace.ptr(); + if (FW == 3) { + if (m == 2) { + DISPATCH(winograd_2x3_4x4_f16, param::Winograd::Format::DEFAULT, + 2, 0); + DISPATCH(winograd_2x3_8x8_f16, param::Winograd::Format::MK8, 2, + 1); + } else if (m == 6) { + DISPATCH(winograd_6x3_1x1_f16, param::Winograd::Format::DEFAULT, + 2, 2); + } + } else if (FW == 5) { + if (m == 4) { + DISPATCH(winograd_4x5_1x1_f16, param::Winograd::Format::DEFAULT, + 2, 3); + } + } + } +#endif +#undef DISPATCH + + megdnn_assert(execed, + "Unsupport winograd filter preprocess. m: %zu src: %s", m, + src.layout.to_string().c_str()); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.h b/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.h new file mode 100644 index 00000000..e2e5bb65 --- /dev/null +++ b/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.h @@ -0,0 +1,28 @@ +/** + * \file dnn/src/arm_common/winograd_filter_preprocess/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace arm_common { + +class WinogradFilterPreprocessImpl : public WinogradFilterPreprocess { +public: + using WinogradFilterPreprocess::WinogradFilterPreprocess; + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/conv_bias/int8/algos.cpp b/dnn/src/armv7/conv_bias/int8/algos.cpp new file mode 100644 index 00000000..5ead3cf1 --- /dev/null +++ b/dnn/src/armv7/conv_bias/int8/algos.cpp @@ -0,0 +1,176 @@ +/** + * \file dnn/src/armv7/conv_bias/int8/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/conv_bias/int8/algos.h" +#include "src/arm_common/convolution/img2col_helper.h" +#include "src/armv7/conv_bias/int8/strategy.h" +#include "src/common/opr_delegate.h" +#include "src/fallback/conv_bias/common.h" +#include "src/fallback/matrix_mul/gemm_impl.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_armv7_conv_bias_int8) + +using namespace megdnn; +using namespace armv7; + +/* ===================== matrix mul algo ===================== */ + +bool ConvBiasImpl::AlgoS8MatrixMul::usable( + FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(opr); + auto&& fm = param.filter_meta; + return param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.dst_type.enumv() == DTypeEnum::QuantizedS8 && + fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + //! As postprocess, the bias is not contigous read, make the + //! performance bad, so we do not process it in fused kernel + param.bias_mode != BiasMode::BIAS && + //! This algo is only support single thread + param.nr_threads == 1_z; +} + +WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( + const NCBKernSizeParam& param) { + UNPACK_CONV_NCB_KERN_SIZES(param); + MEGDNN_MARK_USED_VAR(N); + auto IW2 = IH + 2 * PH; + auto IH2 = IW + 2 * PW; + bool can_matrix_mul_direct = + (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0); + // temp space to store padding-free src (with 16 extra int8) + // temp space to store unrolled matrix (with 16 extra int8) + // workspace for matrix mul opr + size_t part0, part1, part2; + if (can_matrix_mul_direct) { + part0 = part1 = 0; + } else { + part0 = (IC * IH2 * IW2 + 16) * sizeof(int8_t); + part1 = (IC * FH * FW * OH * OW + 16) * sizeof(int8_t); + } + { + size_t M = OC; + size_t K = IC * FH * FW; + size_t N = OH * OW; + +#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ + _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN(megdnn_armv7_conv_bias_int8, 0, _gemm_midout_enum, \ + _bias_midout_enum, _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + part2 = megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ + M, N, K, false, false, strategy) \ + .get_workspace_size(); \ + } \ + MIDOUT_END() + + DISPATCH_GEMM_BIAS(s8_4x2, 0) + +#undef DISPATCH_GEMM_STRATEGY + } + return {nullptr, {part0, part1, part2}}; +} + +void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, + const NCBKernIndex& ncb_index) { + auto is_xcorr = !param.filter_meta.should_flip; + UNPACK_CONV_NCB_KERN_SIZES(param); + auto bundle = get_bundle(param); + bundle.set(param.workspace_ptr); + auto IH2 = IH + 2 * PH; + auto IW2 = IW + 2 * PW; + size_t group_id = ncb_index.ndrange_id[0]; + // workspace = tmp..src2 + for (size_t n = 0; n < N; ++n) { + dt_int8* src = const_cast(param.src(n, group_id)); + dt_int8* filter = const_cast(param.filter(group_id)); + dt_int8* dst = static_cast(param.dst(n, group_id)); + dt_int32* bias = const_cast(param.bias(n, group_id)); + + dt_int8 *B, *src2; + if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { + // special case: 1x1 + B = const_cast(src); + } else { + src2 = static_cast(bundle.get(0)); + // copy src to src2; + dt_int8* src2_ptr = src2; + const dt_int8* src_ptr = src; + rep(ic, IC) { + if (PH != 0) { + std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2); + src2_ptr += PH * IW2; + } + rep(ih, IH) { + if (PW != 0) + rep(pw, PW) { *(src2_ptr++) = 0.0f; } + std::memcpy(src2_ptr, src_ptr, sizeof(dt_int8) * IW); + src2_ptr += IW; + src_ptr += IW; + if (PW != 0) + rep(pw, PW) { *(src2_ptr++) = 0.0f; } + } + if (PH != 0) { + std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2); + src2_ptr += PH * IW2; + } + } + + B = static_cast(bundle.get(1)); + if (SH == 1 && SW == 1) { + if (is_xcorr) + img2col(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); + else + img2col(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); + } else { + if (is_xcorr) + img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, + FW, SH, SW); + else + img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, + FW, SH, SW); + } + } + { + Workspace workspace(static_cast(bundle.get(2)), + bundle.get_size(2)); + size_t M = OC; + size_t K = IC * FH * FW; + size_t N = OH * OW; + +#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ + _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN(megdnn_armv7_conv_bias_int8, 1, _gemm_midout_enum, \ + _bias_midout_enum, _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline> \ + gemm_interleaved(M, N, K, false, false, strategy); \ + gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ + bias); \ + } \ + MIDOUT_END() + + DISPATCH_GEMM_BIAS(s8_4x2, 0) +#undef DISPATCH_GEMM_STRATEGY + } + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/conv_bias/int8/algos.h b/dnn/src/armv7/conv_bias/int8/algos.h new file mode 100644 index 00000000..9199584e --- /dev/null +++ b/dnn/src/armv7/conv_bias/int8/algos.h @@ -0,0 +1,47 @@ +/** + * \file dnn/src/armv7/conv_bias/int8/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/armv7/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/opr_impl.h" + +namespace megdnn { +namespace armv7 { + +using FallbackConvBiasImpl = fallback::ConvBiasImpl; + +class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase { + static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); + static void kimpl(const NCBKernParam& param, const NCBKernIndex&); + +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "S8MATMUL"; } + + bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(FallbackConvBiasImpl*, + const NCBKernSizeParam& param) const override { + return get_bundle(param).total_size_in_bytes(); + } + SmallVector dispatch_kerns( + FallbackConvBiasImpl*, + const NCBKernSizeParam& param) const override { + size_t group = param.filter_meta.group; + return {{kimpl, {group, 1_z, 1_z}}}; + } +}; + +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/conv_bias/int8/strategy.cpp b/dnn/src/armv7/conv_bias/int8/strategy.cpp new file mode 100644 index 00000000..c92fae47 --- /dev/null +++ b/dnn/src/armv7/conv_bias/int8/strategy.cpp @@ -0,0 +1,160 @@ +/** + * \file dnn/src/armv7/conv_bias/int8/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/conv_bias/int8/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +#include "src/arm_common/conv_bias/matmul_postprocess.h" +#include "src/armv7/matrix_mul/int8/kernel_4x2x16.h" + +using namespace megdnn; +using namespace armv7; +using namespace armv7::matmul; + +namespace impl { +template +struct KernCaller; + +template +struct KernCaller { + static void run(const dt_int8* packA, const dt_int8* packB, size_t M, + size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, + Op op, const dt_int32* bias, dt_int32* workspace) { + megdnn_assert(is_first_k); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 2; + //! K is packed to times of 4 + K = round_up(K, 16); + const int K4 = K * 4; + const int K2 = K * 2; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int8_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_4x2x16::kern_4x2(packA, cur_packB, K, workspace, 2, + is_first_k, 4, 2); + arm_common::ConvBiasMatmul::postprocess(bias, workspace, + output, LDC, op); + output += B_INTERLEAVE; + cur_packB += K2; + } + + for (; n < N; n += B_INTERLEAVE) { + matmul_4x2x16::kern_4x2(packA, cur_packB, K, workspace, 2, + is_first_k, 4, + std::min(N - n, 2)); +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_N(cb, 4, std::min(N - n, 2)); +#undef cb + output += B_INTERLEAVE; + cur_packB += K2; + } + + packA += K4; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias += A_INTERLEAVE; + } + } + + for (; m < M; m += A_INTERLEAVE) { + int8_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n < N; n += B_INTERLEAVE) { + matmul_4x2x16::kern_4x2(packA, cur_packB, K, workspace, 2, + is_first_k, std::min(M - m, 4), + std::min(N - n, 2)); +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_M(cb, std::min(M - m, 4), + std::min(N - n, 2)); +#undef cb + + output += B_INTERLEAVE; + cur_packB += K2; + } + packA += K4; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias += A_INTERLEAVE; + } + } + } +}; + +} // namespace impl + +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x2_nobias_identity) + +void gemm_s8_4x2_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax, bool /*transpose*/) const { + MEGDNN_MARK_USED_VAR(matmul_4x2x16::gemm_s8_4x2_pack_A_t); + matmul_4x2x16::gemm_s8_4x2_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, + kmax); +} + +void gemm_s8_4x2_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, + int ldin, int x0, int xmax, int k0, + int kmax, bool /*transpose*/) const { + MEGDNN_MARK_USED_VAR(matmul_4x2x16::gemm_s8_4x2_pack_B_t); + matmul_4x2x16::gemm_s8_4x2_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); +} + +size_t gemm_s8_4x2_nobias_identity::get_workspace_size() const { + return 4 * 2 * sizeof(dt_int32); +} + +#define KERN(_bias, _BIAS, _nonline, _OP) \ + void gemm_s8_4x2_##_bias##_##_nonline::kern( \ + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \ + size_t K, dt_int8* C, size_t LDC, bool is_first_k, \ + const dt_int32* bias, dt_int32* workspace) const { \ + float scale_A = A_dtype.param().scale; \ + float scale_B = B_dtype.param().scale; \ + float scale_C = C_dtype.param().scale; \ + DEFINE_OP(_OP); \ + impl::KernCaller<_BIAS, decltype(op), 4, 2>::run( \ + packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ + workspace); \ + } + +#define DEFINE_OP(_Op) \ + arm_common::_Op op(scale_A* scale_B, scale_C); + +KERN(nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) +KERN(nobias, BiasMode::NO_BIAS, relu, ReluOp) +KERN(nobias, BiasMode::NO_BIAS, hswish, HSwishOp) +#undef DEFINE_OP + +#define DEFINE_OP(_Op) \ + arm_common::_Op op(scale_A* scale_B, \ + scale_A* scale_B, scale_C); +KERN(bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) +KERN(bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) +KERN(bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) +#undef DEFINE_OP + +#undef KERN + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/conv_bias/int8/strategy.h b/dnn/src/armv7/conv_bias/int8/strategy.h new file mode 100644 index 00000000..5ba29af3 --- /dev/null +++ b/dnn/src/armv7/conv_bias/int8/strategy.h @@ -0,0 +1,46 @@ +/** + * \file dnn/src/armv7/conv_bias/int8/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul { + +/** + * \brief base strategy of gemm. + * + * \name gemm___biasmode_nolinemode + */ +MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 2, 16, + false, true, + gemm_s8_4x2_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x2_nobias_relu, + gemm_s8_4x2_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x2_nobias_hswish, + gemm_s8_4x2_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x2_bias_channel_identity, + gemm_s8_4x2_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x2_bias_channel_relu, + gemm_s8_4x2_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x2_bias_channel_hswish, + gemm_s8_4x2_nobias_identity); + +} // namespace matmul +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/conv_bias/opr_impl.cpp b/dnn/src/armv7/conv_bias/opr_impl.cpp new file mode 100644 index 00000000..76602933 --- /dev/null +++ b/dnn/src/armv7/conv_bias/opr_impl.cpp @@ -0,0 +1,50 @@ +/** + * \file dnn/src/armv7/conv_bias/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/conv_bias/opr_impl.h" +#include "src/armv7/conv_bias/int8/algos.h" +#include "src/armv7/conv_bias/quint8/algos.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" +#include "src/common/metahelper.h" + +#include "src/fallback/convolution/opr_impl.h" + +using namespace megdnn; +using namespace armv7; + +class ConvBiasImpl::AlgoPack : NonCopyableObj { + AlgoS8MatrixMul s8_matrix_mul; + AlgoQU8MatrixMul qu8_matrix_mul; +public: + AlgoPack() { + all_algos.emplace_back(&qu8_matrix_mul); + all_algos.emplace_back(&s8_matrix_mul); + } + SmallVector all_algos; +}; + +SmallVector ConvBiasImpl::algo_pack() { + static AlgoPack sl_algo_pack; + auto&& algos = arm_common::ConvBiasImpl::algo_pack(); + //! TODO fused matmul bias is slower than matmul + elemwise in armv7 now, + //! and nearly equal in aarch64, because of the waste of register in + //! postprocess + algos.insert(algos.end(), sl_algo_pack.all_algos.begin(), + sl_algo_pack.all_algos.end()); + return std::move(algos); +} + +const char* ConvBiasImpl::get_algorithm_set_name() const { + return "ARMV7"; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/conv_bias/opr_impl.h b/dnn/src/armv7/conv_bias/opr_impl.h new file mode 100644 index 00000000..4cbf4b06 --- /dev/null +++ b/dnn/src/armv7/conv_bias/opr_impl.h @@ -0,0 +1,37 @@ +/** + * \file dnn/src/armv7/conv_bias/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace armv7 { + +class ConvBiasImpl : public arm_common::ConvBiasImpl { +public: + using arm_common::ConvBiasImpl::ConvBiasImpl; + + SmallVector algo_pack() override; + +protected: + + const char* get_algorithm_set_name() const override; + +private: + class AlgoS8MatrixMul; + class AlgoQU8MatrixMul; + class AlgoPack; +}; + +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/conv_bias/quint8/algos.cpp b/dnn/src/armv7/conv_bias/quint8/algos.cpp new file mode 100644 index 00000000..648dda05 --- /dev/null +++ b/dnn/src/armv7/conv_bias/quint8/algos.cpp @@ -0,0 +1,178 @@ +/** + * \file dnn/src/armv7/conv_bias/quint8/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/conv_bias/quint8/algos.h" +#include "src/arm_common/convolution/img2col_helper.h" +#include "src/armv7/conv_bias/quint8/strategy.h" +#include "src/common/opr_delegate.h" +#include "src/fallback/conv_bias/common.h" +#include "src/fallback/matrix_mul/gemm_impl.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_armv7_conv_bias_quint8) + +using namespace megdnn; +using namespace armv7; + +/* ===================== matrix mul algo ===================== */ + +bool ConvBiasImpl::AlgoQU8MatrixMul::usable( + FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(opr); + auto&& fm = param.filter_meta; + return param.src_type.enumv() == DTypeEnum::Quantized8Asymm && + param.dst_type.enumv() == DTypeEnum::Quantized8Asymm && + fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + //! As postprocess, the bias is not contigous read, make the + //! performance bad, so we do not process it in fused kernel + param.bias_mode != BiasMode::BIAS && + //! This algo is only support single thread + param.nr_threads == 1_z; +} + +WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( + const NCBKernSizeParam& param) { + UNPACK_CONV_NCB_KERN_SIZES(param); + MEGDNN_MARK_USED_VAR(N); + auto IW2 = IH + 2 * PH; + auto IH2 = IW + 2 * PW; + bool can_matrix_mul_direct = + (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0); + // temp space to store padding-free src (with 16 extra int8) + // temp space to store unrolled matrix (with 16 extra int8) + // workspace for matrix mul opr + size_t part0, part1, part2; + if (can_matrix_mul_direct) { + part0 = part1 = 0; + } else { + part0 = (IC * IH2 * IW2 + 16) * sizeof(uint8_t); + part1 = (IC * FH * FW * OH * OW + 16) * sizeof(uint8_t); + } + { + size_t M = OC; + size_t K = IC * FH * FW; + size_t N = OH * OW; + +#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ + _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN(megdnn_armv7_conv_bias_quint8, 0, _gemm_midout_enum, \ + _bias_midout_enum, _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + part2 = megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ + M, N, K, false, false, strategy) \ + .get_workspace_size(); \ + } \ + MIDOUT_END() + + DISPATCH_GEMM_BIAS(u8_4x8, 0) + +#undef DISPATCH_GEMM_STRATEGY + } + return {nullptr, {part0, part1, part2}}; +} + +void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, + const NCBKernIndex& ncb_index) { + auto is_xcorr = !param.filter_meta.should_flip; + UNPACK_CONV_NCB_KERN_SIZES(param); + auto bundle = get_bundle(param); + bundle.set(param.workspace_ptr); + auto IH2 = IH + 2 * PH; + auto IW2 = IW + 2 * PW; + size_t group_id = ncb_index.ndrange_id[0]; + uint8_t src_zp = param.src_type.param().zero_point; + // workspace = tmp..src2 + for (size_t n = 0; n < N; ++n) { + uint8_t* src = const_cast(param.src(n, group_id)); + uint8_t* filter = const_cast(param.filter(group_id)); + uint8_t* dst = static_cast(param.dst(n, group_id)); + int32_t* bias = const_cast(param.bias(n, group_id)); + + uint8_t *B, *src2; + + if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { + // special case: 1x1 + B = const_cast(src); + } else { + src2 = static_cast(bundle.get(0)); + // copy src to src2; + uint8_t* src2_ptr = src2; + const uint8_t* src_ptr = src; + rep(ic, IC) { + if (PH != 0) { + std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2); + src2_ptr += PH * IW2; + } + rep(ih, IH) { + if (PW != 0) + rep(pw, PW) { *(src2_ptr++) = src_zp; } + std::memcpy(src2_ptr, src_ptr, sizeof(uint8_t) * IW); + src2_ptr += IW; + src_ptr += IW; + if (PW != 0) + rep(pw, PW) { *(src2_ptr++) = src_zp; } + } + if (PH != 0) { + std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2); + src2_ptr += PH * IW2; + } + } + + B = static_cast(bundle.get(1)); + if (SH == 1 && SW == 1) { + if (is_xcorr) + img2col(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); + else + img2col(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); + } else { + if (is_xcorr) + img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, + FW, SH, SW); + else + img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, + FW, SH, SW); + } + } + { + Workspace workspace(static_cast(bundle.get(2)), + bundle.get_size(2)); + size_t M = OC; + size_t K = IC * FH * FW; + size_t N = OH * OW; + +#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ + _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN(megdnn_armv7_conv_bias_quint8, 1, _gemm_midout_enum, \ + _bias_midout_enum, _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline> \ + gemm_interleaved(M, N, K, false, false, strategy); \ + gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ + bias); \ + } \ + MIDOUT_END() + + DISPATCH_GEMM_BIAS(u8_4x8, 0) +#undef DISPATCH_GEMM_STRATEGY + } + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/conv_bias/quint8/algos.h b/dnn/src/armv7/conv_bias/quint8/algos.h new file mode 100644 index 00000000..d7399cd1 --- /dev/null +++ b/dnn/src/armv7/conv_bias/quint8/algos.h @@ -0,0 +1,48 @@ +/** + * \file dnn/src/armv7/conv_bias/quint8/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/armv7/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/opr_impl.h" + +namespace megdnn { +namespace armv7 { + +using FallbackConvBiasImpl = fallback::ConvBiasImpl; + +class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase { + static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); + static void kimpl(const NCBKernParam& param, const NCBKernIndex&); + +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "QU8MATMUL"; } + + bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(FallbackConvBiasImpl*, + const NCBKernSizeParam& param) const override { + return get_bundle(param).total_size_in_bytes(); + } + + SmallVector dispatch_kerns( + fallback::ConvBiasImpl* /*opr*/, + const NCBKernSizeParam& param) const override { + size_t group = param.filter_meta.group; + return {{kimpl, {group, 1_z, 1_z}}}; + } +}; + +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/conv_bias/quint8/strategy.cpp b/dnn/src/armv7/conv_bias/quint8/strategy.cpp new file mode 100644 index 00000000..d3385015 --- /dev/null +++ b/dnn/src/armv7/conv_bias/quint8/strategy.cpp @@ -0,0 +1,156 @@ +/** + * \file dnn/src/armv7/conv_bias/quint8/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/conv_bias/quint8/strategy.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +#include "src/armv7/matrix_mul/quint8/kernel_4x8x8.h" +#include "src/arm_common/conv_bias/matmul_postprocess.h" + +using namespace megdnn; +using namespace armv7; +using namespace armv7::matmul; + +namespace impl { +template +struct KernCaller; + +template +struct KernCaller { + static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, + size_t N, size_t K, dt_uint8* C, size_t LDC, + bool is_first_k, Op op, const dt_int32* bias, + dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { + megdnn_assert(is_first_k); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 8; + //! K is packed to times of 8 + K = round_up(K, 8); + const int K8 = K * 8; + const int K4 = K * 4; + + size_t m = 0; + for (; m < M; m += A_INTERLEAVE) { + uint8_t* output = C + (m * LDC); + const dt_uint8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_4x8x8::kern_4x8(packA, cur_packB, K, workspace, 8, + is_first_k, std::min(M - m, 4), + zp_A, zp_B); +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_M_N(cb, std::min(M - m, 4), 8); +#undef cb + + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_4x8x8::kern_4x4(packA, cur_packB, K, workspace, 4, + is_first_k, std::min(M - m, 4), + std::min(N - n, 4), zp_A, zp_B); +#define cb(m, n) \ + arm_common::ConvBiasMatmul::postprocess( \ + bias, workspace, output, LDC, op); + DISPATCH_M(cb, std::min(M - m, 4), + std::min(N - n, 4)); +#undef cb + + output += 4; + cur_packB += K4; + + } + packA += K4; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias += 4; + } + } + } +}; + +} // namespace impl + +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_4x8_nobias_identity); +void gemm_u8_4x8_nobias_identity::pack_A(dt_uint8* outptr, + const dt_uint8* inptr, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + uint8_t zA = A_dtype.param().zero_point; + if (transpose) { + matmul_4x8x8::gemm_u8_4x8_transpose_pack_A_n(outptr, inptr, ldin, y0, + ymax, k0, kmax, zA); + } else { + matmul_4x8x8::gemm_u8_4x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, + kmax, zA); + } +} + +void gemm_u8_4x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, + int ldin, int x0, int xmax, int k0, + int kmax, bool transpose) const { + uint8_t zB = B_dtype.param().zero_point; + if (transpose) { + matmul_4x8x8::gemm_u8_4x8_transpose_pack_B_n(out, in, ldin, x0, xmax, + k0, kmax, zB); + } else { + matmul_4x8x8::gemm_u8_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, + zB); + } +} + +size_t gemm_u8_4x8_nobias_identity::get_workspace_size() const { + return 4 * 8 * sizeof(dt_int32); +} + +#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ + void gemm_u8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ + const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \ + size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ + const dt_int32* bias, dt_int32* workspace) const { \ + float scale_A = A_dtype.param().scale; \ + uint8_t zp_A = A_dtype.param().zero_point; \ + float scale_B = B_dtype.param().scale; \ + uint8_t zp_B = B_dtype.param().zero_point; \ + float scale_C = C_dtype.param().scale; \ + uint8_t zp_C = C_dtype.param().zero_point; \ + DEFINE_OP(_OP); \ + impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ + packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ + workspace, zp_A, zp_B); \ + } + +#define DEFINE_OP(_Op) \ + arm_common::_Op op(scale_A* scale_B, scale_C, zp_C); + +KERN(4, 8, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) +KERN(4, 8, nobias, BiasMode::NO_BIAS, relu, ReluOp) +KERN(4, 8, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) +#undef DEFINE_OP + +#define DEFINE_OP(_Op) \ + arm_common::_Op op(scale_A* scale_B, \ + scale_A* scale_B, scale_C, zp_C); +KERN(4, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) +KERN(4, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) +KERN(4, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, + FuseAddHSwishOp) +#undef DEFINE_OP + +#undef KERN + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/conv_bias/quint8/strategy.h b/dnn/src/armv7/conv_bias/quint8/strategy.h new file mode 100644 index 00000000..c34d0776 --- /dev/null +++ b/dnn/src/armv7/conv_bias/quint8/strategy.h @@ -0,0 +1,46 @@ +/** + * \file dnn/src/armv7/conv_bias/quint8/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul { + +/** + * \brief base strategy of gemm. + * + * \name gemm___biasmode_nolinemode + */ +MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 4, 8, 8, + false, true, + gemm_u8_4x8_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_4x8_nobias_relu, + gemm_u8_4x8_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_4x8_nobias_hswish, + gemm_u8_4x8_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_4x8_bias_channel_identity, + gemm_u8_4x8_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_4x8_bias_channel_relu, + gemm_u8_4x8_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_4x8_bias_channel_hswish, + gemm_u8_4x8_nobias_identity); + +} // namespace matmul +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/handle.cpp b/dnn/src/armv7/handle.cpp new file mode 100644 index 00000000..7a381a94 --- /dev/null +++ b/dnn/src/armv7/handle.cpp @@ -0,0 +1,43 @@ +/** + * \file dnn/src/armv7/handle.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/common/handle_impl.h" + +#include "src/armv7/handle.h" + +#include "src/armv7/matrix_mul/opr_impl.h" +#include "src/armv7/rotate/opr_impl.h" +#include "src/armv7/relayout/opr_impl.h" +#include "src/armv7/conv_bias/opr_impl.h" + +namespace megdnn { +namespace armv7 { + +template +std::unique_ptr HandleImpl::create_operator() { + return arm_common::HandleImpl::create_operator(); +} + +MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMul) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(Rotate) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutForward) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpragmas" +#pragma GCC diagnostic ignored "-Winstantiation-after-specialization" +MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) +#pragma GCC diagnostic pop + +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/handle.h b/dnn/src/armv7/handle.h new file mode 100644 index 00000000..7c5b03d7 --- /dev/null +++ b/dnn/src/armv7/handle.h @@ -0,0 +1,34 @@ +/** + * \file dnn/src/armv7/handle.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/arm_common/handle.h" + +namespace megdnn { +namespace armv7 { + +class HandleImpl: public arm_common::HandleImpl { + public: + HandleImpl(megcoreComputingHandle_t computing_handle, + HandleType type = HandleType::ARMV7): + arm_common::HandleImpl::HandleImpl(computing_handle, type) + { + } + + template + std::unique_ptr create_operator(); +}; + +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen + + diff --git a/dnn/src/armv7/matrix_mul/algos.cpp b/dnn/src/armv7/matrix_mul/algos.cpp new file mode 100644 index 00000000..59b327ec --- /dev/null +++ b/dnn/src/armv7/matrix_mul/algos.cpp @@ -0,0 +1,888 @@ +/** + * \file dnn/src/armv7/matrix_mul/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/matrix_mul/algos.h" +#include "src/armv7/matrix_mul/fp16/strategy.h" +#include "src/armv7/matrix_mul/fp32/strategy.h" +#include "src/armv7/matrix_mul/int16x16x32/strategy.h" +#include "src/armv7/matrix_mul/int8/strategy.h" +#include "src/armv7/matrix_mul/int8x8x16/strategy.h" +#include "src/armv7/matrix_mul/quint8/strategy.h" +#include "src/common/utils.h" +#include "src/fallback/matrix_mul/gemm_impl.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_armv7_matmul_kern) + +using namespace megdnn; +using namespace armv7; + +/* ===================== F32 algo ===================== */ + +namespace { +void f32_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("f32_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + armv7::matmul::sgemm_4x12 strategy(M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} + +} // anonymous namespace + +bool MatrixMulImpl::AlgoF32::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32(); +} + +size_t MatrixMulImpl::AlgoF32::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoF32::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + armv7::matmul::sgemm_4x12 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern( + const KernSizeParam&) const { + return f32_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern, + "AlgoF32Impl"_hash, + armv7::matmul::sgemm_4x12, float, float); + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +/* ===================== F16 K4x16x1 algo ===================== */ +namespace { +void f16_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("f16_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + armv7::matmul::hgemm_4x16 strategy(M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoF16K4x16x1::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float16(); +} + +size_t MatrixMulImpl::AlgoF16K4x16x1::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoF16K4x16x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + armv7::matmul::hgemm_4x16 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K4x16x1::get_kern( + const KernSizeParam&) const { + return f16_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K4x16x1, megdnn_armv7_matmul_kern, + "AlgoF16K4x16x1"_hash, + armv7::matmul::hgemm_4x16, dt_float16, + dt_float16); + +#endif + +/* ===================== Int8x8x32 Kernel 4x2x16 algo ===================== */ + +namespace { +void kern_int8x8x32_k4x2x16(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("kern_int8x8x32_k4x2x16"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto trA = kern_param.trA, trB = kern_param.trB; + + armv7::matmul::gemm_s8_4x2 strategy(M, N, K, kern_param.A_type, + kern_param.B_type, + kern_param.C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32K4x2x16::usable( + const KernSizeParam& kern_size_param) const { + return can_be_treated_as_int8x8x32(kern_size_param); +} + +bool MatrixMulImpl::AlgoInt8x8x32K4x2x16::preferred( + const KernSizeParam& kern_size_param) const { + return kern_size_param.K > 32; +} + +size_t MatrixMulImpl::AlgoInt8x8x32K4x2x16::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x32K4x2x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + matmul::gemm_s8_4x2 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x2x16::get_kern( + const KernSizeParam&) const { + return kern_int8x8x32_k4x2x16; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x2x16, + megdnn_armv7_matmul_kern, + "AlgoInt8x8x32K4x2x16"_hash, + armv7::matmul::gemm_s8_4x2, int8_t, + int32_t); +/* ===================== Int8x8x32 Kernel 4x8x8 algo ===================== */ + +namespace { +void kern_int8x8x32_k4x8x8(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("kern_int8x8x32_k4x8x8"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto trA = kern_param.trA, trB = kern_param.trB; + + armv7::matmul::gemm_s8_4x8 strategy(M, N, K, kern_param.A_type, + kern_param.B_type, + kern_param.C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32K4x8x8::usable( + const KernSizeParam& kern_size_param) const { + return can_be_treated_as_int8x8x32(kern_size_param); +} + +bool MatrixMulImpl::AlgoInt8x8x32K4x8x8::preferred( + const KernSizeParam& kern_size_param) const { + return kern_size_param.K <= 32; +} + +size_t MatrixMulImpl::AlgoInt8x8x32K4x8x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x32K4x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + matmul::gemm_s8_4x8 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x8x8::get_kern( + const KernSizeParam&) const { + return kern_int8x8x32_k4x8x8; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x8x8, + megdnn_armv7_matmul_kern, + "AlgoInt8x8x32K4x8x8"_hash, + armv7::matmul::gemm_s8_4x8, int8_t, + int32_t); +/* ===================== Quint8 Kernel 4x8x8 algo ===================== */ + +namespace { +void kern_quint8_k4x8x8(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("kern_quint8_k4x8x8"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto trA = kern_param.trA, trB = kern_param.trB; + + armv7::matmul::gemm_u8_4x8 strategy(M, N, K, kern_param.A_type, + kern_param.B_type, + kern_param.C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoQuint8K4x8x8::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && + kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; +} + +size_t MatrixMulImpl::AlgoQuint8K4x8x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoQuint8K4x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + matmul::gemm_u8_4x8 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K4x8x8::get_kern( + const KernSizeParam&) const { + return kern_quint8_k4x8x8; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K4x8x8, megdnn_armv7_matmul_kern, + "AlgoQuint8K4x8x8"_hash, + armv7::matmul::gemm_u8_4x8, uint8_t, + int32_t); +/* ===================== Int8x8x16 Kernel 2x4x16 algo ===================== */ + +namespace { +void kern_int8x8x16_k2x4x16(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("kern_int8x8x16_k2x4x16"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto trA = kern_param.trA, trB = kern_param.trB; + + armv7::matmul::gemm_s8x8x16_4x2 strategy(M, N, K, kern_param.A_type, + kern_param.B_type, + kern_param.C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x16K4x2x16::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type == kern_size_param.B_type && + kern_size_param.A_type == dtype::Int8() && + kern_size_param.C_type == dtype::Int16() && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; +} + +size_t MatrixMulImpl::AlgoInt8x8x16K4x2x16::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x16K4x2x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + matmul::gemm_s8x8x16_4x2 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x2x16::get_kern( + const KernSizeParam&) const { + return kern_int8x8x16_k2x4x16; +} + +bool MatrixMulImpl::AlgoInt8x8x16K4x2x16::preferred( + const KernSizeParam& kern_size_param) const { + return kern_size_param.K > 128; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x2x16, + megdnn_armv7_matmul_kern, + "AlgoInt8x8x16K4x2x16"_hash, + armv7::matmul::gemm_s8x8x16_4x2, int8_t, + int16_t); +/* ===================== Int8x8x16 Kernel 4x8x8 algo ===================== */ + +namespace { +void kern_int8x8x16_k4x8x8(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("kern_int8x8x16_k4x8x8"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto trA = kern_param.trA, trB = kern_param.trB; + + armv7::matmul::gemm_s8x8x16_4x8 strategy(M, N, K, kern_param.A_type, + kern_param.B_type, + kern_param.C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x16K4x8x8::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type == kern_size_param.B_type && + kern_size_param.A_type == dtype::Int8() && + kern_size_param.C_type == dtype::Int16() && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; +} + +size_t MatrixMulImpl::AlgoInt8x8x16K4x8x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x16K4x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + matmul::gemm_s8x8x16_4x8 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x8x8::get_kern( + const KernSizeParam&) const { + return kern_int8x8x16_k4x8x8; +} + +bool MatrixMulImpl::AlgoInt8x8x16K4x8x8::preferred( + const KernSizeParam& kern_size_param) const { + return kern_size_param.K >= 8 && kern_size_param.K <= 128; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8, + megdnn_armv7_matmul_kern, + "AlgoInt8x8x16K4x8x8"_hash, + armv7::matmul::gemm_s8x8x16_4x8, int8_t, + int16_t); +/* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */ + +namespace { +void kern_int16x16x32K12x4x1(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("kern_int16x16x32K12x4x1"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto trA = kern_param.trA, trB = kern_param.trB; + + armv7::matmul::gemm_s16x16x32_12x4 strategy(M, N, K, kern_param.A_type, + kern_param.B_type, + kern_param.C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace +bool MatrixMulImpl::AlgoInt16x16x32K12x4x1::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type == kern_size_param.B_type && + kern_size_param.A_type == dtype::Int16() && + kern_size_param.C_type == dtype::Int32() && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; +} + +size_t MatrixMulImpl::AlgoInt16x16x32K12x4x1::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoInt16x16x32K12x4x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + matmul::gemm_s16x16x32_12x4 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x4x1::get_kern( + const KernSizeParam&) const { + return kern_int16x16x32K12x4x1; +} + +bool MatrixMulImpl::AlgoInt16x16x32K12x4x1::preferred( + const KernSizeParam& /*kern_size_param*/) const { + return true; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x4x1, + megdnn_armv7_matmul_kern, + "AlgoInt16x16x32K12x4x1"_hash, + armv7::matmul::gemm_s16x16x32_12x4, + int16_t, int32_t); +#if __ARM_FEATURE_DOTPROD +/* ===================== Int8 K6x8x4 algo ===================== */ +namespace { +void int8_k6x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("int8_k6x8x4_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + armv7::matmul::gemm_dots8_6x8 strategy(M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // namespace + +bool MatrixMulImpl::AlgoInt8x8x32K6x8x4::usable( + const KernSizeParam& kern_size_param) const { + return can_be_treated_as_int8x8x32(kern_size_param); +} + +size_t MatrixMulImpl::AlgoInt8x8x32K6x8x4::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x32K6x8x4::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + armv7::matmul::gemm_dots8_6x8 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K6x8x4::get_kern( + const KernSizeParam&) const { + return int8_k6x8x4_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K6x8x4, + megdnn_armv7_matmul_kern, + "AlgoInt8x8x32K6x8x4"_hash, + armv7::matmul::gemm_dots8_6x8, int8_t, + int32_t); +/* ===================== Quint8 K4x8x4 algo ===================== */ +namespace { +void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("quint8_dot_k4x8x4_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + armv7::matmul::gemm_dot_quint8_4x8 strategy(M, N, K, A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // namespace + +bool MatrixMulImpl::AlgoQuint8DotK4x8x4::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && + kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; +} + +size_t MatrixMulImpl::AlgoQuint8DotK4x8x4::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoQuint8DotK4x8x4::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + armv7::matmul::gemm_dot_quint8_4x8 strategy(M, N, K, A_type, B_type, + C_type); + return megdnn::matmul::GemmInterleaved< + armv7::matmul::gemm_dot_quint8_4x8>(M, N, K, trA, trB, + strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8DotK4x8x4::get_kern( + const KernSizeParam&) const { + return quint8_dot_k4x8x4_kern; +} +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4, + megdnn_armv7_matmul_kern, + "AlgoQuint8DotK4x8x4"_hash, + armv7::matmul::gemm_dot_quint8_4x8, + uint8_t, int32_t); +#endif + +/* ===================== F32 algo K4x8 ===================== */ + +namespace { +void f32_mk4_4x8_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("f32_mk4_4x8_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + armv7::matmul::sgemm_nopack_4x8 strategy(A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} + +} // anonymous namespace + +bool MatrixMulImpl::AlgoF32MK4_4x8::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::MK4 && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32() && + kern_size_param.N % 4 == 0 && !kern_size_param.trA && + !kern_size_param.trB; +} + +size_t MatrixMulImpl::AlgoF32MK4_4x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoF32MK4_4x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + armv7::matmul::sgemm_nopack_4x8 strategy(A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved(M, N, K, trA, trB, + strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x8::get_kern( + const KernSizeParam&) const { + return f32_mk4_4x8_kern; +} + +/* ===================== Int16x16x32 MK8 4x8 algo ===================== */ + +bool MatrixMulImpl::AlgoInt16x16x32MK8_4x8::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::MK8 && + kern_size_param.A_type == dtype::Int16() && + kern_size_param.B_type == dtype::Int16() && + kern_size_param.C_type == dtype::Int32() && + kern_size_param.N % 4 == 0 && !kern_size_param.trA && + !kern_size_param.trB; +} + +size_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoInt16x16x32MK8_4x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + armv7::matmul::gemm_nopack_s16_4x8 strategy(A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved< + armv7::matmul::gemm_nopack_s16_4x8, false>(M, N, K, trA, + trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_kern( + const KernSizeParam&) const { + auto kern_mk8_4x8 = [](const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoInt16x16x32MK8_4x8::get_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, + LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + armv7::matmul::gemm_nopack_s16_4x8 strategy(A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved(M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); + }; + return kern_mk8_4x8; +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +/* ===================== F16_MK8_4x8 algo ===================== */ + +bool MatrixMulImpl::AlgoF16MK8_4x8::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float16() && + kern_size_param.format == param::MatrixMul::Format::MK8 && + !kern_size_param.trA && !kern_size_param.trB && + kern_size_param.N % 4 == 0; +} + +size_t MatrixMulImpl::AlgoF16MK8_4x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoF16MK8_4x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + armv7::matmul::gemm_nopack_f16_4x8 strategy(A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved< + armv7::matmul::gemm_nopack_f16_4x8, false>(M, N, K, trA, + trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern( + const KernSizeParam&) const { + auto kern_mk8_4x8 = [](const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoF16MK8_4x8::get_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, + LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + armv7::matmul::gemm_nopack_f16_4x8 strategy(A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved< + armv7::matmul::gemm_nopack_f16_4x8, false>(M, N, K, trA, + trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); + }; + return kern_mk8_4x8; +} +#endif + +/* ===================== Int8x8x16 Kernel 2x4x16 algo ===================== */ + +namespace { +void kern_int8x8x32_mk4_4x2x16(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("kern_int8x8x32_mk4_4x2x16"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto trA = kern_param.trA, trB = kern_param.trB; + + armv7::matmul::gemm_mk4_s8_4x2 strategy(M, N, K, kern_param.A_type, + kern_param.B_type, + kern_param.C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::usable( + const KernSizeParam& param) const { + return param.A_type.enumv() == param.B_type.enumv() && + (param.A_type.enumv() == DTypeEnum::Int8 || + param.A_type.enumv() == DTypeEnum::QuantizedS8) && + (param.C_type.enumv() == DTypeEnum::Int32 || + param.C_type.enumv() == DTypeEnum::QuantizedS32) && + param.compute_mode == Param::ComputeMode::DEFAULT && + param.format == param::MatrixMul::Format::MK4 && param.M % 4 == 0 && + param.K % 4 == 0 && !param.trA && !param.trB; +} + +size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x32MK4_4x2x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + matmul::gemm_mk4_s8_4x2 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::get_kern( + const KernSizeParam&) const { + return kern_int8x8x32_mk4_4x2x16; +} + +bool MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::preferred( + const KernSizeParam& kern_size_param) const { + return kern_size_param.K > 16; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x2x16, + megdnn_armv7_matmul_kern, + "AlgoInt8x8x32MK4_4x2x16"_hash, + armv7::matmul::gemm_mk4_s8_4x2, int8_t, + int32_t); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h new file mode 100644 index 00000000..141d5650 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -0,0 +1,192 @@ +/** + * \file dnn/src/armv7/matrix_mul/algos.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/matrix_mul/algos.h" +#include "src/armv7/matrix_mul/opr_impl.h" +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace armv7 { + +class MatrixMulImpl::AlgoF32 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMV7_F32"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMV7_F32_MK4_4x8"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + PackMode packmode() const override { return PackMode::NO_PACK; } +}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +class MatrixMulImpl::AlgoF16K4x16x1 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH32_F16_K4X16X1"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; +class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH32_F16_MK8_4X8"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + PackMode packmode() const override { return PackMode::NO_PACK; } +}; +#endif +#if __ARM_FEATURE_DOTPROD +class MatrixMulImpl::AlgoInt8x8x32K6x8x4 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH32_INT8_K6X8X4"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoQuint8DotK4x8x4 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH32_QUINT8_K4X8X4"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; +#endif + +class MatrixMulImpl::AlgoF32Gemv final + : public arm_common::MatrixMulImpl::AlgoF32Gemv {}; + +class MatrixMulImpl::AlgoInt8x8x32K4x2x16 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMV7_INT8X8X32_K4X2X16"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoInt8x8x32K4x8x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMV7_INT8X8X32_K4X8X8"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +#if !__ARM_FEATURE_DOTPROD +class MatrixMulImpl::AlgoInt8x8x32Gemv final + : public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {}; +#endif + +class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMV7_QUINT8_K4X8X8"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoInt8x8x16K4x2x16 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMV7_INT8X8X16_K4X2X16"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoInt8x8x16K4x8x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMV7_INT8X8X16_K4X8X8"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMV7_INT16X16X32_K12X4X1"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoInt16x16x32MK8_4x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMV7_INT16X16X32_MK8_4X8"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + PackMode packmode() const override { return PackMode::NO_PACK; } +}; + +class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/asm/common.h b/dnn/src/armv7/matrix_mul/asm/common.h new file mode 100644 index 00000000..20bbd3f3 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/asm/common.h @@ -0,0 +1,1300 @@ +/** + * \file dnn/src/armv7/matrix_mul/asm/common.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include +#include +#include +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace armv7 { + +/* ======================== Prefetch ======================== */ +#define ASM_PREFETCH(address) "PLD " address "\n" + +static inline void prefetch_6x(const void* pfp) { + // clang-format off + asm volatile(ASM_PREFETCH("[%[pfp]]") + ASM_PREFETCH("[%[pfp], #64]") + ASM_PREFETCH("[%[pfp], #128]") + ASM_PREFETCH("[%[pfp], #192]") + ASM_PREFETCH("[%[pfp], #256]") + ASM_PREFETCH("[%[pfp], #320]") + : + : [pfp] "r"(pfp) + : "memory"); + // clang-format on +} + +static inline void prefetch_5x(const void* pfp) { + // clang-format off + asm volatile(ASM_PREFETCH("[%[pfp]]") + ASM_PREFETCH("[%[pfp], #64]") + ASM_PREFETCH("[%[pfp], #128]") + ASM_PREFETCH("[%[pfp], #192]") + ASM_PREFETCH("[%[pfp], #256]") + : + : [pfp] "r"(pfp) + : "memory"); + // clang-format on +} + +static inline void prefetch_4x(const void* pfp) { + // clang-format off + asm volatile(ASM_PREFETCH("[%[pfp]]") + ASM_PREFETCH("[%[pfp], #64]") + ASM_PREFETCH("[%[pfp], #128]") + ASM_PREFETCH("[%[pfp], #192]") + : + : [pfp] "r"(pfp) + : "memory"); + // clang-format on +} + +static inline void prefetch_3x(const void* pfp) { + // clang-format off + asm volatile(ASM_PREFETCH("[%[pfp]]") + ASM_PREFETCH("[%[pfp], #64]") + ASM_PREFETCH("[%[pfp], #128]") + : + : [pfp] "r"(pfp) + : "memory"); + // clang-format on +} + +static inline void prefetch_2x(const void* pfp) { + // clang-format off + asm volatile(ASM_PREFETCH("[%[pfp]]") + ASM_PREFETCH("[%[pfp], #64]") + : + : [pfp] "r"(pfp) + : "memory"); + // clang-format on +} + +static inline void prefetch_1x(const void* pfp) { + // clang-format off + asm volatile(ASM_PREFETCH("[%[pfp]]") : : [pfp] "r"(pfp) : "memory"); + // clang-format on +} + +/* ======================== transform ======================== */ +/** + * interleave_INTERLEAVE_UNROLLK_BATCH_type + * + * BATCH means process BATCH * UNROLL_K cols once, BATCH * sizeof(TYPE) * + * UNROLL_K = 16bytes(128bits, a vector size). + * + * the elements traverse order: + * rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[j, i] + */ + +static inline void interleave_4x1_2_d(const int64_t*& inptr0, + const int64_t*& inptr1, + const int64_t*& inptr2, + const int64_t*& inptr3, + int64_t*& outptr) { + asm volatile( + "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1 + "vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1 + "vld1.32 {d4, d5}, [%[inptr2]]!\n" // C0C1 + "vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1 + + "vst1.32 {d0}, [%[outptr]]!\n" + "vst1.32 {d2}, [%[outptr]]!\n" + "vst1.32 {d4}, [%[outptr]]!\n" + "vst1.32 {d6}, [%[outptr]]!\n" + "vst1.32 {d1}, [%[outptr]]!\n" + "vst1.32 {d3}, [%[outptr]]!\n" + "vst1.32 {d5}, [%[outptr]]!\n" + "vst1.32 {d7}, [%[outptr]]!\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "cc", "memory"); +} + +template +static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_8x8_1_b only support uint8_t and int8_t"); + asm volatile( + "vld1.32 {d0}, [%[inptr0]]!\n" // A1A2A3A4A5A6A7A8 + "vld1.32 {d1}, [%[inptr1]]!\n" // B1B2B3B4B5B6B7B8 + "vld1.32 {d2}, [%[inptr2]]!\n" // C1C2C3C4C5C6C7C8 + "vld1.32 {d3}, [%[inptr3]]!\n" // D1D2D3D4D5D6D7D8 + "vld1.32 {d4}, [%[inptr4]]!\n" // E1E2E3E4E5E6E7E8 + "vld1.32 {d5}, [%[inptr5]]!\n" // F1F2F3F4F5F6F7F8 + "vld1.32 {d6}, [%[inptr6]]!\n" // G1G2G3G4G5G6G7G8 + "vld1.32 {d7}, [%[inptr7]]!\n" // H1H2H3H4H5H6H7H8 + + "vst1.32 {d0}, [%[outptr]]!\n" // A1A2A3A4A5A6A7A8 + "vst1.32 {d1}, [%[outptr]]!\n" // B1B2B3B4B5B6B7B8 + "vst1.32 {d2}, [%[outptr]]!\n" // C1C2C3C4C5C6C7C8 + "vst1.32 {d3}, [%[outptr]]!\n" // D1D2D3D4D5D6D7D8 + "vst1.32 {d4}, [%[outptr]]!\n" // E1E2E3E4E5E6E7E8 + "vst1.32 {d5}, [%[outptr]]!\n" // F1F2F3F4F5F6F7F8 + "vst1.32 {d6}, [%[outptr]]!\n" // G1G2G3G4G5G6G7G8 + "vst1.32 {d7}, [%[outptr]]!\n" // H1H2H3H4H5H6H7H8 + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "memory"); + +} + +template +static inline void interleave_4x4_4_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_4x4_4_b only support uint8_t and int8_t"); + asm volatile( + "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 + "vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1B2B3 + "vld1.32 {d4, d5}, [%[inptr2]]!\n" // C0C1C2C3 + "vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1D2D3 + "vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3 + "vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3 + "vswp d1, d4 \n" // q0=A0,B0,C0,D0 q2=A2,B2,C2,D2 + "vswp d3, d6 \n" // q1=A1,B1,C1,D1 q3=A3,B3,C3,D3 + "vst1.32 {d0-d1},[%[outptr]]!\n" + "vst1.32 {d2-d3},[%[outptr]]!\n" + "vst1.32 {d4-d5},[%[outptr]]!\n" + "vst1.32 {d6-d7},[%[outptr]]!\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "memory"); +} + +template +static inline void interleave_6x4_4_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_6x4_4_b only support uint8_t and int8_t"); + asm volatile( + "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 + "vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1B2B3 + "vld1.32 {d4, d5}, [%[inptr2]]!\n" // C0C1C2C3 + "vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1D2D3 + "vld1.32 {d8, d9}, [%[inptr4]]!\n" // E0E1E2E3 + "vld1.32 {d10, d11}, [%[inptr5]]!\n" // F0F1F2F3 + "vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3 + "vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3 + "vtrn.32 q4, q5\n" // E0F0E2F2 E1F1E3F3 + "vswp d1, d4 \n" // q0=A0,B0,C0,D0 q2=A2,B2,C2,D2 + "vswp d3, d6 \n" // q1=A1,B1,C1,D1 q3=A3,B3,C3,D3 + "vst1.32 {d0-d1},[%[outptr]]!\n" + "vst1.32 {d8}, [%[outptr]]!\n" + + "vst1.32 {d2-d3},[%[outptr]]!\n" + "vst1.32 {d10}, [%[outptr]]!\n" + + "vst1.32 {d4-d5},[%[outptr]]!\n" + "vst1.32 {d9}, [%[outptr]]!\n" + + "vst1.32 {d6-d7},[%[outptr]]!\n" + "vst1.32 {d11}, [%[outptr]]!\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "memory"); +} + +template +static inline void interleave_8x4_4_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_8x4_4_b only support uint8_t and int8_t"); + asm volatile( + "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 + "vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1B2B3 + "vld1.32 {d4, d5}, [%[inptr2]]!\n" // C0C1C2C3 + "vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1D2D3 + "vld1.32 {d8, d9}, [%[inptr4]]!\n" // E0E1E2E3 + "vld1.32 {d10, d11}, [%[inptr5]]!\n" // F0F1F2F3 + "vld1.32 {d12, d13}, [%[inptr6]]!\n" // G0G1G2G3 + "vld1.32 {d14, d15}, [%[inptr7]]!\n" // H0H1H2H3 + "vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3 + "vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3 + "vtrn.32 q4, q5\n" // E0F0E2F2 E1F1E3F3 + "vtrn.32 q6, q7\n" // G0H0G2H2 G1H1G3H3 + + "vswp d1, d4 \n" // q0=A0,B0,C0,D0 q2=A2,B2,C2,D2 + "vswp d3, d6 \n" // q1=A1,B1,C1,D1 q3=A3,B3,C3,D3 + + "vswp d9, d12 \n" // q4=E0,F0,G0,H0 q6=E2,F2,G2,H2 + "vswp d11, d14 \n" // q5=E1,F1,G1,H1 q7=E3,F3,G3,H3 + + "vst1.32 {d0-d1},[%[outptr]]!\n" + "vst1.32 {d8-d9},[%[outptr]]!\n" + + "vst1.32 {d2-d3},[%[outptr]]!\n" + "vst1.32 {d10-d11},[%[outptr]]!\n" + + "vst1.32 {d4-d5},[%[outptr]]!\n" + "vst1.32 {d12-d13},[%[outptr]]!\n" + + "vst1.32 {d6-d7},[%[outptr]]!\n" + "vst1.32 {d14-d15},[%[outptr]]!\n" + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "memory"); +} + +template +static inline void interleave_6x4_8_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_6x8_4_b only support uint8_t and int8_t"); + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]! \n" // q0,q1=r00,r04,r01,r05,r02,r06,r03,r07 + "vld4.32 {d4-d7}, [%[inptr1]]! \n" // q2,q3=r10,r14,r11,r15,r12,r16,r13,r17 + "vld4.32 {d8-d11}, [%[inptr2]]!\n" // q4,q5=r20,r24,r21,r25,r22,r26,r23,r27 + "vld4.32 {d12-d15}, [%[inptr3]]!\n" // q6,q7=r30,r34,r31,r35,r32,r36,r33,r37 + "vld4.32 {d16-d19}, [%[inptr4]]!\n" // q8,q9=r40,r44,r41,r45,r42,r46,r43,r47 + "vld4.32 {d20-d23}, [%[inptr5]]!\n" // q10,q11=r50,r54,r51,r55,r52,r56,r53,r5 + + "vtrn.32 q0, q2 \n" // q0=r00,r10,r01,r11 q2=r04,r14,r05,r15 + "vtrn.32 q4, q6 \n" // q4=r20,r30,r21,r31 q6=r24,r34,r25,r35 + "vtrn.32 q8, q10 \n" // q8=r40,r50,r41,r51 q10=r44,r54,r45,r55 + "vswp d1, d8 \n" // q0=r00,r10,r20,r30 q4=r01,r11,r21,r31 + "vtrn.32 q1, q3 \n" // q1=r02,r12,r03,r13 q3=r06,r16,r07,r17 + "vtrn.32 q5, q7 \n" // q5=r22,r32,r23,r33 q7=r26,r36,r27,r37 + "vtrn.32 q9, q11 \n" // q9=r42,r52,r43,r53 q11=r46,r56,r47,r57 + "vst1.32 {d0-d1}, [%[outptr]]! \n" + "vst1.32 {d16}, [%[outptr]]! \n" + "vswp d3, d10 \n" // q1=r02,r12,r22,r32 q5=r03,r13,r23,r33 + "vst1.32 {d8-d9}, [%[outptr]]! \n" + "vst1.32 {d17}, [%[outptr]]! \n" + "vst1.32 {d2-d3}, [%[outptr]]!\n" + "vst1.32 {d18}, [%[outptr]]!\n" + "vswp d5, d12 \n" // q2=r04,r14,r24,r34 q6=r05,r15,r25,r35 + "vst1.32 {d10-d11},[%[outptr]]!\n" + "vst1.32 {d19}, [%[outptr]]!\n" + "vst1.32 {d4-d5}, [%[outptr]]! \n" + "vst1.32 {d20}, [%[outptr]]! \n" + "vswp d7, d14 \n" // q3=r06,r16,r26,r36 q7=r07,r17,r27,r37 + "vst1.32 {d12-d13},[%[outptr]]! \n" + "vst1.32 {d21}, [%[outptr]]! \n" + "vst1.32 {d6-d7}, [%[outptr]]! \n" + "vst1.32 {d22}, [%[outptr]]! \n" + "vst1.32 {d14-d15},[%[outptr]]! \n" + "vst1.32 {d23}, [%[outptr]]! \n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "cc", "memory"); +} + +template +static inline void interleave_4x16_1_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 1, "only support size == 1"); + asm volatile( + "vld1.32 {d0, d1}, [%[inptr0]]!\n" // d0 = A0A1A2A3 + "vld1.32 {d2, d3}, [%[inptr1]]!\n" // d1 = B0B1B2B3 + "vld1.32 {d4, d5}, [%[inptr2]]!\n" // d2 = C0C1C2C3 + "vld1.32 {d6, d7}, [%[inptr3]]!\n" // d3 = D0D1D2D3 + "vst1.32 {d0, d1}, [%[outptr]]!\n" + "vst1.32 {d2, d3}, [%[outptr]]!\n" + "vst1.32 {d4, d5}, [%[outptr]]!\n" + "vst1.32 {d6, d7}, [%[outptr]]!\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "cc", "memory"); +} + +template +static inline void interleave_4x8_2_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_4x8_2_b only support uint8_t and int8_t"); + interleave_4x1_2_d(reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(outptr)); +} + +template +static inline void interleave_2x16_1_b(const T*& inptr0, const T*& inptr1, + T*& outptr) { + static_assert(sizeof(T) == 1, "only support size == 2"); + asm volatile( + "vld1.32 {d0, d1}, [%[inptr0]]!\n" + "vld1.32 {d2, d3}, [%[inptr1]]!\n" + "vst1.32 {d0, d1}, [%[outptr]]!\n" + "vst1.32 {d2, d3}, [%[outptr]]!\n" + + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) + : + : "q0", "q1", "cc", "memory"); +} + +template +static inline void interleave_4x4_1_h(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 2, + "interleave_4x16_1_h only support sizeof(T) == 2"); + asm volatile( + "vld1.16 {d0}, [%[inptr0]]!\n" + "vld1.16 {d1}, [%[inptr1]]!\n" + "vld1.16 {d2}, [%[inptr2]]!\n" + "vld1.16 {d3}, [%[inptr3]]!\n" + + "vst1.16 {d0}, [%[outptr]]!\n" + "vst1.16 {d1}, [%[outptr]]!\n" + "vst1.16 {d2}, [%[outptr]]!\n" + "vst1.16 {d3}, [%[outptr]]!\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "memory"); +} + +template +static inline void interleave_4x12_1_h(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 2, + "interleave_4x12_1_h only support sizeof(T) == 2"); + asm volatile( + "pld [%[inptr0],#192]\n" + "vld1.16 {d0}, [%[inptr0]]!\n" // A0A1A2A3 + "vld1.16 {d1}, [%[inptr0]]!\n" // B0B1B2B3 + "vld1.16 {d2}, [%[inptr0]]!\n" // C0C1C2C3 + "pld [%[inptr1],#192]\n" + "vld1.16 {d3}, [%[inptr1]]!\n" // A0A1A2A3 + "vld1.16 {d4}, [%[inptr1]]!\n" // B0B1B2B3 + "vld1.16 {d5}, [%[inptr1]]!\n" // C0C1C2C3 + "pld [%[inptr2],#192]\n" + "vld1.16 {d6}, [%[inptr2]]!\n" // A0A1A2A3 + "vld1.16 {d7}, [%[inptr2]]!\n" // B0B1B2B3 + "vld1.16 {d8}, [%[inptr2]]!\n" // C0C1C2C3 + "pld [%[inptr3],#192]\n" + "vld1.16 {d9}, [%[inptr3]]!\n" // A0A1A2A3 + "vld1.16 {d10}, [%[inptr3]]!\n" // B0B1B2B3 + "vld1.16 {d11}, [%[inptr3]]!\n" // C0C1C2C3 + + "vst1.16 {d0}, [%[outptr]]!\n" // A0B0C0D0 + "vst1.16 {d1}, [%[outptr]]!\n" // E0F0G0H0 + "vst1.16 {d2}, [%[outptr]]!\n" // I0J0K0L0 + "vst1.16 {d3}, [%[outptr]]!\n" // D0D1D2D3 + "vst1.16 {d4}, [%[outptr]]!\n" // E0E1E2E3 + "vst1.16 {d5}, [%[outptr]]!\n" // F0F1F2F3 + "vst1.16 {d6}, [%[outptr]]!\n" // G0G1G2G3 + "vst1.16 {d7}, [%[outptr]]!\n" // H0H1H2H3 + "vst1.16 {d8}, [%[outptr]]!\n" // H0H1H2H3 + "vst1.16 {d9}, [%[outptr]]!\n" // G0G1G2G3 + "vst1.16 {d10}, [%[outptr]]!\n" // H0H1H2H3 + "vst1.16 {d11}, [%[outptr]]!\n" // H0H1H2H3 + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "memory"); +} + +template +static inline void interleave_4x16_1_h(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 2, + "interleave_4x16_1_h only support sizeof(T) == 2"); + asm volatile( + "vld1.16 {d0, d1, d2, d3}, [%[inptr0]]!\n" + "vld1.16 {d4, d5, d6, d7}, [%[inptr1]]!\n" + "vld1.16 {d8, d9, d10, d11}, [%[inptr2]]!\n" + "vld1.16 {d12, d13, d14, d15}, [%[inptr3]]!\n" + + "vst1.16 {d0, d1, d2, d3}, [%[outptr]]!\n" + "vst1.16 {d4, d5, d6, d7}, [%[outptr]]!\n" + "vst1.16 {d8, d9, d10, d11}, [%[outptr]]!\n" + "vst1.16 {d12, d13, d14, d15}, [%[outptr]]!\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "memory"); +} + +template +static inline void interleave_4x4_1_s(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 4, + "interleave_4x4_1_s only support sizeof(T) == 4"); + asm volatile( + "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 + "vld1.32 {d2, d3}, [%[inptr1]]!\n" // A0A1A2A3 + "vld1.32 {d4, d5}, [%[inptr2]]!\n" // A0A1A2A3 + "vld1.32 {d6, d7}, [%[inptr3]]!\n" // A0A1A2A3 + + "vst1.32 {d0, d1}, [%[outptr]]!\n" // A0B0C0D0 + "vst1.32 {d2, d3}, [%[outptr]]!\n" // E0F0G0H0 + "vst1.32 {d4, d5}, [%[outptr]]!\n" // I0J0K0L0 + "vst1.32 {d6, d7}, [%[outptr]]!\n" // D0D1D2D3 + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "memory"); +} + +template +static inline void interleave_1x4_1_h(const T*& inptr0, T*& outptr) { + static_assert(sizeof(T) == 2, + "transpose_1x4_1_h only support sizeof(T) == 2"); + asm volatile( + "vld1.16 {d0}, [%[inptr0]]!\n" // A01234567 + "vst1.16 {d0}, [%[outptr]]!\n" + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) + : + : "d0", "memory"); +} + +template +static inline void interleave_4x12_1_s(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 4, + "interleave_4x12_1_s only support sizeof(T) == 4"); + asm volatile( + "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 + "vld1.32 {d2, d3}, [%[inptr0]]!\n" // B0B1B2B3 + "vld1.32 {d4, d5}, [%[inptr0]]!\n" // C0C1C2C3 + "vld1.32 {d6, d7}, [%[inptr1]]!\n" // A0A1A2A3 + "vld1.32 {d8, d9}, [%[inptr1]]!\n" // B0B1B2B3 + "vld1.32 {d10, d11}, [%[inptr1]]!\n" // C0C1C2C3 + "vld1.32 {d12, d13}, [%[inptr2]]!\n" // A0A1A2A3 + "vld1.32 {d14, d15}, [%[inptr2]]!\n" // B0B1B2B3 + "vld1.32 {d16, d17}, [%[inptr2]]!\n" // C0C1C2C3 + "vld1.32 {d18, d19}, [%[inptr3]]!\n" // A0A1A2A3 + "vld1.32 {d20, d21}, [%[inptr3]]!\n" // B0B1B2B3 + "vld1.32 {d22, d23}, [%[inptr3]]!\n" // C0C1C2C3 + + "vst1.32 {d0, d1}, [%[outptr]]!\n" // A0B0C0D0 + "vst1.32 {d2, d3}, [%[outptr]]!\n" // E0F0G0H0 + "vst1.32 {d4, d5}, [%[outptr]]!\n" // I0J0K0L0 + "vst1.32 {d6, d7}, [%[outptr]]!\n" // D0D1D2D3 + "vst1.32 {d8, d9}, [%[outptr]]!\n" // E0E1E2E3 + "vst1.32 {d10, d11}, [%[outptr]]!\n" // F0F1F2F3 + "vst1.32 {d12, d13}, [%[outptr]]!\n" // G0G1G2G3 + "vst1.32 {d14, d15}, [%[outptr]]!\n" // H0H1H2H3 + "vst1.32 {d16, d17}, [%[outptr]]!\n" // H0H1H2H3 + "vst1.32 {d18, d19}, [%[outptr]]!\n" // G0G1G2G3 + "vst1.32 {d20, d21}, [%[outptr]]!\n" // H0H1H2H3 + "vst1.32 {d22, d23}, [%[outptr]]!\n" // H0H1H2H3 + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", + "d20", "d21", "d22", "d23", "memory"); +} + +template +static inline void interleave_1x12_1_h(const T*& inptr0, T*& outptr) { + static_assert(sizeof(T) == 2, + "transpose_1x12_1_h only support sizeof(T) == 2"); + asm volatile( + "vld1.16 {d0,d1}, [%[inptr0]]!\n" // A01234567 + "vld1.16 {d2} , [%[inptr0]]!\n" // A891011 + "vst1.16 {d0,d1}, [%[outptr]]!\n" + "vst1.16 {d2} , [%[outptr]]\n" + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "memory"); +} + +template +static inline void interleave_1x12_1_s(const T*& inptr0, T*& outptr) { + static_assert(sizeof(T) == 4, + "interleave_1x12_1_s only support sizeof(T) == 4"); + asm volatile( + "vld1.32 {d0, d1}, [%[inptr0]]!\n" + "vld1.32 {d2, d3}, [%[inptr0]]!\n" + "vld1.32 {d4, d5}, [%[inptr0]]!\n" + "vst1.32 {d0, d1}, [%[outptr]]!\n" + "vst1.32 {d2, d3}, [%[outptr]]!\n" + "vst1.32 {d4, d5}, [%[outptr]]!\n" + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "memory"); +} + +template +static inline void interleave_1x16_1_h(const T*& inptr0, T*& outptr) { + static_assert(sizeof(T) == 2, + "transpose_1x12_1_h only support sizeof(T) == 2"); + asm volatile( + "vld1.16 {d0,d1, d2, d3}, [%[inptr0]]!\n" + "vst1.16 {d0,d1, d2, d3}, [%[outptr]]!\n" + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "memory"); +} + +template +static inline void interleave_1x4_1_s(const T*& inptr0, T*& outptr) { + static_assert(sizeof(T) == 4, + "interleave_1x4_1_s only support sizeof(T) == 4"); + asm volatile( + "vld1.32 {d0, d1}, [%[inptr0]]!\n" + "vst1.32 {d0, d1}, [%[outptr]]\n" + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) + : + : "d0", "d1", "memory"); +} + +template +static inline void interleave_helper(const T*& inptr, T*& outptr, int unroll_k, + int ksize, T val = 0) { + int k = 0; + for (; k < ksize; k++) { + *outptr++ = *inptr++; + } + for (; k < unroll_k; k++) { + *outptr++ = val; + } +} + +template +static inline void interleave_1(const T*& inptr0, T*& outptr, int unroll_k, + int ksize, T val = 0) { + for (int k = 0; k < ksize; k += unroll_k) { + int size = std::min(unroll_k, ksize - k); + interleave_helper(inptr0, outptr, unroll_k, size, val); + } +} + +template +static inline void interleave_2(const T*& inptr0, const T*& inptr1, T*& outptr, + int unroll_k, int ksize, T val = 0) { + for (int k = 0; k < ksize; k += unroll_k) { + int size = std::min(unroll_k, ksize - k); + interleave_helper(inptr0, outptr, unroll_k, size, val); + interleave_helper(inptr1, outptr, unroll_k, size, val); + } +} + +template +static inline void interleave_4(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, T*& outptr, + int unroll_k, int ksize, T val = 0) { + for (int k = 0; k < ksize; k += unroll_k) { + int size = std::min(unroll_k, ksize - k); + interleave_helper(inptr0, outptr, unroll_k, size, val); + interleave_helper(inptr1, outptr, unroll_k, size, val); + interleave_helper(inptr2, outptr, unroll_k, size, val); + interleave_helper(inptr3, outptr, unroll_k, size, val); + } +} + +template +static inline void interleave_6(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, T*& outptr, + int unroll_k, int ksize, T val = 0) { + for (int k = 0; k < ksize; k += unroll_k) { + int size = std::min(unroll_k, ksize - k); + interleave_helper(inptr0, outptr, unroll_k, size, val); + interleave_helper(inptr1, outptr, unroll_k, size, val); + interleave_helper(inptr2, outptr, unroll_k, size, val); + interleave_helper(inptr3, outptr, unroll_k, size, val); + interleave_helper(inptr4, outptr, unroll_k, size, val); + interleave_helper(inptr5, outptr, unroll_k, size, val); + } +} +template +static inline void interleave_8(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, T*& outptr, + int unroll_k, int ksize, T val = 0) { + for (int k = 0; k < ksize; k += unroll_k) { + int size = std::min(unroll_k, ksize - k); + interleave_helper(inptr0, outptr, unroll_k, size, val); + interleave_helper(inptr1, outptr, unroll_k, size, val); + interleave_helper(inptr2, outptr, unroll_k, size, val); + interleave_helper(inptr3, outptr, unroll_k, size, val); + interleave_helper(inptr4, outptr, unroll_k, size, val); + interleave_helper(inptr5, outptr, unroll_k, size, val); + interleave_helper(inptr6, outptr, unroll_k, size, val); + interleave_helper(inptr7, outptr, unroll_k, size, val); + } +} + +/* ======================== transpose pack B ======================== */ +/** + * transpose_INTERLEAVE_UNROLLK_BATCH_type + * + * BATCH means process BATCH * INTERLEAVE cols once, BATCH * sizeof(TYPE) * + * INTERLEAVE = 16bytes(128bits, a vector size). + * + * the elements traverse order: + * rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[i, j] + */ +template +static inline void transpose_8x8_1_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T* outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "transpose_8x8_1_b only support uint8_t and int8_t"); + asm volatile( + "vld1.32 {d0}, [%[inptr0]]!\n" // A1A2A3A4A5A6A7A8 + "vld1.32 {d1}, [%[inptr1]]!\n" // B1B2B3B4B5B6B7B8 + "vld1.32 {d2}, [%[inptr2]]!\n" // C1C2C3C4C5C6C7C8 + "vld1.32 {d3}, [%[inptr3]]!\n" // D1D2D3D4D5D6D7D8 + "vld1.32 {d4}, [%[inptr4]]!\n" // E1E2E3E4E5E6E7E8 + "vld1.32 {d5}, [%[inptr5]]!\n" // F1F2F3F4F5F6F7F8 + "vld1.32 {d6}, [%[inptr6]]!\n" // G1G2G3G4G5G6G7G8 + "vld1.32 {d7}, [%[inptr7]]!\n" // H1H2H3H4H5H6H7H8 + + "vzip.8 d0, d1\n" // A1B1A2B2A3B3A4B4 A5B5A6B6A7B7A8B8 + "vzip.8 d2, d3\n" // C1D1C2D2C3D3C4D4 C5D5C6D6C7D7C8D8 + "vzip.8 d4, d5\n" // E1F1E2F2E3F3E4F4 E5F5E6F6E7F7E8F8 + "vzip.8 d6, d7\n" // G1H1G2H2G3H3G4H4 G5H5G6H6G7H7G8H8 + + "vzip.16 d0, d2\n" // A1B1C1D1A2B2C2D2 A3B3C3D3A4B4C4D4 + "vzip.16 d4, d6\n" // E1F1G1H1E2F2G2H2 E3F3G3H3E4F4G4H4 + "vzip.16 d1, d3\n" // A5B5C5D5A6B6C6D6 A7B7C7D7A8B8C8D8 + "vzip.16 d5, d7\n" // E5F5G5H5E6F6G6H6 E7F7G7H7E8F8G8H8 + + "vzip.32 d0, d4\n" // A1B1C1D1E1F1G1H1 A2B2C2D2E2F2G2H2 + "vzip.32 d1, d5\n" // A5B5C5D5E5F5G5H5 A6B6C6D6E6F6G6H6 + "vzip.32 d2, d6\n" // A3B3C3D3E3F3G3H3 A4B4C4D4E4F4G4H4 + "vzip.32 d3, d7\n" // A7B7C7D7E7F7G7H7 A8B8C8D8E8F8G8H8 + + "vst1.32 {d0}, [%[outptr]]!\n" // A1B1C1D1E1F1G1H1 + "vst1.32 {d4}, [%[outptr]]!\n" // A2B2C2D2E2F2G2H2 + "vst1.32 {d2}, [%[outptr]]!\n" // A3B3C3D3E3F3G3H3 + "vst1.32 {d6}, [%[outptr]]!\n" // A4B4C4D4E4F4G4H4 + "vst1.32 {d1}, [%[outptr]]!\n" // A5B5C5D5E5F5G5H5 + "vst1.32 {d5}, [%[outptr]]!\n" // A6B6C6D6E6F6G6H6 + "vst1.32 {d3}, [%[outptr]]!\n" // A7B7C7D7E7F7G7H7 + "vst1.32 {d7}, [%[outptr]]!\n" // A8B8C8D8E8F8G8H8 + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "cc", "memory"); +} + +template +static inline void transpose_8x4_1_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T* outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "transpose_8x4_1_b only support uint8_t and int8_t"); + asm volatile( + "vld1.32 {d0}, [%[inptr0]]!\n" // A1A2A3A4A5A6A7A8 + "vld1.32 {d1}, [%[inptr1]]!\n" // B1B2B3B4B5B6B7B8 + "vld1.32 {d2}, [%[inptr2]]!\n" // C1C2C3C4C5C6C7C8 + "vld1.32 {d3}, [%[inptr3]]!\n" // D1D2D3D4D5D6D7D8 + + "vtrn.8 d0, d1\n" // A1B1A3B3A5B5A7B7 A2B2A4B4A6B6A8B8 + "vtrn.8 d2, d3\n" // C1D1C3D3C5D5C7D7 C2D2C4D4C6D6C8D8 + + "vtrn.16 d0, d2\n" // A1B1C1D1A5B5C5D5 A3B3C3D3A7B7C7D7 + "vtrn.16 d1, d3\n" // A2B2C2D2A6B6C6D6 A4B4C4D4A8B8C8D8 + + //! ABCD=E then + //! d0: E1E5 d1: E2E6 d2: E3E7 d3: E4E8 + "vzip.32 d0, d1\n" // E1E2 E5E6 + "vzip.32 d2, d3\n" // E3E4 E7E8 + + "vst1.32 {d0}, [%[outptr]]!\n" + "vst1.32 {d2}, [%[outptr]]!\n" + "vst1.32 {d1}, [%[outptr]]!\n" + "vst1.32 {d3}, [%[outptr]]!\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "q0", "q1", "memory"); +} + +template +static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + const T*& inptr8, const T*& inptr9, + const T*& inptr10, const T*& inptr11, + int ldin, T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_12x4_1_h only support uint16_t and int16_t"); + auto ldin_asm = ldin << 1; + asm volatile( + "vld1.16 {d0}, [%[inptr0]]!\n" // A0A1A2A3 + "vld1.16 {d1}, [%[inptr1]]!\n" // B0B1B2B3 + "vld1.16 {d2}, [%[inptr2]]!\n" // C0C1C2C3 + "vld1.16 {d3}, [%[inptr3]]!\n" // D0D1D2D3 + "vld1.16 {d4}, [%[inptr4]]!\n" // E0E1E2E3 + "vld1.16 {d5}, [%[inptr5]]!\n" // F0F1F2F3 + "vld1.16 {d6}, [%[inptr6]]!\n" // G0G1G2G3 + "vld1.16 {d7}, [%[inptr7]]!\n" // H0H1H2H3 + "vld1.16 {d8}, [%[inptr8]]!\n" // I0I1I2I3 + "vld1.16 {d9}, [%[inptr9]]\n" // J0J1J2J3 + "add %[inptr9], %[inptr9], %[ldin_asm]\n" + "vld1.16 {d10}, [%[inptr9]]\n" // K0K1K2K3 + "add %[inptr9], %[inptr9], %[ldin_asm]\n" + "vld1.16 {d11}, [%[inptr9]]\n" // L0L1L2L3 + + "vtrn.16 d0, d1\n" // A0B0A2B2A1B1A3B3 + "vtrn.16 d2, d3\n" // C0D0C2D2C1D1C3D3 + "vtrn.16 d4, d5\n" // E0F0E2F2E1F1E3F3 + "vtrn.16 d6, d7\n" // G0H0G2H2G1H1G3H3 + + "vtrn.16 d8, d9\n" // I0J0I2J2I1J1I3J3 + "vtrn.16 d10, d11\n" // K0L0K2L2K1L1K3L3 + + "vtrn.32 q0, q1\n" // A0B0C0D0 A1B1C1D1 A2B2C2D2 A3B3C3D3 + "vtrn.32 q2, q3\n" // E0F0G0H0 E1F1G1G1 E2F2G2H2 E3F3G3H3 + "vtrn.32 q4, q5\n" // I0J0K0L0 I1J1K1L1 I2J2K2L2 I3J3K3L3 + + "vst1.16 {d0}, [%[outptr]]!\n" // A0B0C0D0 + "vst1.16 {d4}, [%[outptr]]!\n" // E0F0G0H0 + "vst1.16 {d8}, [%[outptr]]!\n" // I0J0K0L0 + "vst1.16 {d1}, [%[outptr]]!\n" // D0D1D2D3 + "vst1.16 {d5}, [%[outptr]]!\n" // E0E1E2E3 + "vst1.16 {d9}, [%[outptr]]!\n" // F0F1F2F3 + "vst1.16 {d2}, [%[outptr]]!\n" // G0G1G2G3 + "vst1.16 {d6}, [%[outptr]]!\n" // H0H1H2H3 + "vst1.16 {d10}, [%[outptr]]!\n" // H0H1H2H3 + "vst1.16 {d3}, [%[outptr]]!\n" // G0G1G2G3 + "vst1.16 {d7}, [%[outptr]]!\n" // H0H1H2H3 + "vst1.16 {d11}, [%[outptr]]!\n" // H0H1H2H3 + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), + [inptr9] "+r"(inptr9), [outptr] "+r"(outptr) + :[ldin_asm] "r"(ldin_asm) + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "memory"); + inptr9 -= ldin_asm; + inptr9 += 4; + inptr10 += 4; + inptr11 += 4; +} + + + +template +static inline void transpose_2x16_1_b_helper(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T* outptr) { + static_assert(sizeof(T) == 1, "only support size == 1"); + static uint8x8_t shuffle_idx = {0, 2, 4, 6, 1, 3, 5, 7}; + asm volatile( + "vld1.16 {d0[0]}, [%[inptr0]]!\n" + "vld1.16 {d0[1]}, [%[inptr1]]!\n" + "vld1.16 {d0[2]}, [%[inptr2]]!\n" + "vld1.16 {d0[3]}, [%[inptr3]]!\n" + "vld1.16 {d2[0]}, [%[inptr4]]!\n" + "vld1.16 {d2[1]}, [%[inptr5]]!\n" + "vld1.16 {d2[2]}, [%[inptr6]]!\n" + "vld1.16 {d2[3]}, [%[inptr7]]!\n" + "mov r0, #16\n" + + "vtbl.8 d1, {d0}, %[shuffle_idx]\n" + "vtbl.8 d3, {d2}, %[shuffle_idx]\n" + + "vzip.32 d1, d3\n" + + "vst1.64 d1, [%[outptr]], r0\n" + "vst1.64 d3, [%[outptr]]\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), + [outptr] "+r"(outptr), [shuffle_idx] "+w"(shuffle_idx) + : + : "q0", "q1", "q2", "r0", "memory"); +} + +template +static inline void transpose_4x8_1_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T* outptr) { + static uint8x8_t shuffle_idx = {0, 4, 1, 5, 2, 6, 3, 7}; + static_assert( + std::is_same::value || std::is_same::value, + "transpose_4x8_1_b only support uint8_t and int8_t"); + asm volatile( + "vld1.32 {d0[0]}, [%[inptr0]]!\n" // A1A2A3A4 + "vld1.32 {d0[1]}, [%[inptr1]]!\n" // B1B2B3B4 + "vld1.32 {d1[0]}, [%[inptr2]]!\n" // C1C2C3C4 + "vld1.32 {d1[1]}, [%[inptr3]]!\n" // D1D2D3D4 + "vld1.32 {d2[0]}, [%[inptr4]]!\n" // E1E2E3E4 + "vld1.32 {d2[1]}, [%[inptr5]]!\n" // F1F2F3F4 + "vld1.32 {d3[0]}, [%[inptr6]]!\n" // G1G2G3G4 + "vld1.32 {d3[1]}, [%[inptr7]]!\n" // H1H2H3H4 + + "vtbl.8 d4, {d0}, %[shuffle_idx]\n" // A1B1A2B2A3B3A4B4 + "vtbl.8 d5, {d1}, %[shuffle_idx]\n" // C1D1C2D2C3D3C4D4 + "vtbl.8 d6, {d2}, %[shuffle_idx]\n" // E1F1E2F2E3F3E4F4 + "vtbl.8 d7, {d3}, %[shuffle_idx]\n" // G1H1G2H2G3H3G4H4 + + "vzip.16 d4, d5\n" // A1B1C1D1A2B2C2D2 A3B3C3D3A4B4C4D4 + "vzip.16 d6, d7\n" // E1F1G1H1E2F2G2H2 E3F3G3H3E4F4G4H4 + "vzip.32 d4, d6\n" // A1B1C1D1E1F1G1H1 A2B2C2D2E2F2G2H2 + "vzip.32 d5, d7\n" // A3B3C3D3E3F3G3H3 A4B4C4D4E4F4G4H4 + + "vst1.32 {d4}, [%[outptr]]!\n" // A1B1C1D1E1F1G1H1 + "vst1.32 {d6}, [%[outptr]]!\n" // A2B2C2D2E2F2G2H2 + "vst1.32 {d5}, [%[outptr]]!\n" // A3B3C3D3E3F3G3H3 + "vst1.32 {d7}, [%[outptr]]!\n" // A4B4C4D4E4F4G4H4 + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), + [outptr] "+r"(outptr), [shuffle_idx] "+w"(shuffle_idx) + : + : "q0", "q1", "q2", "q3", "cc", "memory"); +} + +template +static inline void transpose_4x16_1_b_helper(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T* outptr) { + static_assert(sizeof(T) == 1, "only support size == 1"); + static uint8x8_t shuffle_idx = {0, 4, 1, 5, 2, 6, 3, 7}; + asm volatile( + "vld1.32 {d0[0]}, [%[inptr0]]!\n" + "vld1.32 {d0[1]}, [%[inptr1]]!\n" + "vld1.32 {d1[0]}, [%[inptr2]]!\n" + "vld1.32 {d1[1]}, [%[inptr3]]!\n" + "vld1.32 {d2[0]}, [%[inptr4]]!\n" + "vld1.32 {d2[1]}, [%[inptr5]]!\n" + "vld1.32 {d3[0]}, [%[inptr6]]!\n" + "vld1.32 {d3[1]}, [%[inptr7]]!\n" + "mov r0, #16\n" + + "vtbl.8 d4, {d0}, %[shuffle_idx]\n" + "vtbl.8 d5, {d1}, %[shuffle_idx]\n" + "vtbl.8 d6, {d2}, %[shuffle_idx]\n" + "vtbl.8 d7, {d3}, %[shuffle_idx]\n" + + "vzip.16 d4, d5\n" + "vzip.16 d6, d7\n" + "vzip.32 d4, d6\n" + "vzip.32 d5, d7\n" + + "vst1.64 d4, [%[outptr]], r0\n" + "vst1.64 d6, [%[outptr]], r0\n" + "vst1.64 d5, [%[outptr]], r0\n" + "vst1.64 d7, [%[outptr]]\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), + [outptr] "+r"(outptr), [shuffle_idx] "+w"(shuffle_idx) + : + : "q0", "q1", "q2", "q3", "q4", "r0", "memory"); +} + +template +static inline void transpose_4x4_1_h(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr, int stride = 8) { + static_assert(sizeof(T) == 2, + "transpose_4x4_1_h only support sizeof(T) == 2"); + + asm volatile( + "vld1.16 {d0}, [%[inptr0]]!\n" // A0A1A2A3 + "vld1.16 {d1}, [%[inptr1]]!\n" // B0B1B2B3 + "vld1.16 {d2}, [%[inptr2]]!\n" // C0C1C2C3 + "vld1.16 {d3}, [%[inptr3]]!\n" // D0D1D2D3 + "vtrn.16 d0, d1\n" // A0B0A2B2A1B1A3B3 + "vtrn.16 d2, d3\n" // C0D0C2D2C1D1C3D3 + "vtrn.32 q0, q1\n" // A0B0C0D0 A1B1C1D1 A2B2C2D2 A3B3C3D3 + "vst1.16 {d0}, [%[outptr]], %[stride]\n" // A0B0C0D0 + "vst1.16 {d1}, [%[outptr]], %[stride]\n" // A1B1C1D1 + "vst1.16 {d2}, [%[outptr]], %[stride]\n" // A2B2C2D2 + "vst1.16 {d3}, [%[outptr]], %[stride]\n" // A3B3C3D3 + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : [stride] "r"(stride) + : "d0", "d1", "d2", "d3", "memory"); +} + +template +static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr, int stride = 16) { + static_assert(sizeof(T) == 4, + "transpose_4x4_1_s only support sizeof(T) == 4"); + + stride -= 8; + asm volatile( + "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 + "vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1B2B3 + "vld1.32 {d4, d5}, [%[inptr2]]!\n" // C0C1C2C3 + "vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1D2D3 + "vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3 + "vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3 + "vst1.32 {d0}, [%[outptr]]!\n" + "vst1.32 {d4}, [%[outptr]], %[stride]\n" + "vst1.32 {d2}, [%[outptr]]!\n" + "vst1.32 {d6}, [%[outptr]], %[stride]\n" + "vst1.32 {d1}, [%[outptr]]!\n" + "vst1.32 {d5}, [%[outptr]], %[stride]\n" + "vst1.32 {d3}, [%[outptr]]!\n" + "vst1.32 {d7}, [%[outptr]], %[stride]\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr), [stride] "+r" (stride) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "memory"); +} + +template +static inline void transpose_4x2_1_s(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T* outptr, int stride = 8) { + static_assert(sizeof(T) == 4, + "transpose_4x2_1_s only support sizeof(T) == 4"); + + stride -= 8; + asm volatile( + "vld1.32 {d0}, [%[inptr0]]!\n" // A0A1 + "vld1.32 {d1}, [%[inptr1]]!\n" // B0B1 + "vld1.32 {d2}, [%[inptr2]]!\n" // C0C1 + "vld1.32 {d3}, [%[inptr3]]!\n" // D0D1 + "vtrn.32 d0, d1\n" // A0B0 A1B1 + "vtrn.32 d2, d3\n" // C0D0 C1D1 + "vst1.32 {d0}, [%[outptr]]!\n" + "vst1.32 {d2}, [%[outptr]]!\n" + "vst1.32 {d1}, [%[outptr]]!\n" + "vst1.32 {d3}, [%[outptr]]!\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr), [stride] "+r"(stride) + : + : "d0", "d1", "d2", "d3", "memory"); +} + + +template +static inline void transpose_6x4_1_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T* outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_6x4_1_b only support uint8_t and int8_t"); + asm volatile( + "vld1.8 {d0}, [%[inptr0]]\n" // A0A1A2A3A4A5 A6A7 + "vld1.8 {d1}, [%[inptr1]]\n" // B0B1B2B3B4B5 B6B7 + "vld1.8 {d2}, [%[inptr2]]\n" // C0C1C2C3C4C5 C6C7 + "vld1.8 {d3}, [%[inptr3]]\n" // D0D1D2D3D4D5 D6D7 + "vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 + "vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 + + "add %[inptr0],%[inptr0],#6 \n" + "add %[inptr1],%[inptr1],#6 \n" + "add %[inptr2],%[inptr2],#6 \n" + "add %[inptr3],%[inptr3],#6 \n" + + "vtrn.16 d0, d2\n" // A0B0 C0D0 A4B4 C4D4---A2B2 C2D2 A6B6 C6D6 + "vtrn.16 d1, d3\n" // A1B1 C1D1 A5B5 C5D5---A3B3 C3D3 A7B7 C7D7 + + "vst1.32 {d0[0]},[%[outptr]]!\n" + "vst1.32 {d1[0]},[%[outptr]]!\n" + + "vst1.32 {d2[0]},[%[outptr]]!\n" + "vst1.32 {d3[0]},[%[outptr]]!\n" + + "vst1.32 {d0[1]},[%[outptr]]!\n" + "vst1.32 {d1[1]},[%[outptr]]!\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "memory"); +} + +template +static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T* outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_4x4_1_b only support uint8_t and int8_t"); + asm volatile( + "vld1.8 {d0}, [%[inptr0]]\n" // A0A1A2A3A4A5 A6A7 + "vld1.8 {d1}, [%[inptr1]]\n" // B0B1B2B3B4B5 B6B7 + "vld1.8 {d2}, [%[inptr2]]\n" // C0C1C2C3C4C5 C6C7 + "vld1.8 {d3}, [%[inptr3]]\n" // D0D1D2D3D4D5 D6D7 + "vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 + "vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 + + "add %[inptr0],%[inptr0],#4 \n" + "add %[inptr1],%[inptr1],#4 \n" + "add %[inptr2],%[inptr2],#4 \n" + "add %[inptr3],%[inptr3],#4 \n" + + "vtrn.16 d0, d2\n" // A0B0 C0D0 A4B4 C4D4---A2B2 C2D2 A6B6 C6D6 + "vtrn.16 d1, d3\n" // A1B1 C1D1 A5B5 C5D5---A3B3 C3D3 A7B7 C7D7 + + "vst1.32 {d0[0]},[%[outptr]]!\n" + "vst1.32 {d1[0]},[%[outptr]]!\n" + + "vst1.32 {d2[0]},[%[outptr]]!\n" + "vst1.32 {d3[0]},[%[outptr]]!\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "memory"); +} + +template +static inline void transpose_4(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, T* outptr, + int interleave, int size, T val = 0) { + megdnn_assert(size <= interleave); + int i = 0; + for (; i < size; i++) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + } + for (; i < interleave; i++) { + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + } +} + +template +static inline void transpose_8(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, T* outptr, + int interleave, int size, T val = 0) { + megdnn_assert(size <= interleave); + int i = 0; + for (; i < size; i++) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + for (; i < interleave; i++) { + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + *outptr++ = val; + } +} + + +template +static inline void transpose_4x1(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + T*& outptr) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; +} + +template +static inline void transpose_12x1(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + const T*& inptr8, const T*& inptr9, + const T*& inptr10, const T*& inptr11, + T*& outptr) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + *outptr++ = *inptr8++; + *outptr++ = *inptr9++; + *outptr++ = *inptr10++; + *outptr++ = *inptr11++; +} + +/***********************************Transpose interleave *************/ +//! pack form {1, 4(icb), 4(ic), 4(oc)} to {1, 1, 4(oc), 16(ic)} +template +static inline void transpose_interleave_4x4_4_b(const T*& inptr0, + const T*& inptr1, + const T*& inptr2, + const T*& inptr3, T* outptr, + int stride = 64) { + static_assert(sizeof(T) == 1, + "transpose_interleave_4x4_4_b only support sizeof(T) == 1"); + + asm volatile( + "add r1, %[outptr], %[stride]\n" + "vld4.8 {d0-d3},[%[inptr0]]!\n" + "vld4.8 {d4-d7},[%[inptr0]]!\n" + "add r2, r1, %[stride]\n" + "vld4.8 {d8-d11},[%[inptr1]]!\n" + "vld4.8 {d12-d15},[%[inptr1]]!\n" + "vld4.8 {d16-d19},[%[inptr2]]!\n" + "add r3, r2, %[stride]\n" + "vld4.8 {d20-d23},[%[inptr2]]!\n" + "vld4.8 {d24-d27},[%[inptr3]]!\n" + "vld4.8 {d28-d31},[%[inptr3]]!\n" + + "vst1.8 d0, [%[outptr]]!\n" + "vst1.8 d4, [%[outptr]]!\n" + "vst1.8 d1, [%[outptr]]!\n" + "vst1.8 d5, [%[outptr]]!\n" + "vst1.8 d2, [%[outptr]]!\n" + "vst1.8 d6, [%[outptr]]!\n" + "vst1.8 d3, [%[outptr]]!\n" + "vst1.8 d7, [%[outptr]]!\n" + + "vst1.8 d8, [r1]!\n" + "vst1.8 d12,[r1]!\n" + "vst1.8 d9, [r1]!\n" + "vst1.8 d13,[r1]!\n" + "vst1.8 d10,[r1]!\n" + "vst1.8 d14,[r1]!\n" + "vst1.8 d11,[r1]!\n" + "vst1.8 d15,[r1]!\n" + + "vst1.8 d16,[r2]!\n" + "vst1.8 d20,[r2]!\n" + "vst1.8 d17,[r2]!\n" + "vst1.8 d21,[r2]!\n" + "vst1.8 d18,[r2]!\n" + "vst1.8 d22,[r2]!\n" + "vst1.8 d19,[r2]!\n" + "vst1.8 d23,[r2]!\n" + + "vst1.8 d24,[r3]!\n" + "vst1.8 d28,[r3]!\n" + "vst1.8 d25,[r3]!\n" + "vst1.8 d29,[r3]!\n" + "vst1.8 d26,[r3]!\n" + "vst1.8 d30,[r3]!\n" + "vst1.8 d27,[r3]!\n" + "vst1.8 d31,[r3]!\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr), [stride] "+r"(stride) + : + : "r1", "r2", "r3", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q14", "q15", "memory"); +} + +template +static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr, + int stride = 64) { + static_assert(sizeof(T) == 1, + "transpose_interleave_1x4_4_b only support sizeof(T) == 1"); + + asm volatile( + "vld4.8 {d0-d3},[%[inptr0]]!\n" + "vld4.8 {d4-d7},[%[inptr0]]!\n" + + "vst1.8 d0, [%[outptr]]!\n" + "vst1.8 d4, [%[outptr]]!\n" + "vst1.8 d1, [%[outptr]]!\n" + "vst1.8 d5, [%[outptr]]!\n" + "vst1.8 d2, [%[outptr]]!\n" + "vst1.8 d6, [%[outptr]]!\n" + "vst1.8 d3, [%[outptr]]!\n" + "vst1.8 d7, [%[outptr]]!\n" + : + [inptr0] "+r"(inptr0), [outptr] "+r"(outptr), [stride] "+r"(stride) + : + : "q0", "q1", "q2", "q3", "memory"); +} + +} // armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/fp16/strategy.cpp b/dnn/src/armv7/matrix_mul/fp16/strategy.cpp new file mode 100644 index 00000000..214787ed --- /dev/null +++ b/dnn/src/armv7/matrix_mul/fp16/strategy.cpp @@ -0,0 +1,738 @@ +/** + * \file dnn/src/armv7/matrix_mul/fp16/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/matrix_mul/fp16/strategy.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +using namespace megdnn; +using namespace armv7; +using namespace armv7::matmul; + +namespace { + +// Overview of register layout: +// +// A 2x16 cell of Rhs is stored in 16bit in q1-q4 +// A 4x2 cell of Lhs is stored in 16bit in q0 +// A 4x16 block of accumulators is stored in 16bit in q5-q12. +// +// +--------+--------+ +// | v1[0-7]| v2[0-7]| +// Rhs +--------+--------+ +// | v3[0-7]| v4[0-7]| +// +--------+--------+ +// +// | | | +// +// Lhs | | | +// +// +--+--+ - - - - +--------+--------+ +// |v0|v0| | v5[0-7]| v6[0-7]| +// |v0|v0| | v7[0-7]| v8[0-7]| +// |v0|v0| | v9[0-7]|v10[0-7]| +// |v0|v0| |v11[0-7]|v12[0-7]| +// +--+--+ - - - - +--------+--------+ +// +// Accumulator +void kern_4x16(const dt_float16* packA, const dt_float16* packB, int K, + dt_float16* output, int LDC, bool is_first_k, int m_remain) { + const __fp16* a_ptr = reinterpret_cast(packA); + const __fp16* b_ptr = reinterpret_cast(packB); + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(__fp16); + register __fp16* outptr asm("r0") = reinterpret_cast<__fp16*>(output); + +// clang-format off +#define LOAD_LINE(d0, d1, d2, d3, n) \ + "cmp r10, #0\n" \ + "beq 100f\n" \ + "vld1.16 {d" d0 ",d" d1 ",d" d2 ",d" d3 "}, [r" n "]\n" \ + "subs r10, r10, #1\n" + +#define LOAD_C \ + "mov r10, %[m_remain]\n" \ + LOAD_LINE("10", "11", "12", "13", "0") \ + LOAD_LINE("14", "15", "16", "17", "1") \ + LOAD_LINE("18", "19", "20", "21", "2") \ + LOAD_LINE("22", "23", "24", "25", "3") \ + "100:\n" + +#define STORE_LINE(d0, d1, d2, d3, n) \ + "cmp r10, #0\n" \ + "beq 101f\n" \ + "vst1.16 {d" d0 ",d" d1 ",d" d2 ",d" d3 "}, [r" n "]\n" \ + "subs r10, r10, #1\n" + +#define STORE_C \ + "mov r10, %[m_remain]\n" \ + STORE_LINE("10", "11", "12", "13", "0") \ + STORE_LINE("14", "15", "16", "17", "1") \ + STORE_LINE("18", "19", "20", "21", "2") \ + STORE_LINE("22", "23", "24", "25", "3") \ + "101:\n" + // clang-format on + + asm volatile( + // load accumulator C + "add r1, r0, %[LDC]\n" + "add r2, r1, %[LDC]\n" + "add r3, r2, %[LDC]\n" + + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.32 q5, q5, q5\n" + "veor.32 q6, q6, q6\n" + "veor.32 q7, q7, q7\n" + "veor.32 q8, q8, q8\n" + "veor.32 q9, q9, q9\n" + "veor.32 q10, q10, q10\n" + "veor.32 q11, q11, q11\n" + "veor.32 q12, q12, q12\n" + + "2: \n" + "vld1.16 {d2, d3, d4, d5}, [%[b_ptr]]!\n" + + "cmp %[K], #0\n" + "beq 4f\n" + + "3:\n" + "vld1.16 {d0, d1}, [%[a_ptr]]!\n" + "vld1.16 {d6, d7, d8, d9}, [%[b_ptr]]!\n" + "vmla.f16 q5, q1, d0[0]\n" + "vmla.f16 q6, q2, d0[0]\n" + "vmla.f16 q7, q1, d0[1]\n" + "vmla.f16 q8, q2, d0[1]\n" + "vmla.f16 q9, q1, d0[2]\n" + "vmla.f16 q10, q2, d0[2]\n" + "vmla.f16 q11, q1, d0[3]\n" + "vmla.f16 q12, q2, d0[3]\n" + + "vmla.f16 q5, q3, d1[0]\n" + "vmla.f16 q6, q4, d1[0]\n" + "vmla.f16 q7, q3, d1[1]\n" + "vmla.f16 q8, q4, d1[1]\n" + "vmla.f16 q9, q3, d1[2]\n" + "vmla.f16 q10, q4, d1[2]\n" + "vmla.f16 q11, q3, d1[3]\n" + "vmla.f16 q12, q4, d1[3]\n" + + "vld1.16 {d2, d3, d4, d5}, [%[b_ptr]]!\n" + "subs %[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %[oddk], #1\n" + "beq 5f\n" + + // Even tail + "vld1.16 {d0, d1}, [%[a_ptr]]!\n" + "vld1.16 {d6, d7, d8, d9}, [%[b_ptr]]!\n" + "vmla.f16 q5, q1, d0[0]\n" + "vmla.f16 q6, q2, d0[0]\n" + "vmla.f16 q7, q1, d0[1]\n" + "vmla.f16 q8, q2, d0[1]\n" + "vmla.f16 q9, q1, d0[2]\n" + "vmla.f16 q10, q2, d0[2]\n" + "vmla.f16 q11, q1, d0[3]\n" + "vmla.f16 q12, q2, d0[3]\n" + + "vmla.f16 q5, q3, d1[0]\n" + "vmla.f16 q6, q4, d1[0]\n" + "vmla.f16 q7, q3, d1[1]\n" + "vmla.f16 q8, q4, d1[1]\n" + "vmla.f16 q9, q3, d1[2]\n" + "vmla.f16 q10, q4, d1[2]\n" + "vmla.f16 q11, q3, d1[3]\n" + "vmla.f16 q12, q4, d1[3]\n" + "b 6f\n" + + // odd tail + "5:\n" + "vld1.16 {d0}, [%[a_ptr]]!\n" + "vmla.f16 q5, q1, d0[0]\n" + "vmla.f16 q6, q2, d0[0]\n" + "vmla.f16 q7, q1, d0[1]\n" + "vmla.f16 q8, q2, d0[1]\n" + "vmla.f16 q9, q1, d0[2]\n" + "vmla.f16 q10, q2, d0[2]\n" + "vmla.f16 q11, q1, d0[3]\n" + "vmla.f16 q12, q2, d0[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [m_remain] "+r"(m_remain), + [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", + "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", + "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", + "d25", "r1", "r2", "r3", "r10", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + + +// Overview of register layout: +// +// A 2x4 cell of Rhs is stored in 16bit in q1 +// A 4x2 cell of Lhs is stored in 16bit in q0 +// A 4x4 block of accumulators is stored in 16bit in q2-q5 +// +// +--------+ +// | v1[0-3]| +// Rhs +--------+ +// | v1[4-7]| +// +--------+ +// +// | | +// +// Lhs | | +// +// +--+--+ - - - - +--------+ +// |v0|v0| | v2[0-3]| +// |v0|v0| | v3[0-3]| +// |v0|v0| | v4[0-3]| +// |v0|v0| | v5[0-3]| +// +--+--+ - - - - +--------+ +// +// Accumulator +void kern_4x4(const dt_float16* packA, const dt_float16* packB, int K, + dt_float16* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + const __fp16* a_ptr = reinterpret_cast(packA); + const __fp16* b_ptr = reinterpret_cast(packB); + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(__fp16); + register __fp16* outptr asm("r0") = reinterpret_cast<__fp16*>(output); + +// clang-format off +#define LOAD_LINE(d0, n) \ + "cmp r10, #0\n" \ + "beq 102f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "vld1.16 {d" d0 "}, [r" n "]\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "vld1.16 {d" d0 "[0]}, [r" n " ]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "vld1.16 {d" d0 "[1]}, [r" n " ]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "vld1.16 {d" d0 "[2]}, [r" n " ]!\n" \ + "101" n ":\n" \ + "subs r10, r10, #1\n" + +#define LOAD_C \ + "mov r10, %[m_remain]\n" \ + LOAD_LINE("4", "0") \ + LOAD_LINE("6", "1") \ + LOAD_LINE("8", "2") \ + LOAD_LINE("10", "3") \ + "102:\n" + +#define STORE_LINE(d0, n) \ + "cmp r10, #0 \n" \ + "beq 105f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "vst1.16 {d" d0 "}, [r" n " ]!\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "vst1.16 {d" d0 "[0]}, [r" n " ]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "vst1.16 {d" d0 "[1]}, [r" n " ]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "vst1.16 {d" d0 "[2]}, [r" n " ]!\n" \ + "104" n ":\n" \ + "subs r10, r10, #1\n" + + + +#define STORE_C \ + "mov r10, %[m_remain]\n" \ + STORE_LINE("4", "0") \ + STORE_LINE("6", "1") \ + STORE_LINE("8", "2") \ + STORE_LINE("10", "3") \ + "105:\n" + // clang-format on + + asm volatile( + // load accumulator C + "add r1, r0, %[LDC]\n" + "add r2, r1, %[LDC]\n" + "add r3, r2, %[LDC]\n" + + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.32 q2, q2, q2\n" + "veor.32 q3, q3, q3\n" + "veor.32 q4, q4, q4\n" + "veor.32 q5, q5, q5\n" + + "2: \n" + "cmp %[K], #0\n" + "beq 4f\n" + + "3:\n" + "vld1.16 {d0, d1}, [%[a_ptr]]!\n" + "vld1.16 {d2, d3}, [%[b_ptr]]!\n" + "vmla.f16 d4, d2, d0[0]\n" + "vmla.f16 d6, d2, d0[1]\n" + "vmla.f16 d8, d2, d0[2]\n" + "vmla.f16 d10, d2, d0[3]\n" + + "vmla.f16 d4, d3, d1[0]\n" + "vmla.f16 d6, d3, d1[1]\n" + "vmla.f16 d8, d3, d1[2]\n" + "vmla.f16 d10, d3, d1[3]\n" + + "subs %[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %[oddk], #1\n" + "beq 5f\n" + + // Even tail + "vld1.16 {d0, d1}, [%[a_ptr]]!\n" + "vld1.16 {d2, d3}, [%[b_ptr]]!\n" + "vmla.f16 d4, d2, d0[0]\n" + "vmla.f16 d6, d2, d0[1]\n" + "vmla.f16 d8, d2, d0[2]\n" + "vmla.f16 d10, d2, d0[3]\n" + + "vmla.f16 d4, d3, d1[0]\n" + "vmla.f16 d6, d3, d1[1]\n" + "vmla.f16 d8, d3, d1[2]\n" + "vmla.f16 d10, d3, d1[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "vld1.16 {d0}, [%[a_ptr]]!\n" + "vld1.16 {d2}, [%[b_ptr]]!\n" + "vmla.f16 d4, d2, d0[0]\n" + "vmla.f16 d6, d2, d0[1]\n" + "vmla.f16 d8, d2, d0[2]\n" + "vmla.f16 d10, d2, d0[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [m_remain] "+r"(m_remain), + [n_remain] "+r"(n_remain), [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", + "d9", "d10", "r1", "r2", "r3", "r10", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +void hgemm_4x16_pack_A_n(__fp16* outptr, const __fp16* inptr, int ldin, int y0, + int ymax, int k0, int kmax) { + __fp16 zerobuff[16]; + std::memset(zerobuff, 0, sizeof(__fp16) * 8); + + int y = y0; + for (; y + 3 < ymax; y += 4) { + const __fp16* inptr0 = inptr + y * ldin + k0; + const __fp16* inptr1 = inptr0 + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = (kmax - k0); + for (; K > 3; K -= 4) { + transpose_4x4_1_h(inptr0, inptr1, inptr2, inptr3, outptr); + } + + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); + } + + for (; y < ymax; y += 4) { + const __fp16* inptr0 = inptr + y * ldin + k0; + const __fp16* inptr1 = inptr0 + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = (kmax - k0); + for (; K > 3; K -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x4_1_h(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); + } + } +} + +void hgemm_4x16_pack_A_t(__fp16* out, const __fp16* in, int ldin, int x0, + int xmax, int k0, int kmax) { + int ksize = kmax - k0; + int ksize4 = (ksize << 2); + __fp16* outptr_base = reinterpret_cast<__fp16*>(out); + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const __fp16* inptr = in + k * ldin + x0; + const __fp16* inptr1 = inptr + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + prefetch_3x(inptr3); + + int x = x0; + auto outptr = outptr_base; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_4x4_1_h(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + + outptr_base += 4 * 4; + } + + for (; k < kmax; k++) { + const __fp16* inptr = + reinterpret_cast(in + k * ldin + x0); + prefetch_3x(inptr); + int x = x0; + auto outptr = outptr_base; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_1x4_1_h(inptr, outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_1(inptr, outptr, 4, xmax - x); + } + + outptr_base += 4; + } + +} + + +void hgemm_4x16_pack_B_n(__fp16* out, const __fp16* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int ksize = kmax - k0; + int ksize16 = (ksize << 4); + int ksize4 = (ksize << 2); + __fp16* outptr_base = reinterpret_cast<__fp16*>(out); + __fp16* outptr_base4 = outptr_base + (xmax - x0) / 16 * ksize16; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const __fp16* inptr = in + k * ldin + x0; + const __fp16* inptr1 = inptr + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + prefetch_3x(inptr3); + + int x = x0; + auto outptr = outptr_base; + for (; x + 16 <= xmax; x += 16) { + auto outptr_interleave = outptr; + interleave_4x16_1_h(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize16; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_4x4_1_h(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + + outptr_base += 16 * 4; + outptr_base4 += 4 * 4; + } + + for (; k < kmax; k++) { + const __fp16* inptr = + reinterpret_cast(in + k * ldin + x0); + prefetch_3x(inptr); + int x = x0; + auto outptr = outptr_base; + for (; x + 16 <= xmax; x += 16) { + auto outptr_interleave = outptr; + interleave_1x16_1_h(inptr, outptr_interleave); + outptr += ksize16; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_1x4_1_h(inptr, outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_1(inptr, outptr, 4, xmax - x); + } + + outptr_base += 16; + outptr_base4 += 4; + } +} + +void hgemm_4x16_pack_B_t(__fp16* out, const __fp16* in, int ldin, + int y0, int ymax, int k0, int kmax) { + __fp16* outptr = out; + const __fp16* inptr = in; + __fp16 zerobuff[16]; + std::memset(zerobuff, 0, sizeof(__fp16) * 16); + int K16 = 16 * (kmax - k0); + + int y = y0; + + for (; y + 16 <= ymax; y += 16) { + int yi = y; + for (; yi < y + 16; yi += 4) { + const __fp16* inptr0 = inptr + yi * ldin + k0; + const __fp16* inptr1 = inptr0 + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + __fp16* outptr_inner = outptr + yi - y; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int x = (kmax - k0); + for (; x > 3; x -= 4) { + transpose_4x4_1_h(inptr0, inptr1, inptr2, inptr3, outptr_inner, + 32); + } + for (; x > 0; x--) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + outptr_inner += 12; + } + } + outptr += K16; + } + + for (; y < ymax; y += 4) { + const __fp16* inptr0 = inptr + y * ldin + k0; + const __fp16* inptr1 = inptr0 + ldin; + const __fp16* inptr2 = inptr1 + ldin; + const __fp16* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + /* Cope with ragged cases by copying from a buffer of zeroes instead + */ + int x = (kmax - k0); + for (; x > 3; x -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x4_1_h(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (x > 0) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, x); + } + } +} + +} // namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL(hgemm_4x16); + +void hgemm_4x16::pack_A(dt_float16* out, const dt_float16* in, int ldin, int y0, + int ymax, int k0, int kmax, bool transpose_A) const { + if (transpose_A) { + hgemm_4x16_pack_A_t(reinterpret_cast<__fp16*>(out), + reinterpret_cast(in), ldin, y0, ymax, + k0, kmax); + } else { + hgemm_4x16_pack_A_n(reinterpret_cast<__fp16*>(out), + reinterpret_cast(in), ldin, y0, ymax, + k0, kmax); + } +} + +void hgemm_4x16::pack_B(dt_float16* out, const dt_float16* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose_B) const { + if (transpose_B) { + hgemm_4x16_pack_B_t(reinterpret_cast<__fp16*>(out), + reinterpret_cast(in), ldin, x0, xmax, + k0, kmax); + } else { + hgemm_4x16_pack_B_n(reinterpret_cast<__fp16*>(out), + reinterpret_cast(in), ldin, x0, xmax, + k0, kmax); + } +} + +void hgemm_4x16::kern(const dt_float16* packA, const dt_float16* packB, + size_t M, size_t N, size_t K, dt_float16* C, size_t LDC, + bool is_first_k, const dt_float16*, dt_float16*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float16); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 16; + const int K16 = K * 16; + const int K4 = K * 4; + + size_t m = 0; + for (; m < M; m += A_INTERLEAVE) { + dt_float16* output = C + (m * LDC); + + size_t n = 0; + const dt_float16* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + kern_4x16(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); + output += B_INTERLEAVE; + cur_packB += K16; + } + + for (; n < N; n += 4) { + kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + + packA += K4; + } +} +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/fp16/strategy.h b/dnn/src/armv7/matrix_mul/fp16/strategy.h new file mode 100644 index 00000000..4b10675d --- /dev/null +++ b/dnn/src/armv7/matrix_mul/fp16/strategy.h @@ -0,0 +1,29 @@ +/** + * \file dnn/src/armv7/matrix_mul/fp16/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +namespace megdnn { +namespace armv7 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 4, 16, 1, false, + true, hgemm_4x16); + +MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 4, 8, 1, + false, true, gemm_nopack_f16_4x8); + +} // namespace matmul +} // namespace armv7 +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/fp16/strategy_mk8_4x8.cpp b/dnn/src/armv7/matrix_mul/fp16/strategy_mk8_4x8.cpp new file mode 100644 index 00000000..90903f63 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/fp16/strategy_mk8_4x8.cpp @@ -0,0 +1,200 @@ +/** + * \file dnn/src/armv7/matrix_mul/fp16/strategy_mk8_4x8.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/matrix_mul/fp16/strategy.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +using namespace megdnn; +using namespace armv7; +using namespace armv7::matmul; + +namespace { + +// Overview of register layout: +// +// A 8x1 cell of Rhs is stored in 16bit in v4-v11 +// A 4x1 cell of Lhs is stored in 16bit in v0-v3 +// A 4x1 block of accumulators is stored in 16bit in v12-v15. +// +// Rhs +--------+ +// | v4[0-7]| +// | v5[0-7]| +// | v6[0-7]| +// | v7[0-7]| +// | v8[0-7]| +// | v9[0-7]| +// |v10[0-7]| +// |v11[0-7]| +// +--------+ +// Lhs +// +--------+ - - - - -+--------+ +// | v0[0-7]| |v12[0-7]| +// | v1[0-7]| |v13[0-7]| +// | v2[0-7]| |v14[0-7]| +// | v3[0-7]| |v15[0-7]| +// +--------+ +--------+--------+ +// Accumulator +void kern_4x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, + dt_float16* output) { + //! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 + //! here. + LDB = (LDB - 16) * sizeof(dt_float16); + + asm volatile( + "subs %[K], #8\n" + + "vld1.32 {d0, d1, d2, d3}, [%[b_ptr]]!\n" + "vld1.32 {d4, d5, d6, d7}, [%[b_ptr]], %[LDB]\n" + "vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" + + "vmul.f16 q12, q4, d0[0]\n" + "vmul.f16 q13, q4, d2[0]\n" + "vmul.f16 q14, q4, d4[0]\n" + "vmul.f16 q15, q4, d6[0]\n" + + "vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" + "vmla.f16 q12, q5, d0[1]\n" + "vmla.f16 q13, q5, d2[1]\n" + "vmla.f16 q14, q5, d4[1]\n" + "vmla.f16 q15, q5, d6[1]\n" + + "vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" + "vmla.f16 q12, q6, d0[2]\n" + "vmla.f16 q13, q6, d2[2]\n" + "vmla.f16 q14, q6, d4[2]\n" + "vmla.f16 q15, q6, d6[2]\n" + + "vmla.f16 q12, q7, d0[3]\n" + "vmla.f16 q13, q7, d2[3]\n" + "vmla.f16 q14, q7, d4[3]\n" + "vmla.f16 q15, q7, d6[3]\n" + + "vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" + "vmla.f16 q12, q8, d1[0]\n" + "vmla.f16 q13, q8, d3[0]\n" + "vmla.f16 q14, q8, d5[0]\n" + "vmla.f16 q15, q8, d7[0]\n" + + "vmla.f16 q12, q9, d1[1]\n" + "vmla.f16 q13, q9, d3[1]\n" + "vmla.f16 q14, q9, d5[1]\n" + "vmla.f16 q15, q9, d7[1]\n" + + "vmla.f16 q12, q10, d1[2]\n" + "vmla.f16 q13, q10, d3[2]\n" + "vmla.f16 q14, q10, d5[2]\n" + "vmla.f16 q15, q10, d7[2]\n" + + "vmla.f16 q12, q11, d1[3]\n" + "vmla.f16 q13, q11, d3[3]\n" + "vmla.f16 q14, q11, d5[3]\n" + "vmla.f16 q15, q11, d7[3]\n" + + "beq 2f\n" + + "1:\n" + "vld1.32 {d0, d1, d2, d3}, [%[b_ptr]]!\n" + "vld1.32 {d4, d5, d6, d7}, [%[b_ptr]], %[LDB]\n" + "vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" + + "vmla.f16 q12, q4, d0[0]\n" + "vmla.f16 q13, q4, d2[0]\n" + "vmla.f16 q14, q4, d4[0]\n" + "vmla.f16 q15, q4, d6[0]\n" + + "vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" + "vmla.f16 q12, q5, d0[1]\n" + "vmla.f16 q13, q5, d2[1]\n" + "vmla.f16 q14, q5, d4[1]\n" + "vmla.f16 q15, q5, d6[1]\n" + + "vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" + "vmla.f16 q12, q6, d0[2]\n" + "vmla.f16 q13, q6, d2[2]\n" + "vmla.f16 q14, q6, d4[2]\n" + "vmla.f16 q15, q6, d6[2]\n" + + "vmla.f16 q12, q7, d0[3]\n" + "vmla.f16 q13, q7, d2[3]\n" + "vmla.f16 q14, q7, d4[3]\n" + "vmla.f16 q15, q7, d6[3]\n" + + "vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" + "vmla.f16 q12, q8, d1[0]\n" + "vmla.f16 q13, q8, d3[0]\n" + "vmla.f16 q14, q8, d5[0]\n" + "vmla.f16 q15, q8, d7[0]\n" + + "vmla.f16 q12, q9, d1[1]\n" + "vmla.f16 q13, q9, d3[1]\n" + "vmla.f16 q14, q9, d5[1]\n" + "vmla.f16 q15, q9, d7[1]\n" + + "vmla.f16 q12, q10, d1[2]\n" + "vmla.f16 q13, q10, d3[2]\n" + "vmla.f16 q14, q10, d5[2]\n" + "vmla.f16 q15, q10, d7[2]\n" + + "vmla.f16 q12, q11, d1[3]\n" + "vmla.f16 q13, q11, d3[3]\n" + "vmla.f16 q14, q11, d5[3]\n" + "vmla.f16 q15, q11, d7[3]\n" + + "subs %[K], #8\n" + "bne 1b\n" + + "2:\n" + "vst1.32 {d24, d25, d26, d27}, [%[output]]!\n" + "vst1.32 {d28, d29, d30, d31}, [%[output]]!\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", + "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", + "d29", "d30", "d31", "cc", "memory"); +} + +} // anonymous namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_4x8); + +void gemm_nopack_f16_4x8::kern(const dt_float16* A, size_t LDA, + const dt_float16* B, size_t LDB, dt_float16* C, + size_t LDC, size_t M, size_t K, size_t N, + const dt_float16*, void*, bool trA, + bool trB) const { + constexpr static size_t MB = 8; + constexpr static size_t KB = 8; + constexpr static size_t NB = 4; + constexpr static size_t CALCBLK = 4; + + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); + + //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) + for (size_t m = 0; m < M; m += MB) { + dt_float16* output = C + (m / MB) * LDC; + const dt_float16* cur_B = B; + for (size_t n = 0; n < N; n += NB) { + kern_4x8(A, cur_B, LDB, K, output); + cur_B += KB * NB; + output += MB * NB; + } + A += LDA; + } +} + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/fp32/strategy.h b/dnn/src/armv7/matrix_mul/fp32/strategy.h new file mode 100644 index 00000000..9b7b0bae --- /dev/null +++ b/dnn/src/armv7/matrix_mul/fp32/strategy.h @@ -0,0 +1,28 @@ +/** + * \file dnn/src/armv7/matrix_mul/fp32/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, + sgemm_4x12); + +MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 8, 1, false, true, + sgemm_nopack_4x8); + +} // namespace matmul +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp b/dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp new file mode 100644 index 00000000..ebfa1a4e --- /dev/null +++ b/dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp @@ -0,0 +1,757 @@ +/** + * \file dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/matrix_mul/fp32/strategy.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +using namespace megdnn; +using namespace armv7; +using namespace armv7::matmul; + +namespace { + +// Overview of register layout: +// +// A 1x12 cell of Rhs is stored in 32bit in q1-q3 +// A 4x1 cell of Lhs is stored in 132bit in q0 +// A 4x12 block of accumulators is stored in 32bit in q4-q15. +// +// +--------+--------+--------+ +// | v1[0-3]| v2[0-3]| v3[0-3]| +// Rhs +--------+--------+--------+ +// +// | | | | +// +// Lhs | | | | +// +// +--+ - - - - +--------+--------+--------+ +// |v0| | v4[0-3]| v5[0-3]| v6[0-3]| +// |v0| | v7[0-3]| v8[0-3]| v9[0-3]| +// |v0| |v10[0-3]|v11[0-3]|v12[0-3]| +// |v0| |v13[0-3]|v14[0-3]|v15[0-3]| +// +--+ - - - - +--------+--------+--------+ +// +// Accumulator +void kern_4x12(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, int m_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("r0") = reinterpret_cast(output); + +// clang-format off +#define LOAD_LINE(d0, d1, d2, d3, d4, d5, n) \ + "cmp r10, #0\n" \ + "beq 100f\n" \ + "mov r9, r" n "\n" \ + "vld1.32 {d" d0 ",d" d1 ",d" d2 ",d" d3 "}, [r9]!\n" \ + "vld1.32 {d" d4 ",d" d5 "}, [r9]\n" \ + "subs r10, r10, #1\n" + +#define LOAD_C \ + "mov r10, %[m_remain]\n" \ + LOAD_LINE("8", "9", "10", "11", "12", "13", "0") \ + LOAD_LINE("14", "15", "16", "17", "18", "19", "1") \ + LOAD_LINE("20", "21", "22", "23", "24", "25", "2") \ + LOAD_LINE("26", "27", "28", "29", "30", "31", "3") \ + "100:\n" + +#define STORE_LINE(d0, d1, d2, d3, d4, d5, n) \ + "cmp r10, #0\n" \ + "beq 101f\n" \ + "mov r9, r" n "\n" \ + "vst1.32 {d" d0 ",d" d1 ",d" d2 ",d" d3 "}, [r9]!\n" \ + "vst1.32 {d" d4 ",d" d5 "}, [r9]\n" \ + "subs r10, r10, #1\n" + +#define STORE_C \ + "mov r10, %[m_remain]\n" \ + STORE_LINE("8", "9", "10", "11", "12", "13", "0") \ + STORE_LINE("14", "15", "16", "17", "18", "19", "1") \ + STORE_LINE("20", "21", "22", "23", "24", "25", "2") \ + STORE_LINE("26", "27", "28", "29", "30", "31", "3") \ + "101:\n" + // clang-format on + + asm volatile( + // load accumulator C + "add r1, r0, %[LDC]\n" + "add r2, r1, %[LDC]\n" + "add r3, r2, %[LDC]\n" + + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.32 q4, q4, q4\n" + "veor.32 q5, q5, q5\n" + "veor.32 q6, q6, q6\n" + "veor.32 q7, q7, q7\n" + "veor.32 q8, q8, q8\n" + "veor.32 q9, q9, q9\n" + "veor.32 q10, q10, q10\n" + "veor.32 q11, q11, q11\n" + "veor.32 q12, q12, q12\n" + "veor.32 q13, q13, q13\n" + "veor.32 q14, q14, q14\n" + "veor.32 q15, q15, q15\n" + + "2: \n" + "vld1.32 {d2, d3, d4, d5}, [%[b_ptr]]!\n" + "vld1.32 {d6, d7}, [%[b_ptr]]!\n" + + "cmp %[K], #0\n" + "beq 4f\n" + + "3:\n" + "vld1.32 {d0, d1}, [%[a_ptr]]!\n" + "vmla.f32 q4, q1, d0[0]\n" + "vmla.f32 q5, q2, d0[0]\n" + "vmla.f32 q6, q3, d0[0]\n" + "vmla.f32 q7, q1, d0[1]\n" + "vmla.f32 q8, q2, d0[1]\n" + "vmla.f32 q9, q3, d0[1]\n" + "vmla.f32 q10, q1, d1[0]\n" + "vmla.f32 q11, q2, d1[0]\n" + "vmla.f32 q12, q3, d1[0]\n" + "vmla.f32 q13, q1, d1[1]\n" + "vmla.f32 q14, q2, d1[1]\n" + "vmla.f32 q15, q3, d1[1]\n" + + "vld1.32 {d0, d1}, [%[a_ptr]]!\n" + "vld1.32 {d2, d3, d4, d5}, [%[b_ptr]]!\n" + "vld1.32 {d6, d7}, [%[b_ptr]]!\n" + "vmla.f32 q4, q1, d0[0]\n" + "vmla.f32 q5, q2, d0[0]\n" + "vmla.f32 q6, q3, d0[0]\n" + "vmla.f32 q7, q1, d0[1]\n" + "vmla.f32 q8, q2, d0[1]\n" + "vmla.f32 q9, q3, d0[1]\n" + "vmla.f32 q10, q1, d1[0]\n" + "vmla.f32 q11, q2, d1[0]\n" + "vmla.f32 q12, q3, d1[0]\n" + "vmla.f32 q13, q1, d1[1]\n" + "vmla.f32 q14, q2, d1[1]\n" + "vmla.f32 q15, q3, d1[1]\n" + + "vld1.32 {d2, d3, d4, d5}, [%[b_ptr]]!\n" + "vld1.32 {d6, d7}, [%[b_ptr]]!\n" + "subs %[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %[oddk], #1\n" + "beq 5f\n" + + // Even tail + "vld1.32 {d0, d1}, [%[a_ptr]]!\n" + "vmla.f32 q4, q1, d0[0]\n" + "vmla.f32 q5, q2, d0[0]\n" + "vmla.f32 q6, q3, d0[0]\n" + "vmla.f32 q7, q1, d0[1]\n" + "vmla.f32 q8, q2, d0[1]\n" + "vmla.f32 q9, q3, d0[1]\n" + "vmla.f32 q10, q1, d1[0]\n" + "vmla.f32 q11, q2, d1[0]\n" + "vmla.f32 q12, q3, d1[0]\n" + "vmla.f32 q13, q1, d1[1]\n" + "vmla.f32 q14, q2, d1[1]\n" + "vmla.f32 q15, q3, d1[1]\n" + + "vld1.32 {d0, d1}, [%[a_ptr]]!\n" + "vld1.32 {d2, d3, d4, d5}, [%[b_ptr]]!\n" + "vld1.32 {d6, d7}, [%[b_ptr]]!\n" + "vmla.f32 q4, q1, d0[0]\n" + "vmla.f32 q5, q2, d0[0]\n" + "vmla.f32 q6, q3, d0[0]\n" + "vmla.f32 q7, q1, d0[1]\n" + "vmla.f32 q8, q2, d0[1]\n" + "vmla.f32 q9, q3, d0[1]\n" + "vmla.f32 q10, q1, d1[0]\n" + "vmla.f32 q11, q2, d1[0]\n" + "vmla.f32 q12, q3, d1[0]\n" + "vmla.f32 q13, q1, d1[1]\n" + "vmla.f32 q14, q2, d1[1]\n" + "vmla.f32 q15, q3, d1[1]\n" + "b 6f\n" + + // odd tail + "5:\n" + "vld1.32 {d0, d1}, [%[a_ptr]]!\n" + "vmla.f32 q4, q1, d0[0]\n" + "vmla.f32 q5, q2, d0[0]\n" + "vmla.f32 q6, q3, d0[0]\n" + "vmla.f32 q7, q1, d0[1]\n" + "vmla.f32 q8, q2, d0[1]\n" + "vmla.f32 q9, q3, d0[1]\n" + "vmla.f32 q10, q1, d1[0]\n" + "vmla.f32 q11, q2, d1[0]\n" + "vmla.f32 q12, q3, d1[0]\n" + "vmla.f32 q13, q1, d1[1]\n" + "vmla.f32 q14, q2, d1[1]\n" + "vmla.f32 q15, q3, d1[1]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [m_remain] "+r"(m_remain), + [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", + "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", + "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", + "d25", "d26", "d27", "d28", "d29", "d30", "d31", "r1", + "r2", "r3", "r9", "r10", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +// Overview of register layout: +// +// A 2x4 cell of Rhs is stored in 32bit in q2 - q3 +// A 4x2 cell of Lhs is stored in 32bit in q0 - q1 +// A 4x4 block of accumulators is stored in 32bit in q4-q6 +// +// +--------+ +// | v2[0-3]| +// Rhs +--------+ +// | v3[0-3]| +// +--------+ +// +// | | +// +// Lhs | | +// +// +--+--+ - - - - +--------+ +// |v0|v1| | v4[0-3]| +// |v0|v1| | v5[0-3]| +// |v0|v1| | v6[0-3]| +// |v0|v1| | v7[0-3]| +// +--+--+ - - - - +--------+ +// +// Accumulator +void kern_4x4(const float* packA, const float* packB, int K, float* output, + int LDC, bool is_first_k, int m_remain, int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("r0") = output; + +// clang-format off +#define LOAD_LINE(d0, d1, n) \ + "cmp r10, #0\n" \ + "beq 102f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "vld1.32 {d" d0 ", d" d1 "}, [r" n "]\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" d0 "[0]}, [r" n "]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" d0 "[1]}, [r" n "]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" d1 "[0]}, [r" n "]!\n" \ + "101" n ":\n" \ + "subs r10, r10, #1\n" + +#define LOAD_C \ + "mov r10, %[m_remain]\n" \ + LOAD_LINE("8", "9", "0") \ + LOAD_LINE("10", "11", "1") \ + LOAD_LINE("12", "13", "2") \ + LOAD_LINE("14", "15", "3") \ + "102:\n" + +#define STORE_LINE(d0, d1, n) \ + "cmp r10, #0 \n" \ + "beq 105f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "vst1.32 {d" d0 ", d" d1 "}, [r" n " ]!\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" d0 "[0]}, [r" n "]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" d0 "[1]}, [r" n "]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" d1 "[0]}, [r" n "]!\n" \ + "104" n ":\n" \ + "subs r10, r10, #1\n" + + +#define STORE_C \ + "mov r10, %[m_remain]\n" \ + STORE_LINE("8", "9", "0") \ + STORE_LINE("10", "11", "1") \ + STORE_LINE("12", "13", "2") \ + STORE_LINE("14", "15", "3") \ + "105:\n" + // clang-format on + + asm volatile( + // load accumulator C + "add r1, r0, %[LDC]\n" + "add r2, r1, %[LDC]\n" + "add r3, r2, %[LDC]\n" + + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.32 q4, q4, q4\n" + "veor.32 q5, q5, q5\n" + "veor.32 q6, q6, q6\n" + "veor.32 q7, q7, q7\n" + + "2: \n" + "vld1.32 {d0, d1}, [%[a_ptr]]!\n" + "vld1.32 {d4, d5}, [%[b_ptr]]!\n" + "cmp %[K], #0\n" + "beq 4f\n" + + "3:\n" + "vld1.32 {d2, d3}, [%[a_ptr]]!\n" + "vld1.32 {d6, d7}, [%[b_ptr]]!\n" + "vmla.f32 q4, q2, d0[0]\n" + "vmla.f32 q5, q2, d0[1]\n" + "vmla.f32 q6, q2, d1[0]\n" + "vmla.f32 q7, q2, d1[1]\n" + + "vld1.32 {d0, d1}, [%[a_ptr]]!\n" + "vld1.32 {d4, d5}, [%[b_ptr]]!\n" + "vmla.f32 q4, q3, d2[0]\n" + "vmla.f32 q5, q3, d2[1]\n" + "vmla.f32 q6, q3, d3[0]\n" + "vmla.f32 q7, q3, d3[1]\n" + + "subs %[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %[oddk], #1\n" + "beq 5f\n" + + // Even tail + "vld1.32 {d2, d3}, [%[a_ptr]]!\n" + "vld1.32 {d6, d7}, [%[b_ptr]]!\n" + "vmla.f32 q4, q2, d0[0]\n" + "vmla.f32 q5, q2, d0[1]\n" + "vmla.f32 q6, q2, d1[0]\n" + "vmla.f32 q7, q2, d1[1]\n" + + "vmla.f32 q4, q3, d2[0]\n" + "vmla.f32 q5, q3, d2[1]\n" + "vmla.f32 q6, q3, d3[0]\n" + "vmla.f32 q7, q3, d3[1]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "vmla.f32 q4, q2, d0[0]\n" + "vmla.f32 q5, q2, d0[1]\n" + "vmla.f32 q6, q2, d1[0]\n" + "vmla.f32 q7, q2, d1[1]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [m_remain] "+r"(m_remain), + [n_remain] "+r"(n_remain), [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", + "d9", "d10", "d11", "d12", "d13", "d14", "d15", "r1", + "r2", "r3", "r10", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +void sgemm_4x12_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, + int ymax, int k0, int kmax) { + float zerobuff[4]; + std::memset(zerobuff, 0, sizeof(float) * 4); + + int y = y0; + for (; y + 3 < ymax; y += 4) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = (kmax - k0); + for (; K > 3; K -= 4) { + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); + } + + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); + } + + for (; y < ymax; y += 4) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = (kmax - k0); + for (; K > 3; K -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); + } + } +} + +void sgemm_4x12_pack_A_t(float* out, const float* in, int ldin, int x0, + int xmax, int k0, int kmax) { + int ksize = kmax - k0; + int ksize4 = (ksize << 2); + float* outptr_base = out; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k * ldin + x0; + const float* inptr1 = inptr + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + prefetch_3x(inptr3); + + int x = x0; + auto outptr = outptr_base; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + + outptr_base += 4 * 4; + } + + for (; k < kmax; k++) { + const float* inptr = in + k * ldin + x0; + prefetch_3x(inptr); + int x = x0; + auto outptr = outptr_base; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_1x4_1_s(inptr, outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_1(inptr, outptr, 4, xmax - x); + } + + outptr_base += 4; + } +} + +void sgemm_4x12_pack_B_n(float* out, const float* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int ksize = kmax - k0; + int ksize12 = ksize * 12; + int ksize4 = (ksize << 2); + float* outptr_base = out; + float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k * ldin + x0; + const float* inptr1 = inptr + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + prefetch_3x(inptr3); + + int x = x0; + auto outptr = outptr_base; + for (; x + 12 <= xmax; x += 12) { + auto outptr_interleave = outptr; + interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize12; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + + outptr_base += 12 * 4; + outptr_base4 += 4 * 4; + } + + for (; k < kmax; k++) { + const float* inptr = in + k * ldin + x0; + prefetch_3x(inptr); + int x = x0; + auto outptr = outptr_base; + for (; x + 12 <= xmax; x += 12) { + auto outptr_interleave = outptr; + interleave_1x12_1_s(inptr, outptr_interleave); + outptr += ksize12; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_1x4_1_s(inptr, outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_1(inptr, outptr, 4, xmax - x); + } + + outptr_base += 12; + outptr_base4 += 4; + } +} + +void sgemm_4x12_pack_B_t(float* out, const float* in, int ldin, + int y0, int ymax, int k0, int kmax) { + float* outptr = out; + const float* inptr = in; + float zerobuff[4]; + std::memset(zerobuff, 0, sizeof(float) * 4); + int K12 = 12 * (kmax - k0); + + int y = y0; + + for (; y + 12 <= ymax; y += 12) { + int yi = y; + for (; yi < y + 12; yi += 4) { + const float* inptr0 = inptr + yi * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + float* outptr_inner = outptr + yi - y; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int x = (kmax - k0); + for (; x > 3; x -= 4) { + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, + 48); + } + for (; x > 0; x--) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + outptr_inner += 8; + } + } + outptr += K12; + } + + for (; y < ymax; y += 4) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + /* Cope with ragged cases by copying from a buffer of zeroes instead + */ + int x = (kmax - k0); + for (; x > 3; x -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (x > 0) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, x); + } + } +} + +} // namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x12); + +void sgemm_4x12::pack_A(float* out, const float* in, int ldin, int y0, + int ymax, int k0, int kmax, bool transpose_A) const { + if (transpose_A) { + sgemm_4x12_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); + } else { + sgemm_4x12_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); + } +} + +void sgemm_4x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, + int k0, int kmax, bool transpose_B) const { + if (transpose_B) { + sgemm_4x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); + } else { + sgemm_4x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); + } +} + +void sgemm_4x12::kern(const float* packA, const float* packB, + size_t M, size_t N, size_t K, float* C, size_t LDC, + bool is_first_k, const float*, float*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 12; + const int K12 = K * 12; + const int K4 = K * 4; + + size_t m = 0; + for (; m < M; m += A_INTERLEAVE) { + float* output = C + (m * LDC); + + size_t n = 0; + const float* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + kern_4x12(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); + output += B_INTERLEAVE; + cur_packB += K12; + } + + for (; n < N; n += 4) { + kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + + packA += K4; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/fp32/strategy_mk4_4x8.cpp b/dnn/src/armv7/matrix_mul/fp32/strategy_mk4_4x8.cpp new file mode 100644 index 00000000..096e53c9 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/fp32/strategy_mk4_4x8.cpp @@ -0,0 +1,292 @@ +/** + * \file dnn/src/armv7/matrix_mul/fp32/strategy_mk4_4x8.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/matrix_mul/fp32/strategy.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +using namespace megdnn; +using namespace armv7; +using namespace armv7::matmul; + +namespace { + +// Overview of register layout: +// +// A 8x4 cell of Rhs is stored in 32bit in q0-q3, load 4 register each time +// A 4x4 cell of Lhs is stored in 32bit in q4-q7 +// A 4x8 block of accumulators is stored in 32bit in q8-q11. +// +// +--------+ +// | q0-q3 | +// Rhs +--------+ +// +// | | +// +// Lhs | | +// +// +---+ - - - - +--------+ +// | q4| | q8-11 | +// | q5| | | +// | q6| | | +// | q7| | | +// +---+ - - - - +--------+ +// +// Accumulator +void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) { + //! as each load 16 number from B, and pos add 16 * 4, we should minus it + //! before we add stride + LDB = (LDB - 16) * sizeof(float); + asm volatile( + "subs %[K], %[K], #4\n" + + "vld1.32 {d8-d11}, [%[A]]!\n" + "vld1.32 {d12-d15}, [%[A]]!\n" + + "vld1.32 {d0-d3}, [%[B]]!\n" + "vld1.32 {d4-d7}, [%[B]]!\n" + + "vmul.f32 q8, q4, d0[0]\n" + "vmul.f32 q9, q4, d2[0]\n" + "vmul.f32 q10, q4, d4[0]\n" + "vmul.f32 q11, q4, d6[0]\n" + + "vmla.f32 q8, q5, d0[1]\n" + "vmla.f32 q9, q5, d2[1]\n" + "vmla.f32 q10, q5, d4[1]\n" + "vmla.f32 q11, q5, d6[1]\n" + + "beq 2f\n" + + "1:\n" + + "vld1.32 {d8-d11}, [%[A]]!\n" + + "vmla.f32 q8, q6, d1[0]\n" + "vmla.f32 q9, q6, d3[0]\n" + "vmla.f32 q10, q6, d5[0]\n" + "vmla.f32 q11, q6, d7[0]\n" + + "add %[B], %[B], %[LDB]\n" + + "vmla.f32 q8, q7, d1[1]\n" + "vmla.f32 q9, q7, d3[1]\n" + "vld1.32 {d0-d1}, [%[B]]!\n" + "vmla.f32 q10, q7, d5[1]\n" + "vld1.32 {d2-d3}, [%[B]]!\n" + "vmla.f32 q11, q7, d7[1]\n" + "vld1.32 {d4-d5}, [%[B]]!\n" + + "vmla.f32 q8, q4, d0[0]\n" + "vld1.32 {d6-d7}, [%[B]]!\n" + "vmla.f32 q9, q4, d2[0]\n" + "vmla.f32 q10, q4, d4[0]\n" + "vmla.f32 q11, q4, d6[0]\n" + + "vld1.32 {d12-d15}, [%[A]]!\n" + + "vmla.f32 q8, q5, d0[1]\n" + "vmla.f32 q9, q5, d2[1]\n" + "vmla.f32 q10, q5, d4[1]\n" + "vmla.f32 q11, q5, d6[1]\n" + + "subs %[K], %[K], #4\n" + "bne 1b\n" + + "2:\n" + + "vmla.f32 q8, q6, d1[0]\n" + "vmla.f32 q9, q6, d3[0]\n" + "vmla.f32 q10, q6, d5[0]\n" + "vmla.f32 q11, q6, d7[0]\n" + + "vmla.f32 q8, q7, d1[1]\n" + "vmla.f32 q9, q7, d3[1]\n" + "vmla.f32 q10, q7, d5[1]\n" + "vmla.f32 q11, q7, d7[1]\n" + + "vst1.32 {d16, d17, d18, d19}, [%[C]]!\n" + "vst1.32 {d20, d21, d22, d23}, [%[C]]!\n" + + : [A] "+r"(A), [B] "+r"(B), [K] "+r"(K), [C] "+r"(C) + : [LDB] "r"(LDB) + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", + "d20", "d21", "d22", "d23", "cc", "memory"); +} + +// Overview of register layout: +// +// A 8x4 cell of Rhs is stored in 32bit in v0-v3, load 4 register each time +// A 4x4 cell of Lhs is stored in 32bit in v4-v7 +// A 4x8 block of accumulators is stored in 32bit in q8-q15. +// +// +--------+--------+ +// | v0-v3 | v0-v3 | +// Rhs +--------+--------+ +// +// | | | +// +// Lhs | | | +// +// +---+ - - - - +--------+--------+ +// | v4| | v8-11 | v12-15 | +// | v5| | | | +// | v6| | | | +// | v7| | | | +// +---+ - - - - +--------+--------+ +// +// Accumulator +void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) { + LDB *= sizeof(float); + //! as each load 32 number from B, the pos add 32 * 4, we should minus it + //! before we add stride + LDB -= 32 * sizeof(float); + asm volatile( + "vld1.32 {d8, d9, d10, d11}, [%[A]]!\n" + "vld1.32 {d12, d13, d14, d15}, [%[A]]!\n" + + "vld1.32 {d0, d1, d2, d3}, [%[B]]!\n" + "vld1.32 {d4, d5, d6, d7}, [%[B]]!\n" + "vmul.f32 q8, q4, d0[0]\n" + "vmla.f32 q8, q5, d0[1]\n" + "vmul.f32 q9, q4, d2[0]\n" + "vmla.f32 q8, q6, d1[0]\n" + "vmla.f32 q9, q5, d2[1]\n" + "vmla.f32 q8, q7, d1[1]\n" + "vmla.f32 q9, q6, d3[0]\n" + "vmla.f32 q9, q7, d3[1]\n" + "vld1.32 {d0, d1, d2, d3}, [%[B]]!\n" + "vmul.f32 q10, q4, d4[0]\n" + "vmla.f32 q10, q5, d4[1]\n" + "vmul.f32 q11, q4, d6[0]\n" + "vmla.f32 q10, q6, d5[0]\n" + "vmla.f32 q11, q5, d6[1]\n" + "vmla.f32 q10, q7, d5[1]\n" + "vmla.f32 q11, q6, d7[0]\n" + "vmla.f32 q11, q7, d7[1]\n" + + "vld1.32 {d4, d5, d6, d7}, [%[B]]!\n" + "vmul.f32 q12, q4, d0[0]\n" + "vmla.f32 q12, q5, d0[1]\n" + "vmul.f32 q13, q4, d2[0]\n" + "vmla.f32 q12, q6, d1[0]\n" + "vmla.f32 q13, q5, d2[1]\n" + "vmla.f32 q12, q7, d1[1]\n" + "vmla.f32 q13, q6, d3[0]\n" + "vmla.f32 q13, q7, d3[1]\n" + "vmul.f32 q14, q4, d4[0]\n" + "vmla.f32 q14, q5, d4[1]\n" + "vmul.f32 q15, q4, d6[0]\n" + "vmla.f32 q14, q6, d5[0]\n" + "vmla.f32 q15, q5, d6[1]\n" + "vmla.f32 q14, q7, d5[1]\n" + "vmla.f32 q15, q6, d7[0]\n" + "vmla.f32 q15, q7, d7[1]\n" + + "add %[B], %[B], %[LDB]\n" + "subs %[K], %[K], #4\n" + "cmp %[K], #0\n" + "beq 2f\n" + + "1:\n" + "vld1.32 {d8, d9, d10, d11}, [%[A]]!\n" + "vld1.32 {d12, d13, d14, d15}, [%[A]]!\n" + + "vld1.32 {d0, d1, d2, d3}, [%[B]]!\n" + "vld1.32 {d4, d5, d6, d7}, [%[B]]!\n" + "vmla.f32 q8, q4, d0[0]\n" + "vmla.f32 q8, q5, d0[1]\n" + "vmla.f32 q9, q4, d2[0]\n" + "vmla.f32 q8, q6, d1[0]\n" + "vmla.f32 q9, q5, d2[1]\n" + "vmla.f32 q8, q7, d1[1]\n" + "vmla.f32 q9, q6, d3[0]\n" + "vmla.f32 q9, q7, d3[1]\n" + "vld1.32 {d0, d1, d2, d3}, [%[B]]!\n" + "vmla.f32 q10, q4, d4[0]\n" + "vmla.f32 q10, q5, d4[1]\n" + "vmla.f32 q11, q4, d6[0]\n" + "vmla.f32 q10, q6, d5[0]\n" + "vmla.f32 q11, q5, d6[1]\n" + "vmla.f32 q10, q7, d5[1]\n" + "vmla.f32 q11, q6, d7[0]\n" + "vmla.f32 q11, q7, d7[1]\n" + + "vld1.32 {d4, d5, d6, d7}, [%[B]]!\n" + "vmla.f32 q12, q4, d0[0]\n" + "vmla.f32 q12, q5, d0[1]\n" + "vmla.f32 q13, q4, d2[0]\n" + "vmla.f32 q12, q6, d1[0]\n" + "vmla.f32 q13, q5, d2[1]\n" + "vmla.f32 q12, q7, d1[1]\n" + "vmla.f32 q13, q6, d3[0]\n" + "vmla.f32 q13, q7, d3[1]\n" + "vmla.f32 q14, q4, d4[0]\n" + "vmla.f32 q14, q5, d4[1]\n" + "vmla.f32 q15, q4, d6[0]\n" + "vmla.f32 q14, q6, d5[0]\n" + "vmla.f32 q15, q5, d6[1]\n" + "vmla.f32 q14, q7, d5[1]\n" + "vmla.f32 q15, q6, d7[0]\n" + "vmla.f32 q15, q7, d7[1]\n" + + "add %[B], %[B], %[LDB]\n" + "subs %[K], %[K], #4\n" + "cmp %[K], #0\n" + "bne 1b\n" + "2:\n" + "vst1.32 {d16, d17, d18, d19}, [%[C]]!\n" + "vst1.32 {d20, d21, d22, d23}, [%[C]]!\n" + "vst1.32 {d24, d25, d26, d27}, [%[C]]!\n" + "vst1.32 {d28, d29, d30, d31}, [%[C]]!\n" + : [A] "+r"(A), [B] "+r"(B), [K] "+r"(K), [C] "+r"(C) + : [LDB] "r"(LDB) + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", + "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", + "d29", "d30", "d31", "cc", "memory"); +} + +} // namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x8); + +void sgemm_nopack_4x8::kern(const float* A, size_t LDA, const float* B, + size_t LDB, float* C, size_t LDC, size_t M, + size_t K, size_t N, const float*, void*, bool trA, + bool trB) const { + constexpr size_t MB = 4; + constexpr size_t KB = 4; + constexpr size_t NB = 8; + constexpr size_t CALCBLK = 4; + + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); + + //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) + for (size_t m = 0; m < M; m += MB) { + float* output = C + (m / MB) * LDC; + const float* cur_B = B; + size_t n = 0; + for (; n + NB - 1 < N; n += NB) { + kern_4x8(A, cur_B, LDB, K, output); + cur_B += KB * NB; + output += MB * NB; + } + if (n < N) { + kern_4x4(A, cur_B, LDB, K, output); + } + A += LDA; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int16x16x32/kernel_12x4x1.h b/dnn/src/armv7/matrix_mul/int16x16x32/kernel_12x4x1.h new file mode 100644 index 00000000..9cd6efe8 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int16x16x32/kernel_12x4x1.h @@ -0,0 +1,1012 @@ +/** + * \file dnn/src/armv7/matrix_mul/int16x16x32/kernel_12x4x1.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul_12x4x1 { +/** + * Overview of register layout: + * + * A 12x4 cell of Rhs is stored in 16bit in d3 + * A 1x4 cell of Lhs is stored in 16bit in d0 d1 d2 + * A 4x4 block of accumulators is stored in 32bit in q4-q15 + * + * +--------+ + * | d3[0-3]| + * Rhs +--------+ + * Lhs | | + * + * +--------+ - - - - +--------- + * |d0[0]| | q4[0-3]| + * |d0[1]| | q5[0-3]| + * |d0[2]| | q6[0-3]| + * |d0[3]| | q7[0-3]| + * |d1[0]| | q8[0-3]| + * |d1[1]| | q9[0-3]| + * |d1[2]| | q10[0-3]| + * |d1[3]| | q11[0-3]| + * |d2[0]| | q12[0-3]| + * |d2[1]| | q13[0-3]| + * |d2[2]| | q14[0-3]| + * |d2[3]| | q15[0-3]| + * +--------+ - - - - +--------- + * + * Accumulator + */ +static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, + int32_t* output, int LDC, bool is_first_k) { + const int16_t* a_ptr = packA; + const int16_t* b_ptr = packB; + + int asmLDC = LDC * sizeof(int32_t); + int oddLDC = LDC * 2; + int32_t* outptr_row0 = output; + int32_t* outptr_row2 = outptr_row0 + oddLDC; + int32_t* outptr_row4 = outptr_row2 + oddLDC; + int32_t* outptr_row6 = outptr_row4 + oddLDC; + int32_t* outptr_row8 = outptr_row6 + oddLDC; + int32_t* outptr_row10 = outptr_row8 + oddLDC; + asm volatile( + "cmp %[is_first_k], #1\n" + "beq 1f\n" + "vld1.32 {d8,d9} ,[%[outptr_row0]]\n" + "vld1.32 {d12,d13},[%[outptr_row2]]\n" + "vld1.32 {d16,d17},[%[outptr_row4]]\n" + "vld1.32 {d20,d21},[%[outptr_row6]]\n" + "vld1.32 {d24,d25},[%[outptr_row8]]\n" + "vld1.32 {d28,d29},[%[outptr_row10]]\n" + "add %[outptr_row0],%[outptr_row0],%[asmLDC]\n" + "add %[outptr_row2],%[outptr_row2],%[asmLDC]\n" + "add %[outptr_row4],%[outptr_row4],%[asmLDC]\n" + "add %[outptr_row6],%[outptr_row6],%[asmLDC]\n" + "add %[outptr_row8],%[outptr_row8],%[asmLDC]\n" + "add %[outptr_row10],%[outptr_row10],%[asmLDC]\n" + "vld1.16 {d10,d11},[%[outptr_row0]]\n" + "vld1.16 {d14,d15},[%[outptr_row2]]\n" + "vld1.16 {d18,d19},[%[outptr_row4]]\n" + "vld1.16 {d22,d23},[%[outptr_row6]]\n" + "vld1.16 {d26,d27},[%[outptr_row8]]\n" + "vld1.16 {d30,d31},[%[outptr_row10]]\n" + "2:\n" + "pld [%[b_ptr],#64]\n" + "vld1.16 {d3},[%[b_ptr]]!\n" + "pld [%[a_ptr],#196]\n" + "vld1.16 {d0},[%[a_ptr]]!\n" + "vld1.16 {d1},[%[a_ptr]]!\n" + "vld1.16 {d2},[%[a_ptr]]!\n" + "vmlal.s16 q4,d3,d0[0]\n" + "vmlal.s16 q5,d3,d0[1]\n" + "vmlal.s16 q6,d3,d0[2]\n" + "vmlal.s16 q7,d3,d0[3]\n" + + "vmlal.s16 q8 ,d3,d1[0]\n" + "vmlal.s16 q9 ,d3,d1[1]\n" + "vmlal.s16 q10,d3,d1[2]\n" + "vmlal.s16 q11,d3,d1[3]\n" + + "vmlal.s16 q12,d3,d2[0]\n" + "vmlal.s16 q13,d3,d2[1]\n" + "vmlal.s16 q14,d3,d2[2]\n" + "vmlal.s16 q15,d3,d2[3]\n" + "subs %[K], %[K], #1\n" + "bne 2b\n" + "vst1.32 {d10,d11},[%[outptr_row0]]\n" + "vst1.32 {d14,d15},[%[outptr_row2]]\n" + "vst1.32 {d18,d19},[%[outptr_row4]]\n" + "vst1.32 {d22,d23},[%[outptr_row6]]\n" + "vst1.32 {d26,d27},[%[outptr_row8]]\n" + "vst1.32 {d30,d31},[%[outptr_row10]]\n" + "sub %[outptr_row0],%[outptr_row0],%[asmLDC]\n" + "sub %[outptr_row2],%[outptr_row2],%[asmLDC]\n" + "sub %[outptr_row4],%[outptr_row4],%[asmLDC]\n" + "sub %[outptr_row6],%[outptr_row6],%[asmLDC]\n" + "sub %[outptr_row8],%[outptr_row8],%[asmLDC]\n" + "sub %[outptr_row10],%[outptr_row10],%[asmLDC]\n" + "vst1.32 {d8,d9} ,[%[outptr_row0]]\n" + "vst1.32 {d12,d13},[%[outptr_row2]]\n" + "vst1.32 {d16,d17},[%[outptr_row4]]\n" + "vst1.32 {d20,d21},[%[outptr_row6]]\n" + "vst1.32 {d24,d25},[%[outptr_row8]]\n" + "vst1.32 {d28,d29},[%[outptr_row10]]\n" + "b 4f \n" + "1:\n" // handle fisrt reduce 1 cmp + "veor.s32 q4, q4, q4\n" + "veor.s32 q5, q5, q5\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q7, q7, q7\n" + + "veor.s32 q8, q8, q8\n" + "veor.s32 q9, q9, q9\n" + "veor.s32 q10, q10, q10\n" + "veor.s32 q11, q11, q11\n" + + "veor.s32 q12, q12, q12\n" + "veor.s32 q13, q13, q13\n" + "veor.s32 q14, q14, q14\n" + "veor.s32 q15, q15, q15\n" + "3:\n" + "pld [%[b_ptr],#64]\n" + "vld1.16 {d3},[%[b_ptr]]!\n" + "pld [%[a_ptr],#196]\n" + "vld1.16 {d0},[%[a_ptr]]!\n" + "vld1.16 {d1},[%[a_ptr]]!\n" + "vld1.16 {d2},[%[a_ptr]]!\n" + "vmlal.s16 q4,d3,d0[0]\n" + "vmlal.s16 q5,d3,d0[1]\n" + "vmlal.s16 q6,d3,d0[2]\n" + "vmlal.s16 q7,d3,d0[3]\n" + + "vmlal.s16 q8 ,d3,d1[0]\n" + "vmlal.s16 q9 ,d3,d1[1]\n" + "vmlal.s16 q10,d3,d1[2]\n" + "vmlal.s16 q11,d3,d1[3]\n" + + "vmlal.s16 q12,d3,d2[0]\n" + "vmlal.s16 q13,d3,d2[1]\n" + "vmlal.s16 q14,d3,d2[2]\n" + "vmlal.s16 q15,d3,d2[3]\n" + "subs %[K], %[K], #1\n" + "bne 3b\n" + "vst1.32 {d8,d9} ,[%[outptr_row0]]\n" + "vst1.32 {d12,d13},[%[outptr_row2]]\n" + "vst1.32 {d16,d17},[%[outptr_row4]]\n" + "vst1.32 {d20,d21},[%[outptr_row6]]\n" + "vst1.32 {d24,d25},[%[outptr_row8]]\n" + "vst1.32 {d28,d29},[%[outptr_row10]]\n" + "add %[outptr_row0],%[outptr_row0],%[asmLDC]\n" + "add %[outptr_row2],%[outptr_row2],%[asmLDC]\n" + "add %[outptr_row4],%[outptr_row4],%[asmLDC]\n" + "add %[outptr_row6],%[outptr_row6],%[asmLDC]\n" + "add %[outptr_row8],%[outptr_row8],%[asmLDC]\n" + "add %[outptr_row10],%[outptr_row10],%[asmLDC]\n" + "vst1.32 {d10,d11},[%[outptr_row0]]\n" + "vst1.32 {d14,d15},[%[outptr_row2]]\n" + "vst1.32 {d18,d19},[%[outptr_row4]]\n" + "vst1.32 {d22,d23},[%[outptr_row6]]\n" + "vst1.32 {d26,d27},[%[outptr_row8]]\n" + "vst1.32 {d30,d31},[%[outptr_row10]]\n" + "4: \n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [asmLDC] "+r"(asmLDC), [is_first_k] "+r"(is_first_k), + [outptr_row0] "+r"(outptr_row0), [outptr_row2] "+r"(outptr_row2), + [outptr_row4] "+r"(outptr_row4), [outptr_row6] "+r"(outptr_row6), + [outptr_row8] "+r"(outptr_row8), [outptr_row10] "+r"(outptr_row10) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d18", "d20", "d21", "d22", + "d24", "d26", "d28", "d30", "cc", "memory"); +} + +static void kern_12x123(const int16_t* packA, const int16_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, + size_t n_remain) { + const int16_t* a_ptr = packA; + const int16_t* b_ptr = packB; + int asmLDC = LDC * sizeof(int32_t); + int oddLDC = LDC * 2; + int32_t* outptr_row0 = output; + int32_t* outptr_row2 = outptr_row0 + oddLDC; + int32_t* outptr_row4 = outptr_row2 + oddLDC; + int32_t* outptr_row6 = outptr_row4 + oddLDC; + int32_t* outptr_row8 = outptr_row6 + oddLDC; + int32_t* outptr_row10 = outptr_row8 + oddLDC; + + asm volatile( + "cmp %[is_first_k], #1\n" + "beq 1f\n" + "cmp %[n_remain] ,#3 \n" + "beq 5f\n" + "cmp %[n_remain],#2\n" + "beq 6f\n" + "cmp %[n_remain],#1\n" + "beq 7f\n" + "5: \n" + "vld1.32 {d8} ,[%[outptr_row0]]!\n" + "vld1.32 {d9[0]} ,[%[outptr_row0]]\n" + "vld1.32 {d12} ,[%[outptr_row2]]!\n" + "vld1.32 {d13[0]},[%[outptr_row2]]\n" + "vld1.32 {d16} ,[%[outptr_row4]]!\n" + "vld1.32 {d17[0]},[%[outptr_row4]]\n" + "vld1.32 {d20} ,[%[outptr_row6]]!\n" + "vld1.32 {d20[0]},[%[outptr_row6]]\n" + "vld1.32 {d24} ,[%[outptr_row8]]!\n" + "vld1.32 {d25[0]},[%[outptr_row8]]\n" + "vld1.32 {d28} ,[%[outptr_row10]]!\n" + "vld1.32 {d29[0]},[%[outptr_row10]]\n" + "sub %[outptr_row0],%[outptr_row0],#8\n" + "sub %[outptr_row2],%[outptr_row2],#8\n" + "sub %[outptr_row4],%[outptr_row4],#8\n" + "sub %[outptr_row6],%[outptr_row6],#8\n" + "sub %[outptr_row8],%[outptr_row8],#8\n" + "sub %[outptr_row10],%[outptr_row10],#8\n" + "add %[outptr_row0],%[outptr_row0],%[asmLDC]\n" + "add %[outptr_row2],%[outptr_row2],%[asmLDC]\n" + "add %[outptr_row4],%[outptr_row4],%[asmLDC]\n" + "add %[outptr_row6],%[outptr_row6],%[asmLDC]\n" + "add %[outptr_row8],%[outptr_row8],%[asmLDC]\n" + "add %[outptr_row10],%[outptr_row10],%[asmLDC]\n" + "vld1.32 {d10} ,[%[outptr_row0]]!\n" + "vld1.32 {d11[0]},[%[outptr_row0]]\n" + "vld1.32 {d14} ,[%[outptr_row2]]!\n" + "vld1.32 {d15[0]},[%[outptr_row2]]\n" + "vld1.32 {d18} ,[%[outptr_row4]]!\n" + "vld1.32 {d19[0]},[%[outptr_row4]]\n" + "vld1.32 {d22} ,[%[outptr_row6]]!\n" + "vld1.32 {d23[0]},[%[outptr_row6]]\n" + "vld1.32 {d26} ,[%[outptr_row8]]!\n" + "vld1.32 {d27[0]},[%[outptr_row8]]\n" + "vld1.32 {d30} ,[%[outptr_row10]]!\n" + "vld1.32 {d31[0]},[%[outptr_row10]]\n" + "sub %[outptr_row0],%[outptr_row0],#8\n" + "sub %[outptr_row2],%[outptr_row2],#8\n" + "sub %[outptr_row4],%[outptr_row4],#8\n" + "sub %[outptr_row6],%[outptr_row6],#8\n" + "sub %[outptr_row8],%[outptr_row8],#8\n" + "sub %[outptr_row10],%[outptr_row10],#8\n" + "b 2f\n" + "6: \n" + "vld1.32 {d8} ,[%[outptr_row0]]\n" + "vld1.32 {d12},[%[outptr_row2]]\n" + "vld1.32 {d16},[%[outptr_row4]]\n" + "vld1.32 {d20},[%[outptr_row6]]\n" + "vld1.32 {d24},[%[outptr_row8]]\n" + "vld1.32 {d28},[%[outptr_row10]]\n" + "add %[outptr_row0],%[outptr_row0],%[asmLDC]\n" + "add %[outptr_row2],%[outptr_row2],%[asmLDC]\n" + "add %[outptr_row4],%[outptr_row4],%[asmLDC]\n" + "add %[outptr_row6],%[outptr_row6],%[asmLDC]\n" + "add %[outptr_row8],%[outptr_row8],%[asmLDC]\n" + "add %[outptr_row10],%[outptr_row10],%[asmLDC]\n" + "vld1.32 {d10},[%[outptr_row0]]\n" + "vld1.32 {d14},[%[outptr_row2]]\n" + "vld1.32 {d18},[%[outptr_row4]]\n" + "vld1.32 {d22},[%[outptr_row6]]\n" + "vld1.32 {d26},[%[outptr_row8]]\n" + "vld1.32 {d30},[%[outptr_row10]]\n" + "b 2f\n" + "7: \n" + "vld1.32 {d8[0]} ,[%[outptr_row0]]\n" + "vld1.32 {d12[0]},[%[outptr_row2]]\n" + "vld1.32 {d16[0]},[%[outptr_row4]]\n" + "vld1.32 {d20[0]},[%[outptr_row6]]\n" + "vld1.32 {d24[0]},[%[outptr_row8]]\n" + "vld1.32 {d28[0]},[%[outptr_row10]]\n" + "add %[outptr_row0],%[outptr_row0],%[asmLDC]\n" + "add %[outptr_row2],%[outptr_row2],%[asmLDC]\n" + "add %[outptr_row4],%[outptr_row4],%[asmLDC]\n" + "add %[outptr_row6],%[outptr_row6],%[asmLDC]\n" + "add %[outptr_row8],%[outptr_row8],%[asmLDC]\n" + "add %[outptr_row10],%[outptr_row10],%[asmLDC]\n" + "vld1.32 {d10[0]},[%[outptr_row0]]\n" + "vld1.32 {d14[0]},[%[outptr_row2]]\n" + "vld1.32 {d18[0]},[%[outptr_row4]]\n" + "vld1.32 {d22[0]},[%[outptr_row6]]\n" + "vld1.32 {d26[0]},[%[outptr_row8]]\n" + "vld1.32 {d30[0]},[%[outptr_row10]]\n" + "b 2f \n" + "1:\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q5, q5, q5\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q7, q7, q7\n" + + "veor.s32 q8, q8, q8\n" + "veor.s32 q9, q9, q9\n" + "veor.s32 q10, q10, q10\n" + "veor.s32 q11, q11, q11\n" + + "veor.s32 q12, q12, q12\n" + "veor.s32 q13, q13, q13\n" + "veor.s32 q14, q14, q14\n" + "veor.s32 q15, q15, q15\n" + "2:\n" + "pld [%[b_ptr],#16]\n" + "vld1.16 {d3},[%[b_ptr]]!\n" + "pld [%[a_ptr],#196]\n" + "vld1.16 {d0},[%[a_ptr]]!\n" + "vld1.16 {d1},[%[a_ptr]]!\n" + "vld1.16 {d2},[%[a_ptr]]!\n" + "vmlal.s16 q4,d3,d0[0]\n" + "vmlal.s16 q5,d3,d0[1]\n" + "vmlal.s16 q6,d3,d0[2]\n" + "vmlal.s16 q7,d3,d0[3]\n" + + "vmlal.s16 q8 ,d3,d1[0]\n" + "vmlal.s16 q9 ,d3,d1[1]\n" + "vmlal.s16 q10,d3,d1[2]\n" + "vmlal.s16 q11,d3,d1[3]\n" + + "vmlal.s16 q12,d3,d2[0]\n" + "vmlal.s16 q13,d3,d2[1]\n" + "vmlal.s16 q14,d3,d2[2]\n" + "vmlal.s16 q15,d3,d2[3]\n" + "subs %[K], %[K], #1\n" + "bne 2b\n" + "cmp %[is_first_k], #1\n" + "beq 3f\n" + "cmp %[n_remain] ,#3 \n" + "beq 5f\n" + "cmp %[n_remain] ,#2 \n" + "beq 6f\n" + "cmp %[n_remain] ,#1 \n" + "beq 7f\n" + "5: \n" + "vst1.32 {d10} ,[%[outptr_row0]]!\n" + "vst1.32 {d11[0]},[%[outptr_row0]]\n" + "vst1.32 {d14} ,[%[outptr_row2]]!\n" + "vst1.32 {d15[0]},[%[outptr_row2]]\n" + "vst1.32 {d18} ,[%[outptr_row4]]!\n" + "vst1.32 {d19[0]},[%[outptr_row4]]\n" + "vst1.32 {d22} ,[%[outptr_row6]]!\n" + "vst1.32 {d23[0]},[%[outptr_row6]]\n" + "vst1.32 {d26} ,[%[outptr_row8]]!\n" + "vst1.32 {d27[0]},[%[outptr_row8]]\n" + "vst1.32 {d30} ,[%[outptr_row10]]!\n" + "vst1.32 {d31[0]},[%[outptr_row10]]\n" + "sub %[outptr_row0],%[outptr_row0],#8\n" + "sub %[outptr_row2],%[outptr_row2],#8\n" + "sub %[outptr_row4],%[outptr_row4],#8\n" + "sub %[outptr_row6],%[outptr_row6],#8\n" + "sub %[outptr_row8],%[outptr_row8],#8\n" + "sub %[outptr_row10],%[outptr_row10],#8\n" + "sub %[outptr_row0],%[outptr_row0],%[asmLDC]\n" + "sub %[outptr_row2],%[outptr_row2],%[asmLDC]\n" + "sub %[outptr_row4],%[outptr_row4],%[asmLDC]\n" + "sub %[outptr_row6],%[outptr_row6],%[asmLDC]\n" + "sub %[outptr_row8],%[outptr_row8],%[asmLDC]\n" + "sub %[outptr_row10],%[outptr_row10],%[asmLDC]\n" + "vst1.32 {d8} ,[%[outptr_row0]]!\n" + "vst1.32 {d9[0]} ,[%[outptr_row0]]\n" + "vst1.32 {d12} ,[%[outptr_row2]]!\n" + "vst1.32 {d13[0]},[%[outptr_row2]]\n" + "vst1.32 {d16} ,[%[outptr_row4]]!\n" + "vst1.32 {d17[0]},[%[outptr_row4]]\n" + "vst1.32 {d20} ,[%[outptr_row6]]!\n" + "vst1.32 {d21[0]},[%[outptr_row6]]\n" + "vst1.32 {d24} ,[%[outptr_row8]]!\n" + "vst1.32 {d25[0]},[%[outptr_row8]]\n" + "vst1.32 {d28} ,[%[outptr_row10]]!\n" + "vst1.32 {d29[0]},[%[outptr_row10]]\n" + "sub %[outptr_row0],%[outptr_row0],#8\n" + "sub %[outptr_row2],%[outptr_row2],#8\n" + "sub %[outptr_row4],%[outptr_row4],#8\n" + "sub %[outptr_row6],%[outptr_row6],#8\n" + "sub %[outptr_row8],%[outptr_row8],#8\n" + "sub %[outptr_row10],%[outptr_row10],#8\n" + + "b 4f\n" + "6: \n" + "vst1.32 {d10} ,[%[outptr_row0]]\n" + "vst1.32 {d14} ,[%[outptr_row2]]\n" + "vst1.32 {d18} ,[%[outptr_row4]]\n" + "vst1.32 {d22} ,[%[outptr_row6]]\n" + "vst1.32 {d26} ,[%[outptr_row8]]\n" + "vst1.32 {d30} ,[%[outptr_row10]]\n" + "sub %[outptr_row0],%[outptr_row0],%[asmLDC]\n" + "sub %[outptr_row2],%[outptr_row2],%[asmLDC]\n" + "sub %[outptr_row4],%[outptr_row4],%[asmLDC]\n" + "sub %[outptr_row6],%[outptr_row6],%[asmLDC]\n" + "sub %[outptr_row8],%[outptr_row8],%[asmLDC]\n" + "sub %[outptr_row10],%[outptr_row10],%[asmLDC]\n" + "vst1.32 {d8} ,[%[outptr_row0]]\n" + "vst1.32 {d12} ,[%[outptr_row2]]\n" + "vst1.32 {d16} ,[%[outptr_row4]]\n" + "vst1.32 {d20} ,[%[outptr_row6]]\n" + "vst1.32 {d24} ,[%[outptr_row8]]\n" + "vst1.32 {d28} ,[%[outptr_row10]]\n" + "b 4f\n" + "7: \n" + "vst1.32 {d10[0]},[%[outptr_row0]]\n" + "vst1.32 {d14[0]},[%[outptr_row2]]\n" + "vst1.32 {d18[0]},[%[outptr_row4]]\n" + "vst1.32 {d22[0]},[%[outptr_row6]]\n" + "vst1.32 {d26[0]},[%[outptr_row8]]\n" + "vst1.32 {d30[0]},[%[outptr_row10]]\n" + "sub %[outptr_row0],%[outptr_row0],%[asmLDC]\n" + "sub %[outptr_row2],%[outptr_row2],%[asmLDC]\n" + "sub %[outptr_row4],%[outptr_row4],%[asmLDC]\n" + "sub %[outptr_row6],%[outptr_row6],%[asmLDC]\n" + "sub %[outptr_row8],%[outptr_row8],%[asmLDC]\n" + "sub %[outptr_row10],%[outptr_row10],%[asmLDC]\n" + "vst1.32 {d8[0]} ,[%[outptr_row0]]\n" + "vst1.32 {d12[0]},[%[outptr_row2]]\n" + "vst1.32 {d16[0]},[%[outptr_row4]]\n" + "vst1.32 {d20[0]},[%[outptr_row6]]\n" + "vst1.32 {d24[0]},[%[outptr_row8]]\n" + "vst1.32 {d28[0]},[%[outptr_row10]]\n" + "b 4f\n" + "3: \n" // first k + "cmp %[n_remain] ,#3 \n" + "beq 5f\n" + "cmp %[n_remain] ,#2 \n" + "beq 6f\n" + "cmp %[n_remain] ,#1 \n" + "beq 7f\n" + "5:\n" + "vst1.32 {d8} ,[%[outptr_row0]]!\n" + "vst1.32 {d9[0]} ,[%[outptr_row0]]\n" + "vst1.32 {d12} ,[%[outptr_row2]]!\n" + "vst1.32 {d13[0]},[%[outptr_row2]]\n" + "vst1.32 {d16} ,[%[outptr_row4]]!\n" + "vst1.32 {d17[0]},[%[outptr_row4]]\n" + "vst1.32 {d20} ,[%[outptr_row6]]!\n" + "vst1.32 {d21[0]},[%[outptr_row6]]\n" + "vst1.32 {d24} ,[%[outptr_row8]]!\n" + "vst1.32 {d25[0]},[%[outptr_row8]]\n" + "vst1.32 {d28} ,[%[outptr_row10]]!\n" + "vst1.32 {d29[0]},[%[outptr_row10]]\n" + "sub %[outptr_row0],%[outptr_row0],#8\n" + "sub %[outptr_row2],%[outptr_row2],#8\n" + "sub %[outptr_row4],%[outptr_row4],#8\n" + "sub %[outptr_row6],%[outptr_row6],#8\n" + "sub %[outptr_row8],%[outptr_row8],#8\n" + "sub %[outptr_row10],%[outptr_row10],#8\n" + "add %[outptr_row0],%[outptr_row0],%[asmLDC]\n" + "add %[outptr_row2],%[outptr_row2],%[asmLDC]\n" + "add %[outptr_row4],%[outptr_row4],%[asmLDC]\n" + "add %[outptr_row6],%[outptr_row6],%[asmLDC]\n" + "add %[outptr_row8],%[outptr_row8],%[asmLDC]\n" + "add %[outptr_row10],%[outptr_row10],%[asmLDC]\n" + "vst1.32 {d10} ,[%[outptr_row0]]!\n" + "vst1.32 {d11[0]},[%[outptr_row0]]\n" + "vst1.32 {d14} ,[%[outptr_row2]]!\n" + "vst1.32 {d15[0]},[%[outptr_row2]]\n" + "vst1.32 {d18} ,[%[outptr_row4]]!\n" + "vst1.32 {d19[0]},[%[outptr_row4]]\n" + "vst1.32 {d22} ,[%[outptr_row6]]!\n" + "vst1.32 {d23[0]},[%[outptr_row6]]\n" + "vst1.32 {d26} ,[%[outptr_row8]]!\n" + "vst1.32 {d27[0]},[%[outptr_row8]]\n" + "vst1.32 {d30} ,[%[outptr_row10]]!\n" + "vst1.32 {d31[0]},[%[outptr_row10]]\n" + "sub %[outptr_row0],%[outptr_row0],#8\n" + "sub %[outptr_row2],%[outptr_row2],#8\n" + "sub %[outptr_row4],%[outptr_row4],#8\n" + "sub %[outptr_row6],%[outptr_row6],#8\n" + "sub %[outptr_row8],%[outptr_row8],#8\n" + "sub %[outptr_row10],%[outptr_row10],#8\n" + "b 4f\n" + "6:\n" + "vst1.32 {d8} ,[%[outptr_row0]]\n" + "vst1.32 {d12} ,[%[outptr_row2]]\n" + "vst1.32 {d16} ,[%[outptr_row4]]\n" + "vst1.32 {d20} ,[%[outptr_row6]]\n" + "vst1.32 {d24} ,[%[outptr_row8]]\n" + "vst1.32 {d28} ,[%[outptr_row10]]\n" + "add %[outptr_row0],%[outptr_row0],%[asmLDC]\n" + "add %[outptr_row2],%[outptr_row2],%[asmLDC]\n" + "add %[outptr_row4],%[outptr_row4],%[asmLDC]\n" + "add %[outptr_row6],%[outptr_row6],%[asmLDC]\n" + "add %[outptr_row8],%[outptr_row8],%[asmLDC]\n" + "add %[outptr_row10],%[outptr_row10],%[asmLDC]\n" + "vst1.32 {d10} ,[%[outptr_row0]]\n" + "vst1.32 {d14} ,[%[outptr_row2]]\n" + "vst1.32 {d18} ,[%[outptr_row4]]\n" + "vst1.32 {d22} ,[%[outptr_row6]]\n" + "vst1.32 {d26} ,[%[outptr_row8]]\n" + "vst1.32 {d30} ,[%[outptr_row10]]\n" + "b 4f\n" + "7: \n" + "vst1.32 {d8[0]} ,[%[outptr_row0]]\n" + "vst1.32 {d12[0]},[%[outptr_row2]]\n" + "vst1.32 {d16[0]},[%[outptr_row4]]\n" + "vst1.32 {d20[0]},[%[outptr_row6]]\n" + "vst1.32 {d24[0]},[%[outptr_row8]]\n" + "vst1.32 {d28[0]},[%[outptr_row10]]\n" + "add %[outptr_row0],%[outptr_row0],%[asmLDC]\n" + "add %[outptr_row2],%[outptr_row2],%[asmLDC]\n" + "add %[outptr_row4],%[outptr_row4],%[asmLDC]\n" + "add %[outptr_row6],%[outptr_row6],%[asmLDC]\n" + "add %[outptr_row8],%[outptr_row8],%[asmLDC]\n" + "add %[outptr_row10],%[outptr_row10],%[asmLDC]\n" + "vst1.32 {d10[0]},[%[outptr_row0]]\n" + "vst1.32 {d14[0]},[%[outptr_row2]]\n" + "vst1.32 {d18[0]},[%[outptr_row4]]\n" + "vst1.32 {d22[0]},[%[outptr_row6]]\n" + "vst1.32 {d26[0]},[%[outptr_row8]]\n" + "vst1.32 {d30[0]},[%[outptr_row10]]\n" + "4:\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [n_remain] "+r"(n_remain), [asmLDC] "+r"(asmLDC), + [is_first_k] "+r"(is_first_k), [outptr_row0] "+r"(outptr_row0), + [outptr_row2] "+r"(outptr_row2), [outptr_row4] "+r"(outptr_row4), + [outptr_row6] "+r"(outptr_row6), [outptr_row8] "+r"(outptr_row8), + [outptr_row10] "+r"(outptr_row10) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", + "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", + "d30", "d31", "cc", "memory"); +} + +static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, + int32_t* output, int LDC, bool is_first_k) { + const int16_t* a_ptr = packA; + const int16_t* b_ptr = packB; + + int32_t* outptr_row0 = output; + int32_t* outptr_row1 = outptr_row0 + LDC; + int32_t* outptr_row2 = outptr_row1 + LDC; + int32_t* outptr_row3 = outptr_row2 + LDC; + + asm volatile( + "cmp %[is_first_k], #1\n" + "beq 1f\n" + "vld1.32 {d8,d9} ,[%[outptr_row0]]\n" + "vld1.32 {d10,d11},[%[outptr_row1]]\n" + "vld1.32 {d12,d13},[%[outptr_row2]]\n" + "vld1.32 {d14,d15},[%[outptr_row3]]\n" + "b 2f \n" + "1:\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q5, q5, q5\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q7, q7, q7\n" + "2:\n" + "pld [%[b_ptr],#64]\n" + "vld1.16 {d3},[%[b_ptr]]!\n" + "pld [%[a_ptr],#64]\n" + "vld1.16 {d0},[%[a_ptr]]!\n" + "vmlal.s16 q4,d3,d0[0]\n" + "vmlal.s16 q5,d3,d0[1]\n" + "vmlal.s16 q6,d3,d0[2]\n" + "vmlal.s16 q7,d3,d0[3]\n" + "subs %[K], %[K], #1\n" + "bne 2b\n" + "vst1.32 {d8,d9} ,[%[outptr_row0]]\n" + "vst1.32 {d10,d11},[%[outptr_row1]]\n" + "vst1.32 {d12,d13},[%[outptr_row2]]\n" + "vst1.32 {d14,d15},[%[outptr_row3]]\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [outptr_row0] "+r"(outptr_row0), + [outptr_row1] "+r"(outptr_row1), [outptr_row2] "+r"(outptr_row2), + [outptr_row3] "+r"(outptr_row3) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d18", "d20", "d21", "d22", + "d24", "d26", "d28", "d30", "cc", "memory"); +} + +static void kern_4x123(const int16_t* packA, const int16_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, + int n_remain) { + const int16_t* a_ptr = packA; + const int16_t* b_ptr = packB; + + int32_t* outptr_row0 = output; + int32_t* outptr_row1 = outptr_row0 + LDC; + int32_t* outptr_row2 = outptr_row1 + LDC; + int32_t* outptr_row3 = outptr_row2 + LDC; + + asm volatile( + "cmp %[is_first_k], #1\n" + "beq 1f\n" + "cmp %[n_remain],#3\n" + "beq 3f\n" + "cmp %[n_remain],#2\n" + "beq 4f\n" + "cmp %[n_remain],#1\n" + "beq 5f\n" + "3: \n" + "vld1.32 {d8} ,[%[outptr_row0]]!\n" + "vld1.32 {d9[0]} ,[%[outptr_row0]]\n" + "vld1.32 {d10},[%[outptr_row1]]!\n" + "vld1.32 {d11[0]},[%[outptr_row1]]\n" + "vld1.32 {d12},[%[outptr_row2]]!\n" + "vld1.32 {d13[0]},[%[outptr_row2]]\n" + "vld1.32 {d14},[%[outptr_row3]]!\n" + "vld1.32 {d15[0]},[%[outptr_row3]]\n" + "sub %[outptr_row0],%[outptr_row0],#8\n" + "sub %[outptr_row1],%[outptr_row1],#8\n" + "sub %[outptr_row2],%[outptr_row2],#8\n" + "sub %[outptr_row3],%[outptr_row3],#8\n" + "b 2f\n" + "4:\n" + "vld1.32 {d8} ,[%[outptr_row0]]\n" + "vld1.32 {d10},[%[outptr_row1]]\n" + "vld1.32 {d12},[%[outptr_row2]]\n" + "vld1.32 {d14},[%[outptr_row3]]\n" + "b 2f\n" + "5:\n" + "vld1.32 {d8[0]} ,[%[outptr_row0]]\n" + "vld1.32 {d10[0]},[%[outptr_row1]]\n" + "vld1.32 {d12[0]},[%[outptr_row2]]\n" + "vld1.32 {d14[0]},[%[outptr_row3]]\n" + "b 2f \n" + "1:\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q5, q5, q5\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q7, q7, q7\n" + "2:\n" + "pld [%[b_ptr],#16]\n" + "vld1.16 {d3},[%[b_ptr]]!\n" + "pld [%[a_ptr],#64]\n" + "vld1.16 {d0},[%[a_ptr]]!\n" + "vmlal.s16 q4,d3,d0[0]\n" + "vmlal.s16 q5,d3,d0[1]\n" + "vmlal.s16 q6,d3,d0[2]\n" + "vmlal.s16 q7,d3,d0[3]\n" + "subs %[K], %[K], #1\n" + "bne 2b\n" + "cmp %[n_remain],#3\n" + "beq 3f\n" + "cmp %[n_remain],#2\n" + "beq 4f\n" + "cmp %[n_remain],#1\n" + "beq 5f\n" + "3:\n" + "vst1.32 {d8} ,[%[outptr_row0]]!\n" + "vst1.32 {d9[0]} ,[%[outptr_row0]]\n" + "vst1.32 {d10},[%[outptr_row1]]!\n" + "vst1.32 {d11[0]},[%[outptr_row1]]\n" + "vst1.32 {d12},[%[outptr_row2]]!\n" + "vst1.32 {d13[0]},[%[outptr_row2]]\n" + "vst1.32 {d14},[%[outptr_row3]]!\n" + "vst1.32 {d15[0]},[%[outptr_row3]]\n" + "b 6f\n" + "4:\n" + "vst1.32 {d8} ,[%[outptr_row0]]\n" + "vst1.32 {d10},[%[outptr_row1]]\n" + "vst1.32 {d12},[%[outptr_row2]]\n" + "vst1.32 {d14},[%[outptr_row3]]\n" + "b 6f\n" + "5:\n" + "vst1.32 {d8[0]} ,[%[outptr_row0]]\n" + "vst1.32 {d10[0]},[%[outptr_row1]]\n" + "vst1.32 {d12[0]},[%[outptr_row2]]\n" + "vst1.32 {d14[0]},[%[outptr_row3]]\n" + "6:\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [n_remain] "+r"(n_remain), [is_first_k] "+r"(is_first_k), + [outptr_row0] "+r"(outptr_row0), [outptr_row1] "+r"(outptr_row1), + [outptr_row2] "+r"(outptr_row2), [outptr_row3] "+r"(outptr_row3) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d18", "d20", "d21", "d22", + "d24", "d26", "d28", "d30", "cc", "memory"); +} + +static void kern_1x4(const int16_t* packA, const int16_t* packB, int K, + int32_t* output, int LDC, bool is_first_k) { + MEGDNN_MARK_USED_VAR(LDC); + const int16_t* a_ptr = packA; + const int16_t* b_ptr = packB; + + int32_t* outptr_row0 = output; + asm volatile( + "cmp %[is_first_k], #1\n" + "beq 1f\n" + "pld [%[outptr_row0],#64]\n" + "vld1.32 {d8,d9} ,[%[outptr_row0]]\n" + "b 2f \n" + "1:\n" + "veor.s32 q4, q4, q4\n" + "2:\n" + "pld [%[b_ptr],#64]\n" + "pld [%[a_ptr],#16]\n" + "vld1.16 {d3},[%[b_ptr]]!\n" + "vld1.16 {d0[0]},[%[a_ptr]]!\n" + "vmlal.s16 q4,d3,d0[0]\n" + "subs %[K], %[K], #1\n" + "bne 2b\n" + "vst1.32 {d8,d9} ,[%[outptr_row0]]\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [outptr_row0] "+r"(outptr_row0) + : + : "d0", "d3", "d4", "d8", "d9", "cc", "memory"); +} + +/************************************************ + *this kern can hanle 1xk mul kx1 kx2 kx3 get 1x1 1x2 1x3 + *123 stands for n remain 1 2 3 + ************************************************/ +static void kern_1x123(const int16_t* packA, const int16_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, + int n_remain) { + MEGDNN_MARK_USED_VAR(LDC); + const int16_t* a_ptr = packA; + const int16_t* b_ptr = packB; + int32_t* outptr_row0 = output; + asm volatile( + "cmp %[is_first_k], #1\n" + "beq 1f\n" + "cmp %[n_remain],#3\n" + "beq 3f\n" + "cmp %[n_remain],#2\n" + "beq 4f\n" + "cmp %[n_remain],#1\n" + "beq 5f\n" + "3:\n" + "vld1.32 {d8} ,[%[outptr_row0]]!\n" + "vld1.32 {d9[0]} ,[%[outptr_row0]]\n" + "sub %[outptr_row0],%[outptr_row0],#8 \n" + "b 2f\n" + "4:\n" + "vld1.32 {d8} ,[%[outptr_row0]]\n" + "b 2f\n" + "5:\n" + "vld1.32 {d8[0]} ,[%[outptr_row0]]\n" + "b 2f \n" + "1:\n" + "veor.s32 q4, q4, q4\n" + "2:\n" + "vld1.16 {d3},[%[b_ptr]]!\n" + "vld1.16 {d0[0]},[%[a_ptr]]!\n" + "vmlal.s16 q4,d3,d0[0]\n" + "subs %[K], %[K], #1\n" + "bne 2b\n" + "cmp %[n_remain],#3\n" + "beq 3f\n" + "cmp %[n_remain],#2\n" + "beq 4f\n" + "cmp %[n_remain],#1\n" + "beq 5f\n" + "3:\n" + "vst1.32 {d8} ,[%[outptr_row0]]!\n" + "vst1.32 {d9[0]} ,[%[outptr_row0]]\n" + "b 7f\n" + "4:\n" + "vst1.32 {d8} ,[%[outptr_row0]]\n" + "b 7f\n" + "5:\n" + "vst1.32 {d8[0]} ,[%[outptr_row0]]\n" + "7:\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [n_remain] "+r"(n_remain), [is_first_k] "+r"(is_first_k), + [outptr_row0] "+r"(outptr_row0) + : + : "d0", "d3", "d8", "d9", "cc", "memory"); +} + +static void gemm_s16x16x32_12x4_pack_A_n(dt_int16* outptr, + const dt_int16* inptr, int ldin, + int y0, int ymax, int k0, int kmax) { + int y = y0; + int K = kmax - k0; + for (; y + 11 < ymax; y += 12) { + const int16_t* inptr0 = inptr + y * ldin + k0; + const int16_t* inptr1 = inptr0 + ldin; + const int16_t* inptr2 = inptr1 + ldin; + const int16_t* inptr3 = inptr2 + ldin; + const int16_t* inptr4 = inptr3 + ldin; + const int16_t* inptr5 = inptr4 + ldin; + const int16_t* inptr6 = inptr5 + ldin; + const int16_t* inptr7 = inptr6 + ldin; + const int16_t* inptr8 = inptr7 + ldin; + const int16_t* inptr9 = inptr8 + ldin; + const int16_t* inptr10 = inptr9 + ldin; + const int16_t* inptr11 = inptr10 + ldin; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + transpose_12x4_1_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, + ldin, outptr); + } + + for (; k < kmax; k++) { + transpose_12x1(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, + outptr); + } + } + + for (; y + 3 < ymax; y += 4) { + const int16_t* inptr0 = inptr + y * ldin + k0; + const int16_t* inptr1 = inptr0 + ldin; + const int16_t* inptr2 = inptr1 + ldin; + const int16_t* inptr3 = inptr2 + ldin; + + int k = 0; + for (; k + 3 < K; k += 4) { + transpose_4x4_1_h(inptr0, inptr1, inptr2, inptr3, outptr); + } + for (; k < K; k++) { + transpose_4x1(inptr0, inptr1, inptr2, inptr3, outptr); + } + } + for (; y < ymax; y++) { + const int16_t* inptr0 = inptr + y * ldin + k0; + std::memcpy(outptr, inptr0, sizeof(int16_t) * K); + outptr += K; + } +} + +static void gemm_s16x16x32_12x4_transpose_pack_A_n(dt_int16* out, + const dt_int16* in, int ldin, + int x0, int xmax, int k0, + int kmax) { + const int ksize = kmax - k0; + const int ksize12 = ksize * 12; + const int ksize4 = ksize * 4; + int16_t* outptr = out; + int16_t* outptr_interleave = out; + int16_t* outptr_base = out; + int16_t* outptr_times4_base = out + (xmax - x0) / 12 * ksize12; + int16_t* outptr_times1_base = + outptr_times4_base + ((xmax - x0) % 12) / 4 * ksize4; + int k = k0; + for (; k + 3 < kmax; k += 4) { + const int16_t* inptr0 = in + k * ldin + x0; + const int16_t* inptr1 = inptr0 + ldin; + const int16_t* inptr2 = inptr1 + ldin; + const int16_t* inptr3 = inptr2 + ldin; + + int x = x0; + outptr = outptr_base; + + for (; x + 11 < xmax; x += 12) { + outptr_interleave = outptr; + + interleave_4x12_1_h(inptr0, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize12; + } + outptr = outptr_times4_base; + for (; x + 3 < xmax; x += 4) { + outptr_interleave = outptr; + interleave_4x4_1_h(inptr0, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize4; + } + + outptr = outptr_times1_base; + for (; x < xmax; x++) { + outptr_interleave = outptr; + transpose_4x1(inptr0, inptr1, inptr2, inptr3, outptr_interleave); + outptr += ksize; + } + outptr_base += 48; + outptr_times4_base += 16; + outptr_times1_base += 4; + } + for (; k < kmax; k++) { + const int16_t* inptr0 = in + k * ldin + x0; + prefetch_2x(inptr0); + + int x = x0; + outptr = outptr_base; + + for (; x + 11 < xmax; x += 12) { + outptr_interleave = outptr; + interleave_1x12_1_h(inptr0, outptr_interleave); + outptr += ksize12; + } + outptr = outptr_times4_base; + for (; x + 3 < xmax; x += 4) { + outptr_interleave = outptr; + interleave_1x4_1_h(inptr0, outptr_interleave); + outptr += ksize4; + } + + outptr = outptr_times1_base; + for (; x < xmax; x++) { + outptr_interleave = outptr; + *outptr_interleave++ = *inptr0++; + outptr += ksize; + } + outptr_base += 12; + outptr_times4_base += 4; + outptr_times1_base += 1; + } +} + +static void gemm_s16x16x32_12x4_pack_B_n(dt_int16* out, const dt_int16* in, + int ldin, int x0, int xmax, int k0, + int kmax) { + const int ksize = kmax - k0; + const int ksize4 = ksize * 4; + int16_t* outptr = out; + int16_t* outptr_base = out; + int16_t* outptr_interleave = NULL; + int k = k0; + for (; k + 3 < kmax; k += 4) { + const int16_t* inptr0 = in + k * ldin + x0; + const int16_t* inptr1 = inptr0 + ldin; + const int16_t* inptr2 = inptr1 + ldin; + const int16_t* inptr3 = inptr2 + ldin; + + int x = x0; + outptr = outptr_base; + for (; x + 3 < xmax; x += 4) { + outptr_interleave = outptr; + interleave_4x4_1_h(inptr0, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize4; + } + if (x < xmax) { + outptr_interleave = outptr; + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr_interleave, 4, + xmax - x); + outptr += ksize4; + } + outptr_base += 4 * 4; + } + for (; k < kmax; k++) { + const int16_t* inptr0 = in + k * ldin + x0; + prefetch_2x(inptr0); + + int x = x0; + outptr = outptr_base; + for (; x + 3 < xmax; x += 4) { + outptr_interleave = outptr; + int16x4_t vdata = vld1_s16(inptr0); + vst1_s16(outptr_interleave, vdata); + inptr0 += 4; + outptr += ksize4; + } + if (x < xmax) { + int remain = xmax - x; + outptr_interleave = outptr; + interleave_helper(inptr0, outptr_interleave, 4, remain); + outptr += ksize4; + } + outptr_base += 4; + } +} + +static void gemm_s16x16x32_12x4_transpose_pack_B_n(dt_int16* outptr, + const dt_int16* inptr, + int ldin, int y0, int ymax, + int k0, int kmax) { + int K = kmax - k0; + int y = y0; + int16_t* out = outptr; + int16_t zerobuff[4]; + std::memset(zerobuff, 0, sizeof(int16_t) * 4); + for (; y + 3 < ymax; y += 4) { + const int16_t* inptr0 = inptr + y * ldin + k0; + const int16_t* inptr1 = inptr0 + ldin; + const int16_t* inptr2 = inptr1 + ldin; + const int16_t* inptr3 = inptr2 + ldin; + int k = 0; + for (; k + 3 < K; k += 4) { + transpose_4x4_1_h(inptr0, inptr1, inptr2, inptr3, out); + } + for (; k < K; k++) { + transpose_4x1(inptr0, inptr1, inptr2, inptr3, out); + } + } + if (y < ymax) { + const int16_t *inptr0, *inptr1, *inptr2, *inptr3; + inptr0 = inptr + y * ldin + k0; + inptr1 = inptr0 + ldin; + inptr2 = inptr1 + ldin; + + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + int k = 0; + for (; k + 3 < K; k += 4) { + transpose_4x4_1_h(inptr0, inptr1, inptr2, inptr3, out); + } + for (; k < K; k++) { + transpose_4x1(inptr0, inptr1, inptr2, inptr3, out); + } + } +} +} // namespace matmul_12x4x1 +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int16x16x32/strategy.cpp b/dnn/src/armv7/matrix_mul/int16x16x32/strategy.cpp new file mode 100644 index 00000000..08ae30a9 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int16x16x32/strategy.cpp @@ -0,0 +1,139 @@ +/** + * \file dnn/src/armv7/matrix_mul/int16x16x32/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/armv7/matrix_mul/int16x16x32/kernel_12x4x1.h" +#include "src/armv7/matrix_mul/int16x16x32/strategy.h" +#include "src/common/utils.h" +#include "src/fallback/matrix_mul/gemm_common.h" + +using namespace megdnn; +using namespace armv7; +using namespace armv7::matmul; + +// ===========================gemm_s16x16x32_4x4================================= +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16x16x32_12x4); + +void gemm_s16x16x32_12x4::pack_A(dt_int16* out, const dt_int16* in, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_12x4x1::gemm_s16x16x32_12x4_transpose_pack_A_n(out, in, ldin, y0, + ymax, k0, kmax); + } else { + matmul_12x4x1::gemm_s16x16x32_12x4_pack_A_n(out, in, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_s16x16x32_12x4::pack_B(dt_int16* out, const dt_int16* in, int ldin, + int x0, int xmax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_12x4x1::gemm_s16x16x32_12x4_transpose_pack_B_n(out, in, ldin, x0, + xmax, k0, kmax); + } else { + matmul_12x4x1::gemm_s16x16x32_12x4_pack_B_n(out, in, ldin, x0, xmax, k0, + kmax); + } +} + +void gemm_s16x16x32_12x4::kern(const dt_int16* packA, const dt_int16* packB, + size_t M, size_t N, size_t K, dt_int32* C, + size_t LDC, bool is_first_k, const dt_int32*, + dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int16 && + C_dtype.enumv() == DTypeEnum::Int32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 12; + constexpr size_t B_INTERLEAVE = 4; + const int K12 = K * 12; + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int16* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_12x4x1::kern_12x4(packA, cur_packB, K, output, LDC, + is_first_k); + output += B_INTERLEAVE; + cur_packB += K4; + } + + if (n < N ){ + matmul_12x4x1::kern_12x123(packA, cur_packB, K, output, LDC, + is_first_k, (N-n)); + output += (N-n); + cur_packB += K4; + + } + + packA += K12; + } + + for (; m + 3 < M; m += 4) { + int32_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int16* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_12x4x1::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k); + output += B_INTERLEAVE; + cur_packB += K4; + } + + if (n < N){ + int remain = N - n; + matmul_12x4x1::kern_4x123(packA, cur_packB, K, output, LDC, + is_first_k,remain); + output += remain; + cur_packB += K4; + } + + packA += K4; + } + + for (; m < M; m++) { + int32_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int16* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_12x4x1::kern_1x4(packA, cur_packB, K, output, LDC, + is_first_k); + output += B_INTERLEAVE; + cur_packB += K4; + } + + if (n < N) { + int remain = N - n; + matmul_12x4x1::kern_1x123(packA, cur_packB, K, output, LDC, + is_first_k,remain); + output += remain; + cur_packB += K4; + } + + packA += K; + } +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int16x16x32/strategy.h b/dnn/src/armv7/matrix_mul/int16x16x32/strategy.h new file mode 100644 index 00000000..d550ac80 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int16x16x32/strategy.h @@ -0,0 +1,28 @@ +/** + * \file dnn/src/armv7/matrix_mul/int16x16x32/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(int16_t, int32_t, int32_t, 12, 4, 1, false, true, + gemm_s16x16x32_12x4); + +MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 4, 8, 1, false, + true, gemm_nopack_s16_4x8); + +} // namespace matmul +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int16x16x32/strategy_mk8_4x8.cpp b/dnn/src/armv7/matrix_mul/int16x16x32/strategy_mk8_4x8.cpp new file mode 100644 index 00000000..335c9b67 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int16x16x32/strategy_mk8_4x8.cpp @@ -0,0 +1,267 @@ +/** + * \file dnn/src/armv7/matrix_mul/int16x16x32/strategy_mk8_4x8.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/matrix_mul/int16x16x32/strategy.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +using namespace megdnn; +using namespace armv7; +using namespace armv7::matmul; + +namespace { + +// Overview of register layout: +// +// A 4x8 cell of Rhs is stored in 16bit in q0-q3 +// A 8x8 cell of Lhs is stored in 16bit in q4-q7 +// A 2x8 block of accumulators is stored in 32bit in q8-v15. +// +// Rhs +--------+ +// | q4[0-7]| +// | q5[0-7]| +// | q6[0-7]| +// | q7[0-7]| +// +--------+ +// Lhs +// +--------+ - - - - -+--------+--------+ +// | q0[0-7]| | q8[0-3]| v9[0-3]| +// | q1[0-7]| |q10[0-3]|v11[0-3]| +// | q2[0-7]| |q12[0-3]|v13[0-3]| +// | q3[0-7]| |q14[0-3]|v15[0-3]| +// +--------+ +--------+--------+ +// Accumulator +void kern_4x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, + dt_int32* output) { + //! As each load 16 number from B, but the pos add 16 * 2, so we minus 16 + //! here. + LDB = (LDB - 16) * sizeof(dt_int16); + + asm volatile( + "vld1.32 {d8, d9}, [%[a_ptr]]!\n" + "vld1.32 {d0, d1, d2, d3}, [%[b_ptr]]!\n" + "subs %[K], #8\n" + + "vld1.32 {d4, d5, d6, d7}, [%[b_ptr]], %[LDB]\n" + "vmull.s16 q8, d8, d0[0]\n" + "vmull.s16 q10, d8, d2[0]\n" + "vmull.s16 q12, d8, d4[0]\n" + "vmull.s16 q14, d8, d6[0]\n" + + "vld1.32 {d10, d11}, [%[a_ptr]]!\n" + "vmull.s16 q9, d9, d0[0]\n" + "vmull.s16 q11, d9, d2[0]\n" + "vmull.s16 q13, d9, d4[0]\n" + "vmull.s16 q15, d9, d6[0]\n" + + "vld1.32 {d12, d13}, [%[a_ptr]]!\n" + "vmlal.s16 q8, d10, d0[1]\n" + "vmlal.s16 q10, d10, d2[1]\n" + "vmlal.s16 q12, d10, d4[1]\n" + "vmlal.s16 q14, d10, d6[1]\n" + "vmlal.s16 q9, d11, d0[1]\n" + "vmlal.s16 q11, d11, d2[1]\n" + "vmlal.s16 q13, d11, d4[1]\n" + "vmlal.s16 q15, d11, d6[1]\n" + + "vld1.32 {d14, d15}, [%[a_ptr]]!\n" + "vmlal.s16 q8, d12, d0[2]\n" + "vmlal.s16 q10, d12, d2[2]\n" + "vmlal.s16 q12, d12, d4[2]\n" + "vmlal.s16 q14, d12, d6[2]\n" + "vmlal.s16 q9, d13, d0[2]\n" + "vmlal.s16 q11, d13, d2[2]\n" + "vmlal.s16 q13, d13, d4[2]\n" + "vmlal.s16 q15, d13, d6[2]\n" + + "vld1.32 {d8, d9}, [%[a_ptr]]!\n" + "vmlal.s16 q8, d14, d0[3]\n" + "vmlal.s16 q10, d14, d2[3]\n" + "vmlal.s16 q12, d14, d4[3]\n" + "vmlal.s16 q14, d14, d6[3]\n" + "vmlal.s16 q9, d15, d0[3]\n" + "vmlal.s16 q11, d15, d2[3]\n" + "vmlal.s16 q13, d15, d4[3]\n" + "vmlal.s16 q15, d15, d6[3]\n" + + "vld1.32 {d10, d11}, [%[a_ptr]]!\n" + "vmlal.s16 q8, d8, d1[0]\n" + "vmlal.s16 q10, d8, d3[0]\n" + "vmlal.s16 q12, d8, d5[0]\n" + "vmlal.s16 q14, d8, d7[0]\n" + "vmlal.s16 q9, d9, d1[0]\n" + "vmlal.s16 q11, d9, d3[0]\n" + "vmlal.s16 q13, d9, d5[0]\n" + "vmlal.s16 q15, d9, d7[0]\n" + + "vld1.32 {d12, d13}, [%[a_ptr]]!\n" + "vmlal.s16 q8, d10, d1[1]\n" + "vmlal.s16 q10, d10, d3[1]\n" + "vmlal.s16 q12, d10, d5[1]\n" + "vmlal.s16 q14, d10, d7[1]\n" + "vmlal.s16 q9, d11, d1[1]\n" + "vmlal.s16 q11, d11, d3[1]\n" + "vmlal.s16 q13, d11, d5[1]\n" + "vmlal.s16 q15, d11, d7[1]\n" + + "vld1.32 {d14, d15}, [%[a_ptr]]!\n" + "vmlal.s16 q8, d12, d1[2]\n" + "vmlal.s16 q10, d12, d3[2]\n" + "vmlal.s16 q12, d12, d5[2]\n" + "vmlal.s16 q14, d12, d7[2]\n" + "vmlal.s16 q9, d13, d1[2]\n" + "vmlal.s16 q11, d13, d3[2]\n" + "vmlal.s16 q13, d13, d5[2]\n" + "vmlal.s16 q15, d13, d7[2]\n" + + "beq 2f\n" + + "1:\n" + "vld1.32 {d8, d9}, [%[a_ptr]]!\n" + "vmlal.s16 q8, d14, d1[3]\n" + "vmlal.s16 q10, d14, d3[3]\n" + "vmlal.s16 q9, d15, d1[3]\n" + "vmlal.s16 q11, d15, d3[3]\n" + + "vld1.32 {d0, d1, d2, d3}, [%[b_ptr]]!\n" + "vmlal.s16 q12, d14, d5[3]\n" + "vmlal.s16 q14, d14, d7[3]\n" + "vmlal.s16 q13, d15, d5[3]\n" + "vmlal.s16 q15, d15, d7[3]\n" + + "vld1.32 {d4, d5, d6, d7}, [%[b_ptr]], %[LDB]\n" + "vmlal.s16 q8, d8, d0[0]\n" + "vmlal.s16 q10, d8, d2[0]\n" + "vmlal.s16 q12, d8, d4[0]\n" + "vmlal.s16 q14, d8, d6[0]\n" + + "vld1.32 {d10, d11}, [%[a_ptr]]!\n" + "vmlal.s16 q9, d9, d0[0]\n" + "vmlal.s16 q11, d9, d2[0]\n" + "vmlal.s16 q13, d9, d4[0]\n" + "vmlal.s16 q15, d9, d6[0]\n" + + "vld1.32 {d12, d13}, [%[a_ptr]]!\n" + "vmlal.s16 q8, d10, d0[1]\n" + "vmlal.s16 q10, d10, d2[1]\n" + "vmlal.s16 q12, d10, d4[1]\n" + "vmlal.s16 q14, d10, d6[1]\n" + "vmlal.s16 q9, d11, d0[1]\n" + "vmlal.s16 q11, d11, d2[1]\n" + "vmlal.s16 q13, d11, d4[1]\n" + "vmlal.s16 q15, d11, d6[1]\n" + + "vld1.32 {d14, d15}, [%[a_ptr]]!\n" + "vmlal.s16 q8, d12, d0[2]\n" + "vmlal.s16 q10, d12, d2[2]\n" + "vmlal.s16 q12, d12, d4[2]\n" + "vmlal.s16 q14, d12, d6[2]\n" + "vmlal.s16 q9, d13, d0[2]\n" + "vmlal.s16 q11, d13, d2[2]\n" + "vmlal.s16 q13, d13, d4[2]\n" + "vmlal.s16 q15, d13, d6[2]\n" + + "vld1.32 {d8, d9}, [%[a_ptr]]!\n" + "vmlal.s16 q8, d14, d0[3]\n" + "vmlal.s16 q10, d14, d2[3]\n" + "vmlal.s16 q12, d14, d4[3]\n" + "vmlal.s16 q14, d14, d6[3]\n" + "vmlal.s16 q9, d15, d0[3]\n" + "vmlal.s16 q11, d15, d2[3]\n" + "vmlal.s16 q13, d15, d4[3]\n" + "vmlal.s16 q15, d15, d6[3]\n" + + "vld1.32 {d10, d11}, [%[a_ptr]]!\n" + "vmlal.s16 q8, d8, d1[0]\n" + "vmlal.s16 q10, d8, d3[0]\n" + "vmlal.s16 q12, d8, d5[0]\n" + "vmlal.s16 q14, d8, d7[0]\n" + "vmlal.s16 q9, d9, d1[0]\n" + "vmlal.s16 q11, d9, d3[0]\n" + "vmlal.s16 q13, d9, d5[0]\n" + "vmlal.s16 q15, d9, d7[0]\n" + + "vld1.32 {d12, d13}, [%[a_ptr]]!\n" + "vmlal.s16 q8, d10, d1[1]\n" + "vmlal.s16 q10, d10, d3[1]\n" + "vmlal.s16 q12, d10, d5[1]\n" + "vmlal.s16 q14, d10, d7[1]\n" + "vmlal.s16 q9, d11, d1[1]\n" + "vmlal.s16 q11, d11, d3[1]\n" + "vmlal.s16 q13, d11, d5[1]\n" + "vmlal.s16 q15, d11, d7[1]\n" + + "vld1.32 {d14, d15}, [%[a_ptr]]!\n" + "vmlal.s16 q8, d12, d1[2]\n" + "vmlal.s16 q10, d12, d3[2]\n" + "vmlal.s16 q12, d12, d5[2]\n" + "vmlal.s16 q14, d12, d7[2]\n" + "vmlal.s16 q9, d13, d1[2]\n" + "vmlal.s16 q11, d13, d3[2]\n" + "vmlal.s16 q13, d13, d5[2]\n" + "vmlal.s16 q15, d13, d7[2]\n" + + "subs %[K], %[K], #8\n" + "bne 1b\n" + + "2:\n" + "vmlal.s16 q8, d14, d1[3]\n" + "vmlal.s16 q10, d14, d3[3]\n" + "vmlal.s16 q9, d15, d1[3]\n" + "vmlal.s16 q11, d15, d3[3]\n" + "vst1.32 {d16, d17, d18, d19}, [%[output]]!\n" + "vmlal.s16 q12, d14, d5[3]\n" + "vmlal.s16 q14, d14, d7[3]\n" + "vmlal.s16 q13, d15, d5[3]\n" + "vmlal.s16 q15, d15, d7[3]\n" + "vst1.32 {d20, d21, d22, d23}, [%[output]]!\n" + "vst1.32 {d24, d25, d26, d27}, [%[output]]!\n" + "vst1.32 {d28, d29, d30, d31}, [%[output]]!\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", + "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", + "d29", "d30", "d31", "cc", "memory"); +} + +} // anonymous namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_4x8); + +void gemm_nopack_s16_4x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, + size_t LDB, dt_int32* C, size_t LDC, size_t M, + size_t K, size_t N, const dt_int32*, void*, + bool trA, bool trB) const { + constexpr static size_t MB = 8; + constexpr static size_t KB = 8; + constexpr static size_t NB = 4; + constexpr static size_t CALCBLK = 4; + + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); + + //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) + for (size_t m = 0; m < M; m += MB) { + dt_int32* output = C + (m / MB) * LDC; + const dt_int16* cur_B = B; + for (size_t n = 0; n < N; n += NB) { + kern_4x8(A, cur_B, LDB, K, output); + cur_B += KB * NB; + output += MB * NB; + } + A += LDA; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8/kernel_4x2x16.h b/dnn/src/armv7/matrix_mul/int8/kernel_4x2x16.h new file mode 100644 index 00000000..62f03730 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int8/kernel_4x2x16.h @@ -0,0 +1,625 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8/kernel_4x2x16.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul_4x2x16 { + +/** + * Overview of register layout. + * + * A 2x16 block of Rhs is stored in 8 bit in d0--d3. + * A 4x16 block of Lhs is stored in 8 bit in d4--d7. That is only + * half of the register space required, so we loop over these registers + * twice. Only half of it, a 2x16 block, is stored in d4--d7 at + * any given time. + * + * A 4x2 block of accumulators is stored in q8--q15 (as 4x32 bit + * components which need to be horizontally-added at the end) + * + * Only then, having processed 16 levels of depth, do we need to + * horizontally add these int16x8 accumulators into the final + * int32x4 accumulators. + * + * As we do not have enough registers to store all 16 int16x8 + * temporary-16bit-accumulators, we have them cycle through q4--q7. + * + * \warning Fast kernel operating on int8 operands. + * It is assumed that one of the two int8 operands only takes values + * in [-127, 127], while the other may freely range in [-128, 127]. + * The issue with both operands taking the value -128 is that: + * -128*-128 + -128*-128 == -32768 overflows int16. + * Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 + * range. That is the basic idea of this kernel. + * + * + * +--------+--------+ + * |d0[0-8] |d2[0-8] | + * Rhs +--------+--------+ + * |d1[0-8] |d3[0-8] | + * +--------+--------+ + * | | | + * + * Lhs | | | + * + * +--------+--------+ - - - - +------------------ + * |d4[0-8] |d5[0-8] | |q8[0-4] |q9[0-4] | + * |d6[0-8] |d7[0-8] | |q10[0-4]|q11[0-4]| + * |d4[0-8] |d5[0-8] | |q12[0-4]|q13[0-4]| + * |d6[0-8] |d7[0-8] | |q14[0-4]|q15[0-4]| + * +--------+--------+ - - - - +------------------ + * + * Accumulator + */ + +static void kern_4x2(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + MEGDNN_MARK_USED_VAR(m_remain); + MEGDNN_MARK_USED_VAR(n_remain); + K /= 16; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "cmp r5, #0 \n" \ + "beq 102f\n" \ + "cmp %[n_remain], #2\n" \ + "blt 100" n "f\n" \ + "vld1.32 {d" reg_index "}, [r" n "]\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" reg_index "[0]}, [r" n "]\n" \ + "101" n ":\n" \ + "subs r5, r5, #1\n" + +#define LOAD_C \ + "mov r5, %[m_remain]\n" \ + LOAD_LINE("16", "0") \ + LOAD_LINE("17", "1") \ + LOAD_LINE("18", "2") \ + LOAD_LINE("19", "3") \ + "102:\n" + +#define STORE_LINE(reg_index, n) \ + "cmp r5, #0 \n" \ + "beq 105f\n" \ + "cmp %[n_remain], #2\n" \ + "blt 103" n "f\n" \ + "vst1.32 {d" reg_index "}, [r" n "]\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" reg_index "[0]}, [r" n " ]\n" \ + "104" n ":\n" \ + "subs r5, r5, #1\n" + +#define STORE_C \ + "mov r5, %[m_remain]\n" \ + STORE_LINE("8", "0") \ + STORE_LINE("9", "1") \ + STORE_LINE("10", "2") \ + STORE_LINE("11", "3") \ + "105:\n" + + register int32_t* outptr asm("r0") = output; + asm volatile( + "add r1, r0, %[LDC]\n" + "add r2, r1, %[LDC]\n" + "add r3, r2, %[LDC]\n" + "vldr d0, [%[b_ptr], #0]\n" + "vmov.i32 q8, #0\n" + "vldr d4, [%[a_ptr], #0]\n" + "vmov.i32 q9, #0\n" + "vldr d2, [%[b_ptr], #16]\n" + "vmov.i32 q10, q8\n" + "vldr d6, [%[a_ptr], #16]\n" + "vmov.i32 q11, q8\n" + "vldr d1, [%[b_ptr], #8]\n" + "vmov.i32 q12, q8\n" + "vldr d5, [%[a_ptr], #8]\n" + "vmov.i32 q13, q8\n" + "vldr d3, [%[b_ptr], #24]\n" + "vmov.i32 q14, q8\n" + "vldr d7, [%[a_ptr], #24]\n" + "vmov.i32 q15, q8\n" + + // General loop. + "1:\n" + + // Multiply 8 first levels of depth. + "vmull.s8 q4, d0, d4\n" + "add %[b_ptr], %[b_ptr], #32\n" + "vmull.s8 q5, d2, d4\n" + "vldr d4, [%[a_ptr], #32]\n" + "vmull.s8 q6, d0, d6\n" + "vmull.s8 q7, d2, d6\n" + "vldr d6, [%[a_ptr], #48]\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "vmlal.s8 q4, d1, d5\n" + "vmlal.s8 q5, d3, d5\n" + "vldr d5, [%[a_ptr], #40]\n" + "vmlal.s8 q6, d1, d7\n" + "vmlal.s8 q7, d3, d7\n" + "vldr d7, [%[a_ptr], #56]\n" + + // Add pairwise, accumulate into 32-bit accumulators. + "vpadal.s16 q8, q4\n" + "add %[a_ptr], %[a_ptr], #64\n" + "vpadal.s16 q9, q5\n" + "subs %[K], %[K], #1\n" + "vpadal.s16 q10, q6\n" + "vpadal.s16 q11, q7\n" + + "beq 2f\n" + + // Multiply first half. + "vmull.s8 q4, d0, d4\n" + "vmull.s8 q5, d2, d4\n" + "vldr d4, [%[a_ptr], #0]\n" + "vmull.s8 q6, d0, d6\n" + "vldr d0, [%[b_ptr], #0]\n" + "vmull.s8 q7, d2, d6\n" + "vldr d2, [%[b_ptr], #16]\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "vmlal.s8 q4, d1, d5\n" + "vldr d6, [%[a_ptr], #16]\n" + "vmlal.s8 q5, d3, d5\n" + "vldr d5, [%[a_ptr], #8]\n" + "vmlal.s8 q6, d1, d7\n" + "vldr d1, [%[b_ptr], #8]\n" + "vmlal.s8 q7, d3, d7\n" + "vldr d3, [%[b_ptr], #24]\n" + + // Add pairwise, accumulate into 32-bit accumulators. + "vpadal.s16 q12, q4\n" + "vldr d7, [%[a_ptr], #24]\n" + "vpadal.s16 q13, q5\n" + "vpadal.s16 q14, q6\n" + "vpadal.s16 q15, q7\n" + + "b 1b\n" + + "2:\n" + + // Multiply first half. + "vmull.s8 q4, d0, d4\n" + "vmull.s8 q5, d2, d4\n" + "vmull.s8 q6, d0, d6\n" + "vmull.s8 q7, d2, d6\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "vmlal.s8 q4, d1, d5\n" + "vmlal.s8 q5, d3, d5\n" + "vmlal.s8 q6, d1, d7\n" + "vmlal.s8 q7, d3, d7\n" + + // Add pairwise, accumulate into 32-bit accumulators. + "vpadal.s16 q12, q4\n" + "vpadal.s16 q13, q5\n" + "vpadal.s16 q14, q6\n" + "vpadal.s16 q15, q7\n" + "cmp %[is_first_k], #1\n" + + // Reduce 32bit accumulators horizontally. + "vpadd.s32 d0, d16, d17\n" + "vpadd.s32 d1, d18, d19\n" + "vpadd.s32 d2, d20, d21\n" + "vpadd.s32 d3, d22, d23\n" + "vpadd.s32 d4, d24, d25\n" + "vpadd.s32 d5, d26, d27\n" + "vpadd.s32 d6, d28, d29\n" + "vpadd.s32 d7, d30, d31\n" + + "bne 3f\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise). + "vpadd.s32 d8, d0, d1\n" + "vpadd.s32 d9, d2, d3\n" + "vpadd.s32 d10, d4, d5\n" + "vpadd.s32 d11, d6, d7\n" + + "b 4f\n" + + "3:\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise), + // and load destination values from memory. + LOAD_C // + "vpadd.s32 d8, d0, d1\n" + "vpadd.s32 d9, d2, d3\n" + "vpadd.s32 d10, d4, d5\n" + "vpadd.s32 d11, d6, d7\n" + + // Add horizontally-reduced accumulators into + // the values loaded from memory + "vadd.s32 q4, q8, q4\n" + "vadd.s32 q5, q9, q5\n" + + "4:\n" + // Store back into memory + STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr] "+r"(outptr), [m_remain] "+r" (m_remain), + [n_remain] "+r" (n_remain) + : + : "cc", "memory", "r1", "r2", "r3", "r4", "r5", + "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", + "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", + "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", + "d28", "d29", "d30", "d31"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_s8_4x2_pack_A_n(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y + 3 < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + //! read 4 * 16 in each row + for (; K > 15; K -= 16) { + interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); + } + } + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + //! read 4 * 16 in each row + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); + } + } +} + +static void gemm_s8_4x2_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 16) * 4; + int8_t* outptr = out; + + int k = k0; + for (; k < kmax; k += 16) { + int ki = k; + for (int cnt = 0; cnt < 2; ki += 8, cnt++) { + const int8_t* inptr0 = in + ki * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + int8_t* outptr_inner = outptr + ki - k; + + int remain = std::min(ki + 7 - kmax, 7); + int x = x0; + for (; x + 3 < xmax; x += 4) { + if (remain >= 0) { + switch (remain) { + case 7: + inptr0 = zerobuff; + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, + inptr4, inptr5, inptr6, inptr7, + outptr_inner); + outptr_inner += ksize4; + } + + if (x < xmax) { + if (remain >= 0) { + switch (remain) { + case 7: + inptr0 = zerobuff; + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + for (; x < xmax; x++) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + *outptr_inner++ = *inptr4++; + *outptr_inner++ = *inptr5++; + *outptr_inner++ = *inptr6++; + *outptr_inner++ = *inptr7++; + outptr_inner += 8; + } + } + } + + outptr += 16 * 4; + } +} + +static void gemm_s8_4x2_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize2 = round_up(ksize, 16) * 2; + int8_t* outptr = out; + + int k = k0; + for (; k < kmax; k += 16) { + int ki = k; + for (int cnt = 0; cnt < 2; ki += 8, cnt++) { + const int8_t* inptr0 = in + ki * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + int8_t* outptr_inner = outptr + ki - k; + + int remain = std::min(ki + 7 - kmax, 7); + int x = x0; + for (; x + 1 < xmax; x += 2) { + if (remain >= 0) { + switch (remain) { + case 7: + inptr0 = zerobuff; + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_2x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, + inptr4, inptr5, inptr6, inptr7, + outptr_inner); + outptr_inner += ksize2; + } + + if (x < xmax) { + if (remain >= 0) { + switch (remain) { + case 7: + inptr0 = zerobuff; + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + for (; x < xmax; x++) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + *outptr_inner++ = *inptr4++; + *outptr_inner++ = *inptr5++; + *outptr_inner++ = *inptr6++; + *outptr_inner++ = *inptr7++; + outptr_inner += 8; + } + } + } + + outptr += 16 * 2; + } +} + +static void gemm_s8_4x2_pack_B_t(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y + 1 < ymax; y += 2) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + + int K = kmax - k0; + //! read 16 * 2 in each row + for (; K > 15; K -= 16) { + interleave_2x16_1_b(inptr0, inptr1, outptr); + } + + if (K > 0) { + interleave_2(inptr0, inptr1, outptr, 16, K); + } + } + for (; y < ymax; y += 2) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + + int K = kmax - k0; + for (; K > 15; K -= 16) { + if (y + 1 >= ymax) { + switch (y + 1 - ymax) { + case 0: + inptr1 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_2x16_1_b(inptr0, inptr1, outptr); + } + + if (K > 0) { + if (y + 1 >= ymax) { + switch (y + 1 - ymax) { + case 0: + inptr1 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_2(inptr0, inptr1, outptr, 16, K); + } + } +} + +} // matmul_4x2x16 +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8/kernel_4x8x8.h b/dnn/src/armv7/matrix_mul/int8/kernel_4x8x8.h new file mode 100644 index 00000000..6d01e9d4 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int8/kernel_4x8x8.h @@ -0,0 +1,750 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8/kernel_4x8x8.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul_4x8x8 { + +static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, + size_t m_remain) { + K /= 8; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(reg_index1, reg_index2, reg_index3, reg_index4, n) \ + "cmp %[x0], #0 \n" \ + "beq 100f\n" \ + "vld1.32 {d" reg_index1 ", d" reg_index2 ", d" reg_index3 ", d" \ + reg_index4 "}, [r" n "]!\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %[m_remain]\n" \ + LOAD_LINE("8", "9", "10", "11", "0") \ + LOAD_LINE("12", "13", "14", "15", "1") \ + LOAD_LINE("16", "17", "18", "19", "2") \ + LOAD_LINE("20", "21", "22", "23", "3") \ + "100:\n" + +#define STORE_LINE(reg_index1, reg_index2, reg_index3, reg_index4, n) \ + "cmp %[x0], #0 \n" \ + "beq 101f\n" \ + "vst1.32 {d" reg_index1 ", d" reg_index2 ", d" reg_index3 ", d" \ + reg_index4 "}, [r" n "]!\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %[m_remain]\n" \ + STORE_LINE("8", "9", "10", "11", "0") \ + STORE_LINE("12", "13", "14", "15", "1") \ + STORE_LINE("16", "17", "18", "19", "2") \ + STORE_LINE("20", "21", "22", "23", "3") \ + "101:\n" + + // clang-format on + + register int32_t* outptr asm("r0") = output; + asm volatile( + // load accumulator C + "add r1, r0, %[LDC]\n" + "add r2, r1, %[LDC]\n" + "add r3, r2, %[LDC]\n" + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q5, q5, q5\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q7, q7, q7\n" + "veor.s32 q8, q8, q8\n" + "veor.s32 q9, q9, q9\n" + "veor.s32 q10, q10, q10\n" + "veor.s32 q11, q11, q11\n" + + "2: \n" + "vld1.8 {d24}, [%[b_ptr]]!\n" + "vld1.8 {d0}, [%[a_ptr]]!\n" + "vld1.8 {d2}, [%[a_ptr]]!\n" + "vld1.8 {d4}, [%[a_ptr]]!\n" + "vld1.8 {d6}, [%[a_ptr]]!\n" + "vmovl.s8 q12, d24\n" + "vmovl.s8 q0, d0\n" + "vmovl.s8 q1, d2\n" + "vmovl.s8 q2, d4\n" + "vmovl.s8 q3, d6\n" + + "vld1.8 {d26}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d24, d0[0]\n" + "vmlal.s16 q6, d24, d2[0]\n" + "vmlal.s16 q8, d24, d4[0]\n" + "vmlal.s16 q10, d24, d6[0]\n" + "vmovl.s8 q13, d26\n" + "vmlal.s16 q5, d25, d0[0]\n" + "vmlal.s16 q7, d25, d2[0]\n" + "vmlal.s16 q9, d25, d4[0]\n" + "vmlal.s16 q11, d25, d6[0]\n" + + "vld1.8 {d24}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d26, d0[1]\n" + "vmlal.s16 q6, d26, d2[1]\n" + "vmlal.s16 q8, d26, d4[1]\n" + "vmlal.s16 q10, d26, d6[1]\n" + "vmovl.s8 q12, d24\n" + "vmlal.s16 q5, d27, d0[1]\n" + "vmlal.s16 q7, d27, d2[1]\n" + "vmlal.s16 q9, d27, d4[1]\n" + "vmlal.s16 q11, d27, d6[1]\n" + + "vld1.8 {d26}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d24, d0[2]\n" + "vmlal.s16 q6, d24, d2[2]\n" + "vmlal.s16 q8, d24, d4[2]\n" + "vmlal.s16 q10, d24, d6[2]\n" + "vmovl.s8 q13, d26\n" + "vmlal.s16 q5, d25, d0[2]\n" + "vmlal.s16 q7, d25, d2[2]\n" + "vmlal.s16 q9, d25, d4[2]\n" + "vmlal.s16 q11, d25, d6[2]\n" + + "vld1.8 {d24}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d26, d0[3]\n" + "vmlal.s16 q6, d26, d2[3]\n" + "vmlal.s16 q8, d26, d4[3]\n" + "vmlal.s16 q10, d26, d6[3]\n" + "vmovl.s8 q12, d24\n" + "vmlal.s16 q5, d27, d0[3]\n" + "vmlal.s16 q7, d27, d2[3]\n" + "vmlal.s16 q9, d27, d4[3]\n" + "vmlal.s16 q11, d27, d6[3]\n" + + "vld1.8 {d26}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d24, d1[0]\n" + "vmlal.s16 q6, d24, d3[0]\n" + "vmlal.s16 q8, d24, d5[0]\n" + "vmlal.s16 q10, d24, d7[0]\n" + "vmovl.s8 q13, d26\n" + "vmlal.s16 q5, d25, d1[0]\n" + "vmlal.s16 q7, d25, d3[0]\n" + "vmlal.s16 q9, d25, d5[0]\n" + "vmlal.s16 q11, d25, d7[0]\n" + + "vld1.8 {d24}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d26, d1[1]\n" + "vmlal.s16 q6, d26, d3[1]\n" + "vmlal.s16 q8, d26, d5[1]\n" + "vmlal.s16 q10, d26, d7[1]\n" + "vmovl.s8 q12, d24\n" + "vmlal.s16 q5, d27, d1[1]\n" + "vmlal.s16 q7, d27, d3[1]\n" + "vmlal.s16 q9, d27, d5[1]\n" + "vmlal.s16 q11, d27, d7[1]\n" + + "vld1.8 {d26}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d24, d1[2]\n" + "vmlal.s16 q6, d24, d3[2]\n" + "vmlal.s16 q8, d24, d5[2]\n" + "vmlal.s16 q10, d24, d7[2]\n" + "vmovl.s8 q13, d26\n" + "vmlal.s16 q5, d25, d1[2]\n" + "vmlal.s16 q7, d25, d3[2]\n" + "vmlal.s16 q9, d25, d5[2]\n" + "vmlal.s16 q11, d25, d7[2]\n" + + "vmlal.s16 q4, d26, d1[3]\n" + "vmlal.s16 q6, d26, d3[3]\n" + "vmlal.s16 q8, d26, d5[3]\n" + "vmlal.s16 q10, d26, d7[3]\n" + "vmlal.s16 q5, d27, d1[3]\n" + "vmlal.s16 q7, d27, d3[3]\n" + "vmlal.s16 q9, d27, d5[3]\n" + "vmlal.s16 q11, d27, d7[3]\n" + + "subs %[K], %[K], #1\n" + "bne 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [m_remain] "+r"(m_remain), [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", + "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", + "d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, size_t m_remain, + size_t n_remain) { + K /= 8; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(reg_index1, reg_index2, n) \ + "cmp %[x0], #0 \n" \ + "beq 102f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "vld1.32 {d" reg_index1 ", d" reg_index2 "}, [r" n " ]!\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" reg_index1 "[0]}, [r" n " ]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" reg_index1 "[1]}, [r" n " ]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" reg_index2 "[0]}, [r" n " ]!\n" \ + "101" n ":\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %[m_remain]\n" \ + "mov r1, r0\n" \ + LOAD_LINE("8", "9", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + LOAD_LINE("10", "11", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + LOAD_LINE("12", "13", "1") \ + "add r1, r0, %[LDC]\n" \ + LOAD_LINE("14", "15", "1") \ + "102:\n" + +#define STORE_LINE(reg_index1, reg_index2, n) \ + "cmp %[x0], #0 \n" \ + "beq 105f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "vst1.32 {d" reg_index1 ", d" reg_index2 "}, [r" n " ]!\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" reg_index1 "[0]}, [r" n " ]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" reg_index1 "[1]}, [r" n " ]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" reg_index2 "[0]}, [r" n " ]!\n" \ + "104" n ":\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %[m_remain]\n" \ + "mov r1, r0\n" \ + STORE_LINE("8", "9", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + STORE_LINE("10", "11", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + STORE_LINE("12", "13", "1") \ + "add r1, r0, %[LDC]\n" \ + STORE_LINE("14", "15", "1") \ + "105:\n" + + // clang-format on + + register int32_t* outptr asm("r0") = output; + asm volatile( + // load accumulator C + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q5, q5, q5\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q7, q7, q7\n" + + "2: \n" + "vld1.32 {d16[0]}, [%[b_ptr]]!\n" + "vld1.8 {d0}, [%[a_ptr]]!\n" + "vld1.8 {d2}, [%[a_ptr]]!\n" + "vld1.8 {d4}, [%[a_ptr]]!\n" + "vld1.8 {d6}, [%[a_ptr]]!\n" + "vmovl.s8 q8, d16\n" + "vmovl.s8 q0, d0\n" + "vmovl.s8 q1, d2\n" + "vmovl.s8 q2, d4\n" + "vmovl.s8 q3, d6\n" + + "vld1.32 {d18[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d16, d0[0]\n" + "vmlal.s16 q5, d16, d2[0]\n" + "vmovl.s8 q9, d18\n" + "vmlal.s16 q6, d16, d4[0]\n" + "vmlal.s16 q7, d16, d6[0]\n" + + "vld1.32 {d16[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d18, d0[1]\n" + "vmlal.s16 q5, d18, d2[1]\n" + "vmovl.s8 q8, d16\n" + "vmlal.s16 q6, d18, d4[1]\n" + "vmlal.s16 q7, d18, d6[1]\n" + + "vld1.32 {d18[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d16, d0[2]\n" + "vmlal.s16 q5, d16, d2[2]\n" + "vmovl.s8 q9, d18\n" + "vmlal.s16 q6, d16, d4[2]\n" + "vmlal.s16 q7, d16, d6[2]\n" + + "vld1.32 {d16[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d18, d0[3]\n" + "vmlal.s16 q5, d18, d2[3]\n" + "vmovl.s8 q8, d16\n" + "vmlal.s16 q6, d18, d4[3]\n" + "vmlal.s16 q7, d18, d6[3]\n" + + "vld1.32 {d18[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d16, d1[0]\n" + "vmlal.s16 q5, d16, d3[0]\n" + "vmovl.s8 q9, d18\n" + "vmlal.s16 q6, d16, d5[0]\n" + "vmlal.s16 q7, d16, d7[0]\n" + + "vld1.32 {d16[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d18, d1[1]\n" + "vmlal.s16 q5, d18, d3[1]\n" + "vmovl.s8 q8, d16\n" + "vmlal.s16 q6, d18, d5[1]\n" + "vmlal.s16 q7, d18, d7[1]\n" + + "vld1.32 {d18[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d16, d1[2]\n" + "vmlal.s16 q5, d16, d3[2]\n" + "vmovl.s8 q9, d18\n" + "vmlal.s16 q6, d16, d5[2]\n" + "vmlal.s16 q7, d16, d7[2]\n" + + "vmlal.s16 q4, d18, d1[3]\n" + "vmlal.s16 q5, d18, d3[3]\n" + "vmlal.s16 q6, d18, d5[3]\n" + "vmlal.s16 q7, d18, d7[3]\n" + + "subs %[K], %[K], #1\n" + "bne 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), + [K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), + [x0] "+r"(x0), [m_remain] "+r"(m_remain), + [n_remain] "+r"(n_remain) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", + "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", + "d29", "d30", "d31", "r1", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + + +static void gemm_s8_4x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x8_2_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K); + } + } +} + +static void gemm_s8_4x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, + int ldin, int x0, int xmax, int k0, + int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 8) * 4; + int8_t* outptr = out; + int8_t* outptr_base = out; + + int k = k0; + for (; k < kmax; k += 8) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = x0; + outptr = outptr_base; + + for (; x + 3 < xmax; x += 4) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, xmax - x); + } + + outptr_base += 4 * 8; + } +} + +static void gemm_s8_4x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 8) * 4; + const int ksize8 = ksize4 * 2; + int8_t* outptr = out; + int8_t* outptr_base = out; + int8_t* outptr_interleave = nullptr; + int8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k += 8) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = x0; + outptr = outptr_base; + + for (; x + 7 < xmax; x += 8) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + outptr_interleave = outptr; + interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr_interleave); + outptr += ksize8; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + outptr_interleave = outptr; + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr_interleave, 4, 4); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + outptr_interleave = outptr; + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr_interleave, 4, xmax - x); + } + + outptr_base += 8 * 8; + outptr_base4 += 4 * 8; + } +} + +static void gemm_s8_4x8_transpose_pack_B_n(dt_int8* outptr, + const dt_int8* inptr, int ldin, + int y0, int ymax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + constexpr int interleave4 = 32; + constexpr int interleave8 = 64; + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + for (; K > 7; K -= 8) { + transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += interleave8; + } + + if (K > 0) { + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 8, K); + outptr += interleave8; + } + } + + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 7; K -= 8) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += interleave4; + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K); + outptr += interleave4; + } + } +} +} // namespace matmul_4x8x8 +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h b/dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h new file mode 100644 index 00000000..f76c5395 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h @@ -0,0 +1,806 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_DOTPROD + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul_dot_6x8x4 { + +// Overview of register layout: +// +// A 8x4 cell of Rhs is stored in 8bit in q2-q3. +// A 6x4 ping-pong cell of Lhs is stored in 8bit in q0-q1 +// A 6x8 block of accumulators is stored in 8bit in q4-q15 +// +// +--------+--------+ +// |q2[0-16]|q3[0-16]| +// Rhs +--------+--------+ +// +// | | | +// +// Lhs | | | +// +// +-------+-------+ - - - - +--------+--------+ +// |q0[0-4]| | q4[0-4]| q5[0-4]| +// |q0[0-4]| | q6[0-4]| q7[0-4]| +// |q0[0-4]| | q8[0-4]| q9[0-4]| +// |q0[0-4]| |q10[0-4]|q11[0-4]| +// |q1[0-4]| |q12[0-4]|q13[0-4]| +// |q1[0-4]| |q14[0-4]|q15[0-4]| +// +-------+-------+ - - - - +--------+--------+--------+ +// +// Accumulator + +static void kern_6x8(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, + size_t m_remain = 6) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + register int32_t* outptr0 asm("r0") = output; + register int32_t* outptr1 asm("r1") = outptr0 + LDC; + register int32_t* outptr2 asm("r2") = outptr1 + LDC; + register int32_t* outptr3 asm("r3") = outptr2 + LDC; + register int32_t* outptr4 asm("r4") = outptr3 + LDC; + register int32_t* outptr5 asm("r5") = outptr4 + LDC; + +// clang-format off +#define LOAD_LINE(reg_index1, reg_index2, reg_index3, reg_index4, n) \ + "cmp r12, #0 \n" \ + "beq 100f\n" \ + "vld1.32 {d" reg_index1 ", d" reg_index2 ", d" reg_index3 ", d" \ + reg_index4 "}, [r" n "]!\n" \ + "subs r12, r12, #1\n" + +#define LOAD_C \ + "mov r12, %[m_remain]\n" \ + LOAD_LINE("8", "9", "10", "11", "0") \ + LOAD_LINE("12", "13", "14", "15", "1") \ + LOAD_LINE("16", "17", "18", "19", "2") \ + LOAD_LINE("20", "21", "22", "23", "3") \ + LOAD_LINE("24", "25", "26", "27", "4") \ + LOAD_LINE("28", "29", "30", "31", "5") \ + "100:\n" + +#define STORE_LINE(reg_index1, reg_index2, reg_index3, reg_index4, n) \ + "cmp r12, #0 \n" \ + "beq 101f\n" \ + "vst1.32 {d" reg_index1 ", d" reg_index2 ", d" reg_index3 ", d" \ + reg_index4 "}, [r" n "]!\n" \ + "subs r12, r12, #1\n" + +#define STORE_C \ + "mov r12, %[m_remain]\n" \ + STORE_LINE("8", "9", "10", "11", "0") \ + STORE_LINE("12", "13", "14", "15", "1") \ + STORE_LINE("16", "17", "18", "19", "2") \ + STORE_LINE("20", "21", "22", "23", "3") \ + STORE_LINE("24", "25", "26", "27", "4") \ + STORE_LINE("28", "29", "30", "31", "5") \ + "101:\n" + + // clang-format on + + asm volatile( + // load accumulator C + "pld [%[a_ptr]] \n" + "pld [%[b_ptr]] \n" + + "cmp %[is_first_k], #1 \n" + "beq 5f \n" + "cmp %[m_remain], #6 \n" + "beq 7f \n" LOAD_C + "b 6f \n" + + "7:\n" + "vld1.s32 {q4, q5}, [%[outptr0]]\n" + "vld1.s32 {q6, q7}, [%[outptr1]]\n" + "vld1.s32 {q8, q9}, [%[outptr2]]\n" + "vld1.s32 {q10, q11}, [%[outptr3]]\n" + "vld1.s32 {q12, q13}, [%[outptr4]]\n" + "vld1.s32 {q14, q15}, [%[outptr5]]\n" + "b 6f \n" + + "5:\n" + "veor.s32 q4, q4, q4\n" + "pld [%[outptr0]] \n" + "veor.s32 q5, q5, q5\n" + "veor.s32 q6, q6, q6\n" + "pld [%[outptr1]] \n" + "veor.s32 q7, q7, q7\n" + "veor.s32 q8, q8, q8\n" + "pld [%[outptr2]] \n" + "veor.s32 q9, q9, q9\n" + "veor.s32 q10, q10, q10\n" + "pld [%[outptr3]] \n" + "veor.s32 q11, q11, q11\n" + "veor.s32 q12, q12, q12\n" + "pld [%[outptr4]] \n" + "veor.s32 q13, q13, q13\n" + "veor.s32 q14, q14, q14\n" + "pld [%[outptr5]] \n" + "veor.s32 q15, q15, q15\n" + + "6: \n" + "vld1.s8 {d1}, [%[a_ptr]]!\n" + "vld1.s8 {q2}, [%[b_ptr]]!\n" + + // Skip loop if we are doing zero iterations of it. + "cmp %[k], #0 \n" + "beq 4f \n" + + // Loop proper + "1:\n" + "vld1.s8 {q1}, [%[a_ptr]]!\n" + "vsdot.s8 q4 , q2, d1[0]\n" + "vsdot.s8 q6 , q2, d1[1]\n" + "vld1.s8 {q3}, [%[b_ptr]]!\n" + "vsdot.s8 q8 , q2, d2[0]\n" + "vsdot.s8 q10 , q2, d2[1]\n" + "vsdot.s8 q12 , q2, d3[0]\n" + "vsdot.s8 q14 , q2, d3[1]\n" + + "vsdot.s8 q5, q3, d1[0]\n" + "vsdot.s8 q7, q3, d1[1]\n" + "vld1.s8 {q2}, [%[b_ptr]]!\n" + "vsdot.s8 q9, q3, d2[0]\n" + "vsdot.s8 q11, q3, d2[1]\n" + "vld1.s8 {d1}, [%[a_ptr]]!\n" + "vsdot.s8 q13, q3, d3[0]\n" + "vsdot.s8 q15, q3, d3[1]\n" + /////////////////////////////////////// + "vld1.s8 {q1}, [%[a_ptr]]!\n" + "vsdot.s8 q4 , q2, d1[0]\n" + "vsdot.s8 q6 , q2, d1[1]\n" + "vld1.s8 {q3}, [%[b_ptr]]!\n" + "vsdot.s8 q8 , q2, d2[0]\n" + "vsdot.s8 q10 , q2, d2[1]\n" + "vsdot.s8 q12 , q2, d3[0]\n" + "vsdot.s8 q14 , q2, d3[1]\n" + + "vsdot.s8 q5, q3, d1[0]\n" + "vsdot.s8 q7, q3, d1[1]\n" + "vld1.s8 {q2}, [%[b_ptr]]!\n" + "vsdot.s8 q9, q3, d2[0]\n" + "pld [%[b_ptr]] \n" + "vsdot.s8 q11, q3, d2[1]\n" + "vld1.s8 {d1}, [%[a_ptr]]!\n" + "subs %[k], %[k], #1\n" + "vsdot.s8 q13, q3, d3[0]\n" + "pld [%[a_ptr]] \n" + "vsdot.s8 q15, q3, d3[1]\n" + + "bne 1b\n" + + // Target to use when K is 1 or 2 (i.e. zero iterations of main + // loop) + "4:\n" + + // Branch to alternative tail for odd K + "cmp %[oddk], #0 \n" + "bne 2f \n" + + // Detached final iteration (even K) + "vld1.s8 {q1}, [%[a_ptr]]!\n" + "vsdot.s8 q4 , q2, d1[0]\n" + "vsdot.s8 q6 , q2, d1[1]\n" + "vsdot.s8 q8 , q2, d2[0]\n" + "vld1.s8 {q3}, [%[b_ptr]]!\n" + "vsdot.s8 q10 , q2, d2[1]\n" + "vsdot.s8 q12 , q2, d3[0]\n" + "vsdot.s8 q14 , q2, d3[1]\n" + + "vsdot.s8 q5, q3, d1[0]\n" + "vsdot.s8 q7, q3, d1[1]\n" + "vld1.s8 {q2}, [%[b_ptr]]!\n" + "vsdot.s8 q9, q3, d2[0]\n" + "vsdot.s8 q11, q3, d2[1]\n" + "vld1.s8 {d1}, [%[a_ptr]]!\n" + "vsdot.s8 q13, q3, d3[0]\n" + "vsdot.s8 q15, q3, d3[1]\n" + /////////////////////////////////////// + + "2:\n" + "vld1.s8 {q1}, [%[a_ptr]]!\n" + "vsdot.s8 q4 , q2, d1[0]\n" + "vsdot.s8 q6 , q2, d1[1]\n" + "vsdot.s8 q8 , q2, d2[0]\n" + "vld1.s8 {q3}, [%[b_ptr]]!\n" + "vsdot.s8 q10 , q2, d2[1]\n" + "vsdot.s8 q12 , q2, d3[0]\n" + "vsdot.s8 q14 , q2, d3[1]\n" + + "vsdot.s8 q5, q3, d1[0]\n" + "vsdot.s8 q7, q3, d1[1]\n" + "vsdot.s8 q9, q3, d2[0]\n" + "vsdot.s8 q11, q3, d2[1]\n" + "vsdot.s8 q13, q3, d3[0]\n" + "vsdot.s8 q15, q3, d3[1]\n" STORE_C + + : [k] "+r"(k), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), + [m_remain] "+r"(m_remain), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), + [outptr3] "+r"(outptr3), [outptr4] "+r"(outptr4), + [outptr5] "+r"(outptr5) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "r12", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} +// Overview of register layout: +// +// A 8x4 cell of Rhs is stored in 8bit in q2-q3. +// A 6x4 cell of Lhs is stored in 8bit in q0-q1 +// A 6x8 block of accumulators is stored in 8bit in q4-q15 +// +// +--------+ +// |q2[0-16]| +// |q3[0-16]| +// Rhs +--------+ +// +// | | +// +// Lhs | | +// +// +-------+-------+ - - - - +--------+ +// |q0[0-4]| | q4[0-4]| +// |q0[0-4]| | q6[0-4]| +// |q1[0-4]| | q8[0-4]| +// |q1[0-4]| |q10[0-4]| +// |q1[0-4]| |q12[0-4]| +// |q1[0-4]| |q14[0-4]| +// +-------+-------+ - - - - +--------+ +// +// Accumulator + +static void kern_6x4(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, + size_t n_remain = 8, size_t m_remain = 6) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + +// clang-format off +#define LOAD_LINE(reg_index1, reg_index2, n) \ + "cmp r12, #0 \n" \ + "beq 102f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "vld1.32 {d" reg_index1 ", d" reg_index2 "}, [r" n " ]!\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" reg_index1 "[0]}, [r" n " ]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" reg_index1 "[1]}, [r" n " ]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" reg_index2 "[0]}, [r" n " ]!\n" \ + "101" n ":\n" \ + "subs r12, r12, #1\n" + +#define LOAD_C \ + "mov r12, %[m_remain]\n" \ + LOAD_LINE("8", "9", "0") \ + LOAD_LINE("12", "13", "1") \ + LOAD_LINE("16", "17", "2") \ + LOAD_LINE("20", "21", "3") \ + LOAD_LINE("24", "25", "4") \ + LOAD_LINE("28", "29", "5") \ + "102:\n" + +#define STORE_LINE(reg_index1, reg_index2, n) \ + "cmp r12, #0 \n" \ + "beq 105f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "vst1.32 {d" reg_index1 ", d" reg_index2 "}, [r" n " ]!\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" reg_index1 "[0]}, [r" n " ]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" reg_index1 "[1]}, [r" n " ]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" reg_index2 "[0]}, [r" n " ]!\n" \ + "104" n ":\n" \ + "subs r12, r12, #1\n" + +#define STORE_C \ + "mov r12, %[m_remain]\n" \ + STORE_LINE("8", "9", "0") \ + STORE_LINE("12", "13", "1") \ + STORE_LINE("16", "17", "2") \ + STORE_LINE("20", "21", "3") \ + STORE_LINE("24", "25", "4") \ + STORE_LINE("28", "29", "5") \ + "105:\n" + + // clang-format on + + register int32_t* outptr0 asm("r0") = output; + register int32_t* outptr1 asm("r1") = outptr0 + LDC; + register int32_t* outptr2 asm("r2") = outptr1 + LDC; + register int32_t* outptr3 asm("r3") = outptr2 + LDC; + register int32_t* outptr4 asm("r4") = outptr3 + LDC; + register int32_t* outptr5 asm("r5") = outptr4 + LDC; + + asm volatile( + "pld [%[a_ptr]] \n" + "pld [%[b_ptr]] \n" + + "cmp %[is_first_k], #1 \n" + "beq 5f \n" + "cmp %[m_remain], #6 \n" + "beq 7f \n" LOAD_C + "b 6f \n" + + "7:\n" + "vld1.s32 {q4}, [%[outptr0]]\n" + "vld1.s32 {q6}, [%[outptr1]]\n" + "vld1.s32 {q8}, [%[outptr2]]\n" + "vld1.s32 {q10}, [%[outptr3]]\n" + "vld1.s32 {q12}, [%[outptr4]]\n" + "vld1.s32 {q14}, [%[outptr5]]\n" + "b 6f \n" + + "5:\n" + "veor.s32 q4, q4, q4\n" + "pld [%[outptr0]] \n" + "veor.s32 q6, q6, q6\n" + "pld [%[outptr1]] \n" + "veor.s32 q8, q8, q8\n" + "pld [%[outptr2]] \n" + "veor.s32 q10, q10, q10\n" + "pld [%[outptr3]] \n" + "veor.s32 q12, q12, q12\n" + "pld [%[outptr4]] \n" + "veor.s32 q14, q14, q14\n" + "pld [%[outptr5]] \n" + + "6:\n" + "vld1.s8 {d1}, [%[a_ptr]]!\n" + "vld1.s8 {q2}, [%[b_ptr]]!\n" + + // Skip loop if we are doing zero iterations of it. + "cmp %[k], #2\n" + "bgt 1f\n" + "beq 4f\n" + "blt 2f\n" + + // Loop proper + "1:\n" + "vld1.s8 {q1}, [%[a_ptr]]!\n" + "vsdot.s8 q4 , q2, d1[0]\n" + "vsdot.s8 q6 , q2, d1[1]\n" + "vsdot.s8 q8 , q2, d2[0]\n" + "vld1.s8 {q3}, [%[b_ptr]]!\n" + "vsdot.s8 q10 , q2, d2[1]\n" + "vld1.s8 {d1}, [%[a_ptr]]!\n" + "vsdot.s8 q12 , q2, d3[0]\n" + "vsdot.s8 q14 , q2, d3[1]\n" + + /////////////////////////////////////// + "vld1.s8 {q1}, [%[a_ptr]]!\n" + "vsdot.s8 q4 , q3, d1[0]\n" + "vsdot.s8 q6 , q3, d1[1]\n" + "vsdot.s8 q8 , q3, d2[0]\n" + "vld1.s8 {q2}, [%[b_ptr]]!\n" + "vsdot.s8 q10 , q3, d2[1]\n" + "vld1.s8 {d1}, [%[a_ptr]]!\n" + "vsdot.s8 q12 , q3, d3[0]\n" + "vsdot.s8 q14 , q3, d3[1]\n" + + "sub %[k], %[k], #2\n" + "cmp %[k], #2\n" + "bgt 1b\n" + "blt 2f\n" + + // Target to use when left K is 2 + "4:\n" + + "vld1.s8 {q1}, [%[a_ptr]]!\n" + "vsdot.s8 q4 , q2, d1[0]\n" + "vsdot.s8 q6 , q2, d1[1]\n" + "vsdot.s8 q8 , q2, d2[0]\n" + "vld1.s8 {q3}, [%[b_ptr]]!\n" + "vsdot.s8 q10 , q2, d2[1]\n" + "vld1.s8 {d1}, [%[a_ptr]]!\n" + "vsdot.s8 q12 , q2, d3[0]\n" + "vsdot.s8 q14 , q2, d3[1]\n" + + /////////////////////////////////////// + "vld1.s8 {q1}, [%[a_ptr]]!\n" + "vsdot.s8 q4 , q3, d1[0]\n" + "vsdot.s8 q6 , q3, d1[1]\n" + "vsdot.s8 q8 , q3, d2[0]\n" + "vsdot.s8 q10 , q3, d2[1]\n" + "vsdot.s8 q12 , q3, d3[0]\n" + "vsdot.s8 q14 , q3, d3[1]\n" + "b 3f\n" + + // tail for left K is 1 + + "2:\n" + "vld1.s8 {q1}, [%[a_ptr]]!\n" + "vsdot.s8 q4 , q2, d1[0]\n" + "vsdot.s8 q6 , q2, d1[1]\n" + "vsdot.s8 q8 , q2, d2[0]\n" + "vsdot.s8 q10 , q2, d2[1]\n" + "vsdot.s8 q12 , q2, d3[0]\n" + "vsdot.s8 q14 , q2, d3[1]\n" + + "3:\n" + + STORE_C + + : [k] "+r"(K), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), + [outptr3] "+r"(outptr3), [outptr4] "+r"(outptr4), + [outptr5] "+r"(outptr5), [m_remain] "+r"(m_remain), + [n_remain] "+r"(n_remain) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "cc", "r12", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_s8_6x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y < ymax; y += 6) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + int K = kmax - k0; + for (; K > 31; K -= 32) { + if (y + 5 >= ymax) { + switch (y + 5 - ymax) { + case 4: + inptr1 = zerobuff; + case 3: + inptr2 = zerobuff; + case 2: + inptr3 = zerobuff; + case 1: + inptr4 = zerobuff; + case 0: + inptr5 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_6x4_8_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + outptr); + } + for (; K > 15; K -= 16) { + if (y + 5 >= ymax) { + switch (y + 5 - ymax) { + case 4: + inptr1 = zerobuff; + case 3: + inptr2 = zerobuff; + case 2: + inptr3 = zerobuff; + case 1: + inptr4 = zerobuff; + case 0: + inptr5 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_6x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + outptr); + } + if (K > 0) { + if (y + 5 >= ymax) { + switch (y + 5 - ymax) { + case 4: + inptr1 = zerobuff; + case 3: + inptr2 = zerobuff; + case 2: + inptr3 = zerobuff; + case 1: + inptr4 = zerobuff; + case 0: + inptr5 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_6(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, outptr, + 4, K); + } + } +} + +static void gemm_s8_6x8_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize6 = round_up(ksize, 4) * 6; + int8_t* outptr = out; + int8_t* outptr_base = out; + + int k = k0; + for (; k < kmax; k += 4) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int x = x0; + outptr = outptr_base; + for (; x + 5 < xmax; x += 6) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_6x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += ksize6; + } + + if (x < xmax) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 6, xmax - x); + } + outptr_base += 6 * 4; + } +} + +static void gemm_s8_6x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize8 = round_up(ksize, 4) * 8; + const int ksize4 = round_up(ksize, 4) * 4; + int8_t* outptr = out; + int8_t* outptr_base = out; + //! 4x4 block output start pos + int8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k += 4) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int x = x0; + outptr = outptr_base; + for (; x + 7 < xmax; x += 8) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += ksize8; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, 4); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + + outptr_base += 8 * 4; + outptr_base4 += 4 * 4; + } +} + +static void gemm_s8_6x8_pack_B_t(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + //! read 12 * 4 in each row + for (; K > 15; K -= 16) { + interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + } + if (K > 0) { + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, K); + } + } + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + //! read 4 * 4 in each row + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, K); + } + } +} + +} // namespace matmul_dot_6x8x4 +} // namespace armv7 +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h b/dnn/src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h new file mode 100644 index 00000000..f49aeebc --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h @@ -0,0 +1,365 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul_mk4_4x2x16 { + +/** + * Overview of register layout. + * + * A 2x16 block of Rhs is stored in 8 bit in d0--d3. + * A 4x16 block of Lhs is stored in 8 bit in d4--d7. That is only + * half of the register space required, so we loop over these registers + * twice. Only half of it, a 2x16 block, is stored in d4--d7 at + * any given time. + * + * A 4x2 block of accumulators is stored in q8--q15 (as 4x32 bit + * components which need to be horizontally-added at the end) + * + * Only then, having processed 16 levels of depth, do we need to + * horizontally add these int16x8 accumulators into the final + * int32x4 accumulators. + * + * As we do not have enough registers to store all 16 int16x8 + * temporary-16bit-accumulators, we have them cycle through q4--q7. + * + * \warning Fast kernel operating on int8 operands. + * It is assumed that one of the two int8 operands only takes values + * in [-127, 127], while the other may freely range in [-128, 127]. + * The issue with both operands taking the value -128 is that: + * -128*-128 + -128*-128 == -32768 overflows int16. + * Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 + * range. That is the basic idea of this kernel. + * + * + * +--------+--------+ + * |d0[0-8] |d2[0-8] | + * Rhs +--------+--------+ + * |d1[0-8] |d3[0-8] | + * +--------+--------+ + * | | | + * + * Lhs | | | + * + * +--------+--------+ - - - - +------------------ + * |d4[0-8] |d5[0-8] | |q8[0-4] |q9[0-4] | + * |d6[0-8] |d7[0-8] | |q10[0-4]|q11[0-4]| + * |d4[0-8] |d5[0-8] | |q12[0-4]|q13[0-4]| + * |d6[0-8] |d7[0-8] | |q14[0-4]|q15[0-4]| + * +--------+--------+ - - - - +------------------ + * + * Accumulator + */ + +static void kern_4x2(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, bool is_first_k, int n_remain) { + MEGDNN_MARK_USED_VAR(n_remain); + K /= 16; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + asm volatile( + "vldr d0, [%[b_ptr], #0]\n" + "vmov.i32 q8, #0\n" + "vldr d4, [%[a_ptr], #0]\n" + "vmov.i32 q9, #0\n" + "vldr d1, [%[b_ptr], #8]\n" + "vmov.i32 q10, q8\n" + "vldr d5, [%[a_ptr], #8]\n" + "vmov.i32 q11, q8\n" + "vldr d2, [%[b_ptr], #16]\n" + "vmov.i32 q12, q8\n" + "vldr d6, [%[a_ptr], #16]\n" + "vmov.i32 q13, q8\n" + "vldr d3, [%[b_ptr], #24]\n" + "vmov.i32 q14, q8\n" + "vldr d7, [%[a_ptr], #24]\n" + "vmov.i32 q15, q8\n" + + // General loop. + "1:\n" + "vmull.s8 q4, d0, d4\n" + "add %[b_ptr], %[b_ptr], #32\n" + "vmull.s8 q5, d2, d4\n" + "vldr d4, [%[a_ptr], #32]\n" + "vmull.s8 q6, d0, d6\n" + "vmull.s8 q7, d2, d6\n" + "vldr d6, [%[a_ptr], #48]\n" + + "vmlal.s8 q4, d1, d5\n" + "vmlal.s8 q5, d3, d5\n" + "vldr d5, [%[a_ptr], #40]\n" + "vmlal.s8 q6, d1, d7\n" + "vmlal.s8 q7, d3, d7\n" + "vldr d7, [%[a_ptr], #56]\n" + + "vpadal.s16 q8, q4\n" + "add %[a_ptr], %[a_ptr], #64\n" + "vpadal.s16 q9, q5\n" + "subs %[K], %[K], #1\n" + "vpadal.s16 q10, q6\n" + "vpadal.s16 q11, q7\n" + + "beq 2f\n" + + "vmull.s8 q4, d0, d4\n" + "vmull.s8 q5, d2, d4\n" + "vldr d4, [%[a_ptr], #0]\n" + "vmull.s8 q6, d0, d6\n" + "vldr d0, [%[b_ptr], #0]\n" + "vmull.s8 q7, d2, d6\n" + "vldr d2, [%[b_ptr], #16]\n" + + "vmlal.s8 q4, d1, d5\n" + "vldr d6, [%[a_ptr], #16]\n" + "vmlal.s8 q5, d3, d5\n" + "vldr d5, [%[a_ptr], #8]\n" + "vmlal.s8 q6, d1, d7\n" + "vldr d1, [%[b_ptr], #8]\n" + "vmlal.s8 q7, d3, d7\n" + "vldr d3, [%[b_ptr], #24]\n" + + // Add pairwise, accumulate into 32-bit accumulators. + "vpadal.s16 q12, q4\n" + "vldr d7, [%[a_ptr], #24]\n" + "vpadal.s16 q13, q5\n" + "vpadal.s16 q14, q6\n" + "vpadal.s16 q15, q7\n" + + "b 1b\n" + + "2:\n" + // Multiply first half. + "vmull.s8 q4, d0, d4\n" + "vmull.s8 q5, d2, d4\n" + "vmull.s8 q6, d0, d6\n" + "vmull.s8 q7, d2, d6\n" + + "vmlal.s8 q4, d1, d5\n" + "vmlal.s8 q5, d3, d5\n" + "vmlal.s8 q6, d1, d7\n" + "vmlal.s8 q7, d3, d7\n" + + "vpadal.s16 q12, q4\n" + "vpadal.s16 q13, q5\n" + "vpadal.s16 q14, q6\n" + "vpadal.s16 q15, q7\n" + "cmp %[is_first_k], #1\n" + + // Reduce 32bit accumulators horizontally. + "vpadd.s32 d0, d16, d17\n" + "vpadd.s32 d1, d18, d19\n" + "vpadd.s32 d2, d20, d21\n" + "vpadd.s32 d3, d22, d23\n" + "vpadd.s32 d4, d24, d25\n" + "vpadd.s32 d5, d26, d27\n" + "vpadd.s32 d6, d28, d29\n" + "vpadd.s32 d7, d30, d31\n" + + "bne 3f\n" + "vpadd.s32 d8, d0, d2\n" + "vpadd.s32 d9, d4, d6\n" + "vpadd.s32 d10, d1, d3\n" + "vpadd.s32 d11, d5, d7\n" + "b 5f\n" + + "3:\n" + "cmp %[n_remain], #1\n" + "beq 4f\n" + "vldr d18, [%[outptr], #16]\n" + "vldr d19, [%[outptr], #24]\n" + "4:\n" + "vldr d16, [%[outptr]]\n" + "vldr d17, [%[outptr], #8]\n" + + "vpadd.s32 d8, d0, d2\n" + "vpadd.s32 d9, d4, d6\n" + "vpadd.s32 d10, d1, d3\n" + "vpadd.s32 d11, d5, d7\n" + + "vadd.s32 q4, q8, q4\n" + "vadd.s32 q5, q9, q5\n" + + "5:\n" + "cmp %[n_remain], #1\n" + "beq 6f\n" + "vstr d10, [%[outptr], #16]\n" + "vstr d11, [%[outptr], #24]\n" + "6:\n" + "vstr d8, [%[outptr]]\n" + "vstr d9, [%[outptr], #8]\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [outptr] "+r"(output), + [n_remain] "+r"(n_remain) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +} + +static void gemm_mk4_s8_4x2_pack_A(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + //! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} + int8_t zerobuff[4][64]; + std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); + megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, + "mk4 matmul with m is not times of 4"); + megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, + "mk4 matmul with k is not times of 4"); + size_t roundk = round_up(kmax - k0, 16); + size_t out_offset = roundk * 4; + int y = y0; + int start_y = y0 / 4; + for (; y + 15 < ymax; y += 16, start_y += 4) { + const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + int8_t* output = outptr + start_y * out_offset; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + int K = kmax - k0; + for (; K > 15; K -= 16) { + transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, + out_offset); + output += 64; + } + if (K > 0) { + std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4); + std::memcpy(zerobuff[1], inptr1, sizeof(int8_t) * K * 4); + std::memcpy(zerobuff[2], inptr2, sizeof(int8_t) * K * 4); + std::memcpy(zerobuff[3], inptr3, sizeof(int8_t) * K * 4); + inptr0 = zerobuff[0]; + inptr1 = zerobuff[1]; + inptr2 = zerobuff[2]; + inptr3 = zerobuff[3]; + transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, + out_offset); + output += 64; + } + } + for (; y + 3 < ymax; y += 4, start_y++) { + const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; + int8_t* output = outptr + start_y * out_offset; + prefetch_2x(inptr0); + int K = kmax - k0; + for (; K > 15; K -= 16) { + transpose_interleave_1x4_4_b(inptr0, output); + output += 64; + } + if (K > 0) { + std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4); + inptr0 = zerobuff[0]; + transpose_interleave_1x4_4_b(inptr0, output); + output += 64; + } + } +} + +static void gemm_mk4_s8_4x2_pack_B(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int32_t zerobuff[4]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ICB = (ksize) / 4; + const int ksize2 = round_up(ICB, 4) * 2; + int32_t* outptr = reinterpret_cast(out); + megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, + "mk4 matmul with k is not times of 4"); + + int k = k0 / 4; + for (; k + 3 < ICB; k += 4) { + const int32_t* inptr0 = + reinterpret_cast(in + k * ldin + x0); + const int32_t* inptr1 = + reinterpret_cast(in + (k + 1) * ldin + x0); + const int32_t* inptr2 = + reinterpret_cast(in + (k + 2) * ldin + x0); + const int32_t* inptr3 = + reinterpret_cast(in + (k + 3) * ldin + x0); + int32_t* outptr_inner = outptr; + + int x = x0; + for (; x + 1 < xmax; x += 2) { + transpose_4x2_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner); + outptr_inner += ksize2; + } + if (x < xmax) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + } + outptr += 4 * 2; + } + if (k < ICB) { + const int32_t* inptr0 = + reinterpret_cast(in + k * ldin + x0); + const int32_t* inptr1 = + reinterpret_cast(in + (k + 1) * ldin + x0); + const int32_t* inptr2 = + reinterpret_cast(in + (k + 2) * ldin + x0); + const int32_t* inptr3 = + reinterpret_cast(in + (k + 3) * ldin + x0); + int32_t* outptr_inner = outptr; + + int x = x0; + for (; x + 1 < xmax; x += 2) { + if (k + 3 >= ICB) { + switch (k + 3 - ICB) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_4x2_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner); + outptr_inner += ksize2; + } + if (x < xmax) { + if (k + 3 >= ICB) { + switch (k + 3 - ICB) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + *outptr_inner++ = *inptr0; + *outptr_inner++ = *inptr1; + *outptr_inner++ = *inptr2; + *outptr_inner++ = *inptr3; + } + outptr += 4 * 2; + } +} + +} // namespace matmul_mk4_4x2x16 +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8/strategy.cpp b/dnn/src/armv7/matrix_mul/int8/strategy.cpp new file mode 100644 index 00000000..275002a0 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int8/strategy.cpp @@ -0,0 +1,311 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/matrix_mul/int8/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/armv7/matrix_mul/int8/kernel_4x2x16.h" +#include "src/armv7/matrix_mul/int8/kernel_4x8x8.h" +#include "src/armv7/matrix_mul/int8/kernel_6x8x4.h" +#include "src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h" +#include "src/common/utils.h" +#include "src/fallback/matrix_mul/gemm_common.h" + +using namespace megdnn; +using namespace armv7; +using namespace armv7::matmul; + +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x2); + +// ===========================gemm_s8_4x2====================================== + +void gemm_s8_4x2::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, + int ymax, int k0, int kmax, bool transpose) const { + if (transpose) { + matmul_4x2x16::gemm_s8_4x2_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); + } else { + matmul_4x2x16::gemm_s8_4x2_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); + } +} + +void gemm_s8_4x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) const { + if (transpose) { + matmul_4x2x16::gemm_s8_4x2_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); + } else { + matmul_4x2x16::gemm_s8_4x2_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); + } +} + +void gemm_s8_4x2::kern(const dt_int8* packA, const dt_int8* packB, size_t M, + size_t N, size_t K, dt_int32* C, size_t LDC, + bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 2; + //! K is packed to times of 4 + K = round_up(K, 16); + const int K4 = K * 4; + const int K2 = K * 2; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, + is_first_k, 4, 2); + output += B_INTERLEAVE; + cur_packB += K2; + } + + for (; n < N; n += B_INTERLEAVE) { + matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, + is_first_k, 4, std::min(N - n, 2)); + output += B_INTERLEAVE; + cur_packB += K2; + } + + packA += K4; + } + + for (; m < M; m += 4) { + int32_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n < N; n += B_INTERLEAVE) { + matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, + is_first_k, std::min(M - m, 4), + std::min(N - n, 2)); + output += B_INTERLEAVE; + cur_packB += K2; + } + packA += K4; + } +} + +// ===========================gemm_s8_4x4====================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x8); + +void gemm_s8_4x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, + int ymax, int k0, int kmax, bool transpose) const { + if (transpose) { + matmul_4x8x8::gemm_s8_4x8_transpose_pack_A_n(out, in, ldin, y0, ymax, + k0, kmax); + } else { + matmul_4x8x8::gemm_s8_4x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); + } +} + +void gemm_s8_4x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) const { + if (transpose) { + matmul_4x8x8::gemm_s8_4x8_transpose_pack_B_n(out, in, ldin, x0, xmax, + k0, kmax); + } else { + matmul_4x8x8::gemm_s8_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); + } +} + +void gemm_s8_4x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, + size_t N, size_t K, dt_int32* C, size_t LDC, + bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 8; + //! K is packed to times of 8 + K = round_up(K, 8); + const int K4 = K * 4; + const int K8 = K * 8; + + size_t m = 0; + for (; m < M; m += A_INTERLEAVE) { + int32_t* output = C + (m * LDC); + const dt_int8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_4x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_4x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), + std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + packA += K4; + } +} + +#if __ARM_FEATURE_DOTPROD +// ===========================gemm_s8_6x8====================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dots8_6x8); +void gemm_dots8_6x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, + int ymax, int k0, int kmax, bool transpose) const { + if (transpose) { + matmul_dot_6x8x4::gemm_s8_6x8_pack_A_t(out, in, ldin, y0, ymax, k0, + kmax); + } else { + matmul_dot_6x8x4::gemm_s8_6x8_pack_A_n(out, in, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_dots8_6x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) const { + if (transpose) { + matmul_dot_6x8x4::gemm_s8_6x8_pack_B_t(out, in, ldin, x0, xmax, k0, + kmax); + } else { + matmul_dot_6x8x4::gemm_s8_6x8_pack_B_n(out, in, ldin, x0, xmax, k0, + kmax); + } +} + +void gemm_dots8_6x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, + size_t N, size_t K, dt_int32* C, size_t LDC, + bool is_first_k, const dt_int32* bias, + dt_int32* workspace) const { + MEGDNN_MARK_USED_VAR(bias); + constexpr size_t A_INTERLEAVE = 6; + constexpr size_t B_INTERLEAVE = 8; + //! K is packed to times of 4 + K = round_up(K, 4); + const int K4 = K * 4; + const int K6 = K * 6; + const int K8 = K * 8; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + (m * LDC); + const dt_int8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_dot_6x8x4::kern_6x8(packA, cur_packB, K, output, LDC, + is_first_k); + output += B_INTERLEAVE; + cur_packB += K8; + } + for (; n < N; n += 4) { + size_t n_remain = std::min(N - n, 4); + matmul_dot_6x8x4::kern_6x4(packA, cur_packB, K, output, LDC, + is_first_k, n_remain); + output += n_remain; + cur_packB += K4; + } + packA += K6; + } + if (m < M) { + int32_t* output = C + (m * LDC); + const dt_int8* cur_packB = packB; + size_t m_remain = std::min(M - m, 6); + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_dot_6x8x4::kern_6x8(packA, cur_packB, K, output, LDC, + is_first_k, m_remain); + output += B_INTERLEAVE; + cur_packB += K8; + } + for (; n < N; n += 4) { + size_t n_remain = std::min(N - n, 4); + matmul_dot_6x8x4::kern_6x4(packA, cur_packB, K, output, LDC, + is_first_k, n_remain, m_remain); + output += n_remain; + cur_packB += K4; + } + } +} +#endif + +// ===========================gemm_mk4_s8_4x2====================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x2); +void gemm_mk4_s8_4x2::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, + int ymax, int k0, int kmax, bool transpose) const { + megdnn_assert(!transpose); + matmul_mk4_4x2x16::gemm_mk4_s8_4x2_pack_A(out, in, ldin, y0, ymax, k0, + kmax); +} + +void gemm_mk4_s8_4x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) const { + megdnn_assert(!transpose); + matmul_mk4_4x2x16::gemm_mk4_s8_4x2_pack_B(out, in, ldin, x0, xmax, k0, + kmax); +} + +void gemm_mk4_s8_4x2::kern(const dt_int8* packA, const dt_int8* packB, size_t M, + size_t N, size_t K, dt_int32* C, size_t LDC, + bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 2; + //! K is packed to times of 4 + K = round_up(K, 16); + const int K4 = K * 4; + const int K2 = K * 2; + megdnn_assert(M % 4 == 0, "mk4 matmul with m is not times of 4"); + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + (m / 4 * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n < N; n += B_INTERLEAVE) { + matmul_mk4_4x2x16::kern_4x2(packA, cur_packB, K, output, is_first_k, + std::min(N - n, 2)); + output += B_INTERLEAVE * 4; + cur_packB += K2; + } + packA += K4; + } +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8/strategy.h b/dnn/src/armv7/matrix_mul/int8/strategy.h new file mode 100644 index 00000000..516a9931 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int8/strategy.h @@ -0,0 +1,34 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, true, + gemm_s8_4x2); + +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 8, 8, false, true, + gemm_s8_4x8); + +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false, + gemm_mk4_s8_4x2); +#if __ARM_FEATURE_DOTPROD +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false, + gemm_dots8_6x8); +#endif +} // namespace matmul +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h new file mode 100644 index 00000000..ac0a591c --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h @@ -0,0 +1,573 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul_4x2x16 { + +/** + * Overview of register layout. + * + * A 2x16 block of Rhs is stored in 8 bit in d0--d3. + * A 4x16 block of Lhs is stored in 8 bit in d4--d10. + * + * A 4x2 block of accumulators is stored in q6--q13 (as 4x16 bit + * components which need to be horizontally-added at the end) + * + * + * +--------+--------+ + * |d0[0-8] |d2[0-8] | + * Rhs +--------+--------+ + * |d1[0-8] |d3[0-8] | + * +--------+--------+ + * | | | + * + * Lhs | | | + * + * +--------+--------+ - - - - +------------------ + * |d4[0-8] |d5[0-8] | |q6[0-8] | q7[0-8]| + * |d6[0-8] |d7[0-8] | |q8[0-8] | q9[0-8]| + * |d8[0-8] |d9[0-8] | |q10[0-8]|q11[0-8]| + * |d10[0-8]|d11[0-8]| |q12[0-8]|q13[0-8]| + * +--------+--------+ - - - - +------------------ + * + * Accumulator + */ + +static void kern_4x2(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + MEGDNN_MARK_USED_VAR(m_remain); + MEGDNN_MARK_USED_VAR(n_remain); + K /= 16; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int16_t); +// clang-format off +#define LOAD_LINE(reg_index, r1, r2, n) \ + "cmp r5, #0 \n" \ + "beq 102f\n" \ + "cmp %[n_remain], #2\n" \ + "blt 100" n "f\n" \ + "vld1.32 {d" reg_index "[" r1 "]}, [r" n "]\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "vld1.16 {d" reg_index "[" r2 "]}, [r" n "]\n" \ + "101" n ":\n" \ + "subs r5, r5, #1\n" + +#define LOAD_C \ + "mov r5, %[m_remain]\n" \ + LOAD_LINE("28", "0", "0", "0") \ + LOAD_LINE("28", "1", "2", "1") \ + LOAD_LINE("29", "0", "0", "2") \ + LOAD_LINE("29", "1", "2", "3") \ + "102:\n" + +#define STORE_LINE(reg_index, r1, r2, n) \ + "cmp r5, #0 \n" \ + "beq 105f\n" \ + "cmp %[n_remain], #2\n" \ + "blt 103" n "f\n" \ + "vst1.32 {d" reg_index "[" r1 "]}, [r" n "]\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "vst1.16 {d" reg_index "[" r2 "]}, [r" n " ]\n" \ + "104" n ":\n" \ + "subs r5, r5, #1\n" + +#define STORE_C \ + "mov r5, %[m_remain]\n" \ + STORE_LINE("30", "0", "0", "0") \ + STORE_LINE("30", "1", "2", "1") \ + STORE_LINE("31", "0", "0", "2") \ + STORE_LINE("31", "1", "2", "3") \ + "105:\n" + + register int16_t* outptr asm("r0") = output; + asm volatile( + "add r1, r0, %[LDC]\n" + "add r2, r1, %[LDC]\n" + "add r3, r2, %[LDC]\n" + "vldr d0, [%[b_ptr], #0]\n" + "vmov.i32 q6, #0\n" + "vldr d4, [%[a_ptr], #0]\n" + "vmov.i32 q7, #0\n" + "vldr d2, [%[b_ptr], #16]\n" + "vmov.i32 q8, #0\n" + "vldr d6, [%[a_ptr], #16]\n" + "vmov.i32 q9, #0\n" + "vldr d1, [%[b_ptr], #8]\n" + "vmov.i32 q10, q8\n" + "vldr d5, [%[a_ptr], #8]\n" + "vmov.i32 q11, q8\n" + "vldr d3, [%[b_ptr], #24]\n" + "vmov.i32 q12, q8\n" + "vldr d7, [%[a_ptr], #24]\n" + "vmov.i32 q13, q8\n" + + // General loop. + "1:\n" + + // Multiply 8 first levels of depth. + "vmlal.s8 q6, d0, d4\n" + "add %[b_ptr], %[b_ptr], #32\n" + "vmlal.s8 q7, d2, d4\n" + "vldr d8, [%[a_ptr], #32]\n" + "vmlal.s8 q8, d0, d6\n" + "vmlal.s8 q9, d2, d6\n" + "vldr d10, [%[a_ptr], #48]\n" + + "vmlal.s8 q6, d1, d5\n" + "vmlal.s8 q7, d3, d5\n" + "vldr d9, [%[a_ptr], #40]\n" + "vmlal.s8 q8, d1, d7\n" + "vmlal.s8 q9, d3, d7\n" + "vldr d11, [%[a_ptr], #56]\n" + + "add %[a_ptr], %[a_ptr], #64\n" + "subs %[K], %[K], #1\n" + + "beq 2f\n" + + // Multiply first half. + "vmlal.s8 q10, d0, d8\n" + "vmlal.s8 q11, d2, d8\n" + "vldr d4, [%[a_ptr], #0]\n" + "vmlal.s8 q12, d0, d10\n" + "vldr d0, [%[b_ptr], #0]\n" + "vmlal.s8 q13, d2, d10\n" + "vldr d2, [%[b_ptr], #16]\n" + + "vmlal.s8 q10, d1, d9\n" + "vldr d6, [%[a_ptr], #16]\n" + "vmlal.s8 q11, d3, d9\n" + "vldr d5, [%[a_ptr], #8]\n" + "vmlal.s8 q12, d1, d11\n" + "vldr d1, [%[b_ptr], #8]\n" + "vmlal.s8 q13, d3, d11\n" + "vldr d3, [%[b_ptr], #24]\n" + + "vldr d7, [%[a_ptr], #24]\n" + + "b 1b\n" + + "2:\n" + + // Multiply first half. + "vmlal.s8 q10, d0, d8\n" + "vmlal.s8 q11, d2, d8\n" + "vmlal.s8 q12, d0, d10\n" + "vmlal.s8 q13, d2, d10\n" + + "vmlal.s8 q10, d1, d9\n" + "vmlal.s8 q11, d3, d9\n" + "vmlal.s8 q12, d1, d11\n" + "vmlal.s8 q13, d3, d11\n" + + "cmp %[is_first_k], #1\n" + + + // Reduce q6-q13 to q0-q3 + "vpadd.s16 d0, d12, d13\n" + "vpadd.s16 d1, d14, d15\n" + "vpadd.s16 d2, d16, d17\n" + "vpadd.s16 d3, d18, d19\n" + "vpadd.s16 d4, d20, d21\n" + "vpadd.s16 d5, d22, d23\n" + "vpadd.s16 d6, d24, d25\n" + "vpadd.s16 d7, d26, d27\n" + + "vpadd.s16 d8, d0, d1\n" + "vpadd.s16 d9, d2, d3\n" + "vpadd.s16 d10, d4, d5\n" + "vpadd.s16 d11, d6, d7\n" + + "vpadd.s16 d30, d8, d9\n" + "vpadd.s16 d31, d10, d11\n" + + "bne 3f\n" + + "b 4f\n" + + "3:\n" + + LOAD_C // + + "vadd.s16 q15, q14, q15\n" + + "4:\n" + // Store back into memory + STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), + [outptr] "+r"(outptr), [m_remain] "+r" (m_remain), + [n_remain] "+r" (n_remain) + : + : "cc", "memory", "r1", "r2", "r3", "r4", "r5", + "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", + "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", + "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", + "d28", "d29", "d30", "d31"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + + +static void gemm_s8x8x16_4x2_pack_A_n(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y + 3 < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + //! read 4 * 16 in each row + for (; K > 15; K -= 16) { + interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); + } + } + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + //! read 4 * 16 in each row + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); + } + } +} + +static void gemm_s8x8x16_4x2_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 16) * 4; + int8_t* outptr = out; + + int k = k0; + for (; k < kmax; k += 16) { + int ki = k; + for (int cnt = 0; cnt < 2; ki += 8, cnt++) { + const int8_t* inptr0 = in + ki * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + int8_t* outptr_inner = outptr + ki - k; + + int remain = std::min(ki + 7 - kmax, 7); + int x = x0; + for (; x + 3 < xmax; x += 4) { + if (remain >= 0) { + switch (remain) { + case 7: + inptr0 = zerobuff; + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, + inptr4, inptr5, inptr6, inptr7, + outptr_inner); + outptr_inner += ksize4; + } + + if (x < xmax) { + if (remain >= 0) { + switch (remain) { + case 7: + inptr0 = zerobuff; + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + for (; x < xmax; x++) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + *outptr_inner++ = *inptr4++; + *outptr_inner++ = *inptr5++; + *outptr_inner++ = *inptr6++; + *outptr_inner++ = *inptr7++; + outptr_inner += 8; + } + } + } + + outptr += 16 * 4; + } +} + +static void gemm_s8x8x16_4x2_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize2 = round_up(ksize, 16) * 2; + int8_t* outptr = out; + + int k = k0; + for (; k < kmax; k += 16) { + int ki = k; + for (int cnt = 0; cnt < 2; ki += 8, cnt++) { + const int8_t* inptr0 = in + ki * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + int8_t* outptr_inner = outptr + ki - k; + + int remain = std::min(ki + 7 - kmax, 7); + int x = x0; + for (; x + 1 < xmax; x += 2) { + if (remain >= 0) { + switch (remain) { + case 7: + inptr0 = zerobuff; + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_2x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, + inptr4, inptr5, inptr6, inptr7, + outptr_inner); + outptr_inner += ksize2; + } + + if (x < xmax) { + if (remain >= 0) { + switch (remain) { + case 7: + inptr0 = zerobuff; + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + for (; x < xmax; x++) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + *outptr_inner++ = *inptr4++; + *outptr_inner++ = *inptr5++; + *outptr_inner++ = *inptr6++; + *outptr_inner++ = *inptr7++; + outptr_inner += 8; + } + } + } + + outptr += 16 * 2; + } +} + +static void gemm_s8x8x16_4x2_pack_B_t(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y + 1 < ymax; y += 2) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + + int K = kmax - k0; + //! read 16 * 2 in each row + for (; K > 15; K -= 16) { + interleave_2x16_1_b(inptr0, inptr1, outptr); + } + + if (K > 0) { + interleave_2(inptr0, inptr1, outptr, 16, K); + } + } + for (; y < ymax; y += 2) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + + int K = kmax - k0; + for (; K > 15; K -= 16) { + if (y + 1 >= ymax) { + switch (y + 1 - ymax) { + case 0: + inptr1 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_2x16_1_b(inptr0, inptr1, outptr); + } + + if (K > 0) { + if (y + 1 >= ymax) { + switch (y + 1 - ymax) { + case 0: + inptr1 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_2(inptr0, inptr1, outptr, 16, K); + } + } +} + +} // matmul_4x2x16 +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h new file mode 100644 index 00000000..5ad26daa --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h @@ -0,0 +1,755 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul_4x8x8 { + +/** + * Overview of register layout: + * + * A 8x8x8 cell of Rhs is stored in 8bit in q8 + * A 8x8x8 cell of Lhs is stored in 8bit in q0-q3 + * A 8x8 block of accumulators is stored in 16bit in q4-q7 + * + * +--------+ + * | q8[0-8]| + * Rhs +--------+ + * Lhs | | + * + * +--------+ - - - - +--------- + * |q0[0-8]| | q4[0-8]| + * |q1[0-8]| | q5[0-8]| + * |q2[0-8]| | q6[0-8]| + * |q3[0-8]| | q7[0-8]| + * +--------+ - - - - +--------- + * + * Accumulator + */ +static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, + size_t m_remain) { + K /= 8; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int16_t); + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(reg_index1, reg_index2, n) \ + "cmp %[x0], #0 \n" \ + "beq 100f\n" \ + "vld1.16 {d" reg_index1 ", d" reg_index2 "}, [r" n "]\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %[m_remain]\n" \ + LOAD_LINE("8", "9", "0") \ + LOAD_LINE("10", "11", "1") \ + LOAD_LINE("12", "13", "2") \ + LOAD_LINE("14", "15", "3") \ + "100:\n" + +#define STORE_LINE(reg_index1, reg_index2, n) \ + "cmp %[x0], #0 \n" \ + "beq 101f\n" \ + "vst1.16 {d" reg_index1 ", d" reg_index2 "}, [r" n "]\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %[m_remain]\n" \ + STORE_LINE("8", "9", "0") \ + STORE_LINE("10", "11","1") \ + STORE_LINE("12", "13", "2") \ + STORE_LINE("14", "15", "3") \ + "101:\n" + +// clang-format on + + register int16_t* outptr asm("r0") = output; + asm volatile( + // load accumulator C + "add r1, r0, %[LDC]\n" + "add r2, r1, %[LDC]\n" + "add r3, r2, %[LDC]\n" + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q5, q5, q5\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q7, q7, q7\n" + + "2: \n" + "vld1.8 {d16}, [%[b_ptr]]!\n" + "vld1.8 {d0}, [%[a_ptr]]!\n" + "vld1.8 {d2}, [%[a_ptr]]!\n" + "vld1.8 {d4}, [%[a_ptr]]!\n" + "vld1.8 {d6}, [%[a_ptr]]!\n" + "vmovl.s8 q8, d16\n" + "vmovl.s8 q0, d0\n" + "vmovl.s8 q1, d2\n" + "vmovl.s8 q2, d4\n" + "vmovl.s8 q3, d6\n" + + "vld1.8 {d18}, [%[b_ptr]]!\n" + "vmla.s16 q4, q8, d0[0]\n" + "vmla.s16 q5, q8, d2[0]\n" + "vmla.s16 q6, q8, d4[0]\n" + "vmla.s16 q7, q8, d6[0]\n" + "vmovl.s8 q9, d18\n" + + "vld1.8 {d20}, [%[b_ptr]]!\n" + "vmla.s16 q4, q9, d0[1]\n" + "vmla.s16 q5, q9, d2[1]\n" + "vmla.s16 q6, q9, d4[1]\n" + "vmla.s16 q7, q9, d6[1]\n" + "vmovl.s8 q10, d20\n" + + "vld1.8 {d22}, [%[b_ptr]]!\n" + "vmla.s16 q4, q10, d0[2]\n" + "vmla.s16 q5, q10, d2[2]\n" + "vmla.s16 q6, q10, d4[2]\n" + "vmla.s16 q7, q10, d6[2]\n" + "vmovl.s8 q11, d22\n" + + "vld1.8 {d24}, [%[b_ptr]]!\n" + "vmla.s16 q4, q11, d0[3]\n" + "vmla.s16 q5, q11, d2[3]\n" + "vmla.s16 q6, q11, d4[3]\n" + "vmla.s16 q7, q11, d6[3]\n" + "vmovl.s8 q12, d24\n" + + "vld1.8 {d26}, [%[b_ptr]]!\n" + "vmla.s16 q4, q12, d1[0]\n" + "vmla.s16 q5, q12, d3[0]\n" + "vmla.s16 q6, q12, d5[0]\n" + "vmla.s16 q7, q12, d7[0]\n" + "vmovl.s8 q13, d26\n" + + "vld1.8 {d28}, [%[b_ptr]]!\n" + "vmla.s16 q4, q13, d1[1]\n" + "vmla.s16 q5, q13, d3[1]\n" + "vmla.s16 q6, q13, d5[1]\n" + "vmla.s16 q7, q13, d7[1]\n" + "vmovl.s8 q14, d28\n" + + "vld1.8 {d30}, [%[b_ptr]]!\n" + "vmla.s16 q4, q14, d1[2]\n" + "vmla.s16 q5, q14, d3[2]\n" + "vmla.s16 q6, q14, d5[2]\n" + "vmla.s16 q7, q14, d7[2]\n" + "vmovl.s8 q15, d30\n" + + "vmla.s16 q4, q15, d1[3]\n" + "vmla.s16 q5, q15, d3[3]\n" + "vmla.s16 q6, q15, d5[3]\n" + "vmla.s16 q7, q15, d7[3]\n" + + "subs %[K], %[K], #1\n" + "bne 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [m_remain] "+r"(m_remain), [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", + "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", + "d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +/** + * Overview of register layout: + * + * A 8x8x8 cell of Rhs is stored in 8bit in q8 + * A 8x8x8 cell of Lhs is stored in 8bit in q0-q3 + * A 8x8 block of accumulators is stored in 16bit in q4-q7 + * + * +--------+ + * | q8[0-4]| + * Rhs +--------+ + * Lhs | | + * + * +--------+ - - - - +--------- + * |q0[0-8]| | q4[0-4]| + * |q1[0-8]| | q5[0-4]| + * |q2[0-8]| | q6[0-4]| + * |q3[0-8]| | q7[0-4]| + * +--------+ - - - - +--------- + * + * Accumulator + */ +static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, size_t m_remain, + size_t n_remain) { + K /= 8; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int16_t); + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(reg_index1, n) \ + "cmp %[x0], #0 \n" \ + "beq 102f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "vld1.16 {d" reg_index1 "}, [r" n " ]!\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "vld1.16 {d" reg_index1 "[0]}, [r" n " ]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "vld1.16 {d" reg_index1 "[1]}, [r" n " ]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "vld1.16 {d" reg_index1 "[2]}, [r" n " ]!\n" \ + "101" n ":\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %[m_remain]\n" \ + "mov r1, r0\n" \ + LOAD_LINE("8", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + LOAD_LINE("10", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + LOAD_LINE("12", "1") \ + "add r1, r0, %[LDC]\n" \ + LOAD_LINE("14", "1") \ + "102:\n" + +#define STORE_LINE(reg_index1, n) \ + "cmp %[x0], #0 \n" \ + "beq 105f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "vst1.16 {d" reg_index1 "}, [r" n " ]!\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "vst1.16 {d" reg_index1 "[0]}, [r" n " ]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "vst1.16 {d" reg_index1 "[1]}, [r" n " ]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "vst1.16 {d" reg_index1 "[2]}, [r" n " ]!\n" \ + "104" n ":\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %[m_remain]\n" \ + "mov r1, r0\n" \ + STORE_LINE("8", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + STORE_LINE("10", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + STORE_LINE("12", "1") \ + "add r1, r0, %[LDC]\n" \ + STORE_LINE("14", "1") \ + "105:\n" + + // clang-format on + + register int16_t* outptr asm("r0") = output; + asm volatile( + // load accumulator C + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q5, q5, q5\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q7, q7, q7\n" + + "2: \n" + "vld1.32 {d16[0]}, [%[b_ptr]]!\n" + "vld1.8 {d0}, [%[a_ptr]]!\n" + "vld1.8 {d2}, [%[a_ptr]]!\n" + "vld1.8 {d4}, [%[a_ptr]]!\n" + "vld1.8 {d6}, [%[a_ptr]]!\n" + "vmovl.s8 q8, d16\n" + "vmovl.s8 q0, d0\n" + "vmovl.s8 q1, d2\n" + "vmovl.s8 q2, d4\n" + "vmovl.s8 q3, d6\n" + + "vld1.32 {d18[0]}, [%[b_ptr]]!\n" + "vmla.s16 d8, d16, d0[0]\n" + "vmla.s16 d10, d16, d2[0]\n" + "vmovl.s8 q9, d18\n" + "vmla.s16 d12, d16, d4[0]\n" + "vmla.s16 d14, d16, d6[0]\n" + + "vld1.32 {d20[0]}, [%[b_ptr]]!\n" + "vmla.s16 d8, d18, d0[1]\n" + "vmla.s16 d10, d18, d2[1]\n" + "vmovl.s8 q10, d20\n" + "vmla.s16 d12, d18, d4[1]\n" + "vmla.s16 d14, d18, d6[1]\n" + + "vld1.32 {d22[0]}, [%[b_ptr]]!\n" + "vmla.s16 d8, d20, d0[2]\n" + "vmla.s16 d10, d20, d2[2]\n" + "vmovl.s8 q11, d22\n" + "vmla.s16 d12, d20, d4[2]\n" + "vmla.s16 d14, d20, d6[2]\n" + + "vld1.32 {d24[0]}, [%[b_ptr]]!\n" + "vmla.s16 d8, d22, d0[3]\n" + "vmla.s16 d10, d22, d2[3]\n" + "vmovl.s8 q12, d24\n" + "vmla.s16 d12, d22, d4[3]\n" + "vmla.s16 d14, d22, d6[3]\n" + + "vld1.32 {d26[0]}, [%[b_ptr]]!\n" + "vmla.s16 d8, d24, d1[0]\n" + "vmla.s16 d10, d24, d3[0]\n" + "vmovl.s8 q13, d26\n" + "vmla.s16 d12, d24, d5[0]\n" + "vmla.s16 d14, d24, d7[0]\n" + + "vld1.32 {d28[0]}, [%[b_ptr]]!\n" + "vmla.s16 d8, d26, d1[1]\n" + "vmla.s16 d10, d26, d3[1]\n" + "vmovl.s8 q14, d28\n" + "vmla.s16 d12, d26, d5[1]\n" + "vmla.s16 d14, d26, d7[1]\n" + + "vld1.32 {d30[0]}, [%[b_ptr]]!\n" + "vmla.s16 d8, d28, d1[2]\n" + "vmla.s16 d10, d28, d3[2]\n" + "vmovl.s8 q15, d30\n" + "vmla.s16 d12, d28, d5[2]\n" + "vmla.s16 d14, d28, d7[2]\n" + + "vmla.s16 d8, d30, d1[3]\n" + "vmla.s16 d10, d30, d3[3]\n" + "vmla.s16 d12, d30, d5[3]\n" + "vmla.s16 d14, d30, d7[3]\n" + + "subs %[K], %[K], #1\n" + "bne 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), + [K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), + [x0] "+r"(x0), [m_remain] "+r"(m_remain), + [n_remain] "+r"(n_remain) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", + "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", + "d29", "d30", "d31", "r1", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_s8x8x16_4x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + int y = y0; + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x8_2_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K); + } + } +} + +static void gemm_s8x8x16_4x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, + int ldin, int x0, int xmax, + int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 8) * 4; + int8_t* outptr = out; + int8_t* outptr_base = out; + + int k = k0; + for (; k < kmax; k += 8) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = x0; + outptr = outptr_base; + + for (; x + 3 < xmax; x += 4) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, xmax - x); + } + + outptr_base += 4 * 8; + } +} + +static void gemm_s8x8x16_4x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 8) * 4; + const int ksize8 = ksize4 * 2; + int8_t* outptr = out; + int8_t* outptr_base = out; + int8_t* outptr_interleave = nullptr; + int8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k += 8) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = x0; + outptr = outptr_base; + + for (; x + 7 < xmax; x += 8) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + outptr_interleave = outptr; + interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr_interleave); + outptr += ksize8; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + outptr_interleave = outptr; + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr_interleave, 4, 4); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + outptr_interleave = outptr; + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr_interleave, 4, xmax - x); + } + + outptr_base += 8 * 8; + outptr_base4 += 4 * 8; + } +} + +static void gemm_s8x8x16_4x8_transpose_pack_B_n(dt_int8* outptr, + const dt_int8* inptr, int ldin, + int y0, int ymax, int k0, + int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + constexpr int interleave4 = 32; + constexpr int interleave8 = 64; + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + for (; K > 7; K -= 8) { + transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += interleave8; + } + + if (K > 0) { + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 8, K); + outptr += interleave8; + } + } + + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 7; K -= 8) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += interleave4; + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K); + outptr += interleave4; + } + } +} +} // namespace matmul_4x8x8 +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp new file mode 100644 index 00000000..93494ef4 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp @@ -0,0 +1,182 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/matrix_mul/int8x8x16/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h" +#include "src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h" +#include "src/common/utils.h" +#include "src/fallback/matrix_mul/gemm_common.h" + +using namespace megdnn; +using namespace armv7; +using namespace armv7::matmul; + +// ===========================gemm_s8x8x16_4x2================================= +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x2); + +void gemm_s8x8x16_4x2::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, + int ymax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_4x2x16::gemm_s8x8x16_4x2_pack_A_t(out, in, ldin, y0, ymax, k0, + kmax); + } else { + matmul_4x2x16::gemm_s8x8x16_4x2_pack_A_n(out, in, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_s8x8x16_4x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_4x2x16::gemm_s8x8x16_4x2_pack_B_t(out, in, ldin, x0, xmax, k0, + kmax); + } else { + matmul_4x2x16::gemm_s8x8x16_4x2_pack_B_n(out, in, ldin, x0, xmax, k0, + kmax); + } +} + +void gemm_s8x8x16_4x2::kern(const dt_int8* packA, const dt_int8* packB, + size_t M, size_t N, size_t K, dt_int16* C, + size_t LDC, bool is_first_k, const dt_int16*, + dt_int16*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int16)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 2; + //! K is packed to times of 4 + K = round_up(K, 16); + const int K4 = K * 4; + const int K2 = K * 2; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int16_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, + is_first_k, 4, 2); + output += B_INTERLEAVE; + cur_packB += K2; + } + + for (; n < N; n += B_INTERLEAVE) { + matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, + is_first_k, 4, std::min(N - n, 2)); + output += B_INTERLEAVE; + cur_packB += K2; + } + + packA += K4; + } + + for (; m < M; m += 4) { + int16_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n < N; n += B_INTERLEAVE) { + matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, + is_first_k, std::min(M - m, 4), + std::min(N - n, 2)); + output += B_INTERLEAVE; + cur_packB += K2; + } + packA += K4; + } +} + +// ===========================gemm_s8x8x16_4x4================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x8); + +void gemm_s8x8x16_4x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, + int ymax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_4x8x8::gemm_s8x8x16_4x8_transpose_pack_A_n(out, in, ldin, y0, + ymax, k0, kmax); + } else { + matmul_4x8x8::gemm_s8x8x16_4x8_pack_A_n(out, in, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_s8x8x16_4x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_4x8x8::gemm_s8x8x16_4x8_transpose_pack_B_n(out, in, ldin, x0, + xmax, k0, kmax); + } else { + matmul_4x8x8::gemm_s8x8x16_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, + kmax); + } +} + +void gemm_s8x8x16_4x8::kern(const dt_int8* packA, const dt_int8* packB, + size_t M, size_t N, size_t K, dt_int16* C, + size_t LDC, bool is_first_k, const dt_int16*, + dt_int16*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int16)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 8; + //! K is packed to times of 8 + K = round_up(K, 8); + const int K4 = K * 4; + const int K8 = K * 8; + + size_t m = 0; + for (; m < M; m += A_INTERLEAVE) { + int16_t* output = C + (m * LDC); + const dt_int8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_4x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_4x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), + std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + packA += K4; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h new file mode 100644 index 00000000..98d24bcd --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h @@ -0,0 +1,28 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8x8x16/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 2, 16, false, true, + gemm_s8x8x16_4x2); + +MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 8, 8, false, true, + gemm_s8x8x16_4x8); + +} // namespace matmul +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/opr_impl.cpp b/dnn/src/armv7/matrix_mul/opr_impl.cpp new file mode 100644 index 00000000..4244b077 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/opr_impl.cpp @@ -0,0 +1,82 @@ +/** + * \file dnn/src/armv7/matrix_mul/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/armv7/matrix_mul/opr_impl.h" +#include "src/armv7/matrix_mul/algos.h" +#include "src/common/metahelper.h" +#include "src/common/utils.h" +#include "src/fallback/matrix_mul/gemm_impl.h" +#include "src/naive/handle.h" + +using namespace megdnn; +using namespace armv7; + +class MatrixMulImpl::AlgoPack : NonCopyableObj { + AlgoF32 f32; + AlgoF32MK4_4x8 f32_mk4_4x8; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + AlgoF16K4x16x1 f16_k4x16x1; + AlgoF16MK8_4x8 f16_mk8_4x8; +#endif +#if __ARM_FEATURE_DOTPROD + AlgoInt8x8x32K6x8x4 int8_k6x8x4; + AlgoQuint8DotK4x8x4 quint8_k4x8x4; +#endif + AlgoF32Gemv f32_gemv; + AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16; + AlgoInt8x8x32K4x2x16 int8x8x32_k4x2x16; + AlgoInt8x8x32K4x8x8 int8x8x32_k4x8x8; +#if !__ARM_FEATURE_DOTPROD + AlgoInt8x8x32Gemv int8x8x32_gemv; +#endif + AlgoQuint8K4x8x8 quint8_k4x8x8; + AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16; + AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8; + AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1; + AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; + +public: + SmallVector all_algos; + + AlgoPack() { + all_algos.emplace_back(&f32_gemv); + all_algos.emplace_back(&f32); + all_algos.emplace_back(&f32_mk4_4x8); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + all_algos.emplace_back(&f16_k4x16x1); + all_algos.emplace_back(&f16_mk8_4x8); +#endif +#if __ARM_FEATURE_DOTPROD + all_algos.emplace_back(&int8_k6x8x4); + all_algos.emplace_back(&quint8_k4x8x4); +#endif +#if !__ARM_FEATURE_DOTPROD + all_algos.emplace_back(&int8x8x32_gemv); +#endif + all_algos.emplace_back(&int8x8x32_mk4_4x2x16); + all_algos.emplace_back(&int8x8x32_k4x8x8); + all_algos.emplace_back(&int8x8x32_k4x2x16); + all_algos.emplace_back(&quint8_k4x8x8); + all_algos.emplace_back(&int8x8x16_k4x8x8); + all_algos.emplace_back(&int8x8x16_k4x2x16); + all_algos.emplace_back(&int16x16x32_k12x4x1); + all_algos.emplace_back(&int16x16x32_mk8_4x8); + } +}; + +SmallVector MatrixMulImpl::algo_pack() { + static AlgoPack s_algo_pack; + auto algos = arm_common::MatrixMulImpl::algo_pack(); + algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), + s_algo_pack.all_algos.end()); + return algos; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/opr_impl.h b/dnn/src/armv7/matrix_mul/opr_impl.h new file mode 100644 index 00000000..51e74db0 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/opr_impl.h @@ -0,0 +1,51 @@ +/** + * \file dnn/src/armv7/matrix_mul/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/arm_common/matrix_mul/opr_impl.h" + +namespace megdnn { +namespace armv7 { + +class MatrixMulImpl : public arm_common::MatrixMulImpl { +public: + using arm_common::MatrixMulImpl::MatrixMulImpl; + + SmallVector algo_pack() override; +private: + class AlgoF32; // Armv7 F32 + class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack + class AlgoF32Gemv; // Armv7 F32 Gemv + class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 + class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16 + class AlgoInt8x8x32MK4_4x2x16; // Armv7 Int8x8x32 Kernel MK4 4x2x16 +#if !__ARM_FEATURE_DOTPROD + class AlgoInt8x8x32Gemv; // Armv7 Int8x8x32 Gemv +#endif + class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 + class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 + class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 + class AlgoInt16x16x32K12x4x1; // Armv7 Int16x16x32 Kernel 12x4x1 + class AlgoInt16x16x32MK8_4x8; // Armv7 Int16x16x32 MK8 Format block 4x8 +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + class AlgoF16K4x16x1; // Armv7 F16 Kernel 4x16x1 + class AlgoF16MK8_4x8; // Armv7 F16 MK8 Format block 4x8 +#endif +#if __ARM_FEATURE_DOTPROD + class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 + class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 +#endif + class AlgoPack; +}; + +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/quint8/kernel_4x8x8.h b/dnn/src/armv7/matrix_mul/quint8/kernel_4x8x8.h new file mode 100644 index 00000000..31afa188 --- /dev/null +++ b/dnn/src/armv7/matrix_mul/quint8/kernel_4x8x8.h @@ -0,0 +1,805 @@ +/** + * \file dnn/src/armv7/matrix_mul/quint8/kernel_4x8x8.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul_4x8x8 { +/** + * Overview of register layout: + * + * A 8x8x8 cell of Rhs is stored in 8bit in q12-q13 + * A 8x8x4 cell of Lhs is stored in 8bit in q0-q3 + * A 4x8 block of accumulators is stored in 32bit in q4-q11 + * zero_point_A is stored in 8bit in q14 + * zero_point_B is stored in 8bit in q15. + * + * +--------+--------+ + * |v12[0-8]|v13[0-8]| + * Rhs +--------+--------+ + * Lhs | | | + * + * +--------+ - - - - +-----------------+ + * |v0[0-8]| | v4[0-4]| v5[0-4]| + * |v1[0-8]| | v6[0-4]| v7[0-4]| + * |v2[0-8]| | v8[0-4]| v9[0-4]| + * |v3[0-8]| |v10[0-4]|v11[0-4]| + * +--------+ - - - - +-----------------+ + * + * Accumulator + */ +static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, size_t m_remain, + uint8_t za, uint8_t zb) { + K /= 8; + const uint8_t* a_ptr = packA; + const uint8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(reg_index1, reg_index2, reg_index3, reg_index4, n) \ + "cmp %[x0], #0 \n" \ + "beq 100f\n" \ + "vld1.32 {d" reg_index1 ", d" reg_index2 ", d" reg_index3 ", d" \ + reg_index4 "}, [r" n "]!\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %[m_remain]\n" \ + LOAD_LINE("8", "9", "10", "11", "0") \ + LOAD_LINE("12", "13", "14", "15", "1") \ + LOAD_LINE("16", "17", "18", "19", "2") \ + LOAD_LINE("20", "21", "22", "23", "3") \ + "100:\n" + +#define STORE_LINE(reg_index1, reg_index2, reg_index3, reg_index4, n) \ + "cmp %[x0], #0 \n" \ + "beq 101f\n" \ + "vst1.32 {d" reg_index1 ", d" reg_index2 ", d" reg_index3 ", d" \ + reg_index4 "}, [r" n "]!\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %[m_remain]\n" \ + STORE_LINE("8", "9", "10", "11", "0") \ + STORE_LINE("12", "13", "14", "15", "1") \ + STORE_LINE("16", "17", "18", "19", "2") \ + STORE_LINE("20", "21", "22", "23", "3") \ + "101:\n" + + // clang-format on + + register int32_t* outptr asm("r0") = output; + asm volatile( + // load accumulator C + "add r1, r0, %[LDC]\n" + "add r2, r1, %[LDC]\n" + "add r3, r2, %[LDC]\n" + "vdup.8 d28, %[za]\n" + "vdup.8 d30, %[zb]\n" + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q5, q5, q5\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q7, q7, q7\n" + "veor.s32 q8, q8, q8\n" + "veor.s32 q9, q9, q9\n" + "veor.s32 q10, q10, q10\n" + "veor.s32 q11, q11, q11\n" + + "2: \n" + "vld1.8 {d24}, [%[b_ptr]]!\n" + "vld1.8 {d0}, [%[a_ptr]]!\n" + "vld1.8 {d2}, [%[a_ptr]]!\n" + "vld1.8 {d4}, [%[a_ptr]]!\n" + "vld1.8 {d6}, [%[a_ptr]]!\n" + "vsubl.u8 q12, d24, d30\n" + "vsubl.u8 q0, d0, d28\n" + "vsubl.u8 q1, d2, d28\n" + "vsubl.u8 q2, d4, d28\n" + "vsubl.u8 q3, d6, d28\n" + + "vld1.8 {d26}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d24, d0[0]\n" + "vmlal.s16 q6, d24, d2[0]\n" + "vmlal.s16 q8, d24, d4[0]\n" + "vmlal.s16 q10, d24, d6[0]\n" + "vsubl.u8 q13, d26, d30\n" + "vmlal.s16 q5, d25, d0[0]\n" + "vmlal.s16 q7, d25, d2[0]\n" + "vmlal.s16 q9, d25, d4[0]\n" + "vmlal.s16 q11, d25, d6[0]\n" + + "vld1.8 {d24}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d26, d0[1]\n" + "vmlal.s16 q6, d26, d2[1]\n" + "vmlal.s16 q8, d26, d4[1]\n" + "vmlal.s16 q10, d26, d6[1]\n" + "vsubl.u8 q12, d24, d30\n" + "vmlal.s16 q5, d27, d0[1]\n" + "vmlal.s16 q7, d27, d2[1]\n" + "vmlal.s16 q9, d27, d4[1]\n" + "vmlal.s16 q11, d27, d6[1]\n" + + "vld1.8 {d26}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d24, d0[2]\n" + "vmlal.s16 q6, d24, d2[2]\n" + "vmlal.s16 q8, d24, d4[2]\n" + "vmlal.s16 q10, d24, d6[2]\n" + "vsubl.u8 q13, d26, d30\n" + "vmlal.s16 q5, d25, d0[2]\n" + "vmlal.s16 q7, d25, d2[2]\n" + "vmlal.s16 q9, d25, d4[2]\n" + "vmlal.s16 q11, d25, d6[2]\n" + + "vld1.8 {d24}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d26, d0[3]\n" + "vmlal.s16 q6, d26, d2[3]\n" + "vmlal.s16 q8, d26, d4[3]\n" + "vmlal.s16 q10, d26, d6[3]\n" + "vsubl.u8 q12, d24, d30\n" + "vmlal.s16 q5, d27, d0[3]\n" + "vmlal.s16 q7, d27, d2[3]\n" + "vmlal.s16 q9, d27, d4[3]\n" + "vmlal.s16 q11, d27, d6[3]\n" + + "vld1.8 {d26}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d24, d1[0]\n" + "vmlal.s16 q6, d24, d3[0]\n" + "vmlal.s16 q8, d24, d5[0]\n" + "vmlal.s16 q10, d24, d7[0]\n" + "vsubl.u8 q13, d26, d30\n" + "vmlal.s16 q5, d25, d1[0]\n" + "vmlal.s16 q7, d25, d3[0]\n" + "vmlal.s16 q9, d25, d5[0]\n" + "vmlal.s16 q11, d25, d7[0]\n" + + "vld1.8 {d24}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d26, d1[1]\n" + "vmlal.s16 q6, d26, d3[1]\n" + "vmlal.s16 q8, d26, d5[1]\n" + "vmlal.s16 q10, d26, d7[1]\n" + "vsubl.u8 q12, d24, d30\n" + "vmlal.s16 q5, d27, d1[1]\n" + "vmlal.s16 q7, d27, d3[1]\n" + "vmlal.s16 q9, d27, d5[1]\n" + "vmlal.s16 q11, d27, d7[1]\n" + + "vld1.8 {d26}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d24, d1[2]\n" + "vmlal.s16 q6, d24, d3[2]\n" + "vmlal.s16 q8, d24, d5[2]\n" + "vmlal.s16 q10, d24, d7[2]\n" + "vsubl.u8 q13, d26, d30\n" + "vmlal.s16 q5, d25, d1[2]\n" + "vmlal.s16 q7, d25, d3[2]\n" + "vmlal.s16 q9, d25, d5[2]\n" + "vmlal.s16 q11, d25, d7[2]\n" + + "vmlal.s16 q4, d26, d1[3]\n" + "vmlal.s16 q6, d26, d3[3]\n" + "vmlal.s16 q8, d26, d5[3]\n" + "vmlal.s16 q10, d26, d7[3]\n" + "vmlal.s16 q5, d27, d1[3]\n" + "vmlal.s16 q7, d27, d3[3]\n" + "vmlal.s16 q9, d27, d5[3]\n" + "vmlal.s16 q11, d27, d7[3]\n" + + "subs %[K], %[K], #1\n" + "bne 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [m_remain] "+r"(m_remain), [za] "+r"(za), [zb] "+r"(zb), + [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", + "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", + "d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +/** + * Overview of register layout: + * + * A 8x4x8 cell of Rhs is stored in 8bit in q8-q9 + * A 8x8x4 cell of Lhs is stored in 8bit in q0-q3 + * A 4x4 block of accumulators is stored in 32bit in q4-q7 + * zero_point_A is stored in 8bit in q10 + * zero_point_B is stored in 8bit in q11. + * + * +--------+ + * | v8[0-4]| + * Rhs +--------+ + * | v9[0-4]| + * Lhs +--------+ + * + * +--------+ - - - - +--------+ + * |v0[0-8]| | v4[0-4]| + * |v1[0-8]| | v5[0-4]| + * |v2[0-8]| | v6[0-4]| + * |v3[0-8]| | v7[0-4]| + * +--------+ - - - - +--------+ + * + * Accumulator + */ +static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, size_t m_remain, + size_t n_remain, uint8_t za, uint8_t zb) { + K /= 8; + const uint8_t* a_ptr = packA; + const uint8_t* b_ptr = packB; + + LDC = LDC * sizeof(int32_t); + size_t x0 = 0; + +// clang-format off +#define LOAD_LINE(reg_index1, reg_index2, n) \ + "cmp %[x0], #0 \n" \ + "beq 102f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "vld1.32 {d" reg_index1 ", d" reg_index2 "}, [r" n " ]!\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" reg_index1 "[0]}, [r" n " ]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" reg_index1 "[1]}, [r" n " ]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" reg_index2 "[0]}, [r" n " ]!\n" \ + "101" n ":\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %[m_remain]\n" \ + "mov r1, r0\n" \ + LOAD_LINE("8", "9", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + LOAD_LINE("10", "11", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + LOAD_LINE("12", "13", "1") \ + "add r1, r0, %[LDC]\n" \ + LOAD_LINE("14", "15", "1") \ + "102:\n" + +#define STORE_LINE(reg_index1, reg_index2, n) \ + "cmp %[x0], #0 \n" \ + "beq 105f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "vst1.32 {d" reg_index1 ", d" reg_index2 "}, [r" n " ]!\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" reg_index1 "[0]}, [r" n " ]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" reg_index1 "[1]}, [r" n " ]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" reg_index2 "[0]}, [r" n " ]!\n" \ + "104" n ":\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %[m_remain]\n" \ + "mov r1, r0\n" \ + STORE_LINE("8", "9", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + STORE_LINE("10", "11", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + STORE_LINE("12", "13", "1") \ + "add r1, r0, %[LDC]\n" \ + STORE_LINE("14", "15", "1") \ + "105:\n" + + // clang-format on + + register int32_t* outptr asm("r0") = output; + asm volatile( + // load accumulator C + "vdup.8 d20, %[za]\n" + "vdup.8 d22, %[zb]\n" + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q5, q5, q5\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q7, q7, q7\n" + + "2: \n" + "vld1.32 {d16[0]}, [%[b_ptr]]!\n" + "vld1.8 {d0}, [%[a_ptr]]!\n" + "vld1.8 {d2}, [%[a_ptr]]!\n" + "vld1.8 {d4}, [%[a_ptr]]!\n" + "vld1.8 {d6}, [%[a_ptr]]!\n" + "vsubl.u8 q8, d16, d22\n" + "vsubl.u8 q0, d0, d20\n" + "vsubl.u8 q1, d2, d20\n" + "vsubl.u8 q2, d4, d20\n" + "vsubl.u8 q3, d6, d20\n" + + "vld1.32 {d18[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d16, d0[0]\n" + "vmlal.s16 q5, d16, d2[0]\n" + "vsubl.u8 q9, d18, d22\n" + "vmlal.s16 q6, d16, d4[0]\n" + "vmlal.s16 q7, d16, d6[0]\n" + + "vld1.32 {d16[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d18, d0[1]\n" + "vmlal.s16 q5, d18, d2[1]\n" + "vsubl.u8 q8, d16, d22\n" + "vmlal.s16 q6, d18, d4[1]\n" + "vmlal.s16 q7, d18, d6[1]\n" + + "vld1.32 {d18[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d16, d0[2]\n" + "vmlal.s16 q5, d16, d2[2]\n" + "vsubl.u8 q9, d18, d22\n" + "vmlal.s16 q6, d16, d4[2]\n" + "vmlal.s16 q7, d16, d6[2]\n" + + "vld1.32 {d16[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d18, d0[3]\n" + "vmlal.s16 q5, d18, d2[3]\n" + "vsubl.u8 q8, d16, d22\n" + "vmlal.s16 q6, d18, d4[3]\n" + "vmlal.s16 q7, d18, d6[3]\n" + + "vld1.32 {d18[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d16, d1[0]\n" + "vmlal.s16 q5, d16, d3[0]\n" + "vsubl.u8 q9, d18, d22\n" + "vmlal.s16 q6, d16, d5[0]\n" + "vmlal.s16 q7, d16, d7[0]\n" + + "vld1.32 {d16[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d18, d1[1]\n" + "vmlal.s16 q5, d18, d3[1]\n" + "vsubl.u8 q8, d16, d22\n" + "vmlal.s16 q6, d18, d5[1]\n" + "vmlal.s16 q7, d18, d7[1]\n" + + "vld1.32 {d18[0]}, [%[b_ptr]]!\n" + "vmlal.s16 q4, d16, d1[2]\n" + "vmlal.s16 q5, d16, d3[2]\n" + "vsubl.u8 q9, d18, d22\n" + "vmlal.s16 q6, d16, d5[2]\n" + "vmlal.s16 q7, d16, d7[2]\n" + + "vmlal.s16 q4, d18, d1[3]\n" + "vmlal.s16 q5, d18, d3[3]\n" + "vmlal.s16 q6, d18, d5[3]\n" + "vmlal.s16 q7, d18, d7[3]\n" + + "subs %[K], %[K], #1\n" + "bne 2b\n" + + "3:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), + [K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), + [x0] "+r"(x0), [m_remain] "+r"(m_remain), + [n_remain] "+r"(n_remain), [za] "+r"(za), [zb] "+r"(zb) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", + "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", + "d29", "d30", "d31", "r1", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_u8_4x8_pack_A_n(dt_uint8* outptr, const dt_uint8* inptr, + int ldin, int y0, int ymax, int k0, int kmax, + uint8_t zero_point) { + uint8_t zerobuff[16]; + std::fill(zerobuff, zerobuff + 16, zero_point); + + int y = y0; + for (; y < ymax; y += 4) { + const uint8_t* inptr0 = inptr + y * ldin + k0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x8_2_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K, + zero_point); + } + } +} + +static void gemm_u8_4x8_transpose_pack_A_n(dt_uint8* out, const dt_uint8* in, + int ldin, int x0, int xmax, int k0, + int kmax, uint8_t zero_point) { + uint8_t zerobuff[16]; + std::fill(zerobuff, zerobuff + 16, zero_point); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 8) * 4; + uint8_t* outptr = out; + uint8_t* outptr_base = out; + + int k = k0; + for (; k < kmax; k += 8) { + const uint8_t* inptr0 = in + k * ldin + x0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + const uint8_t* inptr4 = inptr3 + ldin; + const uint8_t* inptr5 = inptr4 + ldin; + const uint8_t* inptr6 = inptr5 + ldin; + const uint8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = x0; + outptr = outptr_base; + + for (; x + 3 < xmax; x += 4) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, xmax - x, zero_point); + } + + outptr_base += 4 * 8; + } +} + +static void gemm_u8_4x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, + int x0, int xmax, int k0, int kmax, + uint8_t zero_point) { + uint8_t zerobuff[16]; + std::fill(zerobuff, zerobuff + 16, zero_point); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 8) * 4; + const int ksize8 = ksize4 * 2; + uint8_t* outptr = out; + uint8_t* outptr_base = out; + uint8_t* outptr_interleave = nullptr; + uint8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k += 8) { + const uint8_t* inptr0 = in + k * ldin + x0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + const uint8_t* inptr4 = inptr3 + ldin; + const uint8_t* inptr5 = inptr4 + ldin; + const uint8_t* inptr6 = inptr5 + ldin; + const uint8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x = x0; + outptr = outptr_base; + + for (; x + 7 < xmax; x += 8) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + outptr_interleave = outptr; + interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr_interleave); + outptr += ksize8; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + outptr_interleave = outptr; + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr_interleave, 4, 4, zero_point); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 7 >= kmax) { + switch (k + 7 - kmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + outptr_interleave = outptr; + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr_interleave, 4, xmax - x, zero_point); + } + + outptr_base += 8 * 8; + outptr_base4 += 4 * 8; + } +} + +static void gemm_u8_4x8_transpose_pack_B_n(dt_uint8* outptr, + const dt_uint8* inptr, int ldin, + int y0, int ymax, int k0, int kmax, + uint8_t zero_point) { + uint8_t zerobuff[16]; + std::fill(zerobuff, zerobuff + 16, zero_point); + constexpr int interleave4 = 32; + constexpr int interleave8 = 64; + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const uint8_t* inptr0 = inptr + y * ldin + k0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + const uint8_t* inptr4 = inptr3 + ldin; + const uint8_t* inptr5 = inptr4 + ldin; + const uint8_t* inptr6 = inptr5 + ldin; + const uint8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + for (; K > 7; K -= 8) { + transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += interleave8; + } + + if (K > 0) { + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 8, K, zero_point); + outptr += interleave8; + } + } + + for (; y < ymax; y += 4) { + const uint8_t* inptr0 = inptr + y * ldin + k0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 7; K -= 8) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += interleave4; + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K, + zero_point); + outptr += interleave4; + } + } +} + +} // namespace matmul_4x8x8 +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h b/dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h new file mode 100644 index 00000000..6d92cebd --- /dev/null +++ b/dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h @@ -0,0 +1,739 @@ +/** + * \file dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_DOTPROD + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul_dot_4x8x4 { + +// Overview of register layout: +// +// A 8x4 cell of Rhs is stored in 8bit in q2-q3. +// A 4x4 ping-pong cell of Lhs is stored in 8bit in q0-q1 +// A 4x8 block of accumulators is stored in 8bit in q4-q11 +// A 3x4 sum zero point ZA,ZB,ZAB stroed in q1,q12,q13 +// +// +--------+--------+ +// |q2[0-16]|q3[0-16]| +// Rhs +--------+--------+ +// +// | | | +// +// Lhs | | | +// +// +-------+-------+ - - - - +--------+--------+ +// |q0[0-4]| | q4[0-4]| q5[0-4]| +// |q0[0-4]| | q6[0-4]| q7[0-4]| +// |q0[0-4]| | q8[0-4]| q9[0-4]| +// |q0[0-4]| |q10[0-4]|q11[0-4]| +// +-------+-------+ - - - - +--------+--------+--------+ +// +// Accumulator + +static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, uint8_t zA, + uint8_t zB, uint32_t zAB, size_t m_remain = 4) { + K /= 4; + const uint8_t* a_ptr = packA; + const uint8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + + register uint8x16_t za asm("q14") = vdupq_n_u8(zA); + register uint8x16_t zb asm("q15") = vdupq_n_u8(zB); + + register int32_t* outptr0 asm("r0") = output; + register int32_t* outptr1 asm("r1") = outptr0 + LDC; + register int32_t* outptr2 asm("r2") = outptr1 + LDC; + register int32_t* outptr3 asm("r3") = outptr2 + LDC; + +// clang-format off +#define LOAD_LINE(reg_index1, reg_index2, reg_index3, reg_index4, n) \ + "cmp r12, #0 \n" \ + "beq 100f\n" \ + "vld1.32 {d" reg_index1 ", d" reg_index2 ", d" reg_index3 ", d" \ + reg_index4 "}, [r" n "]\n" \ + "subs r12, r12, #1\n" + +#define LOAD_C \ + "mov r12, %[m_remain]\n" \ + LOAD_LINE("8", "9", "10", "11", "0") \ + LOAD_LINE("12", "13", "14", "15", "1") \ + LOAD_LINE("16", "17", "18", "19", "2") \ + LOAD_LINE("20", "21", "22", "23", "3") \ + "100:\n" + +#define STORE_LINE(reg_index1, reg_index2, reg_index3, reg_index4, n) \ + "cmp r12, #0 \n" \ + "beq 101f\n" \ + "vst1.32 {d" reg_index1 ", d" reg_index2 ", d" reg_index3 ", d" \ + reg_index4 "}, [r" n "]!\n" \ + "subs r12, r12, #1\n" + +#define STORE_C \ + "mov r12, %[m_remain]\n" \ + STORE_LINE("8", "9", "10", "11", "0") \ + STORE_LINE("12", "13", "14", "15", "1") \ + STORE_LINE("16", "17", "18", "19", "2") \ + STORE_LINE("20", "21", "22", "23", "3") \ + "101:\n" + + // clang-format on + + asm volatile( + "pld [%[a_ptr]] \n" + "pld [%[b_ptr]] \n" + + "cmp %[is_first_k], #1 \n" + "beq 5f \n" + "cmp %[m_remain], #4 \n" + "beq 7f \n" LOAD_C + "b 6f \n" + + "7:\n" + "vld1.s32 {q4, q5}, [%[outptr0]]\n" + "vld1.s32 {q6, q7}, [%[outptr1]]\n" + "vld1.s32 {q8, q9}, [%[outptr2]]\n" + "vld1.s32 {q10, q11}, [%[outptr3]]\n" + "b 6f \n" + + "5:\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q5, q5, q5\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q7, q7, q7\n" + "veor.s32 q8, q8, q8\n" + "veor.s32 q9, q9, q9\n" + "veor.s32 q10, q10, q10\n" + "veor.s32 q11, q11, q11\n" + + "6: \n" + "veor.s32 q12, q12, q12\n" + "veor.s32 q13, q13, q13\n" + "veor.s32 q1, q1, q1\n" + + "vld1.u8 {q0}, [%[a_ptr]]!\n" + "vld1.u8 {q2}, [%[b_ptr]]!\n" + + // Skip loop if we are doing zero iterations of it. + "cmp %[k], #0 \n" + "beq 4f \n" + + // Loop proper + "1:\n" + "vudot.u8 q12, q2, %[za] \n" + "vld1.u8 {q3}, [%[b_ptr]]!\n" + "vudot.u8 q1, q0, %[zb] \n" + "vudot.u8 q4 , q2, d0[0]\n" + "vudot.u8 q6 , q2, d0[1]\n" + "vudot.u8 q8 , q2, d1[0]\n" + "vudot.u8 q10 , q2, d1[1]\n" + + "vudot.u8 q5, q3, d0[0]\n" + "vudot.u8 q7, q3, d0[1]\n" + "vld1.u8 {q2}, [%[b_ptr]]!\n" + "vudot.u8 q9, q3, d1[0]\n" + "vudot.u8 q11, q3, d1[1]\n" + + "vld1.u8 {q0}, [%[a_ptr]]!\n" + "vudot.u8 q13, q3, %[za] \n" + /////////////////////////////////////// + "vudot.u8 q12, q2, %[za] \n" + "vld1.u8 {q3}, [%[b_ptr]]!\n" + "vudot.u8 q1, q0, %[zb] \n" + "vudot.u8 q4 , q2, d0[0]\n" + "vudot.u8 q6 , q2, d0[1]\n" + "vudot.u8 q8 , q2, d1[0]\n" + "vudot.u8 q10 , q2, d1[1]\n" + + "vudot.u8 q5, q3, d0[0]\n" + "vudot.u8 q7, q3, d0[1]\n" + "vld1.u8 {q2}, [%[b_ptr]]!\n" + "vudot.u8 q9, q3, d1[0]\n" + "vudot.u8 q11, q3, d1[1]\n" + + "pld [%[b_ptr]] \n" + "subs %[k], %[k], #1\n" + "pld [%[a_ptr]] \n" + + "vld1.u8 {q0}, [%[a_ptr]]!\n" + "vudot.u8 q13, q3, %[za] \n" + "bne 1b\n" + + "4:\n" + // Branch to alternative tail for even K + "cmp %[oddk], #0 \n" + "beq 2f \n" + + "vudot.u8 q12, q2, %[za] \n" + "vld1.u8 {q3}, [%[b_ptr]]!\n" + "vudot.u8 q1, q0, %[zb] \n" + "vudot.u8 q4 , q2, d0[0]\n" + "vudot.u8 q6 , q2, d0[1]\n" + "vudot.u8 q8 , q2, d1[0]\n" + "vudot.u8 q10 , q2, d1[1]\n" + + "vudot.u8 q5, q3, d0[0]\n" + "vudot.u8 q7, q3, d0[1]\n" + "vudot.u8 q9, q3, d1[0]\n" + "vudot.u8 q11, q3, d1[1]\n" + "vudot.u8 q13, q3, %[za] \n" + + "2:\n" + "vdup.s32 q2, %[zab]\n" + "vsub.s32 q1, q1, q2 \n" // sub zab + + "vdup.s32 q3, d2[1]\n" + "vdup.s32 q2, d2[0]\n" + "vsub.s32 q6, q6, q3\n" + "vsub.s32 q7, q7, q3\n" + "vsub.s32 q4, q4, q2\n" + "vsub.s32 q5, q5, q2\n" + "vsub.s32 q6, q6, q12\n" + "vsub.s32 q7, q7, q13\n" + "vsub.s32 q4, q4, q12\n" + "vsub.s32 q5, q5, q13\n" + + "vdup.s32 q2, d3[0]\n" + "vdup.s32 q3, d3[1]\n" + "vsub.s32 q8, q8, q2\n" + "vsub.s32 q9, q9, q2\n" + "vsub.s32 q10, q10, q3\n" + "vsub.s32 q11, q11, q3\n" + "vsub.s32 q8, q8, q12\n" + "vsub.s32 q9, q9, q13\n" + "vsub.s32 q10, q10, q12\n" + "vsub.s32 q11, q11, q13\n" + + STORE_C + + : [k] "+r"(k), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), + [m_remain] "+r"(m_remain), [za] "+w"(za), [zb] "+w"(zb), + [zab] "+r"(zAB), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), + [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "cc", "r12", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} +// Overview of register layout: +// +// A 4x4 cell of Rhs is stored in 8bit in q2-q3. +// A 4x4 ping-pong cell of Lhs is stored in 8bit in q0-q1 +// A 4x8 block of accumulators is stored in 8bit in q4-q10 +// A 2x4 sum zero point ZA,ZB,ZAB stroed in q1,q12 +// +// +--------+ +// |q2[0-16]| +// Rhs +--------+ +// +// | | +// +// Lhs | | +// +// +-------+-------+ - - - - +--------+ +// |q0[0-4]| | q4[0-4]| +// |q0[0-4]| | q6[0-4]| +// |q0[0-4]| | q8[0-4]| +// |q0[0-4]| |q10[0-4]| +// +-------+-------+ - - - - +--------+--------+--------+ +// +// Accumulator +static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, uint8_t zA, + uint8_t zB, uint32_t zAB, size_t m_remain = 4, + size_t n_remain = 4) { + K /= 4; + const uint8_t* a_ptr = packA; + const uint8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + LDC = LDC * sizeof(int32_t); + + register uint8x16_t za asm("q14") = vdupq_n_u8(zA); + register uint8x16_t zb asm("q15") = vdupq_n_u8(zB); + + register int32_t* outptr0 asm("r2") = output; + size_t x0 = 0; + + // clang-format off +#define LOAD_LINE(reg_index1, reg_index2, n) \ + "cmp %[x0], #0 \n" \ + "beq 102f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "vld1.32 {d" reg_index1 ", d" reg_index2 "}, [r" n " ]!\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" reg_index1 "[0]}, [r" n " ]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" reg_index1 "[1]}, [r" n " ]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" reg_index2 "[0]}, [r" n " ]!\n" \ + "101" n ":\n" \ + "subs %[x0], %[x0], #1\n" + +#define LOAD_C \ + "mov %[x0], %[m_remain]\n" \ + "mov r0, r2\n" \ + "mov r1, r0\n" \ + LOAD_LINE("8", "9", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + LOAD_LINE("12", "13", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + LOAD_LINE("16", "17", "1") \ + "add r1, r0, %[LDC]\n" \ + LOAD_LINE("20", "21", "1") \ + "102:\n" + +#define STORE_LINE(reg_index1, reg_index2, n) \ + "cmp %[x0], #0 \n" \ + "beq 105f\n" \ + "cmp %[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "vst1.32 {d" reg_index1 ", d" reg_index2 "}, [r" n " ]!\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" reg_index1 "[0]}, [r" n " ]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" reg_index1 "[1]}, [r" n " ]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "vst1.32 {d" reg_index2 "[0]}, [r" n " ]!\n" \ + "104" n ":\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %[m_remain]\n" \ + "mov r1, r2\n" \ + "mov r0, r2\n" \ + STORE_LINE("8", "9", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + STORE_LINE("12", "13", "1") \ + "add r1, r0, %[LDC]\n" \ + "add r0, r0, %[LDC]\n" \ + STORE_LINE("16", "17", "1") \ + "add r1, r0, %[LDC]\n" \ + STORE_LINE("20", "21", "1") \ + "105:\n" + + // clang-format on + + asm volatile( + "pld [%[a_ptr]] \n" + "pld [%[b_ptr]] \n" + + "cmp %[is_first_k], #1 \n" + "beq 5f \n" LOAD_C + "b 6f \n" + + "5:\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q8, q8, q8\n" + "veor.s32 q10, q10, q10\n" + + "6: \n" + "veor.s32 q12, q12, q12\n" + "veor.s32 q1, q1, q1\n" + + "vld1.u8 {q0}, [%[a_ptr]]!\n" + "vld1.u8 {q2}, [%[b_ptr]]!\n" + + // Skip loop if we are doing zero iterations of it. + "cmp %[k], #0 \n" + "beq 4f \n" + + // Loop proper + "1:\n" + "vudot.u8 q12, q2, %[za]\n" + "vld1.u8 {q3}, [%[b_ptr]]!\n" + "vudot.u8 q1, q0, %[zb] \n" + "vudot.u8 q4, q2, d0[0]\n" + "vld1.u8 {q5}, [%[a_ptr]]!\n" + "vudot.u8 q6 , q2, d0[1]\n" + "vudot.u8 q8 , q2, d1[0]\n" + "vudot.u8 q10 , q2, d1[1]\n" + + /////////////////////////////////////// + "vudot.u8 q12, q3, %[za] \n" + "vudot.u8 q1, q5, %[zb] \n" + "vld1.u8 {q2}, [%[b_ptr]]!\n" + "vudot.u8 q4 , q3, d10[0]\n" + "vudot.u8 q6 , q3, d10[1]\n" + "vudot.u8 q8 , q3, d11[0]\n" + "vudot.u8 q10 , q3, d11[1]\n" + "vld1.u8 {q0}, [%[a_ptr]]!\n" + + "pld [%[b_ptr]] \n" + "subs %[k], %[k], #1\n" + "pld [%[a_ptr]] \n" + "bne 1b\n" + + "4:\n" + // Branch to alternative tail for even K + "cmp %[oddk], #0 \n" + "beq 2f \n" + + "vudot.u8 q12, q2, %[za]\n" + "vudot.u8 q1, q0, %[zb] \n" + "vudot.u8 q4, q2, d0[0]\n" + "vudot.u8 q6 , q2, d0[1]\n" + "vudot.u8 q8 , q2, d1[0]\n" + "vudot.u8 q10 , q2, d1[1]\n" + + "2:\n" + "vdup.s32 q2, %[zab]\n" + "vsub.s32 q1, q1, q2 \n" // sub zab + + "vdup.s32 q3, d2[1]\n" + "vdup.s32 q2, d2[0]\n" + "vdup.s32 q5, d3[0]\n" + "vdup.s32 q7, d3[1]\n" + + "vsub.s32 q4, q4, q2\n" + "vsub.s32 q6, q6, q3\n" + "vsub.s32 q8, q8, q5\n" + "vsub.s32 q10, q10, q7\n" + + "vsub.s32 q4, q4, q12\n" + "vsub.s32 q6, q6, q12\n" + "vsub.s32 q8, q8, q12\n" + "vsub.s32 q10, q10, q12\n" STORE_C + + : [k] "+r"(k), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [oddk] "+r"(oddk), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [za] "+w"(za), [zb] "+w"(zb), [zab] "+r"(zAB), + [outptr0] "+r"(outptr0), [m_remain] "+r"(m_remain), + [n_remain] "+r"(n_remain), [x0] "+r"(x0) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "r0", "r1", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + + +static void gemm_quint8_4x8_pack_A_n(dt_uint8* outptr, const dt_uint8* inptr, + int ldin, int y0, int ymax, int k0, int kmax) { + uint8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(uint8_t) * 16); + + int y = y0; + for (; y < ymax; y += 4) { + const uint8_t* inptr0 = inptr + y * ldin + k0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + int K = kmax - k0; + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, K); + } + } +} + +static void gemm_quint8_4x8_pack_A_t(dt_uint8* out, const dt_uint8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + uint8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(uint8_t) * 16); + const int ksize = kmax - k0; + const int ksize4 = round_up(ksize, 4) * 4; + uint8_t* outptr = out; + uint8_t* outptr_base = out; + + int k = k0; + for (; k < kmax; k += 4) { + const uint8_t* inptr0 = in + k * ldin + x0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int x = x0; + outptr = outptr_base; + for (; x + 4 < xmax; x += 4) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += ksize4; + } + if (x < xmax) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + outptr_base += 4 * 4; + } +} + +static void gemm_quint8_4x8_pack_B_n(dt_uint8* out, const dt_uint8* in, + int ldin, int x0, int xmax, int k0, + int kmax) { + uint8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(uint8_t) * 16); + const int ksize = kmax - k0; + const int ksize8 = round_up(ksize, 4) * 8; + const int ksize4 = round_up(ksize, 4) * 4; + uint8_t* outptr = out; + uint8_t* outptr_base = out; + //! 4x4 block output start pos + uint8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8; + + int k = k0; + for (; k < kmax; k += 4) { + const uint8_t* inptr0 = in + k * ldin + x0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int x = x0; + outptr = outptr_base; + for (; x + 7 < xmax; x += 8) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += ksize8; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, 4); + outptr += ksize4; + } + + if (x < xmax) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + + outptr_base += 8 * 4; + outptr_base4 += 4 * 4; + } +} + +static void gemm_quint8_4x8_pack_B_t(dt_uint8* outptr, const dt_uint8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + uint8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(uint8_t) * 16); + + int y = y0; + for (; y + 7 < ymax; y += 8) { + const uint8_t* inptr0 = inptr + y * ldin + k0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + const uint8_t* inptr4 = inptr3 + ldin; + const uint8_t* inptr5 = inptr4 + ldin; + const uint8_t* inptr6 = inptr5 + ldin; + const uint8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + //! read 12 * 4 in each row + for (; K > 15; K -= 16) { + interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + } + if (K > 0) { + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, K); + } + } + for (; y < ymax; y += 4) { + const uint8_t* inptr0 = inptr + y * ldin + k0; + const uint8_t* inptr1 = inptr0 + ldin; + const uint8_t* inptr2 = inptr1 + ldin; + const uint8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + //! read 4 * 4 in each row + for (; K > 15; K -= 16) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, K); + } + } +} + +} // namespace matmul_dot_4x8x4 +} // namespace armv7 +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/quint8/strategy.cpp b/dnn/src/armv7/matrix_mul/quint8/strategy.cpp new file mode 100644 index 00000000..acbcb4db --- /dev/null +++ b/dnn/src/armv7/matrix_mul/quint8/strategy.cpp @@ -0,0 +1,183 @@ +/** + * \file dnn/src/armv7/matrix_mul/quint8/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/matrix_mul/quint8/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/armv7/matrix_mul/quint8/kernel_4x8x8.h" +#include "src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h" +#include "src/common/utils.h" + +using namespace megdnn; +using namespace armv7; +using namespace armv7::matmul; + +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_4x8); + +void gemm_u8_4x8::pack_A(dt_uint8* outptr, const dt_uint8* inptr, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + uint8_t zA = A_dtype.param().zero_point; + if (transpose) { + matmul_4x8x8::gemm_u8_4x8_transpose_pack_A_n(outptr, inptr, ldin, y0, + ymax, k0, kmax, zA); + } else { + matmul_4x8x8::gemm_u8_4x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, + kmax, zA); + } +} + +void gemm_u8_4x8::pack_B(dt_uint8* out, const dt_uint8* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) const { + uint8_t zB = B_dtype.param().zero_point; + if (transpose) { + matmul_4x8x8::gemm_u8_4x8_transpose_pack_B_n(out, in, ldin, x0, xmax, + k0, kmax, zB); + } else { + matmul_4x8x8::gemm_u8_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, + zB); + } +} + +void gemm_u8_4x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, + size_t N, size_t K, dt_int32* C, size_t LDC, + bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Quantized8Asymm && + C_dtype.enumv() == DTypeEnum::QuantizedS32, + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + uint8_t zA = A_dtype.param().zero_point; + uint8_t zB = B_dtype.param().zero_point; + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 8; + //! K is packed to times of 8 + K = round_up(K, 8); + const int K8 = K * 8; + const int K4 = K * 4; + + size_t m = 0; + for (; m < M; m += A_INTERLEAVE) { + int32_t* output = C + (m * LDC); + const dt_uint8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_4x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), zA, zB); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + matmul_4x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), + std::min(N - n, 4), zA, zB); + output += 4; + cur_packB += K4; + } + packA += K4; + } +} + +#if __ARM_FEATURE_DOTPROD +// ===========================gemm_dot_quint8_4x8====================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dot_quint8_4x8); +void gemm_dot_quint8_4x8::pack_A(dt_uint8* out, const dt_uint8* in, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_dot_4x8x4::gemm_quint8_4x8_pack_A_t(out, in, ldin, y0, ymax, k0, + kmax); + } else { + matmul_dot_4x8x4::gemm_quint8_4x8_pack_A_n(out, in, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_dot_quint8_4x8::pack_B(dt_uint8* out, const dt_uint8* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) const { + if (transpose) { + matmul_dot_4x8x4::gemm_quint8_4x8_pack_B_t(out, in, ldin, x0, xmax, k0, + kmax); + } else { + matmul_dot_4x8x4::gemm_quint8_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, + kmax); + } +} + +void gemm_dot_quint8_4x8::kern(const dt_uint8* packA, const dt_uint8* packB, + size_t M, size_t N, size_t K, dt_int32* C, + size_t LDC, bool is_first_k, const dt_int32*, + dt_int32* workspace) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Quantized8Asymm && + C_dtype.enumv() == DTypeEnum::QuantizedS32, + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + uint8_t zA = A_dtype.param().zero_point; + uint8_t zB = B_dtype.param().zero_point; + const uint32_t zAB = + static_cast(zA) * static_cast(zB) * K; + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 8; + K = round_up(K, 4); + const int K8 = K * 8; + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + (m * LDC); + const dt_uint8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_dot_4x8x4::kern_4x8(packA, cur_packB, K, output, LDC, + is_first_k, zA, zB, zAB); + output += B_INTERLEAVE; + cur_packB += K8; + } + for (; n < N; n += 4) { + size_t n_remain = std::min(N - n, 4); + matmul_dot_4x8x4::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k, zA, zB, zAB, 4, n_remain); + output += n_remain; + cur_packB += K4; + } + packA += K4; + } + if(m(M - m, 4); + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_dot_4x8x4::kern_4x8(packA, cur_packB, K, output, LDC, + is_first_k, zA, zB, zAB, m_remain); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += 4) { + size_t n_remain = std::min(N - n, 4); + matmul_dot_4x8x4::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k, zA, zB, zAB, m_remain, + n_remain); + output += n_remain; + cur_packB += K4; + } + packA += K4; + } +} + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/quint8/strategy.h b/dnn/src/armv7/matrix_mul/quint8/strategy.h new file mode 100644 index 00000000..d732e47e --- /dev/null +++ b/dnn/src/armv7/matrix_mul/quint8/strategy.h @@ -0,0 +1,29 @@ +/** + * \file dnn/src/armv7/matrix_mul/quint8/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 8, false, true, + gemm_u8_4x8); +#if __ARM_FEATURE_DOTPROD +MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 4, false, false, + gemm_dot_quint8_4x8); +#endif + +} // namespace matmul +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/relayout/opr_impl.cpp b/dnn/src/armv7/relayout/opr_impl.cpp new file mode 100644 index 00000000..68a7bc2d --- /dev/null +++ b/dnn/src/armv7/relayout/opr_impl.cpp @@ -0,0 +1,151 @@ +/** + * \file dnn/src/armv7/relayout/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/common/utils.h" +#include "src/common/relayout_helper.h" + +#include "src/armv7/handle.h" +#include "src/armv7/relayout/opr_impl.h" + +using namespace megdnn; +using namespace relayout; + +namespace { +struct TransposeByte { + uint8_t v; +}; + +void trans_16x16_u8(const void* src, void* dst, const size_t src_step, + const size_t dst_step) { + // 16x16 + asm volatile( + "\n" + "vld1.8 {d0, d1}, [%[src]], %[src_step] \n" + "vld1.8 {d2, d3}, [%[src]], %[src_step] \n" + "vld1.8 {d4, d5}, [%[src]], %[src_step] \n" + "vld1.8 {d6, d7}, [%[src]], %[src_step] \n" + "vld1.8 {d8, d9}, [%[src]], %[src_step] \n" + "vld1.8 {d10, d11}, [%[src]], %[src_step] \n" + "vld1.8 {d12, d13}, [%[src]], %[src_step] \n" + "vld1.8 {d14, d15}, [%[src]], %[src_step] \n" + "vld1.8 {d16, d17}, [%[src]], %[src_step] \n" + "vld1.8 {d18, d19}, [%[src]], %[src_step] \n" + "vld1.8 {d20, d21}, [%[src]], %[src_step] \n" + "vld1.8 {d22, d23}, [%[src]], %[src_step] \n" + "vld1.8 {d24, d25}, [%[src]], %[src_step] \n" + "vld1.8 {d26, d27}, [%[src]], %[src_step] \n" + "vld1.8 {d28, d29}, [%[src]], %[src_step] \n" + "vld1.8 {d30, d31}, [%[src]], %[src_step] \n" + "vtrn.8 q0, q1 \n" + "vtrn.8 q2, q3 \n" + "vtrn.8 q4, q5 \n" + "vtrn.8 q6, q7 \n" + "vtrn.8 q8, q9 \n" + "vtrn.8 q10, q11 \n" + "vtrn.8 q12, q13 \n" + "vtrn.8 q14, q15 \n" + "vtrn.16 q0, q2 \n" + "vtrn.16 q1, q3 \n" + "vtrn.16 q4, q6 \n" + "vtrn.16 q5, q7 \n" + "vtrn.16 q8, q10 \n" + "vtrn.16 q9, q11 \n" + "vtrn.16 q12, q14 \n" + "vtrn.16 q13, q15 \n" + "vtrn.32 q0, q4 \n" + "vtrn.32 q1, q5 \n" + "vtrn.32 q2, q6 \n" + "vtrn.32 q3, q7 \n" + "vtrn.32 q8, q12 \n" + "vtrn.32 q9, q13 \n" + "vtrn.32 q10, q14 \n" + "vtrn.32 q11, q15 \n" + "vswp d1, d16 \n" + "vswp d3, d18 \n" + "vswp d5, d20 \n" + "vswp d7, d22 \n" + "vswp d9, d24 \n" + "vswp d11, d26 \n" + "vswp d13, d28 \n" + "vswp d15, d30 \n" + "vst1.8 {d0, d1}, [%[dst]], %[dst_step] \n" + "vst1.8 {d2, d3}, [%[dst]], %[dst_step] \n" + "vst1.8 {d4, d5}, [%[dst]], %[dst_step] \n" + "vst1.8 {d6, d7}, [%[dst]], %[dst_step] \n" + "vst1.8 {d8, d9}, [%[dst]], %[dst_step] \n" + "vst1.8 {d10, d11}, [%[dst]], %[dst_step] \n" + "vst1.8 {d12, d13}, [%[dst]], %[dst_step] \n" + "vst1.8 {d14, d15}, [%[dst]], %[dst_step] \n" + "vst1.8 {d16, d17}, [%[dst]], %[dst_step] \n" + "vst1.8 {d18, d19}, [%[dst]], %[dst_step] \n" + "vst1.8 {d20, d21}, [%[dst]], %[dst_step] \n" + "vst1.8 {d22, d23}, [%[dst]], %[dst_step] \n" + "vst1.8 {d24, d25}, [%[dst]], %[dst_step] \n" + "vst1.8 {d26, d27}, [%[dst]], %[dst_step] \n" + "vst1.8 {d28, d29}, [%[dst]], %[dst_step] \n" + "vst1.8 {d30, d31}, [%[dst]], %[dst_step] \n" + : + [src] "+r" (src), + [dst] "+r" (dst) + : + [src_step] "r" (src_step), + [dst_step] "r" (dst_step) + : + "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", + "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", + "d31"); +} + +} // anonymous namespace + +namespace megdnn { +namespace relayout { +namespace transpose_fallback { +template <> +struct transpose_traits { + static constexpr size_t block_size = 16; +}; + +template <> +void transpose_block(const TransposeByte* src, + TransposeByte* dst, const size_t src_stride, + const size_t dst_stride) { + trans_16x16_u8(src, dst, src_stride, dst_stride); +} + +} // namespace transpose_fallback +} // namespace relayout +} // namespace megdnn + + +void armv7::RelayoutForwardImpl::exec(_megdnn_tensor_in src0, + _megdnn_tensor_out dst0, + Handle* src_handle) { + check_cpu_handle(src_handle); + TensorND src = src0, dst = dst0; + check_layout_and_canonize(src.layout, dst.layout); + + relayout::TransposeParam trans_param; + bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); + if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { + auto sptr = static_cast(src.raw_ptr), + dptr = static_cast(dst.raw_ptr); + MEGDNN_DISPATCH_CPU_KERN_OPR( + transpose_fallback::transpose( + trans_param.batch, trans_param.m, trans_param.n, sptr, + dptr)); + return; + } + exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/relayout/opr_impl.h b/dnn/src/armv7/relayout/opr_impl.h new file mode 100644 index 00000000..43f260b1 --- /dev/null +++ b/dnn/src/armv7/relayout/opr_impl.h @@ -0,0 +1,31 @@ +/** + * \file dnn/src/armv7/relayout/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "src/fallback/relayout/opr_impl.h" + +namespace megdnn { +namespace armv7 { + +class RelayoutForwardImpl final : public fallback::RelayoutForwardImpl { + public: + using fallback::RelayoutForwardImpl::RelayoutForwardImpl; + + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + Handle *src_handle) override; + + bool is_thread_safe() const override { return true; } +}; + +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/rotate/opr_impl.cpp b/dnn/src/armv7/rotate/opr_impl.cpp new file mode 100644 index 00000000..8114194d --- /dev/null +++ b/dnn/src/armv7/rotate/opr_impl.cpp @@ -0,0 +1,325 @@ +/** + * \file dnn/src/armv7/rotate/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include + +#include "src/armv7/rotate/opr_impl.h" +#include "src/armv7/handle.h" +#include "src/common/cv/common.h" +#include "src/common/cv/helper.h" +#include "src/common/utils.h" + +namespace megdnn { + +namespace megcv { +void rotate_8uc1_clockwise_16x16(const uchar *src, + uchar *dst, + size_t src_step, size_t dst_step) +{ + asm volatile ("\n" + "vld1.8 {d0, d1}, [%[src]], %[src_step] \n" + "vld1.8 {d2, d3}, [%[src]], %[src_step] \n" + "vld1.8 {d4, d5}, [%[src]], %[src_step] \n" + "vld1.8 {d6, d7}, [%[src]], %[src_step] \n" + "vld1.8 {d8, d9}, [%[src]], %[src_step] \n" + "vld1.8 {d10, d11}, [%[src]], %[src_step] \n" + "vld1.8 {d12, d13}, [%[src]], %[src_step] \n" + "vld1.8 {d14, d15}, [%[src]], %[src_step] \n" + "vld1.8 {d16, d17}, [%[src]], %[src_step] \n" + "vld1.8 {d18, d19}, [%[src]], %[src_step] \n" + "vld1.8 {d20, d21}, [%[src]], %[src_step] \n" + "vld1.8 {d22, d23}, [%[src]], %[src_step] \n" + "vld1.8 {d24, d25}, [%[src]], %[src_step] \n" + "vld1.8 {d26, d27}, [%[src]], %[src_step] \n" + "vld1.8 {d28, d29}, [%[src]], %[src_step] \n" + "vld1.8 {d30, d31}, [%[src]], %[src_step] \n" + "vtrn.8 q0, q1 \n" + "vtrn.8 q2, q3 \n" + "vtrn.8 q4, q5 \n" + "vtrn.8 q6, q7 \n" + "vtrn.8 q8, q9 \n" + "vtrn.8 q10, q11 \n" + "vtrn.8 q12, q13 \n" + "vtrn.8 q14, q15 \n" + "vtrn.16 q0, q2 \n" + "vtrn.16 q1, q3 \n" + "vtrn.16 q4, q6 \n" + "vtrn.16 q5, q7 \n" + "vtrn.16 q8, q10 \n" + "vtrn.16 q9, q11 \n" + "vtrn.16 q12, q14 \n" + "vtrn.16 q13, q15 \n" + "vtrn.32 q0, q4 \n" + "vtrn.32 q1, q5 \n" + "vtrn.32 q2, q6 \n" + "vtrn.32 q3, q7 \n" + "vtrn.32 q8, q12 \n" + "vtrn.32 q9, q13 \n" + "vtrn.32 q10, q14 \n" + "vtrn.32 q11, q15 \n" + "vswp d1, d16 \n" + "vswp d3, d18 \n" + "vswp d5, d20 \n" + "vswp d7, d22 \n" + "vswp d9, d24 \n" + "vswp d11, d26 \n" + "vswp d13, d28 \n" + "vswp d15, d30 \n" + "vswp d0, d1 \n" + "vswp d2, d3 \n" + "vswp d4, d5 \n" + "vswp d6, d7 \n" + "vswp d8, d9 \n" + "vswp d10, d11 \n" + "vswp d12, d13 \n" + "vswp d14, d15 \n" + "vswp d16, d17 \n" + "vswp d18, d19 \n" + "vswp d20, d21 \n" + "vswp d22, d23 \n" + "vswp d24, d25 \n" + "vswp d26, d27 \n" + "vswp d28, d29 \n" + "vswp d30, d31 \n" + "vrev64.8 q0, q0\n" + "vrev64.8 q1, q1\n" + "vrev64.8 q2, q2\n" + "vrev64.8 q3, q3\n" + "vrev64.8 q4, q4\n" + "vrev64.8 q5, q5\n" + "vrev64.8 q6, q6\n" + "vrev64.8 q7, q7\n" + "vrev64.8 q8, q8\n" + "vrev64.8 q9, q9\n" + "vrev64.8 q10, q10\n" + "vrev64.8 q11, q11\n" + "vrev64.8 q12, q12\n" + "vrev64.8 q13, q13\n" + "vrev64.8 q14, q14\n" + "vrev64.8 q15, q15\n" + "vst1.8 {d0, d1}, [%[dst]], %[dst_step] \n" + "vst1.8 {d2, d3}, [%[dst]], %[dst_step] \n" + "vst1.8 {d4, d5}, [%[dst]], %[dst_step] \n" + "vst1.8 {d6, d7}, [%[dst]], %[dst_step] \n" + "vst1.8 {d8, d9}, [%[dst]], %[dst_step] \n" + "vst1.8 {d10, d11}, [%[dst]], %[dst_step] \n" + "vst1.8 {d12, d13}, [%[dst]], %[dst_step] \n" + "vst1.8 {d14, d15}, [%[dst]], %[dst_step] \n" + "vst1.8 {d16, d17}, [%[dst]], %[dst_step] \n" + "vst1.8 {d18, d19}, [%[dst]], %[dst_step] \n" + "vst1.8 {d20, d21}, [%[dst]], %[dst_step] \n" + "vst1.8 {d22, d23}, [%[dst]], %[dst_step] \n" + "vst1.8 {d24, d25}, [%[dst]], %[dst_step] \n" + "vst1.8 {d26, d27}, [%[dst]], %[dst_step] \n" + "vst1.8 {d28, d29}, [%[dst]], %[dst_step] \n" + "vst1.8 {d30, d31}, [%[dst]], %[dst_step] \n" + : + [src] "+r" (src), + [dst] "+r" (dst) + : + [src_step] "r" (src_step), + [dst_step] "r" (dst_step) + : + "r0", "r1", "r2", "r3", + "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", + "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", + "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", + "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31" + ); + +} + +void rotate_8uc1_counterclockwise_16x16(const uchar *src, + uchar *dst, + size_t src_step, size_t dst_step) +{ + asm volatile ("\n" + "vld1.8 {d0, d1}, [%[src]], %[src_step] \n" + "vld1.8 {d2, d3}, [%[src]], %[src_step] \n" + "vld1.8 {d4, d5}, [%[src]], %[src_step] \n" + "vld1.8 {d6, d7}, [%[src]], %[src_step] \n" + "vld1.8 {d8, d9}, [%[src]], %[src_step] \n" + "vld1.8 {d10, d11}, [%[src]], %[src_step] \n" + "vld1.8 {d12, d13}, [%[src]], %[src_step] \n" + "vld1.8 {d14, d15}, [%[src]], %[src_step] \n" + "vld1.8 {d16, d17}, [%[src]], %[src_step] \n" + "vld1.8 {d18, d19}, [%[src]], %[src_step] \n" + "vld1.8 {d20, d21}, [%[src]], %[src_step] \n" + "vld1.8 {d22, d23}, [%[src]], %[src_step] \n" + "vld1.8 {d24, d25}, [%[src]], %[src_step] \n" + "vld1.8 {d26, d27}, [%[src]], %[src_step] \n" + "vld1.8 {d28, d29}, [%[src]], %[src_step] \n" + "vld1.8 {d30, d31}, [%[src]], %[src_step] \n" + "vtrn.8 q0, q1 \n" + "vtrn.8 q2, q3 \n" + "vtrn.8 q4, q5 \n" + "vtrn.8 q6, q7 \n" + "vtrn.8 q8, q9 \n" + "vtrn.8 q10, q11 \n" + "vtrn.8 q12, q13 \n" + "vtrn.8 q14, q15 \n" + "vtrn.16 q0, q2 \n" + "vtrn.16 q1, q3 \n" + "vtrn.16 q4, q6 \n" + "vtrn.16 q5, q7 \n" + "vtrn.16 q8, q10 \n" + "vtrn.16 q9, q11 \n" + "vtrn.16 q12, q14 \n" + "vtrn.16 q13, q15 \n" + "vtrn.32 q0, q4 \n" + "vtrn.32 q1, q5 \n" + "vtrn.32 q2, q6 \n" + "vtrn.32 q3, q7 \n" + "vtrn.32 q8, q12 \n" + "vtrn.32 q9, q13 \n" + "vtrn.32 q10, q14 \n" + "vtrn.32 q11, q15 \n" + "vswp d1, d16 \n" + "vswp d3, d18 \n" + "vswp d5, d20 \n" + "vswp d7, d22 \n" + "vswp d9, d24 \n" + "vswp d11, d26 \n" + "vswp d13, d28 \n" + "vswp d15, d30 \n" + "vst1.8 {d30, d31}, [%[dst]], %[dst_step] \n" + "vst1.8 {d28, d29}, [%[dst]], %[dst_step] \n" + "vst1.8 {d26, d27}, [%[dst]], %[dst_step] \n" + "vst1.8 {d24, d25}, [%[dst]], %[dst_step] \n" + "vst1.8 {d22, d23}, [%[dst]], %[dst_step] \n" + "vst1.8 {d20, d21}, [%[dst]], %[dst_step] \n" + "vst1.8 {d18, d19}, [%[dst]], %[dst_step] \n" + "vst1.8 {d16, d17}, [%[dst]], %[dst_step] \n" + "vst1.8 {d14, d15}, [%[dst]], %[dst_step] \n" + "vst1.8 {d12, d13}, [%[dst]], %[dst_step] \n" + "vst1.8 {d10, d11}, [%[dst]], %[dst_step] \n" + "vst1.8 {d8, d9}, [%[dst]], %[dst_step] \n" + "vst1.8 {d6, d7}, [%[dst]], %[dst_step] \n" + "vst1.8 {d4, d5}, [%[dst]], %[dst_step] \n" + "vst1.8 {d2, d3}, [%[dst]], %[dst_step] \n" + "vst1.8 {d0, d1}, [%[dst]], %[dst_step] \n" + : + [src] "+r" (src), + [dst] "+r" (dst) + : + [src_step] "r" (src_step), + [dst_step] "r" (dst_step) + : + "r0", "r1", "r2", "r3", + "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", + "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", + "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", + "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31" + ); +} + +void rotate_8uc1_clockwise(const uchar *src, uchar *dst, + const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) +{ + const size_t block = 16; + (void)block; + size_t i = 0; + + for (; i + block <= rows; i += block) { + size_t j = 0; + for (; j + block <= cols; j += block) { + rotate_8uc1_clockwise_16x16(src + i*src_step + j, + dst + j*dst_step + (rows-(i+block)), + src_step, dst_step); + } + for (; j < cols; ++j) { + for (size_t k = 0; k < block; ++k) { + dst[j*dst_step + (rows-1-(i+k))] = src[(i+k)*src_step + j]; + } + } + } + + for (; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + dst[j*dst_step + (rows-1-i)] = src[i*src_step + j]; + } + } +} + +void rotate_8uc1_counterclockwise(const uchar *src, uchar *dst, + const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) +{ + const size_t block = 16; + (void)block; + size_t i = 0; + + for (; i + block <= rows; i += block) { + size_t j = 0; + for (; j + block <= cols; j += block) { + rotate_8uc1_counterclockwise_16x16(src + i*src_step + j, + dst + (cols-(j+block))*dst_step + i, + src_step, dst_step); + } + for (; j < cols; ++j) { + for (size_t k = 0; k < block; ++k) { + dst[(cols-1-j)*dst_step + (i+k)] = src[(i+k)*src_step + j]; + } + } + } + + for (; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + dst[(cols-1-j)*dst_step + i] = src[i*src_step + j]; + } + } +} + +void rotate(const Mat &src, Mat &dst, + bool clockwise) +{ + megdnn_assert(src.rows() == dst.cols()); + megdnn_assert(src.cols() == dst.rows()); + megdnn_assert(src.channels() == dst.channels()); + megdnn_assert(src.channels() == 1_z); + if (clockwise) { + rotate_8uc1_clockwise(src.ptr(), dst.ptr(), src.rows(), src.cols(), + src.step(), dst.step()); + } else { + rotate_8uc1_counterclockwise(src.ptr(), dst.ptr(), src.rows(), + src.cols(), src.step(), dst.step()); + } +} + +} // namespace megcv + +namespace armv7 { + +void RotateImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, + _megdnn_workspace workspace) { + using namespace megcv; + check_exec(src.layout, dst.layout, workspace.size); + + //! rotate only support data type is uchar and the channel size is 1 + if (dst.layout.dtype != dtype::Uint8() || src.layout.shape[3] != 1) { + return fallback::RotateImpl::exec(src, dst, workspace); + } + + MEGDNN_DISPATCH_CPU_KERN_OPR({ + for (size_t i = 0; i < src.layout.shape[0]; ++i) { + Mat src_mat = TensorND2Mat(src, i); + Mat dst_mat = TensorND2Mat(dst, i); + rotate(src_mat, dst_mat, param().clockwise); + } + }); + +} + +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/rotate/opr_impl.h b/dnn/src/armv7/rotate/opr_impl.h new file mode 100644 index 00000000..574d68cd --- /dev/null +++ b/dnn/src/armv7/rotate/opr_impl.h @@ -0,0 +1,35 @@ +/** + * \file dnn/src/armv7/rotate/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "megdnn/oprs.h" +#include "src/fallback/rotate/opr_impl.h" + +namespace megdnn { +namespace armv7 { + +class RotateImpl : public fallback::RotateImpl { + public: + using fallback::RotateImpl::RotateImpl; + + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/handle.cpp b/dnn/src/common/handle.cpp index d9333a9a..ade602db 100644 --- a/dnn/src/common/handle.cpp +++ b/dnn/src/common/handle.cpp @@ -22,6 +22,14 @@ #include "src/x86/handle.h" #endif +#if MEGDNN_ARMV7 +#include "src/armv7/handle.h" +#endif + +#if MEGDNN_AARCH64 +#include "src/aarch64/handle.h" +#endif + #if MEGDNN_WITH_CUDA #include "src/cuda/handle.h" @@ -59,6 +67,10 @@ std::unique_ptr Handle::make(megcoreComputingHandle_t computing_handle, return std::unique_ptr( new x86::HandleImpl(computing_handle)); // return make_unique(computing_handle); +#elif MEGDNN_ARMV7 + return make_unique(computing_handle); +#elif MEGDNN_AARCH64 + return make_unique(computing_handle); #else return make_unique(computing_handle); #endif @@ -142,6 +154,15 @@ std::unique_ptr Handle::make(megcoreComputingHandle_t computing_handle, #if MEGDNN_X86 CASE(X86, x86); #endif +#if MEGDNN_ARMV7 + CASE(ARMV7, armv7); +#endif +#if MEGDNN_AARCH64 + CASE(AARCH64, aarch64); +#endif +#if MEGDNN_ARMV7 || MEGDNN_AARCH64 + CASE(ARM_COMMON, arm_common); +#endif #endif // !MEGDNN_NAIVE #if MEGDNN_WITH_CUDA CASE(CUDA,cuda); diff --git a/dnn/src/common/relayout_helper.h b/dnn/src/common/relayout_helper.h index 5c1e0d02..56d083bd 100644 --- a/dnn/src/common/relayout_helper.h +++ b/dnn/src/common/relayout_helper.h @@ -38,6 +38,8 @@ namespace transpose_fallback { #if MEGDNN_X86 constexpr size_t BLOCK_LINE_SIZE_BYTES = 64; +#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 +constexpr size_t BLOCK_LINE_SIZE_BYTES = 32; #else #error "unknown megdnn arch" #endif diff --git a/dnn/src/common/warp_common.h b/dnn/src/common/warp_common.h index efa719f6..f381a0f1 100644 --- a/dnn/src/common/warp_common.h +++ b/dnn/src/common/warp_common.h @@ -71,6 +71,8 @@ #if MEGDNN_X86 #include +#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 +#include "src/arm_common/simd_macro/marm_neon.h" #endif MIDOUT_DECL(megdnn_warp) @@ -873,6 +875,12 @@ void remap(const Mat& src, Mat& dst, Mat& map1, Mat& map2, d_data = _mm_and_si128(sA_data, v_INTER_TAB_SIZE2); _mm_storeu_si128(dst, d_data); } +#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 + uint16x8_t v_scale = vdupq_n_u16(INTER_TAB_SIZE2 - 1); + for (; x1 <= bcols - 8; x1 += 8) + vst1q_u16(A + x1, + vandq_u16(vld1q_u16(sA + x1), v_scale)); + #endif for (; x1 < bcols; ++x1) A[x1] = (ushort)(sA[x1] & (INTER_TAB_SIZE2 - 1)); diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp index bad00033..088b8f97 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp @@ -21,6 +21,8 @@ #if MEGDNN_X86 #include "src/x86/conv_bias/postprocess_helper.h" +#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) +#include "src/arm_common/conv_bias/postprocess_helper.h" #endif #include "midout.h" @@ -52,7 +54,7 @@ size_t ConvBiasImpl::AlgoConv1x1::get_workspace( auto matmul_param = get_matmul_kern_param(param, OH * OW, compt_oc_block_size); - + auto pack_mode = m_matmul_algo->packmode(); if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::DEFAULT) { MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 0, 0) { diff --git a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h index 65ca322d..2030d02c 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h +++ b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h @@ -16,6 +16,8 @@ #include "src/fallback/conv_bias/opr_impl.h" #if MEGDNN_X86 #include "src/x86/conv_bias/postprocess_helper.h" +#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) +#include "src/arm_common/conv_bias/postprocess_helper.h" #endif namespace megdnn { @@ -97,7 +99,7 @@ public: const ConvBiasImpl::NCBKernSizeParam& param, const ConvBiasImpl::NCBKernParam& ncb_param, const ConvBiasImpl::NCBKernIndex& ncb_index) override { - + if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { megdnn_log_error("NoPack mode has no packA kernel"); return; @@ -321,3 +323,5 @@ public: } // namespace conv1x1 } // namespace fallback } // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/im2col/factory.h b/dnn/src/fallback/conv_bias/im2col/factory.h index 31542af1..528917e8 100644 --- a/dnn/src/fallback/conv_bias/im2col/factory.h +++ b/dnn/src/fallback/conv_bias/im2col/factory.h @@ -23,11 +23,19 @@ namespace im2col { enum class StrategyType : uint32_t { FLOAT = 0, +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + FLOAT_FP16 = 1, +#else #if !MEGDNN_DISABLE_FLOAT16 FLOAT16_FLOAT16 = 2, #endif +#endif INT8x8x32 = 3, INT8x8x16 = 4, +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 + QUINT8x8x32 = 5, + QUINT8x8x32x8 = 6, +#endif QINT8x8x32 = 7, QINT8x8x32x8 = 8 }; @@ -130,9 +138,13 @@ public: } cb1(dt_float32, dt_float32, StrategyType::FLOAT); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + cb1(dt_float16, __fp16, StrategyType::FLOAT_FP16); +#else #if !MEGDNN_DISABLE_FLOAT16 cb1(dt_float16, dt_float16, StrategyType::FLOAT16_FLOAT16); #endif +#endif cb2(dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32, StrategyType::INT8x8x32); @@ -140,6 +152,13 @@ public: cb2(dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16, StrategyType::INT8x8x16); +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 + cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::QuantizedS32, + dt_uint8, dt_int32, dt_int32, StrategyType::QUINT8x8x32); + + cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::Quantized8Asymm, + dt_uint8, dt_int32, dt_uint8, StrategyType::QUINT8x8x32x8); +#endif cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, StrategyType::QINT8x8x32); @@ -193,6 +212,12 @@ public: cb1(NCHW, DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT, "DefaultStrategyType::FLOAT"_hash); break; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case StrategyType::FLOAT_FP16: + cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, + "DefaultStrategyType::FLOAT_FP16"_hash); + break; +#else #if !MEGDNN_DISABLE_FLOAT16 case StrategyType::FLOAT16_FLOAT16: cb1(NCHW, DEFAULT, dt_float16, dt_float16, @@ -200,6 +225,7 @@ public: "DefaultStrategyType::FLOAT16_FLOAT16"_hash); break; #endif +#endif case StrategyType::INT8x8x32: if (format == param::ConvBias::Format::NCHW) { cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, @@ -220,6 +246,21 @@ public: dt_int16, dt_int16, PostprocessMode::NO_PROCESS, "DefaultStrategyType::INT8x8x16"_hash); break; +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 + case StrategyType::QUINT8x8x32: + cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, + dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, + PostprocessMode::NO_PROCESS, + "DefaultStrategyType::QUINT8x8x32"_hash); + break; + + case StrategyType::QUINT8x8x32x8: + cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, + dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, + PostprocessMode::QUANTIZED, + "DefaultStrategyType::QUINT8x8x32x8"_hash); + break; +#endif case StrategyType::QINT8x8x32: if (format == param::ConvBias::Format::NCHW) { cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, @@ -265,6 +306,12 @@ public: cb1(NCHW, NO_PACK, dt_float32, dt_float32, PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash); break; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case StrategyType::FLOAT_FP16: + cb1(NCHW, NO_PACK, dt_float16, __fp16, PostprocessMode::FLOAT, + "NoPackStrategyType::FLOAT_FP16"_hash); + break; +#else #if !MEGDNN_DISABLE_FLOAT16 case StrategyType::FLOAT16_FLOAT16: cb1(NCHW, NO_PACK, dt_float16, dt_float16, @@ -272,6 +319,7 @@ public: "NoPackStrategyType::FLOAT16_FLOAT16"_hash); break; #endif +#endif case StrategyType::INT8x8x32: cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, @@ -284,6 +332,21 @@ public: "NoPackStrategyType::INT8x8x16"_hash); break; +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 + case StrategyType::QUINT8x8x32: + cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, + dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, + PostprocessMode::NO_PROCESS, + "NoPackStrategyType::QUINT8x8x32"_hash); + break; + + case StrategyType::QUINT8x8x32x8: + cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, + dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, + PostprocessMode::QUANTIZED, + "NoPackStrategyType::QUINT8x8x32x8"_hash); + break; +#endif case StrategyType::QINT8x8x32: cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, @@ -312,6 +375,13 @@ public: PostprocessMode::FLOAT, "OnlyPackaStrategyType::FLOAT"_hash); break; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case StrategyType::FLOAT_FP16: + cb1(NCHW, ONLY_PACKA, dt_float16, __fp16, + PostprocessMode::FLOAT, + "OnlyPackaStrategyType::FLOAT_FP16"_hash); + break; +#else #if !MEGDNN_DISABLE_FLOAT16 case StrategyType::FLOAT16_FLOAT16: cb1(NCHW, ONLY_PACKA, dt_float16, dt_float16, @@ -319,6 +389,7 @@ public: "OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash); break; #endif +#endif case StrategyType::INT8x8x32: cb2(NCHW, ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, @@ -331,6 +402,21 @@ public: "OnlyPackaStrategyType::INT8x8x16"_hash); break; +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 + case StrategyType::QUINT8x8x32: + cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm, + dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, + dt_int32, dt_int32, PostprocessMode::NO_PROCESS, + "OnlyPackaStrategyType::QUINT8x8x32"_hash); + break; + + case StrategyType::QUINT8x8x32x8: + cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm, + dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, + dt_int32, dt_uint8, PostprocessMode::QUANTIZED, + "OnlyPackaStrategyType::QUINT8x8x32x8"_hash); + break; +#endif case StrategyType::QINT8x8x32: cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, @@ -409,3 +495,5 @@ Strategy* StrategyDelegationStorage::get( } // namespace im2col } // namespace fallback } // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_default.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_default.cpp index 69854d29..09aeaf5c 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_default.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_default.cpp @@ -12,6 +12,8 @@ #include "src/fallback/convolution/img2col_helper.h" #if MEGDNN_X86 #include "src/x86/conv_bias/postprocess_helper.h" +#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) +#include "src/arm_common/conv_bias/postprocess_helper.h" #endif using namespace megdnn; @@ -346,11 +348,23 @@ void Strategy @@ -314,11 +317,26 @@ void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst, fh2 = FH - fh - 1; fw2 = FW - fw - 1; } +#if MEGDNN_ARMV7 || MEGDNN_AARCH64 + int w = cur_remain_w; + size_t index = (ic * IH * IW + (start_h + fh2) * IW + + (w + fw2)); + for (; w + 3 < end_remain_w; w += 4) { + vst1q_u32(&output[i], + vld1q_u32(&uint32_src[index])); + i += 4; + index += 4; + } + for (; w < end_remain_w; w++) { + output[i++] = uint32_src[index]; + } +#else for (int w = cur_remain_w; w < end_remain_w; w++) { size_t index = (ic * IH * IW + (start_h + fh2) * IW + (w + fw2)); output[i++] = uint32_src[index]; } +#endif } } } @@ -342,11 +360,27 @@ void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst, } for (int h = start_h + 1; h < end_h; h++) { +#if MEGDNN_ARMV7 || MEGDNN_AARCH64 + int ow = 0; + size_t index = (ic * IH * IW + (h + fh2) * IW + + (ow + fw2)); + for (; ow + 3 < OW; ow += 4) { + vst1q_u32(&output[i], + vld1q_u32(&uint32_src[index])); + i += 4; + index += 4; + } + + for (; ow < OW; ow++) { + output[i++] = uint32_src[index++]; + } +#else rep(ow, OW) { size_t index = (ic * IH * IW + (h + fh2) * IW + (ow + fw2)); output[i++] = uint32_src[index]; } +#endif } for (int w = 0; w < end_remain_w; w++) { diff --git a/dnn/src/fallback/powc/opr_impl.cpp b/dnn/src/fallback/powc/opr_impl.cpp index e74e2df8..58640e49 100644 --- a/dnn/src/fallback/powc/opr_impl.cpp +++ b/dnn/src/fallback/powc/opr_impl.cpp @@ -13,6 +13,9 @@ #include "src/naive/handle.h" +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "src/arm_common/simd_macro/marm_neon.h" +#endif #include @@ -167,6 +170,10 @@ void pow_invoke(const T* src, T* dst, size_t size, ExpFunc expfunc) { dst[i + 2] = b2; dst[i + 3] = b3; } +#if MEGDNN_FIX_AARCH32_BUG + // FIXME: as llvm may cause cannot select error if enable vectorize + #pragma clang loop vectorize(disable) +#endif for (; i < size; ++i) { dst[i] = expfunc.apply(src[i]); } @@ -254,9 +261,14 @@ void PowCImpl::do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, #if !MEGDNN_DISABLE_FLOAT16 case DTypeTrait::enumv: +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return MEGDNN_INC_FLOAT16( + do_exec_ct<__fp16>(src, dst, exp_f, exp_i)); +#else return MEGDNN_INC_FLOAT16( do_exec_ct(src, dst, exp_f, exp_i)); #endif +#endif default: megdnn_throw("unsupported dtype for PowC"); } diff --git a/dnn/src/fallback/type_cvt/opr_impl.cpp b/dnn/src/fallback/type_cvt/opr_impl.cpp index c63645b5..4c952130 100644 --- a/dnn/src/fallback/type_cvt/opr_impl.cpp +++ b/dnn/src/fallback/type_cvt/opr_impl.cpp @@ -36,6 +36,56 @@ struct TypeCvt { } }; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +//! As aarch32 __fp16 vectorize may cause llvm error, so if macro \c +//! MEGDNN_FIX_AARCH32_BUG defined, we use dt_float16, otherwise __fp16 +#if MEGDNN_FIX_AARCH32_BUG +#define FLOAT16 dt_float16 +#else +#define FLOAT16 __fp16 +#endif +template +struct TypeCvt { + static void do_cvt(_megdnn_tensor_in src, _megdnn_tensor_out dst) { + using sctype = typename DTypeTrait::ctype; + auto n = src.layout.total_nr_elems(); + const sctype* __restrict sptr = src.ptr(); + FLOAT16* __restrict dptr = static_cast(dst.raw_ptr); + for (size_t i = 0; i < n; ++i) { + dptr[i] = static_cast(sptr[i]); + } + } +}; + +template +struct TypeCvt { + static void do_cvt(_megdnn_tensor_in src, _megdnn_tensor_out dst) { + auto n = src.layout.total_nr_elems(); + using dctype = typename DTypeTrait::ctype; + const FLOAT16* __restrict sptr = static_cast(src.raw_ptr); + dctype* __restrict dptr = dst.ptr(); + for (size_t i = 0; i < n; ++i) { + dptr[i] = static_cast(sptr[i]); + } + } +}; + +template <> +struct TypeCvt { + static void do_cvt(_megdnn_tensor_in src, _megdnn_tensor_out dst) { + auto n = src.layout.total_nr_elems(); + const FLOAT16* __restrict sptr = static_cast(src.raw_ptr); + FLOAT16* __restrict dptr = static_cast(dst.raw_ptr); + for (size_t i = 0; i < n; ++i) { + dptr[i] = static_cast(sptr[i]); + } + } +}; + +#undef FLOAT16 + +#endif template void do_cvt_normal_s8(_megdnn_tensor_in src, _megdnn_tensor_out dst) { diff --git a/dnn/test/aarch64/batched_matrix_mul.cpp b/dnn/test/aarch64/batched_matrix_mul.cpp new file mode 100644 index 00000000..23003f33 --- /dev/null +++ b/dnn/test/aarch64/batched_matrix_mul.cpp @@ -0,0 +1,257 @@ +/** + * \file dnn/test/aarch64/batched_matrix_mul.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/rng.h" +#include "test/common/matrix_mul.h" + +#include "test/aarch64/fixture.h" + +namespace megdnn { +namespace test { + +TEST_F(AARCH64, BATCHED_MATRIX_MUL) { + Checker checker(handle()); + checker.set_epsilon(1e-2); + using Param = MatrixMul::Param; + // auto args = get_batch_matmul_args(); + auto args = matrix_mul::get_batched_matmul_args(); + + for (DType dtype : std::vector{dtype::Float32()}) { + for (unsigned mask = 0; mask < 4; ++mask) { + for (auto& arg : args) { + size_t b = arg.b, m = arg.m, n = arg.n, k = arg.k; + //! if test all batch sizes, the test case will time out. + if (b != 2) { + continue; + } + Param param; + param.transposeA = mask & 1; + param.transposeB = mask & 2; + TensorShape A, B; + if (param.transposeA) + A = TensorShape{b, k, m}; + else + A = TensorShape{b, m, k}; + if (param.transposeB) + B = TensorShape{b, n, k}; + else + B = TensorShape{b, k, n}; + checker.set_param(param) + .set_dtype(0, dtype) + .set_dtype(1, dtype) + .execs({A, B, {}}); + } + } + } +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(AARCH64, BATCHED_MATRIX_MUL_FP16) { + Checker checker(handle()); + using Param = MatrixMul::Param; + auto args = matrix_mul::get_batched_matmul_args(); + + NormalRNG rng(1.f); + checker.set_rng(0, &rng).set_rng(1, &rng).set_epsilon(1e-2); + for (DType dtype : std::vector{dtype::Float16()}) { + for (unsigned mask = 0; mask < 4; ++mask) { + for (auto& arg : args) { + size_t b = arg.b, m = arg.m, n = arg.n, k = arg.k; + //! if test all batch sizes, the test case will time out on + //! sdm855 + if (b != 1) { + continue; + } + Param param; + param.transposeA = mask & 1; + param.transposeB = mask & 2; + TensorShape A, B; + if (param.transposeA) + A = TensorShape{b, k, m}; + else + A = TensorShape{b, m, k}; + if (param.transposeB) + B = TensorShape{b, n, k}; + else + B = TensorShape{b, k, n}; + checker.set_param(param) + .set_dtype(0, dtype) + .set_dtype(1, dtype) + .set_dtype(2, dtype) + .execs({A, B, {}}); + } + } + } +} + +#if MEGDNN_WITH_BENCHMARK +TEST_F(AARCH64, BENCHMARK_TRANSPOSED_MATRIX_MUL_QUICK_FP16) { + int exec_times = 10; + Benchmarker benchmarker_gemm(handle()); + benchmarker_gemm.set_times(exec_times); + + float mod = 1000 * exec_times / 1e9; + using Param = MatrixMul::Param; + auto run = [&](size_t M, size_t K, size_t N) { + float time = 1.f, perf = 1.f; + + std::cout << "GEMM: (" << M << ", " << K << ", " << N << ")" + << std::endl; + Param param; + param.transposeA = true; + param.transposeB = true; + benchmarker_gemm.set_param(param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()); + time = benchmarker_gemm.exec({{M, K}, {K, N}, {}}); + perf = 2.f * M * K * N / time * mod; + std::cout << "gemm fp32, Performance is " << perf << " Gflops" + << std::endl; + benchmarker_gemm.set_param(param) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()); + time = benchmarker_gemm.exec({{M, K}, {K, N}, {}}); + perf = 2.f * M * K * N / time * mod; + std::cout << "gemm fp16, Performance is " << perf << " Gflops" + << std::endl; + + }; + + // run M = K = N + run(32, 32, 32); + run(64, 64, 64); + run(128, 128, 128); + run(256, 256, 256); + run(512, 512, 512); + run(1024, 1024, 1024); + run(2048, 2048, 2048); +} + +TEST_F(AARCH64, BENCHMARK_TRANSPOSED_MATRIX_MUL_ALL_SIZES_FP16) { + int exec_times = 50; + Benchmarker benchmarker_gemm(handle()); + benchmarker_gemm.set_times(exec_times); + + float mod = 1000 * exec_times / 1e9; + using Param = MatrixMul::Param; + auto run = [&](size_t M, size_t K, size_t N) { + float time = 1.f, perf = 1.f; + + std::cout << "GEMM: (" << M << ", " << K << ", " << N << ")" + << std::endl; + Param param; + param.transposeA = param.transposeB = true; + benchmarker_gemm.set_param(param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()); + time = benchmarker_gemm.exec({{K, M}, {N, K}, {}}); + perf = 2.f * M * K * N / time * mod; + std::cout << "gemm fp32, Performance is " << perf << " Gflops" + << std::endl; + benchmarker_gemm.set_param(param) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()); + time = benchmarker_gemm.exec({{K, M}, {N, K}, {}}); + perf = 2.f * M * K * N / time * mod; + std::cout << "gemm fp16, Performance is " << perf << " Gflops" + << std::endl; + + }; + + std::cout << "warm up:\n"; + for (int i = 0; i < 50; i++) { + benchmarker_gemm.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_display(false) + .exec({{256, 256}, {256, 256}, {}}); + benchmarker_gemm.set_display(true); + } + + // run M = K = N + run(8, 8, 8); + run(16, 16, 16); + run(32, 32, 32); + run(64, 64, 64); + run(128, 128, 128); + run(256, 256, 256); + run(512, 512, 512); + run(1024, 1024, 1024); + run(2048, 2048, 2048); + + // run sgmev like + run(32, 32, 1); + run(64, 64, 1); + run(128, 128, 1); + run(256, 256, 1); + run(512, 512, 1); + + // run M, N >> K + run(32, 16, 32); + run(64, 16, 64); + run(128, 16, 128); + run(256, 16, 256); + run(512, 16, 512); + + // run N, K >> M + run(16, 32, 32); + run(16, 64, 64); + run(16, 128, 128); + run(16, 256, 256); + run(16, 512, 512); + + // run M >> K, N + run(32, 16, 16); + run(64, 16, 16); + run(128, 16, 16); + run(256, 16, 16); + run(512, 16, 16); + + // run K >> M, N + run(16, 32, 16); + run(16, 64, 16); + run(16, 128, 16); + run(16, 256, 16); + run(16, 512, 16); + + // run N >> M, K + run(16, 16, 32); + run(16, 16, 64); + run(16, 16, 128); + run(16, 16, 256); + run(16, 16, 512); + + // run VGG + // conv 1.1 + run(64, 3 * 3 * 3, 224 * 224); + // conv 1.2 + run(128, 64 * 3 * 3, 112 * 112); + // conv 2.1 + run(128, 128 * 3 * 3, 112 * 112); + // conv 2.2 + run(128, 128 * 3 * 3, 56 * 56); + // conv 3.1 + run(256, 128 * 3 * 3, 56 * 56); + // conv 3.2 + run(256, 256 * 3 * 3, 28 * 28); + // conv 4.1 + run(512, 256 * 3 * 3, 28 * 28); + // conv 4.2 + run(512, 512 * 3 * 3, 14 * 14); +} + +#endif +#endif + +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/test/aarch64/conv_bias.cpp b/dnn/test/aarch64/conv_bias.cpp new file mode 100644 index 00000000..02eeed25 --- /dev/null +++ b/dnn/test/aarch64/conv_bias.cpp @@ -0,0 +1,291 @@ +/** + * \file dnn/test/aarch64/conv_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/aarch64/fixture.h" + +#include "src/fallback/conv_bias/common.h" +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/conv_bias.h" +#include "test/common/rng.h" +#include "test/common/tensor.h" + + +namespace megdnn { +namespace test { + +std::vector get_conv_bias_args(std::vector kernel, + size_t stride) { + using namespace conv_bias; + using Param = param::ConvBias; + using NLMode = param::ConvBias::NonlineMode; + + std::vector args; + auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h, + size_t kernel, size_t stride, NLMode nonline_mode) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = kernel == 1 ? 0 : kernel / 2; + param.pad_w = kernel == 1 ? 0 : kernel / 2; + param.nonlineMode = nonline_mode; + + //! no bias + args.emplace_back(param, TensorShape{n, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, TensorShape{}); + //! bias broadcast channle + args.emplace_back(param, TensorShape{n, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + //! bias + args.emplace_back( + param, TensorShape{n, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{n, oc, (h + 2 * param.pad_h - kernel) / stride + 1, + (w + 2 * param.pad_h - kernel) / stride + 1}); + }; + + for (auto nlmode : {NLMode::IDENTITY, NLMode::RELU, NLMode::SIGMOID}) { + for (size_t n : {1, 2}) { + for (size_t ic : {1, 2, 3, 4, 8}) { + for (size_t oc : {1, 2, 3, 4, 8}) { + for (size_t size : {1, 2, 3, 4, 8, 24}) { + for (size_t k : kernel) { + pack(n, oc, ic, size + 24, size + 24, k, stride, + nlmode); + } + } + } + } + } + } + return args; +} + +void checker_conv_bias(std::vector args, Handle* handle, + const char* algo_name) { + using namespace conv_bias; + + Checker checker(handle); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name)); + for (auto&& arg : args) { + checker.set_param(arg.param).execs( + {arg.src, arg.filter, arg.bias, {}, {}}); + } +} +TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) { + check_conv_bias( + conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), + handle(), "ARMV8F32STRD2_LARGE_GROUP"); +} +TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) { + check_conv_bias( + conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), + handle(), "ARMV8F32STRD2_SMALL_GROUP"); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void checker_conv_bias_fp16(std::vector args, + Handle* handle, const char* algo_name, + float epsilon) { + using namespace conv_bias; + Checker checker(handle); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name)); + checker.set_epsilon(epsilon); + checker.set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .set_dtype(4, dtype::Float16()); + NormalRNG rng(1.f); + checker.set_rng(0, &rng).set_rng(1, &rng); + + for (auto&& arg : args) { + checker.set_param(arg.param).execs( + {arg.src, arg.filter, arg.bias, {}, {}}); + } +} + +TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2_LARGE_GROUP) { + NormalRNG rng(1); + checker_conv_bias_f16( + conv_bias::get_conv_bias_args({2, 3, 5}, 2, false, false, false), + handle(), rng, "ARMV8F16STRD2_LARGE_GROUP", 0.04); +} +TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2_SMALL_GROUP) { + NormalRNG rng(1); + checker_conv_bias_f16( + conv_bias::get_conv_bias_args({2, 3, 5}, 2, false, false, false), + handle(), rng, "ARMV8F16STRD2_SMALL_GROUP", 0.04); +} +#endif + +#if MEGDNN_WITH_BENCHMARK +std::vector get_conv_bias_benchmaker_args( + std::vector kernel, size_t stride) { + using namespace conv_bias; + using Param = param::ConvBias; + using NLMode = param::ConvBias::NonlineMode; + + std::vector args; + auto pack = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t stride, NLMode nonline_mode) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = kernel == 1 ? 0 : kernel / 2; + param.pad_w = kernel == 1 ? 0 : kernel / 2; + param.nonlineMode = nonline_mode; + //! no bias + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, TensorShape{}); + //! bias broadcast channle + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + //! bias + args.emplace_back( + param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, (h + 2 * param.pad_h - kernel) / stride + 1, + (w + 2 * param.pad_w - kernel) / stride + 1}); + }; + + for (auto nlmode : {NLMode::IDENTITY, NLMode::RELU, NLMode::SIGMOID}) { + for (size_t k : kernel) { + for (size_t ic : {3, 6, 12, 24}) { + for (size_t oc : {3, 6, 12, 24}) { + for (size_t size : + {4, 7, 8, 14, 16, 17, 28, 32, 34, 64, 112}) { + pack(oc, ic, size, size, k, stride, nlmode); + } + } + } + } + } + return args; +} + +void benchmarker_conv_bias(std::vector args, Handle* handle, + const char* algo_name, const char* cmp_algo_name) { + using namespace conv_bias; + + constexpr size_t N = 10; + Benchmarker benchmark_float(handle); + benchmark_float + .set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name)) + .set_times(N) + .set_display(false); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + Benchmarker benchmark_float16(handle); + benchmark_float16 + .set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(cmp_algo_name)) + .set_times(N) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .set_dtype(4, dtype::Float16()) + .set_display(false); +#endif + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, + {arg.bias, dtype::Float32()}, {}, dst_layout); + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; // GFLOPS + printf("filter n: %zu c: %zu h:%zu w:%zu ", arg.filter[0], + arg.filter[1], arg.filter[2], arg.filter[3]); + printf("input c: %zu h:%zu w:%zu \n", arg.src[1], arg.src[2], + arg.src[3]); + auto time32 = benchmark_float.set_param(arg.param).execs( + {arg.src, arg.filter, arg.bias, {}, {}}) / + N; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + auto time16 = benchmark_float16.set_param(arg.param).execs( + {arg.src, arg.filter, arg.bias, {}, {}}) / + N; + printf("---------------------------------fp32 flops: %.3f Gflops fp16 " + "flops %.3f Gflops speedup: %f\n", + computations / time32, computations / time16, time32 / time16); +#else + printf("---------------------------------fp32 flops: %.3f Gflops\n", + computations / time32); +#endif + } +} + + +TEST_F(AARCH64, BENCHMARK_CONVBIAS_STRIDE2_FP32_FP16) { + benchmarker_conv_bias(get_conv_bias_benchmaker_args({2,3,5,7},2), + handle(),"ARMV8F32STRD2", "ARMV8F16STRD2"); +} + +TEST_F(AARCH64, BENCHMARK_CONVBIAS) { + constexpr size_t RUNS = 10; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + Benchmarker benchmarker_int(handle()); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(40.25f)) + .set_display(false); + Benchmarker benchmarker_float(handle()); + benchmarker_float.set_display(false).set_times(RUNS); + + auto run = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS) { + TensorShape src({N, IC, H, W}), filter({OC, IC, FS, FS}), + bias({N, OC, H, W}), dst({N, OC, H, W}); + param.pad_h = FS / 2; + param.pad_w = FS / 2; + auto int_used = benchmarker_int.set_param(param).exec( + {src, filter, bias, {}, dst}) / + RUNS; + auto float_used = benchmarker_float.set_param(param).exec( + {src, filter, bias, {}, dst}) / + RUNS; + float computations = + IC * (FS * FS + 1) * dst.total_nr_elems() * 2 * 1e-6; + printf("run: %s %s %s->%s \nfloat: %f ms %f Gflops int: %f ms " + "%f Gflops speedup: %f\n", + src.to_string().c_str(), filter.to_string().c_str(), + bias.to_string().c_str(), dst.to_string().c_str(), float_used, + computations / float_used, int_used, computations / int_used, + float_used / int_used); + }; + + run(1, 128, 128, 32, 32, 3); + + for (size_t IC : {1, 4, 8, 16, 32, 64}) { + for (size_t OC : {1, 4, 8, 16, 32, 64}) { + for (size_t size : {7, 14, 28, 56}) { + for (size_t FS : {1, 3, 5}) { + run(1, IC, OC, size, size, FS); + } + } + } + } +} + +#endif +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/aarch64/convolution.cpp b/dnn/test/aarch64/convolution.cpp new file mode 100644 index 00000000..6bdda2e6 --- /dev/null +++ b/dnn/test/aarch64/convolution.cpp @@ -0,0 +1,309 @@ +/** + * \file dnn/test/aarch64/convolution.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/aarch64/fixture.h" + +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/convolution.h" + +#include "test/common/rng.h" + +using namespace megdnn; +using namespace test; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(AARCH64, CONVOLUTION_BACKWARD_DATA_FP16) { + Checker checker(handle()); + using Param = ConvolutionBackwardData::Param; + Param param; + auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, + size_t fh, size_t fw, size_t stride, size_t padding, + size_t group = 1) { + param.pad_h = param.pad_w = padding; + param.stride_h = param.stride_w = stride; + + TensorLayout diff = + TensorLayout{{n, oc * group, oh, ow}, dtype::Float16()}; + TensorLayout grad; + TensorLayout filter; + if (group == 1) { + param.sparse = Param::Sparse::DENSE; + filter = {{oc, ic, fh, fw}, dtype::Float16()}; + } else { + param.sparse = Param::Sparse::GROUP; + filter = {{group, oc, ic, fh, fw}, dtype::Float16()}; + } + // TensorLayout grad; + { + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout(filter, diff, grad); + } + NormalRNG rng(10.f); + checker.set_param(param) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .set_rng(0, &rng).set_rng(1, &rng) + .set_epsilon(1e-2) + .set_before_exec_callback( + AlgoChecker("DeconvMatmul")); + checker.exec(TensorLayoutArray{filter, diff, grad}); + }; + + for (auto mode : + {Param::Mode::CONVOLUTION, Param::Mode::CROSS_CORRELATION}) { + param.mode = mode; + run(4, 3, 10, 13, 5, 1, 1, 1, 0, 1); + run(4, 3, 10, 45, 2, 1, 1, 1, 0, 4); + run(2, 3, 9, 12, 2, 4, 6, 1, 0, 1); + run(3, 4, 17, 32, 2, 3, 2, 5, 4, 4); + run(5, 5, 24, 43, 11, 9, 3, 3, 12, 2); + run(2, 3, 20, 33, 3, 5, 7, 4, 15, 2); + run(4, 4, 6, 7, 9, 3, 2, 2, 1, 3); + } +} + + +#if MEGDNN_WITH_BENCHMARK +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_QUICK_FP16) { + int exec_times = 10; + Benchmarker benchmarker_gemm(handle()); + benchmarker_gemm.set_times(exec_times); + + float mod = 1000 * exec_times / 1e9; + auto run = [&](size_t M, size_t K, size_t N) { + float time = 1.f, perf = 1.f; + + std::cout << "GEMM: (" << M << ", " << K << ", " << N << ")" + << std::endl; + benchmarker_gemm.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()); + time = benchmarker_gemm.exec({{M, K}, {K, N}, {}}); + perf = 2.f * M * K * N / time * mod; + std::cout << "gemm fp32, Performance is " << perf << " Gflops" + << std::endl; + benchmarker_gemm.set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()); + time = benchmarker_gemm.exec({{M, K}, {K, N}, {}}); + perf = 2.f * M * K * N / time * mod; + std::cout << "gemm fp16, Performance is " << perf << " Gflops" + << std::endl; + + }; + + // run M = K = N + run(32, 32, 32); + run(64, 64, 64); + run(128, 128, 128); + run(256, 256, 256); + run(512, 512, 512); + run(1024, 1024, 1024); + run(2048, 2048, 2048); +} + +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_ALL_SIZES_FP16) { + int exec_times = 10; + Benchmarker benchmarker_gemm(handle()); + benchmarker_gemm.set_times(exec_times); + + float mod = 1000 * exec_times / 1e9; + auto run = [&](size_t M, size_t K, size_t N) { + float time = 1.f, perf = 1.f; + + std::cout << "GEMM: (" << M << ", " << K << ", " << N << ")" + << std::endl; + benchmarker_gemm.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()); + time = benchmarker_gemm.exec({{M, K}, {K, N}, {}}); + perf = 2.f * M * K * N / time * mod; + std::cout << "gemm fp32, Performance is " << perf << " Gflops" + << std::endl; + benchmarker_gemm.set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()); + time = benchmarker_gemm.exec({{M, K}, {K, N}, {}}); + perf = 2.f * M * K * N / time * mod; + std::cout << "gemm fp16, Performance is " << perf << " Gflops" + << std::endl; + + }; + + std::cout << "warm up:\n"; + for (int i = 0; i < 50; i++) { + benchmarker_gemm.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_display(false) + .exec({{256, 256}, {256, 256}, {}}); + benchmarker_gemm.set_display(true); + } + + // run M = K = N + run(8, 8, 8); + run(16, 16, 16); + run(32, 32, 32); + run(64, 64, 64); + run(128, 128, 128); + run(256, 256, 256); + run(512, 512, 512); + run(1024, 1024, 1024); + run(2048, 2048, 2048); + + // run sgmev like + run(32, 32, 1); + run(64, 64, 1); + run(128, 128, 1); + run(256, 256, 1); + run(512, 512, 1); + + // run M, N >> K + run(32, 16, 32); + run(64, 16, 64); + run(128, 16, 128); + run(256, 16, 256); + run(512, 16, 512); + + // run N, K >> M + run(16, 32, 32); + run(16, 64, 64); + run(16, 128, 128); + run(16, 256, 256); + run(16, 512, 512); + + // run M >> K, N + run(32, 16, 16); + run(64, 16, 16); + run(128, 16, 16); + run(256, 16, 16); + run(512, 16, 16); + + // run K >> M, N + run(16, 32, 16); + run(16, 64, 16); + run(16, 128, 16); + run(16, 256, 16); + run(16, 512, 16); + + // run N >> M, K + run(16, 16, 32); + run(16, 16, 64); + run(16, 16, 128); + run(16, 16, 256); + run(16, 16, 512); + + // run VGG + // conv 1.1 + run(64, 3 * 3 * 3, 224 * 224); + // conv 1.2 + run(128, 64 * 3 * 3, 112 * 112); + // conv 2.1 + run(128, 128 * 3 * 3, 112 * 112); + // conv 2.2 + run(128, 128 * 3 * 3, 56 * 56); + // conv 3.1 + run(256, 128 * 3 * 3, 56 * 56); + // conv 3.2 + run(256, 256 * 3 * 3, 28 * 28); + // conv 4.1 + run(512, 256 * 3 * 3, 28 * 28); + // conv 4.2 + run(512, 512 * 3 * 3, 14 * 14); +} + + +#endif +#endif + +#if MEGDNN_WITH_BENCHMARK +TEST_F(AARCH64, BENCHMARK_CONVOLUTION_STRIDE2) { + using Param = param::Convolution; + auto run = [&](const TensorShapeArray& shapes, Param param) { + Benchmarker benchmarker_float(handle()); + size_t RUN = 50; + auto tfloat = + benchmarker_float.set_display(false) + .set_dtype(0, dtype::Float32{}) + .set_dtype(1, dtype::Float32{}) + .set_before_exec_callback(AlgoChecker( + "CONVOLUTION_DEFAULT_ARMV8F32STRD2_LARGE_" + "GROUP")) + .set_times(RUN) + .set_param(param) + .exec(shapes); + size_t IC = shapes[1][1]; + size_t FH = shapes[1][2]; + size_t FW = shapes[1][3]; + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout({shapes[0], dtype::Float32()}, + {shapes[1], dtype::Float32()}, dst_layout); + printf("fp32 flops: %.3f mflops\n", + (IC * dst_layout.total_nr_elems() * FH * FW * 2) / + (tfloat / RUN * 1000)); + }; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + auto run1 = [&](const TensorShapeArray& shapes, Param param) { + Benchmarker benchmarker_float(handle()); + size_t RUN = 50; + auto tfloat = + benchmarker_float.set_display(false) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_before_exec_callback(AlgoChecker( + "CONVOLUTION_DEFAULT_ARMV8F16STRD2_LARGE_" + "GROUP")) + .set_times(RUN) + .set_param(param) + .exec(shapes); + size_t IC = shapes[1][1]; + size_t FH = shapes[1][2]; + size_t FW = shapes[1][3]; + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout({shapes[0], dtype::Float16()}, + {shapes[1], dtype::Float16()}, dst_layout); + printf("fp16 flops: %.3f mflops\n", + (IC * dst_layout.total_nr_elems() * FH * FW * 2) / + (tfloat / RUN * 1000)); + }; +#endif + auto profile = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t stride) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + printf("oc: %zd ic: %zd w: %zd h: %zd stride: %zd kernel_size: %zd\n", + oc, ic, w, h, stride, kernel); + + run({{1, ic, h, w}, {oc, ic, kernel, kernel}, {}}, param); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + run1({{1, ic, h, w}, {oc, ic, kernel, kernel}, {}}, param); +#endif + + }; + + for (size_t kernel : {2, 3, 5, 7}) { + for (size_t ic : {3, 6, 12, 24}) { + for (size_t oc : {3, 6, 12, 24}) { + for (size_t size : {4, 7, 8, 14, 16, 17, 28, 32, 34, 64, 112}) { + profile(oc, ic, size, size, kernel, 2); + } + } + } + } +} +#endif + + +// vim: syntax=cpp.doxygen + diff --git a/dnn/test/aarch64/fixture.h b/dnn/test/aarch64/fixture.h new file mode 100644 index 00000000..3c8ffda9 --- /dev/null +++ b/dnn/test/aarch64/fixture.h @@ -0,0 +1,29 @@ +/** + * \file dnn/test/aarch64/fixture.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include + +#include "megdnn/handle.h" +#include "test/arm_common/fixture.h" + +#include + +namespace megdnn { +namespace test { + +class AARCH64 : public ARM_COMMON {}; + +class AARCH64_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {}; + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/aarch64/matrix_mul.cpp b/dnn/test/aarch64/matrix_mul.cpp new file mode 100644 index 00000000..e4baef13 --- /dev/null +++ b/dnn/test/aarch64/matrix_mul.cpp @@ -0,0 +1,612 @@ +/** + * \file dnn/test/aarch64/matrix_mul.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/aarch64/fixture.h" + +#include "test/common/benchmarker.h" +#include "test/common/checker.h" + +#include "test/common/matrix_mul.h" +#include "test/common/rng.h" + +using namespace megdnn; +using namespace test; + +TEST_F(AARCH64, MATRIX_MUL_FP32K8X12) { + matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, handle(), + "AARCH64_F32K8X12X1"); +} + +TEST_F(AARCH64, MATRIX_MUL_FP32K4X16) { + matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, handle(), + "AARCH64_F32K4X16X1"); +} + +TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) { + //! nbase should be 4 in order to test the last rest 4 in N dim + matrix_mul::check_matrix_mul( + dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), + "AARCH64_F32_MK4_4x16", param::MatrixMul::Format::MK4, 4); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(AARCH64, MATRIX_MUL_F16_K8X24X1) { + matrix_mul::check_matrix_mul(dtype::Float16{}, dtype::Float16{}, + dtype::Float16{}, handle(), + "AARCH64_F16_K8X24X1"); +} + +TEST_F(AARCH64, MATRIX_MUL_F16_MK8) { + //! nbase should be 4 in order to test the last rest 4 in N dim + matrix_mul::check_matrix_mul( + dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), + "AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 4); +} +#endif + +#if __ARM_FEATURE_DOTPROD +TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K8X12X4_DOTPROD) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, + handle(), "AARCH64_INT8X8X32_K8X12X4_DOTPROD"); +} +#else +TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K4X4X16) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, + handle(), "AARCH64_INT8X8X32_K4X4X16"); +} + +TEST_F(AARCH64, MATRIX_MUL_INT8_MK4) { + std::vector args; + for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11}) + for (size_t n : {1, 2, 3, 4, 5, 8, 16, 24, 25, 32}) + for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34}) + args.emplace_back(m, n, k, 0); + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, + handle(), "AARCH64_INT8X8X32_MK4_4X4X16", + param::MatrixMul::Format::MK4, 1, 1e-3, + std::move(args)); +} + +TEST_F(AARCH64, MATRIX_MUL_INT8x8x32_K8x8x8) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, + handle(), "AARCH64_INT8X8X32_K8X8X8"); +} +#endif + +TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_K8x8x8) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, + handle(), "AARCH64_INT8X8X16_K8X8X8"); +} + +TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_K4x4x16) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, + handle(), "AARCH64_INT8X8X16_K4X4X16"); +} + +TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_K12X8X1) { + matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, + handle(), "AARCH64_INT16X16X32_K12X8X1"); +} + +TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_MK8) { + //! nbase should be 4 in order to test the last rest 4 in N dim + matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, + handle(), "AARCH64_INT16X16X32_MK8_8X8", + param::MatrixMul::Format::MK8, 4); +} + +//! FIXME: need to add tests of GEMV and QUINT8 + +#if MEGDNN_WITH_BENCHMARK + +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_K4X16) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = false; + param.transposeB = false; + Benchmarker benchmarker_K4X16(handle()); + Benchmarker benchmarker_K12X8(handle()); + benchmarker_K4X16.set_times(RUNS) + .set_dtype(0, dtype::Float32{}) + .set_dtype(1, dtype::Float32{}) + .set_dtype(2, dtype::Float32{}) + .set_param(param) + .set_display(false); + benchmarker_K4X16.set_before_exec_callback( + AlgoChecker("AARCH64_F32K4X16X1")); + + benchmarker_K12X8.set_before_exec_callback( + AlgoChecker("AARCH64_F32K8X12X1")); + benchmarker_K12X8.set_times(RUNS) + .set_dtype(0, dtype::Float32{}) + .set_dtype(1, dtype::Float32{}) + .set_dtype(2, dtype::Float32{}) + .set_param(param) + .set_display(false); + + auto run = [&](size_t M, size_t N, size_t K) { + TensorShape A, B; + if (param.transposeA) { + A = TensorShape{K, M}; + } else { + A = TensorShape{M, K}; + } + if (param.transposeB) { + B = TensorShape{N, K}; + } else { + B = TensorShape{K, N}; + } + + auto k4x16_used = benchmarker_K4X16.exec({A, B, {}}) / RUNS; + auto k12x8_used = benchmarker_K12X8.exec({A, B, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} k4x16: %f ms %f Gflops k12x8: %f " + "ms " + "%f Gflops k4x16_vs_k12x8: %f\n", + M, K, N, k4x16_used, computations / k4x16_used, k12x8_used, + computations / k12x8_used, k12x8_used / k4x16_used); + }; + + run(256, 256, 128); + + for (size_t k = 4; k <= 256; k *= 8) { + for (size_t m = 4; m <= 256; m *= 4) { + for (size_t n = 4; n <= 256; n *= 4) { + run(m, n, k); + } + printf("\n"); + } + printf("\n"); + } +} + +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_8X8X8) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = false; + param.transposeB = false; + Benchmarker benchmarker_int(handle()); + Benchmarker benchmarker_int32(handle()); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_param(param) + .set_display(false); + benchmarker_int.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X16_K8X8X8")); + + benchmarker_int32.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X32_K8X8X8")); + benchmarker_int32.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int32{}) + .set_param(param) + .set_display(false); + Benchmarker benchmarker_float(handle()); + benchmarker_float.set_param(param).set_display(false).set_times(RUNS); + + auto run = [&](size_t M, size_t N, size_t K) { + TensorShape A, B; + if (param.transposeA) { + A = TensorShape{K, M}; + } else { + A = TensorShape{M, K}; + } + if (param.transposeB) { + B = TensorShape{N, K}; + } else { + B = TensorShape{K, N}; + } + + auto int_used = benchmarker_int.exec({A, B, {}}) / RUNS; + auto float_used = benchmarker_float.exec({A, B, {}}) / RUNS; + auto int32_used = benchmarker_int32.exec({A, B, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms " + "%f Gflops speedup_vs_fp32: %f, speedup_vs_int32: %f\n", + M, K, N, float_used, computations / float_used, int_used, + computations / int_used, float_used / int_used, + int32_used / int_used); + }; + + run(256, 256, 128); + + for (size_t k = 4; k <= 256; k *= 8) { + for (size_t m = 4; m <= 256; m *= 4) { + for (size_t n = 4; n <= 256; n *= 4) { + run(m, n, k); + } + std::cout << std::endl; + } + std::cout << std::endl; + } +} + +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT32_MK_4X4X16) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = false; + param.transposeB = false; + Benchmarker benchmarker(handle()); + Benchmarker benchmarker_mk4(handle()); + benchmarker.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int32{}) + .set_param(param) + .set_display(false); + benchmarker.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X32_K4X4X16")); + + param.format = MatrixMul::Param::Format::MK4; + benchmarker_mk4.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X32_MK4_4X4X16")); + benchmarker_mk4.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int32{}) + .set_param(param) + .set_display(false); + + auto run = [&](size_t M, size_t N, size_t K) { + auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS; + auto mk_used = benchmarker_mk4.exec( + {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / + RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms " + "%f Gflops speedup_vs_normal: %f\n", + M, K, N, default_used, computations / default_used, mk_used, + computations / mk_used, default_used / mk_used); + }; + + run(256, 256, 128); + for (size_t k = 4; k <= 512; k *= 2) { + for (size_t m = 4; m <= 512; m *= 2) { + for (size_t n = 4; n <= 512; n *= 2) { + run(m, n, k); + } + } + std::cout << std::endl; + } +} + +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = false; + param.transposeB = false; + Benchmarker benchmarker_int(handle()); + Benchmarker benchmarker_int32(handle()); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_param(param) + .set_display(false); + benchmarker_int.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X16_K4X4X16")); + + benchmarker_int32.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X32_K4X4X16")); + benchmarker_int32.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int32{}) + .set_param(param) + .set_display(false); + Benchmarker benchmarker_float(handle()); + benchmarker_float.set_param(param).set_display(false).set_times(RUNS); + + auto run = [&](size_t M, size_t N, size_t K) { + TensorShape A, B; + if (param.transposeA) { + A = TensorShape{K, M}; + } else { + A = TensorShape{M, K}; + } + if (param.transposeB) { + B = TensorShape{N, K}; + } else { + B = TensorShape{K, N}; + } + + auto int_used = benchmarker_int.exec({A, B, {}}) / RUNS; + auto float_used = benchmarker_float.exec({A, B, {}}) / RUNS; + auto int32_used = benchmarker_int32.exec({A, B, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms " + "%f Gflops speedup_vs_fp32: %f, speedup_vs_int32: %f\n", + M, K, N, float_used, computations / float_used, int_used, + computations / int_used, float_used / int_used, + int32_used / int_used); + }; + + run(256, 256, 128); + + for (size_t k = 4; k <= 16; k *= 2) { + for (size_t m = 4; m <= 64; m *= 2) { + for (size_t n = 4; n <= 64; n *= 2) { + run(m, n, k); + } + } + std::cout << std::endl; + } +} + +TEST_F(AARCH64, BENCHMARK_GEMV) { + int exec_times = 10; + Benchmarker benchmarker_gemm(handle()); + benchmarker_gemm.set_times(exec_times); + + float mod = 1000 * exec_times / 1e9; + auto run = [&](size_t M, size_t K, size_t N) { + float time = 1.f, perf = 1.f; + + std::cout << "GEMM: (" << M << ", " << K << ", " << N << ")" + << std::endl; + benchmarker_gemm.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()); + time = benchmarker_gemm.exec({{M, K}, {K, N}, {}}); + perf = 2.f * M * K * N / time * mod; + std::cout << "gemm fp32, Performance is " << perf << " Gflops" + << std::endl; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + benchmarker_gemm.set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()); + time = benchmarker_gemm.exec({{M, K}, {K, N}, {}}); + perf = 2.f * M * K * N / time * mod; + std::cout << "gemm fp16, Performance is " << perf << " Gflops" + << std::endl; +#endif + }; + + std::cout << "warm up:\n"; + for (int i = 0; i < 50; i++) { + benchmarker_gemm.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_display(false) + .exec({{256, 256}, {256, 256}, {}}); + benchmarker_gemm.set_display(true); + } + + // run gemv + for (size_t M : {1, 2, 3, 4, 5, 6, 7, 8, 64, 256}) + for (size_t K : {1, 2, 3, 4, 5, 6, 7, 8, 64, 256}) + for (size_t N : {112}) + run(M, K, N); +} + +#if __ARM_FEATURE_DOTPROD +TEST_F(AARCH64, BENCHMARK_TRANSPOSED_MATRIX_MUL_INT_8X8X32) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = param.transposeB = true; + Benchmarker benchmarker_int(handle()); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, {}) + .set_param(param) + .set_display(false); + Benchmarker benchmarker_float(handle()); + benchmarker_float.set_param(param).set_display(false).set_times(RUNS); + + auto run = [&](size_t M, size_t N, size_t K) { + auto int_used = benchmarker_int.exec({{K, M}, {N, K}, {}}) / RUNS; + auto float_used = benchmarker_float.exec({{K, M}, {N, K}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms " + "%f Gflops speedup: %f\n", + M, K, N, float_used, computations / float_used, int_used, + computations / int_used, float_used / int_used); + }; + + run(256, 12 * 24, 256); + + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 64, 112, 256}) { + for (size_t N : {8, 64, 112, 256}) { + run(M, N, K); + } + } + } +} + +TEST_F(AARCH64, BENCHMARK_GEMV_INT_8X8X32) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + Benchmarker benchmarker_int(handle()); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, {}) + .set_param(param) + .set_display(false); + Benchmarker benchmarker_float(handle()); + benchmarker_float.set_display(false).set_times(RUNS); + + auto run = [&](size_t M, size_t N, size_t K) { + auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS; + auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms " + "%f Gflops speedup: %f\n", + M, K, N, float_used, computations / float_used, int_used, + computations / int_used, float_used / int_used); + }; + + for (size_t M : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 256}) + for (size_t N : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 256}) + for (size_t K : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 256}) + run(M, N, K); +} + +#endif // __ARM_FEATURE_DOTPROD + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_F16_MK8) { + auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8); + matrix_mul::benchmark_with_contrast( + handle(), args, dtype::Float16{}, dtype::Float16{}, + dtype::Float16{}, "AARCH64_F16_MK8_8X8", + param::MatrixMul::Format::MK8, dtype::Float16{}, dtype::Float16{}, + dtype::Float16{}, "AARCH64_F16_K8X24X1"); +} +#endif + +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16x16x32) { + constexpr size_t RUNS = 50; + Benchmarker benchmarker_int(handle()); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::Int16{}) + .set_dtype(1, dtype::Int16{}) + .set_dtype(2, dtype::Int32{}) + .set_display(false); + Benchmarker benchmarker_float(handle()); + benchmarker_float.set_display(false).set_times(RUNS); + + auto run = [&](size_t M, size_t N, size_t K, int mask) { + param::MatrixMul param; + param.transposeA = mask & 0x1; + param.transposeB = mask & 0x2; + benchmarker_int.set_param(param); + benchmarker_float.set_param(param); + TensorShape A, B; + if (param.transposeA) { + A = TensorShape{K, M}; + } else { + A = TensorShape{M, K}; + } + if (param.transposeB) { + B = TensorShape{N, K}; + } else { + B = TensorShape{K, N}; + } + auto int_used = benchmarker_int.exec({A, B, {}}) / RUNS; + auto float_used = benchmarker_float.exec({A, B, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N} %d{TA} %d{TB}} " + "float: %f ms %f Gflops int: %f ms " + "%f Gflops speedup: %f\n", + M, K, N, param.transposeA, param.transposeB, float_used, + computations / float_used, int_used, computations / int_used, + float_used / int_used); + }; + + constexpr int mask = 4; + for (auto i = 0; i < mask; i++) { + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 64, 112, 256}) { + for (size_t N : {8, 64, 112, 256}) { + run(M, N, K, i); + } + } + } + } +} + +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_MK4) { + auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(16); + matrix_mul::benchmark_with_contrast( + handle(), args, dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, "AARCH64_F32_MK4_4x16", + param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}); +} + +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) { + auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8); + matrix_mul::benchmark_with_contrast( + handle(), args, dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, + "AARCH64_INT16X16X32_MK8_8X8", param::MatrixMul::Format::MK8, + dtype::Int16{}, dtype::Int16{}, dtype::Int32{}); +} + +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_K8X12) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = param.transposeB = true; + Benchmarker benchmarker_k12x8(handle()); + Benchmarker benchmarker_k8x12(handle()); + benchmarker_k12x8.set_param(param).set_display(false).set_times(RUNS); + benchmarker_k8x12.set_param(param).set_display(false).set_times(RUNS); + benchmarker_k12x8.set_before_exec_callback( + AlgoChecker("AARCH64_F32K4X16X1")); + + benchmarker_k8x12.set_before_exec_callback( + AlgoChecker("AARCH64_F32K8X12X1")); + + auto run = [&](size_t M, size_t N, size_t K) { + auto k12x8_used = benchmarker_k12x8.exec({{K, M}, {N, K}, {}}) / RUNS; + auto k8x12_used = benchmarker_k8x12.exec({{K, M}, {N, K}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float k12x8: %f ms %f Gflops " + "k8x12: %f ms " + "%f Gflops speedup: %f\n", + M, K, N, k12x8_used, computations / k12x8_used, k8x12_used, + computations / k8x12_used, k12x8_used / k8x12_used); + }; + + run(256, 12 * 24, 256); + + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 64, 112, 256}) { + for (size_t N : {8, 64, 112, 256}) { + run(M, N, K); + } + } + } +} + +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_K8X12_NO_TRANS) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = param.transposeB = false; + Benchmarker benchmarker_k12x8(handle()); + Benchmarker benchmarker_k8x12(handle()); + benchmarker_k12x8.set_param(param).set_display(false).set_times(RUNS); + benchmarker_k8x12.set_param(param).set_display(false).set_times(RUNS); + benchmarker_k12x8.set_before_exec_callback( + AlgoChecker("AARCH64_F32K4X16X1")); + + benchmarker_k8x12.set_before_exec_callback( + AlgoChecker("AARCH64_F32K8X12X1")); + + auto run = [&](size_t M, size_t N, size_t K) { + auto k12x8_used = benchmarker_k12x8.exec({{M, K}, {K, N}, {}}) / RUNS; + auto k8x12_used = benchmarker_k8x12.exec({{M, K}, {K, N}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float k12x8: %f ms %f Gflops " + "k8x12: %f ms " + "%f Gflops speedup: %f\n", + M, K, N, k12x8_used, computations / k12x8_used, k8x12_used, + computations / k8x12_used, k12x8_used / k8x12_used); + }; + + run(256, 12 * 24, 256); + + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 64, 112, 256}) { + for (size_t N : {8, 64, 112, 256}) { + run(M, N, K); + } + } + } +} + +#endif // MEGDNN_WITH_BENCHMARK + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/aarch64/pooling.cpp b/dnn/test/aarch64/pooling.cpp new file mode 100644 index 00000000..b890d329 --- /dev/null +++ b/dnn/test/aarch64/pooling.cpp @@ -0,0 +1,33 @@ +/** + * \file dnn/test/aarch64/pooling.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/aarch64/fixture.h" + +#include "test/common/pooling.h" +#include "test/common/checker.h" + +namespace megdnn { +namespace test { + +TEST_F(AARCH64, POOLING) +{ + auto args = pooling::get_args(); + for (auto &&arg: args) { + Checker checker(handle()); + checker.set_param(arg.param).exec(TensorShapeArray{ + arg.ishape, {}}); + } +} + +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen + + diff --git a/dnn/test/aarch64/rotate.cpp b/dnn/test/aarch64/rotate.cpp new file mode 100644 index 00000000..cedd45fd --- /dev/null +++ b/dnn/test/aarch64/rotate.cpp @@ -0,0 +1,73 @@ +/** + * \file dnn/test/aarch64/rotate.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/common/rotate.h" +#include "test/common/checker.h" +#include "test/common/benchmarker.h" + +#include "test/aarch64/fixture.h" + +namespace megdnn { +namespace test { + +TEST_F(AARCH64, ROTATE) +{ + using namespace rotate; + std::vector args = get_args(); + Checker checker(handle()); + + for (auto &&arg: args) { + checker.set_param(arg.param) + .set_dtype(0, arg.dtype) + .set_dtype(1, arg.dtype) + .execs({arg.src, {}}); + } +} + +TEST_F(AARCH64, BENCHMARK_ROTATE) +{ + using namespace rotate; + using Param = param::Rotate; + +#define BENCHMARK_PARAM(benchmarker) \ + benchmarker.set_param(param); \ + benchmarker.set_dtype(0, dtype::Uint8()); + + auto run = [&](const TensorShapeArray& shapes, Param param) { + auto handle_naive = create_cpu_handle(2); + Benchmarker benchmarker(handle()); + Benchmarker benchmarker_naive(handle_naive.get()); + + BENCHMARK_PARAM(benchmarker); + BENCHMARK_PARAM(benchmarker_naive); + for (auto&& shape : shapes) { + printf("execute %s: current---naive\n", shape.to_string().c_str()); + benchmarker.execs({shape, {}}); + benchmarker_naive.execs({shape, {}}); + } + }; + + Param param; + TensorShapeArray shapes = { + {1, 100, 100, 1}, + {2, 100, 100, 3}, + }; + + param.clockwise = true; + run(shapes, param); + + param.clockwise = false; + run(shapes, param); +} + + +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/test/aarch64/warp_perspective.cpp b/dnn/test/aarch64/warp_perspective.cpp new file mode 100644 index 00000000..ea4ee66e --- /dev/null +++ b/dnn/test/aarch64/warp_perspective.cpp @@ -0,0 +1,206 @@ +/** + * \file dnn/test/aarch64/warp_perspective.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include +#include "test/aarch64/fixture.h" + +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/random_state.h" +#include "test/common/rng.h" + +#include "test/common/warp_perspective.h" + +namespace megdnn { +namespace test { + +TEST_F(AARCH64, WARP_PERSPECTIVE_CV) { + //! Just for the format NHWC + Checker checker(handle()); + param::WarpPerspective param; + class ResizeMatRNG : public RNG { + void gen(const TensorND& tensor_) override { + auto& gen = RandomState::generator(); + std::uniform_real_distribution pdist3(1.9f, 3.1f); + std::uniform_real_distribution pdist(0.9f, 1.1f); + std::uniform_real_distribution pdisth(0.4f, 0.6f); + std::uniform_real_distribution ndist(-1.1f, -0.9f); + std::uniform_real_distribution ndist3(-3.1f, -1.9f); + std::uniform_real_distribution ndisth(-0.6f, -0.4f); + std::uniform_int_distribution dice(0, 5); + float* ptr = tensor_.ptr(); + auto N = tensor_.layout.shape[0]; + for (size_t n = 0; n < N; ++n) { + for (size_t i = 0; i < 9; ++i) { + switch (dice(gen)) { + case 0: + ptr[i] = pdist3(gen); + break; + case 1: + ptr[i] = pdist(gen); + break; + case 2: + ptr[i] = pdisth(gen); + break; + case 3: + ptr[i] = ndist(gen); + break; + case 4: + ptr[i] = ndist3(gen); + break; + case 5: + ptr[i] = ndisth(gen); + break; + } + } + // is resize? + if (n & 1) { + ptr[1] = 0; + ptr[3] = 0; + ptr[6] = ptr[7] = 0; + } + ptr += 9; + } + } + } rng; + + using BMode = param::WarpPerspective::BorderMode; + param.format = param::WarpPerspective::Format::NHWC; + // add for nearest test + param.imode = param::WarpPerspective::InterpolationMode::NEAREST; + for (auto mode : {BMode::REFLECT_101, BMode::REPLICATE, BMode::REFLECT, + BMode::WRAP, BMode::CONSTANT}) { + param.bmode = mode; + param.border_val = 1.737; + checker.set_param(param); + checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); + } + // resize nan case + UniformFloatRNG rng_zero(0, 0); + checker.set_rng(1, &rng_zero); + { + param.bmode = BMode::CONSTANT; + param.border_val = 1.737; + checker.set_param(param); + checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); + } + + // add linear test + param.imode = param::WarpPerspective::InterpolationMode::INTER_LINEAR; + for (auto mode : {BMode::REFLECT_101, BMode::REPLICATE, BMode::REFLECT, + BMode::WRAP, BMode::CONSTANT}) { + param.bmode = mode; + param.border_val = 1.737; + checker.set_param(param); + checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); + } + // resize nan case + checker.set_rng(1, &rng_zero); + { + param.bmode = BMode::CONSTANT; + param.border_val = 1.737; + checker.set_param(param); + checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); + } + + auto args = warp_perspective::get_cv_args(); + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Uint8()) + .execs({arg.src, arg.trans, arg.dst}); + } + + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .execs({arg.src, arg.trans, arg.dst}); + } +} + +#if MEGDNN_WITH_BENCHMARK +TEST_F(AARCH64, BENCHMARK_WARP_PERSPECTIVE_FORWARD) { + Benchmarker benchmarker(handle()); + auto handle_naive = create_cpu_handle(2); + Benchmarker benchmarker_naive(handle_naive.get()); + constexpr size_t NR_RUN = 50; + + using BMode = param::WarpPerspective::BorderMode; + using IMode = param::WarpPerspective::InterpolationMode; + + WarpPerspective::Param param; + param.border_val = 0.3f; + param.format = param::WarpPerspective::Format::NHWC; + + auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t OH, + size_t OW, size_t scale) { + printf("src={%zu, %zu, %zu, %zu}, dst={%zu, %zu, %zu, %zu}\n", N, IH, + IW, C, N, OH, OW, C); + auto time_ms = + benchmarker.exec({{N, IH, IW, C}, {N, 3, 3}, {N, OH, OW, C}}) / + NR_RUN; + auto time_naive_ms = + benchmarker_naive.exec( + {{N, IH, IW, C}, {N, 3, 3}, {N, OH, OW, C}}) / + NR_RUN; + auto bandwidth = N * C * (scale * OH * OW) * dtype::Float32().size(); + printf("aarch64: %.3f, perf: %.3f GBPS naive: %.3f, perf %.3f GBPS " + "speedup: %f\n", + time_ms, bandwidth / time_ms / 1e6, time_naive_ms, + bandwidth / time_naive_ms / 1e6, time_naive_ms / time_ms); + }; + + std::vector bmodestringmap = { + "REPLICATE", "REFLECT", "REFLECT_101", "WARP", "CONSTANT"}; + + std::vector imodestringmap = {"NEAREST", "INTER_LINEAR"}; + size_t scales[2] = {2, 5}; + + for (auto imode : {IMode::NEAREST, IMode::INTER_LINEAR}) { + for (auto bmode : {BMode::REFLECT_101, BMode::REPLICATE, BMode::REFLECT, + BMode::WRAP, BMode::CONSTANT}) { + param.imode = imode; + param.bmode = bmode; + benchmarker.set_param(param).set_display(false).set_times(NR_RUN); + benchmarker_naive.set_param(param).set_display(false).set_times( + NR_RUN); + size_t scale = scales[(int)imode]; + printf("\n\n\n warpperspective InterpolationMode::%s " + "BorderMode::%s start\n", + imodestringmap[(int)imode].c_str(), + bmodestringmap[(int)bmode].c_str()); + for (auto&& shape : + std::vector>{{700, 490}, + {500, 334}, + {472, 342}, + {448, 306}, + {626, 412}, + {140, 144}, + {120, 128}, + {180, 176}}) { + for (size_t ch : {1, 2, 3}) { + run(1, ch, shape.first, shape.second, 256, 256, scale); + } + } + } + } +} + +#endif + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp new file mode 100644 index 00000000..620e12a8 --- /dev/null +++ b/dnn/test/arm_common/conv_bias.cpp @@ -0,0 +1,1750 @@ +/** + * \file dnn/test/arm_common/conv_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "megdnn/dtype.h" +#include "test/arm_common/fixture.h" + +#include "megdnn/opr_param_defs.h" +#include "megdnn/oprs.h" +#include "src/fallback/conv_bias/common.h" +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/conv_bias.h" +#include "test/common/rng.h" +#include "test/common/tensor.h" +#include "test/common/workspace_wrapper.h" + +using namespace megdnn; +using namespace test; +using namespace conv_bias; + +//! TODO this algo current does not support multithread +TEST_F(ARM_COMMON, CONVBIAS_INT8_INT8_INT16_STRIDE2F2) { + checker_conv_bias_int8x8x16(get_conv_bias_args({2}, 2, true, true, true), + handle(), "I8816STRD2F2"); +} + +TEST_F(ARM_COMMON, CONV_BIAS_MATMUL) { + using namespace conv_bias; + std::vector args = get_quantized_args(); + Checker checker(handle()); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("S8MATMUL")); +#if MEGDNN_ARMV7 + checker.set_epsilon(1); +#endif + UniformIntRNG rng{-50, 50}; + for (auto&& arg : args) { + if (arg.bias.ndim == 4 && arg.bias[2] != 1 && arg.bias[3] != 1) + continue; + checker.set_dtype(0, dtype::QuantizedS8(0.41113496f)) + .set_dtype(1, dtype::QuantizedS8(0.01887994f)) + .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f)) + .set_dtype(4, dtype::QuantizedS8(0.49550694f)) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } +} + +TEST_F(ARM_COMMON, CONV_BIAS_MATMUL_QU8) { + using namespace conv_bias; + std::vector args = get_quantized_args(); + Checker checker(handle()); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("QU8MATMUL")); + + UniformIntRNG rng{0, 127}; + for (auto&& arg : args) { + if (arg.bias.ndim == 4 && arg.bias[2] != 1 && arg.bias[3] != 1) + continue; + checker.set_dtype(0, dtype::Quantized8Asymm(2.5f, + static_cast(127))) + .set_dtype(1, dtype::Quantized8Asymm(2.7f, + static_cast(126))) + .set_dtype(2, dtype::QuantizedS32(6.75f)) + .set_dtype(4, dtype::Quantized8Asymm(60.25f, + static_cast(125))) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } +} + +#if MEGDNN_WITH_BENCHMARK + +static void benchmark_convbias(Handle* handle) { + constexpr size_t RUNS = 30; + + Benchmarker benchmarker_int(handle); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::QuantizedS8(2.5)) + .set_dtype(1, dtype::QuantizedS8(2.5)) + .set_dtype(2, dtype::QuantizedS32(6.25)) + .set_dtype(4, dtype::QuantizedS8(60.25)) + .set_display(false); + benchmarker_int.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker( + "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384")); + + Benchmarker benchmarker_float(handle); + benchmarker_float.set_display(false).set_times(RUNS); + benchmarker_float.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(".+")); + + Benchmarker benchmarker_int_nchw44(handle); + benchmarker_int_nchw44.set_times(RUNS) + .set_dtype(0, dtype::QuantizedS8(2.5)) + .set_dtype(1, dtype::QuantizedS8(2.5)) + .set_dtype(2, dtype::QuantizedS32(6.25)) + .set_dtype(4, dtype::QuantizedS8(60.25)) + .set_display(false); + benchmarker_int_nchw44.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(".+")); + + auto run = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t stride, bool input_nchw = false) { + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.stride_h = stride; + param.stride_w = stride; + + param.pad_h = FS / 2; + param.pad_w = FS / 2; + auto OH = (H + 2 * param.pad_h - FS) / + static_cast(param.stride_h) + + 1; + auto OW = (W + 2 * param.pad_w - FS) / + static_cast(param.stride_w) + + 1; + TensorShape src({N, IC, H, W}), filter({OC, IC, FS, FS}), + bias({1, OC, 1, 1}), dst({N, OC, OH, OW}); + param.format = param::ConvBias::Format::NCHW; + auto int_used = benchmarker_int.set_param(param).exec( + {src, filter, bias, {}, dst}) / + RUNS; + auto float_used = benchmarker_float.set_param(param).exec( + {src, filter, bias, {}, dst}) / + RUNS; + param.format = param::ConvBias::Format::NCHW44; + src = {N, IC / 4, H, W, 4}; + filter = {OC / 4, IC / 4, FS, FS, 4, 4}; + if (input_nchw) { + src = {N, IC, H, W}; + filter = {OC / 4, FS, FS, IC, 4}; + } + + bias = {1, OC / 4, 1, 1, 4}; + dst = {N, OC / 4, OH, OW, 4}; + auto int_nchw44_used = benchmarker_int_nchw44.set_param(param).exec( + {src, filter, bias, {}, dst}) / + RUNS; + + float computations = IC * (FS * FS) * dst.total_nr_elems() * 2 * 1e-6; + printf("run: %s %s %s->%s \n", src.to_string().c_str(), + filter.to_string().c_str(), bias.to_string().c_str(), + dst.to_string().c_str()); + printf("float: %f ms %f Gflops, ", float_used, + computations / float_used); + printf("int_nchw: %f ms %f Gflops, ", int_used, + computations / int_used); + printf("int_nchw44: %f ms %f Gflops %f speedup, ", int_nchw44_used, + computations / int_nchw44_used, int_used / int_nchw44_used); + printf("\n"); + }; + run(1, 3, 32, 224, 224, 3, 2, true); + run(1, 3, 64, 224, 224, 5, 2, true); + run(1, 3, 64, 224, 224, 7, 2, true); + run(1, 3, 32, 224, 224, 7, 2, true); + for (size_t stride : {1, 2}) { + printf("stride %zu\n", stride); + for (size_t filter_size : {2, 3, 5, 7}) { + for (size_t img_size : {32}) { + for (size_t channel : {8, 16, 32, 64, 128, 256}) { + run(1, channel, channel, img_size, img_size, filter_size, + stride, false); + } + } + } + } +} +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) { + benchmark_convbias(handle()); +} +TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { + benchmark_convbias(handle()); +} +#endif +TEST_F(ARM_COMMON, CONV_BIAS_MATMUL_QS8) { + using namespace conv_bias; + std::vector args = get_quantized_args(); + Checker checker(handle()); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("S8MATMUL")); + +#if MEGDNN_ARMV7 + checker.set_epsilon(1); +#endif + UniformIntRNG rng{0, 255}; + for (auto&& arg : args) { + if (arg.bias.ndim == 4 && arg.bias[2] != 1 && arg.bias[3] != 1) + continue; + checker.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.7f)) + .set_dtype(2, dtype::QuantizedS32(6.75f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } +} + +#if MEGDNN_ARMV7 +TEST_F(ARM_COMMON, CONV_BIAS_RESCALE_OP) { + using namespace conv_bias; + + Checker checker(handle()); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("S8MATMUL")); + checker.set_epsilon(1).set_max_avg_error(1e-2).set_max_avg_biased_error( + 1e-3); + UniformIntRNG rng{-128, 127}; + checker.set_dtype(0, dtype::QuantizedS8(0.41113496f)) + .set_dtype(1, dtype::QuantizedS8(0.01887994f)) + .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f)) + .set_dtype(4, dtype::QuantizedS8(0.49550694f)) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng); + + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = 0; + param.pad_w = 0; + param.nonlineMode = NonlineMode::IDENTITY; + + //! Unary op + checker.set_param(param).exec({TensorShape{2, 1, 128, 128}, + TensorShape{16, 1, 2, 2}, + TensorShape{}, + TensorShape{}, + {}}); + //! Binary op + checker.set_param(param).exec({TensorShape{2, 1, 128, 128}, + TensorShape{16, 1, 2, 2}, + TensorShape{1, 16, 1, 1}, + TensorShape{}, + {}}); +} +#endif + +#if MEGDNN_WITH_BENCHMARK + +void benchmark_im2col(const char* algo_name, const char* im2col_name, + Handle* handle, size_t kernel, size_t pack_size = 1) { + auto&& args = get_winograd_benchmark_args(kernel, pack_size); + using namespace conv_bias; + constexpr size_t RUN = 10; + Benchmarker benchmark(handle); + benchmark.set_display(false); + benchmark.set_times(RUN); + + Benchmarker benchmark_im2col(handle); + benchmark_im2col.set_display(false); + benchmark_im2col.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, + {arg.bias, dtype::Float32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + benchmark.set_param(arg.param); + auto used = algo_benchmark(benchmark, + {arg.src, arg.filter, {}, {}, {}}, + algo_name) / + RUN; + benchmark_im2col.set_param(arg.param); + auto used_im2col = + algo_benchmark(benchmark_im2col, + {arg.src, arg.filter, {}, {}, {}}, + im2col_name) / + RUN; + + printf("%s %s: normal: %f ms %f Gflops im2col: %f ms %f GFlops " + "speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used, computations / used, used_im2col, + computations / used_im2col, used / used_im2col); + } +} + +void benchmark_im2col_single_algo(const char* im2col_name, Handle* handle, + size_t kernel, size_t pack_size = 1) { + std::vector args; + auto pack = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p) { + if (ic % pack_size != 0 || oc % pack_size != 0) + return; + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = p; + param.pad_w = p; + + args.push_back(conv_bias::TestArg{param, + TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + {1, oc, 1, 1}}); + }; + pack(1, 64, 100, 100, kernel, 1); + pack(8, 64, 100, 100, kernel, 1); + pack(16, 64, 100, 100, kernel, 1); + pack(32, 64, 100, 100, kernel, 1); + pack(64, 64, 100, 100, kernel, 1); + pack(128, 64, 100, 100, kernel, 1); + pack(256, 64, 100, 100, kernel, 1); + pack(512, 64, 100, 100, kernel, 1); + pack(1024, 64, 100, 100, kernel, 1); + pack(1, 64, 10, 10, kernel, 1); + pack(8, 64, 10, 10, kernel, 1); + pack(16, 64, 10, 10, kernel, 1); + pack(32, 64, 10, 10, kernel, 1); + pack(64, 64, 10, 10, kernel, 1); + pack(128, 64, 10, 10, kernel, 1); + pack(256, 64, 10, 10, kernel, 1); + pack(512, 64, 10, 10, kernel, 1); + pack(1024, 64, 10, 10, kernel, 1); + pack(1, 16, 10, 10, kernel, 1); + pack(8, 16, 10, 10, kernel, 1); + pack(16, 16, 10, 10, kernel, 1); + pack(32, 16, 10, 10, kernel, 1); + pack(64, 16, 10, 10, kernel, 1); + pack(128, 16, 10, 10, kernel, 1); + pack(256, 16, 10, 10, kernel, 1); + pack(512, 16, 10, 10, kernel, 1); + pack(1024, 16, 10, 10, kernel, 1); + + using namespace conv_bias; + constexpr size_t RUN = 20; + + Benchmarker benchmark_im2col(handle); + benchmark_im2col.set_display(false); + benchmark_im2col.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, + {arg.bias, dtype::Float32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + benchmark_im2col.set_param(arg.param); + auto used_im2col = + algo_benchmark(benchmark_im2col, + {arg.src, arg.filter, {}, {}, {}}, + im2col_name) / + RUN; + + printf("%s %s: im2col: %f ms %f GFlops \n", arg.src.to_string().c_str(), + arg.filter.to_string().c_str(), used_im2col, + computations / used_im2col); + } +} + +void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name, + const char* im2col_name, Handle* handle, + size_t kernel, size_t pack_size = 1) { + auto&& args = get_winograd_benchmark_args(kernel, pack_size); + using namespace conv_bias; + constexpr size_t RUN = 10; + Benchmarker benchmark(handle); + benchmark.set_display(false); + benchmark.set_times(RUN); + benchmark.set_dtype(0, dtype::Int8()); + benchmark.set_dtype(1, dtype::Int8()); + benchmark.set_dtype(2, dtype::Int32()); + benchmark.set_dtype(4, dtype::Int32()); + + Benchmarker benchmark_im2col(handle); + benchmark_im2col.set_display(false); + benchmark_im2col.set_times(RUN); + benchmark_im2col.set_dtype(0, dtype::Int8()); + benchmark_im2col.set_dtype(1, dtype::Int8()); + benchmark_im2col.set_dtype(2, dtype::Int32()); + benchmark_im2col.set_dtype(4, dtype::Int32()); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, + {arg.bias, dtype::Float32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + std::vector nchw44param; + + benchmark.set_param(arg.param); + auto used = algo_benchmark(benchmark, + {arg.src, arg.filter, {}, {}, {}}, + algo_name) / + RUN; + + arg.param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; + arg.param.format = param::ConvBias::Format::NCHW44; + benchmark_im2col.set_param(arg.param); + nchw44param.push_back(conv_bias::TestArg{ + arg.param, + TensorShape{arg.src.shape[0], arg.src.shape[1] / 4, arg.src[2], + arg.src.shape[3], 4}, + TensorShape{arg.filter.shape[0] / 4, arg.filter.shape[1] / 4, + kernel, kernel, 4, 4}, + TensorShape{}}); + + auto used_im2col = + algo_benchmark( + benchmark_im2col, + {nchw44param[0].src, nchw44param[0].filter, {}, {}, {}}, + im2col_name) / + RUN; + printf("nchw44 shape src %s filter %s\n", + nchw44param[0].src.to_string().c_str(), + nchw44param[0].filter.to_string().c_str()); + printf("%s %s: normal: %f ms %f Gflops im2col: %f ms %f GFlops " + "speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used, computations / used, used_im2col, + computations / used_im2col, used / used_im2col); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) { + printf("=========================compare " + "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16, " + "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16 \n"); + BENCHMARK_IM2COL_NCHW44_VS_NCHW("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16", + "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16", + handle(), 3, 4); +} + +TEST_F(ARM_COMMON, BENCHMARK_GROUP_CONVBIAS_QUANTIZED) { + constexpr size_t RUNS = 50; + param::ConvBias param; + param.sparse = param::ConvBias::Sparse::GROUP; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + Benchmarker benchmarker_int(handle()); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(40.25f)) + .set_display(false); + Benchmarker benchmarker_float(handle()); + benchmarker_float.set_display(false).set_times(RUNS); + + auto run = [&](size_t N, size_t GROUP, size_t IC, size_t OC, size_t H, + size_t W, size_t FS, size_t STRD) { + megdnn_assert(IC % GROUP == 0 && OC % GROUP == 0); + TensorShape src({N, IC, H, W}), + filter({GROUP, OC / GROUP, IC / GROUP, FS, FS}), + bias({1, OC, 1, 1}), dst({N, OC, H / STRD, W / STRD}); + param.pad_h = FS / 2; + param.pad_w = FS / 2; + param.stride_h = STRD; + param.stride_w = STRD; + auto int_used = benchmarker_int.set_param(param).exec( + {src, filter, bias, {}, dst}) / + RUNS; + auto float_used = benchmarker_float.set_param(param).exec( + {src, filter, bias, {}, dst}) / + RUNS; + float computations = (IC / GROUP * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + printf("run: %s %s %s->%s \nfloat: %f ms %f Gflops int: %f ms " + "%f Gflops speedup: %f\n", + src.to_string().c_str(), filter.to_string().c_str(), + bias.to_string().c_str(), dst.to_string().c_str(), float_used, + computations / float_used, int_used, computations / int_used, + float_used / int_used); + }; + + run(1, 1, 28, 28, 28, 28, 3, 1); + run(1, 68, 68, 68, 14, 14, 3, 2); + run(1, 96, 96, 96, 14, 14, 3, 2); + run(1, 100, 100, 100, 7, 7, 3, 1); +} +#endif + +#if MEGDNN_WITH_BENCHMARK +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_MATMUL) { + constexpr size_t RUNS = 10; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + Benchmarker benchmarker(handle()), benchmarker_fused(handle()); + benchmarker.set_times(RUNS) + .set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(40.25f)) + .set_display(false); + benchmarker_fused.set_times(RUNS) + .set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(40.25f)) + .set_display(false); + benchmarker_fused.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("S8MATMUL")); + + auto run = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS) { + TensorShape src({N, IC, H, W}), filter({OC, IC, FS, FS}), + bias({1, OC, 1, 1}), dst({N, OC, H, W}); + param.pad_h = FS / 2; + param.pad_w = FS / 2; + auto default_used = benchmarker.set_param(param).exec( + {src, filter, bias, {}, dst}) / + RUNS; + auto fused_used = benchmarker_fused.set_param(param).exec( + {src, filter, bias, {}, dst}) / + RUNS; + float computations = + IC * (FS * FS + 1) * dst.total_nr_elems() * 2 * 1e-6; + printf("run: %s %s %s->%s \ndefault: %f ms %f Gflops fused: %f ms " + "%f Gflops speedup: %f\n", + src.to_string().c_str(), filter.to_string().c_str(), + bias.to_string().c_str(), dst.to_string().c_str(), default_used, + computations / default_used, fused_used, + computations / fused_used, default_used / fused_used); + }; + + run(1, 128, 128, 32, 32, 3); + + for (size_t IC : {36, 48}) { + for (size_t OC : {36, 48, 64}) { + for (size_t size : {56, 128, 256}) { + for (size_t FS : {1, 3, 5}) { + run(1, IC, OC, size, size, FS); + } + } + } + } +} +#endif +#if MEGDNN_WITH_BENCHMARK + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23) { +#if MEGDNN_AARCH64 + benchmark_winograd("WINOGRAD:AARCH64_F32:1:2", handle(), 3); +#else + benchmark_winograd("WINOGRAD:ARMV7_F32_:1:2", handle(), 3); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_4x4) { +#if MEGDNN_AARCH64 + benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:2", handle(), 3, 4); +#else + benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:2", handle(), 3, 4); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) { +#if MEGDNN_AARCH64 + benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:6", handle(), 3); +#else + benchmark_winograd("WINOGRAD:ARMV7_F32:1:6", handle(), 3); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63_4x4) { +#if MEGDNN_AARCH64 + benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:6", handle(), 3, 4); +#else + benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:6", handle(), 3, 4); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F54) { +#if MEGDNN_AARCH64 + benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:5", handle(), 4); +#else + benchmark_winograd("WINOGRAD:ARMV7_F32:1:5", handle(), 4); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F45) { +#if MEGDNN_AARCH64 + benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:4", handle(), 5); +#else + benchmark_winograd("WINOGRAD:ARMV7_F32:1:4", handle(), 5); +#endif +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F23) { +#if MEGDNN_AARCH64 + benchmark_winograd_fp16("WINOGRAD:AARCH64_F32_MK4_4x16:4:2", + "WINOGRAD:AARCH64_F16_K8X24X1:1:6", handle(), 3, 4); +#else + benchmark_winograd_fp16("WINOGRAD:ARMV7_F32:1:2", + "WINOGRAD:AARCH32_F16_K4X16X1:1:2", handle(), 3); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F45) { +#if MEGDNN_AARCH64 + benchmark_winograd_fp16("WINOGRAD:AARCH64_F32K8X12X1:1:4", + "WINOGRAD:AARCH64_F16_K8X24X1:1:4", handle(), 5); +#else + benchmark_winograd_fp16("WINOGRAD:ARMV7_F32:1:4", + "WINOGRAD:AARCH32_F16_K4X16X1:1:4", handle(), 5); +#endif +} +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F63) { +#if MEGDNN_AARCH64 + benchmark_winograd_fp16("WINOGRAD:AARCH64_F32K8X12X1:1:6", + "WINOGRAD:AARCH64_F16_K8X24X1:1:6", handle(), 3); +#else + benchmark_winograd_fp16("WINOGRAD:ARMV7_F32:1:6", + "WINOGRAD:AARCH32_F16_K4X16X1:1:6", handle(), 3); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F23_8x8) { +#if MEGDNN_AARCH64 + benchmark_winograd_fp16("WINOGRAD:AARCH64_F32_MK4_4x16:4:2", + "WINOGRAD:AARCH64_F16_MK8_8X8:8:2", handle(), 3, 8); +#else + benchmark_winograd_fp16("WINOGRAD:ARMV7_F32_MK4_4x8:4:2", + "WINOGRAD:AARCH32_F16_MK8_4X8:8:2", handle(), 3, 8); +#endif +} +#endif + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_8x8) { + auto benchmark_winograd_quantized = [](const char* algo_name_fp32, + const char* algo_name_quantized, + Handle* handle, size_t kernel) { + auto&& args = get_winograd_benchmark_args(kernel); + using namespace conv_bias; + constexpr size_t RUN = 10; + Benchmarker benchmark(handle); + benchmark.set_display(false); + benchmark.set_times(RUN); + + Benchmarker benchmark_winograd(handle); + benchmark_winograd.set_display(false).set_times(RUN); + benchmark_winograd.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, + {arg.bias, dtype::Float32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + benchmark.set_param(arg.param); + auto used = algo_benchmark( + benchmark, {arg.src, arg.filter, {}, {}, {}}, + algo_name_fp32) / + RUN; + + benchmark_winograd.set_param(arg.param); + auto used_winograd = + algo_benchmark(benchmark_winograd, + {arg.src, arg.filter, {}, {}, {}}, + algo_name_quantized) / + RUN; + + printf("%s %s: normal: %f ms %f Gflops winograd: %f ms %f GFlops " + "speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used, computations / used, used_winograd, + computations / used_winograd, used / used_winograd); + } + }; + +#if MEGDNN_AARCH64 + benchmark_winograd_quantized("WINOGRAD:AARCH64_F32_MK4_4x16:4:2", + "WINOGRAD:AARCH64_INT16X16X32_MK8_8X8:8:2", + handle(), 3); +#else + benchmark_winograd_quantized("WINOGRAD:ARMV7_F32_MK4_4x8:4:2", + "WINOGRAD:ARMV7_INT16X16X32_MK8_4X8:8:2", + handle(), 3); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace conv_bias; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p, NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + + //! channel bias + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t ic : {1, 8, 16, 32}) + for (size_t oc : {1, 8, 16, 32}) + for (size_t p : {1}) + for (NonlineMode nonline_mode : {NonlineMode::RELU}) { + run(oc, ic, 56, 56, kernel, p, nonline_mode); + run(oc, ic, 128, 128, kernel, p, nonline_mode); + run(oc, ic, 256, 256, kernel, p, nonline_mode); + } + constexpr size_t RUN = 50; + Benchmarker benchmark0(handle()); + benchmark0.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + benchmark0.set_display(false); + benchmark0.set_times(RUN); + benchmark0.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("S8STRD1")); + + Benchmarker benchmark1(handle()); + benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + benchmark1.set_display(false); + benchmark1.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Int8()}, + {arg.filter, dtype::Int8()}, + {arg.bias, dtype::Int32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used0 = benchmark0.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + auto used1 = benchmark1.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + + printf("%s %s: conv_bias: %f ms %f Gflops conv_elem: %f ms %f GFlops " + "speedup: %f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used0, computations / used0, used1, computations / used1, + used1 / used0); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE2) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace conv_bias; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p, NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 2; + param.stride_w = 2; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + + //! channel bias + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t ic : {1, 8, 16, 32}) + for (size_t oc : {1, 8, 16, 32}) + for (size_t p : {1}) + for (NonlineMode nonline_mode : {NonlineMode::RELU}) { + run(oc, ic, 56, 56, kernel, p, nonline_mode); + run(oc, ic, 128, 128, kernel, p, nonline_mode); + run(oc, ic, 256, 256, kernel, p, nonline_mode); + } + + constexpr size_t RUN = 50; + Benchmarker benchmark0(handle()); + benchmark0.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + benchmark0.set_display(false); + benchmark0.set_times(RUN); + benchmark0.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("S8STRD2")); + + Benchmarker benchmark1(handle()); + benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + benchmark1.set_display(false); + benchmark1.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Int8()}, + {arg.filter, dtype::Int8()}, + {arg.bias, dtype::Int32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used0 = benchmark0.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + auto used1 = benchmark1.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + + printf("%s %s: conv_bias: %f ms %f Gflops conv_elem: %f ms %f GFlops " + "speedup: %f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used0, computations / used0, used1, computations / used1, + used1 / used0); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE1) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace conv_bias; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p, NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + + //! channel bias + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t ic : {1, 8, 16, 32}) + for (size_t oc : {1, 8, 16, 32}) + for (size_t p : {1}) + for (NonlineMode nonline_mode : {NonlineMode::RELU}) { + run(oc, ic, 56, 56, kernel, p, nonline_mode); + run(oc, ic, 128, 128, kernel, p, nonline_mode); + run(oc, ic, 256, 256, kernel, p, nonline_mode); + } + constexpr size_t RUN = 50; + Benchmarker benchmark0(handle()); + benchmark0 + .set_dtype(0, + dtype::Quantized8Asymm(0.2f, static_cast(100))) + .set_dtype(1, + dtype::Quantized8Asymm(0.2f, static_cast(120))) + .set_dtype(2, dtype::QuantizedS32(0.04f)) + .set_dtype(4, + dtype::Quantized8Asymm(1.4f, static_cast(110))); + benchmark0.set_display(false); + benchmark0.set_times(RUN); + benchmark0.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("QU8STRD1")); + + Benchmarker benchmark1(handle()); + benchmark1 + .set_dtype(0, + dtype::Quantized8Asymm(0.2f, static_cast(100))) + .set_dtype(1, + dtype::Quantized8Asymm(0.2f, static_cast(120))) + .set_dtype(2, dtype::QuantizedS32(0.04f)) + .set_dtype(4, + dtype::Quantized8Asymm(1.4f, static_cast(110))); + benchmark1.set_display(false); + benchmark1.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Int8()}, + {arg.filter, dtype::Int8()}, + {arg.bias, dtype::Int32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used0 = benchmark0.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + auto used1 = benchmark1.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + + printf("%s %s: conv_bias: %f ms %f Gflops conv_elem: %f ms %f GFlops " + "speedup: %f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used0, computations / used0, used1, computations / used1, + used1 / used0); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace conv_bias; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p, NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 2; + param.stride_w = 2; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + + //! channel bias + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t ic : {1, 8, 16, 32}) + for (size_t oc : {1, 8, 16, 32}) + for (size_t p : {1}) + for (NonlineMode nonline_mode : {NonlineMode::RELU}) { + run(oc, ic, 56, 56, kernel, p, nonline_mode); + run(oc, ic, 128, 128, kernel, p, nonline_mode); + run(oc, ic, 256, 256, kernel, p, nonline_mode); + } + constexpr size_t RUN = 50; + Benchmarker benchmark0(handle()); + benchmark0 + .set_dtype(0, + dtype::Quantized8Asymm(0.2f, static_cast(100))) + .set_dtype(1, + dtype::Quantized8Asymm(0.2f, static_cast(120))) + .set_dtype(2, dtype::QuantizedS32(0.04f)) + .set_dtype(4, + dtype::Quantized8Asymm(1.4f, static_cast(110))); + benchmark0.set_display(false); + benchmark0.set_times(RUN); + benchmark0.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("QU8STRD2")); + + Benchmarker benchmark1(handle()); + benchmark1 + .set_dtype(0, + dtype::Quantized8Asymm(0.2f, static_cast(100))) + .set_dtype(1, + dtype::Quantized8Asymm(0.2f, static_cast(120))) + .set_dtype(2, dtype::QuantizedS32(0.04f)) + .set_dtype(4, + dtype::Quantized8Asymm(1.4f, static_cast(110))); + benchmark1.set_display(false); + benchmark1.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Int8()}, + {arg.filter, dtype::Int8()}, + {arg.bias, dtype::Int32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used0 = benchmark0.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + auto used1 = benchmark1.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + + printf("%s %s: conv_bias: %f ms %f Gflops conv_elem: %f ms %f GFlops " + "speedup: %f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used0, computations / used0, used1, computations / used1, + used1 / used0); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace conv_bias; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = 1; + param.pad_w = 1; + param.nonlineMode = NonlineMode::RELU; + param.sparse = param::ConvBias::Sparse::GROUP; + + constexpr size_t RUN = 50; + Benchmarker benchmark0(handle()); + benchmark0.set_dtype(0, dtype::QuantizedS8(0.2f)) + .set_dtype(1, dtype::QuantizedS8(0.2f)) + .set_dtype(2, dtype::QuantizedS32(0.04f)) + .set_dtype(4, dtype::QuantizedS8(1.4f)); + benchmark0.set_display(false); + benchmark0.set_param(param); + benchmark0.set_times(RUN); + benchmark0.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker( + "S8STRD1_LARGE_GROUP")); + + auto opr = handle()->create_operator(); + opr->param() = param; + + param.format = param::ConvBias::Format::NCHW44; + Benchmarker benchmark1(handle()); + benchmark1.set_dtype(0, dtype::QuantizedS8(0.2f)) + .set_dtype(1, dtype::QuantizedS8(0.2f)) + .set_dtype(2, dtype::QuantizedS32(0.04f)) + .set_dtype(4, dtype::QuantizedS8(1.4f)); + benchmark1.set_display(false); + benchmark1.set_param(param); + benchmark1.set_times(RUN); + benchmark1.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker( + "S8_CHAN_WISE_STRD1_NCHW44")); + auto run = [&](size_t group, size_t w, size_t h, size_t kernel) { + TensorLayout dst_layout; + opr->deduce_layout({{1, group * 4, h, w}, dtype::Int8()}, + {{group * 4, 1, 1, kernel, kernel}, dtype::Int8()}, + {{1, group * 4, 1, 1}, dtype::Int32()}, {}, + dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * kernel * kernel * + 2.0 / (1024 * 1024 * 1024) * 1e3; + + auto used0 = benchmark0.exec({{1, group * 4, h, w}, + {group * 4, 1, 1, kernel, kernel}, + {1, group * 4, 1, 1}, + {}, + {}}) / + RUN; + auto used1 = benchmark1.exec({{1, group, h, w, 4}, + {group, 1, 1, kernel, kernel, 4}, + {1, group, 1, 1, 4}, + {}, + {}}) / + RUN; + printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops " + "nchw44: " + "%f ms %f GFlops " + "speedup: %f\n", + group, h, w, kernel, used0, computations / used0, used1, + computations / used1, used0 / used1); + }; + for (size_t group : {8, 16, 32, 64, 128}) { + for (size_t kerenl : {2, 3, 5}) { + run(group, 112, 112, kerenl); + run(group, 56, 56, kerenl); + run(group, 48, 48, kerenl); + run(group, 28, 28, kerenl); + run(group, 14, 14, kerenl); + } + } +} + +#endif + +#if __ARM_FEATURE_DOTPROD +#if MEGDNN_WITH_BENCHMARK +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace conv_bias; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p, NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + + //! channel bias + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t ic : {1, 8, 16, 32}) + for (size_t oc : {1, 8, 16, 32}) + for (size_t p : {1}) + for (NonlineMode nonline_mode : {NonlineMode::RELU}) { + run(oc, ic, 56, 56, kernel, p, nonline_mode); + run(oc, ic, 128, 128, kernel, p, nonline_mode); + run(oc, ic, 256, 256, kernel, p, nonline_mode); + } + constexpr size_t RUN = 50; + Benchmarker benchmark0(handle()); + benchmark0.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + benchmark0.set_display(false); + benchmark0.set_times(RUN); + benchmark0.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("ARMDOTS8STRD1")); + + Benchmarker benchmark1(handle()); + benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + benchmark1.set_display(false); + benchmark1.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Int8()}, + {arg.filter, dtype::Int8()}, + {arg.bias, dtype::Int32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used0 = benchmark0.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + auto used1 = benchmark1.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + + printf("%s %s: conv_bias: %f ms %f Gflops conv_elem: %f ms %f GFlops " + "speedup: %f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used0, computations / used0, used1, computations / used1, + used1 / used0); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE2_WITHDOTPROD) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace conv_bias; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p, NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 2; + param.stride_w = 2; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + + //! channel bias + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t ic : {1, 8, 16, 32}) + for (size_t oc : {1, 8, 16, 32}) + for (size_t p : {1}) + for (NonlineMode nonline_mode : {NonlineMode::RELU}) { + run(oc, ic, 56, 56, kernel, p, nonline_mode); + run(oc, ic, 128, 128, kernel, p, nonline_mode); + run(oc, ic, 256, 256, kernel, p, nonline_mode); + } + + constexpr size_t RUN = 50; + Benchmarker benchmark0(handle()); + benchmark0.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + benchmark0.set_display(false); + benchmark0.set_times(RUN); + benchmark0.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("ARMDOTS8STRD2")); + + Benchmarker benchmark1(handle()); + benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + benchmark1.set_display(false); + benchmark1.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Int8()}, + {arg.filter, dtype::Int8()}, + {arg.bias, dtype::Int32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used0 = benchmark0.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + auto used1 = benchmark1.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + + printf("%s %s: conv_bias: %f ms %f Gflops conv_elem: %f ms %f GFlops " + "speedup: %f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used0, computations / used0, used1, computations / used1, + used1 / used0); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace conv_bias; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p, NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + + //! channel bias + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + }; + + // clang-format off + for (size_t kernel : {2, 3, 5, 7}) + for (size_t ic : {1, 8, 16, 32}) + for (size_t oc : {1, 8, 16, 32}) + for (size_t p : {1}) + for (NonlineMode nonline_mode : {NonlineMode::RELU}) { + run(oc, ic, 56, 56, kernel, p, nonline_mode); + run(oc, ic, 128, 128, kernel, p, nonline_mode); + run(oc, ic, 256, 256, kernel, p, nonline_mode); + } + // clang-format on + constexpr size_t RUN = 50; + Benchmarker benchmark0(handle()); + benchmark0 + .set_dtype(0, + dtype::Quantized8Asymm(0.2f, static_cast(100))) + .set_dtype(1, + dtype::Quantized8Asymm(0.2f, static_cast(120))) + .set_dtype(2, dtype::QuantizedS32(0.04f)) + .set_dtype(4, + dtype::Quantized8Asymm(1.4f, static_cast(110))); + benchmark0.set_display(false); + benchmark0.set_times(RUN); + benchmark0.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("ARMDOTU8STRD1")); + + Benchmarker benchmark1(handle()); + benchmark1 + .set_dtype(0, + dtype::Quantized8Asymm(0.2f, static_cast(100))) + .set_dtype(1, + dtype::Quantized8Asymm(0.2f, static_cast(120))) + .set_dtype(2, dtype::QuantizedS32(0.04f)) + .set_dtype(4, + dtype::Quantized8Asymm(1.4f, static_cast(110))); + benchmark1.set_display(false); + benchmark1.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Int8()}, + {arg.filter, dtype::Int8()}, + {arg.bias, dtype::Int32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used0 = benchmark0.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + auto used1 = benchmark1.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + + printf("%s %s: conv_bias: %f ms %f Gflops conv_elem: %f ms %f GFlops " + "speedup: %f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used0, computations / used0, used1, computations / used1, + used1 / used0); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace conv_bias; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p, NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 2; + param.stride_w = 2; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + + //! channel bias + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + }; + + // clang-format off + for (size_t kernel : {2, 3, 5, 7}) + for (size_t ic : {1, 8, 16, 32}) + for (size_t oc : {1, 8, 16, 32}) + for (size_t p : {1}) + for (NonlineMode nonline_mode : {NonlineMode::RELU}) { + run(oc, ic, 56, 56, kernel, p, nonline_mode); + run(oc, ic, 128, 128, kernel, p, nonline_mode); + run(oc, ic, 256, 256, kernel, p, nonline_mode); + } + // clang-format on + constexpr size_t RUN = 50; + Benchmarker benchmark0(handle()); + benchmark0 + .set_dtype(0, + dtype::Quantized8Asymm(0.2f, static_cast(100))) + .set_dtype(1, + dtype::Quantized8Asymm(0.2f, static_cast(120))) + .set_dtype(2, dtype::QuantizedS32(0.04f)) + .set_dtype(4, + dtype::Quantized8Asymm(1.4f, static_cast(110))); + benchmark0.set_display(false); + benchmark0.set_times(RUN); + benchmark0.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("ARMDOTU8STRD2")); + + Benchmarker benchmark1(handle()); + benchmark1 + .set_dtype(0, + dtype::Quantized8Asymm(0.2f, static_cast(100))) + .set_dtype(1, + dtype::Quantized8Asymm(0.2f, static_cast(120))) + .set_dtype(2, dtype::QuantizedS32(0.04f)) + .set_dtype(4, + dtype::Quantized8Asymm(1.4f, static_cast(110))); + benchmark1.set_display(false); + benchmark1.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Int8()}, + {arg.filter, dtype::Int8()}, + {arg.bias, dtype::Int32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used0 = benchmark0.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + auto used1 = benchmark1.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + + printf("%s %s: conv_bias: %f ms %f Gflops conv_elem: %f ms %f GFlops " + "speedup: %f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used0, computations / used0, used1, computations / used1, + used1 / used0); + } +} +#endif +#endif + +/*====================== BENCHMARK CONV1X1 ===========================*/ +#if MEGDNN_WITH_BENCHMARK + +namespace { +std::vector get_conv_bias_1x1_benchmark_args(size_t pack_size = 1) { + using namespace conv_bias; + std::vector args; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = 0; + param.pad_w = 0; + param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; + auto bench_case = [&](size_t OC, size_t IC, size_t H, size_t W) { + if(pack_size == 1) + args.emplace_back(param, TensorShape{1, IC, H, W}, + TensorShape{OC, IC, 1, 1}, TensorShape{}); + else { + if(pack_size == 4) + param.format = param::ConvBias::Format::NCHW44; + args.emplace_back(param, TensorShape{1, IC / pack_size, H, W, pack_size}, + TensorShape{OC / pack_size, IC / pack_size, 1, 1, pack_size, pack_size}, + TensorShape{}); + } + }; + + //! MobileNetV1 + bench_case(64, 32, 112, 112); + bench_case(128, 64, 56, 56); + bench_case(128, 128, 56, 56); + bench_case(256, 128, 28, 28); + bench_case(256, 256, 28, 28); + bench_case(512, 256, 14, 14); + bench_case(512, 512, 14, 14); + bench_case(1024, 512, 7, 7); + bench_case(1024, 1024, 7, 7); + + //! MobileNetV2 + bench_case(16, 32, 112, 112); + bench_case(96, 16, 112, 112); + bench_case(144, 24, 56, 56); + bench_case(192, 32, 28, 28); + bench_case(384, 64, 28, 28); + bench_case(576, 96, 14, 14); + bench_case(960, 160, 7, 7); + bench_case(320, 960, 7, 7); + bench_case(1280, 320, 7, 7); + + //! MobileNetV3-Large + bench_case(64, 16, 112, 112); + bench_case(72, 24, 56, 56); + bench_case(120, 40, 28, 28); + bench_case(240, 40, 28, 28); + bench_case(200, 80, 14, 14); + bench_case(184, 80, 14, 14); + bench_case(480, 80, 14, 14); + bench_case(672, 112, 14, 14); + + //! MobileNetV3-Small + bench_case(72, 16, 56, 56); + bench_case(88, 24, 28, 28); + bench_case(96, 24, 28, 28); + bench_case(240, 40, 14, 14); + bench_case(120, 40, 14, 14); + bench_case(144, 48, 14, 14); + bench_case(288, 48, 14, 14); + bench_case(576, 96, 7, 7); + + //! resnet50 + bench_case(256, 64, 56, 56); + bench_case(512, 128, 28, 28); + bench_case(1024, 256, 14, 14); + bench_case(2048, 512, 7, 7); + + return args; +} + +void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, + DType stype, DType matmul_dtype, DType bias_type, + DType conv_dtype) { + using namespace conv_bias; + std::vector conv_bias_1x1_args = + get_conv_bias_1x1_benchmark_args(); + constexpr size_t RUNS = 50; + + param::MatrixMul param; + param.transposeA = false; + param.transposeB = false; + Benchmarker benchmark_matmul(handle); + benchmark_matmul.set_before_exec_callback( + AlgoChecker(matmul_algo_name)); + benchmark_matmul.set_times(RUNS) + .set_dtype(0, stype) + .set_dtype(1, stype) + .set_dtype(2, matmul_dtype) + .set_param(param) + .set_display(false); + + std::string conv1x1_algo_name = ssprintf("CONV1x1:%s:24", matmul_algo_name); + Benchmarker benchmark_conv1x1(handle); + benchmark_conv1x1.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker( + conv1x1_algo_name.c_str())); + benchmark_conv1x1.set_times(RUNS) + .set_dtype(0, stype) + .set_dtype(1, stype) + .set_dtype(2, bias_type) + .set_dtype(4, conv_dtype) + .set_display(false); + + for (auto&& arg : conv_bias_1x1_args) { + size_t IC = arg.src[1]; + size_t OH = arg.src[2]; + size_t OW = arg.src[3]; + size_t OC = arg.filter[0]; + size_t M = OC; + size_t K = IC; + size_t N = OH * OW; + + float computations = M * N * K * 2.f / (1024 * 1024 * 1024) * 1e3; + + TensorShape A, B; + A = TensorShape{M, K}; + B = TensorShape{K, N}; + + auto conv1x1_used = benchmark_conv1x1.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUNS; + auto matmul_used = benchmark_matmul.exec({A, B, {}}) / RUNS; + + printf("\n%s: ", matmul_algo_name); + printf("%s %s:\n matmul: %f ms %f Gflops\nconv1x1: %f ms %f GFlops " + "speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + matmul_used, computations / matmul_used, conv1x1_used, + computations / conv1x1_used, matmul_used / conv1x1_used); + } +} +} // namespace + +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_F32) { +#if MEGDNN_AARCH64 + benchmark_conv1x1("AARCH64_F32K8X12X1", handle(), dtype::Float32{}, + dtype::Float32{}, dtype::Float32{}, dtype::Float32{}); +#else + benchmark_conv1x1("ARMV7_F32", handle(), dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, dtype::Float32{}); +#endif +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_F16) { +#if MEGDNN_AARCH64 + benchmark_conv1x1("AARCH64_F16_K8X24X1", handle(), dtype::Float16{}, + dtype::Float16{}, dtype::Float16{}, dtype::Float16{}); +#else + benchmark_conv1x1("AARCH32_F16_K4X16X1", handle(), dtype::Float16{}, + dtype::Float16{}, dtype::Float16{}, dtype::Float16{}); +#endif +} +#endif + +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDSYM) { + dtype::QuantizedS8 stype(2.5f); + dtype::QuantizedS32 dtype(6.25f); +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + benchmark_conv1x1("AARCH64_INT8X8X32_K8X12X4_DOTPROD", handle(), stype, + dtype, dtype, dtype); +#else + benchmark_conv1x1("AARCH64_INT8X8X32_K8X8X8", handle(), stype, dtype, dtype, + dtype); + benchmark_conv1x1("AARCH64_INT8X8X32_K4X4X16", handle(), stype, dtype, + dtype, dtype); +#endif +#elif MEGDNN_ARMV7 + benchmark_conv1x1("ARMV7_INT8X8X32_K4X8X8", handle(), stype, dtype, dtype, + dtype); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDASYM) { + dtype::Quantized8Asymm stype(1.2f, (uint8_t)125); + dtype::QuantizedS32 dtype(1.2 * 1.2); + +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + benchmark_conv1x1("AARCH64_QUINT8_K8X8X4_DOTPROD", handle(), stype, dtype, + dtype, dtype); +#else + benchmark_conv1x1("AARCH64_QUINT8_K8X8X8", handle(), stype, dtype, dtype, + dtype); +#endif +#elif MEGDNN_ARMV7 + benchmark_conv1x1("ARMV7_QUINT8_K4X8X8", handle(), stype, dtype, dtype, + dtype); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_INT8x8x16) { +#if MEGDNN_AARCH64 + benchmark_conv1x1("AARCH64_INT8X8X16_K8X8X8", handle(), dtype::Int8{}, + dtype::Int16{}, dtype::Int16{}, dtype::Int16{}); + benchmark_conv1x1("AARCH64_INT8X8X16_K4X4X16", handle(), dtype::Int8{}, + dtype::Int16{}, dtype::Int16{}, dtype::Int16{}); +#elif MEGDNN_ARMV7 + benchmark_conv1x1("ARMV7_INT8X8X16_K4X8X8", handle(), dtype::Int8{}, + dtype::Int16{}, dtype::Int16{}, dtype::Int16{}); + benchmark_conv1x1("ARMV7_INT8X8X16_K4X2X16", handle(), dtype::Int8{}, + dtype::Int16{}, dtype::Int16{}, dtype::Int16{}); +#endif +} + +#ifndef __ARM_FEATURE_DOTPROD +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) { + std::vector conv_bias_1x1_args_nchw44 = + get_conv_bias_1x1_benchmark_args(4); + std::vector conv_bias_1x1_args_nchw = + get_conv_bias_1x1_benchmark_args(1); + constexpr size_t RUNS = 50; + + Benchmarker benchmark_conv1x1_nchw44(handle()); + benchmark_conv1x1_nchw44.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker( + "CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24")); + benchmark_conv1x1_nchw44.set_times(RUNS) + .set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()) + .set_dtype(4, dtype::Int32()) + .set_display(false); + + Benchmarker benchmark_conv1x1_nchw(handle()); + benchmark_conv1x1_nchw.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker( + "CONV1x1:AARCH64_INT8X8X32_K4X4X16:24")); + benchmark_conv1x1_nchw.set_times(RUNS) + .set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()) + .set_dtype(4, dtype::Int32()) + .set_display(false); + + for (size_t i = 0; i < conv_bias_1x1_args_nchw44.size(); ++i) { + auto&& arg_nchw = conv_bias_1x1_args_nchw[i]; + auto&& arg_nchw44 = conv_bias_1x1_args_nchw44[i]; + + size_t IC = arg_nchw.src[1]; + size_t OH = arg_nchw.src[2]; + size_t OW = arg_nchw.src[3]; + size_t OC = arg_nchw.filter[0]; + size_t M = OC; + size_t K = IC; + size_t N = OH * OW; + + float computations = M * N * K * 2.f / (1024 * 1024 * 1024) * 1e3; + + auto conv1x1_nchw = benchmark_conv1x1_nchw.set_param(arg_nchw.param) + .exec({arg_nchw.src, + arg_nchw.filter, + arg_nchw.bias, + {}, + {}}) / + RUNS; + auto conv1x1_nchw44 = + benchmark_conv1x1_nchw44.set_param(arg_nchw44.param) + .exec({arg_nchw44.src, + arg_nchw44.filter, + arg_nchw44.bias, + {}, + {}}) / + RUNS; + printf("%s %s:\n conv_1x1_nchw: %f ms %f Gflops\nconv1x1_nchw44: %f ms " + "%f GFlops " + "speedup: " + "%f\n", + arg_nchw.src.to_string().c_str(), + arg_nchw.filter.to_string().c_str(), conv1x1_nchw, + computations / conv1x1_nchw, conv1x1_nchw44, + computations / conv1x1_nchw44, conv1x1_nchw / conv1x1_nchw44); + } +} +#endif + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp new file mode 100644 index 00000000..6c0553f6 --- /dev/null +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -0,0 +1,1422 @@ +/** + * \file dnn/test/arm_common/conv_bias_multi_thread.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "test/arm_common/fixture.h" +#include "test/common/benchmarker.h" +#include "test/common/conv_bias.h" + +using namespace megdnn; +using namespace test; +using namespace conv_bias; + +std::vector get_int8_quint8_conv_bias_args( + std::vector kernel, size_t stride, bool no_pad, bool no_bias, + bool no_nonlinemode) { + using namespace conv_bias; + using Param = param::ConvBias; + using NLMode = param::ConvBias::NonlineMode; + std::vector args; + + auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h, + size_t kernel, size_t stride, NLMode nlmode) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + if (!no_pad) { + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + } else { + param.pad_h = 0; + param.pad_w = 0; + } + param.nonlineMode = nlmode; + + args.emplace_back(param, TensorShape{n, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, TensorShape{}); + if (!no_bias) { + args.emplace_back(param, TensorShape{n, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + } + }; + + std::vector nonlinemode = {NLMode::IDENTITY}; + if (!no_nonlinemode) { + nonlinemode.emplace_back(NLMode::RELU); + nonlinemode.emplace_back(NLMode::H_SWISH); + } + + for (size_t n : {1, 2}) { + for (auto nlmode : nonlinemode) { + for (size_t ic : {1, 3, 7}) { + for (size_t oc : {1, 3, 7}) { + for (size_t size : {4, 6, 8, 14, 16, 18}) { + for (size_t kern : kernel) { + pack(n, oc, ic, size, size, kern, stride, nlmode); + } + } + } + } + } + } + return args; +} +std::vector get_nchw44_conv_bias_args( + std::vector kernel_vec, size_t stride, bool no_pad = false, + bool no_bias = false, bool no_nonlinemode = false, + bool is_input_nchw = false) { + using namespace conv_bias; + using NLMode = param::ConvBias::NonlineMode; + std::vector args; + + auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, + size_t kernel, size_t stride, size_t group, NLMode nlmode) { + constexpr int pack_c = 4; + const size_t pad = no_pad ? 0 : kernel / 2; + auto bias_mode = no_bias ? megdnn::BiasMode::NO_BIAS + : megdnn::BiasMode::BROADCAST_CHANNEL_BIAS; + auto oc_per_group = oc / group; + auto ic_per_group = ic / group; + bool ok_group = (oc % group == 0 && ic % group == 0) && + oc_per_group % pack_c == 0 && oc_per_group > 0 && + ic_per_group > 0; + bool nchw_disable = group > 1 || ic_per_group >= 4; + bool nchw44_disable = ic_per_group % pack_c != 0; + if (!(ok_group)) { + return; + } + if ((is_input_nchw && nchw_disable) || + (!is_input_nchw && nchw44_disable)) { + return; + } + + size_t kernel_h = kernel; + size_t kernel_w = kernel; + param::ConvBias param; + param.format = param::ConvBias::Format::NCHW44; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = pad; + param.pad_w = pad; + param.nonlineMode = nlmode; + auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c}; + auto weight_tensor_shape = TensorShape{ + oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c}; + auto bias_tensor_shape = TensorShape{}; + if (bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) { + bias_tensor_shape = {1, oc / pack_c, 1, 1, pack_c}; + } + if (group == 1) { + param.sparse = param::ConvBias::Sparse::DENSE; + } else if (group > 1 && ic / group == 1 && oc / group == 1) { + megdnn_assert(0, "not support channel wise"); + param.sparse = param::ConvBias::Sparse::GROUP; + weight_tensor_shape = TensorShape{group / pack_c, 1, 1, + kernel_h, kernel_w, pack_c}; + } else if (group > 1 && oc_per_group % pack_c == 0 && oc / group > 0 && + ic_per_group % pack_c == 0 && ic / group > 0) { + param.sparse = param::ConvBias::Sparse::GROUP; + weight_tensor_shape = TensorShape{group, + oc_per_group / pack_c, + ic_per_group / pack_c, + kernel_h, + kernel_w, + pack_c, + pack_c}; + } + if (is_input_nchw) { + src_tensor_shape = TensorShape{n, ic, h, w}; + weight_tensor_shape = + TensorShape{oc / pack_c, kernel_h, kernel_w, ic, pack_c}; + } + args.emplace_back(param, src_tensor_shape, weight_tensor_shape, + bias_tensor_shape); + }; + + std::vector nonlinemode = {NLMode::IDENTITY}; + if (!no_nonlinemode) { + nonlinemode.emplace_back(NLMode::RELU); + nonlinemode.emplace_back(NLMode::H_SWISH); + } + for (auto nlmode : nonlinemode) + for (size_t n : {1, 2}) + for (size_t kernel : kernel_vec) + for (size_t oc : {4, 12, 32}) + for (size_t ic : {1, 3, 4, 12, 32}) + for (size_t h : {3, 5, 12}) + for (size_t w : {7, 16, 23}) { + for (size_t group = 1; + group <= std::min(oc, ic); ++group) { + pack(n, oc, ic, h, w, kernel, stride, group, + nlmode); + } + } + return args; +} + +std::vector get_int8_quint8_nchw44_channel_wise_args( + std::vector kernel, size_t stride, bool no_bias, + bool no_nonlinemode) { + using namespace conv_bias; + using Param = param::ConvBias; + using NLMode = param::ConvBias::NonlineMode; + std::vector args; + + auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel, + size_t stride, NLMode nlmode, bool pad) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + if (pad) { + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + } else { + param.pad_h = 0; + param.pad_w = 0; + } + param.nonlineMode = nlmode; + param.format = param::ConvBias::Format::NCHW44; + param.sparse = param::ConvBias::Sparse::GROUP; + + args.emplace_back(param, TensorShape{n, group, h, w, 4}, + TensorShape{group, 1, 1, kernel, kernel, 4}, + TensorShape{}); + if (!no_bias) { + args.emplace_back(param, TensorShape{n, group, h, w, 4}, + TensorShape{group, 1, 1, kernel, kernel, 4}, + TensorShape{1, group, 1, 1, 4}); + } + }; + + std::vector nonlinemode = {NLMode::IDENTITY}; + if (!no_nonlinemode) { + nonlinemode.emplace_back(NLMode::RELU); + nonlinemode.emplace_back(NLMode::H_SWISH); + } + for (size_t n : {1, 2}) { + for (auto nlmode : nonlinemode) { + for (bool pad : {true}) { + for (size_t group : {1, 2, 4, 7, 128}) { + for (size_t size : {4, 5, 6, 7, 8, 9, 10, 15, 40}) { + for (size_t kern : kernel) { + pack(n, group, size, size, kern, stride, nlmode, + pad); + } + } + } + } + for (bool pad : {false}) { + for (size_t group : {1, 2, 7, 128}) { + for (size_t size : {7, 8, 9, 10, 15, 40}) { + for (size_t kern : kernel) { + pack(n, group, size, size, kern, stride, nlmode, + pad); + } + } + } + } + } + } + return args; +} + +void checker_conv_bias_qint8x8x8(std::vector args, + Handle* handle, const char* algo_name) { + Checker checker(handle); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name)); +#if MEGDNN_ARMV7 + checker.set_epsilon(1); +#endif + UniformIntRNG rng{-50, 50}; + checker.set_dtype(0, dtype::QuantizedS8(0.41113496f)) + .set_dtype(1, dtype::QuantizedS8(0.01887994f)) + .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f)) + .set_dtype(4, dtype::QuantizedS8(0.49550694f)) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng); + for (auto&& arg : args) { + checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); + } +} +void checker_conv_bias_qint8x8x32(std::vector args, + Handle* handle, const char* algo_name) { + Checker checker(handle); + + UniformIntRNG rng{-50, 50}; + checker.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, {}); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name)); + for (auto&& arg : args) { + checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); + } +} +void checker_conv_bias_quint8x8x8(std::vector args, + Handle* handle, const char* algo_name) { + Checker checker(handle); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name)); + UniformIntRNG rng(0, 255); + checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100)) + .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120)) + .set_dtype(2, dtype::QuantizedS32(0.04f)) + .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110)) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng); + + for (auto&& arg : args) { + checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); + } +} +void checker_conv_bias_quint8x8x32(std::vector args, + Handle* handle, const char* algo_name) { + Checker checker(handle); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name)); + + NormalRNG rng(128.f); + checker.set_rng(0, &rng).set_rng(1, &rng); + checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127)) + .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129)) + .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3)) + .set_dtype(4, {}); + for (auto&& arg : args) { + checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); + } +} +void checker_conv_bias_int8x8x32_multi(std::vector args, + Handle* handle, const char* algo_name) { + Checker checker(handle); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name)); + checker.set_dtype(0, dtype::Int8()); + checker.set_dtype(1, dtype::Int8()); + checker.set_dtype(2, dtype::Int32()); + checker.set_dtype(4, dtype::Int32()); + for (auto&& arg : args) { + checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); + } +} + +/**********************************F32 direct************************/ +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_LARGE_GROUP) { + check_conv_bias( + get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), + handle(), "F32DIRECT_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) { + check_conv_bias( + get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), + handle(), "F32DIRECT_SMALL_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_LARGE_GROUP) { + check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), + handle(), "F32STRD1_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_SMALL_GROUP) { + check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), + handle(), "F32STRD1_SMALL_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) { + check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), + handle(), "F32STRD2_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) { + check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), + handle(), "F32STRD2_SMALL_GROUP"); +} +/**********************************F16 direct************************/ +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) { + NormalRNG rng(1); + checker_conv_bias_f16( + get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), + handle(), rng, "F16DIRECT_LARGE_GROUP", 0.03); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_SMALL_GROUP) { + NormalRNG rng(1); + checker_conv_bias_f16( + get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), + handle(), rng, "F16DIRECT_SMALL_GROUP", 0.03); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_LARGE_GROUP) { + NormalRNG rng(1); + checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false), + handle(), rng, "F16STRD1_LARGE_GROUP", 0.03); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_SMALL_GROUP) { + NormalRNG rng(1); + checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false), + handle(), rng, "F16STRD1_SMALL_GROUP", 0.03); +} +#endif + +/**********************************algo 8816 direct************************/ +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_LARGE_GROUP) { + checker_conv_bias_int8x8x16( + get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(), + "I8816DIRECT_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_SMALL_GROUP) { + checker_conv_bias_int8x8x16( + get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(), + "I8816DIRECT_SMALL_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_LARGE_GROUP) { + checker_conv_bias_int8x8x16( + get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(), + "I8816STRD2_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_SMALL_GROUP) { + checker_conv_bias_int8x8x16( + get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(), + "I8816STRD2_SMALL_GROUP"); +} + +/**********************************algo 8-8-32 direct************************/ +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_LARGE_GROUP) { + checker_conv_bias_int8x8x32_multi( + get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), + "S8STRD1_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_SMALL_GROUP) { + checker_conv_bias_int8x8x32_multi( + get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), + "S8STRD1_SMALL_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_LARGE_GROUP) { + checker_conv_bias_int8x8x32_multi( + get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), + "S8STRD2_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_SMALL_GROUP) { + checker_conv_bias_int8x8x32_multi( + get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), + "S8STRD2_SMALL_GROUP"); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) { + checker_conv_bias_int8x8x32_multi( + get_int8_quint8_nchw44_channel_wise_args({2, 3, 5}, 1, false, true), + handle(), "S8_CHAN_WISE_STRD1_NCHW44"); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) { + checker_conv_bias_int8x8x32_multi( + get_int8_quint8_nchw44_channel_wise_args({2, 3, 5}, 2, false, true), + handle(), "S8_CHAN_WISE_STRD2_NCHW44"); +} + +/********************************qint8 direct******************************/ +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_LARGE_GROUP) { + checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 1, false, false, false), + handle(), "S8STRD1_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_SMALL_GROUP) { + checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 1, false, false, false), + handle(), "S8STRD1_SMALL_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_LARGE_GROUP) { + checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 2, false, false, false), + handle(), "S8STRD2_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) { + checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 2, false, false, false), + handle(), "S8STRD2_SMALL_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { + checker_conv_bias_qint8x8x8( + get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), + handle(), "S8_NCHW44_DIRECT_STRD1"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) { + checker_conv_bias_qint8x8x8( + get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), + handle(), "S8_NCHW44_DIRECT_STRD2"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) { + checker_conv_bias_qint8x8x8(get_int8_quint8_nchw44_channel_wise_args( + {2, 3, 5}, 1, false, false), + handle(), "S8_CHAN_WISE_STRD1_NCHW44"); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) { + checker_conv_bias_qint8x8x8(get_int8_quint8_nchw44_channel_wise_args( + {2, 3, 5}, 2, false, false), + handle(), "S8_CHAN_WISE_STRD2_NCHW44"); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) { + checker_conv_bias_qint8x8x8( + get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true), + handle(), "S8_CONV_NCHW_NCHW44"); +} + +/*****************************quint8 direct****************************/ +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_LARGE_GROUP) { + checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 1, false, false, false), + handle(), "QU8STRD1_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_SMALL_GROUP) { + checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 1, false, false, false), + handle(), "QU8STRD1_SMALL_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_LARGE_GROUP) { + checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 2, false, false, false), + handle(), "QU8STRD2_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) { + checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 2, false, false, false), + handle(), "QU8STRD2_SMALL_GROUP"); +} + +/****************************dot qint8 direct*************************/ +#if __ARM_FEATURE_DOTPROD +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { + checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 1, false, false, false), + handle(), "ARMDOTS8STRD1_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) { + checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 1, false, false, false), + handle(), "ARMDOTS8STRD1_SMALL_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) { + checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 2, false, false, false), + handle(), "ARMDOTS8STRD2_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) { + checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 2, false, false, false), + handle(), "ARMDOTS8STRD2_SMALL_GROUP"); +} + +/****************************dot 8-8-32 direct*************************/ +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_LARGE_GROUP) { + checker_conv_bias_qint8x8x32( + get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), + "ARMDOTS8STRD1_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_SMALL_GROUP) { + checker_conv_bias_qint8x8x32( + get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), + "ARMDOTS8STRD1_SMALL_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_LARGE_GROUP) { + checker_conv_bias_qint8x8x32( + get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), + "ARMDOTS8STRD2_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_SMALL_GROUP) { + checker_conv_bias_qint8x8x32( + get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), + "ARMDOTS8STRD2_SMALL_GROUP"); +} +/******************************dot quint8*****************************/ +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { + checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 1, false, false, false), + handle(), "ARMDOTU8STRD1_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) { + checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( + {2, 3, 5, 7}, 1, false, false, false), + handle(), "ARMDOTU8STRD1_SMALL_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) { + checker_conv_bias_quint8x8x8( + get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false), + handle(), "ARMDOTU8STRD2_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) { + checker_conv_bias_quint8x8x8( + get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false), + handle(), "ARMDOTU8STRD2_SMALL_GROUP"); +} + +/******************************dot quint8x8x32***********************/ +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_LARGE_GROUP) { + checker_conv_bias_quint8x8x32( + get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), + "ARMDOTU8STRD1_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_SMALL_GROUP) { + checker_conv_bias_quint8x8x32( + get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(), + "ARMDOTU8STRD1_SMALL_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_LARGE_GROUP) { + checker_conv_bias_quint8x8x32( + get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), + "ARMDOTU8STRD2_LARGE_GROUP"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) { + checker_conv_bias_quint8x8x32( + get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), + "ARMDOTU8STRD2_SMALL_GROUP"); +} +#endif + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) { + using namespace conv_bias; + std::vector args = get_winograd_mk_packed_args(); + Checker checker(handle()); + + check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) { + using namespace conv_bias; + std::vector args = get_winograd_args(3); + Checker checker(handle()); + + check_winograd("1:6:32", checker, args); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) { + using namespace conv_bias; + std::vector args = get_winograd_mk_packed_args(); + Checker checker(handle()); + + check_winograd("4:6:32", checker, args, param::MatrixMul::Format::MK4); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) { + using namespace conv_bias; + std::vector args = get_winograd_args(4); + Checker checker(handle()); + + check_winograd("1:5:32", checker, args); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) { + using namespace conv_bias; + std::vector args = get_winograd_args(5); + Checker checker(handle()); + + check_winograd("1:4:32", checker, args); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) { + using namespace conv_bias; + std::vector args = get_winograd_args(3); + + Checker checker(handle()); + + auto extra_impl = [](const TensorNDArray& tensors, uint32_t m, + param::ConvBias param, Handle* handle) { + megdnn_assert(param.format == param::ConvBias::Format::NCHW); + auto winograd_preprocess_opr = + handle->create_operator(); + winograd_preprocess_opr->param().output_block_size = m; + TensorLayout filter_transform_layout; + winograd_preprocess_opr->deduce_layout(tensors[1].layout, + filter_transform_layout); + size_t winograd_preprocess_workspace_in_bytes = + winograd_preprocess_opr->get_workspace_in_bytes( + tensors[1].layout, filter_transform_layout); + + auto conv_bias_opr = handle->create_operator(); + conv_bias_opr->param() = param; + conv_bias_opr->param().format = param::ConvBias::Format::NCHW_WINOGRAD; + conv_bias_opr->param().output_block_size = m; + size_t conv_bias_workspace_in_bytes = + conv_bias_opr->get_workspace_in_bytes( + tensors[0].layout, filter_transform_layout, + tensors[2].layout, tensors[3].layout, + tensors[4].layout); + + WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(), + conv_bias_workspace_in_bytes, + winograd_preprocess_workspace_in_bytes}); + wb.set(malloc(wb.total_size_in_bytes())); + + TensorND filter_transform_tensor(wb.get(0), + std::move(filter_transform_layout)); + winograd_preprocess_opr->exec(tensors[1], filter_transform_tensor, + wb.get_workspace(2)); + conv_bias_opr->exec(tensors[0], filter_transform_tensor, tensors[2], + tensors[3], tensors[4], wb.get_workspace(1)); + + free(wb.ptr()); + }; + + auto run = [&checker, &extra_impl]( + Handle* handle, const std::vector& args, + const std::vector& out_size, DType A_dtype, + DType B_dtype, DType C_dtype, DType D_dtype, + const float eps) { + for (auto&& arg : args) { + for (uint32_t m : out_size) { + checker.set_extra_opr_impl(std::bind(extra_impl, + std::placeholders::_1, m, + arg.param, handle)); + checker.set_dtype(0, A_dtype) + .set_dtype(1, B_dtype) + .set_dtype(2, C_dtype) + .set_dtype(4, D_dtype) + .set_epsilon(eps) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } + } + }; + run(handle(), args, {6}, dtype::Float32(), dtype::Float32(), + dtype::Float32(), dtype::Float32(), 1e-3f); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00); + checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng); + run(handle(), args, {6}, dtype::Float16(), dtype::Float16(), + dtype::Float16(), dtype::Float16(), 0.35f); +#endif +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) { + using namespace conv_bias; + + Checker checker(handle()); + auto run = [&checker](Handle* handle, const std::vector& args, + const std::vector& out_size, DType A_dtype, + DType B_dtype, DType C_dtype, DType D_dtype, + param::MatrixMul::Format format, float eps) { + for (auto&& arg : args) { + for (uint32_t m : out_size) { + checker.set_extra_opr_impl(std::bind( + winograd_algo_extra_impl, std::placeholders::_1, m, + arg.param, handle, format)); + checker.set_dtype(0, A_dtype) + .set_dtype(1, B_dtype) + .set_dtype(2, C_dtype) + .set_dtype(4, D_dtype) + .set_epsilon(eps) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } + } + }; + std::vector args = get_winograd_mk_packed_args(8); + std::vector args_first_half(args.begin(), + args.begin() + args.size() / 2); + run(handle(), args_first_half, {2, 6}, dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4, + 1e-3f); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) { + using namespace conv_bias; + + Checker checker(handle()); + auto run = [&checker](Handle* handle, const std::vector& args, + const std::vector& out_size, DType A_dtype, + DType B_dtype, DType C_dtype, DType D_dtype, + param::MatrixMul::Format format, float eps) { + for (auto&& arg : args) { + for (uint32_t m : out_size) { + checker.set_extra_opr_impl(std::bind( + winograd_algo_extra_impl, std::placeholders::_1, m, + arg.param, handle, format)); + checker.set_dtype(0, A_dtype) + .set_dtype(1, B_dtype) + .set_dtype(2, C_dtype) + .set_dtype(4, D_dtype) + .set_epsilon(eps) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } + } + }; + std::vector args = get_winograd_mk_packed_args(8); + std::vector args_second_half(args.begin() + args.size() / 2, + args.end()); + run(handle(), args_second_half, {2, 6}, dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, dtype::Float32{}, param::MatrixMul::Format::MK4, + 1e-3f); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) { + using namespace conv_bias; + + Checker checker(handle()); + auto run = [&checker](Handle* handle, const std::vector& args, + const std::vector& out_size, DType A_dtype, + DType B_dtype, DType C_dtype, DType D_dtype, + param::MatrixMul::Format format, float eps) { + for (auto&& arg : args) { + for (uint32_t m : out_size) { + checker.set_extra_opr_impl(std::bind( + winograd_algo_extra_impl, std::placeholders::_1, m, + arg.param, handle, format)); + checker.set_dtype(0, A_dtype) + .set_dtype(1, B_dtype) + .set_dtype(2, C_dtype) + .set_dtype(4, D_dtype) + .set_epsilon(eps) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } + } + }; + + std::vector args = get_winograd_mk_packed_args(8); + Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00); + checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng); + run(handle(), args, {2}, dtype::Float16{}, dtype::Float16{}, + dtype::Float16{}, dtype::Float16{}, param::MatrixMul::Format::MK8, + 0.25); +} +#endif +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) { + using namespace conv_bias; + + Checker checker(handle()); + auto run = [&checker](Handle* handle, const std::vector& args, + const std::vector& out_size, DType A_dtype, + DType B_dtype, DType C_dtype, DType D_dtype, + param::MatrixMul::Format format, float eps) { + for (auto&& arg : args) { + for (uint32_t m : out_size) { + checker.set_extra_opr_impl(std::bind( + winograd_algo_extra_impl, std::placeholders::_1, m, + arg.param, handle, format)); + checker.set_dtype(0, A_dtype) + .set_dtype(1, B_dtype) + .set_dtype(2, C_dtype) + .set_dtype(4, D_dtype) + .set_epsilon(eps) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } + } + }; + +#if MEGDNN_AARCH64 + const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8"; +#else + const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8"; +#endif + checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker( + ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str())); + + std::vector args = get_winograd_mk_packed_args(8); + std::vector quantized_args = + get_quantized_winograd_mk_packed_args(8); + UniformIntRNG int_rng{-50, 50}; + checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); + run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f), + dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), + dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) { + using namespace conv_bias; + std::vector args = get_winograd_mk_packed_args(); + Checker checker(handle()); + check_winograd_fp16("1:2:32", checker, args, NULL, 0.08); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) { + using namespace conv_bias; + std::vector args = get_winograd_args(5); + std::vector args_head_half(args.begin(), + args.begin() + args.size() / 2); + Checker checker(handle()); + //! fp16 range -1.0 ~ 1.0 + Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00); + check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) { + using namespace conv_bias; + std::vector args = get_winograd_args(5); + std::vector args_back_half(args.begin() + args.size() / 2, + args.end()); + Checker checker(handle()); + //! fp16 range -1.0 ~ 1.0 + Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00); + check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25); +} +//! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but +//! it will pass when run single testcase +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) { + using namespace conv_bias; + std::vector args = get_winograd_args(3); + Checker checker(handle()); + //! fp16 range -1.0 ~ 1.0 + Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00); + check_winograd_fp16("1:6:32", checker, args, rng, 0.3); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) { + using namespace conv_bias; + std::vector args = get_winograd_mk_packed_args(8); + std::vector args_head_half(args.begin(), + args.begin() + args.size() / 2); + Checker checker(handle()); + Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00); + check_winograd_fp16("8:2:32", checker, args_head_half, rng, 0.25, + param::MatrixMul::Format::MK8); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) { + using namespace conv_bias; + std::vector args = get_winograd_mk_packed_args(8); + std::vector args_back_half(args.begin() + args.size() / 2, + args.end()); + Checker checker(handle()); + Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00); + check_winograd_fp16("8:2:32", checker, args_back_half, rng, 0.25, + param::MatrixMul::Format::MK8); +} +#endif +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) { + using namespace conv_bias; + std::vector args = get_quantized_winograd_mk_packed_args(8); + Checker checker(handle()); + UniformIntRNG rng{-50, 50}; + checker.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng); + + check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8); +} + +void checker_conv_bias(std::vector args, Handle* handle, + RNG* rng, float epsilon, DType type0, DType type1, + DType type2, DType type3, const char* algo_name) { + using namespace conv_bias; + + Checker checker(handle); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name)); + checker.set_dtype(0, type0); + checker.set_dtype(1, type1); + checker.set_dtype(2, type2); + checker.set_dtype(4, type3); + checker.set_epsilon(epsilon); + if (NULL != rng) { + checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng).set_rng(3, rng); + } + for (auto&& arg : args) { + checker.set_param(arg.param).execs( + {arg.src, arg.filter, arg.bias, {}, {}}); + } +} +// clang-format off +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE2) { +#define cb(name) \ + check_conv_bias( \ + get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 2, false, false, false), \ + handle(), name); +#if MEGDNN_AARCH64 + cb("IM2COLMATMUL:AARCH64_F32K8X12X1") + cb("IM2COLMATMUL:AARCH64_F32K4X16X1") + cb("IM2COLMATMUL:FB_F32_K8X12X1") +#elif MEGDNN_ARMV7 + cb("IM2COLMATMUL:ARMV7_F32") +#endif +#undef cb +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE1) { +#define cb(name) \ + check_conv_bias( \ + get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false), \ + handle(), name); +#if MEGDNN_AARCH64 + cb("IM2COLMATMUL:AARCH64_F32K8X12X1") + cb("IM2COLMATMUL:AARCH64_F32K4X16X1") + cb("IM2COLMATMUL:FB_F32_K8X12X1") +#elif MEGDNN_ARMV7 + cb("IM2COLMATMUL:ARMV7_F32") + cb("IM2COLMATMUL:FB_F32_K8X12X1") +#endif +#undef cb +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) { + UniformIntRNG rng{-50, 50}; + +#define cb(name) \ + checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ + false, true, true), \ + handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ + dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ + dtype::QuantizedS8(60.25f), name); \ + checker_conv_bias( \ + get_conv_bias_args({1}, 2, false, false, false, true, true), \ + handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ + dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ + dtype::QuantizedS8(60.25f), name); + + float epsilon = 0.001; +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"); +#else + cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"); + cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16"); +#endif +#elif MEGDNN_ARMV7 + epsilon = 1; + cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8"); +#endif +#undef cb +} +// clang-format on +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { + NormalRNG rng(128.f); + +#define cb(name) \ + checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ + false, true, true), \ + handle(), &rng, epsilon, \ + dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ + dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ + dtype::QuantizedS32(1.2 * 1.3), \ + dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); \ + checker_conv_bias( \ + get_conv_bias_args({1}, 2, false, false, false, true, true), \ + handle(), &rng, epsilon, \ + dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ + dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ + dtype::QuantizedS32(1.2 * 1.3), \ + dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); + float epsilon = 0.001; +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"); +#else + cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"); +#endif +#elif MEGDNN_ARMV7 + epsilon = 1; + cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8"); +#endif +#undef cb +} +#endif + +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { + UniformIntRNG rng{-50, 50}; + float epsilon = 0.001; +#define cb(name) \ + checker_conv_bias( \ + get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ + handle(), &rng, epsilon, \ + dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ + dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ + dtype::QuantizedS32(1.2 * 1.3), {}, name); \ + checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \ + &rng, epsilon, \ + dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ + dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ + dtype::QuantizedS32(1.2 * 1.3), {}, name); + +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"); +#else + cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"); +#endif +#elif MEGDNN_ARMV7 +#if __ARM_FEATURE_DOTPROD + cb("IM2COLMATMUL:AARCH32_QUINT8_K4X8X4"); +#endif + cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8"); +#endif +#undef cb +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { + UniformIntRNG rng{-50, 50}; + float epsilon = 0.001; +#define cb(name) \ + checker_conv_bias( \ + get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ + handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ + dtype::Int16{}, dtype::Int16{}, name); \ + checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \ + &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ + dtype::Int16{}, dtype::Int16{}, name); + +#if MEGDNN_AARCH64 + cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8"); + cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16"); + cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); +#elif MEGDNN_ARMV7 + cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); + cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"); + cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16"); +#endif +#undef cb +} +#endif + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) { + using namespace conv_bias; + + param::ConvBias cur_param; + + std::vector args = + get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, false); + std::vector args1 = + get_conv_bias_args({1}, 2, false, false, false); + args.insert(args.begin(), args1.begin(), args1.end()); + + NormalRNG rng(1); +#define cb(name) \ + checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{}, \ + dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, \ + name); + +#if MEGDNN_AARCH64 + cb("IM2COLMATMUL:AARCH64_F16_K8X24X1"); +#elif MEGDNN_ARMV7 + cb("IM2COLMATMUL:AARCH32_F16_K4X16X1"); +#endif +#undef cb +} +#endif + +void checker_conv_bias_mul_int8x8x32(std::vector args, + Handle* handle, const char* algo_name) { + using namespace conv_bias; + + Checker checker(handle); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name)); + checker.set_dtype(0, dtype::Int8()); + checker.set_dtype(1, dtype::Int8()); + checker.set_dtype(2, dtype::Int32()); + checker.set_dtype(4, dtype::Int32()); + for (auto&& arg : args) { + checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); + } + + UniformIntRNG rng{-50, 50}; + for (auto&& arg : args) { + checker.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, {}) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng) + .set_param(arg.param) + .execs({arg.src, arg.filter, {}, {}, {}}); + } +} + +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 +#if !__ARM_FEATURE_DOTPROD +TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true); + +#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); +#if MEGDNN_AARCH64 + cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96"); +#else + cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96"); +#endif +#undef cb +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_MULTI) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true); + +#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); +#if MEGDNN_AARCH64 + cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96"); +#else + cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96"); +#endif + +#undef cb +} + +TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44) { + UniformIntRNG rng{-50, 50}; + +#define cb(name) \ + checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1), \ + handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ + dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ + dtype::QuantizedS8(60.25f), name); + float epsilon = 0.001; +#if MEGDNN_AARCH64 + cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96"); +#else + cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96"); +#endif +#undef cb +} + +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_MULTI) { + UniformIntRNG rng{-50, 50}; + +#define cb(name) \ + checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1), \ + handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ + dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ + dtype::QuantizedS8(60.25f), name); + float epsilon = 0.001; +#if MEGDNN_AARCH64 + cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96"); +#else + cb("IM2COLMATMUL:ARMV7_INT8X8X32_MK4_4X2X16:96"); +#endif +#undef cb +} + +#endif +#endif + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { + using namespace conv_bias; + std::vector args = + get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true); + std::vector args1 = + get_conv_bias_args({1}, 2, false, true, true); + args.insert(args.begin(), args1.begin(), args1.end()); + +#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); + +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"); +#else + cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"); + cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16"); +#endif +#elif MEGDNN_ARMV7 +#if __ARM_FEATURE_DOTPROD + cb("IM2COLMATMUL:AARCH32_INT8_K6X8X4"); +#endif + cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8"); +#endif + +#if MEGDNN_ARMV7 + cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X2X16"); +#endif +#undef cb +} + +/***************************** Conv1x1 Algo Test ***********************/ +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { + using namespace conv_bias; + std::vector args = get_conv_bias_1x1_args(false, false); +#if MEGDNN_AARCH64 + check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32K8X12X1:24"); +#elif MEGDNN_ARMV7 + check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32:48"); +#endif +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) { + using namespace conv_bias; + std::vector args = get_conv_bias_1x1_args(false, false); + NormalRNG rng(1); +#if MEGDNN_AARCH64 + checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{}, + dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, + "CONV1x1:AARCH64_F16_K8X24X1:48"); +#elif MEGDNN_ARMV7 + checker_conv_bias(args, handle(), &rng, 0.03, dtype::Float16{}, + dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, + "CONV1x1:AARCH32_F16_K4X16X1:24"); +#endif +} +#endif + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) { + UniformIntRNG rng{-50, 50}; + float epsilon = 0.001; +#define cb(name) \ + checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \ + handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ + dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ + dtype::QuantizedS8(60.25f), name); +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24"); +#else + cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); + cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:48"); +#endif +#elif MEGDNN_ARMV7 + epsilon = 1; + cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:48"); +#endif +#undef cb +} + +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) { + NormalRNG rng(128.f); +#define cb(name) \ + checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \ + handle(), &rng, epsilon, \ + dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ + dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ + dtype::QuantizedS32(1.2 * 1.3), \ + dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); + float epsilon = 0.001; +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48"); +#else + cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24"); +#endif +#elif MEGDNN_ARMV7 + epsilon = 1; + cb("CONV1x1:ARMV7_QUINT8_K4X8X8:48"); +#endif +#undef cb +} +#endif + +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) { + UniformIntRNG rng{-50, 50}; + float epsilon = 0.001; +#define cb(name) \ + checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \ + epsilon, dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ + dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ + dtype::QuantizedS32(1.2 * 1.3), {}, name); + +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24"); +#else + cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48"); +#endif +#elif MEGDNN_ARMV7 +#if __ARM_FEATURE_DOTPROD + cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48"); +#endif + cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24"); +#endif +#undef cb +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { + UniformIntRNG rng{-50, 50}; + float epsilon = 0.001; +#define cb(name) \ + checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \ + epsilon, dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, \ + dtype::Int16{}, name); + +#if MEGDNN_AARCH64 + cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); + cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); +#elif MEGDNN_ARMV7 + cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); + cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); +#endif + cb("CONV1x1:ARM_COMMON_INT8X8X16:48"); +#undef cb +} +#endif + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { + using namespace conv_bias; + std::vector args = get_conv_bias_1x1_args(true, true); + +#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); + +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48"); +#else + cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); + cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24"); +#endif +#elif MEGDNN_ARMV7 +#if __ARM_FEATURE_DOTPROD + cb("CONV1x1:AARCH32_INT8_K6X8X4:48"); +#endif + cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24"); +#endif + +#if MEGDNN_ARMV7 + cb("CONV1x1:ARMV7_INT8X8X32_K4X2X16:48"); +#endif +#undef cb +} + +#ifndef __ARM_FEATURE_DOTPROD +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({1}, 1, true, true, true); + +#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); + +#if MEGDNN_AARCH64 + cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24"); +#elif MEGDNN_ARMV7 + cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24"); +#endif +#undef cb + + UniformIntRNG rng{-50, 50}; + float epsilon = 0.001; +#define cb(name) \ + checker_conv_bias(get_nchw44_conv_bias_args({1}, 1, true, false, false), \ + handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ + dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ + dtype::QuantizedS8(60.25f), name); +#if MEGDNN_AARCH64 + cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24"); +#elif MEGDNN_ARMV7 + cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24"); +#endif +#undef cb +} +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp new file mode 100644 index 00000000..e8accc27 --- /dev/null +++ b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp @@ -0,0 +1,1561 @@ +/** + * \file dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "test/arm_common/fixture.h" +#include "test/common/benchmarker.h" +#include "test/common/conv_bias.h" + +using namespace megdnn; +using namespace test; +using namespace conv_bias; +#if MEGDNN_WITH_BENCHMARK +namespace { +void benchmark_impl(const param::ConvBias param, + std::vector, float>>& + shapes_and_computation, + const std::string algo_name, size_t RUNS, + TaskExecutorConfig&& multi_thread_config, + TaskExecutorConfig&& single_thread_config, + std::vector& data_type) { + std::vector multi_thread_times, single_thread_times; + { + auto multi_thread_hanle = + create_cpu_handle(0, true, &multi_thread_config); + auto benchmarker = Benchmarker(multi_thread_hanle.get()); + benchmarker.set_times(RUNS) + .set_display(false) + .set_param(param) + .set_dtype(0, data_type[0]) + .set_dtype(1, data_type[1]) + .set_dtype(2, data_type[2]) + .set_dtype(4, data_type[3]) + .set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker( + algo_name.c_str())); + for (auto shape : shapes_and_computation) { + multi_thread_times.push_back(benchmarker.exec(shape.first) / RUNS); + } + } + { + auto single_thread_handle = + create_cpu_handle(0, true, &single_thread_config); + auto benchmarker = Benchmarker(single_thread_handle.get()); + benchmarker.set_times(RUNS) + .set_display(false) + .set_param(param) + .set_dtype(0, data_type[0]) + .set_dtype(1, data_type[1]) + .set_dtype(2, data_type[2]) + .set_dtype(4, data_type[3]) + .set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker( + algo_name.c_str())); + for (auto shape : shapes_and_computation) { + single_thread_times.push_back(benchmarker.exec(shape.first) / RUNS); + } + } + printf("Benchmark : Multi threads %zu, ", multi_thread_config.nr_thread); + printf("core_ids:"); + for (size_t i = 0; i < multi_thread_config.affinity_core_set.size(); i++) { + printf("%zu ", multi_thread_config.affinity_core_set[i]); + } + printf(", Single thread core_id %zu\n", + single_thread_config.affinity_core_set[0]); + for (size_t i = 0; i < shapes_and_computation.size(); i++) { + auto shapes = shapes_and_computation[i]; + printf("Bench case: "); + for (auto&& shape : shapes.first) { + printf("%s ", shape.to_string().c_str()); + } + float computations = shapes.second; + printf("%zu threads gflops: %f,\n single thread gflops: " + "%f. spead up = %f, speedup/cores=%f\n", + multi_thread_config.nr_thread, + computations / multi_thread_times[i], + computations / single_thread_times[i], + single_thread_times[i] / multi_thread_times[i], + single_thread_times[i] / multi_thread_times[i] / + multi_thread_config.nr_thread); + } +} +} // namespace + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group) { + SmallVector shapes{{N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, H, W}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4); + bench_case(1, 32, 32, 200, 200, 3, 32); + bench_case(1, 32, 32, 128, 128, 3, 4); + bench_case(1, 32, 32, 128, 128, 3, 32); + bench_case(1, 32, 32, 100, 100, 3, 4); + bench_case(1, 32, 32, 100, 100, 3, 32); + bench_case(1, 32, 32, 80, 80, 3, 4); + bench_case(1, 32, 32, 80, 80, 3, 32); + + std::string algo_name = "F32DIRECT_LARGE_GROUP"; + printf("Benchmark F32DIRECT_LARGE_GROUP algo\n"); + std::vector data_type = {dtype::Float32(), dtype::Float32(), + dtype::Float32(), dtype::Float32()}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "F32DIRECT_SMALL_GROUP"; + printf("Benchmark F32DIRECT_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1); + bench_case(1, 32, 32, 128, 128, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR1) { + constexpr size_t RUNS = 50; + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group) { + SmallVector shapes{{N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, H, W}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4); + bench_case(1, 32, 32, 200, 200, 3, 32); + bench_case(1, 32, 32, 128, 128, 3, 4); + bench_case(1, 32, 32, 128, 128, 3, 32); + bench_case(1, 32, 32, 100, 100, 3, 4); + bench_case(1, 32, 32, 100, 100, 3, 32); + bench_case(1, 32, 32, 80, 80, 3, 4); + bench_case(1, 32, 32, 80, 80, 3, 32); + + std::string algo_name = "F32STRD1_LARGE_GROUP"; + printf("Benchmark F32STRD1_LARGE_GROUP algo\n"); + std::vector data_type = {dtype::Float32(), dtype::Float32(), + dtype::Float32(), dtype::Float32()}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "F32STRD1_SMALL_GROUP"; + printf("Benchmark F32STRD1_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1); + bench_case(1, 32, 32, 128, 128, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR2) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 2; + param.stride_w = 2; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S) { + SmallVector shapes{ + {N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4, 1, 2); + bench_case(1, 32, 32, 200, 200, 3, 32, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 4, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 32, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 4, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 32, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); + + std::string algo_name = "F32STRD2_LARGE_GROUP"; + printf("Benchmark F32STRD2_LARGE_GROUP algo\n"); + std::vector data_type = {dtype::Float32(), dtype::Float32(), + dtype::Float32(), dtype::Float32()}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "F32STRD2_SMALL_GROUP"; + printf("Benchmark F32STRD2_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 1, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 1, 1, 2); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group) { + SmallVector shapes{{N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, H, W}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4); + bench_case(1, 32, 32, 200, 200, 3, 32); + bench_case(1, 32, 32, 128, 128, 3, 4); + bench_case(1, 32, 32, 128, 128, 3, 32); + bench_case(1, 32, 32, 100, 100, 3, 4); + bench_case(1, 32, 32, 100, 100, 3, 32); + bench_case(1, 32, 32, 80, 80, 3, 4); + bench_case(1, 32, 32, 80, 80, 3, 32); + + std::string algo_name = "F16DIRECT_LARGE_GROUP"; + printf("Benchmark F16DIRECT_LARGE_GROUP algo\n"); + std::vector data_type = {dtype::Float16(), dtype::Float16(), + dtype::Float16(), dtype::Float16()}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "F16DIRECT_SMALL_GROUP"; + printf("Benchmark F16DIRECT_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1); + bench_case(1, 32, 32, 128, 128, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group) { + SmallVector shapes{{N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, H, W}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4); + bench_case(1, 32, 32, 200, 200, 3, 32); + bench_case(1, 32, 32, 128, 128, 3, 4); + bench_case(1, 32, 32, 128, 128, 3, 32); + bench_case(1, 32, 32, 100, 100, 3, 4); + bench_case(1, 32, 32, 100, 100, 3, 32); + bench_case(1, 32, 32, 80, 80, 3, 4); + bench_case(1, 32, 32, 80, 80, 3, 32); + + std::string algo_name = "F16STRD1_LARGE_GROUP"; + printf("Benchmark F16STRD1_LARGE_GROUP algo\n"); + std::vector data_type = {dtype::Float16(), dtype::Float16(), + dtype::Float16(), dtype::Float16()}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "F16STRD1_SMALL_GROUP"; + printf("Benchmark F16STRD1_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1); + bench_case(1, 32, 32, 128, 128, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} +#endif +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_DIRECT_INT8x8x16) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group) { + SmallVector shapes{{N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {}, + {}, + {N, OC, H, W}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4); + bench_case(1, 32, 32, 200, 200, 3, 32); + bench_case(1, 32, 32, 128, 128, 3, 4); + bench_case(1, 32, 32, 128, 128, 3, 32); + bench_case(1, 32, 32, 100, 100, 3, 4); + bench_case(1, 32, 32, 100, 100, 3, 32); + bench_case(1, 32, 32, 80, 80, 3, 4); + bench_case(1, 32, 32, 80, 80, 3, 32); + + std::string algo_name = "I8816DIRECT_LARGE_GROUP"; + printf("Benchmark I8816DIRECT_LARGE_GROUP algo\n"); + std::vector data_type = {dtype::Int8(), dtype::Int8(), + dtype::Int16(), dtype::Int16()}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "I8816DIRECT_SMALL_GROUP"; + printf("Benchmark I8816DIRECT_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1); + bench_case(1, 32, 32, 128, 128, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_DIRECT_INT8x8x16_STR2) { + constexpr size_t RUNS = 50; + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 2; + param.stride_w = 2; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S) { + SmallVector shapes{ + {N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; + TensorShape dst{N, OC, (H + 2 * P - FS) / S + 1, + (W + 2 * P - FS) / S + 1}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4, 1, 2); + bench_case(1, 32, 32, 200, 200, 3, 32, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 4, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 32, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 4, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 32, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); + + std::string algo_name = "I8816STRD2_LARGE_GROUP"; + printf("Benchmark I8816STRD2_LARGE_GROUP algo\n"); + std::vector data_type = {dtype::Int8(), dtype::Int8(), + dtype::Int16(), dtype::Int16()}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "I8816STRD2_SMALL_GROUP"; + printf("Benchmark I8816STRD2_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 1, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 1, 1, 2); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE1) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S) { + SmallVector shapes{ + {N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4, 1, 1); + bench_case(1, 32, 32, 200, 200, 3, 32, 1, 1); + bench_case(1, 32, 32, 128, 128, 3, 4, 1, 1); + bench_case(1, 32, 32, 128, 128, 3, 32, 1, 1); + bench_case(1, 32, 32, 100, 100, 3, 4, 1, 1); + bench_case(1, 32, 32, 100, 100, 3, 32, 1, 1); + bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); + bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); + + std::string algo_name = "S8STRD1_LARGE_GROUP"; + printf("Benchmark S8STRD1_LARGE_GROUP algo\n"); + std::vector data_type = { + dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f)}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "S8STRD1_SMALL_GROUP"; + printf("Benchmark S8STRD1_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); + bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); + bench_case(1, 32, 32, 100, 100, 3, 1, 1, 1); + bench_case(1, 32, 32, 80, 80, 3, 1, 1, 1); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) { + constexpr size_t RUNS = 40; + std::vector data_type = { + dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f)}; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S, + bool is_nchw = false) { + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = P; + param.pad_w = P; + param.stride_h = S; + param.stride_w = S; + param.sparse = param::ConvBias::Sparse::DENSE; + param.format = param::ConvBias::Format::NCHW44; + auto OH = (H + 2 * P - FS) / static_cast(S) + 1; + auto OW = (W + 2 * P - FS) / static_cast(S) + 1; + TensorShape src = {N, IC / 4, H, W, 4}; + TensorShape filter = {OC / 4, IC / 4, FS, FS, 4, 4}; + if (group > 1) { + filter = {group, OC / group / 4, IC / group / 4, FS, FS, 4, 4}; + param.sparse = param::ConvBias::Sparse::GROUP; + } + if (is_nchw) { + src = {N, IC, H, W}; + filter = {OC / 4, FS, FS, IC, 4}; + } + TensorShape bias = {1, OC / 4, 1, 1, 4}; + TensorShape dst = {N, OC / 4, OH, OW, 4}; + + SmallVector shapes{src, filter, bias, {}, dst}; + float computations = + (((IC / group) * FS * FS + 1) * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + std::vector, float>> shape_arg = { + std::make_pair(shapes, computations)}; + benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}}, + {1, {7}}, data_type); + }; + bench_case(1, 3, 64, 224, 224, 7, 1, 3, 2, true); + bench_case(1, 64, 64, 56, 56, 3, 1, 1, 1); + bench_case(1, 128, 128, 28, 28, 3, 1, 1, 1); + bench_case(1, 256, 256, 14, 14, 3, 1, 1, 1); + bench_case(1, 512, 512, 7, 7, 3, 1, 1, 1); + + bench_case(1, 64, 64, 56, 56, 3, 4, 1, 1); + bench_case(1, 128, 128, 28, 28, 3, 4, 1, 1); + bench_case(1, 256, 256, 14, 14, 3, 4, 1, 1); + bench_case(1, 512, 512, 7, 7, 3, 4, 1, 1); + + bench_case(1, 4, 64, 224, 224, 7, 1, 1, 2); + bench_case(1, 256, 128, 56, 56, 3, 1, 1, 2); + bench_case(1, 512, 256, 28, 28, 3, 1, 1, 2); + bench_case(1, 4, 32, 224, 224, 3, 1, 1, 2); + + bench_case(1, 256, 128, 56, 56, 3, 4, 1, 2); + bench_case(1, 512, 256, 28, 28, 3, 4, 1, 2); +} + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE2) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 2; + param.stride_w = 2; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S) { + SmallVector shapes{ + {N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4, 1, 2); + bench_case(1, 32, 32, 200, 200, 3, 32, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 4, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 32, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 4, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 32, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); + + std::string algo_name = "S8STRD2_LARGE_GROUP"; + printf("Benchmark S8STRD2_LARGE_GROUP algo\n"); + std::vector data_type = { + dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f)}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "S8STRD2_SMALL_GROUP"; + printf("Benchmark S8STRD2_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 1, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 1, 1, 2); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} +#if __ARM_FEATURE_DOTPROD +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE1_WITHDOTPROD) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S) { + SmallVector shapes{ + {N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4, 1, 1); + bench_case(1, 32, 32, 200, 200, 3, 32, 1, 1); + bench_case(1, 32, 32, 128, 128, 3, 4, 1, 1); + bench_case(1, 32, 32, 128, 128, 3, 32, 1, 1); + bench_case(1, 32, 32, 100, 100, 3, 4, 1, 1); + bench_case(1, 32, 32, 100, 100, 3, 32, 1, 1); + bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); + bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); + + std::string algo_name = "ARMDOTS8STRD1_LARGE_GROUP"; + printf("Benchmark ARMDOTS8STRD1_LARGE_GROUP algo\n"); + std::vector data_type = { + dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f)}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "ARMDOTS8STRD1_SMALL_GROUP"; + printf("Benchmark ARMDOTS8STRD1_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); + bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); + bench_case(1, 32, 32, 100, 100, 3, 1, 1, 1); + bench_case(1, 32, 32, 80, 80, 3, 1, 1, 1); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE2_WITHDOTPROD) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 2; + param.stride_w = 2; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S) { + SmallVector shapes{ + {N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4, 1, 2); + bench_case(1, 32, 32, 200, 200, 3, 32, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 4, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 32, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 4, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 32, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); + + std::string algo_name = "ARMDOTS8STRD2_LARGE_GROUP"; + printf("Benchmark ARMDOTS8STRD2_LARGE_GROUP algo\n"); + std::vector data_type = { + dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f)}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "ARMDOTS8STRD2_SMALL_GROUP"; + printf("Benchmark ARMDOTS8STRD2_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 1, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 1, 1, 2); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} +#endif + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_QUINT8_QUINT8_QUINT8_STRIDE1) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S) { + SmallVector shapes{ + {N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4, 1, 1); + bench_case(1, 32, 32, 200, 200, 3, 32, 1, 1); + bench_case(1, 32, 32, 128, 128, 3, 4, 1, 1); + bench_case(1, 32, 32, 128, 128, 3, 32, 1, 1); + bench_case(1, 32, 32, 100, 100, 3, 4, 1, 1); + bench_case(1, 32, 32, 100, 100, 3, 32, 1, 1); + bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); + bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); + + std::string algo_name = "QU8STRD1_LARGE_GROUP"; + printf("Benchmark QU8STRD1_LARGE_GROUP algo\n"); + std::vector data_type = {dtype::Quantized8Asymm(0.2f, 100), + dtype::Quantized8Asymm(0.2f, 120), + dtype::QuantizedS32(0.04f), + dtype::Quantized8Asymm(1.4f, 110)}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "QU8STRD1_SMALL_GROUP"; + printf("Benchmark QU8STRD1_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); + bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); + bench_case(1, 32, 32, 100, 100, 3, 1, 1, 1); + bench_case(1, 32, 32, 80, 80, 3, 1, 1, 1); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_QUINT8_QUINT8_QUINT8_STRIDE2) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 2; + param.stride_w = 2; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S) { + SmallVector shapes{ + {N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4, 1, 2); + bench_case(1, 32, 32, 200, 200, 3, 32, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 4, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 32, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 4, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 32, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); + + std::string algo_name = "QU8STRD2_LARGE_GROUP"; + printf("Benchmark QU8STRD2_LARGE_GROUP algo\n"); + std::vector data_type = {dtype::Quantized8Asymm(0.2f, 100), + dtype::Quantized8Asymm(0.2f, 120), + dtype::QuantizedS32(0.04f), + dtype::Quantized8Asymm(1.4f, 110)}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "QU8STRD2_SMALL_GROUP"; + printf("Benchmark QU8STRD2_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 1, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 1, 1, 2); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} +#if __ARM_FEATURE_DOTPROD +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_QUINT8_QUINT8_QUINT8_STRIDE1_WITHDOTPROD) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S) { + SmallVector shapes{ + {N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; + TensorShape dst{N, OC, (H + 2 * P - FS) / S + 1, + (W + 2 * P - FS) / S + 1}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4, 1, 1); + bench_case(1, 32, 32, 200, 200, 3, 32, 1, 1); + bench_case(1, 32, 32, 128, 128, 3, 4, 1, 1); + bench_case(1, 32, 32, 128, 128, 3, 32, 1, 1); + bench_case(1, 32, 32, 100, 100, 3, 4, 1, 1); + bench_case(1, 32, 32, 100, 100, 3, 32, 1, 1); + bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1); + bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1); + + std::string algo_name = "ARMDOTU8STRD1_LARGE_GROUP"; + printf("Benchmark ARMDOTU8STRD1_LARGE_GROUP algo\n"); + std::vector data_type = {dtype::Quantized8Asymm(0.2f, 100), + dtype::Quantized8Asymm(0.2f, 120), + dtype::QuantizedS32(0.04f), + dtype::Quantized8Asymm(1.4f, 110)}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "ARMDOTU8STRD1_SMALL_GROUP"; + printf("Benchmark ARMDOTS8STRD1_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1); + bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1); + bench_case(1, 32, 32, 100, 100, 3, 1, 1, 1); + bench_case(1, 32, 32, 80, 80, 3, 1, 1, 1); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_QUINT8_QUINT8_QUINT8_STRIDE2_WITHDOTPROD) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 2; + param.stride_w = 2; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S) { + SmallVector shapes{ + {N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; + TensorShape dst{N, OC, (H + 2 * P - FS) / S + 1, + (W + 2 * P - FS) / S + 1}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 5, 4, 1, 2); + bench_case(1, 32, 32, 200, 200, 5, 32, 1, 2); + bench_case(1, 32, 32, 128, 128, 5, 4, 1, 2); + bench_case(1, 32, 32, 128, 128, 5, 32, 1, 2); + bench_case(1, 32, 32, 100, 100, 5, 4, 1, 2); + bench_case(1, 32, 32, 100, 100, 5, 32, 1, 2); + bench_case(1, 32, 32, 80, 80, 5, 4, 1, 2); + bench_case(1, 32, 32, 80, 80, 5, 32, 1, 2); + + std::string algo_name = "ARMDOTU8STRD2_LARGE_GROUP"; + printf("Benchmark ARMDOTU8STRD2_LARGE_GROUP algo\n"); + std::vector data_type = {dtype::Quantized8Asymm(0.2f, 100), + dtype::Quantized8Asymm(0.2f, 120), + dtype::QuantizedS32(0.04f), + dtype::Quantized8Asymm(1.4f, 110)}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); + + algo_name = "ARMDOTU8STRD2_SMALL_GROUP"; + printf("Benchmark ARMDOTU8STRD2_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 5, 1, 1, 2); + bench_case(1, 32, 32, 128, 128, 5, 1, 1, 2); + bench_case(1, 32, 32, 100, 100, 5, 1, 1, 2); + bench_case(1, 32, 32, 80, 80, 5, 1, 1, 2); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} +#endif + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_WINOGRAD_F32) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group) { + SmallVector shapes{{N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, H, W}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4); + bench_case(1, 32, 32, 200, 200, 3, 1); + bench_case(1, 32, 32, 128, 128, 3, 4); + bench_case(1, 32, 32, 128, 128, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 4); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 4); + + bench_case(1, 512, 512, 14, 14, 3, 1); + bench_case(1, 512, 256, 14, 14, 3, 1); + bench_case(1, 512, 128, 14, 14, 3, 1); + bench_case(1, 512, 64, 14, 14, 3, 1); + + bench_case(1, 512, 512, 7, 7, 3, 1); + bench_case(1, 512, 256, 7, 7, 3, 1); + bench_case(1, 512, 128, 7, 7, 3, 1); + bench_case(1, 512, 64, 7, 7, 3, 1); + + std::string algo_name; +#if MEGDNN_AARCH64 + algo_name = "WINOGRAD:AARCH64_F32_MK4_4x16:4:2"; +#else + algo_name = "WINOGRAD:ARMV7_F32_MK4_4x8:4:2"; +#endif + std::vector data_type = {dtype::Float32(), dtype::Float32(), + dtype::Float32(), dtype::Float32()}; + printf("Benchmark WINOGRAD_F32_MK4 algo\n"); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_FP32) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group) { + SmallVector shapes{{N, IC, H, W}, + {OC, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, H, W}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + std::vector data_type = {dtype::Float32(), dtype::Float32(), + dtype::Float32(), dtype::Float32()}; + bench_case(1, 32, 32, 300, 300, 3, 1); + bench_case(1, 32, 32, 400, 400, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + bench_case(1, 32, 64, 200, 200, 3, 1); + bench_case(1, 32, 64, 128, 128, 3, 1); + bench_case(1, 32, 64, 100, 100, 3, 1); + bench_case(1, 32, 64, 80, 80, 3, 1); + bench_case(1, 32, 128, 200, 200, 3, 1); + bench_case(1, 32, 128, 128, 128, 3, 1); + bench_case(1, 32, 128, 100, 100, 3, 1); + bench_case(1, 32, 128, 80, 80, 3, 1); + + bench_case(1, 64, 32, 7, 7, 3, 1); + bench_case(1, 64, 64, 7, 7, 3, 1); + bench_case(1, 64, 128, 7, 7, 3, 1); + bench_case(1, 64, 256, 7, 7, 3, 1); + bench_case(1, 64, 512, 7, 7, 3, 1); + bench_case(1, 64, 1024, 7, 7, 3, 1); + + bench_case(1, 64, 32, 14, 14, 3, 1); + bench_case(1, 64, 64, 14, 14, 3, 1); + bench_case(1, 64, 128, 14, 14, 3, 1); + bench_case(1, 64, 256, 14, 14, 3, 1); + bench_case(1, 64, 512, 14, 14, 3, 1); + + bench_case(1, 64, 1024, 14, 14, 3, 1); + bench_case(1, 128, 128, 14, 14, 3, 1); + bench_case(1, 128, 256, 14, 14, 3, 1); + bench_case(1, 512, 512, 14, 14, 3, 1); + bench_case(1, 256, 512, 14, 14, 3, 1); + bench_case(1, 512, 1024, 14, 14, 3, 1); + bench_case(1, 1024, 1024, 14, 14, 3, 1); + std::string algo_name = "IM2COLMATMUL:AARCH64_F32K8X12X1:96"; + printf("Benchmark IM2COLMATMUL:AARCH64_F32K8X12X1algo:96\n"); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + algo_name = "IM2COLMATMUL:AARCH64_F32K8X12X1:192"; + printf("Benchmark IM2COLMATMUL:AARCH64_F32K8X12X1algo:192\n"); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + algo_name = "IM2COLMATMUL:AARCH64_F32K8X12X1:384"; + printf("Benchmark IM2COLMATMUL:AARCH64_F32K8X12X1algo:384\n"); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); +} +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CHANNEL_WISE_INT8_INT8_INT8_STRIDE1) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + param.format = param::ConvBias::Format::NCHW44; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t H, size_t W, size_t FS, + size_t P) { + size_t group = IC; + size_t OC = IC; + size_t S = 1; + SmallVector shapes{ + {N, IC, H, W, 4}, + {group, 1, 1, FS, FS, 4}, + {1, OC, 1, 1, 4}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1, 4}}; + TensorShape dst{N, OC, (H + 2 * P - FS) / S + 1, + (W + 2 * P - FS) / S + 1, 4}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + bench_case(1, 128, 200, 200, 3, 1); + bench_case(1, 128, 128, 128, 3, 1); + bench_case(1, 128, 100, 100, 3, 1); + bench_case(1, 128, 80, 80, 3, 1); + bench_case(1, 128, 56, 56, 3, 1); + bench_case(1, 128, 28, 28, 3, 1); + bench_case(1, 128, 14, 14, 3, 1); + + bench_case(1, 64, 200, 200, 3, 1); + bench_case(1, 64, 128, 128, 3, 1); + bench_case(1, 64, 100, 100, 3, 1); + bench_case(1, 64, 80, 80, 3, 1); + bench_case(1, 64, 56, 56, 3, 1); + bench_case(1, 64, 28, 28, 3, 1); + bench_case(1, 64, 14, 14, 3, 1); + + bench_case(1, 32, 200, 200, 3, 1); + bench_case(1, 32, 128, 128, 3, 1); + bench_case(1, 32, 100, 100, 3, 1); + bench_case(1, 32, 80, 80, 3, 1); + bench_case(1, 32, 56, 56, 3, 1); + bench_case(1, 32, 28, 28, 3, 1); + bench_case(1, 32, 14, 14, 3, 1); + + std::string algo_name = "S8_CHAN_WISE_STRD1_NCHW44"; + printf("Benchmarker S8_CHAN_WISE_STRD1_NCHW44 algo\n"); + std::vector data_type = { + dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f)}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_IM2COL_NCHW44_INT8x8x32_STRIDE1) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::DENSE; + param.format = param::ConvBias::Format::NCHW44; + + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group=1) { + SmallVector shapes{{N, IC, H, W,4}, + {OC, IC / group, FS, FS,4,4}, + {/*1, OC, 1, 1*/}, + {}, + {N, OC, H, W,4}}; + TensorShape dst{N, OC, H, W,4}; + float computations = + ((4 * IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 300, 300, 3, 1); + bench_case(1, 32, 32, 400, 400, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + bench_case(1, 32, 64, 200, 200, 3, 1); + bench_case(1, 32, 64, 128, 128, 3, 1); + bench_case(1, 32, 64, 100, 100, 3, 1); + bench_case(1, 32, 64, 80, 80, 3, 1); + bench_case(1, 32, 128, 200, 200, 3, 1); + bench_case(1, 32, 128, 128, 128, 3, 1); + bench_case(1, 32, 128, 100, 100, 3, 1); + bench_case(1, 32, 128, 80, 80, 3, 1); +#if 1 + bench_case(1, 64, 32, 7, 7, 3, 1); + bench_case(1, 64, 64, 7, 7, 3, 1); + bench_case(1, 64, 128, 7, 7, 3, 1); + bench_case(1, 64, 256, 7, 7, 3, 1); + bench_case(1, 64, 512, 7, 7, 3, 1); + bench_case(1, 64, 1024, 7, 7, 3, 1); + + bench_case(1, 64, 32, 14, 14, 3, 1); + bench_case(1, 64, 64, 14, 14, 3, 1); + bench_case(1, 64, 128, 14, 14, 3, 1); + bench_case(1, 64, 256, 14, 14, 3, 1); + bench_case(1, 64, 512, 14, 14, 3, 1); + + bench_case(1, 64, 1024, 14, 14, 3, 1); + bench_case(1, 128, 128, 14, 14, 3, 1); + bench_case(1, 128, 256, 14, 14, 3, 1); + bench_case(1, 512, 512, 14, 14, 3, 1); + bench_case(1, 256, 512, 14, 14, 3, 1); + bench_case(1, 512, 1024, 14, 14, 3, 1); + bench_case(1, 1024, 1024, 14, 14, 3, 1); +#endif + std::string algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96"; + printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96 algo\n"); + std::vector data_type = { + dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), {}}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + + + + algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:192"; + printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:192 algo\n"); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + + algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:384"; + printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:384 algo\n"); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + +} + +#endif + +/*================== BENCHMARK MULTITHREAD CONV1X1 =====================*/ +#if MEGDNN_WITH_BENCHMARK + +namespace { +std::vector, float>> +get_conv1x1_multithread_benchmark_args() { + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t IC, size_t OC, size_t H, size_t W) { + SmallVector shapes{{1, IC, H, W}, + {OC, IC, 1, 1}, + {1, OC, 1, 1}, + {}, + {1, OC, H, W}}; + TensorShape dst{1, OC, H, W}; + float computations = + (IC * dst.total_nr_elems() * 2 + dst.total_nr_elems()) * 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + bench_case(32, 32, 300, 300); + bench_case(32, 32, 400, 400); + bench_case(32, 32, 100, 100); + bench_case(32, 32, 80, 80); + bench_case(32, 64, 200, 200); + bench_case(32, 64, 128, 128); + bench_case(32, 64, 100, 100); + bench_case(32, 64, 80, 80); + bench_case(32, 128, 200, 200); + bench_case(32, 128, 128, 128); + bench_case(32, 128, 100, 100); + bench_case(32, 128, 80, 80); + + bench_case(64, 32, 7, 7); + bench_case(64, 64, 7, 7); + bench_case(64, 128, 7, 7); + bench_case(64, 256, 7, 7); + bench_case(64, 512, 7, 7); + bench_case(64, 1024, 7, 7); + + bench_case(64, 32, 14, 14); + bench_case(64, 64, 14, 14); + bench_case(64, 128, 14, 14); + bench_case(64, 256, 14, 14); + bench_case(64, 512, 14, 14); + + bench_case(64, 1024, 14, 14); + bench_case(128, 128, 14, 14); + bench_case(128, 256, 14, 14); + bench_case(512, 512, 14, 14); + bench_case(256, 512, 14, 14); + bench_case(512, 1024, 14, 14); + bench_case(1024, 1024, 14, 14); + return shapes_and_computation; +} + +void conv1x1_multithread_benchmark(const char* algo_name, DType stype, + DType ftype, DType btype, DType dtype) { + constexpr size_t RUNS = 50; + std::vector, float>> + shapes_and_computation = get_conv1x1_multithread_benchmark_args(); + + std::vector data_type = {stype, ftype, btype, dtype}; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 0; + param.pad_w = 0; + param.stride_h = 1; + param.stride_w = 1; + + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); +} +} // namespace + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_CONV1X1_S1_FP32) { +#if MEGDNN_AARCH64 + conv1x1_multithread_benchmark("CONV1x1:AARCH64_F32K8X12X1:8", + dtype::Float32(), dtype::Float32(), + dtype::Float32(), dtype::Float32()); +#else + conv1x1_multithread_benchmark("CONV1x1:ARMV7_F32:8", dtype::Float32(), + dtype::Float32(), dtype::Float32(), + dtype::Float32()); +#endif +} + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_CONV1X1_S1_QUANTIZEDASYM) { + dtype::Quantized8Asymm stype(0.2f, 100); + dtype::Quantized8Asymm ftype(0.2f, 120); + dtype::QuantizedS32 btype(0.04f); + dtype::Quantized8Asymm dtype(1.4f, 110); +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + conv1x1_multithread_benchmark("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:8", + stype, ftype, btype, dtype); +#else + conv1x1_multithread_benchmark("CONV1x1:AARCH64_QUINT8_K8X8X8:8", stype, + ftype, btype, dtype); +#endif +#else + conv1x1_multithread_benchmark("CONV1x1:ARMV7_QUINT8_K4X8X8:8", stype, ftype, + btype, dtype); +#endif +} + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/convolution.cpp b/dnn/test/arm_common/convolution.cpp new file mode 100644 index 00000000..eda5ee9b --- /dev/null +++ b/dnn/test/arm_common/convolution.cpp @@ -0,0 +1,1238 @@ +/** + * \file dnn/test/arm_common/convolution.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/arm_common/fixture.h" + +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/convolution.h" +#include "test/common/timer.h" + +using namespace megdnn; +using namespace test; + +using Param = param::Convolution; + +#if __ARM_FEATURE_DOTPROD +TEST_F(ARM_COMMON, CONVOLUTION_BACKWARD_DATA_INT8_INT8_INT32) { + Checker checker(handle()); + using Param = ConvolutionBackwardData::Param; + Param param; + auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, + size_t fh, size_t fw, size_t stride, size_t ph, size_t pw, + size_t group = 1) { + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = stride; + + TensorLayout diff = + TensorLayout{{n, oc * group, oh, ow}, dtype::Int8()}; + TensorLayout grad; + TensorLayout filter; + if (group == 1) { + param.sparse = Param::Sparse::DENSE; + filter = {{oc, ic, fh, fw}, dtype::Int8()}; + } else { + param.sparse = Param::Sparse::GROUP; + filter = {{group, oc, ic, fh, fw}, dtype::Int8()}; + } + // TensorLayout grad; + { + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout(filter, diff, grad); + } + if(stride == 1 ){ + checker.set_before_exec_callback(AlgoChecker< + ConvolutionBackwardData>( + "AARCH32_I8x8x32_DECONV_STRIDE1")); + } else { + checker.set_before_exec_callback(AlgoChecker< + ConvolutionBackwardData>( + "AARCH32_I8x8x32_DECONV_STRIDE2")); + } + checker.set_param(param) + .set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()); + checker.exec(TensorLayoutArray{filter, diff, grad}); + }; + + // clang-format off + for (size_t f : {2, 3, 5, 7}) + for (size_t ih = 1; ih < f+1; ++ih) + for (size_t iw = 1; iw < 8*f+1; ++iw) + for (size_t s : {1, 2}) + for (size_t ph : {f/2, f-1}) + for (size_t pw : {f / 2, f - 1}) + if (f >= ph + 1 && f >= pw + 1 && (ih - 1) * s + f > 2 * ph && + (iw - 1) * s + f > 2 * pw) { + run(2, 3, ih, iw, 2, f, f, s, ph, pw, 1); + } + // clang-format on +} + +TEST_F(ARM_COMMON, CONVOLUTION_BACKWARD_DATA_QUINT8) { + Checker checker(handle()); + using Param = ConvolutionBackwardData::Param; + Param param; + auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, + size_t fh, size_t fw, size_t stride, size_t ph, size_t pw, + size_t group = 1) { + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = stride; + + TensorLayout diff = + TensorLayout{{n, oc * group, oh, ow}, dtype::Quantized8Asymm(1.3f, (uint8_t)129)}; + TensorLayout grad; + TensorLayout filter; + if (group == 1) { + param.sparse = Param::Sparse::DENSE; + filter = {{oc, ic, fh, fw}, dtype::Quantized8Asymm(1.2f, (uint8_t)127)}; + } else { + param.sparse = Param::Sparse::GROUP; + filter = {{group, oc, ic, fh, fw}, dtype::Quantized8Asymm(1.2f, (uint8_t)127)}; + } + // TensorLayout grad; + { + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout(filter, diff, grad); + } + NormalRNG rng(128.f); + + if(stride == 1 ){ + checker.set_before_exec_callback( + AlgoChecker( + "ARM_COMMON_QUINT8_DIRECT_" + "DECONV_STRIDE1")); + } else { + checker.set_before_exec_callback( + AlgoChecker( + "ARM_COMMON_QUINT8_DIRECT_" + "DECONV_STRIDE2")); + } + checker.set_param(param) + .set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127)) + .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129)) + .set_dtype(2, {}); + checker.set_rng(0, &rng).set_rng(1, &rng); + checker.exec(TensorLayoutArray{filter, diff, grad}); + }; + + // clang-format off + for (size_t f : {2, 3, 5, 7}) + for (size_t ih = 1; ih < f+1; ++ih) + for (size_t iw = 1; iw < 8*f+1; ++iw) + for (size_t s : {1, 2}) + for (size_t ph : {f/2, f-1}) + for (size_t pw : {f/2, f-1}) + if (f >= ph + 1 && f >= pw + 1 && (ih - 1) * s + f > 2 * ph && + (iw - 1) * s + f > 2 * pw) { + run(2, 2, ih, iw, 2, f, f, s, ph, pw, 1); + } + // clang-format on +} +#endif + +#if MEGDNN_WITH_BENCHMARK +#if __ARM_FEATURE_DOTPROD +TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE1_I8x8x32_WITHDOTPROD) { + using namespace convolution; + using Param = param::Convolution; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t stride) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}); + + }; + + for (size_t kernel : {2, 3, 5, 7}) { + for (size_t ic : {1, 8, 16, 32, 64}) { + for (size_t oc : {1, 8, 16, 32, 64}) { + run(oc, ic, 56, 56, kernel, 1); + run(oc, ic, 128, 128, kernel, 1); + run(oc, ic, 256, 256, kernel, 1); + } + } + } + + constexpr size_t RUN = 50; + Benchmarker benchmark(handle()); + benchmark.set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_ARMDOTS8STRD1_SMALL_GROUP")); + benchmark.set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()); + benchmark.set_display(false); + benchmark.set_times(RUN); + + Benchmarker benchmark_float(handle()); + benchmark_float.set_display(false); + benchmark_float.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used_int = benchmark.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + auto used_float = benchmark_float.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + + printf("%s %s: int: %f ms %f Gflops float: %f ms %f GFlops speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used_int, computations / used_int, used_float, + computations / used_float, used_float / used_int); + } +} +TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE2_I8x8x32_WITHDOTPROD) { + using namespace convolution; + using Param = param::Convolution; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t stride) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}); + + }; + + for (size_t kernel : {2, 3, 5, 7}) { + for (size_t ic : {1, 8, 16, 32, 64}) { + for (size_t oc : {1, 8, 16, 32, 64}) { + run(oc, ic, 56, 56, kernel, 2); + run(oc, ic, 128, 128, kernel, 2); + run(oc, ic, 256, 256, kernel, 2); + } + } + } + + constexpr size_t RUN = 10; + Benchmarker benchmark(handle()); + benchmark.set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_ARMDOTS8STRD2_SMALL_GROUP")); + benchmark.set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()); + benchmark.set_display(false); + benchmark.set_times(RUN); + + Benchmarker benchmark_float(handle()); + benchmark_float.set_display(false); + benchmark_float.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used_int = + benchmark.set_param(arg.param).exec({arg.src, arg.filter, {}}) / + RUN; + auto used_float = benchmark_float.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + + printf("%s %s: int: %f ms %f Gflops float: %f ms %f GFlops speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used_int, computations / used_int, used_float, + computations / used_float, used_float / used_int); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE1_QUINT8_WITHDOTPROD) { + using namespace convolution; + using Param = param::Convolution; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t stride) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}); + + }; + + for (size_t kernel : {2, 3, 5, 7}) { + for (size_t ic : {1, 8, 16, 32, 64}) { + for (size_t oc : {1, 8, 16, 32, 64}) { + run(oc, ic, 56, 56, kernel, 1); + run(oc, ic, 128, 128, kernel, 1); + run(oc, ic, 256, 256, kernel, 1); + } + } + } + + constexpr size_t RUN = 50; + Benchmarker benchmark(handle()); + benchmark.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)129)) + .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)127)) + .set_dtype(2, {}); + + benchmark.set_display(false); + benchmark.set_times(RUN); + benchmark.set_before_exec_callback(AlgoChecker( + "CONVOLUTION_DEFAULT_ARMDOTU8STRD1_SMALL_GROUP")); + + Benchmarker benchmark_float(handle()); + benchmark_float.set_display(false); + benchmark_float.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used_int = benchmark.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + auto used_float = benchmark_float.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + + printf("%s %s: int: %f ms %f Gflops float: %f ms %f GFlops speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used_int, computations / used_int, used_float, + computations / used_float, used_float / used_int); + + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE2_QUINT8_WITHDOTPROD) { + using namespace convolution; + using Param = param::Convolution; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t stride) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}); + + }; + + for (size_t kernel : {2, 3, 5, 7}) { + for (size_t ic : {1, 8, 16, 32, 64}) { + for (size_t oc : {1, 8, 16, 32, 64}) { + run(oc, ic, 56, 56, kernel, 2); + run(oc, ic, 128, 128, kernel, 2); + run(oc, ic, 256, 256, kernel, 2); + } + } + } + + constexpr size_t RUN = 50; + Benchmarker benchmark(handle()); + benchmark.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)129)) + .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)127)) + .set_dtype(2, {}); + + benchmark.set_display(false); + benchmark.set_times(RUN); + benchmark.set_before_exec_callback(AlgoChecker( + "CONVOLUTION_DEFAULT_ARMDOTU8STRD2_SMALL_GROUP")); + + Benchmarker benchmark_float(handle()); + benchmark_float.set_display(false); + benchmark_float.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used_int = benchmark.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + auto used_float = benchmark_float.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + + printf("%s %s: int: %f ms %f Gflops float: %f ms %f GFlops speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used_int, computations / used_int, used_float, + computations / used_float, used_float / used_int); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_BACKWARD_DATA_INT8_INT8_INT32) { + using Param = ConvolutionBackwardData::Param; + + auto run = [&](const TensorLayoutArray& tensors, Param param) { + Benchmarker benchmarker(handle()); + size_t RUN = 50; + auto time = benchmarker.set_display(false) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int32{}) + .set_times(RUN) + .set_param(param) + .exec(tensors); + + size_t OC = tensors[0][0]; + size_t FH = tensors[0][2]; + size_t FW = tensors[0][3]; + float computations = tensors[2].total_nr_elems() * OC * FH * FW * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + printf("time = %f \n perf= %f gops\n", time, computations * RUN / time); + }; + + auto profile = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, + size_t fh, size_t fw, size_t s) { + Param param; + param.stride_h = param.stride_w = s; + printf("oc: %zd ic: %zd w: %zd h: %zd kernel_size: %zd sreide: %zd\n", + oc, ic, ow, oh, fh, s); + + TensorLayout diff = TensorLayout{{n, oc, oh, ow}, dtype::Int8()}; + TensorLayout filter = TensorLayout{{oc, ic, fh, fw}, dtype::Int8()}; + TensorLayout grad; + { + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout(filter, diff, grad); + } + run(TensorLayoutArray{filter, diff, grad}, param); + }; + + profile(1, 3, 120, 120, 2, 3, 3, 1); + profile(1, 3, 60, 60, 2, 3, 3, 2); + profile(1, 3, 224, 224, 2, 5, 5, 1); + profile(1, 3, 112, 112, 2, 5, 5, 2); + profile(1, 3, 224, 224, 2, 7, 7, 1); + profile(1, 3, 112, 112, 2, 7, 7, 2); +} +#endif + +TEST_F(ARM_COMMON, BENCHMARK_CHANWISE_CONVOLUTION) { + auto run = [&](const TensorShapeArray& shapes, Param param) { + auto handle_naive = create_cpu_handle(2); + Benchmarker benchmarker_naive(handle_naive.get()), + benchmarker_float(handle()), benchmarker_int(handle()); + benchmarker_int.set_dtype(0, dtype::Int8()); + benchmarker_int.set_dtype(1, dtype::Int8()); + benchmarker_int.set_dtype(2, dtype::Int16()); + size_t RUN = 10; + auto tfloat = benchmarker_float.set_display(false) + .set_times(RUN) + .set_param(param) + .exec(shapes); + auto tnaive = benchmarker_naive.set_display(false) + .set_times(RUN) + .set_param(param) + .exec(shapes); + auto iparam = param; + auto tint = benchmarker_int.set_display(false) + .set_times(RUN) + .set_param(iparam) + .exec(shapes); + float int_float_ratio = static_cast(tfloat) / tint; + printf("naive=%.3fms float=%.3fms int=%.3fms, int/float=%.3f\n", + tnaive / RUN, tfloat / RUN, tint / RUN, int_float_ratio); + EXPECT_GE(int_float_ratio, 1.5); + }; + Param param; + param.mode = Param::Mode::CROSS_CORRELATION; + param.sparse = Param::Sparse::GROUP; + run({{2, 12, 200, 100}, {12, 2, 1, 5, 5}, {}}, param); + run({{10, 24, 28, 28}, {24, 1, 1, 3, 3}, {}}, param); + param.stride_h = 2; + param.stride_w = 2; + param.pad_h = 1; + param.pad_w = 1; + run({{2, 12, 200, 100}, {12, 2, 1, 5, 5}, {}}, param); + run({{10, 24, 28, 28}, {24, 1, 1, 3, 3}, {}}, param); +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_INT8X8X32_STRD1_WITHOUT_DOTPROD) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace convolution; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::Convolution param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = p; + param.pad_w = p; + + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}); + + }; + + // compare to float direct conv here, + // but float direct conv don't support 7x7. + for (size_t kernel : {2, 3, 5}) + for (size_t ic : {1, 8, 16, 32, 64}) + for (size_t oc : {1, 8, 16, 32, 64}) + for (size_t p : {0, 1, 2, 3}) { + run(oc, ic, 56, 56, kernel, p); + run(oc, ic, 128, 128, kernel, p); + run(oc, ic, 256, 256, kernel, p); + } + + constexpr size_t RUN = 50; + Benchmarker benchmark(handle()); + benchmark.set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()); + benchmark.set_display(false); + benchmark.set_times(RUN); + benchmark.set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_S8STRD1")); + + Benchmarker benchmark_float(handle()); + benchmark_float.set_display(false); + benchmark_float.set_times(RUN); + benchmark_float.set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_F32STRD1")); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used_int = + benchmark.set_param(arg.param).exec({arg.src, arg.filter, {}}) / + RUN; + auto used_float = benchmark_float.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + + printf("%s %s: int: %f ms %f Gflops float: %f ms %f GFlops speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used_int, computations / used_int, used_float, + computations / used_float, used_float / used_int); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_INT8X8X32_STRD2_WITHOUT_DOTPROD) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace convolution; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::Convolution param; + param.stride_h = 2; + param.stride_w = 2; + param.pad_h = p; + param.pad_w = p; + + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}); + + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t ic : {1, 8, 16, 32, 64}) + for (size_t oc : {1, 8, 16, 32, 64}) + for (size_t p : {0, 1, 2, 3}) { + run(oc, ic, 56, 56, kernel, p); + run(oc, ic, 128, 128, kernel, p); + run(oc, ic, 256, 256, kernel, p); + } + + constexpr size_t RUN = 50; + Benchmarker benchmark(handle()); + benchmark.set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()); + benchmark.set_display(false); + benchmark.set_times(RUN); + benchmark.set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_S8STRD2")); + + Benchmarker benchmark_float(handle()); + benchmark_float.set_display(false); + benchmark_float.set_times(RUN); +#if MEGDNN_AARCH64 + benchmark_float.set_before_exec_callback(AlgoChecker( + "CONVOLUTION_DEFAULT_ARMV8F32STRD2")); +#else + benchmark_float.set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_F32STRD2")); +#endif + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used_int = + benchmark.set_param(arg.param).exec({arg.src, arg.filter, {}}) / + RUN; + auto used_float = benchmark_float.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + + printf("%s %s: int: %f ms %f Gflops float: %f ms %f GFlops speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used_int, computations / used_int, used_float, + computations / used_float, used_float / used_int); + } +} + +TEST_F(ARM_COMMON, + BENCHMARK_CONVOLUTION_INT8X8X32_STRD1_WITHOUT_DOTPROD_TO_MATMUL) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace convolution; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::Convolution param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = p; + param.pad_w = p; + + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}); + + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t p : {0, 1, 2}) + for (size_t ic : {1, 3, 4, 8, 12, 16, 32, 48, 64}) + for (size_t oc : {1, 3, 4, 8, 12, 16, 32, 48, 64}) + for (size_t size : {56, 128, 256}) { + run(oc, ic, size, size, kernel, p); + } + + constexpr size_t RUN = 50; + Benchmarker benchmark_conv(handle()); + benchmark_conv.set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()); + benchmark_conv.set_display(false); + benchmark_conv.set_times(RUN); + benchmark_conv.set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_S8STRD1")); + + Benchmarker benchmark_matmul(handle()); + benchmark_matmul.set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()); + benchmark_matmul.set_display(false); + benchmark_matmul.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used_conv = benchmark_conv.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + auto used_matmul = benchmark_matmul.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + + printf("%s %s: conv: %f ms %f Gflops matmul: %f ms %f GFlops speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used_conv, computations / used_conv, used_matmul, + computations / used_matmul, used_matmul / used_conv); + } +} + +TEST_F(ARM_COMMON, + BENCHMARK_CONVOLUTION_INT8X8X32_STRD2_WITHOUT_DOTPROD_TO_MATMUL) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace convolution; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::Convolution param; + param.stride_h = 2; + param.stride_w = 2; + param.pad_h = p; + param.pad_w = p; + + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}); + + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t p : {0, 1, 2}) + for (size_t ic : {1, 3, 4, 8, 12, 16, 32, 48, 64}) + for (size_t oc : {1, 3, 4, 8, 12, 16, 32, 48, 64}) + for (size_t size : {56, 128, 256}) { + run(oc, ic, size, size, kernel, p); + } + + constexpr size_t RUN = 50; + Benchmarker benchmark_conv(handle()); + benchmark_conv.set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()); + benchmark_conv.set_display(false); + benchmark_conv.set_times(RUN); + benchmark_conv.set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_S8STRD2")); + + Benchmarker benchmark_matmul(handle()); + benchmark_matmul.set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()); + benchmark_matmul.set_display(false); + benchmark_matmul.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used_conv = benchmark_conv.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + auto used_matmul = benchmark_matmul.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + + printf("%s %s: conv: %f ms %f Gflops matmul: %f ms %f GFlops speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used_conv, computations / used_conv, used_matmul, + computations / used_matmul, used_matmul / used_conv); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_QUINT8X8X32_STRD1_WITHOUT_DOTPROD) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace convolution; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::Convolution param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = p; + param.pad_w = p; + + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}); + + }; + + // compare to float direct conv here, + // but float direct conv don't support 7x7. + for (size_t kernel : {2, 3, 5}) + for (size_t ic : {1, 8, 16, 32, 64}) + for (size_t oc : {1, 8, 16, 32, 64}) + for (size_t p : {0, 1, 2, 3}) { + run(oc, ic, 56, 56, kernel, p); + run(oc, ic, 128, 128, kernel, p); + run(oc, ic, 256, 256, kernel, p); + } + + constexpr size_t RUN = 50; + Benchmarker benchmark(handle()); + benchmark.set_dtype(0, dtype::Quantized8Asymm(0.1f, static_cast(120))) + .set_dtype(1, dtype::Quantized8Asymm(0.1f, static_cast(120))) + .set_dtype(2, dtype::QuantizedS32(0.01f)); + benchmark.set_display(false); + benchmark.set_times(RUN); + benchmark.set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_QU8STRD1")); + + Benchmarker benchmark_float(handle()); + benchmark_float.set_display(false); + benchmark_float.set_times(RUN); + benchmark_float.set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_F32STRD1")); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used_int = + benchmark.set_param(arg.param).exec({arg.src, arg.filter, {}}) / + RUN; + auto used_float = benchmark_float.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + + printf("%s %s: int: %f ms %f Gflops float: %f ms %f GFlops speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used_int, computations / used_int, used_float, + computations / used_float, used_float / used_int); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_QUINT8X8X32_STRD2_WITHOUT_DOTPROD) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace convolution; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::Convolution param; + param.stride_h = 2; + param.stride_w = 2; + param.pad_h = p; + param.pad_w = p; + + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}); + + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t ic : {1, 8, 16, 32, 64}) + for (size_t oc : {1, 8, 16, 32, 64}) + for (size_t p : {0, 1, 2, 3}) { + run(oc, ic, 56, 56, kernel, p); + run(oc, ic, 128, 128, kernel, p); + run(oc, ic, 256, 256, kernel, p); + } + + constexpr size_t RUN = 50; + Benchmarker benchmark(handle()); + benchmark.set_dtype(0, dtype::Quantized8Asymm(0.1f, static_cast(120))) + .set_dtype(1, dtype::Quantized8Asymm(0.1f, static_cast(120))) + .set_dtype(2, dtype::QuantizedS32(0.01f)); + benchmark.set_display(false); + benchmark.set_times(RUN); + benchmark.set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_QU8STRD2")); + + Benchmarker benchmark_float(handle()); + benchmark_float.set_display(false); + benchmark_float.set_times(RUN); +#if MEGDNN_AARCH64 + benchmark_float.set_before_exec_callback(AlgoChecker( + "CONVOLUTION_DEFAULT_ARMV8F32STRD2")); +#else + benchmark_float.set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_F32STRD2")); +#endif + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used_int = + benchmark.set_param(arg.param).exec({arg.src, arg.filter, {}}) / + RUN; + auto used_float = benchmark_float.set_param(arg.param).exec( + {arg.src, arg.filter, {}}) / + RUN; + + printf("%s %s: int: %f ms %f Gflops float: %f ms %f GFlops speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used_int, computations / used_int, used_float, + computations / used_float, used_float / used_int); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_INT8_INT8_INT16) { + using Param = param::Convolution; + auto run = [&](const TensorShapeArray& shapes, Param param) { + TensorLayoutArray layouts; + layouts.emplace_back(shapes[0], dtype::Int8()); + layouts.emplace_back(shapes[1], dtype::Int8()); + layouts.emplace_back(shapes[2], dtype::Int16()); + Benchmarker benchmarker_cpu(handle()), + benchmarker_float(handle()); + benchmarker_cpu.set_dtype(0, dtype::Int8()); + benchmarker_cpu.set_dtype(1, dtype::Int8()); + benchmarker_cpu.set_dtype(2, dtype::Int16()); + auto iparam = param; + size_t RUN = 10; + auto t2 = benchmarker_cpu.set_display(false) + .set_times(RUN) + .set_param(iparam) + .execl(layouts); + auto t4 = benchmarker_float.set_display(false) + .set_times(RUN) + .set_param(param) + .exec(shapes); + auto speedup = t4 / t2; + std::cout << "src=" << shapes[0].to_string() + << " filter=" << shapes[1].to_string() + << " stride=" << param.stride_h << " float=" << t4 << "ms" + << " int=" << t2 << "ms" + << " speedup=" << speedup << std::endl; + ASSERT_GE(speedup, 1); + }; + /* + for (size_t s: {1, 2}) + for (size_t k: {3}) + for (size_t c: {16}) + for (size_t h = 20; h <= 60; ++h) + { + Param param; + param.stride_h = param.stride_w = s; + run({{1, c, h, h}, {c, c, k, k}, {}}, param); + } + + for (size_t s: {1}) + for (size_t k: {1}) + for (size_t c: {16}) + for (size_t h = 16; h <= 1024; h*=2) + { + Param param; + param.stride_h = param.stride_w = s; + run({{1, c, h, h}, {c, c, k, k}, {}}, param); + } + */ + for (size_t s : {1}) { + Param param; + param.stride_h = param.stride_w = s; + + run({{2, 3, 480, 270}, {12, 3, 1, 1}, {}}, param); + run({{2, 12, 240, 135}, {48, 12, 1, 1}, {}}, param); + run({{2, 16, 240, 135}, {4, 16, 1, 1}, {}}, param); + run({{2, 4, 240, 135}, {16, 4, 1, 1}, {}}, param); + run({{2, 16, 240, 135}, {8, 16, 1, 1}, {}}, param); + run({{2, 8, 120, 68}, {32, 8, 1, 1}, {}}, param); + run({{2, 32, 120, 68}, {8, 32, 1, 1}, {}}, param); + run({{2, 64, 60, 34}, {16, 64, 1, 1}, {}}, param); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_INT8_INT8_INT32) { + using Param = param::Convolution; + auto run = [&](const TensorShapeArray& shapes, Param param) { + TensorLayoutArray layouts; + layouts.emplace_back(shapes[0], dtype::Int8()); + layouts.emplace_back(shapes[1], dtype::Int8()); + layouts.emplace_back(shapes[2], dtype::Int32()); + Benchmarker benchmarker_cpu(handle()), + benchmarker_float(handle()); + benchmarker_cpu.set_dtype(0, dtype::Int8()); + benchmarker_cpu.set_dtype(1, dtype::Int8()); + benchmarker_cpu.set_dtype(2, dtype::Int32()); + auto iparam = param; + size_t RUN = 10; + auto t2 = benchmarker_cpu.set_display(false) + .set_times(RUN) + .set_param(iparam) + .execl(layouts); + auto t4 = benchmarker_float.set_display(false) + .set_times(RUN) + .set_param(param) + .exec(shapes); + auto speedup = t4 / t2; + std::cout << "src=" << shapes[0].to_string() + << " filter=" << shapes[1].to_string() + << " stride=" << param.stride_h << " float=" << t4 << "ms" + << " int=" << t2 << "ms" + << " speedup=" << speedup << std::endl; + ASSERT_GE(speedup, 1); + }; + for (size_t s : {1, 2}) + for (size_t k : {3}) + for (size_t c : {16}) + for (size_t h = 20; h <= 60; ++h) { + Param param; + param.stride_h = param.stride_w = s; + run({{1, c, h, h}, {c, c, k, k}, {}}, param); + } + + for (size_t s : {1}) + for (size_t k : {1}) + for (size_t c : {16}) + for (size_t h = 16; h <= 1024; h *= 2) { + Param param; + param.stride_h = param.stride_w = s; + run({{1, c, h, h}, {c, c, k, k}, {}}, param); + } + for (size_t s : {1}) { + Param param; + param.stride_h = param.stride_w = s; + + run({{2, 3, 480, 270}, {12, 3, 1, 1}, {}}, param); + run({{2, 12, 240, 135}, {48, 12, 1, 1}, {}}, param); + run({{2, 16, 240, 135}, {4, 16, 1, 1}, {}}, param); + run({{2, 4, 240, 135}, {16, 4, 1, 1}, {}}, param); + run({{2, 16, 240, 135}, {8, 16, 1, 1}, {}}, param); + run({{2, 8, 120, 68}, {32, 8, 1, 1}, {}}, param); + run({{2, 32, 120, 68}, {8, 32, 1, 1}, {}}, param); + run({{2, 64, 60, 34}, {16, 64, 1, 1}, {}}, param); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_DIRECT) { + using Param = param::Convolution; + Benchmarker benchmarker_float(handle()); + Benchmarker benchmarker_half(handle()); + const size_t RUNS = 10; + benchmarker_float.set_display(false) + .set_times(RUNS) + .set_dtype(0, dtype::Float32{}) + .set_dtype(1, dtype::Float32{}) + .set_dtype(2, dtype::Float32{}) + .set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_F32DIRECT")); + benchmarker_half.set_display(false) + .set_times(RUNS) + .set_dtype(0, dtype::Float16{}) + .set_dtype(1, dtype::Float16{}) + .set_dtype(2, dtype::Float16{}) + .set_before_exec_callback( + AlgoChecker("CONVOLUTION_DEFAULT_F16DIRECT")); + + auto run = [&](const TensorShapeArray& shapes, Param param) { + auto tfloat = benchmarker_float.set_param(param).exec(shapes) / RUNS; + auto thalf = benchmarker_half.set_param(param).exec(shapes) / RUNS; + + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout({shapes[0], dtype::Float32()}, + {shapes[1], dtype::Float32()}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * shapes[1][1] * + shapes[1][2] * shapes[1][3] * 2.0 / + (1024 * 1024 * 1024); + printf("run:%s %s float: %f ms %f Gflops VS half: %f ms %f Gflops " + "speepup: %f\n", + shapes[0].to_string().c_str(), shapes[1].to_string().c_str(), + tfloat, computations / tfloat * 1e3, thalf, + computations / thalf * 1e3, tfloat / thalf); + }; + + auto profile = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h, + size_t kernel, size_t stride) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + + run({{n, ic, h, w}, {oc, ic, kernel, kernel}, {}}, param); + + }; + + for (size_t kernel : {1, 2, 3, 4, 5, 6, 7}) { + for (size_t ic : {12}) { + for (size_t oc : {4}) { + for (size_t size : {17, 28, 32, 34, 64, 112, 256}) { + profile(1, oc, ic, size, size, kernel, 1); + } + } + } + } + for (auto k : {1, 2, 3, 4, 5, 6, 7}) { + profile(2, 12, 3, 480, 270, k, 1); + profile(2, 48, 12, 240, 135, k, 1); + profile(2, 4, 16, 240, 135, k, 1); + profile(2, 16, 4, 240, 135, k, 1); + profile(2, 8, 16, 240, 135, k, 1); + profile(2, 32, 8, 240, 135, k, 1); + profile(2, 8, 32, 120, 68, k, 1); + profile(2, 16, 64, 60, 34, k, 1); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE1) { + using Param = param::Convolution; + auto run_fp32 = [&](const TensorShapeArray& shapes, Param param) { + Benchmarker benchmarker_float(handle()); + size_t RUN = 50; + auto tfloat = + benchmarker_float.set_display(false) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_before_exec_callback(AlgoChecker( + "CONVOLUTION_DEFAULT_F32STRD1")) + .set_times(RUN) + .set_param(param) + .exec(shapes); + size_t IC = shapes[1][1]; + size_t FH = shapes[1][2]; + size_t FW = shapes[1][3]; + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout({shapes[0], dtype::Float32()}, + {shapes[1], dtype::Float32()}, dst_layout); + printf("fp32 flops: %.3f mflops\n", + (IC * dst_layout.total_nr_elems() * FH * FW * 2) / + (tfloat / RUN * 1000)); + }; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + auto run_fp16 = [&](const TensorShapeArray& shapes, Param param) { + Benchmarker benchmarker_float(handle()); + size_t RUN = 50; + auto tfloat = + benchmarker_float.set_display(false) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .set_before_exec_callback(AlgoChecker( + "CONVOLUTION_DEFAULT_F16STRD1")) + .set_times(RUN) + .set_param(param) + .exec(shapes); + size_t IC = shapes[1][1]; + size_t FH = shapes[1][2]; + size_t FW = shapes[1][3]; + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout({shapes[0], dtype::Float16()}, + {shapes[1], dtype::Float16()}, dst_layout); + printf("fp16 flops: %.3f mflops\n", + (IC * dst_layout.total_nr_elems() * FH * FW * 2) / + (tfloat / RUN * 1000)); + }; +#endif + auto profile = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t stride) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + printf("oc: %zd ic: %zd w: %zd h: %zd stride: %zd kernel_size: %zd\n", + oc, ic, w, h, stride, kernel); + + run_fp32({{1, ic, h, w}, {oc, ic, kernel, kernel}, {}}, param); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + run_fp16({{1, ic, h, w}, {oc, ic, kernel, kernel}, {}}, param); +#endif + + }; + + for (size_t kernel : {2, 3, 5}) { + for (size_t ic : {3, 6, 12, 24}) { + for (size_t oc : {3, 6, 12, 24}) { + for (size_t size : {4, 7, 8, 14, 16, 17, 28, 32, 34, 64, 112}) { + profile(oc, ic, size, size, kernel, 1); + } + } + } + } +} +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/cvt_color.cpp b/dnn/test/arm_common/cvt_color.cpp new file mode 100644 index 00000000..75d8c5f6 --- /dev/null +++ b/dnn/test/arm_common/cvt_color.cpp @@ -0,0 +1,114 @@ +/** + * \file dnn/test/arm_common/cvt_color.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/common/checker.h" +#include "test/common/benchmarker.h" +#include "test/common/cvt_color.h" + +#include "test/arm_common/fixture.h" + +namespace megdnn { +namespace test { + +using Mode = param::CvtColor::Mode; + +TEST_F(ARM_COMMON, CVTCOLOR) +{ + using namespace cvt_color; + std::vector args = get_args(); + Checker checker(handle()); + + for (auto &&arg: args) { + checker.set_param(arg.param) + .set_dtype(0, arg.dtype) + .set_dtype(1, arg.dtype) + .execs({arg.src, {}}); + } +} + +#if MEGDNN_WITH_BENCHMARK +TEST_F(ARM_COMMON, BENCHMARK_CVTCOLOR_RGB2GRAY) +{ + using namespace cvt_color; + using Param = param::CvtColor; + +#define BENCHMARK_PARAM(benchmarker, dtype) \ + benchmarker.set_param(param); \ + benchmarker.set_dtype(0, dtype); + + auto run = [&](const TensorShapeArray& shapes, Param param) { + auto handle_naive = create_cpu_handle(2); + Benchmarker benchmarker(handle()); + Benchmarker benchmarker_naive(handle_naive.get()); + + BENCHMARK_PARAM(benchmarker, dtype::Uint8()); + BENCHMARK_PARAM(benchmarker_naive, dtype::Uint8()); + for (auto&& shape : shapes) { + printf("execute %s: current---naive\n", shape.to_string().c_str()); + benchmarker.execs({shape, {}}); + benchmarker_naive.execs({shape, {}}); + } + + BENCHMARK_PARAM(benchmarker, dtype::Float32()); + BENCHMARK_PARAM(benchmarker_naive, dtype::Float32()); + for (auto&& shape : shapes) { + printf("execute %s: current---naive\n", shape.to_string().c_str()); + benchmarker.execs({shape, {}}); + benchmarker_naive.execs({shape, {}}); + } + + }; + + Param param; + TensorShapeArray shapes = { + {1, 500, 512, 3}, + {2, 500, 512, 3}, + }; + + param.mode = Param::Mode::RGB2GRAY; + run(shapes, param); + +} + +TEST_F(ARM_COMMON, BENCHMARK_CVTCOLOR_BT601_YUV) { + using namespace cvt_color; + using Param = param::CvtColor; + +#define BENCHMARK_PARAM(benchmarker, dtype) \ + benchmarker.set_param(param); \ + benchmarker.set_dtype(0, dtype); + + auto run = [&](const TensorShapeArray& shapes, Param param) { + auto handle_naive = create_cpu_handle(2); + Benchmarker benchmarker(handle()); + Benchmarker benchmarker_naive(handle_naive.get()); + + BENCHMARK_PARAM(benchmarker, dtype::Uint8()); + BENCHMARK_PARAM(benchmarker_naive, dtype::Uint8()); + for (auto&& shape : shapes) { + printf("execute %s: current---naive\n", shape.to_string().c_str()); + benchmarker.execs({shape, {}}); + benchmarker_naive.execs({shape, {}}); + } + }; + + Param param; + TensorShapeArray shapes = { + {1, 300, 512, 1}, + }; + + param.mode = Param::Mode::BT601_YUV2RGB_NV21; + run(shapes, param); +} +#endif + +} // namespace test +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/elemwise.cpp b/dnn/test/arm_common/elemwise.cpp new file mode 100644 index 00000000..e46a3efc --- /dev/null +++ b/dnn/test/arm_common/elemwise.cpp @@ -0,0 +1,369 @@ +/** + * \file dnn/test/arm_common/elemwise.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "test/common/elemwise.h" +#include "test/arm_common/fixture.h" +#include "test/common/benchmarker.h" +#include "test/common/checker.h" + +#include "megdnn/oprs/general.h" + +using namespace megdnn; +using namespace test; + +template +class ARM_ELEMWISE : public ARM_COMMON {}; +TYPED_TEST_CASE(ARM_ELEMWISE, elemwise::test_types); +TYPED_TEST(ARM_ELEMWISE, run) { + elemwise::run_test(this->handle()); +} + +#define TERNARY_COMPLATE_TEST_CASE(_optr) \ + printf("Check binary optr %s by all cases.\n", #_optr); \ + checker.set_param(Mode::_optr) \ + .execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \ + checker.set_param(Mode::_optr) \ + .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \ + checker.set_param(Mode::_optr) \ + .execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \ + checker.set_param(Mode::_optr) \ + .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \ + checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \ + checker.set_param(Mode::_optr) \ + .execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \ + checker.set_param(Mode::_optr) \ + .execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \ + checker.set_param(Mode::_optr) \ + .execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); \ + checker.set_param(Mode::_optr).execs({{3, 4, 5}, {1}, {1}, {}}); \ + checker.set_param(Mode::_optr).execs({{1}, {3, 4, 5}, {1}, {}}); + +#define BUILD_TERNARY_COMPLATE_TEST_CASE \ + TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3) + +TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle()); + // case int + checker.set_dtype(0, dtype::Int8()); + checker.set_dtype(1, dtype::Int8()); + checker.set_dtype(2, dtype::Int8()); + // BUILD_TERNARY_TEST_CASE + BUILD_TERNARY_COMPLATE_TEST_CASE + + checker.set_dtype(0, dtype::Int16()); + checker.set_dtype(1, dtype::Int16()); + checker.set_dtype(2, dtype::Int16()); + // BUILD_TERNARY_TEST_CASE + BUILD_TERNARY_COMPLATE_TEST_CASE + + checker.set_dtype(0, dtype::Int32()); + checker.set_dtype(1, dtype::Int32()); + checker.set_dtype(2, dtype::Int32()); + // BUILD_TERNARY_TEST_CASE + BUILD_TERNARY_COMPLATE_TEST_CASE + + // case float + UniformFloatRNG rng(1e-5, 7e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + checker.set_dtype(0, dtype::Float32()); + checker.set_dtype(1, dtype::Float32()); + checker.set_dtype(2, dtype::Float32()); + + // BUILD_TERNARY_TEST_CASE + BUILD_TERNARY_COMPLATE_TEST_CASE + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + // case half + UniformFloatRNG rng_float16(1, 10); + checker.set_rng(0, &rng_float16); + checker.set_epsilon(1e-2); + checker.set_dtype(0, dtype::Float16()); + checker.set_dtype(1, dtype::Float16()); + checker.set_dtype(2, dtype::Float16()); + + // BUILD_TERNARY_TEST_CASE + BUILD_TERNARY_COMPLATE_TEST_CASE +#endif +} + +TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle()); + + auto run = [&]() { + // VEC_BCAST101x not PowOp + checker.set_param(Mode::ADD).execs( + {{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::ADD).execs( + {{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::ADD).execs( + {{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.set_param(Mode::ADD).execs( + {{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::ADD).execs( + {{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + checker.set_param(Mode::RMULH) + .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::RMULH) + .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::RMULH) + .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.set_param(Mode::RMULH) + .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::RMULH) + .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + // BCAST101x_VEC not PowOp + checker.set_param(Mode::ADD).execs( + {{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.set_param(Mode::ADD).execs( + {{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); + checker.set_param(Mode::ADD).execs( + {{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.set_param(Mode::ADD).execs( + {{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::ADD).execs( + {{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + }; + checker.set_dtype(0, dtype::Int8()); + checker.set_dtype(1, dtype::Int8()); + run(); + checker.set_dtype(0, dtype::Int16()); + checker.set_dtype(1, dtype::Int16()); + run(); + checker.set_dtype(0, dtype::Int32()); + checker.set_dtype(1, dtype::Int32()); + run(); +} + +TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle()); + + UniformFloatRNG rng(1e-5, 7e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + checker.set_dtype(0, dtype::Float32()); + checker.set_dtype(1, dtype::Float32()); + + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + + auto run = [&](Mode mode) { + // VEC_BCAST101x + checker.set_param(mode).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(mode).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(mode).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(mode).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + // BCAST101x_VEC not powOp + checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); + checker.set_param(mode).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(mode).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + }; + run(Mode::ADD); + run(Mode::FUSE_ADD_H_SWISH); + run(Mode::FUSE_ADD_RELU); + run(Mode::MAX); + run(Mode::MIN); + run(Mode::MUL); + run(Mode::SUB); + run(Mode::TRUE_DIV); + run(Mode::POW); +} + +#if MEGDNN_WITH_BENCHMARK +namespace { +void run_elemwise_benchmark(const TensorShapeArray& shapes, + param::Elemwise::Mode mode, const char* mode_str, + DType type, Handle* handle_bench) { + auto handle_fallback = create_cpu_handle(1); + Benchmarker benchmarker_bench(handle_bench); + Benchmarker benchmarker_fallback(handle_fallback.get()); + + float throughput = 0; + SmallVector layouts; + std::string src_strs; + for (size_t i = 0; i < shapes.size(); i++) { + layouts.emplace_back(shapes[i], type); + throughput += layouts.back().span().dist_byte(); + src_strs += layouts.back().to_string(); + if (i != shapes.size() - 1) { + src_strs += ","; + } + } + constexpr size_t RUN = 50; + benchmarker_fallback.set_times(RUN).set_display(false); + benchmarker_bench.set_times(RUN).set_display(false); + + benchmarker_fallback.set_param(mode); + benchmarker_bench.set_param(mode); + + TensorLayout dst_layout; + auto opr = handle_bench->create_operator(); + opr->param() = mode; + opr->deduce_layout(layouts, dst_layout); + + float computations = dst_layout.total_nr_elems() * + (std::max(shapes.size(), 2) - 1); + throughput += dst_layout.span().dist_byte(); + computations *= (1e3 / (1024.0 * 1024)); + throughput *= (1e3 / (1024.0 * 1024)); + + layouts.emplace_back(dst_layout); + auto fallback_time = benchmarker_fallback.execl(layouts) / RUN; + auto bench_time = benchmarker_bench.execl(layouts) / RUN; + + float fallback_flops = computations / fallback_time; + float bench_flops = computations / bench_time; + float fallback_thr = throughput / fallback_time; + float bench_thr = throughput / bench_time; + + printf("%s = %s (type: %s, mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS " + "%fMB/s " + "computations: %fx, throughput: %fx\n", + src_strs.c_str(), dst_layout.to_string().c_str(), type.name(), + mode_str, fallback_flops, fallback_thr, bench_flops, bench_thr, + bench_flops / fallback_flops, bench_thr / fallback_thr); +} +} // namespace + +#define INT_RUN(shape, mode) \ + run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \ + run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \ + run_elemwise_benchmark(shape, mode, #mode, dtype::Int32{}, handle()); + +#define FLOAT_RUN(shape, mode) \ + run_elemwise_benchmark(shape, mode, #mode, dtype::Float32{}, handle()); \ + run_elemwise_benchmark(shape, mode, #mode, dtype::Float16{}, handle()); + +#define BENCHMARK_CASES(shape) \ + INT_BENCHMARK_CASES(shape) \ + FLOAT_BENCHMARK_CASES(shape) + +TEST_F(ARM_COMMON, BENCHMARK_UNARY) { +#define INT_BENCHMARK_CASES(shape) \ + INT_RUN(shape, Mode::RELU); \ + INT_RUN(shape, Mode::ABS); + +#define FLOAT_BENCHMARK_CASES(shape) \ + FLOAT_RUN(shape, Mode::RELU); \ + FLOAT_RUN(shape, Mode::ABS); \ + FLOAT_RUN(shape, Mode::SIGMOID); \ + FLOAT_RUN(shape, Mode::EXP); \ + FLOAT_RUN(shape, Mode::TANH); \ + FLOAT_RUN(shape, Mode::FAST_TANH); + + using Mode = param::Elemwise::Mode; + BENCHMARK_CASES({{10000}}); + BENCHMARK_CASES({{50000}}); + +#undef INT_BENCHMARK_CASES +#undef FLOAT_BENCHMARK_CASES +} + +TEST_F(ARM_COMMON, BENCHMARK_BINARY) { +#define INT_BENCHMARK_CASES(shape) \ + INT_RUN(shape, Mode::MIN); \ + INT_RUN(shape, Mode::MAX); \ + INT_RUN(shape, Mode::ADD); \ + INT_RUN(shape, Mode::SUB); \ + INT_RUN(shape, Mode::MUL); \ + INT_RUN(shape, Mode::RMULH); \ + INT_RUN(shape, Mode::FUSE_ADD_RELU); + +#define FLOAT_BENCHMARK_CASES(shape) \ + FLOAT_RUN(shape, Mode::MIN); \ + FLOAT_RUN(shape, Mode::MAX); \ + FLOAT_RUN(shape, Mode::ADD); \ + FLOAT_RUN(shape, Mode::SUB); \ + FLOAT_RUN(shape, Mode::MUL); \ + FLOAT_RUN(shape, Mode::POW); \ + FLOAT_RUN(shape, Mode::TRUE_DIV); \ + FLOAT_RUN(shape, Mode::FUSE_ADD_RELU); + + using Mode = param::Elemwise::Mode; + TensorShapeArray shapes = {{1, 112, 28, 28}, {1, 112, 28, 28}}; + BENCHMARK_CASES(shapes); + shapes = {{1, 16, 1, 1}, {1, 16, 112, 112}}; + BENCHMARK_CASES(shapes); + shapes = {{1, 448, 7, 7}, {1, 448, 7, 7}}; + BENCHMARK_CASES(shapes); + +#undef INT_BENCHMARK_CASES +#undef FLOAT_BENCHMARK_CASES +} + +TEST_F(ARM_COMMON, BENCHMARK_TERNARY_FMA3) { +#define INT_BENCHMARK_CASES(shape) INT_RUN(shape, Mode::FUSE_MUL_ADD3); + +#define FLOAT_BENCHMARK_CASES(shape) FLOAT_RUN(shape, Mode::FUSE_MUL_ADD3); + + using Mode = param::Elemwise::Mode; + TensorShapeArray shapes = {{30, 40, 70}, {30, 40, 70}, {30, 40, 70}}; + BENCHMARK_CASES(shapes); + shapes = {{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}}; + BENCHMARK_CASES(shapes); + shapes = {{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}}; + BENCHMARK_CASES(shapes); + +#undef INT_BENCHMARK_CASES +#undef FLOAT_BENCHMARK_CASES +} + +#undef BENCHMARK_CASES +#undef INT_RUN +#undef FLOAT_RUN + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/elemwise_benchmark.cpp b/dnn/test/arm_common/elemwise_benchmark.cpp new file mode 100644 index 00000000..eafb7354 --- /dev/null +++ b/dnn/test/arm_common/elemwise_benchmark.cpp @@ -0,0 +1,255 @@ +/** + * \file dnn/test/arm_common/elemwise_benchmark.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#if MEGDNN_WITH_BENCHMARK +#include "test/arm_common/fixture.h" + +#include "megdnn/oprs.h" +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/rng.h" + +using namespace megdnn; +using namespace test; + +#define TEST_IN_DIFF_DISTRUBUTION(proportion_of_inf, dataset_number) \ + max_val = 88.3762626647949f / (1 - proportion_of_inf); \ + UniformFloatRNG rng##dataset_number(0.f, max_val); \ + B.set_rng(0, &rng##dataset_number); \ + B.execs({{355600}, {}}); + +TEST_F(ARM_COMMON, BENCHMARK_ELEM_UNARY_FLOATONLY) { + Benchmarker B(handle()); + using Mode = ElemwiseForward::Param::Mode; + // UniformFloatWithZeroRNG rng(80, 100, 0.1); + printf("Test Optr exp(x)\n"); + B.set_param(Mode::EXP); + B.execs({{355600}, {}}); + + B.set_param(Mode::EXP); + B.execs({{355600}, {}}); + float max_val = 0; + TEST_IN_DIFF_DISTRUBUTION(0.25, 1) + TEST_IN_DIFF_DISTRUBUTION(0.5, 2) + TEST_IN_DIFF_DISTRUBUTION(0.75, 3) + TEST_IN_DIFF_DISTRUBUTION(0.9999, 4) + + printf("Test Optr tanh(x)\n"); + B.set_param(Mode::TANH); + B.execs({{355600}, {}}); + + B.set_param(Mode::TANH); + B.execs({{355600}, {}}); + max_val = 0; + TEST_IN_DIFF_DISTRUBUTION(0.25, 5) + TEST_IN_DIFF_DISTRUBUTION(0.5, 6) + TEST_IN_DIFF_DISTRUBUTION(0.75, 7) + TEST_IN_DIFF_DISTRUBUTION(0.9999, 8) + + printf("Test Optr fast_tanh(x)\n"); + B.set_param(Mode::FAST_TANH); + B.execs({{355600}, {}}); + + printf("Test Optr sigmoid(x)\n"); + B.set_param(Mode::SIGMOID); + B.execs({{355600}, {}}); + TEST_IN_DIFF_DISTRUBUTION(0.25, 9) + TEST_IN_DIFF_DISTRUBUTION(0.5, 10) + TEST_IN_DIFF_DISTRUBUTION(0.75, 11) + TEST_IN_DIFF_DISTRUBUTION(0.9999, 12) + + B.set_param(Mode::SIGMOID); + B.execs({{355600}, {}}); + max_val = 0; + TEST_IN_DIFF_DISTRUBUTION(0.25, 13) + TEST_IN_DIFF_DISTRUBUTION(0.5, 14) + TEST_IN_DIFF_DISTRUBUTION(0.75, 15) + TEST_IN_DIFF_DISTRUBUTION(0.9999, 16) +} + +TEST_F(ARM_COMMON, BENCHMARK_ELEMWISE_UNARY) { + Benchmarker B(handle()); + using Mode = ElemwiseForward::Param::Mode; + + const size_t RUN_TIMES = 10; + B.set_times(RUN_TIMES).set_display(false); + + auto run_unary = [&](const TensorShape& shape, param::Elemwise::Mode mode, + const char* mode_str, DType dtype) { + B.set_param(mode).set_dtype(0, dtype); + float time = B.execs({shape, {}}) / RUN_TIMES; + float computations = + shape.total_nr_elems() * 2 / (1024.f * 1024.f * 1024.f); + printf("%s(%s):\tlayout(%s)\ttime(%fms)\tbandwidth(%fGBps)\n", mode_str, + dtype.name(), shape.to_string().c_str(), time, + computations * dtype.size() / time * 1e3); + }; +#define RUN(shape, mode, dtype) run_unary(shape, mode, #mode, dtype); + +#define BENCHMARK_CASES_INT(shape, dtype) \ + RUN(shape, Mode::RELU, dtype) \ + RUN(shape, Mode::ABS, dtype) + +#define BENCHMARK_CASES_FLOAT(shape, dtype) \ + BENCHMARK_CASES_INT(shape, dtype) \ + RUN(shape, Mode::SIGMOID, dtype) \ + RUN(shape, Mode::EXP, dtype) \ + RUN(shape, Mode::TANH, dtype) \ + RUN(shape, Mode::FAST_TANH, dtype) + + TensorShape shape = {10, 50, 10, 100}; + BENCHMARK_CASES_INT(shape, dtype::Int32()); + BENCHMARK_CASES_INT(shape, dtype::Int16()); + BENCHMARK_CASES_INT(shape, dtype::Int8()); + BENCHMARK_CASES_FLOAT(shape, dtype::Float32()); +#undef BENCHMARK_CASES_INT +#undef BENCHMARK_CASES_FLOAT +#undef RUN +} + +TEST_F(ARM_COMMON, BENCHMARK_ELEMWISE_UNARY_MULTI_TYPE) { + Benchmarker B(handle()); + using Mode = ElemwiseMultiType::Param::Mode; + + const size_t RUN_TIMES = 20; + B.set_times(RUN_TIMES).set_display(false); + + auto run_unary = [&](const TensorShape& shape, Mode mode, + const char* mode_str, DType src_dtype, + DType dst_dtype) { + B.set_param(mode).set_dtype(0, src_dtype).set_dtype(1, dst_dtype); + float time = B.execs({shape, {}}) / RUN_TIMES; + float computations = + shape.total_nr_elems() * 2 / (1024.f * 1024.f * 1024.f); + printf("type %s %s(%s) to %s \ttime(%fms)\tbandwidth(%fGBps)\n", + mode_str, src_dtype.name(), shape.to_string().c_str(), + dst_dtype.name(), time, + computations * src_dtype.size() / time * 1e3); + }; + +#define RUN(shape, mode, src_dtype, dst_dtye) \ + run_unary(shape, mode, #mode, src_dtype, dst_dtye); + +#define BENCHMARK_CASES_INT(shape, src_dtype, dst_dtye) \ + RUN(shape, Mode::QRELU, src_dtype, dst_dtye) \ + RUN(shape, Mode::QABS, src_dtype, dst_dtye) + + TensorShape shape = {10, 50, 10, 100}; + BENCHMARK_CASES_INT(shape, dtype::QuantizedS32(62.5f), + dtype::QuantizedS8(2.5f)); +#undef BENCHMARK_CASES_INT +#undef BENCHMARK_CASES_FLOAT +#undef RUN +} + +TEST_F(ARM_COMMON, BENCHMARK_ELEMWISE_BINARY) { + Benchmarker B(handle()); + using Mode = ElemwiseForward::Param::Mode; + + const size_t RUN_TIMES = 10; + B.set_times(RUN_TIMES).set_display(false); + + auto run_binary = [&](const TensorShape& shape0, const TensorShape& shape1, + param::Elemwise::Mode mode, const char* mode_str, + DType dtype) { + B.set_param(mode).set_dtype(0, dtype).set_dtype(1, dtype); + float time = B.execs({shape0, shape1, {}}) / RUN_TIMES; + float bandwidth = + (shape0.total_nr_elems() + shape1.total_nr_elems() + + std::max(shape0.total_nr_elems(), shape1.total_nr_elems())) / + (1024.f * 1024.f * 1024.f) * dtype.size() / time * 1e3; + printf("%s(%s):\tlayout(%s %s)\ttime(%fms)\tbandwidth(%fGBps)\n", + mode_str, dtype.name(), shape0.to_string().c_str(), + shape1.to_string().c_str(), time, bandwidth); + }; +#define RUN(shape0, shape1, mode, dtype) \ + run_binary(shape0, shape1, mode, #mode, dtype); + +#define BENCHMARK_CASES_INT(shape0, shape1, dtype) \ + RUN(shape0, shape1, Mode::ADD, dtype) \ + RUN(shape0, shape1, Mode::MIN, dtype) \ + RUN(shape0, shape1, Mode::MAX, dtype) \ + RUN(shape0, shape1, Mode::SUB, dtype) \ + RUN(shape0, shape1, Mode::MUL, dtype) \ + RUN(shape0, shape1, Mode::FUSE_ADD_RELU, dtype) + +#define BENCHMARK_CASES_FLOAT(shape0, shape1, dtype) \ + BENCHMARK_CASES_INT(shape0, shape1, dtype) \ + RUN(shape0, shape1, Mode::TRUE_DIV, dtype) \ + RUN(shape0, shape1, Mode::FUSE_ADD_SIGMOID, dtype) \ + RUN(shape0, shape1, Mode::FUSE_ADD_TANH, dtype) + +#define BENCHMARK_CASES_EVERY_DTYPE(shape0, shape1) \ + BENCHMARK_CASES_INT(shape0, shape1, dtype::Int32()); \ + BENCHMARK_CASES_INT(shape0, shape1, dtype::Int16()); \ + BENCHMARK_CASES_INT(shape0, shape1, dtype::Int8()); \ + BENCHMARK_CASES_FLOAT(shape0, shape1, dtype::Float32()); + + TensorShape shape0 = {10, 50, 10, 100}; + TensorShape shape1 = {10, 50, 10, 100}; + BENCHMARK_CASES_EVERY_DTYPE(shape0, shape1); + + shape1 = {1, 50, 1, 1}; + BENCHMARK_CASES_EVERY_DTYPE(shape0, shape1); + + shape1 = {1, 1, 1, 1}; + BENCHMARK_CASES_EVERY_DTYPE(shape0, shape1); +#undef BENCHMARK_CASES_EVERY_DTYPE +#undef BENCHMARK_CASES_FLOAT +#undef BENCHMARK_CASES_INT +#undef RUN +} + +TEST_F(ARM_COMMON, BENCHMARK_ELEMWISE_TERNARY) { + Benchmarker B(handle()); + using Mode = ElemwiseForward::Param::Mode; + + const size_t RUN_TIMES = 10; + B.set_times(RUN_TIMES).set_display(false); + + auto run_ternary = [&](const TensorShape& shape0, const TensorShape& shape1, + const TensorShape& shape2, + param::Elemwise::Mode mode, const char* mode_str, + DType dtype) { + B.set_param(mode).set_dtype(0, dtype).set_dtype(1, dtype).set_dtype( + 2, dtype); + float time = B.execs({shape0, shape1, shape2, {}}) / RUN_TIMES; + float bandwidth = (shape0.total_nr_elems() * 2 + + shape1.total_nr_elems() + shape2.total_nr_elems()) / + (1024.f * 1024.f * 1024.f) * dtype.size() / time * + 1e3; + printf("%s(%s):\tlayout(%s %s %s)\ttime(%fms)\tbandwidth(%fGBps)\n", + mode_str, dtype.name(), shape0.to_string().c_str(), + shape1.to_string().c_str(), shape2.to_string().c_str(), time, + bandwidth); + }; + + TensorShape shape = {10, 50, 10, 100}; + run_ternary(shape, shape, shape, Mode::FUSE_MUL_ADD3, "FUSE_MUL_ADD3", + dtype::Int32()); + run_ternary(shape, shape, shape, Mode::FUSE_MUL_ADD3, "FUSE_MUL_ADD3", + dtype::Int16()); + run_ternary(shape, shape, shape, Mode::FUSE_MUL_ADD3, "FUSE_MUL_ADD3", + dtype::Int8()); + run_ternary(shape, shape, shape, Mode::FUSE_MUL_ADD3, "FUSE_MUL_ADD3", + dtype::Float32()); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + run_ternary(shape, {1}, {1}, Mode::FUSE_MUL_ADD3, "FUSE_MUL_ADD3", + dtype::Float32()); + run_ternary(shape, {1}, {1}, Mode::FUSE_MUL_ADD3, "FUSE_MUL_ADD3", + dtype::Float16()); + run_ternary({1}, shape, {1}, Mode::FUSE_MUL_ADD3, "FUSE_MUL_ADD3", + dtype::Float32()); + run_ternary({1}, shape, {1}, Mode::FUSE_MUL_ADD3, "FUSE_MUL_ADD3", + dtype::Float16()); +#endif +} +#endif diff --git a/dnn/test/arm_common/elemwise_multi_type.cpp b/dnn/test/arm_common/elemwise_multi_type.cpp new file mode 100644 index 00000000..0a14b2f3 --- /dev/null +++ b/dnn/test/arm_common/elemwise_multi_type.cpp @@ -0,0 +1,262 @@ +/** + * \file dnn/test/arm_common/elemwise_multi_type.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "test/common/elemwise_multi_type.h" +#include "megdnn/oprs.h" +#include "test/arm_common/fixture.h" +#include "test/common/checker.h" +#include "test/common/timer.h" +#include "test/common/workspace_wrapper.h" + +using namespace megdnn; +using namespace test; + +namespace { +template +class ARM_COMMON_ELEMWISE_MULTI_TYPE : public ARM_COMMON {}; +TYPED_TEST_CASE(ARM_COMMON_ELEMWISE_MULTI_TYPE, + elemwise_multi_type::test_types); +} // anonymous namespace + +TYPED_TEST(ARM_COMMON_ELEMWISE_MULTI_TYPE, run) { + elemwise_multi_type::run_test(this->handle()); +} + +TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_UNARY) { + using Mode = ElemwiseMultiType::Param::Mode; + Checker checker(handle()); + + std::unique_ptr rng; + for (auto mode : {Mode::QRELU, Mode::QABS, Mode::QSIGMOID, Mode::QEXP, + Mode::QTANH, Mode::QFAST_TANH, Mode::QH_SWISH}) { + checker.set_param({mode}); + + for (DType src_type : std::vector{ + dtype::QuantizedS8(1.4f), + dtype::Quantized8Asymm(1.3f, static_cast(4)), + dtype::QuantizedS32(1.3f)}) { + checker.set_dtype(0, src_type); + if (src_type.enumv() == DTypeEnum::QuantizedS8) { + rng = std::make_unique(-127, 127); + checker.set_dtype(1, dtype::QuantizedS8(1.7f)); + } else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) { + rng = std::make_unique(0, 255); + checker.set_dtype(1, dtype::Quantized8Asymm( + 1.7f, static_cast(10))); + } else { + rng = std::make_unique(INT16_MIN >> 1, + INT16_MAX >> 1); + } + + checker.set_rng(0, rng.get()); + auto run = [&]() { + checker.execs({{3, 4, 5, 6}, {}}); + + checker.execs({{3}, {}}); + checker.execs({{9}, {}}); + checker.execs({{17}, {}}); + }; + + if (src_type.enumv() == DTypeEnum::QuantizedS32) { + for (DType dst_type : std::vector{ + dtype::QuantizedS8(32718.6f), + dtype::Quantized8Asymm( + 32729.6f, static_cast(128))}) { + checker.set_dtype(1, dst_type); + run(); + } + } else { + run(); + } + } + } +} + +TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_BINARY) { + using Mode = ElemwiseMultiType::Param::Mode; + Checker checker(handle()); + auto run = [&]() { + //! nchw44 + checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); + checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + //! VEC + SCALAR + checker.execs({{3, 4, 5, 6}, {1, 1, 1, 1}, {}}); + checker.execs({{1, 1, 1, 1}, {3, 4, 5, 6}, {}}); + + //! VEC + 1C11 + checker.execs({{3, 4, 5, 6}, {1, 4, 1, 1}, {}}); + checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {}}); + + //! VEC + VEC + checker.execs({{3}, {3}, {}}); + checker.execs({{9}, {9}, {}}); + checker.execs({{17}, {17}, {}}); + checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {}}); + }; + + // qint32 to qint8/quint8 + for (auto mode : + {Mode::QADD, Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_H_SWISH}) { + checker.set_param({mode}); + UniformIntRNG rng{INT16_MIN >> 1, INT16_MAX >> 1}; + checker.set_rng(0, &rng) + .set_rng(1, &rng) + .set_dtype(0, dtype::QuantizedS32(1.3f)) + .set_dtype(1, dtype::QuantizedS32(1.2f)); + + for (DType dst_type : + std::vector{dtype::QuantizedS8(32718.6f), + dtype::Quantized8Asymm( + 32729.6f, static_cast(128))}) { + checker.set_dtype(2, dst_type); + run(); + } + } + + for (auto mode : {Mode::QMUL, Mode::QADD, Mode::QMIN, Mode::QMAX, + Mode::QSUB, Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_SIGMOID, + Mode::QFUSE_ADD_H_SWISH}) { + checker.set_param({mode}); + + // qint8 to qint8 + UniformIntRNG rng_int8{-127, 127}; + checker.set_rng(0, &rng_int8) + .set_rng(1, &rng_int8) + .set_dtype(0, dtype::QuantizedS8(1.35f)) + .set_dtype(1, dtype::QuantizedS8(1.15f)) + .set_dtype(2, dtype::QuantizedS8(1.75f)); + + run(); + // quint8 to quint8 + UniformIntRNG rng_uint8{0, 255}; + checker.set_rng(0, &rng_uint8) + .set_rng(1, &rng_uint8) + .set_dtype(0, dtype::Quantized8Asymm(1.35f, + static_cast(128))) + .set_dtype(1, dtype::Quantized8Asymm(1.15f, + static_cast(128))) + .set_dtype(2, dtype::Quantized8Asymm( + 1.75f, static_cast(128))); + + run(); + } + + //! TRUE_DIV : 0.0 / 0.0 will fail + checker.set_param({Mode::QTRUE_DIV}); + UniformIntRNG rng_int8_1{-127, 127}; + UniformIntRNG rng_int8_2{-127, -1}; + checker.set_rng(0, &rng_int8_1) + .set_rng(1, &rng_int8_2) + .set_dtype(0, dtype::QuantizedS8(1.4f)) + .set_dtype(1, dtype::QuantizedS8(1.1f)) + .set_dtype(2, dtype::QuantizedS8(1.7f)); + + run(); + + // quint8 to quint8 + UniformIntRNG rng_uint8_1{0, 255}; + UniformIntRNG rng_uint8_2{0, 127}; + checker.set_rng(0, &rng_uint8_1) + .set_rng(1, &rng_uint8_2) + .set_dtype(0, + dtype::Quantized8Asymm(1.35f, static_cast(128))) + .set_dtype(1, + dtype::Quantized8Asymm(1.15f, static_cast(128))) + .set_dtype(2, dtype::Quantized8Asymm(1.75f, + static_cast(128))); + + run(); + + //! TANH + checker.set_param({Mode::QFUSE_ADD_TANH}); + UniformIntRNG rng_int8{-5, 5}; + checker.set_rng(0, &rng_int8) + .set_rng(1, &rng_int8) + .set_dtype(0, dtype::QuantizedS8(1.1f)) + .set_dtype(1, dtype::QuantizedS8(1.4f)) + .set_dtype(2, dtype::QuantizedS8(1.7f)); + + run(); + + UniformIntRNG rng_uint8{123, 133}; + checker.set_rng(0, &rng_uint8) + .set_rng(1, &rng_uint8) + .set_dtype(0, + dtype::Quantized8Asymm(1.1f, static_cast(128))) + .set_dtype(1, + dtype::Quantized8Asymm(1.4f, static_cast(128))) + .set_dtype(2, + dtype::Quantized8Asymm(1.7f, static_cast(128))); + + run(); +} + +TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY) { + using Mode = ElemwiseMultiType::Param::Mode; + Checker checker(handle()); + + for (auto mode : {Mode::QFUSE_MUL_ADD3}) { + checker.set_param({mode}); + + // qint8 to qint8 + UniformIntRNG rng_int8{-127, 127}; + checker.set_rng(0, &rng_int8) + .set_rng(1, &rng_int8) + .set_rng(2, &rng_int8) + .set_dtype(0, dtype::QuantizedS8(1.45f)) + .set_dtype(1, dtype::QuantizedS8(1.15f)) + .set_dtype(2, dtype::QuantizedS8(1.75f)) + .set_dtype(3, dtype::QuantizedS8(1.35f)); + + checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}}); + checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}}); + + checker.execs({{3}, {3}, {3}, {}}); + checker.execs({{9}, {9}, {9}, {}}); + checker.execs({{17}, {17}, {17}, {}}); + checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}}); + + // quint8 to quint8 + UniformIntRNG rng_uint8{0, 225}; + checker.set_rng(0, &rng_uint8) + .set_rng(1, &rng_uint8) + .set_rng(2, &rng_uint8) + .set_dtype(0, dtype::Quantized8Asymm(1.35f, + static_cast(128))) + .set_dtype(1, dtype::Quantized8Asymm(1.15f, + static_cast(128))) + .set_dtype(2, dtype::Quantized8Asymm(1.75f, + static_cast(128))) + .set_dtype(3, dtype::Quantized8Asymm( + 1.45f, static_cast(128))); + + checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}}); + checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}}); + + checker.execs({{3}, {3}, {3}, {}}); + checker.execs({{9}, {9}, {9}, {}}); + checker.execs({{17}, {17}, {17}, {}}); + checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}}); + } +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/test/arm_common/fixture.h b/dnn/test/arm_common/fixture.h new file mode 100644 index 00000000..e1cabaf2 --- /dev/null +++ b/dnn/test/arm_common/fixture.h @@ -0,0 +1,28 @@ +/** + * \file dnn/test/arm_common/fixture.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include "test/cpu/fixture.h" + +namespace megdnn { +namespace test { + +class ARM_COMMON : public CPU {}; + +class ARM_COMMON_MULTI_THREADS : public CPU_MULTI_THREADS {}; + +class ARM_COMMON_BENCHMARK_MULTI_THREADS : public CPU_BENCHMARK_MULTI_THREADS { +}; + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/group_local.cpp b/dnn/test/arm_common/group_local.cpp new file mode 100644 index 00000000..8361f3fd --- /dev/null +++ b/dnn/test/arm_common/group_local.cpp @@ -0,0 +1,45 @@ +/** + * \file dnn/test/arm_common/group_local.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/arm_common/fixture.h" + +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/group_local.h" +#include "test/common/timer.h" + +namespace megdnn { +namespace test { +using Param = param::Convolution; + +TEST_F(ARM_COMMON, GROUP_LOCAL_FORWARD) { + auto args = group_local::get_args(); + Checker checker(handle()); + for (auto&& arg : args) { + checker.set_param(arg.param).execs( + {arg.sshape(), arg.fshape(), arg.dshape()}); + } + + NormalRNG rng(10.f); + checker.set_rng(0, &rng).set_rng(1, &rng); + args = group_local::get_args_for_fp16(); + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + for (auto&& arg : args) { + checker.set_dtype(0, dtype::Float16()).set_dtype(1, dtype::Float16()).set_dtype(2, dtype::Float16()); + checker.set_epsilon(1e-2); + checker.set_param(arg.param).execs( + {arg.sshape(), arg.fshape(), arg.dshape()}); + } +#endif +} +} // namsepace test +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/local.cpp b/dnn/test/arm_common/local.cpp new file mode 100644 index 00000000..9b22445a --- /dev/null +++ b/dnn/test/arm_common/local.cpp @@ -0,0 +1,124 @@ +/** + * \file dnn/test/arm_common/local.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/arm_common/fixture.h" + +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/local.h" +#include "test/common/timer.h" + +namespace megdnn { +namespace test { +using Param = param::Convolution; + +TEST_F(ARM_COMMON, LOCAL_FORWARD) { + auto args = local::get_args(); + Checker checker(handle()); + for (auto&& arg : args) { + checker.set_param(arg.param).execs( + {arg.sshape(), arg.fshape(), arg.dshape()}); + } + + NormalRNG rng(10.f); + checker.set_rng(0, &rng).set_rng(1, &rng); + args = local::get_args_for_fp16(); + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + for (auto&& arg : args) { + checker.set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()); + checker.set_epsilon(1e-2); + checker.set_param(arg.param).execs( + {arg.sshape(), arg.fshape(), arg.dshape()}); + } +#endif +} + +#if MEGDNN_WITH_BENCHMARK +TEST_F(ARM_COMMON, BENCHMARK_LOCAL_FORWARD) { + auto run = [&](const TensorShapeArray& shapes, Param param) { + Benchmarker benchmarker(handle()); + size_t RUN = 50; + benchmarker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()); + auto tfloat32 = benchmarker.set_display(true) + .set_times(RUN) + .set_param(param) + .exec(shapes); + int N = shapes[0][0]; + int IC = shapes[0][1]; + int IH = shapes[0][2]; + int IW = shapes[0][3]; + int OH = shapes[1][0]; + int OW = shapes[1][1]; + int FH = shapes[1][3]; + int FW = shapes[1][4]; + int OC = shapes[1][5]; + std::cout << "LOCAL FORWARD, src: {" << N << ", " << IC << ", " << IH + << ", " << IW << "}" << std::endl; + std::cout << "LOCAL FORWARD, filter: {" << OH << ", " << OW << ", " + << IC << ", " << FH << ", " << FW << ", " << OC << "}" + << std::endl; + std::cout << "LOCAL FORWARD (f32), bandwidth: " + << (1.f * N * OC * OH * OW * FH * FW * IC + + 1.f * N * IC * IH * IW) * + sizeof(float) * 1e-9 / (tfloat32 / RUN * 1e-3) + << "GBPS" << std::endl; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + benchmarker.set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()); + auto tfloat16 = benchmarker.set_display(true) + .set_times(RUN) + .set_param(param) + .exec(shapes); + std::cout << "LOCAL FORWARD (f16), bandwidth: " + << (1.f * N * OC * OH * OW * FH * FW * IC + + 1.f * N * IC * IH * IW) * + sizeof(dt_float16) * 1e-9 / (tfloat16 / RUN * 1e-3) + << "GBPS" << std::endl; +#endif + }; + + Param param; + param.mode = param::Convolution::Mode::CONVOLUTION; + param.pad_h = param.pad_w = 1; + param.stride_h = param.stride_w = 1; + run({{1, 4, 320, 256}, {320, 256, 4, 3, 3, 24}, {}}, param); + param.stride_h = param.stride_w = 2; + run({{1, 4, 320, 256}, {160, 128, 4, 3, 3, 24}, {}}, param); + + param.pad_h = param.pad_w = 2; + param.stride_h = param.stride_w = 1; + run({{1, 4, 64, 64}, {64, 64, 4, 5, 5, 24}, {}}, param); + param.stride_h = param.stride_w = 2; + run({{1, 4, 64, 64}, {32, 32, 4, 5, 5, 24}, {}}, param); + + param.pad_h = param.pad_w = 3; + param.stride_h = param.stride_w = 1; + run({{1, 4, 64, 64}, {64, 64, 4, 7, 7, 24}, {}}, param); + param.stride_h = param.stride_w = 2; + run({{1, 4, 64, 64}, {32, 32, 4, 7, 7, 24}, {}}, param); + + param.pad_h = param.pad_w = 1; + param.stride_h = param.stride_w = 1; + run({{2, 128, 8, 8}, {8, 8, 128, 3, 3, 128}, {}}, param); + run({{1, 16, 64, 64}, {64, 64, 16, 3, 3, 16}, {}}, param); +} +#endif + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/matrix_mul.cpp b/dnn/test/arm_common/matrix_mul.cpp new file mode 100644 index 00000000..26874e1b --- /dev/null +++ b/dnn/test/arm_common/matrix_mul.cpp @@ -0,0 +1,368 @@ +/** + * \file dnn/test/arm_common/matrix_mul.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/arm_common/fixture.h" + +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/matrix_mul.h" +#include "test/common/rng.h" + +using namespace megdnn; +using namespace test; + +TEST_F(ARM_COMMON, MATRIX_MUL_INT8x8x32) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, + handle()); +} + +TEST_F(ARM_COMMON, MATRIX_MUL_INT8x8x16) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, + handle()); +} + +TEST_F(ARM_COMMON, MATRIX_MUL_QUINT8) { + matrix_mul::check_matrix_mul(dtype::Quantized8Asymm(1.2f, (uint8_t)127), + dtype::Quantized8Asymm(1.3f, (uint8_t)129), + {}, + handle()); +} + +TEST_F(ARM_COMMON, MATRIX_MUL_FP32) { + Checker checker(handle()); + using Param = MatrixMul::Param; + + auto run = [&](size_t M, size_t K, size_t N) { + Param param; + param.transposeA = false; + param.transposeB = false; + TensorShape A, B; + A = TensorShape{M, K}; + B = TensorShape{K, N}; + checker.set_param(param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .execs({A, B, {}}); + }; + + checker.set_before_exec_callback( + AlgoChecker("ARM_COMMON_F32_GEMV")); + // M < 8 + for (size_t M : {1, 2, 3, 4, 5, 6, 7}) + for (size_t K : {7, 1024, 2048}) + for (size_t N : {7, 1024, 2056}) + run(M, K, N); + // M = 8,K = 1, 2 + for (size_t M : {8}) + for (size_t K : {1, 2}) + for (size_t N : {7, 1024, 2056}) + run(M, K, N); + // N = 1 + for (size_t M : {1, 2, 3, 4, 5, 6, 7}) + for (size_t K : {7, 1024, 2048}) + for (size_t N : {1}) + run(M, K, N); +} +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +TEST_F(ARM_COMMON, MATRIX_MUL_FP16) { + Checker checker(handle()); + checker.set_epsilon(1e-2); + NormalRNG rng(2.f); + checker.set_rng(0, &rng).set_rng(1, &rng); + + using Param = MatrixMul::Param; + auto args = matrix_mul::get_matmul_args_no_mask(); + + for (auto& arg : args) { + size_t m = arg.m, n = arg.n, k = arg.k; + Param param; + param.transposeA = false; + param.transposeB = false; + TensorShape A, B; + A = TensorShape{m, k}; + B = TensorShape{k, n}; + checker.set_param(param) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .execs({A, B, {}}); + } +} +TEST_F(ARM_COMMON, MATRIX_MUL_FP16_TEST) { + Checker checker(handle()); + using Param = MatrixMul::Param; + checker.set_epsilon(1e-2); + NormalRNG rng(2.f); + checker.set_rng(0, &rng).set_rng(1, &rng); + + auto run = [&](size_t M, size_t K, size_t N) { + Param param; + param.transposeA = false; + param.transposeB = false; + TensorShape A, B; + A = TensorShape{M, K}; + B = TensorShape{K, N}; + checker.set_param(param) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .execs({A, B, {}}); + }; + checker.set_before_exec_callback( + AlgoChecker("ARM_COMMON_F16_GEMV")); + + // M = 1, 2, 3, 4 + for (size_t M : {1, 2, 3, 4}) + for (size_t K : {7, 512, 1024}) + for (size_t N : {13, 1024, 2048}) + run(M, K, N); + // N = 1 + for (size_t M : {1, 2, 3, 4}) + for (size_t K : {7, 512, 1024}) + for (size_t N : {1}) + run(M, K, N); +} +#endif + + +#if MEGDNN_WITH_BENCHMARK + +TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { + int exec_times = 10; + Benchmarker benchmarker(handle()); + benchmarker.set_times(exec_times); + + auto run = [&](size_t M, size_t K, size_t N) { + std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")" + << std::endl; + benchmarker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()); + auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; + auto computations = 2.f * M * K * N * 1e-6; + auto perf = computations / time; + std::cout << "gemv fp32, Performance is " << perf << " Gflops" + << std::endl; + }; + + std::cout << "warm up:\n"; + for (int i = 0; i < 50; i++) { + benchmarker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_display(false) + .exec({{2, 1024}, {1024, 512}, {}}); + benchmarker.set_display(true); + } + + // run gemv + for (size_t M : {1, 2, 4, 8}) + for (size_t K : {1024, 1536, 2048}) + for (size_t N : {512, 1024}) + run(M, K, N); +} +TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { + int exec_times = 50; + Benchmarker benchmarker(handle()); + benchmarker.set_times(exec_times); + benchmarker.set_before_exec_callback( + AlgoChecker("ARM_COMMON_F16_GEMV")); + + auto run = [&](size_t M, size_t K, size_t N) { + std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")" + << std::endl; + benchmarker.set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()); + auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; + auto computations = 2 * M * K * N * 1e-6; + auto perf = computations / time; + std::cout << "gemv fp16, Performance is " << perf << " Gflops" + << std::endl; + }; + + std::cout << "warm up:\n"; + for (int i = 0; i < 50; i++) { + benchmarker.set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .set_display(false) + .exec({{2, 1024}, {1024, 512}, {}}); + benchmarker.set_display(true); + } + + // run gemv + for (size_t M : {1, 2, 3, 4}) + for (size_t K : {1024, 1536, 2048}) + for (size_t N : {512, 1024}) + run(M, K, N); +} +TEST_F(ARM_COMMON, BENCHMARK_SGEMM) { + int exec_times = 10; + Benchmarker benchmarker(handle()); + benchmarker.set_times(exec_times); + + float mod = 1000 * exec_times / 1e9; + auto run = [&](size_t M, size_t K, size_t N) { + float time = 1.f, perf = 1.f; + std::cout << "SGEMM: (" << M << ", " << K << ", " << N << ")" + << std::endl; + benchmarker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()); + time = benchmarker.exec({{M, K}, {K, N}, {}}); + perf = 2.f * M * K * N / time * mod; + std::cout << "gemm fp32, Performance is " << perf << " Gflops" + << std::endl; + }; + + std::cout << "warm up:\n"; + for (int i = 0; i < 50; i++) { + benchmarker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_display(false) + .exec({{2, 1024}, {1024, 512}, {}}); + benchmarker.set_display(true); + } + + run(256, 12 * 24, 256); + + //////////////////////// gemv ////////////////////////// + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 64, 112, 256}) { + run (M, 1, K); + } + } + + //////////////////////// gemm ////////////////////////// + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 16, 32, 64, 112, 256}) { + for (size_t N : {8, 64, 112, 256}) { + run(M, N, K); + } + } + } + +} + + +TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + Benchmarker benchmarker_int(handle()); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int32{}) + .set_param(param).set_display(false); + Benchmarker benchmarker_float(handle()); + benchmarker_float.set_display(false).set_times(RUNS); + + auto run = [&](size_t M, size_t N, size_t K) { + auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS; + auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms " + "%f Gflops speedup: %f\n", + M, K, N, float_used, computations / float_used, int_used, + computations / int_used, float_used / int_used); + }; + + run(256, 12 * 24, 256); + + //////////////////////// gemv ////////////////////////// + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 64, 112, 256}) { + run (M, 1, K); + } + } + + //////////////////////// gemm ////////////////////////// + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 16, 32, 64, 112, 256}) { + for (size_t N : {8, 64, 112, 256}) { + run(M, N, K); + } + } + } +} + +TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_QUINT8) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + Benchmarker benchmarker_int(handle()); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127)) + .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129)) + .set_dtype(2, {}) + .set_param(param) + .set_display(false); + Benchmarker benchmarker_float(handle()); + benchmarker_float.set_display(false).set_times(RUNS); + + auto run = [&](size_t M, size_t N, size_t K) { + auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS; + auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms " + "%f Gflops speedup: %f\n", + M, K, N, float_used, computations / float_used, int_used, + computations / int_used, float_used / int_used); + }; + + run(256, 12 * 24, 256); + + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 64, 112, 256}) { + for (size_t N : {8, 64, 112, 256}) { + run(M, N, K); + } + } + } +} + +TEST_F(ARM_COMMON, BENCHMARK_TRANSPOSED_MATRIX_MUL_QUINT8) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = param.transposeB = true; + Benchmarker benchmarker_int(handle()); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127)) + .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129)) + .set_dtype(2, {}) + .set_param(param) + .set_display(false); + Benchmarker benchmarker_float(handle()); + benchmarker_float.set_param(param).set_display(false).set_times(RUNS); + + auto run = [&](size_t M, size_t N, size_t K) { + auto int_used = benchmarker_int.exec({{K, M}, {N, K}, {}}) / RUNS; + auto float_used = benchmarker_float.exec({{K, M}, {N, K}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms " + "%f Gflops speedup: %f\n", + M, K, N, float_used, computations / float_used, int_used, + computations / int_used, float_used / int_used); + }; + + run(256, 12 * 24, 256); + + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 64, 112, 256}) { + for (size_t N : {8, 64, 112, 256}) { + run(M, N, K); + } + } + } +} + +#endif + + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/pooling.cpp b/dnn/test/arm_common/pooling.cpp new file mode 100644 index 00000000..3b9bf576 --- /dev/null +++ b/dnn/test/arm_common/pooling.cpp @@ -0,0 +1,479 @@ +/** + * \file dnn/test/arm_common/pooling.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/arm_common/fixture.h" + +#include "test/common/pooling.h" +#include "test/common/checker.h" +#include "test/common/benchmarker.h" +#include "test/common/rng.h" + +namespace megdnn { +namespace test { + +TEST_F(ARM_COMMON, POOLING) +{ + using Param = param::Pooling; + // clang-format off + for (size_t ih: {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t iw: {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t p: {1, 2}) + { + Param param; + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 3; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + Checker checker(handle()); + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + + param.mode = Param::Mode::AVERAGE; + param.window_h = param.window_w = 3; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 4; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 5; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + if (ih + p * 2 >= 5 && iw + p * 2 >= 5) + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } + // clang-format on +} + +TEST_F(ARM_COMMON, POOLING_INT8_W2x2_S2x2) +{ + // clang-format off + for (size_t ih: {2, 3, 7, 13, 52, 53, 54, 55}) + for (size_t iw: {2, 3, 6, 14, 53, 54, 55, 56}) + for (size_t ph: {0, 1}) + for (size_t pw: {0, 1}) + if (ih+2*ph >= 3 && iw+2*pw >= 3) + { + Checker checker(handle()); + checker.set_dtype(0, dtype::Int8()); + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = 2; + param.window_h = param.window_w = 2; + checker.set_param(param).exec(TensorShapeArray{{2, 3, ih, iw}, {}}); + } + // clang-format on +} + +TEST_F(ARM_COMMON, POOLING_INT8_W3x3_S2x2) +{ + // clang-format off + for (size_t ih: {2, 3, 7, 13, 52, 53, 54, 55}) + for (size_t iw: {2, 3, 6, 14, 53, 54, 55, 56}) + for (size_t ph: {0, 1, 2}) + for (size_t pw: {0, 1, 2}) + if (ih+2*ph >= 3 && iw+2*pw >= 3) + { + Checker checker(handle()); + checker.set_dtype(0, dtype::Int8()); + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = 2; + param.window_h = param.window_w = 3; + checker.set_param(param).exec(TensorShapeArray{{2, 3, ih, iw}, {}}); + } + // clang-format on +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARM_COMMON, POOLING_FP16) { + Checker checker(handle()); + checker.set_dtype(0, dtype::Float16{}) + .set_dtype(1, dtype::Float16{}) + .set_epsilon(3e-3); + + using Param = param::Pooling; + for (size_t ih : {2, 3, 5, 7, 11, 13, 17, 19, 23}) + for (size_t iw : {2, 3, 5, 7, 11, 13, 17, 19, 23}) + for (auto mode : {Param::Mode::AVERAGE, Param::Mode::MAX}) { + for (size_t window : {2, 3}) { + Param param; + param.mode = mode; + param.window_h = param.window_w = window; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = window / 2; + //! test for SH == 1 && SW == 1 && FH == FW (FH == 2 || FH + //! == 3) + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + + //! test for SH = SW = 2 && FH = FW = 2 + param.stride_h = param.stride_w = 2; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } + } + + //! test for SH == 2 && SW == 2 && FH == FW == 3 max pooling + for (size_t ih : {2, 3, 7, 13, 52, 53, 54, 55}) + for (size_t iw : {2, 3, 6, 14, 53, 54, 55, 56}) + for (size_t ph : {0, 1, 2}) + for (size_t pw : {0, 1, 2}) + if (ih + 2 * ph >= 3 && iw + 2 * pw >= 3) { + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = 2; + param.window_h = param.window_w = 3; + checker.set_param(param).exec( + TensorShapeArray{{2, 3, ih, iw}, {}}); + } + + //! test for SH == 2 && SW == 2 && FH = FW = 4 max pooling + for (size_t ih : + {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t iw : + {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t p : {1, 2}) { + Param param; + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 4; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } + + //! test for SH == 2 && SW == 2 && FH = FW = 5 max pooling + for (size_t ih : + {3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t iw : + {3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t p : {1, 2}) { + Param param; + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 5; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } +} +#endif + +TEST_F(ARM_COMMON, POOLING_QUANTIZED) { + Checker checker(handle()); + UniformIntRNG rng1{INT8_MIN >> 1, INT8_MAX >> 1}; + UniformIntRNG rng2{0, UINT8_MAX >> 1}; + + using Param = param::Pooling; + + for (auto type : std::vector{ + dtype::QuantizedS8(1.1f), + dtype::Quantized8Asymm(1.1f, static_cast(3))}) { + if (type.enumv() == DTypeEnum::QuantizedS8) { + checker.set_rng(0, &rng1); + } else { + megdnn_assert(type.enumv() == DTypeEnum::Quantized8Asymm); + checker.set_rng(0, &rng2); + } + for (size_t ih : {2, 3, 5, 7, 11, 13, 17, 19, 23, 33, 49}) + for (size_t iw : {2, 3, 5, 7, 11, 13, 17, 19, 23, 33, 49}) + for (auto mode : {Param::Mode::AVERAGE, Param::Mode::MAX}) { + for (size_t window : {2, 3}) { + Param param; + param.mode = mode; + param.window_h = param.window_w = window; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = window / 2; + //! test for SH == 1 && SW == 1 && FH == FW (FH == 2 || + //! FH + //! == 3) + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + + //! test for SH = SW = 2 && FH = FW = 2 + param.stride_h = param.stride_w = 2; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } + } + + //! test for SH == 2 && SW == 2 && FH == FW == 3 max pooling + for (size_t ih : {2, 3, 7, 13, 52, 53, 54, 55}) + for (size_t iw : {2, 3, 6, 14, 53, 54, 55, 56}) + for (size_t ph : {0, 1, 2}) + for (size_t pw : {0, 1, 2}) + if (ih + 2 * ph >= 3 && iw + 2 * pw >= 3) { + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.pad_h = ph; + param.pad_w = pw; + param.window_h = param.window_w = 3; + param.stride_h = param.stride_w = 2; + checker.set_param(param).exec( + TensorShapeArray{{2, 3, ih, iw}, {}}); + } + + //! test for SH == 2 && SW == 2 && FH == FW == 4 max pooling + for (size_t ih : + {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t iw : + {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t p : {1, 2}) { + Param param; + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 4; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } + + //! test for SH == 2 && SW == 2 && FH == FW == 5 max pooling + for (size_t ih : + {3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t iw : + {3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t p : {1, 2}) { + Param param; + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 5; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } + } +} + +#if MEGDNN_WITH_BENCHMARK +TEST_F(ARM_COMMON, BENCHMARK_POOLING_INT8_W3x3_S2x2) +{ + using Param = param::Pooling; + auto run = [&](const TensorShapeArray &shapes, + Param param) { + auto handle_naive = create_cpu_handle(2); + TensorLayoutArray layouts; + layouts.emplace_back(shapes[0], dtype::Int8()); + layouts.emplace_back(shapes[1], dtype::Int8()); + Benchmarker benchmarker_naive(handle_naive.get()); + Benchmarker benchmarker_float(handle()); + Benchmarker benchmarker_int(handle()); + size_t RUN = 10; + auto t1 = benchmarker_naive.set_display(false).set_times(RUN). + set_param(param).exec(shapes); + auto t2 = benchmarker_float.set_display(false).set_times(RUN). + set_param(param).exec(shapes); + auto t3 = benchmarker_int.set_display(false).set_times(RUN). + set_param(param).execl(layouts); + printf("naive=%.3fms float=%.3fms, int=%.3fms\n", + t1 / RUN, t2 / RUN, t3 / RUN); + auto speedup = t2/t3; + ASSERT_GE(speedup, 2.0); + }; + Param param; + param.window_h = param.window_w = 3; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + std::cout << "3x3 with 2x2 stride max pooling:" << std::endl; + run({{1, 3, 640, 480}, {}}, param); +} + +TEST_F(ARM_COMMON, BENCHMARK_POOLING_W4x4_S2x2) +{ + using Param = param::Pooling; + auto run = [&](const TensorShapeArray &shapes, + Param param) { + std::cout << "N:" << shapes[0][0] << " " + << "IC:" << shapes[0][1] << " " + << "IH:" << shapes[0][2] << " " + << "IW:" << shapes[0][3] << std::endl; + auto handle_naive = create_cpu_handle(2); + Benchmarker benchmarker_naive(handle_naive.get()); + Benchmarker benchmarker_float(handle()); + size_t RUN = 10; + auto t1 = benchmarker_naive.set_display(false).set_times(RUN). + set_param(param).exec(shapes); + auto t2 = benchmarker_float.set_display(false).set_times(RUN). + set_param(param).exec(shapes); + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout({shapes[0], dtype::Float32()}, dst_layout); + float calc_amount = dst_layout.total_nr_elems() * + param.window_h * param.window_w; + printf("naive={%.3fms, %.3fMflops}, neon={%.3fms, %.3fMflops}\n", + t1 / RUN, calc_amount / (t1 / RUN * 1000), + t2 / RUN, calc_amount / (t2 / RUN * 1000)); + }; + Param param; + param.window_h = param.window_w = 4; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + std::cout << "4x4 with 2x2 stride max pooling:" << std::endl; + run({{1, 24, 160, 128}, {}}, param); + run({{1, 4, 240, 135}, {}}, param); + run({{1, 32, 120, 67}, {}}, param); + run({{1, 64, 60, 33}, {}}, param); +} + +TEST_F(ARM_COMMON, BENCHMARK_POOLING_W5x5_S2x2) +{ + using Param = param::Pooling; + auto run = [&](const TensorShapeArray &shapes, + Param param) { + std::cout << "N:" << shapes[0][0] << " " + << "IC:" << shapes[0][1] << " " + << "IH:" << shapes[0][2] << " " + << "IW:" << shapes[0][3] << std::endl; + auto handle_naive = create_cpu_handle(2); + Benchmarker benchmarker_naive(handle_naive.get()); + Benchmarker benchmarker_float(handle()); + size_t RUN = 10; + auto t1 = benchmarker_naive.set_display(false).set_times(RUN). + set_param(param).exec(shapes); + auto t2 = benchmarker_float.set_display(false).set_times(RUN). + set_param(param).exec(shapes); + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout({shapes[0], dtype::Float32()}, dst_layout); + float calc_amount = dst_layout.total_nr_elems() * + param.window_h * param.window_w; + printf("naive={%.3fms, %.3fMflops}, neon={%.3fms, %.3fMflops}\n", + t1 / RUN, calc_amount / (t1 / RUN * 1000), + t2 / RUN, calc_amount / (t2 / RUN * 1000)); + }; + Param param; + param.window_h = param.window_w = 5; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + std::cout << "5x5 with 2x2 stride max pooling:" << std::endl; + run({{1, 24, 160, 128}, {}}, param); + run({{1, 4, 240, 135}, {}}, param); + run({{1, 32, 120, 67}, {}}, param); + run({{1, 64, 60, 33}, {}}, param); +} + + +TEST_F(ARM_COMMON, BENCHMARK_POOLING_FP16) { + using Param = param::Pooling; + auto run = [&](const TensorShapeArray& shapes, Param param) { + TensorLayoutArray layouts; + layouts.emplace_back(shapes[0], dtype::Float16()); + layouts.emplace_back(shapes[1], dtype::Float16()); + Benchmarker benchmarker_float(handle()); + Benchmarker benchmarker_half(handle()); + size_t RUN = 10; + auto tf = benchmarker_float.set_display(false) + .set_times(RUN) + .set_param(param) + .exec(shapes) / + RUN; + auto th = benchmarker_half.set_display(false) + .set_times(RUN) + .set_param(param) + .execl(layouts) / + RUN; + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout({shapes[0], dtype::Float32()}, dst_layout); + + float computations = dst_layout.total_nr_elems() * param.window_h * + param.window_w / (1024.f * 1024 * 1024); + printf("float=%.3fms %f gflops, float16=%.3fms %f gflops speedup: %f\n", + tf, computations / tf * 1e3, th, computations / th * 1e3, + tf / th); + }; + Param param; + param.window_h = param.window_w = 2; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = 1; + printf("2x2 with 1x1 stride max pooling:\n"); + run({{1, 3, 640, 480}, {}}, param); + + for (size_t oh : {640, 128}) + for (size_t ow : {480, 112}) { + param.window_h = param.window_w = 3; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + param.mode = Param::Mode::AVERAGE; + printf("3x3 with 2x2 stride average pooling.\n"); + run({{1, 3, oh, ow}, {}}, param); + + for (size_t pw : {2, 3, 4, 5}) { + param.window_h = param.window_w = pw; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + param.mode = Param::Mode::MAX; + printf("%zux%zu with 2x2 stride max pooling:\n", pw, pw); + run({{1, 3, oh, ow}, {}}, param); + } + } +} + +TEST_F(ARM_COMMON, BENCHMARK_POOLING_QUANTIZED) { + using Param = param::Pooling; + auto run = [&](const TensorShapeArray& shapes, Param param) { + auto handle_naive = create_cpu_handle(2); + TensorLayoutArray layouts; + layouts.emplace_back(shapes[0], dtype::QuantizedS8(1.1f)); + layouts.emplace_back(shapes[1], dtype::QuantizedS8(1.1f)); + Benchmarker benchmarker_int(handle()); + Benchmarker benchmarker_naive(handle_naive.get()); + size_t RUN = 10; + auto time_int = benchmarker_int.set_display(false) + .set_times(RUN) + .set_param(param) + .exec(shapes) / + RUN; + auto time_naive = benchmarker_naive.set_display(false) + .set_times(RUN) + .set_param(param) + .execl(layouts) / + RUN; + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout({shapes[0], dtype::QuantizedS8(1.1f)}, dst_layout); + + float computations = dst_layout.total_nr_elems() * param.window_h * + param.window_w / (1024.f * 1024 * 1024); + printf("naive=%.3fms %f gflops, int8=%.3fms %f gflops speedup: %f\n", + time_naive, computations / time_naive * 1e3, time_int, + computations / time_int * 1e3, time_naive / time_int); + }; + Param param; + param.window_h = param.window_w = 2; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = 1; + printf("2x2 with 1x1 stride max pooling:\n"); + run({{1, 3, 640, 480}, {}}, param); + + // clang-format off + for (size_t oh : {640, 128}) + for (size_t ow : {480, 112}) + for (size_t pw : {2, 3, 4, 5}) { + param.window_h = param.window_w = pw; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + printf("%zux%zu with 2x2 stride max pooling:\n", pw, pw); + run({{1, 3, oh, ow}, {}}, param); + } + // clang-format on +} +#endif + +} // namespace test +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/pooling_multi_thread.cpp b/dnn/test/arm_common/pooling_multi_thread.cpp new file mode 100644 index 00000000..d9fc9224 --- /dev/null +++ b/dnn/test/arm_common/pooling_multi_thread.cpp @@ -0,0 +1,351 @@ +/** + * \file dnn/test/arm_common/pooling_multi_thread.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/arm_common/fixture.h" + +#include "test/common/pooling.h" +#include "test/common/checker.h" +#include "test/common/benchmarker.h" +#include "test/common/rng.h" + +namespace megdnn { +namespace test { + +/*********************** mutli threads *********************************/ +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING) { + using Param = param::Pooling; + for (size_t ih: {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t iw: {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t p: {1, 2}) + { + Param param; + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 3; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + Checker checker(handle()); + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + + param.mode = Param::Mode::AVERAGE; + param.window_h = param.window_w = 3; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 4; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 5; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + if (ih + p * 2 >= 5 && iw + p * 2 >= 5) + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } +} + +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_INT8_W3x3_S2x2) +{ + for (size_t ih: {2, 3, 7, 13, 52, 53, 54, 55}) + for (size_t iw: {2, 3, 6, 14, 53, 54, 55, 56}) + for (size_t ph: {0, 1, 2}) + for (size_t pw: {0, 1, 2}) + if (ih+2*ph >= 3 && iw+2*pw >= 3) + { + Checker checker(handle()); + checker.set_dtype(0, dtype::Int8()); + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = 2; + param.window_h = param.window_w = 3; + checker.set_param(param).exec(TensorShapeArray{ + {2, 3, ih, iw}, {}}); + } +} + +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_INT8_W2x2_S2x2) +{ + for (size_t ih: {2, 3, 7, 13, 52, 53, 54, 55}) + for (size_t iw: {2, 3, 6, 14, 53, 54, 55, 56}) + for (size_t ph: {0, 1}) + for (size_t pw: {0, 1}) + if (ih+2*ph >= 3 && iw+2*pw >= 3) + { + Checker checker(handle()); + checker.set_dtype(0, dtype::Int8()); + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = 2; + param.window_h = param.window_w = 2; + checker.set_param(param).exec(TensorShapeArray{ + {2, 3, ih, iw}, {}}); + } +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_FP16) { + Checker checker(handle()); + checker.set_dtype(0, dtype::Float16{}) + .set_dtype(1, dtype::Float16{}) + .set_epsilon(3e-3); + + using Param = param::Pooling; + for (size_t ih : {2, 3, 5, 7, 11, 13, 17, 19, 23}) + for (size_t iw : {2, 3, 5, 7, 11, 13, 17, 19, 23}) + for (auto mode : {Param::Mode::AVERAGE, Param::Mode::MAX}) { + for (size_t window : {2, 3}) { + Param param; + param.mode = mode; + param.window_h = param.window_w = window; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = window / 2; + //! test for SH == 1 && SW == 1 && FH == FW (FH == 2 || FH + //! == 3) + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + + //! test for SH = SW = 2 && FH = FW = 2 + param.stride_h = param.stride_w = 2; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } + } + + //! test for SH == 2 && SW == 2 && FH == FW == 3 max pooling + for (size_t ih : {2, 3, 7, 13, 52, 53, 54, 55}) + for (size_t iw : {2, 3, 6, 14, 53, 54, 55, 56}) + for (size_t ph : {0, 1, 2}) + for (size_t pw : {0, 1, 2}) + if (ih + 2 * ph >= 3 && iw + 2 * pw >= 3) { + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = 2; + param.window_h = param.window_w = 3; + checker.set_param(param).exec( + TensorShapeArray{{2, 3, ih, iw}, {}}); + } + + //! test for SH == 2 && SW == 2 && FH = FW = 4 max pooling + for (size_t ih : + {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t iw : + {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t p : {1, 2}) { + Param param; + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 4; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } + + //! test for SH == 2 && SW == 2 && FH = FW = 5 max pooling + for (size_t ih : + {3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t iw : + {3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t p : {1, 2}) { + Param param; + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 5; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } +} +#endif + +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_QUANTIZED) { + Checker checker(handle()); + UniformIntRNG rng1{INT8_MIN >> 1, INT8_MAX >> 1}; + UniformIntRNG rng2{0, UINT8_MAX >> 1}; + + using Param = param::Pooling; + + for (auto type : std::vector{ + dtype::QuantizedS8(1.1f), + dtype::Quantized8Asymm(1.1f, static_cast(3))}) { + if (type.enumv() == DTypeEnum::QuantizedS8) { + checker.set_rng(0, &rng1); + } else { + megdnn_assert(type.enumv() == DTypeEnum::Quantized8Asymm); + checker.set_rng(0, &rng2); + } + for (size_t ih : {2, 3, 5, 7, 11, 13, 17, 19, 23, 33, 49}) + for (size_t iw : {2, 3, 5, 7, 11, 13, 17, 19, 23, 33, 49}) + for (auto mode : {Param::Mode::AVERAGE, Param::Mode::MAX}) { + for (size_t window : {2, 3}) { + Param param; + param.mode = mode; + param.window_h = param.window_w = window; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = window / 2; + //! test for SH == 1 && SW == 1 && FH == FW (FH == 2 || + //! FH + //! == 3) + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + + //! test for SH = SW = 2 && FH = FW = 2 + param.stride_h = param.stride_w = 2; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } + } + + //! test for SH == 2 && SW == 2 && FH == FW == 3 max pooling + for (size_t ih : {2, 3, 7, 13, 52, 53, 54, 55}) + for (size_t iw : {2, 3, 6, 14, 53, 54, 55, 56}) + for (size_t ph : {0, 1, 2}) + for (size_t pw : {0, 1, 2}) + if (ih + 2 * ph >= 3 && iw + 2 * pw >= 3) { + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.pad_h = ph; + param.pad_w = pw; + param.window_h = param.window_w = 3; + param.stride_h = param.stride_w = 2; + checker.set_param(param).exec( + TensorShapeArray{{2, 3, ih, iw}, {}}); + } + + //! test for SH == 2 && SW == 2 && FH == FW == 4 max pooling + for (size_t ih : + {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t iw : + {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t p : {1, 2}) { + Param param; + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 4; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } + + //! test for SH == 2 && SW == 2 && FH == FW == 5 max pooling + for (size_t ih : + {3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t iw : + {3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t p : {1, 2}) { + Param param; + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 5; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } + } +} +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_FALLBACK) { + using Param = param::Pooling; + for (size_t ih: {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t iw: {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t p: {1, 2}) + { + Param param; + param.mode = Param::Mode::MAX; + param.window_h = param.window_w = 3; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = p; + Checker checker(handle()); + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } +} + +#if MEGDNN_WITH_BENCHMARK +namespace { +template +void benchmark_impl(const typename Opr::Param& param, + std::vector> shapes, size_t RUNS, + TaskExecutorConfig&& multi_thread_config, + TaskExecutorConfig&& single_thread_config) { + std::vector multi_thread_times, single_thread_times; + { + auto multi_thread_hanle = + create_cpu_handle(0, true, &multi_thread_config); + auto benchmarker = Benchmarker(multi_thread_hanle.get()); + benchmarker.set_times(RUNS).set_display(false).set_param(param); + for (auto shape : shapes) { + multi_thread_times.push_back(benchmarker.exec(shape) / RUNS); + } + } + { + auto single_thread_handle = + create_cpu_handle(0, true, &single_thread_config); + auto benchmarker = Benchmarker(single_thread_handle.get()); + benchmarker.set_times(RUNS).set_display(false).set_param(param); + for (auto shape : shapes) { + single_thread_times.push_back(benchmarker.exec(shape) / RUNS); + } + } + printf("Benchmark : Multi threads %zu, ", multi_thread_config.nr_thread); + printf("core_ids:"); + for (size_t i = 0; i < multi_thread_config.affinity_core_set.size(); i++) { + printf("%zu ", multi_thread_config.affinity_core_set[i]); + } + printf(", Single thread core_id %zu\n", + single_thread_config.affinity_core_set[0]); + for (size_t i = 0; i < shapes.size(); i++) { + auto shape = shapes[i]; + printf("Case: "); + for (auto sh : shape) + printf("%s ", sh.to_string().c_str()); + printf("%zu threads time: %f,\n single thread time: " + "%f. spead up = %f, speedup/cores=%f\n", + multi_thread_config.nr_thread, multi_thread_times[i], + single_thread_times[i], + single_thread_times[i] / multi_thread_times[i], + single_thread_times[i] / multi_thread_times[i] / + multi_thread_config.nr_thread); + } +} +} // namespace + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING) { + constexpr size_t RUNS = 50; + + using Param = param::Pooling; + Param param; + param.window_h = param.window_w = 3; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + + std::vector> shapes; + + shapes.push_back({{32, 32, 215, 215}, {}}); + shapes.push_back({{32, 32, 128, 128}, {}}); + shapes.push_back({{8, 256, 100, 100}, {}}); + shapes.push_back({{1, 256, 100, 100}, {}}); + shapes.push_back({{1, 32, 100, 100}, {}}); + shapes.push_back({{1, 256, 80, 80}, {}}); + shapes.push_back({{1, 256, 60, 60}, {}}); + shapes.push_back({{1, 256, 30, 30}, {}}); + + param.window_h = param.window_w = 3; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + printf("Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h, + param.stride_h, param.pad_h, static_cast(param.mode)); + benchmark_impl(param, shapes, RUNS, {4, {0, 1, 2, 3}}, {1, {0}}); + benchmark_impl(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}); + benchmark_impl(param, shapes, RUNS, {2, {0, 1}}, {1, {0}}); +} +#endif + +} // namespace test +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/reduce.cpp b/dnn/test/arm_common/reduce.cpp new file mode 100644 index 00000000..29182739 --- /dev/null +++ b/dnn/test/arm_common/reduce.cpp @@ -0,0 +1,171 @@ +/** + * \file dnn/test/arm_common/reduce.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/arm_common/fixture.h" + +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/common/benchmarker.h" + +using namespace megdnn; +using namespace test; + +TEST_F(ARM_COMMON, REDUCE) { + using Param = Reduce::Param; + using Mode = Param::Mode; + Checker checker(handle()); + UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; + checker.set_rng(0, &rng); + struct Config { + Param param; + DType dtype; + TensorShape shape; + Config(Param param, DType dtype, TensorShape shape) + : param(param), dtype(dtype), shape(shape) {} + }; + std::vector configs; + for (auto mode : {Mode::MEAN, Mode::MAX, Mode::MIN}) + for (auto dtype : std::vector{ + dtype::Float32(), dtype::Float16(), + dtype::QuantizedS8(1.3f), + dtype::Quantized8Asymm(1.3f, static_cast(3))}) + for (int32_t axis : {0, 1, 2}) { + for (size_t A : {1, 3, 5}) { + for (size_t B : {4, 6, 9, 16, 33, 45}) { + for (size_t C : {4, 6, 9, 16, 33, 45}) { + TensorShape shape{A, B, C}; + Param param(mode, axis); + Config config(param, dtype, shape); + configs.push_back(config); + } + } + } + } + for (auto&& config : configs) { + auto&& dtype = config.dtype; + auto&& param = config.param; + auto&& shape = config.shape; + + checker.set_dtype(0, dtype).set_param(param).execs({shape, {}}); + } + configs.clear(); + for (auto mode : {Mode::SUM, Mode::PRODUCT, Mode::SUM_SQR}) + for (auto dtype : + std::vector{dtype::Float32(), dtype::Float16()}) + for (int32_t axis : {0, 1, 2}) { + for (size_t A : {1, 3, 5}) { + for (size_t B : {4, 6, 9, 16, 33, 45}) { + for (size_t C : {4, 6, 9, 16, 33, 45}) { + TensorShape shape{A, B, C}; + Param param(mode, axis); + Config config(param, dtype, shape); + configs.push_back(config); + } + } + } + } + + UniformFloatRNG rng_float(-2, 2); + checker.set_rng(0, &rng_float); + checker.set_epsilon(1e-1); + for (auto&& config : configs) { + auto&& dtype = config.dtype; + auto&& param = config.param; + auto&& shape = config.shape; + if(dtype == dtype::Float16()) + checker.set_epsilon(1e-1); + else + checker.set_epsilon(1e-3); + + checker.set_dtype(0, dtype).set_param(param).execs({shape, {}}); + } +} + +#if MEGDNN_WITH_BENCHMARK +TEST_F(ARM_COMMON, BENCHMARK_REDUCE) { + auto run = [&](size_t A, size_t B, size_t C, size_t axis, + megdnn::param::Reduce::Mode mode, megdnn::DType& dtype) { + auto handle_fallback = create_cpu_handle(1); + Benchmarker benchmarker(handle()); + Benchmarker benchmarker_fallback(handle_fallback.get()); + benchmarker_fallback.set_display(false); + benchmarker.set_display(false); + constexpr size_t RUNS = 50; + benchmarker_fallback.set_times(RUNS); + benchmarker.set_times(RUNS); + param::Reduce param; + param.axis = axis; + param.mode = mode; + benchmarker.set_param(param); + benchmarker_fallback.set_param(param); + + TensorLayout src({A, B, C}, dtype), dst; + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout(src, dst); + + auto bench = [&](const char* msg) { + auto cur = benchmarker.execs({src, dst}) / RUNS; + auto fallback = + benchmarker_fallback.execs({src, dst}) / RUNS; + float computation = + src.total_nr_elems() / 1024.0 / 1024.0 / 1024.0 * 1e3; + printf("run %s->%s %s: fallback: %fms %fGflops " + "cur: %fms %fGflops speedup=%f\n", + src.to_string().c_str(), dst.to_string().c_str(), msg, + fallback, computation / fallback, cur, computation / cur, + fallback / cur); + }; + + benchmarker_fallback.set_dtype(0, dtype); + benchmarker.set_dtype(0, dtype); + bench(dtype.name()); + }; + + for (auto mode : {param::Reduce::Mode::MEAN, param::Reduce::Mode::MAX, + param::Reduce::Mode::MIN}) + for (int32_t axis : {1, 2}) { + if (mode == param::Reduce::Mode::MEAN) + printf("testcase mean %s\n", axis == 2 ? "c == 1" : "c > 1"); + else if (mode == param::Reduce::Mode::MAX) + printf("testcase max %s\n", axis == 2 ? "c == 1" : "c > 1"); + else if (mode == param::Reduce::Mode::MIN) + printf("testcase min %s\n", axis == 2 ? "c == 1" : "c > 1"); + for (auto dtype : + std::vector{dtype::Float16(), dtype::Float32(), + dtype::QuantizedS8(4.2f), + dtype::Quantized8Asymm(3.2f, static_cast(10))}) { + run(1, 1024, 49, axis, mode, dtype); + run(2, 10, 10000, axis, mode, dtype); + run(2, 100, 10000, axis, mode, dtype); + run(2, 10, 100000, axis, mode, dtype); + } + } + for (auto mode : {param::Reduce::Mode::SUM, param::Reduce::Mode::PRODUCT, + param::Reduce::Mode::SUM_SQR}) + for (int32_t axis : {1, 2}) { + if (mode == param::Reduce::Mode::SUM) + printf("testcase sum %s\n", axis == 2 ? "c == 1" : "c > 1"); + else if (mode == param::Reduce::Mode::PRODUCT) + printf("testcase product %s\n", axis == 2 ? "c == 1" : "c > 1"); + else if (mode == param::Reduce::Mode::SUM_SQR) + printf("testcase sum SumSqr %s\n", + axis == 2 ? "c == 1" : "c > 1"); + for (auto dtype : std::vector{dtype::Float16(), + dtype::Float32()}) { + run(1, 1024, 49, axis, mode, dtype); + run(2, 10, 10000, axis, mode, dtype); + run(2, 100, 10000, axis, mode, dtype); + run(2, 10, 100000, axis, mode, dtype); + } + } +} +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/resize.cpp b/dnn/test/arm_common/resize.cpp new file mode 100644 index 00000000..6e5c5fd3 --- /dev/null +++ b/dnn/test/arm_common/resize.cpp @@ -0,0 +1,45 @@ +/** + * \file dnn/test/arm_common/resize.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/arm_common/fixture.h" +#include "test/common/resize.h" +#include "test/common/checker.h" + +namespace megdnn { +namespace test { + +TEST_F(ARM_COMMON, RESIZE_CV) +{ + using namespace resize; + std::vector args = get_cv_args(); + Checker checker(handle()); + + for (auto &&arg: args) { + checker.set_param(arg.param) + .set_epsilon(1 + 1e-3) + .set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::Uint8()) + .execs({arg.src, arg.dst}); + } + + for (auto &&arg: args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .execs({arg.src, arg.dst}); + } + +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/test/arm_common/separable_filter.cpp b/dnn/test/arm_common/separable_filter.cpp new file mode 100644 index 00000000..9d364238 --- /dev/null +++ b/dnn/test/arm_common/separable_filter.cpp @@ -0,0 +1,41 @@ +/** + * \file dnn/test/arm_common/separable_filter.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/arm_common/fixture.h" +#include "test/common/separable_filter.h" +#include "test/common/checker.h" +#include "test/common/rng.h" + +namespace megdnn { +namespace test { + +TEST_F(ARM_COMMON, SEPARABLE_FILTER) +{ + using namespace separable_filter; + std::vector args = get_args(); + Checker checker(handle()); + for (auto &&arg: args) { + checker.set_param(arg.param).execs({arg.src, arg.filter_x, arg.filter_y, {}}); + } + + checker.set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_dtype(3, dtype::Uint8()) + .set_epsilon(1+1e-3); + for (auto&& arg : args) { + checker.set_param(arg.param).execs( + {arg.src, arg.filter_x, arg.filter_y, {}}); + } +} + +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/type_cvt.cpp b/dnn/test/arm_common/type_cvt.cpp new file mode 100644 index 00000000..36195f5b --- /dev/null +++ b/dnn/test/arm_common/type_cvt.cpp @@ -0,0 +1,185 @@ +/** + * \file dnn/test/arm_common/type_cvt.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/common/benchmarker.h" +#include "test/common/checker.h" + +#include "test/arm_common/fixture.h" + +namespace megdnn { +namespace test { + +TEST_F(ARM_COMMON, TYPE_CVT) { + Checker checker(handle()); + UniformIntRNG rng{INT32_MIN >> 1, INT32_MAX >> 1}; + UniformIntRNG rng8{INT8_MIN >> 1, INT8_MAX >> 1}; + + for (size_t size : {1, 7, 15, 33, 10000}) { + checker.set_rng(0, &rng); + checker.set_dtype(0, dtype::QuantizedS32(0.0000113264f)) + .set_dtype(1, dtype::Quantized8Asymm(0.018909f, + static_cast(3))) + .execs({{size}, {size}}); + + checker.set_dtype(0, dtype::QuantizedS32(0.0003f)) + .set_dtype(1, dtype::Quantized8Asymm(0.1f, + static_cast(3))) + .execs({{size}, {size}}); + + checker.set_dtype(0, dtype::QuantizedS32(0.000815917f)) + .set_dtype(1, dtype::QuantizedS8(0.245121f)) + .execs({{size}, {size}}); + + checker.set_dtype(0, dtype::QuantizedS32(0.0003f)) + .set_dtype(1, dtype::QuantizedS8(0.2f)) + .execs({{size}, {size}}); + + checker.set_rng(0, &rng8); + + //! we should not use so large random value, otherwise it may cause + //! compute error + checker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::QuantizedS8(0.245121f)) + .execs({{size}, {size}}); + + checker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Quantized8Asymm(0.1f, + static_cast(3))) + .execs({{size}, {size}}); + + checker.set_dtype(0, dtype::QuantizedS32(0.0004f)) + .set_dtype(1, dtype::QuantizedS32(0.0002f)) + .execs({{size}, {size}}); + + checker.set_dtype(0, dtype::QuantizedS8(0.3f)) + .set_dtype(1, dtype::QuantizedS8(0.2f)) + .execs({{size}, {size}}); + + checker.set_dtype(0, + dtype::Quantized8Asymm(0.3f, static_cast(8))) + .set_dtype(1, dtype::Quantized8Asymm(0.1f, + static_cast(3))) + .execs({{size}, {size}}); + + checker.set_dtype(0, dtype::QuantizedS8(0.245121f)) + .set_dtype(1, dtype::QuantizedS32(0.000815917f)) + .execs({{size}, {size}}); + + checker.set_dtype(0, dtype::QuantizedS8(0.2f)) + .set_dtype(1, dtype::QuantizedS32(0.0003f)) + .execs({{size}, {size}}); + + checker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float16()) + .execs({{size}, {size}}); + + checker.set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float32()) + .execs({{size}, {size}}); + } + + UniformIntRNG narrow_rng{-40000, 40000}; + checker.set_rng(0, &narrow_rng); + checker.set_dtype(0, dtype::QuantizedS32(0.000163794f)) + .set_dtype(1, dtype::Quantized8Asymm(0.0479196f, + static_cast(144))) + .execs({{1, 32, 24, 128}, {1, 32, 24, 128}}); +} + +#if MEGDNN_WITH_BENCHMARK +TEST_F(ARM_COMMON, BENCHMARK_TYPE_CVT) { + auto run = [&](const TensorShapeArray& shapes) { + auto handle_fallback = create_cpu_handle(1); + Benchmarker benchmarker(handle()); + Benchmarker benchmarker_fallback(handle_fallback.get()); + benchmarker_fallback.set_display(false); + benchmarker.set_display(false); + constexpr size_t RUNS = 50; + benchmarker_fallback.set_times(RUNS); + benchmarker.set_times(RUNS); + + auto bench = [&](const char* msg) { + for (auto&& shape : shapes) { + auto fallback = + benchmarker_fallback.execs({shape, shape}) / RUNS; + auto cur = benchmarker.execs({shape, shape}) / RUNS; + printf("run %s %s: fallback=%fms " + "cur=%fms speedup=%f\n", + shape.to_string().c_str(), msg, fallback, cur, + fallback / cur); + } + }; + + benchmarker_fallback.set_dtype(0, dtype::QuantizedS32(0.25f)) + .set_dtype(1, dtype::Quantized8Asymm(1.3f, + static_cast(3))); + benchmarker.set_dtype(0, dtype::QuantizedS32(0.25f)) + .set_dtype(1, dtype::Quantized8Asymm(1.3f, + static_cast(3))); + bench("QuantizedS32->Quantized8Asymm"); + + benchmarker_fallback + .set_dtype(0, dtype::Quantized8Asymm(0.25f, + static_cast(9))) + .set_dtype(1, dtype::Quantized8Asymm(1.3f, + static_cast(3))); + benchmarker + .set_dtype(0, dtype::Quantized8Asymm(0.25f, + static_cast(9))) + .set_dtype(1, dtype::Quantized8Asymm(1.3f, + static_cast(3))); + bench("Quantized8Asymm->Quantized8Asymm"); + + benchmarker_fallback.set_dtype(0, dtype::QuantizedS32(0.25f)) + .set_dtype(1, dtype::QuantizedS8(1.3f)); + benchmarker.set_dtype(0, dtype::QuantizedS32(0.25f)) + .set_dtype(1, dtype::QuantizedS8(1.3f)); + bench("QuantizedS32->QuantizedS8"); + + benchmarker_fallback.set_dtype(0, dtype::QuantizedS8(1.3f)) + .set_dtype(1, dtype::QuantizedS32(0.25f)); + benchmarker.set_dtype(0, dtype::QuantizedS32(1.3f)) + .set_dtype(1, dtype::QuantizedS8(0.25f)); + bench("QuantizedS8->QuantizedS32"); + + benchmarker_fallback.set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float32()); + benchmarker.set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float32()); + bench("Float16->Float32"); + + benchmarker_fallback.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float16()); + benchmarker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float16()); + bench("Float32->Float16"); + + benchmarker_fallback.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::QuantizedS8(0.245121f)); + benchmarker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::QuantizedS8(0.245121f)); + bench("Float32->QuantizedS8"); + + benchmarker_fallback.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Quantized8Asymm(0.1f, static_cast(3))); + benchmarker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Quantized8Asymm(0.1f, static_cast(3))); + bench("Float32->Quantized8Asymm"); + }; + + TensorShapeArray shapes = {{100000}, {1000000}}; + + run(shapes); +} +#endif + +} // namespace test +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/warp_affine.cpp b/dnn/test/arm_common/warp_affine.cpp new file mode 100644 index 00000000..8e11fb05 --- /dev/null +++ b/dnn/test/arm_common/warp_affine.cpp @@ -0,0 +1,44 @@ +/** + * \file dnn/test/arm_common/warp_affine.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/arm_common/fixture.h" +#include "test/common/warp_affine.h" +#include "test/common/checker.h" + +namespace megdnn { +namespace test { + +TEST_F(ARM_COMMON_MULTI_THREADS, WARP_AFFINE_CV) { + using namespace warp_affine; + std::vector args = get_cv_args(); + Checker checker(handle()); + + for (auto &&arg : args) { + checker.set_param(arg.param) + .set_epsilon(1 + 1e-3) + .set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Uint8()) + .execs({arg.src, arg.trans, arg.dst}); + } + + for (auto &&arg: args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .execs({arg.src, arg.trans, arg.dst}); + } + +} + +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/warp_perspective.cpp b/dnn/test/arm_common/warp_perspective.cpp new file mode 100644 index 00000000..262ca6b5 --- /dev/null +++ b/dnn/test/arm_common/warp_perspective.cpp @@ -0,0 +1,394 @@ +/** + * \file dnn/test/arm_common/warp_perspective.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include +#include "test/arm_common/fixture.h" + +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/random_state.h" +#include "test/common/rng.h" + +#include "test/common/warp_perspective.h" + +namespace megdnn { +namespace test { + +TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) { + //! Just for the format NHWC + Checker checker(handle()); + param::WarpPerspective param; + class ResizeMatRNG : public RNG { + void gen(const TensorND& tensor_) override { + auto& gen = RandomState::generator(); + std::uniform_real_distribution pdist3(1.9f, 3.1f); + std::uniform_real_distribution pdist(0.9f, 1.1f); + std::uniform_real_distribution pdisth(0.4f, 0.6f); + std::uniform_real_distribution ndist(-1.1f, -0.9f); + std::uniform_real_distribution ndist3(-3.1f, -1.9f); + std::uniform_real_distribution ndisth(-0.6f, -0.4f); + std::uniform_int_distribution dice(0, 5); + float* ptr = tensor_.ptr(); + auto N = tensor_.layout.shape[0]; + for (size_t n = 0; n < N; ++n) { + for (size_t i = 0; i < 9; ++i) { + switch (dice(gen)) { + case 0: + ptr[i] = pdist3(gen); + break; + case 1: + ptr[i] = pdist(gen); + break; + case 2: + ptr[i] = pdisth(gen); + break; + case 3: + ptr[i] = ndist(gen); + break; + case 4: + ptr[i] = ndist3(gen); + break; + case 5: + ptr[i] = ndisth(gen); + break; + } + } + // is resize? + if (n & 1) { + ptr[1] = 0; + ptr[3] = 0; + ptr[6] = ptr[7] = 0; + } + ptr += 9; + } + } + } rng; + + using BMode = param::WarpPerspective::BorderMode; + param.format = param::WarpPerspective::Format::NHWC; + // add for nearest test + param.imode = param::WarpPerspective::InterpolationMode::NEAREST; + for (auto mode : {BMode::REFLECT_101, BMode::REPLICATE, BMode::REFLECT, + BMode::WRAP, BMode::CONSTANT}) { + param.bmode = mode; + param.border_val = 1.737; + checker.set_param(param); + checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); + } + // resize nan case + UniformFloatRNG rng_zero(0, 0); + checker.set_rng(1, &rng_zero); + { + param.bmode = BMode::CONSTANT; + param.border_val = 1.737; + checker.set_param(param); + checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); + } + + // add linear test + param.imode = param::WarpPerspective::InterpolationMode::INTER_LINEAR; + for (auto mode : {BMode::REFLECT_101, BMode::REPLICATE, BMode::REFLECT, + BMode::WRAP, BMode::CONSTANT}) { + param.bmode = mode; + param.border_val = 1.737; + checker.set_param(param); + checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); + } + // resize nan case + checker.set_rng(1, &rng_zero); + { + param.bmode = BMode::CONSTANT; + param.border_val = 1.737; + checker.set_param(param); + checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); + } + + auto args = warp_perspective::get_cv_args(); + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Uint8()) + .execs({arg.src, arg.trans, arg.dst}); + } + + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .execs({arg.src, arg.trans, arg.dst}); + } +} + +TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) { + //! Just for the format NHWC + Checker checker(handle()); + param::WarpPerspective param; + class ResizeMatRNG : public RNG { + void gen(const TensorND& tensor_) override { + auto& gen = RandomState::generator(); + std::uniform_real_distribution pdist3(1.9f, 3.1f); + std::uniform_real_distribution pdist(0.9f, 1.1f); + std::uniform_real_distribution pdisth(0.4f, 0.6f); + std::uniform_real_distribution ndist(-1.1f, -0.9f); + std::uniform_real_distribution ndist3(-3.1f, -1.9f); + std::uniform_real_distribution ndisth(-0.6f, -0.4f); + std::uniform_int_distribution dice(0, 5); + float* ptr = tensor_.ptr(); + auto N = tensor_.layout.shape[0]; + for (size_t n = 0; n < N; ++n) { + for (size_t i = 0; i < 9; ++i) { + switch (dice(gen)) { + case 0: + ptr[i] = pdist3(gen); + break; + case 1: + ptr[i] = pdist(gen); + break; + case 2: + ptr[i] = pdisth(gen); + break; + case 3: + ptr[i] = ndist(gen); + break; + case 4: + ptr[i] = ndist3(gen); + break; + case 5: + ptr[i] = ndisth(gen); + break; + } + } + // is resize? + if (n & 1) { + ptr[1] = 0; + ptr[3] = 0; + ptr[6] = ptr[7] = 0; + } + ptr += 9; + } + } + } rng; + + using BMode = param::WarpPerspective::BorderMode; + param.format = param::WarpPerspective::Format::NHWC; + // add for nearest test + param.imode = param::WarpPerspective::InterpolationMode::NEAREST; + for (auto mode : {BMode::REFLECT_101, BMode::REPLICATE, BMode::REFLECT, + BMode::WRAP, BMode::CONSTANT}) { + param.bmode = mode; + param.border_val = 1.737; + checker.set_param(param); + checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); + } + // resize nan case + UniformFloatRNG rng_zero(0, 0); + checker.set_rng(1, &rng_zero); + { + param.bmode = BMode::CONSTANT; + param.border_val = 1.737; + checker.set_param(param); + checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); + } + + // add linear test + param.imode = param::WarpPerspective::InterpolationMode::INTER_LINEAR; + for (auto mode : {BMode::REFLECT_101, BMode::REPLICATE, BMode::REFLECT, + BMode::WRAP, BMode::CONSTANT}) { + param.bmode = mode; + param.border_val = 1.737; + checker.set_param(param); + checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); + } + // resize nan case + checker.set_rng(1, &rng_zero); + { + param.bmode = BMode::CONSTANT; + param.border_val = 1.737; + checker.set_param(param); + checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); + } + + auto args = warp_perspective::get_cv_args(); + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Uint8()) + .execs({arg.src, arg.trans, arg.dst}); + } + + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .execs({arg.src, arg.trans, arg.dst}); + } +} + +#if MEGDNN_WITH_BENCHMARK +TEST_F(ARM_COMMON, BENCHMARK_WARP_PERSPECTIVE_FORWARD) { + Benchmarker benchmarker(handle()); + auto handle_naive = create_cpu_handle(2); + Benchmarker benchmarker_naive(handle_naive.get()); + constexpr size_t NR_RUN = 50; + + using BMode = param::WarpPerspective::BorderMode; + using IMode = param::WarpPerspective::InterpolationMode; + + WarpPerspective::Param param; + param.border_val = 0.3f; + param.format = param::WarpPerspective::Format::NHWC; + + auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t OH, + size_t OW, size_t scale) { + printf("src={%zu, %zu, %zu, %zu}, dst={%zu, %zu, %zu, %zu}\n", N, IH, + IW, C, N, OH, OW, C); + auto time_ms = + benchmarker.exec({{N, IH, IW, C}, {N, 3, 3}, {N, OH, OW, C}}) / + NR_RUN; + auto time_naive_ms = + benchmarker_naive.exec( + {{N, IH, IW, C}, {N, 3, 3}, {N, OH, OW, C}}) / + NR_RUN; + auto bandwidth = N * C * (scale * OH * OW) * dtype::Float32().size(); + printf("aarch64: %.3f, perf: %.3f GBPS naive: %.3f, perf %.3f GBPS " + "speedup: %f\n", + time_ms, bandwidth / time_ms / 1e6, time_naive_ms, + bandwidth / time_naive_ms / 1e6, time_naive_ms / time_ms); + }; + + std::vector bmodestringmap = { + "REPLICATE", "REFLECT", "REFLECT_101", "WARP", "CONSTANT"}; + + std::vector imodestringmap = {"NEAREST", "INTER_LINEAR"}; + size_t scales[2] = {2, 5}; + + for (auto imode : {IMode::NEAREST, IMode::INTER_LINEAR}) { + for (auto bmode : {BMode::REFLECT_101, BMode::REPLICATE, BMode::REFLECT, + BMode::WRAP, BMode::CONSTANT}) { + param.imode = imode; + param.bmode = bmode; + benchmarker.set_param(param).set_display(false).set_times(NR_RUN); + benchmarker_naive.set_param(param).set_display(false).set_times( + NR_RUN); + size_t scale = scales[(int)imode]; + printf("\n\n\n warpperspective InterpolationMode::%s " + "BorderMode::%s start\n", + imodestringmap[(int)imode].c_str(), + bmodestringmap[(int)bmode].c_str()); + for (auto&& shape : + std::vector>{{700, 490}, + {500, 334}, + {472, 342}, + {448, 306}, + {626, 412}, + {140, 144}, + {120, 128}, + {180, 176}}) { + for (size_t ch : {1, 2, 3}) { + run(1, ch, shape.first, shape.second, 256, 256, scale); + } + } + } + } +} + +namespace { +void benchmark_impl(const typename WarpPerspective::Param& param, + std::vector> shapes, size_t RUNS, + TaskExecutorConfig&& multi_thread_config, + TaskExecutorConfig&& single_thread_config) { + std::vector multi_thread_times, single_thread_times; + { + auto multi_thread_hanle = + create_cpu_handle(0, true, &multi_thread_config); + auto benchmarker = + Benchmarker(multi_thread_hanle.get()); + benchmarker.set_times(RUNS).set_display(false).set_param(param); + for (auto shape : shapes) { + multi_thread_times.push_back(benchmarker.exec(shape) / RUNS); + } + } + { + auto single_thread_handle = + create_cpu_handle(0, true, &single_thread_config); + auto benchmarker = + Benchmarker(single_thread_handle.get()); + benchmarker.set_times(RUNS).set_display(false).set_param(param); + for (auto shape : shapes) { + single_thread_times.push_back(benchmarker.exec(shape) / RUNS); + } + } + printf("Benchmark : Multi threads %zu, ", multi_thread_config.nr_thread); + printf("core_ids:"); + for (size_t i = 0; i < multi_thread_config.affinity_core_set.size(); i++) { + printf("%zu ", multi_thread_config.affinity_core_set[i]); + } + printf(", Single thread core_id %zu\n", + single_thread_config.affinity_core_set[0]); + for (size_t i = 0; i < shapes.size(); i++) { + auto shape = shapes[i]; + printf("Case: "); + for (auto sh : shape) + printf("%s ", sh.to_string().c_str()); + printf("%zu threads time: %f,\n single thread time: " + "%f. spead up = %f, speedup/cores=%f\n", + multi_thread_config.nr_thread, multi_thread_times[i], + single_thread_times[i], + single_thread_times[i] / multi_thread_times[i], + single_thread_times[i] / multi_thread_times[i] / + multi_thread_config.nr_thread); + } +} +} // namespace + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_WARP_PERSPECTIVE) { + constexpr size_t RUNS = 50; + using BMode = param::WarpPerspective::BorderMode; + using IMode = param::WarpPerspective::InterpolationMode; + + WarpPerspective::Param param; + param.border_val = 0.3f; + param.format = param::WarpPerspective::Format::NHWC; + param.imode = IMode::INTER_LINEAR; + param.bmode = BMode::REPLICATE; + + std::vector> shapes; + auto bench_case = [&](size_t N, size_t H, size_t W, size_t C) { + SmallVector shape{ + {N, H, W, C}, {N, 3, 3}, {N, 224, 224, C}}; + shapes.push_back(shape); + }; + bench_case(1, 700, 490, 1); + bench_case(1, 700, 490, 2); + bench_case(1, 700, 490, 3); + bench_case(1, 500, 334, 1); + bench_case(1, 500, 334, 2); + bench_case(1, 500, 334, 3); + bench_case(1, 140, 144, 1); + bench_case(1, 140, 144, 2); + bench_case(1, 140, 114, 3); + + printf("Benchmark warp perspective\n"); + benchmark_impl(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}); + benchmark_impl(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}); + benchmark_impl(param, shapes, RUNS, {2, {4, 5}}, {1, {4}}); +} +#endif + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/winograd_filter_preprocess.cpp b/dnn/test/arm_common/winograd_filter_preprocess.cpp new file mode 100644 index 00000000..0126b1f6 --- /dev/null +++ b/dnn/test/arm_common/winograd_filter_preprocess.cpp @@ -0,0 +1,91 @@ +/** + * \file dnn/test/arm_common/winograd_filter_preprocess.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/common/checker.h" +#include "test/common/benchmarker.h" +#include "test/common/winograd_filter_preprocess.h" + +#include "test/arm_common/fixture.h" + +using namespace megdnn; +using namespace test; + +TEST_F(ARM_COMMON, WinogradFilterPreprocessF32) { + using namespace winograd_filter_preprocess; + Checker checker(handle()); + // default + std::vector args = get_args(6, 3); + std::vector args54 = get_args(5, 4); + std::vector args45 = get_args(4, 5); + + // mk4 + std::vector args_mk4_out2 = + get_mk_packed_args(2, param::Winograd::Format::MK4, 4); + std::vector args_mk4_out6 = + get_mk_packed_args(6, param::Winograd::Format::MK4, 4); + + args.insert(args.end(), args54.begin(), args54.end()); + args.insert(args.end(), args45.begin(), args45.end()); + args.insert(args.end(), args_mk4_out2.begin(), args_mk4_out2.end()); + args.insert(args.end(), args_mk4_out6.begin(), args_mk4_out6.end()); + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Float32{}) + .set_dtype(1, dtype::Float32{}) + .execs({arg.src, {}}); + } +} + +TEST_F(ARM_COMMON, WinogradFilterPreprocessQs8) { + using namespace winograd_filter_preprocess; + std::vector args = + get_mk_packed_args(2, param::Winograd::Format::MK8, 8); + Checker checker(handle()); + UniformIntRNG rng{-50, 50}; + checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &rng); + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS16(2.5f)) + .execs({arg.src, {}}); + } +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARM_COMMON, WinogradFilterPreprocessF16) { + using namespace winograd_filter_preprocess; + Checker checker(handle()); + // default + std::vector args = get_args(6, 3); + std::vector args_23 = + get_mk_packed_args(2, param::Winograd::Format::DEFAULT, 4); + std::vector args45 = get_args(4, 5); + + // mk8 + std::vector args_mk8_out2 = + get_mk_packed_args(2, param::Winograd::Format::MK8, 8); + + args.insert(args.end(), args_23.begin(), args_23.end()); + args.insert(args.end(), args45.begin(), args45.end()); + args.insert(args.end(), args_mk8_out2.begin(), args_mk8_out2.end()); + + Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00); + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_rng(0, rng) + .set_dtype(0, dtype::Float16{}) + .set_dtype(1, dtype::Float16{}) + .execs({arg.src, {}}); + } +} + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/armv7/conv_bias.cpp b/dnn/test/armv7/conv_bias.cpp new file mode 100644 index 00000000..18a471e7 --- /dev/null +++ b/dnn/test/armv7/conv_bias.cpp @@ -0,0 +1,75 @@ +/** + * \file dnn/test/armv7/conv_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/armv7/fixture.h" + +#include "test/common/convolution.h" +#include "test/common/conv_bias.h" +#include "test/common/checker.h" +#include "test/common/benchmarker.h" + +#include "test/common/rng.h" + +using namespace megdnn; +using namespace test; + +namespace{ + +TEST_F(ARMV7, CONV_BIAS_MATMUL_QU8) { + using namespace conv_bias; + std::vector args = get_quantized_args(); + Checker checker(handle()); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("QU8MATMUL")); + + UniformIntRNG rng{0, 127}; + for (auto&& arg : args) { + if (arg.bias.ndim == 4 && arg.bias[2] != 1 && arg.bias[3] != 1) + continue; + checker.set_dtype(0, dtype::Quantized8Asymm(2.5f, + static_cast(127))) + .set_dtype(1, dtype::Quantized8Asymm(2.7f, + static_cast(129))) + .set_dtype(2, dtype::QuantizedS32(6.75f)) + .set_dtype(4, dtype::Quantized8Asymm(60.25f, + static_cast(125))) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } +} + +TEST_F(ARMV7, CONV_BIAS_MATMUL_QS8) { + using namespace conv_bias; + std::vector args = get_quantized_args(); + Checker checker(handle()); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("S8MATMUL")); + + UniformIntRNG rng{0, 127}; + for (auto&& arg : args) { + if (arg.bias.ndim == 4 && arg.bias[2] != 1 && arg.bias[3] != 1) + continue; + checker.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.7f)) + .set_dtype(2, dtype::QuantizedS32(6.75f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng) + .set_param(arg.param) + .set_epsilon(1.0f) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } +} +} +// vim: syntax=cpp.doxygen diff --git a/dnn/test/armv7/convolution.cpp b/dnn/test/armv7/convolution.cpp new file mode 100644 index 00000000..8018ead2 --- /dev/null +++ b/dnn/test/armv7/convolution.cpp @@ -0,0 +1,152 @@ +/** + * \file dnn/test/armv7/convolution.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/armv7/fixture.h" + +#include "test/common/convolution.h" +#include "test/common/checker.h" +#include "test/common/benchmarker.h" + +#include "test/common/rng.h" + +using namespace megdnn; +using namespace test; + +#if MEGDNN_WITH_BENCHMARK +TEST_F(ARMV7, BENCHMARK_CONVOLUTION_STRIDE2) +{ + using Param = param::Convolution; + auto run = [&](const TensorShapeArray& shapes, Param param) { + Benchmarker benchmarker_float(handle()); + size_t RUN = 100; + auto tfloat = benchmarker_float.set_display(false) + .set_times(RUN) + .set_param(param) + .exec(shapes); + size_t IC = shapes[1][1]; + size_t FH = shapes[1][2]; + size_t FW = shapes[1][3]; + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout({shapes[0], dtype::Float32()}, + {shapes[1], dtype::Float32()}, dst_layout); + printf("flops: %.3f mflops\n", + (IC * dst_layout.total_nr_elems() * FH * FW * 2) / + (tfloat / RUN * 1000)); + }; + + auto profile = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t stride) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + printf("oc: %zd ic: %zd w: %zd h: %zd stride: %zd kernel_size: %zd\n", + oc, ic, w, h, stride, kernel); + + run({{1, ic, h, w}, {oc, ic, kernel, kernel}, {}}, + param); + + }; + + for (size_t kernel : {2, 3, 5, 7}) { + for (size_t ic : {3, 6, 12, 24}) { + for (size_t oc : {3, 6, 12, 24}) { + for (size_t size : {4, 7, 8, 14, 16, 17, 28, 32, 34, 64, 112}) { + profile(oc, ic, size, size, kernel, 2); + } + } + } + } +} +#endif + +TEST_F(ARMV7, BENCHMARK_CONVOLUTION_1X1) +{ + int exec_times = 50; + Benchmarker benchmarker_gemm(handle()); + benchmarker_gemm.set_times(exec_times); + + Benchmarker benchmarker(handle()); + benchmarker.set_times(exec_times); + + float mod = 1000 * exec_times / 1e9; + auto run = [&](size_t IC, size_t OC, size_t H, size_t W) { + float time = 1.f, perf = 1.f; + + std::cout< benchmarker_gconv1x1(handle()); + benchmarker_gconv1x1.set_times(exec_times); + + float mod = 1000 * exec_times / 1e9; + auto run = [&](size_t IC, size_t OC, size_t H, size_t W, size_t group){ + float time = 1.f, perf = 1.f; + + std::cout< + +#include "megdnn/handle.h" +#include "test/arm_common/fixture.h" + +#include + +namespace megdnn { +namespace test { + +class ARMV7 : public ARM_COMMON {}; + +class ARMV7_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {}; + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/armv7/matrix_mul.cpp b/dnn/test/armv7/matrix_mul.cpp new file mode 100644 index 00000000..56994f15 --- /dev/null +++ b/dnn/test/armv7/matrix_mul.cpp @@ -0,0 +1,410 @@ +/** + * \file dnn/test/armv7/matrix_mul.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/armv7/fixture.h" +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/matrix_mul.h" +#include "test/common/rng.h" + +using namespace megdnn; +using namespace test; + +TEST_F(ARMV7, MATRIX_MUL) { + matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, handle(), "ARMV7_F32"); +} + +TEST_F(ARMV7, MATRIX_MUL_MK4) { + matrix_mul::check_matrix_mul( + dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), + "ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 4); +} + +TEST_F(ARMV7, MATRIX_MUL_MK4_INT8) { + std::vector args; + for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11}) + for (size_t n : {1, 2, 3, 4, 5, 8, 16, 24, 25, 32}) + for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34}) + args.emplace_back(m, n, k, 0); + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, + handle(), "ARMV7_INT8X8X32_MK4_4X2X16", + param::MatrixMul::Format::MK4, 1, 1e-3, + std::move(args)); +} + +TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K4x8x8) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, + handle(), "ARMV7_INT8X8X16_K4X8X8"); +} + +TEST_F(ARMV7, MATRIX_MUL_INT16x16x32) { + matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, + handle(),"ARMV7_INT16X16X32_K12X4X1"); +} + +TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) { + matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, + handle(), "ARMV7_INT16X16X32_MK8_4X8", + param::MatrixMul::Format::MK8, 4); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARMV7, MATRIX_MUL_FP16) { + matrix_mul::check_matrix_mul(dtype::Float16{}, dtype::Float16{}, + dtype::Float16{}, handle(), + "AARCH32_F16_K4X16X1"); +} +TEST_F(ARMV7, MATRIX_MUL_F16_MK8) { + matrix_mul::check_matrix_mul( + dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), + "AARCH32_F16_MK8_4X8", param::MatrixMul::Format::MK8, 4); +} +#endif + +#if __ARM_FEATURE_DOTPROD +TEST_F(ARMV7, MATRIX_MUL_SDOT) { + matrix_mul::check_matrix_mul(dtype::Int8(), dtype::Int8(), dtype::Int32(), + handle(), "AARCH32_INT8_K6X8X4"); +} + +TEST_F(ARMV7, MATRIX_MUL_UDOT) { + matrix_mul::check_matrix_mul( + dtype::Quantized8Asymm(4.0f, static_cast(10)), dtype::Quantized8Asymm(3.0f, static_cast(54)), + dtype::QuantizedS32(12.0f), handle(), "AARCH32_QUINT8_K4X8X4"); +} +#endif + +#if MEGDNN_WITH_BENCHMARK + +namespace { +void run_8x8x16_benchmark(const char* algo, Handle* handle) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + Benchmarker benchmarker_int(handle); + Benchmarker benchmarker_int_kern_4x2x16(handle); + benchmarker_int.set_before_exec_callback( + AlgoChecker("ARM_COMMON_INT8X8X16")); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_param(param) + .set_display(false); + benchmarker_int_kern_4x2x16.set_before_exec_callback( + AlgoChecker(algo)); + benchmarker_int_kern_4x2x16.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_param(param) + .set_display(false); + Benchmarker benchmarker_float(handle); + benchmarker_float.set_display(false).set_times(RUNS); + + auto run = [&](size_t M, size_t N, size_t K) { + auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS; + auto int_kern_used = + benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) / RUNS; + auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f " + "ms " + "%f Gflops %s: %f ms %f Gflops " + "speedup(%s/arm_common, %s/float): %f " + "%f\n", + M, K, N, float_used, computations / float_used, int_used, + computations / int_used, algo, int_kern_used, + computations / int_kern_used, algo, algo, + int_used / int_kern_used, float_used / int_kern_used); + }; + + run(256, 12 * 24, 256); + + //////////////////////// gemv ////////////////////////// + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 64, 112, 256}) { + run(M, 1, K); + } + } + + //////////////////////// gemm ////////////////////////// + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 16, 32, 64, 112, 256}) { + for (size_t N : {8, 64, 112, 256}) { + run(M, N, K); + } + } + } +} +void run_16x16x32_benchmark(const char* algo, Handle* handle) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + Benchmarker benchmarker_int(handle); + benchmarker_int.set_before_exec_callback( + AlgoChecker("ARMV7_INT16X16X32_K12X4X1")); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::Int16{}) + .set_dtype(1, dtype::Int16{}) + .set_dtype(2, dtype::Int32{}) + .set_param(param) + .set_display(false); + Benchmarker benchmarker_float(handle); + benchmarker_float.set_display(false).set_times(RUNS); + + auto run = [&](size_t M, size_t N, size_t K) { + auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS; + auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops \n" + "int: %f ms %f Gflops %s: \n" + "speedup(%s/arm_common, %s/float): %f\n", + M, K, N, float_used, computations / float_used, int_used, + computations / int_used,algo,algo,algo,float_used / int_used); + }; + + run(256, 12 * 24, 256); + + //////////////////////// gemv ////////////////////////// + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 64, 112, 256}) { + run(M, 1, K); + } + } + + //////////////////////// gemm ////////////////////////// + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 16, 32, 64, 112, 256}) { + for (size_t N : + {1, 2, 3, 4, 8, 64, 112, 113, 114, 115, 256, 257, 258, 259}) { + run(M, N, K); + } + } + } +} + +#if __ARM_FEATURE_DOTPROD +void run_8x8x32_benchmark(const char* algo, Handle* handle) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + Benchmarker benchmarker_int8(handle); + benchmarker_int8.set_before_exec_callback(AlgoChecker(algo)); + benchmarker_int8.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int32{}) + .set_param(param) + .set_display(false); + Benchmarker benchmarker_float(handle); + benchmarker_float.set_display(false).set_times(RUNS); + + auto run = [&](size_t M, size_t N, size_t K) { + auto int_used = benchmarker_int8.exec({{M, K}, {K, N}, {}}) / RUNS; + auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops \n" + "int: %f ms %f Gflops %s: \n" + "speedup(%s/arm_common, %s/float): %f\n", + M, K, N, float_used, computations / float_used, int_used, + computations / int_used,algo,algo,algo,float_used / int_used); + }; + + run(256, 12 * 24, 256); + //////////////////////// gemm ////////////////////////// + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 16, 32, 64, 112, 256}) { + for (size_t N : {113, 114, 115, 256, 1024}) { + run(M, N, K); + } + } + } +} + +void run_8x8x32_quint_benchmark(Handle* handle) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + Benchmarker benchmarker_quint8_dot(handle); + benchmarker_quint8_dot.set_before_exec_callback( + AlgoChecker("AARCH32_QUINT8_K4X8X4")); + benchmarker_quint8_dot.set_times(RUNS) + .set_dtype(0, dtype::Quantized8Asymm(2.3f, static_cast(20))) + .set_dtype(1, dtype::Quantized8Asymm(3.1f, static_cast(30))) + .set_dtype(2, dtype::QuantizedS32(2.3f*3.1f)) + .set_param(param) + .set_display(false); + + Benchmarker benchmarker_quint8(handle); + benchmarker_quint8.set_before_exec_callback( + AlgoChecker("ARMV7_QUINT8_K4X8X8")); + benchmarker_quint8.set_times(RUNS) + .set_dtype(0, dtype::Quantized8Asymm(2.3f, static_cast(20))) + .set_dtype(1, dtype::Quantized8Asymm(3.1f, static_cast(30))) + .set_dtype(2, dtype::QuantizedS32(2.3f*3.1f)) + .set_param(param) + .set_display(false); + + auto run = [&](size_t M, size_t N, size_t K) { + auto dot_used = benchmarker_quint8_dot.exec({{M, K}, {K, N}, {}}) / RUNS; + auto normal_used = benchmarker_quint8.exec({{M, K}, {K, N}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} dot: %f ms %f Gflops \n" + "normal: %f ms %f Gflops.speedup: %f\n", + M, K, N, dot_used, computations / dot_used, normal_used, + computations / normal_used, normal_used / dot_used); + }; + + run(256, 12 * 24, 256); + //////////////////////// gemm ////////////////////////// + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 16, 32, 64, 112, 256}) { + for (size_t N : {113, 114, 115, 256, 1024}) { + run(M, N, K); + } + } + } +} +#endif +} // namespace + +#if __ARM_FEATURE_DOTPROD +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_K6x8x4) { + run_8x8x32_benchmark("AARCH32_INT8_K6X8X4", handle()); +} +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_QUINT8x8x32_K4x8x4) { + run_8x8x32_quint_benchmark(handle()); +} +#endif + +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x2x16) { + run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X2X16", handle()); +} + + +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8) { + run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X8X8", handle()); +} + +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_K12x4x1) { + run_16x16x32_benchmark("ARMV7_INT16X16X32_K12X4X1", handle()); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_FP16) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + Benchmarker benchmarker_fp16(handle()); + benchmarker_fp16.set_times(RUNS) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .set_param(param) + .set_display(false); + Benchmarker benchmarker_float(handle()); + benchmarker_float.set_param(param).set_display(false).set_times(RUNS); + + auto run = [&](size_t M, size_t N, size_t K) { + auto fp16_used = benchmarker_fp16.exec({{M, K}, {K, N}, {}}) / RUNS; + auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops fp16: %f ms " + "%f Gflops speedup: %f\n", + M, K, N, float_used, computations / float_used, fp16_used, + computations / fp16_used, float_used / fp16_used); + }; + + run(256, 12 * 24, 256); + + for (size_t M : {8, 64, 112, 256}) { + for (size_t K : {8, 64, 112, 256}) { + for (size_t N : {8, 64, 112, 256}) { + run(M, N, K); + } + } + } +} + +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_F16_MK8) { + auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(4); + matrix_mul::benchmark_with_contrast( + handle(), args, dtype::Float16{}, dtype::Float16{}, + dtype::Float16{}, "AARCH32_F16_MK8_4X8", + param::MatrixMul::Format::MK8, dtype::Float16{}, dtype::Float16{}, + dtype::Float16{}, "AARCH32_F16_K4X16X1"); +} +#endif + +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_MK4) { + auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8); + matrix_mul::benchmark_with_contrast( + handle(), args, dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, "ARMV7_F32_MK4_4x8", + param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}); +} + +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) { + auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(4); + matrix_mul::benchmark_with_contrast( + handle(), args, dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, + "ARMV7_INT16X16X32_MK8_4X8", param::MatrixMul::Format::MK8, + dtype::Int16{}, dtype::Int16{}, dtype::Int32{}); +} +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT32_MK_4X2X16) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = false; + param.transposeB = false; + Benchmarker benchmarker(handle()); + Benchmarker benchmarker_mk4(handle()); + benchmarker.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int32{}) + .set_param(param) + .set_display(false); + benchmarker.set_before_exec_callback( + AlgoChecker("ARMV7_INT8X8X32_K4X2X16")); + + param.format = MatrixMul::Param::Format::MK4; + benchmarker_mk4.set_before_exec_callback( + AlgoChecker("ARMV7_INT8X8X32_MK4_4X2X16")); + benchmarker_mk4.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int32{}) + .set_param(param) + .set_display(false); + + auto run = [&](size_t M, size_t N, size_t K) { + auto mk_used = benchmarker_mk4.exec( + {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / + RUNS; + auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms " + "%f Gflops speedup_vs_normal: %f\n", + M, K, N, default_used, computations / default_used, mk_used, + computations / mk_used, default_used / mk_used); + }; + + run(256, 256, 128); + for (size_t k = 4; k <= 512; k *= 2) { + for (size_t m = 4; m <= 512; m *= 2) { + for (size_t n = 4; n <= 512; n *= 2) { + run(m, n, k); + } + } + std::cout << std::endl; + } +} + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/armv7/pooling.cpp b/dnn/test/armv7/pooling.cpp new file mode 100644 index 00000000..1b4c051d --- /dev/null +++ b/dnn/test/armv7/pooling.cpp @@ -0,0 +1,33 @@ +/** + * \file dnn/test/armv7/pooling.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/armv7/fixture.h" + +#include "test/common/pooling.h" +#include "test/common/checker.h" + +namespace megdnn { +namespace test { + +TEST_F(ARMV7, POOLING) +{ + auto args = pooling::get_args(); + for (auto &&arg: args) { + Checker checker(handle()); + checker.set_param(arg.param).exec(TensorShapeArray{ + arg.ishape, {}}); + } +} + +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen + + diff --git a/dnn/test/armv7/rotate.cpp b/dnn/test/armv7/rotate.cpp new file mode 100644 index 00000000..af8ef081 --- /dev/null +++ b/dnn/test/armv7/rotate.cpp @@ -0,0 +1,90 @@ +/** + * \file dnn/test/armv7/rotate.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/common/rotate.h" +#include "test/common/checker.h" +#include "test/common/benchmarker.h" + +#include "test/armv7/fixture.h" + +namespace megdnn { +namespace test { + +TEST_F(ARMV7, ROTATE) +{ + using namespace rotate; + std::vector args = get_args(); + Checker checker(handle()); + + for (auto &&arg: args) { + checker.set_param(arg.param) + .set_dtype(0, arg.dtype) + .set_dtype(1, arg.dtype) + .execs({arg.src, {}}); + } +} + +TEST_F(ARMV7, BENCHMARK_ROTATE) +{ + using namespace rotate; + using Param = param::Rotate; + +#define BENCHMARK_PARAM(benchmarker, dtype) \ + benchmarker.set_param(param); \ + benchmarker.set_dtype(0, dtype); + + auto run = [&](const TensorShapeArray& shapes, Param param) { + auto handle_naive = create_cpu_handle(2); + Benchmarker benchmarker(handle()); + Benchmarker benchmarker_naive(handle_naive.get()); + + BENCHMARK_PARAM(benchmarker, dtype::Uint8()); + BENCHMARK_PARAM(benchmarker_naive, dtype::Uint8()); + for (auto&& shape : shapes) { + printf("execute %s: current---naive\n", shape.to_string().c_str()); + benchmarker.execs({shape, {}}); + benchmarker_naive.execs({shape, {}}); + } + + BENCHMARK_PARAM(benchmarker, dtype::Int32()); + BENCHMARK_PARAM(benchmarker_naive, dtype::Int32()); + for (auto&& shape : shapes) { + printf("execute %s: current---naive\n", shape.to_string().c_str()); + benchmarker.execs({shape, {}}); + benchmarker_naive.execs({shape, {}}); + } + + BENCHMARK_PARAM(benchmarker, dtype::Float32()); + BENCHMARK_PARAM(benchmarker_naive, dtype::Float32()); + for (auto&& shape : shapes) { + printf("execute %s: current---naive\n", shape.to_string().c_str()); + benchmarker.execs({shape, {}}); + benchmarker_naive.execs({shape, {}}); + } + + }; + + Param param; + TensorShapeArray shapes = { + {1, 100, 100, 1}, + {2, 100, 100, 3}, + }; + + param.clockwise = true; + run(shapes, param); + + param.clockwise = false; + run(shapes, param); +} + + +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/scripts/cmake-build/cross_build_android_arm_inference.sh b/scripts/cmake-build/cross_build_android_arm_inference.sh new file mode 100755 index 00000000..ddb359be --- /dev/null +++ b/scripts/cmake-build/cross_build_android_arm_inference.sh @@ -0,0 +1,174 @@ +#!/usr/bin/env bash +set -e + +ARCHS=("arm64-v8a" "armeabi-v7a") +BUILD_TYPE=Release +MGE_ARMV8_2_FEATURE_FP16=OFF +MGE_ARMV8_2_FEATURE_DOTPROD=OFF +MGE_DISABLE_FLOAT16=OFF +ARCH=arm64-v8a + +function usage() { + echo "$0 args1 args2 .." + echo "available args detail:" + echo "-d : Build with Debug mode, defaule Release mode" + echo "-f : enable MGE_ARMV8_2_FEATURE_FP16 for ARM64, need toolchain and hardware support" + echo "-p : enable MGE_ARMV8_2_FEATURE_DOTPROD for ARM64, need toolchain and hardware support" + echo "-k : open MGE_DISABLE_FLOAT16 for NEON " + echo "-a : config build arch available: ${ARCHS[@]}" + echo "-h : show usage" + echo "example: $0 -d" + exit -1 +} + +while getopts "khdfpa:" arg +do + case $arg in + d) + echo "Build with Debug mode" + BUILD_TYPE=Debug + ;; + f) + echo "enable MGE_ARMV8_2_FEATURE_FP16 for ARM64" + MGE_ARMV8_2_FEATURE_FP16=ON + ;; + p) + echo "enable MGE_ARMV8_2_FEATURE_DOTPROD for ARM64" + MGE_ARMV8_2_FEATURE_DOTPROD=ON + ;; + k) + echo "open MGE_DISABLE_FLOAT16 for NEON" + MGE_DISABLE_FLOAT16=ON + ;; + a) + tmp_arch=null + for arch in ${ARCHS[@]}; do + if [ "$arch" = "$OPTARG" ]; then + echo "CONFIG BUILD ARCH to : $OPTARG" + tmp_arch=$OPTARG + ARCH=$OPTARG + break + fi + done + if [ "$tmp_arch" = "null" ]; then + echo "ERR args for arch (-a)" + echo "available arch list: ${ARCHS[@]}" + usage + fi + ;; + h) + echo "show usage" + usage + ;; + ?) + echo "unkonw argument" + usage + ;; + esac +done +echo "----------------------------------------------------" +echo "build config summary:" +echo "BUILD_TYPE: $BUILD_TYPE" +echo "MGE_ARMV8_2_FEATURE_FP16: $MGE_ARMV8_2_FEATURE_FP16" +echo "MGE_ARMV8_2_FEATURE_DOTPROD: $MGE_ARMV8_2_FEATURE_DOTPROD" +echo "MGE_DISABLE_FLOAT16: $MGE_DISABLE_FLOAT16" +echo "ARCH: $ARCH" +echo "----------------------------------------------------" + +READLINK=readlink +OS=$(uname -s) + +if [ $OS = "Darwin" ];then + READLINK=greadlink +fi + +SRC_DIR=$($READLINK -f "`dirname $0`/../../") + +if [ -z $NDK_ROOT ];then + echo "can not find NDK_ROOT env, pls export you NDK root dir to NDK_ROOT" + exit -1 +fi + +function cmake_build() { + BUILD_DIR=$SRC_DIR/build_dir/android/$1/$BUILD_TYPE/build + INSTALL_DIR=$BUILD_DIR/../install + BUILD_ABI=$1 + BUILD_NATIVE_LEVEL=$2 + echo "build dir: $BUILD_DIR" + echo "install dir: $INSTALL_DIR" + echo "build type: $BUILD_TYPE" + echo "build ABI: $BUILD_ABI" + echo "build native level: $BUILD_NATIVE_LEVEL" + if [ -e $BUILD_DIR ];then + echo "clean old dir: $BUILD_DIR" + rm -rf $BUILD_DIR + fi + if [ -e $INSTALL_DIR ];then + echo "clean old dir: $INSTALL_DIR" + rm -rf $INSTALL_DIR + fi + + echo "create build dir" + mkdir -p $BUILD_DIR + mkdir -p $INSTALL_DIR + cd $BUILD_DIR + cmake -DCMAKE_TOOLCHAIN_FILE="$NDK_ROOT/build/cmake/android.toolchain.cmake" \ + -DANDROID_NDK="$NDK_ROOT" \ + -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ + -DANDROID_ABI=$BUILD_ABI \ + -DANDROID_NATIVE_API_LEVEL=$BUILD_NATIVE_LEVEL \ + -DMGE_INFERENCE_ONLY=ON \ + -DMGE_WITH_CUDA=OFF \ + -DMGE_ARMV8_2_FEATURE_FP16= $MGE_ARMV8_2_FEATURE_FP16 \ + -DMGE_ARMV8_2_FEATURE_DOTPROD=$MGE_ARMV8_2_FEATURE_DOTPROD \ + -DMGE_DISABLE_FLOAT16=$MGE_DISABLE_FLOAT16 \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_DIR \ + $SRC_DIR + + make -j$(nproc) + make install/strip +} + +function build_flatc() { + BUILD_DIR=$SRC_DIR/build_dir/host_flatc/build + INSTALL_DIR=$BUILD_DIR/../install + if [ -e $BUILD_DIR ];then + echo "clean old dir: $BUILD_DIR" + rm -rf $BUILD_DIR + fi + if [ -e $INSTALL_DIR ];then + echo "clean old dir: $INSTALL_DIR" + rm -rf $INSTALL_DIR + fi + + echo "create build dir" + mkdir -p $BUILD_DIR + mkdir -p $INSTALL_DIR + cd $BUILD_DIR + cmake -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_DIR \ + -DFLATBUFFERS_BUILD_TESTS=OFF \ + -DFLATBUFFERS_BUILD_FLATHASH=OFF \ + -DFLATBUFFERS_BUILD_FLATLIB=OFF \ + -DFLATBUFFERS_LIBCXX_WITH_CLANG=OFF \ + $SRC_DIR/third_party/flatbuffers + + make -j$(nproc) + make install/strip +} +build_flatc + +api_level=16 +abi="armeabi-v7a with NEON" +IFS="" +if [ "$ARCH" = "arm64-v8a" ]; then + api_level=21 + abi="arm64-v8a" +elif [ "$ARCH" = "armeabi-v7a" ]; then + api_level=16 + abi="armeabi-v7a with NEON" +else + echo "ERR CONFIG ABORT NOW!!" + exit -1 +fi +cmake_build $abi $api_level diff --git a/scripts/cmake-build/cross_build_ios_arm_inference.sh b/scripts/cmake-build/cross_build_ios_arm_inference.sh new file mode 100755 index 00000000..53df15e2 --- /dev/null +++ b/scripts/cmake-build/cross_build_ios_arm_inference.sh @@ -0,0 +1,175 @@ +#!/usr/bin/env bash +set -e + +ARCHS=("arm64" "armv7") +BUILD_TYPE=Release +MGE_ARMV8_2_FEATURE_FP16=OFF +MGE_ARMV8_2_FEATURE_DOTPROD=OFF +MGE_DISABLE_FLOAT16=OFF +ARCH=arm64 + +function usage() { + echo "$0 args1 args2 .." + echo "available args detail:" + echo "-d : Build with Debug mode, defaule Release mode" + echo "-f : enable MGE_ARMV8_2_FEATURE_FP16 for ARM64, need toolchain and hardware support" + echo "-p : enable MGE_ARMV8_2_FEATURE_DOTPROD for ARM64, need toolchain and hardware support" + echo "-k : open MGE_DISABLE_FLOAT16 for NEON " + echo "-a : config build arch available: ${ARCHS[@]}" + echo "-h : show usage" + echo "example: $0 -d" + exit -1 +} + +while getopts "khdfpa:" arg +do + case $arg in + d) + echo "Build with Debug mode" + BUILD_TYPE=Debug + ;; + f) + echo "enable MGE_ARMV8_2_FEATURE_FP16 for ARM64" + MGE_ARMV8_2_FEATURE_FP16=ON + ;; + p) + echo "enable MGE_ARMV8_2_FEATURE_DOTPROD for ARM64" + MGE_ARMV8_2_FEATURE_DOTPROD=ON + ;; + k) + echo "open MGE_DISABLE_FLOAT16 for NEON" + MGE_DISABLE_FLOAT16=ON + ;; + a) + tmp_arch=null + for arch in ${ARCHS[@]}; do + if [ "$arch" = "$OPTARG" ]; then + echo "CONFIG BUILD ARCH to : $OPTARG" + tmp_arch=$OPTARG + ARCH=$OPTARG + break + fi + done + if [ "$tmp_arch" = "null" ]; then + echo "ERR args for arch (-a)" + echo "available arch list: ${ARCHS[@]}" + usage + fi + ;; + h) + echo "show usage" + usage + ;; + ?) + echo "unkonw argument" + usage + ;; + esac +done +echo "----------------------------------------------------" +echo "build config summary:" +echo "BUILD_TYPE: $BUILD_TYPE" +echo "MGE_ARMV8_2_FEATURE_FP16: $MGE_ARMV8_2_FEATURE_FP16" +echo "MGE_ARMV8_2_FEATURE_DOTPROD: $MGE_ARMV8_2_FEATURE_DOTPROD" +echo "MGE_DISABLE_FLOAT16: $MGE_DISABLE_FLOAT16" +echo "ARCH: $ARCH" +echo "----------------------------------------------------" + +READLINK=readlink +OS=$(uname -s) + +if [ $OS = "Darwin" ];then + READLINK=greadlink +else + echo "cross build ios only support on macos, abort now!!" + exit -1 +fi + +SRC_DIR=$($READLINK -f "`dirname $0`/../../") + +function cmake_build() { + BUILD_DIR=$SRC_DIR/build_dir/apple/$3/$4/$1/$BUILD_TYPE/build + INSTALL_DIR=$BUILD_DIR/../install + TOOLCHAIN=$SRC_DIR/toolchains/$2 + OS_PLATFORM=$3 + XCODE_IOS_PLATFORM=$4 + IOS_ARCH=$1 + echo "build dir: $BUILD_DIR" + echo "install dir: $INSTALL_DIR" + echo "build type: $BUILD_TYPE" + echo "build toolchain: $TOOLCHAIN" + echo "build OS_PLATFORM: $OS_PLATFORM" + echo "build XCODE_IOS_PLATFORM: $XCODE_IOS_PLATFORM" + echo "build IOS_ARCH: $IOS_ARCH" + if [ -e $BUILD_DIR ];then + echo "clean old dir: $BUILD_DIR" + rm -rf $BUILD_DIR + fi + if [ -e $INSTALL_DIR ];then + echo "clean old dir: $INSTALL_DIR" + rm -rf $INSTALL_DIR + fi + + echo "create build dir" + mkdir -p $BUILD_DIR + mkdir -p $INSTALL_DIR + cd $BUILD_DIR + cmake -DCMAKE_TOOLCHAIN_FILE=$TOOLCHAIN \ + -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ + -DIOS_TOOLCHAIN_ROOT=$TOOLCHAIN \ + -DOS_PLATFORM=$OS_PLATFORM \ + -DXCODE_IOS_PLATFORM=$XCODE_IOS_PLATFORM \ + -DIOS_ARCH=$IOS_ARCH \ + -DMGE_INFERENCE_ONLY=ON \ + -DPYTHON_EXECUTABLE=/usr/local/bin/python3 \ + -DMGE_WITH_CUDA=OFF \ + -DMGE_ARMV8_2_FEATURE_FP16= $MGE_ARMV8_2_FEATURE_FP16 \ + -DMGE_ARMV8_2_FEATURE_DOTPROD=$MGE_ARMV8_2_FEATURE_DOTPROD \ + -DMGE_DISABLE_FLOAT16=$MGE_DISABLE_FLOAT16 \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_DIR \ + $SRC_DIR + + make -j$(nproc) + make install +} + +function build_flatc() { + BUILD_DIR=$SRC_DIR/build_dir/host_flatc/build + INSTALL_DIR=$BUILD_DIR/../install + if [ -e $BUILD_DIR ];then + echo "clean old dir: $BUILD_DIR" + rm -rf $BUILD_DIR + fi + if [ -e $INSTALL_DIR ];then + echo "clean old dir: $INSTALL_DIR" + rm -rf $INSTALL_DIR + fi + + echo "create build dir" + mkdir -p $BUILD_DIR + mkdir -p $INSTALL_DIR + cd $BUILD_DIR + cmake -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_DIR \ + -DFLATBUFFERS_BUILD_TESTS=OFF \ + -DFLATBUFFERS_BUILD_FLATHASH=OFF \ + -DFLATBUFFERS_BUILD_FLATLIB=OFF \ + -DFLATBUFFERS_LIBCXX_WITH_CLANG=OFF \ + $SRC_DIR/third_party/flatbuffers + + make -j$(nproc) + make install/strip +} +build_flatc + +# refs for ../../toolchains/ios.toolchain.cmake +# to config this, if u want to build other, +# like simulator or for iwatch, please do manually modify +# OS_PLATFORM=("OS" "OS64" "SIMULATOR" "SIMULATOR64" "TVOS" "WATCHOS" "SIMULATOR_TVOS") +# XCODE_IOS_PLATFORM=("iphoneos" "iphonesimulator" "appletvos" "appletvsimulator" "watchos", "watchsimulator") +# IOS_ARCHS=("arm64" "armv7" "armv7k" "arm64e" "armv7s") + +#by defaut we only triger build arm64/armv7 for iphoneos +OS_PLATFORM=OS +XCODE_IOS_PLATFORM=iphoneos +cmake_build $ARCH ios.toolchain.cmake $OS_PLATFORM $XCODE_IOS_PLATFORM diff --git a/scripts/cmake-build/cross_build_linux_arm_inference.sh b/scripts/cmake-build/cross_build_linux_arm_inference.sh new file mode 100755 index 00000000..16cef54f --- /dev/null +++ b/scripts/cmake-build/cross_build_linux_arm_inference.sh @@ -0,0 +1,161 @@ +#!/usr/bin/env bash +set -e + +ARCHS=("arm64-v8a" "armeabi-v7a-softfp" "armeabi-v7a-hardfp") +BUILD_TYPE=Release +MGE_ARMV8_2_FEATURE_FP16=OFF +MGE_ARMV8_2_FEATURE_DOTPROD=OFF +MGE_DISABLE_FLOAT16=OFF +ARCH=arm64-v8a + +function usage() { + echo "$0 args1 args2 .." + echo "available args detail:" + echo "-d : Build with Debug mode, defaule Release mode" + echo "-f : enable MGE_ARMV8_2_FEATURE_FP16 for ARM64, need toolchain and hardware support" + echo "-p : enable MGE_ARMV8_2_FEATURE_DOTPROD for ARM64, need toolchain and hardware support" + echo "-k : open MGE_DISABLE_FLOAT16 for NEON " + echo "-a : config build arch available: ${ARCHS[@]}" + echo "-h : show usage" + echo "example: $0 -d" + exit -1 +} + +while getopts "khdfpa:" arg +do + case $arg in + d) + echo "Build with Debug mode" + BUILD_TYPE=Debug + ;; + f) + echo "enable MGE_ARMV8_2_FEATURE_FP16 for ARM64" + MGE_ARMV8_2_FEATURE_FP16=ON + ;; + p) + echo "enable MGE_ARMV8_2_FEATURE_DOTPROD for ARM64" + MGE_ARMV8_2_FEATURE_DOTPROD=ON + ;; + k) + echo "open MGE_DISABLE_FLOAT16 for NEON" + MGE_DISABLE_FLOAT16=ON + ;; + a) + tmp_arch=null + for arch in ${ARCHS[@]}; do + if [ "$arch" = "$OPTARG" ]; then + echo "CONFIG BUILD ARCH to : $OPTARG" + tmp_arch=$OPTARG + ARCH=$OPTARG + break + fi + done + if [ "$tmp_arch" = "null" ]; then + echo "ERR args for arch (-a)" + echo "available arch list: ${ARCHS[@]}" + usage + fi + ;; + h) + echo "show usage" + usage + ;; + ?) + echo "unkonw argument" + usage + ;; + esac +done +echo "----------------------------------------------------" +echo "build config summary:" +echo "BUILD_TYPE: $BUILD_TYPE" +echo "MGE_ARMV8_2_FEATURE_FP16: $MGE_ARMV8_2_FEATURE_FP16" +echo "MGE_ARMV8_2_FEATURE_DOTPROD: $MGE_ARMV8_2_FEATURE_DOTPROD" +echo "MGE_DISABLE_FLOAT16: $MGE_DISABLE_FLOAT16" +echo "ARCH: $ARCH" +echo "----------------------------------------------------" + +READLINK=readlink +OS=$(uname -s) + +if [ $OS = "Darwin" ];then + READLINK=greadlink +fi + +SRC_DIR=$($READLINK -f "`dirname $0`/../../") +function cmake_build() { + BUILD_DIR=$SRC_DIR/build_dir/gnu-linux/$1/$BUILD_TYPE/build + INSTALL_DIR=$BUILD_DIR/../install + TOOLCHAIN=$SRC_DIR/toolchains/$2 + echo "build dir: $BUILD_DIR" + echo "install dir: $INSTALL_DIR" + echo "build type: $BUILD_TYPE" + echo "build toolchain: $TOOLCHAIN" + if [ -e $BUILD_DIR ];then + echo "clean old dir: $BUILD_DIR" + rm -rf $BUILD_DIR + fi + if [ -e $INSTALL_DIR ];then + echo "clean old dir: $INSTALL_DIR" + rm -rf $INSTALL_DIR + fi + + echo "create build dir" + mkdir -p $BUILD_DIR + mkdir -p $INSTALL_DIR + cd $BUILD_DIR + cmake -DCMAKE_TOOLCHAIN_FILE=$TOOLCHAIN \ + -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ + -DMGE_INFERENCE_ONLY=ON \ + -DMGE_WITH_CUDA=OFF \ + -DMGE_ARMV8_2_FEATURE_FP16= $MGE_ARMV8_2_FEATURE_FP16 \ + -DMGE_ARMV8_2_FEATURE_DOTPROD=$MGE_ARMV8_2_FEATURE_DOTPROD \ + -DMGE_DISABLE_FLOAT16=$MGE_DISABLE_FLOAT16 \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_DIR \ + $SRC_DIR + + make -j$(nproc) + make install/strip +} + +function build_flatc() { + BUILD_DIR=$SRC_DIR/build_dir/host_flatc/build + INSTALL_DIR=$BUILD_DIR/../install + if [ -e $BUILD_DIR ];then + echo "clean old dir: $BUILD_DIR" + rm -rf $BUILD_DIR + fi + if [ -e $INSTALL_DIR ];then + echo "clean old dir: $INSTALL_DIR" + rm -rf $INSTALL_DIR + fi + + echo "create build dir" + mkdir -p $BUILD_DIR + mkdir -p $INSTALL_DIR + cd $BUILD_DIR + cmake -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_DIR \ + -DFLATBUFFERS_BUILD_TESTS=OFF \ + -DFLATBUFFERS_BUILD_FLATHASH=OFF \ + -DFLATBUFFERS_BUILD_FLATLIB=OFF \ + -DFLATBUFFERS_LIBCXX_WITH_CLANG=OFF \ + $SRC_DIR/third_party/flatbuffers + + make -j$(nproc) + make install/strip +} +build_flatc + +toolchain=null +if [ "$ARCH" = "arm64-v8a" ]; then + toolchain="aarch64-linux-gnu.toolchain.cmake" +elif [ "$ARCH" = "armeabi-v7a-hardfp" ]; then + toolchain="arm-linux-gnueabihf.toolchain.cmake" +elif [ "$ARCH" = "armeabi-v7a-softfp" ]; then + toolchain="arm-linux-gnueabi.toolchain.cmake" +else + echo "ERR CONFIG ABORT NOW!!" + exit -1 +fi +cmake_build $ARCH $toolchain diff --git a/scripts/cmake-build/host_build.sh b/scripts/cmake-build/host_build.sh new file mode 100755 index 00000000..64c33dcc --- /dev/null +++ b/scripts/cmake-build/host_build.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash +set -e + +function usage() { + echo "$0 args1 args2 .." + echo "available args detail:" + echo "-d : Build with Debug mode, defaule Release mode" + echo "-c : Build with CUDA, default without CUDA" + echo "-t : Build with training mode, default inference only" + echo "example: $0 -d" + exit -1 +} + +BUILD_TYPE=Release +MGE_WITH_CUDA=OFF +MGE_INFERENCE_ONLY=ON + +while getopts "dct" arg +do + case $arg in + d) + echo "Build with Debug mode" + BUILD_TYPE=Debug + ;; + c) + echo "Build with CUDA" + MGE_WITH_CUDA=ON + ;; + t) + echo "Build with training mode" + MGE_INFERENCE_ONLY=OFF + ;; + ?) + echo "unkonw argument" + usage + ;; + esac +done +echo "------------------------------------" +echo "build config summary:" +echo "BUILD_TYPE: $BUILD_TYPE" +echo "MGE_WITH_CUDA: $MGE_WITH_CUDA" +echo "MGE_INFERENCE_ONLY: $MGE_INFERENCE_ONLY" +echo "------------------------------------" +READLINK=readlink +OS=$(uname -s) + +if [ $OS = "Darwin" ];then + READLINK=greadlink + if [ $MGE_WITH_CUDA = "ON" ];then + echo "MACOS DO NOT SUPPORT TensorRT, ABORT NOW!!" + exit -1 + fi +fi + +SRC_DIR=$($READLINK -f "`dirname $0`/../../") + +function cmake_build() { + BUILD_DIR=$SRC_DIR/build_dir/host/MGE_WITH_CUDA_$1/MGE_INFERENCE_ONLY_$2/$3/build + INSTALL_DIR=$BUILD_DIR/../install + MGE_WITH_CUDA=$1 + MGE_INFERENCE_ONLY=$2 + BUILD_TYPE=$3 + echo "build dir: $BUILD_DIR" + echo "install dir: $INSTALL_DIR" + echo "build type: $BUILD_TYPE" + echo "MGE_WITH_CUDA: $MGE_WITH_CUDA" + echo "MGE_INFERENCE_ONLY: $MGE_INFERENCE_ONLY" + if [ -e $BUILD_DIR ];then + echo "clean old dir: $BUILD_DIR" + rm -rf $BUILD_DIR + fi + if [ -e $INSTALL_DIR ];then + echo "clean old dir: $INSTALL_DIR" + rm -rf $INSTALL_DIR + fi + + echo "create build dir" + mkdir -p $BUILD_DIR + mkdir -p $INSTALL_DIR + cd $BUILD_DIR + cmake \ + -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ + -DMGE_INFERENCE_ONLY=$MGE_INFERENCE_ONLY \ + -DMGE_WITH_CUDA=$MGE_WITH_CUDA \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_DIR \ + $SRC_DIR + + make -j$(nproc) + make install/strip + } + +cmake_build $MGE_WITH_CUDA $MGE_INFERENCE_ONLY $BUILD_TYPE + diff --git a/src/megbrain_build_config.h.in b/src/megbrain_build_config.h.in index 3e0327bb..6b853b68 100644 --- a/src/megbrain_build_config.h.in +++ b/src/megbrain_build_config.h.in @@ -32,6 +32,9 @@ // Platform macro's #cmakedefine01 MEGDNN_WITH_CUDA +#cmakedefine01 MEGDNN_ARMV7 +#cmakedefine01 MEGDNN_AARCH64 +#cmakedefine01 MEGDNN_ENABLE_FP16_NEON #cmakedefine01 MEGDNN_X86_WITH_MKL #cmakedefine01 MEGDNN_X86_WITH_OPENBLAS diff --git a/toolchains/aarch64-linux-gnu.toolchain.cmake b/toolchains/aarch64-linux-gnu.toolchain.cmake new file mode 100644 index 00000000..fc502708 --- /dev/null +++ b/toolchains/aarch64-linux-gnu.toolchain.cmake @@ -0,0 +1,6 @@ +set(ARM_CROSS_BUILD_ARCH aarch64) +set(CMAKE_C_COMPILER "aarch64-linux-gnu-gcc") +set(CMAKE_CXX_COMPILER "aarch64-linux-gnu-g++") +set(CMAKE_C_FLAGS "-Werror=unused-parameter -Wno-psabi") +set(CMAKE_CXX_FLAGS "-Werror=unused-parameter -Wno-psabi") +set(CMAKE_STRIP "aarch64-linux-gnu-strip") diff --git a/toolchains/arm-linux-gnueabi.toolchain.cmake b/toolchains/arm-linux-gnueabi.toolchain.cmake new file mode 100644 index 00000000..118fa941 --- /dev/null +++ b/toolchains/arm-linux-gnueabi.toolchain.cmake @@ -0,0 +1,6 @@ +set(ARM_CROSS_BUILD_ARCH armv7) +set(CMAKE_C_COMPILER "arm-linux-gnueabi-gcc") +set(CMAKE_CXX_COMPILER "arm-linux-gnueabi-g++") +set(CMAKE_C_FLAGS "-mfloat-abi=softfp -mfpu=neon-vfpv4 -Werror=unused-parameter -Wno-psabi") +set(CMAKE_CXX_FLAGS "-mfloat-abi=softfp -mfpu=neon-vfpv4 -Werror=unused-parameter -Wno-psabi") +set(CMAKE_STRIP "arm-linux-gnueabi-strip") diff --git a/toolchains/arm-linux-gnueabihf.toolchain.cmake b/toolchains/arm-linux-gnueabihf.toolchain.cmake new file mode 100644 index 00000000..ce691fea --- /dev/null +++ b/toolchains/arm-linux-gnueabihf.toolchain.cmake @@ -0,0 +1,6 @@ +set(ARM_CROSS_BUILD_ARCH armv7) +set(CMAKE_C_COMPILER "arm-linux-gnueabihf-gcc") +set(CMAKE_CXX_COMPILER "arm-linux-gnueabihf-g++") +set(CMAKE_C_FLAGS "-mfloat-abi=hard -mfpu=neon-vfpv4 -Werror=unused-parameter -Wno-psabi") +set(CMAKE_CXX_FLAGS "-mfloat-abi=hard -mfpu=neon-vfpv4 -Werror=unused-parameter -Wno-psabi") +set(CMAKE_STRIP "arm-linux-gnueabihf-strip") diff --git a/toolchains/ios.toolchain.cmake b/toolchains/ios.toolchain.cmake new file mode 100644 index 00000000..6109cf55 --- /dev/null +++ b/toolchains/ios.toolchain.cmake @@ -0,0 +1,486 @@ +# This file is part of the ios-cmake project. It was retrieved from +# https://github.com/cristeab/ios-cmake.git, which is a fork of +# https://code.google.com/p/ios-cmake/. Which in turn is based off of +# the Platform/Darwin.cmake and Platform/UnixPaths.cmake files which +# are included with CMake 2.8.4 +# +# The ios-cmake project is licensed under the new BSD license. +# +# Copyright (c) 2014, Bogdan Cristea and LTE Engineering Software, +# Kitware, Inc., Insight Software Consortium. All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# This file is based off of the Platform/Darwin.cmake and +# Platform/UnixPaths.cmake files which are included with CMake 2.8.4 +# It has been altered for iOS development. +# +# Updated by Alex Stewart (alexs.mac@gmail.com) +# +# ***************************************************************************** +# Now maintained by Alexander Widerberg (widerbergaren [at] gmail.com) +# under the BSD-3-Clause license +# https://github.com/leetal/ios-cmake +# ***************************************************************************** +# +# INFORMATION / HELP +# +# The following variables control the behaviour of this toolchain: +# +# IOS_PLATFORM: OS (default) or SIMULATOR or SIMULATOR64 or TVOS or SIMULATOR_TVOS or WATCHOS or SIMULATOR_WATCHOS +# OS = Build for iPhoneOS. +# OS64 = Build for arm64 arm64e iPhoneOS. +# SIMULATOR = Build for x86 i386 iPhone Simulator. +# SIMULATOR64 = Build for x86_64 iPhone Simulator. +# TVOS = Build for AppleTVOS. +# SIMULATOR_TVOS = Build for x86_64 AppleTV Simulator. +# WATCHOS = Build for armv7k arm64_32 for WatchOS. +# SIMULATOR_WATCHOS = Build for x86_64 for Watch Simulator. +# CMAKE_OSX_SYSROOT: Path to the iOS SDK to use. By default this is +# automatically determined from IOS_PLATFORM and xcodebuild, but +# can also be manually specified (although this should not be required). +# CMAKE_IOS_DEVELOPER_ROOT: Path to the Developer directory for the iOS platform +# being compiled for. By default this is automatically determined from +# CMAKE_OSX_SYSROOT, but can also be manually specified (although this should +# not be required). +# ENABLE_BITCODE: (1|0) Enables or disables bitcode support. Default 1 (true) +# ENABLE_ARC: (1|0) Enables or disables ARC support. Default 1 (true, ARC enabled by default) +# ENABLE_VISIBILITY: (1|0) Enables or disables symbol visibility support. Default 0 (false, visibility hidden by default) +# IOS_ARCH: (armv7 armv7s armv7k arm64 arm64e arm64_32 i386 x86_64) If specified, will override the default architectures for the given IOS_PLATFORM +# OS = armv7 armv7s arm64 arm64e (if applicable) +# OS64 = arm64 arm64e (if applicable) +# SIMULATOR = i386 x86_64 +# SIMULATOR64 = x86_64 +# TVOS = arm64 +# SIMULATOR_TVOS = x86_64 (i386 has since long been deprecated) +# WATCHOS = armv7k arm64_32 (if applicable) +# SIMULATOR_WATCHOS = x86_64 (i386 has since long been deprecated) +# +# This toolchain defines the following variables for use externally: +# +# XCODE_VERSION: Version number (not including Build version) of Xcode detected. +# IOS_SDK_VERSION: Version of iOS SDK being used. +# CMAKE_OSX_ARCHITECTURES: Architectures being compiled for (generated from +# IOS_PLATFORM). +# +# This toolchain defines the following macros for use externally: +# +# set_xcode_property (TARGET XCODE_PROPERTY XCODE_VALUE XCODE_VARIANT) +# A convenience macro for setting xcode specific properties on targets. +# Available variants are: All, Release, RelWithDebInfo, Debug, MinSizeRel +# example: set_xcode_property (myioslib IPHONEOS_DEPLOYMENT_TARGET "3.1" "all"). +# +# find_host_package (PROGRAM ARGS) +# A macro used to find executable programs on the host system, not within the +# iOS environment. Thanks to the android-cmake project for providing the +# command. + +# Fix for PThread library not in path +set(CMAKE_THREAD_LIBS_INIT "-lpthread") +set(CMAKE_HAVE_THREADS_LIBRARY 1) +set(CMAKE_USE_WIN32_THREADS_INIT 0) +set(CMAKE_USE_PTHREADS_INIT 1) + +# Get the Xcode version being used. +execute_process(COMMAND xcodebuild -version + OUTPUT_VARIABLE XCODE_VERSION + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +string(REGEX MATCH "Xcode [0-9\\.]+" XCODE_VERSION "${XCODE_VERSION}") +string(REGEX REPLACE "Xcode ([0-9\\.]+)" "\\1" XCODE_VERSION "${XCODE_VERSION}") +message(STATUS "Building with Xcode version: ${XCODE_VERSION}") +# Default to building for iPhoneOS if not specified otherwise, and we cannot +# determine the platform from the CMAKE_OSX_ARCHITECTURES variable. The use +# of CMAKE_OSX_ARCHITECTURES is such that try_compile() projects can correctly +# determine the value of IOS_PLATFORM from the root project, as +# CMAKE_OSX_ARCHITECTURES is propagated to them by CMake. +if (NOT DEFINED IOS_PLATFORM) + if (CMAKE_OSX_ARCHITECTURES) + if (CMAKE_OSX_ARCHITECTURES MATCHES ".*arm.*") + set(IOS_PLATFORM "OS") + elseif (CMAKE_OSX_ARCHITECTURES MATCHES "i386") + set(IOS_PLATFORM "SIMULATOR") + elseif (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64") + set(IOS_PLATFORM "SIMULATOR64") + elseif (CMAKE_OSX_ARCHITECTURES MATCHES "armv7k") + set(IOS_PLATFORM "WATCHOS") + endif() + endif() + if (NOT IOS_PLATFORM) + set(IOS_PLATFORM "OS") + endif() +endif() +set(IOS_PLATFORM ${IOS_PLATFORM} CACHE STRING + "Type of iOS platform for which to build.") +# Determine the platform name and architectures for use in xcodebuild commands +# from the specified IOS_PLATFORM name. +if (IOS_PLATFORM STREQUAL "OS") + set(XCODE_IOS_PLATFORM iphoneos) + if(NOT IOS_ARCH) + if (XCODE_VERSION VERSION_GREATER 10.0) + set(IOS_ARCH armv7 armv7s arm64 arm64e) + else() + set(IOS_ARCH armv7 armv7s arm64) + endif() + endif() + elseif (IOS_PLATFORM STREQUAL "OS64") + set(XCODE_IOS_PLATFORM iphoneos) + if(NOT IOS_ARCH) + if (XCODE_VERSION VERSION_GREATER 10.0) + set(IOS_ARCH arm64 arm64e) + else() + set(IOS_ARCH arm64) + endif() + endif() +elseif (IOS_PLATFORM STREQUAL "SIMULATOR") + set(XCODE_IOS_PLATFORM iphonesimulator) + if(NOT IOS_ARCH) + set(IOS_ARCH i386 x86_64) + endif() +elseif(IOS_PLATFORM STREQUAL "SIMULATOR64") + set(XCODE_IOS_PLATFORM iphonesimulator) + if(NOT IOS_ARCH) + set(IOS_ARCH x86_64) + endif() +elseif (IOS_PLATFORM STREQUAL "TVOS") + set(XCODE_IOS_PLATFORM appletvos) + if(NOT IOS_ARCH) + set(IOS_ARCH arm64) + endif() +elseif (IOS_PLATFORM STREQUAL "SIMULATOR_TVOS") + set(XCODE_IOS_PLATFORM appletvsimulator) + if(NOT IOS_ARCH) + set(IOS_ARCH x86_64) + endif() +elseif (IOS_PLATFORM STREQUAL "WATCHOS") + set(XCODE_IOS_PLATFORM watchos) + if(NOT IOS_ARCH) + if (XCODE_VERSION VERSION_GREATER 10.0) + set(IOS_ARCH armv7k arm64_32) + else() + set(IOS_ARCH armv7k) + endif() + endif() +elseif (IOS_PLATFORM STREQUAL "SIMULATOR_WATCHOS") + set(XCODE_IOS_PLATFORM watchsimulator) + if(NOT IOS_ARCH) + set(IOS_ARCH x86_64) + endif() +else() + message(FATAL_ERROR "Invalid IOS_PLATFORM: ${IOS_PLATFORM}") +endif() +message(STATUS "Configuring iOS build for platform: ${IOS_PLATFORM}, " + "architecture(s): ${IOS_ARCH}") +# If user did not specify the SDK root to use, then query xcodebuild for it. +execute_process(COMMAND xcodebuild -version -sdk ${XCODE_IOS_PLATFORM} Path + OUTPUT_VARIABLE CMAKE_OSX_SYSROOT_INT + OUTPUT_QUIET ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +# If user did not specify the SDK root to use, then query xcodebuild for it. +if (NOT DEFINED CMAKE_OSX_SYSROOT OR (NOT CMAKE_OSX_SYSROOT STREQUAL CMAKE_OSX_SYSROOT_INT)) + execute_process(COMMAND xcodebuild -version -sdk ${XCODE_IOS_PLATFORM} Path + OUTPUT_VARIABLE CMAKE_OSX_SYSROOT + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +endif() +if (NOT EXISTS ${CMAKE_OSX_SYSROOT}) + message(SEND_ERROR "Please make sure that Xcode is installed and that the toolchain" + "is pointing to the correct path. Please run:" + "sudo xcode-select -s /Applications/Xcode.app/Contents/Developer" + "and see if that fixes the problem for you.") + message(FATAL_ERROR "Invalid CMAKE_OSX_SYSROOT: ${CMAKE_OSX_SYSROOT} " + "does not exist.") +elseif(DEFINED CMAKE_OSX_SYSROOT) + message(STATUS "Using manually set SDK path: ${CMAKE_OSX_SYSROOT} for platform: ${IOS_PLATFORM}") +else() + message(STATUS "Using SDK: ${CMAKE_OSX_SYSROOT} for platform: ${IOS_PLATFORM}") +endif() +# Specify minimum version of deployment target. +if (NOT DEFINED IOS_DEPLOYMENT_TARGET) + if (IOS_PLATFORM STREQUAL "WATCHOS" OR IOS_PLATFORM STREQUAL "SIMULATOR_WATCHOS") + # Unless specified, SDK version 2.0 is used by default as minimum target version (watchOS). + set(IOS_DEPLOYMENT_TARGET "2.0" + CACHE STRING "Minimum iOS version to build for." ) + else() + # Unless specified, SDK version 8.0 is used by default as minimum target version (iOS, tvOS). + set(IOS_DEPLOYMENT_TARGET "8.0" + CACHE STRING "Minimum iOS version to build for." ) + endif() + message(STATUS "Using the default min-version since IOS_DEPLOYMENT_TARGET not provided!") +endif() +# Use bitcode or not +if (NOT DEFINED ENABLE_BITCODE AND NOT IOS_ARCH MATCHES "((^|, )(i386|x86_64))+") + # Unless specified, enable bitcode support by default + set(ENABLE_BITCODE TRUE CACHE BOOL "Whether or not to enable bitcode") + message(STATUS "Enabling bitcode support by default. ENABLE_BITCODE not provided!") +endif() +if (NOT DEFINED ENABLE_BITCODE) + message(STATUS "Disabling bitcode support by default on simulators. ENABLE_BITCODE not provided for override!") +endif() +# Use ARC or not +if (NOT DEFINED ENABLE_ARC) + # Unless specified, enable ARC support by default + set(ENABLE_ARC TRUE CACHE BOOL "Whether or not to enable ARC") + message(STATUS "Enabling ARC support by default. ENABLE_ARC not provided!") +endif() +# Use hidden visibility or not +if (NOT DEFINED ENABLE_VISIBILITY) + # Unless specified, disable symbols visibility by default + set(ENABLE_VISIBILITY FALSE CACHE BOOL "Whether or not to hide symbols (-fvisibility=hidden)") + message(STATUS "Hiding symbols visibility by default. ENABLE_VISIBILITY not provided!") +endif() +# Get the SDK version information. +execute_process(COMMAND xcodebuild -sdk ${CMAKE_OSX_SYSROOT} -version SDKVersion + OUTPUT_VARIABLE IOS_SDK_VERSION + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +# Find the Developer root for the specific iOS platform being compiled for +# from CMAKE_OSX_SYSROOT. Should be ../../ from SDK specified in +# CMAKE_OSX_SYSROOT. There does not appear to be a direct way to obtain +# this information from xcrun or xcodebuild. +if (NOT CMAKE_IOS_DEVELOPER_ROOT) + get_filename_component(IOS_PLATFORM_SDK_DIR ${CMAKE_OSX_SYSROOT} PATH) + get_filename_component(CMAKE_IOS_DEVELOPER_ROOT ${IOS_PLATFORM_SDK_DIR} PATH) +endif() +if (NOT EXISTS ${CMAKE_IOS_DEVELOPER_ROOT}) + message(FATAL_ERROR "Invalid CMAKE_IOS_DEVELOPER_ROOT: " + "${CMAKE_IOS_DEVELOPER_ROOT} does not exist.") +endif() +# Find the C & C++ compilers for the specified SDK. +if (NOT CMAKE_C_COMPILER) + execute_process(COMMAND xcrun -sdk ${CMAKE_OSX_SYSROOT} -find clang + OUTPUT_VARIABLE CMAKE_C_COMPILER + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) + message(STATUS "Using C compiler: ${CMAKE_C_COMPILER}") +endif() +if (NOT CMAKE_CXX_COMPILER) + execute_process(COMMAND xcrun -sdk ${CMAKE_OSX_SYSROOT} -find clang++ + OUTPUT_VARIABLE CMAKE_CXX_COMPILER + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) + message(STATUS "Using CXX compiler: ${CMAKE_CXX_COMPILER}") +endif() +# Find (Apple's) libtool. +execute_process(COMMAND xcrun -sdk ${CMAKE_OSX_SYSROOT} -find libtool + OUTPUT_VARIABLE IOS_LIBTOOL + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +message(STATUS "Using libtool: ${IOS_LIBTOOL}") +# Configure libtool to be used instead of ar + ranlib to build static libraries. +# This is required on Xcode 7+, but should also work on previous versions of +# Xcode. +set(CMAKE_C_CREATE_STATIC_LIBRARY + "${IOS_LIBTOOL} -static -o ") +set(CMAKE_CXX_CREATE_STATIC_LIBRARY + "${IOS_LIBTOOL} -static -o ") +# Get the version of Darwin (OS X) of the host. +execute_process(COMMAND uname -r + OUTPUT_VARIABLE CMAKE_HOST_SYSTEM_VERSION + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +# Standard settings. +set(CMAKE_SYSTEM_NAME Darwin CACHE INTERNAL "") +set(CMAKE_SYSTEM_VERSION ${IOS_SDK_VERSION} CACHE INTERNAL "") +set(UNIX TRUE CACHE BOOL "") +set(APPLE TRUE CACHE BOOL "") +set(IOS TRUE CACHE BOOL "") +set(CMAKE_AR ar CACHE FILEPATH "" FORCE) +set(CMAKE_RANLIB ranlib CACHE FILEPATH "" FORCE) +# Force unset of OS X-specific deployment target (otherwise autopopulated), +# required as of cmake 2.8.10. +set(CMAKE_OSX_DEPLOYMENT_TARGET "" CACHE STRING + "Must be empty for iOS builds." FORCE) +# Set the architectures for which to build. +set(CMAKE_OSX_ARCHITECTURES ${IOS_ARCH} CACHE STRING "Build architecture for iOS") +# Change the type of target generated for try_compile() so it'll work when cross-compiling +set(CMAKE_TRY_COMPILE_TARGET_TYPE STATIC_LIBRARY) +# All iOS/Darwin specific settings - some may be redundant. +set(CMAKE_SHARED_LIBRARY_PREFIX "lib") +set(CMAKE_SHARED_LIBRARY_SUFFIX ".dylib") +set(CMAKE_SHARED_MODULE_PREFIX "lib") +set(CMAKE_SHARED_MODULE_SUFFIX ".so") +set(CMAKE_C_COMPILER_ABI ELF) +set(CMAKE_CXX_COMPILER_ABI ELF) +set(CMAKE_C_HAS_ISYSROOT 1) +set(CMAKE_CXX_HAS_ISYSROOT 1) +set(CMAKE_MODULE_EXISTS 1) +set(CMAKE_DL_LIBS "") +set(CMAKE_C_OSX_COMPATIBILITY_VERSION_FLAG "-compatibility_version ") +set(CMAKE_C_OSX_CURRENT_VERSION_FLAG "-current_version ") +set(CMAKE_CXX_OSX_COMPATIBILITY_VERSION_FLAG "${CMAKE_C_OSX_COMPATIBILITY_VERSION_FLAG}") +set(CMAKE_CXX_OSX_CURRENT_VERSION_FLAG "${CMAKE_C_OSX_CURRENT_VERSION_FLAG}") + +if(IOS_ARCH MATCHES "((^|, )(arm64|arm64e|x86_64))+") + set(CMAKE_C_SIZEOF_DATA_PTR 8) + set(CMAKE_CXX_SIZEOF_DATA_PTR 8) + message(STATUS "Using a data_ptr size of 8") +else() + set(CMAKE_C_SIZEOF_DATA_PTR 4) + set(CMAKE_CXX_SIZEOF_DATA_PTR 4) + message(STATUS "Using a data_ptr size of 4") +endif() + +message(STATUS "Building for minimum iOS version: ${IOS_DEPLOYMENT_TARGET}" + " (SDK version: ${IOS_SDK_VERSION})") +# Note that only Xcode 7+ supports the newer more specific: +# -m${XCODE_IOS_PLATFORM}-version-min flags, older versions of Xcode use: +# -m(ios/ios-simulator)-version-min instead. +if (IOS_PLATFORM STREQUAL "OS" OR IOS_PLATFORM STREQUAL "OS64") + if (XCODE_VERSION VERSION_LESS 7.0) + set(XCODE_IOS_PLATFORM_VERSION_FLAGS + "-mios-version-min=${IOS_DEPLOYMENT_TARGET}") + else() + # Xcode 7.0+ uses flags we can build directly from XCODE_IOS_PLATFORM. + set(XCODE_IOS_PLATFORM_VERSION_FLAGS + "-m${XCODE_IOS_PLATFORM}-version-min=${IOS_DEPLOYMENT_TARGET}") + endif() +elseif (IOS_PLATFORM STREQUAL "TVOS") + set(XCODE_IOS_PLATFORM_VERSION_FLAGS + "-mtvos-version-min=${IOS_DEPLOYMENT_TARGET}") +elseif (IOS_PLATFORM STREQUAL "SIMULATOR_TVOS") + set(XCODE_IOS_PLATFORM_VERSION_FLAGS + "-mtvos-simulator-version-min=${IOS_DEPLOYMENT_TARGET}") +elseif (IOS_PLATFORM STREQUAL "WATCHOS") + set(XCODE_IOS_PLATFORM_VERSION_FLAGS + "-mwatchos-version-min=${IOS_DEPLOYMENT_TARGET}") +elseif (IOS_PLATFORM STREQUAL "SIMULATOR_WATCHOS") + set(XCODE_IOS_PLATFORM_VERSION_FLAGS + "-mwatchos-simulator-version-min=${IOS_DEPLOYMENT_TARGET}") +else() + # SIMULATOR or SIMULATOR64 both use -mios-simulator-version-min. + set(XCODE_IOS_PLATFORM_VERSION_FLAGS + "-mios-simulator-version-min=${IOS_DEPLOYMENT_TARGET}") +endif() +message(STATUS "Version flags set to: ${XCODE_IOS_PLATFORM_VERSION_FLAGS}") + +if (ENABLE_BITCODE) + set(BITCODE "-fembed-bitcode") + set(HEADER_PAD "") + message(STATUS "Enabling bitcode support.") +else() + set(BITCODE "") + set(HEADER_PAD "-headerpad_max_install_names") + message(STATUS "Disabling bitcode support.") +endif() + +if (ENABLE_ARC) + set(FOBJC_ARC "-fobjc-arc") + message(STATUS "Enabling ARC support.") +else() + set(FOBJC_ARC "-fno-objc-arc") + message(STATUS "Disabling ARC support.") +endif() + +if (NOT ENABLE_VISIBILITY) + set(VISIBILITY "-fvisibility=hidden") + message(STATUS "Hiding symbols (-fvisibility=hidden).") +else() + set(VISIBILITY "") +endif() + +set(CMAKE_C_FLAGS +"${XCODE_IOS_PLATFORM_VERSION_FLAGS} ${BITCODE} -fobjc-abi-version=2 ${FOBJC_ARC} ${CMAKE_C_FLAGS}") +# Hidden visibilty is required for C++ on iOS. +set(CMAKE_CXX_FLAGS +"${XCODE_IOS_PLATFORM_VERSION_FLAGS} ${BITCODE} ${VISIBILITY} -fvisibility-inlines-hidden -fobjc-abi-version=2 ${FOBJC_ARC} ${CMAKE_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS} -DNDEBUG -Os -ffast-math ${BITCODE} ${CMAKE_CXX_FLAGS_MINSIZEREL}") +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS} -DNDEBUG -O2 -g -ffast-math ${BITCODE} ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}") +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS} -DNDEBUG -O3 -ffast-math ${BITCODE} ${CMAKE_CXX_FLAGS_RELEASE}") +set(CMAKE_C_LINK_FLAGS "${XCODE_IOS_PLATFORM_VERSION_FLAGS} -Wl,-search_paths_first ${CMAKE_C_LINK_FLAGS}") +set(CMAKE_CXX_LINK_FLAGS "${XCODE_IOS_PLATFORM_VERSION_FLAGS} -Wl,-search_paths_first ${CMAKE_CXX_LINK_FLAGS}") + +# In order to ensure that the updated compiler flags are used in try_compile() +# tests, we have to forcibly set them in the CMake cache, not merely set them +# in the local scope. +list(APPEND VARS_TO_FORCE_IN_CACHE + CMAKE_C_FLAGS + CMAKE_CXX_FLAGS + CMAKE_CXX_FLAGS_RELWITHDEBINFO + CMAKE_CXX_FLAGS_MINSIZEREL + CMAKE_CXX_FLAGS_RELEASE + CMAKE_C_LINK_FLAGS + CMAKE_CXX_LINK_FLAGS) +foreach(VAR_TO_FORCE ${VARS_TO_FORCE_IN_CACHE}) + set(${VAR_TO_FORCE} "${${VAR_TO_FORCE}}" CACHE STRING "") +endforeach() + +set(CMAKE_PLATFORM_HAS_INSTALLNAME 1) +set (CMAKE_SHARED_LINKER_FLAGS "-rpath @executable_path/Frameworks -rpath @loader_path/Frameworks") +set(CMAKE_SHARED_LIBRARY_CREATE_C_FLAGS "-dynamiclib ${HEADER_PAD}") +set(CMAKE_SHARED_MODULE_CREATE_C_FLAGS "-bundle ${HEADER_PAD}") +set(CMAKE_SHARED_MODULE_LOADER_C_FLAG "-Wl,-bundle_loader,") +set(CMAKE_SHARED_MODULE_LOADER_CXX_FLAG "-Wl,-bundle_loader,") +set(CMAKE_FIND_LIBRARY_SUFFIXES ".dylib" ".so" ".a") + +# Hack: if a new cmake (which uses CMAKE_INSTALL_NAME_TOOL) runs on an old +# build tree (where install_name_tool was hardcoded) and where +# CMAKE_INSTALL_NAME_TOOL isn't in the cache and still cmake didn't fail in +# CMakeFindBinUtils.cmake (because it isn't rerun) hardcode +# CMAKE_INSTALL_NAME_TOOL here to install_name_tool, so it behaves as it did +# before, Alex. +if (NOT DEFINED CMAKE_INSTALL_NAME_TOOL) + find_program(CMAKE_INSTALL_NAME_TOOL install_name_tool) +endif (NOT DEFINED CMAKE_INSTALL_NAME_TOOL) + +# Set the find root to the iOS developer roots and to user defined paths. +set(CMAKE_FIND_ROOT_PATH ${CMAKE_IOS_DEVELOPER_ROOT} ${CMAKE_OSX_SYSROOT} + ${CMAKE_PREFIX_PATH} CACHE STRING "iOS find search path root" FORCE) +# Default to searching for frameworks first. +set(CMAKE_FIND_FRAMEWORK FIRST) +# Set up the default search directories for frameworks. +set(CMAKE_SYSTEM_FRAMEWORK_PATH + ${CMAKE_OSX_SYSROOT}/System/Library/Frameworks + ${CMAKE_OSX_SYSROOT}/System/Library/PrivateFrameworks + ${CMAKE_OSX_SYSROOT}/Developer/Library/Frameworks) +# Only search the specified iOS SDK, not the remainder of the host filesystem. +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +# This little macro lets you set any XCode specific property. +macro(set_xcode_property TARGET XCODE_PROPERTY XCODE_VALUE XCODE_RELVERSION) + set(XCODE_RELVERSION_I "${XCODE_RELVERSION}") + if (XCODE_RELVERSION_I STREQUAL "All") + set_property(TARGET ${TARGET} PROPERTY + XCODE_ATTRIBUTE_${XCODE_PROPERTY} "${XCODE_VALUE}") + else() + set_property(TARGET ${TARGET} PROPERTY + XCODE_ATTRIBUTE_${XCODE_PROPERTY}[variant=${XCODE_RELVERSION_I}] "${XCODE_VALUE}") + endif() +endmacro(set_xcode_property) +# This macro lets you find executable programs on the host system. +macro(find_host_package) + set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) + set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER) + set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER) + set(IOS FALSE) + find_package(${ARGN}) + set(IOS TRUE) + set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM ONLY) + set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) + set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +endmacro(find_host_package)