Browse Source

refactor(dnn/arm): split arm direct kernel to cut compile time

GitOrigin-RevId: b06fba83eb
tags/v1.0.0-rc1
Megvii Engine Team 5 years ago
parent
commit
4d56371e0b
48 changed files with 6492 additions and 5552 deletions
  1. +173
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp
  2. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp
  3. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp
  4. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp
  5. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp
  6. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp
  7. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp
  8. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp
  9. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp
  10. +81
    -152
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h
  11. +89
    -251
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h
  12. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp
  13. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp
  14. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp
  15. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp
  16. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp
  17. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp
  18. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp
  19. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp
  20. +443
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h
  21. +10
    -32
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
  22. +34
    -0
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h
  23. +4
    -2
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp
  24. +12
    -395
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h
  25. +0
    -40
      dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h
  26. +0
    -40
      dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h
  27. +4
    -228
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp
  28. +10
    -16
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h
  29. +5
    -9
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp
  30. +0
    -435
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h
  31. +245
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h
  32. +320
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp
  33. +322
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp
  34. +448
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp
  35. +437
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp
  36. +743
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp
  37. +778
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp
  38. +47
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h
  39. +561
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp
  40. +1412
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp
  41. +26
    -57
      dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp
  42. +7
    -1337
      dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h
  43. +10
    -9
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
  44. +3
    -1854
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h
  45. +5
    -4
      dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
  46. +14
    -662
      dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
  47. +2
    -2
      dnn/test/arm_common/conv_bias_multi_thread.cpp
  48. +23
    -27
      dnn/test/arm_common/matrix_mul.cpp

+ 173
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp View File

@@ -0,0 +1,173 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {
template <>
void pack_src_fp32_nchw44<1>(float* sptr_base, const float* sptr_origin,
const int, const int pw, const int pad_right,
const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride) {
constexpr int ic_step = 4;
rep_step(ic_idx, ic, ic_step) {
const float* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step);
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
memset(sptr_base, 0, sizeof(float) * pw * ic_step);
sptr_base += pw * ic_step;
memcpy(sptr_base, sptr, sizeof(float) * iw * ic_step);
sptr_base += iw * ic_step;
sptr += iw * ic_step;
memset(sptr_base, 0, sizeof(float) * pad_right * ic_step);
sptr_base += pad_right * ic_step;
}
memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step);
sptr_base += iw2 * pad_bottom * ic_step;
}
}

namespace {

static inline void odd_even_split_iw8_even(float* sptr_base, const float* sptr,
const int odd_start,
const int src_idx,
const int iw_idx) {
constexpr int ic_step = 4;
const int src_offset = src_idx * ic_step;
const int even_offset = iw_idx / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
float32x4_t temp[8];
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step);
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step);
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step);
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step);
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step);
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step);
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step);
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step);
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]);
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]);
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]);
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]);
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]);
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]);
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]);
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]);
}

static inline void odd_even_split_iw8_odd(float* sptr_base, const float* sptr,
const int odd_start,
const int src_idx, const int iw_idx) {
constexpr int ic_step = 4;
const int src_offset = src_idx * ic_step;
const int even_offset = (iw_idx + 1) / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
float32x4_t temp[8];
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step);
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step);
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step);
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step);
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step);
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step);
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step);
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step);
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]);
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]);
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]);
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]);
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]);
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]);
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]);
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]);
}
} // namespace

template <>
void pack_src_fp32_nchw44<2>(float* sptr_base, const float* sptr_origin,
const int ph, const int pw, const int pad_right,
const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride) {
constexpr int ic_step = 4;
int odd_start = megdnn::div_ceil(iw2, 2);
float32x4_t zero_v = vdupq_n_f32(0.f);
MEGDNN_MARK_USED_VAR(ph);
bool even_start = pw % 2 == 0;
rep_step(ic_idx, ic, ic_step) {
const float* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step);
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
int iw_idx = 0;
rep(idx, pw) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v);
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
zero_v);
}
++iw_idx;
}
int src_idx = 0;
if (even_start) {
for (; src_idx + 7 < iw; src_idx += 8) {
odd_even_split_iw8_even(sptr_base, sptr, odd_start, src_idx,
iw_idx);
iw_idx += 8;
}
} else {
for (; src_idx + 7 < iw; src_idx += 8) {
odd_even_split_iw8_odd(sptr_base, sptr, odd_start, src_idx,
iw_idx);
iw_idx += 8;
}
}
for (; src_idx < iw; ++src_idx) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step,
vld1q_f32(sptr + src_idx * ic_step));
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
vld1q_f32(sptr + src_idx * ic_step));
}
++iw_idx;
}
rep(idx, pad_right) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v);
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
zero_v);
}
++iw_idx;
}
sptr_base += iw2 * ic_step;
sptr += iw * ic_step;
}
memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step);
sptr_base += iw2 * pad_bottom * ic_step;
}
}

} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(2);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(2);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(3);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(3);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(5);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(5);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(7);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(7);

dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
@@ -12,7 +12,7 @@
*/

#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
@@ -24,21 +24,21 @@ using namespace megdnn;
using namespace arm_common;
namespace {

template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block,
typename T, typename T2, typename T3, typename T4>
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
};

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 8, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]); \
c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \
src[(step + src_idx) % 8]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8], lane); \
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \
src[(step + src_idx) % 8], lane);

UNROLL_CALL_RAW(8, cb, 0);
UNROLL_CALL_RAW(8, cb, 1);
@@ -47,15 +47,15 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 4, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]); \
c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \
src[(step + src_idx) % 4]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4], lane); \
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \
src[(step + src_idx) % 4], lane);

UNROLL_CALL_RAW(4, cb, 0);
UNROLL_CALL_RAW(4, cb, 1);
@@ -64,13 +64,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 8, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8], lane);

UNROLL_CALL_RAW(8, cb, 0);
UNROLL_CALL_RAW(8, cb, 1);
@@ -79,13 +79,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 4, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4], lane);

UNROLL_CALL_RAW(4, cb, 0);
UNROLL_CALL_RAW(4, cb, 1);
@@ -95,11 +95,11 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {
}
};

template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block,
typename T, typename T2, typename T3>
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T,
typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, ow_block, T, T2, T3,
int>::impl(c, src, weight);
ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, T, T2, T3, int>::impl(
c, src, weight);
};
template <int oc>
struct OCHelper {
@@ -162,13 +162,11 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw;
weight_ptr += ld_weight_fh;
}
@@ -209,18 +207,15 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw;
weight_ptr += ld_weight_fh;
}
@@ -260,32 +255,27 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);

src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);

src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);

src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step);
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<3, 0, c_dim, ow_block>(c, src, weight);

src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step);
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<4, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<4, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw;
weight_ptr += ld_weight_fh;
}
@@ -326,44 +316,37 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);

src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);

src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);

src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step);
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<3, 0, c_dim, ow_block>(c, src, weight);

src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step);
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<4, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<4, 0, c_dim, ow_block>(c, src, weight);

src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step);
load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<5, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<5, 0, c_dim, ow_block>(c, src, weight);

src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step);
load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<6, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<6, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw;
weight_ptr += ld_weight_fh;
}
@@ -375,36 +358,14 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {

} // namespace

void conv_bias::pack_src_fp32_nchw44_stride1(
float* sptr_base, const float* sptr_origin, const int, const int pw,
const int pad_right, const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom, const int ic,
const int ic_stride) {
constexpr int ic_step = 4;
rep_step(ic_idx, ic, ic_step) {
const float* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step);
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
memset(sptr_base, 0, sizeof(float) * pw * ic_step);
sptr_base += pw * ic_step;
memcpy(sptr_base, sptr, sizeof(float) * iw * ic_step);
sptr_base += iw * ic_step;
sptr += iw * ic_step;
memset(sptr_base, 0, sizeof(float) * pad_right * ic_step);
sptr_base += pad_right * ic_step;
}
memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step);
sptr_base += iw2 * pad_bottom * ic_step;
}
}

template <BiasMode bias_mode, typename Op, int filter_size>
static void conv_direct_stride1_fp32_nchw44(
const float32_t* src, const float32_t* filter, const float32_t* bias,
float32_t*, float32_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op, const int, const int) {
template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter,
const float* bias, float*, float* dst,
const int oc, const int ic,
const int ih, const int iw,
const int oh, const int oh_block,
const int ow, const Op& op, const int,
const int) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
@@ -518,55 +479,23 @@ static void conv_direct_stride1_fp32_nchw44(
}
}

#define CONSTRUCT_FUNC(filter_size) \
template <BiasMode bias_mode, typename Op> \
void conv_bias:: \
conv_direct_stride1_##filter_size##x##filter_size##_fp32_nchw44( \
const float32_t* src, const float32_t* filter, \
const float32_t* bias, float32_t* temp, float32_t* dst, \
const int oc, const int ic, const int ih, const int iw, \
const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw) { \
conv_direct_stride1_fp32_nchw44<bias_mode, Op, filter_size>( \
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \
ow, op, ph, pw); \
}
CONSTRUCT_FUNC(2);
CONSTRUCT_FUNC(3);
CONSTRUCT_FUNC(5);
CONSTRUCT_FUNC(7);
#undef CONSTRUCT_FUNC

#define INSTANTIATION(stride, i, bias, Op) \
template void conv_bias::conv_direct_##stride##_##i##x##i##_fp32_nchw44< \
bias, Op>(const float32_t*, const float32_t*, const float32_t*, \
float32_t*, float32_t*, const int, const int, const int, \
const int, const int, const int, const int, const Op&, \
const int, const int);

#define FOR_OP(stride, i, bias) \
INSTANTIATION(stride, i, bias, NoneOp<dt_float32>) \
INSTANTIATION(stride, i, bias, ReluOp<dt_float32>) \
INSTANTIATION(stride, i, bias, HSwishOp<dt_float32>) \
INSTANTIATION(stride, i, bias, SigmoidOp<dt_float32>)

#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(stride, i, BiasMode::BIAS)

#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)

FOR_FILTER(stride1)

#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
#define INSTANTIATION(filter_size, bias_mode, Op) \
template void \
conv_bias::conv_direct_fp32_nchw44<bias_mode, Op, filter_size, 1>( \
const float* src, const float* filter, const float* bias, float*, \
float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int, const int);

#define FOR_OP(filter_size, bias) \
INSTANTIATION(filter_size, bias, NoneOp<dt_float32>) \
INSTANTIATION(filter_size, bias, ReluOp<dt_float32>) \
INSTANTIATION(filter_size, bias, HSwishOp<dt_float32>) \
INSTANTIATION(filter_size, bias, SigmoidOp<dt_float32>)

#define INSTANTIATION_CONV_S1(filter_size) \
FOR_OP(filter_size, BiasMode::NO_BIAS) \
FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(filter_size, BiasMode::BIAS)

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
@@ -12,7 +12,7 @@
*/

#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
@@ -24,21 +24,21 @@ using namespace megdnn;
using namespace arm_common;
namespace {

template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block,
typename T, typename T2, typename T3, typename T4>
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
};

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 8, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]); \
c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \
src[(step + src_idx) % 8]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8], lane); \
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \
src[(step + src_idx) % 8], lane);

UNROLL_CALL_RAW(8, cb, 0);
UNROLL_CALL_RAW(8, cb, 1);
@@ -47,15 +47,15 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 4, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]); \
c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \
src[(step + src_idx) % 4]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4], lane); \
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \
src[(step + src_idx) % 4], lane);

UNROLL_CALL_RAW(4, cb, 0);
UNROLL_CALL_RAW(4, cb, 1);
@@ -64,13 +64,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 8, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8], lane);

UNROLL_CALL_RAW(8, cb, 0);
UNROLL_CALL_RAW(8, cb, 1);
@@ -79,13 +79,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 4, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4], lane);

UNROLL_CALL_RAW(4, cb, 0);
UNROLL_CALL_RAW(4, cb, 1);
@@ -95,11 +95,11 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {
}
};

template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block,
typename T, typename T2, typename T3>
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T,
typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, ow_block, T, T2, T3,
int>::impl(c, src, weight);
ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, T, T2, T3, int>::impl(
c, src, weight);
};
template <int oc>
struct OCHelper {
@@ -163,13 +163,13 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);

load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd,
0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw;
weight_ptr += ld_weight_fh;
@@ -177,13 +177,13 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);

