@@ -19,11 +19,14 @@ CHECK_CXX_COMPILER_FLAG(-Wclass-memaccess CXX_SUPPORT_WCLASS_MEMACCESS) | |||
set(MGE_ARCH AUTO CACHE STRING "Architecture on which MegEngine to be built.") | |||
set_property(CACHE MGE_ARCH PROPERTY STRINGS AUTO | |||
x86_64 i386 | |||
armv7 aarch64 | |||
naive fallback | |||
) | |||
option(MGE_WITH_JIT "Build MegEngine with JIT." ON) | |||
option(MGE_WITH_HALIDE "Build MegEngine with Halide JIT" ON) | |||
option(MGE_ARMV8_2_FEATURE_FP16 "Enable armv8.2-a+fp16 support" OFF) | |||
option(MGE_ARMV8_2_FEATURE_DOTPROD "enable armv8.2-a+dotprod support" OFF) | |||
option(MGE_DISABLE_FLOAT16 "Disable MegEngine float16 support." OFF) | |||
option(MGE_WITH_CUDA "Enable MegEngine CUDA support." ON) | |||
option(MGE_CUDA_USE_STATIC "Enable MegEngine CUDA static linking." ON) | |||
@@ -31,12 +34,52 @@ option(MGE_WITH_TRT "Build MegEngine with TensorRT." ON) | |||
option(MGE_USE_SYSTEM_LIB "Build MegEngine with system libraries." OFF) | |||
option(MGB_WITH_FLATBUFFERS "Build MegBrain with FlatBuffers serialization support." ON) | |||
if(CMAKE_TOOLCHAIN_FILE) | |||
message("We are cross compiling.") | |||
message("config FLATBUFFERS_FLATC_EXECUTABLE to: ${PROJECT_SOURCE_DIR}/build_dir/host_flatc/install/bin/flatc") | |||
set(FLATBUFFERS_FLATC_EXECUTABLE "${PROJECT_SOURCE_DIR}/build_dir/host_flatc/install/bin/flatc") | |||
if(ANDROID_TOOLCHAIN_ROOT) | |||
if(NOT "${ANDROID_ARCH_NAME}" STREQUAL "") | |||
set(ANDROID_ARCH ${ANDROID_ARCH_NAME}) | |||
endif() | |||
if(${ANDROID_ARCH} STREQUAL "arm") | |||
set(MGE_ARCH "armv7") | |||
elseif(${ANDROID_ARCH} STREQUAL "arm64") | |||
set(MGE_ARCH "aarch64") | |||
else() | |||
message(FATAL_ERROR "DO NOT SUPPORT ANDROID ARCH NOW") | |||
endif() | |||
elseif(IOS_TOOLCHAIN_ROOT) | |||
if(${IOS_ARCH} STREQUAL "armv7") | |||
set(MGE_ARCH "armv7") | |||
elseif(${IOS_ARCH} STREQUAL "arm64") | |||
set(MGE_ARCH "aarch64") | |||
elseif(${IOS_ARCH} STREQUAL "armv7k") | |||
set(MGE_ARCH "armv7") | |||
elseif(${IOS_ARCH} STREQUAL "arm64e") | |||
set(MGE_ARCH "aarch64") | |||
elseif(${IOS_ARCH} STREQUAL "armv7s") | |||
set(MGE_ARCH "armv7") | |||
else() | |||
message(FATAL_ERROR "Unsupported IOS_ARCH.") | |||
endif() | |||
elseif(NOT "${ARM_CROSS_BUILD_ARCH}" STREQUAL "") | |||
set(MGE_ARCH ${ARM_CROSS_BUILD_ARCH}) | |||
else() | |||
message(FATAL_ERROR "Unknown cross-compiling settings.") | |||
endif() | |||
message("CONFIG MGE_ARCH TO ${MGE_ARCH}") | |||
endif() | |||
if(${MGE_ARCH} STREQUAL "AUTO") | |||
if(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") | |||
set(MGE_ARCH "x86_64") | |||
elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i386" OR ${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i686") | |||
set(MGE_ARCH "i386") | |||
elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "aarch64" OR ${CMAKE_SYSTEM_PROCESSOR} STREQUAL "arm64") | |||
set(MGE_ARCH "aarch64") | |||
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "^arm") | |||
set(MGE_ARCH "armv7") | |||
else() | |||
message(FATAL "Unknown machine architecture for MegEngine.") | |||
endif() | |||
@@ -399,6 +442,38 @@ if(MGE_ARCH STREQUAL "x86_64" OR MGE_ARCH STREQUAL "i386") | |||
endif() | |||
endif() | |||
if(MGE_ARCH STREQUAL "armv7") | |||
# -funsafe-math-optimizations to enable neon auto-vectorization (since neon is not fully IEEE 754 compatible, GCC does not turn on neon auto-vectorization by default. | |||
if(ANDROID) | |||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfloat-abi=softfp -mfpu=neon") | |||
endif() | |||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -funsafe-math-optimizations") | |||
set (MARCH "-march=armv7-a") | |||
set (MEGDNN_ARMV7 1) | |||
endif() | |||
if(MGE_ARCH STREQUAL "aarch64") | |||
set(MEGDNN_AARCH64 1) | |||
set(MEGDNN_64_BIT 1) | |||
set(MARCH "-march=armv8-a") | |||
if(MGE_ARMV8_2_FEATURE_FP16) | |||
message("Enable fp16 feature support in armv8.2") | |||
if(NOT ${MGE_DISABLE_FLOAT16}) | |||
set(MEGDNN_ENABLE_FP16_NEON 1) | |||
endif() | |||
set(MARCH "-march=armv8.2-a+fp16") | |||
endif() | |||
if(MGE_ARMV8_2_FEATURE_DOTPROD) | |||
message("Enable dotprod feature support in armv8.2") | |||
if(MGE_ARMV8_2_FEATURE_FP16) | |||
set(MARCH "-march=armv8.2-a+fp16+dotprod") | |||
else() | |||
set(MARCH "-march=armv8.2-a+dotprod") | |||
endif() | |||
endif() | |||
endif() | |||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}") | |||
@@ -29,6 +29,9 @@ class Handle { | |||
NAIVE = 0, | |||
FALLBACK = 1, | |||
X86 = 2, | |||
ARM_COMMON = 3, | |||
ARMV7 = 4, | |||
AARCH64 = 5, | |||
CUDA = 6, | |||
}; | |||
@@ -17,6 +17,22 @@ if(NOT ${MGE_ARCH} STREQUAL "naive") | |||
set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C) | |||
list(APPEND SOURCES ${SOURCES_}) | |||
endif() | |||
elseif(${MGE_ARCH} STREQUAL "armv7") | |||
file(GLOB_RECURSE SOURCES_ armv7/*.cpp) | |||
list(APPEND SOURCES ${SOURCES_}) | |||
file(GLOB_RECURSE SOURCES_ arm_common/*.cpp) | |||
list(APPEND SOURCES ${SOURCES_}) | |||
file(GLOB_RECURSE SOURCES_ armv7/*.S) | |||
set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C) | |||
list(APPEND SOURCES ${SOURCES_}) | |||
elseif(${MGE_ARCH} STREQUAL "aarch64") | |||
file(GLOB_RECURSE SOURCES_ aarch64/*.cpp) | |||
list(APPEND SOURCES ${SOURCES_}) | |||
file(GLOB_RECURSE SOURCES_ arm_common/*.cpp) | |||
list(APPEND SOURCES ${SOURCES_}) | |||
file(GLOB_RECURSE SOURCES_ aarch64/*.S) | |||
set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C) | |||
list(APPEND SOURCES ${SOURCES_}) | |||
endif() | |||
endif() | |||
@@ -0,0 +1,139 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/fp16/algos.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/conv_bias/fp16/algos.h" | |||
#include "src/aarch64/conv_bias/fp16/stride2_kern.h" | |||
#include "src/arm_common/conv_bias/direct/multi_thread_common.h" | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
#include "midout.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
/* ===================== stride-2 algo ===================== */ | |||
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) | |||
bool ConvBiasImpl::AlgoF16DirectStride2::usable( | |||
FallbackConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
bool aviliable = | |||
param.filter_meta.format == param::Convolution::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float16 && | |||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( | |||
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 1) { | |||
auto wbundle = arm_common::MultithreadDirectConvCommon< | |||
dt_float16, __fp16>::get_bundle_stride(param, m_large_group); | |||
return wbundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { | |||
return get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||
const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
size_t N = param.n; | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | |||
size_t, size_t, size_t, size_t, size_t)>; | |||
Func conv = nullptr; | |||
if (FH == 2) { | |||
conv = fp16::conv_stride2::do_conv_2x2_stride2; | |||
} else if (FH == 3) { | |||
conv = fp16::conv_stride2::do_conv_3x3_stride2; | |||
} else if (FH == 5) { | |||
conv = fp16::conv_stride2::do_conv_5x5_stride2; | |||
} else if (FH == 7) { | |||
conv = fp16::conv_stride2::do_conv_7x7_stride2; | |||
} | |||
WorkspaceBundle wbundle = arm_common::MultithreadDirectConvCommon< | |||
dt_float16, __fp16>::get_bundle_stride(param, m_large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! Dense conv and small group | |||
if (m_large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [wbundle, conv](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
auto fm = kern_param.filter_meta; | |||
size_t IC = fm.icpg; | |||
size_t OC = fm.ocpg; | |||
WorkspaceBundle bundle = wbundle; | |||
for (size_t ic = 0; ic < IC; ic++) { | |||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
{ncb_index.thread_id, 0, ic}); | |||
} | |||
for (size_t oc = 0; oc < OC; oc++) { | |||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
do_conv_kern_stride(bundle, kern_param, ncb_index, conv, | |||
{ncb_index.thread_id, 0, oc}); | |||
} | |||
}; | |||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||
} else { | |||
WorkspaceBundle bundle = wbundle; | |||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||
auto do_conv = [bundle, conv](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
do_conv_kern_stride(bundle, kern_param, ncb_index, conv, | |||
ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||
} | |||
return ret_kerns; | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,42 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/fp16/algos.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/aarch64/conv_bias/opr_impl.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
namespace megdnn { | |||
namespace aarch64 { | |||
/* ===================== stride-2 algo ===================== */ | |||
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF16DirectStride2(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "ARMV8F16STRD2_LARGE_GROUP" | |||
: "ARMV8F16STRD2_SMALL_GROUP"; | |||
} | |||
bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(FallbackConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
SmallVector<NCBKern> dispatch_kerns(FallbackConvBiasImpl*, | |||
const NCBKernSizeParam&) const override; | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,137 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/fp32/algos.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/conv_bias/fp32/algos.h" | |||
#include "src/aarch64/conv_bias/fp32/stride2_kern.h" | |||
#include "src/arm_common/conv_bias/direct/multi_thread_common.h" | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "midout.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) | |||
bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||
FallbackConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
bool aviliable = | |||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | |||
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 1) { | |||
auto wbundle = arm_common::MultithreadDirectConvCommon< | |||
float, float>::get_bundle_stride(param, m_large_group); | |||
return wbundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { | |||
return get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
size_t N = param.n; | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
using Func = std::function<void(const float*, const float*, float*, size_t, | |||
size_t, size_t, size_t, size_t)>; | |||
Func conv = nullptr; | |||
if (FH == 2) { | |||
conv = fp32::conv_stride2::do_conv_2x2_stride2; | |||
} else if (FH == 3) { | |||
conv = fp32::conv_stride2::do_conv_3x3_stride2; | |||
} else if (FH == 5) { | |||
conv = fp32::conv_stride2::do_conv_5x5_stride2; | |||
} else if (FH == 7) { | |||
conv = fp32::conv_stride2::do_conv_7x7_stride2; | |||
} | |||
WorkspaceBundle wbundle = arm_common::MultithreadDirectConvCommon< | |||
float, float>::get_bundle_stride(param, m_large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! Dense conv and small group | |||
if (m_large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [wbundle, conv](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
auto fm = kern_param.filter_meta; | |||
size_t IC = fm.icpg; | |||
size_t OC = fm.ocpg; | |||
WorkspaceBundle bundle = wbundle; | |||
for (size_t ic = 0; ic < IC; ic++) { | |||
arm_common::MultithreadDirectConvCommon<float, float>:: | |||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
{ncb_index.thread_id, 0, ic}); | |||
} | |||
for (size_t oc = 0; oc < OC; oc++) { | |||
arm_common::MultithreadDirectConvCommon< | |||
float, float>::do_conv_kern_stride(bundle, kern_param, | |||
ncb_index, conv, | |||
{ncb_index.thread_id, | |||
0, oc}); | |||
} | |||
}; | |||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||
} else { | |||
WorkspaceBundle bundle = wbundle; | |||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
arm_common::MultithreadDirectConvCommon<float, float>:: | |||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||
auto do_conv = [bundle, conv](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
arm_common::MultithreadDirectConvCommon< | |||
float, float>::do_conv_kern_stride(bundle, kern_param, | |||
ncb_index, conv, | |||
ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||
} | |||
return ret_kerns; | |||
} |
@@ -0,0 +1,47 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/fp32/algos.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/aarch64/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||
/* ===================== stride-2 algo ===================== */ | |||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF32DirectStride2(bool large_group) : m_large_group(large_group) {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "ARMV8F32STRD2_LARGE_GROUP" | |||
: "ARMV8F32STRD2_SMALL_GROUP"; | |||
} | |||
bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(FallbackConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
SmallVector<NCBKern> dispatch_kerns(FallbackConvBiasImpl*, | |||
const NCBKernSizeParam&) const override; | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,187 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/int8/algos.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/conv_bias/int8/algos.h" | |||
#include "src/aarch64/conv_bias/int8/strategy.h" | |||
#include "src/arm_common/convolution/img2col_helper.h" | |||
#include "src/arm_common/elemwise_op.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/matrix_mul/gemm_impl.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_aarch64_conv_bias_int8_gemm) | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using megdnn::arm_common::HSwishOp; | |||
using megdnn::arm_common::ReluOp; | |||
using megdnn::arm_common::TypeCvtOp; | |||
/* ===================== matrix mul algo ===================== */ | |||
bool ConvBiasImpl::AlgoS8MatrixMul::usable( | |||
FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(opr); | |||
auto&& fm = param.filter_meta; | |||
return param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS8 && | |||
fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
//! As postprocess, the bias is not contigous read, make the | |||
//! performance bad, so we do not process it in fused kernel | |||
param.bias_mode != BiasMode::BIAS && | |||
//! This algo is only support single thread | |||
param.nr_threads == 1_z; | |||
} | |||
WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||
const NCBKernSizeParam& param) { | |||
UNPACK_CONV_NCB_KERN_SIZES(param); | |||
MEGDNN_MARK_USED_VAR(N); | |||
auto IW2 = IH + 2 * PH; | |||
auto IH2 = IW + 2 * PW; | |||
bool can_matrix_mul_direct = | |||
(FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0); | |||
// temp space to store padding-free src (with 16 extra int8) | |||
// temp space to store unrolled matrix (with 16 extra int8) | |||
// workspace for matrix mul opr | |||
size_t part0, part1, part2; | |||
if (can_matrix_mul_direct) { | |||
part0 = part1 = 0; | |||
} else { | |||
part0 = (IC * IH2 * IW2 + 16) * sizeof(int8_t); | |||
part1 = (IC * FH * FW * OH * OW + 16) * sizeof(int8_t); | |||
} | |||
{ | |||
size_t M = OC; | |||
size_t K = IC * FH * FW; | |||
size_t N = OH * OW; | |||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
_bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \ | |||
_bias_midout_enum, _nonline_midout_enum) { \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
part2 = megdnn::matmul::GemmInterleaved< \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ | |||
M, N, K, false, false, strategy) \ | |||
.get_workspace_size(); \ | |||
} \ | |||
MIDOUT_END() | |||
#if !(__ARM_FEATURE_DOTPROD) | |||
DISPATCH_GEMM_BIAS(s8_4x4, 0) | |||
#else | |||
DISPATCH_GEMM_BIAS(s8_8x12, 1) | |||
#endif | |||
#undef DISPATCH_GEMM_STRATEGY | |||
} | |||
return {nullptr, {part0, part1, part2}}; | |||
} | |||
void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||
const NCBKernIndex& ncb_index) { | |||
auto is_xcorr = !param.filter_meta.should_flip; | |||
UNPACK_CONV_NCB_KERN_SIZES(param); | |||
auto bundle = get_bundle(param); | |||
bundle.set(param.workspace_ptr); | |||
auto IH2 = IH + 2 * PH; | |||
auto IW2 = IW + 2 * PW; | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
// workspace = tmp..src2 | |||
for (size_t n = 0; n < N; ++n) { | |||
dt_int8* src = const_cast<dt_int8*>(param.src<dt_int8>(n, group_id)); | |||
dt_int8* filter = const_cast<dt_int8*>(param.filter<dt_int8>(group_id)); | |||
dt_int8* dst = static_cast<dt_int8*>(param.dst<dt_int8>(n, group_id)); | |||
dt_int32* bias = const_cast<dt_int32*>(param.bias<dt_int32>(n, group_id)); | |||
dt_int8 *B, *src2; | |||
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { | |||
// special case: 1x1 | |||
B = const_cast<dt_int8*>(src); | |||
} else { | |||
src2 = static_cast<dt_int8*>(bundle.get(0)); | |||
// copy src to src2; | |||
dt_int8* src2_ptr = src2; | |||
const dt_int8* src_ptr = src; | |||
rep(ic, IC) { | |||
if (PH != 0) { | |||
std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2); | |||
src2_ptr += PH * IW2; | |||
} | |||
rep(ih, IH) { | |||
if (PW != 0) | |||
rep(pw, PW) { *(src2_ptr++) = 0.0f; } | |||
std::memcpy(src2_ptr, src_ptr, sizeof(dt_int8) * IW); | |||
src2_ptr += IW; | |||
src_ptr += IW; | |||
if (PW != 0) | |||
rep(pw, PW) { *(src2_ptr++) = 0.0f; } | |||
} | |||
if (PH != 0) { | |||
std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2); | |||
src2_ptr += PH * IW2; | |||
} | |||
} | |||
B = static_cast<dt_int8*>(bundle.get(1)); | |||
if (SH == 1 && SW == 1) { | |||
if (is_xcorr) | |||
img2col<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | |||
else | |||
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | |||
} else { | |||
if (is_xcorr) | |||
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||
FW, SH, SW); | |||
else | |||
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||
FW, SH, SW); | |||
} | |||
} | |||
{ | |||
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)), | |||
bundle.get_size(2)); | |||
size_t M = OC; | |||
size_t K = IC * FH * FW; | |||
size_t N = OH * OW; | |||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
_bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \ | |||
_bias_midout_enum, _nonline_midout_enum) { \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
megdnn::matmul::GemmInterleaved< \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ | |||
bias); \ | |||
} \ | |||
MIDOUT_END() | |||
#if !(__ARM_FEATURE_DOTPROD) | |||
DISPATCH_GEMM_BIAS(s8_4x4, 0) | |||
#else | |||
DISPATCH_GEMM_BIAS(s8_8x12, 1) | |||
#endif | |||
#undef DISPATCH_GEMM_STRATEGY | |||
} | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,52 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/int8/algos.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/aarch64/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||
class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase { | |||
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | |||
static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "S8MATMUL"; } | |||
bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(FallbackConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
SmallVector<NCBKern> dispatch_kerns( | |||
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const override { | |||
size_t group = param.filter_meta.group; | |||
return {{kimpl, {group, 1_z, 1_z}}}; | |||
} | |||
//! select matmul to the highest preference | |||
bool is_preferred(FallbackConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override { | |||
return static_cast<arm_common::ConvBiasImpl*>(opr) | |||
->is_matmul_quantized_prefer(param); | |||
} | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,309 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/int8/strategy.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/conv_bias/int8/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h" | |||
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" | |||
#include "src/arm_common/conv_bias/matmul_postprocess.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
namespace impl { | |||
template <BiasMode bmode, typename Op, int block_m, int block_n> | |||
struct KernCaller; | |||
#if __ARM_FEATURE_DOTPROD | |||
template <BiasMode bmode, typename Op> | |||
struct KernCaller<bmode, Op, 8, 12> { | |||
static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, | |||
Op op, const dt_int32* bias, dt_int32* workspace) { | |||
megdnn_assert(is_first_k); | |||
constexpr size_t A_INTERLEAVE = 8; | |||
constexpr size_t B_INTERLEAVE = 12; | |||
//! K is packed to times of 4 | |||
K = round_up<size_t>(K, 4); | |||
const int K8 = (K << 3); | |||
const int K12 = K * 12; | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
int8_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x12x4::kern_8x12(packA, cur_packB, K, workspace, 12, | |||
is_first_k); | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8, | |||
12>::postprocess(bias, workspace, | |||
output, LDC, op); | |||
output += B_INTERLEAVE; | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x12x4::kern_8x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k, std::min<size_t>(N - n, 4)); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 4, 8, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4)); | |||
#undef cb | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K8; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias += A_INTERLEAVE; | |||
} | |||
} | |||
for (; m < M; m += 4) { | |||
int8_t* output = C + (m * LDC); | |||
const dt_int8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x12x4::kern_4x12(packA, cur_packB, K, workspace, 12, | |||
is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 12, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_M_N(cb, std::min<size_t>(M - m, 4), 12); | |||
#undef cb | |||
output += B_INTERLEAVE; | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x12x4::kern_4x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
#undef cb | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias += 4; | |||
} | |||
} | |||
} | |||
}; | |||
#else | |||
template <BiasMode bmode, typename Op> | |||
struct KernCaller<bmode, Op, 4, 4> { | |||
static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, | |||
Op op, const dt_int32* bias, dt_int32* workspace) { | |||
megdnn_assert(is_first_k); | |||
constexpr size_t A_INTERLEAVE = 4; | |||
constexpr size_t B_INTERLEAVE = 4; | |||
//! K is packed to times of 4 | |||
K = round_up<size_t>(K, 16); | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
int8_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k); | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, | |||
4>::postprocess(bias, workspace, | |||
output, LDC, op); | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, workspace, | |||
4, is_first_k, 4, | |||
std::min<size_t>(N - n, 4)); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_N(cb, 4, std::min<size_t>(N - n, 4)); | |||
#undef cb | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias += A_INTERLEAVE; | |||
} | |||
} | |||
for (; m < M; m += A_INTERLEAVE) { | |||
int8_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4_remain( | |||
packA, cur_packB, K, workspace, 4, is_first_k, | |||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
#undef cb | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias += A_INTERLEAVE; | |||
} | |||
} | |||
} | |||
}; | |||
#endif | |||
} // namespace impl | |||
#if !(__ARM_FEATURE_DOTPROD) | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) | |||
void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
} else { | |||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
} | |||
} | |||
void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, | |||
int ldin, int x0, int xmax, int k0, | |||
int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const { | |||
return 4 * 4 * sizeof(dt_int32); | |||
} | |||
#else | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) | |||
void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t); | |||
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t); | |||
if (transpose) { | |||
matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
} else { | |||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
} | |||
} | |||
void gemm_s8_8x12_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, | |||
int ldin, int x0, int xmax, int k0, | |||
int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
matmul_8x12x4::gemm_s8_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const { | |||
return 8 * 12 * sizeof(dt_int32); | |||
} | |||
#endif | |||
#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ | |||
void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \ | |||
size_t K, dt_int8* C, size_t LDC, bool is_first_k, \ | |||
const dt_int32* bias, dt_int32* workspace) const { \ | |||
float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \ | |||
float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \ | |||
float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \ | |||
DEFINE_OP(_OP); \ | |||
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ | |||
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||
workspace); \ | |||
} | |||
#define DEFINE_OP(_Op) \ | |||
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, scale_C); | |||
#if !(__ARM_FEATURE_DOTPROD) | |||
KERN(4, 4, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) | |||
KERN(4, 4, nobias, BiasMode::NO_BIAS, relu, ReluOp) | |||
KERN(4, 4, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||
#else | |||
KERN(8, 12, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) | |||
KERN(8, 12, nobias, BiasMode::NO_BIAS, relu, ReluOp) | |||
KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||
#endif | |||
#undef DEFINE_OP | |||
#define DEFINE_OP(_Op) \ | |||
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \ | |||
scale_A* scale_B, scale_C); | |||
#if !(__ARM_FEATURE_DOTPROD) | |||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||
FuseAddHSwishOp) | |||
#else | |||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||
FuseAddHSwishOp) | |||
#endif | |||
#undef DEFINE_OP | |||
#undef KERN | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,69 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/int8/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
#if !(__ARM_FEATURE_DOTPROD) | |||
/** | |||
* \brief base strategy of gemm. | |||
* | |||
* \name gemm_<type>_<block>_biasmode_nolinemode | |||
*/ | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 4, 16, | |||
false, true, | |||
gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_relu, | |||
gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_hswish, | |||
gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_identity, | |||
gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu, | |||
gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish, | |||
gemm_s8_4x4_nobias_identity); | |||
#else | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4, | |||
false, true, | |||
gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_relu, | |||
gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_hswish, | |||
gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_identity, | |||
gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu, | |||
gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish, | |||
gemm_s8_8x12_nobias_identity); | |||
#endif | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,69 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/conv_bias/opr_impl.h" | |||
#include "src/aarch64/conv_bias/int8/algos.h" | |||
#include "src/aarch64/conv_bias/quint8/algos.h" | |||
#include "src/naive/handle.h" | |||
#include "src/common/utils.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/fallback/convolution/opr_impl.h" | |||
#include "src/aarch64/conv_bias/fp32/algos.h" | |||
#include "src/aarch64/conv_bias/fp16/algos.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoF32DirectStride2 f32_direct_stride2_large_group{true}; | |||
AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; | |||
AlgoS8MatrixMul s8_matrix_mul; | |||
AlgoQU8MatrixMul qu8_matrix_mul; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
AlgoF16DirectStride2 f16_direct_stride2_large_group{true}; | |||
AlgoF16DirectStride2 f16_direct_stride2_small_group{false}; | |||
#endif | |||
public: | |||
AlgoPack() { | |||
matmul_algos.emplace_back(&qu8_matrix_mul); | |||
matmul_algos.emplace_back(&s8_matrix_mul); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
direct_algos.emplace_back(&f16_direct_stride2_large_group); | |||
direct_algos.emplace_back(&f16_direct_stride2_small_group); | |||
#endif | |||
direct_algos.emplace_back(&f32_direct_stride2_large_group); | |||
direct_algos.emplace_back(&f32_direct_stride2_small_group); | |||
} | |||
SmallVector<AlgoBase*> direct_algos; | |||
SmallVector<AlgoBase*> matmul_algos; | |||
}; | |||
SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||
static AlgoPack sl_algo_pack; | |||
auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | |||
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | |||
sl_algo_pack.direct_algos.end()); | |||
//! We put matmul algos at the end. Because matmul will get privilege when | |||
//! prefer return true. See | |||
//! fallback::ConvolutionImpl::ncb_1g_get_all_algorithms for more details. | |||
algos.insert(algos.end(), sl_algo_pack.matmul_algos.begin(), | |||
sl_algo_pack.matmul_algos.end()); | |||
return std::move(algos); | |||
} | |||
const char* ConvBiasImpl::get_algorithm_set_name() const { | |||
return "AARCH64"; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,41 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/common/utils.h" | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
class ConvBiasImpl : public arm_common::ConvBiasImpl { | |||
public: | |||
using arm_common::ConvBiasImpl::ConvBiasImpl; | |||
SmallVector<AlgoBase*> algo_pack() override; | |||
protected: | |||
const char* get_algorithm_set_name() const override; | |||
private: | |||
class AlgoF32DirectStride2; | |||
class AlgoS8MatrixMul; | |||
class AlgoQU8MatrixMul; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
class AlgoF16DirectStride2; | |||
#endif | |||
class AlgoPack; | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,181 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/quint8/algos.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/conv_bias/quint8/algos.h" | |||
#include "src/aarch64/conv_bias/quint8/strategy.h" | |||
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | |||
#include "src/aarch64/matrix_mul/quint8_dot/strategy.h" | |||
#include "src/arm_common/convolution/img2col_helper.h" | |||
#include "src/arm_common/elemwise_op.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/matrix_mul/gemm_impl.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_aarch64_conv_bias_quint8_gemm) | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using megdnn::arm_common::HSwishOp; | |||
using megdnn::arm_common::ReluOp; | |||
using megdnn::arm_common::TypeCvtOp; | |||
/* ===================== matrix mul algo ===================== */ | |||
bool ConvBiasImpl::AlgoQU8MatrixMul::usable( | |||
FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(opr); | |||
auto&& fm = param.filter_meta; | |||
return param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
//! As postprocess, the bias is not contigous read, make the | |||
//! performance bad, so we do not process it in fused kernel | |||
param.bias_mode != BiasMode::BIAS && | |||
//! This algo is only support single thread | |||
param.nr_threads == 1_z; | |||
} | |||
WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||
const NCBKernSizeParam& param) { | |||
UNPACK_CONV_NCB_KERN_SIZES(param); | |||
MEGDNN_MARK_USED_VAR(N); | |||
auto IW2 = IH + 2 * PH; | |||
auto IH2 = IW + 2 * PW; | |||
bool can_matrix_mul_direct = | |||
(FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0); | |||
// temp space to store padding-free src (with 16 extra int8) | |||
// temp space to store unrolled matrix (with 16 extra int8) | |||
// workspace for matrix mul opr | |||
size_t part0, part1, part2; | |||
if (can_matrix_mul_direct) { | |||
part0 = part1 = 0; | |||
} else { | |||
part0 = (IC * IH2 * IW2 + 16) * sizeof(uint8_t); | |||
part1 = (IC * FH * FW * OH * OW + 16) * sizeof(uint8_t); | |||
} | |||
{ | |||
size_t M = OC; | |||
size_t K = IC * FH * FW; | |||
size_t N = OH * OW; | |||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
_bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \ | |||
_bias_midout_enum, _nonline_midout_enum) { \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
part2 = megdnn::matmul::GemmInterleaved< \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ | |||
M, N, K, false, false, strategy) \ | |||
.get_workspace_size(); \ | |||
} \ | |||
MIDOUT_END() | |||
DISPATCH_GEMM_BIAS(u8_8x8, 0) | |||
#undef DISPATCH_GEMM_STRATEGY | |||
} | |||
return {nullptr, {part0, part1, part2}}; | |||
} | |||
void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||
const NCBKernIndex& ncb_index) { | |||
auto is_xcorr = !param.filter_meta.should_flip; | |||
UNPACK_CONV_NCB_KERN_SIZES(param); | |||
auto bundle = get_bundle(param); | |||
bundle.set(param.workspace_ptr); | |||
auto IH2 = IH + 2 * PH; | |||
auto IW2 = IW + 2 * PW; | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
uint8_t src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||
// workspace = tmp..src2 | |||
for (size_t n = 0; n < N; ++n) { | |||
uint8_t* src = const_cast<uint8_t*>(param.src<uint8_t>(n, group_id)); | |||
uint8_t* filter = const_cast<uint8_t*>(param.filter<uint8_t>(group_id)); | |||
uint8_t* dst = static_cast<uint8_t*>(param.dst<uint8_t>(n, group_id)); | |||
int32_t* bias = const_cast<int32_t*>(param.bias<int32_t>(n, group_id)); | |||
uint8_t *B, *src2; | |||
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { | |||
// special case: 1x1 | |||
B = const_cast<uint8_t*>(src); | |||
} else { | |||
src2 = static_cast<uint8_t*>(bundle.get(0)); | |||
// copy src to src2; | |||
uint8_t* src2_ptr = src2; | |||
const uint8_t* src_ptr = src; | |||
rep(ic, IC) { | |||
if (PH != 0) { | |||
std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2); | |||
src2_ptr += PH * IW2; | |||
} | |||
rep(ih, IH) { | |||
if (PW != 0) | |||
rep(pw, PW) { *(src2_ptr++) = src_zp; } | |||
std::memcpy(src2_ptr, src_ptr, sizeof(uint8_t) * IW); | |||
src2_ptr += IW; | |||
src_ptr += IW; | |||
if (PW != 0) | |||
rep(pw, PW) { *(src2_ptr++) = src_zp; } | |||
} | |||
if (PH != 0) { | |||
std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2); | |||
src2_ptr += PH * IW2; | |||
} | |||
} | |||
B = static_cast<uint8_t*>(bundle.get(1)); | |||
if (SH == 1 && SW == 1) { | |||
if (is_xcorr) | |||
img2col<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | |||
else | |||
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | |||
} else { | |||
if (is_xcorr) | |||
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||
FW, SH, SW); | |||
else | |||
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||
FW, SH, SW); | |||
} | |||
} | |||
{ | |||
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)), | |||
bundle.get_size(2)); | |||
size_t M = OC; | |||
size_t K = IC * FH * FW; | |||
size_t N = OH * OW; | |||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
_bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \ | |||
_bias_midout_enum, _nonline_midout_enum) { \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
megdnn::matmul::GemmInterleaved< \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ | |||
bias); \ | |||
} \ | |||
MIDOUT_END() | |||
DISPATCH_GEMM_BIAS(u8_8x8, 0) | |||
#undef DISPATCH_GEMM_STRATEGY | |||
} | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,52 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/quint8/algos.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/aarch64/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||
class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase { | |||
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | |||
static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "QU8MATMUL"; } | |||
bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(FallbackConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
SmallVector<NCBKern> dispatch_kerns( | |||
FallbackConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override { | |||
size_t group = param.filter_meta.group; | |||
return {{kimpl, {group, 1_z, 1_z}}}; | |||
} | |||
//! select matmul to the highest preference | |||
bool is_preferred(FallbackConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override { | |||
return static_cast<arm_common::ConvBiasImpl*>(opr) | |||
->is_matmul_quantized_prefer(param); | |||
} | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,319 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/quint8/strategy.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/conv_bias/quint8/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" | |||
#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" | |||
#include "src/arm_common/conv_bias/matmul_postprocess.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
namespace impl { | |||
template <BiasMode bmode, typename Op, int block_m, int block_n> | |||
struct KernCaller; | |||
#if __ARM_FEATURE_DOTPROD | |||
template <BiasMode bmode, typename Op> | |||
struct KernCaller<bmode, Op, 8, 8> { | |||
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||
size_t N, size_t K, dt_uint8* C, size_t LDC, | |||
bool is_first_k, Op op, const dt_int32* bias, | |||
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||
megdnn_assert(is_first_k); | |||
constexpr size_t A_INTERLEAVE = 8; | |||
constexpr size_t B_INTERLEAVE = 8; | |||
const uint32_t zAB = | |||
static_cast<uint32_t>(zp_A) * static_cast<uint32_t>(zp_B) * K; | |||
//! K is packed to times of 4 | |||
K = round_up<size_t>(K, 4); | |||
const int K8 = (K << 3); | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
uint8_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_uint8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x4::kern_8x8(packA, cur_packB, K, workspace, 8, | |||
is_first_k, zp_A, zp_B, zAB); | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, | |||
8>::postprocess(bias, workspace, | |||
output, LDC, op); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x4::kern_8x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k, std::min<size_t>(N - n, 4), | |||
zp_A, zp_B, zAB); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4)); | |||
#undef cb | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K8; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias += A_INTERLEAVE; | |||
} | |||
} | |||
for (; m < M; m += 4) { | |||
uint8_t* output = C + (m * LDC); | |||
const dt_uint8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x4::kern_4x8(packA, cur_packB, K, workspace, 8, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
zp_A, zp_B, zAB); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_M_N(cb, std::min<size_t>(M - m, 4), 8); | |||
#undef cb | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x4::kern_4x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4), zp_A, zp_B, | |||
zAB); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
#undef cb | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias += 4; | |||
} | |||
} | |||
} | |||
}; | |||
#else | |||
template <BiasMode bmode, typename Op> | |||
struct KernCaller<bmode, Op, 8, 8> { | |||
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||
size_t N, size_t K, dt_uint8* C, size_t LDC, | |||
bool is_first_k, Op op, const dt_int32* bias, | |||
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||
megdnn_assert(is_first_k); | |||
constexpr size_t A_INTERLEAVE = 8; | |||
constexpr size_t B_INTERLEAVE = 8; | |||
//! K is packed to times of 8 | |||
K = round_up<size_t>(K, 8); | |||
const int K8 = K * 8; | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
uint8_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_uint8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, workspace, 8, | |||
is_first_k, zp_A, zp_B); | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, | |||
8>::postprocess(bias, workspace, | |||
output, LDC, op); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k, std::min<size_t>(N - n, 4), | |||
zp_A, zp_B); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4)); | |||
#undef cb | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K8; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias += A_INTERLEAVE; | |||
} | |||
} | |||
for (; m < M; m += 4) { | |||
uint8_t* output = C + (m * LDC); | |||
const dt_uint8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, workspace, 8, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
zp_A, zp_B); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_M_N(cb, std::min<size_t>(M - m, 4), 8); | |||
#undef cb | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4), zp_A, zp_B); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
#undef cb | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias += 4; | |||
} | |||
} | |||
} | |||
}; | |||
#endif | |||
} // namespace impl | |||
#if __ARM_FEATURE_DOTPROD | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity) | |||
void gemm_u8_8x8_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0, | |||
ymax, k0, kmax); | |||
} else { | |||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin, | |||
y0, ymax, k0, kmax); | |||
} | |||
} | |||
void gemm_u8_8x8_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, | |||
int ldin, int x0, int xmax, int k0, | |||
int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, | |||
xmax, k0, kmax); | |||
} else { | |||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax, | |||
k0, kmax); | |||
} | |||
} | |||
#else | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity) | |||
void gemm_u8_8x8_nobias_identity::pack_A(dt_uint8* outptr, | |||
const dt_uint8* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||
if (transpose) { | |||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||
ymax, k0, kmax, zA); | |||
} else { | |||
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax, zA); | |||
} | |||
} | |||
void gemm_u8_8x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, | |||
int ldin, int x0, int xmax, int k0, | |||
int kmax, bool transpose) const { | |||
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||
if (transpose) { | |||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||
k0, kmax, zB); | |||
} else { | |||
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, | |||
zB); | |||
} | |||
} | |||
#endif | |||
size_t gemm_u8_8x8_nobias_identity::get_workspace_size() const { | |||
return 8 * 8 * sizeof(dt_int32); | |||
} | |||
#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ | |||
void gemm_u8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ | |||
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \ | |||
size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ | |||
const dt_int32* bias, dt_int32* workspace) const { \ | |||
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
DEFINE_OP(_OP); \ | |||
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ | |||
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||
workspace, zp_A, zp_B); \ | |||
} | |||
#define DEFINE_OP(_Op) \ | |||
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, scale_C, zp_C); | |||
KERN(8, 8, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) | |||
KERN(8, 8, nobias, BiasMode::NO_BIAS, relu, ReluOp) | |||
KERN(8, 8, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||
#undef DEFINE_OP | |||
#define DEFINE_OP(_Op) \ | |||
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \ | |||
scale_A* scale_B, scale_C, zp_C); | |||
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||
FuseAddHSwishOp) | |||
#undef DEFINE_OP | |||
#undef KERN | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,48 @@ | |||
/** | |||
* \file dnn/src/aarch64/conv_bias/quint8/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
#if __ARM_FEATURE_DOTPROD | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4, | |||
false, true, | |||
gemm_u8_8x8_nobias_identity); | |||
#else | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8, | |||
false, true, | |||
gemm_u8_8x8_nobias_identity); | |||
#endif | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_relu, | |||
gemm_u8_8x8_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_hswish, | |||
gemm_u8_8x8_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_identity, | |||
gemm_u8_8x8_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_relu, | |||
gemm_u8_8x8_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_hswish, | |||
gemm_u8_8x8_nobias_identity); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,44 @@ | |||
/** | |||
* \file dnn/src/aarch64/handle.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/common/handle_impl.h" | |||
#include "src/aarch64/handle.h" | |||
#include "src/aarch64/matrix_mul/opr_impl.h" | |||
#include "src/aarch64/rotate/opr_impl.h" | |||
#include "src/aarch64/relayout/opr_impl.h" | |||
#include "src/aarch64/conv_bias/opr_impl.h" | |||
#include "src/aarch64/warp_perspective/opr_impl.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
template <typename Opr> | |||
std::unique_ptr<Opr> HandleImpl::create_operator() { | |||
return arm_common::HandleImpl::create_operator<Opr>(); | |||
} | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMul) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Rotate) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutForward) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspective) | |||
#pragma GCC diagnostic push | |||
#pragma GCC diagnostic ignored "-Wpragmas" | |||
#pragma GCC diagnostic ignored "-Winstantiation-after-specialization" | |||
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) | |||
#pragma GCC diagnostic pop | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,33 @@ | |||
/** | |||
* \file dnn/src/aarch64/handle.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/arm_common/handle.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
class HandleImpl: public arm_common::HandleImpl { | |||
public: | |||
HandleImpl(megcoreComputingHandle_t computing_handle, | |||
HandleType type = HandleType::AARCH64): | |||
arm_common::HandleImpl::HandleImpl(computing_handle, type) | |||
{} | |||
template <typename Opr> | |||
std::unique_ptr<Opr> create_operator(); | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
@@ -0,0 +1,250 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/algos.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/aarch64/matrix_mul/opr_impl.h" | |||
#include "src/arm_common/matrix_mul/algos.h" | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_F32K8X12X1"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_F32K4X16X1"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_F32_MK4_4x16"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
}; | |||
class MatrixMulImpl::AlgoF32Gemv final | |||
: public arm_common::MatrixMulImpl::AlgoF32Gemv {}; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_F16_K8X24X1"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_F16_MK8_8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
}; | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; | |||
} | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32GemvDotProd final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return "AARCH64_INT8X8X32_GEMV_DOTPROD"; | |||
} | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
}; | |||
#else | |||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return "AARCH64_INT8X8X32_MK4_4X4X16"; | |||
} | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::DEFAULT; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32Gemv final | |||
: public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {}; | |||
#endif | |||
class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
}; | |||
#if __ARM_FEATURE_DOTPROD | |||
class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return "AARCH64_QUINT8_K8X8X4_DOTPROD"; | |||
} | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
}; | |||
#else | |||
class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
#endif | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,29 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/fp16/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 8, 24, 1, false, | |||
true, hgemm_8x24); | |||
MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 8, 8, 1, | |||
false, true, gemm_nopack_f16_8x8); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,439 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/fp16/strategy.h" | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
namespace { | |||
// Overview of register layout: | |||
// | |||
// A 8x1 cell of Rhs is stored in 16bit in v0-v3 | |||
// A 8x1 cell of Lhs is stored in 16bit in v16-v23 | |||
// A 8x1 block of accumulators is stored in 16bit in v24-v27. | |||
// | |||
// Rhs +-------+ | |||
// |v0[0-7]| | |||
// |v1[0-7]| | |||
// |v2[0-7]| | |||
// |v3[0-7]| | |||
// +-------+ | |||
// Lhs | |||
// +--------+ | |||
// |v16[0-7]| | |||
// |v17[0-7]| | |||
// |v18[0-7]| | |||
// |v19[0-7]| +--------+ | |||
// |v20[0-7]| |v24[0-7]| | |||
// |v21[0-7]| |v25[0-7]| | |||
// |v22[0-7]| |v26[0-7]| | |||
// |v23[0-7]| |v27[0-7]| | |||
// +--------+ +--------+ | |||
// Accumulator | |||
void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
dt_float16* output) { | |||
//! LDB means number of elements in one block in B. we will read 24 numbers | |||
//! first. so minus 24 * 2 bytes here. | |||
LDB = (LDB - 24) * sizeof(dt_float16); | |||
asm volatile( | |||
".arch armv8.2-a+fp16\n" | |||
"ld1 {v16.4s, v17.4s}, [%[a_ptr]], 32\n" | |||
"subs %w[K], %w[K], #8\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||
"fmul v24.8h, v16.8h, v0.h[0]\n" | |||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
"fmul v25.8h, v16.8h, v1.h[0]\n" | |||
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"fmul v26.8h, v16.8h, v2.h[0]\n" | |||
"ld1 {v18.4s}, [%[a_ptr]], 16\n" | |||
"fmul v27.8h, v16.8h, v3.h[0]\n" | |||
"fmla v24.8h, v17.8h, v0.h[1]\n" | |||
"fmla v25.8h, v17.8h, v1.h[1]\n" | |||
"fmla v26.8h, v17.8h, v2.h[1]\n" | |||
"fmla v27.8h, v17.8h, v3.h[1]\n" | |||
"ld1 {v19.4s}, [%[a_ptr]], 16\n" | |||
"fmla v24.8h, v18.8h, v0.h[2]\n" | |||
"fmla v25.8h, v18.8h, v1.h[2]\n" | |||
"fmla v26.8h, v18.8h, v2.h[2]\n" | |||
"fmla v27.8h, v18.8h, v3.h[2]\n" | |||
"ld1 {v20.4s}, [%[a_ptr]], 16\n" | |||
"fmla v24.8h, v19.8h, v0.h[3]\n" | |||
"fmla v25.8h, v19.8h, v1.h[3]\n" | |||
"fmla v26.8h, v19.8h, v2.h[3]\n" | |||
"fmla v27.8h, v19.8h, v3.h[3]\n" | |||
"ld1 {v21.4s}, [%[a_ptr]], 16\n" | |||
"fmla v24.8h, v20.8h, v0.h[4]\n" | |||
"fmla v25.8h, v20.8h, v1.h[4]\n" | |||
"fmla v26.8h, v20.8h, v2.h[4]\n" | |||
"fmla v27.8h, v20.8h, v3.h[4]\n" | |||
"ld1 {v22.4s}, [%[a_ptr]], 16\n" | |||
"fmla v24.8h, v21.8h, v0.h[5]\n" | |||
"fmla v25.8h, v21.8h, v1.h[5]\n" | |||
"fmla v26.8h, v21.8h, v2.h[5]\n" | |||
"fmla v27.8h, v21.8h, v3.h[5]\n" | |||
"ld1 {v23.4s}, [%[a_ptr]], 16\n" | |||
"fmla v24.8h, v22.8h, v0.h[6]\n" | |||
"fmla v25.8h, v22.8h, v1.h[6]\n" | |||
"fmla v26.8h, v22.8h, v2.h[6]\n" | |||
"fmla v27.8h, v22.8h, v3.h[6]\n" | |||
"beq 2f\n" | |||
"1:\n" | |||
"ld1 {v16.4s}, [%[a_ptr]], 16\n" | |||
"fmla v24.8h, v23.8h, v0.h[7]\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||
"fmla v25.8h, v23.8h, v1.h[7]\n" | |||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||
"fmla v26.8h, v23.8h, v2.h[7]\n" | |||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
"fmla v27.8h, v23.8h, v3.h[7]\n" | |||
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"ld1 {v17.4s}, [%[a_ptr]], 16\n" | |||
"fmla v24.8h, v16.8h, v0.h[0]\n" | |||
"fmla v25.8h, v16.8h, v1.h[0]\n" | |||
"fmla v26.8h, v16.8h, v2.h[0]\n" | |||
"fmla v27.8h, v16.8h, v3.h[0]\n" | |||
"ld1 {v18.4s}, [%[a_ptr]], 16\n" | |||
"fmla v24.8h, v17.8h, v0.h[1]\n" | |||
"fmla v25.8h, v17.8h, v1.h[1]\n" | |||
"fmla v26.8h, v17.8h, v2.h[1]\n" | |||
"fmla v27.8h, v17.8h, v3.h[1]\n" | |||
"ld1 {v19.4s}, [%[a_ptr]], 16\n" | |||
"fmla v24.8h, v18.8h, v0.h[2]\n" | |||
"fmla v25.8h, v18.8h, v1.h[2]\n" | |||
"fmla v26.8h, v18.8h, v2.h[2]\n" | |||
"fmla v27.8h, v18.8h, v3.h[2]\n" | |||
"ld1 {v20.4s}, [%[a_ptr]], 16\n" | |||
"fmla v24.8h, v19.8h, v0.h[3]\n" | |||
"fmla v25.8h, v19.8h, v1.h[3]\n" | |||
"fmla v26.8h, v19.8h, v2.h[3]\n" | |||
"fmla v27.8h, v19.8h, v3.h[3]\n" | |||
"ld1 {v21.4s}, [%[a_ptr]], 16\n" | |||
"fmla v24.8h, v20.8h, v0.h[4]\n" | |||
"fmla v25.8h, v20.8h, v1.h[4]\n" | |||
"fmla v26.8h, v20.8h, v2.h[4]\n" | |||
"fmla v27.8h, v20.8h, v3.h[4]\n" | |||
"ld1 {v22.4s}, [%[a_ptr]], 16\n" | |||
"fmla v24.8h, v21.8h, v0.h[5]\n" | |||
"fmla v25.8h, v21.8h, v1.h[5]\n" | |||
"fmla v26.8h, v21.8h, v2.h[5]\n" | |||
"fmla v27.8h, v21.8h, v3.h[5]\n" | |||
"ld1 {v23.4s}, [%[a_ptr]], 16\n" | |||
"fmla v24.8h, v22.8h, v0.h[6]\n" | |||
"fmla v25.8h, v22.8h, v1.h[6]\n" | |||
"fmla v26.8h, v22.8h, v2.h[6]\n" | |||
"fmla v27.8h, v22.8h, v3.h[6]\n" | |||
"subs %w[K], %w[K], #8\n" | |||
"bne 1b\n" | |||
"2:\n" | |||
"fmla v24.8h, v23.8h, v0.h[7]\n" | |||
"fmla v25.8h, v23.8h, v1.h[7]\n" | |||
"fmla v26.8h, v23.8h, v2.h[7]\n" | |||
"fmla v27.8h, v23.8h, v3.h[7]\n" | |||
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
// | |||
// A 8x1 cell of Rhs is stored in 16bit in v8-v15 | |||
// A 8x1 cell of Lhs is stored in 16bit in v0-v7 | |||
// A 8x1 block of accumulators is stored in 16bit in v24-v31. | |||
// | |||
// Rhs +--------+ | |||
// | v8[0-7]| | |||
// | v9[0-7]| | |||
// |v10[0-7]| | |||
// |v11[0-7]| | |||
// |v12[0-7]| | |||
// |v13[0-7]| | |||
// |v14[0-7]| | |||
// |v15[0-7]| | |||
// +--------+ | |||
// Lhs | |||
// +--------+ - - - - -+--------+ | |||
// | v0[0-7]| |v24[0-7]| | |||
// | v1[0-7]| |v25[0-7]| | |||
// | v2[0-7]| |v26[0-7]| | |||
// | v3[0-7]| |v27[0-7]| | |||
// | v4[0-7]| |v28[0-7]| | |||
// | v5[0-7]| |v29[0-7]| | |||
// | v6[0-7]| |v30[0-7]| | |||
// | v7[0-7]| |v31[0-7]| | |||
// +--------+ +--------+ | |||
// Accumulator | |||
void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
dt_float16* output) { | |||
//! As each load 128 number from B, but the pos add 112 * 2, so we minus 112 | |||
//! here. | |||
LDB = (LDB - 32) * sizeof(dt_float16); | |||
asm volatile( | |||
".arch armv8.2-a+fp16\n" | |||
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n" | |||
"subs %w[K], %w[K], #8\n" | |||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"fmul v24.8h, v8.8h, v0.h[0]\n" | |||
"fmul v25.8h, v8.8h, v1.h[0]\n" | |||
"fmul v26.8h, v8.8h, v2.h[0]\n" | |||
"fmul v27.8h, v8.8h, v3.h[0]\n" | |||
"fmul v28.8h, v8.8h, v4.h[0]\n" | |||
"fmul v29.8h, v8.8h, v5.h[0]\n" | |||
"fmul v30.8h, v8.8h, v6.h[0]\n" | |||
"fmul v31.8h, v8.8h, v7.h[0]\n" | |||
"fmla v24.8h, v9.8h, v0.h[1]\n" | |||
"fmla v25.8h, v9.8h, v1.h[1]\n" | |||
"fmla v26.8h, v9.8h, v2.h[1]\n" | |||
"fmla v27.8h, v9.8h, v3.h[1]\n" | |||
"fmla v28.8h, v9.8h, v4.h[1]\n" | |||
"fmla v29.8h, v9.8h, v5.h[1]\n" | |||
"fmla v30.8h, v9.8h, v6.h[1]\n" | |||
"fmla v31.8h, v9.8h, v7.h[1]\n" | |||
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n" | |||
"fmla v24.8h, v10.8h, v0.h[2]\n" | |||
"fmla v25.8h, v10.8h, v1.h[2]\n" | |||
"fmla v26.8h, v10.8h, v2.h[2]\n" | |||
"fmla v27.8h, v10.8h, v3.h[2]\n" | |||
"fmla v28.8h, v10.8h, v4.h[2]\n" | |||
"fmla v29.8h, v10.8h, v5.h[2]\n" | |||
"fmla v30.8h, v10.8h, v6.h[2]\n" | |||
"fmla v31.8h, v10.8h, v7.h[2]\n" | |||
"fmla v24.8h, v11.8h, v0.h[3]\n" | |||
"fmla v25.8h, v11.8h, v1.h[3]\n" | |||
"fmla v26.8h, v11.8h, v2.h[3]\n" | |||
"fmla v27.8h, v11.8h, v3.h[3]\n" | |||
"fmla v28.8h, v11.8h, v4.h[3]\n" | |||
"fmla v29.8h, v11.8h, v5.h[3]\n" | |||
"fmla v30.8h, v11.8h, v6.h[3]\n" | |||
"fmla v31.8h, v11.8h, v7.h[3]\n" | |||
"fmla v24.8h, v12.8h, v0.h[4]\n" | |||
"fmla v25.8h, v12.8h, v1.h[4]\n" | |||
"fmla v26.8h, v12.8h, v2.h[4]\n" | |||
"fmla v27.8h, v12.8h, v3.h[4]\n" | |||
"fmla v24.8h, v13.8h, v0.h[5]\n" | |||
"fmla v25.8h, v13.8h, v1.h[5]\n" | |||
"fmla v26.8h, v13.8h, v2.h[5]\n" | |||
"fmla v27.8h, v13.8h, v3.h[5]\n" | |||
"beq 2f\n" | |||
"1:\n" | |||
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n" | |||
"fmla v24.8h, v15.8h, v0.h[7]\n" | |||
"fmla v25.8h, v15.8h, v1.h[7]\n" | |||
"fmla v26.8h, v15.8h, v2.h[7]\n" | |||
"fmla v27.8h, v15.8h, v3.h[7]\n" | |||
"fmla v24.8h, v14.8h, v0.h[6]\n" | |||
"fmla v25.8h, v14.8h, v1.h[6]\n" | |||
"fmla v26.8h, v14.8h, v2.h[6]\n" | |||
"fmla v27.8h, v14.8h, v3.h[6]\n" | |||
"fmla v28.8h, v12.8h, v4.h[4]\n" | |||
"fmla v29.8h, v12.8h, v5.h[4]\n" | |||
"fmla v30.8h, v12.8h, v6.h[4]\n" | |||
"fmla v31.8h, v12.8h, v7.h[4]\n" | |||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||
"fmla v28.8h, v13.8h, v4.h[5]\n" | |||
"fmla v29.8h, v13.8h, v5.h[5]\n" | |||
"fmla v30.8h, v13.8h, v6.h[5]\n" | |||
"fmla v31.8h, v13.8h, v7.h[5]\n" | |||
"fmla v28.8h, v14.8h, v4.h[6]\n" | |||
"fmla v29.8h, v14.8h, v5.h[6]\n" | |||
"fmla v30.8h, v14.8h, v6.h[6]\n" | |||
"fmla v31.8h, v14.8h, v7.h[6]\n" | |||
"fmla v28.8h, v15.8h, v4.h[7]\n" | |||
"fmla v29.8h, v15.8h, v5.h[7]\n" | |||
"fmla v30.8h, v15.8h, v6.h[7]\n" | |||
"fmla v31.8h, v15.8h, v7.h[7]\n" | |||
"fmla v24.8h, v8.8h, v0.h[0]\n" | |||
"fmla v25.8h, v8.8h, v1.h[0]\n" | |||
"fmla v26.8h, v8.8h, v2.h[0]\n" | |||
"fmla v27.8h, v8.8h, v3.h[0]\n" | |||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"fmla v24.8h, v9.8h, v0.h[1]\n" | |||
"fmla v25.8h, v9.8h, v1.h[1]\n" | |||
"fmla v26.8h, v9.8h, v2.h[1]\n" | |||
"fmla v27.8h, v9.8h, v3.h[1]\n" | |||
"fmla v24.8h, v10.8h, v0.h[2]\n" | |||
"fmla v25.8h, v10.8h, v1.h[2]\n" | |||
"fmla v26.8h, v10.8h, v2.h[2]\n" | |||
"fmla v27.8h, v10.8h, v3.h[2]\n" | |||
"fmla v24.8h, v11.8h, v0.h[3]\n" | |||
"fmla v25.8h, v11.8h, v1.h[3]\n" | |||
"fmla v26.8h, v11.8h, v2.h[3]\n" | |||
"fmla v27.8h, v11.8h, v3.h[3]\n" | |||
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n" | |||
"fmla v28.8h, v10.8h, v4.h[2]\n" | |||
"fmla v29.8h, v10.8h, v5.h[2]\n" | |||
"fmla v30.8h, v10.8h, v6.h[2]\n" | |||
"fmla v31.8h, v10.8h, v7.h[2]\n" | |||
"fmla v28.8h, v8.8h, v4.h[0]\n" | |||
"fmla v29.8h, v8.8h, v5.h[0]\n" | |||
"fmla v30.8h, v8.8h, v6.h[0]\n" | |||
"fmla v31.8h, v8.8h, v7.h[0]\n" | |||
"fmla v28.8h, v9.8h, v4.h[1]\n" | |||
"fmla v29.8h, v9.8h, v5.h[1]\n" | |||
"fmla v30.8h, v9.8h, v6.h[1]\n" | |||
"fmla v31.8h, v9.8h, v7.h[1]\n" | |||
"fmla v28.8h, v11.8h, v4.h[3]\n" | |||
"fmla v29.8h, v11.8h, v5.h[3]\n" | |||
"fmla v30.8h, v11.8h, v6.h[3]\n" | |||
"fmla v31.8h, v11.8h, v7.h[3]\n" | |||
"fmla v24.8h, v12.8h, v0.h[4]\n" | |||
"fmla v25.8h, v12.8h, v1.h[4]\n" | |||
"fmla v26.8h, v12.8h, v2.h[4]\n" | |||
"fmla v27.8h, v12.8h, v3.h[4]\n" | |||
"fmla v24.8h, v13.8h, v0.h[5]\n" | |||
"fmla v25.8h, v13.8h, v1.h[5]\n" | |||
"fmla v26.8h, v13.8h, v2.h[5]\n" | |||
"fmla v27.8h, v13.8h, v3.h[5]\n" | |||
"subs %w[K], %w[K], #8\n" | |||
"bne 1b\n" | |||
"2:\n" | |||
"fmla v24.8h, v14.8h, v0.h[6]\n" | |||
"fmla v25.8h, v14.8h, v1.h[6]\n" | |||
"fmla v26.8h, v14.8h, v2.h[6]\n" | |||
"fmla v27.8h, v14.8h, v3.h[6]\n" | |||
"fmla v24.8h, v15.8h, v0.h[7]\n" | |||
"fmla v25.8h, v15.8h, v1.h[7]\n" | |||
"fmla v26.8h, v15.8h, v2.h[7]\n" | |||
"fmla v27.8h, v15.8h, v3.h[7]\n" | |||
"fmla v28.8h, v12.8h, v4.h[4]\n" | |||
"fmla v29.8h, v12.8h, v5.h[4]\n" | |||
"fmla v28.8h, v13.8h, v4.h[5]\n" | |||
"fmla v29.8h, v13.8h, v5.h[5]\n" | |||
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n" | |||
"fmla v28.8h, v14.8h, v4.h[6]\n" | |||
"fmla v29.8h, v14.8h, v5.h[6]\n" | |||
"fmla v28.8h, v15.8h, v4.h[7]\n" | |||
"fmla v29.8h, v15.8h, v5.h[7]\n" | |||
"fmla v30.8h, v12.8h, v6.h[4]\n" | |||
"fmla v31.8h, v12.8h, v7.h[4]\n" | |||
"fmla v30.8h, v13.8h, v6.h[5]\n" | |||
"fmla v31.8h, v13.8h, v7.h[5]\n" | |||
"fmla v30.8h, v14.8h, v6.h[6]\n" | |||
"fmla v31.8h, v14.8h, v7.h[6]\n" | |||
"fmla v30.8h, v15.8h, v6.h[7]\n" | |||
"fmla v31.8h, v15.8h, v7.h[7]\n" | |||
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", | |||
"v28", "v29", "v30", "v31", "cc", "memory"); | |||
} | |||
} // anonymous namespace | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8); | |||
void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, | |||
const dt_float16* B, size_t LDB, dt_float16* C, | |||
size_t LDC, size_t M, size_t K, size_t N, | |||
const dt_float16*, void*, bool trA, | |||
bool trB) const { | |||
constexpr static size_t MB = 8; | |||
constexpr static size_t KB = 8; | |||
constexpr static size_t NB = 8; | |||
constexpr static size_t CALCBLK = 4; | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||
for (size_t m = 0; m < M; m += MB) { | |||
dt_float16* output = C + (m / MB) * LDC; | |||
const dt_float16* cur_B = B; | |||
size_t n = 0; | |||
for (; n + NB - 1 < N; n += NB) { | |||
kern_8x8(A, cur_B, LDB, K, output); | |||
cur_B += KB * NB; | |||
output += MB * NB; | |||
} | |||
if (n < N) { | |||
kern_8x4(A, cur_B, LDB, K, output); | |||
} | |||
A += LDA; | |||
} | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,39 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/fp32/common.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <cstddef> | |||
#include "megdnn/arch.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
MEGDNN_NOINLINE void sgemm_packA_n(const float* A, float* Apacked, size_t M, | |||
size_t K, size_t LDA, const float* alpha); | |||
MEGDNN_NOINLINE void sgemm_packA_t(const float* A, float* Apacked, size_t M, | |||
size_t K, size_t LDA, const float* alpha); | |||
MEGDNN_NOINLINE void sgemm_packB_n(const float* B, float* Bpacked, size_t K, | |||
size_t N, size_t LDB); | |||
MEGDNN_NOINLINE void sgemm_packB_t(const float* B, float* Bpacked, size_t K, | |||
size_t N, size_t LDB); | |||
MEGDNN_NOINLINE void sgemm_kernel12x8(const float* A, const float* B, float* C, | |||
size_t LDC, size_t M, size_t N, size_t K, | |||
int type, const float* beta); | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,718 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul_general_4x16 { | |||
// Overview of register layout: | |||
// | |||
// A 1x16 cell of Rhs is stored in 32bit in v1-v4 | |||
// A 4x1 cell of Lhs is stored in 32bit in v0 | |||
// A 4x16 block of accumulators is stored in 32bit in v10-v25. | |||
// | |||
// +--------+--------+--------+--------+ | |||
// | v2[0-3]| v3[0-3]| v4[0-3]| v5[0-3]| | |||
// Rhs +--------+--------+--------+--------+ | |||
// | |||
// | | | | | | |||
// | |||
// Lhs | | | | | | |||
// | |||
// +--+ - - - - +--------+--------+--------+--------+ | |||
// |v0| |v10[0-3]|v11[0-3]|v12[0-3]|v13[0-3]| | |||
// |v0| |v14[0-3]|v15[0-3]|v16[0-3]|v17[0-3]| | |||
// |v0| |v18[0-3]|v19[0-3]|v20[0-3]|v21[0-3]| | |||
// |v0| |v22[0-3]|v23[0-3]|v24[0-3]|v25[0-3]| | |||
// +--+ - - - - +--------+--------+--------+--------+ | |||
// | |||
// Accumulator | |||
void kern_4x16(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, int m_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
K = ((K + 1) / 2) - 1; | |||
LDC = LDC * sizeof(float); | |||
register float* outptr asm("x0") = reinterpret_cast<float*>(output); | |||
// clang-format off | |||
#define LOAD_LINE(v0, v1, v2, v3, n) \ | |||
"cmp x10, #0\n" \ | |||
"beq 100f\n" \ | |||
"mov x9, x" n "\n" \ | |||
"ld1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s, v" v3 ".4s}, [x9], 64\n" \ | |||
"subs x10, x10, #1\n" | |||
#define LOAD_C \ | |||
"mov x10, %x[m_remain]\n" \ | |||
LOAD_LINE("10", "11", "12", "13", "0") \ | |||
LOAD_LINE("14", "15", "16", "17", "1") \ | |||
LOAD_LINE("18", "19", "20", "21", "2") \ | |||
LOAD_LINE("22", "23", "24", "25", "3") \ | |||
"100:\n" | |||
#define STORE_LINE(v0, v1, v2, v3, n) \ | |||
"cmp x10, #0\n" \ | |||
"beq 101f\n" \ | |||
"mov x9, x" n "\n" \ | |||
"st1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s, v" v3 ".4s}, [x9], 64\n" \ | |||
"subs x10, x10, #1\n" | |||
#define STORE_C \ | |||
"mov x10, %x[m_remain]\n" \ | |||
STORE_LINE("10", "11", "12", "13", "0") \ | |||
STORE_LINE("14", "15", "16", "17", "1") \ | |||
STORE_LINE("18", "19", "20", "21", "2") \ | |||
STORE_LINE("22", "23", "24", "25", "3") \ | |||
"101:\n" | |||
// clang-format on | |||
asm volatile( | |||
// load accumulator C | |||
"add x1, x0, %x[LDC]\n" | |||
"add x2, x1, %x[LDC]\n" | |||
"add x3, x2, %x[LDC]\n" | |||
"cmp %w[is_first_k], #1\n" | |||
"beq 1f\n" LOAD_C | |||
"b 2f\n" | |||
"1:\n" | |||
"eor v10.16b, v10.16b, v10.16b\n" | |||
"eor v11.16b, v11.16b, v11.16b\n" | |||
"eor v12.16b, v12.16b, v12.16b\n" | |||
"eor v13.16b, v13.16b, v13.16b\n" | |||
"eor v14.16b, v14.16b, v14.16b\n" | |||
"eor v15.16b, v15.16b, v15.16b\n" | |||
"eor v16.16b, v16.16b, v16.16b\n" | |||
"eor v17.16b, v17.16b, v17.16b\n" | |||
"eor v18.16b, v18.16b, v18.16b\n" | |||
"eor v19.16b, v19.16b, v19.16b\n" | |||
"eor v20.16b, v20.16b, v20.16b\n" | |||
"eor v21.16b, v21.16b, v21.16b\n" | |||
"eor v22.16b, v22.16b, v22.16b\n" | |||
"eor v23.16b, v23.16b, v23.16b\n" | |||
"eor v24.16b, v24.16b, v24.16b\n" | |||
"eor v25.16b, v25.16b, v25.16b\n" | |||
"2: \n" | |||
"ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], 64\n" | |||
"cmp %w[K], #0\n" | |||
"beq 4f\n" | |||
"3:\n" | |||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||
"fmla v10.4s, v2.4s, v0.s[0]\n" | |||
"fmla v11.4s, v3.4s, v0.s[0]\n" | |||
"fmla v12.4s, v4.4s, v0.s[0]\n" | |||
"fmla v13.4s, v5.4s, v0.s[0]\n" | |||
"ld1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[b_ptr]], 64\n" | |||
"fmla v14.4s, v2.4s, v0.s[1]\n" | |||
"fmla v15.4s, v3.4s, v0.s[1]\n" | |||
"fmla v16.4s, v4.4s, v0.s[1]\n" | |||
"fmla v17.4s, v5.4s, v0.s[1]\n" | |||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||
"fmla v18.4s, v2.4s, v0.s[2]\n" | |||
"fmla v19.4s, v3.4s, v0.s[2]\n" | |||
"fmla v20.4s, v4.4s, v0.s[2]\n" | |||
"fmla v21.4s, v5.4s, v0.s[2]\n" | |||
"fmla v22.4s, v2.4s, v0.s[3]\n" | |||
"fmla v23.4s, v3.4s, v0.s[3]\n" | |||
"fmla v24.4s, v4.4s, v0.s[3]\n" | |||
"fmla v25.4s, v5.4s, v0.s[3]\n" | |||
"ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], 64\n" | |||
"fmla v10.4s, v6.4s, v1.s[0]\n" | |||
"fmla v11.4s, v7.4s, v1.s[0]\n" | |||
"fmla v12.4s, v8.4s, v1.s[0]\n" | |||
"fmla v13.4s, v9.4s, v1.s[0]\n" | |||
"fmla v14.4s, v6.4s, v1.s[1]\n" | |||
"fmla v15.4s, v7.4s, v1.s[1]\n" | |||
"fmla v16.4s, v8.4s, v1.s[1]\n" | |||
"fmla v17.4s, v9.4s, v1.s[1]\n" | |||
"fmla v18.4s, v6.4s, v1.s[2]\n" | |||
"fmla v19.4s, v7.4s, v1.s[2]\n" | |||
"fmla v20.4s, v8.4s, v1.s[2]\n" | |||
"fmla v21.4s, v9.4s, v1.s[2]\n" | |||
"fmla v22.4s, v6.4s, v1.s[3]\n" | |||
"fmla v23.4s, v7.4s, v1.s[3]\n" | |||
"fmla v24.4s, v8.4s, v1.s[3]\n" | |||
"fmla v25.4s, v9.4s, v1.s[3]\n" | |||
"subs %w[K], %w[K], #1\n" | |||
"bne 3b\n" | |||
"4:\n" | |||
"cmp %w[oddk], #1\n" | |||
"beq 5f\n" | |||
// Even tail | |||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||
"fmla v10.4s, v2.4s, v0.s[0]\n" | |||
"fmla v11.4s, v3.4s, v0.s[0]\n" | |||
"fmla v12.4s, v4.4s, v0.s[0]\n" | |||
"fmla v13.4s, v5.4s, v0.s[0]\n" | |||
"ld1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[b_ptr]], 64\n" | |||
"fmla v14.4s, v2.4s, v0.s[1]\n" | |||
"fmla v15.4s, v3.4s, v0.s[1]\n" | |||
"fmla v16.4s, v4.4s, v0.s[1]\n" | |||
"fmla v17.4s, v5.4s, v0.s[1]\n" | |||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||
"fmla v18.4s, v2.4s, v0.s[2]\n" | |||
"fmla v19.4s, v3.4s, v0.s[2]\n" | |||
"fmla v20.4s, v4.4s, v0.s[2]\n" | |||
"fmla v21.4s, v5.4s, v0.s[2]\n" | |||
"fmla v22.4s, v2.4s, v0.s[3]\n" | |||
"fmla v23.4s, v3.4s, v0.s[3]\n" | |||
"fmla v24.4s, v4.4s, v0.s[3]\n" | |||
"fmla v25.4s, v5.4s, v0.s[3]\n" | |||
"fmla v10.4s, v6.4s, v1.s[0]\n" | |||
"fmla v11.4s, v7.4s, v1.s[0]\n" | |||
"fmla v12.4s, v8.4s, v1.s[0]\n" | |||
"fmla v13.4s, v9.4s, v1.s[0]\n" | |||
"fmla v14.4s, v6.4s, v1.s[1]\n" | |||
"fmla v15.4s, v7.4s, v1.s[1]\n" | |||
"fmla v16.4s, v8.4s, v1.s[1]\n" | |||
"fmla v17.4s, v9.4s, v1.s[1]\n" | |||
"fmla v18.4s, v6.4s, v1.s[2]\n" | |||
"fmla v19.4s, v7.4s, v1.s[2]\n" | |||
"fmla v20.4s, v8.4s, v1.s[2]\n" | |||
"fmla v21.4s, v9.4s, v1.s[2]\n" | |||
"fmla v22.4s, v6.4s, v1.s[3]\n" | |||
"fmla v23.4s, v7.4s, v1.s[3]\n" | |||
"fmla v24.4s, v8.4s, v1.s[3]\n" | |||
"fmla v25.4s, v9.4s, v1.s[3]\n" | |||
"b 6f\n" | |||
// odd tail | |||
"5:\n" | |||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||
"fmla v10.4s, v2.4s, v0.s[0]\n" | |||
"fmla v11.4s, v3.4s, v0.s[0]\n" | |||
"fmla v12.4s, v4.4s, v0.s[0]\n" | |||
"fmla v13.4s, v5.4s, v0.s[0]\n" | |||
"fmla v14.4s, v2.4s, v0.s[1]\n" | |||
"fmla v15.4s, v3.4s, v0.s[1]\n" | |||
"fmla v16.4s, v4.4s, v0.s[1]\n" | |||
"fmla v17.4s, v5.4s, v0.s[1]\n" | |||
"fmla v18.4s, v2.4s, v0.s[2]\n" | |||
"fmla v19.4s, v3.4s, v0.s[2]\n" | |||
"fmla v20.4s, v4.4s, v0.s[2]\n" | |||
"fmla v21.4s, v5.4s, v0.s[2]\n" | |||
"fmla v22.4s, v2.4s, v0.s[3]\n" | |||
"fmla v23.4s, v3.4s, v0.s[3]\n" | |||
"fmla v24.4s, v4.4s, v0.s[3]\n" | |||
"fmla v25.4s, v5.4s, v0.s[3]\n" | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[m_remain] "+r"(m_remain), [outptr] "+r"(outptr) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", | |||
"x10", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
#undef STORE_C | |||
} | |||
// Overview of register layout: | |||
// | |||
// A 2x4 cell of Rhs is stored in 32bit in v2 - v3 | |||
// A 4x2 cell of Lhs is stored in 32bit in v0 - v1 | |||
// A 4x4 block of accumulators is stored in 32bit in v4-v6 | |||
// | |||
// +--------+ | |||
// | v2[0-3]| | |||
// Rhs +--------+ | |||
// | v3[0-3]| | |||
// +--------+ | |||
// | |||
// | | | |||
// | |||
// Lhs | | | |||
// | |||
// +--+--+ - - - - +--------+ | |||
// |v0|v1| | v4[0-3]| | |||
// |v0|v1| | v5[0-3]| | |||
// |v0|v1| | v6[0-3]| | |||
// |v0|v1| | v7[0-3]| | |||
// +--+--+ - - - - +--------+ | |||
// | |||
// Accumulator | |||
void kern_4x4(const float* packA, const float* packB, int K, float* output, | |||
int LDC, bool is_first_k, int m_remain, int n_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
K = ((K + 1) / 2) - 1; | |||
LDC = LDC * sizeof(float); | |||
register float* outptr asm("x0") = output; | |||
// clang-format off | |||
#define LOAD_LINE(v0, n) \ | |||
"cmp x10, #0\n" \ | |||
"beq 102f\n" \ | |||
"cmp %w[n_remain], #4\n" \ | |||
"blt 100" n "f\n" \ | |||
"ld1 {v" v0 ".4s}, [x" n "], 16\n" \ | |||
"b 101" n "f\n" \ | |||
"100" n ":\n" \ | |||
"cmp %w[n_remain], #0\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" v0 ".s}[0], [x" n "], 4\n" \ | |||
"cmp %w[n_remain], #1\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" v0 ".s}[1], [x" n "], 4\n" \ | |||
"cmp %w[n_remain], #2\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" v0 ".s}[2], [x" n "], 4\n" \ | |||
"101" n ":\n" \ | |||
"subs x10, x10, #1\n" | |||
#define LOAD_C \ | |||
"mov x10, %x[m_remain]\n" \ | |||
LOAD_LINE("4", "0") \ | |||
LOAD_LINE("5", "1") \ | |||
LOAD_LINE("6", "2") \ | |||
LOAD_LINE("7", "3") \ | |||
"102:\n" | |||
#define STORE_LINE(v0, n) \ | |||
"cmp x10, #0 \n" \ | |||
"beq 105f\n" \ | |||
"cmp %w[n_remain], #4\n" \ | |||
"blt 103" n "f\n" \ | |||
"st1 {v" v0 ".4s}, [x" n " ], 16\n" \ | |||
"b 104" n "f\n" \ | |||
"103" n ":\n" \ | |||
"cmp %w[n_remain], #0\n" \ | |||
"beq 104" n "f\n" \ | |||
"st1 {v" v0 ".s}[0], [x" n "], 4\n" \ | |||
"cmp %w[n_remain], #1\n" \ | |||
"beq 104" n "f\n" \ | |||
"st1 {v" v0 ".s}[1], [x" n "], 4\n" \ | |||
"cmp %w[n_remain], #2\n" \ | |||
"beq 104" n "f\n" \ | |||
"st1 {v" v0 ".s}[2], [x" n "], 4\n" \ | |||
"104" n ":\n" \ | |||
"subs x10, x10, #1\n" | |||
#define STORE_C \ | |||
"mov x10, %x[m_remain]\n" \ | |||
STORE_LINE("4", "0") \ | |||
STORE_LINE("5", "1") \ | |||
STORE_LINE("6", "2") \ | |||
STORE_LINE("7", "3") \ | |||
"105:\n" | |||
// clang-format on | |||
asm volatile( | |||
// load accumulator C | |||
"add x1, x0, %x[LDC]\n" | |||
"add x2, x1, %x[LDC]\n" | |||
"add x3, x2, %x[LDC]\n" | |||
"cmp %w[is_first_k], #1\n" | |||
"beq 1f\n" LOAD_C | |||
"b 2f\n" | |||
"1:\n" | |||
"eor v4.16b, v4.16b, v4.16b\n" | |||
"eor v5.16b, v5.16b, v5.16b\n" | |||
"eor v6.16b, v6.16b, v6.16b\n" | |||
"eor v7.16b, v7.16b, v7.16b\n" | |||
"2: \n" | |||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
"cmp %w[K], #0\n" | |||
"beq 4f\n" | |||
"3:\n" | |||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||
"ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
"fmla v4.4s, v3.4s, v1.s[0]\n" | |||
"fmla v5.4s, v3.4s, v1.s[1]\n" | |||
"fmla v6.4s, v3.4s, v1.s[2]\n" | |||
"fmla v7.4s, v3.4s, v1.s[3]\n" | |||
"subs %w[K], %w[K], #1\n" | |||
"bne 3b\n" | |||
"4:\n" | |||
"cmp %w[oddk], #1\n" | |||
"beq 5f\n" | |||
// Even tail | |||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||
"ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||
"fmla v4.4s, v3.4s, v1.s[0]\n" | |||
"fmla v5.4s, v3.4s, v1.s[1]\n" | |||
"fmla v6.4s, v3.4s, v1.s[2]\n" | |||
"fmla v7.4s, v3.4s, v1.s[3]\n" | |||
"b 6f\n" | |||
// odd tail | |||
"5:\n" | |||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [m_remain] "+r"(m_remain), | |||
[n_remain] "+r"(n_remain), [outptr] "+r"(outptr) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", | |||
"x2", "x3", "x10", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
#undef STORE_C | |||
} | |||
void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||
int ymax, int k0, int kmax) { | |||
float zerobuff[4]; | |||
std::memset(zerobuff, 0, sizeof(float) * 4); | |||
constexpr int PACK_SIZE = 4*4; | |||
int y = y0; | |||
for (; y + 3 < ymax; y += 4) { | |||
// printf("main loop pack_a_n %p \n",outptr); | |||
const float* inptr0 = inptr + y * ldin + k0; | |||
const float* inptr1 = inptr0 + ldin; | |||
const float* inptr2 = inptr1 + ldin; | |||
const float* inptr3 = inptr2 + ldin; | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
prefetch_2x(inptr2); | |||
prefetch_2x(inptr3); | |||
int K = (kmax - k0); | |||
for (; K > 3; K -= 4) { | |||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); | |||
outptr += PACK_SIZE; | |||
} | |||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); | |||
} | |||
for (; y < ymax; y += 4) { | |||
const float* inptr0 = inptr + y * ldin + k0; | |||
const float* inptr1 = inptr0 + ldin; | |||
const float* inptr2 = inptr1 + ldin; | |||
const float* inptr3 = inptr2 + ldin; | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
prefetch_2x(inptr2); | |||
prefetch_2x(inptr3); | |||
int K = (kmax - k0); | |||
for (; K > 3; K -= 4) { | |||
if ((y + 3) >= ymax) { | |||
switch ((y + 3) - ymax) { | |||
/* Everything falls through in here */ | |||
case 2: | |||
inptr1 = zerobuff; | |||
case 1: | |||
inptr2 = zerobuff; | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); | |||
outptr += PACK_SIZE; | |||
} | |||
if (K > 0) { | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; | |||
case 1: | |||
inptr2 = zerobuff; | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); | |||
} | |||
} | |||
} | |||
void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||
int xmax, int k0, int kmax) { | |||
int ksize = kmax - k0; | |||
int ksize4 = (ksize << 2); | |||
float* outptr_base = out; | |||
int k = k0; | |||
for (; k + 3 < kmax; k += 4) { | |||
const float* inptr = in + k * ldin + x0; | |||
const float* inptr1 = inptr + ldin; | |||
const float* inptr2 = inptr1 + ldin; | |||
const float* inptr3 = inptr2 + ldin; | |||
prefetch_3x(inptr); | |||
prefetch_3x(inptr1); | |||
prefetch_3x(inptr2); | |||
prefetch_3x(inptr3); | |||
int x = x0; | |||
auto outptr = outptr_base; | |||
for (; x + 4 <= xmax; x += 4) { | |||
auto outptr_interleave = outptr; | |||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||
outptr_interleave); | |||
outptr += ksize4; | |||
} | |||
if (x < xmax) { | |||
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); | |||
} | |||
outptr_base += 4 * 4; | |||
} | |||
for (; k < kmax; k++) { | |||
const float* inptr = in + k * ldin + x0; | |||
prefetch_3x(inptr); | |||
int x = x0; | |||
auto outptr = outptr_base; | |||
for (; x + 4 <= xmax; x += 4) { | |||
auto outptr_interleave = outptr; | |||
interleave_1x4_1_s(inptr, outptr_interleave); | |||
outptr += ksize4; | |||
} | |||
if (x < xmax) { | |||
interleave_1(inptr, outptr, 4, xmax - x); | |||
} | |||
outptr_base += 4; | |||
} | |||
} | |||
void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
int ksize = kmax - k0; | |||
int ksize16 = ksize * 16; | |||
int ksize4 = (ksize << 2); | |||
float* outptr_base = out; | |||
float* outptr_base4 = outptr_base + (xmax - x0) / 16 * ksize16; | |||
int k = k0; | |||
for (; k + 3 < kmax; k += 4) { | |||
const float* inptr = in + k * ldin + x0; | |||
const float* inptr1 = inptr + ldin; | |||
const float* inptr2 = inptr1 + ldin; | |||
const float* inptr3 = inptr2 + ldin; | |||
prefetch_3x(inptr); | |||
prefetch_3x(inptr1); | |||
prefetch_3x(inptr2); | |||
prefetch_3x(inptr3); | |||
int x = x0; | |||
auto outptr = outptr_base; | |||
for (; x + 16 <= xmax; x += 16) { | |||
auto outptr_interleave = outptr; | |||
interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, | |||
outptr_interleave); | |||
outptr += ksize16; | |||
} | |||
outptr = outptr_base4; | |||
for (; x + 4 <= xmax; x += 4) { | |||
auto outptr_interleave = outptr; | |||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||
outptr_interleave); | |||
outptr += ksize4; | |||
} | |||
if (x < xmax) { | |||
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); | |||
} | |||
outptr_base += 16 * 4; | |||
outptr_base4 += 4 * 4; | |||
} | |||
for (; k < kmax; k++) { | |||
const float* inptr = in + k * ldin + x0; | |||
prefetch_3x(inptr); | |||
int x = x0; | |||
auto outptr = outptr_base; | |||
for (; x + 16 <= xmax; x += 16) { | |||
auto outptr_interleave = outptr; | |||
interleave_1x16_1_s(inptr, outptr_interleave); | |||
outptr += ksize16; | |||
} | |||
outptr = outptr_base4; | |||
for (; x + 4 <= xmax; x += 4) { | |||
auto outptr_interleave = outptr; | |||
interleave_1x4_1_s(inptr, outptr_interleave); | |||
outptr += ksize4; | |||
} | |||
if (x < xmax) { | |||
interleave_1(inptr, outptr, 4, xmax - x); | |||
} | |||
outptr_base += 16; | |||
outptr_base4 += 4; | |||
} | |||
} | |||
void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||
int y0, int ymax, int k0, int kmax) { | |||
float* outptr = out; | |||
const float* inptr = in; | |||
float zerobuff[4]; | |||
std::memset(zerobuff, 0, sizeof(float) * 4); | |||
int K16 = 16 * (kmax - k0); | |||
int y = y0; | |||
for (; y + 16 <= ymax; y += 16) { | |||
int yi = y; | |||
for (; yi < y + 16; yi += 4) { | |||
const float* inptr0 = inptr + yi * ldin + k0; | |||
const float* inptr1 = inptr0 + ldin; | |||
const float* inptr2 = inptr1 + ldin; | |||
const float* inptr3 = inptr2 + ldin; | |||
float* outptr_inner = outptr + yi - y; | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
prefetch_2x(inptr2); | |||
prefetch_2x(inptr3); | |||
int x = (kmax - k0); | |||
for (; x > 3; x -= 4) { | |||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, | |||
64); | |||
outptr_inner += 64; | |||
} | |||
for (; x > 0; x--) { | |||
*outptr_inner++ = *inptr0++; | |||
*outptr_inner++ = *inptr1++; | |||
*outptr_inner++ = *inptr2++; | |||
*outptr_inner++ = *inptr3++; | |||
outptr_inner += 12; | |||
} | |||
} | |||
outptr += K16; | |||
} | |||
for (; y < ymax; y += 4) { | |||
const float* inptr0 = inptr + y * ldin + k0; | |||
const float* inptr1 = inptr0 + ldin; | |||
const float* inptr2 = inptr1 + ldin; | |||
const float* inptr3 = inptr2 + ldin; | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
prefetch_2x(inptr2); | |||
prefetch_2x(inptr3); | |||
/* Cope with ragged cases by copying from a buffer of zeroes instead | |||
*/ | |||
int x = (kmax - k0); | |||
for (; x > 3; x -= 4) { | |||
if ((y + 3) >= ymax) { | |||
switch ((y + 3) - ymax) { | |||
/* Everything falls through in here */ | |||
case 2: | |||
inptr1 = zerobuff; | |||
case 1: | |||
inptr2 = zerobuff; | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); | |||
outptr += 16; | |||
} | |||
if (x > 0) { | |||
if ((y + 3) >= ymax) { | |||
switch ((y + 3) - ymax) { | |||
/* Everything falls through in here */ | |||
case 2: | |||
inptr1 = zerobuff; | |||
case 1: | |||
inptr2 = zerobuff; | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, x); | |||
} | |||
} | |||
} | |||
} // matmul_general_4x16 | |||
} // aarch64 | |||
} // megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,166 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/fp32/strategy.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/fp32/strategy.h" | |||
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" | |||
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" | |||
#include "src/common/utils.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); | |||
void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, | |||
int ymax, int k0, int kmax, bool transpose_A) const { | |||
if (transpose_A) { | |||
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); | |||
} else { | |||
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||
} | |||
} | |||
void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||
int k0, int kmax, bool transpose_B) const { | |||
if (transpose_B) { | |||
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
void sgemm_4x16::kern(const float* packA, const float* packB, | |||
size_t M, size_t N, size_t K, float* C, size_t LDC, | |||
bool is_first_k, const float*, float*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
A_dtype.enumv() == C_dtype.enumv() && | |||
A_dtype.enumv() == DTypeEnum::Float32); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
constexpr size_t A_INTERLEAVE = 4; | |||
constexpr size_t B_INTERLEAVE = 16; | |||
const int K16 = K * 16; | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m < M; m += A_INTERLEAVE) { | |||
float* output = C + (m * LDC); | |||
size_t n = 0; | |||
const float* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K16; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_general_4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
} | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); | |||
void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, | |||
int ymax, int k0, int kmax, bool transpose_A) const { | |||
if (transpose_A) { | |||
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
} else { | |||
matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
} | |||
} | |||
void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||
int k0, int kmax, bool transpose_B) const { | |||
if (transpose_B) { | |||
matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
} else { | |||
matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
} | |||
} | |||
void sgemm_8x12::kern(const float* packA, const float* packB, | |||
size_t M, size_t N, size_t K, float* C, size_t LDC, | |||
bool is_first_k, const float*, float*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
A_dtype.enumv() == C_dtype.enumv() && | |||
A_dtype.enumv() == DTypeEnum::Float32); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
constexpr size_t A_INTERLEAVE = 8; | |||
constexpr size_t A_INTERLEAVE4 = 4; | |||
constexpr size_t B_INTERLEAVE = 12; | |||
const int K12 = K * 12; | |||
const int K8 = K * 8; | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) { | |||
float* output = C + (m * LDC); | |||
size_t n = 0; | |||
const float* cur_packB = packB; | |||
for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) { | |||
matmul_general_8x12::kern_8x12(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
output += B_INTERLEAVE; | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_general_8x12::kern_8x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K8; | |||
} | |||
for (; m < M; m += A_INTERLEAVE4) { | |||
float* output = C + (m * LDC); | |||
size_t n = 0; | |||
const float* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_general_8x12::kern_4x12(packA, cur_packB, K, output, LDC, | |||
is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_general_8x12::kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,30 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/fp32/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, | |||
sgemm_8x12); | |||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, | |||
sgemm_4x16); | |||
MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true, | |||
sgemm_nopack_4x16); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,570 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/aarch64/matrix_mul/fp32/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
namespace { | |||
// Overview of register layout: | |||
// | |||
// A 4x4 block of A is stored in register v4-v7 | |||
// A 4x4 block of B is stored in register v0-v3 | |||
// A 8x4 block of accumulators store in v16-v19 | |||
// | |||
// A +--------+ | |||
// | v4[0-3]| | |||
// | v5[0-3]| | |||
// | v6[0-3]| | |||
// | v7[0-3]| | |||
// +--------+ | |||
// B | |||
// +--------+ - - - - -+--------+ | |||
// | v0[0-3]| |v16[0-3]| | |||
// | v1[0-3]| |v17[0-3]| | |||
// | v2[0-3]| |v18[0-3]| | |||
// | v3[0-3]| |v19[0-3]| | |||
// +--------+ - - - - -+--------+ | |||
// Accumulator | |||
void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
float* output) { | |||
//! As each load 16 number from B, but the pos add 12 * 4, so we minus 12 | |||
//! here. | |||
LDB = (LDB - 12) * sizeof(float); | |||
asm volatile( | |||
"subs %w[K], %w[K], #4\n" | |||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"fmul v16.4s, v4.4s, v0.s[0]\n" | |||
"fmul v17.4s, v4.4s, v1.s[0]\n" | |||
"fmul v18.4s, v4.4s, v2.s[0]\n" | |||
"fmul v19.4s, v4.4s, v3.s[0]\n" | |||
"fmla v16.4s, v5.4s, v0.s[1]\n" | |||
"fmla v17.4s, v5.4s, v1.s[1]\n" | |||
"fmla v18.4s, v5.4s, v2.s[1]\n" | |||
"fmla v19.4s, v5.4s, v3.s[1]\n" | |||
"beq 2f\n" | |||
"1:\n" | |||
"ld1 {v4.4s, v5.4s}, [%[a_ptr]], 32\n" | |||
"fmla v16.4s, v6.4s, v0.s[2]\n" | |||
"fmla v17.4s, v6.4s, v1.s[2]\n" | |||
"fmla v18.4s, v6.4s, v2.s[2]\n" | |||
"fmla v19.4s, v6.4s, v3.s[2]\n" | |||
"fmla v16.4s, v7.4s, v0.s[3]\n" | |||
"fmla v17.4s, v7.4s, v1.s[3]\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||
"fmla v18.4s, v7.4s, v2.s[3]\n" | |||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||
"fmla v19.4s, v7.4s, v3.s[3]\n" | |||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
"fmla v16.4s, v4.4s, v0.s[0]\n" | |||
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"fmla v17.4s, v4.4s, v1.s[0]\n" | |||
"fmla v18.4s, v4.4s, v2.s[0]\n" | |||
"fmla v19.4s, v4.4s, v3.s[0]\n" | |||
"ld1 {v6.4s, v7.4s}, [%[a_ptr]], 32\n" | |||
"fmla v16.4s, v5.4s, v0.s[1]\n" | |||
"fmla v17.4s, v5.4s, v1.s[1]\n" | |||
"fmla v18.4s, v5.4s, v2.s[1]\n" | |||
"fmla v19.4s, v5.4s, v3.s[1]\n" | |||
"subs %w[K], %w[K], #4\n" | |||
"bne 1b\n" | |||
"2:\n" | |||
"fmla v16.4s, v6.4s, v0.s[2]\n" | |||
"fmla v17.4s, v6.4s, v1.s[2]\n" | |||
"fmla v18.4s, v6.4s, v2.s[2]\n" | |||
"fmla v19.4s, v6.4s, v3.s[2]\n" | |||
"fmla v16.4s, v7.4s, v0.s[3]\n" | |||
"fmla v17.4s, v7.4s, v1.s[3]\n" | |||
"fmla v18.4s, v7.4s, v2.s[3]\n" | |||
"fmla v19.4s, v7.4s, v3.s[3]\n" | |||
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
// | |||
// A 4x4 block of A is stored in register v4-v7 | |||
// A 4x4 block of B is stored in register v0-v3, slipping until 8x4 | |||
// A 8x4 block of accumulators store in v16-v23. | |||
// | |||
// A +--------+ | |||
// | v4[0-3]| | |||
// | v5[0-3]| | |||
// | v6[0-3]| | |||
// | v7[0-3]| | |||
// +--------+ | |||
// B | |||
// +--------+ - - - - -+--------+ | |||
// | v0[0-3]| |v16[0-3]| | |||
// | v1[0-3]| |v17[0-3]| | |||
// | v2[0-3]| |v18[0-3]| | |||
// | v3[0-3]| |v19[0-3]| | |||
// +--------+ - - - - -+--------+ | |||
// | v0[0-3]| |v20[0-3]| | |||
// | v1[0-3]| |v21[0-3]| | |||
// | v2[0-3]| |v22[0-3]| | |||
// | v3[0-3]| |v23[0-3]| | |||
// +--------+ - - - - -+--------+ | |||
// Accumulator | |||
void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
float* output) { | |||
//! As each load 32 number from B, but the pos add 24 * 4, so we minus 24 | |||
//! here. | |||
LDB = (LDB - 24) * sizeof(float); | |||
asm volatile( | |||
"subs %w[K], %w[K], #4\n" | |||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||
"fmul v16.4s, v4.4s, v0.s[0]\n" | |||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||
"fmla v16.4s, v5.4s, v0.s[1]\n" | |||
"fmul v17.4s, v4.4s, v1.s[0]\n" | |||
"ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n" | |||
"fmla v17.4s, v5.4s, v1.s[1]\n" | |||
"fmla v16.4s, v6.4s, v0.s[2]\n" | |||
"fmla v17.4s, v6.4s, v1.s[2]\n" | |||
"fmul v18.4s, v4.4s, v2.s[0]\n" | |||
"fmla v16.4s, v7.4s, v0.s[3]\n" | |||
"fmla v18.4s, v5.4s, v2.s[1]\n" | |||
"fmla v17.4s, v7.4s, v1.s[3]\n" | |||
"fmul v19.4s, v4.4s, v3.s[0]\n" | |||
"ld1 {v24.4s, v25.4s}, [%[b_ptr]], 32\n" | |||
"fmla v18.4s, v7.4s, v2.s[3]\n" | |||
"fmla v19.4s, v5.4s, v3.s[1]\n" | |||
"fmul v20.4s, v4.4s, v24.s[0]\n" | |||
"fmla v19.4s, v6.4s, v3.s[2]\n" | |||
"ld1 {v26.4s, v27.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"fmla v18.4s, v6.4s, v2.s[2]\n" | |||
"fmla v19.4s, v7.4s, v3.s[3]\n" | |||
"fmul v21.4s, v4.4s, v25.s[0]\n" | |||
"fmla v20.4s, v5.4s, v24.s[1]\n" | |||
"fmla v21.4s, v5.4s, v25.s[1]\n" | |||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||
"fmla v20.4s, v6.4s, v24.s[2]\n" | |||
"fmul v22.4s, v4.4s, v26.s[0]\n" | |||
"fmla v21.4s, v6.4s, v25.s[2]\n" | |||
"fmla v22.4s, v5.4s, v26.s[1]\n" | |||
"fmla v21.4s, v7.4s, v25.s[3]\n" | |||
"fmul v23.4s, v4.4s, v27.s[0]\n" | |||
"fmla v20.4s, v7.4s, v24.s[3]\n" | |||
"fmla v22.4s, v6.4s, v26.s[2]\n" | |||
"fmla v23.4s, v5.4s, v27.s[1]\n" | |||
"beq 2f\n" | |||
"1:\n" | |||
"ld1 {v4.4s, v5.4s}, [%[a_ptr]], 32\n" | |||
"fmla v22.4s, v7.4s, v26.s[3]\n" | |||
"fmla v23.4s, v6.4s, v27.s[2]\n" | |||
"fmla v16.4s, v4.4s, v0.s[0]\n" | |||
"fmla v17.4s, v4.4s, v1.s[0]\n" | |||
"fmla v23.4s, v7.4s, v27.s[3]\n" | |||
"ld1 {v6.4s, v7.4s}, [%[a_ptr]], 32\n" | |||
"fmla v16.4s, v5.4s, v0.s[1]\n" | |||
"fmla v17.4s, v5.4s, v1.s[1]\n" | |||
"fmla v16.4s, v6.4s, v0.s[2]\n" | |||
"fmla v17.4s, v6.4s, v1.s[2]\n" | |||
"fmla v18.4s, v4.4s, v2.s[0]\n" | |||
"fmla v16.4s, v7.4s, v0.s[3]\n" | |||
"fmla v18.4s, v5.4s, v2.s[1]\n" | |||
"fmla v17.4s, v7.4s, v1.s[3]\n" | |||
"fmla v19.4s, v4.4s, v3.s[0]\n" | |||
"ld1 {v24.4s, v25.4s}, [%[b_ptr]], 32\n" | |||
"fmla v18.4s, v6.4s, v2.s[2]\n" | |||
"fmla v19.4s, v5.4s, v3.s[1]\n" | |||
"fmla v20.4s, v4.4s, v24.s[0]\n" | |||
"fmla v19.4s, v6.4s, v3.s[2]\n" | |||
"ld1 {v26.4s, v27.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"fmla v18.4s, v7.4s, v2.s[3]\n" | |||
"fmla v19.4s, v7.4s, v3.s[3]\n" | |||
"fmla v21.4s, v4.4s, v25.s[0]\n" | |||
"fmla v20.4s, v5.4s, v24.s[1]\n" | |||
"fmla v21.4s, v5.4s, v25.s[1]\n" | |||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||
"fmla v20.4s, v6.4s, v24.s[2]\n" | |||
"fmla v22.4s, v4.4s, v26.s[0]\n" | |||
"fmla v20.4s, v7.4s, v24.s[3]\n" | |||
"fmla v23.4s, v4.4s, v27.s[0]\n" | |||
"fmla v21.4s, v6.4s, v25.s[2]\n" | |||
"fmla v22.4s, v5.4s, v26.s[1]\n" | |||
"fmla v21.4s, v7.4s, v25.s[3]\n" | |||
"fmla v23.4s, v5.4s, v27.s[1]\n" | |||
"fmla v22.4s, v6.4s, v26.s[2]\n" | |||
"subs %w[K], %w[K], #4\n" | |||
"bne 1b\n" | |||
"2:\n" | |||
"fmla v22.4s, v7.4s, v26.s[3]\n" | |||
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" | |||
"fmla v23.4s, v6.4s, v27.s[2]\n" | |||
"fmla v23.4s, v7.4s, v27.s[3]\n" | |||
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||
"v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", | |||
"v27", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
// | |||
// A 4x1 cell of Rhs is stored in 32bit in v4-v7 (v8-v11 for ping pong) | |||
// A 4x1 cell of Lhs is stored in 32bit in v0-v3 | |||
// A 16x1 block of accumulators is stored in 32bit in v16-v31. | |||
// | |||
// Rhs +--------+ | |||
// | v4[0-3]| | |||
// | v5[0-3]| | |||
// | v6[0-3]| | |||
// | v7[0-3]| | |||
// +--------+ | |||
// Lhs | |||
// +--------+ - - - - -+--------+ | |||
// | v0[0-3] | |v16[0-3]| | |||
// | v1[0-3] | |v17[0-3]| | |||
// | v2[0-3] | |v18[0-3]| | |||
// | v3[0-3] | |v19[0-3]| | |||
// | v8[0-3] | |v20[0-3]| | |||
// | v9[0-3] | |v21[0-3]| | |||
// | v10[0-3]| |v22[0-3]| | |||
// | v11[0-3]| |v23[0-3]| | |||
// +--------+ |v24[0-3]| | |||
// |v25[0-3]| | |||
// |v26[0-3]| | |||
// |v27[0-3]| | |||
// |v28[0-3]| | |||
// |v29[0-3]| | |||
// |v30[0-3]| | |||
// |v31[0-3]| | |||
// +--------+ | |||
// Accumulator | |||
void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, | |||
float* output) { | |||
//! As each load 64 number from B, but the pos add 56 * 4, so we minus 56 | |||
//! here. | |||
LDB = (LDB - 56) * sizeof(float); | |||
asm volatile( | |||
"stp d8, d9, [sp, #-16]!\n" | |||
"stp d10, d11, [sp, #-16]!\n" | |||
"subs %w[K], %w[K], #4\n" | |||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n" | |||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||
"fmul v16.4s, v4.4s, v0.s[0]\n" | |||
"fmul v17.4s, v4.4s, v1.s[0]\n" | |||
"fmul v18.4s, v4.4s, v2.s[0]\n" | |||
"fmul v19.4s, v4.4s, v3.s[0]\n" | |||
"fmla v16.4s, v5.4s, v0.s[1]\n" | |||
"fmla v17.4s, v5.4s, v1.s[1]\n" | |||
"fmla v18.4s, v5.4s, v2.s[1]\n" | |||
"fmla v19.4s, v5.4s, v3.s[1]\n" | |||
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[b_ptr]], 64\n" | |||
"fmla v16.4s, v6.4s, v0.s[2]\n" | |||
"fmla v17.4s, v6.4s, v1.s[2]\n" | |||
"fmla v18.4s, v6.4s, v2.s[2]\n" | |||
"fmla v19.4s, v6.4s, v3.s[2]\n" | |||
"fmla v16.4s, v7.4s, v0.s[3]\n" | |||
"fmla v17.4s, v7.4s, v1.s[3]\n" | |||
"fmla v18.4s, v7.4s, v2.s[3]\n" | |||
"fmla v19.4s, v7.4s, v3.s[3]\n" | |||
"fmul v20.4s, v4.4s, v8.s[0]\n" | |||
"fmul v21.4s, v4.4s, v9.s[0]\n" | |||
"fmul v22.4s, v4.4s, v10.s[0]\n" | |||
"fmul v23.4s, v4.4s, v11.s[0]\n" | |||
"fmla v20.4s, v5.4s, v8.s[1]\n" | |||
"fmla v21.4s, v5.4s, v9.s[1]\n" | |||
"fmla v22.4s, v5.4s, v10.s[1]\n" | |||
"fmla v23.4s, v5.4s, v11.s[1]\n" | |||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||
"fmla v20.4s, v6.4s, v8.s[2]\n" | |||
"fmla v21.4s, v6.4s, v9.s[2]\n" | |||
"fmla v22.4s, v6.4s, v10.s[2]\n" | |||
"fmla v23.4s, v6.4s, v11.s[2]\n" | |||
"fmla v20.4s, v7.4s, v8.s[3]\n" | |||
"fmla v21.4s, v7.4s, v9.s[3]\n" | |||
"fmla v22.4s, v7.4s, v10.s[3]\n" | |||
"fmla v23.4s, v7.4s, v11.s[3]\n" | |||
"fmul v24.4s, v4.4s, v0.s[0]\n" | |||
"fmul v25.4s, v4.4s, v1.s[0]\n" | |||
"fmul v26.4s, v4.4s, v2.s[0]\n" | |||
"fmul v27.4s, v4.4s, v3.s[0]\n" | |||
"fmla v24.4s, v5.4s, v0.s[1]\n" | |||
"fmla v25.4s, v5.4s, v1.s[1]\n" | |||
"fmla v26.4s, v5.4s, v2.s[1]\n" | |||
"fmla v27.4s, v5.4s, v3.s[1]\n" | |||
"ld1 {v8.4s, v9.4s}, [%[b_ptr]], 32\n" | |||
"fmla v24.4s, v6.4s, v0.s[2]\n" | |||
"fmla v25.4s, v6.4s, v1.s[2]\n" | |||
"fmla v26.4s, v6.4s, v2.s[2]\n" | |||
"fmla v27.4s, v6.4s, v3.s[2]\n" | |||
"ld1 {v10.4s, v11.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"fmla v24.4s, v7.4s, v0.s[3]\n" | |||
"fmla v25.4s, v7.4s, v1.s[3]\n" | |||
"fmla v26.4s, v7.4s, v2.s[3]\n" | |||
"fmla v27.4s, v7.4s, v3.s[3]\n" | |||
"fmul v28.4s, v4.4s, v8.s[0]\n" | |||
"fmul v29.4s, v4.4s, v9.s[0]\n" | |||
"fmul v30.4s, v4.4s, v10.s[0]\n" | |||
"fmul v31.4s, v4.4s, v11.s[0]\n" | |||
"beq 2f\n" | |||
"1:\n" | |||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||
"fmla v28.4s, v5.4s, v8.s[1]\n" | |||
"fmla v29.4s, v5.4s, v9.s[1]\n" | |||
"fmla v30.4s, v5.4s, v10.s[1]\n" | |||
"fmla v31.4s, v5.4s, v11.s[1]\n" | |||
"ld1 {v4.4s}, [%[a_ptr]], 16\n" | |||
"fmla v28.4s, v6.4s, v8.s[2]\n" | |||
"fmla v29.4s, v6.4s, v9.s[2]\n" | |||
"fmla v30.4s, v6.4s, v10.s[2]\n" | |||
"fmla v31.4s, v6.4s, v11.s[2]\n" | |||
"ld1 {v5.4s}, [%[a_ptr]], 16\n" | |||
"fmla v28.4s, v7.4s, v8.s[3]\n" | |||
"fmla v29.4s, v7.4s, v9.s[3]\n" | |||
"fmla v30.4s, v7.4s, v10.s[3]\n" | |||
"fmla v31.4s, v7.4s, v11.s[3]\n" | |||
"ld1 {v6.4s}, [%[a_ptr]], 16\n" | |||
"fmla v16.4s, v4.4s, v0.s[0]\n" | |||
"fmla v17.4s, v4.4s, v1.s[0]\n" | |||
"fmla v18.4s, v4.4s, v2.s[0]\n" | |||
"fmla v19.4s, v4.4s, v3.s[0]\n" | |||
"ld1 {v7.4s}, [%[a_ptr]], 16\n" | |||
"fmla v16.4s, v5.4s, v0.s[1]\n" | |||
"fmla v17.4s, v5.4s, v1.s[1]\n" | |||
"fmla v18.4s, v5.4s, v2.s[1]\n" | |||
"fmla v19.4s, v5.4s, v3.s[1]\n" | |||
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[b_ptr]], 64\n" | |||
"fmla v16.4s, v6.4s, v0.s[2]\n" | |||
"fmla v17.4s, v6.4s, v1.s[2]\n" | |||
"fmla v18.4s, v6.4s, v2.s[2]\n" | |||
"fmla v19.4s, v6.4s, v3.s[2]\n" | |||
"fmla v16.4s, v7.4s, v0.s[3]\n" | |||
"fmla v17.4s, v7.4s, v1.s[3]\n" | |||
"fmla v18.4s, v7.4s, v2.s[3]\n" | |||
"fmla v19.4s, v7.4s, v3.s[3]\n" | |||
"fmla v20.4s, v4.4s, v8.s[0]\n" | |||
"fmla v21.4s, v4.4s, v9.s[0]\n" | |||
"fmla v22.4s, v4.4s, v10.s[0]\n" | |||
"fmla v23.4s, v4.4s, v11.s[0]\n" | |||
"fmla v20.4s, v5.4s, v8.s[1]\n" | |||
"fmla v21.4s, v5.4s, v9.s[1]\n" | |||
"fmla v22.4s, v5.4s, v10.s[1]\n" | |||
"fmla v23.4s, v5.4s, v11.s[1]\n" | |||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n" | |||
"fmla v20.4s, v6.4s, v8.s[2]\n" | |||
"fmla v21.4s, v6.4s, v9.s[2]\n" | |||
"fmla v22.4s, v6.4s, v10.s[2]\n" | |||
"fmla v23.4s, v6.4s, v11.s[2]\n" | |||
"fmla v20.4s, v7.4s, v8.s[3]\n" | |||
"fmla v21.4s, v7.4s, v9.s[3]\n" | |||
"fmla v22.4s, v7.4s, v10.s[3]\n" | |||
"fmla v23.4s, v7.4s, v11.s[3]\n" | |||
"fmla v24.4s, v4.4s, v0.s[0]\n" | |||
"fmla v25.4s, v4.4s, v1.s[0]\n" | |||
"fmla v26.4s, v4.4s, v2.s[0]\n" | |||
"fmla v27.4s, v4.4s, v3.s[0]\n" | |||
"fmla v24.4s, v5.4s, v0.s[1]\n" | |||
"fmla v25.4s, v5.4s, v1.s[1]\n" | |||
"fmla v26.4s, v5.4s, v2.s[1]\n" | |||
"fmla v27.4s, v5.4s, v3.s[1]\n" | |||
"ld1 {v8.4s, v9.4s}, [%[b_ptr]], 32\n" | |||
"fmla v24.4s, v6.4s, v0.s[2]\n" | |||
"fmla v25.4s, v6.4s, v1.s[2]\n" | |||
"fmla v26.4s, v6.4s, v2.s[2]\n" | |||
"fmla v27.4s, v6.4s, v3.s[2]\n" | |||
"ld1 {v10.4s, v11.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"fmla v24.4s, v7.4s, v0.s[3]\n" | |||
"fmla v25.4s, v7.4s, v1.s[3]\n" | |||
"fmla v26.4s, v7.4s, v2.s[3]\n" | |||
"fmla v27.4s, v7.4s, v3.s[3]\n" | |||
"fmla v28.4s, v4.4s, v8.s[0]\n" | |||
"fmla v29.4s, v4.4s, v9.s[0]\n" | |||
"fmla v30.4s, v4.4s, v10.s[0]\n" | |||
"fmla v31.4s, v4.4s, v11.s[0]\n" | |||
"subs %w[K], %w[K], #4\n" | |||
"bne 1b\n" | |||
"2:\n" | |||
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" | |||
"fmla v28.4s, v5.4s, v8.s[1]\n" | |||
"fmla v29.4s, v5.4s, v9.s[1]\n" | |||
"fmla v30.4s, v5.4s, v10.s[1]\n" | |||
"fmla v31.4s, v5.4s, v11.s[1]\n" | |||
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n" | |||
"fmla v28.4s, v6.4s, v8.s[2]\n" | |||
"fmla v29.4s, v6.4s, v9.s[2]\n" | |||
"fmla v30.4s, v6.4s, v10.s[2]\n" | |||
"fmla v31.4s, v6.4s, v11.s[2]\n" | |||
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n" | |||
"fmla v28.4s, v7.4s, v8.s[3]\n" | |||
"fmla v29.4s, v7.4s, v9.s[3]\n" | |||
"fmla v30.4s, v7.4s, v10.s[3]\n" | |||
"fmla v31.4s, v7.4s, v11.s[3]\n" | |||
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n" | |||
"ldp d10, d11, [sp], #16\n" | |||
"ldp d8, d9, [sp], #16\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||
"memory"); | |||
} | |||
} // namespace | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16); | |||
void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, | |||
size_t LDB, float* C, size_t LDC, size_t M, | |||
size_t K, size_t N, const float*, void*, bool trA, | |||
bool trB) const { | |||
constexpr static size_t MB = 4; | |||
constexpr static size_t KB = 4; | |||
constexpr static size_t NB = 16; | |||
constexpr static size_t CALCBLK = 4; | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
//! (m/4, k/4, 4, 4) * (k/4, n, 4) = (m/4, n, 4) | |||
for (size_t m = 0; m < M; m += MB) { | |||
float* output = C + (m / MB) * LDC; | |||
const float* cur_B = B; | |||
size_t n = 0; | |||
for (; n + NB - 1 < N; n += NB) { | |||
kern_4x16(A, cur_B, LDB, K, output); | |||
cur_B += KB * NB; | |||
output += MB * NB; | |||
} | |||
switch (N - n) { | |||
case 4: | |||
kern_4x4(A, cur_B, LDB, K, output); | |||
break; | |||
case 8: | |||
kern_4x8(A, cur_B, LDB, K, output); | |||
break; | |||
case 12: | |||
kern_4x8(A, cur_B, LDB, K, output); | |||
cur_B += KB * CALCBLK * 2; | |||
output += MB * CALCBLK * 2; | |||
kern_4x4(A, cur_B, LDB, K, output); | |||
break; | |||
default: | |||
break; | |||
} | |||
A += LDA; | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,132 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int16/strategy.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/int16/strategy.h" | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/aarch64/matrix_mul/int16/kernel_12x8x1.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
///////////////////////// gemm_s16_12x8x1 //////////////////////////////////// | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1); | |||
void gemm_s16_12x8x1::pack_A(dt_int16* outptr, const dt_int16* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n(outptr, inptr, ldin, | |||
y0, ymax, k0, kmax); | |||
} else { | |||
matmul_12x8x1::gemm_s16_12x8x1_pack_A_n(outptr, inptr, ldin, y0, ymax, | |||
k0, kmax); | |||
} | |||
} | |||
void gemm_s16_12x8x1::pack_B(dt_int16* out, const dt_int16* in, int ldin, | |||
int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n(out, in, ldin, x0, | |||
xmax, k0, kmax); | |||
} else { | |||
matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
} | |||
} | |||
void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||
size_t M, size_t N, size_t K, dt_int32* C, | |||
size_t LDC, bool is_first_k, const dt_int32*, | |||
dt_int32*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
(A_dtype.enumv() == DTypeEnum::Int16 && | |||
C_dtype.enumv() == DTypeEnum::Int32), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
constexpr size_t A_INTERLEAVE = 12; | |||
constexpr size_t B_INTERLEAVE = 8; | |||
const int K12 = K * 12; | |||
const int K8 = K * 8; | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
int32_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_int16* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_12x8x1::kern_12x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K12; | |||
} | |||
for (; m + 7 < M; m += 8) { | |||
int32_t* output = C + (m * LDC); | |||
const dt_int16* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_12x8x1::kern_8x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K8; | |||
} | |||
for (; m < M; m += 4) { | |||
int32_t* output = C + (m * LDC); | |||
const dt_int16* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_12x8x1::kern_4x8(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(M - m, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_12x8x1::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,29 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int16/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, | |||
gemm_s16_12x8x1); | |||
MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 8, 8, 1, false, | |||
true, gemm_nopack_s16_8x8); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,658 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/int16/strategy.h" | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
namespace { | |||
// Overview of register layout: | |||
// | |||
// A 8x1 cell of Lhs is stored in 16bit in v24-v27 | |||
// A 8x1 cell of Rhs is stored in 16bit in v0-v7 | |||
// A 8x2 block of accumulators is stored in 32bit in v16-v23. | |||
// | |||
// Lhs +--------+ | |||
// |v0[0-7]| | |||
// |v1[0-7]| | |||
// |v2[0-7]| | |||
// |v3[0-7]| | |||
// +--------+ | |||
// Rhs | |||
// +---------+ - - - - -+--------+ | |||
// | v24[0-7]| |v16[0-3]| | |||
// | v25[0-7]| |v17[0-3]| | |||
// | v26[0-7]| |v18[0-3]| | |||
// | v27[0-7]| |v19[0-3]| | |||
// | v28[0-7]| |v20[0-3]| | |||
// | v29[0-7]| |v21[0-3]| | |||
// | v30[0-7]| |v22[0-3]| | |||
// | v31[0-7]| |v23[0-3]| | |||
// +---------+ +--------+ | |||
// Accumulator | |||
void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
dt_int32* output) { | |||
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | |||
//! here. | |||
LDB = (LDB - 24) * sizeof(dt_int16); | |||
asm volatile( | |||
"subs %w[K], %w[K], #8\n" | |||
"ld1 {v24.4s}, [%[a_ptr]], 16\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"smull v16.4s, v24.4h, v0.h[0]\n" | |||
"smull2 v17.4s, v24.8h, v0.h[0]\n" | |||
"smull v18.4s, v24.4h, v1.h[0]\n" | |||
"smull2 v19.4s, v24.8h, v1.h[0]\n" | |||
"ld1 {v25.4s}, [%[a_ptr]], 16\n" | |||
"smull v20.4s, v24.4h, v2.h[0]\n" | |||
"smull2 v21.4s, v24.8h, v2.h[0]\n" | |||
"smull v22.4s, v24.4h, v3.h[0]\n" | |||
"smull2 v23.4s, v24.8h, v3.h[0]\n" | |||
"smlal v16.4s, v25.4h, v0.h[1]\n" | |||
"smlal2 v17.4s, v25.8h, v0.h[1]\n" | |||
"smlal v18.4s, v25.4h, v1.h[1]\n" | |||
"smlal2 v19.4s, v25.8h, v1.h[1]\n" | |||
"ld1 {v26.4s}, [%[a_ptr]], 16\n" | |||
"smlal v20.4s, v25.4h, v2.h[1]\n" | |||
"smlal2 v21.4s, v25.8h, v2.h[1]\n" | |||
"smlal v22.4s, v25.4h, v3.h[1]\n" | |||
"smlal2 v23.4s, v25.8h, v3.h[1]\n" | |||
"smlal v16.4s, v26.4h, v0.h[2]\n" | |||
"smlal2 v17.4s, v26.8h, v0.h[2]\n" | |||
"smlal v18.4s, v26.4h, v1.h[2]\n" | |||
"smlal2 v19.4s, v26.8h, v1.h[2]\n" | |||
"ld1 {v27.4s}, [%[a_ptr]], 16\n" | |||
"smlal v20.4s, v26.4h, v2.h[2]\n" | |||
"smlal2 v21.4s, v26.8h, v2.h[2]\n" | |||
"smlal v22.4s, v26.4h, v3.h[2]\n" | |||
"smlal2 v23.4s, v26.8h, v3.h[2]\n" | |||
"smlal v16.4s, v27.4h, v0.h[3]\n" | |||
"smlal2 v17.4s, v27.8h, v0.h[3]\n" | |||
"smlal v18.4s, v27.4h, v1.h[3]\n" | |||
"smlal2 v19.4s, v27.8h, v1.h[3]\n" | |||
"ld1 {v28.4s}, [%[a_ptr]], 16\n" | |||
"smlal v20.4s, v27.4h, v2.h[3]\n" | |||
"smlal2 v21.4s, v27.8h, v2.h[3]\n" | |||
"smlal v22.4s, v27.4h, v3.h[3]\n" | |||
"smlal2 v23.4s, v27.8h, v3.h[3]\n" | |||
"smlal v16.4s, v28.4h, v0.h[4]\n" | |||
"smlal2 v17.4s, v28.8h, v0.h[4]\n" | |||
"smlal v18.4s, v28.4h, v1.h[4]\n" | |||
"smlal2 v19.4s, v28.8h, v1.h[4]\n" | |||
"ld1 {v29.4s}, [%[a_ptr]], 16\n" | |||
"smlal v20.4s, v28.4h, v2.h[4]\n" | |||
"smlal2 v21.4s, v28.8h, v2.h[4]\n" | |||
"smlal v22.4s, v28.4h, v3.h[4]\n" | |||
"smlal2 v23.4s, v28.8h, v3.h[4]\n" | |||
"smlal v16.4s, v29.4h, v0.h[5]\n" | |||
"smlal2 v17.4s, v29.8h, v0.h[5]\n" | |||
"smlal v18.4s, v29.4h, v1.h[5]\n" | |||
"smlal2 v19.4s, v29.8h, v1.h[5]\n" | |||
"ld1 {v30.4s}, [%[a_ptr]], 16\n" | |||
"smlal v20.4s, v29.4h, v2.h[5]\n" | |||
"smlal2 v21.4s, v29.8h, v2.h[5]\n" | |||
"smlal v22.4s, v29.4h, v3.h[5]\n" | |||
"smlal2 v23.4s, v29.8h, v3.h[5]\n" | |||
"smlal v16.4s, v30.4h, v0.h[6]\n" | |||
"smlal2 v17.4s, v30.8h, v0.h[6]\n" | |||
"smlal v18.4s, v30.4h, v1.h[6]\n" | |||
"smlal2 v19.4s, v30.8h, v1.h[6]\n" | |||
"ld1 {v31.4s}, [%[a_ptr]], 16\n" | |||
"smlal v20.4s, v30.4h, v2.h[6]\n" | |||
"smlal2 v21.4s, v30.8h, v2.h[6]\n" | |||
"smlal v22.4s, v30.4h, v3.h[6]\n" | |||
"smlal2 v23.4s, v30.8h, v3.h[6]\n" | |||
"beq 2f\n" | |||
"1:\n" | |||
"ld1 {v24.4s}, [%[a_ptr]], 16\n" | |||
"smlal v16.4s, v31.4h, v0.h[7]\n" | |||
"smlal2 v17.4s, v31.8h, v0.h[7]\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||
"smlal v18.4s, v31.4h, v1.h[7]\n" | |||
"smlal2 v19.4s, v31.8h, v1.h[7]\n" | |||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||
"smlal v20.4s, v31.4h, v2.h[7]\n" | |||
"smlal2 v21.4s, v31.8h, v2.h[7]\n" | |||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
"smlal v22.4s, v31.4h, v3.h[7]\n" | |||
"smlal2 v23.4s, v31.8h, v3.h[7]\n" | |||
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"smlal v16.4s, v24.4h, v0.h[0]\n" | |||
"smlal2 v17.4s, v24.8h, v0.h[0]\n" | |||
"smlal v18.4s, v24.4h, v1.h[0]\n" | |||
"smlal2 v19.4s, v24.8h, v1.h[0]\n" | |||
"ld1 {v25.4s}, [%[a_ptr]], 16\n" | |||
"smlal v20.4s, v24.4h, v2.h[0]\n" | |||
"smlal2 v21.4s, v24.8h, v2.h[0]\n" | |||
"smlal v22.4s, v24.4h, v3.h[0]\n" | |||
"smlal2 v23.4s, v24.8h, v3.h[0]\n" | |||
"smlal v16.4s, v25.4h, v0.h[1]\n" | |||
"smlal2 v17.4s, v25.8h, v0.h[1]\n" | |||
"smlal v18.4s, v25.4h, v1.h[1]\n" | |||
"smlal2 v19.4s, v25.8h, v1.h[1]\n" | |||
"ld1 {v26.4s}, [%[a_ptr]], 16\n" | |||
"smlal v20.4s, v25.4h, v2.h[1]\n" | |||
"smlal2 v21.4s, v25.8h, v2.h[1]\n" | |||
"smlal v22.4s, v25.4h, v3.h[1]\n" | |||
"smlal2 v23.4s, v25.8h, v3.h[1]\n" | |||
"smlal v16.4s, v26.4h, v0.h[2]\n" | |||
"smlal2 v17.4s, v26.8h, v0.h[2]\n" | |||
"smlal v18.4s, v26.4h, v1.h[2]\n" | |||
"smlal2 v19.4s, v26.8h, v1.h[2]\n" | |||
"ld1 {v27.4s}, [%[a_ptr]], 16\n" | |||
"smlal v20.4s, v26.4h, v2.h[2]\n" | |||
"smlal2 v21.4s, v26.8h, v2.h[2]\n" | |||
"smlal v22.4s, v26.4h, v3.h[2]\n" | |||
"smlal2 v23.4s, v26.8h, v3.h[2]\n" | |||
"smlal v16.4s, v27.4h, v0.h[3]\n" | |||
"smlal2 v17.4s, v27.8h, v0.h[3]\n" | |||
"smlal v18.4s, v27.4h, v1.h[3]\n" | |||
"smlal2 v19.4s, v27.8h, v1.h[3]\n" | |||
"ld1 {v28.4s}, [%[a_ptr]], 16\n" | |||
"smlal v20.4s, v27.4h, v2.h[3]\n" | |||
"smlal2 v21.4s, v27.8h, v2.h[3]\n" | |||
"smlal v22.4s, v27.4h, v3.h[3]\n" | |||
"smlal2 v23.4s, v27.8h, v3.h[3]\n" | |||
"smlal v16.4s, v28.4h, v0.h[4]\n" | |||
"smlal2 v17.4s, v28.8h, v0.h[4]\n" | |||
"smlal v18.4s, v28.4h, v1.h[4]\n" | |||
"smlal2 v19.4s, v28.8h, v1.h[4]\n" | |||
"ld1 {v29.4s}, [%[a_ptr]], 16\n" | |||
"smlal v20.4s, v28.4h, v2.h[4]\n" | |||
"smlal2 v21.4s, v28.8h, v2.h[4]\n" | |||
"smlal v22.4s, v28.4h, v3.h[4]\n" | |||
"smlal2 v23.4s, v28.8h, v3.h[4]\n" | |||
"smlal v16.4s, v29.4h, v0.h[5]\n" | |||
"smlal2 v17.4s, v29.8h, v0.h[5]\n" | |||
"smlal v18.4s, v29.4h, v1.h[5]\n" | |||
"smlal2 v19.4s, v29.8h, v1.h[5]\n" | |||
"ld1 {v30.4s}, [%[a_ptr]], 16\n" | |||
"smlal v20.4s, v29.4h, v2.h[5]\n" | |||
"smlal2 v21.4s, v29.8h, v2.h[5]\n" | |||
"smlal v22.4s, v29.4h, v3.h[5]\n" | |||
"smlal2 v23.4s, v29.8h, v3.h[5]\n" | |||
"smlal v16.4s, v30.4h, v0.h[6]\n" | |||
"smlal2 v17.4s, v30.8h, v0.h[6]\n" | |||
"smlal v18.4s, v30.4h, v1.h[6]\n" | |||
"smlal2 v19.4s, v30.8h, v1.h[6]\n" | |||
"ld1 {v31.4s}, [%[a_ptr]], 16\n" | |||
"smlal v20.4s, v30.4h, v2.h[6]\n" | |||
"smlal2 v21.4s, v30.8h, v2.h[6]\n" | |||
"smlal v22.4s, v30.4h, v3.h[6]\n" | |||
"smlal2 v23.4s, v30.8h, v3.h[6]\n" | |||
"subs %w[K], %w[K], #8\n" | |||
"bne 1b\n" | |||
"2:\n" | |||
"smlal v16.4s, v31.4h, v0.h[7]\n" | |||
"smlal2 v17.4s, v31.8h, v0.h[7]\n" | |||
"smlal v18.4s, v31.4h, v1.h[7]\n" | |||
"smlal2 v19.4s, v31.8h, v1.h[7]\n" | |||
"smlal v20.4s, v31.4h, v2.h[7]\n" | |||
"smlal2 v21.4s, v31.8h, v2.h[7]\n" | |||
"smlal v22.4s, v31.4h, v3.h[7]\n" | |||
"smlal2 v23.4s, v31.8h, v3.h[7]\n" | |||
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" | |||
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
// | |||
// A 8x1 cell of Rhs is stored in 16bit in v8-v15 | |||
// A 8x1 cell of Lhs is stored in 16bit in v0-v7 | |||
// A 8x2 block of accumulators is stored in 32bit in v16-v31. | |||
// | |||
// Rhs +--------+ | |||
// | v8[0-7]| | |||
// | v9[0-7]| | |||
// |v10[0-7]| | |||
// |v11[0-7]| | |||
// |v12[0-7]| | |||
// |v13[0-7]| | |||
// |v14[0-7]| | |||
// |v15[0-7]| | |||
// +--------+ | |||
// Lhs | |||
// +--------+ - - - - -+--------+--------+ | |||
// | v0[0-7]| |v16[0-3]|v17[0-3]| | |||
// | v1[0-7]| |v18[0-3]|v19[0-3]| | |||
// | v2[0-7]| |v20[0-3]|v21[0-3]| | |||
// | v3[0-7]| |v22[0-3]|v23[0-3]| | |||
// | v4[0-7]| |v24[0-3]|v25[0-3]| | |||
// | v5[0-7]| |v26[0-3]|v27[0-3]| | |||
// | v6[0-7]| |v28[0-3]|v29[0-3]| | |||
// | v7[0-7]| |v30[0-3]|v31[0-3]| | |||
// +--------+ +--------+--------+ | |||
// Accumulator | |||
void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
dt_int32* output) { | |||
//! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 | |||
//! here. | |||
LDB = (LDB - 48) * sizeof(dt_int16); | |||
asm volatile( | |||
"subs %w[K], %w[K], #8\n" | |||
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n" | |||
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||
"smull v16.4s, v8.4h, v0.h[0]\n" | |||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||
"smlal v16.4s, v9.4h, v0.h[1]\n" | |||
"smull v18.4s, v8.4h, v1.h[0]\n" | |||
"smull2 v17.4s, v8.8h, v0.h[0]\n" | |||
"smull2 v19.4s, v8.8h, v1.h[0]\n" | |||
"smlal v16.4s, v10.4h, v0.h[2]\n" | |||
"smlal v18.4s, v9.4h, v1.h[1]\n" | |||
"smlal2 v17.4s, v9.8h, v0.h[1]\n" | |||
"smlal2 v19.4s, v9.8h, v1.h[1]\n" | |||
"smlal v16.4s, v11.4h, v0.h[3]\n" | |||
"smlal v18.4s, v10.4h, v1.h[2]\n" | |||
"smlal2 v17.4s, v10.8h, v0.h[2]\n" | |||
"smlal2 v19.4s, v10.8h, v1.h[2]\n" | |||
"smlal v16.4s, v12.4h, v0.h[4]\n" | |||
"smlal v18.4s, v11.4h, v1.h[3]\n" | |||
"smlal2 v17.4s, v11.8h, v0.h[3]\n" | |||
"smlal2 v19.4s, v11.8h, v1.h[3]\n" | |||
"smlal v16.4s, v13.4h, v0.h[5]\n" | |||
"smlal v18.4s, v12.4h, v1.h[4]\n" | |||
"smlal2 v17.4s, v12.8h, v0.h[4]\n" | |||
"smlal2 v19.4s, v12.8h, v1.h[4]\n" | |||
"smlal2 v17.4s, v13.8h, v0.h[5]\n" | |||
"ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n" | |||
"smlal v16.4s, v14.4h, v0.h[6]\n" | |||
"smlal v18.4s, v13.4h, v1.h[5]\n" | |||
"smlal2 v17.4s, v14.8h, v0.h[6]\n" | |||
"smlal2 v19.4s, v13.8h, v1.h[5]\n" | |||
"smull v20.4s, v8.4h, v2.h[0]\n" | |||
"smull v22.4s, v8.4h, v3.h[0]\n" | |||
"smull2 v21.4s, v8.8h, v2.h[0]\n" | |||
"smull2 v23.4s, v8.8h, v3.h[0]\n" | |||
"smlal v16.4s, v15.4h, v0.h[7]\n" | |||
"smlal v18.4s, v14.4h, v1.h[6]\n" | |||
"smlal2 v17.4s, v15.8h, v0.h[7]\n" | |||
"smlal2 v19.4s, v14.8h, v1.h[6]\n" | |||
"smlal v20.4s, v9.4h, v2.h[1]\n" | |||
"smlal v22.4s, v9.4h, v3.h[1]\n" | |||
"smlal2 v21.4s, v9.8h, v2.h[1]\n" | |||
"smlal2 v23.4s, v9.8h, v3.h[1]\n" | |||
"smlal v18.4s, v15.4h, v1.h[7]\n" | |||
"smlal v20.4s, v10.4h, v2.h[2]\n" | |||
"smlal v22.4s, v10.4h, v3.h[2]\n" | |||
"smlal2 v21.4s, v10.8h, v2.h[2]\n" | |||
"smlal2 v23.4s, v10.8h, v3.h[2]\n" | |||
"smlal2 v19.4s, v15.8h, v1.h[7]\n" | |||
"smlal v20.4s, v11.4h, v2.h[3]\n" | |||
"smlal v22.4s, v11.4h, v3.h[3]\n" | |||
"smlal2 v21.4s, v11.8h, v2.h[3]\n" | |||
"smlal2 v23.4s, v11.8h, v3.h[3]\n" | |||
"smlal v20.4s, v12.4h, v2.h[4]\n" | |||
"smlal v22.4s, v12.4h, v3.h[4]\n" | |||
"smlal2 v21.4s, v12.8h, v2.h[4]\n" | |||
"smlal2 v23.4s, v12.8h, v3.h[4]\n" | |||
"smlal v20.4s, v13.4h, v2.h[5]\n" | |||
"smlal v22.4s, v13.4h, v3.h[5]\n" | |||
"smlal2 v21.4s, v13.8h, v2.h[5]\n" | |||
"smlal2 v23.4s, v13.8h, v3.h[5]\n" | |||
"ld1 {v4.4s, v5.4s}, [%[b_ptr]], 32\n" | |||
"smlal v20.4s, v14.4h, v2.h[6]\n" | |||
"smlal v22.4s, v14.4h, v3.h[6]\n" | |||
"smlal2 v21.4s, v14.8h, v2.h[6]\n" | |||
"smlal2 v23.4s, v14.8h, v3.h[6]\n" | |||
"smull v24.4s, v8.4h, v4.h[0]\n" | |||
"smull v26.4s, v8.4h, v5.h[0]\n" | |||
"smull2 v25.4s, v8.8h, v4.h[0]\n" | |||
"smull2 v27.4s, v8.8h, v5.h[0]\n" | |||
"smlal v20.4s, v15.4h, v2.h[7]\n" | |||
"smlal v22.4s, v15.4h, v3.h[7]\n" | |||
"smlal2 v21.4s, v15.8h, v2.h[7]\n" | |||
"smlal2 v23.4s, v15.8h, v3.h[7]\n" | |||
"smlal v24.4s, v9.4h, v4.h[1]\n" | |||
"smlal v26.4s, v9.4h, v5.h[1]\n" | |||
"smlal2 v25.4s, v9.8h, v4.h[1]\n" | |||
"smlal2 v27.4s, v9.8h, v5.h[1]\n" | |||
"smlal v24.4s, v10.4h, v4.h[2]\n" | |||
"smlal v26.4s, v10.4h, v5.h[2]\n" | |||
"smlal2 v25.4s, v10.8h, v4.h[2]\n" | |||
"smlal2 v27.4s, v10.8h, v5.h[2]\n" | |||
"smlal v24.4s, v11.4h, v4.h[3]\n" | |||
"smlal v26.4s, v11.4h, v5.h[3]\n" | |||
"smlal2 v25.4s, v11.8h, v4.h[3]\n" | |||
"smlal2 v27.4s, v11.8h, v5.h[3]\n" | |||
"smlal v24.4s, v12.4h, v4.h[4]\n" | |||
"smlal v26.4s, v12.4h, v5.h[4]\n" | |||
"smlal2 v25.4s, v12.8h, v4.h[4]\n" | |||
"smlal2 v27.4s, v12.8h, v5.h[4]\n" | |||
"smlal v24.4s, v13.4h, v4.h[5]\n" | |||
"smlal v26.4s, v13.4h, v5.h[5]\n" | |||
"smlal2 v25.4s, v13.8h, v4.h[5]\n" | |||
"smlal2 v27.4s, v13.8h, v5.h[5]\n" | |||
"ld1 {v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"smlal v24.4s, v14.4h, v4.h[6]\n" | |||
"smlal v26.4s, v14.4h, v5.h[6]\n" | |||
"smlal2 v25.4s, v14.8h, v4.h[6]\n" | |||
"smlal2 v27.4s, v14.8h, v5.h[6]\n" | |||
"smull v28.4s, v8.4h, v6.h[0]\n" | |||
"smull v30.4s, v8.4h, v7.h[0]\n" | |||
"smull2 v29.4s, v8.8h, v6.h[0]\n" | |||
"smull2 v31.4s, v8.8h, v7.h[0]\n" | |||
"smlal v28.4s, v9.4h, v6.h[1]\n" | |||
"smlal v30.4s, v9.4h, v7.h[1]\n" | |||
"smlal2 v29.4s, v9.8h, v6.h[1]\n" | |||
"smlal2 v31.4s, v9.8h, v7.h[1]\n" | |||
"smlal v28.4s, v10.4h, v6.h[2]\n" | |||
"smlal v30.4s, v10.4h, v7.h[2]\n" | |||
"smlal2 v29.4s, v10.8h, v6.h[2]\n" | |||
"smlal2 v31.4s, v10.8h, v7.h[2]\n" | |||
"smlal v28.4s, v11.4h, v6.h[3]\n" | |||
"smlal v30.4s, v11.4h, v7.h[3]\n" | |||
"smlal2 v29.4s, v11.8h, v6.h[3]\n" | |||
"smlal2 v31.4s, v11.8h, v7.h[3]\n" | |||
"smlal v28.4s, v12.4h, v6.h[4]\n" | |||
"smlal v30.4s, v12.4h, v7.h[4]\n" | |||
"smlal2 v29.4s, v12.8h, v6.h[4]\n" | |||
"smlal2 v31.4s, v12.8h, v7.h[4]\n" | |||
"smlal v28.4s, v13.4h, v6.h[5]\n" | |||
"smlal v30.4s, v13.4h, v7.h[5]\n" | |||
"smlal2 v29.4s, v13.8h, v6.h[5]\n" | |||
"smlal2 v31.4s, v13.8h, v7.h[5]\n" | |||
"beq 2f\n" | |||
"1:\n" | |||
"smlal v24.4s, v15.4h, v4.h[7]\n" | |||
"smlal v26.4s, v15.4h, v5.h[7]\n" | |||
"smlal2 v25.4s, v15.8h, v4.h[7]\n" | |||
"ld1 {v8.4s, v9.4s}, [%[a_ptr]], 32\n" | |||
"smlal2 v27.4s, v15.8h, v5.h[7]\n" | |||
"smlal v28.4s, v14.4h, v6.h[6]\n" | |||
"smlal v30.4s, v14.4h, v7.h[6]\n" | |||
"ld1 {v10.4s, v11.4s}, [%[a_ptr]], 32\n" | |||
"smlal2 v29.4s, v15.8h, v6.h[7]\n" | |||
"smlal2 v31.4s, v14.8h, v7.h[6]\n" | |||
"smlal v28.4s, v15.4h, v6.h[7]\n" | |||
"ld1 {v12.4s, v13.4s}, [%[a_ptr]], 32\n" | |||
"smlal v30.4s, v15.4h, v7.h[7]\n" | |||
"smlal2 v29.4s, v14.8h, v6.h[6]\n" | |||
"ld1 {v0.4s}, [%[b_ptr]], 16\n" | |||
"smlal2 v31.4s, v15.8h, v7.h[7]\n" | |||
"smlal v16.4s, v8.4h, v0.h[0]\n" | |||
"ld1 {v1.4s}, [%[b_ptr]], 16\n" | |||
"smlal v16.4s, v9.4h, v0.h[1]\n" | |||
"smlal2 v17.4s, v8.8h, v0.h[0]\n" | |||
"smlal v16.4s, v10.4h, v0.h[2]\n" | |||
"smlal v18.4s, v8.4h, v1.h[0]\n" | |||
"smlal2 v17.4s, v9.8h, v0.h[1]\n" | |||
"smlal2 v19.4s, v8.8h, v1.h[0]\n" | |||
"ld1 {v14.4s, v15.4s}, [%[a_ptr]], 32\n" | |||
"smlal v16.4s, v11.4h, v0.h[3]\n" | |||
"smlal v18.4s, v9.4h, v1.h[1]\n" | |||
"smlal2 v17.4s, v10.8h, v0.h[2]\n" | |||
"smlal2 v19.4s, v9.8h, v1.h[1]\n" | |||
"smlal v16.4s, v12.4h, v0.h[4]\n" | |||
"smlal v18.4s, v10.4h, v1.h[2]\n" | |||
"smlal2 v17.4s, v11.8h, v0.h[3]\n" | |||
"smlal2 v19.4s, v10.8h, v1.h[2]\n" | |||
"smlal v16.4s, v13.4h, v0.h[5]\n" | |||
"smlal v18.4s, v11.4h, v1.h[3]\n" | |||
"smlal2 v17.4s, v12.8h, v0.h[4]\n" | |||
"smlal2 v19.4s, v11.8h, v1.h[3]\n" | |||
"smlal v16.4s, v14.4h, v0.h[6]\n" | |||
"smlal v18.4s, v12.4h, v1.h[4]\n" | |||
"smlal2 v17.4s, v13.8h, v0.h[5]\n" | |||
"smlal2 v19.4s, v12.8h, v1.h[4]\n" | |||
"smlal v16.4s, v15.4h, v0.h[7]\n" | |||
"smlal v18.4s, v13.4h, v1.h[5]\n" | |||
"smlal2 v17.4s, v14.8h, v0.h[6]\n" | |||
"smlal2 v19.4s, v13.8h, v1.h[5]\n" | |||
"ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n" | |||
"smlal v18.4s, v14.4h, v1.h[6]\n" | |||
"smlal2 v17.4s, v15.8h, v0.h[7]\n" | |||
"smlal2 v19.4s, v14.8h, v1.h[6]\n" | |||
"smlal v20.4s, v8.4h, v2.h[0]\n" | |||
"smlal v22.4s, v8.4h, v3.h[0]\n" | |||
"smlal2 v21.4s, v8.8h, v2.h[0]\n" | |||
"smlal2 v23.4s, v8.8h, v3.h[0]\n" | |||
"smlal v18.4s, v15.4h, v1.h[7]\n" | |||
"smlal v20.4s, v9.4h, v2.h[1]\n" | |||
"smlal v22.4s, v9.4h, v3.h[1]\n" | |||
"smlal2 v21.4s, v9.8h, v2.h[1]\n" | |||
"smlal2 v23.4s, v9.8h, v3.h[1]\n" | |||
"smlal2 v19.4s, v15.8h, v1.h[7]\n" | |||
"smlal v20.4s, v10.4h, v2.h[2]\n" | |||
"smlal v22.4s, v10.4h, v3.h[2]\n" | |||
"smlal2 v21.4s, v10.8h, v2.h[2]\n" | |||
"smlal2 v23.4s, v10.8h, v3.h[2]\n" | |||
"smlal v20.4s, v11.4h, v2.h[3]\n" | |||
"smlal v22.4s, v11.4h, v3.h[3]\n" | |||
"smlal2 v21.4s, v11.8h, v2.h[3]\n" | |||
"smlal2 v23.4s, v11.8h, v3.h[3]\n" | |||
"smlal v20.4s, v12.4h, v2.h[4]\n" | |||
"smlal v22.4s, v12.4h, v3.h[4]\n" | |||
"smlal2 v21.4s, v12.8h, v2.h[4]\n" | |||
"smlal2 v23.4s, v12.8h, v3.h[4]\n" | |||
"smlal v20.4s, v13.4h, v2.h[5]\n" | |||
"smlal v22.4s, v13.4h, v3.h[5]\n" | |||
"smlal2 v21.4s, v13.8h, v2.h[5]\n" | |||
"smlal2 v23.4s, v13.8h, v3.h[5]\n" | |||
"ld1 {v4.4s, v5.4s}, [%[b_ptr]], 32\n" | |||
"smlal v20.4s, v14.4h, v2.h[6]\n" | |||
"smlal v22.4s, v14.4h, v3.h[6]\n" | |||
"smlal2 v21.4s, v14.8h, v2.h[6]\n" | |||
"smlal2 v23.4s, v14.8h, v3.h[6]\n" | |||
"smlal v24.4s, v8.4h, v4.h[0]\n" | |||
"smlal v26.4s, v8.4h, v5.h[0]\n" | |||
"smlal2 v25.4s, v8.8h, v4.h[0]\n" | |||
"smlal2 v27.4s, v8.8h, v5.h[0]\n" | |||
"smlal v20.4s, v15.4h, v2.h[7]\n" | |||
"smlal2 v21.4s, v15.8h, v2.h[7]\n" | |||
"smlal v22.4s, v15.4h, v3.h[7]\n" | |||
"smlal2 v23.4s, v15.8h, v3.h[7]\n" | |||
"smlal v24.4s, v9.4h, v4.h[1]\n" | |||
"smlal v26.4s, v9.4h, v5.h[1]\n" | |||
"smlal2 v25.4s, v9.8h, v4.h[1]\n" | |||
"smlal2 v27.4s, v9.8h, v5.h[1]\n" | |||
"smlal v24.4s, v10.4h, v4.h[2]\n" | |||
"smlal v26.4s, v10.4h, v5.h[2]\n" | |||
"smlal2 v25.4s, v10.8h, v4.h[2]\n" | |||
"smlal2 v27.4s, v10.8h, v5.h[2]\n" | |||
"smlal v24.4s, v11.4h, v4.h[3]\n" | |||
"smlal v26.4s, v11.4h, v5.h[3]\n" | |||
"smlal2 v25.4s, v11.8h, v4.h[3]\n" | |||
"smlal2 v27.4s, v11.8h, v5.h[3]\n" | |||
"smlal v24.4s, v12.4h, v4.h[4]\n" | |||
"smlal v26.4s, v12.4h, v5.h[4]\n" | |||
"smlal2 v25.4s, v12.8h, v4.h[4]\n" | |||
"smlal2 v27.4s, v12.8h, v5.h[4]\n" | |||
"smlal v24.4s, v13.4h, v4.h[5]\n" | |||
"smlal v26.4s, v13.4h, v5.h[5]\n" | |||
"smlal2 v25.4s, v13.8h, v4.h[5]\n" | |||
"smlal2 v27.4s, v13.8h, v5.h[5]\n" | |||
"ld1 {v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n" | |||
"smlal v24.4s, v14.4h, v4.h[6]\n" | |||
"smlal v26.4s, v14.4h, v5.h[6]\n" | |||
"smlal2 v25.4s, v14.8h, v4.h[6]\n" | |||
"smlal2 v27.4s, v14.8h, v5.h[6]\n" | |||
"smlal v28.4s, v8.4h, v6.h[0]\n" | |||
"smlal v30.4s, v8.4h, v7.h[0]\n" | |||
"smlal2 v29.4s, v8.8h, v6.h[0]\n" | |||
"smlal2 v31.4s, v8.8h, v7.h[0]\n" | |||
"smlal v28.4s, v9.4h, v6.h[1]\n" | |||
"smlal v30.4s, v9.4h, v7.h[1]\n" | |||
"smlal2 v29.4s, v9.8h, v6.h[1]\n" | |||
"smlal2 v31.4s, v9.8h, v7.h[1]\n" | |||
"smlal v28.4s, v10.4h, v6.h[2]\n" | |||
"smlal v30.4s, v10.4h, v7.h[2]\n" | |||
"smlal2 v29.4s, v10.8h, v6.h[2]\n" | |||
"smlal2 v31.4s, v10.8h, v7.h[2]\n" | |||
"smlal v28.4s, v11.4h, v6.h[3]\n" | |||
"smlal v30.4s, v11.4h, v7.h[3]\n" | |||
"smlal2 v29.4s, v11.8h, v6.h[3]\n" | |||
"smlal2 v31.4s, v11.8h, v7.h[3]\n" | |||
"smlal v28.4s, v12.4h, v6.h[4]\n" | |||
"smlal v30.4s, v12.4h, v7.h[4]\n" | |||
"smlal2 v29.4s, v12.8h, v6.h[4]\n" | |||
"smlal2 v31.4s, v12.8h, v7.h[4]\n" | |||
"smlal v28.4s, v13.4h, v6.h[5]\n" | |||
"smlal v30.4s, v13.4h, v7.h[5]\n" | |||
"smlal2 v29.4s, v13.8h, v6.h[5]\n" | |||
"smlal2 v31.4s, v13.8h, v7.h[5]\n" | |||
"subs %w[K], %w[K], #8\n" | |||
"bne 1b\n" | |||
"2:\n" | |||
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n" | |||
"smlal v24.4s, v15.4h, v4.h[7]\n" | |||
"smlal v28.4s, v14.4h, v6.h[6]\n" | |||
"smlal v30.4s, v14.4h, v7.h[6]\n" | |||
"smlal v26.4s, v15.4h, v5.h[7]\n" | |||
"smlal2 v25.4s, v15.8h, v4.h[7]\n" | |||
"smlal2 v27.4s, v15.8h, v5.h[7]\n" | |||
"smlal2 v29.4s, v14.8h, v6.h[6]\n" | |||
"smlal2 v31.4s, v14.8h, v7.h[6]\n" | |||
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n" | |||
"smlal v28.4s, v15.4h, v6.h[7]\n" | |||
"smlal v30.4s, v15.4h, v7.h[7]\n" | |||
"smlal2 v29.4s, v15.8h, v6.h[7]\n" | |||
"smlal2 v31.4s, v15.8h, v7.h[7]\n" | |||
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n" | |||
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "cc", "memory"); | |||
} | |||
} // anonymous namespace | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8); | |||
void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||
size_t LDB, dt_int32* C, size_t LDC, size_t M, | |||
size_t K, size_t N, const dt_int32*, void*, | |||
bool trA, bool trB) const { | |||
constexpr static size_t MB = 8; | |||
constexpr static size_t KB = 8; | |||
constexpr static size_t NB = 8; | |||
constexpr static size_t CALCBLK = 4; | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); | |||
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) | |||
for (size_t m = 0; m < M; m += MB) { | |||
dt_int32* output = C + (m / MB) * LDC; | |||
const dt_int16* cur_B = B; | |||
size_t n = 0; | |||
for (; n + NB - 1 < N; n += NB) { | |||
kern_8x8(A, cur_B, LDB, K, output); | |||
cur_B += KB * NB; | |||
output += MB * NB; | |||
} | |||
if (n < N) { | |||
kern_8x4(A, cur_B, LDB, K, output); | |||
} | |||
A += LDA; | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,856 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#if !(__ARM_FEATURE_DOTPROD) | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul_4x4x16 { | |||
/** | |||
* Overview of register layout: | |||
* | |||
* A 16x2 cell of Rhs is stored in 8bit in q2-q4. | |||
* A 8x2x2 cell of Lhs is stored in 8bit in q0-q1 | |||
* A 8x16 block of accumulators is stored in 8bit in q8--q31. | |||
* | |||
* \warning Fast kernel operating on int8 operands. | |||
* It is assumed that one of the two int8 operands only takes values | |||
* in [-127, 127], while the other may freely range in [-128, 127]. | |||
* The issue with both operands taking the value -128 is that: | |||
* -128*-128 + -128*-128 == -32768 overflows int16. | |||
* Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 | |||
* range. That is the basic idea of this kernel. | |||
* | |||
* | |||
* +--------+--------+---------+---------+ | |||
* |v4[0-16]|v5[0-16]| v6[0-16]| v7[0-16]| | |||
* Rhs +--------+--------+---------+---------+ | |||
* |v8[0-16]|v9[0-16]|v10[0-16]|v11[0-16]| | |||
* +--------+--------+---------+---------+ | |||
* | | | | | | |||
* | |||
* Lhs | | | | | | |||
* | |||
* +--------+ - - - - +-------------------------------------+ | |||
* |v0[0-16]| |v16[0-4]|v17[0-4]| v18[0-4]| v19[0-4]| | |||
* |v1[0-16]| |v20[0-4]|v21[0-4]| v22[0-4]| v23[0-4]| | |||
* |v2[0-16]| |v24[0-4]|v25[0-4]| v26[0-4]| v27[0-4]| | |||
* |v3[0-16]| |v28[0-4]|v29[0-4]| v30[0-4]| v31[0-4]| | |||
* +--------+ - - - - +-------------------------------------+ | |||
* | |||
* Accumulator | |||
*/ | |||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k) { | |||
K /= 16; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
// Fix up for odd lengths - set a flag if K is odd, but make | |||
// sure we round up the iteration count. | |||
int oddk = (K & 1); | |||
int k = ((K + 1) / 2) - 1; | |||
LDC = LDC * sizeof(int32_t); | |||
asm volatile ( | |||
// load accumulator C | |||
"add x1, %[output], %x[LDC]\n" | |||
"add x2, x1, %x[LDC]\n" | |||
"add x3, x2, %x[LDC]\n" | |||
"cmp %w[is_first_k], #1\n" | |||
"beq 1f\n" | |||
"ldr q16, [%[output]]\n" | |||
"ldr q17, [x1]\n" | |||
"ldr q18, [x2]\n" | |||
"ldr q19, [x3]\n" | |||
"b 2f\n" | |||
"1:\n" | |||
"eor v16.16b, v16.16b, v16.16b\n" | |||
"eor v17.16b, v17.16b, v17.16b\n" | |||
"eor v18.16b, v18.16b, v18.16b\n" | |||
"eor v19.16b, v19.16b, v19.16b\n" | |||
"2: \n" | |||
"ldr q0, [%[a_ptr]]\n" | |||
"ldr q4, [%[b_ptr]]\n" | |||
"ldr q5, [%[b_ptr], #16]\n" | |||
"ldr q6, [%[b_ptr], #32]\n" | |||
"movi v20.4s, #0x0\n" | |||
"ldr q7, [%[b_ptr], #48]\n" | |||
"movi v21.4s, #0x0\n" | |||
"ldr q1, [%[a_ptr], #16]\n" | |||
"movi v22.4s, #0x0\n" | |||
"ldr q2, [%[a_ptr], #32]\n" | |||
"movi v23.4s, #0x0\n" | |||
"ldr q3, [%[a_ptr], #48]\n" | |||
"movi v24.4s, #0x0\n" | |||
ASM_PREFETCH("[%[b_ptr], #64]") | |||
"movi v25.4s, #0x0\n" | |||
ASM_PREFETCH("[%[a_ptr], #64]") | |||
"movi v26.4s, #0x0\n" | |||
ASM_PREFETCH("[%[b_ptr], #128]") | |||
"movi v27.4s, #0x0\n" | |||
ASM_PREFETCH("[%[a_ptr], #128]") | |||
"movi v28.4s, #0x0\n" | |||
ASM_PREFETCH("[%[b_ptr], #192]") | |||
"movi v29.4s, #0x0\n" | |||
ASM_PREFETCH("[%[a_ptr], #192]") | |||
"movi v30.4s, #0x0\n" | |||
ASM_PREFETCH("[%[b_ptr], #256]") | |||
"movi v31.4s, #0x0\n" | |||
ASM_PREFETCH("[%[a_ptr], #256]") | |||
// Start of unroll 0 (first iteration) | |||
"smull v12.8h, v0.8b, v4.8b\n" | |||
"smull v13.8h, v0.8b, v5.8b\n" | |||
// Skip loop if we are doing zero iterations of it. | |||
"cbz %w[k], 4f\n" | |||
// Unroll 0 continuation (branch target) | |||
"3:\n" | |||
"smull v14.8h, v0.8b, v6.8b\n" | |||
"subs %w[k], %w[k], #1\n" | |||
"smull v15.8h, v0.8b, v7.8b\n" | |||
"ldr q8, [%[b_ptr], #64]\n" | |||
"smlal2 v12.8h, v0.16b, v4.16b\n" | |||
"smlal2 v13.8h, v0.16b, v5.16b\n" | |||
"ldr q9, [%[b_ptr], #80]\n" | |||
"smlal2 v14.8h, v0.16b, v6.16b\n" | |||
"smlal2 v15.8h, v0.16b, v7.16b\n" | |||
"ldr q0, [%[a_ptr], #64]\n" | |||
"sadalp v16.4s, v12.8h\n" | |||
"smull v12.8h, v1.8b, v4.8b\n" | |||
"sadalp v17.4s, v13.8h\n" | |||
"sadalp v18.4s, v14.8h\n" | |||
"smull v13.8h, v1.8b, v5.8b\n" | |||
"sadalp v19.4s, v15.8h\n" | |||
"smull v14.8h, v1.8b, v6.8b\n" | |||
"ldr q10, [%[b_ptr], #96]\n" | |||
"smull v15.8h, v1.8b, v7.8b\n" | |||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||
"ldr q11, [%[b_ptr], #112]\n" | |||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||
"add %[b_ptr], %[b_ptr], #128\n" | |||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||
"ldr q1, [%[a_ptr], #80]\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"smull v12.8h, v2.8b, v4.8b\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"smull v13.8h, v2.8b, v5.8b\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"smull v14.8h, v2.8b, v6.8b\n" | |||
"smull v15.8h, v2.8b, v7.8b\n" | |||
"smlal2 v12.8h, v2.16b, v4.16b\n" | |||
ASM_PREFETCH("[%[b_ptr], #192]") | |||
"smlal2 v13.8h, v2.16b, v5.16b\n" | |||
"smlal2 v14.8h, v2.16b, v6.16b\n" | |||
ASM_PREFETCH("[%[a_ptr], #320]") | |||
"smlal2 v15.8h, v2.16b, v7.16b\n" | |||
"ldr q2, [%[a_ptr], #96]\n" | |||
"sadalp v24.4s, v12.8h\n" | |||
"smull v12.8h, v3.8b, v4.8b\n" | |||
"sadalp v25.4s, v13.8h\n" | |||
"sadalp v26.4s, v14.8h\n" | |||
"smull v13.8h, v3.8b, v5.8b\n" | |||
"sadalp v27.4s, v15.8h\n" | |||
"smull v14.8h, v3.8b, v6.8b\n" | |||
"smull v15.8h, v3.8b, v7.8b\n" | |||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||
"ldr q4, [%[b_ptr], #0]\n" | |||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||
"ldr q3, [%[a_ptr], #112]\n" | |||
// Unroll 1 | |||
"sadalp v28.4s, v12.8h\n" | |||
"smull v12.8h, v0.8b, v8.8b\n" | |||
"sadalp v29.4s, v13.8h\n" | |||
"sadalp v30.4s, v14.8h\n" | |||
"smull v13.8h, v0.8b, v9.8b\n" | |||
"sadalp v31.4s, v15.8h\n" | |||
"smull v14.8h, v0.8b, v10.8b\n" | |||
"smull v15.8h, v0.8b, v11.8b\n" | |||
"ldr q5, [%[b_ptr], #16]\n" | |||
"smlal2 v12.8h, v0.16b, v8.16b\n" | |||
"smlal2 v13.8h, v0.16b, v9.16b\n" | |||
"ldr q6, [%[b_ptr], #32]\n" | |||
"smlal2 v14.8h, v0.16b, v10.16b\n" | |||
"smlal2 v15.8h, v0.16b, v11.16b\n" | |||
"ldr q0, [%[a_ptr], #128]\n" | |||
"sadalp v16.4s, v12.8h\n" | |||
"smull v12.8h, v1.8b, v8.8b\n" | |||
"sadalp v17.4s, v13.8h\n" | |||
"sadalp v18.4s, v14.8h\n" | |||
"smull v13.8h, v1.8b, v9.8b\n" | |||
"sadalp v19.4s, v15.8h\n" | |||
"add %[a_ptr], %[a_ptr], #128\n" | |||
"smull v14.8h, v1.8b, v10.8b\n" | |||
"smull v15.8h, v1.8b, v11.8b\n" | |||
"ldr q7, [%[b_ptr], #48]\n" | |||
"smlal2 v12.8h, v1.16b, v8.16b\n" | |||
"smlal2 v13.8h, v1.16b, v9.16b\n" | |||
"smlal2 v14.8h, v1.16b, v10.16b\n" | |||
"smlal2 v15.8h, v1.16b, v11.16b\n" | |||
"ldr q1, [%[a_ptr], #16]\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"smull v12.8h, v2.8b, v8.8b\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"smull v13.8h, v2.8b, v9.8b\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"smull v14.8h, v2.8b, v10.8b\n" | |||
"smull v15.8h, v2.8b, v11.8b\n" | |||
"smlal2 v12.8h, v2.16b, v8.16b\n" | |||
ASM_PREFETCH("[%[b_ptr], #256]") | |||
"smlal2 v13.8h, v2.16b, v9.16b\n" | |||
"smlal2 v14.8h, v2.16b, v10.16b\n" | |||
ASM_PREFETCH("[%[a_ptr], #256]") | |||
"smlal2 v15.8h, v2.16b, v11.16b\n" | |||
"ldr q2, [%[a_ptr], #32]\n" | |||
"sadalp v24.4s, v12.8h\n" | |||
"smull v12.8h, v3.8b, v8.8b\n" | |||
"sadalp v25.4s, v13.8h\n" | |||
"sadalp v26.4s, v14.8h\n" | |||
"smull v13.8h, v3.8b, v9.8b\n" | |||
"sadalp v27.4s, v15.8h\n" | |||
"smull v14.8h, v3.8b, v10.8b\n" | |||
"smull v15.8h, v3.8b, v11.8b\n" | |||
"smlal2 v12.8h, v3.16b, v8.16b\n" | |||
"smlal2 v13.8h, v3.16b, v9.16b\n" | |||
"smlal2 v14.8h, v3.16b, v10.16b\n" | |||
"smlal2 v15.8h, v3.16b, v11.16b\n" | |||
"ldr q3, [%[a_ptr], #48]\n" | |||
// Start of unroll 0 for next iteration. | |||
"sadalp v28.4s, v12.8h\n" | |||
"smull v12.8h, v0.8b, v4.8b\n" | |||
"sadalp v29.4s, v13.8h\n" | |||
"sadalp v30.4s, v14.8h\n" | |||
"smull v13.8h, v0.8b, v5.8b\n" | |||
"sadalp v31.4s, v15.8h\n" | |||
"bne 3b\n" | |||
// Target to use when K=1 or 2 (i.e. zero iterations of main loop) | |||
"4:\n" | |||
// Branch to alternative tail for odd K | |||
"cbnz %w[oddk], 5f\n" | |||
// Detached final iteration (even K) | |||
"smull v14.8h, v0.8b, v6.8b\n" | |||
"smull v15.8h, v0.8b, v7.8b\n" | |||
"ldr q8, [%[b_ptr], #64]\n" | |||
"smlal2 v12.8h, v0.16b, v4.16b\n" | |||
"smlal2 v13.8h, v0.16b, v5.16b\n" | |||
"ldr q9, [%[b_ptr], #80]\n" | |||
"smlal2 v14.8h, v0.16b, v6.16b\n" | |||
"smlal2 v15.8h, v0.16b, v7.16b\n" | |||
"ldr q0, [%[a_ptr], #64]\n" | |||
"sadalp v16.4s, v12.8h\n" | |||
"smull v12.8h, v1.8b, v4.8b\n" | |||
"sadalp v17.4s, v13.8h\n" | |||
"sadalp v18.4s, v14.8h\n" | |||
"smull v13.8h, v1.8b, v5.8b\n" | |||
"sadalp v19.4s, v15.8h\n" | |||
"smull v14.8h, v1.8b, v6.8b\n" | |||
"ldr q10, [%[b_ptr], #96]\n" | |||
"smull v15.8h, v1.8b, v7.8b\n" | |||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||
"ldr q11, [%[b_ptr], #112]\n" | |||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||
"add %[b_ptr], %[b_ptr], #128\n" | |||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||
"ldr q1, [%[a_ptr], #80]\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"smull v12.8h, v2.8b, v4.8b\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"smull v13.8h, v2.8b, v5.8b\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"smull v14.8h, v2.8b, v6.8b\n" | |||
"smull v15.8h, v2.8b, v7.8b\n" | |||
"smlal2 v12.8h, v2.16b, v4.16b\n" | |||
"smlal2 v13.8h, v2.16b, v5.16b\n" | |||
"smlal2 v14.8h, v2.16b, v6.16b\n" | |||
"smlal2 v15.8h, v2.16b, v7.16b\n" | |||
"ldr q2, [%[a_ptr], #96]\n" | |||
"sadalp v24.4s, v12.8h\n" | |||
"smull v12.8h, v3.8b, v4.8b\n" | |||
"sadalp v25.4s, v13.8h\n" | |||
"sadalp v26.4s, v14.8h\n" | |||
"smull v13.8h, v3.8b, v5.8b\n" | |||
"sadalp v27.4s, v15.8h\n" | |||
"smull v14.8h, v3.8b, v6.8b\n" | |||
"smull v15.8h, v3.8b, v7.8b\n" | |||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||
"ldr q3, [%[a_ptr], #112]\n" | |||
// Unroll 1 | |||
"sadalp v28.4s, v12.8h\n" | |||
"smull v12.8h, v0.8b, v8.8b\n" | |||
"sadalp v29.4s, v13.8h\n" | |||
"sadalp v30.4s, v14.8h\n" | |||
"smull v13.8h, v0.8b, v9.8b\n" | |||
"sadalp v31.4s, v15.8h\n" | |||
"smull v14.8h, v0.8b, v10.8b\n" | |||
"add %[a_ptr], %[a_ptr], #128\n" | |||
"smull v15.8h, v0.8b, v11.8b\n" | |||
"smlal2 v12.8h, v0.16b, v8.16b\n" | |||
"smlal2 v13.8h, v0.16b, v9.16b\n" | |||
"smlal2 v14.8h, v0.16b, v10.16b\n" | |||
"smlal2 v15.8h, v0.16b, v11.16b\n" | |||
"sadalp v16.4s, v12.8h\n" | |||
"smull v12.8h, v1.8b, v8.8b\n" | |||
"sadalp v17.4s, v13.8h\n" | |||
"sadalp v18.4s, v14.8h\n" | |||
"smull v13.8h, v1.8b, v9.8b\n" | |||
"sadalp v19.4s, v15.8h\n" | |||
"smull v14.8h, v1.8b, v10.8b\n" | |||
"smull v15.8h, v1.8b, v11.8b\n" | |||
"smlal2 v12.8h, v1.16b, v8.16b\n" | |||
"addp v16.4s, v16.4s, v17.4s\n" | |||
"smlal2 v13.8h, v1.16b, v9.16b\n" | |||
"addp v17.4s, v18.4s, v19.4s\n" | |||
"smlal2 v14.8h, v1.16b, v10.16b\n" | |||
"smlal2 v15.8h, v1.16b, v11.16b\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"smull v12.8h, v2.8b, v8.8b\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"smull v13.8h, v2.8b, v9.8b\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"addp v16.4s, v16.4s, v17.4s\n" | |||
"smull v14.8h, v2.8b, v10.8b\n" | |||
"addp v18.4s, v20.4s, v21.4s\n" | |||
"addp v19.4s, v22.4s, v23.4s\n" | |||
"smull v15.8h, v2.8b, v11.8b\n" | |||
"smlal2 v12.8h, v2.16b, v8.16b\n" | |||
"str q16, [%[output]]\n" | |||
"smlal2 v13.8h, v2.16b, v9.16b\n" | |||
"smlal2 v14.8h, v2.16b, v10.16b\n" | |||
"smlal2 v15.8h, v2.16b, v11.16b\n" | |||
"sadalp v24.4s, v12.8h\n" | |||
"smull v12.8h, v3.8b, v8.8b\n" | |||
"sadalp v25.4s, v13.8h\n" | |||
"sadalp v26.4s, v14.8h\n" | |||
"smull v13.8h, v3.8b, v9.8b\n" | |||
"sadalp v27.4s, v15.8h\n" | |||
"addp v17.4s, v18.4s, v19.4s\n" | |||
"smull v14.8h, v3.8b, v10.8b\n" | |||
"addp v20.4s, v24.4s, v25.4s\n" | |||
"addp v21.4s, v26.4s, v27.4s\n" | |||
"smull v15.8h, v3.8b, v11.8b\n" | |||
"smlal2 v12.8h, v3.16b, v8.16b\n" | |||
"str q17, [x1]\n" | |||
"smlal2 v13.8h, v3.16b, v9.16b\n" | |||
"smlal2 v14.8h, v3.16b, v10.16b\n" | |||
"addp v18.4s, v20.4s, v21.4s\n" | |||
"smlal2 v15.8h, v3.16b, v11.16b\n" | |||
"b 6f\n" | |||
// Detached final iteration (odd K) | |||
"5:\n" | |||
"smull v14.8h, v0.8b, v6.8b\n" | |||
"add %[a_ptr], %[a_ptr], #64\n" | |||
"smull v15.8h, v0.8b, v7.8b\n" | |||
"add %[b_ptr], %[b_ptr], #64\n" | |||
"smlal2 v12.8h, v0.16b, v4.16b\n" | |||
"smlal2 v13.8h, v0.16b, v5.16b\n" | |||
"smlal2 v14.8h, v0.16b, v6.16b\n" | |||
"smlal2 v15.8h, v0.16b, v7.16b\n" | |||
"sadalp v16.4s, v12.8h\n" | |||
"smull v12.8h, v1.8b, v4.8b\n" | |||
"sadalp v17.4s, v13.8h\n" | |||
"sadalp v18.4s, v14.8h\n" | |||
"smull v13.8h, v1.8b, v5.8b\n" | |||
"sadalp v19.4s, v15.8h\n" | |||
"smull v14.8h, v1.8b, v6.8b\n" | |||
"smull v15.8h, v1.8b, v7.8b\n" | |||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||
"addp v16.4s, v16.4s, v17.4s\n" | |||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||
"addp v17.4s, v18.4s, v19.4s\n" | |||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"smull v12.8h, v2.8b, v4.8b\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"smull v13.8h, v2.8b, v5.8b\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"addp v16.4s, v16.4s, v17.4s\n" | |||
"smull v14.8h, v2.8b, v6.8b\n" | |||
"addp v18.4s, v20.4s, v21.4s\n" | |||
"addp v19.4s, v22.4s, v23.4s\n" | |||
"smull v15.8h, v2.8b, v7.8b\n" | |||
"smlal2 v12.8h, v2.16b, v4.16b\n" | |||
"str q16, [%[output]]\n" | |||
"smlal2 v13.8h, v2.16b, v5.16b\n" | |||
"smlal2 v14.8h, v2.16b, v6.16b\n" | |||
"smlal2 v15.8h, v2.16b, v7.16b\n" | |||
"sadalp v24.4s, v12.8h\n" | |||
"smull v12.8h, v3.8b, v4.8b\n" | |||
"sadalp v25.4s, v13.8h\n" | |||
"sadalp v26.4s, v14.8h\n" | |||
"smull v13.8h, v3.8b, v5.8b\n" | |||
"sadalp v27.4s, v15.8h\n" | |||
"addp v17.4s, v18.4s, v19.4s\n" | |||
"smull v14.8h, v3.8b, v6.8b\n" | |||
"addp v20.4s, v24.4s, v25.4s\n" | |||
"addp v21.4s, v26.4s, v27.4s\n" | |||
"smull v15.8h, v3.8b, v7.8b\n" | |||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||
"str q17, [x1]\n" | |||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||
"addp v18.4s, v20.4s, v21.4s\n" | |||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||
"6:\n" | |||
// Final additions | |||
"sadalp v28.4s, v12.8h\n" | |||
"str q18, [x2]\n" | |||
"sadalp v29.4s, v13.8h\n" | |||
"sadalp v30.4s, v14.8h\n" | |||
"sadalp v31.4s, v15.8h\n" | |||
// Horizontal reduction, phase 1 | |||
"addp v22.4s, v28.4s, v29.4s\n" | |||
"addp v23.4s, v30.4s, v31.4s\n" | |||
// Horizontal reduction, phase 2 | |||
"addp v19.4s, v22.4s, v23.4s\n" | |||
"str q19, [x3]\n" | |||
: | |||
[a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [oddk] "+r" (oddk), | |||
[is_first_k] "+r" (is_first_k), [k] "+r" (k), [LDC] "+r" (LDC), | |||
[output] "+r"(output) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
"v17", "v18", "v19", "v20","v21","v22","v23","v24","v25","v26", | |||
"v27","v28","v29","v30","v31", "x1", "x2", "x3", | |||
"cc", "memory" | |||
); | |||
} | |||
static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, | |||
int m_remain, int n_remain) { | |||
megdnn_assert(K > 0); | |||
K /= 16; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
LDC = LDC * sizeof(int32_t); | |||
// clang-format off | |||
#define LOAD_LINE(reg_index, n) \ | |||
"cbz x4, 102f\n" \ | |||
"mov x5, x" n "\n" \ | |||
"cmp %w[n_remain], #4\n" \ | |||
"blt 100" n "f\n" \ | |||
"ldr q" reg_index ", [x5]\n" \ | |||
"b 101" n "f\n" \ | |||
"100" n ":\n" \ | |||
"cmp %w[n_remain], #0\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" reg_index ".s}[0], [x5], #4\n" \ | |||
"cmp %w[n_remain], #1\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" reg_index ".s}[1], [x5], #4\n" \ | |||
"cmp %w[n_remain], #2\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" reg_index ".s}[2], [x5], #4\n" \ | |||
"101" n ":\n" \ | |||
"subs x4, x4, #1\n" | |||
#define LOAD_C \ | |||
"mov x4, %x[m_remain]\n" \ | |||
LOAD_LINE("16", "0") \ | |||
LOAD_LINE("17", "1") \ | |||
LOAD_LINE("18", "2") \ | |||
LOAD_LINE("19", "3") \ | |||
"102:\n" | |||
#define STORE_LINE(reg_index, n) \ | |||
"cbz x4, 105f\n" \ | |||
"mov x5, x" n "\n" \ | |||
"cmp %w[n_remain], #4\n" \ | |||
"blt 103" n "f\n" \ | |||
"str q" reg_index ", [x5]\n" \ | |||
"b 104" n "f\n" \ | |||
"103" n ":\n" \ | |||
"cmp %w[n_remain], #0\n" \ | |||
"beq 104" n "f\n" \ | |||
"st1 {v" reg_index ".s}[0], [x5], #4\n" \ | |||
"cmp %w[n_remain], #1\n" \ | |||
"beq 104" n "f\n" \ | |||
"st1 {v" reg_index ".s}[1], [x5], #4\n" \ | |||
"cmp %w[n_remain], #2\n" \ | |||
"beq 104" n "f\n" \ | |||
"st1 {v" reg_index ".s}[2], [x5], #4\n" \ | |||
"104" n ":\n" \ | |||
"subs x4, x4, #1\n" | |||
#define STORE_C \ | |||
"mov x4, %x[m_remain]\n" \ | |||
STORE_LINE("16", "0") \ | |||
STORE_LINE("17", "1") \ | |||
STORE_LINE("18", "2") \ | |||
STORE_LINE("19", "3") \ | |||
"105:\n" | |||
// clang-format on | |||
asm volatile( | |||
// load accumulator C | |||
"mov x0, %[output]\n" | |||
"add x1, x0, %x[LDC]\n" | |||
"add x2, x1, %x[LDC]\n" | |||
"add x3, x2, %x[LDC]\n" | |||
"cmp %w[is_first_k], #1\n" | |||
"beq 1f\n" | |||
LOAD_C // | |||
"b 2f\n" | |||
"1:\n" | |||
"eor v16.16b, v16.16b, v16.16b\n" | |||
"eor v17.16b, v17.16b, v17.16b\n" | |||
"eor v18.16b, v18.16b, v18.16b\n" | |||
"eor v19.16b, v19.16b, v19.16b\n" | |||
"eor v20.16b, v20.16b, v20.16b\n" | |||
"eor v21.16b, v21.16b, v21.16b\n" | |||
"eor v22.16b, v22.16b, v22.16b\n" | |||
"eor v23.16b, v23.16b, v23.16b\n" | |||
"eor v24.16b, v24.16b, v24.16b\n" | |||
"eor v25.16b, v25.16b, v25.16b\n" | |||
"eor v26.16b, v26.16b, v26.16b\n" | |||
"eor v27.16b, v27.16b, v27.16b\n" | |||
"eor v28.16b, v28.16b, v28.16b\n" | |||
"eor v29.16b, v29.16b, v29.16b\n" | |||
"eor v30.16b, v30.16b, v30.16b\n" | |||
"eor v31.16b, v31.16b, v31.16b\n" | |||
"2: \n" | |||
"ldr q4, [%[b_ptr]]\n" | |||
"ldr q5, [%[b_ptr], #16]\n" | |||
"ldr q6, [%[b_ptr], #32]\n" | |||
"ldr q7, [%[b_ptr], #48]\n" | |||
"ldr q0, [%[a_ptr]]\n" | |||
"ldr q1, [%[a_ptr], #16]\n" | |||
"ldr q2, [%[a_ptr], #32]\n" | |||
"ldr q3, [%[a_ptr], #48]\n" | |||
"smull v12.8h, v0.8b, v4.8b\n" | |||
"smull v13.8h, v0.8b, v5.8b\n" | |||
"smull v14.8h, v0.8b, v6.8b\n" | |||
"smull v15.8h, v0.8b, v7.8b\n" | |||
"smlal2 v12.8h, v0.16b, v4.16b\n" | |||
"smlal2 v13.8h, v0.16b, v5.16b\n" | |||
"smlal2 v14.8h, v0.16b, v6.16b\n" | |||
"smlal2 v15.8h, v0.16b, v7.16b\n" | |||
"sadalp v16.4s, v12.8h\n" | |||
"sadalp v17.4s, v13.8h\n" | |||
"sadalp v18.4s, v14.8h\n" | |||
"sadalp v19.4s, v15.8h\n" | |||
"smull v12.8h, v1.8b, v4.8b\n" | |||
"smull v13.8h, v1.8b, v5.8b\n" | |||
"smull v14.8h, v1.8b, v6.8b\n" | |||
"smull v15.8h, v1.8b, v7.8b\n" | |||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"smull v12.8h, v2.8b, v4.8b\n" | |||
"smull v13.8h, v2.8b, v5.8b\n" | |||
"smull v14.8h, v2.8b, v6.8b\n" | |||
"smull v15.8h, v2.8b, v7.8b\n" | |||
"smlal2 v12.8h, v2.16b, v4.16b\n" | |||
"smlal2 v13.8h, v2.16b, v5.16b\n" | |||
"smlal2 v14.8h, v2.16b, v6.16b\n" | |||
"smlal2 v15.8h, v2.16b, v7.16b\n" | |||
"sadalp v24.4s, v12.8h\n" | |||
"sadalp v25.4s, v13.8h\n" | |||
"sadalp v26.4s, v14.8h\n" | |||
"sadalp v27.4s, v15.8h\n" | |||
"smull v12.8h, v3.8b, v4.8b\n" | |||
"smull v13.8h, v3.8b, v5.8b\n" | |||
"smull v14.8h, v3.8b, v6.8b\n" | |||
"smull v15.8h, v3.8b, v7.8b\n" | |||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||
"sadalp v28.4s, v12.8h\n" | |||
"sadalp v29.4s, v13.8h\n" | |||
"sadalp v30.4s, v14.8h\n" | |||
"sadalp v31.4s, v15.8h\n" | |||
"add %[a_ptr], %[a_ptr], #64\n" | |||
"add %[b_ptr], %[b_ptr], #64\n" | |||
"subs %w[K], %w[K], #1\n" | |||
"cbnz %w[K], 2b\n" | |||
"3:\n" | |||
// reduction | |||
"addp v16.4s, v16.4s, v17.4s\n" | |||
"addp v17.4s, v18.4s, v19.4s\n" | |||
"addp v16.4s, v16.4s, v17.4s\n" | |||
"addp v18.4s, v20.4s, v21.4s\n" | |||
"addp v19.4s, v22.4s, v23.4s\n" | |||
"addp v17.4s, v18.4s, v19.4s\n" | |||
"addp v20.4s, v24.4s, v25.4s\n" | |||
"addp v21.4s, v26.4s, v27.4s\n" | |||
"addp v18.4s, v20.4s, v21.4s\n" | |||
"addp v22.4s, v28.4s, v29.4s\n" | |||
"addp v23.4s, v30.4s, v31.4s\n" | |||
"addp v19.4s, v22.4s, v23.4s\n" | |||
STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[output] "+r"(output), [m_remain] "+r"(m_remain), | |||
[n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc", | |||
"memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
#undef STORE_C | |||
} | |||
static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
int y = y0; | |||
for (; y + 3 < ymax; y += 4) { | |||
const int8_t* inptr0 = inptr + y * ldin + k0; | |||
const int8_t* inptr1 = inptr0 + ldin; | |||
const int8_t* inptr2 = inptr1 + ldin; | |||
const int8_t* inptr3 = inptr2 + ldin; | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
prefetch_2x(inptr2); | |||
prefetch_2x(inptr3); | |||
int K = kmax - k0; | |||
//! read 16 * 4 in each row | |||
for (; K > 15; K -= 16) { | |||
interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); | |||
} | |||
if (K > 0) { | |||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); | |||
} | |||
} | |||
for (; y < ymax; y += 4) { | |||
const int8_t* inptr0 = inptr + y * ldin + k0; | |||
const int8_t* inptr1 = inptr0 + ldin; | |||
const int8_t* inptr2 = inptr1 + ldin; | |||
const int8_t* inptr3 = inptr2 + ldin; | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
prefetch_2x(inptr2); | |||
prefetch_2x(inptr3); | |||
int K = kmax - k0; | |||
//! read 4 * 4 in each row | |||
for (; K > 15; K -= 16) { | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; | |||
case 1: | |||
inptr2 = zerobuff; | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); | |||
} | |||
if (K > 0) { | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; | |||
case 1: | |||
inptr2 = zerobuff; | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); | |||
} | |||
} | |||
} | |||
static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
const int ksize = kmax - k0; | |||
const int ksize4 = round_up(ksize, 16) * 4; | |||
int8_t* outptr = out; | |||
int k = k0; | |||
for (; k < kmax; k += 16) { | |||
int ki = k; | |||
for (int cnt = 0; cnt < 2; ki += 8, cnt++) { | |||
const int8_t* inptr0 = in + ki * ldin + x0; | |||
const int8_t* inptr1 = inptr0 + ldin; | |||
const int8_t* inptr2 = inptr1 + ldin; | |||
const int8_t* inptr3 = inptr2 + ldin; | |||
const int8_t* inptr4 = inptr3 + ldin; | |||
const int8_t* inptr5 = inptr4 + ldin; | |||
const int8_t* inptr6 = inptr5 + ldin; | |||
const int8_t* inptr7 = inptr6 + ldin; | |||
int8_t* outptr_inner = outptr + ki - k; | |||
int remain = std::min(ki + 7 - kmax, 7); | |||
int x = x0; | |||
for (; x + 3 < xmax; x += 4) { | |||
if (remain >= 0) { | |||
switch (remain) { | |||
case 7: | |||
inptr0 = zerobuff; | |||
case 6: | |||
inptr1 = zerobuff; | |||
case 5: | |||
inptr2 = zerobuff; | |||
case 4: | |||
inptr3 = zerobuff; | |||
case 3: | |||
inptr4 = zerobuff; | |||
case 2: | |||
inptr5 = zerobuff; | |||
case 1: | |||
inptr6 = zerobuff; | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, | |||
inptr4, inptr5, inptr6, inptr7, | |||
outptr_inner); | |||
outptr_inner += ksize4; | |||
} | |||
if (x < xmax) { | |||
if (remain >= 0) { | |||
switch (remain) { | |||
case 7: | |||
inptr0 = zerobuff; | |||
case 6: | |||
inptr1 = zerobuff; | |||
case 5: | |||
inptr2 = zerobuff; | |||
case 4: | |||
inptr3 = zerobuff; | |||
case 3: | |||
inptr4 = zerobuff; | |||
case 2: | |||
inptr5 = zerobuff; | |||
case 1: | |||
inptr6 = zerobuff; | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
for (; x < xmax; x++) { | |||
*outptr_inner++ = *inptr0++; | |||
*outptr_inner++ = *inptr1++; | |||
*outptr_inner++ = *inptr2++; | |||
*outptr_inner++ = *inptr3++; | |||
*outptr_inner++ = *inptr4++; | |||
*outptr_inner++ = *inptr5++; | |||
*outptr_inner++ = *inptr6++; | |||
*outptr_inner++ = *inptr7++; | |||
outptr_inner += 8; | |||
} | |||
} | |||
} | |||
outptr += 16 * 4; | |||
} | |||
} | |||
} // namespace matmul_4x4x16 | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,892 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include <cstring> | |||
#if !(__ARM_FEATURE_DOTPROD) | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul_mk4_4x4x16 { | |||
/** | |||
* Overview of register layout: | |||
* | |||
* A 16x4 cell of Rhs is stored in 8bit in v0-q3. | |||
* B 16x4 cell of Lhs is stored in 8bit in q4-q7 | |||
* C 8x16 block of accumulators is stored in 8bit in q8--q31. | |||
* | |||
* \warning Fast kernel operating on int8 operands. | |||
* It is assumed that one of the two int8 operands only takes values | |||
* in [-127, 127], while the other may freely range in [-128, 127]. | |||
* The issue with both operands taking the value -128 is that: | |||
* -128*-128 + -128*-128 == -32768 overflows int16. | |||
* Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 | |||
* range. That is the basic idea of this kernel. | |||
* | |||
* | |||
* +--------+--------+---------+---------+ | |||
* |v4[0-16]|v5[0-16]| v6[0-16]| v7[0-16]| | |||
* Rhs +--------+--------+---------+---------+ | |||
* | | | | | | |||
* | |||
* Lhs | | | | | | |||
* | |||
* +--------+ - - - - +-------------------------------------+ | |||
* |v0[0-16]| |v16[0-4]|v17[0-4]| v18[0-4]| v19[0-4]| | |||
* |v1[0-16]| |v20[0-4]|v21[0-4]| v22[0-4]| v23[0-4]| | |||
* |v2[0-16]| |v24[0-4]|v25[0-4]| v26[0-4]| v27[0-4]| | |||
* |v3[0-16]| |v28[0-4]|v29[0-4]| v30[0-4]| v31[0-4]| | |||
* +--------+ - - - - +-------------------------------------+ | |||
* | |||
* Accumulator | |||
*/ | |||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, bool is_first_k) { | |||
K = div_ceil(K, 16); | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
asm volatile( | |||
// load accumulator C | |||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||
"eor v16.16b, v16.16b, v16.16b\n" | |||
"eor v17.16b, v17.16b, v17.16b\n" | |||
"eor v18.16b, v18.16b, v18.16b\n" | |||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||
"eor v19.16b, v19.16b, v19.16b\n" | |||
"eor v20.16b, v19.16b, v19.16b\n" | |||
"eor v21.16b, v19.16b, v19.16b\n" | |||
"ld1 {v4.16b, v5.16b}, [%[b_ptr]], #32\n" | |||
"eor v22.16b, v19.16b, v19.16b\n" | |||
"PRFM PLDL1KEEP, [%[a_ptr], #32]\n" | |||
"eor v23.16b, v19.16b, v19.16b\n" | |||
"eor v24.16b, v19.16b, v19.16b\n" | |||
"PRFM PLDL1KEEP, [%[b_ptr], #32]\n" | |||
"eor v25.16b, v19.16b, v19.16b\n" | |||
"eor v26.16b, v19.16b, v19.16b\n" | |||
"PRFM PLDL1KEEP, [%[b_ptr], #64]\n" | |||
"eor v27.16b, v19.16b, v19.16b\n" | |||
"eor v28.16b, v19.16b, v19.16b\n" | |||
"PRFM PLDL1KEEP, [%[a_ptr], #64]\n" | |||
"eor v29.16b, v19.16b, v19.16b\n" | |||
"eor v30.16b, v19.16b, v19.16b\n" | |||
"PRFM PLDL1KEEP, [%[b_ptr], #128]\n" | |||
"eor v31.16b, v19.16b, v19.16b\n" | |||
//! if K==1 jump to compute last K | |||
"cmp %w[k], #2\n" | |||
"beq 2f\n" | |||
"blt 3f\n" | |||
//! K>2 | |||
"1:\n" | |||
//! First k | |||
"smull v8.8h, v0.8b, v4.8b\n" | |||
"smull v9.8h, v0.8b, v5.8b\n" | |||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||
"smull v12.8h, v1.8b, v4.8b\n" | |||
"smull v13.8h, v1.8b, v5.8b\n" | |||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||
"smull v10.8h, v0.8b, v6.8b\n" | |||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||
"smull v11.8h, v0.8b, v7.8b\n" | |||
"smull v14.8h, v1.8b, v6.8b\n" | |||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||
"smull v15.8h, v1.8b, v7.8b\n" | |||
"sadalp v16.4s, v8.8h\n" | |||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||
"sadalp v17.4s, v9.8h\n" | |||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||
"smull v8.8h, v2.8b, v4.8b\n" | |||
"smull v9.8h, v2.8b, v5.8b\n" | |||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||
"smull v12.8h, v3.8b, v4.8b\n" | |||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||
"smull v13.8h, v3.8b, v5.8b\n" | |||
"sadalp v18.4s, v10.8h\n" | |||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||
"sadalp v19.4s, v11.8h\n" | |||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||
"smull v10.8h, v2.8b, v6.8b\n" | |||
"smull v11.8h, v2.8b, v7.8b\n" | |||
"ld1 {v4.16b}, [%[b_ptr]], #16\n" | |||
"smull v14.8h, v3.8b, v6.8b\n" | |||
"ld1 {v5.16b}, [%[b_ptr]], #16\n" | |||
"smull v15.8h, v3.8b, v7.8b\n" | |||
"sadalp v24.4s, v8.8h\n" | |||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||
"sadalp v25.4s, v9.8h\n" | |||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||
"sadalp v28.4s, v12.8h\n" | |||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||
"sadalp v29.4s, v13.8h\n" | |||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||
//! Second k | |||
"smull v8.8h, v0.8b, v4.8b\n" | |||
"smull v9.8h, v0.8b, v5.8b\n" | |||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||
"smull v12.8h, v1.8b, v4.8b\n" | |||
"smull v13.8h, v1.8b, v5.8b\n" | |||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||
"sadalp v26.4s, v10.8h\n" | |||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||
"sadalp v27.4s, v11.8h\n" | |||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||
"sadalp v30.4s, v14.8h\n" | |||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||
"sadalp v31.4s, v15.8h\n" | |||
"smull v10.8h, v0.8b, v6.8b\n" | |||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||
"smull v11.8h, v0.8b, v7.8b\n" | |||
"smull v14.8h, v1.8b, v6.8b\n" | |||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||
"smull v15.8h, v1.8b, v7.8b\n" | |||
"sadalp v16.4s, v8.8h\n" | |||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||
"sadalp v17.4s, v9.8h\n" | |||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||
"smull v8.8h, v2.8b, v4.8b\n" | |||
"smull v9.8h, v2.8b, v5.8b\n" | |||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||
"smull v12.8h, v3.8b, v4.8b\n" | |||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||
"smull v13.8h, v3.8b, v5.8b\n" | |||
"sadalp v18.4s, v10.8h\n" | |||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||
"sadalp v19.4s, v11.8h\n" | |||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||
"sub %w[k], %w[k], #2\n" | |||
"cmp %w[k], #2\n" | |||
"smull v10.8h, v2.8b, v6.8b\n" | |||
"ld1 {v4.16b}, [%[b_ptr]], #16\n" | |||
"smull v11.8h, v2.8b, v7.8b\n" | |||
"ld1 {v5.16b}, [%[b_ptr]], #16\n" | |||
"smull v14.8h, v3.8b, v6.8b\n" | |||
"sadalp v24.4s, v8.8h\n" | |||
"smull v15.8h, v3.8b, v7.8b\n" | |||
"sadalp v25.4s, v9.8h\n" | |||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||
"sadalp v28.4s, v12.8h\n" | |||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||
"sadalp v29.4s, v13.8h\n" | |||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||
"sadalp v26.4s, v10.8h\n" | |||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||
"sadalp v27.4s, v11.8h\n" | |||
"sadalp v30.4s, v14.8h\n" | |||
"sadalp v31.4s, v15.8h\n" | |||
"bgt 1b\n" | |||
"blt 3f\n" | |||
//! K==2 | |||
"2:\n" | |||
"smull v8.8h, v0.8b, v4.8b\n" | |||
"smull v9.8h, v0.8b, v5.8b\n" | |||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||
"smull v12.8h, v1.8b, v4.8b\n" | |||
"smull v13.8h, v1.8b, v5.8b\n" | |||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||
"smull v10.8h, v0.8b, v6.8b\n" | |||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||
"smull v11.8h, v0.8b, v7.8b\n" | |||
"smull v14.8h, v1.8b, v6.8b\n" | |||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||
"smull v15.8h, v1.8b, v7.8b\n" | |||
"sadalp v16.4s, v8.8h\n" | |||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||
"sadalp v17.4s, v9.8h\n" | |||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||
"smull v8.8h, v2.8b, v4.8b\n" | |||
"smull v9.8h, v2.8b, v5.8b\n" | |||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||
"smull v12.8h, v3.8b, v4.8b\n" | |||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||
"smull v13.8h, v3.8b, v5.8b\n" | |||
"sadalp v18.4s, v10.8h\n" | |||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||
"sadalp v19.4s, v11.8h\n" | |||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||
"smull v10.8h, v2.8b, v6.8b\n" | |||
"smull v11.8h, v2.8b, v7.8b\n" | |||
"ld1 {v4.16b}, [%[b_ptr]], #16\n" | |||
"smull v14.8h, v3.8b, v6.8b\n" | |||
"ld1 {v5.16b}, [%[b_ptr]], #16\n" | |||
"smull v15.8h, v3.8b, v7.8b\n" | |||
"sadalp v24.4s, v8.8h\n" | |||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||
"sadalp v25.4s, v9.8h\n" | |||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||
"sadalp v28.4s, v12.8h\n" | |||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||
"sadalp v29.4s, v13.8h\n" | |||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||
"sadalp v26.4s, v10.8h\n" | |||
"sadalp v27.4s, v11.8h\n" | |||
"sadalp v30.4s, v14.8h\n" | |||
"sadalp v31.4s, v15.8h\n" | |||
//! K==1 | |||
"3:\n" | |||
"smull v8.8h, v0.8b, v4.8b\n" | |||
"smull v9.8h, v0.8b, v5.8b\n" | |||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||
"smull v12.8h, v1.8b, v4.8b\n" | |||
"smull v13.8h, v1.8b, v5.8b\n" | |||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||
"smull v10.8h, v0.8b, v6.8b\n" | |||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||
"smull v11.8h, v0.8b, v7.8b\n" | |||
"smull v14.8h, v1.8b, v6.8b\n" | |||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||
"smull v15.8h, v1.8b, v7.8b\n" | |||
"sadalp v16.4s, v8.8h\n" | |||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||
"sadalp v17.4s, v9.8h\n" | |||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||
"smull v8.8h, v2.8b, v4.8b\n" | |||
"smull v9.8h, v2.8b, v5.8b\n" | |||
"smull v12.8h, v3.8b, v4.8b\n" | |||
"smull v13.8h, v3.8b, v5.8b\n" | |||
"sadalp v18.4s, v10.8h\n" | |||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||
"sadalp v19.4s, v11.8h\n" | |||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||
"smull v10.8h, v2.8b, v6.8b\n" | |||
"sadalp v24.4s, v8.8h\n" | |||
"smull v11.8h, v2.8b, v7.8b\n" | |||
"sadalp v25.4s, v9.8h\n" | |||
"smull v14.8h, v3.8b, v6.8b\n" | |||
"sadalp v28.4s, v12.8h\n" | |||
"smull v15.8h, v3.8b, v7.8b\n" | |||
"sadalp v29.4s, v13.8h\n" | |||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||
"sadalp v26.4s, v10.8h\n" | |||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||
"sadalp v27.4s, v11.8h\n" | |||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||
"sadalp v30.4s, v14.8h\n" | |||
"sadalp v31.4s, v15.8h\n" | |||
"addp v4.4s, v16.4s, v20.4s\n" | |||
"addp v5.4s, v24.4s, v28.4s\n" | |||
"addp v6.4s, v17.4s, v21.4s\n" | |||
"addp v7.4s, v25.4s, v29.4s\n" | |||
"addp v8.4s, v18.4s, v22.4s\n" | |||
"addp v9.4s, v26.4s, v30.4s\n" | |||
"addp v10.4s, v19.4s, v23.4s\n" | |||
"addp v11.4s, v27.4s, v31.4s\n" | |||
"cmp %w[is_first_k], #1\n" | |||
"addp v0.4s, v4.4s, v5.4s\n" | |||
"addp v1.4s, v6.4s, v7.4s\n" | |||
"addp v2.4s, v8.4s, v9.4s\n" | |||
"addp v3.4s, v10.4s, v11.4s\n" | |||
"beq 6f\n" | |||
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output]]\n" | |||
"add v0.4s, v0.4s, v8.4s\n" | |||
"add v1.4s, v1.4s, v9.4s\n" | |||
"add v2.4s, v2.4s, v10.4s\n" | |||
"add v3.4s, v3.4s, v11.4s\n" | |||
"6:\n" | |||
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "cc", "memory"); | |||
} | |||
static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, bool is_first_k, size_t remain_n) { | |||
K = div_ceil(K, 16); | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
asm volatile( | |||
// load accumulator C | |||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||
"eor v16.16b, v16.16b, v16.16b\n" | |||
"eor v17.16b, v17.16b, v17.16b\n" | |||
"eor v18.16b, v18.16b, v18.16b\n" | |||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||
"eor v19.16b, v19.16b, v19.16b\n" | |||
"eor v20.16b, v19.16b, v19.16b\n" | |||
"eor v21.16b, v19.16b, v19.16b\n" | |||
"ld1 {v4.16b, v5.16b}, [%[b_ptr]], #32\n" | |||
"eor v22.16b, v19.16b, v19.16b\n" | |||
"PRFM PLDL1KEEP, [%[a_ptr], #32]\n" | |||
"eor v23.16b, v19.16b, v19.16b\n" | |||
"eor v24.16b, v19.16b, v19.16b\n" | |||
"PRFM PLDL1KEEP, [%[b_ptr], #32]\n" | |||
"eor v25.16b, v19.16b, v19.16b\n" | |||
"eor v26.16b, v19.16b, v19.16b\n" | |||
"PRFM PLDL1KEEP, [%[b_ptr], #64]\n" | |||
"eor v27.16b, v19.16b, v19.16b\n" | |||
"eor v28.16b, v19.16b, v19.16b\n" | |||
"PRFM PLDL1KEEP, [%[a_ptr], #64]\n" | |||
"eor v29.16b, v19.16b, v19.16b\n" | |||
"eor v30.16b, v19.16b, v19.16b\n" | |||
"PRFM PLDL1KEEP, [%[b_ptr], #128]\n" | |||
"eor v31.16b, v19.16b, v19.16b\n" | |||
//! if K==1 jump to compute last K | |||
"cmp %w[k], #2\n" | |||
"beq 2f\n" | |||
"blt 3f\n" | |||
//! K>2 | |||
"1:\n" | |||
//! First k | |||
"smull v8.8h, v0.8b, v4.8b\n" | |||
"smull v9.8h, v0.8b, v5.8b\n" | |||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||
"smull v12.8h, v1.8b, v4.8b\n" | |||
"smull v13.8h, v1.8b, v5.8b\n" | |||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||
"smull v10.8h, v0.8b, v6.8b\n" | |||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||
"smull v11.8h, v0.8b, v7.8b\n" | |||
"smull v14.8h, v1.8b, v6.8b\n" | |||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||
"smull v15.8h, v1.8b, v7.8b\n" | |||
"sadalp v16.4s, v8.8h\n" | |||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||
"sadalp v17.4s, v9.8h\n" | |||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||
"smull v8.8h, v2.8b, v4.8b\n" | |||
"smull v9.8h, v2.8b, v5.8b\n" | |||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||
"smull v12.8h, v3.8b, v4.8b\n" | |||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||
"smull v13.8h, v3.8b, v5.8b\n" | |||
"sadalp v18.4s, v10.8h\n" | |||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||
"sadalp v19.4s, v11.8h\n" | |||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||
"smull v10.8h, v2.8b, v6.8b\n" | |||
"smull v11.8h, v2.8b, v7.8b\n" | |||
"ld1 {v4.16b}, [%[b_ptr]], #16\n" | |||
"smull v14.8h, v3.8b, v6.8b\n" | |||
"ld1 {v5.16b}, [%[b_ptr]], #16\n" | |||
"smull v15.8h, v3.8b, v7.8b\n" | |||
"sadalp v24.4s, v8.8h\n" | |||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||
"sadalp v25.4s, v9.8h\n" | |||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||
"sadalp v28.4s, v12.8h\n" | |||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||
"sadalp v29.4s, v13.8h\n" | |||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||
//! Second k | |||
"smull v8.8h, v0.8b, v4.8b\n" | |||
"smull v9.8h, v0.8b, v5.8b\n" | |||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||
"smull v12.8h, v1.8b, v4.8b\n" | |||
"smull v13.8h, v1.8b, v5.8b\n" | |||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||
"sadalp v26.4s, v10.8h\n" | |||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||
"sadalp v27.4s, v11.8h\n" | |||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||
"sadalp v30.4s, v14.8h\n" | |||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||
"sadalp v31.4s, v15.8h\n" | |||
"smull v10.8h, v0.8b, v6.8b\n" | |||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||
"smull v11.8h, v0.8b, v7.8b\n" | |||
"smull v14.8h, v1.8b, v6.8b\n" | |||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||
"smull v15.8h, v1.8b, v7.8b\n" | |||
"sadalp v16.4s, v8.8h\n" | |||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||
"sadalp v17.4s, v9.8h\n" | |||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||
"smull v8.8h, v2.8b, v4.8b\n" | |||
"smull v9.8h, v2.8b, v5.8b\n" | |||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||
"smull v12.8h, v3.8b, v4.8b\n" | |||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||
"smull v13.8h, v3.8b, v5.8b\n" | |||
"sadalp v18.4s, v10.8h\n" | |||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||
"sadalp v19.4s, v11.8h\n" | |||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||
"sub %w[k], %w[k], #2\n" | |||
"cmp %w[k], #2\n" | |||
"smull v10.8h, v2.8b, v6.8b\n" | |||
"ld1 {v4.16b}, [%[b_ptr]], #16\n" | |||
"smull v11.8h, v2.8b, v7.8b\n" | |||
"ld1 {v5.16b}, [%[b_ptr]], #16\n" | |||
"smull v14.8h, v3.8b, v6.8b\n" | |||
"sadalp v24.4s, v8.8h\n" | |||
"smull v15.8h, v3.8b, v7.8b\n" | |||
"sadalp v25.4s, v9.8h\n" | |||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||
"sadalp v28.4s, v12.8h\n" | |||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||
"sadalp v29.4s, v13.8h\n" | |||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||
"sadalp v26.4s, v10.8h\n" | |||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||
"sadalp v27.4s, v11.8h\n" | |||
"sadalp v30.4s, v14.8h\n" | |||
"sadalp v31.4s, v15.8h\n" | |||
"bgt 1b\n" | |||
"blt 3f\n" | |||
//! K==2 | |||
"2:\n" | |||
"smull v8.8h, v0.8b, v4.8b\n" | |||
"smull v9.8h, v0.8b, v5.8b\n" | |||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||
"smull v12.8h, v1.8b, v4.8b\n" | |||
"smull v13.8h, v1.8b, v5.8b\n" | |||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||
"smull v10.8h, v0.8b, v6.8b\n" | |||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||
"smull v11.8h, v0.8b, v7.8b\n" | |||
"smull v14.8h, v1.8b, v6.8b\n" | |||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||
"smull v15.8h, v1.8b, v7.8b\n" | |||
"sadalp v16.4s, v8.8h\n" | |||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||
"sadalp v17.4s, v9.8h\n" | |||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||
"smull v8.8h, v2.8b, v4.8b\n" | |||
"smull v9.8h, v2.8b, v5.8b\n" | |||
"ld1 {v0.16b}, [%[a_ptr]], #16\n" | |||
"smull v12.8h, v3.8b, v4.8b\n" | |||
"ld1 {v1.16b}, [%[a_ptr]], #16\n" | |||
"smull v13.8h, v3.8b, v5.8b\n" | |||
"sadalp v18.4s, v10.8h\n" | |||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||
"sadalp v19.4s, v11.8h\n" | |||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||
"smull v10.8h, v2.8b, v6.8b\n" | |||
"smull v11.8h, v2.8b, v7.8b\n" | |||
"ld1 {v4.16b}, [%[b_ptr]], #16\n" | |||
"smull v14.8h, v3.8b, v6.8b\n" | |||
"ld1 {v5.16b}, [%[b_ptr]], #16\n" | |||
"smull v15.8h, v3.8b, v7.8b\n" | |||
"sadalp v24.4s, v8.8h\n" | |||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||
"sadalp v25.4s, v9.8h\n" | |||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||
"sadalp v28.4s, v12.8h\n" | |||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||
"sadalp v29.4s, v13.8h\n" | |||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||
"sadalp v26.4s, v10.8h\n" | |||
"sadalp v27.4s, v11.8h\n" | |||
"sadalp v30.4s, v14.8h\n" | |||
"sadalp v31.4s, v15.8h\n" | |||
//! K==1 | |||
"3:\n" | |||
"smull v8.8h, v0.8b, v4.8b\n" | |||
"smull v9.8h, v0.8b, v5.8b\n" | |||
"ld1 {v6.16b}, [%[b_ptr]], #16\n" | |||
"smull v12.8h, v1.8b, v4.8b\n" | |||
"smull v13.8h, v1.8b, v5.8b\n" | |||
"ld1 {v7.16b}, [%[b_ptr]], #16\n" | |||
"smlal2 v8.8h, v0.16b, v4.16b\n" | |||
"smlal2 v9.8h, v0.16b, v5.16b\n" | |||
"smlal2 v12.8h, v1.16b, v4.16b\n" | |||
"smlal2 v13.8h, v1.16b, v5.16b\n" | |||
"smull v10.8h, v0.8b, v6.8b\n" | |||
"ld1 {v2.16b}, [%[a_ptr]], #16\n" | |||
"smull v11.8h, v0.8b, v7.8b\n" | |||
"smull v14.8h, v1.8b, v6.8b\n" | |||
"ld1 {v3.16b}, [%[a_ptr]], #16\n" | |||
"smull v15.8h, v1.8b, v7.8b\n" | |||
"sadalp v16.4s, v8.8h\n" | |||
"smlal2 v10.8h, v0.16b, v6.16b\n" | |||
"sadalp v17.4s, v9.8h\n" | |||
"smlal2 v11.8h, v0.16b, v7.16b\n" | |||
"sadalp v20.4s, v12.8h\n" | |||
"smlal2 v14.8h, v1.16b, v6.16b\n" | |||
"sadalp v21.4s, v13.8h\n" | |||
"smlal2 v15.8h, v1.16b, v7.16b\n" | |||
"smull v8.8h, v2.8b, v4.8b\n" | |||
"smull v9.8h, v2.8b, v5.8b\n" | |||
"smull v12.8h, v3.8b, v4.8b\n" | |||
"smull v13.8h, v3.8b, v5.8b\n" | |||
"sadalp v18.4s, v10.8h\n" | |||
"smlal2 v8.8h, v2.16b, v4.16b\n" | |||
"sadalp v19.4s, v11.8h\n" | |||
"smlal2 v9.8h, v2.16b, v5.16b\n" | |||
"sadalp v22.4s, v14.8h\n" | |||
"smlal2 v12.8h, v3.16b, v4.16b\n" | |||
"sadalp v23.4s, v15.8h\n" | |||
"smlal2 v13.8h, v3.16b, v5.16b\n" | |||
"smull v10.8h, v2.8b, v6.8b\n" | |||
"sadalp v24.4s, v8.8h\n" | |||
"smull v11.8h, v2.8b, v7.8b\n" | |||
"sadalp v25.4s, v9.8h\n" | |||
"smull v14.8h, v3.8b, v6.8b\n" | |||
"sadalp v28.4s, v12.8h\n" | |||
"smull v15.8h, v3.8b, v7.8b\n" | |||
"sadalp v29.4s, v13.8h\n" | |||
"smlal2 v10.8h, v2.16b, v6.16b\n" | |||
"smlal2 v11.8h, v2.16b, v7.16b\n" | |||
"sadalp v26.4s, v10.8h\n" | |||
"smlal2 v14.8h, v3.16b, v6.16b\n" | |||
"sadalp v27.4s, v11.8h\n" | |||
"smlal2 v15.8h, v3.16b, v7.16b\n" | |||
"sadalp v30.4s, v14.8h\n" | |||
"sadalp v31.4s, v15.8h\n" | |||
"addp v4.4s, v16.4s, v20.4s\n" | |||
"addp v5.4s, v24.4s, v28.4s\n" | |||
"addp v6.4s, v17.4s, v21.4s\n" | |||
"addp v7.4s, v25.4s, v29.4s\n" | |||
"addp v8.4s, v18.4s, v22.4s\n" | |||
"addp v9.4s, v26.4s, v30.4s\n" | |||
"addp v10.4s, v19.4s, v23.4s\n" | |||
"addp v11.4s, v27.4s, v31.4s\n" | |||
"addp v0.4s, v4.4s, v5.4s\n" | |||
"addp v1.4s, v6.4s, v7.4s\n" | |||
"addp v2.4s, v8.4s, v9.4s\n" | |||
"addp v3.4s, v10.4s, v11.4s\n" | |||
"cmp %w[is_first_k], #1\n" | |||
"beq 6f\n" | |||
"cmp %w[remain_n], #3\n" | |||
"beq 1003f\n" | |||
"cmp %w[remain_n], #2\n" | |||
"beq 1002f\n" | |||
"cmp %w[remain_n], #1\n" | |||
"beq 1001f\n" | |||
"1003:\n" | |||
"ld1 {v8.4s, v9.4s, v10.4s}, [%[output]]\n" | |||
"add v0.4s, v0.4s, v8.4s\n" | |||
"add v1.4s, v1.4s, v9.4s\n" | |||
"add v2.4s, v2.4s, v10.4s\n" | |||
"b 6f\n" | |||
"1002:\n" | |||
"ld1 {v8.4s, v9.4s}, [%[output]]\n" | |||
"add v0.4s, v0.4s, v8.4s\n" | |||
"add v1.4s, v1.4s, v9.4s\n" | |||
"b 6f\n" | |||
"1001:\n" | |||
"ld1 {v8.4s}, [%[output]]\n" | |||
"add v0.4s, v0.4s, v8.4s\n" | |||
"6:\n" | |||
"cmp %w[remain_n], #3\n" | |||
"beq 10003f\n" | |||
"cmp %w[remain_n], #2\n" | |||
"beq 10002f\n" | |||
"cmp %w[remain_n], #1\n" | |||
"beq 10001f\n" | |||
"10003:\n" | |||
"str q2, [%[output], #32]\n" | |||
"10002:\n" | |||
"str q1, [%[output], #16]\n" | |||
"10001:\n" | |||
"str q0, [%[output]]\n" | |||
"7:\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[remain_n] "+r"(remain_n), [is_first_k] "+r"(is_first_k), | |||
[k] "+r"(K), [output] "+r"(output) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "cc", "memory"); | |||
} | |||
static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
//! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} | |||
int8_t zerobuff[4][64]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); | |||
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, | |||
"mk4 matmul with m is not times of 4"); | |||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, | |||
"mk4 matmul with k is not times of 4"); | |||
size_t roundk = round_up(kmax - k0, 16); | |||
size_t out_offset = roundk * 4; | |||
int y = y0; | |||
int start_y = y0 / 4; | |||
for (; y + 15 < ymax; y += 16, start_y += 4) { | |||
const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; | |||
const int8_t* inptr1 = inptr0 + ldin; | |||
const int8_t* inptr2 = inptr1 + ldin; | |||
const int8_t* inptr3 = inptr2 + ldin; | |||
int8_t* output = outptr + start_y * out_offset; | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
prefetch_2x(inptr2); | |||
prefetch_2x(inptr3); | |||
int K = kmax - k0; | |||
for (; K > 15; K -= 16) { | |||
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, | |||
out_offset); | |||
output += 64; | |||
} | |||
if (K > 0) { | |||
std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4); | |||
std::memcpy(zerobuff[1], inptr1, sizeof(int8_t) * K * 4); | |||
std::memcpy(zerobuff[2], inptr2, sizeof(int8_t) * K * 4); | |||
std::memcpy(zerobuff[3], inptr3, sizeof(int8_t) * K * 4); | |||
inptr0 = zerobuff[0]; | |||
inptr1 = zerobuff[1]; | |||
inptr2 = zerobuff[2]; | |||
inptr3 = zerobuff[3]; | |||
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, | |||
out_offset); | |||
output += 64; | |||
} | |||
} | |||
for (; y + 3 < ymax; y += 4, start_y++) { | |||
const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; | |||
int8_t* output = outptr + start_y * out_offset; | |||
prefetch_2x(inptr0); | |||
int K = kmax - k0; | |||
for (; K > 15; K -= 16) { | |||
transpose_interleave_1x4_4_b(inptr0, output); | |||
output += 64; | |||
} | |||
if (K > 0) { | |||
std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4); | |||
inptr0 = zerobuff[0]; | |||
transpose_interleave_1x4_4_b(inptr0, output); | |||
output += 64; | |||
} | |||
} | |||
} | |||
static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
int32_t zerobuff[4]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
const int ksize = kmax - k0; | |||
const int ICB = (ksize) / 4; | |||
const int ksize4 = round_up<int>(ICB, 4) * 4; | |||
int32_t* outptr = reinterpret_cast<int32_t*>(out); | |||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, | |||
"mk4 matmul with k is not times of 4"); | |||
int k = k0 / 4; | |||
for (; k + 3 < ICB; k += 4) { | |||
const int32_t* inptr0 = | |||
reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||
const int32_t* inptr1 = | |||
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | |||
const int32_t* inptr2 = | |||
reinterpret_cast<const int32_t*>(in + (k + 2) * ldin + x0); | |||
const int32_t* inptr3 = | |||
reinterpret_cast<const int32_t*>(in + (k + 3) * ldin + x0); | |||
int32_t* outptr_inner = outptr; | |||
int x = x0; | |||
for (; x + 3 < xmax; x += 4) { | |||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner); | |||
outptr_inner += ksize4; | |||
} | |||
if (x < xmax) { | |||
for (; x < xmax; x++) { | |||
*outptr_inner++ = *inptr0++; | |||
*outptr_inner++ = *inptr1++; | |||
*outptr_inner++ = *inptr2++; | |||
*outptr_inner++ = *inptr3++; | |||
} | |||
} | |||
outptr += 4 * 4; | |||
} | |||
if (k < ICB) { | |||
const int32_t* inptr0 = | |||
reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||
const int32_t* inptr1 = | |||
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | |||
const int32_t* inptr2 = | |||
reinterpret_cast<const int32_t*>(in + (k + 2) * ldin + x0); | |||
const int32_t* inptr3 = | |||
reinterpret_cast<const int32_t*>(in + (k + 3) * ldin + x0); | |||
int32_t* outptr_inner = outptr; | |||
int x = x0; | |||
for (; x + 3 < xmax; x += 4) { | |||
if (k + 3 >= ICB) { | |||
switch (k + 3 - ICB) { | |||
case 2: | |||
inptr1 = zerobuff; | |||
case 1: | |||
inptr2 = zerobuff; | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner); | |||
outptr_inner += ksize4; | |||
} | |||
if (x < xmax) { | |||
if (k + 3 >= ICB) { | |||
switch (k + 3 - ICB) { | |||
case 2: | |||
inptr1 = zerobuff; | |||
case 1: | |||
inptr2 = zerobuff; | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
for (; x < xmax; x++) { | |||
*outptr_inner++ = *inptr0++; | |||
*outptr_inner++ = *inptr1++; | |||
*outptr_inner++ = *inptr2++; | |||
*outptr_inner++ = *inptr3++; | |||
} | |||
} | |||
outptr += 4 * 4; | |||
} | |||
} | |||
} // namespace matmul_4x4x16 | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,263 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int8/strategy.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#if !(__ARM_FEATURE_DOTPROD) | |||
#include "src/aarch64/matrix_mul/int8/strategy.h" | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h" | |||
#include "src/aarch64/matrix_mul/int8/kernel_8x8x8.h" | |||
#include "src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
///////////////////////// gemm_s8_4x4 //////////////////////////////////// | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4); | |||
void gemm_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
} else { | |||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
} | |||
} | |||
void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
constexpr size_t A_INTERLEAVE = 4; | |||
constexpr size_t B_INTERLEAVE = 4; | |||
//! K is packed to times of 4 | |||
K = round_up<size_t>(K, 16); | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
int32_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, LDC, | |||
is_first_k, 4, | |||
std::min<size_t>(N - n, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
for (; m < M; m += 4) { | |||
int32_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4_remain( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
} | |||
///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); | |||
void gemm_mk4_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
megdnn_assert(!transpose, | |||
"the gemm_mk4_s8_4x4 strategy is not support transpose A"); | |||
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
} | |||
void gemm_mk4_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, bool transpose) const { | |||
megdnn_assert(!transpose, | |||
"the gemm_mk4_s8_4x4 strategy is not support transpose B"); | |||
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
} | |||
void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
constexpr size_t A_INTERLEAVE = 4; | |||
constexpr size_t B_INTERLEAVE = 4; | |||
//! K is packed to times of 4 | |||
megdnn_assert(K % 4 == 0, "K is not time of 4"); | |||
const size_t K4 = round_up<size_t>(K, 16) * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
int32_t* output = C + (m / 4 * LDC); | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, | |||
is_first_k); | |||
output += B_INTERLEAVE * 4; | |||
cur_packB += K4; | |||
} | |||
if (n < N) { | |||
matmul_mk4_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, | |||
is_first_k, N - n); | |||
} | |||
packA += K4; | |||
} | |||
} | |||
///////////////////////// gemm_s8_8x8 //////////////////////////////////// | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8); | |||
void gemm_s8_8x8::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||
ymax, k0, kmax); | |||
} else { | |||
matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
} | |||
} | |||
void gemm_s8_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||
k0, kmax); | |||
} else { | |||
matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
constexpr size_t A_INTERLEAVE = 8; | |||
constexpr size_t B_INTERLEAVE = 8; | |||
//! K is packed to times of 4 | |||
K = round_up<size_t>(K, 8); | |||
const int K8 = K * 8; | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
int32_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K8; | |||
} | |||
for (; m < M; m += 4) { | |||
int32_t* output = C + (m * LDC); | |||
const dt_int8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,34 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int8/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#if !(__ARM_FEATURE_DOTPROD) | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, | |||
gemm_s8_4x4); | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, | |||
gemm_mk4_s8_4x4); | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, | |||
gemm_s8_8x8); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,116 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/int8_dot/gemv.h" | |||
#include <cstddef> | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#if __ARM_FEATURE_DOTPROD | |||
namespace { | |||
void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride) { | |||
megdnn_assert(N == 1 && Bstride == 1); | |||
size_t m = 0; | |||
for (; m + 2 <= M; m += 2) { | |||
int32_t acc[4]; | |||
int32x4_t acc_neon = vdupq_n_s32(0); | |||
size_t k = 0; | |||
for (; k + 16 <= K; k += 16) { | |||
int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k)); | |||
int64x2_t a1 = | |||
vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k)); | |||
//! the first 8 elements is m, the last 8 elements is m + 1 | |||
int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1)); | |||
int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1)); | |||
int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k)); | |||
int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0)); | |||
int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0)); | |||
acc_neon = vdotq_s32(acc_neon, a2, b2); | |||
acc_neon = vdotq_s32(acc_neon, a3, b3); | |||
} | |||
vst1q_s32(acc, acc_neon); | |||
for (; k + 8 <= K; k += 8) { | |||
int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||
int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k); | |||
int8x8_t b0 = vld1_s8(B + k); | |||
uint32x2_t zero = vdup_n_s32(0); | |||
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||
zero = vdup_n_s32(0); | |||
acc[3] += vaddv_s32(vdot_s32(zero, a1, b0)); | |||
} | |||
for (; k < K; ++k) { | |||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||
acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k]; | |||
} | |||
C[m * Cstride] = acc[0] + acc[1]; | |||
C[(m + 1) * Cstride] = acc[2] + acc[3]; | |||
} | |||
for (; m < M; ++m) { | |||
int32_t acc[4]; | |||
int32x4_t acc_neon = vdupq_n_s32(0); | |||
size_t k = 0; | |||
for (; k + 16 <= K; k += 16) { | |||
int8x16_t a0 = vld1q_s8(A + m * Astride + k); | |||
int8x16_t b0 = vld1q_s8(B + k); | |||
acc_neon = vdotq_s32(acc_neon, a0, b0); | |||
} | |||
vst1q_s32(acc, acc_neon); | |||
for (; k + 8 <= K; k += 8) { | |||
int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||
int8x8_t b0 = vld1_s8(B + k); | |||
uint32x2_t zero = vdup_n_s32(0); | |||
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||
} | |||
for (; k < K; ++k) { | |||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||
} | |||
C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; | |||
} | |||
} | |||
} // namespace | |||
bool megdnn::aarch64::matmul::is_gemv_like_preferred_int8( | |||
bool transposeA, bool transposeB, size_t M, size_t N, size_t K, | |||
size_t /* LDA */, size_t LDB, size_t /* LDC */) { | |||
if (transposeA) | |||
return false; | |||
if (transposeB) | |||
return false; | |||
MEGDNN_MARK_USED_VAR(K); | |||
MEGDNN_MARK_USED_VAR(M); | |||
return (N == 1 && LDB == 1); | |||
} | |||
void megdnn::aarch64::matmul::gemv_like_int8(const int8_t* __restrict A, | |||
const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, | |||
size_t N, size_t K, size_t Astride, | |||
size_t Bstride, size_t Cstride) { | |||
megdnn_assert(N == 1); | |||
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,34 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <cstddef> | |||
#include <cstdint> | |||
#if __ARM_FEATURE_DOTPROD | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
bool is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M, | |||
size_t N, size_t K, size_t LDA, size_t LDB, | |||
size_t LDC); | |||
void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,113 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/int8_dot/strategy.h" | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" | |||
#if __ARM_FEATURE_DOTPROD | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12); | |||
void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
} else { | |||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
} | |||
} | |||
void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
matmul_8x12x4::gemm_s8_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
constexpr size_t A_INTERLEAVE = 8; | |||
constexpr size_t B_INTERLEAVE = 12; | |||
//! K is packed to times of 4 | |||
K = round_up<size_t>(K, 4); | |||
const int K8 = (K << 3); | |||
const int K12 = K * 12; | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
int32_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
output += B_INTERLEAVE; | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K8; | |||
} | |||
for (; m < M; m += 4) { | |||
int32_t* output = C + (m * LDC); | |||
const dt_int8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(M - m, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,26 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int8_dot/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
#if __ARM_FEATURE_DOTPROD | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, | |||
gemm_s8_8x12); | |||
} // namespace aarch64 | |||
} // namespace matmul | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,439 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <inttypes.h> | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul_4x4x16 { | |||
/** | |||
* Overview of register layout: | |||
* | |||
* +---------+---------+---------+---------+ | |||
* |v20[0-15]|v21[0-15]|v22[0-15]|v23[0-15]| | |||
* Rhs +---------+---------+---------+---------+ | |||
* Lhs | | | | |||
* | |||
* +--------+ - - - - +---------+---------+---------+---------+ | |||
* |v0[0-15]| | v4[0-8] | v8[0-8]| v12[0-8]| v16[0-8]| | |||
* |v1[0-15]| | v5[0-8] | v9[0-8]| v13[0-8]| v17[0-8]| | |||
* |v2[0-15]| | v6[0-8] | v10[0-8]| v14[0-8]| v18[0-8]| | |||
* |v3[0-15]| | v7[0-8] | v11[0-8]| v15[0-8]| v19[0-8]| | |||
* +--------+ - - - - +---------+---------+---------+---------+ | |||
* | |||
* Accumulator | |||
*/ | |||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
K /= 16; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
LDC = LDC * sizeof(int16_t); | |||
// clang-format off | |||
#define LOAD_LINE(reg_index, n) \ | |||
"cmp x5, #0 \n" \ | |||
"beq 105f\n" \ | |||
"cmp %w[n_remain], #4\n" \ | |||
"blt 100" n "f\n" \ | |||
"ld1 {v" reg_index ".4h}, [x" n "], #8\n" \ | |||
"b 101" n "f\n" \ | |||
"100" n ":\n" \ | |||
"cmp %w[n_remain], #0\n" \ | |||
"blt 101" n "f\n" \ | |||
"ld1 {v" reg_index ".h}[0], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #1\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" reg_index ".h}[1], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #2\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" reg_index ".h}[2], [x" n "], #2\n" \ | |||
"101" n ":\n" \ | |||
"sub x5, x5, #1\n" | |||
#define LOAD_C \ | |||
"mov x5, %x[m_remain]\n" \ | |||
LOAD_LINE("24", "0") \ | |||
LOAD_LINE("25", "1") \ | |||
LOAD_LINE("26", "2") \ | |||
LOAD_LINE("27", "3") \ | |||
"105:\n" | |||
#define STORE_LINE(reg_index, n) \ | |||
"cmp x5, #0 \n" \ | |||
"beq 105f\n" \ | |||
"cmp %w[n_remain], #4\n" \ | |||
"blt 102" n "f\n" \ | |||
"st1 {v" reg_index ".4h}, [x" n "], #8\n" \ | |||
"b 103" n "f\n" \ | |||
"102" n ":\n" \ | |||
"cmp %w[n_remain], #0\n" \ | |||
"beq 103" n "f\n" \ | |||
"st1 {v" reg_index ".h}[0], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #1\n" \ | |||
"beq 103" n "f\n" \ | |||
"st1 {v" reg_index ".h}[1], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #2\n" \ | |||
"beq 103" n "f\n" \ | |||
"st1 {v" reg_index ".h}[2], [x" n "], #2\n" \ | |||
"103" n ":\n" \ | |||
"sub x5, x5, #1\n" | |||
#define STORE_C \ | |||
"mov x5, %x[m_remain]\n" \ | |||
STORE_LINE("24", "0") \ | |||
STORE_LINE("25", "1") \ | |||
STORE_LINE("26", "2") \ | |||
STORE_LINE("27", "3") \ | |||
"105:\n" | |||
// clang-format on | |||
register int16_t* outptr asm("x0") = output; | |||
asm volatile( | |||
"add x1, x0, %x[LDC]\n" | |||
"add x2, x1, %x[LDC]\n" | |||
"add x3, x2, %x[LDC]\n" | |||
// Clear accumulators | |||
"eor v4.16b, v4.16b, v4.16b\n" | |||
"eor v5.16b, v5.16b, v5.16b\n" | |||
"eor v6.16b, v6.16b, v6.16b\n" | |||
"eor v7.16b, v7.16b, v7.16b\n" | |||
"eor v8.16b, v8.16b, v8.16b\n" | |||
"eor v9.16b, v9.16b, v9.16b\n" | |||
"eor v10.16b, v10.16b, v10.16b\n" | |||
"eor v11.16b, v11.16b, v11.16b\n" | |||
"eor v12.16b, v12.16b, v12.16b\n" | |||
"eor v13.16b, v13.16b, v13.16b\n" | |||
"eor v14.16b, v14.16b, v14.16b\n" | |||
"eor v15.16b, v15.16b, v15.16b\n" | |||
"eor v16.16b, v16.16b, v16.16b\n" | |||
"eor v17.16b, v17.16b, v17.16b\n" | |||
"eor v18.16b, v18.16b, v18.16b\n" | |||
"eor v19.16b, v19.16b, v19.16b\n" | |||
// General loop. | |||
"1:\n" | |||
"ld1 {v20.16b}, [%[b_ptr]], 16\n" | |||
"ld1 {v0.16b}, [%[a_ptr]], 16\n" | |||
"ld1 {v1.16b}, [%[a_ptr]], 16\n" | |||
"ld1 {v2.16b}, [%[a_ptr]], 16\n" | |||
"ld1 {v3.16b}, [%[a_ptr]], 16\n" | |||
"ld1 {v21.16b}, [%[b_ptr]], 16\n" | |||
"smlal v4.8h, v0.8b, v20.8b\n" | |||
"smlal v5.8h, v1.8b, v20.8b\n" | |||
"smlal v6.8h, v2.8b, v20.8b\n" | |||
"smlal v7.8h, v3.8b, v20.8b\n" | |||
"smlal2 v4.8h, v0.16b, v20.16b\n" | |||
"smlal2 v5.8h, v1.16b, v20.16b\n" | |||
"smlal2 v6.8h, v2.16b, v20.16b\n" | |||
"smlal2 v7.8h, v3.16b, v20.16b\n" | |||
"ld1 {v22.16b}, [%[b_ptr]], 16\n" | |||
"smlal v8.8h, v0.8b, v21.8b\n" | |||
"smlal v9.8h, v1.8b, v21.8b\n" | |||
"smlal v10.8h, v2.8b, v21.8b\n" | |||
"smlal v11.8h, v3.8b, v21.8b\n" | |||
"smlal2 v8.8h, v0.16b, v21.16b\n" | |||
"smlal2 v9.8h, v1.16b, v21.16b\n" | |||
"smlal2 v10.8h, v2.16b, v21.16b\n" | |||
"smlal2 v11.8h, v3.16b, v21.16b\n" | |||
"ld1 {v23.16b}, [%[b_ptr]], 16\n" | |||
"smlal v12.8h, v0.8b, v22.8b\n" | |||
"smlal v13.8h, v1.8b, v22.8b\n" | |||
"smlal v14.8h, v2.8b, v22.8b\n" | |||
"smlal v15.8h, v3.8b, v22.8b\n" | |||
"smlal2 v12.8h, v0.16b, v22.16b\n" | |||
"smlal2 v13.8h, v1.16b, v22.16b\n" | |||
"smlal2 v14.8h, v2.16b, v22.16b\n" | |||
"smlal2 v15.8h, v3.16b, v22.16b\n" | |||
"smlal v16.8h, v0.8b, v23.8b\n" | |||
"smlal v17.8h, v1.8b, v23.8b\n" | |||
"smlal v18.8h, v2.8b, v23.8b\n" | |||
"smlal v19.8h, v3.8b, v23.8b\n" | |||
"smlal2 v16.8h, v0.16b, v23.16b\n" | |||
"smlal2 v17.8h, v1.16b, v23.16b\n" | |||
"smlal2 v18.8h, v2.16b, v23.16b\n" | |||
"smlal2 v19.8h, v3.16b, v23.16b\n" | |||
"subs %w[K], %w[K], #1\n" | |||
"cbnz %w[K], 1b\n" | |||
"cmp %w[is_first_k], #1\n" | |||
"beq 2f\n" LOAD_C | |||
"b 3f\n" | |||
"2:\n" // Clear the C regs. | |||
"eor v24.16b, v24.16b, v24.16b\n" | |||
"eor v25.16b, v25.16b, v25.16b\n" | |||
"eor v26.16b, v26.16b, v26.16b\n" | |||
"eor v27.16b, v27.16b, v27.16b\n" | |||
"3:\n" | |||
// Reduce v4-v19 to v0-v3 | |||
"addv h20, v4.8h\n" | |||
"addv h21, v8.8h\n" | |||
"addv h22, v12.8h\n" | |||
"addv h23, v16.8h\n" | |||
"ins v0.h[0], v20.h[0]\n" | |||
"ins v0.h[1], v21.h[0]\n" | |||
"ins v0.h[2], v22.h[0]\n" | |||
"ins v0.h[3], v23.h[0]\n" | |||
"add v24.4h, v24.4h, v0.4h\n" | |||
"addv h28, v5.8h\n" | |||
"addv h29, v9.8h\n" | |||
"addv h30, v13.8h\n" | |||
"addv h31, v17.8h\n" | |||
"ins v1.h[0], v28.h[0]\n" | |||
"ins v1.h[1], v29.h[0]\n" | |||
"ins v1.h[2], v30.h[0]\n" | |||
"ins v1.h[3], v31.h[0]\n" | |||
"add v25.4h, v25.4h, v1.4h\n" | |||
"addv h20, v6.8h\n" | |||
"addv h21, v10.8h\n" | |||
"addv h22, v14.8h\n" | |||
"addv h23, v18.8h\n" | |||
"ins v2.h[0], v20.h[0]\n" | |||
"ins v2.h[1], v21.h[0]\n" | |||
"ins v2.h[2], v22.h[0]\n" | |||
"ins v2.h[3], v23.h[0]\n" | |||
"add v26.4h, v26.4h, v2.4h\n" | |||
"addv h28, v7.8h\n" | |||
"addv h29, v11.8h\n" | |||
"addv h30, v15.8h\n" | |||
"addv h31, v19.8h\n" | |||
"ins v3.h[0], v28.h[0]\n" | |||
"ins v3.h[1], v29.h[0]\n" | |||
"ins v3.h[2], v30.h[0]\n" | |||
"ins v3.h[3], v31.h[0]\n" | |||
"add v27.4h, v27.4h, v3.4h\n" | |||
// Store back into memory | |||
STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain), | |||
[n_remain] "+r"(n_remain) | |||
: | |||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", | |||
"v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
"v31"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
#undef STORE_C | |||
} | |||
static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
int y = y0; | |||
for (; y + 3 < ymax; y += 4) { | |||
const int8_t* inptr0 = inptr + y * ldin + k0; | |||
const int8_t* inptr1 = inptr0 + ldin; | |||
const int8_t* inptr2 = inptr1 + ldin; | |||
const int8_t* inptr3 = inptr2 + ldin; | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
prefetch_2x(inptr2); | |||
prefetch_2x(inptr3); | |||
int K = kmax - k0; | |||
//! read 4 * 16 in each row | |||
for (; K > 15; K -= 16) { | |||
interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); | |||
} | |||
if (K > 0) { | |||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); | |||
} | |||
} | |||
for (; y < ymax; y += 4) { | |||
const int8_t* inptr0 = inptr + y * ldin + k0; | |||
const int8_t* inptr1 = inptr0 + ldin; | |||
const int8_t* inptr2 = inptr1 + ldin; | |||
const int8_t* inptr3 = inptr2 + ldin; | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
prefetch_2x(inptr2); | |||
prefetch_2x(inptr3); | |||
int K = kmax - k0; | |||
//! read 4 * 16 in each row | |||
for (; K > 15; K -= 16) { | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; | |||
case 1: | |||
inptr2 = zerobuff; | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr); | |||
} | |||
if (K > 0) { | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; | |||
case 1: | |||
inptr2 = zerobuff; | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K); | |||
} | |||
} | |||
} | |||
static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
const int ksize = kmax - k0; | |||
const int ksize4 = round_up(ksize, 16) * 4; | |||
int8_t* outptr = out; | |||
int k = k0; | |||
for (; k < kmax; k += 16) { | |||
int ki = k; | |||
for (int cnt = 0; cnt < 2; ki += 8, cnt++) { | |||
const int8_t* inptr0 = in + ki * ldin + x0; | |||
const int8_t* inptr1 = inptr0 + ldin; | |||
const int8_t* inptr2 = inptr1 + ldin; | |||
const int8_t* inptr3 = inptr2 + ldin; | |||
const int8_t* inptr4 = inptr3 + ldin; | |||
const int8_t* inptr5 = inptr4 + ldin; | |||
const int8_t* inptr6 = inptr5 + ldin; | |||
const int8_t* inptr7 = inptr6 + ldin; | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
prefetch_2x(inptr2); | |||
prefetch_2x(inptr3); | |||
prefetch_2x(inptr4); | |||
prefetch_2x(inptr5); | |||
prefetch_2x(inptr6); | |||
prefetch_2x(inptr7); | |||
int8_t* outptr_inner = outptr + ki - k; | |||
int remain = std::min(ki + 7 - kmax, 7); | |||
int x = x0; | |||
for (; x + 3 < xmax; x += 4) { | |||
if (remain >= 0) { | |||
switch (remain) { | |||
case 7: | |||
inptr0 = zerobuff; | |||
case 6: | |||
inptr1 = zerobuff; | |||
case 5: | |||
inptr2 = zerobuff; | |||
case 4: | |||
inptr3 = zerobuff; | |||
case 3: | |||
inptr4 = zerobuff; | |||
case 2: | |||
inptr5 = zerobuff; | |||
case 1: | |||
inptr6 = zerobuff; | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, | |||
inptr4, inptr5, inptr6, inptr7, | |||
outptr_inner); | |||
outptr_inner += ksize4; | |||
} | |||
if (x < xmax) { | |||
if (remain >= 0) { | |||
switch (remain) { | |||
case 7: | |||
inptr0 = zerobuff; | |||
case 6: | |||
inptr1 = zerobuff; | |||
case 5: | |||
inptr2 = zerobuff; | |||
case 4: | |||
inptr3 = zerobuff; | |||
case 3: | |||
inptr4 = zerobuff; | |||
case 2: | |||
inptr5 = zerobuff; | |||
case 1: | |||
inptr6 = zerobuff; | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
for (; x < xmax; x++) { | |||
*outptr_inner++ = *inptr0++; | |||
*outptr_inner++ = *inptr1++; | |||
*outptr_inner++ = *inptr2++; | |||
*outptr_inner++ = *inptr3++; | |||
*outptr_inner++ = *inptr4++; | |||
*outptr_inner++ = *inptr5++; | |||
*outptr_inner++ = *inptr6++; | |||
*outptr_inner++ = *inptr7++; | |||
outptr_inner += 8; | |||
} | |||
} | |||
} | |||
outptr += 16 * 4; | |||
} | |||
} | |||
} // namespace matmul_4x4x16 | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,200 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
// ===========================gemm_s8x8x16_4x4================================== | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8); | |||
void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||
int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n(out, in, ldin, y0, | |||
ymax, k0, kmax); | |||
} else { | |||
matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
} | |||
} | |||
void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n(out, in, ldin, x0, | |||
xmax, k0, kmax); | |||
} else { | |||
matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
} | |||
} | |||
void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t M, size_t N, size_t K, dt_int16* C, | |||
size_t LDC, bool is_first_k, const dt_int16*, | |||
dt_int16*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
(A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int16), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
constexpr size_t A_INTERLEAVE = 8; | |||
constexpr size_t B_INTERLEAVE = 8; | |||
//! K is packed to times of 4 | |||
K = round_up<size_t>(K, 8); | |||
const int K8 = K * 8; | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
int16_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K8; | |||
} | |||
for (; m < M; m += 4) { | |||
int16_t* output = C + (m * LDC); | |||
const dt_int8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
} | |||
// ===========================gemm_s8x8x16_4x4================================== | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4); | |||
void gemm_s8x8x16_4x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||
int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
} else { | |||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
} | |||
} | |||
void gemm_s8x8x16_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
} else { | |||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
} | |||
} | |||
void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t M, size_t N, size_t K, dt_int16* C, | |||
size_t LDC, bool is_first_k, const dt_int16*, | |||
dt_int16*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
(A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int16), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
constexpr size_t A_INTERLEAVE = 4; | |||
constexpr size_t B_INTERLEAVE = 4; | |||
//! K is packed to times of 4 | |||
K = round_up<size_t>(K, 16); | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
int16_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, A_INTERLEAVE, B_INTERLEAVE); | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, A_INTERLEAVE, | |||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
for (; m < M; m += A_INTERLEAVE) { | |||
int16_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, | |||
std::min<size_t>(M - m, A_INTERLEAVE), | |||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,27 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, | |||
gemm_s8x8x16_8x8); | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 16, false, true, | |||
gemm_s8x8x16_4x4); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,93 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/opr_impl.h" | |||
#include "src/aarch64/matrix_mul/algos.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/common/utils.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoF32K8x12x1 f32K8x12x1; | |||
AlgoF32K4x16x1 f32k4x16x1; | |||
AlgoF32MK4_4x16 f32mk4_4x16; | |||
AlgoF32Gemv f32_gemv; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
AlgoF16K8x24x1 f16_k8x24x1; | |||
AlgoF16MK8_8x8 f16_mk8_8x8; | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; | |||
AlgoInt8x8x32GemvDotProd int8x8x32_gemv_dotprod; | |||
#else | |||
AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; | |||
AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; | |||
AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8; | |||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||
#endif | |||
AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; | |||
AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; | |||
AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; | |||
AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; | |||
#if __ARM_FEATURE_DOTPROD | |||
AlgoQuint8K8x8x4DotProd quint8_k8x8x4_dotprod; | |||
AlgoQuint8GemvDotProd quint8_gemv_dotprod; | |||
#else | |||
AlgoQuint8K8x8x8 quint8_k8x8x8; | |||
#endif | |||
public: | |||
SmallVector<MatrixMulImpl::AlgoBase*> all_algos; | |||
AlgoPack() { | |||
all_algos.emplace_back(&f32_gemv); | |||
all_algos.emplace_back(&f32K8x12x1); | |||
all_algos.emplace_back(&f32k4x16x1); | |||
all_algos.emplace_back(&f32mk4_4x16); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
all_algos.emplace_back(&f16_k8x24x1); | |||
all_algos.emplace_back(&f16_mk8_8x8); | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
all_algos.emplace_back(&int8x8x32_gemv_dotprod); | |||
all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | |||
#else | |||
all_algos.emplace_back(&int8x8x32_gemv); | |||
all_algos.emplace_back(&int8x8x32_k8x8x8); | |||
all_algos.emplace_back(&int8x8x32_k4x4x16); | |||
all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | |||
#endif | |||
all_algos.emplace_back(&int8x8x16_k4x4x16); | |||
all_algos.emplace_back(&int8x8x16_k8x8x8); | |||
all_algos.emplace_back(&int16x16x32_k12x8x1); | |||
all_algos.emplace_back(&int16x16x32_mk8_8x8); | |||
#if __ARM_FEATURE_DOTPROD | |||
all_algos.emplace_back(&quint8_gemv_dotprod); | |||
all_algos.emplace_back(&quint8_k8x8x4_dotprod); | |||
#else | |||
all_algos.emplace_back(&quint8_k8x8x8); | |||
#endif | |||
} | |||
}; | |||
SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||
static AlgoPack s_algo_pack; | |||
auto&& algos = arm_common::MatrixMulImpl::algo_pack(); | |||
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | |||
s_algo_pack.all_algos.end()); | |||
return std::move(algos); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,63 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/arm_common/matrix_mul/opr_impl.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
class MatrixMulImpl : public arm_common::MatrixMulImpl { | |||
public: | |||
using arm_common::MatrixMulImpl::MatrixMulImpl; | |||
SmallVector<AlgoBase*> algo_pack() override; | |||
private: | |||
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 | |||
class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1 | |||
class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4 | |||
class AlgoF32Gemv; // Aarch64 F32 Gemv | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
class AlgoF16K8x24x1; // Aarch64 F16 Kernel 8x24x1 | |||
class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8 | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel | |||
// 8x12x4 DotProduct | |||
class AlgoInt8x8x32GemvDotProd; // Aarch64 Int8x8x32 Gemv DotProduct | |||
#else | |||
class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 | |||
class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 | |||
class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 | |||
class AlgoInt8x8x32Gemv; // Aarch64 Int8x8x32 Gemv | |||
#endif | |||
class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 | |||
class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 | |||
class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1 | |||
class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8 | |||
#if __ARM_FEATURE_DOTPROD | |||
class AlgoQuint8K8x8x4DotProd; // Aarch64 Quint8 Kernel | |||
// 8x8x4 DotProduct | |||
class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct | |||
#else | |||
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | |||
#endif | |||
class AlgoPack; | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,113 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/quint8/strategy.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#if !(__ARM_FEATURE_DOTPROD) | |||
#include "src/aarch64/matrix_mul/quint8/strategy.h" | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8); | |||
void gemm_u8_8x8::pack_A(dt_uint8* outptr, const dt_uint8* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||
if (transpose) { | |||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||
ymax, k0, kmax, zA); | |||
} else { | |||
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax, zA); | |||
} | |||
} | |||
void gemm_u8_8x8::pack_B(dt_uint8* out, const dt_uint8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, bool transpose) const { | |||
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||
if (transpose) { | |||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||
k0, kmax, zB); | |||
} else { | |||
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, | |||
zB); | |||
} | |||
} | |||
void gemm_u8_8x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
A_dtype.enumv() == DTypeEnum::Quantized8Asymm && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32, | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||
constexpr size_t A_INTERLEAVE = 8; | |||
constexpr size_t B_INTERLEAVE = 8; | |||
//! K is packed to times of 8 | |||
K = round_up<size_t>(K, 8); | |||
const int K8 = K * 8; | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
int32_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_uint8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k, | |||
zA, zB); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4), zA, zB); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K8; | |||
} | |||
for (; m < M; m += 4) { | |||
int32_t* output = C + (m * LDC); | |||
const dt_uint8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), zA, zB); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4), zA, zB); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,28 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/quint8/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#if !(__ARM_FEATURE_DOTPROD) | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 8, 8, 8, false, true, | |||
gemm_u8_8x8); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,177 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | |||
#include <cstddef> | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#if __ARM_FEATURE_DOTPROD | |||
namespace { | |||
void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride, | |||
uint8_t zero_point_A, uint8_t zero_point_B) { | |||
int32_t zAB = static_cast<int32_t>(zero_point_A) * | |||
static_cast<int32_t>(zero_point_B) * K; | |||
uint8x16_t zAq = vdupq_n_u8(zero_point_A); | |||
uint8x16_t zBq = vdupq_n_u8(zero_point_B); | |||
uint8x8_t zA = vdup_n_u8(zero_point_A); | |||
uint8x8_t zB = vdup_n_u8(zero_point_B); | |||
megdnn_assert(N == 1 && Bstride == 1); | |||
size_t m = 0; | |||
for (; m + 2 <= M; m += 2) { | |||
int32_t acc_zA, acc_zB, acc_zB2; | |||
int32_t acc[4]; | |||
size_t k = 0; | |||
uint32x4_t acc_neon = vdupq_n_u32(0); | |||
{ | |||
uint32x4_t acc_zA_neon = vdupq_n_u32(0); | |||
uint32x4_t acc_zB_neon = vdupq_n_u32(0); | |||
uint32x4_t acc_zB2_neon = vdupq_n_u32(0); | |||
for (; k + 16 <= K; k += 16) { | |||
uint8x16_t elem = vld1q_u8(A + m * Astride + k); | |||
acc_zB_neon = vdotq_u32(acc_zB_neon, zBq, elem); | |||
uint64x2_t a0 = vreinterpretq_u64_u8(elem); | |||
elem = vld1q_u8(A + (m + 1) * Astride + k); | |||
acc_zB2_neon = vdotq_u32(acc_zB2_neon, zBq, elem); | |||
uint64x2_t a1 = vreinterpretq_u64_u8(elem); | |||
//! the first 8 elements is m, the last 8 elements is m + 1 | |||
uint8x16_t a2 = vreinterpretq_u8_u64(vzip1q_u64(a0, a1)); | |||
uint8x16_t a3 = vreinterpretq_u8_u64(vzip2q_u64(a0, a1)); | |||
elem = vld1q_u8(B + k); | |||
acc_zA_neon = vdotq_u32(acc_zA_neon, zAq, elem); | |||
uint64x2_t b0 = vreinterpretq_u64_u8(elem); | |||
uint8x16_t b2 = vreinterpretq_u8_u64(vzip1q_u64(b0, b0)); | |||
uint8x16_t b3 = vreinterpretq_u8_u64(vzip2q_u64(b0, b0)); | |||
acc_neon = vdotq_u32(acc_neon, a2, b2); | |||
acc_neon = vdotq_u32(acc_neon, a3, b3); | |||
} | |||
vst1q_s32(acc, vreinterpretq_s32_u32(acc_neon)); | |||
acc_zA = vaddvq_u32(acc_zA_neon); | |||
acc_zB = vaddvq_u32(acc_zB_neon); | |||
acc_zB2 = vaddvq_u32(acc_zB2_neon); | |||
} | |||
{ | |||
uint32x2_t acc_zA_neon = vdup_n_u32(0); | |||
uint32x2_t acc_zB_neon = vdup_n_u32(0); | |||
uint32x2_t acc_zB2_neon = vdup_n_u32(0); | |||
for (; k + 8 <= K; k += 8) { | |||
uint8x8_t a0 = vld1_u8(A + m * Astride + k); | |||
uint8x8_t a1 = vld1_u8(A + (m + 1) * Astride + k); | |||
uint8x8_t b0 = vld1_u8(B + k); | |||
uint32x2_t zero = vdup_n_u32(0); | |||
acc[0] += vaddv_u32(vdot_u32(zero, a0, b0)); | |||
zero = vdup_n_u32(0); | |||
acc[3] += vaddv_u32(vdot_u32(zero, a1, b0)); | |||
acc_zB_neon = vdot_u32(acc_zB_neon, a0, zB); | |||
acc_zB2_neon = vdot_u32(acc_zB2_neon, a1, zB); | |||
acc_zA_neon = vdot_u32(acc_zA_neon, b0, zA); | |||
} | |||
acc_zA += vaddv_u32(acc_zA_neon); | |||
acc_zB += vaddv_u32(acc_zB_neon); | |||
acc_zB2 += vaddv_u32(acc_zB2_neon); | |||
} | |||
for (; k < K; ++k) { | |||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||
acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k]; | |||
acc_zA += static_cast<int32_t>(B[k]) * zero_point_A; | |||
acc_zB += static_cast<int32_t>(A[m * Astride + k]) * zero_point_B; | |||
acc_zB2 += static_cast<int32_t>(A[(m + 1) * Astride + k]) * | |||
zero_point_B; | |||
} | |||
C[m * Cstride] = acc[0] + acc[1] + zAB - acc_zA - acc_zB; | |||
C[(m + 1) * Cstride] = acc[2] + acc[3] + zAB - acc_zA - acc_zB2; | |||
} | |||
for (; m < M; ++m) { | |||
int32_t acc[4]; | |||
int32_t acc_zA, acc_zB; | |||
uint32x4_t acc_neon = vdupq_n_u32(0); | |||
size_t k = 0; | |||
{ | |||
uint32x4_t acc_zA_neon = vdupq_n_u32(0); | |||
uint32x4_t acc_zB_neon = vdupq_n_u32(0); | |||
for (; k + 16 <= K; k += 16) { | |||
uint8x16_t a0 = vld1q_u8(A + m * Astride + k); | |||
uint8x16_t b0 = vld1q_u8(B + k); | |||
acc_neon = vdotq_u32(acc_neon, a0, b0); | |||
acc_zB_neon = vdotq_u32(acc_zB_neon, zBq, a0); | |||
acc_zA_neon = vdotq_u32(acc_zA_neon, zAq, b0); | |||
} | |||
vst1q_s32(acc, vreinterpretq_s32_u32(acc_neon)); | |||
acc_zA = vaddvq_u32(acc_zA_neon); | |||
acc_zB = vaddvq_u32(acc_zB_neon); | |||
} | |||
{ | |||
uint32x2_t acc_zA_neon = vdup_n_u32(0); | |||
uint32x2_t acc_zB_neon = vdup_n_u32(0); | |||
for (; k + 8 <= K; k += 8) { | |||
uint8x8_t a0 = vld1_u8(A + m * Astride + k); | |||
uint8x8_t b0 = vld1_u8(B + k); | |||
uint32x2_t zero = vdup_n_u32(0); | |||
acc[0] += vaddv_u32(vdot_u32(zero, a0, b0)); | |||
acc_zB_neon = vdot_u32(acc_zB_neon, a0, zB); | |||
acc_zA_neon = vdot_u32(acc_zA_neon, b0, zA); | |||
} | |||
acc_zA += vaddv_u32(acc_zA_neon); | |||
acc_zB += vaddv_u32(acc_zB_neon); | |||
} | |||
for (; k < K; ++k) { | |||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||
acc_zA += static_cast<int32_t>(B[k]) * zero_point_A; | |||
acc_zB += static_cast<int32_t>(A[m * Astride + k]) * zero_point_B; | |||
} | |||
C[m * Cstride] = | |||
acc[0] + acc[1] + acc[2] + acc[3] + zAB - acc_zA - acc_zB; | |||
} | |||
} | |||
} // namespace | |||
bool megdnn::aarch64::matmul::is_gemv_like_preferred_quint8( | |||
bool transposeA, bool transposeB, size_t M, size_t N, size_t K, | |||
size_t /* LDA */, size_t LDB, size_t /* LDC */) { | |||
if (transposeA) | |||
return false; | |||
if (transposeB) | |||
return false; | |||
MEGDNN_MARK_USED_VAR(K); | |||
MEGDNN_MARK_USED_VAR(M); | |||
//! rebenchmark gemv in sdm855 | |||
return (N == 1 && LDB == 1); | |||
} | |||
void megdnn::aarch64::matmul::gemv_like_quint8( | |||
const uint8_t* __restrict A, const uint8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, size_t Astride, | |||
size_t Bstride, size_t Cstride, uint8_t zero_point_A, | |||
uint8_t zero_point_B) { | |||
megdnn_assert(N == 1); | |||
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride, | |||
zero_point_A, zero_point_B); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,35 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <cstddef> | |||
#include <cstdint> | |||
#if __ARM_FEATURE_DOTPROD | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
bool is_gemv_like_preferred_quint8(bool transposeA, bool transposeB, size_t M, | |||
size_t N, size_t K, size_t LDA, size_t LDB, | |||
size_t LDC); | |||
void gemv_like_quint8(const uint8_t* __restrict A, const uint8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride, | |||
uint8_t zero_point_A, uint8_t zero_point_B); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,114 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/quint8_dot/strategy.h" | |||
#include "megdnn/dtype.h" | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#if __ARM_FEATURE_DOTPROD | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8); | |||
void gemm_u8_8x8::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0, | |||
ymax, k0, kmax); | |||
} else { | |||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin, | |||
y0, ymax, k0, kmax); | |||
} | |||
} | |||
void gemm_u8_8x8::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, | |||
xmax, k0, kmax); | |||
} else { | |||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax, | |||
k0, kmax); | |||
} | |||
} | |||
void gemm_u8_8x8::kern(const uint8_t* packA, const uint8_t* packB, size_t M, | |||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
A_dtype.enumv() == DTypeEnum::Quantized8Asymm && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
size_t zero_point_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||
size_t zero_point_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||
constexpr size_t A_INTERLEAVE = 8; | |||
constexpr size_t B_INTERLEAVE = 8; | |||
const uint32_t zAB = static_cast<uint32_t>(zero_point_A) * | |||
static_cast<uint32_t>(zero_point_B) * K; | |||
//! K is packed to times of 4 | |||
K = round_up<size_t>(K, 4); | |||
const int K8 = (K << 3); | |||
const int K4 = K * 4; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
int32_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_uint8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k, | |||
zero_point_A, zero_point_B, zAB); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x4::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4), zero_point_A, | |||
zero_point_B, zAB); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K8; | |||
} | |||
for (; m < M; m += 4) { | |||
int32_t* output = C + (m * LDC); | |||
const dt_uint8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), zero_point_A, | |||
zero_point_B, zAB); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x4::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4), zero_point_A, | |||
zero_point_B, zAB); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,26 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
#if __ARM_FEATURE_DOTPROD | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(uint8_t, int32_t, int32_t, 8, 8, 4, false, true, | |||
gemm_u8_8x8); | |||
} // namespace aarch64 | |||
} // namespace matmul | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,183 @@ | |||
/** | |||
* \file dnn/src/aarch64/relayout/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/common/utils.h" | |||
#include "src/common/relayout_helper.h" | |||
#include "src/aarch64/handle.h" | |||
#include "src/aarch64/relayout/opr_impl.h" | |||
using namespace megdnn; | |||
using namespace relayout; | |||
namespace { | |||
struct TransposeByte { | |||
uint8_t v; | |||
}; | |||
void trans_16x16_u8(const void* src, void* dst, const size_t src_step, | |||
const size_t dst_step) { | |||
asm volatile( | |||
"\n" | |||
"ld1 {v0.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v1.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v2.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v3.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v4.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v5.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v6.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v7.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v8.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v9.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v10.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v11.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v12.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v13.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v14.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v15.16b}, [%[src]], %[src_step] \n" | |||
"trn1 v16.16b, v0.16b, v1.16b \n" | |||
"trn2 v17.16b, v0.16b, v1.16b \n" | |||
"trn1 v18.16b, v2.16b, v3.16b \n" | |||
"trn2 v19.16b, v2.16b, v3.16b \n" | |||
"trn1 v20.16b, v4.16b, v5.16b \n" | |||
"trn2 v21.16b, v4.16b, v5.16b \n" | |||
"trn1 v22.16b, v6.16b, v7.16b \n" | |||
"trn2 v23.16b, v6.16b, v7.16b \n" | |||
"trn1 v24.16b, v8.16b, v9.16b \n" | |||
"trn2 v25.16b, v8.16b, v9.16b \n" | |||
"trn1 v26.16b, v10.16b, v11.16b \n" | |||
"trn2 v27.16b, v10.16b, v11.16b \n" | |||
"trn1 v28.16b, v12.16b, v13.16b \n" | |||
"trn2 v29.16b, v12.16b, v13.16b \n" | |||
"trn1 v30.16b, v14.16b, v15.16b \n" | |||
"trn2 v31.16b, v14.16b, v15.16b \n" | |||
"trn1 v0.8h, v16.8h, v18.8h \n" | |||
"trn2 v2.8h, v16.8h, v18.8h \n" | |||
"trn1 v4.8h, v20.8h, v22.8h \n" | |||
"trn2 v6.8h, v20.8h, v22.8h \n" | |||
"trn1 v8.8h, v24.8h, v26.8h \n" | |||
"trn2 v10.8h, v24.8h, v26.8h \n" | |||
"trn1 v12.8h, v28.8h, v30.8h \n" | |||
"trn2 v14.8h, v28.8h, v30.8h \n" | |||
"trn1 v1.8h, v17.8h, v19.8h \n" | |||
"trn2 v3.8h, v17.8h, v19.8h \n" | |||
"trn1 v5.8h, v21.8h, v23.8h \n" | |||
"trn2 v7.8h, v21.8h, v23.8h \n" | |||
"trn1 v9.8h, v25.8h, v27.8h \n" | |||
"trn2 v11.8h, v25.8h, v27.8h \n" | |||
"trn1 v13.8h, v29.8h, v31.8h \n" | |||
"trn2 v15.8h, v29.8h, v31.8h \n" | |||
"trn1 v16.4s, v0.4s, v4.4s \n" | |||
"trn2 v20.4s, v0.4s, v4.4s \n" | |||
"trn1 v24.4s, v8.4s, v12.4s \n" | |||
"trn2 v28.4s, v8.4s, v12.4s \n" | |||
"trn1 v17.4s, v1.4s, v5.4s \n" | |||
"trn2 v21.4s, v1.4s, v5.4s \n" | |||
"trn1 v25.4s, v9.4s, v13.4s \n" | |||
"trn2 v29.4s, v9.4s, v13.4s \n" | |||
"trn1 v18.4s, v2.4s, v6.4s \n" | |||
"trn2 v22.4s, v2.4s, v6.4s \n" | |||
"trn1 v26.4s, v10.4s, v14.4s \n" | |||
"trn2 v30.4s, v10.4s, v14.4s \n" | |||
"trn1 v19.4s, v3.4s, v7.4s \n" | |||
"trn2 v23.4s, v3.4s, v7.4s \n" | |||
"trn1 v27.4s, v11.4s, v15.4s \n" | |||
"trn2 v31.4s, v11.4s, v15.4s \n" | |||
"trn1 v0.2d, v16.2d, v24.2d \n" | |||
"trn2 v8.2d, v16.2d, v24.2d \n" | |||
"trn1 v1.2d, v17.2d, v25.2d \n" | |||
"trn2 v9.2d, v17.2d, v25.2d \n" | |||
"trn1 v2.2d, v18.2d, v26.2d \n" | |||
"trn2 v10.2d, v18.2d, v26.2d \n" | |||
"trn1 v3.2d, v19.2d, v27.2d \n" | |||
"trn2 v11.2d, v19.2d, v27.2d \n" | |||
"trn1 v4.2d, v20.2d, v28.2d \n" | |||
"trn2 v12.2d, v20.2d, v28.2d \n" | |||
"trn1 v5.2d, v21.2d, v29.2d \n" | |||
"trn2 v13.2d, v21.2d, v29.2d \n" | |||
"trn1 v6.2d, v22.2d, v30.2d \n" | |||
"trn2 v14.2d, v22.2d, v30.2d \n" | |||
"trn1 v7.2d, v23.2d, v31.2d \n" | |||
"trn2 v15.2d, v23.2d, v31.2d \n" | |||
"st1 {v0.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v1.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v2.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v3.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v4.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v5.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v6.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v7.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v8.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v9.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v10.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v11.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v12.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v13.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v14.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v15.16b}, [%[dst]], %[dst_step] \n" | |||
: | |||
[src] "+r" (src), | |||
[dst] "+r" (dst) | |||
: | |||
[src_step] "r" (src_step), | |||
[dst_step] "r" (dst_step) | |||
: | |||
"d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", | |||
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", | |||
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", | |||
"d31"); | |||
} | |||
} // anonymous namespace | |||
namespace megdnn { | |||
namespace relayout { | |||
namespace transpose_fallback { | |||
template <> | |||
struct transpose_traits<TransposeByte> { | |||
static constexpr size_t block_size = 16; | |||
}; | |||
template <> | |||
void transpose_block<TransposeByte>(const TransposeByte* src, | |||
TransposeByte* dst, const size_t src_stride, | |||
const size_t dst_stride) { | |||
trans_16x16_u8(src, dst, src_stride, dst_stride); | |||
} | |||
} // namespace transpose_fallback | |||
} // namespace relayout | |||
} // namespace megdnn | |||
void aarch64::RelayoutForwardImpl::exec(_megdnn_tensor_in src0, | |||
_megdnn_tensor_out dst0, | |||
Handle* src_handle) { | |||
check_cpu_handle(src_handle); | |||
TensorND src = src0, dst = dst0; | |||
check_layout_and_canonize(src.layout, dst.layout); | |||
relayout::TransposeParam trans_param; | |||
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); | |||
if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { | |||
auto sptr = static_cast<TransposeByte*>(src.raw_ptr), | |||
dptr = static_cast<TransposeByte*>(dst.raw_ptr); | |||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||
transpose_fallback::transpose<TransposeByte>( | |||
trans_param.batch, trans_param.m, trans_param.n, sptr, | |||
dptr)); | |||
return; | |||
} | |||
exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,31 @@ | |||
/** | |||
* \file dnn/src/aarch64/relayout/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/fallback/relayout/opr_impl.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
class RelayoutForwardImpl final : public fallback::RelayoutForwardImpl { | |||
public: | |||
using fallback::RelayoutForwardImpl::RelayoutForwardImpl; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
Handle *src_handle) override; | |||
bool is_thread_safe() const override { return true; } | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,393 @@ | |||
/** | |||
* \file dnn/src/aarch64/rotate/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <cstring> | |||
#include "src/aarch64/rotate/opr_impl.h" | |||
#include "src/aarch64/handle.h" | |||
#include "src/common/cv/common.h" | |||
#include "src/common/cv/helper.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
namespace megcv { | |||
void rotate_8uc1_clockwise_16x16(const uchar *src, | |||
uchar *dst, | |||
size_t src_step, size_t dst_step) | |||
{ | |||
asm volatile ("\n" | |||
"ld1 {v0.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v1.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v2.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v3.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v4.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v5.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v6.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v7.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v8.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v9.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v10.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v11.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v12.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v13.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v14.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v15.16b}, [%[src]], %[src_step] \n" | |||
"trn1 v16.16b, v0.16b, v1.16b \n" | |||
"trn2 v17.16b, v0.16b, v1.16b \n" | |||
"trn1 v18.16b, v2.16b, v3.16b \n" | |||
"trn2 v19.16b, v2.16b, v3.16b \n" | |||
"trn1 v20.16b, v4.16b, v5.16b \n" | |||
"trn2 v21.16b, v4.16b, v5.16b \n" | |||
"trn1 v22.16b, v6.16b, v7.16b \n" | |||
"trn2 v23.16b, v6.16b, v7.16b \n" | |||
"trn1 v24.16b, v8.16b, v9.16b \n" | |||
"trn2 v25.16b, v8.16b, v9.16b \n" | |||
"trn1 v26.16b, v10.16b, v11.16b \n" | |||
"trn2 v27.16b, v10.16b, v11.16b \n" | |||
"trn1 v28.16b, v12.16b, v13.16b \n" | |||
"trn2 v29.16b, v12.16b, v13.16b \n" | |||
"trn1 v30.16b, v14.16b, v15.16b \n" | |||
"trn2 v31.16b, v14.16b, v15.16b \n" | |||
"trn1 v0.8h, v16.8h, v18.8h \n" | |||
"trn2 v2.8h, v16.8h, v18.8h \n" | |||
"trn1 v4.8h, v20.8h, v22.8h \n" | |||
"trn2 v6.8h, v20.8h, v22.8h \n" | |||
"trn1 v8.8h, v24.8h, v26.8h \n" | |||
"trn2 v10.8h, v24.8h, v26.8h \n" | |||
"trn1 v12.8h, v28.8h, v30.8h \n" | |||
"trn2 v14.8h, v28.8h, v30.8h \n" | |||
"trn1 v1.8h, v17.8h, v19.8h \n" | |||
"trn2 v3.8h, v17.8h, v19.8h \n" | |||
"trn1 v5.8h, v21.8h, v23.8h \n" | |||
"trn2 v7.8h, v21.8h, v23.8h \n" | |||
"trn1 v9.8h, v25.8h, v27.8h \n" | |||
"trn2 v11.8h, v25.8h, v27.8h \n" | |||
"trn1 v13.8h, v29.8h, v31.8h \n" | |||
"trn2 v15.8h, v29.8h, v31.8h \n" | |||
"trn1 v16.4s, v0.4s, v4.4s \n" | |||
"trn2 v20.4s, v0.4s, v4.4s \n" | |||
"trn1 v24.4s, v8.4s, v12.4s \n" | |||
"trn2 v28.4s, v8.4s, v12.4s \n" | |||
"trn1 v17.4s, v1.4s, v5.4s \n" | |||
"trn2 v21.4s, v1.4s, v5.4s \n" | |||
"trn1 v25.4s, v9.4s, v13.4s \n" | |||
"trn2 v29.4s, v9.4s, v13.4s \n" | |||
"trn1 v18.4s, v2.4s, v6.4s \n" | |||
"trn2 v22.4s, v2.4s, v6.4s \n" | |||
"trn1 v26.4s, v10.4s, v14.4s \n" | |||
"trn2 v30.4s, v10.4s, v14.4s \n" | |||
"trn1 v19.4s, v3.4s, v7.4s \n" | |||
"trn2 v23.4s, v3.4s, v7.4s \n" | |||
"trn1 v27.4s, v11.4s, v15.4s \n" | |||
"trn2 v31.4s, v11.4s, v15.4s \n" | |||
"trn1 v0.2d, v16.2d, v24.2d \n" | |||
"trn2 v8.2d, v16.2d, v24.2d \n" | |||
"trn1 v1.2d, v17.2d, v25.2d \n" | |||
"trn2 v9.2d, v17.2d, v25.2d \n" | |||
"trn1 v2.2d, v18.2d, v26.2d \n" | |||
"trn2 v10.2d, v18.2d, v26.2d \n" | |||
"trn1 v3.2d, v19.2d, v27.2d \n" | |||
"trn2 v11.2d, v19.2d, v27.2d \n" | |||
"trn1 v4.2d, v20.2d, v28.2d \n" | |||
"trn2 v12.2d, v20.2d, v28.2d \n" | |||
"trn1 v5.2d, v21.2d, v29.2d \n" | |||
"trn2 v13.2d, v21.2d, v29.2d \n" | |||
"trn1 v6.2d, v22.2d, v30.2d \n" | |||
"trn2 v14.2d, v22.2d, v30.2d \n" | |||
"trn1 v7.2d, v23.2d, v31.2d \n" | |||
"trn2 v15.2d, v23.2d, v31.2d \n" | |||
// There is no rev128 instruction, so we use rev64 and ext to simulate it. | |||
"rev64 v0.16b, v0.16b \n" | |||
"rev64 v1.16b, v1.16b \n" | |||
"rev64 v2.16b, v2.16b \n" | |||
"rev64 v3.16b, v3.16b \n" | |||
"rev64 v4.16b, v4.16b \n" | |||
"rev64 v5.16b, v5.16b \n" | |||
"rev64 v6.16b, v6.16b \n" | |||
"rev64 v7.16b, v7.16b \n" | |||
"rev64 v8.16b, v8.16b \n" | |||
"rev64 v9.16b, v9.16b \n" | |||
"rev64 v10.16b, v10.16b \n" | |||
"rev64 v11.16b, v11.16b \n" | |||
"rev64 v12.16b, v12.16b \n" | |||
"rev64 v13.16b, v13.16b \n" | |||
"rev64 v14.16b, v14.16b \n" | |||
"rev64 v15.16b, v15.16b \n" | |||
"ext v0.16b, v0.16b, v0.16b, #8 \n" | |||
"ext v1.16b, v1.16b, v1.16b, #8 \n" | |||
"ext v2.16b, v2.16b, v2.16b, #8 \n" | |||
"ext v3.16b, v3.16b, v3.16b, #8 \n" | |||
"ext v4.16b, v4.16b, v4.16b, #8 \n" | |||
"ext v5.16b, v5.16b, v5.16b, #8 \n" | |||
"ext v6.16b, v6.16b, v6.16b, #8 \n" | |||
"ext v7.16b, v7.16b, v7.16b, #8 \n" | |||
"ext v8.16b, v8.16b, v8.16b, #8 \n" | |||
"ext v9.16b, v9.16b, v9.16b, #8 \n" | |||
"ext v10.16b, v10.16b, v10.16b, #8 \n" | |||
"ext v11.16b, v11.16b, v11.16b, #8 \n" | |||
"ext v12.16b, v12.16b, v12.16b, #8 \n" | |||
"ext v13.16b, v13.16b, v13.16b, #8 \n" | |||
"ext v14.16b, v14.16b, v14.16b, #8 \n" | |||
"ext v15.16b, v15.16b, v15.16b, #8 \n" | |||
"st1 {v0.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v1.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v2.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v3.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v4.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v5.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v6.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v7.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v8.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v9.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v10.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v11.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v12.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v13.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v14.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v15.16b}, [%[dst]], %[dst_step] \n" | |||
: | |||
[src] "+r" (src), | |||
[dst] "+r" (dst) | |||
: | |||
[src_step] "r" (src_step), | |||
[dst_step] "r" (dst_step) | |||
: | |||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", | |||
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" | |||
); | |||
} | |||
void rotate_8uc1_counterclockwise_16x16(const uchar *src, | |||
uchar *dst, | |||
size_t src_step, size_t dst_step) | |||
{ | |||
asm volatile ("\n" | |||
"ld1 {v0.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v1.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v2.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v3.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v4.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v5.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v6.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v7.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v8.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v9.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v10.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v11.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v12.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v13.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v14.16b}, [%[src]], %[src_step] \n" | |||
"ld1 {v15.16b}, [%[src]], %[src_step] \n" | |||
"trn1 v16.16b, v0.16b, v1.16b \n" | |||
"trn2 v17.16b, v0.16b, v1.16b \n" | |||
"trn1 v18.16b, v2.16b, v3.16b \n" | |||
"trn2 v19.16b, v2.16b, v3.16b \n" | |||
"trn1 v20.16b, v4.16b, v5.16b \n" | |||
"trn2 v21.16b, v4.16b, v5.16b \n" | |||
"trn1 v22.16b, v6.16b, v7.16b \n" | |||
"trn2 v23.16b, v6.16b, v7.16b \n" | |||
"trn1 v24.16b, v8.16b, v9.16b \n" | |||
"trn2 v25.16b, v8.16b, v9.16b \n" | |||
"trn1 v26.16b, v10.16b, v11.16b \n" | |||
"trn2 v27.16b, v10.16b, v11.16b \n" | |||
"trn1 v28.16b, v12.16b, v13.16b \n" | |||
"trn2 v29.16b, v12.16b, v13.16b \n" | |||
"trn1 v30.16b, v14.16b, v15.16b \n" | |||
"trn2 v31.16b, v14.16b, v15.16b \n" | |||
"trn1 v0.8h, v16.8h, v18.8h \n" | |||
"trn2 v2.8h, v16.8h, v18.8h \n" | |||
"trn1 v4.8h, v20.8h, v22.8h \n" | |||
"trn2 v6.8h, v20.8h, v22.8h \n" | |||
"trn1 v8.8h, v24.8h, v26.8h \n" | |||
"trn2 v10.8h, v24.8h, v26.8h \n" | |||
"trn1 v12.8h, v28.8h, v30.8h \n" | |||
"trn2 v14.8h, v28.8h, v30.8h \n" | |||
"trn1 v1.8h, v17.8h, v19.8h \n" | |||
"trn2 v3.8h, v17.8h, v19.8h \n" | |||
"trn1 v5.8h, v21.8h, v23.8h \n" | |||
"trn2 v7.8h, v21.8h, v23.8h \n" | |||
"trn1 v9.8h, v25.8h, v27.8h \n" | |||
"trn2 v11.8h, v25.8h, v27.8h \n" | |||
"trn1 v13.8h, v29.8h, v31.8h \n" | |||
"trn2 v15.8h, v29.8h, v31.8h \n" | |||
"trn1 v16.4s, v0.4s, v4.4s \n" | |||
"trn2 v20.4s, v0.4s, v4.4s \n" | |||
"trn1 v24.4s, v8.4s, v12.4s \n" | |||
"trn2 v28.4s, v8.4s, v12.4s \n" | |||
"trn1 v17.4s, v1.4s, v5.4s \n" | |||
"trn2 v21.4s, v1.4s, v5.4s \n" | |||
"trn1 v25.4s, v9.4s, v13.4s \n" | |||
"trn2 v29.4s, v9.4s, v13.4s \n" | |||
"trn1 v18.4s, v2.4s, v6.4s \n" | |||
"trn2 v22.4s, v2.4s, v6.4s \n" | |||
"trn1 v26.4s, v10.4s, v14.4s \n" | |||
"trn2 v30.4s, v10.4s, v14.4s \n" | |||
"trn1 v19.4s, v3.4s, v7.4s \n" | |||
"trn2 v23.4s, v3.4s, v7.4s \n" | |||
"trn1 v27.4s, v11.4s, v15.4s \n" | |||
"trn2 v31.4s, v11.4s, v15.4s \n" | |||
"trn1 v0.2d, v16.2d, v24.2d \n" | |||
"trn2 v8.2d, v16.2d, v24.2d \n" | |||
"trn1 v1.2d, v17.2d, v25.2d \n" | |||
"trn2 v9.2d, v17.2d, v25.2d \n" | |||
"trn1 v2.2d, v18.2d, v26.2d \n" | |||
"trn2 v10.2d, v18.2d, v26.2d \n" | |||
"trn1 v3.2d, v19.2d, v27.2d \n" | |||
"trn2 v11.2d, v19.2d, v27.2d \n" | |||
"trn1 v4.2d, v20.2d, v28.2d \n" | |||
"trn2 v12.2d, v20.2d, v28.2d \n" | |||
"trn1 v5.2d, v21.2d, v29.2d \n" | |||
"trn2 v13.2d, v21.2d, v29.2d \n" | |||
"trn1 v6.2d, v22.2d, v30.2d \n" | |||
"trn2 v14.2d, v22.2d, v30.2d \n" | |||
"trn1 v7.2d, v23.2d, v31.2d \n" | |||
"trn2 v15.2d, v23.2d, v31.2d \n" | |||
"st1 {v15.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v14.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v13.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v12.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v11.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v10.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v9.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v8.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v7.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v6.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v5.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v4.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v3.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v2.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v1.16b}, [%[dst]], %[dst_step] \n" | |||
"st1 {v0.16b}, [%[dst]], %[dst_step] \n" | |||
: | |||
[src] "+r" (src), | |||
[dst] "+r" (dst) | |||
: | |||
[src_step] "r" (src_step), | |||
[dst_step] "r" (dst_step) | |||
: | |||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", | |||
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" | |||
); | |||
} | |||
void rotate_8uc1_clockwise(const uchar* src, uchar* dst, const size_t rows, | |||
const size_t cols, const size_t src_step, | |||
const size_t dst_step) { | |||
const size_t block = 16; | |||
(void)block; | |||
size_t i = 0; | |||
for (; i + block <= rows; i += block) { | |||
size_t j = 0; | |||
for (; j + block <= cols; j += block) { | |||
rotate_8uc1_clockwise_16x16( | |||
src + i * src_step + j, | |||
dst + j * dst_step + (rows - (i + block)), src_step, | |||
dst_step); | |||
} | |||
for (; j < cols; ++j) { | |||
for (size_t k = 0; k < block; ++k) { | |||
dst[j * dst_step + (rows - 1 - (i + k))] = | |||
src[(i + k) * src_step + j]; | |||
} | |||
} | |||
} | |||
for (; i < rows; ++i) { | |||
for (size_t j = 0; j < cols; ++j) { | |||
dst[j * dst_step + (rows - 1 - i)] = src[i * src_step + j]; | |||
} | |||
} | |||
} | |||
void rotate_8uc1_counterclockwise(const uchar* src, uchar* dst, | |||
const size_t rows, const size_t cols, | |||
const size_t src_step, | |||
const size_t dst_step) { | |||
const size_t block = 16; | |||
(void)block; | |||
size_t i = 0; | |||
for (; i + block <= rows; i += block) { | |||
size_t j = 0; | |||
for (; j + block <= cols; j += block) { | |||
rotate_8uc1_counterclockwise_16x16( | |||
src + i * src_step + j, | |||
dst + (cols - (j + block)) * dst_step + i, src_step, | |||
dst_step); | |||
} | |||
for (; j < cols; ++j) { | |||
for (size_t k = 0; k < block; ++k) { | |||
dst[(cols - 1 - j) * dst_step + (i + k)] = | |||
src[(i + k) * src_step + j]; | |||
} | |||
} | |||
} | |||
for (; i < rows; ++i) { | |||
for (size_t j = 0; j < cols; ++j) { | |||
dst[(cols - 1 - j) * dst_step + i] = src[i * src_step + j]; | |||
} | |||
} | |||
} | |||
void rotate(const Mat<uchar>& src, Mat<uchar>& dst, bool clockwise) { | |||
megdnn_assert(src.rows() == dst.cols()); | |||
megdnn_assert(src.cols() == dst.rows()); | |||
megdnn_assert(src.channels() == dst.channels()); | |||
megdnn_assert(src.channels() == 1_z); | |||
if (clockwise) { | |||
rotate_8uc1_clockwise(src.ptr(), dst.ptr(), src.rows(), src.cols(), | |||
src.step(), dst.step()); | |||
} else { | |||
rotate_8uc1_counterclockwise(src.ptr(), dst.ptr(), src.rows(), | |||
src.cols(), src.step(), dst.step()); | |||
} | |||
} | |||
} // namespace megcv | |||
namespace aarch64 { | |||
void RotateImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||
_megdnn_workspace workspace) { | |||
using namespace megcv; | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
//! rotate only support data type is uchar and the channel size is 1 | |||
if (dst.layout.dtype != dtype::Uint8() || src.layout.shape[3] != 1) { | |||
return fallback::RotateImpl::exec(src, dst, workspace); | |||
} | |||
MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||
for (size_t i = 0; i < src.layout.shape[0]; ++i) { | |||
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i); | |||
Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i); | |||
rotate(src_mat, dst_mat, param().clockwise); | |||
} | |||
}); | |||
} | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,35 @@ | |||
/** | |||
* \file dnn/src/aarch64/rotate/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/fallback/rotate/opr_impl.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
class RotateImpl : public fallback::RotateImpl { | |||
public: | |||
using fallback::RotateImpl::RotateImpl; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,43 @@ | |||
/** | |||
* \file dnn/src/aarch64/warp_perspective/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/warp_perspective/opr_impl.h" | |||
#include "src/aarch64/warp_perspective/warp_perspective_cv.h" | |||
#include "src/common/utils.h" | |||
#include "src/common/warp_common.h" | |||
#include "src/naive/handle.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
void WarpPerspectiveImpl::exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in mat, | |||
_megdnn_tensor_in mat_idx, | |||
_megdnn_tensor_in dst, | |||
_megdnn_workspace workspace) | |||
{ | |||
check_exec(src.layout, mat.layout, mat_idx.layout, dst.layout, | |||
workspace.size); | |||
if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, | |||
param().format) && !mat_idx.layout.ndim) { | |||
warp_perspective_cv_exec(src, mat, dst, param().border_val, | |||
param().bmode, param().imode, handle()); | |||
} else { | |||
//! Use arm_common implementation | |||
arm_common::WarpPerspectiveImpl::exec(src, mat, mat_idx, dst, workspace); | |||
} | |||
} | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,30 @@ | |||
/** | |||
* \file dnn/src/aarch64/warp_perspective/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/arm_common/warp_perspective/opr_impl.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
class WarpPerspectiveImpl : public arm_common::WarpPerspectiveImpl { | |||
public: | |||
using arm_common::WarpPerspectiveImpl::WarpPerspectiveImpl; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, | |||
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,257 @@ | |||
/** | |||
* \file dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/handle.h" | |||
#include "src/aarch64/warp_perspective/warp_perspective_cv.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/cv/common.h" | |||
#include "src/common/cv/helper.h" | |||
#include "src/common/cv/interp_helper.h" | |||
#include "src/common/utils.h" | |||
#include "src/common/warp_common.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace megcv; | |||
using namespace warp; | |||
namespace { | |||
constexpr size_t BLOCK_SZ = 32u; | |||
template <typename T, InterpolationMode imode, BorderMode bmode, size_t CH> | |||
void warp_perspective_cv(const Mat<T>& src, Mat<T>& dst, const float* trans, | |||
const float border_value, size_t task_id) { | |||
// no extra padding | |||
double M[9]; | |||
rep(i, 9) M[i] = trans[i]; | |||
T bvalue[3] = {(T)border_value, (T)border_value, (T)border_value}; | |||
size_t x1, y1, width = dst.cols(), height = dst.rows(); | |||
size_t BLOCK_SZ_H = std::min(BLOCK_SZ / 2, height); | |||
size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width); | |||
BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height); | |||
size_t width_block_size = div_ceil<size_t>(width, BLOCK_SZ_W); | |||
size_t y = (task_id / width_block_size) * BLOCK_SZ_H; | |||
size_t x = (task_id % width_block_size) * BLOCK_SZ_W; | |||
// start invoke | |||
short XY[BLOCK_SZ * BLOCK_SZ * 2], A[BLOCK_SZ * BLOCK_SZ]; | |||
float64x2_t vM6 = vdupq_n_f64(M[6]); | |||
float64x2_t vM0 = vdupq_n_f64(M[0]); | |||
float64x2_t vM3 = vdupq_n_f64(M[3]); | |||
float64x2_t v2M6 = vdupq_n_f64(M[6] * 2); | |||
float64x2_t v2M0 = vdupq_n_f64(M[0] * 2); | |||
float64x2_t v2M3 = vdupq_n_f64(M[3] * 2); | |||
float64x2_t v4f = vdupq_n_f64(4); | |||
float64x2_t v1f = vdupq_n_f64(1); | |||
float64x2_t v0f = vdupq_n_f64(0); | |||
float64x2_t vTABLE_SIZE = vdupq_n_f64(INTER_TAB_SIZE); | |||
float64x2_t vmin = vdupq_n_f64((double)INT_MIN); | |||
float64x2_t vmax = vdupq_n_f64((double)INT_MAX); | |||
int32x4_t vtabmask = vdupq_n_s32(INTER_TAB_SIZE - 1); | |||
size_t bw = std::min(BLOCK_SZ_W, width - x); | |||
size_t bh = std::min(BLOCK_SZ_H, height - y); // height | |||
Mat<short> _XY(bh, bw, 2, XY); | |||
Mat<T> dpart(dst, y, bh, x, bw); | |||
for (y1 = 0; y1 < bh; y1++) { | |||
short* xy = XY + y1 * bw * 2; | |||
double X0 = M[0] * x + M[1] * (y + y1) + M[2]; | |||
double Y0 = M[3] * x + M[4] * (y + y1) + M[5]; | |||
double W0 = M[6] * x + M[7] * (y + y1) + M[8]; | |||
float64x2_t vW0 = vdupq_n_f64(W0); | |||
float64x2_t vidx = {0.f, 1.f}; | |||
float64x2_t vX0 = vdupq_n_f64(X0); | |||
float64x2_t vY0 = vdupq_n_f64(Y0); | |||
if (imode == IMode::NEAREST) { | |||
for (x1 = 0; x1 + 4 <= bw; x1 += 4) { | |||
float64x2_t vw0 = vaddq_f64(vW0, vmulq_f64(vM6, vidx)); | |||
float64x2_t vw1 = vaddq_f64(vw0, v2M6); | |||
vw0 = vbitq_f64(vdivq_f64(v1f, vw0), v0f, vceqq_f64(vw0, v0f)); | |||
vw1 = vbitq_f64(vdivq_f64(v1f, vw1), v0f, vceqq_f64(vw1, v0f)); | |||
float64x2_t vtmp0 = vmlaq_f64(vX0, vM0, vidx); | |||
float64x2_t vtmp1 = vaddq_f64(vtmp0, v2M0); | |||
float64x2_t vfx0 = vmulq_f64(vtmp0, vw0); | |||
float64x2_t vfx1 = vmulq_f64(vtmp1, vw1); | |||
vfx0 = vmaxq_f64(vminq_f64(vfx0, vmax), vmin); | |||
vfx1 = vmaxq_f64(vminq_f64(vfx1, vmax), vmin); | |||
vtmp0 = vmlaq_f64(vY0, vM3, vidx); | |||
vtmp1 = vaddq_f64(vtmp0, v2M3); | |||
float64x2_t vfy0 = vmulq_f64(vtmp0, vw0); | |||
float64x2_t vfy1 = vmulq_f64(vtmp1, vw1); | |||
vfy0 = vmaxq_f64(vminq_f64(vfy0, vmax), vmin); | |||
vfy1 = vmaxq_f64(vminq_f64(vfy1, vmax), vmin); | |||
int32x2_t vx0 = vqmovn_s64(vcvtaq_s64_f64(vfx0)); | |||
int32x2_t vx1 = vqmovn_s64(vcvtaq_s64_f64(vfx1)); | |||
int32x2_t vy0 = vqmovn_s64(vcvtaq_s64_f64(vfy0)); | |||
int32x2_t vy1 = vqmovn_s64(vcvtaq_s64_f64(vfy1)); | |||
int32x4_t vx = vcombine_s32(vx0, vx1); | |||
int32x4_t vy = vcombine_s32(vy0, vy1); | |||
int16x4x2_t ret = {{vqmovn_s32(vx), vqmovn_s32(vy)}}; | |||
vst2_s16(xy + x1 * 2, ret); | |||
vidx = vaddq_f64(vidx, v4f); | |||
} | |||
for (; x1 < bw; x1++) { | |||
double W = W0 + M[6] * x1; | |||
W = W ? 1. / W : 0; | |||
double fX = std::max( | |||
(double)INT_MIN, | |||
std::min((double)INT_MAX, (X0 + M[0] * x1) * W)); | |||
double fY = std::max( | |||
(double)INT_MIN, | |||
std::min((double)INT_MAX, (Y0 + M[3] * x1) * W)); | |||
int X = saturate_cast<int>(fX); | |||
int Y = saturate_cast<int>(fY); | |||
xy[x1 * 2] = saturate_cast<short>(X); | |||
xy[x1 * 2 + 1] = saturate_cast<short>(Y); | |||
} | |||
} else { | |||
short* alpha = A + y1 * bw; | |||
for (x1 = 0; x1 + 4 <= bw; x1 += 4) { | |||
float64x2_t vw0 = vaddq_f64(vW0, vmulq_f64(vM6, vidx)); | |||
float64x2_t vw1 = vaddq_f64(vw0, v2M6); | |||
vw0 = vbitq_f64(vdivq_f64(vTABLE_SIZE, vw0), v0f, | |||
vceqq_f64(vw0, v0f)); | |||
vw1 = vbitq_f64(vdivq_f64(vTABLE_SIZE, vw1), v0f, | |||
vceqq_f64(vw1, v0f)); | |||
float64x2_t vtmp0 = vmlaq_f64(vX0, vM0, vidx); | |||
float64x2_t vtmp1 = vaddq_f64(vtmp0, v2M0); | |||
float64x2_t vfx0 = vmulq_f64(vtmp0, vw0); | |||
float64x2_t vfx1 = vmulq_f64(vtmp1, vw1); | |||
vfx0 = vmaxq_f64(vminq_f64(vfx0, vmax), vmin); | |||
vfx1 = vmaxq_f64(vminq_f64(vfx1, vmax), vmin); | |||
vtmp0 = vmlaq_f64(vY0, vM3, vidx); | |||
vtmp1 = vaddq_f64(vtmp0, v2M3); | |||
float64x2_t vfy0 = vmulq_f64(vtmp0, vw0); | |||
float64x2_t vfy1 = vmulq_f64(vtmp1, vw1); | |||
vfy0 = vmaxq_f64(vminq_f64(vfy0, vmax), vmin); | |||
vfy1 = vmaxq_f64(vminq_f64(vfy1, vmax), vmin); | |||
int32x2_t vx0 = vqmovn_s64(vcvtaq_s64_f64(vfx0)); | |||
int32x2_t vx1 = vqmovn_s64(vcvtaq_s64_f64(vfx1)); | |||
int32x2_t vy0 = vqmovn_s64(vcvtaq_s64_f64(vfy0)); | |||
int32x2_t vy1 = vqmovn_s64(vcvtaq_s64_f64(vfy1)); | |||
int32x4_t vx = vcombine_s32(vx0, vx1); | |||
int32x4_t vy = vcombine_s32(vy0, vy1); | |||
int16x4x2_t ret = {{vqshrn_n_s32(vx, INTER_BITS), | |||
vqshrn_n_s32(vy, INTER_BITS)}}; | |||
vst2_s16(xy + x1 * 2, ret); | |||
vidx = vaddq_f64(vidx, v4f); | |||
vx = vandq_s32(vx, vtabmask); | |||
vy = vandq_s32(vy, vtabmask); | |||
vst1_s16(&alpha[x1], | |||
vqmovn_s32(vmlaq_n_s32(vx, vy, INTER_TAB_SIZE))); | |||
} | |||
for (; x1 < bw; x1++) { | |||
double W = W0 + M[6] * x1; | |||
W = W ? INTER_TAB_SIZE / W : 0; | |||
double fX = std::max( | |||
(double)INT_MIN, | |||
std::min((double)INT_MAX, (X0 + M[0] * x1) * W)); | |||
double fY = std::max( | |||
(double)INT_MIN, | |||
std::min((double)INT_MAX, (Y0 + M[3] * x1) * W)); | |||
int X = saturate_cast<int>(fX); | |||
int Y = saturate_cast<int>(fY); | |||
xy[x1 * 2] = saturate_cast<short>(X >> INTER_BITS); | |||
xy[x1 * 2 + 1] = saturate_cast<short>(Y >> INTER_BITS); | |||
alpha[x1] = | |||
(short)((Y & (INTER_TAB_SIZE - 1)) * INTER_TAB_SIZE + | |||
(X & (INTER_TAB_SIZE - 1))); | |||
} | |||
} | |||
} | |||
Mat<ushort> _matA(bh, bw, 1, (ushort*)(A)); | |||
remap<T, imode, bmode, CH, RemapVec<T, CH>>(src, dpart, _XY, _matA, bvalue); | |||
} | |||
} // anonymous namespace | |||
void megdnn::aarch64::warp_perspective_cv_exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst, | |||
float border_value, BorderMode bmode, InterpolationMode imode, | |||
Handle* handle) { | |||
size_t ch = dst.layout[3]; | |||
size_t width = dst.layout[2]; | |||
size_t height = dst.layout[1]; | |||
const size_t batch = dst.layout.shape[0]; | |||
size_t BLOCK_SZ_H = std::min(BLOCK_SZ / 2, height); | |||
size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width); | |||
BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height); | |||
size_t parallelism_batch = div_ceil<size_t>(height, BLOCK_SZ_H) * | |||
div_ceil<size_t>(width, BLOCK_SZ_W); | |||
megdnn_assert(ch == 1 || ch == 3 || ch == 2, | |||
"unsupported src channel: %zu, avaiable channel size: 1/2/3", | |||
ch); | |||
const float* trans_ptr = trans.ptr<dt_float32>(); | |||
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { | |||
#define cb(_imode, _bmode, _ch) \ | |||
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ | |||
size_t index, size_t) { \ | |||
size_t batch_id = index / parallelism_batch; \ | |||
size_t task_id = index % parallelism_batch; \ | |||
Mat<float> src_mat = TensorND2Mat<float>(src, batch_id); \ | |||
Mat<float> dst_mat = TensorND2Mat<float>(dst, batch_id); \ | |||
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ | |||
warp_perspective_cv<float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \ | |||
MEGDNN_COMMA _ch>( \ | |||
src_mat MEGDNN_COMMA const_cast<Mat<float>&>(dst_mat) \ | |||
MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ | |||
task_id); \ | |||
}; \ | |||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||
static_cast<naive::HandleImpl*>(handle), batch* parallelism_batch, \ | |||
task); | |||
DISPATCH_IMODE(imode, bmode, ch, cb) | |||
#undef cb | |||
} else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { | |||
#define cb(_imode, _bmode, _ch) \ | |||
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ | |||
size_t index, size_t) { \ | |||
size_t batch_id = index / parallelism_batch; \ | |||
size_t task_id = index % parallelism_batch; \ | |||
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, batch_id); \ | |||
Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, batch_id); \ | |||
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ | |||
warp_perspective_cv<uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \ | |||
MEGDNN_COMMA _ch>( \ | |||
src_mat MEGDNN_COMMA const_cast<Mat<uchar>&>(dst_mat) \ | |||
MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ | |||
task_id); \ | |||
}; \ | |||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||
static_cast<naive::HandleImpl*>(handle), batch* parallelism_batch, \ | |||
task); | |||
DISPATCH_IMODE(imode, bmode, ch, cb) | |||
#undef cb | |||
} else { | |||
megdnn_throw( | |||
megdnn_mangle("Unsupported datatype of WarpAffine optr.")); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,32 @@ | |||
/** | |||
* \file dnn/src/aarch64/warp_perspective/warp_perspective_cv.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <megdnn/oprs.h> | |||
#include "src/common/cv/helper.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
/** | |||
* \fn warp_perspective_cv | |||
* \brief Used if the format is NHWC, transfer from megcv | |||
*/ | |||
void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, | |||
_megdnn_tensor_in dst, float border_value, | |||
param::WarpPerspective::BorderMode border_mode, | |||
param::WarpPerspective::InterpolationMode imode, | |||
Handle* handle); | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,380 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/direct/multi_thread_common.h" | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "src/fallback/matrix_mul/opr_impl.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
namespace { | |||
bool need_dst_copy( | |||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { | |||
auto align = param.src_type.enumv() == DTypeEnum::Float32 ? 4 : 8; | |||
return param.osz[1] % align; | |||
} | |||
bool need_src_copy( | |||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { | |||
if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { | |||
return true; | |||
} | |||
return need_dst_copy(param); | |||
} | |||
void get_rectified_size( | |||
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, size_t FW, | |||
size_t PH, size_t PW, size_t& IH2, size_t& IW2, size_t& OW2) { | |||
MEGDNN_MARK_USED_VAR(PW); | |||
MEGDNN_MARK_USED_VAR(PH); | |||
auto&& fm = param.filter_meta; | |||
auto SW = fm.stride[1]; | |||
auto Align = param.src_type.enumv() == DTypeEnum::Float32 ? 3 : 7; | |||
OW2 = (OW + Align) & ~Align; | |||
IH2 = SW * OH + FH - SW; | |||
IW2 = SW * OW2 + FW - SW; | |||
// Because stride is 2, sometimes IW == IW2+1. Do a max update to | |||
// handle this case. | |||
IH2 = std::max(IH2, IH); | |||
IW2 = std::max(IW2, IW); | |||
} | |||
} // namespace | |||
template <typename io_ctype, typename compute_ctype> | |||
WorkspaceBundle MultithreadDirectConvCommon<io_ctype, compute_ctype>::get_bundle( | |||
const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { | |||
auto&& fm = param.filter_meta; | |||
size_t nr_threads = param.nr_threads; | |||
size_t group = fm.group, batch = param.n; | |||
size_t IH2 = param.isz[0] + 2 * fm.padding[0]; | |||
size_t IW2 = param.isz[1] + 2 * fm.padding[1]; | |||
// part0: copied src | |||
// part1: copied filter | |||
size_t part0, part1; | |||
if (fm.padding[0] == 0 && fm.padding[1] == 0) { | |||
//! only the last plane need to be copied, add 16 Byte extra space in | |||
//! case of invalid read and write | |||
part0 = (param.isz[0] * param.isz[1]) * sizeof(io_ctype) + 16; | |||
} else if (m_large_group) { | |||
//! Serial in group, each thread process one group, parallel by group | |||
part0 = (IH2 * IW2 * fm.icpg * nr_threads) * sizeof(io_ctype) + 16; | |||
} else { | |||
//! Parallel in group, Then should copy every inputs to workspace | |||
part0 = (IH2 * IW2 * fm.icpg * group * batch) * sizeof(io_ctype) + 16; | |||
} | |||
if (param.filter_meta.should_flip) { | |||
if (m_large_group) { | |||
//! Serial in group, each thread has own workspace and then reuse | |||
part1 = fm.spatial[0] * fm.spatial[1] * fm.ocpg * fm.icpg * | |||
nr_threads * sizeof(io_ctype); | |||
} else { | |||
part1 = fm.spatial[0] * fm.spatial[1] * fm.ocpg * fm.icpg * group * | |||
sizeof(io_ctype); | |||
} | |||
} else { | |||
part1 = 0; | |||
} | |||
return {nullptr, {part0, part1}}; | |||
} | |||
template <typename io_ctype, typename compute_ctype> | |||
WorkspaceBundle | |||
MultithreadDirectConvCommon<io_ctype, compute_ctype>::get_bundle_stride( | |||
const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
MEGDNN_MARK_USED_VAR(N); | |||
MEGDNN_MARK_USED_VAR(OC); | |||
MEGDNN_MARK_USED_VAR(SH); | |||
MEGDNN_MARK_USED_VAR(SW); | |||
auto&& fm = param.filter_meta; | |||
size_t nr_threads = param.nr_threads; | |||
size_t group = fm.group, batch = param.n; | |||
size_t IH2, IW2, OW2; | |||
get_rectified_size(param, IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); | |||
size_t src_size = 0, dst_size = 0; | |||
// src_size: copied src | |||
// dst_size: copied dst | |||
if (need_src_copy(param)) { | |||
src_size = m_large_group | |||
? IC * IH2 * IW2 * sizeof(io_ctype) * nr_threads | |||
: IC * IH2 * IW2 * sizeof(io_ctype) * group * batch; | |||
}; | |||
if (need_dst_copy(param)) { | |||
//! add 16 Byte extra space in case of invalid read and write | |||
dst_size = OH * OW2 * sizeof(io_ctype) * nr_threads + 16; | |||
} | |||
return {nullptr, {src_size, dst_size}}; | |||
} | |||
//! Process one output channel weight flip | |||
template <typename io_ctype, typename compute_ctype> | |||
void MultithreadDirectConvCommon<io_ctype, compute_ctype>::weight_flip_kern( | |||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids) { | |||
size_t FH = kern_param.filter_meta.spatial[0]; | |||
size_t FW = kern_param.filter_meta.spatial[1]; | |||
size_t IC = kern_param.filter_meta.icpg; | |||
size_t OC = kern_param.filter_meta.ocpg; | |||
//! Used for get the workspace offset | |||
size_t workspace_group_id = workspace_ids[0], channel_id = workspace_ids[2], | |||
group_id = ncb_index.ndrange_id[0]; | |||
const io_ctype* filter = | |||
kern_param.filter<io_ctype>(group_id) + channel_id * FH * FW * IC; | |||
bundle.set(kern_param.workspace_ptr); | |||
io_ctype* filter_flip = | |||
static_cast<io_ctype*>(bundle.get(1)) + | |||
(workspace_group_id * IC * OC + channel_id * IC) * FH * FW; | |||
rep(ic, IC) { | |||
const io_ctype* filter_plane = filter + ic * FH * FW; | |||
io_ctype* filter_flip_plane = filter_flip + ic * FH * FW; | |||
rep(fh, FH) rep(fw, FW) { | |||
filter_flip_plane[fh * FW + fw] = | |||
filter_plane[(FH - fh - 1) * FW + (FW - fw - 1)]; | |||
} | |||
} | |||
} | |||
//! Process one input channel copy padding | |||
template <typename io_ctype, typename compute_ctype> | |||
void MultithreadDirectConvCommon<io_ctype, compute_ctype>::copy_padding_kern( | |||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids) { | |||
size_t IH = kern_param.isz[0]; | |||
size_t IW = kern_param.isz[1]; | |||
size_t IC = kern_param.filter_meta.icpg; | |||
size_t PH = kern_param.filter_meta.padding[0]; | |||
size_t PW = kern_param.filter_meta.padding[1]; | |||
size_t IH2 = IH + 2 * PH; | |||
size_t IW2 = IW + 2 * PW; | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
size_t N = kern_param.n; | |||
size_t GROUP = kern_param.filter_meta.group; | |||
bundle.set(kern_param.workspace_ptr); | |||
//! Used for get the workspace offset | |||
size_t workspace_group_id = workspace_ids[0], | |||
workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2]; | |||
size_t batch_id = ncb_index.ndrange_id[1], | |||
group_id = ncb_index.ndrange_id[0]; | |||
const io_ctype* sptr = static_cast<const io_ctype*>( | |||
kern_param.src<io_ctype>(batch_id, group_id, channel_id)); | |||
if (PH > 0 || PW > 0) { | |||
//! copy to sptr_base to eliminate padding effect | |||
io_ctype* sptr_base = static_cast<io_ctype*>(bundle.get(0)) + | |||
workspace_group_id * padding_group_size + | |||
workspace_batch_id * GROUP * padding_group_size + | |||
channel_id * IH2 * IW2; | |||
std::memset(sptr_base, 0, sizeof(io_ctype) * IH2 * IW2); | |||
rep(ih, IH) { | |||
std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, | |||
sizeof(io_ctype) * IW); | |||
} | |||
} else if (batch_id + 1 == N && channel_id + 1 == IC && | |||
group_id + 1 == GROUP) { | |||
//! copy last plane | |||
io_ctype* sptr_last_c = static_cast<io_ctype*>(bundle.get(0)); | |||
std::memcpy(sptr_last_c, sptr, sizeof(io_ctype) * IH2 * IW2); | |||
} | |||
}; | |||
//! Process one input channel copy padding | |||
template <typename io_ctype, typename compute_ctype> | |||
void MultithreadDirectConvCommon<io_ctype, compute_ctype>:: | |||
copy_padding_kern_stride(WorkspaceBundle bundle, | |||
const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids) { | |||
size_t IH = kern_param.isz[0]; | |||
size_t IW = kern_param.isz[1]; | |||
size_t IC = kern_param.filter_meta.icpg; | |||
size_t PH = kern_param.filter_meta.padding[0]; | |||
size_t PW = kern_param.filter_meta.padding[1]; | |||
size_t FH = kern_param.filter_meta.spatial[0]; | |||
size_t FW = kern_param.filter_meta.spatial[1]; | |||
size_t OW = kern_param.osz[1]; | |||
size_t OH = kern_param.osz[0]; | |||
size_t IH2, IW2, OW2; | |||
size_t GROUP = kern_param.filter_meta.group; | |||
get_rectified_size(kern_param, IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
bundle.set(kern_param.workspace_ptr); | |||
//! Used for get the workspace offset | |||
size_t workspace_group_id = workspace_ids[0], | |||
workspace_batch_id = workspace_ids[1]; | |||
size_t channel_id = workspace_ids[2], batch_id = ncb_index.ndrange_id[1], | |||
group_id = ncb_index.ndrange_id[0]; | |||
const io_ctype* sptr = static_cast<const io_ctype*>( | |||
kern_param.src<io_ctype>(batch_id, group_id, channel_id)); | |||
if (need_src_copy(kern_param)) { | |||
//! copy to sptr_base to eliminate padding effect | |||
io_ctype* sptr_base = static_cast<io_ctype*>(bundle.get(0)) + | |||
workspace_group_id * padding_group_size + | |||
workspace_batch_id * GROUP * padding_group_size + | |||
channel_id * IH2 * IW2; | |||
std::memset(sptr_base, 0, sizeof(io_ctype) * IH2 * IW2); | |||
rep(ih, IH) { | |||
std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, | |||
sizeof(io_ctype) * IW); | |||
} | |||
} | |||
}; | |||
//! compute one output channel | |||
template <typename io_ctype, typename compute_ctype> | |||
void MultithreadDirectConvCommon<io_ctype, compute_ctype>::do_conv_kern( | |||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const kern_direct_conv_f32& fun, const CpuNDRange& workspace_ids) { | |||
size_t OH = kern_param.osz[0]; | |||
size_t OW = kern_param.osz[1]; | |||
size_t FH = kern_param.filter_meta.spatial[0]; | |||
size_t FW = kern_param.filter_meta.spatial[1]; | |||
size_t IC = kern_param.filter_meta.icpg; | |||
size_t OC = kern_param.filter_meta.ocpg; | |||
size_t PH = kern_param.filter_meta.padding[0]; | |||
size_t PW = kern_param.filter_meta.padding[1]; | |||
size_t IH2 = kern_param.isz[0] + 2 * PH; | |||
size_t IW2 = kern_param.isz[1] + 2 * PW; | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
size_t N = kern_param.n; | |||
size_t GROUP = kern_param.filter_meta.group; | |||
bundle.set(kern_param.workspace_ptr); | |||
size_t group_id = ncb_index.ndrange_id[0], | |||
batch_id = ncb_index.ndrange_id[1]; | |||
size_t channel_id = workspace_ids[2]; | |||
const io_ctype* sptr = kern_param.src<io_ctype>(batch_id, group_id); | |||
const io_ctype* filter = kern_param.filter<io_ctype>(group_id); | |||
const io_ctype* bias_ptr = | |||
kern_param.bias<io_ctype>(batch_id, group_id, channel_id); | |||
io_ctype* dptr = kern_param.dst<io_ctype>(batch_id, group_id, channel_id); | |||
//! Used for get the workspace offset | |||
size_t workspace_batch_id = workspace_ids[1]; | |||
size_t workspace_group_id = workspace_ids[0]; | |||
io_ctype* sptr_base; | |||
io_ctype* sptr_last_c; | |||
auto fptr = | |||
kern_param.filter_meta.should_flip | |||
? static_cast<io_ctype*>(bundle.get(1)) + | |||
(workspace_group_id * OC * IC + channel_id * IC) * | |||
FH * FW | |||
: filter + channel_id * FH * FW * IC; | |||
if (PH > 0 || PW > 0) { | |||
sptr_base = static_cast<io_ctype*>(bundle.get(0)) + | |||
workspace_group_id * padding_group_size + | |||
workspace_batch_id * GROUP * padding_group_size; | |||
sptr_last_c = sptr_base + (IC - 1) * IH2 * IW2; | |||
//! Last batch, last group | |||
} else if (batch_id + 1 == N && group_id + 1 == GROUP) { | |||
sptr_base = const_cast<io_ctype*>(sptr); | |||
sptr_last_c = static_cast<io_ctype*>(bundle.get(0)); | |||
} else { | |||
sptr_base = const_cast<io_ctype*>(sptr); | |||
sptr_last_c = sptr_base + (IC - 1) * IH2 * IW2; | |||
} | |||
std::memset(dptr, 0, sizeof(io_ctype) * (OH * OW)); | |||
rep(ic, IC) { | |||
io_ctype* sptr_cur = | |||
(ic + 1 == IC ? sptr_last_c : sptr_base + ic * IH2 * IW2); | |||
fun(reinterpret_cast<const compute_ctype*>(sptr_cur), | |||
reinterpret_cast<const compute_ctype*>(fptr + ic * FH * FW), | |||
reinterpret_cast<compute_ctype*>(dptr), IH2, IW2, OH, OW, FH, FW); | |||
} | |||
PostProcess<compute_ctype>::run(dptr, const_cast<io_ctype*>(bias_ptr), dptr, | |||
kern_param.bias_mode, kern_param.nonlineMode, | |||
kern_param.bias_type, kern_param.dst_type, 1_z, | |||
1_z, OH, OW); | |||
}; | |||
//! compute one output channel | |||
template <typename io_ctype, typename compute_ctype> | |||
void MultithreadDirectConvCommon<io_ctype, compute_ctype>::do_conv_kern_stride( | |||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const kern_direct_conv_f32_stride& fun, | |||
const CpuNDRange& workspace_ids) { | |||
size_t IH = kern_param.isz[0]; | |||
size_t IW = kern_param.isz[1]; | |||
size_t OH = kern_param.osz[0]; | |||
size_t OW = kern_param.osz[1]; | |||
size_t FH = kern_param.filter_meta.spatial[0]; | |||
size_t FW = kern_param.filter_meta.spatial[1]; | |||
size_t IC = kern_param.filter_meta.icpg; | |||
size_t PH = kern_param.filter_meta.padding[0]; | |||
size_t PW = kern_param.filter_meta.padding[1]; | |||
size_t IH2, IW2, OW2; | |||
get_rectified_size(kern_param, IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
size_t GROUP = kern_param.filter_meta.group; | |||
bundle.set(kern_param.workspace_ptr); | |||
//! Used for get the workspace offset | |||
size_t group_id = ncb_index.ndrange_id[0], | |||
batch_id = ncb_index.ndrange_id[1]; | |||
size_t channel_id = workspace_ids[2]; | |||
const io_ctype* sptr = kern_param.src<io_ctype>(batch_id, group_id); | |||
const io_ctype* fptr = | |||
kern_param.filter<io_ctype>(group_id) + channel_id * FH * FW * IC; | |||
const io_ctype* bias_ptr = | |||
kern_param.bias<io_ctype>(batch_id, group_id, channel_id); | |||
io_ctype* dptr = kern_param.dst<io_ctype>(batch_id, group_id, channel_id); | |||
size_t workspace_batch_id = workspace_ids[1]; | |||
size_t workspace_group_id = workspace_ids[0]; | |||
io_ctype* sptr_base; | |||
io_ctype* dptr_base; | |||
if (need_src_copy(kern_param)) { | |||
sptr_base = static_cast<io_ctype*>(bundle.get(0)) + | |||
workspace_group_id * padding_group_size + | |||
workspace_batch_id * GROUP * padding_group_size; | |||
} else { | |||
sptr_base = const_cast<io_ctype*>(sptr); | |||
} | |||
if (need_dst_copy(kern_param)) { | |||
dptr_base = static_cast<io_ctype*>(bundle.get(1)) + | |||
ncb_index.thread_id * OH * OW2; | |||
} else { | |||
dptr_base = dptr; | |||
} | |||
if (need_dst_copy(kern_param)) { | |||
std::memset(dptr_base, 0, sizeof(io_ctype) * (OH * OW2)); | |||
fun(reinterpret_cast<const compute_ctype*>(sptr_base), | |||
reinterpret_cast<const compute_ctype*>(fptr), | |||
reinterpret_cast<compute_ctype*>(dptr_base), IH2, IW2, OH, OW2, IC); | |||
copy_plane_in_bytes(dptr, dptr_base, OH, OW * sizeof(io_ctype), | |||
OW * sizeof(io_ctype), OW2 * sizeof(io_ctype)); | |||
} else { | |||
std::memset(dptr_base, 0, sizeof(io_ctype) * (OH * OW)); | |||
fun(reinterpret_cast<const compute_ctype*>(sptr_base), | |||
reinterpret_cast<const compute_ctype*>(fptr), | |||
reinterpret_cast<compute_ctype*>(dptr_base), IH2, IW2, OH, OW, IC); | |||
} | |||
PostProcess<compute_ctype>::run(dptr, const_cast<io_ctype*>(bias_ptr), dptr, | |||
kern_param.bias_mode, kern_param.nonlineMode, | |||
kern_param.bias_type, kern_param.dst_type, 1_z, | |||
1_z, OH, OW); | |||
}; | |||
template class megdnn::arm_common::MultithreadDirectConvCommon<float, float>; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
template class megdnn::arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>; | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,65 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/direct/multi_thread_common.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/fallback/matrix_mul/opr_impl.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
template <class io_ctype, class compute_ctype> | |||
class MultithreadDirectConvCommon { | |||
public: | |||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; | |||
using kern_direct_conv_f32 = | |||
std::function<void(const compute_ctype* src, | |||
const compute_ctype* filter, compute_ctype* dst, | |||
size_t, size_t, size_t, size_t, size_t, size_t)>; | |||
using kern_direct_conv_f32_stride = std::function<void( | |||
const compute_ctype* src, const compute_ctype* filter, | |||
compute_ctype* dst, size_t, size_t, size_t, size_t, size_t)>; | |||
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param, | |||
bool m_large_group); | |||
static WorkspaceBundle get_bundle_stride(const NCBKernSizeParam& param, | |||
bool m_large_group); | |||
static void weight_flip_kern(WorkspaceBundle bundle, | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids); | |||
static void copy_padding_kern(WorkspaceBundle bundle, | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids); | |||
static void copy_padding_kern_stride(WorkspaceBundle bundle, | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids); | |||
static void do_conv_kern(WorkspaceBundle bundle, | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index, | |||
const kern_direct_conv_f32& fun, | |||
const CpuNDRange& workspace_ids); | |||
static void do_conv_kern_stride(WorkspaceBundle bundle, | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index, | |||
const kern_direct_conv_f32_stride& fun, | |||
const CpuNDRange& workspace_ids); | |||
}; | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,561 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/f16/algos.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/f16/algos.h" | |||
#include "src/arm_common/conv_bias/direct/multi_thread_common.h" | |||
#include "src/arm_common/conv_bias/f16/direct.h" | |||
#include "src/arm_common/conv_bias/f16/do_conv_stride1.h" | |||
#include "src/arm_common/conv_bias/f16/strategy.h" | |||
#include "src/arm_common/conv_bias/img2col_helper.h" | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp16) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
/* ======================= AlgoFP16WinogradF23 ======================== */ | |||
bool ConvBiasImpl::AlgoFP16WinogradF23::usable( | |||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MEGDNN_MARK_USED_VAR(opr); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 0, 0) { | |||
using Strategy = winograd::winograd_2x3_4x4_f16; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
megdnn::winograd::ConvBias<Strategy>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_matmul_kern_param(param); | |||
return m_matmul_algo->usable(matmul_param) && | |||
(opr->param().format == param::ConvBias::Format::NCHW || | |||
(opr->param().format == | |||
param::ConvBias::Format::NCHW_WINOGRAD && | |||
opr->param().output_block_size == 2 && | |||
param.winograd_matmul_format == | |||
param::MatrixMul::Format::DEFAULT)) && | |||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||
param.filter_meta.spatial[0] == 3) && | |||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
param.filter_meta.stride[0] == 1) && | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||
param.src_type.enumv() == DTypeEnum::Float16 && | |||
param.filter_meta.icpg % 4 == 0 && | |||
param.filter_meta.ocpg % 4 == 0; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoFP16WinogradF23::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 0, 1) { | |||
winograd::winograd_2x3_4x4_f16 strategy( | |||
param.src_type, param.filter_type, param.dst_type); | |||
return megdnn::winograd::ConvBias<winograd::winograd_2x3_4x4_f16>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_workspace_size(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoFP16WinogradF23::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 0, 2) { | |||
winograd::winograd_2x3_4x4_f16 strategy( | |||
param.src_type, param.filter_type, param.dst_type); | |||
auto winograd_impl = | |||
megdnn::winograd::ConvBias<winograd::winograd_2x3_4x4_f16>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg); | |||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ======================= AlgoFP16WinogradF45 ======================== */ | |||
bool ConvBiasImpl::AlgoFP16WinogradF45::usable( | |||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MEGDNN_MARK_USED_VAR(opr); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 1, 0) { | |||
using Strategy = winograd::winograd_4x5_1x1_f16; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
megdnn::winograd::ConvBias<Strategy>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_matmul_kern_param(param); | |||
return m_matmul_algo->usable(matmul_param) && | |||
(opr->param().format == param::ConvBias::Format::NCHW || | |||
(opr->param().format == | |||
param::ConvBias::Format::NCHW_WINOGRAD && | |||
opr->param().output_block_size == 4 && | |||
param.winograd_matmul_format == | |||
param::MatrixMul::Format::DEFAULT)) && | |||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||
param.filter_meta.spatial[0] == 5) && | |||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
param.filter_meta.stride[0] == 1) && | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||
param.src_type.enumv() == DTypeEnum::Float16; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoFP16WinogradF45::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
winograd::winograd_4x5_1x1_f16 strategy(param.src_type, param.filter_type, | |||
param.dst_type); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 1, 1) { | |||
return megdnn::winograd::ConvBias<winograd::winograd_4x5_1x1_f16>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_workspace_size(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoFP16WinogradF45::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 1, 2) { | |||
winograd::winograd_4x5_1x1_f16 strategy( | |||
param.src_type, param.filter_type, param.dst_type); | |||
auto winograd_impl = | |||
megdnn::winograd::ConvBias<winograd::winograd_4x5_1x1_f16>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg); | |||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ======================= AlgoFP16WinogradF63 ======================== */ | |||
bool ConvBiasImpl::AlgoFP16WinogradF63::usable( | |||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MEGDNN_MARK_USED_VAR(opr); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 2, 0) { | |||
using Strategy = winograd::winograd_6x3_1x1_f16; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
megdnn::winograd::ConvBias<Strategy>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_matmul_kern_param(param); | |||
return m_matmul_algo->usable(matmul_param) && | |||
(opr->param().format == param::ConvBias::Format::NCHW || | |||
(opr->param().format == | |||
param::ConvBias::Format::NCHW_WINOGRAD && | |||
opr->param().output_block_size == 6 && | |||
param.winograd_matmul_format == | |||
param::MatrixMul::Format::DEFAULT)) && | |||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||
param.filter_meta.spatial[0] == 3) && | |||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
param.filter_meta.stride[0] == 1) && | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||
param.src_type.enumv() == DTypeEnum::Float16; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoFP16WinogradF63::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
winograd::winograd_6x3_1x1_f16 strategy(param.src_type, param.filter_type, | |||
param.dst_type); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 2, 1) { | |||
return megdnn::winograd::ConvBias<winograd::winograd_6x3_1x1_f16>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_workspace_size(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoFP16WinogradF63::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 2, 2) { | |||
winograd::winograd_6x3_1x1_f16 strategy( | |||
param.src_type, param.filter_type, param.dst_type); | |||
auto winograd_impl = | |||
megdnn::winograd::ConvBias<winograd::winograd_6x3_1x1_f16>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg); | |||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ======================= AlgoFP16WinogradF23_8x8 ======================== */ | |||
bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable( | |||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MEGDNN_MARK_USED_VAR(opr); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 3, 0) { | |||
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) | |||
return false; | |||
using Strategy = winograd::winograd_2x3_8x8_f16; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
megdnn::winograd::ConvBias<Strategy, | |||
param::MatrixMul::Format::MK8>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_matmul_kern_param(param); | |||
return m_matmul_algo->usable(matmul_param) && | |||
(opr->param().format == param::ConvBias::Format::NCHW || | |||
(opr->param().format == | |||
param::ConvBias::Format::NCHW_WINOGRAD && | |||
opr->param().output_block_size == 2 && | |||
param.winograd_matmul_format == | |||
param::MatrixMul::Format::MK8)) && | |||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||
param.filter_meta.spatial[0] == 3) && | |||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
param.filter_meta.stride[0] == 1) && | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||
param.src_type.enumv() == DTypeEnum::Float16; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoFP16WinogradF23_8x8::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 3, 1) { | |||
winograd::winograd_2x3_8x8_f16 strategy( | |||
param.src_type, param.filter_type, param.dst_type); | |||
return megdnn::winograd::ConvBias<winograd::winograd_2x3_8x8_f16, | |||
param::MatrixMul::Format::MK8>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_workspace_size(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoFP16WinogradF23_8x8::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 2) { | |||
winograd::winograd_2x3_8x8_f16 strategy( | |||
param.src_type, param.filter_type, param.dst_type); | |||
auto winograd_impl = | |||
megdnn::winograd::ConvBias<winograd::winograd_2x3_8x8_f16, | |||
param::MatrixMul::Format::MK8>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg); | |||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/*========================from Convolution=============================*/ | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_kimpl) | |||
bool ConvBiasImpl::AlgoF16Direct::usable( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
auto SH = fm.stride[0], SW = fm.stride[1]; | |||
// the condition ``param.isz[0]*param.isz[1] >= 8'' and | |||
// ``param.osz[0]*param.osz[1] >= 8'' comes from the fact that the | |||
// kernel may have access to up to 8 fp16 after the end of the memory | |||
// chunk. | |||
bool aviliable = fm.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float16 && | |||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && | |||
param.isz[0] * param.isz[1] >= 8 && | |||
param.osz[0] * param.osz[1] >= 8 && FH <= 7 && | |||
SH == 1 && SW == 1; | |||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoF16Direct::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) { | |||
auto wbundle = | |||
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle( | |||
param, m_large_group); | |||
return wbundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls( | |||
const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
size_t N = param.n; | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
WorkspaceBundle wbundle = | |||
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle( | |||
param, m_large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! When group >= nr_threads, treat it as large_group, each thread process | |||
//! one group for better performance | |||
if (m_large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [wbundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
auto fm = kern_param.filter_meta; | |||
size_t IC = fm.icpg; | |||
size_t OC = fm.ocpg; | |||
WorkspaceBundle bundle = wbundle; | |||
if (fm.should_flip) { | |||
for (size_t oc = 0; oc < OC; oc++) { | |||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
weight_flip_kern(bundle, kern_param, ncb_index, | |||
{ncb_index.thread_id, 0, oc}); | |||
} | |||
} | |||
for (size_t ic = 0; ic < IC; ic++) { | |||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
copy_padding_kern(bundle, kern_param, ncb_index, | |||
{ncb_index.thread_id, 0, ic}); | |||
} | |||
for (size_t oc = 0; oc < OC; oc++) { | |||
MultithreadDirectConvCommon<dt_float16, __fp16>::do_conv_kern( | |||
bundle, kern_param, ncb_index, | |||
fp16::conv_bias::kern_direct_f16, | |||
{ncb_index.thread_id, 0, oc}); | |||
} | |||
}; | |||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||
} else { | |||
WorkspaceBundle bundle = wbundle; | |||
if (fm.should_flip) { | |||
auto weight_flip = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
weight_flip_kern(bundle, kern_param, ncb_index, | |||
ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({weight_flip, {group, 1_z, OC}}); | |||
} | |||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
MultithreadDirectConvCommon<dt_float16, __fp16>::copy_padding_kern( | |||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||
auto do_conv = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
MultithreadDirectConvCommon<dt_float16, __fp16>::do_conv_kern( | |||
bundle, kern_param, ncb_index, | |||
fp16::conv_bias::kern_direct_f16, ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||
} | |||
return ret_kerns; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) { | |||
return get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ===================== stride-1 algo ===================== */ | |||
bool ConvBiasImpl::AlgoF16DirectStride1::usable( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
bool aviliable = | |||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float16 && | |||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( | |||
const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
size_t N = param.n; | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | |||
size_t, size_t, size_t, size_t, size_t)>; | |||
Func conv_kern_function = nullptr; | |||
#define SWITCH_KERN() \ | |||
switch (FH) { \ | |||
case 2: \ | |||
conv_kern_function = fp16::conv_stride1::do_conv_2x2_stride1; \ | |||
break; \ | |||
case 3: \ | |||
conv_kern_function = fp16::conv_stride1::do_conv_3x3_stride1; \ | |||
break; \ | |||
case 5: \ | |||
conv_kern_function = fp16::conv_stride1::do_conv_5x5_stride1; \ | |||
break; \ | |||
} | |||
SWITCH_KERN(); | |||
WorkspaceBundle wbundle = | |||
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle_stride( | |||
param, m_large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! When group >= nr_threads, treat it as large_group, each thread process | |||
//! one group for better performance | |||
if (m_large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [wbundle, conv_kern_function]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
auto fm = kern_param.filter_meta; | |||
size_t IC = fm.icpg; | |||
size_t OC = fm.ocpg; | |||
WorkspaceBundle bundle = wbundle; | |||
for (size_t ic = 0; ic < IC; ic++) { | |||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
{ncb_index.thread_id, 0, ic}); | |||
} | |||
for (size_t oc = 0; oc < OC; oc++) { | |||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
do_conv_kern_stride(bundle, kern_param, ncb_index, | |||
conv_kern_function, | |||
{ncb_index.thread_id, 0, oc}); | |||
} | |||
}; | |||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||
} else { | |||
WorkspaceBundle bundle = wbundle; | |||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||
auto do_conv = [bundle, conv_kern_function]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
do_conv_kern_stride(bundle, kern_param, ncb_index, | |||
conv_kern_function, | |||
ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||
} | |||
return ret_kerns; | |||
} | |||
size_t ConvBiasImpl::AlgoF16DirectStride1::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 1) { | |||
auto bundle = MultithreadDirectConvCommon< | |||
dt_float16, __fp16>::get_bundle_stride(param, m_large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF16DirectStride1::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 2) { | |||
return get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,185 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/f16/algos.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/fallback/matrix_mul/opr_impl.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
namespace megdnn { | |||
namespace arm_common { | |||
class ConvBiasImpl::AlgoFP16WinogradF23 final : public AlgoBase { | |||
public: | |||
AlgoFP16WinogradF23(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
uint32_t tile_size) | |||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
if (m_name.empty()) { | |||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||
m_matmul_algo->name(), {1, 2, m_tile_size}); | |||
} | |||
return m_name.c_str(); | |||
} | |||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
static std::vector<fallback::MatrixMulImpl::Algorithm*> | |||
get_avaiable_matmul_algos(const NCBKernSizeParam& param); | |||
private: | |||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
mutable std::string m_name; | |||
uint32_t m_tile_size; | |||
}; | |||
class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { | |||
public: | |||
AlgoFP16WinogradF45(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
uint32_t tile_size) | |||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
if (m_name.empty()) { | |||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||
m_matmul_algo->name(), {1, 4, m_tile_size}); | |||
} | |||
return m_name.c_str(); | |||
} | |||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
static std::vector<fallback::MatrixMulImpl::Algorithm*> | |||
get_avaiable_matmul_algos(const NCBKernSizeParam& param); | |||
private: | |||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
mutable std::string m_name; | |||
uint32_t m_tile_size; | |||
}; | |||
class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { | |||
public: | |||
AlgoFP16WinogradF63(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
uint32_t tile_size) | |||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
if (m_name.empty()) { | |||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||
m_matmul_algo->name(), {1, 6, m_tile_size}); | |||
} | |||
return m_name.c_str(); | |||
} | |||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
static std::vector<fallback::MatrixMulImpl::Algorithm*> | |||
get_avaiable_matmul_algos(const NCBKernSizeParam& param); | |||
private: | |||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
mutable std::string m_name; | |||
uint32_t m_tile_size; | |||
}; | |||
class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { | |||
public: | |||
AlgoFP16WinogradF23_8x8(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
uint32_t tile_size) | |||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
if (m_name.empty()) { | |||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||
m_matmul_algo->name(), {8, 2, m_tile_size}); | |||
} | |||
return m_name.c_str(); | |||
} | |||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
private: | |||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
mutable std::string m_name; | |||
uint32_t m_tile_size; | |||
}; | |||
class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF16Direct(bool is_large_group) : m_large_group{is_large_group} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "F16DIRECT_LARGE_GROUP" | |||
: "F16DIRECT_SMALL_GROUP"; | |||
} | |||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
}; | |||
class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF16DirectStride1(bool is_large_group) : m_large_group{is_large_group} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "F16STRD1_LARGE_GROUP" : "F16STRD1_SMALL_GROUP"; | |||
} | |||
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
}; | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,799 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/f16/direct.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./direct.h" | |||
#include "include/megdnn/oprs.h" | |||
#include "midout.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
#include <cstring> | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/simd_macro/neon_helper.h" | |||
MIDOUT_DECL(megdnn_arm_conv_f16) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fp16; | |||
using namespace conv_bias; | |||
namespace { | |||
#define BLOCK_H 8 | |||
#define LOAD_RESULT_VAL \ | |||
if (width < 8) { \ | |||
auto load_less_8 = [](__fp16* dst, float16x8_t& data) { \ | |||
if (width == 1u) { \ | |||
data = vld1q_lane_f16(dst, data, 0); \ | |||
} else if (width == 2u) { \ | |||
data = vld1q_lane_f16(dst + 0, data, 0); \ | |||
data = vld1q_lane_f16(dst + 1, data, 1); \ | |||
} else if (width == 3u) { \ | |||
data = vld1q_lane_f16(dst + 0, data, 0); \ | |||
data = vld1q_lane_f16(dst + 1, data, 1); \ | |||
data = vld1q_lane_f16(dst + 2, data, 2); \ | |||
} else if (width == 4u) { \ | |||
data = vld1q_lane_u64(dst, data, 0); \ | |||
} else if (width == 5u) { \ | |||
data = vld1q_lane_u64(dst, data, 0); \ | |||
data = vld1q_lane_f16(dst + 4, data, 4); \ | |||
} else if (width == 6u) { \ | |||
data = vld1q_lane_u64(dst, data, 0); \ | |||
data = vld1q_lane_f16(dst + 4, data, 4); \ | |||
data = vld1q_lane_f16(dst + 5, data, 5); \ | |||
} else if (width == 7u) { \ | |||
data = vld1q_lane_u64(dst, data, 0); \ | |||
data = vld1q_lane_f16(dst + 4, data, 4); \ | |||
data = vld1q_lane_f16(dst + 5, data, 5); \ | |||
data = vld1q_lane_f16(dst + 6, data, 6); \ | |||
} \ | |||
}; \ | |||
if (height >= 1) \ | |||
load_less_8(dst + 0 * OW, out0); \ | |||
if (height >= 2) \ | |||
load_less_8(dst + 1 * OW, out1); \ | |||
if (height >= 3) \ | |||
load_less_8(dst + 2 * OW, out2); \ | |||
if (height >= 4) \ | |||
load_less_8(dst + 3 * OW, out3); \ | |||
if (height >= 5) \ | |||
load_less_8(dst + 4 * OW, out4); \ | |||
if (height >= 6) \ | |||
load_less_8(dst + 5 * OW, out5); \ | |||
if (height >= 7) \ | |||
load_less_8(dst + 6 * OW, out6); \ | |||
if (height >= 8) \ | |||
load_less_8(dst + 7 * OW, out7); \ | |||
} else { \ | |||
if (height >= 1) \ | |||
out0 = vld1q_f16(dst + 0 * OW); \ | |||
if (height >= 2) \ | |||
out1 = vld1q_f16(dst + 1 * OW); \ | |||
if (height >= 3) \ | |||
out2 = vld1q_f16(dst + 2 * OW); \ | |||
if (height >= 4) \ | |||
out3 = vld1q_f16(dst + 3 * OW); \ | |||
if (height >= 5) \ | |||
out4 = vld1q_f16(dst + 4 * OW); \ | |||
if (height >= 6) \ | |||
out5 = vld1q_f16(dst + 5 * OW); \ | |||
if (height >= 7) \ | |||
out6 = vld1q_f16(dst + 6 * OW); \ | |||
if (height >= 8) \ | |||
out7 = vld1q_f16(dst + 7 * OW); \ | |||
} | |||
#define STORE_RESULT_VAL \ | |||
if (width < 8) { \ | |||
auto store_less_8 = [](__fp16* dst, float16x8_t& data) { \ | |||
if (width == 1u) { \ | |||
vst1q_lane_f16(dst, data, 0); \ | |||
} else if (width == 2u) { \ | |||
vst1q_lane_f16(dst + 0, data, 0); \ | |||
vst1q_lane_f16(dst + 1, data, 1); \ | |||
} else if (width == 3u) { \ | |||
vst1q_lane_f16(dst + 0, data, 0); \ | |||
vst1q_lane_f16(dst + 1, data, 1); \ | |||
vst1q_lane_f16(dst + 2, data, 2); \ | |||
} else if (width == 4u) { \ | |||
vst1_f16(dst, vget_low_f16(data)); \ | |||
} else if (width == 5u) { \ | |||
vst1_f16(dst, vget_low_f16(data)); \ | |||
vst1q_lane_f16(dst + 4, data, 4); \ | |||
} else if (width == 6u) { \ | |||
vst1_f16(dst, vget_low_f16(data)); \ | |||
vst1q_lane_f16(dst + 4, data, 4); \ | |||
vst1q_lane_f16(dst + 5, data, 5); \ | |||
} else if (width == 7u) { \ | |||
vst1_f16(dst, vget_low_f16(data)); \ | |||
vst1q_lane_f16(dst + 4, data, 4); \ | |||
vst1q_lane_f16(dst + 5, data, 5); \ | |||
vst1q_lane_f16(dst + 6, data, 6); \ | |||
} \ | |||
}; \ | |||
if (height >= 1) \ | |||
store_less_8(dst + 0 * OW, out0); \ | |||
if (height >= 2) \ | |||
store_less_8(dst + 1 * OW, out1); \ | |||
if (height >= 3) \ | |||
store_less_8(dst + 2 * OW, out2); \ | |||
if (height >= 4) \ | |||
store_less_8(dst + 3 * OW, out3); \ | |||
if (height >= 5) \ | |||
store_less_8(dst + 4 * OW, out4); \ | |||
if (height >= 6) \ | |||
store_less_8(dst + 5 * OW, out5); \ | |||
if (height >= 7) \ | |||
store_less_8(dst + 6 * OW, out6); \ | |||
if (height >= 8) \ | |||
store_less_8(dst + 7 * OW, out7); \ | |||
} else { \ | |||
if (height >= 1) \ | |||
vst1q_f16(dst + 0 * OW, out0); \ | |||
if (height >= 2) \ | |||
vst1q_f16(dst + 1 * OW, out1); \ | |||
if (height >= 3) \ | |||
vst1q_f16(dst + 2 * OW, out2); \ | |||
if (height >= 4) \ | |||
vst1q_f16(dst + 3 * OW, out3); \ | |||
if (height >= 5) \ | |||
vst1q_f16(dst + 4 * OW, out4); \ | |||
if (height >= 6) \ | |||
vst1q_f16(dst + 5 * OW, out5); \ | |||
if (height >= 7) \ | |||
vst1q_f16(dst + 6 * OW, out6); \ | |||
if (height >= 8) \ | |||
vst1q_f16(dst + 7 * OW, out7); \ | |||
} | |||
template <int FH, int height, int width> | |||
struct do_pixel_proxy { | |||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow); | |||
}; | |||
template <int height, int width> | |||
struct do_pixel_proxy<1, height, width> { | |||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
MEGDNN_MARK_USED_VAR(IH); | |||
MEGDNN_MARK_USED_VAR(OH); | |||
const int ih = oh, iw = ow; | |||
#define cb(i) float16x8_t out##i{0}; | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
float16x8_t kr0, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_RESULT_VAL; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const __fp16* src_dd = src + fw; | |||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||
#define cb(i) \ | |||
if (height > i) { \ | |||
inp = vld1q_f16(src_dd + i * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||
} | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
} | |||
STORE_RESULT_VAL; | |||
} | |||
}; | |||
template <int height, int width> | |||
struct do_pixel_proxy<2, height, width> { | |||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
MEGDNN_MARK_USED_VAR(IH); | |||
MEGDNN_MARK_USED_VAR(OH); | |||
const int ih = oh, iw = ow; | |||
#define cb(i) float16x8_t out##i{0}; | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
float16x8_t kr0, kr1, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_RESULT_VAL; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const __fp16* src_dd = src + fw; | |||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f16(filter[1 * FW + fw]); | |||
#define cb(i) \ | |||
if (height > i) { \ | |||
inp = vld1q_f16(src_dd + i * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||
inp = vld1q_f16(src_dd + (i + 1) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr1); \ | |||
} | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
} | |||
STORE_RESULT_VAL; | |||
} | |||
}; | |||
template <int height, int width> | |||
struct do_pixel_proxy<3, height, width> { | |||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
MEGDNN_MARK_USED_VAR(IH); | |||
MEGDNN_MARK_USED_VAR(OH); | |||
const int ih = oh, iw = ow; | |||
#define cb(i) float16x8_t out##i{0}; | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
float16x8_t kr0, kr1, kr2, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_RESULT_VAL; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const __fp16* src_dd = src + fw; | |||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f16(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f16(filter[2 * FW + fw]); | |||
#define cb(i) \ | |||
if (height > i) { \ | |||
inp = vld1q_f16(src_dd + i * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||
inp = vld1q_f16(src_dd + (i + 1) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr1); \ | |||
inp = vld1q_f16(src_dd + (i + 2) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr2); \ | |||
} | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
} | |||
STORE_RESULT_VAL; | |||
} | |||
}; | |||
template <int height, int width> | |||
struct do_pixel_proxy<4, height, width> { | |||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
MEGDNN_MARK_USED_VAR(IH); | |||
MEGDNN_MARK_USED_VAR(OH); | |||
const int ih = oh, iw = ow; | |||
#define cb(i) float16x8_t out##i{0}; | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
float16x8_t kr0, kr1, kr2, kr3, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_RESULT_VAL; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const __fp16* src_dd = src + fw; | |||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f16(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f16(filter[2 * FW + fw]); | |||
kr3 = vdupq_n_f16(filter[3 * FW + fw]); | |||
#define cb(i) \ | |||
if (height > i) { \ | |||
inp = vld1q_f16(src_dd + i * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||
inp = vld1q_f16(src_dd + (i + 1) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr1); \ | |||
inp = vld1q_f16(src_dd + (i + 2) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr2); \ | |||
inp = vld1q_f16(src_dd + (i + 3) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr3); \ | |||
} | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
} | |||
STORE_RESULT_VAL; | |||
} | |||
}; | |||
template <int height, int width> | |||
struct do_pixel_proxy<5, height, width> { | |||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
MEGDNN_MARK_USED_VAR(IH); | |||
MEGDNN_MARK_USED_VAR(OH); | |||
const int ih = oh, iw = ow; | |||
#define cb(i) float16x8_t out##i{0}; | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
float16x8_t kr0, kr1, kr2, kr3, kr4, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_RESULT_VAL; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const __fp16* src_dd = src + fw; | |||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f16(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f16(filter[2 * FW + fw]); | |||
kr3 = vdupq_n_f16(filter[3 * FW + fw]); | |||
kr4 = vdupq_n_f16(filter[4 * FW + fw]); | |||
#define cb(i) \ | |||
if (height > i) { \ | |||
inp = vld1q_f16(src_dd + i * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||
inp = vld1q_f16(src_dd + (i + 1) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr1); \ | |||
inp = vld1q_f16(src_dd + (i + 2) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr2); \ | |||
inp = vld1q_f16(src_dd + (i + 3) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr3); \ | |||
inp = vld1q_f16(src_dd + (i + 4) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr4); \ | |||
} | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
} | |||
STORE_RESULT_VAL; | |||
} | |||
}; | |||
template <int height, int width> | |||
struct do_pixel_proxy<6, height, width> { | |||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
MEGDNN_MARK_USED_VAR(IH); | |||
MEGDNN_MARK_USED_VAR(OH); | |||
const int ih = oh, iw = ow; | |||
#define cb(i) float16x8_t out##i{0}; | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
float16x8_t kr0, kr1, kr2, kr3, kr4, kr5, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_RESULT_VAL; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const __fp16* src_dd = src + fw; | |||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f16(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f16(filter[2 * FW + fw]); | |||
kr3 = vdupq_n_f16(filter[3 * FW + fw]); | |||
kr4 = vdupq_n_f16(filter[4 * FW + fw]); | |||
kr5 = vdupq_n_f16(filter[5 * FW + fw]); | |||
#define cb(i) \ | |||
if (height > i) { \ | |||
inp = vld1q_f16(src_dd + i * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||
inp = vld1q_f16(src_dd + (i + 1) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr1); \ | |||
inp = vld1q_f16(src_dd + (i + 2) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr2); \ | |||
inp = vld1q_f16(src_dd + (i + 3) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr3); \ | |||
inp = vld1q_f16(src_dd + (i + 4) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr4); \ | |||
inp = vld1q_f16(src_dd + (i + 5) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr5); \ | |||
} | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
} | |||
STORE_RESULT_VAL; | |||
} | |||
}; | |||
template <int height, int width> | |||
struct do_pixel_proxy<7, height, width> { | |||
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
MEGDNN_MARK_USED_VAR(IH); | |||
MEGDNN_MARK_USED_VAR(OH); | |||
const int ih = oh, iw = ow; | |||
#define cb(i) float16x8_t out##i{0}; | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
float16x8_t kr0, kr1, kr2, kr3, kr4, kr5, kr6, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_RESULT_VAL; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const __fp16* src_dd = src + fw; | |||
kr0 = vdupq_n_f16(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f16(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f16(filter[2 * FW + fw]); | |||
kr3 = vdupq_n_f16(filter[3 * FW + fw]); | |||
kr4 = vdupq_n_f16(filter[4 * FW + fw]); | |||
kr5 = vdupq_n_f16(filter[5 * FW + fw]); | |||
kr6 = vdupq_n_f16(filter[6 * FW + fw]); | |||
#define cb(i) \ | |||
if (height > i) { \ | |||
inp = vld1q_f16(src_dd + i * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr0); \ | |||
inp = vld1q_f16(src_dd + (i + 1) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr1); \ | |||
inp = vld1q_f16(src_dd + (i + 2) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr2); \ | |||
inp = vld1q_f16(src_dd + (i + 3) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr3); \ | |||
inp = vld1q_f16(src_dd + (i + 4) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr4); \ | |||
inp = vld1q_f16(src_dd + (i + 5) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr5); \ | |||
inp = vld1q_f16(src_dd + (i + 6) * IW); \ | |||
out##i = vmlaq_f16(out##i, inp, kr6); \ | |||
} | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb); | |||
#undef cb | |||
} | |||
STORE_RESULT_VAL; | |||
} | |||
}; | |||
#undef STORE_RESULT_VAL | |||
#undef LOAD_RESULT_VAL | |||
template <int FH, int height, int width> | |||
void do_pixel(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
do_pixel_proxy<FH, height, width>::exec(src, filter, dst, IH, IW, OH, OW, | |||
FW, oh, ow); | |||
} | |||
template <int FH> | |||
void do_conv_tpl_enable_prefetch(const __fp16* src, | |||
const __fp16* filter, __fp16* dst, | |||
const int IH, const int IW, const int OH, | |||
const int OW, const int FW) { | |||
const int hbeg = 0, hend = OH; | |||
const int wbeg = 0, wend = OW; | |||
int i, j; | |||
for (i = hbeg; i + BLOCK_H <= hend; i += BLOCK_H) { | |||
for (j = wbeg; j + 8 <= wend; j += 8) { | |||
// do prefetch | |||
const int prefetch_index_input = | |||
(j + 16) < wend | |||
? i * IW + j + 16 | |||
: (i + 8) * IW + (((j + 16 - wend) >> 2) << 2); | |||
const int prefetch_index_output = | |||
(j + 16) < wend | |||
? i * OW + j + 16 | |||
: (i + 8) * OW + (((j + 16 - wend) >> 2) << 2); | |||
const __fp16* src_prefetch = src + prefetch_index_input; | |||
const __fp16* dst_prefetch = dst + prefetch_index_output; | |||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { | |||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); | |||
} | |||
#define unroll_prefetch_cb(i) __builtin_prefetch(dst_prefetch + i * OW, 1, 3); | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); | |||
do_pixel<FH, BLOCK_H, 8>(src, filter, dst, IH, IW, OH, OW, FW, i, | |||
j); | |||
} | |||
#define DISPATCH(width) \ | |||
do { \ | |||
const int prefetch_index_input = (i + 8) * IW + 12; \ | |||
const int prefetch_index_output = (i + 8) * OW + 12; \ | |||
const __fp16* src_prefetch = src + prefetch_index_input; \ | |||
const __fp16* dst_prefetch = dst + prefetch_index_output; \ | |||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ | |||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ | |||
} \ | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); \ | |||
do_pixel<FH, BLOCK_H, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||
j); \ | |||
} while (0) | |||
switch (wend - j) { | |||
case 1: | |||
DISPATCH(1); | |||
break; | |||
case 2: | |||
DISPATCH(2); | |||
break; | |||
case 3: | |||
DISPATCH(3); | |||
break; | |||
case 4: | |||
DISPATCH(4); | |||
break; | |||
case 5: | |||
DISPATCH(5); | |||
break; | |||
case 6: | |||
DISPATCH(6); | |||
break; | |||
case 7: | |||
DISPATCH(7); | |||
break; | |||
} | |||
#undef DISPATCH | |||
} | |||
#define DISPATCH2(height, width) \ | |||
do { \ | |||
const int prefetch_index_input = IH * IW + 12; \ | |||
const __fp16* src_prefetch = src + prefetch_index_input; \ | |||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ | |||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ | |||
} \ | |||
do_pixel<FH, height, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||
j); \ | |||
} while (0) | |||
#define DISPATCH1(height) \ | |||
do { \ | |||
for (j = wbeg; j + 8 <= wend; j += 8) { \ | |||
const int prefetch_index_input = \ | |||
(j + 16) < wend \ | |||
? i * IW + j + 16 \ | |||
: (i + 8) * IW + (((j + 16 - wend) >> 2) << 2); \ | |||
const int prefetch_index_output = \ | |||
(j + 16) < wend \ | |||
? i * OW + j + 16 \ | |||
: (i + 8) * OW + (((j + 16 - wend) >> 2) << 2); \ | |||
const __fp16* src_prefetch = src + prefetch_index_input; \ | |||
const __fp16* dst_prefetch = dst + prefetch_index_output; \ | |||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ | |||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ | |||
} \ | |||
UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); \ | |||
do_pixel<FH, height, 8>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||
j); \ | |||
} \ | |||
switch (wend - j) { \ | |||
case 1: \ | |||
DISPATCH2(height, 1); \ | |||
break; \ | |||
case 2: \ | |||
DISPATCH2(height, 2); \ | |||
break; \ | |||
case 3: \ | |||
DISPATCH2(height, 3); \ | |||
break; \ | |||
case 4: \ | |||
DISPATCH2(height, 4); \ | |||
break; \ | |||
case 5: \ | |||
DISPATCH2(height, 5); \ | |||
break; \ | |||
case 6: \ | |||
DISPATCH2(height, 6); \ | |||
break; \ | |||
case 7: \ | |||
DISPATCH2(height, 7); \ | |||
break; \ | |||
} \ | |||
} while (0) | |||
switch (hend - i) { | |||
case 1: | |||
DISPATCH1(1); | |||
break; | |||
case 2: | |||
DISPATCH1(2); | |||
break; | |||
case 3: | |||
DISPATCH1(3); | |||
break; | |||
#if BLOCK_H == 8 | |||
case 4: | |||
DISPATCH1(4); | |||
break; | |||
case 5: | |||
DISPATCH1(5); | |||
break; | |||
case 6: | |||
DISPATCH1(6); | |||
break; | |||
case 7: | |||
DISPATCH1(7); | |||
break; | |||
#endif | |||
} | |||
#undef DISPATCH1 | |||
#undef DISPATCH2 | |||
#undef unroll_prefetch_cb | |||
} | |||
template <int FH> | |||
void do_conv_tpl_disable_prefetch(const __fp16* src, | |||
const __fp16* filter, __fp16* dst, | |||
const int IH, const int IW, const int OH, | |||
const int OW, const int FW) { | |||
const int hbeg = 0, hend = OH; | |||
const int wbeg = 0, wend = OW; | |||
int i, j; | |||
for (i = hbeg; i + BLOCK_H <= hend; i += BLOCK_H) { | |||
for (j = wbeg; j + 8 <= wend; j += 8) { | |||
do_pixel<FH, BLOCK_H, 8>(src, filter, dst, IH, IW, OH, OW, FW, i, | |||
j); | |||
} | |||
#define DISPATCH(width) \ | |||
do { \ | |||
do_pixel<FH, BLOCK_H, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||
j); \ | |||
} while (0) | |||
switch (wend - j) { | |||
case 1: | |||
DISPATCH(1); | |||
break; | |||
case 2: | |||
DISPATCH(2); | |||
break; | |||
case 3: | |||
DISPATCH(3); | |||
break; | |||
case 4: | |||
DISPATCH(4); | |||
break; | |||
case 5: | |||
DISPATCH(5); | |||
break; | |||
case 6: | |||
DISPATCH(6); | |||
break; | |||
case 7: | |||
DISPATCH(7); | |||
break; | |||
} | |||
#undef DISPATCH | |||
} | |||
#define DISPATCH2(height, width) \ | |||
do { \ | |||
do_pixel<FH, height, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||
j); \ | |||
} while (0) | |||
#define DISPATCH1(height) \ | |||
do { \ | |||
for (j = wbeg; j + 8 <= wend; j += 8) { \ | |||
do_pixel<FH, height, 8>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||
j); \ | |||
} \ | |||
switch (wend - j) { \ | |||
case 1: \ | |||
DISPATCH2(height, 1); \ | |||
break; \ | |||
case 2: \ | |||
DISPATCH2(height, 2); \ | |||
break; \ | |||
case 3: \ | |||
DISPATCH2(height, 3); \ | |||
break; \ | |||
case 4: \ | |||
DISPATCH2(height, 4); \ | |||
break; \ | |||
case 5: \ | |||
DISPATCH2(height, 5); \ | |||
break; \ | |||
case 6: \ | |||
DISPATCH2(height, 6); \ | |||
break; \ | |||
case 7: \ | |||
DISPATCH2(height, 7); \ | |||
break; \ | |||
} \ | |||
} while (0) | |||
switch (hend - i) { | |||
case 1: | |||
DISPATCH1(1); | |||
break; | |||
case 2: | |||
DISPATCH1(2); | |||
break; | |||
case 3: | |||
DISPATCH1(3); | |||
break; | |||
#if BLOCK_H == 8 | |||
case 4: | |||
DISPATCH1(4); | |||
break; | |||
case 5: | |||
DISPATCH1(5); | |||
break; | |||
case 6: | |||
DISPATCH1(6); | |||
break; | |||
case 7: | |||
DISPATCH1(7); | |||
break; | |||
#endif | |||
} | |||
#undef DISPATCH1 | |||
#undef DISPATCH2 | |||
} | |||
} // anonymous namespace | |||
void conv_bias::kern_direct_f16(const __fp16* src, | |||
const __fp16* filter, __fp16* dst, | |||
const int IH, const int IW, const int OH, | |||
const int OW, const int FH, const int FW) { | |||
megdnn_assert_internal(FH <= 7); | |||
if (IH > 100 && IW > 100) { | |||
#define GAO(FH) \ | |||
do { \ | |||
return do_conv_tpl_enable_prefetch<FH>(src, filter, dst, IH, IW, OH, \ | |||
OW, FW); \ | |||
} while (0) | |||
switch (FH) { | |||
case 1: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(0)) { GAO(1); } | |||
MIDOUT_END(); | |||
break; | |||
case 2: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(1)) { GAO(2); } | |||
MIDOUT_END(); | |||
break; | |||
case 3: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(2)) { GAO(3); } | |||
MIDOUT_END(); | |||
break; | |||
case 4: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(3)) { GAO(4); } | |||
MIDOUT_END(); | |||
break; | |||
case 5: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(4)) { GAO(5); } | |||
MIDOUT_END(); | |||
break; | |||
case 6: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(5)) { GAO(6); } | |||
MIDOUT_END(); | |||
break; | |||
case 7: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(6)) { GAO(7); } | |||
MIDOUT_END(); | |||
break; | |||
} | |||
#undef GAO | |||
} else { | |||
#define GAO(FH) \ | |||
do { \ | |||
return do_conv_tpl_disable_prefetch<FH>(src, filter, dst, IH, IW, OH, \ | |||
OW, FW); \ | |||
} while (0) | |||
switch (FH) { | |||
case 1: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(0)) { GAO(1); } | |||
MIDOUT_END(); | |||
break; | |||
case 2: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(1)) { GAO(2); } | |||
MIDOUT_END(); | |||
break; | |||
case 3: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(2)) { GAO(3); } | |||
MIDOUT_END(); | |||
break; | |||
case 4: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(3)) { GAO(4); } | |||
MIDOUT_END(); | |||
break; | |||
case 5: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(4)) { GAO(5); } | |||
MIDOUT_END(); | |||
break; | |||
case 6: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(5)) { GAO(6); } | |||
MIDOUT_END(); | |||
break; | |||
case 7: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(6)) { GAO(7); } | |||
MIDOUT_END(); | |||
break; | |||
} | |||
#undef GAO | |||
} | |||
megdnn_assert_internal(0); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,32 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/f16/direct.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <cstddef> | |||
#include "megdnn/dtype.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fp16{ | |||
namespace conv_bias { | |||
void kern_direct_f16(const __fp16* src, const __fp16* filter, | |||
__fp16* dst, const int IH, const int IW, const int OH, | |||
const int OW, const int FH, const int FW); | |||
} // namespace convolution | |||
} // namespace fp16 | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,522 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/f16/do_conv_stride1.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <algorithm> | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
#include "./do_conv_stride1.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fp16; | |||
using namespace conv_stride1; | |||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||
void conv_stride1::do_conv_2x2_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, | |||
size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
//! unroll of 2 | |||
size_t ic = 0; | |||
for (; ic + 1 < IC; ic += 2) { | |||
const __fp16* src_ptr = src + IW * IH * ic; | |||
const __fp16* src_ptr1 = src_ptr + IW * IH; | |||
__fp16* outptr = dst; | |||
const __fp16* r00 = src_ptr; | |||
const __fp16* r01 = src_ptr + IW; | |||
const __fp16* r10 = src_ptr1; | |||
const __fp16* r11 = src_ptr1 + IW; | |||
const __fp16* k0 = filter + ic * 4; | |||
float16x8_t _k0 = vld1q_f16(k0); | |||
rep(h, OH) { | |||
int width = OW >> 3; | |||
rep(i, width) { | |||
float16x8_t _r000 = vld1q_f16(r00); | |||
float16x8_t _r010 = vld1q_f16(r01); | |||
float16x8_t _r001 = vld1q_f16(r00 + 1); | |||
float16x8_t _r011 = vld1q_f16(r01 + 1); | |||
float16x8_t _r100 = vld1q_f16(r10); | |||
float16x8_t _r110 = vld1q_f16(r11); | |||
float16x8_t _r101 = vld1q_f16(r10 + 1); | |||
float16x8_t _r111 = vld1q_f16(r11 + 1); | |||
float16x8_t _sum = vld1q_f16(outptr); | |||
_sum = vmlaq_lane_f16(_sum, _r000, vget_low_f16(_k0), 0); | |||
_sum = vmlaq_lane_f16(_sum, _r001, vget_low_f16(_k0), 1); | |||
_sum = vmlaq_lane_f16(_sum, _r010, vget_low_f16(_k0), 2); | |||
_sum = vmlaq_lane_f16(_sum, _r011, vget_low_f16(_k0), 3); | |||
_sum = vmlaq_lane_f16(_sum, _r100, vget_high_f16(_k0), 0); | |||
_sum = vmlaq_lane_f16(_sum, _r101, vget_high_f16(_k0), 1); | |||
_sum = vmlaq_lane_f16(_sum, _r110, vget_high_f16(_k0), 2); | |||
_sum = vmlaq_lane_f16(_sum, _r111, vget_high_f16(_k0), 3); | |||
vst1q_f16(outptr, _sum); | |||
r00 += 8; | |||
r01 += 8; | |||
r10 += 8; | |||
r11 += 8; | |||
outptr += 8; | |||
} | |||
r00 += tail_step; | |||
r01 += tail_step; | |||
r10 += tail_step; | |||
r11 += tail_step; | |||
} | |||
} | |||
for (; ic < IC; ic++) { | |||
const __fp16* src_ptr = src + IW * IH * ic; | |||
__fp16* outptr = dst; | |||
const __fp16* r0 = src_ptr; | |||
const __fp16* r1 = src_ptr + IW; | |||
const __fp16* k0 = filter + ic * 4; | |||
float16x8_t _k0 = vdupq_n_f16(k0[0]); | |||
float16x8_t _k1 = vdupq_n_f16(k0[1]); | |||
float16x8_t _k2 = vdupq_n_f16(k0[2]); | |||
float16x8_t _k3 = vdupq_n_f16(k0[3]); | |||
rep(h, OH) { | |||
int width = OW >> 3; | |||
rep(i, width) { | |||
float16x8_t _r00 = vld1q_f16(r0); | |||
float16x8_t _r10 = vld1q_f16(r1); | |||
float16x8_t _r01 = vld1q_f16(r0 + 1); | |||
float16x8_t _r11 = vld1q_f16(r1 + 1); | |||
float16x8_t _sum = vld1q_f16(outptr); | |||
float16x8_t _sum2; | |||
_sum = vmlaq_f16(_sum, _r00, _k0); | |||
_sum2 = vmulq_f16(_r01, _k1); | |||
_sum = vmlaq_f16(_sum, _r10, _k2); | |||
_sum2 = vmlaq_f16(_sum2, _r11, _k3); | |||
_sum = vaddq_f16(_sum, _sum2); | |||
vst1q_f16(outptr, _sum); | |||
r0 += 8; | |||
r1 += 8; | |||
outptr += 8; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
} | |||
} | |||
} | |||
void conv_stride1::do_conv_3x3_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, | |||
size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
rep(ic, IC) { | |||
const __fp16* src_ptr = src + IW * IH * ic; | |||
__fp16* outptr = dst; | |||
__fp16* outptr2 = outptr + OW; | |||
const __fp16* r0 = src_ptr; | |||
const __fp16* r1 = src_ptr + IW; | |||
const __fp16* r2 = src_ptr + IW * 2; | |||
const __fp16* r3 = src_ptr + IW * 3; | |||
float16x8_t _k01234567 = vld1q_f16(filter); | |||
float16x8_t _k12345678 = vld1q_f16(filter + 1); | |||
size_t h = 0; | |||
for (; h + 1 < OH; h += 2) { | |||
int width = OW >> 3; | |||
rep(i, width) { | |||
float16x8_t _sum1 = vld1q_f16(outptr); | |||
float16x8_t _sum2 = vdupq_n_f16(0.f); | |||
float16x8_t _sum3 = vld1q_f16(outptr2); | |||
float16x8_t _sum4 = vdupq_n_f16(0.f); | |||
float16x8_t _r00 = vld1q_f16(r0); | |||
float16x8_t _r00n = vld1q_f16(r0 + 8); | |||
float16x8_t _r01 = vextq_f16(_r00, _r00n, 1); | |||
float16x8_t _r02 = vextq_f16(_r00, _r00n, 2); | |||
float16x8_t _r10 = vld1q_f16(r1); | |||
float16x8_t _r10n = vld1q_f16(r1 + 8); | |||
float16x8_t _r11 = vextq_f16(_r10, _r10n, 1); | |||
float16x8_t _r12 = vextq_f16(_r10, _r10n, 2); | |||
float16x8_t _r20 = vld1q_f16(r2); | |||
float16x8_t _r20n = vld1q_f16(r2 + 8); | |||
float16x8_t _r21 = vextq_f16(_r20, _r20n, 1); | |||
float16x8_t _r22 = vextq_f16(_r20, _r20n, 2); | |||
float16x8_t _r30 = vld1q_f16(r3); | |||
float16x8_t _r30n = vld1q_f16(r3 + 8); | |||
float16x8_t _r31 = vextq_f16(_r30, _r30n, 1); | |||
float16x8_t _r32 = vextq_f16(_r30, _r30n, 2); | |||
_sum1 = vmlaq_low_lane_f16(_sum1, _r00, _k01234567, 0); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r01, _k01234567, 1); | |||
_sum1 = vmlaq_low_lane_f16(_sum1, _r02, _k01234567, 2); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r10, _k01234567, 3); | |||
_sum1 = vmlaq_high_lane_f16(_sum1, _r11, _k01234567, 4); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r12, _k01234567, 5); | |||
_sum1 = vmlaq_high_lane_f16(_sum1, _r20, _k01234567, 6); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r21, _k01234567, 7); | |||
_sum1 = vmlaq_high_lane_f16(_sum1, _r22, _k12345678, 7); | |||
_sum3 = vmlaq_low_lane_f16(_sum3, _r10, _k01234567, 0); | |||
_sum4 = vmlaq_low_lane_f16(_sum4, _r11, _k01234567, 1); | |||
_sum3 = vmlaq_low_lane_f16(_sum3, _r12, _k01234567, 2); | |||
_sum4 = vmlaq_low_lane_f16(_sum4, _r20, _k01234567, 3); | |||
_sum3 = vmlaq_high_lane_f16(_sum3, _r21, _k01234567, 4); | |||
_sum4 = vmlaq_high_lane_f16(_sum4, _r22, _k01234567, 5); | |||
_sum3 = vmlaq_high_lane_f16(_sum3, _r30, _k01234567, 6); | |||
_sum4 = vmlaq_high_lane_f16(_sum4, _r31, _k01234567, 7); | |||
_sum3 = vmlaq_high_lane_f16(_sum3, _r32, _k12345678, 7); | |||
_sum1 = vaddq_f16(_sum1, _sum2); | |||
_sum3 = vaddq_f16(_sum3, _sum4); | |||
vst1q_f16(outptr, _sum1); | |||
vst1q_f16(outptr2, _sum3); | |||
r0 += 8; | |||
r1 += 8; | |||
r2 += 8; | |||
r3 += 8; | |||
outptr += 8; | |||
outptr2 += 8; | |||
} | |||
r0 += tail_step + IW; | |||
r1 += tail_step + IW; | |||
r2 += tail_step + IW; | |||
r3 += tail_step + IW; | |||
outptr += OW; | |||
outptr2 += OW; | |||
} | |||
for (; h < OH; h++) { | |||
int width = OW >> 3; | |||
rep(i, width) { | |||
float16x8_t _sum1 = vld1q_f16(outptr); | |||
float16x8_t _sum2 = vdupq_n_f16(0.f); | |||
float16x8_t _r00 = vld1q_f16(r0); | |||
float16x8_t _r00n = vld1q_f16(r0 + 8); | |||
float16x8_t _r01 = vextq_f16(_r00, _r00n, 1); | |||
float16x8_t _r02 = vextq_f16(_r00, _r00n, 2); | |||
float16x8_t _r10 = vld1q_f16(r1); | |||
float16x8_t _r10n = vld1q_f16(r1 + 8); | |||
float16x8_t _r11 = vextq_f16(_r10, _r10n, 1); | |||
float16x8_t _r12 = vextq_f16(_r10, _r10n, 2); | |||
float16x8_t _r20 = vld1q_f16(r2); | |||
float16x8_t _r20n = vld1q_f16(r2 + 8); | |||
float16x8_t _r21 = vextq_f16(_r20, _r20n, 1); | |||
float16x8_t _r22 = vextq_f16(_r20, _r20n, 2); | |||
_sum1 = vmlaq_low_lane_f16(_sum1, _r00, _k01234567, 0); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r01, _k01234567, 1); | |||
_sum1 = vmlaq_low_lane_f16(_sum1, _r02, _k01234567, 2); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r10, _k01234567, 3); | |||
_sum1 = vmlaq_high_lane_f16(_sum1, _r11, _k01234567, 4); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r12, _k01234567, 5); | |||
_sum1 = vmlaq_high_lane_f16(_sum1, _r20, _k01234567, 6); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r21, _k01234567, 7); | |||
_sum1 = vmlaq_high_lane_f16(_sum1, _r22, _k12345678, 7); | |||
_sum1 = vaddq_f16(_sum1, _sum2); | |||
vst1q_f16(outptr, _sum1); | |||
r0 += 8; | |||
r1 += 8; | |||
r2 += 8; | |||
outptr += 8; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
} | |||
filter += 9; | |||
} | |||
} | |||
void conv_stride1::do_conv_5x5_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, | |||
size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
rep(ic, IC) { | |||
const __fp16* src_ptr = src + IW * IH * ic; | |||
__fp16* outptr = dst; | |||
__fp16* outptr2 = outptr + OW; | |||
const __fp16* r0 = src_ptr; | |||
const __fp16* r1 = src_ptr + IW; | |||
const __fp16* r2 = src_ptr + IW * 2; | |||
const __fp16* r3 = src_ptr + IW * 3; | |||
const __fp16* r4 = src_ptr + IW * 4; | |||
const __fp16* r5 = src_ptr + IW * 5; | |||
float16x8_t _k0 = vld1q_f16(filter); | |||
float16x8_t _k1 = vld1q_f16(filter + 8); | |||
float16x8_t _k2 = vld1q_f16(filter + 16); | |||
float16x8_t _k3 = vld1q_f16(filter + 17); | |||
size_t h = 0; | |||
for (; h + 1 < OH; h += 2) { | |||
int width = OW >> 3; | |||
rep(i, width) { | |||
float16x8_t _sum = vld1q_f16(outptr); | |||
float16x8_t _sum2 = vld1q_f16(outptr2); | |||
float16x8_t _r00 = vld1q_f16(r0); | |||
float16x8_t _r05 = vld1q_f16(r0 + 8); | |||
float16x8_t _r01 = vextq_f16(_r00, _r05, 1); | |||
float16x8_t _r02 = vextq_f16(_r00, _r05, 2); | |||
float16x8_t _r03 = vextq_f16(_r00, _r05, 3); | |||
float16x8_t _r04 = vextq_f16(_r00, _r05, 4); | |||
float16x8_t _r10 = vld1q_f16(r1); | |||
float16x8_t _r15 = vld1q_f16(r1 + 8); | |||
float16x8_t _r11 = vextq_f16(_r10, _r15, 1); | |||
float16x8_t _r12 = vextq_f16(_r10, _r15, 2); | |||
float16x8_t _r13 = vextq_f16(_r10, _r15, 3); | |||
float16x8_t _r14 = vextq_f16(_r10, _r15, 4); | |||
float16x8_t _r20 = vld1q_f16(r2); | |||
float16x8_t _r25 = vld1q_f16(r2 + 8); | |||
float16x8_t _r21 = vextq_f16(_r20, _r25, 1); | |||
float16x8_t _r22 = vextq_f16(_r20, _r25, 2); | |||
float16x8_t _r23 = vextq_f16(_r20, _r25, 3); | |||
float16x8_t _r24 = vextq_f16(_r20, _r25, 4); | |||
float16x8_t _r30 = vld1q_f16(r3); | |||
float16x8_t _r35 = vld1q_f16(r3 + 8); | |||
float16x8_t _r31 = vextq_f16(_r30, _r35, 1); | |||
float16x8_t _r32 = vextq_f16(_r30, _r35, 2); | |||
float16x8_t _r33 = vextq_f16(_r30, _r35, 3); | |||
float16x8_t _r34 = vextq_f16(_r30, _r35, 4); | |||
float16x8_t _r40 = vld1q_f16(r4); | |||
float16x8_t _r45 = vld1q_f16(r4 + 8); | |||
float16x8_t _r41 = vextq_f16(_r40, _r45, 1); | |||
float16x8_t _r42 = vextq_f16(_r40, _r45, 2); | |||
float16x8_t _r43 = vextq_f16(_r40, _r45, 3); | |||
float16x8_t _r44 = vextq_f16(_r40, _r45, 4); | |||
float16x8_t _r50 = vld1q_f16(r5); | |||
float16x8_t _r55 = vld1q_f16(r5 + 8); | |||
float16x8_t _r51 = vextq_f16(_r50, _r55, 1); | |||
float16x8_t _r52 = vextq_f16(_r50, _r55, 2); | |||
float16x8_t _r53 = vextq_f16(_r50, _r55, 3); | |||
float16x8_t _r54 = vextq_f16(_r50, _r55, 4); | |||
_sum = vmlaq_low_lane_f16(_sum, _r00, _k0, 0); | |||
_sum = vmlaq_low_lane_f16(_sum, _r01, _k0, 1); | |||
_sum = vmlaq_low_lane_f16(_sum, _r02, _k0, 2); | |||
_sum = vmlaq_low_lane_f16(_sum, _r03, _k0, 3); | |||
_sum = vmlaq_high_lane_f16(_sum, _r04, _k0, 4); | |||
_sum = vmlaq_high_lane_f16(_sum, _r10, _k0, 5); | |||
_sum = vmlaq_high_lane_f16(_sum, _r11, _k0, 6); | |||
_sum = vmlaq_high_lane_f16(_sum, _r12, _k0, 7); | |||
_sum = vmlaq_low_lane_f16(_sum, _r13, _k1, 0); | |||
_sum = vmlaq_low_lane_f16(_sum, _r14, _k1, 1); | |||
_sum = vmlaq_low_lane_f16(_sum, _r20, _k1, 2); | |||
_sum = vmlaq_low_lane_f16(_sum, _r21, _k1, 3); | |||
_sum = vmlaq_high_lane_f16(_sum, _r22, _k1, 4); | |||
_sum = vmlaq_high_lane_f16(_sum, _r23, _k1, 5); | |||
_sum = vmlaq_high_lane_f16(_sum, _r24, _k1, 6); | |||
_sum = vmlaq_high_lane_f16(_sum, _r30, _k1, 7); | |||
_sum = vmlaq_low_lane_f16(_sum, _r31, _k2, 0); | |||
_sum = vmlaq_low_lane_f16(_sum, _r32, _k2, 1); | |||
_sum = vmlaq_low_lane_f16(_sum, _r33, _k2, 2); | |||
_sum = vmlaq_low_lane_f16(_sum, _r34, _k2, 3); | |||
_sum = vmlaq_high_lane_f16(_sum, _r40, _k2, 4); | |||
_sum = vmlaq_high_lane_f16(_sum, _r41, _k2, 5); | |||
_sum = vmlaq_high_lane_f16(_sum, _r42, _k2, 6); | |||
_sum = vmlaq_high_lane_f16(_sum, _r43, _k2, 7); | |||
_sum = vmlaq_high_lane_f16(_sum, _r44, _k3, 7); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r10, _k0, 0); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r11, _k0, 1); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r12, _k0, 2); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r13, _k0, 3); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r14, _k0, 4); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r20, _k0, 5); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r21, _k0, 6); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r22, _k0, 7); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r23, _k1, 0); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r24, _k1, 1); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r30, _k1, 2); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r31, _k1, 3); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r32, _k1, 4); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r33, _k1, 5); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r34, _k1, 6); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r40, _k1, 7); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r41, _k2, 0); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r42, _k2, 1); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r43, _k2, 2); | |||
_sum2 = vmlaq_low_lane_f16(_sum2, _r44, _k2, 3); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r50, _k2, 4); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r51, _k2, 5); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r52, _k2, 6); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r53, _k2, 7); | |||
_sum2 = vmlaq_high_lane_f16(_sum2, _r54, _k3, 7); | |||
vst1q_f16(outptr, _sum); | |||
vst1q_f16(outptr2, _sum2); | |||
r0 += 8; | |||
r1 += 8; | |||
r2 += 8; | |||
r3 += 8; | |||
r4 += 8; | |||
r5 += 8; | |||
outptr += 8; | |||
outptr2 += 8; | |||
} | |||
r0 += tail_step + IW; | |||
r1 += tail_step + IW; | |||
r2 += tail_step + IW; | |||
r3 += tail_step + IW; | |||
r4 += tail_step + IW; | |||
r5 += tail_step + IW; | |||
outptr += OW; | |||
outptr2 += OW; | |||
} | |||
for (; h < OH; h++) { | |||
int width = OW >> 3; | |||
rep(i, width) { | |||
float16x8_t _sum = vld1q_f16(outptr); | |||
float16x8_t _r00 = vld1q_f16(r0); | |||
float16x8_t _r05 = vld1q_f16(r0 + 8); | |||
float16x8_t _r01 = vextq_f16(_r00, _r05, 1); | |||
float16x8_t _r02 = vextq_f16(_r00, _r05, 2); | |||
float16x8_t _r03 = vextq_f16(_r00, _r05, 3); | |||
float16x8_t _r04 = vextq_f16(_r00, _r05, 4); | |||
float16x8_t _r10 = vld1q_f16(r1); | |||
float16x8_t _r15 = vld1q_f16(r1 + 8); | |||
float16x8_t _r11 = vextq_f16(_r10, _r15, 1); | |||
float16x8_t _r12 = vextq_f16(_r10, _r15, 2); | |||
float16x8_t _r13 = vextq_f16(_r10, _r15, 3); | |||
float16x8_t _r14 = vextq_f16(_r10, _r15, 4); | |||
float16x8_t _r20 = vld1q_f16(r2); | |||
float16x8_t _r25 = vld1q_f16(r2 + 8); | |||
float16x8_t _r21 = vextq_f16(_r20, _r25, 1); | |||
float16x8_t _r22 = vextq_f16(_r20, _r25, 2); | |||
float16x8_t _r23 = vextq_f16(_r20, _r25, 3); | |||
float16x8_t _r24 = vextq_f16(_r20, _r25, 4); | |||
float16x8_t _r30 = vld1q_f16(r3); | |||
float16x8_t _r35 = vld1q_f16(r3 + 8); | |||
float16x8_t _r31 = vextq_f16(_r30, _r35, 1); | |||
float16x8_t _r32 = vextq_f16(_r30, _r35, 2); | |||
float16x8_t _r33 = vextq_f16(_r30, _r35, 3); | |||
float16x8_t _r34 = vextq_f16(_r30, _r35, 4); | |||
float16x8_t _r40 = vld1q_f16(r4); | |||
float16x8_t _r45 = vld1q_f16(r4 + 8); | |||
float16x8_t _r41 = vextq_f16(_r40, _r45, 1); | |||
float16x8_t _r42 = vextq_f16(_r40, _r45, 2); | |||
float16x8_t _r43 = vextq_f16(_r40, _r45, 3); | |||
float16x8_t _r44 = vextq_f16(_r40, _r45, 4); | |||
_sum = vmlaq_low_lane_f16(_sum, _r00, _k0, 0); | |||
_sum = vmlaq_low_lane_f16(_sum, _r01, _k0, 1); | |||
_sum = vmlaq_low_lane_f16(_sum, _r02, _k0, 2); | |||
_sum = vmlaq_low_lane_f16(_sum, _r03, _k0, 3); | |||
_sum = vmlaq_high_lane_f16(_sum, _r04, _k0, 4); | |||
_sum = vmlaq_high_lane_f16(_sum, _r10, _k0, 5); | |||
_sum = vmlaq_high_lane_f16(_sum, _r11, _k0, 6); | |||
_sum = vmlaq_high_lane_f16(_sum, _r12, _k0, 7); | |||
_sum = vmlaq_low_lane_f16(_sum, _r13, _k1, 0); | |||
_sum = vmlaq_low_lane_f16(_sum, _r14, _k1, 1); | |||
_sum = vmlaq_low_lane_f16(_sum, _r20, _k1, 2); | |||
_sum = vmlaq_low_lane_f16(_sum, _r21, _k1, 3); | |||
_sum = vmlaq_high_lane_f16(_sum, _r22, _k1, 4); | |||
_sum = vmlaq_high_lane_f16(_sum, _r23, _k1, 5); | |||
_sum = vmlaq_high_lane_f16(_sum, _r24, _k1, 6); | |||
_sum = vmlaq_high_lane_f16(_sum, _r30, _k1, 7); | |||
_sum = vmlaq_low_lane_f16(_sum, _r31, _k2, 0); | |||
_sum = vmlaq_low_lane_f16(_sum, _r32, _k2, 1); | |||
_sum = vmlaq_low_lane_f16(_sum, _r33, _k2, 2); | |||
_sum = vmlaq_low_lane_f16(_sum, _r34, _k2, 3); | |||
_sum = vmlaq_high_lane_f16(_sum, _r40, _k2, 4); | |||
_sum = vmlaq_high_lane_f16(_sum, _r41, _k2, 5); | |||
_sum = vmlaq_high_lane_f16(_sum, _r42, _k2, 6); | |||
_sum = vmlaq_high_lane_f16(_sum, _r43, _k2, 7); | |||
_sum = vmlaq_high_lane_f16(_sum, _r44, _k3, 7); | |||
vst1q_f16(outptr, _sum); | |||
r0 += 8; | |||
r1 += 8; | |||
r2 += 8; | |||
r3 += 8; | |||
r4 += 8; | |||
outptr += 8; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
r3 += tail_step; | |||
r4 += tail_step; | |||
} | |||
filter += 25; | |||
} | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,32 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/f16/do_conv_stride1.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fp16 { | |||
namespace conv_stride1 { | |||
void do_conv_2x2_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||
void do_conv_3x3_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||
void do_conv_5x5_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||
} // namespace conv_stride1 | |||
} // namespace fp16 | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,341 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/f16/helper.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/common/unroll_macro.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
#define MATRIX_MUL4x4_fp16(sum, a, b) \ | |||
sum##0 = vmul_lane_f16(b##0, a##0, 0); \ | |||
sum##1 = vmul_lane_f16(b##0, a##1, 0); \ | |||
sum##2 = vmul_lane_f16(b##0, a##2, 0); \ | |||
sum##3 = vmul_lane_f16(b##0, a##3, 0); \ | |||
sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##1, a##0, 1)); \ | |||
sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##1, a##1, 1)); \ | |||
sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##1, a##2, 1)); \ | |||
sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##1, a##3, 1)); \ | |||
sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##2, a##0, 2)); \ | |||
sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##2, a##1, 2)); \ | |||
sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##2, a##2, 2)); \ | |||
sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##2, a##3, 2)); \ | |||
sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##3, a##0, 3)); \ | |||
sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##3, a##1, 3)); \ | |||
sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##3, a##2, 3)); \ | |||
sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##3, a##3, 3)); | |||
#define CONCAT(a, id) a##id | |||
#if MEGDNN_AARCH64 | |||
#define TRANSPOSE_4x4(a, ret) \ | |||
do { \ | |||
auto b00 = vzip1_f16(CONCAT(a, 0).value, \ | |||
CONCAT(a, 1).value); /*a1b1a2b2*/ \ | |||
auto b01 = vzip2_f16(CONCAT(a, 0).value, \ | |||
CONCAT(a, 1).value); /*a3b3a4b4*/ \ | |||
auto b10 = vzip1_f16(CONCAT(a, 2).value, \ | |||
CONCAT(a, 3).value); /*c1d1c2d2*/ \ | |||
auto b11 = vzip2_f16(CONCAT(a, 2).value, \ | |||
CONCAT(a, 3).value); /*c3d3c4d4*/ \ | |||
auto s32b00 = vreinterpret_s32_f16(b00); \ | |||
auto s32b01 = vreinterpret_s32_f16(b01); \ | |||
auto s32b10 = vreinterpret_s32_f16(b10); \ | |||
auto s32b11 = vreinterpret_s32_f16(b11); \ | |||
CONCAT(ret, 0).value = \ | |||
vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)); \ | |||
CONCAT(ret, 1).value = \ | |||
vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)); \ | |||
CONCAT(ret, 2).value = \ | |||
vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)); \ | |||
CONCAT(ret, 3).value = \ | |||
vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)); \ | |||
} while (0); | |||
#define TRANSPOSE_4x8(a, ret) \ | |||
do { \ | |||
auto b00 = vzip1q_f16(CONCAT(a, 0).value, \ | |||
CONCAT(a, 1).value); /*a1b1a2b2a3b3a4b4*/ \ | |||
auto b01 = vzip2q_f16(CONCAT(a, 0).value, \ | |||
CONCAT(a, 1).value); /*a5b5a6b6a7b7a8b8*/ \ | |||
auto b10 = vzip1q_f16(CONCAT(a, 2).value, \ | |||
CONCAT(a, 3).value); /*c1d1c2d2c3d3c4d4*/ \ | |||
auto b11 = vzip2q_f16(CONCAT(a, 2).value, \ | |||
CONCAT(a, 3).value); /*c5d5c6d6c7d7c8d8*/ \ | |||
auto s32b00 = vreinterpretq_s32_f16(b00); \ | |||
auto s32b01 = vreinterpretq_s32_f16(b01); \ | |||
auto s32b10 = vreinterpretq_s32_f16(b10); \ | |||
auto s32b11 = vreinterpretq_s32_f16(b11); \ | |||
auto f16b00 = vreinterpretq_f16_s32( \ | |||
vzip1q_s32(s32b00, s32b10)); /*a1b1c1d1a2b2c2d2*/ \ | |||
auto f16b01 = vreinterpretq_f16_s32( \ | |||
vzip2q_s32(s32b00, s32b10)); /*a3b3c3d3a4b4a4d4*/ \ | |||
auto f16b10 = vreinterpretq_f16_s32( \ | |||
vzip1q_s32(s32b01, s32b11)); /*a5b5c5d5a6b6c6d6*/ \ | |||
auto f16b11 = vreinterpretq_f16_s32( \ | |||
vzip2q_s32(s32b01, s32b11)); /*a7b7c7d7a8b8c8d8*/ \ | |||
CONCAT(ret, 0).value = vget_low_f16(f16b00); \ | |||
CONCAT(ret, 1).value = vget_high_f16(f16b00); \ | |||
CONCAT(ret, 2).value = vget_low_f16(f16b01); \ | |||
CONCAT(ret, 3).value = vget_high_f16(f16b01); \ | |||
CONCAT(ret, 4).value = vget_low_f16(f16b10); \ | |||
CONCAT(ret, 5).value = vget_high_f16(f16b10); \ | |||
CONCAT(ret, 6).value = vget_low_f16(f16b11); \ | |||
CONCAT(ret, 7).value = vget_high_f16(f16b11); \ | |||
} while (0); | |||
#define TRANSPOSE_8x4(a, ret) \ | |||
do { \ | |||
auto b00 = vzip1_f16(CONCAT(a, 0).value, \ | |||
CONCAT(a, 1).value); /*a1b1a2b2*/ \ | |||
auto b01 = vzip2_f16(CONCAT(a, 0).value, \ | |||
CONCAT(a, 1).value); /*a3b3a4b4*/ \ | |||
auto b10 = vzip1_f16(CONCAT(a, 2).value, \ | |||
CONCAT(a, 3).value); /*c1d1c2d2*/ \ | |||
auto b11 = vzip2_f16(CONCAT(a, 2).value, \ | |||
CONCAT(a, 3).value); /*c3d3c4d4*/ \ | |||
auto b20 = vzip1_f16(CONCAT(a, 4).value, \ | |||
CONCAT(a, 5).value); /*e1f1e2f2*/ \ | |||
auto b21 = vzip2_f16(CONCAT(a, 4).value, \ | |||
CONCAT(a, 5).value); /*e3f3e4f4*/ \ | |||
auto b30 = vzip1_f16(CONCAT(a, 6).value, \ | |||
CONCAT(a, 7).value); /*g1h1g2h2*/ \ | |||
auto b31 = vzip2_f16(CONCAT(a, 6).value, \ | |||
CONCAT(a, 7).value); /*g3h3g4h4*/ \ | |||
auto s32b00 = vreinterpret_s32_f16(b00); \ | |||
auto s32b01 = vreinterpret_s32_f16(b01); \ | |||
auto s32b10 = vreinterpret_s32_f16(b10); \ | |||
auto s32b11 = vreinterpret_s32_f16(b11); \ | |||
auto s32b20 = vreinterpret_s32_f16(b20); \ | |||
auto s32b21 = vreinterpret_s32_f16(b21); \ | |||
auto s32b30 = vreinterpret_s32_f16(b30); \ | |||
auto s32b31 = vreinterpret_s32_f16(b31); \ | |||
CONCAT(ret, 0).value = \ | |||
vcombine_f16(vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)), \ | |||
vreinterpret_f16_s32(vzip1_s32(s32b20, s32b30))); \ | |||
CONCAT(ret, 1).value = \ | |||
vcombine_f16(vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)), \ | |||
vreinterpret_f16_s32(vzip2_s32(s32b20, s32b30))); \ | |||
CONCAT(ret, 2).value = \ | |||
vcombine_f16(vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)), \ | |||
vreinterpret_f16_s32(vzip1_s32(s32b21, s32b31))); \ | |||
CONCAT(ret, 3).value = \ | |||
vcombine_f16(vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)), \ | |||
vreinterpret_f16_s32(vzip2_s32(s32b21, s32b31))); \ | |||
} while (0); | |||
#define TRANSPOSE_8x8(a, ret) \ | |||
do { \ | |||
auto b00 = vzip1q_f16(CONCAT(a, 0).value, \ | |||
CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ | |||
auto b01 = vzip2q_f16(CONCAT(a, 0).value, \ | |||
CONCAT(a, 1).value); /*a5b5a6b6 a7b7a8b8*/ \ | |||
auto b10 = vzip1q_f16(CONCAT(a, 2).value, \ | |||
CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ | |||
auto b11 = vzip2q_f16(CONCAT(a, 2).value, \ | |||
CONCAT(a, 3).value); /*c5d5c6d6 c7d7c8d8*/ \ | |||
auto b20 = vzip1q_f16(CONCAT(a, 4).value, \ | |||
CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ | |||
auto b21 = vzip2q_f16(CONCAT(a, 4).value, \ | |||
CONCAT(a, 5).value); /*e5f5e6f6 e7f7e8f8*/ \ | |||
auto b30 = vzip1q_f16(CONCAT(a, 6).value, \ | |||
CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ | |||
auto b31 = vzip2q_f16(CONCAT(a, 6).value, \ | |||
CONCAT(a, 7).value); /*g5h5g6h6 g7h7g8h8*/ \ | |||
auto s32b00 = vreinterpretq_s32_f16(b00); \ | |||
auto s32b01 = vreinterpretq_s32_f16(b01); \ | |||
auto s32b10 = vreinterpretq_s32_f16(b10); \ | |||
auto s32b11 = vreinterpretq_s32_f16(b11); \ | |||
auto s32b20 = vreinterpretq_s32_f16(b20); \ | |||
auto s32b21 = vreinterpretq_s32_f16(b21); \ | |||
auto s32b30 = vreinterpretq_s32_f16(b30); \ | |||
auto s32b31 = vreinterpretq_s32_f16(b31); \ | |||
auto s64b00 = vreinterpretq_s64_s32( \ | |||
vzip1q_s32(s32b00, s32b10)); /*a1b1c1d1 a2b2c2d2*/ \ | |||
auto s64b01 = vreinterpretq_s64_s32( \ | |||
vzip2q_s32(s32b00, s32b10)); /*a3b3c3d3 a4b4c4d4*/ \ | |||
auto s64b10 = vreinterpretq_s64_s32( \ | |||
vzip1q_s32(s32b01, s32b11)); /*a5b5c5d5 a6b6c6d6*/ \ | |||
auto s64b11 = vreinterpretq_s64_s32( \ | |||
vzip2q_s32(s32b01, s32b11)); /*a7b7c7d7 a8b8c8d8*/ \ | |||
auto s64b20 = vreinterpretq_s64_s32( \ | |||
vzip1q_s32(s32b20, s32b30)); /*e1f1g1h1 e2f2g2h2*/ \ | |||
auto s64b21 = vreinterpretq_s64_s32( \ | |||
vzip2q_s32(s32b20, s32b30)); /*e3f3g3h3 e4f4g4h4*/ \ | |||
auto s64b30 = vreinterpretq_s64_s32( \ | |||
vzip1q_s32(s32b21, s32b31)); /*e5f5g5h5 e6f6g6h6*/ \ | |||
auto s64b31 = vreinterpretq_s64_s32( \ | |||
vzip2q_s32(s32b21, s32b31)); /*e7f7g7h7 e8f8g8h8*/ \ | |||
CONCAT(ret, 0).value = \ | |||
vreinterpretq_f16_s64(vzip1q_s64(s64b00, s64b20)); \ | |||
CONCAT(ret, 1).value = \ | |||
vreinterpretq_f16_s64(vzip2q_s64(s64b00, s64b20)); \ | |||
CONCAT(ret, 2).value = \ | |||
vreinterpretq_f16_s64(vzip1q_s64(s64b01, s64b21)); \ | |||
CONCAT(ret, 3).value = \ | |||
vreinterpretq_f16_s64(vzip2q_s64(s64b01, s64b21)); \ | |||
CONCAT(ret, 4).value = \ | |||
vreinterpretq_f16_s64(vzip1q_s64(s64b10, s64b30)); \ | |||
CONCAT(ret, 5).value = \ | |||
vreinterpretq_f16_s64(vzip2q_s64(s64b10, s64b30)); \ | |||
CONCAT(ret, 6).value = \ | |||
vreinterpretq_f16_s64(vzip1q_s64(s64b11, s64b31)); \ | |||
CONCAT(ret, 7).value = \ | |||
vreinterpretq_f16_s64(vzip2q_s64(s64b11, s64b31)); \ | |||
} while (0); | |||
#else | |||
#define TRANSPOSE_4x4(a, ret) \ | |||
do { \ | |||
auto b0_01 = vzip_f16(CONCAT(a, 0).value, \ | |||
CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ | |||
auto b1_01 = vzip_f16(CONCAT(a, 2).value, \ | |||
CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ | |||
auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \ | |||
auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \ | |||
auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \ | |||
auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \ | |||
auto s32b00b10 = vzip_s32(s32b00, s32b10); /*a1b1c1d1 a2b2c2d2*/ \ | |||
auto s32b01b11 = vzip_s32(s32b01, s32b11); /*a3b3c3d3 a4b4c4d4*/ \ | |||
CONCAT(ret, 0).value = vreinterpret_f16_s32(s32b00b10.val[0]); \ | |||
CONCAT(ret, 1).value = vreinterpret_f16_s32(s32b00b10.val[1]); \ | |||
CONCAT(ret, 2).value = vreinterpret_f16_s32(s32b01b11.val[0]); \ | |||
CONCAT(ret, 3).value = vreinterpret_f16_s32(s32b01b11.val[1]); \ | |||
} while (0); | |||
#define TRANSPOSE_4x8(a, ret) \ | |||
do { \ | |||
auto b0_01 = vzipq_f16( \ | |||
CONCAT(a, 0).value, \ | |||
CONCAT(a, 1).value); /*a1b1a2b2a3b3a4b4 a5b5a6b6a7b7a8b8*/ \ | |||
auto b1_01 = vzipq_f16( \ | |||
CONCAT(a, 2).value, \ | |||
CONCAT(a, 3).value); /*c1d1c2d2c3d3c4d4 c5d6c6d6c7d7c8d8*/ \ | |||
auto s32b00 = vreinterpretq_s32_f16(b0_01.val[0]); \ | |||
auto s32b01 = vreinterpretq_s32_f16(b0_01.val[1]); \ | |||
auto s32b10 = vreinterpretq_s32_f16(b1_01.val[0]); \ | |||
auto s32b11 = vreinterpretq_s32_f16(b1_01.val[1]); \ | |||
auto s32b00b10 = vzipq_s32( \ | |||
s32b00, s32b10); /*a1b1c1d1a2b2c2d2 a3b3c3d3a4b4c4d4*/ \ | |||
auto s32b01b11 = vzipq_s32( \ | |||
s32b01, s32b11); /*a5b5c5d5a6b6c6d6 a7b7c7d7a8b8c8d8*/ \ | |||
CONCAT(ret, 0).value = \ | |||
vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[0])); \ | |||
CONCAT(ret, 1).value = \ | |||
vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[0])); \ | |||
CONCAT(ret, 2).value = \ | |||
vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[1])); \ | |||
CONCAT(ret, 3).value = \ | |||
vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[1])); \ | |||
CONCAT(ret, 4).value = \ | |||
vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[0])); \ | |||
CONCAT(ret, 5).value = \ | |||
vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[0])); \ | |||
CONCAT(ret, 6).value = \ | |||
vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[1])); \ | |||
CONCAT(ret, 7).value = \ | |||
vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[1])); \ | |||
} while (0); | |||
#define TRANSPOSE_8x4(a, ret) \ | |||
do { \ | |||
auto b0_01 = vzip_f16(CONCAT(a, 0).value, \ | |||
CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ | |||
auto b1_01 = vzip_f16(CONCAT(a, 2).value, \ | |||
CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ | |||
auto b2_01 = vzip_f16(CONCAT(a, 4).value, \ | |||
CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ | |||
auto b3_01 = vzip_f16(CONCAT(a, 6).value, \ | |||
CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ | |||
auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \ | |||
auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \ | |||
auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \ | |||
auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \ | |||
auto s32b20 = vreinterpret_s32_f16(b2_01.val[0]); \ | |||
auto s32b21 = vreinterpret_s32_f16(b2_01.val[1]); \ | |||
auto s32b30 = vreinterpret_s32_f16(b3_01.val[0]); \ | |||
auto s32b31 = vreinterpret_s32_f16(b3_01.val[1]); \ | |||
auto s32b00b10 = vzip_s32(s32b00, s32b10); \ | |||
auto s32b01b11 = vzip_s32(s32b01, s32b11); \ | |||
auto s32b20b30 = vzip_s32(s32b20, s32b30); \ | |||
auto s32b21b31 = vzip_s32(s32b21, s32b31); \ | |||
CONCAT(ret, 0).value = \ | |||
vcombine_f16(vreinterpret_f16_s32(s32b00b10.val[0]), \ | |||
vreinterpret_f16_s32(s32b20b30.val[0])); \ | |||
CONCAT(ret, 1).value = \ | |||
vcombine_f16(vreinterpret_f16_s32(s32b00b10.val[1]), \ | |||
vreinterpret_f16_s32(s32b20b30.val[1])); \ | |||
CONCAT(ret, 2).value = \ | |||
vcombine_f16(vreinterpret_f16_s32(s32b01b11.val[0]), \ | |||
vreinterpret_f16_s32(s32b21b31.val[0])); \ | |||
CONCAT(ret, 3).value = \ | |||
vcombine_f16(vreinterpret_f16_s32(s32b01b11.val[1]), \ | |||
vreinterpret_f16_s32(s32b21b31.val[1])); \ | |||
} while (0); | |||
#define TRANSPOSE_8x8(a, ret) \ | |||
do { \ | |||
auto b00 = vzipq_f16(CONCAT(a, 0).value, \ | |||
CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ | |||
auto b01 = vzipq_f16(CONCAT(a, 0).value, \ | |||
CONCAT(a, 1).value); /*a5b5a6b6 a7b7a8b8*/ \ | |||
auto b10 = vzipq_f16(CONCAT(a, 2).value, \ | |||
CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ | |||
auto b11 = vzipq_f16(CONCAT(a, 2).value, \ | |||
CONCAT(a, 3).value); /*c5d5c6d6 c7d7c8d8*/ \ | |||
auto b20 = vzipq_f16(CONCAT(a, 4).value, \ | |||
CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ | |||
auto b21 = vzipq_f16(CONCAT(a, 4).value, \ | |||
CONCAT(a, 5).value); /*e5f5e6f6 e7f7e8f8*/ \ | |||
auto b30 = vzipq_f16(CONCAT(a, 6).value, \ | |||
CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ | |||
auto b31 = vzipq_f16(CONCAT(a, 6).value, \ | |||
CONCAT(a, 7).value); /*g5h5g6h6 g7h7g8h8*/ \ | |||
auto s32b00 = vreinterpretq_s32_f16(b00.val[0]); \ | |||
auto s32b01 = vreinterpretq_s32_f16(b01.val[1]); \ | |||
auto s32b10 = vreinterpretq_s32_f16(b10.val[0]); \ | |||
auto s32b11 = vreinterpretq_s32_f16(b11.val[1]); \ | |||
auto s32b20 = vreinterpretq_s32_f16(b20.val[0]); \ | |||
auto s32b21 = vreinterpretq_s32_f16(b21.val[1]); \ | |||
auto s32b30 = vreinterpretq_s32_f16(b30.val[0]); \ | |||
auto s32b31 = vreinterpretq_s32_f16(b31.val[1]); \ | |||
auto s32b00b10 = vzipq_s32(s32b00, s32b10); \ | |||
auto s32b01b11 = vzipq_s32(s32b01, s32b11); \ | |||
auto s32b20b30 = vzipq_s32(s32b20, s32b30); \ | |||
auto s32b21b31 = vzipq_s32(s32b21, s32b31); \ | |||
CONCAT(ret, 0).value = vreinterpretq_f16_s32( \ | |||
vcombine_s32(vget_low_s32(s32b00b10.val[0]), \ | |||
vget_low_s32(s32b20b30.val[0]))); \ | |||
CONCAT(ret, 1).value = vreinterpretq_f16_s32( \ | |||
vcombine_s32(vget_high_s32(s32b00b10.val[0]), \ | |||
vget_high_s32(s32b20b30.val[0]))); \ | |||
CONCAT(ret, 2).value = vreinterpretq_f16_s32( \ | |||
vcombine_s32(vget_low_s32(s32b00b10.val[1]), \ | |||
vget_low_s32(s32b20b30.val[1]))); \ | |||
CONCAT(ret, 3).value = vreinterpretq_f16_s32( \ | |||
vcombine_s32(vget_high_s32(s32b00b10.val[1]), \ | |||
vget_high_s32(s32b20b30.val[1]))); \ | |||
CONCAT(ret, 4).value = vreinterpretq_f16_s32( \ | |||
vcombine_s32(vget_low_s32(s32b01b11.val[0]), \ | |||
vget_low_s32(s32b21b31.val[0]))); \ | |||
CONCAT(ret, 5).value = vreinterpretq_f16_s32( \ | |||
vcombine_s32(vget_high_s32(s32b01b11.val[0]), \ | |||
vget_high_s32(s32b21b31.val[0]))); \ | |||
CONCAT(ret, 6).value = vreinterpretq_f16_s32( \ | |||
vcombine_s32(vget_low_s32(s32b01b11.val[1]), \ | |||
vget_low_s32(s32b21b31.val[1]))); \ | |||
CONCAT(ret, 7).value = vreinterpretq_f16_s32( \ | |||
vcombine_s32(vget_high_s32(s32b01b11.val[1]), \ | |||
vget_high_s32(s32b21b31.val[1]))); \ | |||
} while (0); | |||
#endif | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,33 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/f16/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 2, | |||
3, 4, 4, winograd_2x3_4x4_f16) | |||
MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 4, | |||
5, 1, 1, winograd_4x5_1x1_f16) | |||
MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 6, | |||
3, 1, 1, winograd_6x3_1x1_f16) | |||
MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 2, | |||
3, 8, 8, winograd_2x3_8x8_f16) | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,373 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/f16/strategy_2x3.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/f16/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/arm_common/conv_bias/f16/helper.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F23) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
namespace { | |||
void transpose_4x4(const __fp16* src, __fp16* dst, int lda, int ldb) { | |||
float16x4x2_t a0, a1; | |||
a0.val[0] = vld1_f16(src + 0 * lda); // a0a1a2a3 | |||
a0.val[1] = vld1_f16(src + 1 * lda); // b0b1b2b3 | |||
a1.val[0] = vld1_f16(src + 2 * lda); // c0c1c2c3 | |||
a1.val[1] = vld1_f16(src + 3 * lda); // d0d1d2d3 | |||
float16x4x2_t b0 = vzip_f16(a0.val[0], a1.val[0]); // a0c0a1c1a2c2a3c3 | |||
float16x4x2_t b1 = vzip_f16(a0.val[1], a1.val[1]); // b0d0b1d1b2d2b3d3 | |||
float16x4x2_t c0 = vzip_f16(b0.val[0], b1.val[0]); // a0b0c0d0a1b1c1d1 | |||
float16x4x2_t c1 = vzip_f16(b0.val[1], b1.val[1]); // a2b2c2d2a3b3c3d3 | |||
vst1_f16(dst + 0 * ldb, c0.val[0]); | |||
vst1_f16(dst + 1 * ldb, c0.val[1]); | |||
vst1_f16(dst + 2 * ldb, c1.val[0]); | |||
vst1_f16(dst + 3 * ldb, c1.val[1]); | |||
} | |||
struct InputTransform2X3 { | |||
template <bool inner> | |||
static void prepare(const __fp16* input, __fp16* patch, __fp16* patchT, | |||
int ih_start, int iw_start, size_t IH, size_t IW, | |||
size_t ic, size_t IC) { | |||
constexpr size_t alpha = 2 + 3 - 1; | |||
constexpr size_t alpha4 = alpha * 4; | |||
if (!(inner && ic + 4 < IC)) { | |||
memset(patch, 0, sizeof(__fp16) * 4 * alpha * alpha); | |||
} | |||
if (inner) { | |||
const __fp16* input_ptr = | |||
input + ic * IH * IW + ih_start * IW + iw_start; | |||
for (size_t ico = 0; ico < 4; ++ico) { | |||
if (ic + ico < IC) { | |||
auto v0 = vld1_f16(input_ptr); | |||
auto v1 = vld1_f16(input_ptr + IW); | |||
auto v2 = vld1_f16(input_ptr + IW * 2); | |||
auto v3 = vld1_f16(input_ptr + IW * 3); | |||
vst1_f16(patch + ico * alpha4 + 0 * 4, v0); | |||
vst1_f16(patch + ico * alpha4 + 1 * 4, v1); | |||
vst1_f16(patch + ico * alpha4 + 2 * 4, v2); | |||
vst1_f16(patch + ico * alpha4 + 3 * 4, v3); | |||
input_ptr += IH * IW; | |||
} | |||
} | |||
} else { | |||
int ih0_act = std::max<int>(ih_start, 0), | |||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||
iw0_act = std::max<int>(iw_start, 0), | |||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||
// partial copy | |||
for (size_t ico = 0; ico < 4; ++ico) { | |||
if (ic + ico < IC) { | |||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||
patch[ico * alpha4 + iho * 4 + iwo] = | |||
input[(ic + ico) * IH * IW + ih * IW + iw]; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
transpose_4x4(patch + 0 * 1, patchT + 0 * 4, 16, 4); | |||
transpose_4x4(patch + 4 * 1, patchT + 4 * 4, 16, 4); | |||
transpose_4x4(patch + 8 * 1, patchT + 8 * 4, 16, 4); | |||
transpose_4x4(patch + 12 * 1, patchT + 12 * 4, 16, 4); | |||
} | |||
static void transform(const __fp16* patchT, __fp16* input_transform_buf, | |||
size_t unit_idx, size_t nr_units_in_tile, size_t ic, | |||
size_t IC) { | |||
constexpr size_t alpha = 2 + 3 - 1; | |||
// BT * d * B | |||
#define cb(m, n) \ | |||
Vector<__fp16, 4> d##m##n = \ | |||
Vector<__fp16, 4>::load(patchT + m * 4 * 4 + n * 4); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
#undef cb | |||
//! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 | |||
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | |||
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | |||
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 | |||
#define cb(m) \ | |||
auto t0##m = d0##m - d2##m; \ | |||
auto t1##m = d1##m + d2##m; \ | |||
auto t2##m = d2##m - d1##m; \ | |||
auto t3##m = d3##m - d1##m; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
d##m##0 = t##m##0 - t##m##2; \ | |||
d##m##1 = t##m##1 + t##m##2; \ | |||
d##m##2 = t##m##2 - t##m##1; \ | |||
d##m##3 = t##m##3 - t##m##1; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(m, n) \ | |||
d##m##n.save(input_transform_buf + \ | |||
(m * alpha + n) * nr_units_in_tile * IC + unit_idx * IC + \ | |||
ic); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) | |||
#undef cb | |||
} | |||
}; | |||
template <BiasMode bmode, typename Op> | |||
struct OutputTransform2X3 { | |||
static void transform(const dt_float16* output_transform_buf, | |||
const dt_float16* bias, dt_float16* output, | |||
dt_float16* transform_mid_buf, size_t oh_start, | |||
size_t ow_start, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, size_t oc_index, | |||
size_t unit_idx, size_t nr_units_in_tile, | |||
const DType& src_dtype, const DType& dst_dtype) { | |||
Op op(src_dtype, dst_dtype); | |||
const __fp16* output_transform_ptr = | |||
reinterpret_cast<const __fp16*>(output_transform_buf); | |||
const __fp16* bias_ptr = reinterpret_cast<const __fp16*>(bias); | |||
__fp16* output_ptr = reinterpret_cast<__fp16*>(output); | |||
__fp16* transform_mid_ptr = | |||
reinterpret_cast<__fp16*>(transform_mid_buf); | |||
//! AT * m * A | |||
constexpr size_t alpha = 2 + 3 - 1; | |||
size_t OC = oc_end - oc_start; | |||
size_t oc = oc_start + oc_index; | |||
#define cb(m, n) \ | |||
auto v##m##n = Vector<__fp16, 4>::load( \ | |||
output_transform_ptr + (m * alpha + n) * nr_units_in_tile * OC + \ | |||
unit_idx * OC + oc_index); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
#undef cb | |||
//! 1 1 1 0 v00 v01 v02 v03 1 0 | |||
//! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||
//! v20 v21 v22 v23 1 -1 | |||
//! v30 v31 v32 v33 0 1 | |||
#define cb(m) \ | |||
auto t0##m = v0##m + v1##m + v2##m; \ | |||
auto t1##m = v1##m - v2##m + v3##m; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
v00 = t00 + t01 + t02; | |||
v10 = t10 + t11 + t12; | |||
v01 = t01 - t02 + t03; | |||
v11 = t11 - t12 + t13; | |||
Vector<__fp16, 4> vbias; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
vbias = Vector<__fp16, 4>::load(bias_ptr + oc); | |||
v00 += vbias; | |||
v10 += vbias; | |||
v01 += vbias; | |||
v11 += vbias; | |||
} | |||
float16x8_t result01, result23; | |||
result01 = vcombine_f16(v00.value, v01.value); | |||
result23 = vcombine_f16(v10.value, v11.value); | |||
if (bmode != BiasMode::BIAS) { | |||
result01 = op(result01); | |||
result23 = op(result23); | |||
} | |||
vst1q_f16(transform_mid_ptr, result01); | |||
vst1q_f16(transform_mid_ptr + 8, result23); | |||
for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { | |||
for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { | |||
for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) { | |||
size_t oh = oh_start + oho; | |||
size_t ow = ow_start + owo; | |||
__fp16 res = transform_mid_ptr[oho * 2 * 4 + owo * 4 + oco]; | |||
if (bmode == BiasMode::BIAS) { | |||
res += bias_ptr[(oc + oco) * OH * OW + oh * OW + ow]; | |||
res = op(res); | |||
} | |||
output_ptr[(oc + oco) * OH * OW + oh * OW + ow] = res; | |||
} | |||
} | |||
} | |||
} | |||
}; | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f16) | |||
void winograd_2x3_4x4_f16::filter(const dt_float16* filter, | |||
dt_float16* filter_transform_buf, | |||
dt_float16* transform_mid_buf, size_t OC, | |||
size_t IC, size_t oc_start, size_t oc_end) { | |||
constexpr int alpha = 2 + 3 - 1; | |||
//! G * g * GT | |||
__fp16* filter_transbuf_ptr = | |||
reinterpret_cast<__fp16*>(filter_transform_buf); | |||
__fp16* filter_transmid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf); | |||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||
rep(ic, IC) { | |||
const __fp16* filter_ptr = reinterpret_cast<const __fp16*>(filter) + | |||
(oc * IC + ic) * 3 * 3; | |||
/** | |||
* origin: (4x3) * (3 x 3) * (3 x 4) | |||
* pack to G and g to times of 4 | |||
* now: (4x4) * (4 x 4) * (4 x 4) | |||
*/ | |||
//! 1 0 0 0 v00 v01 v02 0 1 0.5 0.5 0 | |||
//! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 | |||
//! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 | |||
//! 0 0 1 0 0 0 0 0 0 0 0 0 | |||
float16x4_t v0 = vld1_f16(filter_ptr); // 0 1 2 3 | |||
float16x4_t v1 = vld1_f16(filter_ptr + 3); // 3 4 5 6 | |||
float16x4_t v2 = vld1_f16(filter_ptr + 5); // 5678 | |||
float16x4_t v3 = vdup_n_f16(0); | |||
v2 = vext_f16(v2, v3, 1); | |||
v0 = vset_lane_f16(0, v0, 3); | |||
v1 = vset_lane_f16(0, v1, 3); | |||
#define cb(i) float16x4_t vsum##i; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
vsum0 = v0; | |||
float16x4_t v0addv2 = vadd_f16(v0, v2); | |||
float16x4_t v02addv1 = vadd_f16(v0addv2, v1); | |||
float16x4_t v02subv1 = vsub_f16(v0addv2, v1); | |||
vsum1 = vmul_n_f16(v02addv1, 0.5); | |||
vsum2 = vmul_n_f16(v02subv1, 0.5); | |||
vsum3 = v2; | |||
#define cb(i) \ | |||
do { \ | |||
mid_buf1[0] = vget_lane_f16(vsum##i, 0); \ | |||
__fp16 a0a2 = vget_lane_f16(vsum##i, 0) + vget_lane_f16(vsum##i, 2); \ | |||
__fp16 a0a2adda1 = a0a2 + vget_lane_f16(vsum##i, 1); \ | |||
__fp16 a0a2suba1 = a0a2 - vget_lane_f16(vsum##i, 1); \ | |||
mid_buf1[1] = a0a2adda1 * 0.5; \ | |||
mid_buf1[2] = a0a2suba1 * 0.5; \ | |||
mid_buf1[3] = vget_lane_f16(vsum##i, 2); \ | |||
mid_buf1 += 4; \ | |||
} while (0); | |||
__fp16* mid_buf1 = filter_transmid_ptr; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
mid_buf1 = filter_transmid_ptr; | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
filter_transbuf_ptr[(i * alpha + j) * OC * IC + ic * OC + oc] = | |||
filter_transmid_ptr[i * alpha + j]; | |||
} | |||
} | |||
} | |||
} | |||
void winograd_2x3_4x4_f16::input(const dt_float16* input, | |||
dt_float16* input_transform_buf, | |||
dt_float16* transform_mid_buf, size_t IH, | |||
size_t IW, size_t IC, size_t PH, size_t PW, | |||
size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
megdnn_assert(IC % 4 == 0); | |||
constexpr int alpha = 3 + 2 - 1; | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
dt_float16* patch = transform_mid_buf; | |||
dt_float16* patchT = transform_mid_buf + 4 * alpha * alpha; | |||
for (size_t ic = 0; ic < IC; ic += 4) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||
InputTransform2X3::prepare<true>( | |||
reinterpret_cast<const __fp16*>(input), | |||
reinterpret_cast<__fp16*>(patch), | |||
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, | |||
IH, IW, ic, IC); | |||
InputTransform2X3::transform( | |||
reinterpret_cast<const __fp16*>(patchT), | |||
reinterpret_cast<__fp16*>(input_transform_buf), | |||
unit_idx, nr_units_in_tile, ic, IC); | |||
} else { | |||
InputTransform2X3::prepare<false>( | |||
reinterpret_cast<const __fp16*>(input), | |||
reinterpret_cast<__fp16*>(patch), | |||
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, | |||
IH, IW, ic, IC); | |||
InputTransform2X3::transform( | |||
reinterpret_cast<const __fp16*>(patchT), | |||
reinterpret_cast<__fp16*>(input_transform_buf), | |||
unit_idx, nr_units_in_tile, ic, IC); | |||
} | |||
} | |||
} | |||
} | |||
void winograd_2x3_4x4_f16::output(const dt_float16* output_transform_buf, | |||
const dt_float16* bias, dt_float16* output, | |||
dt_float16* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, | |||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||
#define cb(_bmode, _nonline_op, ...) \ | |||
OutputTransform2X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
for (size_t oc = oc_start; oc < oc_end; oc += 4) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp16_F23, cb, __fp16, __fp16, bmode, | |||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||
nr_units_in_tile, src_dtype, dst_dtype); | |||
} | |||
} | |||
#undef cb | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,407 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/arm_common/conv_bias/f16/strategy.h" | |||
#include "src/arm_common/conv_bias/f16/helper.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/winograd/winograd_generator.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "midout.h" | |||
#include "src/common/winograd/winograd_helper.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_f16_F23_8x8) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
namespace { | |||
void transpose_8x4(const __fp16* src, __fp16* dst, int lda, int ldb) { | |||
float16x4x2_t a0, a1, a2, a3; | |||
a0.val[0] = vld1_f16(src + 0 * lda); | |||
a0.val[1] = vld1_f16(src + 1 * lda); | |||
a1.val[0] = vld1_f16(src + 2 * lda); | |||
a1.val[1] = vld1_f16(src + 3 * lda); | |||
a2.val[0] = vld1_f16(src + 4 * lda); | |||
a2.val[1] = vld1_f16(src + 5 * lda); | |||
a3.val[0] = vld1_f16(src + 6 * lda); | |||
a3.val[1] = vld1_f16(src + 7 * lda); | |||
float16x4x2_t b0 = vzip_f16(a0.val[0], a1.val[0]); | |||
float16x4x2_t b1 = vzip_f16(a0.val[1], a1.val[1]); | |||
float16x4x2_t b2 = vzip_f16(a2.val[0], a3.val[0]); | |||
float16x4x2_t b3 = vzip_f16(a2.val[1], a3.val[1]); | |||
float16x4x2_t c0 = vzip_f16(b0.val[0], b1.val[0]); | |||
float16x4x2_t c1 = vzip_f16(b0.val[1], b1.val[1]); | |||
float16x4x2_t c2 = vzip_f16(b2.val[0], b3.val[0]); | |||
float16x4x2_t c3 = vzip_f16(b2.val[1], b3.val[1]); | |||
vst1_f16(dst + 0 * ldb, c0.val[0]); | |||
vst1_f16(dst + 1 * ldb, c2.val[0]); | |||
vst1_f16(dst + 2 * ldb, c0.val[1]); | |||
vst1_f16(dst + 3 * ldb, c2.val[1]); | |||
vst1_f16(dst + 4 * ldb, c1.val[0]); | |||
vst1_f16(dst + 5 * ldb, c3.val[0]); | |||
vst1_f16(dst + 6 * ldb, c1.val[1]); | |||
vst1_f16(dst + 7 * ldb, c3.val[1]); | |||
} | |||
struct InputTransform2X3_8x8 { | |||
template <bool inner> | |||
static void prepare(const __fp16* input, __fp16* patch, __fp16* patchT, | |||
int ih_start, int iw_start, size_t IH, size_t IW, | |||
size_t ic, size_t IC) { | |||
constexpr size_t alpha = 2 + 3 - 1; | |||
if (!(inner && ic + 8 < IC)) { | |||
memset(patch, 0, sizeof(__fp16) * 8 * alpha * alpha); | |||
} | |||
if (inner) { | |||
const __fp16* input_ptr = | |||
input + ic * IH * IW + ih_start * IW + iw_start; | |||
for (size_t ico = 0; ico < 8; ++ico) { | |||
if (ic + ico < IC) { | |||
auto v0 = vld1_f16(input_ptr); | |||
auto v1 = vld1_f16(input_ptr + IW); | |||
auto v2 = vld1_f16(input_ptr + IW * 2); | |||
auto v3 = vld1_f16(input_ptr + IW * 3); | |||
vst1_f16(patch + ico * alpha * 4 + 0 * 4, v0); | |||
vst1_f16(patch + ico * alpha * 4 + 1 * 4, v1); | |||
vst1_f16(patch + ico * alpha * 4 + 2 * 4, v2); | |||
vst1_f16(patch + ico * alpha * 4 + 3 * 4, v3); | |||
input_ptr += IH * IW; | |||
} | |||
} | |||
} else { | |||
int ih0_act = std::max<int>(ih_start, 0), | |||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||
iw0_act = std::max<int>(iw_start, 0), | |||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||
// partial copy | |||
for (size_t ico = 0; ico < 8; ++ico) { | |||
if (ic + ico < IC) { | |||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||
patch[ico * alpha * alpha + iho * alpha + iwo] = | |||
input[(ic + ico) * IH * IW + ih * IW + iw]; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
transpose_8x4(patch + 4 * 0, patchT + 32 * 0, 16, 4); | |||
transpose_8x4(patch + 4 * 1, patchT + 32 * 1, 16, 4); | |||
transpose_8x4(patch + 4 * 2, patchT + 32 * 2, 16, 4); | |||
transpose_8x4(patch + 4 * 3, patchT + 32 * 3, 16, 4); | |||
} | |||
static void transform(const __fp16* patchT, __fp16* input_transform_buf, | |||
size_t unit_idx, size_t nr_units_in_tile, size_t ic, | |||
size_t IC) { | |||
constexpr size_t alpha = 2 + 3 - 1; | |||
// BT * d * B | |||
#define cb(m, n) \ | |||
Vector<__fp16, 8> d##m##n = \ | |||
Vector<__fp16, 8>::load(patchT + 8 * (m * 4 + n)); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
#undef cb | |||
//! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 | |||
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | |||
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | |||
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 | |||
#define cb(m) \ | |||
auto t0##m = d0##m - d2##m; \ | |||
auto t1##m = d1##m + d2##m; \ | |||
auto t2##m = d2##m - d1##m; \ | |||
auto t3##m = d3##m - d1##m; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
d##m##0 = t##m##0 - t##m##2; \ | |||
d##m##1 = t##m##1 + t##m##2; \ | |||
d##m##2 = t##m##2 - t##m##1; \ | |||
d##m##3 = t##m##3 - t##m##1; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
size_t ICB = IC / 8; | |||
size_t icb = ic / 8; | |||
#define cb(m, n) \ | |||
d##m##n.save(input_transform_buf + \ | |||
(m * alpha + n) * nr_units_in_tile * ICB * 8 + \ | |||
icb * nr_units_in_tile * 8 + unit_idx * 8); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) | |||
#undef cb | |||
} | |||
}; | |||
template <BiasMode bmode, typename Op> | |||
struct OutputTransform2X3_8x8 { | |||
static void transform(const dt_float16* output_transform_buf, | |||
const dt_float16* bias, dt_float16* output, | |||
dt_float16* transform_mid_buf, size_t oh_start, | |||
size_t ow_start, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, size_t oc_index, | |||
size_t unit_idx, size_t nr_units_in_tile, | |||
const DType& src_dtype, const DType& dst_dtype) { | |||
Op op(src_dtype, dst_dtype); | |||
const __fp16* output_transform_ptr = | |||
reinterpret_cast<const __fp16*>(output_transform_buf); | |||
const __fp16* bias_ptr = reinterpret_cast<const __fp16*>(bias); | |||
__fp16* output_ptr = reinterpret_cast<__fp16*>(output); | |||
__fp16* transform_mid_ptr = | |||
reinterpret_cast<__fp16*>(transform_mid_buf); | |||
//! AT * m * A | |||
constexpr size_t alpha = 2 + 3 - 1; | |||
size_t oc = oc_start + oc_index; | |||
size_t OCB = (oc_end - oc_start) / 8; | |||
size_t ocb = oc_index / 8; | |||
#define cb(m, n) \ | |||
auto v##m##n = Vector<__fp16, 8>::load( \ | |||
output_transform_ptr + \ | |||
(m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | |||
ocb * nr_units_in_tile * 8 + unit_idx * 8); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
#undef cb | |||
//! 1 1 1 0 v00 v01 v02 v03 1 0 | |||
//! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||
//! v20 v21 v22 v23 1 -1 | |||
//! v30 v31 v32 v33 0 1 | |||
#define cb(m) \ | |||
auto t0##m = v0##m + v1##m + v2##m; \ | |||
auto t1##m = v1##m - v2##m + v3##m; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
v00 = t00 + t01 + t02; | |||
v10 = t10 + t11 + t12; | |||
v01 = t01 - t02 + t03; | |||
v11 = t11 - t12 + t13; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
Vector<__fp16, 8> vbias; | |||
vbias = Vector<__fp16, 8>::load(bias_ptr + oc); | |||
v00 += vbias; | |||
v10 += vbias; | |||
v01 += vbias; | |||
v11 += vbias; | |||
} | |||
if (bmode != BiasMode::BIAS) { | |||
v00.value = op(v00.value); | |||
v01.value = op(v01.value); | |||
v10.value = op(v10.value); | |||
v11.value = op(v11.value); | |||
} | |||
v00.save(transform_mid_ptr + (0 * 2 + 0) * 8); | |||
v10.save(transform_mid_ptr + (1 * 2 + 0) * 8); | |||
v01.save(transform_mid_ptr + (0 * 2 + 1) * 8); | |||
v11.save(transform_mid_ptr + (1 * 2 + 1) * 8); | |||
for (size_t oco = 0; oco < 8 && oc + oco < oc_end; ++oco) { | |||
for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { | |||
for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) { | |||
size_t oh = oh_start + oho; | |||
size_t ow = ow_start + owo; | |||
__fp16 res = transform_mid_ptr[oho * 2 * 8 + owo * 8 + oco]; | |||
if (bmode == BiasMode::BIAS) { | |||
res += bias_ptr[(oc + oco) * OH * OW + oh * OW + ow]; | |||
res = op(res); | |||
} | |||
output_ptr[(oc + oco) * OH * OW + oh * OW + ow] = res; | |||
} | |||
} | |||
} | |||
} | |||
}; | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_f16) | |||
void winograd_2x3_8x8_f16::filter(const dt_float16* filter, | |||
dt_float16* filter_transform_buf, | |||
dt_float16* transform_mid_buf, size_t OC, | |||
size_t IC, size_t oc_start, size_t oc_end) { | |||
constexpr int alpha = 2 + 3 - 1; | |||
//! G * g * GT | |||
__fp16* filter_transbuf_ptr = | |||
reinterpret_cast<__fp16*>(filter_transform_buf); | |||
__fp16* filter_transmid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf); | |||
size_t OCB = OC / 8; | |||
size_t ICB = IC / 8; | |||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||
rep(ic, IC) { | |||
const __fp16* filter_ptr = reinterpret_cast<const __fp16*>(filter) + | |||
(oc * IC + ic) * 3 * 3; | |||
/** | |||
* origin: (4x3) * (3 x 3) * (3 x 4) | |||
* pack to G and g to times of 4 | |||
* now: (4x4) * (4 x 4) * (4 x 4) | |||
*/ | |||
//! 1 0 0 0 v00 v01 v02 0 1 0.5 0.5 0 | |||
//! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 | |||
//! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 | |||
//! 0 0 1 0 0 0 0 0 0 0 0 0 | |||
float16x4_t v0 = vld1_f16(filter_ptr); // 0 1 2 3 | |||
float16x4_t v1 = vld1_f16(filter_ptr + 3); // 3 4 5 6 | |||
float16x4_t v2 = vld1_f16(filter_ptr + 5); // 5678 | |||
float16x4_t v3 = vdup_n_f16(0); | |||
v2 = vext_f16(v2, v3, 1); | |||
v0 = vset_lane_f16(0, v0, 3); | |||
v1 = vset_lane_f16(0, v1, 3); | |||
#define cb(i) float16x4_t vsum##i; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
vsum0 = v0; | |||
float16x4_t v0addv2 = vadd_f16(v0, v2); | |||
float16x4_t v02addv1 = vadd_f16(v0addv2, v1); | |||
float16x4_t v02subv1 = vsub_f16(v0addv2, v1); | |||
vsum1 = vmul_n_f16(v02addv1, 0.5); | |||
vsum2 = vmul_n_f16(v02subv1, 0.5); | |||
vsum3 = v2; | |||
#define cb(i) \ | |||
do { \ | |||
mid_buf1[0] = vget_lane_f16(vsum##i, 0); \ | |||
__fp16 a0a2 = vget_lane_f16(vsum##i, 0) + vget_lane_f16(vsum##i, 2); \ | |||
__fp16 a0a2adda1 = a0a2 + vget_lane_f16(vsum##i, 1); \ | |||
__fp16 a0a2suba1 = a0a2 - vget_lane_f16(vsum##i, 1); \ | |||
mid_buf1[1] = a0a2adda1 * 0.5; \ | |||
mid_buf1[2] = a0a2suba1 * 0.5; \ | |||
mid_buf1[3] = vget_lane_f16(vsum##i, 2); \ | |||
mid_buf1 += 4; \ | |||
} while (0); | |||
__fp16* mid_buf1 = filter_transmid_ptr; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
mid_buf1 = filter_transmid_ptr; | |||
#undef cb | |||
size_t ocb = (oc) / 8; | |||
size_t oc8 = (oc) % 8; | |||
size_t icb = (ic) / 8; | |||
size_t ic8 = (ic) % 8; | |||
rep(i, alpha) rep(j, alpha) { | |||
filter_transbuf_ptr[(i * alpha + j) * OCB * ICB * 8 * 8 + | |||
ocb * ICB * 8 * 8 + icb * 8 * 8 + ic8 * 8 + | |||
oc8] = filter_transmid_ptr[i * alpha + j]; | |||
} | |||
} | |||
} | |||
} | |||
void winograd_2x3_8x8_f16::input(const dt_float16* input, | |||
dt_float16* input_transform_buf, | |||
dt_float16* transform_mid_buf, size_t IH, | |||
size_t IW, size_t IC, size_t PH, size_t PW, | |||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||
megdnn_assert(IC % 8 == 0); | |||
constexpr int alpha = 3 + 2 - 1; | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
dt_float16* patch = transform_mid_buf; | |||
dt_float16* patchT = transform_mid_buf + 8 * alpha * alpha; | |||
for (size_t ic = 0; ic < IC; ic += 8) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||
InputTransform2X3_8x8::prepare<true>( | |||
reinterpret_cast<const __fp16*>(input), | |||
reinterpret_cast<__fp16*>(patch), | |||
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, | |||
IH, IW, ic, IC); | |||
InputTransform2X3_8x8::transform( | |||
reinterpret_cast<const __fp16*>(patchT), | |||
reinterpret_cast<__fp16*>(input_transform_buf), | |||
unit_idx, nr_units_in_tile, ic, IC); | |||
} else { | |||
InputTransform2X3_8x8::prepare<false>( | |||
reinterpret_cast<const __fp16*>(input), | |||
reinterpret_cast<__fp16*>(patch), | |||
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, | |||
IH, IW, ic, IC); | |||
InputTransform2X3_8x8::transform( | |||
reinterpret_cast<const __fp16*>(patchT), | |||
reinterpret_cast<__fp16*>(input_transform_buf), | |||
unit_idx, nr_units_in_tile, ic, IC); | |||
} | |||
} | |||
} | |||
} | |||
void winograd_2x3_8x8_f16::output(const dt_float16* output_transform_buf, | |||
const dt_float16* bias, dt_float16* output, | |||
dt_float16* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t OH, | |||
size_t OW, size_t oc_start, size_t oc_end, | |||
size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
#define cb(_bmode, _nonline_op, ...) \ | |||
OutputTransform2X3_8x8<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | |||
__VA_ARGS__); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
for (size_t oc = oc_start; oc < oc_end; oc += 8) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp16_F23_8x8, cb, __fp16, __fp16, | |||
bmode, nonline_mode, output_transform_buf, bias, output, | |||
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, | |||
oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | |||
} | |||
} | |||
#undef cb | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,488 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/f16/helper.h" | |||
#include "src/arm_common/conv_bias/f16/strategy.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F45) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
namespace { | |||
struct FilterTransform4X5 { | |||
#define FILTER_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
wd##r0 = d##r0; \ | |||
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.222168; \ | |||
wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.222168; \ | |||
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.222168; \ | |||
wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.222168; \ | |||
auto tmpd0 = d##0 * 0.710938; \ | |||
auto tmpd1 = d##1 * 0.355469; \ | |||
auto tmpd2 = d##2 * 0.177734; \ | |||
auto tmpd3 = d##3 * 0.088867; \ | |||
auto tmpd4 = d##4 * 0.044434; \ | |||
auto tmpdr0 = d##r0 * 0.710938; \ | |||
auto tmpdr1 = d##r1 * 0.355469; \ | |||
auto tmpdr2 = d##r2 * 0.177734; \ | |||
auto tmpdr3 = d##r3 * 0.088867; \ | |||
auto tmpdr4 = d##r4 * 0.044434; \ | |||
wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ | |||
wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||
wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ | |||
wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||
tmpd0 = d##0 * 0.011108; \ | |||
tmpd1 = d##1 * 0.022217; \ | |||
tmpd2 = d##2 * 0.044434; \ | |||
tmpd3 = d##3 * 0.088867; \ | |||
tmpd4 = d##4 * 0.177734; \ | |||
tmpdr0 = d##r0 * 0.011108; \ | |||
tmpdr1 = d##r1 * 0.022217; \ | |||
tmpdr2 = d##r2 * 0.044434; \ | |||
tmpdr3 = d##r3 * 0.088867; \ | |||
tmpdr4 = d##r4 * 0.177734; \ | |||
wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ | |||
; \ | |||
wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||
; \ | |||
wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ | |||
; \ | |||
wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||
; \ | |||
wd##7 = d##4; \ | |||
wd##r7 = d##r4; \ | |||
} while (0); | |||
#define FILTER_TRANSFORM_FINAL(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.222168; \ | |||
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.222168; \ | |||
auto tmp0 = d##0 * 0.710938 + d##2 * 0.177734 + d##4 * 0.044434; \ | |||
auto tmp1 = d##1 * 0.355469 + d##3 * 0.088867; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d##0 * 0.011108 + d##2 * 0.044434 + d##4 * 0.177734; \ | |||
tmp1 = d##1 * 0.022217 + d##3 * 0.088867; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = d##4; \ | |||
} while (0); | |||
static void transform(const __fp16* filter, __fp16* filter_transform_buf, | |||
__fp16* transform_mid_buf, size_t OC, size_t IC, | |||
size_t oc_start, size_t oc_end) { | |||
// Gg * GT | |||
// G | |||
//[[ 1. 0. 0. 0. 0. ] | |||
// [-0.2222222 -0.2222222 -0.2222222 -0.2222222 -0.2222222] | |||
// [-0.2222222 0.2222222 -0.2222222 0.2222222 -0.2222222] | |||
// [ 0.7111111 0.3555556 0.1777778 0.0888889 0.0444444] | |||
// [ 0.7111111 -0.3555556 0.1777778 -0.0888889 0.0444444] | |||
// [ 0.0111111 0.0222222 0.0444444 0.0888889 0.1777778] | |||
// [ 0.0111111 -0.0222222 0.0444444 -0.0888889 0.1777778] | |||
// [ 0. 0. 0. 0. 1. ]] | |||
constexpr size_t alpha = 4 + 5 - 1; | |||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||
rep(ic, IC) { | |||
const __fp16* fptr = filter + (oc * IC + ic) * 5 * 5; | |||
#define cb(i) Vector<__fp16, 4> g##i = Vector<__fp16, 4>::load(fptr + 5 * i); | |||
UNROLL_CALL_NOWRAPPER(5, cb); | |||
#undef cb | |||
#define cb(i) __fp16 gr##i = *(fptr + 5 * i + 4); | |||
UNROLL_CALL_NOWRAPPER(5, cb); | |||
#undef cb | |||
#define cb(i) Vector<__fp16, 4> Gg##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) __fp16 Ggr##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<__fp16, 8> Ggt##i; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(i) Vector<__fp16, 8> result##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
FILTER_TRANSFORM(g, Gg) | |||
#if MEGDNN_AARCH64 | |||
float16x8_t vgr = {Ggr0, Ggr1, Ggr2, Ggr3, | |||
Ggr4, Ggr5, Ggr6, Ggr7}; | |||
Vector<__fp16, 8> Ggt4(vgr); | |||
TRANSPOSE_8x4(Gg, Ggt); | |||
FILTER_TRANSFORM_FINAL(Ggt, result); | |||
#define cb(i) result##i.save(transform_mid_buf + i * alpha); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||
oc] = transform_mid_buf[j * alpha + i]; | |||
} | |||
#else | |||
#define GET_VECTOR_FP16D_ELEM(s, i, idx) vget_lane_f16(CONCAT(s, i).value, idx) | |||
#define cb(i) \ | |||
do { \ | |||
mid_buf1[0] = GET_VECTOR_FP16D_ELEM(Gg, i, 0); \ | |||
auto tmp024 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) + \ | |||
GET_VECTOR_FP16D_ELEM(Gg, i, 2) + Ggr##i; \ | |||
auto tmp13 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) + \ | |||
GET_VECTOR_FP16D_ELEM(Gg, i, 3); \ | |||
mid_buf1[1] = (tmp024 + tmp13) * -0.2222222; \ | |||
mid_buf1[2] = (tmp024 - tmp13) * -0.2222222; \ | |||
auto tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.7111111; \ | |||
auto tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.3555556; \ | |||
auto tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.1777778; \ | |||
auto tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \ | |||
auto tmp4 = Ggr##i * 0.0444444; \ | |||
tmp024 = tmp0 + tmp2 + tmp4; \ | |||
tmp13 = tmp1 + tmp3; \ | |||
mid_buf1[3] = tmp024 + tmp13; \ | |||
mid_buf1[4] = tmp024 - tmp13; \ | |||
tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.0111111; \ | |||
tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.0222222; \ | |||
tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.0444444; \ | |||
tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \ | |||
tmp4 = Ggr##i * 0.1777778; \ | |||
tmp024 = tmp0 + tmp2 + tmp4; \ | |||
tmp13 = tmp1 + tmp3; \ | |||
mid_buf1[5] = tmp024 + tmp13; \ | |||
mid_buf1[6] = tmp024 - tmp13; \ | |||
mid_buf1[7] = Ggr##i; \ | |||
mid_buf1 += 8; \ | |||
} while (0); | |||
__fp16* mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
mid_buf1 = transform_mid_buf; | |||
#undef cb | |||
#undef GET_VECTOR_FP16D_ELEM | |||
rep(i, alpha) rep(j, alpha) { | |||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||
oc] = transform_mid_buf[i * alpha + j]; | |||
} | |||
#endif | |||
} | |||
} | |||
} | |||
}; | |||
#undef FILTER_TRANSFORM | |||
#undef FILTER_TRANSFORM_FINAL | |||
struct InputTransform4X5 { | |||
#define INPUT_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||
auto tmp0 = d##2 - d##4 * 4.25f + d##6; \ | |||
auto tmp1 = d##1 - d##3 * 4.25f + d##5; \ | |||
wd##1 = tmp0 + tmp1; \ | |||
wd##2 = tmp0 - tmp1; \ | |||
tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ | |||
tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ | |||
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||
} while (0) | |||
#define GET_VECTOR_FP16Q_ELEM(s, i, idx) vgetq_lane_f16(CONCAT(s, i).value, idx) | |||
template <bool inner> | |||
static void transform(const __fp16* input, __fp16* input_transform_buf, | |||
__fp16* transform_mid_buf, int ih_start, int iw_start, | |||
size_t ic, size_t IH, size_t IW, size_t IC, | |||
size_t unit_idx, size_t nr_units_in_tile) { | |||
// BTd * B | |||
//([[ 1. , 0. , -5.25, 0. , 5.25, 0. , -1. , 0. ], | |||
// [ 0. , 1. , 1. , -4.25, -4.25, 1. , 1. , 0. ], | |||
// [ 0. , -1. , 1. , 4.25, -4.25, -1. , 1. , 0. ], | |||
// [ 0. , 2. , 4. , -2.5 , -5. , 0.5 , 1. , 0. ], | |||
// [ 0. , -2. , 4. , 2.5 , -5. , -0.5 , 1. , 0. ], | |||
// [ 0. , 0.5 , 0.25, -2.5 , -1.25, 2. , 1. , 0. ], | |||
// [ 0. , -0.5 , 0.25, 2.5 , -1.25, -2. , 1. , 0. ], | |||
// [ 0. , -1. , 0. , 5.25, 0. , -5.25, 0. , 1. ]])) | |||
constexpr size_t alpha = 4 + 5 - 1; | |||
if (!inner) { | |||
memset(transform_mid_buf, 0, sizeof(__fp16) * alpha * alpha); | |||
} | |||
#define cb(i) Vector<__fp16, 8> d##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
if (inner) { | |||
const __fp16* input_ptr = | |||
input + ic * IH * IW + ih_start * IW + iw_start; | |||
#define cb(i) d##i = Vector<__fp16, 8>::load(input_ptr + IW * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} else { | |||
int ih0_act = std::max<int>(ih_start, 0), | |||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||
iw0_act = std::max<int>(iw_start, 0), | |||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||
transform_mid_buf[iho * alpha + iwo] = | |||
input[ic * IH * IW + ih * IW + iw]; | |||
} | |||
} | |||
#define cb(i) d##i = Vector<__fp16, 8>::load(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} | |||
#define cb(i) Vector<__fp16, 8> wd##i, ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
INPUT_TRANSFORM(d, wd); | |||
TRANSPOSE_8x8(wd, d); | |||
INPUT_TRANSFORM(d, ret); | |||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||
unit_idx * IC + ic] = | |||
transform_mid_buf[j * alpha + i]; | |||
} | |||
} | |||
}; | |||
#undef INPUT_TRANSFORM | |||
#define OUTPUT_TRANSFORM(m, s) \ | |||
do { \ | |||
s##0 = m##0 + m##1 + m##2 + m##3 + m##4 + m##5 + m##6; \ | |||
s##1 = m##1 - m##2 + m##3 * 0.5 - m##4 * 0.5 + m##5 * 2.0 - \ | |||
m##6 * 2.0; \ | |||
s##2 = m##1 + m##2 + m##3 * 0.25 + m##4 * 0.25 + m##5 * 4.0 + \ | |||
m##6 * 4.0; \ | |||
s##3 = m##1 - m##2 + m##3 * 0.125 - m##4 * 0.125 + m##5 * 8.0 - \ | |||
m##6 * 8.0 + m##7; \ | |||
} while (0) | |||
template <BiasMode bmode, typename Op> | |||
struct OutputTransform4X5 { | |||
static void transform(const dt_float16* output_transform_buf, | |||
const dt_float16* bias, dt_float16* output, | |||
dt_float16* transform_mid_buf, size_t oh_start, | |||
size_t ow_start, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, size_t oc_index, | |||
size_t unit_idx, size_t nr_units_in_tile, | |||
const DType& src_dtype, const DType& dst_dtype) { | |||
Op op(src_dtype, dst_dtype); | |||
//! AT * m * A | |||
// AT f45 | |||
// 1.0 1.0 1.0 1.000 1.000 1.0 1.0 0.0 | |||
// 0.0 1.0 -1.0 0.500 -0.500 2.0 -2.0 0.0 | |||
// 0.0 1.0 1.0 0.250 0.250 4.0 4.0 0.0 | |||
// 0.0 1.0 -1.0 0.125 -0.125 8.0 -8.0 1.0 | |||
constexpr size_t alpha = 5 + 4 - 1; | |||
const __fp16* fp16_output_transform_buf = | |||
reinterpret_cast<const __fp16*>(output_transform_buf); | |||
const __fp16* fp16_bias = reinterpret_cast<const __fp16*>(bias); | |||
__fp16* fp16_output = reinterpret_cast<__fp16*>(output); | |||
__fp16* fp16_transform_mid_buf = | |||
reinterpret_cast<__fp16*>(transform_mid_buf); | |||
__fp16* mid_buf1 = fp16_transform_mid_buf; | |||
size_t OC = oc_end - oc_start; | |||
size_t oc = oc_start + oc_index; | |||
#define cb(m, n) \ | |||
fp16_transform_mid_buf[m * alpha + n] = \ | |||
fp16_output_transform_buf[(m * alpha + n) * nr_units_in_tile * \ | |||
OC + \ | |||
unit_idx * OC + oc_index]; | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
#undef cb | |||
#define cb(i) \ | |||
auto m##i = Vector<__fp16, 8>::load(fp16_transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<__fp16, 8> s##i; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(i) Vector<__fp16, 4> st##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<__fp16, 4> result##i; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
OUTPUT_TRANSFORM(m, s); | |||
TRANSPOSE_4x8(s, st); | |||
OUTPUT_TRANSFORM(st, result); | |||
TRANSPOSE_4x4(result, result); | |||
if (oh_start + 4 <= OH && ow_start + 4 <= OW) { | |||
int index = (oc * OH + oh_start) * OW + ow_start; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
float16x4_t bias0 = vdup_n_f16(fp16_bias[oc]); | |||
result0.value = vadd_f16(result0.value, bias0); | |||
result1.value = vadd_f16(result1.value, bias0); | |||
result2.value = vadd_f16(result2.value, bias0); | |||
result3.value = vadd_f16(result3.value, bias0); | |||
} else if (bmode == BiasMode::BIAS) { | |||
float16x4_t bmbias0 = vld1_f16(fp16_bias + index); | |||
float16x4_t bmbias1 = vld1_f16(fp16_bias + index + OW); | |||
float16x4_t bmbias2 = vld1_f16(fp16_bias + index + OW * 2); | |||
float16x4_t bmbias3 = vld1_f16(fp16_bias + index + OW * 3); | |||
result0.value = vadd_f16(result0.value, bmbias0); | |||
result1.value = vadd_f16(result1.value, bmbias1); | |||
result2.value = vadd_f16(result2.value, bmbias2); | |||
result3.value = vadd_f16(result3.value, bmbias3); | |||
} | |||
float16x8_t item01 = op(vcombine_f16(result0.value, result1.value)); | |||
float16x8_t item23 = op(vcombine_f16(result2.value, result3.value)); | |||
vst1_f16(fp16_output + index, vget_low_f16(item01)); | |||
vst1_f16(fp16_output + index + OW, vget_high_f16(item01)); | |||
vst1_f16(fp16_output + index + OW * 2, vget_low_f16(item23)); | |||
vst1_f16(fp16_output + index + OW * 3, vget_high_f16(item23)); | |||
} else { | |||
#define cb(i) result##i.save(mid_buf1 + i * 4); | |||
mid_buf1 = fp16_transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
mid_buf1 = fp16_transform_mid_buf; | |||
#undef cb | |||
for (size_t oho = 0; oho < 4 && oh_start + oho < OH; ++oho) { | |||
for (size_t owo = 0; owo < 4 && ow_start + owo < OW; ++owo) { | |||
size_t oh = oh_start + oho; | |||
size_t ow = ow_start + owo; | |||
__fp16 res = mid_buf1[oho * 4 + owo]; | |||
if (bmode == BiasMode::BIAS) { | |||
res += fp16_bias[oc * OH * OW + oh * OW + ow]; | |||
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
res += fp16_bias[oc]; | |||
} | |||
res = op(res); | |||
fp16_output[oc * OH * OW + oh * OW + ow] = res; | |||
} | |||
} | |||
} | |||
} | |||
}; | |||
#undef OUTPUT_TRANSFORM | |||
#undef GET_VECTOR_FP16Q_ELEM | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f16) | |||
void winograd_4x5_1x1_f16::filter(const dt_float16* filter, | |||
dt_float16* filter_transform_buf, | |||
dt_float16* transform_mid_buf, size_t OC, | |||
size_t IC, size_t oc_start, size_t oc_end) { | |||
FilterTransform4X5::transform( | |||
reinterpret_cast<const __fp16*>(filter), | |||
reinterpret_cast<__fp16*>(filter_transform_buf), | |||
reinterpret_cast<__fp16*>(transform_mid_buf), OC, IC, oc_start, | |||
oc_end); | |||
} | |||
void winograd_4x5_1x1_f16::input(const dt_float16* input, | |||
dt_float16* input_transform_buf, | |||
dt_float16* transform_mid_buf, size_t IH, | |||
size_t IW, size_t IC, size_t PH, size_t PW, | |||
size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
constexpr int alpha = 4 + 5 - 1; | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
rep(ic, IC) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||
InputTransform4X5::transform<true>( | |||
reinterpret_cast<const __fp16*>(input), | |||
reinterpret_cast<__fp16*>(input_transform_buf), | |||
reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, | |||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||
} else { | |||
InputTransform4X5::transform<false>( | |||
reinterpret_cast<const __fp16*>(input), | |||
reinterpret_cast<__fp16*>(input_transform_buf), | |||
reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, | |||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||
} | |||
} | |||
} | |||
} | |||
void winograd_4x5_1x1_f16::output(const dt_float16* output_transform_buf, | |||
const dt_float16* bias, dt_float16* output, | |||
dt_float16* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, | |||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||
#define cb(_bmode, _nonline_op, ...) \ | |||
OutputTransform4X5<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp16_F45, cb, __fp16, __fp16, bmode, | |||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||
nr_units_in_tile, src_dtype, dst_dtype); | |||
} | |||
} | |||
#undef cb | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,608 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/f16/strategy_6x3.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/f16/helper.h" | |||
#include "src/arm_common/conv_bias/f16/strategy.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F63) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
namespace { | |||
struct FilterTransform6X3 { | |||
// 1.0000000 0.0000000 0.0000000 | |||
//-0.2222222 -0.2222222 -0.2222222 | |||
//-0.2222222 0.2222222 -0.2222222 | |||
// 0.0111111 0.0222222 0.0444444 | |||
// 0.0111111 -0.0222222 0.0444444 | |||
// 0.7111111 0.3555556 0.1777778 | |||
// 0.7111111 -0.3555556 0.1777778 | |||
// 0.0000000 0.0000000 1.0000000 | |||
#define FILTER_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
wd##1 = (d##0 + d##1 + d##2) * -0.222168; \ | |||
wd##2 = (d##0 - d##1 + d##2) * -0.222168; \ | |||
auto tmpd0 = d##0 * 0.011108; \ | |||
auto tmpd1 = d##1 * 0.022217; \ | |||
auto tmpd2 = d##2 * 0.044434; \ | |||
wd##3 = tmpd0 + tmpd1 + tmpd2; \ | |||
wd##4 = tmpd0 - tmpd1 + tmpd2; \ | |||
tmpd0 = d##0 * 0.710938; \ | |||
tmpd1 = d##1 * 0.355469; \ | |||
tmpd2 = d##2 * 0.177734; \ | |||
wd##5 = tmpd0 + tmpd1 + tmpd2; \ | |||
wd##6 = tmpd0 - tmpd1 + tmpd2; \ | |||
wd##7 = d##2; \ | |||
} while (0); | |||
#define FILTER_TRANSFORM_FINAL(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
wd##1 = (d##0 + d##1 + d##2) * -0.222168; \ | |||
wd##2 = (d##0 - d##1 + d##2) * -0.222168; \ | |||
auto tmp0 = d##0 * 0.011108 + d##2 * 0.044434; \ | |||
auto tmp1 = d##1 * 0.022217; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d##0 * 0.710938 + d##2 * 0.177734; \ | |||
tmp1 = d##1 * 0.355469; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = d##2; \ | |||
} while (0); | |||
static void transform(const __fp16* filter, __fp16* filter_transform_buf, | |||
__fp16* transform_mid_buf, size_t OC, size_t IC, | |||
size_t oc_start, size_t oc_end) { | |||
// Gg * GT | |||
// G | |||
// 1.0000000 0.0000000 0.0000000 | |||
//-0.2222222 -0.2222222 -0.2222222 | |||
//-0.2222222 0.2222222 -0.2222222 | |||
// 0.0111111 0.0222222 0.0444444 | |||
// 0.0111111 -0.0222222 0.0444444 | |||
// 0.7111111 0.3555556 0.1777778 | |||
// 0.7111111 -0.3555556 0.1777778 | |||
// 0.0000000 0.0000000 1.0000000 | |||
constexpr size_t alpha = 6 + 3 - 1; | |||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||
rep(ic, IC) { | |||
const __fp16* fptr = filter + (oc * IC + ic) * 3 * 3; | |||
float16x4_t v0, v1, v2, v3; | |||
v0 = vld1_f16(fptr); // 0 1 2 3 | |||
v1 = vld1_f16(fptr + 3); // 3 4 5 6 | |||
v2 = vld1_f16(fptr + 5); // 5 6 7 8 | |||
v3 = vdup_n_f16(0); | |||
v2 = vext_f16(v2, v3, 1); | |||
v0 = vset_lane_f16(0, v0, 3); | |||
v1 = vset_lane_f16(0, v1, 3); | |||
#define cb(i) Vector<__fp16, 4> g##i(v##i); | |||
UNROLL_CALL_NOWRAPPER(3, cb); | |||
#undef cb | |||
#define cb(i) Vector<__fp16, 4> Gg##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<__fp16, 8> Ggt##i; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(i) Vector<__fp16, 8> result##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
FILTER_TRANSFORM(g, Gg) | |||
#if MEGDNN_AARCH64 | |||
TRANSPOSE_8x4(Gg, Ggt); | |||
FILTER_TRANSFORM_FINAL(Ggt, result); | |||
#define cb(i) result##i.save(transform_mid_buf + i * alpha); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||
oc] = transform_mid_buf[j * alpha + i]; | |||
} | |||
#else | |||
/* 1.0000000 -0.2222222 -0.2222222 0.0111111 0.0111111 | |||
0.7111111 0.7111111 0.0000000 0.0000000 -0.2222222 | |||
0.2222222 0.0222222 -0.0222222 0.3555556 -0.3555556 | |||
0.0000000 0.0000000 -0.2222222 -0.2222222 0.0444444 | |||
0.0444444 0.1777778 0.1777778 1.0000000*/ | |||
#define GET_VECTOR_FP16D_ELEM(s, i, idx) vget_lane_f16(CONCAT(s, i).value, idx) | |||
#define cb(i) \ | |||
do { \ | |||
mid_buf1[0] = GET_VECTOR_FP16D_ELEM(Gg, i, 0); \ | |||
auto tmp02 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) + \ | |||
GET_VECTOR_FP16D_ELEM(Gg, i, 2); \ | |||
mid_buf1[1] = (tmp02 + GET_VECTOR_FP16D_ELEM(Gg, i, 1)) * -0.2222222; \ | |||
mid_buf1[2] = (tmp02 - GET_VECTOR_FP16D_ELEM(Gg, i, 1)) * -0.2222222; \ | |||
auto tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.0111111; \ | |||
auto tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.0222222; \ | |||
auto tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.0444444; \ | |||
tmp02 = tmp0 + tmp2; \ | |||
mid_buf1[3] = tmp02 + tmp1; \ | |||
mid_buf1[4] = tmp02 - tmp1; \ | |||
tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.7111111; \ | |||
tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.3555556; \ | |||
tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.1777778; \ | |||
tmp02 = tmp0 + tmp2; \ | |||
mid_buf1[5] = tmp02 + tmp1; \ | |||
mid_buf1[6] = tmp02 - tmp1; \ | |||
mid_buf1[7] = GET_VECTOR_FP16D_ELEM(Gg, i, 2); \ | |||
mid_buf1 += 8; \ | |||
} while (0); | |||
__fp16* mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
mid_buf1 = transform_mid_buf; | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||
oc] = transform_mid_buf[i * alpha + j]; | |||
} | |||
#undef GET_VECTOR_FP16D_ELEM | |||
#endif | |||
} | |||
} | |||
} | |||
}; | |||
#undef FILTER_TRANSFORM | |||
#undef FILTER_TRANSFORM_FINAL | |||
/** | |||
* input transform | |||
* | |||
* wd0 = (d0 - d6) + 5.25 * (d4 - d2) | |||
* wd1 = (d6 + d2 - 4.25 * d4) + (d1 + d5 - 4.25 * d3) | |||
* wd2 = (d6 + d2 - 4.25 * d4) - (d1 + d5 - 4.25 * d3) | |||
* wd3 = (d6 + 0.25 * d2 - 1.25 * d4) + 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) | |||
* wd4 = (d6 + 0.25 * d2 - 1.25 * d4) - 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) | |||
* wd5 = (d6 - 5.0 * d4 + 4.0 * d2) + 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) | |||
* wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) | |||
* wd7 = (d7 - d1) + 5.25 * (d3 - d5) | |||
*/ | |||
#define INPUT_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25; \ | |||
auto tmp0 = d##6 + d##2 - d##4 * 4.25; \ | |||
auto tmp1 = d##1 + d##5 - d##3 * 4.25; \ | |||
wd##1 = tmp0 + tmp1; \ | |||
wd##2 = tmp0 - tmp1; \ | |||
tmp0 = d##6 + d##2 * 0.25 - d##4 * 1.25; \ | |||
tmp1 = (d##5 + d##1 * 0.25 - d##3 * 1.25) * 2.0; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d6 - d4 * 5.0 + d2 * 4.0; \ | |||
tmp1 = (d1 + d5 * 0.25 - d3 * 1.25) * 2.0; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25; \ | |||
} while (0); | |||
#define GET_VECTOR_FP16Q_ELEM(s, i, idx) vgetq_lane_f16(CONCAT(s, i).value, idx) | |||
struct InputTransform6x3 { | |||
template <bool inner> | |||
static void transform(const __fp16* input, __fp16* input_transform_buf, | |||
__fp16* transform_mid_buf, int ih_start, int iw_start, | |||
size_t ic, size_t IH, size_t IW, size_t IC, | |||
size_t unit_idx, size_t nr_units_in_tile) { | |||
// BTd * B | |||
// 1.000 0.000 -5.25 0.000 5.250 0.000 -1.0 0.00 | |||
// -0.00 1.000 1.000 -4.25 -4.25 1.000 1.00 -0.0 | |||
// -0.00 -1.00 1.000 4.250 -4.25 -1.00 1.00 -0.0 | |||
// 0.000 0.500 0.250 -2.50 -1.25 2.000 1.00 0.00 | |||
// 0.000 -0.50 0.250 2.500 -1.25 -2.00 1.00 0.00 | |||
// 0.000 2.000 4.000 -2.50 -5.00 0.500 1.00 0.00 | |||
// 0.000 -2.00 4.000 2.500 -5.00 -0.50 1.00 0.00 | |||
// 0.000 -1.00 0.000 5.250 0.000 -5.25 0.00 1.00 | |||
constexpr size_t alpha = 6 + 3 - 1; | |||
if (!inner) { | |||
memset(transform_mid_buf, 0, sizeof(__fp16) * alpha * alpha); | |||
} | |||
#define cb(i) Vector<__fp16, 8> d##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
if (inner) { | |||
const __fp16* input_ptr = | |||
input + ic * IH * IW + ih_start * IW + iw_start; | |||
#define cb(i) d##i = Vector<__fp16, 8>::load(input_ptr + IW * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} else { | |||
int ih0_act = std::max<int>(ih_start, 0), | |||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||
iw0_act = std::max<int>(iw_start, 0), | |||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||
transform_mid_buf[iho * alpha + iwo] = | |||
input[ic * IH * IW + ih * IW + iw]; | |||
} | |||
} | |||
#define cb(i) d##i = Vector<__fp16, 8>::load(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} | |||
#define cb(i) Vector<__fp16, 8> wd##i, ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
INPUT_TRANSFORM(d, wd); | |||
#if MEGDNN_AARCH64 | |||
TRANSPOSE_8x8(wd, d); | |||
INPUT_TRANSFORM(d, ret); | |||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||
unit_idx * IC + ic] = | |||
transform_mid_buf[j * alpha + i]; | |||
} | |||
#else | |||
//! 1 0 0 0 0 0 0 0 | |||
//! 0 1 -1 0.5 -0.5 2 -2 -1 | |||
//! -5.25 1 1 0.25 0.25 4 4 0 | |||
//! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25 | |||
//! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0 | |||
//! 0 1 -1 2 -2 0.5 -0.5 -5.25 | |||
//! -1 1 1 1 1 1 1 0 | |||
//! 0 0 0 0 0 0 0 1 | |||
#define cb(i) \ | |||
do { \ | |||
mid_buf1[0] = GET_VECTOR_FP16Q_ELEM(wd, i, 0) - \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 6) + \ | |||
5.25 * (GET_VECTOR_FP16Q_ELEM(wd, i, 4) - \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 2)); \ | |||
mid_buf1[7] = GET_VECTOR_FP16Q_ELEM(wd, i, 7) - \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 1) + \ | |||
5.25 * (GET_VECTOR_FP16Q_ELEM(wd, i, 3) - \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 5)); \ | |||
auto tmp0 = GET_VECTOR_FP16Q_ELEM(wd, i, 2) + \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 6) - \ | |||
4.25 * GET_VECTOR_FP16Q_ELEM(wd, i, 4); \ | |||
auto tmp1 = GET_VECTOR_FP16Q_ELEM(wd, i, 1) - \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 3) * 4.25 + \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 5); \ | |||
mid_buf1[1] = tmp0 + tmp1; \ | |||
mid_buf1[2] = tmp0 - tmp1; \ | |||
tmp0 = GET_VECTOR_FP16Q_ELEM(wd, i, 2) * 0.25 + \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 6) - \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 4) * 1.25; \ | |||
tmp1 = GET_VECTOR_FP16Q_ELEM(wd, i, 1) * 0.5 - \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 3) * 2.5 + \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 5) * 2; \ | |||
mid_buf1[3] = tmp0 + tmp1; \ | |||
mid_buf1[4] = tmp0 - tmp1; \ | |||
tmp0 = GET_VECTOR_FP16Q_ELEM(wd, i, 6) + \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 2) * 4.0 - \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 4) * 5.0; \ | |||
tmp1 = GET_VECTOR_FP16Q_ELEM(wd, i, 1) * 2 - \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 3) * 2.5 + \ | |||
GET_VECTOR_FP16Q_ELEM(wd, i, 5) * 0.5; \ | |||
mid_buf1[5] = tmp0 + tmp1; \ | |||
mid_buf1[6] = tmp0 - tmp1; \ | |||
mid_buf1 += 8; \ | |||
} while (0); | |||
__fp16* mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
mid_buf1 = transform_mid_buf; | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||
unit_idx * IC + ic] = | |||
transform_mid_buf[i * alpha + j]; | |||
} | |||
#endif | |||
} | |||
}; | |||
#undef INPUT_TRANSFORM | |||
#define OUTPUT_TRANSFORM(m, r) \ | |||
do { \ | |||
auto m1addm2 = m##1 + m##2; \ | |||
auto m1subm2 = m##1 - m##2; \ | |||
auto m3addm4 = m##3 + m##4; \ | |||
auto m3subm4 = m##3 - m##4; \ | |||
auto m5addm6 = m##5 + m##6; \ | |||
auto m5subm6 = m##5 - m##6; \ | |||
r##0 = m##0 + m1addm2 + m3addm4 + m5addm6; \ | |||
r##1 = m1subm2 + m3subm4 * 2.0 + m5subm6 * 0.5; \ | |||
r##2 = m1addm2 + m3addm4 * 4.0 + m5addm6 * 0.25; \ | |||
r##3 = m1subm2 + m3subm4 * 8.0 + m5subm6 * 0.125; \ | |||
r##4 = m1addm2 + m3addm4 * 16.0 + m5addm6 * 0.0625; \ | |||
r##5 = m1subm2 + m3subm4 * 32.0 + m5subm6 * 0.03125 + m##7; \ | |||
} while (0) | |||
template <BiasMode bmode, typename Op> | |||
struct OutputTransform6X3 { | |||
static void transform(const dt_float16* output_transform_buf, | |||
const dt_float16* bias, dt_float16* output, | |||
dt_float16* transform_mid_buf, size_t oh_start, | |||
size_t ow_start, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, size_t oc_index, | |||
size_t unit_idx, size_t nr_units_in_tile, | |||
const DType& src_dtype, const DType& dst_dtype) { | |||
Op op(src_dtype, dst_dtype); | |||
//! AT * m * A | |||
// AT f45 | |||
// 1.0 1.0 1.0 1.0 1.0 1.00000 1.00000 0.0 | |||
// 0.0 1.0 -1.0 2.0 -2.0 0.50000 -0.50000 0.0 | |||
// 0.0 1.0 1.0 4.0 4.0 0.25000 0.25000 0.0 | |||
// 0.0 1.0 -1.0 8.0 -8.0 0.12500 -0.12500 0.0 | |||
// 0.0 1.0 1.0 16.0 16.0 0.06250 0.06250 0.0 | |||
// 0.0 1.0 -1.0 32.0 -32.0 0.03125 -0.03125 1.0 | |||
constexpr size_t alpha = 3 + 6 - 1; | |||
const __fp16* fp16_output_transform_buf = | |||
reinterpret_cast<const __fp16*>(output_transform_buf); | |||
const __fp16* fp16_bias = reinterpret_cast<const __fp16*>(bias); | |||
__fp16* fp16_output = reinterpret_cast<__fp16*>(output); | |||
__fp16* fp16_transform_mid_buf = | |||
reinterpret_cast<__fp16*>(transform_mid_buf); | |||
__fp16* mid_buf1 = fp16_transform_mid_buf; | |||
size_t OC = oc_end - oc_start; | |||
size_t oc = oc_start + oc_index; | |||
#define cb(m, n) \ | |||
fp16_transform_mid_buf[m * alpha + n] = \ | |||
fp16_output_transform_buf[(m * alpha + n) * nr_units_in_tile * \ | |||
OC + \ | |||
unit_idx * OC + oc_index]; | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
#undef cb | |||
#define cb(i) \ | |||
auto m##i = Vector<__fp16, 8>::load(fp16_transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<__fp16, 8> s##i; | |||
UNROLL_CALL_NOWRAPPER(6, cb); | |||
#undef cb | |||
/* 1.0 0.0 0.00 0.000 0.0000 0.00000 | |||
1.0 1.0 1.00 1.000 1.0000 1.00000 | |||
1.0 -1.0 1.00 -1.000 1.0000 -1.00000 | |||
1.0 2.0 4.00 8.000 16.000 32.00000 | |||
1.0 -2.0 4.00 -8.000 16.000 -32.00000 | |||
1.0 0.5 0.25 0.125 0.0625 0.03125 | |||
1.0 -0.5 0.25 -0.125 0.0625 -0.03125 | |||
0.0 0.0 0.00 0.000 0.0000 1.00000*/ | |||
OUTPUT_TRANSFORM(m, s); | |||
mid_buf1 = fp16_transform_mid_buf; | |||
#define cb(i) \ | |||
do { \ | |||
auto m1addm2 = GET_VECTOR_FP16Q_ELEM(s, i, 1) + \ | |||
GET_VECTOR_FP16Q_ELEM(s, i, 2); \ | |||
auto m1subm2 = GET_VECTOR_FP16Q_ELEM(s, i, 1) - \ | |||
GET_VECTOR_FP16Q_ELEM(s, i, 2); \ | |||
auto m3addm4 = GET_VECTOR_FP16Q_ELEM(s, i, 3) + \ | |||
GET_VECTOR_FP16Q_ELEM(s, i, 4); \ | |||
auto m3subm4 = GET_VECTOR_FP16Q_ELEM(s, i, 3) - \ | |||
GET_VECTOR_FP16Q_ELEM(s, i, 4); \ | |||
auto m5addm6 = GET_VECTOR_FP16Q_ELEM(s, i, 5) + \ | |||
GET_VECTOR_FP16Q_ELEM(s, i, 6); \ | |||
auto m5subm6 = GET_VECTOR_FP16Q_ELEM(s, i, 5) - \ | |||
GET_VECTOR_FP16Q_ELEM(s, i, 6); \ | |||
mid_buf1[0] = \ | |||
GET_VECTOR_FP16Q_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ | |||
mid_buf1[1] = m1subm2 + m3subm4 * 2 + m5subm6 * 0.5; \ | |||
mid_buf1[2] = m1addm2 + m3addm4 * 4 + m5addm6 * 0.25; \ | |||
mid_buf1[3] = m1subm2 + m3subm4 * 8 + m5subm6 * 0.125; \ | |||
mid_buf1[4] = m1addm2 + m3addm4 * 16 + m5addm6 * 0.0625; \ | |||
mid_buf1[5] = m1subm2 + m3subm4 * 32 + m5subm6 * 0.03125 + \ | |||
GET_VECTOR_FP16Q_ELEM(s, i, 7); \ | |||
mid_buf1 += 6; \ | |||
} while (0); | |||
mid_buf1 = fp16_transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(6, cb); | |||
mid_buf1 = fp16_transform_mid_buf; | |||
#undef cb | |||
if (oh_start + 6 <= OH && ow_start + 6 <= OW) { | |||
int index = (oc * OH + oh_start) * OW + ow_start; | |||
#define cb(i) float16x4_t vr##i = vld1_f16(mid_buf1 + i * 6); | |||
UNROLL_CALL_NOWRAPPER(6, cb); | |||
#undef cb | |||
float16x8_t vr0123_45 = {mid_buf1[4], mid_buf1[5], mid_buf1[10], | |||
mid_buf1[11], mid_buf1[16], mid_buf1[17], | |||
mid_buf1[22], mid_buf1[23]}; | |||
float16x4_t vr45_45 = {mid_buf1[28], mid_buf1[29], mid_buf1[34], | |||
mid_buf1[35]}; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
float16x4_t bias0 = vdup_n_f16(fp16_bias[oc]); | |||
#define cb(i) vr##i = vadd_f16(vr##i, bias0); | |||
UNROLL_CALL_NOWRAPPER(6, cb); | |||
#undef cb | |||
vr45_45 = vadd_f16(vr45_45, bias0); | |||
vr0123_45 = vaddq_f16(vr0123_45, vcombine_f16(bias0, bias0)); | |||
} else if (bmode == BiasMode::BIAS) { | |||
#define cb(i) float16x4_t bmbias##i = vld1_f16(fp16_bias + index + OW * i); | |||
UNROLL_CALL_NOWRAPPER(6, cb); | |||
#undef cb | |||
#define cb(i) vr##i = vadd_f16(vr##i, bmbias##i); | |||
UNROLL_CALL_NOWRAPPER(6, cb); | |||
#undef cb | |||
float16x8_t vb0123_45 = {fp16_bias[index + 0 * OW + 4], | |||
fp16_bias[index + 0 * OW + 5], | |||
fp16_bias[index + 1 * OW + 4], | |||
fp16_bias[index + 1 * OW + 5], | |||
fp16_bias[index + 2 * OW + 4], | |||
fp16_bias[index + 2 * OW + 5], | |||
fp16_bias[index + 3 * OW + 4], | |||
fp16_bias[index + 3 * OW + 5]}; | |||
float16x4_t vb45_45 = {fp16_bias[index + 4 * OW + 4], | |||
fp16_bias[index + 4 * OW + 5], | |||
fp16_bias[index + 5 * OW + 4], | |||
fp16_bias[index + 5 * OW + 5]}; | |||
vr45_45 = vadd_f16(vr45_45, vb45_45); | |||
vr0123_45 = vaddq_f16(vr0123_45, vb0123_45); | |||
} | |||
float16x8_t item01 = op(vcombine_f16(vr0, vr1)); | |||
float16x8_t item23 = op(vcombine_f16(vr2, vr3)); | |||
float16x8_t item45 = op(vcombine_f16(vr4, vr5)); | |||
vst1_f16(fp16_output + index, vget_low_f16(item01)); | |||
vst1_f16(fp16_output + index + OW, vget_high_f16(item01)); | |||
vst1_f16(fp16_output + index + OW * 2, vget_low_f16(item23)); | |||
vst1_f16(fp16_output + index + OW * 3, vget_high_f16(item23)); | |||
vst1_f16(fp16_output + index + OW * 4, vget_low_f16(item45)); | |||
vst1_f16(fp16_output + index + OW * 5, vget_high_f16(item45)); | |||
vr0123_45 = op(vr0123_45); | |||
float16x8_t vr45 = op(vcombine_f16(vr45_45, vr45_45)); | |||
fp16_output[index + OW * 0 + 4] = vgetq_lane_f16(vr0123_45, 0); | |||
fp16_output[index + OW * 0 + 5] = vgetq_lane_f16(vr0123_45, 1); | |||
fp16_output[index + OW * 1 + 4] = vgetq_lane_f16(vr0123_45, 2); | |||
fp16_output[index + OW * 1 + 5] = vgetq_lane_f16(vr0123_45, 3); | |||
fp16_output[index + OW * 2 + 4] = vgetq_lane_f16(vr0123_45, 4); | |||
fp16_output[index + OW * 2 + 5] = vgetq_lane_f16(vr0123_45, 5); | |||
fp16_output[index + OW * 3 + 4] = vgetq_lane_f16(vr0123_45, 6); | |||
fp16_output[index + OW * 3 + 5] = vgetq_lane_f16(vr0123_45, 7); | |||
fp16_output[index + OW * 4 + 4] = vgetq_lane_f16(vr45, 0); | |||
fp16_output[index + OW * 4 + 5] = vgetq_lane_f16(vr45, 1); | |||
fp16_output[index + OW * 5 + 4] = vgetq_lane_f16(vr45, 2); | |||
fp16_output[index + OW * 5 + 5] = vgetq_lane_f16(vr45, 3); | |||
} else { | |||
for (size_t oho = 0; oho < 6 && oh_start + oho < OH; ++oho) { | |||
for (size_t owo = 0; owo < 6 && ow_start + owo < OW; ++owo) { | |||
size_t oh = oh_start + oho; | |||
size_t ow = ow_start + owo; | |||
__fp16 res = mid_buf1[oho * 6 + owo]; | |||
if (bmode == BiasMode::BIAS) { | |||
res += fp16_bias[oc * OH * OW + oh * OW + ow]; | |||
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
res += fp16_bias[oc]; | |||
} | |||
res = op(res); | |||
fp16_output[oc * OH * OW + oh * OW + ow] = res; | |||
} | |||
} | |||
} | |||
} | |||
}; | |||
#undef GET_VECTOR_FP16Q_ELEM | |||
#undef OUTPUT_TRANSFORM | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f16) | |||
void winograd_6x3_1x1_f16::filter(const dt_float16* filter, | |||
dt_float16* filter_transform_buf, | |||
dt_float16* transform_mid_buf, size_t OC, | |||
size_t IC, size_t oc_start, size_t oc_end) { | |||
FilterTransform6X3::transform( | |||
reinterpret_cast<const __fp16*>(filter), | |||
reinterpret_cast<__fp16*>(filter_transform_buf), | |||
reinterpret_cast<__fp16*>(transform_mid_buf), OC, IC, oc_start, | |||
oc_end); | |||
} | |||
void winograd_6x3_1x1_f16::input(const dt_float16* input, | |||
dt_float16* input_transform_buf, | |||
dt_float16* transform_mid_buf, size_t IH, | |||
size_t IW, size_t IC, size_t PH, size_t PW, | |||
size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
constexpr int alpha = 6 + 3 - 1; | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
rep(ic, IC) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||
InputTransform6x3::transform<true>( | |||
reinterpret_cast<const __fp16*>(input), | |||
reinterpret_cast<__fp16*>(input_transform_buf), | |||
reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, | |||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||
} else { | |||
InputTransform6x3::transform<false>( | |||
reinterpret_cast<const __fp16*>(input), | |||
reinterpret_cast<__fp16*>(input_transform_buf), | |||
reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, | |||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||
} | |||
} | |||
} | |||
} | |||
void winograd_6x3_1x1_f16::output(const dt_float16* output_transform_buf, | |||
const dt_float16* bias, dt_float16* output, | |||
dt_float16* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, | |||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||
#define cb(_bmode, _nonline_op, ...) \ | |||
OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp16_F63, cb, __fp16, __fp16, bmode, | |||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||
nr_units_in_tile, src_dtype, dst_dtype); | |||
} | |||
} | |||
#undef cb | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,754 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/algos.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/algos.h" | |||
#include "src/arm_common/conv_bias/direct/multi_thread_common.h" | |||
#include "src/arm_common/conv_bias/fp32/direct.h" | |||
#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" | |||
#include "src/arm_common/conv_bias/fp32/do_conv_stride2.h" | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/conv_bias/img2col_helper.h" | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
/* ======================= AlgoFP32WinogradF23_4x4 ======================== */ | |||
bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( | |||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(opr); | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 0) { | |||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | |||
return false; | |||
using Strategy = winograd::winograd_2x3_4x4_f; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
megdnn::winograd::ConvBias<Strategy, | |||
param::MatrixMul::Format::MK4>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_matmul_kern_param(param); | |||
return m_matmul_algo->usable(matmul_param) && | |||
(opr->param().format == param::ConvBias::Format::NCHW || | |||
(opr->param().format == | |||
param::ConvBias::Format::NCHW_WINOGRAD && | |||
opr->param().output_block_size == 2 && | |||
param.winograd_matmul_format == | |||
param::MatrixMul::Format::MK4)) && | |||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||
param.filter_meta.spatial[0] == 3) && | |||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
param.filter_meta.stride[0] == 1) && | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||
param.src_type.enumv() == DTypeEnum::Float32; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoFP32WinogradF23_4x4::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 1) { | |||
winograd::winograd_2x3_4x4_f strategy(param.src_type, param.filter_type, | |||
param.dst_type); | |||
return megdnn::winograd::ConvBias<winograd::winograd_2x3_4x4_f, | |||
param::MatrixMul::Format::MK4>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_workspace_size(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoFP32WinogradF23_4x4::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 2) { | |||
winograd::winograd_2x3_4x4_f strategy(param.src_type, param.filter_type, | |||
param.dst_type); | |||
auto winograd_impl = | |||
megdnn::winograd::ConvBias<winograd::winograd_2x3_4x4_f, | |||
param::MatrixMul::Format::MK4>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg); | |||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ======================= AlgoFP32WinogradF63 ======================== */ | |||
bool ConvBiasImpl::AlgoFP32WinogradF63::usable( | |||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MEGDNN_MARK_USED_VAR(opr); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 0) { | |||
using Strategy = winograd::winograd_6x3_1x1_f; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
megdnn::winograd::ConvBias<Strategy>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_matmul_kern_param(param); | |||
return m_matmul_algo->usable(matmul_param) && | |||
(opr->param().format == param::ConvBias::Format::NCHW || | |||
(opr->param().format == | |||
param::ConvBias::Format::NCHW_WINOGRAD && | |||
opr->param().output_block_size == 6 && | |||
param.winograd_matmul_format == | |||
param::MatrixMul::Format::DEFAULT)) && | |||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||
param.filter_meta.spatial[0] == 3) && | |||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
param.filter_meta.stride[0] == 1) && | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||
param.src_type.enumv() == DTypeEnum::Float32; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoFP32WinogradF63::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 1) { | |||
winograd::winograd_6x3_1x1_f strategy(param.src_type, param.filter_type, | |||
param.dst_type); | |||
return megdnn::winograd::ConvBias<winograd::winograd_6x3_1x1_f>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_workspace_size(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoFP32WinogradF63::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 2) { | |||
winograd::winograd_6x3_1x1_f strategy(param.src_type, param.filter_type, | |||
param.dst_type); | |||
auto winograd_impl = | |||
megdnn::winograd::ConvBias<winograd::winograd_6x3_1x1_f>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg); | |||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ======================= AlgoFP32WinogradF54 ======================== */ | |||
bool ConvBiasImpl::AlgoFP32WinogradF54::usable( | |||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MEGDNN_MARK_USED_VAR(opr); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 0) { | |||
using Strategy = winograd::winograd_5x4_1x1_f; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
megdnn::winograd::ConvBias<Strategy>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_matmul_kern_param(param); | |||
return m_matmul_algo->usable(matmul_param) && | |||
(opr->param().format == param::ConvBias::Format::NCHW || | |||
(opr->param().format == | |||
param::ConvBias::Format::NCHW_WINOGRAD && | |||
opr->param().output_block_size == 5 && | |||
param.winograd_matmul_format == | |||
param::MatrixMul::Format::DEFAULT)) && | |||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||
param.filter_meta.spatial[0] == 4) && | |||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
param.filter_meta.stride[0] == 1) && | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||
param.src_type.enumv() == DTypeEnum::Float32; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoFP32WinogradF54::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 1) { | |||
winograd::winograd_5x4_1x1_f strategy(param.src_type, param.filter_type, | |||
param.dst_type); | |||
return megdnn::winograd::ConvBias<winograd::winograd_5x4_1x1_f>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_workspace_size(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoFP32WinogradF54::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 2) { | |||
winograd::winograd_5x4_1x1_f strategy(param.src_type, param.filter_type, | |||
param.dst_type); | |||
auto winograd_impl = | |||
megdnn::winograd::ConvBias<winograd::winograd_5x4_1x1_f>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg); | |||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ======================= AlgoFP32WinogradF45 ======================== */ | |||
bool ConvBiasImpl::AlgoFP32WinogradF45::usable( | |||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MEGDNN_MARK_USED_VAR(opr); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 0) { | |||
using Strategy = winograd::winograd_4x5_1x1_f; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
megdnn::winograd::ConvBias<Strategy>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_matmul_kern_param(param); | |||
return m_matmul_algo->usable(matmul_param) && | |||
(opr->param().format == param::ConvBias::Format::NCHW || | |||
(opr->param().format == | |||
param::ConvBias::Format::NCHW_WINOGRAD && | |||
opr->param().output_block_size == 4 && | |||
param.winograd_matmul_format == | |||
param::MatrixMul::Format::DEFAULT)) && | |||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||
param.filter_meta.spatial[0] == 5) && | |||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
param.filter_meta.stride[0] == 1) && | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||
param.src_type.enumv() == DTypeEnum::Float32; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoFP32WinogradF45::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 1) { | |||
winograd::winograd_4x5_1x1_f strategy(param.src_type, param.filter_type, | |||
param.dst_type); | |||
return megdnn::winograd::ConvBias<winograd::winograd_4x5_1x1_f>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_workspace_size(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoFP32WinogradF45::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 2) { | |||
winograd::winograd_4x5_1x1_f strategy(param.src_type, param.filter_type, | |||
param.dst_type); | |||
auto winograd_impl = | |||
megdnn::winograd::ConvBias<winograd::winograd_4x5_1x1_f>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg); | |||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ======================= AlgoFP32WinogradF63_4x4 ======================== */ | |||
bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( | |||
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MEGDNN_MARK_USED_VAR(opr); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 0) { | |||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | |||
return false; | |||
using Strategy = winograd::winograd_6x3_4x4_f; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
megdnn::winograd::ConvBias<Strategy, | |||
param::MatrixMul::Format::MK4>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_matmul_kern_param(param); | |||
return m_matmul_algo->usable(matmul_param) && | |||
(opr->param().format == param::ConvBias::Format::NCHW || | |||
(opr->param().format == | |||
param::ConvBias::Format::NCHW_WINOGRAD && | |||
opr->param().output_block_size == 6 && | |||
param.winograd_matmul_format == | |||
param::MatrixMul::Format::MK4)) && | |||
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && | |||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||
param.filter_meta.spatial[0] == 3) && | |||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
param.filter_meta.stride[0] == 1) && | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_meta.icpg % 4 == 0 && | |||
param.filter_meta.ocpg % 4 == 0; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoFP32WinogradF63_4x4::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 1) { | |||
winograd::winograd_6x3_4x4_f strategy(param.src_type, param.filter_type, | |||
param.dst_type); | |||
return megdnn::winograd::ConvBias<winograd::winograd_6x3_4x4_f, | |||
param::MatrixMul::Format::MK4>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg) | |||
.get_workspace_size(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoFP32WinogradF63_4x4::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 2) { | |||
winograd::winograd_6x3_4x4_f strategy(param.src_type, param.filter_type, | |||
param.dst_type); | |||
auto winograd_impl = | |||
megdnn::winograd::ConvBias<winograd::winograd_6x3_4x4_f, | |||
param::MatrixMul::Format::MK4>( | |||
strategy, m_tile_size, param.nr_threads, param.osz[0], | |||
param.osz[1], param.filter_meta.ocpg); | |||
return winograd_impl.get_kerns(param, m_matmul_algo); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ===================== direct algo ===================== */ | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl); | |||
bool ConvBiasImpl::AlgoF32Direct::usable( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
auto SH = fm.stride[0], SW = fm.stride[1]; | |||
// the condition ``param.isz[0]*param.isz[1] >= 4'' and | |||
// ``param.osz[0]*param.osz[1] >= 4'' comes from the fact that the | |||
// kernel may have access to up to 4 floats after the end of the memory | |||
// chunk. | |||
bool aviliable = fm.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && | |||
param.isz[0] * param.isz[1] >= 4 && | |||
param.osz[0] * param.osz[1] >= 4 && FH <= 7 && | |||
SH == 1 && SW == 1; | |||
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoF32Direct::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { | |||
auto wbundle = MultithreadDirectConvCommon<float, float>::get_bundle( | |||
param, m_large_group); | |||
return wbundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls( | |||
const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
size_t N = param.n; | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
WorkspaceBundle wbundle = | |||
MultithreadDirectConvCommon<float, float>::get_bundle( | |||
param, m_large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! When group >= nr_threads, treat it as large_group, each thread process | |||
//! one group for better performance | |||
if (m_large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [wbundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
auto fm = kern_param.filter_meta; | |||
size_t IC = fm.icpg; | |||
size_t OC = fm.ocpg; | |||
WorkspaceBundle bundle = wbundle; | |||
if (fm.should_flip) { | |||
for (size_t oc = 0; oc < OC; oc++) { | |||
MultithreadDirectConvCommon<float, float>::weight_flip_kern( | |||
bundle, kern_param, ncb_index, | |||
{ncb_index.thread_id, 0, oc}); | |||
} | |||
} | |||
for (size_t ic = 0; ic < IC; ic++) { | |||
MultithreadDirectConvCommon<float, float>::copy_padding_kern( | |||
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); | |||
} | |||
for (size_t oc = 0; oc < OC; oc++) { | |||
MultithreadDirectConvCommon<float, float>::do_conv_kern( | |||
bundle, kern_param, ncb_index, | |||
fp32::conv_bias::kern_direct, | |||
{ncb_index.thread_id, 0, oc}); | |||
} | |||
}; | |||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||
} else { | |||
WorkspaceBundle bundle = wbundle; | |||
if (fm.should_flip) { | |||
auto weight_flip = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
MultithreadDirectConvCommon<float, float>::weight_flip_kern( | |||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({weight_flip, {group, 1_z, OC}}); | |||
} | |||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
MultithreadDirectConvCommon<float, float>::copy_padding_kern( | |||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||
auto do_conv = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
MultithreadDirectConvCommon<float, float>::do_conv_kern( | |||
bundle, kern_param, ncb_index, fp32::conv_bias::kern_direct, | |||
ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||
} | |||
return ret_kerns; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { | |||
return get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ===================== stride-1 algo ===================== */ | |||
bool ConvBiasImpl::AlgoF32DirectStride1::usable( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
bool aviliable = | |||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | |||
auto bundle = | |||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||
param, m_large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( | |||
const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
size_t N = param.n; | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
using Func = std::function<void(const float*, const float*, float*, size_t, | |||
size_t, size_t, size_t, size_t)>; | |||
Func conv_kern_function = nullptr; | |||
#define SWITCH_KERN_STR1() \ | |||
switch (FH) { \ | |||
case 2: \ | |||
conv_kern_function = fp32::conv_stride1::do_conv_2x2_stride1; \ | |||
break; \ | |||
case 3: \ | |||
conv_kern_function = fp32::conv_stride1::do_conv_3x3_stride1; \ | |||
break; \ | |||
case 5: \ | |||
conv_kern_function = fp32::conv_stride1::do_conv_5x5_stride1; \ | |||
break; \ | |||
case 7: \ | |||
conv_kern_function = fp32::conv_stride1::do_conv_7x7_stride1; \ | |||
break; \ | |||
} | |||
SWITCH_KERN_STR1(); | |||
WorkspaceBundle wbundle = | |||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||
param, m_large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! When group >= nr_threads, treat it as large_group, each thread process | |||
//! one group for better performance | |||
if (m_large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [wbundle, conv_kern_function]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
auto fm = kern_param.filter_meta; | |||
size_t IC = fm.icpg; | |||
size_t OC = fm.ocpg; | |||
WorkspaceBundle bundle = wbundle; | |||
for (size_t ic = 0; ic < IC; ic++) { | |||
MultithreadDirectConvCommon<float, float>:: | |||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
{ncb_index.thread_id, 0, ic}); | |||
} | |||
for (size_t oc = 0; oc < OC; oc++) { | |||
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( | |||
bundle, kern_param, ncb_index, conv_kern_function, | |||
{ncb_index.thread_id, 0, oc}); | |||
} | |||
}; | |||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||
} else { | |||
WorkspaceBundle bundle = wbundle; | |||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
MultithreadDirectConvCommon<float, float>::copy_padding_kern_stride( | |||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||
auto do_conv = [bundle, conv_kern_function]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( | |||
bundle, kern_param, ncb_index, conv_kern_function, | |||
ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||
} | |||
return ret_kerns; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 2) { | |||
return get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ===================== stride-2 algo ===================== */ | |||
bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
bool aviliable = | |||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
if (algo_selection_strategy == | |||
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
aviliable &= (large_group == m_large_group); | |||
} | |||
return aviliable; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) { | |||
auto bundle = | |||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||
param, m_large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
size_t N = param.n; | |||
size_t IC = param.filter_meta.icpg; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
using Func = std::function<void(const float*, const float*, float*, size_t, | |||
size_t, size_t, size_t, size_t)>; | |||
Func conv_kern_function = nullptr; | |||
#define SWITCH_KERN_STR2() \ | |||
switch (FH) { \ | |||
case 2: \ | |||
conv_kern_function = fp32::conv_stride2::do_conv_2x2_stride2; \ | |||
break; \ | |||
case 3: \ | |||
conv_kern_function = fp32::conv_stride2::do_conv_3x3_stride2; \ | |||
break; \ | |||
case 5: \ | |||
conv_kern_function = fp32::conv_stride2::do_conv_5x5_stride2; \ | |||
break; \ | |||
case 7: \ | |||
conv_kern_function = fp32::conv_stride2::do_conv_7x7_stride2; \ | |||
break; \ | |||
} | |||
SWITCH_KERN_STR2(); | |||
WorkspaceBundle wbundle = | |||
MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||
param, m_large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! When group >= nr_threads, treat it as large_group, each thread process | |||
//! one group for better performance | |||
if (m_large_group) { | |||
//! Channel wise conv and big groups | |||
auto exec_one_group = [wbundle, conv_kern_function]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
auto fm = kern_param.filter_meta; | |||
size_t IC = fm.icpg; | |||
size_t OC = fm.ocpg; | |||
WorkspaceBundle bundle = wbundle; | |||
for (size_t ic = 0; ic < IC; ic++) { | |||
MultithreadDirectConvCommon<float, float>:: | |||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
{ncb_index.thread_id, 0, ic}); | |||
} | |||
for (size_t oc = 0; oc < OC; oc++) { | |||
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( | |||
bundle, kern_param, ncb_index, conv_kern_function, | |||
{ncb_index.thread_id, 0, oc}); | |||
} | |||
}; | |||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||
} else { | |||
WorkspaceBundle bundle = wbundle; | |||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
MultithreadDirectConvCommon<float, float>::copy_padding_kern_stride( | |||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||
auto do_conv = [bundle, conv_kern_function]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( | |||
bundle, kern_param, ncb_index, conv_kern_function, | |||
ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||
} | |||
return ret_kerns; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 2) { | |||
return get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,223 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/algos.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/fallback/matrix_mul/opr_impl.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
class ConvBiasImpl::AlgoFP32WinogradF23_4x4 final : public AlgoBase { | |||
public: | |||
AlgoFP32WinogradF23_4x4(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
uint32_t tile_size) | |||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
if (m_name.empty()) { | |||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||
m_matmul_algo->name(), {4, 2, m_tile_size}); | |||
} | |||
return m_name.c_str(); | |||
} | |||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
private: | |||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
mutable std::string m_name; | |||
uint32_t m_tile_size; | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { | |||
public: | |||
AlgoFP32WinogradF63(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
uint32_t tile_size) | |||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
if (m_name.empty()) { | |||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||
m_matmul_algo->name(), {1, 6, m_tile_size}); | |||
} | |||
return m_name.c_str(); | |||
} | |||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
private: | |||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
mutable std::string m_name; | |||
uint32_t m_tile_size; | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | |||
public: | |||
AlgoFP32WinogradF63_4x4(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
uint32_t tile_size) | |||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
if (m_name.empty()) { | |||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||
m_matmul_algo->name(), {4, 6, m_tile_size}); | |||
} | |||
return m_name.c_str(); | |||
} | |||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
private: | |||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
mutable std::string m_name; | |||
uint32_t m_tile_size; | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { | |||
public: | |||
AlgoFP32WinogradF54(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
uint32_t tile_size) | |||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
if (m_name.empty()) { | |||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||
m_matmul_algo->name(), {1, 5, m_tile_size}); | |||
} | |||
return m_name.c_str(); | |||
} | |||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
private: | |||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
mutable std::string m_name; | |||
uint32_t m_tile_size; | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { | |||
public: | |||
AlgoFP32WinogradF45(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
uint32_t tile_size) | |||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
if (m_name.empty()) { | |||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||
m_matmul_algo->name(), {1, 4, m_tile_size}); | |||
} | |||
return m_name.c_str(); | |||
} | |||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
private: | |||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
mutable std::string m_name; | |||
uint32_t m_tile_size; | |||
}; | |||
class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF32Direct(bool is_large_group) : m_large_group{is_large_group} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "F32DIRECT_LARGE_GROUP" | |||
: "F32DIRECT_SMALL_GROUP"; | |||
} | |||
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF32DirectStride1(bool is_large_group) : m_large_group{is_large_group} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "F32STRD1_LARGE_GROUP" : "F32STRD1_SMALL_GROUP"; | |||
} | |||
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
bool m_large_group; | |||
public: | |||
AlgoF32DirectStride2(bool is_large_group) : m_large_group{is_large_group} {} | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return m_large_group ? "F32STRD2_LARGE_GROUP" : "F32STRD2_SMALL_GROUP"; | |||
} | |||
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(fallback::ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
fallback::ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param) const override; | |||
}; | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,911 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/direct.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <cstring> | |||
#include "include/megdnn/oprs.h" | |||
#include "midout.h" | |||
#include "src/arm_common/conv_bias/fp32/direct.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "src/common/unroll_macro.h" | |||
MIDOUT_DECL(megdnn_arm_conv_f32) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fp32; | |||
using namespace conv_bias; | |||
namespace { | |||
template <int FH, int height, int width> | |||
struct do_pixel_proxy { | |||
static void exec(const float* src, const float* filter, float* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow); | |||
}; | |||
#define cb_load(i) data = vld1q_lane_f32(dst + i, data, i); | |||
#define LOAD_OUT \ | |||
if (width < 4) { \ | |||
auto load_less_4 = [](float* dst, float32x4_t& data) { \ | |||
if (width == 1u) { \ | |||
UNROLL_CALL_NOWRAPPER(1, cb_load); \ | |||
} else if (width == 2u) { \ | |||
UNROLL_CALL_NOWRAPPER(2, cb_load); \ | |||
} else if (width == 3u) { \ | |||
UNROLL_CALL_NOWRAPPER(3, cb_load); \ | |||
} \ | |||
}; \ | |||
if (height >= 1) \ | |||
load_less_4(dst + 0 * OW, out0); \ | |||
if (height >= 2) \ | |||
load_less_4(dst + 1 * OW, out1); \ | |||
if (height >= 3) \ | |||
load_less_4(dst + 2 * OW, out2); \ | |||
if (height >= 4) \ | |||
load_less_4(dst + 3 * OW, out3); \ | |||
} else { \ | |||
if (height > 0) \ | |||
out0 = vld1q_f32(dst + 0 * OW); \ | |||
if (height > 1) \ | |||
out1 = vld1q_f32(dst + 1 * OW); \ | |||
if (height > 2) \ | |||
out2 = vld1q_f32(dst + 2 * OW); \ | |||
if (height > 3) \ | |||
out3 = vld1q_f32(dst + 3 * OW); \ | |||
} | |||
#define cb_store(i) vst1q_lane_f32(dst + i, data, i); | |||
#define STORE_OUT \ | |||
if (width < 4) { \ | |||
auto store_less_4 = [](float* dst, float32x4_t& data) { \ | |||
if (width == 1u) { \ | |||
UNROLL_CALL_NOWRAPPER(1, cb_store); \ | |||
} else if (width == 2u) { \ | |||
UNROLL_CALL_NOWRAPPER(2, cb_store); \ | |||
} else if (width == 3u) { \ | |||
UNROLL_CALL_NOWRAPPER(3, cb_store); \ | |||
} \ | |||
}; \ | |||
if (height >= 1) \ | |||
store_less_4(dst + 0 * OW, out0); \ | |||
if (height >= 2) \ | |||
store_less_4(dst + 1 * OW, out1); \ | |||
if (height >= 3) \ | |||
store_less_4(dst + 2 * OW, out2); \ | |||
if (height >= 4) \ | |||
store_less_4(dst + 3 * OW, out3); \ | |||
} else { \ | |||
if (height >= 1) \ | |||
vst1q_f32(dst + 0 * OW, out0); \ | |||
if (height >= 2) \ | |||
vst1q_f32(dst + 1 * OW, out1); \ | |||
if (height >= 3) \ | |||
vst1q_f32(dst + 2 * OW, out2); \ | |||
if (height >= 4) \ | |||
vst1q_f32(dst + 3 * OW, out3); \ | |||
} | |||
template <int height, int width> | |||
struct do_pixel_proxy<1, height, width> { | |||
static void exec(const float* src, const float* filter, float* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
} | |||
STORE_OUT; | |||
} | |||
}; | |||
template <int height, int width> | |||
struct do_pixel_proxy<2, height, width> { | |||
static void exec(const float* src, const float* filter, float* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr1); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr1); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr1); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 4 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr1); | |||
} | |||
STORE_OUT; | |||
} | |||
}; | |||
template <int height, int width> | |||
struct do_pixel_proxy<3, height, width> { | |||
static void exec(const float* src, const float* filter, float* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr1); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr2); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr1); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr2); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr1); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 4 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr2); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr1); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 5 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr2); | |||
} | |||
STORE_OUT; | |||
} | |||
}; | |||
template <int height, int width> | |||
struct do_pixel_proxy<4, height, width> { | |||
static void exec(const float* src, const float* filter, float* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr1); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr2); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr1); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr3); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr2); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr1); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 4 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr3); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr2); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr1); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 5 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr3); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr2); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 6 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr3); | |||
} | |||
STORE_OUT; | |||
} | |||
}; | |||
template <int height, int width> | |||
struct do_pixel_proxy<5, height, width> { | |||
static void exec(const float* src, const float* filter, float* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, | |||
inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||
kr4 = vdupq_n_f32(filter[4 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr1); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr2); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr1); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr3); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr2); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr1); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 4 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr4); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr3); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr2); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr1); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 5 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr4); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr3); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr2); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 6 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr4); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr3); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 7 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr4); | |||
} | |||
STORE_OUT; | |||
} | |||
}; | |||
template <int height, int width> | |||
struct do_pixel_proxy<6, height, width> { | |||
static void exec(const float* src, const float* filter, float* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, | |||
kr5, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||
kr4 = vdupq_n_f32(filter[4 * FW + fw]); | |||
kr5 = vdupq_n_f32(filter[5 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr1); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr2); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr1); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr3); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr2); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr1); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 4 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr4); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr3); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr2); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr1); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 5 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr5); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr4); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr3); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr2); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 6 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr5); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr4); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr3); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 7 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr5); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr4); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 8 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr5); | |||
} | |||
STORE_OUT; | |||
} | |||
}; | |||
template <int height, int width> | |||
struct do_pixel_proxy<7, height, width> { | |||
static void exec(const float* src, const float* filter, float* dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FW, const int oh, const int ow) { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, | |||
kr5, kr6, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||
kr4 = vdupq_n_f32(filter[4 * FW + fw]); | |||
kr5 = vdupq_n_f32(filter[5 * FW + fw]); | |||
kr6 = vdupq_n_f32(filter[6 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr1); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr2); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr1); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr3); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr2); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr1); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 4 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr4); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr3); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr2); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr1); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 5 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr5); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr4); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr3); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr2); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 6 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr6); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr5); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr4); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr3); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 7 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr6); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr5); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr4); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 8 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr6); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr5); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 9 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr6); | |||
} | |||
STORE_OUT; | |||
} | |||
}; | |||
#undef cb_load | |||
#undef cb_load | |||
#undef LOAD_OUT | |||
#undef STORE_OUT | |||
template <int FH, int height, int width> | |||
void do_pixel(const float* src, const float* filter, float* dst, const int IH, | |||
const int IW, const int OH, const int OW, const int FW, | |||
const int oh, const int ow) { | |||
do_pixel_proxy<FH, height, width>::exec(src, filter, dst, IH, IW, OH, OW, | |||
FW, oh, ow); | |||
} | |||
template <int FH> | |||
void do_conv_tpl_enable_prefetch(const float* src, const float* filter, | |||
float* dst, const int IH, const int IW, | |||
const int OH, const int OW, const int FW) { | |||
const int hbeg = 0, hend = OH; | |||
const int wbeg = 0, wend = OW; | |||
int i, j; | |||
for (i = hbeg; i + 4 <= hend; i += 4) { | |||
for (j = wbeg; j + 4 <= wend; j += 4) { | |||
// do prefetch | |||
const int prefetch_index_input = | |||
(j + 16) < wend | |||
? i * IW + j + 16 | |||
: (i + 4) * IW + (((j + 16 - wend) >> 2) << 2); | |||
const int prefetch_index_output = | |||
(j + 16) < wend | |||
? i * OW + j + 16 | |||
: (i + 4) * OW + (((j + 16 - wend) >> 2) << 2); | |||
const float* src_prefetch = src + prefetch_index_input; | |||
const float* dst_prefetch = dst + prefetch_index_output; | |||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { | |||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); | |||
} | |||
__builtin_prefetch(dst_prefetch + 0 * OW, 1, 3); | |||
__builtin_prefetch(dst_prefetch + 1 * OW, 1, 3); | |||
__builtin_prefetch(dst_prefetch + 2 * OW, 1, 3); | |||
__builtin_prefetch(dst_prefetch + 3 * OW, 1, 3); | |||
do_pixel<FH, 4, 4>(src, filter, dst, IH, IW, OH, OW, FW, i, j); | |||
} | |||
#define DISPATCH(width) \ | |||
do { \ | |||
const int prefetch_index_input = (i + 4) * IW + 12; \ | |||
const int prefetch_index_output = (i + 4) * OW + 12; \ | |||
const float* src_prefetch = src + prefetch_index_input; \ | |||
const float* dst_prefetch = dst + prefetch_index_output; \ | |||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ | |||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ | |||
} \ | |||
__builtin_prefetch(dst_prefetch + 0 * OW, 1, 3); \ | |||
__builtin_prefetch(dst_prefetch + 1 * OW, 1, 3); \ | |||
__builtin_prefetch(dst_prefetch + 2 * OW, 1, 3); \ | |||
__builtin_prefetch(dst_prefetch + 3 * OW, 1, 3); \ | |||
do_pixel<FH, 4, width>(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ | |||
} while (0) | |||
switch (wend - j) { | |||
case 1: | |||
DISPATCH(1); | |||
break; | |||
case 2: | |||
DISPATCH(2); | |||
break; | |||
case 3: | |||
DISPATCH(3); | |||
break; | |||
} | |||
#undef DISPATCH | |||
} | |||
#define DISPATCH2(height, width) \ | |||
do { \ | |||
const int prefetch_index_input = IH * IW + 12; \ | |||
const float* src_prefetch = src + prefetch_index_input; \ | |||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ | |||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ | |||
} \ | |||
do_pixel<FH, height, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||
j); \ | |||
} while (0) | |||
#define DISPATCH1(height) \ | |||
do { \ | |||
for (j = wbeg; j + 4 <= wend; j += 4) { \ | |||
const int prefetch_index_input = \ | |||
(j + 16) < wend \ | |||
? i * IW + j + 16 \ | |||
: (i + 4) * IW + (((j + 16 - wend) >> 2) << 2); \ | |||
const int prefetch_index_output = \ | |||
(j + 16) < wend \ | |||
? i * OW + j + 16 \ | |||
: (i + 4) * OW + (((j + 16 - wend) >> 2) << 2); \ | |||
const float* src_prefetch = src + prefetch_index_input; \ | |||
const float* dst_prefetch = dst + prefetch_index_output; \ | |||
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ | |||
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ | |||
} \ | |||
__builtin_prefetch(dst_prefetch + 0 * OW, 1, 3); \ | |||
__builtin_prefetch(dst_prefetch + 1 * OW, 1, 3); \ | |||
__builtin_prefetch(dst_prefetch + 2 * OW, 1, 3); \ | |||
__builtin_prefetch(dst_prefetch + 3 * OW, 1, 3); \ | |||
do_pixel<FH, height, 4>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||
j); \ | |||
} \ | |||
switch (wend - j) { \ | |||
case 1: \ | |||
DISPATCH2(height, 1); \ | |||
break; \ | |||
case 2: \ | |||
DISPATCH2(height, 2); \ | |||
break; \ | |||
case 3: \ | |||
DISPATCH2(height, 3); \ | |||
break; \ | |||
} \ | |||
} while (0) | |||
switch (hend - i) { | |||
case 1: | |||
DISPATCH1(1); | |||
break; | |||
case 2: | |||
DISPATCH1(2); | |||
break; | |||
case 3: | |||
DISPATCH1(3); | |||
break; | |||
} | |||
#undef DISPATCH1 | |||
#undef DISPATCH2 | |||
} | |||
template <int FH> | |||
void do_conv_tpl_disable_prefetch(const float* src, const float* filter, | |||
float* dst, const int IH, const int IW, | |||
const int OH, const int OW, const int FW) { | |||
const int hbeg = 0, hend = OH; | |||
const int wbeg = 0, wend = OW; | |||
int i, j; | |||
for (i = hbeg; i + 4 <= hend; i += 4) { | |||
for (j = wbeg; j + 4 <= wend; j += 4) { | |||
do_pixel<FH, 4, 4>(src, filter, dst, IH, IW, OH, OW, FW, i, j); | |||
} | |||
#define DISPATCH(width) \ | |||
do { \ | |||
do_pixel<FH, 4, width>(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ | |||
} while (0) | |||
switch (wend - j) { | |||
case 1: | |||
DISPATCH(1); | |||
break; | |||
case 2: | |||
DISPATCH(2); | |||
break; | |||
case 3: | |||
DISPATCH(3); | |||
break; | |||
} | |||
#undef DISPATCH | |||
} | |||
#define DISPATCH2(height, width) \ | |||
do { \ | |||
do_pixel<FH, height, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||
j); \ | |||
} while (0) | |||
#define DISPATCH1(height) \ | |||
do { \ | |||
for (j = wbeg; j + 4 <= wend; j += 4) { \ | |||
do_pixel<FH, height, 4>(src, filter, dst, IH, IW, OH, OW, FW, i, \ | |||
j); \ | |||
} \ | |||
switch (wend - j) { \ | |||
case 1: \ | |||
DISPATCH2(height, 1); \ | |||
break; \ | |||
case 2: \ | |||
DISPATCH2(height, 2); \ | |||
break; \ | |||
case 3: \ | |||
DISPATCH2(height, 3); \ | |||
break; \ | |||
} \ | |||
} while (0) | |||
switch (hend - i) { | |||
case 1: | |||
DISPATCH1(1); | |||
break; | |||
case 2: | |||
DISPATCH1(2); | |||
break; | |||
case 3: | |||
DISPATCH1(3); | |||
break; | |||
} | |||
#undef DISPATCH1 | |||
#undef DISPATCH2 | |||
} | |||
} // anonymous namespace | |||
void conv_bias::kern_direct(const float* src, const float* filter, float* dst, | |||
const int IH, const int IW, const int OH, | |||
const int OW, const int FH, const int FW) { | |||
megdnn_assert_internal(FH <= 7); | |||
if (IH > 100 && IW > 100) { | |||
#define GAO(FH) \ | |||
do { \ | |||
return do_conv_tpl_enable_prefetch<FH>(src, filter, dst, IH, IW, OH, \ | |||
OW, FW); \ | |||
} while (0) | |||
switch (FH) { | |||
case 1: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); } | |||
MIDOUT_END(); | |||
break; | |||
case 2: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); } | |||
MIDOUT_END(); | |||
break; | |||
case 3: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); } | |||
MIDOUT_END(); | |||
break; | |||
case 4: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); } | |||
MIDOUT_END(); | |||
break; | |||
case 5: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); } | |||
MIDOUT_END(); | |||
break; | |||
case 6: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); } | |||
MIDOUT_END(); | |||
break; | |||
case 7: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); } | |||
MIDOUT_END(); | |||
break; | |||
} | |||
#undef GAO | |||
} else { | |||
#define GAO(FH) \ | |||
do { \ | |||
return do_conv_tpl_disable_prefetch<FH>(src, filter, dst, IH, IW, OH, \ | |||
OW, FW); \ | |||
} while (0) | |||
switch (FH) { | |||
case 1: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); } | |||
MIDOUT_END(); | |||
break; | |||
case 2: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); } | |||
MIDOUT_END(); | |||
break; | |||
case 3: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); } | |||
MIDOUT_END(); | |||
break; | |||
case 4: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); } | |||
MIDOUT_END(); | |||
break; | |||
case 5: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); } | |||
MIDOUT_END(); | |||
break; | |||
case 6: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); } | |||
MIDOUT_END(); | |||
break; | |||
case 7: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); } | |||
MIDOUT_END(); | |||
break; | |||
} | |||
#undef GAO | |||
} | |||
megdnn_assert_internal(0); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,29 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/direct.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <cstddef> | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fp32{ | |||
namespace conv_bias { | |||
void kern_direct(const float *src, const float *filter, float *dst, | |||
const int IH, const int IW, const int OH, const int OW, | |||
const int FH, const int FW); | |||
} // namespace convolution | |||
} // namespace fp32 | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,735 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <algorithm> | |||
#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" | |||
#include "src/arm_common/simd_macro/neon_helper.h" | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs1) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fp32; | |||
using namespace conv_stride1; | |||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||
void conv_stride1::do_conv_2x2_stride1(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, | |||
size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
//! unroll of 2 | |||
size_t ic = 0; | |||
for (; ic + 1 < IC; ic += 2) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
const float* src_ptr1 = src_ptr + IW * IH; | |||
float* outptr = dst; | |||
const float* r00 = src_ptr; | |||
const float* r01 = src_ptr + IW; | |||
const float* r10 = src_ptr1; | |||
const float* r11 = src_ptr1 + IW; | |||
const float* k0 = filter + ic * 4; | |||
const float* k1 = k0 + 4; | |||
MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_LOADU(k0); | |||
MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_LOADU(k1); | |||
rep(h, OH) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _r000 = MEGDNN_SIMD_LOADU(r00); | |||
MEGDNN_SIMD_TYPE _r010 = MEGDNN_SIMD_LOADU(r01); | |||
MEGDNN_SIMD_TYPE _r001 = MEGDNN_SIMD_LOADU(r00 + 1); | |||
MEGDNN_SIMD_TYPE _r011 = MEGDNN_SIMD_LOADU(r01 + 1); | |||
MEGDNN_SIMD_TYPE _r100 = MEGDNN_SIMD_LOADU(r10); | |||
MEGDNN_SIMD_TYPE _r110 = MEGDNN_SIMD_LOADU(r11); | |||
MEGDNN_SIMD_TYPE _r101 = MEGDNN_SIMD_LOADU(r10 + 1); | |||
MEGDNN_SIMD_TYPE _r111 = MEGDNN_SIMD_LOADU(r11 + 1); | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r000, | |||
MEGDNN_SIMD_GET_LOW(_k0), 0); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r001, | |||
MEGDNN_SIMD_GET_LOW(_k0), 1); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r010, | |||
MEGDNN_SIMD_GET_HIGH(_k0), 0); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r011, | |||
MEGDNN_SIMD_GET_HIGH(_k0), 1); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r100, | |||
MEGDNN_SIMD_GET_LOW(_k1), 0); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r101, | |||
MEGDNN_SIMD_GET_LOW(_k1), 1); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r110, | |||
MEGDNN_SIMD_GET_HIGH(_k1), 0); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r111, | |||
MEGDNN_SIMD_GET_HIGH(_k1), 1); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
r00 += 4; | |||
r01 += 4; | |||
r10 += 4; | |||
r11 += 4; | |||
outptr += 4; | |||
} | |||
r00 += tail_step; | |||
r01 += tail_step; | |||
r10 += tail_step; | |||
r11 += tail_step; | |||
} | |||
} | |||
for (; ic < IC; ic++) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* k0 = filter + ic * 4; | |||
MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_SET1(k0[0]); | |||
MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_SET1(k0[1]); | |||
MEGDNN_SIMD_TYPE _k2 = MEGDNN_SIMD_SET1(k0[2]); | |||
MEGDNN_SIMD_TYPE _k3 = MEGDNN_SIMD_SET1(k0[3]); | |||
rep(h, OH) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_LOADU(r0 + 1); | |||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_LOADU(r1 + 1); | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _sum2; | |||
_sum = MEGDNN_SIMD_FMADD(_r00, _k0, _sum); | |||
_sum2 = MEGDNN_SIMD_MUL(_r01, _k1); | |||
_sum = MEGDNN_SIMD_FMADD(_r10, _k2, _sum); | |||
_sum2 = MEGDNN_SIMD_FMADD(_r11, _k3, _sum2); | |||
_sum = MEGDNN_SIMD_ADD(_sum, _sum2); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
r0 += 4; | |||
r1 += 4; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
} | |||
} | |||
} | |||
void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, | |||
size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
float* outptr2 = outptr + OW; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* k0 = filter; | |||
const float* k1 = filter + 3; | |||
const float* k2 = filter + 5; | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||
MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); | |||
MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); | |||
MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1); | |||
size_t h = 0; | |||
for (; h + 1 < OH; h += 2) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f); | |||
MEGDNN_SIMD_TYPE _sum3 = MEGDNN_SIMD_LOADU(outptr2); | |||
MEGDNN_SIMD_TYPE _sum4 = MEGDNN_SIMD_SET1(0.f); | |||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4); | |||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1); | |||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2); | |||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4); | |||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1); | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2); | |||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4); | |||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1); | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); | |||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 4); | |||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r30n, 1); | |||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r30n, 2); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2); | |||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r10, _k0123, 0); | |||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r11, _k0123, 1); | |||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r12, _k0123, 2); | |||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r20, _k3456, 0); | |||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r21, _k3456, 1); | |||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r22, _k3456, 2); | |||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r30, _k6789, 0); | |||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r31, _k6789, 1); | |||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r32, _k6789, 2); | |||
_sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2); | |||
_sum3 = MEGDNN_SIMD_ADD(_sum3, _sum4); | |||
MEGDNN_SIMD_STOREU(outptr, _sum1); | |||
MEGDNN_SIMD_STOREU(outptr2, _sum3); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
r3 += 4; | |||
outptr += 4; | |||
outptr2 += 4; | |||
} | |||
r0 += tail_step + IW; | |||
r1 += tail_step + IW; | |||
r2 += tail_step + IW; | |||
r3 += tail_step + IW; | |||
outptr += OW; | |||
outptr2 += OW; | |||
} | |||
for (; h < OH; h++) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f); | |||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4); | |||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1); | |||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2); | |||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4); | |||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1); | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2); | |||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4); | |||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1); | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2); | |||
_sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2); | |||
MEGDNN_SIMD_STOREU(outptr, _sum1); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
} | |||
filter += 9; | |||
} | |||
} | |||
void conv_stride1::do_conv_5x5_stride1(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, | |||
size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
float* outptr2 = outptr + OW; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* r4 = src_ptr + IW * 4; | |||
const float* r5 = src_ptr + IW * 5; | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); | |||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); | |||
MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); | |||
MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); | |||
MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); | |||
MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); | |||
MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]); | |||
size_t h = 0; | |||
for (; h + 1 < OH; h += 2) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_LOADU(outptr2); | |||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); | |||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); | |||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); | |||
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); | |||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); | |||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); | |||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); | |||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); | |||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); | |||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); | |||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); | |||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); | |||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); | |||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); | |||
MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); | |||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); | |||
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); | |||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); | |||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); | |||
MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5); | |||
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4); | |||
MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1); | |||
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2); | |||
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k0123, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r11, _k0123, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k0123, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r13, _k0123, 3); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r14, _k4567, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r20, _k4567, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k4567, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r22, _k4567, 3); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r23, _k891011, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r24, _k891011, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r30, _k891011, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r31, _k891011, 3); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r32, _k12131415, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r33, _k12131415, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r34, _k12131415, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r40, _k12131415, 3); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r41, _k16171819, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r42, _k16171819, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r43, _k16171819, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r44, _k16171819, 3); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r50, _k20212223, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r51, _k20212223, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r52, _k20212223, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r53, _k20212223, 3); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r54, _k24242424, 0); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
MEGDNN_SIMD_STOREU(outptr2, _sum2); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
r3 += 4; | |||
r4 += 4; | |||
r5 += 4; | |||
outptr += 4; | |||
outptr2 += 4; | |||
} | |||
r0 += tail_step + IW; | |||
r1 += tail_step + IW; | |||
r2 += tail_step + IW; | |||
r3 += tail_step + IW; | |||
r4 += tail_step + IW; | |||
r5 += tail_step + IW; | |||
outptr += OW; | |||
outptr2 += OW; | |||
} | |||
for (; h < OH; h++) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); | |||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); | |||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); | |||
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); | |||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); | |||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); | |||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); | |||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); | |||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); | |||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); | |||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); | |||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); | |||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); | |||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); | |||
MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); | |||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); | |||
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); | |||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); | |||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
r3 += 4; | |||
r4 += 4; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
r3 += tail_step; | |||
r4 += tail_step; | |||
} | |||
filter += 25; | |||
} | |||
} | |||
void conv_stride1::do_conv_7x7_stride1(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, | |||
size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* r4 = src_ptr + IW * 4; | |||
const float* r5 = src_ptr + IW * 5; | |||
const float* r6 = src_ptr + IW * 6; | |||
const float* k0 = filter; | |||
const float* k1 = filter + 7; | |||
const float* k2 = filter + 14; | |||
const float* k3 = filter + 21; | |||
const float* k4 = filter + 28; | |||
const float* k5 = filter + 35; | |||
const float* k6 = filter + 42; | |||
for (size_t i = 0; i < OH; i++) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4); | |||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); // 0 1 2 3 | |||
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); // 4 5 6 7 | |||
MEGDNN_SIMD_TYPE _r00n = | |||
MEGDNN_SIMD_LOADU(r0 + 8); // 8 9 10 11 | |||
MEGDNN_SIMD_TYPE _r01 = | |||
MEGDNN_SIMD_EXT(_r00, _r04, 1); // 1 2 3 4 | |||
MEGDNN_SIMD_TYPE _r02 = | |||
MEGDNN_SIMD_EXT(_r00, _r04, 2); // 2 3 4 5 | |||
MEGDNN_SIMD_TYPE _r03 = | |||
MEGDNN_SIMD_EXT(_r00, _r04, 3); // 3 4 5 6 | |||
MEGDNN_SIMD_TYPE _r05 = | |||
MEGDNN_SIMD_EXT(_r04, _r00n, 1); // 5 6 7 8 | |||
MEGDNN_SIMD_TYPE _r06 = | |||
MEGDNN_SIMD_EXT(_r04, _r00n, 2); // 6 7 8 9 | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2); | |||
MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1); | |||
MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4); | |||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); | |||
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 8); | |||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); | |||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); | |||
MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r14, _r10n, 1); | |||
MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r14, _r10n, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2); | |||
MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2); | |||
MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4); | |||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); | |||
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 8); | |||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); | |||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); | |||
MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r24, _r20n, 1); | |||
MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r24, _r20n, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2); | |||
MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3); | |||
MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4); | |||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); | |||
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 8); | |||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); | |||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); | |||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); | |||
MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r34, _r30n, 1); | |||
MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r34, _r30n, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2); | |||
MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4); | |||
MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4); | |||
MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); | |||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); | |||
MEGDNN_SIMD_TYPE _r40n = MEGDNN_SIMD_LOADU(r4 + 8); | |||
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); | |||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); | |||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); | |||
MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r44, _r40n, 1); | |||
MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r44, _r40n, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2); | |||
MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5); | |||
MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4); | |||
MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5); | |||
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4); | |||
MEGDNN_SIMD_TYPE _r50n = MEGDNN_SIMD_LOADU(r5 + 8); | |||
MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1); | |||
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2); | |||
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3); | |||
MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r54, _r50n, 1); | |||
MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r54, _r50n, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); | |||
MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); | |||
MEGDNN_SIMD_TYPE _k46474849 = MEGDNN_SIMD_LOADU(k6 + 4); | |||
MEGDNN_SIMD_TYPE _r60 = MEGDNN_SIMD_LOADU(r6); | |||
MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_LOADU(r6 + 4); | |||
MEGDNN_SIMD_TYPE _r60n = MEGDNN_SIMD_LOADU(r6 + 8); | |||
MEGDNN_SIMD_TYPE _r61 = MEGDNN_SIMD_EXT(_r60, _r64, 1); | |||
MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r64, 2); | |||
MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r60, _r64, 3); | |||
MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r64, _r60n, 1); | |||
MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r64, _r60n, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k46474849, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k46474849, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k46474849, 2); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
r3 += 4; | |||
r4 += 4; | |||
r5 += 4; | |||
r6 += 4; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
r3 += tail_step; | |||
r4 += tail_step; | |||
r5 += tail_step; | |||
r6 += tail_step; | |||
} | |||
filter += 49; | |||
} | |||
} | |||
#include "src/common/simd_macro/epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,34 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <cstddef> | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fp32 { | |||
namespace conv_stride1 { | |||
void do_conv_2x2_stride1(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||
void do_conv_3x3_stride1(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||
void do_conv_5x5_stride1(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||
void do_conv_7x7_stride1(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||
} // namespace conv_stride1 | |||
} // namespace fp32 | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
@@ -0,0 +1,513 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <algorithm> | |||
#include "./do_conv_stride2.h" | |||
#include "midout.h" | |||
#include "src/arm_common/simd_macro/neon_helper.h" | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs2) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fp32; | |||
using namespace conv_stride2; | |||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||
void conv_stride2::do_conv_2x2_stride2(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, | |||
size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* k0 = filter; | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||
rep(h, OH) { | |||
int nn = OW >> 2; | |||
rep(i, nn) { | |||
MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0); | |||
MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6 | |||
MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7 | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1); | |||
MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1); | |||
MEGDNN_SIMD_TYPE _r10 = _r1.val[0]; | |||
MEGDNN_SIMD_TYPE _r11 = _r1.val[1]; | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k0123, 2); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k0123, 3); | |||
MEGDNN_SIMD_STOREU(outptr, _outp); | |||
r0 += 8; | |||
r1 += 8; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
} | |||
filter += 4; | |||
} | |||
} | |||
void conv_stride2::do_conv_3x3_stride2(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, | |||
size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* k0 = filter; | |||
const float* k1 = filter + 3; | |||
const float* k2 = filter + 5; | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||
MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); | |||
MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); | |||
MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1); | |||
rep(h, OH) { | |||
int nn = OW >> 2; | |||
rep(i, nn) { | |||
MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0); | |||
MEGDNN_SIMD_TYPE2 _r0n = MEGDNN_SIMD_LOAD2(r0 + 8); | |||
MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6 | |||
MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7 | |||
MEGDNN_SIMD_TYPE _r02 = | |||
MEGDNN_SIMD_EXT(_r00, _r0n.val[0], 1); // 2 4 6 8 | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r02, _k0123, 2); | |||
MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1); | |||
MEGDNN_SIMD_TYPE2 _r1n = MEGDNN_SIMD_LOAD2(r1 + 8); | |||
MEGDNN_SIMD_TYPE _r10 = _r1.val[0]; | |||
MEGDNN_SIMD_TYPE _r11 = _r1.val[1]; | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1n.val[0], 1); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k3456, 0); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k3456, 1); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r12, _k3456, 2); | |||
MEGDNN_SIMD_TYPE2 _r2 = MEGDNN_SIMD_LOAD2(r2); | |||
MEGDNN_SIMD_TYPE2 _r2n = MEGDNN_SIMD_LOAD2(r2 + 8); | |||
MEGDNN_SIMD_TYPE _r20 = _r2.val[0]; | |||
MEGDNN_SIMD_TYPE _r21 = _r2.val[1]; | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2n.val[0], 1); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r20, _k6789, 0); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r21, _k6789, 1); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r22, _k6789, 2); | |||
MEGDNN_SIMD_STOREU(outptr, _outp); | |||
r0 += 8; | |||
r1 += 8; | |||
r2 += 8; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
} | |||
filter += 9; | |||
} | |||
} | |||
void conv_stride2::do_conv_5x5_stride2(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, | |||
size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* r4 = src_ptr + IW * 4; | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); | |||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); | |||
MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); | |||
MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); | |||
MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); | |||
MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); | |||
MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]); | |||
for (size_t i = 0; i < OH; i++) { | |||
int nn = OW >> 2; | |||
rep(i, nn) { | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0); | |||
MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8); | |||
MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 | |||
MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 | |||
MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6 | |||
MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7 | |||
MEGDNN_SIMD_TYPE _r02 = | |||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8 | |||
MEGDNN_SIMD_TYPE _r03 = | |||
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9 | |||
MEGDNN_SIMD_TYPE _r04 = | |||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10 | |||
MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1); | |||
MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8); | |||
MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2); | |||
MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2); | |||
MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8); | |||
MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2); | |||
MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3); | |||
MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8); | |||
MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2); | |||
MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4); | |||
MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8); | |||
MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
r0 += 8; | |||
r1 += 8; | |||
r2 += 8; | |||
r3 += 8; | |||
r4 += 8; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
r3 += tail_step; | |||
r4 += tail_step; | |||
} | |||
filter += 25; | |||
} | |||
} | |||
void conv_stride2::do_conv_7x7_stride2(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, | |||
size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* r4 = src_ptr + IW * 4; | |||
const float* r5 = src_ptr + IW * 5; | |||
const float* r6 = src_ptr + IW * 6; | |||
const float* k0 = filter; | |||
const float* k1 = filter + 7; | |||
const float* k2 = filter + 14; | |||
const float* k3 = filter + 21; | |||
const float* k4 = filter + 28; | |||
const float* k5 = filter + 35; | |||
const float* k6 = filter + 42; | |||
for (size_t i = 0; i < OH; i++) { | |||
int nn = OW >> 2; | |||
rep(i, nn) { | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4); | |||
MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0); | |||
MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8); | |||
MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 | |||
MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 | |||
MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6 | |||
MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7 | |||
MEGDNN_SIMD_TYPE _r02 = | |||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8 | |||
MEGDNN_SIMD_TYPE _r03 = | |||
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9 | |||
MEGDNN_SIMD_TYPE _r04 = | |||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10 | |||
MEGDNN_SIMD_TYPE _r05 = | |||
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 2); // 5 7 9 11 | |||
MEGDNN_SIMD_TYPE _r06 = | |||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 3); // 6 8 10 12 | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2); | |||
MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1); | |||
MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4); | |||
MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1); | |||
MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8); | |||
MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2); | |||
MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 2); | |||
MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2); | |||
MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2); | |||
MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4); | |||
MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2); | |||
MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8); | |||
MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2); | |||
MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 2); | |||
MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2); | |||
MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3); | |||
MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4); | |||
MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3); | |||
MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8); | |||
MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2); | |||
MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 2); | |||
MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2); | |||
MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4); | |||
MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4); | |||
MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4); | |||
MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8); | |||
MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2); | |||
MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 2); | |||
MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2); | |||
MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5); | |||
MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4); | |||
MEGDNN_SIMD_TYPE2 _r50_02461357 = MEGDNN_SIMD_LOAD2(r5); | |||
MEGDNN_SIMD_TYPE2 _r50nx2 = MEGDNN_SIMD_LOAD2(r5 + 8); | |||
MEGDNN_SIMD_TYPE _r5_8101214 = _r50nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r5_9111315 = _r50nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r50 = _r50_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r51 = _r50_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 2); | |||
MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 2); | |||
MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); | |||
MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); | |||
MEGDNN_SIMD_TYPE _k45464748 = MEGDNN_SIMD_LOADU(k6 + 3); | |||
MEGDNN_SIMD_TYPE2 _r60_02461357 = MEGDNN_SIMD_LOAD2(r6); | |||
MEGDNN_SIMD_TYPE2 _r60nx2 = MEGDNN_SIMD_LOAD2(r6 + 8); | |||
MEGDNN_SIMD_TYPE _r6_8101214 = _r60nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r6_9111315 = _r60nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r60 = _r60_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r61 = _r60_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 2); | |||
MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 2); | |||
MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k45464748, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k45464748, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k45464748, 3); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
r0 += 8; | |||
r1 += 8; | |||
r2 += 8; | |||
r3 += 8; | |||
r4 += 8; | |||
r5 += 8; | |||
r6 += 8; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
r3 += tail_step; | |||
r4 += tail_step; | |||
r5 += tail_step; | |||
r6 += tail_step; | |||
} | |||
filter += 49; | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,32 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fp32 { | |||
namespace conv_stride2 { | |||
void do_conv_2x2_stride2(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||
void do_conv_3x3_stride2(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||
void do_conv_5x5_stride2(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||
void do_conv_7x7_stride2(const float* src, const float* filter, float* dst, | |||
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); | |||
} // namespace conv_stride2 | |||
} // namespace fp32 | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,164 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/filter_transform.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "megdnn/opr_param_defs.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
template <param::MatrixMul::Format format=param::MatrixMul::Format::DEFAULT> | |||
struct FilterTransform6X3 { | |||
#define FILTER_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
auto tmp0 = (d##0 + d##2) * -0.2222222f; \ | |||
auto tmp1 = d##1 * -0.2222222f; \ | |||
wd##1 = tmp0 + tmp1; \ | |||
wd##2 = tmp0 - tmp1; \ | |||
tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \ | |||
tmp1 = d##1 * 0.0222222f; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \ | |||
tmp1 = d##1 * 0.3555556f; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = d##2; \ | |||
} while (0); | |||
static void transform(const float* filter, float* filter_transform_buf, | |||
float* transform_mid_buf, size_t OC, size_t IC, | |||
size_t oc_start, size_t oc_end) { | |||
// Gg * GT | |||
// G | |||
// 1.0000000 0.0000000 0.0000000 | |||
// -0.2222222 -0.2222222 -0.2222222 | |||
// -0.2222222 0.2222222 -0.2222222 | |||
// 0.0111111 0.0222222 0.0444444 | |||
// 0.0111111 -0.0222222 0.0444444 | |||
// 0.7111111 0.3555556 0.1777778 | |||
// 0.7111111 -0.3555556 0.1777778 | |||
// 0.0000000 0.0000000 1.0000000 | |||
constexpr size_t alpha = 6 + 3 - 1; | |||
size_t OCB = OC / 4; | |||
size_t ICB = IC / 4; | |||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||
rep(ic, IC) { | |||
const float* fptr = filter + (oc * IC + ic) * 3 * 3; | |||
Vector<float, 4> g0 = Vector<float, 4>::load(fptr); | |||
Vector<float, 4> g1 = Vector<float, 4>::load(fptr + 3); | |||
Vector<float, 4> g2 = Vector<float, 4>::load(fptr + 6 - 1); | |||
float32x4_t zeros = vdupq_n_f32(0.0f); | |||
g2.value = vextq_f32(g2.value, zeros, 1); | |||
#define cb(i) Vector<float, 4> wd##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> wdt##i; | |||
UNROLL_CALL_NOWRAPPER(3, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
FILTER_TRANSFORM(g, wd); | |||
size_t ocb = oc / 4; | |||
size_t oc4 = oc % 4; | |||
size_t icb = ic / 4; | |||
size_t ic4 = ic % 4; | |||
#if MEGDNN_AARCH64 | |||
TRANSPOSE_8x3(wd, wdt); | |||
FILTER_TRANSFORM(wdt, ret); | |||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
if (format == param::MatrixMul::Format::DEFAULT) { | |||
filter_transform_buf[(i * alpha + j) * OC * IC + | |||
ic * OC + oc] = | |||
transform_mid_buf[j * alpha + i]; | |||
} else { | |||
filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 * | |||
4 + | |||
ocb * ICB * 4 * 4 + icb * 4 * 4 + | |||
ic4 * 4 + oc4] = | |||
transform_mid_buf[j * alpha + i]; | |||
} | |||
} | |||
#else | |||
#define cb(i) \ | |||
do { \ | |||
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ | |||
auto tmp0 = (GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * \ | |||
-0.2222222f; \ | |||
auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * -0.2222222f; \ | |||
mid_buf1[1] = tmp0 + tmp1; \ | |||
mid_buf1[2] = tmp0 - tmp1; \ | |||
tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0111111f + \ | |||
GET_VECTOR_ELEM(wd, i, 2) * 0.0444444f; \ | |||
tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0222222f; \ | |||
mid_buf1[3] = tmp0 + tmp1; \ | |||
mid_buf1[4] = tmp0 - tmp1; \ | |||
tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.7111111f + \ | |||
GET_VECTOR_ELEM(wd, i, 2) * 0.1777778f; \ | |||
tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.3555556f; \ | |||
mid_buf1[5] = tmp0 + tmp1; \ | |||
mid_buf1[6] = tmp0 - tmp1; \ | |||
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ | |||
mid_buf1 += 8; \ | |||
} while (0); | |||
#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) | |||
float* mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
mid_buf1 = transform_mid_buf; | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
if (format == param::MatrixMul::Format::DEFAULT) { | |||
filter_transform_buf[(i * alpha + j) * OC * IC + | |||
ic * OC + oc] = | |||
transform_mid_buf[i * alpha + j]; | |||
} else { | |||
filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 * | |||
4 + | |||
ocb * ICB * 4 * 4 + icb * 4 * 4 + | |||
ic4 * 4 + oc4] = | |||
transform_mid_buf[i * alpha + j]; | |||
} | |||
} | |||
#endif | |||
} | |||
} | |||
} | |||
}; | |||
#undef FILTER_TRANSFORM | |||
#undef GET_VECTOR_ELEM | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,204 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/helper.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/common/unroll_macro.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { | |||
float32x4x2_t a0, a1; | |||
a0.val[0] = vld1q_f32(src + 0 * lda); | |||
a0.val[1] = vld1q_f32(src + 1 * lda); | |||
a1.val[0] = vld1q_f32(src + 2 * lda); | |||
a1.val[1] = vld1q_f32(src + 3 * lda); | |||
float32x4x2_t b0 = vzipq_f32(a0.val[0], a1.val[0]); | |||
float32x4x2_t b1 = vzipq_f32(a0.val[1], a1.val[1]); | |||
float32x4x2_t c0 = vzipq_f32(b0.val[0], b1.val[0]); | |||
float32x4x2_t c1 = vzipq_f32(b0.val[1], b1.val[1]); | |||
vst1q_f32(dst + 0 * ldb, c0.val[0]); | |||
vst1q_f32(dst + 1 * ldb, c0.val[1]); | |||
vst1q_f32(dst + 2 * ldb, c1.val[0]); | |||
vst1q_f32(dst + 3 * ldb, c1.val[1]); | |||
} | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
#define MATRIX_MUL4x4(sum, a, b) \ | |||
sum##0 = vmlaq_low_lane_f32(sum##0, b##0, a##0, 0); \ | |||
sum##0 = vmlaq_low_lane_f32(sum##0, b##1, a##0, 1); \ | |||
sum##0 = vmlaq_high_lane_f32(sum##0, b##2, a##0, 2); \ | |||
sum##0 = vmlaq_high_lane_f32(sum##0, b##3, a##0, 3); \ | |||
sum##1 = vmlaq_low_lane_f32(sum##1, b##0, a##1, 0); \ | |||
sum##1 = vmlaq_low_lane_f32(sum##1, b##1, a##1, 1); \ | |||
sum##1 = vmlaq_high_lane_f32(sum##1, b##2, a##1, 2); \ | |||
sum##1 = vmlaq_high_lane_f32(sum##1, b##3, a##1, 3); \ | |||
sum##2 = vmlaq_low_lane_f32(sum##2, b##0, a##2, 0); \ | |||
sum##2 = vmlaq_low_lane_f32(sum##2, b##1, a##2, 1); \ | |||
sum##2 = vmlaq_high_lane_f32(sum##2, b##2, a##2, 2); \ | |||
sum##2 = vmlaq_high_lane_f32(sum##2, b##3, a##2, 3); \ | |||
sum##3 = vmlaq_low_lane_f32(sum##3, b##0, a##3, 0); \ | |||
sum##3 = vmlaq_low_lane_f32(sum##3, b##1, a##3, 1); \ | |||
sum##3 = vmlaq_high_lane_f32(sum##3, b##2, a##3, 2); \ | |||
sum##3 = vmlaq_high_lane_f32(sum##3, b##3, a##3, 3); | |||
#define CONCAT(a, idx) a##idx | |||
#if MEGDNN_AARCH64 | |||
//! ret and a are type Vector<float, 8> | |||
#define TRANSPOSE_8x8(a, ret) \ | |||
do { \ | |||
auto b0 = vzipq_f32(CONCAT(a, 0).value.val[0], \ | |||
CONCAT(a, 1).value.val[0]); \ | |||
auto b1 = vzipq_f32(CONCAT(a, 0).value.val[1], \ | |||
CONCAT(a, 1).value.val[1]); \ | |||
auto b2 = vzipq_f32(CONCAT(a, 2).value.val[0], \ | |||
CONCAT(a, 3).value.val[0]); \ | |||
auto b3 = vzipq_f32(CONCAT(a, 2).value.val[1], \ | |||
CONCAT(a, 3).value.val[1]); \ | |||
auto b4 = vzipq_f32(CONCAT(a, 4).value.val[0], \ | |||
CONCAT(a, 5).value.val[0]); \ | |||
auto b5 = vzipq_f32(CONCAT(a, 4).value.val[1], \ | |||
CONCAT(a, 5).value.val[1]); \ | |||
auto b6 = vzipq_f32(CONCAT(a, 6).value.val[0], \ | |||
CONCAT(a, 7).value.val[0]); \ | |||
auto b7 = vzipq_f32(CONCAT(a, 6).value.val[1], \ | |||
CONCAT(a, 7).value.val[1]); \ | |||
CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \ | |||
vreinterpretq_s64_f32(b2.val[0]))); \ | |||
CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b4.val[0]), \ | |||
vreinterpretq_s64_f32(b6.val[0]))); \ | |||
CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \ | |||
vreinterpretq_s64_f32(b2.val[0]))); \ | |||
CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b4.val[0]), \ | |||
vreinterpretq_s64_f32(b6.val[0]))); \ | |||
CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \ | |||
vreinterpretq_s64_f32(b2.val[1]))); \ | |||
CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b4.val[1]), \ | |||
vreinterpretq_s64_f32(b6.val[1]))); \ | |||
CONCAT(ret, 3).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b0.val[1]), \ | |||
vreinterpretq_s64_f32(b2.val[1]))); \ | |||
CONCAT(ret, 3).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b4.val[1]), \ | |||
vreinterpretq_s64_f32(b6.val[1]))); \ | |||
CONCAT(ret, 4).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b1.val[0]), \ | |||
vreinterpretq_s64_f32(b3.val[0]))); \ | |||
CONCAT(ret, 4).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b5.val[0]), \ | |||
vreinterpretq_s64_f32(b7.val[0]))); \ | |||
CONCAT(ret, 5).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b1.val[0]), \ | |||
vreinterpretq_s64_f32(b3.val[0]))); \ | |||
CONCAT(ret, 5).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b5.val[0]), \ | |||
vreinterpretq_s64_f32(b7.val[0]))); \ | |||
CONCAT(ret, 6).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b1.val[1]), \ | |||
vreinterpretq_s64_f32(b3.val[1]))); \ | |||
CONCAT(ret, 6).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b5.val[1]), \ | |||
vreinterpretq_s64_f32(b7.val[1]))); \ | |||
CONCAT(ret, 7).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b1.val[1]), \ | |||
vreinterpretq_s64_f32(b3.val[1]))); \ | |||
CONCAT(ret, 7).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b5.val[1]), \ | |||
vreinterpretq_s64_f32(b7.val[1]))); \ | |||
} while (0); | |||
#define TRANSPOSE_8x3(a, ret) \ | |||
auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||
auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||
auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||
auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||
CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \ | |||
vreinterpretq_s64_f32(b1.val[0]))); \ | |||
CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b2.val[0]), \ | |||
vreinterpretq_s64_f32(b3.val[0]))); \ | |||
CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \ | |||
vreinterpretq_s64_f32(b1.val[0]))); \ | |||
CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b2.val[0]), \ | |||
vreinterpretq_s64_f32(b3.val[0]))); \ | |||
CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \ | |||
vreinterpretq_s64_f32(b1.val[1]))); \ | |||
CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b2.val[1]), \ | |||
vreinterpretq_s64_f32(b3.val[1]))); | |||
#define TRANSPOSE_8x4(a, ret) \ | |||
auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||
auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||
auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||
auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||
CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \ | |||
vreinterpretq_s64_f32(b1.val[0]))); \ | |||
CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b2.val[0]), \ | |||
vreinterpretq_s64_f32(b3.val[0]))); \ | |||
CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \ | |||
vreinterpretq_s64_f32(b1.val[0]))); \ | |||
CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b2.val[0]), \ | |||
vreinterpretq_s64_f32(b3.val[0]))); \ | |||
CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \ | |||
vreinterpretq_s64_f32(b1.val[1]))); \ | |||
CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip1q_s64(vreinterpretq_s64_f32(b2.val[1]), \ | |||
vreinterpretq_s64_f32(b3.val[1]))); \ | |||
CONCAT(ret, 3).value.val[0] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b0.val[1]), \ | |||
vreinterpretq_s64_f32(b1.val[1]))); \ | |||
CONCAT(ret, 3).value.val[1] = vreinterpretq_f32_s64( \ | |||
vzip2q_s64(vreinterpretq_s64_f32(b2.val[1]), \ | |||
vreinterpretq_s64_f32(b3.val[1]))); | |||
#elif MEGDNN_ARMV7 | |||
#define TRANSPOSE_8x4(a, ret) \ | |||
auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||
auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||
auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||
auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||
CONCAT(ret, 0).value.val[0] = \ | |||
vcombine_f32(vget_low_f32(b0.val[0]), vget_low_f32(b1.val[0])); \ | |||
CONCAT(ret, 1).value.val[0] = \ | |||
vcombine_f32(vget_high_f32(b0.val[0]), vget_high_f32(b1.val[0])); \ | |||
CONCAT(ret, 2).value.val[0] = \ | |||
vcombine_f32(vget_low_f32(b0.val[1]), vget_low_f32(b1.val[1])); \ | |||
CONCAT(ret, 3).value.val[0] = \ | |||
vcombine_f32(vget_high_f32(b0.val[1]), vget_high_f32(b1.val[1])); \ | |||
CONCAT(ret, 0).value.val[1] = \ | |||
vcombine_f32(vget_low_f32(b2.val[0]), vget_low_f32(b3.val[0])); \ | |||
CONCAT(ret, 1).value.val[1] = \ | |||
vcombine_f32(vget_high_f32(b2.val[0]), vget_high_f32(b3.val[0])); \ | |||
CONCAT(ret, 2).value.val[1] = \ | |||
vcombine_f32(vget_low_f32(b2.val[1]), vget_low_f32(b3.val[1])); \ | |||
CONCAT(ret, 3).value.val[1] = \ | |||
vcombine_f32(vget_high_f32(b2.val[1]), vget_high_f32(b3.val[1])); | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,39 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4, | |||
winograd_2x3_4x4_f) | |||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, | |||
winograd_6x3_1x1_f) | |||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4, | |||
winograd_6x3_4x4_f) | |||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 5, 4, 1, 1, | |||
winograd_5x4_1x1_f) | |||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 4, 5, 1, 1, | |||
winograd_4x5_1x1_f) | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,346 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F23) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
namespace { | |||
struct InputTransform2X3 { | |||
template <bool inner> | |||
static void prepare(const float* input, float* patch, float* patchT, | |||
int ih_start, int iw_start, size_t IH, size_t IW, | |||
size_t ic, size_t IC) { | |||
constexpr size_t alpha = 2 + 3 - 1; | |||
if (!(inner && ic + 4 < IC)) { | |||
memset(patch, 0, sizeof(float) * 4 * alpha * alpha); | |||
} | |||
if (inner) { | |||
const float* input_ptr = | |||
input + ic * IH * IW + ih_start * IW + iw_start; | |||
for (size_t ico = 0; ico < 4; ++ico) { | |||
if (ic + ico < IC) { | |||
auto v0 = vld1q_f32(input_ptr); | |||
auto v1 = vld1q_f32(input_ptr + IW); | |||
auto v2 = vld1q_f32(input_ptr + IW * 2); | |||
auto v3 = vld1q_f32(input_ptr + IW * 3); | |||
vst1q_f32(patch + ico * 4 * alpha + 0 * 4, v0); | |||
vst1q_f32(patch + ico * 4 * alpha + 1 * 4, v1); | |||
vst1q_f32(patch + ico * 4 * alpha + 2 * 4, v2); | |||
vst1q_f32(patch + ico * 4 * alpha + 3 * 4, v3); | |||
input_ptr += IH * IW; | |||
} | |||
} | |||
} else { | |||
int ih0_act = std::max<int>(ih_start, 0), | |||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||
iw0_act = std::max<int>(iw_start, 0), | |||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||
// partial copy | |||
for (size_t ico = 0; ico < 4; ++ico) { | |||
if (ic + ico < IC) { | |||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||
patch[ico * alpha * 4 + iho * 4 + iwo] = | |||
input[(ic + ico) * IH * IW + ih * IW + iw]; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
transpose_4x4(patch + 0 * 1, patchT + 0 * 4, 16, 4); | |||
transpose_4x4(patch + 4 * 1, patchT + 4 * 4, 16, 4); | |||
transpose_4x4(patch + 8 * 1, patchT + 8 * 4, 16, 4); | |||
transpose_4x4(patch + 12 * 1, patchT + 12 * 4, 16, 4); | |||
} | |||
static void transform(const float* patchT, float* input_transform_buf, | |||
size_t unit_idx, size_t nr_units_in_tile, size_t ic, | |||
size_t IC) { | |||
constexpr size_t alpha = 2 + 3 - 1; | |||
// BT * d * B | |||
#define cb(m, n) \ | |||
Vector<float, 4> d##m##n = \ | |||
Vector<float, 4>::load(patchT + m * 4 * 4 + n * 4); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
#undef cb | |||
//! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 | |||
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | |||
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | |||
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 | |||
#define cb(m) \ | |||
auto t0##m = d0##m - d2##m; \ | |||
auto t1##m = d1##m + d2##m; \ | |||
auto t2##m = d2##m - d1##m; \ | |||
auto t3##m = d3##m - d1##m; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
d##m##0 = t##m##0 - t##m##2; \ | |||
d##m##1 = t##m##1 + t##m##2; \ | |||
d##m##2 = t##m##2 - t##m##1; \ | |||
d##m##3 = t##m##3 - t##m##1; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
size_t ICB = IC / 4; | |||
size_t icb = ic / 4; | |||
#define cb(m, n) \ | |||
d##m##n.save(input_transform_buf + \ | |||
(m * alpha + n) * ICB * nr_units_in_tile * 4 + \ | |||
icb * nr_units_in_tile * 4 + unit_idx * 4); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) | |||
#undef cb | |||
} | |||
}; | |||
template <BiasMode bmode, typename Op> | |||
struct OutputTransform2X3 { | |||
static void transform(const float* output_transform_buf, const float* bias, | |||
float* output, float* transform_mid_buf, | |||
size_t oh_start, size_t ow_start, size_t OH, | |||
size_t OW, size_t oc_start, size_t oc_end, | |||
size_t oc_index, size_t unit_idx, | |||
size_t nr_units_in_tile, const DType& src_dtype, | |||
const DType& dst_dtype) { | |||
Op op(src_dtype, dst_dtype); | |||
//! AT * m * A | |||
constexpr size_t alpha = 2 + 3 - 1; | |||
size_t oc = oc_start + oc_index; | |||
size_t OCB = (oc_end - oc_start) / 4; | |||
size_t ocb = oc_index / 4; | |||
#define cb(m, n) \ | |||
auto v##m##n = Vector<float, 4>::load( \ | |||
output_transform_buf + \ | |||
(m * alpha + n) * OCB * nr_units_in_tile * 4 + \ | |||
ocb * nr_units_in_tile * 4 + unit_idx * 4); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
#undef cb | |||
//! 1 1 1 0 v00 v01 v02 v03 1 0 | |||
//! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||
//! v20 v21 v22 v23 1 -1 | |||
//! v30 v31 v32 v33 0 1 | |||
#define cb(m) \ | |||
auto t0##m = v0##m + v1##m + v2##m; \ | |||
auto t1##m = v1##m - v2##m + v3##m; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
v00 = t00 + t01 + t02; | |||
v10 = t10 + t11 + t12; | |||
v01 = t01 - t02 + t03; | |||
v11 = t11 - t12 + t13; | |||
Vector<float, 4> vbias; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
vbias = Vector<float, 4>::load(bias + oc); | |||
v00 += vbias; | |||
v10 += vbias; | |||
v01 += vbias; | |||
v11 += vbias; | |||
} | |||
if (bmode != BiasMode::BIAS) { | |||
v00 = op(v00.value); | |||
v01 = op(v01.value); | |||
v10 = op(v10.value); | |||
v11 = op(v11.value); | |||
} | |||
v00.save(transform_mid_buf + (0 * 2 + 0) * 4); | |||
v10.save(transform_mid_buf + (1 * 2 + 0) * 4); | |||
v01.save(transform_mid_buf + (0 * 2 + 1) * 4); | |||
v11.save(transform_mid_buf + (1 * 2 + 1) * 4); | |||
for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { | |||
for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) { | |||
for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) { | |||
size_t oh = oh_start + oho; | |||
size_t ow = ow_start + owo; | |||
float res = transform_mid_buf[oho * 2 * 4 + owo * 4 + oco]; | |||
if (bmode == BiasMode::BIAS) { | |||
res += bias[(oc + oco) * OH * OW + oh * OW + ow]; | |||
res = op(res); | |||
} | |||
output[(oc + oco) * OH * OW + oh * OW + ow] = res; | |||
} | |||
} | |||
} | |||
} | |||
}; | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) | |||
void winograd_2x3_4x4_f::filter(const float* filter, | |||
float* filter_transform_buf, | |||
float* transform_mid_buf, size_t OC, size_t IC, | |||
size_t oc_start, size_t oc_end) { | |||
constexpr int alpha = 2 + 3 - 1; | |||
//! G * g * GT | |||
float32x4_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0}, | |||
g3{0, 0, 1, 0}; | |||
float32x4_t gt0{1, 0.5, 0.5, 0}, gt1{0, 0.5, -0.5, 0}, gt2{0, 0.5, 0.5, 1}, | |||
gt3{0, 0, 0, 0}; | |||
size_t OCB = OC / 4; | |||
size_t ICB = IC / 4; | |||
for (size_t oc = oc_start; oc < oc_end; oc++) | |||
rep(ic, IC) { | |||
const float* filter_ptr = filter + (oc * IC + ic) * 3 * 3; | |||
/** | |||
* origin: (4x3) * (3 x 3) * (3 x 4) | |||
* pack to G and g to times of 4 | |||
* now: (4x4) * (4 x 4) * (4 x 4) | |||
*/ | |||
//! 1 0 0 0 v00 v01 v02 0 1 0.5 0.5 0 | |||
//! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 | |||
//! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 | |||
//! 0 0 1 0 0 0 0 0 0 0 0 0 | |||
float32x4_t vf0 = vld1q_f32(filter_ptr); | |||
float32x4_t vf1 = vld1q_f32(filter_ptr + 4); | |||
float32x4_t vf2 = vdupq_n_f32(filter_ptr[8]); | |||
float32x4_t v3(vdupq_n_f32(0)); | |||
auto vtmp = vextq_f32(vf1, vf2, 2); | |||
vtmp = vsetq_lane_f32(0, vtmp, 3); | |||
float32x4_t v2(vtmp); | |||
vtmp = vextq_f32(vf0, vf1, 3); | |||
vtmp = vsetq_lane_f32(0, vtmp, 3); | |||
float32x4_t v1(vtmp); | |||
vtmp = vsetq_lane_f32(0, vf0, 3); | |||
float32x4_t v0(vtmp); | |||
float32x4_t vsum0 = vdupq_n_f32(0), vsum1 = vdupq_n_f32(0), | |||
vsum2 = vdupq_n_f32(0), vsum3 = vdupq_n_f32(0); | |||
MATRIX_MUL4x4(vsum, g, v); | |||
float32x4_t vres0 = vdupq_n_f32(0), vres1 = vdupq_n_f32(0), | |||
vres2 = vdupq_n_f32(0), vres3 = vdupq_n_f32(0); | |||
MATRIX_MUL4x4(vres, vsum, gt); | |||
vst1q_f32(transform_mid_buf, vres0); | |||
vst1q_f32(transform_mid_buf + 4, vres1); | |||
vst1q_f32(transform_mid_buf + 8, vres2); | |||
vst1q_f32(transform_mid_buf + 12, vres3); | |||
size_t ocb = oc / 4; | |||
size_t oc4 = oc % 4; | |||
size_t icb = ic / 4; | |||
size_t ic4 = ic % 4; | |||
rep(i, alpha) rep(j, alpha) { | |||
filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 * 4 + | |||
ocb * ICB * 4 * 4 + icb * 4 * 4 + ic4 * 4 + | |||
oc4] = transform_mid_buf[i * alpha + j]; | |||
} | |||
} | |||
} | |||
void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf, | |||
float* transform_mid_buf, size_t IH, size_t IW, | |||
size_t IC, size_t PH, size_t PW, | |||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||
megdnn_assert(IC % 4 == 0); | |||
constexpr int alpha = 3 + 2 - 1; | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
float* patch = transform_mid_buf; | |||
float* patchT = transform_mid_buf + 4 * alpha * alpha; | |||
for (size_t ic = 0; ic < IC; ic += 4) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||
InputTransform2X3::prepare<true>(input, patch, patchT, ih_start, | |||
iw_start, IH, IW, ic, IC); | |||
InputTransform2X3::transform(patchT, input_transform_buf, | |||
unit_idx, nr_units_in_tile, ic, | |||
IC); | |||
} else { | |||
InputTransform2X3::prepare<false>(input, patch, patchT, | |||
ih_start, iw_start, IH, IW, | |||
ic, IC); | |||
InputTransform2X3::transform(patchT, input_transform_buf, | |||
unit_idx, nr_units_in_tile, ic, | |||
IC); | |||
} | |||
} | |||
} | |||
} | |||
void winograd_2x3_4x4_f::output(const float* output_transform_buf, | |||
const float* bias, float* output, | |||
float* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
#define cb(_bmode, _nonline_op, ...) \ | |||
OutputTransform2X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
for (size_t oc = oc_start; oc < oc_end; oc += 4) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp32_F23, cb, float, float, bmode, | |||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||
nr_units_in_tile, src_dtype, dst_dtype); | |||
} | |||
} | |||
#undef cb | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,483 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F45) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
namespace { | |||
struct FilterTransform4X5 { | |||
#define FILTER_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
wd##r0 = d##r0; \ | |||
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \ | |||
wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; \ | |||
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \ | |||
wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; \ | |||
auto tmpd0 = d##0 * 0.7111111; \ | |||
auto tmpd1 = d##1 * 0.3555556; \ | |||
auto tmpd2 = d##2 * 0.1777778; \ | |||
auto tmpd3 = d##3 * 0.0888889; \ | |||
auto tmpd4 = d##4 * 0.0444444; \ | |||
auto tmpdr0 = d##r0 * 0.7111111; \ | |||
auto tmpdr1 = d##r1 * 0.3555556; \ | |||
auto tmpdr2 = d##r2 * 0.1777778; \ | |||
auto tmpdr3 = d##r3 * 0.0888889; \ | |||
auto tmpdr4 = d##r4 * 0.0444444; \ | |||
wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ | |||
wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||
wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ | |||
wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||
tmpd0 = d##0 * 0.0111111; \ | |||
tmpd1 = d##1 * 0.0222222; \ | |||
tmpd2 = d##2 * 0.0444444; \ | |||
tmpd3 = d##3 * 0.0888889; \ | |||
tmpd4 = d##4 * 0.1777778; \ | |||
tmpdr0 = d##r0 * 0.0111111; \ | |||
tmpdr1 = d##r1 * 0.0222222; \ | |||
tmpdr2 = d##r2 * 0.0444444; \ | |||
tmpdr3 = d##r3 * 0.0888889; \ | |||
tmpdr4 = d##r4 * 0.1777778; \ | |||
wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ | |||
wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ | |||
wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ | |||
wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ | |||
wd##7 = d##4; \ | |||
wd##r7 = d##r4; \ | |||
} while (0); | |||
#define FILTER_TRANSFORM_FINAL(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \ | |||
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \ | |||
auto tmp0 = d##0 * 0.7111111 + d##2 * 0.1777778 + d##4 * 0.0444444; \ | |||
auto tmp1 = d##1 * 0.3555556 + d##3 * 0.0888889; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d##0 * 0.0111111 + d##2 * 0.0444444 + d##4 * 0.1777778; \ | |||
tmp1 = d##1 * 0.0222222 + d##3 * 0.0888889; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = d##4; \ | |||
} while (0); | |||
static void transform(const float* filter, float* filter_transform_buf, | |||
float* transform_mid_buf, size_t OC, size_t IC, | |||
size_t oc_start, size_t oc_end) { | |||
// Gg * GT | |||
// G | |||
//[[ 1. 0. 0. 0. 0. ] | |||
// [-0.2222222 -0.2222222 -0.2222222 -0.2222222 -0.2222222] | |||
// [-0.2222222 0.2222222 -0.2222222 0.2222222 -0.2222222] | |||
// [ 0.7111111 0.3555556 0.1777778 0.0888889 0.0444444] | |||
// [ 0.7111111 -0.3555556 0.1777778 -0.0888889 0.0444444] | |||
// [ 0.0111111 0.0222222 0.0444444 0.0888889 0.1777778] | |||
// [ 0.0111111 -0.0222222 0.0444444 -0.0888889 0.1777778] | |||
// [ 0. 0. 0. 0. 1. ]] | |||
constexpr size_t alpha = 4 + 5 - 1; | |||
for (size_t oc = oc_start; oc < oc_end; oc++) | |||
rep(ic, IC) { | |||
const float* fptr = filter + (oc * IC + ic) * 5 * 5; | |||
#define cb(i) Vector<float, 4> g##i = Vector<float, 4>::load(fptr + 5 * i); | |||
UNROLL_CALL_NOWRAPPER(5, cb); | |||
#undef cb | |||
#define cb(i) float gr##i = *(fptr + 5 * i + 4); | |||
UNROLL_CALL_NOWRAPPER(5, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 4> Gg##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) float Ggr##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> Ggt##i; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> result##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
FILTER_TRANSFORM(g, Gg) | |||
float32x4x2_t vgr; | |||
float32x4_t vgr0 = {Ggr0, Ggr1, Ggr2, Ggr3}; | |||
float32x4_t vgr1 = {Ggr4, Ggr5, Ggr6, Ggr7}; | |||
vgr.val[0] = vgr0; //{Ggr0, Ggr1, Ggr2, Ggr3}; | |||
vgr.val[1] = vgr1; //{Ggr4, Ggr5, Ggr6, Ggr7}; | |||
Vector<float, 8> Ggt4(vgr); | |||
TRANSPOSE_8x4(Gg, Ggt); | |||
FILTER_TRANSFORM_FINAL(Ggt, result); | |||
#define cb(i) result##i.save(transform_mid_buf + i * alpha); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||
oc] = transform_mid_buf[j * alpha + i]; | |||
} | |||
} | |||
} | |||
}; | |||
#undef FILTER_TRANSFORM | |||
#undef FILTER_TRANSFORM_FINAL | |||
struct InputTransform4X5 { | |||
#define INPUT_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||
auto tmp0 = d##2 - d##4 * 4.25f + d##6; \ | |||
auto tmp1 = d##1 - d##3 * 4.25f + d##5; \ | |||
wd##1 = tmp0 + tmp1; \ | |||
wd##2 = tmp0 - tmp1; \ | |||
tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ | |||
tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ | |||
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||
} while (0) | |||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||
vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) | |||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||
vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) | |||
template <bool inner> | |||
static void transform(const float* input, float* input_transform_buf, | |||
float* transform_mid_buf, int ih_start, int iw_start, | |||
size_t ic, size_t IH, size_t IW, size_t IC, | |||
size_t unit_idx, size_t nr_units_in_tile) { | |||
// BTd * B | |||
//([[ 1. , 0. , -5.25, 0. , 5.25, 0. , -1. , 0. ], | |||
// [ 0. , 1. , 1. , -4.25, -4.25, 1. , 1. , 0. ], | |||
// [ 0. , -1. , 1. , 4.25, -4.25, -1. , 1. , 0. ], | |||
// [ 0. , 2. , 4. , -2.5 , -5. , 0.5 , 1. , 0. ], | |||
// [ 0. , -2. , 4. , 2.5 , -5. , -0.5 , 1. , 0. ], | |||
// [ 0. , 0.5 , 0.25, -2.5 , -1.25, 2. , 1. , 0. ], | |||
// [ 0. , -0.5 , 0.25, 2.5 , -1.25, -2. , 1. , 0. ], | |||
// [ 0. , -1. , 0. , 5.25, 0. , -5.25, 0. , 1. ]])) | |||
constexpr size_t alpha = 4 + 5 - 1; | |||
if (!inner) { | |||
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | |||
} | |||
#define cb(i) Vector<float, 8> d##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
if (inner) { | |||
const float* input_ptr = | |||
input + ic * IH * IW + ih_start * IW + iw_start; | |||
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} else { | |||
int ih0_act = std::max<int>(ih_start, 0), | |||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||
iw0_act = std::max<int>(iw_start, 0), | |||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||
transform_mid_buf[iho * alpha + iwo] = | |||
input[ic * IH * IW + ih * IW + iw]; | |||
} | |||
} | |||
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} | |||
#define cb(i) Vector<float, 8> wd##i, ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
INPUT_TRANSFORM(d, wd); | |||
#if MEGDNN_AARCH64 | |||
TRANSPOSE_8x8(wd, d); | |||
INPUT_TRANSFORM(d, ret); | |||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||
unit_idx * IC + ic] = | |||
transform_mid_buf[j * alpha + i]; | |||
} | |||
#else | |||
#define cb(i) \ | |||
do { \ | |||
mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ | |||
5.25 * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ | |||
GET_VECTOR_LOW_ELEM(wd, i, 2)); \ | |||
mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \ | |||
GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||
5.25 * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ | |||
auto tmp0 = 4 * GET_VECTOR_LOW_ELEM(wd, i, 2) + \ | |||
-5 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 2); \ | |||
auto tmp1 = 2 * GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||
-2.5 * GET_VECTOR_LOW_ELEM(wd, i, 3) + \ | |||
0.5 * GET_VECTOR_HIGH_ELEM(wd, i, 1); \ | |||
mid_buf1[3] = tmp0 + tmp1; \ | |||
mid_buf1[4] = tmp0 - tmp1; \ | |||
tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ | |||
-4.25 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 2); \ | |||
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||
GET_VECTOR_LOW_ELEM(wd, i, 3) * -4.25 + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 1); \ | |||
mid_buf1[1] = tmp0 + tmp1; \ | |||
mid_buf1[2] = tmp0 - tmp1; \ | |||
tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) * 0.25 + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 0) * -1.25 + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 2); \ | |||
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5 + \ | |||
GET_VECTOR_LOW_ELEM(wd, i, 3) * -2.5 + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2; \ | |||
mid_buf1[5] = tmp0 + tmp1; \ | |||
mid_buf1[6] = tmp0 - tmp1; \ | |||
mid_buf1 += 8; \ | |||
} while (0); | |||
float* mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
mid_buf1 = transform_mid_buf; | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||
unit_idx * IC + ic] = | |||
transform_mid_buf[i * alpha + j]; | |||
} | |||
#endif | |||
} | |||
}; | |||
#undef INPUT_TRANSFORM | |||
#define OUTPUT_TRANSFORM(m, s) \ | |||
do { \ | |||
s0 = m0 + m1 + m2 + m3 + m4 + m5 + m6; \ | |||
s1 = m1 - m2 + m3 * 0.5 - m4 * 0.5 + m5 * 2.0 - m6 * 2.0; \ | |||
s2 = m1 + m2 + m3 * 0.25 + m4 * 0.25 + m5 * 4.0 + m6 * 4.0; \ | |||
s3 = m1 - m2 + m3 * 0.125 - m4 * 0.125 + m5 * 8.0 - m6 * 8.0 + m7; \ | |||
} while (0) | |||
template <BiasMode bmode, typename Op> | |||
struct OutputTransform4X5 { | |||
static void transform(const float* output_transform_buf, const float* bias, | |||
float* output, float* transform_mid_buf, | |||
size_t oh_start, size_t ow_start, size_t OH, | |||
size_t OW, size_t oc_start, size_t oc_end, | |||
size_t oc_index, size_t unit_idx, | |||
size_t nr_units_in_tile, const DType& src_dtype, | |||
const DType& dst_dtype) { | |||
Op op(src_dtype, dst_dtype); | |||
//! AT * m * A | |||
// AT f45 | |||
// 1.0 1.0 1.0 1.000 1.000 1.0 1.0 0.0 | |||
// 0.0 1.0 -1.0 0.500 -0.500 2.0 -2.0 0.0 | |||
// 0.0 1.0 1.0 0.250 0.250 4.0 4.0 0.0 | |||
// 0.0 1.0 -1.0 0.125 -0.125 8.0 -8.0 1.0 | |||
constexpr size_t alpha = 5 + 4 - 1; | |||
float* mid_buf1 = transform_mid_buf; | |||
size_t OC = oc_end - oc_start; | |||
size_t oc = oc_start + oc_index; | |||
#define cb(m, n) \ | |||
transform_mid_buf[m * alpha + n] = \ | |||
output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \ | |||
unit_idx * OC + oc_index]; | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
#undef cb | |||
#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> s##i; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
OUTPUT_TRANSFORM(m, s); | |||
#define cb(i) \ | |||
do { \ | |||
auto add12 = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||
auto add34 = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||
auto add56 = \ | |||
GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ | |||
auto sub12 = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||
auto sub34 = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||
auto sub56 = \ | |||
GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ | |||
mid_buf1[0] = GET_VECTOR_LOW_ELEM(s, i, 0) + add12 + add34 + add56; \ | |||
mid_buf1[1] = sub12 + sub34 * 0.5 + sub56 * 2.0; \ | |||
mid_buf1[2] = add12 + add34 * 0.25 + add56 * 4.0; \ | |||
mid_buf1[3] = sub12 + sub34 * 0.125 + sub56 * 8.0 + \ | |||
GET_VECTOR_HIGH_ELEM(s, i, 3); \ | |||
mid_buf1 += 4; \ | |||
} while (0); | |||
mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
mid_buf1 = transform_mid_buf; | |||
#undef cb | |||
if (oh_start + 4 <= OH && ow_start + 4 <= OW) { | |||
float32x4_t bias0; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias0 = vdupq_n_f32(bias[oc]); | |||
} | |||
rep(i, 4) { | |||
size_t oh = oh_start + i; | |||
float32x4_t item0 = vld1q_f32(mid_buf1); | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
item0 = vaddq_f32(item0, bias0); | |||
} else if (bmode == BiasMode::BIAS) { | |||
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); | |||
item0 = vaddq_f32(item0, bias0); | |||
} | |||
item0 = op(item0); | |||
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||
mid_buf1 += 4; | |||
} | |||
} else { | |||
for (size_t oho = 0; oho < 4 && oh_start + oho < OH; ++oho) { | |||
for (size_t owo = 0; owo < 4 && ow_start + owo < OW; ++owo) { | |||
size_t oh = oh_start + oho; | |||
size_t ow = ow_start + owo; | |||
float res = mid_buf1[oho * 4 + owo]; | |||
if (bmode == BiasMode::BIAS) { | |||
res += bias[oc * OH * OW + oh * OW + ow]; | |||
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
res += bias[oc]; | |||
} | |||
res = op(res); | |||
output[oc * OH * OW + oh * OW + ow] = res; | |||
} | |||
} | |||
} | |||
} | |||
}; | |||
#undef OUTPUT_TRANSFORM | |||
#undef GET_VECTOR_HIGH_ELEM | |||
#undef GET_VECTOR_LOW_ELEM | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f) | |||
void winograd_4x5_1x1_f::filter(const float* filter, | |||
float* filter_transform_buf, | |||
float* transform_mid_buf, size_t OC, size_t IC, | |||
size_t oc_start, size_t oc_end) { | |||
FilterTransform4X5::transform(filter, filter_transform_buf, | |||
transform_mid_buf, OC, IC, oc_start, oc_end); | |||
} | |||
void winograd_4x5_1x1_f::input(const float* input, float* input_transform_buf, | |||
float* transform_mid_buf, size_t IH, size_t IW, | |||
size_t IC, size_t PH, size_t PW, | |||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||
constexpr int alpha = 4 + 5 - 1; | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
rep(ic, IC) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||
InputTransform4X5::transform<true>( | |||
input, input_transform_buf, transform_mid_buf, ih_start, | |||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||
} else { | |||
InputTransform4X5::transform<false>( | |||
input, input_transform_buf, transform_mid_buf, ih_start, | |||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||
} | |||
} | |||
} | |||
} | |||
void winograd_4x5_1x1_f::output(const float* output_transform_buf, | |||
const float* bias, float* output, | |||
float* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
#define cb(_bmode, _nonline_op, ...) \ | |||
OutputTransform4X5<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp32_F45, cb, float, float, bmode, | |||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||
nr_units_in_tile, src_dtype, dst_dtype); | |||
} | |||
} | |||
#undef cb | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,500 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F54) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
namespace { | |||
struct FilterTransform5X4 { | |||
#define FILTER_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = d##0; \ | |||
auto tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \ | |||
auto tmp1 = d##1 * 0.3555556f + d##3 * 0.0888889f; \ | |||
wd##1 = tmp0 + tmp1; \ | |||
wd##2 = tmp0 - tmp1; \ | |||
tmp0 = (d##0 + d##2) * -0.2222222f; \ | |||
tmp1 = (d##1 + d##3) * -0.2222222f; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \ | |||
tmp1 = d##1 * 0.0222222f + d##3 * 0.0888889f; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = d##3; \ | |||
} while (0) | |||
static void transform(const float* filter, float* filter_transform_buf, | |||
float* transform_mid_buf, size_t OC, size_t IC, | |||
size_t oc_start, size_t oc_end) { | |||
// Gg * GT | |||
// G | |||
// 1 0 0 0 | |||
// 0.7111111 0.3555556 0.1777778 0.0888889 | |||
// 0.7111111 -0.3555556 0.1777778 -0.0888889 | |||
// -0.2222222 -0.2222222 -0.2222222 -0.2222222 | |||
// -0.2222222 0.2222222 -0.2222222 0.2222222 | |||
// 0.0111111 0.0222222 0.0444444 0.0888889 | |||
// 0.0111111 -0.0222222 0.0444444 -0.0888889 | |||
// 0 0 0 1 | |||
constexpr size_t alpha = 4 + 5 - 1; | |||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||
rep(ic, IC) { | |||
const float* fptr = filter + (oc * IC + ic) * 4 * 4; | |||
#define cb(i) Vector<float, 4> g##i = Vector<float, 4>::load(fptr + 4 * i); | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 4> wd##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> wdt##i; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
FILTER_TRANSFORM(g, wd); | |||
#if MEGDNN_AARCH64 | |||
TRANSPOSE_8x4(wd, wdt); | |||
FILTER_TRANSFORM(wdt, ret); | |||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||
oc] = transform_mid_buf[j * alpha + i]; | |||
} | |||
#else | |||
#define cb(i) \ | |||
do { \ | |||
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ | |||
auto tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.7111111f + \ | |||
GET_VECTOR_ELEM(wd, i, 2) * 0.1777778f; \ | |||
auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.3555556f + \ | |||
GET_VECTOR_ELEM(wd, i, 3) * 0.0888889f; \ | |||
mid_buf1[1] = tmp0 + tmp1; \ | |||
mid_buf1[2] = tmp0 - tmp1; \ | |||
tmp0 = (GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * \ | |||
-0.2222222f; \ | |||
tmp1 = (GET_VECTOR_ELEM(wd, i, 1) + GET_VECTOR_ELEM(wd, i, 3)) * \ | |||
-0.2222222f; \ | |||
mid_buf1[3] = tmp0 + tmp1; \ | |||
mid_buf1[4] = tmp0 - tmp1; \ | |||
tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0111111f + \ | |||
GET_VECTOR_ELEM(wd, i, 2) * 0.0444444f; \ | |||
tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0222222f + \ | |||
GET_VECTOR_ELEM(wd, i, 3) * 0.0888889f; \ | |||
mid_buf1[5] = tmp0 + tmp1; \ | |||
mid_buf1[6] = tmp0 - tmp1; \ | |||
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ | |||
mid_buf1 += 8; \ | |||
} while (0); | |||
#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) | |||
float* mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
mid_buf1 = transform_mid_buf; | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + | |||
oc] = transform_mid_buf[i * alpha + j]; | |||
} | |||
#endif | |||
} | |||
} | |||
} | |||
}; | |||
#undef FILTER_TRANSFORM | |||
#undef GET_VECTOR_ELEM | |||
struct InputTransform5X4 { | |||
#define INPUT_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||
auto tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ | |||
auto tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ | |||
wd##1 = tmp0 + tmp1; \ | |||
wd##2 = tmp0 - tmp1; \ | |||
tmp0 = d##2 - d##4 * 4.25f + d##6; \ | |||
tmp1 = d##1 - d##3 * 4.25f + d##5; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ | |||
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||
} while (0) | |||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||
vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) | |||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||
vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) | |||
template <bool inner> | |||
static void transform(const float* input, float* input_transform_buf, | |||
float* transform_mid_buf, int ih_start, int iw_start, | |||
size_t ic, size_t IH, size_t IW, size_t IC, | |||
size_t unit_idx, size_t nr_units_in_tile) { | |||
// BTd * B | |||
// BT | |||
// 1 0 -5.25 0 5.25 0 -1 0 | |||
// 0 2 4 -2.5 -5 0.5 1 0 | |||
// 0 -2 4 2.5 -5 -0.5 1 0 | |||
// 0 1 1 -4.25 -4.25 1 1 0 | |||
// 0 -1 1 4.25 -4.25 -1 1 0 | |||
// 0 0.5 0.25 -2.5 -1.25 2 1 0 | |||
// 0 -0.5 0.25 2.5 -1.25 -2 1 0 | |||
// 0 -1 0 5.25 0 -5.25 0 1 | |||
constexpr size_t alpha = 4 + 5 - 1; | |||
if (!inner) { | |||
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | |||
} | |||
#define cb(i) Vector<float, 8> d##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
if (inner) { | |||
const float* input_ptr = | |||
input + ic * IH * IW + ih_start * IW + iw_start; | |||
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} else { | |||
int ih0_act = std::max<int>(ih_start, 0), | |||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||
iw0_act = std::max<int>(iw_start, 0), | |||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||
transform_mid_buf[iho * alpha + iwo] = | |||
input[ic * IH * IW + ih * IW + iw]; | |||
} | |||
} | |||
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} | |||
#define cb(i) Vector<float, 8> wd##i, ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
INPUT_TRANSFORM(d, wd); | |||
#if MEGDNN_AARCH64 | |||
TRANSPOSE_8x8(wd, d); | |||
INPUT_TRANSFORM(d, ret); | |||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||
unit_idx * IC + ic] = | |||
transform_mid_buf[j * alpha + i]; | |||
} | |||
#else | |||
#define cb(i) \ | |||
do { \ | |||
mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ | |||
5.25 * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ | |||
GET_VECTOR_LOW_ELEM(wd, i, 2)); \ | |||
mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \ | |||
GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||
5.25 * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ | |||
auto tmp0 = 4 * GET_VECTOR_LOW_ELEM(wd, i, 2) + \ | |||
-5 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 2); \ | |||
auto tmp1 = 2 * GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||
-2.5 * GET_VECTOR_LOW_ELEM(wd, i, 3) + \ | |||
0.5 * GET_VECTOR_HIGH_ELEM(wd, i, 1); \ | |||
mid_buf1[1] = tmp0 + tmp1; \ | |||
mid_buf1[2] = tmp0 - tmp1; \ | |||
tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ | |||
-4.25 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 2); \ | |||
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||
GET_VECTOR_LOW_ELEM(wd, i, 3) * -4.25 + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 1); \ | |||
mid_buf1[3] = tmp0 + tmp1; \ | |||
mid_buf1[4] = tmp0 - tmp1; \ | |||
tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) * 0.25 + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 0) * -1.25 + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 2); \ | |||
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5 + \ | |||
GET_VECTOR_LOW_ELEM(wd, i, 3) * -2.5 + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2; \ | |||
mid_buf1[5] = tmp0 + tmp1; \ | |||
mid_buf1[6] = tmp0 - tmp1; \ | |||
mid_buf1 += 8; \ | |||
} while (0); | |||
float* mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
mid_buf1 = transform_mid_buf; | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||
unit_idx * IC + ic] = | |||
transform_mid_buf[i * alpha + j]; | |||
} | |||
#endif | |||
} | |||
}; | |||
#undef INPUT_TRANSFORM | |||
#define OUTPUT_TRANSFORM(m, s) \ | |||
do { \ | |||
auto m1addm2 = m##1 + m##2; \ | |||
auto m1subm2 = m##1 - m##2; \ | |||
auto m3addm4 = m##3 + m##4; \ | |||
auto m3subm4 = m##3 - m##4; \ | |||
auto m5addm6 = (m##5 + m##6); \ | |||
auto m5subm6 = (m##5 - m##6); \ | |||
s##0 = m##0; \ | |||
CONCAT(s, 0).add(m1addm2).add(m3addm4).add(m5addm6); \ | |||
CONCAT(s, 1) = m3subm4; \ | |||
CONCAT(s, 1).mla(m1subm2, 0.5f).mla(m5subm6, 2.0f); \ | |||
CONCAT(s, 2) = m3addm4; \ | |||
CONCAT(s, 2).mla(m1addm2, 0.25f).mla(m5addm6, 4.0f); \ | |||
CONCAT(s, 3) = m3subm4; \ | |||
CONCAT(s, 3).mla(m1subm2, 0.125f).mla(m5subm6, 8.0f); \ | |||
CONCAT(s, 4) = m##7; \ | |||
CONCAT(s, 4).mla(m1addm2, 0.0625f).add(m3addm4).mla(m5addm6, 16.0f); \ | |||
} while (0) | |||
template <BiasMode bmode, typename Op> | |||
struct OutputTransform5X4 { | |||
static void transform(const float* output_transform_buf, const float* bias, | |||
float* output, float* transform_mid_buf, | |||
size_t oh_start, size_t ow_start, size_t OH, | |||
size_t OW, size_t oc_start, size_t oc_end, | |||
size_t oc_index, size_t unit_idx, | |||
size_t nr_units_in_tile, const DType& src_dtype, | |||
const DType& dst_dtype) { | |||
Op op(src_dtype, dst_dtype); | |||
//! AT * m * A | |||
// AT | |||
// 1 1 1 1 1 1 1 0 | |||
// 0 0.5 -0.5 1 -1 2 -2 0 | |||
// 0 0.25 0.25 1 1 4 4 0 | |||
// 0 0.125 -0.125 1 -1 8 -8 0 | |||
// 0 0.0625 0.0625 1 1 16 16 1 | |||
constexpr size_t alpha = 5 + 4 - 1; | |||
float* mid_buf1 = transform_mid_buf; | |||
size_t OC = oc_end - oc_start; | |||
size_t oc = oc_start + oc_index; | |||
#define cb(m, n) \ | |||
transform_mid_buf[m * alpha + n] = \ | |||
output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \ | |||
unit_idx * OC + oc_index]; | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
#undef cb | |||
#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> s##i, ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
OUTPUT_TRANSFORM(m, s); | |||
#define cb(i) \ | |||
do { \ | |||
auto m1addm2 = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||
auto m1subm2 = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||
auto m3addm4 = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||
auto m3subm4 = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||
auto m5addm6 = \ | |||
GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ | |||
auto m5subm6 = \ | |||
GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ | |||
mid_buf1[0] = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ | |||
mid_buf1[1] = 0.5f * m1subm2 + m3subm4 + 2.0f * m5subm6; \ | |||
mid_buf1[2] = 0.25f * m1addm2 + m3addm4 + 4.0f * m5addm6; \ | |||
mid_buf1[3] = 0.125f * m1subm2 + m3subm4 + 8.0f * m5subm6; \ | |||
mid_buf1[4] = 0.0625f * m1addm2 + m3addm4 + 16.0f * m5addm6 + \ | |||
GET_VECTOR_HIGH_ELEM(s, i, 3); \ | |||
mid_buf1 += 5; \ | |||
} while (0); | |||
mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(5, cb); | |||
mid_buf1 = transform_mid_buf; | |||
#undef cb | |||
if (oh_start + 5 <= OH && ow_start + 5 <= OW) { | |||
float32x4_t bias0; | |||
float32_t bias1; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias0 = vdupq_n_f32(bias[oc]); | |||
bias1 = bias[oc]; | |||
} | |||
rep(i, 5) { | |||
size_t oh = oh_start + i; | |||
float32x4_t item0 = vld1q_f32(mid_buf1); | |||
float32_t item1 = mid_buf1[4]; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
item0 = vaddq_f32(item0, bias0); | |||
item1 = item1 + bias1; | |||
} else if (bmode == BiasMode::BIAS) { | |||
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); | |||
bias1 = bias[oc * OH * OW + oh * OW + ow_start + 4]; | |||
item0 = vaddq_f32(item0, bias0); | |||
item1 = item1 + bias1; | |||
} | |||
item0 = op(item0); | |||
item1 = op(item1); | |||
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||
output[oc * OH * OW + oh * OW + ow_start + 4] = item1; | |||
mid_buf1 += 5; | |||
} | |||
} else { | |||
for (size_t oho = 0; oho < 5 && oh_start + oho < OH; ++oho) { | |||
for (size_t owo = 0; owo < 5 && ow_start + owo < OW; ++owo) { | |||
size_t oh = oh_start + oho; | |||
size_t ow = ow_start + owo; | |||
float res = mid_buf1[oho * 5 + owo]; | |||
if (bmode == BiasMode::BIAS) { | |||
res += bias[oc * OH * OW + oh * OW + ow]; | |||
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
res += bias[oc]; | |||
} | |||
res = op(res); | |||
output[oc * OH * OW + oh * OW + ow] = res; | |||
} | |||
} | |||
} | |||
} | |||
}; | |||
#undef OUTPUT_TRANSFORM | |||
#undef GET_VECTOR_HIGH_ELEM | |||
#undef GET_VECTOR_LOW_ELEM | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_5x4_1x1_f) | |||
void winograd_5x4_1x1_f::filter(const float* filter, | |||
float* filter_transform_buf, | |||
float* transform_mid_buf, size_t OC, size_t IC, | |||
size_t oc_start, size_t oc_end) { | |||
FilterTransform5X4::transform(filter, filter_transform_buf, | |||
transform_mid_buf, OC, IC, oc_start, oc_end); | |||
} | |||
void winograd_5x4_1x1_f::input(const float* input, float* input_transform_buf, | |||
float* transform_mid_buf, size_t IH, size_t IW, | |||
size_t IC, size_t PH, size_t PW, | |||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||
constexpr int alpha = 5 + 4 - 1; | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
rep(ic, IC) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||
InputTransform5X4::transform<true>( | |||
input, input_transform_buf, transform_mid_buf, ih_start, | |||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||
} else { | |||
InputTransform5X4::transform<false>( | |||
input, input_transform_buf, transform_mid_buf, ih_start, | |||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||
} | |||
} | |||
} | |||
} | |||
void winograd_5x4_1x1_f::output(const float* output_transform_buf, | |||
const float* bias, float* output, | |||
float* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, | |||
size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
#define cb(_bmode, _nonline_op, ...) \ | |||
OutputTransform5X4<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp32_F54, cb, float, float, bmode, | |||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||
nr_units_in_tile, src_dtype, dst_dtype); | |||
} | |||
} | |||
#undef cb | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,424 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/filter_transform.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
namespace { | |||
/** | |||
* input transform | |||
* | |||
* wd0 = (d0 - d6) + 5.25 * (d4 - d2) | |||
* wd1 = (d6 + d2 - 4.25 * d4) + (d1 + d5 - 4.25 * d3) | |||
* wd2 = (d6 + d2 - 4.25 * d4) - (d1 + d5 - 4.25 * d3) | |||
* wd3 = (d6 + 0.25 * d2 - 1.25 * d4) + 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) | |||
* wd4 = (d6 + 0.25 * d2 - 1.25 * d4) - 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) | |||
* wd5 = (d6 - 5.0 * d4 + 4.0 * d2) + 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) | |||
* wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) | |||
* wd7 = (d7 - d1) + 5.25 * (d3 - d5) | |||
*/ | |||
#define INPUT_TRANSFORM(d, wd) \ | |||
do { \ | |||
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ | |||
auto tmp0 = d##6 + d##2 - d##4 * 4.25f; \ | |||
auto tmp1 = d##1 + d##5 - d##3 * 4.25f; \ | |||
wd##1 = tmp0 + tmp1; \ | |||
wd##2 = tmp0 - tmp1; \ | |||
tmp0 = d##6 + d##2 * 0.25f - d##4 * 1.25f; \ | |||
tmp1 = (d##5 + d##1 * 0.25f - d##3 * 1.25f) * 2.0f; \ | |||
wd##3 = tmp0 + tmp1; \ | |||
wd##4 = tmp0 - tmp1; \ | |||
tmp0 = d6 - d4 * 5.0f + d2 * 4.0f; \ | |||
tmp1 = (d1 + d5 * 0.25f - d3 * 1.25f) * 2.0f; \ | |||
wd##5 = tmp0 + tmp1; \ | |||
wd##6 = tmp0 - tmp1; \ | |||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||
} while (0); | |||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||
vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) | |||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||
vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) | |||
struct InputTransform6X3 { | |||
template <bool inner> | |||
static void transform(const float* input, float* input_transform_buf, | |||
float* transform_mid_buf, int ih_start, int iw_start, | |||
size_t ic, size_t IH, size_t IW, size_t IC, | |||
size_t unit_idx, size_t nr_units_in_tile) { | |||
constexpr size_t alpha = 6 + 3 - 1; | |||
if (!inner) { | |||
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | |||
} | |||
#define cb(i) Vector<float, 8> d##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
if (inner) { | |||
const float* input_ptr = | |||
input + ic * IH * IW + ih_start * IW + iw_start; | |||
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} else { | |||
int ih0_act = std::max<int>(ih_start, 0), | |||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||
iw0_act = std::max<int>(iw_start, 0), | |||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||
transform_mid_buf[iho * alpha + iwo] = | |||
input[ic * IH * IW + ih * IW + iw]; | |||
} | |||
} | |||
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
} | |||
#define cb(i) Vector<float, 8> wd##i, ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
INPUT_TRANSFORM(d, wd); | |||
#if MEGDNN_AARCH64 | |||
TRANSPOSE_8x8(wd, d); | |||
INPUT_TRANSFORM(d, ret); | |||
#define cb(i) ret##i.save(transform_mid_buf + i * alpha); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||
unit_idx * IC + ic] = | |||
transform_mid_buf[j * alpha + i]; | |||
} | |||
#else | |||
//! 1 0 0 0 0 0 0 0 | |||
//! 0 1 -1 0.5 -0.5 2 -2 -1 | |||
//! -5.25 1 1 0.25 0.25 4 4 0 | |||
//! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25 | |||
//! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0 | |||
//! 0 1 -1 2 -2 0.5 -0.5 -5.25 | |||
//! -1 1 1 1 1 1 1 0 | |||
//! 0 0 0 0 0 0 0 1 | |||
#define cb(i) \ | |||
do { \ | |||
mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ | |||
5.25f * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ | |||
GET_VECTOR_LOW_ELEM(wd, i, 2)); \ | |||
mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \ | |||
GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||
5.25f * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ | |||
auto tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 2) - \ | |||
4.25f * GET_VECTOR_HIGH_ELEM(wd, i, 0); \ | |||
auto tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 1) - \ | |||
4.25f * GET_VECTOR_LOW_ELEM(wd, i, 3); \ | |||
mid_buf1[1] = tmp0 + tmp1; \ | |||
mid_buf1[2] = tmp0 - tmp1; \ | |||
tmp0 = GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ | |||
0.25f * GET_VECTOR_LOW_ELEM(wd, i, 2) - \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 0) * 1.25f; \ | |||
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5f - \ | |||
GET_VECTOR_LOW_ELEM(wd, i, 3) * 2.5f + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2.f; \ | |||
mid_buf1[3] = tmp0 + tmp1; \ | |||
mid_buf1[4] = tmp0 - tmp1; \ | |||
tmp0 = GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ | |||
(GET_VECTOR_LOW_ELEM(wd, i, 2) - \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 0) * 1.25f) * \ | |||
4; \ | |||
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 2.f - \ | |||
GET_VECTOR_LOW_ELEM(wd, i, 3) * 2.5f + \ | |||
GET_VECTOR_HIGH_ELEM(wd, i, 1) * 0.5f; \ | |||
mid_buf1[5] = tmp0 + tmp1; \ | |||
mid_buf1[6] = tmp0 - tmp1; \ | |||
mid_buf1 += 8; \ | |||
} while (0); | |||
float* mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
mid_buf1 = transform_mid_buf; | |||
#undef cb | |||
rep(i, alpha) rep(j, alpha) { | |||
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + | |||
unit_idx * IC + ic] = | |||
transform_mid_buf[i * alpha + j]; | |||
} | |||
#endif | |||
} | |||
}; | |||
#undef INPUT_TRANSFORM | |||
/** | |||
* Output Transform: use fma | |||
* | |||
* s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6) / 32 | |||
* s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6) / 32 | |||
* s2 = (m1 + m2) + 4 * (m3 + m4) + 8 * (m5 + m6) / 32 | |||
* s3 = (m1 - m2) + 8 * (m3 - m4) + 4 * (m5 - m6) / 32 | |||
* s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6) / 32 | |||
* s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) / 32 + m7 | |||
*/ | |||
#define OUTPUT_TRANSFORM(m, s) \ | |||
do { \ | |||
auto m1addm2 = m##1 + m##2; \ | |||
auto m1subm2 = m##1 - m##2; \ | |||
auto m3addm4 = m##3 + m##4; \ | |||
auto m3subm4 = m##3 - m##4; \ | |||
auto m5addm6 = (m##5 + m##6) * 0.03125f; \ | |||
auto m5subm6 = (m##5 - m##6) * 0.03125f; \ | |||
s##0 = m##0; \ | |||
CONCAT(s, 0).mla(m5addm6, 32.f).add(m3addm4).add(m1addm2); \ | |||
CONCAT(s, 1) = m1subm2; \ | |||
CONCAT(s, 1).mla(m3subm4, 2.f).mla(m5subm6, 16.f); \ | |||
CONCAT(s, 2) = m1addm2; \ | |||
CONCAT(s, 2).mla(m3addm4, 4.f).mla(m5addm6, 8.f); \ | |||
CONCAT(s, 3) = m1subm2; \ | |||
CONCAT(s, 3).mla(m3subm4, 8.f).mla(m5subm6, 4.f); \ | |||
CONCAT(s, 4) = m1addm2; \ | |||
CONCAT(s, 4).mla(m3addm4, 16.f).mla(m5addm6, 2.f); \ | |||
CONCAT(s, 5) = m1subm2; \ | |||
CONCAT(s, 5).mla(m3subm4, 32.f).add(m5subm6).add(m##7); \ | |||
} while (0); | |||
template <BiasMode bmode, typename Op> | |||
struct OutputTransform6X3 { | |||
static void transform(const float* output_transform_buf, const float* bias, | |||
float* output, float* transform_mid_buf, | |||
size_t oh_start, size_t ow_start, size_t OH, | |||
size_t OW, size_t oc_start, size_t oc_end, | |||
size_t oc_index, size_t unit_idx, | |||
size_t nr_units_in_tile, const DType& src_dtype, | |||
const DType& dst_dtype) { | |||
constexpr size_t alpha = 6 + 3 - 1; | |||
Op op(src_dtype, dst_dtype); | |||
float* mid_buf1 = transform_mid_buf; | |||
//! AT * m * A | |||
size_t OC = oc_end - oc_start; | |||
size_t oc = oc_start + oc_index; | |||
#define cb(m, n) \ | |||
transform_mid_buf[m * alpha + n] = \ | |||
output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \ | |||
unit_idx * OC + oc_index]; | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
#undef cb | |||
#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) Vector<float, 8> s##i, ret##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
OUTPUT_TRANSFORM(m, s); | |||
/** | |||
* Output transform: m * A | |||
* | |||
* 1 0 0 0 0 0 | |||
* 1 1 1 1 1 1 | |||
* 1 -1 1 -1 1 -1 | |||
* 1 2 4 8 16 32 | |||
* 1 -2 4 -8 16 -32 | |||
* 1 0.5 0.25 0.125 0.0625 0.03125 | |||
* 1 -0.5 0.25 -0.125 0.0625 -0.03125 | |||
* 0 0.0 0 0 0 1 | |||
*/ | |||
#define cb(i) \ | |||
do { \ | |||
auto m1addm2 = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||
auto m1subm2 = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||
auto m3addm4 = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||
auto m3subm4 = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||
auto m5addm6 = \ | |||
GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ | |||
auto m5subm6 = \ | |||
GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ | |||
mid_buf1[0] = \ | |||
GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ | |||
mid_buf1[1] = m1subm2 + 2.f * m3subm4 + 0.5f * m5subm6; \ | |||
mid_buf1[2] = m1addm2 + 4.f * m3addm4 + 0.25f * m5addm6; \ | |||
mid_buf1[3] = m1subm2 + 8.f * m3subm4 + 0.125f * m5subm6; \ | |||
mid_buf1[4] = m1addm2 + 16.f * m3addm4 + 0.0625f * m5addm6; \ | |||
mid_buf1[5] = m1subm2 + 32.f * m3subm4 + 0.03125f * m5subm6 + \ | |||
GET_VECTOR_HIGH_ELEM(s, i, 3); \ | |||
mid_buf1 += 6; \ | |||
} while (0); | |||
mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(6, cb); | |||
mid_buf1 = transform_mid_buf; | |||
#undef cb | |||
if (oh_start + 6 <= OH && ow_start + 6 <= OW) { | |||
float32x4_t bias0; | |||
float32x2_t bias1; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias0 = vdupq_n_f32(bias[oc]); | |||
bias1 = vdup_n_f32(bias[oc]); | |||
} | |||
rep(i, 6) { | |||
size_t oh = oh_start + i; | |||
float32x4_t item0 = vld1q_f32(mid_buf1); | |||
float32x2_t item1 = vld1_f32(mid_buf1 + 4); | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
item0 = vaddq_f32(item0, bias0); | |||
item1 = vadd_f32(item1, bias1); | |||
} else if (bmode == BiasMode::BIAS) { | |||
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); | |||
bias1 = vld1_f32(bias + oc * OH * OW + oh * OW + ow_start + | |||
4); | |||
item0 = vaddq_f32(item0, bias0); | |||
item1 = vadd_f32(item1, bias1); | |||
} | |||
item0 = op(item0); | |||
item1 = vset_lane_f32(op(vget_lane_f32(item1, 0)), item1, 0); | |||
item1 = vset_lane_f32(op(vget_lane_f32(item1, 1)), item1, 1); | |||
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||
vst1_f32(output + oc * OH * OW + oh * OW + ow_start + 4, item1); | |||
mid_buf1 += 6; | |||
} | |||
} else { | |||
for (size_t oho = 0; oho < 6 && oh_start + oho < OH; ++oho) { | |||
for (size_t owo = 0; owo < 6 && ow_start + owo < OW; ++owo) { | |||
size_t oh = oh_start + oho; | |||
size_t ow = ow_start + owo; | |||
float res = mid_buf1[oho * 6 + owo]; | |||
if (bmode == BiasMode::BIAS) { | |||
res += bias[oc * OH * OW + oh * OW + ow]; | |||
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
res += bias[oc]; | |||
} | |||
res = op(res); | |||
output[oc * OH * OW + oh * OW + ow] = res; | |||
} | |||
} | |||
} | |||
} | |||
}; | |||
#undef GET_VECTOR_HIGH_ELEM | |||
#undef GET_VECTOR_LOW_ELEM | |||
#undef OUTPUT_TRANSFORM | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f) | |||
void winograd_6x3_1x1_f::filter(const float* filter, | |||
float* filter_transform_buf, | |||
float* transform_mid_buf, size_t OC, size_t IC, | |||
size_t oc_start, size_t oc_end) { | |||
FilterTransform6X3<param::MatrixMul::Format::DEFAULT>::transform( | |||
filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, | |||
oc_end); | |||
} | |||
void winograd_6x3_1x1_f::input(const float* input, float* input_transform_buf, | |||
float* transform_mid_buf, size_t IH, size_t IW, | |||
size_t IC, size_t PH, size_t PW, | |||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||
constexpr int alpha = 3 + 6 - 1; | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
rep(ic, IC) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||
InputTransform6X3::transform<true>( | |||
input, input_transform_buf, transform_mid_buf, ih_start, | |||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||
} else { | |||
InputTransform6X3::transform<false>( | |||
input, input_transform_buf, transform_mid_buf, ih_start, | |||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||
} | |||
} | |||
} | |||
} | |||
void winograd_6x3_1x1_f::output(const float* output_transform_buf, | |||
const float* bias, float* output, | |||
float* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, | |||
size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
#define cb(_bmode, _nonline_op, ...) \ | |||
OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp32_F63, cb, float, float, bmode, | |||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||
nr_units_in_tile, src_dtype, dst_dtype); | |||
} | |||
} | |||
#undef cb | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,351 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/filter_transform.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/common/winograd/winograd_helper.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63_4x4) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
namespace { | |||
struct InputTransform6X3 { | |||
template <bool inner> | |||
static void prepare(const float* input, float* patch, float* patchT, | |||
int ih_start, int iw_start, size_t IH, size_t IW, | |||
size_t ic, size_t IC) { | |||
constexpr size_t alpha = 6 + 3 - 1; | |||
if (!(inner && ic + 4 < IC)) { | |||
memset(patch, 0, sizeof(float) * 4 * alpha * alpha); | |||
} | |||
if (inner) { | |||
const float* input_ptr = | |||
input + ic * IH * IW + ih_start * IW + iw_start; | |||
for (size_t ico = 0; ico < 4; ++ico) { | |||
if (ic + ico < IC) { | |||
#define cb(i) \ | |||
auto v##i##0 = vld1q_f32(input_ptr + IW * i); \ | |||
auto v##i##1 = vld1q_f32(input_ptr + IW * i + 4); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) \ | |||
vst1q_f32(patch + ico * 8 * alpha + i * 8, v##i##0); \ | |||
vst1q_f32(patch + ico * 8 * alpha + i * 8 + 4, v##i##1); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
input_ptr += IH * IW; | |||
} | |||
} | |||
} else { | |||
int ih0_act = std::max<int>(ih_start, 0), | |||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||
iw0_act = std::max<int>(iw_start, 0), | |||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||
// partial copy | |||
for (size_t ico = 0; ico < 4; ++ico) { | |||
if (ic + ico < IC) { | |||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||
patch[ico * alpha * 8 + iho * 8 + iwo] = | |||
input[(ic + ico) * IH * IW + ih * IW + iw]; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
#define cb(i) \ | |||
transpose_4x4(patch + 8 * i + 0, patchT + 8 * i * 4, 64, 4); \ | |||
transpose_4x4(patch + 8 * i + 4, patchT + 8 * i * 4 + 4 * 4, 64, 4); | |||
UNROLL_CALL_NOWRAPPER(8, cb) | |||
#undef cb | |||
} | |||
static void transform(const float* patchT, float* input_transform_buf, | |||
size_t unit_idx, size_t nr_units_in_tile, size_t ic, | |||
size_t IC) { | |||
constexpr size_t alpha = 6 + 3 - 1; | |||
// BT * d * B | |||
#define cb(m, n) \ | |||
Vector<float, 4> d##m##n = \ | |||
Vector<float, 4>::load(patchT + m * 8 * 4 + n * 4); | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
#undef cb | |||
//! B | |||
//! 1 0 0 0 0 0 0 0 | |||
//! 0 1 -1 0.5 -0.5 2 -2 -1 | |||
//! -5.25 1 1 0.25 0.25 4 4 0 | |||
//! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25 | |||
//! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0 | |||
//! 0 1 -1 2 -2 0.5 -0.5 -5.25 | |||
//! -1 1 1 1 1 1 1 0 | |||
//! 0 0 0 0 0 0 0 1 | |||
#define cb(m) \ | |||
auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \ | |||
auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \ | |||
auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \ | |||
auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \ | |||
d5##m * 2.f + d6##m; \ | |||
auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - \ | |||
d4##m * 1.25f - d5##m * 2.f + d6##m; \ | |||
auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \ | |||
d5##m * 0.5f + d6##m; \ | |||
auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \ | |||
d5##m * 0.5f + d6##m; \ | |||
auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \ | |||
d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - \ | |||
(t##m##3 + t##m##4) * 4.25f; \ | |||
d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + \ | |||
(t##m##3 - t##m##4) * 4.25f; \ | |||
d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - \ | |||
t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; \ | |||
d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - \ | |||
t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; \ | |||
d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \ | |||
t##m##5 * 0.5f + t##m##6; \ | |||
d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - \ | |||
t##m##4 * 5.f - t##m##5 * 0.5f + t##m##6; \ | |||
d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
size_t ICB = IC / 4; | |||
size_t icb = ic / 4; | |||
#define cb(m, n) \ | |||
d##m##n.save(input_transform_buf + \ | |||
(m * alpha + n) * ICB * nr_units_in_tile * 4 + \ | |||
icb * nr_units_in_tile * 4 + unit_idx * 4); | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb) | |||
#undef cb | |||
} | |||
}; | |||
template <BiasMode bmode, typename Op> | |||
struct OutputTransform6X3 { | |||
static void transform(const float* output_transform_buf, const float* bias, | |||
float* output, float* transform_mid_buf, | |||
size_t oh_start, size_t ow_start, size_t OH, | |||
size_t OW, size_t oc_start, size_t oc_end, | |||
size_t oc_index, size_t unit_idx, | |||
size_t nr_units_in_tile, const DType& src_dtype, | |||
const DType& dst_dtype) { | |||
Op op(src_dtype, dst_dtype); | |||
//! AT * m * A | |||
constexpr size_t alpha = 6 + 3 - 1; | |||
size_t oc = oc_start + oc_index; | |||
size_t OCB = (oc_end - oc_start) / 4; | |||
size_t ocb = oc_index / 4; | |||
#define cb(m, n) \ | |||
auto v##m##n = Vector<float, 4>::load( \ | |||
output_transform_buf + \ | |||
(m * alpha + n) * OCB * nr_units_in_tile * 4 + \ | |||
ocb * nr_units_in_tile * 4 + unit_idx * 4); | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
#undef cb | |||
/** | |||
* A | |||
* | |||
* 1 0 0 0 0 0 | |||
* 1 1 1 1 1 1 | |||
* 1 -1 1 -1 1 -1 | |||
* 1 2 4 8 16 32 | |||
* 1 -2 4 -8 16 -32 | |||
* 1 0.5 0.25 0.125 0.0625 0.03125 | |||
* 1 -0.5 0.25 -0.125 0.0625 -0.03125 | |||
* 0 0.0 0 0 0 1 | |||
*/ | |||
Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||
#define cb(m) \ | |||
v1addv2 = v1##m + v2##m; \ | |||
v1subv2 = v1##m - v2##m; \ | |||
v3addv4 = v3##m + v4##m; \ | |||
v3subv4 = v3##m - v4##m; \ | |||
v5addv6 = v5##m + v6##m; \ | |||
v5subv6 = v5##m - v6##m; \ | |||
auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \ | |||
auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ | |||
auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ | |||
auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ | |||
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
v1addv2 = t##m##1 + t##m##2; \ | |||
v1subv2 = t##m##1 - t##m##2; \ | |||
v3addv4 = t##m##3 + t##m##4; \ | |||
v3subv4 = t##m##3 - t##m##4; \ | |||
v5addv6 = t##m##5 + t##m##6; \ | |||
v5subv6 = t##m##5 - t##m##6; \ | |||
v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \ | |||
v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ | |||
v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ | |||
v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ | |||
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | |||
UNROLL_CALL_NOWRAPPER(6, cb); | |||
#undef cb | |||
Vector<float, 4> vbias; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
vbias = Vector<float, 4>::load(bias + oc); | |||
#define cb(m, n) v##m##n += vbias; | |||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||
#undef cb | |||
} | |||
if (bmode != BiasMode::BIAS) { | |||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | |||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||
#undef cb | |||
} | |||
#define cb(m, n) CONCAT(v##m, n).save(transform_mid_buf + (m * 6 + n) * 4); | |||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||
#undef cb | |||
for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { | |||
for (size_t oho = 0; oho < 6 && oh_start + oho < OH; ++oho) { | |||
for (size_t owo = 0; owo < 6 && ow_start + owo < OW; ++owo) { | |||
size_t oh = oh_start + oho; | |||
size_t ow = ow_start + owo; | |||
float res = transform_mid_buf[oho * 6 * 4 + owo * 4 + oco]; | |||
if (bmode == BiasMode::BIAS) { | |||
res += bias[(oc + oco) * OH * OW + oh * OW + ow]; | |||
res = op(res); | |||
} | |||
output[(oc + oco) * OH * OW + oh * OW + ow] = res; | |||
} | |||
} | |||
} | |||
} | |||
}; | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_4x4_f) | |||
void winograd_6x3_4x4_f::filter(const float* filter, | |||
float* filter_transform_buf, | |||
float* transform_mid_buf, size_t OC, size_t IC, | |||
size_t oc_start, size_t oc_end) { | |||
FilterTransform6X3<param::MatrixMul::Format::MK4>::transform( | |||
filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, | |||
oc_end); | |||
} | |||
void winograd_6x3_4x4_f::input(const float* input, float* input_transform_buf, | |||
float* transform_mid_buf, size_t IH, size_t IW, | |||
size_t IC, size_t PH, size_t PW, | |||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||
megdnn_assert(IC % 4 == 0); | |||
constexpr int alpha = 3 + 6 - 1; | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
float* patch = transform_mid_buf; | |||
float* patchT = transform_mid_buf + 4 * alpha * alpha; | |||
for (size_t ic = 0; ic < IC; ic += 4) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||
InputTransform6X3::prepare<true>(input, patch, patchT, ih_start, | |||
iw_start, IH, IW, ic, IC); | |||
InputTransform6X3::transform(patchT, input_transform_buf, | |||
unit_idx, nr_units_in_tile, ic, | |||
IC); | |||
} else { | |||
InputTransform6X3::prepare<false>(input, patch, patchT, | |||
ih_start, iw_start, IH, IW, | |||
ic, IC); | |||
InputTransform6X3::transform(patchT, input_transform_buf, | |||
unit_idx, nr_units_in_tile, ic, | |||
IC); | |||
} | |||
} | |||
} | |||
} | |||
void winograd_6x3_4x4_f::output(const float* output_transform_buf, | |||
const float* bias, float* output, | |||
float* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
#define cb(_bmode, _nonline_op, ...) \ | |||
OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
for (size_t oc = oc_start; oc < oc_end; oc += 4) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp32_F63_4x4, cb, float, float, bmode, | |||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||
nr_units_in_tile, src_dtype, dst_dtype); | |||
} | |||
} | |||
#undef cb | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,82 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/img2col_helper.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <cstddef> | |||
#include "src/common/utils.h" | |||
namespace { | |||
template <bool is_xcorr, typename dtype> | |||
void img2col_stride(const dtype* __restrict src, | |||
dtype* __restrict dst, const int OC, const int OH, | |||
const int OW, const int IC, const int IH, const int IW, | |||
const int FH, const int FW, const int SH, const int SW) { | |||
(void)OC; | |||
size_t i = 0; | |||
rep(ic, IC) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
rep(oh, OH) { | |||
rep(ow, OW) { | |||
int fh2, fw2; | |||
if (is_xcorr) { | |||
fh2 = fh; | |||
fw2 = fw; | |||
} else { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
dst[i++] = src[ic * IH * IW + (oh * SH + fh2) * IW + | |||
(ow * SW + fw2)]; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
template <bool is_xcorr, typename dtype> | |||
void img2col(const dtype* src, dtype* dst, size_t /* OC */, size_t OH, | |||
size_t OW, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW) { | |||
size_t offset = (4 - OW % 4) % 4; | |||
size_t i = 0; | |||
rep(ic, IC) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
rep(oh, OH) { | |||
size_t ow = 0; | |||
for (; ow < OW; ow += 4) { | |||
size_t fh2, fw2; | |||
if (is_xcorr) { | |||
fh2 = fh; | |||
fw2 = fw; | |||
} else { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + | |||
(ow + fw2) + 0]; | |||
dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + | |||
(ow + fw2) + 1]; | |||
dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + | |||
(ow + fw2) + 2]; | |||
dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + | |||
(ow + fw2) + 3]; | |||
} | |||
i -= offset; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} // anonymous namespace | |||
// vim: syntax=cpp.doxygen |