Browse Source

fix(dnn): fix can not inline small function with GCC compiler

GitOrigin-RevId: a23605c9e2
release-0.6
Megvii Engine Team 5 years ago
parent
commit
9e904f683b
13 changed files with 298 additions and 187 deletions
  1. +3
    -0
      CMakeLists.txt
  2. +7
    -0
      dnn/include/megdnn/arch.h
  3. +5
    -5
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h
  4. +7
    -6
      dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp
  5. +11
    -9
      dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp
  6. +19
    -15
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h
  7. +19
    -15
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h
  8. +95
    -93
      dnn/src/arm_common/conv_bias/intrinsic_helper.h
  9. +8
    -7
      dnn/src/arm_common/intrinsic_helper.h
  10. +22
    -12
      dnn/src/arm_common/neon_struct.h
  11. +81
    -14
      dnn/src/arm_common/simd_macro/marm_neon.h
  12. +15
    -11
      dnn/test/arm_common/conv_bias_multi_thread.cpp
  13. +6
    -0
      toolchains/aarch64-none-linux-gnu.toolchain.cmake

+ 3
- 0
CMakeLists.txt View File

@@ -128,8 +128,11 @@ else()
set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g") set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g")
if(ANDROID) if(ANDROID)
set(CMAKE_CXX_FLAGS_RELEASE "-Ofast -DNDEBUG") set(CMAKE_CXX_FLAGS_RELEASE "-Ofast -DNDEBUG")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-Ofast -DNDEBUG -g")
else() else()
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG") set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -DNDEBUG -g")
endif() endif()
endif() endif()




+ 7
- 0
dnn/include/megdnn/arch.h View File

@@ -29,6 +29,13 @@
#define megdnn_likely(v) __builtin_expect(bool(v), 1) #define megdnn_likely(v) __builtin_expect(bool(v), 1)
#define megdnn_unlikely(v) __builtin_expect(bool(v), 0) #define megdnn_unlikely(v) __builtin_expect(bool(v), 0)


#if !defined(__clang__) && MEGDNN_ARMV7 && !defined(NDEBUG)
//! Thumb2 limit code length
#define MEGDNN_ALWAYS_INLINE
#else
#define MEGDNN_ALWAYS_INLINE inline __attribute__((__always_inline__))
#endif

#define MEGDNN_DEPRECATED __attribute__((deprecated)) #define MEGDNN_DEPRECATED __attribute__((deprecated))
#define MEGDNN_PACKED __attribute__((packed)) #define MEGDNN_PACKED __attribute__((packed))
#define MEGDNN_CONSTEXPR constexpr #define MEGDNN_CONSTEXPR constexpr


+ 5
- 5
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h View File

@@ -10,6 +10,7 @@
* implied. * implied.
*/ */
#pragma once #pragma once
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
@@ -17,7 +18,6 @@
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"

namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
namespace { namespace {
@@ -32,13 +32,13 @@ namespace {
template <int src_idx, int weight_idx, int c_dim, typename Func, int stride, template <int src_idx, int weight_idx, int c_dim, typename Func, int stride,
typename T, typename T2, typename T3> typename T, typename T2, typename T3>
struct ShiftCalHelper { struct ShiftCalHelper {
static void impl(T& c, T2& src, T3& weight);
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
}; };


template <int src_idx, int weight_idx, typename Func, int stride, typename T, template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3> typename T2, typename T3>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, stride, T, T2, T3> { struct ShiftCalHelper<src_idx, weight_idx, 2, Func, stride, T, T2, T3> {
static void impl(T& c, T2& src, T3& weight) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step) \ #define cb(step) \
c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \ c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[0][step], weight[0][weight_idx], \ c[0][step], weight[0][weight_idx], \
@@ -54,7 +54,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, stride, T, T2, T3> {
template <int src_idx, int weight_idx, typename Func, int stride, typename T, template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3> typename T2, typename T3>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, stride, T, T2, T3> { struct ShiftCalHelper<src_idx, weight_idx, 1, Func, stride, T, T2, T3> {
static void impl(T& c, T2& src, T3& weight) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step) \ #define cb(step) \
c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \ c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[0][step], weight[0][weight_idx], \ c[0][step], weight[0][weight_idx], \
@@ -67,7 +67,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, stride, T, T2, T3> {


template <int src_idx, int weight_idx, int c_dim, typename FUNC, int stride, template <int src_idx, int weight_idx, int c_dim, typename FUNC, int stride,
typename T, typename T2, typename T3> typename T, typename T2, typename T3>
inline void cal_helper(T& c, T2& src, T3& weight) {
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, stride, T, T2, T3>::impl( ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, stride, T, T2, T3>::impl(
c, src, weight); c, src, weight);
}; };


+ 7
- 6
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp View File

@@ -11,6 +11,7 @@
* implied. * implied.
*/ */


#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
@@ -26,13 +27,13 @@ namespace {
template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block, template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block,
typename T, typename T2, typename T3, typename T4> typename T, typename T2, typename T3, typename T4>
struct ShiftCalHelper { struct ShiftCalHelper {
static void impl(T& c, T2& src, T3& weight);
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
}; };


template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \ #define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]); \ src[(step + src_idx) % 8]); \
@@ -49,7 +50,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \ #define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]); \ src[(step + src_idx) % 4]); \
@@ -66,7 +67,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \ #define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]); src[(step + src_idx) % 8]);
@@ -81,7 +82,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \ #define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]); src[(step + src_idx) % 4]);
@@ -96,7 +97,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {


template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block, template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block,
typename T, typename T2, typename T3> typename T, typename T2, typename T3>
inline void cal_helper(T& c, T2& src, T3& weight) {
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, ow_block, T, T2, T3, ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, ow_block, T, T2, T3,
int>::impl(c, src, weight); int>::impl(c, src, weight);
}; };


+ 11
- 9
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp View File

@@ -11,6 +11,7 @@
* implied. * implied.
*/ */


#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
@@ -26,13 +27,13 @@ namespace {
template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block, template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block,
typename T, typename T2, typename T3, typename T4> typename T, typename T2, typename T3, typename T4>
struct ShiftCalHelper { struct ShiftCalHelper {
static void impl(T& c, T2& src, T3& weight);
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
}; };


template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \ #define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]); \ src[(step + src_idx) % 8]); \
@@ -49,7 +50,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \ #define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]); \ src[(step + src_idx) % 4]); \
@@ -66,7 +67,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \ #define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]); src[(step + src_idx) % 8]);
@@ -81,7 +82,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \ #define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]); src[(step + src_idx) % 4]);
@@ -96,7 +97,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {


template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block, template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block,
typename T, typename T2, typename T3> typename T, typename T2, typename T3>
inline void cal_helper(T& c, T2& src, T3& weight) {
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, ow_block, T, T2, T3, ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, ow_block, T, T2, T3,
int>::impl(c, src, weight); int>::impl(c, src, weight);
}; };
@@ -462,9 +463,10 @@ inline void odd_even_split_iw8_even(float* sptr_base, const float* sptr,
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]); vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]);
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]); vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]);
} }
void odd_even_split_iw8_odd(float* sptr_base, const float* sptr,
const int odd_start, const int src_idx,
const int iw_idx) {

inline void odd_even_split_iw8_odd(float* sptr_base, const float* sptr,
const int odd_start, const int src_idx,
const int iw_idx) {
constexpr int ic_step = 4; constexpr int ic_step = 4;
const int src_offset = src_idx * ic_step; const int src_offset = src_idx * ic_step;
const int even_offset = (iw_idx + 1) / 2 * ic_step; const int even_offset = (iw_idx + 1) / 2 * ic_step;


+ 19
- 15
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h View File

@@ -5,11 +5,13 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once
#ifdef __ARM_FEATURE_DOTPROD #ifdef __ARM_FEATURE_DOTPROD


#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
#include "src/arm_common/intrinsic_helper.h" #include "src/arm_common/intrinsic_helper.h"
@@ -27,8 +29,8 @@ constexpr int filter_next_col =
IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]


template <int row, BiasMode bias_mode> template <int row, BiasMode bias_mode>
inline void init_ocx_ow8(int32x4_t c[][8], const int32_t* bias_ptr,
int oc_step) {
MEGDNN_ALWAYS_INLINE void init_ocx_ow8(int32x4_t c[][8],
const int32_t* bias_ptr, int oc_step) {
static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number."); static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number.");
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
#define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step); #define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step);
@@ -90,12 +92,13 @@ inline void init_ocx_ow8(int32x4_t c[][8], const int32_t* bias_ptr,


template <int row, int ow_remain, typename Op, typename T> template <int row, int ow_remain, typename Op, typename T>
struct StoreOCxOWx { struct StoreOCxOWx {
static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr,
const int ld_dst_oc);
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc);
}; };


template <int ow_remain, typename Op, typename T> template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<1, ow_remain, Op, T> { struct StoreOCxOWx<1, ow_remain, Op, T> {

static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr,
const int ld_dst_oc) { const int ld_dst_oc) {
MEGDNN_MARK_USED_VAR(ld_dst_oc); MEGDNN_MARK_USED_VAR(ld_dst_oc);
@@ -128,8 +131,8 @@ struct StoreOCxOWx<1, ow_remain, Op, T> {


template <int ow_remain, typename Op, typename T> template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<2, ow_remain, Op, T> { struct StoreOCxOWx<2, ow_remain, Op, T> {
static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr,
const int ld_dst_oc) {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc) {
switch (ow_remain) { switch (ow_remain) {
case 8: case 8:
UNROLL_CALL_RAW(4, cb22); UNROLL_CALL_RAW(4, cb22);
@@ -159,8 +162,8 @@ struct StoreOCxOWx<2, ow_remain, Op, T> {


template <int ow_remain, typename Op, typename T> template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<3, ow_remain, Op, T> { struct StoreOCxOWx<3, ow_remain, Op, T> {
static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr,
const int ld_dst_oc) {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc) {
switch (ow_remain) { switch (ow_remain) {
case 8: case 8:
UNROLL_CALL_RAW(4, cb32); UNROLL_CALL_RAW(4, cb32);
@@ -196,15 +199,16 @@ struct StoreOCxOWx<3, ow_remain, Op, T> {
#undef cb32 #undef cb32


template <int row, int ow_remain, typename Op, typename T> template <int row, int ow_remain, typename Op, typename T>
inline void store_ocx_owx_remain_static(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc) {
MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8],
const Op& op, T* dst_ptr,
const int ld_dst_oc) {
StoreOCxOWx<row, ow_remain, Op, T>::impl(res, op, dst_ptr, ld_dst_oc); StoreOCxOWx<row, ow_remain, Op, T>::impl(res, op, dst_ptr, ld_dst_oc);
} }


template <int res_row, int src_row, int src_start_idx, int weight_idx, template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename FUNC, typename T, typename T2, typename T3> typename FUNC, typename T, typename T2, typename T3>
struct ShiftCalHelper { struct ShiftCalHelper {
static void impl(T& res, T2& src, T3& weight) {
static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) {
#define cb(step) \ #define cb(step) \
res[res_row][step] = FUNC::template impl<((src_start_idx + step) % 4)>( \ res[res_row][step] = FUNC::template impl<((src_start_idx + step) % 4)>( \
res[res_row][step], weight[weight_idx], \ res[res_row][step], weight[weight_idx], \
@@ -216,7 +220,7 @@ struct ShiftCalHelper {


template <int res_row, int src_row, int src_start_idx, int weight_idx, template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename FUNC, typename T, typename T2, typename T3> typename FUNC, typename T, typename T2, typename T3>
inline void cal_helper(T& res, T2& src, T3& weight) {
MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) {
ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, FUNC, T, T2, ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, FUNC, T, T2,
T3>::impl(res, src, weight); T3>::impl(res, src, weight);
}; };
@@ -428,4 +432,4 @@ struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size,


#endif #endif


//vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 19
- 15
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h View File

@@ -10,6 +10,7 @@
* implied. * implied.
*/ */
#pragma once #pragma once
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
@@ -37,13 +38,13 @@ namespace {
template <int src_idx, int weight_idx, int c_dim, typename Func, int stride, template <int src_idx, int weight_idx, int c_dim, typename Func, int stride,
typename T, typename T2, typename T3, typename T4> typename T, typename T2, typename T3, typename T4>
struct ShiftCalHelper { struct ShiftCalHelper {
static void impl(T& c, T2& src, T3& weight, T4& temp);
static void impl(T& c, T2& src, T3& weight);
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp);
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
}; };
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 2, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 2, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight, T4& temp) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0],
temp[0]); temp[0]);
c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0], c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0],
@@ -61,7 +62,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 2, T, T2, T3, T4> {
c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3], c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3],
temp[3]); temp[3]);
} }
static void impl(T& c, T2& src, T3& weight) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]);
c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0]); c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0]);
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]);
@@ -75,7 +76,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 2, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 2, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 2, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight, T4& temp) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0],
temp[0]); temp[0]);
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1],
@@ -85,7 +86,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 2, T, T2, T3, T4> {
c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3],
temp[2]); temp[2]);
} }
static void impl(T& c, T2& src, T3& weight) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]);
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]);
c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]);
@@ -96,7 +97,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 2, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 1, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight, T4& temp) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = Func::impl(src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0] = Func::impl(src[(0 + src_idx) % 8], weight[0][weight_idx],
c[0][0], temp[0]); c[0][0], temp[0]);
c[1][0] = Func::impl(src[(0 + src_idx) % 8], weight[1][weight_idx], c[1][0] = Func::impl(src[(0 + src_idx) % 8], weight[1][weight_idx],
@@ -131,12 +132,12 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 1, T, T2, T3, T4> {
c[1][7] = Func::impl(src[(7 + src_idx) % 8], weight[1][weight_idx], c[1][7] = Func::impl(src[(7 + src_idx) % 8], weight[1][weight_idx],
c[1][7], temp[3]); c[1][7], temp[3]);
} }
static void impl(T&, T2&, T3&);
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&);
}; };
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 1, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight, T4& temp) {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = Func::impl(src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0] = Func::impl(src[(0 + src_idx) % 8], weight[0][weight_idx],
c[0][0], temp[0]); c[0][0], temp[0]);
c[0][1] = Func::impl(src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1] = Func::impl(src[(1 + src_idx) % 8], weight[0][weight_idx],
@@ -154,18 +155,18 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 1, T, T2, T3, T4> {
c[0][7] = Func::impl(src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7] = Func::impl(src[(7 + src_idx) % 8], weight[0][weight_idx],
c[0][7], temp[3]); c[0][7], temp[3]);
} }
static void impl(T&, T2&, T3&);
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&);
}; };


template <int src_idx, int weight_idx, int c_dim, typename FUNC, int stride, template <int src_idx, int weight_idx, int c_dim, typename FUNC, int stride,
typename T, typename T2, typename T3, typename T4> typename T, typename T2, typename T3, typename T4>
inline void cal_helper(T& c, T2& src, T3& weight, T4& temp) {
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, stride, T, T2, T3, ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, stride, T, T2, T3,
T4>::impl(c, src, weight, temp); T4>::impl(c, src, weight, temp);
} }
template <int src_idx, int weight_idx, int c_dim, typename FUNC, int stride, template <int src_idx, int weight_idx, int c_dim, typename FUNC, int stride,
typename T, typename T2, typename T3> typename T, typename T2, typename T3>
inline void cal_helper(T& c, T2& src, T3& weight) {
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, stride, T, T2, T3, ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, stride, T, T2, T3,
int>::impl(c, src, weight); int>::impl(c, src, weight);
}; };
@@ -703,8 +704,9 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block, 1> {


enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 };
template <PACK_MODE mode> template <PACK_MODE mode>
inline void pack_src_one_line(const int8_t* inptr, int8_t* outptr, int left_pad,
int right_pad, const int iw) {
MEGDNN_ALWAYS_INLINE void pack_src_one_line(const int8_t* inptr, int8_t* outptr,
int left_pad, int right_pad,
const int iw) {
const int8_t* src_row_0 = inptr; const int8_t* src_row_0 = inptr;
const int8_t* src_row_1 = inptr + iw; const int8_t* src_row_1 = inptr + iw;
constexpr int combine_row = 2; constexpr int combine_row = 2;
@@ -1235,6 +1237,7 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> {
} }
} }
} }