load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd,
0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw;
weight_ptr += ld_weight_fh;
@@ -224,18 +224,18 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);

src[0] = vld1q_f32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);

load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd,
0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw;
weight_ptr += ld_weight_fh;
@@ -243,17 +243,17 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);

load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd,
0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw;
weight_ptr += ld_weight_fh;
@@ -261,18 +261,18 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + ow_block * simd_len);

load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);

load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd,
0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw;
weight_ptr += ld_weight_fh;
@@ -316,30 +316,25 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len);
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);
// odd element
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len);
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);

src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw;
@@ -390,40 +385,33 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len);
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len);
load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<3, 0, c_dim, ow_block>(c, src, weight);
// odd element
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len);
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len);
load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);

src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw;
@@ -436,133 +424,15 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
};

} // namespace
namespace {

inline void odd_even_split_iw8_even(float* sptr_base, const float* sptr,
const int odd_start, const int src_idx,
const int iw_idx) {
constexpr int ic_step = 4;
const int src_offset = src_idx * ic_step;
const int even_offset = iw_idx / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
float32x4_t temp[8];
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step);
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step);
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step);
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step);
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step);
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step);
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step);
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step);
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]);
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]);
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]);
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]);
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]);
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]);
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]);
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]);
}

inline void odd_even_split_iw8_odd(float* sptr_base, const float* sptr,
const int odd_start, const int src_idx,
const int iw_idx) {
constexpr int ic_step = 4;
const int src_offset = src_idx * ic_step;
const int even_offset = (iw_idx + 1) / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
float32x4_t temp[8];
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step);
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step);
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step);
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step);
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step);
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step);
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step);
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step);
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]);
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]);
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]);
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]);
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]);
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]);
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]);
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]);
}
} // namespace

void conv_bias::pack_src_fp32_nchw44_stride2(
float* sptr_base, const float* sptr_origin, const int ph, const int pw,
const int pad_right, const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom, const int ic,
const int ic_stride) {
constexpr int ic_step = 4;
int odd_start = megdnn::div_ceil(iw2, 2);
float32x4_t zero_v = vdupq_n_f32(0.f);
MEGDNN_MARK_USED_VAR(ph);
bool even_start = pw % 2 == 0;
rep_step(ic_idx, ic, ic_step) {
const float* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step);
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
int iw_idx = 0;
rep(idx, pw) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v);
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
zero_v);
}
++iw_idx;
}
int src_idx = 0;
if (even_start) {
for (; src_idx + 7 < iw; src_idx += 8) {
odd_even_split_iw8_even(sptr_base, sptr, odd_start, src_idx,
iw_idx);
iw_idx += 8;
}
} else {
for (; src_idx + 7 < iw; src_idx += 8) {
odd_even_split_iw8_odd(sptr_base, sptr, odd_start, src_idx,
iw_idx);
iw_idx += 8;
}
}
for (; src_idx < iw; ++src_idx) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step,
vld1q_f32(sptr + src_idx * ic_step));
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
vld1q_f32(sptr + src_idx * ic_step));
}
++iw_idx;
}
rep(idx, pad_right) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v);
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
zero_v);
}
++iw_idx;
}
sptr_base += iw2 * ic_step;
sptr += iw * ic_step;
}
memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step);
sptr_base += iw2 * pad_bottom * ic_step;
}
}

template <BiasMode bias_mode, typename Op, int filter_size>
static void conv_direct_stride2_fp32_nchw44(
const float32_t* src, const float32_t* filter, const float32_t* bias,
float32_t*, float32_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op, const int, const int) {
template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter,
const float* bias, float*, float* dst,
const int oc, const int ic,
const int ih, const int iw,
const int oh, const int oh_block,
const int ow, const Op& op, const int,
const int) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
@@ -697,55 +567,23 @@ static void conv_direct_stride2_fp32_nchw44(
}
}

#define CONSTRUCT_FUNC(filter_size) \
template <BiasMode bias_mode, typename Op> \
void conv_bias:: \
conv_direct_stride2_##filter_size##x##filter_size##_fp32_nchw44( \
const float32_t* src, const float32_t* filter, \
const float32_t* bias, float32_t* temp, float32_t* dst, \
const int oc, const int ic, const int ih, const int iw, \
const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw) { \
conv_direct_stride2_fp32_nchw44<bias_mode, Op, filter_size>( \
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \
ow, op, ph, pw); \
}
CONSTRUCT_FUNC(2);
CONSTRUCT_FUNC(3);
CONSTRUCT_FUNC(5);
CONSTRUCT_FUNC(7);
#undef CONSTRUCT_FUNC

#define INSTANTIATION(stride, i, bias, Op) \
template void conv_bias::conv_direct_##stride##_##i##x##i##_fp32_nchw44< \
bias, Op>(const float32_t*, const float32_t*, const float32_t*, \
float32_t*, float32_t*, const int, const int, const int, \
const int, const int, const int, const int, const Op&, \
const int, const int);

#define FOR_OP(stride, i, bias) \
INSTANTIATION(stride, i, bias, NoneOp<dt_float32>) \
INSTANTIATION(stride, i, bias, ReluOp<dt_float32>) \
INSTANTIATION(stride, i, bias, HSwishOp<dt_float32>) \
INSTANTIATION(stride, i, bias, SigmoidOp<dt_float32>)

#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(stride, i, BiasMode::BIAS)

#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)

FOR_FILTER(stride2)

#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
#define INSTANTIATION(filter_size, bias_mode, Op) \
template void \
conv_bias::conv_direct_fp32_nchw44<bias_mode, Op, filter_size, 2>( \
const float* src, const float* filter, const float* bias, float*, \
float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int, const int);

#define FOR_OP(filter_size, bias) \
INSTANTIATION(filter_size, bias, NoneOp<dt_float32>) \
INSTANTIATION(filter_size, bias, ReluOp<dt_float32>) \
INSTANTIATION(filter_size, bias, HSwishOp<dt_float32>) \
INSTANTIATION(filter_size, bias, SigmoidOp<dt_float32>)

#define INSTANTIATION_CONV_S2(filter_size) \
FOR_OP(filter_size, BiasMode::NO_BIAS) \
FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(filter_size, BiasMode::BIAS)

// vim: syntax=cpp.doxygen

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(2, 1);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(2, 2);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(3, 1);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(3, 2);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(5, 1);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(5, 2);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(7, 1);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(7, 2);

+ 443
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h View File

@@ -0,0 +1,443 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"

using namespace megdnn;
using namespace arm_common;

namespace {
/**
*\brief ShiftCalHelper is core calculate code
*\tparam src_idx is offset for src regs
*\tparam weight_idx is offset for weight regs
*\tparam T is type of output regs
*\tparam T2 is type of src regs
*\tparam T3 is type of weight regs
*/
template <int src_idx, int weight_idx, int c_dim, int stride, typename T,
typename T2, typename T3>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
};

template <int src_idx, int weight_idx, int stride, typename T, typename T2,
typename T3>
struct ShiftCalHelper<src_idx, weight_idx, 2, stride, T, T2, T3> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4], \
(step * stride + src_idx) % 4); \
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][weight_idx], \
src[(step * stride + src_idx) / 4], \
(step * stride + src_idx) % 4);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
template <int src_idx, int weight_idx, int stride, typename T, typename T2,
typename T3>
struct ShiftCalHelper<src_idx, weight_idx, 1, stride, T, T2, T3> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4], \
(step * stride + src_idx) % 4);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, int c_dim, int stride, typename T,
typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3>::impl(c, src,
weight);
};
template <int oc>
struct OCHelper {
public:
static const int val = -1;
};

template <>
struct OCHelper<4> {
public:
static const int val = 1;
};

template <>
struct OCHelper<8> {
public:
static const int val = 2;
};
/**
* oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel
**/
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int oc_block, int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32 {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op);
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 7;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];

#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, stride>(c, src, weight); \
cal_helper<1, 1, c_dim, stride>(c, src, weight); \
cal_helper<2, 2, c_dim, stride>(c, src, weight); \
cal_helper<3, 3, c_dim, stride>(c, src, weight); \
cal_helper<4, 4, c_dim, stride>(c, src, weight); \
cal_helper<5, 5, c_dim, stride>(c, src, weight); \
cal_helper<6, 6, c_dim, stride>(c, src, weight);

UNROLL_CALL_RAW(7, KERNEL_CB)
#undef KERNEL_CB

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 5;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];

#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, stride>(c, src, weight); \
cal_helper<1, 1, c_dim, stride>(c, src, weight); \
cal_helper<2, 2, c_dim, stride>(c, src, weight); \
cal_helper<3, 3, c_dim, stride>(c, src, weight); \
cal_helper<4, 4, c_dim, stride>(c, src, weight);
UNROLL_CALL_RAW(5, KERNEL_CB)
#undef KERNEL_CB

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 3;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr,
0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, stride>(c, src, weight);
cal_helper<1, 1, c_dim, stride>(c, src, weight);
cal_helper<2, 2, c_dim, stride>(c, src, weight);

// row 1
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, stride>(c, src, weight);
cal_helper<1, 1, c_dim, stride>(c, src, weight);
cal_helper<2, 2, c_dim, stride>(c, src, weight);

// row 2
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + 2 * iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, stride>(c, src, weight);
cal_helper<1, 1, c_dim, stride>(c, src, weight);
cal_helper<2, 2, c_dim, stride>(c, src, weight);

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 2;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr,
0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, stride>(c, src, weight);
cal_helper<1, 1, c_dim, stride>(c, src, weight);

// row 1
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, stride>(c, src, weight);
cal_helper<1, 1, c_dim, stride>(c, src, weight);

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};

} // namespace

template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44(
const float32_t* src, const float32_t* filter, const float32_t* bias,
float32_t*, float32_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op, const int, const int) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 1;
constexpr int big_oc_step = 8;
constexpr int oc_step = 4;
constexpr int ih_step = 1;
constexpr int oh_step = 1;
constexpr int ow_step = 8;
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int pack_iw_len = 1;

const int img_stride = oh * ow;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun = std::function<void(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
big_oc_step, stride, ow_step>::impl; \
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
oc_step, stride, ow_step>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}
for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size,
big_oc_step, stride,
ow_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
int oc_idx = oc_end;
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size,
oc_step, stride,
ow_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}

#define INSTANTIATION(stride, filter_size, bias_mode, Op) \
template void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44< \
bias_mode, Op, filter_size, stride>( \
const float32_t* src, const float32_t* filter, \
const float32_t* bias, float32_t*, float32_t* dst, const int oc, \
const int ic, const int ih, const int iw, const int oh, \
const int oh_block, const int ow, const Op& op, const int, \
const int);

#define FOR_OP(stride, filter, bias) \
INSTANTIATION(stride, filter, bias, NoneOp<dt_float32>) \
INSTANTIATION(stride, filter, bias, ReluOp<dt_float32>) \
INSTANTIATION(stride, filter, bias, HSwishOp<dt_float32>)

#define INSTANCE_CONV(filter, stride) \
FOR_OP(stride, filter, BiasMode::NO_BIAS) \
FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(stride, filter, BiasMode::BIAS)

// vim: syntax=cpp.doxygen

+ 10
- 32
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp View File

@@ -13,8 +13,8 @@
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"

#include "midout.h"
@@ -112,17 +112,11 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2);
float* sptr = reinterpret_cast<float*>((int8_t*)bundle.get(0) +
ncb_index.thread_id * src_size);
if (stride == 1) {
conv_bias::pack_src_fp32_nchw44_stride1(
sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw);
} else {
conv_bias::pack_src_fp32_nchw44_stride2(
sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw);
}

conv_bias::pack_src_fp32_nchw44<stride>(
sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw);

const float* fptr =
kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic;
@@ -135,25 +129,9 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
kern_param.bias<dt_float32>(batch_id, group_id) + bias_offset;

Op op;
if (stride == 1) {
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride1_##filter##x##filter##_fp32_nchw44< \
\
bias_mode, Op>(sptr, fptr, bptr, nullptr, dst, oc_block, ic, \
ih_real, iw2, oh, oh_block_real, ow, op, ph, pw)

DISPATCH_FILTER(filter, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
} else {
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw44< \
\
bias_mode, Op>(sptr, fptr, bptr, nullptr, dst, oc_block, ic, \
ih_real, iw2, oh, oh_block_real, ow, op, ph, pw)

DISPATCH_FILTER(filter, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
}
conv_bias::conv_direct_fp32_nchw44<bias_mode, Op, filter, stride>(
sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh,
oh_block_real, ow, op, ph, pw);
}

} // namespace


