GitOrigin-RevId: fb4300004c
release-1.7
@@ -104,6 +104,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available( | |||
return false; | |||
} | |||
bool ElemwiseImpl::AlgoBinaryVecBcast111C::is_available( | |||
const KernParam& kern_param) const { | |||
if (!is_available_common(kern_param.mode) || | |||
((BcastType::VEC_BCAST111C != kern_param.broad_cast_type) && | |||
(BcastType::BCAST111C_VEC != kern_param.broad_cast_type))) | |||
return false; | |||
auto& elparam = kern_param.binary_elparam; | |||
auto& src0 = elparam[0]; | |||
DISPATCH_TYPE("AlgoBinaryVecBcast111C::is_available"_hash); | |||
return false; | |||
} | |||
bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( | |||
const KernParam& kern_param) const { | |||
if (!is_available_common(kern_param.mode) || | |||
@@ -333,6 +348,72 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons | |||
return; | |||
} | |||
void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) const { | |||
auto& elparam = kern_param.binary_elparam; | |||
auto &src0 = elparam[0], &src1 = elparam[1]; | |||
auto&& dst = *(kern_param.m_dst); | |||
BroadcastChannelInfo binfo; | |||
// Case extra: BcastType::VEC + BCAST_111C | |||
if (BcastType::VEC_BCAST111C == kern_param.broad_cast_type && | |||
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { | |||
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
case Mode::_mode: \ | |||
MIDOUT_BEGIN( \ | |||
megdnn_arm_common_elemwise_binary, midout_iv(_case), \ | |||
midout_iv(Mode::_mode), _type_midout_id) { \ | |||
thin_function<void( \ | |||
const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
size_t, size_t)> \ | |||
run = OpCallerBinary< \ | |||
_op<_type, _type>, BcastType::VEC_BCAST111C>::run; \ | |||
MEGDNN_DISPATCH_CPU_KERN( \ | |||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||
static_cast<const _type*>(src1.raw_ptr), \ | |||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | |||
binfo.z)); \ | |||
} \ | |||
MIDOUT_END(); \ | |||
return | |||
DISPATCH_TYPE("AlgoBinaryVecBcast111C::exec_vec_b"_hash); | |||
#undef DISPATCH_BINARY | |||
} | |||
// BCAST_111C + BcastType::VEC | |||
if (BcastType::BCAST111C_VEC == kern_param.broad_cast_type && | |||
is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { | |||
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
case Mode::_mode: \ | |||
MIDOUT_BEGIN( \ | |||
megdnn_arm_common_elemwise_binary, midout_iv(_case), \ | |||
midout_iv(Mode::_mode), _type_midout_id) { \ | |||
thin_function<void( \ | |||
const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
size_t, size_t)> \ | |||
run = OpCallerBinary< \ | |||
_op<_type, _type>, BcastType::BCAST111C_VEC>::run; \ | |||
MEGDNN_DISPATCH_CPU_KERN( \ | |||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||
static_cast<const _type*>(src1.raw_ptr), \ | |||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | |||
binfo.z)); \ | |||
} \ | |||
MIDOUT_END(); \ | |||
return | |||
DISPATCH_TYPE("AlgoBinaryVecBcast111C::exec_b_vec"_hash); | |||
#undef DISPATCH_BINARY | |||
} | |||
return; | |||
} | |||
void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(const KernParam& kern_param) const { | |||
auto& elparam = kern_param.binary_elparam; | |||
auto &src0 = elparam[0], &src1 = elparam[1]; | |||
@@ -33,6 +33,7 @@ namespace arm_common { | |||
DECL_CB(VecVec); | |||
DECL_CB(VecScalar); | |||
DECL_CB(VecBcast101); | |||
DECL_CB(VecBcast111C); | |||
DECL_CB(VecBcast101xX); | |||
#undef DECL_CB | |||
} // namespace arm_common | |||
@@ -27,12 +27,15 @@ class ElemwiseImpl::AlgoPack { | |||
AlgoBinaryVecVec algo_binary_vec_vec; | |||
AlgoBinaryVecScalar algo_binary_vec_sca; | |||
AlgoBinaryVecBcast101 algo_binary_vec_bcast101; | |||
AlgoBinaryVecBcast111C algo_binary_vec_bcast110; | |||
AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX; | |||
AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; | |||
AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca; | |||
AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101; | |||
AlgoTernaryFma3Bcast111CVecBcast111C algo_ternaryfma3_bcast110_vec_bcast110; | |||
AlgoTernaryFma3Bcast101xXVecBcast101xX algo_ternaryfma3_bcast101xX_vec_bcast101xX; | |||
AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec; | |||
AlgoTernaryFma3VecBcast111CVec algo_ternaryfma3_vec_bcast110_vec; | |||
AlgoTernaryFma3VecBcast101xXVec algo_ternaryfma3_vec_bcast101xX_vec; | |||
AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec; | |||
AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca; | |||
@@ -43,12 +46,15 @@ public: | |||
all_algos.emplace_back(&algo_binary_vec_vec); | |||
all_algos.emplace_back(&algo_binary_vec_sca); | |||
all_algos.emplace_back(&algo_binary_vec_bcast101); | |||
all_algos.emplace_back(&algo_binary_vec_bcast110); | |||
all_algos.emplace_back(&algo_binary_VEC_BCAST101xX); | |||
all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); | |||
all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca); | |||
all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101); | |||
all_algos.emplace_back(&algo_ternaryfma3_bcast110_vec_bcast110); | |||
all_algos.emplace_back(&algo_ternaryfma3_bcast101xX_vec_bcast101xX); | |||
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec); | |||
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast110_vec); | |||
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101xX_vec); | |||
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec); | |||
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca); | |||
@@ -87,6 +93,14 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { | |||
kern_param.mode = opr->param().mode; | |||
kern_param.handle = opr->handle(); | |||
auto is_legal_layout_for_nhwc = [](const TensorLayout& l) { | |||
if (is_vector(l)) | |||
return true; | |||
if (l.ndim == 2 && l.stride[1] == 1) | |||
return true; | |||
return false; | |||
}; | |||
if ((opr->m_src->size() == 3) && (opr->param().mode == Mode::FUSE_MUL_ADD3)) { | |||
kern_param.ternary_elparam = opr->make_elemwise_op_param<3>(); | |||
bool c_is_scalar; | |||
@@ -127,6 +141,20 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { | |||
return kern_param; | |||
} | |||
if (is_legal_layout_for_nhwc(src1.layout) && | |||
is_NHWC_broadcasted_channel_like(src0.layout, binfo) && | |||
src0.layout.eq_layout(src2.layout)) { | |||
kern_param.broad_cast_type = BcastType::BCAST111C_VEC_BCAST111C; | |||
return kern_param; | |||
} | |||
if (is_legal_layout_for_nhwc(src0.layout) && | |||
src2.layout.eq_layout(src0.layout) && | |||
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { | |||
kern_param.broad_cast_type = BcastType::VEC_BCAST111C_VEC; | |||
return kern_param; | |||
} | |||
if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && | |||
(is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) { | |||
@@ -174,6 +202,18 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { | |||
return kern_param; | |||
} | |||
if (is_legal_layout_for_nhwc(src1.layout) && | |||
is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { | |||
kern_param.broad_cast_type = BcastType::BCAST111C_VEC; | |||
return kern_param; | |||
} | |||
if (is_legal_layout_for_nhwc(src0.layout) && | |||
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { | |||
kern_param.broad_cast_type = BcastType::VEC_BCAST111C; | |||
return kern_param; | |||
} | |||
if (is_vector(src0.layout) && | |||
(is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) { | |||
@@ -38,12 +38,15 @@ private: | |||
class AlgoBinaryVecVec; | |||
class AlgoBinaryVecScalar; | |||
class AlgoBinaryVecBcast101; | |||
class AlgoBinaryVecBcast111C; | |||
class AlgoBinaryVecBcast101xX; | |||
class AlgoTernaryFma3VecVecVec; | |||
class AlgoTernaryFma3VecVecScalar; | |||
class AlgoTernaryFma3Bcast101VecBcast101; | |||
class AlgoTernaryFma3Bcast111CVecBcast111C; | |||
class AlgoTernaryFma3Bcast101xXVecBcast101xX; | |||
class AlgoTernaryFma3VecBcast101Vec; | |||
class AlgoTernaryFma3VecBcast111CVec; | |||
class AlgoTernaryFma3VecBcast101xXVec; | |||
class AlgoTernaryFma3VecScalarVec; | |||
class AlgoTernaryFma3VecScalarScalar; | |||
@@ -42,8 +42,10 @@ using namespace arm_common; | |||
DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC); | |||
DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR); | |||
DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101); | |||
DECL_AVAILABLE(Bcast111CVecBcast111C, BcastType::BCAST111C_VEC_BCAST111C); | |||
DECL_AVAILABLE(Bcast101xXVecBcast101xX, BcastType::BCAST101xX_VEC_BCAST101xX); | |||
DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC); | |||
DECL_AVAILABLE(VecBcast111CVec, BcastType::VEC_BCAST111C_VEC); | |||
DECL_AVAILABLE(VecBcast101xXVec, BcastType::VEC_BCAST101xX_VEC); | |||
DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC); | |||
DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR); | |||
@@ -164,6 +166,45 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( | |||
return; | |||
} | |||
void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec( | |||
const KernParam& kern_param) const { | |||
auto& elparam = kern_param.ternary_elparam; | |||
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
// Case 3: shape of src0 and src2 is {1, 1, 1, C} | |||
BroadcastChannelInfo binfo; | |||
is_NHWC_broadcasted_channel_like(src0.layout, binfo); | |||
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
case Mode::_mode: \ | |||
MIDOUT_BEGIN( \ | |||
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ | |||
midout_iv(Mode::_mode), _type_midout_id) { \ | |||
thin_function<void( \ | |||
const _type*, const _type*, size_t, const _type*, _type*, DType, \ | |||
DType, DType, DType, size_t, size_t, size_t)> \ | |||
run = OpCallerTernary< \ | |||
_op<_type, _type>, \ | |||
BcastType::BCAST111C_VEC_BCAST111C>::run; \ | |||
MEGDNN_DISPATCH_CPU_KERN( \ | |||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||
static_cast<const _type*>(src1.raw_ptr), \ | |||
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \ | |||
static_cast<const _type*>(src2.raw_ptr), \ | |||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
binfo.x, binfo.y, binfo.z)); \ | |||
} \ | |||
MIDOUT_END(); \ | |||
return | |||
auto&& dst = *(kern_param.m_dst); | |||
DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash); | |||
#undef DISPATCH_TERNARY | |||
return; | |||
} | |||
void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec( | |||
const KernParam& kern_param) const { | |||
auto& elparam = kern_param.ternary_elparam; | |||
@@ -282,6 +323,45 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( | |||
return; | |||
} | |||
void ElemwiseImpl::AlgoTernaryFma3VecBcast111CVec::exec( | |||
const KernParam& kern_param) const { | |||
auto& elparam = kern_param.ternary_elparam; | |||
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
// Case 4: shape of src1 is {1, 1, 1, C}, and src0 and src2 are contig | |||
BroadcastChannelInfo binfo; | |||
is_NHWC_broadcasted_channel_like(src1.layout, binfo); | |||
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
case Mode::_mode: \ | |||
MIDOUT_BEGIN( \ | |||
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ | |||
midout_iv(Mode::_mode), _type_midout_id) { \ | |||
thin_function<void( \ | |||
const _type*, size_t, const _type*, const _type*, size_t, _type*, \ | |||
DType, DType, DType, DType, size_t, size_t, size_t)> \ | |||
run = OpCallerTernary< \ | |||
_op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run; \ | |||
MEGDNN_DISPATCH_CPU_KERN( \ | |||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||
is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \ | |||
static_cast<const _type*>(src1.raw_ptr), \ | |||
static_cast<const _type*>(src2.raw_ptr), \ | |||
is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \ | |||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
binfo.x, binfo.y, binfo.z)); \ | |||
} \ | |||
MIDOUT_END(); \ | |||
return | |||
auto&& dst = *(kern_param.m_dst); | |||
DISPATCH_TYPE("AlgoTernaryFma3VecBcast111CVec::exec"_hash); | |||
#undef DISPATCH_TERNARY | |||
return; | |||
} | |||
void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec( | |||
const KernParam& kern_param) const { | |||
auto& elparam = kern_param.ternary_elparam; | |||
@@ -33,8 +33,10 @@ namespace arm_common { | |||
DECL_CB(VecVecVec); | |||
DECL_CB(VecVecScalar); | |||
DECL_CB(Bcast101VecBcast101); | |||
DECL_CB(Bcast111CVecBcast111C); | |||
DECL_CB(Bcast101xXVecBcast101xX); | |||
DECL_CB(VecBcast101Vec); | |||
DECL_CB(VecBcast111CVec); | |||
DECL_CB(VecBcast101xXVec); | |||
DECL_CB(VecScalarVec); | |||
DECL_CB(VecScalarScalar); | |||
@@ -107,16 +107,20 @@ enum BcastType { | |||
VEC, | |||
VEC_VEC, | |||
VEC_BCAST101, | |||
VEC_BCAST111C, | |||
VEC_BCAST101xX, | |||
VEC_SCALAR, | |||
SCALAR_VEC, | |||
BCAST101_VEC, | |||
BCAST111C_VEC, | |||
BCAST101xX_VEC, | |||
VEC_VEC_VEC, | |||
VEC_VEC_SCALAR, | |||
BCAST101_VEC_BCAST101, | |||
BCAST111C_VEC_BCAST111C, | |||
BCAST101xX_VEC_BCAST101xX, | |||
VEC_BCAST101_VEC, | |||
VEC_BCAST111C_VEC, | |||
VEC_BCAST101xX_VEC, | |||
VEC_SCALAR_VEC, | |||
VEC_SCALAR_SCALAR, | |||
@@ -227,6 +231,60 @@ struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101> { | |||
}; | |||
template <typename ctype> | |||
struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST111C> { | |||
using Op = PowOp<ctype, ctype>; | |||
static void run( | |||
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { | |||
Op op(src0_dtype, src1_dtype, dst_dtype); | |||
for (size_t b = 0; b < batch; b++) { | |||
for (size_t c = 0; c < channel; c++) { | |||
size_t i = 0; | |||
const typename Op::src_ctype* src1_ptr = src1; | |||
#if MEGDNN_FIX_AARCH32_BUG | |||
// FIXME: as llvm may cause cannot select error if enable vectorize | |||
#pragma clang loop vectorize(disable) | |||
#endif | |||
for (; i < channel_stride; i++) { | |||
op(*src0, *src1_ptr, dst); | |||
src0++; | |||
src1_ptr++; | |||
dst++; | |||
} | |||
} | |||
} | |||
} | |||
}; | |||
template <typename ctype> | |||
struct OpCallerBinary<PowOp<ctype, ctype>, BCAST111C_VEC> { | |||
using Op = PowOp<ctype, ctype>; | |||
static void run( | |||
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { | |||
Op op(src0_dtype, src1_dtype, dst_dtype); | |||
for (size_t b = 0; b < batch; b++) { | |||
for (size_t c = 0; c < channel; c++) { | |||
size_t i = 0; | |||
const typename Op::src_ctype* src0_ptr = src0; | |||
#if MEGDNN_FIX_AARCH32_BUG | |||
// FIXME: as llvm may cause cannot select error if enable vectorize | |||
#pragma clang loop vectorize(disable) | |||
#endif | |||
for (; i < channel_stride; i++) { | |||
op(*src0_ptr, *src1, dst); | |||
src0_ptr++; | |||
src1++; | |||
dst++; | |||
} | |||
} | |||
} | |||
} | |||
}; | |||
template <typename ctype> | |||
struct OpCallerBinary<PowOp<ctype, ctype>, SCALAR_VEC> { | |||
using Op = PowOp<ctype, ctype>; | |||
static void run( | |||
@@ -340,6 +398,84 @@ struct OpCallerBinary<Op, VEC_BCAST101> { | |||
} | |||
}; | |||
template <typename Op> | |||
struct OpCallerBinary<Op, VEC_BCAST111C> { | |||
static void run( | |||
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { | |||
Op op(src0_dtype, src1_dtype, dst_dtype); | |||
ParamElemVisitor<typename Op::src_ctype> vis; | |||
for (size_t b = 0; b < batch; b++) { | |||
for (size_t c = 0; c < channel; c++) { | |||
size_t rest = channel_stride; | |||
const typename Op::src_ctype* src1_ptr = src1; | |||
while (rest >= Op::SIMD_WIDTH * 2) { | |||
auto src0_neon0 = vis(src0); | |||
auto src0_neon1 = vis(src0 + Op::SIMD_WIDTH); | |||
auto src1_neon0 = vis(src1_ptr); | |||
auto src1_neon1 = vis(src1_ptr + Op::SIMD_WIDTH); | |||
src0 += Op::SIMD_WIDTH * 2; | |||
src1_ptr += Op::SIMD_WIDTH * 2; | |||
op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); | |||
dst += Op::SIMD_WIDTH * 2; | |||
rest -= Op::SIMD_WIDTH * 2; | |||
} | |||
#if MEGDNN_FIX_AARCH32_BUG | |||
// FIXME: as llvm may cause cannot select error if enable vectorize | |||
#pragma clang loop vectorize(disable) | |||
#endif | |||
while (rest > 0) { | |||
op(*src0, *src1_ptr, dst); | |||
dst++; | |||
src0++; | |||
src1_ptr++; | |||
rest--; | |||
} | |||
} | |||
} | |||
} | |||
}; | |||
template <typename Op> | |||
struct OpCallerBinary<Op, BCAST111C_VEC> { | |||
static void run( | |||
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { | |||
Op op(src0_dtype, src1_dtype, dst_dtype); | |||
ParamElemVisitor<typename Op::src_ctype> vis; | |||
for (size_t b = 0; b < batch; b++) { | |||
for (size_t c = 0; c < channel; c++) { | |||
size_t rest = channel_stride; | |||
const typename Op::src_ctype* src0_ptr = src0; | |||
while (rest >= Op::SIMD_WIDTH * 2) { | |||
auto src0_neon0 = vis(src0_ptr); | |||
auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH); | |||
auto src1_neon0 = vis(src1); | |||
auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH); | |||
src0_ptr += Op::SIMD_WIDTH * 2; | |||
src1 += Op::SIMD_WIDTH * 2; | |||
op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); | |||
dst += Op::SIMD_WIDTH * 2; | |||
rest -= Op::SIMD_WIDTH * 2; | |||
} | |||
#if MEGDNN_FIX_AARCH32_BUG | |||
// FIXME: as llvm may cause cannot select error if enable vectorize | |||
#pragma clang loop vectorize(disable) | |||
#endif | |||
while (rest > 0) { | |||
op(*src0_ptr, *src1, dst); | |||
dst++; | |||
src0_ptr++; | |||
src1++; | |||
rest--; | |||
} | |||
} | |||
} | |||
} | |||
}; | |||
template <typename ctype> | |||
struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101xX_VEC> { | |||
using Op = PowOp<ctype, ctype>; | |||
@@ -824,6 +960,54 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { | |||
} | |||
}; | |||
//! src0: 111C, src1: vector, src2: 111C, src1 may not be contig | |||
template <typename Op> | |||
struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> { | |||
static void run( | |||
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
size_t src1_offset, const typename Op::src_ctype* src2, | |||
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, | |||
size_t channel_stride) { | |||
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | |||
ParamElemVisitor<typename Op::src_ctype> vis; | |||
for (size_t batch = 0; batch < batch_size; batch++) { | |||
for (size_t channel = 0; channel < channel_size; channel++) { | |||
auto src0_ptr = src0; | |||
auto src2_ptr = src2; | |||
size_t i = 0; | |||
for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; | |||
i += Op::SIMD_WIDTH * 2) { | |||
auto src0_neon0 = vis(src0_ptr); | |||
auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH); | |||
auto src1_neon0 = vis(src1); | |||
auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH); | |||
auto src2_neon0 = vis(src2_ptr); | |||
auto src2_neon1 = vis(src2_ptr + Op::SIMD_WIDTH); | |||
op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, | |||
{{src2_neon0, src2_neon1}}, dst); | |||
src0_ptr += Op::SIMD_WIDTH * 2; | |||
src1 += Op::SIMD_WIDTH * 2; | |||
src2_ptr += Op::SIMD_WIDTH * 2; | |||
dst += Op::SIMD_WIDTH * 2; | |||
} | |||
#if MEGDNN_FIX_AARCH32_BUG | |||
// FIXME: as llvm may cause cannot select error if enable vectorize | |||
#pragma clang loop vectorize(disable) | |||
#endif | |||
for (; i < channel_stride; i++) { | |||
op(*src0_ptr, *src1, *src2_ptr, dst); | |||
src0_ptr++; | |||
src1++; | |||
src2_ptr++; | |||
dst++; | |||
} | |||
src1 += src1_offset; | |||
} | |||
} | |||
} | |||
}; | |||
template <typename src_ctype, size_t channel_block_dim> | |||
struct OpCallerTernaryBcast101xXVecBcast101xX { | |||
template <typename Op> | |||
@@ -992,6 +1176,51 @@ struct OpCallerTernary<Op, VEC_BCAST101_VEC> { | |||
} | |||
}; | |||
//! src1: 111C, src0 and src2 may not be contig | |||
template <typename Op> | |||
struct OpCallerTernary<Op, VEC_BCAST111C_VEC> { | |||
static void run( | |||
const typename Op::src_ctype* src0, size_t src0_offset, | |||
const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, | |||
size_t src2_offset, typename Op::dst_ctype* dst, DType src0_dtype, | |||
DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size, | |||
size_t channel_size, size_t channel_stride) { | |||
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | |||
ParamElemVisitor<typename Op::src_ctype> vis0; | |||
ParamElemVisitor<typename Op::src_ctype> vis1; | |||
ParamElemVisitor<typename Op::src_ctype> vis2; | |||
for (size_t batch = 0; batch < batch_size; batch++) { | |||
for (size_t channel = 0; channel < channel_size; channel++) { | |||
auto src1_ptr = src1; | |||
size_t i = 0; | |||
for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; | |||
i += Op::SIMD_WIDTH * 2) { | |||
op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, | |||
{{vis1(src1_ptr), vis1(src1_ptr + Op::SIMD_WIDTH)}}, | |||
{{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); | |||
src0 += Op::SIMD_WIDTH * 2; | |||
src1_ptr += Op::SIMD_WIDTH * 2; | |||
src2 += Op::SIMD_WIDTH * 2; | |||
dst += Op::SIMD_WIDTH * 2; | |||
} | |||
#if MEGDNN_FIX_AARCH32_BUG | |||
// FIXME: as llvm may cause cannot select error if enable vectorize | |||
#pragma clang loop vectorize(disable) | |||
#endif | |||
for (; i < channel_stride; i++) { | |||
op(*src0, *src1_ptr, *src2, dst); | |||
src0++; | |||
src1_ptr++; | |||
src2++; | |||
dst++; | |||
} | |||
src0 += src0_offset; | |||
src2 += src2_offset; | |||
} | |||
} | |||
} | |||
}; | |||
template <typename src_ctype, size_t channel_block_dim> | |||
struct OpCallerTernaryVecBcast101xXVec { | |||
template <typename Op> | |||
@@ -50,6 +50,20 @@ inline dt_qint32 QConverter::convert(const float& src) { | |||
saturate<int32_t, float>(std::round(src), -2147483648, 2147483647)); | |||
} | |||
template <> | |||
inline float32x4x2_t QConverter::convert(const int16x8_t& vsrc) { | |||
int32x4_t vhi = vmovl_s16(vget_high_s16(vsrc)); | |||
int32x4_t vlo = vmovl_s16(vget_low_s16(vsrc)); | |||
return {{vcvtq_f32_s32(vlo), vcvtq_f32_s32(vhi)}}; | |||
} | |||
template <> | |||
inline float32x4x2_t QConverter::convert(const uint16x8_t& vsrc) { | |||
uint32x4_t vhi = vmovl_u16(vget_high_u16(vsrc)); | |||
uint32x4_t vlo = vmovl_u16(vget_low_u16(vsrc)); | |||
return {{vcvtq_f32_u32(vlo), vcvtq_f32_u32(vhi)}}; | |||
} | |||
#if __ARM_ARCH >= 8 | |||
template <> | |||
inline int8x8_t QConverter::convert(const float32x4x2_t& vsrc) { | |||
@@ -17,6 +17,7 @@ | |||
#include "src/common/utils.h" | |||
#include "src/naive/handle.h" | |||
MIDOUT_DECL(megdnn_arm_typecvt_fix2float) | |||
MIDOUT_DECL(megdnn_arm_typecvt_quantized) | |||
MIDOUT_DECL(megdnn_arm_typecvt_float) | |||
@@ -325,6 +326,48 @@ struct FloatTypeCvter<float, __fp16> { | |||
}; | |||
#endif | |||
template <typename ctype, typename dtype> | |||
struct Fix2FloatTypeCvter; | |||
template <> | |||
struct Fix2FloatTypeCvter<int16_t, float> { | |||
using stype = int16_t; | |||
using dst_type = float; | |||
static constexpr size_t SIMD_WIDTH = 8; | |||
Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) { | |||
MEGDNN_MARK_USED_VAR(src_dtype); | |||
MEGDNN_MARK_USED_VAR(dst_dtype); | |||
} | |||
void cvt(const int16_t* src, float* dst) { | |||
int16x8_t vitem = vld1q_s16(src); | |||
auto vres = QConverter::convert<float32x4x2_t, int16x8_t>(vitem); | |||
vst1q_f32_x2(dst, vres); | |||
} | |||
void cvt_remain(const int16_t* src, float* dst) { *dst = *src; } | |||
}; | |||
template <> | |||
struct Fix2FloatTypeCvter<uint16_t, float> { | |||
using stype = uint16_t; | |||
using dst_type = float; | |||
static constexpr size_t SIMD_WIDTH = 8; | |||
Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) { | |||
MEGDNN_MARK_USED_VAR(src_dtype); | |||
MEGDNN_MARK_USED_VAR(dst_dtype); | |||
} | |||
void cvt(const uint16_t* src, float* dst) { | |||
uint16x8_t vitem = vld1q_u16(src); | |||
auto vres = QConverter::convert<float32x4x2_t, uint16x8_t>(vitem); | |||
vst1q_f32_x2(dst, vres); | |||
} | |||
void cvt_remain(const uint16_t* src, float* dst) { *dst = *src; } | |||
}; | |||
template <typename TypeCvter> | |||
void do_typecvt( | |||
const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, | |||
@@ -347,6 +390,43 @@ void do_typecvt( | |||
} | |||
} | |||
template <typename TypeCvter> | |||
void do_typecvt( | |||
const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, | |||
DType src_dtype, DType dst_dtype, const TensorLayout& src_layout) { | |||
TypeCvter typecvt(src_dtype, dst_dtype); | |||
size_t calc_num = 1; | |||
size_t nr_elems = src_layout.total_nr_elems(); | |||
size_t src_stride = nr_elems; | |||
//! adjust calc_num nr_elems and src_stride according to src_collapse_layout | |||
auto src_collapse_layout = src_layout.collapse_contiguous(); | |||
if (src_collapse_layout.ndim == 2) { | |||
calc_num = src_collapse_layout.shape[0]; | |||
nr_elems = src_collapse_layout.shape[1]; | |||
src_stride = src_collapse_layout.stride[0]; | |||
} | |||
for (size_t c = 0; c < calc_num; ++c) { | |||
size_t i = 0; | |||
for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { | |||
typecvt.cvt(src, dst); | |||
src += TypeCvter::SIMD_WIDTH; | |||
dst += TypeCvter::SIMD_WIDTH; | |||
} | |||
#if MEGDNN_FIX_AARCH32_BUG | |||
// FIXME: as llvm may cause cannot select error if enable vectorize | |||
#pragma clang loop vectorize(disable) | |||
#endif | |||
for (; i < nr_elems; i++) { | |||
typecvt.cvt_remain(src, dst); | |||
src++; | |||
dst++; | |||
} | |||
src += src_stride - nr_elems; | |||
} | |||
} | |||
} // anonymous namespace | |||
void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
@@ -354,7 +434,30 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
DType dst_dtype = dst.layout.dtype; | |||
size_t nr_elems = src.layout.total_nr_elems(); | |||
bool execed = false; | |||
if (src.layout.is_contiguous()) { | |||
auto src_collapse_layout = src.layout.collapse_contiguous(); | |||
bool has_int16_special_impl = | |||
(src.layout.dtype.enumv() == DTypeEnum::Int16 || | |||
src.layout.dtype.enumv() == DTypeEnum::Uint16) && | |||
(src.layout.is_contiguous() || src_collapse_layout.ndim == 2) && | |||
dst.layout.is_contiguous(); | |||
if (has_int16_special_impl) { | |||
using namespace dtype; | |||
#define DISPATCH_FIX2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ | |||
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ | |||
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ | |||
MIDOUT_BEGIN(megdnn_arm_typecvt_fix2float, midout_iv(_midout_iv)) { \ | |||
using _TypeCvter = Fix2FloatTypeCvter<_stype, _dtype>; \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ | |||
src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \ | |||
src_dtype, dst_dtype, src.layout)); \ | |||
execed = true; \ | |||
} \ | |||
MIDOUT_END(); \ | |||
} | |||
DISPATCH_FIX2FLOAT(Int16, int16_t, Float32, float, 0); | |||
DISPATCH_FIX2FLOAT(Uint16, uint16_t, Float32, float, 1); | |||
#undef DISPATCH_FIX2FLOAT | |||
} else if (src.layout.is_contiguous()) { | |||
using namespace dtype; | |||
#define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ | |||
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ | |||
@@ -377,6 +480,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS32, int32_t, 5); | |||
DISPATCH_QUANTIZED(float, float, QuantizedS8, int8_t, 6); | |||
DISPATCH_QUANTIZED(float, float, Quantized8Asymm, uint8_t, 7); | |||
#undef DISPATCH_QUANTIZED | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
#define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ | |||
@@ -394,6 +498,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
} | |||
DISPATCH_FLOAT(dt_float16, __fp16, float, float, 0); | |||
DISPATCH_FLOAT(float, float, dt_float16, __fp16, 1); | |||
#undef DISPATCH_FLOAT | |||
#endif | |||
} | |||
if (!execed) { | |||
@@ -150,6 +150,19 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like( | |||
return false; | |||
} | |||
bool ElemwiseLayoutHelper::is_NHWC_broadcasted_channel_like( | |||
const TensorLayout& layout, BroadcastChannelInfo& info) { | |||
if (layout.format.type() == TensorFormat::Type::DEFAULT) { | |||
if (layout.ndim == 2 && layout.stride[1] == 1 && layout.stride[0] == 0) { | |||
info.x = 1; | |||
info.y = layout.shape[0]; | |||
info.z = layout.shape[1]; | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
bool ElemwiseLayoutHelper::is_broadcasted_1x( | |||
const TensorLayout& layout, Broadcast1xInfo& binfo) { | |||
if (layout.ndim == 2 && layout.stride[0] == 0 && layout.stride[1] == 1) { | |||
@@ -81,6 +81,16 @@ public: | |||
const TensorLayout& layout, BroadcastChannelInfo& info); | |||
/*! | |||
* \brief check whether layout matches BroadcastChannelInfo under NHWC | |||
* layout | |||
* | |||
* Note that Input must be 2-dimensional, and must be [1, y] broadacsted | |||
* into [z, y] and x would be set to 1. | |||
*/ | |||
static bool is_NHWC_broadcasted_channel_like( | |||
const TensorLayout& layout, BroadcastChannelInfo& info); | |||
/*! | |||
* \brief check whether layout matches BroadcastChannelInfo | |||
* | |||
* Note that Input can also be 3-dimensional, and must be [x, 1, z] | |||
@@ -309,7 +309,8 @@ void on_dest_ctype(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
break; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
cb(::megdnn::dtype::Bool) case DTypeEnum::QuantizedS8 | |||
cb(::megdnn::dtype::Bool) | |||
cb(::megdnn::dtype::Uint16) case DTypeEnum::QuantizedS8 | |||
: MIDOUT_BEGIN( | |||
megdnn_fb_typecvt_src_dtype, | |||
midout_iv(DTypeEnum::QuantizedS8)) { | |||
@@ -467,7 +468,8 @@ void run_contiguous(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
cb(::megdnn::dtype::Bool) case DTypeEnum::QuantizedS8 | |||
cb(::megdnn::dtype::Bool) | |||
cb(::megdnn::dtype::Uint16) case DTypeEnum::QuantizedS8 | |||
: MIDOUT_BEGIN( | |||
megdnn_fb_typecvt_dst_dtype, | |||
midout_iv(DTypeEnum::QuantizedS8)) { | |||
@@ -78,7 +78,7 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, const TensorND& src | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
cb(::megdnn::dtype::Bool) | |||
cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) | |||
#undef cb | |||
default : megdnn_throw("bad dtype"); | |||
} | |||
@@ -99,7 +99,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
cb(::megdnn::dtype::Bool) | |||
cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) | |||
#undef cb | |||
default : megdnn_throw("bad dtype"); | |||
} | |||
@@ -14,6 +14,7 @@ | |||
#include "test/common/benchmarker.h" | |||
#include "test/common/checker.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megdnn/oprs/general.h" | |||
using namespace megdnn; | |||
@@ -298,6 +299,63 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) { | |||
#endif | |||
} | |||
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) { | |||
using Mode = ElemwiseForward::Param::Mode; | |||
Checker<ElemwiseForward> checker(handle()); | |||
UniformFloatRNG rng(1e-5, 7e1); | |||
checker.set_rng(0, &rng); | |||
checker.set_epsilon(1e-5); | |||
checker.set_dtype(0, dtype::Float32()); | |||
checker.set_dtype(1, dtype::Float32()); | |||
//! 2 dim | |||
auto run = [&](Mode mode) { | |||
// VEC_BCAST111C | |||
checker.set_param(mode).execs({{1, 2, 2, 12}, {1, 1, 1, 12}, {}}); | |||
checker.set_param(mode).execs({{2, 5, 3, 28}, {1, 1, 1, 28}, {}}); | |||
checker.set_param(mode).execs({{3, 5, 8, 32}, {1, 1, 1, 32}, {}}); | |||
// BCAST111C_VEC | |||
checker.set_param(mode).execs({{1, 1, 1, 12}, {1, 2, 2, 12}, {}}); | |||
checker.set_param(mode).execs({{1, 1, 1, 28}, {2, 5, 3, 28}, {}}); | |||
checker.set_param(mode).execs({{1, 1, 1, 32}, {3, 5, 8, 32}, {}}); | |||
}; | |||
run(Mode::ADD); | |||
run(Mode::MUL); | |||
run(Mode::SUB); | |||
//! 3 dim contig | |||
auto run_3d_contig = [&](Mode mode) { | |||
// BCAST111C_VEC_BCAST111C | |||
checker.set_param(mode).execs( | |||
{{1, 1, 1, 12}, {1, 2, 2, 12}, {1, 1, 1, 12}, {}}); | |||
checker.set_param(mode).execs( | |||
{{1, 1, 1, 28}, {2, 5, 3, 28}, {1, 1, 1, 28}, {}}); | |||
checker.set_param(mode).execs( | |||
{{1, 1, 1, 32}, {3, 5, 8, 32}, {1, 1, 1, 32}, {}}); | |||
// VEC_BCAST111C_VEC | |||
checker.set_param(mode).execs( | |||
{{1, 2, 2, 12}, {1, 1, 1, 12}, {1, 2, 2, 12}, {}}); | |||
checker.set_param(mode).execs( | |||
{{2, 5, 3, 28}, {1, 1, 1, 28}, {2, 5, 3, 28}, {}}); | |||
checker.set_param(mode).execs( | |||
{{3, 5, 8, 32}, {1, 1, 1, 32}, {3, 5, 8, 32}, {}}); | |||
}; | |||
run_3d_contig(Mode::FUSE_MUL_ADD3); | |||
//! 3 dim incontig | |||
auto run_3d_incontig = [&](Mode mode) { | |||
megdnn::TensorLayout src0({1, 1, 1, 12}, dtype::Float32()); | |||
megdnn::TensorLayout src1({1, 2, 2, 12}, {80, 40, 20, 1}, dtype::Float32()); | |||
// BCAST111C_VEC_BCAST111C | |||
checker.set_param(mode).execl({src0, src1, src0, {}}); | |||
// VEC_BCAST111C_VEC | |||
checker.set_param(mode).execl({src1, src0, src1, {}}); | |||
}; | |||
run_3d_incontig(Mode::FUSE_MUL_ADD3); | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
namespace { | |||
void run_elemwise_benchmark( | |||
@@ -354,6 +412,39 @@ void run_elemwise_benchmark( | |||
} | |||
} // namespace | |||
TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NHWC) { | |||
Benchmarker<Elemwise> benchmarker(handle()); | |||
constexpr size_t RUN = 50; | |||
benchmarker.set_times(RUN).set_display(false); | |||
auto run = [&](size_t N, size_t C, size_t H, size_t W, param::Elemwise::Mode mode, | |||
const char* mode_name) { | |||
megdnn::param::Elemwise param; | |||
param.mode = mode; | |||
benchmarker.set_param(param); | |||
megdnn::TensorShape nhwc_src0{N, H, W, C}; | |||
megdnn::TensorShape nhwc_src1{1, 1, 1, C}; | |||
megdnn::TensorShape nchw_src0{N, C, H, W}; | |||
megdnn::TensorShape nchw_src1{1, C, 1, 1}; | |||
float computations = N * C * H * W; | |||
auto nhwc_time = benchmarker.execs({nhwc_src1, nhwc_src0, {}}) / RUN; | |||
auto nchw_time = benchmarker.execs({nchw_src1, nchw_src0, {}}) / RUN; | |||
auto perf_nhwc = computations / nhwc_time / 1e6; | |||
auto perf_nchw = computations / nchw_time / 1e6; | |||
printf("Elemwise Mode : %s\nNHWC : %fms %fGflops\nNCHW : %fms " | |||
"%fGflops\n", | |||
mode_name, nhwc_time, perf_nhwc, nchw_time, perf_nchw); | |||
}; | |||
run(1, 120, 16, 24, param::Elemwise::Mode::ADD, "ADD"); | |||
run(1, 120, 16, 24, param::Elemwise::Mode::MUL, "MUL"); | |||
run(1, 120, 32, 48, param::Elemwise::Mode::ADD, "ADD"); | |||
run(1, 120, 32, 48, param::Elemwise::Mode::MUL, "MUL"); | |||
run(1, 120, 64, 96, param::Elemwise::Mode::ADD, "ADD"); | |||
run(1, 120, 64, 96, param::Elemwise::Mode::MUL, "MUL"); | |||
} | |||
#define INT_RUN(shape, mode) \ | |||
run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \ | |||
run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \ | |||
@@ -88,6 +88,26 @@ TEST_F(ARM_COMMON, TYPE_CVT) { | |||
.execs({{1, 32, 24, 128}, {1, 32, 24, 128}}); | |||
} | |||
TEST_F(ARM_COMMON, TYPE_CVT_16_F32) { | |||
Checker<TypeCvt> checker(handle()); | |||
UniformIntRNG rng{INT16_MIN >> 1, INT16_MAX >> 1}; | |||
for (size_t size : {3, 7, 15, 33, 10000}) { | |||
checker.set_rng(0, &rng); | |||
checker.set_dtype(0, dtype::Int16()).execs({{size}, {size}}); | |||
checker.set_dtype(0, dtype::Uint16()).execs({{size}, {size}}); | |||
} | |||
TensorLayout src_int16{ | |||
{1, 96, 64, 120}, {128 * 64 * 96, 128 * 64, 128, 1}, dtype::Int16()}; | |||
TensorLayout dst_int16{{1, 96, 64, 120}, dtype::Float32()}; | |||
checker.execl({src_int16, dst_int16}); | |||
TensorLayout src_uint16{ | |||
{1, 96, 64, 120}, {128 * 64 * 96, 128 * 64, 128, 1}, dtype::Uint16()}; | |||
TensorLayout dst_uint16{{1, 96, 64, 120}, dtype::Float32()}; | |||
checker.execl({src_uint16, dst_uint16}); | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(ARM_COMMON, BENCHMARK_TYPE_CVT) { | |||
auto run = [&](const TensorShapeArray& shapes) { | |||
@@ -158,8 +158,9 @@ void copy_tensors( | |||
//! In order to avoid an unnecessary increase in binary size, we just | |||
//! use QuantizedS16 dtype in winograd_filter_preprocess now. | |||
cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
cb(::megdnn::dtype::Uint16) | |||
#undef cb | |||
default : megdnn_trap(); | |||
default : megdnn_trap(); | |||
} | |||
} | |||
@@ -202,6 +202,9 @@ void IIDRNG::gen(const TensorND& tensor) { | |||
memset(tensor.raw_ptr, 0, tensor.layout.access_bytes()); | |||
return; | |||
} | |||
if (tensor.layout.dtype.enumv() == DTypeEnum::Uint16) { | |||
return; | |||
} | |||
megdnn_assert( | |||
0, "IIDRNG does not know how to generate value for DType %s", | |||
tensor.layout.dtype.name()); | |||
@@ -25,6 +25,11 @@ TEST_F(CUDA, TYPE_CVT) { | |||
TensorLayout src({10, 10}, sdtype), dst({10, 10}, ddtype); | |||
Checker<TypeCvt> checker(handle_cuda()); | |||
checker.set_rng(0, &init).exec(TensorLayoutArray{src, dst}); | |||
TensorLayout non_contig_src( | |||
{1, 96, 64, 120}, {96 * 64 * 128, 64 * 128, 128, 1}, sdtype); | |||
TensorLayout non_contig_dst({1, 96, 64, 120}, ddtype); | |||
checker.exec(TensorLayoutArray{non_contig_src, non_contig_dst}); | |||
} | |||
} | |||
@@ -37,8 +37,22 @@ TEST_F(X86, TYPE_CVT) { | |||
for (auto ddtype : dtypes) { | |||
checker.set_dtype(0, sdtype).set_dtype(1, ddtype).execs( | |||
{{size}, {size}}); | |||
TensorLayout non_contig_src( | |||
{1, 10, 10, 12}, {10 * 10 * 18, 10 * 18, 18, 1}, sdtype); | |||
TensorLayout non_contig_dst({1, 10, 10, 12}, ddtype); | |||
checker.exec(TensorLayoutArray{non_contig_src, non_contig_dst}); | |||
} | |||
} | |||
for (size_t size : {1, 7, 15, 33}) { | |||
checker.set_dtype(0, dtype::Uint16()) | |||
.set_dtype(1, dtype::Float32()) | |||
.execs({{size}, {size}}); | |||
} | |||
TensorLayout non_contig_src( | |||
{1, 10, 10, 12}, {10 * 10 * 18, 10 * 18, 18, 1}, dtype::Uint16()); | |||
TensorLayout non_contig_dst({1, 10, 10, 12}, dtype::Float32()); | |||
checker.exec(TensorLayoutArray{non_contig_src, non_contig_dst}); | |||
} | |||
TEST_F(X86, TYPE_CVT_NO_CONTIGUOUS) { | |||
@@ -772,8 +772,10 @@ void TypeCvt::perform( | |||
} | |||
void TypeCvt::add_input_layout_constraint() { | |||
//! Because the implementation of typecvt on arm/x86/cuda/opencl support | |||
//! non-contiguous memory. So we change constraint of typecvt to monotone | |||
for (auto i : input()) { | |||
i->add_layout_constraint_contiguous(); | |||
i->add_layout_constraint_monotone(); | |||
} | |||
} | |||