|
|
@@ -11,6 +11,7 @@ |
|
|
|
#include "src/arm_common/type_cvt/opr_impl.h" |
|
|
|
|
|
|
|
#include <cstring> |
|
|
|
#include <deque> |
|
|
|
#include "midout.h" |
|
|
|
#include "src/arm_common/quantized_converter.h" |
|
|
|
#include "src/arm_common/simd_macro/marm_neon.h" |
|
|
@@ -18,6 +19,7 @@ |
|
|
|
#include "src/naive/handle.h" |
|
|
|
|
|
|
|
MIDOUT_DECL(megdnn_arm_typecvt_fix2float) |
|
|
|
MIDOUT_DECL(megdnn_arm_typecvt_quan2float) |
|
|
|
MIDOUT_DECL(megdnn_arm_typecvt_quantized) |
|
|
|
MIDOUT_DECL(megdnn_arm_typecvt_float) |
|
|
|
|
|
|
@@ -326,8 +328,34 @@ struct FloatTypeCvter<float, __fp16> { |
|
|
|
}; |
|
|
|
#endif |
|
|
|
|
|
|
|
template <typename TypeCvter> |
|
|
|
void do_typecvt( |
|
|
|
const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, |
|
|
|
DType src_dtype, DType dst_dtype, size_t nr_elems) { |
|
|
|
TypeCvter typecvt(src_dtype, dst_dtype); |
|
|
|
size_t i = 0; |
|
|
|
for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { |
|
|
|
typecvt.cvt(src, dst); |
|
|
|
src += TypeCvter::SIMD_WIDTH; |
|
|
|
dst += TypeCvter::SIMD_WIDTH; |
|
|
|
} |
|
|
|
#if MEGDNN_FIX_AARCH32_BUG |
|
|
|
// FIXME: as llvm may cause cannot select error if enable vectorize |
|
|
|
#pragma clang loop vectorize(disable) |
|
|
|
#endif |
|
|
|
for (; i < nr_elems; i++) { |
|
|
|
typecvt.cvt_remain(src, dst); |
|
|
|
src++; |
|
|
|
dst++; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename ctype, typename dtype> |
|
|
|
struct Fix2FloatTypeCvter; |
|
|
|
|
|
|
|
template <typename ctype, typename dtype> |
|
|
|
struct Quan2FloatTypeCvter; |
|
|
|
|
|
|
|
template <> |
|
|
|
struct Fix2FloatTypeCvter<int16_t, float> { |
|
|
|
using stype = int16_t; |
|
|
@@ -368,62 +396,184 @@ struct Fix2FloatTypeCvter<uint16_t, float> { |
|
|
|
void cvt_remain(const uint16_t* src, float* dst) { *dst = *src; } |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename TypeCvter> |
|
|
|
void do_typecvt( |
|
|
|
const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, |
|
|
|
DType src_dtype, DType dst_dtype, size_t nr_elems) { |
|
|
|
TypeCvter typecvt(src_dtype, dst_dtype); |
|
|
|
size_t i = 0; |
|
|
|
for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { |
|
|
|
typecvt.cvt(src, dst); |
|
|
|
src += TypeCvter::SIMD_WIDTH; |
|
|
|
dst += TypeCvter::SIMD_WIDTH; |
|
|
|
template <> |
|
|
|
struct Fix2FloatTypeCvter<int8_t, float> { |
|
|
|
using stype = int8_t; |
|
|
|
using dst_type = float; |
|
|
|
static constexpr size_t SIMD_WIDTH = 16; |
|
|
|
|
|
|
|
Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) { |
|
|
|
MEGDNN_MARK_USED_VAR(src_dtype); |
|
|
|
MEGDNN_MARK_USED_VAR(dst_dtype); |
|
|
|
} |
|
|
|
#if MEGDNN_FIX_AARCH32_BUG |
|
|
|
// FIXME: as llvm may cause cannot select error if enable vectorize |
|
|
|
#pragma clang loop vectorize(disable) |
|
|
|
|
|
|
|
void cvt(const int8_t* src, float* dst) { |
|
|
|
int8x16_t vitem = vld1q_s8(src); |
|
|
|
int16x8_t vtrans_high = vmovl_s8(vget_high_s8(vitem)); |
|
|
|
int16x8_t vtrans_low = vmovl_s8(vget_low_s8(vitem)); |
|
|
|
auto vres_high = QConverter::convert<float32x4x2_t, int16x8_t>(vtrans_high); |
|
|
|
auto vres_low = QConverter::convert<float32x4x2_t, int16x8_t>(vtrans_low); |
|
|
|
vst1q_f32_x2(dst, vres_low); |
|
|
|
vst1q_f32_x2(dst + 8, vres_high); |
|
|
|
} |
|
|
|
|
|
|
|
void cvt_remain(const int8_t* src, float* dst) { *dst = *src; } |
|
|
|
}; |
|
|
|
|
|
|
|
template <> |
|
|
|
struct Fix2FloatTypeCvter<uint8_t, float> { |
|
|
|
using stype = uint8_t; |
|
|
|
using dst_type = float; |
|
|
|
static constexpr size_t SIMD_WIDTH = 16; |
|
|
|
|
|
|
|
Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) { |
|
|
|
MEGDNN_MARK_USED_VAR(src_dtype); |
|
|
|
MEGDNN_MARK_USED_VAR(dst_dtype); |
|
|
|
} |
|
|
|
|
|
|
|
void cvt(const uint8_t* src, float* dst) { |
|
|
|
uint8x16_t vitem = vld1q_u8(src); |
|
|
|
uint16x8_t vtrans_high = vmovl_u8(vget_high_u8(vitem)); |
|
|
|
uint16x8_t vtrans_low = vmovl_u8(vget_low_u8(vitem)); |
|
|
|
auto vres_high = QConverter::convert<float32x4x2_t, uint16x8_t>(vtrans_high); |
|
|
|
auto vres_low = QConverter::convert<float32x4x2_t, uint16x8_t>(vtrans_low); |
|
|
|
vst1q_f32_x2(dst, vres_low); |
|
|
|
vst1q_f32_x2(dst + 8, vres_high); |
|
|
|
} |
|
|
|
|
|
|
|
void cvt_remain(const uint8_t* src, float* dst) { *dst = *src; } |
|
|
|
}; |
|
|
|
|
|
|
|
template <> |
|
|
|
struct Quan2FloatTypeCvter<int8_t, float> { |
|
|
|
using stype = int8_t; |
|
|
|
using dst_type = float; |
|
|
|
static constexpr size_t SIMD_WIDTH = 16; |
|
|
|
float _scale = 0.0f; |
|
|
|
float32x4_t _vscale; |
|
|
|
|
|
|
|
Quan2FloatTypeCvter(DType src_dtype, DType dst_dtype) { |
|
|
|
_scale = src_dtype.param<dtype::QuantizedS8>().scale; |
|
|
|
_vscale = vdupq_n_f32(_scale); |
|
|
|
MEGDNN_MARK_USED_VAR(dst_dtype); |
|
|
|
} |
|
|
|
|
|
|
|
void cvt(const int8_t* src, float* dst) { |
|
|
|
int8x16_t vitem = vld1q_s8(src); |
|
|
|
int16x8_t vtrans_high = vmovl_s8(vget_high_s8(vitem)); |
|
|
|
int16x8_t vtrans_low = vmovl_s8(vget_low_s8(vitem)); |
|
|
|
auto vres_high = QConverter::convert<float32x4x2_t, int16x8_t>(vtrans_high); |
|
|
|
auto vres_low = QConverter::convert<float32x4x2_t, int16x8_t>(vtrans_low); |
|
|
|
vst1q_f32(dst, vmulq_f32(vres_low.val[0], _vscale)); |
|
|
|
vst1q_f32(dst + 4, vmulq_f32(vres_low.val[1], _vscale)); |
|
|
|
vst1q_f32(dst + 8, vmulq_f32(vres_high.val[0], _vscale)); |
|
|
|
vst1q_f32(dst + 12, vmulq_f32(vres_high.val[1], _vscale)); |
|
|
|
} |
|
|
|
|
|
|
|
void cvt_remain(const int8_t* src, float* dst) { *dst = *src * _scale; } |
|
|
|
}; |
|
|
|
|
|
|
|
#if defined(__ARM_FEATURE_FMA) |
|
|
|
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) |
|
|
|
#else |
|
|
|
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) |
|
|
|
#endif |
|
|
|
for (; i < nr_elems; i++) { |
|
|
|
typecvt.cvt_remain(src, dst); |
|
|
|
src++; |
|
|
|
dst++; |
|
|
|
|
|
|
|
template <> |
|
|
|
struct Quan2FloatTypeCvter<uint8_t, float> { |
|
|
|
using stype = uint8_t; |
|
|
|
using dst_type = float; |
|
|
|
static constexpr size_t SIMD_WIDTH = 16; |
|
|
|
float _scale = 0.0f; |
|
|
|
float32x4_t _vscale; |
|
|
|
uint8_t _zp = 0; |
|
|
|
float32x4_t _vbias; |
|
|
|
|
|
|
|
Quan2FloatTypeCvter(DType src_dtype, DType dst_dtype) { |
|
|
|
_scale = src_dtype.param<dtype::Quantized8Asymm>().scale; |
|
|
|
_vscale = vdupq_n_f32(_scale); |
|
|
|
_zp = src_dtype.param<dtype::Quantized8Asymm>().zero_point; |
|
|
|
float bias = -_zp * 1.0f * _scale; |
|
|
|
_vbias = vdupq_n_f32(bias); |
|
|
|
MEGDNN_MARK_USED_VAR(dst_dtype); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void cvt(const uint8_t* src, float* dst) { |
|
|
|
uint8x16_t vitem = vld1q_u8(src); |
|
|
|
uint16x8_t vtrans_high = vmovl_u8(vget_high_u8(vitem)); |
|
|
|
uint16x8_t vtrans_low = vmovl_u8(vget_low_u8(vitem)); |
|
|
|
auto vres_high = QConverter::convert<float32x4x2_t, uint16x8_t>(vtrans_high); |
|
|
|
auto vres_low = QConverter::convert<float32x4x2_t, uint16x8_t>(vtrans_low); |
|
|
|
vst1q_f32(dst, Vfmaq_f32(_vbias, vres_low.val[0], _vscale)); |
|
|
|
vst1q_f32(dst + 4, Vfmaq_f32(_vbias, vres_low.val[1], _vscale)); |
|
|
|
vst1q_f32(dst + 8, Vfmaq_f32(_vbias, vres_high.val[0], _vscale)); |
|
|
|
vst1q_f32(dst + 12, Vfmaq_f32(_vbias, vres_high.val[1], _vscale)); |
|
|
|
} |
|
|
|
|
|
|
|
void cvt_remain(const uint8_t* src, float* dst) { *dst = (*src - _zp) * _scale; } |
|
|
|
}; |
|
|
|
|
|
|
|
#undef Vfmaq_f32 |
|
|
|
|
|
|
|
template <typename stype, typename dtype> |
|
|
|
struct TypeCvtTask { |
|
|
|
const stype* src; |
|
|
|
dtype* dst; |
|
|
|
size_t dim; |
|
|
|
size_t nr_elems; |
|
|
|
|
|
|
|
explicit TypeCvtTask(const stype* s, dtype* d, size_t n, size_t tot) |
|
|
|
: src(s), dst(d), dim(n), nr_elems(tot) {} |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename TypeCvter> |
|
|
|
void do_typecvt( |
|
|
|
const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, |
|
|
|
DType src_dtype, DType dst_dtype, const TensorLayout& src_layout) { |
|
|
|
TypeCvter typecvt(src_dtype, dst_dtype); |
|
|
|
size_t calc_num = 1; |
|
|
|
size_t nr_elems = src_layout.total_nr_elems(); |
|
|
|
size_t src_stride = nr_elems; |
|
|
|
|
|
|
|
//! adjust calc_num nr_elems and src_stride according to src_collapse_layout |
|
|
|
auto src_collapse_layout = src_layout.collapse_contiguous(); |
|
|
|
if (src_collapse_layout.ndim == 2) { |
|
|
|
calc_num = src_collapse_layout.shape[0]; |
|
|
|
nr_elems = src_collapse_layout.shape[1]; |
|
|
|
src_stride = src_collapse_layout.stride[0]; |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t c = 0; c < calc_num; ++c) { |
|
|
|
size_t i = 0; |
|
|
|
for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { |
|
|
|
typecvt.cvt(src, dst); |
|
|
|
src += TypeCvter::SIMD_WIDTH; |
|
|
|
dst += TypeCvter::SIMD_WIDTH; |
|
|
|
} |
|
|
|
|
|
|
|
using TypeCvtTaskWithType = |
|
|
|
TypeCvtTask<typename TypeCvter::stype, typename TypeCvter::dst_type>; |
|
|
|
std::deque<TypeCvtTaskWithType> task_queue; |
|
|
|
task_queue.emplace_back(src, dst, 0, src_collapse_layout.total_nr_elems()); |
|
|
|
|
|
|
|
while (!task_queue.empty()) { |
|
|
|
auto&& task = task_queue.front(); |
|
|
|
const typename TypeCvter::stype* psrc = task.src; |
|
|
|
typename TypeCvter::dst_type* pdst = task.dst; |
|
|
|
size_t dim = task.dim; |
|
|
|
size_t nr_elems = task.nr_elems; |
|
|
|
|
|
|
|
//! calc according to stride information |
|
|
|
if (src_collapse_layout.stride[dim] == 1) { |
|
|
|
size_t i = 0; |
|
|
|
for (; i + TypeCvter::SIMD_WIDTH < nr_elems; i += TypeCvter::SIMD_WIDTH) { |
|
|
|
typecvt.cvt(psrc, pdst); |
|
|
|
psrc += TypeCvter::SIMD_WIDTH; |
|
|
|
pdst += TypeCvter::SIMD_WIDTH; |
|
|
|
} |
|
|
|
#if MEGDNN_FIX_AARCH32_BUG |
|
|
|
// FIXME: as llvm may cause cannot select error if enable vectorize |
|
|
|
#pragma clang loop vectorize(disable) |
|
|
|
#endif |
|
|
|
for (; i < nr_elems; i++) { |
|
|
|
typecvt.cvt_remain(src, dst); |
|
|
|
src++; |
|
|
|
dst++; |
|
|
|
for (; i < nr_elems; i++) { |
|
|
|
typecvt.cvt_remain(psrc, pdst); |
|
|
|
psrc++; |
|
|
|
pdst++; |
|
|
|
} |
|
|
|
} else { |
|
|
|
size_t calc_num = src_collapse_layout.shape[dim]; |
|
|
|
size_t src_stride = src_collapse_layout.stride[dim]; |
|
|
|
size_t dst_stride = nr_elems / calc_num; |
|
|
|
for (size_t i = 0; i < calc_num; ++i) { |
|
|
|
task_queue.emplace_back(psrc, pdst, dim + 1, dst_stride); |
|
|
|
psrc += src_stride; |
|
|
|
pdst += dst_stride; |
|
|
|
} |
|
|
|
} |
|
|
|
src += src_stride - nr_elems; |
|
|
|
|
|
|
|
task_queue.pop_front(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -432,15 +582,56 @@ void do_typecvt( |
|
|
|
void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { |
|
|
|
DType src_dtype = src.layout.dtype; |
|
|
|
DType dst_dtype = dst.layout.dtype; |
|
|
|
size_t nr_elems = src.layout.total_nr_elems(); |
|
|
|
bool execed = false; |
|
|
|
auto src_collapse_layout = src.layout.collapse_contiguous(); |
|
|
|
bool has_int16_special_impl = |
|
|
|
(src.layout.dtype.enumv() == DTypeEnum::Int16 || |
|
|
|
src.layout.dtype.enumv() == DTypeEnum::Uint16) && |
|
|
|
(src.layout.is_contiguous() || src_collapse_layout.ndim == 2) && |
|
|
|
dst.layout.is_contiguous(); |
|
|
|
if (has_int16_special_impl) { |
|
|
|
|
|
|
|
if (src.layout.is_contiguous()) { |
|
|
|
using namespace dtype; |
|
|
|
size_t nr_elems = src.layout.total_nr_elems(); |
|
|
|
#define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ |
|
|
|
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ |
|
|
|
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ |
|
|
|
MIDOUT_BEGIN(megdnn_arm_typecvt_quantized, midout_iv(_midout_iv)) { \ |
|
|
|
using _TypeCvter = QuantizedTypeCvter<_stype, _dtype>; \ |
|
|
|
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ |
|
|
|
src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \ |
|
|
|
src_dtype, dst_dtype, nr_elems)); \ |
|
|
|
execed = true; \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
|
} |
|
|
|
DISPATCH_QUANTIZED(QuantizedS32, int32_t, Quantized8Asymm, uint8_t, 0); |
|
|
|
DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS8, int8_t, 1); |
|
|
|
DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS32, int32_t, 2); |
|
|
|
DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS8, int8_t, 3); |
|
|
|
DISPATCH_QUANTIZED(Quantized8Asymm, uint8_t, Quantized8Asymm, uint8_t, 4); |
|
|
|
DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS32, int32_t, 5); |
|
|
|
DISPATCH_QUANTIZED(float, float, QuantizedS8, int8_t, 6); |
|
|
|
DISPATCH_QUANTIZED(float, float, Quantized8Asymm, uint8_t, 7); |
|
|
|
#undef DISPATCH_QUANTIZED |
|
|
|
|
|
|
|
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |
|
|
|
#define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ |
|
|
|
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ |
|
|
|
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ |
|
|
|
MIDOUT_BEGIN(megdnn_arm_typecvt_float, midout_iv(_midout_iv)) { \ |
|
|
|
using _TypeCvter = FloatTypeCvter<_stype, _dtype>; \ |
|
|
|
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ |
|
|
|
reinterpret_cast<_stype*>(src.raw_ptr()), \ |
|
|
|
reinterpret_cast<_dtype*>(dst.raw_ptr()), src_dtype, dst_dtype, \ |
|
|
|
nr_elems)); \ |
|
|
|
execed = true; \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
|
} |
|
|
|
DISPATCH_FLOAT(dt_float16, __fp16, float, float, 0); |
|
|
|
DISPATCH_FLOAT(float, float, dt_float16, __fp16, 1); |
|
|
|
#undef DISPATCH_FLOAT |
|
|
|
#endif |
|
|
|
} |
|
|
|
|
|
|
|
size_t last_stride = src_collapse_layout.stride[src_collapse_layout.ndim - 1]; |
|
|
|
if (!execed && last_stride == 1 && dst.layout.is_contiguous()) { |
|
|
|
using namespace dtype; |
|
|
|
#define DISPATCH_FIX2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ |
|
|
|
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ |
|
|
@@ -456,9 +647,26 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { |
|
|
|
} |
|
|
|
DISPATCH_FIX2FLOAT(Int16, int16_t, Float32, float, 0); |
|
|
|
DISPATCH_FIX2FLOAT(Uint16, uint16_t, Float32, float, 1); |
|
|
|
DISPATCH_FIX2FLOAT(Int8, int8_t, Float32, float, 2); |
|
|
|
DISPATCH_FIX2FLOAT(Uint8, uint8_t, Float32, float, 3); |
|
|
|
#undef DISPATCH_FIX2FLOAT |
|
|
|
} else if (src.layout.is_contiguous()) { |
|
|
|
using namespace dtype; |
|
|
|
|
|
|
|
#define DISPATCH_QUAN2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ |
|
|
|
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ |
|
|
|
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ |
|
|
|
MIDOUT_BEGIN(megdnn_arm_typecvt_quan2float, midout_iv(_midout_iv)) { \ |
|
|
|
using _TypeCvter = Quan2FloatTypeCvter<_stype, _dtype>; \ |
|
|
|
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ |
|
|
|
src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \ |
|
|
|
src_dtype, dst_dtype, src.layout)); \ |
|
|
|
execed = true; \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
|
} |
|
|
|
DISPATCH_QUAN2FLOAT(QuantizedS8, int8_t, Float32, float, 0); |
|
|
|
DISPATCH_QUAN2FLOAT(Quantized8Asymm, uint8_t, Float32, float, 1); |
|
|
|
#undef DISPATCH_QUAN2FLOAT |
|
|
|
|
|
|
|
#define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ |
|
|
|
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ |
|
|
|
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ |
|
|
@@ -466,12 +674,11 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { |
|
|
|
using _TypeCvter = QuantizedTypeCvter<_stype, _dtype>; \ |
|
|
|
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ |
|
|
|
src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \ |
|
|
|
src_dtype, dst_dtype, nr_elems)); \ |
|
|
|
src_dtype, dst_dtype, src.layout)); \ |
|
|
|
execed = true; \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
|
} |
|
|
|
|
|
|
|
DISPATCH_QUANTIZED(QuantizedS32, int32_t, Quantized8Asymm, uint8_t, 0); |
|
|
|
DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS8, int8_t, 1); |
|
|
|
DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS32, int32_t, 2); |
|
|
@@ -491,7 +698,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { |
|
|
|
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ |
|
|
|
reinterpret_cast<_stype*>(src.raw_ptr()), \ |
|
|
|
reinterpret_cast<_dtype*>(dst.raw_ptr()), src_dtype, dst_dtype, \ |
|
|
|
nr_elems)); \ |
|
|
|
src.layout)); \ |
|
|
|
execed = true; \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
@@ -501,6 +708,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { |
|
|
|
#undef DISPATCH_FLOAT |
|
|
|
#endif |
|
|
|
} |
|
|
|
|
|
|
|
if (!execed) { |
|
|
|
fallback::TypeCvtImpl::exec(src, dst); |
|
|
|
} |
|
|
|