|
|
@@ -13,8 +13,8 @@ |
|
|
|
#pragma once |
|
|
|
|
|
|
|
#include "megdnn/basic_types.h" |
|
|
|
#include "src/arm_common/elemwise_helper/elemwise_op.h" |
|
|
|
#include "src/arm_common/elemwise_helper/kimpl/op_base.h" |
|
|
|
#include "src/arm_common/elemwise_op.h" |
|
|
|
#include "src/fallback/conv_bias/opr_impl.h" |
|
|
|
|
|
|
|
#include "midout.h" |
|
|
@@ -44,29 +44,29 @@ namespace { |
|
|
|
break; |
|
|
|
|
|
|
|
#define FOR_NONLINEAR_UNARY(_op) \ |
|
|
|
megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::VEC>::run( \ |
|
|
|
megdnn::elemwise::OpCallerUnary<_op<ctype>, megdnn::elemwise::VEC>::run( \ |
|
|
|
static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \ |
|
|
|
bias_type, dst_type, N* OC* OH* OW* pack_oc_size); |
|
|
|
|
|
|
|
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ |
|
|
|
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \ |
|
|
|
static_cast<ctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const ctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ |
|
|
|
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ |
|
|
|
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101>::run( \ |
|
|
|
static_cast<ctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const ctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ |
|
|
|
OH* OW); |
|
|
|
|
|
|
|
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ |
|
|
|
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \ |
|
|
|
static_cast<ctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const ctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ |
|
|
|
OH* OW, pack_oc_size); |
|
|
|
|
|
|
|
#define FOR_NONLINEAR_BINARY(_op) \ |
|
|
|
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \ |
|
|
|
static_cast<ctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const ctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ |
|
|
|
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101xX>:: \ |
|
|
|
run(static_cast<ctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const ctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ |
|
|
|
OC, OH* OW, pack_oc_size); |
|
|
|
|
|
|
|
#define FOR_NONLINEAR_BINARY(_op) \ |
|
|
|
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_VEC>::run( \ |
|
|
|
static_cast<ctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const ctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ |
|
|
|
N* OC* OH* OW* pack_oc_size); |
|
|
|
|
|
|
|
#define FOR_BIAS(_mode) \ |
|
|
@@ -167,33 +167,35 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { |
|
|
|
#undef FOR_BIAS |
|
|
|
#undef HANDLE_IDENTITY |
|
|
|
|
|
|
|
#define FOR_NONLINEAR_UNARY(_op) \ |
|
|
|
megdnn::arm_common::OpCallerUnary<_op<opctype, opdtype>, megdnn::VEC>::run( \ |
|
|
|
static_cast<opctype*>(conv_dst_ptr), reinterpret_cast<opdtype*>(dst_ptr), \ |
|
|
|
bias_type, dst_type, N* OC* OH* OW* pack_oc_size); |
|
|
|
|
|
|
|
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ |
|
|
|
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101>:: \ |
|
|
|
run(static_cast<opctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const opctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ |
|
|
|
#define FOR_NONLINEAR_UNARY(_op) \ |
|
|
|
megdnn::elemwise::OpCallerUnary<_op<opctype, opdtype>, megdnn::elemwise::VEC>:: \ |
|
|
|
run(static_cast<opctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<opdtype*>(dst_ptr), bias_type, dst_type, \ |
|
|
|
N* OC* OH* OW* pack_oc_size); |
|
|
|
|
|
|
|
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ |
|
|
|
megdnn::elemwise::OpCallerBinary< \ |
|
|
|
_op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101>:: \ |
|
|
|
run(static_cast<opctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const opctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ |
|
|
|
N, OC, OH* OW); |
|
|
|
|
|
|
|
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ |
|
|
|
megdnn::arm_common:: \ |
|
|
|
OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \ |
|
|
|
static_cast<opctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const opctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ |
|
|
|
dst_type, N, OC, OH* OW, pack_oc_size); |
|
|
|
|
|
|
|
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ |
|
|
|
megdnn::arm_common:: \ |
|
|
|
OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \ |
|
|
|
static_cast<opctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const opctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ |
|
|
|
dst_type, N, OC, OH* OW, pack_oc_size); |
|
|
|
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ |
|
|
|
megdnn::elemwise::OpCallerBinary< \ |
|
|
|
_op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101xX>:: \ |
|
|
|
run(static_cast<opctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const opctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ |
|
|
|
N, OC, OH* OW, pack_oc_size); |
|
|
|
|
|
|
|
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ |
|
|
|
megdnn::elemwise::OpCallerBinary< \ |
|
|
|
_op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101xX>:: \ |
|
|
|
run(static_cast<opctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const opctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ |
|
|
|
N, OC, OH* OW, pack_oc_size); |
|
|
|
|
|
|
|
#define HANDLE_IDENTITY(_caller, _op) \ |
|
|
|
case megdnn::NonlineMode::IDENTITY: \ |
|
|
@@ -267,25 +269,25 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { |
|
|
|
#undef FOR_NONLINEAR |
|
|
|
#undef FOR_BIAS |
|
|
|
|
|
|
|
#define FOR_BINARY_BROADCAST(_op) \ |
|
|
|
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \ |
|
|
|
static_cast<ctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const ctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ |
|
|
|
#define FOR_BINARY_BROADCAST(_op) \ |
|
|
|
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101>::run( \ |
|
|
|
static_cast<ctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const ctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ |
|
|
|
OH* OW); |
|
|
|
|
|
|
|
#define FOR_BINARY_BROADCAST_NCHWXX(_op) \ |
|
|
|
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \ |
|
|
|
static_cast<ctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const ctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ |
|
|
|
OH* OW, pack_oc_size); |
|
|
|
|
|
|
|
#define FOR_BINARY(_op) \ |
|
|
|
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \ |
|
|
|
static_cast<ctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const ctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ |
|
|
|
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101xX>:: \ |
|
|
|
run(static_cast<ctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const ctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ |
|
|
|
OC, OH* OW, pack_oc_size); |
|
|
|
|
|
|
|
#define FOR_BINARY(_op) \ |
|
|
|
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_VEC>::run( \ |
|
|
|
static_cast<ctype*>(conv_dst_ptr), \ |
|
|
|
reinterpret_cast<const ctype*>(bias_ptr), \ |
|
|
|
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ |
|
|
|
N* OC* OH* OW* pack_oc_size); |
|
|
|
|
|
|
|
#define FOR_BIAS(_bias_mode, OH, OW) \ |
|
|
|