+ 34
- 0
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h View File

@@ -0,0 +1,34 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {

template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void conv_direct_fp32_nchw44(const float* src, const float* filter,
const float* bias, float*, float* dst,
const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block,
const int ow, const Op& op, const int, const int);
template <int stride>
void pack_src_fp32_nchw44(float* sptr_base, const float* sptr_origin, const int,
const int pw, const int pad_right, const int ih,
const int iw, const int iw2, const int pad_top,
const int pad_bottom, const int ic,
const int ic_stride);

} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn

+ 4
- 2
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp View File

@@ -120,7 +120,8 @@ static void pack_weight(const WorkspaceBundle& bundle,
kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic;
auto packed_weight = reinterpret_cast<float*>(bundle.get(1)) +
group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw;
pack_weight_fp32_nchw_nchw44(fptr, packed_weight, oc_block, fh, fw, ic);
fp32_direct_nchw_nchw44::pack_weight_fp32_nchw_nchw44(fptr, packed_weight,
oc_block, fh, fw, ic);
}

template <size_t filter_size, BiasMode bias_mode, typename Op, size_t stride>
@@ -180,7 +181,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
kern_param.bias<dt_float32>(batch_id, group_id) + oc_idx;
Op op;

conv_direct_fp32_nchw_nchw44<bias_mode, Op, filter_size, stride>(
fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44<bias_mode, Op,
filter_size, stride>(
sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2,
oh, oh_block_real, ow, op, ph, pw);
}


+ 12
- 395
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h View File

@@ -20,295 +20,12 @@
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace {
/**
*\brief ShiftCalHelper is core calculate code
*\tparam src_idx is offset for src regs
*\tparam weight_idx is offset for weight regs
*\tparam T is type of output regs
*\tparam T2 is type of src regs
*\tparam T3 is type of weight regs
*/
template <int src_idx, int weight_idx, int c_dim, typename Func, int stride,
typename T, typename T2, typename T3>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
};

template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, stride, T, T2, T3> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4]); \
c[1][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[1][step], weight[1][weight_idx], \
src[(step * stride + src_idx) / 4]);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, stride, T, T2, T3> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4]);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, int c_dim, typename FUNC, int stride,
typename T, typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, stride, T, T2, T3>::impl(
c, src, weight);
};
template <int oc>
struct OCHelper {
public:
static const int val = -1;
};

template <>
struct OCHelper<4> {
public:
static const int val = 1;
};

template <>
struct OCHelper<8> {
public:
static const int val = 2;
};
/**
* oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel
**/
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int oc_block, int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32 {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op);
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 7;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];

#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<3, 3, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<4, 4, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<5, 5, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<6, 6, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);

UNROLL_CALL_RAW(7, KERNEL_CB)
#undef KERNEL_CB

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 5;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];

#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<3, 3, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<4, 4, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
UNROLL_CALL_RAW(5, KERNEL_CB)
#undef KERNEL_CB

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 3;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr,
0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
namespace fp32_direct_nchw_nchw44 {

// row 1
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);

// row 2
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + 2 * iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 2;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr,
0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);

// row 1
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};
void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, float32_t* dst_ptr,
const int oc, const int kh, const int kw,
const int ic) {
static inline void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr,
float32_t* dst_ptr,
const int oc, const int kh,
const int kw, const int ic) {
constexpr int oc_step = 4;
const int filter_oc_stride = kh * kw * ic;
const int filter_ic_stride = kh * kw * oc_step;
@@ -327,115 +44,15 @@ void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, float32_t* dst_ptr,
}
}
}

template <BiasMode bias_mode, typename Op, int filter_size, int stride>
static void conv_direct_fp32_nchw_nchw44(
const float32_t* src, const float32_t* filter, const float32_t* bias,
float32_t*, float32_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op, const int, const int) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 1;
constexpr int big_oc_step = 8;
constexpr int oc_step = 4;
constexpr int ih_step = 1;
constexpr int oh_step = 1;
constexpr int ow_step = 8;
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int pack_iw_len = 1;
void conv_direct_fp32_nchw_nchw44(const float32_t* src, const float32_t* filter,
const float32_t* bias, float32_t*,
float32_t* dst, const int oc, const int ic,
const int ih, const int iw, const int oh,
const int oh_block, const int ow,
const Op& op, const int, const int);
} // namespace fp32_direct_nchw_nchw44

const int img_stride = oh * ow;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun = std::function<void(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
big_oc_step, stride, ow_step>::impl; \
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
oc_step, stride, ow_step>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}
for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size,
big_oc_step, stride,
ow_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
int oc_idx = oc_end;
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size,
oc_step, stride,
ow_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}
} // namespace
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen

+ 0
- 40
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h View File

@@ -1,40 +0,0 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_fp32_##layout( \
const float* src, const float* filter, const float* bias, \
float* temp, float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw);

KERN(stride1, 2, nchw44)
KERN(stride1, 3, nchw44)
KERN(stride1, 5, nchw44)
KERN(stride1, 7, nchw44)
#undef KERN

void pack_src_fp32_nchw44_stride1(float* sptr_base, const float* sptr_origin,
const int ph, const int pw,
const int pad_right, const int ih,
const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride);
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn

+ 0
- 40
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h View File

@@ -1,40 +0,0 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_fp32_##layout( \
const float* src, const float* filter, const float* bias, \
float* temp, float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw);

KERN(stride2, 2, nchw44)
KERN(stride2, 3, nchw44)
KERN(stride2, 5, nchw44)
KERN(stride2, 7, nchw44)
#undef KERN

void pack_src_fp32_nchw44_stride2(float* sptr_base, const float* sptr_origin,
const int ph, const int pw,
const int pad_right, const int ih,
const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride);
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn

+ 4
- 228
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#ifdef __ARM_FEATURE_DOTPROD
@@ -17,7 +18,7 @@
#include "src/fallback/conv_bias/common.h"

#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h"
namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {
@@ -139,234 +140,9 @@ void copy_packed_src_int8_nchw44<2>(int8_t* dst, const int dst_step,
}
}

template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size>
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias,
const int oh_size, const int oc, const int ic,
const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;

#if MEGDNN_AARCH64
constexpr int OC_BIG_INTERVAL = 12;
constexpr int OC_MID_INTERVAL = 8;
constexpr int OC_SMA_INTERVAL = 4;
#else
constexpr int OC_BIG_INTERVAL = 4;
constexpr int OC_MID_INTERVAL = 4;
constexpr int OC_SMA_INTERVAL = 4;
#endif

constexpr int OW_INTERVAL = 8;
constexpr int SH = stride;

const int dst_numbers_per_channel = oh * ow;
const int ow_remain = ow % OW_INTERVAL;
const int ow_end_idx = ow - ow_remain;
const int oc_remain =
oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8
const int oc_end_idx = oc - oc_remain;
const int dst_numbers_4channel_packed =
dst_numbers_per_channel * OC_PACK_SIZE;

using remain_fun = std::function<void(
dst_type * dst, const int dst_step, const int8_t* src, const int ih,
const int iw, const int8_t* filter, const int32_t* bias,
const int ic, const Op& op)>;

remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_mid_oc_remain = nullptr;
remain_fun kern_sma_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_BIG_INTERVAL, \
OW_INTERVAL>::impl; \
kern_mid_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_MID_INTERVAL, \
OW_INTERVAL>::impl; \
kern_sma_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_SMA_INTERVAL, \
OW_INTERVAL>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
#undef cb
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}

//! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC]
//! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL,
//! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates
//! [OW_INTERVAL, 1, OC_INTERVAL] each time
for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) {
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_BIG_INTERVAL, OW_INTERVAL>::
impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
kern_big_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}

#ifdef MEGDNN_AARCH64
//! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32
if (oc_remain) {
int oc_idx = oc_end_idx;
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_MID_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
} else {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_SMA_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
}
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
kern_mid_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
} else {
kern_sma_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
}
#endif
}

#define CONSTRUCT_FUNC(filter_size) \
template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \
void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \
dst_type* dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op) { \
conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, Op, \
filter_size>( \
dst, oh, ow, src, ih, iw, weight, bias, oh_size, oc, ic, op); \
}

CONSTRUCT_FUNC(2);
CONSTRUCT_FUNC(3);
CONSTRUCT_FUNC(5);
CONSTRUCT_FUNC(7);
#undef CONSTRUCT_FUNC

#define INSTANTIATION(dst_type, stride, i, bias_mode, Op) \
template void conv_direct_##i##x##i##_int8_nchw44<dst_type, bias_mode, Op, \
stride>( \
dst_type * dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op);

#define FOR_OP(stride, i, bias_mode) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int32, stride, i, bias_mode, \
NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)

#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)

FOR_FILTER(1)
FOR_FILTER(2)

#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION

} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn
#endif

//vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 10
- 16
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#if __ARM_FEATURE_DOTPROD
@@ -42,20 +43,13 @@ using BiasMode = ConvBiasForward::BiasMode;
* @return none
*/

#define KERN(filter_size) \
template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \
void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \
dst_type* dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op)

KERN(2);
KERN(3);
KERN(5);
KERN(7);

#undef KERN
template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size>
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias,
const int oh_size, const int oc, const int ic,
const Op& op);
/**
* @brief : copy data from src to dst for direct conv with no side effect
* @param : [output ptr] dst
@@ -84,4 +78,4 @@ void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step,

#endif

//vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 5
- 9
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp View File

@@ -148,14 +148,10 @@ static void conv_kern(const WorkspaceBundle& bundle,
float scale_dst = ncb_param.dst_type.param<dtype::QuantizedS8>().scale;
op = Op(scale_bias, scale_dst);
}

#define KERN1_NCHW44_CONV(filter) \
direct_dotprod_nchw44::conv_direct_##filter##x##filter##_int8_nchw44< \
dst_type, bias_mode, Op, stride>(dst, OH, OW, copy_dst, \
ih_real_size, iw2, weights, bias, \
oh_real_size, OC, IC, op);
DISPATCH_FILTER(filter_size, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
direct_dotprod_nchw44::conv_direct_sdot_int8_nchw44<
dst_type, stride, bias_mode, Op, filter_size>(
dst, OH, OW, copy_dst, ih_real_size, iw2, weights, bias,
oh_real_size, OC, IC, op);
}

} // namespace
@@ -342,4 +338,4 @@ ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns(

#endif

//vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 0
- 435
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h View File

@@ -1,435 +0,0 @@
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#ifdef __ARM_FEATURE_DOTPROD

#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/intrinsic_helper.h"
#include "src/arm_common/neon_struct.h"
#include "src/common/unroll_macro.h"

namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {

constexpr int SIMD_LEN = 16;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;
constexpr int filter_next_col =
IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]

template <int row, BiasMode bias_mode>
MEGDNN_ALWAYS_INLINE void init_ocx_ow8(int32x4_t c[][8],
const int32_t* bias_ptr, int oc_step) {
static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number.");
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
#define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step);
switch (row) {
case 3:
UNROLL_CALL_RAW(8, BIAS_INIT, 2);
case 2:
UNROLL_CALL_RAW(8, BIAS_INIT, 1);
default:
UNROLL_CALL_RAW(8, BIAS_INIT, 0);
}
#undef BIAS_INIT
} else {
#define BIAS_INIT(step, i) c[i][step] = vdupq_n_s32(0);
switch (row) {
case 3:
UNROLL_CALL_RAW(8, BIAS_INIT, 2);
case 2:
UNROLL_CALL_RAW(8, BIAS_INIT, 1);
default:
UNROLL_CALL_RAW(8, BIAS_INIT, 0);
}
#undef BIAS_INIT
}
}

#define cb11(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8));

#define cb21(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \
op(res[1][col], \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8));

#define cb31(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \
op(res[1][col], \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8)); \
op(res[2][col], reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + \
ld_dst_oc + col / 2 * 8));

#define cb12(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8));

#define cb22(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \
op({{res[1][2 * step], res[1][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8));

#define cb32(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \
op({{res[1][2 * step], res[1][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8)); \
op({{res[2][2 * step], res[2][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + 2 * ld_dst_oc + step * 8));

template <int row, int ow_remain, typename Op, typename T>
struct StoreOCxOWx {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc);
};

template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<1, ow_remain, Op, T> {

static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr,
const int ld_dst_oc) {
MEGDNN_MARK_USED_VAR(ld_dst_oc);
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb12);
break;
case 7:
cb11(6);
case 6:
UNROLL_CALL_RAW(3, cb12);
break;
case 5:
cb11(4);
case 4:
UNROLL_CALL_RAW(2, cb12);
break;
case 3:
cb11(2);
case 2:
UNROLL_CALL_RAW(1, cb12);
break;
case 1:
cb11(0);
default:
break;
}
}
};

template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<2, ow_remain, Op, T> {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc) {
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb22);
break;
case 7:
cb21(6);
case 6:
UNROLL_CALL_RAW(3, cb22);
break;
case 5:
cb21(4);
case 4:
UNROLL_CALL_RAW(2, cb22);
break;
case 3:
cb21(2);
case 2:
UNROLL_CALL_RAW(1, cb22);
break;
case 1:
cb21(0);
default:
break;
}
}
};

template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<3, ow_remain, Op, T> {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc) {
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb32);
break;
case 7:
cb31(6);
case 6:
UNROLL_CALL_RAW(3, cb32);
break;
case 5:
cb31(4);
case 4:
UNROLL_CALL_RAW(2, cb32);
break;
case 3:
cb31(2);
case 2:
UNROLL_CALL_RAW(1, cb32);
break;
case 1:
cb31(0);
default:
break;
}
}
};

#undef cb11
#undef cb21
#undef cb31
#undef cb12
#undef cb22
#undef cb32

template <int row, int ow_remain, typename Op, typename T>
MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8],
const Op& op, T* dst_ptr,
const int ld_dst_oc) {
StoreOCxOWx<row, ow_remain, Op, T>::impl(res, op, dst_ptr, ld_dst_oc);
}

template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename FUNC, typename T, typename T2, typename T3>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) {
#define cb(step) \
res[res_row][step] = FUNC::template impl<((src_start_idx + step) % 4)>( \
res[res_row][step], weight[weight_idx], \
src[src_row][(src_start_idx + step) / 4]);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename FUNC, typename T, typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) {
ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, FUNC, T, T2,
T3>::impl(res, src, weight);
};

/**
* oc12_owx(m = 12, n = x) and oc8_owx(m = 8, n = x) and oc4_owx(m = 4, n = x)
* gemm like kernel
* */
template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int ow_remain, int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44 {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op);
};

