@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
@@ -22,7 +23,6 @@ MIDOUT_DECL(arm_common_conv_bias_postprocess_helper) | |||||
namespace { | namespace { | ||||
#define CONCAT_OP(_name) megdnn::arm_common::_name | #define CONCAT_OP(_name) megdnn::arm_common::_name | ||||
#define CONCAT_NL(_name) megdnn::NonlineMode::_name | #define CONCAT_NL(_name) megdnn::NonlineMode::_name | ||||
@@ -57,9 +57,9 @@ namespace { | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | ||||
dst_type, N, OC, OH* OW); | dst_type, N, OC, OH* OW); | ||||
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \ | |||||
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | |||||
megdnn::arm_common::OpCallerBinary<_op<ctype>, \ | megdnn::arm_common::OpCallerBinary<_op<ctype>, \ | ||||
megdnn::arm_common::VEC_BCAST101x4>:: \ | |||||
megdnn::arm_common::VEC_BCAST101xX>:: \ | |||||
run(static_cast<ctype*>(conv_dst_ptr), \ | run(static_cast<ctype*>(conv_dst_ptr), \ | ||||
reinterpret_cast<const ctype*>(bias_ptr), \ | reinterpret_cast<const ctype*>(bias_ptr), \ | ||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | ||||
@@ -86,9 +86,9 @@ namespace { | |||||
if (pack_oc_size == 1) { \ | if (pack_oc_size == 1) { \ | ||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | ||||
} else { \ | } else { \ | ||||
megdnn_assert(pack_oc_size == 4, \ | |||||
"Only support nchw44 in ARM"); \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||||
megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \ | |||||
"Only support nchw44/nchw88 in ARM"); \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX); \ | |||||
} \ | } \ | ||||
} \ | } \ | ||||
MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
@@ -100,7 +100,7 @@ namespace { | |||||
MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
break; \ | break; \ | ||||
default: \ | default: \ | ||||
megdnn_throw("unknow biasmode"); \ | |||||
megdnn_throw("unknow biasmode"); \ | |||||
break; \ | break; \ | ||||
} | } | ||||
@@ -160,7 +160,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
#undef FOR_NONLINEAR_UNARY | #undef FOR_NONLINEAR_UNARY | ||||
#undef FOR_NONLINEAR_BINARY_BROADCAST | #undef FOR_NONLINEAR_BINARY_BROADCAST | ||||
#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44 | |||||
#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX | |||||
#undef FOR_NONLINEAR_BINARY | #undef FOR_NONLINEAR_BINARY | ||||
#undef FOR_NONLINEAR_NOBIAS | #undef FOR_NONLINEAR_NOBIAS | ||||
#undef FOR_NONLINEAR | #undef FOR_NONLINEAR | ||||
@@ -183,16 +183,24 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | ||||
dst_type, N, OC, OH* OW); | dst_type, N, OC, OH* OW); | ||||
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \ | |||||
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | |||||
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \ | |||||
megdnn::arm_common::VEC_BCAST101xX>:: \ | |||||
run(static_cast<opctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const opctype*>(bias_ptr), \ | |||||
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | |||||
dst_type, N, OC, OH* OW, pack_oc_size); | |||||
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ | |||||
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \ | megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \ | ||||
megdnn::arm_common::VEC_BCAST101x4>:: \ | |||||
megdnn::arm_common::VEC_BCAST101xX>:: \ | |||||
run(static_cast<opctype*>(conv_dst_ptr), \ | run(static_cast<opctype*>(conv_dst_ptr), \ | ||||
reinterpret_cast<const opctype*>(bias_ptr), \ | reinterpret_cast<const opctype*>(bias_ptr), \ | ||||
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | ||||
dst_type, N, OC, OH* OW, pack_oc_size); | dst_type, N, OC, OH* OW, pack_oc_size); | ||||
#define HANDLE_IDENTITY(_caller, _op) \ | |||||
case megdnn::NonlineMode::IDENTITY: \ | |||||
#define HANDLE_IDENTITY(_caller, _op) \ | |||||
case megdnn::NonlineMode::IDENTITY: \ | |||||
_caller(_op) break; | _caller(_op) break; | ||||
#define FOR_NONLINEAR(_caller) \ | #define FOR_NONLINEAR(_caller) \ | ||||
@@ -220,9 +228,9 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
if (pack_oc_size == 1) { \ | if (pack_oc_size == 1) { \ | ||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | ||||
} else { \ | } else { \ | ||||
megdnn_assert(pack_oc_size == 4, \ | |||||
"Only support nchw44 in ARM"); \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||||
megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \ | |||||
"Only support nchw44/nchw88 in ARM"); \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX); \ | |||||
} \ | } \ | ||||
break; \ | break; \ | ||||
default: \ | default: \ | ||||
@@ -230,9 +238,9 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
if (pack_oc_size == 1) { \ | if (pack_oc_size == 1) { \ | ||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | ||||
} else { \ | } else { \ | ||||
megdnn_assert(pack_oc_size == 4, \ | |||||
"Only support nchw44 in ARM"); \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||||
megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \ | |||||
"Only support nchw44/nchw88 in ARM"); \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX); \ | |||||
} \ | } \ | ||||
break; \ | break; \ | ||||
} \ | } \ | ||||
@@ -254,7 +262,7 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | |||||
#undef FOR_NONLINEAR_UNARY | #undef FOR_NONLINEAR_UNARY | ||||
#undef FOR_NONLINEAR_BINARY_BROADCAST | #undef FOR_NONLINEAR_BINARY_BROADCAST | ||||
#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44 | |||||
#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX | |||||
#undef FOR_NONLINEAR_BINARY | #undef FOR_NONLINEAR_BINARY | ||||
#undef FOR_NONLINEAR_NOBIAS | #undef FOR_NONLINEAR_NOBIAS | ||||
#undef FOR_NONLINEAR | #undef FOR_NONLINEAR | ||||
@@ -268,9 +276,9 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | ||||
dst_type, N, OC, OH* OW); | dst_type, N, OC, OH* OW); | ||||
#define FOR_BINARY_BROADCAST_NCHW44(_op) \ | |||||
#define FOR_BINARY_BROADCAST_NCHWXX(_op) \ | |||||
megdnn::arm_common::OpCallerBinary<_op<ctype>, \ | megdnn::arm_common::OpCallerBinary<_op<ctype>, \ | ||||
megdnn::arm_common::VEC_BCAST101x4>:: \ | |||||
megdnn::arm_common::VEC_BCAST101xX>:: \ | |||||
run(static_cast<ctype*>(conv_dst_ptr), \ | run(static_cast<ctype*>(conv_dst_ptr), \ | ||||
reinterpret_cast<const ctype*>(bias_ptr), \ | reinterpret_cast<const ctype*>(bias_ptr), \ | ||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | ||||
@@ -284,25 +292,25 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | ||||
dst_type, N* OC* OH* OW* pack_oc_size); | dst_type, N* OC* OH* OW* pack_oc_size); | ||||
#define FOR_BIAS(_bias_mode, OH, OW) \ | |||||
switch (_bias_mode) { \ | |||||
case megdnn::BiasMode::NO_BIAS: \ | |||||
break; \ | |||||
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
if (pack_oc_size == 1) { \ | |||||
FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \ | |||||
} else { \ | |||||
megdnn_assert(pack_oc_size == 4, \ | |||||
"Only support nchw44 in ARM"); \ | |||||
FOR_BINARY_BROADCAST_NCHW44(CONCAT_OP(AddOp)); \ | |||||
} \ | |||||
break; \ | |||||
case megdnn::BiasMode::BIAS: \ | |||||
FOR_BINARY(CONCAT_OP(AddOp)); \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_throw("unknow biasmode"); \ | |||||
break; \ | |||||
#define FOR_BIAS(_bias_mode, OH, OW) \ | |||||
switch (_bias_mode) { \ | |||||
case megdnn::BiasMode::NO_BIAS: \ | |||||
break; \ | |||||
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
if (pack_oc_size == 1) { \ | |||||
FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \ | |||||
} else { \ | |||||
megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \ | |||||
"Only support nchw44/nchw88 in ARM"); \ | |||||
FOR_BINARY_BROADCAST_NCHWXX(CONCAT_OP(AddOp)); \ | |||||
} \ | |||||
break; \ | |||||
case megdnn::BiasMode::BIAS: \ | |||||
FOR_BINARY(CONCAT_OP(AddOp)); \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_throw("unknow biasmode"); \ | |||||
break; \ | |||||
} | } | ||||
template <typename ctype, typename dtype> | template <typename ctype, typename dtype> | ||||
@@ -318,7 +326,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | |||||
}; | }; | ||||
#undef FOR_BINARY_BROADCAST | #undef FOR_BINARY_BROADCAST | ||||
#undef FOR_BINARY_BROADCAST_NCHW44 | |||||
#undef FOR_BINARY_BROADCAST_NCHWXX | |||||
#undef FOR_BINARY | #undef FOR_BINARY | ||||
#undef FOR_BIAS | #undef FOR_BIAS | ||||
#undef CB | #undef CB | ||||
@@ -105,25 +105,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available( | |||||
return false; | return false; | ||||
} | } | ||||
bool ElemwiseImpl::AlgoBinaryVecBcast101x4::is_available( | |||||
bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( | |||||
const KernParam& kern_param) const { | const KernParam& kern_param) const { | ||||
if (!is_available_common(kern_param.mode) || | if (!is_available_common(kern_param.mode) || | ||||
((BcastType::VEC_BCAST101x4 != kern_param.broad_cast_type) && | |||||
(BcastType::BCAST101x4_VEC != kern_param.broad_cast_type))) | |||||
((BcastType::VEC_BCAST101xX != kern_param.broad_cast_type) && | |||||
(BcastType::BCAST101xX_VEC != kern_param.broad_cast_type))) | |||||
return false; | return false; | ||||
auto& elparam = kern_param.binary_elparam; | auto& elparam = kern_param.binary_elparam; | ||||
auto& src0 = elparam[0]; | auto& src0 = elparam[0]; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
if (DNN_FLOAT16_SELECT(src0.layout.dtype == dtype::Float16{}, false)) { | |||||
return false; | |||||
} | |||||
#endif | |||||
DISPATCH_TYPE("AlgoBinaryVecBcast101x::is_available"_hash); | |||||
DISPATCH_TYPE("AlgoBinaryVecBcast101xX::is_available"_hash); | |||||
return false; | return false; | ||||
} | } | ||||
#undef DISPATCH_MODE_FLOAT | #undef DISPATCH_MODE_FLOAT | ||||
#undef DISPATCH_MODE_INT | #undef DISPATCH_MODE_INT | ||||
@@ -334,16 +330,19 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec( | |||||
return; | return; | ||||
} | } | ||||
void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( | |||||
void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec( | |||||
const KernParam& kern_param) const { | const KernParam& kern_param) const { | ||||
auto& elparam = kern_param.binary_elparam; | auto& elparam = kern_param.binary_elparam; | ||||
auto &src0 = elparam[0], &src1 = elparam[1]; | auto &src0 = elparam[0], &src1 = elparam[1]; | ||||
auto&& dst = *(kern_param.m_dst); | auto&& dst = *(kern_param.m_dst); | ||||
BroadcastChannelInfo binfo; | BroadcastChannelInfo binfo; | ||||
// BcastType::VEC + BCAST_101x | |||||
if (BcastType::VEC_BCAST101x4 == kern_param.broad_cast_type && | |||||
is_broadcastedx_channel_like<4>(src1.layout, binfo)) { | |||||
// BcastType::VEC + BCAST_101X | |||||
if (BcastType::VEC_BCAST101xX == kern_param.broad_cast_type) { | |||||
megdnn_assert( | |||||
is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src1.layout, binfo), | |||||
"only nchw44 and nchw88 supported"); | |||||
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | ||||
case Mode::_mode: \ | case Mode::_mode: \ | ||||
MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ | MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ | ||||
@@ -351,7 +350,7 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( | |||||
thin_function<void(const _type*, const _type*, _type*, DType, \ | thin_function<void(const _type*, const _type*, _type*, DType, \ | ||||
DType, DType, size_t, size_t, size_t, size_t)> \ | DType, DType, size_t, size_t, size_t, size_t)> \ | ||||
run = OpCallerBinary<_op<_type, _type>, \ | run = OpCallerBinary<_op<_type, _type>, \ | ||||
BcastType::VEC_BCAST101x4>::run; \ | |||||
BcastType::VEC_BCAST101xX>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | run(static_cast<const _type*>(src0.raw_ptr), \ | ||||
@@ -362,17 +361,19 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( | |||||
} \ | } \ | ||||
MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
return | return | ||||
size_t batch_size = | size_t batch_size = | ||||
src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | ||||
DISPATCH_TYPE("AlgoBinaryVecBcast101x::exec_vec_b"_hash); | |||||
DISPATCH_TYPE("AlgoBinaryVecBcast101xX::exec_vec_b"_hash); | |||||
#undef DISPATCH_BINARY | #undef DISPATCH_BINARY | ||||
} | } | ||||
// BCAST_101x + BcastType::VEC | // BCAST_101x + BcastType::VEC | ||||
if (BcastType::BCAST101x4_VEC == kern_param.broad_cast_type && | |||||
is_broadcastedx_channel_like<4>(src0.layout, binfo)) { | |||||
if (BcastType::BCAST101xX_VEC == kern_param.broad_cast_type) { | |||||
megdnn_assert( | |||||
is_broadcastedx_channel_like<4>(src0.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src0.layout, binfo), | |||||
"only nchw44 and nchw88 supported"); | |||||
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | ||||
case Mode::_mode: \ | case Mode::_mode: \ | ||||
MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ | MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ | ||||
@@ -380,7 +381,7 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( | |||||
thin_function<void(const _type*, const _type*, _type*, DType, \ | thin_function<void(const _type*, const _type*, _type*, DType, \ | ||||
DType, DType, size_t, size_t, size_t, size_t)> \ | DType, DType, size_t, size_t, size_t, size_t)> \ | ||||
run = OpCallerBinary<_op<_type, _type>, \ | run = OpCallerBinary<_op<_type, _type>, \ | ||||
BcastType::BCAST101x4_VEC>::run; \ | |||||
BcastType::BCAST101xX_VEC>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | run(static_cast<const _type*>(src0.raw_ptr), \ | ||||
@@ -394,12 +395,13 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( | |||||
size_t batch_size = | size_t batch_size = | ||||
src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | ||||
DISPATCH_TYPE("AlgoBinaryVecBcast101x::exec_b_vec"_hash); | |||||
DISPATCH_TYPE("AlgoBinaryVecBcast101xX::exec_b_vec"_hash); | |||||
#undef DISPATCH_BINARY | #undef DISPATCH_BINARY | ||||
} | } | ||||
return; | return; | ||||
} | } | ||||
#undef DISPATCH_MODE_FLOAT | #undef DISPATCH_MODE_FLOAT | ||||
#undef DISPATCH_MODE_INT | #undef DISPATCH_MODE_INT | ||||
@@ -34,7 +34,7 @@ namespace arm_common { | |||||
DECL_CB(VecVec); | DECL_CB(VecVec); | ||||
DECL_CB(VecScalar); | DECL_CB(VecScalar); | ||||
DECL_CB(VecBcast101); | DECL_CB(VecBcast101); | ||||
DECL_CB(VecBcast101x4); | |||||
DECL_CB(VecBcast101xX); | |||||
#undef DECL_CB | #undef DECL_CB | ||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -27,14 +27,14 @@ class ElemwiseImpl::AlgoPack { | |||||
AlgoBinaryVecVec algo_binary_vec_vec; | AlgoBinaryVecVec algo_binary_vec_vec; | ||||
AlgoBinaryVecScalar algo_binary_vec_sca; | AlgoBinaryVecScalar algo_binary_vec_sca; | ||||
AlgoBinaryVecBcast101 algo_binary_vec_bcast101; | AlgoBinaryVecBcast101 algo_binary_vec_bcast101; | ||||
AlgoBinaryVecBcast101x4 algo_binary_VEC_BCAST101x4; | |||||
AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX; | |||||
AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; | AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; | ||||
AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca; | AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca; | ||||
AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101; | AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101; | ||||
AlgoTernaryFma3Bcast101x4VecBcast101x4 | |||||
algo_ternaryfma3_bcast101x4_vec_bcast101x4; | |||||
AlgoTernaryFma3Bcast101xXVecBcast101xX | |||||
algo_ternaryfma3_bcast101xX_vec_bcast101xX; | |||||
AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec; | AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec; | ||||
AlgoTernaryFma3VecBcast101x4Vec algo_ternaryfma3_vec_bcast101x4_vec; | |||||
AlgoTernaryFma3VecBcast101xXVec algo_ternaryfma3_vec_bcast101xX_vec; | |||||
AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec; | AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec; | ||||
AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca; | AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca; | ||||
@@ -44,13 +44,13 @@ public: | |||||
all_algos.emplace_back(&algo_binary_vec_vec); | all_algos.emplace_back(&algo_binary_vec_vec); | ||||
all_algos.emplace_back(&algo_binary_vec_sca); | all_algos.emplace_back(&algo_binary_vec_sca); | ||||
all_algos.emplace_back(&algo_binary_vec_bcast101); | all_algos.emplace_back(&algo_binary_vec_bcast101); | ||||
all_algos.emplace_back(&algo_binary_VEC_BCAST101x4); | |||||
all_algos.emplace_back(&algo_binary_VEC_BCAST101xX); | |||||
all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); | all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); | ||||
all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca); | all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca); | ||||
all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101); | all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101); | ||||
all_algos.emplace_back(&algo_ternaryfma3_bcast101x4_vec_bcast101x4); | |||||
all_algos.emplace_back(&algo_ternaryfma3_bcast101xX_vec_bcast101xX); | |||||
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec); | all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec); | ||||
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101x4_vec); | |||||
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101xX_vec); | |||||
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec); | all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec); | ||||
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca); | all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca); | ||||
} | } | ||||
@@ -118,9 +118,10 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { | |||||
} | } | ||||
if (is_vector(src1.layout) && | if (is_vector(src1.layout) && | ||||
is_broadcastedx_channel_like<4>(src0.layout, binfo) && | |||||
(is_broadcastedx_channel_like<4>(src0.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src0.layout, binfo)) && | |||||
src0.layout.eq_layout(src2.layout)) { | src0.layout.eq_layout(src2.layout)) { | ||||
kern_param.broad_cast_type = BcastType::BCAST101x4_VEC_BCAST101x4; | |||||
kern_param.broad_cast_type = BcastType::BCAST101xX_VEC_BCAST101xX; | |||||
return kern_param; | return kern_param; | ||||
} | } | ||||
@@ -131,8 +132,9 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { | |||||
} | } | ||||
if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && | if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && | ||||
is_broadcastedx_channel_like<4>(src1.layout, binfo)) { | |||||
kern_param.broad_cast_type = BcastType::VEC_BCAST101x4_VEC; | |||||
(is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) { | |||||
kern_param.broad_cast_type = BcastType::VEC_BCAST101xX_VEC; | |||||
return kern_param; | return kern_param; | ||||
} | } | ||||
@@ -180,17 +182,18 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { | |||||
} | } | ||||
if (is_vector(src0.layout) && | if (is_vector(src0.layout) && | ||||
is_broadcastedx_channel_like<4>(src1.layout, binfo)) { | |||||
kern_param.broad_cast_type = BcastType::VEC_BCAST101x4; | |||||
(is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) { | |||||
kern_param.broad_cast_type = BcastType::VEC_BCAST101xX; | |||||
return kern_param; | return kern_param; | ||||
} | } | ||||
if (is_vector(src1.layout) && | if (is_vector(src1.layout) && | ||||
is_broadcastedx_channel_like<4>(src0.layout, binfo)) { | |||||
kern_param.broad_cast_type = BcastType::BCAST101x4_VEC; | |||||
(is_broadcastedx_channel_like<4>(src0.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src0.layout, binfo))) { | |||||
kern_param.broad_cast_type = BcastType::BCAST101xX_VEC; | |||||
return kern_param; | return kern_param; | ||||
} | } | ||||
} else if (opr->m_src->size() == 1) { | } else if (opr->m_src->size() == 1) { | ||||
kern_param.broad_cast_type = BcastType::VEC; | kern_param.broad_cast_type = BcastType::VEC; | ||||
kern_param.unary_elparam = opr->make_elemwise_op_param<1>(); | kern_param.unary_elparam = opr->make_elemwise_op_param<1>(); | ||||
@@ -10,7 +10,9 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/fallback/elemwise/opr_impl.h" | #include "src/fallback/elemwise/opr_impl.h" | ||||
#include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
@@ -37,13 +39,13 @@ private: | |||||
class AlgoBinaryVecVec; | class AlgoBinaryVecVec; | ||||
class AlgoBinaryVecScalar; | class AlgoBinaryVecScalar; | ||||
class AlgoBinaryVecBcast101; | class AlgoBinaryVecBcast101; | ||||
class AlgoBinaryVecBcast101x4; | |||||
class AlgoBinaryVecBcast101xX; | |||||
class AlgoTernaryFma3VecVecVec; | class AlgoTernaryFma3VecVecVec; | ||||
class AlgoTernaryFma3VecVecScalar; | class AlgoTernaryFma3VecVecScalar; | ||||
class AlgoTernaryFma3Bcast101VecBcast101; | class AlgoTernaryFma3Bcast101VecBcast101; | ||||
class AlgoTernaryFma3Bcast101x4VecBcast101x4; | |||||
class AlgoTernaryFma3Bcast101xXVecBcast101xX; | |||||
class AlgoTernaryFma3VecBcast101Vec; | class AlgoTernaryFma3VecBcast101Vec; | ||||
class AlgoTernaryFma3VecBcast101x4Vec; | |||||
class AlgoTernaryFma3VecBcast101xXVec; | |||||
class AlgoTernaryFma3VecScalarVec; | class AlgoTernaryFma3VecScalarVec; | ||||
class AlgoTernaryFma3VecScalarScalar; | class AlgoTernaryFma3VecScalarScalar; | ||||
class AlgoPack; | class AlgoPack; | ||||
@@ -42,9 +42,9 @@ using namespace arm_common; | |||||
DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC); | DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC); | ||||
DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR); | DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR); | ||||
DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101); | DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101); | ||||
DECL_AVAILABLE(Bcast101x4VecBcast101x4, BcastType::BCAST101x4_VEC_BCAST101x4); | |||||
DECL_AVAILABLE(Bcast101xXVecBcast101xX, BcastType::BCAST101xX_VEC_BCAST101xX); | |||||
DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC); | DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC); | ||||
DECL_AVAILABLE(VecBcast101x4Vec, BcastType::VEC_BCAST101x4_VEC); | |||||
DECL_AVAILABLE(VecBcast101xXVec, BcastType::VEC_BCAST101xX_VEC); | |||||
DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC); | DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC); | ||||
DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR); | DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR); | ||||
#undef DECL_CB | #undef DECL_CB | ||||
@@ -161,13 +161,15 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( | |||||
return; | return; | ||||
} | } | ||||
void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec( | |||||
void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec( | |||||
const KernParam& kern_param) const { | const KernParam& kern_param) const { | ||||
auto& elparam = kern_param.ternary_elparam; | auto& elparam = kern_param.ternary_elparam; | ||||
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | ||||
BroadcastChannelInfo binfo; | BroadcastChannelInfo binfo; | ||||
is_broadcastedx_channel_like<4>(src0.layout, binfo); | |||||
megdnn_assert(is_broadcastedx_channel_like<4>(src0.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src0.layout, binfo), | |||||
"only nchw44 and nchw88 supported"); | |||||
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | ||||
case Mode::_mode: \ | case Mode::_mode: \ | ||||
MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ | MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ | ||||
@@ -177,7 +179,7 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec( | |||||
size_t, size_t, size_t)> \ | size_t, size_t, size_t)> \ | ||||
run = OpCallerTernary< \ | run = OpCallerTernary< \ | ||||
_op<_type, _type>, \ | _op<_type, _type>, \ | ||||
BcastType::BCAST101x4_VEC_BCAST101x4>::run; \ | |||||
BcastType::BCAST101xX_VEC_BCAST101xX>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | run(static_cast<const _type*>(src0.raw_ptr), \ | ||||
@@ -193,19 +195,21 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec( | |||||
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | ||||
auto&& dst = *(kern_param.m_dst); | auto&& dst = *(kern_param.m_dst); | ||||
DISPATCH_TYPE("AlgoTernaryFma3Bcast101x4VecBcast101x4::exec"_hash); | |||||
DISPATCH_TYPE("AlgoTernaryFma3Bcast101xXVecBcast101xX::exec"_hash); | |||||
#undef DISPATCH_TERNARY | #undef DISPATCH_TERNARY | ||||
return; | return; | ||||
} | } | ||||
void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec( | |||||
void ElemwiseImpl::AlgoTernaryFma3VecBcast101xXVec::exec( | |||||
const KernParam& kern_param) const { | const KernParam& kern_param) const { | ||||
auto& elparam = kern_param.ternary_elparam; | auto& elparam = kern_param.ternary_elparam; | ||||
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | ||||
BroadcastChannelInfo binfo; | BroadcastChannelInfo binfo; | ||||
is_broadcastedx_channel_like<4>(src1.layout, binfo); | |||||
megdnn_assert(is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src1.layout, binfo), | |||||
"only nchw44 and nchw88 supported"); | |||||
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | ||||
case Mode::_mode: \ | case Mode::_mode: \ | ||||
MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ | MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ | ||||
@@ -214,7 +218,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec( | |||||
_type*, DType, DType, DType, DType, size_t, \ | _type*, DType, DType, DType, DType, size_t, \ | ||||
size_t, size_t, size_t)> \ | size_t, size_t, size_t)> \ | ||||
run = OpCallerTernary<_op<_type, _type>, \ | run = OpCallerTernary<_op<_type, _type>, \ | ||||
BcastType::VEC_BCAST101x4_VEC>::run; \ | |||||
BcastType::VEC_BCAST101xX_VEC>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | run(static_cast<const _type*>(src0.raw_ptr), \ | ||||
@@ -230,7 +234,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec( | |||||
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | ||||
auto&& dst = *(kern_param.m_dst); | auto&& dst = *(kern_param.m_dst); | ||||
DISPATCH_TYPE("AlgoTernaryFma3VecBcast101x4Vec::exec"_hash); | |||||
DISPATCH_TYPE("AlgoTernaryFma3VecBcast101xXVec::exec"_hash); | |||||
#undef DISPATCH_TERNARY | #undef DISPATCH_TERNARY | ||||
return; | return; | ||||
@@ -34,9 +34,9 @@ namespace arm_common { | |||||
DECL_CB(VecVecVec); | DECL_CB(VecVecVec); | ||||
DECL_CB(VecVecScalar); | DECL_CB(VecVecScalar); | ||||
DECL_CB(Bcast101VecBcast101); | DECL_CB(Bcast101VecBcast101); | ||||
DECL_CB(Bcast101x4VecBcast101x4); | |||||
DECL_CB(Bcast101xXVecBcast101xX); | |||||
DECL_CB(VecBcast101Vec); | DECL_CB(VecBcast101Vec); | ||||
DECL_CB(VecBcast101x4Vec); | |||||
DECL_CB(VecBcast101xXVec); | |||||
DECL_CB(VecScalarVec); | DECL_CB(VecScalarVec); | ||||
DECL_CB(VecScalarScalar); | DECL_CB(VecScalarScalar); | ||||
#undef DECL_CB | #undef DECL_CB | ||||
@@ -644,7 +644,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, | |||||
{ | { | ||||
BroadcastChannelInfo binfo; | BroadcastChannelInfo binfo; | ||||
if (is_vector(src0.layout) && | if (is_vector(src0.layout) && | ||||
is_broadcastedx_channel_like<4>(src1.layout, binfo)) { | |||||
(is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) { | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | ||||
case _mode: { \ | case _mode: { \ | ||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | ||||
@@ -653,14 +654,13 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, | |||||
DType, DType, DType, size_t, size_t, size_t, \ | DType, DType, DType, size_t, size_t, size_t, \ | ||||
size_t)> \ | size_t)> \ | ||||
run = OpCallerBinary<_op<src_ctype, dst_ctype>, \ | run = OpCallerBinary<_op<src_ctype, dst_ctype>, \ | ||||
VEC_BCAST101x4>::run; \ | |||||
VEC_BCAST101xX>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ | MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ | ||||
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | ||||
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | ||||
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ | dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ | ||||
return; \ | return; \ | ||||
} | } | ||||
size_t batch_size = | size_t batch_size = | ||||
src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | ||||
DISPATCH() | DISPATCH() | ||||
@@ -679,14 +679,13 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, | |||||
DType, DType, DType, size_t, size_t, size_t, \ | DType, DType, DType, size_t, size_t, size_t, \ | ||||
size_t)> \ | size_t)> \ | ||||
run = OpCallerBinary<_op<src_ctype, dst_ctype>, \ | run = OpCallerBinary<_op<src_ctype, dst_ctype>, \ | ||||
BCAST101x4_VEC>::run; \ | |||||
BCAST101xX_VEC>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ | MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ | ||||
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | ||||
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | ||||
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ | dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ | ||||
return; \ | return; \ | ||||
} | } | ||||
size_t batch_size = | size_t batch_size = | ||||
src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | ||||
DISPATCH() | DISPATCH() | ||||
@@ -818,7 +817,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, | |||||
{ | { | ||||
BroadcastChannelInfo binfo; | BroadcastChannelInfo binfo; | ||||
if (is_vector(src0.layout) && | if (is_vector(src0.layout) && | ||||
is_broadcastedx_channel_like<4>(src1.layout, binfo) && | |||||
(is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src1.layout, binfo)) && | |||||
src0.layout.eq_shape(src2.layout)) { | src0.layout.eq_shape(src2.layout)) { | ||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | ||||
case _mode: { \ | case _mode: { \ | ||||
@@ -828,7 +828,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, | |||||
const src_ctype*, dst_ctype*, DType, DType, DType, \ | const src_ctype*, dst_ctype*, DType, DType, DType, \ | ||||
DType, size_t, size_t, size_t, size_t)> \ | DType, size_t, size_t, size_t, size_t)> \ | ||||
run = OpCallerTernary<_op<src_ctype, dst_ctype>, \ | run = OpCallerTernary<_op<src_ctype, dst_ctype>, \ | ||||
VEC_BCAST101x4_VEC>::run; \ | |||||
VEC_BCAST101xX_VEC>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | MEGDNN_DISPATCH_CPU_KERN_OPR( \ | ||||
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | ||||
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \ | src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \ | ||||
@@ -846,7 +846,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, | |||||
//! BCAST101x + VEC +BCAST101x | //! BCAST101x + VEC +BCAST101x | ||||
if (is_vector(src1.layout) && | if (is_vector(src1.layout) && | ||||
is_broadcastedx_channel_like<4>(src0.layout, binfo) && | |||||
(is_broadcastedx_channel_like<4>(src0.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src0.layout, binfo)) && | |||||
src0.layout.eq_shape(src2.layout)) { | src0.layout.eq_shape(src2.layout)) { | ||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | ||||
case _mode: { \ | case _mode: { \ | ||||
@@ -856,7 +857,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, | |||||
const src_ctype*, dst_ctype*, DType, DType, DType, \ | const src_ctype*, dst_ctype*, DType, DType, DType, \ | ||||
DType, size_t, size_t, size_t, size_t)> \ | DType, size_t, size_t, size_t, size_t)> \ | ||||
run = OpCallerTernary<_op<src_ctype, dst_ctype>, \ | run = OpCallerTernary<_op<src_ctype, dst_ctype>, \ | ||||
BCAST101x4_VEC_BCAST101x4>::run; \ | |||||
BCAST101xX_VEC_BCAST101xX>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | MEGDNN_DISPATCH_CPU_KERN_OPR( \ | ||||
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | ||||
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \ | src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \ | ||||
@@ -89,6 +89,21 @@ cb(dt_float32, float32_t, float32x4_t, f32); | |||||
cb(dt_int32, int32_t, int32x4_t, s32); | cb(dt_int32, int32_t, int32x4_t, s32); | ||||
#undef cb | #undef cb | ||||
template <typename ctype> | |||||
struct ParamElemVisitorBcast101x8; | |||||
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ | |||||
template <> \ | |||||
struct ParamElemVisitorBcast101x8<_ctype> { \ | |||||
_neon_type operator()(const _ctype* src) const { \ | |||||
return vld1q_##_fun_suffix( \ | |||||
reinterpret_cast<const _inner_ctype*>(src)); \ | |||||
} \ | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
cb(__fp16, __fp16, float16x8_t, f16); | |||||
#endif | |||||
#undef cb | |||||
/*! | /*! | ||||
* \brief broadcast type | * \brief broadcast type | ||||
* BCAST_x[0]x[1]...: x[i] == !stride[i] | * BCAST_x[0]x[1]...: x[i] == !stride[i] | ||||
@@ -97,17 +112,17 @@ enum BcastType { | |||||
VEC, | VEC, | ||||
VEC_VEC, | VEC_VEC, | ||||
VEC_BCAST101, | VEC_BCAST101, | ||||
VEC_BCAST101x4, | |||||
VEC_BCAST101xX, | |||||
VEC_SCALAR, | VEC_SCALAR, | ||||
SCALAR_VEC, | SCALAR_VEC, | ||||
BCAST101_VEC, | BCAST101_VEC, | ||||
BCAST101x4_VEC, | |||||
BCAST101xX_VEC, | |||||
VEC_VEC_VEC, | VEC_VEC_VEC, | ||||
VEC_VEC_SCALAR, | VEC_VEC_SCALAR, | ||||
BCAST101_VEC_BCAST101, | BCAST101_VEC_BCAST101, | ||||
BCAST101x4_VEC_BCAST101x4, | |||||
BCAST101xX_VEC_BCAST101xX, | |||||
VEC_BCAST101_VEC, | VEC_BCAST101_VEC, | ||||
VEC_BCAST101x4_VEC, | |||||
VEC_BCAST101xX_VEC, | |||||
VEC_SCALAR_VEC, | VEC_SCALAR_VEC, | ||||
VEC_SCALAR_SCALAR, | VEC_SCALAR_SCALAR, | ||||
UNKNOWN_BCAST_TYPE | UNKNOWN_BCAST_TYPE | ||||
@@ -334,7 +349,7 @@ struct OpCallerBinary<Op, VEC_BCAST101> { | |||||
}; | }; | ||||
template <typename ctype> | template <typename ctype> | ||||
struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101x4_VEC> { | |||||
struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101xX_VEC> { | |||||
using Op = PowOp<ctype, ctype>; | using Op = PowOp<ctype, ctype>; | ||||
static void run(const typename Op::src_ctype* src0, | static void run(const typename Op::src_ctype* src0, | ||||
const typename Op::src_ctype* src1, | const typename Op::src_ctype* src1, | ||||
@@ -360,18 +375,37 @@ struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101x4_VEC> { | |||||
} | } | ||||
}; | }; | ||||
template <typename Op> | |||||
struct OpCallerBinary<Op, BCAST101x4_VEC> { | |||||
static void run(const typename Op::src_ctype* src0, | |||||
const typename Op::src_ctype* src1, | |||||
typename Op::dst_ctype* dst, DType src0_dtype, | |||||
DType src1_dtype, DType dst_dtype, size_t batch, | |||||
size_t nr_channel_blocks, size_t channel_stride, | |||||
size_t channel_block_dim) { | |||||
megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); | |||||
Op op(src0_dtype, src1_dtype, dst_dtype); | |||||
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis0; | |||||
ParamElemVisitor<typename Op::src_ctype> vis1; | |||||
template <typename src_ctype, size_t channel_block_dim> | |||||
struct OpCallerBinaryBcast101xXVec { | |||||
template <typename Op> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
typename Op::dst_ctype* dst, const Op& op, size_t batch, | |||||
size_t nr_channel_blocks, size_t channel_stride) { | |||||
for (size_t b = 0; b < batch; b++) { | |||||
auto src0_ptr = src0; | |||||
for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | |||||
auto src0_block_ptr = src0_ptr + cb * channel_block_dim; | |||||
for (size_t img_index = 0; img_index < channel_stride; | |||||
img_index++) { | |||||
for (size_t c_iter = 0; c_iter < channel_block_dim; | |||||
c_iter++) { | |||||
op(*(src0_block_ptr + c_iter), *src1, dst); | |||||
src1++; | |||||
dst++; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
template <typename src_ctype, size_t channel_block_dim> | |||||
struct OpCallerBinaryBcast101xDVec { | |||||
template <typename Op, typename Vis0, typename Vis1> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, | |||||
const Vis1& vis1, size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride) { | |||||
for (size_t b = 0; b < batch; b++) { | for (size_t b = 0; b < batch; b++) { | ||||
auto src0_ptr = src0; | auto src0_ptr = src0; | ||||
for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | ||||
@@ -400,8 +434,63 @@ struct OpCallerBinary<Op, BCAST101x4_VEC> { | |||||
} | } | ||||
}; | }; | ||||
template <typename src_ctype> | |||||
struct OpCallerBinaryBcast101xXVec<src_ctype, 4> { | |||||
template <typename Op> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
typename Op::dst_ctype* dst, const Op& op, size_t batch, | |||||
size_t nr_channel_blocks, size_t channel_stride) { | |||||
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis0; | |||||
ParamElemVisitor<typename Op::src_ctype> vis1; | |||||
OpCallerBinaryBcast101xDVec<src_ctype, 4>::run( | |||||
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} | |||||
}; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
template <> | |||||
struct OpCallerBinaryBcast101xXVec<__fp16, 8> { | |||||
using src_ctype = __fp16; | |||||
template <typename Op> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
typename Op::dst_ctype* dst, const Op& op, size_t batch, | |||||
size_t nr_channel_blocks, size_t channel_stride) { | |||||
ParamElemVisitorBcast101x8<src_ctype> vis0; | |||||
ParamElemVisitor<src_ctype> vis1; | |||||
OpCallerBinaryBcast101xDVec<src_ctype, 8>::run( | |||||
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} | |||||
}; | |||||
#endif | |||||
template <typename Op> | |||||
struct OpCallerBinary<Op, BCAST101xX_VEC> { | |||||
static void run(const typename Op::src_ctype* src0, | |||||
const typename Op::src_ctype* src1, | |||||
typename Op::dst_ctype* dst, DType src0_dtype, | |||||
DType src1_dtype, DType dst_dtype, size_t batch, | |||||
size_t nr_channel_blocks, size_t channel_stride, | |||||
size_t channel_block_dim) { | |||||
megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8, | |||||
"only imp for nchw44/nchw88"); | |||||
Op op(src0_dtype, src1_dtype, dst_dtype); | |||||
if (channel_block_dim == 4) { | |||||
OpCallerBinaryBcast101xXVec<typename Op::src_ctype, 4>::run( | |||||
src0, src1, dst, op, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} else { | |||||
OpCallerBinaryBcast101xXVec<typename Op::src_ctype, 8>::run( | |||||
src0, src1, dst, op, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} | |||||
} | |||||
}; | |||||
template <typename ctype> | template <typename ctype> | ||||
struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101x4> { | |||||
struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101xX> { | |||||
using Op = PowOp<ctype, ctype>; | using Op = PowOp<ctype, ctype>; | ||||
static void run(const typename Op::src_ctype* src0, | static void run(const typename Op::src_ctype* src0, | ||||
const typename Op::src_ctype* src1, | const typename Op::src_ctype* src1, | ||||
@@ -427,18 +516,37 @@ struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101x4> { | |||||
} | } | ||||
}; | }; | ||||
template <typename Op> | |||||
struct OpCallerBinary<Op, VEC_BCAST101x4> { | |||||
static void run(const typename Op::src_ctype* src0, | |||||
const typename Op::src_ctype* src1, | |||||
typename Op::dst_ctype* dst, DType src0_dtype, | |||||
DType src1_dtype, DType dst_dtype, size_t batch, | |||||
size_t nr_channel_blocks, size_t channel_stride, | |||||
size_t channel_block_dim) { | |||||
megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); | |||||
Op op(src0_dtype, src1_dtype, dst_dtype); | |||||
ParamElemVisitor<typename Op::src_ctype> vis0; | |||||
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis1; | |||||
template <typename src_ctype, size_t channel_block_dim> | |||||
struct OpCallerBinaryVecBcast101xX { | |||||
template <typename Op> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
typename Op::dst_ctype* dst, const Op& op, size_t batch, | |||||
size_t nr_channel_blocks, size_t channel_stride) { | |||||
for (size_t b = 0; b < batch; b++) { | |||||
auto src1_ptr = src1; | |||||
for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | |||||
auto src1_block_ptr = src1_ptr + cb * channel_block_dim; | |||||
for (size_t img_index = 0; img_index < channel_stride; | |||||
img_index++) { | |||||
for (size_t c_iter = 0; c_iter < channel_block_dim; | |||||
c_iter++) { | |||||
op(*src0, *(src1_block_ptr + c_iter), dst); | |||||
src0++; | |||||
dst++; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
template <typename src_ctype, size_t channel_block_dim> | |||||
struct OpCallerBinaryVecBcast101xD { | |||||
template <typename Op, typename Vis0, typename Vis1> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, | |||||
const Vis1& vis1, size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride) { | |||||
for (size_t b = 0; b < batch; b++) { | for (size_t b = 0; b < batch; b++) { | ||||
auto src1_ptr = src1; | auto src1_ptr = src1; | ||||
for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | ||||
@@ -467,6 +575,60 @@ struct OpCallerBinary<Op, VEC_BCAST101x4> { | |||||
} | } | ||||
}; | }; | ||||
template <typename src_ctype> | |||||
struct OpCallerBinaryVecBcast101xX<src_ctype, 4> { | |||||
template <typename Op> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
typename Op::dst_ctype* dst, const Op& op, size_t batch, | |||||
size_t nr_channel_blocks, size_t channel_stride) { | |||||
ParamElemVisitor<src_ctype> vis0; | |||||
ParamElemVisitorBcast101x4<src_ctype> vis1; | |||||
OpCallerBinaryVecBcast101xD<src_ctype, 4>::run( | |||||
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} | |||||
}; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
template <> | |||||
struct OpCallerBinaryVecBcast101xX<__fp16, 8> { | |||||
using src_ctype = __fp16; | |||||
template <typename Op> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
typename Op::dst_ctype* dst, const Op& op, size_t batch, | |||||
size_t nr_channel_blocks, size_t channel_stride) { | |||||
ParamElemVisitor<src_ctype> vis0; | |||||
ParamElemVisitorBcast101x8<src_ctype> vis1; | |||||
OpCallerBinaryVecBcast101xD<src_ctype, 8>::run( | |||||
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} | |||||
}; | |||||
#endif | |||||
template <typename Op> | |||||
struct OpCallerBinary<Op, VEC_BCAST101xX> { | |||||
static void run(const typename Op::src_ctype* src0, | |||||
const typename Op::src_ctype* src1, | |||||
typename Op::dst_ctype* dst, DType src0_dtype, | |||||
DType src1_dtype, DType dst_dtype, size_t batch, | |||||
size_t nr_channel_blocks, size_t channel_stride, | |||||
size_t channel_block_dim) { | |||||
megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8, | |||||
"only imp for nchw44/nchw88"); | |||||
Op op(src0_dtype, src1_dtype, dst_dtype); | |||||
if (channel_block_dim == 4) { | |||||
OpCallerBinaryVecBcast101xX<typename Op::src_ctype, 4>::run( | |||||
src0, src1, dst, op, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} else { | |||||
OpCallerBinaryVecBcast101xX<typename Op::src_ctype, 8>::run( | |||||
src0, src1, dst, op, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} | |||||
} | |||||
}; | |||||
template <typename Op> | template <typename Op> | ||||
struct OpCallerBinary<Op, VEC_SCALAR> { | struct OpCallerBinary<Op, VEC_SCALAR> { | ||||
static void run(const typename Op::src_ctype* src0, | static void run(const typename Op::src_ctype* src0, | ||||
@@ -683,21 +845,42 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { | |||||
} | } | ||||
}; | }; | ||||
//! src0: CHW44, src1: vector, src2: CHW44 | |||||
template <typename Op> | |||||
struct OpCallerTernary<Op, BCAST101x4_VEC_BCAST101x4> { | |||||
static void run(const typename Op::src_ctype* src0, | |||||
const typename Op::src_ctype* src1, | |||||
const typename Op::src_ctype* src2, | |||||
typename Op::dst_ctype* dst, DType src0_dtype, | |||||
DType src1_dtype, DType src2_dtype, DType dst_dtype, | |||||
size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride, size_t channel_block_dim) { | |||||
megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); | |||||
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | |||||
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis0; | |||||
ParamElemVisitor<typename Op::src_ctype> vis1; | |||||
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis2; | |||||
template <typename src_ctype, size_t channel_block_dim> | |||||
struct OpCallerTernaryBcast101xXVecBcast101xX { | |||||
template <typename Op> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
const src_ctype* src2, typename Op::dst_ctype* dst, | |||||
const Op& op, size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride) { | |||||
for (size_t b = 0; b < batch; b++) { | |||||
auto src0_ptr = src0; | |||||
auto src2_ptr = src2; | |||||
for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | |||||
auto src0_block_ptr = src0_ptr + cb * channel_block_dim; | |||||
auto src2_block_ptr = src2_ptr + cb * channel_block_dim; | |||||
for (size_t img_index = 0; img_index < channel_stride; | |||||
img_index++) { | |||||
for (size_t c_iter = 0; c_iter < channel_block_dim; | |||||
c_iter++) { | |||||
op(*(src0_block_ptr + c_iter), *src1, | |||||
*(src2_block_ptr + c_iter), dst); | |||||
src1++; | |||||
dst++; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
template <typename src_ctype, size_t channel_block_dim> | |||||
struct OpCallerTernaryBcast101xDVecBcast101xD { | |||||
template <typename Op, typename Vis0, typename Vis1, typename Vis2> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
const src_ctype* src2, typename Op::dst_ctype* dst, | |||||
const Op& op, const Vis0& vis0, const Vis1& vis1, | |||||
const Vis2& vis2, size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride) { | |||||
for (size_t b = 0; b < batch; b++) { | for (size_t b = 0; b < batch; b++) { | ||||
auto src0_ptr = src0; | auto src0_ptr = src0; | ||||
auto src2_ptr = src2; | auto src2_ptr = src2; | ||||
@@ -731,6 +914,70 @@ struct OpCallerTernary<Op, BCAST101x4_VEC_BCAST101x4> { | |||||
} | } | ||||
}; | }; | ||||
//! src0: CHW44, src1: vector, src2: CHW44 | |||||
template <typename src_ctype> | |||||
struct OpCallerTernaryBcast101xXVecBcast101xX<src_ctype, 4> { | |||||
template <typename Op> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
const src_ctype* src2, typename Op::dst_ctype* dst, | |||||
const Op& op, size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride) { | |||||
ParamElemVisitorBcast101x4<src_ctype> vis0; | |||||
ParamElemVisitor<src_ctype> vis1; | |||||
ParamElemVisitorBcast101x4<src_ctype> vis2; | |||||
OpCallerTernaryBcast101xDVecBcast101xD<src_ctype, 4>::run( | |||||
src0, src1, src2, dst, op, vis0, vis1, vis2, batch, | |||||
nr_channel_blocks, channel_stride); | |||||
} | |||||
}; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
template <> | |||||
struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> { | |||||
using src_ctype = __fp16; | |||||
template <typename Op> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
const src_ctype* src2, typename Op::dst_ctype* dst, | |||||
const Op& op, size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride) { | |||||
ParamElemVisitorBcast101x8<src_ctype> vis0; | |||||
ParamElemVisitor<src_ctype> vis1; | |||||
ParamElemVisitorBcast101x8<src_ctype> vis2; | |||||
OpCallerTernaryBcast101xDVecBcast101xD<src_ctype, 8>::run( | |||||
src0, src1, src2, dst, op, vis0, vis1, vis2, batch, | |||||
nr_channel_blocks, channel_stride); | |||||
} | |||||
}; | |||||
#endif | |||||
template <typename Op> | |||||
struct OpCallerTernary<Op, BCAST101xX_VEC_BCAST101xX> { | |||||
static void run(const typename Op::src_ctype* src0, | |||||
const typename Op::src_ctype* src1, | |||||
const typename Op::src_ctype* src2, | |||||
typename Op::dst_ctype* dst, DType src0_dtype, | |||||
DType src1_dtype, DType src2_dtype, DType dst_dtype, | |||||
size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride, size_t channel_block_dim) { | |||||
megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8, | |||||
"only imp for nchw44/nchw88"); | |||||
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | |||||
if (channel_block_dim == 4) { | |||||
OpCallerTernaryBcast101xXVecBcast101xX<typename Op::src_ctype, | |||||
4>::run(src0, src1, src2, | |||||
dst, op, batch, | |||||
nr_channel_blocks, | |||||
channel_stride); | |||||
} else { | |||||
OpCallerTernaryBcast101xXVecBcast101xX<typename Op::src_ctype, | |||||
8>::run(src0, src1, src2, | |||||
dst, op, batch, | |||||
nr_channel_blocks, | |||||
channel_stride); | |||||
} | |||||
} | |||||
}; | |||||
//! src1: 1C11, src0 and src2 are contig | //! src1: 1C11, src0 and src2 are contig | ||||
template <typename Op> | template <typename Op> | ||||
struct OpCallerTernary<Op, VEC_BCAST101_VEC> { | struct OpCallerTernary<Op, VEC_BCAST101_VEC> { | ||||
@@ -775,21 +1022,41 @@ struct OpCallerTernary<Op, VEC_BCAST101_VEC> { | |||||
} | } | ||||
}; | }; | ||||
template <typename src_ctype, size_t channel_block_dim> | |||||
struct OpCallerTernaryVecBcast101xXVec { | |||||
template <typename Op> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
const src_ctype* src2, typename Op::dst_ctype* dst, | |||||
const Op& op, size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride) { | |||||
for (size_t b = 0; b < batch; b++) { | |||||
auto src1_ptr = src1; | |||||
for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | |||||
auto src1_block_ptr = src1_ptr + cb * channel_block_dim; | |||||
for (size_t img_index = 0; img_index < channel_stride; | |||||
img_index++) { | |||||
for (size_t c_iter = 0; c_iter < channel_block_dim; | |||||
c_iter++) { | |||||
op(*src0, *(src1_block_ptr + c_iter), *src2, dst); | |||||
src0++; | |||||
src2++; | |||||
dst++; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
//! src1: CHW44, src0 and src2 are contig | //! src1: CHW44, src0 and src2 are contig | ||||
template <typename Op> | |||||
struct OpCallerTernary<Op, VEC_BCAST101x4_VEC> { | |||||
static void run(const typename Op::src_ctype* src0, | |||||
const typename Op::src_ctype* src1, | |||||
const typename Op::src_ctype* src2, | |||||
typename Op::dst_ctype* dst, DType src0_dtype, | |||||
DType src1_dtype, DType src2_dtype, DType dst_dtype, | |||||
size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride, size_t channel_block_dim) { | |||||
megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); | |||||
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | |||||
ParamElemVisitor<typename Op::src_ctype> vis0; | |||||
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis1; | |||||
ParamElemVisitor<typename Op::src_ctype> vis2; | |||||
template <typename src_ctype, size_t channel_block_dim> | |||||
struct OpCallerTernaryVecBcast101xDVec { | |||||
template <typename Op, typename Vis0, typename Vis1, typename Vis2> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
const src_ctype* src2, typename Op::dst_ctype* dst, | |||||
const Op& op, const Vis0& vis0, const Vis1& vis1, | |||||
const Vis2& vis2, size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride) { | |||||
for (size_t b = 0; b < batch; b++) { | for (size_t b = 0; b < batch; b++) { | ||||
auto src1_ptr = src1; | auto src1_ptr = src1; | ||||
for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | ||||
@@ -821,6 +1088,66 @@ struct OpCallerTernary<Op, VEC_BCAST101x4_VEC> { | |||||
} | } | ||||
}; | }; | ||||
template <typename src_ctype> | |||||
struct OpCallerTernaryVecBcast101xXVec<src_ctype, 4> { | |||||
template <typename Op> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
const src_ctype* src2, typename Op::dst_ctype* dst, | |||||
const Op& op, size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride) { | |||||
ParamElemVisitor<src_ctype> vis0; | |||||
ParamElemVisitorBcast101x4<src_ctype> vis1; | |||||
ParamElemVisitor<src_ctype> vis2; | |||||
OpCallerTernaryVecBcast101xDVec<src_ctype, 4>::run( | |||||
src0, src1, src2, dst, op, vis0, vis1, vis2, batch, | |||||
nr_channel_blocks, channel_stride); | |||||
} | |||||
}; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
template <> | |||||
struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> { | |||||
using src_ctype = __fp16; | |||||
template <typename Op> | |||||
static void run(const src_ctype* src0, const src_ctype* src1, | |||||
const src_ctype* src2, typename Op::dst_ctype* dst, | |||||
const Op& op, size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride) { | |||||
ParamElemVisitor<src_ctype> vis0; | |||||
ParamElemVisitorBcast101x8<src_ctype> vis1; | |||||
ParamElemVisitor<src_ctype> vis2; | |||||
OpCallerTernaryVecBcast101xDVec<src_ctype, 8>::run( | |||||
src0, src1, src2, dst, op, vis0, vis1, vis2, batch, | |||||
nr_channel_blocks, channel_stride); | |||||
} | |||||
}; | |||||
#endif | |||||
template <typename Op> | |||||
struct OpCallerTernary<Op, VEC_BCAST101xX_VEC> { | |||||
static void run(const typename Op::src_ctype* src0, | |||||
const typename Op::src_ctype* src1, | |||||
const typename Op::src_ctype* src2, | |||||
typename Op::dst_ctype* dst, DType src0_dtype, | |||||
DType src1_dtype, DType src2_dtype, DType dst_dtype, | |||||
size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride, size_t channel_block_dim) { | |||||
megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8, | |||||
"only imp for nchw44/nchw88"); | |||||
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | |||||
if (channel_block_dim == 4) { | |||||
OpCallerTernaryVecBcast101xXVec<typename Op::src_ctype, 4>::run( | |||||
src0, src1, src2, dst, op, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} else { | |||||
OpCallerTernaryVecBcast101xXVec<typename Op::src_ctype, 8>::run( | |||||
src0, src1, src2, dst, op, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} | |||||
} | |||||
}; | |||||
//! src1: scalar, src0 and src2 has the same shape | //! src1: scalar, src0 and src2 has the same shape | ||||
template <typename Op> | template <typename Op> | ||||
struct OpCallerTernary<Op, VEC_SCALAR_VEC> { | struct OpCallerTernary<Op, VEC_SCALAR_VEC> { | ||||
@@ -53,6 +53,20 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) { | |||||
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | ||||
checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); | checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); | ||||
//! nchw88 | |||||
checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||||
checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||||
checker.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}}); | |||||
checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); | |||||
checker.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}}); | |||||
//! nchw88 | |||||
checker.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}}); | |||||
checker.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}}); | |||||
checker.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}}); | |||||
checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); | |||||
checker.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}}); | |||||
checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); | checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); | ||||
checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); | checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); | ||||
checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); | checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); | ||||
@@ -227,6 +241,78 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) { | |||||
run(Mode::POW); | run(Mode::POW); | ||||
} | } | ||||
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) { | |||||
using Mode = ElemwiseForward::Param::Mode; | |||||
Checker<ElemwiseForward> checker(handle()); | |||||
checker.set_param(Mode::FUSE_ADD_RELU) | |||||
.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}}); | |||||
checker.set_param(Mode::FUSE_ADD_RELU) | |||||
.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}}); | |||||
checker.set_param(Mode::FUSE_ADD_RELU) | |||||
.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}}); | |||||
checker.set_param(Mode::FUSE_ADD_RELU) | |||||
.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); | |||||
checker.set_param(Mode::FUSE_ADD_RELU) | |||||
.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}}); | |||||
checker.set_param(Mode::FUSE_ADD_RELU) | |||||
.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||||
checker.set_param(Mode::FUSE_ADD_RELU) | |||||
.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||||
checker.set_param(Mode::FUSE_ADD_RELU) | |||||
.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}}); | |||||
checker.set_param(Mode::FUSE_ADD_RELU) | |||||
.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); | |||||
checker.set_param(Mode::FUSE_ADD_RELU) | |||||
.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}}); | |||||
auto run = [&](Mode mode) { | |||||
// VEC_BCAST101x | |||||
checker.set_param(mode).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||||
checker.set_param(mode).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||||
checker.set_param(mode).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}}); | |||||
checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); | |||||
checker.set_param(mode).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}}); | |||||
// BCAST101x_VEC not powOp | |||||
checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}}); | |||||
checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}}); | |||||
checker.set_param(mode).execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}}); | |||||
checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); | |||||
checker.set_param(mode).execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}}); | |||||
}; | |||||
auto run_all = [&]() { | |||||
run(Mode::ADD); | |||||
run(Mode::FUSE_ADD_H_SWISH); | |||||
run(Mode::FUSE_ADD_RELU); | |||||
run(Mode::MAX); | |||||
run(Mode::MIN); | |||||
run(Mode::MUL); | |||||
run(Mode::SUB); | |||||
run(Mode::TRUE_DIV); | |||||
run(Mode::POW); | |||||
}; | |||||
{ | |||||
UniformFloatRNG rng(1e-5, 7e1); | |||||
checker.set_rng(0, &rng); | |||||
checker.set_epsilon(1e-5); | |||||
checker.set_dtype(0, dtype::Float32()); | |||||
checker.set_dtype(1, dtype::Float32()); | |||||
run_all(); | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
{ | |||||
UniformFloatRNG rng(1, 2); | |||||
checker.set_rng(0, &rng); | |||||
checker.set_epsilon(3e-3); | |||||
checker.set_dtype(0, dtype::Float16()); | |||||
checker.set_dtype(1, dtype::Float16()); | |||||
run_all(); | |||||
} | |||||
#endif | |||||
} | |||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
namespace { | namespace { | ||||
void run_elemwise_benchmark(const TensorShapeArray& shapes, | void run_elemwise_benchmark(const TensorShapeArray& shapes, | ||||