@@ -19,11 +19,14 @@ CHECK_CXX_COMPILER_FLAG(-Wclass-memaccess CXX_SUPPORT_WCLASS_MEMACCESS) | |||||
set(MGE_ARCH AUTO CACHE STRING "Architecture on which MegEngine to be built.") | set(MGE_ARCH AUTO CACHE STRING "Architecture on which MegEngine to be built.") | ||||
set_property(CACHE MGE_ARCH PROPERTY STRINGS AUTO | set_property(CACHE MGE_ARCH PROPERTY STRINGS AUTO | ||||
x86_64 i386 | x86_64 i386 | ||||
armv7 aarch64 | |||||
naive fallback | naive fallback | ||||
) | ) | ||||
option(MGE_WITH_JIT "Build MegEngine with JIT." ON) | option(MGE_WITH_JIT "Build MegEngine with JIT." ON) | ||||
option(MGE_WITH_HALIDE "Build MegEngine with Halide JIT" ON) | option(MGE_WITH_HALIDE "Build MegEngine with Halide JIT" ON) | ||||
option(MGE_ARMV8_2_FEATURE_FP16 "Enable armv8.2-a+fp16 support" OFF) | |||||
option(MGE_ARMV8_2_FEATURE_DOTPROD "enable armv8.2-a+dotprod support" OFF) | |||||
option(MGE_DISABLE_FLOAT16 "Disable MegEngine float16 support." OFF) | option(MGE_DISABLE_FLOAT16 "Disable MegEngine float16 support." OFF) | ||||
option(MGE_WITH_CUDA "Enable MegEngine CUDA support." ON) | option(MGE_WITH_CUDA "Enable MegEngine CUDA support." ON) | ||||
option(MGE_CUDA_USE_STATIC "Enable MegEngine CUDA static linking." ON) | option(MGE_CUDA_USE_STATIC "Enable MegEngine CUDA static linking." ON) | ||||
@@ -31,12 +34,52 @@ option(MGE_WITH_TRT "Build MegEngine with TensorRT." ON) | |||||
option(MGE_USE_SYSTEM_LIB "Build MegEngine with system libraries." OFF) | option(MGE_USE_SYSTEM_LIB "Build MegEngine with system libraries." OFF) | ||||
option(MGB_WITH_FLATBUFFERS "Build MegBrain with FlatBuffers serialization support." ON) | option(MGB_WITH_FLATBUFFERS "Build MegBrain with FlatBuffers serialization support." ON) | ||||
if(CMAKE_TOOLCHAIN_FILE) | |||||
message("We are cross compiling.") | |||||
message("config FLATBUFFERS_FLATC_EXECUTABLE to: ${PROJECT_SOURCE_DIR}/build_dir/host_flatc/install/bin/flatc") | |||||
set(FLATBUFFERS_FLATC_EXECUTABLE "${PROJECT_SOURCE_DIR}/build_dir/host_flatc/install/bin/flatc") | |||||
if(ANDROID_TOOLCHAIN_ROOT) | |||||
if(NOT "${ANDROID_ARCH_NAME}" STREQUAL "") | |||||
set(ANDROID_ARCH ${ANDROID_ARCH_NAME}) | |||||
endif() | |||||
if(${ANDROID_ARCH} STREQUAL "arm") | |||||
set(MGE_ARCH "armv7") | |||||
elseif(${ANDROID_ARCH} STREQUAL "arm64") | |||||
set(MGE_ARCH "aarch64") | |||||
else() | |||||
message(FATAL_ERROR "DO NOT SUPPORT ANDROID ARCH NOW") | |||||
endif() | |||||
elseif(IOS_TOOLCHAIN_ROOT) | |||||
if(${IOS_ARCH} STREQUAL "armv7") | |||||
set(MGE_ARCH "armv7") | |||||
elseif(${IOS_ARCH} STREQUAL "arm64") | |||||
set(MGE_ARCH "aarch64") | |||||
elseif(${IOS_ARCH} STREQUAL "armv7k") | |||||
set(MGE_ARCH "armv7") | |||||
elseif(${IOS_ARCH} STREQUAL "arm64e") | |||||
set(MGE_ARCH "aarch64") | |||||
elseif(${IOS_ARCH} STREQUAL "armv7s") | |||||
set(MGE_ARCH "armv7") | |||||
else() | |||||
message(FATAL_ERROR "Unsupported IOS_ARCH.") | |||||
endif() | |||||
elseif(NOT "${ARM_CROSS_BUILD_ARCH}" STREQUAL "") | |||||
set(MGE_ARCH ${ARM_CROSS_BUILD_ARCH}) | |||||
else() | |||||
message(FATAL_ERROR "Unknown cross-compiling settings.") | |||||
endif() | |||||
message("CONFIG MGE_ARCH TO ${MGE_ARCH}") | |||||
endif() | |||||
if(${MGE_ARCH} STREQUAL "AUTO") | if(${MGE_ARCH} STREQUAL "AUTO") | ||||
if(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") | if(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") | ||||
set(MGE_ARCH "x86_64") | set(MGE_ARCH "x86_64") | ||||
elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i386" OR ${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i686") | elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i386" OR ${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i686") | ||||
set(MGE_ARCH "i386") | set(MGE_ARCH "i386") | ||||
elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "aarch64" OR ${CMAKE_SYSTEM_PROCESSOR} STREQUAL "arm64") | |||||
set(MGE_ARCH "aarch64") | |||||
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "^arm") | |||||
set(MGE_ARCH "armv7") | |||||
else() | else() | ||||
message(FATAL "Unknown machine architecture for MegEngine.") | message(FATAL "Unknown machine architecture for MegEngine.") | ||||
endif() | endif() | ||||
@@ -399,6 +442,38 @@ if(MGE_ARCH STREQUAL "x86_64" OR MGE_ARCH STREQUAL "i386") | |||||
endif() | endif() | ||||
endif() | endif() | ||||
if(MGE_ARCH STREQUAL "armv7") | |||||
# -funsafe-math-optimizations to enable neon auto-vectorization (since neon is not fully IEEE 754 compatible, GCC does not turn on neon auto-vectorization by default. | |||||
if(ANDROID) | |||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfloat-abi=softfp -mfpu=neon") | |||||
endif() | |||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -funsafe-math-optimizations") | |||||
set (MARCH "-march=armv7-a") | |||||
set (MEGDNN_ARMV7 1) | |||||
endif() | |||||
if(MGE_ARCH STREQUAL "aarch64") | |||||
set(MEGDNN_AARCH64 1) | |||||
set(MEGDNN_64_BIT 1) | |||||
set(MARCH "-march=armv8-a") | |||||
if(MGE_ARMV8_2_FEATURE_FP16) | |||||
message("Enable fp16 feature support in armv8.2") | |||||
if(NOT ${MGE_DISABLE_FLOAT16}) | |||||
set(MEGDNN_ENABLE_FP16_NEON 1) | |||||
endif() | |||||
set(MARCH "-march=armv8.2-a+fp16") | |||||
endif() | |||||
if(MGE_ARMV8_2_FEATURE_DOTPROD) | |||||
message("Enable dotprod feature support in armv8.2") | |||||
if(MGE_ARMV8_2_FEATURE_FP16) | |||||
set(MARCH "-march=armv8.2-a+fp16+dotprod") | |||||
else() | |||||
set(MARCH "-march=armv8.2-a+dotprod") | |||||
endif() | |||||
endif() | |||||
endif() | |||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}") | set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}") | ||||
@@ -29,6 +29,9 @@ class Handle { | |||||
NAIVE = 0, | NAIVE = 0, | ||||
FALLBACK = 1, | FALLBACK = 1, | ||||
X86 = 2, | X86 = 2, | ||||
ARM_COMMON = 3, | |||||
ARMV7 = 4, | |||||
AARCH64 = 5, | |||||
CUDA = 6, | CUDA = 6, | ||||
}; | }; | ||||
@@ -17,6 +17,22 @@ if(NOT ${MGE_ARCH} STREQUAL "naive") | |||||
set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C) | set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C) | ||||
list(APPEND SOURCES ${SOURCES_}) | list(APPEND SOURCES ${SOURCES_}) | ||||
endif() | endif() | ||||
elseif(${MGE_ARCH} STREQUAL "armv7") | |||||
file(GLOB_RECURSE SOURCES_ armv7/*.cpp) | |||||
list(APPEND SOURCES ${SOURCES_}) | |||||
file(GLOB_RECURSE SOURCES_ arm_common/*.cpp) | |||||
list(APPEND SOURCES ${SOURCES_}) | |||||
file(GLOB_RECURSE SOURCES_ armv7/*.S) | |||||
set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C) | |||||
list(APPEND SOURCES ${SOURCES_}) | |||||
elseif(${MGE_ARCH} STREQUAL "aarch64") | |||||
file(GLOB_RECURSE SOURCES_ aarch64/*.cpp) | |||||
list(APPEND SOURCES ${SOURCES_}) | |||||
file(GLOB_RECURSE SOURCES_ arm_common/*.cpp) | |||||
list(APPEND SOURCES ${SOURCES_}) | |||||
file(GLOB_RECURSE SOURCES_ aarch64/*.S) | |||||
set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C) | |||||
list(APPEND SOURCES ${SOURCES_}) | |||||
endif() | endif() | ||||
endif() | endif() | ||||
@@ -0,0 +1,139 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/fp16/algos.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/conv_bias/fp16/algos.h" | |||||
#include "src/aarch64/conv_bias/fp16/stride2_kern.h" | |||||
#include "src/arm_common/conv_bias/direct/multi_thread_common.h" | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
#include "midout.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
/* ===================== stride-2 algo ===================== */ | |||||
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) | |||||
bool ConvBiasImpl::AlgoF16DirectStride2::usable( | |||||
FallbackConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
bool aviliable = | |||||
param.filter_meta.format == param::Convolution::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float16 && | |||||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( | |||||
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 1) { | |||||
auto wbundle = arm_common::MultithreadDirectConvCommon< | |||||
dt_float16, __fp16>::get_bundle_stride(param, m_large_group); | |||||
return wbundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||||
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { | |||||
return get_kimpls(param); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||||
const NCBKernSizeParam& param) const { | |||||
auto fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
size_t N = param.n; | |||||
size_t IC = param.filter_meta.icpg; | |||||
size_t OC = param.filter_meta.ocpg; | |||||
size_t group = fm.group; | |||||
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | |||||
size_t, size_t, size_t, size_t, size_t)>; | |||||
Func conv = nullptr; | |||||
if (FH == 2) { | |||||
conv = fp16::conv_stride2::do_conv_2x2_stride2; | |||||
} else if (FH == 3) { | |||||
conv = fp16::conv_stride2::do_conv_3x3_stride2; | |||||
} else if (FH == 5) { | |||||
conv = fp16::conv_stride2::do_conv_5x5_stride2; | |||||
} else if (FH == 7) { | |||||
conv = fp16::conv_stride2::do_conv_7x7_stride2; | |||||
} | |||||
WorkspaceBundle wbundle = arm_common::MultithreadDirectConvCommon< | |||||
dt_float16, __fp16>::get_bundle_stride(param, m_large_group); | |||||
SmallVector<NCBKern> ret_kerns; | |||||
//! Dense conv and small group | |||||
if (m_large_group) { | |||||
//! Channel wise conv and big groups | |||||
auto exec_one_group = [wbundle, conv](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
auto fm = kern_param.filter_meta; | |||||
size_t IC = fm.icpg; | |||||
size_t OC = fm.ocpg; | |||||
WorkspaceBundle bundle = wbundle; | |||||
for (size_t ic = 0; ic < IC; ic++) { | |||||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
{ncb_index.thread_id, 0, ic}); | |||||
} | |||||
for (size_t oc = 0; oc < OC; oc++) { | |||||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||||
do_conv_kern_stride(bundle, kern_param, ncb_index, conv, | |||||
{ncb_index.thread_id, 0, oc}); | |||||
} | |||||
}; | |||||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||||
} else { | |||||
WorkspaceBundle bundle = wbundle; | |||||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||||
auto do_conv = [bundle, conv](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||||
do_conv_kern_stride(bundle, kern_param, ncb_index, conv, | |||||
ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||||
} | |||||
return ret_kerns; | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,42 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/fp16/algos.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/aarch64/conv_bias/opr_impl.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
/* ===================== stride-2 algo ===================== */ | |||||
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||||
bool m_large_group; | |||||
public: | |||||
AlgoF16DirectStride2(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return m_large_group ? "ARMV8F16STRD2_LARGE_GROUP" | |||||
: "ARMV8F16STRD2_SMALL_GROUP"; | |||||
} | |||||
bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(FallbackConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
SmallVector<NCBKern> dispatch_kerns(FallbackConvBiasImpl*, | |||||
const NCBKernSizeParam&) const override; | |||||
}; | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,137 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/fp32/algos.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/conv_bias/fp32/algos.h" | |||||
#include "src/aarch64/conv_bias/fp32/stride2_kern.h" | |||||
#include "src/arm_common/conv_bias/direct/multi_thread_common.h" | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
#include "midout.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) | |||||
bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||||
FallbackConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
bool aviliable = | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | |||||
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 1) { | |||||
auto wbundle = arm_common::MultithreadDirectConvCommon< | |||||
float, float>::get_bundle_stride(param, m_large_group); | |||||
return wbundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||||
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { | |||||
return get_kimpls(param); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
const NCBKernSizeParam& param) const { | |||||
auto fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
size_t N = param.n; | |||||
size_t IC = param.filter_meta.icpg; | |||||
size_t OC = param.filter_meta.ocpg; | |||||
size_t group = fm.group; | |||||
using Func = std::function<void(const float*, const float*, float*, size_t, | |||||
size_t, size_t, size_t, size_t)>; | |||||
Func conv = nullptr; | |||||
if (FH == 2) { | |||||
conv = fp32::conv_stride2::do_conv_2x2_stride2; | |||||
} else if (FH == 3) { | |||||
conv = fp32::conv_stride2::do_conv_3x3_stride2; | |||||
} else if (FH == 5) { | |||||
conv = fp32::conv_stride2::do_conv_5x5_stride2; | |||||
} else if (FH == 7) { | |||||
conv = fp32::conv_stride2::do_conv_7x7_stride2; | |||||
} | |||||
WorkspaceBundle wbundle = arm_common::MultithreadDirectConvCommon< | |||||
float, float>::get_bundle_stride(param, m_large_group); | |||||
SmallVector<NCBKern> ret_kerns; | |||||
//! Dense conv and small group | |||||
if (m_large_group) { | |||||
//! Channel wise conv and big groups | |||||
auto exec_one_group = [wbundle, conv](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
auto fm = kern_param.filter_meta; | |||||
size_t IC = fm.icpg; | |||||
size_t OC = fm.ocpg; | |||||
WorkspaceBundle bundle = wbundle; | |||||
for (size_t ic = 0; ic < IC; ic++) { | |||||
arm_common::MultithreadDirectConvCommon<float, float>:: | |||||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
{ncb_index.thread_id, 0, ic}); | |||||
} | |||||
for (size_t oc = 0; oc < OC; oc++) { | |||||
arm_common::MultithreadDirectConvCommon< | |||||
float, float>::do_conv_kern_stride(bundle, kern_param, | |||||
ncb_index, conv, | |||||
{ncb_index.thread_id, | |||||
0, oc}); | |||||
} | |||||
}; | |||||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||||
} else { | |||||
WorkspaceBundle bundle = wbundle; | |||||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
arm_common::MultithreadDirectConvCommon<float, float>:: | |||||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||||
auto do_conv = [bundle, conv](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
arm_common::MultithreadDirectConvCommon< | |||||
float, float>::do_conv_kern_stride(bundle, kern_param, | |||||
ncb_index, conv, | |||||
ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||||
} | |||||
return ret_kerns; | |||||
} |
@@ -0,0 +1,47 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/fp32/algos.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/aarch64/conv_bias/opr_impl.h" | |||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||||
/* ===================== stride-2 algo ===================== */ | |||||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||||
bool m_large_group; | |||||
public: | |||||
AlgoF32DirectStride2(bool large_group) : m_large_group(large_group) {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return m_large_group ? "ARMV8F32STRD2_LARGE_GROUP" | |||||
: "ARMV8F32STRD2_SMALL_GROUP"; | |||||
} | |||||
bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(FallbackConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
SmallVector<NCBKern> dispatch_kerns(FallbackConvBiasImpl*, | |||||
const NCBKernSizeParam&) const override; | |||||
}; | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,187 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/int8/algos.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/conv_bias/int8/algos.h" | |||||
#include "src/aarch64/conv_bias/int8/strategy.h" | |||||
#include "src/arm_common/convolution/img2col_helper.h" | |||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
#include "src/fallback/matrix_mul/gemm_impl.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_aarch64_conv_bias_int8_gemm) | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using megdnn::arm_common::HSwishOp; | |||||
using megdnn::arm_common::ReluOp; | |||||
using megdnn::arm_common::TypeCvtOp; | |||||
/* ===================== matrix mul algo ===================== */ | |||||
bool ConvBiasImpl::AlgoS8MatrixMul::usable( | |||||
FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||||
MEGDNN_MARK_USED_VAR(opr); | |||||
auto&& fm = param.filter_meta; | |||||
return param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
//! As postprocess, the bias is not contigous read, make the | |||||
//! performance bad, so we do not process it in fused kernel | |||||
param.bias_mode != BiasMode::BIAS && | |||||
//! This algo is only support single thread | |||||
param.nr_threads == 1_z; | |||||
} | |||||
WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||||
const NCBKernSizeParam& param) { | |||||
UNPACK_CONV_NCB_KERN_SIZES(param); | |||||
MEGDNN_MARK_USED_VAR(N); | |||||
auto IW2 = IH + 2 * PH; | |||||
auto IH2 = IW + 2 * PW; | |||||
bool can_matrix_mul_direct = | |||||
(FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0); | |||||
// temp space to store padding-free src (with 16 extra int8) | |||||
// temp space to store unrolled matrix (with 16 extra int8) | |||||
// workspace for matrix mul opr | |||||
size_t part0, part1, part2; | |||||
if (can_matrix_mul_direct) { | |||||
part0 = part1 = 0; | |||||
} else { | |||||
part0 = (IC * IH2 * IW2 + 16) * sizeof(int8_t); | |||||
part1 = (IC * FH * FW * OH * OW + 16) * sizeof(int8_t); | |||||
} | |||||
{ | |||||
size_t M = OC; | |||||
size_t K = IC * FH * FW; | |||||
size_t N = OH * OW; | |||||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
_bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \ | |||||
_bias_midout_enum, _nonline_midout_enum) { \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
part2 = megdnn::matmul::GemmInterleaved< \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ | |||||
M, N, K, false, false, strategy) \ | |||||
.get_workspace_size(); \ | |||||
} \ | |||||
MIDOUT_END() | |||||
#if !(__ARM_FEATURE_DOTPROD) | |||||
DISPATCH_GEMM_BIAS(s8_4x4, 0) | |||||
#else | |||||
DISPATCH_GEMM_BIAS(s8_8x12, 1) | |||||
#endif | |||||
#undef DISPATCH_GEMM_STRATEGY | |||||
} | |||||
return {nullptr, {part0, part1, part2}}; | |||||
} | |||||
void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||||
const NCBKernIndex& ncb_index) { | |||||
auto is_xcorr = !param.filter_meta.should_flip; | |||||
UNPACK_CONV_NCB_KERN_SIZES(param); | |||||
auto bundle = get_bundle(param); | |||||
bundle.set(param.workspace_ptr); | |||||
auto IH2 = IH + 2 * PH; | |||||
auto IW2 = IW + 2 * PW; | |||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
// workspace = tmp..src2 | |||||
for (size_t n = 0; n < N; ++n) { | |||||
dt_int8* src = const_cast<dt_int8*>(param.src<dt_int8>(n, group_id)); | |||||
dt_int8* filter = const_cast<dt_int8*>(param.filter<dt_int8>(group_id)); | |||||
dt_int8* dst = static_cast<dt_int8*>(param.dst<dt_int8>(n, group_id)); | |||||
dt_int32* bias = const_cast<dt_int32*>(param.bias<dt_int32>(n, group_id)); | |||||
dt_int8 *B, *src2; | |||||
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { | |||||
// special case: 1x1 | |||||
B = const_cast<dt_int8*>(src); | |||||
} else { | |||||
src2 = static_cast<dt_int8*>(bundle.get(0)); | |||||
// copy src to src2; | |||||
dt_int8* src2_ptr = src2; | |||||
const dt_int8* src_ptr = src; | |||||
rep(ic, IC) { | |||||
if (PH != 0) { | |||||
std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2); | |||||
src2_ptr += PH * IW2; | |||||
} | |||||
rep(ih, IH) { | |||||
if (PW != 0) | |||||
rep(pw, PW) { *(src2_ptr++) = 0.0f; } | |||||
std::memcpy(src2_ptr, src_ptr, sizeof(dt_int8) * IW); | |||||
src2_ptr += IW; | |||||
src_ptr += IW; | |||||
if (PW != 0) | |||||
rep(pw, PW) { *(src2_ptr++) = 0.0f; } | |||||
} | |||||
if (PH != 0) { | |||||
std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2); | |||||
src2_ptr += PH * IW2; | |||||
} | |||||
} | |||||
B = static_cast<dt_int8*>(bundle.get(1)); | |||||
if (SH == 1 && SW == 1) { | |||||
if (is_xcorr) | |||||
img2col<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | |||||
else | |||||
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | |||||
} else { | |||||
if (is_xcorr) | |||||
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||||
FW, SH, SW); | |||||
else | |||||
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||||
FW, SH, SW); | |||||
} | |||||
} | |||||
{ | |||||
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)), | |||||
bundle.get_size(2)); | |||||
size_t M = OC; | |||||
size_t K = IC * FH * FW; | |||||
size_t N = OH * OW; | |||||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
_bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \ | |||||
_bias_midout_enum, _nonline_midout_enum) { \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
megdnn::matmul::GemmInterleaved< \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ | |||||
bias); \ | |||||
} \ | |||||
MIDOUT_END() | |||||
#if !(__ARM_FEATURE_DOTPROD) | |||||
DISPATCH_GEMM_BIAS(s8_4x4, 0) | |||||
#else | |||||
DISPATCH_GEMM_BIAS(s8_8x12, 1) | |||||
#endif | |||||
#undef DISPATCH_GEMM_STRATEGY | |||||
} | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,52 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/int8/algos.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/aarch64/conv_bias/opr_impl.h" | |||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||||
class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase { | |||||
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | |||||
static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "S8MATMUL"; } | |||||
bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(FallbackConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
SmallVector<NCBKern> dispatch_kerns( | |||||
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const override { | |||||
size_t group = param.filter_meta.group; | |||||
return {{kimpl, {group, 1_z, 1_z}}}; | |||||
} | |||||
//! select matmul to the highest preference | |||||
bool is_preferred(FallbackConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override { | |||||
return static_cast<arm_common::ConvBiasImpl*>(opr) | |||||
->is_matmul_quantized_prefer(param); | |||||
} | |||||
}; | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,309 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/int8/strategy.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/conv_bias/int8/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
#include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h" | |||||
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" | |||||
#include "src/arm_common/conv_bias/matmul_postprocess.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
namespace impl { | |||||
template <BiasMode bmode, typename Op, int block_m, int block_n> | |||||
struct KernCaller; | |||||
#if __ARM_FEATURE_DOTPROD | |||||
template <BiasMode bmode, typename Op> | |||||
struct KernCaller<bmode, Op, 8, 12> { | |||||
static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, | |||||
Op op, const dt_int32* bias, dt_int32* workspace) { | |||||
megdnn_assert(is_first_k); | |||||
constexpr size_t A_INTERLEAVE = 8; | |||||
constexpr size_t B_INTERLEAVE = 12; | |||||
//! K is packed to times of 4 | |||||
K = round_up<size_t>(K, 4); | |||||
const int K8 = (K << 3); | |||||
const int K12 = K * 12; | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
int8_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_int8* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x12x4::kern_8x12(packA, cur_packB, K, workspace, 12, | |||||
is_first_k); | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8, | |||||
12>::postprocess(bias, workspace, | |||||
output, LDC, op); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K12; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x12x4::kern_8x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k, std::min<size_t>(N - n, 4)); | |||||
#define cb(m, n) \ | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 4, 8, n>::postprocess( \ | |||||
bias, workspace, output, LDC, op); | |||||
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4)); | |||||
#undef cb | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K8; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias += A_INTERLEAVE; | |||||
} | |||||
} | |||||
for (; m < M; m += 4) { | |||||
int8_t* output = C + (m * LDC); | |||||
const dt_int8* cur_packB = packB; | |||||
size_t n = 0; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x12x4::kern_4x12(packA, cur_packB, K, workspace, 12, | |||||
is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
#define cb(m, n) \ | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 12, m, n>::postprocess( \ | |||||
bias, workspace, output, LDC, op); | |||||
DISPATCH_M_N(cb, std::min<size_t>(M - m, 4), 12); | |||||
#undef cb | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K12; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x12x4::kern_4x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
#define cb(m, n) \ | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | |||||
bias, workspace, output, LDC, op); | |||||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
#undef cb | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias += 4; | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#else | |||||
template <BiasMode bmode, typename Op> | |||||
struct KernCaller<bmode, Op, 4, 4> { | |||||
static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, | |||||
Op op, const dt_int32* bias, dt_int32* workspace) { | |||||
megdnn_assert(is_first_k); | |||||
constexpr size_t A_INTERLEAVE = 4; | |||||
constexpr size_t B_INTERLEAVE = 4; | |||||
//! K is packed to times of 4 | |||||
K = round_up<size_t>(K, 16); | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
int8_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_int8* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k); | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, | |||||
4>::postprocess(bias, workspace, | |||||
output, LDC, op); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K4; | |||||
} | |||||
for (; n < N; n += B_INTERLEAVE) { | |||||
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, workspace, | |||||
4, is_first_k, 4, | |||||
std::min<size_t>(N - n, 4)); | |||||
#define cb(m, n) \ | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, n>::postprocess( \ | |||||
bias, workspace, output, LDC, op); | |||||
DISPATCH_N(cb, 4, std::min<size_t>(N - n, 4)); | |||||
#undef cb | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias += A_INTERLEAVE; | |||||
} | |||||
} | |||||
for (; m < M; m += A_INTERLEAVE) { | |||||
int8_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_int8* cur_packB = packB; | |||||
for (; n < N; n += B_INTERLEAVE) { | |||||
matmul_4x4x16::kern_4x4_remain( | |||||
packA, cur_packB, K, workspace, 4, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
#define cb(m, n) \ | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | |||||
bias, workspace, output, LDC, op); | |||||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
#undef cb | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias += A_INTERLEAVE; | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#endif | |||||
} // namespace impl | |||||
#if !(__ARM_FEATURE_DOTPROD) | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) | |||||
void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
if (transpose) { | |||||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} else { | |||||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, | |||||
int ldin, int x0, int xmax, int k0, | |||||
int kmax, bool transpose) const { | |||||
if (transpose) { | |||||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} else { | |||||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | |||||
} | |||||
size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const { | |||||
return 4 * 4 * sizeof(dt_int32); | |||||
} | |||||
#else | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) | |||||
void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t); | |||||
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t); | |||||
if (transpose) { | |||||
matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} else { | |||||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void gemm_s8_8x12_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, | |||||
int ldin, int x0, int xmax, int k0, | |||||
int kmax, bool transpose) const { | |||||
if (transpose) { | |||||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} else { | |||||
matmul_8x12x4::gemm_s8_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | |||||
} | |||||
size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const { | |||||
return 8 * 12 * sizeof(dt_int32); | |||||
} | |||||
#endif | |||||
#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ | |||||
void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \ | |||||
size_t K, dt_int8* C, size_t LDC, bool is_first_k, \ | |||||
const dt_int32* bias, dt_int32* workspace) const { \ | |||||
float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
DEFINE_OP(_OP); \ | |||||
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ | |||||
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||||
workspace); \ | |||||
} | |||||
#define DEFINE_OP(_Op) \ | |||||
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, scale_C); | |||||
#if !(__ARM_FEATURE_DOTPROD) | |||||
KERN(4, 4, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) | |||||
KERN(4, 4, nobias, BiasMode::NO_BIAS, relu, ReluOp) | |||||
KERN(4, 4, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||||
#else | |||||
KERN(8, 12, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) | |||||
KERN(8, 12, nobias, BiasMode::NO_BIAS, relu, ReluOp) | |||||
KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||||
#endif | |||||
#undef DEFINE_OP | |||||
#define DEFINE_OP(_Op) \ | |||||
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \ | |||||
scale_A* scale_B, scale_C); | |||||
#if !(__ARM_FEATURE_DOTPROD) | |||||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||||
FuseAddHSwishOp) | |||||
#else | |||||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||||
FuseAddHSwishOp) | |||||
#endif | |||||
#undef DEFINE_OP | |||||
#undef KERN | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,69 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/int8/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
#if !(__ARM_FEATURE_DOTPROD) | |||||
/** | |||||
* \brief base strategy of gemm. | |||||
* | |||||
* \name gemm_<type>_<block>_biasmode_nolinemode | |||||
*/ | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 4, 16, | |||||
false, true, | |||||
gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_relu, | |||||
gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_hswish, | |||||
gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_identity, | |||||
gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu, | |||||
gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish, | |||||
gemm_s8_4x4_nobias_identity); | |||||
#else | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4, | |||||
false, true, | |||||
gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_relu, | |||||
gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_hswish, | |||||
gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_identity, | |||||
gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu, | |||||
gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish, | |||||
gemm_s8_8x12_nobias_identity); | |||||
#endif | |||||
} // namespace matmul | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,69 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/conv_bias/opr_impl.h" | |||||
#include "src/aarch64/conv_bias/int8/algos.h" | |||||
#include "src/aarch64/conv_bias/quint8/algos.h" | |||||
#include "src/naive/handle.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/common/metahelper.h" | |||||
#include "src/fallback/convolution/opr_impl.h" | |||||
#include "src/aarch64/conv_bias/fp32/algos.h" | |||||
#include "src/aarch64/conv_bias/fp16/algos.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
AlgoF32DirectStride2 f32_direct_stride2_large_group{true}; | |||||
AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; | |||||
AlgoS8MatrixMul s8_matrix_mul; | |||||
AlgoQU8MatrixMul qu8_matrix_mul; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
AlgoF16DirectStride2 f16_direct_stride2_large_group{true}; | |||||
AlgoF16DirectStride2 f16_direct_stride2_small_group{false}; | |||||
#endif | |||||
public: | |||||
AlgoPack() { | |||||
matmul_algos.emplace_back(&qu8_matrix_mul); | |||||
matmul_algos.emplace_back(&s8_matrix_mul); | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
direct_algos.emplace_back(&f16_direct_stride2_large_group); | |||||
direct_algos.emplace_back(&f16_direct_stride2_small_group); | |||||
#endif | |||||
direct_algos.emplace_back(&f32_direct_stride2_large_group); | |||||
direct_algos.emplace_back(&f32_direct_stride2_small_group); | |||||
} | |||||
SmallVector<AlgoBase*> direct_algos; | |||||
SmallVector<AlgoBase*> matmul_algos; | |||||
}; | |||||
SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
static AlgoPack sl_algo_pack; | |||||
auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | |||||
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | |||||
sl_algo_pack.direct_algos.end()); | |||||
//! We put matmul algos at the end. Because matmul will get privilege when | |||||
//! prefer return true. See | |||||
//! fallback::ConvolutionImpl::ncb_1g_get_all_algorithms for more details. | |||||
algos.insert(algos.end(), sl_algo_pack.matmul_algos.begin(), | |||||
sl_algo_pack.matmul_algos.end()); | |||||
return std::move(algos); | |||||
} | |||||
const char* ConvBiasImpl::get_algorithm_set_name() const { | |||||
return "AARCH64"; | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/utils.h" | |||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
class ConvBiasImpl : public arm_common::ConvBiasImpl { | |||||
public: | |||||
using arm_common::ConvBiasImpl::ConvBiasImpl; | |||||
SmallVector<AlgoBase*> algo_pack() override; | |||||
protected: | |||||
const char* get_algorithm_set_name() const override; | |||||
private: | |||||
class AlgoF32DirectStride2; | |||||
class AlgoS8MatrixMul; | |||||
class AlgoQU8MatrixMul; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
class AlgoF16DirectStride2; | |||||
#endif | |||||
class AlgoPack; | |||||
}; | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,181 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/quint8/algos.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/conv_bias/quint8/algos.h" | |||||
#include "src/aarch64/conv_bias/quint8/strategy.h" | |||||
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | |||||
#include "src/aarch64/matrix_mul/quint8_dot/strategy.h" | |||||
#include "src/arm_common/convolution/img2col_helper.h" | |||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
#include "src/fallback/matrix_mul/gemm_impl.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_aarch64_conv_bias_quint8_gemm) | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using megdnn::arm_common::HSwishOp; | |||||
using megdnn::arm_common::ReluOp; | |||||
using megdnn::arm_common::TypeCvtOp; | |||||
/* ===================== matrix mul algo ===================== */ | |||||
bool ConvBiasImpl::AlgoQU8MatrixMul::usable( | |||||
FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||||
MEGDNN_MARK_USED_VAR(opr); | |||||
auto&& fm = param.filter_meta; | |||||
return param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | |||||
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm && | |||||
fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
//! As postprocess, the bias is not contigous read, make the | |||||
//! performance bad, so we do not process it in fused kernel | |||||
param.bias_mode != BiasMode::BIAS && | |||||
//! This algo is only support single thread | |||||
param.nr_threads == 1_z; | |||||
} | |||||
WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||||
const NCBKernSizeParam& param) { | |||||
UNPACK_CONV_NCB_KERN_SIZES(param); | |||||
MEGDNN_MARK_USED_VAR(N); | |||||
auto IW2 = IH + 2 * PH; | |||||
auto IH2 = IW + 2 * PW; | |||||
bool can_matrix_mul_direct = | |||||
(FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0); | |||||
// temp space to store padding-free src (with 16 extra int8) | |||||
// temp space to store unrolled matrix (with 16 extra int8) | |||||
// workspace for matrix mul opr | |||||
size_t part0, part1, part2; | |||||
if (can_matrix_mul_direct) { | |||||
part0 = part1 = 0; | |||||
} else { | |||||
part0 = (IC * IH2 * IW2 + 16) * sizeof(uint8_t); | |||||
part1 = (IC * FH * FW * OH * OW + 16) * sizeof(uint8_t); | |||||
} | |||||
{ | |||||
size_t M = OC; | |||||
size_t K = IC * FH * FW; | |||||
size_t N = OH * OW; | |||||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
_bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \ | |||||
_bias_midout_enum, _nonline_midout_enum) { \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
part2 = megdnn::matmul::GemmInterleaved< \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ | |||||
M, N, K, false, false, strategy) \ | |||||
.get_workspace_size(); \ | |||||
} \ | |||||
MIDOUT_END() | |||||
DISPATCH_GEMM_BIAS(u8_8x8, 0) | |||||
#undef DISPATCH_GEMM_STRATEGY | |||||
} | |||||
return {nullptr, {part0, part1, part2}}; | |||||
} | |||||
void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||||
const NCBKernIndex& ncb_index) { | |||||
auto is_xcorr = !param.filter_meta.should_flip; | |||||
UNPACK_CONV_NCB_KERN_SIZES(param); | |||||
auto bundle = get_bundle(param); | |||||
bundle.set(param.workspace_ptr); | |||||
auto IH2 = IH + 2 * PH; | |||||
auto IW2 = IW + 2 * PW; | |||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
uint8_t src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||||
// workspace = tmp..src2 | |||||
for (size_t n = 0; n < N; ++n) { | |||||
uint8_t* src = const_cast<uint8_t*>(param.src<uint8_t>(n, group_id)); | |||||
uint8_t* filter = const_cast<uint8_t*>(param.filter<uint8_t>(group_id)); | |||||
uint8_t* dst = static_cast<uint8_t*>(param.dst<uint8_t>(n, group_id)); | |||||
int32_t* bias = const_cast<int32_t*>(param.bias<int32_t>(n, group_id)); | |||||
uint8_t *B, *src2; | |||||
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { | |||||
// special case: 1x1 | |||||
B = const_cast<uint8_t*>(src); | |||||
} else { | |||||
src2 = static_cast<uint8_t*>(bundle.get(0)); | |||||
// copy src to src2; | |||||
uint8_t* src2_ptr = src2; | |||||
const uint8_t* src_ptr = src; | |||||
rep(ic, IC) { | |||||
if (PH != 0) { | |||||
std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2); | |||||
src2_ptr += PH * IW2; | |||||
} | |||||
rep(ih, IH) { | |||||
if (PW != 0) | |||||
rep(pw, PW) { *(src2_ptr++) = src_zp; } | |||||
std::memcpy(src2_ptr, src_ptr, sizeof(uint8_t) * IW); | |||||
src2_ptr += IW; | |||||
src_ptr += IW; | |||||
if (PW != 0) | |||||
rep(pw, PW) { *(src2_ptr++) = src_zp; } | |||||
} | |||||
if (PH != 0) { | |||||
std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2); | |||||
src2_ptr += PH * IW2; | |||||
} | |||||
} | |||||
B = static_cast<uint8_t*>(bundle.get(1)); | |||||
if (SH == 1 && SW == 1) { | |||||
if (is_xcorr) | |||||
img2col<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | |||||
else | |||||
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | |||||
} else { | |||||
if (is_xcorr) | |||||
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||||
FW, SH, SW); | |||||
else | |||||
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||||
FW, SH, SW); | |||||
} | |||||
} | |||||
{ | |||||
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)), | |||||
bundle.get_size(2)); | |||||
size_t M = OC; | |||||
size_t K = IC * FH * FW; | |||||
size_t N = OH * OW; | |||||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
_bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \ | |||||
_bias_midout_enum, _nonline_midout_enum) { \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
megdnn::matmul::GemmInterleaved< \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ | |||||
bias); \ | |||||
} \ | |||||
MIDOUT_END() | |||||
DISPATCH_GEMM_BIAS(u8_8x8, 0) | |||||
#undef DISPATCH_GEMM_STRATEGY | |||||
} | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,52 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/quint8/algos.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/aarch64/conv_bias/opr_impl.h" | |||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||||
class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase { | |||||
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | |||||
static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "QU8MATMUL"; } | |||||
bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(FallbackConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
SmallVector<NCBKern> dispatch_kerns( | |||||
FallbackConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override { | |||||
size_t group = param.filter_meta.group; | |||||
return {{kimpl, {group, 1_z, 1_z}}}; | |||||
} | |||||
//! select matmul to the highest preference | |||||
bool is_preferred(FallbackConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override { | |||||
return static_cast<arm_common::ConvBiasImpl*>(opr) | |||||
->is_matmul_quantized_prefer(param); | |||||
} | |||||
}; | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,319 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/quint8/strategy.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/conv_bias/quint8/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" | |||||
#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" | |||||
#include "src/arm_common/conv_bias/matmul_postprocess.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
namespace impl { | |||||
template <BiasMode bmode, typename Op, int block_m, int block_n> | |||||
struct KernCaller; | |||||
#if __ARM_FEATURE_DOTPROD | |||||
template <BiasMode bmode, typename Op> | |||||
struct KernCaller<bmode, Op, 8, 8> { | |||||
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||||
size_t N, size_t K, dt_uint8* C, size_t LDC, | |||||
bool is_first_k, Op op, const dt_int32* bias, | |||||
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||||
megdnn_assert(is_first_k); | |||||
constexpr size_t A_INTERLEAVE = 8; | |||||
constexpr size_t B_INTERLEAVE = 8; | |||||
const uint32_t zAB = | |||||
static_cast<uint32_t>(zp_A) * static_cast<uint32_t>(zp_B) * K; | |||||
//! K is packed to times of 4 | |||||
K = round_up<size_t>(K, 4); | |||||
const int K8 = (K << 3); | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
uint8_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_uint8* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x8x4::kern_8x8(packA, cur_packB, K, workspace, 8, | |||||
is_first_k, zp_A, zp_B, zAB); | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, | |||||
8>::postprocess(bias, workspace, | |||||
output, LDC, op); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x8x4::kern_8x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k, std::min<size_t>(N - n, 4), | |||||
zp_A, zp_B, zAB); | |||||
#define cb(m, n) \ | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | |||||
bias, workspace, output, LDC, op); | |||||
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4)); | |||||
#undef cb | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K8; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias += A_INTERLEAVE; | |||||
} | |||||
} | |||||
for (; m < M; m += 4) { | |||||
uint8_t* output = C + (m * LDC); | |||||
const dt_uint8* cur_packB = packB; | |||||
size_t n = 0; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x8x4::kern_4x8(packA, cur_packB, K, workspace, 8, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
zp_A, zp_B, zAB); | |||||
#define cb(m, n) \ | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | |||||
bias, workspace, output, LDC, op); | |||||
DISPATCH_M_N(cb, std::min<size_t>(M - m, 4), 8); | |||||
#undef cb | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x8x4::kern_4x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4), zp_A, zp_B, | |||||
zAB); | |||||
#define cb(m, n) \ | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | |||||
bias, workspace, output, LDC, op); | |||||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
#undef cb | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias += 4; | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#else | |||||
template <BiasMode bmode, typename Op> | |||||
struct KernCaller<bmode, Op, 8, 8> { | |||||
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||||
size_t N, size_t K, dt_uint8* C, size_t LDC, | |||||
bool is_first_k, Op op, const dt_int32* bias, | |||||
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||||
megdnn_assert(is_first_k); | |||||
constexpr size_t A_INTERLEAVE = 8; | |||||
constexpr size_t B_INTERLEAVE = 8; | |||||
//! K is packed to times of 8 | |||||
K = round_up<size_t>(K, 8); | |||||
const int K8 = K * 8; | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
uint8_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_uint8* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, workspace, 8, | |||||
is_first_k, zp_A, zp_B); | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, | |||||
8>::postprocess(bias, workspace, | |||||
output, LDC, op); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k, std::min<size_t>(N - n, 4), | |||||
zp_A, zp_B); | |||||
#define cb(m, n) \ | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | |||||
bias, workspace, output, LDC, op); | |||||
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4)); | |||||
#undef cb | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K8; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias += A_INTERLEAVE; | |||||
} | |||||
} | |||||
for (; m < M; m += 4) { | |||||
uint8_t* output = C + (m * LDC); | |||||
const dt_uint8* cur_packB = packB; | |||||
size_t n = 0; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, workspace, 8, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
zp_A, zp_B); | |||||
#define cb(m, n) \ | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | |||||
bias, workspace, output, LDC, op); | |||||
DISPATCH_M_N(cb, std::min<size_t>(M - m, 4), 8); | |||||
#undef cb | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4), zp_A, zp_B); | |||||
#define cb(m, n) \ | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | |||||
bias, workspace, output, LDC, op); | |||||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
#undef cb | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias += 4; | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#endif | |||||
} // namespace impl | |||||
#if __ARM_FEATURE_DOTPROD | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity) | |||||
void gemm_u8_8x8_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
if (transpose) { | |||||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0, | |||||
ymax, k0, kmax); | |||||
} else { | |||||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin, | |||||
y0, ymax, k0, kmax); | |||||
} | |||||
} | |||||
void gemm_u8_8x8_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, | |||||
int ldin, int x0, int xmax, int k0, | |||||
int kmax, bool transpose) const { | |||||
if (transpose) { | |||||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, | |||||
xmax, k0, kmax); | |||||
} else { | |||||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax, | |||||
k0, kmax); | |||||
} | |||||
} | |||||
#else | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity) | |||||
void gemm_u8_8x8_nobias_identity::pack_A(dt_uint8* outptr, | |||||
const dt_uint8* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||||
if (transpose) { | |||||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||||
ymax, k0, kmax, zA); | |||||
} else { | |||||
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax, zA); | |||||
} | |||||
} | |||||
void gemm_u8_8x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, | |||||
int ldin, int x0, int xmax, int k0, | |||||
int kmax, bool transpose) const { | |||||
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||||
if (transpose) { | |||||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||||
k0, kmax, zB); | |||||
} else { | |||||
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, | |||||
zB); | |||||
} | |||||
} | |||||
#endif | |||||
size_t gemm_u8_8x8_nobias_identity::get_workspace_size() const { | |||||
return 8 * 8 * sizeof(dt_int32); | |||||
} | |||||
#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ | |||||
void gemm_u8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ | |||||
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \ | |||||
size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ | |||||
const dt_int32* bias, dt_int32* workspace) const { \ | |||||
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
DEFINE_OP(_OP); \ | |||||
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ | |||||
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||||
workspace, zp_A, zp_B); \ | |||||
} | |||||
#define DEFINE_OP(_Op) \ | |||||
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, scale_C, zp_C); | |||||
KERN(8, 8, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) | |||||
KERN(8, 8, nobias, BiasMode::NO_BIAS, relu, ReluOp) | |||||
KERN(8, 8, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||||
#undef DEFINE_OP | |||||
#define DEFINE_OP(_Op) \ | |||||
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \ | |||||
scale_A* scale_B, scale_C, zp_C); | |||||
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||||
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||||
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||||
FuseAddHSwishOp) | |||||
#undef DEFINE_OP | |||||
#undef KERN | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,48 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/conv_bias/quint8/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
#if __ARM_FEATURE_DOTPROD | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4, | |||||
false, true, | |||||
gemm_u8_8x8_nobias_identity); | |||||
#else | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8, | |||||
false, true, | |||||
gemm_u8_8x8_nobias_identity); | |||||
#endif | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_relu, | |||||
gemm_u8_8x8_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_hswish, | |||||
gemm_u8_8x8_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_identity, | |||||
gemm_u8_8x8_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_relu, | |||||
gemm_u8_8x8_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_hswish, | |||||
gemm_u8_8x8_nobias_identity); | |||||
} // namespace matmul | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,44 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/handle.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/common/handle_impl.h" | |||||
#include "src/aarch64/handle.h" | |||||
#include "src/aarch64/matrix_mul/opr_impl.h" | |||||
#include "src/aarch64/rotate/opr_impl.h" | |||||
#include "src/aarch64/relayout/opr_impl.h" | |||||
#include "src/aarch64/conv_bias/opr_impl.h" | |||||
#include "src/aarch64/warp_perspective/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
template <typename Opr> | |||||
std::unique_ptr<Opr> HandleImpl::create_operator() { | |||||
return arm_common::HandleImpl::create_operator<Opr>(); | |||||
} | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMul) | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Rotate) | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutForward) | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspective) | |||||
#pragma GCC diagnostic push | |||||
#pragma GCC diagnostic ignored "-Wpragmas" | |||||
#pragma GCC diagnostic ignored "-Winstantiation-after-specialization" | |||||
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) | |||||
#pragma GCC diagnostic pop | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,33 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/handle.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/handle.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
class HandleImpl: public arm_common::HandleImpl { | |||||
public: | |||||
HandleImpl(megcoreComputingHandle_t computing_handle, | |||||
HandleType type = HandleType::AARCH64): | |||||
arm_common::HandleImpl::HandleImpl(computing_handle, type) | |||||
{} | |||||
template <typename Opr> | |||||
std::unique_ptr<Opr> create_operator(); | |||||
}; | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,250 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/algos.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/aarch64/matrix_mul/opr_impl.h" | |||||
#include "src/arm_common/matrix_mul/algos.h" | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_F32K8X12X1"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_F32K4X16X1"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_F32_MK4_4x16"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
}; | |||||
class MatrixMulImpl::AlgoF32Gemv final | |||||
: public arm_common::MatrixMulImpl::AlgoF32Gemv {}; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_F16_K8X24X1"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_F16_MK8_8X8"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
}; | |||||
#endif | |||||
#if __ARM_FEATURE_DOTPROD | |||||
class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; | |||||
} | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoInt8x8x32GemvDotProd final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return "AARCH64_INT8X8X32_GEMV_DOTPROD"; | |||||
} | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
}; | |||||
#else | |||||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return "AARCH64_INT8X8X32_MK4_4X4X16"; | |||||
} | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::DEFAULT; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoInt8x8x32Gemv final | |||||
: public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {}; | |||||
#endif | |||||
class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
}; | |||||
#if __ARM_FEATURE_DOTPROD | |||||
class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return "AARCH64_QUINT8_K8X8X4_DOTPROD"; | |||||
} | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
}; | |||||
#else | |||||
class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
#endif | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,29 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/fp16/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 8, 24, 1, false, | |||||
true, hgemm_8x24); | |||||
MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 8, 8, 1, | |||||
false, true, gemm_nopack_f16_8x8); | |||||
} // namespace matmul | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,439 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/fp16/strategy.h" | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
namespace { | |||||
// Overview of register layout: | |||||
// | |||||
// A 8x1 cell of Rhs is stored in 16bit in v0-v3 | |||||
// A 8x1 cell of Lhs is stored in 16bit in v16-v23 | |||||
// A 8x1 block of accumulators is stored in 16bit in v24-v27. | |||||
// | |||||
// Rhs +-------+ | |||||
// |v0[0-7]| | |||||
// |v1[0-7]| | |||||
// |v2[0-7]| | |||||
// |v3[0-7]| | |||||
// +-------+ | |||||
// Lhs | |||||
// +--------+ | |||||
// |v16[0-7]| | |||||
// |v17[0-7]| | |||||
// |v18[0-7]| | |||||
// |v19[0-7]| +--------+ | |||||
// |v20[0-7]| |v24[0-7]| | |||||
// |v21[0-7]| |v25[0-7]| | |||||
// |v22[0-7]| |v26[0-7]| | |||||
// |v23[0-7]| |v27[0-7]| | |||||
// +--------+ +--------+ | |||||
// Accumulator | |||||
void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
dt_float16* output) { | |||||
//! LDB means number of elements in one block in B. we will read 24 numbers | |||||
//! first. so minus 24 * 2 bytes here. | |||||
LDB = (LDB - 24) * sizeof(dt_float16); | |||||
asm volatile( | |||||
".arch armv8.2-a+fp16\n" | |||||
"ld1 {v16.4s, v17.4s}, [%[a_ptr]], 32\n" | |||||
"subs %w[K], %w[K], #8\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||||
"fmul v24.8h, v16.8h, v0.h[0]\n" | |||||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
"fmul v25.8h, v16.8h, v1.h[0]\n" | |||||
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"fmul v26.8h, v16.8h, v2.h[0]\n" | |||||
"ld1 {v18.4s}, [%[a_ptr]], 16\n" | |||||
"fmul v27.8h, v16.8h, v3.h[0]\n" | |||||
"fmla v24.8h, v17.8h, v0.h[1]\n" | |||||
"fmla v25.8h, v17.8h, v1.h[1]\n" | |||||
"fmla v26.8h, v17.8h, v2.h[1]\n" | |||||
"fmla v27.8h, v17.8h, v3.h[1]\n" | |||||
"ld1 {v19.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v24.8h, v18.8h, v0.h[2]\n" | |||||
"fmla v25.8h, v18.8h, v1.h[2]\n" | |||||
"fmla v26.8h, v18.8h, v2.h[2]\n" | |||||
"fmla v27.8h, v18.8h, v3.h[2]\n" | |||||
"ld1 {v20.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v24.8h, v19.8h, v0.h[3]\n" | |||||
"fmla v25.8h, v19.8h, v1.h[3]\n" | |||||
"fmla v26.8h, v19.8h, v2.h[3]\n" | |||||
"fmla v27.8h, v19.8h, v3.h[3]\n" | |||||
"ld1 {v21.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v24.8h, v20.8h, v0.h[4]\n" | |||||
"fmla v25.8h, v20.8h, v1.h[4]\n" | |||||
"fmla v26.8h, v20.8h, v2.h[4]\n" | |||||
"fmla v27.8h, v20.8h, v3.h[4]\n" | |||||
"ld1 {v22.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v24.8h, v21.8h, v0.h[5]\n" | |||||
"fmla v25.8h, v21.8h, v1.h[5]\n" | |||||
"fmla v26.8h, v21.8h, v2.h[5]\n" | |||||
"fmla v27.8h, v21.8h, v3.h[5]\n" | |||||
"ld1 {v23.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v24.8h, v22.8h, v0.h[6]\n" | |||||
"fmla v25.8h, v22.8h, v1.h[6]\n" | |||||
"fmla v26.8h, v22.8h, v2.h[6]\n" | |||||
"fmla v27.8h, v22.8h, v3.h[6]\n" | |||||
"beq 2f\n" | |||||
"1:\n" | |||||
"ld1 {v16.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v24.8h, v23.8h, v0.h[7]\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v25.8h, v23.8h, v1.h[7]\n" | |||||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v26.8h, v23.8h, v2.h[7]\n" | |||||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v27.8h, v23.8h, v3.h[7]\n" | |||||
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"ld1 {v17.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v24.8h, v16.8h, v0.h[0]\n" | |||||
"fmla v25.8h, v16.8h, v1.h[0]\n" | |||||
"fmla v26.8h, v16.8h, v2.h[0]\n" | |||||
"fmla v27.8h, v16.8h, v3.h[0]\n" | |||||
"ld1 {v18.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v24.8h, v17.8h, v0.h[1]\n" | |||||
"fmla v25.8h, v17.8h, v1.h[1]\n" | |||||
"fmla v26.8h, v17.8h, v2.h[1]\n" | |||||
"fmla v27.8h, v17.8h, v3.h[1]\n" | |||||
"ld1 {v19.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v24.8h, v18.8h, v0.h[2]\n" | |||||
"fmla v25.8h, v18.8h, v1.h[2]\n" | |||||
"fmla v26.8h, v18.8h, v2.h[2]\n" | |||||
"fmla v27.8h, v18.8h, v3.h[2]\n" | |||||
"ld1 {v20.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v24.8h, v19.8h, v0.h[3]\n" | |||||
"fmla v25.8h, v19.8h, v1.h[3]\n" | |||||
"fmla v26.8h, v19.8h, v2.h[3]\n" | |||||
"fmla v27.8h, v19.8h, v3.h[3]\n" | |||||
"ld1 {v21.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v24.8h, v20.8h, v0.h[4]\n" | |||||
"fmla v25.8h, v20.8h, v1.h[4]\n" | |||||
"fmla v26.8h, v20.8h, v2.h[4]\n" | |||||
"fmla v27.8h, v20.8h, v3.h[4]\n" | |||||
"ld1 {v22.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v24.8h, v21.8h, v0.h[5]\n" | |||||
"fmla v25.8h, v21.8h, v1.h[5]\n" | |||||
"fmla v26.8h, v21.8h, v2.h[5]\n" | |||||
"fmla v27.8h, v21.8h, v3.h[5]\n" | |||||
"ld1 {v23.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v24.8h, v22.8h, v0.h[6]\n" | |||||
"fmla v25.8h, v22.8h, v1.h[6]\n" | |||||
"fmla v26.8h, v22.8h, v2.h[6]\n" | |||||
"fmla v27.8h, v22.8h, v3.h[6]\n" | |||||
"subs %w[K], %w[K], #8\n" | |||||
"bne 1b\n" | |||||
"2:\n" | |||||
"fmla v24.8h, v23.8h, v0.h[7]\n" | |||||
"fmla v25.8h, v23.8h, v1.h[7]\n" | |||||
"fmla v26.8h, v23.8h, v2.h[7]\n" | |||||
"fmla v27.8h, v23.8h, v3.h[7]\n" | |||||
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[output] "+r"(output), [LDB] "+r"(LDB) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory"); | |||||
} | |||||
// Overview of register layout: | |||||
// | |||||
// A 8x1 cell of Rhs is stored in 16bit in v8-v15 | |||||
// A 8x1 cell of Lhs is stored in 16bit in v0-v7 | |||||
// A 8x1 block of accumulators is stored in 16bit in v24-v31. | |||||
// | |||||
// Rhs +--------+ | |||||
// | v8[0-7]| | |||||
// | v9[0-7]| | |||||
// |v10[0-7]| | |||||
// |v11[0-7]| | |||||
// |v12[0-7]| | |||||
// |v13[0-7]| | |||||
// |v14[0-7]| | |||||
// |v15[0-7]| | |||||
// +--------+ | |||||
// Lhs | |||||
// +--------+ - - - - -+--------+ | |||||
// | v0[0-7]| |v24[0-7]| | |||||
// | v1[0-7]| |v25[0-7]| | |||||
// | v2[0-7]| |v26[0-7]| | |||||
// | v3[0-7]| |v27[0-7]| | |||||
// | v4[0-7]| |v28[0-7]| | |||||
// | v5[0-7]| |v29[0-7]| | |||||
// | v6[0-7]| |v30[0-7]| | |||||
// | v7[0-7]| |v31[0-7]| | |||||
// +--------+ +--------+ | |||||
// Accumulator | |||||
void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
dt_float16* output) { | |||||
//! As each load 128 number from B, but the pos add 112 * 2, so we minus 112 | |||||
//! here. | |||||
LDB = (LDB - 32) * sizeof(dt_float16); | |||||
asm volatile( | |||||
".arch armv8.2-a+fp16\n" | |||||
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n" | |||||
"subs %w[K], %w[K], #8\n" | |||||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"fmul v24.8h, v8.8h, v0.h[0]\n" | |||||
"fmul v25.8h, v8.8h, v1.h[0]\n" | |||||
"fmul v26.8h, v8.8h, v2.h[0]\n" | |||||
"fmul v27.8h, v8.8h, v3.h[0]\n" | |||||
"fmul v28.8h, v8.8h, v4.h[0]\n" | |||||
"fmul v29.8h, v8.8h, v5.h[0]\n" | |||||
"fmul v30.8h, v8.8h, v6.h[0]\n" | |||||
"fmul v31.8h, v8.8h, v7.h[0]\n" | |||||
"fmla v24.8h, v9.8h, v0.h[1]\n" | |||||
"fmla v25.8h, v9.8h, v1.h[1]\n" | |||||
"fmla v26.8h, v9.8h, v2.h[1]\n" | |||||
"fmla v27.8h, v9.8h, v3.h[1]\n" | |||||
"fmla v28.8h, v9.8h, v4.h[1]\n" | |||||
"fmla v29.8h, v9.8h, v5.h[1]\n" | |||||
"fmla v30.8h, v9.8h, v6.h[1]\n" | |||||
"fmla v31.8h, v9.8h, v7.h[1]\n" | |||||
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n" | |||||
"fmla v24.8h, v10.8h, v0.h[2]\n" | |||||
"fmla v25.8h, v10.8h, v1.h[2]\n" | |||||
"fmla v26.8h, v10.8h, v2.h[2]\n" | |||||
"fmla v27.8h, v10.8h, v3.h[2]\n" | |||||
"fmla v28.8h, v10.8h, v4.h[2]\n" | |||||
"fmla v29.8h, v10.8h, v5.h[2]\n" | |||||
"fmla v30.8h, v10.8h, v6.h[2]\n" | |||||
"fmla v31.8h, v10.8h, v7.h[2]\n" | |||||
"fmla v24.8h, v11.8h, v0.h[3]\n" | |||||
"fmla v25.8h, v11.8h, v1.h[3]\n" | |||||
"fmla v26.8h, v11.8h, v2.h[3]\n" | |||||
"fmla v27.8h, v11.8h, v3.h[3]\n" | |||||
"fmla v28.8h, v11.8h, v4.h[3]\n" | |||||
"fmla v29.8h, v11.8h, v5.h[3]\n" | |||||
"fmla v30.8h, v11.8h, v6.h[3]\n" | |||||
"fmla v31.8h, v11.8h, v7.h[3]\n" | |||||
"fmla v24.8h, v12.8h, v0.h[4]\n" | |||||
"fmla v25.8h, v12.8h, v1.h[4]\n" | |||||
"fmla v26.8h, v12.8h, v2.h[4]\n" | |||||
"fmla v27.8h, v12.8h, v3.h[4]\n" | |||||
"fmla v24.8h, v13.8h, v0.h[5]\n" | |||||
"fmla v25.8h, v13.8h, v1.h[5]\n" | |||||
"fmla v26.8h, v13.8h, v2.h[5]\n" | |||||
"fmla v27.8h, v13.8h, v3.h[5]\n" | |||||
"beq 2f\n" | |||||
"1:\n" | |||||
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n" | |||||
"fmla v24.8h, v15.8h, v0.h[7]\n" | |||||
"fmla v25.8h, v15.8h, v1.h[7]\n" | |||||
"fmla v26.8h, v15.8h, v2.h[7]\n" | |||||
"fmla v27.8h, v15.8h, v3.h[7]\n" | |||||
"fmla v24.8h, v14.8h, v0.h[6]\n" | |||||
"fmla v25.8h, v14.8h, v1.h[6]\n" | |||||
"fmla v26.8h, v14.8h, v2.h[6]\n" | |||||
"fmla v27.8h, v14.8h, v3.h[6]\n" | |||||
"fmla v28.8h, v12.8h, v4.h[4]\n" | |||||
"fmla v29.8h, v12.8h, v5.h[4]\n" | |||||
"fmla v30.8h, v12.8h, v6.h[4]\n" | |||||
"fmla v31.8h, v12.8h, v7.h[4]\n" | |||||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||||
"fmla v28.8h, v13.8h, v4.h[5]\n" | |||||
"fmla v29.8h, v13.8h, v5.h[5]\n" | |||||
"fmla v30.8h, v13.8h, v6.h[5]\n" | |||||
"fmla v31.8h, v13.8h, v7.h[5]\n" | |||||
"fmla v28.8h, v14.8h, v4.h[6]\n" | |||||
"fmla v29.8h, v14.8h, v5.h[6]\n" | |||||
"fmla v30.8h, v14.8h, v6.h[6]\n" | |||||
"fmla v31.8h, v14.8h, v7.h[6]\n" | |||||
"fmla v28.8h, v15.8h, v4.h[7]\n" | |||||
"fmla v29.8h, v15.8h, v5.h[7]\n" | |||||
"fmla v30.8h, v15.8h, v6.h[7]\n" | |||||
"fmla v31.8h, v15.8h, v7.h[7]\n" | |||||
"fmla v24.8h, v8.8h, v0.h[0]\n" | |||||
"fmla v25.8h, v8.8h, v1.h[0]\n" | |||||
"fmla v26.8h, v8.8h, v2.h[0]\n" | |||||
"fmla v27.8h, v8.8h, v3.h[0]\n" | |||||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"fmla v24.8h, v9.8h, v0.h[1]\n" | |||||
"fmla v25.8h, v9.8h, v1.h[1]\n" | |||||
"fmla v26.8h, v9.8h, v2.h[1]\n" | |||||
"fmla v27.8h, v9.8h, v3.h[1]\n" | |||||
"fmla v24.8h, v10.8h, v0.h[2]\n" | |||||
"fmla v25.8h, v10.8h, v1.h[2]\n" | |||||
"fmla v26.8h, v10.8h, v2.h[2]\n" | |||||
"fmla v27.8h, v10.8h, v3.h[2]\n" | |||||
"fmla v24.8h, v11.8h, v0.h[3]\n" | |||||
"fmla v25.8h, v11.8h, v1.h[3]\n" | |||||
"fmla v26.8h, v11.8h, v2.h[3]\n" | |||||
"fmla v27.8h, v11.8h, v3.h[3]\n" | |||||
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n" | |||||
"fmla v28.8h, v10.8h, v4.h[2]\n" | |||||
"fmla v29.8h, v10.8h, v5.h[2]\n" | |||||
"fmla v30.8h, v10.8h, v6.h[2]\n" | |||||
"fmla v31.8h, v10.8h, v7.h[2]\n" | |||||
"fmla v28.8h, v8.8h, v4.h[0]\n" | |||||
"fmla v29.8h, v8.8h, v5.h[0]\n" | |||||
"fmla v30.8h, v8.8h, v6.h[0]\n" | |||||
"fmla v31.8h, v8.8h, v7.h[0]\n" | |||||
"fmla v28.8h, v9.8h, v4.h[1]\n" | |||||
"fmla v29.8h, v9.8h, v5.h[1]\n" | |||||
"fmla v30.8h, v9.8h, v6.h[1]\n" | |||||
"fmla v31.8h, v9.8h, v7.h[1]\n" | |||||
"fmla v28.8h, v11.8h, v4.h[3]\n" | |||||
"fmla v29.8h, v11.8h, v5.h[3]\n" | |||||
"fmla v30.8h, v11.8h, v6.h[3]\n" | |||||
"fmla v31.8h, v11.8h, v7.h[3]\n" | |||||
"fmla v24.8h, v12.8h, v0.h[4]\n" | |||||
"fmla v25.8h, v12.8h, v1.h[4]\n" | |||||
"fmla v26.8h, v12.8h, v2.h[4]\n" | |||||
"fmla v27.8h, v12.8h, v3.h[4]\n" | |||||
"fmla v24.8h, v13.8h, v0.h[5]\n" | |||||
"fmla v25.8h, v13.8h, v1.h[5]\n" | |||||
"fmla v26.8h, v13.8h, v2.h[5]\n" | |||||
"fmla v27.8h, v13.8h, v3.h[5]\n" | |||||
"subs %w[K], %w[K], #8\n" | |||||
"bne 1b\n" | |||||
"2:\n" | |||||
"fmla v24.8h, v14.8h, v0.h[6]\n" | |||||
"fmla v25.8h, v14.8h, v1.h[6]\n" | |||||
"fmla v26.8h, v14.8h, v2.h[6]\n" | |||||
"fmla v27.8h, v14.8h, v3.h[6]\n" | |||||
"fmla v24.8h, v15.8h, v0.h[7]\n" | |||||
"fmla v25.8h, v15.8h, v1.h[7]\n" | |||||
"fmla v26.8h, v15.8h, v2.h[7]\n" | |||||
"fmla v27.8h, v15.8h, v3.h[7]\n" | |||||
"fmla v28.8h, v12.8h, v4.h[4]\n" | |||||
"fmla v29.8h, v12.8h, v5.h[4]\n" | |||||
"fmla v28.8h, v13.8h, v4.h[5]\n" | |||||
"fmla v29.8h, v13.8h, v5.h[5]\n" | |||||
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n" | |||||
"fmla v28.8h, v14.8h, v4.h[6]\n" | |||||
"fmla v29.8h, v14.8h, v5.h[6]\n" | |||||
"fmla v28.8h, v15.8h, v4.h[7]\n" | |||||
"fmla v29.8h, v15.8h, v5.h[7]\n" | |||||
"fmla v30.8h, v12.8h, v6.h[4]\n" | |||||
"fmla v31.8h, v12.8h, v7.h[4]\n" | |||||
"fmla v30.8h, v13.8h, v6.h[5]\n" | |||||
"fmla v31.8h, v13.8h, v7.h[5]\n" | |||||
"fmla v30.8h, v14.8h, v6.h[6]\n" | |||||
"fmla v31.8h, v14.8h, v7.h[6]\n" | |||||
"fmla v30.8h, v15.8h, v6.h[7]\n" | |||||
"fmla v31.8h, v15.8h, v7.h[7]\n" | |||||
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[output] "+r"(output), [LDB] "+r"(LDB) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", | |||||
"v28", "v29", "v30", "v31", "cc", "memory"); | |||||
} | |||||
} // anonymous namespace | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8); | |||||
void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, | |||||
const dt_float16* B, size_t LDB, dt_float16* C, | |||||
size_t LDC, size_t M, size_t K, size_t N, | |||||
const dt_float16*, void*, bool trA, | |||||
bool trB) const { | |||||
constexpr static size_t MB = 8; | |||||
constexpr static size_t KB = 8; | |||||
constexpr static size_t NB = 8; | |||||
constexpr static size_t CALCBLK = 4; | |||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||||
for (size_t m = 0; m < M; m += MB) { | |||||
dt_float16* output = C + (m / MB) * LDC; | |||||
const dt_float16* cur_B = B; | |||||
size_t n = 0; | |||||
for (; n + NB - 1 < N; n += NB) { | |||||
kern_8x8(A, cur_B, LDB, K, output); | |||||
cur_B += KB * NB; | |||||
output += MB * NB; | |||||
} | |||||
if (n < N) { | |||||
kern_8x4(A, cur_B, LDB, K, output); | |||||
} | |||||
A += LDA; | |||||
} | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,39 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/fp32/common.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cstddef> | |||||
#include "megdnn/arch.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
MEGDNN_NOINLINE void sgemm_packA_n(const float* A, float* Apacked, size_t M, | |||||
size_t K, size_t LDA, const float* alpha); | |||||
MEGDNN_NOINLINE void sgemm_packA_t(const float* A, float* Apacked, size_t M, | |||||
size_t K, size_t LDA, const float* alpha); | |||||
MEGDNN_NOINLINE void sgemm_packB_n(const float* B, float* Bpacked, size_t K, | |||||
size_t N, size_t LDB); | |||||
MEGDNN_NOINLINE void sgemm_packB_t(const float* B, float* Bpacked, size_t K, | |||||
size_t N, size_t LDB); | |||||
MEGDNN_NOINLINE void sgemm_kernel12x8(const float* A, const float* B, float* C, | |||||
size_t LDC, size_t M, size_t N, size_t K, | |||||
int type, const float* beta); | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,718 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul_general_4x16 { | |||||
// Overview of register layout: | |||||
// | |||||
// A 1x16 cell of Rhs is stored in 32bit in v1-v4 | |||||
// A 4x1 cell of Lhs is stored in 32bit in v0 | |||||
// A 4x16 block of accumulators is stored in 32bit in v10-v25. | |||||
// | |||||
// +--------+--------+--------+--------+ | |||||
// | v2[0-3]| v3[0-3]| v4[0-3]| v5[0-3]| | |||||
// Rhs +--------+--------+--------+--------+ | |||||
// | |||||
// | | | | | | |||||
// | |||||
// Lhs | | | | | | |||||
// | |||||
// +--+ - - - - +--------+--------+--------+--------+ | |||||
// |v0| |v10[0-3]|v11[0-3]|v12[0-3]|v13[0-3]| | |||||
// |v0| |v14[0-3]|v15[0-3]|v16[0-3]|v17[0-3]| | |||||
// |v0| |v18[0-3]|v19[0-3]|v20[0-3]|v21[0-3]| | |||||
// |v0| |v22[0-3]|v23[0-3]|v24[0-3]|v25[0-3]| | |||||
// +--+ - - - - +--------+--------+--------+--------+ | |||||
// | |||||
// Accumulator | |||||
void kern_4x16(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, int m_remain) { | |||||
const float* a_ptr = packA; | |||||
const float* b_ptr = packB; | |||||
int oddk = (K & 1); | |||||
K = ((K + 1) / 2) - 1; | |||||
LDC = LDC * sizeof(float); | |||||
register float* outptr asm("x0") = reinterpret_cast<float*>(output); | |||||
// clang-format off | |||||
#define LOAD_LINE(v0, v1, v2, v3, n) \ | |||||
"cmp x10, #0\n" \ | |||||
"beq 100f\n" \ | |||||
"mov x9, x" n "\n" \ | |||||
"ld1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s, v" v3 ".4s}, [x9], 64\n" \ | |||||
"subs x10, x10, #1\n" | |||||
#define LOAD_C \ | |||||
"mov x10, %x[m_remain]\n" \ | |||||
LOAD_LINE("10", "11", "12", "13", "0") \ | |||||
LOAD_LINE("14", "15", "16", "17", "1") \ | |||||
LOAD_LINE("18", "19", "20", "21", "2") \ | |||||
LOAD_LINE("22", "23", "24", "25", "3") \ | |||||
"100:\n" | |||||
#define STORE_LINE(v0, v1, v2, v3, n) \ | |||||
"cmp x10, #0\n" \ | |||||
"beq 101f\n" \ | |||||
"mov x9, x" n "\n" \ | |||||
"st1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s, v" v3 ".4s}, [x9], 64\n" \ | |||||
"subs x10, x10, #1\n" | |||||
#define STORE_C \ | |||||
"mov x10, %x[m_remain]\n" \ | |||||
STORE_LINE("10", "11", "12", "13", "0") \ | |||||
STORE_LINE("14", "15", "16", "17", "1") \ | |||||
STORE_LINE("18", "19", "20", "21", "2") \ | |||||
STORE_LINE("22", "23", "24", "25", "3") \ | |||||
"101:\n" | |||||
// clang-format on | |||||
asm volatile( | |||||
// load accumulator C | |||||
"add x1, x0, %x[LDC]\n" | |||||
"add x2, x1, %x[LDC]\n" | |||||
"add x3, x2, %x[LDC]\n" | |||||
"cmp %w[is_first_k], #1\n" | |||||
"beq 1f\n" LOAD_C | |||||
"b 2f\n" | |||||
"1:\n" | |||||
"eor v10.16b, v10.16b, v10.16b\n" | |||||
"eor v11.16b, v11.16b, v11.16b\n" | |||||
"eor v12.16b, v12.16b, v12.16b\n" | |||||
"eor v13.16b, v13.16b, v13.16b\n" | |||||
"eor v14.16b, v14.16b, v14.16b\n" | |||||
"eor v15.16b, v15.16b, v15.16b\n" | |||||
"eor v16.16b, v16.16b, v16.16b\n" | |||||
"eor v17.16b, v17.16b, v17.16b\n" | |||||
"eor v18.16b, v18.16b, v18.16b\n" | |||||
"eor v19.16b, v19.16b, v19.16b\n" | |||||
"eor v20.16b, v20.16b, v20.16b\n" | |||||
"eor v21.16b, v21.16b, v21.16b\n" | |||||
"eor v22.16b, v22.16b, v22.16b\n" | |||||
"eor v23.16b, v23.16b, v23.16b\n" | |||||
"eor v24.16b, v24.16b, v24.16b\n" | |||||
"eor v25.16b, v25.16b, v25.16b\n" | |||||
"2: \n" | |||||
"ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], 64\n" | |||||
"cmp %w[K], #0\n" | |||||
"beq 4f\n" | |||||
"3:\n" | |||||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v10.4s, v2.4s, v0.s[0]\n" | |||||
"fmla v11.4s, v3.4s, v0.s[0]\n" | |||||
"fmla v12.4s, v4.4s, v0.s[0]\n" | |||||
"fmla v13.4s, v5.4s, v0.s[0]\n" | |||||
"ld1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[b_ptr]], 64\n" | |||||
"fmla v14.4s, v2.4s, v0.s[1]\n" | |||||
"fmla v15.4s, v3.4s, v0.s[1]\n" | |||||
"fmla v16.4s, v4.4s, v0.s[1]\n" | |||||
"fmla v17.4s, v5.4s, v0.s[1]\n" | |||||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v18.4s, v2.4s, v0.s[2]\n" | |||||
"fmla v19.4s, v3.4s, v0.s[2]\n" | |||||
"fmla v20.4s, v4.4s, v0.s[2]\n" | |||||
"fmla v21.4s, v5.4s, v0.s[2]\n" | |||||
"fmla v22.4s, v2.4s, v0.s[3]\n" | |||||
"fmla v23.4s, v3.4s, v0.s[3]\n" | |||||
"fmla v24.4s, v4.4s, v0.s[3]\n" | |||||
"fmla v25.4s, v5.4s, v0.s[3]\n" | |||||
"ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], 64\n" | |||||
"fmla v10.4s, v6.4s, v1.s[0]\n" | |||||
"fmla v11.4s, v7.4s, v1.s[0]\n" | |||||
"fmla v12.4s, v8.4s, v1.s[0]\n" | |||||
"fmla v13.4s, v9.4s, v1.s[0]\n" | |||||
"fmla v14.4s, v6.4s, v1.s[1]\n" | |||||
"fmla v15.4s, v7.4s, v1.s[1]\n" | |||||
"fmla v16.4s, v8.4s, v1.s[1]\n" | |||||
"fmla v17.4s, v9.4s, v1.s[1]\n" | |||||
"fmla v18.4s, v6.4s, v1.s[2]\n" | |||||
"fmla v19.4s, v7.4s, v1.s[2]\n" | |||||
"fmla v20.4s, v8.4s, v1.s[2]\n" | |||||
"fmla v21.4s, v9.4s, v1.s[2]\n" | |||||
"fmla v22.4s, v6.4s, v1.s[3]\n" | |||||
"fmla v23.4s, v7.4s, v1.s[3]\n" | |||||
"fmla v24.4s, v8.4s, v1.s[3]\n" | |||||
"fmla v25.4s, v9.4s, v1.s[3]\n" | |||||
"subs %w[K], %w[K], #1\n" | |||||
"bne 3b\n" | |||||
"4:\n" | |||||
"cmp %w[oddk], #1\n" | |||||
"beq 5f\n" | |||||
// Even tail | |||||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v10.4s, v2.4s, v0.s[0]\n" | |||||
"fmla v11.4s, v3.4s, v0.s[0]\n" | |||||
"fmla v12.4s, v4.4s, v0.s[0]\n" | |||||
"fmla v13.4s, v5.4s, v0.s[0]\n" | |||||
"ld1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[b_ptr]], 64\n" | |||||
"fmla v14.4s, v2.4s, v0.s[1]\n" | |||||
"fmla v15.4s, v3.4s, v0.s[1]\n" | |||||
"fmla v16.4s, v4.4s, v0.s[1]\n" | |||||
"fmla v17.4s, v5.4s, v0.s[1]\n" | |||||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v18.4s, v2.4s, v0.s[2]\n" | |||||
"fmla v19.4s, v3.4s, v0.s[2]\n" | |||||
"fmla v20.4s, v4.4s, v0.s[2]\n" | |||||
"fmla v21.4s, v5.4s, v0.s[2]\n" | |||||
"fmla v22.4s, v2.4s, v0.s[3]\n" | |||||
"fmla v23.4s, v3.4s, v0.s[3]\n" | |||||
"fmla v24.4s, v4.4s, v0.s[3]\n" | |||||
"fmla v25.4s, v5.4s, v0.s[3]\n" | |||||
"fmla v10.4s, v6.4s, v1.s[0]\n" | |||||
"fmla v11.4s, v7.4s, v1.s[0]\n" | |||||
"fmla v12.4s, v8.4s, v1.s[0]\n" | |||||
"fmla v13.4s, v9.4s, v1.s[0]\n" | |||||
"fmla v14.4s, v6.4s, v1.s[1]\n" | |||||
"fmla v15.4s, v7.4s, v1.s[1]\n" | |||||
"fmla v16.4s, v8.4s, v1.s[1]\n" | |||||
"fmla v17.4s, v9.4s, v1.s[1]\n" | |||||
"fmla v18.4s, v6.4s, v1.s[2]\n" | |||||
"fmla v19.4s, v7.4s, v1.s[2]\n" | |||||
"fmla v20.4s, v8.4s, v1.s[2]\n" | |||||
"fmla v21.4s, v9.4s, v1.s[2]\n" | |||||
"fmla v22.4s, v6.4s, v1.s[3]\n" | |||||
"fmla v23.4s, v7.4s, v1.s[3]\n" | |||||
"fmla v24.4s, v8.4s, v1.s[3]\n" | |||||
"fmla v25.4s, v9.4s, v1.s[3]\n" | |||||
"b 6f\n" | |||||
// odd tail | |||||
"5:\n" | |||||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v10.4s, v2.4s, v0.s[0]\n" | |||||
"fmla v11.4s, v3.4s, v0.s[0]\n" | |||||
"fmla v12.4s, v4.4s, v0.s[0]\n" | |||||
"fmla v13.4s, v5.4s, v0.s[0]\n" | |||||
"fmla v14.4s, v2.4s, v0.s[1]\n" | |||||
"fmla v15.4s, v3.4s, v0.s[1]\n" | |||||
"fmla v16.4s, v4.4s, v0.s[1]\n" | |||||
"fmla v17.4s, v5.4s, v0.s[1]\n" | |||||
"fmla v18.4s, v2.4s, v0.s[2]\n" | |||||
"fmla v19.4s, v3.4s, v0.s[2]\n" | |||||
"fmla v20.4s, v4.4s, v0.s[2]\n" | |||||
"fmla v21.4s, v5.4s, v0.s[2]\n" | |||||
"fmla v22.4s, v2.4s, v0.s[3]\n" | |||||
"fmla v23.4s, v3.4s, v0.s[3]\n" | |||||
"fmla v24.4s, v4.4s, v0.s[3]\n" | |||||
"fmla v25.4s, v5.4s, v0.s[3]\n" | |||||
"6:\n" STORE_C | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[m_remain] "+r"(m_remain), [outptr] "+r"(outptr) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", | |||||
"x10", "cc", "memory"); | |||||
#undef LOAD_LINE | |||||
#undef LOAD_C | |||||
#undef STORE_LINE | |||||
#undef STORE_C | |||||
} | |||||
// Overview of register layout: | |||||
// | |||||
// A 2x4 cell of Rhs is stored in 32bit in v2 - v3 | |||||
// A 4x2 cell of Lhs is stored in 32bit in v0 - v1 | |||||
// A 4x4 block of accumulators is stored in 32bit in v4-v6 | |||||
// | |||||
// +--------+ | |||||
// | v2[0-3]| | |||||
// Rhs +--------+ | |||||
// | v3[0-3]| | |||||
// +--------+ | |||||
// | |||||
// | | | |||||
// | |||||
// Lhs | | | |||||
// | |||||
// +--+--+ - - - - +--------+ | |||||
// |v0|v1| | v4[0-3]| | |||||
// |v0|v1| | v5[0-3]| | |||||
// |v0|v1| | v6[0-3]| | |||||
// |v0|v1| | v7[0-3]| | |||||
// +--+--+ - - - - +--------+ | |||||
// | |||||
// Accumulator | |||||
void kern_4x4(const float* packA, const float* packB, int K, float* output, | |||||
int LDC, bool is_first_k, int m_remain, int n_remain) { | |||||
const float* a_ptr = packA; | |||||
const float* b_ptr = packB; | |||||
int oddk = (K & 1); | |||||
K = ((K + 1) / 2) - 1; | |||||
LDC = LDC * sizeof(float); | |||||
register float* outptr asm("x0") = output; | |||||
// clang-format off | |||||
#define LOAD_LINE(v0, n) \ | |||||
"cmp x10, #0\n" \ | |||||
"beq 102f\n" \ | |||||
"cmp %w[n_remain], #4\n" \ | |||||
"blt 100" n "f\n" \ | |||||
"ld1 {v" v0 ".4s}, [x" n "], 16\n" \ | |||||
"b 101" n "f\n" \ | |||||
"100" n ":\n" \ | |||||
"cmp %w[n_remain], #0\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" v0 ".s}[0], [x" n "], 4\n" \ | |||||
"cmp %w[n_remain], #1\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" v0 ".s}[1], [x" n "], 4\n" \ | |||||
"cmp %w[n_remain], #2\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" v0 ".s}[2], [x" n "], 4\n" \ | |||||
"101" n ":\n" \ | |||||
"subs x10, x10, #1\n" | |||||
#define LOAD_C \ | |||||
"mov x10, %x[m_remain]\n" \ | |||||
LOAD_LINE("4", "0") \ | |||||
LOAD_LINE("5", "1") \ | |||||
LOAD_LINE("6", "2") \ | |||||
LOAD_LINE("7", "3") \ | |||||
"102:\n" | |||||
#define STORE_LINE(v0, n) \ | |||||
"cmp x10, #0 \n" \ | |||||
"beq 105f\n" \ | |||||
"cmp %w[n_remain], #4\n" \ | |||||
"blt 103" n "f\n" \ | |||||
"st1 {v" v0 ".4s}, [x" n " ], 16\n" \ | |||||
"b 104" n "f\n" \ | |||||
"103" n ":\n" \ | |||||
"cmp %w[n_remain], #0\n" \ | |||||
"beq 104" n "f\n" \ | |||||
"st1 {v" v0 ".s}[0], [x" n "], 4\n" \ | |||||
"cmp %w[n_remain], #1\n" \ | |||||
"beq 104" n "f\n" \ | |||||
"st1 {v" v0 ".s}[1], [x" n "], 4\n" \ | |||||
"cmp %w[n_remain], #2\n" \ | |||||
"beq 104" n "f\n" \ | |||||
"st1 {v" v0 ".s}[2], [x" n "], 4\n" \ | |||||
"104" n ":\n" \ | |||||
"subs x10, x10, #1\n" | |||||
#define STORE_C \ | |||||
"mov x10, %x[m_remain]\n" \ | |||||
STORE_LINE("4", "0") \ | |||||
STORE_LINE("5", "1") \ | |||||
STORE_LINE("6", "2") \ | |||||
STORE_LINE("7", "3") \ | |||||
"105:\n" | |||||
// clang-format on | |||||
asm volatile( | |||||
// load accumulator C | |||||
"add x1, x0, %x[LDC]\n" | |||||
"add x2, x1, %x[LDC]\n" | |||||
"add x3, x2, %x[LDC]\n" | |||||
"cmp %w[is_first_k], #1\n" | |||||
"beq 1f\n" LOAD_C | |||||
"b 2f\n" | |||||
"1:\n" | |||||
"eor v4.16b, v4.16b, v4.16b\n" | |||||
"eor v5.16b, v5.16b, v5.16b\n" | |||||
"eor v6.16b, v6.16b, v6.16b\n" | |||||
"eor v7.16b, v7.16b, v7.16b\n" | |||||
"2: \n" | |||||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
"cmp %w[K], #0\n" | |||||
"beq 4f\n" | |||||
"3:\n" | |||||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||||
"ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v4.4s, v3.4s, v1.s[0]\n" | |||||
"fmla v5.4s, v3.4s, v1.s[1]\n" | |||||
"fmla v6.4s, v3.4s, v1.s[2]\n" | |||||
"fmla v7.4s, v3.4s, v1.s[3]\n" | |||||
"subs %w[K], %w[K], #1\n" | |||||
"bne 3b\n" | |||||
"4:\n" | |||||
"cmp %w[oddk], #1\n" | |||||
"beq 5f\n" | |||||
// Even tail | |||||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||||
"ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
"fmla v4.4s, v3.4s, v1.s[0]\n" | |||||
"fmla v5.4s, v3.4s, v1.s[1]\n" | |||||
"fmla v6.4s, v3.4s, v1.s[2]\n" | |||||
"fmla v7.4s, v3.4s, v1.s[3]\n" | |||||
"b 6f\n" | |||||
// odd tail | |||||
"5:\n" | |||||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
"6:\n" STORE_C | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [m_remain] "+r"(m_remain), | |||||
[n_remain] "+r"(n_remain), [outptr] "+r"(outptr) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", | |||||
"x2", "x3", "x10", "cc", "memory"); | |||||
#undef LOAD_LINE | |||||
#undef LOAD_C | |||||
#undef STORE_LINE | |||||
#undef STORE_C | |||||
} | |||||
void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||||
int ymax, int k0, int kmax) { | |||||
float zerobuff[4]; | |||||
std::memset(zerobuff, 0, sizeof(float) * 4); | |||||
constexpr int PACK_SIZE = 4*4; | |||||
int y = y0; | |||||
for (; y + 3 < ymax; y += 4) { | |||||
// printf("main loop pack_a_n %p \n",outptr); | |||||
const float* inptr0 = inptr + y * ldin + k0; | |||||
const float* inptr1 = inptr0 + ldin; | |||||
const float* inptr2 = inptr1 + ldin; | |||||
const float* inptr3 = inptr2 + ldin; | |||||
prefetch_2x(inptr0); | |||||
prefetch_2x(inptr1); | |||||
prefetch_2x(inptr2); | |||||
prefetch_2x(inptr3); | |||||
int K = (kmax - k0); | |||||
for (; K > 3; K -= 4) { | |||||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); | |||||
outptr += PACK_SIZE; | |||||
} | |||||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); | |||||
} | |||||
for (; y < ymax; y += 4) { | |||||
const float* inptr0 = inptr + y * ldin + k0; | |||||
const float* inptr1 = inptr0 + ldin; | |||||
const float* inptr2 = inptr1 + ldin; | |||||
const float* inptr3 = inptr2 + ldin; | |||||
prefetch_2x(inptr0); | |||||
prefetch_2x(inptr1); | |||||
prefetch_2x(inptr2); | |||||
prefetch_2x(inptr3); | |||||
int K = (kmax - k0); | |||||
for (; K > 3; K -= 4) { | |||||
if ((y + 3) >= ymax) { | |||||
switch ((y + 3) - ymax) { | |||||
/* Everything falls through in here */ | |||||
case 2: | |||||
inptr1 = zerobuff; | |||||
case 1: | |||||
inptr2 = zerobuff; | |||||
case 0: | |||||
inptr3 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); | |||||
outptr += PACK_SIZE; | |||||
} | |||||
if (K > 0) { | |||||
if (y + 3 >= ymax) { | |||||
switch (y + 3 - ymax) { | |||||
case 2: | |||||
inptr1 = zerobuff; | |||||
case 1: | |||||
inptr2 = zerobuff; | |||||
case 0: | |||||
inptr3 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); | |||||
} | |||||
} | |||||
} | |||||
void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax) { | |||||
int ksize = kmax - k0; | |||||
int ksize4 = (ksize << 2); | |||||
float* outptr_base = out; | |||||
int k = k0; | |||||
for (; k + 3 < kmax; k += 4) { | |||||
const float* inptr = in + k * ldin + x0; | |||||
const float* inptr1 = inptr + ldin; | |||||
const float* inptr2 = inptr1 + ldin; | |||||
const float* inptr3 = inptr2 + ldin; | |||||
prefetch_3x(inptr); | |||||
prefetch_3x(inptr1); | |||||
prefetch_3x(inptr2); | |||||
prefetch_3x(inptr3); | |||||
int x = x0; | |||||
auto outptr = outptr_base; | |||||
for (; x + 4 <= xmax; x += 4) { | |||||
auto outptr_interleave = outptr; | |||||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||||
outptr_interleave); | |||||
outptr += ksize4; | |||||
} | |||||
if (x < xmax) { | |||||
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); | |||||
} | |||||
outptr_base += 4 * 4; | |||||
} | |||||
for (; k < kmax; k++) { | |||||
const float* inptr = in + k * ldin + x0; | |||||
prefetch_3x(inptr); | |||||
int x = x0; | |||||
auto outptr = outptr_base; | |||||
for (; x + 4 <= xmax; x += 4) { | |||||
auto outptr_interleave = outptr; | |||||
interleave_1x4_1_s(inptr, outptr_interleave); | |||||
outptr += ksize4; | |||||
} | |||||
if (x < xmax) { | |||||
interleave_1(inptr, outptr, 4, xmax - x); | |||||
} | |||||
outptr_base += 4; | |||||
} | |||||
} | |||||
void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
int ksize = kmax - k0; | |||||
int ksize16 = ksize * 16; | |||||
int ksize4 = (ksize << 2); | |||||
float* outptr_base = out; | |||||
float* outptr_base4 = outptr_base + (xmax - x0) / 16 * ksize16; | |||||
int k = k0; | |||||
for (; k + 3 < kmax; k += 4) { | |||||
const float* inptr = in + k * ldin + x0; | |||||
const float* inptr1 = inptr + ldin; | |||||
const float* inptr2 = inptr1 + ldin; | |||||
const float* inptr3 = inptr2 + ldin; | |||||
prefetch_3x(inptr); | |||||
prefetch_3x(inptr1); | |||||
prefetch_3x(inptr2); | |||||
prefetch_3x(inptr3); | |||||
int x = x0; | |||||
auto outptr = outptr_base; | |||||
for (; x + 16 <= xmax; x += 16) { | |||||
auto outptr_interleave = outptr; | |||||
interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, | |||||
outptr_interleave); | |||||
outptr += ksize16; | |||||
} | |||||
outptr = outptr_base4; | |||||
for (; x + 4 <= xmax; x += 4) { | |||||
auto outptr_interleave = outptr; | |||||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||||
outptr_interleave); | |||||
outptr += ksize4; | |||||
} | |||||
if (x < xmax) { | |||||
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); | |||||
} | |||||
outptr_base += 16 * 4; | |||||
outptr_base4 += 4 * 4; | |||||
} | |||||
for (; k < kmax; k++) { | |||||
const float* inptr = in + k * ldin + x0; | |||||
prefetch_3x(inptr); | |||||
int x = x0; | |||||
auto outptr = outptr_base; | |||||
for (; x + 16 <= xmax; x += 16) { | |||||
auto outptr_interleave = outptr; | |||||
interleave_1x16_1_s(inptr, outptr_interleave); | |||||
outptr += ksize16; | |||||
} | |||||
outptr = outptr_base4; | |||||
for (; x + 4 <= xmax; x += 4) { | |||||
auto outptr_interleave = outptr; | |||||
interleave_1x4_1_s(inptr, outptr_interleave); | |||||
outptr += ksize4; | |||||
} | |||||
if (x < xmax) { | |||||
interleave_1(inptr, outptr, 4, xmax - x); | |||||
} | |||||
outptr_base += 16; | |||||
outptr_base4 += 4; | |||||
} | |||||
} | |||||
void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||||
int y0, int ymax, int k0, int kmax) { | |||||
float* outptr = out; | |||||
const float* inptr = in; | |||||
float zerobuff[4]; | |||||
std::memset(zerobuff, 0, sizeof(float) * 4); | |||||
int K16 = 16 * (kmax - k0); | |||||
int y = y0; | |||||
for (; y + 16 <= ymax; y += 16) { | |||||
int yi = y; | |||||
for (; yi < y + 16; yi += 4) { | |||||
const float* inptr0 = inptr + yi * ldin + k0; | |||||
const float* inptr1 = inptr0 + ldin; | |||||
const float* inptr2 = inptr1 + ldin; | |||||
const float* inptr3 = inptr2 + ldin; | |||||
float* outptr_inner = outptr + yi - y; | |||||
prefetch_2x(inptr0); | |||||
prefetch_2x(inptr1); | |||||
prefetch_2x(inptr2); | |||||
prefetch_2x(inptr3); | |||||
int x = (kmax - k0); | |||||
for (; x > 3; x -= 4) { | |||||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, | |||||
64); | |||||
outptr_inner += 64; | |||||
} | |||||
for (; x > 0; x--) { | |||||
*outptr_inner++ = *inptr0++; | |||||
*outptr_inner++ = *inptr1++; | |||||
*outptr_inner++ = *inptr2++; | |||||
*outptr_inner++ = *inptr3++; | |||||
outptr_inner += 12; | |||||
} | |||||
} | |||||
outptr += K16; | |||||
} | |||||
for (; y < ymax; y += 4) { | |||||
const float* inptr0 = inptr + y * ldin + k0; | |||||
const float* inptr1 = inptr0 + ldin; | |||||
const float* inptr2 = inptr1 + ldin; | |||||
const float* inptr3 = inptr2 + ldin; | |||||
prefetch_2x(inptr0); | |||||
prefetch_2x(inptr1); | |||||
prefetch_2x(inptr2); | |||||
prefetch_2x(inptr3); | |||||
/* Cope with ragged cases by copying from a buffer of zeroes instead | |||||
*/ | |||||
int x = (kmax - k0); | |||||
for (; x > 3; x -= 4) { | |||||
if ((y + 3) >= ymax) { | |||||
switch ((y + 3) - ymax) { | |||||
/* Everything falls through in here */ | |||||
case 2: | |||||
inptr1 = zerobuff; | |||||
case 1: | |||||
inptr2 = zerobuff; | |||||
case 0: | |||||
inptr3 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); | |||||
outptr += 16; | |||||
} | |||||
if (x > 0) { | |||||
if ((y + 3) >= ymax) { | |||||
switch ((y + 3) - ymax) { | |||||
/* Everything falls through in here */ | |||||
case 2: | |||||
inptr1 = zerobuff; | |||||
case 1: | |||||
inptr2 = zerobuff; | |||||
case 0: | |||||
inptr3 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, x); | |||||
} | |||||
} | |||||
} | |||||
} // matmul_general_4x16 | |||||
} // aarch64 | |||||
} // megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,166 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/fp32/strategy.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/fp32/strategy.h" | |||||
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" | |||||
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" | |||||
#include "src/common/utils.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); | |||||
void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, | |||||
int ymax, int k0, int kmax, bool transpose_A) const { | |||||
if (transpose_A) { | |||||
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); | |||||
} else { | |||||
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||||
} | |||||
} | |||||
void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||||
int k0, int kmax, bool transpose_B) const { | |||||
if (transpose_B) { | |||||
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||||
} else { | |||||
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | |||||
} | |||||
void sgemm_4x16::kern(const float* packA, const float* packB, | |||||
size_t M, size_t N, size_t K, float* C, size_t LDC, | |||||
bool is_first_k, const float*, float*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
constexpr size_t A_INTERLEAVE = 4; | |||||
constexpr size_t B_INTERLEAVE = 16; | |||||
const int K16 = K * 16; | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m < M; m += A_INTERLEAVE) { | |||||
float* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const float* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K16; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_general_4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
} | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); | |||||
void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, | |||||
int ymax, int k0, int kmax, bool transpose_A) const { | |||||
if (transpose_A) { | |||||
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} else { | |||||
matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||||
int k0, int kmax, bool transpose_B) const { | |||||
if (transpose_B) { | |||||
matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
} else { | |||||
matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void sgemm_8x12::kern(const float* packA, const float* packB, | |||||
size_t M, size_t N, size_t K, float* C, size_t LDC, | |||||
bool is_first_k, const float*, float*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
constexpr size_t A_INTERLEAVE = 8; | |||||
constexpr size_t A_INTERLEAVE4 = 4; | |||||
constexpr size_t B_INTERLEAVE = 12; | |||||
const int K12 = K * 12; | |||||
const int K8 = K * 8; | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) { | |||||
float* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const float* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) { | |||||
matmul_general_8x12::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K12; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_general_8x12::kern_8x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K8; | |||||
} | |||||
for (; m < M; m += A_INTERLEAVE4) { | |||||
float* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const float* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_general_8x12::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K12; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_general_8x12::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,30 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/fp32/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, | |||||
sgemm_8x12); | |||||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, | |||||
sgemm_4x16); | |||||
MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true, | |||||
sgemm_nopack_4x16); | |||||
} // namespace matmul | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,570 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/aarch64/matrix_mul/fp32/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
namespace { | |||||
// Overview of register layout: | |||||
// | |||||
// A 4x4 block of A is stored in register v4-v7 | |||||
// A 4x4 block of B is stored in register v0-v3 | |||||
// A 8x4 block of accumulators store in v16-v19 | |||||
// | |||||
// A +--------+ | |||||
// | v4[0-3]| | |||||
// | v5[0-3]| | |||||
// | v6[0-3]| | |||||
// | v7[0-3]| | |||||
// +--------+ | |||||
// B | |||||
// +--------+ - - - - -+--------+ | |||||
// | v0[0-3]| |v16[0-3]| | |||||
// | v1[0-3]| |v17[0-3]| | |||||
// | v2[0-3]| |v18[0-3]| | |||||
// | v3[0-3]| |v19[0-3]| | |||||
// +--------+ - - - - -+--------+ | |||||
// Accumulator | |||||
void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
float* output) { | |||||
//! As each load 16 number from B, but the pos add 12 * 4, so we minus 12 | |||||
//! here. | |||||
LDB = (LDB - 12) * sizeof(float); | |||||
asm volatile( | |||||
"subs %w[K], %w[K], #4\n" | |||||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"fmul v16.4s, v4.4s, v0.s[0]\n" | |||||
"fmul v17.4s, v4.4s, v1.s[0]\n" | |||||
"fmul v18.4s, v4.4s, v2.s[0]\n" | |||||
"fmul v19.4s, v4.4s, v3.s[0]\n" | |||||
"fmla v16.4s, v5.4s, v0.s[1]\n" | |||||
"fmla v17.4s, v5.4s, v1.s[1]\n" | |||||
"fmla v18.4s, v5.4s, v2.s[1]\n" | |||||
"fmla v19.4s, v5.4s, v3.s[1]\n" | |||||
"beq 2f\n" | |||||
"1:\n" | |||||
"ld1 {v4.4s, v5.4s}, [%[a_ptr]], 32\n" | |||||
"fmla v16.4s, v6.4s, v0.s[2]\n" | |||||
"fmla v17.4s, v6.4s, v1.s[2]\n" | |||||
"fmla v18.4s, v6.4s, v2.s[2]\n" | |||||
"fmla v19.4s, v6.4s, v3.s[2]\n" | |||||
"fmla v16.4s, v7.4s, v0.s[3]\n" | |||||
"fmla v17.4s, v7.4s, v1.s[3]\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v18.4s, v7.4s, v2.s[3]\n" | |||||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v19.4s, v7.4s, v3.s[3]\n" | |||||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v16.4s, v4.4s, v0.s[0]\n" | |||||
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"fmla v17.4s, v4.4s, v1.s[0]\n" | |||||
"fmla v18.4s, v4.4s, v2.s[0]\n" | |||||
"fmla v19.4s, v4.4s, v3.s[0]\n" | |||||
"ld1 {v6.4s, v7.4s}, [%[a_ptr]], 32\n" | |||||
"fmla v16.4s, v5.4s, v0.s[1]\n" | |||||
"fmla v17.4s, v5.4s, v1.s[1]\n" | |||||
"fmla v18.4s, v5.4s, v2.s[1]\n" | |||||
"fmla v19.4s, v5.4s, v3.s[1]\n" | |||||
"subs %w[K], %w[K], #4\n" | |||||
"bne 1b\n" | |||||
"2:\n" | |||||
"fmla v16.4s, v6.4s, v0.s[2]\n" | |||||
"fmla v17.4s, v6.4s, v1.s[2]\n" | |||||
"fmla v18.4s, v6.4s, v2.s[2]\n" | |||||
"fmla v19.4s, v6.4s, v3.s[2]\n" | |||||
"fmla v16.4s, v7.4s, v0.s[3]\n" | |||||
"fmla v17.4s, v7.4s, v1.s[3]\n" | |||||
"fmla v18.4s, v7.4s, v2.s[3]\n" | |||||
"fmla v19.4s, v7.4s, v3.s[3]\n" | |||||
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[output] "+r"(output), [LDB] "+r"(LDB) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); | |||||
} | |||||
// Overview of register layout: | |||||
// | |||||
// A 4x4 block of A is stored in register v4-v7 | |||||
// A 4x4 block of B is stored in register v0-v3, slipping until 8x4 | |||||
// A 8x4 block of accumulators store in v16-v23. | |||||
// | |||||
// A +--------+ | |||||
// | v4[0-3]| | |||||
// | v5[0-3]| | |||||
// | v6[0-3]| | |||||
// | v7[0-3]| | |||||
// +--------+ | |||||
// B | |||||
// +--------+ - - - - -+--------+ | |||||
// | v0[0-3]| |v16[0-3]| | |||||
// | v1[0-3]| |v17[0-3]| | |||||
// | v2[0-3]| |v18[0-3]| | |||||
// | v3[0-3]| |v19[0-3]| | |||||
// +--------+ - - - - -+--------+ | |||||
// | v0[0-3]| |v20[0-3]| | |||||
// | v1[0-3]| |v21[0-3]| | |||||
// | v2[0-3]| |v22[0-3]| | |||||
// | v3[0-3]| |v23[0-3]| | |||||
// +--------+ - - - - -+--------+ | |||||
// Accumulator | |||||
void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
float* output) { | |||||
//! As each load 32 number from B, but the pos add 24 * 4, so we minus 24 | |||||
//! here. | |||||
LDB = (LDB - 24) * sizeof(float); | |||||
asm volatile( | |||||
"subs %w[K], %w[K], #4\n" | |||||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||||
"fmul v16.4s, v4.4s, v0.s[0]\n" | |||||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v16.4s, v5.4s, v0.s[1]\n" | |||||
"fmul v17.4s, v4.4s, v1.s[0]\n" | |||||
"ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n" | |||||
"fmla v17.4s, v5.4s, v1.s[1]\n" | |||||
"fmla v16.4s, v6.4s, v0.s[2]\n" | |||||
"fmla v17.4s, v6.4s, v1.s[2]\n" | |||||
"fmul v18.4s, v4.4s, v2.s[0]\n" | |||||
"fmla v16.4s, v7.4s, v0.s[3]\n" | |||||
"fmla v18.4s, v5.4s, v2.s[1]\n" | |||||
"fmla v17.4s, v7.4s, v1.s[3]\n" | |||||
"fmul v19.4s, v4.4s, v3.s[0]\n" | |||||
"ld1 {v24.4s, v25.4s}, [%[b_ptr]], 32\n" | |||||
"fmla v18.4s, v7.4s, v2.s[3]\n" | |||||
"fmla v19.4s, v5.4s, v3.s[1]\n" | |||||
"fmul v20.4s, v4.4s, v24.s[0]\n" | |||||
"fmla v19.4s, v6.4s, v3.s[2]\n" | |||||
"ld1 {v26.4s, v27.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"fmla v18.4s, v6.4s, v2.s[2]\n" | |||||
"fmla v19.4s, v7.4s, v3.s[3]\n" | |||||
"fmul v21.4s, v4.4s, v25.s[0]\n" | |||||
"fmla v20.4s, v5.4s, v24.s[1]\n" | |||||
"fmla v21.4s, v5.4s, v25.s[1]\n" | |||||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||||
"fmla v20.4s, v6.4s, v24.s[2]\n" | |||||
"fmul v22.4s, v4.4s, v26.s[0]\n" | |||||
"fmla v21.4s, v6.4s, v25.s[2]\n" | |||||
"fmla v22.4s, v5.4s, v26.s[1]\n" | |||||
"fmla v21.4s, v7.4s, v25.s[3]\n" | |||||
"fmul v23.4s, v4.4s, v27.s[0]\n" | |||||
"fmla v20.4s, v7.4s, v24.s[3]\n" | |||||
"fmla v22.4s, v6.4s, v26.s[2]\n" | |||||
"fmla v23.4s, v5.4s, v27.s[1]\n" | |||||
"beq 2f\n" | |||||
"1:\n" | |||||
"ld1 {v4.4s, v5.4s}, [%[a_ptr]], 32\n" | |||||
"fmla v22.4s, v7.4s, v26.s[3]\n" | |||||
"fmla v23.4s, v6.4s, v27.s[2]\n" | |||||
"fmla v16.4s, v4.4s, v0.s[0]\n" | |||||
"fmla v17.4s, v4.4s, v1.s[0]\n" | |||||
"fmla v23.4s, v7.4s, v27.s[3]\n" | |||||
"ld1 {v6.4s, v7.4s}, [%[a_ptr]], 32\n" | |||||
"fmla v16.4s, v5.4s, v0.s[1]\n" | |||||
"fmla v17.4s, v5.4s, v1.s[1]\n" | |||||
"fmla v16.4s, v6.4s, v0.s[2]\n" | |||||
"fmla v17.4s, v6.4s, v1.s[2]\n" | |||||
"fmla v18.4s, v4.4s, v2.s[0]\n" | |||||
"fmla v16.4s, v7.4s, v0.s[3]\n" | |||||
"fmla v18.4s, v5.4s, v2.s[1]\n" | |||||
"fmla v17.4s, v7.4s, v1.s[3]\n" | |||||
"fmla v19.4s, v4.4s, v3.s[0]\n" | |||||
"ld1 {v24.4s, v25.4s}, [%[b_ptr]], 32\n" | |||||
"fmla v18.4s, v6.4s, v2.s[2]\n" | |||||
"fmla v19.4s, v5.4s, v3.s[1]\n" | |||||
"fmla v20.4s, v4.4s, v24.s[0]\n" | |||||
"fmla v19.4s, v6.4s, v3.s[2]\n" | |||||
"ld1 {v26.4s, v27.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"fmla v18.4s, v7.4s, v2.s[3]\n" | |||||
"fmla v19.4s, v7.4s, v3.s[3]\n" | |||||
"fmla v21.4s, v4.4s, v25.s[0]\n" | |||||
"fmla v20.4s, v5.4s, v24.s[1]\n" | |||||
"fmla v21.4s, v5.4s, v25.s[1]\n" | |||||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||||
"fmla v20.4s, v6.4s, v24.s[2]\n" | |||||
"fmla v22.4s, v4.4s, v26.s[0]\n" | |||||
"fmla v20.4s, v7.4s, v24.s[3]\n" | |||||
"fmla v23.4s, v4.4s, v27.s[0]\n" | |||||
"fmla v21.4s, v6.4s, v25.s[2]\n" | |||||
"fmla v22.4s, v5.4s, v26.s[1]\n" | |||||
"fmla v21.4s, v7.4s, v25.s[3]\n" | |||||
"fmla v23.4s, v5.4s, v27.s[1]\n" | |||||
"fmla v22.4s, v6.4s, v26.s[2]\n" | |||||
"subs %w[K], %w[K], #4\n" | |||||
"bne 1b\n" | |||||
"2:\n" | |||||
"fmla v22.4s, v7.4s, v26.s[3]\n" | |||||
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" | |||||
"fmla v23.4s, v6.4s, v27.s[2]\n" | |||||
"fmla v23.4s, v7.4s, v27.s[3]\n" | |||||
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[output] "+r"(output), [LDB] "+r"(LDB) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||||
"v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", | |||||
"v27", "cc", "memory"); | |||||
} | |||||
// Overview of register layout: | |||||
// | |||||
// A 4x1 cell of Rhs is stored in 32bit in v4-v7 (v8-v11 for ping pong) | |||||
// A 4x1 cell of Lhs is stored in 32bit in v0-v3 | |||||
// A 16x1 block of accumulators is stored in 32bit in v16-v31. | |||||
// | |||||
// Rhs +--------+ | |||||
// | v4[0-3]| | |||||
// | v5[0-3]| | |||||
// | v6[0-3]| | |||||
// | v7[0-3]| | |||||
// +--------+ | |||||
// Lhs | |||||
// +--------+ - - - - -+--------+ | |||||
// | v0[0-3] | |v16[0-3]| | |||||
// | v1[0-3] | |v17[0-3]| | |||||
// | v2[0-3] | |v18[0-3]| | |||||
// | v3[0-3] | |v19[0-3]| | |||||
// | v8[0-3] | |v20[0-3]| | |||||
// | v9[0-3] | |v21[0-3]| | |||||
// | v10[0-3]| |v22[0-3]| | |||||
// | v11[0-3]| |v23[0-3]| | |||||
// +--------+ |v24[0-3]| | |||||
// |v25[0-3]| | |||||
// |v26[0-3]| | |||||
// |v27[0-3]| | |||||
// |v28[0-3]| | |||||
// |v29[0-3]| | |||||
// |v30[0-3]| | |||||
// |v31[0-3]| | |||||
// +--------+ | |||||
// Accumulator | |||||
void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, | |||||
float* output) { | |||||
//! As each load 64 number from B, but the pos add 56 * 4, so we minus 56 | |||||
//! here. | |||||
LDB = (LDB - 56) * sizeof(float); | |||||
asm volatile( | |||||
"stp d8, d9, [sp, #-16]!\n" | |||||
"stp d10, d11, [sp, #-16]!\n" | |||||
"subs %w[K], %w[K], #4\n" | |||||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n" | |||||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||||
"fmul v16.4s, v4.4s, v0.s[0]\n" | |||||
"fmul v17.4s, v4.4s, v1.s[0]\n" | |||||
"fmul v18.4s, v4.4s, v2.s[0]\n" | |||||
"fmul v19.4s, v4.4s, v3.s[0]\n" | |||||
"fmla v16.4s, v5.4s, v0.s[1]\n" | |||||
"fmla v17.4s, v5.4s, v1.s[1]\n" | |||||
"fmla v18.4s, v5.4s, v2.s[1]\n" | |||||
"fmla v19.4s, v5.4s, v3.s[1]\n" | |||||
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[b_ptr]], 64\n" | |||||
"fmla v16.4s, v6.4s, v0.s[2]\n" | |||||
"fmla v17.4s, v6.4s, v1.s[2]\n" | |||||
"fmla v18.4s, v6.4s, v2.s[2]\n" | |||||
"fmla v19.4s, v6.4s, v3.s[2]\n" | |||||
"fmla v16.4s, v7.4s, v0.s[3]\n" | |||||
"fmla v17.4s, v7.4s, v1.s[3]\n" | |||||
"fmla v18.4s, v7.4s, v2.s[3]\n" | |||||
"fmla v19.4s, v7.4s, v3.s[3]\n" | |||||
"fmul v20.4s, v4.4s, v8.s[0]\n" | |||||
"fmul v21.4s, v4.4s, v9.s[0]\n" | |||||
"fmul v22.4s, v4.4s, v10.s[0]\n" | |||||
"fmul v23.4s, v4.4s, v11.s[0]\n" | |||||
"fmla v20.4s, v5.4s, v8.s[1]\n" | |||||
"fmla v21.4s, v5.4s, v9.s[1]\n" | |||||
"fmla v22.4s, v5.4s, v10.s[1]\n" | |||||
"fmla v23.4s, v5.4s, v11.s[1]\n" | |||||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||||
"fmla v20.4s, v6.4s, v8.s[2]\n" | |||||
"fmla v21.4s, v6.4s, v9.s[2]\n" | |||||
"fmla v22.4s, v6.4s, v10.s[2]\n" | |||||
"fmla v23.4s, v6.4s, v11.s[2]\n" | |||||
"fmla v20.4s, v7.4s, v8.s[3]\n" | |||||
"fmla v21.4s, v7.4s, v9.s[3]\n" | |||||
"fmla v22.4s, v7.4s, v10.s[3]\n" | |||||
"fmla v23.4s, v7.4s, v11.s[3]\n" | |||||
"fmul v24.4s, v4.4s, v0.s[0]\n" | |||||
"fmul v25.4s, v4.4s, v1.s[0]\n" | |||||
"fmul v26.4s, v4.4s, v2.s[0]\n" | |||||
"fmul v27.4s, v4.4s, v3.s[0]\n" | |||||
"fmla v24.4s, v5.4s, v0.s[1]\n" | |||||
"fmla v25.4s, v5.4s, v1.s[1]\n" | |||||
"fmla v26.4s, v5.4s, v2.s[1]\n" | |||||
"fmla v27.4s, v5.4s, v3.s[1]\n" | |||||
"ld1 {v8.4s, v9.4s}, [%[b_ptr]], 32\n" | |||||
"fmla v24.4s, v6.4s, v0.s[2]\n" | |||||
"fmla v25.4s, v6.4s, v1.s[2]\n" | |||||
"fmla v26.4s, v6.4s, v2.s[2]\n" | |||||
"fmla v27.4s, v6.4s, v3.s[2]\n" | |||||
"ld1 {v10.4s, v11.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"fmla v24.4s, v7.4s, v0.s[3]\n" | |||||
"fmla v25.4s, v7.4s, v1.s[3]\n" | |||||
"fmla v26.4s, v7.4s, v2.s[3]\n" | |||||
"fmla v27.4s, v7.4s, v3.s[3]\n" | |||||
"fmul v28.4s, v4.4s, v8.s[0]\n" | |||||
"fmul v29.4s, v4.4s, v9.s[0]\n" | |||||
"fmul v30.4s, v4.4s, v10.s[0]\n" | |||||
"fmul v31.4s, v4.4s, v11.s[0]\n" | |||||
"beq 2f\n" | |||||
"1:\n" | |||||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||||
"fmla v28.4s, v5.4s, v8.s[1]\n" | |||||
"fmla v29.4s, v5.4s, v9.s[1]\n" | |||||
"fmla v30.4s, v5.4s, v10.s[1]\n" | |||||
"fmla v31.4s, v5.4s, v11.s[1]\n" | |||||
"ld1 {v4.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v28.4s, v6.4s, v8.s[2]\n" | |||||
"fmla v29.4s, v6.4s, v9.s[2]\n" | |||||
"fmla v30.4s, v6.4s, v10.s[2]\n" | |||||
"fmla v31.4s, v6.4s, v11.s[2]\n" | |||||
"ld1 {v5.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v28.4s, v7.4s, v8.s[3]\n" | |||||
"fmla v29.4s, v7.4s, v9.s[3]\n" | |||||
"fmla v30.4s, v7.4s, v10.s[3]\n" | |||||
"fmla v31.4s, v7.4s, v11.s[3]\n" | |||||
"ld1 {v6.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v16.4s, v4.4s, v0.s[0]\n" | |||||
"fmla v17.4s, v4.4s, v1.s[0]\n" | |||||
"fmla v18.4s, v4.4s, v2.s[0]\n" | |||||
"fmla v19.4s, v4.4s, v3.s[0]\n" | |||||
"ld1 {v7.4s}, [%[a_ptr]], 16\n" | |||||
"fmla v16.4s, v5.4s, v0.s[1]\n" | |||||
"fmla v17.4s, v5.4s, v1.s[1]\n" | |||||
"fmla v18.4s, v5.4s, v2.s[1]\n" | |||||
"fmla v19.4s, v5.4s, v3.s[1]\n" | |||||
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[b_ptr]], 64\n" | |||||
"fmla v16.4s, v6.4s, v0.s[2]\n" | |||||
"fmla v17.4s, v6.4s, v1.s[2]\n" | |||||
"fmla v18.4s, v6.4s, v2.s[2]\n" | |||||
"fmla v19.4s, v6.4s, v3.s[2]\n" | |||||
"fmla v16.4s, v7.4s, v0.s[3]\n" | |||||
"fmla v17.4s, v7.4s, v1.s[3]\n" | |||||
"fmla v18.4s, v7.4s, v2.s[3]\n" | |||||
"fmla v19.4s, v7.4s, v3.s[3]\n" | |||||
"fmla v20.4s, v4.4s, v8.s[0]\n" | |||||
"fmla v21.4s, v4.4s, v9.s[0]\n" | |||||
"fmla v22.4s, v4.4s, v10.s[0]\n" | |||||
"fmla v23.4s, v4.4s, v11.s[0]\n" | |||||
"fmla v20.4s, v5.4s, v8.s[1]\n" | |||||
"fmla v21.4s, v5.4s, v9.s[1]\n" | |||||
"fmla v22.4s, v5.4s, v10.s[1]\n" | |||||
"fmla v23.4s, v5.4s, v11.s[1]\n" | |||||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||||
"fmla v20.4s, v6.4s, v8.s[2]\n" | |||||
"fmla v21.4s, v6.4s, v9.s[2]\n" | |||||
"fmla v22.4s, v6.4s, v10.s[2]\n" | |||||
"fmla v23.4s, v6.4s, v11.s[2]\n" | |||||
"fmla v20.4s, v7.4s, v8.s[3]\n" | |||||
"fmla v21.4s, v7.4s, v9.s[3]\n" | |||||
"fmla v22.4s, v7.4s, v10.s[3]\n" | |||||
"fmla v23.4s, v7.4s, v11.s[3]\n" | |||||
"fmla v24.4s, v4.4s, v0.s[0]\n" | |||||
"fmla v25.4s, v4.4s, v1.s[0]\n" | |||||
"fmla v26.4s, v4.4s, v2.s[0]\n" | |||||
"fmla v27.4s, v4.4s, v3.s[0]\n" | |||||
"fmla v24.4s, v5.4s, v0.s[1]\n" | |||||
"fmla v25.4s, v5.4s, v1.s[1]\n" | |||||
"fmla v26.4s, v5.4s, v2.s[1]\n" | |||||
"fmla v27.4s, v5.4s, v3.s[1]\n" | |||||
"ld1 {v8.4s, v9.4s}, [%[b_ptr]], 32\n" | |||||
"fmla v24.4s, v6.4s, v0.s[2]\n" | |||||
"fmla v25.4s, v6.4s, v1.s[2]\n" | |||||
"fmla v26.4s, v6.4s, v2.s[2]\n" | |||||
"fmla v27.4s, v6.4s, v3.s[2]\n" | |||||
"ld1 {v10.4s, v11.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"fmla v24.4s, v7.4s, v0.s[3]\n" | |||||
"fmla v25.4s, v7.4s, v1.s[3]\n" | |||||
"fmla v26.4s, v7.4s, v2.s[3]\n" | |||||
"fmla v27.4s, v7.4s, v3.s[3]\n" | |||||
"fmla v28.4s, v4.4s, v8.s[0]\n" | |||||
"fmla v29.4s, v4.4s, v9.s[0]\n" | |||||
"fmla v30.4s, v4.4s, v10.s[0]\n" | |||||
"fmla v31.4s, v4.4s, v11.s[0]\n" | |||||
"subs %w[K], %w[K], #4\n" | |||||
"bne 1b\n" | |||||
"2:\n" | |||||
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" | |||||
"fmla v28.4s, v5.4s, v8.s[1]\n" | |||||
"fmla v29.4s, v5.4s, v9.s[1]\n" | |||||
"fmla v30.4s, v5.4s, v10.s[1]\n" | |||||
"fmla v31.4s, v5.4s, v11.s[1]\n" | |||||
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n" | |||||
"fmla v28.4s, v6.4s, v8.s[2]\n" | |||||
"fmla v29.4s, v6.4s, v9.s[2]\n" | |||||
"fmla v30.4s, v6.4s, v10.s[2]\n" | |||||
"fmla v31.4s, v6.4s, v11.s[2]\n" | |||||
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n" | |||||
"fmla v28.4s, v7.4s, v8.s[3]\n" | |||||
"fmla v29.4s, v7.4s, v9.s[3]\n" | |||||
"fmla v30.4s, v7.4s, v10.s[3]\n" | |||||
"fmla v31.4s, v7.4s, v11.s[3]\n" | |||||
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n" | |||||
"ldp d10, d11, [sp], #16\n" | |||||
"ldp d8, d9, [sp], #16\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[output] "+r"(output), [LDB] "+r"(LDB) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||||
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||||
"memory"); | |||||
} | |||||
} // namespace | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16); | |||||
void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, | |||||
size_t LDB, float* C, size_t LDC, size_t M, | |||||
size_t K, size_t N, const float*, void*, bool trA, | |||||
bool trB) const { | |||||
constexpr static size_t MB = 4; | |||||
constexpr static size_t KB = 4; | |||||
constexpr static size_t NB = 16; | |||||
constexpr static size_t CALCBLK = 4; | |||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||||
//! (m/4, k/4, 4, 4) * (k/4, n, 4) = (m/4, n, 4) | |||||
for (size_t m = 0; m < M; m += MB) { | |||||
float* output = C + (m / MB) * LDC; | |||||
const float* cur_B = B; | |||||
size_t n = 0; | |||||
for (; n + NB - 1 < N; n += NB) { | |||||
kern_4x16(A, cur_B, LDB, K, output); | |||||
cur_B += KB * NB; | |||||
output += MB * NB; | |||||
} | |||||
switch (N - n) { | |||||
case 4: | |||||
kern_4x4(A, cur_B, LDB, K, output); | |||||
break; | |||||
case 8: | |||||
kern_4x8(A, cur_B, LDB, K, output); | |||||
break; | |||||
case 12: | |||||
kern_4x8(A, cur_B, LDB, K, output); | |||||
cur_B += KB * CALCBLK * 2; | |||||
output += MB * CALCBLK * 2; | |||||
kern_4x4(A, cur_B, LDB, K, output); | |||||
break; | |||||
default: | |||||
break; | |||||
} | |||||
A += LDA; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,132 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int16/strategy.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/int16/strategy.h" | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/aarch64/matrix_mul/int16/kernel_12x8x1.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
///////////////////////// gemm_s16_12x8x1 //////////////////////////////////// | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1); | |||||
void gemm_s16_12x8x1::pack_A(dt_int16* outptr, const dt_int16* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | |||||
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n(outptr, inptr, ldin, | |||||
y0, ymax, k0, kmax); | |||||
} else { | |||||
matmul_12x8x1::gemm_s16_12x8x1_pack_A_n(outptr, inptr, ldin, y0, ymax, | |||||
k0, kmax); | |||||
} | |||||
} | |||||
void gemm_s16_12x8x1::pack_B(dt_int16* out, const dt_int16* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | |||||
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n(out, in, ldin, x0, | |||||
xmax, k0, kmax); | |||||
} else { | |||||
matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||||
size_t M, size_t N, size_t K, dt_int32* C, | |||||
size_t LDC, bool is_first_k, const dt_int32*, | |||||
dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
(A_dtype.enumv() == DTypeEnum::Int16 && | |||||
C_dtype.enumv() == DTypeEnum::Int32), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
constexpr size_t A_INTERLEAVE = 12; | |||||
constexpr size_t B_INTERLEAVE = 8; | |||||
const int K12 = K * 12; | |||||
const int K8 = K * 8; | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
int32_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_int16* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_12x8x1::kern_12x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(N - n, 4)); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K12; | |||||
} | |||||
for (; m + 7 < M; m += 8) { | |||||
int32_t* output = C + (m * LDC); | |||||
const dt_int16* cur_packB = packB; | |||||
size_t n = 0; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_12x8x1::kern_8x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(N - n, 4)); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K8; | |||||
} | |||||
for (; m < M; m += 4) { | |||||
int32_t* output = C + (m * LDC); | |||||
const dt_int16* cur_packB = packB; | |||||
size_t n = 0; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_12x8x1::kern_4x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_12x8x1::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,29 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int16/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, | |||||
gemm_s16_12x8x1); | |||||
MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 8, 8, 1, false, | |||||
true, gemm_nopack_s16_8x8); | |||||
} // namespace matmul | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,658 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/int16/strategy.h" | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
namespace { | |||||
// Overview of register layout: | |||||
// | |||||
// A 8x1 cell of Lhs is stored in 16bit in v24-v27 | |||||
// A 8x1 cell of Rhs is stored in 16bit in v0-v7 | |||||
// A 8x2 block of accumulators is stored in 32bit in v16-v23. | |||||
// | |||||
// Lhs +--------+ | |||||
// |v0[0-7]| | |||||
// |v1[0-7]| | |||||
// |v2[0-7]| | |||||
// |v3[0-7]| | |||||
// +--------+ | |||||
// Rhs | |||||
// +---------+ - - - - -+--------+ | |||||
// | v24[0-7]| |v16[0-3]| | |||||
// | v25[0-7]| |v17[0-3]| | |||||
// | v26[0-7]| |v18[0-3]| | |||||
// | v27[0-7]| |v19[0-3]| | |||||
// | v28[0-7]| |v20[0-3]| | |||||
// | v29[0-7]| |v21[0-3]| | |||||
// | v30[0-7]| |v22[0-3]| | |||||
// | v31[0-7]| |v23[0-3]| | |||||
// +---------+ +--------+ | |||||
// Accumulator | |||||
void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
dt_int32* output) { | |||||
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | |||||
//! here. | |||||
LDB = (LDB - 24) * sizeof(dt_int16); | |||||
asm volatile( | |||||
"subs %w[K], %w[K], #8\n" | |||||
"ld1 {v24.4s}, [%[a_ptr]], 16\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"smull v16.4s, v24.4h, v0.h[0]\n" | |||||
"smull2 v17.4s, v24.8h, v0.h[0]\n" | |||||
"smull v18.4s, v24.4h, v1.h[0]\n" | |||||
"smull2 v19.4s, v24.8h, v1.h[0]\n" | |||||
"ld1 {v25.4s}, [%[a_ptr]], 16\n" | |||||
"smull v20.4s, v24.4h, v2.h[0]\n" | |||||
"smull2 v21.4s, v24.8h, v2.h[0]\n" | |||||
"smull v22.4s, v24.4h, v3.h[0]\n" | |||||
"smull2 v23.4s, v24.8h, v3.h[0]\n" | |||||
"smlal v16.4s, v25.4h, v0.h[1]\n" | |||||
"smlal2 v17.4s, v25.8h, v0.h[1]\n" | |||||
"smlal v18.4s, v25.4h, v1.h[1]\n" | |||||
"smlal2 v19.4s, v25.8h, v1.h[1]\n" | |||||
"ld1 {v26.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v20.4s, v25.4h, v2.h[1]\n" | |||||
"smlal2 v21.4s, v25.8h, v2.h[1]\n" | |||||
"smlal v22.4s, v25.4h, v3.h[1]\n" | |||||
"smlal2 v23.4s, v25.8h, v3.h[1]\n" | |||||
"smlal v16.4s, v26.4h, v0.h[2]\n" | |||||
"smlal2 v17.4s, v26.8h, v0.h[2]\n" | |||||
"smlal v18.4s, v26.4h, v1.h[2]\n" | |||||
"smlal2 v19.4s, v26.8h, v1.h[2]\n" | |||||
"ld1 {v27.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v20.4s, v26.4h, v2.h[2]\n" | |||||
"smlal2 v21.4s, v26.8h, v2.h[2]\n" | |||||
"smlal v22.4s, v26.4h, v3.h[2]\n" | |||||
"smlal2 v23.4s, v26.8h, v3.h[2]\n" | |||||
"smlal v16.4s, v27.4h, v0.h[3]\n" | |||||
"smlal2 v17.4s, v27.8h, v0.h[3]\n" | |||||
"smlal v18.4s, v27.4h, v1.h[3]\n" | |||||
"smlal2 v19.4s, v27.8h, v1.h[3]\n" | |||||
"ld1 {v28.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v20.4s, v27.4h, v2.h[3]\n" | |||||
"smlal2 v21.4s, v27.8h, v2.h[3]\n" | |||||
"smlal v22.4s, v27.4h, v3.h[3]\n" | |||||
"smlal2 v23.4s, v27.8h, v3.h[3]\n" | |||||
"smlal v16.4s, v28.4h, v0.h[4]\n" | |||||
"smlal2 v17.4s, v28.8h, v0.h[4]\n" | |||||
"smlal v18.4s, v28.4h, v1.h[4]\n" | |||||
"smlal2 v19.4s, v28.8h, v1.h[4]\n" | |||||
"ld1 {v29.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v20.4s, v28.4h, v2.h[4]\n" | |||||
"smlal2 v21.4s, v28.8h, v2.h[4]\n" | |||||
"smlal v22.4s, v28.4h, v3.h[4]\n" | |||||
"smlal2 v23.4s, v28.8h, v3.h[4]\n" | |||||
"smlal v16.4s, v29.4h, v0.h[5]\n" | |||||
"smlal2 v17.4s, v29.8h, v0.h[5]\n" | |||||
"smlal v18.4s, v29.4h, v1.h[5]\n" | |||||
"smlal2 v19.4s, v29.8h, v1.h[5]\n" | |||||
"ld1 {v30.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v20.4s, v29.4h, v2.h[5]\n" | |||||
"smlal2 v21.4s, v29.8h, v2.h[5]\n" | |||||
"smlal v22.4s, v29.4h, v3.h[5]\n" | |||||
"smlal2 v23.4s, v29.8h, v3.h[5]\n" | |||||
"smlal v16.4s, v30.4h, v0.h[6]\n" | |||||
"smlal2 v17.4s, v30.8h, v0.h[6]\n" | |||||
"smlal v18.4s, v30.4h, v1.h[6]\n" | |||||
"smlal2 v19.4s, v30.8h, v1.h[6]\n" | |||||
"ld1 {v31.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v20.4s, v30.4h, v2.h[6]\n" | |||||
"smlal2 v21.4s, v30.8h, v2.h[6]\n" | |||||
"smlal v22.4s, v30.4h, v3.h[6]\n" | |||||
"smlal2 v23.4s, v30.8h, v3.h[6]\n" | |||||
"beq 2f\n" | |||||
"1:\n" | |||||
"ld1 {v24.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v16.4s, v31.4h, v0.h[7]\n" | |||||
"smlal2 v17.4s, v31.8h, v0.h[7]\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||||
"smlal v18.4s, v31.4h, v1.h[7]\n" | |||||
"smlal2 v19.4s, v31.8h, v1.h[7]\n" | |||||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||||
"smlal v20.4s, v31.4h, v2.h[7]\n" | |||||
"smlal2 v21.4s, v31.8h, v2.h[7]\n" | |||||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
"smlal v22.4s, v31.4h, v3.h[7]\n" | |||||
"smlal2 v23.4s, v31.8h, v3.h[7]\n" | |||||
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"smlal v16.4s, v24.4h, v0.h[0]\n" | |||||
"smlal2 v17.4s, v24.8h, v0.h[0]\n" | |||||
"smlal v18.4s, v24.4h, v1.h[0]\n" | |||||
"smlal2 v19.4s, v24.8h, v1.h[0]\n" | |||||
"ld1 {v25.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v20.4s, v24.4h, v2.h[0]\n" | |||||
"smlal2 v21.4s, v24.8h, v2.h[0]\n" | |||||
"smlal v22.4s, v24.4h, v3.h[0]\n" | |||||
"smlal2 v23.4s, v24.8h, v3.h[0]\n" | |||||
"smlal v16.4s, v25.4h, v0.h[1]\n" | |||||
"smlal2 v17.4s, v25.8h, v0.h[1]\n" | |||||
"smlal v18.4s, v25.4h, v1.h[1]\n" | |||||
"smlal2 v19.4s, v25.8h, v1.h[1]\n" | |||||
"ld1 {v26.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v20.4s, v25.4h, v2.h[1]\n" | |||||
"smlal2 v21.4s, v25.8h, v2.h[1]\n" | |||||
"smlal v22.4s, v25.4h, v3.h[1]\n" | |||||
"smlal2 v23.4s, v25.8h, v3.h[1]\n" | |||||
"smlal v16.4s, v26.4h, v0.h[2]\n" | |||||
"smlal2 v17.4s, v26.8h, v0.h[2]\n" | |||||
"smlal v18.4s, v26.4h, v1.h[2]\n" | |||||
"smlal2 v19.4s, v26.8h, v1.h[2]\n" | |||||
"ld1 {v27.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v20.4s, v26.4h, v2.h[2]\n" | |||||
"smlal2 v21.4s, v26.8h, v2.h[2]\n" | |||||
"smlal v22.4s, v26.4h, v3.h[2]\n" | |||||
"smlal2 v23.4s, v26.8h, v3.h[2]\n" | |||||
"smlal v16.4s, v27.4h, v0.h[3]\n" | |||||
"smlal2 v17.4s, v27.8h, v0.h[3]\n" | |||||
"smlal v18.4s, v27.4h, v1.h[3]\n" | |||||
"smlal2 v19.4s, v27.8h, v1.h[3]\n" | |||||
"ld1 {v28.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v20.4s, v27.4h, v2.h[3]\n" | |||||
"smlal2 v21.4s, v27.8h, v2.h[3]\n" | |||||
"smlal v22.4s, v27.4h, v3.h[3]\n" | |||||
"smlal2 v23.4s, v27.8h, v3.h[3]\n" | |||||
"smlal v16.4s, v28.4h, v0.h[4]\n" | |||||
"smlal2 v17.4s, v28.8h, v0.h[4]\n" | |||||
"smlal v18.4s, v28.4h, v1.h[4]\n" | |||||
"smlal2 v19.4s, v28.8h, v1.h[4]\n" | |||||
"ld1 {v29.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v20.4s, v28.4h, v2.h[4]\n" | |||||
"smlal2 v21.4s, v28.8h, v2.h[4]\n" | |||||
"smlal v22.4s, v28.4h, v3.h[4]\n" | |||||
"smlal2 v23.4s, v28.8h, v3.h[4]\n" | |||||
"smlal v16.4s, v29.4h, v0.h[5]\n" | |||||
"smlal2 v17.4s, v29.8h, v0.h[5]\n" | |||||
"smlal v18.4s, v29.4h, v1.h[5]\n" | |||||
"smlal2 v19.4s, v29.8h, v1.h[5]\n" | |||||
"ld1 {v30.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v20.4s, v29.4h, v2.h[5]\n" | |||||
"smlal2 v21.4s, v29.8h, v2.h[5]\n" | |||||
"smlal v22.4s, v29.4h, v3.h[5]\n" | |||||
"smlal2 v23.4s, v29.8h, v3.h[5]\n" | |||||
"smlal v16.4s, v30.4h, v0.h[6]\n" | |||||
"smlal2 v17.4s, v30.8h, v0.h[6]\n" | |||||
"smlal v18.4s, v30.4h, v1.h[6]\n" | |||||
"smlal2 v19.4s, v30.8h, v1.h[6]\n" | |||||
"ld1 {v31.4s}, [%[a_ptr]], 16\n" | |||||
"smlal v20.4s, v30.4h, v2.h[6]\n" | |||||
"smlal2 v21.4s, v30.8h, v2.h[6]\n" | |||||
"smlal v22.4s, v30.4h, v3.h[6]\n" | |||||
"smlal2 v23.4s, v30.8h, v3.h[6]\n" | |||||
"subs %w[K], %w[K], #8\n" | |||||
"bne 1b\n" | |||||
"2:\n" | |||||
"smlal v16.4s, v31.4h, v0.h[7]\n" | |||||
"smlal2 v17.4s, v31.8h, v0.h[7]\n" | |||||
"smlal v18.4s, v31.4h, v1.h[7]\n" | |||||
"smlal2 v19.4s, v31.8h, v1.h[7]\n" | |||||
"smlal v20.4s, v31.4h, v2.h[7]\n" | |||||
"smlal2 v21.4s, v31.8h, v2.h[7]\n" | |||||
"smlal v22.4s, v31.4h, v3.h[7]\n" | |||||
"smlal2 v23.4s, v31.8h, v3.h[7]\n" | |||||
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" | |||||
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[output] "+r"(output), [LDB] "+r"(LDB) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "cc", "memory"); | |||||
} | |||||
// Overview of register layout: | |||||
// | |||||
// A 8x1 cell of Rhs is stored in 16bit in v8-v15 | |||||
// A 8x1 cell of Lhs is stored in 16bit in v0-v7 | |||||
// A 8x2 block of accumulators is stored in 32bit in v16-v31. | |||||
// | |||||
// Rhs +--------+ | |||||
// | v8[0-7]| | |||||
// | v9[0-7]| | |||||
// |v10[0-7]| | |||||
// |v11[0-7]| | |||||
// |v12[0-7]| | |||||
// |v13[0-7]| | |||||
// |v14[0-7]| | |||||
// |v15[0-7]| | |||||
// +--------+ | |||||
// Lhs | |||||
// +--------+ - - - - -+--------+--------+ | |||||
// | v0[0-7]| |v16[0-3]|v17[0-3]| | |||||
// | v1[0-7]| |v18[0-3]|v19[0-3]| | |||||
// | v2[0-7]| |v20[0-3]|v21[0-3]| | |||||
// | v3[0-7]| |v22[0-3]|v23[0-3]| | |||||
// | v4[0-7]| |v24[0-3]|v25[0-3]| | |||||
// | v5[0-7]| |v26[0-3]|v27[0-3]| | |||||
// | v6[0-7]| |v28[0-3]|v29[0-3]| | |||||
// | v7[0-7]| |v30[0-3]|v31[0-3]| | |||||
// +--------+ +--------+--------+ | |||||
// Accumulator | |||||
void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
dt_int32* output) { | |||||
//! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 | |||||
//! here. | |||||
LDB = (LDB - 48) * sizeof(dt_int16); | |||||
asm volatile( | |||||
"subs %w[K], %w[K], #8\n" | |||||
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n" | |||||
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||||
"smull v16.4s, v8.4h, v0.h[0]\n" | |||||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||||
"smlal v16.4s, v9.4h, v0.h[1]\n" | |||||
"smull v18.4s, v8.4h, v1.h[0]\n" | |||||
"smull2 v17.4s, v8.8h, v0.h[0]\n" | |||||
"smull2 v19.4s, v8.8h, v1.h[0]\n" | |||||
"smlal v16.4s, v10.4h, v0.h[2]\n" | |||||
"smlal v18.4s, v9.4h, v1.h[1]\n" | |||||
"smlal2 v17.4s, v9.8h, v0.h[1]\n" | |||||
"smlal2 v19.4s, v9.8h, v1.h[1]\n" | |||||
"smlal v16.4s, v11.4h, v0.h[3]\n" | |||||
"smlal v18.4s, v10.4h, v1.h[2]\n" | |||||
"smlal2 v17.4s, v10.8h, v0.h[2]\n" | |||||
"smlal2 v19.4s, v10.8h, v1.h[2]\n" | |||||
"smlal v16.4s, v12.4h, v0.h[4]\n" | |||||
"smlal v18.4s, v11.4h, v1.h[3]\n" | |||||
"smlal2 v17.4s, v11.8h, v0.h[3]\n" | |||||
"smlal2 v19.4s, v11.8h, v1.h[3]\n" | |||||
"smlal v16.4s, v13.4h, v0.h[5]\n" | |||||
"smlal v18.4s, v12.4h, v1.h[4]\n" | |||||
"smlal2 v17.4s, v12.8h, v0.h[4]\n" | |||||
"smlal2 v19.4s, v12.8h, v1.h[4]\n" | |||||
"smlal2 v17.4s, v13.8h, v0.h[5]\n" | |||||
"ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n" | |||||
"smlal v16.4s, v14.4h, v0.h[6]\n" | |||||
"smlal v18.4s, v13.4h, v1.h[5]\n" | |||||
"smlal2 v17.4s, v14.8h, v0.h[6]\n" | |||||
"smlal2 v19.4s, v13.8h, v1.h[5]\n" | |||||
"smull v20.4s, v8.4h, v2.h[0]\n" | |||||
"smull v22.4s, v8.4h, v3.h[0]\n" | |||||
"smull2 v21.4s, v8.8h, v2.h[0]\n" | |||||
"smull2 v23.4s, v8.8h, v3.h[0]\n" | |||||
"smlal v16.4s, v15.4h, v0.h[7]\n" | |||||
"smlal v18.4s, v14.4h, v1.h[6]\n" | |||||
"smlal2 v17.4s, v15.8h, v0.h[7]\n" | |||||
"smlal2 v19.4s, v14.8h, v1.h[6]\n" | |||||
"smlal v20.4s, v9.4h, v2.h[1]\n" | |||||
"smlal v22.4s, v9.4h, v3.h[1]\n" | |||||
"smlal2 v21.4s, v9.8h, v2.h[1]\n" | |||||
"smlal2 v23.4s, v9.8h, v3.h[1]\n" | |||||
"smlal v18.4s, v15.4h, v1.h[7]\n" | |||||
"smlal v20.4s, v10.4h, v2.h[2]\n" | |||||
"smlal v22.4s, v10.4h, v3.h[2]\n" | |||||
"smlal2 v21.4s, v10.8h, v2.h[2]\n" | |||||
"smlal2 v23.4s, v10.8h, v3.h[2]\n" | |||||
"smlal2 v19.4s, v15.8h, v1.h[7]\n" | |||||
"smlal v20.4s, v11.4h, v2.h[3]\n" | |||||
"smlal v22.4s, v11.4h, v3.h[3]\n" | |||||
"smlal2 v21.4s, v11.8h, v2.h[3]\n" | |||||
"smlal2 v23.4s, v11.8h, v3.h[3]\n" | |||||
"smlal v20.4s, v12.4h, v2.h[4]\n" | |||||
"smlal v22.4s, v12.4h, v3.h[4]\n" | |||||
"smlal2 v21.4s, v12.8h, v2.h[4]\n" | |||||
"smlal2 v23.4s, v12.8h, v3.h[4]\n" | |||||
"smlal v20.4s, v13.4h, v2.h[5]\n" | |||||
"smlal v22.4s, v13.4h, v3.h[5]\n" | |||||
"smlal2 v21.4s, v13.8h, v2.h[5]\n" | |||||
"smlal2 v23.4s, v13.8h, v3.h[5]\n" | |||||
"ld1 {v4.4s, v5.4s}, [%[b_ptr]], 32\n" | |||||
"smlal v20.4s, v14.4h, v2.h[6]\n" | |||||
"smlal v22.4s, v14.4h, v3.h[6]\n" | |||||
"smlal2 v21.4s, v14.8h, v2.h[6]\n" | |||||
"smlal2 v23.4s, v14.8h, v3.h[6]\n" | |||||
"smull v24.4s, v8.4h, v4.h[0]\n" | |||||
"smull v26.4s, v8.4h, v5.h[0]\n" | |||||
"smull2 v25.4s, v8.8h, v4.h[0]\n" | |||||
"smull2 v27.4s, v8.8h, v5.h[0]\n" | |||||
"smlal v20.4s, v15.4h, v2.h[7]\n" | |||||
"smlal v22.4s, v15.4h, v3.h[7]\n" | |||||
"smlal2 v21.4s, v15.8h, v2.h[7]\n" | |||||
"smlal2 v23.4s, v15.8h, v3.h[7]\n" | |||||
"smlal v24.4s, v9.4h, v4.h[1]\n" | |||||
"smlal v26.4s, v9.4h, v5.h[1]\n" | |||||
"smlal2 v25.4s, v9.8h, v4.h[1]\n" | |||||
"smlal2 v27.4s, v9.8h, v5.h[1]\n" | |||||
"smlal v24.4s, v10.4h, v4.h[2]\n" | |||||
"smlal v26.4s, v10.4h, v5.h[2]\n" | |||||
"smlal2 v25.4s, v10.8h, v4.h[2]\n" | |||||
"smlal2 v27.4s, v10.8h, v5.h[2]\n" | |||||
"smlal v24.4s, v11.4h, v4.h[3]\n" | |||||
"smlal v26.4s, v11.4h, v5.h[3]\n" | |||||
"smlal2 v25.4s, v11.8h, v4.h[3]\n" | |||||
"smlal2 v27.4s, v11.8h, v5.h[3]\n" | |||||
"smlal v24.4s, v12.4h, v4.h[4]\n" | |||||
"smlal v26.4s, v12.4h, v5.h[4]\n" | |||||
"smlal2 v25.4s, v12.8h, v4.h[4]\n" | |||||
"smlal2 v27.4s, v12.8h, v5.h[4]\n" | |||||
"smlal v24.4s, v13.4h, v4.h[5]\n" | |||||
"smlal v26.4s, v13.4h, v5.h[5]\n" | |||||
"smlal2 v25.4s, v13.8h, v4.h[5]\n" | |||||
"smlal2 v27.4s, v13.8h, v5.h[5]\n" | |||||
"ld1 {v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"smlal v24.4s, v14.4h, v4.h[6]\n" | |||||
"smlal v26.4s, v14.4h, v5.h[6]\n" | |||||
"smlal2 v25.4s, v14.8h, v4.h[6]\n" | |||||
"smlal2 v27.4s, v14.8h, v5.h[6]\n" | |||||
"smull v28.4s, v8.4h, v6.h[0]\n" | |||||
"smull v30.4s, v8.4h, v7.h[0]\n" | |||||
"smull2 v29.4s, v8.8h, v6.h[0]\n" | |||||
"smull2 v31.4s, v8.8h, v7.h[0]\n" | |||||
"smlal v28.4s, v9.4h, v6.h[1]\n" | |||||
"smlal v30.4s, v9.4h, v7.h[1]\n" | |||||
"smlal2 v29.4s, v9.8h, v6.h[1]\n" | |||||
"smlal2 v31.4s, v9.8h, v7.h[1]\n" | |||||
"smlal v28.4s, v10.4h, v6.h[2]\n" | |||||
"smlal v30.4s, v10.4h, v7.h[2]\n" | |||||
"smlal2 v29.4s, v10.8h, v6.h[2]\n" | |||||
"smlal2 v31.4s, v10.8h, v7.h[2]\n" | |||||
"smlal v28.4s, v11.4h, v6.h[3]\n" | |||||
"smlal v30.4s, v11.4h, v7.h[3]\n" | |||||
"smlal2 v29.4s, v11.8h, v6.h[3]\n" | |||||
"smlal2 v31.4s, v11.8h, v7.h[3]\n" | |||||
"smlal v28.4s, v12.4h, v6.h[4]\n" | |||||
"smlal v30.4s, v12.4h, v7.h[4]\n" | |||||
"smlal2 v29.4s, v12.8h, v6.h[4]\n" | |||||
"smlal2 v31.4s, v12.8h, v7.h[4]\n" | |||||
"smlal v28.4s, v13.4h, v6.h[5]\n" | |||||
"smlal v30.4s, v13.4h, v7.h[5]\n" | |||||
"smlal2 v29.4s, v13.8h, v6.h[5]\n" | |||||
"smlal2 v31.4s, v13.8h, v7.h[5]\n" | |||||
"beq 2f\n" | |||||
"1:\n" | |||||
"smlal v24.4s, v15.4h, v4.h[7]\n" | |||||
"smlal v26.4s, v15.4h, v5.h[7]\n" | |||||
"smlal2 v25.4s, v15.8h, v4.h[7]\n" | |||||
"ld1 {v8.4s, v9.4s}, [%[a_ptr]], 32\n" | |||||
"smlal2 v27.4s, v15.8h, v5.h[7]\n" | |||||
"smlal v28.4s, v14.4h, v6.h[6]\n" | |||||
"smlal v30.4s, v14.4h, v7.h[6]\n" | |||||
"ld1 {v10.4s, v11.4s}, [%[a_ptr]], 32\n" | |||||
"smlal2 v29.4s, v15.8h, v6.h[7]\n" | |||||
"smlal2 v31.4s, v14.8h, v7.h[6]\n" | |||||
"smlal v28.4s, v15.4h, v6.h[7]\n" | |||||
"ld1 {v12.4s, v13.4s}, [%[a_ptr]], 32\n" | |||||
"smlal v30.4s, v15.4h, v7.h[7]\n" | |||||
"smlal2 v29.4s, v14.8h, v6.h[6]\n" | |||||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||||
"smlal2 v31.4s, v15.8h, v7.h[7]\n" | |||||
"smlal v16.4s, v8.4h, v0.h[0]\n" | |||||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||||
"smlal v16.4s, v9.4h, v0.h[1]\n" | |||||
"smlal2 v17.4s, v8.8h, v0.h[0]\n" | |||||
"smlal v16.4s, v10.4h, v0.h[2]\n" | |||||
"smlal v18.4s, v8.4h, v1.h[0]\n" | |||||
"smlal2 v17.4s, v9.8h, v0.h[1]\n" | |||||
"smlal2 v19.4s, v8.8h, v1.h[0]\n" | |||||
"ld1 {v14.4s, v15.4s}, [%[a_ptr]], 32\n" | |||||
"smlal v16.4s, v11.4h, v0.h[3]\n" | |||||
"smlal v18.4s, v9.4h, v1.h[1]\n" | |||||
"smlal2 v17.4s, v10.8h, v0.h[2]\n" | |||||
"smlal2 v19.4s, v9.8h, v1.h[1]\n" | |||||
"smlal v16.4s, v12.4h, v0.h[4]\n" | |||||
"smlal v18.4s, v10.4h, v1.h[2]\n" | |||||
"smlal2 v17.4s, v11.8h, v0.h[3]\n" | |||||
"smlal2 v19.4s, v10.8h, v1.h[2]\n" | |||||
"smlal v16.4s, v13.4h, v0.h[5]\n" | |||||
"smlal v18.4s, v11.4h, v1.h[3]\n" | |||||
"smlal2 v17.4s, v12.8h, v0.h[4]\n" | |||||
"smlal2 v19.4s, v11.8h, v1.h[3]\n" | |||||
"smlal v16.4s, v14.4h, v0.h[6]\n" | |||||
"smlal v18.4s, v12.4h, v1.h[4]\n" | |||||
"smlal2 v17.4s, v13.8h, v0.h[5]\n" | |||||
"smlal2 v19.4s, v12.8h, v1.h[4]\n" | |||||
"smlal v16.4s, v15.4h, v0.h[7]\n" | |||||
"smlal v18.4s, v13.4h, v1.h[5]\n" | |||||
"smlal2 v17.4s, v14.8h, v0.h[6]\n" | |||||
"smlal2 v19.4s, v13.8h, v1.h[5]\n" | |||||
"ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n" | |||||
"smlal v18.4s, v14.4h, v1.h[6]\n" | |||||
"smlal2 v17.4s, v15.8h, v0.h[7]\n" | |||||
"smlal2 v19.4s, v14.8h, v1.h[6]\n" | |||||
"smlal v20.4s, v8.4h, v2.h[0]\n" | |||||
"smlal v22.4s, v8.4h, v3.h[0]\n" | |||||
"smlal2 v21.4s, v8.8h, v2.h[0]\n" | |||||
"smlal2 v23.4s, v8.8h, v3.h[0]\n" | |||||
"smlal v18.4s, v15.4h, v1.h[7]\n" | |||||
"smlal v20.4s, v9.4h, v2.h[1]\n" | |||||
"smlal v22.4s, v9.4h, v3.h[1]\n" | |||||
"smlal2 v21.4s, v9.8h, v2.h[1]\n" | |||||
"smlal2 v23.4s, v9.8h, v3.h[1]\n" | |||||
"smlal2 v19.4s, v15.8h, v1.h[7]\n" | |||||
"smlal v20.4s, v10.4h, v2.h[2]\n" | |||||
"smlal v22.4s, v10.4h, v3.h[2]\n" | |||||
"smlal2 v21.4s, v10.8h, v2.h[2]\n" | |||||
"smlal2 v23.4s, v10.8h, v3.h[2]\n" | |||||
"smlal v20.4s, v11.4h, v2.h[3]\n" | |||||
"smlal v22.4s, v11.4h, v3.h[3]\n" | |||||
"smlal2 v21.4s, v11.8h, v2.h[3]\n" | |||||
"smlal2 v23.4s, v11.8h, v3.h[3]\n" | |||||
"smlal v20.4s, v12.4h, v2.h[4]\n" | |||||
"smlal v22.4s, v12.4h, v3.h[4]\n" | |||||
"smlal2 v21.4s, v12.8h, v2.h[4]\n" | |||||
"smlal2 v23.4s, v12.8h, v3.h[4]\n" | |||||
"smlal v20.4s, v13.4h, v2.h[5]\n" | |||||
"smlal v22.4s, v13.4h, v3.h[5]\n" | |||||
"smlal2 v21.4s, v13.8h, v2.h[5]\n" | |||||
"smlal2 v23.4s, v13.8h, v3.h[5]\n" | |||||
"ld1 {v4.4s, v5.4s}, [%[b_ptr]], 32\n" | |||||
"smlal v20.4s, v14.4h, v2.h[6]\n" | |||||
"smlal v22.4s, v14.4h, v3.h[6]\n" | |||||
"smlal2 v21.4s, v14.8h, v2.h[6]\n" | |||||
"smlal2 v23.4s, v14.8h, v3.h[6]\n" | |||||
"smlal v24.4s, v8.4h, v4.h[0]\n" | |||||
"smlal v26.4s, v8.4h, v5.h[0]\n" | |||||
"smlal2 v25.4s, v8.8h, v4.h[0]\n" | |||||
"smlal2 v27.4s, v8.8h, v5.h[0]\n" | |||||
"smlal v20.4s, v15.4h, v2.h[7]\n" | |||||
"smlal2 v21.4s, v15.8h, v2.h[7]\n" | |||||
"smlal v22.4s, v15.4h, v3.h[7]\n" | |||||
"smlal2 v23.4s, v15.8h, v3.h[7]\n" | |||||
"smlal v24.4s, v9.4h, v4.h[1]\n" | |||||
"smlal v26.4s, v9.4h, v5.h[1]\n" | |||||
"smlal2 v25.4s, v9.8h, v4.h[1]\n" | |||||
"smlal2 v27.4s, v9.8h, v5.h[1]\n" | |||||
"smlal v24.4s, v10.4h, v4.h[2]\n" | |||||
"smlal v26.4s, v10.4h, v5.h[2]\n" | |||||
"smlal2 v25.4s, v10.8h, v4.h[2]\n" | |||||
"smlal2 v27.4s, v10.8h, v5.h[2]\n" | |||||
"smlal v24.4s, v11.4h, v4.h[3]\n" | |||||
"smlal v26.4s, v11.4h, v5.h[3]\n" | |||||
"smlal2 v25.4s, v11.8h, v4.h[3]\n" | |||||
"smlal2 v27.4s, v11.8h, v5.h[3]\n" | |||||
"smlal v24.4s, v12.4h, v4.h[4]\n" | |||||
"smlal v26.4s, v12.4h, v5.h[4]\n" | |||||
"smlal2 v25.4s, v12.8h, v4.h[4]\n" | |||||
"smlal2 v27.4s, v12.8h, v5.h[4]\n" | |||||
"smlal v24.4s, v13.4h, v4.h[5]\n" | |||||
"smlal v26.4s, v13.4h, v5.h[5]\n" | |||||
"smlal2 v25.4s, v13.8h, v4.h[5]\n" | |||||
"smlal2 v27.4s, v13.8h, v5.h[5]\n" | |||||
"ld1 {v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n" | |||||
"smlal v24.4s, v14.4h, v4.h[6]\n" | |||||
"smlal v26.4s, v14.4h, v5.h[6]\n" | |||||
"smlal2 v25.4s, v14.8h, v4.h[6]\n" | |||||
"smlal2 v27.4s, v14.8h, v5.h[6]\n" | |||||
"smlal v28.4s, v8.4h, v6.h[0]\n" | |||||
"smlal v30.4s, v8.4h, v7.h[0]\n" | |||||
"smlal2 v29.4s, v8.8h, v6.h[0]\n" | |||||
"smlal2 v31.4s, v8.8h, v7.h[0]\n" | |||||
"smlal v28.4s, v9.4h, v6.h[1]\n" | |||||
"smlal v30.4s, v9.4h, v7.h[1]\n" | |||||
"smlal2 v29.4s, v9.8h, v6.h[1]\n" | |||||
"smlal2 v31.4s, v9.8h, v7.h[1]\n" | |||||
"smlal v28.4s, v10.4h, v6.h[2]\n" | |||||
"smlal v30.4s, v10.4h, v7.h[2]\n" | |||||
"smlal2 v29.4s, v10.8h, v6.h[2]\n" | |||||
"smlal2 v31.4s, v10.8h, v7.h[2]\n" | |||||
"smlal v28.4s, v11.4h, v6.h[3]\n" | |||||
"smlal v30.4s, v11.4h, v7.h[3]\n" | |||||
"smlal2 v29.4s, v11.8h, v6.h[3]\n" | |||||
"smlal2 v31.4s, v11.8h, v7.h[3]\n" | |||||
"smlal v28.4s, v12.4h, v6.h[4]\n" | |||||
"smlal v30.4s, v12.4h, v7.h[4]\n" | |||||
"smlal2 v29.4s, v12.8h, v6.h[4]\n" | |||||
"smlal2 v31.4s, v12.8h, v7.h[4]\n" | |||||
"smlal v28.4s, v13.4h, v6.h[5]\n" | |||||
"smlal v30.4s, v13.4h, v7.h[5]\n" | |||||
"smlal2 v29.4s, v13.8h, v6.h[5]\n" | |||||
"smlal2 v31.4s, v13.8h, v7.h[5]\n" | |||||
"subs %w[K], %w[K], #8\n" | |||||
"bne 1b\n" | |||||
"2:\n" | |||||
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" | |||||
"smlal v24.4s, v15.4h, v4.h[7]\n" | |||||
"smlal v28.4s, v14.4h, v6.h[6]\n" | |||||
"smlal v30.4s, v14.4h, v7.h[6]\n" | |||||
"smlal v26.4s, v15.4h, v5.h[7]\n" | |||||
"smlal2 v25.4s, v15.8h, v4.h[7]\n" | |||||
"smlal2 v27.4s, v15.8h, v5.h[7]\n" | |||||
"smlal2 v29.4s, v14.8h, v6.h[6]\n" | |||||
"smlal2 v31.4s, v14.8h, v7.h[6]\n" | |||||
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n" | |||||
"smlal v28.4s, v15.4h, v6.h[7]\n" | |||||
"smlal v30.4s, v15.4h, v7.h[7]\n" | |||||
"smlal2 v29.4s, v15.8h, v6.h[7]\n" | |||||
"smlal2 v31.4s, v15.8h, v7.h[7]\n" | |||||
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n" | |||||
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[output] "+r"(output), [LDB] "+r"(LDB) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "cc", "memory"); | |||||
} | |||||
} // anonymous namespace | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8); | |||||
void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||||
size_t LDB, dt_int32* C, size_t LDC, size_t M, | |||||
size_t K, size_t N, const dt_int32*, void*, | |||||
bool trA, bool trB) const { | |||||
constexpr static size_t MB = 8; | |||||
constexpr static size_t KB = 8; | |||||
constexpr static size_t NB = 8; | |||||
constexpr static size_t CALCBLK = 4; | |||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||||
for (size_t m = 0; m < M; m += MB) { | |||||
dt_int32* output = C + (m / MB) * LDC; | |||||
const dt_int16* cur_B = B; | |||||
size_t n = 0; | |||||
for (; n + NB - 1 < N; n += NB) { | |||||
kern_8x8(A, cur_B, LDB, K, output); | |||||
cur_B += KB * NB; | |||||
output += MB * NB; | |||||
} | |||||
if (n < N) { | |||||
kern_8x4(A, cur_B, LDB, K, output); | |||||
} | |||||
A += LDA; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,856 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#if !(__ARM_FEATURE_DOTPROD) | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul_4x4x16 { | |||||
/** | |||||
* Overview of register layout: | |||||
* | |||||
* A 16x2 cell of Rhs is stored in 8bit in q2-q4. | |||||
* A 8x2x2 cell of Lhs is stored in 8bit in q0-q1 | |||||
* A 8x16 block of accumulators is stored in 8bit in q8--q31. | |||||
* | |||||
* \warning Fast kernel operating on int8 operands. | |||||
* It is assumed that one of the two int8 operands only takes values | |||||
* in [-127, 127], while the other may freely range in [-128, 127]. | |||||
* The issue with both operands taking the value -128 is that: | |||||
* -128*-128 + -128*-128 == -32768 overflows int16. | |||||
* Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 | |||||
* range. That is the basic idea of this kernel. | |||||
* | |||||
* | |||||
* +--------+--------+---------+---------+ | |||||
* |v4[0-16]|v5[0-16]| v6[0-16]| v7[0-16]| | |||||
* Rhs +--------+--------+---------+---------+ | |||||
* |v8[0-16]|v9[0-16]|v10[0-16]|v11[0-16]| | |||||
* +--------+--------+---------+---------+ | |||||
* | | | | | | |||||
* | |||||
* Lhs | | | | | | |||||
* | |||||
* +--------+ - - - - +-------------------------------------+ | |||||
* |v0[0-16]| |v16[0-4]|v17[0-4]| v18[0-4]| v19[0-4]| | |||||
* |v1[0-16]| |v20[0-4]|v21[0-4]| v22[0-4]| v23[0-4]| | |||||
* |v2[0-16]| |v24[0-4]|v25[0-4]| v26[0-4]| v27[0-4]| | |||||
* |v3[0-16]| |v28[0-4]|v29[0-4]| v30[0-4]| v31[0-4]| | |||||
* +--------+ - - - - +-------------------------------------+ | |||||
* | |||||
* Accumulator | |||||
*/ | |||||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k) { | |||||
K /= 16; | |||||
const int8_t* a_ptr = packA; | |||||
const int8_t* b_ptr = packB; | |||||
// Fix up for odd lengths - set a flag if K is odd, but make | |||||
// sure we round up the iteration count. | |||||
int oddk = (K & 1); | |||||
int k = ((K + 1) / 2) - 1; | |||||
LDC = LDC * sizeof(int32_t); | |||||
asm volatile ( | |||||
// load accumulator C | |||||
"add x1, %[output], %x[LDC]\n" | |||||
"add x2, x1, %x[LDC]\n" | |||||
"add x3, x2, %x[LDC]\n" | |||||
"cmp %w[is_first_k], #1\n" | |||||
"beq 1f\n" | |||||
"ldr q16, [%[output]]\n" | |||||
"ldr q17, [x1]\n" | |||||
"ldr q18, [x2]\n" | |||||
"ldr q19, [x3]\n" | |||||
"b 2f\n" | |||||
"1:\n" | |||||
"eor v16.16b, v16.16b, v16.16b\n" | |||||
"eor v17.16b, v17.16b, v17.16b\n" | |||||
"eor v18.16b, v18.16b, v18.16b\n" | |||||
"eor v19.16b, v19.16b, v19.16b\n" | |||||
"2: \n" | |||||
"ldr q0, [%[a_ptr]]\n" | |||||
"ldr q4, [%[b_ptr]]\n" | |||||
"ldr q5, [%[b_ptr], #16]\n" | |||||
"ldr q6, [%[b_ptr], #32]\n" | |||||
"movi v20.4s, #0x0\n" | |||||
"ldr q7, [%[b_ptr], #48]\n" | |||||
"movi v21.4s, #0x0\n" | |||||
"ldr q1, [%[a_ptr], #16]\n" | |||||
"movi v22.4s, #0x0\n" | |||||
"ldr q2, [%[a_ptr], #32]\n" | |||||
"movi v23.4s, #0x0\n" | |||||
"ldr q3, [%[a_ptr], #48]\n" | |||||
"movi v24.4s, #0x0\n" | |||||
ASM_PREFETCH("[%[b_ptr], #64]") | |||||
"movi v25.4s, #0x0\n" | |||||
ASM_PREFETCH("[%[a_ptr], #64]") | |||||
"movi v26.4s, #0x0\n" | |||||
ASM_PREFETCH("[%[b_ptr], #128]") | |||||
"movi v27.4s, #0x0\n" | |||||
ASM_PREFETCH("[%[a_ptr], #128]") | |||||
"movi v28.4s, #0x0\n" | |||||
ASM_PREFETCH("[%[b_ptr], #192]") | |||||
"movi v29.4s, #0x0\n" | |||||
ASM_PREFETCH("[%[a_ptr], #192]") | |||||
"movi v30.4s, #0x0\n" | |||||
ASM_PREFETCH("[%[b_ptr], #256]") | |||||
"movi v31.4s, #0x0\n" | |||||
ASM_PREFETCH("[%[a_ptr], #256]") | |||||
// Start of unroll 0 (first iteration) | |||||
"smull v12.8h, v0.8b, v4.8b\n" | |||||
"smull v13.8h, v0.8b, v5.8b\n" | |||||
// Skip loop if we are doing zero iterations of it. | |||||
"cbz %w[k], 4f\n" | |||||
// Unroll 0 continuation (branch target) | |||||
"3:\n" | |||||
"smull v14.8h, v0.8b, v6.8b\n" | |||||
"subs %w[k], %w[k], #1\n" | |||||
"smull v15.8h, v0.8b, v7.8b\n" | |||||
"ldr q8, [%[b_ptr], #64]\n" | |||||
"smlal2 v12.8h, v0.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v0.16b, v5.16b\n" | |||||
"ldr q9, [%[b_ptr], #80]\n" | |||||
"smlal2 v14.8h, v0.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v0.16b, v7.16b\n" | |||||
"ldr q0, [%[a_ptr], #64]\n" | |||||
"sadalp v16.4s, v12.8h\n" | |||||
"smull v12.8h, v1.8b, v4.8b\n" | |||||
"sadalp v17.4s, v13.8h\n" | |||||
"sadalp v18.4s, v14.8h\n" | |||||
"smull v13.8h, v1.8b, v5.8b\n" | |||||
"sadalp v19.4s, v15.8h\n" | |||||
"smull v14.8h, v1.8b, v6.8b\n" | |||||
"ldr q10, [%[b_ptr], #96]\n" | |||||
"smull v15.8h, v1.8b, v7.8b\n" | |||||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||||
"ldr q11, [%[b_ptr], #112]\n" | |||||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||||
"add %[b_ptr], %[b_ptr], #128\n" | |||||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||||
"ldr q1, [%[a_ptr], #80]\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"smull v12.8h, v2.8b, v4.8b\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"smull v13.8h, v2.8b, v5.8b\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"smull v14.8h, v2.8b, v6.8b\n" | |||||
"smull v15.8h, v2.8b, v7.8b\n" | |||||
"smlal2 v12.8h, v2.16b, v4.16b\n" | |||||
ASM_PREFETCH("[%[b_ptr], #192]") | |||||
"smlal2 v13.8h, v2.16b, v5.16b\n" | |||||
"smlal2 v14.8h, v2.16b, v6.16b\n" | |||||
ASM_PREFETCH("[%[a_ptr], #320]") | |||||
"smlal2 v15.8h, v2.16b, v7.16b\n" | |||||
"ldr q2, [%[a_ptr], #96]\n" | |||||
"sadalp v24.4s, v12.8h\n" | |||||
"smull v12.8h, v3.8b, v4.8b\n" | |||||
"sadalp v25.4s, v13.8h\n" | |||||
"sadalp v26.4s, v14.8h\n" | |||||
"smull v13.8h, v3.8b, v5.8b\n" | |||||
"sadalp v27.4s, v15.8h\n" | |||||
"smull v14.8h, v3.8b, v6.8b\n" | |||||
"smull v15.8h, v3.8b, v7.8b\n" | |||||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||||
"ldr q4, [%[b_ptr], #0]\n" | |||||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||||
"ldr q3, [%[a_ptr], #112]\n" | |||||
// Unroll 1 | |||||
"sadalp v28.4s, v12.8h\n" | |||||
"smull v12.8h, v0.8b, v8.8b\n" | |||||
"sadalp v29.4s, v13.8h\n" | |||||
"sadalp v30.4s, v14.8h\n" | |||||
"smull v13.8h, v0.8b, v9.8b\n" | |||||
"sadalp v31.4s, v15.8h\n" | |||||
"smull v14.8h, v0.8b, v10.8b\n" | |||||
"smull v15.8h, v0.8b, v11.8b\n" | |||||
"ldr q5, [%[b_ptr], #16]\n" | |||||
"smlal2 v12.8h, v0.16b, v8.16b\n" | |||||
"smlal2 v13.8h, v0.16b, v9.16b\n" | |||||
"ldr q6, [%[b_ptr], #32]\n" | |||||
"smlal2 v14.8h, v0.16b, v10.16b\n" | |||||
"smlal2 v15.8h, v0.16b, v11.16b\n" | |||||
"ldr q0, [%[a_ptr], #128]\n" | |||||
"sadalp v16.4s, v12.8h\n" | |||||
"smull v12.8h, v1.8b, v8.8b\n" | |||||
"sadalp v17.4s, v13.8h\n" | |||||
"sadalp v18.4s, v14.8h\n" | |||||
"smull v13.8h, v1.8b, v9.8b\n" | |||||
"sadalp v19.4s, v15.8h\n" | |||||
"add %[a_ptr], %[a_ptr], #128\n" | |||||
"smull v14.8h, v1.8b, v10.8b\n" | |||||
"smull v15.8h, v1.8b, v11.8b\n" | |||||
"ldr q7, [%[b_ptr], #48]\n" | |||||
"smlal2 v12.8h, v1.16b, v8.16b\n" | |||||
"smlal2 v13.8h, v1.16b, v9.16b\n" | |||||
"smlal2 v14.8h, v1.16b, v10.16b\n" | |||||
"smlal2 v15.8h, v1.16b, v11.16b\n" | |||||
"ldr q1, [%[a_ptr], #16]\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"smull v12.8h, v2.8b, v8.8b\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"smull v13.8h, v2.8b, v9.8b\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"smull v14.8h, v2.8b, v10.8b\n" | |||||
"smull v15.8h, v2.8b, v11.8b\n" | |||||
"smlal2 v12.8h, v2.16b, v8.16b\n" | |||||
ASM_PREFETCH("[%[b_ptr], #256]") | |||||
"smlal2 v13.8h, v2.16b, v9.16b\n" | |||||
"smlal2 v14.8h, v2.16b, v10.16b\n" | |||||
ASM_PREFETCH("[%[a_ptr], #256]") | |||||
"smlal2 v15.8h, v2.16b, v11.16b\n" | |||||
"ldr q2, [%[a_ptr], #32]\n" | |||||
"sadalp v24.4s, v12.8h\n" | |||||
"smull v12.8h, v3.8b, v8.8b\n" | |||||
"sadalp v25.4s, v13.8h\n" | |||||
"sadalp v26.4s, v14.8h\n" | |||||
"smull v13.8h, v3.8b, v9.8b\n" | |||||
"sadalp v27.4s, v15.8h\n" | |||||
"smull v14.8h, v3.8b, v10.8b\n" | |||||
"smull v15.8h, v3.8b, v11.8b\n" | |||||
"smlal2 v12.8h, v3.16b, v8.16b\n" | |||||
"smlal2 v13.8h, v3.16b, v9.16b\n" | |||||
"smlal2 v14.8h, v3.16b, v10.16b\n" | |||||
"smlal2 v15.8h, v3.16b, v11.16b\n" | |||||
"ldr q3, [%[a_ptr], #48]\n" | |||||
// Start of unroll 0 for next iteration. | |||||
"sadalp v28.4s, v12.8h\n" | |||||
"smull v12.8h, v0.8b, v4.8b\n" | |||||
"sadalp v29.4s, v13.8h\n" | |||||
"sadalp v30.4s, v14.8h\n" | |||||
"smull v13.8h, v0.8b, v5.8b\n" | |||||
"sadalp v31.4s, v15.8h\n" | |||||
"bne 3b\n" | |||||
// Target to use when K=1 or 2 (i.e. zero iterations of main loop) | |||||
"4:\n" | |||||
// Branch to alternative tail for odd K | |||||
"cbnz %w[oddk], 5f\n" | |||||
// Detached final iteration (even K) | |||||
"smull v14.8h, v0.8b, v6.8b\n" | |||||
"smull v15.8h, v0.8b, v7.8b\n" | |||||
"ldr q8, [%[b_ptr], #64]\n" | |||||
"smlal2 v12.8h, v0.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v0.16b, v5.16b\n" | |||||
"ldr q9, [%[b_ptr], #80]\n" | |||||
"smlal2 v14.8h, v0.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v0.16b, v7.16b\n" | |||||
"ldr q0, [%[a_ptr], #64]\n" | |||||
"sadalp v16.4s, v12.8h\n" | |||||
"smull v12.8h, v1.8b, v4.8b\n" | |||||
"sadalp v17.4s, v13.8h\n" | |||||
"sadalp v18.4s, v14.8h\n" | |||||
"smull v13.8h, v1.8b, v5.8b\n" | |||||
"sadalp v19.4s, v15.8h\n" | |||||
"smull v14.8h, v1.8b, v6.8b\n" | |||||
"ldr q10, [%[b_ptr], #96]\n" | |||||
"smull v15.8h, v1.8b, v7.8b\n" | |||||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||||
"ldr q11, [%[b_ptr], #112]\n" | |||||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||||
"add %[b_ptr], %[b_ptr], #128\n" | |||||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||||
"ldr q1, [%[a_ptr], #80]\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"smull v12.8h, v2.8b, v4.8b\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"smull v13.8h, v2.8b, v5.8b\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"smull v14.8h, v2.8b, v6.8b\n" | |||||
"smull v15.8h, v2.8b, v7.8b\n" | |||||
"smlal2 v12.8h, v2.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v2.16b, v5.16b\n" | |||||
"smlal2 v14.8h, v2.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v2.16b, v7.16b\n" | |||||
"ldr q2, [%[a_ptr], #96]\n" | |||||
"sadalp v24.4s, v12.8h\n" | |||||
"smull v12.8h, v3.8b, v4.8b\n" | |||||
"sadalp v25.4s, v13.8h\n" | |||||
"sadalp v26.4s, v14.8h\n" | |||||
"smull v13.8h, v3.8b, v5.8b\n" | |||||
"sadalp v27.4s, v15.8h\n" | |||||
"smull v14.8h, v3.8b, v6.8b\n" | |||||
"smull v15.8h, v3.8b, v7.8b\n" | |||||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||||
"ldr q3, [%[a_ptr], #112]\n" | |||||
// Unroll 1 | |||||
"sadalp v28.4s, v12.8h\n" | |||||
"smull v12.8h, v0.8b, v8.8b\n" | |||||
"sadalp v29.4s, v13.8h\n" | |||||
"sadalp v30.4s, v14.8h\n" | |||||
"smull v13.8h, v0.8b, v9.8b\n" | |||||
"sadalp v31.4s, v15.8h\n" | |||||
"smull v14.8h, v0.8b, v10.8b\n" | |||||
"add %[a_ptr], %[a_ptr], #128\n" | |||||
"smull v15.8h, v0.8b, v11.8b\n" | |||||
"smlal2 v12.8h, v0.16b, v8.16b\n" | |||||
"smlal2 v13.8h, v0.16b, v9.16b\n" | |||||
"smlal2 v14.8h, v0.16b, v10.16b\n" | |||||
"smlal2 v15.8h, v0.16b, v11.16b\n" | |||||
"sadalp v16.4s, v12.8h\n" | |||||
"smull v12.8h, v1.8b, v8.8b\n" | |||||
"sadalp v17.4s, v13.8h\n" | |||||
"sadalp v18.4s, v14.8h\n" | |||||
"smull v13.8h, v1.8b, v9.8b\n" | |||||
"sadalp v19.4s, v15.8h\n" | |||||
"smull v14.8h, v1.8b, v10.8b\n" | |||||
"smull v15.8h, v1.8b, v11.8b\n" | |||||
"smlal2 v12.8h, v1.16b, v8.16b\n" | |||||
"addp v16.4s, v16.4s, v17.4s\n" | |||||
"smlal2 v13.8h, v1.16b, v9.16b\n" | |||||
"addp v17.4s, v18.4s, v19.4s\n" | |||||
"smlal2 v14.8h, v1.16b, v10.16b\n" | |||||
"smlal2 v15.8h, v1.16b, v11.16b\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"smull v12.8h, v2.8b, v8.8b\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"smull v13.8h, v2.8b, v9.8b\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"addp v16.4s, v16.4s, v17.4s\n" | |||||
"smull v14.8h, v2.8b, v10.8b\n" | |||||
"addp v18.4s, v20.4s, v21.4s\n" | |||||
"addp v19.4s, v22.4s, v23.4s\n" | |||||
"smull v15.8h, v2.8b, v11.8b\n" | |||||
"smlal2 v12.8h, v2.16b, v8.16b\n" | |||||
"str q16, [%[output]]\n" | |||||
"smlal2 v13.8h, v2.16b, v9.16b\n" | |||||
"smlal2 v14.8h, v2.16b, v10.16b\n" | |||||
"smlal2 v15.8h, v2.16b, v11.16b\n" | |||||
"sadalp v24.4s, v12.8h\n" | |||||
"smull v12.8h, v3.8b, v8.8b\n" | |||||
"sadalp v25.4s, v13.8h\n" | |||||
"sadalp v26.4s, v14.8h\n" | |||||
"smull v13.8h, v3.8b, v9.8b\n" | |||||
"sadalp v27.4s, v15.8h\n" | |||||
"addp v17.4s, v18.4s, v19.4s\n" | |||||
"smull v14.8h, v3.8b, v10.8b\n" | |||||
"addp v20.4s, v24.4s, v25.4s\n" | |||||
"addp v21.4s, v26.4s, v27.4s\n" | |||||
"smull v15.8h, v3.8b, v11.8b\n" | |||||
"smlal2 v12.8h, v3.16b, v8.16b\n" | |||||
"str q17, [x1]\n" | |||||
"smlal2 v13.8h, v3.16b, v9.16b\n" | |||||
"smlal2 v14.8h, v3.16b, v10.16b\n" | |||||
"addp v18.4s, v20.4s, v21.4s\n" | |||||
"smlal2 v15.8h, v3.16b, v11.16b\n" | |||||
"b 6f\n" | |||||
// Detached final iteration (odd K) | |||||
"5:\n" | |||||
"smull v14.8h, v0.8b, v6.8b\n" | |||||
"add %[a_ptr], %[a_ptr], #64\n" | |||||
"smull v15.8h, v0.8b, v7.8b\n" | |||||
"add %[b_ptr], %[b_ptr], #64\n" | |||||
"smlal2 v12.8h, v0.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v0.16b, v5.16b\n" | |||||
"smlal2 v14.8h, v0.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v0.16b, v7.16b\n" | |||||
"sadalp v16.4s, v12.8h\n" | |||||
"smull v12.8h, v1.8b, v4.8b\n" | |||||
"sadalp v17.4s, v13.8h\n" | |||||
"sadalp v18.4s, v14.8h\n" | |||||
"smull v13.8h, v1.8b, v5.8b\n" | |||||
"sadalp v19.4s, v15.8h\n" | |||||
"smull v14.8h, v1.8b, v6.8b\n" | |||||
"smull v15.8h, v1.8b, v7.8b\n" | |||||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||||
"addp v16.4s, v16.4s, v17.4s\n" | |||||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||||
"addp v17.4s, v18.4s, v19.4s\n" | |||||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"smull v12.8h, v2.8b, v4.8b\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"smull v13.8h, v2.8b, v5.8b\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"addp v16.4s, v16.4s, v17.4s\n" | |||||
"smull v14.8h, v2.8b, v6.8b\n" | |||||
"addp v18.4s, v20.4s, v21.4s\n" | |||||
"addp v19.4s, v22.4s, v23.4s\n" | |||||
"smull v15.8h, v2.8b, v7.8b\n" | |||||
"smlal2 v12.8h, v2.16b, v4.16b\n" | |||||
"str q16, [%[output]]\n" | |||||
"smlal2 v13.8h, v2.16b, v5.16b\n" | |||||
"smlal2 v14.8h, v2.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v2.16b, v7.16b\n" | |||||
"sadalp v24.4s, v12.8h\n" | |||||
"smull v12.8h, v3.8b, v4.8b\n" | |||||
"sadalp v25.4s, v13.8h\n" | |||||
"sadalp v26.4s, v14.8h\n" | |||||
"smull v13.8h, v3.8b, v5.8b\n" | |||||
"sadalp v27.4s, v15.8h\n" | |||||
"addp v17.4s, v18.4s, v19.4s\n" | |||||
"smull v14.8h, v3.8b, v6.8b\n" | |||||
"addp v20.4s, v24.4s, v25.4s\n" | |||||
"addp v21.4s, v26.4s, v27.4s\n" | |||||
"smull v15.8h, v3.8b, v7.8b\n" | |||||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||||
"str q17, [x1]\n" | |||||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||||
"addp v18.4s, v20.4s, v21.4s\n" | |||||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||||
"6:\n" | |||||
// Final additions | |||||
"sadalp v28.4s, v12.8h\n" | |||||
"str q18, [x2]\n" | |||||
"sadalp v29.4s, v13.8h\n" | |||||
"sadalp v30.4s, v14.8h\n" | |||||
"sadalp v31.4s, v15.8h\n" | |||||
// Horizontal reduction, phase 1 | |||||
"addp v22.4s, v28.4s, v29.4s\n" | |||||
"addp v23.4s, v30.4s, v31.4s\n" | |||||
// Horizontal reduction, phase 2 | |||||
"addp v19.4s, v22.4s, v23.4s\n" | |||||
"str q19, [x3]\n" | |||||
: | |||||
[a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [oddk] "+r" (oddk), | |||||
[is_first_k] "+r" (is_first_k), [k] "+r" (k), [LDC] "+r" (LDC), | |||||
[output] "+r"(output) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
"v17", "v18", "v19", "v20","v21","v22","v23","v24","v25","v26", | |||||
"v27","v28","v29","v30","v31", "x1", "x2", "x3", | |||||
"cc", "memory" | |||||
); | |||||
} | |||||
static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, | |||||
int m_remain, int n_remain) { | |||||
megdnn_assert(K > 0); | |||||
K /= 16; | |||||
const int8_t* a_ptr = packA; | |||||
const int8_t* b_ptr = packB; | |||||
LDC = LDC * sizeof(int32_t); | |||||
// clang-format off | |||||
#define LOAD_LINE(reg_index, n) \ | |||||
"cbz x4, 102f\n" \ | |||||
"mov x5, x" n "\n" \ | |||||
"cmp %w[n_remain], #4\n" \ | |||||
"blt 100" n "f\n" \ | |||||
"ldr q" reg_index ", [x5]\n" \ | |||||
"b 101" n "f\n" \ | |||||
"100" n ":\n" \ | |||||
"cmp %w[n_remain], #0\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" reg_index ".s}[0], [x5], #4\n" \ | |||||
"cmp %w[n_remain], #1\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" reg_index ".s}[1], [x5], #4\n" \ | |||||
"cmp %w[n_remain], #2\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" reg_index ".s}[2], [x5], #4\n" \ | |||||
"101" n ":\n" \ | |||||
"subs x4, x4, #1\n" | |||||
#define LOAD_C \ | |||||
"mov x4, %x[m_remain]\n" \ | |||||
LOAD_LINE("16", "0") \ | |||||
LOAD_LINE("17", "1") \ | |||||
LOAD_LINE("18", "2") \ | |||||
LOAD_LINE("19", "3") \ | |||||
"102:\n" | |||||
#define STORE_LINE(reg_index, n) \ | |||||
"cbz x4, 105f\n" \ | |||||
"mov x5, x" n "\n" \ | |||||
"cmp %w[n_remain], #4\n" \ | |||||
"blt 103" n "f\n" \ | |||||
"str q" reg_index ", [x5]\n" \ | |||||
"b 104" n "f\n" \ | |||||
"103" n ":\n" \ | |||||
"cmp %w[n_remain], #0\n" \ | |||||
"beq 104" n "f\n" \ | |||||
"st1 {v" reg_index ".s}[0], [x5], #4\n" \ | |||||
"cmp %w[n_remain], #1\n" \ | |||||
"beq 104" n "f\n" \ | |||||
"st1 {v" reg_index ".s}[1], [x5], #4\n" \ | |||||
"cmp %w[n_remain], #2\n" \ | |||||
"beq 104" n "f\n" \ | |||||
"st1 {v" reg_index ".s}[2], [x5], #4\n" \ | |||||
"104" n ":\n" \ | |||||
"subs x4, x4, #1\n" | |||||
#define STORE_C \ | |||||
"mov x4, %x[m_remain]\n" \ | |||||
STORE_LINE("16", "0") \ | |||||
STORE_LINE("17", "1") \ | |||||
STORE_LINE("18", "2") \ | |||||
STORE_LINE("19", "3") \ | |||||
"105:\n" | |||||
// clang-format on | |||||
asm volatile( | |||||
// load accumulator C | |||||
"mov x0, %[output]\n" | |||||
"add x1, x0, %x[LDC]\n" | |||||
"add x2, x1, %x[LDC]\n" | |||||
"add x3, x2, %x[LDC]\n" | |||||
"cmp %w[is_first_k], #1\n" | |||||
"beq 1f\n" | |||||
LOAD_C // | |||||
"b 2f\n" | |||||
"1:\n" | |||||
"eor v16.16b, v16.16b, v16.16b\n" | |||||
"eor v17.16b, v17.16b, v17.16b\n" | |||||
"eor v18.16b, v18.16b, v18.16b\n" | |||||
"eor v19.16b, v19.16b, v19.16b\n" | |||||
"eor v20.16b, v20.16b, v20.16b\n" | |||||
"eor v21.16b, v21.16b, v21.16b\n" | |||||
"eor v22.16b, v22.16b, v22.16b\n" | |||||
"eor v23.16b, v23.16b, v23.16b\n" | |||||
"eor v24.16b, v24.16b, v24.16b\n" | |||||
"eor v25.16b, v25.16b, v25.16b\n" | |||||
"eor v26.16b, v26.16b, v26.16b\n" | |||||
"eor v27.16b, v27.16b, v27.16b\n" | |||||
"eor v28.16b, v28.16b, v28.16b\n" | |||||
"eor v29.16b, v29.16b, v29.16b\n" | |||||
"eor v30.16b, v30.16b, v30.16b\n" | |||||
"eor v31.16b, v31.16b, v31.16b\n" | |||||
"2: \n" | |||||
"ldr q4, [%[b_ptr]]\n" | |||||
"ldr q5, [%[b_ptr], #16]\n" | |||||
"ldr q6, [%[b_ptr], #32]\n" | |||||
"ldr q7, [%[b_ptr], #48]\n" | |||||
"ldr q0, [%[a_ptr]]\n" | |||||
"ldr q1, [%[a_ptr], #16]\n" | |||||
"ldr q2, [%[a_ptr], #32]\n" | |||||
"ldr q3, [%[a_ptr], #48]\n" | |||||
"smull v12.8h, v0.8b, v4.8b\n" | |||||
"smull v13.8h, v0.8b, v5.8b\n" | |||||
"smull v14.8h, v0.8b, v6.8b\n" | |||||
"smull v15.8h, v0.8b, v7.8b\n" | |||||
"smlal2 v12.8h, v0.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v0.16b, v5.16b\n" | |||||
"smlal2 v14.8h, v0.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v0.16b, v7.16b\n" | |||||
"sadalp v16.4s, v12.8h\n" | |||||
"sadalp v17.4s, v13.8h\n" | |||||
"sadalp v18.4s, v14.8h\n" | |||||
"sadalp v19.4s, v15.8h\n" | |||||
"smull v12.8h, v1.8b, v4.8b\n" | |||||
"smull v13.8h, v1.8b, v5.8b\n" | |||||
"smull v14.8h, v1.8b, v6.8b\n" | |||||
"smull v15.8h, v1.8b, v7.8b\n" | |||||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"smull v12.8h, v2.8b, v4.8b\n" | |||||
"smull v13.8h, v2.8b, v5.8b\n" | |||||
"smull v14.8h, v2.8b, v6.8b\n" | |||||
"smull v15.8h, v2.8b, v7.8b\n" | |||||
"smlal2 v12.8h, v2.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v2.16b, v5.16b\n" | |||||
"smlal2 v14.8h, v2.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v2.16b, v7.16b\n" | |||||
"sadalp v24.4s, v12.8h\n" | |||||
"sadalp v25.4s, v13.8h\n" | |||||
"sadalp v26.4s, v14.8h\n" | |||||
"sadalp v27.4s, v15.8h\n" | |||||
"smull v12.8h, v3.8b, v4.8b\n" | |||||
"smull v13.8h, v3.8b, v5.8b\n" | |||||
"smull v14.8h, v3.8b, v6.8b\n" | |||||
"smull v15.8h, v3.8b, v7.8b\n" | |||||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||||
"sadalp v28.4s, v12.8h\n" | |||||
"sadalp v29.4s, v13.8h\n" | |||||
"sadalp v30.4s, v14.8h\n" | |||||
"sadalp v31.4s, v15.8h\n" | |||||
"add %[a_ptr], %[a_ptr], #64\n" | |||||
"add %[b_ptr], %[b_ptr], #64\n" | |||||
"subs %w[K], %w[K], #1\n" | |||||
"cbnz %w[K], 2b\n" | |||||
"3:\n" | |||||
// reduction | |||||
"addp v16.4s, v16.4s, v17.4s\n" | |||||
"addp v17.4s, v18.4s, v19.4s\n" | |||||
"addp v16.4s, v16.4s, v17.4s\n" | |||||
"addp v18.4s, v20.4s, v21.4s\n" | |||||
"addp v19.4s, v22.4s, v23.4s\n" | |||||
"addp v17.4s, v18.4s, v19.4s\n" | |||||
"addp v20.4s, v24.4s, v25.4s\n" | |||||
"addp v21.4s, v26.4s, v27.4s\n" | |||||
"addp v18.4s, v20.4s, v21.4s\n" | |||||
"addp v22.4s, v28.4s, v29.4s\n" | |||||
"addp v23.4s, v30.4s, v31.4s\n" | |||||
"addp v19.4s, v22.4s, v23.4s\n" | |||||
STORE_C | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[output] "+r"(output), [m_remain] "+r"(m_remain), | |||||
[n_remain] "+r"(n_remain) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc", | |||||
"memory"); | |||||
#undef LOAD_LINE | |||||
#undef LOAD_C | |||||
#undef STORE_LINE | |||||
#undef STORE_C | |||||
} | |||||
static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, int kmax) { | |||||
int8_t zerobuff[16]; | |||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||||
int y = y0; | |||||
for (; y + 3 < ymax; y += 4) { | |||||
const int8_t* inptr0 = inptr + y * ldin + k0; | |||||
const int8_t* inptr1 = inptr0 + ldin; | |||||
const int8_t* inptr2 = inptr1 + ldin; | |||||
const int8_t* inptr3 = inptr2 + ldin; | |||||
prefetch_2x(inptr0); | |||||
prefetch_2x(inptr1); | |||||
prefetch_2x(inptr2); | |||||
prefetch_2x(inptr3); | |||||
int K = kmax - k0; | |||||
//! read 16 * 4 in each row | |||||
for (; K > 15; K -= 16) { | |||||
interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); | |||||
} | |||||
if (K > 0) { | |||||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); | |||||
} | |||||
} | |||||
for (; y < ymax; y += 4) { | |||||
const int8_t* inptr0 = inptr + y * ldin + k0; | |||||
const int8_t* inptr1 = inptr0 + ldin; | |||||
const int8_t* inptr2 = inptr1 + ldin; | |||||
const int8_t* inptr3 = inptr2 + ldin; | |||||
prefetch_2x(inptr0); | |||||
prefetch_2x(inptr1); | |||||
prefetch_2x(inptr2); | |||||
prefetch_2x(inptr3); | |||||
int K = kmax - k0; | |||||
//! read 4 * 4 in each row | |||||
for (; K > 15; K -= 16) { | |||||
if (y + 3 >= ymax) { | |||||
switch (y + 3 - ymax) { | |||||
case 2: | |||||
inptr1 = zerobuff; | |||||
case 1: | |||||
inptr2 = zerobuff; | |||||
case 0: | |||||
inptr3 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); | |||||
} | |||||
if (K > 0) { | |||||
if (y + 3 >= ymax) { | |||||
switch (y + 3 - ymax) { | |||||
case 2: | |||||
inptr1 = zerobuff; | |||||
case 1: | |||||
inptr2 = zerobuff; | |||||
case 0: | |||||
inptr3 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); | |||||
} | |||||
} | |||||
} | |||||
static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
int8_t zerobuff[16]; | |||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||||
const int ksize = kmax - k0; | |||||
const int ksize4 = round_up(ksize, 16) * 4; | |||||
int8_t* outptr = out; | |||||
int k = k0; | |||||
for (; k < kmax; k += 16) { | |||||
int ki = k; | |||||
for (int cnt = 0; cnt < 2; ki += 8, cnt++) { | |||||
const int8_t* inptr0 = in + ki * ldin + x0; | |||||
const int8_t* inptr1 = inptr0 + ldin; | |||||
const int8_t* inptr2 = inptr1 + ldin; | |||||
const int8_t* inptr3 = inptr2 + ldin; | |||||
const int8_t* inptr4 = inptr3 + ldin; | |||||
const int8_t* inptr5 = inptr4 + ldin; | |||||
const int8_t* inptr6 = inptr5 + ldin; | |||||
const int8_t* inptr7 = inptr6 + ldin; | |||||
int8_t* outptr_inner = outptr + ki - k; | |||||
int remain = std::min(ki + 7 - kmax, 7); | |||||
int x = x0; | |||||
for (; x + 3 < xmax; x += 4) { | |||||
if (remain >= 0) { | |||||
switch (remain) { | |||||
case 7: | |||||
inptr0 = zerobuff; | |||||
case 6: | |||||
inptr1 = zerobuff; | |||||
case 5: | |||||
inptr2 = zerobuff; | |||||
case 4: | |||||
inptr3 = zerobuff; | |||||
case 3: | |||||
inptr4 = zerobuff; | |||||
case 2: | |||||
inptr5 = zerobuff; | |||||
case 1: | |||||
inptr6 = zerobuff; | |||||
case 0: | |||||
inptr7 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, | |||||
inptr4, inptr5, inptr6, inptr7, | |||||
outptr_inner); | |||||
outptr_inner += ksize4; | |||||
} | |||||
if (x < xmax) { | |||||
if (remain >= 0) { | |||||
switch (remain) { | |||||
case 7: | |||||
inptr0 = zerobuff; | |||||
case 6: | |||||
inptr1 = zerobuff; | |||||
case 5: | |||||
inptr2 = zerobuff; | |||||
case 4: | |||||
inptr3 = zerobuff; | |||||
case 3: | |||||
inptr4 = zerobuff; | |||||
case 2: | |||||
inptr5 = zerobuff; | |||||
case 1: | |||||
inptr6 = zerobuff; | |||||
case 0: | |||||
inptr7 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
for (; x < xmax; x++) { | |||||
*outptr_inner++ = *inptr0++; | |||||
*outptr_inner++ = *inptr1++; | |||||
*outptr_inner++ = *inptr2++; | |||||
*outptr_inner++ = *inptr3++; | |||||
*outptr_inner++ = *inptr4++; | |||||
*outptr_inner++ = *inptr5++; | |||||
*outptr_inner++ = *inptr6++; | |||||
*outptr_inner++ = *inptr7++; | |||||
outptr_inner += 8; | |||||
} | |||||
} | |||||
} | |||||
outptr += 16 * 4; | |||||
} | |||||
} | |||||
} // namespace matmul_4x4x16 | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,892 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include <cstring> | |||||
#if !(__ARM_FEATURE_DOTPROD) | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul_mk4_4x4x16 { | |||||
/** | |||||
* Overview of register layout: | |||||
* | |||||
* A 16x4 cell of Rhs is stored in 8bit in v0-q3. | |||||
* B 16x4 cell of Lhs is stored in 8bit in q4-q7 | |||||
* C 8x16 block of accumulators is stored in 8bit in q8--q31. | |||||
* | |||||
* \warning Fast kernel operating on int8 operands. | |||||
* It is assumed that one of the two int8 operands only takes values | |||||
* in [-127, 127], while the other may freely range in [-128, 127]. | |||||
* The issue with both operands taking the value -128 is that: | |||||
* -128*-128 + -128*-128 == -32768 overflows int16. | |||||
* Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 | |||||
* range. That is the basic idea of this kernel. | |||||
* | |||||
* | |||||
* +--------+--------+---------+---------+ | |||||
* |v4[0-16]|v5[0-16]| v6[0-16]| v7[0-16]| | |||||
* Rhs +--------+--------+---------+---------+ | |||||
* | | | | | | |||||
* | |||||
* Lhs | | | | | | |||||
* | |||||
* +--------+ - - - - +-------------------------------------+ | |||||
* |v0[0-16]| |v16[0-4]|v17[0-4]| v18[0-4]| v19[0-4]| | |||||
* |v1[0-16]| |v20[0-4]|v21[0-4]| v22[0-4]| v23[0-4]| | |||||
* |v2[0-16]| |v24[0-4]|v25[0-4]| v26[0-4]| v27[0-4]| | |||||
* |v3[0-16]| |v28[0-4]|v29[0-4]| v30[0-4]| v31[0-4]| | |||||
* +--------+ - - - - +-------------------------------------+ | |||||
* | |||||
* Accumulator | |||||
*/ | |||||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, bool is_first_k) { | |||||
K = div_ceil(K, 16); | |||||
const int8_t* a_ptr = packA; | |||||
const int8_t* b_ptr = packB; | |||||
asm volatile( | |||||
// load accumulator C | |||||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||||
"eor v16.16b, v16.16b, v16.16b\n" | |||||
"eor v17.16b, v17.16b, v17.16b\n" | |||||
"eor v18.16b, v18.16b, v18.16b\n" | |||||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||||
"eor v19.16b, v19.16b, v19.16b\n" | |||||
"eor v20.16b, v19.16b, v19.16b\n" | |||||
"eor v21.16b, v19.16b, v19.16b\n" | |||||
"ld1 {v4.16b, v5.16b}, [%[b_ptr]], #32\n" | |||||
"eor v22.16b, v19.16b, v19.16b\n" | |||||
"PRFM PLDL1KEEP, [%[a_ptr], #32]\n" | |||||
"eor v23.16b, v19.16b, v19.16b\n" | |||||
"eor v24.16b, v19.16b, v19.16b\n" | |||||
"PRFM PLDL1KEEP, [%[b_ptr], #32]\n" | |||||
"eor v25.16b, v19.16b, v19.16b\n" | |||||
"eor v26.16b, v19.16b, v19.16b\n" | |||||
"PRFM PLDL1KEEP, [%[b_ptr], #64]\n" | |||||
"eor v27.16b, v19.16b, v19.16b\n" | |||||
"eor v28.16b, v19.16b, v19.16b\n" | |||||
"PRFM PLDL1KEEP, [%[a_ptr], #64]\n" | |||||
"eor v29.16b, v19.16b, v19.16b\n" | |||||
"eor v30.16b, v19.16b, v19.16b\n" | |||||
"PRFM PLDL1KEEP, [%[b_ptr], #128]\n" | |||||
"eor v31.16b, v19.16b, v19.16b\n" | |||||
//! if K==1 jump to compute last K | |||||
"cmp %w[k], #2\n" | |||||
"beq 2f\n" | |||||
"blt 3f\n" | |||||
//! K>2 | |||||
"1:\n" | |||||
//! First k | |||||
"smull v8.8h, v0.8b, v4.8b\n" | |||||
"smull v9.8h, v0.8b, v5.8b\n" | |||||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||||
"smull v12.8h, v1.8b, v4.8b\n" | |||||
"smull v13.8h, v1.8b, v5.8b\n" | |||||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||||
"smull v10.8h, v0.8b, v6.8b\n" | |||||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||||
"smull v11.8h, v0.8b, v7.8b\n" | |||||
"smull v14.8h, v1.8b, v6.8b\n" | |||||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||||
"smull v15.8h, v1.8b, v7.8b\n" | |||||
"sadalp v16.4s, v8.8h\n" | |||||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||||
"sadalp v17.4s, v9.8h\n" | |||||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||||
"smull v8.8h, v2.8b, v4.8b\n" | |||||
"smull v9.8h, v2.8b, v5.8b\n" | |||||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||||
"smull v12.8h, v3.8b, v4.8b\n" | |||||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||||
"smull v13.8h, v3.8b, v5.8b\n" | |||||
"sadalp v18.4s, v10.8h\n" | |||||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||||
"sadalp v19.4s, v11.8h\n" | |||||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||||
"smull v10.8h, v2.8b, v6.8b\n" | |||||
"smull v11.8h, v2.8b, v7.8b\n" | |||||
"ld1 {v4.16b}, [%[b_ptr]], #16\n" | |||||
"smull v14.8h, v3.8b, v6.8b\n" | |||||
"ld1 {v5.16b}, [%[b_ptr]], #16\n" | |||||
"smull v15.8h, v3.8b, v7.8b\n" | |||||
"sadalp v24.4s, v8.8h\n" | |||||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||||
"sadalp v25.4s, v9.8h\n" | |||||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||||
"sadalp v28.4s, v12.8h\n" | |||||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||||
"sadalp v29.4s, v13.8h\n" | |||||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||||
//! Second k | |||||
"smull v8.8h, v0.8b, v4.8b\n" | |||||
"smull v9.8h, v0.8b, v5.8b\n" | |||||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||||
"smull v12.8h, v1.8b, v4.8b\n" | |||||
"smull v13.8h, v1.8b, v5.8b\n" | |||||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||||
"sadalp v26.4s, v10.8h\n" | |||||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||||
"sadalp v27.4s, v11.8h\n" | |||||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||||
"sadalp v30.4s, v14.8h\n" | |||||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||||
"sadalp v31.4s, v15.8h\n" | |||||
"smull v10.8h, v0.8b, v6.8b\n" | |||||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||||
"smull v11.8h, v0.8b, v7.8b\n" | |||||
"smull v14.8h, v1.8b, v6.8b\n" | |||||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||||
"smull v15.8h, v1.8b, v7.8b\n" | |||||
"sadalp v16.4s, v8.8h\n" | |||||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||||
"sadalp v17.4s, v9.8h\n" | |||||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||||
"smull v8.8h, v2.8b, v4.8b\n" | |||||
"smull v9.8h, v2.8b, v5.8b\n" | |||||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||||
"smull v12.8h, v3.8b, v4.8b\n" | |||||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||||
"smull v13.8h, v3.8b, v5.8b\n" | |||||
"sadalp v18.4s, v10.8h\n" | |||||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||||
"sadalp v19.4s, v11.8h\n" | |||||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||||
"sub %w[k], %w[k], #2\n" | |||||
"cmp %w[k], #2\n" | |||||
"smull v10.8h, v2.8b, v6.8b\n" | |||||
"ld1 {v4.16b}, [%[b_ptr]], #16\n" | |||||
"smull v11.8h, v2.8b, v7.8b\n" | |||||
"ld1 {v5.16b}, [%[b_ptr]], #16\n" | |||||
"smull v14.8h, v3.8b, v6.8b\n" | |||||
"sadalp v24.4s, v8.8h\n" | |||||
"smull v15.8h, v3.8b, v7.8b\n" | |||||
"sadalp v25.4s, v9.8h\n" | |||||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||||
"sadalp v28.4s, v12.8h\n" | |||||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||||
"sadalp v29.4s, v13.8h\n" | |||||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||||
"sadalp v26.4s, v10.8h\n" | |||||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||||
"sadalp v27.4s, v11.8h\n" | |||||
"sadalp v30.4s, v14.8h\n" | |||||
"sadalp v31.4s, v15.8h\n" | |||||
"bgt 1b\n" | |||||
"blt 3f\n" | |||||
//! K==2 | |||||
"2:\n" | |||||
"smull v8.8h, v0.8b, v4.8b\n" | |||||
"smull v9.8h, v0.8b, v5.8b\n" | |||||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||||
"smull v12.8h, v1.8b, v4.8b\n" | |||||
"smull v13.8h, v1.8b, v5.8b\n" | |||||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||||
"smull v10.8h, v0.8b, v6.8b\n" | |||||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||||
"smull v11.8h, v0.8b, v7.8b\n" | |||||
"smull v14.8h, v1.8b, v6.8b\n" | |||||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||||
"smull v15.8h, v1.8b, v7.8b\n" | |||||
"sadalp v16.4s, v8.8h\n" | |||||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||||
"sadalp v17.4s, v9.8h\n" | |||||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||||
"smull v8.8h, v2.8b, v4.8b\n" | |||||
"smull v9.8h, v2.8b, v5.8b\n" | |||||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||||
"smull v12.8h, v3.8b, v4.8b\n" | |||||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||||
"smull v13.8h, v3.8b, v5.8b\n" | |||||
"sadalp v18.4s, v10.8h\n" | |||||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||||
"sadalp v19.4s, v11.8h\n" | |||||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||||
"smull v10.8h, v2.8b, v6.8b\n" | |||||
"smull v11.8h, v2.8b, v7.8b\n" | |||||
"ld1 {v4.16b}, [%[b_ptr]], #16\n" | |||||
"smull v14.8h, v3.8b, v6.8b\n" | |||||
"ld1 {v5.16b}, [%[b_ptr]], #16\n" | |||||
"smull v15.8h, v3.8b, v7.8b\n" | |||||
"sadalp v24.4s, v8.8h\n" | |||||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||||
"sadalp v25.4s, v9.8h\n" | |||||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||||
"sadalp v28.4s, v12.8h\n" | |||||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||||
"sadalp v29.4s, v13.8h\n" | |||||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||||
"sadalp v26.4s, v10.8h\n" | |||||
"sadalp v27.4s, v11.8h\n" | |||||
"sadalp v30.4s, v14.8h\n" | |||||
"sadalp v31.4s, v15.8h\n" | |||||
//! K==1 | |||||
"3:\n" | |||||
"smull v8.8h, v0.8b, v4.8b\n" | |||||
"smull v9.8h, v0.8b, v5.8b\n" | |||||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||||
"smull v12.8h, v1.8b, v4.8b\n" | |||||
"smull v13.8h, v1.8b, v5.8b\n" | |||||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||||
"smull v10.8h, v0.8b, v6.8b\n" | |||||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||||
"smull v11.8h, v0.8b, v7.8b\n" | |||||
"smull v14.8h, v1.8b, v6.8b\n" | |||||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||||
"smull v15.8h, v1.8b, v7.8b\n" | |||||
"sadalp v16.4s, v8.8h\n" | |||||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||||
"sadalp v17.4s, v9.8h\n" | |||||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||||
"smull v8.8h, v2.8b, v4.8b\n" | |||||
"smull v9.8h, v2.8b, v5.8b\n" | |||||
"smull v12.8h, v3.8b, v4.8b\n" | |||||
"smull v13.8h, v3.8b, v5.8b\n" | |||||
"sadalp v18.4s, v10.8h\n" | |||||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||||
"sadalp v19.4s, v11.8h\n" | |||||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||||
"smull v10.8h, v2.8b, v6.8b\n" | |||||
"sadalp v24.4s, v8.8h\n" | |||||
"smull v11.8h, v2.8b, v7.8b\n" | |||||
"sadalp v25.4s, v9.8h\n" | |||||
"smull v14.8h, v3.8b, v6.8b\n" | |||||
"sadalp v28.4s, v12.8h\n" | |||||
"smull v15.8h, v3.8b, v7.8b\n" | |||||
"sadalp v29.4s, v13.8h\n" | |||||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||||
"sadalp v26.4s, v10.8h\n" | |||||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||||
"sadalp v27.4s, v11.8h\n" | |||||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||||
"sadalp v30.4s, v14.8h\n" | |||||
"sadalp v31.4s, v15.8h\n" | |||||
"addp v4.4s, v16.4s, v20.4s\n" | |||||
"addp v5.4s, v24.4s, v28.4s\n" | |||||
"addp v6.4s, v17.4s, v21.4s\n" | |||||
"addp v7.4s, v25.4s, v29.4s\n" | |||||
"addp v8.4s, v18.4s, v22.4s\n" | |||||
"addp v9.4s, v26.4s, v30.4s\n" | |||||
"addp v10.4s, v19.4s, v23.4s\n" | |||||
"addp v11.4s, v27.4s, v31.4s\n" | |||||
"cmp %w[is_first_k], #1\n" | |||||
"addp v0.4s, v4.4s, v5.4s\n" | |||||
"addp v1.4s, v6.4s, v7.4s\n" | |||||
"addp v2.4s, v8.4s, v9.4s\n" | |||||
"addp v3.4s, v10.4s, v11.4s\n" | |||||
"beq 6f\n" | |||||
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output]]\n" | |||||
"add v0.4s, v0.4s, v8.4s\n" | |||||
"add v1.4s, v1.4s, v9.4s\n" | |||||
"add v2.4s, v2.4s, v10.4s\n" | |||||
"add v3.4s, v3.4s, v11.4s\n" | |||||
"6:\n" | |||||
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "cc", "memory"); | |||||
} | |||||
static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, bool is_first_k, size_t remain_n) { | |||||
K = div_ceil(K, 16); | |||||
const int8_t* a_ptr = packA; | |||||
const int8_t* b_ptr = packB; | |||||
asm volatile( | |||||
// load accumulator C | |||||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||||
"eor v16.16b, v16.16b, v16.16b\n" | |||||
"eor v17.16b, v17.16b, v17.16b\n" | |||||
"eor v18.16b, v18.16b, v18.16b\n" | |||||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||||
"eor v19.16b, v19.16b, v19.16b\n" | |||||
"eor v20.16b, v19.16b, v19.16b\n" | |||||
"eor v21.16b, v19.16b, v19.16b\n" | |||||
"ld1 {v4.16b, v5.16b}, [%[b_ptr]], #32\n" | |||||
"eor v22.16b, v19.16b, v19.16b\n" | |||||
"PRFM PLDL1KEEP, [%[a_ptr], #32]\n" | |||||
"eor v23.16b, v19.16b, v19.16b\n" | |||||
"eor v24.16b, v19.16b, v19.16b\n" | |||||
"PRFM PLDL1KEEP, [%[b_ptr], #32]\n" | |||||
"eor v25.16b, v19.16b, v19.16b\n" | |||||
"eor v26.16b, v19.16b, v19.16b\n" | |||||
"PRFM PLDL1KEEP, [%[b_ptr], #64]\n" | |||||
"eor v27.16b, v19.16b, v19.16b\n" | |||||
"eor v28.16b, v19.16b, v19.16b\n" | |||||
"PRFM PLDL1KEEP, [%[a_ptr], #64]\n" | |||||
"eor v29.16b, v19.16b, v19.16b\n" | |||||
"eor v30.16b, v19.16b, v19.16b\n" | |||||
"PRFM PLDL1KEEP, [%[b_ptr], #128]\n" | |||||
"eor v31.16b, v19.16b, v19.16b\n" | |||||
//! if K==1 jump to compute last K | |||||
"cmp %w[k], #2\n" | |||||
"beq 2f\n" | |||||
"blt 3f\n" | |||||
//! K>2 | |||||
"1:\n" | |||||
//! First k | |||||
"smull v8.8h, v0.8b, v4.8b\n" | |||||
"smull v9.8h, v0.8b, v5.8b\n" | |||||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||||
"smull v12.8h, v1.8b, v4.8b\n" | |||||
"smull v13.8h, v1.8b, v5.8b\n" | |||||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||||
"smull v10.8h, v0.8b, v6.8b\n" | |||||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||||
"smull v11.8h, v0.8b, v7.8b\n" | |||||
"smull v14.8h, v1.8b, v6.8b\n" | |||||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||||
"smull v15.8h, v1.8b, v7.8b\n" | |||||
"sadalp v16.4s, v8.8h\n" | |||||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||||
"sadalp v17.4s, v9.8h\n" | |||||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||||
"smull v8.8h, v2.8b, v4.8b\n" | |||||
"smull v9.8h, v2.8b, v5.8b\n" | |||||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||||
"smull v12.8h, v3.8b, v4.8b\n" | |||||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||||
"smull v13.8h, v3.8b, v5.8b\n" | |||||
"sadalp v18.4s, v10.8h\n" | |||||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||||
"sadalp v19.4s, v11.8h\n" | |||||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||||
"smull v10.8h, v2.8b, v6.8b\n" | |||||
"smull v11.8h, v2.8b, v7.8b\n" | |||||
"ld1 {v4.16b}, [%[b_ptr]], #16\n" | |||||
"smull v14.8h, v3.8b, v6.8b\n" | |||||
"ld1 {v5.16b}, [%[b_ptr]], #16\n" | |||||
"smull v15.8h, v3.8b, v7.8b\n" | |||||
"sadalp v24.4s, v8.8h\n" | |||||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||||
"sadalp v25.4s, v9.8h\n" | |||||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||||
"sadalp v28.4s, v12.8h\n" | |||||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||||
"sadalp v29.4s, v13.8h\n" | |||||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||||
//! Second k | |||||
"smull v8.8h, v0.8b, v4.8b\n" | |||||
"smull v9.8h, v0.8b, v5.8b\n" | |||||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||||
"smull v12.8h, v1.8b, v4.8b\n" | |||||
"smull v13.8h, v1.8b, v5.8b\n" | |||||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||||
"sadalp v26.4s, v10.8h\n" | |||||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||||
"sadalp v27.4s, v11.8h\n" | |||||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||||
"sadalp v30.4s, v14.8h\n" | |||||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||||
"sadalp v31.4s, v15.8h\n" | |||||
"smull v10.8h, v0.8b, v6.8b\n" | |||||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||||
"smull v11.8h, v0.8b, v7.8b\n" | |||||
"smull v14.8h, v1.8b, v6.8b\n" | |||||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||||
"smull v15.8h, v1.8b, v7.8b\n" | |||||
"sadalp v16.4s, v8.8h\n" | |||||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||||
"sadalp v17.4s, v9.8h\n" | |||||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||||
"smull v8.8h, v2.8b, v4.8b\n" | |||||
"smull v9.8h, v2.8b, v5.8b\n" | |||||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||||
"smull v12.8h, v3.8b, v4.8b\n" | |||||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||||
"smull v13.8h, v3.8b, v5.8b\n" | |||||
"sadalp v18.4s, v10.8h\n" | |||||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||||
"sadalp v19.4s, v11.8h\n" | |||||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||||
"sub %w[k], %w[k], #2\n" | |||||
"cmp %w[k], #2\n" | |||||
"smull v10.8h, v2.8b, v6.8b\n" | |||||
"ld1 {v4.16b}, [%[b_ptr]], #16\n" | |||||
"smull v11.8h, v2.8b, v7.8b\n" | |||||
"ld1 {v5.16b}, [%[b_ptr]], #16\n" | |||||
"smull v14.8h, v3.8b, v6.8b\n" | |||||
"sadalp v24.4s, v8.8h\n" | |||||
"smull v15.8h, v3.8b, v7.8b\n" | |||||
"sadalp v25.4s, v9.8h\n" | |||||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||||
"sadalp v28.4s, v12.8h\n" | |||||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||||
"sadalp v29.4s, v13.8h\n" | |||||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||||
"sadalp v26.4s, v10.8h\n" | |||||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||||
"sadalp v27.4s, v11.8h\n" | |||||
"sadalp v30.4s, v14.8h\n" | |||||
"sadalp v31.4s, v15.8h\n" | |||||
"bgt 1b\n" | |||||
"blt 3f\n" | |||||
//! K==2 | |||||
"2:\n" | |||||
"smull v8.8h, v0.8b, v4.8b\n" | |||||
"smull v9.8h, v0.8b, v5.8b\n" | |||||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||||
"smull v12.8h, v1.8b, v4.8b\n" | |||||
"smull v13.8h, v1.8b, v5.8b\n" | |||||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||||
"smull v10.8h, v0.8b, v6.8b\n" | |||||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||||
"smull v11.8h, v0.8b, v7.8b\n" | |||||
"smull v14.8h, v1.8b, v6.8b\n" | |||||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||||
"smull v15.8h, v1.8b, v7.8b\n" | |||||
"sadalp v16.4s, v8.8h\n" | |||||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||||
"sadalp v17.4s, v9.8h\n" | |||||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||||
"smull v8.8h, v2.8b, v4.8b\n" | |||||
"smull v9.8h, v2.8b, v5.8b\n" | |||||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||||
"smull v12.8h, v3.8b, v4.8b\n" | |||||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||||
"smull v13.8h, v3.8b, v5.8b\n" | |||||
"sadalp v18.4s, v10.8h\n" | |||||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||||
"sadalp v19.4s, v11.8h\n" | |||||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||||
"smull v10.8h, v2.8b, v6.8b\n" | |||||
"smull v11.8h, v2.8b, v7.8b\n" | |||||
"ld1 {v4.16b}, [%[b_ptr]], #16\n" | |||||
"smull v14.8h, v3.8b, v6.8b\n" | |||||
"ld1 {v5.16b}, [%[b_ptr]], #16\n" | |||||
"smull v15.8h, v3.8b, v7.8b\n" | |||||
"sadalp v24.4s, v8.8h\n" | |||||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||||
"sadalp v25.4s, v9.8h\n" | |||||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||||
"sadalp v28.4s, v12.8h\n" | |||||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||||
"sadalp v29.4s, v13.8h\n" | |||||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||||
"sadalp v26.4s, v10.8h\n" | |||||
"sadalp v27.4s, v11.8h\n" | |||||
"sadalp v30.4s, v14.8h\n" | |||||
"sadalp v31.4s, v15.8h\n" | |||||
//! K==1 | |||||
"3:\n" | |||||
"smull v8.8h, v0.8b, v4.8b\n" | |||||
"smull v9.8h, v0.8b, v5.8b\n" | |||||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||||
"smull v12.8h, v1.8b, v4.8b\n" | |||||
"smull v13.8h, v1.8b, v5.8b\n" | |||||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||||
"smull v10.8h, v0.8b, v6.8b\n" | |||||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||||
"smull v11.8h, v0.8b, v7.8b\n" | |||||
"smull v14.8h, v1.8b, v6.8b\n" | |||||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||||
"smull v15.8h, v1.8b, v7.8b\n" | |||||
"sadalp v16.4s, v8.8h\n" | |||||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||||
"sadalp v17.4s, v9.8h\n" | |||||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||||
"sadalp v20.4s, v12.8h\n" | |||||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||||
"sadalp v21.4s, v13.8h\n" | |||||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||||
"smull v8.8h, v2.8b, v4.8b\n" | |||||
"smull v9.8h, v2.8b, v5.8b\n" | |||||
"smull v12.8h, v3.8b, v4.8b\n" | |||||
"smull v13.8h, v3.8b, v5.8b\n" | |||||
"sadalp v18.4s, v10.8h\n" | |||||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||||
"sadalp v19.4s, v11.8h\n" | |||||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||||
"sadalp v22.4s, v14.8h\n" | |||||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||||
"sadalp v23.4s, v15.8h\n" | |||||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||||
"smull v10.8h, v2.8b, v6.8b\n" | |||||
"sadalp v24.4s, v8.8h\n" | |||||
"smull v11.8h, v2.8b, v7.8b\n" | |||||
"sadalp v25.4s, v9.8h\n" | |||||
"smull v14.8h, v3.8b, v6.8b\n" | |||||
"sadalp v28.4s, v12.8h\n" | |||||
"smull v15.8h, v3.8b, v7.8b\n" | |||||
"sadalp v29.4s, v13.8h\n" | |||||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||||
"sadalp v26.4s, v10.8h\n" | |||||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||||
"sadalp v27.4s, v11.8h\n" | |||||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||||
"sadalp v30.4s, v14.8h\n" | |||||
"sadalp v31.4s, v15.8h\n" | |||||
"addp v4.4s, v16.4s, v20.4s\n" | |||||
"addp v5.4s, v24.4s, v28.4s\n" | |||||
"addp v6.4s, v17.4s, v21.4s\n" | |||||
"addp v7.4s, v25.4s, v29.4s\n" | |||||
"addp v8.4s, v18.4s, v22.4s\n" | |||||
"addp v9.4s, v26.4s, v30.4s\n" | |||||
"addp v10.4s, v19.4s, v23.4s\n" | |||||
"addp v11.4s, v27.4s, v31.4s\n" | |||||
"addp v0.4s, v4.4s, v5.4s\n" | |||||
"addp v1.4s, v6.4s, v7.4s\n" | |||||
"addp v2.4s, v8.4s, v9.4s\n" | |||||
"addp v3.4s, v10.4s, v11.4s\n" | |||||
"cmp %w[is_first_k], #1\n" | |||||
"beq 6f\n" | |||||
"cmp %w[remain_n], #3\n" | |||||
"beq 1003f\n" | |||||
"cmp %w[remain_n], #2\n" | |||||
"beq 1002f\n" | |||||
"cmp %w[remain_n], #1\n" | |||||
"beq 1001f\n" | |||||
"1003:\n" | |||||
"ld1 {v8.4s, v9.4s, v10.4s}, [%[output]]\n" | |||||
"add v0.4s, v0.4s, v8.4s\n" | |||||
"add v1.4s, v1.4s, v9.4s\n" | |||||
"add v2.4s, v2.4s, v10.4s\n" | |||||
"b 6f\n" | |||||
"1002:\n" | |||||
"ld1 {v8.4s, v9.4s}, [%[output]]\n" | |||||
"add v0.4s, v0.4s, v8.4s\n" | |||||
"add v1.4s, v1.4s, v9.4s\n" | |||||
"b 6f\n" | |||||
"1001:\n" | |||||
"ld1 {v8.4s}, [%[output]]\n" | |||||
"add v0.4s, v0.4s, v8.4s\n" | |||||
"6:\n" | |||||
"cmp %w[remain_n], #3\n" | |||||
"beq 10003f\n" | |||||
"cmp %w[remain_n], #2\n" | |||||
"beq 10002f\n" | |||||
"cmp %w[remain_n], #1\n" | |||||
"beq 10001f\n" | |||||
"10003:\n" | |||||
"str q2, [%[output], #32]\n" | |||||
"10002:\n" | |||||
"str q1, [%[output], #16]\n" | |||||
"10001:\n" | |||||
"str q0, [%[output]]\n" | |||||
"7:\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[remain_n] "+r"(remain_n), [is_first_k] "+r"(is_first_k), | |||||
[k] "+r"(K), [output] "+r"(output) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "cc", "memory"); | |||||
} | |||||
static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
//! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} | |||||
int8_t zerobuff[4][64]; | |||||
std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); | |||||
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, | |||||
"mk4 matmul with m is not times of 4"); | |||||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, | |||||
"mk4 matmul with k is not times of 4"); | |||||
size_t roundk = round_up(kmax - k0, 16); | |||||
size_t out_offset = roundk * 4; | |||||
int y = y0; | |||||
int start_y = y0 / 4; | |||||
for (; y + 15 < ymax; y += 16, start_y += 4) { | |||||
const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; | |||||
const int8_t* inptr1 = inptr0 + ldin; | |||||
const int8_t* inptr2 = inptr1 + ldin; | |||||
const int8_t* inptr3 = inptr2 + ldin; | |||||
int8_t* output = outptr + start_y * out_offset; | |||||
prefetch_2x(inptr0); | |||||
prefetch_2x(inptr1); | |||||
prefetch_2x(inptr2); | |||||
prefetch_2x(inptr3); | |||||
int K = kmax - k0; | |||||
for (; K > 15; K -= 16) { | |||||
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, | |||||
out_offset); | |||||
output += 64; | |||||
} | |||||
if (K > 0) { | |||||
std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4); | |||||
std::memcpy(zerobuff[1], inptr1, sizeof(int8_t) * K * 4); | |||||
std::memcpy(zerobuff[2], inptr2, sizeof(int8_t) * K * 4); | |||||
std::memcpy(zerobuff[3], inptr3, sizeof(int8_t) * K * 4); | |||||
inptr0 = zerobuff[0]; | |||||
inptr1 = zerobuff[1]; | |||||
inptr2 = zerobuff[2]; | |||||
inptr3 = zerobuff[3]; | |||||
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, | |||||
out_offset); | |||||
output += 64; | |||||
} | |||||
} | |||||
for (; y + 3 < ymax; y += 4, start_y++) { | |||||
const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; | |||||
int8_t* output = outptr + start_y * out_offset; | |||||
prefetch_2x(inptr0); | |||||
int K = kmax - k0; | |||||
for (; K > 15; K -= 16) { | |||||
transpose_interleave_1x4_4_b(inptr0, output); | |||||
output += 64; | |||||
} | |||||
if (K > 0) { | |||||
std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4); | |||||
inptr0 = zerobuff[0]; | |||||
transpose_interleave_1x4_4_b(inptr0, output); | |||||
output += 64; | |||||
} | |||||
} | |||||
} | |||||
static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
int32_t zerobuff[4]; | |||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||||
const int ksize = kmax - k0; | |||||
const int ICB = (ksize) / 4; | |||||
const int ksize4 = round_up<int>(ICB, 4) * 4; | |||||
int32_t* outptr = reinterpret_cast<int32_t*>(out); | |||||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, | |||||
"mk4 matmul with k is not times of 4"); | |||||
int k = k0 / 4; | |||||
for (; k + 3 < ICB; k += 4) { | |||||
const int32_t* inptr0 = | |||||
reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||||
const int32_t* inptr1 = | |||||
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | |||||
const int32_t* inptr2 = | |||||
reinterpret_cast<const int32_t*>(in + (k + 2) * ldin + x0); | |||||
const int32_t* inptr3 = | |||||
reinterpret_cast<const int32_t*>(in + (k + 3) * ldin + x0); | |||||
int32_t* outptr_inner = outptr; | |||||
int x = x0; | |||||
for (; x + 3 < xmax; x += 4) { | |||||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner); | |||||
outptr_inner += ksize4; | |||||
} | |||||
if (x < xmax) { | |||||
for (; x < xmax; x++) { | |||||
*outptr_inner++ = *inptr0++; | |||||
*outptr_inner++ = *inptr1++; | |||||
*outptr_inner++ = *inptr2++; | |||||
*outptr_inner++ = *inptr3++; | |||||
} | |||||
} | |||||
outptr += 4 * 4; | |||||
} | |||||
if (k < ICB) { | |||||
const int32_t* inptr0 = | |||||
reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||||
const int32_t* inptr1 = | |||||
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | |||||
const int32_t* inptr2 = | |||||
reinterpret_cast<const int32_t*>(in + (k + 2) * ldin + x0); | |||||
const int32_t* inptr3 = | |||||
reinterpret_cast<const int32_t*>(in + (k + 3) * ldin + x0); | |||||
int32_t* outptr_inner = outptr; | |||||
int x = x0; | |||||
for (; x + 3 < xmax; x += 4) { | |||||
if (k + 3 >= ICB) { | |||||
switch (k + 3 - ICB) { | |||||
case 2: | |||||
inptr1 = zerobuff; | |||||
case 1: | |||||
inptr2 = zerobuff; | |||||
case 0: | |||||
inptr3 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner); | |||||
outptr_inner += ksize4; | |||||
} | |||||
if (x < xmax) { | |||||
if (k + 3 >= ICB) { | |||||
switch (k + 3 - ICB) { | |||||
case 2: | |||||
inptr1 = zerobuff; | |||||
case 1: | |||||
inptr2 = zerobuff; | |||||
case 0: | |||||
inptr3 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
for (; x < xmax; x++) { | |||||
*outptr_inner++ = *inptr0++; | |||||
*outptr_inner++ = *inptr1++; | |||||
*outptr_inner++ = *inptr2++; | |||||
*outptr_inner++ = *inptr3++; | |||||
} | |||||
} | |||||
outptr += 4 * 4; | |||||
} | |||||
} | |||||
} // namespace matmul_4x4x16 | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,263 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8/strategy.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#if !(__ARM_FEATURE_DOTPROD) | |||||
#include "src/aarch64/matrix_mul/int8/strategy.h" | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h" | |||||
#include "src/aarch64/matrix_mul/int8/kernel_8x8x8.h" | |||||
#include "src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
///////////////////////// gemm_s8_4x4 //////////////////////////////////// | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4); | |||||
void gemm_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | |||||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} else { | |||||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, bool transpose) const { | |||||
if (transpose) { | |||||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} else { | |||||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | |||||
} | |||||
void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
constexpr size_t A_INTERLEAVE = 4; | |||||
constexpr size_t B_INTERLEAVE = 4; | |||||
//! K is packed to times of 4 | |||||
K = round_up<size_t>(K, 16); | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
int32_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_int8* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K4; | |||||
} | |||||
for (; n < N; n += B_INTERLEAVE) { | |||||
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, LDC, | |||||
is_first_k, 4, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
for (; m < M; m += 4) { | |||||
int32_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_int8* cur_packB = packB; | |||||
for (; n < N; n += B_INTERLEAVE) { | |||||
matmul_4x4x16::kern_4x4_remain( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
} | |||||
///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); | |||||
void gemm_mk4_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
megdnn_assert(!transpose, | |||||
"the gemm_mk4_s8_4x4 strategy is not support transpose A"); | |||||
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} | |||||
void gemm_mk4_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, bool transpose) const { | |||||
megdnn_assert(!transpose, | |||||
"the gemm_mk4_s8_4x4 strategy is not support transpose B"); | |||||
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
} | |||||
void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
constexpr size_t A_INTERLEAVE = 4; | |||||
constexpr size_t B_INTERLEAVE = 4; | |||||
//! K is packed to times of 4 | |||||
megdnn_assert(K % 4 == 0, "K is not time of 4"); | |||||
const size_t K4 = round_up<size_t>(K, 16) * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
int32_t* output = C + (m / 4 * LDC); | |||||
size_t n = 0; | |||||
const dt_int8* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, | |||||
is_first_k); | |||||
output += B_INTERLEAVE * 4; | |||||
cur_packB += K4; | |||||
} | |||||
if (n < N) { | |||||
matmul_mk4_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, | |||||
is_first_k, N - n); | |||||
} | |||||
packA += K4; | |||||
} | |||||
} | |||||
///////////////////////// gemm_s8_8x8 //////////////////////////////////// | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8); | |||||
void gemm_s8_8x8::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | |||||
matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||||
ymax, k0, kmax); | |||||
} else { | |||||
matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void gemm_s8_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, bool transpose) const { | |||||
if (transpose) { | |||||
matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||||
k0, kmax); | |||||
} else { | |||||
matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | |||||
} | |||||
void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
constexpr size_t A_INTERLEAVE = 8; | |||||
constexpr size_t B_INTERLEAVE = 8; | |||||
//! K is packed to times of 4 | |||||
K = round_up<size_t>(K, 8); | |||||
const int K8 = K * 8; | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
int32_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_int8* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K8; | |||||
} | |||||
for (; m < M; m += 4) { | |||||
int32_t* output = C + (m * LDC); | |||||
const dt_int8* cur_packB = packB; | |||||
size_t n = 0; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,34 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#if !(__ARM_FEATURE_DOTPROD) | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, | |||||
gemm_s8_4x4); | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, | |||||
gemm_mk4_s8_4x4); | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, | |||||
gemm_s8_8x8); | |||||
} // namespace matmul | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,116 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/int8_dot/gemv.h" | |||||
#include <cstddef> | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#if __ARM_FEATURE_DOTPROD | |||||
namespace { | |||||
void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride) { | |||||
megdnn_assert(N == 1 && Bstride == 1); | |||||
size_t m = 0; | |||||
for (; m + 2 <= M; m += 2) { | |||||
int32_t acc[4]; | |||||
int32x4_t acc_neon = vdupq_n_s32(0); | |||||
size_t k = 0; | |||||
for (; k + 16 <= K; k += 16) { | |||||
int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k)); | |||||
int64x2_t a1 = | |||||
vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k)); | |||||
//! the first 8 elements is m, the last 8 elements is m + 1 | |||||
int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1)); | |||||
int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1)); | |||||
int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k)); | |||||
int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0)); | |||||
int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0)); | |||||
acc_neon = vdotq_s32(acc_neon, a2, b2); | |||||
acc_neon = vdotq_s32(acc_neon, a3, b3); | |||||
} | |||||
vst1q_s32(acc, acc_neon); | |||||
for (; k + 8 <= K; k += 8) { | |||||
int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||||
int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k); | |||||
int8x8_t b0 = vld1_s8(B + k); | |||||
uint32x2_t zero = vdup_n_s32(0); | |||||
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||||
zero = vdup_n_s32(0); | |||||
acc[3] += vaddv_s32(vdot_s32(zero, a1, b0)); | |||||
} | |||||
for (; k < K; ++k) { | |||||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||||
acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k]; | |||||
} | |||||
C[m * Cstride] = acc[0] + acc[1]; | |||||
C[(m + 1) * Cstride] = acc[2] + acc[3]; | |||||
} | |||||
for (; m < M; ++m) { | |||||
int32_t acc[4]; | |||||
int32x4_t acc_neon = vdupq_n_s32(0); | |||||
size_t k = 0; | |||||
for (; k + 16 <= K; k += 16) { | |||||
int8x16_t a0 = vld1q_s8(A + m * Astride + k); | |||||
int8x16_t b0 = vld1q_s8(B + k); | |||||
acc_neon = vdotq_s32(acc_neon, a0, b0); | |||||
} | |||||
vst1q_s32(acc, acc_neon); | |||||
for (; k + 8 <= K; k += 8) { | |||||
int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||||
int8x8_t b0 = vld1_s8(B + k); | |||||
uint32x2_t zero = vdup_n_s32(0); | |||||
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||||
} | |||||
for (; k < K; ++k) { | |||||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||||
} | |||||
C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; | |||||
} | |||||
} | |||||
} // namespace | |||||
bool megdnn::aarch64::matmul::is_gemv_like_preferred_int8( | |||||
bool transposeA, bool transposeB, size_t M, size_t N, size_t K, | |||||
size_t /* LDA */, size_t LDB, size_t /* LDC */) { | |||||
if (transposeA) | |||||
return false; | |||||
if (transposeB) | |||||
return false; | |||||
MEGDNN_MARK_USED_VAR(K); | |||||
MEGDNN_MARK_USED_VAR(M); | |||||
return (N == 1 && LDB == 1); | |||||
} | |||||
void megdnn::aarch64::matmul::gemv_like_int8(const int8_t* __restrict A, | |||||
const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, | |||||
size_t N, size_t K, size_t Astride, | |||||
size_t Bstride, size_t Cstride) { | |||||
megdnn_assert(N == 1); | |||||
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,34 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cstddef> | |||||
#include <cstdint> | |||||
#if __ARM_FEATURE_DOTPROD | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
bool is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M, | |||||
size_t N, size_t K, size_t LDA, size_t LDB, | |||||
size_t LDC); | |||||
void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride); | |||||
} // namespace matmul | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,113 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/int8_dot/strategy.h" | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" | |||||
#if __ARM_FEATURE_DOTPROD | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12); | |||||
void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | |||||
matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} else { | |||||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, bool transpose) const { | |||||
if (transpose) { | |||||
matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||||
} else { | |||||
matmul_8x12x4::gemm_s8_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | |||||
} | |||||
void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
constexpr size_t A_INTERLEAVE = 8; | |||||
constexpr size_t B_INTERLEAVE = 12; | |||||
//! K is packed to times of 4 | |||||
K = round_up<size_t>(K, 4); | |||||
const int K8 = (K << 3); | |||||
const int K12 = K * 12; | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
int32_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_int8* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K12; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(N - n, 4)); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K8; | |||||
} | |||||
for (; m < M; m += 4) { | |||||
int32_t* output = C + (m * LDC); | |||||
const dt_int8* cur_packB = packB; | |||||
size_t n = 0; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K12; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,26 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8_dot/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
#if __ARM_FEATURE_DOTPROD | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, | |||||
gemm_s8_8x12); | |||||
} // namespace aarch64 | |||||
} // namespace matmul | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,439 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <inttypes.h> | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul_4x4x16 { | |||||
/** | |||||
* Overview of register layout: | |||||
* | |||||
* +---------+---------+---------+---------+ | |||||
* |v20[0-15]|v21[0-15]|v22[0-15]|v23[0-15]| | |||||
* Rhs +---------+---------+---------+---------+ | |||||
* Lhs | | | | |||||
* | |||||
* +--------+ - - - - +---------+---------+---------+---------+ | |||||
* |v0[0-15]| | v4[0-8] | v8[0-8]| v12[0-8]| v16[0-8]| | |||||
* |v1[0-15]| | v5[0-8] | v9[0-8]| v13[0-8]| v17[0-8]| | |||||
* |v2[0-15]| | v6[0-8] | v10[0-8]| v14[0-8]| v18[0-8]| | |||||
* |v3[0-15]| | v7[0-8] | v11[0-8]| v15[0-8]| v19[0-8]| | |||||
* +--------+ - - - - +---------+---------+---------+---------+ | |||||
* | |||||
* Accumulator | |||||
*/ | |||||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
K /= 16; | |||||
const int8_t* a_ptr = packA; | |||||
const int8_t* b_ptr = packB; | |||||
LDC = LDC * sizeof(int16_t); | |||||
// clang-format off | |||||
#define LOAD_LINE(reg_index, n) \ | |||||
"cmp x5, #0 \n" \ | |||||
"beq 105f\n" \ | |||||
"cmp %w[n_remain], #4\n" \ | |||||
"blt 100" n "f\n" \ | |||||
"ld1 {v" reg_index ".4h}, [x" n "], #8\n" \ | |||||
"b 101" n "f\n" \ | |||||
"100" n ":\n" \ | |||||
"cmp %w[n_remain], #0\n" \ | |||||
"blt 101" n "f\n" \ | |||||
"ld1 {v" reg_index ".h}[0], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #1\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" reg_index ".h}[1], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #2\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" reg_index ".h}[2], [x" n "], #2\n" \ | |||||
"101" n ":\n" \ | |||||
"sub x5, x5, #1\n" | |||||
#define LOAD_C \ | |||||
"mov x5, %x[m_remain]\n" \ | |||||
LOAD_LINE("24", "0") \ | |||||
LOAD_LINE("25", "1") \ | |||||
LOAD_LINE("26", "2") \ | |||||
LOAD_LINE("27", "3") \ | |||||
"105:\n" | |||||
#define STORE_LINE(reg_index, n) \ | |||||
"cmp x5, #0 \n" \ | |||||
"beq 105f\n" \ | |||||
"cmp %w[n_remain], #4\n" \ | |||||
"blt 102" n "f\n" \ | |||||
"st1 {v" reg_index ".4h}, [x" n "], #8\n" \ | |||||
"b 103" n "f\n" \ | |||||
"102" n ":\n" \ | |||||
"cmp %w[n_remain], #0\n" \ | |||||
"beq 103" n "f\n" \ | |||||
"st1 {v" reg_index ".h}[0], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #1\n" \ | |||||
"beq 103" n "f\n" \ | |||||
"st1 {v" reg_index ".h}[1], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #2\n" \ | |||||
"beq 103" n "f\n" \ | |||||
"st1 {v" reg_index ".h}[2], [x" n "], #2\n" \ | |||||
"103" n ":\n" \ | |||||
"sub x5, x5, #1\n" | |||||
#define STORE_C \ | |||||
"mov x5, %x[m_remain]\n" \ | |||||
STORE_LINE("24", "0") \ | |||||
STORE_LINE("25", "1") \ | |||||
STORE_LINE("26", "2") \ | |||||
STORE_LINE("27", "3") \ | |||||
"105:\n" | |||||
// clang-format on | |||||
register int16_t* outptr asm("x0") = output; | |||||
asm volatile( | |||||
"add x1, x0, %x[LDC]\n" | |||||
"add x2, x1, %x[LDC]\n" | |||||
"add x3, x2, %x[LDC]\n" | |||||
// Clear accumulators | |||||
"eor v4.16b, v4.16b, v4.16b\n" | |||||
"eor v5.16b, v5.16b, v5.16b\n" | |||||
"eor v6.16b, v6.16b, v6.16b\n" | |||||
"eor v7.16b, v7.16b, v7.16b\n" | |||||
"eor v8.16b, v8.16b, v8.16b\n" | |||||
"eor v9.16b, v9.16b, v9.16b\n" | |||||
"eor v10.16b, v10.16b, v10.16b\n" | |||||
"eor v11.16b, v11.16b, v11.16b\n" | |||||
"eor v12.16b, v12.16b, v12.16b\n" | |||||
"eor v13.16b, v13.16b, v13.16b\n" | |||||
"eor v14.16b, v14.16b, v14.16b\n" | |||||
"eor v15.16b, v15.16b, v15.16b\n" | |||||
"eor v16.16b, v16.16b, v16.16b\n" | |||||
"eor v17.16b, v17.16b, v17.16b\n" | |||||
"eor v18.16b, v18.16b, v18.16b\n" | |||||
"eor v19.16b, v19.16b, v19.16b\n" | |||||
// General loop. | |||||
"1:\n" | |||||
"ld1 {v20.16b}, [%[b_ptr]], 16\n" | |||||
"ld1 {v0.16b}, [%[a_ptr]], 16\n" | |||||
"ld1 {v1.16b}, [%[a_ptr]], 16\n" | |||||
"ld1 {v2.16b}, [%[a_ptr]], 16\n" | |||||
"ld1 {v3.16b}, [%[a_ptr]], 16\n" | |||||
"ld1 {v21.16b}, [%[b_ptr]], 16\n" | |||||
"smlal v4.8h, v0.8b, v20.8b\n" | |||||
"smlal v5.8h, v1.8b, v20.8b\n" | |||||
"smlal v6.8h, v2.8b, v20.8b\n" | |||||
"smlal v7.8h, v3.8b, v20.8b\n" | |||||
"smlal2 v4.8h, v0.16b, v20.16b\n" | |||||
"smlal2 v5.8h, v1.16b, v20.16b\n" | |||||
"smlal2 v6.8h, v2.16b, v20.16b\n" | |||||
"smlal2 v7.8h, v3.16b, v20.16b\n" | |||||
"ld1 {v22.16b}, [%[b_ptr]], 16\n" | |||||
"smlal v8.8h, v0.8b, v21.8b\n" | |||||
"smlal v9.8h, v1.8b, v21.8b\n" | |||||
"smlal v10.8h, v2.8b, v21.8b\n" | |||||
"smlal v11.8h, v3.8b, v21.8b\n" | |||||
"smlal2 v8.8h, v0.16b, v21.16b\n" | |||||
"smlal2 v9.8h, v1.16b, v21.16b\n" | |||||
"smlal2 v10.8h, v2.16b, v21.16b\n" | |||||
"smlal2 v11.8h, v3.16b, v21.16b\n" | |||||
"ld1 {v23.16b}, [%[b_ptr]], 16\n" | |||||
"smlal v12.8h, v0.8b, v22.8b\n" | |||||
"smlal v13.8h, v1.8b, v22.8b\n" | |||||
"smlal v14.8h, v2.8b, v22.8b\n" | |||||
"smlal v15.8h, v3.8b, v22.8b\n" | |||||
"smlal2 v12.8h, v0.16b, v22.16b\n" | |||||
"smlal2 v13.8h, v1.16b, v22.16b\n" | |||||
"smlal2 v14.8h, v2.16b, v22.16b\n" | |||||
"smlal2 v15.8h, v3.16b, v22.16b\n" | |||||
"smlal v16.8h, v0.8b, v23.8b\n" | |||||
"smlal v17.8h, v1.8b, v23.8b\n" | |||||
"smlal v18.8h, v2.8b, v23.8b\n" | |||||
"smlal v19.8h, v3.8b, v23.8b\n" | |||||
"smlal2 v16.8h, v0.16b, v23.16b\n" | |||||
"smlal2 v17.8h, v1.16b, v23.16b\n" | |||||
"smlal2 v18.8h, v2.16b, v23.16b\n" | |||||
"smlal2 v19.8h, v3.16b, v23.16b\n" | |||||
"subs %w[K], %w[K], #1\n" | |||||
"cbnz %w[K], 1b\n" | |||||
"cmp %w[is_first_k], #1\n" | |||||
"beq 2f\n" LOAD_C | |||||
"b 3f\n" | |||||
"2:\n" // Clear the C regs. | |||||
"eor v24.16b, v24.16b, v24.16b\n" | |||||
"eor v25.16b, v25.16b, v25.16b\n" | |||||
"eor v26.16b, v26.16b, v26.16b\n" | |||||
"eor v27.16b, v27.16b, v27.16b\n" | |||||
"3:\n" | |||||
// Reduce v4-v19 to v0-v3 | |||||
"addv h20, v4.8h\n" | |||||
"addv h21, v8.8h\n" | |||||
"addv h22, v12.8h\n" | |||||
"addv h23, v16.8h\n" | |||||
"ins v0.h[0], v20.h[0]\n" | |||||
"ins v0.h[1], v21.h[0]\n" | |||||
"ins v0.h[2], v22.h[0]\n" | |||||
"ins v0.h[3], v23.h[0]\n" | |||||
"add v24.4h, v24.4h, v0.4h\n" | |||||
"addv h28, v5.8h\n" | |||||
"addv h29, v9.8h\n" | |||||
"addv h30, v13.8h\n" | |||||
"addv h31, v17.8h\n" | |||||
"ins v1.h[0], v28.h[0]\n" | |||||
"ins v1.h[1], v29.h[0]\n" | |||||
"ins v1.h[2], v30.h[0]\n" | |||||
"ins v1.h[3], v31.h[0]\n" | |||||
"add v25.4h, v25.4h, v1.4h\n" | |||||
"addv h20, v6.8h\n" | |||||
"addv h21, v10.8h\n" | |||||
"addv h22, v14.8h\n" | |||||
"addv h23, v18.8h\n" | |||||
"ins v2.h[0], v20.h[0]\n" | |||||
"ins v2.h[1], v21.h[0]\n" | |||||
"ins v2.h[2], v22.h[0]\n" | |||||
"ins v2.h[3], v23.h[0]\n" | |||||
"add v26.4h, v26.4h, v2.4h\n" | |||||
"addv h28, v7.8h\n" | |||||
"addv h29, v11.8h\n" | |||||
"addv h30, v15.8h\n" | |||||
"addv h31, v19.8h\n" | |||||
"ins v3.h[0], v28.h[0]\n" | |||||
"ins v3.h[1], v29.h[0]\n" | |||||
"ins v3.h[2], v30.h[0]\n" | |||||
"ins v3.h[3], v31.h[0]\n" | |||||
"add v27.4h, v27.4h, v3.4h\n" | |||||
// Store back into memory | |||||
STORE_C | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain), | |||||
[n_remain] "+r"(n_remain) | |||||
: | |||||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", | |||||
"v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||||
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
"v31"); | |||||
#undef LOAD_LINE | |||||
#undef LOAD_C | |||||
#undef STORE_LINE | |||||
#undef STORE_C | |||||
} | |||||
static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
int8_t zerobuff[16]; | |||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||||
int y = y0; | |||||
for (; y + 3 < ymax; y += 4) { | |||||
const int8_t* inptr0 = inptr + y * ldin + k0; | |||||
const int8_t* inptr1 = inptr0 + ldin; | |||||
const int8_t* inptr2 = inptr1 + ldin; | |||||
const int8_t* inptr3 = inptr2 + ldin; | |||||
prefetch_2x(inptr0); | |||||
prefetch_2x(inptr1); | |||||
prefetch_2x(inptr2); | |||||
prefetch_2x(inptr3); | |||||
int K = kmax - k0; | |||||
//! read 4 * 16 in each row | |||||
for (; K > 15; K -= 16) { | |||||
interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); | |||||
} | |||||
if (K > 0) { | |||||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); | |||||
} | |||||
} | |||||
for (; y < ymax; y += 4) { | |||||
const int8_t* inptr0 = inptr + y * ldin + k0; | |||||
const int8_t* inptr1 = inptr0 + ldin; | |||||
const int8_t* inptr2 = inptr1 + ldin; | |||||
const int8_t* inptr3 = inptr2 + ldin; | |||||
prefetch_2x(inptr0); | |||||
prefetch_2x(inptr1); | |||||
prefetch_2x(inptr2); | |||||
prefetch_2x(inptr3); | |||||
int K = kmax - k0; | |||||
//! read 4 * 16 in each row | |||||
for (; K > 15; K -= 16) { | |||||
if (y + 3 >= ymax) { | |||||
switch (y + 3 - ymax) { | |||||
case 2: | |||||
inptr1 = zerobuff; | |||||
case 1: | |||||
inptr2 = zerobuff; | |||||
case 0: | |||||
inptr3 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); | |||||
} | |||||
if (K > 0) { | |||||
if (y + 3 >= ymax) { | |||||
switch (y + 3 - ymax) { | |||||
case 2: | |||||
inptr1 = zerobuff; | |||||
case 1: | |||||
inptr2 = zerobuff; | |||||
case 0: | |||||
inptr3 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); | |||||
} | |||||
} | |||||
} | |||||
static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
int8_t zerobuff[16]; | |||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||||
const int ksize = kmax - k0; | |||||
const int ksize4 = round_up(ksize, 16) * 4; | |||||
int8_t* outptr = out; | |||||
int k = k0; | |||||
for (; k < kmax; k += 16) { | |||||
int ki = k; | |||||
for (int cnt = 0; cnt < 2; ki += 8, cnt++) { | |||||
const int8_t* inptr0 = in + ki * ldin + x0; | |||||
const int8_t* inptr1 = inptr0 + ldin; | |||||
const int8_t* inptr2 = inptr1 + ldin; | |||||
const int8_t* inptr3 = inptr2 + ldin; | |||||
const int8_t* inptr4 = inptr3 + ldin; | |||||
const int8_t* inptr5 = inptr4 + ldin; | |||||
const int8_t* inptr6 = inptr5 + ldin; | |||||
const int8_t* inptr7 = inptr6 + ldin; | |||||
prefetch_2x(inptr0); | |||||
prefetch_2x(inptr1); | |||||
prefetch_2x(inptr2); | |||||
prefetch_2x(inptr3); | |||||
prefetch_2x(inptr4); | |||||
prefetch_2x(inptr5); | |||||
prefetch_2x(inptr6); | |||||
prefetch_2x(inptr7); | |||||
int8_t* outptr_inner = outptr + ki - k; | |||||
int remain = std::min(ki + 7 - kmax, 7); | |||||
int x = x0; | |||||
for (; x + 3 < xmax; x += 4) { | |||||
if (remain >= 0) { | |||||
switch (remain) { | |||||
case 7: | |||||
inptr0 = zerobuff; | |||||
case 6: | |||||
inptr1 = zerobuff; | |||||
case 5: | |||||
inptr2 = zerobuff; | |||||
case 4: | |||||
inptr3 = zerobuff; | |||||
case 3: | |||||
inptr4 = zerobuff; | |||||
case 2: | |||||
inptr5 = zerobuff; | |||||
case 1: | |||||
inptr6 = zerobuff; | |||||
case 0: | |||||
inptr7 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, | |||||
inptr4, inptr5, inptr6, inptr7, | |||||
outptr_inner); | |||||
outptr_inner += ksize4; | |||||
} | |||||
if (x < xmax) { | |||||
if (remain >= 0) { | |||||
switch (remain) { | |||||
case 7: | |||||
inptr0 = zerobuff; | |||||
case 6: | |||||
inptr1 = zerobuff; | |||||
case 5: | |||||
inptr2 = zerobuff; | |||||
case 4: | |||||
inptr3 = zerobuff; | |||||
case 3: | |||||
inptr4 = zerobuff; | |||||
case 2: | |||||
inptr5 = zerobuff; | |||||
case 1: | |||||
inptr6 = zerobuff; | |||||
case 0: | |||||
inptr7 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
for (; x < xmax; x++) { | |||||
*outptr_inner++ = *inptr0++; | |||||
*outptr_inner++ = *inptr1++; | |||||
*outptr_inner++ = *inptr2++; | |||||
*outptr_inner++ = *inptr3++; | |||||
*outptr_inner++ = *inptr4++; | |||||
*outptr_inner++ = *inptr5++; | |||||
*outptr_inner++ = *inptr6++; | |||||
*outptr_inner++ = *inptr7++; | |||||
outptr_inner += 8; | |||||
} | |||||
} | |||||
} | |||||
outptr += 16 * 4; | |||||
} | |||||
} | |||||
} // namespace matmul_4x4x16 | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,200 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" | |||||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" | |||||
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
// ===========================gemm_s8x8x16_4x4================================== | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8); | |||||
void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||||
int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | |||||
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n(out, in, ldin, y0, | |||||
ymax, k0, kmax); | |||||
} else { | |||||
matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | |||||
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n(out, in, ldin, x0, | |||||
xmax, k0, kmax); | |||||
} else { | |||||
matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t M, size_t N, size_t K, dt_int16* C, | |||||
size_t LDC, bool is_first_k, const dt_int16*, | |||||
dt_int16*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
(A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int16), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
constexpr size_t A_INTERLEAVE = 8; | |||||
constexpr size_t B_INTERLEAVE = 8; | |||||
//! K is packed to times of 4 | |||||
K = round_up<size_t>(K, 8); | |||||
const int K8 = K * 8; | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
int16_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_int8* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K8; | |||||
} | |||||
for (; m < M; m += 4) { | |||||
int16_t* output = C + (m * LDC); | |||||
const dt_int8* cur_packB = packB; | |||||
size_t n = 0; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
} | |||||
// ===========================gemm_s8x8x16_4x4================================== | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4); | |||||
void gemm_s8x8x16_4x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||||
int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | |||||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} else { | |||||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void gemm_s8x8x16_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | |||||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
} else { | |||||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t M, size_t N, size_t K, dt_int16* C, | |||||
size_t LDC, bool is_first_k, const dt_int16*, | |||||
dt_int16*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
(A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int16), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
constexpr size_t A_INTERLEAVE = 4; | |||||
constexpr size_t B_INTERLEAVE = 4; | |||||
//! K is packed to times of 4 | |||||
K = round_up<size_t>(K, 16); | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
int16_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_int8* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, A_INTERLEAVE, B_INTERLEAVE); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K4; | |||||
} | |||||
for (; n < N; n += B_INTERLEAVE) { | |||||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, A_INTERLEAVE, | |||||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
for (; m < M; m += A_INTERLEAVE) { | |||||
int16_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_int8* cur_packB = packB; | |||||
for (; n < N; n += B_INTERLEAVE) { | |||||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, | |||||
std::min<size_t>(M - m, A_INTERLEAVE), | |||||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,27 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, | |||||
gemm_s8x8x16_8x8); | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 16, false, true, | |||||
gemm_s8x8x16_4x4); | |||||
} // namespace matmul | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,93 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/opr_impl.h" | |||||
#include "src/aarch64/matrix_mul/algos.h" | |||||
#include "src/common/metahelper.h" | |||||
#include "src/common/utils.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
AlgoF32K8x12x1 f32K8x12x1; | |||||
AlgoF32K4x16x1 f32k4x16x1; | |||||
AlgoF32MK4_4x16 f32mk4_4x16; | |||||
AlgoF32Gemv f32_gemv; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
AlgoF16K8x24x1 f16_k8x24x1; | |||||
AlgoF16MK8_8x8 f16_mk8_8x8; | |||||
#endif | |||||
#if __ARM_FEATURE_DOTPROD | |||||
AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; | |||||
AlgoInt8x8x32GemvDotProd int8x8x32_gemv_dotprod; | |||||
#else | |||||
AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; | |||||
AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; | |||||
AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8; | |||||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||||
#endif | |||||
AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; | |||||
AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; | |||||
AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; | |||||
AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; | |||||
#if __ARM_FEATURE_DOTPROD | |||||
AlgoQuint8K8x8x4DotProd quint8_k8x8x4_dotprod; | |||||
AlgoQuint8GemvDotProd quint8_gemv_dotprod; | |||||
#else | |||||
AlgoQuint8K8x8x8 quint8_k8x8x8; | |||||
#endif | |||||
public: | |||||
SmallVector<MatrixMulImpl::AlgoBase*> all_algos; | |||||
AlgoPack() { | |||||
all_algos.emplace_back(&f32_gemv); | |||||
all_algos.emplace_back(&f32K8x12x1); | |||||
all_algos.emplace_back(&f32k4x16x1); | |||||
all_algos.emplace_back(&f32mk4_4x16); | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
all_algos.emplace_back(&f16_k8x24x1); | |||||
all_algos.emplace_back(&f16_mk8_8x8); | |||||
#endif | |||||
#if __ARM_FEATURE_DOTPROD | |||||
all_algos.emplace_back(&int8x8x32_gemv_dotprod); | |||||
all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | |||||
#else | |||||
all_algos.emplace_back(&int8x8x32_gemv); | |||||
all_algos.emplace_back(&int8x8x32_k8x8x8); | |||||
all_algos.emplace_back(&int8x8x32_k4x4x16); | |||||
all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | |||||
#endif | |||||
all_algos.emplace_back(&int8x8x16_k4x4x16); | |||||
all_algos.emplace_back(&int8x8x16_k8x8x8); | |||||
all_algos.emplace_back(&int16x16x32_k12x8x1); | |||||
all_algos.emplace_back(&int16x16x32_mk8_8x8); | |||||
#if __ARM_FEATURE_DOTPROD | |||||
all_algos.emplace_back(&quint8_gemv_dotprod); | |||||
all_algos.emplace_back(&quint8_k8x8x4_dotprod); | |||||
#else | |||||
all_algos.emplace_back(&quint8_k8x8x8); | |||||
#endif | |||||
} | |||||
}; | |||||
SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||||
static AlgoPack s_algo_pack; | |||||
auto&& algos = arm_common::MatrixMulImpl::algo_pack(); | |||||
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | |||||
s_algo_pack.all_algos.end()); | |||||
return std::move(algos); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,63 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/matrix_mul/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
class MatrixMulImpl : public arm_common::MatrixMulImpl { | |||||
public: | |||||
using arm_common::MatrixMulImpl::MatrixMulImpl; | |||||
SmallVector<AlgoBase*> algo_pack() override; | |||||
private: | |||||
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 | |||||
class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1 | |||||
class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4 | |||||
class AlgoF32Gemv; // Aarch64 F32 Gemv | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
class AlgoF16K8x24x1; // Aarch64 F16 Kernel 8x24x1 | |||||
class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8 | |||||
#endif | |||||
#if __ARM_FEATURE_DOTPROD | |||||
class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel | |||||
// 8x12x4 DotProduct | |||||
class AlgoInt8x8x32GemvDotProd; // Aarch64 Int8x8x32 Gemv DotProduct | |||||
#else | |||||
class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 | |||||
class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 | |||||
class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 | |||||
class AlgoInt8x8x32Gemv; // Aarch64 Int8x8x32 Gemv | |||||
#endif | |||||
class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 | |||||
class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 | |||||
class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1 | |||||
class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8 | |||||
#if __ARM_FEATURE_DOTPROD | |||||
class AlgoQuint8K8x8x4DotProd; // Aarch64 Quint8 Kernel | |||||
// 8x8x4 DotProduct | |||||
class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct | |||||
#else | |||||
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | |||||
#endif | |||||
class AlgoPack; | |||||
}; | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,113 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/quint8/strategy.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#if !(__ARM_FEATURE_DOTPROD) | |||||
#include "src/aarch64/matrix_mul/quint8/strategy.h" | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8); | |||||
void gemm_u8_8x8::pack_A(dt_uint8* outptr, const dt_uint8* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||||
if (transpose) { | |||||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||||
ymax, k0, kmax, zA); | |||||
} else { | |||||
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax, zA); | |||||
} | |||||
} | |||||
void gemm_u8_8x8::pack_B(dt_uint8* out, const dt_uint8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, bool transpose) const { | |||||
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||||
if (transpose) { | |||||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||||
k0, kmax, zB); | |||||
} else { | |||||
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, | |||||
zB); | |||||
} | |||||
} | |||||
void gemm_u8_8x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Quantized8Asymm && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32, | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||||
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||||
constexpr size_t A_INTERLEAVE = 8; | |||||
constexpr size_t B_INTERLEAVE = 8; | |||||
//! K is packed to times of 8 | |||||
K = round_up<size_t>(K, 8); | |||||
const int K8 = K * 8; | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
int32_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_uint8* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k, | |||||
zA, zB); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4), zA, zB); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K8; | |||||
} | |||||
for (; m < M; m += 4) { | |||||
int32_t* output = C + (m * LDC); | |||||
const dt_uint8* cur_packB = packB; | |||||
size_t n = 0; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), zA, zB); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4), zA, zB); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,28 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/quint8/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#if !(__ARM_FEATURE_DOTPROD) | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 8, 8, 8, false, true, | |||||
gemm_u8_8x8); | |||||
} // namespace matmul | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,177 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | |||||
#include <cstddef> | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#if __ARM_FEATURE_DOTPROD | |||||
namespace { | |||||
void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride, | |||||
uint8_t zero_point_A, uint8_t zero_point_B) { | |||||
int32_t zAB = static_cast<int32_t>(zero_point_A) * | |||||
static_cast<int32_t>(zero_point_B) * K; | |||||
uint8x16_t zAq = vdupq_n_u8(zero_point_A); | |||||
uint8x16_t zBq = vdupq_n_u8(zero_point_B); | |||||
uint8x8_t zA = vdup_n_u8(zero_point_A); | |||||
uint8x8_t zB = vdup_n_u8(zero_point_B); | |||||
megdnn_assert(N == 1 && Bstride == 1); | |||||
size_t m = 0; | |||||
for (; m + 2 <= M; m += 2) { | |||||
int32_t acc_zA, acc_zB, acc_zB2; | |||||
int32_t acc[4]; | |||||
size_t k = 0; | |||||
uint32x4_t acc_neon = vdupq_n_u32(0); | |||||
{ | |||||
uint32x4_t acc_zA_neon = vdupq_n_u32(0); | |||||
uint32x4_t acc_zB_neon = vdupq_n_u32(0); | |||||
uint32x4_t acc_zB2_neon = vdupq_n_u32(0); | |||||
for (; k + 16 <= K; k += 16) { | |||||
uint8x16_t elem = vld1q_u8(A + m * Astride + k); | |||||
acc_zB_neon = vdotq_u32(acc_zB_neon, zBq, elem); | |||||
uint64x2_t a0 = vreinterpretq_u64_u8(elem); | |||||
elem = vld1q_u8(A + (m + 1) * Astride + k); | |||||
acc_zB2_neon = vdotq_u32(acc_zB2_neon, zBq, elem); | |||||
uint64x2_t a1 = vreinterpretq_u64_u8(elem); | |||||
//! the first 8 elements is m, the last 8 elements is m + 1 | |||||
uint8x16_t a2 = vreinterpretq_u8_u64(vzip1q_u64(a0, a1)); | |||||
uint8x16_t a3 = vreinterpretq_u8_u64(vzip2q_u64(a0, a1)); | |||||
elem = vld1q_u8(B + k); | |||||
acc_zA_neon = vdotq_u32(acc_zA_neon, zAq, elem); | |||||
uint64x2_t b0 = vreinterpretq_u64_u8(elem); | |||||
uint8x16_t b2 = vreinterpretq_u8_u64(vzip1q_u64(b0, b0)); | |||||
uint8x16_t b3 = vreinterpretq_u8_u64(vzip2q_u64(b0, b0)); | |||||
acc_neon = vdotq_u32(acc_neon, a2, b2); | |||||
acc_neon = vdotq_u32(acc_neon, a3, b3); | |||||
} | |||||
vst1q_s32(acc, vreinterpretq_s32_u32(acc_neon)); | |||||
acc_zA = vaddvq_u32(acc_zA_neon); | |||||
acc_zB = vaddvq_u32(acc_zB_neon); | |||||
acc_zB2 = vaddvq_u32(acc_zB2_neon); | |||||
} | |||||
{ | |||||
uint32x2_t acc_zA_neon = vdup_n_u32(0); | |||||
uint32x2_t acc_zB_neon = vdup_n_u32(0); | |||||
uint32x2_t acc_zB2_neon = vdup_n_u32(0); | |||||
for (; k + 8 <= K; k += 8) { | |||||
uint8x8_t a0 = vld1_u8(A + m * Astride + k); | |||||
uint8x8_t a1 = vld1_u8(A + (m + 1) * Astride + k); | |||||
uint8x8_t b0 = vld1_u8(B + k); | |||||
uint32x2_t zero = vdup_n_u32(0); | |||||
acc[0] += vaddv_u32(vdot_u32(zero, a0, b0)); | |||||
zero = vdup_n_u32(0); | |||||
acc[3] += vaddv_u32(vdot_u32(zero, a1, b0)); | |||||
acc_zB_neon = vdot_u32(acc_zB_neon, a0, zB); | |||||
acc_zB2_neon = vdot_u32(acc_zB2_neon, a1, zB); | |||||
acc_zA_neon = vdot_u32(acc_zA_neon, b0, zA); | |||||
} | |||||
acc_zA += vaddv_u32(acc_zA_neon); | |||||
acc_zB += vaddv_u32(acc_zB_neon); | |||||
acc_zB2 += vaddv_u32(acc_zB2_neon); | |||||
} | |||||
for (; k < K; ++k) { | |||||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||||
acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k]; | |||||
acc_zA += static_cast<int32_t>(B[k]) * zero_point_A; | |||||
acc_zB += static_cast<int32_t>(A[m * Astride + k]) * zero_point_B; | |||||
acc_zB2 += static_cast<int32_t>(A[(m + 1) * Astride + k]) * | |||||
zero_point_B; | |||||
} | |||||
C[m * Cstride] = acc[0] + acc[1] + zAB - acc_zA - acc_zB; | |||||
C[(m + 1) * Cstride] = acc[2] + acc[3] + zAB - acc_zA - acc_zB2; | |||||
} | |||||
for (; m < M; ++m) { | |||||
int32_t acc[4]; | |||||
int32_t acc_zA, acc_zB; | |||||
uint32x4_t acc_neon = vdupq_n_u32(0); | |||||
size_t k = 0; | |||||
{ | |||||
uint32x4_t acc_zA_neon = vdupq_n_u32(0); | |||||
uint32x4_t acc_zB_neon = vdupq_n_u32(0); | |||||
for (; k + 16 <= K; k += 16) { | |||||
uint8x16_t a0 = vld1q_u8(A + m * Astride + k); | |||||
uint8x16_t b0 = vld1q_u8(B + k); | |||||
acc_neon = vdotq_u32(acc_neon, a0, b0); | |||||
acc_zB_neon = vdotq_u32(acc_zB_neon, zBq, a0); | |||||
acc_zA_neon = vdotq_u32(acc_zA_neon, zAq, b0); | |||||
} | |||||
vst1q_s32(acc, vreinterpretq_s32_u32(acc_neon)); | |||||
acc_zA = vaddvq_u32(acc_zA_neon); | |||||
acc_zB = vaddvq_u32(acc_zB_neon); | |||||
} | |||||
{ | |||||
uint32x2_t acc_zA_neon = vdup_n_u32(0); | |||||
uint32x2_t acc_zB_neon = vdup_n_u32(0); | |||||
for (; k + 8 <= K; k += 8) { | |||||
uint8x8_t a0 = vld1_u8(A + m * Astride + k); | |||||
uint8x8_t b0 = vld1_u8(B + k); | |||||
uint32x2_t zero = vdup_n_u32(0); | |||||
acc[0] += vaddv_u32(vdot_u32(zero, a0, b0)); | |||||
acc_zB_neon = vdot_u32(acc_zB_neon, a0, zB); | |||||
acc_zA_neon = vdot_u32(acc_zA_neon, b0, zA); | |||||
} | |||||
acc_zA += vaddv_u32(acc_zA_neon); | |||||
acc_zB += vaddv_u32(acc_zB_neon); | |||||
} | |||||
for (; k < K; ++k) { | |||||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||||
acc_zA += static_cast<int32_t>(B[k]) * zero_point_A; | |||||
acc_zB += static_cast<int32_t>(A[m * Astride + k]) * zero_point_B; | |||||
} | |||||
C[m * Cstride] = | |||||
acc[0] + acc[1] + acc[2] + acc[3] + zAB - acc_zA - acc_zB; | |||||
} | |||||
} | |||||
} // namespace | |||||
bool megdnn::aarch64::matmul::is_gemv_like_preferred_quint8( | |||||
bool transposeA, bool transposeB, size_t M, size_t N, size_t K, | |||||
size_t /* LDA */, size_t LDB, size_t /* LDC */) { | |||||
if (transposeA) | |||||
return false; | |||||
if (transposeB) | |||||
return false; | |||||
MEGDNN_MARK_USED_VAR(K); | |||||
MEGDNN_MARK_USED_VAR(M); | |||||
//! rebenchmark gemv in sdm855 | |||||
return (N == 1 && LDB == 1); | |||||
} | |||||
void megdnn::aarch64::matmul::gemv_like_quint8( | |||||
const uint8_t* __restrict A, const uint8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, size_t K, size_t Astride, | |||||
size_t Bstride, size_t Cstride, uint8_t zero_point_A, | |||||
uint8_t zero_point_B) { | |||||
megdnn_assert(N == 1); | |||||
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride, | |||||
zero_point_A, zero_point_B); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,35 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cstddef> | |||||
#include <cstdint> | |||||
#if __ARM_FEATURE_DOTPROD | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
bool is_gemv_like_preferred_quint8(bool transposeA, bool transposeB, size_t M, | |||||
size_t N, size_t K, size_t LDA, size_t LDB, | |||||
size_t LDC); | |||||
void gemv_like_quint8(const uint8_t* __restrict A, const uint8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride, | |||||
uint8_t zero_point_A, uint8_t zero_point_B); | |||||
} // namespace matmul | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,114 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/quint8_dot/strategy.h" | |||||
#include "megdnn/dtype.h" | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#if __ARM_FEATURE_DOTPROD | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8); | |||||
void gemm_u8_8x8::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | |||||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0, | |||||
ymax, k0, kmax); | |||||
} else { | |||||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin, | |||||
y0, ymax, k0, kmax); | |||||
} | |||||
} | |||||
void gemm_u8_8x8::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, bool transpose) const { | |||||
if (transpose) { | |||||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, | |||||
xmax, k0, kmax); | |||||
} else { | |||||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax, | |||||
k0, kmax); | |||||
} | |||||
} | |||||
void gemm_u8_8x8::kern(const uint8_t* packA, const uint8_t* packB, size_t M, | |||||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Quantized8Asymm && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
size_t zero_point_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||||
size_t zero_point_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||||
constexpr size_t A_INTERLEAVE = 8; | |||||
constexpr size_t B_INTERLEAVE = 8; | |||||
const uint32_t zAB = static_cast<uint32_t>(zero_point_A) * | |||||
static_cast<uint32_t>(zero_point_B) * K; | |||||
//! K is packed to times of 4 | |||||
K = round_up<size_t>(K, 4); | |||||
const int K8 = (K << 3); | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
int32_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_uint8* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k, | |||||
zero_point_A, zero_point_B, zAB); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x8x4::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4), zero_point_A, | |||||
zero_point_B, zAB); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K8; | |||||
} | |||||
for (; m < M; m += 4) { | |||||
int32_t* output = C + (m * LDC); | |||||
const dt_uint8* cur_packB = packB; | |||||
size_t n = 0; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), zero_point_A, | |||||
zero_point_B, zAB); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
matmul_8x8x4::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4), zero_point_A, | |||||
zero_point_B, zAB); | |||||
output += 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,26 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
#if __ARM_FEATURE_DOTPROD | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
MEGDNN_REG_GEMM_STRATEGY(uint8_t, int32_t, int32_t, 8, 8, 4, false, true, | |||||
gemm_u8_8x8); | |||||
} // namespace aarch64 | |||||
} // namespace matmul | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,183 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/relayout/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/common/utils.h" | |||||
#include "src/common/relayout_helper.h" | |||||
#include "src/aarch64/handle.h" | |||||
#include "src/aarch64/relayout/opr_impl.h" | |||||
using namespace megdnn; | |||||
using namespace relayout; | |||||
namespace { | |||||
struct TransposeByte { | |||||
uint8_t v; | |||||
}; | |||||
void trans_16x16_u8(const void* src, void* dst, const size_t src_step, | |||||
const size_t dst_step) { | |||||
asm volatile( | |||||
"\n" | |||||
"ld1 {v0.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v1.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v2.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v3.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v4.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v5.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v6.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v7.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v8.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v9.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v10.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v11.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v12.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v13.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v14.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v15.16b}, [%[src]], %[src_step] \n" | |||||
"trn1 v16.16b, v0.16b, v1.16b \n" | |||||
"trn2 v17.16b, v0.16b, v1.16b \n" | |||||
"trn1 v18.16b, v2.16b, v3.16b \n" | |||||
"trn2 v19.16b, v2.16b, v3.16b \n" | |||||
"trn1 v20.16b, v4.16b, v5.16b \n" | |||||
"trn2 v21.16b, v4.16b, v5.16b \n" | |||||
"trn1 v22.16b, v6.16b, v7.16b \n" | |||||
"trn2 v23.16b, v6.16b, v7.16b \n" | |||||
"trn1 v24.16b, v8.16b, v9.16b \n" | |||||
"trn2 v25.16b, v8.16b, v9.16b \n" | |||||
"trn1 v26.16b, v10.16b, v11.16b \n" | |||||
"trn2 v27.16b, v10.16b, v11.16b \n" | |||||
"trn1 v28.16b, v12.16b, v13.16b \n" | |||||
"trn2 v29.16b, v12.16b, v13.16b \n" | |||||
"trn1 v30.16b, v14.16b, v15.16b \n" | |||||
"trn2 v31.16b, v14.16b, v15.16b \n" | |||||
"trn1 v0.8h, v16.8h, v18.8h \n" | |||||
"trn2 v2.8h, v16.8h, v18.8h \n" | |||||
"trn1 v4.8h, v20.8h, v22.8h \n" | |||||
"trn2 v6.8h, v20.8h, v22.8h \n" | |||||
"trn1 v8.8h, v24.8h, v26.8h \n" | |||||
"trn2 v10.8h, v24.8h, v26.8h \n" | |||||
"trn1 v12.8h, v28.8h, v30.8h \n" | |||||
"trn2 v14.8h, v28.8h, v30.8h \n" | |||||
"trn1 v1.8h, v17.8h, v19.8h \n" | |||||
"trn2 v3.8h, v17.8h, v19.8h \n" | |||||
"trn1 v5.8h, v21.8h, v23.8h \n" | |||||
"trn2 v7.8h, v21.8h, v23.8h \n" | |||||
"trn1 v9.8h, v25.8h, v27.8h \n" | |||||
"trn2 v11.8h, v25.8h, v27.8h \n" | |||||
"trn1 v13.8h, v29.8h, v31.8h \n" | |||||
"trn2 v15.8h, v29.8h, v31.8h \n" | |||||
"trn1 v16.4s, v0.4s, v4.4s \n" | |||||
"trn2 v20.4s, v0.4s, v4.4s \n" | |||||
"trn1 v24.4s, v8.4s, v12.4s \n" | |||||
"trn2 v28.4s, v8.4s, v12.4s \n" | |||||
"trn1 v17.4s, v1.4s, v5.4s \n" | |||||
"trn2 v21.4s, v1.4s, v5.4s \n" | |||||
"trn1 v25.4s, v9.4s, v13.4s \n" | |||||
"trn2 v29.4s, v9.4s, v13.4s \n" | |||||
"trn1 v18.4s, v2.4s, v6.4s \n" | |||||
"trn2 v22.4s, v2.4s, v6.4s \n" | |||||
"trn1 v26.4s, v10.4s, v14.4s \n" | |||||
"trn2 v30.4s, v10.4s, v14.4s \n" | |||||
"trn1 v19.4s, v3.4s, v7.4s \n" | |||||
"trn2 v23.4s, v3.4s, v7.4s \n" | |||||
"trn1 v27.4s, v11.4s, v15.4s \n" | |||||
"trn2 v31.4s, v11.4s, v15.4s \n" | |||||
"trn1 v0.2d, v16.2d, v24.2d \n" | |||||
"trn2 v8.2d, v16.2d, v24.2d \n" | |||||
"trn1 v1.2d, v17.2d, v25.2d \n" | |||||
"trn2 v9.2d, v17.2d, v25.2d \n" | |||||
"trn1 v2.2d, v18.2d, v26.2d \n" | |||||
"trn2 v10.2d, v18.2d, v26.2d \n" | |||||
"trn1 v3.2d, v19.2d, v27.2d \n" | |||||
"trn2 v11.2d, v19.2d, v27.2d \n" | |||||
"trn1 v4.2d, v20.2d, v28.2d \n" | |||||
"trn2 v12.2d, v20.2d, v28.2d \n" | |||||
"trn1 v5.2d, v21.2d, v29.2d \n" | |||||
"trn2 v13.2d, v21.2d, v29.2d \n" | |||||
"trn1 v6.2d, v22.2d, v30.2d \n" | |||||
"trn2 v14.2d, v22.2d, v30.2d \n" | |||||
"trn1 v7.2d, v23.2d, v31.2d \n" | |||||
"trn2 v15.2d, v23.2d, v31.2d \n" | |||||
"st1 {v0.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v1.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v2.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v3.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v4.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v5.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v6.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v7.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v8.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v9.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v10.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v11.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v12.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v13.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v14.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v15.16b}, [%[dst]], %[dst_step] \n" | |||||
: | |||||
[src] "+r" (src), | |||||
[dst] "+r" (dst) | |||||
: | |||||
[src_step] "r" (src_step), | |||||
[dst_step] "r" (dst_step) | |||||
: | |||||
"d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", | |||||
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", | |||||
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", | |||||
"d31"); | |||||
} | |||||
} // anonymous namespace | |||||
namespace megdnn { | |||||
namespace relayout { | |||||
namespace transpose_fallback { | |||||
template <> | |||||
struct transpose_traits<TransposeByte> { | |||||
static constexpr size_t block_size = 16; | |||||
}; | |||||
template <> | |||||
void transpose_block<TransposeByte>(const TransposeByte* src, | |||||
TransposeByte* dst, const size_t src_stride, | |||||
const size_t dst_stride) { | |||||
trans_16x16_u8(src, dst, src_stride, dst_stride); | |||||
} | |||||
} // namespace transpose_fallback | |||||
} // namespace relayout | |||||
} // namespace megdnn | |||||
void aarch64::RelayoutForwardImpl::exec(_megdnn_tensor_in src0, | |||||
_megdnn_tensor_out dst0, | |||||
Handle* src_handle) { | |||||
check_cpu_handle(src_handle); | |||||
TensorND src = src0, dst = dst0; | |||||
check_layout_and_canonize(src.layout, dst.layout); | |||||
relayout::TransposeParam trans_param; | |||||
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); | |||||
if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { | |||||
auto sptr = static_cast<TransposeByte*>(src.raw_ptr), | |||||
dptr = static_cast<TransposeByte*>(dst.raw_ptr); | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
transpose_fallback::transpose<TransposeByte>( | |||||
trans_param.batch, trans_param.m, trans_param.n, sptr, | |||||
dptr)); | |||||
return; | |||||
} | |||||
exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,31 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/relayout/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/fallback/relayout/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
class RelayoutForwardImpl final : public fallback::RelayoutForwardImpl { | |||||
public: | |||||
using fallback::RelayoutForwardImpl::RelayoutForwardImpl; | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
Handle *src_handle) override; | |||||
bool is_thread_safe() const override { return true; } | |||||
}; | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,393 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/rotate/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <cstring> | |||||
#include "src/aarch64/rotate/opr_impl.h" | |||||
#include "src/aarch64/handle.h" | |||||
#include "src/common/cv/common.h" | |||||
#include "src/common/cv/helper.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
namespace megcv { | |||||
void rotate_8uc1_clockwise_16x16(const uchar *src, | |||||
uchar *dst, | |||||
size_t src_step, size_t dst_step) | |||||
{ | |||||
asm volatile ("\n" | |||||
"ld1 {v0.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v1.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v2.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v3.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v4.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v5.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v6.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v7.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v8.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v9.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v10.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v11.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v12.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v13.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v14.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v15.16b}, [%[src]], %[src_step] \n" | |||||
"trn1 v16.16b, v0.16b, v1.16b \n" | |||||
"trn2 v17.16b, v0.16b, v1.16b \n" | |||||
"trn1 v18.16b, v2.16b, v3.16b \n" | |||||
"trn2 v19.16b, v2.16b, v3.16b \n" | |||||
"trn1 v20.16b, v4.16b, v5.16b \n" | |||||
"trn2 v21.16b, v4.16b, v5.16b \n" | |||||
"trn1 v22.16b, v6.16b, v7.16b \n" | |||||
"trn2 v23.16b, v6.16b, v7.16b \n" | |||||
"trn1 v24.16b, v8.16b, v9.16b \n" | |||||
"trn2 v25.16b, v8.16b, v9.16b \n" | |||||
"trn1 v26.16b, v10.16b, v11.16b \n" | |||||
"trn2 v27.16b, v10.16b, v11.16b \n" | |||||
"trn1 v28.16b, v12.16b, v13.16b \n" | |||||
"trn2 v29.16b, v12.16b, v13.16b \n" | |||||
"trn1 v30.16b, v14.16b, v15.16b \n" | |||||
"trn2 v31.16b, v14.16b, v15.16b \n" | |||||
"trn1 v0.8h, v16.8h, v18.8h \n" | |||||
"trn2 v2.8h, v16.8h, v18.8h \n" | |||||
"trn1 v4.8h, v20.8h, v22.8h \n" | |||||
"trn2 v6.8h, v20.8h, v22.8h \n" | |||||
"trn1 v8.8h, v24.8h, v26.8h \n" | |||||
"trn2 v10.8h, v24.8h, v26.8h \n" | |||||
"trn1 v12.8h, v28.8h, v30.8h \n" | |||||
"trn2 v14.8h, v28.8h, v30.8h \n" | |||||
"trn1 v1.8h, v17.8h, v19.8h \n" | |||||
"trn2 v3.8h, v17.8h, v19.8h \n" | |||||
"trn1 v5.8h, v21.8h, v23.8h \n" | |||||
"trn2 v7.8h, v21.8h, v23.8h \n" | |||||
"trn1 v9.8h, v25.8h, v27.8h \n" | |||||
"trn2 v11.8h, v25.8h, v27.8h \n" | |||||
"trn1 v13.8h, v29.8h, v31.8h \n" | |||||
"trn2 v15.8h, v29.8h, v31.8h \n" | |||||
"trn1 v16.4s, v0.4s, v4.4s \n" | |||||
"trn2 v20.4s, v0.4s, v4.4s \n" | |||||
"trn1 v24.4s, v8.4s, v12.4s \n" | |||||
"trn2 v28.4s, v8.4s, v12.4s \n" | |||||
"trn1 v17.4s, v1.4s, v5.4s \n" | |||||
"trn2 v21.4s, v1.4s, v5.4s \n" | |||||
"trn1 v25.4s, v9.4s, v13.4s \n" | |||||
"trn2 v29.4s, v9.4s, v13.4s \n" | |||||
"trn1 v18.4s, v2.4s, v6.4s \n" | |||||
"trn2 v22.4s, v2.4s, v6.4s \n" | |||||
"trn1 v26.4s, v10.4s, v14.4s \n" | |||||
"trn2 v30.4s, v10.4s, v14.4s \n" | |||||
"trn1 v19.4s, v3.4s, v7.4s \n" | |||||
"trn2 v23.4s, v3.4s, v7.4s \n" | |||||
"trn1 v27.4s, v11.4s, v15.4s \n" | |||||
"trn2 v31.4s, v11.4s, v15.4s \n" | |||||
"trn1 v0.2d, v16.2d, v24.2d \n" | |||||
"trn2 v8.2d, v16.2d, v24.2d \n" | |||||
"trn1 v1.2d, v17.2d, v25.2d \n" | |||||
"trn2 v9.2d, v17.2d, v25.2d \n" | |||||
"trn1 v2.2d, v18.2d, v26.2d \n" | |||||
"trn2 v10.2d, v18.2d, v26.2d \n" | |||||
"trn1 v3.2d, v19.2d, v27.2d \n" | |||||
"trn2 v11.2d, v19.2d, v27.2d \n" | |||||
"trn1 v4.2d, v20.2d, v28.2d \n" | |||||
"trn2 v12.2d, v20.2d, v28.2d \n" | |||||
"trn1 v5.2d, v21.2d, v29.2d \n" | |||||
"trn2 v13.2d, v21.2d, v29.2d \n" | |||||
"trn1 v6.2d, v22.2d, v30.2d \n" | |||||
"trn2 v14.2d, v22.2d, v30.2d \n" | |||||
"trn1 v7.2d, v23.2d, v31.2d \n" | |||||
"trn2 v15.2d, v23.2d, v31.2d \n" | |||||
// There is no rev128 instruction, so we use rev64 and ext to simulate it. | |||||
"rev64 v0.16b, v0.16b \n" | |||||
"rev64 v1.16b, v1.16b \n" | |||||
"rev64 v2.16b, v2.16b \n" | |||||
"rev64 v3.16b, v3.16b \n" | |||||
"rev64 v4.16b, v4.16b \n" | |||||
"rev64 v5.16b, v5.16b \n" | |||||
"rev64 v6.16b, v6.16b \n" | |||||
"rev64 v7.16b, v7.16b \n" | |||||
"rev64 v8.16b, v8.16b \n" | |||||
"rev64 v9.16b, v9.16b \n" | |||||
"rev64 v10.16b, v10.16b \n" | |||||
"rev64 v11.16b, v11.16b \n" | |||||
"rev64 v12.16b, v12.16b \n" | |||||
"rev64 v13.16b, v13.16b \n" | |||||
"rev64 v14.16b, v14.16b \n" | |||||
"rev64 v15.16b, v15.16b \n" | |||||
"ext v0.16b, v0.16b, v0.16b, #8 \n" | |||||
"ext v1.16b, v1.16b, v1.16b, #8 \n" | |||||
"ext v2.16b, v2.16b, v2.16b, #8 \n" | |||||
"ext v3.16b, v3.16b, v3.16b, #8 \n" | |||||
"ext v4.16b, v4.16b, v4.16b, #8 \n" | |||||
"ext v5.16b, v5.16b, v5.16b, #8 \n" | |||||
"ext v6.16b, v6.16b, v6.16b, #8 \n" | |||||
"ext v7.16b, v7.16b, v7.16b, #8 \n" | |||||
"ext v8.16b, v8.16b, v8.16b, #8 \n" | |||||
"ext v9.16b, v9.16b, v9.16b, #8 \n" | |||||
"ext v10.16b, v10.16b, v10.16b, #8 \n" | |||||
"ext v11.16b, v11.16b, v11.16b, #8 \n" | |||||
"ext v12.16b, v12.16b, v12.16b, #8 \n" | |||||
"ext v13.16b, v13.16b, v13.16b, #8 \n" | |||||
"ext v14.16b, v14.16b, v14.16b, #8 \n" | |||||
"ext v15.16b, v15.16b, v15.16b, #8 \n" | |||||
"st1 {v0.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v1.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v2.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v3.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v4.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v5.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v6.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v7.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v8.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v9.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v10.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v11.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v12.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v13.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v14.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v15.16b}, [%[dst]], %[dst_step] \n" | |||||
: | |||||
[src] "+r" (src), | |||||
[dst] "+r" (dst) | |||||
: | |||||
[src_step] "r" (src_step), | |||||
[dst_step] "r" (dst_step) | |||||
: | |||||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", | |||||
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" | |||||
); | |||||
} | |||||
void rotate_8uc1_counterclockwise_16x16(const uchar *src, | |||||
uchar *dst, | |||||
size_t src_step, size_t dst_step) | |||||
{ | |||||
asm volatile ("\n" | |||||
"ld1 {v0.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v1.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v2.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v3.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v4.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v5.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v6.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v7.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v8.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v9.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v10.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v11.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v12.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v13.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v14.16b}, [%[src]], %[src_step] \n" | |||||
"ld1 {v15.16b}, [%[src]], %[src_step] \n" | |||||
"trn1 v16.16b, v0.16b, v1.16b \n" | |||||
"trn2 v17.16b, v0.16b, v1.16b \n" | |||||
"trn1 v18.16b, v2.16b, v3.16b \n" | |||||
"trn2 v19.16b, v2.16b, v3.16b \n" | |||||
"trn1 v20.16b, v4.16b, v5.16b \n" | |||||
"trn2 v21.16b, v4.16b, v5.16b \n" | |||||
"trn1 v22.16b, v6.16b, v7.16b \n" | |||||
"trn2 v23.16b, v6.16b, v7.16b \n" | |||||
"trn1 v24.16b, v8.16b, v9.16b \n" | |||||
"trn2 v25.16b, v8.16b, v9.16b \n" | |||||
"trn1 v26.16b, v10.16b, v11.16b \n" | |||||
"trn2 v27.16b, v10.16b, v11.16b \n" | |||||
"trn1 v28.16b, v12.16b, v13.16b \n" | |||||
"trn2 v29.16b, v12.16b, v13.16b \n" | |||||
"trn1 v30.16b, v14.16b, v15.16b \n" | |||||
"trn2 v31.16b, v14.16b, v15.16b \n" | |||||
"trn1 v0.8h, v16.8h, v18.8h \n" | |||||
"trn2 v2.8h, v16.8h, v18.8h \n" | |||||
"trn1 v4.8h, v20.8h, v22.8h \n" | |||||
"trn2 v6.8h, v20.8h, v22.8h \n" | |||||
"trn1 v8.8h, v24.8h, v26.8h \n" | |||||
"trn2 v10.8h, v24.8h, v26.8h \n" | |||||
"trn1 v12.8h, v28.8h, v30.8h \n" | |||||
"trn2 v14.8h, v28.8h, v30.8h \n" | |||||
"trn1 v1.8h, v17.8h, v19.8h \n" | |||||
"trn2 v3.8h, v17.8h, v19.8h \n" | |||||
"trn1 v5.8h, v21.8h, v23.8h \n" | |||||
"trn2 v7.8h, v21.8h, v23.8h \n" | |||||
"trn1 v9.8h, v25.8h, v27.8h \n" | |||||
"trn2 v11.8h, v25.8h, v27.8h \n" | |||||
"trn1 v13.8h, v29.8h, v31.8h \n" | |||||
"trn2 v15.8h, v29.8h, v31.8h \n" | |||||
"trn1 v16.4s, v0.4s, v4.4s \n" | |||||
"trn2 v20.4s, v0.4s, v4.4s \n" | |||||
"trn1 v24.4s, v8.4s, v12.4s \n" | |||||
"trn2 v28.4s, v8.4s, v12.4s \n" | |||||
"trn1 v17.4s, v1.4s, v5.4s \n" | |||||
"trn2 v21.4s, v1.4s, v5.4s \n" | |||||
"trn1 v25.4s, v9.4s, v13.4s \n" | |||||
"trn2 v29.4s, v9.4s, v13.4s \n" | |||||
"trn1 v18.4s, v2.4s, v6.4s \n" | |||||
"trn2 v22.4s, v2.4s, v6.4s \n" | |||||
"trn1 v26.4s, v10.4s, v14.4s \n" | |||||
"trn2 v30.4s, v10.4s, v14.4s \n" | |||||
"trn1 v19.4s, v3.4s, v7.4s \n" | |||||
"trn2 v23.4s, v3.4s, v7.4s \n" | |||||
"trn1 v27.4s, v11.4s, v15.4s \n" | |||||
"trn2 v31.4s, v11.4s, v15.4s \n" | |||||
"trn1 v0.2d, v16.2d, v24.2d \n" | |||||
"trn2 v8.2d, v16.2d, v24.2d \n" | |||||
"trn1 v1.2d, v17.2d, v25.2d \n" | |||||
"trn2 v9.2d, v17.2d, v25.2d \n" | |||||
"trn1 v2.2d, v18.2d, v26.2d \n" | |||||
"trn2 v10.2d, v18.2d, v26.2d \n" | |||||
"trn1 v3.2d, v19.2d, v27.2d \n" | |||||
"trn2 v11.2d, v19.2d, v27.2d \n" | |||||
"trn1 v4.2d, v20.2d, v28.2d \n" | |||||
"trn2 v12.2d, v20.2d, v28.2d \n" | |||||
"trn1 v5.2d, v21.2d, v29.2d \n" | |||||
"trn2 v13.2d, v21.2d, v29.2d \n" | |||||
"trn1 v6.2d, v22.2d, v30.2d \n" | |||||
"trn2 v14.2d, v22.2d, v30.2d \n" | |||||
"trn1 v7.2d, v23.2d, v31.2d \n" | |||||
"trn2 v15.2d, v23.2d, v31.2d \n" | |||||
"st1 {v15.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v14.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v13.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v12.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v11.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v10.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v9.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v8.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v7.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v6.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v5.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v4.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v3.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v2.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v1.16b}, [%[dst]], %[dst_step] \n" | |||||
"st1 {v0.16b}, [%[dst]], %[dst_step] \n" | |||||
: | |||||
[src] "+r" (src), | |||||
[dst] "+r" (dst) | |||||
: | |||||
[src_step] "r" (src_step), | |||||
[dst_step] "r" (dst_step) | |||||
: | |||||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", | |||||
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" | |||||
); | |||||
} | |||||
void rotate_8uc1_clockwise(const uchar* src, uchar* dst, const size_t rows, | |||||
const size_t cols, const size_t src_step, | |||||
const size_t dst_step) { | |||||
const size_t block = 16; | |||||
(void)block; | |||||
size_t i = 0; | |||||
for (; i + block <= rows; i += block) { | |||||
size_t j = 0; | |||||
for (; j + block <= cols; j += block) { | |||||
rotate_8uc1_clockwise_16x16( | |||||
src + i * src_step + j, | |||||
dst + j * dst_step + (rows - (i + block)), src_step, | |||||
dst_step); | |||||
} | |||||
for (; j < cols; ++j) { | |||||
for (size_t k = 0; k < block; ++k) { | |||||
dst[j * dst_step + (rows - 1 - (i + k))] = | |||||
src[(i + k) * src_step + j]; | |||||
} | |||||
} | |||||
} | |||||
for (; i < rows; ++i) { | |||||
for (size_t j = 0; j < cols; ++j) { | |||||
dst[j * dst_step + (rows - 1 - i)] = src[i * src_step + j]; | |||||
} | |||||
} | |||||
} | |||||
void rotate_8uc1_counterclockwise(const uchar* src, uchar* dst, | |||||
const size_t rows, const size_t cols, | |||||
const size_t src_step, | |||||
const size_t dst_step) { | |||||
const size_t block = 16; | |||||
(void)block; | |||||
size_t i = 0; | |||||
for (; i + block <= rows; i += block) { | |||||
size_t j = 0; | |||||
for (; j + block <= cols; j += block) { | |||||
rotate_8uc1_counterclockwise_16x16( | |||||
src + i * src_step + j, | |||||
dst + (cols - (j + block)) * dst_step + i, src_step, | |||||
dst_step); | |||||
} | |||||
for (; j < cols; ++j) { | |||||
for (size_t k = 0; k < block; ++k) { | |||||
dst[(cols - 1 - j) * dst_step + (i + k)] = | |||||
src[(i + k) * src_step + j]; | |||||
} | |||||
} | |||||
} | |||||
for (; i < rows; ++i) { | |||||
for (size_t j = 0; j < cols; ++j) { | |||||
dst[(cols - 1 - j) * dst_step + i] = src[i * src_step + j]; | |||||
} | |||||
} | |||||
} | |||||
void rotate(const Mat<uchar>& src, Mat<uchar>& dst, bool clockwise) { | |||||
megdnn_assert(src.rows() == dst.cols()); | |||||
megdnn_assert(src.cols() == dst.rows()); | |||||
megdnn_assert(src.channels() == dst.channels()); | |||||
megdnn_assert(src.channels() == 1_z); | |||||
if (clockwise) { | |||||
rotate_8uc1_clockwise(src.ptr(), dst.ptr(), src.rows(), src.cols(), | |||||
src.step(), dst.step()); | |||||
} else { | |||||
rotate_8uc1_counterclockwise(src.ptr(), dst.ptr(), src.rows(), | |||||
src.cols(), src.step(), dst.step()); | |||||
} | |||||
} | |||||
} // namespace megcv | |||||
namespace aarch64 { | |||||
void RotateImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||||
_megdnn_workspace workspace) { | |||||
using namespace megcv; | |||||
check_exec(src.layout, dst.layout, workspace.size); | |||||
//! rotate only support data type is uchar and the channel size is 1 | |||||
if (dst.layout.dtype != dtype::Uint8() || src.layout.shape[3] != 1) { | |||||
return fallback::RotateImpl::exec(src, dst, workspace); | |||||
} | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||||
for (size_t i = 0; i < src.layout.shape[0]; ++i) { | |||||
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i); | |||||
Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i); | |||||
rotate(src_mat, dst_mat, param().clockwise); | |||||
} | |||||
}); | |||||
} | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,35 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/rotate/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/fallback/rotate/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
class RotateImpl : public fallback::RotateImpl { | |||||
public: | |||||
using fallback::RotateImpl::RotateImpl; | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,43 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/warp_perspective/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/warp_perspective/opr_impl.h" | |||||
#include "src/aarch64/warp_perspective/warp_perspective_cv.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/common/warp_common.h" | |||||
#include "src/naive/handle.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
void WarpPerspectiveImpl::exec(_megdnn_tensor_in src, | |||||
_megdnn_tensor_in mat, | |||||
_megdnn_tensor_in mat_idx, | |||||
_megdnn_tensor_in dst, | |||||
_megdnn_workspace workspace) | |||||
{ | |||||
check_exec(src.layout, mat.layout, mat_idx.layout, dst.layout, | |||||
workspace.size); | |||||
if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, | |||||
param().format) && !mat_idx.layout.ndim) { | |||||
warp_perspective_cv_exec(src, mat, dst, param().border_val, | |||||
param().bmode, param().imode, handle()); | |||||
} else { | |||||
//! Use arm_common implementation | |||||
arm_common::WarpPerspectiveImpl::exec(src, mat, mat_idx, dst, workspace); | |||||
} | |||||
} | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,30 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/warp_perspective/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/arm_common/warp_perspective/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
class WarpPerspectiveImpl : public arm_common::WarpPerspectiveImpl { | |||||
public: | |||||
using arm_common::WarpPerspectiveImpl::WarpPerspectiveImpl; | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, | |||||
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
}; | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,257 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/handle.h" | |||||
#include "src/aarch64/warp_perspective/warp_perspective_cv.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/cv/common.h" | |||||
#include "src/common/cv/helper.h" | |||||
#include "src/common/cv/interp_helper.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/common/warp_common.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace megcv; | |||||
using namespace warp; | |||||
namespace { | |||||
constexpr size_t BLOCK_SZ = 32u; | |||||
template <typename T, InterpolationMode imode, BorderMode bmode, size_t CH> | |||||
void warp_perspective_cv(const Mat<T>& src, Mat<T>& dst, const float* trans, | |||||
const float border_value, size_t task_id) { | |||||
// no extra padding | |||||
double M[9]; | |||||
rep(i, 9) M[i] = trans[i]; | |||||
T bvalue[3] = {(T)border_value, (T)border_value, (T)border_value}; | |||||
size_t x1, y1, width = dst.cols(), height = dst.rows(); | |||||
size_t BLOCK_SZ_H = std::min(BLOCK_SZ / 2, height); | |||||
size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width); | |||||
BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height); | |||||
size_t width_block_size = div_ceil<size_t>(width, BLOCK_SZ_W); | |||||
size_t y = (task_id / width_block_size) * BLOCK_SZ_H; | |||||
size_t x = (task_id % width_block_size) * BLOCK_SZ_W; | |||||
// start invoke | |||||
short XY[BLOCK_SZ * BLOCK_SZ * 2], A[BLOCK_SZ * BLOCK_SZ]; | |||||
float64x2_t vM6 = vdupq_n_f64(M[6]); | |||||
float64x2_t vM0 = vdupq_n_f64(M[0]); | |||||
float64x2_t vM3 = vdupq_n_f64(M[3]); | |||||
float64x2_t v2M6 = vdupq_n_f64(M[6] * 2); | |||||
float64x2_t v2M0 = vdupq_n_f64(M[0] * 2); | |||||
float64x2_t v2M3 = vdupq_n_f64(M[3] * 2); | |||||
float64x2_t v4f = vdupq_n_f64(4); | |||||
float64x2_t v1f = vdupq_n_f64(1); | |||||
float64x2_t v0f = vdupq_n_f64(0); | |||||
float64x2_t vTABLE_SIZE = vdupq_n_f64(INTER_TAB_SIZE); | |||||
float64x2_t vmin = vdupq_n_f64((double)INT_MIN); | |||||
float64x2_t vmax = vdupq_n_f64((double)INT_MAX); | |||||
int32x4_t vtabmask = vdupq_n_s32(INTER_TAB_SIZE - 1); | |||||
size_t bw = std::min(BLOCK_SZ_W, width - x); | |||||
size_t bh = std::min(BLOCK_SZ_H, height - y); // height | |||||
Mat<short> _XY(bh, bw, 2, XY); | |||||
Mat<T> dpart(dst, y, bh, x, bw); | |||||
for (y1 = 0; y1 < bh; y1++) { | |||||
short* xy = XY + y1 * bw * 2; | |||||
double X0 = M[0] * x + M[1] * (y + y1) + M[2]; | |||||
double Y0 = M[3] * x + M[4] * (y + y1) + M[5]; | |||||
double W0 = M[6] * x + M[7] * (y + y1) + M[8]; | |||||
float64x2_t vW0 = vdupq_n_f64(W0); | |||||
float64x2_t vidx = {0.f, 1.f}; | |||||
float64x2_t vX0 = vdupq_n_f64(X0); | |||||
float64x2_t vY0 = vdupq_n_f64(Y0); | |||||
if (imode == IMode::NEAREST) { | |||||
for (x1 = 0; x1 + 4 <= bw; x1 += 4) { | |||||
float64x2_t vw0 = vaddq_f64(vW0, vmulq_f64(vM6, vidx)); | |||||
float64x2_t vw1 = vaddq_f64(vw0, v2M6); | |||||
vw0 = vbitq_f64(vdivq_f64(v1f, vw0), v0f, vceqq_f64(vw0, v0f)); | |||||
vw1 = vbitq_f64(vdivq_f64(v1f, vw1), v0f, vceqq_f64(vw1, v0f)); | |||||
float64x2_t vtmp0 = vmlaq_f64(vX0, vM0, vidx); | |||||
float64x2_t vtmp1 = vaddq_f64(vtmp0, v2M0); | |||||
float64x2_t vfx0 = vmulq_f64(vtmp0, vw0); | |||||
float64x2_t vfx1 = vmulq_f64(vtmp1, vw1); | |||||
vfx0 = vmaxq_f64(vminq_f64(vfx0, vmax), vmin); | |||||
vfx1 = vmaxq_f64(vminq_f64(vfx1, vmax), vmin); | |||||
vtmp0 = vmlaq_f64(vY0, vM3, vidx); | |||||
vtmp1 = vaddq_f64(vtmp0, v2M3); | |||||
float64x2_t vfy0 = vmulq_f64(vtmp0, vw0); | |||||
float64x2_t vfy1 = vmulq_f64(vtmp1, vw1); | |||||
vfy0 = vmaxq_f64(vminq_f64(vfy0, vmax), vmin); | |||||
vfy1 = vmaxq_f64(vminq_f64(vfy1, vmax), vmin); | |||||
int32x2_t vx0 = vqmovn_s64(vcvtaq_s64_f64(vfx0)); | |||||
int32x2_t vx1 = vqmovn_s64(vcvtaq_s64_f64(vfx1)); | |||||
int32x2_t vy0 = vqmovn_s64(vcvtaq_s64_f64(vfy0)); | |||||
int32x2_t vy1 = vqmovn_s64(vcvtaq_s64_f64(vfy1)); | |||||
int32x4_t vx = vcombine_s32(vx0, vx1); | |||||
int32x4_t vy = vcombine_s32(vy0, vy1); | |||||
int16x4x2_t ret = {{vqmovn_s32(vx), vqmovn_s32(vy)}}; | |||||
vst2_s16(xy + x1 * 2, ret); | |||||
vidx = vaddq_f64(vidx, v4f); | |||||
} | |||||
for (; x1 < bw; x1++) { | |||||
double W = W0 + M[6] * x1; | |||||
W = W ? 1. / W : 0; | |||||
double fX = std::max( | |||||
(double)INT_MIN, | |||||
std::min((double)INT_MAX, (X0 + M[0] * x1) * W)); | |||||
double fY = std::max( | |||||
(double)INT_MIN, | |||||
std::min((double)INT_MAX, (Y0 + M[3] * x1) * W)); | |||||
int X = saturate_cast<int>(fX); | |||||
int Y = saturate_cast<int>(fY); | |||||
xy[x1 * 2] = saturate_cast<short>(X); | |||||
xy[x1 * 2 + 1] = saturate_cast<short>(Y); | |||||
} | |||||
} else { | |||||
short* alpha = A + y1 * bw; | |||||
for (x1 = 0; x1 + 4 <= bw; x1 += 4) { | |||||
float64x2_t vw0 = vaddq_f64(vW0, vmulq_f64(vM6, vidx)); | |||||
float64x2_t vw1 = vaddq_f64(vw0, v2M6); | |||||
vw0 = vbitq_f64(vdivq_f64(vTABLE_SIZE, vw0), v0f, | |||||
vceqq_f64(vw0, v0f)); | |||||
vw1 = vbitq_f64(vdivq_f64(vTABLE_SIZE, vw1), v0f, | |||||
vceqq_f64(vw1, v0f)); | |||||
float64x2_t vtmp0 = vmlaq_f64(vX0, vM0, vidx); | |||||
float64x2_t vtmp1 = vaddq_f64(vtmp0, v2M0); | |||||
float64x2_t vfx0 = vmulq_f64(vtmp0, vw0); | |||||
float64x2_t vfx1 = vmulq_f64(vtmp1, vw1); | |||||
vfx0 = vmaxq_f64(vminq_f64(vfx0, vmax), vmin); | |||||
vfx1 = vmaxq_f64(vminq_f64(vfx1, vmax), vmin); | |||||
vtmp0 = vmlaq_f64(vY0, vM3, vidx); | |||||
vtmp1 = vaddq_f64(vtmp0, v2M3); | |||||
float64x2_t vfy0 = vmulq_f64(vtmp0, vw0); | |||||
float64x2_t vfy1 = vmulq_f64(vtmp1, vw1); | |||||
vfy0 = vmaxq_f64(vminq_f64(vfy0, vmax), vmin); | |||||
vfy1 = vmaxq_f64(vminq_f64(vfy1, vmax), vmin); | |||||
int32x2_t vx0 = vqmovn_s64(vcvtaq_s64_f64(vfx0)); | |||||
int32x2_t vx1 = vqmovn_s64(vcvtaq_s64_f64(vfx1)); | |||||
int32x2_t vy0 = vqmovn_s64(vcvtaq_s64_f64(vfy0)); | |||||
int32x2_t vy1 = vqmovn_s64(vcvtaq_s64_f64(vfy1)); | |||||
int32x4_t vx = vcombine_s32(vx0, vx1); | |||||
int32x4_t vy = vcombine_s32(vy0, vy1); | |||||
int16x4x2_t ret = {{vqshrn_n_s32(vx, INTER_BITS), | |||||
vqshrn_n_s32(vy, INTER_BITS)}}; | |||||
vst2_s16(xy + x1 * 2, ret); | |||||
vidx = vaddq_f64(vidx, v4f); | |||||
vx = vandq_s32(vx, vtabmask); | |||||
vy = vandq_s32(vy, vtabmask); | |||||
vst1_s16(&alpha[x1], | |||||
vqmovn_s32(vmlaq_n_s32(vx, vy, INTER_TAB_SIZE))); | |||||
} | |||||
for (; x1 < bw; x1++) { | |||||
double W = W0 + M[6] * x1; | |||||
W = W ? INTER_TAB_SIZE / W : 0; | |||||
double fX = std::max( | |||||
(double)INT_MIN, | |||||
std::min((double)INT_MAX, (X0 + M[0] * x1) * W)); | |||||
double fY = std::max( | |||||
(double)INT_MIN, | |||||
std::min((double)INT_MAX, (Y0 + M[3] * x1) * W)); | |||||
int X = saturate_cast<int>(fX); | |||||
int Y = saturate_cast<int>(fY); | |||||
xy[x1 * 2] = saturate_cast<short>(X >> INTER_BITS); | |||||
xy[x1 * 2 + 1] = saturate_cast<short>(Y >> INTER_BITS); | |||||
alpha[x1] = | |||||
(short)((Y & (INTER_TAB_SIZE - 1)) * INTER_TAB_SIZE + | |||||
(X & (INTER_TAB_SIZE - 1))); | |||||
} | |||||
} | |||||
} | |||||
Mat<ushort> _matA(bh, bw, 1, (ushort*)(A)); | |||||
remap<T, imode, bmode, CH, RemapVec<T, CH>>(src, dpart, _XY, _matA, bvalue); | |||||
} | |||||
} // anonymous namespace | |||||
void megdnn::aarch64::warp_perspective_cv_exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst, | |||||
float border_value, BorderMode bmode, InterpolationMode imode, | |||||
Handle* handle) { | |||||
size_t ch = dst.layout[3]; | |||||
size_t width = dst.layout[2]; | |||||
size_t height = dst.layout[1]; | |||||
const size_t batch = dst.layout.shape[0]; | |||||
size_t BLOCK_SZ_H = std::min(BLOCK_SZ / 2, height); | |||||
size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width); | |||||
BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height); | |||||
size_t parallelism_batch = div_ceil<size_t>(height, BLOCK_SZ_H) * | |||||
div_ceil<size_t>(width, BLOCK_SZ_W); | |||||
megdnn_assert(ch == 1 || ch == 3 || ch == 2, | |||||
"unsupported src channel: %zu, avaiable channel size: 1/2/3", | |||||
ch); | |||||
const float* trans_ptr = trans.ptr<dt_float32>(); | |||||
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { | |||||
#define cb(_imode, _bmode, _ch) \ | |||||
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ | |||||
size_t index, size_t) { \ | |||||
size_t batch_id = index / parallelism_batch; \ | |||||
size_t task_id = index % parallelism_batch; \ | |||||
Mat<float> src_mat = TensorND2Mat<float>(src, batch_id); \ | |||||
Mat<float> dst_mat = TensorND2Mat<float>(dst, batch_id); \ | |||||
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ | |||||
warp_perspective_cv<float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \ | |||||
MEGDNN_COMMA _ch>( \ | |||||
src_mat MEGDNN_COMMA const_cast<Mat<float>&>(dst_mat) \ | |||||
MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ | |||||
task_id); \ | |||||
}; \ | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
static_cast<naive::HandleImpl*>(handle), batch* parallelism_batch, \ | |||||
task); | |||||
DISPATCH_IMODE(imode, bmode, ch, cb) | |||||
#undef cb | |||||
} else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { | |||||
#define cb(_imode, _bmode, _ch) \ | |||||
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ | |||||
size_t index, size_t) { \ | |||||
size_t batch_id = index / parallelism_batch; \ | |||||
size_t task_id = index % parallelism_batch; \ | |||||
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, batch_id); \ | |||||
Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, batch_id); \ | |||||
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ | |||||
warp_perspective_cv<uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \ | |||||
MEGDNN_COMMA _ch>( \ | |||||
src_mat MEGDNN_COMMA const_cast<Mat<uchar>&>(dst_mat) \ | |||||
MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ | |||||
task_id); \ | |||||
}; \ | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
static_cast<naive::HandleImpl*>(handle), batch* parallelism_batch, \ | |||||
task); | |||||
DISPATCH_IMODE(imode, bmode, ch, cb) | |||||
#undef cb | |||||
} else { | |||||
megdnn_throw( | |||||
megdnn_mangle("Unsupported datatype of WarpAffine optr.")); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,32 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/warp_perspective/warp_perspective_cv.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <megdnn/oprs.h> | |||||
#include "src/common/cv/helper.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
/** | |||||
* \fn warp_perspective_cv | |||||
* \brief Used if the format is NHWC, transfer from megcv | |||||
*/ | |||||
void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, | |||||
_megdnn_tensor_in dst, float border_value, | |||||
param::WarpPerspective::BorderMode border_mode, | |||||
param::WarpPerspective::InterpolationMode imode, | |||||
Handle* handle); | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,380 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/direct/multi_thread_common.h" | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#include "src/fallback/matrix_mul/opr_impl.h" | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
bool need_dst_copy( | |||||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { | |||||
auto align = param.src_type.enumv() == DTypeEnum::Float32 ? 4 : 8; | |||||
return param.osz[1] % align; | |||||
} | |||||
bool need_src_copy( | |||||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { | |||||
if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { | |||||
return true; | |||||
} | |||||
return need_dst_copy(param); | |||||
} | |||||
void get_rectified_size( | |||||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, size_t FW, | |||||
size_t PH, size_t PW, size_t& IH2, size_t& IW2, size_t& OW2) { | |||||
MEGDNN_MARK_USED_VAR(PW); | |||||
MEGDNN_MARK_USED_VAR(PH); | |||||
auto&& fm = param.filter_meta; | |||||
auto SW = fm.stride[1]; | |||||
auto Align = param.src_type.enumv() == DTypeEnum::Float32 ? 3 : 7; | |||||
OW2 = (OW + Align) & ~Align; | |||||
IH2 = SW * OH + FH - SW; | |||||
IW2 = SW * OW2 + FW - SW; | |||||
// Because stride is 2, sometimes IW == IW2+1. Do a max update to | |||||
// handle this case. | |||||
IH2 = std::max(IH2, IH); | |||||
IW2 = std::max(IW2, IW); | |||||
} | |||||
} // namespace | |||||
template <typename io_ctype, typename compute_ctype> | |||||
WorkspaceBundle MultithreadDirectConvCommon<io_ctype, compute_ctype>::get_bundle( | |||||
const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { | |||||
auto&& fm = param.filter_meta; | |||||
size_t nr_threads = param.nr_threads; | |||||
size_t group = fm.group, batch = param.n; | |||||
size_t IH2 = param.isz[0] + 2 * fm.padding[0]; | |||||
size_t IW2 = param.isz[1] + 2 * fm.padding[1]; | |||||
// part0: copied src | |||||
// part1: copied filter | |||||
size_t part0, part1; | |||||
if (fm.padding[0] == 0 && fm.padding[1] == 0) { | |||||
//! only the last plane need to be copied, add 16 Byte extra space in | |||||
//! case of invalid read and write | |||||
part0 = (param.isz[0] * param.isz[1]) * sizeof(io_ctype) + 16; | |||||
} else if (m_large_group) { | |||||
//! Serial in group, each thread process one group, parallel by group | |||||
part0 = (IH2 * IW2 * fm.icpg * nr_threads) * sizeof(io_ctype) + 16; | |||||
} else { | |||||
//! Parallel in group, Then should copy every inputs to workspace | |||||
part0 = (IH2 * IW2 * fm.icpg * group * batch) * sizeof(io_ctype) + 16; | |||||
} | |||||
if (param.filter_meta.should_flip) { | |||||
if (m_large_group) { | |||||
//! Serial in group, each thread has own workspace and then reuse | |||||
part1 = fm.spatial[0] * fm.spatial[1] * fm.ocpg * fm.icpg * | |||||
nr_threads * sizeof(io_ctype); | |||||
} else { | |||||
part1 = fm.spatial[0] * fm.spatial[1] * fm.ocpg * fm.icpg * group * | |||||
sizeof(io_ctype); | |||||
} | |||||
} else { | |||||
part1 = 0; | |||||
} | |||||
return {nullptr, {part0, part1}}; | |||||
} | |||||
template <typename io_ctype, typename compute_ctype> | |||||
WorkspaceBundle | |||||
MultithreadDirectConvCommon<io_ctype, compute_ctype>::get_bundle_stride( | |||||
const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { | |||||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||||
MEGDNN_MARK_USED_VAR(N); | |||||
MEGDNN_MARK_USED_VAR(OC); | |||||
MEGDNN_MARK_USED_VAR(SH); | |||||
MEGDNN_MARK_USED_VAR(SW); | |||||
auto&& fm = param.filter_meta; | |||||
size_t nr_threads = param.nr_threads; | |||||
size_t group = fm.group, batch = param.n; | |||||
size_t IH2, IW2, OW2; | |||||
get_rectified_size(param, IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); | |||||
size_t src_size = 0, dst_size = 0; | |||||
// src_size: copied src | |||||
// dst_size: copied dst | |||||
if (need_src_copy(param)) { | |||||
src_size = m_large_group | |||||
? IC * IH2 * IW2 * sizeof(io_ctype) * nr_threads | |||||
: IC * IH2 * IW2 * sizeof(io_ctype) * group * batch; | |||||
}; | |||||
if (need_dst_copy(param)) { | |||||
//! add 16 Byte extra space in case of invalid read and write | |||||
dst_size = OH * OW2 * sizeof(io_ctype) * nr_threads + 16; | |||||
} | |||||
return {nullptr, {src_size, dst_size}}; | |||||
} | |||||
//! Process one output channel weight flip | |||||
template <typename io_ctype, typename compute_ctype> | |||||
void MultithreadDirectConvCommon<io_ctype, compute_ctype>::weight_flip_kern( | |||||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids) { | |||||
size_t FH = kern_param.filter_meta.spatial[0]; | |||||
size_t FW = kern_param.filter_meta.spatial[1]; | |||||
size_t IC = kern_param.filter_meta.icpg; | |||||
size_t OC = kern_param.filter_meta.ocpg; | |||||
//! Used for get the workspace offset | |||||
size_t workspace_group_id = workspace_ids[0], channel_id = workspace_ids[2], | |||||
group_id = ncb_index.ndrange_id[0]; | |||||
const io_ctype* filter = | |||||
kern_param.filter<io_ctype>(group_id) + channel_id * FH * FW * IC; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
io_ctype* filter_flip = | |||||
static_cast<io_ctype*>(bundle.get(1)) + | |||||
(workspace_group_id * IC * OC + channel_id * IC) * FH * FW; | |||||
rep(ic, IC) { | |||||
const io_ctype* filter_plane = filter + ic * FH * FW; | |||||
io_ctype* filter_flip_plane = filter_flip + ic * FH * FW; | |||||
rep(fh, FH) rep(fw, FW) { | |||||
filter_flip_plane[fh * FW + fw] = | |||||
filter_plane[(FH - fh - 1) * FW + (FW - fw - 1)]; | |||||
} | |||||
} | |||||
} | |||||
//! Process one input channel copy padding | |||||
template <typename io_ctype, typename compute_ctype> | |||||
void MultithreadDirectConvCommon<io_ctype, compute_ctype>::copy_padding_kern( | |||||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids) { | |||||
size_t IH = kern_param.isz[0]; | |||||
size_t IW = kern_param.isz[1]; | |||||
size_t IC = kern_param.filter_meta.icpg; | |||||
size_t PH = kern_param.filter_meta.padding[0]; | |||||
size_t PW = kern_param.filter_meta.padding[1]; | |||||
size_t IH2 = IH + 2 * PH; | |||||
size_t IW2 = IW + 2 * PW; | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
size_t N = kern_param.n; | |||||
size_t GROUP = kern_param.filter_meta.group; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
//! Used for get the workspace offset | |||||
size_t workspace_group_id = workspace_ids[0], | |||||
workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2]; | |||||
size_t batch_id = ncb_index.ndrange_id[1], | |||||
group_id = ncb_index.ndrange_id[0]; | |||||
const io_ctype* sptr = static_cast<const io_ctype*>( | |||||
kern_param.src<io_ctype>(batch_id, group_id, channel_id)); | |||||
if (PH > 0 || PW > 0) { | |||||
//! copy to sptr_base to eliminate padding effect | |||||
io_ctype* sptr_base = static_cast<io_ctype*>(bundle.get(0)) + | |||||
workspace_group_id * padding_group_size + | |||||
workspace_batch_id * GROUP * padding_group_size + | |||||
channel_id * IH2 * IW2; | |||||
std::memset(sptr_base, 0, sizeof(io_ctype) * IH2 * IW2); | |||||
rep(ih, IH) { | |||||
std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, | |||||
sizeof(io_ctype) * IW); | |||||
} | |||||
} else if (batch_id + 1 == N && channel_id + 1 == IC && | |||||
group_id + 1 == GROUP) { | |||||
//! copy last plane | |||||
io_ctype* sptr_last_c = static_cast<io_ctype*>(bundle.get(0)); | |||||
std::memcpy(sptr_last_c, sptr, sizeof(io_ctype) * IH2 * IW2); | |||||
} | |||||
}; | |||||
//! Process one input channel copy padding | |||||
template <typename io_ctype, typename compute_ctype> | |||||
void MultithreadDirectConvCommon<io_ctype, compute_ctype>:: | |||||
copy_padding_kern_stride(WorkspaceBundle bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids) { | |||||
size_t IH = kern_param.isz[0]; | |||||
size_t IW = kern_param.isz[1]; | |||||
size_t IC = kern_param.filter_meta.icpg; | |||||
size_t PH = kern_param.filter_meta.padding[0]; | |||||
size_t PW = kern_param.filter_meta.padding[1]; | |||||
size_t FH = kern_param.filter_meta.spatial[0]; | |||||
size_t FW = kern_param.filter_meta.spatial[1]; | |||||
size_t OW = kern_param.osz[1]; | |||||
size_t OH = kern_param.osz[0]; | |||||
size_t IH2, IW2, OW2; | |||||
size_t GROUP = kern_param.filter_meta.group; | |||||
get_rectified_size(kern_param, IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
//! Used for get the workspace offset | |||||
size_t workspace_group_id = workspace_ids[0], | |||||
workspace_batch_id = workspace_ids[1]; | |||||
size_t channel_id = workspace_ids[2], batch_id = ncb_index.ndrange_id[1], | |||||
group_id = ncb_index.ndrange_id[0]; | |||||
const io_ctype* sptr = static_cast<const io_ctype*>( | |||||
kern_param.src<io_ctype>(batch_id, group_id, channel_id)); | |||||
if (need_src_copy(kern_param)) { | |||||
//! copy to sptr_base to eliminate padding effect | |||||
io_ctype* sptr_base = static_cast<io_ctype*>(bundle.get(0)) + | |||||
workspace_group_id * padding_group_size + | |||||
workspace_batch_id * GROUP * padding_group_size + | |||||
channel_id * IH2 * IW2; | |||||
std::memset(sptr_base, 0, sizeof(io_ctype) * IH2 * IW2); | |||||
rep(ih, IH) { | |||||
std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, | |||||
sizeof(io_ctype) * IW); | |||||
} | |||||
} | |||||
}; | |||||
//! compute one output channel | |||||
template <typename io_ctype, typename compute_ctype> | |||||
void MultithreadDirectConvCommon<io_ctype, compute_ctype>::do_conv_kern( | |||||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const kern_direct_conv_f32& fun, const CpuNDRange& workspace_ids) { | |||||
size_t OH = kern_param.osz[0]; | |||||
size_t OW = kern_param.osz[1]; | |||||
size_t FH = kern_param.filter_meta.spatial[0]; | |||||
size_t FW = kern_param.filter_meta.spatial[1]; | |||||
size_t IC = kern_param.filter_meta.icpg; | |||||
size_t OC = kern_param.filter_meta.ocpg; | |||||
size_t PH = kern_param.filter_meta.padding[0]; | |||||
size_t PW = kern_param.filter_meta.padding[1]; | |||||
size_t IH2 = kern_param.isz[0] + 2 * PH; | |||||
size_t IW2 = kern_param.isz[1] + 2 * PW; | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
size_t N = kern_param.n; | |||||
size_t GROUP = kern_param.filter_meta.group; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
size_t group_id = ncb_index.ndrange_id[0], | |||||
batch_id = ncb_index.ndrange_id[1]; | |||||
size_t channel_id = workspace_ids[2]; | |||||
const io_ctype* sptr = kern_param.src<io_ctype>(batch_id, group_id); | |||||
const io_ctype* filter = kern_param.filter<io_ctype>(group_id); | |||||
const io_ctype* bias_ptr = | |||||
kern_param.bias<io_ctype>(batch_id, group_id, channel_id); | |||||
io_ctype* dptr = kern_param.dst<io_ctype>(batch_id, group_id, channel_id); | |||||
//! Used for get the workspace offset | |||||
size_t workspace_batch_id = workspace_ids[1]; | |||||
size_t workspace_group_id = workspace_ids[0]; | |||||
io_ctype* sptr_base; | |||||
io_ctype* sptr_last_c; | |||||
auto fptr = | |||||
kern_param.filter_meta.should_flip | |||||
? static_cast<io_ctype*>(bundle.get(1)) + | |||||
(workspace_group_id * OC * IC + channel_id * IC) * | |||||
FH * FW | |||||
: filter + channel_id * FH * FW * IC; | |||||
if (PH > 0 || PW > 0) { | |||||
sptr_base = static_cast<io_ctype*>(bundle.get(0)) + | |||||
workspace_group_id * padding_group_size + | |||||
workspace_batch_id * GROUP * padding_group_size; | |||||
sptr_last_c = sptr_base + (IC - 1) * IH2 * IW2; | |||||
//! Last batch, last group | |||||
} else if (batch_id + 1 == N && group_id + 1 == GROUP) { | |||||
sptr_base = const_cast<io_ctype*>(sptr); | |||||
sptr_last_c = static_cast<io_ctype*>(bundle.get(0)); | |||||
} else { | |||||
sptr_base = const_cast<io_ctype*>(sptr); | |||||
sptr_last_c = sptr_base + (IC - 1) * IH2 * IW2; | |||||
} | |||||
std::memset(dptr, 0, sizeof(io_ctype) * (OH * OW)); | |||||
rep(ic, IC) { | |||||
io_ctype* sptr_cur = | |||||
(ic + 1 == IC ? sptr_last_c : sptr_base + ic * IH2 * IW2); | |||||
fun(reinterpret_cast<const compute_ctype*>(sptr_cur), | |||||
reinterpret_cast<const compute_ctype*>(fptr + ic * FH * FW), | |||||
reinterpret_cast<compute_ctype*>(dptr), IH2, IW2, OH, OW, FH, FW); | |||||
} | |||||
PostProcess<compute_ctype>::run(dptr, const_cast<io_ctype*>(bias_ptr), dptr, | |||||
kern_param.bias_mode, kern_param.nonlineMode, | |||||
kern_param.bias_type, kern_param.dst_type, 1_z, | |||||
1_z, OH, OW); | |||||
}; | |||||
//! compute one output channel | |||||
template <typename io_ctype, typename compute_ctype> | |||||
void MultithreadDirectConvCommon<io_ctype, compute_ctype>::do_conv_kern_stride( | |||||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const kern_direct_conv_f32_stride& fun, | |||||
const CpuNDRange& workspace_ids) { | |||||
size_t IH = kern_param.isz[0]; | |||||
size_t IW = kern_param.isz[1]; | |||||
size_t OH = kern_param.osz[0]; | |||||
size_t OW = kern_param.osz[1]; | |||||
size_t FH = kern_param.filter_meta.spatial[0]; | |||||
size_t FW = kern_param.filter_meta.spatial[1]; | |||||
size_t IC = kern_param.filter_meta.icpg; | |||||
size_t PH = kern_param.filter_meta.padding[0]; | |||||
size_t PW = kern_param.filter_meta.padding[1]; | |||||
size_t IH2, IW2, OW2; | |||||
get_rectified_size(kern_param, IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); | |||||
size_t padding_group_size = IH2 * IW2 * IC; | |||||
size_t GROUP = kern_param.filter_meta.group; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
//! Used for get the workspace offset | |||||
size_t group_id = ncb_index.ndrange_id[0], | |||||
batch_id = ncb_index.ndrange_id[1]; | |||||
size_t channel_id = workspace_ids[2]; | |||||
const io_ctype* sptr = kern_param.src<io_ctype>(batch_id, group_id); | |||||
const io_ctype* fptr = | |||||
kern_param.filter<io_ctype>(group_id) + channel_id * FH * FW * IC; | |||||
const io_ctype* bias_ptr = | |||||
kern_param.bias<io_ctype>(batch_id, group_id, channel_id); | |||||
io_ctype* dptr = kern_param.dst<io_ctype>(batch_id, group_id, channel_id); | |||||
size_t workspace_batch_id = workspace_ids[1]; | |||||
size_t workspace_group_id = workspace_ids[0]; | |||||
io_ctype* sptr_base; | |||||
io_ctype* dptr_base; | |||||
if (need_src_copy(kern_param)) { | |||||
sptr_base = static_cast<io_ctype*>(bundle.get(0)) + | |||||
workspace_group_id * padding_group_size + | |||||
workspace_batch_id * GROUP * padding_group_size; | |||||
} else { | |||||
sptr_base = const_cast<io_ctype*>(sptr); | |||||
} | |||||
if (need_dst_copy(kern_param)) { | |||||
dptr_base = static_cast<io_ctype*>(bundle.get(1)) + | |||||
ncb_index.thread_id * OH * OW2; | |||||
} else { | |||||
dptr_base = dptr; | |||||
} | |||||
if (need_dst_copy(kern_param)) { | |||||
std::memset(dptr_base, 0, sizeof(io_ctype) * (OH * OW2)); | |||||
fun(reinterpret_cast<const compute_ctype*>(sptr_base), | |||||
reinterpret_cast<const compute_ctype*>(fptr), | |||||
reinterpret_cast<compute_ctype*>(dptr_base), IH2, IW2, OH, OW2, IC); | |||||
copy_plane_in_bytes(dptr, dptr_base, OH, OW * sizeof(io_ctype), | |||||
OW * sizeof(io_ctype), OW2 * sizeof(io_ctype)); | |||||
} else { | |||||
std::memset(dptr_base, 0, sizeof(io_ctype) * (OH * OW)); | |||||
fun(reinterpret_cast<const compute_ctype*>(sptr_base), | |||||
reinterpret_cast<const compute_ctype*>(fptr), | |||||
reinterpret_cast<compute_ctype*>(dptr_base), IH2, IW2, OH, OW, IC); | |||||
} | |||||
PostProcess<compute_ctype>::run(dptr, const_cast<io_ctype*>(bias_ptr), dptr, | |||||
kern_param.bias_mode, kern_param.nonlineMode, | |||||
kern_param.bias_type, kern_param.dst_type, 1_z, | |||||
1_z, OH, OW); | |||||
}; | |||||
template class megdnn::arm_common::MultithreadDirectConvCommon<float, float>; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
template class megdnn::arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>; | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,65 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/direct/multi_thread_common.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/fallback/matrix_mul/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
template <class io_ctype, class compute_ctype> | |||||
class MultithreadDirectConvCommon { | |||||
public: | |||||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; | |||||
using kern_direct_conv_f32 = | |||||
std::function<void(const compute_ctype* src, | |||||
const compute_ctype* filter, compute_ctype* dst, | |||||
size_t, size_t, size_t, size_t, size_t, size_t)>; | |||||
using kern_direct_conv_f32_stride = std::function<void( | |||||
const compute_ctype* src, const compute_ctype* filter, | |||||
compute_ctype* dst, size_t, size_t, size_t, size_t, size_t)>; | |||||
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param, | |||||
bool m_large_group); | |||||
static WorkspaceBundle get_bundle_stride(const NCBKernSizeParam& param, | |||||
bool m_large_group); | |||||
static void weight_flip_kern(WorkspaceBundle bundle, | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids); | |||||
static void copy_padding_kern(WorkspaceBundle bundle, | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids); | |||||
static void copy_padding_kern_stride(WorkspaceBundle bundle, | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids); | |||||
static void do_conv_kern(WorkspaceBundle bundle, | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index, | |||||
const kern_direct_conv_f32& fun, | |||||
const CpuNDRange& workspace_ids); | |||||
static void do_conv_kern_stride(WorkspaceBundle bundle, | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index, | |||||
const kern_direct_conv_f32_stride& fun, | |||||
const CpuNDRange& workspace_ids); | |||||
}; | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,561 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/f16/algos.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/f16/algos.h" | |||||
#include "src/arm_common/conv_bias/direct/multi_thread_common.h" | |||||
#include "src/arm_common/conv_bias/f16/direct.h" | |||||
#include "src/arm_common/conv_bias/f16/do_conv_stride1.h" | |||||
#include "src/arm_common/conv_bias/f16/strategy.h" | |||||
#include "src/arm_common/conv_bias/img2col_helper.h" | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#include "src/common/opr_delegate.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp16) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
/* ======================= AlgoFP16WinogradF23 ======================== */ | |||||
bool ConvBiasImpl::AlgoFP16WinogradF23::usable( | |||||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MEGDNN_MARK_USED_VAR(opr); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 0, 0) { | |||||
using Strategy = winograd::winograd_2x3_4x4_f16; | |||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||||
auto&& matmul_param = | |||||
megdnn::winograd::ConvBias<Strategy>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_matmul_kern_param(param); | |||||
return m_matmul_algo->usable(matmul_param) && | |||||
(opr->param().format == param::ConvBias::Format::NCHW || | |||||
(opr->param().format == | |||||
param::ConvBias::Format::NCHW_WINOGRAD && | |||||
opr->param().output_block_size == 2 && | |||||
param.winograd_matmul_format == | |||||
param::MatrixMul::Format::DEFAULT)) && | |||||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||||
param.filter_meta.spatial[0] == 3) && | |||||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||||
param.filter_meta.stride[0] == 1) && | |||||
(param.filter_meta.dilation[0] == | |||||
param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||||
param.src_type.enumv() == DTypeEnum::Float16 && | |||||
param.filter_meta.icpg % 4 == 0 && | |||||
param.filter_meta.ocpg % 4 == 0; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoFP16WinogradF23::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 0, 1) { | |||||
winograd::winograd_2x3_4x4_f16 strategy( | |||||
param.src_type, param.filter_type, param.dst_type); | |||||
return megdnn::winograd::ConvBias<winograd::winograd_2x3_4x4_f16>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_workspace_size(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoFP16WinogradF23::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 0, 2) { | |||||
winograd::winograd_2x3_4x4_f16 strategy( | |||||
param.src_type, param.filter_type, param.dst_type); | |||||
auto winograd_impl = | |||||
megdnn::winograd::ConvBias<winograd::winograd_2x3_4x4_f16>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg); | |||||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
/* ======================= AlgoFP16WinogradF45 ======================== */ | |||||
bool ConvBiasImpl::AlgoFP16WinogradF45::usable( | |||||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MEGDNN_MARK_USED_VAR(opr); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 1, 0) { | |||||
using Strategy = winograd::winograd_4x5_1x1_f16; | |||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||||
auto&& matmul_param = | |||||
megdnn::winograd::ConvBias<Strategy>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_matmul_kern_param(param); | |||||
return m_matmul_algo->usable(matmul_param) && | |||||
(opr->param().format == param::ConvBias::Format::NCHW || | |||||
(opr->param().format == | |||||
param::ConvBias::Format::NCHW_WINOGRAD && | |||||
opr->param().output_block_size == 4 && | |||||
param.winograd_matmul_format == | |||||
param::MatrixMul::Format::DEFAULT)) && | |||||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||||
param.filter_meta.spatial[0] == 5) && | |||||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||||
param.filter_meta.stride[0] == 1) && | |||||
(param.filter_meta.dilation[0] == | |||||
param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||||
param.src_type.enumv() == DTypeEnum::Float16; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoFP16WinogradF45::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
winograd::winograd_4x5_1x1_f16 strategy(param.src_type, param.filter_type, | |||||
param.dst_type); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 1, 1) { | |||||
return megdnn::winograd::ConvBias<winograd::winograd_4x5_1x1_f16>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_workspace_size(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoFP16WinogradF45::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 1, 2) { | |||||
winograd::winograd_4x5_1x1_f16 strategy( | |||||
param.src_type, param.filter_type, param.dst_type); | |||||
auto winograd_impl = | |||||
megdnn::winograd::ConvBias<winograd::winograd_4x5_1x1_f16>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg); | |||||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
/* ======================= AlgoFP16WinogradF63 ======================== */ | |||||
bool ConvBiasImpl::AlgoFP16WinogradF63::usable( | |||||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MEGDNN_MARK_USED_VAR(opr); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 2, 0) { | |||||
using Strategy = winograd::winograd_6x3_1x1_f16; | |||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||||
auto&& matmul_param = | |||||
megdnn::winograd::ConvBias<Strategy>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_matmul_kern_param(param); | |||||
return m_matmul_algo->usable(matmul_param) && | |||||
(opr->param().format == param::ConvBias::Format::NCHW || | |||||
(opr->param().format == | |||||
param::ConvBias::Format::NCHW_WINOGRAD && | |||||
opr->param().output_block_size == 6 && | |||||
param.winograd_matmul_format == | |||||
param::MatrixMul::Format::DEFAULT)) && | |||||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||||
param.filter_meta.spatial[0] == 3) && | |||||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||||
param.filter_meta.stride[0] == 1) && | |||||
(param.filter_meta.dilation[0] == | |||||
param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||||
param.src_type.enumv() == DTypeEnum::Float16; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoFP16WinogradF63::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
winograd::winograd_6x3_1x1_f16 strategy(param.src_type, param.filter_type, | |||||
param.dst_type); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 2, 1) { | |||||
return megdnn::winograd::ConvBias<winograd::winograd_6x3_1x1_f16>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_workspace_size(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoFP16WinogradF63::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 2, 2) { | |||||
winograd::winograd_6x3_1x1_f16 strategy( | |||||
param.src_type, param.filter_type, param.dst_type); | |||||
auto winograd_impl = | |||||
megdnn::winograd::ConvBias<winograd::winograd_6x3_1x1_f16>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg); | |||||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
/* ======================= AlgoFP16WinogradF23_8x8 ======================== */ | |||||
bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable( | |||||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MEGDNN_MARK_USED_VAR(opr); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 3, 0) { | |||||
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) | |||||
return false; | |||||
using Strategy = winograd::winograd_2x3_8x8_f16; | |||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||||
auto&& matmul_param = | |||||
megdnn::winograd::ConvBias<Strategy, | |||||
param::MatrixMul::Format::MK8>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_matmul_kern_param(param); | |||||
return m_matmul_algo->usable(matmul_param) && | |||||
(opr->param().format == param::ConvBias::Format::NCHW || | |||||
(opr->param().format == | |||||
param::ConvBias::Format::NCHW_WINOGRAD && | |||||
opr->param().output_block_size == 2 && | |||||
param.winograd_matmul_format == | |||||
param::MatrixMul::Format::MK8)) && | |||||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||||
param.filter_meta.spatial[0] == 3) && | |||||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||||
param.filter_meta.stride[0] == 1) && | |||||
(param.filter_meta.dilation[0] == | |||||
param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||||
param.src_type.enumv() == DTypeEnum::Float16; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoFP16WinogradF23_8x8::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 3, 1) { | |||||
winograd::winograd_2x3_8x8_f16 strategy( | |||||
param.src_type, param.filter_type, param.dst_type); | |||||
return megdnn::winograd::ConvBias<winograd::winograd_2x3_8x8_f16, | |||||
param::MatrixMul::Format::MK8>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_workspace_size(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoFP16WinogradF23_8x8::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 2) { | |||||
winograd::winograd_2x3_8x8_f16 strategy( | |||||
param.src_type, param.filter_type, param.dst_type); | |||||
auto winograd_impl = | |||||
megdnn::winograd::ConvBias<winograd::winograd_2x3_8x8_f16, | |||||
param::MatrixMul::Format::MK8>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg); | |||||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
/*========================from Convolution=============================*/ | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_kimpl) | |||||
bool ConvBiasImpl::AlgoF16Direct::usable( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 0) { | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
auto SH = fm.stride[0], SW = fm.stride[1]; | |||||
// the condition ``param.isz[0]*param.isz[1] >= 8'' and | |||||
// ``param.osz[0]*param.osz[1] >= 8'' comes from the fact that the | |||||
// kernel may have access to up to 8 fp16 after the end of the memory | |||||
// chunk. | |||||
bool aviliable = fm.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float16 && | |||||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && | |||||
param.isz[0] * param.isz[1] >= 8 && | |||||
param.osz[0] * param.osz[1] >= 8 && FH <= 7 && | |||||
SH == 1 && SW == 1; | |||||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoF16Direct::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) { | |||||
auto wbundle = | |||||
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle( | |||||
param, m_large_group); | |||||
return wbundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls( | |||||
const NCBKernSizeParam& param) const { | |||||
auto fm = param.filter_meta; | |||||
size_t N = param.n; | |||||
size_t IC = param.filter_meta.icpg; | |||||
size_t OC = param.filter_meta.ocpg; | |||||
size_t group = fm.group; | |||||
WorkspaceBundle wbundle = | |||||
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle( | |||||
param, m_large_group); | |||||
SmallVector<NCBKern> ret_kerns; | |||||
//! When group >= nr_threads, treat it as large_group, each thread process | |||||
//! one group for better performance | |||||
if (m_large_group) { | |||||
//! Channel wise conv and big groups | |||||
auto exec_one_group = [wbundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
auto fm = kern_param.filter_meta; | |||||
size_t IC = fm.icpg; | |||||
size_t OC = fm.ocpg; | |||||
WorkspaceBundle bundle = wbundle; | |||||
if (fm.should_flip) { | |||||
for (size_t oc = 0; oc < OC; oc++) { | |||||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||||
weight_flip_kern(bundle, kern_param, ncb_index, | |||||
{ncb_index.thread_id, 0, oc}); | |||||
} | |||||
} | |||||
for (size_t ic = 0; ic < IC; ic++) { | |||||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||||
copy_padding_kern(bundle, kern_param, ncb_index, | |||||
{ncb_index.thread_id, 0, ic}); | |||||
} | |||||
for (size_t oc = 0; oc < OC; oc++) { | |||||
MultithreadDirectConvCommon<dt_float16, __fp16>::do_conv_kern( | |||||
bundle, kern_param, ncb_index, | |||||
fp16::conv_bias::kern_direct_f16, | |||||
{ncb_index.thread_id, 0, oc}); | |||||
} | |||||
}; | |||||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||||
} else { | |||||
WorkspaceBundle bundle = wbundle; | |||||
if (fm.should_flip) { | |||||
auto weight_flip = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||||
weight_flip_kern(bundle, kern_param, ncb_index, | |||||
ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({weight_flip, {group, 1_z, OC}}); | |||||
} | |||||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
MultithreadDirectConvCommon<dt_float16, __fp16>::copy_padding_kern( | |||||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||||
auto do_conv = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
MultithreadDirectConvCommon<dt_float16, __fp16>::do_conv_kern( | |||||
bundle, kern_param, ncb_index, | |||||
fp16::conv_bias::kern_direct_f16, ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||||
} | |||||
return ret_kerns; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) { | |||||
return get_kimpls(param); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
/* ===================== stride-1 algo ===================== */ | |||||
bool ConvBiasImpl::AlgoF16DirectStride1::usable( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 0) { | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
bool aviliable = | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float16 && | |||||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( | |||||
const NCBKernSizeParam& param) const { | |||||
auto fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
size_t N = param.n; | |||||
size_t IC = param.filter_meta.icpg; | |||||
size_t OC = param.filter_meta.ocpg; | |||||
size_t group = fm.group; | |||||
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | |||||
size_t, size_t, size_t, size_t, size_t)>; | |||||
Func conv_kern_function = nullptr; | |||||
#define SWITCH_KERN() \ | |||||
switch (FH) { \ | |||||
case 2: \ | |||||
conv_kern_function = fp16::conv_stride1::do_conv_2x2_stride1; \ | |||||
break; \ | |||||
case 3: \ | |||||
conv_kern_function = fp16::conv_stride1::do_conv_3x3_stride1; \ | |||||
break; \ | |||||
case 5: \ | |||||
conv_kern_function = fp16::conv_stride1::do_conv_5x5_stride1; \ | |||||
break; \ | |||||
} | |||||
SWITCH_KERN(); | |||||
WorkspaceBundle wbundle = | |||||
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle_stride( | |||||
param, m_large_group); | |||||
SmallVector<NCBKern> ret_kerns; | |||||
//! When group >= nr_threads, treat it as large_group, each thread process | |||||
//! one group for better performance | |||||
if (m_large_group) { | |||||
//! Channel wise conv and big groups | |||||
auto exec_one_group = [wbundle, conv_kern_function]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
auto fm = kern_param.filter_meta; | |||||
size_t IC = fm.icpg; | |||||
size_t OC = fm.ocpg; | |||||
WorkspaceBundle bundle = wbundle; | |||||
for (size_t ic = 0; ic < IC; ic++) { | |||||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
{ncb_index.thread_id, 0, ic}); | |||||
} | |||||
for (size_t oc = 0; oc < OC; oc++) { | |||||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||||
do_conv_kern_stride(bundle, kern_param, ncb_index, | |||||
conv_kern_function, | |||||
{ncb_index.thread_id, 0, oc}); | |||||
} | |||||
}; | |||||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||||
} else { | |||||
WorkspaceBundle bundle = wbundle; | |||||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||||
auto do_conv = [bundle, conv_kern_function]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||||
do_conv_kern_stride(bundle, kern_param, ncb_index, | |||||
conv_kern_function, | |||||
ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||||
} | |||||
return ret_kerns; | |||||
} | |||||
size_t ConvBiasImpl::AlgoF16DirectStride1::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 1) { | |||||
auto bundle = MultithreadDirectConvCommon< | |||||
dt_float16, __fp16>::get_bundle_stride(param, m_large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF16DirectStride1::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 2) { | |||||
return get_kimpls(param); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,185 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/f16/algos.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/fallback/matrix_mul/opr_impl.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
class ConvBiasImpl::AlgoFP16WinogradF23 final : public AlgoBase { | |||||
public: | |||||
AlgoFP16WinogradF23(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
uint32_t tile_size) | |||||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
if (m_name.empty()) { | |||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||||
m_matmul_algo->name(), {1, 2, m_tile_size}); | |||||
} | |||||
return m_name.c_str(); | |||||
} | |||||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
static std::vector<fallback::MatrixMulImpl::Algorithm*> | |||||
get_avaiable_matmul_algos(const NCBKernSizeParam& param); | |||||
private: | |||||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||||
mutable std::string m_name; | |||||
uint32_t m_tile_size; | |||||
}; | |||||
class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { | |||||
public: | |||||
AlgoFP16WinogradF45(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
uint32_t tile_size) | |||||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
if (m_name.empty()) { | |||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||||
m_matmul_algo->name(), {1, 4, m_tile_size}); | |||||
} | |||||
return m_name.c_str(); | |||||
} | |||||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
static std::vector<fallback::MatrixMulImpl::Algorithm*> | |||||
get_avaiable_matmul_algos(const NCBKernSizeParam& param); | |||||
private: | |||||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||||
mutable std::string m_name; | |||||
uint32_t m_tile_size; | |||||
}; | |||||
class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { | |||||
public: | |||||
AlgoFP16WinogradF63(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
uint32_t tile_size) | |||||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
if (m_name.empty()) { | |||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||||
m_matmul_algo->name(), {1, 6, m_tile_size}); | |||||
} | |||||
return m_name.c_str(); | |||||
} | |||||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
static std::vector<fallback::MatrixMulImpl::Algorithm*> | |||||
get_avaiable_matmul_algos(const NCBKernSizeParam& param); | |||||
private: | |||||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||||
mutable std::string m_name; | |||||
uint32_t m_tile_size; | |||||
}; | |||||
class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { | |||||
public: | |||||
AlgoFP16WinogradF23_8x8(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
uint32_t tile_size) | |||||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
if (m_name.empty()) { | |||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||||
m_matmul_algo->name(), {8, 2, m_tile_size}); | |||||
} | |||||
return m_name.c_str(); | |||||
} | |||||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
private: | |||||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||||
mutable std::string m_name; | |||||
uint32_t m_tile_size; | |||||
}; | |||||
class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||||
bool m_large_group; | |||||
public: | |||||
AlgoF16Direct(bool is_large_group) : m_large_group{is_large_group} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return m_large_group ? "F16DIRECT_LARGE_GROUP" | |||||
: "F16DIRECT_SMALL_GROUP"; | |||||
} | |||||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
}; | |||||
class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||||
bool m_large_group; | |||||
public: | |||||
AlgoF16DirectStride1(bool is_large_group) : m_large_group{is_large_group} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return m_large_group ? "F16STRD1_LARGE_GROUP" : "F16STRD1_SMALL_GROUP"; | |||||
} | |||||
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
}; | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,799 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/f16/direct.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "./direct.h" | |||||
#include "include/megdnn/oprs.h" | |||||
#include "midout.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
#include <cstring> | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/simd_macro/neon_helper.h" | |||||
MIDOUT_DECL(megdnn_arm_conv_f16) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
using namespace fp16; | |||||
using namespace conv_bias; | |||||
namespace { | |||||
#define BLOCK_H 8 | |||||
#define LOAD_RESULT_VAL \ | |||||
if (width < 8) { \ | |||||
auto load_less_8 = [](__fp16* dst, float16x8_t& data) { \ | |||||
if (width == 1u) { \ | |||||
data = vld1q_lane_f16(dst, data, 0); \ | |||||
} else if (width == 2u) { \ | |||||
data = vld1q_lane_f16(dst + 0, data, 0); \ | |||||
data = vld1q_lane_f16(dst + 1, data, 1); \ | |||||
} else if (width == 3u) { \ | |||||
data = vld1q_lane_f16(dst + 0, data, 0); \ | |||||
data = vld1q_lane_f16(dst + 1, data, 1); \ | |||||
data = vld1q_lane_f16(dst + 2, data, 2); \ | |||||
} else if (width == 4u) { \ | |||||
data = vld1q_lane_u64(dst, data, 0); \ | |||||
} else if (width == 5u) { \ | |||||
data = vld1q_lane_u64(dst, data, 0); \ | |||||
data = vld1q_lane_f16(dst + 4, data, 4); \ | |||||
} else if (width == 6u) { \ | |||||
data = vld1q_lane_u64(dst, data, 0); \ | |||||
data = vld1q_lane_f16(dst + 4, data, 4); \ | |||||
data = vld1q_lane_f16(dst + 5, data, 5); \ | |||||
} else if (width == 7u) { \ | |||||
data = vld1q_lane_u64(dst, data, 0); \ | |||||
data = vld1q_lane_f16(dst + 4, data, 4); \ | |||||
data = vld1q_lane_f16(dst + 5, data, 5); \ | |||||
data = vld1q_lane_f16(dst + 6, data, 6); \ | |||||
} \ | |||||
}; \ | |||||
if (height >= 1) \ | |||||
load_less_8(dst + 0 * OW, out0); \ | |||||
if (height >= 2) \ | |||||
load_less_8(dst + 1 * OW, out1); \ | |||||
if (height >= 3) \ | |||||
load_less_8(dst + 2 * OW, out2); \ | |||||
if (height >= 4) \ | |||||
load_less_8(dst + 3 * OW, out3); \ | |||||
if (height >= 5) \ | |||||
load_less_8(dst + 4 * OW, out4); \ | |||||
if (height >= 6) \ | |||||
load_less_8(dst + 5 * OW, out5); \ | |||||
if (height >= 7) \ | |||||
load_less_8(dst + 6 * OW, out6); \ | |||||
if (height >= 8) \ | |||||
load_less_8(dst + 7 * OW, out7); \ | |||||
} else { \ | |||||
if (height >= 1) \ | |||||
out0 = vld1q_f16(dst + 0 * OW); \ | |||||
if (height >= 2) \ | |||||
out1 = vld1q_f16(dst + 1 * OW); \ | |||||
if (height >= 3) \ | |||||
out2 = vld1q_f16(dst + 2 * OW); \ | |||||
if (height >= 4) \ | |||||
out3 = vld1q_f16(dst + 3 * OW); \ | |||||
if (height >= 5) \ | |||||
out4 = vld1q_f16(dst + 4 * OW); \ | |||||
if (height >= 6) \ | |||||
out5 = vld1q_f16(dst + 5 * OW); \ | |||||
if (height >= 7) \ | |||||
out6 = vld1q_f16(dst + 6 * OW); \ | |||||
if (height >= 8) \ | |||||
out7 = vld1q_f16(dst + 7 * OW); \ | |||||
} | |||||
#define STORE_RESULT_VAL \ | |||||
if (width < 8) { \ | |||||
auto store_less_8 = [](__fp16* dst, float16x8_t& data) { \ | |||||
if (width == 1u) { \ | |||||
vst1q_lane_f16(dst, data, 0); \ | |||||
} else if (width == 2u) { \ | |||||
vst1q_lane_f16(dst + 0, data, 0); \ | |||||
vst1q_lane_f16(dst + 1, data, 1); \ | |||||
} else if (width == 3u) { \ | |||||
vst1q_lane_f16(dst + 0, data, 0); \ | |||||
vst1q_lane_f16(dst + 1, data, 1); \ | |||||
vst1q_lane_f16(dst + 2, data, 2); \ | |||||
} else if (width == 4u) { \ | |||||
vst1_f16(dst, vget_low_f16(data)); \ | |||||
} else if (width == 5u) { \ | |||||
vst1_f16(dst, vget_low_f16(data)); \ | |||||
vst1q_lane_f16(dst + 4, data, 4); \ | |||||
} else if (width == 6u) { \ | |||||
vst1_f16(dst, vget_low_f16(data)); \ | |||||
vst1q_lane_f16(dst + 4, data, 4); \ | |||||
vst1q_lane_f16(dst + 5, data, 5); \ | |||||
} else if (width == 7u) { \ | |||||
vst1_f16(dst, vget_low_f16(data)); \ | |||||
vst1q_lane_f16(dst + 4, data, 4); \ | |||||
vst1q_lane_f16(dst + 5, data, 5); \ | |||||
vst1q_lane_f16(dst + 6, data, 6); \ | |||||
} \ | |||||
}; \ | |||||
if (height >= 1) \ | |||||
store_less_8(dst + 0 * OW, out0); \ | |||||
if (height >= 2) \ | |||||
store_less_8(dst + 1 * OW, out1); \ | |||||
if (height >= 3) \ | |||||
store_less_8(dst + 2 * OW, out2); \ | |||||
if (height >= 4) \ | |||||
store_less_8(dst + 3 * OW, out3); \ | |||||
if (height >= 5) \ | |||||
store_less_8(dst + 4 * OW, out4); \ | |||||
if (height >= 6) \ | |||||
store_less_8(dst + 5 * OW, out5); \ | |||||
if (height >= 7) \ | |||||
store_less_8(dst + 6 * OW, out6); \ | |||||
if (height >= 8) \ | |||||
store_less_8(dst + 7 * OW, out7); \ | |||||
} else { \ | |||||
if (height >= 1) \ | |||||
vst1q_f16(dst + 0 * OW, out0); \ | |||||
if (height >= 2) \ | |||||
vst1q_f16(dst + 1 * OW, out1); \ | |||||
if (height >= 3) \ | |||||
vst1q_f16(dst + 2 * OW, out2); \ | |||||
if (height >= 4) \ | |||||
vst1q_f16(dst + 3 * OW, out3); \ | |||||
if (height >= 5) \ | |||||
vst1q_f16(dst + 4 * OW, out4); \ | |||||
if (height >= 6) \ | |||||
vst1q_f16(dst + 5 * OW, out5); \ | |||||
if (height >= 7) \ | |||||
vst1q_f16(dst + 6 * OW, out6); \ | |||||
if (height >= 8) \ | |||||
vst1q_f16(dst + 7 * OW, out7); \ | |||||
} | |||||
template <int FH, int height, int width> | |||||
struct do_pixel_proxy { | |||||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow); | |||||
}; | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<1, height, width> { | |||||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
MEGDNN_MARK_USED_VAR(OH); | |||||
const int ih = oh, iw = ow; | |||||
#define cb(i) float16x8_t out##i{0}; | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
float16x8_t kr0, inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_RESULT_VAL; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const __fp16* src_dd = src + fw; | |||||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||||
#define cb(i) \ | |||||
if (height > i) { \ | |||||
inp = vld1q_f16(src_dd + i * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||||
} | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
} | |||||
STORE_RESULT_VAL; | |||||
} | |||||
}; | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<2, height, width> { | |||||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
MEGDNN_MARK_USED_VAR(OH); | |||||
const int ih = oh, iw = ow; | |||||
#define cb(i) float16x8_t out##i{0}; | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
float16x8_t kr0, kr1, inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_RESULT_VAL; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const __fp16* src_dd = src + fw; | |||||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f16(filter[1 * FW + fw]); | |||||
#define cb(i) \ | |||||
if (height > i) { \ | |||||
inp = vld1q_f16(src_dd + i * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||||
inp = vld1q_f16(src_dd + (i + 1) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr1); \ | |||||
} | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
} | |||||
STORE_RESULT_VAL; | |||||
} | |||||
}; | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<3, height, width> { | |||||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
MEGDNN_MARK_USED_VAR(OH); | |||||
const int ih = oh, iw = ow; | |||||
#define cb(i) float16x8_t out##i{0}; | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
float16x8_t kr0, kr1, kr2, inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_RESULT_VAL; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const __fp16* src_dd = src + fw; | |||||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f16(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f16(filter[2 * FW + fw]); | |||||
#define cb(i) \ | |||||
if (height > i) { \ | |||||
inp = vld1q_f16(src_dd + i * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||||
inp = vld1q_f16(src_dd + (i + 1) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr1); \ | |||||
inp = vld1q_f16(src_dd + (i + 2) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr2); \ | |||||
} | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
} | |||||
STORE_RESULT_VAL; | |||||
} | |||||
}; | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<4, height, width> { | |||||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
MEGDNN_MARK_USED_VAR(OH); | |||||
const int ih = oh, iw = ow; | |||||
#define cb(i) float16x8_t out##i{0}; | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
float16x8_t kr0, kr1, kr2, kr3, inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_RESULT_VAL; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const __fp16* src_dd = src + fw; | |||||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f16(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f16(filter[2 * FW + fw]); | |||||
kr3 = vdupq_n_f16(filter[3 * FW + fw]); | |||||
#define cb(i) \ | |||||
if (height > i) { \ | |||||
inp = vld1q_f16(src_dd + i * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||||
inp = vld1q_f16(src_dd + (i + 1) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr1); \ | |||||
inp = vld1q_f16(src_dd + (i + 2) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr2); \ | |||||
inp = vld1q_f16(src_dd + (i + 3) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr3); \ | |||||
} | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
} | |||||
STORE_RESULT_VAL; | |||||
} | |||||
}; | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<5, height, width> { | |||||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
MEGDNN_MARK_USED_VAR(OH); | |||||
const int ih = oh, iw = ow; | |||||
#define cb(i) float16x8_t out##i{0}; | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
float16x8_t kr0, kr1, kr2, kr3, kr4, inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_RESULT_VAL; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const __fp16* src_dd = src + fw; | |||||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f16(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f16(filter[2 * FW + fw]); | |||||
kr3 = vdupq_n_f16(filter[3 * FW + fw]); | |||||
kr4 = vdupq_n_f16(filter[4 * FW + fw]); | |||||
#define cb(i) \ | |||||
if (height > i) { \ | |||||
inp = vld1q_f16(src_dd + i * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||||
inp = vld1q_f16(src_dd + (i + 1) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr1); \ | |||||
inp = vld1q_f16(src_dd + (i + 2) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr2); \ | |||||
inp = vld1q_f16(src_dd + (i + 3) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr3); \ | |||||
inp = vld1q_f16(src_dd + (i + 4) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr4); \ | |||||
} | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
} | |||||
STORE_RESULT_VAL; | |||||
} | |||||
}; | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<6, height, width> { | |||||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
MEGDNN_MARK_USED_VAR(OH); | |||||
const int ih = oh, iw = ow; | |||||
#define cb(i) float16x8_t out##i{0}; | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
float16x8_t kr0, kr1, kr2, kr3, kr4, kr5, inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_RESULT_VAL; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const __fp16* src_dd = src + fw; | |||||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f16(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f16(filter[2 * FW + fw]); | |||||
kr3 = vdupq_n_f16(filter[3 * FW + fw]); | |||||
kr4 = vdupq_n_f16(filter[4 * FW + fw]); | |||||
kr5 = vdupq_n_f16(filter[5 * FW + fw]); | |||||
#define cb(i) \ | |||||
if (height > i) { \ | |||||
inp = vld1q_f16(src_dd + i * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||||
inp = vld1q_f16(src_dd + (i + 1) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr1); \ | |||||
inp = vld1q_f16(src_dd + (i + 2) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr2); \ | |||||
inp = vld1q_f16(src_dd + (i + 3) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr3); \ | |||||
inp = vld1q_f16(src_dd + (i + 4) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr4); \ | |||||
inp = vld1q_f16(src_dd + (i + 5) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr5); \ | |||||
} | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
} | |||||
STORE_RESULT_VAL; | |||||
} | |||||
}; | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<7, height, width> { | |||||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
MEGDNN_MARK_USED_VAR(OH); | |||||
const int ih = oh, iw = ow; | |||||
#define cb(i) float16x8_t out##i{0}; | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
float16x8_t kr0, kr1, kr2, kr3, kr4, kr5, kr6, inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_RESULT_VAL; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const __fp16* src_dd = src + fw; | |||||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f16(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f16(filter[2 * FW + fw]); | |||||
kr3 = vdupq_n_f16(filter[3 * FW + fw]); | |||||
kr4 = vdupq_n_f16(filter[4 * FW + fw]); | |||||
kr5 = vdupq_n_f16(filter[5 * FW + fw]); | |||||
kr6 = vdupq_n_f16(filter[6 * FW + fw]); | |||||
#define cb(i) \ | |||||
if (height > i) { \ | |||||
inp = vld1q_f16(src_dd + i * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||||
inp = vld1q_f16(src_dd + (i + 1) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr1); \ | |||||
inp = vld1q_f16(src_dd + (i + 2) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr2); \ | |||||
inp = vld1q_f16(src_dd + (i + 3) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr3); \ | |||||
inp = vld1q_f16(src_dd + (i + 4) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr4); \ | |||||
inp = vld1q_f16(src_dd + (i + 5) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr5); \ | |||||
inp = vld1q_f16(src_dd + (i + 6) * IW); \ | |||||
out##i = vmlaq_f16(out##i, inp, kr6); \ | |||||
} | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||||
#undef cb | |||||
} | |||||
STORE_RESULT_VAL; | |||||
} | |||||
}; | |||||
#undef STORE_RESULT_VAL | |||||
#undef LOAD_RESULT_VAL | |||||
template <int FH, int height, int width> | |||||
void do_pixel(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
do_pixel_proxy<FH, height, width>::exec(src, filter, dst, IH, IW, OH, OW, | |||||
FW, oh, ow); | |||||
} | |||||
template <int FH> | |||||
void do_conv_tpl_enable_prefetch(const __fp16* src, | |||||
const __fp16* filter, __fp16* dst, | |||||
const int IH, const int IW, const int OH, | |||||
const int OW, const int FW) { | |||||
const int hbeg = 0, hend = OH; | |||||
const int wbeg = 0, wend = OW; | |||||
int i, j; | |||||
for (i = hbeg; i + BLOCK_H <= hend; i += BLOCK_H) { | |||||
for (j = wbeg; j + 8 <= wend; j += 8) { | |||||
// do prefetch | |||||
const int prefetch_index_input = | |||||
(j + 16) < wend | |||||
? i * IW + j + 16 | |||||
: (i + 8) * IW + (((j + 16 - wend) >> 2) << 2); | |||||
const int prefetch_index_output = | |||||
(j + 16) < wend | |||||
? i * OW + j + 16 | |||||
: (i + 8) * OW + (((j + 16 - wend) >> 2) << 2); | |||||
const __fp16* src_prefetch = src + prefetch_index_input; | |||||
const __fp16* dst_prefetch = dst + prefetch_index_output; | |||||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { | |||||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); | |||||
} | |||||
#define unroll_prefetch_cb(i) __builtin_prefetch(dst_prefetch + i * OW, 1, 3); | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); | |||||
do_pixel<FH, BLOCK_H, 8>(src, filter, dst, IH, IW, OH, OW, FW, i, | |||||
j); | |||||
} | |||||
#define DISPATCH(width) \ | |||||
do { \ | |||||
const int prefetch_index_input = (i + 8) * IW + 12; \ | |||||
const int prefetch_index_output = (i + 8) * OW + 12; \ | |||||
const __fp16* src_prefetch = src + prefetch_index_input; \ | |||||
const __fp16* dst_prefetch = dst + prefetch_index_output; \ | |||||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ | |||||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ | |||||
} \ | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); \ | |||||
do_pixel<FH, BLOCK_H, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||||
j); \ | |||||
} while (0) | |||||
switch (wend - j) { | |||||
case 1: | |||||
DISPATCH(1); | |||||
break; | |||||
case 2: | |||||
DISPATCH(2); | |||||
break; | |||||
case 3: | |||||
DISPATCH(3); | |||||
break; | |||||
case 4: | |||||
DISPATCH(4); | |||||
break; | |||||
case 5: | |||||
DISPATCH(5); | |||||
break; | |||||
case 6: | |||||
DISPATCH(6); | |||||
break; | |||||
case 7: | |||||
DISPATCH(7); | |||||
break; | |||||
} | |||||
#undef DISPATCH | |||||
} | |||||
#define DISPATCH2(height, width) \ | |||||
do { \ | |||||
const int prefetch_index_input = IH * IW + 12; \ | |||||
const __fp16* src_prefetch = src + prefetch_index_input; \ | |||||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ | |||||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ | |||||
} \ | |||||
do_pixel<FH, height, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||||
j); \ | |||||
} while (0) | |||||
#define DISPATCH1(height) \ | |||||
do { \ | |||||
for (j = wbeg; j + 8 <= wend; j += 8) { \ | |||||
const int prefetch_index_input = \ | |||||
(j + 16) < wend \ | |||||
? i * IW + j + 16 \ | |||||
: (i + 8) * IW + (((j + 16 - wend) >> 2) << 2); \ | |||||
const int prefetch_index_output = \ | |||||
(j + 16) < wend \ | |||||
? i * OW + j + 16 \ | |||||
: (i + 8) * OW + (((j + 16 - wend) >> 2) << 2); \ | |||||
const __fp16* src_prefetch = src + prefetch_index_input; \ | |||||
const __fp16* dst_prefetch = dst + prefetch_index_output; \ | |||||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ | |||||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ | |||||
} \ | |||||
UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); \ | |||||
do_pixel<FH, height, 8>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||||
j); \ | |||||
} \ | |||||
switch (wend - j) { \ | |||||
case 1: \ | |||||
DISPATCH2(height, 1); \ | |||||
break; \ | |||||
case 2: \ | |||||
DISPATCH2(height, 2); \ | |||||
break; \ | |||||
case 3: \ | |||||
DISPATCH2(height, 3); \ | |||||
break; \ | |||||
case 4: \ | |||||
DISPATCH2(height, 4); \ | |||||
break; \ | |||||
case 5: \ | |||||
DISPATCH2(height, 5); \ | |||||
break; \ | |||||
case 6: \ | |||||
DISPATCH2(height, 6); \ | |||||
break; \ | |||||
case 7: \ | |||||
DISPATCH2(height, 7); \ | |||||
break; \ | |||||
} \ | |||||
} while (0) | |||||
switch (hend - i) { | |||||
case 1: | |||||
DISPATCH1(1); | |||||
break; | |||||
case 2: | |||||
DISPATCH1(2); | |||||
break; | |||||
case 3: | |||||
DISPATCH1(3); | |||||
break; | |||||
#if BLOCK_H == 8 | |||||
case 4: | |||||
DISPATCH1(4); | |||||
break; | |||||
case 5: | |||||
DISPATCH1(5); | |||||
break; | |||||
case 6: | |||||
DISPATCH1(6); | |||||
break; | |||||
case 7: | |||||
DISPATCH1(7); | |||||
break; | |||||
#endif | |||||
} | |||||
#undef DISPATCH1 | |||||
#undef DISPATCH2 | |||||
#undef unroll_prefetch_cb | |||||
} | |||||
template <int FH> | |||||
void do_conv_tpl_disable_prefetch(const __fp16* src, | |||||
const __fp16* filter, __fp16* dst, | |||||
const int IH, const int IW, const int OH, | |||||
const int OW, const int FW) { | |||||
const int hbeg = 0, hend = OH; | |||||
const int wbeg = 0, wend = OW; | |||||
int i, j; | |||||
for (i = hbeg; i + BLOCK_H <= hend; i += BLOCK_H) { | |||||
for (j = wbeg; j + 8 <= wend; j += 8) { | |||||
do_pixel<FH, BLOCK_H, 8>(src, filter, dst, IH, IW, OH, OW, FW, i, | |||||
j); | |||||
} | |||||
#define DISPATCH(width) \ | |||||
do { \ | |||||
do_pixel<FH, BLOCK_H, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||||
j); \ | |||||
} while (0) | |||||
switch (wend - j) { | |||||
case 1: | |||||
DISPATCH(1); | |||||
break; | |||||
case 2: | |||||
DISPATCH(2); | |||||
break; | |||||
case 3: | |||||
DISPATCH(3); | |||||
break; | |||||
case 4: | |||||
DISPATCH(4); | |||||
break; | |||||
case 5: | |||||
DISPATCH(5); | |||||
break; | |||||
case 6: | |||||
DISPATCH(6); | |||||
break; | |||||
case 7: | |||||
DISPATCH(7); | |||||
break; | |||||
} | |||||
#undef DISPATCH | |||||
} | |||||
#define DISPATCH2(height, width) \ | |||||
do { \ | |||||
do_pixel<FH, height, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||||
j); \ | |||||
} while (0) | |||||
#define DISPATCH1(height) \ | |||||
do { \ | |||||
for (j = wbeg; j + 8 <= wend; j += 8) { \ | |||||
do_pixel<FH, height, 8>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||||
j); \ | |||||
} \ | |||||
switch (wend - j) { \ | |||||
case 1: \ | |||||
DISPATCH2(height, 1); \ | |||||
break; \ | |||||
case 2: \ | |||||
DISPATCH2(height, 2); \ | |||||
break; \ | |||||
case 3: \ | |||||
DISPATCH2(height, 3); \ | |||||
break; \ | |||||
case 4: \ | |||||
DISPATCH2(height, 4); \ | |||||
break; \ | |||||
case 5: \ | |||||
DISPATCH2(height, 5); \ | |||||
break; \ | |||||
case 6: \ | |||||
DISPATCH2(height, 6); \ | |||||
break; \ | |||||
case 7: \ | |||||
DISPATCH2(height, 7); \ | |||||
break; \ | |||||
} \ | |||||
} while (0) | |||||
switch (hend - i) { | |||||
case 1: | |||||
DISPATCH1(1); | |||||
break; | |||||
case 2: | |||||
DISPATCH1(2); | |||||
break; | |||||
case 3: | |||||
DISPATCH1(3); | |||||
break; | |||||
#if BLOCK_H == 8 | |||||
case 4: | |||||
DISPATCH1(4); | |||||
break; | |||||
case 5: | |||||
DISPATCH1(5); | |||||
break; | |||||
case 6: | |||||
DISPATCH1(6); | |||||
break; | |||||
case 7: | |||||
DISPATCH1(7); | |||||
break; | |||||
#endif | |||||
} | |||||
#undef DISPATCH1 | |||||
#undef DISPATCH2 | |||||
} | |||||
} // anonymous namespace | |||||
void conv_bias::kern_direct_f16(const __fp16* src, | |||||
const __fp16* filter, __fp16* dst, | |||||
const int IH, const int IW, const int OH, | |||||
const int OW, const int FH, const int FW) { | |||||
megdnn_assert_internal(FH <= 7); | |||||
if (IH > 100 && IW > 100) { | |||||
#define GAO(FH) \ | |||||
do { \ | |||||
return do_conv_tpl_enable_prefetch<FH>(src, filter, dst, IH, IW, OH, \ | |||||
OW, FW); \ | |||||
} while (0) | |||||
switch (FH) { | |||||
case 1: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(0)) { GAO(1); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 2: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(1)) { GAO(2); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 3: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(2)) { GAO(3); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 4: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(3)) { GAO(4); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 5: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(4)) { GAO(5); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 6: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(5)) { GAO(6); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 7: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(6)) { GAO(7); } | |||||
MIDOUT_END(); | |||||
break; | |||||
} | |||||
#undef GAO | |||||
} else { | |||||
#define GAO(FH) \ | |||||
do { \ | |||||
return do_conv_tpl_disable_prefetch<FH>(src, filter, dst, IH, IW, OH, \ | |||||
OW, FW); \ | |||||
} while (0) | |||||
switch (FH) { | |||||
case 1: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(0)) { GAO(1); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 2: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(1)) { GAO(2); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 3: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(2)) { GAO(3); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 4: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(3)) { GAO(4); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 5: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(4)) { GAO(5); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 6: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(5)) { GAO(6); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 7: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(6)) { GAO(7); } | |||||
MIDOUT_END(); | |||||
break; | |||||
} | |||||
#undef GAO | |||||
} | |||||
megdnn_assert_internal(0); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,32 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/f16/direct.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cstddef> | |||||
#include "megdnn/dtype.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace fp16{ | |||||
namespace conv_bias { | |||||
void kern_direct_f16(const __fp16* src, const __fp16* filter, | |||||
__fp16* dst, const int IH, const int IW, const int OH, | |||||
const int OW, const int FH, const int FW); | |||||
} // namespace convolution | |||||
} // namespace fp16 | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,522 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/f16/do_conv_stride1.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <algorithm> | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
#include "./do_conv_stride1.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
using namespace fp16; | |||||
using namespace conv_stride1; | |||||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
void conv_stride1::do_conv_2x2_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
//! unroll of 2 | |||||
size_t ic = 0; | |||||
for (; ic + 1 < IC; ic += 2) { | |||||
const __fp16* src_ptr = src + IW * IH * ic; | |||||
const __fp16* src_ptr1 = src_ptr + IW * IH; | |||||
__fp16* outptr = dst; | |||||
const __fp16* r00 = src_ptr; | |||||
const __fp16* r01 = src_ptr + IW; | |||||
const __fp16* r10 = src_ptr1; | |||||
const __fp16* r11 = src_ptr1 + IW; | |||||
const __fp16* k0 = filter + ic * 4; | |||||
float16x8_t _k0 = vld1q_f16(k0); | |||||
rep(h, OH) { | |||||
int width = OW >> 3; | |||||
rep(i, width) { | |||||
float16x8_t _r000 = vld1q_f16(r00); | |||||
float16x8_t _r010 = vld1q_f16(r01); | |||||
float16x8_t _r001 = vld1q_f16(r00 + 1); | |||||
float16x8_t _r011 = vld1q_f16(r01 + 1); | |||||
float16x8_t _r100 = vld1q_f16(r10); | |||||
float16x8_t _r110 = vld1q_f16(r11); | |||||
float16x8_t _r101 = vld1q_f16(r10 + 1); | |||||
float16x8_t _r111 = vld1q_f16(r11 + 1); | |||||
float16x8_t _sum = vld1q_f16(outptr); | |||||
_sum = vmlaq_lane_f16(_sum, _r000, vget_low_f16(_k0), 0); | |||||
_sum = vmlaq_lane_f16(_sum, _r001, vget_low_f16(_k0), 1); | |||||
_sum = vmlaq_lane_f16(_sum, _r010, vget_low_f16(_k0), 2); | |||||
_sum = vmlaq_lane_f16(_sum, _r011, vget_low_f16(_k0), 3); | |||||
_sum = vmlaq_lane_f16(_sum, _r100, vget_high_f16(_k0), 0); | |||||
_sum = vmlaq_lane_f16(_sum, _r101, vget_high_f16(_k0), 1); | |||||
_sum = vmlaq_lane_f16(_sum, _r110, vget_high_f16(_k0), 2); | |||||
_sum = vmlaq_lane_f16(_sum, _r111, vget_high_f16(_k0), 3); | |||||
vst1q_f16(outptr, _sum); | |||||
r00 += 8; | |||||
r01 += 8; | |||||
r10 += 8; | |||||
r11 += 8; | |||||
outptr += 8; | |||||
} | |||||
r00 += tail_step; | |||||
r01 += tail_step; | |||||
r10 += tail_step; | |||||
r11 += tail_step; | |||||
} | |||||
} | |||||
for (; ic < IC; ic++) { | |||||
const __fp16* src_ptr = src + IW * IH * ic; | |||||
__fp16* outptr = dst; | |||||
const __fp16* r0 = src_ptr; | |||||
const __fp16* r1 = src_ptr + IW; | |||||
const __fp16* k0 = filter + ic * 4; | |||||
float16x8_t _k0 = vdupq_n_f16(k0[0]); | |||||
float16x8_t _k1 = vdupq_n_f16(k0[1]); | |||||
float16x8_t _k2 = vdupq_n_f16(k0[2]); | |||||
float16x8_t _k3 = vdupq_n_f16(k0[3]); | |||||
rep(h, OH) { | |||||
int width = OW >> 3; | |||||
rep(i, width) { | |||||
float16x8_t _r00 = vld1q_f16(r0); | |||||
float16x8_t _r10 = vld1q_f16(r1); | |||||
float16x8_t _r01 = vld1q_f16(r0 + 1); | |||||
float16x8_t _r11 = vld1q_f16(r1 + 1); | |||||
float16x8_t _sum = vld1q_f16(outptr); | |||||
float16x8_t _sum2; | |||||
_sum = vmlaq_f16(_sum, _r00, _k0); | |||||
_sum2 = vmulq_f16(_r01, _k1); | |||||
_sum = vmlaq_f16(_sum, _r10, _k2); | |||||
_sum2 = vmlaq_f16(_sum2, _r11, _k3); | |||||
_sum = vaddq_f16(_sum, _sum2); | |||||
vst1q_f16(outptr, _sum); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
outptr += 8; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
} | |||||
} | |||||
} | |||||
void conv_stride1::do_conv_3x3_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
rep(ic, IC) { | |||||
const __fp16* src_ptr = src + IW * IH * ic; | |||||
__fp16* outptr = dst; | |||||
__fp16* outptr2 = outptr + OW; | |||||
const __fp16* r0 = src_ptr; | |||||
const __fp16* r1 = src_ptr + IW; | |||||
const __fp16* r2 = src_ptr + IW * 2; | |||||
const __fp16* r3 = src_ptr + IW * 3; | |||||
float16x8_t _k01234567 = vld1q_f16(filter); | |||||
float16x8_t _k12345678 = vld1q_f16(filter + 1); | |||||
size_t h = 0; | |||||
for (; h + 1 < OH; h += 2) { | |||||
int width = OW >> 3; | |||||
rep(i, width) { | |||||
float16x8_t _sum1 = vld1q_f16(outptr); | |||||
float16x8_t _sum2 = vdupq_n_f16(0.f); | |||||
float16x8_t _sum3 = vld1q_f16(outptr2); | |||||
float16x8_t _sum4 = vdupq_n_f16(0.f); | |||||
float16x8_t _r00 = vld1q_f16(r0); | |||||
float16x8_t _r00n = vld1q_f16(r0 + 8); | |||||
float16x8_t _r01 = vextq_f16(_r00, _r00n, 1); | |||||
float16x8_t _r02 = vextq_f16(_r00, _r00n, 2); | |||||
float16x8_t _r10 = vld1q_f16(r1); | |||||
float16x8_t _r10n = vld1q_f16(r1 + 8); | |||||
float16x8_t _r11 = vextq_f16(_r10, _r10n, 1); | |||||
float16x8_t _r12 = vextq_f16(_r10, _r10n, 2); | |||||
float16x8_t _r20 = vld1q_f16(r2); | |||||
float16x8_t _r20n = vld1q_f16(r2 + 8); | |||||
float16x8_t _r21 = vextq_f16(_r20, _r20n, 1); | |||||
float16x8_t _r22 = vextq_f16(_r20, _r20n, 2); | |||||
float16x8_t _r30 = vld1q_f16(r3); | |||||
float16x8_t _r30n = vld1q_f16(r3 + 8); | |||||
float16x8_t _r31 = vextq_f16(_r30, _r30n, 1); | |||||
float16x8_t _r32 = vextq_f16(_r30, _r30n, 2); | |||||
_sum1 = vmlaq_low_lane_f16(_sum1, _r00, _k01234567, 0); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r01, _k01234567, 1); | |||||
_sum1 = vmlaq_low_lane_f16(_sum1, _r02, _k01234567, 2); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r10, _k01234567, 3); | |||||
_sum1 = vmlaq_high_lane_f16(_sum1, _r11, _k01234567, 4); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r12, _k01234567, 5); | |||||
_sum1 = vmlaq_high_lane_f16(_sum1, _r20, _k01234567, 6); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r21, _k01234567, 7); | |||||
_sum1 = vmlaq_high_lane_f16(_sum1, _r22, _k12345678, 7); | |||||
_sum3 = vmlaq_low_lane_f16(_sum3, _r10, _k01234567, 0); | |||||
_sum4 = vmlaq_low_lane_f16(_sum4, _r11, _k01234567, 1); | |||||
_sum3 = vmlaq_low_lane_f16(_sum3, _r12, _k01234567, 2); | |||||
_sum4 = vmlaq_low_lane_f16(_sum4, _r20, _k01234567, 3); | |||||
_sum3 = vmlaq_high_lane_f16(_sum3, _r21, _k01234567, 4); | |||||
_sum4 = vmlaq_high_lane_f16(_sum4, _r22, _k01234567, 5); | |||||
_sum3 = vmlaq_high_lane_f16(_sum3, _r30, _k01234567, 6); | |||||
_sum4 = vmlaq_high_lane_f16(_sum4, _r31, _k01234567, 7); | |||||
_sum3 = vmlaq_high_lane_f16(_sum3, _r32, _k12345678, 7); | |||||
_sum1 = vaddq_f16(_sum1, _sum2); | |||||
_sum3 = vaddq_f16(_sum3, _sum4); | |||||
vst1q_f16(outptr, _sum1); | |||||
vst1q_f16(outptr2, _sum3); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
r2 += 8; | |||||
r3 += 8; | |||||
outptr += 8; | |||||
outptr2 += 8; | |||||
} | |||||
r0 += tail_step + IW; | |||||
r1 += tail_step + IW; | |||||
r2 += tail_step + IW; | |||||
r3 += tail_step + IW; | |||||
outptr += OW; | |||||
outptr2 += OW; | |||||
} | |||||
for (; h < OH; h++) { | |||||
int width = OW >> 3; | |||||
rep(i, width) { | |||||
float16x8_t _sum1 = vld1q_f16(outptr); | |||||
float16x8_t _sum2 = vdupq_n_f16(0.f); | |||||
float16x8_t _r00 = vld1q_f16(r0); | |||||
float16x8_t _r00n = vld1q_f16(r0 + 8); | |||||
float16x8_t _r01 = vextq_f16(_r00, _r00n, 1); | |||||
float16x8_t _r02 = vextq_f16(_r00, _r00n, 2); | |||||
float16x8_t _r10 = vld1q_f16(r1); | |||||
float16x8_t _r10n = vld1q_f16(r1 + 8); | |||||
float16x8_t _r11 = vextq_f16(_r10, _r10n, 1); | |||||
float16x8_t _r12 = vextq_f16(_r10, _r10n, 2); | |||||
float16x8_t _r20 = vld1q_f16(r2); | |||||
float16x8_t _r20n = vld1q_f16(r2 + 8); | |||||
float16x8_t _r21 = vextq_f16(_r20, _r20n, 1); | |||||
float16x8_t _r22 = vextq_f16(_r20, _r20n, 2); | |||||
_sum1 = vmlaq_low_lane_f16(_sum1, _r00, _k01234567, 0); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r01, _k01234567, 1); | |||||
_sum1 = vmlaq_low_lane_f16(_sum1, _r02, _k01234567, 2); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r10, _k01234567, 3); | |||||
_sum1 = vmlaq_high_lane_f16(_sum1, _r11, _k01234567, 4); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r12, _k01234567, 5); | |||||
_sum1 = vmlaq_high_lane_f16(_sum1, _r20, _k01234567, 6); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r21, _k01234567, 7); | |||||
_sum1 = vmlaq_high_lane_f16(_sum1, _r22, _k12345678, 7); | |||||
_sum1 = vaddq_f16(_sum1, _sum2); | |||||
vst1q_f16(outptr, _sum1); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
r2 += 8; | |||||
outptr += 8; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
} | |||||
filter += 9; | |||||
} | |||||
} | |||||
void conv_stride1::do_conv_5x5_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
rep(ic, IC) { | |||||
const __fp16* src_ptr = src + IW * IH * ic; | |||||
__fp16* outptr = dst; | |||||
__fp16* outptr2 = outptr + OW; | |||||
const __fp16* r0 = src_ptr; | |||||
const __fp16* r1 = src_ptr + IW; | |||||
const __fp16* r2 = src_ptr + IW * 2; | |||||
const __fp16* r3 = src_ptr + IW * 3; | |||||
const __fp16* r4 = src_ptr + IW * 4; | |||||
const __fp16* r5 = src_ptr + IW * 5; | |||||
float16x8_t _k0 = vld1q_f16(filter); | |||||
float16x8_t _k1 = vld1q_f16(filter + 8); | |||||
float16x8_t _k2 = vld1q_f16(filter + 16); | |||||
float16x8_t _k3 = vld1q_f16(filter + 17); | |||||
size_t h = 0; | |||||
for (; h + 1 < OH; h += 2) { | |||||
int width = OW >> 3; | |||||
rep(i, width) { | |||||
float16x8_t _sum = vld1q_f16(outptr); | |||||
float16x8_t _sum2 = vld1q_f16(outptr2); | |||||
float16x8_t _r00 = vld1q_f16(r0); | |||||
float16x8_t _r05 = vld1q_f16(r0 + 8); | |||||
float16x8_t _r01 = vextq_f16(_r00, _r05, 1); | |||||
float16x8_t _r02 = vextq_f16(_r00, _r05, 2); | |||||
float16x8_t _r03 = vextq_f16(_r00, _r05, 3); | |||||
float16x8_t _r04 = vextq_f16(_r00, _r05, 4); | |||||
float16x8_t _r10 = vld1q_f16(r1); | |||||
float16x8_t _r15 = vld1q_f16(r1 + 8); | |||||
float16x8_t _r11 = vextq_f16(_r10, _r15, 1); | |||||
float16x8_t _r12 = vextq_f16(_r10, _r15, 2); | |||||
float16x8_t _r13 = vextq_f16(_r10, _r15, 3); | |||||
float16x8_t _r14 = vextq_f16(_r10, _r15, 4); | |||||
float16x8_t _r20 = vld1q_f16(r2); | |||||
float16x8_t _r25 = vld1q_f16(r2 + 8); | |||||
float16x8_t _r21 = vextq_f16(_r20, _r25, 1); | |||||
float16x8_t _r22 = vextq_f16(_r20, _r25, 2); | |||||
float16x8_t _r23 = vextq_f16(_r20, _r25, 3); | |||||
float16x8_t _r24 = vextq_f16(_r20, _r25, 4); | |||||
float16x8_t _r30 = vld1q_f16(r3); | |||||
float16x8_t _r35 = vld1q_f16(r3 + 8); | |||||
float16x8_t _r31 = vextq_f16(_r30, _r35, 1); | |||||
float16x8_t _r32 = vextq_f16(_r30, _r35, 2); | |||||
float16x8_t _r33 = vextq_f16(_r30, _r35, 3); | |||||
float16x8_t _r34 = vextq_f16(_r30, _r35, 4); | |||||
float16x8_t _r40 = vld1q_f16(r4); | |||||
float16x8_t _r45 = vld1q_f16(r4 + 8); | |||||
float16x8_t _r41 = vextq_f16(_r40, _r45, 1); | |||||
float16x8_t _r42 = vextq_f16(_r40, _r45, 2); | |||||
float16x8_t _r43 = vextq_f16(_r40, _r45, 3); | |||||
float16x8_t _r44 = vextq_f16(_r40, _r45, 4); | |||||
float16x8_t _r50 = vld1q_f16(r5); | |||||
float16x8_t _r55 = vld1q_f16(r5 + 8); | |||||
float16x8_t _r51 = vextq_f16(_r50, _r55, 1); | |||||
float16x8_t _r52 = vextq_f16(_r50, _r55, 2); | |||||
float16x8_t _r53 = vextq_f16(_r50, _r55, 3); | |||||
float16x8_t _r54 = vextq_f16(_r50, _r55, 4); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r00, _k0, 0); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r01, _k0, 1); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r02, _k0, 2); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r03, _k0, 3); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r04, _k0, 4); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r10, _k0, 5); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r11, _k0, 6); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r12, _k0, 7); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r13, _k1, 0); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r14, _k1, 1); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r20, _k1, 2); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r21, _k1, 3); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r22, _k1, 4); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r23, _k1, 5); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r24, _k1, 6); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r30, _k1, 7); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r31, _k2, 0); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r32, _k2, 1); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r33, _k2, 2); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r34, _k2, 3); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r40, _k2, 4); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r41, _k2, 5); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r42, _k2, 6); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r43, _k2, 7); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r44, _k3, 7); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r10, _k0, 0); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r11, _k0, 1); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r12, _k0, 2); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r13, _k0, 3); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r14, _k0, 4); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r20, _k0, 5); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r21, _k0, 6); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r22, _k0, 7); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r23, _k1, 0); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r24, _k1, 1); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r30, _k1, 2); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r31, _k1, 3); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r32, _k1, 4); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r33, _k1, 5); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r34, _k1, 6); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r40, _k1, 7); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r41, _k2, 0); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r42, _k2, 1); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r43, _k2, 2); | |||||
_sum2 = vmlaq_low_lane_f16(_sum2, _r44, _k2, 3); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r50, _k2, 4); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r51, _k2, 5); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r52, _k2, 6); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r53, _k2, 7); | |||||
_sum2 = vmlaq_high_lane_f16(_sum2, _r54, _k3, 7); | |||||
vst1q_f16(outptr, _sum); | |||||
vst1q_f16(outptr2, _sum2); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
r2 += 8; | |||||
r3 += 8; | |||||
r4 += 8; | |||||
r5 += 8; | |||||
outptr += 8; | |||||
outptr2 += 8; | |||||
} | |||||
r0 += tail_step + IW; | |||||
r1 += tail_step + IW; | |||||
r2 += tail_step + IW; | |||||
r3 += tail_step + IW; | |||||
r4 += tail_step + IW; | |||||
r5 += tail_step + IW; | |||||
outptr += OW; | |||||
outptr2 += OW; | |||||
} | |||||
for (; h < OH; h++) { | |||||
int width = OW >> 3; | |||||
rep(i, width) { | |||||
float16x8_t _sum = vld1q_f16(outptr); | |||||
float16x8_t _r00 = vld1q_f16(r0); | |||||
float16x8_t _r05 = vld1q_f16(r0 + 8); | |||||
float16x8_t _r01 = vextq_f16(_r00, _r05, 1); | |||||
float16x8_t _r02 = vextq_f16(_r00, _r05, 2); | |||||
float16x8_t _r03 = vextq_f16(_r00, _r05, 3); | |||||
float16x8_t _r04 = vextq_f16(_r00, _r05, 4); | |||||
float16x8_t _r10 = vld1q_f16(r1); | |||||
float16x8_t _r15 = vld1q_f16(r1 + 8); | |||||
float16x8_t _r11 = vextq_f16(_r10, _r15, 1); | |||||
float16x8_t _r12 = vextq_f16(_r10, _r15, 2); | |||||
float16x8_t _r13 = vextq_f16(_r10, _r15, 3); | |||||
float16x8_t _r14 = vextq_f16(_r10, _r15, 4); | |||||
float16x8_t _r20 = vld1q_f16(r2); | |||||
float16x8_t _r25 = vld1q_f16(r2 + 8); | |||||
float16x8_t _r21 = vextq_f16(_r20, _r25, 1); | |||||
float16x8_t _r22 = vextq_f16(_r20, _r25, 2); | |||||
float16x8_t _r23 = vextq_f16(_r20, _r25, 3); | |||||
float16x8_t _r24 = vextq_f16(_r20, _r25, 4); | |||||
float16x8_t _r30 = vld1q_f16(r3); | |||||
float16x8_t _r35 = vld1q_f16(r3 + 8); | |||||
float16x8_t _r31 = vextq_f16(_r30, _r35, 1); | |||||
float16x8_t _r32 = vextq_f16(_r30, _r35, 2); | |||||
float16x8_t _r33 = vextq_f16(_r30, _r35, 3); | |||||
float16x8_t _r34 = vextq_f16(_r30, _r35, 4); | |||||
float16x8_t _r40 = vld1q_f16(r4); | |||||
float16x8_t _r45 = vld1q_f16(r4 + 8); | |||||
float16x8_t _r41 = vextq_f16(_r40, _r45, 1); | |||||
float16x8_t _r42 = vextq_f16(_r40, _r45, 2); | |||||
float16x8_t _r43 = vextq_f16(_r40, _r45, 3); | |||||
float16x8_t _r44 = vextq_f16(_r40, _r45, 4); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r00, _k0, 0); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r01, _k0, 1); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r02, _k0, 2); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r03, _k0, 3); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r04, _k0, 4); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r10, _k0, 5); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r11, _k0, 6); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r12, _k0, 7); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r13, _k1, 0); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r14, _k1, 1); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r20, _k1, 2); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r21, _k1, 3); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r22, _k1, 4); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r23, _k1, 5); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r24, _k1, 6); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r30, _k1, 7); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r31, _k2, 0); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r32, _k2, 1); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r33, _k2, 2); | |||||
_sum = vmlaq_low_lane_f16(_sum, _r34, _k2, 3); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r40, _k2, 4); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r41, _k2, 5); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r42, _k2, 6); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r43, _k2, 7); | |||||
_sum = vmlaq_high_lane_f16(_sum, _r44, _k3, 7); | |||||
vst1q_f16(outptr, _sum); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
r2 += 8; | |||||
r3 += 8; | |||||
r4 += 8; | |||||
outptr += 8; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
r3 += tail_step; | |||||
r4 += tail_step; | |||||
} | |||||
filter += 25; | |||||
} | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,32 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/f16/do_conv_stride1.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace fp16 { | |||||
namespace conv_stride1 { | |||||
void do_conv_2x2_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||||
void do_conv_3x3_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||||
void do_conv_5x5_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||||
} // namespace conv_stride1 | |||||
} // namespace fp16 | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,341 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/f16/helper.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/unroll_macro.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
#define MATRIX_MUL4x4_fp16(sum, a, b) \ | |||||
sum##0 = vmul_lane_f16(b##0, a##0, 0); \ | |||||
sum##1 = vmul_lane_f16(b##0, a##1, 0); \ | |||||
sum##2 = vmul_lane_f16(b##0, a##2, 0); \ | |||||
sum##3 = vmul_lane_f16(b##0, a##3, 0); \ | |||||
sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##1, a##0, 1)); \ | |||||
sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##1, a##1, 1)); \ | |||||
sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##1, a##2, 1)); \ | |||||
sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##1, a##3, 1)); \ | |||||
sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##2, a##0, 2)); \ | |||||
sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##2, a##1, 2)); \ | |||||
sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##2, a##2, 2)); \ | |||||
sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##2, a##3, 2)); \ | |||||
sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##3, a##0, 3)); \ | |||||
sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##3, a##1, 3)); \ | |||||
sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##3, a##2, 3)); \ | |||||
sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##3, a##3, 3)); | |||||
#define CONCAT(a, id) a##id | |||||
#if MEGDNN_AARCH64 | |||||
#define TRANSPOSE_4x4(a, ret) \ | |||||
do { \ | |||||
auto b00 = vzip1_f16(CONCAT(a, 0).value, \ | |||||
CONCAT(a, 1).value); /*a1b1a2b2*/ \ | |||||
auto b01 = vzip2_f16(CONCAT(a, 0).value, \ | |||||
CONCAT(a, 1).value); /*a3b3a4b4*/ \ | |||||
auto b10 = vzip1_f16(CONCAT(a, 2).value, \ | |||||
CONCAT(a, 3).value); /*c1d1c2d2*/ \ | |||||
auto b11 = vzip2_f16(CONCAT(a, 2).value, \ | |||||
CONCAT(a, 3).value); /*c3d3c4d4*/ \ | |||||
auto s32b00 = vreinterpret_s32_f16(b00); \ | |||||
auto s32b01 = vreinterpret_s32_f16(b01); \ | |||||
auto s32b10 = vreinterpret_s32_f16(b10); \ | |||||
auto s32b11 = vreinterpret_s32_f16(b11); \ | |||||
CONCAT(ret, 0).value = \ | |||||
vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)); \ | |||||
CONCAT(ret, 1).value = \ | |||||
vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)); \ | |||||
CONCAT(ret, 2).value = \ | |||||
vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)); \ | |||||
CONCAT(ret, 3).value = \ | |||||
vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)); \ | |||||
} while (0); | |||||
#define TRANSPOSE_4x8(a, ret) \ | |||||
do { \ | |||||
auto b00 = vzip1q_f16(CONCAT(a, 0).value, \ | |||||
CONCAT(a, 1).value); /*a1b1a2b2a3b3a4b4*/ \ | |||||
auto b01 = vzip2q_f16(CONCAT(a, 0).value, \ | |||||
CONCAT(a, 1).value); /*a5b5a6b6a7b7a8b8*/ \ | |||||
auto b10 = vzip1q_f16(CONCAT(a, 2).value, \ | |||||
CONCAT(a, 3).value); /*c1d1c2d2c3d3c4d4*/ \ | |||||
auto b11 = vzip2q_f16(CONCAT(a, 2).value, \ | |||||
CONCAT(a, 3).value); /*c5d5c6d6c7d7c8d8*/ \ | |||||
auto s32b00 = vreinterpretq_s32_f16(b00); \ | |||||
auto s32b01 = vreinterpretq_s32_f16(b01); \ | |||||
auto s32b10 = vreinterpretq_s32_f16(b10); \ | |||||
auto s32b11 = vreinterpretq_s32_f16(b11); \ | |||||
auto f16b00 = vreinterpretq_f16_s32( \ | |||||
vzip1q_s32(s32b00, s32b10)); /*a1b1c1d1a2b2c2d2*/ \ | |||||
auto f16b01 = vreinterpretq_f16_s32( \ | |||||
vzip2q_s32(s32b00, s32b10)); /*a3b3c3d3a4b4a4d4*/ \ | |||||
auto f16b10 = vreinterpretq_f16_s32( \ | |||||
vzip1q_s32(s32b01, s32b11)); /*a5b5c5d5a6b6c6d6*/ \ | |||||
auto f16b11 = vreinterpretq_f16_s32( \ | |||||
vzip2q_s32(s32b01, s32b11)); /*a7b7c7d7a8b8c8d8*/ \ | |||||
CONCAT(ret, 0).value = vget_low_f16(f16b00); \ | |||||
CONCAT(ret, 1).value = vget_high_f16(f16b00); \ | |||||
CONCAT(ret, 2).value = vget_low_f16(f16b01); \ | |||||
CONCAT(ret, 3).value = vget_high_f16(f16b01); \ | |||||
CONCAT(ret, 4).value = vget_low_f16(f16b10); \ | |||||
CONCAT(ret, 5).value = vget_high_f16(f16b10); \ | |||||
CONCAT(ret, 6).value = vget_low_f16(f16b11); \ | |||||
CONCAT(ret, 7).value = vget_high_f16(f16b11); \ | |||||
} while (0); | |||||
#define TRANSPOSE_8x4(a, ret) \ | |||||
do { \ | |||||
auto b00 = vzip1_f16(CONCAT(a, 0).value, \ | |||||
CONCAT(a, 1).value); /*a1b1a2b2*/ \ | |||||
auto b01 = vzip2_f16(CONCAT(a, 0).value, \ | |||||
CONCAT(a, 1).value); /*a3b3a4b4*/ \ | |||||
auto b10 = vzip1_f16(CONCAT(a, 2).value, \ | |||||
CONCAT(a, 3).value); /*c1d1c2d2*/ \ | |||||
auto b11 = vzip2_f16(CONCAT(a, 2).value, \ | |||||
CONCAT(a, 3).value); /*c3d3c4d4*/ \ | |||||
auto b20 = vzip1_f16(CONCAT(a, 4).value, \ | |||||
CONCAT(a, 5).value); /*e1f1e2f2*/ \ | |||||
auto b21 = vzip2_f16(CONCAT(a, 4).value, \ | |||||
CONCAT(a, 5).value); /*e3f3e4f4*/ \ | |||||
auto b30 = vzip1_f16(CONCAT(a, 6).value, \ | |||||
CONCAT(a, 7).value); /*g1h1g2h2*/ \ | |||||
auto b31 = vzip2_f16(CONCAT(a, 6).value, \ | |||||
CONCAT(a, 7).value); /*g3h3g4h4*/ \ | |||||
auto s32b00 = vreinterpret_s32_f16(b00); \ | |||||
auto s32b01 = vreinterpret_s32_f16(b01); \ | |||||
auto s32b10 = vreinterpret_s32_f16(b10); \ | |||||
auto s32b11 = vreinterpret_s32_f16(b11); \ | |||||
auto s32b20 = vreinterpret_s32_f16(b20); \ | |||||
auto s32b21 = vreinterpret_s32_f16(b21); \ | |||||
auto s32b30 = vreinterpret_s32_f16(b30); \ | |||||
auto s32b31 = vreinterpret_s32_f16(b31); \ | |||||
CONCAT(ret, 0).value = \ | |||||
vcombine_f16(vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)), \ | |||||
vreinterpret_f16_s32(vzip1_s32(s32b20, s32b30))); \ | |||||
CONCAT(ret, 1).value = \ | |||||
vcombine_f16(vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)), \ | |||||
vreinterpret_f16_s32(vzip2_s32(s32b20, s32b30))); \ | |||||
CONCAT(ret, 2).value = \ | |||||
vcombine_f16(vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)), \ | |||||
vreinterpret_f16_s32(vzip1_s32(s32b21, s32b31))); \ | |||||
CONCAT(ret, 3).value = \ | |||||
vcombine_f16(vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)), \ | |||||
vreinterpret_f16_s32(vzip2_s32(s32b21, s32b31))); \ | |||||
} while (0); | |||||
#define TRANSPOSE_8x8(a, ret) \ | |||||
do { \ | |||||
auto b00 = vzip1q_f16(CONCAT(a, 0).value, \ | |||||
CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ | |||||
auto b01 = vzip2q_f16(CONCAT(a, 0).value, \ | |||||
CONCAT(a, 1).value); /*a5b5a6b6 a7b7a8b8*/ \ | |||||
auto b10 = vzip1q_f16(CONCAT(a, 2).value, \ | |||||
CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ | |||||
auto b11 = vzip2q_f16(CONCAT(a, 2).value, \ | |||||
CONCAT(a, 3).value); /*c5d5c6d6 c7d7c8d8*/ \ | |||||
auto b20 = vzip1q_f16(CONCAT(a, 4).value, \ | |||||
CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ | |||||
auto b21 = vzip2q_f16(CONCAT(a, 4).value, \ | |||||
CONCAT(a, 5).value); /*e5f5e6f6 e7f7e8f8*/ \ | |||||
auto b30 = vzip1q_f16(CONCAT(a, 6).value, \ | |||||
CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ | |||||
auto b31 = vzip2q_f16(CONCAT(a, 6).value, \ | |||||
CONCAT(a, 7).value); /*g5h5g6h6 g7h7g8h8*/ \ | |||||
auto s32b00 = vreinterpretq_s32_f16(b00); \ | |||||
auto s32b01 = vreinterpretq_s32_f16(b01); \ | |||||
auto s32b10 = vreinterpretq_s32_f16(b10); \ | |||||
auto s32b11 = vreinterpretq_s32_f16(b11); \ | |||||
auto s32b20 = vreinterpretq_s32_f16(b20); \ | |||||
auto s32b21 = vreinterpretq_s32_f16(b21); \ | |||||
auto s32b30 = vreinterpretq_s32_f16(b30); \ | |||||
auto s32b31 = vreinterpretq_s32_f16(b31); \ | |||||
auto s64b00 = vreinterpretq_s64_s32( \ | |||||
vzip1q_s32(s32b00, s32b10)); /*a1b1c1d1 a2b2c2d2*/ \ | |||||
auto s64b01 = vreinterpretq_s64_s32( \ | |||||
vzip2q_s32(s32b00, s32b10)); /*a3b3c3d3 a4b4c4d4*/ \ | |||||
auto s64b10 = vreinterpretq_s64_s32( \ | |||||
vzip1q_s32(s32b01, s32b11)); /*a5b5c5d5 a6b6c6d6*/ \ | |||||
auto s64b11 = vreinterpretq_s64_s32( \ | |||||
vzip2q_s32(s32b01, s32b11)); /*a7b7c7d7 a8b8c8d8*/ \ | |||||
auto s64b20 = vreinterpretq_s64_s32( \ | |||||
vzip1q_s32(s32b20, s32b30)); /*e1f1g1h1 e2f2g2h2*/ \ | |||||
auto s64b21 = vreinterpretq_s64_s32( \ | |||||
vzip2q_s32(s32b20, s32b30)); /*e3f3g3h3 e4f4g4h4*/ \ | |||||
auto s64b30 = vreinterpretq_s64_s32( \ | |||||
vzip1q_s32(s32b21, s32b31)); /*e5f5g5h5 e6f6g6h6*/ \ | |||||
auto s64b31 = vreinterpretq_s64_s32( \ | |||||
vzip2q_s32(s32b21, s32b31)); /*e7f7g7h7 e8f8g8h8*/ \ | |||||
CONCAT(ret, 0).value = \ | |||||
vreinterpretq_f16_s64(vzip1q_s64(s64b00, s64b20)); \ | |||||
CONCAT(ret, 1).value = \ | |||||
vreinterpretq_f16_s64(vzip2q_s64(s64b00, s64b20)); \ | |||||
CONCAT(ret, 2).value = \ | |||||
vreinterpretq_f16_s64(vzip1q_s64(s64b01, s64b21)); \ | |||||
CONCAT(ret, 3).value = \ | |||||
vreinterpretq_f16_s64(vzip2q_s64(s64b01, s64b21)); \ | |||||
CONCAT(ret, 4).value = \ | |||||
vreinterpretq_f16_s64(vzip1q_s64(s64b10, s64b30)); \ | |||||
CONCAT(ret, 5).value = \ | |||||
vreinterpretq_f16_s64(vzip2q_s64(s64b10, s64b30)); \ | |||||
CONCAT(ret, 6).value = \ | |||||
vreinterpretq_f16_s64(vzip1q_s64(s64b11, s64b31)); \ | |||||
CONCAT(ret, 7).value = \ | |||||
vreinterpretq_f16_s64(vzip2q_s64(s64b11, s64b31)); \ | |||||
} while (0); | |||||
#else | |||||
#define TRANSPOSE_4x4(a, ret) \ | |||||
do { \ | |||||
auto b0_01 = vzip_f16(CONCAT(a, 0).value, \ | |||||
CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ | |||||
auto b1_01 = vzip_f16(CONCAT(a, 2).value, \ | |||||
CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ | |||||
auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \ | |||||
auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \ | |||||
auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \ | |||||
auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \ | |||||
auto s32b00b10 = vzip_s32(s32b00, s32b10); /*a1b1c1d1 a2b2c2d2*/ \ | |||||
auto s32b01b11 = vzip_s32(s32b01, s32b11); /*a3b3c3d3 a4b4c4d4*/ \ | |||||
CONCAT(ret, 0).value = vreinterpret_f16_s32(s32b00b10.val[0]); \ | |||||
CONCAT(ret, 1).value = vreinterpret_f16_s32(s32b00b10.val[1]); \ | |||||
CONCAT(ret, 2).value = vreinterpret_f16_s32(s32b01b11.val[0]); \ | |||||
CONCAT(ret, 3).value = vreinterpret_f16_s32(s32b01b11.val[1]); \ | |||||
} while (0); | |||||
#define TRANSPOSE_4x8(a, ret) \ | |||||
do { \ | |||||
auto b0_01 = vzipq_f16( \ | |||||
CONCAT(a, 0).value, \ | |||||
CONCAT(a, 1).value); /*a1b1a2b2a3b3a4b4 a5b5a6b6a7b7a8b8*/ \ | |||||
auto b1_01 = vzipq_f16( \ | |||||
CONCAT(a, 2).value, \ | |||||
CONCAT(a, 3).value); /*c1d1c2d2c3d3c4d4 c5d6c6d6c7d7c8d8*/ \ | |||||
auto s32b00 = vreinterpretq_s32_f16(b0_01.val[0]); \ | |||||
auto s32b01 = vreinterpretq_s32_f16(b0_01.val[1]); \ | |||||
auto s32b10 = vreinterpretq_s32_f16(b1_01.val[0]); \ | |||||
auto s32b11 = vreinterpretq_s32_f16(b1_01.val[1]); \ | |||||
auto s32b00b10 = vzipq_s32( \ | |||||
s32b00, s32b10); /*a1b1c1d1a2b2c2d2 a3b3c3d3a4b4c4d4*/ \ | |||||
auto s32b01b11 = vzipq_s32( \ | |||||
s32b01, s32b11); /*a5b5c5d5a6b6c6d6 a7b7c7d7a8b8c8d8*/ \ | |||||
CONCAT(ret, 0).value = \ | |||||
vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[0])); \ | |||||
CONCAT(ret, 1).value = \ | |||||
vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[0])); \ | |||||
CONCAT(ret, 2).value = \ | |||||
vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[1])); \ | |||||
CONCAT(ret, 3).value = \ | |||||
vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[1])); \ | |||||
CONCAT(ret, 4).value = \ | |||||
vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[0])); \ | |||||
CONCAT(ret, 5).value = \ | |||||
vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[0])); \ | |||||
CONCAT(ret, 6).value = \ | |||||
vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[1])); \ | |||||
CONCAT(ret, 7).value = \ | |||||
vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[1])); \ | |||||
} while (0); | |||||
#define TRANSPOSE_8x4(a, ret) \ | |||||
do { \ | |||||
auto b0_01 = vzip_f16(CONCAT(a, 0).value, \ | |||||
CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ | |||||
auto b1_01 = vzip_f16(CONCAT(a, 2).value, \ | |||||
CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ | |||||
auto b2_01 = vzip_f16(CONCAT(a, 4).value, \ | |||||
CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ | |||||
auto b3_01 = vzip_f16(CONCAT(a, 6).value, \ | |||||
CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ | |||||
auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \ | |||||
auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \ | |||||
auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \ | |||||
auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \ | |||||
auto s32b20 = vreinterpret_s32_f16(b2_01.val[0]); \ | |||||
auto s32b21 = vreinterpret_s32_f16(b2_01.val[1]); \ | |||||
auto s32b30 = vreinterpret_s32_f16(b3_01.val[0]); \ | |||||
auto s32b31 = vreinterpret_s32_f16(b3_01.val[1]); \ | |||||
auto s32b00b10 = vzip_s32(s32b00, s32b10); \ | |||||
auto s32b01b11 = vzip_s32(s32b01, s32b11); \ | |||||
auto s32b20b30 = vzip_s32(s32b20, s32b30); \ | |||||
auto s32b21b31 = vzip_s32(s32b21, s32b31); \ | |||||
CONCAT(ret, 0).value = \ | |||||
vcombine_f16(vreinterpret_f16_s32(s32b00b10.val[0]), \ | |||||
vreinterpret_f16_s32(s32b20b30.val[0])); \ | |||||
CONCAT(ret, 1).value = \ | |||||
vcombine_f16(vreinterpret_f16_s32(s32b00b10.val[1]), \ | |||||
vreinterpret_f16_s32(s32b20b30.val[1])); \ | |||||
CONCAT(ret, 2).value = \ | |||||
vcombine_f16(vreinterpret_f16_s32(s32b01b11.val[0]), \ | |||||
vreinterpret_f16_s32(s32b21b31.val[0])); \ | |||||
CONCAT(ret, 3).value = \ | |||||
vcombine_f16(vreinterpret_f16_s32(s32b01b11.val[1]), \ | |||||
vreinterpret_f16_s32(s32b21b31.val[1])); \ | |||||
} while (0); | |||||
#define TRANSPOSE_8x8(a, ret) \ | |||||
do { \ | |||||
auto b00 = vzipq_f16(CONCAT(a, 0).value, \ | |||||
CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ | |||||
auto b01 = vzipq_f16(CONCAT(a, 0).value, \ | |||||
CONCAT(a, 1).value); /*a5b5a6b6 a7b7a8b8*/ \ | |||||
auto b10 = vzipq_f16(CONCAT(a, 2).value, \ | |||||
CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ | |||||
auto b11 = vzipq_f16(CONCAT(a, 2).value, \ | |||||
CONCAT(a, 3).value); /*c5d5c6d6 c7d7c8d8*/ \ | |||||
auto b20 = vzipq_f16(CONCAT(a, 4).value, \ | |||||
CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ | |||||
auto b21 = vzipq_f16(CONCAT(a, 4).value, \ | |||||
CONCAT(a, 5).value); /*e5f5e6f6 e7f7e8f8*/ \ | |||||
auto b30 = vzipq_f16(CONCAT(a, 6).value, \ | |||||
CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ | |||||
auto b31 = vzipq_f16(CONCAT(a, 6).value, \ | |||||
CONCAT(a, 7).value); /*g5h5g6h6 g7h7g8h8*/ \ | |||||
auto s32b00 = vreinterpretq_s32_f16(b00.val[0]); \ | |||||
auto s32b01 = vreinterpretq_s32_f16(b01.val[1]); \ | |||||
auto s32b10 = vreinterpretq_s32_f16(b10.val[0]); \ | |||||
auto s32b11 = vreinterpretq_s32_f16(b11.val[1]); \ | |||||
auto s32b20 = vreinterpretq_s32_f16(b20.val[0]); \ | |||||
auto s32b21 = vreinterpretq_s32_f16(b21.val[1]); \ | |||||
auto s32b30 = vreinterpretq_s32_f16(b30.val[0]); \ | |||||
auto s32b31 = vreinterpretq_s32_f16(b31.val[1]); \ | |||||
auto s32b00b10 = vzipq_s32(s32b00, s32b10); \ | |||||
auto s32b01b11 = vzipq_s32(s32b01, s32b11); \ | |||||
auto s32b20b30 = vzipq_s32(s32b20, s32b30); \ | |||||
auto s32b21b31 = vzipq_s32(s32b21, s32b31); \ | |||||
CONCAT(ret, 0).value = vreinterpretq_f16_s32( \ | |||||
vcombine_s32(vget_low_s32(s32b00b10.val[0]), \ | |||||
vget_low_s32(s32b20b30.val[0]))); \ | |||||
CONCAT(ret, 1).value = vreinterpretq_f16_s32( \ | |||||
vcombine_s32(vget_high_s32(s32b00b10.val[0]), \ | |||||
vget_high_s32(s32b20b30.val[0]))); \ | |||||
CONCAT(ret, 2).value = vreinterpretq_f16_s32( \ | |||||
vcombine_s32(vget_low_s32(s32b00b10.val[1]), \ | |||||
vget_low_s32(s32b20b30.val[1]))); \ | |||||
CONCAT(ret, 3).value = vreinterpretq_f16_s32( \ | |||||
vcombine_s32(vget_high_s32(s32b00b10.val[1]), \ | |||||
vget_high_s32(s32b20b30.val[1]))); \ | |||||
CONCAT(ret, 4).value = vreinterpretq_f16_s32( \ | |||||
vcombine_s32(vget_low_s32(s32b01b11.val[0]), \ | |||||
vget_low_s32(s32b21b31.val[0]))); \ | |||||
CONCAT(ret, 5).value = vreinterpretq_f16_s32( \ | |||||
vcombine_s32(vget_high_s32(s32b01b11.val[0]), \ | |||||
vget_high_s32(s32b21b31.val[0]))); \ | |||||
CONCAT(ret, 6).value = vreinterpretq_f16_s32( \ | |||||
vcombine_s32(vget_low_s32(s32b01b11.val[1]), \ | |||||
vget_low_s32(s32b21b31.val[1]))); \ | |||||
CONCAT(ret, 7).value = vreinterpretq_f16_s32( \ | |||||
vcombine_s32(vget_high_s32(s32b01b11.val[1]), \ | |||||
vget_high_s32(s32b21b31.val[1]))); \ | |||||
} while (0); | |||||
#endif | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,33 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/f16/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace winograd { | |||||
MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 2, | |||||
3, 4, 4, winograd_2x3_4x4_f16) | |||||
MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 4, | |||||
5, 1, 1, winograd_4x5_1x1_f16) | |||||
MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 6, | |||||
3, 1, 1, winograd_6x3_1x1_f16) | |||||
MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 2, | |||||
3, 8, 8, winograd_2x3_8x8_f16) | |||||
} // namespace winograd | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,373 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/f16/strategy_2x3.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/f16/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||||
#include "src/arm_common/conv_bias/f16/helper.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F23) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
void transpose_4x4(const __fp16* src, __fp16* dst, int lda, int ldb) { | |||||
float16x4x2_t a0, a1; | |||||
a0.val[0] = vld1_f16(src + 0 * lda); // a0a1a2a3 | |||||
a0.val[1] = vld1_f16(src + 1 * lda); // b0b1b2b3 | |||||
a1.val[0] = vld1_f16(src + 2 * lda); // c0c1c2c3 | |||||
a1.val[1] = vld1_f16(src + 3 * lda); // d0d1d2d3 | |||||
float16x4x2_t b0 = vzip_f16(a0.val[0], a1.val[0]); // a0c0a1c1a2c2a3c3 | |||||
float16x4x2_t b1 = vzip_f16(a0.val[1], a1.val[1]); // b0d0b1d1b2d2b3d3 | |||||
float16x4x2_t c0 = vzip_f16(b0.val[0], b1.val[0]); // a0b0c0d0a1b1c1d1 | |||||
float16x4x2_t c1 = vzip_f16(b0.val[1], b1.val[1]); // a2b2c2d2a3b3c3d3 | |||||
vst1_f16(dst + 0 * ldb, c0.val[0]); | |||||
vst1_f16(dst + 1 * ldb, c0.val[1]); | |||||
vst1_f16(dst + 2 * ldb, c1.val[0]); | |||||
vst1_f16(dst + 3 * ldb, c1.val[1]); | |||||
} | |||||
struct InputTransform2X3 { | |||||
template <bool inner> | |||||
static void prepare(const __fp16* input, __fp16* patch, __fp16* patchT, | |||||
int ih_start, int iw_start, size_t IH, size_t IW, | |||||
size_t ic, size_t IC) { | |||||
constexpr size_t alpha = 2 + 3 - 1; | |||||
constexpr size_t alpha4 = alpha * 4; | |||||
if (!(inner && ic + 4 < IC)) { | |||||
memset(patch, 0, sizeof(__fp16) * 4 * alpha * alpha); | |||||
} | |||||
if (inner) { | |||||
const __fp16* input_ptr = | |||||
input + ic * IH * IW + ih_start * IW + iw_start; | |||||
for (size_t ico = 0; ico < 4; ++ico) { | |||||
if (ic + ico < IC) { | |||||
auto v0 = vld1_f16(input_ptr); | |||||
auto v1 = vld1_f16(input_ptr + IW); | |||||
auto v2 = vld1_f16(input_ptr + IW * 2); | |||||
auto v3 = vld1_f16(input_ptr + IW * 3); | |||||
vst1_f16(patch + ico * alpha4 + 0 * 4, v0); | |||||
vst1_f16(patch + ico * alpha4 + 1 * 4, v1); | |||||
vst1_f16(patch + ico * alpha4 + 2 * 4, v2); | |||||
vst1_f16(patch + ico * alpha4 + 3 * 4, v3); | |||||
input_ptr += IH * IW; | |||||
} | |||||
} | |||||
} else { | |||||
int ih0_act = std::max<int>(ih_start, 0), | |||||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||||
iw0_act = std::max<int>(iw_start, 0), | |||||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||||
// partial copy | |||||
for (size_t ico = 0; ico < 4; ++ico) { | |||||
if (ic + ico < IC) { | |||||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||||
patch[ico * alpha4 + iho * 4 + iwo] = | |||||
input[(ic + ico) * IH * IW + ih * IW + iw]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
transpose_4x4(patch + 0 * 1, patchT + 0 * 4, 16, 4); | |||||
transpose_4x4(patch + 4 * 1, patchT + 4 * 4, 16, 4); | |||||
transpose_4x4(patch + 8 * 1, patchT + 8 * 4, 16, 4); | |||||
transpose_4x4(patch + 12 * 1, patchT + 12 * 4, 16, 4); | |||||
} | |||||
static void transform(const __fp16* patchT, __fp16* input_transform_buf, | |||||
size_t unit_idx, size_t nr_units_in_tile, size_t ic, | |||||
size_t IC) { | |||||
constexpr size_t alpha = 2 + 3 - 1; | |||||
// BT * d * B | |||||
#define cb(m, n) \ | |||||
Vector<__fp16, 4> d##m##n = \ | |||||
Vector<__fp16, 4>::load(patchT + m * 4 * 4 + n * 4); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||||
#undef cb | |||||
//! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 | |||||
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | |||||
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | |||||
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 | |||||
#define cb(m) \ | |||||
auto t0##m = d0##m - d2##m; \ | |||||
auto t1##m = d1##m + d2##m; \ | |||||
auto t2##m = d2##m - d1##m; \ | |||||
auto t3##m = d3##m - d1##m; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
#define cb(m) \ | |||||
d##m##0 = t##m##0 - t##m##2; \ | |||||
d##m##1 = t##m##1 + t##m##2; \ | |||||
d##m##2 = t##m##2 - t##m##1; \ | |||||
d##m##3 = t##m##3 - t##m##1; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
#define cb(m, n) \ | |||||
d##m##n.save(input_transform_buf + \ | |||||
(m * alpha + n) * nr_units_in_tile * IC + unit_idx * IC + \ | |||||
ic); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) | |||||
#undef cb | |||||
} | |||||
}; | |||||
template <BiasMode bmode, typename Op> | |||||
struct OutputTransform2X3 { | |||||
static void transform(const dt_float16* output_transform_buf, | |||||
const dt_float16* bias, dt_float16* output, | |||||
dt_float16* transform_mid_buf, size_t oh_start, | |||||
size_t ow_start, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, size_t oc_index, | |||||
size_t unit_idx, size_t nr_units_in_tile, | |||||
const DType& src_dtype, const DType& dst_dtype) { | |||||
Op op(src_dtype, dst_dtype); | |||||
const __fp16* output_transform_ptr = | |||||
reinterpret_cast<const __fp16*>(output_transform_buf); | |||||
const __fp16* bias_ptr = reinterpret_cast<const __fp16*>(bias); | |||||
__fp16* output_ptr = reinterpret_cast<__fp16*>(output); | |||||
__fp16* transform_mid_ptr = | |||||
reinterpret_cast<__fp16*>(transform_mid_buf); | |||||
//! AT * m * A | |||||
constexpr size_t alpha = 2 + 3 - 1; | |||||
size_t OC = oc_end - oc_start; | |||||
size_t oc = oc_start + oc_index; | |||||
#define cb(m, n) \ | |||||
auto v##m##n = Vector<__fp16, 4>::load( \ | |||||
output_transform_ptr + (m * alpha + n) * nr_units_in_tile * OC + \ | |||||
unit_idx * OC + oc_index); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||||
#undef cb | |||||
//! 1 1 1 0 v00 v01 v02 v03 1 0 | |||||
//! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||||
//! v20 v21 v22 v23 1 -1 | |||||
//! v30 v31 v32 v33 0 1 | |||||
#define cb(m) \ | |||||
auto t0##m = v0##m + v1##m + v2##m; \ | |||||
auto t1##m = v1##m - v2##m + v3##m; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
v00 = t00 + t01 + t02; | |||||
v10 = t10 + t11 + t12; | |||||
v01 = t01 - t02 + t03; | |||||
v11 = t11 - t12 + t13; | |||||
Vector<__fp16, 4> vbias; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
vbias = Vector<__fp16, 4>::load(bias_ptr + oc); | |||||
v00 += vbias; | |||||
v10 += vbias; | |||||
v01 += vbias; | |||||
v11 += vbias; | |||||
} | |||||
float16x8_t result01, result23; | |||||
result01 = vcombine_f16(v00.value, v01.value); | |||||
result23 = vcombine_f16(v10.value, v11.value); | |||||
if (bmode != BiasMode::BIAS) { | |||||
result01 = op(result01); | |||||
result23 = op(result23); | |||||
} | |||||
vst1q_f16(transform_mid_ptr, result01); | |||||
vst1q_f16(transform_mid_ptr + 8, result23); | |||||
for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { | |||||
for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { | |||||
for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) { | |||||
size_t oh = oh_start + oho; | |||||
size_t ow = ow_start + owo; | |||||
__fp16 res = transform_mid_ptr[oho * 2 * 4 + owo * 4 + oco]; | |||||
if (bmode == BiasMode::BIAS) { | |||||
res += bias_ptr[(oc + oco) * OH * OW + oh * OW + ow]; | |||||
res = op(res); | |||||
} | |||||
output_ptr[(oc + oco) * OH * OW + oh * OW + ow] = res; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace winograd { | |||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f16) | |||||
void winograd_2x3_4x4_f16::filter(const dt_float16* filter, | |||||
dt_float16* filter_transform_buf, | |||||
dt_float16* transform_mid_buf, size_t OC, | |||||
size_t IC, size_t oc_start, size_t oc_end) { | |||||
constexpr int alpha = 2 + 3 - 1; | |||||
//! G * g * GT | |||||
__fp16* filter_transbuf_ptr = | |||||
reinterpret_cast<__fp16*>(filter_transform_buf); | |||||
__fp16* filter_transmid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf); | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
rep(ic, IC) { | |||||
const __fp16* filter_ptr = reinterpret_cast<const __fp16*>(filter) + | |||||
(oc * IC + ic) * 3 * 3; | |||||
/** | |||||
* origin: (4x3) * (3 x 3) * (3 x 4) | |||||
* pack to G and g to times of 4 | |||||
* now: (4x4) * (4 x 4) * (4 x 4) | |||||
*/ | |||||
//! 1 0 0 0 v00 v01 v02 0 1 0.5 0.5 0 | |||||
//! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 | |||||
//! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 | |||||
//! 0 0 1 0 0 0 0 0 0 0 0 0 | |||||
float16x4_t v0 = vld1_f16(filter_ptr); // 0 1 2 3 | |||||
float16x4_t v1 = vld1_f16(filter_ptr + 3); // 3 4 5 6 | |||||
float16x4_t v2 = vld1_f16(filter_ptr + 5); // 5678 | |||||
float16x4_t v3 = vdup_n_f16(0); | |||||
v2 = vext_f16(v2, v3, 1); | |||||
v0 = vset_lane_f16(0, v0, 3); | |||||
v1 = vset_lane_f16(0, v1, 3); | |||||
#define cb(i) float16x4_t vsum##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
vsum0 = v0; | |||||
float16x4_t v0addv2 = vadd_f16(v0, v2); | |||||
float16x4_t v02addv1 = vadd_f16(v0addv2, v1); | |||||
float16x4_t v02subv1 = vsub_f16(v0addv2, v1); | |||||
vsum1 = vmul_n_f16(v02addv1, 0.5); | |||||
vsum2 = vmul_n_f16(v02subv1, 0.5); | |||||
vsum3 = v2; | |||||
#define cb(i) \ | |||||
do { \ | |||||
mid_buf1[0] = vget_lane_f16(vsum##i, 0); \ | |||||
__fp16 a0a2 = vget_lane_f16(vsum##i, 0) + vget_lane_f16(vsum##i, 2); \ | |||||
__fp16 a0a2adda1 = a0a2 + vget_lane_f16(vsum##i, 1); \ | |||||
__fp16 a0a2suba1 = a0a2 - vget_lane_f16(vsum##i, 1); \ | |||||
mid_buf1[1] = a0a2adda1 * 0.5; \ | |||||
mid_buf1[2] = a0a2suba1 * 0.5; \ | |||||
mid_buf1[3] = vget_lane_f16(vsum##i, 2); \ | |||||
mid_buf1 += 4; \ | |||||
} while (0); | |||||
__fp16* mid_buf1 = filter_transmid_ptr; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
mid_buf1 = filter_transmid_ptr; | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
filter_transbuf_ptr[(i * alpha + j) * OC * IC + ic * OC + oc] = | |||||
filter_transmid_ptr[i * alpha + j]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void winograd_2x3_4x4_f16::input(const dt_float16* input, | |||||
dt_float16* input_transform_buf, | |||||
dt_float16* transform_mid_buf, size_t IH, | |||||
size_t IW, size_t IC, size_t PH, size_t PW, | |||||
size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
megdnn_assert(IC % 4 == 0); | |||||
constexpr int alpha = 3 + 2 - 1; | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
dt_float16* patch = transform_mid_buf; | |||||
dt_float16* patchT = transform_mid_buf + 4 * alpha * alpha; | |||||
for (size_t ic = 0; ic < IC; ic += 4) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||||
InputTransform2X3::prepare<true>( | |||||
reinterpret_cast<const __fp16*>(input), | |||||
reinterpret_cast<__fp16*>(patch), | |||||
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, | |||||
IH, IW, ic, IC); | |||||
InputTransform2X3::transform( | |||||
reinterpret_cast<const __fp16*>(patchT), | |||||
reinterpret_cast<__fp16*>(input_transform_buf), | |||||
unit_idx, nr_units_in_tile, ic, IC); | |||||
} else { | |||||
InputTransform2X3::prepare<false>( | |||||
reinterpret_cast<const __fp16*>(input), | |||||
reinterpret_cast<__fp16*>(patch), | |||||
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, | |||||
IH, IW, ic, IC); | |||||
InputTransform2X3::transform( | |||||
reinterpret_cast<const __fp16*>(patchT), | |||||
reinterpret_cast<__fp16*>(input_transform_buf), | |||||
unit_idx, nr_units_in_tile, ic, IC); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void winograd_2x3_4x4_f16::output(const dt_float16* output_transform_buf, | |||||
const dt_float16* bias, dt_float16* output, | |||||
dt_float16* transform_mid_buf, BiasMode bmode, | |||||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, | |||||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
OutputTransform2X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
for (size_t oc = oc_start; oc < oc_end; oc += 4) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp16_F23, cb, __fp16, __fp16, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||||
nr_units_in_tile, src_dtype, dst_dtype); | |||||
} | |||||
} | |||||
#undef cb | |||||
} | |||||
} // namespace winograd | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,407 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/arm_common/conv_bias/f16/strategy.h" | |||||
#include "src/arm_common/conv_bias/f16/helper.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/winograd/winograd_generator.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#include "midout.h" | |||||
#include "src/common/winograd/winograd_helper.h" | |||||
MIDOUT_DECL(megdnn_arm_common_winograd_f16_F23_8x8) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
void transpose_8x4(const __fp16* src, __fp16* dst, int lda, int ldb) { | |||||
float16x4x2_t a0, a1, a2, a3; | |||||
a0.val[0] = vld1_f16(src + 0 * lda); | |||||
a0.val[1] = vld1_f16(src + 1 * lda); | |||||
a1.val[0] = vld1_f16(src + 2 * lda); | |||||
a1.val[1] = vld1_f16(src + 3 * lda); | |||||
a2.val[0] = vld1_f16(src + 4 * lda); | |||||
a2.val[1] = vld1_f16(src + 5 * lda); | |||||
a3.val[0] = vld1_f16(src + 6 * lda); | |||||
a3.val[1] = vld1_f16(src + 7 * lda); | |||||
float16x4x2_t b0 = vzip_f16(a0.val[0], a1.val[0]); | |||||
float16x4x2_t b1 = vzip_f16(a0.val[1], a1.val[1]); | |||||
float16x4x2_t b2 = vzip_f16(a2.val[0], a3.val[0]); | |||||
float16x4x2_t b3 = vzip_f16(a2.val[1], a3.val[1]); | |||||
float16x4x2_t c0 = vzip_f16(b0.val[0], b1.val[0]); | |||||
float16x4x2_t c1 = vzip_f16(b0.val[1], b1.val[1]); | |||||
float16x4x2_t c2 = vzip_f16(b2.val[0], b3.val[0]); | |||||
float16x4x2_t c3 = vzip_f16(b2.val[1], b3.val[1]); | |||||
vst1_f16(dst + 0 * ldb, c0.val[0]); | |||||
vst1_f16(dst + 1 * ldb, c2.val[0]); | |||||
vst1_f16(dst + 2 * ldb, c0.val[1]); | |||||
vst1_f16(dst + 3 * ldb, c2.val[1]); | |||||
vst1_f16(dst + 4 * ldb, c1.val[0]); | |||||
vst1_f16(dst + 5 * ldb, c3.val[0]); | |||||
vst1_f16(dst + 6 * ldb, c1.val[1]); | |||||
vst1_f16(dst + 7 * ldb, c3.val[1]); | |||||
} | |||||
struct InputTransform2X3_8x8 { | |||||
template <bool inner> | |||||
static void prepare(const __fp16* input, __fp16* patch, __fp16* patchT, | |||||
int ih_start, int iw_start, size_t IH, size_t IW, | |||||
size_t ic, size_t IC) { | |||||
constexpr size_t alpha = 2 + 3 - 1; | |||||
if (!(inner && ic + 8 < IC)) { | |||||
memset(patch, 0, sizeof(__fp16) * 8 * alpha * alpha); | |||||
} | |||||
if (inner) { | |||||
const __fp16* input_ptr = | |||||
input + ic * IH * IW + ih_start * IW + iw_start; | |||||
for (size_t ico = 0; ico < 8; ++ico) { | |||||
if (ic + ico < IC) { | |||||
auto v0 = vld1_f16(input_ptr); | |||||
auto v1 = vld1_f16(input_ptr + IW); | |||||
auto v2 = vld1_f16(input_ptr + IW * 2); | |||||
auto v3 = vld1_f16(input_ptr + IW * 3); | |||||
vst1_f16(patch + ico * alpha * 4 + 0 * 4, v0); | |||||
vst1_f16(patch + ico * alpha * 4 + 1 * 4, v1); | |||||
vst1_f16(patch + ico * alpha * 4 + 2 * 4, v2); | |||||
vst1_f16(patch + ico * alpha * 4 + 3 * 4, v3); | |||||
input_ptr += IH * IW; | |||||
} | |||||
} | |||||
} else { | |||||
int ih0_act = std::max<int>(ih_start, 0), | |||||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||||
iw0_act = std::max<int>(iw_start, 0), | |||||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||||
// partial copy | |||||
for (size_t ico = 0; ico < 8; ++ico) { | |||||
if (ic + ico < IC) { | |||||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||||
patch[ico * alpha * alpha + iho * alpha + iwo] = | |||||
input[(ic + ico) * IH * IW + ih * IW + iw]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
transpose_8x4(patch + 4 * 0, patchT + 32 * 0, 16, 4); | |||||
transpose_8x4(patch + 4 * 1, patchT + 32 * 1, 16, 4); | |||||
transpose_8x4(patch + 4 * 2, patchT + 32 * 2, 16, 4); | |||||
transpose_8x4(patch + 4 * 3, patchT + 32 * 3, 16, 4); | |||||
} | |||||
static void transform(const __fp16* patchT, __fp16* input_transform_buf, | |||||
size_t unit_idx, size_t nr_units_in_tile, size_t ic, | |||||
size_t IC) { | |||||
constexpr size_t alpha = 2 + 3 - 1; | |||||
// BT * d * B | |||||
#define cb(m, n) \ | |||||
Vector<__fp16, 8> d##m##n = \ | |||||
Vector<__fp16, 8>::load(patchT + 8 * (m * 4 + n)); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||||
#undef cb | |||||
//! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 | |||||
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | |||||
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | |||||
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 | |||||
#define cb(m) \ | |||||
auto t0##m = d0##m - d2##m; \ | |||||
auto t1##m = d1##m + d2##m; \ | |||||
auto t2##m = d2##m - d1##m; \ | |||||
auto t3##m = d3##m - d1##m; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
#define cb(m) \ | |||||
d##m##0 = t##m##0 - t##m##2; \ | |||||
d##m##1 = t##m##1 + t##m##2; \ | |||||
d##m##2 = t##m##2 - t##m##1; \ | |||||
d##m##3 = t##m##3 - t##m##1; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
size_t ICB = IC / 8; | |||||
size_t icb = ic / 8; | |||||
#define cb(m, n) \ | |||||
d##m##n.save(input_transform_buf + \ | |||||
(m * alpha + n) * nr_units_in_tile * ICB * 8 + \ | |||||
icb * nr_units_in_tile * 8 + unit_idx * 8); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) | |||||
#undef cb | |||||
} | |||||
}; | |||||
template <BiasMode bmode, typename Op> | |||||
struct OutputTransform2X3_8x8 { | |||||
static void transform(const dt_float16* output_transform_buf, | |||||
const dt_float16* bias, dt_float16* output, | |||||
dt_float16* transform_mid_buf, size_t oh_start, | |||||
size_t ow_start, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, size_t oc_index, | |||||
size_t unit_idx, size_t nr_units_in_tile, | |||||
const DType& src_dtype, const DType& dst_dtype) { | |||||
Op op(src_dtype, dst_dtype); | |||||
const __fp16* output_transform_ptr = | |||||
reinterpret_cast<const __fp16*>(output_transform_buf); | |||||
const __fp16* bias_ptr = reinterpret_cast<const __fp16*>(bias); | |||||
__fp16* output_ptr = reinterpret_cast<__fp16*>(output); | |||||
__fp16* transform_mid_ptr = | |||||
reinterpret_cast<__fp16*>(transform_mid_buf); | |||||
//! AT * m * A | |||||
constexpr size_t alpha = 2 + 3 - 1; | |||||
size_t oc = oc_start + oc_index; | |||||
size_t OCB = (oc_end - oc_start) / 8; | |||||
size_t ocb = oc_index / 8; | |||||
#define cb(m, n) \ | |||||
auto v##m##n = Vector<__fp16, 8>::load( \ | |||||
output_transform_ptr + \ | |||||
(m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | |||||
ocb * nr_units_in_tile * 8 + unit_idx * 8); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||||
#undef cb | |||||
//! 1 1 1 0 v00 v01 v02 v03 1 0 | |||||
//! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||||
//! v20 v21 v22 v23 1 -1 | |||||
//! v30 v31 v32 v33 0 1 | |||||
#define cb(m) \ | |||||
auto t0##m = v0##m + v1##m + v2##m; \ | |||||
auto t1##m = v1##m - v2##m + v3##m; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
v00 = t00 + t01 + t02; | |||||
v10 = t10 + t11 + t12; | |||||
v01 = t01 - t02 + t03; | |||||
v11 = t11 - t12 + t13; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
Vector<__fp16, 8> vbias; | |||||
vbias = Vector<__fp16, 8>::load(bias_ptr + oc); | |||||
v00 += vbias; | |||||
v10 += vbias; | |||||
v01 += vbias; | |||||
v11 += vbias; | |||||
} | |||||
if (bmode != BiasMode::BIAS) { | |||||
v00.value = op(v00.value); | |||||
v01.value = op(v01.value); | |||||
v10.value = op(v10.value); | |||||
v11.value = op(v11.value); | |||||
} | |||||
v00.save(transform_mid_ptr + (0 * 2 + 0) * 8); | |||||
v10.save(transform_mid_ptr + (1 * 2 + 0) * 8); | |||||
v01.save(transform_mid_ptr + (0 * 2 + 1) * 8); | |||||
v11.save(transform_mid_ptr + (1 * 2 + 1) * 8); | |||||
for (size_t oco = 0; oco < 8 && oc + oco < oc_end; ++oco) { | |||||
for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { | |||||
for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) { | |||||
size_t oh = oh_start + oho; | |||||
size_t ow = ow_start + owo; | |||||
__fp16 res = transform_mid_ptr[oho * 2 * 8 + owo * 8 + oco]; | |||||
if (bmode == BiasMode::BIAS) { | |||||
res += bias_ptr[(oc + oco) * OH * OW + oh * OW + ow]; | |||||
res = op(res); | |||||
} | |||||
output_ptr[(oc + oco) * OH * OW + oh * OW + ow] = res; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace winograd { | |||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_f16) | |||||
void winograd_2x3_8x8_f16::filter(const dt_float16* filter, | |||||
dt_float16* filter_transform_buf, | |||||
dt_float16* transform_mid_buf, size_t OC, | |||||
size_t IC, size_t oc_start, size_t oc_end) { | |||||
constexpr int alpha = 2 + 3 - 1; | |||||
//! G * g * GT | |||||
__fp16* filter_transbuf_ptr = | |||||
reinterpret_cast<__fp16*>(filter_transform_buf); | |||||
__fp16* filter_transmid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf); | |||||
size_t OCB = OC / 8; | |||||
size_t ICB = IC / 8; | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
rep(ic, IC) { | |||||
const __fp16* filter_ptr = reinterpret_cast<const __fp16*>(filter) + | |||||
(oc * IC + ic) * 3 * 3; | |||||
/** | |||||
* origin: (4x3) * (3 x 3) * (3 x 4) | |||||
* pack to G and g to times of 4 | |||||
* now: (4x4) * (4 x 4) * (4 x 4) | |||||
*/ | |||||
//! 1 0 0 0 v00 v01 v02 0 1 0.5 0.5 0 | |||||
//! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 | |||||
//! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 | |||||
//! 0 0 1 0 0 0 0 0 0 0 0 0 | |||||
float16x4_t v0 = vld1_f16(filter_ptr); // 0 1 2 3 | |||||
float16x4_t v1 = vld1_f16(filter_ptr + 3); // 3 4 5 6 | |||||
float16x4_t v2 = vld1_f16(filter_ptr + 5); // 5678 | |||||
float16x4_t v3 = vdup_n_f16(0); | |||||
v2 = vext_f16(v2, v3, 1); | |||||
v0 = vset_lane_f16(0, v0, 3); | |||||
v1 = vset_lane_f16(0, v1, 3); | |||||
#define cb(i) float16x4_t vsum##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
vsum0 = v0; | |||||
float16x4_t v0addv2 = vadd_f16(v0, v2); | |||||
float16x4_t v02addv1 = vadd_f16(v0addv2, v1); | |||||
float16x4_t v02subv1 = vsub_f16(v0addv2, v1); | |||||
vsum1 = vmul_n_f16(v02addv1, 0.5); | |||||
vsum2 = vmul_n_f16(v02subv1, 0.5); | |||||
vsum3 = v2; | |||||
#define cb(i) \ | |||||
do { \ | |||||
mid_buf1[0] = vget_lane_f16(vsum##i, 0); \ | |||||
__fp16 a0a2 = vget_lane_f16(vsum##i, 0) + vget_lane_f16(vsum##i, 2); \ | |||||
__fp16 a0a2adda1 = a0a2 + vget_lane_f16(vsum##i, 1); \ | |||||
__fp16 a0a2suba1 = a0a2 - vget_lane_f16(vsum##i, 1); \ | |||||
mid_buf1[1] = a0a2adda1 * 0.5; \ | |||||
mid_buf1[2] = a0a2suba1 * 0.5; \ | |||||
mid_buf1[3] = vget_lane_f16(vsum##i, 2); \ | |||||
mid_buf1 += 4; \ | |||||
} while (0); | |||||
__fp16* mid_buf1 = filter_transmid_ptr; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
mid_buf1 = filter_transmid_ptr; | |||||
#undef cb | |||||
size_t ocb = (oc) / 8; | |||||
size_t oc8 = (oc) % 8; | |||||
size_t icb = (ic) / 8; | |||||
size_t ic8 = (ic) % 8; | |||||
rep(i, alpha) rep(j, alpha) { | |||||
filter_transbuf_ptr[(i * alpha + j) * OCB * ICB * 8 * 8 + | |||||
ocb * ICB * 8 * 8 + icb * 8 * 8 + ic8 * 8 + | |||||
oc8] = filter_transmid_ptr[i * alpha + j]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void winograd_2x3_8x8_f16::input(const dt_float16* input, | |||||
dt_float16* input_transform_buf, | |||||
dt_float16* transform_mid_buf, size_t IH, | |||||
size_t IW, size_t IC, size_t PH, size_t PW, | |||||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||||
megdnn_assert(IC % 8 == 0); | |||||
constexpr int alpha = 3 + 2 - 1; | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
dt_float16* patch = transform_mid_buf; | |||||
dt_float16* patchT = transform_mid_buf + 8 * alpha * alpha; | |||||
for (size_t ic = 0; ic < IC; ic += 8) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||||
InputTransform2X3_8x8::prepare<true>( | |||||
reinterpret_cast<const __fp16*>(input), | |||||
reinterpret_cast<__fp16*>(patch), | |||||
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, | |||||
IH, IW, ic, IC); | |||||
InputTransform2X3_8x8::transform( | |||||
reinterpret_cast<const __fp16*>(patchT), | |||||
reinterpret_cast<__fp16*>(input_transform_buf), | |||||
unit_idx, nr_units_in_tile, ic, IC); | |||||
} else { | |||||
InputTransform2X3_8x8::prepare<false>( | |||||
reinterpret_cast<const __fp16*>(input), | |||||
reinterpret_cast<__fp16*>(patch), | |||||
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, | |||||
IH, IW, ic, IC); | |||||
InputTransform2X3_8x8::transform( | |||||
reinterpret_cast<const __fp16*>(patchT), | |||||
reinterpret_cast<__fp16*>(input_transform_buf), | |||||
unit_idx, nr_units_in_tile, ic, IC); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void winograd_2x3_8x8_f16::output(const dt_float16* output_transform_buf, | |||||
const dt_float16* bias, dt_float16* output, | |||||
dt_float16* transform_mid_buf, BiasMode bmode, | |||||
NonlineMode nonline_mode, size_t OH, | |||||
size_t OW, size_t oc_start, size_t oc_end, | |||||
size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
OutputTransform2X3_8x8<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | |||||
__VA_ARGS__); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
for (size_t oc = oc_start; oc < oc_end; oc += 8) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp16_F23_8x8, cb, __fp16, __fp16, | |||||
bmode, nonline_mode, output_transform_buf, bias, output, | |||||
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, | |||||
oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | |||||
} | |||||
} | |||||
#undef cb | |||||
} | |||||
} // namespace winograd | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,488 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/f16/helper.h" | |||||
#include "src/arm_common/conv_bias/f16/strategy.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F45) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
struct FilterTransform4X5 { | |||||
#define FILTER_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
wd##r0 = d##r0; \ | |||||
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.222168; \ | |||||
wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.222168; \ | |||||
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.222168; \ | |||||
wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.222168; \ | |||||
auto tmpd0 = d##0 * 0.710938; \ | |||||
auto tmpd1 = d##1 * 0.355469; \ | |||||
auto tmpd2 = d##2 * 0.177734; \ | |||||
auto tmpd3 = d##3 * 0.088867; \ | |||||
auto tmpd4 = d##4 * 0.044434; \ | |||||
auto tmpdr0 = d##r0 * 0.710938; \ | |||||
auto tmpdr1 = d##r1 * 0.355469; \ | |||||
auto tmpdr2 = d##r2 * 0.177734; \ | |||||
auto tmpdr3 = d##r3 * 0.088867; \ | |||||
auto tmpdr4 = d##r4 * 0.044434; \ | |||||
wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ | |||||
wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||||
wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ | |||||
wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||||
tmpd0 = d##0 * 0.011108; \ | |||||
tmpd1 = d##1 * 0.022217; \ | |||||
tmpd2 = d##2 * 0.044434; \ | |||||
tmpd3 = d##3 * 0.088867; \ | |||||
tmpd4 = d##4 * 0.177734; \ | |||||
tmpdr0 = d##r0 * 0.011108; \ | |||||
tmpdr1 = d##r1 * 0.022217; \ | |||||
tmpdr2 = d##r2 * 0.044434; \ | |||||
tmpdr3 = d##r3 * 0.088867; \ | |||||
tmpdr4 = d##r4 * 0.177734; \ | |||||
wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ | |||||
; \ | |||||
wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||||
; \ | |||||
wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ | |||||
; \ | |||||
wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||||
; \ | |||||
wd##7 = d##4; \ | |||||
wd##r7 = d##r4; \ | |||||
} while (0); | |||||
#define FILTER_TRANSFORM_FINAL(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.222168; \ | |||||
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.222168; \ | |||||
auto tmp0 = d##0 * 0.710938 + d##2 * 0.177734 + d##4 * 0.044434; \ | |||||
auto tmp1 = d##1 * 0.355469 + d##3 * 0.088867; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d##0 * 0.011108 + d##2 * 0.044434 + d##4 * 0.177734; \ | |||||
tmp1 = d##1 * 0.022217 + d##3 * 0.088867; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = d##4; \ | |||||
} while (0); | |||||
static void transform(const __fp16* filter, __fp16* filter_transform_buf, | |||||
__fp16* transform_mid_buf, size_t OC, size_t IC, | |||||
size_t oc_start, size_t oc_end) { | |||||
// Gg * GT | |||||
// G | |||||
//[[ 1. 0. 0. 0. 0. ] | |||||
// [-0.2222222 -0.2222222 -0.2222222 -0.2222222 -0.2222222] | |||||
// [-0.2222222 0.2222222 -0.2222222 0.2222222 -0.2222222] | |||||
// [ 0.7111111 0.3555556 0.1777778 0.0888889 0.0444444] | |||||
// [ 0.7111111 -0.3555556 0.1777778 -0.0888889 0.0444444] | |||||
// [ 0.0111111 0.0222222 0.0444444 0.0888889 0.1777778] | |||||
// [ 0.0111111 -0.0222222 0.0444444 -0.0888889 0.1777778] | |||||
// [ 0. 0. 0. 0. 1. ]] | |||||
constexpr size_t alpha = 4 + 5 - 1; | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
rep(ic, IC) { | |||||
const __fp16* fptr = filter + (oc * IC + ic) * 5 * 5; | |||||
#define cb(i) Vector<__fp16, 4> g##i = Vector<__fp16, 4>::load(fptr + 5 * i); | |||||
UNROLL_CALL_NOWRAPPER(5, cb); | |||||
#undef cb | |||||
#define cb(i) __fp16 gr##i = *(fptr + 5 * i + 4); | |||||
UNROLL_CALL_NOWRAPPER(5, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<__fp16, 4> Gg##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) __fp16 Ggr##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<__fp16, 8> Ggt##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<__fp16, 8> result##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
FILTER_TRANSFORM(g, Gg) | |||||
#if MEGDNN_AARCH64 | |||||
float16x8_t vgr = {Ggr0, Ggr1, Ggr2, Ggr3, | |||||
Ggr4, Ggr5, Ggr6, Ggr7}; | |||||
Vector<__fp16, 8> Ggt4(vgr); | |||||
TRANSPOSE_8x4(Gg, Ggt); | |||||
FILTER_TRANSFORM_FINAL(Ggt, result); | |||||
#define cb(i) result##i.save(transform_mid_buf + i * alpha); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||||
oc] = transform_mid_buf[j * alpha + i]; | |||||
} | |||||
#else | |||||
#define GET_VECTOR_FP16D_ELEM(s, i, idx) vget_lane_f16(CONCAT(s, i).value, idx) | |||||
#define cb(i) \ | |||||
do { \ | |||||
mid_buf1[0] = GET_VECTOR_FP16D_ELEM(Gg, i, 0); \ | |||||
auto tmp024 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) + \ | |||||
GET_VECTOR_FP16D_ELEM(Gg, i, 2) + Ggr##i; \ | |||||
auto tmp13 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) + \ | |||||
GET_VECTOR_FP16D_ELEM(Gg, i, 3); \ | |||||
mid_buf1[1] = (tmp024 + tmp13) * -0.2222222; \ | |||||
mid_buf1[2] = (tmp024 - tmp13) * -0.2222222; \ | |||||
auto tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.7111111; \ | |||||
auto tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.3555556; \ | |||||
auto tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.1777778; \ | |||||
auto tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \ | |||||
auto tmp4 = Ggr##i * 0.0444444; \ | |||||
tmp024 = tmp0 + tmp2 + tmp4; \ | |||||
tmp13 = tmp1 + tmp3; \ | |||||
mid_buf1[3] = tmp024 + tmp13; \ | |||||
mid_buf1[4] = tmp024 - tmp13; \ | |||||
tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.0111111; \ | |||||
tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.0222222; \ | |||||
tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.0444444; \ | |||||
tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \ | |||||
tmp4 = Ggr##i * 0.1777778; \ | |||||
tmp024 = tmp0 + tmp2 + tmp4; \ | |||||
tmp13 = tmp1 + tmp3; \ | |||||
mid_buf1[5] = tmp024 + tmp13; \ | |||||
mid_buf1[6] = tmp024 - tmp13; \ | |||||
mid_buf1[7] = Ggr##i; \ | |||||
mid_buf1 += 8; \ | |||||
} while (0); | |||||
__fp16* mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
#undef GET_VECTOR_FP16D_ELEM | |||||
rep(i, alpha) rep(j, alpha) { | |||||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||||
oc] = transform_mid_buf[i * alpha + j]; | |||||
} | |||||
#endif | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#undef FILTER_TRANSFORM | |||||
#undef FILTER_TRANSFORM_FINAL | |||||
struct InputTransform4X5 { | |||||
#define INPUT_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||||
auto tmp0 = d##2 - d##4 * 4.25f + d##6; \ | |||||
auto tmp1 = d##1 - d##3 * 4.25f + d##5; \ | |||||
wd##1 = tmp0 + tmp1; \ | |||||
wd##2 = tmp0 - tmp1; \ | |||||
tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ | |||||
tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ | |||||
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||||
} while (0) | |||||
#define GET_VECTOR_FP16Q_ELEM(s, i, idx) vgetq_lane_f16(CONCAT(s, i).value, idx) | |||||
template <bool inner> | |||||
static void transform(const __fp16* input, __fp16* input_transform_buf, | |||||
__fp16* transform_mid_buf, int ih_start, int iw_start, | |||||
size_t ic, size_t IH, size_t IW, size_t IC, | |||||
size_t unit_idx, size_t nr_units_in_tile) { | |||||
// BTd * B | |||||
//([[ 1. , 0. , -5.25, 0. , 5.25, 0. , -1. , 0. ], | |||||
// [ 0. , 1. , 1. , -4.25, -4.25, 1. , 1. , 0. ], | |||||
// [ 0. , -1. , 1. , 4.25, -4.25, -1. , 1. , 0. ], | |||||
// [ 0. , 2. , 4. , -2.5 , -5. , 0.5 , 1. , 0. ], | |||||
// [ 0. , -2. , 4. , 2.5 , -5. , -0.5 , 1. , 0. ], | |||||
// [ 0. , 0.5 , 0.25, -2.5 , -1.25, 2. , 1. , 0. ], | |||||
// [ 0. , -0.5 , 0.25, 2.5 , -1.25, -2. , 1. , 0. ], | |||||
// [ 0. , -1. , 0. , 5.25, 0. , -5.25, 0. , 1. ]])) | |||||
constexpr size_t alpha = 4 + 5 - 1; | |||||
if (!inner) { | |||||
memset(transform_mid_buf, 0, sizeof(__fp16) * alpha * alpha); | |||||
} | |||||
#define cb(i) Vector<__fp16, 8> d##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
if (inner) { | |||||
const __fp16* input_ptr = | |||||
input + ic * IH * IW + ih_start * IW + iw_start; | |||||
#define cb(i) d##i = Vector<__fp16, 8>::load(input_ptr + IW * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
} else { | |||||
int ih0_act = std::max<int>(ih_start, 0), | |||||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||||
iw0_act = std::max<int>(iw_start, 0), | |||||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||||
transform_mid_buf[iho * alpha + iwo] = | |||||
input[ic * IH * IW + ih * IW + iw]; | |||||
} | |||||
} | |||||
#define cb(i) d##i = Vector<__fp16, 8>::load(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
} | |||||
#define cb(i) Vector<__fp16, 8> wd##i, ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
INPUT_TRANSFORM(d, wd); | |||||
TRANSPOSE_8x8(wd, d); | |||||
INPUT_TRANSFORM(d, ret); | |||||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||||
unit_idx * IC + ic] = | |||||
transform_mid_buf[j * alpha + i]; | |||||
} | |||||
} | |||||
}; | |||||
#undef INPUT_TRANSFORM | |||||
#define OUTPUT_TRANSFORM(m, s) \ | |||||
do { \ | |||||
s##0 = m##0 + m##1 + m##2 + m##3 + m##4 + m##5 + m##6; \ | |||||
s##1 = m##1 - m##2 + m##3 * 0.5 - m##4 * 0.5 + m##5 * 2.0 - \ | |||||
m##6 * 2.0; \ | |||||
s##2 = m##1 + m##2 + m##3 * 0.25 + m##4 * 0.25 + m##5 * 4.0 + \ | |||||
m##6 * 4.0; \ | |||||
s##3 = m##1 - m##2 + m##3 * 0.125 - m##4 * 0.125 + m##5 * 8.0 - \ | |||||
m##6 * 8.0 + m##7; \ | |||||
} while (0) | |||||
template <BiasMode bmode, typename Op> | |||||
struct OutputTransform4X5 { | |||||
static void transform(const dt_float16* output_transform_buf, | |||||
const dt_float16* bias, dt_float16* output, | |||||
dt_float16* transform_mid_buf, size_t oh_start, | |||||
size_t ow_start, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, size_t oc_index, | |||||
size_t unit_idx, size_t nr_units_in_tile, | |||||
const DType& src_dtype, const DType& dst_dtype) { | |||||
Op op(src_dtype, dst_dtype); | |||||
//! AT * m * A | |||||
// AT f45 | |||||
// 1.0 1.0 1.0 1.000 1.000 1.0 1.0 0.0 | |||||
// 0.0 1.0 -1.0 0.500 -0.500 2.0 -2.0 0.0 | |||||
// 0.0 1.0 1.0 0.250 0.250 4.0 4.0 0.0 | |||||
// 0.0 1.0 -1.0 0.125 -0.125 8.0 -8.0 1.0 | |||||
constexpr size_t alpha = 5 + 4 - 1; | |||||
const __fp16* fp16_output_transform_buf = | |||||
reinterpret_cast<const __fp16*>(output_transform_buf); | |||||
const __fp16* fp16_bias = reinterpret_cast<const __fp16*>(bias); | |||||
__fp16* fp16_output = reinterpret_cast<__fp16*>(output); | |||||
__fp16* fp16_transform_mid_buf = | |||||
reinterpret_cast<__fp16*>(transform_mid_buf); | |||||
__fp16* mid_buf1 = fp16_transform_mid_buf; | |||||
size_t OC = oc_end - oc_start; | |||||
size_t oc = oc_start + oc_index; | |||||
#define cb(m, n) \ | |||||
fp16_transform_mid_buf[m * alpha + n] = \ | |||||
fp16_output_transform_buf[(m * alpha + n) * nr_units_in_tile * \ | |||||
OC + \ | |||||
unit_idx * OC + oc_index]; | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||||
#undef cb | |||||
#define cb(i) \ | |||||
auto m##i = Vector<__fp16, 8>::load(fp16_transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<__fp16, 8> s##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<__fp16, 4> st##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<__fp16, 4> result##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
OUTPUT_TRANSFORM(m, s); | |||||
TRANSPOSE_4x8(s, st); | |||||
OUTPUT_TRANSFORM(st, result); | |||||
TRANSPOSE_4x4(result, result); | |||||
if (oh_start + 4 <= OH && ow_start + 4 <= OW) { | |||||
int index = (oc * OH + oh_start) * OW + ow_start; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
float16x4_t bias0 = vdup_n_f16(fp16_bias[oc]); | |||||
result0.value = vadd_f16(result0.value, bias0); | |||||
result1.value = vadd_f16(result1.value, bias0); | |||||
result2.value = vadd_f16(result2.value, bias0); | |||||
result3.value = vadd_f16(result3.value, bias0); | |||||
} else if (bmode == BiasMode::BIAS) { | |||||
float16x4_t bmbias0 = vld1_f16(fp16_bias + index); | |||||
float16x4_t bmbias1 = vld1_f16(fp16_bias + index + OW); | |||||
float16x4_t bmbias2 = vld1_f16(fp16_bias + index + OW * 2); | |||||
float16x4_t bmbias3 = vld1_f16(fp16_bias + index + OW * 3); | |||||
result0.value = vadd_f16(result0.value, bmbias0); | |||||
result1.value = vadd_f16(result1.value, bmbias1); | |||||
result2.value = vadd_f16(result2.value, bmbias2); | |||||
result3.value = vadd_f16(result3.value, bmbias3); | |||||
} | |||||
float16x8_t item01 = op(vcombine_f16(result0.value, result1.value)); | |||||
float16x8_t item23 = op(vcombine_f16(result2.value, result3.value)); | |||||
vst1_f16(fp16_output + index, vget_low_f16(item01)); | |||||
vst1_f16(fp16_output + index + OW, vget_high_f16(item01)); | |||||
vst1_f16(fp16_output + index + OW * 2, vget_low_f16(item23)); | |||||
vst1_f16(fp16_output + index + OW * 3, vget_high_f16(item23)); | |||||
} else { | |||||
#define cb(i) result##i.save(mid_buf1 + i * 4); | |||||
mid_buf1 = fp16_transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
mid_buf1 = fp16_transform_mid_buf; | |||||
#undef cb | |||||
for (size_t oho = 0; oho < 4 && oh_start + oho < OH; ++oho) { | |||||
for (size_t owo = 0; owo < 4 && ow_start + owo < OW; ++owo) { | |||||
size_t oh = oh_start + oho; | |||||
size_t ow = ow_start + owo; | |||||
__fp16 res = mid_buf1[oho * 4 + owo]; | |||||
if (bmode == BiasMode::BIAS) { | |||||
res += fp16_bias[oc * OH * OW + oh * OW + ow]; | |||||
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
res += fp16_bias[oc]; | |||||
} | |||||
res = op(res); | |||||
fp16_output[oc * OH * OW + oh * OW + ow] = res; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#undef OUTPUT_TRANSFORM | |||||
#undef GET_VECTOR_FP16Q_ELEM | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace winograd { | |||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f16) | |||||
void winograd_4x5_1x1_f16::filter(const dt_float16* filter, | |||||
dt_float16* filter_transform_buf, | |||||
dt_float16* transform_mid_buf, size_t OC, | |||||
size_t IC, size_t oc_start, size_t oc_end) { | |||||
FilterTransform4X5::transform( | |||||
reinterpret_cast<const __fp16*>(filter), | |||||
reinterpret_cast<__fp16*>(filter_transform_buf), | |||||
reinterpret_cast<__fp16*>(transform_mid_buf), OC, IC, oc_start, | |||||
oc_end); | |||||
} | |||||
void winograd_4x5_1x1_f16::input(const dt_float16* input, | |||||
dt_float16* input_transform_buf, | |||||
dt_float16* transform_mid_buf, size_t IH, | |||||
size_t IW, size_t IC, size_t PH, size_t PW, | |||||
size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
constexpr int alpha = 4 + 5 - 1; | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
rep(ic, IC) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||||
InputTransform4X5::transform<true>( | |||||
reinterpret_cast<const __fp16*>(input), | |||||
reinterpret_cast<__fp16*>(input_transform_buf), | |||||
reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, | |||||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||||
} else { | |||||
InputTransform4X5::transform<false>( | |||||
reinterpret_cast<const __fp16*>(input), | |||||
reinterpret_cast<__fp16*>(input_transform_buf), | |||||
reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, | |||||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void winograd_4x5_1x1_f16::output(const dt_float16* output_transform_buf, | |||||
const dt_float16* bias, dt_float16* output, | |||||
dt_float16* transform_mid_buf, BiasMode bmode, | |||||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, | |||||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
OutputTransform4X5<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp16_F45, cb, __fp16, __fp16, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||||
nr_units_in_tile, src_dtype, dst_dtype); | |||||
} | |||||
} | |||||
#undef cb | |||||
} | |||||
} // namespace winograd | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,608 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/f16/strategy_6x3.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/f16/helper.h" | |||||
#include "src/arm_common/conv_bias/f16/strategy.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F63) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
struct FilterTransform6X3 { | |||||
// 1.0000000 0.0000000 0.0000000 | |||||
//-0.2222222 -0.2222222 -0.2222222 | |||||
//-0.2222222 0.2222222 -0.2222222 | |||||
// 0.0111111 0.0222222 0.0444444 | |||||
// 0.0111111 -0.0222222 0.0444444 | |||||
// 0.7111111 0.3555556 0.1777778 | |||||
// 0.7111111 -0.3555556 0.1777778 | |||||
// 0.0000000 0.0000000 1.0000000 | |||||
#define FILTER_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
wd##1 = (d##0 + d##1 + d##2) * -0.222168; \ | |||||
wd##2 = (d##0 - d##1 + d##2) * -0.222168; \ | |||||
auto tmpd0 = d##0 * 0.011108; \ | |||||
auto tmpd1 = d##1 * 0.022217; \ | |||||
auto tmpd2 = d##2 * 0.044434; \ | |||||
wd##3 = tmpd0 + tmpd1 + tmpd2; \ | |||||
wd##4 = tmpd0 - tmpd1 + tmpd2; \ | |||||
tmpd0 = d##0 * 0.710938; \ | |||||
tmpd1 = d##1 * 0.355469; \ | |||||
tmpd2 = d##2 * 0.177734; \ | |||||
wd##5 = tmpd0 + tmpd1 + tmpd2; \ | |||||
wd##6 = tmpd0 - tmpd1 + tmpd2; \ | |||||
wd##7 = d##2; \ | |||||
} while (0); | |||||
#define FILTER_TRANSFORM_FINAL(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
wd##1 = (d##0 + d##1 + d##2) * -0.222168; \ | |||||
wd##2 = (d##0 - d##1 + d##2) * -0.222168; \ | |||||
auto tmp0 = d##0 * 0.011108 + d##2 * 0.044434; \ | |||||
auto tmp1 = d##1 * 0.022217; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d##0 * 0.710938 + d##2 * 0.177734; \ | |||||
tmp1 = d##1 * 0.355469; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = d##2; \ | |||||
} while (0); | |||||
static void transform(const __fp16* filter, __fp16* filter_transform_buf, | |||||
__fp16* transform_mid_buf, size_t OC, size_t IC, | |||||
size_t oc_start, size_t oc_end) { | |||||
// Gg * GT | |||||
// G | |||||
// 1.0000000 0.0000000 0.0000000 | |||||
//-0.2222222 -0.2222222 -0.2222222 | |||||
//-0.2222222 0.2222222 -0.2222222 | |||||
// 0.0111111 0.0222222 0.0444444 | |||||
// 0.0111111 -0.0222222 0.0444444 | |||||
// 0.7111111 0.3555556 0.1777778 | |||||
// 0.7111111 -0.3555556 0.1777778 | |||||
// 0.0000000 0.0000000 1.0000000 | |||||
constexpr size_t alpha = 6 + 3 - 1; | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
rep(ic, IC) { | |||||
const __fp16* fptr = filter + (oc * IC + ic) * 3 * 3; | |||||
float16x4_t v0, v1, v2, v3; | |||||
v0 = vld1_f16(fptr); // 0 1 2 3 | |||||
v1 = vld1_f16(fptr + 3); // 3 4 5 6 | |||||
v2 = vld1_f16(fptr + 5); // 5 6 7 8 | |||||
v3 = vdup_n_f16(0); | |||||
v2 = vext_f16(v2, v3, 1); | |||||
v0 = vset_lane_f16(0, v0, 3); | |||||
v1 = vset_lane_f16(0, v1, 3); | |||||
#define cb(i) Vector<__fp16, 4> g##i(v##i); | |||||
UNROLL_CALL_NOWRAPPER(3, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<__fp16, 4> Gg##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<__fp16, 8> Ggt##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<__fp16, 8> result##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
FILTER_TRANSFORM(g, Gg) | |||||
#if MEGDNN_AARCH64 | |||||
TRANSPOSE_8x4(Gg, Ggt); | |||||
FILTER_TRANSFORM_FINAL(Ggt, result); | |||||
#define cb(i) result##i.save(transform_mid_buf + i * alpha); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||||
oc] = transform_mid_buf[j * alpha + i]; | |||||
} | |||||
#else | |||||
/* 1.0000000 -0.2222222 -0.2222222 0.0111111 0.0111111 | |||||
0.7111111 0.7111111 0.0000000 0.0000000 -0.2222222 | |||||
0.2222222 0.0222222 -0.0222222 0.3555556 -0.3555556 | |||||
0.0000000 0.0000000 -0.2222222 -0.2222222 0.0444444 | |||||
0.0444444 0.1777778 0.1777778 1.0000000*/ | |||||
#define GET_VECTOR_FP16D_ELEM(s, i, idx) vget_lane_f16(CONCAT(s, i).value, idx) | |||||
#define cb(i) \ | |||||
do { \ | |||||
mid_buf1[0] = GET_VECTOR_FP16D_ELEM(Gg, i, 0); \ | |||||
auto tmp02 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) + \ | |||||
GET_VECTOR_FP16D_ELEM(Gg, i, 2); \ | |||||
mid_buf1[1] = (tmp02 + GET_VECTOR_FP16D_ELEM(Gg, i, 1)) * -0.2222222; \ | |||||
mid_buf1[2] = (tmp02 - GET_VECTOR_FP16D_ELEM(Gg, i, 1)) * -0.2222222; \ | |||||
auto tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.0111111; \ | |||||
auto tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.0222222; \ | |||||
auto tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.0444444; \ | |||||
tmp02 = tmp0 + tmp2; \ | |||||
mid_buf1[3] = tmp02 + tmp1; \ | |||||
mid_buf1[4] = tmp02 - tmp1; \ | |||||
tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.7111111; \ | |||||
tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.3555556; \ | |||||
tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.1777778; \ | |||||
tmp02 = tmp0 + tmp2; \ | |||||
mid_buf1[5] = tmp02 + tmp1; \ | |||||
mid_buf1[6] = tmp02 - tmp1; \ | |||||
mid_buf1[7] = GET_VECTOR_FP16D_ELEM(Gg, i, 2); \ | |||||
mid_buf1 += 8; \ | |||||
} while (0); | |||||
__fp16* mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||||
oc] = transform_mid_buf[i * alpha + j]; | |||||
} | |||||
#undef GET_VECTOR_FP16D_ELEM | |||||
#endif | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#undef FILTER_TRANSFORM | |||||
#undef FILTER_TRANSFORM_FINAL | |||||
/** | |||||
* input transform | |||||
* | |||||
* wd0 = (d0 - d6) + 5.25 * (d4 - d2) | |||||
* wd1 = (d6 + d2 - 4.25 * d4) + (d1 + d5 - 4.25 * d3) | |||||
* wd2 = (d6 + d2 - 4.25 * d4) - (d1 + d5 - 4.25 * d3) | |||||
* wd3 = (d6 + 0.25 * d2 - 1.25 * d4) + 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) | |||||
* wd4 = (d6 + 0.25 * d2 - 1.25 * d4) - 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) | |||||
* wd5 = (d6 - 5.0 * d4 + 4.0 * d2) + 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) | |||||
* wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) | |||||
* wd7 = (d7 - d1) + 5.25 * (d3 - d5) | |||||
*/ | |||||
#define INPUT_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25; \ | |||||
auto tmp0 = d##6 + d##2 - d##4 * 4.25; \ | |||||
auto tmp1 = d##1 + d##5 - d##3 * 4.25; \ | |||||
wd##1 = tmp0 + tmp1; \ | |||||
wd##2 = tmp0 - tmp1; \ | |||||
tmp0 = d##6 + d##2 * 0.25 - d##4 * 1.25; \ | |||||
tmp1 = (d##5 + d##1 * 0.25 - d##3 * 1.25) * 2.0; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d6 - d4 * 5.0 + d2 * 4.0; \ | |||||
tmp1 = (d1 + d5 * 0.25 - d3 * 1.25) * 2.0; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25; \ | |||||
} while (0); | |||||
#define GET_VECTOR_FP16Q_ELEM(s, i, idx) vgetq_lane_f16(CONCAT(s, i).value, idx) | |||||
struct InputTransform6x3 { | |||||
template <bool inner> | |||||
static void transform(const __fp16* input, __fp16* input_transform_buf, | |||||
__fp16* transform_mid_buf, int ih_start, int iw_start, | |||||
size_t ic, size_t IH, size_t IW, size_t IC, | |||||
size_t unit_idx, size_t nr_units_in_tile) { | |||||
// BTd * B | |||||
// 1.000 0.000 -5.25 0.000 5.250 0.000 -1.0 0.00 | |||||
// -0.00 1.000 1.000 -4.25 -4.25 1.000 1.00 -0.0 | |||||
// -0.00 -1.00 1.000 4.250 -4.25 -1.00 1.00 -0.0 | |||||
// 0.000 0.500 0.250 -2.50 -1.25 2.000 1.00 0.00 | |||||
// 0.000 -0.50 0.250 2.500 -1.25 -2.00 1.00 0.00 | |||||
// 0.000 2.000 4.000 -2.50 -5.00 0.500 1.00 0.00 | |||||
// 0.000 -2.00 4.000 2.500 -5.00 -0.50 1.00 0.00 | |||||
// 0.000 -1.00 0.000 5.250 0.000 -5.25 0.00 1.00 | |||||
constexpr size_t alpha = 6 + 3 - 1; | |||||
if (!inner) { | |||||
memset(transform_mid_buf, 0, sizeof(__fp16) * alpha * alpha); | |||||
} | |||||
#define cb(i) Vector<__fp16, 8> d##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
if (inner) { | |||||
const __fp16* input_ptr = | |||||
input + ic * IH * IW + ih_start * IW + iw_start; | |||||
#define cb(i) d##i = Vector<__fp16, 8>::load(input_ptr + IW * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
} else { | |||||
int ih0_act = std::max<int>(ih_start, 0), | |||||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||||
iw0_act = std::max<int>(iw_start, 0), | |||||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||||
transform_mid_buf[iho * alpha + iwo] = | |||||
input[ic * IH * IW + ih * IW + iw]; | |||||
} | |||||
} | |||||
#define cb(i) d##i = Vector<__fp16, 8>::load(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
} | |||||
#define cb(i) Vector<__fp16, 8> wd##i, ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
INPUT_TRANSFORM(d, wd); | |||||
#if MEGDNN_AARCH64 | |||||
TRANSPOSE_8x8(wd, d); | |||||
INPUT_TRANSFORM(d, ret); | |||||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||||
unit_idx * IC + ic] = | |||||
transform_mid_buf[j * alpha + i]; | |||||
} | |||||
#else | |||||
//! 1 0 0 0 0 0 0 0 | |||||
//! 0 1 -1 0.5 -0.5 2 -2 -1 | |||||
//! -5.25 1 1 0.25 0.25 4 4 0 | |||||
//! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25 | |||||
//! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0 | |||||
//! 0 1 -1 2 -2 0.5 -0.5 -5.25 | |||||
//! -1 1 1 1 1 1 1 0 | |||||
//! 0 0 0 0 0 0 0 1 | |||||
#define cb(i) \ | |||||
do { \ | |||||
mid_buf1[0] = GET_VECTOR_FP16Q_ELEM(wd, i, 0) - \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 6) + \ | |||||
5.25 * (GET_VECTOR_FP16Q_ELEM(wd, i, 4) - \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 2)); \ | |||||
mid_buf1[7] = GET_VECTOR_FP16Q_ELEM(wd, i, 7) - \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 1) + \ | |||||
5.25 * (GET_VECTOR_FP16Q_ELEM(wd, i, 3) - \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 5)); \ | |||||
auto tmp0 = GET_VECTOR_FP16Q_ELEM(wd, i, 2) + \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 6) - \ | |||||
4.25 * GET_VECTOR_FP16Q_ELEM(wd, i, 4); \ | |||||
auto tmp1 = GET_VECTOR_FP16Q_ELEM(wd, i, 1) - \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 3) * 4.25 + \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 5); \ | |||||
mid_buf1[1] = tmp0 + tmp1; \ | |||||
mid_buf1[2] = tmp0 - tmp1; \ | |||||
tmp0 = GET_VECTOR_FP16Q_ELEM(wd, i, 2) * 0.25 + \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 6) - \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 4) * 1.25; \ | |||||
tmp1 = GET_VECTOR_FP16Q_ELEM(wd, i, 1) * 0.5 - \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 3) * 2.5 + \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 5) * 2; \ | |||||
mid_buf1[3] = tmp0 + tmp1; \ | |||||
mid_buf1[4] = tmp0 - tmp1; \ | |||||
tmp0 = GET_VECTOR_FP16Q_ELEM(wd, i, 6) + \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 2) * 4.0 - \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 4) * 5.0; \ | |||||
tmp1 = GET_VECTOR_FP16Q_ELEM(wd, i, 1) * 2 - \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 3) * 2.5 + \ | |||||
GET_VECTOR_FP16Q_ELEM(wd, i, 5) * 0.5; \ | |||||
mid_buf1[5] = tmp0 + tmp1; \ | |||||
mid_buf1[6] = tmp0 - tmp1; \ | |||||
mid_buf1 += 8; \ | |||||
} while (0); | |||||
__fp16* mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||||
unit_idx * IC + ic] = | |||||
transform_mid_buf[i * alpha + j]; | |||||
} | |||||
#endif | |||||
} | |||||
}; | |||||
#undef INPUT_TRANSFORM | |||||
#define OUTPUT_TRANSFORM(m, r) \ | |||||
do { \ | |||||
auto m1addm2 = m##1 + m##2; \ | |||||
auto m1subm2 = m##1 - m##2; \ | |||||
auto m3addm4 = m##3 + m##4; \ | |||||
auto m3subm4 = m##3 - m##4; \ | |||||
auto m5addm6 = m##5 + m##6; \ | |||||
auto m5subm6 = m##5 - m##6; \ | |||||
r##0 = m##0 + m1addm2 + m3addm4 + m5addm6; \ | |||||
r##1 = m1subm2 + m3subm4 * 2.0 + m5subm6 * 0.5; \ | |||||
r##2 = m1addm2 + m3addm4 * 4.0 + m5addm6 * 0.25; \ | |||||
r##3 = m1subm2 + m3subm4 * 8.0 + m5subm6 * 0.125; \ | |||||
r##4 = m1addm2 + m3addm4 * 16.0 + m5addm6 * 0.0625; \ | |||||
r##5 = m1subm2 + m3subm4 * 32.0 + m5subm6 * 0.03125 + m##7; \ | |||||
} while (0) | |||||
template <BiasMode bmode, typename Op> | |||||
struct OutputTransform6X3 { | |||||
static void transform(const dt_float16* output_transform_buf, | |||||
const dt_float16* bias, dt_float16* output, | |||||
dt_float16* transform_mid_buf, size_t oh_start, | |||||
size_t ow_start, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, size_t oc_index, | |||||
size_t unit_idx, size_t nr_units_in_tile, | |||||
const DType& src_dtype, const DType& dst_dtype) { | |||||
Op op(src_dtype, dst_dtype); | |||||
//! AT * m * A | |||||
// AT f45 | |||||
// 1.0 1.0 1.0 1.0 1.0 1.00000 1.00000 0.0 | |||||
// 0.0 1.0 -1.0 2.0 -2.0 0.50000 -0.50000 0.0 | |||||
// 0.0 1.0 1.0 4.0 4.0 0.25000 0.25000 0.0 | |||||
// 0.0 1.0 -1.0 8.0 -8.0 0.12500 -0.12500 0.0 | |||||
// 0.0 1.0 1.0 16.0 16.0 0.06250 0.06250 0.0 | |||||
// 0.0 1.0 -1.0 32.0 -32.0 0.03125 -0.03125 1.0 | |||||
constexpr size_t alpha = 3 + 6 - 1; | |||||
const __fp16* fp16_output_transform_buf = | |||||
reinterpret_cast<const __fp16*>(output_transform_buf); | |||||
const __fp16* fp16_bias = reinterpret_cast<const __fp16*>(bias); | |||||
__fp16* fp16_output = reinterpret_cast<__fp16*>(output); | |||||
__fp16* fp16_transform_mid_buf = | |||||
reinterpret_cast<__fp16*>(transform_mid_buf); | |||||
__fp16* mid_buf1 = fp16_transform_mid_buf; | |||||
size_t OC = oc_end - oc_start; | |||||
size_t oc = oc_start + oc_index; | |||||
#define cb(m, n) \ | |||||
fp16_transform_mid_buf[m * alpha + n] = \ | |||||
fp16_output_transform_buf[(m * alpha + n) * nr_units_in_tile * \ | |||||
OC + \ | |||||
unit_idx * OC + oc_index]; | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||||
#undef cb | |||||
#define cb(i) \ | |||||
auto m##i = Vector<__fp16, 8>::load(fp16_transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<__fp16, 8> s##i; | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
#undef cb | |||||
/* 1.0 0.0 0.00 0.000 0.0000 0.00000 | |||||
1.0 1.0 1.00 1.000 1.0000 1.00000 | |||||
1.0 -1.0 1.00 -1.000 1.0000 -1.00000 | |||||
1.0 2.0 4.00 8.000 16.000 32.00000 | |||||
1.0 -2.0 4.00 -8.000 16.000 -32.00000 | |||||
1.0 0.5 0.25 0.125 0.0625 0.03125 | |||||
1.0 -0.5 0.25 -0.125 0.0625 -0.03125 | |||||
0.0 0.0 0.00 0.000 0.0000 1.00000*/ | |||||
OUTPUT_TRANSFORM(m, s); | |||||
mid_buf1 = fp16_transform_mid_buf; | |||||
#define cb(i) \ | |||||
do { \ | |||||
auto m1addm2 = GET_VECTOR_FP16Q_ELEM(s, i, 1) + \ | |||||
GET_VECTOR_FP16Q_ELEM(s, i, 2); \ | |||||
auto m1subm2 = GET_VECTOR_FP16Q_ELEM(s, i, 1) - \ | |||||
GET_VECTOR_FP16Q_ELEM(s, i, 2); \ | |||||
auto m3addm4 = GET_VECTOR_FP16Q_ELEM(s, i, 3) + \ | |||||
GET_VECTOR_FP16Q_ELEM(s, i, 4); \ | |||||
auto m3subm4 = GET_VECTOR_FP16Q_ELEM(s, i, 3) - \ | |||||
GET_VECTOR_FP16Q_ELEM(s, i, 4); \ | |||||
auto m5addm6 = GET_VECTOR_FP16Q_ELEM(s, i, 5) + \ | |||||
GET_VECTOR_FP16Q_ELEM(s, i, 6); \ | |||||
auto m5subm6 = GET_VECTOR_FP16Q_ELEM(s, i, 5) - \ | |||||
GET_VECTOR_FP16Q_ELEM(s, i, 6); \ | |||||
mid_buf1[0] = \ | |||||
GET_VECTOR_FP16Q_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ | |||||
mid_buf1[1] = m1subm2 + m3subm4 * 2 + m5subm6 * 0.5; \ | |||||
mid_buf1[2] = m1addm2 + m3addm4 * 4 + m5addm6 * 0.25; \ | |||||
mid_buf1[3] = m1subm2 + m3subm4 * 8 + m5subm6 * 0.125; \ | |||||
mid_buf1[4] = m1addm2 + m3addm4 * 16 + m5addm6 * 0.0625; \ | |||||
mid_buf1[5] = m1subm2 + m3subm4 * 32 + m5subm6 * 0.03125 + \ | |||||
GET_VECTOR_FP16Q_ELEM(s, i, 7); \ | |||||
mid_buf1 += 6; \ | |||||
} while (0); | |||||
mid_buf1 = fp16_transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
mid_buf1 = fp16_transform_mid_buf; | |||||
#undef cb | |||||
if (oh_start + 6 <= OH && ow_start + 6 <= OW) { | |||||
int index = (oc * OH + oh_start) * OW + ow_start; | |||||
#define cb(i) float16x4_t vr##i = vld1_f16(mid_buf1 + i * 6); | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
#undef cb | |||||
float16x8_t vr0123_45 = {mid_buf1[4], mid_buf1[5], mid_buf1[10], | |||||
mid_buf1[11], mid_buf1[16], mid_buf1[17], | |||||
mid_buf1[22], mid_buf1[23]}; | |||||
float16x4_t vr45_45 = {mid_buf1[28], mid_buf1[29], mid_buf1[34], | |||||
mid_buf1[35]}; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
float16x4_t bias0 = vdup_n_f16(fp16_bias[oc]); | |||||
#define cb(i) vr##i = vadd_f16(vr##i, bias0); | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
#undef cb | |||||
vr45_45 = vadd_f16(vr45_45, bias0); | |||||
vr0123_45 = vaddq_f16(vr0123_45, vcombine_f16(bias0, bias0)); | |||||
} else if (bmode == BiasMode::BIAS) { | |||||
#define cb(i) float16x4_t bmbias##i = vld1_f16(fp16_bias + index + OW * i); | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
#undef cb | |||||
#define cb(i) vr##i = vadd_f16(vr##i, bmbias##i); | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
#undef cb | |||||
float16x8_t vb0123_45 = {fp16_bias[index + 0 * OW + 4], | |||||
fp16_bias[index + 0 * OW + 5], | |||||
fp16_bias[index + 1 * OW + 4], | |||||
fp16_bias[index + 1 * OW + 5], | |||||
fp16_bias[index + 2 * OW + 4], | |||||
fp16_bias[index + 2 * OW + 5], | |||||
fp16_bias[index + 3 * OW + 4], | |||||
fp16_bias[index + 3 * OW + 5]}; | |||||
float16x4_t vb45_45 = {fp16_bias[index + 4 * OW + 4], | |||||
fp16_bias[index + 4 * OW + 5], | |||||
fp16_bias[index + 5 * OW + 4], | |||||
fp16_bias[index + 5 * OW + 5]}; | |||||
vr45_45 = vadd_f16(vr45_45, vb45_45); | |||||
vr0123_45 = vaddq_f16(vr0123_45, vb0123_45); | |||||
} | |||||
float16x8_t item01 = op(vcombine_f16(vr0, vr1)); | |||||
float16x8_t item23 = op(vcombine_f16(vr2, vr3)); | |||||
float16x8_t item45 = op(vcombine_f16(vr4, vr5)); | |||||
vst1_f16(fp16_output + index, vget_low_f16(item01)); | |||||
vst1_f16(fp16_output + index + OW, vget_high_f16(item01)); | |||||
vst1_f16(fp16_output + index + OW * 2, vget_low_f16(item23)); | |||||
vst1_f16(fp16_output + index + OW * 3, vget_high_f16(item23)); | |||||
vst1_f16(fp16_output + index + OW * 4, vget_low_f16(item45)); | |||||
vst1_f16(fp16_output + index + OW * 5, vget_high_f16(item45)); | |||||
vr0123_45 = op(vr0123_45); | |||||
float16x8_t vr45 = op(vcombine_f16(vr45_45, vr45_45)); | |||||
fp16_output[index + OW * 0 + 4] = vgetq_lane_f16(vr0123_45, 0); | |||||
fp16_output[index + OW * 0 + 5] = vgetq_lane_f16(vr0123_45, 1); | |||||
fp16_output[index + OW * 1 + 4] = vgetq_lane_f16(vr0123_45, 2); | |||||
fp16_output[index + OW * 1 + 5] = vgetq_lane_f16(vr0123_45, 3); | |||||
fp16_output[index + OW * 2 + 4] = vgetq_lane_f16(vr0123_45, 4); | |||||
fp16_output[index + OW * 2 + 5] = vgetq_lane_f16(vr0123_45, 5); | |||||
fp16_output[index + OW * 3 + 4] = vgetq_lane_f16(vr0123_45, 6); | |||||
fp16_output[index + OW * 3 + 5] = vgetq_lane_f16(vr0123_45, 7); | |||||
fp16_output[index + OW * 4 + 4] = vgetq_lane_f16(vr45, 0); | |||||
fp16_output[index + OW * 4 + 5] = vgetq_lane_f16(vr45, 1); | |||||
fp16_output[index + OW * 5 + 4] = vgetq_lane_f16(vr45, 2); | |||||
fp16_output[index + OW * 5 + 5] = vgetq_lane_f16(vr45, 3); | |||||
} else { | |||||
for (size_t oho = 0; oho < 6 && oh_start + oho < OH; ++oho) { | |||||
for (size_t owo = 0; owo < 6 && ow_start + owo < OW; ++owo) { | |||||
size_t oh = oh_start + oho; | |||||
size_t ow = ow_start + owo; | |||||
__fp16 res = mid_buf1[oho * 6 + owo]; | |||||
if (bmode == BiasMode::BIAS) { | |||||
res += fp16_bias[oc * OH * OW + oh * OW + ow]; | |||||
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
res += fp16_bias[oc]; | |||||
} | |||||
res = op(res); | |||||
fp16_output[oc * OH * OW + oh * OW + ow] = res; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#undef GET_VECTOR_FP16Q_ELEM | |||||
#undef OUTPUT_TRANSFORM | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace winograd { | |||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f16) | |||||
void winograd_6x3_1x1_f16::filter(const dt_float16* filter, | |||||
dt_float16* filter_transform_buf, | |||||
dt_float16* transform_mid_buf, size_t OC, | |||||
size_t IC, size_t oc_start, size_t oc_end) { | |||||
FilterTransform6X3::transform( | |||||
reinterpret_cast<const __fp16*>(filter), | |||||
reinterpret_cast<__fp16*>(filter_transform_buf), | |||||
reinterpret_cast<__fp16*>(transform_mid_buf), OC, IC, oc_start, | |||||
oc_end); | |||||
} | |||||
void winograd_6x3_1x1_f16::input(const dt_float16* input, | |||||
dt_float16* input_transform_buf, | |||||
dt_float16* transform_mid_buf, size_t IH, | |||||
size_t IW, size_t IC, size_t PH, size_t PW, | |||||
size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
constexpr int alpha = 6 + 3 - 1; | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
rep(ic, IC) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||||
InputTransform6x3::transform<true>( | |||||
reinterpret_cast<const __fp16*>(input), | |||||
reinterpret_cast<__fp16*>(input_transform_buf), | |||||
reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, | |||||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||||
} else { | |||||
InputTransform6x3::transform<false>( | |||||
reinterpret_cast<const __fp16*>(input), | |||||
reinterpret_cast<__fp16*>(input_transform_buf), | |||||
reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, | |||||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void winograd_6x3_1x1_f16::output(const dt_float16* output_transform_buf, | |||||
const dt_float16* bias, dt_float16* output, | |||||
dt_float16* transform_mid_buf, BiasMode bmode, | |||||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, | |||||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp16_F63, cb, __fp16, __fp16, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||||
nr_units_in_tile, src_dtype, dst_dtype); | |||||
} | |||||
} | |||||
#undef cb | |||||
} | |||||
} // namespace winograd | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,754 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/algos.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/fp32/algos.h" | |||||
#include "src/arm_common/conv_bias/direct/multi_thread_common.h" | |||||
#include "src/arm_common/conv_bias/fp32/direct.h" | |||||
#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" | |||||
#include "src/arm_common/conv_bias/fp32/do_conv_stride2.h" | |||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/conv_bias/img2col_helper.h" | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#include "src/common/opr_delegate.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
/* ======================= AlgoFP32WinogradF23_4x4 ======================== */ | |||||
bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( | |||||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||||
MEGDNN_MARK_USED_VAR(opr); | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 0) { | |||||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | |||||
return false; | |||||
using Strategy = winograd::winograd_2x3_4x4_f; | |||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||||
auto&& matmul_param = | |||||
megdnn::winograd::ConvBias<Strategy, | |||||
param::MatrixMul::Format::MK4>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_matmul_kern_param(param); | |||||
return m_matmul_algo->usable(matmul_param) && | |||||
(opr->param().format == param::ConvBias::Format::NCHW || | |||||
(opr->param().format == | |||||
param::ConvBias::Format::NCHW_WINOGRAD && | |||||
opr->param().output_block_size == 2 && | |||||
param.winograd_matmul_format == | |||||
param::MatrixMul::Format::MK4)) && | |||||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||||
param.filter_meta.spatial[0] == 3) && | |||||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||||
param.filter_meta.stride[0] == 1) && | |||||
(param.filter_meta.dilation[0] == | |||||
param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||||
param.src_type.enumv() == DTypeEnum::Float32; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoFP32WinogradF23_4x4::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 1) { | |||||
winograd::winograd_2x3_4x4_f strategy(param.src_type, param.filter_type, | |||||
param.dst_type); | |||||
return megdnn::winograd::ConvBias<winograd::winograd_2x3_4x4_f, | |||||
param::MatrixMul::Format::MK4>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_workspace_size(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoFP32WinogradF23_4x4::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 2) { | |||||
winograd::winograd_2x3_4x4_f strategy(param.src_type, param.filter_type, | |||||
param.dst_type); | |||||
auto winograd_impl = | |||||
megdnn::winograd::ConvBias<winograd::winograd_2x3_4x4_f, | |||||
param::MatrixMul::Format::MK4>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg); | |||||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
/* ======================= AlgoFP32WinogradF63 ======================== */ | |||||
bool ConvBiasImpl::AlgoFP32WinogradF63::usable( | |||||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MEGDNN_MARK_USED_VAR(opr); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 0) { | |||||
using Strategy = winograd::winograd_6x3_1x1_f; | |||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||||
auto&& matmul_param = | |||||
megdnn::winograd::ConvBias<Strategy>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_matmul_kern_param(param); | |||||
return m_matmul_algo->usable(matmul_param) && | |||||
(opr->param().format == param::ConvBias::Format::NCHW || | |||||
(opr->param().format == | |||||
param::ConvBias::Format::NCHW_WINOGRAD && | |||||
opr->param().output_block_size == 6 && | |||||
param.winograd_matmul_format == | |||||
param::MatrixMul::Format::DEFAULT)) && | |||||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||||
param.filter_meta.spatial[0] == 3) && | |||||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||||
param.filter_meta.stride[0] == 1) && | |||||
(param.filter_meta.dilation[0] == | |||||
param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||||
param.src_type.enumv() == DTypeEnum::Float32; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoFP32WinogradF63::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 1) { | |||||
winograd::winograd_6x3_1x1_f strategy(param.src_type, param.filter_type, | |||||
param.dst_type); | |||||
return megdnn::winograd::ConvBias<winograd::winograd_6x3_1x1_f>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_workspace_size(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoFP32WinogradF63::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 2) { | |||||
winograd::winograd_6x3_1x1_f strategy(param.src_type, param.filter_type, | |||||
param.dst_type); | |||||
auto winograd_impl = | |||||
megdnn::winograd::ConvBias<winograd::winograd_6x3_1x1_f>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg); | |||||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
/* ======================= AlgoFP32WinogradF54 ======================== */ | |||||
bool ConvBiasImpl::AlgoFP32WinogradF54::usable( | |||||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MEGDNN_MARK_USED_VAR(opr); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 0) { | |||||
using Strategy = winograd::winograd_5x4_1x1_f; | |||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||||
auto&& matmul_param = | |||||
megdnn::winograd::ConvBias<Strategy>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_matmul_kern_param(param); | |||||
return m_matmul_algo->usable(matmul_param) && | |||||
(opr->param().format == param::ConvBias::Format::NCHW || | |||||
(opr->param().format == | |||||
param::ConvBias::Format::NCHW_WINOGRAD && | |||||
opr->param().output_block_size == 5 && | |||||
param.winograd_matmul_format == | |||||
param::MatrixMul::Format::DEFAULT)) && | |||||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||||
param.filter_meta.spatial[0] == 4) && | |||||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||||
param.filter_meta.stride[0] == 1) && | |||||
(param.filter_meta.dilation[0] == | |||||
param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||||
param.src_type.enumv() == DTypeEnum::Float32; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoFP32WinogradF54::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 1) { | |||||
winograd::winograd_5x4_1x1_f strategy(param.src_type, param.filter_type, | |||||
param.dst_type); | |||||
return megdnn::winograd::ConvBias<winograd::winograd_5x4_1x1_f>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_workspace_size(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoFP32WinogradF54::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 2) { | |||||
winograd::winograd_5x4_1x1_f strategy(param.src_type, param.filter_type, | |||||
param.dst_type); | |||||
auto winograd_impl = | |||||
megdnn::winograd::ConvBias<winograd::winograd_5x4_1x1_f>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg); | |||||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
/* ======================= AlgoFP32WinogradF45 ======================== */ | |||||
bool ConvBiasImpl::AlgoFP32WinogradF45::usable( | |||||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MEGDNN_MARK_USED_VAR(opr); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 0) { | |||||
using Strategy = winograd::winograd_4x5_1x1_f; | |||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||||
auto&& matmul_param = | |||||
megdnn::winograd::ConvBias<Strategy>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_matmul_kern_param(param); | |||||
return m_matmul_algo->usable(matmul_param) && | |||||
(opr->param().format == param::ConvBias::Format::NCHW || | |||||
(opr->param().format == | |||||
param::ConvBias::Format::NCHW_WINOGRAD && | |||||
opr->param().output_block_size == 4 && | |||||
param.winograd_matmul_format == | |||||
param::MatrixMul::Format::DEFAULT)) && | |||||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||||
param.filter_meta.spatial[0] == 5) && | |||||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||||
param.filter_meta.stride[0] == 1) && | |||||
(param.filter_meta.dilation[0] == | |||||
param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||||
param.src_type.enumv() == DTypeEnum::Float32; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoFP32WinogradF45::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 1) { | |||||
winograd::winograd_4x5_1x1_f strategy(param.src_type, param.filter_type, | |||||
param.dst_type); | |||||
return megdnn::winograd::ConvBias<winograd::winograd_4x5_1x1_f>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_workspace_size(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoFP32WinogradF45::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 2) { | |||||
winograd::winograd_4x5_1x1_f strategy(param.src_type, param.filter_type, | |||||
param.dst_type); | |||||
auto winograd_impl = | |||||
megdnn::winograd::ConvBias<winograd::winograd_4x5_1x1_f>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg); | |||||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
/* ======================= AlgoFP32WinogradF63_4x4 ======================== */ | |||||
bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( | |||||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MEGDNN_MARK_USED_VAR(opr); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 0) { | |||||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | |||||
return false; | |||||
using Strategy = winograd::winograd_6x3_4x4_f; | |||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||||
auto&& matmul_param = | |||||
megdnn::winograd::ConvBias<Strategy, | |||||
param::MatrixMul::Format::MK4>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_matmul_kern_param(param); | |||||
return m_matmul_algo->usable(matmul_param) && | |||||
(opr->param().format == param::ConvBias::Format::NCHW || | |||||
(opr->param().format == | |||||
param::ConvBias::Format::NCHW_WINOGRAD && | |||||
opr->param().output_block_size == 6 && | |||||
param.winograd_matmul_format == | |||||
param::MatrixMul::Format::MK4)) && | |||||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||||
param.filter_meta.spatial[0] == 3) && | |||||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||||
param.filter_meta.stride[0] == 1) && | |||||
(param.filter_meta.dilation[0] == | |||||
param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_meta.icpg % 4 == 0 && | |||||
param.filter_meta.ocpg % 4 == 0; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoFP32WinogradF63_4x4::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 1) { | |||||
winograd::winograd_6x3_4x4_f strategy(param.src_type, param.filter_type, | |||||
param.dst_type); | |||||
return megdnn::winograd::ConvBias<winograd::winograd_6x3_4x4_f, | |||||
param::MatrixMul::Format::MK4>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg) | |||||
.get_workspace_size(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoFP32WinogradF63_4x4::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 2) { | |||||
winograd::winograd_6x3_4x4_f strategy(param.src_type, param.filter_type, | |||||
param.dst_type); | |||||
auto winograd_impl = | |||||
megdnn::winograd::ConvBias<winograd::winograd_6x3_4x4_f, | |||||
param::MatrixMul::Format::MK4>( | |||||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||||
param.osz[1], param.filter_meta.ocpg); | |||||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
/* ===================== direct algo ===================== */ | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl); | |||||
bool ConvBiasImpl::AlgoF32Direct::usable( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) { | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
auto SH = fm.stride[0], SW = fm.stride[1]; | |||||
// the condition ``param.isz[0]*param.isz[1] >= 4'' and | |||||
// ``param.osz[0]*param.osz[1] >= 4'' comes from the fact that the | |||||
// kernel may have access to up to 4 floats after the end of the memory | |||||
// chunk. | |||||
bool aviliable = fm.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && | |||||
param.isz[0] * param.isz[1] >= 4 && | |||||
param.osz[0] * param.osz[1] >= 4 && FH <= 7 && | |||||
SH == 1 && SW == 1; | |||||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoF32Direct::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { | |||||
auto wbundle = MultithreadDirectConvCommon<float, float>::get_bundle( | |||||
param, m_large_group); | |||||
return wbundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls( | |||||
const NCBKernSizeParam& param) const { | |||||
auto fm = param.filter_meta; | |||||
size_t N = param.n; | |||||
size_t IC = param.filter_meta.icpg; | |||||
size_t OC = param.filter_meta.ocpg; | |||||
size_t group = fm.group; | |||||
WorkspaceBundle wbundle = | |||||
MultithreadDirectConvCommon<float, float>::get_bundle( | |||||
param, m_large_group); | |||||
SmallVector<NCBKern> ret_kerns; | |||||
//! When group >= nr_threads, treat it as large_group, each thread process | |||||
//! one group for better performance | |||||
if (m_large_group) { | |||||
//! Channel wise conv and big groups | |||||
auto exec_one_group = [wbundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
auto fm = kern_param.filter_meta; | |||||
size_t IC = fm.icpg; | |||||
size_t OC = fm.ocpg; | |||||
WorkspaceBundle bundle = wbundle; | |||||
if (fm.should_flip) { | |||||
for (size_t oc = 0; oc < OC; oc++) { | |||||
MultithreadDirectConvCommon<float, float>::weight_flip_kern( | |||||
bundle, kern_param, ncb_index, | |||||
{ncb_index.thread_id, 0, oc}); | |||||
} | |||||
} | |||||
for (size_t ic = 0; ic < IC; ic++) { | |||||
MultithreadDirectConvCommon<float, float>::copy_padding_kern( | |||||
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); | |||||
} | |||||
for (size_t oc = 0; oc < OC; oc++) { | |||||
MultithreadDirectConvCommon<float, float>::do_conv_kern( | |||||
bundle, kern_param, ncb_index, | |||||
fp32::conv_bias::kern_direct, | |||||
{ncb_index.thread_id, 0, oc}); | |||||
} | |||||
}; | |||||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||||
} else { | |||||
WorkspaceBundle bundle = wbundle; | |||||
if (fm.should_flip) { | |||||
auto weight_flip = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
MultithreadDirectConvCommon<float, float>::weight_flip_kern( | |||||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({weight_flip, {group, 1_z, OC}}); | |||||
} | |||||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
MultithreadDirectConvCommon<float, float>::copy_padding_kern( | |||||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||||
auto do_conv = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
MultithreadDirectConvCommon<float, float>::do_conv_kern( | |||||
bundle, kern_param, ncb_index, fp32::conv_bias::kern_direct, | |||||
ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||||
} | |||||
return ret_kerns; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { | |||||
return get_kimpls(param); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
/* ===================== stride-1 algo ===================== */ | |||||
bool ConvBiasImpl::AlgoF32DirectStride1::usable( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
bool aviliable = | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | |||||
auto bundle = | |||||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||||
param, m_large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( | |||||
const NCBKernSizeParam& param) const { | |||||
auto fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
size_t N = param.n; | |||||
size_t IC = param.filter_meta.icpg; | |||||
size_t OC = param.filter_meta.ocpg; | |||||
size_t group = fm.group; | |||||
using Func = std::function<void(const float*, const float*, float*, size_t, | |||||
size_t, size_t, size_t, size_t)>; | |||||
Func conv_kern_function = nullptr; | |||||
#define SWITCH_KERN_STR1() \ | |||||
switch (FH) { \ | |||||
case 2: \ | |||||
conv_kern_function = fp32::conv_stride1::do_conv_2x2_stride1; \ | |||||
break; \ | |||||
case 3: \ | |||||
conv_kern_function = fp32::conv_stride1::do_conv_3x3_stride1; \ | |||||
break; \ | |||||
case 5: \ | |||||
conv_kern_function = fp32::conv_stride1::do_conv_5x5_stride1; \ | |||||
break; \ | |||||
case 7: \ | |||||
conv_kern_function = fp32::conv_stride1::do_conv_7x7_stride1; \ | |||||
break; \ | |||||
} | |||||
SWITCH_KERN_STR1(); | |||||
WorkspaceBundle wbundle = | |||||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||||
param, m_large_group); | |||||
SmallVector<NCBKern> ret_kerns; | |||||
//! When group >= nr_threads, treat it as large_group, each thread process | |||||
//! one group for better performance | |||||
if (m_large_group) { | |||||
//! Channel wise conv and big groups | |||||
auto exec_one_group = [wbundle, conv_kern_function]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
auto fm = kern_param.filter_meta; | |||||
size_t IC = fm.icpg; | |||||
size_t OC = fm.ocpg; | |||||
WorkspaceBundle bundle = wbundle; | |||||
for (size_t ic = 0; ic < IC; ic++) { | |||||
MultithreadDirectConvCommon<float, float>:: | |||||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
{ncb_index.thread_id, 0, ic}); | |||||
} | |||||
for (size_t oc = 0; oc < OC; oc++) { | |||||
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( | |||||
bundle, kern_param, ncb_index, conv_kern_function, | |||||
{ncb_index.thread_id, 0, oc}); | |||||
} | |||||
}; | |||||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||||
} else { | |||||
WorkspaceBundle bundle = wbundle; | |||||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
MultithreadDirectConvCommon<float, float>::copy_padding_kern_stride( | |||||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||||
auto do_conv = [bundle, conv_kern_function]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( | |||||
bundle, kern_param, ncb_index, conv_kern_function, | |||||
ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||||
} | |||||
return ret_kerns; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 2) { | |||||
return get_kimpls(param); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
/* ===================== stride-2 algo ===================== */ | |||||
bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) { | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
bool aviliable = | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||||
if (algo_selection_strategy == | |||||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
aviliable &= (large_group == m_large_group); | |||||
} | |||||
return aviliable; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) { | |||||
auto bundle = | |||||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||||
param, m_large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
const NCBKernSizeParam& param) const { | |||||
auto fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
size_t N = param.n; | |||||
size_t IC = param.filter_meta.icpg; | |||||
size_t OC = param.filter_meta.ocpg; | |||||
size_t group = fm.group; | |||||
using Func = std::function<void(const float*, const float*, float*, size_t, | |||||
size_t, size_t, size_t, size_t)>; | |||||
Func conv_kern_function = nullptr; | |||||
#define SWITCH_KERN_STR2() \ | |||||
switch (FH) { \ | |||||
case 2: \ | |||||
conv_kern_function = fp32::conv_stride2::do_conv_2x2_stride2; \ | |||||
break; \ | |||||
case 3: \ | |||||
conv_kern_function = fp32::conv_stride2::do_conv_3x3_stride2; \ | |||||
break; \ | |||||
case 5: \ | |||||
conv_kern_function = fp32::conv_stride2::do_conv_5x5_stride2; \ | |||||
break; \ | |||||
case 7: \ | |||||
conv_kern_function = fp32::conv_stride2::do_conv_7x7_stride2; \ | |||||
break; \ | |||||
} | |||||
SWITCH_KERN_STR2(); | |||||
WorkspaceBundle wbundle = | |||||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||||
param, m_large_group); | |||||
SmallVector<NCBKern> ret_kerns; | |||||
//! When group >= nr_threads, treat it as large_group, each thread process | |||||
//! one group for better performance | |||||
if (m_large_group) { | |||||
//! Channel wise conv and big groups | |||||
auto exec_one_group = [wbundle, conv_kern_function]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
auto fm = kern_param.filter_meta; | |||||
size_t IC = fm.icpg; | |||||
size_t OC = fm.ocpg; | |||||
WorkspaceBundle bundle = wbundle; | |||||
for (size_t ic = 0; ic < IC; ic++) { | |||||
MultithreadDirectConvCommon<float, float>:: | |||||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
{ncb_index.thread_id, 0, ic}); | |||||
} | |||||
for (size_t oc = 0; oc < OC; oc++) { | |||||
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( | |||||
bundle, kern_param, ncb_index, conv_kern_function, | |||||
{ncb_index.thread_id, 0, oc}); | |||||
} | |||||
}; | |||||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||||
} else { | |||||
WorkspaceBundle bundle = wbundle; | |||||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
MultithreadDirectConvCommon<float, float>::copy_padding_kern_stride( | |||||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||||
auto do_conv = [bundle, conv_kern_function]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( | |||||
bundle, kern_param, ncb_index, conv_kern_function, | |||||
ncb_index.ndrange_id); | |||||
}; | |||||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||||
} | |||||
return ret_kerns; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 2) { | |||||
return get_kimpls(param); | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,223 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/algos.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/fallback/matrix_mul/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
class ConvBiasImpl::AlgoFP32WinogradF23_4x4 final : public AlgoBase { | |||||
public: | |||||
AlgoFP32WinogradF23_4x4(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
uint32_t tile_size) | |||||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
if (m_name.empty()) { | |||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||||
m_matmul_algo->name(), {4, 2, m_tile_size}); | |||||
} | |||||
return m_name.c_str(); | |||||
} | |||||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
private: | |||||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||||
mutable std::string m_name; | |||||
uint32_t m_tile_size; | |||||
}; | |||||
class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { | |||||
public: | |||||
AlgoFP32WinogradF63(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
uint32_t tile_size) | |||||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
if (m_name.empty()) { | |||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||||
m_matmul_algo->name(), {1, 6, m_tile_size}); | |||||
} | |||||
return m_name.c_str(); | |||||
} | |||||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
private: | |||||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||||
mutable std::string m_name; | |||||
uint32_t m_tile_size; | |||||
}; | |||||
class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | |||||
public: | |||||
AlgoFP32WinogradF63_4x4(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
uint32_t tile_size) | |||||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
if (m_name.empty()) { | |||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||||
m_matmul_algo->name(), {4, 6, m_tile_size}); | |||||
} | |||||
return m_name.c_str(); | |||||
} | |||||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
private: | |||||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||||
mutable std::string m_name; | |||||
uint32_t m_tile_size; | |||||
}; | |||||
class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { | |||||
public: | |||||
AlgoFP32WinogradF54(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
uint32_t tile_size) | |||||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
if (m_name.empty()) { | |||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||||
m_matmul_algo->name(), {1, 5, m_tile_size}); | |||||
} | |||||
return m_name.c_str(); | |||||
} | |||||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
private: | |||||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||||
mutable std::string m_name; | |||||
uint32_t m_tile_size; | |||||
}; | |||||
class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { | |||||
public: | |||||
AlgoFP32WinogradF45(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
uint32_t tile_size) | |||||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
if (m_name.empty()) { | |||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||||
m_matmul_algo->name(), {1, 4, m_tile_size}); | |||||
} | |||||
return m_name.c_str(); | |||||
} | |||||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
private: | |||||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||||
mutable std::string m_name; | |||||
uint32_t m_tile_size; | |||||
}; | |||||
class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||||
bool m_large_group; | |||||
public: | |||||
AlgoF32Direct(bool is_large_group) : m_large_group{is_large_group} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return m_large_group ? "F32DIRECT_LARGE_GROUP" | |||||
: "F32DIRECT_SMALL_GROUP"; | |||||
} | |||||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
}; | |||||
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||||
bool m_large_group; | |||||
public: | |||||
AlgoF32DirectStride1(bool is_large_group) : m_large_group{is_large_group} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return m_large_group ? "F32STRD1_LARGE_GROUP" : "F32STRD1_SMALL_GROUP"; | |||||
} | |||||
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
}; | |||||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||||
bool m_large_group; | |||||
public: | |||||
AlgoF32DirectStride2(bool is_large_group) : m_large_group{is_large_group} {} | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return m_large_group ? "F32STRD2_LARGE_GROUP" : "F32STRD2_SMALL_GROUP"; | |||||
} | |||||
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
}; | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,911 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/direct.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <cstring> | |||||
#include "include/megdnn/oprs.h" | |||||
#include "midout.h" | |||||
#include "src/arm_common/conv_bias/fp32/direct.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#include "src/common/unroll_macro.h" | |||||
MIDOUT_DECL(megdnn_arm_conv_f32) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
using namespace fp32; | |||||
using namespace conv_bias; | |||||
namespace { | |||||
template <int FH, int height, int width> | |||||
struct do_pixel_proxy { | |||||
static void exec(const float* src, const float* filter, float* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow); | |||||
}; | |||||
#define cb_load(i) data = vld1q_lane_f32(dst + i, data, i); | |||||
#define LOAD_OUT \ | |||||
if (width < 4) { \ | |||||
auto load_less_4 = [](float* dst, float32x4_t& data) { \ | |||||
if (width == 1u) { \ | |||||
UNROLL_CALL_NOWRAPPER(1, cb_load); \ | |||||
} else if (width == 2u) { \ | |||||
UNROLL_CALL_NOWRAPPER(2, cb_load); \ | |||||
} else if (width == 3u) { \ | |||||
UNROLL_CALL_NOWRAPPER(3, cb_load); \ | |||||
} \ | |||||
}; \ | |||||
if (height >= 1) \ | |||||
load_less_4(dst + 0 * OW, out0); \ | |||||
if (height >= 2) \ | |||||
load_less_4(dst + 1 * OW, out1); \ | |||||
if (height >= 3) \ | |||||
load_less_4(dst + 2 * OW, out2); \ | |||||
if (height >= 4) \ | |||||
load_less_4(dst + 3 * OW, out3); \ | |||||
} else { \ | |||||
if (height > 0) \ | |||||
out0 = vld1q_f32(dst + 0 * OW); \ | |||||
if (height > 1) \ | |||||
out1 = vld1q_f32(dst + 1 * OW); \ | |||||
if (height > 2) \ | |||||
out2 = vld1q_f32(dst + 2 * OW); \ | |||||
if (height > 3) \ | |||||
out3 = vld1q_f32(dst + 3 * OW); \ | |||||
} | |||||
#define cb_store(i) vst1q_lane_f32(dst + i, data, i); | |||||
#define STORE_OUT \ | |||||
if (width < 4) { \ | |||||
auto store_less_4 = [](float* dst, float32x4_t& data) { \ | |||||
if (width == 1u) { \ | |||||
UNROLL_CALL_NOWRAPPER(1, cb_store); \ | |||||
} else if (width == 2u) { \ | |||||
UNROLL_CALL_NOWRAPPER(2, cb_store); \ | |||||
} else if (width == 3u) { \ | |||||
UNROLL_CALL_NOWRAPPER(3, cb_store); \ | |||||
} \ | |||||
}; \ | |||||
if (height >= 1) \ | |||||
store_less_4(dst + 0 * OW, out0); \ | |||||
if (height >= 2) \ | |||||
store_less_4(dst + 1 * OW, out1); \ | |||||
if (height >= 3) \ | |||||
store_less_4(dst + 2 * OW, out2); \ | |||||
if (height >= 4) \ | |||||
store_less_4(dst + 3 * OW, out3); \ | |||||
} else { \ | |||||
if (height >= 1) \ | |||||
vst1q_f32(dst + 0 * OW, out0); \ | |||||
if (height >= 2) \ | |||||
vst1q_f32(dst + 1 * OW, out1); \ | |||||
if (height >= 3) \ | |||||
vst1q_f32(dst + 2 * OW, out2); \ | |||||
if (height >= 4) \ | |||||
vst1q_f32(dst + 3 * OW, out3); \ | |||||
} | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<1, height, width> { | |||||
static void exec(const float* src, const float* filter, float* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
(void)IH; | |||||
(void)OH; | |||||
const int ih = oh, iw = ow; | |||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_OUT; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const float* src_dd = src + fw; | |||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
if (height > 1) | |||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
if (height > 2) | |||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
if (height > 3) | |||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
} | |||||
STORE_OUT; | |||||
} | |||||
}; | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<2, height, width> { | |||||
static void exec(const float* src, const float* filter, float* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
(void)IH; | |||||
(void)OH; | |||||
const int ih = oh, iw = ow; | |||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_OUT; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const float* src_dd = src + fw; | |||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr1); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
if (height > 1) | |||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr1); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
if (height > 2) | |||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr1); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
if (height > 3) | |||||
inp = vld1q_f32(src_dd + 4 * IW); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr1); | |||||
} | |||||
STORE_OUT; | |||||
} | |||||
}; | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<3, height, width> { | |||||
static void exec(const float* src, const float* filter, float* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
(void)IH; | |||||
(void)OH; | |||||
const int ih = oh, iw = ow; | |||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_OUT; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const float* src_dd = src + fw; | |||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr1); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr2); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr1); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
if (height > 1) | |||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr2); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr1); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
if (height > 2) | |||||
inp = vld1q_f32(src_dd + 4 * IW); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr2); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr1); | |||||
if (height > 3) | |||||
inp = vld1q_f32(src_dd + 5 * IW); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr2); | |||||
} | |||||
STORE_OUT; | |||||
} | |||||
}; | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<4, height, width> { | |||||
static void exec(const float* src, const float* filter, float* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
(void)IH; | |||||
(void)OH; | |||||
const int ih = oh, iw = ow; | |||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_OUT; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const float* src_dd = src + fw; | |||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr1); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr2); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr1); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr3); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr2); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr1); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
if (height > 1) | |||||
inp = vld1q_f32(src_dd + 4 * IW); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr3); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr2); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr1); | |||||
if (height > 2) | |||||
inp = vld1q_f32(src_dd + 5 * IW); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr3); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr2); | |||||
if (height > 3) | |||||
inp = vld1q_f32(src_dd + 6 * IW); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr3); | |||||
} | |||||
STORE_OUT; | |||||
} | |||||
}; | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<5, height, width> { | |||||
static void exec(const float* src, const float* filter, float* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
(void)IH; | |||||
(void)OH; | |||||
const int ih = oh, iw = ow; | |||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, | |||||
inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_OUT; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const float* src_dd = src + fw; | |||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||||
kr4 = vdupq_n_f32(filter[4 * FW + fw]); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr1); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr2); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr1); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr3); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr2); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr1); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 4 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr4); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr3); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr2); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr1); | |||||
if (height > 1) | |||||
inp = vld1q_f32(src_dd + 5 * IW); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr4); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr3); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr2); | |||||
if (height > 2) | |||||
inp = vld1q_f32(src_dd + 6 * IW); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr4); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr3); | |||||
if (height > 3) | |||||
inp = vld1q_f32(src_dd + 7 * IW); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr4); | |||||
} | |||||
STORE_OUT; | |||||
} | |||||
}; | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<6, height, width> { | |||||
static void exec(const float* src, const float* filter, float* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
(void)IH; | |||||
(void)OH; | |||||
const int ih = oh, iw = ow; | |||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, | |||||
kr5, inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_OUT; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const float* src_dd = src + fw; | |||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||||
kr4 = vdupq_n_f32(filter[4 * FW + fw]); | |||||
kr5 = vdupq_n_f32(filter[5 * FW + fw]); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr1); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr2); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr1); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr3); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr2); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr1); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 4 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr4); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr3); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr2); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr1); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 5 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr5); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr4); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr3); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr2); | |||||
if (height > 1) | |||||
inp = vld1q_f32(src_dd + 6 * IW); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr5); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr4); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr3); | |||||
if (height > 2) | |||||
inp = vld1q_f32(src_dd + 7 * IW); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr5); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr4); | |||||
if (height > 3) | |||||
inp = vld1q_f32(src_dd + 8 * IW); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr5); | |||||
} | |||||
STORE_OUT; | |||||
} | |||||
}; | |||||
template <int height, int width> | |||||
struct do_pixel_proxy<7, height, width> { | |||||
static void exec(const float* src, const float* filter, float* dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FW, const int oh, const int ow) { | |||||
(void)IH; | |||||
(void)OH; | |||||
const int ih = oh, iw = ow; | |||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, | |||||
kr5, kr6, inp; | |||||
src += ih * IW + iw; | |||||
dst += oh * OW + ow; | |||||
LOAD_OUT; | |||||
for (int fw = 0; fw < FW; ++fw) { | |||||
const float* src_dd = src + fw; | |||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||||
kr4 = vdupq_n_f32(filter[4 * FW + fw]); | |||||
kr5 = vdupq_n_f32(filter[5 * FW + fw]); | |||||
kr6 = vdupq_n_f32(filter[6 * FW + fw]); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr1); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr2); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr1); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr3); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr2); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr1); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 4 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr4); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr3); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr2); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr1); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 5 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr5); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr4); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr3); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr2); | |||||
if (height > 0) | |||||
inp = vld1q_f32(src_dd + 6 * IW); | |||||
if (height > 0) | |||||
out0 = vmlaq_f32(out0, inp, kr6); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr5); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr4); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr3); | |||||
if (height > 1) | |||||
inp = vld1q_f32(src_dd + 7 * IW); | |||||
if (height > 1) | |||||
out1 = vmlaq_f32(out1, inp, kr6); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr5); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr4); | |||||
if (height > 2) | |||||
inp = vld1q_f32(src_dd + 8 * IW); | |||||
if (height > 2) | |||||
out2 = vmlaq_f32(out2, inp, kr6); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr5); | |||||
if (height > 3) | |||||
inp = vld1q_f32(src_dd + 9 * IW); | |||||
if (height > 3) | |||||
out3 = vmlaq_f32(out3, inp, kr6); | |||||
} | |||||
STORE_OUT; | |||||
} | |||||
}; | |||||
#undef cb_load | |||||
#undef cb_load | |||||
#undef LOAD_OUT | |||||
#undef STORE_OUT | |||||
template <int FH, int height, int width> | |||||
void do_pixel(const float* src, const float* filter, float* dst, const int IH, | |||||
const int IW, const int OH, const int OW, const int FW, | |||||
const int oh, const int ow) { | |||||
do_pixel_proxy<FH, height, width>::exec(src, filter, dst, IH, IW, OH, OW, | |||||
FW, oh, ow); | |||||
} | |||||
template <int FH> | |||||
void do_conv_tpl_enable_prefetch(const float* src, const float* filter, | |||||
float* dst, const int IH, const int IW, | |||||
const int OH, const int OW, const int FW) { | |||||
const int hbeg = 0, hend = OH; | |||||
const int wbeg = 0, wend = OW; | |||||
int i, j; | |||||
for (i = hbeg; i + 4 <= hend; i += 4) { | |||||
for (j = wbeg; j + 4 <= wend; j += 4) { | |||||
// do prefetch | |||||
const int prefetch_index_input = | |||||
(j + 16) < wend | |||||
? i * IW + j + 16 | |||||
: (i + 4) * IW + (((j + 16 - wend) >> 2) << 2); | |||||
const int prefetch_index_output = | |||||
(j + 16) < wend | |||||
? i * OW + j + 16 | |||||
: (i + 4) * OW + (((j + 16 - wend) >> 2) << 2); | |||||
const float* src_prefetch = src + prefetch_index_input; | |||||
const float* dst_prefetch = dst + prefetch_index_output; | |||||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { | |||||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); | |||||
} | |||||
__builtin_prefetch(dst_prefetch + 0 * OW, 1, 3); | |||||
__builtin_prefetch(dst_prefetch + 1 * OW, 1, 3); | |||||
__builtin_prefetch(dst_prefetch + 2 * OW, 1, 3); | |||||
__builtin_prefetch(dst_prefetch + 3 * OW, 1, 3); | |||||
do_pixel<FH, 4, 4>(src, filter, dst, IH, IW, OH, OW, FW, i, j); | |||||
} | |||||
#define DISPATCH(width) \ | |||||
do { \ | |||||
const int prefetch_index_input = (i + 4) * IW + 12; \ | |||||
const int prefetch_index_output = (i + 4) * OW + 12; \ | |||||
const float* src_prefetch = src + prefetch_index_input; \ | |||||
const float* dst_prefetch = dst + prefetch_index_output; \ | |||||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ | |||||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ | |||||
} \ | |||||
__builtin_prefetch(dst_prefetch + 0 * OW, 1, 3); \ | |||||
__builtin_prefetch(dst_prefetch + 1 * OW, 1, 3); \ | |||||
__builtin_prefetch(dst_prefetch + 2 * OW, 1, 3); \ | |||||
__builtin_prefetch(dst_prefetch + 3 * OW, 1, 3); \ | |||||
do_pixel<FH, 4, width>(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ | |||||
} while (0) | |||||
switch (wend - j) { | |||||
case 1: | |||||
DISPATCH(1); | |||||
break; | |||||
case 2: | |||||
DISPATCH(2); | |||||
break; | |||||
case 3: | |||||
DISPATCH(3); | |||||
break; | |||||
} | |||||
#undef DISPATCH | |||||
} | |||||
#define DISPATCH2(height, width) \ | |||||
do { \ | |||||
const int prefetch_index_input = IH * IW + 12; \ | |||||
const float* src_prefetch = src + prefetch_index_input; \ | |||||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ | |||||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ | |||||
} \ | |||||
do_pixel<FH, height, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||||
j); \ | |||||
} while (0) | |||||
#define DISPATCH1(height) \ | |||||
do { \ | |||||
for (j = wbeg; j + 4 <= wend; j += 4) { \ | |||||
const int prefetch_index_input = \ | |||||
(j + 16) < wend \ | |||||
? i * IW + j + 16 \ | |||||
: (i + 4) * IW + (((j + 16 - wend) >> 2) << 2); \ | |||||
const int prefetch_index_output = \ | |||||
(j + 16) < wend \ | |||||
? i * OW + j + 16 \ | |||||
: (i + 4) * OW + (((j + 16 - wend) >> 2) << 2); \ | |||||
const float* src_prefetch = src + prefetch_index_input; \ | |||||
const float* dst_prefetch = dst + prefetch_index_output; \ | |||||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ | |||||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ | |||||
} \ | |||||
__builtin_prefetch(dst_prefetch + 0 * OW, 1, 3); \ | |||||
__builtin_prefetch(dst_prefetch + 1 * OW, 1, 3); \ | |||||
__builtin_prefetch(dst_prefetch + 2 * OW, 1, 3); \ | |||||
__builtin_prefetch(dst_prefetch + 3 * OW, 1, 3); \ | |||||
do_pixel<FH, height, 4>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||||
j); \ | |||||
} \ | |||||
switch (wend - j) { \ | |||||
case 1: \ | |||||
DISPATCH2(height, 1); \ | |||||
break; \ | |||||
case 2: \ | |||||
DISPATCH2(height, 2); \ | |||||
break; \ | |||||
case 3: \ | |||||
DISPATCH2(height, 3); \ | |||||
break; \ | |||||
} \ | |||||
} while (0) | |||||
switch (hend - i) { | |||||
case 1: | |||||
DISPATCH1(1); | |||||
break; | |||||
case 2: | |||||
DISPATCH1(2); | |||||
break; | |||||
case 3: | |||||
DISPATCH1(3); | |||||
break; | |||||
} | |||||
#undef DISPATCH1 | |||||
#undef DISPATCH2 | |||||
} | |||||
template <int FH> | |||||
void do_conv_tpl_disable_prefetch(const float* src, const float* filter, | |||||
float* dst, const int IH, const int IW, | |||||
const int OH, const int OW, const int FW) { | |||||
const int hbeg = 0, hend = OH; | |||||
const int wbeg = 0, wend = OW; | |||||
int i, j; | |||||
for (i = hbeg; i + 4 <= hend; i += 4) { | |||||
for (j = wbeg; j + 4 <= wend; j += 4) { | |||||
do_pixel<FH, 4, 4>(src, filter, dst, IH, IW, OH, OW, FW, i, j); | |||||
} | |||||
#define DISPATCH(width) \ | |||||
do { \ | |||||
do_pixel<FH, 4, width>(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ | |||||
} while (0) | |||||
switch (wend - j) { | |||||
case 1: | |||||
DISPATCH(1); | |||||
break; | |||||
case 2: | |||||
DISPATCH(2); | |||||
break; | |||||
case 3: | |||||
DISPATCH(3); | |||||
break; | |||||
} | |||||
#undef DISPATCH | |||||
} | |||||
#define DISPATCH2(height, width) \ | |||||
do { \ | |||||
do_pixel<FH, height, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||||
j); \ | |||||
} while (0) | |||||
#define DISPATCH1(height) \ | |||||
do { \ | |||||
for (j = wbeg; j + 4 <= wend; j += 4) { \ | |||||
do_pixel<FH, height, 4>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||||
j); \ | |||||
} \ | |||||
switch (wend - j) { \ | |||||
case 1: \ | |||||
DISPATCH2(height, 1); \ | |||||
break; \ | |||||
case 2: \ | |||||
DISPATCH2(height, 2); \ | |||||
break; \ | |||||
case 3: \ | |||||
DISPATCH2(height, 3); \ | |||||
break; \ | |||||
} \ | |||||
} while (0) | |||||
switch (hend - i) { | |||||
case 1: | |||||
DISPATCH1(1); | |||||
break; | |||||
case 2: | |||||
DISPATCH1(2); | |||||
break; | |||||
case 3: | |||||
DISPATCH1(3); | |||||
break; | |||||
} | |||||
#undef DISPATCH1 | |||||
#undef DISPATCH2 | |||||
} | |||||
} // anonymous namespace | |||||
void conv_bias::kern_direct(const float* src, const float* filter, float* dst, | |||||
const int IH, const int IW, const int OH, | |||||
const int OW, const int FH, const int FW) { | |||||
megdnn_assert_internal(FH <= 7); | |||||
if (IH > 100 && IW > 100) { | |||||
#define GAO(FH) \ | |||||
do { \ | |||||
return do_conv_tpl_enable_prefetch<FH>(src, filter, dst, IH, IW, OH, \ | |||||
OW, FW); \ | |||||
} while (0) | |||||
switch (FH) { | |||||
case 1: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 2: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 3: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 4: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 5: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 6: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 7: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); } | |||||
MIDOUT_END(); | |||||
break; | |||||
} | |||||
#undef GAO | |||||
} else { | |||||
#define GAO(FH) \ | |||||
do { \ | |||||
return do_conv_tpl_disable_prefetch<FH>(src, filter, dst, IH, IW, OH, \ | |||||
OW, FW); \ | |||||
} while (0) | |||||
switch (FH) { | |||||
case 1: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 2: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 3: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 4: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 5: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 6: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); } | |||||
MIDOUT_END(); | |||||
break; | |||||
case 7: | |||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); } | |||||
MIDOUT_END(); | |||||
break; | |||||
} | |||||
#undef GAO | |||||
} | |||||
megdnn_assert_internal(0); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,29 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/direct.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cstddef> | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace fp32{ | |||||
namespace conv_bias { | |||||
void kern_direct(const float *src, const float *filter, float *dst, | |||||
const int IH, const int IW, const int OH, const int OW, | |||||
const int FH, const int FW); | |||||
} // namespace convolution | |||||
} // namespace fp32 | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,735 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <algorithm> | |||||
#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" | |||||
#include "src/arm_common/simd_macro/neon_helper.h" | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs1) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
using namespace fp32; | |||||
using namespace conv_stride1; | |||||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
void conv_stride1::do_conv_2x2_stride1(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
//! unroll of 2 | |||||
size_t ic = 0; | |||||
for (; ic + 1 < IC; ic += 2) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
const float* src_ptr1 = src_ptr + IW * IH; | |||||
float* outptr = dst; | |||||
const float* r00 = src_ptr; | |||||
const float* r01 = src_ptr + IW; | |||||
const float* r10 = src_ptr1; | |||||
const float* r11 = src_ptr1 + IW; | |||||
const float* k0 = filter + ic * 4; | |||||
const float* k1 = k0 + 4; | |||||
MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_LOADU(k0); | |||||
MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_LOADU(k1); | |||||
rep(h, OH) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _r000 = MEGDNN_SIMD_LOADU(r00); | |||||
MEGDNN_SIMD_TYPE _r010 = MEGDNN_SIMD_LOADU(r01); | |||||
MEGDNN_SIMD_TYPE _r001 = MEGDNN_SIMD_LOADU(r00 + 1); | |||||
MEGDNN_SIMD_TYPE _r011 = MEGDNN_SIMD_LOADU(r01 + 1); | |||||
MEGDNN_SIMD_TYPE _r100 = MEGDNN_SIMD_LOADU(r10); | |||||
MEGDNN_SIMD_TYPE _r110 = MEGDNN_SIMD_LOADU(r11); | |||||
MEGDNN_SIMD_TYPE _r101 = MEGDNN_SIMD_LOADU(r10 + 1); | |||||
MEGDNN_SIMD_TYPE _r111 = MEGDNN_SIMD_LOADU(r11 + 1); | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r000, | |||||
MEGDNN_SIMD_GET_LOW(_k0), 0); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r001, | |||||
MEGDNN_SIMD_GET_LOW(_k0), 1); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r010, | |||||
MEGDNN_SIMD_GET_HIGH(_k0), 0); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r011, | |||||
MEGDNN_SIMD_GET_HIGH(_k0), 1); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r100, | |||||
MEGDNN_SIMD_GET_LOW(_k1), 0); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r101, | |||||
MEGDNN_SIMD_GET_LOW(_k1), 1); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r110, | |||||
MEGDNN_SIMD_GET_HIGH(_k1), 0); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r111, | |||||
MEGDNN_SIMD_GET_HIGH(_k1), 1); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
r00 += 4; | |||||
r01 += 4; | |||||
r10 += 4; | |||||
r11 += 4; | |||||
outptr += 4; | |||||
} | |||||
r00 += tail_step; | |||||
r01 += tail_step; | |||||
r10 += tail_step; | |||||
r11 += tail_step; | |||||
} | |||||
} | |||||
for (; ic < IC; ic++) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* k0 = filter + ic * 4; | |||||
MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_SET1(k0[0]); | |||||
MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_SET1(k0[1]); | |||||
MEGDNN_SIMD_TYPE _k2 = MEGDNN_SIMD_SET1(k0[2]); | |||||
MEGDNN_SIMD_TYPE _k3 = MEGDNN_SIMD_SET1(k0[3]); | |||||
rep(h, OH) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_LOADU(r0 + 1); | |||||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_LOADU(r1 + 1); | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _sum2; | |||||
_sum = MEGDNN_SIMD_FMADD(_r00, _k0, _sum); | |||||
_sum2 = MEGDNN_SIMD_MUL(_r01, _k1); | |||||
_sum = MEGDNN_SIMD_FMADD(_r10, _k2, _sum); | |||||
_sum2 = MEGDNN_SIMD_FMADD(_r11, _k3, _sum2); | |||||
_sum = MEGDNN_SIMD_ADD(_sum, _sum2); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
} | |||||
} | |||||
} | |||||
void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
float* outptr2 = outptr + OW; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* k0 = filter; | |||||
const float* k1 = filter + 3; | |||||
const float* k2 = filter + 5; | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||||
MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); | |||||
MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); | |||||
MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1); | |||||
size_t h = 0; | |||||
for (; h + 1 < OH; h += 2) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f); | |||||
MEGDNN_SIMD_TYPE _sum3 = MEGDNN_SIMD_LOADU(outptr2); | |||||
MEGDNN_SIMD_TYPE _sum4 = MEGDNN_SIMD_SET1(0.f); | |||||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||||
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4); | |||||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1); | |||||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2); | |||||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||||
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4); | |||||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1); | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2); | |||||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||||
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4); | |||||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1); | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); | |||||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||||
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 4); | |||||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r30n, 1); | |||||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r30n, 2); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2); | |||||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r10, _k0123, 0); | |||||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r11, _k0123, 1); | |||||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r12, _k0123, 2); | |||||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r20, _k3456, 0); | |||||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r21, _k3456, 1); | |||||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r22, _k3456, 2); | |||||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r30, _k6789, 0); | |||||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r31, _k6789, 1); | |||||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r32, _k6789, 2); | |||||
_sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2); | |||||
_sum3 = MEGDNN_SIMD_ADD(_sum3, _sum4); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum1); | |||||
MEGDNN_SIMD_STOREU(outptr2, _sum3); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
r3 += 4; | |||||
outptr += 4; | |||||
outptr2 += 4; | |||||
} | |||||
r0 += tail_step + IW; | |||||
r1 += tail_step + IW; | |||||
r2 += tail_step + IW; | |||||
r3 += tail_step + IW; | |||||
outptr += OW; | |||||
outptr2 += OW; | |||||
} | |||||
for (; h < OH; h++) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f); | |||||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||||
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4); | |||||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1); | |||||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2); | |||||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||||
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4); | |||||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1); | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2); | |||||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||||
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4); | |||||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1); | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2); | |||||
_sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum1); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
} | |||||
filter += 9; | |||||
} | |||||
} | |||||
void conv_stride1::do_conv_5x5_stride1(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
float* outptr2 = outptr + OW; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* r4 = src_ptr + IW * 4; | |||||
const float* r5 = src_ptr + IW * 5; | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); | |||||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); | |||||
MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); | |||||
MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); | |||||
MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); | |||||
MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); | |||||
MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]); | |||||
size_t h = 0; | |||||
for (; h + 1 < OH; h += 2) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_LOADU(outptr2); | |||||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||||
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); | |||||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); | |||||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); | |||||
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); | |||||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); | |||||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); | |||||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); | |||||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); | |||||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); | |||||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); | |||||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); | |||||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); | |||||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); | |||||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); | |||||
MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); | |||||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); | |||||
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); | |||||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); | |||||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); | |||||
MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5); | |||||
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4); | |||||
MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1); | |||||
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2); | |||||
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k0123, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r11, _k0123, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k0123, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r13, _k0123, 3); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r14, _k4567, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r20, _k4567, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k4567, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r22, _k4567, 3); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r23, _k891011, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r24, _k891011, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r30, _k891011, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r31, _k891011, 3); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r32, _k12131415, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r33, _k12131415, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r34, _k12131415, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r40, _k12131415, 3); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r41, _k16171819, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r42, _k16171819, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r43, _k16171819, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r44, _k16171819, 3); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r50, _k20212223, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r51, _k20212223, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r52, _k20212223, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r53, _k20212223, 3); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r54, _k24242424, 0); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
MEGDNN_SIMD_STOREU(outptr2, _sum2); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
r3 += 4; | |||||
r4 += 4; | |||||
r5 += 4; | |||||
outptr += 4; | |||||
outptr2 += 4; | |||||
} | |||||
r0 += tail_step + IW; | |||||
r1 += tail_step + IW; | |||||
r2 += tail_step + IW; | |||||
r3 += tail_step + IW; | |||||
r4 += tail_step + IW; | |||||
r5 += tail_step + IW; | |||||
outptr += OW; | |||||
outptr2 += OW; | |||||
} | |||||
for (; h < OH; h++) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||||
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); | |||||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); | |||||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); | |||||
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); | |||||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); | |||||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); | |||||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); | |||||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); | |||||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); | |||||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); | |||||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); | |||||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); | |||||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); | |||||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); | |||||
MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); | |||||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); | |||||
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); | |||||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); | |||||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
r3 += 4; | |||||
r4 += 4; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
r3 += tail_step; | |||||
r4 += tail_step; | |||||
} | |||||
filter += 25; | |||||
} | |||||
} | |||||
void conv_stride1::do_conv_7x7_stride1(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* r4 = src_ptr + IW * 4; | |||||
const float* r5 = src_ptr + IW * 5; | |||||
const float* r6 = src_ptr + IW * 6; | |||||
const float* k0 = filter; | |||||
const float* k1 = filter + 7; | |||||
const float* k2 = filter + 14; | |||||
const float* k3 = filter + 21; | |||||
const float* k4 = filter + 28; | |||||
const float* k5 = filter + 35; | |||||
const float* k6 = filter + 42; | |||||
for (size_t i = 0; i < OH; i++) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4); | |||||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); // 0 1 2 3 | |||||
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); // 4 5 6 7 | |||||
MEGDNN_SIMD_TYPE _r00n = | |||||
MEGDNN_SIMD_LOADU(r0 + 8); // 8 9 10 11 | |||||
MEGDNN_SIMD_TYPE _r01 = | |||||
MEGDNN_SIMD_EXT(_r00, _r04, 1); // 1 2 3 4 | |||||
MEGDNN_SIMD_TYPE _r02 = | |||||
MEGDNN_SIMD_EXT(_r00, _r04, 2); // 2 3 4 5 | |||||
MEGDNN_SIMD_TYPE _r03 = | |||||
MEGDNN_SIMD_EXT(_r00, _r04, 3); // 3 4 5 6 | |||||
MEGDNN_SIMD_TYPE _r05 = | |||||
MEGDNN_SIMD_EXT(_r04, _r00n, 1); // 5 6 7 8 | |||||
MEGDNN_SIMD_TYPE _r06 = | |||||
MEGDNN_SIMD_EXT(_r04, _r00n, 2); // 6 7 8 9 | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2); | |||||
MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1); | |||||
MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4); | |||||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); | |||||
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 8); | |||||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); | |||||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); | |||||
MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r14, _r10n, 1); | |||||
MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r14, _r10n, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2); | |||||
MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2); | |||||
MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4); | |||||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); | |||||
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 8); | |||||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); | |||||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); | |||||
MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r24, _r20n, 1); | |||||
MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r24, _r20n, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2); | |||||
MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3); | |||||
MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4); | |||||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); | |||||
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 8); | |||||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); | |||||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); | |||||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); | |||||
MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r34, _r30n, 1); | |||||
MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r34, _r30n, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2); | |||||
MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4); | |||||
MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4); | |||||
MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); | |||||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); | |||||
MEGDNN_SIMD_TYPE _r40n = MEGDNN_SIMD_LOADU(r4 + 8); | |||||
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); | |||||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); | |||||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); | |||||
MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r44, _r40n, 1); | |||||
MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r44, _r40n, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2); | |||||
MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5); | |||||
MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4); | |||||
MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5); | |||||
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4); | |||||
MEGDNN_SIMD_TYPE _r50n = MEGDNN_SIMD_LOADU(r5 + 8); | |||||
MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1); | |||||
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2); | |||||
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3); | |||||
MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r54, _r50n, 1); | |||||
MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r54, _r50n, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); | |||||
MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); | |||||
MEGDNN_SIMD_TYPE _k46474849 = MEGDNN_SIMD_LOADU(k6 + 4); | |||||
MEGDNN_SIMD_TYPE _r60 = MEGDNN_SIMD_LOADU(r6); | |||||
MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_LOADU(r6 + 4); | |||||
MEGDNN_SIMD_TYPE _r60n = MEGDNN_SIMD_LOADU(r6 + 8); | |||||
MEGDNN_SIMD_TYPE _r61 = MEGDNN_SIMD_EXT(_r60, _r64, 1); | |||||
MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r64, 2); | |||||
MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r60, _r64, 3); | |||||
MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r64, _r60n, 1); | |||||
MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r64, _r60n, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k46474849, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k46474849, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k46474849, 2); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
r3 += 4; | |||||
r4 += 4; | |||||
r5 += 4; | |||||
r6 += 4; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
r3 += tail_step; | |||||
r4 += tail_step; | |||||
r5 += tail_step; | |||||
r6 += tail_step; | |||||
} | |||||
filter += 49; | |||||
} | |||||
} | |||||
#include "src/common/simd_macro/epilogue.h" | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,34 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cstddef> | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace fp32 { | |||||
namespace conv_stride1 { | |||||
void do_conv_2x2_stride1(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||||
void do_conv_3x3_stride1(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||||
void do_conv_5x5_stride1(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||||
void do_conv_7x7_stride1(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||||
} // namespace conv_stride1 | |||||
} // namespace fp32 | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,513 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <algorithm> | |||||
#include "./do_conv_stride2.h" | |||||
#include "midout.h" | |||||
#include "src/arm_common/simd_macro/neon_helper.h" | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs2) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
using namespace fp32; | |||||
using namespace conv_stride2; | |||||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
void conv_stride2::do_conv_2x2_stride2(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* k0 = filter; | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||||
rep(h, OH) { | |||||
int nn = OW >> 2; | |||||
rep(i, nn) { | |||||
MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0); | |||||
MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6 | |||||
MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7 | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1); | |||||
MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1); | |||||
MEGDNN_SIMD_TYPE _r10 = _r1.val[0]; | |||||
MEGDNN_SIMD_TYPE _r11 = _r1.val[1]; | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k0123, 2); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k0123, 3); | |||||
MEGDNN_SIMD_STOREU(outptr, _outp); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
} | |||||
filter += 4; | |||||
} | |||||
} | |||||
void conv_stride2::do_conv_3x3_stride2(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* k0 = filter; | |||||
const float* k1 = filter + 3; | |||||
const float* k2 = filter + 5; | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||||
MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); | |||||
MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); | |||||
MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1); | |||||
rep(h, OH) { | |||||
int nn = OW >> 2; | |||||
rep(i, nn) { | |||||
MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0); | |||||
MEGDNN_SIMD_TYPE2 _r0n = MEGDNN_SIMD_LOAD2(r0 + 8); | |||||
MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6 | |||||
MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7 | |||||
MEGDNN_SIMD_TYPE _r02 = | |||||
MEGDNN_SIMD_EXT(_r00, _r0n.val[0], 1); // 2 4 6 8 | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r02, _k0123, 2); | |||||
MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1); | |||||
MEGDNN_SIMD_TYPE2 _r1n = MEGDNN_SIMD_LOAD2(r1 + 8); | |||||
MEGDNN_SIMD_TYPE _r10 = _r1.val[0]; | |||||
MEGDNN_SIMD_TYPE _r11 = _r1.val[1]; | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1n.val[0], 1); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k3456, 0); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k3456, 1); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r12, _k3456, 2); | |||||
MEGDNN_SIMD_TYPE2 _r2 = MEGDNN_SIMD_LOAD2(r2); | |||||
MEGDNN_SIMD_TYPE2 _r2n = MEGDNN_SIMD_LOAD2(r2 + 8); | |||||
MEGDNN_SIMD_TYPE _r20 = _r2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r21 = _r2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2n.val[0], 1); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r20, _k6789, 0); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r21, _k6789, 1); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r22, _k6789, 2); | |||||
MEGDNN_SIMD_STOREU(outptr, _outp); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
r2 += 8; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
} | |||||
filter += 9; | |||||
} | |||||
} | |||||
void conv_stride2::do_conv_5x5_stride2(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* r4 = src_ptr + IW * 4; | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); | |||||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); | |||||
MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); | |||||
MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); | |||||
MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); | |||||
MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); | |||||
MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]); | |||||
for (size_t i = 0; i < OH; i++) { | |||||
int nn = OW >> 2; | |||||
rep(i, nn) { | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0); | |||||
MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8); | |||||
MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 | |||||
MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 | |||||
MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6 | |||||
MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7 | |||||
MEGDNN_SIMD_TYPE _r02 = | |||||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8 | |||||
MEGDNN_SIMD_TYPE _r03 = | |||||
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9 | |||||
MEGDNN_SIMD_TYPE _r04 = | |||||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10 | |||||
MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1); | |||||
MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8); | |||||
MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2); | |||||
MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2); | |||||
MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8); | |||||
MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2); | |||||
MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3); | |||||
MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8); | |||||
MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2); | |||||
MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4); | |||||
MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8); | |||||
MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
r2 += 8; | |||||
r3 += 8; | |||||
r4 += 8; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
r3 += tail_step; | |||||
r4 += tail_step; | |||||
} | |||||
filter += 25; | |||||
} | |||||
} | |||||
void conv_stride2::do_conv_7x7_stride2(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* r4 = src_ptr + IW * 4; | |||||
const float* r5 = src_ptr + IW * 5; | |||||
const float* r6 = src_ptr + IW * 6; | |||||
const float* k0 = filter; | |||||
const float* k1 = filter + 7; | |||||
const float* k2 = filter + 14; | |||||
const float* k3 = filter + 21; | |||||
const float* k4 = filter + 28; | |||||
const float* k5 = filter + 35; | |||||
const float* k6 = filter + 42; | |||||
for (size_t i = 0; i < OH; i++) { | |||||
int nn = OW >> 2; | |||||
rep(i, nn) { | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4); | |||||
MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0); | |||||
MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8); | |||||
MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 | |||||
MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 | |||||
MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6 | |||||
MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7 | |||||
MEGDNN_SIMD_TYPE _r02 = | |||||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8 | |||||
MEGDNN_SIMD_TYPE _r03 = | |||||
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9 | |||||
MEGDNN_SIMD_TYPE _r04 = | |||||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10 | |||||
MEGDNN_SIMD_TYPE _r05 = | |||||
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 2); // 5 7 9 11 | |||||
MEGDNN_SIMD_TYPE _r06 = | |||||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 3); // 6 8 10 12 | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2); | |||||
MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1); | |||||
MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4); | |||||
MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1); | |||||
MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8); | |||||
MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2); | |||||
MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 2); | |||||
MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2); | |||||
MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2); | |||||
MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4); | |||||
MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2); | |||||
MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8); | |||||
MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2); | |||||
MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 2); | |||||
MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2); | |||||
MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3); | |||||
MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4); | |||||
MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3); | |||||
MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8); | |||||
MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2); | |||||
MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 2); | |||||
MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2); | |||||
MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4); | |||||
MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4); | |||||
MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4); | |||||
MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8); | |||||
MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2); | |||||
MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 2); | |||||
MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2); | |||||
MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5); | |||||
MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4); | |||||
MEGDNN_SIMD_TYPE2 _r50_02461357 = MEGDNN_SIMD_LOAD2(r5); | |||||
MEGDNN_SIMD_TYPE2 _r50nx2 = MEGDNN_SIMD_LOAD2(r5 + 8); | |||||
MEGDNN_SIMD_TYPE _r5_8101214 = _r50nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r5_9111315 = _r50nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r50 = _r50_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r51 = _r50_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 2); | |||||
MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 2); | |||||
MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); | |||||
MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); | |||||
MEGDNN_SIMD_TYPE _k45464748 = MEGDNN_SIMD_LOADU(k6 + 3); | |||||
MEGDNN_SIMD_TYPE2 _r60_02461357 = MEGDNN_SIMD_LOAD2(r6); | |||||
MEGDNN_SIMD_TYPE2 _r60nx2 = MEGDNN_SIMD_LOAD2(r6 + 8); | |||||
MEGDNN_SIMD_TYPE _r6_8101214 = _r60nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r6_9111315 = _r60nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r60 = _r60_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r61 = _r60_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 2); | |||||
MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 2); | |||||
MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k45464748, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k45464748, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k45464748, 3); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
r2 += 8; | |||||
r3 += 8; | |||||
r4 += 8; | |||||
r5 += 8; | |||||
r6 += 8; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
r3 += tail_step; | |||||
r4 += tail_step; | |||||
r5 += tail_step; | |||||
r6 += tail_step; | |||||
} | |||||
filter += 49; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,32 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace fp32 { | |||||
namespace conv_stride2 { | |||||
void do_conv_2x2_stride2(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||||
void do_conv_3x3_stride2(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||||
void do_conv_5x5_stride2(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||||
void do_conv_7x7_stride2(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||||
} // namespace conv_stride2 | |||||
} // namespace fp32 | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,164 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/filter_transform.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
template <param::MatrixMul::Format format=param::MatrixMul::Format::DEFAULT> | |||||
struct FilterTransform6X3 { | |||||
#define FILTER_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
auto tmp0 = (d##0 + d##2) * -0.2222222f; \ | |||||
auto tmp1 = d##1 * -0.2222222f; \ | |||||
wd##1 = tmp0 + tmp1; \ | |||||
wd##2 = tmp0 - tmp1; \ | |||||
tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \ | |||||
tmp1 = d##1 * 0.0222222f; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \ | |||||
tmp1 = d##1 * 0.3555556f; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = d##2; \ | |||||
} while (0); | |||||
static void transform(const float* filter, float* filter_transform_buf, | |||||
float* transform_mid_buf, size_t OC, size_t IC, | |||||
size_t oc_start, size_t oc_end) { | |||||
// Gg * GT | |||||
// G | |||||
// 1.0000000 0.0000000 0.0000000 | |||||
// -0.2222222 -0.2222222 -0.2222222 | |||||
// -0.2222222 0.2222222 -0.2222222 | |||||
// 0.0111111 0.0222222 0.0444444 | |||||
// 0.0111111 -0.0222222 0.0444444 | |||||
// 0.7111111 0.3555556 0.1777778 | |||||
// 0.7111111 -0.3555556 0.1777778 | |||||
// 0.0000000 0.0000000 1.0000000 | |||||
constexpr size_t alpha = 6 + 3 - 1; | |||||
size_t OCB = OC / 4; | |||||
size_t ICB = IC / 4; | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
rep(ic, IC) { | |||||
const float* fptr = filter + (oc * IC + ic) * 3 * 3; | |||||
Vector<float, 4> g0 = Vector<float, 4>::load(fptr); | |||||
Vector<float, 4> g1 = Vector<float, 4>::load(fptr + 3); | |||||
Vector<float, 4> g2 = Vector<float, 4>::load(fptr + 6 - 1); | |||||
float32x4_t zeros = vdupq_n_f32(0.0f); | |||||
g2.value = vextq_f32(g2.value, zeros, 1); | |||||
#define cb(i) Vector<float, 4> wd##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<float, 8> wdt##i; | |||||
UNROLL_CALL_NOWRAPPER(3, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<float, 8> ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
FILTER_TRANSFORM(g, wd); | |||||
size_t ocb = oc / 4; | |||||
size_t oc4 = oc % 4; | |||||
size_t icb = ic / 4; | |||||
size_t ic4 = ic % 4; | |||||
#if MEGDNN_AARCH64 | |||||
TRANSPOSE_8x3(wd, wdt); | |||||
FILTER_TRANSFORM(wdt, ret); | |||||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
if (format == param::MatrixMul::Format::DEFAULT) { | |||||
filter_transform_buf[(i * alpha + j) * OC * IC + | |||||
ic * OC + oc] = | |||||
transform_mid_buf[j * alpha + i]; | |||||
} else { | |||||
filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 * | |||||
4 + | |||||
ocb * ICB * 4 * 4 + icb * 4 * 4 + | |||||
ic4 * 4 + oc4] = | |||||
transform_mid_buf[j * alpha + i]; | |||||
} | |||||
} | |||||
#else | |||||
#define cb(i) \ | |||||
do { \ | |||||
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ | |||||
auto tmp0 = (GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * \ | |||||
-0.2222222f; \ | |||||
auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * -0.2222222f; \ | |||||
mid_buf1[1] = tmp0 + tmp1; \ | |||||
mid_buf1[2] = tmp0 - tmp1; \ | |||||
tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0111111f + \ | |||||
GET_VECTOR_ELEM(wd, i, 2) * 0.0444444f; \ | |||||
tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0222222f; \ | |||||
mid_buf1[3] = tmp0 + tmp1; \ | |||||
mid_buf1[4] = tmp0 - tmp1; \ | |||||
tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.7111111f + \ | |||||
GET_VECTOR_ELEM(wd, i, 2) * 0.1777778f; \ | |||||
tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.3555556f; \ | |||||
mid_buf1[5] = tmp0 + tmp1; \ | |||||
mid_buf1[6] = tmp0 - tmp1; \ | |||||
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ | |||||
mid_buf1 += 8; \ | |||||
} while (0); | |||||
#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) | |||||
float* mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
if (format == param::MatrixMul::Format::DEFAULT) { | |||||
filter_transform_buf[(i * alpha + j) * OC * IC + | |||||
ic * OC + oc] = | |||||
transform_mid_buf[i * alpha + j]; | |||||
} else { | |||||
filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 * | |||||
4 + | |||||
ocb * ICB * 4 * 4 + icb * 4 * 4 + | |||||
ic4 * 4 + oc4] = | |||||
transform_mid_buf[i * alpha + j]; | |||||
} | |||||
} | |||||
#endif | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#undef FILTER_TRANSFORM | |||||
#undef GET_VECTOR_ELEM | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,204 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/helper.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { | |||||
float32x4x2_t a0, a1; | |||||
a0.val[0] = vld1q_f32(src + 0 * lda); | |||||
a0.val[1] = vld1q_f32(src + 1 * lda); | |||||
a1.val[0] = vld1q_f32(src + 2 * lda); | |||||
a1.val[1] = vld1q_f32(src + 3 * lda); | |||||
float32x4x2_t b0 = vzipq_f32(a0.val[0], a1.val[0]); | |||||
float32x4x2_t b1 = vzipq_f32(a0.val[1], a1.val[1]); | |||||
float32x4x2_t c0 = vzipq_f32(b0.val[0], b1.val[0]); | |||||
float32x4x2_t c1 = vzipq_f32(b0.val[1], b1.val[1]); | |||||
vst1q_f32(dst + 0 * ldb, c0.val[0]); | |||||
vst1q_f32(dst + 1 * ldb, c0.val[1]); | |||||
vst1q_f32(dst + 2 * ldb, c1.val[0]); | |||||
vst1q_f32(dst + 3 * ldb, c1.val[1]); | |||||
} | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
#define MATRIX_MUL4x4(sum, a, b) \ | |||||
sum##0 = vmlaq_low_lane_f32(sum##0, b##0, a##0, 0); \ | |||||
sum##0 = vmlaq_low_lane_f32(sum##0, b##1, a##0, 1); \ | |||||
sum##0 = vmlaq_high_lane_f32(sum##0, b##2, a##0, 2); \ | |||||
sum##0 = vmlaq_high_lane_f32(sum##0, b##3, a##0, 3); \ | |||||
sum##1 = vmlaq_low_lane_f32(sum##1, b##0, a##1, 0); \ | |||||
sum##1 = vmlaq_low_lane_f32(sum##1, b##1, a##1, 1); \ | |||||
sum##1 = vmlaq_high_lane_f32(sum##1, b##2, a##1, 2); \ | |||||
sum##1 = vmlaq_high_lane_f32(sum##1, b##3, a##1, 3); \ | |||||
sum##2 = vmlaq_low_lane_f32(sum##2, b##0, a##2, 0); \ | |||||
sum##2 = vmlaq_low_lane_f32(sum##2, b##1, a##2, 1); \ | |||||
sum##2 = vmlaq_high_lane_f32(sum##2, b##2, a##2, 2); \ | |||||
sum##2 = vmlaq_high_lane_f32(sum##2, b##3, a##2, 3); \ | |||||
sum##3 = vmlaq_low_lane_f32(sum##3, b##0, a##3, 0); \ | |||||
sum##3 = vmlaq_low_lane_f32(sum##3, b##1, a##3, 1); \ | |||||
sum##3 = vmlaq_high_lane_f32(sum##3, b##2, a##3, 2); \ | |||||
sum##3 = vmlaq_high_lane_f32(sum##3, b##3, a##3, 3); | |||||
#define CONCAT(a, idx) a##idx | |||||
#if MEGDNN_AARCH64 | |||||
//! ret and a are type Vector<float, 8> | |||||
#define TRANSPOSE_8x8(a, ret) \ | |||||
do { \ | |||||
auto b0 = vzipq_f32(CONCAT(a, 0).value.val[0], \ | |||||
CONCAT(a, 1).value.val[0]); \ | |||||
auto b1 = vzipq_f32(CONCAT(a, 0).value.val[1], \ | |||||
CONCAT(a, 1).value.val[1]); \ | |||||
auto b2 = vzipq_f32(CONCAT(a, 2).value.val[0], \ | |||||
CONCAT(a, 3).value.val[0]); \ | |||||
auto b3 = vzipq_f32(CONCAT(a, 2).value.val[1], \ | |||||
CONCAT(a, 3).value.val[1]); \ | |||||
auto b4 = vzipq_f32(CONCAT(a, 4).value.val[0], \ | |||||
CONCAT(a, 5).value.val[0]); \ | |||||
auto b5 = vzipq_f32(CONCAT(a, 4).value.val[1], \ | |||||
CONCAT(a, 5).value.val[1]); \ | |||||
auto b6 = vzipq_f32(CONCAT(a, 6).value.val[0], \ | |||||
CONCAT(a, 7).value.val[0]); \ | |||||
auto b7 = vzipq_f32(CONCAT(a, 6).value.val[1], \ | |||||
CONCAT(a, 7).value.val[1]); \ | |||||
CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \ | |||||
vreinterpretq_s64_f32(b2.val[0]))); \ | |||||
CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b4.val[0]), \ | |||||
vreinterpretq_s64_f32(b6.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \ | |||||
vreinterpretq_s64_f32(b2.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b4.val[0]), \ | |||||
vreinterpretq_s64_f32(b6.val[0]))); \ | |||||
CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \ | |||||
vreinterpretq_s64_f32(b2.val[1]))); \ | |||||
CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b4.val[1]), \ | |||||
vreinterpretq_s64_f32(b6.val[1]))); \ | |||||
CONCAT(ret, 3).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b0.val[1]), \ | |||||
vreinterpretq_s64_f32(b2.val[1]))); \ | |||||
CONCAT(ret, 3).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b4.val[1]), \ | |||||
vreinterpretq_s64_f32(b6.val[1]))); \ | |||||
CONCAT(ret, 4).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b1.val[0]), \ | |||||
vreinterpretq_s64_f32(b3.val[0]))); \ | |||||
CONCAT(ret, 4).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b5.val[0]), \ | |||||
vreinterpretq_s64_f32(b7.val[0]))); \ | |||||
CONCAT(ret, 5).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b1.val[0]), \ | |||||
vreinterpretq_s64_f32(b3.val[0]))); \ | |||||
CONCAT(ret, 5).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b5.val[0]), \ | |||||
vreinterpretq_s64_f32(b7.val[0]))); \ | |||||
CONCAT(ret, 6).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b1.val[1]), \ | |||||
vreinterpretq_s64_f32(b3.val[1]))); \ | |||||
CONCAT(ret, 6).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b5.val[1]), \ | |||||
vreinterpretq_s64_f32(b7.val[1]))); \ | |||||
CONCAT(ret, 7).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b1.val[1]), \ | |||||
vreinterpretq_s64_f32(b3.val[1]))); \ | |||||
CONCAT(ret, 7).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b5.val[1]), \ | |||||
vreinterpretq_s64_f32(b7.val[1]))); \ | |||||
} while (0); | |||||
#define TRANSPOSE_8x3(a, ret) \ | |||||
auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||||
auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||||
auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||||
auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||||
CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \ | |||||
vreinterpretq_s64_f32(b1.val[0]))); \ | |||||
CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b2.val[0]), \ | |||||
vreinterpretq_s64_f32(b3.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \ | |||||
vreinterpretq_s64_f32(b1.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b2.val[0]), \ | |||||
vreinterpretq_s64_f32(b3.val[0]))); \ | |||||
CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \ | |||||
vreinterpretq_s64_f32(b1.val[1]))); \ | |||||
CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b2.val[1]), \ | |||||
vreinterpretq_s64_f32(b3.val[1]))); | |||||
#define TRANSPOSE_8x4(a, ret) \ | |||||
auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||||
auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||||
auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||||
auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||||
CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \ | |||||
vreinterpretq_s64_f32(b1.val[0]))); \ | |||||
CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b2.val[0]), \ | |||||
vreinterpretq_s64_f32(b3.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \ | |||||
vreinterpretq_s64_f32(b1.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b2.val[0]), \ | |||||
vreinterpretq_s64_f32(b3.val[0]))); \ | |||||
CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \ | |||||
vreinterpretq_s64_f32(b1.val[1]))); \ | |||||
CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip1q_s64(vreinterpretq_s64_f32(b2.val[1]), \ | |||||
vreinterpretq_s64_f32(b3.val[1]))); \ | |||||
CONCAT(ret, 3).value.val[0] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b0.val[1]), \ | |||||
vreinterpretq_s64_f32(b1.val[1]))); \ | |||||
CONCAT(ret, 3).value.val[1] = vreinterpretq_f32_s64( \ | |||||
vzip2q_s64(vreinterpretq_s64_f32(b2.val[1]), \ | |||||
vreinterpretq_s64_f32(b3.val[1]))); | |||||
#elif MEGDNN_ARMV7 | |||||
#define TRANSPOSE_8x4(a, ret) \ | |||||
auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||||
auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||||
auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||||
auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||||
CONCAT(ret, 0).value.val[0] = \ | |||||
vcombine_f32(vget_low_f32(b0.val[0]), vget_low_f32(b1.val[0])); \ | |||||
CONCAT(ret, 1).value.val[0] = \ | |||||
vcombine_f32(vget_high_f32(b0.val[0]), vget_high_f32(b1.val[0])); \ | |||||
CONCAT(ret, 2).value.val[0] = \ | |||||
vcombine_f32(vget_low_f32(b0.val[1]), vget_low_f32(b1.val[1])); \ | |||||
CONCAT(ret, 3).value.val[0] = \ | |||||
vcombine_f32(vget_high_f32(b0.val[1]), vget_high_f32(b1.val[1])); \ | |||||
CONCAT(ret, 0).value.val[1] = \ | |||||
vcombine_f32(vget_low_f32(b2.val[0]), vget_low_f32(b3.val[0])); \ | |||||
CONCAT(ret, 1).value.val[1] = \ | |||||
vcombine_f32(vget_high_f32(b2.val[0]), vget_high_f32(b3.val[0])); \ | |||||
CONCAT(ret, 2).value.val[1] = \ | |||||
vcombine_f32(vget_low_f32(b2.val[1]), vget_low_f32(b3.val[1])); \ | |||||
CONCAT(ret, 3).value.val[1] = \ | |||||
vcombine_f32(vget_high_f32(b2.val[1]), vget_high_f32(b3.val[1])); | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,39 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace winograd { | |||||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4, | |||||
winograd_2x3_4x4_f) | |||||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, | |||||
winograd_6x3_1x1_f) | |||||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4, | |||||
winograd_6x3_4x4_f) | |||||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 5, 4, 1, 1, | |||||
winograd_5x4_1x1_f) | |||||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 4, 5, 1, 1, | |||||
winograd_4x5_1x1_f) | |||||
} // namespace winograd | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,346 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F23) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
struct InputTransform2X3 { | |||||
template <bool inner> | |||||
static void prepare(const float* input, float* patch, float* patchT, | |||||
int ih_start, int iw_start, size_t IH, size_t IW, | |||||
size_t ic, size_t IC) { | |||||
constexpr size_t alpha = 2 + 3 - 1; | |||||
if (!(inner && ic + 4 < IC)) { | |||||
memset(patch, 0, sizeof(float) * 4 * alpha * alpha); | |||||
} | |||||
if (inner) { | |||||
const float* input_ptr = | |||||
input + ic * IH * IW + ih_start * IW + iw_start; | |||||
for (size_t ico = 0; ico < 4; ++ico) { | |||||
if (ic + ico < IC) { | |||||
auto v0 = vld1q_f32(input_ptr); | |||||
auto v1 = vld1q_f32(input_ptr + IW); | |||||
auto v2 = vld1q_f32(input_ptr + IW * 2); | |||||
auto v3 = vld1q_f32(input_ptr + IW * 3); | |||||
vst1q_f32(patch + ico * 4 * alpha + 0 * 4, v0); | |||||
vst1q_f32(patch + ico * 4 * alpha + 1 * 4, v1); | |||||
vst1q_f32(patch + ico * 4 * alpha + 2 * 4, v2); | |||||
vst1q_f32(patch + ico * 4 * alpha + 3 * 4, v3); | |||||
input_ptr += IH * IW; | |||||
} | |||||
} | |||||
} else { | |||||
int ih0_act = std::max<int>(ih_start, 0), | |||||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||||
iw0_act = std::max<int>(iw_start, 0), | |||||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||||
// partial copy | |||||
for (size_t ico = 0; ico < 4; ++ico) { | |||||
if (ic + ico < IC) { | |||||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||||
patch[ico * alpha * 4 + iho * 4 + iwo] = | |||||
input[(ic + ico) * IH * IW + ih * IW + iw]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
transpose_4x4(patch + 0 * 1, patchT + 0 * 4, 16, 4); | |||||
transpose_4x4(patch + 4 * 1, patchT + 4 * 4, 16, 4); | |||||
transpose_4x4(patch + 8 * 1, patchT + 8 * 4, 16, 4); | |||||
transpose_4x4(patch + 12 * 1, patchT + 12 * 4, 16, 4); | |||||
} | |||||
static void transform(const float* patchT, float* input_transform_buf, | |||||
size_t unit_idx, size_t nr_units_in_tile, size_t ic, | |||||
size_t IC) { | |||||
constexpr size_t alpha = 2 + 3 - 1; | |||||
// BT * d * B | |||||
#define cb(m, n) \ | |||||
Vector<float, 4> d##m##n = \ | |||||
Vector<float, 4>::load(patchT + m * 4 * 4 + n * 4); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||||
#undef cb | |||||
//! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 | |||||
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | |||||
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | |||||
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 | |||||
#define cb(m) \ | |||||
auto t0##m = d0##m - d2##m; \ | |||||
auto t1##m = d1##m + d2##m; \ | |||||
auto t2##m = d2##m - d1##m; \ | |||||
auto t3##m = d3##m - d1##m; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
#define cb(m) \ | |||||
d##m##0 = t##m##0 - t##m##2; \ | |||||
d##m##1 = t##m##1 + t##m##2; \ | |||||
d##m##2 = t##m##2 - t##m##1; \ | |||||
d##m##3 = t##m##3 - t##m##1; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
size_t ICB = IC / 4; | |||||
size_t icb = ic / 4; | |||||
#define cb(m, n) \ | |||||
d##m##n.save(input_transform_buf + \ | |||||
(m * alpha + n) * ICB * nr_units_in_tile * 4 + \ | |||||
icb * nr_units_in_tile * 4 + unit_idx * 4); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) | |||||
#undef cb | |||||
} | |||||
}; | |||||
template <BiasMode bmode, typename Op> | |||||
struct OutputTransform2X3 { | |||||
static void transform(const float* output_transform_buf, const float* bias, | |||||
float* output, float* transform_mid_buf, | |||||
size_t oh_start, size_t ow_start, size_t OH, | |||||
size_t OW, size_t oc_start, size_t oc_end, | |||||
size_t oc_index, size_t unit_idx, | |||||
size_t nr_units_in_tile, const DType& src_dtype, | |||||
const DType& dst_dtype) { | |||||
Op op(src_dtype, dst_dtype); | |||||
//! AT * m * A | |||||
constexpr size_t alpha = 2 + 3 - 1; | |||||
size_t oc = oc_start + oc_index; | |||||
size_t OCB = (oc_end - oc_start) / 4; | |||||
size_t ocb = oc_index / 4; | |||||
#define cb(m, n) \ | |||||
auto v##m##n = Vector<float, 4>::load( \ | |||||
output_transform_buf + \ | |||||
(m * alpha + n) * OCB * nr_units_in_tile * 4 + \ | |||||
ocb * nr_units_in_tile * 4 + unit_idx * 4); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||||
#undef cb | |||||
//! 1 1 1 0 v00 v01 v02 v03 1 0 | |||||
//! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||||
//! v20 v21 v22 v23 1 -1 | |||||
//! v30 v31 v32 v33 0 1 | |||||
#define cb(m) \ | |||||
auto t0##m = v0##m + v1##m + v2##m; \ | |||||
auto t1##m = v1##m - v2##m + v3##m; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
v00 = t00 + t01 + t02; | |||||
v10 = t10 + t11 + t12; | |||||
v01 = t01 - t02 + t03; | |||||
v11 = t11 - t12 + t13; | |||||
Vector<float, 4> vbias; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
vbias = Vector<float, 4>::load(bias + oc); | |||||
v00 += vbias; | |||||
v10 += vbias; | |||||
v01 += vbias; | |||||
v11 += vbias; | |||||
} | |||||
if (bmode != BiasMode::BIAS) { | |||||
v00 = op(v00.value); | |||||
v01 = op(v01.value); | |||||
v10 = op(v10.value); | |||||
v11 = op(v11.value); | |||||
} | |||||
v00.save(transform_mid_buf + (0 * 2 + 0) * 4); | |||||
v10.save(transform_mid_buf + (1 * 2 + 0) * 4); | |||||
v01.save(transform_mid_buf + (0 * 2 + 1) * 4); | |||||
v11.save(transform_mid_buf + (1 * 2 + 1) * 4); | |||||
for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { | |||||
for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { | |||||
for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) { | |||||
size_t oh = oh_start + oho; | |||||
size_t ow = ow_start + owo; | |||||
float res = transform_mid_buf[oho * 2 * 4 + owo * 4 + oco]; | |||||
if (bmode == BiasMode::BIAS) { | |||||
res += bias[(oc + oco) * OH * OW + oh * OW + ow]; | |||||
res = op(res); | |||||
} | |||||
output[(oc + oco) * OH * OW + oh * OW + ow] = res; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace winograd { | |||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) | |||||
void winograd_2x3_4x4_f::filter(const float* filter, | |||||
float* filter_transform_buf, | |||||
float* transform_mid_buf, size_t OC, size_t IC, | |||||
size_t oc_start, size_t oc_end) { | |||||
constexpr int alpha = 2 + 3 - 1; | |||||
//! G * g * GT | |||||
float32x4_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0}, | |||||
g3{0, 0, 1, 0}; | |||||
float32x4_t gt0{1, 0.5, 0.5, 0}, gt1{0, 0.5, -0.5, 0}, gt2{0, 0.5, 0.5, 1}, | |||||
gt3{0, 0, 0, 0}; | |||||
size_t OCB = OC / 4; | |||||
size_t ICB = IC / 4; | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) | |||||
rep(ic, IC) { | |||||
const float* filter_ptr = filter + (oc * IC + ic) * 3 * 3; | |||||
/** | |||||
* origin: (4x3) * (3 x 3) * (3 x 4) | |||||
* pack to G and g to times of 4 | |||||
* now: (4x4) * (4 x 4) * (4 x 4) | |||||
*/ | |||||
//! 1 0 0 0 v00 v01 v02 0 1 0.5 0.5 0 | |||||
//! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 | |||||
//! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 | |||||
//! 0 0 1 0 0 0 0 0 0 0 0 0 | |||||
float32x4_t vf0 = vld1q_f32(filter_ptr); | |||||
float32x4_t vf1 = vld1q_f32(filter_ptr + 4); | |||||
float32x4_t vf2 = vdupq_n_f32(filter_ptr[8]); | |||||
float32x4_t v3(vdupq_n_f32(0)); | |||||
auto vtmp = vextq_f32(vf1, vf2, 2); | |||||
vtmp = vsetq_lane_f32(0, vtmp, 3); | |||||
float32x4_t v2(vtmp); | |||||
vtmp = vextq_f32(vf0, vf1, 3); | |||||
vtmp = vsetq_lane_f32(0, vtmp, 3); | |||||
float32x4_t v1(vtmp); | |||||
vtmp = vsetq_lane_f32(0, vf0, 3); | |||||
float32x4_t v0(vtmp); | |||||
float32x4_t vsum0 = vdupq_n_f32(0), vsum1 = vdupq_n_f32(0), | |||||
vsum2 = vdupq_n_f32(0), vsum3 = vdupq_n_f32(0); | |||||
MATRIX_MUL4x4(vsum, g, v); | |||||
float32x4_t vres0 = vdupq_n_f32(0), vres1 = vdupq_n_f32(0), | |||||
vres2 = vdupq_n_f32(0), vres3 = vdupq_n_f32(0); | |||||
MATRIX_MUL4x4(vres, vsum, gt); | |||||
vst1q_f32(transform_mid_buf, vres0); | |||||
vst1q_f32(transform_mid_buf + 4, vres1); | |||||
vst1q_f32(transform_mid_buf + 8, vres2); | |||||
vst1q_f32(transform_mid_buf + 12, vres3); | |||||
size_t ocb = oc / 4; | |||||
size_t oc4 = oc % 4; | |||||
size_t icb = ic / 4; | |||||
size_t ic4 = ic % 4; | |||||
rep(i, alpha) rep(j, alpha) { | |||||
filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 * 4 + | |||||
ocb * ICB * 4 * 4 + icb * 4 * 4 + ic4 * 4 + | |||||
oc4] = transform_mid_buf[i * alpha + j]; | |||||
} | |||||
} | |||||
} | |||||
void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf, | |||||
float* transform_mid_buf, size_t IH, size_t IW, | |||||
size_t IC, size_t PH, size_t PW, | |||||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||||
megdnn_assert(IC % 4 == 0); | |||||
constexpr int alpha = 3 + 2 - 1; | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
float* patch = transform_mid_buf; | |||||
float* patchT = transform_mid_buf + 4 * alpha * alpha; | |||||
for (size_t ic = 0; ic < IC; ic += 4) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||||
InputTransform2X3::prepare<true>(input, patch, patchT, ih_start, | |||||
iw_start, IH, IW, ic, IC); | |||||
InputTransform2X3::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, ic, | |||||
IC); | |||||
} else { | |||||
InputTransform2X3::prepare<false>(input, patch, patchT, | |||||
ih_start, iw_start, IH, IW, | |||||
ic, IC); | |||||
InputTransform2X3::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, ic, | |||||
IC); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void winograd_2x3_4x4_f::output(const float* output_transform_buf, | |||||
const float* bias, float* output, | |||||
float* transform_mid_buf, BiasMode bmode, | |||||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
OutputTransform2X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
for (size_t oc = oc_start; oc < oc_end; oc += 4) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp32_F23, cb, float, float, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||||
nr_units_in_tile, src_dtype, dst_dtype); | |||||
} | |||||
} | |||||
#undef cb | |||||
} | |||||
} // namespace winograd | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,483 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F45) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
struct FilterTransform4X5 { | |||||
#define FILTER_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
wd##r0 = d##r0; \ | |||||
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \ | |||||
wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; \ | |||||
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \ | |||||
wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; \ | |||||
auto tmpd0 = d##0 * 0.7111111; \ | |||||
auto tmpd1 = d##1 * 0.3555556; \ | |||||
auto tmpd2 = d##2 * 0.1777778; \ | |||||
auto tmpd3 = d##3 * 0.0888889; \ | |||||
auto tmpd4 = d##4 * 0.0444444; \ | |||||
auto tmpdr0 = d##r0 * 0.7111111; \ | |||||
auto tmpdr1 = d##r1 * 0.3555556; \ | |||||
auto tmpdr2 = d##r2 * 0.1777778; \ | |||||
auto tmpdr3 = d##r3 * 0.0888889; \ | |||||
auto tmpdr4 = d##r4 * 0.0444444; \ | |||||
wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ | |||||
wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||||
wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ | |||||
wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||||
tmpd0 = d##0 * 0.0111111; \ | |||||
tmpd1 = d##1 * 0.0222222; \ | |||||
tmpd2 = d##2 * 0.0444444; \ | |||||
tmpd3 = d##3 * 0.0888889; \ | |||||
tmpd4 = d##4 * 0.1777778; \ | |||||
tmpdr0 = d##r0 * 0.0111111; \ | |||||
tmpdr1 = d##r1 * 0.0222222; \ | |||||
tmpdr2 = d##r2 * 0.0444444; \ | |||||
tmpdr3 = d##r3 * 0.0888889; \ | |||||
tmpdr4 = d##r4 * 0.1777778; \ | |||||
wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ | |||||
wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||||
wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ | |||||
wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||||
wd##7 = d##4; \ | |||||
wd##r7 = d##r4; \ | |||||
} while (0); | |||||
#define FILTER_TRANSFORM_FINAL(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \ | |||||
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \ | |||||
auto tmp0 = d##0 * 0.7111111 + d##2 * 0.1777778 + d##4 * 0.0444444; \ | |||||
auto tmp1 = d##1 * 0.3555556 + d##3 * 0.0888889; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d##0 * 0.0111111 + d##2 * 0.0444444 + d##4 * 0.1777778; \ | |||||
tmp1 = d##1 * 0.0222222 + d##3 * 0.0888889; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = d##4; \ | |||||
} while (0); | |||||
static void transform(const float* filter, float* filter_transform_buf, | |||||
float* transform_mid_buf, size_t OC, size_t IC, | |||||
size_t oc_start, size_t oc_end) { | |||||
// Gg * GT | |||||
// G | |||||
//[[ 1. 0. 0. 0. 0. ] | |||||
// [-0.2222222 -0.2222222 -0.2222222 -0.2222222 -0.2222222] | |||||
// [-0.2222222 0.2222222 -0.2222222 0.2222222 -0.2222222] | |||||
// [ 0.7111111 0.3555556 0.1777778 0.0888889 0.0444444] | |||||
// [ 0.7111111 -0.3555556 0.1777778 -0.0888889 0.0444444] | |||||
// [ 0.0111111 0.0222222 0.0444444 0.0888889 0.1777778] | |||||
// [ 0.0111111 -0.0222222 0.0444444 -0.0888889 0.1777778] | |||||
// [ 0. 0. 0. 0. 1. ]] | |||||
constexpr size_t alpha = 4 + 5 - 1; | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) | |||||
rep(ic, IC) { | |||||
const float* fptr = filter + (oc * IC + ic) * 5 * 5; | |||||
#define cb(i) Vector<float, 4> g##i = Vector<float, 4>::load(fptr + 5 * i); | |||||
UNROLL_CALL_NOWRAPPER(5, cb); | |||||
#undef cb | |||||
#define cb(i) float gr##i = *(fptr + 5 * i + 4); | |||||
UNROLL_CALL_NOWRAPPER(5, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<float, 4> Gg##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) float Ggr##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<float, 8> Ggt##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<float, 8> result##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
FILTER_TRANSFORM(g, Gg) | |||||
float32x4x2_t vgr; | |||||
float32x4_t vgr0 = {Ggr0, Ggr1, Ggr2, Ggr3}; | |||||
float32x4_t vgr1 = {Ggr4, Ggr5, Ggr6, Ggr7}; | |||||
vgr.val[0] = vgr0; //{Ggr0, Ggr1, Ggr2, Ggr3}; | |||||
vgr.val[1] = vgr1; //{Ggr4, Ggr5, Ggr6, Ggr7}; | |||||
Vector<float, 8> Ggt4(vgr); | |||||
TRANSPOSE_8x4(Gg, Ggt); | |||||
FILTER_TRANSFORM_FINAL(Ggt, result); | |||||
#define cb(i) result##i.save(transform_mid_buf + i * alpha); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||||
oc] = transform_mid_buf[j * alpha + i]; | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#undef FILTER_TRANSFORM | |||||
#undef FILTER_TRANSFORM_FINAL | |||||
struct InputTransform4X5 { | |||||
#define INPUT_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||||
auto tmp0 = d##2 - d##4 * 4.25f + d##6; \ | |||||
auto tmp1 = d##1 - d##3 * 4.25f + d##5; \ | |||||
wd##1 = tmp0 + tmp1; \ | |||||
wd##2 = tmp0 - tmp1; \ | |||||
tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ | |||||
tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ | |||||
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||||
} while (0) | |||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||||
vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||||
vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) | |||||
template <bool inner> | |||||
static void transform(const float* input, float* input_transform_buf, | |||||
float* transform_mid_buf, int ih_start, int iw_start, | |||||
size_t ic, size_t IH, size_t IW, size_t IC, | |||||
size_t unit_idx, size_t nr_units_in_tile) { | |||||
// BTd * B | |||||
//([[ 1. , 0. , -5.25, 0. , 5.25, 0. , -1. , 0. ], | |||||
// [ 0. , 1. , 1. , -4.25, -4.25, 1. , 1. , 0. ], | |||||
// [ 0. , -1. , 1. , 4.25, -4.25, -1. , 1. , 0. ], | |||||
// [ 0. , 2. , 4. , -2.5 , -5. , 0.5 , 1. , 0. ], | |||||
// [ 0. , -2. , 4. , 2.5 , -5. , -0.5 , 1. , 0. ], | |||||
// [ 0. , 0.5 , 0.25, -2.5 , -1.25, 2. , 1. , 0. ], | |||||
// [ 0. , -0.5 , 0.25, 2.5 , -1.25, -2. , 1. , 0. ], | |||||
// [ 0. , -1. , 0. , 5.25, 0. , -5.25, 0. , 1. ]])) | |||||
constexpr size_t alpha = 4 + 5 - 1; | |||||
if (!inner) { | |||||
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | |||||
} | |||||
#define cb(i) Vector<float, 8> d##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
if (inner) { | |||||
const float* input_ptr = | |||||
input + ic * IH * IW + ih_start * IW + iw_start; | |||||
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
} else { | |||||
int ih0_act = std::max<int>(ih_start, 0), | |||||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||||
iw0_act = std::max<int>(iw_start, 0), | |||||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||||
transform_mid_buf[iho * alpha + iwo] = | |||||
input[ic * IH * IW + ih * IW + iw]; | |||||
} | |||||
} | |||||
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
} | |||||
#define cb(i) Vector<float, 8> wd##i, ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
INPUT_TRANSFORM(d, wd); | |||||
#if MEGDNN_AARCH64 | |||||
TRANSPOSE_8x8(wd, d); | |||||
INPUT_TRANSFORM(d, ret); | |||||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||||
unit_idx * IC + ic] = | |||||
transform_mid_buf[j * alpha + i]; | |||||
} | |||||
#else | |||||
#define cb(i) \ | |||||
do { \ | |||||
mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ | |||||
5.25 * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ | |||||
GET_VECTOR_LOW_ELEM(wd, i, 2)); \ | |||||
mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \ | |||||
GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||||
5.25 * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ | |||||
auto tmp0 = 4 * GET_VECTOR_LOW_ELEM(wd, i, 2) + \ | |||||
-5 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 2); \ | |||||
auto tmp1 = 2 * GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||||
-2.5 * GET_VECTOR_LOW_ELEM(wd, i, 3) + \ | |||||
0.5 * GET_VECTOR_HIGH_ELEM(wd, i, 1); \ | |||||
mid_buf1[3] = tmp0 + tmp1; \ | |||||
mid_buf1[4] = tmp0 - tmp1; \ | |||||
tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ | |||||
-4.25 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 2); \ | |||||
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||||
GET_VECTOR_LOW_ELEM(wd, i, 3) * -4.25 + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 1); \ | |||||
mid_buf1[1] = tmp0 + tmp1; \ | |||||
mid_buf1[2] = tmp0 - tmp1; \ | |||||
tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) * 0.25 + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 0) * -1.25 + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 2); \ | |||||
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5 + \ | |||||
GET_VECTOR_LOW_ELEM(wd, i, 3) * -2.5 + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2; \ | |||||
mid_buf1[5] = tmp0 + tmp1; \ | |||||
mid_buf1[6] = tmp0 - tmp1; \ | |||||
mid_buf1 += 8; \ | |||||
} while (0); | |||||
float* mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||||
unit_idx * IC + ic] = | |||||
transform_mid_buf[i * alpha + j]; | |||||
} | |||||
#endif | |||||
} | |||||
}; | |||||
#undef INPUT_TRANSFORM | |||||
#define OUTPUT_TRANSFORM(m, s) \ | |||||
do { \ | |||||
s0 = m0 + m1 + m2 + m3 + m4 + m5 + m6; \ | |||||
s1 = m1 - m2 + m3 * 0.5 - m4 * 0.5 + m5 * 2.0 - m6 * 2.0; \ | |||||
s2 = m1 + m2 + m3 * 0.25 + m4 * 0.25 + m5 * 4.0 + m6 * 4.0; \ | |||||
s3 = m1 - m2 + m3 * 0.125 - m4 * 0.125 + m5 * 8.0 - m6 * 8.0 + m7; \ | |||||
} while (0) | |||||
template <BiasMode bmode, typename Op> | |||||
struct OutputTransform4X5 { | |||||
static void transform(const float* output_transform_buf, const float* bias, | |||||
float* output, float* transform_mid_buf, | |||||
size_t oh_start, size_t ow_start, size_t OH, | |||||
size_t OW, size_t oc_start, size_t oc_end, | |||||
size_t oc_index, size_t unit_idx, | |||||
size_t nr_units_in_tile, const DType& src_dtype, | |||||
const DType& dst_dtype) { | |||||
Op op(src_dtype, dst_dtype); | |||||
//! AT * m * A | |||||
// AT f45 | |||||
// 1.0 1.0 1.0 1.000 1.000 1.0 1.0 0.0 | |||||
// 0.0 1.0 -1.0 0.500 -0.500 2.0 -2.0 0.0 | |||||
// 0.0 1.0 1.0 0.250 0.250 4.0 4.0 0.0 | |||||
// 0.0 1.0 -1.0 0.125 -0.125 8.0 -8.0 1.0 | |||||
constexpr size_t alpha = 5 + 4 - 1; | |||||
float* mid_buf1 = transform_mid_buf; | |||||
size_t OC = oc_end - oc_start; | |||||
size_t oc = oc_start + oc_index; | |||||
#define cb(m, n) \ | |||||
transform_mid_buf[m * alpha + n] = \ | |||||
output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \ | |||||
unit_idx * OC + oc_index]; | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||||
#undef cb | |||||
#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<float, 8> s##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
OUTPUT_TRANSFORM(m, s); | |||||
#define cb(i) \ | |||||
do { \ | |||||
auto add12 = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||||
auto add34 = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||||
auto add56 = \ | |||||
GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ | |||||
auto sub12 = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||||
auto sub34 = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||||
auto sub56 = \ | |||||
GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ | |||||
mid_buf1[0] = GET_VECTOR_LOW_ELEM(s, i, 0) + add12 + add34 + add56; \ | |||||
mid_buf1[1] = sub12 + sub34 * 0.5 + sub56 * 2.0; \ | |||||
mid_buf1[2] = add12 + add34 * 0.25 + add56 * 4.0; \ | |||||
mid_buf1[3] = sub12 + sub34 * 0.125 + sub56 * 8.0 + \ | |||||
GET_VECTOR_HIGH_ELEM(s, i, 3); \ | |||||
mid_buf1 += 4; \ | |||||
} while (0); | |||||
mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
if (oh_start + 4 <= OH && ow_start + 4 <= OW) { | |||||
float32x4_t bias0; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias0 = vdupq_n_f32(bias[oc]); | |||||
} | |||||
rep(i, 4) { | |||||
size_t oh = oh_start + i; | |||||
float32x4_t item0 = vld1q_f32(mid_buf1); | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
item0 = vaddq_f32(item0, bias0); | |||||
} else if (bmode == BiasMode::BIAS) { | |||||
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); | |||||
item0 = vaddq_f32(item0, bias0); | |||||
} | |||||
item0 = op(item0); | |||||
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||||
mid_buf1 += 4; | |||||
} | |||||
} else { | |||||
for (size_t oho = 0; oho < 4 && oh_start + oho < OH; ++oho) { | |||||
for (size_t owo = 0; owo < 4 && ow_start + owo < OW; ++owo) { | |||||
size_t oh = oh_start + oho; | |||||
size_t ow = ow_start + owo; | |||||
float res = mid_buf1[oho * 4 + owo]; | |||||
if (bmode == BiasMode::BIAS) { | |||||
res += bias[oc * OH * OW + oh * OW + ow]; | |||||
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
res += bias[oc]; | |||||
} | |||||
res = op(res); | |||||
output[oc * OH * OW + oh * OW + ow] = res; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#undef OUTPUT_TRANSFORM | |||||
#undef GET_VECTOR_HIGH_ELEM | |||||
#undef GET_VECTOR_LOW_ELEM | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace winograd { | |||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f) | |||||
void winograd_4x5_1x1_f::filter(const float* filter, | |||||
float* filter_transform_buf, | |||||
float* transform_mid_buf, size_t OC, size_t IC, | |||||
size_t oc_start, size_t oc_end) { | |||||
FilterTransform4X5::transform(filter, filter_transform_buf, | |||||
transform_mid_buf, OC, IC, oc_start, oc_end); | |||||
} | |||||
void winograd_4x5_1x1_f::input(const float* input, float* input_transform_buf, | |||||
float* transform_mid_buf, size_t IH, size_t IW, | |||||
size_t IC, size_t PH, size_t PW, | |||||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||||
constexpr int alpha = 4 + 5 - 1; | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
rep(ic, IC) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||||
InputTransform4X5::transform<true>( | |||||
input, input_transform_buf, transform_mid_buf, ih_start, | |||||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||||
} else { | |||||
InputTransform4X5::transform<false>( | |||||
input, input_transform_buf, transform_mid_buf, ih_start, | |||||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void winograd_4x5_1x1_f::output(const float* output_transform_buf, | |||||
const float* bias, float* output, | |||||
float* transform_mid_buf, BiasMode bmode, | |||||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
OutputTransform4X5<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp32_F45, cb, float, float, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||||
nr_units_in_tile, src_dtype, dst_dtype); | |||||
} | |||||
} | |||||
#undef cb | |||||
} | |||||
} // namespace winograd | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,500 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F54) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
struct FilterTransform5X4 { | |||||
#define FILTER_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = d##0; \ | |||||
auto tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \ | |||||
auto tmp1 = d##1 * 0.3555556f + d##3 * 0.0888889f; \ | |||||
wd##1 = tmp0 + tmp1; \ | |||||
wd##2 = tmp0 - tmp1; \ | |||||
tmp0 = (d##0 + d##2) * -0.2222222f; \ | |||||
tmp1 = (d##1 + d##3) * -0.2222222f; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \ | |||||
tmp1 = d##1 * 0.0222222f + d##3 * 0.0888889f; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = d##3; \ | |||||
} while (0) | |||||
static void transform(const float* filter, float* filter_transform_buf, | |||||
float* transform_mid_buf, size_t OC, size_t IC, | |||||
size_t oc_start, size_t oc_end) { | |||||
// Gg * GT | |||||
// G | |||||
// 1 0 0 0 | |||||
// 0.7111111 0.3555556 0.1777778 0.0888889 | |||||
// 0.7111111 -0.3555556 0.1777778 -0.0888889 | |||||
// -0.2222222 -0.2222222 -0.2222222 -0.2222222 | |||||
// -0.2222222 0.2222222 -0.2222222 0.2222222 | |||||
// 0.0111111 0.0222222 0.0444444 0.0888889 | |||||
// 0.0111111 -0.0222222 0.0444444 -0.0888889 | |||||
// 0 0 0 1 | |||||
constexpr size_t alpha = 4 + 5 - 1; | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
rep(ic, IC) { | |||||
const float* fptr = filter + (oc * IC + ic) * 4 * 4; | |||||
#define cb(i) Vector<float, 4> g##i = Vector<float, 4>::load(fptr + 4 * i); | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<float, 4> wd##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<float, 8> wdt##i; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<float, 8> ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
FILTER_TRANSFORM(g, wd); | |||||
#if MEGDNN_AARCH64 | |||||
TRANSPOSE_8x4(wd, wdt); | |||||
FILTER_TRANSFORM(wdt, ret); | |||||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||||
oc] = transform_mid_buf[j * alpha + i]; | |||||
} | |||||
#else | |||||
#define cb(i) \ | |||||
do { \ | |||||
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ | |||||
auto tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.7111111f + \ | |||||
GET_VECTOR_ELEM(wd, i, 2) * 0.1777778f; \ | |||||
auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.3555556f + \ | |||||
GET_VECTOR_ELEM(wd, i, 3) * 0.0888889f; \ | |||||
mid_buf1[1] = tmp0 + tmp1; \ | |||||
mid_buf1[2] = tmp0 - tmp1; \ | |||||
tmp0 = (GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * \ | |||||
-0.2222222f; \ | |||||
tmp1 = (GET_VECTOR_ELEM(wd, i, 1) + GET_VECTOR_ELEM(wd, i, 3)) * \ | |||||
-0.2222222f; \ | |||||
mid_buf1[3] = tmp0 + tmp1; \ | |||||
mid_buf1[4] = tmp0 - tmp1; \ | |||||
tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0111111f + \ | |||||
GET_VECTOR_ELEM(wd, i, 2) * 0.0444444f; \ | |||||
tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0222222f + \ | |||||
GET_VECTOR_ELEM(wd, i, 3) * 0.0888889f; \ | |||||
mid_buf1[5] = tmp0 + tmp1; \ | |||||
mid_buf1[6] = tmp0 - tmp1; \ | |||||
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ | |||||
mid_buf1 += 8; \ | |||||
} while (0); | |||||
#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) | |||||
float* mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||||
oc] = transform_mid_buf[i * alpha + j]; | |||||
} | |||||
#endif | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#undef FILTER_TRANSFORM | |||||
#undef GET_VECTOR_ELEM | |||||
struct InputTransform5X4 { | |||||
#define INPUT_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||||
auto tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ | |||||
auto tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ | |||||
wd##1 = tmp0 + tmp1; \ | |||||
wd##2 = tmp0 - tmp1; \ | |||||
tmp0 = d##2 - d##4 * 4.25f + d##6; \ | |||||
tmp1 = d##1 - d##3 * 4.25f + d##5; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ | |||||
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||||
} while (0) | |||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||||
vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||||
vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) | |||||
template <bool inner> | |||||
static void transform(const float* input, float* input_transform_buf, | |||||
float* transform_mid_buf, int ih_start, int iw_start, | |||||
size_t ic, size_t IH, size_t IW, size_t IC, | |||||
size_t unit_idx, size_t nr_units_in_tile) { | |||||
// BTd * B | |||||
// BT | |||||
// 1 0 -5.25 0 5.25 0 -1 0 | |||||
// 0 2 4 -2.5 -5 0.5 1 0 | |||||
// 0 -2 4 2.5 -5 -0.5 1 0 | |||||
// 0 1 1 -4.25 -4.25 1 1 0 | |||||
// 0 -1 1 4.25 -4.25 -1 1 0 | |||||
// 0 0.5 0.25 -2.5 -1.25 2 1 0 | |||||
// 0 -0.5 0.25 2.5 -1.25 -2 1 0 | |||||
// 0 -1 0 5.25 0 -5.25 0 1 | |||||
constexpr size_t alpha = 4 + 5 - 1; | |||||
if (!inner) { | |||||
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | |||||
} | |||||
#define cb(i) Vector<float, 8> d##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
if (inner) { | |||||
const float* input_ptr = | |||||
input + ic * IH * IW + ih_start * IW + iw_start; | |||||
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
} else { | |||||
int ih0_act = std::max<int>(ih_start, 0), | |||||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||||
iw0_act = std::max<int>(iw_start, 0), | |||||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||||
transform_mid_buf[iho * alpha + iwo] = | |||||
input[ic * IH * IW + ih * IW + iw]; | |||||
} | |||||
} | |||||
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
} | |||||
#define cb(i) Vector<float, 8> wd##i, ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
INPUT_TRANSFORM(d, wd); | |||||
#if MEGDNN_AARCH64 | |||||
TRANSPOSE_8x8(wd, d); | |||||
INPUT_TRANSFORM(d, ret); | |||||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||||
unit_idx * IC + ic] = | |||||
transform_mid_buf[j * alpha + i]; | |||||
} | |||||
#else | |||||
#define cb(i) \ | |||||
do { \ | |||||
mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ | |||||
5.25 * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ | |||||
GET_VECTOR_LOW_ELEM(wd, i, 2)); \ | |||||
mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \ | |||||
GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||||
5.25 * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ | |||||
auto tmp0 = 4 * GET_VECTOR_LOW_ELEM(wd, i, 2) + \ | |||||
-5 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 2); \ | |||||
auto tmp1 = 2 * GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||||
-2.5 * GET_VECTOR_LOW_ELEM(wd, i, 3) + \ | |||||
0.5 * GET_VECTOR_HIGH_ELEM(wd, i, 1); \ | |||||
mid_buf1[1] = tmp0 + tmp1; \ | |||||
mid_buf1[2] = tmp0 - tmp1; \ | |||||
tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ | |||||
-4.25 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 2); \ | |||||
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||||
GET_VECTOR_LOW_ELEM(wd, i, 3) * -4.25 + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 1); \ | |||||
mid_buf1[3] = tmp0 + tmp1; \ | |||||
mid_buf1[4] = tmp0 - tmp1; \ | |||||
tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) * 0.25 + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 0) * -1.25 + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 2); \ | |||||
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5 + \ | |||||
GET_VECTOR_LOW_ELEM(wd, i, 3) * -2.5 + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2; \ | |||||
mid_buf1[5] = tmp0 + tmp1; \ | |||||
mid_buf1[6] = tmp0 - tmp1; \ | |||||
mid_buf1 += 8; \ | |||||
} while (0); | |||||
float* mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||||
unit_idx * IC + ic] = | |||||
transform_mid_buf[i * alpha + j]; | |||||
} | |||||
#endif | |||||
} | |||||
}; | |||||
#undef INPUT_TRANSFORM | |||||
#define OUTPUT_TRANSFORM(m, s) \ | |||||
do { \ | |||||
auto m1addm2 = m##1 + m##2; \ | |||||
auto m1subm2 = m##1 - m##2; \ | |||||
auto m3addm4 = m##3 + m##4; \ | |||||
auto m3subm4 = m##3 - m##4; \ | |||||
auto m5addm6 = (m##5 + m##6); \ | |||||
auto m5subm6 = (m##5 - m##6); \ | |||||
s##0 = m##0; \ | |||||
CONCAT(s, 0).add(m1addm2).add(m3addm4).add(m5addm6); \ | |||||
CONCAT(s, 1) = m3subm4; \ | |||||
CONCAT(s, 1).mla(m1subm2, 0.5f).mla(m5subm6, 2.0f); \ | |||||
CONCAT(s, 2) = m3addm4; \ | |||||
CONCAT(s, 2).mla(m1addm2, 0.25f).mla(m5addm6, 4.0f); \ | |||||
CONCAT(s, 3) = m3subm4; \ | |||||
CONCAT(s, 3).mla(m1subm2, 0.125f).mla(m5subm6, 8.0f); \ | |||||
CONCAT(s, 4) = m##7; \ | |||||
CONCAT(s, 4).mla(m1addm2, 0.0625f).add(m3addm4).mla(m5addm6, 16.0f); \ | |||||
} while (0) | |||||
template <BiasMode bmode, typename Op> | |||||
struct OutputTransform5X4 { | |||||
static void transform(const float* output_transform_buf, const float* bias, | |||||
float* output, float* transform_mid_buf, | |||||
size_t oh_start, size_t ow_start, size_t OH, | |||||
size_t OW, size_t oc_start, size_t oc_end, | |||||
size_t oc_index, size_t unit_idx, | |||||
size_t nr_units_in_tile, const DType& src_dtype, | |||||
const DType& dst_dtype) { | |||||
Op op(src_dtype, dst_dtype); | |||||
//! AT * m * A | |||||
// AT | |||||
// 1 1 1 1 1 1 1 0 | |||||
// 0 0.5 -0.5 1 -1 2 -2 0 | |||||
// 0 0.25 0.25 1 1 4 4 0 | |||||
// 0 0.125 -0.125 1 -1 8 -8 0 | |||||
// 0 0.0625 0.0625 1 1 16 16 1 | |||||
constexpr size_t alpha = 5 + 4 - 1; | |||||
float* mid_buf1 = transform_mid_buf; | |||||
size_t OC = oc_end - oc_start; | |||||
size_t oc = oc_start + oc_index; | |||||
#define cb(m, n) \ | |||||
transform_mid_buf[m * alpha + n] = \ | |||||
output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \ | |||||
unit_idx * OC + oc_index]; | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||||
#undef cb | |||||
#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<float, 8> s##i, ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
OUTPUT_TRANSFORM(m, s); | |||||
#define cb(i) \ | |||||
do { \ | |||||
auto m1addm2 = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||||
auto m1subm2 = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||||
auto m3addm4 = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||||
auto m3subm4 = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||||
auto m5addm6 = \ | |||||
GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ | |||||
auto m5subm6 = \ | |||||
GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ | |||||
mid_buf1[0] = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ | |||||
mid_buf1[1] = 0.5f * m1subm2 + m3subm4 + 2.0f * m5subm6; \ | |||||
mid_buf1[2] = 0.25f * m1addm2 + m3addm4 + 4.0f * m5addm6; \ | |||||
mid_buf1[3] = 0.125f * m1subm2 + m3subm4 + 8.0f * m5subm6; \ | |||||
mid_buf1[4] = 0.0625f * m1addm2 + m3addm4 + 16.0f * m5addm6 + \ | |||||
GET_VECTOR_HIGH_ELEM(s, i, 3); \ | |||||
mid_buf1 += 5; \ | |||||
} while (0); | |||||
mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(5, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
if (oh_start + 5 <= OH && ow_start + 5 <= OW) { | |||||
float32x4_t bias0; | |||||
float32_t bias1; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias0 = vdupq_n_f32(bias[oc]); | |||||
bias1 = bias[oc]; | |||||
} | |||||
rep(i, 5) { | |||||
size_t oh = oh_start + i; | |||||
float32x4_t item0 = vld1q_f32(mid_buf1); | |||||
float32_t item1 = mid_buf1[4]; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
item0 = vaddq_f32(item0, bias0); | |||||
item1 = item1 + bias1; | |||||
} else if (bmode == BiasMode::BIAS) { | |||||
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); | |||||
bias1 = bias[oc * OH * OW + oh * OW + ow_start + 4]; | |||||
item0 = vaddq_f32(item0, bias0); | |||||
item1 = item1 + bias1; | |||||
} | |||||
item0 = op(item0); | |||||
item1 = op(item1); | |||||
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||||
output[oc * OH * OW + oh * OW + ow_start + 4] = item1; | |||||
mid_buf1 += 5; | |||||
} | |||||
} else { | |||||
for (size_t oho = 0; oho < 5 && oh_start + oho < OH; ++oho) { | |||||
for (size_t owo = 0; owo < 5 && ow_start + owo < OW; ++owo) { | |||||
size_t oh = oh_start + oho; | |||||
size_t ow = ow_start + owo; | |||||
float res = mid_buf1[oho * 5 + owo]; | |||||
if (bmode == BiasMode::BIAS) { | |||||
res += bias[oc * OH * OW + oh * OW + ow]; | |||||
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
res += bias[oc]; | |||||
} | |||||
res = op(res); | |||||
output[oc * OH * OW + oh * OW + ow] = res; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#undef OUTPUT_TRANSFORM | |||||
#undef GET_VECTOR_HIGH_ELEM | |||||
#undef GET_VECTOR_LOW_ELEM | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace winograd { | |||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_5x4_1x1_f) | |||||
void winograd_5x4_1x1_f::filter(const float* filter, | |||||
float* filter_transform_buf, | |||||
float* transform_mid_buf, size_t OC, size_t IC, | |||||
size_t oc_start, size_t oc_end) { | |||||
FilterTransform5X4::transform(filter, filter_transform_buf, | |||||
transform_mid_buf, OC, IC, oc_start, oc_end); | |||||
} | |||||
void winograd_5x4_1x1_f::input(const float* input, float* input_transform_buf, | |||||
float* transform_mid_buf, size_t IH, size_t IW, | |||||
size_t IC, size_t PH, size_t PW, | |||||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||||
constexpr int alpha = 5 + 4 - 1; | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
rep(ic, IC) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||||
InputTransform5X4::transform<true>( | |||||
input, input_transform_buf, transform_mid_buf, ih_start, | |||||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||||
} else { | |||||
InputTransform5X4::transform<false>( | |||||
input, input_transform_buf, transform_mid_buf, ih_start, | |||||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void winograd_5x4_1x1_f::output(const float* output_transform_buf, | |||||
const float* bias, float* output, | |||||
float* transform_mid_buf, BiasMode bmode, | |||||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, | |||||
size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
OutputTransform5X4<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp32_F54, cb, float, float, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||||
nr_units_in_tile, src_dtype, dst_dtype); | |||||
} | |||||
} | |||||
#undef cb | |||||
} | |||||
} // namespace winograd | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,424 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/fp32/filter_transform.h" | |||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
/** | |||||
* input transform | |||||
* | |||||
* wd0 = (d0 - d6) + 5.25 * (d4 - d2) | |||||
* wd1 = (d6 + d2 - 4.25 * d4) + (d1 + d5 - 4.25 * d3) | |||||
* wd2 = (d6 + d2 - 4.25 * d4) - (d1 + d5 - 4.25 * d3) | |||||
* wd3 = (d6 + 0.25 * d2 - 1.25 * d4) + 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) | |||||
* wd4 = (d6 + 0.25 * d2 - 1.25 * d4) - 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) | |||||
* wd5 = (d6 - 5.0 * d4 + 4.0 * d2) + 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) | |||||
* wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) | |||||
* wd7 = (d7 - d1) + 5.25 * (d3 - d5) | |||||
*/ | |||||
#define INPUT_TRANSFORM(d, wd) \ | |||||
do { \ | |||||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||||
auto tmp0 = d##6 + d##2 - d##4 * 4.25f; \ | |||||
auto tmp1 = d##1 + d##5 - d##3 * 4.25f; \ | |||||
wd##1 = tmp0 + tmp1; \ | |||||
wd##2 = tmp0 - tmp1; \ | |||||
tmp0 = d##6 + d##2 * 0.25f - d##4 * 1.25f; \ | |||||
tmp1 = (d##5 + d##1 * 0.25f - d##3 * 1.25f) * 2.0f; \ | |||||
wd##3 = tmp0 + tmp1; \ | |||||
wd##4 = tmp0 - tmp1; \ | |||||
tmp0 = d6 - d4 * 5.0f + d2 * 4.0f; \ | |||||
tmp1 = (d1 + d5 * 0.25f - d3 * 1.25f) * 2.0f; \ | |||||
wd##5 = tmp0 + tmp1; \ | |||||
wd##6 = tmp0 - tmp1; \ | |||||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||||
} while (0); | |||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||||
vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||||
vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) | |||||
struct InputTransform6X3 { | |||||
template <bool inner> | |||||
static void transform(const float* input, float* input_transform_buf, | |||||
float* transform_mid_buf, int ih_start, int iw_start, | |||||
size_t ic, size_t IH, size_t IW, size_t IC, | |||||
size_t unit_idx, size_t nr_units_in_tile) { | |||||
constexpr size_t alpha = 6 + 3 - 1; | |||||
if (!inner) { | |||||
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | |||||
} | |||||
#define cb(i) Vector<float, 8> d##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
if (inner) { | |||||
const float* input_ptr = | |||||
input + ic * IH * IW + ih_start * IW + iw_start; | |||||
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
} else { | |||||
int ih0_act = std::max<int>(ih_start, 0), | |||||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||||
iw0_act = std::max<int>(iw_start, 0), | |||||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||||
transform_mid_buf[iho * alpha + iwo] = | |||||
input[ic * IH * IW + ih * IW + iw]; | |||||
} | |||||
} | |||||
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
} | |||||
#define cb(i) Vector<float, 8> wd##i, ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
INPUT_TRANSFORM(d, wd); | |||||
#if MEGDNN_AARCH64 | |||||
TRANSPOSE_8x8(wd, d); | |||||
INPUT_TRANSFORM(d, ret); | |||||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||||
unit_idx * IC + ic] = | |||||
transform_mid_buf[j * alpha + i]; | |||||
} | |||||
#else | |||||
//! 1 0 0 0 0 0 0 0 | |||||
//! 0 1 -1 0.5 -0.5 2 -2 -1 | |||||
//! -5.25 1 1 0.25 0.25 4 4 0 | |||||
//! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25 | |||||
//! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0 | |||||
//! 0 1 -1 2 -2 0.5 -0.5 -5.25 | |||||
//! -1 1 1 1 1 1 1 0 | |||||
//! 0 0 0 0 0 0 0 1 | |||||
#define cb(i) \ | |||||
do { \ | |||||
mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ | |||||
5.25f * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ | |||||
GET_VECTOR_LOW_ELEM(wd, i, 2)); \ | |||||
mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \ | |||||
GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||||
5.25f * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ | |||||
auto tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 2) - \ | |||||
4.25f * GET_VECTOR_HIGH_ELEM(wd, i, 0); \ | |||||
auto tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 1) - \ | |||||
4.25f * GET_VECTOR_LOW_ELEM(wd, i, 3); \ | |||||
mid_buf1[1] = tmp0 + tmp1; \ | |||||
mid_buf1[2] = tmp0 - tmp1; \ | |||||
tmp0 = GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ | |||||
0.25f * GET_VECTOR_LOW_ELEM(wd, i, 2) - \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 0) * 1.25f; \ | |||||
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5f - \ | |||||
GET_VECTOR_LOW_ELEM(wd, i, 3) * 2.5f + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2.f; \ | |||||
mid_buf1[3] = tmp0 + tmp1; \ | |||||
mid_buf1[4] = tmp0 - tmp1; \ | |||||
tmp0 = GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ | |||||
(GET_VECTOR_LOW_ELEM(wd, i, 2) - \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 0) * 1.25f) * \ | |||||
4; \ | |||||
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 2.f - \ | |||||
GET_VECTOR_LOW_ELEM(wd, i, 3) * 2.5f + \ | |||||
GET_VECTOR_HIGH_ELEM(wd, i, 1) * 0.5f; \ | |||||
mid_buf1[5] = tmp0 + tmp1; \ | |||||
mid_buf1[6] = tmp0 - tmp1; \ | |||||
mid_buf1 += 8; \ | |||||
} while (0); | |||||
float* mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||||
unit_idx * IC + ic] = | |||||
transform_mid_buf[i * alpha + j]; | |||||
} | |||||
#endif | |||||
} | |||||
}; | |||||
#undef INPUT_TRANSFORM | |||||
/** | |||||
* Output Transform: use fma | |||||
* | |||||
* s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6) / 32 | |||||
* s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6) / 32 | |||||
* s2 = (m1 + m2) + 4 * (m3 + m4) + 8 * (m5 + m6) / 32 | |||||
* s3 = (m1 - m2) + 8 * (m3 - m4) + 4 * (m5 - m6) / 32 | |||||
* s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6) / 32 | |||||
* s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) / 32 + m7 | |||||
*/ | |||||
#define OUTPUT_TRANSFORM(m, s) \ | |||||
do { \ | |||||
auto m1addm2 = m##1 + m##2; \ | |||||
auto m1subm2 = m##1 - m##2; \ | |||||
auto m3addm4 = m##3 + m##4; \ | |||||
auto m3subm4 = m##3 - m##4; \ | |||||
auto m5addm6 = (m##5 + m##6) * 0.03125f; \ | |||||
auto m5subm6 = (m##5 - m##6) * 0.03125f; \ | |||||
s##0 = m##0; \ | |||||
CONCAT(s, 0).mla(m5addm6, 32.f).add(m3addm4).add(m1addm2); \ | |||||
CONCAT(s, 1) = m1subm2; \ | |||||
CONCAT(s, 1).mla(m3subm4, 2.f).mla(m5subm6, 16.f); \ | |||||
CONCAT(s, 2) = m1addm2; \ | |||||
CONCAT(s, 2).mla(m3addm4, 4.f).mla(m5addm6, 8.f); \ | |||||
CONCAT(s, 3) = m1subm2; \ | |||||
CONCAT(s, 3).mla(m3subm4, 8.f).mla(m5subm6, 4.f); \ | |||||
CONCAT(s, 4) = m1addm2; \ | |||||
CONCAT(s, 4).mla(m3addm4, 16.f).mla(m5addm6, 2.f); \ | |||||
CONCAT(s, 5) = m1subm2; \ | |||||
CONCAT(s, 5).mla(m3subm4, 32.f).add(m5subm6).add(m##7); \ | |||||
} while (0); | |||||
template <BiasMode bmode, typename Op> | |||||
struct OutputTransform6X3 { | |||||
static void transform(const float* output_transform_buf, const float* bias, | |||||
float* output, float* transform_mid_buf, | |||||
size_t oh_start, size_t ow_start, size_t OH, | |||||
size_t OW, size_t oc_start, size_t oc_end, | |||||
size_t oc_index, size_t unit_idx, | |||||
size_t nr_units_in_tile, const DType& src_dtype, | |||||
const DType& dst_dtype) { | |||||
constexpr size_t alpha = 6 + 3 - 1; | |||||
Op op(src_dtype, dst_dtype); | |||||
float* mid_buf1 = transform_mid_buf; | |||||
//! AT * m * A | |||||
size_t OC = oc_end - oc_start; | |||||
size_t oc = oc_start + oc_index; | |||||
#define cb(m, n) \ | |||||
transform_mid_buf[m * alpha + n] = \ | |||||
output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \ | |||||
unit_idx * OC + oc_index]; | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||||
#undef cb | |||||
#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) Vector<float, 8> s##i, ret##i; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
OUTPUT_TRANSFORM(m, s); | |||||
/** | |||||
* Output transform: m * A | |||||
* | |||||
* 1 0 0 0 0 0 | |||||
* 1 1 1 1 1 1 | |||||
* 1 -1 1 -1 1 -1 | |||||
* 1 2 4 8 16 32 | |||||
* 1 -2 4 -8 16 -32 | |||||
* 1 0.5 0.25 0.125 0.0625 0.03125 | |||||
* 1 -0.5 0.25 -0.125 0.0625 -0.03125 | |||||
* 0 0.0 0 0 0 1 | |||||
*/ | |||||
#define cb(i) \ | |||||
do { \ | |||||
auto m1addm2 = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||||
auto m1subm2 = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||||
auto m3addm4 = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||||
auto m3subm4 = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||||
auto m5addm6 = \ | |||||
GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ | |||||
auto m5subm6 = \ | |||||
GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ | |||||
mid_buf1[0] = \ | |||||
GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ | |||||
mid_buf1[1] = m1subm2 + 2.f * m3subm4 + 0.5f * m5subm6; \ | |||||
mid_buf1[2] = m1addm2 + 4.f * m3addm4 + 0.25f * m5addm6; \ | |||||
mid_buf1[3] = m1subm2 + 8.f * m3subm4 + 0.125f * m5subm6; \ | |||||
mid_buf1[4] = m1addm2 + 16.f * m3addm4 + 0.0625f * m5addm6; \ | |||||
mid_buf1[5] = m1subm2 + 32.f * m3subm4 + 0.03125f * m5subm6 + \ | |||||
GET_VECTOR_HIGH_ELEM(s, i, 3); \ | |||||
mid_buf1 += 6; \ | |||||
} while (0); | |||||
mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
if (oh_start + 6 <= OH && ow_start + 6 <= OW) { | |||||
float32x4_t bias0; | |||||
float32x2_t bias1; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias0 = vdupq_n_f32(bias[oc]); | |||||
bias1 = vdup_n_f32(bias[oc]); | |||||
} | |||||
rep(i, 6) { | |||||
size_t oh = oh_start + i; | |||||
float32x4_t item0 = vld1q_f32(mid_buf1); | |||||
float32x2_t item1 = vld1_f32(mid_buf1 + 4); | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
item0 = vaddq_f32(item0, bias0); | |||||
item1 = vadd_f32(item1, bias1); | |||||
} else if (bmode == BiasMode::BIAS) { | |||||
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); | |||||
bias1 = vld1_f32(bias + oc * OH * OW + oh * OW + ow_start + | |||||
4); | |||||
item0 = vaddq_f32(item0, bias0); | |||||
item1 = vadd_f32(item1, bias1); | |||||
} | |||||
item0 = op(item0); | |||||
item1 = vset_lane_f32(op(vget_lane_f32(item1, 0)), item1, 0); | |||||
item1 = vset_lane_f32(op(vget_lane_f32(item1, 1)), item1, 1); | |||||
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||||
vst1_f32(output + oc * OH * OW + oh * OW + ow_start + 4, item1); | |||||
mid_buf1 += 6; | |||||
} | |||||
} else { | |||||
for (size_t oho = 0; oho < 6 && oh_start + oho < OH; ++oho) { | |||||
for (size_t owo = 0; owo < 6 && ow_start + owo < OW; ++owo) { | |||||
size_t oh = oh_start + oho; | |||||
size_t ow = ow_start + owo; | |||||
float res = mid_buf1[oho * 6 + owo]; | |||||
if (bmode == BiasMode::BIAS) { | |||||
res += bias[oc * OH * OW + oh * OW + ow]; | |||||
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
res += bias[oc]; | |||||
} | |||||
res = op(res); | |||||
output[oc * OH * OW + oh * OW + ow] = res; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#undef GET_VECTOR_HIGH_ELEM | |||||
#undef GET_VECTOR_LOW_ELEM | |||||
#undef OUTPUT_TRANSFORM | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace winograd { | |||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f) | |||||
void winograd_6x3_1x1_f::filter(const float* filter, | |||||
float* filter_transform_buf, | |||||
float* transform_mid_buf, size_t OC, size_t IC, | |||||
size_t oc_start, size_t oc_end) { | |||||
FilterTransform6X3<param::MatrixMul::Format::DEFAULT>::transform( | |||||
filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, | |||||
oc_end); | |||||
} | |||||
void winograd_6x3_1x1_f::input(const float* input, float* input_transform_buf, | |||||
float* transform_mid_buf, size_t IH, size_t IW, | |||||
size_t IC, size_t PH, size_t PW, | |||||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||||
constexpr int alpha = 3 + 6 - 1; | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
rep(ic, IC) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||||
InputTransform6X3::transform<true>( | |||||
input, input_transform_buf, transform_mid_buf, ih_start, | |||||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||||
} else { | |||||
InputTransform6X3::transform<false>( | |||||
input, input_transform_buf, transform_mid_buf, ih_start, | |||||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void winograd_6x3_1x1_f::output(const float* output_transform_buf, | |||||
const float* bias, float* output, | |||||
float* transform_mid_buf, BiasMode bmode, | |||||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, | |||||
size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp32_F63, cb, float, float, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||||
nr_units_in_tile, src_dtype, dst_dtype); | |||||
} | |||||
} | |||||
#undef cb | |||||
} | |||||
} // namespace winograd | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,351 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/fp32/filter_transform.h" | |||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/common/winograd/winograd_helper.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63_4x4) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
struct InputTransform6X3 { | |||||
template <bool inner> | |||||
static void prepare(const float* input, float* patch, float* patchT, | |||||
int ih_start, int iw_start, size_t IH, size_t IW, | |||||
size_t ic, size_t IC) { | |||||
constexpr size_t alpha = 6 + 3 - 1; | |||||
if (!(inner && ic + 4 < IC)) { | |||||
memset(patch, 0, sizeof(float) * 4 * alpha * alpha); | |||||
} | |||||
if (inner) { | |||||
const float* input_ptr = | |||||
input + ic * IH * IW + ih_start * IW + iw_start; | |||||
for (size_t ico = 0; ico < 4; ++ico) { | |||||
if (ic + ico < IC) { | |||||
#define cb(i) \ | |||||
auto v##i##0 = vld1q_f32(input_ptr + IW * i); \ | |||||
auto v##i##1 = vld1q_f32(input_ptr + IW * i + 4); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(i) \ | |||||
vst1q_f32(patch + ico * 8 * alpha + i * 8, v##i##0); \ | |||||
vst1q_f32(patch + ico * 8 * alpha + i * 8 + 4, v##i##1); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
input_ptr += IH * IW; | |||||
} | |||||
} | |||||
} else { | |||||
int ih0_act = std::max<int>(ih_start, 0), | |||||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||||
iw0_act = std::max<int>(iw_start, 0), | |||||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||||
// partial copy | |||||
for (size_t ico = 0; ico < 4; ++ico) { | |||||
if (ic + ico < IC) { | |||||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||||
patch[ico * alpha * 8 + iho * 8 + iwo] = | |||||
input[(ic + ico) * IH * IW + ih * IW + iw]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#define cb(i) \ | |||||
transpose_4x4(patch + 8 * i + 0, patchT + 8 * i * 4, 64, 4); \ | |||||
transpose_4x4(patch + 8 * i + 4, patchT + 8 * i * 4 + 4 * 4, 64, 4); | |||||
UNROLL_CALL_NOWRAPPER(8, cb) | |||||
#undef cb | |||||
} | |||||
static void transform(const float* patchT, float* input_transform_buf, | |||||
size_t unit_idx, size_t nr_units_in_tile, size_t ic, | |||||
size_t IC) { | |||||
constexpr size_t alpha = 6 + 3 - 1; | |||||
// BT * d * B | |||||
#define cb(m, n) \ | |||||
Vector<float, 4> d##m##n = \ | |||||
Vector<float, 4>::load(patchT + m * 8 * 4 + n * 4); | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||||
#undef cb | |||||
//! B | |||||
//! 1 0 0 0 0 0 0 0 | |||||
//! 0 1 -1 0.5 -0.5 2 -2 -1 | |||||
//! -5.25 1 1 0.25 0.25 4 4 0 | |||||
//! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25 | |||||
//! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0 | |||||
//! 0 1 -1 2 -2 0.5 -0.5 -5.25 | |||||
//! -1 1 1 1 1 1 1 0 | |||||
//! 0 0 0 0 0 0 0 1 | |||||
#define cb(m) \ | |||||
auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \ | |||||
auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \ | |||||
auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \ | |||||
auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \ | |||||
d5##m * 2.f + d6##m; \ | |||||
auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - \ | |||||
d4##m * 1.25f - d5##m * 2.f + d6##m; \ | |||||
auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \ | |||||
d5##m * 0.5f + d6##m; \ | |||||
auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \ | |||||
d5##m * 0.5f + d6##m; \ | |||||
auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(m) \ | |||||
d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \ | |||||
d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - \ | |||||
(t##m##3 + t##m##4) * 4.25f; \ | |||||
d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + \ | |||||
(t##m##3 - t##m##4) * 4.25f; \ | |||||
d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - \ | |||||
t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; \ | |||||
d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - \ | |||||
t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; \ | |||||
d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \ | |||||
t##m##5 * 0.5f + t##m##6; \ | |||||
d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - \ | |||||
t##m##4 * 5.f - t##m##5 * 0.5f + t##m##6; \ | |||||
d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
size_t ICB = IC / 4; | |||||
size_t icb = ic / 4; | |||||
#define cb(m, n) \ | |||||
d##m##n.save(input_transform_buf + \ | |||||
(m * alpha + n) * ICB * nr_units_in_tile * 4 + \ | |||||
icb * nr_units_in_tile * 4 + unit_idx * 4); | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb) | |||||
#undef cb | |||||
} | |||||
}; | |||||
template <BiasMode bmode, typename Op> | |||||
struct OutputTransform6X3 { | |||||
static void transform(const float* output_transform_buf, const float* bias, | |||||
float* output, float* transform_mid_buf, | |||||
size_t oh_start, size_t ow_start, size_t OH, | |||||
size_t OW, size_t oc_start, size_t oc_end, | |||||
size_t oc_index, size_t unit_idx, | |||||
size_t nr_units_in_tile, const DType& src_dtype, | |||||
const DType& dst_dtype) { | |||||
Op op(src_dtype, dst_dtype); | |||||
//! AT * m * A | |||||
constexpr size_t alpha = 6 + 3 - 1; | |||||
size_t oc = oc_start + oc_index; | |||||
size_t OCB = (oc_end - oc_start) / 4; | |||||
size_t ocb = oc_index / 4; | |||||
#define cb(m, n) \ | |||||
auto v##m##n = Vector<float, 4>::load( \ | |||||
output_transform_buf + \ | |||||
(m * alpha + n) * OCB * nr_units_in_tile * 4 + \ | |||||
ocb * nr_units_in_tile * 4 + unit_idx * 4); | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||||
#undef cb | |||||
/** | |||||
* A | |||||
* | |||||
* 1 0 0 0 0 0 | |||||
* 1 1 1 1 1 1 | |||||
* 1 -1 1 -1 1 -1 | |||||
* 1 2 4 8 16 32 | |||||
* 1 -2 4 -8 16 -32 | |||||
* 1 0.5 0.25 0.125 0.0625 0.03125 | |||||
* 1 -0.5 0.25 -0.125 0.0625 -0.03125 | |||||
* 0 0.0 0 0 0 1 | |||||
*/ | |||||
Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||||
#define cb(m) \ | |||||
v1addv2 = v1##m + v2##m; \ | |||||
v1subv2 = v1##m - v2##m; \ | |||||
v3addv4 = v3##m + v4##m; \ | |||||
v3subv4 = v3##m - v4##m; \ | |||||
v5addv6 = v5##m + v6##m; \ | |||||
v5subv6 = v5##m - v6##m; \ | |||||
auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \ | |||||
auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ | |||||
auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ | |||||
auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ | |||||
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||||
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(m) \ | |||||
v1addv2 = t##m##1 + t##m##2; \ | |||||
v1subv2 = t##m##1 - t##m##2; \ | |||||
v3addv4 = t##m##3 + t##m##4; \ | |||||
v3subv4 = t##m##3 - t##m##4; \ | |||||
v5addv6 = t##m##5 + t##m##6; \ | |||||
v5subv6 = t##m##5 - t##m##6; \ | |||||
v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \ | |||||
v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ | |||||
v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ | |||||
v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ | |||||
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||||
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
#undef cb | |||||
Vector<float, 4> vbias; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
vbias = Vector<float, 4>::load(bias + oc); | |||||
#define cb(m, n) v##m##n += vbias; | |||||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||||
#undef cb | |||||
} | |||||
if (bmode != BiasMode::BIAS) { | |||||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | |||||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||||
#undef cb | |||||
} | |||||
#define cb(m, n) CONCAT(v##m, n).save(transform_mid_buf + (m * 6 + n) * 4); | |||||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||||
#undef cb | |||||
for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { | |||||
for (size_t oho = 0; oho < 6 && oh_start + oho < OH; ++oho) { | |||||
for (size_t owo = 0; owo < 6 && ow_start + owo < OW; ++owo) { | |||||
size_t oh = oh_start + oho; | |||||
size_t ow = ow_start + owo; | |||||
float res = transform_mid_buf[oho * 6 * 4 + owo * 4 + oco]; | |||||
if (bmode == BiasMode::BIAS) { | |||||
res += bias[(oc + oco) * OH * OW + oh * OW + ow]; | |||||
res = op(res); | |||||
} | |||||
output[(oc + oco) * OH * OW + oh * OW + ow] = res; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace winograd { | |||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_4x4_f) | |||||
void winograd_6x3_4x4_f::filter(const float* filter, | |||||
float* filter_transform_buf, | |||||
float* transform_mid_buf, size_t OC, size_t IC, | |||||
size_t oc_start, size_t oc_end) { | |||||
FilterTransform6X3<param::MatrixMul::Format::MK4>::transform( | |||||
filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, | |||||
oc_end); | |||||
} | |||||
void winograd_6x3_4x4_f::input(const float* input, float* input_transform_buf, | |||||
float* transform_mid_buf, size_t IH, size_t IW, | |||||
size_t IC, size_t PH, size_t PW, | |||||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||||
megdnn_assert(IC % 4 == 0); | |||||
constexpr int alpha = 3 + 6 - 1; | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
float* patch = transform_mid_buf; | |||||
float* patchT = transform_mid_buf + 4 * alpha * alpha; | |||||
for (size_t ic = 0; ic < IC; ic += 4) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||||
InputTransform6X3::prepare<true>(input, patch, patchT, ih_start, | |||||
iw_start, IH, IW, ic, IC); | |||||
InputTransform6X3::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, ic, | |||||
IC); | |||||
} else { | |||||
InputTransform6X3::prepare<false>(input, patch, patchT, | |||||
ih_start, iw_start, IH, IW, | |||||
ic, IC); | |||||
InputTransform6X3::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, ic, | |||||
IC); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void winograd_6x3_4x4_f::output(const float* output_transform_buf, | |||||
const float* bias, float* output, | |||||
float* transform_mid_buf, BiasMode bmode, | |||||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
for (size_t oc = oc_start; oc < oc_end; oc += 4) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp32_F63_4x4, cb, float, float, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||||
nr_units_in_tile, src_dtype, dst_dtype); | |||||
} | |||||
} | |||||
#undef cb | |||||
} | |||||
} // namespace winograd | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,82 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/img2col_helper.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <cstddef> | |||||
#include "src/common/utils.h" | |||||
namespace { | |||||
template <bool is_xcorr, typename dtype> | |||||
void img2col_stride(const dtype* __restrict src, | |||||
dtype* __restrict dst, const int OC, const int OH, | |||||
const int OW, const int IC, const int IH, const int IW, | |||||
const int FH, const int FW, const int SH, const int SW) { | |||||
(void)OC; | |||||
size_t i = 0; | |||||
rep(ic, IC) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
rep(oh, OH) { | |||||
rep(ow, OW) { | |||||
int fh2, fw2; | |||||
if (is_xcorr) { | |||||
fh2 = fh; | |||||
fw2 = fw; | |||||
} else { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
dst[i++] = src[ic * IH * IW + (oh * SH + fh2) * IW + | |||||
(ow * SW + fw2)]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <bool is_xcorr, typename dtype> | |||||
void img2col(const dtype* src, dtype* dst, size_t /* OC */, size_t OH, | |||||
size_t OW, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW) { | |||||
size_t offset = (4 - OW % 4) % 4; | |||||
size_t i = 0; | |||||
rep(ic, IC) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
rep(oh, OH) { | |||||
size_t ow = 0; | |||||
for (; ow < OW; ow += 4) { | |||||
size_t fh2, fw2; | |||||
if (is_xcorr) { | |||||
fh2 = fh; | |||||
fw2 = fw; | |||||
} else { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + | |||||
(ow + fw2) + 0]; | |||||
dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + | |||||
(ow + fw2) + 1]; | |||||
dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + | |||||
(ow + fw2) + 2]; | |||||
dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + | |||||
(ow + fw2) + 3]; | |||||
} | |||||
i -= offset; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} // anonymous namespace | |||||
// vim: syntax=cpp.doxygen |