|
|
@@ -141,6 +141,7 @@ template <typename Op, typename T> |
|
|
|
struct StoreOcxOw4Remain<1, 0, Op, T> { |
|
|
|
static void impl(int32x4_t c[2][4], const Op& op, int8_t* dst_ptr, |
|
|
|
int ld_dst_oc) { |
|
|
|
MEGDNN_MARK_USED_VAR(ld_dst_oc); |
|
|
|
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); |
|
|
|
op({{c[0][2], c[0][3]}}, reinterpret_cast<dt_qint8*>(dst_ptr + 8)); |
|
|
|
} |
|
|
@@ -149,6 +150,7 @@ struct StoreOcxOw4Remain<1, 0, Op, T> { |
|
|
|
template <typename Op, typename T> |
|
|
|
struct StoreOcxOw4Remain<1, 3, Op, T> { |
|
|
|
static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) { |
|
|
|
MEGDNN_MARK_USED_VAR(ld_dst_oc); |
|
|
|
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); |
|
|
|
op(c[0][2], reinterpret_cast<dt_qint8*>(dst_ptr + 8)); |
|
|
|
} |
|
|
@@ -156,12 +158,14 @@ struct StoreOcxOw4Remain<1, 3, Op, T> { |
|
|
|
template <typename Op, typename T> |
|
|
|
struct StoreOcxOw4Remain<1, 2, Op, T> { |
|
|
|
static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) { |
|
|
|
MEGDNN_MARK_USED_VAR(ld_dst_oc); |
|
|
|
op({{c[0][0], c[0][1]}}, reinterpret_cast<dt_qint8*>(dst_ptr)); |
|
|
|
} |
|
|
|
}; |
|
|
|
template <typename Op, typename T> |
|
|
|
struct StoreOcxOw4Remain<1, 1, Op, T> { |
|
|
|
static void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) { |
|
|
|
MEGDNN_MARK_USED_VAR(ld_dst_oc); |
|
|
|
op(c[0][0], reinterpret_cast<dt_qint8*>(dst_ptr)); |
|
|
|
} |
|
|
|
}; |
|
|
@@ -360,6 +364,7 @@ struct InitOcxOw4<2, bias_mode, T> { |
|
|
|
template <BiasMode bias_mode, typename T> |
|
|
|
struct InitOcxOw4<1, bias_mode, T> { |
|
|
|
static void impl(T& c, const int32_t* bias_ptr, int oc_step) { |
|
|
|
MEGDNN_MARK_USED_VAR(oc_step); |
|
|
|
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { |
|
|
|
#define BAIS_INIT(step) c[0][step] = vld1q_s32(bias_ptr); |
|
|
|
UNROLL_CALL_RAW(4, BAIS_INIT); |
|
|
@@ -412,6 +417,7 @@ template <int base_offset, int ptr_step, typename Func, typename T, |
|
|
|
typename... XT> |
|
|
|
struct LoadHelper<4, base_offset, ptr_step, 0, Func, T, XT...> { |
|
|
|
static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { |
|
|
|
MEGDNN_MARK_USED_VAR(oc_offset); |
|
|
|
UNROLL_CALL_RAW(4, WEIGHT_CB); |
|
|
|
} |
|
|
|
}; |
|
|
@@ -419,6 +425,7 @@ template <int base_offset, int ptr_step, typename Func, typename T, |
|
|
|
typename... XT> |
|
|
|
struct LoadHelper<5, base_offset, ptr_step, 0, Func, T, XT...> { |
|
|
|
static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { |
|
|
|
MEGDNN_MARK_USED_VAR(oc_offset); |
|
|
|
UNROLL_CALL_RAW(5, WEIGHT_CB); |
|
|
|
} |
|
|
|
}; |
|
|
@@ -426,6 +433,7 @@ template <int base_offset, int ptr_step, typename Func, typename T, |
|
|
|
typename... XT> |
|
|
|
struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, XT...> { |
|
|
|
static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { |
|
|
|
MEGDNN_MARK_USED_VAR(oc_offset); |
|
|
|
UNROLL_CALL_RAW(6, WEIGHT_CB); |
|
|
|
} |
|
|
|
}; |
|
|
@@ -436,12 +444,14 @@ struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, XT...> { |
|
|
|
template <int base_offset, int ptr_step, typename Func, typename T> |
|
|
|
struct LoadHelper<1, base_offset, ptr_step, 1, Func, T> { |
|
|
|
static void impl(T& src, const int8_t* ptr, int oc_offset) { |
|
|
|
MEGDNN_MARK_USED_VAR(oc_offset); |
|
|
|
UNROLL_CALL_RAW(1, WEIGHT_CB); |
|
|
|
} |
|
|
|
}; |
|
|
|
template <int base_offset, int ptr_step, typename Func, typename T> |
|
|
|
struct LoadHelper<2, base_offset, ptr_step, 1, Func, T> { |
|
|
|
static void impl(T& src, const int8_t* ptr, int oc_offset) { |
|
|
|
MEGDNN_MARK_USED_VAR(oc_offset); |
|
|
|
UNROLL_CALL_RAW(2, WEIGHT_CB); |
|
|
|
} |
|
|
|
}; |
|
|
@@ -449,6 +459,7 @@ struct LoadHelper<2, base_offset, ptr_step, 1, Func, T> { |
|
|
|
template <int base_offset, int ptr_step, typename Func, typename T> |
|
|
|
struct LoadHelper<3, base_offset, ptr_step, 1, Func, T> { |
|
|
|
static void impl(T& src, const int8_t* ptr, int oc_offset) { |
|
|
|
MEGDNN_MARK_USED_VAR(oc_offset); |
|
|
|
UNROLL_CALL_RAW(3, WEIGHT_CB); |
|
|
|
} |
|
|
|
}; |
|
|
@@ -497,4 +508,4 @@ inline void load_helper_x(T& weight, const int8_t* ptr, int oc_offset, |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace |
|
|
|
} // namespace megdnn |
|
|
|
} // namespace megdnn |