/** * \file dnn/src/fallback/elemwise_helper/kimpl/sub.h */ #pragma once #include "src/fallback/elemwise_helper/kimpl/op_base.h" namespace megdnn { namespace fallback { template struct SubOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; void operator()( const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { return src0 - src1; } }; template struct SubOp; #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ template <> \ struct SubOp<_ctype> : SubOpBase<_ctype> { \ using SubOpBase::SubOpBase; \ using SubOpBase::operator(); \ constexpr static size_t SIMD_WIDTH = _simd_width; \ void operator()( \ const _simd_type2& src0, const _simd_type2& src1, \ dst_ctype* dst) const { \ auto vitem = operator()(src0, src1); \ GiStore##_func_suffix(dst, vitem.val[0]); \ GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ } \ _simd_type2 operator()( \ const _simd_type2& src0, const _simd_type2& src1) const { \ auto vitem0 = GiSubtract##_func_suffix(src0.val[0], src1.val[0]); \ auto vitem1 = GiSubtract##_func_suffix(src0.val[1], src1.val[1]); \ return {{vitem0, vitem1}}; \ } \ void operator()( \ const _simd_type& src0, const _simd_type& src1, \ dst_ctype* dst) const { \ auto vitem = operator()(src0, src1); \ GiStore##_func_suffix(dst, vitem); \ } \ _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ return GiSubtract##_func_suffix(src0, src1); \ } \ }; OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) #undef OP template <> struct SubOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { *dst = operator()(src0, src1); } dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { return QConverter::convert( src0.as_int8() * scale0 - src1.as_int8() * scale1); } }; template <> struct SubOp : SubOpBase { using SubOpBase::SubOpBase; constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); using SubOpBase::operator(); void operator()( const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { OPERATOR_BINARY_QINT8_FALLBACK; } GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { auto vitem0 = GiSubtractFloat32( GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); auto vitem1 = GiSubtractFloat32( GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); return QConverter::convert({{vitem0, vitem1}}); } }; } // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen