|
|
@@ -54,6 +54,33 @@ cb(dt_uint8, __m256i, "avx2", uint8_t, __m256i, mm256, si256, epi8, SIMDType::AV |
|
|
|
cb(dt_float32, float, "avx2", float, __m256, mm256, ps, ps, SIMDType::AVX2); |
|
|
|
|
|
|
|
#undef cb |
|
|
|
|
|
|
|
//! visitor for handle BCAST101xX(4) at AVX2, load 128, broadcast to 256 |
|
|
|
template <typename ctype, SIMDType simd_type = SIMDType::AVX2> |
|
|
|
struct ParamElemVisitorHalfBoardCast; |
|
|
|
|
|
|
|
#define cb( \ |
|
|
|
_ctype, _simd_ptr_type, load_half_fuc, half_type, _simd_type, board_cast_func) \ |
|
|
|
template <> \ |
|
|
|
struct ParamElemVisitorHalfBoardCast<_ctype, SIMDType::AVX2> { \ |
|
|
|
MEGDNN_ATTRIBUTE_TARGET("avx2") \ |
|
|
|
_simd_type operator()(const _ctype* src) const { \ |
|
|
|
half_type tmp = \ |
|
|
|
load_half_fuc(reinterpret_cast<_simd_ptr_type const*>(src)); \ |
|
|
|
return board_cast_func(tmp, tmp); \ |
|
|
|
} \ |
|
|
|
} |
|
|
|
|
|
|
|
cb(dt_qint32, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); |
|
|
|
cb(dt_qint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); |
|
|
|
cb(dt_quint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); |
|
|
|
cb(dt_int32, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); |
|
|
|
cb(dt_int16, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); |
|
|
|
cb(dt_int8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); |
|
|
|
cb(dt_uint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); |
|
|
|
cb(dt_float32, float, _mm_load_ps, __m128, __m256, _mm256_set_m128); |
|
|
|
|
|
|
|
#undef cb |
|
|
|
/*! |
|
|
|
* \brief broadcast type |
|
|
|
* BCAST_x[0]x[1]...: x[i] == !stride[i] |
|
|
@@ -71,7 +98,8 @@ enum BcastType { |
|
|
|
BCAST101_VEC_BCAST101, |
|
|
|
VEC_BCAST101_VEC, |
|
|
|
VEC_SCALAR_VEC, |
|
|
|
VEC_SCALAR_SCALAR |
|
|
|
VEC_SCALAR_SCALAR, |
|
|
|
VEC_BCAST101xX |
|
|
|
}; |
|
|
|
|
|
|
|
///////////////////////////////// OpCaller ///////////////////////////// |
|
|
@@ -227,6 +255,106 @@ struct OpCallerBinary<Op, SIMDType::NONE, VEC_BCAST101> { |
|
|
|
}; |
|
|
|
#undef OP_CALLER |
|
|
|
|
|
|
|
template <typename Op> |
|
|
|
struct OpCallerBinary<Op, SIMDType::SSE4_2, VEC_BCAST101xX> { |
|
|
|
MEGDNN_ATTRIBUTE_TARGET("sse4.2") |
|
|
|
static void run( |
|
|
|
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, |
|
|
|
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, |
|
|
|
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride, |
|
|
|
size_t channel_block_dim) { |
|
|
|
megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); |
|
|
|
Op op(src0_dtype, src1_dtype, dst_dtype); |
|
|
|
ParamElemVisitor<typename Op::src_ctype, SIMDType::SSE4_2> vis0; |
|
|
|
ParamElemVisitor<typename Op::src_ctype, SIMDType::SSE4_2> vis1; |
|
|
|
for (size_t b = 0; b < batch; b++) { |
|
|
|
const typename Op::src_ctype* src1_ptr = src1; |
|
|
|
for (size_t c = 0; c < channel; c++) { |
|
|
|
auto src1_block_ptr = src1_ptr + c * channel_block_dim; |
|
|
|
auto channel_block_vec = vis1(src1_block_ptr); |
|
|
|
size_t img_index = 0; |
|
|
|
auto src0_offset = Op::SIMD_WIDTH / channel_block_dim; |
|
|
|
for (; img_index + 2 * src0_offset <= channel_stride; |
|
|
|
img_index += 2 * src0_offset) { |
|
|
|
op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, |
|
|
|
{{channel_block_vec, channel_block_vec}}, dst); |
|
|
|
src0 += Op::SIMD_WIDTH * 2; |
|
|
|
dst += Op::SIMD_WIDTH * 2; |
|
|
|
} |
|
|
|
for (; img_index < channel_stride; img_index++) { |
|
|
|
for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { |
|
|
|
op(*src0, *(src1_block_ptr + c_iter), dst); |
|
|
|
src0++; |
|
|
|
dst++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename Op> |
|
|
|
struct OpCallerBinary<Op, SIMDType::AVX2, VEC_BCAST101xX> { |
|
|
|
MEGDNN_ATTRIBUTE_TARGET("avx2") |
|
|
|
static void run( |
|
|
|
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, |
|
|
|
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, |
|
|
|
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride, |
|
|
|
size_t channel_block_dim) { |
|
|
|
megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); |
|
|
|
Op op(src0_dtype, src1_dtype, dst_dtype); |
|
|
|
ParamElemVisitor<typename Op::src_ctype, SIMDType::AVX2> vis0; |
|
|
|
ParamElemVisitorHalfBoardCast<typename Op::src_ctype, SIMDType::AVX2> vis1; |
|
|
|
for (size_t b = 0; b < batch; b++) { |
|
|
|
const typename Op::src_ctype* src1_ptr = src1; |
|
|
|
for (size_t c = 0; c < channel; c++) { |
|
|
|
auto src1_block_ptr = src1_ptr + c * channel_block_dim; |
|
|
|
auto channel_block_vec = vis1(src1_block_ptr); |
|
|
|
size_t img_index = 0; |
|
|
|
auto src0_offset = Op::SIMD_WIDTH / channel_block_dim; |
|
|
|
for (; img_index + 2 * src0_offset <= channel_stride; |
|
|
|
img_index += 2 * src0_offset) { |
|
|
|
op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, |
|
|
|
{{channel_block_vec, channel_block_vec}}, dst); |
|
|
|
src0 += Op::SIMD_WIDTH * 2; |
|
|
|
dst += Op::SIMD_WIDTH * 2; |
|
|
|
} |
|
|
|
for (; img_index < channel_stride; img_index++) { |
|
|
|
for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { |
|
|
|
op(*src0, *(src1_block_ptr + c_iter), dst); |
|
|
|
src0++; |
|
|
|
dst++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename Op> |
|
|
|
struct OpCallerBinary<Op, SIMDType::NONE, VEC_BCAST101xX> { |
|
|
|
static void run( |
|
|
|
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, |
|
|
|
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, |
|
|
|
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride, |
|
|
|
size_t channel_block_dim) { |
|
|
|
Op op(src0_dtype, src1_dtype, dst_dtype); |
|
|
|
for (size_t b = 0; b < batch; b++) { |
|
|
|
auto src1_ptr = src1; |
|
|
|
for (size_t cb = 0; cb < channel; cb++) { |
|
|
|
auto src1_block_ptr = src1_ptr + cb * channel_block_dim; |
|
|
|
for (size_t img_index = 0; img_index < channel_stride; img_index++) { |
|
|
|
for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { |
|
|
|
op(*src0, *(src1_block_ptr + c_iter), dst); |
|
|
|
src0++; |
|
|
|
dst++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
#define OP_CALLER(simd_type, target_simd) \ |
|
|
|
template <typename Op> \ |
|
|
|
struct OpCallerBinary<Op, simd_type, VEC_SCALAR> { \ |
|
|
|