template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]

const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;

constexpr int NSRC = (ow_interval + filter_size - 1) / 4 + 1;
constexpr int LOOP = oc_interval / 4;

int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);

for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[1][4];
int8x16_t weight[3];

load_helper<NSRC, 0, SIMD_LEN, 1, Vld1q_s8>(src, i_src, 0);

//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, 0, step, 2, Vdotq_laneq_s32>(res, src, weight); \
break; \
default: \
break; \
}

switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART

i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};

template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]

const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;

constexpr int NSRC = (ow_interval * 2 + filter_size - 3) / 8 + 1;
constexpr int LOOP = oc_interval / 4;

int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);

for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[2][3];
int8x16_t weight[3];
const int offset = megdnn::div_ceil(iw, 2) * IC_PACK_SIZE;

load_helper<NSRC, 0, SIMD_LEN, 2, Vld1q_s8>(src, i_src, offset);

//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \
weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, step % 2, step / 2, 2, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
default: \
break; \
}

switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART

i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};

} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn

#endif

// vim: syntax=cpp.doxygen

+ 245
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h View File

@@ -0,0 +1,245 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#if __ARM_FEATURE_DOTPROD
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/intrinsic_helper.h"
#include "src/arm_common/neon_struct.h"
#include "src/common/unroll_macro.h"

namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {

constexpr int SIMD_LEN = 16;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;
constexpr int filter_next_col =
IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]

template <int row, BiasMode bias_mode>
MEGDNN_ALWAYS_INLINE void init_ocx_ow8(int32x4_t c[][8],
const int32_t* bias_ptr, int oc_step) {
static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number.");
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
#define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step);
switch (row) {
case 3:
UNROLL_CALL_RAW(8, BIAS_INIT, 2);
case 2:
UNROLL_CALL_RAW(8, BIAS_INIT, 1);
default:
UNROLL_CALL_RAW(8, BIAS_INIT, 0);
}
#undef BIAS_INIT
} else {
#define BIAS_INIT(step, i) c[i][step] = vdupq_n_s32(0);
switch (row) {
case 3:
UNROLL_CALL_RAW(8, BIAS_INIT, 2);
case 2:
UNROLL_CALL_RAW(8, BIAS_INIT, 1);
default:
UNROLL_CALL_RAW(8, BIAS_INIT, 0);
}
#undef BIAS_INIT
}
}

#define cb11(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8));

#define cb21(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \
op(res[1][col], \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8));

#define cb31(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \
op(res[1][col], \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8)); \
op(res[2][col], reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + \
ld_dst_oc + col / 2 * 8));

#define cb12(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8));

#define cb22(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \
op({{res[1][2 * step], res[1][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8));

#define cb32(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \
op({{res[1][2 * step], res[1][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8)); \
op({{res[2][2 * step], res[2][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + 2 * ld_dst_oc + step * 8));

template <int row, int ow_remain, typename Op, typename T>
struct StoreOCxOWx {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc);
};

template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<1, ow_remain, Op, T> {
static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr,
const int ld_dst_oc) {
MEGDNN_MARK_USED_VAR(ld_dst_oc);
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb12);
break;
case 7:
cb11(6);
case 6:
UNROLL_CALL_RAW(3, cb12);
break;
case 5:
cb11(4);
case 4:
UNROLL_CALL_RAW(2, cb12);
break;
case 3:
cb11(2);
case 2:
UNROLL_CALL_RAW(1, cb12);
break;
case 1:
cb11(0);
default:
break;
}
}
};

template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<2, ow_remain, Op, T> {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc) {
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb22);
break;
case 7:
cb21(6);
case 6:
UNROLL_CALL_RAW(3, cb22);
break;
case 5:
cb21(4);
case 4:
UNROLL_CALL_RAW(2, cb22);
break;
case 3:
cb21(2);
case 2:
UNROLL_CALL_RAW(1, cb22);
break;
case 1:
cb21(0);
default:
break;
}
}
};

template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<3, ow_remain, Op, T> {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc) {
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb32);
break;
case 7:
cb31(6);
case 6:
UNROLL_CALL_RAW(3, cb32);
break;
case 5:
cb31(4);
case 4:
UNROLL_CALL_RAW(2, cb32);
break;
case 3:
cb31(2);
case 2:
UNROLL_CALL_RAW(1, cb32);
break;
case 1:
cb31(0);
default:
break;
}
}
};

#undef cb11
#undef cb21
#undef cb31
#undef cb12
#undef cb22
#undef cb32

template <int row, int ow_remain, typename Op, typename T>
MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8],
const Op& op, T* dst_ptr,
const int ld_dst_oc) {
StoreOCxOWx<row, ow_remain, Op, T>::impl(res, op, dst_ptr, ld_dst_oc);
}

template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename T, typename T2, typename T3>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) {
#define cb(step) \
res[res_row][step] = \
vdotq_laneq_s32(res[res_row][step], weight[weight_idx], \
src[src_row][(src_start_idx + step) / 4], \
(src_start_idx + step) % 4);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename T, typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) {
ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, T, T2,
T3>::impl(res, src, weight);
};

/**
* oc12_owx(m = 12, n = x) and oc8_owx(m = 8, n = x) and oc4_owx(m = 4, n = x)
* gemm like kernel
* */
template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int ow_remain, int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44 {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op);
};

} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 320
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp View File

@@ -0,0 +1,320 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h"

namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {
template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]

const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;

constexpr int NSRC = (ow_interval + filter_size - 1) / 4 + 1;
constexpr int LOOP = oc_interval / 4;

int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);

for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[1][4];
int8x16_t weight[3];

load_helper<NSRC, 0, SIMD_LEN, 1, Vld1q_s8>(src, i_src, 0);

//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0>(res, src, weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1>(res, src, weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1>(res, src, weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, 0, step, 2>(res, src, weight); \
break; \
default: \
break; \
}

switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART

i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};

template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size>
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias,
const int oh_size, const int oc, const int ic,
const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;

#if MEGDNN_AARCH64
constexpr int OC_BIG_INTERVAL = 12;
constexpr int OC_MID_INTERVAL = 8;
constexpr int OC_SMA_INTERVAL = 4;
#else
constexpr int OC_BIG_INTERVAL = 4;
constexpr int OC_MID_INTERVAL = 4;
constexpr int OC_SMA_INTERVAL = 4;
#endif

constexpr int OW_INTERVAL = 8;
constexpr int SH = stride;

const int dst_numbers_per_channel = oh * ow;
const int ow_remain = ow % OW_INTERVAL;
const int ow_end_idx = ow - ow_remain;
const int oc_remain =
oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8
const int oc_end_idx = oc - oc_remain;
const int dst_numbers_4channel_packed =
dst_numbers_per_channel * OC_PACK_SIZE;

using remain_fun = std::function<void(
dst_type * dst, const int dst_step, const int8_t* src, const int ih,
const int iw, const int8_t* filter, const int32_t* bias,
const int ic, const Op& op)>;

remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_mid_oc_remain = nullptr;
remain_fun kern_sma_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_BIG_INTERVAL, \
OW_INTERVAL>::impl; \
kern_mid_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_MID_INTERVAL, \
OW_INTERVAL>::impl; \
kern_sma_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_SMA_INTERVAL, \
OW_INTERVAL>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
#undef cb
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}

//! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC]
//! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL,
//! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates
//! [OW_INTERVAL, 1, OC_INTERVAL] each time
for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) {
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_BIG_INTERVAL, OW_INTERVAL>::
impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
kern_big_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}

#ifdef MEGDNN_AARCH64
//! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32
if (oc_remain) {
int oc_idx = oc_end_idx;
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_MID_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
} else {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_SMA_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
}
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
kern_mid_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
} else {
kern_sma_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
}
#endif
}

#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \
template void conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, \
Op, filter_size>( \
dst_type * dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op);

#define FOR_OP(stride, i, bias_mode) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int32, stride, i, bias_mode, \
NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)

#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)

FOR_FILTER(1)

#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION

} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 322
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp View File

@@ -0,0 +1,322 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD

#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h"
namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {
template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]

const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;

constexpr int NSRC = (ow_interval * 2 + filter_size - 3) / 8 + 1;
constexpr int LOOP = oc_interval / 4;

int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);

for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[2][3];
int8x16_t weight[3];
const int offset = megdnn::div_ceil(iw, 2) * IC_PACK_SIZE;

load_helper<NSRC, 0, SIMD_LEN, 2, Vld1q_s8>(src, i_src, offset);

//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, step % 2, step / 2, 2>(res, src, weight); \
break; \
default: \
break; \
}

switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART

i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};

template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size>
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias,
const int oh_size, const int oc, const int ic,
const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;

#if MEGDNN_AARCH64
constexpr int OC_BIG_INTERVAL = 12;
constexpr int OC_MID_INTERVAL = 8;
constexpr int OC_SMA_INTERVAL = 4;
#else
constexpr int OC_BIG_INTERVAL = 4;
constexpr int OC_MID_INTERVAL = 4;
constexpr int OC_SMA_INTERVAL = 4;
#endif

constexpr int OW_INTERVAL = 8;
constexpr int SH = stride;

const int dst_numbers_per_channel = oh * ow;
const int ow_remain = ow % OW_INTERVAL;
const int ow_end_idx = ow - ow_remain;
const int oc_remain =
oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8
const int oc_end_idx = oc - oc_remain;
const int dst_numbers_4channel_packed =
dst_numbers_per_channel * OC_PACK_SIZE;

using remain_fun = std::function<void(
dst_type * dst, const int dst_step, const int8_t* src, const int ih,
const int iw, const int8_t* filter, const int32_t* bias,
const int ic, const Op& op)>;

remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_mid_oc_remain = nullptr;
remain_fun kern_sma_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_BIG_INTERVAL, \
OW_INTERVAL>::impl; \
kern_mid_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_MID_INTERVAL, \
OW_INTERVAL>::impl; \
kern_sma_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_SMA_INTERVAL, \
OW_INTERVAL>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
#undef cb
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}

//! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC]
//! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL,
//! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates
//! [OW_INTERVAL, 1, OC_INTERVAL] each time
for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) {
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_BIG_INTERVAL, OW_INTERVAL>::
impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
kern_big_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}

#ifdef MEGDNN_AARCH64
//! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32
if (oc_remain) {
int oc_idx = oc_end_idx;
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_MID_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
} else {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_SMA_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
}
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
kern_mid_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
} else {
kern_sma_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
}
#endif
}

#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \
template void conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, \
Op, filter_size>( \
dst_type * dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op);

#define FOR_OP(stride, i, bias_mode) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int32, stride, i, bias_mode, \
NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)

#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)

FOR_FILTER(2)

#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION

} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn

#endif
// vim: syntax=cpp.doxygen

+ 448
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp View File

@@ -0,0 +1,448 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
namespace megdnn {
namespace arm_common {
namespace dot_direct_nchw_nchw44 {

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); \
c[1][step] = Func::template impl<(src_idx + step) % 4>( \
c[1][step], weight[1][weight_idx], src[(src_idx + step) / 4]);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
////////////////////stride 1///////////////////
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 2;
constexpr int filter_width = 4;
constexpr int weight_reg = 2;
constexpr int src_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 0 * iw * pack_iw_len, 0);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 3;
constexpr int filter_width = 4;
constexpr int weight_reg = 3;
constexpr int src_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 0 * iw * pack_iw_len, 0);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 2
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 2 * iw * pack_iw_len, 0);
cal_helper<0, 2, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 5;
constexpr int filter_width = 8;
constexpr int src_reg = 3;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];

#define cb(step) \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, src_ptr + step * iw * pack_iw_len, 0); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);

UNROLL_CALL_RAW(5, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 7;
constexpr int filter_width = 8;
constexpr int src_reg = 3;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, src_ptr + step * iw * pack_iw_len, 0); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);

UNROLL_CALL_RAW(7, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <>
void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base,
const int8_t* sptr_origin, const int,
const int pw, const int, const int ih,
const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride,
int8_t* temp_ptr) {
static uint8_t reorder_idx[16] = {0, 1, 2, 3, 1, 2, 3, 4,
2, 3, 4, 5, 3, 4, 5, 6};
uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]);

constexpr int iw_step = 16;
constexpr int pack_iw_len = 4;
const int iw_with_pad = iw + 2 * pw;
const int iw_with_pad_end = iw_with_pad / iw_step * iw_step;
rep(ic_idx, ic) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0,
sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) *
pack_iw_len);
sptr_base += iw2 * pad_top * pack_iw_len;
rep(ih_idx, ih) {
memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t));
memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw);
for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) {
int8x16_t src[4];
int8x16_t dst[4];
src[0] = vld1q_s8(temp_ptr + iw_idx);
src[1] = vld1q_s8(temp_ptr + iw_idx + 4);
src[2] = vld1q_s8(temp_ptr + iw_idx + 8);
src[3] = vld1q_s8(temp_ptr + iw_idx + 12);
dst[0] = vqtbl1q_s8(src[0], tbl_idx);
dst[1] = vqtbl1q_s8(src[1], tbl_idx);
dst[2] = vqtbl1q_s8(src[2], tbl_idx);
dst[3] = vqtbl1q_s8(src[3], tbl_idx);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]);
}
for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) {
*(sptr_base + iw_idx * pack_iw_len + 0) =
*(temp_ptr + iw_idx + 0);
*(sptr_base + iw_idx * pack_iw_len + 1) =
*(temp_ptr + iw_idx + 1);
*(sptr_base + iw_idx * pack_iw_len + 2) =
*(temp_ptr + iw_idx + 2);
*(sptr_base + iw_idx * pack_iw_len + 3) =
*(temp_ptr + iw_idx + 3);
}
sptr_base += iw2 * pack_iw_len;
sptr += iw;
}
sptr_base += iw2 * pad_bottom * pack_iw_len;
}
}

template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp,
int8_t* dst, const int oc, const int ic,
const int ih, const int iw, const int oh,
const int oh_block, const int ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr int fh = filter_size;
constexpr int fw = (filter_size + 3) / 4 * 4;
#if MEGDNN_AARCH64
constexpr int big_oc_step = 8;
#else
constexpr int big_oc_step = 4;
#endif
constexpr int oc_step = 4;
constexpr int ih_step = 1;
constexpr int oh_step = 1;
constexpr int ow_step = 8;
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int pack_iw_len = stride == 2 ? 1 : 4;

const int img_stride = oh * ow;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun =
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
big_oc_step, ow_step, stride>::impl; \
kern_small_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
oc_step, ow_step, stride>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}

for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
big_oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
int oc_idx = oc_end;
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}
#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \
template void \
conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, filter_size, stride>( \
const int8_t* src, const int8_t* filter, const int32_t* bias, \
int32_t* temp, int8_t* dst, const int oc, const int ic, \
const int ih, const int iw, const int oh, const int oh_block, \
const int ow, const Op& op);

#define GET_OP_PARAM(stride, filter, bias_mode) \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define GET_BIAS_MODE_PARAM(stride, filter) \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
GET_BIAS_MODE_PARAM(stride, 7)

DISPATCH_CONV_KERN(1);

} // namespace dot_direct_nchw_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 437
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp View File

@@ -0,0 +1,437 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
namespace megdnn {
namespace arm_common {
namespace dot_direct_nchw_nchw44 {

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 2, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2], weight[0][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[1][step * 2], weight[1][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2 + 1], weight[0][weight_idx], \
src[1][(src_idx + step) / 4]); \
c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[1][step * 2 + 1], weight[1][weight_idx], \
src[1][(src_idx + step) / 4]);

UNROLL_CALL_RAW(4, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 2, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2], weight[0][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2 + 1], weight[0][weight_idx], \
src[1][(src_idx + step) / 4]);

UNROLL_CALL_RAW(4, cb);
#undef cb
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
2> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 2;
constexpr int filter_hight = 2;
constexpr int filter_width = 4;
constexpr int weight_reg = 1;
constexpr int src_reg = 1;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 0 * iw, stride);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 1 * iw, stride);
load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block,
2> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 2;
constexpr int filter_hight = 3;
constexpr int filter_width = 4;
constexpr int weight_reg = 1;
constexpr int src_reg = 1;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 0 * iw, stride);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 1 * iw, stride);
load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 2
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 2 * iw, stride);
load_helper<weight_reg, 2 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block,
2> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 2;
constexpr int filter_hight = 5;
constexpr int filter_width = 8;
constexpr int src_reg = 2;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \
stride); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW(5, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += 5 * 32;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

/**
* oc = 8, ow = 8
* dot 4 element, pad last filter and do twice dot every row filter, filter like
* below
* --------------------------
* |x, x, x, x,| x, x, x, 0 |
* --------------------------
**/
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block,
2> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 2;
constexpr int filter_hight = 7;
constexpr int filter_width = 8;
constexpr int src_reg = 2;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \
stride); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW(7, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += 7 * 32;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <>
void pack_src_int8_nchw_nchw44_dot<2>(
int8_t* sptr_base, const int8_t* sptr_origin, const int, const int pw,
const int, const int ih, const int iw, const int iw2, const int pad_top,
const int pad_bottom, const int ic, const int ic_stride, int8_t*) {
constexpr int ic_step = 1;
rep_step(ic_idx, ic, ic_step) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0,
sizeof(int8_t) * ic_step * iw2 * (ih + pad_top + pad_bottom));
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
memcpy(sptr_base + pw * ic_step, sptr,
sizeof(int8_t) * iw * ic_step);
sptr_base += iw2 * ic_step;
sptr += iw * ic_step;
}
sptr_base += iw2 * pad_bottom * ic_step;
}
}

template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp,
int8_t* dst, const int oc, const int ic,
const int ih, const int iw, const int oh,
const int oh_block, const int ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr int fh = filter_size;
constexpr int fw = (filter_size + 3) / 4 * 4;
#if MEGDNN_AARCH64
constexpr int big_oc_step = 8;
#else
constexpr int big_oc_step = 4;
#endif
constexpr int oc_step = 4;
constexpr int ih_step = 1;
constexpr int oh_step = 1;
constexpr int ow_step = 8;
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int pack_iw_len = stride == 2 ? 1 : 4;

const int img_stride = oh * ow;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun =
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
big_oc_step, ow_step, stride>::impl; \
kern_small_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
oc_step, ow_step, stride>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}

for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
big_oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
int oc_idx = oc_end;
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}

#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \
template void \
conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, filter_size, stride>( \
const int8_t* src, const int8_t* filter, const int32_t* bias, \
int32_t* temp, int8_t* dst, const int oc, const int ic, \
const int ih, const int iw, const int oh, const int oh_block, \
const int ow, const Op& op);

#define GET_OP_PARAM(stride, filter, bias_mode) \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define GET_BIAS_MODE_PARAM(stride, filter) \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
GET_BIAS_MODE_PARAM(stride, 7)

DISPATCH_CONV_KERN(2);

} // namespace dot_direct_nchw_nchw44
} // namespace arm_common
} // namespace megdnn

#endif
// vim: syntax=cpp.doxygen

+ 743
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp View File

@@ -0,0 +1,743 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"

namespace megdnn {
namespace arm_common {
namespace {
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
DstType* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc4 = oc_step * fh * fw * ic;

int32x4_t c[2][8];
int8x16_t weight[2][2];
int8x16_t src[8 + 1];
int16x8_t temp_c[4];

init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4);
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16);

c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][0], src[1], c[1][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][1], src[2], c[1][1], temp_c[3]);

c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][0], src[2], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][0], src[3], c[1][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][1], src[3], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][1], src[4], c[1][3], temp_c[3]);

c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][0], src[4], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][0], src[5], c[1][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][1], src[5], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][1], src[6], c[1][5], temp_c[3]);

c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][0], src[6], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][0], src[7], c[1][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][1], src[7], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][1], src[8], c[1][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}

template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
DstType* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[1][8];
int8x16_t weight[1][2];
int8x16_t src[8 + 1];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);

c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[1]);

c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[1]);

c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[1]);

c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}

store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}

template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
struct KerNeonDirectStride1Int8 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc);
};
/**
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc>
example: (format like weight<oc, ic>)
packed weight
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0>
**/
//! TODO: can try oh = 2 impl, oc = 8 impl
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 3, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 3;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[3];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]);

c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]);

c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]);

c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 5, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 5;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[5];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]);

c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]);

c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]);

src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));

c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 7, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 7;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[7];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16);
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[5], src[6], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[6], src[7], c[0][1], temp_c[1]);

c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[5], src[7], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[5], src[8], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[6], src[8], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[6], src[9], c[0][3], temp_c[1]);

src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));

c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[5], src[9], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[5], src[0], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[6], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[6], src[1], c[0][5], temp_c[1]);

src[2] = vld1q_s8(src_ic_0_3 + 12 * 16);
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16));

c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[5], src[1], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[5], src[2], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[6], src[2], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[6], src[3], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, typename DstType>
void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src,
const int8_t* filter,
const int32_t* bias, int32_t* temp,
DstType* dst, const size_t oc,
const size_t ic, const size_t ih,
const size_t iw, const size_t oh,
const size_t ow, const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t filter_size = 2;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t big_oc_step = 8;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr int pack_iw_len = 4;

const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_oc = oh * ow * oc_step;

using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, int iw,
int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, step, \
filter_size, 2, DstType>; \
kern_small_oc_remain = \
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, step, \
filter_size, 1, DstType>; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
#undef cb
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, 0, filter_size,
2, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_oc, op);
}
}
}
if (oc_remain > 0) {
const size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, 0, filter_size,
1, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_oc, op);
}
}
}
}
template <BiasMode bias_mode, typename Op, int filter_size, typename DstType>
void conv_direct_stride1_int8_nchw44_kern(const int8_t* src,
const int8_t* filter,
const int32_t* bias, int32_t* temp,
DstType* dst, const size_t oc,
const size_t ic, const size_t ih,
const size_t iw, const size_t oh,
const size_t ow, const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr int pack_iw_len = 4;

const size_t img_stride = oh * ow;
const int ld_dst_oc = oh * ow * oc_step;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;

using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, int iw,
const Op& op, int ld_dst_oc)>;

remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_small_oc_remain = \
KerNeonDirectStride1Int8<bias_mode, Op, step, filter_size, 1, \
DstType>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
#undef cb

for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDirectStride1Int8<bias_mode, Op, ow_step, filter_size, 1,
DstType>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, op, ld_dst_oc);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, op, ld_dst_oc);
}
}
}
}
} // namespace

namespace int8_direct_nchw44 {
template <BiasMode bias_mode, typename Op, int filter_size, typename DstType>
struct ConvDirectInt8Nchw44Choose<bias_mode, Op, filter_size, DstType, 1> {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, DstType* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
conv_direct_stride1_int8_nchw44_kern<bias_mode, Op, filter_size,
DstType>(
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op);
}
};

template <BiasMode bias_mode, typename Op, typename DstType>
struct ConvDirectInt8Nchw44Choose<bias_mode, Op, 2, DstType, 1> {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, DstType* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
conv_direct_stride1_2x2_int8_nchw44<bias_mode, Op, DstType>(
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op);
}
};

#define DO_CONV_KERN_FUN(stride, DstType, filter_size, bias_mode, Op) \
template struct ConvDirectInt8Nchw44Choose<bias_mode, Op, filter_size, \
DstType, stride>;

#define GET_OP_PARAM(stride, filter, bias_mode) \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
\
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
\
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
\
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, NoneOp<dt_int32>)

#define GET_BIAS_MODE_PARAM(stride, filter) \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
GET_BIAS_MODE_PARAM(stride, 7)

DISPATCH_CONV_KERN(1);

} // namespace int8_direct_nchw44
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 778
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp View File

@@ -0,0 +1,778 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"

namespace megdnn {
namespace arm_common {
namespace {
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
struct KerNeonDirectStride2Int8 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc);
};

template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
DstType* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc4 = oc_step * fh * fw * ic;

int32x4_t c[2][8];
int8x16_t weight[2][2];
int8x16_t src[8 + 1];
int16x8_t temp_c[4];

init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8(src_ic_0_3 + 16);
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16);
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16);
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16);

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4);
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16);

c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][0], src[2], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][0], src[2], c[1][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][1], src[3], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][1], src[3], c[1][1], temp_c[3]);

c[0][2] = vdotq_s32_h(weight[0][0], src[4], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][0], src[4], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][0], src[6], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][0], src[6], c[1][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[0][1], src[5], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][1], src[5], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][1], src[7], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][1], src[7], c[1][3], temp_c[3]);

src[0] = vld1q_s8(src_ic_0_3 + 9 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 11 * 16);
c[0][4] = vdotq_s32_h(weight[0][0], src[8], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][0], src[8], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][0], src[1], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][0], src[1], c[1][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[0][1], src[0], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][1], src[0], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][1], src[2], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][1], src[2], c[1][5], temp_c[3]);

src[3] = vld1q_s8(src_ic_0_3 + 12 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 13 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 14 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 15 * 16);
c[0][6] = vdotq_s32_h(weight[0][0], src[3], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][0], src[3], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][0], src[5], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][0], src[5], c[1][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[0][1], src[4], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][1], src[4], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][1], src[6], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][1], src[6], c[1][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}

template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
DstType* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[2];
int8x16_t src[8 + 1];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[1]);

c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]);

src[0] = vld1q_s8(src_ic_0_3 + 9 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 11 * 16);
c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[1], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[1], src[2], c[0][5], temp_c[1]);

src[3] = vld1q_s8(src_ic_0_3 + 12 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 13 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 14 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 15 * 16);
c[0][6] = vdotq_s32_h(weight[0], src[3], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0], src[5], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[1], src[4], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[6], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}

store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
/**
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc>
example: (format like weight<oc, ic>)
packed weight
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0>
**/
// TODO: can try oh = 2 impl, oc = 8 impl
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 3, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 3;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[3];
int8x16_t src[8 + 2];
int16x8_t temp_c[4];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]);

c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]);

src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));
src[2] = vld1q_s8((src_ic_0_3 + 12 * 16));
c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]);

src[3] = vld1q_s8((src_ic_0_3 + 13 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 14 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 15 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 16 * 16));
c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 5, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 5;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[5];
int8x16_t src[8 + 2];
int16x8_t temp_c[4];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]);

src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]);

src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));
src[2] = vld1q_s8((src_ic_0_3 + 12 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 14 * 16));
c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]);

src[5] = vld1q_s8((src_ic_0_3 + 15 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 16 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 17 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 18 * 16));
c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}

store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 7, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 7;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[7];
int8x16_t src[8 + 2];
int16x8_t temp_c[4];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16);
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16);
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16);
src[9] = vld1q_s8(src_ic_0_3 + 9 * 16);

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16);
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[5], src[7], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[6], src[8], c[0][1], temp_c[1]);

src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 11 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 12 * 16);
c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[5], src[9], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[5], src[1], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[6], src[0], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[6], src[2], c[0][3], temp_c[3]);

src[3] = vld1q_s8(src_ic_0_3 + 13 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 14 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 15 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 16 * 16);
c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[5], src[3], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[5], src[5], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[6], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[6], src[6], c[0][5], temp_c[1]);

src[7] = vld1q_s8(src_ic_0_3 + 17 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 18 * 16);
src[9] = vld1q_s8(src_ic_0_3 + 19 * 16);
src[0] = vld1q_s8(src_ic_0_3 + 20 * 16);
c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[5], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[5], src[9], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[6], src[8], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[6], src[0], c[0][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, typename DstType>
void conv_direct_stride2_2x2_int8_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*,
DstType* dst, const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow, const Op& op) {
constexpr size_t filter_size = 2;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t big_oc_step = 8;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
constexpr int pack_iw_len = 4;

const size_t out_img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_dst_oc = oh * ow * oc_step;

using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, int iw,
int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, step, \
filter_size, 2, DstType>; \
kern_small_oc_remain = \
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, step, \
filter_size, 1, DstType>; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
#undef cb

for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, ow_step,
filter_size, 2, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}

if (oc_remain > 0) {
const size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, ow_step,
filter_size, 1, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}

template <BiasMode bias_mode, typename Op, int filter_size, typename DstType>
void conv_direct_stride2_int8_nchw44_kern(
const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*,
DstType* dst, const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow, const Op& op) {
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
constexpr int pack_iw_len = 4;

const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const int ld_dst_oc = oh * ow * oc_step;

using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, int iw,
const Op& op, int ld_dst_oc)>;

remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_small_oc_remain = \
KerNeonDirectStride2Int8<bias_mode, Op, step, filter_size, 1, \
DstType>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
#undef cb

for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDirectStride2Int8<bias_mode, Op, ow_step, filter_size, 1,
DstType>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, op, ld_dst_oc);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, op, ld_dst_oc);
}
}
}
}
} // namespace

namespace int8_direct_nchw44 {
template <BiasMode bias_mode, typename Op, int filter_size, typename DstType>
struct ConvDirectInt8Nchw44Choose<bias_mode, Op, filter_size, DstType, 2> {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, DstType* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
conv_direct_stride2_int8_nchw44_kern<bias_mode, Op, filter_size,
DstType>(
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op);
}
};

template <BiasMode bias_mode, typename Op, typename DstType>
struct ConvDirectInt8Nchw44Choose<bias_mode, Op, 2, DstType, 2> {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, DstType* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
conv_direct_stride2_2x2_int8_nchw44<bias_mode, Op, DstType>(
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op);
}
};

#define DO_CONV_KERN_FUN(stride, DstType, filter_size, bias_mode, Op) \
template struct ConvDirectInt8Nchw44Choose<bias_mode, Op, filter_size, \
DstType, stride>;

#define GET_OP_PARAM(stride, filter, bias_mode) \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
\
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
\
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
\
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, NoneOp<dt_int32>)

#define GET_BIAS_MODE_PARAM(stride, filter) \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
GET_BIAS_MODE_PARAM(stride, 7)

DISPATCH_CONV_KERN(2);

} // namespace int8_direct_nchw44
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 47
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h View File

@@ -0,0 +1,47 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"
namespace megdnn {
namespace arm_common {
namespace {

template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int oc_block, int stride>
struct KerNeonXXs2NchwNchw44 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op);
};

template <int oc>
struct OCHelper {
public:
static const int val = 0;
};
template <>
struct OCHelper<4> {
public:
static const int val = 1;
};
template <>
struct OCHelper<8> {
public:
static const int val = 2;
};

} // namespace
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 561
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp View File

@@ -0,0 +1,561 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h"
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"
namespace megdnn {
namespace arm_common {
namespace {
/**
* @brief core code for calculation patten
*
* @tparam src_idx is offset of src reg
* @tparam weight_idx is offset of weight reg
* @tparam c_dim is output channel
* @tparam Func mla operation funcion
* @tparam stride
* @tparam T outpur regs type
* @tparam T2 src regs type
* @tparam T3 weight regs type
* @tparam T4 temp regs type
*/

template <int src_idx, int weight_idx, int c_dim, int stride, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp);
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
};
template <int src_idx, int weight_idx, int c_dim, int stride, typename T,
typename T2, typename T3, typename T4>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) {
ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, T4>::impl(
c, src, weight, temp);
}
template <int src_idx, int weight_idx, int c_dim, int stride, typename T,
typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, int>::impl(
c, src, weight);
};
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 1, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[0][weight_idx],
c[0][0], temp[0]);
c[1][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[1][weight_idx],
c[1][0], temp[1]);
c[0][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[0][weight_idx],
c[0][1], temp[2]);
c[1][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[1][weight_idx],
c[1][1], temp[3]);
c[0][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[0][weight_idx],
c[0][2], temp[0]);
c[1][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[1][weight_idx],
c[1][2], temp[1]);
c[0][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[0][weight_idx],
c[0][3], temp[2]);
c[1][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[1][weight_idx],
c[1][3], temp[3]);

c[0][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[0][weight_idx],
c[0][4], temp[0]);
c[1][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[1][weight_idx],
c[1][4], temp[1]);
c[0][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[0][weight_idx],
c[0][5], temp[2]);
c[1][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[1][weight_idx],
c[1][5], temp[3]);
c[0][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[0][weight_idx],
c[0][6], temp[0]);
c[1][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[1][weight_idx],
c[1][6], temp[1]);
c[0][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[0][weight_idx],
c[0][7], temp[2]);
c[1][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[1][weight_idx],
c[1][7], temp[3]);
}
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&);
};
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 1, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[0][weight_idx],
c[0][0], temp[0]);
c[0][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[0][weight_idx],
c[0][1], temp[1]);
c[0][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[0][weight_idx],
c[0][2], temp[2]);
c[0][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[0][weight_idx],
c[0][3], temp[3]);
c[0][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[0][weight_idx],
c[0][4], temp[0]);
c[0][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[0][weight_idx],
c[0][5], temp[1]);
c[0][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[0][weight_idx],
c[0][6], temp[2]);
c[0][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[0][weight_idx],
c[0][7], temp[3]);
}
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&);
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 2;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 1 * filter_width * oc_step,
ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

weight_ptr += oc_step * filter_height * filter_width;
}

store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 3;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);

load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 1 * filter_width * oc_step,
ld_weight_oc);

load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 2 * filter_width * oc_step,
ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 2 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 5;
constexpr int filter_width = 8;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 2;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
#define cb(step) \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
dot4_weight, weight_ptr + step * filter_width * oc_step, \
ld_weight_oc); \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len, 0); \
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \
load_helper<4, 0, simd_len, 0, Vld1q_s8>( \
src, \
nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \
0); \
cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c);
UNROLL_CALL_RAW(5, cb);
#undef cb
weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 7;
constexpr int filter_width = 8;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 2;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
#define cb(step) \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
dot4_weight, weight_ptr + step * filter_width * oc_step, \
ld_weight_oc); \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len, 0); \
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \
load_helper<4, 0, simd_len, 0, Vld1q_s8>( \
src, \
nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \
0); \
cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c);

UNROLL_CALL_RAW(7, cb);
#undef cb
weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
} // namespace

