/** * \file dnn/src/arm_common/elemwise/binary/algo.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "src/arm_common/elemwise/binary/algo.h" #include "src/arm_common/elemwise_op.h" #include "src/common/utils.h" #include "src/naive/handle.h" #include "midout.h" MIDOUT_DECL(megdnn_arm_common_elemwise_binary) using namespace megdnn; using namespace arm_common; namespace { static inline bool is_available_common(Elemwise::Mode mode) { /** * Fused sigmoid & tanh may be slower than the naive algo, because the * time used by neon function `exp_ps_f32` is decided by the input. */ if (mode == Elemwise::Mode::FUSE_ADD_SIGMOID || mode == Elemwise::Mode::FUSE_ADD_TANH) { return false; } return true; } } // anonymous namespace #if MEGDNN_AARCH64 #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ auto mode = kern_param.mode; \ if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \ mode == Mode::SUB || mode == Mode::MUL || mode == Mode::POW || \ mode == Mode::TRUE_DIV || mode == Mode::FUSE_ADD_RELU || \ mode == Mode::FUSE_ADD_H_SWISH) \ return true; #else #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ auto mode = kern_param.mode; \ if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \ mode == Mode::SUB || mode == Mode::MUL || mode == Mode::POW || \ mode == Mode::FUSE_ADD_RELU || mode == Mode::FUSE_ADD_H_SWISH) \ return true; #endif #define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ auto mode = kern_param.mode; \ if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \ mode == Mode::SUB || mode == Mode::MUL || mode == Mode::RMULH || \ mode == Mode::FUSE_ADD_RELU) \ return true; bool ElemwiseImpl::AlgoBinaryVecVec::is_available( const KernParam& kern_param) const { if (!is_available_common(kern_param.mode) || (BcastType::VEC_VEC != kern_param.broad_cast_type)) return false; auto& elparam = kern_param.binary_elparam; auto& src0 = elparam[0]; //! exactly match [x, y] + [x, y] DISPATCH_TYPE("AlgoBinaryVecVec::is_available"_hash); return false; } bool ElemwiseImpl::AlgoBinaryVecScalar::is_available( const KernParam& kern_param) const { if (!is_available_common(kern_param.mode) || ((BcastType::VEC_SCALAR != kern_param.broad_cast_type) && (BcastType::SCALAR_VEC != kern_param.broad_cast_type))) return false; auto& elparam = kern_param.binary_elparam; auto& src0 = elparam[0]; DISPATCH_TYPE("AlgoBinaryVecScalar::is_available"_hash); return false; } bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available( const KernParam& kern_param) const { if (!is_available_common(kern_param.mode) || ((BcastType::VEC_BCAST101 != kern_param.broad_cast_type) && (BcastType::BCAST101_VEC != kern_param.broad_cast_type))) return false; auto& elparam = kern_param.binary_elparam; auto& src0 = elparam[0]; DISPATCH_TYPE("AlgoBinaryVecBcast101::is_available"_hash); return false; } bool ElemwiseImpl::AlgoBinaryVecBcast101x4::is_available( const KernParam& kern_param) const { if (!is_available_common(kern_param.mode) || ((BcastType::VEC_BCAST101x4 != kern_param.broad_cast_type) && (BcastType::BCAST101x4_VEC != kern_param.broad_cast_type))) return false; auto& elparam = kern_param.binary_elparam; auto& src0 = elparam[0]; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (DNN_FLOAT16_SELECT(src0.layout.dtype == dtype::Float16{}, false)) { return false; } #endif DISPATCH_TYPE("AlgoBinaryVecBcast101x::is_available"_hash); return false; } #undef DISPATCH_MODE_FLOAT #undef DISPATCH_MODE_INT #if MEGDNN_AARCH64 #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ switch (kern_param.mode) { \ DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \ DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, \ FuseAddReluOp); \ DISPATCH_BINARY(FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, \ FuseAddHSwishOp); \ default: \ megdnn_throw(ssprintf("No avaiable algo find for: %d", \ static_cast(kern_param.mode))); \ } #else #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ switch (kern_param.mode) { \ DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, \ FuseAddReluOp); \ DISPATCH_BINARY(FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, \ FuseAddHSwishOp); \ default: \ megdnn_throw(ssprintf("No avaiable algo find for: %d", \ static_cast(kern_param.mode))); \ } #endif #define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ switch (kern_param.mode) { \ DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ DISPATCH_BINARY(RMULH, _case, _type, _type_midout_id, RmulhOp); \ DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, \ FuseAddReluOp); \ default: \ megdnn_throw(ssprintf("No avaiable algo find for: %d", \ static_cast(kern_param.mode))); \ } void ElemwiseImpl::AlgoBinaryVecVec::exec(const KernParam& kern_param) const { auto& elparam = kern_param.binary_elparam; auto &src0 = elparam[0], &src1 = elparam[1]; //! exactly match [x, y] + [x, y] #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ case Mode::_mode: \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ midout_iv(Mode::_mode), _type_midout_id) { \ thin_function \ run = OpCallerBinary<_op<_type, _type>, \ BcastType::VEC_VEC>::run; \ MEGDNN_DISPATCH_CPU_KERN( \ static_cast(kern_param.handle), \ run(static_cast(src0.raw_ptr), \ static_cast(src1.raw_ptr), \ static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ src1.layout.dtype, dst.layout.dtype, \ src0.layout.total_nr_elems())); \ } \ MIDOUT_END(); \ return auto&& dst = *(kern_param.m_dst); DISPATCH_TYPE("AlgoBinaryVecVec::exec"_hash); #undef DISPATCH_BINARY return; } void ElemwiseImpl::AlgoBinaryVecScalar::exec( const KernParam& kern_param) const { auto& elparam = kern_param.binary_elparam; auto &src0 = elparam[0], &src1 = elparam[1]; auto&& dst = *(kern_param.m_dst); // Case 2: vector + scalar #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ case Mode::_mode: \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ midout_iv(Mode::_mode), _type_midout_id) { \ thin_function \ run = OpCallerBinary<_op<_type, _type>, \ BcastType::VEC_SCALAR>::run; \ MEGDNN_DISPATCH_CPU_KERN( \ static_cast(kern_param.handle), \ run(static_cast(src0.raw_ptr), \ static_cast(src1.raw_ptr)[0], \ static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ src1.layout.dtype, dst.layout.dtype, \ src0.layout.total_nr_elems())); \ } \ MIDOUT_END(); \ return if (BcastType::VEC_SCALAR == kern_param.broad_cast_type) { DISPATCH_TYPE("AlgoBinaryVecScalar::exec_vec_sca"_hash); } #undef DISPATCH_BINARY // scalar + vector #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ case Mode::_mode: \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ midout_iv(Mode::_mode), _type_midout_id) { \ thin_function \ run = OpCallerBinary<_op<_type, _type>, \ BcastType::SCALAR_VEC>::run; \ MEGDNN_DISPATCH_CPU_KERN( \ static_cast(kern_param.handle), \ run(static_cast(src0.raw_ptr)[0], \ static_cast(src1.raw_ptr), \ static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ src1.layout.dtype, dst.layout.dtype, \ src1.layout.total_nr_elems())); \ } \ MIDOUT_END(); \ return if (BcastType::SCALAR_VEC == kern_param.broad_cast_type) { DISPATCH_TYPE("AlgoBinaryVecScalar::exec_sca_vec"_hash); } #undef DISPATCH_BINARY return; } void ElemwiseImpl::AlgoBinaryVecBcast101::exec( const KernParam& kern_param) const { auto& elparam = kern_param.binary_elparam; auto &src0 = elparam[0], &src1 = elparam[1]; auto&& dst = *(kern_param.m_dst); BroadcastChannelInfo binfo; // Case 3: BcastType::VEC + BCAST_101 if (BcastType::VEC_BCAST101 == kern_param.broad_cast_type && is_broadcasted_channel_like(src1.layout, binfo)) { #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ case Mode::_mode: \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ midout_iv(Mode::_mode), _type_midout_id) { \ thin_function \ run = OpCallerBinary<_op<_type, _type>, \ BcastType::VEC_BCAST101>::run; \ MEGDNN_DISPATCH_CPU_KERN( \ static_cast(kern_param.handle), \ run(static_cast(src0.raw_ptr), \ static_cast(src1.raw_ptr), \ static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ binfo.z)); \ } \ MIDOUT_END(); \ return DISPATCH_TYPE("AlgoBinaryVecBcast101::exec_vec_b"_hash); #undef DISPATCH_BINARY } // BCAST_101 + BcastType::VEC if (BcastType::BCAST101_VEC == kern_param.broad_cast_type && is_broadcasted_channel_like(src0.layout, binfo)) { #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ case Mode::_mode: \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ midout_iv(Mode::_mode), _type_midout_id) { \ thin_function \ run = OpCallerBinary<_op<_type, _type>, \ BcastType::BCAST101_VEC>::run; \ MEGDNN_DISPATCH_CPU_KERN( \ static_cast(kern_param.handle), \ run(static_cast(src0.raw_ptr), \ static_cast(src1.raw_ptr), \ static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ binfo.z)); \ } \ MIDOUT_END(); \ return DISPATCH_TYPE("AlgoBinaryVecBcast101::exec_b_vec"_hash); #undef DISPATCH_BINARY } return; } void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( const KernParam& kern_param) const { auto& elparam = kern_param.binary_elparam; auto &src0 = elparam[0], &src1 = elparam[1]; auto&& dst = *(kern_param.m_dst); BroadcastChannelInfo binfo; // BcastType::VEC + BCAST_101x if (BcastType::VEC_BCAST101x4 == kern_param.broad_cast_type && is_broadcastedx_channel_like<4>(src1.layout, binfo)) { #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ case Mode::_mode: \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ midout_iv(Mode::_mode), _type_midout_id) { \ thin_function \ run = OpCallerBinary<_op<_type, _type>, \ BcastType::VEC_BCAST101x4>::run; \ MEGDNN_DISPATCH_CPU_KERN( \ static_cast(kern_param.handle), \ run(static_cast(src0.raw_ptr), \ static_cast(src1.raw_ptr), \ static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ src1.layout.dtype, dst.layout.dtype, batch_size, \ binfo.x, binfo.y, binfo.z)); \ } \ MIDOUT_END(); \ return size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); DISPATCH_TYPE("AlgoBinaryVecBcast101x::exec_vec_b"_hash); #undef DISPATCH_BINARY } // BCAST_101x + BcastType::VEC if (BcastType::BCAST101x4_VEC == kern_param.broad_cast_type && is_broadcastedx_channel_like<4>(src0.layout, binfo)) { #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ case Mode::_mode: \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ midout_iv(Mode::_mode), _type_midout_id) { \ thin_function \ run = OpCallerBinary<_op<_type, _type>, \ BcastType::BCAST101x4_VEC>::run; \ MEGDNN_DISPATCH_CPU_KERN( \ static_cast(kern_param.handle), \ run(static_cast(src0.raw_ptr), \ static_cast(src1.raw_ptr), \ static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ src1.layout.dtype, dst.layout.dtype, batch_size, \ binfo.x, binfo.y, binfo.z)); \ } \ MIDOUT_END(); \ return size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); DISPATCH_TYPE("AlgoBinaryVecBcast101x::exec_b_vec"_hash); #undef DISPATCH_BINARY } return; } #undef DISPATCH_MODE_FLOAT #undef DISPATCH_MODE_INT // vim: syntax=cpp.doxygen