|
|
@@ -37,6 +37,26 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, stride, 0, T, T2, T3> { |
|
|
|
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&) {} |
|
|
|
}; |
|
|
|
|
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) |
|
|
|
//! x86 and rvv GiSimdFmaLane API is slowly, as an alternate, use |
|
|
|
//! GiMultiplyAddScalarFloat32 |
|
|
|
#define MLA GiMultiplyAddScalarFloat32 |
|
|
|
#define cb(step) \ |
|
|
|
c[0][step] = GiFloat32Type2FixLenType(MLA( \ |
|
|
|
GiFixLenType2GiFloat32Type(c[0][step]), \ |
|
|
|
GiFixLenType2GiFloat32Type(weight[0][weight_idx]), \ |
|
|
|
*(src[(step * stride + src_idx) / 4] + (step * stride + src_idx) % 4))); \ |
|
|
|
c[1][step] = GiFloat32Type2FixLenType(MLA( \ |
|
|
|
GiFixLenType2GiFloat32Type(c[1][step]), \ |
|
|
|
GiFixLenType2GiFloat32Type(weight[1][weight_idx]), \ |
|
|
|
*(src[(step * stride + src_idx) / 4] + (step * stride + src_idx) % 4))); |
|
|
|
|
|
|
|
#define cb2(step) \ |
|
|
|
c[0][step] = GiFloat32Type2FixLenType(MLA( \ |
|
|
|
GiFixLenType2GiFloat32Type(c[0][step]), \ |
|
|
|
GiFixLenType2GiFloat32Type(weight[0][weight_idx]), \ |
|
|
|
*(src[(step * stride + src_idx) / 4] + (step * stride + src_idx) % 4))); |
|
|
|
#else |
|
|
|
#define cb(step) \ |
|
|
|
c[0][step] = GiFloat32Type2FixLenType(GiSimdFmaLane( \ |
|
|
|
GiFixLenType2GiFloat32Type(c[0][step]), \ |
|
|
@@ -55,6 +75,8 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, stride, 0, T, T2, T3> { |
|
|
|
GiFixLenType2GiFloat32Type(weight[0][weight_idx]), \ |
|
|
|
GiFixLenType2GiFloat32Type(src[(step * stride + src_idx) / 4]), \ |
|
|
|
(step * stride + src_idx) % 4)); |
|
|
|
#undef MLA |
|
|
|
#endif |
|
|
|
|
|
|
|
#define SHIFT_CAL_HELPER(ow_remain) \ |
|
|
|
template < \ |
|
|
@@ -151,23 +173,38 @@ struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride, ow_ |
|
|
|
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); |
|
|
|
|
|
|
|
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { |
|
|
|
//! x86 and rvv GiSimdFmaLane API is slowly, as an alternate, use |
|
|
|
//! GiMultiplyAddScalarFloat32 |
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) |
|
|
|
const float* src[src_reg_size]; |
|
|
|
#else |
|
|
|
GI_FLOAT32_FIXLEN_t src[src_reg_size]; |
|
|
|
#endif |
|
|
|
GI_FLOAT32_FIXLEN_t weight[c_dim][filter_size]; |
|
|
|
|
|
|
|
#define KERNEL_CB(step) \ |
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + step * iw, 0); \ |
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( \ |
|
|
|
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ |
|
|
|
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<3, 3, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<4, 4, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<5, 5, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) |
|
|
|
#define SRC_LOAD(step) \ |
|
|
|
load_ptr_helper<src_reg_size, 0, simd_len, 0>(src, src_ptr + step * iw, 0) |
|
|
|
#else |
|
|
|
#define SRC_LOAD(step) \ |
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + step * iw, 0) |
|
|
|
#endif |
|
|
|
|
|
|
|
#define KERNEL_CB(step) \ |
|
|
|
SRC_LOAD(step); \ |
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( \ |
|
|
|
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ |
|
|
|
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<3, 3, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<4, 4, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<5, 5, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<6, 6, c_dim, stride, remain_w>(c, src, weight); |
|
|
|
|
|
|
|
UNROLL_CALL_RAW(7, KERNEL_CB) |
|
|
|
#undef KERNEL_CB |
|
|
|
#undef SRC_LOAD |
|
|
|
|
|
|
|
src_ptr += ld_src_ic; |
|
|
|
weight_ptr += ld_weight_ic; |
|
|
@@ -200,20 +237,33 @@ struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride, ow_ |
|
|
|
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); |
|
|
|
|
|
|
|
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { |
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) |
|
|
|
const float* src[src_reg_size]; |
|
|
|
#else |
|
|
|
GI_FLOAT32_FIXLEN_t src[src_reg_size]; |
|
|
|
#endif |
|
|
|
GI_FLOAT32_FIXLEN_t weight[c_dim][filter_size]; |
|
|
|
|
|
|
|
#define KERNEL_CB(step) \ |
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + step * iw, 0); \ |
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( \ |
|
|
|
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ |
|
|
|
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<3, 3, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) |
|
|
|
#define SRC_LOAD(step) \ |
|
|
|
load_ptr_helper<src_reg_size, 0, simd_len, 0>(src, src_ptr + step * iw, 0); |
|
|
|
#else |
|
|
|
#define SRC_LOAD(step) \ |
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + step * iw, 0); |
|
|
|
#endif |
|
|
|
|
|
|
|
#define KERNEL_CB(step) \ |
|
|
|
SRC_LOAD(step); \ |
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( \ |
|
|
|
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ |
|
|
|
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<3, 3, c_dim, stride, remain_w>(c, src, weight); \ |
|
|
|
cal_helper<4, 4, c_dim, stride, remain_w>(c, src, weight); |
|
|
|
UNROLL_CALL_RAW(5, KERNEL_CB) |
|
|
|
#undef KERNEL_CB |
|
|
|
#undef SRC_LOAD |
|
|
|
|
|
|
|
src_ptr += ld_src_ic; |
|
|
|
weight_ptr += ld_weight_ic; |
|
|
@@ -246,10 +296,18 @@ struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride, ow_ |
|
|
|
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); |
|
|
|
|
|
|
|
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { |
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) |
|
|
|
const float* src[src_reg_size]; |
|
|
|
#else |
|
|
|
GI_FLOAT32_FIXLEN_t src[src_reg_size]; |
|
|
|
#endif |
|
|
|
GI_FLOAT32_FIXLEN_t weight[c_dim][filter_size]; |
|
|
|
// row 0 |
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) |
|
|
|
load_ptr_helper<src_reg_size, 0, simd_len, 0>(src, src_ptr, 0); |
|
|
|
#else |
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); |
|
|
|
#endif |
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( |
|
|
|
weight, weight_ptr, ld_weight_oc); |
|
|
|
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); |
|
|
@@ -257,7 +315,11 @@ struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride, ow_ |
|
|
|
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); |
|
|
|
|
|
|
|
// row 1 |
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) |
|
|
|
load_ptr_helper<src_reg_size, 0, simd_len, 0>(src, src_ptr + iw, 0); |
|
|
|
#else |
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + iw, 0); |
|
|
|
#endif |
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( |
|
|
|
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); |
|
|
|
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); |
|
|
@@ -265,8 +327,12 @@ struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride, ow_ |
|
|
|
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); |
|
|
|
|
|
|
|
// row 2 |
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) |
|
|
|
load_ptr_helper<src_reg_size, 0, simd_len, 0>(src, src_ptr + 2 * iw, 0); |
|
|
|
#else |
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>( |
|
|
|
src, src_ptr + 2 * iw, 0); |
|
|
|
#endif |
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( |
|
|
|
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); |
|
|
|
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); |
|
|
@@ -637,17 +703,29 @@ struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride, ow_ |
|
|
|
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); |
|
|
|
|
|
|
|
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { |
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) |
|
|
|
const float* src[src_reg_size]; |
|
|
|
#else |
|
|
|
GI_FLOAT32_FIXLEN_t src[src_reg_size]; |
|
|
|
#endif |
|
|
|
GI_FLOAT32_FIXLEN_t weight[c_dim][filter_size]; |
|
|
|
// row 0 |
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) |
|
|
|
load_ptr_helper<src_reg_size, 0, simd_len, 0>(src, src_ptr, 0); |
|
|
|
#else |
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); |
|
|
|
#endif |
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( |
|
|
|
weight, weight_ptr, ld_weight_oc); |
|
|
|
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); |
|
|
|
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); |
|
|
|
|
|
|
|
// row 1 |
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) |
|
|
|
load_ptr_helper<src_reg_size, 0, simd_len, 0>(src, src_ptr + iw, 0); |
|
|
|
#else |
|
|
|
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + iw, 0); |
|
|
|
#endif |
|
|
|
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( |
|
|
|
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); |
|
|
|
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); |
|
|
@@ -670,7 +748,7 @@ struct ConvDirectFp32NchwNchw44 { |
|
|
|
constexpr int fh = filter_size; |
|
|
|
constexpr int fw = filter_size; |
|
|
|
constexpr int ic_step = 1; |
|
|
|
#if MEGDNN_ARMV7 |
|
|
|
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) || defined(MEGDNN_ARMV7) |
|
|
|
constexpr int big_oc_step = 4; |
|
|
|
#else |
|
|
|
constexpr int big_oc_step = 8; |
|
|
|