namespace int8_direct_nchw_nchw44 {
/**
* pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)}
* pack interleave two adjacent row in filter to one row
* */
template <>
void pack_nchw44_weight_for_nchw_conv<1>(const int8_t* src_ptr, int8_t* dst_ptr,
const int ic, const int fh,
const int fw, const int oc) {
constexpr int oc_step = 4;
const int fw2 = round_up(fw, 4);
const int fw_remain = fw2 - fw;
const int dst_ic_stride = fh * fw2;
const int oc_step_stride = fh * fw2 * ic * oc_step;
static const uint8_t transpose_4x4_idx[16] = {0, 4, 1, 5, 2, 6, 3, 7,
8, 12, 9, 13, 10, 14, 11, 15};
uint8x16_t tbl_transpose_4x4 = vld1q_u8(&transpose_4x4_idx[0]);
rep_step(oc_idx, oc, oc_step) {
int32_t* dst_temp_ptr =
reinterpret_cast<int32_t*>(dst_ptr + oc_idx * ic * fh * fw2);
const int32_t* src_temp_ptr = reinterpret_cast<const int32_t*>(
src_ptr + oc_idx * ic * fh * fw);
// transpose ic and pad
rep(fh_idx, fh) {
rep(fw_idx, fw) {
rep(ic_idx, ic) {
*(dst_temp_ptr + ic_idx * dst_ic_stride) = *src_temp_ptr;
src_temp_ptr++;
}
dst_temp_ptr++;
}
rep(ic_idx, ic) {
memset(dst_temp_ptr + ic_idx * dst_ic_stride, 0,
sizeof(int8_t) * oc_step * fw_remain);
}
dst_temp_ptr += fw_remain;
}
// transpose fw oc
int8_t* trans_dst_temp_ptr =
reinterpret_cast<int8_t*>(dst_ptr + oc_idx * ic * fh * fw2);

rep_step(idx, oc_step_stride, 16) {
int8x16_t temp = vld1q_s8(trans_dst_temp_ptr + idx);
vst1q_s8(trans_dst_temp_ptr + idx,
vqtbl1q_s8(temp, tbl_transpose_4x4));
}
}
};

/**
* pack (ic, h, w) to (ic, h, w * 16)
* pack interleave two adjacent row in src and repeat 4 times, store to one row
* */
template <>
void pack_nchw_src_for_nchw44_conv<1>(const int8_t* sptr_origin,
int8_t* sptr_base, const int ic,
const int pad_top, const int pad_bottom,
const int, const int, const int ih,
const int iw, const int iw2, const int pw,
int8_t* temp_ptr) {
static uint8_t reorder_idx[16] = {0, 1, 0, 1, 0, 1, 0, 1,
2, 3, 2, 3, 2, 3, 2, 3};
uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]);

constexpr int iw_step = 4;
constexpr int pack_iw_len = 16;
const int ic_stride = ih * iw;
const int iw_with_pad = iw + 2 * pw;
const int iw_with_pad_end = iw_with_pad / iw_step * iw_step;
rep(ic_idx, ic) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0,
sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) *
pack_iw_len);
sptr_base += iw2 * pad_top * pack_iw_len;
rep(ih_idx, ih) {
memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t));
memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw);
for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) {
int8x16_t src[4];
int8x16_t dst[4];
src[0] = vld1q_s8(temp_ptr + iw_idx);
src[1] = vld1q_s8(temp_ptr + iw_idx + 1);
src[2] = vld1q_s8(temp_ptr + iw_idx + 2);
src[3] = vld1q_s8(temp_ptr + iw_idx + 3);
dst[0] = vqtbl1q_s8(src[0], tbl_idx);
dst[1] = vqtbl1q_s8(src[1], tbl_idx);
dst[2] = vqtbl1q_s8(src[2], tbl_idx);
dst[3] = vqtbl1q_s8(src[3], tbl_idx);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]);
}
for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) {
int8x16_t src = vld1q_s8(temp_ptr + iw_idx);
int8x16_t dst = vqtbl1q_s8(src, tbl_idx);
vst1q_s8(sptr_base + iw_idx * pack_iw_len, dst);
}
sptr_base += iw2 * pack_iw_len;
sptr += iw;
}
sptr_base += iw2 * pad_bottom * pack_iw_len;
}
}

template <BiasMode bias_mode, typename Op, size_t filter_size>
struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, int8_t* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr int stride = 1;
constexpr size_t fh = filter_size;
constexpr size_t fw = (filter_size + 3) / 4 * 4;
constexpr size_t ic_step = 1;
constexpr size_t big_oc_step = 8;
constexpr size_t oc_step = 4;
constexpr size_t ih_step = 1;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = stride;
constexpr size_t stride_w = stride;
constexpr int pack_iw_len = 16;

const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
big_oc_step, stride>::impl; \
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
oc_step, stride>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}

for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_idx) * oc_step;

KerNeonXXs2NchwNchw44<bias_mode, Op, ow_step, filter_size,
big_oc_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}

if (oc_remain > 0) {
size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44<bias_mode, Op, ow_step, filter_size,
oc_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset,
filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
}
};

#define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \
template struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, \
stride>;

#define INSTANCE_OP_PARAM(stride, filter, bias_mode) \
INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define INSTANCE_BIAS_MODE_PARAM(stride, filter) \
INSTANCE_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define INSTANCE_CONV_KERN(stride) \
INSTANCE_BIAS_MODE_PARAM(stride, 2) \
INSTANCE_BIAS_MODE_PARAM(stride, 3) \
INSTANCE_BIAS_MODE_PARAM(stride, 5) \
INSTANCE_BIAS_MODE_PARAM(stride, 7)

INSTANCE_CONV_KERN(1);

} // namespace int8_direct_nchw_nchw44
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 1412
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp
File diff suppressed because it is too large
View File


+ 26
- 57
dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp View File

@@ -114,7 +114,7 @@ static void copy_padding_kern(const WorkspaceBundle& bundle,
rep(ih_idx, IH) {
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t));
sptr_base += nr_pad_w;
nchw44_pack_src(sptr, sptr_base, IW);
int8_direct_nchw44::nchw44_pack_src(sptr, sptr_base, IW);
sptr_base += IW * pack_ic * expend_element;
sptr += IW * pack_ic;
std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t));
@@ -125,8 +125,8 @@ static void copy_padding_kern(const WorkspaceBundle& bundle,
}
}

template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain,
typename DstType, int stride>
template <size_t filter, BiasMode bias_mode, typename Op, typename DstType,
int stride>
static void do_conv_kern(const WorkspaceBundle& bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
@@ -182,8 +182,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx;
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW;
nchw44_pack_filter(fptr, packed_weight, oc_block / 4 * IC / 4 * FH * FW);
conv_direct_int8_nchw44<bias_mode, Op, ow_remain, filter, DstType, stride>(
int8_direct_nchw44::nchw44_pack_filter(fptr, packed_weight,
oc_block / 4 * IC / 4 * FH * FW);
int8_direct_nchw44::conv_direct_int8_nchw44<bias_mode, Op, filter, DstType,
stride>(
sptr, packed_weight, bptr, nullptr, static_cast<DstType*>(dst),
oc_block, IC, IH2, IW2, OH, OW, op);
}
@@ -233,40 +235,38 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns(
size_t N = param.n;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
size_t OW = param.osz[1];
size_t group = fm.group;
size_t fh = fm.spatial[0];
size_t fw = fm.spatial[1];
WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
int ow_remain = OW % 8;
bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8;
// NOTE: remain_w is not used to gen hash of midout for compatible with changing
// shape runtime
#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, remain_w, op) \
#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \
midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w, dst_type, \
stride>; \
do_conv_fun = do_conv_kern<filter, bias_mode, op, dst_type, stride>; \
} \
MIDOUT_END();

#define GET_OP_PARAM(stride, filter, bias_mode, remain_w) \
#define GET_OP_PARAM(stride, filter, bias_mode) \
if (need_post_process) { \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
\
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
\
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
\
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
@@ -277,7 +277,7 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns(
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \
remain_w, NoneOp<dt_int32>) \
NoneOp<dt_int32>) \
break; \
default: \
megdnn_assert( \
@@ -287,48 +287,17 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns(
} \
}

#define GET_REMAIN_W_PARAM(stride, filter, bias_mode) \
switch (ow_remain) { \
case 0: \
GET_OP_PARAM(stride, filter, bias_mode, 0); \
break; \
case 1: \
GET_OP_PARAM(stride, filter, bias_mode, 1); \
break; \
case 2: \
GET_OP_PARAM(stride, filter, bias_mode, 2); \
break; \
case 3: \
GET_OP_PARAM(stride, filter, bias_mode, 3); \
break; \
case 4: \
GET_OP_PARAM(stride, filter, bias_mode, 4); \
break; \
case 5: \
GET_OP_PARAM(stride, filter, bias_mode, 5); \
break; \
case 6: \
GET_OP_PARAM(stride, filter, bias_mode, 6); \
break; \
case 7: \
GET_OP_PARAM(stride, filter, bias_mode, 7); \
break; \
default: \
megdnn_assert(0); \
}

#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_REMAIN_W_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_REMAIN_W_PARAM(stride, filter, \
BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}

#define DISPATCH_CONV_KERN(stride) \


+ 7
- 1337
dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h
File diff suppressed because it is too large
View File


+ 10
- 9
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp View File

@@ -117,11 +117,11 @@ static void copy_padding_kern(const WorkspaceBundle& bundle,
const size_t tmp_size = get_temp_bytes(iw, pw);
int8_t* tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) +
ncb_index.thread_id * tmp_size;
pack_nchw_src_for_nchw44_conv<1>(sptr, sptr_base, 1, ph, ph, pw, pw, ih,
iw, iw2, pw, tmp_ptr);
int8_direct_nchw_nchw44::pack_nchw_src_for_nchw44_conv<1>(
sptr, sptr_base, 1, ph, ph, pw, pw, ih, iw, iw2, pw, tmp_ptr);
} else {
pack_nchw_src_for_nchw44_conv<2>(sptr, sptr_base, 1, ph, ph, pw, pw, ih,
iw, iw2, pw, nullptr);
int8_direct_nchw_nchw44::pack_nchw_src_for_nchw44_conv<2>(
sptr, sptr_base, 1, ph, ph, pw, pw, ih, iw, iw2, pw, nullptr);
}
}
static void pack_weight(const WorkspaceBundle& bundle,
@@ -142,11 +142,11 @@ static void pack_weight(const WorkspaceBundle& bundle,
group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2;

if (stride_h == 1) {
pack_nchw44_weight_for_nchw_conv<1>(fptr, packed_weight, ic, fh, fw,
oc_block);
int8_direct_nchw_nchw44::pack_nchw44_weight_for_nchw_conv<1>(
fptr, packed_weight, ic, fh, fw, oc_block);
} else {
pack_nchw44_weight_for_nchw_conv<2>(fptr, packed_weight, ic, fh, fw,
oc_block);
int8_direct_nchw_nchw44::pack_nchw44_weight_for_nchw_conv<2>(
fptr, packed_weight, ic, fh, fw, oc_block);
}
}
template <size_t filter, BiasMode bias_mode, typename Op, int stride>
@@ -208,7 +208,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
int8_t* packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * oc * ic * fh * fw2 +
oc_idx * ic * fh * fw2;
conv_direct_int8_nchw_nchw44<bias_mode, Op, filter, stride>(
int8_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44<bias_mode, Op, filter,
stride>(
sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh,
ow, op);
}


+ 3
- 1854
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h
File diff suppressed because it is too large
View File


+ 5
- 4
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp View File

@@ -93,8 +93,8 @@ void do_weight_trans(const WorkspaceBundle& bundle,
const int fw2 = round_up(fw, 4);
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1));
auto origin_weight = kern_param.filter<dt_int8>();
pack_weight_int8_nchw_nchw44_dot(packed_weight, origin_weight, oc, ic, fh,
fw, fw2);
dot_direct_nchw_nchw44::pack_weight_int8_nchw_nchw44_dot(
packed_weight, origin_weight, oc, ic, fh, fw, fw2);
}

