Browse Source

feat(dnn/arm_common): add nchw_nchw44 aarch64 int8 3x3s2 7x7s2 asm

GitOrigin-RevId: 871465335d
release-0.6
Megvii Engine Team 5 years ago
parent
commit
80ecabe8c6
1 changed files with 624 additions and 0 deletions
  1. +624
    -0
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h

+ 624
- 0
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h View File

@@ -302,6 +302,470 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block, stride> {
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc);
}
};
#if MEGDNN_AARCH64
template <BiasMode bias_mode, typename Op>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, 0, 7, 8, 2> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8,
0, 8, 0, 8, 0, 8, 0, 8};
uint8x16_t vtbl = vld1q_u8(src_idx_buffer);

// constexpr int stride = 2;
constexpr int oc_block = 8;
constexpr int remain_w = 0;
constexpr int filter_size = 7;
constexpr int ic_step = 1;
constexpr int oc_step = 4;
constexpr int pack_iw_len = 4;
constexpr int fh_step = 2;
constexpr int c_dim = OCHelper<oc_block>::val;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic;
const size_t src_step = fh_step * iw * ic_step * pack_iw_len;
const size_t weight_step = filter_size * pack_iw_len * fh_step;
const size_t weight_step_small = filter_size * pack_iw_len;
int32x4_t c[c_dim][4];

init_ocx_ow4<c_dim, bias_mode>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;

const int8_t* weight_ptr_oc = weight_ptr + ld_dot4_weight_oc;

