BCAST111C_VEC_BCAST111C and BCAST101_VEC_BCAST101
GitOrigin-RevId: 0e26553c90
tags/v1.8.0
@@ -144,21 +144,35 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( | |||
midout_iv(Mode::_mode), _type_midout_id) { \ | |||
thin_function<void( \ | |||
const _type*, const _type*, const _type*, _type*, DType, DType, \ | |||
DType, DType, size_t, size_t, size_t)> \ | |||
DType, DType, size_t, size_t, size_t, size_t)> \ | |||
run = OpCallerTernary< \ | |||
_op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \ | |||
MEGDNN_DISPATCH_CPU_KERN( \ | |||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
static_cast<const _type*>(src1.raw_ptr()), \ | |||
static_cast<const _type*>(src2.raw_ptr()), \ | |||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
binfo.x, binfo.y, binfo.z)); \ | |||
auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \ | |||
binfo, dst, run](size_t task_id, size_t) { \ | |||
size_t offset = task_id * nr_channels_per_thread; \ | |||
size_t nr_channels_thread = \ | |||
std::min(nr_channels - offset, nr_channels_per_thread); \ | |||
run(static_cast<const _type*>(src0.raw_ptr()) + offset, \ | |||
static_cast<const _type*>(src1.raw_ptr()) + offset * binfo.z, \ | |||
static_cast<const _type*>(src2.raw_ptr()) + offset, \ | |||
static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \ | |||
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ | |||
dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \ | |||
binfo.y * binfo.z); \ | |||
}; \ | |||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \ | |||
kernel); \ | |||
} \ | |||
MIDOUT_END(); \ | |||
return | |||
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle) | |||
->megcore_dispatcher() | |||
->nr_threads(); | |||
size_t nr_channels = binfo.y; | |||
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads; | |||
auto&& dst = *(kern_param.m_dst); | |||
DISPATCH_TYPE("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash); | |||
#undef DISPATCH_TERNARY | |||
@@ -181,23 +195,39 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec( | |||
midout_iv(Mode::_mode), _type_midout_id) { \ | |||
thin_function<void( \ | |||
const _type*, const _type*, size_t, const _type*, _type*, DType, \ | |||
DType, DType, DType, size_t, size_t, size_t)> \ | |||
DType, DType, DType, size_t, size_t, size_t, size_t)> \ | |||
run = OpCallerTernary< \ | |||
_op<_type, _type>, \ | |||
BcastType::BCAST111C_VEC_BCAST111C>::run; \ | |||
MEGDNN_DISPATCH_CPU_KERN( \ | |||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
static_cast<const _type*>(src1.raw_ptr()), \ | |||
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \ | |||
static_cast<const _type*>(src2.raw_ptr()), \ | |||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
binfo.x, binfo.y, binfo.z)); \ | |||
auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \ | |||
binfo, dst, run](size_t task_id, size_t) { \ | |||
size_t offset = task_id * nr_channels_per_thread; \ | |||
size_t nr_channels_thread = \ | |||
std::min(nr_channels - offset, nr_channels_per_thread); \ | |||
size_t src1_offset = \ | |||
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z; \ | |||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
static_cast<const _type*>(src1.raw_ptr()) + \ | |||
offset * (binfo.z + src1_offset), \ | |||
src1_offset, static_cast<const _type*>(src2.raw_ptr()), \ | |||
static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \ | |||
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ | |||
dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \ | |||
binfo.y * binfo.z); \ | |||
}; \ | |||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \ | |||
kernel); \ | |||
} \ | |||
MIDOUT_END(); \ | |||
return | |||
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle) | |||
->megcore_dispatcher() | |||
->nr_threads(); | |||
size_t nr_channels = binfo.y; | |||
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads; | |||
auto&& dst = *(kern_param.m_dst); | |||
DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash); | |||
#undef DISPATCH_TERNARY | |||
@@ -772,13 +772,14 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( | |||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
thin_function<void( \ | |||
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ | |||
DType, DType, DType, DType, size_t, size_t, size_t)> \ | |||
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ | |||
run = OpCallerTernary< \ | |||
_op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ | |||
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), src2.ptr<src_ctype>(), \ | |||
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||
src2.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ | |||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \ | |||
binfo.y, binfo.z, binfo.y* binfo.z)); \ | |||
return; \ | |||
} | |||
@@ -1060,7 +1060,8 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { | |||
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, | |||
DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, | |||
size_t batch_size, size_t channel_size, size_t channel_stride) { | |||
size_t batch_size, size_t channel_size, size_t channel_stride, | |||
size_t batch_offset) { | |||
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | |||
ParamElemVisitor<typename Op::src_ctype> vis1; | |||
ParamElemVisitorDup<typename Op::src_ctype> vis0; | |||
@@ -1068,6 +1069,7 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { | |||
for (size_t batch = 0; batch < batch_size; batch++) { | |||
auto src0_ptr = src0; | |||
auto src2_ptr = src2; | |||
auto b_offset = batch_offset; | |||
for (size_t channel = 0; channel < channel_size; channel++) { | |||
size_t i = 0; | |||
auto src0_neon = vis0(src0_ptr); | |||
@@ -1079,6 +1081,7 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { | |||
{{src2_neon, src2_neon}}, dst); | |||
src1 += Op::SIMD_WIDTH * 2; | |||
dst += Op::SIMD_WIDTH * 2; | |||
b_offset -= Op::SIMD_WIDTH * 2; | |||
} | |||
#if MEGDNN_FIX_AARCH32_BUG | |||
// FIXME: as llvm may cause cannot select error if enable vectorize | |||
@@ -1088,10 +1091,13 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { | |||
op(*src0_ptr, *src1, *src2_ptr, dst); | |||
src1++; | |||
dst++; | |||
b_offset--; | |||
} | |||
src0_ptr++; | |||
src2_ptr++; | |||
} | |||
src1 += b_offset; | |||
dst += b_offset; | |||
} | |||
} | |||
}; | |||
@@ -1104,10 +1110,11 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> { | |||
size_t src1_offset, const typename Op::src_ctype* src2, | |||
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, | |||
size_t channel_stride) { | |||
size_t channel_stride, size_t batch_offset) { | |||
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | |||
ParamElemVisitor<typename Op::src_ctype> vis; | |||
for (size_t batch = 0; batch < batch_size; batch++) { | |||
auto b_offset = batch_offset; | |||
for (size_t channel = 0; channel < channel_size; channel++) { | |||
auto src0_ptr = src0; | |||
auto src2_ptr = src2; | |||
@@ -1126,6 +1133,7 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> { | |||
src1 += Op::SIMD_WIDTH * 2; | |||
src2_ptr += Op::SIMD_WIDTH * 2; | |||
dst += Op::SIMD_WIDTH * 2; | |||
b_offset -= Op::SIMD_WIDTH * 2; | |||
} | |||
#if MEGDNN_FIX_AARCH32_BUG | |||
// FIXME: as llvm may cause cannot select error if enable vectorize | |||
@@ -1137,9 +1145,12 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> { | |||
src1++; | |||
src2_ptr++; | |||
dst++; | |||
b_offset--; | |||
} | |||
src1 += src1_offset; | |||
} | |||
src1 += b_offset; | |||
dst += b_offset; | |||
} | |||
} | |||
}; | |||
@@ -300,7 +300,7 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) { | |||
#endif | |||
} | |||
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, ELEMWISE_FORWARD_NHWC_FP32_BCAST) { | |||
using Mode = ElemwiseForward::Param::Mode; | |||
Checker<ElemwiseForward> checker(handle()); | |||