template <size_t filter, BiasMode bias_mode, typename Op, int stride>
@@ -147,7 +147,7 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) +
ncb_index.thread_id * tmp_size;
}
pack_src_int8_nchw_nchw44_dot<stride>(
dot_direct_nchw_nchw44::pack_src_int8_nchw_nchw44_dot<stride>(
sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw, tmp_ptr);
@@ -164,7 +164,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale;
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
Op op(scale_bias, scale_dst);
conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, filter, stride>(
dot_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44_dot<bias_mode, Op,
filter, stride>(
sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh,
oh_block_real, ow, op);
}


+ 14
- 662
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h View File

@@ -20,83 +20,15 @@
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"

using namespace megdnn;
using namespace arm_common;
namespace {
namespace megdnn {
namespace arm_common {
namespace dot_direct_nchw_nchw44 {
template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block,
int stride, typename T, typename T2, typename T3, typename T4>
struct ShiftCalHelper {
static void impl(T& c, T2& src, T3& weight);
};

template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, stride, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2], weight[0][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[1][step * 2], weight[1][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2 + 1], weight[0][weight_idx], \
src[1][(src_idx + step) / 4]); \
c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[1][step * 2 + 1], weight[1][weight_idx], \
src[1][(src_idx + step) / 4]);

UNROLL_CALL_RAW(4, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, stride, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2], weight[0][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2 + 1], weight[0][weight_idx], \
src[1][(src_idx + step) / 4]);

UNROLL_CALL_RAW(4, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); \
c[1][step] = Func::template impl<(src_idx + step) % 4>( \
c[1][step], weight[1][weight_idx], src[(src_idx + step) / 4]);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block,
int stride, typename T, typename T2, typename T3>
inline void cal_helper(T& c, T2& src, T3& weight) {
@@ -133,490 +65,12 @@ struct KerNeonDotXXs2Nchw44Int8 {
int iw, int ld_dst_oc, const Op& op);
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block, int stride>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_hight = 2;
constexpr int filter_width = 4;
constexpr int weight_reg = 1;
constexpr int src_reg = 1;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 0 * iw, stride);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 1 * iw, stride);
load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block, int stride>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block,
stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_hight = 3;
constexpr int filter_width = 4;
constexpr int weight_reg = 1;
constexpr int src_reg = 1;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 0 * iw, stride);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 1 * iw, stride);
load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 2
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 2 * iw, stride);
load_helper<weight_reg, 2 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block, int stride>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block,
stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_hight = 5;
constexpr int filter_width = 8;
constexpr int src_reg = 2;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \
stride); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW(5, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += 5 * 32;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

/**
* oc = 8, ow = 8
* dot 4 element, pad last filter and do twice dot every row filter, filter like
* below
* --------------------------
* |x, x, x, x,| x, x, x, 0 |
* --------------------------
**/
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block, int stride>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block,
stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_hight = 7;
constexpr int filter_width = 8;
constexpr int src_reg = 2;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \
stride); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW(7, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += 7 * 32;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
////////////////////stride 1///////////////////
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 2;
constexpr int filter_width = 4;
constexpr int weight_reg = 2;
constexpr int src_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 0 * iw * pack_iw_len, 0);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 3;
constexpr int filter_width = 4;
constexpr int weight_reg = 3;
constexpr int src_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 0 * iw * pack_iw_len, 0);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 2
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 2 * iw * pack_iw_len, 0);
cal_helper<0, 2, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 5;
constexpr int filter_width = 8;
constexpr int src_reg = 3;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, src_ptr + step * iw * pack_iw_len, 0); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);

UNROLL_CALL_RAW(5, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 7;
constexpr int filter_width = 8;
constexpr int src_reg = 3;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, src_ptr + step * iw * pack_iw_len, 0); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);

UNROLL_CALL_RAW(7, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <int stride>
void pack_src_int8_nchw_nchw44_dot(int8_t* sptr_base, const int8_t* sptr_origin,
const int, const int pw, const int,
const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride, int8_t*) {
constexpr int ic_step = 1;
rep_step(ic_idx, ic, ic_step) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0,
sizeof(int8_t) * ic_step * iw2 * (ih + pad_top + pad_bottom));
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
memcpy(sptr_base + pw * ic_step, sptr,
sizeof(int8_t) * iw * ic_step);
sptr_base += iw2 * ic_step;
sptr += iw * ic_step;
}
sptr_base += iw2 * pad_bottom * ic_step;
}
}

template <>
void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base,
const int8_t* sptr_origin, const int,
const int pw, const int, const int ih,
const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride,
int8_t* temp_ptr) {
static uint8_t reorder_idx[16] = {0, 1, 2, 3, 1, 2, 3, 4,
2, 3, 4, 5, 3, 4, 5, 6};
uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]);

constexpr int iw_step = 16;
constexpr int pack_iw_len = 4;
const int iw_with_pad = iw + 2 * pw;
const int iw_with_pad_end = iw_with_pad / iw_step * iw_step;
rep(ic_idx, ic) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0,
sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) *
pack_iw_len);
sptr_base += iw2 * pad_top * pack_iw_len;
rep(ih_idx, ih) {
memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t));
memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw);
for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) {
int8x16_t src[4];
int8x16_t dst[4];
src[0] = vld1q_s8(temp_ptr + iw_idx);
src[1] = vld1q_s8(temp_ptr + iw_idx + 4);
src[2] = vld1q_s8(temp_ptr + iw_idx + 8);
src[3] = vld1q_s8(temp_ptr + iw_idx + 12);
dst[0] = vqtbl1q_s8(src[0], tbl_idx);
dst[1] = vqtbl1q_s8(src[1], tbl_idx);
dst[2] = vqtbl1q_s8(src[2], tbl_idx);
dst[3] = vqtbl1q_s8(src[3], tbl_idx);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]);
}
for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) {
*(sptr_base + iw_idx * pack_iw_len + 0) =
*(temp_ptr + iw_idx + 0);
*(sptr_base + iw_idx * pack_iw_len + 1) =
*(temp_ptr + iw_idx + 1);
*(sptr_base + iw_idx * pack_iw_len + 2) =
*(temp_ptr + iw_idx + 2);
*(sptr_base + iw_idx * pack_iw_len + 3) =
*(temp_ptr + iw_idx + 3);
}
sptr_base += iw2 * pack_iw_len;
sptr += iw;
}
sptr_base += iw2 * pad_bottom * pack_iw_len;
}
}
const int ic, const int ic_stride, int8_t*);

static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr,
const int8_t* src_ptr,
@@ -663,117 +117,15 @@ static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr,
}

template <BiasMode bias_mode, typename Op, int filter_size, int stride>
static void conv_direct_int8_nchw_nchw44_dot(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr int fh = filter_size;
constexpr int fw = (filter_size + 3) / 4 * 4;
#if MEGDNN_AARCH64
constexpr int big_oc_step = 8;
#else
constexpr int big_oc_step = 4;
#endif
constexpr int oc_step = 4;
constexpr int ih_step = 1;
constexpr int oh_step = 1;
constexpr int ow_step = 8;
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int pack_iw_len = stride == 2 ? 1 : 4;

const int img_stride = oh * ow;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun =
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
big_oc_step, ow_step, stride>::impl; \
kern_small_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
oc_step, ow_step, stride>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}

for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
big_oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
int oc_idx = oc_end;
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}

} // namespace
void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp,
int8_t* dst, const int oc, const int ic,
const int ih, const int iw, const int oh,
const int oh_block, const int ow,
const Op& op);

} // namespace dot_direct_nchw_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 2
- 2
dnn/test/arm_common/conv_bias_multi_thread.cpp View File

@@ -2344,7 +2344,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
#endif
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
@@ -2361,7 +2361,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) {
#endif
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");


+ 23
- 27
dnn/test/arm_common/matrix_mul.cpp View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/arm_common/fixture.h"

@@ -30,8 +31,7 @@ TEST_F(ARM_COMMON, MATRIX_MUL_INT8x8x16) {

TEST_F(ARM_COMMON, MATRIX_MUL_QUINT8) {
matrix_mul::check_matrix_mul(dtype::Quantized8Asymm(1.2f, (uint8_t)127),
dtype::Quantized8Asymm(1.3f, (uint8_t)129),
{},
dtype::Quantized8Asymm(1.3f, (uint8_t)129), {},
handle());
}

@@ -232,8 +232,7 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEVM) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;

checker.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));

std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127);
checker.set_rng(0, rng.get()).set_rng(1, rng.get());
@@ -251,7 +250,7 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEVM) {
.set_dtype(2, dtype::QuantizedS32(6.25f))
.execs({A, B, {}});
};
// M = 1
for (size_t N : {1, 10, 16, 33, 64})
for (size_t K : {7, 512, 1024})
@@ -263,8 +262,7 @@ TEST_F(ARM_COMMON, FP32_GEVM) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;

checker.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));

checker.set_epsilon(1e-2);
auto run = [&](size_t M, size_t K, size_t N) {
@@ -276,7 +274,7 @@ TEST_F(ARM_COMMON, FP32_GEVM) {
B = TensorShape{N, K};
checker.set_param(param).execs({A, B, {}});
};
// M = 1
for (size_t M : {1})
for (size_t K : {1000, 4096, 25088})
@@ -298,15 +296,15 @@ TEST_F(ARM_COMMON, FP32_GEMV_MK4) {
param.transposeA = false;
param.transposeB = false;
TensorShape A, B;
A = TensorShape{M/4, K/4, 4, 4};
B = TensorShape{K/4, 1, 4};
A = TensorShape{M / 4, K / 4, 4, 4};
B = TensorShape{K / 4, 1, 4};
checker.set_param(param).execs({A, B, {}});
};
// N = 1
for (size_t M : {4, 16, 128, 1024})
for (size_t K : {4, 8, 12, 128, 256, 4096})
run(M, K);
run(M, K);
}

#if MEGDNN_WITH_BENCHMARK
@@ -343,7 +341,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) {

for (size_t M : {4, 64, 1024, 4096})
for (size_t K : {128, 256, 1024, 4096})
run(M, K, 1);
run(M, K, 1);
}

TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) {
@@ -372,7 +370,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) {
.exec({{2, 1024}, {1024, 512}, {}});
benchmarker.set_display(true);
}
// run gemv
run(12, 48, 1);
run(48, 12, 1);
@@ -396,14 +394,14 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) {
Benchmarker<MatrixMul> benchmarker(handle());
benchmarker.set_times(exec_times);
benchmarker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_param(param);
.set_dtype(1, dtype::Float32())
.set_param(param);

auto run = [&](size_t M, size_t K) {
printf("SGEMV_MK4: (%zu, %zu, %zu)\n", M, K, N);
printf("SGEMV_MK4: (%zu, %zu)\n", M, K);
TensorShape A, B;
A = TensorShape{M/4, K/4, 4, 4};
B = TensorShape{K/4, 1, 4};
A = TensorShape{M / 4, K / 4, 4, 4};
B = TensorShape{K / 4, 1, 4};
auto time = benchmarker.exec({A, B, {}}) / exec_times;
auto computations = 2.f * M * K * 1e-6;
auto perf = computations / time;
@@ -422,7 +420,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) {
// run gemv mk4
for (size_t M : {4, 64, 1024, 4096})
for (size_t K : {128, 1024, 4096})
run(M, K);
run(M, K);
}

TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) {
@@ -490,7 +488,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) {
//////////////////////// gemv //////////////////////////
for (size_t M : {8, 64, 112, 256}) {
for (size_t K : {8, 64, 112, 256}) {
run (M, 1, K);
run(M, 1, K);
}
}

@@ -502,10 +500,8 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) {
}
}
}

}


TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
@@ -514,7 +510,8 @@ TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) {
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int32{})
.set_param(param).set_display(false);
.set_param(param)
.set_display(false);
Benchmarker<MatrixMul> benchmarker_float(handle());
benchmarker_float.set_display(false).set_times(RUNS);

@@ -533,7 +530,7 @@ TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) {
//////////////////////// gemv //////////////////////////
for (size_t M : {8, 64, 112, 256}) {
for (size_t K : {8, 64, 112, 256}) {
run (M, 1, K);
run(M, 1, K);
}
}

@@ -618,5 +615,4 @@ TEST_F(ARM_COMMON, BENCHMARK_TRANSPOSED_MATRIX_MUL_QUINT8) {

#endif


// vim: syntax=cpp.doxygen

Loading…
Cancel
Save