GitOrigin-RevId: fe7b335545
release-1.10
@@ -15,7 +15,7 @@ | |||||
#include "src/arm_common/elemwise_helper/op_binary.h" | #include "src/arm_common/elemwise_helper/op_binary.h" | ||||
#include "src/arm_common/elemwise_helper/op_ternary.h" | #include "src/arm_common/elemwise_helper/op_ternary.h" | ||||
#include "src/arm_common/elemwise_helper/op_unary.h" | #include "src/arm_common/elemwise_helper/op_unary.h" | ||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
#include "src/fallback/elemwise_helper/op_common.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace elemwise { | namespace elemwise { | ||||
@@ -364,19 +364,11 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( | |||||
} | } | ||||
#define DISPATCH() \ | #define DISPATCH() \ | ||||
if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||||
DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ | |||||
} else if ( \ | |||||
param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||||
if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||||
DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \ | DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \ | ||||
} else if ( \ | } else if ( \ | ||||
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | ||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||||
DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ | |||||
} else if ( \ | |||||
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | ||||
DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \ | DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \ | ||||
} | } | ||||
@@ -467,17 +459,9 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( | |||||
#define DISPATCH() \ | #define DISPATCH() \ | ||||
if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | ||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||||
DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ | |||||
} else if ( \ | |||||
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||||
DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \ | DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \ | ||||
} else if ( \ | } else if ( \ | ||||
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||||
DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ | |||||
} else if ( \ | |||||
param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ | param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ | ||||
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | ||||
DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \ | DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \ | ||||
@@ -701,12 +685,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( | |||||
} | } | ||||
#define DISPATCH() \ | #define DISPATCH() \ | ||||
if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||||
DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ | |||||
} else if ( \ | |||||
param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||||
if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||||
DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \ | DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \ | ||||
} | } | ||||
@@ -12,61 +12,4 @@ | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | #include "src/fallback/general_intrinsic/gi_float.h" | ||||
#include "src/fallback/general_intrinsic/gi_int.h" | #include "src/fallback/general_intrinsic/gi_int.h" | ||||
namespace megdnn { | |||||
namespace elemwise { | |||||
///////////////////////////////// ParamElemVistor /////////////////////////// | |||||
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||||
template <> \ | |||||
struct ParamElemVisitor<_ctype> { \ | |||||
_simd_type operator()(const _ctype* src) const { \ | |||||
return GiLoad##_fun_suffix(src); \ | |||||
} \ | |||||
}; \ | |||||
template <> \ | |||||
struct ParamElemVisitorDup<_ctype> { \ | |||||
_simd_type operator()(const _ctype* src) const { \ | |||||
return GiBroadcast##_fun_suffix( \ | |||||
*reinterpret_cast<const _inner_ctype*>(src)); \ | |||||
} \ | |||||
} | |||||
cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||||
cb(dt_qint8, int8_t, GI_INT8_t, Int8); | |||||
cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||||
cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||||
cb(dt_int8, int8_t, GI_INT8_t, Int8); | |||||
#undef cb | |||||
template <typename ctype> | |||||
struct ParamElemVisitorBcast101x4; | |||||
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ | |||||
template <> \ | |||||
struct ParamElemVisitorBcast101x4<_ctype> { \ | |||||
_simd_type operator()(const _ctype* src) const { \ | |||||
return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ | |||||
*reinterpret_cast<const _inner_ctype*>(src))); \ | |||||
} \ | |||||
} | |||||
cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); | |||||
cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); | |||||
#undef cb | |||||
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||||
template <> \ | |||||
struct ParamElemVisitorBcast101x4<_ctype> { \ | |||||
_simd_type operator()(const _ctype* src) const { \ | |||||
return GiLoad##_fun_suffix(src); \ | |||||
} \ | |||||
} | |||||
cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||||
cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||||
cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||||
#undef cb | |||||
} // namespace elemwise | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -87,7 +87,7 @@ template <> | |||||
struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> { | struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> { | ||||
using FuseAddHSwishOpBase::FuseAddHSwishOpBase; | using FuseAddHSwishOpBase::FuseAddHSwishOpBase; | ||||
using FuseAddHSwishOpBase::operator(); | using FuseAddHSwishOpBase::operator(); | ||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); | |||||
void operator()( | void operator()( | ||||
const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, | const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, | ||||
dt_qint8* dst) const { | dt_qint8* dst) const { | ||||
@@ -41,7 +41,7 @@ struct UnaryOpBase : OpBase<src_ctype, dst_ctype> { | |||||
GiStoreLowInt8( \ | GiStoreLowInt8( \ | ||||
reinterpret_cast<int8_t*>(dst + 8), \ | reinterpret_cast<int8_t*>(dst + 8), \ | ||||
operator()({{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}})); \ | operator()({{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}})); \ | ||||
GI_INT16_t vsrct2 = GiMoveHighLongInt8(vsrc.val[1]); \ | |||||
GI_INT16_t vsrct2 = GiMoveLowLongInt8(vsrc.val[1]); \ | |||||
GiStoreLowInt8( \ | GiStoreLowInt8( \ | ||||
reinterpret_cast<int8_t*>(dst + 16), \ | reinterpret_cast<int8_t*>(dst + 16), \ | ||||
operator()({{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \ | operator()({{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \ | ||||
@@ -330,7 +330,7 @@ struct UnaryQuantizationOp; | |||||
template <typename Op> | template <typename Op> | ||||
struct UnaryQuantizationOp<dt_qint8, dt_qint8, Op> : UnaryOpBase<dt_qint8, dt_qint8> { | struct UnaryQuantizationOp<dt_qint8, dt_qint8, Op> : UnaryOpBase<dt_qint8, dt_qint8> { | ||||
using UnaryOpBase<dt_qint8, dt_qint8>::UnaryOpBase; | using UnaryOpBase<dt_qint8, dt_qint8>::UnaryOpBase; | ||||
constexpr static size_t SIMD_WIDTH = 16; | |||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
Op op; | Op op; | ||||
void operator()(const dt_qint8& src, dt_qint8* dst) const { | void operator()(const dt_qint8& src, dt_qint8* dst) const { | ||||
@@ -354,7 +354,7 @@ struct UnaryQuantizationOp<dt_qint8, dt_qint8, Op> : UnaryOpBase<dt_qint8, dt_qi | |||||
auto val = this->op({{vitem0, vitem1}}); | auto val = this->op({{vitem0, vitem1}}); | ||||
val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst); | val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst); | ||||
val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst); | val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst); | ||||
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V4_t>(val); | |||||
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>(val); | |||||
} | } | ||||
}; | }; | ||||
@@ -364,7 +364,7 @@ struct BinaryQuantizationOp; | |||||
template <typename Op> | template <typename Op> | ||||
struct BinaryQuantizationOp<dt_qint8, dt_qint8, Op> : BinaryOpBase<dt_qint8, dt_qint8> { | struct BinaryQuantizationOp<dt_qint8, dt_qint8, Op> : BinaryOpBase<dt_qint8, dt_qint8> { | ||||
using BinaryOpBase<dt_qint8, dt_qint8>::BinaryOpBase; | using BinaryOpBase<dt_qint8, dt_qint8>::BinaryOpBase; | ||||
constexpr static size_t SIMD_WIDTH = 16; | |||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
Op op; | Op op; | ||||
void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { | void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { | ||||
@@ -403,7 +403,7 @@ template <typename Op> | |||||
struct TernaryQuantizationOp<dt_qint8, dt_qint8, Op> | struct TernaryQuantizationOp<dt_qint8, dt_qint8, Op> | ||||
: TernaryOpBase<dt_qint8, dt_qint8> { | : TernaryOpBase<dt_qint8, dt_qint8> { | ||||
using TernaryOpBase<dt_qint8, dt_qint8>::TernaryOpBase; | using TernaryOpBase<dt_qint8, dt_qint8>::TernaryOpBase; | ||||
constexpr static size_t SIMD_WIDTH = 16; | |||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
Op op; | Op op; | ||||
void operator()( | void operator()( | ||||
@@ -69,7 +69,7 @@ struct ReluOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> { | |||||
template <> | template <> | ||||
struct ReluOp<dt_qint8, dt_qint8> : ReluOpBase<dt_qint8, dt_qint8> { | struct ReluOp<dt_qint8, dt_qint8> : ReluOpBase<dt_qint8, dt_qint8> { | ||||
using ReluOpBase::ReluOpBase; | using ReluOpBase::ReluOpBase; | ||||
constexpr static size_t SIMD_WIDTH = 16; | |||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
using ReluOpBase::operator(); | using ReluOpBase::operator(); | ||||
void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { | void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { | ||||
@@ -8,6 +8,7 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace elemwise { | namespace elemwise { | ||||
/*! | /*! | ||||
* \brief broadcast type | * \brief broadcast type | ||||
* BCAST_x[0]x[1]...: x[i] == !stride[i] | * BCAST_x[0]x[1]...: x[i] == !stride[i] | ||||
@@ -49,6 +50,55 @@ struct ParamElemVisitorDup; | |||||
template <typename ctype> | template <typename ctype> | ||||
struct ParamElemVisitorBcast101x4; | struct ParamElemVisitorBcast101x4; | ||||
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||||
template <> \ | |||||
struct ParamElemVisitor<_ctype> { \ | |||||
_simd_type operator()(const _ctype* src) const { \ | |||||
return GiLoad##_fun_suffix(src); \ | |||||
} \ | |||||
}; \ | |||||
template <> \ | |||||
struct ParamElemVisitorDup<_ctype> { \ | |||||
_simd_type operator()(const _ctype* src) const { \ | |||||
return GiBroadcast##_fun_suffix( \ | |||||
*reinterpret_cast<const _inner_ctype*>(src)); \ | |||||
} \ | |||||
} | |||||
cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||||
cb(dt_qint8, int8_t, GI_INT8_t, Int8); | |||||
cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||||
cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||||
cb(dt_int8, int8_t, GI_INT8_t, Int8); | |||||
#undef cb | |||||
template <typename ctype> | |||||
struct ParamElemVisitorBcast101x4; | |||||
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ | |||||
template <> \ | |||||
struct ParamElemVisitorBcast101x4<_ctype> { \ | |||||
_simd_type operator()(const _ctype* src) const { \ | |||||
return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ | |||||
*reinterpret_cast<const _inner_ctype*>(src))); \ | |||||
} \ | |||||
} | |||||
cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); | |||||
cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); | |||||
#undef cb | |||||
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||||
template <> \ | |||||
struct ParamElemVisitorBcast101x4<_ctype> { \ | |||||
_simd_type operator()(const _ctype* src) const { \ | |||||
return GiLoad##_fun_suffix(src); \ | |||||
} \ | |||||
} | |||||
cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||||
cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||||
cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||||
#undef cb | |||||
///////////////////////////////// OpCaller ///////////////////////////// | ///////////////////////////////// OpCaller ///////////////////////////// | ||||
template <typename Op, BcastType bcast_type> | template <typename Op, BcastType bcast_type> | ||||
struct OpCallerUnary; | struct OpCallerUnary; | ||||
@@ -50,6 +50,18 @@ protected: | |||||
void on_fuse_mul_add3_uint8xf32xf32xf32( | void on_fuse_mul_add3_uint8xf32xf32xf32( | ||||
const ElemwiseOpParamN<3>& param, const TensorND& dst) override; | const ElemwiseOpParamN<3>& param, const TensorND& dst) override; | ||||
void on_quantized_mode( | |||||
const ElemwiseOpParamN<1>& param, const TensorND& dst, | |||||
Elemwise::Mode mode) override; | |||||
void on_quantized_mode( | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst, | |||||
Elemwise::Mode mode) override; | |||||
void on_quantized_mode( | |||||
const ElemwiseOpParamN<3>& param, const TensorND& dst, | |||||
Elemwise::Mode mode) override; | |||||
public: | public: | ||||
using naive::ElemwiseMultiTypeImpl::ElemwiseMultiTypeImpl; | using naive::ElemwiseMultiTypeImpl::ElemwiseMultiTypeImpl; | ||||
}; | }; | ||||
@@ -0,0 +1,499 @@ | |||||
/** | |||||
* \file dnn/src/fallback/elemwise_multi_type/quantized_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megdnn/tensor_iter.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
#include "src/fallback/elemwise_multi_type/opr_impl.h" | |||||
#include "src/naive/handle.h" | |||||
using namespace megdnn; | |||||
using namespace fallback; | |||||
using namespace elemwise; | |||||
void ElemwiseMultiTypeImpl::on_quantized_mode( | |||||
const ElemwiseOpParamN<1>& param, const TensorND& dst, Elemwise::Mode mode) { | |||||
megdnn_assert(param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); | |||||
megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); | |||||
#define DISPATCH_MODE(_src_dt, _dst_dt) \ | |||||
switch (mode) { \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::RELU, ReluOp) \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::H_SWISH, HSwishOp) \ | |||||
default: \ | |||||
break; \ | |||||
} | |||||
#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ | |||||
switch (mode) { \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::RELU, ReluOp) \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ABS, AbsOp) \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SIGMOID, SigmoidOp) \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::EXP, ExpOp) \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::TANH, TanhOp) \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FAST_TANH, FastTanhOp) \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::H_SWISH, HSwishOp) \ | |||||
default: \ | |||||
break; \ | |||||
} | |||||
#define DISPATCH() \ | |||||
if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||||
DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ | |||||
} else if ( \ | |||||
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||||
DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ | |||||
} | |||||
TensorND src = param[0]; | |||||
size_t nr_elems = src.layout.total_nr_elems(); | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
case _mode: { \ | |||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
thin_function<void(const src_ctype*, dst_ctype*, DType, DType, size_t)> run = \ | |||||
OpCallerUnary<_op<src_ctype, dst_ctype>, VEC>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \ | |||||
dst.layout.dtype, nr_elems)); \ | |||||
return; \ | |||||
} | |||||
DISPATCH() | |||||
naive::ElemwiseMultiTypeImpl::on_quantized_mode(param, dst, mode); | |||||
#undef DISPATCH_SINGLE_MODE | |||||
#undef DISPATCH | |||||
#undef DISPATCH_QUANTIZED_MODE | |||||
#undef DISPATCH_MODE | |||||
} | |||||
void ElemwiseMultiTypeImpl::on_quantized_mode( | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst, Elemwise::Mode mode) { | |||||
megdnn_assert( | |||||
param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv() && | |||||
param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); | |||||
megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); | |||||
#define DISPATCH_MODE(_src_dt, _dst_dt) \ | |||||
switch (mode) { \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ | |||||
DISPATCH_SINGLE_MODE( \ | |||||
_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, FuseAddReluOp) \ | |||||
DISPATCH_SINGLE_MODE( \ | |||||
_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_H_SWISH, FuseAddHSwishOp) \ | |||||
default: \ | |||||
break; \ | |||||
} | |||||
#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ | |||||
switch (mode) { \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MIN, MinOp) \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MAX, MaxOp) \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SUB, SubOp) \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MUL, MulOp) \ | |||||
DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::TRUE_DIV, TrueDivOp) \ | |||||
DISPATCH_SINGLE_MODE( \ | |||||
_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, FuseAddReluOp) \ | |||||
DISPATCH_SINGLE_MODE( \ | |||||
_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_SIGMOID, FuseAddSigmoidOp) \ | |||||
DISPATCH_SINGLE_MODE( \ | |||||
_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_TANH, FuseAddTanhOp) \ | |||||
DISPATCH_SINGLE_MODE( \ | |||||
_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_H_SWISH, FuseAddHSwishOp) \ | |||||
default: \ | |||||
break; \ | |||||
} | |||||
#define DISPATCH() \ | |||||
if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||||
DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ | |||||
} else if ( \ | |||||
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||||
DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ | |||||
} | |||||
TensorND src0 = param[0]; | |||||
TensorND src1 = param[1]; | |||||
//! VEC + VEC | |||||
if (is_vector(src0.layout) && is_vector(src1.layout)) { | |||||
size_t nr_elems = src0.layout.total_nr_elems(); | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
case _mode: { \ | |||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
thin_function<void( \ | |||||
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||||
size_t)> \ | |||||
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_VEC>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ | |||||
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \ | |||||
src0.layout.dtype, src1.layout.dtype, dst.layout.dtype, nr_elems)); \ | |||||
return; \ | |||||
} | |||||
DISPATCH() | |||||
#undef DISPATCH_SINGLE_MODE | |||||
} | |||||
//! VEC + SCALAR | |||||
{ | |||||
bool normal_case = is_vector(src0.layout) && is_broadcasted_scalar(src1.layout); | |||||
bool swap_case = false; | |||||
bool commutable = false; | |||||
if (mode != Elemwise::Mode::SUB && mode != Elemwise::Mode::TRUE_DIV) | |||||
commutable = true; | |||||
if (!normal_case && commutable) { | |||||
swap_case = is_vector(src1.layout) && is_broadcasted_scalar(src0.layout); | |||||
} | |||||
if (normal_case || swap_case) { | |||||
auto &lhs = src0, &rhs = src1; | |||||
if (swap_case) { | |||||
std::swap(lhs, rhs); | |||||
} | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
case _mode: { \ | |||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
thin_function<void( \ | |||||
const src_ctype*, const src_ctype, dst_ctype*, DType, DType, DType, \ | |||||
size_t)> \ | |||||
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_SCALAR>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \ | |||||
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||||
dst.layout.dtype, src0.layout.total_nr_elems())); \ | |||||
return; \ | |||||
} | |||||
DISPATCH() | |||||
#undef DISPATCH_SINGLE_MODE | |||||
} | |||||
//! SCALAR + VEC | |||||
if (!commutable && is_vector(src1.layout) && | |||||
is_broadcasted_scalar(src0.layout)) { | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
case _mode: { \ | |||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
thin_function<void( \ | |||||
const src_ctype, const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||||
size_t)> \ | |||||
run = OpCallerBinary<_op<src_ctype, dst_ctype>, SCALAR_VEC>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \ | |||||
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||||
dst.layout.dtype, src1.layout.total_nr_elems())); \ | |||||
return; \ | |||||
} | |||||
DISPATCH() | |||||
#undef DISPATCH_SINGLE_MODE | |||||
} | |||||
} | |||||
//! VEC + BCAST101 | |||||
{ | |||||
BroadcastChannelInfo binfo; | |||||
bool normal_case = is_vector(src0.layout) && | |||||
is_broadcasted_channel_like(src1.layout, binfo); | |||||
bool swap_case = false; | |||||
bool commutable = false; | |||||
if (mode != Elemwise::Mode::SUB && mode != Elemwise::Mode::TRUE_DIV) | |||||
commutable = true; | |||||
if (!normal_case && commutable) { | |||||
swap_case = is_vector(src1.layout) && | |||||
is_broadcasted_channel_like(src0.layout, binfo); | |||||
} | |||||
if (normal_case || swap_case) { | |||||
auto &lhs = src0, &rhs = src1; | |||||
if (swap_case) | |||||
std::swap(lhs, rhs); | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
case _mode: { \ | |||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
thin_function<void( \ | |||||
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||||
size_t, size_t, size_t)> \ | |||||
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||||
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||||
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ | |||||
return; \ | |||||
} | |||||
DISPATCH() | |||||
#undef DISPATCH_SINGLE_MODE | |||||
} | |||||
//! BCAST101 + VEC : only for SUB or TRUE_DIV | |||||
if (!commutable && is_vector(src1.layout) && | |||||
is_broadcasted_channel_like(src0.layout, binfo)) { | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
case _mode: { \ | |||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
thin_function<void( \ | |||||
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||||
size_t, size_t, size_t)> \ | |||||
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101_VEC>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||||
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||||
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ | |||||
return; \ | |||||
} | |||||
DISPATCH() | |||||
#undef DISPATCH_SINGLE_MODE | |||||
} | |||||
} | |||||
//! VEC + BCAST101x4 | |||||
{ | |||||
BroadcastChannelInfo binfo; | |||||
if (is_vector(src0.layout) && | |||||
(is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) { | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
case _mode: { \ | |||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
thin_function<void( \ | |||||
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||||
size_t, size_t, size_t, size_t)> \ | |||||
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101xX>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||||
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||||
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ | |||||
return; \ | |||||
} | |||||
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||||
DISPATCH() | |||||
#undef DISPATCH_SINGLE_MODE | |||||
} | |||||
//! BCAST101x + VEC | |||||
if (is_vector(src1.layout) && | |||||
is_broadcastedx_channel_like<4>(src0.layout, binfo)) { | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
case _mode: { \ | |||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
thin_function<void( \ | |||||
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||||
size_t, size_t, size_t, size_t)> \ | |||||
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101xX_VEC>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||||
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||||
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ | |||||
return; \ | |||||
} | |||||
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||||
DISPATCH() | |||||
#undef DISPATCH_SINGLE_MODE | |||||
} | |||||
} | |||||
naive::ElemwiseMultiTypeImpl::on_quantized_mode(param, dst, mode); | |||||
#undef DISPATCH_MODE | |||||
#undef DISPATCH_QUANTIZED_MODE | |||||
#undef DISPATCH | |||||
} | |||||
void ElemwiseMultiTypeImpl::on_quantized_mode( | |||||
const ElemwiseOpParamN<3>& param, const TensorND& dst, Elemwise::Mode mode) { | |||||
megdnn_assert( | |||||
param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv() && | |||||
param[0].layout.dtype.enumv() == param[2].layout.dtype.enumv() && | |||||
param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); | |||||
megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); | |||||
#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ | |||||
switch (mode) { \ | |||||
DISPATCH_SINGLE_MODE( \ | |||||
_src_dt, _dst_dt, Elemwise::Mode::FUSE_MUL_ADD3, FuseMulAdd3Op) \ | |||||
default: \ | |||||
break; \ | |||||
} | |||||
#define DISPATCH() \ | |||||
if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ | |||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||||
DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ | |||||
} | |||||
TensorND src0 = param[0]; | |||||
TensorND src1 = param[1]; | |||||
TensorND src2 = param[2]; | |||||
//! VEC + VEC + VEC | |||||
if (is_vector(src0.layout) && is_vector(src1.layout) && is_vector(src2.layout)) { | |||||
size_t nr_elems = src0.layout.total_nr_elems(); | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
case _mode: { \ | |||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
thin_function<void( \ | |||||
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ | |||||
DType, DType, DType, DType, size_t)> \ | |||||
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ | |||||
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), src2.ptr<src_ctype>(), \ | |||||
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||||
src2.layout.dtype, dst.layout.dtype, nr_elems)); \ | |||||
return; \ | |||||
} | |||||
DISPATCH() | |||||
#undef DISPATCH_SINGLE_MODE | |||||
} | |||||
//! VEC + VEC + SCALAR | |||||
if (is_vector(src0.layout) && is_vector(src1.layout) && | |||||
is_broadcasted_scalar(src2.layout)) { | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
case _mode: { \ | |||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
thin_function<void( \ | |||||
const src_ctype*, const src_ctype*, const src_ctype, dst_ctype*, \ | |||||
DType, DType, DType, DType, size_t)> \ | |||||
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_SCALAR>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||||
src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||||
src0.layout.total_nr_elems())); \ | |||||
return; \ | |||||
} | |||||
DISPATCH() | |||||
#undef DISPATCH_SINGLE_MODE | |||||
} | |||||
//! BCAST101 + VEC + BCAST101 | |||||
{ | |||||
BroadcastChannelInfo binfo; | |||||
bool normal_case = is_vector(src1.layout) && | |||||
is_broadcasted_channel_like(src0.layout, binfo) && | |||||
src0.layout.eq_shape(src2.layout); | |||||
if (normal_case) { | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
case _mode: { \ | |||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
thin_function<void( \ | |||||
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ | |||||
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ | |||||
run = OpCallerTernary< \ | |||||
_op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||||
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \ | |||||
binfo.y, binfo.z, binfo.y* binfo.z)); \ | |||||
return; \ | |||||
} | |||||
DISPATCH() | |||||
#undef DISPATCH_SINGLE_MODE | |||||
} | |||||
} | |||||
//! VEC + BCAST101x4 + VEC | |||||
{ | |||||
BroadcastChannelInfo binfo; | |||||
if (is_vector(src0.layout) && | |||||
(is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src1.layout, binfo)) && | |||||
src0.layout.eq_shape(src2.layout)) { | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
case _mode: { \ | |||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
thin_function<void( \ | |||||
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ | |||||
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ | |||||
run = OpCallerTernary< \ | |||||
_op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||||
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||||
batch_size, binfo.x, binfo.y, binfo.z)); \ | |||||
return; \ | |||||
} | |||||
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||||
DISPATCH() | |||||
#undef DISPATCH_SINGLE_MODE | |||||
} | |||||
//! BCAST101x + VEC +BCAST101x | |||||
if (is_vector(src1.layout) && | |||||
(is_broadcastedx_channel_like<4>(src0.layout, binfo) || | |||||
is_broadcastedx_channel_like<8>(src0.layout, binfo)) && | |||||
src0.layout.eq_shape(src2.layout)) { | |||||
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
case _mode: { \ | |||||
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
thin_function<void( \ | |||||
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ | |||||
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ | |||||
run = OpCallerTernary< \ | |||||
_op<src_ctype, dst_ctype>, BCAST101xX_VEC_BCAST101xX>::run; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||||
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||||
batch_size, binfo.x, binfo.y, binfo.z)); \ | |||||
return; \ | |||||
} | |||||
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||||
DISPATCH() | |||||
#undef DISPATCH_SINGLE_MODE | |||||
} | |||||
} | |||||
naive::ElemwiseMultiTypeImpl::on_quantized_mode(param, dst, mode); | |||||
#undef DISPATCH | |||||
#undef DISPATCH_QUANTIZED_MODE | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -60,6 +60,7 @@ | |||||
#define GI_NEON_INTRINSICS | #define GI_NEON_INTRINSICS | ||||
#if defined(__aarch64__) | #if defined(__aarch64__) | ||||
#define GI_NEON64_INTRINSICS | #define GI_NEON64_INTRINSICS | ||||
#define GI_NEON32_INTRINSICS | |||||
#else | #else | ||||
#define GI_NEON32_INTRINSICS | #define GI_NEON32_INTRINSICS | ||||
#endif | #endif | ||||
@@ -11,8 +11,10 @@ | |||||
*/ | */ | ||||
#include "test/common/elemwise_multi_type.h" | #include "test/common/elemwise_multi_type.h" | ||||
#include "megdnn/opr_param_defs.h" | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "test/arm_common/fixture.h" | #include "test/arm_common/fixture.h" | ||||
#include "test/common/benchmarker.h" | |||||
#include "test/common/checker.h" | #include "test/common/checker.h" | ||||
#include "test/common/task_record_check.h" | #include "test/common/task_record_check.h" | ||||
#include "test/common/timer.h" | #include "test/common/timer.h" | ||||
@@ -559,4 +561,95 @@ TEST_F(ARM_COMMON, ELEMWISE_FMA3_UINT8xF32xF32xF32_RECORD) { | |||||
.execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); | .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); | ||||
} | } | ||||
#if MEGDNN_WITH_BENCHMARK | |||||
namespace { | |||||
void run_elemwise_benchmark( | |||||
const TensorShapeArray& shapes, ElemwiseMultiType::Param::Mode mode, | |||||
const char* mode_str, std::vector<DType> types, Handle* handle_bench) { | |||||
auto handle_fallback = create_cpu_handle(1); | |||||
Benchmarker<ElemwiseMultiType> benchmarker_bench(handle_bench); | |||||
Benchmarker<ElemwiseMultiType> benchmarker_fallback(handle_fallback.get()); | |||||
float throughput = 0; | |||||
SmallVector<TensorLayout> layouts; | |||||
std::string src_strs; | |||||
for (size_t i = 0; i < shapes.size(); i++) { | |||||
layouts.emplace_back(shapes[i], types[i]); | |||||
throughput += layouts.back().span().dist_byte(); | |||||
src_strs += layouts.back().to_string(); | |||||
if (i != shapes.size() - 1) { | |||||
src_strs += ","; | |||||
} | |||||
} | |||||
constexpr size_t RUN = 50; | |||||
benchmarker_fallback.set_times(RUN).set_display(false); | |||||
benchmarker_bench.set_times(RUN).set_display(false); | |||||
benchmarker_fallback.set_param(mode); | |||||
benchmarker_bench.set_param(mode); | |||||
TensorLayout dst_layout; | |||||
dst_layout.dtype = types.back(); | |||||
auto opr = handle_bench->create_operator<ElemwiseMultiType>(); | |||||
opr->param() = mode; | |||||
opr->deduce_layout(layouts, dst_layout); | |||||
float computations = | |||||
dst_layout.total_nr_elems() * (std::max<size_t>(shapes.size(), 2) - 1); | |||||
throughput += dst_layout.span().dist_byte(); | |||||
computations *= (1e3 / (1024.0 * 1024)); | |||||
throughput *= (1e3 / (1024.0 * 1024)); | |||||
layouts.emplace_back(dst_layout); | |||||
auto fallback_time = benchmarker_fallback.execl(layouts) / RUN; | |||||
auto bench_time = benchmarker_bench.execl(layouts) / RUN; | |||||
float fallback_flops = computations / fallback_time; | |||||
float bench_flops = computations / bench_time; | |||||
float fallback_thr = throughput / fallback_time; | |||||
float bench_thr = throughput / bench_time; | |||||
printf("%s = %s (mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS " | |||||
"%fMB/s " | |||||
"computations: %fx, throughput: %fx\n", | |||||
src_strs.c_str(), dst_layout.to_string().c_str(), mode_str, fallback_flops, | |||||
fallback_thr, bench_flops, bench_thr, bench_flops / fallback_flops, | |||||
bench_thr / fallback_thr); | |||||
} | |||||
} // namespace | |||||
#define RUN_WITH_MODE(shape, mode, types) \ | |||||
run_elemwise_benchmark(shape, mode, #mode, types, handle()); | |||||
TEST_F(ARM_COMMON, BENCHMARK_UNARY_MULTI_TYPE) { | |||||
using Mode = ElemwiseMultiType::Param::Mode; | |||||
for (auto mode : | |||||
{Mode::QRELU, Mode::QABS, Mode::QSIGMOID, Mode::QEXP, Mode::QTANH, | |||||
Mode::QFAST_TANH, Mode::QH_SWISH}) { | |||||
std::vector<DType> types = {dtype::QuantizedS8(1.4f), dtype::QuantizedS8(3.4f)}; | |||||
TensorShapeArray shapes = {{10000}}; | |||||
RUN_WITH_MODE(shapes, mode, types); | |||||
std::vector<DType> types2 = { | |||||
dtype::QuantizedS32(1.4f), dtype::QuantizedS8(3.4f)}; | |||||
RUN_WITH_MODE(shapes, mode, types2); | |||||
} | |||||
} | |||||
TEST_F(ARM_COMMON, BENCHMARK_BINARY_MULTI_TYPE) { | |||||
using Mode = ElemwiseMultiType::Param::Mode; | |||||
for (auto mode : {Mode::QADD, Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_H_SWISH}) { | |||||
std::vector<DType> types = { | |||||
dtype::QuantizedS8(1.4f), dtype::QuantizedS8(3.4f), | |||||
dtype::QuantizedS8(1.6f)}; | |||||
TensorShapeArray shapes = {{10000}, {10000}}; | |||||
RUN_WITH_MODE(shapes, mode, types); | |||||
std::vector<DType> types2 = { | |||||
dtype::QuantizedS32(1.4f), dtype::QuantizedS32(3.4f), | |||||
dtype::QuantizedS8(1.6f)}; | |||||
RUN_WITH_MODE(shapes, mode, types2); | |||||
} | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -26,6 +26,175 @@ TYPED_TEST(FALLBACK_ELEMWISE_MULTI_TYPE, run) { | |||||
elemwise_multi_type::run_test<TypeParam>(this->handle()); | elemwise_multi_type::run_test<TypeParam>(this->handle()); | ||||
} | } | ||||
TEST_F(FALLBACK, ELEMWISE_QUANTIZED_MODE_UNARY) { | |||||
using Mode = ElemwiseMultiType::Param::Mode; | |||||
Checker<ElemwiseMultiType> checker(handle()); | |||||
std::unique_ptr<RNG> rng; | |||||
for (auto mode : | |||||
{Mode::QRELU, Mode::QABS, Mode::QSIGMOID, Mode::QEXP, Mode::QTANH, | |||||
Mode::QFAST_TANH, Mode::QH_SWISH}) { | |||||
checker.set_param({mode}); | |||||
for (DType src_type : | |||||
std::vector<DType>{dtype::QuantizedS8(1.4f), dtype::QuantizedS32(1.3f)}) { | |||||
checker.set_dtype(0, src_type); | |||||
if (src_type.enumv() == DTypeEnum::QuantizedS8) { | |||||
rng = std::make_unique<UniformIntRNG>(-127, 127); | |||||
checker.set_dtype(1, dtype::QuantizedS8(1.7f)); | |||||
} else { | |||||
rng = std::make_unique<UniformIntRNG>(INT16_MIN >> 1, INT16_MAX >> 1); | |||||
} | |||||
checker.set_rng(0, rng.get()); | |||||
auto run = [&]() { | |||||
checker.execs({{3, 4, 5, 6}, {}}); | |||||
checker.execs({{3}, {}}); | |||||
checker.execs({{9}, {}}); | |||||
checker.execs({{17}, {}}); | |||||
}; | |||||
if (src_type.enumv() == DTypeEnum::QuantizedS32) { | |||||
for (DType dst_type : | |||||
std::vector<DType>{dtype::QuantizedS8(32718.6f)}) { | |||||
checker.set_dtype(1, dst_type); | |||||
run(); | |||||
} | |||||
} else { | |||||
run(); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
TEST_F(FALLBACK, ELEMWISE_QUANTIZED_MODE_BINARY) { | |||||
using Mode = ElemwiseMultiType::Param::Mode; | |||||
Checker<ElemwiseMultiType> checker(handle()); | |||||
auto run = [&]() { | |||||
//! nchw44 | |||||
checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||||
checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); | |||||
checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); | |||||
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||||
checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); | |||||
checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); | |||||
checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); | |||||
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||||
checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); | |||||
//! VEC + SCALAR | |||||
checker.execs({{3, 4, 5, 6}, {1, 1, 1, 1}, {}}); | |||||
checker.execs({{1, 1, 1, 1}, {3, 4, 5, 6}, {}}); | |||||
checker.execs({{3, 4, 5, 6}, {1}, {}}); | |||||
checker.execs({{1}, {3, 4, 5, 6}, {}}); | |||||
//! VEC + 1C11 | |||||
checker.execs({{3, 4, 5, 6}, {1, 4, 1, 1}, {}}); | |||||
checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {}}); | |||||
//! VEC + VEC | |||||
checker.execs({{3}, {3}, {}}); | |||||
checker.execs({{9}, {9}, {}}); | |||||
checker.execs({{17}, {17}, {}}); | |||||
}; | |||||
// qint32 to qint8/quint8 | |||||
for (auto mode : {Mode::QADD, Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_H_SWISH}) { | |||||
checker.set_param({mode}); | |||||
UniformIntRNG rng{INT16_MIN >> 1, INT16_MAX >> 1}; | |||||
checker.set_rng(0, &rng) | |||||
.set_rng(1, &rng) | |||||
.set_dtype(0, dtype::QuantizedS32(1.3f)) | |||||
.set_dtype(1, dtype::QuantizedS32(1.2f)); | |||||
for (DType dst_type : std::vector<DType>{dtype::QuantizedS8(32718.6f)}) { | |||||
checker.set_dtype(2, dst_type); | |||||
run(); | |||||
} | |||||
} | |||||
for (auto mode : | |||||
{Mode::QMUL, Mode::QADD, Mode::QMIN, Mode::QMAX, Mode::QSUB, | |||||
Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_SIGMOID, Mode::QFUSE_ADD_H_SWISH}) { | |||||
checker.set_param({mode}); | |||||
// qint8 to qint8 | |||||
UniformIntRNG rng_int8{-127, 127}; | |||||
checker.set_rng(0, &rng_int8) | |||||
.set_rng(1, &rng_int8) | |||||
.set_dtype(0, dtype::QuantizedS8(1.35f)) | |||||
.set_dtype(1, dtype::QuantizedS8(1.15f)) | |||||
.set_dtype(2, dtype::QuantizedS8(1.75f)); | |||||
run(); | |||||
} | |||||
//! TRUE_DIV : 0.0 / 0.0 will fail | |||||
checker.set_param({Mode::QTRUE_DIV}); | |||||
UniformIntRNG rng_int8_1{-127, 127}; | |||||
UniformIntRNG rng_int8_2{-127, -1}; | |||||
checker.set_rng(0, &rng_int8_1) | |||||
.set_rng(1, &rng_int8_2) | |||||
.set_dtype(0, dtype::QuantizedS8(1.4f)) | |||||
.set_dtype(1, dtype::QuantizedS8(1.1f)) | |||||
.set_dtype(2, dtype::QuantizedS8(1.7f)); | |||||
run(); | |||||
//! TANH | |||||
checker.set_param({Mode::QFUSE_ADD_TANH}); | |||||
UniformIntRNG rng_int8{-5, 5}; | |||||
checker.set_rng(0, &rng_int8) | |||||
.set_rng(1, &rng_int8) | |||||
.set_dtype(0, dtype::QuantizedS8(1.1f)) | |||||
.set_dtype(1, dtype::QuantizedS8(1.4f)) | |||||
.set_dtype(2, dtype::QuantizedS8(1.7f)); | |||||
run(); | |||||
} | |||||
TEST_F(FALLBACK, ELEMWISE_QUANTIZED_MODE_TERNARY) { | |||||
using Mode = ElemwiseMultiType::Param::Mode; | |||||
Checker<ElemwiseMultiType> checker(handle()); | |||||
auto run = [&]() { | |||||
//! nchw44 | |||||
checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||||
checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||||
checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); | |||||
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||||
checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); | |||||
//! nchw44 | |||||
checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); | |||||
checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); | |||||
checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); | |||||
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||||
checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); | |||||
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}}); | |||||
checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}}); | |||||
checker.execs({{3}, {3}, {3}, {}}); | |||||
checker.execs({{9}, {9}, {9}, {}}); | |||||
checker.execs({{17}, {17}, {17}, {}}); | |||||
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}}); | |||||
}; | |||||
for (auto mode : {Mode::QFUSE_MUL_ADD3}) { | |||||
checker.set_param({mode}); | |||||
// qint8 to qint8 | |||||
UniformIntRNG rng_int8{-127, 127}; | |||||
checker.set_rng(0, &rng_int8) | |||||
.set_rng(1, &rng_int8) | |||||
.set_rng(2, &rng_int8) | |||||
.set_dtype(0, dtype::QuantizedS8(1.45f)) | |||||
.set_dtype(1, dtype::QuantizedS8(1.15f)) | |||||
.set_dtype(2, dtype::QuantizedS8(1.75f)) | |||||
.set_dtype(3, dtype::QuantizedS8(1.35f)); | |||||
run(); | |||||
} | |||||
} | |||||
TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16x32x32x32) { | TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16x32x32x32) { | ||||
TaskRecordChecker<ElemwiseMultiType> checker{1}; | TaskRecordChecker<ElemwiseMultiType> checker{1}; | ||||
checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32}); | checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32}); | ||||