|
@@ -207,24 +207,27 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block> { |
|
|
float32x4_t src[src_reg_size]; |
|
|
float32x4_t src[src_reg_size]; |
|
|
float32x4_t weight[c_dim][filter_size]; |
|
|
float32x4_t weight[c_dim][filter_size]; |
|
|
// row 0 |
|
|
// row 0 |
|
|
load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); |
|
|
|
|
|
load_helper<3, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, |
|
|
|
|
|
ld_weight_oc); |
|
|
|
|
|
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, |
|
|
|
|
|
0); |
|
|
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( |
|
|
|
|
|
weight, weight_ptr, ld_weight_oc); |
|
|
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
|
|
|
|
|
|
// row 1 |
|
|
// row 1 |
|
|
load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + iw, 0); |
|
|
|
|
|
load_helper<3, 0, oc_step, c_dim, Vld1q_f32>( |
|
|
|
|
|
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( |
|
|
|
|
|
src, src_ptr + iw, 0); |
|
|
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( |
|
|
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); |
|
|
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); |
|
|
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
|
|
|
|
|
|
// row 2 |
|
|
// row 2 |
|
|
load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + 2 * iw, 0); |
|
|
|
|
|
load_helper<3, 0, oc_step, c_dim, Vld1q_f32>( |
|
|
|
|
|
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( |
|
|
|
|
|
src, src_ptr + 2 * iw, 0); |
|
|
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( |
|
|
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); |
|
|
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); |
|
|
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
@@ -238,6 +241,52 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block> { |
|
|
} |
|
|
} |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> |
|
|
|
|
|
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block> { |
|
|
|
|
|
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, |
|
|
|
|
|
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, |
|
|
|
|
|
int ih, int iw, int ld_dst_oc, const Op& op) { |
|
|
|
|
|
constexpr int loop_ic_step = 1; |
|
|
|
|
|
constexpr int filter_size = 2; |
|
|
|
|
|
constexpr int oc_step = 4; |
|
|
|
|
|
constexpr int simd_len = 4; |
|
|
|
|
|
constexpr int src_reg_size = 4; |
|
|
|
|
|
|
|
|
|
|
|
constexpr int ld_weight_fw = oc_step * filter_size; |
|
|
|
|
|
const int ld_weight_oc = oc_step * filter_size * filter_size * ic; |
|
|
|
|
|
const int ld_weight_ic = oc_step * filter_size * filter_size; |
|
|
|
|
|
const int ld_src_ic = ih * iw; |
|
|
|
|
|
constexpr int c_dim = OCHelper<oc_block>::val; |
|
|
|
|
|
float32x4_t c[c_dim][8]; |
|
|
|
|
|
init_ocx_ow8<c_dim, bias_mode>(c, bias_ptr, oc_step); |
|
|
|
|
|
|
|
|
|
|
|
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { |
|
|
|
|
|
float32x4_t src[src_reg_size]; |
|
|
|
|
|
float32x4_t weight[c_dim][filter_size]; |
|
|
|
|
|
// row 0 |
|
|
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, |
|
|
|
|
|
0); |
|
|
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( |
|
|
|
|
|
weight, weight_ptr, ld_weight_oc); |
|
|
|
|
|
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
|
|
|
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
|
|
|
|
|
|
|
|
|
// row 1 |
|
|
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( |
|
|
|
|
|
src, src_ptr + iw, 0); |
|
|
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( |
|
|
|
|
|
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); |
|
|
|
|
|
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
|
|
|
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); |
|
|
|
|
|
|
|
|
|
|
|
src_ptr += ld_src_ic; |
|
|
|
|
|
weight_ptr += ld_weight_ic; |
|
|
|
|
|
} |
|
|
|
|
|
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr, |
|
|
|
|
|
ld_dst_oc); |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
} // namespace |
|
|
} // namespace |
|
|
|
|
|
|
|
|
void conv_bias::pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, |
|
|
void conv_bias::pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, |
|
@@ -383,19 +432,12 @@ static void conv_direct_stride2_fp32_nchw_nchw44( |
|
|
ow, op, ph, pw); \ |
|
|
ow, op, ph, pw); \ |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
CONSTRUCT_FUNC(2); |
|
|
CONSTRUCT_FUNC(3); |
|
|
CONSTRUCT_FUNC(3); |
|
|
CONSTRUCT_FUNC(5); |
|
|
CONSTRUCT_FUNC(5); |
|
|
CONSTRUCT_FUNC(7); |
|
|
CONSTRUCT_FUNC(7); |
|
|
#undef CONSTRUCT_FUNC |
|
|
#undef CONSTRUCT_FUNC |
|
|
|
|
|
|
|
|
template <BiasMode bias_mode, typename Op> |
|
|
|
|
|
void conv_bias::conv_direct_stride2_2x2_fp32_nchw_nchw44( |
|
|
|
|
|
const float32_t*, const float32_t*, const float32_t*, float32_t*, |
|
|
|
|
|
float32_t*, const int, const int, const int, const int, const int, |
|
|
|
|
|
const int, const int, const Op&, const int, const int) { |
|
|
|
|
|
megdnn_assert(0, "not imple nchw_nchw44 2x2s2 conv"); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#define INSTANTIATION(stride, i, bias, Op) \ |
|
|
#define INSTANTIATION(stride, i, bias, Op) \ |
|
|
template void conv_bias:: \ |
|
|
template void conv_bias:: \ |
|
|
conv_direct_##stride##_##i##x##i##_fp32_nchw_nchw44<bias, Op>( \ |
|
|
conv_direct_##stride##_##i##x##i##_fp32_nchw_nchw44<bias, Op>( \ |
|
|