BCAST111C_VEC_BCAST111C and BCAST101_VEC_BCAST101
GitOrigin-RevId: 0e26553c90
tags/v1.8.0
@@ -144,21 +144,35 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( | |||||
midout_iv(Mode::_mode), _type_midout_id) { \ | midout_iv(Mode::_mode), _type_midout_id) { \ | ||||
thin_function<void( \ | thin_function<void( \ | ||||
const _type*, const _type*, const _type*, _type*, DType, DType, \ | const _type*, const _type*, const _type*, _type*, DType, DType, \ | ||||
DType, DType, size_t, size_t, size_t)> \ | |||||
DType, DType, size_t, size_t, size_t, size_t)> \ | |||||
run = OpCallerTernary< \ | run = OpCallerTernary< \ | ||||
_op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \ | _op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | |||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<const _type*>(src2.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||||
binfo.x, binfo.y, binfo.z)); \ | |||||
auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \ | |||||
binfo, dst, run](size_t task_id, size_t) { \ | |||||
size_t offset = task_id * nr_channels_per_thread; \ | |||||
size_t nr_channels_thread = \ | |||||
std::min(nr_channels - offset, nr_channels_per_thread); \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()) + offset, \ | |||||
static_cast<const _type*>(src1.raw_ptr()) + offset * binfo.z, \ | |||||
static_cast<const _type*>(src2.raw_ptr()) + offset, \ | |||||
static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \ | |||||
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ | |||||
dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \ | |||||
binfo.y * binfo.z); \ | |||||
}; \ | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \ | |||||
kernel); \ | |||||
} \ | } \ | ||||
MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
return | return | ||||
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle) | |||||
->megcore_dispatcher() | |||||
->nr_threads(); | |||||
size_t nr_channels = binfo.y; | |||||
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads; | |||||
auto&& dst = *(kern_param.m_dst); | auto&& dst = *(kern_param.m_dst); | ||||
DISPATCH_TYPE("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash); | DISPATCH_TYPE("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash); | ||||
#undef DISPATCH_TERNARY | #undef DISPATCH_TERNARY | ||||
@@ -181,23 +195,39 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec( | |||||
midout_iv(Mode::_mode), _type_midout_id) { \ | midout_iv(Mode::_mode), _type_midout_id) { \ | ||||
thin_function<void( \ | thin_function<void( \ | ||||
const _type*, const _type*, size_t, const _type*, _type*, DType, \ | const _type*, const _type*, size_t, const _type*, _type*, DType, \ | ||||
DType, DType, DType, size_t, size_t, size_t)> \ | |||||
DType, DType, DType, size_t, size_t, size_t, size_t)> \ | |||||
run = OpCallerTernary< \ | run = OpCallerTernary< \ | ||||
_op<_type, _type>, \ | _op<_type, _type>, \ | ||||
BcastType::BCAST111C_VEC_BCAST111C>::run; \ | BcastType::BCAST111C_VEC_BCAST111C>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | |||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \ | |||||
static_cast<const _type*>(src2.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||||
binfo.x, binfo.y, binfo.z)); \ | |||||
auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \ | |||||
binfo, dst, run](size_t task_id, size_t) { \ | |||||
size_t offset = task_id * nr_channels_per_thread; \ | |||||
size_t nr_channels_thread = \ | |||||
std::min(nr_channels - offset, nr_channels_per_thread); \ | |||||
size_t src1_offset = \ | |||||
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z; \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()) + \ | |||||
offset * (binfo.z + src1_offset), \ | |||||
src1_offset, static_cast<const _type*>(src2.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \ | |||||
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ | |||||
dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \ | |||||
binfo.y * binfo.z); \ | |||||
}; \ | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \ | |||||
kernel); \ | |||||
} \ | } \ | ||||
MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
return | return | ||||
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle) | |||||
->megcore_dispatcher() | |||||
->nr_threads(); | |||||
size_t nr_channels = binfo.y; | |||||
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads; | |||||
auto&& dst = *(kern_param.m_dst); | auto&& dst = *(kern_param.m_dst); | ||||
DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash); | DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash); | ||||
#undef DISPATCH_TERNARY | #undef DISPATCH_TERNARY | ||||
@@ -772,13 +772,14 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | ||||
thin_function<void( \ | thin_function<void( \ | ||||
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ | const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ | ||||
DType, DType, DType, DType, size_t, size_t, size_t)> \ | |||||
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ | |||||
run = OpCallerTernary< \ | run = OpCallerTernary< \ | ||||
_op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \ | _op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ | |||||
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), src2.ptr<src_ctype>(), \ | |||||
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||||
src2.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||||
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \ | |||||
binfo.y, binfo.z, binfo.y* binfo.z)); \ | |||||
return; \ | return; \ | ||||
} | } | ||||
@@ -1060,7 +1060,8 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { | |||||
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | ||||
const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, | const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, | ||||
DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, | DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, | ||||
size_t batch_size, size_t channel_size, size_t channel_stride) { | |||||
size_t batch_size, size_t channel_size, size_t channel_stride, | |||||
size_t batch_offset) { | |||||
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | ||||
ParamElemVisitor<typename Op::src_ctype> vis1; | ParamElemVisitor<typename Op::src_ctype> vis1; | ||||
ParamElemVisitorDup<typename Op::src_ctype> vis0; | ParamElemVisitorDup<typename Op::src_ctype> vis0; | ||||
@@ -1068,6 +1069,7 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { | |||||
for (size_t batch = 0; batch < batch_size; batch++) { | for (size_t batch = 0; batch < batch_size; batch++) { | ||||
auto src0_ptr = src0; | auto src0_ptr = src0; | ||||
auto src2_ptr = src2; | auto src2_ptr = src2; | ||||
auto b_offset = batch_offset; | |||||
for (size_t channel = 0; channel < channel_size; channel++) { | for (size_t channel = 0; channel < channel_size; channel++) { | ||||
size_t i = 0; | size_t i = 0; | ||||
auto src0_neon = vis0(src0_ptr); | auto src0_neon = vis0(src0_ptr); | ||||
@@ -1079,6 +1081,7 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { | |||||
{{src2_neon, src2_neon}}, dst); | {{src2_neon, src2_neon}}, dst); | ||||
src1 += Op::SIMD_WIDTH * 2; | src1 += Op::SIMD_WIDTH * 2; | ||||
dst += Op::SIMD_WIDTH * 2; | dst += Op::SIMD_WIDTH * 2; | ||||
b_offset -= Op::SIMD_WIDTH * 2; | |||||
} | } | ||||
#if MEGDNN_FIX_AARCH32_BUG | #if MEGDNN_FIX_AARCH32_BUG | ||||
// FIXME: as llvm may cause cannot select error if enable vectorize | // FIXME: as llvm may cause cannot select error if enable vectorize | ||||
@@ -1088,10 +1091,13 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { | |||||
op(*src0_ptr, *src1, *src2_ptr, dst); | op(*src0_ptr, *src1, *src2_ptr, dst); | ||||
src1++; | src1++; | ||||
dst++; | dst++; | ||||
b_offset--; | |||||
} | } | ||||
src0_ptr++; | src0_ptr++; | ||||
src2_ptr++; | src2_ptr++; | ||||
} | } | ||||
src1 += b_offset; | |||||
dst += b_offset; | |||||
} | } | ||||
} | } | ||||
}; | }; | ||||
@@ -1104,10 +1110,11 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> { | |||||
size_t src1_offset, const typename Op::src_ctype* src2, | size_t src1_offset, const typename Op::src_ctype* src2, | ||||
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | ||||
DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, | DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, | ||||
size_t channel_stride) { | |||||
size_t channel_stride, size_t batch_offset) { | |||||
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | ||||
ParamElemVisitor<typename Op::src_ctype> vis; | ParamElemVisitor<typename Op::src_ctype> vis; | ||||
for (size_t batch = 0; batch < batch_size; batch++) { | for (size_t batch = 0; batch < batch_size; batch++) { | ||||
auto b_offset = batch_offset; | |||||
for (size_t channel = 0; channel < channel_size; channel++) { | for (size_t channel = 0; channel < channel_size; channel++) { | ||||
auto src0_ptr = src0; | auto src0_ptr = src0; | ||||
auto src2_ptr = src2; | auto src2_ptr = src2; | ||||
@@ -1126,6 +1133,7 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> { | |||||
src1 += Op::SIMD_WIDTH * 2; | src1 += Op::SIMD_WIDTH * 2; | ||||
src2_ptr += Op::SIMD_WIDTH * 2; | src2_ptr += Op::SIMD_WIDTH * 2; | ||||
dst += Op::SIMD_WIDTH * 2; | dst += Op::SIMD_WIDTH * 2; | ||||
b_offset -= Op::SIMD_WIDTH * 2; | |||||
} | } | ||||
#if MEGDNN_FIX_AARCH32_BUG | #if MEGDNN_FIX_AARCH32_BUG | ||||
// FIXME: as llvm may cause cannot select error if enable vectorize | // FIXME: as llvm may cause cannot select error if enable vectorize | ||||
@@ -1137,9 +1145,12 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> { | |||||
src1++; | src1++; | ||||
src2_ptr++; | src2_ptr++; | ||||
dst++; | dst++; | ||||
b_offset--; | |||||
} | } | ||||
src1 += src1_offset; | src1 += src1_offset; | ||||
} | } | ||||
src1 += b_offset; | |||||
dst += b_offset; | |||||
} | } | ||||
} | } | ||||
}; | }; | ||||
@@ -300,7 +300,7 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) { | |||||
#endif | #endif | ||||
} | } | ||||
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, ELEMWISE_FORWARD_NHWC_FP32_BCAST) { | |||||
using Mode = ElemwiseForward::Param::Mode; | using Mode = ElemwiseForward::Param::Mode; | ||||
Checker<ElemwiseForward> checker(handle()); | Checker<ElemwiseForward> checker(handle()); | ||||