GitOrigin-RevId: f1c4cae667
tags/v1.0.0-rc1
@@ -24,82 +24,75 @@ using namespace megdnn; | |||||
using namespace arm_common; | using namespace arm_common; | ||||
namespace { | namespace { | ||||
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T, | |||||
typename T2, typename T3, typename T4> | |||||
template <int src_idx, int weight_idx, int c_dim, int ow_block, int remain_w, | |||||
typename T, typename T2, typename T3, typename T4> | |||||
struct ShiftCalHelper { | struct ShiftCalHelper { | ||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | ||||
}; | }; | ||||
template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
typename T4> | |||||
struct ShiftCalHelper<src_idx, weight_idx, 2, 8, T, T2, T3, T4> { | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
#define cb(step, lane) \ | |||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
src[(step + src_idx) % 8], lane); \ | |||||
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ | |||||
src[(step + src_idx) % 8], lane); | |||||
UNROLL_CALL_RAW(8, cb, 0); | |||||
UNROLL_CALL_RAW(8, cb, 1); | |||||
UNROLL_CALL_RAW(8, cb, 2); | |||||
UNROLL_CALL_RAW(8, cb, 3); | |||||
#undef cb | |||||
} | |||||
}; | |||||
template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
typename T4> | |||||
struct ShiftCalHelper<src_idx, weight_idx, 2, 4, T, T2, T3, T4> { | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
#define cb(step, lane) \ | |||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
src[(step + src_idx) % 4], lane); \ | |||||
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ | |||||
src[(step + src_idx) % 4], lane); | |||||
UNROLL_CALL_RAW(4, cb, 0); | |||||
UNROLL_CALL_RAW(4, cb, 1); | |||||
UNROLL_CALL_RAW(4, cb, 2); | |||||
UNROLL_CALL_RAW(4, cb, 3); | |||||
#undef cb | |||||
} | |||||
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T, | |||||
typename T2, typename T3, typename T4> | |||||
struct ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, 0, T, T2, T3, T4> { | |||||
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&) {} | |||||
}; | }; | ||||
template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
typename T4> | |||||
struct ShiftCalHelper<src_idx, weight_idx, 1, 8, T, T2, T3, T4> { | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
#define cb(step, lane) \ | |||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
src[(step + src_idx) % 8], lane); | |||||
UNROLL_CALL_RAW(8, cb, 0); | |||||
UNROLL_CALL_RAW(8, cb, 1); | |||||
UNROLL_CALL_RAW(8, cb, 2); | |||||
UNROLL_CALL_RAW(8, cb, 3); | |||||
#undef cb | |||||
} | |||||
}; | |||||
template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
typename T4> | |||||
struct ShiftCalHelper<src_idx, weight_idx, 1, 4, T, T2, T3, T4> { | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
#define cb(step, lane) \ | |||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
src[(step + src_idx) % 4], lane); | |||||
#define cb2(step, lane, ow_block) \ | |||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
src[(step + src_idx) % ow_block], lane); \ | |||||
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ | |||||
src[(step + src_idx) % ow_block], lane); | |||||
UNROLL_CALL_RAW(4, cb, 0); | |||||
UNROLL_CALL_RAW(4, cb, 1); | |||||
UNROLL_CALL_RAW(4, cb, 2); | |||||
UNROLL_CALL_RAW(4, cb, 3); | |||||
#define cb(step, lane, ow_block) \ | |||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
src[(step + src_idx) % ow_block], lane); | |||||
#define SHIFT_CAL_HELPER(ow_block, remain_w) \ | |||||
template <int src_idx, int weight_idx, typename T, typename T2, \ | |||||
typename T3, typename T4> \ | |||||
struct ShiftCalHelper<src_idx, weight_idx, 2, ow_block, remain_w, T, T2, \ | |||||
T3, T4> { \ | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ | |||||
UNROLL_CALL_RAW(remain_w, cb2, 0, ow_block); \ | |||||
UNROLL_CALL_RAW(remain_w, cb2, 1, ow_block); \ | |||||
UNROLL_CALL_RAW(remain_w, cb2, 2, ow_block); \ | |||||
UNROLL_CALL_RAW(remain_w, cb2, 3, ow_block); \ | |||||
} \ | |||||
}; \ | |||||
template <int src_idx, int weight_idx, typename T, typename T2, \ | |||||
typename T3, typename T4> \ | |||||
struct ShiftCalHelper<src_idx, weight_idx, 1, ow_block, remain_w, T, T2, \ | |||||
T3, T4> { \ | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ | |||||
UNROLL_CALL_RAW(remain_w, cb, 0, ow_block); \ | |||||
UNROLL_CALL_RAW(remain_w, cb, 1, ow_block); \ | |||||
UNROLL_CALL_RAW(remain_w, cb, 2, ow_block); \ | |||||
UNROLL_CALL_RAW(remain_w, cb, 3, ow_block); \ | |||||
} \ | |||||
}; | |||||
SHIFT_CAL_HELPER(8, 1); | |||||
SHIFT_CAL_HELPER(8, 2); | |||||
SHIFT_CAL_HELPER(8, 3); | |||||
SHIFT_CAL_HELPER(8, 4); | |||||
SHIFT_CAL_HELPER(8, 5); | |||||
SHIFT_CAL_HELPER(8, 6); | |||||
SHIFT_CAL_HELPER(8, 7); | |||||
SHIFT_CAL_HELPER(8, 8); | |||||
SHIFT_CAL_HELPER(4, 1); | |||||
SHIFT_CAL_HELPER(4, 2); | |||||
SHIFT_CAL_HELPER(4, 3); | |||||
SHIFT_CAL_HELPER(4, 4); | |||||
#undef SHIFT_CAL_HELPER | |||||
#undef cb | #undef cb | ||||
} | |||||
}; | |||||
#undef cb2 | |||||
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T, | |||||
typename T2, typename T3> | |||||
template <int src_idx, int weight_idx, int c_dim, int ow_block, int remain_w, | |||||
typename T, typename T2, typename T3> | |||||
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | ||||
ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, T, T2, T3, int>::impl( | |||||
c, src, weight); | |||||
ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, remain_w, T, T2, T3, | |||||
int>::impl(c, src, weight); | |||||
}; | }; | ||||
template <int oc> | template <int oc> | ||||
struct OCHelper { | struct OCHelper { | ||||
@@ -151,7 +144,7 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | float32x4_t c[c_dim][ow_block]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
@@ -162,11 +155,11 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
0); | 0); | ||||
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | ||||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
} | } | ||||
@@ -196,7 +189,7 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | float32x4_t c[c_dim][ow_block]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
@@ -207,15 +200,15 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
0); | 0); | ||||
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | ||||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | ||||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<2, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
} | } | ||||
@@ -244,7 +237,7 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | float32x4_t c[c_dim][ow_block]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
@@ -255,27 +248,27 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
0); | 0); | ||||
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | ||||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | ||||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<2, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); | src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); | ||||
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<3, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); | src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); | ||||
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<4, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
} | } | ||||
@@ -305,7 +298,7 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | float32x4_t c[c_dim][ow_block]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
@@ -316,37 +309,37 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
0); | 0); | ||||
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | ||||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | ||||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<2, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); | src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); | ||||
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<3, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); | src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); | ||||
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<4, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step); | src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step); | ||||
load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<5, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<5, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step); | src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step); | ||||
load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<6, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<6, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
} | } | ||||
@@ -24,83 +24,77 @@ using namespace megdnn; | |||||
using namespace arm_common; | using namespace arm_common; | ||||
namespace { | namespace { | ||||
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T, | |||||
typename T2, typename T3, typename T4> | |||||
template <int src_idx, int weight_idx, int c_dim, int ow_block, int remain_w, | |||||
typename T, typename T2, typename T3, typename T4> | |||||
struct ShiftCalHelper { | struct ShiftCalHelper { | ||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | ||||
}; | }; | ||||
template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
typename T4> | |||||
struct ShiftCalHelper<src_idx, weight_idx, 2, 8, T, T2, T3, T4> { | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
#define cb(step, lane) \ | |||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
src[(step + src_idx) % 8], lane); \ | |||||
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ | |||||
src[(step + src_idx) % 8], lane); | |||||
UNROLL_CALL_RAW(8, cb, 0); | |||||
UNROLL_CALL_RAW(8, cb, 1); | |||||
UNROLL_CALL_RAW(8, cb, 2); | |||||
UNROLL_CALL_RAW(8, cb, 3); | |||||
#undef cb | |||||
} | |||||
}; | |||||
template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
typename T4> | |||||
struct ShiftCalHelper<src_idx, weight_idx, 2, 4, T, T2, T3, T4> { | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
#define cb(step, lane) \ | |||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
src[(step + src_idx) % 4], lane); \ | |||||
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ | |||||
src[(step + src_idx) % 4], lane); | |||||
UNROLL_CALL_RAW(4, cb, 0); | |||||
UNROLL_CALL_RAW(4, cb, 1); | |||||
UNROLL_CALL_RAW(4, cb, 2); | |||||
UNROLL_CALL_RAW(4, cb, 3); | |||||
#undef cb | |||||
} | |||||
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T, | |||||
typename T2, typename T3, typename T4> | |||||
struct ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, 0, T, T2, T3, T4> { | |||||
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&) {} | |||||
}; | }; | ||||
template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
typename T4> | |||||
struct ShiftCalHelper<src_idx, weight_idx, 1, 8, T, T2, T3, T4> { | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
#define cb(step, lane) \ | |||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
src[(step + src_idx) % 8], lane); | |||||
UNROLL_CALL_RAW(8, cb, 0); | |||||
UNROLL_CALL_RAW(8, cb, 1); | |||||
UNROLL_CALL_RAW(8, cb, 2); | |||||
UNROLL_CALL_RAW(8, cb, 3); | |||||
#undef cb | |||||
} | |||||
}; | |||||
template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
typename T4> | |||||
struct ShiftCalHelper<src_idx, weight_idx, 1, 4, T, T2, T3, T4> { | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
#define cb(step, lane) \ | |||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
src[(step + src_idx) % 4], lane); | |||||
#define cb2(step, lane, ow_block) \ | |||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
src[(step + src_idx) % ow_block], lane); \ | |||||
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ | |||||
src[(step + src_idx) % ow_block], lane); | |||||
UNROLL_CALL_RAW(4, cb, 0); | |||||
UNROLL_CALL_RAW(4, cb, 1); | |||||
UNROLL_CALL_RAW(4, cb, 2); | |||||
UNROLL_CALL_RAW(4, cb, 3); | |||||
#define cb(step, lane, ow_block) \ | |||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
src[(step + src_idx) % ow_block], lane); | |||||
#define SHIFT_CAL_HELPER(ow_block, remain_w) \ | |||||
template <int src_idx, int weight_idx, typename T, typename T2, \ | |||||
typename T3, typename T4> \ | |||||
struct ShiftCalHelper<src_idx, weight_idx, 2, ow_block, remain_w, T, T2, \ | |||||
T3, T4> { \ | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ | |||||
UNROLL_CALL_RAW(remain_w, cb2, 0, ow_block); \ | |||||
UNROLL_CALL_RAW(remain_w, cb2, 1, ow_block); \ | |||||
UNROLL_CALL_RAW(remain_w, cb2, 2, ow_block); \ | |||||
UNROLL_CALL_RAW(remain_w, cb2, 3, ow_block); \ | |||||
} \ | |||||
}; \ | |||||
template <int src_idx, int weight_idx, typename T, typename T2, \ | |||||
typename T3, typename T4> \ | |||||
struct ShiftCalHelper<src_idx, weight_idx, 1, ow_block, remain_w, T, T2, \ | |||||
T3, T4> { \ | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ | |||||
UNROLL_CALL_RAW(remain_w, cb, 0, ow_block); \ | |||||
UNROLL_CALL_RAW(remain_w, cb, 1, ow_block); \ | |||||
UNROLL_CALL_RAW(remain_w, cb, 2, ow_block); \ | |||||
UNROLL_CALL_RAW(remain_w, cb, 3, ow_block); \ | |||||
} \ | |||||
}; | |||||
SHIFT_CAL_HELPER(8, 1); | |||||
SHIFT_CAL_HELPER(8, 2); | |||||
SHIFT_CAL_HELPER(8, 3); | |||||
SHIFT_CAL_HELPER(8, 4); | |||||
SHIFT_CAL_HELPER(8, 5); | |||||
SHIFT_CAL_HELPER(8, 6); | |||||
SHIFT_CAL_HELPER(8, 7); | |||||
SHIFT_CAL_HELPER(8, 8); | |||||
SHIFT_CAL_HELPER(4, 1); | |||||
SHIFT_CAL_HELPER(4, 2); | |||||
SHIFT_CAL_HELPER(4, 3); | |||||
SHIFT_CAL_HELPER(4, 4); | |||||
#undef SHIFT_CAL_HELPER | |||||
#undef cb | #undef cb | ||||
} | |||||
}; | |||||
#undef cb2 | |||||
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T, | |||||
typename T2, typename T3> | |||||
template <int src_idx, int weight_idx, int c_dim, int ow_block, int remain_w, | |||||
typename T, typename T2, typename T3> | |||||
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | ||||
ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, T, T2, T3, int>::impl( | |||||
c, src, weight); | |||||
ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, remain_w, T, T2, T3, | |||||
int>::impl(c, src, weight); | |||||
}; | }; | ||||
template <int oc> | template <int oc> | ||||
struct OCHelper { | struct OCHelper { | ||||
public: | public: | ||||
@@ -151,7 +145,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | float32x4_t c[c_dim][ow_block]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
@@ -163,13 +157,13 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | ||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
ld_weight_oc); | ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | ||||
0); | 0); | ||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
@@ -177,13 +171,13 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | ||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
ld_weight_oc); | ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | ||||
0); | 0); | ||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
@@ -213,7 +207,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | float32x4_t c[c_dim][ow_block]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | ||||
@@ -224,18 +218,18 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | ||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
ld_weight_oc); | ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | ||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | ||||
0); | 0); | ||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
@@ -243,17 +237,17 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | ||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
ld_weight_oc); | ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | ||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | ||||
0); | 0); | ||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
@@ -261,18 +255,18 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | ||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
ld_weight_oc); | ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | ||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | ||||
0); | 0); | ||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
@@ -302,7 +296,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | float32x4_t c[c_dim][ow_block]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
@@ -316,25 +310,25 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
0); | 0); | ||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
ld_weight_oc); | ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | ||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); | src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); | ||||
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<2, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
// odd element | // odd element | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>( | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>( | ||||
src, src_ptr_odd, 0); | src, src_ptr_odd, 0); | ||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); | ||||
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
@@ -371,7 +365,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | float32x4_t c[c_dim][ow_block]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
@@ -385,33 +379,33 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
0); | 0); | ||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
ld_weight_oc); | ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | ||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); | src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); | ||||
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<2, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len); | src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len); | ||||
load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<3, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
// odd element | // odd element | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>( | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>( | ||||
src, src_ptr_odd, 0); | src, src_ptr_odd, 0); | ||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); | ||||
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len); | src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len); | ||||
load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<2, 0, c_dim, ow_block>(c, src, weight); | |||||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
@@ -39,16 +39,18 @@ namespace { | |||||
*\tparam T2 is type of src regs | *\tparam T2 is type of src regs | ||||
*\tparam T3 is type of weight regs | *\tparam T3 is type of weight regs | ||||
*/ | */ | ||||
template <int src_idx, int weight_idx, int c_dim, int stride, typename T, | |||||
typename T2, typename T3> | |||||
template <int src_idx, int weight_idx, int c_dim, int stride, int remain_w, | |||||
typename T, typename T2, typename T3> | |||||
struct ShiftCalHelper { | struct ShiftCalHelper { | ||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | ||||
}; | }; | ||||
template <int src_idx, int weight_idx, int stride, typename T, typename T2, | |||||
typename T3> | |||||
struct ShiftCalHelper<src_idx, weight_idx, 2, stride, T, T2, T3> { | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
template <int src_idx, int weight_idx, int c_dim, int stride, typename T, | |||||
typename T2, typename T3> | |||||
struct ShiftCalHelper<src_idx, weight_idx, c_dim, stride, 0, T, T2, T3> { | |||||
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&) {} | |||||
}; | |||||
#define cb(step) \ | #define cb(step) \ | ||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \ | c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \ | ||||
src[(step * stride + src_idx) / 4], \ | src[(step * stride + src_idx) / 4], \ | ||||
@@ -57,29 +59,47 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, stride, T, T2, T3> { | |||||
src[(step * stride + src_idx) / 4], \ | src[(step * stride + src_idx) / 4], \ | ||||
(step * stride + src_idx) % 4); | (step * stride + src_idx) % 4); | ||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | |||||
} | |||||
}; | |||||
template <int src_idx, int weight_idx, int stride, typename T, typename T2, | |||||
typename T3> | |||||
struct ShiftCalHelper<src_idx, weight_idx, 1, stride, T, T2, T3> { | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
#define cb(step) \ | |||||
#define cb2(step) \ | |||||
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \ | c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \ | ||||
src[(step * stride + src_idx) / 4], \ | src[(step * stride + src_idx) / 4], \ | ||||
(step * stride + src_idx) % 4); | (step * stride + src_idx) % 4); | ||||
UNROLL_CALL_RAW(8, cb); | |||||
#define SHIFT_CAL_HELPER(ow_remain) \ | |||||
template <int src_idx, int weight_idx, int stride, typename T, \ | |||||
typename T2, typename T3> \ | |||||
struct ShiftCalHelper<src_idx, weight_idx, 2, stride, ow_remain, T, T2, \ | |||||
T3> { \ | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ | |||||
UNROLL_CALL_RAW(ow_remain, cb); \ | |||||
} \ | |||||
}; \ | |||||
template <int src_idx, int weight_idx, int stride, typename T, \ | |||||
typename T2, typename T3> \ | |||||
struct ShiftCalHelper<src_idx, weight_idx, 1, stride, ow_remain, T, T2, \ | |||||
T3> { \ | |||||
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ | |||||
UNROLL_CALL_RAW(ow_remain, cb2); \ | |||||
} \ | |||||
}; | |||||
SHIFT_CAL_HELPER(1) | |||||
SHIFT_CAL_HELPER(2) | |||||
SHIFT_CAL_HELPER(3) | |||||
SHIFT_CAL_HELPER(4) | |||||
SHIFT_CAL_HELPER(5) | |||||
SHIFT_CAL_HELPER(6) | |||||
SHIFT_CAL_HELPER(7) | |||||
SHIFT_CAL_HELPER(8) | |||||
#undef SHIFT_CAL_HELPER | |||||
#undef cb | #undef cb | ||||
} | |||||
}; | |||||
#undef cb2 | |||||
template <int src_idx, int weight_idx, int c_dim, int stride, typename T, | |||||
typename T2, typename T3> | |||||
template <int src_idx, int weight_idx, int c_dim, int stride, int remain_w, | |||||
typename T, typename T2, typename T3> | |||||
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | ||||
ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3>::impl(c, src, | |||||
weight); | |||||
ShiftCalHelper<src_idx, weight_idx, c_dim, stride, remain_w, T, T2, | |||||
T3>::impl(c, src, weight); | |||||
}; | }; | ||||
enum CpuTag { | enum CpuTag { | ||||
DEFAULT_CPU_TAG = 0, | DEFAULT_CPU_TAG = 0, | ||||
@@ -134,7 +154,7 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride, | |||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][8]; | float32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
float32x4_t src[src_reg_size]; | float32x4_t src[src_reg_size]; | ||||
@@ -145,13 +165,13 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride, | |||||
src, src_ptr + step * iw, 0); \ | src, src_ptr + step * iw, 0); \ | ||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | ||||
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | ||||
cal_helper<0, 0, c_dim, stride>(c, src, weight); \ | |||||
cal_helper<1, 1, c_dim, stride>(c, src, weight); \ | |||||
cal_helper<2, 2, c_dim, stride>(c, src, weight); \ | |||||
cal_helper<3, 3, c_dim, stride>(c, src, weight); \ | |||||
cal_helper<4, 4, c_dim, stride>(c, src, weight); \ | |||||
cal_helper<5, 5, c_dim, stride>(c, src, weight); \ | |||||
cal_helper<6, 6, c_dim, stride>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ | |||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ | |||||
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); \ | |||||
cal_helper<3, 3, c_dim, stride, remain_w>(c, src, weight); \ | |||||
cal_helper<4, 4, c_dim, stride, remain_w>(c, src, weight); \ | |||||
cal_helper<5, 5, c_dim, stride, remain_w>(c, src, weight); \ | |||||
cal_helper<6, 6, c_dim, stride, remain_w>(c, src, weight); | |||||
UNROLL_CALL_RAW(7, KERNEL_CB) | UNROLL_CALL_RAW(7, KERNEL_CB) | ||||
#undef KERNEL_CB | #undef KERNEL_CB | ||||
@@ -185,7 +205,7 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride, | |||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][8]; | float32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
float32x4_t src[src_reg_size]; | float32x4_t src[src_reg_size]; | ||||
@@ -196,11 +216,11 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride, | |||||
src, src_ptr + step * iw, 0); \ | src, src_ptr + step * iw, 0); \ | ||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | ||||
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | ||||
cal_helper<0, 0, c_dim, stride>(c, src, weight); \ | |||||
cal_helper<1, 1, c_dim, stride>(c, src, weight); \ | |||||
cal_helper<2, 2, c_dim, stride>(c, src, weight); \ | |||||
cal_helper<3, 3, c_dim, stride>(c, src, weight); \ | |||||
cal_helper<4, 4, c_dim, stride>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ | |||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ | |||||
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); \ | |||||
cal_helper<3, 3, c_dim, stride, remain_w>(c, src, weight); \ | |||||
cal_helper<4, 4, c_dim, stride, remain_w>(c, src, weight); | |||||
UNROLL_CALL_RAW(5, KERNEL_CB) | UNROLL_CALL_RAW(5, KERNEL_CB) | ||||
#undef KERNEL_CB | #undef KERNEL_CB | ||||
@@ -233,7 +253,7 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride, | |||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][8]; | float32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
float32x4_t src[src_reg_size]; | float32x4_t src[src_reg_size]; | ||||
@@ -243,27 +263,27 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride, | |||||
0); | 0); | ||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, stride>(c, src, weight); | |||||
cal_helper<1, 1, c_dim, stride>(c, src, weight); | |||||
cal_helper<2, 2, c_dim, stride>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | |||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | |||||
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); | |||||
// row 1 | // row 1 | ||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | ||||
src, src_ptr + iw, 0); | src, src_ptr + iw, 0); | ||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, stride>(c, src, weight); | |||||
cal_helper<1, 1, c_dim, stride>(c, src, weight); | |||||
cal_helper<2, 2, c_dim, stride>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | |||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | |||||
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); | |||||
// row 2 | // row 2 | ||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | ||||
src, src_ptr + 2 * iw, 0); | src, src_ptr + 2 * iw, 0); | ||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); | weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, stride>(c, src, weight); | |||||
cal_helper<1, 1, c_dim, stride>(c, src, weight); | |||||
cal_helper<2, 2, c_dim, stride>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | |||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | |||||
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); | |||||
src_ptr += ld_src_ic; | src_ptr += ld_src_ic; | ||||
weight_ptr += ld_weight_ic; | weight_ptr += ld_weight_ic; | ||||
@@ -634,7 +654,7 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride, | |||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][8]; | float32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
float32x4_t src[src_reg_size]; | float32x4_t src[src_reg_size]; | ||||
@@ -644,16 +664,16 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride, | |||||
0); | 0); | ||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, stride>(c, src, weight); | |||||
cal_helper<1, 1, c_dim, stride>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | |||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | |||||
// row 1 | // row 1 | ||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | ||||
src, src_ptr + iw, 0); | src, src_ptr + iw, 0); | ||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | ||||
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, stride>(c, src, weight); | |||||
cal_helper<1, 1, c_dim, stride>(c, src, weight); | |||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | |||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | |||||
src_ptr += ld_src_ic; | src_ptr += ld_src_ic; | ||||
weight_ptr += ld_weight_ic; | weight_ptr += ld_weight_ic; | ||||
@@ -6,14 +6,15 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include <algorithm> | #include <algorithm> | ||||
#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" | #include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" | ||||
#include "src/arm_common/simd_macro/neon_helper.h" | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | #include "src/arm_common/conv_bias/postprocess_helper.h" | ||||
#include "src/arm_common/simd_macro/neon_helper.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
@@ -27,10 +28,9 @@ using namespace conv_stride1; | |||||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | ||||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | ||||
void conv_stride1::do_conv_2x2_stride1(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
void conv_stride1::do_conv_2x2_stride1(const float* src, const float* filter, | |||||
float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - OW; | const size_t tail_step = IW - OW; | ||||
//! unroll of 2 | //! unroll of 2 | ||||
size_t ic = 0; | size_t ic = 0; | ||||
@@ -143,9 +143,9 @@ void conv_stride1::do_conv_2x2_stride1(const float* src, const float* filter, fl | |||||
} | } | ||||
} | } | ||||
void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, | |||||
float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - OW; | const size_t tail_step = IW - OW; | ||||
rep(ic, IC) { | rep(ic, IC) { | ||||
@@ -193,7 +193,7 @@ void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, fl | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); | MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); | ||||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | ||||
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 4); | |||||
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU_2(r3 + 4); | |||||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r30n, 1); | MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r30n, 1); | ||||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r30n, 2); | MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r30n, 2); | ||||
@@ -290,9 +290,9 @@ void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, fl | |||||
} | } | ||||
} | } | ||||
void conv_stride1::do_conv_5x5_stride1(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
void conv_stride1::do_conv_5x5_stride1(const float* src, const float* filter, | |||||
float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - OW; | const size_t tail_step = IW - OW; | ||||
rep(ic, IC) { | rep(ic, IC) { | ||||
@@ -530,9 +530,9 @@ void conv_stride1::do_conv_5x5_stride1(const float* src, const float* filter, fl | |||||
} | } | ||||
} | } | ||||
void conv_stride1::do_conv_7x7_stride1(const float* src, const float* filter, float* dst, | |||||
size_t IH, size_t IW, size_t OH, size_t OW, | |||||
size_t IC) { | |||||
void conv_stride1::do_conv_7x7_stride1(const float* src, const float* filter, | |||||
float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - OW; | const size_t tail_step = IW - OW; | ||||
rep(ic, IC) { | rep(ic, IC) { | ||||
@@ -688,7 +688,7 @@ void conv_stride1::do_conv_7x7_stride1(const float* src, const float* filter, fl | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); | _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); | ||||
MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); | MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); | ||||
MEGDNN_SIMD_TYPE _k46474849 = MEGDNN_SIMD_LOADU(k6 + 4); | |||||
MEGDNN_SIMD_TYPE _k46474849 = MEGDNN_SIMD_LOADU_3(k6 + 4); | |||||
MEGDNN_SIMD_TYPE _r60 = MEGDNN_SIMD_LOADU(r6); | MEGDNN_SIMD_TYPE _r60 = MEGDNN_SIMD_LOADU(r6); | ||||
MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_LOADU(r6 + 4); | MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_LOADU(r6 + 4); | ||||
@@ -126,7 +126,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||||
? oh_idx * oh_block * ow * pack_c | ? oh_idx * oh_block * ow * pack_c | ||||
: oc_idx; | : oc_idx; | ||||
const float* bptr = | const float* bptr = | ||||
kern_param.bias<dt_float32>(batch_id, group_id) + bias_offset; | |||||
kern_param.bias<dt_float32>(batch_id, group_id, oc_idx, 1, pack_c) + | |||||
bias_offset; | |||||
Op op; | Op op; | ||||
conv_bias::conv_direct_fp32_nchw44<bias_mode, Op, filter, stride>( | conv_bias::conv_direct_fp32_nchw44<bias_mode, Op, filter, stride>( | ||||
@@ -69,7 +69,7 @@ struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | |||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
int32x4_t c[c_dim][8]; | int32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
int8x16_t src[src_reg]; | int8x16_t src[src_reg]; | ||||
int8x16_t weight[c_dim][weight_reg]; | int8x16_t weight[c_dim][weight_reg]; | ||||
@@ -117,7 +117,7 @@ struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, | |||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
int32x4_t c[c_dim][8]; | int32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
int8x16_t src[src_reg]; | int8x16_t src[src_reg]; | ||||
int8x16_t weight[c_dim][weight_reg]; | int8x16_t weight[c_dim][weight_reg]; | ||||
@@ -171,7 +171,7 @@ struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, | |||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
int32x4_t c[c_dim][8]; | int32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
int8x16_t src[src_reg]; | int8x16_t src[src_reg]; | ||||
@@ -220,7 +220,7 @@ struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, | |||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
int32x4_t c[c_dim][8]; | int32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
int8x16_t src[src_reg]; | int8x16_t src[src_reg]; | ||||
@@ -80,7 +80,7 @@ struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | |||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
int32x4_t c[c_dim][8]; | int32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
int8x16_t src[2][src_reg]; | int8x16_t src[2][src_reg]; | ||||
int8x16_t weight[c_dim][weight_reg]; | int8x16_t weight[c_dim][weight_reg]; | ||||
@@ -131,7 +131,7 @@ struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, | |||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
int32x4_t c[c_dim][8]; | int32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
int8x16_t src[2][src_reg]; | int8x16_t src[2][src_reg]; | ||||
int8x16_t weight[c_dim][weight_reg]; | int8x16_t weight[c_dim][weight_reg]; | ||||
@@ -189,7 +189,7 @@ struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, | |||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
int32x4_t c[c_dim][8]; | int32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
int8x16_t src[2][src_reg]; | int8x16_t src[2][src_reg]; | ||||
@@ -244,7 +244,7 @@ struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, | |||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
int32x4_t c[c_dim][8]; | int32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
int8x16_t src[2][src_reg]; | int8x16_t src[2][src_reg]; | ||||
@@ -45,7 +45,7 @@ static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, | |||||
int8x16_t src[8 + 1]; | int8x16_t src[8 + 1]; | ||||
int16x8_t temp_c[4]; | int16x8_t temp_c[4]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | ||||
@@ -135,7 +135,7 @@ static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, | |||||
int8x16_t weight[1][2]; | int8x16_t weight[1][2]; | ||||
int8x16_t src[8 + 1]; | int8x16_t src[8 + 1]; | ||||
int16x8_t temp_c[2]; | int16x8_t temp_c[2]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | ||||
@@ -224,7 +224,7 @@ struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 3, c_dim, DstType> { | |||||
int8x16_t weight[3]; | int8x16_t weight[3]; | ||||
int8x16_t src[8 + 2]; | int8x16_t src[8 + 2]; | ||||
int16x8_t temp_c[2]; | int16x8_t temp_c[2]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | ||||
@@ -306,7 +306,7 @@ struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 5, c_dim, DstType> { | |||||
int8x16_t weight[5]; | int8x16_t weight[5]; | ||||
int8x16_t src[8 + 2]; | int8x16_t src[8 + 2]; | ||||
int16x8_t temp_c[2]; | int16x8_t temp_c[2]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | ||||
@@ -409,7 +409,7 @@ struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 7, c_dim, DstType> { | |||||
int8x16_t weight[7]; | int8x16_t weight[7]; | ||||
int8x16_t src[8 + 2]; | int8x16_t src[8 + 2]; | ||||
int16x8_t temp_c[2]; | int16x8_t temp_c[2]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | ||||
@@ -569,7 +569,7 @@ void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, | |||||
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | ||||
const size_t dst_offset = | const size_t dst_offset = | ||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | ||||
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, 0, filter_size, | |||||
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, ow_step, filter_size, | |||||
2, DstType>( | 2, DstType>( | ||||
src + src_offset, filter + weight_offset, bias + oc_idx, | src + src_offset, filter + weight_offset, bias + oc_idx, | ||||
dst + dst_offset, ic, ih, iw, ld_oc, op); | dst + dst_offset, ic, ih, iw, ld_oc, op); | ||||
@@ -594,7 +594,7 @@ void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, | |||||
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | ||||
const size_t dst_offset = | const size_t dst_offset = | ||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | ||||
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, 0, filter_size, | |||||
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, ow_step, filter_size, | |||||
1, DstType>( | 1, DstType>( | ||||
src + src_offset, filter + weight_offset, bias + oc_idx, | src + src_offset, filter + weight_offset, bias + oc_idx, | ||||
dst + dst_offset, ic, ih, iw, ld_oc, op); | dst + dst_offset, ic, ih, iw, ld_oc, op); | ||||
@@ -54,7 +54,7 @@ static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, | |||||
int8x16_t src[8 + 1]; | int8x16_t src[8 + 1]; | ||||
int16x8_t temp_c[4]; | int16x8_t temp_c[4]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | ||||
@@ -151,7 +151,7 @@ static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, | |||||
int8x16_t weight[2]; | int8x16_t weight[2]; | ||||
int8x16_t src[8 + 1]; | int8x16_t src[8 + 1]; | ||||
int16x8_t temp_c[2]; | int16x8_t temp_c[2]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | ||||
@@ -239,7 +239,7 @@ struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 3, c_dim, DstType> { | |||||
int8x16_t weight[3]; | int8x16_t weight[3]; | ||||
int8x16_t src[8 + 2]; | int8x16_t src[8 + 2]; | ||||
int16x8_t temp_c[4]; | int16x8_t temp_c[4]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | ||||
@@ -327,7 +327,7 @@ struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 5, c_dim, DstType> { | |||||
int8x16_t weight[5]; | int8x16_t weight[5]; | ||||
int8x16_t src[8 + 2]; | int8x16_t src[8 + 2]; | ||||
int16x8_t temp_c[4]; | int16x8_t temp_c[4]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | ||||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | ||||
@@ -435,7 +435,7 @@ struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 7, c_dim, DstType> { | |||||
int8x16_t weight[7]; | int8x16_t weight[7]; | ||||
int8x16_t src[8 + 2]; | int8x16_t src[8 + 2]; | ||||
int16x8_t temp_c[4]; | int16x8_t temp_c[4]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | ||||
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | ||||
@@ -131,7 +131,7 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> { | |||||
const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
int32x4_t c[c_dim][8]; | int32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | ||||
@@ -178,7 +178,7 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block, 1> { | |||||
const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
int32x4_t c[c_dim][8]; | int32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | ||||
@@ -232,7 +232,7 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block, 1> { | |||||
const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
int32x4_t c[c_dim][8]; | int32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | ||||
@@ -279,7 +279,7 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block, 1> { | |||||
const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
int32x4_t c[c_dim][8]; | int32x4_t c[c_dim][8]; | ||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | ||||
@@ -643,128 +643,95 @@ __ai int32x4_t neon_vld1q(const int* ptr) { | |||||
__ai int16x8_t neon_vld1q(const int16_t* ptr) { | __ai int16x8_t neon_vld1q(const int16_t* ptr) { | ||||
return vld1q_s16(ptr); | return vld1q_s16(ptr); | ||||
} | } | ||||
template <int c_dim, BiasMode bias_mode, int ow_block, typename T, typename T2> | |||||
template <typename T> | |||||
struct NeonLdqSimd; | |||||
template <> | |||||
struct NeonLdqSimd<float> { | |||||
static constexpr int simd_len = 4; | |||||
}; | |||||
template <> | |||||
struct NeonLdqSimd<int> { | |||||
static constexpr int simd_len = 4; | |||||
}; | |||||
template <> | |||||
struct NeonLdqSimd<int16_t> { | |||||
static constexpr int simd_len = 8; | |||||
}; | |||||
template <int c_dim, BiasMode bias_mode, int ow_remain, typename T, typename T2> | |||||
struct InitOcxOw8 { | struct InitOcxOw8 { | ||||
static __ai void impl(T& c, const T2* bias_ptr, int oc_step); | static __ai void impl(T& c, const T2* bias_ptr, int oc_step); | ||||
}; | }; | ||||
template <typename T, typename T2> | |||||
struct InitOcxOw8<2, BiasMode::NO_BIAS, 8, T, T2> { | |||||
static __ai void impl(T& c, const T2*, int) { | |||||
#define BAIS_INIT(step) \ | |||||
c[0][step] = neon_vdupq_n(static_cast<T2>(0)); \ | |||||
c[1][step] = neon_vdupq_n(static_cast<T2>(0)); | |||||
UNROLL_CALL_RAW(8, BAIS_INIT); | |||||
#undef BAIS_INIT | |||||
} | |||||
template <int c_dim, BiasMode bias_mode, typename T, typename T2> | |||||
struct InitOcxOw8<c_dim, bias_mode, 0, T, T2> { | |||||
static __ai void impl(T&, const T2*, int) {} | |||||
}; | }; | ||||
template <typename T, typename T2> | |||||
struct InitOcxOw8<2, BiasMode::NO_BIAS, 4, T, T2> { | |||||
static __ai void impl(T& c, const T2*, int) { | |||||
#define BAIS_INIT(step) \ | |||||
#define BAIS_INIT_NO_BIAS_C2(step) \ | |||||
c[0][step] = neon_vdupq_n(static_cast<T2>(0)); \ | c[0][step] = neon_vdupq_n(static_cast<T2>(0)); \ | ||||
c[1][step] = neon_vdupq_n(static_cast<T2>(0)); | c[1][step] = neon_vdupq_n(static_cast<T2>(0)); | ||||
UNROLL_CALL_RAW(4, BAIS_INIT); | |||||
#undef BAIS_INIT | |||||
} | |||||
}; | |||||
template <typename T, typename T2> | |||||
struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> { | |||||
static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { | |||||
#define BAIS_INIT(step) \ | |||||
c[0][step] = neon_vld1q(bias_ptr); \ | |||||
c[1][step] = neon_vld1q(bias_ptr + oc_step); | |||||
UNROLL_CALL_RAW(8, BAIS_INIT); | |||||
#undef BAIS_INIT | |||||
} | |||||
}; | |||||
template <typename T, typename T2> | |||||
struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> { | |||||
static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { | |||||
#define BAIS_INIT(step) \ | |||||
#define BAIS_INIT_NO_BIAS_C1(step) \ | |||||
c[0][step] = neon_vdupq_n(static_cast<T2>(0)); | |||||
#define BAIS_INIT_BROADCAST_C2(step) \ | |||||
c[0][step] = neon_vld1q(bias_ptr); \ | c[0][step] = neon_vld1q(bias_ptr); \ | ||||
c[1][step] = neon_vld1q(bias_ptr + oc_step); | c[1][step] = neon_vld1q(bias_ptr + oc_step); | ||||
UNROLL_CALL_RAW(4, BAIS_INIT); | |||||
#undef BAIS_INIT | |||||
} | |||||
}; | |||||
template <typename T, typename T2> | |||||
struct InitOcxOw8<2, BiasMode::BIAS, 8, T, T2> { | |||||
static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { | |||||
constexpr int simd_len = 4; | |||||
#define BAIS_INIT(step) \ | |||||
c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \ | |||||
c[1][step] = neon_vld1q(bias_ptr + oc_step + step * simd_len); | |||||
UNROLL_CALL_RAW(8, BAIS_INIT); | |||||
#undef BAIS_INIT | |||||
} | |||||
}; | |||||
template <typename T, typename T2> | |||||
struct InitOcxOw8<2, BiasMode::BIAS, 4, T, T2> { | |||||
static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { | |||||
constexpr int simd_len = 4; | |||||
#define BAIS_INIT(step) \ | |||||
#define BAIS_INIT_BROADCAST_C1(step) c[0][step] = neon_vld1q(bias_ptr); | |||||
#define BAIS_INIT_BIAS_C2(step) \ | |||||
c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \ | c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \ | ||||
c[1][step] = neon_vld1q(bias_ptr + oc_step + step * simd_len); | c[1][step] = neon_vld1q(bias_ptr + oc_step + step * simd_len); | ||||
UNROLL_CALL_RAW(4, BAIS_INIT); | |||||
#undef BAIS_INIT | |||||
} | |||||
}; | |||||
template <typename T, typename T2> | |||||
struct InitOcxOw8<1, BiasMode::NO_BIAS, 8, T, T2> { | |||||
static __ai void impl(T& c, const T2*, int) { | |||||
#define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast<T2>(0)); | |||||
UNROLL_CALL_RAW(8, BAIS_INIT); | |||||
#undef BAIS_INIT | |||||
} | |||||
}; | |||||
template <typename T, typename T2> | |||||
struct InitOcxOw8<1, BiasMode::NO_BIAS, 4, T, T2> { | |||||
static __ai void impl(T& c, const T2*, int) { | |||||
#define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast<T2>(0)); | |||||
UNROLL_CALL_RAW(4, BAIS_INIT); | |||||
#undef BAIS_INIT | |||||
} | |||||
}; | |||||
template <typename T, typename T2> | |||||
struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> { | |||||
static __ai void impl(T& c, const T2* bias_ptr, int) { | |||||
#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr); | |||||
UNROLL_CALL_RAW(8, BAIS_INIT); | |||||
#undef BAIS_INIT | |||||
} | |||||
}; | |||||
template <typename T, typename T2> | |||||
struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> { | |||||
static __ai void impl(T& c, const T2* bias_ptr, int) { | |||||
#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr); | |||||
UNROLL_CALL_RAW(4, BAIS_INIT); | |||||
#undef BAIS_INIT | |||||
} | |||||
}; | |||||
template <typename T, typename T2> | |||||
struct InitOcxOw8<1, BiasMode::BIAS, 8, T, T2> { | |||||
static __ai void impl(T& c, const T2* bias_ptr, int) { | |||||
constexpr int simd_len = 4; | |||||
#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len); | |||||
UNROLL_CALL_RAW(8, BAIS_INIT); | |||||
#undef BAIS_INIT | |||||
} | |||||
}; | |||||
template <typename T, typename T2> | |||||
struct InitOcxOw8<1, BiasMode::BIAS, 4, T, T2> { | |||||
static __ai void impl(T& c, const T2* bias_ptr, int) { | |||||
constexpr int simd_len = 4; | |||||
#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len); | |||||
UNROLL_CALL_RAW(4, BAIS_INIT); | |||||
#undef BAIS_INIT | |||||
} | |||||
}; | |||||
template <int c_dim, BiasMode bias_mode, int ow_block, typename T, typename T2> | |||||
#define BAIS_INIT_BIAS_C1(step) \ | |||||
c[0][step] = neon_vld1q(bias_ptr + step * simd_len); | |||||
#define INSTANCE_InitOcxOw8(ow_remain, cdim) \ | |||||
template <typename T, typename T2> \ | |||||
struct InitOcxOw8<cdim, BiasMode::NO_BIAS, ow_remain, T, T2> { \ | |||||
static __ai void impl(T& c, const T2*, int) { \ | |||||
UNROLL_CALL_RAW(ow_remain, BAIS_INIT_NO_BIAS_C##cdim); \ | |||||
} \ | |||||
}; \ | |||||
template <typename T, typename T2> \ | |||||
struct InitOcxOw8<cdim, BiasMode::BROADCAST_CHANNEL_BIAS, ow_remain, T, \ | |||||
T2> { \ | |||||
static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { \ | |||||
(void)oc_step; \ | |||||
UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BROADCAST_C##cdim); \ | |||||
} \ | |||||
}; \ | |||||
template <typename T, typename T2> \ | |||||
struct InitOcxOw8<cdim, BiasMode::BIAS, ow_remain, T, T2> { \ | |||||
static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { \ | |||||
constexpr int simd_len = NeonLdqSimd<T2>::simd_len; \ | |||||
(void)oc_step; \ | |||||
UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BIAS_C##cdim); \ | |||||
} \ | |||||
}; | |||||
#define INSTANCE_InitOcxOw8_C(ow_remain) \ | |||||
INSTANCE_InitOcxOw8(ow_remain, 2); \ | |||||
INSTANCE_InitOcxOw8(ow_remain, 1); | |||||
INSTANCE_InitOcxOw8_C(1); | |||||
INSTANCE_InitOcxOw8_C(2); | |||||
INSTANCE_InitOcxOw8_C(3); | |||||
INSTANCE_InitOcxOw8_C(4); | |||||
INSTANCE_InitOcxOw8_C(5); | |||||
INSTANCE_InitOcxOw8_C(6); | |||||
INSTANCE_InitOcxOw8_C(7); | |||||
INSTANCE_InitOcxOw8_C(8); | |||||
#undef INSTANCE_InitOcxOw8 | |||||
#undef INSTANCE_InitOcxOw8_C | |||||
#undef BAIS_INIT_BIAS_C1 | |||||
#undef BAIS_INIT_BIAS_C2 | |||||
#undef BAIS_INIT_BROADCAST_C1 | |||||
#undef BAIS_INIT_BROADCAST_C2 | |||||
#undef BAIS_INIT_NO_BIAS_C1 | |||||
#undef BAIS_INIT_NO_BIAS_C2 | |||||
template <int c_dim, BiasMode bias_mode, int ow_remain, typename T, typename T2> | |||||
__ai void init_ocx_ow8(T& c, const T2* bias_ptr, int oc_step) { | __ai void init_ocx_ow8(T& c, const T2* bias_ptr, int oc_step) { | ||||
InitOcxOw8<c_dim, bias_mode, ow_block, T, T2>::impl(c, bias_ptr, oc_step); | |||||
InitOcxOw8<c_dim, bias_mode, ow_remain, T, T2>::impl(c, bias_ptr, oc_step); | |||||
} | } | ||||
/////////////////////init_ocx_ow4///////////////////// | /////////////////////init_ocx_ow4///////////////////// | ||||
template <int c_dim, BiasMode bias_mode, typename T> | template <int c_dim, BiasMode bias_mode, typename T> | ||||
@@ -18,6 +18,8 @@ | |||||
#define MEGDNN_SIMD_TYPE float32x4_t | #define MEGDNN_SIMD_TYPE float32x4_t | ||||
#define MEGDNN_SIMD_TYPE2 float32x4x2_t | #define MEGDNN_SIMD_TYPE2 float32x4x2_t | ||||
#define MEGDNN_SIMD_LOADU(addr) vld1q_f32(addr) | #define MEGDNN_SIMD_LOADU(addr) vld1q_f32(addr) | ||||
#define MEGDNN_SIMD_LOADU_2(addr) vcombine_f32(vld1_f32(addr), vdup_n_f32(0.f)) | |||||
#define MEGDNN_SIMD_LOADU_3(addr) vld1q_lane_f32(addr + 2, vcombine_f32(vld1_f32(addr), vdup_n_f32(0.f)), 2) | |||||
#define MEGDNN_SIMD_STOREU(addr, reg) vst1q_f32(addr, reg) | #define MEGDNN_SIMD_STOREU(addr, reg) vst1q_f32(addr, reg) | ||||
#define MEGDNN_SIMD_SETZERO() vdupq_n_f32(0.0f) | #define MEGDNN_SIMD_SETZERO() vdupq_n_f32(0.0f) | ||||
#define MEGDNN_SIMD_SET1(num) vdupq_n_f32(num) | #define MEGDNN_SIMD_SET1(num) vdupq_n_f32(num) | ||||
@@ -23,6 +23,20 @@ | |||||
#include <regex> | #include <regex> | ||||
#include <unordered_map> | #include <unordered_map> | ||||
// clang-format off | |||||
#if defined(__has_feature) | |||||
#if __has_feature(address_sanitizer) | |||||
#define MEGDNN_TEST_ASAN 1 | |||||
#else | |||||
#define MEGDNN_TEST_ASAN 0 | |||||
#endif | |||||
#elif defined(__SANITIZE_ADDRESS__) | |||||
#define MEGDNN_TEST_ASAN 1 | |||||
#else | |||||
#define MEGDNN_TEST_ASAN 0 | |||||
#endif | |||||
// clang-format on | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace test { | namespace test { | ||||
@@ -76,6 +76,9 @@ TEST_F(FALLBACK, WARP_PERSPECTIVE) { | |||||
checker.set_param(param); | checker.set_param(param); | ||||
checker.exec({{1000, 2, 10, 11}, {1000, 3, 3}, {1000, 2, 12, 13}}); | checker.exec({{1000, 2, 10, 11}, {1000, 3, 3}, {1000, 2, 12, 13}}); | ||||
} | } | ||||
#if MEGDNN_TEST_ASAN | |||||
//! asan detect nan will make test failed | |||||
#else | |||||
// resize nan case | // resize nan case | ||||
UniformFloatRNG rng_zero(0, 0); | UniformFloatRNG rng_zero(0, 0); | ||||
checker.set_rng(1, &rng_zero); | checker.set_rng(1, &rng_zero); | ||||
@@ -85,6 +88,7 @@ TEST_F(FALLBACK, WARP_PERSPECTIVE) { | |||||
checker.set_param(param); | checker.set_param(param); | ||||
checker.exec({{1000, 2, 10, 11}, {1000, 3, 3}, {1000, 2, 12, 13}}); | checker.exec({{1000, 2, 10, 11}, {1000, 3, 3}, {1000, 2, 12, 13}}); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
TEST_F(FALLBACK, WARP_PERSPECTIVE_MAT_IDX) { | TEST_F(FALLBACK, WARP_PERSPECTIVE_MAT_IDX) { | ||||
@@ -352,6 +352,9 @@ TEST_F(NAIVE_MULTI_THREADS, WARP_PERSPECTIVE_FORWARD_HWCD4) { | |||||
checker.execs({{22, 10, 1, 11, 4}, {22, 3, 3}, {22, 11, 1, 12, 4}}); | checker.execs({{22, 10, 1, 11, 4}, {22, 3, 3}, {22, 11, 1, 12, 4}}); | ||||
} | } | ||||
} | } | ||||
#if MEGDNN_TEST_ASAN | |||||
//! asan detect nan will make test failed | |||||
#else | |||||
// nan case | // nan case | ||||
NanMatRNG rng_nan; | NanMatRNG rng_nan; | ||||
UniformFloatRNG rng_zero(0, 0); | UniformFloatRNG rng_zero(0, 0); | ||||
@@ -369,6 +372,7 @@ TEST_F(NAIVE_MULTI_THREADS, WARP_PERSPECTIVE_FORWARD_HWCD4) { | |||||
checker.set_param(param); | checker.set_param(param); | ||||
checker.exec({{10, 10, 1, 11, 4}, {10, 3, 3}, {10, 12, 1, 13, 4}}); | checker.exec({{10, 10, 1, 11, 4}, {10, 3, 3}, {10, 12, 1, 13, 4}}); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||