Browse Source

ci(copybara): fix copybara of arm

GitOrigin-RevId: 2aa113ef47
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
3ef308e71c
100 changed files with 36936 additions and 0 deletions
  1. +75
    -0
      CMakeLists.txt
  2. +3
    -0
      dnn/include/megdnn/handle.h
  3. +16
    -0
      dnn/src/CMakeLists.txt
  4. +139
    -0
      dnn/src/aarch64/conv_bias/fp16/algos.cpp
  5. +42
    -0
      dnn/src/aarch64/conv_bias/fp16/algos.h
  6. +1037
    -0
      dnn/src/aarch64/conv_bias/fp16/stride2_kern.h
  7. +137
    -0
      dnn/src/aarch64/conv_bias/fp32/algos.cpp
  8. +47
    -0
      dnn/src/aarch64/conv_bias/fp32/algos.h
  9. +1024
    -0
      dnn/src/aarch64/conv_bias/fp32/stride2_kern.h
  10. +187
    -0
      dnn/src/aarch64/conv_bias/int8/algos.cpp
  11. +52
    -0
      dnn/src/aarch64/conv_bias/int8/algos.h
  12. +309
    -0
      dnn/src/aarch64/conv_bias/int8/strategy.cpp
  13. +69
    -0
      dnn/src/aarch64/conv_bias/int8/strategy.h
  14. +69
    -0
      dnn/src/aarch64/conv_bias/opr_impl.cpp
  15. +41
    -0
      dnn/src/aarch64/conv_bias/opr_impl.h
  16. +181
    -0
      dnn/src/aarch64/conv_bias/quint8/algos.cpp
  17. +52
    -0
      dnn/src/aarch64/conv_bias/quint8/algos.h
  18. +319
    -0
      dnn/src/aarch64/conv_bias/quint8/strategy.cpp
  19. +48
    -0
      dnn/src/aarch64/conv_bias/quint8/strategy.h
  20. +44
    -0
      dnn/src/aarch64/handle.cpp
  21. +33
    -0
      dnn/src/aarch64/handle.h
  22. +1038
    -0
      dnn/src/aarch64/matrix_mul/algos.cpp
  23. +250
    -0
      dnn/src/aarch64/matrix_mul/algos.h
  24. +1888
    -0
      dnn/src/aarch64/matrix_mul/asm/common.h
  25. +2589
    -0
      dnn/src/aarch64/matrix_mul/fp16/strategy.cpp
  26. +29
    -0
      dnn/src/aarch64/matrix_mul/fp16/strategy.h
  27. +439
    -0
      dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp
  28. +39
    -0
      dnn/src/aarch64/matrix_mul/fp32/common.h
  29. +718
    -0
      dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h
  30. +1242
    -0
      dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h
  31. +166
    -0
      dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
  32. +30
    -0
      dnn/src/aarch64/matrix_mul/fp32/strategy.h
  33. +570
    -0
      dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp
  34. +1175
    -0
      dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h
  35. +132
    -0
      dnn/src/aarch64/matrix_mul/int16/strategy.cpp
  36. +29
    -0
      dnn/src/aarch64/matrix_mul/int16/strategy.h
  37. +658
    -0
      dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp
  38. +856
    -0
      dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h
  39. +1375
    -0
      dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h
  40. +892
    -0
      dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h
  41. +263
    -0
      dnn/src/aarch64/matrix_mul/int8/strategy.cpp
  42. +34
    -0
      dnn/src/aarch64/matrix_mul/int8/strategy.h
  43. +116
    -0
      dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp
  44. +34
    -0
      dnn/src/aarch64/matrix_mul/int8_dot/gemv.h
  45. +1552
    -0
      dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h
  46. +113
    -0
      dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp
  47. +26
    -0
      dnn/src/aarch64/matrix_mul/int8_dot/strategy.h
  48. +439
    -0
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h
  49. +1300
    -0
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h
  50. +200
    -0
      dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp
  51. +27
    -0
      dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h
  52. +93
    -0
      dnn/src/aarch64/matrix_mul/opr_impl.cpp
  53. +63
    -0
      dnn/src/aarch64/matrix_mul/opr_impl.h
  54. +1398
    -0
      dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h
  55. +113
    -0
      dnn/src/aarch64/matrix_mul/quint8/strategy.cpp
  56. +28
    -0
      dnn/src/aarch64/matrix_mul/quint8/strategy.h
  57. +177
    -0
      dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp
  58. +35
    -0
      dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h
  59. +1092
    -0
      dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h
  60. +114
    -0
      dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp
  61. +26
    -0
      dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h
  62. +183
    -0
      dnn/src/aarch64/relayout/opr_impl.cpp
  63. +31
    -0
      dnn/src/aarch64/relayout/opr_impl.h
  64. +393
    -0
      dnn/src/aarch64/rotate/opr_impl.cpp
  65. +35
    -0
      dnn/src/aarch64/rotate/opr_impl.h
  66. +43
    -0
      dnn/src/aarch64/warp_perspective/opr_impl.cpp
  67. +30
    -0
      dnn/src/aarch64/warp_perspective/opr_impl.h
  68. +257
    -0
      dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp
  69. +32
    -0
      dnn/src/aarch64/warp_perspective/warp_perspective_cv.h
  70. +380
    -0
      dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp
  71. +65
    -0
      dnn/src/arm_common/conv_bias/direct/multi_thread_common.h
  72. +561
    -0
      dnn/src/arm_common/conv_bias/f16/algos.cpp
  73. +185
    -0
      dnn/src/arm_common/conv_bias/f16/algos.h
  74. +799
    -0
      dnn/src/arm_common/conv_bias/f16/direct.cpp
  75. +32
    -0
      dnn/src/arm_common/conv_bias/f16/direct.h
  76. +522
    -0
      dnn/src/arm_common/conv_bias/f16/do_conv_stride1.cpp
  77. +32
    -0
      dnn/src/arm_common/conv_bias/f16/do_conv_stride1.h
  78. +341
    -0
      dnn/src/arm_common/conv_bias/f16/helper.h
  79. +33
    -0
      dnn/src/arm_common/conv_bias/f16/strategy.h
  80. +373
    -0
      dnn/src/arm_common/conv_bias/f16/strategy_2x3.cpp
  81. +407
    -0
      dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp
  82. +488
    -0
      dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp
  83. +608
    -0
      dnn/src/arm_common/conv_bias/f16/strategy_6x3.cpp
  84. +754
    -0
      dnn/src/arm_common/conv_bias/fp32/algos.cpp
  85. +223
    -0
      dnn/src/arm_common/conv_bias/fp32/algos.h
  86. +911
    -0
      dnn/src/arm_common/conv_bias/fp32/direct.cpp
  87. +29
    -0
      dnn/src/arm_common/conv_bias/fp32/direct.h
  88. +735
    -0
      dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp
  89. +34
    -0
      dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h
  90. +513
    -0
      dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp
  91. +32
    -0
      dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h
  92. +164
    -0
      dnn/src/arm_common/conv_bias/fp32/filter_transform.h
  93. +204
    -0
      dnn/src/arm_common/conv_bias/fp32/helper.h
  94. +39
    -0
      dnn/src/arm_common/conv_bias/fp32/strategy.h
  95. +346
    -0
      dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp
  96. +483
    -0
      dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp
  97. +500
    -0
      dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp
  98. +424
    -0
      dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp
  99. +351
    -0
      dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp
  100. +82
    -0
      dnn/src/arm_common/conv_bias/img2col_helper.h

+ 75
- 0
CMakeLists.txt View File

@@ -19,11 +19,14 @@ CHECK_CXX_COMPILER_FLAG(-Wclass-memaccess CXX_SUPPORT_WCLASS_MEMACCESS)
set(MGE_ARCH AUTO CACHE STRING "Architecture on which MegEngine to be built.")
set_property(CACHE MGE_ARCH PROPERTY STRINGS AUTO
x86_64 i386
armv7 aarch64
naive fallback
)

option(MGE_WITH_JIT "Build MegEngine with JIT." ON)
option(MGE_WITH_HALIDE "Build MegEngine with Halide JIT" ON)
option(MGE_ARMV8_2_FEATURE_FP16 "Enable armv8.2-a+fp16 support" OFF)
option(MGE_ARMV8_2_FEATURE_DOTPROD "enable armv8.2-a+dotprod support" OFF)
option(MGE_DISABLE_FLOAT16 "Disable MegEngine float16 support." OFF)
option(MGE_WITH_CUDA "Enable MegEngine CUDA support." ON)
option(MGE_CUDA_USE_STATIC "Enable MegEngine CUDA static linking." ON)
@@ -31,12 +34,52 @@ option(MGE_WITH_TRT "Build MegEngine with TensorRT." ON)
option(MGE_USE_SYSTEM_LIB "Build MegEngine with system libraries." OFF)
option(MGB_WITH_FLATBUFFERS "Build MegBrain with FlatBuffers serialization support." ON)

if(CMAKE_TOOLCHAIN_FILE)
message("We are cross compiling.")
message("config FLATBUFFERS_FLATC_EXECUTABLE to: ${PROJECT_SOURCE_DIR}/build_dir/host_flatc/install/bin/flatc")
set(FLATBUFFERS_FLATC_EXECUTABLE "${PROJECT_SOURCE_DIR}/build_dir/host_flatc/install/bin/flatc")
if(ANDROID_TOOLCHAIN_ROOT)
if(NOT "${ANDROID_ARCH_NAME}" STREQUAL "")
set(ANDROID_ARCH ${ANDROID_ARCH_NAME})
endif()
if(${ANDROID_ARCH} STREQUAL "arm")
set(MGE_ARCH "armv7")
elseif(${ANDROID_ARCH} STREQUAL "arm64")
set(MGE_ARCH "aarch64")
else()
message(FATAL_ERROR "DO NOT SUPPORT ANDROID ARCH NOW")
endif()
elseif(IOS_TOOLCHAIN_ROOT)
if(${IOS_ARCH} STREQUAL "armv7")
set(MGE_ARCH "armv7")
elseif(${IOS_ARCH} STREQUAL "arm64")
set(MGE_ARCH "aarch64")
elseif(${IOS_ARCH} STREQUAL "armv7k")
set(MGE_ARCH "armv7")
elseif(${IOS_ARCH} STREQUAL "arm64e")
set(MGE_ARCH "aarch64")
elseif(${IOS_ARCH} STREQUAL "armv7s")
set(MGE_ARCH "armv7")
else()
message(FATAL_ERROR "Unsupported IOS_ARCH.")
endif()
elseif(NOT "${ARM_CROSS_BUILD_ARCH}" STREQUAL "")
set(MGE_ARCH ${ARM_CROSS_BUILD_ARCH})
else()
message(FATAL_ERROR "Unknown cross-compiling settings.")
endif()
message("CONFIG MGE_ARCH TO ${MGE_ARCH}")
endif()

if(${MGE_ARCH} STREQUAL "AUTO")
if(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64")
set(MGE_ARCH "x86_64")
elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i386" OR ${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i686")
set(MGE_ARCH "i386")
elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "aarch64" OR ${CMAKE_SYSTEM_PROCESSOR} STREQUAL "arm64")
set(MGE_ARCH "aarch64")
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "^arm")
set(MGE_ARCH "armv7")
else()
message(FATAL "Unknown machine architecture for MegEngine.")
endif()
@@ -399,6 +442,38 @@ if(MGE_ARCH STREQUAL "x86_64" OR MGE_ARCH STREQUAL "i386")
endif()
endif()

if(MGE_ARCH STREQUAL "armv7")
# -funsafe-math-optimizations to enable neon auto-vectorization (since neon is not fully IEEE 754 compatible, GCC does not turn on neon auto-vectorization by default.
if(ANDROID)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfloat-abi=softfp -mfpu=neon")
endif()
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -funsafe-math-optimizations")
set (MARCH "-march=armv7-a")
set (MEGDNN_ARMV7 1)
endif()

if(MGE_ARCH STREQUAL "aarch64")
set(MEGDNN_AARCH64 1)
set(MEGDNN_64_BIT 1)
set(MARCH "-march=armv8-a")
if(MGE_ARMV8_2_FEATURE_FP16)
message("Enable fp16 feature support in armv8.2")
if(NOT ${MGE_DISABLE_FLOAT16})
set(MEGDNN_ENABLE_FP16_NEON 1)
endif()
set(MARCH "-march=armv8.2-a+fp16")
endif()

if(MGE_ARMV8_2_FEATURE_DOTPROD)
message("Enable dotprod feature support in armv8.2")
if(MGE_ARMV8_2_FEATURE_FP16)
set(MARCH "-march=armv8.2-a+fp16+dotprod")
else()
set(MARCH "-march=armv8.2-a+dotprod")
endif()
endif()

endif()

set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}")



+ 3
- 0
dnn/include/megdnn/handle.h View File

@@ -29,6 +29,9 @@ class Handle {
NAIVE = 0,
FALLBACK = 1,
X86 = 2,
ARM_COMMON = 3,
ARMV7 = 4,
AARCH64 = 5,
CUDA = 6,
};



+ 16
- 0
dnn/src/CMakeLists.txt View File

@@ -17,6 +17,22 @@ if(NOT ${MGE_ARCH} STREQUAL "naive")
set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C)
list(APPEND SOURCES ${SOURCES_})
endif()
elseif(${MGE_ARCH} STREQUAL "armv7")
file(GLOB_RECURSE SOURCES_ armv7/*.cpp)
list(APPEND SOURCES ${SOURCES_})
file(GLOB_RECURSE SOURCES_ arm_common/*.cpp)
list(APPEND SOURCES ${SOURCES_})
file(GLOB_RECURSE SOURCES_ armv7/*.S)
set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C)
list(APPEND SOURCES ${SOURCES_})
elseif(${MGE_ARCH} STREQUAL "aarch64")
file(GLOB_RECURSE SOURCES_ aarch64/*.cpp)
list(APPEND SOURCES ${SOURCES_})
file(GLOB_RECURSE SOURCES_ arm_common/*.cpp)
list(APPEND SOURCES ${SOURCES_})
file(GLOB_RECURSE SOURCES_ aarch64/*.S)
set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C)
list(APPEND SOURCES ${SOURCES_})
endif()
endif()



+ 139
- 0
dnn/src/aarch64/conv_bias/fp16/algos.cpp View File

@@ -0,0 +1,139 @@
/**
* \file dnn/src/aarch64/conv_bias/fp16/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/conv_bias/fp16/algos.h"
#include "src/aarch64/conv_bias/fp16/stride2_kern.h"
#include "src/arm_common/conv_bias/direct/multi_thread_common.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"

using namespace megdnn;
using namespace aarch64;
#include "midout.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/* ===================== stride-2 algo ===================== */
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16)

bool ConvBiasImpl::AlgoF16DirectStride2::usable(
FallbackConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.filter_meta.format == param::Convolution::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16 &&
!fm.should_flip && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
}
MIDOUT_END();
return false;
}

size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 1) {
auto wbundle = arm_common::MultithreadDirectConvCommon<
dt_float16, __fp16>::get_bundle_stride(param, m_large_group);
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
return false;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) {
return get_kimpls(param);
}
MIDOUT_END();
return {};
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
auto FH = fm.spatial[0];
size_t N = param.n;
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*,
size_t, size_t, size_t, size_t, size_t)>;
Func conv = nullptr;
if (FH == 2) {
conv = fp16::conv_stride2::do_conv_2x2_stride2;
} else if (FH == 3) {
conv = fp16::conv_stride2::do_conv_3x3_stride2;
} else if (FH == 5) {
conv = fp16::conv_stride2::do_conv_5x5_stride2;
} else if (FH == 7) {
conv = fp16::conv_stride2::do_conv_7x7_stride2;
}

WorkspaceBundle wbundle = arm_common::MultithreadDirectConvCommon<
dt_float16, __fp16>::get_bundle_stride(param, m_large_group);
SmallVector<NCBKern> ret_kerns;

//! Dense conv and small group
if (m_large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [wbundle, conv](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
auto fm = kern_param.filter_meta;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
WorkspaceBundle bundle = wbundle;
for (size_t ic = 0; ic < IC; ic++) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride(bundle, kern_param, ncb_index, conv,
{ncb_index.thread_id, 0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
WorkspaceBundle bundle = wbundle;
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride(bundle, kern_param, ncb_index, conv,
ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}
return ret_kerns;
}

#endif

// vim: syntax=cpp.doxygen

+ 42
- 0
dnn/src/aarch64/conv_bias/fp16/algos.h View File

@@ -0,0 +1,42 @@
/**
* \file dnn/src/aarch64/conv_bias/fp16/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "src/aarch64/conv_bias/opr_impl.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
namespace megdnn {
namespace aarch64 {
/* ===================== stride-2 algo ===================== */
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF16DirectStride2(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "ARMV8F16STRD2_LARGE_GROUP"
: "ARMV8F16STRD2_SMALL_GROUP";
}

bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;

size_t get_workspace(FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override;

SmallVector<NCBKern> dispatch_kerns(FallbackConvBiasImpl*,
const NCBKernSizeParam&) const override;
};
} // namespace aarch64
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 1037
- 0
dnn/src/aarch64/conv_bias/fp16/stride2_kern.h
File diff suppressed because it is too large
View File


+ 137
- 0
dnn/src/aarch64/conv_bias/fp32/algos.cpp View File

@@ -0,0 +1,137 @@
/**
* \file dnn/src/aarch64/conv_bias/fp32/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/conv_bias/fp32/algos.h"
#include "src/aarch64/conv_bias/fp32/stride2_kern.h"
#include "src/arm_common/conv_bias/direct/multi_thread_common.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/fallback/conv_bias/common.h"

#include "midout.h"

using namespace megdnn;
using namespace aarch64;

MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32)
bool ConvBiasImpl::AlgoF32DirectStride2::usable(
FallbackConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
!fm.should_flip && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
}
MIDOUT_END();
return false;
}

size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 1) {
auto wbundle = arm_common::MultithreadDirectConvCommon<
float, float>::get_bundle_stride(param, m_large_group);
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) {
return get_kimpls(param);
}
MIDOUT_END();
return {};
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
auto FH = fm.spatial[0];
size_t N = param.n;
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
using Func = std::function<void(const float*, const float*, float*, size_t,
size_t, size_t, size_t, size_t)>;
Func conv = nullptr;
if (FH == 2) {
conv = fp32::conv_stride2::do_conv_2x2_stride2;
} else if (FH == 3) {
conv = fp32::conv_stride2::do_conv_3x3_stride2;
} else if (FH == 5) {
conv = fp32::conv_stride2::do_conv_5x5_stride2;
} else if (FH == 7) {
conv = fp32::conv_stride2::do_conv_7x7_stride2;
}

WorkspaceBundle wbundle = arm_common::MultithreadDirectConvCommon<
float, float>::get_bundle_stride(param, m_large_group);
SmallVector<NCBKern> ret_kerns;

//! Dense conv and small group
if (m_large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [wbundle, conv](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
auto fm = kern_param.filter_meta;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
WorkspaceBundle bundle = wbundle;
for (size_t ic = 0; ic < IC; ic++) {
arm_common::MultithreadDirectConvCommon<float, float>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
arm_common::MultithreadDirectConvCommon<
float, float>::do_conv_kern_stride(bundle, kern_param,
ncb_index, conv,
{ncb_index.thread_id,
0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
WorkspaceBundle bundle = wbundle;
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
arm_common::MultithreadDirectConvCommon<float, float>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
arm_common::MultithreadDirectConvCommon<
float, float>::do_conv_kern_stride(bundle, kern_param,
ncb_index, conv,
ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}
return ret_kerns;
}

+ 47
- 0
dnn/src/aarch64/conv_bias/fp32/algos.h View File

@@ -0,0 +1,47 @@
/**
* \file dnn/src/aarch64/conv_bias/fp32/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/opr_impl.h"

namespace megdnn {
namespace aarch64 {

using FallbackConvBiasImpl = fallback::ConvBiasImpl;
/* ===================== stride-2 algo ===================== */

class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32DirectStride2(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "ARMV8F32STRD2_LARGE_GROUP"
: "ARMV8F32STRD2_SMALL_GROUP";
}

bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;

size_t get_workspace(FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override;

SmallVector<NCBKern> dispatch_kerns(FallbackConvBiasImpl*,
const NCBKernSizeParam&) const override;
};

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 1024
- 0
dnn/src/aarch64/conv_bias/fp32/stride2_kern.h
File diff suppressed because it is too large
View File


+ 187
- 0
dnn/src/aarch64/conv_bias/int8/algos.cpp View File

@@ -0,0 +1,187 @@
/**
* \file dnn/src/aarch64/conv_bias/int8/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/conv_bias/int8/algos.h"
#include "src/aarch64/conv_bias/int8/strategy.h"
#include "src/arm_common/convolution/img2col_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/matrix_mul/gemm_impl.h"

#include "midout.h"

MIDOUT_DECL(megdnn_aarch64_conv_bias_int8_gemm)

using namespace megdnn;
using namespace aarch64;
using megdnn::arm_common::HSwishOp;
using megdnn::arm_common::ReluOp;
using megdnn::arm_common::TypeCvtOp;

/* ===================== matrix mul algo ===================== */

bool ConvBiasImpl::AlgoS8MatrixMul::usable(
FallbackConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(opr);
auto&& fm = param.filter_meta;
return param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8 &&
fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
//! As postprocess, the bias is not contigous read, make the
//! performance bad, so we do not process it in fused kernel
param.bias_mode != BiasMode::BIAS &&
//! This algo is only support single thread
param.nr_threads == 1_z;
}

WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle(
const NCBKernSizeParam& param) {
UNPACK_CONV_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
auto IW2 = IH + 2 * PH;
auto IH2 = IW + 2 * PW;
bool can_matrix_mul_direct =
(FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0);
// temp space to store padding-free src (with 16 extra int8)
// temp space to store unrolled matrix (with 16 extra int8)
// workspace for matrix mul opr
size_t part0, part1, part2;
if (can_matrix_mul_direct) {
part0 = part1 = 0;
} else {
part0 = (IC * IH2 * IW2 + 16) * sizeof(int8_t);
part1 = (IC * FH * FW * OH * OW + 16) * sizeof(int8_t);
}
{
size_t M = OC;
size_t K = IC * FH * FW;
size_t N = OH * OW;

#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline>( \
M, N, K, false, false, strategy) \
.get_workspace_size(); \
} \
MIDOUT_END()

#if !(__ARM_FEATURE_DOTPROD)
DISPATCH_GEMM_BIAS(s8_4x4, 0)
#else
DISPATCH_GEMM_BIAS(s8_8x12, 1)
#endif
#undef DISPATCH_GEMM_STRATEGY
}
return {nullptr, {part0, part1, part2}};
}

void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
auto is_xcorr = !param.filter_meta.should_flip;
UNPACK_CONV_NCB_KERN_SIZES(param);
auto bundle = get_bundle(param);
bundle.set(param.workspace_ptr);
auto IH2 = IH + 2 * PH;
auto IW2 = IW + 2 * PW;
size_t group_id = ncb_index.ndrange_id[0];
// workspace = tmp..src2
for (size_t n = 0; n < N; ++n) {
dt_int8* src = const_cast<dt_int8*>(param.src<dt_int8>(n, group_id));
dt_int8* filter = const_cast<dt_int8*>(param.filter<dt_int8>(group_id));
dt_int8* dst = static_cast<dt_int8*>(param.dst<dt_int8>(n, group_id));
dt_int32* bias = const_cast<dt_int32*>(param.bias<dt_int32>(n, group_id));

dt_int8 *B, *src2;
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) {
// special case: 1x1
B = const_cast<dt_int8*>(src);
} else {
src2 = static_cast<dt_int8*>(bundle.get(0));
// copy src to src2;
dt_int8* src2_ptr = src2;
const dt_int8* src_ptr = src;
rep(ic, IC) {
if (PH != 0) {
std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2);
src2_ptr += PH * IW2;
}
rep(ih, IH) {
if (PW != 0)
rep(pw, PW) { *(src2_ptr++) = 0.0f; }
std::memcpy(src2_ptr, src_ptr, sizeof(dt_int8) * IW);
src2_ptr += IW;
src_ptr += IW;
if (PW != 0)
rep(pw, PW) { *(src2_ptr++) = 0.0f; }
}
if (PH != 0) {
std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2);
src2_ptr += PH * IW2;
}
}

B = static_cast<dt_int8*>(bundle.get(1));
if (SH == 1 && SW == 1) {
if (is_xcorr)
img2col<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
else
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
} else {
if (is_xcorr)
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
else
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
}
}
{
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)),
bundle.get_size(2));
size_t M = OC;
size_t K = IC * FH * FW;
size_t N = OH * OW;

#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \
bias); \
} \
MIDOUT_END()

#if !(__ARM_FEATURE_DOTPROD)
DISPATCH_GEMM_BIAS(s8_4x4, 0)
#else
DISPATCH_GEMM_BIAS(s8_8x12, 1)
#endif
#undef DISPATCH_GEMM_STRATEGY
}
}
}

// vim: syntax=cpp.doxygen

+ 52
- 0
dnn/src/aarch64/conv_bias/int8/algos.h View File

@@ -0,0 +1,52 @@
/**
* \file dnn/src/aarch64/conv_bias/int8/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/opr_impl.h"

namespace megdnn {
namespace aarch64 {

using FallbackConvBiasImpl = fallback::ConvBiasImpl;

class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase {
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index);

public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "S8MATMUL"; }

bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<NCBKern> dispatch_kerns(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const override {
size_t group = param.filter_meta.group;
return {{kimpl, {group, 1_z, 1_z}}};
}
//! select matmul to the highest preference
bool is_preferred(FallbackConvBiasImpl* opr,
const NCBKernSizeParam& param) const override {
return static_cast<arm_common::ConvBiasImpl*>(opr)
->is_matmul_quantized_prefer(param);
}
};

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 309
- 0
dnn/src/aarch64/conv_bias/int8/strategy.cpp View File

@@ -0,0 +1,309 @@
/**
* \file dnn/src/aarch64/conv_bias/int8/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/conv_bias/int8/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"

#include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h"
#include "src/arm_common/conv_bias/matmul_postprocess.h"

using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

namespace impl {
template <BiasMode bmode, typename Op, int block_m, int block_n>
struct KernCaller;

#if __ARM_FEATURE_DOTPROD
template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 12> {
static void run(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k,
Op op, const dt_int32* bias, dt_int32* workspace) {
megdnn_assert(is_first_k);

constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 12;
//! K is packed to times of 4
K = round_up<size_t>(K, 4);
const int K8 = (K << 3);
const int K12 = K * 12;
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int8_t* output = C + (m * LDC);

size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_8x12(packA, cur_packB, K, workspace, 12,
is_first_k);

arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8,
12>::postprocess(bias, workspace,
output, LDC, op);
output += B_INTERLEAVE;
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_8x12x4::kern_8x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(N - n, 4));

#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 4, 8, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4));
#undef cb
output += 4;
cur_packB += K4;
}
packA += K8;

if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += A_INTERLEAVE;
}
}

for (; m < M; m += 4) {
int8_t* output = C + (m * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_4x12(packA, cur_packB, K, workspace, 12,
is_first_k,
std::min<size_t>(M - m, 4));
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 12, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M_N(cb, std::min<size_t>(M - m, 4), 12);
#undef cb

output += B_INTERLEAVE;
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_8x12x4::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
#undef cb

output += 4;
cur_packB += K4;
}
packA += K4;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += 4;
}
}
}
};

#else

template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 4, 4> {
static void run(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k,
Op op, const dt_int32* bias, dt_int32* workspace) {
megdnn_assert(is_first_k);

constexpr size_t A_INTERLEAVE = 4;
constexpr size_t B_INTERLEAVE = 4;
//! K is packed to times of 4
K = round_up<size_t>(K, 16);
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int8_t* output = C + (m * LDC);

size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k);
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4,
4>::postprocess(bias, workspace,
output, LDC, op);

output += B_INTERLEAVE;
cur_packB += K4;
}

for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, workspace,
4, is_first_k, 4,
std::min<size_t>(N - n, 4));
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_N(cb, 4, std::min<size_t>(N - n, 4));
#undef cb
output += B_INTERLEAVE;
cur_packB += K4;
}

packA += K4;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += A_INTERLEAVE;
}
}

for (; m < M; m += A_INTERLEAVE) {
int8_t* output = C + (m * LDC);

size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4_remain(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));

#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
#undef cb
output += B_INTERLEAVE;
cur_packB += K4;
}
packA += K4;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += A_INTERLEAVE;
}
}
}
};

#endif

} // namespace impl
#if !(__ARM_FEATURE_DOTPROD)
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity)

void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
} else {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
}
}

void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_4x4x16::gemm_s8_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}

size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const {
return 4 * 4 * sizeof(dt_int32);
}
#else
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity)

void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t);
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t);
if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
} else {
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
}
}

void gemm_s8_8x12_nobias_identity::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_8x12x4::gemm_s8_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}

size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const {
return 8 * 12 * sizeof(dt_int32);
}

#endif

#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \
void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \
size_t K, dt_int8* C, size_t LDC, bool is_first_k, \
const dt_int32* bias, dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \
float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \
float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \
workspace); \
}

#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, scale_C);

#if !(__ARM_FEATURE_DOTPROD)
KERN(4, 4, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp)
KERN(4, 4, nobias, BiasMode::NO_BIAS, relu, ReluOp)
KERN(4, 4, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#else
KERN(8, 12, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp)
KERN(8, 12, nobias, BiasMode::NO_BIAS, relu, ReluOp)
KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#endif
#undef DEFINE_OP

#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \
scale_A* scale_B, scale_C);
#if !(__ARM_FEATURE_DOTPROD)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
#else
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
#endif
#undef DEFINE_OP

#undef KERN

// vim: syntax=cpp.doxygen

+ 69
- 0
dnn/src/aarch64/conv_bias/int8/strategy.h View File

@@ -0,0 +1,69 @@
/**
* \file dnn/src/aarch64/conv_bias/int8/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"

namespace megdnn {
namespace aarch64 {
namespace matmul {

#if !(__ARM_FEATURE_DOTPROD)
/**
* \brief base strategy of gemm.
*
* \name gemm_<type>_<block>_biasmode_nolinemode
*/
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 4, 16,
false, true,
gemm_s8_4x4_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_relu,
gemm_s8_4x4_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_hswish,
gemm_s8_4x4_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_identity,
gemm_s8_4x4_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu,
gemm_s8_4x4_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish,
gemm_s8_4x4_nobias_identity);

#else
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4,
false, true,
gemm_s8_8x12_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_relu,
gemm_s8_8x12_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_hswish,
gemm_s8_8x12_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_identity,
gemm_s8_8x12_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu,
gemm_s8_8x12_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish,
gemm_s8_8x12_nobias_identity);

#endif

} // namespace matmul
} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 69
- 0
dnn/src/aarch64/conv_bias/opr_impl.cpp View File

@@ -0,0 +1,69 @@
/**
* \file dnn/src/aarch64/conv_bias/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/aarch64/conv_bias/int8/algos.h"
#include "src/aarch64/conv_bias/quint8/algos.h"

#include "src/naive/handle.h"
#include "src/common/utils.h"
#include "src/common/metahelper.h"

#include "src/fallback/convolution/opr_impl.h"
#include "src/aarch64/conv_bias/fp32/algos.h"
#include "src/aarch64/conv_bias/fp16/algos.h"

using namespace megdnn;
using namespace aarch64;

class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoF32DirectStride2 f32_direct_stride2_large_group{true};
AlgoF32DirectStride2 f32_direct_stride2_small_group{false};
AlgoS8MatrixMul s8_matrix_mul;
AlgoQU8MatrixMul qu8_matrix_mul;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16DirectStride2 f16_direct_stride2_large_group{true};
AlgoF16DirectStride2 f16_direct_stride2_small_group{false};
#endif

public:
AlgoPack() {
matmul_algos.emplace_back(&qu8_matrix_mul);
matmul_algos.emplace_back(&s8_matrix_mul);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
direct_algos.emplace_back(&f16_direct_stride2_large_group);
direct_algos.emplace_back(&f16_direct_stride2_small_group);
#endif
direct_algos.emplace_back(&f32_direct_stride2_large_group);
direct_algos.emplace_back(&f32_direct_stride2_small_group);
}
SmallVector<AlgoBase*> direct_algos;
SmallVector<AlgoBase*> matmul_algos;
};

SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack;
auto&& algos = arm_common::ConvBiasImpl::algo_pack();
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(),
sl_algo_pack.direct_algos.end());
//! We put matmul algos at the end. Because matmul will get privilege when
//! prefer return true. See
//! fallback::ConvolutionImpl::ncb_1g_get_all_algorithms for more details.
algos.insert(algos.end(), sl_algo_pack.matmul_algos.begin(),
sl_algo_pack.matmul_algos.end());
return std::move(algos);
}

const char* ConvBiasImpl::get_algorithm_set_name() const {
return "AARCH64";
}

// vim: syntax=cpp.doxygen

+ 41
- 0
dnn/src/aarch64/conv_bias/opr_impl.h View File

@@ -0,0 +1,41 @@
/**
* \file dnn/src/aarch64/conv_bias/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/common/utils.h"
#include "src/arm_common/conv_bias/opr_impl.h"

namespace megdnn {
namespace aarch64 {

class ConvBiasImpl : public arm_common::ConvBiasImpl {
public:
using arm_common::ConvBiasImpl::ConvBiasImpl;

SmallVector<AlgoBase*> algo_pack() override;

protected:

const char* get_algorithm_set_name() const override;

private:
class AlgoF32DirectStride2;
class AlgoS8MatrixMul;
class AlgoQU8MatrixMul;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoF16DirectStride2;
#endif
class AlgoPack;
};

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 181
- 0
dnn/src/aarch64/conv_bias/quint8/algos.cpp View File

@@ -0,0 +1,181 @@
/**
* \file dnn/src/aarch64/conv_bias/quint8/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/conv_bias/quint8/algos.h"
#include "src/aarch64/conv_bias/quint8/strategy.h"
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
#include "src/aarch64/matrix_mul/quint8_dot/strategy.h"
#include "src/arm_common/convolution/img2col_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/matrix_mul/gemm_impl.h"

#include "midout.h"

MIDOUT_DECL(megdnn_aarch64_conv_bias_quint8_gemm)

using namespace megdnn;
using namespace aarch64;
using megdnn::arm_common::HSwishOp;
using megdnn::arm_common::ReluOp;
using megdnn::arm_common::TypeCvtOp;

/* ===================== matrix mul algo ===================== */

bool ConvBiasImpl::AlgoQU8MatrixMul::usable(
FallbackConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(opr);
auto&& fm = param.filter_meta;
return param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm &&
fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
//! As postprocess, the bias is not contigous read, make the
//! performance bad, so we do not process it in fused kernel
param.bias_mode != BiasMode::BIAS &&
//! This algo is only support single thread
param.nr_threads == 1_z;
}

WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle(
const NCBKernSizeParam& param) {
UNPACK_CONV_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
auto IW2 = IH + 2 * PH;
auto IH2 = IW + 2 * PW;
bool can_matrix_mul_direct =
(FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0);
// temp space to store padding-free src (with 16 extra int8)
// temp space to store unrolled matrix (with 16 extra int8)
// workspace for matrix mul opr
size_t part0, part1, part2;
if (can_matrix_mul_direct) {
part0 = part1 = 0;
} else {
part0 = (IC * IH2 * IW2 + 16) * sizeof(uint8_t);
part1 = (IC * FH * FW * OH * OW + 16) * sizeof(uint8_t);
}
{
size_t M = OC;
size_t K = IC * FH * FW;
size_t N = OH * OW;

#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline>( \
M, N, K, false, false, strategy) \
.get_workspace_size(); \
} \
MIDOUT_END()

DISPATCH_GEMM_BIAS(u8_8x8, 0)
#undef DISPATCH_GEMM_STRATEGY
}
return {nullptr, {part0, part1, part2}};
}

void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
auto is_xcorr = !param.filter_meta.should_flip;
UNPACK_CONV_NCB_KERN_SIZES(param);
auto bundle = get_bundle(param);
bundle.set(param.workspace_ptr);
auto IH2 = IH + 2 * PH;
auto IW2 = IW + 2 * PW;
size_t group_id = ncb_index.ndrange_id[0];
uint8_t src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
// workspace = tmp..src2
for (size_t n = 0; n < N; ++n) {
uint8_t* src = const_cast<uint8_t*>(param.src<uint8_t>(n, group_id));
uint8_t* filter = const_cast<uint8_t*>(param.filter<uint8_t>(group_id));
uint8_t* dst = static_cast<uint8_t*>(param.dst<uint8_t>(n, group_id));
int32_t* bias = const_cast<int32_t*>(param.bias<int32_t>(n, group_id));

uint8_t *B, *src2;
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) {
// special case: 1x1
B = const_cast<uint8_t*>(src);
} else {
src2 = static_cast<uint8_t*>(bundle.get(0));
// copy src to src2;
uint8_t* src2_ptr = src2;
const uint8_t* src_ptr = src;
rep(ic, IC) {
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2);
src2_ptr += PH * IW2;
}
rep(ih, IH) {
if (PW != 0)
rep(pw, PW) { *(src2_ptr++) = src_zp; }
std::memcpy(src2_ptr, src_ptr, sizeof(uint8_t) * IW);
src2_ptr += IW;
src_ptr += IW;
if (PW != 0)
rep(pw, PW) { *(src2_ptr++) = src_zp; }
}
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2);
src2_ptr += PH * IW2;
}
}

B = static_cast<uint8_t*>(bundle.get(1));
if (SH == 1 && SW == 1) {
if (is_xcorr)
img2col<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
else
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
} else {
if (is_xcorr)
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
else
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
}
}
{
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)),
bundle.get_size(2));
size_t M = OC;
size_t K = IC * FH * FW;
size_t N = OH * OW;

#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \
bias); \
} \
MIDOUT_END()

DISPATCH_GEMM_BIAS(u8_8x8, 0)
#undef DISPATCH_GEMM_STRATEGY
}
}
}
// vim: syntax=cpp.doxygen

+ 52
- 0
dnn/src/aarch64/conv_bias/quint8/algos.h View File

@@ -0,0 +1,52 @@
/**
* \file dnn/src/aarch64/conv_bias/quint8/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/opr_impl.h"

namespace megdnn {
namespace aarch64 {

using FallbackConvBiasImpl = fallback::ConvBiasImpl;

class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase {
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
static void kimpl(const NCBKernParam& param, const NCBKernIndex&);

public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "QU8MATMUL"; }

bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<NCBKern> dispatch_kerns(
FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override {
size_t group = param.filter_meta.group;
return {{kimpl, {group, 1_z, 1_z}}};
}
//! select matmul to the highest preference
bool is_preferred(FallbackConvBiasImpl* opr,
const NCBKernSizeParam& param) const override {
return static_cast<arm_common::ConvBiasImpl*>(opr)
->is_matmul_quantized_prefer(param);
}
};
} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 319
- 0
dnn/src/aarch64/conv_bias/quint8/strategy.cpp View File

@@ -0,0 +1,319 @@
/**
* \file dnn/src/aarch64/conv_bias/quint8/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/conv_bias/quint8/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"

#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h"
#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h"
#include "src/arm_common/conv_bias/matmul_postprocess.h"

using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

namespace impl {
template <BiasMode bmode, typename Op, int block_m, int block_n>
struct KernCaller;

#if __ARM_FEATURE_DOTPROD
template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 8> {
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M,
size_t N, size_t K, dt_uint8* C, size_t LDC,
bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) {
megdnn_assert(is_first_k);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 8;
const uint32_t zAB =
static_cast<uint32_t>(zp_A) * static_cast<uint32_t>(zp_B) * K;
//! K is packed to times of 4
K = round_up<size_t>(K, 4);
const int K8 = (K << 3);
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
uint8_t* output = C + (m * LDC);

size_t n = 0;
const dt_uint8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_8x8(packA, cur_packB, K, workspace, 8,
is_first_k, zp_A, zp_B, zAB);

arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8,
8>::postprocess(bias, workspace,
output, LDC, op);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x4::kern_8x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(N - n, 4),
zp_A, zp_B, zAB);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4));
#undef cb

output += 4;
cur_packB += K4;
}
packA += K8;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += A_INTERLEAVE;
}
}

for (; m < M; m += 4) {
uint8_t* output = C + (m * LDC);
const dt_uint8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_4x8(packA, cur_packB, K, workspace, 8,
is_first_k, std::min<size_t>(M - m, 4),
zp_A, zp_B, zAB);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M_N(cb, std::min<size_t>(M - m, 4), 8);
#undef cb

output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x4::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4), zp_A, zp_B,
zAB);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
#undef cb

output += 4;
cur_packB += K4;
}
packA += K4;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += 4;
}
}
}
};

#else

template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 8> {
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M,
size_t N, size_t K, dt_uint8* C, size_t LDC,
bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) {
megdnn_assert(is_first_k);

constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 8;
//! K is packed to times of 8
K = round_up<size_t>(K, 8);
const int K8 = K * 8;
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
uint8_t* output = C + (m * LDC);

size_t n = 0;
const dt_uint8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_8x8(packA, cur_packB, K, workspace, 8,
is_first_k, zp_A, zp_B);

arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8,
8>::postprocess(bias, workspace,
output, LDC, op);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_8x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(N - n, 4),
zp_A, zp_B);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4));
#undef cb


output += 4;
cur_packB += K4;
}
packA += K8;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += A_INTERLEAVE;
}
}

for (; m < M; m += 4) {
uint8_t* output = C + (m * LDC);
const dt_uint8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_4x8(packA, cur_packB, K, workspace, 8,
is_first_k, std::min<size_t>(M - m, 4),
zp_A, zp_B);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M_N(cb, std::min<size_t>(M - m, 4), 8);
#undef cb

output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4), zp_A, zp_B);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
#undef cb


output += 4;
cur_packB += K4;
}
packA += K4;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += 4;
}
}
}
};

#endif

} // namespace impl
#if __ARM_FEATURE_DOTPROD
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity)

void gemm_u8_8x8_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0,
ymax, k0, kmax);
} else {
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin,
y0, ymax, k0, kmax);
}
}

void gemm_u8_8x8_nobias_identity::pack_B(uint8_t* out, const uint8_t* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0,
xmax, k0, kmax);
} else {
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax,
k0, kmax);
}
}

#else

MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity)
void gemm_u8_8x8_nobias_identity::pack_A(dt_uint8* outptr,
const dt_uint8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point;
if (transpose) {
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0,
ymax, k0, kmax, zA);
} else {
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax, zA);
}
}

void gemm_u8_8x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point;
if (transpose) {
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax,
k0, kmax, zB);
} else {
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax,
zB);
}
}

#endif
size_t gemm_u8_8x8_nobias_identity::get_workspace_size() const {
return 8 * 8 * sizeof(dt_int32);
}

#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \
void gemm_u8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \
size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \
const dt_int32* bias, dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \
workspace, zp_A, zp_B); \
}

#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, scale_C, zp_C);

KERN(8, 8, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp)
KERN(8, 8, nobias, BiasMode::NO_BIAS, relu, ReluOp)
KERN(8, 8, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#undef DEFINE_OP

#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \
scale_A* scale_B, scale_C, zp_C);
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
#undef DEFINE_OP

#undef KERN

// vim: syntax=cpp.doxygen

+ 48
- 0
dnn/src/aarch64/conv_bias/quint8/strategy.h View File

@@ -0,0 +1,48 @@
/**
* \file dnn/src/aarch64/conv_bias/quint8/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"

namespace megdnn {
namespace aarch64 {
namespace matmul {

#if __ARM_FEATURE_DOTPROD
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4,
false, true,
gemm_u8_8x8_nobias_identity);
#else
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8,
false, true,
gemm_u8_8x8_nobias_identity);
#endif

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_relu,
gemm_u8_8x8_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_hswish,
gemm_u8_8x8_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_identity,
gemm_u8_8x8_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_relu,
gemm_u8_8x8_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_hswish,
gemm_u8_8x8_nobias_identity);


} // namespace matmul
} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 44
- 0
dnn/src/aarch64/handle.cpp View File

@@ -0,0 +1,44 @@
/**
* \file dnn/src/aarch64/handle.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/common/handle_impl.h"

#include "src/aarch64/handle.h"
#include "src/aarch64/matrix_mul/opr_impl.h"
#include "src/aarch64/rotate/opr_impl.h"
#include "src/aarch64/relayout/opr_impl.h"
#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/aarch64/warp_perspective/opr_impl.h"

namespace megdnn {
namespace aarch64 {

template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() {
return arm_common::HandleImpl::create_operator<Opr>();
}

MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMul)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Rotate)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutForward)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspective)

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas"
#pragma GCC diagnostic ignored "-Winstantiation-after-specialization"
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR)
#pragma GCC diagnostic pop

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 33
- 0
dnn/src/aarch64/handle.h View File

@@ -0,0 +1,33 @@
/**
* \file dnn/src/aarch64/handle.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/arm_common/handle.h"

namespace megdnn {
namespace aarch64 {

class HandleImpl: public arm_common::HandleImpl {
public:
HandleImpl(megcoreComputingHandle_t computing_handle,
HandleType type = HandleType::AARCH64):
arm_common::HandleImpl::HandleImpl(computing_handle, type)
{}

template <typename Opr>
std::unique_ptr<Opr> create_operator();
};

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen



+ 1038
- 0
dnn/src/aarch64/matrix_mul/algos.cpp
File diff suppressed because it is too large
View File


+ 250
- 0
dnn/src/aarch64/matrix_mul/algos.h View File

@@ -0,0 +1,250 @@
/**
* \file dnn/src/aarch64/matrix_mul/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "src/aarch64/matrix_mul/opr_impl.h"
#include "src/arm_common/matrix_mul/algos.h"
#include "src/fallback/matrix_mul/gemm_common.h"

namespace megdnn {
namespace aarch64 {

class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_F32K8X12X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_F32K4X16X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_F32_MK4_4x16"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};

class MatrixMulImpl::AlgoF32Gemv final
: public arm_common::MatrixMulImpl::AlgoF32Gemv {};

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_F16_K8X24X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_F16_MK8_8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};

#endif

#if __ARM_FEATURE_DOTPROD
class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X32_K8X12X4_DOTPROD";
}
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoInt8x8x32GemvDotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X32_GEMV_DOTPROD";
}
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};
#else

class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase {

public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X32_MK4_4X4X16";
}
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::DEFAULT; }

MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }

MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoInt8x8x32Gemv final
: public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {};

#endif

class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }

MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};

#if __ARM_FEATURE_DOTPROD
class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_QUINT8_K8X8X4_DOTPROD";
}
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};
#else

class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
#endif

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 1888
- 0
dnn/src/aarch64/matrix_mul/asm/common.h
File diff suppressed because it is too large
View File


+ 2589
- 0
dnn/src/aarch64/matrix_mul/fp16/strategy.cpp
File diff suppressed because it is too large
View File


+ 29
- 0
dnn/src/aarch64/matrix_mul/fp16/strategy.h View File

@@ -0,0 +1,29 @@
/**
* \file dnn/src/aarch64/matrix_mul/fp16/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
namespace megdnn {
namespace aarch64 {
namespace matmul {

MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 8, 24, 1, false,
true, hgemm_8x24);

MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 8, 8, 1,
false, true, gemm_nopack_f16_8x8);

} // namespace matmul
} // namespace aarch64
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 439
- 0
dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp View File

@@ -0,0 +1,439 @@
/**
* \file dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/matrix_mul/fp16/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

namespace {

// Overview of register layout:
//
// A 8x1 cell of Rhs is stored in 16bit in v0-v3
// A 8x1 cell of Lhs is stored in 16bit in v16-v23
// A 8x1 block of accumulators is stored in 16bit in v24-v27.
//
// Rhs +-------+
// |v0[0-7]|
// |v1[0-7]|
// |v2[0-7]|
// |v3[0-7]|
// +-------+
// Lhs
// +--------+
// |v16[0-7]|
// |v17[0-7]|
// |v18[0-7]|
// |v19[0-7]| +--------+
// |v20[0-7]| |v24[0-7]|
// |v21[0-7]| |v25[0-7]|
// |v22[0-7]| |v26[0-7]|
// |v23[0-7]| |v27[0-7]|
// +--------+ +--------+
// Accumulator
void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
//! LDB means number of elements in one block in B. we will read 24 numbers
//! first. so minus 24 * 2 bytes here.
LDB = (LDB - 24) * sizeof(dt_float16);

asm volatile(
".arch armv8.2-a+fp16\n"

"ld1 {v16.4s, v17.4s}, [%[a_ptr]], 32\n"

"subs %w[K], %w[K], #8\n"
"ld1 {v0.4s}, [%[b_ptr]], 16\n"

"ld1 {v1.4s}, [%[b_ptr]], 16\n"
"fmul v24.8h, v16.8h, v0.h[0]\n"

"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"fmul v25.8h, v16.8h, v1.h[0]\n"

"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n"
"fmul v26.8h, v16.8h, v2.h[0]\n"

"ld1 {v18.4s}, [%[a_ptr]], 16\n"
"fmul v27.8h, v16.8h, v3.h[0]\n"

"fmla v24.8h, v17.8h, v0.h[1]\n"
"fmla v25.8h, v17.8h, v1.h[1]\n"
"fmla v26.8h, v17.8h, v2.h[1]\n"
"fmla v27.8h, v17.8h, v3.h[1]\n"

"ld1 {v19.4s}, [%[a_ptr]], 16\n"

"fmla v24.8h, v18.8h, v0.h[2]\n"
"fmla v25.8h, v18.8h, v1.h[2]\n"
"fmla v26.8h, v18.8h, v2.h[2]\n"
"fmla v27.8h, v18.8h, v3.h[2]\n"

"ld1 {v20.4s}, [%[a_ptr]], 16\n"

"fmla v24.8h, v19.8h, v0.h[3]\n"
"fmla v25.8h, v19.8h, v1.h[3]\n"
"fmla v26.8h, v19.8h, v2.h[3]\n"
"fmla v27.8h, v19.8h, v3.h[3]\n"

"ld1 {v21.4s}, [%[a_ptr]], 16\n"

"fmla v24.8h, v20.8h, v0.h[4]\n"
"fmla v25.8h, v20.8h, v1.h[4]\n"
"fmla v26.8h, v20.8h, v2.h[4]\n"
"fmla v27.8h, v20.8h, v3.h[4]\n"

"ld1 {v22.4s}, [%[a_ptr]], 16\n"

"fmla v24.8h, v21.8h, v0.h[5]\n"
"fmla v25.8h, v21.8h, v1.h[5]\n"
"fmla v26.8h, v21.8h, v2.h[5]\n"
"fmla v27.8h, v21.8h, v3.h[5]\n"

"ld1 {v23.4s}, [%[a_ptr]], 16\n"

"fmla v24.8h, v22.8h, v0.h[6]\n"
"fmla v25.8h, v22.8h, v1.h[6]\n"
"fmla v26.8h, v22.8h, v2.h[6]\n"
"fmla v27.8h, v22.8h, v3.h[6]\n"

"beq 2f\n"

"1:\n"

"ld1 {v16.4s}, [%[a_ptr]], 16\n"

"fmla v24.8h, v23.8h, v0.h[7]\n"
"ld1 {v0.4s}, [%[b_ptr]], 16\n"

"fmla v25.8h, v23.8h, v1.h[7]\n"
"ld1 {v1.4s}, [%[b_ptr]], 16\n"

"fmla v26.8h, v23.8h, v2.h[7]\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"

"fmla v27.8h, v23.8h, v3.h[7]\n"
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n"

"ld1 {v17.4s}, [%[a_ptr]], 16\n"

"fmla v24.8h, v16.8h, v0.h[0]\n"
"fmla v25.8h, v16.8h, v1.h[0]\n"
"fmla v26.8h, v16.8h, v2.h[0]\n"
"fmla v27.8h, v16.8h, v3.h[0]\n"

"ld1 {v18.4s}, [%[a_ptr]], 16\n"

"fmla v24.8h, v17.8h, v0.h[1]\n"
"fmla v25.8h, v17.8h, v1.h[1]\n"
"fmla v26.8h, v17.8h, v2.h[1]\n"
"fmla v27.8h, v17.8h, v3.h[1]\n"

"ld1 {v19.4s}, [%[a_ptr]], 16\n"

"fmla v24.8h, v18.8h, v0.h[2]\n"
"fmla v25.8h, v18.8h, v1.h[2]\n"
"fmla v26.8h, v18.8h, v2.h[2]\n"
"fmla v27.8h, v18.8h, v3.h[2]\n"

"ld1 {v20.4s}, [%[a_ptr]], 16\n"

"fmla v24.8h, v19.8h, v0.h[3]\n"
"fmla v25.8h, v19.8h, v1.h[3]\n"
"fmla v26.8h, v19.8h, v2.h[3]\n"
"fmla v27.8h, v19.8h, v3.h[3]\n"

"ld1 {v21.4s}, [%[a_ptr]], 16\n"

"fmla v24.8h, v20.8h, v0.h[4]\n"
"fmla v25.8h, v20.8h, v1.h[4]\n"
"fmla v26.8h, v20.8h, v2.h[4]\n"
"fmla v27.8h, v20.8h, v3.h[4]\n"

"ld1 {v22.4s}, [%[a_ptr]], 16\n"

"fmla v24.8h, v21.8h, v0.h[5]\n"
"fmla v25.8h, v21.8h, v1.h[5]\n"
"fmla v26.8h, v21.8h, v2.h[5]\n"
"fmla v27.8h, v21.8h, v3.h[5]\n"

"ld1 {v23.4s}, [%[a_ptr]], 16\n"

"fmla v24.8h, v22.8h, v0.h[6]\n"
"fmla v25.8h, v22.8h, v1.h[6]\n"
"fmla v26.8h, v22.8h, v2.h[6]\n"
"fmla v27.8h, v22.8h, v3.h[6]\n"

"subs %w[K], %w[K], #8\n"
"bne 1b\n"

"2:\n"

"fmla v24.8h, v23.8h, v0.h[7]\n"
"fmla v25.8h, v23.8h, v1.h[7]\n"
"fmla v26.8h, v23.8h, v2.h[7]\n"
"fmla v27.8h, v23.8h, v3.h[7]\n"

"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory");
}

// Overview of register layout:
//
// A 8x1 cell of Rhs is stored in 16bit in v8-v15
// A 8x1 cell of Lhs is stored in 16bit in v0-v7
// A 8x1 block of accumulators is stored in 16bit in v24-v31.
//
// Rhs +--------+
// | v8[0-7]|
// | v9[0-7]|
// |v10[0-7]|
// |v11[0-7]|
// |v12[0-7]|
// |v13[0-7]|
// |v14[0-7]|
// |v15[0-7]|
// +--------+
// Lhs
// +--------+ - - - - -+--------+
// | v0[0-7]| |v24[0-7]|
// | v1[0-7]| |v25[0-7]|
// | v2[0-7]| |v26[0-7]|
// | v3[0-7]| |v27[0-7]|
// | v4[0-7]| |v28[0-7]|
// | v5[0-7]| |v29[0-7]|
// | v6[0-7]| |v30[0-7]|
// | v7[0-7]| |v31[0-7]|
// +--------+ +--------+
// Accumulator
void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
//! As each load 128 number from B, but the pos add 112 * 2, so we minus 112
//! here.
LDB = (LDB - 32) * sizeof(dt_float16);

asm volatile(
".arch armv8.2-a+fp16\n"

"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n"
"subs %w[K], %w[K], #8\n"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n"
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n"

"fmul v24.8h, v8.8h, v0.h[0]\n"
"fmul v25.8h, v8.8h, v1.h[0]\n"
"fmul v26.8h, v8.8h, v2.h[0]\n"
"fmul v27.8h, v8.8h, v3.h[0]\n"
"fmul v28.8h, v8.8h, v4.h[0]\n"
"fmul v29.8h, v8.8h, v5.h[0]\n"
"fmul v30.8h, v8.8h, v6.h[0]\n"
"fmul v31.8h, v8.8h, v7.h[0]\n"

"fmla v24.8h, v9.8h, v0.h[1]\n"
"fmla v25.8h, v9.8h, v1.h[1]\n"
"fmla v26.8h, v9.8h, v2.h[1]\n"
"fmla v27.8h, v9.8h, v3.h[1]\n"
"fmla v28.8h, v9.8h, v4.h[1]\n"
"fmla v29.8h, v9.8h, v5.h[1]\n"
"fmla v30.8h, v9.8h, v6.h[1]\n"
"fmla v31.8h, v9.8h, v7.h[1]\n"

"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n"
"fmla v24.8h, v10.8h, v0.h[2]\n"
"fmla v25.8h, v10.8h, v1.h[2]\n"
"fmla v26.8h, v10.8h, v2.h[2]\n"
"fmla v27.8h, v10.8h, v3.h[2]\n"
"fmla v28.8h, v10.8h, v4.h[2]\n"
"fmla v29.8h, v10.8h, v5.h[2]\n"
"fmla v30.8h, v10.8h, v6.h[2]\n"
"fmla v31.8h, v10.8h, v7.h[2]\n"

"fmla v24.8h, v11.8h, v0.h[3]\n"
"fmla v25.8h, v11.8h, v1.h[3]\n"
"fmla v26.8h, v11.8h, v2.h[3]\n"
"fmla v27.8h, v11.8h, v3.h[3]\n"
"fmla v28.8h, v11.8h, v4.h[3]\n"
"fmla v29.8h, v11.8h, v5.h[3]\n"
"fmla v30.8h, v11.8h, v6.h[3]\n"
"fmla v31.8h, v11.8h, v7.h[3]\n"

"fmla v24.8h, v12.8h, v0.h[4]\n"
"fmla v25.8h, v12.8h, v1.h[4]\n"
"fmla v26.8h, v12.8h, v2.h[4]\n"
"fmla v27.8h, v12.8h, v3.h[4]\n"
"fmla v24.8h, v13.8h, v0.h[5]\n"
"fmla v25.8h, v13.8h, v1.h[5]\n"
"fmla v26.8h, v13.8h, v2.h[5]\n"
"fmla v27.8h, v13.8h, v3.h[5]\n"

"beq 2f\n"

"1:\n"

"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n"
"fmla v24.8h, v15.8h, v0.h[7]\n"
"fmla v25.8h, v15.8h, v1.h[7]\n"
"fmla v26.8h, v15.8h, v2.h[7]\n"
"fmla v27.8h, v15.8h, v3.h[7]\n"
"fmla v24.8h, v14.8h, v0.h[6]\n"
"fmla v25.8h, v14.8h, v1.h[6]\n"
"fmla v26.8h, v14.8h, v2.h[6]\n"
"fmla v27.8h, v14.8h, v3.h[6]\n"
"fmla v28.8h, v12.8h, v4.h[4]\n"
"fmla v29.8h, v12.8h, v5.h[4]\n"
"fmla v30.8h, v12.8h, v6.h[4]\n"
"fmla v31.8h, v12.8h, v7.h[4]\n"

"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n"
"fmla v28.8h, v13.8h, v4.h[5]\n"
"fmla v29.8h, v13.8h, v5.h[5]\n"
"fmla v30.8h, v13.8h, v6.h[5]\n"
"fmla v31.8h, v13.8h, v7.h[5]\n"
"fmla v28.8h, v14.8h, v4.h[6]\n"
"fmla v29.8h, v14.8h, v5.h[6]\n"
"fmla v30.8h, v14.8h, v6.h[6]\n"
"fmla v31.8h, v14.8h, v7.h[6]\n"
"fmla v28.8h, v15.8h, v4.h[7]\n"
"fmla v29.8h, v15.8h, v5.h[7]\n"
"fmla v30.8h, v15.8h, v6.h[7]\n"
"fmla v31.8h, v15.8h, v7.h[7]\n"
"fmla v24.8h, v8.8h, v0.h[0]\n"
"fmla v25.8h, v8.8h, v1.h[0]\n"
"fmla v26.8h, v8.8h, v2.h[0]\n"
"fmla v27.8h, v8.8h, v3.h[0]\n"

"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n"
"fmla v24.8h, v9.8h, v0.h[1]\n"
"fmla v25.8h, v9.8h, v1.h[1]\n"
"fmla v26.8h, v9.8h, v2.h[1]\n"
"fmla v27.8h, v9.8h, v3.h[1]\n"
"fmla v24.8h, v10.8h, v0.h[2]\n"
"fmla v25.8h, v10.8h, v1.h[2]\n"
"fmla v26.8h, v10.8h, v2.h[2]\n"
"fmla v27.8h, v10.8h, v3.h[2]\n"
"fmla v24.8h, v11.8h, v0.h[3]\n"
"fmla v25.8h, v11.8h, v1.h[3]\n"
"fmla v26.8h, v11.8h, v2.h[3]\n"
"fmla v27.8h, v11.8h, v3.h[3]\n"

"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n"
"fmla v28.8h, v10.8h, v4.h[2]\n"
"fmla v29.8h, v10.8h, v5.h[2]\n"
"fmla v30.8h, v10.8h, v6.h[2]\n"
"fmla v31.8h, v10.8h, v7.h[2]\n"
"fmla v28.8h, v8.8h, v4.h[0]\n"
"fmla v29.8h, v8.8h, v5.h[0]\n"
"fmla v30.8h, v8.8h, v6.h[0]\n"
"fmla v31.8h, v8.8h, v7.h[0]\n"
"fmla v28.8h, v9.8h, v4.h[1]\n"
"fmla v29.8h, v9.8h, v5.h[1]\n"
"fmla v30.8h, v9.8h, v6.h[1]\n"
"fmla v31.8h, v9.8h, v7.h[1]\n"

"fmla v28.8h, v11.8h, v4.h[3]\n"
"fmla v29.8h, v11.8h, v5.h[3]\n"
"fmla v30.8h, v11.8h, v6.h[3]\n"
"fmla v31.8h, v11.8h, v7.h[3]\n"

"fmla v24.8h, v12.8h, v0.h[4]\n"
"fmla v25.8h, v12.8h, v1.h[4]\n"
"fmla v26.8h, v12.8h, v2.h[4]\n"
"fmla v27.8h, v12.8h, v3.h[4]\n"
"fmla v24.8h, v13.8h, v0.h[5]\n"
"fmla v25.8h, v13.8h, v1.h[5]\n"
"fmla v26.8h, v13.8h, v2.h[5]\n"
"fmla v27.8h, v13.8h, v3.h[5]\n"

"subs %w[K], %w[K], #8\n"
"bne 1b\n"

"2:\n"
"fmla v24.8h, v14.8h, v0.h[6]\n"
"fmla v25.8h, v14.8h, v1.h[6]\n"
"fmla v26.8h, v14.8h, v2.h[6]\n"
"fmla v27.8h, v14.8h, v3.h[6]\n"
"fmla v24.8h, v15.8h, v0.h[7]\n"
"fmla v25.8h, v15.8h, v1.h[7]\n"
"fmla v26.8h, v15.8h, v2.h[7]\n"
"fmla v27.8h, v15.8h, v3.h[7]\n"
"fmla v28.8h, v12.8h, v4.h[4]\n"
"fmla v29.8h, v12.8h, v5.h[4]\n"
"fmla v28.8h, v13.8h, v4.h[5]\n"
"fmla v29.8h, v13.8h, v5.h[5]\n"
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n"
"fmla v28.8h, v14.8h, v4.h[6]\n"
"fmla v29.8h, v14.8h, v5.h[6]\n"
"fmla v28.8h, v15.8h, v4.h[7]\n"
"fmla v29.8h, v15.8h, v5.h[7]\n"
"fmla v30.8h, v12.8h, v6.h[4]\n"
"fmla v31.8h, v12.8h, v7.h[4]\n"
"fmla v30.8h, v13.8h, v6.h[5]\n"
"fmla v31.8h, v13.8h, v7.h[5]\n"
"fmla v30.8h, v14.8h, v6.h[6]\n"
"fmla v31.8h, v14.8h, v7.h[6]\n"
"fmla v30.8h, v15.8h, v6.h[7]\n"
"fmla v31.8h, v15.8h, v7.h[7]\n"
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "cc", "memory");
}

} // anonymous namespace

MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8);

void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA,
const dt_float16* B, size_t LDB, dt_float16* C,
size_t LDC, size_t M, size_t K, size_t N,
const dt_float16*, void*, bool trA,
bool trB) const {
constexpr static size_t MB = 8;
constexpr static size_t KB = 8;
constexpr static size_t NB = 8;
constexpr static size_t CALCBLK = 4;

megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0);

//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8)
for (size_t m = 0; m < M; m += MB) {
dt_float16* output = C + (m / MB) * LDC;
const dt_float16* cur_B = B;
size_t n = 0;
for (; n + NB - 1 < N; n += NB) {
kern_8x8(A, cur_B, LDB, K, output);
cur_B += KB * NB;
output += MB * NB;
}
if (n < N) {
kern_8x4(A, cur_B, LDB, K, output);
}
A += LDA;
}
}

#endif
// vim: syntax=cpp.doxygen

+ 39
- 0
dnn/src/aarch64/matrix_mul/fp32/common.h View File

@@ -0,0 +1,39 @@
/**
* \file dnn/src/aarch64/matrix_mul/fp32/common.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include <cstddef>
#include "megdnn/arch.h"
#include "src/common/utils.h"

namespace megdnn {
namespace aarch64 {

MEGDNN_NOINLINE void sgemm_packA_n(const float* A, float* Apacked, size_t M,
size_t K, size_t LDA, const float* alpha);

MEGDNN_NOINLINE void sgemm_packA_t(const float* A, float* Apacked, size_t M,
size_t K, size_t LDA, const float* alpha);

MEGDNN_NOINLINE void sgemm_packB_n(const float* B, float* Bpacked, size_t K,
size_t N, size_t LDB);

MEGDNN_NOINLINE void sgemm_packB_t(const float* B, float* Bpacked, size_t K,
size_t N, size_t LDB);

MEGDNN_NOINLINE void sgemm_kernel12x8(const float* A, const float* B, float* C,
size_t LDC, size_t M, size_t N, size_t K,
int type, const float* beta);

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 718
- 0
dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h View File

@@ -0,0 +1,718 @@
/**
* \file dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"


namespace megdnn {
namespace aarch64 {
namespace matmul_general_4x16 {

// Overview of register layout:
//
// A 1x16 cell of Rhs is stored in 32bit in v1-v4
// A 4x1 cell of Lhs is stored in 32bit in v0
// A 4x16 block of accumulators is stored in 32bit in v10-v25.
//
// +--------+--------+--------+--------+
// | v2[0-3]| v3[0-3]| v4[0-3]| v5[0-3]|
// Rhs +--------+--------+--------+--------+
//
// | | | | |
//
// Lhs | | | | |
//
// +--+ - - - - +--------+--------+--------+--------+
// |v0| |v10[0-3]|v11[0-3]|v12[0-3]|v13[0-3]|
// |v0| |v14[0-3]|v15[0-3]|v16[0-3]|v17[0-3]|
// |v0| |v18[0-3]|v19[0-3]|v20[0-3]|v21[0-3]|
// |v0| |v22[0-3]|v23[0-3]|v24[0-3]|v25[0-3]|
// +--+ - - - - +--------+--------+--------+--------+
//
// Accumulator
void kern_4x16(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k, int m_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;

LDC = LDC * sizeof(float);
register float* outptr asm("x0") = reinterpret_cast<float*>(output);

// clang-format off
#define LOAD_LINE(v0, v1, v2, v3, n) \
"cmp x10, #0\n" \
"beq 100f\n" \
"mov x9, x" n "\n" \
"ld1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s, v" v3 ".4s}, [x9], 64\n" \
"subs x10, x10, #1\n"

#define LOAD_C \
"mov x10, %x[m_remain]\n" \
LOAD_LINE("10", "11", "12", "13", "0") \
LOAD_LINE("14", "15", "16", "17", "1") \
LOAD_LINE("18", "19", "20", "21", "2") \
LOAD_LINE("22", "23", "24", "25", "3") \
"100:\n"

#define STORE_LINE(v0, v1, v2, v3, n) \
"cmp x10, #0\n" \
"beq 101f\n" \
"mov x9, x" n "\n" \
"st1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s, v" v3 ".4s}, [x9], 64\n" \
"subs x10, x10, #1\n"

#define STORE_C \
"mov x10, %x[m_remain]\n" \
STORE_LINE("10", "11", "12", "13", "0") \
STORE_LINE("14", "15", "16", "17", "1") \
STORE_LINE("18", "19", "20", "21", "2") \
STORE_LINE("22", "23", "24", "25", "3") \
"101:\n"
// clang-format on

asm volatile(
// load accumulator C
"add x1, x0, %x[LDC]\n"
"add x2, x1, %x[LDC]\n"
"add x3, x2, %x[LDC]\n"

"cmp %w[is_first_k], #1\n"
"beq 1f\n" LOAD_C

"b 2f\n"

"1:\n"
"eor v10.16b, v10.16b, v10.16b\n"
"eor v11.16b, v11.16b, v11.16b\n"
"eor v12.16b, v12.16b, v12.16b\n"
"eor v13.16b, v13.16b, v13.16b\n"
"eor v14.16b, v14.16b, v14.16b\n"
"eor v15.16b, v15.16b, v15.16b\n"
"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"eor v18.16b, v18.16b, v18.16b\n"
"eor v19.16b, v19.16b, v19.16b\n"
"eor v20.16b, v20.16b, v20.16b\n"
"eor v21.16b, v21.16b, v21.16b\n"
"eor v22.16b, v22.16b, v22.16b\n"
"eor v23.16b, v23.16b, v23.16b\n"
"eor v24.16b, v24.16b, v24.16b\n"
"eor v25.16b, v25.16b, v25.16b\n"

"2: \n"
"ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], 64\n"

"cmp %w[K], #0\n"
"beq 4f\n"

"3:\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"fmla v10.4s, v2.4s, v0.s[0]\n"
"fmla v11.4s, v3.4s, v0.s[0]\n"
"fmla v12.4s, v4.4s, v0.s[0]\n"
"fmla v13.4s, v5.4s, v0.s[0]\n"
"ld1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[b_ptr]], 64\n"
"fmla v14.4s, v2.4s, v0.s[1]\n"
"fmla v15.4s, v3.4s, v0.s[1]\n"
"fmla v16.4s, v4.4s, v0.s[1]\n"
"fmla v17.4s, v5.4s, v0.s[1]\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"fmla v18.4s, v2.4s, v0.s[2]\n"
"fmla v19.4s, v3.4s, v0.s[2]\n"
"fmla v20.4s, v4.4s, v0.s[2]\n"
"fmla v21.4s, v5.4s, v0.s[2]\n"
"fmla v22.4s, v2.4s, v0.s[3]\n"
"fmla v23.4s, v3.4s, v0.s[3]\n"
"fmla v24.4s, v4.4s, v0.s[3]\n"
"fmla v25.4s, v5.4s, v0.s[3]\n"

"ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], 64\n"
"fmla v10.4s, v6.4s, v1.s[0]\n"
"fmla v11.4s, v7.4s, v1.s[0]\n"
"fmla v12.4s, v8.4s, v1.s[0]\n"
"fmla v13.4s, v9.4s, v1.s[0]\n"
"fmla v14.4s, v6.4s, v1.s[1]\n"
"fmla v15.4s, v7.4s, v1.s[1]\n"
"fmla v16.4s, v8.4s, v1.s[1]\n"
"fmla v17.4s, v9.4s, v1.s[1]\n"
"fmla v18.4s, v6.4s, v1.s[2]\n"
"fmla v19.4s, v7.4s, v1.s[2]\n"
"fmla v20.4s, v8.4s, v1.s[2]\n"
"fmla v21.4s, v9.4s, v1.s[2]\n"
"fmla v22.4s, v6.4s, v1.s[3]\n"
"fmla v23.4s, v7.4s, v1.s[3]\n"
"fmla v24.4s, v8.4s, v1.s[3]\n"
"fmla v25.4s, v9.4s, v1.s[3]\n"

"subs %w[K], %w[K], #1\n"
"bne 3b\n"

"4:\n"
"cmp %w[oddk], #1\n"
"beq 5f\n"

// Even tail
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"fmla v10.4s, v2.4s, v0.s[0]\n"
"fmla v11.4s, v3.4s, v0.s[0]\n"
"fmla v12.4s, v4.4s, v0.s[0]\n"
"fmla v13.4s, v5.4s, v0.s[0]\n"
"ld1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[b_ptr]], 64\n"
"fmla v14.4s, v2.4s, v0.s[1]\n"
"fmla v15.4s, v3.4s, v0.s[1]\n"
"fmla v16.4s, v4.4s, v0.s[1]\n"
"fmla v17.4s, v5.4s, v0.s[1]\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"fmla v18.4s, v2.4s, v0.s[2]\n"
"fmla v19.4s, v3.4s, v0.s[2]\n"
"fmla v20.4s, v4.4s, v0.s[2]\n"
"fmla v21.4s, v5.4s, v0.s[2]\n"
"fmla v22.4s, v2.4s, v0.s[3]\n"
"fmla v23.4s, v3.4s, v0.s[3]\n"
"fmla v24.4s, v4.4s, v0.s[3]\n"
"fmla v25.4s, v5.4s, v0.s[3]\n"

"fmla v10.4s, v6.4s, v1.s[0]\n"
"fmla v11.4s, v7.4s, v1.s[0]\n"
"fmla v12.4s, v8.4s, v1.s[0]\n"
"fmla v13.4s, v9.4s, v1.s[0]\n"
"fmla v14.4s, v6.4s, v1.s[1]\n"
"fmla v15.4s, v7.4s, v1.s[1]\n"
"fmla v16.4s, v8.4s, v1.s[1]\n"
"fmla v17.4s, v9.4s, v1.s[1]\n"
"fmla v18.4s, v6.4s, v1.s[2]\n"
"fmla v19.4s, v7.4s, v1.s[2]\n"
"fmla v20.4s, v8.4s, v1.s[2]\n"
"fmla v21.4s, v9.4s, v1.s[2]\n"
"fmla v22.4s, v6.4s, v1.s[3]\n"
"fmla v23.4s, v7.4s, v1.s[3]\n"
"fmla v24.4s, v8.4s, v1.s[3]\n"
"fmla v25.4s, v9.4s, v1.s[3]\n"

"b 6f\n"

// odd tail
"5:\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"fmla v10.4s, v2.4s, v0.s[0]\n"
"fmla v11.4s, v3.4s, v0.s[0]\n"
"fmla v12.4s, v4.4s, v0.s[0]\n"
"fmla v13.4s, v5.4s, v0.s[0]\n"
"fmla v14.4s, v2.4s, v0.s[1]\n"
"fmla v15.4s, v3.4s, v0.s[1]\n"
"fmla v16.4s, v4.4s, v0.s[1]\n"
"fmla v17.4s, v5.4s, v0.s[1]\n"
"fmla v18.4s, v2.4s, v0.s[2]\n"
"fmla v19.4s, v3.4s, v0.s[2]\n"
"fmla v20.4s, v4.4s, v0.s[2]\n"
"fmla v21.4s, v5.4s, v0.s[2]\n"
"fmla v22.4s, v2.4s, v0.s[3]\n"
"fmla v23.4s, v3.4s, v0.s[3]\n"
"fmla v24.4s, v4.4s, v0.s[3]\n"
"fmla v25.4s, v5.4s, v0.s[3]\n"

"6:\n" STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[m_remain] "+r"(m_remain), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9",
"x10", "cc", "memory");

#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}

// Overview of register layout:
//
// A 2x4 cell of Rhs is stored in 32bit in v2 - v3
// A 4x2 cell of Lhs is stored in 32bit in v0 - v1
// A 4x4 block of accumulators is stored in 32bit in v4-v6
//
// +--------+
// | v2[0-3]|
// Rhs +--------+
// | v3[0-3]|
// +--------+
//
// | |
//
// Lhs | |
//
// +--+--+ - - - - +--------+
// |v0|v1| | v4[0-3]|
// |v0|v1| | v5[0-3]|
// |v0|v1| | v6[0-3]|
// |v0|v1| | v7[0-3]|
// +--+--+ - - - - +--------+
//
// Accumulator
void kern_4x4(const float* packA, const float* packB, int K, float* output,
int LDC, bool is_first_k, int m_remain, int n_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;

LDC = LDC * sizeof(float);
register float* outptr asm("x0") = output;

// clang-format off
#define LOAD_LINE(v0, n) \
"cmp x10, #0\n" \
"beq 102f\n" \
"cmp %w[n_remain], #4\n" \
"blt 100" n "f\n" \
"ld1 {v" v0 ".4s}, [x" n "], 16\n" \
"b 101" n "f\n" \
"100" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 101" n "f\n" \
"ld1 {v" v0 ".s}[0], [x" n "], 4\n" \
"cmp %w[n_remain], #1\n" \
"beq 101" n "f\n" \
"ld1 {v" v0 ".s}[1], [x" n "], 4\n" \
"cmp %w[n_remain], #2\n" \
"beq 101" n "f\n" \
"ld1 {v" v0 ".s}[2], [x" n "], 4\n" \
"101" n ":\n" \
"subs x10, x10, #1\n"

#define LOAD_C \
"mov x10, %x[m_remain]\n" \
LOAD_LINE("4", "0") \
LOAD_LINE("5", "1") \
LOAD_LINE("6", "2") \
LOAD_LINE("7", "3") \
"102:\n"

#define STORE_LINE(v0, n) \
"cmp x10, #0 \n" \
"beq 105f\n" \
"cmp %w[n_remain], #4\n" \
"blt 103" n "f\n" \
"st1 {v" v0 ".4s}, [x" n " ], 16\n" \
"b 104" n "f\n" \
"103" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 104" n "f\n" \
"st1 {v" v0 ".s}[0], [x" n "], 4\n" \
"cmp %w[n_remain], #1\n" \
"beq 104" n "f\n" \
"st1 {v" v0 ".s}[1], [x" n "], 4\n" \
"cmp %w[n_remain], #2\n" \
"beq 104" n "f\n" \
"st1 {v" v0 ".s}[2], [x" n "], 4\n" \
"104" n ":\n" \
"subs x10, x10, #1\n"


#define STORE_C \
"mov x10, %x[m_remain]\n" \
STORE_LINE("4", "0") \
STORE_LINE("5", "1") \
STORE_LINE("6", "2") \
STORE_LINE("7", "3") \
"105:\n"
// clang-format on

asm volatile(
// load accumulator C
"add x1, x0, %x[LDC]\n"
"add x2, x1, %x[LDC]\n"
"add x3, x2, %x[LDC]\n"

"cmp %w[is_first_k], #1\n"
"beq 1f\n" LOAD_C

"b 2f\n"

"1:\n"
"eor v4.16b, v4.16b, v4.16b\n"
"eor v5.16b, v5.16b, v5.16b\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"

"2: \n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"cmp %w[K], #0\n"
"beq 4f\n"

"3:\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"ld1 {v3.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"

"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v3.4s, v1.s[0]\n"
"fmla v5.4s, v3.4s, v1.s[1]\n"
"fmla v6.4s, v3.4s, v1.s[2]\n"
"fmla v7.4s, v3.4s, v1.s[3]\n"

"subs %w[K], %w[K], #1\n"
"bne 3b\n"

"4:\n"
"cmp %w[oddk], #1\n"
"beq 5f\n"

// Even tail
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"ld1 {v3.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"

"fmla v4.4s, v3.4s, v1.s[0]\n"
"fmla v5.4s, v3.4s, v1.s[1]\n"
"fmla v6.4s, v3.4s, v1.s[2]\n"
"fmla v7.4s, v3.4s, v1.s[3]\n"

"b 6f\n"

// odd tail
"5:\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"

"6:\n" STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [m_remain] "+r"(m_remain),
[n_remain] "+r"(n_remain), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1",
"x2", "x3", "x10", "cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}

void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0,
int ymax, int k0, int kmax) {
float zerobuff[4];
std::memset(zerobuff, 0, sizeof(float) * 4);
constexpr int PACK_SIZE = 4*4;

int y = y0;
for (; y + 3 < ymax; y += 4) {
// printf("main loop pack_a_n %p \n",outptr);
const float* inptr0 = inptr + y * ldin + k0;
const float* inptr1 = inptr0 + ldin;
const float* inptr2 = inptr1 + ldin;
const float* inptr3 = inptr2 + ldin;

prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);

int K = (kmax - k0);
for (; K > 3; K -= 4) {
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr);
outptr += PACK_SIZE;
}

interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K);
}

for (; y < ymax; y += 4) {
const float* inptr0 = inptr + y * ldin + k0;
const float* inptr1 = inptr0 + ldin;
const float* inptr2 = inptr1 + ldin;
const float* inptr3 = inptr2 + ldin;

prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);

int K = (kmax - k0);
for (; K > 3; K -= 4) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
/* Everything falls through in here */
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}

transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr);
outptr += PACK_SIZE;
}

if (K > 0) {
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K);
}
}
}

void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0,
int xmax, int k0, int kmax) {
int ksize = kmax - k0;
int ksize4 = (ksize << 2);
float* outptr_base = out;

int k = k0;
for (; k + 3 < kmax; k += 4) {
const float* inptr = in + k * ldin + x0;
const float* inptr1 = inptr + ldin;
const float* inptr2 = inptr1 + ldin;
const float* inptr3 = inptr2 + ldin;

prefetch_3x(inptr);
prefetch_3x(inptr1);
prefetch_3x(inptr2);
prefetch_3x(inptr3);

int x = x0;
auto outptr = outptr_base;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
outptr += ksize4;
}

if (x < xmax) {
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x);
}

outptr_base += 4 * 4;
}

for (; k < kmax; k++) {
const float* inptr = in + k * ldin + x0;
prefetch_3x(inptr);
int x = x0;
auto outptr = outptr_base;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
interleave_1x4_1_s(inptr, outptr_interleave);
outptr += ksize4;
}

if (x < xmax) {
interleave_1(inptr, outptr, 4, xmax - x);
}

outptr_base += 4;
}
}

void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin,
int x0, int xmax, int k0, int kmax) {
int ksize = kmax - k0;
int ksize16 = ksize * 16;
int ksize4 = (ksize << 2);
float* outptr_base = out;
float* outptr_base4 = outptr_base + (xmax - x0) / 16 * ksize16;

int k = k0;
for (; k + 3 < kmax; k += 4) {
const float* inptr = in + k * ldin + x0;
const float* inptr1 = inptr + ldin;
const float* inptr2 = inptr1 + ldin;
const float* inptr3 = inptr2 + ldin;

prefetch_3x(inptr);
prefetch_3x(inptr1);
prefetch_3x(inptr2);
prefetch_3x(inptr3);

int x = x0;
auto outptr = outptr_base;
for (; x + 16 <= xmax; x += 16) {
auto outptr_interleave = outptr;
interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
outptr += ksize16;
}
outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
outptr += ksize4;
}

if (x < xmax) {
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x);
}

outptr_base += 16 * 4;
outptr_base4 += 4 * 4;
}

for (; k < kmax; k++) {
const float* inptr = in + k * ldin + x0;
prefetch_3x(inptr);
int x = x0;
auto outptr = outptr_base;
for (; x + 16 <= xmax; x += 16) {
auto outptr_interleave = outptr;
interleave_1x16_1_s(inptr, outptr_interleave);
outptr += ksize16;
}
outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
interleave_1x4_1_s(inptr, outptr_interleave);
outptr += ksize4;
}

if (x < xmax) {
interleave_1(inptr, outptr, 4, xmax - x);
}

outptr_base += 16;
outptr_base4 += 4;
}
}

void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin,
int y0, int ymax, int k0, int kmax) {
float* outptr = out;
const float* inptr = in;
float zerobuff[4];
std::memset(zerobuff, 0, sizeof(float) * 4);
int K16 = 16 * (kmax - k0);

int y = y0;

for (; y + 16 <= ymax; y += 16) {
int yi = y;
for (; yi < y + 16; yi += 4) {
const float* inptr0 = inptr + yi * ldin + k0;
const float* inptr1 = inptr0 + ldin;
const float* inptr2 = inptr1 + ldin;
const float* inptr3 = inptr2 + ldin;
float* outptr_inner = outptr + yi - y;

prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);

int x = (kmax - k0);
for (; x > 3; x -= 4) {
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner,
64);
outptr_inner += 64;
}
for (; x > 0; x--) {
*outptr_inner++ = *inptr0++;
*outptr_inner++ = *inptr1++;
*outptr_inner++ = *inptr2++;
*outptr_inner++ = *inptr3++;
outptr_inner += 12;
}
}
outptr += K16;
}

for (; y < ymax; y += 4) {
const float* inptr0 = inptr + y * ldin + k0;
const float* inptr1 = inptr0 + ldin;
const float* inptr2 = inptr1 + ldin;
const float* inptr3 = inptr2 + ldin;

prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);

/* Cope with ragged cases by copying from a buffer of zeroes instead
*/
int x = (kmax - k0);
for (; x > 3; x -= 4) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
/* Everything falls through in here */
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}

transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr);
outptr += 16;
}

if (x > 0) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
/* Everything falls through in here */
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, x);
}
}
}

} // matmul_general_4x16
} // aarch64
} // megdnn

// vim: syntax=cpp.doxygen

+ 1242
- 0
dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h
File diff suppressed because it is too large
View File


+ 166
- 0
dnn/src/aarch64/matrix_mul/fp32/strategy.cpp View File

@@ -0,0 +1,166 @@
/**
* \file dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h"
#include "src/common/utils.h"

using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16);

void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0,
int ymax, int k0, int kmax, bool transpose_A) const {
if (transpose_A) {
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax);
} else {
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
}
}

void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
int k0, int kmax, bool transpose_B) const {
if (transpose_B) {
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}

void sgemm_4x16::kern(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);

constexpr size_t A_INTERLEAVE = 4;
constexpr size_t B_INTERLEAVE = 16;
const int K16 = K * 16;
const int K4 = K * 4;

size_t m = 0;
for (; m < M; m += A_INTERLEAVE) {
float* output = C + (m * LDC);

size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K16;
}

for (; n < N; n += 4) {
matmul_general_4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}

packA += K4;
}
}

MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12);

void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0,
int ymax, int k0, int kmax, bool transpose_A) const {
if (transpose_A) {
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0,
kmax);
} else {
matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
}
}

void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
int k0, int kmax, bool transpose_B) const {
if (transpose_B) {
matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0,
kmax);
} else {
matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
}
}

void sgemm_8x12::kern(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);

constexpr size_t A_INTERLEAVE = 8;
constexpr size_t A_INTERLEAVE4 = 4;
constexpr size_t B_INTERLEAVE = 12;
const int K12 = K * 12;
const int K8 = K * 8;
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) {
float* output = C + (m * LDC);

size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) {
matmul_general_8x12::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
output += B_INTERLEAVE;
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_general_8x12::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K8;
}
for (; m < M; m += A_INTERLEAVE4) {
float* output = C + (m * LDC);
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_general_8x12::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_general_8x12::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K4;
}
}

// vim: syntax=cpp.doxygen

+ 30
- 0
dnn/src/aarch64/matrix_mul/fp32/strategy.h View File

@@ -0,0 +1,30 @@
/**
* \file dnn/src/aarch64/matrix_mul/fp32/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"

namespace megdnn {
namespace aarch64 {
namespace matmul {
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true,
sgemm_8x12);

MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true,
sgemm_4x16);

MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true,
sgemm_nopack_4x16);

} // namespace matmul
} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 570
- 0
dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp View File

@@ -0,0 +1,570 @@
/**
* \file dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"

using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

namespace {

// Overview of register layout:
//
// A 4x4 block of A is stored in register v4-v7
// A 4x4 block of B is stored in register v0-v3
// A 8x4 block of accumulators store in v16-v19
//
// A +--------+
// | v4[0-3]|
// | v5[0-3]|
// | v6[0-3]|
// | v7[0-3]|
// +--------+
// B
// +--------+ - - - - -+--------+
// | v0[0-3]| |v16[0-3]|
// | v1[0-3]| |v17[0-3]|
// | v2[0-3]| |v18[0-3]|
// | v3[0-3]| |v19[0-3]|
// +--------+ - - - - -+--------+
// Accumulator

void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
float* output) {
//! As each load 16 number from B, but the pos add 12 * 4, so we minus 12
//! here.
LDB = (LDB - 12) * sizeof(float);
asm volatile(
"subs %w[K], %w[K], #4\n"
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n"

"ld1 {v0.4s}, [%[b_ptr]], 16\n"
"ld1 {v1.4s}, [%[b_ptr]], 16\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n"

"fmul v16.4s, v4.4s, v0.s[0]\n"
"fmul v17.4s, v4.4s, v1.s[0]\n"
"fmul v18.4s, v4.4s, v2.s[0]\n"
"fmul v19.4s, v4.4s, v3.s[0]\n"

"fmla v16.4s, v5.4s, v0.s[1]\n"
"fmla v17.4s, v5.4s, v1.s[1]\n"
"fmla v18.4s, v5.4s, v2.s[1]\n"
"fmla v19.4s, v5.4s, v3.s[1]\n"

"beq 2f\n"

"1:\n"

"ld1 {v4.4s, v5.4s}, [%[a_ptr]], 32\n"

"fmla v16.4s, v6.4s, v0.s[2]\n"
"fmla v17.4s, v6.4s, v1.s[2]\n"
"fmla v18.4s, v6.4s, v2.s[2]\n"
"fmla v19.4s, v6.4s, v3.s[2]\n"

"fmla v16.4s, v7.4s, v0.s[3]\n"
"fmla v17.4s, v7.4s, v1.s[3]\n"
"ld1 {v0.4s}, [%[b_ptr]], 16\n"
"fmla v18.4s, v7.4s, v2.s[3]\n"
"ld1 {v1.4s}, [%[b_ptr]], 16\n"
"fmla v19.4s, v7.4s, v3.s[3]\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"

"fmla v16.4s, v4.4s, v0.s[0]\n"
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n"
"fmla v17.4s, v4.4s, v1.s[0]\n"
"fmla v18.4s, v4.4s, v2.s[0]\n"
"fmla v19.4s, v4.4s, v3.s[0]\n"

"ld1 {v6.4s, v7.4s}, [%[a_ptr]], 32\n"

"fmla v16.4s, v5.4s, v0.s[1]\n"
"fmla v17.4s, v5.4s, v1.s[1]\n"
"fmla v18.4s, v5.4s, v2.s[1]\n"
"fmla v19.4s, v5.4s, v3.s[1]\n"

"subs %w[K], %w[K], #4\n"
"bne 1b\n"

"2:\n"

"fmla v16.4s, v6.4s, v0.s[2]\n"
"fmla v17.4s, v6.4s, v1.s[2]\n"
"fmla v18.4s, v6.4s, v2.s[2]\n"
"fmla v19.4s, v6.4s, v3.s[2]\n"

"fmla v16.4s, v7.4s, v0.s[3]\n"
"fmla v17.4s, v7.4s, v1.s[3]\n"
"fmla v18.4s, v7.4s, v2.s[3]\n"
"fmla v19.4s, v7.4s, v3.s[3]\n"

"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory");
}

// Overview of register layout:
//
// A 4x4 block of A is stored in register v4-v7
// A 4x4 block of B is stored in register v0-v3, slipping until 8x4
// A 8x4 block of accumulators store in v16-v23.
//
// A +--------+
// | v4[0-3]|
// | v5[0-3]|
// | v6[0-3]|
// | v7[0-3]|
// +--------+
// B
// +--------+ - - - - -+--------+
// | v0[0-3]| |v16[0-3]|
// | v1[0-3]| |v17[0-3]|
// | v2[0-3]| |v18[0-3]|
// | v3[0-3]| |v19[0-3]|
// +--------+ - - - - -+--------+
// | v0[0-3]| |v20[0-3]|
// | v1[0-3]| |v21[0-3]|
// | v2[0-3]| |v22[0-3]|
// | v3[0-3]| |v23[0-3]|
// +--------+ - - - - -+--------+
// Accumulator

void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
float* output) {
//! As each load 32 number from B, but the pos add 24 * 4, so we minus 24
//! here.
LDB = (LDB - 24) * sizeof(float);
asm volatile(
"subs %w[K], %w[K], #4\n"
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n"

"ld1 {v0.4s}, [%[b_ptr]], 16\n"
"fmul v16.4s, v4.4s, v0.s[0]\n"

"ld1 {v1.4s}, [%[b_ptr]], 16\n"
"fmla v16.4s, v5.4s, v0.s[1]\n"
"fmul v17.4s, v4.4s, v1.s[0]\n"

"ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n"
"fmla v17.4s, v5.4s, v1.s[1]\n"
"fmla v16.4s, v6.4s, v0.s[2]\n"
"fmla v17.4s, v6.4s, v1.s[2]\n"
"fmul v18.4s, v4.4s, v2.s[0]\n"
"fmla v16.4s, v7.4s, v0.s[3]\n"
"fmla v18.4s, v5.4s, v2.s[1]\n"
"fmla v17.4s, v7.4s, v1.s[3]\n"
"fmul v19.4s, v4.4s, v3.s[0]\n"

"ld1 {v24.4s, v25.4s}, [%[b_ptr]], 32\n"
"fmla v18.4s, v7.4s, v2.s[3]\n"
"fmla v19.4s, v5.4s, v3.s[1]\n"
"fmul v20.4s, v4.4s, v24.s[0]\n"
"fmla v19.4s, v6.4s, v3.s[2]\n"

"ld1 {v26.4s, v27.4s}, [%[b_ptr]], %x[LDB]\n"
"fmla v18.4s, v6.4s, v2.s[2]\n"
"fmla v19.4s, v7.4s, v3.s[3]\n"
"fmul v21.4s, v4.4s, v25.s[0]\n"
"fmla v20.4s, v5.4s, v24.s[1]\n"
"fmla v21.4s, v5.4s, v25.s[1]\n"

"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n"
"fmla v20.4s, v6.4s, v24.s[2]\n"
"fmul v22.4s, v4.4s, v26.s[0]\n"
"fmla v21.4s, v6.4s, v25.s[2]\n"
"fmla v22.4s, v5.4s, v26.s[1]\n"

"fmla v21.4s, v7.4s, v25.s[3]\n"
"fmul v23.4s, v4.4s, v27.s[0]\n"
"fmla v20.4s, v7.4s, v24.s[3]\n"
"fmla v22.4s, v6.4s, v26.s[2]\n"
"fmla v23.4s, v5.4s, v27.s[1]\n"

"beq 2f\n"

"1:\n"
"ld1 {v4.4s, v5.4s}, [%[a_ptr]], 32\n"
"fmla v22.4s, v7.4s, v26.s[3]\n"
"fmla v23.4s, v6.4s, v27.s[2]\n"
"fmla v16.4s, v4.4s, v0.s[0]\n"
"fmla v17.4s, v4.4s, v1.s[0]\n"
"fmla v23.4s, v7.4s, v27.s[3]\n"

"ld1 {v6.4s, v7.4s}, [%[a_ptr]], 32\n"
"fmla v16.4s, v5.4s, v0.s[1]\n"
"fmla v17.4s, v5.4s, v1.s[1]\n"
"fmla v16.4s, v6.4s, v0.s[2]\n"
"fmla v17.4s, v6.4s, v1.s[2]\n"
"fmla v18.4s, v4.4s, v2.s[0]\n"
"fmla v16.4s, v7.4s, v0.s[3]\n"
"fmla v18.4s, v5.4s, v2.s[1]\n"
"fmla v17.4s, v7.4s, v1.s[3]\n"
"fmla v19.4s, v4.4s, v3.s[0]\n"

"ld1 {v24.4s, v25.4s}, [%[b_ptr]], 32\n"
"fmla v18.4s, v6.4s, v2.s[2]\n"
"fmla v19.4s, v5.4s, v3.s[1]\n"
"fmla v20.4s, v4.4s, v24.s[0]\n"
"fmla v19.4s, v6.4s, v3.s[2]\n"

"ld1 {v26.4s, v27.4s}, [%[b_ptr]], %x[LDB]\n"
"fmla v18.4s, v7.4s, v2.s[3]\n"
"fmla v19.4s, v7.4s, v3.s[3]\n"
"fmla v21.4s, v4.4s, v25.s[0]\n"
"fmla v20.4s, v5.4s, v24.s[1]\n"
"fmla v21.4s, v5.4s, v25.s[1]\n"

"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n"
"fmla v20.4s, v6.4s, v24.s[2]\n"
"fmla v22.4s, v4.4s, v26.s[0]\n"
"fmla v20.4s, v7.4s, v24.s[3]\n"
"fmla v23.4s, v4.4s, v27.s[0]\n"
"fmla v21.4s, v6.4s, v25.s[2]\n"
"fmla v22.4s, v5.4s, v26.s[1]\n"
"fmla v21.4s, v7.4s, v25.s[3]\n"
"fmla v23.4s, v5.4s, v27.s[1]\n"
"fmla v22.4s, v6.4s, v26.s[2]\n"

"subs %w[K], %w[K], #4\n"
"bne 1b\n"

"2:\n"
"fmla v22.4s, v7.4s, v26.s[3]\n"
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n"
"fmla v23.4s, v6.4s, v27.s[2]\n"
"fmla v23.4s, v7.4s, v27.s[3]\n"
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
"v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
"v27", "cc", "memory");
}

// Overview of register layout:
//
// A 4x1 cell of Rhs is stored in 32bit in v4-v7 (v8-v11 for ping pong)
// A 4x1 cell of Lhs is stored in 32bit in v0-v3
// A 16x1 block of accumulators is stored in 32bit in v16-v31.
//
// Rhs +--------+
// | v4[0-3]|
// | v5[0-3]|
// | v6[0-3]|
// | v7[0-3]|
// +--------+
// Lhs
// +--------+ - - - - -+--------+
// | v0[0-3] | |v16[0-3]|
// | v1[0-3] | |v17[0-3]|
// | v2[0-3] | |v18[0-3]|
// | v3[0-3] | |v19[0-3]|
// | v8[0-3] | |v20[0-3]|
// | v9[0-3] | |v21[0-3]|
// | v10[0-3]| |v22[0-3]|
// | v11[0-3]| |v23[0-3]|
// +--------+ |v24[0-3]|
// |v25[0-3]|
// |v26[0-3]|
// |v27[0-3]|
// |v28[0-3]|
// |v29[0-3]|
// |v30[0-3]|
// |v31[0-3]|
// +--------+
// Accumulator

void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K,
float* output) {
//! As each load 64 number from B, but the pos add 56 * 4, so we minus 56
//! here.
LDB = (LDB - 56) * sizeof(float);

asm volatile(
"stp d8, d9, [sp, #-16]!\n"
"stp d10, d11, [sp, #-16]!\n"

"subs %w[K], %w[K], #4\n"
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n"

"fmul v16.4s, v4.4s, v0.s[0]\n"
"fmul v17.4s, v4.4s, v1.s[0]\n"
"fmul v18.4s, v4.4s, v2.s[0]\n"
"fmul v19.4s, v4.4s, v3.s[0]\n"

"fmla v16.4s, v5.4s, v0.s[1]\n"
"fmla v17.4s, v5.4s, v1.s[1]\n"
"fmla v18.4s, v5.4s, v2.s[1]\n"
"fmla v19.4s, v5.4s, v3.s[1]\n"

"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[b_ptr]], 64\n"

"fmla v16.4s, v6.4s, v0.s[2]\n"
"fmla v17.4s, v6.4s, v1.s[2]\n"
"fmla v18.4s, v6.4s, v2.s[2]\n"
"fmla v19.4s, v6.4s, v3.s[2]\n"

"fmla v16.4s, v7.4s, v0.s[3]\n"
"fmla v17.4s, v7.4s, v1.s[3]\n"
"fmla v18.4s, v7.4s, v2.s[3]\n"
"fmla v19.4s, v7.4s, v3.s[3]\n"

"fmul v20.4s, v4.4s, v8.s[0]\n"
"fmul v21.4s, v4.4s, v9.s[0]\n"
"fmul v22.4s, v4.4s, v10.s[0]\n"
"fmul v23.4s, v4.4s, v11.s[0]\n"

"fmla v20.4s, v5.4s, v8.s[1]\n"
"fmla v21.4s, v5.4s, v9.s[1]\n"
"fmla v22.4s, v5.4s, v10.s[1]\n"
"fmla v23.4s, v5.4s, v11.s[1]\n"

"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n"

"fmla v20.4s, v6.4s, v8.s[2]\n"
"fmla v21.4s, v6.4s, v9.s[2]\n"
"fmla v22.4s, v6.4s, v10.s[2]\n"
"fmla v23.4s, v6.4s, v11.s[2]\n"

"fmla v20.4s, v7.4s, v8.s[3]\n"
"fmla v21.4s, v7.4s, v9.s[3]\n"
"fmla v22.4s, v7.4s, v10.s[3]\n"
"fmla v23.4s, v7.4s, v11.s[3]\n"

"fmul v24.4s, v4.4s, v0.s[0]\n"
"fmul v25.4s, v4.4s, v1.s[0]\n"
"fmul v26.4s, v4.4s, v2.s[0]\n"
"fmul v27.4s, v4.4s, v3.s[0]\n"

"fmla v24.4s, v5.4s, v0.s[1]\n"
"fmla v25.4s, v5.4s, v1.s[1]\n"
"fmla v26.4s, v5.4s, v2.s[1]\n"
"fmla v27.4s, v5.4s, v3.s[1]\n"

"ld1 {v8.4s, v9.4s}, [%[b_ptr]], 32\n"

"fmla v24.4s, v6.4s, v0.s[2]\n"
"fmla v25.4s, v6.4s, v1.s[2]\n"
"fmla v26.4s, v6.4s, v2.s[2]\n"
"fmla v27.4s, v6.4s, v3.s[2]\n"

"ld1 {v10.4s, v11.4s}, [%[b_ptr]], %x[LDB]\n"

"fmla v24.4s, v7.4s, v0.s[3]\n"
"fmla v25.4s, v7.4s, v1.s[3]\n"
"fmla v26.4s, v7.4s, v2.s[3]\n"
"fmla v27.4s, v7.4s, v3.s[3]\n"

"fmul v28.4s, v4.4s, v8.s[0]\n"
"fmul v29.4s, v4.4s, v9.s[0]\n"
"fmul v30.4s, v4.4s, v10.s[0]\n"
"fmul v31.4s, v4.4s, v11.s[0]\n"

"beq 2f\n"

"1:\n"

"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n"

"fmla v28.4s, v5.4s, v8.s[1]\n"
"fmla v29.4s, v5.4s, v9.s[1]\n"
"fmla v30.4s, v5.4s, v10.s[1]\n"
"fmla v31.4s, v5.4s, v11.s[1]\n"

"ld1 {v4.4s}, [%[a_ptr]], 16\n"

"fmla v28.4s, v6.4s, v8.s[2]\n"
"fmla v29.4s, v6.4s, v9.s[2]\n"
"fmla v30.4s, v6.4s, v10.s[2]\n"
"fmla v31.4s, v6.4s, v11.s[2]\n"

"ld1 {v5.4s}, [%[a_ptr]], 16\n"

"fmla v28.4s, v7.4s, v8.s[3]\n"
"fmla v29.4s, v7.4s, v9.s[3]\n"
"fmla v30.4s, v7.4s, v10.s[3]\n"
"fmla v31.4s, v7.4s, v11.s[3]\n"

"ld1 {v6.4s}, [%[a_ptr]], 16\n"

"fmla v16.4s, v4.4s, v0.s[0]\n"
"fmla v17.4s, v4.4s, v1.s[0]\n"
"fmla v18.4s, v4.4s, v2.s[0]\n"
"fmla v19.4s, v4.4s, v3.s[0]\n"

"ld1 {v7.4s}, [%[a_ptr]], 16\n"

"fmla v16.4s, v5.4s, v0.s[1]\n"
"fmla v17.4s, v5.4s, v1.s[1]\n"
"fmla v18.4s, v5.4s, v2.s[1]\n"
"fmla v19.4s, v5.4s, v3.s[1]\n"

"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[b_ptr]], 64\n"

"fmla v16.4s, v6.4s, v0.s[2]\n"
"fmla v17.4s, v6.4s, v1.s[2]\n"
"fmla v18.4s, v6.4s, v2.s[2]\n"
"fmla v19.4s, v6.4s, v3.s[2]\n"

"fmla v16.4s, v7.4s, v0.s[3]\n"
"fmla v17.4s, v7.4s, v1.s[3]\n"
"fmla v18.4s, v7.4s, v2.s[3]\n"
"fmla v19.4s, v7.4s, v3.s[3]\n"

"fmla v20.4s, v4.4s, v8.s[0]\n"
"fmla v21.4s, v4.4s, v9.s[0]\n"
"fmla v22.4s, v4.4s, v10.s[0]\n"
"fmla v23.4s, v4.4s, v11.s[0]\n"

"fmla v20.4s, v5.4s, v8.s[1]\n"
"fmla v21.4s, v5.4s, v9.s[1]\n"
"fmla v22.4s, v5.4s, v10.s[1]\n"
"fmla v23.4s, v5.4s, v11.s[1]\n"

"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n"

"fmla v20.4s, v6.4s, v8.s[2]\n"
"fmla v21.4s, v6.4s, v9.s[2]\n"
"fmla v22.4s, v6.4s, v10.s[2]\n"
"fmla v23.4s, v6.4s, v11.s[2]\n"

"fmla v20.4s, v7.4s, v8.s[3]\n"
"fmla v21.4s, v7.4s, v9.s[3]\n"
"fmla v22.4s, v7.4s, v10.s[3]\n"
"fmla v23.4s, v7.4s, v11.s[3]\n"

"fmla v24.4s, v4.4s, v0.s[0]\n"
"fmla v25.4s, v4.4s, v1.s[0]\n"
"fmla v26.4s, v4.4s, v2.s[0]\n"
"fmla v27.4s, v4.4s, v3.s[0]\n"

"fmla v24.4s, v5.4s, v0.s[1]\n"
"fmla v25.4s, v5.4s, v1.s[1]\n"
"fmla v26.4s, v5.4s, v2.s[1]\n"
"fmla v27.4s, v5.4s, v3.s[1]\n"

"ld1 {v8.4s, v9.4s}, [%[b_ptr]], 32\n"

"fmla v24.4s, v6.4s, v0.s[2]\n"
"fmla v25.4s, v6.4s, v1.s[2]\n"
"fmla v26.4s, v6.4s, v2.s[2]\n"
"fmla v27.4s, v6.4s, v3.s[2]\n"

"ld1 {v10.4s, v11.4s}, [%[b_ptr]], %x[LDB]\n"

"fmla v24.4s, v7.4s, v0.s[3]\n"
"fmla v25.4s, v7.4s, v1.s[3]\n"
"fmla v26.4s, v7.4s, v2.s[3]\n"
"fmla v27.4s, v7.4s, v3.s[3]\n"

"fmla v28.4s, v4.4s, v8.s[0]\n"
"fmla v29.4s, v4.4s, v9.s[0]\n"
"fmla v30.4s, v4.4s, v10.s[0]\n"
"fmla v31.4s, v4.4s, v11.s[0]\n"

"subs %w[K], %w[K], #4\n"
"bne 1b\n"

"2:\n"

"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n"

"fmla v28.4s, v5.4s, v8.s[1]\n"
"fmla v29.4s, v5.4s, v9.s[1]\n"
"fmla v30.4s, v5.4s, v10.s[1]\n"
"fmla v31.4s, v5.4s, v11.s[1]\n"

"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n"

"fmla v28.4s, v6.4s, v8.s[2]\n"
"fmla v29.4s, v6.4s, v9.s[2]\n"
"fmla v30.4s, v6.4s, v10.s[2]\n"
"fmla v31.4s, v6.4s, v11.s[2]\n"

"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n"

"fmla v28.4s, v7.4s, v8.s[3]\n"
"fmla v29.4s, v7.4s, v9.s[3]\n"
"fmla v30.4s, v7.4s, v10.s[3]\n"
"fmla v31.4s, v7.4s, v11.s[3]\n"

"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n"

"ldp d10, d11, [sp], #16\n"
"ldp d8, d9, [sp], #16\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc",
"memory");
}

} // namespace

MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16);

void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B,
size_t LDB, float* C, size_t LDC, size_t M,
size_t K, size_t N, const float*, void*, bool trA,
bool trB) const {
constexpr static size_t MB = 4;
constexpr static size_t KB = 4;
constexpr static size_t NB = 16;
constexpr static size_t CALCBLK = 4;

megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0);

//! (m/4, k/4, 4, 4) * (k/4, n, 4) = (m/4, n, 4)
for (size_t m = 0; m < M; m += MB) {
float* output = C + (m / MB) * LDC;
const float* cur_B = B;
size_t n = 0;
for (; n + NB - 1 < N; n += NB) {
kern_4x16(A, cur_B, LDB, K, output);
cur_B += KB * NB;
output += MB * NB;
}
switch (N - n) {
case 4:
kern_4x4(A, cur_B, LDB, K, output);
break;
case 8:
kern_4x8(A, cur_B, LDB, K, output);
break;
case 12:
kern_4x8(A, cur_B, LDB, K, output);
cur_B += KB * CALCBLK * 2;
output += MB * CALCBLK * 2;
kern_4x4(A, cur_B, LDB, K, output);
break;
default:
break;
}
A += LDA;
}
}

// vim: syntax=cpp.doxygen

+ 1175
- 0
dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h
File diff suppressed because it is too large
View File


+ 132
- 0
dnn/src/aarch64/matrix_mul/int16/strategy.cpp View File

@@ -0,0 +1,132 @@
/**
* \file dnn/src/aarch64/matrix_mul/int16/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/matrix_mul/int16/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int16/kernel_12x8x1.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"

using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

///////////////////////// gemm_s16_12x8x1 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1);

void gemm_s16_12x8x1::pack_A(dt_int16* outptr, const dt_int16* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n(outptr, inptr, ldin,
y0, ymax, k0, kmax);
} else {
matmul_12x8x1::gemm_s16_12x8x1_pack_A_n(outptr, inptr, ldin, y0, ymax,
k0, kmax);
}
}

void gemm_s16_12x8x1::pack_B(dt_int16* out, const dt_int16* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n(out, in, ldin, x0,
xmax, k0, kmax);
} else {
matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
}
}

void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB,
size_t M, size_t N, size_t K, dt_int32* C,
size_t LDC, bool is_first_k, const dt_int32*,
dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::Int16 &&
C_dtype.enumv() == DTypeEnum::Int32),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);

constexpr size_t A_INTERLEAVE = 12;
constexpr size_t B_INTERLEAVE = 8;
const int K12 = K * 12;
const int K8 = K * 8;
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int32_t* output = C + (m * LDC);

size_t n = 0;
const dt_int16* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC,
is_first_k);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_12x8x1::kern_12x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K12;
}

for (; m + 7 < M; m += 8) {
int32_t* output = C + (m * LDC);
const dt_int16* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_12x8x1::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K8;
}

for (; m < M; m += 4) {
int32_t* output = C + (m * LDC);
const dt_int16* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_12x8x1::kern_4x8(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_12x8x1::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K4;
}
}

// vim: syntax=cpp.doxygen

+ 29
- 0
dnn/src/aarch64/matrix_mul/int16/strategy.h View File

@@ -0,0 +1,29 @@
/**
* \file dnn/src/aarch64/matrix_mul/int16/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "src/fallback/matrix_mul/gemm_common.h"

namespace megdnn {
namespace aarch64 {
namespace matmul {

MEGDNN_REG_GEMM_STRATEGY(dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true,
gemm_s16_12x8x1);

MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 8, 8, 1, false,
true, gemm_nopack_s16_8x8);

} // namespace matmul
} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 658
- 0
dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp View File

@@ -0,0 +1,658 @@
/**
* \file dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/matrix_mul/int16/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"

using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

namespace {

// Overview of register layout:
//
// A 8x1 cell of Lhs is stored in 16bit in v24-v27
// A 8x1 cell of Rhs is stored in 16bit in v0-v7
// A 8x2 block of accumulators is stored in 32bit in v16-v23.
//
// Lhs +--------+
// |v0[0-7]|
// |v1[0-7]|
// |v2[0-7]|
// |v3[0-7]|
// +--------+
// Rhs
// +---------+ - - - - -+--------+
// | v24[0-7]| |v16[0-3]|
// | v25[0-7]| |v17[0-3]|
// | v26[0-7]| |v18[0-3]|
// | v27[0-7]| |v19[0-3]|
// | v28[0-7]| |v20[0-3]|
// | v29[0-7]| |v21[0-3]|
// | v30[0-7]| |v22[0-3]|
// | v31[0-7]| |v23[0-3]|
// +---------+ +--------+
// Accumulator
void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24
//! here.
LDB = (LDB - 24) * sizeof(dt_int16);

asm volatile(
"subs %w[K], %w[K], #8\n"

"ld1 {v24.4s}, [%[a_ptr]], 16\n"
"ld1 {v0.4s}, [%[b_ptr]], 16\n"
"ld1 {v1.4s}, [%[b_ptr]], 16\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n"

"smull v16.4s, v24.4h, v0.h[0]\n"
"smull2 v17.4s, v24.8h, v0.h[0]\n"
"smull v18.4s, v24.4h, v1.h[0]\n"
"smull2 v19.4s, v24.8h, v1.h[0]\n"

"ld1 {v25.4s}, [%[a_ptr]], 16\n"

"smull v20.4s, v24.4h, v2.h[0]\n"
"smull2 v21.4s, v24.8h, v2.h[0]\n"
"smull v22.4s, v24.4h, v3.h[0]\n"
"smull2 v23.4s, v24.8h, v3.h[0]\n"

"smlal v16.4s, v25.4h, v0.h[1]\n"
"smlal2 v17.4s, v25.8h, v0.h[1]\n"
"smlal v18.4s, v25.4h, v1.h[1]\n"
"smlal2 v19.4s, v25.8h, v1.h[1]\n"

"ld1 {v26.4s}, [%[a_ptr]], 16\n"

"smlal v20.4s, v25.4h, v2.h[1]\n"
"smlal2 v21.4s, v25.8h, v2.h[1]\n"
"smlal v22.4s, v25.4h, v3.h[1]\n"
"smlal2 v23.4s, v25.8h, v3.h[1]\n"

"smlal v16.4s, v26.4h, v0.h[2]\n"
"smlal2 v17.4s, v26.8h, v0.h[2]\n"
"smlal v18.4s, v26.4h, v1.h[2]\n"
"smlal2 v19.4s, v26.8h, v1.h[2]\n"

"ld1 {v27.4s}, [%[a_ptr]], 16\n"

"smlal v20.4s, v26.4h, v2.h[2]\n"
"smlal2 v21.4s, v26.8h, v2.h[2]\n"
"smlal v22.4s, v26.4h, v3.h[2]\n"
"smlal2 v23.4s, v26.8h, v3.h[2]\n"

"smlal v16.4s, v27.4h, v0.h[3]\n"
"smlal2 v17.4s, v27.8h, v0.h[3]\n"
"smlal v18.4s, v27.4h, v1.h[3]\n"
"smlal2 v19.4s, v27.8h, v1.h[3]\n"

"ld1 {v28.4s}, [%[a_ptr]], 16\n"

"smlal v20.4s, v27.4h, v2.h[3]\n"
"smlal2 v21.4s, v27.8h, v2.h[3]\n"
"smlal v22.4s, v27.4h, v3.h[3]\n"
"smlal2 v23.4s, v27.8h, v3.h[3]\n"

"smlal v16.4s, v28.4h, v0.h[4]\n"
"smlal2 v17.4s, v28.8h, v0.h[4]\n"
"smlal v18.4s, v28.4h, v1.h[4]\n"
"smlal2 v19.4s, v28.8h, v1.h[4]\n"

"ld1 {v29.4s}, [%[a_ptr]], 16\n"

"smlal v20.4s, v28.4h, v2.h[4]\n"
"smlal2 v21.4s, v28.8h, v2.h[4]\n"
"smlal v22.4s, v28.4h, v3.h[4]\n"
"smlal2 v23.4s, v28.8h, v3.h[4]\n"

"smlal v16.4s, v29.4h, v0.h[5]\n"
"smlal2 v17.4s, v29.8h, v0.h[5]\n"
"smlal v18.4s, v29.4h, v1.h[5]\n"
"smlal2 v19.4s, v29.8h, v1.h[5]\n"

"ld1 {v30.4s}, [%[a_ptr]], 16\n"

"smlal v20.4s, v29.4h, v2.h[5]\n"
"smlal2 v21.4s, v29.8h, v2.h[5]\n"
"smlal v22.4s, v29.4h, v3.h[5]\n"
"smlal2 v23.4s, v29.8h, v3.h[5]\n"

"smlal v16.4s, v30.4h, v0.h[6]\n"
"smlal2 v17.4s, v30.8h, v0.h[6]\n"
"smlal v18.4s, v30.4h, v1.h[6]\n"
"smlal2 v19.4s, v30.8h, v1.h[6]\n"

"ld1 {v31.4s}, [%[a_ptr]], 16\n"

"smlal v20.4s, v30.4h, v2.h[6]\n"
"smlal2 v21.4s, v30.8h, v2.h[6]\n"
"smlal v22.4s, v30.4h, v3.h[6]\n"
"smlal2 v23.4s, v30.8h, v3.h[6]\n"

"beq 2f\n"

"1:\n"

"ld1 {v24.4s}, [%[a_ptr]], 16\n"

"smlal v16.4s, v31.4h, v0.h[7]\n"
"smlal2 v17.4s, v31.8h, v0.h[7]\n"

"ld1 {v0.4s}, [%[b_ptr]], 16\n"

"smlal v18.4s, v31.4h, v1.h[7]\n"
"smlal2 v19.4s, v31.8h, v1.h[7]\n"

"ld1 {v1.4s}, [%[b_ptr]], 16\n"

"smlal v20.4s, v31.4h, v2.h[7]\n"
"smlal2 v21.4s, v31.8h, v2.h[7]\n"

"ld1 {v2.4s}, [%[b_ptr]], 16\n"

"smlal v22.4s, v31.4h, v3.h[7]\n"
"smlal2 v23.4s, v31.8h, v3.h[7]\n"

"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n"

"smlal v16.4s, v24.4h, v0.h[0]\n"
"smlal2 v17.4s, v24.8h, v0.h[0]\n"
"smlal v18.4s, v24.4h, v1.h[0]\n"
"smlal2 v19.4s, v24.8h, v1.h[0]\n"

"ld1 {v25.4s}, [%[a_ptr]], 16\n"

"smlal v20.4s, v24.4h, v2.h[0]\n"
"smlal2 v21.4s, v24.8h, v2.h[0]\n"
"smlal v22.4s, v24.4h, v3.h[0]\n"
"smlal2 v23.4s, v24.8h, v3.h[0]\n"

"smlal v16.4s, v25.4h, v0.h[1]\n"
"smlal2 v17.4s, v25.8h, v0.h[1]\n"
"smlal v18.4s, v25.4h, v1.h[1]\n"
"smlal2 v19.4s, v25.8h, v1.h[1]\n"

"ld1 {v26.4s}, [%[a_ptr]], 16\n"

"smlal v20.4s, v25.4h, v2.h[1]\n"
"smlal2 v21.4s, v25.8h, v2.h[1]\n"
"smlal v22.4s, v25.4h, v3.h[1]\n"
"smlal2 v23.4s, v25.8h, v3.h[1]\n"

"smlal v16.4s, v26.4h, v0.h[2]\n"
"smlal2 v17.4s, v26.8h, v0.h[2]\n"
"smlal v18.4s, v26.4h, v1.h[2]\n"
"smlal2 v19.4s, v26.8h, v1.h[2]\n"

"ld1 {v27.4s}, [%[a_ptr]], 16\n"

"smlal v20.4s, v26.4h, v2.h[2]\n"
"smlal2 v21.4s, v26.8h, v2.h[2]\n"
"smlal v22.4s, v26.4h, v3.h[2]\n"
"smlal2 v23.4s, v26.8h, v3.h[2]\n"

"smlal v16.4s, v27.4h, v0.h[3]\n"
"smlal2 v17.4s, v27.8h, v0.h[3]\n"
"smlal v18.4s, v27.4h, v1.h[3]\n"
"smlal2 v19.4s, v27.8h, v1.h[3]\n"

"ld1 {v28.4s}, [%[a_ptr]], 16\n"

"smlal v20.4s, v27.4h, v2.h[3]\n"
"smlal2 v21.4s, v27.8h, v2.h[3]\n"
"smlal v22.4s, v27.4h, v3.h[3]\n"
"smlal2 v23.4s, v27.8h, v3.h[3]\n"

"smlal v16.4s, v28.4h, v0.h[4]\n"
"smlal2 v17.4s, v28.8h, v0.h[4]\n"
"smlal v18.4s, v28.4h, v1.h[4]\n"
"smlal2 v19.4s, v28.8h, v1.h[4]\n"

"ld1 {v29.4s}, [%[a_ptr]], 16\n"

"smlal v20.4s, v28.4h, v2.h[4]\n"
"smlal2 v21.4s, v28.8h, v2.h[4]\n"
"smlal v22.4s, v28.4h, v3.h[4]\n"
"smlal2 v23.4s, v28.8h, v3.h[4]\n"

"smlal v16.4s, v29.4h, v0.h[5]\n"
"smlal2 v17.4s, v29.8h, v0.h[5]\n"
"smlal v18.4s, v29.4h, v1.h[5]\n"
"smlal2 v19.4s, v29.8h, v1.h[5]\n"

"ld1 {v30.4s}, [%[a_ptr]], 16\n"

"smlal v20.4s, v29.4h, v2.h[5]\n"
"smlal2 v21.4s, v29.8h, v2.h[5]\n"
"smlal v22.4s, v29.4h, v3.h[5]\n"
"smlal2 v23.4s, v29.8h, v3.h[5]\n"

"smlal v16.4s, v30.4h, v0.h[6]\n"
"smlal2 v17.4s, v30.8h, v0.h[6]\n"
"smlal v18.4s, v30.4h, v1.h[6]\n"
"smlal2 v19.4s, v30.8h, v1.h[6]\n"

"ld1 {v31.4s}, [%[a_ptr]], 16\n"

"smlal v20.4s, v30.4h, v2.h[6]\n"
"smlal2 v21.4s, v30.8h, v2.h[6]\n"
"smlal v22.4s, v30.4h, v3.h[6]\n"
"smlal2 v23.4s, v30.8h, v3.h[6]\n"

"subs %w[K], %w[K], #8\n"
"bne 1b\n"

"2:\n"

"smlal v16.4s, v31.4h, v0.h[7]\n"
"smlal2 v17.4s, v31.8h, v0.h[7]\n"
"smlal v18.4s, v31.4h, v1.h[7]\n"
"smlal2 v19.4s, v31.8h, v1.h[7]\n"
"smlal v20.4s, v31.4h, v2.h[7]\n"
"smlal2 v21.4s, v31.8h, v2.h[7]\n"
"smlal v22.4s, v31.4h, v3.h[7]\n"
"smlal2 v23.4s, v31.8h, v3.h[7]\n"

"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n"
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
}

// Overview of register layout:
//
// A 8x1 cell of Rhs is stored in 16bit in v8-v15
// A 8x1 cell of Lhs is stored in 16bit in v0-v7
// A 8x2 block of accumulators is stored in 32bit in v16-v31.
//
// Rhs +--------+
// | v8[0-7]|
// | v9[0-7]|
// |v10[0-7]|
// |v11[0-7]|
// |v12[0-7]|
// |v13[0-7]|
// |v14[0-7]|
// |v15[0-7]|
// +--------+
// Lhs
// +--------+ - - - - -+--------+--------+
// | v0[0-7]| |v16[0-3]|v17[0-3]|
// | v1[0-7]| |v18[0-3]|v19[0-3]|
// | v2[0-7]| |v20[0-3]|v21[0-3]|
// | v3[0-7]| |v22[0-3]|v23[0-3]|
// | v4[0-7]| |v24[0-3]|v25[0-3]|
// | v5[0-7]| |v26[0-3]|v27[0-3]|
// | v6[0-7]| |v28[0-3]|v29[0-3]|
// | v7[0-7]| |v30[0-3]|v31[0-3]|
// +--------+ +--------+--------+
// Accumulator
void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
//! As each load 64 number from B, but the pos add 48 * 2, so we minus 48
//! here.
LDB = (LDB - 48) * sizeof(dt_int16);

asm volatile(
"subs %w[K], %w[K], #8\n"
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n"
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n"

"ld1 {v0.4s}, [%[b_ptr]], 16\n"
"smull v16.4s, v8.4h, v0.h[0]\n"
"ld1 {v1.4s}, [%[b_ptr]], 16\n"
"smlal v16.4s, v9.4h, v0.h[1]\n"
"smull v18.4s, v8.4h, v1.h[0]\n"
"smull2 v17.4s, v8.8h, v0.h[0]\n"
"smull2 v19.4s, v8.8h, v1.h[0]\n"
"smlal v16.4s, v10.4h, v0.h[2]\n"
"smlal v18.4s, v9.4h, v1.h[1]\n"
"smlal2 v17.4s, v9.8h, v0.h[1]\n"
"smlal2 v19.4s, v9.8h, v1.h[1]\n"
"smlal v16.4s, v11.4h, v0.h[3]\n"
"smlal v18.4s, v10.4h, v1.h[2]\n"
"smlal2 v17.4s, v10.8h, v0.h[2]\n"
"smlal2 v19.4s, v10.8h, v1.h[2]\n"
"smlal v16.4s, v12.4h, v0.h[4]\n"
"smlal v18.4s, v11.4h, v1.h[3]\n"
"smlal2 v17.4s, v11.8h, v0.h[3]\n"
"smlal2 v19.4s, v11.8h, v1.h[3]\n"
"smlal v16.4s, v13.4h, v0.h[5]\n"
"smlal v18.4s, v12.4h, v1.h[4]\n"
"smlal2 v17.4s, v12.8h, v0.h[4]\n"
"smlal2 v19.4s, v12.8h, v1.h[4]\n"
"smlal2 v17.4s, v13.8h, v0.h[5]\n"

"ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n"
"smlal v16.4s, v14.4h, v0.h[6]\n"
"smlal v18.4s, v13.4h, v1.h[5]\n"
"smlal2 v17.4s, v14.8h, v0.h[6]\n"
"smlal2 v19.4s, v13.8h, v1.h[5]\n"
"smull v20.4s, v8.4h, v2.h[0]\n"
"smull v22.4s, v8.4h, v3.h[0]\n"
"smull2 v21.4s, v8.8h, v2.h[0]\n"
"smull2 v23.4s, v8.8h, v3.h[0]\n"
"smlal v16.4s, v15.4h, v0.h[7]\n"
"smlal v18.4s, v14.4h, v1.h[6]\n"
"smlal2 v17.4s, v15.8h, v0.h[7]\n"
"smlal2 v19.4s, v14.8h, v1.h[6]\n"
"smlal v20.4s, v9.4h, v2.h[1]\n"
"smlal v22.4s, v9.4h, v3.h[1]\n"
"smlal2 v21.4s, v9.8h, v2.h[1]\n"
"smlal2 v23.4s, v9.8h, v3.h[1]\n"
"smlal v18.4s, v15.4h, v1.h[7]\n"
"smlal v20.4s, v10.4h, v2.h[2]\n"
"smlal v22.4s, v10.4h, v3.h[2]\n"
"smlal2 v21.4s, v10.8h, v2.h[2]\n"
"smlal2 v23.4s, v10.8h, v3.h[2]\n"
"smlal2 v19.4s, v15.8h, v1.h[7]\n"
"smlal v20.4s, v11.4h, v2.h[3]\n"
"smlal v22.4s, v11.4h, v3.h[3]\n"
"smlal2 v21.4s, v11.8h, v2.h[3]\n"
"smlal2 v23.4s, v11.8h, v3.h[3]\n"
"smlal v20.4s, v12.4h, v2.h[4]\n"
"smlal v22.4s, v12.4h, v3.h[4]\n"
"smlal2 v21.4s, v12.8h, v2.h[4]\n"
"smlal2 v23.4s, v12.8h, v3.h[4]\n"
"smlal v20.4s, v13.4h, v2.h[5]\n"
"smlal v22.4s, v13.4h, v3.h[5]\n"
"smlal2 v21.4s, v13.8h, v2.h[5]\n"
"smlal2 v23.4s, v13.8h, v3.h[5]\n"

"ld1 {v4.4s, v5.4s}, [%[b_ptr]], 32\n"
"smlal v20.4s, v14.4h, v2.h[6]\n"
"smlal v22.4s, v14.4h, v3.h[6]\n"
"smlal2 v21.4s, v14.8h, v2.h[6]\n"
"smlal2 v23.4s, v14.8h, v3.h[6]\n"
"smull v24.4s, v8.4h, v4.h[0]\n"
"smull v26.4s, v8.4h, v5.h[0]\n"
"smull2 v25.4s, v8.8h, v4.h[0]\n"
"smull2 v27.4s, v8.8h, v5.h[0]\n"
"smlal v20.4s, v15.4h, v2.h[7]\n"
"smlal v22.4s, v15.4h, v3.h[7]\n"
"smlal2 v21.4s, v15.8h, v2.h[7]\n"
"smlal2 v23.4s, v15.8h, v3.h[7]\n"
"smlal v24.4s, v9.4h, v4.h[1]\n"
"smlal v26.4s, v9.4h, v5.h[1]\n"
"smlal2 v25.4s, v9.8h, v4.h[1]\n"
"smlal2 v27.4s, v9.8h, v5.h[1]\n"
"smlal v24.4s, v10.4h, v4.h[2]\n"
"smlal v26.4s, v10.4h, v5.h[2]\n"
"smlal2 v25.4s, v10.8h, v4.h[2]\n"
"smlal2 v27.4s, v10.8h, v5.h[2]\n"
"smlal v24.4s, v11.4h, v4.h[3]\n"
"smlal v26.4s, v11.4h, v5.h[3]\n"
"smlal2 v25.4s, v11.8h, v4.h[3]\n"
"smlal2 v27.4s, v11.8h, v5.h[3]\n"
"smlal v24.4s, v12.4h, v4.h[4]\n"
"smlal v26.4s, v12.4h, v5.h[4]\n"
"smlal2 v25.4s, v12.8h, v4.h[4]\n"
"smlal2 v27.4s, v12.8h, v5.h[4]\n"
"smlal v24.4s, v13.4h, v4.h[5]\n"
"smlal v26.4s, v13.4h, v5.h[5]\n"
"smlal2 v25.4s, v13.8h, v4.h[5]\n"
"smlal2 v27.4s, v13.8h, v5.h[5]\n"

"ld1 {v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n"
"smlal v24.4s, v14.4h, v4.h[6]\n"
"smlal v26.4s, v14.4h, v5.h[6]\n"
"smlal2 v25.4s, v14.8h, v4.h[6]\n"
"smlal2 v27.4s, v14.8h, v5.h[6]\n"
"smull v28.4s, v8.4h, v6.h[0]\n"
"smull v30.4s, v8.4h, v7.h[0]\n"
"smull2 v29.4s, v8.8h, v6.h[0]\n"
"smull2 v31.4s, v8.8h, v7.h[0]\n"
"smlal v28.4s, v9.4h, v6.h[1]\n"
"smlal v30.4s, v9.4h, v7.h[1]\n"
"smlal2 v29.4s, v9.8h, v6.h[1]\n"
"smlal2 v31.4s, v9.8h, v7.h[1]\n"
"smlal v28.4s, v10.4h, v6.h[2]\n"
"smlal v30.4s, v10.4h, v7.h[2]\n"
"smlal2 v29.4s, v10.8h, v6.h[2]\n"
"smlal2 v31.4s, v10.8h, v7.h[2]\n"
"smlal v28.4s, v11.4h, v6.h[3]\n"
"smlal v30.4s, v11.4h, v7.h[3]\n"
"smlal2 v29.4s, v11.8h, v6.h[3]\n"
"smlal2 v31.4s, v11.8h, v7.h[3]\n"
"smlal v28.4s, v12.4h, v6.h[4]\n"
"smlal v30.4s, v12.4h, v7.h[4]\n"
"smlal2 v29.4s, v12.8h, v6.h[4]\n"
"smlal2 v31.4s, v12.8h, v7.h[4]\n"
"smlal v28.4s, v13.4h, v6.h[5]\n"
"smlal v30.4s, v13.4h, v7.h[5]\n"
"smlal2 v29.4s, v13.8h, v6.h[5]\n"
"smlal2 v31.4s, v13.8h, v7.h[5]\n"

"beq 2f\n"

"1:\n"

"smlal v24.4s, v15.4h, v4.h[7]\n"
"smlal v26.4s, v15.4h, v5.h[7]\n"
"smlal2 v25.4s, v15.8h, v4.h[7]\n"

"ld1 {v8.4s, v9.4s}, [%[a_ptr]], 32\n"
"smlal2 v27.4s, v15.8h, v5.h[7]\n"
"smlal v28.4s, v14.4h, v6.h[6]\n"
"smlal v30.4s, v14.4h, v7.h[6]\n"

"ld1 {v10.4s, v11.4s}, [%[a_ptr]], 32\n"
"smlal2 v29.4s, v15.8h, v6.h[7]\n"
"smlal2 v31.4s, v14.8h, v7.h[6]\n"
"smlal v28.4s, v15.4h, v6.h[7]\n"

"ld1 {v12.4s, v13.4s}, [%[a_ptr]], 32\n"
"smlal v30.4s, v15.4h, v7.h[7]\n"
"smlal2 v29.4s, v14.8h, v6.h[6]\n"

"ld1 {v0.4s}, [%[b_ptr]], 16\n"
"smlal2 v31.4s, v15.8h, v7.h[7]\n"
"smlal v16.4s, v8.4h, v0.h[0]\n"

"ld1 {v1.4s}, [%[b_ptr]], 16\n"
"smlal v16.4s, v9.4h, v0.h[1]\n"
"smlal2 v17.4s, v8.8h, v0.h[0]\n"
"smlal v16.4s, v10.4h, v0.h[2]\n"
"smlal v18.4s, v8.4h, v1.h[0]\n"
"smlal2 v17.4s, v9.8h, v0.h[1]\n"
"smlal2 v19.4s, v8.8h, v1.h[0]\n"

"ld1 {v14.4s, v15.4s}, [%[a_ptr]], 32\n"
"smlal v16.4s, v11.4h, v0.h[3]\n"
"smlal v18.4s, v9.4h, v1.h[1]\n"
"smlal2 v17.4s, v10.8h, v0.h[2]\n"
"smlal2 v19.4s, v9.8h, v1.h[1]\n"
"smlal v16.4s, v12.4h, v0.h[4]\n"
"smlal v18.4s, v10.4h, v1.h[2]\n"
"smlal2 v17.4s, v11.8h, v0.h[3]\n"
"smlal2 v19.4s, v10.8h, v1.h[2]\n"
"smlal v16.4s, v13.4h, v0.h[5]\n"
"smlal v18.4s, v11.4h, v1.h[3]\n"
"smlal2 v17.4s, v12.8h, v0.h[4]\n"
"smlal2 v19.4s, v11.8h, v1.h[3]\n"
"smlal v16.4s, v14.4h, v0.h[6]\n"
"smlal v18.4s, v12.4h, v1.h[4]\n"
"smlal2 v17.4s, v13.8h, v0.h[5]\n"
"smlal2 v19.4s, v12.8h, v1.h[4]\n"
"smlal v16.4s, v15.4h, v0.h[7]\n"
"smlal v18.4s, v13.4h, v1.h[5]\n"
"smlal2 v17.4s, v14.8h, v0.h[6]\n"
"smlal2 v19.4s, v13.8h, v1.h[5]\n"

"ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n"
"smlal v18.4s, v14.4h, v1.h[6]\n"
"smlal2 v17.4s, v15.8h, v0.h[7]\n"
"smlal2 v19.4s, v14.8h, v1.h[6]\n"
"smlal v20.4s, v8.4h, v2.h[0]\n"
"smlal v22.4s, v8.4h, v3.h[0]\n"
"smlal2 v21.4s, v8.8h, v2.h[0]\n"
"smlal2 v23.4s, v8.8h, v3.h[0]\n"
"smlal v18.4s, v15.4h, v1.h[7]\n"
"smlal v20.4s, v9.4h, v2.h[1]\n"
"smlal v22.4s, v9.4h, v3.h[1]\n"
"smlal2 v21.4s, v9.8h, v2.h[1]\n"
"smlal2 v23.4s, v9.8h, v3.h[1]\n"
"smlal2 v19.4s, v15.8h, v1.h[7]\n"
"smlal v20.4s, v10.4h, v2.h[2]\n"
"smlal v22.4s, v10.4h, v3.h[2]\n"
"smlal2 v21.4s, v10.8h, v2.h[2]\n"
"smlal2 v23.4s, v10.8h, v3.h[2]\n"
"smlal v20.4s, v11.4h, v2.h[3]\n"
"smlal v22.4s, v11.4h, v3.h[3]\n"
"smlal2 v21.4s, v11.8h, v2.h[3]\n"
"smlal2 v23.4s, v11.8h, v3.h[3]\n"
"smlal v20.4s, v12.4h, v2.h[4]\n"
"smlal v22.4s, v12.4h, v3.h[4]\n"
"smlal2 v21.4s, v12.8h, v2.h[4]\n"
"smlal2 v23.4s, v12.8h, v3.h[4]\n"
"smlal v20.4s, v13.4h, v2.h[5]\n"
"smlal v22.4s, v13.4h, v3.h[5]\n"
"smlal2 v21.4s, v13.8h, v2.h[5]\n"
"smlal2 v23.4s, v13.8h, v3.h[5]\n"

"ld1 {v4.4s, v5.4s}, [%[b_ptr]], 32\n"
"smlal v20.4s, v14.4h, v2.h[6]\n"
"smlal v22.4s, v14.4h, v3.h[6]\n"
"smlal2 v21.4s, v14.8h, v2.h[6]\n"
"smlal2 v23.4s, v14.8h, v3.h[6]\n"
"smlal v24.4s, v8.4h, v4.h[0]\n"
"smlal v26.4s, v8.4h, v5.h[0]\n"
"smlal2 v25.4s, v8.8h, v4.h[0]\n"
"smlal2 v27.4s, v8.8h, v5.h[0]\n"
"smlal v20.4s, v15.4h, v2.h[7]\n"
"smlal2 v21.4s, v15.8h, v2.h[7]\n"
"smlal v22.4s, v15.4h, v3.h[7]\n"
"smlal2 v23.4s, v15.8h, v3.h[7]\n"
"smlal v24.4s, v9.4h, v4.h[1]\n"
"smlal v26.4s, v9.4h, v5.h[1]\n"
"smlal2 v25.4s, v9.8h, v4.h[1]\n"
"smlal2 v27.4s, v9.8h, v5.h[1]\n"
"smlal v24.4s, v10.4h, v4.h[2]\n"
"smlal v26.4s, v10.4h, v5.h[2]\n"
"smlal2 v25.4s, v10.8h, v4.h[2]\n"
"smlal2 v27.4s, v10.8h, v5.h[2]\n"
"smlal v24.4s, v11.4h, v4.h[3]\n"
"smlal v26.4s, v11.4h, v5.h[3]\n"
"smlal2 v25.4s, v11.8h, v4.h[3]\n"
"smlal2 v27.4s, v11.8h, v5.h[3]\n"
"smlal v24.4s, v12.4h, v4.h[4]\n"
"smlal v26.4s, v12.4h, v5.h[4]\n"
"smlal2 v25.4s, v12.8h, v4.h[4]\n"
"smlal2 v27.4s, v12.8h, v5.h[4]\n"
"smlal v24.4s, v13.4h, v4.h[5]\n"
"smlal v26.4s, v13.4h, v5.h[5]\n"
"smlal2 v25.4s, v13.8h, v4.h[5]\n"
"smlal2 v27.4s, v13.8h, v5.h[5]\n"

"ld1 {v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n"
"smlal v24.4s, v14.4h, v4.h[6]\n"
"smlal v26.4s, v14.4h, v5.h[6]\n"
"smlal2 v25.4s, v14.8h, v4.h[6]\n"
"smlal2 v27.4s, v14.8h, v5.h[6]\n"
"smlal v28.4s, v8.4h, v6.h[0]\n"
"smlal v30.4s, v8.4h, v7.h[0]\n"
"smlal2 v29.4s, v8.8h, v6.h[0]\n"
"smlal2 v31.4s, v8.8h, v7.h[0]\n"
"smlal v28.4s, v9.4h, v6.h[1]\n"
"smlal v30.4s, v9.4h, v7.h[1]\n"
"smlal2 v29.4s, v9.8h, v6.h[1]\n"
"smlal2 v31.4s, v9.8h, v7.h[1]\n"
"smlal v28.4s, v10.4h, v6.h[2]\n"
"smlal v30.4s, v10.4h, v7.h[2]\n"
"smlal2 v29.4s, v10.8h, v6.h[2]\n"
"smlal2 v31.4s, v10.8h, v7.h[2]\n"
"smlal v28.4s, v11.4h, v6.h[3]\n"
"smlal v30.4s, v11.4h, v7.h[3]\n"
"smlal2 v29.4s, v11.8h, v6.h[3]\n"
"smlal2 v31.4s, v11.8h, v7.h[3]\n"
"smlal v28.4s, v12.4h, v6.h[4]\n"
"smlal v30.4s, v12.4h, v7.h[4]\n"
"smlal2 v29.4s, v12.8h, v6.h[4]\n"
"smlal2 v31.4s, v12.8h, v7.h[4]\n"
"smlal v28.4s, v13.4h, v6.h[5]\n"
"smlal v30.4s, v13.4h, v7.h[5]\n"
"smlal2 v29.4s, v13.8h, v6.h[5]\n"
"smlal2 v31.4s, v13.8h, v7.h[5]\n"

"subs %w[K], %w[K], #8\n"
"bne 1b\n"

"2:\n"
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n"
"smlal v24.4s, v15.4h, v4.h[7]\n"
"smlal v28.4s, v14.4h, v6.h[6]\n"
"smlal v30.4s, v14.4h, v7.h[6]\n"
"smlal v26.4s, v15.4h, v5.h[7]\n"
"smlal2 v25.4s, v15.8h, v4.h[7]\n"
"smlal2 v27.4s, v15.8h, v5.h[7]\n"
"smlal2 v29.4s, v14.8h, v6.h[6]\n"
"smlal2 v31.4s, v14.8h, v7.h[6]\n"
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n"
"smlal v28.4s, v15.4h, v6.h[7]\n"
"smlal v30.4s, v15.4h, v7.h[7]\n"
"smlal2 v29.4s, v15.8h, v6.h[7]\n"
"smlal2 v31.4s, v15.8h, v7.h[7]\n"
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n"
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
}

} // anonymous namespace

MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8);

void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B,
size_t LDB, dt_int32* C, size_t LDC, size_t M,
size_t K, size_t N, const dt_int32*, void*,
bool trA, bool trB) const {
constexpr static size_t MB = 8;
constexpr static size_t KB = 8;
constexpr static size_t NB = 8;
constexpr static size_t CALCBLK = 4;

megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0);

//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8)
for (size_t m = 0; m < M; m += MB) {
dt_int32* output = C + (m / MB) * LDC;
const dt_int16* cur_B = B;
size_t n = 0;
for (; n + NB - 1 < N; n += NB) {
kern_8x8(A, cur_B, LDB, K, output);
cur_B += KB * NB;
output += MB * NB;
}
if (n < N) {
kern_8x4(A, cur_B, LDB, K, output);
}
A += LDA;
}
}

// vim: syntax=cpp.doxygen

+ 856
- 0
dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h View File

@@ -0,0 +1,856 @@
/**
* \file dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"

namespace megdnn {
namespace aarch64 {
namespace matmul_4x4x16 {

/**
* Overview of register layout:
*
* A 16x2 cell of Rhs is stored in 8bit in q2-q4.
* A 8x2x2 cell of Lhs is stored in 8bit in q0-q1
* A 8x16 block of accumulators is stored in 8bit in q8--q31.
*
* \warning Fast kernel operating on int8 operands.
* It is assumed that one of the two int8 operands only takes values
* in [-127, 127], while the other may freely range in [-128, 127].
* The issue with both operands taking the value -128 is that:
* -128*-128 + -128*-128 == -32768 overflows int16.
* Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16
* range. That is the basic idea of this kernel.
*
*
* +--------+--------+---------+---------+
* |v4[0-16]|v5[0-16]| v6[0-16]| v7[0-16]|
* Rhs +--------+--------+---------+---------+
* |v8[0-16]|v9[0-16]|v10[0-16]|v11[0-16]|
* +--------+--------+---------+---------+
* | | | | |
*
* Lhs | | | | |
*
* +--------+ - - - - +-------------------------------------+
* |v0[0-16]| |v16[0-4]|v17[0-4]| v18[0-4]| v19[0-4]|
* |v1[0-16]| |v20[0-4]|v21[0-4]| v22[0-4]| v23[0-4]|
* |v2[0-16]| |v24[0-4]|v25[0-4]| v26[0-4]| v27[0-4]|
* |v3[0-16]| |v28[0-4]|v29[0-4]| v30[0-4]| v31[0-4]|
* +--------+ - - - - +-------------------------------------+
*
* Accumulator
*/

static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
K /= 16;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
int k = ((K + 1) / 2) - 1;

LDC = LDC * sizeof(int32_t);

asm volatile (
// load accumulator C
"add x1, %[output], %x[LDC]\n"
"add x2, x1, %x[LDC]\n"
"add x3, x2, %x[LDC]\n"
"cmp %w[is_first_k], #1\n"
"beq 1f\n"

"ldr q16, [%[output]]\n"
"ldr q17, [x1]\n"
"ldr q18, [x2]\n"
"ldr q19, [x3]\n"
"b 2f\n"

"1:\n"
"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"eor v18.16b, v18.16b, v18.16b\n"
"eor v19.16b, v19.16b, v19.16b\n"

"2: \n"
"ldr q0, [%[a_ptr]]\n"
"ldr q4, [%[b_ptr]]\n"
"ldr q5, [%[b_ptr], #16]\n"
"ldr q6, [%[b_ptr], #32]\n"
"movi v20.4s, #0x0\n"
"ldr q7, [%[b_ptr], #48]\n"
"movi v21.4s, #0x0\n"
"ldr q1, [%[a_ptr], #16]\n"
"movi v22.4s, #0x0\n"
"ldr q2, [%[a_ptr], #32]\n"
"movi v23.4s, #0x0\n"
"ldr q3, [%[a_ptr], #48]\n"
"movi v24.4s, #0x0\n"
ASM_PREFETCH("[%[b_ptr], #64]")
"movi v25.4s, #0x0\n"
ASM_PREFETCH("[%[a_ptr], #64]")
"movi v26.4s, #0x0\n"
ASM_PREFETCH("[%[b_ptr], #128]")
"movi v27.4s, #0x0\n"
ASM_PREFETCH("[%[a_ptr], #128]")
"movi v28.4s, #0x0\n"
ASM_PREFETCH("[%[b_ptr], #192]")
"movi v29.4s, #0x0\n"
ASM_PREFETCH("[%[a_ptr], #192]")
"movi v30.4s, #0x0\n"
ASM_PREFETCH("[%[b_ptr], #256]")
"movi v31.4s, #0x0\n"
ASM_PREFETCH("[%[a_ptr], #256]")

// Start of unroll 0 (first iteration)
"smull v12.8h, v0.8b, v4.8b\n"
"smull v13.8h, v0.8b, v5.8b\n"

// Skip loop if we are doing zero iterations of it.
"cbz %w[k], 4f\n"

// Unroll 0 continuation (branch target)
"3:\n"
"smull v14.8h, v0.8b, v6.8b\n"
"subs %w[k], %w[k], #1\n"
"smull v15.8h, v0.8b, v7.8b\n"
"ldr q8, [%[b_ptr], #64]\n"
"smlal2 v12.8h, v0.16b, v4.16b\n"
"smlal2 v13.8h, v0.16b, v5.16b\n"
"ldr q9, [%[b_ptr], #80]\n"
"smlal2 v14.8h, v0.16b, v6.16b\n"
"smlal2 v15.8h, v0.16b, v7.16b\n"
"ldr q0, [%[a_ptr], #64]\n"

"sadalp v16.4s, v12.8h\n"
"smull v12.8h, v1.8b, v4.8b\n"
"sadalp v17.4s, v13.8h\n"
"sadalp v18.4s, v14.8h\n"
"smull v13.8h, v1.8b, v5.8b\n"
"sadalp v19.4s, v15.8h\n"
"smull v14.8h, v1.8b, v6.8b\n"
"ldr q10, [%[b_ptr], #96]\n"
"smull v15.8h, v1.8b, v7.8b\n"
"smlal2 v12.8h, v1.16b, v4.16b\n"
"ldr q11, [%[b_ptr], #112]\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"add %[b_ptr], %[b_ptr], #128\n"
"smlal2 v14.8h, v1.16b, v6.16b\n"
"smlal2 v15.8h, v1.16b, v7.16b\n"
"ldr q1, [%[a_ptr], #80]\n"

"sadalp v20.4s, v12.8h\n"
"smull v12.8h, v2.8b, v4.8b\n"
"sadalp v21.4s, v13.8h\n"
"sadalp v22.4s, v14.8h\n"
"smull v13.8h, v2.8b, v5.8b\n"
"sadalp v23.4s, v15.8h\n"
"smull v14.8h, v2.8b, v6.8b\n"
"smull v15.8h, v2.8b, v7.8b\n"
"smlal2 v12.8h, v2.16b, v4.16b\n"
ASM_PREFETCH("[%[b_ptr], #192]")
"smlal2 v13.8h, v2.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v6.16b\n"
ASM_PREFETCH("[%[a_ptr], #320]")
"smlal2 v15.8h, v2.16b, v7.16b\n"
"ldr q2, [%[a_ptr], #96]\n"

"sadalp v24.4s, v12.8h\n"
"smull v12.8h, v3.8b, v4.8b\n"
"sadalp v25.4s, v13.8h\n"
"sadalp v26.4s, v14.8h\n"
"smull v13.8h, v3.8b, v5.8b\n"
"sadalp v27.4s, v15.8h\n"
"smull v14.8h, v3.8b, v6.8b\n"
"smull v15.8h, v3.8b, v7.8b\n"
"smlal2 v12.8h, v3.16b, v4.16b\n"
"ldr q4, [%[b_ptr], #0]\n"
"smlal2 v13.8h, v3.16b, v5.16b\n"
"smlal2 v14.8h, v3.16b, v6.16b\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"ldr q3, [%[a_ptr], #112]\n"

// Unroll 1
"sadalp v28.4s, v12.8h\n"
"smull v12.8h, v0.8b, v8.8b\n"
"sadalp v29.4s, v13.8h\n"
"sadalp v30.4s, v14.8h\n"
"smull v13.8h, v0.8b, v9.8b\n"
"sadalp v31.4s, v15.8h\n"
"smull v14.8h, v0.8b, v10.8b\n"
"smull v15.8h, v0.8b, v11.8b\n"
"ldr q5, [%[b_ptr], #16]\n"
"smlal2 v12.8h, v0.16b, v8.16b\n"
"smlal2 v13.8h, v0.16b, v9.16b\n"
"ldr q6, [%[b_ptr], #32]\n"
"smlal2 v14.8h, v0.16b, v10.16b\n"
"smlal2 v15.8h, v0.16b, v11.16b\n"
"ldr q0, [%[a_ptr], #128]\n"

"sadalp v16.4s, v12.8h\n"
"smull v12.8h, v1.8b, v8.8b\n"
"sadalp v17.4s, v13.8h\n"
"sadalp v18.4s, v14.8h\n"
"smull v13.8h, v1.8b, v9.8b\n"
"sadalp v19.4s, v15.8h\n"
"add %[a_ptr], %[a_ptr], #128\n"
"smull v14.8h, v1.8b, v10.8b\n"
"smull v15.8h, v1.8b, v11.8b\n"
"ldr q7, [%[b_ptr], #48]\n"
"smlal2 v12.8h, v1.16b, v8.16b\n"
"smlal2 v13.8h, v1.16b, v9.16b\n"
"smlal2 v14.8h, v1.16b, v10.16b\n"
"smlal2 v15.8h, v1.16b, v11.16b\n"
"ldr q1, [%[a_ptr], #16]\n"

"sadalp v20.4s, v12.8h\n"
"smull v12.8h, v2.8b, v8.8b\n"
"sadalp v21.4s, v13.8h\n"
"sadalp v22.4s, v14.8h\n"
"smull v13.8h, v2.8b, v9.8b\n"
"sadalp v23.4s, v15.8h\n"
"smull v14.8h, v2.8b, v10.8b\n"
"smull v15.8h, v2.8b, v11.8b\n"
"smlal2 v12.8h, v2.16b, v8.16b\n"
ASM_PREFETCH("[%[b_ptr], #256]")
"smlal2 v13.8h, v2.16b, v9.16b\n"
"smlal2 v14.8h, v2.16b, v10.16b\n"
ASM_PREFETCH("[%[a_ptr], #256]")
"smlal2 v15.8h, v2.16b, v11.16b\n"
"ldr q2, [%[a_ptr], #32]\n"

"sadalp v24.4s, v12.8h\n"
"smull v12.8h, v3.8b, v8.8b\n"
"sadalp v25.4s, v13.8h\n"
"sadalp v26.4s, v14.8h\n"
"smull v13.8h, v3.8b, v9.8b\n"
"sadalp v27.4s, v15.8h\n"
"smull v14.8h, v3.8b, v10.8b\n"
"smull v15.8h, v3.8b, v11.8b\n"
"smlal2 v12.8h, v3.16b, v8.16b\n"
"smlal2 v13.8h, v3.16b, v9.16b\n"
"smlal2 v14.8h, v3.16b, v10.16b\n"
"smlal2 v15.8h, v3.16b, v11.16b\n"
"ldr q3, [%[a_ptr], #48]\n"

// Start of unroll 0 for next iteration.
"sadalp v28.4s, v12.8h\n"
"smull v12.8h, v0.8b, v4.8b\n"
"sadalp v29.4s, v13.8h\n"
"sadalp v30.4s, v14.8h\n"
"smull v13.8h, v0.8b, v5.8b\n"
"sadalp v31.4s, v15.8h\n"
"bne 3b\n"

// Target to use when K=1 or 2 (i.e. zero iterations of main loop)
"4:\n"

// Branch to alternative tail for odd K
"cbnz %w[oddk], 5f\n"

// Detached final iteration (even K)
"smull v14.8h, v0.8b, v6.8b\n"
"smull v15.8h, v0.8b, v7.8b\n"
"ldr q8, [%[b_ptr], #64]\n"
"smlal2 v12.8h, v0.16b, v4.16b\n"
"smlal2 v13.8h, v0.16b, v5.16b\n"
"ldr q9, [%[b_ptr], #80]\n"
"smlal2 v14.8h, v0.16b, v6.16b\n"
"smlal2 v15.8h, v0.16b, v7.16b\n"
"ldr q0, [%[a_ptr], #64]\n"

"sadalp v16.4s, v12.8h\n"
"smull v12.8h, v1.8b, v4.8b\n"
"sadalp v17.4s, v13.8h\n"
"sadalp v18.4s, v14.8h\n"
"smull v13.8h, v1.8b, v5.8b\n"
"sadalp v19.4s, v15.8h\n"
"smull v14.8h, v1.8b, v6.8b\n"
"ldr q10, [%[b_ptr], #96]\n"
"smull v15.8h, v1.8b, v7.8b\n"
"smlal2 v12.8h, v1.16b, v4.16b\n"
"ldr q11, [%[b_ptr], #112]\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"add %[b_ptr], %[b_ptr], #128\n"
"smlal2 v14.8h, v1.16b, v6.16b\n"
"smlal2 v15.8h, v1.16b, v7.16b\n"
"ldr q1, [%[a_ptr], #80]\n"

"sadalp v20.4s, v12.8h\n"
"smull v12.8h, v2.8b, v4.8b\n"
"sadalp v21.4s, v13.8h\n"
"sadalp v22.4s, v14.8h\n"
"smull v13.8h, v2.8b, v5.8b\n"
"sadalp v23.4s, v15.8h\n"
"smull v14.8h, v2.8b, v6.8b\n"
"smull v15.8h, v2.8b, v7.8b\n"
"smlal2 v12.8h, v2.16b, v4.16b\n"
"smlal2 v13.8h, v2.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v6.16b\n"
"smlal2 v15.8h, v2.16b, v7.16b\n"
"ldr q2, [%[a_ptr], #96]\n"

"sadalp v24.4s, v12.8h\n"
"smull v12.8h, v3.8b, v4.8b\n"
"sadalp v25.4s, v13.8h\n"
"sadalp v26.4s, v14.8h\n"
"smull v13.8h, v3.8b, v5.8b\n"
"sadalp v27.4s, v15.8h\n"
"smull v14.8h, v3.8b, v6.8b\n"
"smull v15.8h, v3.8b, v7.8b\n"
"smlal2 v12.8h, v3.16b, v4.16b\n"
"smlal2 v13.8h, v3.16b, v5.16b\n"
"smlal2 v14.8h, v3.16b, v6.16b\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"ldr q3, [%[a_ptr], #112]\n"

// Unroll 1
"sadalp v28.4s, v12.8h\n"
"smull v12.8h, v0.8b, v8.8b\n"
"sadalp v29.4s, v13.8h\n"
"sadalp v30.4s, v14.8h\n"
"smull v13.8h, v0.8b, v9.8b\n"
"sadalp v31.4s, v15.8h\n"
"smull v14.8h, v0.8b, v10.8b\n"
"add %[a_ptr], %[a_ptr], #128\n"
"smull v15.8h, v0.8b, v11.8b\n"
"smlal2 v12.8h, v0.16b, v8.16b\n"
"smlal2 v13.8h, v0.16b, v9.16b\n"
"smlal2 v14.8h, v0.16b, v10.16b\n"
"smlal2 v15.8h, v0.16b, v11.16b\n"

"sadalp v16.4s, v12.8h\n"
"smull v12.8h, v1.8b, v8.8b\n"
"sadalp v17.4s, v13.8h\n"
"sadalp v18.4s, v14.8h\n"
"smull v13.8h, v1.8b, v9.8b\n"
"sadalp v19.4s, v15.8h\n"
"smull v14.8h, v1.8b, v10.8b\n"
"smull v15.8h, v1.8b, v11.8b\n"
"smlal2 v12.8h, v1.16b, v8.16b\n"
"addp v16.4s, v16.4s, v17.4s\n"
"smlal2 v13.8h, v1.16b, v9.16b\n"
"addp v17.4s, v18.4s, v19.4s\n"
"smlal2 v14.8h, v1.16b, v10.16b\n"
"smlal2 v15.8h, v1.16b, v11.16b\n"

"sadalp v20.4s, v12.8h\n"
"smull v12.8h, v2.8b, v8.8b\n"
"sadalp v21.4s, v13.8h\n"
"sadalp v22.4s, v14.8h\n"
"smull v13.8h, v2.8b, v9.8b\n"
"sadalp v23.4s, v15.8h\n"
"addp v16.4s, v16.4s, v17.4s\n"
"smull v14.8h, v2.8b, v10.8b\n"
"addp v18.4s, v20.4s, v21.4s\n"
"addp v19.4s, v22.4s, v23.4s\n"
"smull v15.8h, v2.8b, v11.8b\n"
"smlal2 v12.8h, v2.16b, v8.16b\n"
"str q16, [%[output]]\n"
"smlal2 v13.8h, v2.16b, v9.16b\n"
"smlal2 v14.8h, v2.16b, v10.16b\n"
"smlal2 v15.8h, v2.16b, v11.16b\n"

"sadalp v24.4s, v12.8h\n"
"smull v12.8h, v3.8b, v8.8b\n"
"sadalp v25.4s, v13.8h\n"
"sadalp v26.4s, v14.8h\n"
"smull v13.8h, v3.8b, v9.8b\n"
"sadalp v27.4s, v15.8h\n"
"addp v17.4s, v18.4s, v19.4s\n"
"smull v14.8h, v3.8b, v10.8b\n"
"addp v20.4s, v24.4s, v25.4s\n"
"addp v21.4s, v26.4s, v27.4s\n"
"smull v15.8h, v3.8b, v11.8b\n"
"smlal2 v12.8h, v3.16b, v8.16b\n"
"str q17, [x1]\n"
"smlal2 v13.8h, v3.16b, v9.16b\n"
"smlal2 v14.8h, v3.16b, v10.16b\n"
"addp v18.4s, v20.4s, v21.4s\n"
"smlal2 v15.8h, v3.16b, v11.16b\n"
"b 6f\n"

// Detached final iteration (odd K)
"5:\n"
"smull v14.8h, v0.8b, v6.8b\n"
"add %[a_ptr], %[a_ptr], #64\n"
"smull v15.8h, v0.8b, v7.8b\n"
"add %[b_ptr], %[b_ptr], #64\n"
"smlal2 v12.8h, v0.16b, v4.16b\n"
"smlal2 v13.8h, v0.16b, v5.16b\n"
"smlal2 v14.8h, v0.16b, v6.16b\n"
"smlal2 v15.8h, v0.16b, v7.16b\n"

"sadalp v16.4s, v12.8h\n"
"smull v12.8h, v1.8b, v4.8b\n"
"sadalp v17.4s, v13.8h\n"
"sadalp v18.4s, v14.8h\n"
"smull v13.8h, v1.8b, v5.8b\n"
"sadalp v19.4s, v15.8h\n"
"smull v14.8h, v1.8b, v6.8b\n"
"smull v15.8h, v1.8b, v7.8b\n"
"smlal2 v12.8h, v1.16b, v4.16b\n"
"addp v16.4s, v16.4s, v17.4s\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"addp v17.4s, v18.4s, v19.4s\n"
"smlal2 v14.8h, v1.16b, v6.16b\n"
"smlal2 v15.8h, v1.16b, v7.16b\n"

"sadalp v20.4s, v12.8h\n"
"smull v12.8h, v2.8b, v4.8b\n"
"sadalp v21.4s, v13.8h\n"
"sadalp v22.4s, v14.8h\n"
"smull v13.8h, v2.8b, v5.8b\n"
"sadalp v23.4s, v15.8h\n"
"addp v16.4s, v16.4s, v17.4s\n"
"smull v14.8h, v2.8b, v6.8b\n"
"addp v18.4s, v20.4s, v21.4s\n"
"addp v19.4s, v22.4s, v23.4s\n"
"smull v15.8h, v2.8b, v7.8b\n"
"smlal2 v12.8h, v2.16b, v4.16b\n"
"str q16, [%[output]]\n"
"smlal2 v13.8h, v2.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v6.16b\n"
"smlal2 v15.8h, v2.16b, v7.16b\n"

"sadalp v24.4s, v12.8h\n"
"smull v12.8h, v3.8b, v4.8b\n"
"sadalp v25.4s, v13.8h\n"
"sadalp v26.4s, v14.8h\n"
"smull v13.8h, v3.8b, v5.8b\n"
"sadalp v27.4s, v15.8h\n"
"addp v17.4s, v18.4s, v19.4s\n"
"smull v14.8h, v3.8b, v6.8b\n"
"addp v20.4s, v24.4s, v25.4s\n"
"addp v21.4s, v26.4s, v27.4s\n"
"smull v15.8h, v3.8b, v7.8b\n"
"smlal2 v12.8h, v3.16b, v4.16b\n"
"str q17, [x1]\n"
"smlal2 v13.8h, v3.16b, v5.16b\n"
"smlal2 v14.8h, v3.16b, v6.16b\n"
"addp v18.4s, v20.4s, v21.4s\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"

"6:\n"

// Final additions
"sadalp v28.4s, v12.8h\n"
"str q18, [x2]\n"
"sadalp v29.4s, v13.8h\n"
"sadalp v30.4s, v14.8h\n"
"sadalp v31.4s, v15.8h\n"

// Horizontal reduction, phase 1
"addp v22.4s, v28.4s, v29.4s\n"
"addp v23.4s, v30.4s, v31.4s\n"

// Horizontal reduction, phase 2
"addp v19.4s, v22.4s, v23.4s\n"
"str q19, [x3]\n"

:
[a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [oddk] "+r" (oddk),
[is_first_k] "+r" (is_first_k), [k] "+r" (k), [LDC] "+r" (LDC),
[output] "+r"(output)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "v20","v21","v22","v23","v24","v25","v26",
"v27","v28","v29","v30","v31", "x1", "x2", "x3",
"cc", "memory"
);
}

static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k,
int m_remain, int n_remain) {
megdnn_assert(K > 0);
K /= 16;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;

LDC = LDC * sizeof(int32_t);

// clang-format off
#define LOAD_LINE(reg_index, n) \
"cbz x4, 102f\n" \
"mov x5, x" n "\n" \
"cmp %w[n_remain], #4\n" \
"blt 100" n "f\n" \
"ldr q" reg_index ", [x5]\n" \
"b 101" n "f\n" \
"100" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".s}[0], [x5], #4\n" \
"cmp %w[n_remain], #1\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".s}[1], [x5], #4\n" \
"cmp %w[n_remain], #2\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".s}[2], [x5], #4\n" \
"101" n ":\n" \
"subs x4, x4, #1\n"

#define LOAD_C \
"mov x4, %x[m_remain]\n" \
LOAD_LINE("16", "0") \
LOAD_LINE("17", "1") \
LOAD_LINE("18", "2") \
LOAD_LINE("19", "3") \
"102:\n"

#define STORE_LINE(reg_index, n) \
"cbz x4, 105f\n" \
"mov x5, x" n "\n" \
"cmp %w[n_remain], #4\n" \
"blt 103" n "f\n" \
"str q" reg_index ", [x5]\n" \
"b 104" n "f\n" \
"103" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 104" n "f\n" \
"st1 {v" reg_index ".s}[0], [x5], #4\n" \
"cmp %w[n_remain], #1\n" \
"beq 104" n "f\n" \
"st1 {v" reg_index ".s}[1], [x5], #4\n" \
"cmp %w[n_remain], #2\n" \
"beq 104" n "f\n" \
"st1 {v" reg_index ".s}[2], [x5], #4\n" \
"104" n ":\n" \
"subs x4, x4, #1\n"

#define STORE_C \
"mov x4, %x[m_remain]\n" \
STORE_LINE("16", "0") \
STORE_LINE("17", "1") \
STORE_LINE("18", "2") \
STORE_LINE("19", "3") \
"105:\n"

// clang-format on

asm volatile(
// load accumulator C
"mov x0, %[output]\n"
"add x1, x0, %x[LDC]\n"
"add x2, x1, %x[LDC]\n"
"add x3, x2, %x[LDC]\n"
"cmp %w[is_first_k], #1\n"
"beq 1f\n"

LOAD_C //
"b 2f\n"

"1:\n"
"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"eor v18.16b, v18.16b, v18.16b\n"
"eor v19.16b, v19.16b, v19.16b\n"
"eor v20.16b, v20.16b, v20.16b\n"
"eor v21.16b, v21.16b, v21.16b\n"
"eor v22.16b, v22.16b, v22.16b\n"
"eor v23.16b, v23.16b, v23.16b\n"
"eor v24.16b, v24.16b, v24.16b\n"
"eor v25.16b, v25.16b, v25.16b\n"
"eor v26.16b, v26.16b, v26.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"eor v31.16b, v31.16b, v31.16b\n"

"2: \n"
"ldr q4, [%[b_ptr]]\n"
"ldr q5, [%[b_ptr], #16]\n"
"ldr q6, [%[b_ptr], #32]\n"
"ldr q7, [%[b_ptr], #48]\n"
"ldr q0, [%[a_ptr]]\n"
"ldr q1, [%[a_ptr], #16]\n"
"ldr q2, [%[a_ptr], #32]\n"
"ldr q3, [%[a_ptr], #48]\n"

"smull v12.8h, v0.8b, v4.8b\n"
"smull v13.8h, v0.8b, v5.8b\n"
"smull v14.8h, v0.8b, v6.8b\n"
"smull v15.8h, v0.8b, v7.8b\n"
"smlal2 v12.8h, v0.16b, v4.16b\n"
"smlal2 v13.8h, v0.16b, v5.16b\n"
"smlal2 v14.8h, v0.16b, v6.16b\n"
"smlal2 v15.8h, v0.16b, v7.16b\n"
"sadalp v16.4s, v12.8h\n"
"sadalp v17.4s, v13.8h\n"
"sadalp v18.4s, v14.8h\n"
"sadalp v19.4s, v15.8h\n"

"smull v12.8h, v1.8b, v4.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"smull v14.8h, v1.8b, v6.8b\n"
"smull v15.8h, v1.8b, v7.8b\n"
"smlal2 v12.8h, v1.16b, v4.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"smlal2 v14.8h, v1.16b, v6.16b\n"
"smlal2 v15.8h, v1.16b, v7.16b\n"
"sadalp v20.4s, v12.8h\n"
"sadalp v21.4s, v13.8h\n"
"sadalp v22.4s, v14.8h\n"
"sadalp v23.4s, v15.8h\n"

"smull v12.8h, v2.8b, v4.8b\n"
"smull v13.8h, v2.8b, v5.8b\n"
"smull v14.8h, v2.8b, v6.8b\n"
"smull v15.8h, v2.8b, v7.8b\n"
"smlal2 v12.8h, v2.16b, v4.16b\n"
"smlal2 v13.8h, v2.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v6.16b\n"
"smlal2 v15.8h, v2.16b, v7.16b\n"
"sadalp v24.4s, v12.8h\n"
"sadalp v25.4s, v13.8h\n"
"sadalp v26.4s, v14.8h\n"
"sadalp v27.4s, v15.8h\n"

"smull v12.8h, v3.8b, v4.8b\n"
"smull v13.8h, v3.8b, v5.8b\n"
"smull v14.8h, v3.8b, v6.8b\n"
"smull v15.8h, v3.8b, v7.8b\n"
"smlal2 v12.8h, v3.16b, v4.16b\n"
"smlal2 v13.8h, v3.16b, v5.16b\n"
"smlal2 v14.8h, v3.16b, v6.16b\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"sadalp v28.4s, v12.8h\n"
"sadalp v29.4s, v13.8h\n"
"sadalp v30.4s, v14.8h\n"
"sadalp v31.4s, v15.8h\n"
"add %[a_ptr], %[a_ptr], #64\n"
"add %[b_ptr], %[b_ptr], #64\n"

"subs %w[K], %w[K], #1\n"
"cbnz %w[K], 2b\n"

"3:\n"
// reduction
"addp v16.4s, v16.4s, v17.4s\n"
"addp v17.4s, v18.4s, v19.4s\n"
"addp v16.4s, v16.4s, v17.4s\n"
"addp v18.4s, v20.4s, v21.4s\n"
"addp v19.4s, v22.4s, v23.4s\n"
"addp v17.4s, v18.4s, v19.4s\n"
"addp v20.4s, v24.4s, v25.4s\n"
"addp v21.4s, v26.4s, v27.4s\n"
"addp v18.4s, v20.4s, v21.4s\n"
"addp v22.4s, v28.4s, v29.4s\n"
"addp v23.4s, v30.4s, v31.4s\n"
"addp v19.4s, v22.4s, v23.4s\n"

STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[output] "+r"(output), [m_remain] "+r"(m_remain),
[n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc",
"memory");

#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}

static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0, int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);

int y = y0;
for (; y + 3 < ymax; y += 4) {
const int8_t* inptr0 = inptr + y * ldin + k0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;

prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);

int K = kmax - k0;
//! read 16 * 4 in each row
for (; K > 15; K -= 16) {
interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr);
}

if (K > 0) {
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K);
}
}
for (; y < ymax; y += 4) {
const int8_t* inptr0 = inptr + y * ldin + k0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;

prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);

int K = kmax - k0;
//! read 4 * 4 in each row
for (; K > 15; K -= 16) {
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}

interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr);
}

if (K > 0) {
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K);
}
}
}

static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0;
const int ksize4 = round_up(ksize, 16) * 4;
int8_t* outptr = out;

int k = k0;
for (; k < kmax; k += 16) {
int ki = k;
for (int cnt = 0; cnt < 2; ki += 8, cnt++) {
const int8_t* inptr0 = in + ki * ldin + x0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;
const int8_t* inptr4 = inptr3 + ldin;
const int8_t* inptr5 = inptr4 + ldin;
const int8_t* inptr6 = inptr5 + ldin;
const int8_t* inptr7 = inptr6 + ldin;
int8_t* outptr_inner = outptr + ki - k;

int remain = std::min(ki + 7 - kmax, 7);
int x = x0;
for (; x + 3 < xmax; x += 4) {
if (remain >= 0) {
switch (remain) {
case 7:
inptr0 = zerobuff;
case 6:
inptr1 = zerobuff;
case 5:
inptr2 = zerobuff;
case 4:
inptr3 = zerobuff;
case 3:
inptr4 = zerobuff;
case 2:
inptr5 = zerobuff;
case 1:
inptr6 = zerobuff;
case 0:
inptr7 = zerobuff;
break;
default:
megdnn_assert(0);
}
}

transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3,
inptr4, inptr5, inptr6, inptr7,
outptr_inner);
outptr_inner += ksize4;
}

if (x < xmax) {
if (remain >= 0) {
switch (remain) {
case 7:
inptr0 = zerobuff;
case 6:
inptr1 = zerobuff;
case 5:
inptr2 = zerobuff;
case 4:
inptr3 = zerobuff;
case 3:
inptr4 = zerobuff;
case 2:
inptr5 = zerobuff;
case 1:
inptr6 = zerobuff;
case 0:
inptr7 = zerobuff;
break;
default:
megdnn_assert(0);
}
}

for (; x < xmax; x++) {
*outptr_inner++ = *inptr0++;
*outptr_inner++ = *inptr1++;
*outptr_inner++ = *inptr2++;
*outptr_inner++ = *inptr3++;
*outptr_inner++ = *inptr4++;
*outptr_inner++ = *inptr5++;
*outptr_inner++ = *inptr6++;
*outptr_inner++ = *inptr7++;
outptr_inner += 8;
}
}
}

outptr += 16 * 4;
}
}

} // namespace matmul_4x4x16
} // namespace aarch64
} // namespace megdnn
#endif

// vim: syntax=cpp.doxygen

+ 1375
- 0
dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h
File diff suppressed because it is too large
View File


+ 892
- 0
dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h View File

@@ -0,0 +1,892 @@
/**
* \file dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include <cstring>
#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"

namespace megdnn {
namespace aarch64 {
namespace matmul_mk4_4x4x16 {

/**
* Overview of register layout:
*
* A 16x4 cell of Rhs is stored in 8bit in v0-q3.
* B 16x4 cell of Lhs is stored in 8bit in q4-q7
* C 8x16 block of accumulators is stored in 8bit in q8--q31.
*
* \warning Fast kernel operating on int8 operands.
* It is assumed that one of the two int8 operands only takes values
* in [-127, 127], while the other may freely range in [-128, 127].
* The issue with both operands taking the value -128 is that:
* -128*-128 + -128*-128 == -32768 overflows int16.
* Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16
* range. That is the basic idea of this kernel.
*
*
* +--------+--------+---------+---------+
* |v4[0-16]|v5[0-16]| v6[0-16]| v7[0-16]|
* Rhs +--------+--------+---------+---------+
* | | | | |
*
* Lhs | | | | |
*
* +--------+ - - - - +-------------------------------------+
* |v0[0-16]| |v16[0-4]|v17[0-4]| v18[0-4]| v19[0-4]|
* |v1[0-16]| |v20[0-4]|v21[0-4]| v22[0-4]| v23[0-4]|
* |v2[0-16]| |v24[0-4]|v25[0-4]| v26[0-4]| v27[0-4]|
* |v3[0-16]| |v28[0-4]|v29[0-4]| v30[0-4]| v31[0-4]|
* +--------+ - - - - +-------------------------------------+
*
* Accumulator
*/

static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, bool is_first_k) {
K = div_ceil(K, 16);
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;

asm volatile(
// load accumulator C
"ld1 {v0.16b}, [%[a_ptr]], #16\n"
"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"eor v18.16b, v18.16b, v18.16b\n"
"ld1 {v1.16b}, [%[a_ptr]], #16\n"
"eor v19.16b, v19.16b, v19.16b\n"
"eor v20.16b, v19.16b, v19.16b\n"
"eor v21.16b, v19.16b, v19.16b\n"
"ld1 {v4.16b, v5.16b}, [%[b_ptr]], #32\n"
"eor v22.16b, v19.16b, v19.16b\n"
"PRFM PLDL1KEEP, [%[a_ptr], #32]\n"
"eor v23.16b, v19.16b, v19.16b\n"
"eor v24.16b, v19.16b, v19.16b\n"
"PRFM PLDL1KEEP, [%[b_ptr], #32]\n"
"eor v25.16b, v19.16b, v19.16b\n"
"eor v26.16b, v19.16b, v19.16b\n"
"PRFM PLDL1KEEP, [%[b_ptr], #64]\n"
"eor v27.16b, v19.16b, v19.16b\n"
"eor v28.16b, v19.16b, v19.16b\n"
"PRFM PLDL1KEEP, [%[a_ptr], #64]\n"
"eor v29.16b, v19.16b, v19.16b\n"
"eor v30.16b, v19.16b, v19.16b\n"
"PRFM PLDL1KEEP, [%[b_ptr], #128]\n"
"eor v31.16b, v19.16b, v19.16b\n"

//! if K==1 jump to compute last K
"cmp %w[k], #2\n"
"beq 2f\n"
"blt 3f\n"

//! K>2
"1:\n"
//! First k
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v0.8b, v5.8b\n"
"ld1 {v6.16b}, [%[b_ptr]], #16\n"
"smull v12.8h, v1.8b, v4.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"ld1 {v7.16b}, [%[b_ptr]], #16\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v0.16b, v5.16b\n"
"smlal2 v12.8h, v1.16b, v4.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"

"smull v10.8h, v0.8b, v6.8b\n"
"ld1 {v2.16b}, [%[a_ptr]], #16\n"
"smull v11.8h, v0.8b, v7.8b\n"
"smull v14.8h, v1.8b, v6.8b\n"
"ld1 {v3.16b}, [%[a_ptr]], #16\n"
"smull v15.8h, v1.8b, v7.8b\n"
"sadalp v16.4s, v8.8h\n"
"smlal2 v10.8h, v0.16b, v6.16b\n"
"sadalp v17.4s, v9.8h\n"
"smlal2 v11.8h, v0.16b, v7.16b\n"
"sadalp v20.4s, v12.8h\n"
"smlal2 v14.8h, v1.16b, v6.16b\n"
"sadalp v21.4s, v13.8h\n"
"smlal2 v15.8h, v1.16b, v7.16b\n"

"smull v8.8h, v2.8b, v4.8b\n"
"smull v9.8h, v2.8b, v5.8b\n"
"ld1 {v0.16b}, [%[a_ptr]], #16\n"
"smull v12.8h, v3.8b, v4.8b\n"
"ld1 {v1.16b}, [%[a_ptr]], #16\n"
"smull v13.8h, v3.8b, v5.8b\n"
"sadalp v18.4s, v10.8h\n"
"smlal2 v8.8h, v2.16b, v4.16b\n"
"sadalp v19.4s, v11.8h\n"
"smlal2 v9.8h, v2.16b, v5.16b\n"
"sadalp v22.4s, v14.8h\n"
"smlal2 v12.8h, v3.16b, v4.16b\n"
"sadalp v23.4s, v15.8h\n"
"smlal2 v13.8h, v3.16b, v5.16b\n"

"smull v10.8h, v2.8b, v6.8b\n"
"smull v11.8h, v2.8b, v7.8b\n"
"ld1 {v4.16b}, [%[b_ptr]], #16\n"
"smull v14.8h, v3.8b, v6.8b\n"
"ld1 {v5.16b}, [%[b_ptr]], #16\n"
"smull v15.8h, v3.8b, v7.8b\n"
"sadalp v24.4s, v8.8h\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"sadalp v25.4s, v9.8h\n"
"smlal2 v11.8h, v2.16b, v7.16b\n"
"sadalp v28.4s, v12.8h\n"
"smlal2 v14.8h, v3.16b, v6.16b\n"
"sadalp v29.4s, v13.8h\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"

//! Second k
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v0.8b, v5.8b\n"
"ld1 {v6.16b}, [%[b_ptr]], #16\n"
"smull v12.8h, v1.8b, v4.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"ld1 {v7.16b}, [%[b_ptr]], #16\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"sadalp v26.4s, v10.8h\n"
"smlal2 v9.8h, v0.16b, v5.16b\n"
"sadalp v27.4s, v11.8h\n"
"smlal2 v12.8h, v1.16b, v4.16b\n"
"sadalp v30.4s, v14.8h\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"sadalp v31.4s, v15.8h\n"

"smull v10.8h, v0.8b, v6.8b\n"
"ld1 {v2.16b}, [%[a_ptr]], #16\n"
"smull v11.8h, v0.8b, v7.8b\n"
"smull v14.8h, v1.8b, v6.8b\n"
"ld1 {v3.16b}, [%[a_ptr]], #16\n"
"smull v15.8h, v1.8b, v7.8b\n"
"sadalp v16.4s, v8.8h\n"
"smlal2 v10.8h, v0.16b, v6.16b\n"
"sadalp v17.4s, v9.8h\n"
"smlal2 v11.8h, v0.16b, v7.16b\n"
"sadalp v20.4s, v12.8h\n"
"smlal2 v14.8h, v1.16b, v6.16b\n"
"sadalp v21.4s, v13.8h\n"
"smlal2 v15.8h, v1.16b, v7.16b\n"

"smull v8.8h, v2.8b, v4.8b\n"
"smull v9.8h, v2.8b, v5.8b\n"
"ld1 {v0.16b}, [%[a_ptr]], #16\n"
"smull v12.8h, v3.8b, v4.8b\n"
"ld1 {v1.16b}, [%[a_ptr]], #16\n"
"smull v13.8h, v3.8b, v5.8b\n"
"sadalp v18.4s, v10.8h\n"
"smlal2 v8.8h, v2.16b, v4.16b\n"
"sadalp v19.4s, v11.8h\n"
"smlal2 v9.8h, v2.16b, v5.16b\n"
"sadalp v22.4s, v14.8h\n"
"smlal2 v12.8h, v3.16b, v4.16b\n"
"sadalp v23.4s, v15.8h\n"
"smlal2 v13.8h, v3.16b, v5.16b\n"

"sub %w[k], %w[k], #2\n"
"cmp %w[k], #2\n"

"smull v10.8h, v2.8b, v6.8b\n"
"ld1 {v4.16b}, [%[b_ptr]], #16\n"
"smull v11.8h, v2.8b, v7.8b\n"
"ld1 {v5.16b}, [%[b_ptr]], #16\n"
"smull v14.8h, v3.8b, v6.8b\n"
"sadalp v24.4s, v8.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
"sadalp v25.4s, v9.8h\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"sadalp v28.4s, v12.8h\n"
"smlal2 v11.8h, v2.16b, v7.16b\n"
"sadalp v29.4s, v13.8h\n"
"smlal2 v14.8h, v3.16b, v6.16b\n"
"sadalp v26.4s, v10.8h\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"

"sadalp v27.4s, v11.8h\n"
"sadalp v30.4s, v14.8h\n"
"sadalp v31.4s, v15.8h\n"

"bgt 1b\n"
"blt 3f\n"

//! K==2
"2:\n"
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v0.8b, v5.8b\n"
"ld1 {v6.16b}, [%[b_ptr]], #16\n"
"smull v12.8h, v1.8b, v4.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"ld1 {v7.16b}, [%[b_ptr]], #16\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v0.16b, v5.16b\n"
"smlal2 v12.8h, v1.16b, v4.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"

"smull v10.8h, v0.8b, v6.8b\n"
"ld1 {v2.16b}, [%[a_ptr]], #16\n"
"smull v11.8h, v0.8b, v7.8b\n"
"smull v14.8h, v1.8b, v6.8b\n"
"ld1 {v3.16b}, [%[a_ptr]], #16\n"
"smull v15.8h, v1.8b, v7.8b\n"
"sadalp v16.4s, v8.8h\n"
"smlal2 v10.8h, v0.16b, v6.16b\n"
"sadalp v17.4s, v9.8h\n"
"smlal2 v11.8h, v0.16b, v7.16b\n"
"sadalp v20.4s, v12.8h\n"
"smlal2 v14.8h, v1.16b, v6.16b\n"
"sadalp v21.4s, v13.8h\n"
"smlal2 v15.8h, v1.16b, v7.16b\n"

"smull v8.8h, v2.8b, v4.8b\n"
"smull v9.8h, v2.8b, v5.8b\n"
"ld1 {v0.16b}, [%[a_ptr]], #16\n"
"smull v12.8h, v3.8b, v4.8b\n"
"ld1 {v1.16b}, [%[a_ptr]], #16\n"
"smull v13.8h, v3.8b, v5.8b\n"
"sadalp v18.4s, v10.8h\n"
"smlal2 v8.8h, v2.16b, v4.16b\n"
"sadalp v19.4s, v11.8h\n"
"smlal2 v9.8h, v2.16b, v5.16b\n"
"sadalp v22.4s, v14.8h\n"
"smlal2 v12.8h, v3.16b, v4.16b\n"
"sadalp v23.4s, v15.8h\n"
"smlal2 v13.8h, v3.16b, v5.16b\n"

"smull v10.8h, v2.8b, v6.8b\n"
"smull v11.8h, v2.8b, v7.8b\n"
"ld1 {v4.16b}, [%[b_ptr]], #16\n"
"smull v14.8h, v3.8b, v6.8b\n"
"ld1 {v5.16b}, [%[b_ptr]], #16\n"
"smull v15.8h, v3.8b, v7.8b\n"
"sadalp v24.4s, v8.8h\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"sadalp v25.4s, v9.8h\n"
"smlal2 v11.8h, v2.16b, v7.16b\n"
"sadalp v28.4s, v12.8h\n"
"smlal2 v14.8h, v3.16b, v6.16b\n"
"sadalp v29.4s, v13.8h\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"sadalp v26.4s, v10.8h\n"
"sadalp v27.4s, v11.8h\n"
"sadalp v30.4s, v14.8h\n"
"sadalp v31.4s, v15.8h\n"

//! K==1
"3:\n"
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v0.8b, v5.8b\n"
"ld1 {v6.16b}, [%[b_ptr]], #16\n"
"smull v12.8h, v1.8b, v4.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"ld1 {v7.16b}, [%[b_ptr]], #16\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v0.16b, v5.16b\n"
"smlal2 v12.8h, v1.16b, v4.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"

"smull v10.8h, v0.8b, v6.8b\n"
"ld1 {v2.16b}, [%[a_ptr]], #16\n"
"smull v11.8h, v0.8b, v7.8b\n"
"smull v14.8h, v1.8b, v6.8b\n"
"ld1 {v3.16b}, [%[a_ptr]], #16\n"
"smull v15.8h, v1.8b, v7.8b\n"
"sadalp v16.4s, v8.8h\n"
"smlal2 v10.8h, v0.16b, v6.16b\n"
"sadalp v17.4s, v9.8h\n"
"smlal2 v11.8h, v0.16b, v7.16b\n"
"sadalp v20.4s, v12.8h\n"
"smlal2 v14.8h, v1.16b, v6.16b\n"
"sadalp v21.4s, v13.8h\n"
"smlal2 v15.8h, v1.16b, v7.16b\n"

"smull v8.8h, v2.8b, v4.8b\n"
"smull v9.8h, v2.8b, v5.8b\n"
"smull v12.8h, v3.8b, v4.8b\n"
"smull v13.8h, v3.8b, v5.8b\n"
"sadalp v18.4s, v10.8h\n"
"smlal2 v8.8h, v2.16b, v4.16b\n"
"sadalp v19.4s, v11.8h\n"
"smlal2 v9.8h, v2.16b, v5.16b\n"
"sadalp v22.4s, v14.8h\n"
"smlal2 v12.8h, v3.16b, v4.16b\n"
"sadalp v23.4s, v15.8h\n"
"smlal2 v13.8h, v3.16b, v5.16b\n"

"smull v10.8h, v2.8b, v6.8b\n"
"sadalp v24.4s, v8.8h\n"
"smull v11.8h, v2.8b, v7.8b\n"
"sadalp v25.4s, v9.8h\n"
"smull v14.8h, v3.8b, v6.8b\n"
"sadalp v28.4s, v12.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
"sadalp v29.4s, v13.8h\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"smlal2 v11.8h, v2.16b, v7.16b\n"
"sadalp v26.4s, v10.8h\n"
"smlal2 v14.8h, v3.16b, v6.16b\n"
"sadalp v27.4s, v11.8h\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"sadalp v30.4s, v14.8h\n"
"sadalp v31.4s, v15.8h\n"

"addp v4.4s, v16.4s, v20.4s\n"
"addp v5.4s, v24.4s, v28.4s\n"
"addp v6.4s, v17.4s, v21.4s\n"
"addp v7.4s, v25.4s, v29.4s\n"
"addp v8.4s, v18.4s, v22.4s\n"
"addp v9.4s, v26.4s, v30.4s\n"
"addp v10.4s, v19.4s, v23.4s\n"
"addp v11.4s, v27.4s, v31.4s\n"

"cmp %w[is_first_k], #1\n"

"addp v0.4s, v4.4s, v5.4s\n"
"addp v1.4s, v6.4s, v7.4s\n"
"addp v2.4s, v8.4s, v9.4s\n"
"addp v3.4s, v10.4s, v11.4s\n"

"beq 6f\n"

"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output]]\n"
"add v0.4s, v0.4s, v8.4s\n"
"add v1.4s, v1.4s, v9.4s\n"
"add v2.4s, v2.4s, v10.4s\n"
"add v3.4s, v3.4s, v11.4s\n"

"6:\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
}

static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, bool is_first_k, size_t remain_n) {
K = div_ceil(K, 16);
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;

asm volatile(
// load accumulator C
"ld1 {v0.16b}, [%[a_ptr]], #16\n"
"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"eor v18.16b, v18.16b, v18.16b\n"
"ld1 {v1.16b}, [%[a_ptr]], #16\n"
"eor v19.16b, v19.16b, v19.16b\n"
"eor v20.16b, v19.16b, v19.16b\n"
"eor v21.16b, v19.16b, v19.16b\n"
"ld1 {v4.16b, v5.16b}, [%[b_ptr]], #32\n"
"eor v22.16b, v19.16b, v19.16b\n"
"PRFM PLDL1KEEP, [%[a_ptr], #32]\n"
"eor v23.16b, v19.16b, v19.16b\n"
"eor v24.16b, v19.16b, v19.16b\n"
"PRFM PLDL1KEEP, [%[b_ptr], #32]\n"
"eor v25.16b, v19.16b, v19.16b\n"
"eor v26.16b, v19.16b, v19.16b\n"
"PRFM PLDL1KEEP, [%[b_ptr], #64]\n"
"eor v27.16b, v19.16b, v19.16b\n"
"eor v28.16b, v19.16b, v19.16b\n"
"PRFM PLDL1KEEP, [%[a_ptr], #64]\n"
"eor v29.16b, v19.16b, v19.16b\n"
"eor v30.16b, v19.16b, v19.16b\n"
"PRFM PLDL1KEEP, [%[b_ptr], #128]\n"
"eor v31.16b, v19.16b, v19.16b\n"

//! if K==1 jump to compute last K
"cmp %w[k], #2\n"
"beq 2f\n"
"blt 3f\n"

//! K>2
"1:\n"
//! First k
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v0.8b, v5.8b\n"
"ld1 {v6.16b}, [%[b_ptr]], #16\n"
"smull v12.8h, v1.8b, v4.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"ld1 {v7.16b}, [%[b_ptr]], #16\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v0.16b, v5.16b\n"
"smlal2 v12.8h, v1.16b, v4.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"

"smull v10.8h, v0.8b, v6.8b\n"
"ld1 {v2.16b}, [%[a_ptr]], #16\n"
"smull v11.8h, v0.8b, v7.8b\n"
"smull v14.8h, v1.8b, v6.8b\n"
"ld1 {v3.16b}, [%[a_ptr]], #16\n"
"smull v15.8h, v1.8b, v7.8b\n"
"sadalp v16.4s, v8.8h\n"
"smlal2 v10.8h, v0.16b, v6.16b\n"
"sadalp v17.4s, v9.8h\n"
"smlal2 v11.8h, v0.16b, v7.16b\n"
"sadalp v20.4s, v12.8h\n"
"smlal2 v14.8h, v1.16b, v6.16b\n"
"sadalp v21.4s, v13.8h\n"
"smlal2 v15.8h, v1.16b, v7.16b\n"

"smull v8.8h, v2.8b, v4.8b\n"
"smull v9.8h, v2.8b, v5.8b\n"
"ld1 {v0.16b}, [%[a_ptr]], #16\n"
"smull v12.8h, v3.8b, v4.8b\n"
"ld1 {v1.16b}, [%[a_ptr]], #16\n"
"smull v13.8h, v3.8b, v5.8b\n"
"sadalp v18.4s, v10.8h\n"
"smlal2 v8.8h, v2.16b, v4.16b\n"
"sadalp v19.4s, v11.8h\n"
"smlal2 v9.8h, v2.16b, v5.16b\n"
"sadalp v22.4s, v14.8h\n"
"smlal2 v12.8h, v3.16b, v4.16b\n"
"sadalp v23.4s, v15.8h\n"
"smlal2 v13.8h, v3.16b, v5.16b\n"

"smull v10.8h, v2.8b, v6.8b\n"
"smull v11.8h, v2.8b, v7.8b\n"
"ld1 {v4.16b}, [%[b_ptr]], #16\n"
"smull v14.8h, v3.8b, v6.8b\n"
"ld1 {v5.16b}, [%[b_ptr]], #16\n"
"smull v15.8h, v3.8b, v7.8b\n"
"sadalp v24.4s, v8.8h\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"sadalp v25.4s, v9.8h\n"
"smlal2 v11.8h, v2.16b, v7.16b\n"
"sadalp v28.4s, v12.8h\n"
"smlal2 v14.8h, v3.16b, v6.16b\n"
"sadalp v29.4s, v13.8h\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"

//! Second k
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v0.8b, v5.8b\n"
"ld1 {v6.16b}, [%[b_ptr]], #16\n"
"smull v12.8h, v1.8b, v4.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"ld1 {v7.16b}, [%[b_ptr]], #16\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"sadalp v26.4s, v10.8h\n"
"smlal2 v9.8h, v0.16b, v5.16b\n"
"sadalp v27.4s, v11.8h\n"
"smlal2 v12.8h, v1.16b, v4.16b\n"
"sadalp v30.4s, v14.8h\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"sadalp v31.4s, v15.8h\n"

"smull v10.8h, v0.8b, v6.8b\n"
"ld1 {v2.16b}, [%[a_ptr]], #16\n"
"smull v11.8h, v0.8b, v7.8b\n"
"smull v14.8h, v1.8b, v6.8b\n"
"ld1 {v3.16b}, [%[a_ptr]], #16\n"
"smull v15.8h, v1.8b, v7.8b\n"
"sadalp v16.4s, v8.8h\n"
"smlal2 v10.8h, v0.16b, v6.16b\n"
"sadalp v17.4s, v9.8h\n"
"smlal2 v11.8h, v0.16b, v7.16b\n"
"sadalp v20.4s, v12.8h\n"
"smlal2 v14.8h, v1.16b, v6.16b\n"
"sadalp v21.4s, v13.8h\n"
"smlal2 v15.8h, v1.16b, v7.16b\n"

"smull v8.8h, v2.8b, v4.8b\n"
"smull v9.8h, v2.8b, v5.8b\n"
"ld1 {v0.16b}, [%[a_ptr]], #16\n"
"smull v12.8h, v3.8b, v4.8b\n"
"ld1 {v1.16b}, [%[a_ptr]], #16\n"
"smull v13.8h, v3.8b, v5.8b\n"
"sadalp v18.4s, v10.8h\n"
"smlal2 v8.8h, v2.16b, v4.16b\n"
"sadalp v19.4s, v11.8h\n"
"smlal2 v9.8h, v2.16b, v5.16b\n"
"sadalp v22.4s, v14.8h\n"
"smlal2 v12.8h, v3.16b, v4.16b\n"
"sadalp v23.4s, v15.8h\n"
"smlal2 v13.8h, v3.16b, v5.16b\n"

"sub %w[k], %w[k], #2\n"
"cmp %w[k], #2\n"

"smull v10.8h, v2.8b, v6.8b\n"
"ld1 {v4.16b}, [%[b_ptr]], #16\n"
"smull v11.8h, v2.8b, v7.8b\n"
"ld1 {v5.16b}, [%[b_ptr]], #16\n"
"smull v14.8h, v3.8b, v6.8b\n"
"sadalp v24.4s, v8.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
"sadalp v25.4s, v9.8h\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"sadalp v28.4s, v12.8h\n"
"smlal2 v11.8h, v2.16b, v7.16b\n"
"sadalp v29.4s, v13.8h\n"
"smlal2 v14.8h, v3.16b, v6.16b\n"
"sadalp v26.4s, v10.8h\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"

"sadalp v27.4s, v11.8h\n"
"sadalp v30.4s, v14.8h\n"
"sadalp v31.4s, v15.8h\n"

"bgt 1b\n"
"blt 3f\n"

//! K==2
"2:\n"
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v0.8b, v5.8b\n"
"ld1 {v6.16b}, [%[b_ptr]], #16\n"
"smull v12.8h, v1.8b, v4.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"ld1 {v7.16b}, [%[b_ptr]], #16\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v0.16b, v5.16b\n"
"smlal2 v12.8h, v1.16b, v4.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"

"smull v10.8h, v0.8b, v6.8b\n"
"ld1 {v2.16b}, [%[a_ptr]], #16\n"
"smull v11.8h, v0.8b, v7.8b\n"
"smull v14.8h, v1.8b, v6.8b\n"
"ld1 {v3.16b}, [%[a_ptr]], #16\n"
"smull v15.8h, v1.8b, v7.8b\n"
"sadalp v16.4s, v8.8h\n"
"smlal2 v10.8h, v0.16b, v6.16b\n"
"sadalp v17.4s, v9.8h\n"
"smlal2 v11.8h, v0.16b, v7.16b\n"
"sadalp v20.4s, v12.8h\n"
"smlal2 v14.8h, v1.16b, v6.16b\n"
"sadalp v21.4s, v13.8h\n"
"smlal2 v15.8h, v1.16b, v7.16b\n"

"smull v8.8h, v2.8b, v4.8b\n"
"smull v9.8h, v2.8b, v5.8b\n"
"ld1 {v0.16b}, [%[a_ptr]], #16\n"
"smull v12.8h, v3.8b, v4.8b\n"
"ld1 {v1.16b}, [%[a_ptr]], #16\n"
"smull v13.8h, v3.8b, v5.8b\n"
"sadalp v18.4s, v10.8h\n"
"smlal2 v8.8h, v2.16b, v4.16b\n"
"sadalp v19.4s, v11.8h\n"
"smlal2 v9.8h, v2.16b, v5.16b\n"
"sadalp v22.4s, v14.8h\n"
"smlal2 v12.8h, v3.16b, v4.16b\n"
"sadalp v23.4s, v15.8h\n"
"smlal2 v13.8h, v3.16b, v5.16b\n"

"smull v10.8h, v2.8b, v6.8b\n"
"smull v11.8h, v2.8b, v7.8b\n"
"ld1 {v4.16b}, [%[b_ptr]], #16\n"
"smull v14.8h, v3.8b, v6.8b\n"
"ld1 {v5.16b}, [%[b_ptr]], #16\n"
"smull v15.8h, v3.8b, v7.8b\n"
"sadalp v24.4s, v8.8h\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"sadalp v25.4s, v9.8h\n"
"smlal2 v11.8h, v2.16b, v7.16b\n"
"sadalp v28.4s, v12.8h\n"
"smlal2 v14.8h, v3.16b, v6.16b\n"
"sadalp v29.4s, v13.8h\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"sadalp v26.4s, v10.8h\n"
"sadalp v27.4s, v11.8h\n"
"sadalp v30.4s, v14.8h\n"
"sadalp v31.4s, v15.8h\n"

//! K==1
"3:\n"
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v0.8b, v5.8b\n"
"ld1 {v6.16b}, [%[b_ptr]], #16\n"
"smull v12.8h, v1.8b, v4.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"ld1 {v7.16b}, [%[b_ptr]], #16\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v0.16b, v5.16b\n"
"smlal2 v12.8h, v1.16b, v4.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"

"smull v10.8h, v0.8b, v6.8b\n"
"ld1 {v2.16b}, [%[a_ptr]], #16\n"
"smull v11.8h, v0.8b, v7.8b\n"
"smull v14.8h, v1.8b, v6.8b\n"
"ld1 {v3.16b}, [%[a_ptr]], #16\n"
"smull v15.8h, v1.8b, v7.8b\n"
"sadalp v16.4s, v8.8h\n"
"smlal2 v10.8h, v0.16b, v6.16b\n"
"sadalp v17.4s, v9.8h\n"
"smlal2 v11.8h, v0.16b, v7.16b\n"
"sadalp v20.4s, v12.8h\n"
"smlal2 v14.8h, v1.16b, v6.16b\n"
"sadalp v21.4s, v13.8h\n"
"smlal2 v15.8h, v1.16b, v7.16b\n"

"smull v8.8h, v2.8b, v4.8b\n"
"smull v9.8h, v2.8b, v5.8b\n"
"smull v12.8h, v3.8b, v4.8b\n"
"smull v13.8h, v3.8b, v5.8b\n"
"sadalp v18.4s, v10.8h\n"
"smlal2 v8.8h, v2.16b, v4.16b\n"
"sadalp v19.4s, v11.8h\n"
"smlal2 v9.8h, v2.16b, v5.16b\n"
"sadalp v22.4s, v14.8h\n"
"smlal2 v12.8h, v3.16b, v4.16b\n"
"sadalp v23.4s, v15.8h\n"
"smlal2 v13.8h, v3.16b, v5.16b\n"

"smull v10.8h, v2.8b, v6.8b\n"
"sadalp v24.4s, v8.8h\n"
"smull v11.8h, v2.8b, v7.8b\n"
"sadalp v25.4s, v9.8h\n"
"smull v14.8h, v3.8b, v6.8b\n"
"sadalp v28.4s, v12.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
"sadalp v29.4s, v13.8h\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"smlal2 v11.8h, v2.16b, v7.16b\n"
"sadalp v26.4s, v10.8h\n"
"smlal2 v14.8h, v3.16b, v6.16b\n"
"sadalp v27.4s, v11.8h\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"sadalp v30.4s, v14.8h\n"
"sadalp v31.4s, v15.8h\n"

"addp v4.4s, v16.4s, v20.4s\n"
"addp v5.4s, v24.4s, v28.4s\n"
"addp v6.4s, v17.4s, v21.4s\n"
"addp v7.4s, v25.4s, v29.4s\n"
"addp v8.4s, v18.4s, v22.4s\n"
"addp v9.4s, v26.4s, v30.4s\n"
"addp v10.4s, v19.4s, v23.4s\n"
"addp v11.4s, v27.4s, v31.4s\n"

"addp v0.4s, v4.4s, v5.4s\n"
"addp v1.4s, v6.4s, v7.4s\n"
"addp v2.4s, v8.4s, v9.4s\n"
"addp v3.4s, v10.4s, v11.4s\n"

"cmp %w[is_first_k], #1\n"
"beq 6f\n"

"cmp %w[remain_n], #3\n"
"beq 1003f\n"
"cmp %w[remain_n], #2\n"
"beq 1002f\n"
"cmp %w[remain_n], #1\n"
"beq 1001f\n"
"1003:\n"
"ld1 {v8.4s, v9.4s, v10.4s}, [%[output]]\n"
"add v0.4s, v0.4s, v8.4s\n"
"add v1.4s, v1.4s, v9.4s\n"
"add v2.4s, v2.4s, v10.4s\n"
"b 6f\n"
"1002:\n"
"ld1 {v8.4s, v9.4s}, [%[output]]\n"
"add v0.4s, v0.4s, v8.4s\n"
"add v1.4s, v1.4s, v9.4s\n"
"b 6f\n"
"1001:\n"
"ld1 {v8.4s}, [%[output]]\n"
"add v0.4s, v0.4s, v8.4s\n"

"6:\n"
"cmp %w[remain_n], #3\n"
"beq 10003f\n"
"cmp %w[remain_n], #2\n"
"beq 10002f\n"
"cmp %w[remain_n], #1\n"
"beq 10001f\n"
"10003:\n"
"str q2, [%[output], #32]\n"
"10002:\n"
"str q1, [%[output], #16]\n"
"10001:\n"
"str q0, [%[output]]\n"

"7:\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[remain_n] "+r"(remain_n), [is_first_k] "+r"(is_first_k),
[k] "+r"(K), [output] "+r"(output)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
}

static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
//! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)}
int8_t zerobuff[4][64];
std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4);
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0,
"mk4 matmul with m is not times of 4");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0,
"mk4 matmul with k is not times of 4");
size_t roundk = round_up(kmax - k0, 16);
size_t out_offset = roundk * 4;
int y = y0;
int start_y = y0 / 4;
for (; y + 15 < ymax; y += 16, start_y += 4) {
const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;
int8_t* output = outptr + start_y * out_offset;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);
int K = kmax - k0;
for (; K > 15; K -= 16) {
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output,
out_offset);
output += 64;
}
if (K > 0) {
std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4);
std::memcpy(zerobuff[1], inptr1, sizeof(int8_t) * K * 4);
std::memcpy(zerobuff[2], inptr2, sizeof(int8_t) * K * 4);
std::memcpy(zerobuff[3], inptr3, sizeof(int8_t) * K * 4);
inptr0 = zerobuff[0];
inptr1 = zerobuff[1];
inptr2 = zerobuff[2];
inptr3 = zerobuff[3];
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output,
out_offset);
output += 64;
}
}
for (; y + 3 < ymax; y += 4, start_y++) {
const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4;
int8_t* output = outptr + start_y * out_offset;
prefetch_2x(inptr0);
int K = kmax - k0;
for (; K > 15; K -= 16) {
transpose_interleave_1x4_4_b(inptr0, output);
output += 64;
}
if (K > 0) {
std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4);
inptr0 = zerobuff[0];
transpose_interleave_1x4_4_b(inptr0, output);
output += 64;
}
}
}

static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
int32_t zerobuff[4];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0;
const int ICB = (ksize) / 4;
const int ksize4 = round_up<int>(ICB, 4) * 4;
int32_t* outptr = reinterpret_cast<int32_t*>(out);
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0,
"mk4 matmul with k is not times of 4");

int k = k0 / 4;
for (; k + 3 < ICB; k += 4) {
const int32_t* inptr0 =
reinterpret_cast<const int32_t*>(in + k * ldin + x0);
const int32_t* inptr1 =
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0);
const int32_t* inptr2 =
reinterpret_cast<const int32_t*>(in + (k + 2) * ldin + x0);
const int32_t* inptr3 =
reinterpret_cast<const int32_t*>(in + (k + 3) * ldin + x0);
int32_t* outptr_inner = outptr;

int x = x0;
for (; x + 3 < xmax; x += 4) {
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner);
outptr_inner += ksize4;
}
if (x < xmax) {
for (; x < xmax; x++) {
*outptr_inner++ = *inptr0++;
*outptr_inner++ = *inptr1++;
*outptr_inner++ = *inptr2++;
*outptr_inner++ = *inptr3++;
}
}
outptr += 4 * 4;
}
if (k < ICB) {
const int32_t* inptr0 =
reinterpret_cast<const int32_t*>(in + k * ldin + x0);
const int32_t* inptr1 =
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0);
const int32_t* inptr2 =
reinterpret_cast<const int32_t*>(in + (k + 2) * ldin + x0);
const int32_t* inptr3 =
reinterpret_cast<const int32_t*>(in + (k + 3) * ldin + x0);
int32_t* outptr_inner = outptr;

int x = x0;
for (; x + 3 < xmax; x += 4) {
if (k + 3 >= ICB) {
switch (k + 3 - ICB) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner);
outptr_inner += ksize4;
}
if (x < xmax) {
if (k + 3 >= ICB) {
switch (k + 3 - ICB) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
for (; x < xmax; x++) {
*outptr_inner++ = *inptr0++;
*outptr_inner++ = *inptr1++;
*outptr_inner++ = *inptr2++;
*outptr_inner++ = *inptr3++;
}
}
outptr += 4 * 4;
}
}

} // namespace matmul_4x4x16
} // namespace aarch64
} // namespace megdnn
#endif

// vim: syntax=cpp.doxygen

+ 263
- 0
dnn/src/aarch64/matrix_mul/int8/strategy.cpp View File

@@ -0,0 +1,263 @@
/**
* \file dnn/src/aarch64/matrix_mul/int8/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/int8/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h"
#include "src/aarch64/matrix_mul/int8/kernel_8x8x8.h"
#include "src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"

using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

///////////////////////// gemm_s8_4x4 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4);

void gemm_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
} else {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
}
}

void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_4x4x16::gemm_s8_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}

void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);

constexpr size_t A_INTERLEAVE = 4;
constexpr size_t B_INTERLEAVE = 4;
//! K is packed to times of 4
K = round_up<size_t>(K, 16);
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int32_t* output = C + (m * LDC);

size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k);
output += B_INTERLEAVE;
cur_packB += K4;
}

for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, LDC,
is_first_k, 4,
std::min<size_t>(N - n, 4));
output += B_INTERLEAVE;
cur_packB += K4;
}

packA += K4;
}

for (; m < M; m += 4) {
int32_t* output = C + (m * LDC);

size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4_remain(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += B_INTERLEAVE;
cur_packB += K4;
}
packA += K4;
}
}

///////////////////////// gemm_mk4_s8_4x4 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4);

void gemm_mk4_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose,
"the gemm_mk4_s8_4x4 strategy is not support transpose A");
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0,
kmax);
}

void gemm_mk4_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
megdnn_assert(!transpose,
"the gemm_mk4_s8_4x4 strategy is not support transpose B");
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0,
kmax);
}

void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);

constexpr size_t A_INTERLEAVE = 4;
constexpr size_t B_INTERLEAVE = 4;
//! K is packed to times of 4
megdnn_assert(K % 4 == 0, "K is not time of 4");
const size_t K4 = round_up<size_t>(K, 16) * 4;

size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int32_t* output = C + (m / 4 * LDC);

size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output,
is_first_k);
output += B_INTERLEAVE * 4;
cur_packB += K4;
}

if (n < N) {
matmul_mk4_4x4x16::kern_4x4_remain(packA, cur_packB, K, output,
is_first_k, N - n);
}

packA += K4;
}
}


///////////////////////// gemm_s8_8x8 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8);

void gemm_s8_8x8::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0,
ymax, k0, kmax);
} else {
matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
}
}

void gemm_s8_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
if (transpose) {
matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax,
k0, kmax);
} else {
matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}

void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);

constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 8;
//! K is packed to times of 4
K = round_up<size_t>(K, 8);
const int K8 = K * 8;
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int32_t* output = C + (m * LDC);

size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K8;
}

for (; m < M; m += 4) {
int32_t* output = C + (m * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K4;
}
}
#endif

// vim: syntax=cpp.doxygen

+ 34
- 0
dnn/src/aarch64/matrix_mul/int8/strategy.h View File

@@ -0,0 +1,34 @@
/**
* \file dnn/src/aarch64/matrix_mul/int8/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#if !(__ARM_FEATURE_DOTPROD)
#include "src/fallback/matrix_mul/gemm_common.h"

namespace megdnn {
namespace aarch64 {
namespace matmul {

MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true,
gemm_s8_4x4);

MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false,
gemm_mk4_s8_4x4);

MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true,
gemm_s8_8x8);

} // namespace matmul
} // namespace aarch64
} // namespace megdnn

#endif
// vim: syntax=cpp.doxygen

+ 116
- 0
dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp View File

@@ -0,0 +1,116 @@
/**
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/matrix_mul/int8_dot/gemv.h"
#include <cstddef>
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/common/unroll_macro.h"

#if __ARM_FEATURE_DOTPROD

namespace {

void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride) {
megdnn_assert(N == 1 && Bstride == 1);
size_t m = 0;
for (; m + 2 <= M; m += 2) {
int32_t acc[4];
int32x4_t acc_neon = vdupq_n_s32(0);
size_t k = 0;
for (; k + 16 <= K; k += 16) {
int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k));
int64x2_t a1 =
vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k));
//! the first 8 elements is m, the last 8 elements is m + 1
int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1));
int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1));

int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k));
int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0));
int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0));

acc_neon = vdotq_s32(acc_neon, a2, b2);
acc_neon = vdotq_s32(acc_neon, a3, b3);
}
vst1q_s32(acc, acc_neon);

for (; k + 8 <= K; k += 8) {
int8x8_t a0 = vld1_s8(A + m * Astride + k);
int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k);
int8x8_t b0 = vld1_s8(B + k);
uint32x2_t zero = vdup_n_s32(0);
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0));
zero = vdup_n_s32(0);
acc[3] += vaddv_s32(vdot_s32(zero, a1, b0));
}

for (; k < K; ++k) {
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k];
acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k];
}
C[m * Cstride] = acc[0] + acc[1];
C[(m + 1) * Cstride] = acc[2] + acc[3];
}

for (; m < M; ++m) {
int32_t acc[4];
int32x4_t acc_neon = vdupq_n_s32(0);
size_t k = 0;
for (; k + 16 <= K; k += 16) {
int8x16_t a0 = vld1q_s8(A + m * Astride + k);
int8x16_t b0 = vld1q_s8(B + k);
acc_neon = vdotq_s32(acc_neon, a0, b0);
}
vst1q_s32(acc, acc_neon);

for (; k + 8 <= K; k += 8) {
int8x8_t a0 = vld1_s8(A + m * Astride + k);
int8x8_t b0 = vld1_s8(B + k);
uint32x2_t zero = vdup_n_s32(0);
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0));
}

for (; k < K; ++k) {
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k];
}
C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3];
}
}

} // namespace

bool megdnn::aarch64::matmul::is_gemv_like_preferred_int8(
bool transposeA, bool transposeB, size_t M, size_t N, size_t K,
size_t /* LDA */, size_t LDB, size_t /* LDC */) {
if (transposeA)
return false;
if (transposeB)
return false;
MEGDNN_MARK_USED_VAR(K);
MEGDNN_MARK_USED_VAR(M);
return (N == 1 && LDB == 1);
}

void megdnn::aarch64::matmul::gemv_like_int8(const int8_t* __restrict A,
const int8_t* __restrict B,
int32_t* __restrict C, size_t M,
size_t N, size_t K, size_t Astride,
size_t Bstride, size_t Cstride) {
megdnn_assert(N == 1);
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride);
}

#endif

// vim: syntax=cpp.doxygen

+ 34
- 0
dnn/src/aarch64/matrix_mul/int8_dot/gemv.h View File

@@ -0,0 +1,34 @@
/**
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include <cstddef>
#include <cstdint>

#if __ARM_FEATURE_DOTPROD
namespace megdnn {
namespace aarch64 {
namespace matmul {

bool is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M,
size_t N, size_t K, size_t LDA, size_t LDB,
size_t LDC);

void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride);

} // namespace matmul
} // namespace aarch64
} // namespace megdnn
#endif

// vim: syntax=cpp.doxygen

+ 1552
- 0
dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h
File diff suppressed because it is too large
View File


+ 113
- 0
dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp View File

@@ -0,0 +1,113 @@
/**
* \file dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/matrix_mul/int8_dot/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h"

#if __ARM_FEATURE_DOTPROD
using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12);

void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0,
kmax);
} else {
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
}
}

void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_8x12x4::gemm_s8_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}

void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());

MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);

constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 12;
//! K is packed to times of 4
K = round_up<size_t>(K, 4);
const int K8 = (K << 3);
const int K12 = K * 12;
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int32_t* output = C + (m * LDC);

size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
output += B_INTERLEAVE;
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K8;
}

for (; m < M; m += 4) {
int32_t* output = C + (m * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K4;
}
}
#endif
// vim: syntax=cpp.doxygen

+ 26
- 0
dnn/src/aarch64/matrix_mul/int8_dot/strategy.h View File

@@ -0,0 +1,26 @@
/**
* \file dnn/src/aarch64/matrix_mul/int8_dot/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"

#if __ARM_FEATURE_DOTPROD
namespace megdnn {
namespace aarch64 {
namespace matmul {

MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true,
gemm_s8_8x12);

} // namespace aarch64
} // namespace matmul
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 439
- 0
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h View File

@@ -0,0 +1,439 @@
/**
* \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include <inttypes.h>
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"

namespace megdnn {
namespace aarch64 {
namespace matmul_4x4x16 {

/**
* Overview of register layout:
*
* +---------+---------+---------+---------+
* |v20[0-15]|v21[0-15]|v22[0-15]|v23[0-15]|
* Rhs +---------+---------+---------+---------+
* Lhs | | |
*
* +--------+ - - - - +---------+---------+---------+---------+
* |v0[0-15]| | v4[0-8] | v8[0-8]| v12[0-8]| v16[0-8]|
* |v1[0-15]| | v5[0-8] | v9[0-8]| v13[0-8]| v17[0-8]|
* |v2[0-15]| | v6[0-8] | v10[0-8]| v14[0-8]| v18[0-8]|
* |v3[0-15]| | v7[0-8] | v11[0-8]| v15[0-8]| v19[0-8]|
* +--------+ - - - - +---------+---------+---------+---------+
*
* Accumulator
*/
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
K /= 16;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;

LDC = LDC * sizeof(int16_t);
// clang-format off
#define LOAD_LINE(reg_index, n) \
"cmp x5, #0 \n" \
"beq 105f\n" \
"cmp %w[n_remain], #4\n" \
"blt 100" n "f\n" \
"ld1 {v" reg_index ".4h}, [x" n "], #8\n" \
"b 101" n "f\n" \
"100" n ":\n" \
"cmp %w[n_remain], #0\n" \
"blt 101" n "f\n" \
"ld1 {v" reg_index ".h}[0], [x" n "], #2\n" \
"cmp %w[n_remain], #1\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[1], [x" n "], #2\n" \
"cmp %w[n_remain], #2\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[2], [x" n "], #2\n" \
"101" n ":\n" \
"sub x5, x5, #1\n"

#define LOAD_C \
"mov x5, %x[m_remain]\n" \
LOAD_LINE("24", "0") \
LOAD_LINE("25", "1") \
LOAD_LINE("26", "2") \
LOAD_LINE("27", "3") \
"105:\n"


#define STORE_LINE(reg_index, n) \
"cmp x5, #0 \n" \
"beq 105f\n" \
"cmp %w[n_remain], #4\n" \
"blt 102" n "f\n" \
"st1 {v" reg_index ".4h}, [x" n "], #8\n" \
"b 103" n "f\n" \
"102" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[0], [x" n "], #2\n" \
"cmp %w[n_remain], #1\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[1], [x" n "], #2\n" \
"cmp %w[n_remain], #2\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[2], [x" n "], #2\n" \
"103" n ":\n" \
"sub x5, x5, #1\n"

#define STORE_C \
"mov x5, %x[m_remain]\n" \
STORE_LINE("24", "0") \
STORE_LINE("25", "1") \
STORE_LINE("26", "2") \
STORE_LINE("27", "3") \
"105:\n"
// clang-format on

register int16_t* outptr asm("x0") = output;
asm volatile(
"add x1, x0, %x[LDC]\n"
"add x2, x1, %x[LDC]\n"
"add x3, x2, %x[LDC]\n"

// Clear accumulators
"eor v4.16b, v4.16b, v4.16b\n"
"eor v5.16b, v5.16b, v5.16b\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"eor v8.16b, v8.16b, v8.16b\n"
"eor v9.16b, v9.16b, v9.16b\n"
"eor v10.16b, v10.16b, v10.16b\n"
"eor v11.16b, v11.16b, v11.16b\n"
"eor v12.16b, v12.16b, v12.16b\n"
"eor v13.16b, v13.16b, v13.16b\n"
"eor v14.16b, v14.16b, v14.16b\n"
"eor v15.16b, v15.16b, v15.16b\n"
"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"eor v18.16b, v18.16b, v18.16b\n"
"eor v19.16b, v19.16b, v19.16b\n"

// General loop.
"1:\n"
"ld1 {v20.16b}, [%[b_ptr]], 16\n"
"ld1 {v0.16b}, [%[a_ptr]], 16\n"
"ld1 {v1.16b}, [%[a_ptr]], 16\n"
"ld1 {v2.16b}, [%[a_ptr]], 16\n"
"ld1 {v3.16b}, [%[a_ptr]], 16\n"

"ld1 {v21.16b}, [%[b_ptr]], 16\n"
"smlal v4.8h, v0.8b, v20.8b\n"
"smlal v5.8h, v1.8b, v20.8b\n"
"smlal v6.8h, v2.8b, v20.8b\n"
"smlal v7.8h, v3.8b, v20.8b\n"
"smlal2 v4.8h, v0.16b, v20.16b\n"
"smlal2 v5.8h, v1.16b, v20.16b\n"
"smlal2 v6.8h, v2.16b, v20.16b\n"
"smlal2 v7.8h, v3.16b, v20.16b\n"

"ld1 {v22.16b}, [%[b_ptr]], 16\n"
"smlal v8.8h, v0.8b, v21.8b\n"
"smlal v9.8h, v1.8b, v21.8b\n"
"smlal v10.8h, v2.8b, v21.8b\n"
"smlal v11.8h, v3.8b, v21.8b\n"
"smlal2 v8.8h, v0.16b, v21.16b\n"
"smlal2 v9.8h, v1.16b, v21.16b\n"
"smlal2 v10.8h, v2.16b, v21.16b\n"
"smlal2 v11.8h, v3.16b, v21.16b\n"

"ld1 {v23.16b}, [%[b_ptr]], 16\n"
"smlal v12.8h, v0.8b, v22.8b\n"
"smlal v13.8h, v1.8b, v22.8b\n"
"smlal v14.8h, v2.8b, v22.8b\n"
"smlal v15.8h, v3.8b, v22.8b\n"
"smlal2 v12.8h, v0.16b, v22.16b\n"
"smlal2 v13.8h, v1.16b, v22.16b\n"
"smlal2 v14.8h, v2.16b, v22.16b\n"
"smlal2 v15.8h, v3.16b, v22.16b\n"

"smlal v16.8h, v0.8b, v23.8b\n"
"smlal v17.8h, v1.8b, v23.8b\n"
"smlal v18.8h, v2.8b, v23.8b\n"
"smlal v19.8h, v3.8b, v23.8b\n"
"smlal2 v16.8h, v0.16b, v23.16b\n"
"smlal2 v17.8h, v1.16b, v23.16b\n"
"smlal2 v18.8h, v2.16b, v23.16b\n"
"smlal2 v19.8h, v3.16b, v23.16b\n"

"subs %w[K], %w[K], #1\n"
"cbnz %w[K], 1b\n"

"cmp %w[is_first_k], #1\n"
"beq 2f\n" LOAD_C
"b 3f\n"

"2:\n" // Clear the C regs.
"eor v24.16b, v24.16b, v24.16b\n"
"eor v25.16b, v25.16b, v25.16b\n"
"eor v26.16b, v26.16b, v26.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"

"3:\n"
// Reduce v4-v19 to v0-v3
"addv h20, v4.8h\n"
"addv h21, v8.8h\n"
"addv h22, v12.8h\n"
"addv h23, v16.8h\n"
"ins v0.h[0], v20.h[0]\n"
"ins v0.h[1], v21.h[0]\n"
"ins v0.h[2], v22.h[0]\n"
"ins v0.h[3], v23.h[0]\n"
"add v24.4h, v24.4h, v0.4h\n"

"addv h28, v5.8h\n"
"addv h29, v9.8h\n"
"addv h30, v13.8h\n"
"addv h31, v17.8h\n"
"ins v1.h[0], v28.h[0]\n"
"ins v1.h[1], v29.h[0]\n"
"ins v1.h[2], v30.h[0]\n"
"ins v1.h[3], v31.h[0]\n"
"add v25.4h, v25.4h, v1.4h\n"

"addv h20, v6.8h\n"
"addv h21, v10.8h\n"
"addv h22, v14.8h\n"
"addv h23, v18.8h\n"
"ins v2.h[0], v20.h[0]\n"
"ins v2.h[1], v21.h[0]\n"
"ins v2.h[2], v22.h[0]\n"
"ins v2.h[3], v23.h[0]\n"
"add v26.4h, v26.4h, v2.4h\n"

"addv h28, v7.8h\n"
"addv h29, v11.8h\n"
"addv h30, v15.8h\n"
"addv h31, v19.8h\n"
"ins v3.h[0], v28.h[0]\n"
"ins v3.h[1], v29.h[0]\n"
"ins v3.h[2], v30.h[0]\n"
"ins v3.h[3], v31.h[0]\n"
"add v27.4h, v27.4h, v3.4h\n"

// Store back into memory
STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain),
[n_remain] "+r"(n_remain)
:
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2",
"v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31");

#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}

static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);

int y = y0;
for (; y + 3 < ymax; y += 4) {
const int8_t* inptr0 = inptr + y * ldin + k0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;

prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);

int K = kmax - k0;
//! read 4 * 16 in each row
for (; K > 15; K -= 16) {
interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr);
}

if (K > 0) {
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K);
}
}
for (; y < ymax; y += 4) {
const int8_t* inptr0 = inptr + y * ldin + k0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;

prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);

int K = kmax - k0;
//! read 4 * 16 in each row
for (; K > 15; K -= 16) {
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
interleave_4x16_1_b(inptr0, inptr1, inptr2, inptr3, outptr);
}

if (K > 0) {
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 16, K);
}
}
}

static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0;
const int ksize4 = round_up(ksize, 16) * 4;
int8_t* outptr = out;

int k = k0;
for (; k < kmax; k += 16) {
int ki = k;
for (int cnt = 0; cnt < 2; ki += 8, cnt++) {
const int8_t* inptr0 = in + ki * ldin + x0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;
const int8_t* inptr4 = inptr3 + ldin;
const int8_t* inptr5 = inptr4 + ldin;
const int8_t* inptr6 = inptr5 + ldin;
const int8_t* inptr7 = inptr6 + ldin;

prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);
prefetch_2x(inptr4);
prefetch_2x(inptr5);
prefetch_2x(inptr6);
prefetch_2x(inptr7);

int8_t* outptr_inner = outptr + ki - k;

int remain = std::min(ki + 7 - kmax, 7);
int x = x0;
for (; x + 3 < xmax; x += 4) {
if (remain >= 0) {
switch (remain) {
case 7:
inptr0 = zerobuff;
case 6:
inptr1 = zerobuff;
case 5:
inptr2 = zerobuff;
case 4:
inptr3 = zerobuff;
case 3:
inptr4 = zerobuff;
case 2:
inptr5 = zerobuff;
case 1:
inptr6 = zerobuff;
case 0:
inptr7 = zerobuff;
break;
default:
megdnn_assert(0);
}
}

transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3,
inptr4, inptr5, inptr6, inptr7,
outptr_inner);
outptr_inner += ksize4;
}

if (x < xmax) {
if (remain >= 0) {
switch (remain) {
case 7:
inptr0 = zerobuff;
case 6:
inptr1 = zerobuff;
case 5:
inptr2 = zerobuff;
case 4:
inptr3 = zerobuff;
case 3:
inptr4 = zerobuff;
case 2:
inptr5 = zerobuff;
case 1:
inptr6 = zerobuff;
case 0:
inptr7 = zerobuff;
break;
default:
megdnn_assert(0);
}
}

for (; x < xmax; x++) {
*outptr_inner++ = *inptr0++;
*outptr_inner++ = *inptr1++;
*outptr_inner++ = *inptr2++;
*outptr_inner++ = *inptr3++;
*outptr_inner++ = *inptr4++;
*outptr_inner++ = *inptr5++;
*outptr_inner++ = *inptr6++;
*outptr_inner++ = *inptr7++;
outptr_inner += 8;
}
}
}

outptr += 16 * 4;
}
}

} // namespace matmul_4x4x16
} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 1300
- 0
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h
File diff suppressed because it is too large
View File


+ 200
- 0
dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp View File

@@ -0,0 +1,200 @@
/**
* \file dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h"
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"

using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

// ===========================gemm_s8x8x16_4x4==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8);

void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n(out, in, ldin, y0,
ymax, k0, kmax);
} else {
matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
}
}

void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n(out, in, ldin, x0,
xmax, k0, kmax);
} else {
matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
}
}

void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);

constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 8;
//! K is packed to times of 4
K = round_up<size_t>(K, 8);
const int K8 = K * 8;
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int16_t* output = C + (m * LDC);

size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K8;
}

for (; m < M; m += 4) {
int16_t* output = C + (m * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K4;
}
}

// ===========================gemm_s8x8x16_4x4==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4);

void gemm_s8x8x16_4x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0,
kmax);
} else {
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
}
}

void gemm_s8x8x16_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0,
kmax);
} else {
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
}
}

void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);

constexpr size_t A_INTERLEAVE = 4;
constexpr size_t B_INTERLEAVE = 4;
//! K is packed to times of 4
K = round_up<size_t>(K, 16);
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int16_t* output = C + (m * LDC);

size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, A_INTERLEAVE, B_INTERLEAVE);
output += B_INTERLEAVE;
cur_packB += K4;
}

for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, A_INTERLEAVE,
std::min<size_t>(N - n, B_INTERLEAVE));
output += B_INTERLEAVE;
cur_packB += K4;
}

packA += K4;
}

for (; m < M; m += A_INTERLEAVE) {
int16_t* output = C + (m * LDC);
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(M - m, A_INTERLEAVE),
std::min<size_t>(N - n, B_INTERLEAVE));
output += B_INTERLEAVE;
cur_packB += K4;
}
packA += K4;
}
}
// vim: syntax=cpp.doxygen

+ 27
- 0
dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h View File

@@ -0,0 +1,27 @@
/**
* \file dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "src/fallback/matrix_mul/gemm_common.h"

namespace megdnn {
namespace aarch64 {
namespace matmul {

MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true,
gemm_s8x8x16_8x8);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 16, false, true,
gemm_s8x8x16_4x4);

} // namespace matmul
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen

+ 93
- 0
dnn/src/aarch64/matrix_mul/opr_impl.cpp View File

@@ -0,0 +1,93 @@
/**
* \file dnn/src/aarch64/matrix_mul/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/matrix_mul/opr_impl.h"
#include "src/aarch64/matrix_mul/algos.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h"

using namespace megdnn;
using namespace aarch64;

class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF32K8x12x1 f32K8x12x1;
AlgoF32K4x16x1 f32k4x16x1;
AlgoF32MK4_4x16 f32mk4_4x16;
AlgoF32Gemv f32_gemv;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16K8x24x1 f16_k8x24x1;
AlgoF16MK8_8x8 f16_mk8_8x8;
#endif
#if __ARM_FEATURE_DOTPROD
AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod;
AlgoInt8x8x32GemvDotProd int8x8x32_gemv_dotprod;
#else
AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16;
AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16;
AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8;
AlgoInt8x8x32Gemv int8x8x32_gemv;
#endif
AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8;
AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16;

AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1;
AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8;

#if __ARM_FEATURE_DOTPROD
AlgoQuint8K8x8x4DotProd quint8_k8x8x4_dotprod;
AlgoQuint8GemvDotProd quint8_gemv_dotprod;
#else
AlgoQuint8K8x8x8 quint8_k8x8x8;
#endif

public:
SmallVector<MatrixMulImpl::AlgoBase*> all_algos;

AlgoPack() {
all_algos.emplace_back(&f32_gemv);
all_algos.emplace_back(&f32K8x12x1);
all_algos.emplace_back(&f32k4x16x1);
all_algos.emplace_back(&f32mk4_4x16);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&f16_k8x24x1);
all_algos.emplace_back(&f16_mk8_8x8);
#endif
#if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_gemv_dotprod);
all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod);
#else
all_algos.emplace_back(&int8x8x32_gemv);
all_algos.emplace_back(&int8x8x32_k8x8x8);
all_algos.emplace_back(&int8x8x32_k4x4x16);
all_algos.emplace_back(&int8x8x32_mk4_4x4x16);
#endif
all_algos.emplace_back(&int8x8x16_k4x4x16);
all_algos.emplace_back(&int8x8x16_k8x8x8);

all_algos.emplace_back(&int16x16x32_k12x8x1);
all_algos.emplace_back(&int16x16x32_mk8_8x8);
#if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&quint8_gemv_dotprod);
all_algos.emplace_back(&quint8_k8x8x4_dotprod);
#else
all_algos.emplace_back(&quint8_k8x8x8);
#endif
}
};

SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
static AlgoPack s_algo_pack;
auto&& algos = arm_common::MatrixMulImpl::algo_pack();
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(),
s_algo_pack.all_algos.end());
return std::move(algos);
}

// vim: syntax=cpp.doxygen

+ 63
- 0
dnn/src/aarch64/matrix_mul/opr_impl.h View File

@@ -0,0 +1,63 @@
/**
* \file dnn/src/aarch64/matrix_mul/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/arm_common/matrix_mul/opr_impl.h"

namespace megdnn {
namespace aarch64 {

class MatrixMulImpl : public arm_common::MatrixMulImpl {
public:
using arm_common::MatrixMulImpl::MatrixMulImpl;

SmallVector<AlgoBase*> algo_pack() override;

private:
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1
class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1
class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4
class AlgoF32Gemv; // Aarch64 F32 Gemv
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoF16K8x24x1; // Aarch64 F16 Kernel 8x24x1
class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8
#endif

#if __ARM_FEATURE_DOTPROD
class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel
// 8x12x4 DotProduct
class AlgoInt8x8x32GemvDotProd; // Aarch64 Int8x8x32 Gemv DotProduct
#else
class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16
class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16
class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8
class AlgoInt8x8x32Gemv; // Aarch64 Int8x8x32 Gemv
#endif
class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8
class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16

class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1
class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8

#if __ARM_FEATURE_DOTPROD
class AlgoQuint8K8x8x4DotProd; // Aarch64 Quint8 Kernel
// 8x8x4 DotProduct
class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct
#else
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8
#endif

class AlgoPack;
};

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 1398
- 0
dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h
File diff suppressed because it is too large
View File


+ 113
- 0
dnn/src/aarch64/matrix_mul/quint8/strategy.cpp View File

@@ -0,0 +1,113 @@
/**
* \file dnn/src/aarch64/matrix_mul/quint8/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/quint8/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"

using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8);

void gemm_u8_8x8::pack_A(dt_uint8* outptr, const dt_uint8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point;
if (transpose) {
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0,
ymax, k0, kmax, zA);
} else {
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax, zA);
}
}

void gemm_u8_8x8::pack_B(dt_uint8* out, const dt_uint8* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point;
if (transpose) {
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax,
k0, kmax, zB);
} else {
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax,
zB);
}
}

void gemm_u8_8x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Quantized8Asymm &&
C_dtype.enumv() == DTypeEnum::QuantizedS32,
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point;
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point;

constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 8;
//! K is packed to times of 8
K = round_up<size_t>(K, 8);
const int K8 = K * 8;
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int32_t* output = C + (m * LDC);

size_t n = 0;
const dt_uint8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k,
zA, zB);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4), zA, zB);
output += 4;
cur_packB += K4;
}
packA += K8;
}

for (; m < M; m += 4) {
int32_t* output = C + (m * LDC);
const dt_uint8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), zA, zB);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4), zA, zB);
output += 4;
cur_packB += K4;
}
packA += K4;
}
}
#endif

// vim: syntax=cpp.doxygen

+ 28
- 0
dnn/src/aarch64/matrix_mul/quint8/strategy.h View File

@@ -0,0 +1,28 @@
/**
* \file dnn/src/aarch64/matrix_mul/quint8/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#if !(__ARM_FEATURE_DOTPROD)
#include "src/fallback/matrix_mul/gemm_common.h"

namespace megdnn {
namespace aarch64 {
namespace matmul {

MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 8, 8, 8, false, true,
gemm_u8_8x8);

} // namespace matmul
} // namespace aarch64
} // namespace megdnn
#endif

// vim: syntax=cpp.doxygen

+ 177
- 0
dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp View File

@@ -0,0 +1,177 @@
/**
* \file dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
#include <cstddef>
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/common/unroll_macro.h"

#if __ARM_FEATURE_DOTPROD

namespace {

void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride,
uint8_t zero_point_A, uint8_t zero_point_B) {
int32_t zAB = static_cast<int32_t>(zero_point_A) *
static_cast<int32_t>(zero_point_B) * K;
uint8x16_t zAq = vdupq_n_u8(zero_point_A);
uint8x16_t zBq = vdupq_n_u8(zero_point_B);
uint8x8_t zA = vdup_n_u8(zero_point_A);
uint8x8_t zB = vdup_n_u8(zero_point_B);
megdnn_assert(N == 1 && Bstride == 1);
size_t m = 0;
for (; m + 2 <= M; m += 2) {
int32_t acc_zA, acc_zB, acc_zB2;
int32_t acc[4];
size_t k = 0;
uint32x4_t acc_neon = vdupq_n_u32(0);
{
uint32x4_t acc_zA_neon = vdupq_n_u32(0);
uint32x4_t acc_zB_neon = vdupq_n_u32(0);
uint32x4_t acc_zB2_neon = vdupq_n_u32(0);
for (; k + 16 <= K; k += 16) {
uint8x16_t elem = vld1q_u8(A + m * Astride + k);
acc_zB_neon = vdotq_u32(acc_zB_neon, zBq, elem);
uint64x2_t a0 = vreinterpretq_u64_u8(elem);
elem = vld1q_u8(A + (m + 1) * Astride + k);
acc_zB2_neon = vdotq_u32(acc_zB2_neon, zBq, elem);
uint64x2_t a1 = vreinterpretq_u64_u8(elem);
//! the first 8 elements is m, the last 8 elements is m + 1
uint8x16_t a2 = vreinterpretq_u8_u64(vzip1q_u64(a0, a1));
uint8x16_t a3 = vreinterpretq_u8_u64(vzip2q_u64(a0, a1));

elem = vld1q_u8(B + k);
acc_zA_neon = vdotq_u32(acc_zA_neon, zAq, elem);
uint64x2_t b0 = vreinterpretq_u64_u8(elem);
uint8x16_t b2 = vreinterpretq_u8_u64(vzip1q_u64(b0, b0));
uint8x16_t b3 = vreinterpretq_u8_u64(vzip2q_u64(b0, b0));

acc_neon = vdotq_u32(acc_neon, a2, b2);
acc_neon = vdotq_u32(acc_neon, a3, b3);
}
vst1q_s32(acc, vreinterpretq_s32_u32(acc_neon));
acc_zA = vaddvq_u32(acc_zA_neon);
acc_zB = vaddvq_u32(acc_zB_neon);
acc_zB2 = vaddvq_u32(acc_zB2_neon);
}

{
uint32x2_t acc_zA_neon = vdup_n_u32(0);
uint32x2_t acc_zB_neon = vdup_n_u32(0);
uint32x2_t acc_zB2_neon = vdup_n_u32(0);
for (; k + 8 <= K; k += 8) {
uint8x8_t a0 = vld1_u8(A + m * Astride + k);
uint8x8_t a1 = vld1_u8(A + (m + 1) * Astride + k);
uint8x8_t b0 = vld1_u8(B + k);
uint32x2_t zero = vdup_n_u32(0);
acc[0] += vaddv_u32(vdot_u32(zero, a0, b0));
zero = vdup_n_u32(0);
acc[3] += vaddv_u32(vdot_u32(zero, a1, b0));

acc_zB_neon = vdot_u32(acc_zB_neon, a0, zB);
acc_zB2_neon = vdot_u32(acc_zB2_neon, a1, zB);
acc_zA_neon = vdot_u32(acc_zA_neon, b0, zA);
}

acc_zA += vaddv_u32(acc_zA_neon);
acc_zB += vaddv_u32(acc_zB_neon);
acc_zB2 += vaddv_u32(acc_zB2_neon);
}

for (; k < K; ++k) {
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k];
acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k];
acc_zA += static_cast<int32_t>(B[k]) * zero_point_A;
acc_zB += static_cast<int32_t>(A[m * Astride + k]) * zero_point_B;
acc_zB2 += static_cast<int32_t>(A[(m + 1) * Astride + k]) *
zero_point_B;
}
C[m * Cstride] = acc[0] + acc[1] + zAB - acc_zA - acc_zB;
C[(m + 1) * Cstride] = acc[2] + acc[3] + zAB - acc_zA - acc_zB2;
}

for (; m < M; ++m) {
int32_t acc[4];
int32_t acc_zA, acc_zB;
uint32x4_t acc_neon = vdupq_n_u32(0);
size_t k = 0;
{
uint32x4_t acc_zA_neon = vdupq_n_u32(0);
uint32x4_t acc_zB_neon = vdupq_n_u32(0);
for (; k + 16 <= K; k += 16) {
uint8x16_t a0 = vld1q_u8(A + m * Astride + k);
uint8x16_t b0 = vld1q_u8(B + k);
acc_neon = vdotq_u32(acc_neon, a0, b0);
acc_zB_neon = vdotq_u32(acc_zB_neon, zBq, a0);
acc_zA_neon = vdotq_u32(acc_zA_neon, zAq, b0);
}
vst1q_s32(acc, vreinterpretq_s32_u32(acc_neon));
acc_zA = vaddvq_u32(acc_zA_neon);
acc_zB = vaddvq_u32(acc_zB_neon);
}

{
uint32x2_t acc_zA_neon = vdup_n_u32(0);
uint32x2_t acc_zB_neon = vdup_n_u32(0);
for (; k + 8 <= K; k += 8) {
uint8x8_t a0 = vld1_u8(A + m * Astride + k);
uint8x8_t b0 = vld1_u8(B + k);
uint32x2_t zero = vdup_n_u32(0);
acc[0] += vaddv_u32(vdot_u32(zero, a0, b0));

acc_zB_neon = vdot_u32(acc_zB_neon, a0, zB);
acc_zA_neon = vdot_u32(acc_zA_neon, b0, zA);
}
acc_zA += vaddv_u32(acc_zA_neon);
acc_zB += vaddv_u32(acc_zB_neon);
}

for (; k < K; ++k) {
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k];
acc_zA += static_cast<int32_t>(B[k]) * zero_point_A;
acc_zB += static_cast<int32_t>(A[m * Astride + k]) * zero_point_B;
}
C[m * Cstride] =
acc[0] + acc[1] + acc[2] + acc[3] + zAB - acc_zA - acc_zB;
}
}

} // namespace

bool megdnn::aarch64::matmul::is_gemv_like_preferred_quint8(
bool transposeA, bool transposeB, size_t M, size_t N, size_t K,
size_t /* LDA */, size_t LDB, size_t /* LDC */) {
if (transposeA)
return false;
if (transposeB)
return false;
MEGDNN_MARK_USED_VAR(K);
MEGDNN_MARK_USED_VAR(M);
//! rebenchmark gemv in sdm855
return (N == 1 && LDB == 1);
}

void megdnn::aarch64::matmul::gemv_like_quint8(
const uint8_t* __restrict A, const uint8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K, size_t Astride,
size_t Bstride, size_t Cstride, uint8_t zero_point_A,
uint8_t zero_point_B) {
megdnn_assert(N == 1);
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride,
zero_point_A, zero_point_B);
}

#endif

// vim: syntax=cpp.doxygen

+ 35
- 0
dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h View File

@@ -0,0 +1,35 @@
/**
* \file dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include <cstddef>
#include <cstdint>

#if __ARM_FEATURE_DOTPROD
namespace megdnn {
namespace aarch64 {
namespace matmul {

bool is_gemv_like_preferred_quint8(bool transposeA, bool transposeB, size_t M,
size_t N, size_t K, size_t LDA, size_t LDB,
size_t LDC);

void gemv_like_quint8(const uint8_t* __restrict A, const uint8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride,
uint8_t zero_point_A, uint8_t zero_point_B);

} // namespace matmul
} // namespace aarch64
} // namespace megdnn
#endif

// vim: syntax=cpp.doxygen

+ 1092
- 0
dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h
File diff suppressed because it is too large
View File


+ 114
- 0
dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp View File

@@ -0,0 +1,114 @@
/**
* \file dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/matrix_mul/quint8_dot/strategy.h"
#include "megdnn/dtype.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"

#if __ARM_FEATURE_DOTPROD
using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8);

void gemm_u8_8x8::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0,
ymax, k0, kmax);
} else {
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin,
y0, ymax, k0, kmax);
}
}

void gemm_u8_8x8::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
if (transpose) {
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0,
xmax, k0, kmax);
} else {
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax,
k0, kmax);
}
}

void gemm_u8_8x8::kern(const uint8_t* packA, const uint8_t* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Quantized8Asymm &&
C_dtype.enumv() == DTypeEnum::QuantizedS32);
MEGDNN_MARK_USED_VAR(C_dtype);
size_t zero_point_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point;
size_t zero_point_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point;
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 8;
const uint32_t zAB = static_cast<uint32_t>(zero_point_A) *
static_cast<uint32_t>(zero_point_B) * K;
//! K is packed to times of 4
K = round_up<size_t>(K, 4);
const int K8 = (K << 3);
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int32_t* output = C + (m * LDC);

size_t n = 0;
const dt_uint8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k,
zero_point_A, zero_point_B, zAB);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x4::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4), zero_point_A,
zero_point_B, zAB);
output += 4;
cur_packB += K4;
}
packA += K8;
}

for (; m < M; m += 4) {
int32_t* output = C + (m * LDC);
const dt_uint8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), zero_point_A,
zero_point_B, zAB);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x4::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4), zero_point_A,
zero_point_B, zAB);
output += 4;
cur_packB += K4;
}
packA += K4;
}
}
#endif
// vim: syntax=cpp.doxygen

+ 26
- 0
dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h View File

@@ -0,0 +1,26 @@
/**
* \file dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"

#if __ARM_FEATURE_DOTPROD
namespace megdnn {
namespace aarch64 {
namespace matmul {

MEGDNN_REG_GEMM_STRATEGY(uint8_t, int32_t, int32_t, 8, 8, 4, false, true,
gemm_u8_8x8);

} // namespace aarch64
} // namespace matmul
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 183
- 0
dnn/src/aarch64/relayout/opr_impl.cpp View File

@@ -0,0 +1,183 @@
/**
* \file dnn/src/aarch64/relayout/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/common/utils.h"
#include "src/common/relayout_helper.h"

#include "src/aarch64/handle.h"
#include "src/aarch64/relayout/opr_impl.h"

using namespace megdnn;
using namespace relayout;

namespace {

struct TransposeByte {
uint8_t v;
};

void trans_16x16_u8(const void* src, void* dst, const size_t src_step,
const size_t dst_step) {
asm volatile(
"\n"
"ld1 {v0.16b}, [%[src]], %[src_step] \n"
"ld1 {v1.16b}, [%[src]], %[src_step] \n"
"ld1 {v2.16b}, [%[src]], %[src_step] \n"
"ld1 {v3.16b}, [%[src]], %[src_step] \n"
"ld1 {v4.16b}, [%[src]], %[src_step] \n"
"ld1 {v5.16b}, [%[src]], %[src_step] \n"
"ld1 {v6.16b}, [%[src]], %[src_step] \n"
"ld1 {v7.16b}, [%[src]], %[src_step] \n"
"ld1 {v8.16b}, [%[src]], %[src_step] \n"
"ld1 {v9.16b}, [%[src]], %[src_step] \n"
"ld1 {v10.16b}, [%[src]], %[src_step] \n"
"ld1 {v11.16b}, [%[src]], %[src_step] \n"
"ld1 {v12.16b}, [%[src]], %[src_step] \n"
"ld1 {v13.16b}, [%[src]], %[src_step] \n"
"ld1 {v14.16b}, [%[src]], %[src_step] \n"
"ld1 {v15.16b}, [%[src]], %[src_step] \n"
"trn1 v16.16b, v0.16b, v1.16b \n"
"trn2 v17.16b, v0.16b, v1.16b \n"
"trn1 v18.16b, v2.16b, v3.16b \n"
"trn2 v19.16b, v2.16b, v3.16b \n"
"trn1 v20.16b, v4.16b, v5.16b \n"
"trn2 v21.16b, v4.16b, v5.16b \n"
"trn1 v22.16b, v6.16b, v7.16b \n"
"trn2 v23.16b, v6.16b, v7.16b \n"
"trn1 v24.16b, v8.16b, v9.16b \n"
"trn2 v25.16b, v8.16b, v9.16b \n"
"trn1 v26.16b, v10.16b, v11.16b \n"
"trn2 v27.16b, v10.16b, v11.16b \n"
"trn1 v28.16b, v12.16b, v13.16b \n"
"trn2 v29.16b, v12.16b, v13.16b \n"
"trn1 v30.16b, v14.16b, v15.16b \n"
"trn2 v31.16b, v14.16b, v15.16b \n"
"trn1 v0.8h, v16.8h, v18.8h \n"
"trn2 v2.8h, v16.8h, v18.8h \n"
"trn1 v4.8h, v20.8h, v22.8h \n"
"trn2 v6.8h, v20.8h, v22.8h \n"
"trn1 v8.8h, v24.8h, v26.8h \n"
"trn2 v10.8h, v24.8h, v26.8h \n"
"trn1 v12.8h, v28.8h, v30.8h \n"
"trn2 v14.8h, v28.8h, v30.8h \n"
"trn1 v1.8h, v17.8h, v19.8h \n"
"trn2 v3.8h, v17.8h, v19.8h \n"
"trn1 v5.8h, v21.8h, v23.8h \n"
"trn2 v7.8h, v21.8h, v23.8h \n"
"trn1 v9.8h, v25.8h, v27.8h \n"
"trn2 v11.8h, v25.8h, v27.8h \n"
"trn1 v13.8h, v29.8h, v31.8h \n"
"trn2 v15.8h, v29.8h, v31.8h \n"
"trn1 v16.4s, v0.4s, v4.4s \n"
"trn2 v20.4s, v0.4s, v4.4s \n"
"trn1 v24.4s, v8.4s, v12.4s \n"
"trn2 v28.4s, v8.4s, v12.4s \n"
"trn1 v17.4s, v1.4s, v5.4s \n"
"trn2 v21.4s, v1.4s, v5.4s \n"
"trn1 v25.4s, v9.4s, v13.4s \n"
"trn2 v29.4s, v9.4s, v13.4s \n"
"trn1 v18.4s, v2.4s, v6.4s \n"
"trn2 v22.4s, v2.4s, v6.4s \n"
"trn1 v26.4s, v10.4s, v14.4s \n"
"trn2 v30.4s, v10.4s, v14.4s \n"
"trn1 v19.4s, v3.4s, v7.4s \n"
"trn2 v23.4s, v3.4s, v7.4s \n"
"trn1 v27.4s, v11.4s, v15.4s \n"
"trn2 v31.4s, v11.4s, v15.4s \n"
"trn1 v0.2d, v16.2d, v24.2d \n"
"trn2 v8.2d, v16.2d, v24.2d \n"
"trn1 v1.2d, v17.2d, v25.2d \n"
"trn2 v9.2d, v17.2d, v25.2d \n"
"trn1 v2.2d, v18.2d, v26.2d \n"
"trn2 v10.2d, v18.2d, v26.2d \n"
"trn1 v3.2d, v19.2d, v27.2d \n"
"trn2 v11.2d, v19.2d, v27.2d \n"
"trn1 v4.2d, v20.2d, v28.2d \n"
"trn2 v12.2d, v20.2d, v28.2d \n"
"trn1 v5.2d, v21.2d, v29.2d \n"
"trn2 v13.2d, v21.2d, v29.2d \n"
"trn1 v6.2d, v22.2d, v30.2d \n"
"trn2 v14.2d, v22.2d, v30.2d \n"
"trn1 v7.2d, v23.2d, v31.2d \n"
"trn2 v15.2d, v23.2d, v31.2d \n"
"st1 {v0.16b}, [%[dst]], %[dst_step] \n"
"st1 {v1.16b}, [%[dst]], %[dst_step] \n"
"st1 {v2.16b}, [%[dst]], %[dst_step] \n"
"st1 {v3.16b}, [%[dst]], %[dst_step] \n"
"st1 {v4.16b}, [%[dst]], %[dst_step] \n"
"st1 {v5.16b}, [%[dst]], %[dst_step] \n"
"st1 {v6.16b}, [%[dst]], %[dst_step] \n"
"st1 {v7.16b}, [%[dst]], %[dst_step] \n"
"st1 {v8.16b}, [%[dst]], %[dst_step] \n"
"st1 {v9.16b}, [%[dst]], %[dst_step] \n"
"st1 {v10.16b}, [%[dst]], %[dst_step] \n"
"st1 {v11.16b}, [%[dst]], %[dst_step] \n"
"st1 {v12.16b}, [%[dst]], %[dst_step] \n"
"st1 {v13.16b}, [%[dst]], %[dst_step] \n"
"st1 {v14.16b}, [%[dst]], %[dst_step] \n"
"st1 {v15.16b}, [%[dst]], %[dst_step] \n"
:
[src] "+r" (src),
[dst] "+r" (dst)
:
[src_step] "r" (src_step),
[dst_step] "r" (dst_step)
:
"d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30",
"d31");

}

} // anonymous namespace

namespace megdnn {
namespace relayout {
namespace transpose_fallback {
template <>
struct transpose_traits<TransposeByte> {
static constexpr size_t block_size = 16;
};

template <>
void transpose_block<TransposeByte>(const TransposeByte* src,
TransposeByte* dst, const size_t src_stride,
const size_t dst_stride) {
trans_16x16_u8(src, dst, src_stride, dst_stride);
}

} // namespace transpose_fallback
} // namespace relayout
} // namespace megdnn

void aarch64::RelayoutForwardImpl::exec(_megdnn_tensor_in src0,
_megdnn_tensor_out dst0,
Handle* src_handle) {
check_cpu_handle(src_handle);
TensorND src = src0, dst = dst0;
check_layout_and_canonize(src.layout, dst.layout);

relayout::TransposeParam trans_param;
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param);
if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) {
auto sptr = static_cast<TransposeByte*>(src.raw_ptr),
dptr = static_cast<TransposeByte*>(dst.raw_ptr);
MEGDNN_DISPATCH_CPU_KERN_OPR(
transpose_fallback::transpose<TransposeByte>(
trans_param.batch, trans_param.m, trans_param.n, sptr,
dptr));
return;
}
exec_after_preprocess(src, dst, trans ? &trans_param : nullptr);
}

// vim: syntax=cpp.doxygen

+ 31
- 0
dnn/src/aarch64/relayout/opr_impl.h View File

@@ -0,0 +1,31 @@
/**
* \file dnn/src/aarch64/relayout/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/fallback/relayout/opr_impl.h"

namespace megdnn {
namespace aarch64 {

class RelayoutForwardImpl final : public fallback::RelayoutForwardImpl {
public:
using fallback::RelayoutForwardImpl::RelayoutForwardImpl;

void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
Handle *src_handle) override;

bool is_thread_safe() const override { return true; }
};

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 393
- 0
dnn/src/aarch64/rotate/opr_impl.cpp View File

@@ -0,0 +1,393 @@
/**
* \file dnn/src/aarch64/rotate/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <cstring>

#include "src/aarch64/rotate/opr_impl.h"
#include "src/aarch64/handle.h"
#include "src/common/cv/common.h"
#include "src/common/cv/helper.h"
#include "src/common/utils.h"


namespace megdnn {
namespace megcv {

void rotate_8uc1_clockwise_16x16(const uchar *src,
uchar *dst,
size_t src_step, size_t dst_step)
{
asm volatile ("\n"
"ld1 {v0.16b}, [%[src]], %[src_step] \n"
"ld1 {v1.16b}, [%[src]], %[src_step] \n"
"ld1 {v2.16b}, [%[src]], %[src_step] \n"
"ld1 {v3.16b}, [%[src]], %[src_step] \n"
"ld1 {v4.16b}, [%[src]], %[src_step] \n"
"ld1 {v5.16b}, [%[src]], %[src_step] \n"
"ld1 {v6.16b}, [%[src]], %[src_step] \n"
"ld1 {v7.16b}, [%[src]], %[src_step] \n"
"ld1 {v8.16b}, [%[src]], %[src_step] \n"
"ld1 {v9.16b}, [%[src]], %[src_step] \n"
"ld1 {v10.16b}, [%[src]], %[src_step] \n"
"ld1 {v11.16b}, [%[src]], %[src_step] \n"
"ld1 {v12.16b}, [%[src]], %[src_step] \n"
"ld1 {v13.16b}, [%[src]], %[src_step] \n"
"ld1 {v14.16b}, [%[src]], %[src_step] \n"
"ld1 {v15.16b}, [%[src]], %[src_step] \n"

"trn1 v16.16b, v0.16b, v1.16b \n"
"trn2 v17.16b, v0.16b, v1.16b \n"
"trn1 v18.16b, v2.16b, v3.16b \n"
"trn2 v19.16b, v2.16b, v3.16b \n"
"trn1 v20.16b, v4.16b, v5.16b \n"
"trn2 v21.16b, v4.16b, v5.16b \n"
"trn1 v22.16b, v6.16b, v7.16b \n"
"trn2 v23.16b, v6.16b, v7.16b \n"
"trn1 v24.16b, v8.16b, v9.16b \n"
"trn2 v25.16b, v8.16b, v9.16b \n"
"trn1 v26.16b, v10.16b, v11.16b \n"
"trn2 v27.16b, v10.16b, v11.16b \n"
"trn1 v28.16b, v12.16b, v13.16b \n"
"trn2 v29.16b, v12.16b, v13.16b \n"
"trn1 v30.16b, v14.16b, v15.16b \n"
"trn2 v31.16b, v14.16b, v15.16b \n"

"trn1 v0.8h, v16.8h, v18.8h \n"
"trn2 v2.8h, v16.8h, v18.8h \n"
"trn1 v4.8h, v20.8h, v22.8h \n"
"trn2 v6.8h, v20.8h, v22.8h \n"
"trn1 v8.8h, v24.8h, v26.8h \n"
"trn2 v10.8h, v24.8h, v26.8h \n"
"trn1 v12.8h, v28.8h, v30.8h \n"
"trn2 v14.8h, v28.8h, v30.8h \n"
"trn1 v1.8h, v17.8h, v19.8h \n"
"trn2 v3.8h, v17.8h, v19.8h \n"
"trn1 v5.8h, v21.8h, v23.8h \n"
"trn2 v7.8h, v21.8h, v23.8h \n"
"trn1 v9.8h, v25.8h, v27.8h \n"
"trn2 v11.8h, v25.8h, v27.8h \n"
"trn1 v13.8h, v29.8h, v31.8h \n"
"trn2 v15.8h, v29.8h, v31.8h \n"

"trn1 v16.4s, v0.4s, v4.4s \n"
"trn2 v20.4s, v0.4s, v4.4s \n"
"trn1 v24.4s, v8.4s, v12.4s \n"
"trn2 v28.4s, v8.4s, v12.4s \n"
"trn1 v17.4s, v1.4s, v5.4s \n"
"trn2 v21.4s, v1.4s, v5.4s \n"
"trn1 v25.4s, v9.4s, v13.4s \n"
"trn2 v29.4s, v9.4s, v13.4s \n"
"trn1 v18.4s, v2.4s, v6.4s \n"
"trn2 v22.4s, v2.4s, v6.4s \n"
"trn1 v26.4s, v10.4s, v14.4s \n"
"trn2 v30.4s, v10.4s, v14.4s \n"
"trn1 v19.4s, v3.4s, v7.4s \n"
"trn2 v23.4s, v3.4s, v7.4s \n"
"trn1 v27.4s, v11.4s, v15.4s \n"
"trn2 v31.4s, v11.4s, v15.4s \n"

"trn1 v0.2d, v16.2d, v24.2d \n"
"trn2 v8.2d, v16.2d, v24.2d \n"
"trn1 v1.2d, v17.2d, v25.2d \n"
"trn2 v9.2d, v17.2d, v25.2d \n"
"trn1 v2.2d, v18.2d, v26.2d \n"
"trn2 v10.2d, v18.2d, v26.2d \n"
"trn1 v3.2d, v19.2d, v27.2d \n"
"trn2 v11.2d, v19.2d, v27.2d \n"
"trn1 v4.2d, v20.2d, v28.2d \n"
"trn2 v12.2d, v20.2d, v28.2d \n"
"trn1 v5.2d, v21.2d, v29.2d \n"
"trn2 v13.2d, v21.2d, v29.2d \n"
"trn1 v6.2d, v22.2d, v30.2d \n"
"trn2 v14.2d, v22.2d, v30.2d \n"
"trn1 v7.2d, v23.2d, v31.2d \n"
"trn2 v15.2d, v23.2d, v31.2d \n"
// There is no rev128 instruction, so we use rev64 and ext to simulate it.
"rev64 v0.16b, v0.16b \n"
"rev64 v1.16b, v1.16b \n"
"rev64 v2.16b, v2.16b \n"
"rev64 v3.16b, v3.16b \n"
"rev64 v4.16b, v4.16b \n"
"rev64 v5.16b, v5.16b \n"
"rev64 v6.16b, v6.16b \n"
"rev64 v7.16b, v7.16b \n"
"rev64 v8.16b, v8.16b \n"
"rev64 v9.16b, v9.16b \n"
"rev64 v10.16b, v10.16b \n"
"rev64 v11.16b, v11.16b \n"
"rev64 v12.16b, v12.16b \n"
"rev64 v13.16b, v13.16b \n"
"rev64 v14.16b, v14.16b \n"
"rev64 v15.16b, v15.16b \n"
"ext v0.16b, v0.16b, v0.16b, #8 \n"
"ext v1.16b, v1.16b, v1.16b, #8 \n"
"ext v2.16b, v2.16b, v2.16b, #8 \n"
"ext v3.16b, v3.16b, v3.16b, #8 \n"
"ext v4.16b, v4.16b, v4.16b, #8 \n"
"ext v5.16b, v5.16b, v5.16b, #8 \n"
"ext v6.16b, v6.16b, v6.16b, #8 \n"
"ext v7.16b, v7.16b, v7.16b, #8 \n"
"ext v8.16b, v8.16b, v8.16b, #8 \n"
"ext v9.16b, v9.16b, v9.16b, #8 \n"
"ext v10.16b, v10.16b, v10.16b, #8 \n"
"ext v11.16b, v11.16b, v11.16b, #8 \n"
"ext v12.16b, v12.16b, v12.16b, #8 \n"
"ext v13.16b, v13.16b, v13.16b, #8 \n"
"ext v14.16b, v14.16b, v14.16b, #8 \n"
"ext v15.16b, v15.16b, v15.16b, #8 \n"

"st1 {v0.16b}, [%[dst]], %[dst_step] \n"
"st1 {v1.16b}, [%[dst]], %[dst_step] \n"
"st1 {v2.16b}, [%[dst]], %[dst_step] \n"
"st1 {v3.16b}, [%[dst]], %[dst_step] \n"
"st1 {v4.16b}, [%[dst]], %[dst_step] \n"
"st1 {v5.16b}, [%[dst]], %[dst_step] \n"
"st1 {v6.16b}, [%[dst]], %[dst_step] \n"
"st1 {v7.16b}, [%[dst]], %[dst_step] \n"
"st1 {v8.16b}, [%[dst]], %[dst_step] \n"
"st1 {v9.16b}, [%[dst]], %[dst_step] \n"
"st1 {v10.16b}, [%[dst]], %[dst_step] \n"
"st1 {v11.16b}, [%[dst]], %[dst_step] \n"
"st1 {v12.16b}, [%[dst]], %[dst_step] \n"
"st1 {v13.16b}, [%[dst]], %[dst_step] \n"
"st1 {v14.16b}, [%[dst]], %[dst_step] \n"
"st1 {v15.16b}, [%[dst]], %[dst_step] \n"
:
[src] "+r" (src),
[dst] "+r" (dst)
:
[src_step] "r" (src_step),
[dst_step] "r" (dst_step)
:
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
);
}

void rotate_8uc1_counterclockwise_16x16(const uchar *src,
uchar *dst,
size_t src_step, size_t dst_step)
{
asm volatile ("\n"
"ld1 {v0.16b}, [%[src]], %[src_step] \n"
"ld1 {v1.16b}, [%[src]], %[src_step] \n"
"ld1 {v2.16b}, [%[src]], %[src_step] \n"
"ld1 {v3.16b}, [%[src]], %[src_step] \n"
"ld1 {v4.16b}, [%[src]], %[src_step] \n"
"ld1 {v5.16b}, [%[src]], %[src_step] \n"
"ld1 {v6.16b}, [%[src]], %[src_step] \n"
"ld1 {v7.16b}, [%[src]], %[src_step] \n"
"ld1 {v8.16b}, [%[src]], %[src_step] \n"
"ld1 {v9.16b}, [%[src]], %[src_step] \n"
"ld1 {v10.16b}, [%[src]], %[src_step] \n"
"ld1 {v11.16b}, [%[src]], %[src_step] \n"
"ld1 {v12.16b}, [%[src]], %[src_step] \n"
"ld1 {v13.16b}, [%[src]], %[src_step] \n"
"ld1 {v14.16b}, [%[src]], %[src_step] \n"
"ld1 {v15.16b}, [%[src]], %[src_step] \n"

"trn1 v16.16b, v0.16b, v1.16b \n"
"trn2 v17.16b, v0.16b, v1.16b \n"
"trn1 v18.16b, v2.16b, v3.16b \n"
"trn2 v19.16b, v2.16b, v3.16b \n"
"trn1 v20.16b, v4.16b, v5.16b \n"
"trn2 v21.16b, v4.16b, v5.16b \n"
"trn1 v22.16b, v6.16b, v7.16b \n"
"trn2 v23.16b, v6.16b, v7.16b \n"
"trn1 v24.16b, v8.16b, v9.16b \n"
"trn2 v25.16b, v8.16b, v9.16b \n"
"trn1 v26.16b, v10.16b, v11.16b \n"
"trn2 v27.16b, v10.16b, v11.16b \n"
"trn1 v28.16b, v12.16b, v13.16b \n"
"trn2 v29.16b, v12.16b, v13.16b \n"
"trn1 v30.16b, v14.16b, v15.16b \n"
"trn2 v31.16b, v14.16b, v15.16b \n"

"trn1 v0.8h, v16.8h, v18.8h \n"
"trn2 v2.8h, v16.8h, v18.8h \n"
"trn1 v4.8h, v20.8h, v22.8h \n"
"trn2 v6.8h, v20.8h, v22.8h \n"
"trn1 v8.8h, v24.8h, v26.8h \n"
"trn2 v10.8h, v24.8h, v26.8h \n"
"trn1 v12.8h, v28.8h, v30.8h \n"
"trn2 v14.8h, v28.8h, v30.8h \n"
"trn1 v1.8h, v17.8h, v19.8h \n"
"trn2 v3.8h, v17.8h, v19.8h \n"
"trn1 v5.8h, v21.8h, v23.8h \n"
"trn2 v7.8h, v21.8h, v23.8h \n"
"trn1 v9.8h, v25.8h, v27.8h \n"
"trn2 v11.8h, v25.8h, v27.8h \n"
"trn1 v13.8h, v29.8h, v31.8h \n"
"trn2 v15.8h, v29.8h, v31.8h \n"

"trn1 v16.4s, v0.4s, v4.4s \n"
"trn2 v20.4s, v0.4s, v4.4s \n"
"trn1 v24.4s, v8.4s, v12.4s \n"
"trn2 v28.4s, v8.4s, v12.4s \n"
"trn1 v17.4s, v1.4s, v5.4s \n"
"trn2 v21.4s, v1.4s, v5.4s \n"
"trn1 v25.4s, v9.4s, v13.4s \n"
"trn2 v29.4s, v9.4s, v13.4s \n"
"trn1 v18.4s, v2.4s, v6.4s \n"
"trn2 v22.4s, v2.4s, v6.4s \n"
"trn1 v26.4s, v10.4s, v14.4s \n"
"trn2 v30.4s, v10.4s, v14.4s \n"
"trn1 v19.4s, v3.4s, v7.4s \n"
"trn2 v23.4s, v3.4s, v7.4s \n"
"trn1 v27.4s, v11.4s, v15.4s \n"
"trn2 v31.4s, v11.4s, v15.4s \n"

"trn1 v0.2d, v16.2d, v24.2d \n"
"trn2 v8.2d, v16.2d, v24.2d \n"
"trn1 v1.2d, v17.2d, v25.2d \n"
"trn2 v9.2d, v17.2d, v25.2d \n"
"trn1 v2.2d, v18.2d, v26.2d \n"
"trn2 v10.2d, v18.2d, v26.2d \n"
"trn1 v3.2d, v19.2d, v27.2d \n"
"trn2 v11.2d, v19.2d, v27.2d \n"
"trn1 v4.2d, v20.2d, v28.2d \n"
"trn2 v12.2d, v20.2d, v28.2d \n"
"trn1 v5.2d, v21.2d, v29.2d \n"
"trn2 v13.2d, v21.2d, v29.2d \n"
"trn1 v6.2d, v22.2d, v30.2d \n"
"trn2 v14.2d, v22.2d, v30.2d \n"
"trn1 v7.2d, v23.2d, v31.2d \n"
"trn2 v15.2d, v23.2d, v31.2d \n"

"st1 {v15.16b}, [%[dst]], %[dst_step] \n"
"st1 {v14.16b}, [%[dst]], %[dst_step] \n"
"st1 {v13.16b}, [%[dst]], %[dst_step] \n"
"st1 {v12.16b}, [%[dst]], %[dst_step] \n"
"st1 {v11.16b}, [%[dst]], %[dst_step] \n"
"st1 {v10.16b}, [%[dst]], %[dst_step] \n"
"st1 {v9.16b}, [%[dst]], %[dst_step] \n"
"st1 {v8.16b}, [%[dst]], %[dst_step] \n"
"st1 {v7.16b}, [%[dst]], %[dst_step] \n"
"st1 {v6.16b}, [%[dst]], %[dst_step] \n"
"st1 {v5.16b}, [%[dst]], %[dst_step] \n"
"st1 {v4.16b}, [%[dst]], %[dst_step] \n"
"st1 {v3.16b}, [%[dst]], %[dst_step] \n"
"st1 {v2.16b}, [%[dst]], %[dst_step] \n"
"st1 {v1.16b}, [%[dst]], %[dst_step] \n"
"st1 {v0.16b}, [%[dst]], %[dst_step] \n"
:
[src] "+r" (src),
[dst] "+r" (dst)
:
[src_step] "r" (src_step),
[dst_step] "r" (dst_step)
:
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
);
}

void rotate_8uc1_clockwise(const uchar* src, uchar* dst, const size_t rows,
const size_t cols, const size_t src_step,
const size_t dst_step) {
const size_t block = 16;
(void)block;
size_t i = 0;

for (; i + block <= rows; i += block) {
size_t j = 0;
for (; j + block <= cols; j += block) {
rotate_8uc1_clockwise_16x16(
src + i * src_step + j,
dst + j * dst_step + (rows - (i + block)), src_step,
dst_step);
}
for (; j < cols; ++j) {
for (size_t k = 0; k < block; ++k) {
dst[j * dst_step + (rows - 1 - (i + k))] =
src[(i + k) * src_step + j];
}
}
}

for (; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
dst[j * dst_step + (rows - 1 - i)] = src[i * src_step + j];
}
}
}

void rotate_8uc1_counterclockwise(const uchar* src, uchar* dst,
const size_t rows, const size_t cols,
const size_t src_step,
const size_t dst_step) {
const size_t block = 16;
(void)block;
size_t i = 0;

for (; i + block <= rows; i += block) {
size_t j = 0;
for (; j + block <= cols; j += block) {
rotate_8uc1_counterclockwise_16x16(
src + i * src_step + j,
dst + (cols - (j + block)) * dst_step + i, src_step,
dst_step);
}
for (; j < cols; ++j) {
for (size_t k = 0; k < block; ++k) {
dst[(cols - 1 - j) * dst_step + (i + k)] =
src[(i + k) * src_step + j];
}
}
}

for (; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
dst[(cols - 1 - j) * dst_step + i] = src[i * src_step + j];
}
}
}

void rotate(const Mat<uchar>& src, Mat<uchar>& dst, bool clockwise) {
megdnn_assert(src.rows() == dst.cols());
megdnn_assert(src.cols() == dst.rows());
megdnn_assert(src.channels() == dst.channels());
megdnn_assert(src.channels() == 1_z);
if (clockwise) {
rotate_8uc1_clockwise(src.ptr(), dst.ptr(), src.rows(), src.cols(),
src.step(), dst.step());
} else {
rotate_8uc1_counterclockwise(src.ptr(), dst.ptr(), src.rows(),
src.cols(), src.step(), dst.step());
}
}

} // namespace megcv

namespace aarch64 {

void RotateImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
_megdnn_workspace workspace) {
using namespace megcv;
check_exec(src.layout, dst.layout, workspace.size);

//! rotate only support data type is uchar and the channel size is 1
if (dst.layout.dtype != dtype::Uint8() || src.layout.shape[3] != 1) {
return fallback::RotateImpl::exec(src, dst, workspace);
}

MEGDNN_DISPATCH_CPU_KERN_OPR({
for (size_t i = 0; i < src.layout.shape[0]; ++i) {
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i);
Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i);
rotate(src_mat, dst_mat, param().clockwise);
}
});
}

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 35
- 0
dnn/src/aarch64/rotate/opr_impl.h View File

@@ -0,0 +1,35 @@
/**
* \file dnn/src/aarch64/rotate/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "megdnn/oprs.h"
#include "src/fallback/rotate/opr_impl.h"

namespace megdnn {
namespace aarch64 {

class RotateImpl : public fallback::RotateImpl {
public:
using fallback::RotateImpl::RotateImpl;

void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;

size_t get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 43
- 0
dnn/src/aarch64/warp_perspective/opr_impl.cpp View File

@@ -0,0 +1,43 @@
/**
* \file dnn/src/aarch64/warp_perspective/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/warp_perspective/opr_impl.h"

#include "src/aarch64/warp_perspective/warp_perspective_cv.h"

#include "src/common/utils.h"
#include "src/common/warp_common.h"
#include "src/naive/handle.h"

namespace megdnn {
namespace aarch64 {

void WarpPerspectiveImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx,
_megdnn_tensor_in dst,
_megdnn_workspace workspace)
{
check_exec(src.layout, mat.layout, mat_idx.layout, dst.layout,
workspace.size);
if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode,
param().format) && !mat_idx.layout.ndim) {
warp_perspective_cv_exec(src, mat, dst, param().border_val,
param().bmode, param().imode, handle());
} else {
//! Use arm_common implementation
arm_common::WarpPerspectiveImpl::exec(src, mat, mat_idx, dst, workspace);
}
}

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 30
- 0
dnn/src/aarch64/warp_perspective/opr_impl.h View File

@@ -0,0 +1,30 @@
/**
* \file dnn/src/aarch64/warp_perspective/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/arm_common/warp_perspective/opr_impl.h"

namespace megdnn {
namespace aarch64 {

class WarpPerspectiveImpl : public arm_common::WarpPerspectiveImpl {
public:
using arm_common::WarpPerspectiveImpl::WarpPerspectiveImpl;

void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
};

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 257
- 0
dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp View File

@@ -0,0 +1,257 @@
/**
* \file dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/handle.h"
#include "src/aarch64/warp_perspective/warp_perspective_cv.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/cv/common.h"
#include "src/common/cv/helper.h"
#include "src/common/cv/interp_helper.h"
#include "src/common/utils.h"
#include "src/common/warp_common.h"

using namespace megdnn;
using namespace aarch64;
using namespace megcv;
using namespace warp;
namespace {

constexpr size_t BLOCK_SZ = 32u;
template <typename T, InterpolationMode imode, BorderMode bmode, size_t CH>
void warp_perspective_cv(const Mat<T>& src, Mat<T>& dst, const float* trans,
const float border_value, size_t task_id) {
// no extra padding
double M[9];
rep(i, 9) M[i] = trans[i];
T bvalue[3] = {(T)border_value, (T)border_value, (T)border_value};

size_t x1, y1, width = dst.cols(), height = dst.rows();
size_t BLOCK_SZ_H = std::min(BLOCK_SZ / 2, height);
size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width);
BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height);

size_t width_block_size = div_ceil<size_t>(width, BLOCK_SZ_W);
size_t y = (task_id / width_block_size) * BLOCK_SZ_H;
size_t x = (task_id % width_block_size) * BLOCK_SZ_W;
// start invoke
short XY[BLOCK_SZ * BLOCK_SZ * 2], A[BLOCK_SZ * BLOCK_SZ];

float64x2_t vM6 = vdupq_n_f64(M[6]);
float64x2_t vM0 = vdupq_n_f64(M[0]);
float64x2_t vM3 = vdupq_n_f64(M[3]);
float64x2_t v2M6 = vdupq_n_f64(M[6] * 2);
float64x2_t v2M0 = vdupq_n_f64(M[0] * 2);
float64x2_t v2M3 = vdupq_n_f64(M[3] * 2);
float64x2_t v4f = vdupq_n_f64(4);
float64x2_t v1f = vdupq_n_f64(1);
float64x2_t v0f = vdupq_n_f64(0);
float64x2_t vTABLE_SIZE = vdupq_n_f64(INTER_TAB_SIZE);
float64x2_t vmin = vdupq_n_f64((double)INT_MIN);
float64x2_t vmax = vdupq_n_f64((double)INT_MAX);
int32x4_t vtabmask = vdupq_n_s32(INTER_TAB_SIZE - 1);

size_t bw = std::min(BLOCK_SZ_W, width - x);
size_t bh = std::min(BLOCK_SZ_H, height - y); // height
Mat<short> _XY(bh, bw, 2, XY);
Mat<T> dpart(dst, y, bh, x, bw);

for (y1 = 0; y1 < bh; y1++) {
short* xy = XY + y1 * bw * 2;
double X0 = M[0] * x + M[1] * (y + y1) + M[2];
double Y0 = M[3] * x + M[4] * (y + y1) + M[5];
double W0 = M[6] * x + M[7] * (y + y1) + M[8];
float64x2_t vW0 = vdupq_n_f64(W0);
float64x2_t vidx = {0.f, 1.f};
float64x2_t vX0 = vdupq_n_f64(X0);
float64x2_t vY0 = vdupq_n_f64(Y0);
if (imode == IMode::NEAREST) {
for (x1 = 0; x1 + 4 <= bw; x1 += 4) {
float64x2_t vw0 = vaddq_f64(vW0, vmulq_f64(vM6, vidx));
float64x2_t vw1 = vaddq_f64(vw0, v2M6);

vw0 = vbitq_f64(vdivq_f64(v1f, vw0), v0f, vceqq_f64(vw0, v0f));
vw1 = vbitq_f64(vdivq_f64(v1f, vw1), v0f, vceqq_f64(vw1, v0f));

float64x2_t vtmp0 = vmlaq_f64(vX0, vM0, vidx);
float64x2_t vtmp1 = vaddq_f64(vtmp0, v2M0);
float64x2_t vfx0 = vmulq_f64(vtmp0, vw0);
float64x2_t vfx1 = vmulq_f64(vtmp1, vw1);
vfx0 = vmaxq_f64(vminq_f64(vfx0, vmax), vmin);
vfx1 = vmaxq_f64(vminq_f64(vfx1, vmax), vmin);

vtmp0 = vmlaq_f64(vY0, vM3, vidx);
vtmp1 = vaddq_f64(vtmp0, v2M3);
float64x2_t vfy0 = vmulq_f64(vtmp0, vw0);
float64x2_t vfy1 = vmulq_f64(vtmp1, vw1);
vfy0 = vmaxq_f64(vminq_f64(vfy0, vmax), vmin);
vfy1 = vmaxq_f64(vminq_f64(vfy1, vmax), vmin);

int32x2_t vx0 = vqmovn_s64(vcvtaq_s64_f64(vfx0));
int32x2_t vx1 = vqmovn_s64(vcvtaq_s64_f64(vfx1));
int32x2_t vy0 = vqmovn_s64(vcvtaq_s64_f64(vfy0));
int32x2_t vy1 = vqmovn_s64(vcvtaq_s64_f64(vfy1));

int32x4_t vx = vcombine_s32(vx0, vx1);
int32x4_t vy = vcombine_s32(vy0, vy1);

int16x4x2_t ret = {{vqmovn_s32(vx), vqmovn_s32(vy)}};
vst2_s16(xy + x1 * 2, ret);

vidx = vaddq_f64(vidx, v4f);
}

for (; x1 < bw; x1++) {
double W = W0 + M[6] * x1;
W = W ? 1. / W : 0;
double fX = std::max(
(double)INT_MIN,
std::min((double)INT_MAX, (X0 + M[0] * x1) * W));
double fY = std::max(
(double)INT_MIN,
std::min((double)INT_MAX, (Y0 + M[3] * x1) * W));
int X = saturate_cast<int>(fX);
int Y = saturate_cast<int>(fY);
xy[x1 * 2] = saturate_cast<short>(X);
xy[x1 * 2 + 1] = saturate_cast<short>(Y);
}
} else {
short* alpha = A + y1 * bw;
for (x1 = 0; x1 + 4 <= bw; x1 += 4) {
float64x2_t vw0 = vaddq_f64(vW0, vmulq_f64(vM6, vidx));
float64x2_t vw1 = vaddq_f64(vw0, v2M6);

vw0 = vbitq_f64(vdivq_f64(vTABLE_SIZE, vw0), v0f,
vceqq_f64(vw0, v0f));
vw1 = vbitq_f64(vdivq_f64(vTABLE_SIZE, vw1), v0f,
vceqq_f64(vw1, v0f));

float64x2_t vtmp0 = vmlaq_f64(vX0, vM0, vidx);
float64x2_t vtmp1 = vaddq_f64(vtmp0, v2M0);
float64x2_t vfx0 = vmulq_f64(vtmp0, vw0);
float64x2_t vfx1 = vmulq_f64(vtmp1, vw1);
vfx0 = vmaxq_f64(vminq_f64(vfx0, vmax), vmin);
vfx1 = vmaxq_f64(vminq_f64(vfx1, vmax), vmin);

vtmp0 = vmlaq_f64(vY0, vM3, vidx);
vtmp1 = vaddq_f64(vtmp0, v2M3);
float64x2_t vfy0 = vmulq_f64(vtmp0, vw0);
float64x2_t vfy1 = vmulq_f64(vtmp1, vw1);
vfy0 = vmaxq_f64(vminq_f64(vfy0, vmax), vmin);
vfy1 = vmaxq_f64(vminq_f64(vfy1, vmax), vmin);

int32x2_t vx0 = vqmovn_s64(vcvtaq_s64_f64(vfx0));
int32x2_t vx1 = vqmovn_s64(vcvtaq_s64_f64(vfx1));
int32x2_t vy0 = vqmovn_s64(vcvtaq_s64_f64(vfy0));
int32x2_t vy1 = vqmovn_s64(vcvtaq_s64_f64(vfy1));

int32x4_t vx = vcombine_s32(vx0, vx1);
int32x4_t vy = vcombine_s32(vy0, vy1);

int16x4x2_t ret = {{vqshrn_n_s32(vx, INTER_BITS),
vqshrn_n_s32(vy, INTER_BITS)}};
vst2_s16(xy + x1 * 2, ret);

vidx = vaddq_f64(vidx, v4f);

vx = vandq_s32(vx, vtabmask);
vy = vandq_s32(vy, vtabmask);

vst1_s16(&alpha[x1],
vqmovn_s32(vmlaq_n_s32(vx, vy, INTER_TAB_SIZE)));
}
for (; x1 < bw; x1++) {
double W = W0 + M[6] * x1;
W = W ? INTER_TAB_SIZE / W : 0;
double fX = std::max(
(double)INT_MIN,
std::min((double)INT_MAX, (X0 + M[0] * x1) * W));
double fY = std::max(
(double)INT_MIN,
std::min((double)INT_MAX, (Y0 + M[3] * x1) * W));
int X = saturate_cast<int>(fX);
int Y = saturate_cast<int>(fY);
xy[x1 * 2] = saturate_cast<short>(X >> INTER_BITS);
xy[x1 * 2 + 1] = saturate_cast<short>(Y >> INTER_BITS);
alpha[x1] =
(short)((Y & (INTER_TAB_SIZE - 1)) * INTER_TAB_SIZE +
(X & (INTER_TAB_SIZE - 1)));
}
}
}
Mat<ushort> _matA(bh, bw, 1, (ushort*)(A));
remap<T, imode, bmode, CH, RemapVec<T, CH>>(src, dpart, _XY, _matA, bvalue);
}
} // anonymous namespace
void megdnn::aarch64::warp_perspective_cv_exec(
_megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst,
float border_value, BorderMode bmode, InterpolationMode imode,
Handle* handle) {
size_t ch = dst.layout[3];
size_t width = dst.layout[2];
size_t height = dst.layout[1];
const size_t batch = dst.layout.shape[0];

size_t BLOCK_SZ_H = std::min(BLOCK_SZ / 2, height);
size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width);
BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height);

size_t parallelism_batch = div_ceil<size_t>(height, BLOCK_SZ_H) *
div_ceil<size_t>(width, BLOCK_SZ_W);
megdnn_assert(ch == 1 || ch == 3 || ch == 2,
"unsupported src channel: %zu, avaiable channel size: 1/2/3",
ch);
const float* trans_ptr = trans.ptr<dt_float32>();
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) {
#define cb(_imode, _bmode, _ch) \
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \
size_t index, size_t) { \
size_t batch_id = index / parallelism_batch; \
size_t task_id = index % parallelism_batch; \
Mat<float> src_mat = TensorND2Mat<float>(src, batch_id); \
Mat<float> dst_mat = TensorND2Mat<float>(dst, batch_id); \
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \
warp_perspective_cv<float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \
MEGDNN_COMMA _ch>( \
src_mat MEGDNN_COMMA const_cast<Mat<float>&>(dst_mat) \
MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \
task_id); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<naive::HandleImpl*>(handle), batch* parallelism_batch, \
task);
DISPATCH_IMODE(imode, bmode, ch, cb)
#undef cb
} else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) {
#define cb(_imode, _bmode, _ch) \
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \
size_t index, size_t) { \
size_t batch_id = index / parallelism_batch; \
size_t task_id = index % parallelism_batch; \
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, batch_id); \
Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, batch_id); \
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \
warp_perspective_cv<uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \
MEGDNN_COMMA _ch>( \
src_mat MEGDNN_COMMA const_cast<Mat<uchar>&>(dst_mat) \
MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \
task_id); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<naive::HandleImpl*>(handle), batch* parallelism_batch, \
task);
DISPATCH_IMODE(imode, bmode, ch, cb)
#undef cb
} else {
megdnn_throw(
megdnn_mangle("Unsupported datatype of WarpAffine optr."));
}

}
// vim: syntax=cpp.doxygen

+ 32
- 0
dnn/src/aarch64/warp_perspective/warp_perspective_cv.h View File

@@ -0,0 +1,32 @@
/**
* \file dnn/src/aarch64/warp_perspective/warp_perspective_cv.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include <megdnn/oprs.h>

#include "src/common/cv/helper.h"

namespace megdnn {
namespace aarch64 {

/**
* \fn warp_perspective_cv
* \brief Used if the format is NHWC, transfer from megcv
*/
void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans,
_megdnn_tensor_in dst, float border_value,
param::WarpPerspective::BorderMode border_mode,
param::WarpPerspective::InterpolationMode imode,
Handle* handle);

} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 380
- 0
dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp View File

@@ -0,0 +1,380 @@
/**
* \file dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/direct/multi_thread_common.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/fallback/matrix_mul/opr_impl.h"

using namespace megdnn;
using namespace arm_common;

namespace {
bool need_dst_copy(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) {
auto align = param.src_type.enumv() == DTypeEnum::Float32 ? 4 : 8;
return param.osz[1] % align;
}
bool need_src_copy(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) {
if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) {
return true;
}
return need_dst_copy(param);
}
void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param,
size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, size_t FW,
size_t PH, size_t PW, size_t& IH2, size_t& IW2, size_t& OW2) {
MEGDNN_MARK_USED_VAR(PW);
MEGDNN_MARK_USED_VAR(PH);
auto&& fm = param.filter_meta;
auto SW = fm.stride[1];

auto Align = param.src_type.enumv() == DTypeEnum::Float32 ? 3 : 7;
OW2 = (OW + Align) & ~Align;
IH2 = SW * OH + FH - SW;
IW2 = SW * OW2 + FW - SW;
// Because stride is 2, sometimes IW == IW2+1. Do a max update to
// handle this case.
IH2 = std::max(IH2, IH);
IW2 = std::max(IW2, IW);
}

} // namespace

template <typename io_ctype, typename compute_ctype>
WorkspaceBundle MultithreadDirectConvCommon<io_ctype, compute_ctype>::get_bundle(
const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) {
auto&& fm = param.filter_meta;
size_t nr_threads = param.nr_threads;
size_t group = fm.group, batch = param.n;
size_t IH2 = param.isz[0] + 2 * fm.padding[0];
size_t IW2 = param.isz[1] + 2 * fm.padding[1];
// part0: copied src
// part1: copied filter
size_t part0, part1;
if (fm.padding[0] == 0 && fm.padding[1] == 0) {
//! only the last plane need to be copied, add 16 Byte extra space in
//! case of invalid read and write
part0 = (param.isz[0] * param.isz[1]) * sizeof(io_ctype) + 16;
} else if (m_large_group) {
//! Serial in group, each thread process one group, parallel by group
part0 = (IH2 * IW2 * fm.icpg * nr_threads) * sizeof(io_ctype) + 16;
} else {
//! Parallel in group, Then should copy every inputs to workspace
part0 = (IH2 * IW2 * fm.icpg * group * batch) * sizeof(io_ctype) + 16;
}
if (param.filter_meta.should_flip) {
if (m_large_group) {
//! Serial in group, each thread has own workspace and then reuse
part1 = fm.spatial[0] * fm.spatial[1] * fm.ocpg * fm.icpg *
nr_threads * sizeof(io_ctype);
} else {
part1 = fm.spatial[0] * fm.spatial[1] * fm.ocpg * fm.icpg * group *
sizeof(io_ctype);
}
} else {
part1 = 0;
}
return {nullptr, {part0, part1}};
}
template <typename io_ctype, typename compute_ctype>
WorkspaceBundle
MultithreadDirectConvCommon<io_ctype, compute_ctype>::get_bundle_stride(
const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(SH);
MEGDNN_MARK_USED_VAR(SW);
auto&& fm = param.filter_meta;
size_t nr_threads = param.nr_threads;
size_t group = fm.group, batch = param.n;
size_t IH2, IW2, OW2;
get_rectified_size(param, IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2);

size_t src_size = 0, dst_size = 0;
// src_size: copied src
// dst_size: copied dst
if (need_src_copy(param)) {
src_size = m_large_group
? IC * IH2 * IW2 * sizeof(io_ctype) * nr_threads
: IC * IH2 * IW2 * sizeof(io_ctype) * group * batch;
};
if (need_dst_copy(param)) {
//! add 16 Byte extra space in case of invalid read and write
dst_size = OH * OW2 * sizeof(io_ctype) * nr_threads + 16;
}
return {nullptr, {src_size, dst_size}};
}

//! Process one output channel weight flip
template <typename io_ctype, typename compute_ctype>
void MultithreadDirectConvCommon<io_ctype, compute_ctype>::weight_flip_kern(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t IC = kern_param.filter_meta.icpg;
size_t OC = kern_param.filter_meta.ocpg;
//! Used for get the workspace offset
size_t workspace_group_id = workspace_ids[0], channel_id = workspace_ids[2],
group_id = ncb_index.ndrange_id[0];
const io_ctype* filter =
kern_param.filter<io_ctype>(group_id) + channel_id * FH * FW * IC;
bundle.set(kern_param.workspace_ptr);
io_ctype* filter_flip =
static_cast<io_ctype*>(bundle.get(1)) +
(workspace_group_id * IC * OC + channel_id * IC) * FH * FW;
rep(ic, IC) {
const io_ctype* filter_plane = filter + ic * FH * FW;
io_ctype* filter_flip_plane = filter_flip + ic * FH * FW;
rep(fh, FH) rep(fw, FW) {
filter_flip_plane[fh * FW + fw] =
filter_plane[(FH - fh - 1) * FW + (FW - fw - 1)];
}
}
}

//! Process one input channel copy padding
template <typename io_ctype, typename compute_ctype>
void MultithreadDirectConvCommon<io_ctype, compute_ctype>::copy_padding_kern(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = kern_param.filter_meta.icpg;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t IH2 = IH + 2 * PH;
size_t IW2 = IW + 2 * PW;
size_t padding_group_size = IH2 * IW2 * IC;
size_t N = kern_param.n;
size_t GROUP = kern_param.filter_meta.group;
bundle.set(kern_param.workspace_ptr);

//! Used for get the workspace offset
size_t workspace_group_id = workspace_ids[0],
workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2];
size_t batch_id = ncb_index.ndrange_id[1],
group_id = ncb_index.ndrange_id[0];
const io_ctype* sptr = static_cast<const io_ctype*>(
kern_param.src<io_ctype>(batch_id, group_id, channel_id));
if (PH > 0 || PW > 0) {
//! copy to sptr_base to eliminate padding effect
io_ctype* sptr_base = static_cast<io_ctype*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size +
channel_id * IH2 * IW2;
std::memset(sptr_base, 0, sizeof(io_ctype) * IH2 * IW2);
rep(ih, IH) {
std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW,
sizeof(io_ctype) * IW);
}
} else if (batch_id + 1 == N && channel_id + 1 == IC &&
group_id + 1 == GROUP) {
//! copy last plane
io_ctype* sptr_last_c = static_cast<io_ctype*>(bundle.get(0));
std::memcpy(sptr_last_c, sptr, sizeof(io_ctype) * IH2 * IW2);
}
};
//! Process one input channel copy padding
template <typename io_ctype, typename compute_ctype>
void MultithreadDirectConvCommon<io_ctype, compute_ctype>::
copy_padding_kern_stride(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = kern_param.filter_meta.icpg;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t OW = kern_param.osz[1];
size_t OH = kern_param.osz[0];
size_t IH2, IW2, OW2;
size_t GROUP = kern_param.filter_meta.group;
get_rectified_size(kern_param, IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2);
size_t padding_group_size = IH2 * IW2 * IC;
bundle.set(kern_param.workspace_ptr);

//! Used for get the workspace offset
size_t workspace_group_id = workspace_ids[0],
workspace_batch_id = workspace_ids[1];
size_t channel_id = workspace_ids[2], batch_id = ncb_index.ndrange_id[1],
group_id = ncb_index.ndrange_id[0];

const io_ctype* sptr = static_cast<const io_ctype*>(
kern_param.src<io_ctype>(batch_id, group_id, channel_id));
if (need_src_copy(kern_param)) {
//! copy to sptr_base to eliminate padding effect
io_ctype* sptr_base = static_cast<io_ctype*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size +
channel_id * IH2 * IW2;
std::memset(sptr_base, 0, sizeof(io_ctype) * IH2 * IW2);
rep(ih, IH) {
std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW,
sizeof(io_ctype) * IW);
}
}
};

//! compute one output channel
template <typename io_ctype, typename compute_ctype>
void MultithreadDirectConvCommon<io_ctype, compute_ctype>::do_conv_kern(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const kern_direct_conv_f32& fun, const CpuNDRange& workspace_ids) {
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t IC = kern_param.filter_meta.icpg;
size_t OC = kern_param.filter_meta.ocpg;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t IH2 = kern_param.isz[0] + 2 * PH;
size_t IW2 = kern_param.isz[1] + 2 * PW;
size_t padding_group_size = IH2 * IW2 * IC;
size_t N = kern_param.n;
size_t GROUP = kern_param.filter_meta.group;
bundle.set(kern_param.workspace_ptr);

size_t group_id = ncb_index.ndrange_id[0],
batch_id = ncb_index.ndrange_id[1];
size_t channel_id = workspace_ids[2];

const io_ctype* sptr = kern_param.src<io_ctype>(batch_id, group_id);
const io_ctype* filter = kern_param.filter<io_ctype>(group_id);
const io_ctype* bias_ptr =
kern_param.bias<io_ctype>(batch_id, group_id, channel_id);
io_ctype* dptr = kern_param.dst<io_ctype>(batch_id, group_id, channel_id);

//! Used for get the workspace offset
size_t workspace_batch_id = workspace_ids[1];
size_t workspace_group_id = workspace_ids[0];

io_ctype* sptr_base;
io_ctype* sptr_last_c;
auto fptr =
kern_param.filter_meta.should_flip
? static_cast<io_ctype*>(bundle.get(1)) +
(workspace_group_id * OC * IC + channel_id * IC) *
FH * FW
: filter + channel_id * FH * FW * IC;
if (PH > 0 || PW > 0) {
sptr_base = static_cast<io_ctype*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size;
sptr_last_c = sptr_base + (IC - 1) * IH2 * IW2;
//! Last batch, last group
} else if (batch_id + 1 == N && group_id + 1 == GROUP) {
sptr_base = const_cast<io_ctype*>(sptr);
sptr_last_c = static_cast<io_ctype*>(bundle.get(0));
} else {
sptr_base = const_cast<io_ctype*>(sptr);
sptr_last_c = sptr_base + (IC - 1) * IH2 * IW2;
}
std::memset(dptr, 0, sizeof(io_ctype) * (OH * OW));
rep(ic, IC) {
io_ctype* sptr_cur =
(ic + 1 == IC ? sptr_last_c : sptr_base + ic * IH2 * IW2);
fun(reinterpret_cast<const compute_ctype*>(sptr_cur),
reinterpret_cast<const compute_ctype*>(fptr + ic * FH * FW),
reinterpret_cast<compute_ctype*>(dptr), IH2, IW2, OH, OW, FH, FW);
}
PostProcess<compute_ctype>::run(dptr, const_cast<io_ctype*>(bias_ptr), dptr,
kern_param.bias_mode, kern_param.nonlineMode,
kern_param.bias_type, kern_param.dst_type, 1_z,
1_z, OH, OW);
};

//! compute one output channel
template <typename io_ctype, typename compute_ctype>
void MultithreadDirectConvCommon<io_ctype, compute_ctype>::do_conv_kern_stride(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const kern_direct_conv_f32_stride& fun,
const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t IC = kern_param.filter_meta.icpg;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t IH2, IW2, OW2;
get_rectified_size(kern_param, IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2);

size_t padding_group_size = IH2 * IW2 * IC;
size_t GROUP = kern_param.filter_meta.group;
bundle.set(kern_param.workspace_ptr);

//! Used for get the workspace offset
size_t group_id = ncb_index.ndrange_id[0],
batch_id = ncb_index.ndrange_id[1];
size_t channel_id = workspace_ids[2];

const io_ctype* sptr = kern_param.src<io_ctype>(batch_id, group_id);
const io_ctype* fptr =
kern_param.filter<io_ctype>(group_id) + channel_id * FH * FW * IC;
const io_ctype* bias_ptr =
kern_param.bias<io_ctype>(batch_id, group_id, channel_id);
io_ctype* dptr = kern_param.dst<io_ctype>(batch_id, group_id, channel_id);

size_t workspace_batch_id = workspace_ids[1];
size_t workspace_group_id = workspace_ids[0];

io_ctype* sptr_base;
io_ctype* dptr_base;
if (need_src_copy(kern_param)) {
sptr_base = static_cast<io_ctype*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
workspace_batch_id * GROUP * padding_group_size;
} else {
sptr_base = const_cast<io_ctype*>(sptr);
}
if (need_dst_copy(kern_param)) {
dptr_base = static_cast<io_ctype*>(bundle.get(1)) +
ncb_index.thread_id * OH * OW2;
} else {
dptr_base = dptr;
}
if (need_dst_copy(kern_param)) {
std::memset(dptr_base, 0, sizeof(io_ctype) * (OH * OW2));
fun(reinterpret_cast<const compute_ctype*>(sptr_base),
reinterpret_cast<const compute_ctype*>(fptr),
reinterpret_cast<compute_ctype*>(dptr_base), IH2, IW2, OH, OW2, IC);
copy_plane_in_bytes(dptr, dptr_base, OH, OW * sizeof(io_ctype),
OW * sizeof(io_ctype), OW2 * sizeof(io_ctype));
} else {
std::memset(dptr_base, 0, sizeof(io_ctype) * (OH * OW));
fun(reinterpret_cast<const compute_ctype*>(sptr_base),
reinterpret_cast<const compute_ctype*>(fptr),
reinterpret_cast<compute_ctype*>(dptr_base), IH2, IW2, OH, OW, IC);
}
PostProcess<compute_ctype>::run(dptr, const_cast<io_ctype*>(bias_ptr), dptr,
kern_param.bias_mode, kern_param.nonlineMode,
kern_param.bias_type, kern_param.dst_type, 1_z,
1_z, OH, OW);
};
template class megdnn::arm_common::MultithreadDirectConvCommon<float, float>;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template class megdnn::arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>;
#endif
// vim: syntax=cpp.doxygen

+ 65
- 0
dnn/src/arm_common/conv_bias/direct/multi_thread_common.h View File

@@ -0,0 +1,65 @@
/**
* \file dnn/src/arm_common/conv_bias/direct/multi_thread_common.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/matrix_mul/opr_impl.h"

namespace megdnn {
namespace arm_common {

template <class io_ctype, class compute_ctype>
class MultithreadDirectConvCommon {
public:
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam;
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam;
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex;

using kern_direct_conv_f32 =
std::function<void(const compute_ctype* src,
const compute_ctype* filter, compute_ctype* dst,
size_t, size_t, size_t, size_t, size_t, size_t)>;
using kern_direct_conv_f32_stride = std::function<void(
const compute_ctype* src, const compute_ctype* filter,
compute_ctype* dst, size_t, size_t, size_t, size_t, size_t)>;

static WorkspaceBundle get_bundle(const NCBKernSizeParam& param,
bool m_large_group);
static WorkspaceBundle get_bundle_stride(const NCBKernSizeParam& param,
bool m_large_group);
static void weight_flip_kern(WorkspaceBundle bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
static void copy_padding_kern(WorkspaceBundle bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
static void copy_padding_kern_stride(WorkspaceBundle bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
static void do_conv_kern(WorkspaceBundle bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index,
const kern_direct_conv_f32& fun,
const CpuNDRange& workspace_ids);
static void do_conv_kern_stride(WorkspaceBundle bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index,
const kern_direct_conv_f32_stride& fun,
const CpuNDRange& workspace_ids);
};

} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 561
- 0
dnn/src/arm_common/conv_bias/f16/algos.cpp View File

@@ -0,0 +1,561 @@
/**
* \file dnn/src/arm_common/conv_bias/f16/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/f16/algos.h"
#include "src/arm_common/conv_bias/direct/multi_thread_common.h"
#include "src/arm_common/conv_bias/f16/direct.h"
#include "src/arm_common/conv_bias/f16/do_conv_stride1.h"
#include "src/arm_common/conv_bias/f16/strategy.h"
#include "src/arm_common/conv_bias/img2col_helper.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp16)
using namespace megdnn;
using namespace arm_common;

/* ======================= AlgoFP16WinogradF23 ======================== */

bool ConvBiasImpl::AlgoFP16WinogradF23::usable(
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(opr);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 0, 0) {
using Strategy = winograd::winograd_2x3_4x4_f16;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
(opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
opr->param().output_block_size == 2 &&
param.winograd_matmul_format ==
param::MatrixMul::Format::DEFAULT)) &&
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_meta.icpg % 4 == 0 &&
param.filter_meta.ocpg % 4 == 0;
}
MIDOUT_END();
return false;
}

size_t ConvBiasImpl::AlgoFP16WinogradF23::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 0, 1) {
winograd::winograd_2x3_4x4_f16 strategy(
param.src_type, param.filter_type, param.dst_type);
return megdnn::winograd::ConvBias<winograd::winograd_2x3_4x4_f16>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_workspace_size(param, m_matmul_algo);
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoFP16WinogradF23::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 0, 2) {
winograd::winograd_2x3_4x4_f16 strategy(
param.src_type, param.filter_type, param.dst_type);

auto winograd_impl =
megdnn::winograd::ConvBias<winograd::winograd_2x3_4x4_f16>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg);
return winograd_impl.get_kerns(param, m_matmul_algo);
}
MIDOUT_END();
return {};
}

/* ======================= AlgoFP16WinogradF45 ======================== */

bool ConvBiasImpl::AlgoFP16WinogradF45::usable(
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(opr);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 1, 0) {
using Strategy = winograd::winograd_4x5_1x1_f16;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
(opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
opr->param().output_block_size == 4 &&
param.winograd_matmul_format ==
param::MatrixMul::Format::DEFAULT)) &&
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 5) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::Float16;
}
MIDOUT_END();
return false;
}

size_t ConvBiasImpl::AlgoFP16WinogradF45::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
winograd::winograd_4x5_1x1_f16 strategy(param.src_type, param.filter_type,
param.dst_type);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 1, 1) {
return megdnn::winograd::ConvBias<winograd::winograd_4x5_1x1_f16>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_workspace_size(param, m_matmul_algo);
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoFP16WinogradF45::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 1, 2) {
winograd::winograd_4x5_1x1_f16 strategy(
param.src_type, param.filter_type, param.dst_type);
auto winograd_impl =
megdnn::winograd::ConvBias<winograd::winograd_4x5_1x1_f16>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg);
return winograd_impl.get_kerns(param, m_matmul_algo);
}
MIDOUT_END();
return {};
}
/* ======================= AlgoFP16WinogradF63 ======================== */

bool ConvBiasImpl::AlgoFP16WinogradF63::usable(
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(opr);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 2, 0) {
using Strategy = winograd::winograd_6x3_1x1_f16;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
(opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
opr->param().output_block_size == 6 &&
param.winograd_matmul_format ==
param::MatrixMul::Format::DEFAULT)) &&
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::Float16;
}
MIDOUT_END();
return false;
}

size_t ConvBiasImpl::AlgoFP16WinogradF63::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
winograd::winograd_6x3_1x1_f16 strategy(param.src_type, param.filter_type,
param.dst_type);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 2, 1) {
return megdnn::winograd::ConvBias<winograd::winograd_6x3_1x1_f16>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_workspace_size(param, m_matmul_algo);
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoFP16WinogradF63::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 2, 2) {
winograd::winograd_6x3_1x1_f16 strategy(
param.src_type, param.filter_type, param.dst_type);
auto winograd_impl =
megdnn::winograd::ConvBias<winograd::winograd_6x3_1x1_f16>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg);
return winograd_impl.get_kerns(param, m_matmul_algo);
}
MIDOUT_END();
return {};
}

/* ======================= AlgoFP16WinogradF23_8x8 ======================== */

bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable(
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(opr);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 3, 0) {
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0)
return false;
using Strategy = winograd::winograd_2x3_8x8_f16;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy,
param::MatrixMul::Format::MK8>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
(opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
opr->param().output_block_size == 2 &&
param.winograd_matmul_format ==
param::MatrixMul::Format::MK8)) &&
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::Float16;
}
MIDOUT_END();
return false;
}

size_t ConvBiasImpl::AlgoFP16WinogradF23_8x8::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 3, 1) {
winograd::winograd_2x3_8x8_f16 strategy(
param.src_type, param.filter_type, param.dst_type);
return megdnn::winograd::ConvBias<winograd::winograd_2x3_8x8_f16,
param::MatrixMul::Format::MK8>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_workspace_size(param, m_matmul_algo);
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoFP16WinogradF23_8x8::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 2) {
winograd::winograd_2x3_8x8_f16 strategy(
param.src_type, param.filter_type, param.dst_type);
auto winograd_impl =
megdnn::winograd::ConvBias<winograd::winograd_2x3_8x8_f16,
param::MatrixMul::Format::MK8>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg);
return winograd_impl.get_kerns(param, m_matmul_algo);
}
MIDOUT_END();
return {};
}

/*========================from Convolution=============================*/

MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_kimpl)

bool ConvBiasImpl::AlgoF16Direct::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto SH = fm.stride[0], SW = fm.stride[1];
// the condition ``param.isz[0]*param.isz[1] >= 8'' and
// ``param.osz[0]*param.osz[1] >= 8'' comes from the fact that the
// kernel may have access to up to 8 fp16 after the end of the memory
// chunk.
bool aviliable = fm.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16 &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 &&
param.isz[0] * param.isz[1] >= 8 &&
param.osz[0] * param.osz[1] >= 8 && FH <= 7 &&
SH == 1 && SW == 1;
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
}
MIDOUT_END();
return false;
}

size_t ConvBiasImpl::AlgoF16Direct::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) {
auto wbundle =
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle(
param, m_large_group);
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
size_t N = param.n;
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
WorkspaceBundle wbundle =
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle(
param, m_large_group);
SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process
//! one group for better performance
if (m_large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [wbundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
auto fm = kern_param.filter_meta;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
WorkspaceBundle bundle = wbundle;
if (fm.should_flip) {
for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<dt_float16, __fp16>::
weight_flip_kern(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, oc});
}
}
for (size_t ic = 0; ic < IC; ic++) {
MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<dt_float16, __fp16>::do_conv_kern(
bundle, kern_param, ncb_index,
fp16::conv_bias::kern_direct_f16,
{ncb_index.thread_id, 0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
WorkspaceBundle bundle = wbundle;
if (fm.should_flip) {
auto weight_flip = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
MultithreadDirectConvCommon<dt_float16, __fp16>::
weight_flip_kern(bundle, kern_param, ncb_index,
ncb_index.ndrange_id);
};
ret_kerns.push_back({weight_flip, {group, 1_z, OC}});
}
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
MultithreadDirectConvCommon<dt_float16, __fp16>::copy_padding_kern(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
MultithreadDirectConvCommon<dt_float16, __fp16>::do_conv_kern(
bundle, kern_param, ncb_index,
fp16::conv_bias::kern_direct_f16, ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}
return ret_kerns;
}

SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) {
return get_kimpls(param);
}
MIDOUT_END();
return {};
}

/* ===================== stride-1 algo ===================== */

bool ConvBiasImpl::AlgoF16DirectStride1::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16 &&
!fm.should_flip && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
}
MIDOUT_END();
return false;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16DirectStride1::get_kimpls(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
auto FH = fm.spatial[0];
size_t N = param.n;
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*,
size_t, size_t, size_t, size_t, size_t)>;
Func conv_kern_function = nullptr;

#define SWITCH_KERN() \
switch (FH) { \
case 2: \
conv_kern_function = fp16::conv_stride1::do_conv_2x2_stride1; \
break; \
case 3: \
conv_kern_function = fp16::conv_stride1::do_conv_3x3_stride1; \
break; \
case 5: \
conv_kern_function = fp16::conv_stride1::do_conv_5x5_stride1; \
break; \
}
SWITCH_KERN();

WorkspaceBundle wbundle =
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle_stride(
param, m_large_group);
SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process
//! one group for better performance
if (m_large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [wbundle, conv_kern_function](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
auto fm = kern_param.filter_meta;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
WorkspaceBundle bundle = wbundle;
for (size_t ic = 0; ic < IC; ic++) {
MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride(bundle, kern_param, ncb_index,
conv_kern_function,
{ncb_index.thread_id, 0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
WorkspaceBundle bundle = wbundle;
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv_kern_function](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride(bundle, kern_param, ncb_index,
conv_kern_function,
ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}
return ret_kerns;
}

size_t ConvBiasImpl::AlgoF16DirectStride1::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 1) {
auto bundle = MultithreadDirectConvCommon<
dt_float16, __fp16>::get_bundle_stride(param, m_large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16DirectStride1::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 2) {
return get_kimpls(param);
}
MIDOUT_END();
return {};
}

#endif
// vim: syntax=cpp.doxygen

+ 185
- 0
dnn/src/arm_common/conv_bias/f16/algos.h View File

@@ -0,0 +1,185 @@
/**
* \file dnn/src/arm_common/conv_bias/f16/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/matrix_mul/opr_impl.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
namespace megdnn {
namespace arm_common {

class ConvBiasImpl::AlgoFP16WinogradF23 final : public AlgoBase {
public:
AlgoFP16WinogradF23(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {1, 2, m_tile_size});
}
return m_name.c_str();
}
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;

static std::vector<fallback::MatrixMulImpl::Algorithm*>
get_avaiable_matmul_algos(const NCBKernSizeParam& param);

private:
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;

uint32_t m_tile_size;
};

class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase {
public:
AlgoFP16WinogradF45(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {1, 4, m_tile_size});
}
return m_name.c_str();
}
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;

virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;

static std::vector<fallback::MatrixMulImpl::Algorithm*>
get_avaiable_matmul_algos(const NCBKernSizeParam& param);

private:
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;

uint32_t m_tile_size;
};
class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase {
public:
AlgoFP16WinogradF63(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {1, 6, m_tile_size});
}
return m_name.c_str();
}

bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;

static std::vector<fallback::MatrixMulImpl::Algorithm*>
get_avaiable_matmul_algos(const NCBKernSizeParam& param);

private:
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;

uint32_t m_tile_size;
};
class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase {
public:
AlgoFP16WinogradF23_8x8(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {8, 2, m_tile_size});
}
return m_name.c_str();
}
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;

virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;

private:
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;
uint32_t m_tile_size;
};

class ConvBiasImpl::AlgoF16Direct final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;

public:
AlgoF16Direct(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F16DIRECT_LARGE_GROUP"
: "F16DIRECT_SMALL_GROUP";
}
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;

size_t get_workspace(fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;

virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};

class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;

public:
AlgoF16DirectStride1(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F16STRD1_LARGE_GROUP" : "F16STRD1_SMALL_GROUP";
}
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};

} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 799
- 0
dnn/src/arm_common/conv_bias/f16/direct.cpp View File

@@ -0,0 +1,799 @@
/**
* \file dnn/src/arm_common/conv_bias/f16/direct.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./direct.h"
#include "include/megdnn/oprs.h"
#include "midout.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include <cstring>
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/simd_macro/neon_helper.h"

MIDOUT_DECL(megdnn_arm_conv_f16)

using namespace megdnn;
using namespace arm_common;
using namespace fp16;
using namespace conv_bias;
namespace {

#define BLOCK_H 8

#define LOAD_RESULT_VAL \
if (width < 8) { \
auto load_less_8 = [](__fp16* dst, float16x8_t& data) { \
if (width == 1u) { \
data = vld1q_lane_f16(dst, data, 0); \
} else if (width == 2u) { \
data = vld1q_lane_f16(dst + 0, data, 0); \
data = vld1q_lane_f16(dst + 1, data, 1); \
} else if (width == 3u) { \
data = vld1q_lane_f16(dst + 0, data, 0); \
data = vld1q_lane_f16(dst + 1, data, 1); \
data = vld1q_lane_f16(dst + 2, data, 2); \
} else if (width == 4u) { \
data = vld1q_lane_u64(dst, data, 0); \
} else if (width == 5u) { \
data = vld1q_lane_u64(dst, data, 0); \
data = vld1q_lane_f16(dst + 4, data, 4); \
} else if (width == 6u) { \
data = vld1q_lane_u64(dst, data, 0); \
data = vld1q_lane_f16(dst + 4, data, 4); \
data = vld1q_lane_f16(dst + 5, data, 5); \
} else if (width == 7u) { \
data = vld1q_lane_u64(dst, data, 0); \
data = vld1q_lane_f16(dst + 4, data, 4); \
data = vld1q_lane_f16(dst + 5, data, 5); \
data = vld1q_lane_f16(dst + 6, data, 6); \
} \
}; \
if (height >= 1) \
load_less_8(dst + 0 * OW, out0); \
if (height >= 2) \
load_less_8(dst + 1 * OW, out1); \
if (height >= 3) \
load_less_8(dst + 2 * OW, out2); \
if (height >= 4) \
load_less_8(dst + 3 * OW, out3); \
if (height >= 5) \
load_less_8(dst + 4 * OW, out4); \
if (height >= 6) \
load_less_8(dst + 5 * OW, out5); \
if (height >= 7) \
load_less_8(dst + 6 * OW, out6); \
if (height >= 8) \
load_less_8(dst + 7 * OW, out7); \
} else { \
if (height >= 1) \
out0 = vld1q_f16(dst + 0 * OW); \
if (height >= 2) \
out1 = vld1q_f16(dst + 1 * OW); \
if (height >= 3) \
out2 = vld1q_f16(dst + 2 * OW); \
if (height >= 4) \
out3 = vld1q_f16(dst + 3 * OW); \
if (height >= 5) \
out4 = vld1q_f16(dst + 4 * OW); \
if (height >= 6) \
out5 = vld1q_f16(dst + 5 * OW); \
if (height >= 7) \
out6 = vld1q_f16(dst + 6 * OW); \
if (height >= 8) \
out7 = vld1q_f16(dst + 7 * OW); \
}

#define STORE_RESULT_VAL \
if (width < 8) { \
auto store_less_8 = [](__fp16* dst, float16x8_t& data) { \
if (width == 1u) { \
vst1q_lane_f16(dst, data, 0); \
} else if (width == 2u) { \
vst1q_lane_f16(dst + 0, data, 0); \
vst1q_lane_f16(dst + 1, data, 1); \
} else if (width == 3u) { \
vst1q_lane_f16(dst + 0, data, 0); \
vst1q_lane_f16(dst + 1, data, 1); \
vst1q_lane_f16(dst + 2, data, 2); \
} else if (width == 4u) { \
vst1_f16(dst, vget_low_f16(data)); \
} else if (width == 5u) { \
vst1_f16(dst, vget_low_f16(data)); \
vst1q_lane_f16(dst + 4, data, 4); \
} else if (width == 6u) { \
vst1_f16(dst, vget_low_f16(data)); \
vst1q_lane_f16(dst + 4, data, 4); \
vst1q_lane_f16(dst + 5, data, 5); \
} else if (width == 7u) { \
vst1_f16(dst, vget_low_f16(data)); \
vst1q_lane_f16(dst + 4, data, 4); \
vst1q_lane_f16(dst + 5, data, 5); \
vst1q_lane_f16(dst + 6, data, 6); \
} \
}; \
if (height >= 1) \
store_less_8(dst + 0 * OW, out0); \
if (height >= 2) \
store_less_8(dst + 1 * OW, out1); \
if (height >= 3) \
store_less_8(dst + 2 * OW, out2); \
if (height >= 4) \
store_less_8(dst + 3 * OW, out3); \
if (height >= 5) \
store_less_8(dst + 4 * OW, out4); \
if (height >= 6) \
store_less_8(dst + 5 * OW, out5); \
if (height >= 7) \
store_less_8(dst + 6 * OW, out6); \
if (height >= 8) \
store_less_8(dst + 7 * OW, out7); \
} else { \
if (height >= 1) \
vst1q_f16(dst + 0 * OW, out0); \
if (height >= 2) \
vst1q_f16(dst + 1 * OW, out1); \
if (height >= 3) \
vst1q_f16(dst + 2 * OW, out2); \
if (height >= 4) \
vst1q_f16(dst + 3 * OW, out3); \
if (height >= 5) \
vst1q_f16(dst + 4 * OW, out4); \
if (height >= 6) \
vst1q_f16(dst + 5 * OW, out5); \
if (height >= 7) \
vst1q_f16(dst + 6 * OW, out6); \
if (height >= 8) \
vst1q_f16(dst + 7 * OW, out7); \
}

template <int FH, int height, int width>
struct do_pixel_proxy {
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow);
};

template <int height, int width>
struct do_pixel_proxy<1, height, width> {
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
MEGDNN_MARK_USED_VAR(IH);
MEGDNN_MARK_USED_VAR(OH);
const int ih = oh, iw = ow;
#define cb(i) float16x8_t out##i{0};
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);
#undef cb
float16x8_t kr0, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_RESULT_VAL;
for (int fw = 0; fw < FW; ++fw) {
const __fp16* src_dd = src + fw;
kr0 = vdupq_n_f16(filter[0 * FW + fw]);

#define cb(i) \
if (height > i) { \
inp = vld1q_f16(src_dd + i * IW); \
out##i = vmlaq_f16(out##i, inp, kr0); \
}
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);

#undef cb
}
STORE_RESULT_VAL;
}
};

template <int height, int width>
struct do_pixel_proxy<2, height, width> {
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
MEGDNN_MARK_USED_VAR(IH);
MEGDNN_MARK_USED_VAR(OH);
const int ih = oh, iw = ow;
#define cb(i) float16x8_t out##i{0};
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);
#undef cb
float16x8_t kr0, kr1, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_RESULT_VAL;
for (int fw = 0; fw < FW; ++fw) {
const __fp16* src_dd = src + fw;
kr0 = vdupq_n_f16(filter[0 * FW + fw]);
kr1 = vdupq_n_f16(filter[1 * FW + fw]);

#define cb(i) \
if (height > i) { \
inp = vld1q_f16(src_dd + i * IW); \
out##i = vmlaq_f16(out##i, inp, kr0); \
inp = vld1q_f16(src_dd + (i + 1) * IW); \
out##i = vmlaq_f16(out##i, inp, kr1); \
}
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);
#undef cb
}
STORE_RESULT_VAL;
}
};

template <int height, int width>
struct do_pixel_proxy<3, height, width> {
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
MEGDNN_MARK_USED_VAR(IH);
MEGDNN_MARK_USED_VAR(OH);
const int ih = oh, iw = ow;
#define cb(i) float16x8_t out##i{0};
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);
#undef cb
float16x8_t kr0, kr1, kr2, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_RESULT_VAL;
for (int fw = 0; fw < FW; ++fw) {
const __fp16* src_dd = src + fw;
kr0 = vdupq_n_f16(filter[0 * FW + fw]);
kr1 = vdupq_n_f16(filter[1 * FW + fw]);
kr2 = vdupq_n_f16(filter[2 * FW + fw]);
#define cb(i) \
if (height > i) { \
inp = vld1q_f16(src_dd + i * IW); \
out##i = vmlaq_f16(out##i, inp, kr0); \
inp = vld1q_f16(src_dd + (i + 1) * IW); \
out##i = vmlaq_f16(out##i, inp, kr1); \
inp = vld1q_f16(src_dd + (i + 2) * IW); \
out##i = vmlaq_f16(out##i, inp, kr2); \
}
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);

#undef cb
}
STORE_RESULT_VAL;
}
};

template <int height, int width>
struct do_pixel_proxy<4, height, width> {
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
MEGDNN_MARK_USED_VAR(IH);
MEGDNN_MARK_USED_VAR(OH);
const int ih = oh, iw = ow;
#define cb(i) float16x8_t out##i{0};
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);
#undef cb
float16x8_t kr0, kr1, kr2, kr3, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_RESULT_VAL;
for (int fw = 0; fw < FW; ++fw) {
const __fp16* src_dd = src + fw;
kr0 = vdupq_n_f16(filter[0 * FW + fw]);
kr1 = vdupq_n_f16(filter[1 * FW + fw]);
kr2 = vdupq_n_f16(filter[2 * FW + fw]);
kr3 = vdupq_n_f16(filter[3 * FW + fw]);
#define cb(i) \
if (height > i) { \
inp = vld1q_f16(src_dd + i * IW); \
out##i = vmlaq_f16(out##i, inp, kr0); \
inp = vld1q_f16(src_dd + (i + 1) * IW); \
out##i = vmlaq_f16(out##i, inp, kr1); \
inp = vld1q_f16(src_dd + (i + 2) * IW); \
out##i = vmlaq_f16(out##i, inp, kr2); \
inp = vld1q_f16(src_dd + (i + 3) * IW); \
out##i = vmlaq_f16(out##i, inp, kr3); \
}
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);

#undef cb
}
STORE_RESULT_VAL;
}
};

template <int height, int width>
struct do_pixel_proxy<5, height, width> {
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
MEGDNN_MARK_USED_VAR(IH);
MEGDNN_MARK_USED_VAR(OH);
const int ih = oh, iw = ow;
#define cb(i) float16x8_t out##i{0};
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);
#undef cb
float16x8_t kr0, kr1, kr2, kr3, kr4, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_RESULT_VAL;
for (int fw = 0; fw < FW; ++fw) {
const __fp16* src_dd = src + fw;
kr0 = vdupq_n_f16(filter[0 * FW + fw]);
kr1 = vdupq_n_f16(filter[1 * FW + fw]);
kr2 = vdupq_n_f16(filter[2 * FW + fw]);
kr3 = vdupq_n_f16(filter[3 * FW + fw]);
kr4 = vdupq_n_f16(filter[4 * FW + fw]);
#define cb(i) \
if (height > i) { \
inp = vld1q_f16(src_dd + i * IW); \
out##i = vmlaq_f16(out##i, inp, kr0); \
inp = vld1q_f16(src_dd + (i + 1) * IW); \
out##i = vmlaq_f16(out##i, inp, kr1); \
inp = vld1q_f16(src_dd + (i + 2) * IW); \
out##i = vmlaq_f16(out##i, inp, kr2); \
inp = vld1q_f16(src_dd + (i + 3) * IW); \
out##i = vmlaq_f16(out##i, inp, kr3); \
inp = vld1q_f16(src_dd + (i + 4) * IW); \
out##i = vmlaq_f16(out##i, inp, kr4); \
}
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);
#undef cb
}
STORE_RESULT_VAL;
}
};

template <int height, int width>
struct do_pixel_proxy<6, height, width> {
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
MEGDNN_MARK_USED_VAR(IH);
MEGDNN_MARK_USED_VAR(OH);
const int ih = oh, iw = ow;
#define cb(i) float16x8_t out##i{0};
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);
#undef cb
float16x8_t kr0, kr1, kr2, kr3, kr4, kr5, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_RESULT_VAL;
for (int fw = 0; fw < FW; ++fw) {
const __fp16* src_dd = src + fw;
kr0 = vdupq_n_f16(filter[0 * FW + fw]);
kr1 = vdupq_n_f16(filter[1 * FW + fw]);
kr2 = vdupq_n_f16(filter[2 * FW + fw]);
kr3 = vdupq_n_f16(filter[3 * FW + fw]);
kr4 = vdupq_n_f16(filter[4 * FW + fw]);
kr5 = vdupq_n_f16(filter[5 * FW + fw]);
#define cb(i) \
if (height > i) { \
inp = vld1q_f16(src_dd + i * IW); \
out##i = vmlaq_f16(out##i, inp, kr0); \
inp = vld1q_f16(src_dd + (i + 1) * IW); \
out##i = vmlaq_f16(out##i, inp, kr1); \
inp = vld1q_f16(src_dd + (i + 2) * IW); \
out##i = vmlaq_f16(out##i, inp, kr2); \
inp = vld1q_f16(src_dd + (i + 3) * IW); \
out##i = vmlaq_f16(out##i, inp, kr3); \
inp = vld1q_f16(src_dd + (i + 4) * IW); \
out##i = vmlaq_f16(out##i, inp, kr4); \
inp = vld1q_f16(src_dd + (i + 5) * IW); \
out##i = vmlaq_f16(out##i, inp, kr5); \
}
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);
#undef cb
}
STORE_RESULT_VAL;
}
};

template <int height, int width>
struct do_pixel_proxy<7, height, width> {
static void exec(const __fp16* src, const __fp16* filter, __fp16* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
MEGDNN_MARK_USED_VAR(IH);
MEGDNN_MARK_USED_VAR(OH);
const int ih = oh, iw = ow;
#define cb(i) float16x8_t out##i{0};
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);
#undef cb
float16x8_t kr0, kr1, kr2, kr3, kr4, kr5, kr6, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_RESULT_VAL;
for (int fw = 0; fw < FW; ++fw) {
const __fp16* src_dd = src + fw;
kr0 = vdupq_n_f16(filter[0 * FW + fw]);
kr1 = vdupq_n_f16(filter[1 * FW + fw]);
kr2 = vdupq_n_f16(filter[2 * FW + fw]);
kr3 = vdupq_n_f16(filter[3 * FW + fw]);
kr4 = vdupq_n_f16(filter[4 * FW + fw]);
kr5 = vdupq_n_f16(filter[5 * FW + fw]);
kr6 = vdupq_n_f16(filter[6 * FW + fw]);
#define cb(i) \
if (height > i) { \
inp = vld1q_f16(src_dd + i * IW); \
out##i = vmlaq_f16(out##i, inp, kr0); \
inp = vld1q_f16(src_dd + (i + 1) * IW); \
out##i = vmlaq_f16(out##i, inp, kr1); \
inp = vld1q_f16(src_dd + (i + 2) * IW); \
out##i = vmlaq_f16(out##i, inp, kr2); \
inp = vld1q_f16(src_dd + (i + 3) * IW); \
out##i = vmlaq_f16(out##i, inp, kr3); \
inp = vld1q_f16(src_dd + (i + 4) * IW); \
out##i = vmlaq_f16(out##i, inp, kr4); \
inp = vld1q_f16(src_dd + (i + 5) * IW); \
out##i = vmlaq_f16(out##i, inp, kr5); \
inp = vld1q_f16(src_dd + (i + 6) * IW); \
out##i = vmlaq_f16(out##i, inp, kr6); \
}
UNROLL_CALL_NOWRAPPER(BLOCK_H, cb);
#undef cb
}
STORE_RESULT_VAL;
}
};

#undef STORE_RESULT_VAL
#undef LOAD_RESULT_VAL

template <int FH, int height, int width>
void do_pixel(const __fp16* src, const __fp16* filter, __fp16* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
do_pixel_proxy<FH, height, width>::exec(src, filter, dst, IH, IW, OH, OW,
FW, oh, ow);
}

template <int FH>
void do_conv_tpl_enable_prefetch(const __fp16* src,
const __fp16* filter, __fp16* dst,
const int IH, const int IW, const int OH,
const int OW, const int FW) {
const int hbeg = 0, hend = OH;
const int wbeg = 0, wend = OW;
int i, j;
for (i = hbeg; i + BLOCK_H <= hend; i += BLOCK_H) {
for (j = wbeg; j + 8 <= wend; j += 8) {
// do prefetch
const int prefetch_index_input =
(j + 16) < wend
? i * IW + j + 16
: (i + 8) * IW + (((j + 16 - wend) >> 2) << 2);
const int prefetch_index_output =
(j + 16) < wend
? i * OW + j + 16
: (i + 8) * OW + (((j + 16 - wend) >> 2) << 2);
const __fp16* src_prefetch = src + prefetch_index_input;
const __fp16* dst_prefetch = dst + prefetch_index_output;
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) {
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3);
}
#define unroll_prefetch_cb(i) __builtin_prefetch(dst_prefetch + i * OW, 1, 3);
UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb);
do_pixel<FH, BLOCK_H, 8>(src, filter, dst, IH, IW, OH, OW, FW, i,
j);
}
#define DISPATCH(width) \
do { \
const int prefetch_index_input = (i + 8) * IW + 12; \
const int prefetch_index_output = (i + 8) * OW + 12; \
const __fp16* src_prefetch = src + prefetch_index_input; \
const __fp16* dst_prefetch = dst + prefetch_index_output; \
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \
} \
UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); \
do_pixel<FH, BLOCK_H, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \
j); \
} while (0)
switch (wend - j) {
case 1:
DISPATCH(1);
break;
case 2:
DISPATCH(2);
break;
case 3:
DISPATCH(3);
break;
case 4:
DISPATCH(4);
break;
case 5:
DISPATCH(5);
break;
case 6:
DISPATCH(6);
break;
case 7:
DISPATCH(7);
break;
}
#undef DISPATCH
}

#define DISPATCH2(height, width) \
do { \
const int prefetch_index_input = IH * IW + 12; \
const __fp16* src_prefetch = src + prefetch_index_input; \
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \
} \
do_pixel<FH, height, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \
j); \
} while (0)

#define DISPATCH1(height) \
do { \
for (j = wbeg; j + 8 <= wend; j += 8) { \
const int prefetch_index_input = \
(j + 16) < wend \
? i * IW + j + 16 \
: (i + 8) * IW + (((j + 16 - wend) >> 2) << 2); \
const int prefetch_index_output = \
(j + 16) < wend \
? i * OW + j + 16 \
: (i + 8) * OW + (((j + 16 - wend) >> 2) << 2); \
const __fp16* src_prefetch = src + prefetch_index_input; \
const __fp16* dst_prefetch = dst + prefetch_index_output; \
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \
} \
UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); \
do_pixel<FH, height, 8>(src, filter, dst, IH, IW, OH, OW, FW, i, \
j); \
} \
switch (wend - j) { \
case 1: \
DISPATCH2(height, 1); \
break; \
case 2: \
DISPATCH2(height, 2); \
break; \
case 3: \
DISPATCH2(height, 3); \
break; \
case 4: \
DISPATCH2(height, 4); \
break; \
case 5: \
DISPATCH2(height, 5); \
break; \
case 6: \
DISPATCH2(height, 6); \
break; \
case 7: \
DISPATCH2(height, 7); \
break; \
} \
} while (0)
switch (hend - i) {
case 1:
DISPATCH1(1);
break;
case 2:
DISPATCH1(2);
break;
case 3:
DISPATCH1(3);
break;
#if BLOCK_H == 8
case 4:
DISPATCH1(4);
break;
case 5:
DISPATCH1(5);
break;
case 6:
DISPATCH1(6);
break;
case 7:
DISPATCH1(7);
break;
#endif
}
#undef DISPATCH1
#undef DISPATCH2
#undef unroll_prefetch_cb
}
template <int FH>
void do_conv_tpl_disable_prefetch(const __fp16* src,
const __fp16* filter, __fp16* dst,
const int IH, const int IW, const int OH,
const int OW, const int FW) {
const int hbeg = 0, hend = OH;
const int wbeg = 0, wend = OW;
int i, j;
for (i = hbeg; i + BLOCK_H <= hend; i += BLOCK_H) {
for (j = wbeg; j + 8 <= wend; j += 8) {
do_pixel<FH, BLOCK_H, 8>(src, filter, dst, IH, IW, OH, OW, FW, i,
j);
}
#define DISPATCH(width) \
do { \
do_pixel<FH, BLOCK_H, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \
j); \
} while (0)
switch (wend - j) {
case 1:
DISPATCH(1);
break;
case 2:
DISPATCH(2);
break;
case 3:
DISPATCH(3);
break;
case 4:
DISPATCH(4);
break;
case 5:
DISPATCH(5);
break;
case 6:
DISPATCH(6);
break;
case 7:
DISPATCH(7);
break;
}
#undef DISPATCH
}
#define DISPATCH2(height, width) \
do { \
do_pixel<FH, height, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \
j); \
} while (0)
#define DISPATCH1(height) \
do { \
for (j = wbeg; j + 8 <= wend; j += 8) { \
do_pixel<FH, height, 8>(src, filter, dst, IH, IW, OH, OW, FW, i, \
j); \
} \
switch (wend - j) { \
case 1: \
DISPATCH2(height, 1); \
break; \
case 2: \
DISPATCH2(height, 2); \
break; \
case 3: \
DISPATCH2(height, 3); \
break; \
case 4: \
DISPATCH2(height, 4); \
break; \
case 5: \
DISPATCH2(height, 5); \
break; \
case 6: \
DISPATCH2(height, 6); \
break; \
case 7: \
DISPATCH2(height, 7); \
break; \
} \
} while (0)
switch (hend - i) {
case 1:
DISPATCH1(1);
break;
case 2:
DISPATCH1(2);
break;
case 3:
DISPATCH1(3);
break;
#if BLOCK_H == 8
case 4:
DISPATCH1(4);
break;
case 5:
DISPATCH1(5);
break;
case 6:
DISPATCH1(6);
break;
case 7:
DISPATCH1(7);
break;
#endif
}
#undef DISPATCH1
#undef DISPATCH2
}
} // anonymous namespace

void conv_bias::kern_direct_f16(const __fp16* src,
const __fp16* filter, __fp16* dst,
const int IH, const int IW, const int OH,
const int OW, const int FH, const int FW) {
megdnn_assert_internal(FH <= 7);
if (IH > 100 && IW > 100) {
#define GAO(FH) \
do { \
return do_conv_tpl_enable_prefetch<FH>(src, filter, dst, IH, IW, OH, \
OW, FW); \
} while (0)
switch (FH) {
case 1:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(0)) { GAO(1); }
MIDOUT_END();
break;
case 2:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(1)) { GAO(2); }
MIDOUT_END();
break;
case 3:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(2)) { GAO(3); }
MIDOUT_END();
break;
case 4:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(3)) { GAO(4); }
MIDOUT_END();
break;
case 5:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(4)) { GAO(5); }
MIDOUT_END();
break;
case 6:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(5)) { GAO(6); }
MIDOUT_END();
break;
case 7:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(6)) { GAO(7); }
MIDOUT_END();
break;
}
#undef GAO
} else {
#define GAO(FH) \
do { \
return do_conv_tpl_disable_prefetch<FH>(src, filter, dst, IH, IW, OH, \
OW, FW); \
} while (0)
switch (FH) {
case 1:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(0)) { GAO(1); }
MIDOUT_END();
break;
case 2:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(1)) { GAO(2); }
MIDOUT_END();
break;
case 3:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(2)) { GAO(3); }
MIDOUT_END();
break;
case 4:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(3)) { GAO(4); }
MIDOUT_END();
break;
case 5:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(4)) { GAO(5); }
MIDOUT_END();
break;
case 6:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(5)) { GAO(6); }
MIDOUT_END();
break;
case 7:
MIDOUT_BEGIN(megdnn_arm_conv_f16, midout_iv(6)) { GAO(7); }
MIDOUT_END();
break;
}
#undef GAO
}
megdnn_assert_internal(0);
}
#endif

// vim: syntax=cpp.doxygen

+ 32
- 0
dnn/src/arm_common/conv_bias/f16/direct.h View File

@@ -0,0 +1,32 @@
/**
* \file dnn/src/arm_common/conv_bias/f16/direct.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include <cstddef>
#include "megdnn/dtype.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC

namespace megdnn {
namespace arm_common {
namespace fp16{
namespace conv_bias {

void kern_direct_f16(const __fp16* src, const __fp16* filter,
__fp16* dst, const int IH, const int IW, const int OH,
const int OW, const int FH, const int FW);

} // namespace convolution
} // namespace fp16
} // namespace arm_common
} // namespace megdnn
#endif

// vim: syntax=cpp.doxygen

+ 522
- 0
dnn/src/arm_common/conv_bias/f16/do_conv_stride1.cpp View File

@@ -0,0 +1,522 @@
/**
* \file dnn/src/arm_common/conv_bias/f16/do_conv_stride1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include <algorithm>

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "./do_conv_stride1.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"

using namespace megdnn;
using namespace arm_common;
using namespace fp16;
using namespace conv_stride1;

using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam;
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam;


void conv_stride1::do_conv_2x2_stride1(const __fp16* src, const __fp16* filter, __fp16* dst,
size_t IH, size_t IW, size_t OH, size_t OW,
size_t IC) {
const size_t tail_step = IW - OW;
//! unroll of 2
size_t ic = 0;
for (; ic + 1 < IC; ic += 2) {
const __fp16* src_ptr = src + IW * IH * ic;
const __fp16* src_ptr1 = src_ptr + IW * IH;
__fp16* outptr = dst;

const __fp16* r00 = src_ptr;
const __fp16* r01 = src_ptr + IW;
const __fp16* r10 = src_ptr1;
const __fp16* r11 = src_ptr1 + IW;

const __fp16* k0 = filter + ic * 4;

float16x8_t _k0 = vld1q_f16(k0);
rep(h, OH) {
int width = OW >> 3;

rep(i, width) {
float16x8_t _r000 = vld1q_f16(r00);
float16x8_t _r010 = vld1q_f16(r01);
float16x8_t _r001 = vld1q_f16(r00 + 1);
float16x8_t _r011 = vld1q_f16(r01 + 1);

float16x8_t _r100 = vld1q_f16(r10);
float16x8_t _r110 = vld1q_f16(r11);
float16x8_t _r101 = vld1q_f16(r10 + 1);
float16x8_t _r111 = vld1q_f16(r11 + 1);

float16x8_t _sum = vld1q_f16(outptr);

_sum = vmlaq_lane_f16(_sum, _r000, vget_low_f16(_k0), 0);
_sum = vmlaq_lane_f16(_sum, _r001, vget_low_f16(_k0), 1);
_sum = vmlaq_lane_f16(_sum, _r010, vget_low_f16(_k0), 2);
_sum = vmlaq_lane_f16(_sum, _r011, vget_low_f16(_k0), 3);

_sum = vmlaq_lane_f16(_sum, _r100, vget_high_f16(_k0), 0);
_sum = vmlaq_lane_f16(_sum, _r101, vget_high_f16(_k0), 1);
_sum = vmlaq_lane_f16(_sum, _r110, vget_high_f16(_k0), 2);
_sum = vmlaq_lane_f16(_sum, _r111, vget_high_f16(_k0), 3);

vst1q_f16(outptr, _sum);

r00 += 8;
r01 += 8;
r10 += 8;
r11 += 8;
outptr += 8;
}

r00 += tail_step;
r01 += tail_step;
r10 += tail_step;
r11 += tail_step;
}
}
for (; ic < IC; ic++) {
const __fp16* src_ptr = src + IW * IH * ic;
__fp16* outptr = dst;

const __fp16* r0 = src_ptr;
const __fp16* r1 = src_ptr + IW;

const __fp16* k0 = filter + ic * 4;

float16x8_t _k0 = vdupq_n_f16(k0[0]);
float16x8_t _k1 = vdupq_n_f16(k0[1]);
float16x8_t _k2 = vdupq_n_f16(k0[2]);
float16x8_t _k3 = vdupq_n_f16(k0[3]);
rep(h, OH) {
int width = OW >> 3;

rep(i, width) {
float16x8_t _r00 = vld1q_f16(r0);
float16x8_t _r10 = vld1q_f16(r1);
float16x8_t _r01 = vld1q_f16(r0 + 1);
float16x8_t _r11 = vld1q_f16(r1 + 1);

float16x8_t _sum = vld1q_f16(outptr);
float16x8_t _sum2;

_sum = vmlaq_f16(_sum, _r00, _k0);
_sum2 = vmulq_f16(_r01, _k1);
_sum = vmlaq_f16(_sum, _r10, _k2);
_sum2 = vmlaq_f16(_sum2, _r11, _k3);

_sum = vaddq_f16(_sum, _sum2);

vst1q_f16(outptr, _sum);

r0 += 8;
r1 += 8;
outptr += 8;
}

r0 += tail_step;
r1 += tail_step;
}
}
}

void conv_stride1::do_conv_3x3_stride1(const __fp16* src, const __fp16* filter, __fp16* dst,
size_t IH, size_t IW, size_t OH, size_t OW,
size_t IC) {
const size_t tail_step = IW - OW;

rep(ic, IC) {
const __fp16* src_ptr = src + IW * IH * ic;
__fp16* outptr = dst;
__fp16* outptr2 = outptr + OW;

const __fp16* r0 = src_ptr;
const __fp16* r1 = src_ptr + IW;
const __fp16* r2 = src_ptr + IW * 2;
const __fp16* r3 = src_ptr + IW * 3;

float16x8_t _k01234567 = vld1q_f16(filter);
float16x8_t _k12345678 = vld1q_f16(filter + 1);

size_t h = 0;
for (; h + 1 < OH; h += 2) {
int width = OW >> 3;

rep(i, width) {
float16x8_t _sum1 = vld1q_f16(outptr);
float16x8_t _sum2 = vdupq_n_f16(0.f);
float16x8_t _sum3 = vld1q_f16(outptr2);
float16x8_t _sum4 = vdupq_n_f16(0.f);

float16x8_t _r00 = vld1q_f16(r0);
float16x8_t _r00n = vld1q_f16(r0 + 8);
float16x8_t _r01 = vextq_f16(_r00, _r00n, 1);
float16x8_t _r02 = vextq_f16(_r00, _r00n, 2);

float16x8_t _r10 = vld1q_f16(r1);
float16x8_t _r10n = vld1q_f16(r1 + 8);
float16x8_t _r11 = vextq_f16(_r10, _r10n, 1);
float16x8_t _r12 = vextq_f16(_r10, _r10n, 2);

float16x8_t _r20 = vld1q_f16(r2);
float16x8_t _r20n = vld1q_f16(r2 + 8);
float16x8_t _r21 = vextq_f16(_r20, _r20n, 1);
float16x8_t _r22 = vextq_f16(_r20, _r20n, 2);

float16x8_t _r30 = vld1q_f16(r3);
float16x8_t _r30n = vld1q_f16(r3 + 8);
float16x8_t _r31 = vextq_f16(_r30, _r30n, 1);
float16x8_t _r32 = vextq_f16(_r30, _r30n, 2);

_sum1 = vmlaq_low_lane_f16(_sum1, _r00, _k01234567, 0);
_sum2 = vmlaq_low_lane_f16(_sum2, _r01, _k01234567, 1);
_sum1 = vmlaq_low_lane_f16(_sum1, _r02, _k01234567, 2);
_sum2 = vmlaq_low_lane_f16(_sum2, _r10, _k01234567, 3);
_sum1 = vmlaq_high_lane_f16(_sum1, _r11, _k01234567, 4);
_sum2 = vmlaq_high_lane_f16(_sum2, _r12, _k01234567, 5);
_sum1 = vmlaq_high_lane_f16(_sum1, _r20, _k01234567, 6);
_sum2 = vmlaq_high_lane_f16(_sum2, _r21, _k01234567, 7);
_sum1 = vmlaq_high_lane_f16(_sum1, _r22, _k12345678, 7);

_sum3 = vmlaq_low_lane_f16(_sum3, _r10, _k01234567, 0);
_sum4 = vmlaq_low_lane_f16(_sum4, _r11, _k01234567, 1);
_sum3 = vmlaq_low_lane_f16(_sum3, _r12, _k01234567, 2);
_sum4 = vmlaq_low_lane_f16(_sum4, _r20, _k01234567, 3);
_sum3 = vmlaq_high_lane_f16(_sum3, _r21, _k01234567, 4);
_sum4 = vmlaq_high_lane_f16(_sum4, _r22, _k01234567, 5);
_sum3 = vmlaq_high_lane_f16(_sum3, _r30, _k01234567, 6);
_sum4 = vmlaq_high_lane_f16(_sum4, _r31, _k01234567, 7);
_sum3 = vmlaq_high_lane_f16(_sum3, _r32, _k12345678, 7);

_sum1 = vaddq_f16(_sum1, _sum2);
_sum3 = vaddq_f16(_sum3, _sum4);

vst1q_f16(outptr, _sum1);
vst1q_f16(outptr2, _sum3);

r0 += 8;
r1 += 8;
r2 += 8;
r3 += 8;
outptr += 8;
outptr2 += 8;
}

r0 += tail_step + IW;
r1 += tail_step + IW;
r2 += tail_step + IW;
r3 += tail_step + IW;

outptr += OW;
outptr2 += OW;
}

for (; h < OH; h++) {
int width = OW >> 3;

rep(i, width) {
float16x8_t _sum1 = vld1q_f16(outptr);
float16x8_t _sum2 = vdupq_n_f16(0.f);

float16x8_t _r00 = vld1q_f16(r0);
float16x8_t _r00n = vld1q_f16(r0 + 8);
float16x8_t _r01 = vextq_f16(_r00, _r00n, 1);
float16x8_t _r02 = vextq_f16(_r00, _r00n, 2);

float16x8_t _r10 = vld1q_f16(r1);
float16x8_t _r10n = vld1q_f16(r1 + 8);
float16x8_t _r11 = vextq_f16(_r10, _r10n, 1);
float16x8_t _r12 = vextq_f16(_r10, _r10n, 2);

float16x8_t _r20 = vld1q_f16(r2);
float16x8_t _r20n = vld1q_f16(r2 + 8);
float16x8_t _r21 = vextq_f16(_r20, _r20n, 1);
float16x8_t _r22 = vextq_f16(_r20, _r20n, 2);

_sum1 = vmlaq_low_lane_f16(_sum1, _r00, _k01234567, 0);
_sum2 = vmlaq_low_lane_f16(_sum2, _r01, _k01234567, 1);
_sum1 = vmlaq_low_lane_f16(_sum1, _r02, _k01234567, 2);
_sum2 = vmlaq_low_lane_f16(_sum2, _r10, _k01234567, 3);
_sum1 = vmlaq_high_lane_f16(_sum1, _r11, _k01234567, 4);
_sum2 = vmlaq_high_lane_f16(_sum2, _r12, _k01234567, 5);
_sum1 = vmlaq_high_lane_f16(_sum1, _r20, _k01234567, 6);
_sum2 = vmlaq_high_lane_f16(_sum2, _r21, _k01234567, 7);
_sum1 = vmlaq_high_lane_f16(_sum1, _r22, _k12345678, 7);

_sum1 = vaddq_f16(_sum1, _sum2);

vst1q_f16(outptr, _sum1);

r0 += 8;
r1 += 8;
r2 += 8;
outptr += 8;
}
r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
}

filter += 9;
}
}

void conv_stride1::do_conv_5x5_stride1(const __fp16* src, const __fp16* filter, __fp16* dst,
size_t IH, size_t IW, size_t OH, size_t OW,
size_t IC) {
const size_t tail_step = IW - OW;

rep(ic, IC) {
const __fp16* src_ptr = src + IW * IH * ic;
__fp16* outptr = dst;
__fp16* outptr2 = outptr + OW;

const __fp16* r0 = src_ptr;
const __fp16* r1 = src_ptr + IW;
const __fp16* r2 = src_ptr + IW * 2;
const __fp16* r3 = src_ptr + IW * 3;
const __fp16* r4 = src_ptr + IW * 4;
const __fp16* r5 = src_ptr + IW * 5;

float16x8_t _k0 = vld1q_f16(filter);
float16x8_t _k1 = vld1q_f16(filter + 8);
float16x8_t _k2 = vld1q_f16(filter + 16);
float16x8_t _k3 = vld1q_f16(filter + 17);

size_t h = 0;
for (; h + 1 < OH; h += 2) {
int width = OW >> 3;

rep(i, width) {
float16x8_t _sum = vld1q_f16(outptr);
float16x8_t _sum2 = vld1q_f16(outptr2);

float16x8_t _r00 = vld1q_f16(r0);
float16x8_t _r05 = vld1q_f16(r0 + 8);
float16x8_t _r01 = vextq_f16(_r00, _r05, 1);
float16x8_t _r02 = vextq_f16(_r00, _r05, 2);
float16x8_t _r03 = vextq_f16(_r00, _r05, 3);
float16x8_t _r04 = vextq_f16(_r00, _r05, 4);

float16x8_t _r10 = vld1q_f16(r1);
float16x8_t _r15 = vld1q_f16(r1 + 8);
float16x8_t _r11 = vextq_f16(_r10, _r15, 1);
float16x8_t _r12 = vextq_f16(_r10, _r15, 2);
float16x8_t _r13 = vextq_f16(_r10, _r15, 3);
float16x8_t _r14 = vextq_f16(_r10, _r15, 4);

float16x8_t _r20 = vld1q_f16(r2);
float16x8_t _r25 = vld1q_f16(r2 + 8);
float16x8_t _r21 = vextq_f16(_r20, _r25, 1);
float16x8_t _r22 = vextq_f16(_r20, _r25, 2);
float16x8_t _r23 = vextq_f16(_r20, _r25, 3);
float16x8_t _r24 = vextq_f16(_r20, _r25, 4);

float16x8_t _r30 = vld1q_f16(r3);
float16x8_t _r35 = vld1q_f16(r3 + 8);
float16x8_t _r31 = vextq_f16(_r30, _r35, 1);
float16x8_t _r32 = vextq_f16(_r30, _r35, 2);
float16x8_t _r33 = vextq_f16(_r30, _r35, 3);
float16x8_t _r34 = vextq_f16(_r30, _r35, 4);

float16x8_t _r40 = vld1q_f16(r4);
float16x8_t _r45 = vld1q_f16(r4 + 8);
float16x8_t _r41 = vextq_f16(_r40, _r45, 1);
float16x8_t _r42 = vextq_f16(_r40, _r45, 2);
float16x8_t _r43 = vextq_f16(_r40, _r45, 3);
float16x8_t _r44 = vextq_f16(_r40, _r45, 4);

float16x8_t _r50 = vld1q_f16(r5);
float16x8_t _r55 = vld1q_f16(r5 + 8);
float16x8_t _r51 = vextq_f16(_r50, _r55, 1);
float16x8_t _r52 = vextq_f16(_r50, _r55, 2);
float16x8_t _r53 = vextq_f16(_r50, _r55, 3);
float16x8_t _r54 = vextq_f16(_r50, _r55, 4);

_sum = vmlaq_low_lane_f16(_sum, _r00, _k0, 0);
_sum = vmlaq_low_lane_f16(_sum, _r01, _k0, 1);
_sum = vmlaq_low_lane_f16(_sum, _r02, _k0, 2);
_sum = vmlaq_low_lane_f16(_sum, _r03, _k0, 3);
_sum = vmlaq_high_lane_f16(_sum, _r04, _k0, 4);

_sum = vmlaq_high_lane_f16(_sum, _r10, _k0, 5);
_sum = vmlaq_high_lane_f16(_sum, _r11, _k0, 6);
_sum = vmlaq_high_lane_f16(_sum, _r12, _k0, 7);
_sum = vmlaq_low_lane_f16(_sum, _r13, _k1, 0);
_sum = vmlaq_low_lane_f16(_sum, _r14, _k1, 1);

_sum = vmlaq_low_lane_f16(_sum, _r20, _k1, 2);
_sum = vmlaq_low_lane_f16(_sum, _r21, _k1, 3);
_sum = vmlaq_high_lane_f16(_sum, _r22, _k1, 4);
_sum = vmlaq_high_lane_f16(_sum, _r23, _k1, 5);
_sum = vmlaq_high_lane_f16(_sum, _r24, _k1, 6);

_sum = vmlaq_high_lane_f16(_sum, _r30, _k1, 7);
_sum = vmlaq_low_lane_f16(_sum, _r31, _k2, 0);
_sum = vmlaq_low_lane_f16(_sum, _r32, _k2, 1);
_sum = vmlaq_low_lane_f16(_sum, _r33, _k2, 2);
_sum = vmlaq_low_lane_f16(_sum, _r34, _k2, 3);

_sum = vmlaq_high_lane_f16(_sum, _r40, _k2, 4);
_sum = vmlaq_high_lane_f16(_sum, _r41, _k2, 5);
_sum = vmlaq_high_lane_f16(_sum, _r42, _k2, 6);
_sum = vmlaq_high_lane_f16(_sum, _r43, _k2, 7);
_sum = vmlaq_high_lane_f16(_sum, _r44, _k3, 7);

_sum2 = vmlaq_low_lane_f16(_sum2, _r10, _k0, 0);
_sum2 = vmlaq_low_lane_f16(_sum2, _r11, _k0, 1);
_sum2 = vmlaq_low_lane_f16(_sum2, _r12, _k0, 2);
_sum2 = vmlaq_low_lane_f16(_sum2, _r13, _k0, 3);
_sum2 = vmlaq_high_lane_f16(_sum2, _r14, _k0, 4);

_sum2 = vmlaq_high_lane_f16(_sum2, _r20, _k0, 5);
_sum2 = vmlaq_high_lane_f16(_sum2, _r21, _k0, 6);
_sum2 = vmlaq_high_lane_f16(_sum2, _r22, _k0, 7);
_sum2 = vmlaq_low_lane_f16(_sum2, _r23, _k1, 0);
_sum2 = vmlaq_low_lane_f16(_sum2, _r24, _k1, 1);

_sum2 = vmlaq_low_lane_f16(_sum2, _r30, _k1, 2);
_sum2 = vmlaq_low_lane_f16(_sum2, _r31, _k1, 3);
_sum2 = vmlaq_high_lane_f16(_sum2, _r32, _k1, 4);
_sum2 = vmlaq_high_lane_f16(_sum2, _r33, _k1, 5);
_sum2 = vmlaq_high_lane_f16(_sum2, _r34, _k1, 6);

_sum2 = vmlaq_high_lane_f16(_sum2, _r40, _k1, 7);
_sum2 = vmlaq_low_lane_f16(_sum2, _r41, _k2, 0);
_sum2 = vmlaq_low_lane_f16(_sum2, _r42, _k2, 1);
_sum2 = vmlaq_low_lane_f16(_sum2, _r43, _k2, 2);
_sum2 = vmlaq_low_lane_f16(_sum2, _r44, _k2, 3);

_sum2 = vmlaq_high_lane_f16(_sum2, _r50, _k2, 4);
_sum2 = vmlaq_high_lane_f16(_sum2, _r51, _k2, 5);
_sum2 = vmlaq_high_lane_f16(_sum2, _r52, _k2, 6);
_sum2 = vmlaq_high_lane_f16(_sum2, _r53, _k2, 7);
_sum2 = vmlaq_high_lane_f16(_sum2, _r54, _k3, 7);

vst1q_f16(outptr, _sum);
vst1q_f16(outptr2, _sum2);

r0 += 8;
r1 += 8;
r2 += 8;
r3 += 8;
r4 += 8;
r5 += 8;
outptr += 8;
outptr2 += 8;
}

r0 += tail_step + IW;
r1 += tail_step + IW;
r2 += tail_step + IW;
r3 += tail_step + IW;
r4 += tail_step + IW;
r5 += tail_step + IW;

outptr += OW;
outptr2 += OW;
}

for (; h < OH; h++) {
int width = OW >> 3;

rep(i, width) {
float16x8_t _sum = vld1q_f16(outptr);

float16x8_t _r00 = vld1q_f16(r0);
float16x8_t _r05 = vld1q_f16(r0 + 8);
float16x8_t _r01 = vextq_f16(_r00, _r05, 1);
float16x8_t _r02 = vextq_f16(_r00, _r05, 2);
float16x8_t _r03 = vextq_f16(_r00, _r05, 3);
float16x8_t _r04 = vextq_f16(_r00, _r05, 4);

float16x8_t _r10 = vld1q_f16(r1);
float16x8_t _r15 = vld1q_f16(r1 + 8);
float16x8_t _r11 = vextq_f16(_r10, _r15, 1);
float16x8_t _r12 = vextq_f16(_r10, _r15, 2);
float16x8_t _r13 = vextq_f16(_r10, _r15, 3);
float16x8_t _r14 = vextq_f16(_r10, _r15, 4);

float16x8_t _r20 = vld1q_f16(r2);
float16x8_t _r25 = vld1q_f16(r2 + 8);
float16x8_t _r21 = vextq_f16(_r20, _r25, 1);
float16x8_t _r22 = vextq_f16(_r20, _r25, 2);
float16x8_t _r23 = vextq_f16(_r20, _r25, 3);
float16x8_t _r24 = vextq_f16(_r20, _r25, 4);

float16x8_t _r30 = vld1q_f16(r3);
float16x8_t _r35 = vld1q_f16(r3 + 8);
float16x8_t _r31 = vextq_f16(_r30, _r35, 1);
float16x8_t _r32 = vextq_f16(_r30, _r35, 2);
float16x8_t _r33 = vextq_f16(_r30, _r35, 3);
float16x8_t _r34 = vextq_f16(_r30, _r35, 4);

float16x8_t _r40 = vld1q_f16(r4);
float16x8_t _r45 = vld1q_f16(r4 + 8);
float16x8_t _r41 = vextq_f16(_r40, _r45, 1);
float16x8_t _r42 = vextq_f16(_r40, _r45, 2);
float16x8_t _r43 = vextq_f16(_r40, _r45, 3);
float16x8_t _r44 = vextq_f16(_r40, _r45, 4);

_sum = vmlaq_low_lane_f16(_sum, _r00, _k0, 0);
_sum = vmlaq_low_lane_f16(_sum, _r01, _k0, 1);
_sum = vmlaq_low_lane_f16(_sum, _r02, _k0, 2);
_sum = vmlaq_low_lane_f16(_sum, _r03, _k0, 3);
_sum = vmlaq_high_lane_f16(_sum, _r04, _k0, 4);

_sum = vmlaq_high_lane_f16(_sum, _r10, _k0, 5);
_sum = vmlaq_high_lane_f16(_sum, _r11, _k0, 6);
_sum = vmlaq_high_lane_f16(_sum, _r12, _k0, 7);
_sum = vmlaq_low_lane_f16(_sum, _r13, _k1, 0);
_sum = vmlaq_low_lane_f16(_sum, _r14, _k1, 1);

_sum = vmlaq_low_lane_f16(_sum, _r20, _k1, 2);
_sum = vmlaq_low_lane_f16(_sum, _r21, _k1, 3);
_sum = vmlaq_high_lane_f16(_sum, _r22, _k1, 4);
_sum = vmlaq_high_lane_f16(_sum, _r23, _k1, 5);
_sum = vmlaq_high_lane_f16(_sum, _r24, _k1, 6);

_sum = vmlaq_high_lane_f16(_sum, _r30, _k1, 7);
_sum = vmlaq_low_lane_f16(_sum, _r31, _k2, 0);
_sum = vmlaq_low_lane_f16(_sum, _r32, _k2, 1);
_sum = vmlaq_low_lane_f16(_sum, _r33, _k2, 2);
_sum = vmlaq_low_lane_f16(_sum, _r34, _k2, 3);

_sum = vmlaq_high_lane_f16(_sum, _r40, _k2, 4);
_sum = vmlaq_high_lane_f16(_sum, _r41, _k2, 5);
_sum = vmlaq_high_lane_f16(_sum, _r42, _k2, 6);
_sum = vmlaq_high_lane_f16(_sum, _r43, _k2, 7);
_sum = vmlaq_high_lane_f16(_sum, _r44, _k3, 7);

vst1q_f16(outptr, _sum);

r0 += 8;
r1 += 8;
r2 += 8;
r3 += 8;
r4 += 8;
outptr += 8;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
r3 += tail_step;
r4 += tail_step;
}

filter += 25;
}
}
#endif
// vim: syntax=cpp.doxygen

+ 32
- 0
dnn/src/arm_common/conv_bias/f16/do_conv_stride1.h View File

@@ -0,0 +1,32 @@
/**
* \file dnn/src/arm_common/conv_bias/f16/do_conv_stride1.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "src/fallback/conv_bias/opr_impl.h"

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
namespace megdnn {
namespace arm_common {
namespace fp16 {
namespace conv_stride1 {
void do_conv_2x2_stride1(const __fp16* src, const __fp16* filter, __fp16* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC);
void do_conv_3x3_stride1(const __fp16* src, const __fp16* filter, __fp16* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC);
void do_conv_5x5_stride1(const __fp16* src, const __fp16* filter, __fp16* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC);
} // namespace conv_stride1
} // namespace fp16
} // namespace arm_common
} // namespace megdnn
#endif

// vim: syntax=cpp.doxygen

+ 341
- 0
dnn/src/arm_common/conv_bias/f16/helper.h View File

@@ -0,0 +1,341 @@
/**
* \file dnn/src/arm_common/conv_bias/f16/helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once
#include "src/common/unroll_macro.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#define MATRIX_MUL4x4_fp16(sum, a, b) \
sum##0 = vmul_lane_f16(b##0, a##0, 0); \
sum##1 = vmul_lane_f16(b##0, a##1, 0); \
sum##2 = vmul_lane_f16(b##0, a##2, 0); \
sum##3 = vmul_lane_f16(b##0, a##3, 0); \
sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##1, a##0, 1)); \
sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##1, a##1, 1)); \
sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##1, a##2, 1)); \
sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##1, a##3, 1)); \
sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##2, a##0, 2)); \
sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##2, a##1, 2)); \
sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##2, a##2, 2)); \
sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##2, a##3, 2)); \
sum##0 = vadd_f16(sum##0, vmul_lane_f16(b##3, a##0, 3)); \
sum##1 = vadd_f16(sum##1, vmul_lane_f16(b##3, a##1, 3)); \
sum##2 = vadd_f16(sum##2, vmul_lane_f16(b##3, a##2, 3)); \
sum##3 = vadd_f16(sum##3, vmul_lane_f16(b##3, a##3, 3));

#define CONCAT(a, id) a##id

#if MEGDNN_AARCH64

#define TRANSPOSE_4x4(a, ret) \
do { \
auto b00 = vzip1_f16(CONCAT(a, 0).value, \
CONCAT(a, 1).value); /*a1b1a2b2*/ \
auto b01 = vzip2_f16(CONCAT(a, 0).value, \
CONCAT(a, 1).value); /*a3b3a4b4*/ \
auto b10 = vzip1_f16(CONCAT(a, 2).value, \
CONCAT(a, 3).value); /*c1d1c2d2*/ \
auto b11 = vzip2_f16(CONCAT(a, 2).value, \
CONCAT(a, 3).value); /*c3d3c4d4*/ \
auto s32b00 = vreinterpret_s32_f16(b00); \
auto s32b01 = vreinterpret_s32_f16(b01); \
auto s32b10 = vreinterpret_s32_f16(b10); \
auto s32b11 = vreinterpret_s32_f16(b11); \
CONCAT(ret, 0).value = \
vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)); \
CONCAT(ret, 1).value = \
vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)); \
CONCAT(ret, 2).value = \
vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)); \
CONCAT(ret, 3).value = \
vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)); \
} while (0);

#define TRANSPOSE_4x8(a, ret) \
do { \
auto b00 = vzip1q_f16(CONCAT(a, 0).value, \
CONCAT(a, 1).value); /*a1b1a2b2a3b3a4b4*/ \
auto b01 = vzip2q_f16(CONCAT(a, 0).value, \
CONCAT(a, 1).value); /*a5b5a6b6a7b7a8b8*/ \
auto b10 = vzip1q_f16(CONCAT(a, 2).value, \
CONCAT(a, 3).value); /*c1d1c2d2c3d3c4d4*/ \
auto b11 = vzip2q_f16(CONCAT(a, 2).value, \
CONCAT(a, 3).value); /*c5d5c6d6c7d7c8d8*/ \
auto s32b00 = vreinterpretq_s32_f16(b00); \
auto s32b01 = vreinterpretq_s32_f16(b01); \
auto s32b10 = vreinterpretq_s32_f16(b10); \
auto s32b11 = vreinterpretq_s32_f16(b11); \
auto f16b00 = vreinterpretq_f16_s32( \
vzip1q_s32(s32b00, s32b10)); /*a1b1c1d1a2b2c2d2*/ \
auto f16b01 = vreinterpretq_f16_s32( \
vzip2q_s32(s32b00, s32b10)); /*a3b3c3d3a4b4a4d4*/ \
auto f16b10 = vreinterpretq_f16_s32( \
vzip1q_s32(s32b01, s32b11)); /*a5b5c5d5a6b6c6d6*/ \
auto f16b11 = vreinterpretq_f16_s32( \
vzip2q_s32(s32b01, s32b11)); /*a7b7c7d7a8b8c8d8*/ \
CONCAT(ret, 0).value = vget_low_f16(f16b00); \
CONCAT(ret, 1).value = vget_high_f16(f16b00); \
CONCAT(ret, 2).value = vget_low_f16(f16b01); \
CONCAT(ret, 3).value = vget_high_f16(f16b01); \
CONCAT(ret, 4).value = vget_low_f16(f16b10); \
CONCAT(ret, 5).value = vget_high_f16(f16b10); \
CONCAT(ret, 6).value = vget_low_f16(f16b11); \
CONCAT(ret, 7).value = vget_high_f16(f16b11); \
} while (0);

#define TRANSPOSE_8x4(a, ret) \
do { \
auto b00 = vzip1_f16(CONCAT(a, 0).value, \
CONCAT(a, 1).value); /*a1b1a2b2*/ \
auto b01 = vzip2_f16(CONCAT(a, 0).value, \
CONCAT(a, 1).value); /*a3b3a4b4*/ \
auto b10 = vzip1_f16(CONCAT(a, 2).value, \
CONCAT(a, 3).value); /*c1d1c2d2*/ \
auto b11 = vzip2_f16(CONCAT(a, 2).value, \
CONCAT(a, 3).value); /*c3d3c4d4*/ \
auto b20 = vzip1_f16(CONCAT(a, 4).value, \
CONCAT(a, 5).value); /*e1f1e2f2*/ \
auto b21 = vzip2_f16(CONCAT(a, 4).value, \
CONCAT(a, 5).value); /*e3f3e4f4*/ \
auto b30 = vzip1_f16(CONCAT(a, 6).value, \
CONCAT(a, 7).value); /*g1h1g2h2*/ \
auto b31 = vzip2_f16(CONCAT(a, 6).value, \
CONCAT(a, 7).value); /*g3h3g4h4*/ \
auto s32b00 = vreinterpret_s32_f16(b00); \
auto s32b01 = vreinterpret_s32_f16(b01); \
auto s32b10 = vreinterpret_s32_f16(b10); \
auto s32b11 = vreinterpret_s32_f16(b11); \
auto s32b20 = vreinterpret_s32_f16(b20); \
auto s32b21 = vreinterpret_s32_f16(b21); \
auto s32b30 = vreinterpret_s32_f16(b30); \
auto s32b31 = vreinterpret_s32_f16(b31); \
CONCAT(ret, 0).value = \
vcombine_f16(vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)), \
vreinterpret_f16_s32(vzip1_s32(s32b20, s32b30))); \
CONCAT(ret, 1).value = \
vcombine_f16(vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)), \
vreinterpret_f16_s32(vzip2_s32(s32b20, s32b30))); \
CONCAT(ret, 2).value = \
vcombine_f16(vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)), \
vreinterpret_f16_s32(vzip1_s32(s32b21, s32b31))); \
CONCAT(ret, 3).value = \
vcombine_f16(vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)), \
vreinterpret_f16_s32(vzip2_s32(s32b21, s32b31))); \
} while (0);

#define TRANSPOSE_8x8(a, ret) \
do { \
auto b00 = vzip1q_f16(CONCAT(a, 0).value, \
CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \
auto b01 = vzip2q_f16(CONCAT(a, 0).value, \
CONCAT(a, 1).value); /*a5b5a6b6 a7b7a8b8*/ \
auto b10 = vzip1q_f16(CONCAT(a, 2).value, \
CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \
auto b11 = vzip2q_f16(CONCAT(a, 2).value, \
CONCAT(a, 3).value); /*c5d5c6d6 c7d7c8d8*/ \
auto b20 = vzip1q_f16(CONCAT(a, 4).value, \
CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \
auto b21 = vzip2q_f16(CONCAT(a, 4).value, \
CONCAT(a, 5).value); /*e5f5e6f6 e7f7e8f8*/ \
auto b30 = vzip1q_f16(CONCAT(a, 6).value, \
CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \
auto b31 = vzip2q_f16(CONCAT(a, 6).value, \
CONCAT(a, 7).value); /*g5h5g6h6 g7h7g8h8*/ \
auto s32b00 = vreinterpretq_s32_f16(b00); \
auto s32b01 = vreinterpretq_s32_f16(b01); \
auto s32b10 = vreinterpretq_s32_f16(b10); \
auto s32b11 = vreinterpretq_s32_f16(b11); \
auto s32b20 = vreinterpretq_s32_f16(b20); \
auto s32b21 = vreinterpretq_s32_f16(b21); \
auto s32b30 = vreinterpretq_s32_f16(b30); \
auto s32b31 = vreinterpretq_s32_f16(b31); \
auto s64b00 = vreinterpretq_s64_s32( \
vzip1q_s32(s32b00, s32b10)); /*a1b1c1d1 a2b2c2d2*/ \
auto s64b01 = vreinterpretq_s64_s32( \
vzip2q_s32(s32b00, s32b10)); /*a3b3c3d3 a4b4c4d4*/ \
auto s64b10 = vreinterpretq_s64_s32( \
vzip1q_s32(s32b01, s32b11)); /*a5b5c5d5 a6b6c6d6*/ \
auto s64b11 = vreinterpretq_s64_s32( \
vzip2q_s32(s32b01, s32b11)); /*a7b7c7d7 a8b8c8d8*/ \
auto s64b20 = vreinterpretq_s64_s32( \
vzip1q_s32(s32b20, s32b30)); /*e1f1g1h1 e2f2g2h2*/ \
auto s64b21 = vreinterpretq_s64_s32( \
vzip2q_s32(s32b20, s32b30)); /*e3f3g3h3 e4f4g4h4*/ \
auto s64b30 = vreinterpretq_s64_s32( \
vzip1q_s32(s32b21, s32b31)); /*e5f5g5h5 e6f6g6h6*/ \
auto s64b31 = vreinterpretq_s64_s32( \
vzip2q_s32(s32b21, s32b31)); /*e7f7g7h7 e8f8g8h8*/ \
CONCAT(ret, 0).value = \
vreinterpretq_f16_s64(vzip1q_s64(s64b00, s64b20)); \
CONCAT(ret, 1).value = \
vreinterpretq_f16_s64(vzip2q_s64(s64b00, s64b20)); \
CONCAT(ret, 2).value = \
vreinterpretq_f16_s64(vzip1q_s64(s64b01, s64b21)); \
CONCAT(ret, 3).value = \
vreinterpretq_f16_s64(vzip2q_s64(s64b01, s64b21)); \
CONCAT(ret, 4).value = \
vreinterpretq_f16_s64(vzip1q_s64(s64b10, s64b30)); \
CONCAT(ret, 5).value = \
vreinterpretq_f16_s64(vzip2q_s64(s64b10, s64b30)); \
CONCAT(ret, 6).value = \
vreinterpretq_f16_s64(vzip1q_s64(s64b11, s64b31)); \
CONCAT(ret, 7).value = \
vreinterpretq_f16_s64(vzip2q_s64(s64b11, s64b31)); \
} while (0);

#else

#define TRANSPOSE_4x4(a, ret) \
do { \
auto b0_01 = vzip_f16(CONCAT(a, 0).value, \
CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \
auto b1_01 = vzip_f16(CONCAT(a, 2).value, \
CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \
auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \
auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \
auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \
auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \
auto s32b00b10 = vzip_s32(s32b00, s32b10); /*a1b1c1d1 a2b2c2d2*/ \
auto s32b01b11 = vzip_s32(s32b01, s32b11); /*a3b3c3d3 a4b4c4d4*/ \
CONCAT(ret, 0).value = vreinterpret_f16_s32(s32b00b10.val[0]); \
CONCAT(ret, 1).value = vreinterpret_f16_s32(s32b00b10.val[1]); \
CONCAT(ret, 2).value = vreinterpret_f16_s32(s32b01b11.val[0]); \
CONCAT(ret, 3).value = vreinterpret_f16_s32(s32b01b11.val[1]); \
} while (0);

#define TRANSPOSE_4x8(a, ret) \
do { \
auto b0_01 = vzipq_f16( \
CONCAT(a, 0).value, \
CONCAT(a, 1).value); /*a1b1a2b2a3b3a4b4 a5b5a6b6a7b7a8b8*/ \
auto b1_01 = vzipq_f16( \
CONCAT(a, 2).value, \
CONCAT(a, 3).value); /*c1d1c2d2c3d3c4d4 c5d6c6d6c7d7c8d8*/ \
auto s32b00 = vreinterpretq_s32_f16(b0_01.val[0]); \
auto s32b01 = vreinterpretq_s32_f16(b0_01.val[1]); \
auto s32b10 = vreinterpretq_s32_f16(b1_01.val[0]); \
auto s32b11 = vreinterpretq_s32_f16(b1_01.val[1]); \
auto s32b00b10 = vzipq_s32( \
s32b00, s32b10); /*a1b1c1d1a2b2c2d2 a3b3c3d3a4b4c4d4*/ \
auto s32b01b11 = vzipq_s32( \
s32b01, s32b11); /*a5b5c5d5a6b6c6d6 a7b7c7d7a8b8c8d8*/ \
CONCAT(ret, 0).value = \
vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[0])); \
CONCAT(ret, 1).value = \
vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[0])); \
CONCAT(ret, 2).value = \
vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[1])); \
CONCAT(ret, 3).value = \
vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[1])); \
CONCAT(ret, 4).value = \
vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[0])); \
CONCAT(ret, 5).value = \
vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[0])); \
CONCAT(ret, 6).value = \
vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[1])); \
CONCAT(ret, 7).value = \
vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[1])); \
} while (0);

#define TRANSPOSE_8x4(a, ret) \
do { \
auto b0_01 = vzip_f16(CONCAT(a, 0).value, \
CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \
auto b1_01 = vzip_f16(CONCAT(a, 2).value, \
CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \
auto b2_01 = vzip_f16(CONCAT(a, 4).value, \
CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \
auto b3_01 = vzip_f16(CONCAT(a, 6).value, \
CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \
auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \
auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \
auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \
auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \
auto s32b20 = vreinterpret_s32_f16(b2_01.val[0]); \
auto s32b21 = vreinterpret_s32_f16(b2_01.val[1]); \
auto s32b30 = vreinterpret_s32_f16(b3_01.val[0]); \
auto s32b31 = vreinterpret_s32_f16(b3_01.val[1]); \
auto s32b00b10 = vzip_s32(s32b00, s32b10); \
auto s32b01b11 = vzip_s32(s32b01, s32b11); \
auto s32b20b30 = vzip_s32(s32b20, s32b30); \
auto s32b21b31 = vzip_s32(s32b21, s32b31); \
CONCAT(ret, 0).value = \
vcombine_f16(vreinterpret_f16_s32(s32b00b10.val[0]), \
vreinterpret_f16_s32(s32b20b30.val[0])); \
CONCAT(ret, 1).value = \
vcombine_f16(vreinterpret_f16_s32(s32b00b10.val[1]), \
vreinterpret_f16_s32(s32b20b30.val[1])); \
CONCAT(ret, 2).value = \
vcombine_f16(vreinterpret_f16_s32(s32b01b11.val[0]), \
vreinterpret_f16_s32(s32b21b31.val[0])); \
CONCAT(ret, 3).value = \
vcombine_f16(vreinterpret_f16_s32(s32b01b11.val[1]), \
vreinterpret_f16_s32(s32b21b31.val[1])); \
} while (0);

#define TRANSPOSE_8x8(a, ret) \
do { \
auto b00 = vzipq_f16(CONCAT(a, 0).value, \
CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \
auto b01 = vzipq_f16(CONCAT(a, 0).value, \
CONCAT(a, 1).value); /*a5b5a6b6 a7b7a8b8*/ \
auto b10 = vzipq_f16(CONCAT(a, 2).value, \
CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \
auto b11 = vzipq_f16(CONCAT(a, 2).value, \
CONCAT(a, 3).value); /*c5d5c6d6 c7d7c8d8*/ \
auto b20 = vzipq_f16(CONCAT(a, 4).value, \
CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \
auto b21 = vzipq_f16(CONCAT(a, 4).value, \
CONCAT(a, 5).value); /*e5f5e6f6 e7f7e8f8*/ \
auto b30 = vzipq_f16(CONCAT(a, 6).value, \
CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \
auto b31 = vzipq_f16(CONCAT(a, 6).value, \
CONCAT(a, 7).value); /*g5h5g6h6 g7h7g8h8*/ \
auto s32b00 = vreinterpretq_s32_f16(b00.val[0]); \
auto s32b01 = vreinterpretq_s32_f16(b01.val[1]); \
auto s32b10 = vreinterpretq_s32_f16(b10.val[0]); \
auto s32b11 = vreinterpretq_s32_f16(b11.val[1]); \
auto s32b20 = vreinterpretq_s32_f16(b20.val[0]); \
auto s32b21 = vreinterpretq_s32_f16(b21.val[1]); \
auto s32b30 = vreinterpretq_s32_f16(b30.val[0]); \
auto s32b31 = vreinterpretq_s32_f16(b31.val[1]); \
auto s32b00b10 = vzipq_s32(s32b00, s32b10); \
auto s32b01b11 = vzipq_s32(s32b01, s32b11); \
auto s32b20b30 = vzipq_s32(s32b20, s32b30); \
auto s32b21b31 = vzipq_s32(s32b21, s32b31); \
CONCAT(ret, 0).value = vreinterpretq_f16_s32( \
vcombine_s32(vget_low_s32(s32b00b10.val[0]), \
vget_low_s32(s32b20b30.val[0]))); \
CONCAT(ret, 1).value = vreinterpretq_f16_s32( \
vcombine_s32(vget_high_s32(s32b00b10.val[0]), \
vget_high_s32(s32b20b30.val[0]))); \
CONCAT(ret, 2).value = vreinterpretq_f16_s32( \
vcombine_s32(vget_low_s32(s32b00b10.val[1]), \
vget_low_s32(s32b20b30.val[1]))); \
CONCAT(ret, 3).value = vreinterpretq_f16_s32( \
vcombine_s32(vget_high_s32(s32b00b10.val[1]), \
vget_high_s32(s32b20b30.val[1]))); \
CONCAT(ret, 4).value = vreinterpretq_f16_s32( \
vcombine_s32(vget_low_s32(s32b01b11.val[0]), \
vget_low_s32(s32b21b31.val[0]))); \
CONCAT(ret, 5).value = vreinterpretq_f16_s32( \
vcombine_s32(vget_high_s32(s32b01b11.val[0]), \
vget_high_s32(s32b21b31.val[0]))); \
CONCAT(ret, 6).value = vreinterpretq_f16_s32( \
vcombine_s32(vget_low_s32(s32b01b11.val[1]), \
vget_low_s32(s32b21b31.val[1]))); \
CONCAT(ret, 7).value = vreinterpretq_f16_s32( \
vcombine_s32(vget_high_s32(s32b01b11.val[1]), \
vget_high_s32(s32b21b31.val[1]))); \
} while (0);

#endif
#endif
// vim: syntax=cpp.doxygen

+ 33
- 0
dnn/src/arm_common/conv_bias/f16/strategy.h View File

@@ -0,0 +1,33 @@
/**
* \file dnn/src/arm_common/conv_bias/f16/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
namespace megdnn {
namespace arm_common {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 2,
3, 4, 4, winograd_2x3_4x4_f16)
MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 4,
5, 1, 1, winograd_4x5_1x1_f16)
MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 6,
3, 1, 1, winograd_6x3_1x1_f16)
MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 2,
3, 8, 8, winograd_2x3_8x8_f16)
} // namespace winograd
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 373
- 0
dnn/src/arm_common/conv_bias/f16/strategy_2x3.cpp View File

@@ -0,0 +1,373 @@
/**
* \file dnn/src/arm_common/conv_bias/f16/strategy_2x3.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/f16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"

#include "src/arm_common/conv_bias/f16/helper.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F23)

using namespace megdnn;
using namespace arm_common;

namespace {
void transpose_4x4(const __fp16* src, __fp16* dst, int lda, int ldb) {
float16x4x2_t a0, a1;
a0.val[0] = vld1_f16(src + 0 * lda); // a0a1a2a3
a0.val[1] = vld1_f16(src + 1 * lda); // b0b1b2b3
a1.val[0] = vld1_f16(src + 2 * lda); // c0c1c2c3
a1.val[1] = vld1_f16(src + 3 * lda); // d0d1d2d3
float16x4x2_t b0 = vzip_f16(a0.val[0], a1.val[0]); // a0c0a1c1a2c2a3c3
float16x4x2_t b1 = vzip_f16(a0.val[1], a1.val[1]); // b0d0b1d1b2d2b3d3
float16x4x2_t c0 = vzip_f16(b0.val[0], b1.val[0]); // a0b0c0d0a1b1c1d1
float16x4x2_t c1 = vzip_f16(b0.val[1], b1.val[1]); // a2b2c2d2a3b3c3d3
vst1_f16(dst + 0 * ldb, c0.val[0]);
vst1_f16(dst + 1 * ldb, c0.val[1]);
vst1_f16(dst + 2 * ldb, c1.val[0]);
vst1_f16(dst + 3 * ldb, c1.val[1]);
}

struct InputTransform2X3 {
template <bool inner>
static void prepare(const __fp16* input, __fp16* patch, __fp16* patchT,
int ih_start, int iw_start, size_t IH, size_t IW,
size_t ic, size_t IC) {
constexpr size_t alpha = 2 + 3 - 1;
constexpr size_t alpha4 = alpha * 4;
if (!(inner && ic + 4 < IC)) {
memset(patch, 0, sizeof(__fp16) * 4 * alpha * alpha);
}
if (inner) {
const __fp16* input_ptr =
input + ic * IH * IW + ih_start * IW + iw_start;
for (size_t ico = 0; ico < 4; ++ico) {
if (ic + ico < IC) {
auto v0 = vld1_f16(input_ptr);
auto v1 = vld1_f16(input_ptr + IW);
auto v2 = vld1_f16(input_ptr + IW * 2);
auto v3 = vld1_f16(input_ptr + IW * 3);

vst1_f16(patch + ico * alpha4 + 0 * 4, v0);
vst1_f16(patch + ico * alpha4 + 1 * 4, v1);
vst1_f16(patch + ico * alpha4 + 2 * 4, v2);
vst1_f16(patch + ico * alpha4 + 3 * 4, v3);
input_ptr += IH * IW;
}
}
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
// partial copy
for (size_t ico = 0; ico < 4; ++ico) {
if (ic + ico < IC) {
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
patch[ico * alpha4 + iho * 4 + iwo] =
input[(ic + ico) * IH * IW + ih * IW + iw];
}
}
}
}
}

transpose_4x4(patch + 0 * 1, patchT + 0 * 4, 16, 4);
transpose_4x4(patch + 4 * 1, patchT + 4 * 4, 16, 4);
transpose_4x4(patch + 8 * 1, patchT + 8 * 4, 16, 4);
transpose_4x4(patch + 12 * 1, patchT + 12 * 4, 16, 4);
}

static void transform(const __fp16* patchT, __fp16* input_transform_buf,
size_t unit_idx, size_t nr_units_in_tile, size_t ic,
size_t IC) {
constexpr size_t alpha = 2 + 3 - 1;
// BT * d * B
#define cb(m, n) \
Vector<__fp16, 4> d##m##n = \
Vector<__fp16, 4>::load(patchT + m * 4 * 4 + n * 4);

UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb

//! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1
#define cb(m) \
auto t0##m = d0##m - d2##m; \
auto t1##m = d1##m + d2##m; \
auto t2##m = d2##m - d1##m; \
auto t3##m = d3##m - d1##m;

UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb

#define cb(m) \
d##m##0 = t##m##0 - t##m##2; \
d##m##1 = t##m##1 + t##m##2; \
d##m##2 = t##m##2 - t##m##1; \
d##m##3 = t##m##3 - t##m##1;

UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb

#define cb(m, n) \
d##m##n.save(input_transform_buf + \
(m * alpha + n) * nr_units_in_tile * IC + unit_idx * IC + \
ic);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb)
#undef cb
}
};

template <BiasMode bmode, typename Op>
struct OutputTransform2X3 {
static void transform(const dt_float16* output_transform_buf,
const dt_float16* bias, dt_float16* output,
dt_float16* transform_mid_buf, size_t oh_start,
size_t ow_start, size_t OH, size_t OW,
size_t oc_start, size_t oc_end, size_t oc_index,
size_t unit_idx, size_t nr_units_in_tile,
const DType& src_dtype, const DType& dst_dtype) {
Op op(src_dtype, dst_dtype);
const __fp16* output_transform_ptr =
reinterpret_cast<const __fp16*>(output_transform_buf);
const __fp16* bias_ptr = reinterpret_cast<const __fp16*>(bias);
__fp16* output_ptr = reinterpret_cast<__fp16*>(output);
__fp16* transform_mid_ptr =
reinterpret_cast<__fp16*>(transform_mid_buf);

//! AT * m * A
constexpr size_t alpha = 2 + 3 - 1;
size_t OC = oc_end - oc_start;
size_t oc = oc_start + oc_index;

#define cb(m, n) \
auto v##m##n = Vector<__fp16, 4>::load( \
output_transform_ptr + (m * alpha + n) * nr_units_in_tile * OC + \
unit_idx * OC + oc_index);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb
//! 1 1 1 0 v00 v01 v02 v03 1 0
//! 0 1 -1 1 v10 v11 v12 v13 1 1
//! v20 v21 v22 v23 1 -1
//! v30 v31 v32 v33 0 1
#define cb(m) \
auto t0##m = v0##m + v1##m + v2##m; \
auto t1##m = v1##m - v2##m + v3##m;

UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
v00 = t00 + t01 + t02;
v10 = t10 + t11 + t12;
v01 = t01 - t02 + t03;
v11 = t11 - t12 + t13;

Vector<__fp16, 4> vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<__fp16, 4>::load(bias_ptr + oc);

v00 += vbias;
v10 += vbias;
v01 += vbias;
v11 += vbias;
}
float16x8_t result01, result23;
result01 = vcombine_f16(v00.value, v01.value);
result23 = vcombine_f16(v10.value, v11.value);
if (bmode != BiasMode::BIAS) {
result01 = op(result01);
result23 = op(result23);
}
vst1q_f16(transform_mid_ptr, result01);
vst1q_f16(transform_mid_ptr + 8, result23);

for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) {
for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) {
for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) {
size_t oh = oh_start + oho;
size_t ow = ow_start + owo;
__fp16 res = transform_mid_ptr[oho * 2 * 4 + owo * 4 + oco];
if (bmode == BiasMode::BIAS) {
res += bias_ptr[(oc + oco) * OH * OW + oh * OW + ow];
res = op(res);
}
output_ptr[(oc + oco) * OH * OW + oh * OW + ow] = res;
}
}
}
}
};
} // namespace

namespace megdnn {
namespace arm_common {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f16)

void winograd_2x3_4x4_f16::filter(const dt_float16* filter,
dt_float16* filter_transform_buf,
dt_float16* transform_mid_buf, size_t OC,
size_t IC, size_t oc_start, size_t oc_end) {
constexpr int alpha = 2 + 3 - 1;
//! G * g * GT
__fp16* filter_transbuf_ptr =
reinterpret_cast<__fp16*>(filter_transform_buf);
__fp16* filter_transmid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf);

for (size_t oc = oc_start; oc < oc_end; oc++) {
rep(ic, IC) {
const __fp16* filter_ptr = reinterpret_cast<const __fp16*>(filter) +
(oc * IC + ic) * 3 * 3;
/**
* origin: (4x3) * (3 x 3) * (3 x 4)
* pack to G and g to times of 4
* now: (4x4) * (4 x 4) * (4 x 4)
*/
//! 1 0 0 0 v00 v01 v02 0 1 0.5 0.5 0
//! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0
//! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1
//! 0 0 1 0 0 0 0 0 0 0 0 0
float16x4_t v0 = vld1_f16(filter_ptr); // 0 1 2 3
float16x4_t v1 = vld1_f16(filter_ptr + 3); // 3 4 5 6
float16x4_t v2 = vld1_f16(filter_ptr + 5); // 5678
float16x4_t v3 = vdup_n_f16(0);
v2 = vext_f16(v2, v3, 1);
v0 = vset_lane_f16(0, v0, 3);
v1 = vset_lane_f16(0, v1, 3);
#define cb(i) float16x4_t vsum##i;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
vsum0 = v0;
float16x4_t v0addv2 = vadd_f16(v0, v2);
float16x4_t v02addv1 = vadd_f16(v0addv2, v1);
float16x4_t v02subv1 = vsub_f16(v0addv2, v1);
vsum1 = vmul_n_f16(v02addv1, 0.5);
vsum2 = vmul_n_f16(v02subv1, 0.5);
vsum3 = v2;

#define cb(i) \
do { \
mid_buf1[0] = vget_lane_f16(vsum##i, 0); \
__fp16 a0a2 = vget_lane_f16(vsum##i, 0) + vget_lane_f16(vsum##i, 2); \
__fp16 a0a2adda1 = a0a2 + vget_lane_f16(vsum##i, 1); \
__fp16 a0a2suba1 = a0a2 - vget_lane_f16(vsum##i, 1); \
mid_buf1[1] = a0a2adda1 * 0.5; \
mid_buf1[2] = a0a2suba1 * 0.5; \
mid_buf1[3] = vget_lane_f16(vsum##i, 2); \
mid_buf1 += 4; \
} while (0);

__fp16* mid_buf1 = filter_transmid_ptr;
UNROLL_CALL_NOWRAPPER(4, cb);
mid_buf1 = filter_transmid_ptr;
#undef cb
rep(i, alpha) rep(j, alpha) {
filter_transbuf_ptr[(i * alpha + j) * OC * IC + ic * OC + oc] =
filter_transmid_ptr[i * alpha + j];
}
}
}
}

void winograd_2x3_4x4_f16::input(const dt_float16* input,
dt_float16* input_transform_buf,
dt_float16* transform_mid_buf, size_t IH,
size_t IW, size_t IC, size_t PH, size_t PW,
size_t unit_start_idx,
size_t nr_units_in_tile) {
megdnn_assert(IC % 4 == 0);
constexpr int alpha = 3 + 2 - 1;

// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
dt_float16* patch = transform_mid_buf;
dt_float16* patchT = transform_mid_buf + 4 * alpha * alpha;

for (size_t ic = 0; ic < IC; ic += 4) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransform2X3::prepare<true>(
reinterpret_cast<const __fp16*>(input),
reinterpret_cast<__fp16*>(patch),
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start,
IH, IW, ic, IC);
InputTransform2X3::transform(
reinterpret_cast<const __fp16*>(patchT),
reinterpret_cast<__fp16*>(input_transform_buf),
unit_idx, nr_units_in_tile, ic, IC);
} else {
InputTransform2X3::prepare<false>(
reinterpret_cast<const __fp16*>(input),
reinterpret_cast<__fp16*>(patch),
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start,
IH, IW, ic, IC);
InputTransform2X3::transform(
reinterpret_cast<const __fp16*>(patchT),
reinterpret_cast<__fp16*>(input_transform_buf),
unit_idx, nr_units_in_tile, ic, IC);
}
}
}
}

void winograd_2x3_4x4_f16::output(const dt_float16* output_transform_buf,
const dt_float16* bias, dt_float16* output,
dt_float16* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH, size_t OW,
size_t oc_start, size_t oc_end,
size_t unit_start_idx, size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform2X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__);
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);

for (size_t oc = oc_start; oc < oc_end; oc += 4) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp16_F23, cb, __fp16, __fp16, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
}
}
#undef cb
}

} // namespace winograd
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 407
- 0
dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp View File

@@ -0,0 +1,407 @@
/**
* \file dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"

#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/arm_common/conv_bias/f16/strategy.h"
#include "src/arm_common/conv_bias/f16/helper.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"

#include "src/common/winograd/winograd_generator.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "midout.h"

#include "src/common/winograd/winograd_helper.h"

MIDOUT_DECL(megdnn_arm_common_winograd_f16_F23_8x8)

using namespace megdnn;
using namespace arm_common;
namespace {
void transpose_8x4(const __fp16* src, __fp16* dst, int lda, int ldb) {
float16x4x2_t a0, a1, a2, a3;
a0.val[0] = vld1_f16(src + 0 * lda);
a0.val[1] = vld1_f16(src + 1 * lda);
a1.val[0] = vld1_f16(src + 2 * lda);
a1.val[1] = vld1_f16(src + 3 * lda);
a2.val[0] = vld1_f16(src + 4 * lda);
a2.val[1] = vld1_f16(src + 5 * lda);
a3.val[0] = vld1_f16(src + 6 * lda);
a3.val[1] = vld1_f16(src + 7 * lda);
float16x4x2_t b0 = vzip_f16(a0.val[0], a1.val[0]);
float16x4x2_t b1 = vzip_f16(a0.val[1], a1.val[1]);
float16x4x2_t b2 = vzip_f16(a2.val[0], a3.val[0]);
float16x4x2_t b3 = vzip_f16(a2.val[1], a3.val[1]);

float16x4x2_t c0 = vzip_f16(b0.val[0], b1.val[0]);
float16x4x2_t c1 = vzip_f16(b0.val[1], b1.val[1]);
float16x4x2_t c2 = vzip_f16(b2.val[0], b3.val[0]);
float16x4x2_t c3 = vzip_f16(b2.val[1], b3.val[1]);

vst1_f16(dst + 0 * ldb, c0.val[0]);
vst1_f16(dst + 1 * ldb, c2.val[0]);
vst1_f16(dst + 2 * ldb, c0.val[1]);
vst1_f16(dst + 3 * ldb, c2.val[1]);
vst1_f16(dst + 4 * ldb, c1.val[0]);
vst1_f16(dst + 5 * ldb, c3.val[0]);
vst1_f16(dst + 6 * ldb, c1.val[1]);
vst1_f16(dst + 7 * ldb, c3.val[1]);
}

struct InputTransform2X3_8x8 {
template <bool inner>
static void prepare(const __fp16* input, __fp16* patch, __fp16* patchT,
int ih_start, int iw_start, size_t IH, size_t IW,
size_t ic, size_t IC) {
constexpr size_t alpha = 2 + 3 - 1;
if (!(inner && ic + 8 < IC)) {
memset(patch, 0, sizeof(__fp16) * 8 * alpha * alpha);
}
if (inner) {
const __fp16* input_ptr =
input + ic * IH * IW + ih_start * IW + iw_start;
for (size_t ico = 0; ico < 8; ++ico) {
if (ic + ico < IC) {
auto v0 = vld1_f16(input_ptr);
auto v1 = vld1_f16(input_ptr + IW);
auto v2 = vld1_f16(input_ptr + IW * 2);
auto v3 = vld1_f16(input_ptr + IW * 3);

vst1_f16(patch + ico * alpha * 4 + 0 * 4, v0);
vst1_f16(patch + ico * alpha * 4 + 1 * 4, v1);
vst1_f16(patch + ico * alpha * 4 + 2 * 4, v2);
vst1_f16(patch + ico * alpha * 4 + 3 * 4, v3);
input_ptr += IH * IW;
}
}
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
// partial copy
for (size_t ico = 0; ico < 8; ++ico) {
if (ic + ico < IC) {
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
patch[ico * alpha * alpha + iho * alpha + iwo] =

input[(ic + ico) * IH * IW + ih * IW + iw];
}
}
}
}
}
transpose_8x4(patch + 4 * 0, patchT + 32 * 0, 16, 4);
transpose_8x4(patch + 4 * 1, patchT + 32 * 1, 16, 4);
transpose_8x4(patch + 4 * 2, patchT + 32 * 2, 16, 4);
transpose_8x4(patch + 4 * 3, patchT + 32 * 3, 16, 4);
}

static void transform(const __fp16* patchT, __fp16* input_transform_buf,
size_t unit_idx, size_t nr_units_in_tile, size_t ic,
size_t IC) {
constexpr size_t alpha = 2 + 3 - 1;
// BT * d * B
#define cb(m, n) \
Vector<__fp16, 8> d##m##n = \
Vector<__fp16, 8>::load(patchT + 8 * (m * 4 + n));

UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb

//! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1
#define cb(m) \
auto t0##m = d0##m - d2##m; \
auto t1##m = d1##m + d2##m; \
auto t2##m = d2##m - d1##m; \
auto t3##m = d3##m - d1##m;

UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb

#define cb(m) \
d##m##0 = t##m##0 - t##m##2; \
d##m##1 = t##m##1 + t##m##2; \
d##m##2 = t##m##2 - t##m##1; \
d##m##3 = t##m##3 - t##m##1;

UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb

size_t ICB = IC / 8;
size_t icb = ic / 8;
#define cb(m, n) \
d##m##n.save(input_transform_buf + \
(m * alpha + n) * nr_units_in_tile * ICB * 8 + \
icb * nr_units_in_tile * 8 + unit_idx * 8);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb)
#undef cb
}
};

template <BiasMode bmode, typename Op>
struct OutputTransform2X3_8x8 {
static void transform(const dt_float16* output_transform_buf,
const dt_float16* bias, dt_float16* output,
dt_float16* transform_mid_buf, size_t oh_start,
size_t ow_start, size_t OH, size_t OW,
size_t oc_start, size_t oc_end, size_t oc_index,
size_t unit_idx, size_t nr_units_in_tile,
const DType& src_dtype, const DType& dst_dtype) {
Op op(src_dtype, dst_dtype);
const __fp16* output_transform_ptr =
reinterpret_cast<const __fp16*>(output_transform_buf);
const __fp16* bias_ptr = reinterpret_cast<const __fp16*>(bias);
__fp16* output_ptr = reinterpret_cast<__fp16*>(output);
__fp16* transform_mid_ptr =
reinterpret_cast<__fp16*>(transform_mid_buf);

//! AT * m * A
constexpr size_t alpha = 2 + 3 - 1;

size_t oc = oc_start + oc_index;
size_t OCB = (oc_end - oc_start) / 8;
size_t ocb = oc_index / 8;
#define cb(m, n) \
auto v##m##n = Vector<__fp16, 8>::load( \
output_transform_ptr + \
(m * alpha + n) * OCB * nr_units_in_tile * 8 + \
ocb * nr_units_in_tile * 8 + unit_idx * 8);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb

//! 1 1 1 0 v00 v01 v02 v03 1 0
//! 0 1 -1 1 v10 v11 v12 v13 1 1
//! v20 v21 v22 v23 1 -1
//! v30 v31 v32 v33 0 1
#define cb(m) \
auto t0##m = v0##m + v1##m + v2##m; \
auto t1##m = v1##m - v2##m + v3##m;

UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
v00 = t00 + t01 + t02;
v10 = t10 + t11 + t12;
v01 = t01 - t02 + t03;
v11 = t11 - t12 + t13;

if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
Vector<__fp16, 8> vbias;
vbias = Vector<__fp16, 8>::load(bias_ptr + oc);

v00 += vbias;
v10 += vbias;
v01 += vbias;
v11 += vbias;
}
if (bmode != BiasMode::BIAS) {
v00.value = op(v00.value);
v01.value = op(v01.value);
v10.value = op(v10.value);
v11.value = op(v11.value);
}

v00.save(transform_mid_ptr + (0 * 2 + 0) * 8);
v10.save(transform_mid_ptr + (1 * 2 + 0) * 8);
v01.save(transform_mid_ptr + (0 * 2 + 1) * 8);
v11.save(transform_mid_ptr + (1 * 2 + 1) * 8);

for (size_t oco = 0; oco < 8 && oc + oco < oc_end; ++oco) {
for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) {
for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) {
size_t oh = oh_start + oho;
size_t ow = ow_start + owo;
__fp16 res = transform_mid_ptr[oho * 2 * 8 + owo * 8 + oco];
if (bmode == BiasMode::BIAS) {
res += bias_ptr[(oc + oco) * OH * OW + oh * OW + ow];
res = op(res);
}
output_ptr[(oc + oco) * OH * OW + oh * OW + ow] = res;
}
}
}
}
};
} // namespace

namespace megdnn {
namespace arm_common {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_f16)

void winograd_2x3_8x8_f16::filter(const dt_float16* filter,
dt_float16* filter_transform_buf,
dt_float16* transform_mid_buf, size_t OC,
size_t IC, size_t oc_start, size_t oc_end) {
constexpr int alpha = 2 + 3 - 1;
//! G * g * GT
__fp16* filter_transbuf_ptr =
reinterpret_cast<__fp16*>(filter_transform_buf);
__fp16* filter_transmid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf);

size_t OCB = OC / 8;
size_t ICB = IC / 8;
for (size_t oc = oc_start; oc < oc_end; oc++) {
rep(ic, IC) {
const __fp16* filter_ptr = reinterpret_cast<const __fp16*>(filter) +
(oc * IC + ic) * 3 * 3;
/**
* origin: (4x3) * (3 x 3) * (3 x 4)
* pack to G and g to times of 4
* now: (4x4) * (4 x 4) * (4 x 4)
*/
//! 1 0 0 0 v00 v01 v02 0 1 0.5 0.5 0
//! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0
//! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1
//! 0 0 1 0 0 0 0 0 0 0 0 0
float16x4_t v0 = vld1_f16(filter_ptr); // 0 1 2 3
float16x4_t v1 = vld1_f16(filter_ptr + 3); // 3 4 5 6
float16x4_t v2 = vld1_f16(filter_ptr + 5); // 5678
float16x4_t v3 = vdup_n_f16(0);
v2 = vext_f16(v2, v3, 1);
v0 = vset_lane_f16(0, v0, 3);
v1 = vset_lane_f16(0, v1, 3);
#define cb(i) float16x4_t vsum##i;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
vsum0 = v0;
float16x4_t v0addv2 = vadd_f16(v0, v2);
float16x4_t v02addv1 = vadd_f16(v0addv2, v1);
float16x4_t v02subv1 = vsub_f16(v0addv2, v1);
vsum1 = vmul_n_f16(v02addv1, 0.5);
vsum2 = vmul_n_f16(v02subv1, 0.5);
vsum3 = v2;

#define cb(i) \
do { \
mid_buf1[0] = vget_lane_f16(vsum##i, 0); \
__fp16 a0a2 = vget_lane_f16(vsum##i, 0) + vget_lane_f16(vsum##i, 2); \
__fp16 a0a2adda1 = a0a2 + vget_lane_f16(vsum##i, 1); \
__fp16 a0a2suba1 = a0a2 - vget_lane_f16(vsum##i, 1); \
mid_buf1[1] = a0a2adda1 * 0.5; \
mid_buf1[2] = a0a2suba1 * 0.5; \
mid_buf1[3] = vget_lane_f16(vsum##i, 2); \
mid_buf1 += 4; \
} while (0);

__fp16* mid_buf1 = filter_transmid_ptr;
UNROLL_CALL_NOWRAPPER(4, cb);
mid_buf1 = filter_transmid_ptr;
#undef cb
size_t ocb = (oc) / 8;
size_t oc8 = (oc) % 8;
size_t icb = (ic) / 8;
size_t ic8 = (ic) % 8;
rep(i, alpha) rep(j, alpha) {
filter_transbuf_ptr[(i * alpha + j) * OCB * ICB * 8 * 8 +
ocb * ICB * 8 * 8 + icb * 8 * 8 + ic8 * 8 +
oc8] = filter_transmid_ptr[i * alpha + j];
}
}
}
}

void winograd_2x3_8x8_f16::input(const dt_float16* input,
dt_float16* input_transform_buf,
dt_float16* transform_mid_buf, size_t IH,
size_t IW, size_t IC, size_t PH, size_t PW,
size_t unit_start_idx, size_t nr_units_in_tile) {
megdnn_assert(IC % 8 == 0);
constexpr int alpha = 3 + 2 - 1;

// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
dt_float16* patch = transform_mid_buf;
dt_float16* patchT = transform_mid_buf + 8 * alpha * alpha;

for (size_t ic = 0; ic < IC; ic += 8) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransform2X3_8x8::prepare<true>(
reinterpret_cast<const __fp16*>(input),
reinterpret_cast<__fp16*>(patch),
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start,
IH, IW, ic, IC);
InputTransform2X3_8x8::transform(
reinterpret_cast<const __fp16*>(patchT),
reinterpret_cast<__fp16*>(input_transform_buf),
unit_idx, nr_units_in_tile, ic, IC);

} else {
InputTransform2X3_8x8::prepare<false>(
reinterpret_cast<const __fp16*>(input),
reinterpret_cast<__fp16*>(patch),
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start,
IH, IW, ic, IC);
InputTransform2X3_8x8::transform(
reinterpret_cast<const __fp16*>(patchT),
reinterpret_cast<__fp16*>(input_transform_buf),
unit_idx, nr_units_in_tile, ic, IC);
}
}
}
}

void winograd_2x3_8x8_f16::output(const dt_float16* output_transform_buf,
const dt_float16* bias, dt_float16* output,
dt_float16* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH,
size_t OW, size_t oc_start, size_t oc_end,
size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform2X3_8x8<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__);

auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);

for (size_t oc = oc_start; oc < oc_end; oc += 8) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp16_F23_8x8, cb, __fp16, __fp16,
bmode, nonline_mode, output_transform_buf, bias, output,
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start,
oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype);
}
}
#undef cb
}

} // namespace winograd
} // namespace arm_common
} // namespace megdnn
#endif

// vim: syntax=cpp.doxygen

+ 488
- 0
dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp View File

@@ -0,0 +1,488 @@
/**
* \file dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/f16/helper.h"
#include "src/arm_common/conv_bias/f16/strategy.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F45)
using namespace megdnn;
using namespace arm_common;
namespace {

struct FilterTransform4X5 {
#define FILTER_TRANSFORM(d, wd) \
do { \
wd##0 = d##0; \
wd##r0 = d##r0; \
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.222168; \
wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.222168; \
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.222168; \
wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.222168; \
auto tmpd0 = d##0 * 0.710938; \
auto tmpd1 = d##1 * 0.355469; \
auto tmpd2 = d##2 * 0.177734; \
auto tmpd3 = d##3 * 0.088867; \
auto tmpd4 = d##4 * 0.044434; \
auto tmpdr0 = d##r0 * 0.710938; \
auto tmpdr1 = d##r1 * 0.355469; \
auto tmpdr2 = d##r2 * 0.177734; \
auto tmpdr3 = d##r3 * 0.088867; \
auto tmpdr4 = d##r4 * 0.044434; \
wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \
wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \
wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \
wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \
tmpd0 = d##0 * 0.011108; \
tmpd1 = d##1 * 0.022217; \
tmpd2 = d##2 * 0.044434; \
tmpd3 = d##3 * 0.088867; \
tmpd4 = d##4 * 0.177734; \
tmpdr0 = d##r0 * 0.011108; \
tmpdr1 = d##r1 * 0.022217; \
tmpdr2 = d##r2 * 0.044434; \
tmpdr3 = d##r3 * 0.088867; \
tmpdr4 = d##r4 * 0.177734; \
wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \
; \
wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \
; \
wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \
; \
wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \
; \
wd##7 = d##4; \
wd##r7 = d##r4; \
} while (0);

#define FILTER_TRANSFORM_FINAL(d, wd) \
do { \
wd##0 = d##0; \
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.222168; \
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.222168; \
auto tmp0 = d##0 * 0.710938 + d##2 * 0.177734 + d##4 * 0.044434; \
auto tmp1 = d##1 * 0.355469 + d##3 * 0.088867; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d##0 * 0.011108 + d##2 * 0.044434 + d##4 * 0.177734; \
tmp1 = d##1 * 0.022217 + d##3 * 0.088867; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = d##4; \
} while (0);
static void transform(const __fp16* filter, __fp16* filter_transform_buf,
__fp16* transform_mid_buf, size_t OC, size_t IC,
size_t oc_start, size_t oc_end) {
// Gg * GT
// G
//[[ 1. 0. 0. 0. 0. ]
// [-0.2222222 -0.2222222 -0.2222222 -0.2222222 -0.2222222]
// [-0.2222222 0.2222222 -0.2222222 0.2222222 -0.2222222]
// [ 0.7111111 0.3555556 0.1777778 0.0888889 0.0444444]
// [ 0.7111111 -0.3555556 0.1777778 -0.0888889 0.0444444]
// [ 0.0111111 0.0222222 0.0444444 0.0888889 0.1777778]
// [ 0.0111111 -0.0222222 0.0444444 -0.0888889 0.1777778]
// [ 0. 0. 0. 0. 1. ]]
constexpr size_t alpha = 4 + 5 - 1;
for (size_t oc = oc_start; oc < oc_end; oc++) {
rep(ic, IC) {
const __fp16* fptr = filter + (oc * IC + ic) * 5 * 5;

#define cb(i) Vector<__fp16, 4> g##i = Vector<__fp16, 4>::load(fptr + 5 * i);
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb

#define cb(i) __fp16 gr##i = *(fptr + 5 * i + 4);
UNROLL_CALL_NOWRAPPER(5, cb);

#undef cb
#define cb(i) Vector<__fp16, 4> Gg##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

#define cb(i) __fp16 Ggr##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

#define cb(i) Vector<__fp16, 8> Ggt##i;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb

#define cb(i) Vector<__fp16, 8> result##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
FILTER_TRANSFORM(g, Gg)
#if MEGDNN_AARCH64
float16x8_t vgr = {Ggr0, Ggr1, Ggr2, Ggr3,
Ggr4, Ggr5, Ggr6, Ggr7};
Vector<__fp16, 8> Ggt4(vgr);
TRANSPOSE_8x4(Gg, Ggt);
FILTER_TRANSFORM_FINAL(Ggt, result);
#define cb(i) result##i.save(transform_mid_buf + i * alpha);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
rep(i, alpha) rep(j, alpha) {
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC +
oc] = transform_mid_buf[j * alpha + i];
}
#else

#define GET_VECTOR_FP16D_ELEM(s, i, idx) vget_lane_f16(CONCAT(s, i).value, idx)

#define cb(i) \
do { \
mid_buf1[0] = GET_VECTOR_FP16D_ELEM(Gg, i, 0); \
auto tmp024 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) + \
GET_VECTOR_FP16D_ELEM(Gg, i, 2) + Ggr##i; \
auto tmp13 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) + \
GET_VECTOR_FP16D_ELEM(Gg, i, 3); \
mid_buf1[1] = (tmp024 + tmp13) * -0.2222222; \
mid_buf1[2] = (tmp024 - tmp13) * -0.2222222; \
auto tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.7111111; \
auto tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.3555556; \
auto tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.1777778; \
auto tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \
auto tmp4 = Ggr##i * 0.0444444; \
tmp024 = tmp0 + tmp2 + tmp4; \
tmp13 = tmp1 + tmp3; \
mid_buf1[3] = tmp024 + tmp13; \
mid_buf1[4] = tmp024 - tmp13; \
tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.0111111; \
tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.0222222; \
tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.0444444; \
tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \
tmp4 = Ggr##i * 0.1777778; \
tmp024 = tmp0 + tmp2 + tmp4; \
tmp13 = tmp1 + tmp3; \
mid_buf1[5] = tmp024 + tmp13; \
mid_buf1[6] = tmp024 - tmp13; \
mid_buf1[7] = Ggr##i; \
mid_buf1 += 8; \
} while (0);
__fp16* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb);
mid_buf1 = transform_mid_buf;
#undef cb
#undef GET_VECTOR_FP16D_ELEM
rep(i, alpha) rep(j, alpha) {
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC +
oc] = transform_mid_buf[i * alpha + j];
}
#endif
}
}
}
};
#undef FILTER_TRANSFORM
#undef FILTER_TRANSFORM_FINAL

struct InputTransform4X5 {
#define INPUT_TRANSFORM(d, wd) \
do { \
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \
auto tmp0 = d##2 - d##4 * 4.25f + d##6; \
auto tmp1 = d##1 - d##3 * 4.25f + d##5; \
wd##1 = tmp0 + tmp1; \
wd##2 = tmp0 - tmp1; \
tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \
tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \
} while (0)

#define GET_VECTOR_FP16Q_ELEM(s, i, idx) vgetq_lane_f16(CONCAT(s, i).value, idx)

template <bool inner>
static void transform(const __fp16* input, __fp16* input_transform_buf,
__fp16* transform_mid_buf, int ih_start, int iw_start,
size_t ic, size_t IH, size_t IW, size_t IC,
size_t unit_idx, size_t nr_units_in_tile) {
// BTd * B
//([[ 1. , 0. , -5.25, 0. , 5.25, 0. , -1. , 0. ],
// [ 0. , 1. , 1. , -4.25, -4.25, 1. , 1. , 0. ],
// [ 0. , -1. , 1. , 4.25, -4.25, -1. , 1. , 0. ],
// [ 0. , 2. , 4. , -2.5 , -5. , 0.5 , 1. , 0. ],
// [ 0. , -2. , 4. , 2.5 , -5. , -0.5 , 1. , 0. ],
// [ 0. , 0.5 , 0.25, -2.5 , -1.25, 2. , 1. , 0. ],
// [ 0. , -0.5 , 0.25, 2.5 , -1.25, -2. , 1. , 0. ],
// [ 0. , -1. , 0. , 5.25, 0. , -5.25, 0. , 1. ]]))

constexpr size_t alpha = 4 + 5 - 1;
if (!inner) {
memset(transform_mid_buf, 0, sizeof(__fp16) * alpha * alpha);
}

#define cb(i) Vector<__fp16, 8> d##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

if (inner) {
const __fp16* input_ptr =
input + ic * IH * IW + ih_start * IW + iw_start;
#define cb(i) d##i = Vector<__fp16, 8>::load(input_ptr + IW * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
transform_mid_buf[iho * alpha + iwo] =
input[ic * IH * IW + ih * IW + iw];
}
}
#define cb(i) d##i = Vector<__fp16, 8>::load(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
}

#define cb(i) Vector<__fp16, 8> wd##i, ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

INPUT_TRANSFORM(d, wd);
TRANSPOSE_8x8(wd, d);
INPUT_TRANSFORM(d, ret);

#define cb(i) ret##i.save(transform_mid_buf + i * alpha);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

rep(i, alpha) rep(j, alpha) {
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC +
unit_idx * IC + ic] =
transform_mid_buf[j * alpha + i];
}
}
};
#undef INPUT_TRANSFORM

#define OUTPUT_TRANSFORM(m, s) \
do { \
s##0 = m##0 + m##1 + m##2 + m##3 + m##4 + m##5 + m##6; \
s##1 = m##1 - m##2 + m##3 * 0.5 - m##4 * 0.5 + m##5 * 2.0 - \
m##6 * 2.0; \
s##2 = m##1 + m##2 + m##3 * 0.25 + m##4 * 0.25 + m##5 * 4.0 + \
m##6 * 4.0; \
s##3 = m##1 - m##2 + m##3 * 0.125 - m##4 * 0.125 + m##5 * 8.0 - \
m##6 * 8.0 + m##7; \
} while (0)
template <BiasMode bmode, typename Op>
struct OutputTransform4X5 {
static void transform(const dt_float16* output_transform_buf,
const dt_float16* bias, dt_float16* output,
dt_float16* transform_mid_buf, size_t oh_start,
size_t ow_start, size_t OH, size_t OW,
size_t oc_start, size_t oc_end, size_t oc_index,
size_t unit_idx, size_t nr_units_in_tile,
const DType& src_dtype, const DType& dst_dtype) {
Op op(src_dtype, dst_dtype);
//! AT * m * A
// AT f45
// 1.0 1.0 1.0 1.000 1.000 1.0 1.0 0.0
// 0.0 1.0 -1.0 0.500 -0.500 2.0 -2.0 0.0
// 0.0 1.0 1.0 0.250 0.250 4.0 4.0 0.0
// 0.0 1.0 -1.0 0.125 -0.125 8.0 -8.0 1.0
constexpr size_t alpha = 5 + 4 - 1;
const __fp16* fp16_output_transform_buf =
reinterpret_cast<const __fp16*>(output_transform_buf);
const __fp16* fp16_bias = reinterpret_cast<const __fp16*>(bias);
__fp16* fp16_output = reinterpret_cast<__fp16*>(output);
__fp16* fp16_transform_mid_buf =
reinterpret_cast<__fp16*>(transform_mid_buf);

__fp16* mid_buf1 = fp16_transform_mid_buf;

size_t OC = oc_end - oc_start;
size_t oc = oc_start + oc_index;

#define cb(m, n) \
fp16_transform_mid_buf[m * alpha + n] = \
fp16_output_transform_buf[(m * alpha + n) * nr_units_in_tile * \
OC + \
unit_idx * OC + oc_index];
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb

#define cb(i) \
auto m##i = Vector<__fp16, 8>::load(fp16_transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
#define cb(i) Vector<__fp16, 8> s##i;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(i) Vector<__fp16, 4> st##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
#define cb(i) Vector<__fp16, 4> result##i;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb

OUTPUT_TRANSFORM(m, s);
TRANSPOSE_4x8(s, st);
OUTPUT_TRANSFORM(st, result);
TRANSPOSE_4x4(result, result);

if (oh_start + 4 <= OH && ow_start + 4 <= OW) {
int index = (oc * OH + oh_start) * OW + ow_start;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
float16x4_t bias0 = vdup_n_f16(fp16_bias[oc]);
result0.value = vadd_f16(result0.value, bias0);
result1.value = vadd_f16(result1.value, bias0);
result2.value = vadd_f16(result2.value, bias0);
result3.value = vadd_f16(result3.value, bias0);
} else if (bmode == BiasMode::BIAS) {
float16x4_t bmbias0 = vld1_f16(fp16_bias + index);
float16x4_t bmbias1 = vld1_f16(fp16_bias + index + OW);
float16x4_t bmbias2 = vld1_f16(fp16_bias + index + OW * 2);
float16x4_t bmbias3 = vld1_f16(fp16_bias + index + OW * 3);
result0.value = vadd_f16(result0.value, bmbias0);
result1.value = vadd_f16(result1.value, bmbias1);
result2.value = vadd_f16(result2.value, bmbias2);
result3.value = vadd_f16(result3.value, bmbias3);
}

float16x8_t item01 = op(vcombine_f16(result0.value, result1.value));
float16x8_t item23 = op(vcombine_f16(result2.value, result3.value));

vst1_f16(fp16_output + index, vget_low_f16(item01));
vst1_f16(fp16_output + index + OW, vget_high_f16(item01));
vst1_f16(fp16_output + index + OW * 2, vget_low_f16(item23));
vst1_f16(fp16_output + index + OW * 3, vget_high_f16(item23));
} else {
#define cb(i) result##i.save(mid_buf1 + i * 4);
mid_buf1 = fp16_transform_mid_buf;
UNROLL_CALL_NOWRAPPER(4, cb);
mid_buf1 = fp16_transform_mid_buf;
#undef cb
for (size_t oho = 0; oho < 4 && oh_start + oho < OH; ++oho) {
for (size_t owo = 0; owo < 4 && ow_start + owo < OW; ++owo) {
size_t oh = oh_start + oho;
size_t ow = ow_start + owo;
__fp16 res = mid_buf1[oho * 4 + owo];
if (bmode == BiasMode::BIAS) {
res += fp16_bias[oc * OH * OW + oh * OW + ow];
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
res += fp16_bias[oc];
}
res = op(res);
fp16_output[oc * OH * OW + oh * OW + ow] = res;
}
}
}
}
};
#undef OUTPUT_TRANSFORM
#undef GET_VECTOR_FP16Q_ELEM
} // namespace

namespace megdnn {
namespace arm_common {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f16)

void winograd_4x5_1x1_f16::filter(const dt_float16* filter,
dt_float16* filter_transform_buf,
dt_float16* transform_mid_buf, size_t OC,
size_t IC, size_t oc_start, size_t oc_end) {
FilterTransform4X5::transform(
reinterpret_cast<const __fp16*>(filter),
reinterpret_cast<__fp16*>(filter_transform_buf),
reinterpret_cast<__fp16*>(transform_mid_buf), OC, IC, oc_start,
oc_end);
}

void winograd_4x5_1x1_f16::input(const dt_float16* input,
dt_float16* input_transform_buf,
dt_float16* transform_mid_buf, size_t IH,
size_t IW, size_t IC, size_t PH, size_t PW,
size_t unit_start_idx,
size_t nr_units_in_tile) {
constexpr int alpha = 4 + 5 - 1;
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
rep(ic, IC) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransform4X5::transform<true>(
reinterpret_cast<const __fp16*>(input),
reinterpret_cast<__fp16*>(input_transform_buf),
reinterpret_cast<__fp16*>(transform_mid_buf), ih_start,
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile);

} else {
InputTransform4X5::transform<false>(
reinterpret_cast<const __fp16*>(input),
reinterpret_cast<__fp16*>(input_transform_buf),
reinterpret_cast<__fp16*>(transform_mid_buf), ih_start,
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile);
}
}
}
}

void winograd_4x5_1x1_f16::output(const dt_float16* output_transform_buf,
const dt_float16* bias, dt_float16* output,
dt_float16* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH, size_t OW,
size_t oc_start, size_t oc_end,
size_t unit_start_idx, size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform4X5<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__);

auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);

for (size_t oc = oc_start; oc < oc_end; oc++) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp16_F45, cb, __fp16, __fp16, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
}
}
#undef cb
}

} // namespace winograd
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 608
- 0
dnn/src/arm_common/conv_bias/f16/strategy_6x3.cpp View File

@@ -0,0 +1,608 @@
/**
* \file dnn/src/arm_common/conv_bias/f16/strategy_6x3.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/f16/helper.h"
#include "src/arm_common/conv_bias/f16/strategy.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F63)
using namespace megdnn;
using namespace arm_common;
namespace {
struct FilterTransform6X3 {
// 1.0000000 0.0000000 0.0000000
//-0.2222222 -0.2222222 -0.2222222
//-0.2222222 0.2222222 -0.2222222
// 0.0111111 0.0222222 0.0444444
// 0.0111111 -0.0222222 0.0444444
// 0.7111111 0.3555556 0.1777778
// 0.7111111 -0.3555556 0.1777778
// 0.0000000 0.0000000 1.0000000
#define FILTER_TRANSFORM(d, wd) \
do { \
wd##0 = d##0; \
wd##1 = (d##0 + d##1 + d##2) * -0.222168; \
wd##2 = (d##0 - d##1 + d##2) * -0.222168; \
auto tmpd0 = d##0 * 0.011108; \
auto tmpd1 = d##1 * 0.022217; \
auto tmpd2 = d##2 * 0.044434; \
wd##3 = tmpd0 + tmpd1 + tmpd2; \
wd##4 = tmpd0 - tmpd1 + tmpd2; \
tmpd0 = d##0 * 0.710938; \
tmpd1 = d##1 * 0.355469; \
tmpd2 = d##2 * 0.177734; \
wd##5 = tmpd0 + tmpd1 + tmpd2; \
wd##6 = tmpd0 - tmpd1 + tmpd2; \
wd##7 = d##2; \
} while (0);

#define FILTER_TRANSFORM_FINAL(d, wd) \
do { \
wd##0 = d##0; \
wd##1 = (d##0 + d##1 + d##2) * -0.222168; \
wd##2 = (d##0 - d##1 + d##2) * -0.222168; \
auto tmp0 = d##0 * 0.011108 + d##2 * 0.044434; \
auto tmp1 = d##1 * 0.022217; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d##0 * 0.710938 + d##2 * 0.177734; \
tmp1 = d##1 * 0.355469; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = d##2; \
} while (0);
static void transform(const __fp16* filter, __fp16* filter_transform_buf,
__fp16* transform_mid_buf, size_t OC, size_t IC,
size_t oc_start, size_t oc_end) {
// Gg * GT
// G
// 1.0000000 0.0000000 0.0000000
//-0.2222222 -0.2222222 -0.2222222
//-0.2222222 0.2222222 -0.2222222
// 0.0111111 0.0222222 0.0444444
// 0.0111111 -0.0222222 0.0444444
// 0.7111111 0.3555556 0.1777778
// 0.7111111 -0.3555556 0.1777778
// 0.0000000 0.0000000 1.0000000
constexpr size_t alpha = 6 + 3 - 1;
for (size_t oc = oc_start; oc < oc_end; oc++) {
rep(ic, IC) {
const __fp16* fptr = filter + (oc * IC + ic) * 3 * 3;

float16x4_t v0, v1, v2, v3;
v0 = vld1_f16(fptr); // 0 1 2 3
v1 = vld1_f16(fptr + 3); // 3 4 5 6
v2 = vld1_f16(fptr + 5); // 5 6 7 8
v3 = vdup_n_f16(0);
v2 = vext_f16(v2, v3, 1);
v0 = vset_lane_f16(0, v0, 3);
v1 = vset_lane_f16(0, v1, 3);
#define cb(i) Vector<__fp16, 4> g##i(v##i);
UNROLL_CALL_NOWRAPPER(3, cb);
#undef cb

#define cb(i) Vector<__fp16, 4> Gg##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

#define cb(i) Vector<__fp16, 8> Ggt##i;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb

#define cb(i) Vector<__fp16, 8> result##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
FILTER_TRANSFORM(g, Gg)
#if MEGDNN_AARCH64
TRANSPOSE_8x4(Gg, Ggt);
FILTER_TRANSFORM_FINAL(Ggt, result);
#define cb(i) result##i.save(transform_mid_buf + i * alpha);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
rep(i, alpha) rep(j, alpha) {
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC +
oc] = transform_mid_buf[j * alpha + i];
}
#else
/* 1.0000000 -0.2222222 -0.2222222 0.0111111 0.0111111
0.7111111 0.7111111 0.0000000 0.0000000 -0.2222222
0.2222222 0.0222222 -0.0222222 0.3555556 -0.3555556
0.0000000 0.0000000 -0.2222222 -0.2222222 0.0444444
0.0444444 0.1777778 0.1777778 1.0000000*/

#define GET_VECTOR_FP16D_ELEM(s, i, idx) vget_lane_f16(CONCAT(s, i).value, idx)
#define cb(i) \
do { \
mid_buf1[0] = GET_VECTOR_FP16D_ELEM(Gg, i, 0); \
auto tmp02 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) + \
GET_VECTOR_FP16D_ELEM(Gg, i, 2); \
mid_buf1[1] = (tmp02 + GET_VECTOR_FP16D_ELEM(Gg, i, 1)) * -0.2222222; \
mid_buf1[2] = (tmp02 - GET_VECTOR_FP16D_ELEM(Gg, i, 1)) * -0.2222222; \
auto tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.0111111; \
auto tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.0222222; \
auto tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.0444444; \
tmp02 = tmp0 + tmp2; \
mid_buf1[3] = tmp02 + tmp1; \
mid_buf1[4] = tmp02 - tmp1; \
tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.7111111; \
tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.3555556; \
tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.1777778; \
tmp02 = tmp0 + tmp2; \
mid_buf1[5] = tmp02 + tmp1; \
mid_buf1[6] = tmp02 - tmp1; \
mid_buf1[7] = GET_VECTOR_FP16D_ELEM(Gg, i, 2); \
mid_buf1 += 8; \
} while (0);
__fp16* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb);
mid_buf1 = transform_mid_buf;
#undef cb
rep(i, alpha) rep(j, alpha) {
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC +
oc] = transform_mid_buf[i * alpha + j];
}
#undef GET_VECTOR_FP16D_ELEM
#endif
}
}
}
};
#undef FILTER_TRANSFORM
#undef FILTER_TRANSFORM_FINAL
/**
* input transform
*
* wd0 = (d0 - d6) + 5.25 * (d4 - d2)
* wd1 = (d6 + d2 - 4.25 * d4) + (d1 + d5 - 4.25 * d3)
* wd2 = (d6 + d2 - 4.25 * d4) - (d1 + d5 - 4.25 * d3)
* wd3 = (d6 + 0.25 * d2 - 1.25 * d4) + 2.0 * (d5 + 0.25 * d1 - 1.25 * d3)
* wd4 = (d6 + 0.25 * d2 - 1.25 * d4) - 2.0 * (d5 + 0.25 * d1 - 1.25 * d3)
* wd5 = (d6 - 5.0 * d4 + 4.0 * d2) + 2.0 * (d1 + 0.25 * d5 - 1.25 * d3)
* wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3)
* wd7 = (d7 - d1) + 5.25 * (d3 - d5)
*/
#define INPUT_TRANSFORM(d, wd) \
do { \
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25; \
auto tmp0 = d##6 + d##2 - d##4 * 4.25; \
auto tmp1 = d##1 + d##5 - d##3 * 4.25; \
wd##1 = tmp0 + tmp1; \
wd##2 = tmp0 - tmp1; \
tmp0 = d##6 + d##2 * 0.25 - d##4 * 1.25; \
tmp1 = (d##5 + d##1 * 0.25 - d##3 * 1.25) * 2.0; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d6 - d4 * 5.0 + d2 * 4.0; \
tmp1 = (d1 + d5 * 0.25 - d3 * 1.25) * 2.0; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25; \
} while (0);

#define GET_VECTOR_FP16Q_ELEM(s, i, idx) vgetq_lane_f16(CONCAT(s, i).value, idx)
struct InputTransform6x3 {
template <bool inner>
static void transform(const __fp16* input, __fp16* input_transform_buf,
__fp16* transform_mid_buf, int ih_start, int iw_start,
size_t ic, size_t IH, size_t IW, size_t IC,
size_t unit_idx, size_t nr_units_in_tile) {
// BTd * B
// 1.000 0.000 -5.25 0.000 5.250 0.000 -1.0 0.00
// -0.00 1.000 1.000 -4.25 -4.25 1.000 1.00 -0.0
// -0.00 -1.00 1.000 4.250 -4.25 -1.00 1.00 -0.0
// 0.000 0.500 0.250 -2.50 -1.25 2.000 1.00 0.00
// 0.000 -0.50 0.250 2.500 -1.25 -2.00 1.00 0.00
// 0.000 2.000 4.000 -2.50 -5.00 0.500 1.00 0.00
// 0.000 -2.00 4.000 2.500 -5.00 -0.50 1.00 0.00
// 0.000 -1.00 0.000 5.250 0.000 -5.25 0.00 1.00
constexpr size_t alpha = 6 + 3 - 1;
if (!inner) {
memset(transform_mid_buf, 0, sizeof(__fp16) * alpha * alpha);
}

#define cb(i) Vector<__fp16, 8> d##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

if (inner) {
const __fp16* input_ptr =
input + ic * IH * IW + ih_start * IW + iw_start;
#define cb(i) d##i = Vector<__fp16, 8>::load(input_ptr + IW * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
transform_mid_buf[iho * alpha + iwo] =
input[ic * IH * IW + ih * IW + iw];
}
}
#define cb(i) d##i = Vector<__fp16, 8>::load(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
}

#define cb(i) Vector<__fp16, 8> wd##i, ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

INPUT_TRANSFORM(d, wd);

#if MEGDNN_AARCH64
TRANSPOSE_8x8(wd, d);
INPUT_TRANSFORM(d, ret);

#define cb(i) ret##i.save(transform_mid_buf + i * alpha);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

rep(i, alpha) rep(j, alpha) {
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC +
unit_idx * IC + ic] =
transform_mid_buf[j * alpha + i];
}
#else
//! 1 0 0 0 0 0 0 0
//! 0 1 -1 0.5 -0.5 2 -2 -1
//! -5.25 1 1 0.25 0.25 4 4 0
//! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25
//! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0
//! 0 1 -1 2 -2 0.5 -0.5 -5.25
//! -1 1 1 1 1 1 1 0
//! 0 0 0 0 0 0 0 1
#define cb(i) \
do { \
mid_buf1[0] = GET_VECTOR_FP16Q_ELEM(wd, i, 0) - \
GET_VECTOR_FP16Q_ELEM(wd, i, 6) + \
5.25 * (GET_VECTOR_FP16Q_ELEM(wd, i, 4) - \
GET_VECTOR_FP16Q_ELEM(wd, i, 2)); \
mid_buf1[7] = GET_VECTOR_FP16Q_ELEM(wd, i, 7) - \
GET_VECTOR_FP16Q_ELEM(wd, i, 1) + \
5.25 * (GET_VECTOR_FP16Q_ELEM(wd, i, 3) - \
GET_VECTOR_FP16Q_ELEM(wd, i, 5)); \
auto tmp0 = GET_VECTOR_FP16Q_ELEM(wd, i, 2) + \
GET_VECTOR_FP16Q_ELEM(wd, i, 6) - \
4.25 * GET_VECTOR_FP16Q_ELEM(wd, i, 4); \
auto tmp1 = GET_VECTOR_FP16Q_ELEM(wd, i, 1) - \
GET_VECTOR_FP16Q_ELEM(wd, i, 3) * 4.25 + \
GET_VECTOR_FP16Q_ELEM(wd, i, 5); \
mid_buf1[1] = tmp0 + tmp1; \
mid_buf1[2] = tmp0 - tmp1; \
tmp0 = GET_VECTOR_FP16Q_ELEM(wd, i, 2) * 0.25 + \
GET_VECTOR_FP16Q_ELEM(wd, i, 6) - \
GET_VECTOR_FP16Q_ELEM(wd, i, 4) * 1.25; \
tmp1 = GET_VECTOR_FP16Q_ELEM(wd, i, 1) * 0.5 - \
GET_VECTOR_FP16Q_ELEM(wd, i, 3) * 2.5 + \
GET_VECTOR_FP16Q_ELEM(wd, i, 5) * 2; \
mid_buf1[3] = tmp0 + tmp1; \
mid_buf1[4] = tmp0 - tmp1; \
tmp0 = GET_VECTOR_FP16Q_ELEM(wd, i, 6) + \
GET_VECTOR_FP16Q_ELEM(wd, i, 2) * 4.0 - \
GET_VECTOR_FP16Q_ELEM(wd, i, 4) * 5.0; \
tmp1 = GET_VECTOR_FP16Q_ELEM(wd, i, 1) * 2 - \
GET_VECTOR_FP16Q_ELEM(wd, i, 3) * 2.5 + \
GET_VECTOR_FP16Q_ELEM(wd, i, 5) * 0.5; \
mid_buf1[5] = tmp0 + tmp1; \
mid_buf1[6] = tmp0 - tmp1; \
mid_buf1 += 8; \
} while (0);

__fp16* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb);
mid_buf1 = transform_mid_buf;

#undef cb
rep(i, alpha) rep(j, alpha) {
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC +
unit_idx * IC + ic] =
transform_mid_buf[i * alpha + j];
}
#endif
}
};

#undef INPUT_TRANSFORM

#define OUTPUT_TRANSFORM(m, r) \
do { \
auto m1addm2 = m##1 + m##2; \
auto m1subm2 = m##1 - m##2; \
auto m3addm4 = m##3 + m##4; \
auto m3subm4 = m##3 - m##4; \
auto m5addm6 = m##5 + m##6; \
auto m5subm6 = m##5 - m##6; \
r##0 = m##0 + m1addm2 + m3addm4 + m5addm6; \
r##1 = m1subm2 + m3subm4 * 2.0 + m5subm6 * 0.5; \
r##2 = m1addm2 + m3addm4 * 4.0 + m5addm6 * 0.25; \
r##3 = m1subm2 + m3subm4 * 8.0 + m5subm6 * 0.125; \
r##4 = m1addm2 + m3addm4 * 16.0 + m5addm6 * 0.0625; \
r##5 = m1subm2 + m3subm4 * 32.0 + m5subm6 * 0.03125 + m##7; \
} while (0)
template <BiasMode bmode, typename Op>
struct OutputTransform6X3 {
static void transform(const dt_float16* output_transform_buf,
const dt_float16* bias, dt_float16* output,
dt_float16* transform_mid_buf, size_t oh_start,
size_t ow_start, size_t OH, size_t OW,
size_t oc_start, size_t oc_end, size_t oc_index,
size_t unit_idx, size_t nr_units_in_tile,
const DType& src_dtype, const DType& dst_dtype) {
Op op(src_dtype, dst_dtype);
//! AT * m * A
// AT f45
// 1.0 1.0 1.0 1.0 1.0 1.00000 1.00000 0.0
// 0.0 1.0 -1.0 2.0 -2.0 0.50000 -0.50000 0.0
// 0.0 1.0 1.0 4.0 4.0 0.25000 0.25000 0.0
// 0.0 1.0 -1.0 8.0 -8.0 0.12500 -0.12500 0.0
// 0.0 1.0 1.0 16.0 16.0 0.06250 0.06250 0.0
// 0.0 1.0 -1.0 32.0 -32.0 0.03125 -0.03125 1.0
constexpr size_t alpha = 3 + 6 - 1;
const __fp16* fp16_output_transform_buf =
reinterpret_cast<const __fp16*>(output_transform_buf);
const __fp16* fp16_bias = reinterpret_cast<const __fp16*>(bias);
__fp16* fp16_output = reinterpret_cast<__fp16*>(output);
__fp16* fp16_transform_mid_buf =
reinterpret_cast<__fp16*>(transform_mid_buf);

__fp16* mid_buf1 = fp16_transform_mid_buf;

size_t OC = oc_end - oc_start;
size_t oc = oc_start + oc_index;

#define cb(m, n) \
fp16_transform_mid_buf[m * alpha + n] = \
fp16_output_transform_buf[(m * alpha + n) * nr_units_in_tile * \
OC + \
unit_idx * OC + oc_index];
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb

#define cb(i) \
auto m##i = Vector<__fp16, 8>::load(fp16_transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
#define cb(i) Vector<__fp16, 8> s##i;
UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb
/* 1.0 0.0 0.00 0.000 0.0000 0.00000
1.0 1.0 1.00 1.000 1.0000 1.00000
1.0 -1.0 1.00 -1.000 1.0000 -1.00000
1.0 2.0 4.00 8.000 16.000 32.00000
1.0 -2.0 4.00 -8.000 16.000 -32.00000
1.0 0.5 0.25 0.125 0.0625 0.03125
1.0 -0.5 0.25 -0.125 0.0625 -0.03125
0.0 0.0 0.00 0.000 0.0000 1.00000*/

OUTPUT_TRANSFORM(m, s);
mid_buf1 = fp16_transform_mid_buf;

#define cb(i) \
do { \
auto m1addm2 = GET_VECTOR_FP16Q_ELEM(s, i, 1) + \
GET_VECTOR_FP16Q_ELEM(s, i, 2); \
auto m1subm2 = GET_VECTOR_FP16Q_ELEM(s, i, 1) - \
GET_VECTOR_FP16Q_ELEM(s, i, 2); \
auto m3addm4 = GET_VECTOR_FP16Q_ELEM(s, i, 3) + \
GET_VECTOR_FP16Q_ELEM(s, i, 4); \
auto m3subm4 = GET_VECTOR_FP16Q_ELEM(s, i, 3) - \
GET_VECTOR_FP16Q_ELEM(s, i, 4); \
auto m5addm6 = GET_VECTOR_FP16Q_ELEM(s, i, 5) + \
GET_VECTOR_FP16Q_ELEM(s, i, 6); \
auto m5subm6 = GET_VECTOR_FP16Q_ELEM(s, i, 5) - \
GET_VECTOR_FP16Q_ELEM(s, i, 6); \
mid_buf1[0] = \
GET_VECTOR_FP16Q_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \
mid_buf1[1] = m1subm2 + m3subm4 * 2 + m5subm6 * 0.5; \
mid_buf1[2] = m1addm2 + m3addm4 * 4 + m5addm6 * 0.25; \
mid_buf1[3] = m1subm2 + m3subm4 * 8 + m5subm6 * 0.125; \
mid_buf1[4] = m1addm2 + m3addm4 * 16 + m5addm6 * 0.0625; \
mid_buf1[5] = m1subm2 + m3subm4 * 32 + m5subm6 * 0.03125 + \
GET_VECTOR_FP16Q_ELEM(s, i, 7); \
mid_buf1 += 6; \
} while (0);
mid_buf1 = fp16_transform_mid_buf;
UNROLL_CALL_NOWRAPPER(6, cb);
mid_buf1 = fp16_transform_mid_buf;

#undef cb

if (oh_start + 6 <= OH && ow_start + 6 <= OW) {
int index = (oc * OH + oh_start) * OW + ow_start;

#define cb(i) float16x4_t vr##i = vld1_f16(mid_buf1 + i * 6);

UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb
float16x8_t vr0123_45 = {mid_buf1[4], mid_buf1[5], mid_buf1[10],
mid_buf1[11], mid_buf1[16], mid_buf1[17],
mid_buf1[22], mid_buf1[23]};
float16x4_t vr45_45 = {mid_buf1[28], mid_buf1[29], mid_buf1[34],
mid_buf1[35]};

if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
float16x4_t bias0 = vdup_n_f16(fp16_bias[oc]);
#define cb(i) vr##i = vadd_f16(vr##i, bias0);

UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb
vr45_45 = vadd_f16(vr45_45, bias0);
vr0123_45 = vaddq_f16(vr0123_45, vcombine_f16(bias0, bias0));

} else if (bmode == BiasMode::BIAS) {
#define cb(i) float16x4_t bmbias##i = vld1_f16(fp16_bias + index + OW * i);

UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb

#define cb(i) vr##i = vadd_f16(vr##i, bmbias##i);

UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb
float16x8_t vb0123_45 = {fp16_bias[index + 0 * OW + 4],
fp16_bias[index + 0 * OW + 5],
fp16_bias[index + 1 * OW + 4],
fp16_bias[index + 1 * OW + 5],
fp16_bias[index + 2 * OW + 4],
fp16_bias[index + 2 * OW + 5],
fp16_bias[index + 3 * OW + 4],
fp16_bias[index + 3 * OW + 5]};
float16x4_t vb45_45 = {fp16_bias[index + 4 * OW + 4],
fp16_bias[index + 4 * OW + 5],
fp16_bias[index + 5 * OW + 4],
fp16_bias[index + 5 * OW + 5]};
vr45_45 = vadd_f16(vr45_45, vb45_45);
vr0123_45 = vaddq_f16(vr0123_45, vb0123_45);
}

float16x8_t item01 = op(vcombine_f16(vr0, vr1));
float16x8_t item23 = op(vcombine_f16(vr2, vr3));
float16x8_t item45 = op(vcombine_f16(vr4, vr5));

vst1_f16(fp16_output + index, vget_low_f16(item01));
vst1_f16(fp16_output + index + OW, vget_high_f16(item01));
vst1_f16(fp16_output + index + OW * 2, vget_low_f16(item23));
vst1_f16(fp16_output + index + OW * 3, vget_high_f16(item23));
vst1_f16(fp16_output + index + OW * 4, vget_low_f16(item45));
vst1_f16(fp16_output + index + OW * 5, vget_high_f16(item45));
vr0123_45 = op(vr0123_45);
float16x8_t vr45 = op(vcombine_f16(vr45_45, vr45_45));

fp16_output[index + OW * 0 + 4] = vgetq_lane_f16(vr0123_45, 0);
fp16_output[index + OW * 0 + 5] = vgetq_lane_f16(vr0123_45, 1);
fp16_output[index + OW * 1 + 4] = vgetq_lane_f16(vr0123_45, 2);
fp16_output[index + OW * 1 + 5] = vgetq_lane_f16(vr0123_45, 3);
fp16_output[index + OW * 2 + 4] = vgetq_lane_f16(vr0123_45, 4);
fp16_output[index + OW * 2 + 5] = vgetq_lane_f16(vr0123_45, 5);
fp16_output[index + OW * 3 + 4] = vgetq_lane_f16(vr0123_45, 6);
fp16_output[index + OW * 3 + 5] = vgetq_lane_f16(vr0123_45, 7);
fp16_output[index + OW * 4 + 4] = vgetq_lane_f16(vr45, 0);
fp16_output[index + OW * 4 + 5] = vgetq_lane_f16(vr45, 1);
fp16_output[index + OW * 5 + 4] = vgetq_lane_f16(vr45, 2);
fp16_output[index + OW * 5 + 5] = vgetq_lane_f16(vr45, 3);
} else {
for (size_t oho = 0; oho < 6 && oh_start + oho < OH; ++oho) {
for (size_t owo = 0; owo < 6 && ow_start + owo < OW; ++owo) {
size_t oh = oh_start + oho;
size_t ow = ow_start + owo;
__fp16 res = mid_buf1[oho * 6 + owo];
if (bmode == BiasMode::BIAS) {
res += fp16_bias[oc * OH * OW + oh * OW + ow];
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
res += fp16_bias[oc];
}
res = op(res);
fp16_output[oc * OH * OW + oh * OW + ow] = res;
}
}
}
}
};
#undef GET_VECTOR_FP16Q_ELEM
#undef OUTPUT_TRANSFORM
} // namespace

namespace megdnn {
namespace arm_common {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f16)

void winograd_6x3_1x1_f16::filter(const dt_float16* filter,
dt_float16* filter_transform_buf,
dt_float16* transform_mid_buf, size_t OC,
size_t IC, size_t oc_start, size_t oc_end) {
FilterTransform6X3::transform(
reinterpret_cast<const __fp16*>(filter),
reinterpret_cast<__fp16*>(filter_transform_buf),
reinterpret_cast<__fp16*>(transform_mid_buf), OC, IC, oc_start,
oc_end);
}

void winograd_6x3_1x1_f16::input(const dt_float16* input,
dt_float16* input_transform_buf,
dt_float16* transform_mid_buf, size_t IH,
size_t IW, size_t IC, size_t PH, size_t PW,
size_t unit_start_idx,
size_t nr_units_in_tile) {
constexpr int alpha = 6 + 3 - 1;
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
rep(ic, IC) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransform6x3::transform<true>(
reinterpret_cast<const __fp16*>(input),
reinterpret_cast<__fp16*>(input_transform_buf),
reinterpret_cast<__fp16*>(transform_mid_buf), ih_start,
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile);

} else {
InputTransform6x3::transform<false>(
reinterpret_cast<const __fp16*>(input),
reinterpret_cast<__fp16*>(input_transform_buf),
reinterpret_cast<__fp16*>(transform_mid_buf), ih_start,
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile);
}
}
}
}

void winograd_6x3_1x1_f16::output(const dt_float16* output_transform_buf,
const dt_float16* bias, dt_float16* output,
dt_float16* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH, size_t OW,
size_t oc_start, size_t oc_end,
size_t unit_start_idx, size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__);

auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);

for (size_t oc = oc_start; oc < oc_end; oc++) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp16_F63, cb, __fp16, __fp16, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
}
}
#undef cb
}
} // namespace winograd
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 754
- 0
dnn/src/arm_common/conv_bias/fp32/algos.cpp View File

@@ -0,0 +1,754 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/direct/multi_thread_common.h"
#include "src/arm_common/conv_bias/fp32/direct.h"
#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h"
#include "src/arm_common/conv_bias/fp32/do_conv_stride2.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/conv_bias/img2col_helper.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"

#include "midout.h"

MIDOUT_DECL(megdnn_arm_common_winograd_fp32)

using namespace megdnn;
using namespace arm_common;

/* ======================= AlgoFP32WinogradF23_4x4 ======================== */

bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable(
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(opr);
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 0) {
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0)
return false;
using Strategy = winograd::winograd_2x3_4x4_f;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy,
param::MatrixMul::Format::MK4>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
(opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
opr->param().output_block_size == 2 &&
param.winograd_matmul_format ==
param::MatrixMul::Format::MK4)) &&
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::Float32;
}
MIDOUT_END();
return false;
}

size_t ConvBiasImpl::AlgoFP32WinogradF23_4x4::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 1) {
winograd::winograd_2x3_4x4_f strategy(param.src_type, param.filter_type,
param.dst_type);
return megdnn::winograd::ConvBias<winograd::winograd_2x3_4x4_f,
param::MatrixMul::Format::MK4>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_workspace_size(param, m_matmul_algo);
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoFP32WinogradF23_4x4::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 2) {
winograd::winograd_2x3_4x4_f strategy(param.src_type, param.filter_type,
param.dst_type);
auto winograd_impl =
megdnn::winograd::ConvBias<winograd::winograd_2x3_4x4_f,
param::MatrixMul::Format::MK4>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg);
return winograd_impl.get_kerns(param, m_matmul_algo);
}
MIDOUT_END();
return {};
}

/* ======================= AlgoFP32WinogradF63 ======================== */

bool ConvBiasImpl::AlgoFP32WinogradF63::usable(
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(opr);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 0) {
using Strategy = winograd::winograd_6x3_1x1_f;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
(opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
opr->param().output_block_size == 6 &&
param.winograd_matmul_format ==
param::MatrixMul::Format::DEFAULT)) &&
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::Float32;
}
MIDOUT_END();
return false;
}

size_t ConvBiasImpl::AlgoFP32WinogradF63::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 1) {
winograd::winograd_6x3_1x1_f strategy(param.src_type, param.filter_type,
param.dst_type);
return megdnn::winograd::ConvBias<winograd::winograd_6x3_1x1_f>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_workspace_size(param, m_matmul_algo);
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoFP32WinogradF63::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 2) {
winograd::winograd_6x3_1x1_f strategy(param.src_type, param.filter_type,
param.dst_type);
auto winograd_impl =
megdnn::winograd::ConvBias<winograd::winograd_6x3_1x1_f>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg);
return winograd_impl.get_kerns(param, m_matmul_algo);
}
MIDOUT_END();
return {};
}

/* ======================= AlgoFP32WinogradF54 ======================== */

bool ConvBiasImpl::AlgoFP32WinogradF54::usable(
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(opr);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 0) {
using Strategy = winograd::winograd_5x4_1x1_f;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
(opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
opr->param().output_block_size == 5 &&
param.winograd_matmul_format ==
param::MatrixMul::Format::DEFAULT)) &&
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 4) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::Float32;
}
MIDOUT_END();
return false;
}

size_t ConvBiasImpl::AlgoFP32WinogradF54::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 1) {
winograd::winograd_5x4_1x1_f strategy(param.src_type, param.filter_type,
param.dst_type);
return megdnn::winograd::ConvBias<winograd::winograd_5x4_1x1_f>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_workspace_size(param, m_matmul_algo);
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoFP32WinogradF54::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 2) {
winograd::winograd_5x4_1x1_f strategy(param.src_type, param.filter_type,
param.dst_type);
auto winograd_impl =
megdnn::winograd::ConvBias<winograd::winograd_5x4_1x1_f>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg);
return winograd_impl.get_kerns(param, m_matmul_algo);
}
MIDOUT_END();
return {};
}

/* ======================= AlgoFP32WinogradF45 ======================== */

bool ConvBiasImpl::AlgoFP32WinogradF45::usable(
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(opr);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 0) {
using Strategy = winograd::winograd_4x5_1x1_f;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
(opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
opr->param().output_block_size == 4 &&
param.winograd_matmul_format ==
param::MatrixMul::Format::DEFAULT)) &&
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 5) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::Float32;
}
MIDOUT_END();
return false;
}

size_t ConvBiasImpl::AlgoFP32WinogradF45::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 1) {
winograd::winograd_4x5_1x1_f strategy(param.src_type, param.filter_type,
param.dst_type);
return megdnn::winograd::ConvBias<winograd::winograd_4x5_1x1_f>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_workspace_size(param, m_matmul_algo);
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoFP32WinogradF45::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 2) {
winograd::winograd_4x5_1x1_f strategy(param.src_type, param.filter_type,
param.dst_type);
auto winograd_impl =
megdnn::winograd::ConvBias<winograd::winograd_4x5_1x1_f>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg);
return winograd_impl.get_kerns(param, m_matmul_algo);
}
MIDOUT_END();
return {};
}

/* ======================= AlgoFP32WinogradF63_4x4 ======================== */

bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable(
fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(opr);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 0) {
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0)
return false;
using Strategy = winograd::winograd_6x3_4x4_f;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy,
param::MatrixMul::Format::MK4>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
(opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
opr->param().output_block_size == 6 &&
param.winograd_matmul_format ==
param::MatrixMul::Format::MK4)) &&
opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_meta.icpg % 4 == 0 &&
param.filter_meta.ocpg % 4 == 0;
}
MIDOUT_END();
return false;
}

size_t ConvBiasImpl::AlgoFP32WinogradF63_4x4::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 1) {
winograd::winograd_6x3_4x4_f strategy(param.src_type, param.filter_type,
param.dst_type);
return megdnn::winograd::ConvBias<winograd::winograd_6x3_4x4_f,
param::MatrixMul::Format::MK4>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg)
.get_workspace_size(param, m_matmul_algo);
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoFP32WinogradF63_4x4::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 2) {
winograd::winograd_6x3_4x4_f strategy(param.src_type, param.filter_type,
param.dst_type);
auto winograd_impl =
megdnn::winograd::ConvBias<winograd::winograd_6x3_4x4_f,
param::MatrixMul::Format::MK4>(
strategy, m_tile_size, param.nr_threads, param.osz[0],
param.osz[1], param.filter_meta.ocpg);
return winograd_impl.get_kerns(param, m_matmul_algo);
}
MIDOUT_END();
return {};
}

/* ===================== direct algo ===================== */
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl);

bool ConvBiasImpl::AlgoF32Direct::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto SH = fm.stride[0], SW = fm.stride[1];
// the condition ``param.isz[0]*param.isz[1] >= 4'' and
// ``param.osz[0]*param.osz[1] >= 4'' comes from the fact that the
// kernel may have access to up to 4 floats after the end of the memory
// chunk.
bool aviliable = fm.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 &&
param.isz[0] * param.isz[1] >= 4 &&
param.osz[0] * param.osz[1] >= 4 && FH <= 7 &&
SH == 1 && SW == 1;
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
}
MIDOUT_END();
return false;
}
size_t ConvBiasImpl::AlgoF32Direct::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) {
auto wbundle = MultithreadDirectConvCommon<float, float>::get_bundle(
param, m_large_group);
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
size_t N = param.n;
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
WorkspaceBundle wbundle =
MultithreadDirectConvCommon<float, float>::get_bundle(
param, m_large_group);
SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process
//! one group for better performance
if (m_large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [wbundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
auto fm = kern_param.filter_meta;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
WorkspaceBundle bundle = wbundle;
if (fm.should_flip) {
for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<float, float>::weight_flip_kern(
bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, oc});
}
}
for (size_t ic = 0; ic < IC; ic++) {
MultithreadDirectConvCommon<float, float>::copy_padding_kern(
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<float, float>::do_conv_kern(
bundle, kern_param, ncb_index,
fp32::conv_bias::kern_direct,
{ncb_index.thread_id, 0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
WorkspaceBundle bundle = wbundle;
if (fm.should_flip) {
auto weight_flip = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
MultithreadDirectConvCommon<float, float>::weight_flip_kern(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({weight_flip, {group, 1_z, OC}});
}
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
MultithreadDirectConvCommon<float, float>::copy_padding_kern(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
MultithreadDirectConvCommon<float, float>::do_conv_kern(
bundle, kern_param, ncb_index, fp32::conv_bias::kern_direct,
ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}
return ret_kerns;
}

SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) {
return get_kimpls(param);
}
MIDOUT_END();
return {};
}
/* ===================== stride-1 algo ===================== */
bool ConvBiasImpl::AlgoF32DirectStride1::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
!fm.should_flip && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
}
MIDOUT_END();
return false;
}

size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) {
auto bundle =
MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, m_large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride1::get_kimpls(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
auto FH = fm.spatial[0];
size_t N = param.n;
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
using Func = std::function<void(const float*, const float*, float*, size_t,
size_t, size_t, size_t, size_t)>;
Func conv_kern_function = nullptr;

#define SWITCH_KERN_STR1() \
switch (FH) { \
case 2: \
conv_kern_function = fp32::conv_stride1::do_conv_2x2_stride1; \
break; \
case 3: \
conv_kern_function = fp32::conv_stride1::do_conv_3x3_stride1; \
break; \
case 5: \
conv_kern_function = fp32::conv_stride1::do_conv_5x5_stride1; \
break; \
case 7: \
conv_kern_function = fp32::conv_stride1::do_conv_7x7_stride1; \
break; \
}
SWITCH_KERN_STR1();

WorkspaceBundle wbundle =
MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, m_large_group);
SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process
//! one group for better performance
if (m_large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [wbundle, conv_kern_function](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
auto fm = kern_param.filter_meta;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
WorkspaceBundle bundle = wbundle;
for (size_t ic = 0; ic < IC; ic++) {
MultithreadDirectConvCommon<float, float>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride(
bundle, kern_param, ncb_index, conv_kern_function,
{ncb_index.thread_id, 0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
WorkspaceBundle bundle = wbundle;
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
MultithreadDirectConvCommon<float, float>::copy_padding_kern_stride(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv_kern_function](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride(
bundle, kern_param, ncb_index, conv_kern_function,
ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}
return ret_kerns;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 2) {
return get_kimpls(param);
}
MIDOUT_END();
return {};
}

/* ===================== stride-2 algo ===================== */

bool ConvBiasImpl::AlgoF32DirectStride2::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
!fm.should_flip && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
}
MIDOUT_END();
return false;
}
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) {
auto bundle =
MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, m_large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
auto FH = fm.spatial[0];
size_t N = param.n;
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
using Func = std::function<void(const float*, const float*, float*, size_t,
size_t, size_t, size_t, size_t)>;
Func conv_kern_function = nullptr;

#define SWITCH_KERN_STR2() \
switch (FH) { \
case 2: \
conv_kern_function = fp32::conv_stride2::do_conv_2x2_stride2; \
break; \
case 3: \
conv_kern_function = fp32::conv_stride2::do_conv_3x3_stride2; \
break; \
case 5: \
conv_kern_function = fp32::conv_stride2::do_conv_5x5_stride2; \
break; \
case 7: \
conv_kern_function = fp32::conv_stride2::do_conv_7x7_stride2; \
break; \
}
SWITCH_KERN_STR2();

WorkspaceBundle wbundle =
MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, m_large_group);
SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process
//! one group for better performance
if (m_large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [wbundle, conv_kern_function](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
auto fm = kern_param.filter_meta;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
WorkspaceBundle bundle = wbundle;
for (size_t ic = 0; ic < IC; ic++) {
MultithreadDirectConvCommon<float, float>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride(
bundle, kern_param, ncb_index, conv_kern_function,
{ncb_index.thread_id, 0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
WorkspaceBundle bundle = wbundle;
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
MultithreadDirectConvCommon<float, float>::copy_padding_kern_stride(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv_kern_function](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride(
bundle, kern_param, ncb_index, conv_kern_function,
ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}
return ret_kerns;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 2) {
return get_kimpls(param);
}
MIDOUT_END();
return {};
}
// vim: syntax=cpp.doxygen

+ 223
- 0
dnn/src/arm_common/conv_bias/fp32/algos.h View File

@@ -0,0 +1,223 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/matrix_mul/opr_impl.h"

namespace megdnn {
namespace arm_common {

class ConvBiasImpl::AlgoFP32WinogradF23_4x4 final : public AlgoBase {
public:
AlgoFP32WinogradF23_4x4(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {4, 2, m_tile_size});
}
return m_name.c_str();
}
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;

private:
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;
uint32_t m_tile_size;
};

class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase {
public:
AlgoFP32WinogradF63(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {1, 6, m_tile_size});
}
return m_name.c_str();
}
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;

private:
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;

uint32_t m_tile_size;
};

class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase {
public:
AlgoFP32WinogradF63_4x4(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {4, 6, m_tile_size});
}
return m_name.c_str();
}
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;

private:
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;

uint32_t m_tile_size;
};

class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase {
public:
AlgoFP32WinogradF54(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {1, 5, m_tile_size});
}
return m_name.c_str();
}
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;

private:
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;

uint32_t m_tile_size;
};

class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase {
public:
AlgoFP32WinogradF45(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {1, 4, m_tile_size});
}
return m_name.c_str();
}
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;

private:
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;

uint32_t m_tile_size;
};


class ConvBiasImpl::AlgoF32Direct final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;

public:
AlgoF32Direct(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F32DIRECT_LARGE_GROUP"
: "F32DIRECT_SMALL_GROUP";
}
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;

size_t get_workspace(fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};

class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;

public:
AlgoF32DirectStride1(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F32STRD1_LARGE_GROUP" : "F32STRD1_SMALL_GROUP";
}
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;

size_t get_workspace(fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};

class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;

public:
AlgoF32DirectStride2(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F32STRD2_LARGE_GROUP" : "F32STRD2_SMALL_GROUP";
}
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;

size_t get_workspace(fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 911
- 0
dnn/src/arm_common/conv_bias/fp32/direct.cpp View File

@@ -0,0 +1,911 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/direct.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include <cstring>
#include "include/megdnn/oprs.h"
#include "midout.h"
#include "src/arm_common/conv_bias/fp32/direct.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/common/unroll_macro.h"
MIDOUT_DECL(megdnn_arm_conv_f32)

using namespace megdnn;
using namespace arm_common;
using namespace fp32;
using namespace conv_bias;

namespace {

template <int FH, int height, int width>
struct do_pixel_proxy {
static void exec(const float* src, const float* filter, float* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow);
};

#define cb_load(i) data = vld1q_lane_f32(dst + i, data, i);
#define LOAD_OUT \
if (width < 4) { \
auto load_less_4 = [](float* dst, float32x4_t& data) { \
if (width == 1u) { \
UNROLL_CALL_NOWRAPPER(1, cb_load); \
} else if (width == 2u) { \
UNROLL_CALL_NOWRAPPER(2, cb_load); \
} else if (width == 3u) { \
UNROLL_CALL_NOWRAPPER(3, cb_load); \
} \
}; \
if (height >= 1) \
load_less_4(dst + 0 * OW, out0); \
if (height >= 2) \
load_less_4(dst + 1 * OW, out1); \
if (height >= 3) \
load_less_4(dst + 2 * OW, out2); \
if (height >= 4) \
load_less_4(dst + 3 * OW, out3); \
} else { \
if (height > 0) \
out0 = vld1q_f32(dst + 0 * OW); \
if (height > 1) \
out1 = vld1q_f32(dst + 1 * OW); \
if (height > 2) \
out2 = vld1q_f32(dst + 2 * OW); \
if (height > 3) \
out3 = vld1q_f32(dst + 3 * OW); \
}
#define cb_store(i) vst1q_lane_f32(dst + i, data, i);
#define STORE_OUT \
if (width < 4) { \
auto store_less_4 = [](float* dst, float32x4_t& data) { \
if (width == 1u) { \
UNROLL_CALL_NOWRAPPER(1, cb_store); \
} else if (width == 2u) { \
UNROLL_CALL_NOWRAPPER(2, cb_store); \
} else if (width == 3u) { \
UNROLL_CALL_NOWRAPPER(3, cb_store); \
} \
}; \
if (height >= 1) \
store_less_4(dst + 0 * OW, out0); \
if (height >= 2) \
store_less_4(dst + 1 * OW, out1); \
if (height >= 3) \
store_less_4(dst + 2 * OW, out2); \
if (height >= 4) \
store_less_4(dst + 3 * OW, out3); \
} else { \
if (height >= 1) \
vst1q_f32(dst + 0 * OW, out0); \
if (height >= 2) \
vst1q_f32(dst + 1 * OW, out1); \
if (height >= 3) \
vst1q_f32(dst + 2 * OW, out2); \
if (height >= 4) \
vst1q_f32(dst + 3 * OW, out3); \
}

template <int height, int width>
struct do_pixel_proxy<1, height, width> {
static void exec(const float* src, const float* filter, float* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);

if (height > 1)
inp = vld1q_f32(src_dd + 1 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);

if (height > 2)
inp = vld1q_f32(src_dd + 2 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);

if (height > 3)
inp = vld1q_f32(src_dd + 3 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);
}
STORE_OUT;
}
};

template <int height, int width>
struct do_pixel_proxy<2, height, width> {
static void exec(const float* src, const float* filter, float* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);
kr1 = vdupq_n_f32(filter[1 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 1 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr1);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);

if (height > 1)
inp = vld1q_f32(src_dd + 2 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr1);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);

if (height > 2)
inp = vld1q_f32(src_dd + 3 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr1);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);

if (height > 3)
inp = vld1q_f32(src_dd + 4 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr1);
}
STORE_OUT;
}
};

template <int height, int width>
struct do_pixel_proxy<3, height, width> {
static void exec(const float* src, const float* filter, float* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);
kr1 = vdupq_n_f32(filter[1 * FW + fw]);
kr2 = vdupq_n_f32(filter[2 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 1 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr1);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 2 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr2);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr1);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);

if (height > 1)
inp = vld1q_f32(src_dd + 3 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr2);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr1);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);

if (height > 2)
inp = vld1q_f32(src_dd + 4 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr2);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr1);

if (height > 3)
inp = vld1q_f32(src_dd + 5 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr2);
}
STORE_OUT;
}
};

template <int height, int width>
struct do_pixel_proxy<4, height, width> {
static void exec(const float* src, const float* filter, float* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);
kr1 = vdupq_n_f32(filter[1 * FW + fw]);
kr2 = vdupq_n_f32(filter[2 * FW + fw]);
kr3 = vdupq_n_f32(filter[3 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 1 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr1);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 2 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr2);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr1);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 3 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr3);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr2);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr1);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);

if (height > 1)
inp = vld1q_f32(src_dd + 4 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr3);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr2);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr1);

if (height > 2)
inp = vld1q_f32(src_dd + 5 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr3);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr2);

if (height > 3)
inp = vld1q_f32(src_dd + 6 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr3);
}
STORE_OUT;
}
};

template <int height, int width>
struct do_pixel_proxy<5, height, width> {
static void exec(const float* src, const float* filter, float* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4,
inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);
kr1 = vdupq_n_f32(filter[1 * FW + fw]);
kr2 = vdupq_n_f32(filter[2 * FW + fw]);
kr3 = vdupq_n_f32(filter[3 * FW + fw]);
kr4 = vdupq_n_f32(filter[4 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 1 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr1);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 2 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr2);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr1);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 3 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr3);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr2);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr1);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 4 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr4);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr3);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr2);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr1);

if (height > 1)
inp = vld1q_f32(src_dd + 5 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr4);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr3);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr2);

if (height > 2)
inp = vld1q_f32(src_dd + 6 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr4);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr3);

if (height > 3)
inp = vld1q_f32(src_dd + 7 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr4);
}
STORE_OUT;
}
};

template <int height, int width>
struct do_pixel_proxy<6, height, width> {
static void exec(const float* src, const float* filter, float* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4,
kr5, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);
kr1 = vdupq_n_f32(filter[1 * FW + fw]);
kr2 = vdupq_n_f32(filter[2 * FW + fw]);
kr3 = vdupq_n_f32(filter[3 * FW + fw]);
kr4 = vdupq_n_f32(filter[4 * FW + fw]);
kr5 = vdupq_n_f32(filter[5 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 1 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr1);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 2 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr2);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr1);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 3 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr3);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr2);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr1);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 4 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr4);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr3);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr2);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr1);

if (height > 0)
inp = vld1q_f32(src_dd + 5 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr5);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr4);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr3);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr2);

if (height > 1)
inp = vld1q_f32(src_dd + 6 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr5);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr4);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr3);

if (height > 2)
inp = vld1q_f32(src_dd + 7 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr5);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr4);

if (height > 3)
inp = vld1q_f32(src_dd + 8 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr5);
}
STORE_OUT;
}
};

template <int height, int width>
struct do_pixel_proxy<7, height, width> {
static void exec(const float* src, const float* filter, float* dst,
const int IH, const int IW, const int OH, const int OW,
const int FW, const int oh, const int ow) {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4,
kr5, kr6, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);
kr1 = vdupq_n_f32(filter[1 * FW + fw]);
kr2 = vdupq_n_f32(filter[2 * FW + fw]);
kr3 = vdupq_n_f32(filter[3 * FW + fw]);
kr4 = vdupq_n_f32(filter[4 * FW + fw]);
kr5 = vdupq_n_f32(filter[5 * FW + fw]);
kr6 = vdupq_n_f32(filter[6 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 1 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr1);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 2 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr2);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr1);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 3 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr3);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr2);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr1);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 4 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr4);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr3);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr2);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr1);

if (height > 0)
inp = vld1q_f32(src_dd + 5 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr5);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr4);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr3);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr2);

if (height > 0)
inp = vld1q_f32(src_dd + 6 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr6);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr5);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr4);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr3);

if (height > 1)
inp = vld1q_f32(src_dd + 7 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr6);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr5);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr4);

if (height > 2)
inp = vld1q_f32(src_dd + 8 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr6);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr5);

if (height > 3)
inp = vld1q_f32(src_dd + 9 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr6);
}
STORE_OUT;
}
};
#undef cb_load
#undef cb_load
#undef LOAD_OUT
#undef STORE_OUT

template <int FH, int height, int width>
void do_pixel(const float* src, const float* filter, float* dst, const int IH,
const int IW, const int OH, const int OW, const int FW,
const int oh, const int ow) {
do_pixel_proxy<FH, height, width>::exec(src, filter, dst, IH, IW, OH, OW,
FW, oh, ow);
}

template <int FH>
void do_conv_tpl_enable_prefetch(const float* src, const float* filter,
float* dst, const int IH, const int IW,
const int OH, const int OW, const int FW) {
const int hbeg = 0, hend = OH;
const int wbeg = 0, wend = OW;
int i, j;
for (i = hbeg; i + 4 <= hend; i += 4) {
for (j = wbeg; j + 4 <= wend; j += 4) {
// do prefetch
const int prefetch_index_input =
(j + 16) < wend
? i * IW + j + 16
: (i + 4) * IW + (((j + 16 - wend) >> 2) << 2);
const int prefetch_index_output =
(j + 16) < wend
? i * OW + j + 16
: (i + 4) * OW + (((j + 16 - wend) >> 2) << 2);
const float* src_prefetch = src + prefetch_index_input;
const float* dst_prefetch = dst + prefetch_index_output;
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) {
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3);
}
__builtin_prefetch(dst_prefetch + 0 * OW, 1, 3);
__builtin_prefetch(dst_prefetch + 1 * OW, 1, 3);
__builtin_prefetch(dst_prefetch + 2 * OW, 1, 3);
__builtin_prefetch(dst_prefetch + 3 * OW, 1, 3);
do_pixel<FH, 4, 4>(src, filter, dst, IH, IW, OH, OW, FW, i, j);
}
#define DISPATCH(width) \
do { \
const int prefetch_index_input = (i + 4) * IW + 12; \
const int prefetch_index_output = (i + 4) * OW + 12; \
const float* src_prefetch = src + prefetch_index_input; \
const float* dst_prefetch = dst + prefetch_index_output; \
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \
} \
__builtin_prefetch(dst_prefetch + 0 * OW, 1, 3); \
__builtin_prefetch(dst_prefetch + 1 * OW, 1, 3); \
__builtin_prefetch(dst_prefetch + 2 * OW, 1, 3); \
__builtin_prefetch(dst_prefetch + 3 * OW, 1, 3); \
do_pixel<FH, 4, width>(src, filter, dst, IH, IW, OH, OW, FW, i, j); \
} while (0)
switch (wend - j) {
case 1:
DISPATCH(1);
break;
case 2:
DISPATCH(2);
break;
case 3:
DISPATCH(3);
break;
}
#undef DISPATCH
}

#define DISPATCH2(height, width) \
do { \
const int prefetch_index_input = IH * IW + 12; \
const float* src_prefetch = src + prefetch_index_input; \
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \
} \
do_pixel<FH, height, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \
j); \
} while (0)

#define DISPATCH1(height) \
do { \
for (j = wbeg; j + 4 <= wend; j += 4) { \
const int prefetch_index_input = \
(j + 16) < wend \
? i * IW + j + 16 \
: (i + 4) * IW + (((j + 16 - wend) >> 2) << 2); \
const int prefetch_index_output = \
(j + 16) < wend \
? i * OW + j + 16 \
: (i + 4) * OW + (((j + 16 - wend) >> 2) << 2); \
const float* src_prefetch = src + prefetch_index_input; \
const float* dst_prefetch = dst + prefetch_index_output; \
for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \
__builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \
} \
__builtin_prefetch(dst_prefetch + 0 * OW, 1, 3); \
__builtin_prefetch(dst_prefetch + 1 * OW, 1, 3); \
__builtin_prefetch(dst_prefetch + 2 * OW, 1, 3); \
__builtin_prefetch(dst_prefetch + 3 * OW, 1, 3); \
do_pixel<FH, height, 4>(src, filter, dst, IH, IW, OH, OW, FW, i, \
j); \
} \
switch (wend - j) { \
case 1: \
DISPATCH2(height, 1); \
break; \
case 2: \
DISPATCH2(height, 2); \
break; \
case 3: \
DISPATCH2(height, 3); \
break; \
} \
} while (0)
switch (hend - i) {
case 1:
DISPATCH1(1);
break;
case 2:
DISPATCH1(2);
break;
case 3:
DISPATCH1(3);
break;
}
#undef DISPATCH1
#undef DISPATCH2
}
template <int FH>
void do_conv_tpl_disable_prefetch(const float* src, const float* filter,
float* dst, const int IH, const int IW,
const int OH, const int OW, const int FW) {
const int hbeg = 0, hend = OH;
const int wbeg = 0, wend = OW;
int i, j;
for (i = hbeg; i + 4 <= hend; i += 4) {
for (j = wbeg; j + 4 <= wend; j += 4) {
do_pixel<FH, 4, 4>(src, filter, dst, IH, IW, OH, OW, FW, i, j);
}
#define DISPATCH(width) \
do { \
do_pixel<FH, 4, width>(src, filter, dst, IH, IW, OH, OW, FW, i, j); \
} while (0)
switch (wend - j) {
case 1:
DISPATCH(1);
break;
case 2:
DISPATCH(2);
break;
case 3:
DISPATCH(3);
break;
}
#undef DISPATCH
}
#define DISPATCH2(height, width) \
do { \
do_pixel<FH, height, width>(src, filter, dst, IH, IW, OH, OW, FW, i, \
j); \
} while (0)
#define DISPATCH1(height) \
do { \
for (j = wbeg; j + 4 <= wend; j += 4) { \
do_pixel<FH, height, 4>(src, filter, dst, IH, IW, OH, OW, FW, i, \
j); \
} \
switch (wend - j) { \
case 1: \
DISPATCH2(height, 1); \
break; \
case 2: \
DISPATCH2(height, 2); \
break; \
case 3: \
DISPATCH2(height, 3); \
break; \
} \
} while (0)
switch (hend - i) {
case 1:
DISPATCH1(1);
break;
case 2:
DISPATCH1(2);
break;
case 3:
DISPATCH1(3);
break;
}
#undef DISPATCH1
#undef DISPATCH2
}
} // anonymous namespace

void conv_bias::kern_direct(const float* src, const float* filter, float* dst,
const int IH, const int IW, const int OH,
const int OW, const int FH, const int FW) {
megdnn_assert_internal(FH <= 7);
if (IH > 100 && IW > 100) {
#define GAO(FH) \
do { \
return do_conv_tpl_enable_prefetch<FH>(src, filter, dst, IH, IW, OH, \
OW, FW); \
} while (0)
switch (FH) {
case 1:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); }
MIDOUT_END();
break;
case 2:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); }
MIDOUT_END();
break;
case 3:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); }
MIDOUT_END();
break;
case 4:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); }
MIDOUT_END();
break;
case 5:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); }
MIDOUT_END();
break;
case 6:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); }
MIDOUT_END();
break;
case 7:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); }
MIDOUT_END();
break;
}
#undef GAO
} else {
#define GAO(FH) \
do { \
return do_conv_tpl_disable_prefetch<FH>(src, filter, dst, IH, IW, OH, \
OW, FW); \
} while (0)
switch (FH) {
case 1:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); }
MIDOUT_END();
break;
case 2:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); }
MIDOUT_END();
break;
case 3:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); }
MIDOUT_END();
break;
case 4:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); }
MIDOUT_END();
break;
case 5:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); }
MIDOUT_END();
break;
case 6:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); }
MIDOUT_END();
break;
case 7:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); }
MIDOUT_END();
break;
}
#undef GAO
}
megdnn_assert_internal(0);
}

// vim: syntax=cpp.doxygen

+ 29
- 0
dnn/src/arm_common/conv_bias/fp32/direct.h View File

@@ -0,0 +1,29 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/direct.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include <cstddef>

namespace megdnn {
namespace arm_common {
namespace fp32{
namespace conv_bias {

void kern_direct(const float *src, const float *filter, float *dst,
const int IH, const int IW, const int OH, const int OW,
const int FH, const int FW);

} // namespace convolution
} // namespace fp32
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 735
- 0
dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp View File

@@ -0,0 +1,735 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include <algorithm>

#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h"
#include "src/arm_common/simd_macro/neon_helper.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"

#include "midout.h"

MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs1)

using namespace megdnn;
using namespace arm_common;
using namespace fp32;
using namespace conv_stride1;

using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam;
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam;


void conv_stride1::do_conv_2x2_stride1(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW,
size_t IC) {
const size_t tail_step = IW - OW;
//! unroll of 2
size_t ic = 0;
for (; ic + 1 < IC; ic += 2) {
const float* src_ptr = src + IW * IH * ic;
const float* src_ptr1 = src_ptr + IW * IH;
float* outptr = dst;

const float* r00 = src_ptr;
const float* r01 = src_ptr + IW;
const float* r10 = src_ptr1;
const float* r11 = src_ptr1 + IW;

const float* k0 = filter + ic * 4;
const float* k1 = k0 + 4;

MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_LOADU(k0);
MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_LOADU(k1);
rep(h, OH) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _r000 = MEGDNN_SIMD_LOADU(r00);
MEGDNN_SIMD_TYPE _r010 = MEGDNN_SIMD_LOADU(r01);
MEGDNN_SIMD_TYPE _r001 = MEGDNN_SIMD_LOADU(r00 + 1);
MEGDNN_SIMD_TYPE _r011 = MEGDNN_SIMD_LOADU(r01 + 1);

MEGDNN_SIMD_TYPE _r100 = MEGDNN_SIMD_LOADU(r10);
MEGDNN_SIMD_TYPE _r110 = MEGDNN_SIMD_LOADU(r11);
MEGDNN_SIMD_TYPE _r101 = MEGDNN_SIMD_LOADU(r10 + 1);
MEGDNN_SIMD_TYPE _r111 = MEGDNN_SIMD_LOADU(r11 + 1);

MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);

_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r000,
MEGDNN_SIMD_GET_LOW(_k0), 0);
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r001,
MEGDNN_SIMD_GET_LOW(_k0), 1);
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r010,
MEGDNN_SIMD_GET_HIGH(_k0), 0);
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r011,
MEGDNN_SIMD_GET_HIGH(_k0), 1);

_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r100,
MEGDNN_SIMD_GET_LOW(_k1), 0);
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r101,
MEGDNN_SIMD_GET_LOW(_k1), 1);
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r110,
MEGDNN_SIMD_GET_HIGH(_k1), 0);
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r111,
MEGDNN_SIMD_GET_HIGH(_k1), 1);

MEGDNN_SIMD_STOREU(outptr, _sum);

r00 += 4;
r01 += 4;
r10 += 4;
r11 += 4;
outptr += 4;
}

r00 += tail_step;
r01 += tail_step;
r10 += tail_step;
r11 += tail_step;
}
}
for (; ic < IC; ic++) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;

const float* k0 = filter + ic * 4;

MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_SET1(k0[0]);
MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_SET1(k0[1]);
MEGDNN_SIMD_TYPE _k2 = MEGDNN_SIMD_SET1(k0[2]);
MEGDNN_SIMD_TYPE _k3 = MEGDNN_SIMD_SET1(k0[3]);
rep(h, OH) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0);
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1);
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_LOADU(r0 + 1);
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_LOADU(r1 + 1);

MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);
MEGDNN_SIMD_TYPE _sum2;

_sum = MEGDNN_SIMD_FMADD(_r00, _k0, _sum);
_sum2 = MEGDNN_SIMD_MUL(_r01, _k1);
_sum = MEGDNN_SIMD_FMADD(_r10, _k2, _sum);
_sum2 = MEGDNN_SIMD_FMADD(_r11, _k3, _sum2);

_sum = MEGDNN_SIMD_ADD(_sum, _sum2);

MEGDNN_SIMD_STOREU(outptr, _sum);

r0 += 4;
r1 += 4;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
}
}
}

void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW,
size_t IC) {
const size_t tail_step = IW - OW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;
float* outptr2 = outptr + OW;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;

const float* k0 = filter;
const float* k1 = filter + 3;
const float* k2 = filter + 5;

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0);
MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1);
MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2);
MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1);

size_t h = 0;
for (; h + 1 < OH; h += 2) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr);
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f);
MEGDNN_SIMD_TYPE _sum3 = MEGDNN_SIMD_LOADU(outptr2);
MEGDNN_SIMD_TYPE _sum4 = MEGDNN_SIMD_SET1(0.f);

MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0);
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4);
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1);
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2);

MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1);
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4);
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1);
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2);

MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2);
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4);
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1);
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2);

MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3);
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 4);
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r30n, 1);
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r30n, 2);

_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2);

_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r10, _k0123, 0);
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r11, _k0123, 1);
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r12, _k0123, 2);
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r20, _k3456, 0);
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r21, _k3456, 1);
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r22, _k3456, 2);
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r30, _k6789, 0);
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r31, _k6789, 1);
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r32, _k6789, 2);

_sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2);
_sum3 = MEGDNN_SIMD_ADD(_sum3, _sum4);

MEGDNN_SIMD_STOREU(outptr, _sum1);
MEGDNN_SIMD_STOREU(outptr2, _sum3);

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;
outptr += 4;
outptr2 += 4;
}

r0 += tail_step + IW;
r1 += tail_step + IW;
r2 += tail_step + IW;
r3 += tail_step + IW;

outptr += OW;
outptr2 += OW;
}

for (; h < OH; h++) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr);
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f);

MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0);
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4);
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1);
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2);

MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1);
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4);
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1);
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2);

MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2);
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4);
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1);
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2);

_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2);

_sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2);

MEGDNN_SIMD_STOREU(outptr, _sum1);

r0 += 4;
r1 += 4;
r2 += 4;
outptr += 4;
}
r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
}

filter += 9;
}
}

void conv_stride1::do_conv_5x5_stride1(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW,
size_t IC) {
const size_t tail_step = IW - OW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;
float* outptr2 = outptr + OW;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;
const float* r4 = src_ptr + IW * 4;
const float* r5 = src_ptr + IW * 5;

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter);
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4);
MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8);
MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12);
MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16);
MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20);
MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]);

size_t h = 0;
for (; h + 1 < OH; h += 2) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_LOADU(outptr2);

MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0);
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4);
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1);
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2);
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3);

MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1);
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4);
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1);
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2);
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3);

MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2);
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4);
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1);
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2);
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3);

MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3);
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4);
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1);
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2);
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3);

MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4);
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4);
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1);
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2);
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3);

MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5);
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4);
MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1);
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2);
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0);

_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k0123, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r11, _k0123, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k0123, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r13, _k0123, 3);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r14, _k4567, 0);

_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r20, _k4567, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k4567, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r22, _k4567, 3);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r23, _k891011, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r24, _k891011, 1);

_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r30, _k891011, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r31, _k891011, 3);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r32, _k12131415, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r33, _k12131415, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r34, _k12131415, 2);

_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r40, _k12131415, 3);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r41, _k16171819, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r42, _k16171819, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r43, _k16171819, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r44, _k16171819, 3);

_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r50, _k20212223, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r51, _k20212223, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r52, _k20212223, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r53, _k20212223, 3);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r54, _k24242424, 0);

MEGDNN_SIMD_STOREU(outptr, _sum);
MEGDNN_SIMD_STOREU(outptr2, _sum2);

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;
r4 += 4;
r5 += 4;
outptr += 4;
outptr2 += 4;
}

r0 += tail_step + IW;
r1 += tail_step + IW;
r2 += tail_step + IW;
r3 += tail_step + IW;
r4 += tail_step + IW;
r5 += tail_step + IW;

outptr += OW;
outptr2 += OW;
}

for (; h < OH; h++) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);

MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0);
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4);
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1);
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2);
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3);

MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1);
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4);
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1);
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2);
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3);

MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2);
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4);
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1);
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2);
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3);

MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3);
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4);
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1);
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2);
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3);

MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4);
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4);
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1);
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2);
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0);

MEGDNN_SIMD_STOREU(outptr, _sum);

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;
r4 += 4;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
r3 += tail_step;
r4 += tail_step;
}

filter += 25;
}
}

void conv_stride1::do_conv_7x7_stride1(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW,
size_t IC) {
const size_t tail_step = IW - OW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;
const float* r4 = src_ptr + IW * 4;
const float* r5 = src_ptr + IW * 5;
const float* r6 = src_ptr + IW * 6;

const float* k0 = filter;
const float* k1 = filter + 7;
const float* k2 = filter + 14;
const float* k3 = filter + 21;
const float* k4 = filter + 28;
const float* k5 = filter + 35;
const float* k6 = filter + 42;

for (size_t i = 0; i < OH; i++) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0);
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4);

MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); // 0 1 2 3
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); // 4 5 6 7
MEGDNN_SIMD_TYPE _r00n =
MEGDNN_SIMD_LOADU(r0 + 8); // 8 9 10 11
MEGDNN_SIMD_TYPE _r01 =
MEGDNN_SIMD_EXT(_r00, _r04, 1); // 1 2 3 4
MEGDNN_SIMD_TYPE _r02 =
MEGDNN_SIMD_EXT(_r00, _r04, 2); // 2 3 4 5
MEGDNN_SIMD_TYPE _r03 =
MEGDNN_SIMD_EXT(_r00, _r04, 3); // 3 4 5 6
MEGDNN_SIMD_TYPE _r05 =
MEGDNN_SIMD_EXT(_r04, _r00n, 1); // 5 6 7 8
MEGDNN_SIMD_TYPE _r06 =
MEGDNN_SIMD_EXT(_r04, _r00n, 2); // 6 7 8 9

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2);

MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1);
MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4);

MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1);
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4);
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 8);
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1);
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2);
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3);
MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r14, _r10n, 1);
MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r14, _r10n, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2);

MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2);
MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4);

MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2);
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4);
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 8);
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1);
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2);
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3);
MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r24, _r20n, 1);
MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r24, _r20n, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2);

MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3);
MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4);

MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3);
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4);
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 8);
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1);
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2);
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3);
MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r34, _r30n, 1);
MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r34, _r30n, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2);

MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4);
MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4);

MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4);
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4);
MEGDNN_SIMD_TYPE _r40n = MEGDNN_SIMD_LOADU(r4 + 8);
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1);
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2);
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3);
MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r44, _r40n, 1);
MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r44, _r40n, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2);

MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5);
MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4);

MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5);
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4);
MEGDNN_SIMD_TYPE _r50n = MEGDNN_SIMD_LOADU(r5 + 8);
MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1);
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2);
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3);
MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r54, _r50n, 1);
MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r54, _r50n, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2);

MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6);
MEGDNN_SIMD_TYPE _k46474849 = MEGDNN_SIMD_LOADU(k6 + 4);

MEGDNN_SIMD_TYPE _r60 = MEGDNN_SIMD_LOADU(r6);
MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_LOADU(r6 + 4);
MEGDNN_SIMD_TYPE _r60n = MEGDNN_SIMD_LOADU(r6 + 8);
MEGDNN_SIMD_TYPE _r61 = MEGDNN_SIMD_EXT(_r60, _r64, 1);
MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r64, 2);
MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r60, _r64, 3);
MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r64, _r60n, 1);
MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r64, _r60n, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k46474849, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k46474849, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k46474849, 2);

MEGDNN_SIMD_STOREU(outptr, _sum);

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;
r4 += 4;
r5 += 4;
r6 += 4;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
r3 += tail_step;
r4 += tail_step;
r5 += tail_step;
r6 += tail_step;
}
filter += 49;
}
}

#include "src/common/simd_macro/epilogue.h"
// vim: syntax=cpp.doxygen

+ 34
- 0
dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h View File

@@ -0,0 +1,34 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include <cstddef>

namespace megdnn {
namespace arm_common {
namespace fp32 {
namespace conv_stride1 {

void do_conv_2x2_stride1(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC);
void do_conv_3x3_stride1(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC);
void do_conv_5x5_stride1(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC);
void do_conv_7x7_stride1(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC);
} // namespace conv_stride1
} // namespace fp32
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen


+ 513
- 0
dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp View File

@@ -0,0 +1,513 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include <algorithm>

#include "./do_conv_stride2.h"
#include "midout.h"
#include "src/arm_common/simd_macro/neon_helper.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"

MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs2)

using namespace megdnn;
using namespace arm_common;
using namespace fp32;
using namespace conv_stride2;

using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam;
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam;


void conv_stride2::do_conv_2x2_stride2(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW,
size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;

const float* k0 = filter;

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0);
rep(h, OH) {
int nn = OW >> 2;

rep(i, nn) {
MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr);

MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0);

MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6
MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7

_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1);

MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1);

MEGDNN_SIMD_TYPE _r10 = _r1.val[0];
MEGDNN_SIMD_TYPE _r11 = _r1.val[1];

_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k0123, 2);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k0123, 3);

MEGDNN_SIMD_STOREU(outptr, _outp);

r0 += 8;
r1 += 8;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
}

filter += 4;
}
}

void conv_stride2::do_conv_3x3_stride2(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW,
size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;

const float* k0 = filter;
const float* k1 = filter + 3;
const float* k2 = filter + 5;

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0);
MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1);
MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2);
MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1);
rep(h, OH) {
int nn = OW >> 2;

rep(i, nn) {
MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr);

MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0);
MEGDNN_SIMD_TYPE2 _r0n = MEGDNN_SIMD_LOAD2(r0 + 8);

MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6
MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7
MEGDNN_SIMD_TYPE _r02 =
MEGDNN_SIMD_EXT(_r00, _r0n.val[0], 1); // 2 4 6 8

_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r02, _k0123, 2);

MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1);
MEGDNN_SIMD_TYPE2 _r1n = MEGDNN_SIMD_LOAD2(r1 + 8);

MEGDNN_SIMD_TYPE _r10 = _r1.val[0];
MEGDNN_SIMD_TYPE _r11 = _r1.val[1];
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1n.val[0], 1);

_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k3456, 0);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k3456, 1);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r12, _k3456, 2);

MEGDNN_SIMD_TYPE2 _r2 = MEGDNN_SIMD_LOAD2(r2);
MEGDNN_SIMD_TYPE2 _r2n = MEGDNN_SIMD_LOAD2(r2 + 8);

MEGDNN_SIMD_TYPE _r20 = _r2.val[0];
MEGDNN_SIMD_TYPE _r21 = _r2.val[1];
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2n.val[0], 1);

_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r20, _k6789, 0);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r21, _k6789, 1);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r22, _k6789, 2);

MEGDNN_SIMD_STOREU(outptr, _outp);

r0 += 8;
r1 += 8;
r2 += 8;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
}

filter += 9;
}
}

void conv_stride2::do_conv_5x5_stride2(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW,
size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;
const float* r4 = src_ptr + IW * 4;

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter);
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4);
MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8);
MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12);
MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16);
MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20);
MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]);

for (size_t i = 0; i < OH; i++) {
int nn = OW >> 2;

rep(i, nn) {
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);

MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0);
MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8);
MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14
MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15
MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6
MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7
MEGDNN_SIMD_TYPE _r02 =
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8
MEGDNN_SIMD_TYPE _r03 =
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9
MEGDNN_SIMD_TYPE _r04 =
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10

MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1);
MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8);
MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0];
MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1];
MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0];
MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1];
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1);
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1);
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2);

MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2);
MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8);
MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0];
MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1];
MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0];
MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1];
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1);
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1);
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2);

MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3);
MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8);
MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0];
MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1];
MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0];
MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1];
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1);
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1);
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2);

MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4);
MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8);
MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0];
MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1];
MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0];
MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1];
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1);
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1);
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0);

MEGDNN_SIMD_STOREU(outptr, _sum);

r0 += 8;
r1 += 8;
r2 += 8;
r3 += 8;
r4 += 8;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
r3 += tail_step;
r4 += tail_step;
}

filter += 25;
}
}

void conv_stride2::do_conv_7x7_stride2(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW,
size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;
const float* r4 = src_ptr + IW * 4;
const float* r5 = src_ptr + IW * 5;
const float* r6 = src_ptr + IW * 6;

const float* k0 = filter;
const float* k1 = filter + 7;
const float* k2 = filter + 14;
const float* k3 = filter + 21;
const float* k4 = filter + 28;
const float* k5 = filter + 35;
const float* k6 = filter + 42;

for (size_t i = 0; i < OH; i++) {
int nn = OW >> 2;

rep(i, nn) {
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0);
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4);

MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0);
MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8);
MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14
MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15
MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6
MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7
MEGDNN_SIMD_TYPE _r02 =
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8
MEGDNN_SIMD_TYPE _r03 =
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9
MEGDNN_SIMD_TYPE _r04 =
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10
MEGDNN_SIMD_TYPE _r05 =
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 2); // 5 7 9 11
MEGDNN_SIMD_TYPE _r06 =
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 3); // 6 8 10 12

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2);

MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1);
MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4);

MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1);
MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8);
MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0];
MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1];
MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0];
MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1];
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1);
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1);
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2);
MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 2);
MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2);

MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2);
MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4);

MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2);
MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8);
MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0];
MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1];
MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0];
MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1];
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1);
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1);
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2);
MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 2);
MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2);

MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3);
MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4);

MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3);
MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8);
MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0];
MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1];
MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0];
MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1];
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1);
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1);
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2);
MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 2);
MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2);

MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4);
MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4);

MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4);
MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8);
MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0];
MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1];
MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0];
MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1];
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1);
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1);
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2);
MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 2);
MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2);

MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5);
MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4);

MEGDNN_SIMD_TYPE2 _r50_02461357 = MEGDNN_SIMD_LOAD2(r5);
MEGDNN_SIMD_TYPE2 _r50nx2 = MEGDNN_SIMD_LOAD2(r5 + 8);
MEGDNN_SIMD_TYPE _r5_8101214 = _r50nx2.val[0];
MEGDNN_SIMD_TYPE _r5_9111315 = _r50nx2.val[1];
MEGDNN_SIMD_TYPE _r50 = _r50_02461357.val[0];
MEGDNN_SIMD_TYPE _r51 = _r50_02461357.val[1];
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 1);
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 1);
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 2);
MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 2);
MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2);

MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6);
MEGDNN_SIMD_TYPE _k45464748 = MEGDNN_SIMD_LOADU(k6 + 3);

MEGDNN_SIMD_TYPE2 _r60_02461357 = MEGDNN_SIMD_LOAD2(r6);
MEGDNN_SIMD_TYPE2 _r60nx2 = MEGDNN_SIMD_LOAD2(r6 + 8);
MEGDNN_SIMD_TYPE _r6_8101214 = _r60nx2.val[0];
MEGDNN_SIMD_TYPE _r6_9111315 = _r60nx2.val[1];
MEGDNN_SIMD_TYPE _r60 = _r60_02461357.val[0];
MEGDNN_SIMD_TYPE _r61 = _r60_02461357.val[1];
MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 1);
MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 1);
MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 2);
MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 2);
MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k45464748, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k45464748, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k45464748, 3);

MEGDNN_SIMD_STOREU(outptr, _sum);

r0 += 8;
r1 += 8;
r2 += 8;
r3 += 8;
r4 += 8;
r5 += 8;
r6 += 8;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
r3 += tail_step;
r4 += tail_step;
r5 += tail_step;
r6 += tail_step;
}
filter += 49;
}
}
// vim: syntax=cpp.doxygen

+ 32
- 0
dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h View File

@@ -0,0 +1,32 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "src/fallback/conv_bias/opr_impl.h"

namespace megdnn {
namespace arm_common {
namespace fp32 {
namespace conv_stride2 {
void do_conv_2x2_stride2(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC);
void do_conv_3x3_stride2(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC);
void do_conv_5x5_stride2(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC);
void do_conv_7x7_stride2(const float* src, const float* filter, float* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC);
} // namespace conv_stride2
} // namespace fp32
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 164
- 0
dnn/src/arm_common/conv_bias/fp32/filter_transform.h View File

@@ -0,0 +1,164 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/filter_transform.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/arm_common/conv_bias/fp32/helper.h"
#include "megdnn/opr_param_defs.h"

namespace megdnn {
namespace arm_common {

template <param::MatrixMul::Format format=param::MatrixMul::Format::DEFAULT>
struct FilterTransform6X3 {
#define FILTER_TRANSFORM(d, wd) \
do { \
wd##0 = d##0; \
auto tmp0 = (d##0 + d##2) * -0.2222222f; \
auto tmp1 = d##1 * -0.2222222f; \
wd##1 = tmp0 + tmp1; \
wd##2 = tmp0 - tmp1; \
tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \
tmp1 = d##1 * 0.0222222f; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \
tmp1 = d##1 * 0.3555556f; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = d##2; \
} while (0);

static void transform(const float* filter, float* filter_transform_buf,
float* transform_mid_buf, size_t OC, size_t IC,
size_t oc_start, size_t oc_end) {
// Gg * GT
// G
// 1.0000000 0.0000000 0.0000000
// -0.2222222 -0.2222222 -0.2222222
// -0.2222222 0.2222222 -0.2222222
// 0.0111111 0.0222222 0.0444444
// 0.0111111 -0.0222222 0.0444444
// 0.7111111 0.3555556 0.1777778
// 0.7111111 -0.3555556 0.1777778
// 0.0000000 0.0000000 1.0000000

constexpr size_t alpha = 6 + 3 - 1;
size_t OCB = OC / 4;
size_t ICB = IC / 4;
for (size_t oc = oc_start; oc < oc_end; oc++) {
rep(ic, IC) {
const float* fptr = filter + (oc * IC + ic) * 3 * 3;

Vector<float, 4> g0 = Vector<float, 4>::load(fptr);
Vector<float, 4> g1 = Vector<float, 4>::load(fptr + 3);

Vector<float, 4> g2 = Vector<float, 4>::load(fptr + 6 - 1);
float32x4_t zeros = vdupq_n_f32(0.0f);
g2.value = vextq_f32(g2.value, zeros, 1);

#define cb(i) Vector<float, 4> wd##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

#define cb(i) Vector<float, 8> wdt##i;
UNROLL_CALL_NOWRAPPER(3, cb);
#undef cb

#define cb(i) Vector<float, 8> ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

FILTER_TRANSFORM(g, wd);

size_t ocb = oc / 4;
size_t oc4 = oc % 4;
size_t icb = ic / 4;
size_t ic4 = ic % 4;
#if MEGDNN_AARCH64
TRANSPOSE_8x3(wd, wdt);
FILTER_TRANSFORM(wdt, ret);

#define cb(i) ret##i.save(transform_mid_buf + i * alpha);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
rep(i, alpha) rep(j, alpha) {
if (format == param::MatrixMul::Format::DEFAULT) {
filter_transform_buf[(i * alpha + j) * OC * IC +
ic * OC + oc] =
transform_mid_buf[j * alpha + i];
} else {
filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 *
4 +
ocb * ICB * 4 * 4 + icb * 4 * 4 +
ic4 * 4 + oc4] =
transform_mid_buf[j * alpha + i];
}
}

#else

#define cb(i) \
do { \
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \
auto tmp0 = (GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * \
-0.2222222f; \
auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * -0.2222222f; \
mid_buf1[1] = tmp0 + tmp1; \
mid_buf1[2] = tmp0 - tmp1; \
tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0111111f + \
GET_VECTOR_ELEM(wd, i, 2) * 0.0444444f; \
tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0222222f; \
mid_buf1[3] = tmp0 + tmp1; \
mid_buf1[4] = tmp0 - tmp1; \
tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.7111111f + \
GET_VECTOR_ELEM(wd, i, 2) * 0.1777778f; \
tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.3555556f; \
mid_buf1[5] = tmp0 + tmp1; \
mid_buf1[6] = tmp0 - tmp1; \
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \
mid_buf1 += 8; \
} while (0);
#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx)

float* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb);
mid_buf1 = transform_mid_buf;
#undef cb

rep(i, alpha) rep(j, alpha) {
if (format == param::MatrixMul::Format::DEFAULT) {
filter_transform_buf[(i * alpha + j) * OC * IC +
ic * OC + oc] =
transform_mid_buf[i * alpha + j];
} else {
filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 *
4 +
ocb * ICB * 4 * 4 + icb * 4 * 4 +
ic4 * 4 + oc4] =
transform_mid_buf[i * alpha + j];
}
}
#endif
}
}
}
};
#undef FILTER_TRANSFORM
#undef GET_VECTOR_ELEM

} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 204
- 0
dnn/src/arm_common/conv_bias/fp32/helper.h View File

@@ -0,0 +1,204 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once
#include "src/common/unroll_macro.h"
#include "src/arm_common/simd_macro/marm_neon.h"

namespace megdnn {
namespace arm_common {
inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) {
float32x4x2_t a0, a1;
a0.val[0] = vld1q_f32(src + 0 * lda);
a0.val[1] = vld1q_f32(src + 1 * lda);
a1.val[0] = vld1q_f32(src + 2 * lda);
a1.val[1] = vld1q_f32(src + 3 * lda);
float32x4x2_t b0 = vzipq_f32(a0.val[0], a1.val[0]);
float32x4x2_t b1 = vzipq_f32(a0.val[1], a1.val[1]);
float32x4x2_t c0 = vzipq_f32(b0.val[0], b1.val[0]);
float32x4x2_t c1 = vzipq_f32(b0.val[1], b1.val[1]);
vst1q_f32(dst + 0 * ldb, c0.val[0]);
vst1q_f32(dst + 1 * ldb, c0.val[1]);
vst1q_f32(dst + 2 * ldb, c1.val[0]);
vst1q_f32(dst + 3 * ldb, c1.val[1]);
}
} // namespace arm_common
} // namespace megdnn

#define MATRIX_MUL4x4(sum, a, b) \
sum##0 = vmlaq_low_lane_f32(sum##0, b##0, a##0, 0); \
sum##0 = vmlaq_low_lane_f32(sum##0, b##1, a##0, 1); \
sum##0 = vmlaq_high_lane_f32(sum##0, b##2, a##0, 2); \
sum##0 = vmlaq_high_lane_f32(sum##0, b##3, a##0, 3); \
sum##1 = vmlaq_low_lane_f32(sum##1, b##0, a##1, 0); \
sum##1 = vmlaq_low_lane_f32(sum##1, b##1, a##1, 1); \
sum##1 = vmlaq_high_lane_f32(sum##1, b##2, a##1, 2); \
sum##1 = vmlaq_high_lane_f32(sum##1, b##3, a##1, 3); \
sum##2 = vmlaq_low_lane_f32(sum##2, b##0, a##2, 0); \
sum##2 = vmlaq_low_lane_f32(sum##2, b##1, a##2, 1); \
sum##2 = vmlaq_high_lane_f32(sum##2, b##2, a##2, 2); \
sum##2 = vmlaq_high_lane_f32(sum##2, b##3, a##2, 3); \
sum##3 = vmlaq_low_lane_f32(sum##3, b##0, a##3, 0); \
sum##3 = vmlaq_low_lane_f32(sum##3, b##1, a##3, 1); \
sum##3 = vmlaq_high_lane_f32(sum##3, b##2, a##3, 2); \
sum##3 = vmlaq_high_lane_f32(sum##3, b##3, a##3, 3);

#define CONCAT(a, idx) a##idx

#if MEGDNN_AARCH64
//! ret and a are type Vector<float, 8>
#define TRANSPOSE_8x8(a, ret) \
do { \
auto b0 = vzipq_f32(CONCAT(a, 0).value.val[0], \
CONCAT(a, 1).value.val[0]); \
auto b1 = vzipq_f32(CONCAT(a, 0).value.val[1], \
CONCAT(a, 1).value.val[1]); \
auto b2 = vzipq_f32(CONCAT(a, 2).value.val[0], \
CONCAT(a, 3).value.val[0]); \
auto b3 = vzipq_f32(CONCAT(a, 2).value.val[1], \
CONCAT(a, 3).value.val[1]); \
auto b4 = vzipq_f32(CONCAT(a, 4).value.val[0], \
CONCAT(a, 5).value.val[0]); \
auto b5 = vzipq_f32(CONCAT(a, 4).value.val[1], \
CONCAT(a, 5).value.val[1]); \
auto b6 = vzipq_f32(CONCAT(a, 6).value.val[0], \
CONCAT(a, 7).value.val[0]); \
auto b7 = vzipq_f32(CONCAT(a, 6).value.val[1], \
CONCAT(a, 7).value.val[1]); \
CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \
vreinterpretq_s64_f32(b2.val[0]))); \
CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b4.val[0]), \
vreinterpretq_s64_f32(b6.val[0]))); \
CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \
vreinterpretq_s64_f32(b2.val[0]))); \
CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b4.val[0]), \
vreinterpretq_s64_f32(b6.val[0]))); \
CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \
vreinterpretq_s64_f32(b2.val[1]))); \
CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b4.val[1]), \
vreinterpretq_s64_f32(b6.val[1]))); \
CONCAT(ret, 3).value.val[0] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b0.val[1]), \
vreinterpretq_s64_f32(b2.val[1]))); \
CONCAT(ret, 3).value.val[1] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b4.val[1]), \
vreinterpretq_s64_f32(b6.val[1]))); \
CONCAT(ret, 4).value.val[0] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b1.val[0]), \
vreinterpretq_s64_f32(b3.val[0]))); \
CONCAT(ret, 4).value.val[1] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b5.val[0]), \
vreinterpretq_s64_f32(b7.val[0]))); \
CONCAT(ret, 5).value.val[0] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b1.val[0]), \
vreinterpretq_s64_f32(b3.val[0]))); \
CONCAT(ret, 5).value.val[1] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b5.val[0]), \
vreinterpretq_s64_f32(b7.val[0]))); \
CONCAT(ret, 6).value.val[0] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b1.val[1]), \
vreinterpretq_s64_f32(b3.val[1]))); \
CONCAT(ret, 6).value.val[1] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b5.val[1]), \
vreinterpretq_s64_f32(b7.val[1]))); \
CONCAT(ret, 7).value.val[0] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b1.val[1]), \
vreinterpretq_s64_f32(b3.val[1]))); \
CONCAT(ret, 7).value.val[1] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b5.val[1]), \
vreinterpretq_s64_f32(b7.val[1]))); \
} while (0);

#define TRANSPOSE_8x3(a, ret) \
auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \
CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \
vreinterpretq_s64_f32(b1.val[0]))); \
CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b2.val[0]), \
vreinterpretq_s64_f32(b3.val[0]))); \
CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \
vreinterpretq_s64_f32(b1.val[0]))); \
CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b2.val[0]), \
vreinterpretq_s64_f32(b3.val[0]))); \
CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \
vreinterpretq_s64_f32(b1.val[1]))); \
CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b2.val[1]), \
vreinterpretq_s64_f32(b3.val[1])));

#define TRANSPOSE_8x4(a, ret) \
auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \
CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \
vreinterpretq_s64_f32(b1.val[0]))); \
CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b2.val[0]), \
vreinterpretq_s64_f32(b3.val[0]))); \
CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \
vreinterpretq_s64_f32(b1.val[0]))); \
CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b2.val[0]), \
vreinterpretq_s64_f32(b3.val[0]))); \
CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \
vreinterpretq_s64_f32(b1.val[1]))); \
CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \
vzip1q_s64(vreinterpretq_s64_f32(b2.val[1]), \
vreinterpretq_s64_f32(b3.val[1]))); \
CONCAT(ret, 3).value.val[0] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b0.val[1]), \
vreinterpretq_s64_f32(b1.val[1]))); \
CONCAT(ret, 3).value.val[1] = vreinterpretq_f32_s64( \
vzip2q_s64(vreinterpretq_s64_f32(b2.val[1]), \
vreinterpretq_s64_f32(b3.val[1])));

#elif MEGDNN_ARMV7
#define TRANSPOSE_8x4(a, ret) \
auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \
CONCAT(ret, 0).value.val[0] = \
vcombine_f32(vget_low_f32(b0.val[0]), vget_low_f32(b1.val[0])); \
CONCAT(ret, 1).value.val[0] = \
vcombine_f32(vget_high_f32(b0.val[0]), vget_high_f32(b1.val[0])); \
CONCAT(ret, 2).value.val[0] = \
vcombine_f32(vget_low_f32(b0.val[1]), vget_low_f32(b1.val[1])); \
CONCAT(ret, 3).value.val[0] = \
vcombine_f32(vget_high_f32(b0.val[1]), vget_high_f32(b1.val[1])); \
CONCAT(ret, 0).value.val[1] = \
vcombine_f32(vget_low_f32(b2.val[0]), vget_low_f32(b3.val[0])); \
CONCAT(ret, 1).value.val[1] = \
vcombine_f32(vget_high_f32(b2.val[0]), vget_high_f32(b3.val[0])); \
CONCAT(ret, 2).value.val[1] = \
vcombine_f32(vget_low_f32(b2.val[1]), vget_low_f32(b3.val[1])); \
CONCAT(ret, 3).value.val[1] = \
vcombine_f32(vget_high_f32(b2.val[1]), vget_high_f32(b3.val[1]));

#endif
// vim: syntax=cpp.doxygen

+ 39
- 0
dnn/src/arm_common/conv_bias/fp32/strategy.h View File

@@ -0,0 +1,39 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/fallback/conv_bias/winograd/winograd.h"

namespace megdnn {
namespace arm_common {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4,
winograd_2x3_4x4_f)

MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1,
winograd_6x3_1x1_f)

MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4,
winograd_6x3_4x4_f)

MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 5, 4, 1, 1,
winograd_5x4_1x1_f)

MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 4, 5, 1, 1,
winograd_4x5_1x1_f)
} // namespace winograd
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 346
- 0
dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp View File

@@ -0,0 +1,346 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"

#include "src/naive/matrix_mul/matrix_mul_helper.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/arm_common/conv_bias/fp32/helper.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F23)

using namespace megdnn;
using namespace arm_common;
namespace {

struct InputTransform2X3 {
template <bool inner>
static void prepare(const float* input, float* patch, float* patchT,
int ih_start, int iw_start, size_t IH, size_t IW,
size_t ic, size_t IC) {
constexpr size_t alpha = 2 + 3 - 1;
if (!(inner && ic + 4 < IC)) {
memset(patch, 0, sizeof(float) * 4 * alpha * alpha);
}
if (inner) {
const float* input_ptr =
input + ic * IH * IW + ih_start * IW + iw_start;
for (size_t ico = 0; ico < 4; ++ico) {
if (ic + ico < IC) {
auto v0 = vld1q_f32(input_ptr);
auto v1 = vld1q_f32(input_ptr + IW);
auto v2 = vld1q_f32(input_ptr + IW * 2);
auto v3 = vld1q_f32(input_ptr + IW * 3);

vst1q_f32(patch + ico * 4 * alpha + 0 * 4, v0);
vst1q_f32(patch + ico * 4 * alpha + 1 * 4, v1);
vst1q_f32(patch + ico * 4 * alpha + 2 * 4, v2);
vst1q_f32(patch + ico * 4 * alpha + 3 * 4, v3);
input_ptr += IH * IW;
}
}
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
// partial copy
for (size_t ico = 0; ico < 4; ++ico) {
if (ic + ico < IC) {
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
patch[ico * alpha * 4 + iho * 4 + iwo] =
input[(ic + ico) * IH * IW + ih * IW + iw];
}
}
}
}
}

transpose_4x4(patch + 0 * 1, patchT + 0 * 4, 16, 4);
transpose_4x4(patch + 4 * 1, patchT + 4 * 4, 16, 4);
transpose_4x4(patch + 8 * 1, patchT + 8 * 4, 16, 4);
transpose_4x4(patch + 12 * 1, patchT + 12 * 4, 16, 4);
}

static void transform(const float* patchT, float* input_transform_buf,
size_t unit_idx, size_t nr_units_in_tile, size_t ic,
size_t IC) {
constexpr size_t alpha = 2 + 3 - 1;
// BT * d * B
#define cb(m, n) \
Vector<float, 4> d##m##n = \
Vector<float, 4>::load(patchT + m * 4 * 4 + n * 4);

UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb

//! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0
//! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1
#define cb(m) \
auto t0##m = d0##m - d2##m; \
auto t1##m = d1##m + d2##m; \
auto t2##m = d2##m - d1##m; \
auto t3##m = d3##m - d1##m;

UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb

#define cb(m) \
d##m##0 = t##m##0 - t##m##2; \
d##m##1 = t##m##1 + t##m##2; \
d##m##2 = t##m##2 - t##m##1; \
d##m##3 = t##m##3 - t##m##1;

UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb

size_t ICB = IC / 4;
size_t icb = ic / 4;
#define cb(m, n) \
d##m##n.save(input_transform_buf + \
(m * alpha + n) * ICB * nr_units_in_tile * 4 + \
icb * nr_units_in_tile * 4 + unit_idx * 4);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb)
#undef cb
}
};

template <BiasMode bmode, typename Op>
struct OutputTransform2X3 {
static void transform(const float* output_transform_buf, const float* bias,
float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end,
size_t oc_index, size_t unit_idx,
size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
Op op(src_dtype, dst_dtype);
//! AT * m * A
constexpr size_t alpha = 2 + 3 - 1;

size_t oc = oc_start + oc_index;
size_t OCB = (oc_end - oc_start) / 4;
size_t ocb = oc_index / 4;
#define cb(m, n) \
auto v##m##n = Vector<float, 4>::load( \
output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * 4 + \
ocb * nr_units_in_tile * 4 + unit_idx * 4);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb
//! 1 1 1 0 v00 v01 v02 v03 1 0
//! 0 1 -1 1 v10 v11 v12 v13 1 1
//! v20 v21 v22 v23 1 -1
//! v30 v31 v32 v33 0 1
#define cb(m) \
auto t0##m = v0##m + v1##m + v2##m; \
auto t1##m = v1##m - v2##m + v3##m;

UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
v00 = t00 + t01 + t02;
v10 = t10 + t11 + t12;
v01 = t01 - t02 + t03;
v11 = t11 - t12 + t13;

Vector<float, 4> vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 4>::load(bias + oc);

v00 += vbias;
v10 += vbias;
v01 += vbias;
v11 += vbias;
}
if (bmode != BiasMode::BIAS) {
v00 = op(v00.value);
v01 = op(v01.value);
v10 = op(v10.value);
v11 = op(v11.value);
}

v00.save(transform_mid_buf + (0 * 2 + 0) * 4);
v10.save(transform_mid_buf + (1 * 2 + 0) * 4);
v01.save(transform_mid_buf + (0 * 2 + 1) * 4);
v11.save(transform_mid_buf + (1 * 2 + 1) * 4);

for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) {
for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) {
for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) {
size_t oh = oh_start + oho;
size_t ow = ow_start + owo;
float res = transform_mid_buf[oho * 2 * 4 + owo * 4 + oco];
if (bmode == BiasMode::BIAS) {
res += bias[(oc + oco) * OH * OW + oh * OW + ow];
res = op(res);
}
output[(oc + oco) * OH * OW + oh * OW + ow] = res;
}
}
}
}
};
} // namespace

namespace megdnn {
namespace arm_common {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f)
void winograd_2x3_4x4_f::filter(const float* filter,
float* filter_transform_buf,
float* transform_mid_buf, size_t OC, size_t IC,
size_t oc_start, size_t oc_end) {
constexpr int alpha = 2 + 3 - 1;
//! G * g * GT
float32x4_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0},
g3{0, 0, 1, 0};
float32x4_t gt0{1, 0.5, 0.5, 0}, gt1{0, 0.5, -0.5, 0}, gt2{0, 0.5, 0.5, 1},
gt3{0, 0, 0, 0};
size_t OCB = OC / 4;
size_t ICB = IC / 4;

for (size_t oc = oc_start; oc < oc_end; oc++)
rep(ic, IC) {
const float* filter_ptr = filter + (oc * IC + ic) * 3 * 3;
/**
* origin: (4x3) * (3 x 3) * (3 x 4)
* pack to G and g to times of 4
* now: (4x4) * (4 x 4) * (4 x 4)
*/
//! 1 0 0 0 v00 v01 v02 0 1 0.5 0.5 0
//! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0
//! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1
//! 0 0 1 0 0 0 0 0 0 0 0 0
float32x4_t vf0 = vld1q_f32(filter_ptr);
float32x4_t vf1 = vld1q_f32(filter_ptr + 4);
float32x4_t vf2 = vdupq_n_f32(filter_ptr[8]);

float32x4_t v3(vdupq_n_f32(0));
auto vtmp = vextq_f32(vf1, vf2, 2);
vtmp = vsetq_lane_f32(0, vtmp, 3);
float32x4_t v2(vtmp);
vtmp = vextq_f32(vf0, vf1, 3);
vtmp = vsetq_lane_f32(0, vtmp, 3);
float32x4_t v1(vtmp);
vtmp = vsetq_lane_f32(0, vf0, 3);
float32x4_t v0(vtmp);

float32x4_t vsum0 = vdupq_n_f32(0), vsum1 = vdupq_n_f32(0),
vsum2 = vdupq_n_f32(0), vsum3 = vdupq_n_f32(0);

MATRIX_MUL4x4(vsum, g, v);

float32x4_t vres0 = vdupq_n_f32(0), vres1 = vdupq_n_f32(0),
vres2 = vdupq_n_f32(0), vres3 = vdupq_n_f32(0);
MATRIX_MUL4x4(vres, vsum, gt);

vst1q_f32(transform_mid_buf, vres0);
vst1q_f32(transform_mid_buf + 4, vres1);
vst1q_f32(transform_mid_buf + 8, vres2);
vst1q_f32(transform_mid_buf + 12, vres3);

size_t ocb = oc / 4;
size_t oc4 = oc % 4;
size_t icb = ic / 4;
size_t ic4 = ic % 4;
rep(i, alpha) rep(j, alpha) {
filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 * 4 +
ocb * ICB * 4 * 4 + icb * 4 * 4 + ic4 * 4 +
oc4] = transform_mid_buf[i * alpha + j];
}
}
}

void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf,
float* transform_mid_buf, size_t IH, size_t IW,
size_t IC, size_t PH, size_t PW,
size_t unit_start_idx, size_t nr_units_in_tile) {
megdnn_assert(IC % 4 == 0);
constexpr int alpha = 3 + 2 - 1;

// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
float* patch = transform_mid_buf;
float* patchT = transform_mid_buf + 4 * alpha * alpha;

for (size_t ic = 0; ic < IC; ic += 4) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransform2X3::prepare<true>(input, patch, patchT, ih_start,
iw_start, IH, IW, ic, IC);
InputTransform2X3::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile, ic,
IC);

} else {
InputTransform2X3::prepare<false>(input, patch, patchT,
ih_start, iw_start, IH, IW,
ic, IC);
InputTransform2X3::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile, ic,
IC);
}
}
}
}

void winograd_2x3_4x4_f::output(const float* output_transform_buf,
const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH, size_t OW,
size_t oc_start, size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform2X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__);

auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);

for (size_t oc = oc_start; oc < oc_end; oc += 4) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp32_F23, cb, float, float, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
}
}
#undef cb
}

} // namespace winograd
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 483
- 0
dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp View File

@@ -0,0 +1,483 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"

#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F45)

using namespace megdnn;
using namespace arm_common;
namespace {

struct FilterTransform4X5 {
#define FILTER_TRANSFORM(d, wd) \
do { \
wd##0 = d##0; \
wd##r0 = d##r0; \
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \
wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.2222222; \
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \
wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.2222222; \
auto tmpd0 = d##0 * 0.7111111; \
auto tmpd1 = d##1 * 0.3555556; \
auto tmpd2 = d##2 * 0.1777778; \
auto tmpd3 = d##3 * 0.0888889; \
auto tmpd4 = d##4 * 0.0444444; \
auto tmpdr0 = d##r0 * 0.7111111; \
auto tmpdr1 = d##r1 * 0.3555556; \
auto tmpdr2 = d##r2 * 0.1777778; \
auto tmpdr3 = d##r3 * 0.0888889; \
auto tmpdr4 = d##r4 * 0.0444444; \
wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \
wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \
wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \
wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \
tmpd0 = d##0 * 0.0111111; \
tmpd1 = d##1 * 0.0222222; \
tmpd2 = d##2 * 0.0444444; \
tmpd3 = d##3 * 0.0888889; \
tmpd4 = d##4 * 0.1777778; \
tmpdr0 = d##r0 * 0.0111111; \
tmpdr1 = d##r1 * 0.0222222; \
tmpdr2 = d##r2 * 0.0444444; \
tmpdr3 = d##r3 * 0.0888889; \
tmpdr4 = d##r4 * 0.1777778; \
wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \
wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \
wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \
wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \
wd##7 = d##4; \
wd##r7 = d##r4; \
} while (0);

#define FILTER_TRANSFORM_FINAL(d, wd) \
do { \
wd##0 = d##0; \
wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.2222222; \
wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.2222222; \
auto tmp0 = d##0 * 0.7111111 + d##2 * 0.1777778 + d##4 * 0.0444444; \
auto tmp1 = d##1 * 0.3555556 + d##3 * 0.0888889; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d##0 * 0.0111111 + d##2 * 0.0444444 + d##4 * 0.1777778; \
tmp1 = d##1 * 0.0222222 + d##3 * 0.0888889; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = d##4; \
} while (0);
static void transform(const float* filter, float* filter_transform_buf,
float* transform_mid_buf, size_t OC, size_t IC,
size_t oc_start, size_t oc_end) {
// Gg * GT
// G
//[[ 1. 0. 0. 0. 0. ]
// [-0.2222222 -0.2222222 -0.2222222 -0.2222222 -0.2222222]
// [-0.2222222 0.2222222 -0.2222222 0.2222222 -0.2222222]
// [ 0.7111111 0.3555556 0.1777778 0.0888889 0.0444444]
// [ 0.7111111 -0.3555556 0.1777778 -0.0888889 0.0444444]
// [ 0.0111111 0.0222222 0.0444444 0.0888889 0.1777778]
// [ 0.0111111 -0.0222222 0.0444444 -0.0888889 0.1777778]
// [ 0. 0. 0. 0. 1. ]]
constexpr size_t alpha = 4 + 5 - 1;
for (size_t oc = oc_start; oc < oc_end; oc++)
rep(ic, IC) {
const float* fptr = filter + (oc * IC + ic) * 5 * 5;

#define cb(i) Vector<float, 4> g##i = Vector<float, 4>::load(fptr + 5 * i);
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb

#define cb(i) float gr##i = *(fptr + 5 * i + 4);
UNROLL_CALL_NOWRAPPER(5, cb);

#undef cb
#define cb(i) Vector<float, 4> Gg##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

#define cb(i) float Ggr##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

#define cb(i) Vector<float, 8> Ggt##i;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb

#define cb(i) Vector<float, 8> result##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

FILTER_TRANSFORM(g, Gg)
float32x4x2_t vgr;
float32x4_t vgr0 = {Ggr0, Ggr1, Ggr2, Ggr3};
float32x4_t vgr1 = {Ggr4, Ggr5, Ggr6, Ggr7};
vgr.val[0] = vgr0; //{Ggr0, Ggr1, Ggr2, Ggr3};
vgr.val[1] = vgr1; //{Ggr4, Ggr5, Ggr6, Ggr7};
Vector<float, 8> Ggt4(vgr);
TRANSPOSE_8x4(Gg, Ggt);
FILTER_TRANSFORM_FINAL(Ggt, result);

#define cb(i) result##i.save(transform_mid_buf + i * alpha);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
rep(i, alpha) rep(j, alpha) {
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC +
oc] = transform_mid_buf[j * alpha + i];
}
}
}
};
#undef FILTER_TRANSFORM
#undef FILTER_TRANSFORM_FINAL

struct InputTransform4X5 {
#define INPUT_TRANSFORM(d, wd) \
do { \
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \
auto tmp0 = d##2 - d##4 * 4.25f + d##6; \
auto tmp1 = d##1 - d##3 * 4.25f + d##5; \
wd##1 = tmp0 + tmp1; \
wd##2 = tmp0 - tmp1; \
tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \
tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \
} while (0)

#define GET_VECTOR_HIGH_ELEM(s, i, idx) \
vgetq_lane_f32(CONCAT(s, i).value.val[1], idx)
#define GET_VECTOR_LOW_ELEM(s, i, idx) \
vgetq_lane_f32(CONCAT(s, i).value.val[0], idx)

template <bool inner>
static void transform(const float* input, float* input_transform_buf,
float* transform_mid_buf, int ih_start, int iw_start,
size_t ic, size_t IH, size_t IW, size_t IC,
size_t unit_idx, size_t nr_units_in_tile) {
// BTd * B
//([[ 1. , 0. , -5.25, 0. , 5.25, 0. , -1. , 0. ],
// [ 0. , 1. , 1. , -4.25, -4.25, 1. , 1. , 0. ],
// [ 0. , -1. , 1. , 4.25, -4.25, -1. , 1. , 0. ],
// [ 0. , 2. , 4. , -2.5 , -5. , 0.5 , 1. , 0. ],
// [ 0. , -2. , 4. , 2.5 , -5. , -0.5 , 1. , 0. ],
// [ 0. , 0.5 , 0.25, -2.5 , -1.25, 2. , 1. , 0. ],
// [ 0. , -0.5 , 0.25, 2.5 , -1.25, -2. , 1. , 0. ],
// [ 0. , -1. , 0. , 5.25, 0. , -5.25, 0. , 1. ]]))

constexpr size_t alpha = 4 + 5 - 1;
if (!inner) {
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha);
}

#define cb(i) Vector<float, 8> d##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

if (inner) {
const float* input_ptr =
input + ic * IH * IW + ih_start * IW + iw_start;
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
transform_mid_buf[iho * alpha + iwo] =
input[ic * IH * IW + ih * IW + iw];
}
}
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
}

#define cb(i) Vector<float, 8> wd##i, ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

INPUT_TRANSFORM(d, wd);
#if MEGDNN_AARCH64
TRANSPOSE_8x8(wd, d);
INPUT_TRANSFORM(d, ret);

#define cb(i) ret##i.save(transform_mid_buf + i * alpha);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
rep(i, alpha) rep(j, alpha) {
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC +
unit_idx * IC + ic] =
transform_mid_buf[j * alpha + i];
}
#else
#define cb(i) \
do { \
mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \
GET_VECTOR_HIGH_ELEM(wd, i, 2) + \
5.25 * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \
GET_VECTOR_LOW_ELEM(wd, i, 2)); \
mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \
GET_VECTOR_LOW_ELEM(wd, i, 1) + \
5.25 * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \
GET_VECTOR_HIGH_ELEM(wd, i, 1)); \
auto tmp0 = 4 * GET_VECTOR_LOW_ELEM(wd, i, 2) + \
-5 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \
GET_VECTOR_HIGH_ELEM(wd, i, 2); \
auto tmp1 = 2 * GET_VECTOR_LOW_ELEM(wd, i, 1) + \
-2.5 * GET_VECTOR_LOW_ELEM(wd, i, 3) + \
0.5 * GET_VECTOR_HIGH_ELEM(wd, i, 1); \
mid_buf1[3] = tmp0 + tmp1; \
mid_buf1[4] = tmp0 - tmp1; \
tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \
-4.25 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \
GET_VECTOR_HIGH_ELEM(wd, i, 2); \
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \
GET_VECTOR_LOW_ELEM(wd, i, 3) * -4.25 + \
GET_VECTOR_HIGH_ELEM(wd, i, 1); \
mid_buf1[1] = tmp0 + tmp1; \
mid_buf1[2] = tmp0 - tmp1; \
tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) * 0.25 + \
GET_VECTOR_HIGH_ELEM(wd, i, 0) * -1.25 + \
GET_VECTOR_HIGH_ELEM(wd, i, 2); \
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5 + \
GET_VECTOR_LOW_ELEM(wd, i, 3) * -2.5 + \
GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2; \
mid_buf1[5] = tmp0 + tmp1; \
mid_buf1[6] = tmp0 - tmp1; \
mid_buf1 += 8; \
} while (0);

float* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb);
mid_buf1 = transform_mid_buf;

#undef cb
rep(i, alpha) rep(j, alpha) {
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC +
unit_idx * IC + ic] =
transform_mid_buf[i * alpha + j];
}
#endif
}
};
#undef INPUT_TRANSFORM

#define OUTPUT_TRANSFORM(m, s) \
do { \
s0 = m0 + m1 + m2 + m3 + m4 + m5 + m6; \
s1 = m1 - m2 + m3 * 0.5 - m4 * 0.5 + m5 * 2.0 - m6 * 2.0; \
s2 = m1 + m2 + m3 * 0.25 + m4 * 0.25 + m5 * 4.0 + m6 * 4.0; \
s3 = m1 - m2 + m3 * 0.125 - m4 * 0.125 + m5 * 8.0 - m6 * 8.0 + m7; \
} while (0)
template <BiasMode bmode, typename Op>
struct OutputTransform4X5 {
static void transform(const float* output_transform_buf, const float* bias,
float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end,
size_t oc_index, size_t unit_idx,
size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
Op op(src_dtype, dst_dtype);
//! AT * m * A
// AT f45
// 1.0 1.0 1.0 1.000 1.000 1.0 1.0 0.0
// 0.0 1.0 -1.0 0.500 -0.500 2.0 -2.0 0.0
// 0.0 1.0 1.0 0.250 0.250 4.0 4.0 0.0
// 0.0 1.0 -1.0 0.125 -0.125 8.0 -8.0 1.0
constexpr size_t alpha = 5 + 4 - 1;
float* mid_buf1 = transform_mid_buf;

size_t OC = oc_end - oc_start;
size_t oc = oc_start + oc_index;

#define cb(m, n) \
transform_mid_buf[m * alpha + n] = \
output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \
unit_idx * OC + oc_index];
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb

#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
#define cb(i) Vector<float, 8> s##i;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb

OUTPUT_TRANSFORM(m, s);
#define cb(i) \
do { \
auto add12 = \
GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \
auto add34 = \
GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \
auto add56 = \
GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \
auto sub12 = \
GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \
auto sub34 = \
GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \
auto sub56 = \
GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \
mid_buf1[0] = GET_VECTOR_LOW_ELEM(s, i, 0) + add12 + add34 + add56; \
mid_buf1[1] = sub12 + sub34 * 0.5 + sub56 * 2.0; \
mid_buf1[2] = add12 + add34 * 0.25 + add56 * 4.0; \
mid_buf1[3] = sub12 + sub34 * 0.125 + sub56 * 8.0 + \
GET_VECTOR_HIGH_ELEM(s, i, 3); \
mid_buf1 += 4; \
} while (0);

mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(4, cb);
mid_buf1 = transform_mid_buf;
#undef cb

if (oh_start + 4 <= OH && ow_start + 4 <= OW) {
float32x4_t bias0;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias0 = vdupq_n_f32(bias[oc]);
}
rep(i, 4) {
size_t oh = oh_start + i;
float32x4_t item0 = vld1q_f32(mid_buf1);

if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
item0 = vaddq_f32(item0, bias0);
} else if (bmode == BiasMode::BIAS) {
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start);
item0 = vaddq_f32(item0, bias0);
}
item0 = op(item0);
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0);
mid_buf1 += 4;
}
} else {
for (size_t oho = 0; oho < 4 && oh_start + oho < OH; ++oho) {
for (size_t owo = 0; owo < 4 && ow_start + owo < OW; ++owo) {
size_t oh = oh_start + oho;
size_t ow = ow_start + owo;
float res = mid_buf1[oho * 4 + owo];
if (bmode == BiasMode::BIAS) {
res += bias[oc * OH * OW + oh * OW + ow];
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
res += bias[oc];
}
res = op(res);
output[oc * OH * OW + oh * OW + ow] = res;
}
}
}
}
};
#undef OUTPUT_TRANSFORM
#undef GET_VECTOR_HIGH_ELEM
#undef GET_VECTOR_LOW_ELEM

} // namespace

namespace megdnn {
namespace arm_common {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f)

void winograd_4x5_1x1_f::filter(const float* filter,
float* filter_transform_buf,
float* transform_mid_buf, size_t OC, size_t IC,
size_t oc_start, size_t oc_end) {
FilterTransform4X5::transform(filter, filter_transform_buf,
transform_mid_buf, OC, IC, oc_start, oc_end);
}

void winograd_4x5_1x1_f::input(const float* input, float* input_transform_buf,
float* transform_mid_buf, size_t IH, size_t IW,
size_t IC, size_t PH, size_t PW,
size_t unit_start_idx, size_t nr_units_in_tile) {
constexpr int alpha = 4 + 5 - 1;

// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
rep(ic, IC) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransform4X5::transform<true>(
input, input_transform_buf, transform_mid_buf, ih_start,
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile);

} else {
InputTransform4X5::transform<false>(
input, input_transform_buf, transform_mid_buf, ih_start,
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile);
}
}
}
}

void winograd_4x5_1x1_f::output(const float* output_transform_buf,
const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH, size_t OW,
size_t oc_start, size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform4X5<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__);

auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);

for (size_t oc = oc_start; oc < oc_end; oc++) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp32_F45, cb, float, float, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
}
}
#undef cb
}

} // namespace winograd
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 500
- 0
dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp View File

@@ -0,0 +1,500 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"

#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F54)

using namespace megdnn;
using namespace arm_common;
namespace {

struct FilterTransform5X4 {
#define FILTER_TRANSFORM(d, wd) \
do { \
wd##0 = d##0; \
auto tmp0 = d##0 * 0.7111111f + d##2 * 0.1777778f; \
auto tmp1 = d##1 * 0.3555556f + d##3 * 0.0888889f; \
wd##1 = tmp0 + tmp1; \
wd##2 = tmp0 - tmp1; \
tmp0 = (d##0 + d##2) * -0.2222222f; \
tmp1 = (d##1 + d##3) * -0.2222222f; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d##0 * 0.0111111f + d##2 * 0.0444444f; \
tmp1 = d##1 * 0.0222222f + d##3 * 0.0888889f; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = d##3; \
} while (0)

static void transform(const float* filter, float* filter_transform_buf,
float* transform_mid_buf, size_t OC, size_t IC,
size_t oc_start, size_t oc_end) {
// Gg * GT
// G
// 1 0 0 0
// 0.7111111 0.3555556 0.1777778 0.0888889
// 0.7111111 -0.3555556 0.1777778 -0.0888889
// -0.2222222 -0.2222222 -0.2222222 -0.2222222
// -0.2222222 0.2222222 -0.2222222 0.2222222
// 0.0111111 0.0222222 0.0444444 0.0888889
// 0.0111111 -0.0222222 0.0444444 -0.0888889
// 0 0 0 1

constexpr size_t alpha = 4 + 5 - 1;
for (size_t oc = oc_start; oc < oc_end; oc++) {
rep(ic, IC) {
const float* fptr = filter + (oc * IC + ic) * 4 * 4;

#define cb(i) Vector<float, 4> g##i = Vector<float, 4>::load(fptr + 4 * i);
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb

#define cb(i) Vector<float, 4> wd##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

#define cb(i) Vector<float, 8> wdt##i;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb

#define cb(i) Vector<float, 8> ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

FILTER_TRANSFORM(g, wd);
#if MEGDNN_AARCH64
TRANSPOSE_8x4(wd, wdt);
FILTER_TRANSFORM(wdt, ret);

#define cb(i) ret##i.save(transform_mid_buf + i * alpha);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
rep(i, alpha) rep(j, alpha) {
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC +
oc] = transform_mid_buf[j * alpha + i];
}
#else

#define cb(i) \
do { \
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \
auto tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.7111111f + \
GET_VECTOR_ELEM(wd, i, 2) * 0.1777778f; \
auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.3555556f + \
GET_VECTOR_ELEM(wd, i, 3) * 0.0888889f; \
mid_buf1[1] = tmp0 + tmp1; \
mid_buf1[2] = tmp0 - tmp1; \
tmp0 = (GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * \
-0.2222222f; \
tmp1 = (GET_VECTOR_ELEM(wd, i, 1) + GET_VECTOR_ELEM(wd, i, 3)) * \
-0.2222222f; \
mid_buf1[3] = tmp0 + tmp1; \
mid_buf1[4] = tmp0 - tmp1; \
tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0111111f + \
GET_VECTOR_ELEM(wd, i, 2) * 0.0444444f; \
tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0222222f + \
GET_VECTOR_ELEM(wd, i, 3) * 0.0888889f; \
mid_buf1[5] = tmp0 + tmp1; \
mid_buf1[6] = tmp0 - tmp1; \
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \
mid_buf1 += 8; \
} while (0);
#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx)

float* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb);
mid_buf1 = transform_mid_buf;
#undef cb
rep(i, alpha) rep(j, alpha) {
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC +
oc] = transform_mid_buf[i * alpha + j];
}
#endif
}
}
}
};
#undef FILTER_TRANSFORM
#undef GET_VECTOR_ELEM

struct InputTransform5X4 {
#define INPUT_TRANSFORM(d, wd) \
do { \
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \
auto tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \
auto tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \
wd##1 = tmp0 + tmp1; \
wd##2 = tmp0 - tmp1; \
tmp0 = d##2 - d##4 * 4.25f + d##6; \
tmp1 = d##1 - d##3 * 4.25f + d##5; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \
tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \
} while (0)

#define GET_VECTOR_HIGH_ELEM(s, i, idx) \
vgetq_lane_f32(CONCAT(s, i).value.val[1], idx)
#define GET_VECTOR_LOW_ELEM(s, i, idx) \
vgetq_lane_f32(CONCAT(s, i).value.val[0], idx)

template <bool inner>
static void transform(const float* input, float* input_transform_buf,
float* transform_mid_buf, int ih_start, int iw_start,
size_t ic, size_t IH, size_t IW, size_t IC,
size_t unit_idx, size_t nr_units_in_tile) {
// BTd * B
// BT
// 1 0 -5.25 0 5.25 0 -1 0
// 0 2 4 -2.5 -5 0.5 1 0
// 0 -2 4 2.5 -5 -0.5 1 0
// 0 1 1 -4.25 -4.25 1 1 0
// 0 -1 1 4.25 -4.25 -1 1 0
// 0 0.5 0.25 -2.5 -1.25 2 1 0
// 0 -0.5 0.25 2.5 -1.25 -2 1 0
// 0 -1 0 5.25 0 -5.25 0 1

constexpr size_t alpha = 4 + 5 - 1;
if (!inner) {
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha);
}

#define cb(i) Vector<float, 8> d##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

if (inner) {
const float* input_ptr =
input + ic * IH * IW + ih_start * IW + iw_start;
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
transform_mid_buf[iho * alpha + iwo] =
input[ic * IH * IW + ih * IW + iw];
}
}
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
}

#define cb(i) Vector<float, 8> wd##i, ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

INPUT_TRANSFORM(d, wd);
#if MEGDNN_AARCH64
TRANSPOSE_8x8(wd, d);
INPUT_TRANSFORM(d, ret);

#define cb(i) ret##i.save(transform_mid_buf + i * alpha);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
rep(i, alpha) rep(j, alpha) {
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC +
unit_idx * IC + ic] =
transform_mid_buf[j * alpha + i];
}
#else
#define cb(i) \
do { \
mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \
GET_VECTOR_HIGH_ELEM(wd, i, 2) + \
5.25 * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \
GET_VECTOR_LOW_ELEM(wd, i, 2)); \
mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \
GET_VECTOR_LOW_ELEM(wd, i, 1) + \
5.25 * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \
GET_VECTOR_HIGH_ELEM(wd, i, 1)); \
auto tmp0 = 4 * GET_VECTOR_LOW_ELEM(wd, i, 2) + \
-5 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \
GET_VECTOR_HIGH_ELEM(wd, i, 2); \
auto tmp1 = 2 * GET_VECTOR_LOW_ELEM(wd, i, 1) + \
-2.5 * GET_VECTOR_LOW_ELEM(wd, i, 3) + \
0.5 * GET_VECTOR_HIGH_ELEM(wd, i, 1); \
mid_buf1[1] = tmp0 + tmp1; \
mid_buf1[2] = tmp0 - tmp1; \
tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \
-4.25 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \
GET_VECTOR_HIGH_ELEM(wd, i, 2); \
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \
GET_VECTOR_LOW_ELEM(wd, i, 3) * -4.25 + \
GET_VECTOR_HIGH_ELEM(wd, i, 1); \
mid_buf1[3] = tmp0 + tmp1; \
mid_buf1[4] = tmp0 - tmp1; \
tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) * 0.25 + \
GET_VECTOR_HIGH_ELEM(wd, i, 0) * -1.25 + \
GET_VECTOR_HIGH_ELEM(wd, i, 2); \
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5 + \
GET_VECTOR_LOW_ELEM(wd, i, 3) * -2.5 + \
GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2; \
mid_buf1[5] = tmp0 + tmp1; \
mid_buf1[6] = tmp0 - tmp1; \
mid_buf1 += 8; \
} while (0);

float* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb);
mid_buf1 = transform_mid_buf;

#undef cb
rep(i, alpha) rep(j, alpha) {
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC +
unit_idx * IC + ic] =
transform_mid_buf[i * alpha + j];
}
#endif
}
};
#undef INPUT_TRANSFORM

#define OUTPUT_TRANSFORM(m, s) \
do { \
auto m1addm2 = m##1 + m##2; \
auto m1subm2 = m##1 - m##2; \
auto m3addm4 = m##3 + m##4; \
auto m3subm4 = m##3 - m##4; \
auto m5addm6 = (m##5 + m##6); \
auto m5subm6 = (m##5 - m##6); \
s##0 = m##0; \
CONCAT(s, 0).add(m1addm2).add(m3addm4).add(m5addm6); \
CONCAT(s, 1) = m3subm4; \
CONCAT(s, 1).mla(m1subm2, 0.5f).mla(m5subm6, 2.0f); \
CONCAT(s, 2) = m3addm4; \
CONCAT(s, 2).mla(m1addm2, 0.25f).mla(m5addm6, 4.0f); \
CONCAT(s, 3) = m3subm4; \
CONCAT(s, 3).mla(m1subm2, 0.125f).mla(m5subm6, 8.0f); \
CONCAT(s, 4) = m##7; \
CONCAT(s, 4).mla(m1addm2, 0.0625f).add(m3addm4).mla(m5addm6, 16.0f); \
} while (0)

template <BiasMode bmode, typename Op>
struct OutputTransform5X4 {
static void transform(const float* output_transform_buf, const float* bias,
float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end,
size_t oc_index, size_t unit_idx,
size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
Op op(src_dtype, dst_dtype);
//! AT * m * A
// AT
// 1 1 1 1 1 1 1 0
// 0 0.5 -0.5 1 -1 2 -2 0
// 0 0.25 0.25 1 1 4 4 0
// 0 0.125 -0.125 1 -1 8 -8 0
// 0 0.0625 0.0625 1 1 16 16 1
constexpr size_t alpha = 5 + 4 - 1;
float* mid_buf1 = transform_mid_buf;

size_t OC = oc_end - oc_start;
size_t oc = oc_start + oc_index;

#define cb(m, n) \
transform_mid_buf[m * alpha + n] = \
output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \
unit_idx * OC + oc_index];
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb

#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
#define cb(i) Vector<float, 8> s##i, ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

OUTPUT_TRANSFORM(m, s);
#define cb(i) \
do { \
auto m1addm2 = \
GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \
auto m1subm2 = \
GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \
auto m3addm4 = \
GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \
auto m3subm4 = \
GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \
auto m5addm6 = \
GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \
auto m5subm6 = \
GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \
mid_buf1[0] = \
GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \
mid_buf1[1] = 0.5f * m1subm2 + m3subm4 + 2.0f * m5subm6; \
mid_buf1[2] = 0.25f * m1addm2 + m3addm4 + 4.0f * m5addm6; \
mid_buf1[3] = 0.125f * m1subm2 + m3subm4 + 8.0f * m5subm6; \
mid_buf1[4] = 0.0625f * m1addm2 + m3addm4 + 16.0f * m5addm6 + \
GET_VECTOR_HIGH_ELEM(s, i, 3); \
mid_buf1 += 5; \
} while (0);

mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(5, cb);
mid_buf1 = transform_mid_buf;
#undef cb

if (oh_start + 5 <= OH && ow_start + 5 <= OW) {
float32x4_t bias0;
float32_t bias1;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias0 = vdupq_n_f32(bias[oc]);
bias1 = bias[oc];
}
rep(i, 5) {
size_t oh = oh_start + i;
float32x4_t item0 = vld1q_f32(mid_buf1);
float32_t item1 = mid_buf1[4];

if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
item0 = vaddq_f32(item0, bias0);
item1 = item1 + bias1;
} else if (bmode == BiasMode::BIAS) {
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start);
bias1 = bias[oc * OH * OW + oh * OW + ow_start + 4];
item0 = vaddq_f32(item0, bias0);
item1 = item1 + bias1;
}
item0 = op(item0);
item1 = op(item1);
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0);
output[oc * OH * OW + oh * OW + ow_start + 4] = item1;

mid_buf1 += 5;
}
} else {
for (size_t oho = 0; oho < 5 && oh_start + oho < OH; ++oho) {
for (size_t owo = 0; owo < 5 && ow_start + owo < OW; ++owo) {
size_t oh = oh_start + oho;
size_t ow = ow_start + owo;
float res = mid_buf1[oho * 5 + owo];
if (bmode == BiasMode::BIAS) {
res += bias[oc * OH * OW + oh * OW + ow];
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
res += bias[oc];
}
res = op(res);
output[oc * OH * OW + oh * OW + ow] = res;
}
}
}
}
};
#undef OUTPUT_TRANSFORM
#undef GET_VECTOR_HIGH_ELEM
#undef GET_VECTOR_LOW_ELEM

} // namespace

namespace megdnn {
namespace arm_common {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_5x4_1x1_f)

void winograd_5x4_1x1_f::filter(const float* filter,
float* filter_transform_buf,
float* transform_mid_buf, size_t OC, size_t IC,
size_t oc_start, size_t oc_end) {
FilterTransform5X4::transform(filter, filter_transform_buf,
transform_mid_buf, OC, IC, oc_start, oc_end);
}

void winograd_5x4_1x1_f::input(const float* input, float* input_transform_buf,
float* transform_mid_buf, size_t IH, size_t IW,
size_t IC, size_t PH, size_t PW,
size_t unit_start_idx, size_t nr_units_in_tile) {
constexpr int alpha = 5 + 4 - 1;

// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);

rep(ic, IC) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransform5X4::transform<true>(
input, input_transform_buf, transform_mid_buf, ih_start,
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile);

} else {
InputTransform5X4::transform<false>(
input, input_transform_buf, transform_mid_buf, ih_start,
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile);
}
}
}
}

void winograd_5x4_1x1_f::output(const float* output_transform_buf,
const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH, size_t OW,
size_t oc_start, size_t oc_end,
size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform5X4<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__);

auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);

for (size_t oc = oc_start; oc < oc_end; oc++) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp32_F54, cb, float, float, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
}
}
#undef cb
}

} // namespace winograd
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 424
- 0
dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp View File

@@ -0,0 +1,424 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/filter_transform.h"
#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63)

using namespace megdnn;
using namespace arm_common;
namespace {

/**
* input transform
*
* wd0 = (d0 - d6) + 5.25 * (d4 - d2)
* wd1 = (d6 + d2 - 4.25 * d4) + (d1 + d5 - 4.25 * d3)
* wd2 = (d6 + d2 - 4.25 * d4) - (d1 + d5 - 4.25 * d3)
* wd3 = (d6 + 0.25 * d2 - 1.25 * d4) + 2.0 * (d5 + 0.25 * d1 - 1.25 * d3)
* wd4 = (d6 + 0.25 * d2 - 1.25 * d4) - 2.0 * (d5 + 0.25 * d1 - 1.25 * d3)
* wd5 = (d6 - 5.0 * d4 + 4.0 * d2) + 2.0 * (d1 + 0.25 * d5 - 1.25 * d3)
* wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3)
* wd7 = (d7 - d1) + 5.25 * (d3 - d5)
*/
#define INPUT_TRANSFORM(d, wd) \
do { \
wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \
auto tmp0 = d##6 + d##2 - d##4 * 4.25f; \
auto tmp1 = d##1 + d##5 - d##3 * 4.25f; \
wd##1 = tmp0 + tmp1; \
wd##2 = tmp0 - tmp1; \
tmp0 = d##6 + d##2 * 0.25f - d##4 * 1.25f; \
tmp1 = (d##5 + d##1 * 0.25f - d##3 * 1.25f) * 2.0f; \
wd##3 = tmp0 + tmp1; \
wd##4 = tmp0 - tmp1; \
tmp0 = d6 - d4 * 5.0f + d2 * 4.0f; \
tmp1 = (d1 + d5 * 0.25f - d3 * 1.25f) * 2.0f; \
wd##5 = tmp0 + tmp1; \
wd##6 = tmp0 - tmp1; \
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \
} while (0);

#define GET_VECTOR_HIGH_ELEM(s, i, idx) \
vgetq_lane_f32(CONCAT(s, i).value.val[1], idx)
#define GET_VECTOR_LOW_ELEM(s, i, idx) \
vgetq_lane_f32(CONCAT(s, i).value.val[0], idx)
struct InputTransform6X3 {
template <bool inner>
static void transform(const float* input, float* input_transform_buf,
float* transform_mid_buf, int ih_start, int iw_start,
size_t ic, size_t IH, size_t IW, size_t IC,
size_t unit_idx, size_t nr_units_in_tile) {
constexpr size_t alpha = 6 + 3 - 1;
if (!inner) {
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha);
}

#define cb(i) Vector<float, 8> d##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

if (inner) {
const float* input_ptr =
input + ic * IH * IW + ih_start * IW + iw_start;
#define cb(i) d##i = Vector<float, 8>::load(input_ptr + IW * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
transform_mid_buf[iho * alpha + iwo] =
input[ic * IH * IW + ih * IW + iw];
}
}
#define cb(i) d##i = Vector<float, 8>::load(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
}

#define cb(i) Vector<float, 8> wd##i, ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

INPUT_TRANSFORM(d, wd);

#if MEGDNN_AARCH64
TRANSPOSE_8x8(wd, d);
INPUT_TRANSFORM(d, ret);

#define cb(i) ret##i.save(transform_mid_buf + i * alpha);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

rep(i, alpha) rep(j, alpha) {
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC +
unit_idx * IC + ic] =
transform_mid_buf[j * alpha + i];
}
#else
//! 1 0 0 0 0 0 0 0
//! 0 1 -1 0.5 -0.5 2 -2 -1
//! -5.25 1 1 0.25 0.25 4 4 0
//! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25
//! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0
//! 0 1 -1 2 -2 0.5 -0.5 -5.25
//! -1 1 1 1 1 1 1 0
//! 0 0 0 0 0 0 0 1
#define cb(i) \
do { \
mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \
GET_VECTOR_HIGH_ELEM(wd, i, 2) + \
5.25f * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \
GET_VECTOR_LOW_ELEM(wd, i, 2)); \
mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \
GET_VECTOR_LOW_ELEM(wd, i, 1) + \
5.25f * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \
GET_VECTOR_HIGH_ELEM(wd, i, 1)); \
auto tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \
GET_VECTOR_HIGH_ELEM(wd, i, 2) - \
4.25f * GET_VECTOR_HIGH_ELEM(wd, i, 0); \
auto tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \
GET_VECTOR_HIGH_ELEM(wd, i, 1) - \
4.25f * GET_VECTOR_LOW_ELEM(wd, i, 3); \
mid_buf1[1] = tmp0 + tmp1; \
mid_buf1[2] = tmp0 - tmp1; \
tmp0 = GET_VECTOR_HIGH_ELEM(wd, i, 2) + \
0.25f * GET_VECTOR_LOW_ELEM(wd, i, 2) - \
GET_VECTOR_HIGH_ELEM(wd, i, 0) * 1.25f; \
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5f - \
GET_VECTOR_LOW_ELEM(wd, i, 3) * 2.5f + \
GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2.f; \
mid_buf1[3] = tmp0 + tmp1; \
mid_buf1[4] = tmp0 - tmp1; \
tmp0 = GET_VECTOR_HIGH_ELEM(wd, i, 2) + \
(GET_VECTOR_LOW_ELEM(wd, i, 2) - \
GET_VECTOR_HIGH_ELEM(wd, i, 0) * 1.25f) * \
4; \
tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 2.f - \
GET_VECTOR_LOW_ELEM(wd, i, 3) * 2.5f + \
GET_VECTOR_HIGH_ELEM(wd, i, 1) * 0.5f; \
mid_buf1[5] = tmp0 + tmp1; \
mid_buf1[6] = tmp0 - tmp1; \
mid_buf1 += 8; \
} while (0);

float* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb);
mid_buf1 = transform_mid_buf;

#undef cb
rep(i, alpha) rep(j, alpha) {
input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC +
unit_idx * IC + ic] =
transform_mid_buf[i * alpha + j];
}
#endif
}
};

#undef INPUT_TRANSFORM

/**
* Output Transform: use fma
*
* s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6) / 32
* s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6) / 32
* s2 = (m1 + m2) + 4 * (m3 + m4) + 8 * (m5 + m6) / 32
* s3 = (m1 - m2) + 8 * (m3 - m4) + 4 * (m5 - m6) / 32
* s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6) / 32
* s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) / 32 + m7
*/
#define OUTPUT_TRANSFORM(m, s) \
do { \
auto m1addm2 = m##1 + m##2; \
auto m1subm2 = m##1 - m##2; \
auto m3addm4 = m##3 + m##4; \
auto m3subm4 = m##3 - m##4; \
auto m5addm6 = (m##5 + m##6) * 0.03125f; \
auto m5subm6 = (m##5 - m##6) * 0.03125f; \
s##0 = m##0; \
CONCAT(s, 0).mla(m5addm6, 32.f).add(m3addm4).add(m1addm2); \
CONCAT(s, 1) = m1subm2; \
CONCAT(s, 1).mla(m3subm4, 2.f).mla(m5subm6, 16.f); \
CONCAT(s, 2) = m1addm2; \
CONCAT(s, 2).mla(m3addm4, 4.f).mla(m5addm6, 8.f); \
CONCAT(s, 3) = m1subm2; \
CONCAT(s, 3).mla(m3subm4, 8.f).mla(m5subm6, 4.f); \
CONCAT(s, 4) = m1addm2; \
CONCAT(s, 4).mla(m3addm4, 16.f).mla(m5addm6, 2.f); \
CONCAT(s, 5) = m1subm2; \
CONCAT(s, 5).mla(m3subm4, 32.f).add(m5subm6).add(m##7); \
} while (0);

template <BiasMode bmode, typename Op>
struct OutputTransform6X3 {
static void transform(const float* output_transform_buf, const float* bias,
float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end,
size_t oc_index, size_t unit_idx,
size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
constexpr size_t alpha = 6 + 3 - 1;
Op op(src_dtype, dst_dtype);
float* mid_buf1 = transform_mid_buf;

//! AT * m * A
size_t OC = oc_end - oc_start;
size_t oc = oc_start + oc_index;

#define cb(m, n) \
transform_mid_buf[m * alpha + n] = \
output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \
unit_idx * OC + oc_index];
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb

#define cb(i) auto m##i = Vector<float, 8>::load(transform_mid_buf + alpha * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
#define cb(i) Vector<float, 8> s##i, ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

OUTPUT_TRANSFORM(m, s);
/**
* Output transform: m * A
*
* 1 0 0 0 0 0
* 1 1 1 1 1 1
* 1 -1 1 -1 1 -1
* 1 2 4 8 16 32
* 1 -2 4 -8 16 -32
* 1 0.5 0.25 0.125 0.0625 0.03125
* 1 -0.5 0.25 -0.125 0.0625 -0.03125
* 0 0.0 0 0 0 1
*/
#define cb(i) \
do { \
auto m1addm2 = \
GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \
auto m1subm2 = \
GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \
auto m3addm4 = \
GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \
auto m3subm4 = \
GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \
auto m5addm6 = \
GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \
auto m5subm6 = \
GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \
mid_buf1[0] = \
GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \
mid_buf1[1] = m1subm2 + 2.f * m3subm4 + 0.5f * m5subm6; \
mid_buf1[2] = m1addm2 + 4.f * m3addm4 + 0.25f * m5addm6; \
mid_buf1[3] = m1subm2 + 8.f * m3subm4 + 0.125f * m5subm6; \
mid_buf1[4] = m1addm2 + 16.f * m3addm4 + 0.0625f * m5addm6; \
mid_buf1[5] = m1subm2 + 32.f * m3subm4 + 0.03125f * m5subm6 + \
GET_VECTOR_HIGH_ELEM(s, i, 3); \
mid_buf1 += 6; \
} while (0);

mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(6, cb);
mid_buf1 = transform_mid_buf;
#undef cb

if (oh_start + 6 <= OH && ow_start + 6 <= OW) {
float32x4_t bias0;
float32x2_t bias1;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias0 = vdupq_n_f32(bias[oc]);
bias1 = vdup_n_f32(bias[oc]);
}
rep(i, 6) {
size_t oh = oh_start + i;
float32x4_t item0 = vld1q_f32(mid_buf1);
float32x2_t item1 = vld1_f32(mid_buf1 + 4);

if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
item0 = vaddq_f32(item0, bias0);
item1 = vadd_f32(item1, bias1);
} else if (bmode == BiasMode::BIAS) {
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start);
bias1 = vld1_f32(bias + oc * OH * OW + oh * OW + ow_start +
4);
item0 = vaddq_f32(item0, bias0);
item1 = vadd_f32(item1, bias1);
}
item0 = op(item0);
item1 = vset_lane_f32(op(vget_lane_f32(item1, 0)), item1, 0);
item1 = vset_lane_f32(op(vget_lane_f32(item1, 1)), item1, 1);
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0);
vst1_f32(output + oc * OH * OW + oh * OW + ow_start + 4, item1);

mid_buf1 += 6;
}
} else {
for (size_t oho = 0; oho < 6 && oh_start + oho < OH; ++oho) {
for (size_t owo = 0; owo < 6 && ow_start + owo < OW; ++owo) {
size_t oh = oh_start + oho;
size_t ow = ow_start + owo;
float res = mid_buf1[oho * 6 + owo];
if (bmode == BiasMode::BIAS) {
res += bias[oc * OH * OW + oh * OW + ow];
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
res += bias[oc];
}
res = op(res);
output[oc * OH * OW + oh * OW + ow] = res;
}
}
}
}
};

#undef GET_VECTOR_HIGH_ELEM
#undef GET_VECTOR_LOW_ELEM
#undef OUTPUT_TRANSFORM

} // namespace

namespace megdnn {
namespace arm_common {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f)

void winograd_6x3_1x1_f::filter(const float* filter,
float* filter_transform_buf,
float* transform_mid_buf, size_t OC, size_t IC,
size_t oc_start, size_t oc_end) {
FilterTransform6X3<param::MatrixMul::Format::DEFAULT>::transform(
filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start,
oc_end);
}

void winograd_6x3_1x1_f::input(const float* input, float* input_transform_buf,
float* transform_mid_buf, size_t IH, size_t IW,
size_t IC, size_t PH, size_t PW,
size_t unit_start_idx, size_t nr_units_in_tile) {
constexpr int alpha = 3 + 6 - 1;

// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
rep(ic, IC) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransform6X3::transform<true>(
input, input_transform_buf, transform_mid_buf, ih_start,
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile);

} else {
InputTransform6X3::transform<false>(
input, input_transform_buf, transform_mid_buf, ih_start,
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile);
}
}
}
}

void winograd_6x3_1x1_f::output(const float* output_transform_buf,
const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH, size_t OW,
size_t oc_start, size_t oc_end,
size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__);

auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);

for (size_t oc = oc_start; oc < oc_end; oc++) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp32_F63, cb, float, float, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
}
}
#undef cb
}

} // namespace winograd
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 351
- 0
dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp View File

@@ -0,0 +1,351 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/filter_transform.h"
#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/common/winograd/winograd_helper.h"
#include "src/fallback/conv_bias/winograd/winograd.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63_4x4)

using namespace megdnn;
using namespace arm_common;

namespace {

struct InputTransform6X3 {
template <bool inner>
static void prepare(const float* input, float* patch, float* patchT,
int ih_start, int iw_start, size_t IH, size_t IW,
size_t ic, size_t IC) {
constexpr size_t alpha = 6 + 3 - 1;
if (!(inner && ic + 4 < IC)) {
memset(patch, 0, sizeof(float) * 4 * alpha * alpha);
}
if (inner) {
const float* input_ptr =
input + ic * IH * IW + ih_start * IW + iw_start;
for (size_t ico = 0; ico < 4; ++ico) {
if (ic + ico < IC) {
#define cb(i) \
auto v##i##0 = vld1q_f32(input_ptr + IW * i); \
auto v##i##1 = vld1q_f32(input_ptr + IW * i + 4);

UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

#define cb(i) \
vst1q_f32(patch + ico * 8 * alpha + i * 8, v##i##0); \
vst1q_f32(patch + ico * 8 * alpha + i * 8 + 4, v##i##1);

UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
input_ptr += IH * IW;
}
}
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
// partial copy
for (size_t ico = 0; ico < 4; ++ico) {
if (ic + ico < IC) {
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
patch[ico * alpha * 8 + iho * 8 + iwo] =
input[(ic + ico) * IH * IW + ih * IW + iw];
}
}
}
}
}

#define cb(i) \
transpose_4x4(patch + 8 * i + 0, patchT + 8 * i * 4, 64, 4); \
transpose_4x4(patch + 8 * i + 4, patchT + 8 * i * 4 + 4 * 4, 64, 4);
UNROLL_CALL_NOWRAPPER(8, cb)
#undef cb
}

static void transform(const float* patchT, float* input_transform_buf,
size_t unit_idx, size_t nr_units_in_tile, size_t ic,
size_t IC) {
constexpr size_t alpha = 6 + 3 - 1;
// BT * d * B
#define cb(m, n) \
Vector<float, 4> d##m##n = \
Vector<float, 4>::load(patchT + m * 8 * 4 + n * 4);

UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb

//! B
//! 1 0 0 0 0 0 0 0
//! 0 1 -1 0.5 -0.5 2 -2 -1
//! -5.25 1 1 0.25 0.25 4 4 0
//! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25
//! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0
//! 0 1 -1 2 -2 0.5 -0.5 -5.25
//! -1 1 1 1 1 1 1 0
//! 0 0 0 0 0 0 0 1
#define cb(m) \
auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \
auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \
auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \
auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \
d5##m * 2.f + d6##m; \
auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - \
d4##m * 1.25f - d5##m * 2.f + d6##m; \
auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \
d5##m * 0.5f + d6##m; \
auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \
d5##m * 0.5f + d6##m; \
auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f;

UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

#define cb(m) \
d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \
d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - \
(t##m##3 + t##m##4) * 4.25f; \
d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + \
(t##m##3 - t##m##4) * 4.25f; \
d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - \
t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; \
d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - \
t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; \
d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \
t##m##5 * 0.5f + t##m##6; \
d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - \
t##m##4 * 5.f - t##m##5 * 0.5f + t##m##6; \
d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f;

UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

size_t ICB = IC / 4;
size_t icb = ic / 4;
#define cb(m, n) \
d##m##n.save(input_transform_buf + \
(m * alpha + n) * ICB * nr_units_in_tile * 4 + \
icb * nr_units_in_tile * 4 + unit_idx * 4);
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb)
#undef cb
}
};

template <BiasMode bmode, typename Op>
struct OutputTransform6X3 {
static void transform(const float* output_transform_buf, const float* bias,
float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end,
size_t oc_index, size_t unit_idx,
size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
Op op(src_dtype, dst_dtype);
//! AT * m * A
constexpr size_t alpha = 6 + 3 - 1;

size_t oc = oc_start + oc_index;
size_t OCB = (oc_end - oc_start) / 4;
size_t ocb = oc_index / 4;

#define cb(m, n) \
auto v##m##n = Vector<float, 4>::load( \
output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * 4 + \
ocb * nr_units_in_tile * 4 + unit_idx * 4);
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb

/**
* A
*
* 1 0 0 0 0 0
* 1 1 1 1 1 1
* 1 -1 1 -1 1 -1
* 1 2 4 8 16 32
* 1 -2 4 -8 16 -32
* 1 0.5 0.25 0.125 0.0625 0.03125
* 1 -0.5 0.25 -0.125 0.0625 -0.03125
* 0 0.0 0 0 0 1
*/

Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
#define cb(m) \
v1addv2 = v1##m + v2##m; \
v1subv2 = v1##m - v2##m; \
v3addv4 = v3##m + v4##m; \
v3subv4 = v3##m - v4##m; \
v5addv6 = v5##m + v6##m; \
v5subv6 = v5##m - v6##m; \
auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \
auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \
auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \
auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m;

UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

#define cb(m) \
v1addv2 = t##m##1 + t##m##2; \
v1subv2 = t##m##1 - t##m##2; \
v3addv4 = t##m##3 + t##m##4; \
v3subv4 = t##m##3 - t##m##4; \
v5addv6 = t##m##5 + t##m##6; \
v5subv6 = t##m##5 - t##m##6; \
v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \
v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \
v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \
v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7;

UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb

Vector<float, 4> vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 4>::load(bias + oc);

#define cb(m, n) v##m##n += vbias;
UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb
}
if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb
}

#define cb(m, n) CONCAT(v##m, n).save(transform_mid_buf + (m * 6 + n) * 4);
UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb

for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) {
for (size_t oho = 0; oho < 6 && oh_start + oho < OH; ++oho) {
for (size_t owo = 0; owo < 6 && ow_start + owo < OW; ++owo) {
size_t oh = oh_start + oho;
size_t ow = ow_start + owo;
float res = transform_mid_buf[oho * 6 * 4 + owo * 4 + oco];
if (bmode == BiasMode::BIAS) {
res += bias[(oc + oco) * OH * OW + oh * OW + ow];
res = op(res);
}
output[(oc + oco) * OH * OW + oh * OW + ow] = res;
}
}
}
}
};
} // namespace

namespace megdnn {
namespace arm_common {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_4x4_f)

void winograd_6x3_4x4_f::filter(const float* filter,
float* filter_transform_buf,
float* transform_mid_buf, size_t OC, size_t IC,
size_t oc_start, size_t oc_end) {
FilterTransform6X3<param::MatrixMul::Format::MK4>::transform(
filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start,
oc_end);
}

void winograd_6x3_4x4_f::input(const float* input, float* input_transform_buf,
float* transform_mid_buf, size_t IH, size_t IW,
size_t IC, size_t PH, size_t PW,
size_t unit_start_idx, size_t nr_units_in_tile) {
megdnn_assert(IC % 4 == 0);
constexpr int alpha = 3 + 6 - 1;

// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
float* patch = transform_mid_buf;
float* patchT = transform_mid_buf + 4 * alpha * alpha;

for (size_t ic = 0; ic < IC; ic += 4) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransform6X3::prepare<true>(input, patch, patchT, ih_start,
iw_start, IH, IW, ic, IC);
InputTransform6X3::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile, ic,
IC);

} else {
InputTransform6X3::prepare<false>(input, patch, patchT,
ih_start, iw_start, IH, IW,
ic, IC);
InputTransform6X3::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile, ic,
IC);
}
}
}
}

void winograd_6x3_4x4_f::output(const float* output_transform_buf,
const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH, size_t OW,
size_t oc_start, size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__);

auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);

for (size_t oc = oc_start; oc < oc_end; oc += 4) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp32_F63_4x4, cb, float, float, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
}
}
#undef cb
}

} // namespace winograd
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 82
- 0
dnn/src/arm_common/conv_bias/img2col_helper.h View File

@@ -0,0 +1,82 @@
/**
* \file dnn/src/arm_common/conv_bias/img2col_helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <cstddef>
#include "src/common/utils.h"

namespace {

template <bool is_xcorr, typename dtype>
void img2col_stride(const dtype* __restrict src,
dtype* __restrict dst, const int OC, const int OH,
const int OW, const int IC, const int IH, const int IW,
const int FH, const int FW, const int SH, const int SW) {
(void)OC;
size_t i = 0;
rep(ic, IC) {
rep(fh, FH) {
rep(fw, FW) {
rep(oh, OH) {
rep(ow, OW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
dst[i++] = src[ic * IH * IW + (oh * SH + fh2) * IW +
(ow * SW + fw2)];
}
}
}
}
}
}

template <bool is_xcorr, typename dtype>
void img2col(const dtype* src, dtype* dst, size_t /* OC */, size_t OH,
size_t OW, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW) {
size_t offset = (4 - OW % 4) % 4;
size_t i = 0;
rep(ic, IC) {
rep(fh, FH) {
rep(fw, FW) {
rep(oh, OH) {
size_t ow = 0;
for (; ow < OW; ow += 4) {
size_t fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
dst[i++] = src[ic * IH * IW + (oh + fh2) * IW +
(ow + fw2) + 0];
dst[i++] = src[ic * IH * IW + (oh + fh2) * IW +
(ow + fw2) + 1];
dst[i++] = src[ic * IH * IW + (oh + fh2) * IW +
(ow + fw2) + 2];
dst[i++] = src[ic * IH * IW + (oh + fh2) * IW +
(ow + fw2) + 3];
}
i -= offset;
}
}
}
}
}

} // anonymous namespace

// vim: syntax=cpp.doxygen

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save