if (oc_remain > 0) { if (oc_remain > 0) {
size_t oc_idx = oc_end; size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw; const size_t weight_offset = oc_idx * ic * fh * fw;
@@ -1284,4 +1287,5 @@ static void conv_direct_int8_nchw_nchw44(const int8_t* src,
} // namespace } // namespace
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen

// vim: syntax=cpp.doxygen

+ 95
- 93
dnn/src/arm_common/conv_bias/intrinsic_helper.h View File

@@ -15,18 +15,20 @@
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"

#define __ai inline __attribute__((__always_inline__))
namespace megdnn { namespace megdnn {
namespace { namespace {


////////////////////Store_OC4_OW8_Remain///////////////////////// ////////////////////Store_OC4_OW8_Remain/////////////////////////
template <int ow_remain, typename Op> template <int ow_remain, typename Op>
struct Store_OC4_OW8_Remain { struct Store_OC4_OW8_Remain {
static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr);
static __ai void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr);
}; };


template <typename Op> template <typename Op>
struct Store_OC4_OW8_Remain<0, Op> { struct Store_OC4_OW8_Remain<0, Op> {
static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
static __ai void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[2], c[3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op({{c[2], c[3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8));
op({{c[4], c[5]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 16)); op({{c[4], c[5]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 16));
@@ -36,7 +38,7 @@ struct Store_OC4_OW8_Remain<0, Op> {


template <typename Op> template <typename Op>
struct Store_OC4_OW8_Remain<7, Op> { struct Store_OC4_OW8_Remain<7, Op> {
static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
static __ai void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[2], c[3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op({{c[2], c[3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8));
op({{c[4], c[5]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 16)); op({{c[4], c[5]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 16));
@@ -45,7 +47,7 @@ struct Store_OC4_OW8_Remain<7, Op> {
}; };
template <typename Op> template <typename Op>
struct Store_OC4_OW8_Remain<6, Op> { struct Store_OC4_OW8_Remain<6, Op> {
static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
static __ai void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[2], c[3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op({{c[2], c[3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8));
op({{c[4], c[5]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 16)); op({{c[4], c[5]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 16));
@@ -53,7 +55,7 @@ struct Store_OC4_OW8_Remain<6, Op> {
}; };
template <typename Op> template <typename Op>
struct Store_OC4_OW8_Remain<5, Op> { struct Store_OC4_OW8_Remain<5, Op> {
static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
static __ai void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[2], c[3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op({{c[2], c[3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8));
op(c[4], reinterpret_cast<dt_qint8*>(dst_ptr + 16)); op(c[4], reinterpret_cast<dt_qint8*>(dst_ptr + 16));
@@ -61,46 +63,46 @@ struct Store_OC4_OW8_Remain<5, Op> {
}; };
template <typename Op> template <typename Op>
struct Store_OC4_OW8_Remain<4, Op> { struct Store_OC4_OW8_Remain<4, Op> {
static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
static __ai void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[2], c[3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op({{c[2], c[3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8));
} }
}; };
template <typename Op> template <typename Op>
struct Store_OC4_OW8_Remain<3, Op> { struct Store_OC4_OW8_Remain<3, Op> {
static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
static __ai void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op(c[2], reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op(c[2], reinterpret_cast<dt_qint8*>(dst_ptr + 8));
} }
}; };
template <typename Op> template <typename Op>
struct Store_OC4_OW8_Remain<2, Op> { struct Store_OC4_OW8_Remain<2, Op> {
static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
static __ai void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0], c[1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
} }
}; };
template <typename Op> template <typename Op>
struct Store_OC4_OW8_Remain<1, Op> { struct Store_OC4_OW8_Remain<1, Op> {
static void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
static __ai void impl(int32x4_t c[8], const Op& op, int8_t* dst_ptr) {
op(c[0], reinterpret_cast<dt_qint8*>(dst_ptr)); op(c[0], reinterpret_cast<dt_qint8*>(dst_ptr));
} }
}; };


template <int ow_remain, typename Op> template <int ow_remain, typename Op>
inline void store_oc4_ow8_remain_static(int32x4_t c[8], const Op& op,
int8_t* dst_ptr) {
__ai void store_oc4_ow8_remain_static(int32x4_t c[8], const Op& op,
int8_t* dst_ptr) {
Store_OC4_OW8_Remain<ow_remain, Op>::impl(c, op, dst_ptr); Store_OC4_OW8_Remain<ow_remain, Op>::impl(c, op, dst_ptr);
} }


template <int c_dim, int ow_remain, typename Op, typename T> template <int c_dim, int ow_remain, typename Op, typename T>
struct StoreOcxOw4Remain { struct StoreOcxOw4Remain {
static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc);
static __ai void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc);
}; };


template <typename Op, typename T> template <typename Op, typename T>
struct StoreOcxOw4Remain<2, 0, Op, T> { struct StoreOcxOw4Remain<2, 0, Op, T> {
static void impl(int32x4_t c[2][4], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
static __ai void impl(int32x4_t c[2][4], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8));


@@ -113,7 +115,7 @@ struct StoreOcxOw4Remain<2, 0, Op, T> {


template <typename Op, typename T> template <typename Op, typename T>
struct StoreOcxOw4Remain<2, 3, Op, T> { struct StoreOcxOw4Remain<2, 3, Op, T> {
static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op(c[0][2], reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op(c[0][2], reinterpret_cast<dt_qint8*>(dst_ptr + 8));


@@ -124,7 +126,7 @@ struct StoreOcxOw4Remain<2, 3, Op, T> {
}; };
template <typename Op, typename T> template <typename Op, typename T>
struct StoreOcxOw4Remain<2, 2, Op, T> { struct StoreOcxOw4Remain<2, 2, Op, T> {
static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[1][0], c[1][1]}}, op({{c[1][0], c[1][1]}},
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc)); reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc));
@@ -132,7 +134,7 @@ struct StoreOcxOw4Remain<2, 2, Op, T> {
}; };
template <typename Op, typename T> template <typename Op, typename T>
struct StoreOcxOw4Remain<2, 1, Op, T> { struct StoreOcxOw4Remain<2, 1, Op, T> {
static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) {
op(c[0][0], reinterpret_cast<dt_qint8*>(dst_ptr)); op(c[0][0], reinterpret_cast<dt_qint8*>(dst_ptr));
op(c[1][0], reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc)); op(c[1][0], reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc));
} }
@@ -140,8 +142,8 @@ struct StoreOcxOw4Remain<2, 1, Op, T> {


template <typename Op, typename T> template <typename Op, typename T>
struct StoreOcxOw4Remain<1, 0, Op, T> { struct StoreOcxOw4Remain<1, 0, Op, T> {
static void impl(int32x4_t c[2][4], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
static __ai void impl(int32x4_t c[2][4], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
MEGDNN_MARK_USED_VAR(ld_dst_oc); MEGDNN_MARK_USED_VAR(ld_dst_oc);
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8));
@@ -150,7 +152,7 @@ struct StoreOcxOw4Remain<1, 0, Op, T> {


template <typename Op, typename T> template <typename Op, typename T>
struct StoreOcxOw4Remain<1, 3, Op, T> { struct StoreOcxOw4Remain<1, 3, Op, T> {
static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) {
MEGDNN_MARK_USED_VAR(ld_dst_oc); MEGDNN_MARK_USED_VAR(ld_dst_oc);
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op(c[0][2], reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op(c[0][2], reinterpret_cast<dt_qint8*>(dst_ptr + 8));
@@ -158,33 +160,33 @@ struct StoreOcxOw4Remain<1, 3, Op, T> {
}; };
template <typename Op, typename T> template <typename Op, typename T>
struct StoreOcxOw4Remain<1, 2, Op, T> { struct StoreOcxOw4Remain<1, 2, Op, T> {
static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) {
MEGDNN_MARK_USED_VAR(ld_dst_oc); MEGDNN_MARK_USED_VAR(ld_dst_oc);
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
} }
}; };
template <typename Op, typename T> template <typename Op, typename T>
struct StoreOcxOw4Remain<1, 1, Op, T> { struct StoreOcxOw4Remain<1, 1, Op, T> {
static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) {
MEGDNN_MARK_USED_VAR(ld_dst_oc); MEGDNN_MARK_USED_VAR(ld_dst_oc);
op(c[0][0], reinterpret_cast<dt_qint8*>(dst_ptr)); op(c[0][0], reinterpret_cast<dt_qint8*>(dst_ptr));
} }
}; };
template <int c_dim, int ow_remain, typename Op, typename T> template <int c_dim, int ow_remain, typename Op, typename T>
inline void store_ocx_ow4_remain_static(T& c, const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
__ai void store_ocx_ow4_remain_static(T& c, const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
StoreOcxOw4Remain<c_dim, ow_remain, Op, T>::impl(c, op, dst_ptr, ld_dst_oc); StoreOcxOw4Remain<c_dim, ow_remain, Op, T>::impl(c, op, dst_ptr, ld_dst_oc);
} }
////////////////////Store_OCX_OW8_Remain///////////////////////// ////////////////////Store_OCX_OW8_Remain/////////////////////////
template <int c_dim, int ow_remain, typename Op, typename T, typename T2, template <int c_dim, int ow_remain, typename Op, typename T, typename T2,
typename T3> typename T3>
struct StoreOcxOw8Remain { struct StoreOcxOw8Remain {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc);
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc);
}; };


template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> { struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
@@ -200,7 +202,7 @@ struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> {
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3> { struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
@@ -216,7 +218,7 @@ struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3> {
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3> { struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
@@ -231,7 +233,7 @@ struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3> {
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3> { struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
@@ -244,7 +246,7 @@ struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3> {
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3> { struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16)); op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16));
@@ -256,7 +258,7 @@ struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3> {
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3> { struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));


@@ -266,7 +268,7 @@ struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3> {
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3> { struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8)); op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8));


@@ -276,14 +278,14 @@ struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3> {
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 2, Op, T, T2, T3> { struct StoreOcxOw8Remain<2, 2, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
} }
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> { struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op(c[0][0], reinterpret_cast<T3>(dst_ptr)); op(c[0][0], reinterpret_cast<T3>(dst_ptr));
op(c[1][0], reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); op(c[1][0], reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
} }
@@ -291,7 +293,7 @@ struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> {


template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> { struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
@@ -300,7 +302,7 @@ struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> {
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> { struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
@@ -310,7 +312,7 @@ struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> {


template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3> { struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
@@ -319,7 +321,7 @@ struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3> {
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3> { struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
@@ -327,7 +329,7 @@ struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3> {
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3> { struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16)); op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16));
@@ -335,41 +337,41 @@ struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3> {
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 4, Op, T, T2, T3> { struct StoreOcxOw8Remain<1, 4, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
} }
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 3, Op, T, T2, T3> { struct StoreOcxOw8Remain<1, 3, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8)); op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8));
} }
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 2, Op, T, T2, T3> { struct StoreOcxOw8Remain<1, 2, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
} }
}; };
template <typename Op, typename T, typename T2, typename T3> template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3> { struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op(c[0][0], reinterpret_cast<T3>(dst_ptr)); op(c[0][0], reinterpret_cast<T3>(dst_ptr));
} }
}; };


template <int c_dim, int ow_remain, typename Op, typename T, typename T2> template <int c_dim, int ow_remain, typename Op, typename T, typename T2>
inline void store_ocx_ow8_remain_static(T& c, const Op& op, T2 dst_ptr,
int ld_dst_oc) {
__ai void store_ocx_ow8_remain_static(T& c, const Op& op, T2 dst_ptr,
int ld_dst_oc) {
StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T2>::impl(c, op, dst_ptr, StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T2>::impl(c, op, dst_ptr,
ld_dst_oc); ld_dst_oc);
} }
template <int c_dim, int ow_remain, typename Op, typename T3, typename T, template <int c_dim, int ow_remain, typename Op, typename T3, typename T,
typename T2> typename T2>
inline void store_ocx_ow8_remain_static_dt(T& c, const Op& op, T2 dst_ptr,
int ld_dst_oc) {
__ai void store_ocx_ow8_remain_static_dt(T& c, const Op& op, T2 dst_ptr,
int ld_dst_oc) {
StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T3>::impl(c, op, dst_ptr, StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T3>::impl(c, op, dst_ptr,
ld_dst_oc); ld_dst_oc);
} }
@@ -377,14 +379,14 @@ inline void store_ocx_ow8_remain_static_dt(T& c, const Op& op, T2 dst_ptr,


template <int ow_remain, typename Op> template <int ow_remain, typename Op>
struct Store_OC8_OW8_Remain { struct Store_OC8_OW8_Remain {
static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc);
static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc);
}; };


template <typename Op> template <typename Op>
struct Store_OC8_OW8_Remain<0, Op> { struct Store_OC8_OW8_Remain<0, Op> {
static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 16)); op({{c[0][4], c[0][5]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 16));
@@ -403,8 +405,8 @@ struct Store_OC8_OW8_Remain<0, Op> {


template <typename Op> template <typename Op>
struct Store_OC8_OW8_Remain<7, Op> { struct Store_OC8_OW8_Remain<7, Op> {
static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 16)); op({{c[0][4], c[0][5]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 16));
@@ -422,8 +424,8 @@ struct Store_OC8_OW8_Remain<7, Op> {


template <typename Op> template <typename Op>
struct Store_OC8_OW8_Remain<6, Op> { struct Store_OC8_OW8_Remain<6, Op> {
static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 16)); op({{c[0][4], c[0][5]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 16));
@@ -439,8 +441,8 @@ struct Store_OC8_OW8_Remain<6, Op> {


template <typename Op> template <typename Op>
struct Store_OC8_OW8_Remain<5, Op> { struct Store_OC8_OW8_Remain<5, Op> {
static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8));
op(c[0][4], reinterpret_cast<dt_qint8*>(dst_ptr + 16)); op(c[0][4], reinterpret_cast<dt_qint8*>(dst_ptr + 16));
@@ -455,8 +457,8 @@ struct Store_OC8_OW8_Remain<5, Op> {


template <typename Op> template <typename Op>
struct Store_OC8_OW8_Remain<4, Op> { struct Store_OC8_OW8_Remain<4, Op> {
static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8));


@@ -469,8 +471,8 @@ struct Store_OC8_OW8_Remain<4, Op> {


template <typename Op> template <typename Op>
struct Store_OC8_OW8_Remain<3, Op> { struct Store_OC8_OW8_Remain<3, Op> {
static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op(c[0][2], reinterpret_cast<dt_qint8*>(dst_ptr + 8)); op(c[0][2], reinterpret_cast<dt_qint8*>(dst_ptr + 8));


@@ -481,8 +483,8 @@ struct Store_OC8_OW8_Remain<3, Op> {
}; };
template <typename Op> template <typename Op>
struct Store_OC8_OW8_Remain<2, Op> { struct Store_OC8_OW8_Remain<2, Op> {
static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr));
op({{c[1][0], c[1][1]}}, op({{c[1][0], c[1][1]}},
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc)); reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc));
@@ -490,8 +492,8 @@ struct Store_OC8_OW8_Remain<2, Op> {
}; };
template <typename Op> template <typename Op>
struct Store_OC8_OW8_Remain<1, Op> { struct Store_OC8_OW8_Remain<1, Op> {
static void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
op(c[0][0], reinterpret_cast<dt_qint8*>(dst_ptr)); op(c[0][0], reinterpret_cast<dt_qint8*>(dst_ptr));
op(c[1][0], reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc)); op(c[1][0], reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc));
} }
@@ -500,14 +502,14 @@ struct Store_OC8_OW8_Remain<1, Op> {
/////////// ///////////


template <int ow_remain, typename Op, typename T, typename T2> template <int ow_remain, typename Op, typename T, typename T2>
inline void store_oc8_ow8_remain_static(T& c, const Op& op, T2 dst_ptr,
int ld_dst_oc) {
__ai void store_oc8_ow8_remain_static(T& c, const Op& op, T2 dst_ptr,
int ld_dst_oc) {
Store_OC8_OW8_Remain<ow_remain, Op>::impl(c, op, dst_ptr, ld_dst_oc); Store_OC8_OW8_Remain<ow_remain, Op>::impl(c, op, dst_ptr, ld_dst_oc);
} }


////////////////////////////////////// //////////////////////////////////////
template <BiasMode bias_mode> template <BiasMode bias_mode>
inline void init_oc4_ow8(int32x4_t c[8], const int32_t* bias_ptr) {
__ai void init_oc4_ow8(int32x4_t c[8], const int32_t* bias_ptr) {
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
#define BAIS_INIT(step) c[step] = vld1q_s32(bias_ptr); #define BAIS_INIT(step) c[step] = vld1q_s32(bias_ptr);
UNROLL_CALL_RAW(8, BAIS_INIT); UNROLL_CALL_RAW(8, BAIS_INIT);
@@ -520,8 +522,8 @@ inline void init_oc4_ow8(int32x4_t c[8], const int32_t* bias_ptr) {
} }


template <BiasMode bias_mode> template <BiasMode bias_mode>
inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr,
int oc_step) {
__ai void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr,
int oc_step) {
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
#define BAIS_INIT(step) \ #define BAIS_INIT(step) \
c[0][step] = vld1q_s32(bias_ptr); \ c[0][step] = vld1q_s32(bias_ptr); \
@@ -539,28 +541,28 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr,


/////////////////////////init_ocx_ow8//////////////////// /////////////////////////init_ocx_ow8////////////////////


inline float32x4_t neon_vdupq_n(float val) {
__ai float32x4_t neon_vdupq_n(float val) {
return vdupq_n_f32(val); return vdupq_n_f32(val);
} }


inline int32x4_t neon_vdupq_n(int val) {
__ai int32x4_t neon_vdupq_n(int val) {
return vdupq_n_s32(val); return vdupq_n_s32(val);
} }
inline float32x4_t neon_vld1q(const float* ptr) {
__ai float32x4_t neon_vld1q(const float* ptr) {
return vld1q_f32(ptr); return vld1q_f32(ptr);
} }


inline int32x4_t neon_vld1q(const int* ptr) {
__ai int32x4_t neon_vld1q(const int* ptr) {
return vld1q_s32(ptr); return vld1q_s32(ptr);
} }


template <int c_dim, BiasMode bias_mode, int ow_block, typename T, typename T2> template <int c_dim, BiasMode bias_mode, int ow_block, typename T, typename T2>
struct InitOcxOw8 { struct InitOcxOw8 {
static void impl(T& c, const T2* bias_ptr, int oc_step);
static __ai void impl(T& c, const T2* bias_ptr, int oc_step);
}; };
template <typename T, typename T2> template <typename T, typename T2>
struct InitOcxOw8<2, BiasMode::NO_BIAS, 8, T, T2> { struct InitOcxOw8<2, BiasMode::NO_BIAS, 8, T, T2> {
static void impl(T& c, const T2*, int) {
static __ai void impl(T& c, const T2*, int) {
#define BAIS_INIT(step) \ #define BAIS_INIT(step) \
c[0][step] = neon_vdupq_n(static_cast<T2>(0)); \ c[0][step] = neon_vdupq_n(static_cast<T2>(0)); \
c[1][step] = neon_vdupq_n(static_cast<T2>(0)); c[1][step] = neon_vdupq_n(static_cast<T2>(0));
@@ -570,7 +572,7 @@ struct InitOcxOw8<2, BiasMode::NO_BIAS, 8, T, T2> {
}; };
template <typename T, typename T2> template <typename T, typename T2>
struct InitOcxOw8<2, BiasMode::NO_BIAS, 4, T, T2> { struct InitOcxOw8<2, BiasMode::NO_BIAS, 4, T, T2> {
static void impl(T& c, const T2*, int) {
static __ai void impl(T& c, const T2*, int) {
#define BAIS_INIT(step) \ #define BAIS_INIT(step) \
c[0][step] = neon_vdupq_n(static_cast<T2>(0)); \ c[0][step] = neon_vdupq_n(static_cast<T2>(0)); \
c[1][step] = neon_vdupq_n(static_cast<T2>(0)); c[1][step] = neon_vdupq_n(static_cast<T2>(0));
@@ -580,7 +582,7 @@ struct InitOcxOw8<2, BiasMode::NO_BIAS, 4, T, T2> {
}; };
template <typename T, typename T2> template <typename T, typename T2>
struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> { struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> {
static void impl(T& c, const T2* bias_ptr, int oc_step) {
static __ai void impl(T& c, const T2* bias_ptr, int oc_step) {
#define BAIS_INIT(step) \ #define BAIS_INIT(step) \
c[0][step] = neon_vld1q(bias_ptr); \ c[0][step] = neon_vld1q(bias_ptr); \
c[1][step] = neon_vld1q(bias_ptr + oc_step); c[1][step] = neon_vld1q(bias_ptr + oc_step);
@@ -590,7 +592,7 @@ struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> {
}; };
template <typename T, typename T2> template <typename T, typename T2>
struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> { struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> {
static void impl(T& c, const T2* bias_ptr, int oc_step) {
static __ai void impl(T& c, const T2* bias_ptr, int oc_step) {
#define BAIS_INIT(step) \ #define BAIS_INIT(step) \
c[0][step] = neon_vld1q(bias_ptr); \ c[0][step] = neon_vld1q(bias_ptr); \
c[1][step] = neon_vld1q(bias_ptr + oc_step); c[1][step] = neon_vld1q(bias_ptr + oc_step);
@@ -600,7 +602,7 @@ struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> {
}; };
template <typename T, typename T2> template <typename T, typename T2>
struct InitOcxOw8<2, BiasMode::BIAS, 8, T, T2> { struct InitOcxOw8<2, BiasMode::BIAS, 8, T, T2> {
static void impl(T& c, const T2* bias_ptr, int oc_step) {
static __ai void impl(T& c, const T2* bias_ptr, int oc_step) {
constexpr int simd_len = 4; constexpr int simd_len = 4;
#define BAIS_INIT(step) \ #define BAIS_INIT(step) \
c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \ c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \
@@ -611,7 +613,7 @@ struct InitOcxOw8<2, BiasMode::BIAS, 8, T, T2> {
}; };
template <typename T, typename T2> template <typename T, typename T2>
struct InitOcxOw8<2, BiasMode::BIAS, 4, T, T2> { struct InitOcxOw8<2, BiasMode::BIAS, 4, T, T2> {
static void impl(T& c, const T2* bias_ptr, int oc_step) {
static __ai void impl(T& c, const T2* bias_ptr, int oc_step) {
constexpr int simd_len = 4; constexpr int simd_len = 4;
#define BAIS_INIT(step) \ #define BAIS_INIT(step) \
c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \ c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \
@@ -623,7 +625,7 @@ struct InitOcxOw8<2, BiasMode::BIAS, 4, T, T2> {


template <typename T, typename T2> template <typename T, typename T2>
struct InitOcxOw8<1, BiasMode::NO_BIAS, 8, T, T2> { struct InitOcxOw8<1, BiasMode::NO_BIAS, 8, T, T2> {
static void impl(T& c, const T2*, int) {
static __ai void impl(T& c, const T2*, int) {
#define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast<T2>(0)); #define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast<T2>(0));
UNROLL_CALL_RAW(8, BAIS_INIT); UNROLL_CALL_RAW(8, BAIS_INIT);
#undef BAIS_INIT #undef BAIS_INIT
@@ -631,7 +633,7 @@ struct InitOcxOw8<1, BiasMode::NO_BIAS, 8, T, T2> {
}; };
template <typename T, typename T2> template <typename T, typename T2>
struct InitOcxOw8<1, BiasMode::NO_BIAS, 4, T, T2> { struct InitOcxOw8<1, BiasMode::NO_BIAS, 4, T, T2> {
static void impl(T& c, const T2*, int) {
static __ai void impl(T& c, const T2*, int) {
#define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast<T2>(0)); #define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast<T2>(0));
UNROLL_CALL_RAW(4, BAIS_INIT); UNROLL_CALL_RAW(4, BAIS_INIT);
#undef BAIS_INIT #undef BAIS_INIT
@@ -639,7 +641,7 @@ struct InitOcxOw8<1, BiasMode::NO_BIAS, 4, T, T2> {
}; };
template <typename T, typename T2> template <typename T, typename T2>
struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> { struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> {
static void impl(T& c, const T2* bias_ptr, int) {
static __ai void impl(T& c, const T2* bias_ptr, int) {
#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr); #define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr);
UNROLL_CALL_RAW(8, BAIS_INIT); UNROLL_CALL_RAW(8, BAIS_INIT);
#undef BAIS_INIT #undef BAIS_INIT
@@ -647,7 +649,7 @@ struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> {
}; };
template <typename T, typename T2> template <typename T, typename T2>
struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> { struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> {
static void impl(T& c, const T2* bias_ptr, int) {
static __ai void impl(T& c, const T2* bias_ptr, int) {
#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr); #define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr);
UNROLL_CALL_RAW(4, BAIS_INIT); UNROLL_CALL_RAW(4, BAIS_INIT);
#undef BAIS_INIT #undef BAIS_INIT
@@ -655,7 +657,7 @@ struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> {
}; };
template <typename T, typename T2> template <typename T, typename T2>
struct InitOcxOw8<1, BiasMode::BIAS, 8, T, T2> { struct InitOcxOw8<1, BiasMode::BIAS, 8, T, T2> {
static void impl(T& c, const T2* bias_ptr, int) {
static __ai void impl(T& c, const T2* bias_ptr, int) {
constexpr int simd_len = 4; constexpr int simd_len = 4;
#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len); #define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len);
UNROLL_CALL_RAW(8, BAIS_INIT); UNROLL_CALL_RAW(8, BAIS_INIT);
@@ -664,7 +666,7 @@ struct InitOcxOw8<1, BiasMode::BIAS, 8, T, T2> {
}; };
template <typename T, typename T2> template <typename T, typename T2>
struct InitOcxOw8<1, BiasMode::BIAS, 4, T, T2> { struct InitOcxOw8<1, BiasMode::BIAS, 4, T, T2> {
static void impl(T& c, const T2* bias_ptr, int) {
static __ai void impl(T& c, const T2* bias_ptr, int) {
constexpr int simd_len = 4; constexpr int simd_len = 4;
#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len); #define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len);
UNROLL_CALL_RAW(4, BAIS_INIT); UNROLL_CALL_RAW(4, BAIS_INIT);
@@ -673,18 +675,18 @@ struct InitOcxOw8<1, BiasMode::BIAS, 4, T, T2> {
}; };


template <int c_dim, BiasMode bias_mode, int ow_block, typename T, typename T2> template <int c_dim, BiasMode bias_mode, int ow_block, typename T, typename T2>
inline void init_ocx_ow8(T& c, const T2* bias_ptr, int oc_step) {
__ai void init_ocx_ow8(T& c, const T2* bias_ptr, int oc_step) {
InitOcxOw8<c_dim, bias_mode, ow_block, T, T2>::impl(c, bias_ptr, oc_step); InitOcxOw8<c_dim, bias_mode, ow_block, T, T2>::impl(c, bias_ptr, oc_step);
} }
/////////////////////init_ocx_ow4///////////////////// /////////////////////init_ocx_ow4/////////////////////
template <int c_dim, BiasMode bias_mode, typename T> template <int c_dim, BiasMode bias_mode, typename T>
struct InitOcxOw4 { struct InitOcxOw4 {
static void impl(T& c, const int32_t* bias_ptr, int oc_step);
static __ai void impl(T& c, const int32_t* bias_ptr, int oc_step);
}; };


template <BiasMode bias_mode, typename T> template <BiasMode bias_mode, typename T>
struct InitOcxOw4<2, bias_mode, T> { struct InitOcxOw4<2, bias_mode, T> {
static void impl(T& c, const int32_t* bias_ptr, int oc_step) {
static __ai void impl(T& c, const int32_t* bias_ptr, int oc_step) {
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
#define BAIS_INIT(step) \ #define BAIS_INIT(step) \
c[0][step] = vld1q_s32(bias_ptr); \ c[0][step] = vld1q_s32(bias_ptr); \
@@ -703,7 +705,7 @@ struct InitOcxOw4<2, bias_mode, T> {


template <BiasMode bias_mode, typename T> template <BiasMode bias_mode, typename T>
struct InitOcxOw4<1, bias_mode, T> { struct InitOcxOw4<1, bias_mode, T> {
static void impl(T& c, const int32_t* bias_ptr, int oc_step) {
static __ai void impl(T& c, const int32_t* bias_ptr, int oc_step) {
MEGDNN_MARK_USED_VAR(oc_step); MEGDNN_MARK_USED_VAR(oc_step);
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
#define BAIS_INIT(step) c[0][step] = vld1q_s32(bias_ptr); #define BAIS_INIT(step) c[0][step] = vld1q_s32(bias_ptr);
@@ -718,12 +720,12 @@ struct InitOcxOw4<1, bias_mode, T> {
}; };


template <int c_dim, BiasMode bias_mode, typename T> template <int c_dim, BiasMode bias_mode, typename T>
inline void init_ocx_ow4(T& c, const int32_t* bias_ptr, int oc_step) {
__ai void init_ocx_ow4(T& c, const int32_t* bias_ptr, int oc_step) {
InitOcxOw4<c_dim, bias_mode, T>::impl(c, bias_ptr, oc_step); InitOcxOw4<c_dim, bias_mode, T>::impl(c, bias_ptr, oc_step);
} }
/////////////////////////////////////// ///////////////////////////////////////


} // namespace } // namespace
} // namespace megdnn } // namespace megdnn
#undef __ai
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 8
- 7
dnn/src/arm_common/intrinsic_helper.h View File

@@ -13,13 +13,14 @@
#include "src/arm_common/neon_struct.h" #include "src/arm_common/neon_struct.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#define __ai inline __attribute__((__always_inline__))
namespace megdnn { namespace megdnn {
namespace { namespace {


template <int weight_number, int base_offset, int ptr_step, int oc_block, template <int weight_number, int base_offset, int ptr_step, int oc_block,
typename Func, typename T, typename T2, typename... XT> typename Func, typename T, typename T2, typename... XT>
struct LoadHelper { struct LoadHelper {
static void impl(T& weight, T2 ptr, int oc_offset, XT... args);
static __ai void impl(T& weight, T2 ptr, int oc_offset, XT... args);
}; };


#define WEIGHT_CB(step) \ #define WEIGHT_CB(step) \
@@ -29,7 +30,7 @@ struct LoadHelper {
template <int base_offset, int ptr_step, typename Func, typename T, \ template <int base_offset, int ptr_step, typename Func, typename T, \
typename T2, typename... XT> \ typename T2, typename... XT> \
struct LoadHelper<step, base_offset, ptr_step, 0, Func, T, T2, XT...> { \ struct LoadHelper<step, base_offset, ptr_step, 0, Func, T, T2, XT...> { \
static void impl(T& src, T2 ptr, int, XT... args) { \
static __ai void impl(T& src, T2 ptr, int, XT... args) { \
UNROLL_CALL_RAW(step, WEIGHT_CB); \ UNROLL_CALL_RAW(step, WEIGHT_CB); \
} \ } \
} }
@@ -62,7 +63,7 @@ LOAD_HELPER(16);
template <int base_offset, int ptr_step, typename Func, typename T, \ template <int base_offset, int ptr_step, typename Func, typename T, \
typename T2> \ typename T2> \
struct LoadHelper<step, base_offset, ptr_step, 1, Func, T, T2> { \ struct LoadHelper<step, base_offset, ptr_step, 1, Func, T, T2> { \
static void impl(T& src, T2 ptr, int) { \
static __ai void impl(T& src, T2 ptr, int) { \
UNROLL_CALL_RAW(step, WEIGHT_CB); \ UNROLL_CALL_RAW(step, WEIGHT_CB); \
} \ } \
} }
@@ -89,7 +90,7 @@ LOAD_HELPER(9);
template <int base_offset, int ptr_step, typename Func, typename T, \ template <int base_offset, int ptr_step, typename Func, typename T, \
typename T2> \ typename T2> \
struct LoadHelper<step, base_offset, ptr_step, 2, Func, T, T2> { \ struct LoadHelper<step, base_offset, ptr_step, 2, Func, T, T2> { \
static void impl(T& src, T2 ptr, int oc_offset) { \
static __ai void impl(T& src, T2 ptr, int oc_offset) { \
UNROLL_CALL_RAW(step, WEIGHT_CB); \ UNROLL_CALL_RAW(step, WEIGHT_CB); \
} \ } \
} }
@@ -108,19 +109,19 @@ LOAD_HELPER(8);


template <int weight_number, int base_offset, int ptr_step, int c_dim, template <int weight_number, int base_offset, int ptr_step, int c_dim,
typename Func, typename T, typename T2> typename Func, typename T, typename T2>
inline void load_helper(T& weight, T2 ptr, int oc_offset) {
__ai void load_helper(T& weight, T2 ptr, int oc_offset) {
LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T, T2>::impl( LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T, T2>::impl(
weight, ptr, oc_offset); weight, ptr, oc_offset);
} }


template <int weight_number, int base_offset, int ptr_step, int c_dim, template <int weight_number, int base_offset, int ptr_step, int c_dim,
typename Func, typename T, typename T2, typename... XT> typename Func, typename T, typename T2, typename... XT>
inline void load_helper_x(T& weight, T2 ptr, int oc_offset, XT... args) {
__ai void load_helper_x(T& weight, T2 ptr, int oc_offset, XT... args) {
LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T, T2, LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T, T2,
XT...>::impl(weight, ptr, oc_offset, args...); XT...>::impl(weight, ptr, oc_offset, args...);
} }


} // namespace } // namespace
} // namespace megdnn } // namespace megdnn
#undef __ai
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 22
- 12
dnn/src/arm_common/neon_struct.h View File

@@ -11,59 +11,68 @@
*/ */
#pragma once #pragma once
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"

#define __ai inline __attribute__((__always_inline__))
namespace megdnn { namespace megdnn {
namespace { namespace {
struct Vdotq_s32_h { struct Vdotq_s32_h {
static int32x4_t impl(int8x16_t& a, int8x16_t& b, int32x4_t& c,
int16x8_t& temp) {
static __ai int32x4_t impl(int8x16_t& a, int8x16_t& b, int32x4_t& c,
int16x8_t& temp) {
return vdotq_s32_h(a, b, c, temp); return vdotq_s32_h(a, b, c, temp);
} }
}; };
struct Vdot2_s32_h { struct Vdot2_s32_h {
static int32x4_t impl(int8x8_t a, int8x8_t b, int32x4_t c, int16x8_t temp) {
static __ai int32x4_t impl(int8x8_t a, int8x8_t b, int32x4_t c,
int16x8_t temp) {
return vdot2_s32_h(a, b, c, temp); return vdot2_s32_h(a, b, c, temp);
} }
}; };


struct Vmlal_s16 { struct Vmlal_s16 {
static int32x4_t impl(int16x8_t a, int16x8_t b, int32x4_t c) {
static __ai int32x4_t impl(int16x8_t a, int16x8_t b, int32x4_t c) {
return vmlal_s16(c, vget_low_s16(a), vget_low_s16(b)); return vmlal_s16(c, vget_low_s16(a), vget_low_s16(b));
} }
}; };


struct Vld1q_s8 { struct Vld1q_s8 {
static int8x16_t impl(const int8_t* ptr) { return vld1q_s8(ptr); }
static __ai int8x16_t impl(const int8_t* ptr) { return vld1q_s8(ptr); }
}; };
struct Vld1q_f32 { struct Vld1q_f32 {
static float32x4_t impl(const float32_t* ptr) { return vld1q_f32(ptr); }
static __ai float32x4_t impl(const float32_t* ptr) {
return vld1q_f32(ptr);
}
}; };
struct Vld1_s8 { struct Vld1_s8 {
static int8x8_t impl(const int8_t* ptr) { return vld1_s8(ptr); }
static __ai int8x8_t impl(const int8_t* ptr) { return vld1_s8(ptr); }
}; };
struct Vldq_dup_4s8_8s16 { struct Vldq_dup_4s8_8s16 {
static int16x8_t impl(const int8_t* ptr) { return vldq_dup_4s8_8s16(ptr); }
static __ai int16x8_t impl(const int8_t* ptr) {
return vldq_dup_4s8_8s16(ptr);
}
}; };


struct Vldq_tbl_low_s8 { struct Vldq_tbl_low_s8 {
static int8x8_t impl(const int8_t* ptr, uint8x16_t idx) {
static __ai int8x8_t impl(const int8_t* ptr, uint8x16_t idx) {
return vldq_tbl_low_s8(ptr, idx); return vldq_tbl_low_s8(ptr, idx);
} }
}; };


struct Vld1_dup_s8_s16 { struct Vld1_dup_s8_s16 {
static int16x8_t impl(const int8_t* ptr) { return vld1_dup_s8_s16(ptr); }
static __ai int16x8_t impl(const int8_t* ptr) {
return vld1_dup_s8_s16(ptr);
}
}; };


struct Vfmaq_laneq_f32 { struct Vfmaq_laneq_f32 {
template <const int lane> template <const int lane>
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
static __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vfmaq_laneq_f32(a, b, v, lane); return vfmaq_laneq_f32(a, b, v, lane);
} }
}; };
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
struct Vdotq_laneq_s32 { struct Vdotq_laneq_s32 {
template <const int lane> template <const int lane>
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
static __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_laneq_s32(a, b, v, lane); return vdotq_laneq_s32(a, b, v, lane);
} }
}; };
@@ -72,4 +81,5 @@ struct Vdotq_laneq_s32 {
} // namespace } // namespace
} // namespace megdnn } // namespace megdnn


#undef __ai
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 81
- 14
dnn/src/arm_common/simd_macro/marm_neon.h View File

@@ -20,7 +20,9 @@
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas" #pragma GCC diagnostic ignored "-Wpragmas"
#pragma GCC diagnostic ignored "-Wattributes" #pragma GCC diagnostic ignored "-Wattributes"
#define __ai static inline __attribute__((__always_inline__, __nodebug__))
#define __ai \
static inline \
__attribute__((__gnu_inline__, __always_inline__, __nodebug__))


#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC && !MEGDNN_DISABLE_FLOAT16 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC && !MEGDNN_DISABLE_FLOAT16
#define MEGDNN_INC_ARM_FP16(_x) _x #define MEGDNN_INC_ARM_FP16(_x) _x
@@ -299,16 +301,20 @@ __ai uint32x2_t vdot2_u8(uint8x8_t a, uint8x8_t b) {


#endif // __ARM_FEATURE_DOTPROD #endif // __ARM_FEATURE_DOTPROD


#if __GNUC__ < 8
#undef vld1q_f32_x2 #undef vld1q_f32_x2
__ai float32x4x2_t vld1q_f32_x2(const float* p) { __ai float32x4x2_t vld1q_f32_x2(const float* p) {
return {{vld1q_f32(p), vld1q_f32(p + 4)}}; return {{vld1q_f32(p), vld1q_f32(p + 4)}};
} }
#endif


#if __GNUC__ < 9
#undef vst1q_f32_x2 #undef vst1q_f32_x2
__ai void vst1q_f32_x2(const float* p, float32x4x2_t v) { __ai void vst1q_f32_x2(const float* p, float32x4x2_t v) {
vst1q_f32(const_cast<float*>(p), v.val[0]); vst1q_f32(const_cast<float*>(p), v.val[0]);
vst1q_f32(const_cast<float*>(p) + 4, v.val[1]); vst1q_f32(const_cast<float*>(p) + 4, v.val[1]);
} }
#endif


__ai int8x16_t vtranslq_s8(int8x8_t a) { __ai int8x16_t vtranslq_s8(int8x8_t a) {
int8x16_t ret; int8x16_t ret;
@@ -472,18 +478,18 @@ __ai int8x16_t vqtbl1q_s8(int8x16_t& a, uint8x16_t& idx) {
namespace { namespace {
template <int lane> template <int lane>
struct Vdup_laneq_s16_armv7 { struct Vdup_laneq_s16_armv7 {
static int16x4_t impl(int16x8_t vec);
__ai int16x4_t impl(int16x8_t vec);
}; };
#define cb(step) \ #define cb(step) \
template <> \ template <> \
struct Vdup_laneq_s16_armv7<step + 4> { \ struct Vdup_laneq_s16_armv7<step + 4> { \
static int16x4_t impl(int16x8_t vec) { \
__ai int16x4_t impl(int16x8_t vec) { \
return vdup_lane_s16(vget_high_s16(vec), step); \ return vdup_lane_s16(vget_high_s16(vec), step); \
} \ } \
}; \ }; \
template <> \ template <> \
struct Vdup_laneq_s16_armv7<step> { \ struct Vdup_laneq_s16_armv7<step> { \
static int16x4_t impl(int16x8_t vec) { \
__ai int16x4_t impl(int16x8_t vec) { \
return vdup_lane_s16(vget_low_s16(vec), step); \ return vdup_lane_s16(vget_low_s16(vec), step); \
} \ } \
}; };
@@ -495,30 +501,30 @@ UNROLL_CALL_RAW(4, cb);
namespace { namespace {
template <int lane> template <int lane>
struct Vfmaq_laneq_f32_armv7 { struct Vfmaq_laneq_f32_armv7 {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v);
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v);
}; };


template <> template <>
struct Vfmaq_laneq_f32_armv7<0> { struct Vfmaq_laneq_f32_armv7<0> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_low_f32(v), 0); return vmlaq_lane_f32(a, b, vget_low_f32(v), 0);
} }
}; };
template <> template <>
struct Vfmaq_laneq_f32_armv7<1> { struct Vfmaq_laneq_f32_armv7<1> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_low_f32(v), 1); return vmlaq_lane_f32(a, b, vget_low_f32(v), 1);
} }
}; };
template <> template <>
struct Vfmaq_laneq_f32_armv7<2> { struct Vfmaq_laneq_f32_armv7<2> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_high_f32(v), 0); return vmlaq_lane_f32(a, b, vget_high_f32(v), 0);
} }
}; };
template <> template <>
struct Vfmaq_laneq_f32_armv7<3> { struct Vfmaq_laneq_f32_armv7<3> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_high_f32(v), 1); return vmlaq_lane_f32(a, b, vget_high_f32(v), 1);
} }
}; };
@@ -527,37 +533,98 @@ struct Vfmaq_laneq_f32_armv7<3> {
Vfmaq_laneq_f32_armv7<lane>::impl(a, b, v) Vfmaq_laneq_f32_armv7<lane>::impl(a, b, v)


#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
namespace {
template <int lane> template <int lane>
struct Vdotq_laneq_s32_armv7 { struct Vdotq_laneq_s32_armv7 {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v);
__ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v);
}; };
template <> template <>
struct Vdotq_laneq_s32_armv7<0> { struct Vdotq_laneq_s32_armv7<0> {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
__ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_low_s32(v), 0); return vdotq_lane_s32(a, b, vget_low_s32(v), 0);
} }
}; };
template <> template <>
struct Vdotq_laneq_s32_armv7<1> { struct Vdotq_laneq_s32_armv7<1> {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
__ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_low_s32(v), 1); return vdotq_lane_s32(a, b, vget_low_s32(v), 1);
} }
}; };
template <> template <>
struct Vdotq_laneq_s32_armv7<2> { struct Vdotq_laneq_s32_armv7<2> {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
__ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_high_s32(v), 0); return vdotq_lane_s32(a, b, vget_high_s32(v), 0);
} }
}; };
template <> template <>
struct Vdotq_laneq_s32_armv7<3> { struct Vdotq_laneq_s32_armv7<3> {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
__ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_high_f32(v), 1); return vdotq_lane_s32(a, b, vget_high_f32(v), 1);
} }
}; };
#define vdotq_laneq_s32(a, b, v, lane) \ #define vdotq_laneq_s32(a, b, v, lane) \
Vdotq_laneq_s32_armv7<lane>::impl(a, b, v) Vdotq_laneq_s32_armv7<lane>::impl(a, b, v)


} // namespace
#endif

#endif

//! GCC split fmla with lane to dup+fmla when version < 9
//! https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
#if !defined(__clang__) && __GNUC__ < 9
#if MEGDNN_AARCH64
namespace {

template <int lane>
struct Vfmaq_laneq_f32_armv8 {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v);
};
template <>
struct Vfmaq_laneq_f32_armv8<0> {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
asm volatile("fmla %0.4s, %1.4s, %2.s[0]\n"
: "+w"(a)
: "w"(b), "w"(v)
:);
return a;
}
};
template <>
struct Vfmaq_laneq_f32_armv8<1> {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
asm volatile("fmla %0.4s, %1.4s, %2.s[1]\n"
: "+w"(a)
: "w"(b), "w"(v)
:);
return a;
}
};
template <>
struct Vfmaq_laneq_f32_armv8<2> {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
asm volatile("fmla %0.4s, %1.4s, %2.s[2]\n"
: "+w"(a)
: "w"(b), "w"(v)
:);
return a;
}
};
template <>
struct Vfmaq_laneq_f32_armv8<3> {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
asm volatile("fmla %0.4s, %1.4s, %2.s[3]\n"
: "+w"(a)
: "w"(b), "w"(v)
:);
return a;
}
};
} // namespace
#undef vfmaq_laneq_f32
#define vfmaq_laneq_f32(a, b, v, lane) \
Vfmaq_laneq_f32_armv8<lane>::impl(a, b, v)

#endif #endif


#endif #endif


+ 15
- 11
dnn/test/arm_common/conv_bias_multi_thread.cpp View File

@@ -77,7 +77,7 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
bool only_no_bias = false) { bool only_no_bias = false) {
using namespace conv_bias; using namespace conv_bias;
using NLMode = param::ConvBias::NonlineMode; using NLMode = param::ConvBias::NonlineMode;
std::vector<TestArg> args; std::vector<TestArg> args;


auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w,
@@ -172,11 +172,11 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS); bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS);
} }
if (support_full_bias) { if (support_full_bias) {
bias_mode.emplace_back(megdnn::BiasMode::BIAS);
bias_mode.emplace_back(megdnn::BiasMode::BIAS);
} }
for (auto bias : bias_mode) for (auto bias : bias_mode)
for (auto nlmode : nonlinemode) for (auto nlmode : nonlinemode)
for (size_t n : {1,2})
for (size_t n : {1, 2})
for (size_t kernel : kernel_vec) for (size_t kernel : kernel_vec)
for (size_t oc : {4, 12}) for (size_t oc : {4, 12})
for (size_t ic : {1, 3, 4, 12}) for (size_t ic : {1, 3, 4, 12})
@@ -364,8 +364,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) {
} }


TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) { TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) {
check_conv_bias(get_nchw44_conv_bias_args({7}, 1, false, true, true,
false, false, false),
check_conv_bias(get_nchw44_conv_bias_args({7}, 1, false, true, true, false,
false, false),
handle(), "F32_CONV_NCHW44_DIRECT"); handle(), "F32_CONV_NCHW44_DIRECT");
} }


@@ -403,10 +403,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
handle(), "F32STRD2_SMALL_GROUP"); handle(), "F32STRD2_SMALL_GROUP");
} }
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) {
check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false,
false, true), false, true),
handle(), "F32_CONV_NCHW_NCHW44"); handle(), "F32_CONV_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S1) {
check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false,
false, true), false, true),
handle(), "F32_CONV_NCHW_NCHW44"); handle(), "F32_CONV_NCHW_NCHW44");
@@ -566,13 +568,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
handle(), "S8_CHAN_WISE_STRD2_NCHW44"); handle(), "S8_CHAN_WISE_STRD2_NCHW44");
} }


TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1) {
checker_conv_bias_qint8x8x8( checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
true), true),
handle(), "S8_CONV_NCHW_NCHW44"); handle(), "S8_CONV_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
checker_conv_bias_qint8x8x8( checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
true), true),
handle(), "S8_CONV_NCHW_NCHW44"); handle(), "S8_CONV_NCHW_NCHW44");
} }
@@ -1820,7 +1824,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) {
using namespace conv_bias; using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
{2, 4, 7}, 1, false, false, false, false, false, true,true);
{2, 4, 7}, 1, false, false, false, false, false, true, true);
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
#elif MEGDNN_ARMV7 #elif MEGDNN_ARMV7
@@ -1841,7 +1845,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE) {
using namespace conv_bias; using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
{3}, 2, false, false, false, false, false, true, true,false);
{3}, 2, false, false, false, false, false, true, true, false);
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
#elif MEGDNN_ARMV7 #elif MEGDNN_ARMV7


+ 6
- 0
toolchains/aarch64-none-linux-gnu.toolchain.cmake View File

@@ -0,0 +1,6 @@
set(ARM_CROSS_BUILD_ARCH aarch64)
set(CMAKE_C_COMPILER "aarch64-none-linux-gnu-gcc")
set(CMAKE_CXX_COMPILER "aarch64-none-linux-gnu-g++")
set(CMAKE_C_FLAGS "-Werror=unused-parameter -Wno-psabi")
set(CMAKE_CXX_FLAGS "-Werror=unused-parameter -Wno-psabi")
set(CMAKE_STRIP "aarch64-none-linux-gnu-strip")

Loading…
Cancel
Save