const int8_t* nchw_src_ptr_last_line =
src_ptr + ic_idx * ic_stride +
6 * iw * ic_step * pack_iw_len;
/**
* r0-r7 c
* r24-r31 temp
* r8-r15 src
* r16-r22 weight
* r23 vtbl
*/
asm volatile(

"ldp q8, q9, [%[nchw_src_ptr]]\n"
"ldp q16, q17, [%[weight_ptr]]\n"
"ldp q10, q11, [%[nchw_src_ptr], #32]\n"
"smull v24.8h, v8.8b, v16.8b\n"
"ldp q19, q20, [%[weight_ptr_oc]]\n"
"smull v25.8h, v9.8b, v16.8b\n"
"ldp q12, q13, [%[nchw_src_ptr], #64]\n"
"smull v26.8h, v10.8b, v16.8b\n"
"ldr q18, [%[weight_ptr],#32]\n"
"smull v27.8h, v11.8b, v16.8b\n"
"ldr q21, [%[weight_ptr_oc],#32]\n"
"smull v28.8h, v8.8b, v19.8b\n"
"smlal2 v24.8h, v8.16b, v16.16b\n"
"smlal2 v25.8h, v9.16b, v16.16b\n"
"smlal2 v26.8h, v10.16b, v16.16b\n"
"smlal2 v27.8h, v11.16b, v16.16b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v29.8h, v9.8b, v19.8b\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v30.8h, v10.8b, v19.8b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v31.8h, v11.8b, v19.8b\n"
"sadalp %[c03].4s, v27.8h\n"
"smlal2 v28.8h, v8.16b, v19.16b\n"
"ldr d8, [%[nchw_src_ptr],#48]\n"
"smlal2 v29.8h, v9.16b, v19.16b\n"
"smlal2 v30.8h, v10.16b, v19.16b\n"
"smlal2 v31.8h, v11.16b, v19.16b\n"
"smull v24.8h, v9.8b, v17.8b\n"
"sadalp %[c10].4s, v28.8h\n"
"smull v25.8h, v10.8b, v17.8b\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v26.8h, v11.8b, v17.8b\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v27.8h, v12.8b, v17.8b\n"
"sadalp %[c13].4s, v31.8h\n"
"smlal2 v24.8h, v9.16b, v17.16b\n"
"smlal2 v25.8h, v10.16b, v17.16b\n"
"smlal2 v26.8h, v11.16b, v17.16b\n"
"smlal2 v27.8h, v12.16b, v17.16b\n"
"smull v28.8h, v9.8b, v20.8b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v29.8h, v10.8b, v20.8b\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v30.8h, v11.8b, v20.8b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v31.8h, v12.8b, v20.8b\n"
"sadalp %[c03].4s, v27.8h\n"
"smlal2 v28.8h, v9.16b, v20.16b\n"
"ldr d9, [%[nchw_src_ptr],#64]\n"
"smlal2 v29.8h, v10.16b, v20.16b\n"
"ldr d14, [%[nchw_src_ptr],#80]\n"
"smlal2 v30.8h, v11.16b, v20.16b\n"
"smlal2 v31.8h, v12.16b, v20.16b\n"
"smull v24.8h, v10.8b, v18.8b\n"
"sadalp %[c10].4s, v28.8h\n"
"smull v25.8h, v11.8b, v18.8b\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v26.8h, v12.8b, v18.8b\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v27.8h, v13.8b, v18.8b\n"
"sadalp %[c13].4s, v31.8h\n"
"smlal2 v24.8h, v10.16b, v18.16b\n"
"ldr d19, [%[weight_ptr_oc],#48]\n"
"smlal2 v25.8h, v11.16b, v18.16b\n"
"ldr d15, [%[nchw_src_ptr],#96]\n"
"smlal2 v26.8h, v12.16b, v18.16b\n"
"smlal2 v27.8h, v13.16b, v18.16b\n"
"ldr d18, [%[weight_ptr],#48]\n"
"smull v28.8h, v10.8b, v21.8b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v29.8h, v11.8b, v21.8b\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v30.8h, v12.8b, v21.8b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v31.8h, v13.8b, v21.8b\n"
"sadalp %[c03].4s, v27.8h\n"
"smlal2 v28.8h, v10.16b, v21.16b\n"
"add %[nchw_src_ptr], %[nchw_src_ptr], %[src_step]\n"
"smlal2 v29.8h, v11.16b, v21.16b\n"
"ldp q10, q11, [%[nchw_src_ptr], #32]\n"
"add %[weight_ptr], %[weight_ptr], %[weight_step]\n"
"smlal2 v30.8h, v12.16b, v21.16b\n"
"add %[weight_ptr_oc], %[weight_ptr_oc], "
"%[weight_step]\n"
"smlal2 v31.8h, v13.16b, v21.16b\n"
"ldp q16, q17, [%[weight_ptr]]\n"
"smull v24.8h, v8.8b, v18.8b\n"
"sadalp %[c10].4s, v28.8h\n"
"smull v25.8h, v9.8b, v18.8b\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v26.8h, v14.8b, v18.8b\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v27.8h, v15.8b, v18.8b\n"
"sadalp %[c13].4s, v31.8h\n"
"smull v28.8h, v8.8b, v19.8b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v29.8h, v9.8b, v19.8b\n"
"ldp q8, q9, [%[nchw_src_ptr]]\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v30.8h, v14.8b, v19.8b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v31.8h, v15.8b, v19.8b\n"
"ldp q19, q20, [%[weight_ptr_oc]]\n"
"sadalp %[c03].4s, v27.8h\n"
"smull v24.8h, v8.8b, v16.8b\n"
"sadalp %[c10].4s, v28.8h\n"
"smull v25.8h, v9.8b, v16.8b\n"
"ldp q12, q13, [%[nchw_src_ptr], #64]\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v26.8h, v10.8b, v16.8b\n"
"ldr q18, [%[weight_ptr],#32]\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v27.8h, v11.8b, v16.8b\n"
"ldr q21, [%[weight_ptr_oc],#32]\n"
"sadalp %[c13].4s, v31.8h\n"
//! fh = 2
"smull v28.8h, v8.8b, v19.8b\n"
"smlal2 v24.8h, v8.16b, v16.16b\n"
"smlal2 v25.8h, v9.16b, v16.16b\n"
"smlal2 v26.8h, v10.16b, v16.16b\n"
"smlal2 v27.8h, v11.16b, v16.16b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v29.8h, v9.8b, v19.8b\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v30.8h, v10.8b, v19.8b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v31.8h, v11.8b, v19.8b\n"
"sadalp %[c03].4s, v27.8h\n"
"smlal2 v28.8h, v8.16b, v19.16b\n"
"ldr d8, [%[nchw_src_ptr],#48]\n"
"smlal2 v29.8h, v9.16b, v19.16b\n"
"smlal2 v30.8h, v10.16b, v19.16b\n"
"smlal2 v31.8h, v11.16b, v19.16b\n"
"smull v24.8h, v9.8b, v17.8b\n"
"sadalp %[c10].4s, v28.8h\n"
"smull v25.8h, v10.8b, v17.8b\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v26.8h, v11.8b, v17.8b\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v27.8h, v12.8b, v17.8b\n"
"sadalp %[c13].4s, v31.8h\n"
"smlal2 v24.8h, v9.16b, v17.16b\n"
"smlal2 v25.8h, v10.16b, v17.16b\n"
"smlal2 v26.8h, v11.16b, v17.16b\n"
"smlal2 v27.8h, v12.16b, v17.16b\n"
"smull v28.8h, v9.8b, v20.8b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v29.8h, v10.8b, v20.8b\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v30.8h, v11.8b, v20.8b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v31.8h, v12.8b, v20.8b\n"
"sadalp %[c03].4s, v27.8h\n"
"smlal2 v28.8h, v9.16b, v20.16b\n"
"ldr d9, [%[nchw_src_ptr],#64]\n"
"smlal2 v29.8h, v10.16b, v20.16b\n"
"ldr d14, [%[nchw_src_ptr],#80]\n"
"smlal2 v30.8h, v11.16b, v20.16b\n"
"smlal2 v31.8h, v12.16b, v20.16b\n"
"smull v24.8h, v10.8b, v18.8b\n"
"sadalp %[c10].4s, v28.8h\n"
"smull v25.8h, v11.8b, v18.8b\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v26.8h, v12.8b, v18.8b\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v27.8h, v13.8b, v18.8b\n"
"sadalp %[c13].4s, v31.8h\n"
"smlal2 v24.8h, v10.16b, v18.16b\n"
"ldr d19, [%[weight_ptr_oc],#48]\n"
"smlal2 v25.8h, v11.16b, v18.16b\n"
"ldr d15, [%[nchw_src_ptr],#96]\n"
"smlal2 v26.8h, v12.16b, v18.16b\n"
"smlal2 v27.8h, v13.16b, v18.16b\n"
"ldr d18, [%[weight_ptr],#48]\n"
"smull v28.8h, v10.8b, v21.8b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v29.8h, v11.8b, v21.8b\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v30.8h, v12.8b, v21.8b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v31.8h, v13.8b, v21.8b\n"
"sadalp %[c03].4s, v27.8h\n"
"smlal2 v28.8h, v10.16b, v21.16b\n"
"add %[nchw_src_ptr], %[nchw_src_ptr], %[src_step]\n"
"smlal2 v29.8h, v11.16b, v21.16b\n"
"add %[weight_ptr], %[weight_ptr], %[weight_step]\n"
"smlal2 v30.8h, v12.16b, v21.16b\n"
"add %[weight_ptr_oc], %[weight_ptr_oc], "
"%[weight_step]\n"
"smlal2 v31.8h, v13.16b, v21.16b\n"
"ldp q16, q17, [%[weight_ptr]]\n"
"smull v24.8h, v8.8b, v18.8b\n"
"ldp q10, q11, [%[nchw_src_ptr], #32]\n"
"sadalp %[c10].4s, v28.8h\n"
"smull v25.8h, v9.8b, v18.8b\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v26.8h, v14.8b, v18.8b\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v27.8h, v15.8b, v18.8b\n"
"sadalp %[c13].4s, v31.8h\n"
"smull v28.8h, v8.8b, v19.8b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v29.8h, v9.8b, v19.8b\n"
"ldp q8, q9, [%[nchw_src_ptr]]\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v30.8h, v14.8b, v19.8b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v31.8h, v15.8b, v19.8b\n"
"ldp q19, q20, [%[weight_ptr_oc]]\n"
"sadalp %[c03].4s, v27.8h\n"
"smull v24.8h, v8.8b, v16.8b\n"
"sadalp %[c10].4s, v28.8h\n"
"smull v25.8h, v9.8b, v16.8b\n"
"ldp q12, q13, [%[nchw_src_ptr], #64]\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v26.8h, v10.8b, v16.8b\n"
"ldr q18, [%[weight_ptr],#32]\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v27.8h, v11.8b, v16.8b\n"
"ldr q21, [%[weight_ptr_oc],#32]\n"
"sadalp %[c13].4s, v31.8h\n"
//! fh = 4
"smull v28.8h, v8.8b, v19.8b\n"
"smlal2 v24.8h, v8.16b, v16.16b\n"
"smlal2 v25.8h, v9.16b, v16.16b\n"
"smlal2 v26.8h, v10.16b, v16.16b\n"
"smlal2 v27.8h, v11.16b, v16.16b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v29.8h, v9.8b, v19.8b\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v30.8h, v10.8b, v19.8b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v31.8h, v11.8b, v19.8b\n"
"sadalp %[c03].4s, v27.8h\n"
"smlal2 v28.8h, v8.16b, v19.16b\n"
"ldr d8, [%[nchw_src_ptr],#48]\n"
"smlal2 v29.8h, v9.16b, v19.16b\n"
"smlal2 v30.8h, v10.16b, v19.16b\n"
"smlal2 v31.8h, v11.16b, v19.16b\n"
"smull v24.8h, v9.8b, v17.8b\n"
"sadalp %[c10].4s, v28.8h\n"
"smull v25.8h, v10.8b, v17.8b\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v26.8h, v11.8b, v17.8b\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v27.8h, v12.8b, v17.8b\n"
"sadalp %[c13].4s, v31.8h\n"
"smlal2 v24.8h, v9.16b, v17.16b\n"
"smlal2 v25.8h, v10.16b, v17.16b\n"
"smlal2 v26.8h, v11.16b, v17.16b\n"
"smlal2 v27.8h, v12.16b, v17.16b\n"
"smull v28.8h, v9.8b, v20.8b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v29.8h, v10.8b, v20.8b\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v30.8h, v11.8b, v20.8b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v31.8h, v12.8b, v20.8b\n"
"sadalp %[c03].4s, v27.8h\n"
"smlal2 v28.8h, v9.16b, v20.16b\n"
"ldr d9, [%[nchw_src_ptr],#64]\n"
"smlal2 v29.8h, v10.16b, v20.16b\n"
"ldr d14, [%[nchw_src_ptr],#80]\n"
"smlal2 v30.8h, v11.16b, v20.16b\n"
"smlal2 v31.8h, v12.16b, v20.16b\n"
"smull v24.8h, v10.8b, v18.8b\n"
"sadalp %[c10].4s, v28.8h\n"
"smull v25.8h, v11.8b, v18.8b\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v26.8h, v12.8b, v18.8b\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v27.8h, v13.8b, v18.8b\n"
"sadalp %[c13].4s, v31.8h\n"
"smlal2 v24.8h, v10.16b, v18.16b\n"
"ldr d19, [%[weight_ptr_oc],#48]\n"
"smlal2 v25.8h, v11.16b, v18.16b\n"
"ldr d15, [%[nchw_src_ptr],#96]\n"
"smlal2 v26.8h, v12.16b, v18.16b\n"
"smlal2 v27.8h, v13.16b, v18.16b\n"
"ldr d18, [%[weight_ptr],#48]\n"
"smull v28.8h, v10.8b, v21.8b\n"
"add %[weight_ptr], %[weight_ptr], %[weight_step]\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v29.8h, v11.8b, v21.8b\n"
"add %[weight_ptr_oc], %[weight_ptr_oc], %[weight_step]\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v30.8h, v12.8b, v21.8b\n"
"ldr q16, [%[weight_ptr]]\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v31.8h, v13.8b, v21.8b\n"
"sadalp %[c03].4s, v27.8h\n"
"smlal2 v28.8h, v10.16b, v21.16b\n"
"smlal2 v29.8h, v11.16b, v21.16b\n"
"ldp q10, q11, [%[nchw_src_ptr_last_line], #32]\n"
"smlal2 v30.8h, v12.16b, v21.16b\n"
"smlal2 v31.8h, v13.16b, v21.16b\n"
"ldp q12, q13, [%[nchw_src_ptr_last_line], #64]\n"
"smull v24.8h, v8.8b, v18.8b\n"
"ldr d21, [%[weight_ptr_oc],#16]\n"
"sadalp %[c10].4s, v28.8h\n"
"smull v25.8h, v9.8b, v18.8b\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v26.8h, v14.8b, v18.8b\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v27.8h, v15.8b, v18.8b\n"
"ldr d18, [%[weight_ptr],#16]\n"
"sadalp %[c13].4s, v31.8h\n"
"smull v28.8h, v8.8b, v19.8b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v29.8h, v9.8b, v19.8b\n"
"ldp q8, q9, [%[nchw_src_ptr_last_line]]\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v30.8h, v14.8b, v19.8b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v31.8h, v15.8b, v19.8b\n"
"ldr q19, [%[weight_ptr_oc]]\n"
"tbl v8.16b, {v8.16b}, %[vtbl].16b\n"
"tbl v9.16b, {v9.16b}, %[vtbl].16b\n"
"sadalp %[c03].4s, v27.8h\n"
"tbl v10.16b, {v10.16b}, %[vtbl].16b\n"
"tbl v11.16b, {v11.16b}, %[vtbl].16b\n"
"sadalp %[c10].4s, v28.8h\n"
"tbl v12.16b, {v12.16b}, %[vtbl].16b\n"
"tbl v13.16b, {v13.16b}, %[vtbl].16b\n"
"sadalp %[c11].4s, v29.8h\n"
/// last line////
"smull v24.8h, v8.8b, v16.8b\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v25.8h, v9.8b, v16.8b\n"
"sadalp %[c13].4s, v31.8h\n"
"smull v26.8h, v10.8b, v16.8b\n"
"smull v27.8h, v11.8b, v16.8b\n"
"smlal2 v24.8h, v9.16b, v16.16b\n"
"smlal2 v25.8h, v10.16b, v16.16b\n"
"smlal2 v26.8h, v11.16b, v16.16b\n"
"smlal2 v27.8h, v12.16b, v16.16b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v28.8h, v8.8b, v19.8b\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v29.8h, v9.8b, v19.8b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v30.8h, v10.8b, v19.8b\n"
"sadalp %[c03].4s, v27.8h\n"
"smull v31.8h, v11.8b, v19.8b\n"
"smlal2 v28.8h, v9.16b, v19.16b\n"
"dup v9.8b, v11.b[0]\n"
"smlal2 v29.8h, v10.16b, v19.16b\n"
"smlal2 v30.8h, v11.16b, v19.16b\n"
"smlal2 v31.8h, v12.16b, v19.16b\n"
"sadalp %[c10].4s, v28.8h\n"
"smull v24.8h, v10.8b, v18.8b\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v25.8h, v11.8b, v18.8b\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v26.8h, v12.8b, v18.8b\n"
"sadalp %[c13].4s, v31.8h\n"
"smull v27.8h, v13.8b, v18.8b\n"
"add x10, %[nchw_src_ptr_last_line], #96\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v28.8h, v10.8b, v21.8b\n"

"sadalp %[c01].4s, v25.8h\n"
"add x5, %[weight_ptr], #24\n"
"smull v29.8h, v11.8b, v21.8b\n"
"add x6, %[weight_ptr_oc], #24\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v30.8h, v12.8b, v21.8b\n"
"sadalp %[c03].4s, v27.8h\n"
"smull v31.8h, v13.8b, v21.8b\n"
"dup v10.8b, v12.b[0]\n"
"sadalp %[c10].4s, v28.8h\n"
"ld1r {v12.8b}, [x10]\n"
"sadalp %[c11].4s, v29.8h\n"
"dup v11.8b, v13.b[0]\n"
"sadalp %[c12].4s, v30.8h\n"
"ld1r {v16.2s}, [x5]\n"
"sadalp %[c13].4s, v31.8h\n"
"sxtl v16.8h, v16.8b\n"
///////////////last element/////////
"add %[weight_ptr], %[weight_ptr], %[weight_step_small]\n"
"sxtl v9.8h, v9.8b\n"
"ld1r {v19.2s}, [x6]\n"
"sxtl v10.8h, v10.8b\n"
"sxtl v11.8h, v11.8b\n"
"smlal %[c00].4s, v9.4h, v16.4h\n"
"sxtl v12.8h, v12.8b\n"
"smlal %[c01].4s, v10.4h, v16.4h\n"
"sxtl v19.8h, v19.8b\n"
"smlal %[c02].4s, v11.4h, v16.4h\n"
"smlal %[c03].4s, v12.4h, v16.4h\n"
"smlal %[c10].4s, v9.4h, v19.4h\n"
"smlal %[c11].4s, v10.4h, v19.4h\n"
"smlal %[c12].4s, v11.4h, v19.4h\n"
"smlal %[c13].4s, v12.4h, v19.4h\n"
:

[c00] "+w"(c[0][0]), [c10] "+w"(c[1][0]),
[c01] "+w"(c[0][1]), [c11] "+w"(c[1][1]),
[c02] "+w"(c[0][2]), [c12] "+w"(c[1][2]),
[c03] "+w"(c[0][3]), [c13] "+w"(c[1][3]),
[nchw_src_ptr] "+r"(nchw_src_ptr),
[weight_ptr] "+r"(weight_ptr),
[weight_ptr_oc] "+r"(weight_ptr_oc)

: [vtbl] "w"(vtbl),
[nchw_src_ptr_last_line] "r"(nchw_src_ptr_last_line),
[src_step] "r"(src_step), [weight_step] "r"(weight_step),
[weight_step_small] "r"(weight_step_small)
: "x5", "x6", "x7", "x8", "x9", "x10", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
}
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc);
}
};
#endif
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block, stride> {
@@ -467,6 +931,166 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block, stride> {
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc);
}
};
#if MEGDNN_AARCH64
template <BiasMode bias_mode, typename Op>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, 0, 3, 8, 2> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_size = 3;
static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8,
0, 8, 0, 8, 0, 8, 0, 8};
constexpr int oc_block = 8;
constexpr int remain_w = 0;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int loop_ic_step = 1;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const size_t weight_step = filter_size * filter_size * pack_iw_len;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][4];
init_ocx_ow4<c_dim, bias_mode>(c, bias_ptr, oc_step);
uint8x16_t vtbl = vld1q_u8(src_idx_buffer);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
const int8_t* nchw_src_ptr_last_line =
src_ptr + ic_idx * ic_stride +
2 * iw * ic_step * pack_iw_len;
const int8_t* weight_ptr_oc = weight_ptr + ld_weight_oc;
/**
* r0-r7 c
* r24-r31 temp
* r8-r15 src
* r16-r19 weight
* r20-vtbl
*/
asm volatile(
//! load src 0,1
"ldp q8,q9, [%[nchw_src_ptr]]\n"
"ldr q16, [%[weight_ptr]]\n"
"ldp q10,q11, [%[nchw_src_ptr], #32]\n"
"add x5, %[weight_ptr], #32\n"
"smull v24.8h, v8.8b, v16.8b\n"
"ldr q17, [%[weight_ptr_oc]]\n"
"smull v25.8h, v9.8b, v16.8b\n"
"add x6, %[weight_ptr_oc], #32\n"
"smull v26.8h, v10.8b, v16.8b\n"
"smull v27.8h, v11.8b, v16.8b\n"
"smlal2 v24.8h, v8.16b, v16.16b\n"
"add x7, %[nchw_src_ptr_last_line], #64\n"
"smlal2 v25.8h, v9.16b, v16.16b\n"
"smlal2 v26.8h, v10.16b, v16.16b\n"
"smlal2 v27.8h, v11.16b, v16.16b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v28.8h, v8.8b, v17.8b\n"
"ldr d12, [%[nchw_src_ptr],#16]\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v29.8h, v9.8b, v17.8b\n"
"ldr d13, [%[nchw_src_ptr],#32]\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v30.8h, v10.8b, v17.8b\n"
"ldr d14, [%[nchw_src_ptr],#48]\n"
"sadalp %[c03].4s, v27.8h\n"
"smull v31.8h, v11.8b, v17.8b\n"
"ldr d18, [%[weight_ptr],#16]\n"
"smlal2 v28.8h, v8.16b, v17.16b\n"
"ldr d19, [%[weight_ptr_oc],#16]\n"
"smlal2 v29.8h, v9.16b, v17.16b\n"
"ldr d15, [%[nchw_src_ptr],#64]\n"
"smlal2 v30.8h, v10.16b, v17.16b\n"
"ldp q8,q9, [%[nchw_src_ptr_last_line]]\n"
"smull v24.8h, v12.8b, v18.8b\n"
"sadalp %[c10].4s, v28.8h\n"
"smlal2 v31.8h, v11.16b, v17.16b\n"
"ldp q10,q11, [%[nchw_src_ptr_last_line], #32]\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v25.8h, v13.8b, v18.8b\n"
"tbl v8.16b, {v8.16b}, %[vtbl].16b\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v26.8h, v14.8b, v18.8b\n"
"ldr d16, [%[weight_ptr],#24]\n"
"sadalp %[c13].4s, v31.8h\n"
"ldr d17, [%[weight_ptr_oc],#24]\n"
"smull v27.8h, v15.8b, v18.8b\n"
"tbl v9.16b, {v9.16b}, %[vtbl].16b\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v28.8h, v12.8b, v19.8b\n"
"tbl v10.16b, {v10.16b}, %[vtbl].16b\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v29.8h, v13.8b, v19.8b\n"
"tbl v11.16b, {v11.16b}, %[vtbl].16b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v30.8h, v14.8b, v19.8b\n"
"ld1r {v18.2s}, [x5]\n"
"sadalp %[c03].4s, v27.8h\n"
"smull v31.8h, v15.8b, v19.8b\n"
"ld1r {v19.2s}, [x6]\n"
"sadalp %[c10].4s, v28.8h\n"
"smull v24.8h, v8.8b, v16.8b\n"
"sadalp %[c11].4s, v29.8h\n"
"smull v25.8h, v9.8b, v16.8b\n"
"dup v12.8b, v9.b[0]\n"
"sadalp %[c12].4s, v30.8h\n"
"smull v26.8h, v10.8b, v16.8b\n"
"dup v12.8b, v9.b[0]\n"
"sadalp %[c13].4s, v31.8h\n"
"smull v27.8h, v11.8b, v16.8b\n"
"dup v13.8b, v10.b[0]\n"
"smull v28.8h, v8.8b, v17.8b\n"
"dup v14.8b, v11.b[0]\n"
"sadalp %[c00].4s, v24.8h\n"
"smull v29.8h, v9.8b, v17.8b\n"
"ld1r {v15.8b}, [x7]\n"
"sadalp %[c01].4s, v25.8h\n"
"smull v30.8h, v10.8b, v17.8b\n"
"sxtl v12.8h, v12.8b\n"
"sxtl v18.8h, v18.8b\n"
"sadalp %[c02].4s, v26.8h\n"
"smull v31.8h, v11.8b, v17.8b\n"
"sxtl v13.8h, v13.8b\n"
"sadalp %[c03].4s, v27.8h\n"
"smlal %[c00].4s, v12.4h, v18.4h\n"
"sxtl v14.8h, v14.8b\n"
"sadalp %[c10].4s, v28.8h\n"
"smlal %[c01].4s, v13.4h, v18.4h\n"
"sxtl v15.8h, v15.8b\n"
"sadalp %[c11].4s, v29.8h\n"
"smlal %[c02].4s, v14.4h, v18.4h\n"
"sxtl v19.8h, v19.8b\n"
"sadalp %[c12].4s, v30.8h\n"
"add %[weight_ptr], %[weight_ptr], %[weight_step]\n"
"smlal %[c03].4s, v15.4h, v18.4h\n"
"sadalp %[c13].4s, v31.8h\n"
"smlal %[c10].4s, v12.4h, v19.4h\n"
"smlal %[c11].4s, v13.4h, v19.4h\n"
"smlal %[c12].4s, v14.4h, v19.4h\n"
"smlal %[c13].4s, v15.4h, v19.4h\n"
:

[c00] "+w"(c[0][0]), [c10] "+w"(c[1][0]),
[c01] "+w"(c[0][1]), [c11] "+w"(c[1][1]),
[c02] "+w"(c[0][2]), [c12] "+w"(c[1][2]),
[c03] "+w"(c[0][3]), [c13] "+w"(c[1][3]),

[weight_ptr] "+r"(weight_ptr),
[weight_ptr_oc] "+r"(weight_ptr_oc)
: [vtbl] "w"(vtbl), [nchw_src_ptr] "r"(nchw_src_ptr),
[nchw_src_ptr_last_line] "r"(nchw_src_ptr_last_line),
[weight_step] "r"(weight_step)
: "x5", "x6", "x7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory");
}
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc);
}
};
#endif

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, stride> {


Loading…
Cancel
Save