diff --git a/dnn/src/arm_common/type_cvt/opr_impl.cpp b/dnn/src/arm_common/type_cvt/opr_impl.cpp index 831248c4..dadf87ce 100644 --- a/dnn/src/arm_common/type_cvt/opr_impl.cpp +++ b/dnn/src/arm_common/type_cvt/opr_impl.cpp @@ -11,6 +11,7 @@ #include "src/arm_common/type_cvt/opr_impl.h" #include +#include #include "midout.h" #include "src/arm_common/quantized_converter.h" #include "src/arm_common/simd_macro/marm_neon.h" @@ -18,6 +19,7 @@ #include "src/naive/handle.h" MIDOUT_DECL(megdnn_arm_typecvt_fix2float) +MIDOUT_DECL(megdnn_arm_typecvt_quan2float) MIDOUT_DECL(megdnn_arm_typecvt_quantized) MIDOUT_DECL(megdnn_arm_typecvt_float) @@ -326,8 +328,34 @@ struct FloatTypeCvter { }; #endif +template +void do_typecvt( + const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, + DType src_dtype, DType dst_dtype, size_t nr_elems) { + TypeCvter typecvt(src_dtype, dst_dtype); + size_t i = 0; + for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { + typecvt.cvt(src, dst); + src += TypeCvter::SIMD_WIDTH; + dst += TypeCvter::SIMD_WIDTH; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + typecvt.cvt_remain(src, dst); + src++; + dst++; + } +} + template struct Fix2FloatTypeCvter; + +template +struct Quan2FloatTypeCvter; + template <> struct Fix2FloatTypeCvter { using stype = int16_t; @@ -368,62 +396,184 @@ struct Fix2FloatTypeCvter { void cvt_remain(const uint16_t* src, float* dst) { *dst = *src; } }; -template -void do_typecvt( - const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, - DType src_dtype, DType dst_dtype, size_t nr_elems) { - TypeCvter typecvt(src_dtype, dst_dtype); - size_t i = 0; - for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { - typecvt.cvt(src, dst); - src += TypeCvter::SIMD_WIDTH; - dst += TypeCvter::SIMD_WIDTH; +template <> +struct Fix2FloatTypeCvter { + using stype = int8_t; + using dst_type = float; + static constexpr size_t SIMD_WIDTH = 16; + + Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) { + MEGDNN_MARK_USED_VAR(src_dtype); + MEGDNN_MARK_USED_VAR(dst_dtype); } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) + + void cvt(const int8_t* src, float* dst) { + int8x16_t vitem = vld1q_s8(src); + int16x8_t vtrans_high = vmovl_s8(vget_high_s8(vitem)); + int16x8_t vtrans_low = vmovl_s8(vget_low_s8(vitem)); + auto vres_high = QConverter::convert(vtrans_high); + auto vres_low = QConverter::convert(vtrans_low); + vst1q_f32_x2(dst, vres_low); + vst1q_f32_x2(dst + 8, vres_high); + } + + void cvt_remain(const int8_t* src, float* dst) { *dst = *src; } +}; + +template <> +struct Fix2FloatTypeCvter { + using stype = uint8_t; + using dst_type = float; + static constexpr size_t SIMD_WIDTH = 16; + + Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) { + MEGDNN_MARK_USED_VAR(src_dtype); + MEGDNN_MARK_USED_VAR(dst_dtype); + } + + void cvt(const uint8_t* src, float* dst) { + uint8x16_t vitem = vld1q_u8(src); + uint16x8_t vtrans_high = vmovl_u8(vget_high_u8(vitem)); + uint16x8_t vtrans_low = vmovl_u8(vget_low_u8(vitem)); + auto vres_high = QConverter::convert(vtrans_high); + auto vres_low = QConverter::convert(vtrans_low); + vst1q_f32_x2(dst, vres_low); + vst1q_f32_x2(dst + 8, vres_high); + } + + void cvt_remain(const uint8_t* src, float* dst) { *dst = *src; } +}; + +template <> +struct Quan2FloatTypeCvter { + using stype = int8_t; + using dst_type = float; + static constexpr size_t SIMD_WIDTH = 16; + float _scale = 0.0f; + float32x4_t _vscale; + + Quan2FloatTypeCvter(DType src_dtype, DType dst_dtype) { + _scale = src_dtype.param().scale; + _vscale = vdupq_n_f32(_scale); + MEGDNN_MARK_USED_VAR(dst_dtype); + } + + void cvt(const int8_t* src, float* dst) { + int8x16_t vitem = vld1q_s8(src); + int16x8_t vtrans_high = vmovl_s8(vget_high_s8(vitem)); + int16x8_t vtrans_low = vmovl_s8(vget_low_s8(vitem)); + auto vres_high = QConverter::convert(vtrans_high); + auto vres_low = QConverter::convert(vtrans_low); + vst1q_f32(dst, vmulq_f32(vres_low.val[0], _vscale)); + vst1q_f32(dst + 4, vmulq_f32(vres_low.val[1], _vscale)); + vst1q_f32(dst + 8, vmulq_f32(vres_high.val[0], _vscale)); + vst1q_f32(dst + 12, vmulq_f32(vres_high.val[1], _vscale)); + } + + void cvt_remain(const int8_t* src, float* dst) { *dst = *src * _scale; } +}; + +#if defined(__ARM_FEATURE_FMA) +#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) +#else +#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) #endif - for (; i < nr_elems; i++) { - typecvt.cvt_remain(src, dst); - src++; - dst++; + +template <> +struct Quan2FloatTypeCvter { + using stype = uint8_t; + using dst_type = float; + static constexpr size_t SIMD_WIDTH = 16; + float _scale = 0.0f; + float32x4_t _vscale; + uint8_t _zp = 0; + float32x4_t _vbias; + + Quan2FloatTypeCvter(DType src_dtype, DType dst_dtype) { + _scale = src_dtype.param().scale; + _vscale = vdupq_n_f32(_scale); + _zp = src_dtype.param().zero_point; + float bias = -_zp * 1.0f * _scale; + _vbias = vdupq_n_f32(bias); + MEGDNN_MARK_USED_VAR(dst_dtype); } -} + + void cvt(const uint8_t* src, float* dst) { + uint8x16_t vitem = vld1q_u8(src); + uint16x8_t vtrans_high = vmovl_u8(vget_high_u8(vitem)); + uint16x8_t vtrans_low = vmovl_u8(vget_low_u8(vitem)); + auto vres_high = QConverter::convert(vtrans_high); + auto vres_low = QConverter::convert(vtrans_low); + vst1q_f32(dst, Vfmaq_f32(_vbias, vres_low.val[0], _vscale)); + vst1q_f32(dst + 4, Vfmaq_f32(_vbias, vres_low.val[1], _vscale)); + vst1q_f32(dst + 8, Vfmaq_f32(_vbias, vres_high.val[0], _vscale)); + vst1q_f32(dst + 12, Vfmaq_f32(_vbias, vres_high.val[1], _vscale)); + } + + void cvt_remain(const uint8_t* src, float* dst) { *dst = (*src - _zp) * _scale; } +}; + +#undef Vfmaq_f32 + +template +struct TypeCvtTask { + const stype* src; + dtype* dst; + size_t dim; + size_t nr_elems; + + explicit TypeCvtTask(const stype* s, dtype* d, size_t n, size_t tot) + : src(s), dst(d), dim(n), nr_elems(tot) {} +}; template void do_typecvt( const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, DType src_dtype, DType dst_dtype, const TensorLayout& src_layout) { TypeCvter typecvt(src_dtype, dst_dtype); - size_t calc_num = 1; - size_t nr_elems = src_layout.total_nr_elems(); - size_t src_stride = nr_elems; - - //! adjust calc_num nr_elems and src_stride according to src_collapse_layout auto src_collapse_layout = src_layout.collapse_contiguous(); - if (src_collapse_layout.ndim == 2) { - calc_num = src_collapse_layout.shape[0]; - nr_elems = src_collapse_layout.shape[1]; - src_stride = src_collapse_layout.stride[0]; - } - - for (size_t c = 0; c < calc_num; ++c) { - size_t i = 0; - for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { - typecvt.cvt(src, dst); - src += TypeCvter::SIMD_WIDTH; - dst += TypeCvter::SIMD_WIDTH; - } + + using TypeCvtTaskWithType = + TypeCvtTask; + std::deque task_queue; + task_queue.emplace_back(src, dst, 0, src_collapse_layout.total_nr_elems()); + + while (!task_queue.empty()) { + auto&& task = task_queue.front(); + const typename TypeCvter::stype* psrc = task.src; + typename TypeCvter::dst_type* pdst = task.dst; + size_t dim = task.dim; + size_t nr_elems = task.nr_elems; + + //! calc according to stride information + if (src_collapse_layout.stride[dim] == 1) { + size_t i = 0; + for (; i + TypeCvter::SIMD_WIDTH < nr_elems; i += TypeCvter::SIMD_WIDTH) { + typecvt.cvt(psrc, pdst); + psrc += TypeCvter::SIMD_WIDTH; + pdst += TypeCvter::SIMD_WIDTH; + } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif - for (; i < nr_elems; i++) { - typecvt.cvt_remain(src, dst); - src++; - dst++; + for (; i < nr_elems; i++) { + typecvt.cvt_remain(psrc, pdst); + psrc++; + pdst++; + } + } else { + size_t calc_num = src_collapse_layout.shape[dim]; + size_t src_stride = src_collapse_layout.stride[dim]; + size_t dst_stride = nr_elems / calc_num; + for (size_t i = 0; i < calc_num; ++i) { + task_queue.emplace_back(psrc, pdst, dim + 1, dst_stride); + psrc += src_stride; + pdst += dst_stride; + } } - src += src_stride - nr_elems; + + task_queue.pop_front(); } } @@ -432,15 +582,56 @@ void do_typecvt( void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { DType src_dtype = src.layout.dtype; DType dst_dtype = dst.layout.dtype; - size_t nr_elems = src.layout.total_nr_elems(); bool execed = false; auto src_collapse_layout = src.layout.collapse_contiguous(); - bool has_int16_special_impl = - (src.layout.dtype.enumv() == DTypeEnum::Int16 || - src.layout.dtype.enumv() == DTypeEnum::Uint16) && - (src.layout.is_contiguous() || src_collapse_layout.ndim == 2) && - dst.layout.is_contiguous(); - if (has_int16_special_impl) { + + if (src.layout.is_contiguous()) { + using namespace dtype; + size_t nr_elems = src.layout.total_nr_elems(); +#define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ + if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ + dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ + MIDOUT_BEGIN(megdnn_arm_typecvt_quantized, midout_iv(_midout_iv)) { \ + using _TypeCvter = QuantizedTypeCvter<_stype, _dtype>; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ + src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \ + src_dtype, dst_dtype, nr_elems)); \ + execed = true; \ + } \ + MIDOUT_END(); \ + } + DISPATCH_QUANTIZED(QuantizedS32, int32_t, Quantized8Asymm, uint8_t, 0); + DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS8, int8_t, 1); + DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS32, int32_t, 2); + DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS8, int8_t, 3); + DISPATCH_QUANTIZED(Quantized8Asymm, uint8_t, Quantized8Asymm, uint8_t, 4); + DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS32, int32_t, 5); + DISPATCH_QUANTIZED(float, float, QuantizedS8, int8_t, 6); + DISPATCH_QUANTIZED(float, float, Quantized8Asymm, uint8_t, 7); +#undef DISPATCH_QUANTIZED + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ + if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ + dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ + MIDOUT_BEGIN(megdnn_arm_typecvt_float, midout_iv(_midout_iv)) { \ + using _TypeCvter = FloatTypeCvter<_stype, _dtype>; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ + reinterpret_cast<_stype*>(src.raw_ptr()), \ + reinterpret_cast<_dtype*>(dst.raw_ptr()), src_dtype, dst_dtype, \ + nr_elems)); \ + execed = true; \ + } \ + MIDOUT_END(); \ + } + DISPATCH_FLOAT(dt_float16, __fp16, float, float, 0); + DISPATCH_FLOAT(float, float, dt_float16, __fp16, 1); +#undef DISPATCH_FLOAT +#endif + } + + size_t last_stride = src_collapse_layout.stride[src_collapse_layout.ndim - 1]; + if (!execed && last_stride == 1 && dst.layout.is_contiguous()) { using namespace dtype; #define DISPATCH_FIX2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ @@ -456,9 +647,26 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { } DISPATCH_FIX2FLOAT(Int16, int16_t, Float32, float, 0); DISPATCH_FIX2FLOAT(Uint16, uint16_t, Float32, float, 1); + DISPATCH_FIX2FLOAT(Int8, int8_t, Float32, float, 2); + DISPATCH_FIX2FLOAT(Uint8, uint8_t, Float32, float, 3); #undef DISPATCH_FIX2FLOAT - } else if (src.layout.is_contiguous()) { - using namespace dtype; + +#define DISPATCH_QUAN2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ + if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ + dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ + MIDOUT_BEGIN(megdnn_arm_typecvt_quan2float, midout_iv(_midout_iv)) { \ + using _TypeCvter = Quan2FloatTypeCvter<_stype, _dtype>; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ + src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \ + src_dtype, dst_dtype, src.layout)); \ + execed = true; \ + } \ + MIDOUT_END(); \ + } + DISPATCH_QUAN2FLOAT(QuantizedS8, int8_t, Float32, float, 0); + DISPATCH_QUAN2FLOAT(Quantized8Asymm, uint8_t, Float32, float, 1); +#undef DISPATCH_QUAN2FLOAT + #define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ @@ -466,12 +674,11 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { using _TypeCvter = QuantizedTypeCvter<_stype, _dtype>; \ MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \ - src_dtype, dst_dtype, nr_elems)); \ + src_dtype, dst_dtype, src.layout)); \ execed = true; \ } \ MIDOUT_END(); \ } - DISPATCH_QUANTIZED(QuantizedS32, int32_t, Quantized8Asymm, uint8_t, 0); DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS8, int8_t, 1); DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS32, int32_t, 2); @@ -491,7 +698,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ reinterpret_cast<_stype*>(src.raw_ptr()), \ reinterpret_cast<_dtype*>(dst.raw_ptr()), src_dtype, dst_dtype, \ - nr_elems)); \ + src.layout)); \ execed = true; \ } \ MIDOUT_END(); \ @@ -501,6 +708,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { #undef DISPATCH_FLOAT #endif } + if (!execed) { fallback::TypeCvtImpl::exec(src, dst); } diff --git a/dnn/test/arm_common/type_cvt.cpp b/dnn/test/arm_common/type_cvt.cpp index a06b46f1..427ef303 100644 --- a/dnn/test/arm_common/type_cvt.cpp +++ b/dnn/test/arm_common/type_cvt.cpp @@ -161,24 +161,251 @@ TEST_F(ARM_COMMON, TYPE_CVT_RECORD) { .execs({{1, 32, 24, 128}, {1, 32, 24, 128}}); } -TEST_F(ARM_COMMON, TYPE_CVT_16_F32) { +TEST_F(ARM_COMMON, TYPE_CVT_NONCONTIGUOUS) { + UniformIntRNG rng32{INT32_MIN >> 1, INT32_MAX >> 1}; + UniformIntRNG rng16{INT16_MIN >> 1, INT16_MAX >> 1}; + UniformIntRNG rng8{INT8_MIN >> 1, INT8_MAX >> 1}; + Checker checker(handle()); - UniformIntRNG rng{INT16_MIN >> 1, INT16_MAX >> 1}; + size_t N = 1; + size_t C = 96; + size_t H = 64; + size_t W = 120; + TensorShape shape{N, C, H, W}; + std::vector stride{ + static_cast(C * H * (W + 8)), static_cast(H * (W + 8)), + static_cast(W + 8), 1}; + TensorLayout src, dst; + + //! float32 -> float16 + src = TensorLayout{shape, stride, dtype::Float32()}; + dst = TensorLayout{shape, dtype::Float16()}; + checker.execl({src, dst}); + + //! float16 -> float32 + src = TensorLayout{shape, stride, dtype::Float16()}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + //! float -> s8 + src = TensorLayout{shape, stride, dtype::Float32()}; + dst = TensorLayout{shape, dtype::QuantizedS8(0.245121f)}; + checker.execl({src, dst}); + + //! float -> as8 + src = TensorLayout{shape, stride, dtype::Float32()}; + dst = TensorLayout{shape, dtype::Quantized8Asymm(0.1f, static_cast(3))}; + checker.execl({src, dst}); + + checker.set_rng(0, &rng32); + //! s32 -> as8 + src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0000113264f)}; + dst = TensorLayout{ + shape, dtype::Quantized8Asymm(0.018909f, static_cast(3))}; + checker.execl({src, dst}); + + src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0003f)}; + dst = TensorLayout{shape, dtype::Quantized8Asymm(0.1f, static_cast(3))}; + checker.execl({src, dst}); + + //! s32 -> s8 + src = TensorLayout{shape, stride, dtype::QuantizedS32(0.000815917f)}; + dst = TensorLayout{shape, dtype::QuantizedS8(0.245121f)}; + checker.execl({src, dst}); + + src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0003f)}; + dst = TensorLayout{shape, dtype::QuantizedS8(0.2f)}; + checker.execl({src, dst}); + + checker.set_rng(0, &rng8); + //! s32 -> s32 + src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0004f)}; + dst = TensorLayout{shape, dtype::QuantizedS32(0.0002f)}; + checker.execl({src, dst}); + + //! s8 -> s8 + src = TensorLayout{shape, stride, dtype::QuantizedS8(0.3f)}; + dst = TensorLayout{shape, dtype::QuantizedS8(0.2f)}; + checker.execl({src, dst}); + + //! as8 -> as8 + src = TensorLayout{ + shape, stride, dtype::Quantized8Asymm(0.3f, static_cast(8))}; + dst = TensorLayout{shape, dtype::Quantized8Asymm(0.1f, static_cast(3))}; + checker.execl({src, dst}); + + //! s8 -> s32 + src = TensorLayout{shape, stride, dtype::QuantizedS8(0.245121f)}; + dst = TensorLayout{shape, dtype::QuantizedS32(0.000815917f)}; + checker.execl({src, dst}); + + src = TensorLayout{shape, stride, dtype::QuantizedS8(0.2f)}; + dst = TensorLayout{shape, dtype::QuantizedS32(0.0003f)}; + checker.execl({src, dst}); + + //! s8 -> float + src = TensorLayout{shape, stride, dtype::QuantizedS8(0.3f)}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + //! as8 -> float + src = TensorLayout{ + shape, stride, dtype::Quantized8Asymm(0.3f, static_cast(8))}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + //! int8/uint8 -> float + src = TensorLayout{shape, stride, dtype::Int8()}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + src = TensorLayout{shape, stride, dtype::Uint8()}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + //! int16/uint16 -> float + checker.set_rng(0, &rng16); for (size_t size : {3, 7, 15, 33, 10000}) { - checker.set_rng(0, &rng); checker.set_dtype(0, dtype::Int16()).execs({{size}, {size}}); checker.set_dtype(0, dtype::Uint16()).execs({{size}, {size}}); } - TensorLayout src_int16{ - {1, 96, 64, 120}, {128 * 64 * 96, 128 * 64, 128, 1}, dtype::Int16()}; - TensorLayout dst_int16{{1, 96, 64, 120}, dtype::Float32()}; - checker.execl({src_int16, dst_int16}); - - TensorLayout src_uint16{ - {1, 96, 64, 120}, {128 * 64 * 96, 128 * 64, 128, 1}, dtype::Uint16()}; - TensorLayout dst_uint16{{1, 96, 64, 120}, dtype::Float32()}; - checker.execl({src_uint16, dst_uint16}); + + src = TensorLayout{shape, stride, dtype::Int16()}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + src = TensorLayout{shape, stride, dtype::Uint16()}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + UniformIntRNG narrow_rng{-40000, 40000}; + checker.set_rng(0, &narrow_rng); + //! s32 -> as8 + src = TensorLayout{shape, stride, dtype::QuantizedS32(0.000163794f)}; + dst = TensorLayout{ + shape, dtype::Quantized8Asymm(0.0479196f, static_cast(144))}; + checker.execl({src, dst}); +} + +TEST_F(ARM_COMMON, TYPE_CVT_MONOTONOUS) { + UniformIntRNG rng32{INT32_MIN >> 1, INT32_MAX >> 1}; + UniformIntRNG rng16{INT16_MIN >> 1, INT16_MAX >> 1}; + UniformIntRNG rng8{INT8_MIN >> 1, INT8_MAX >> 1}; + + Checker checker(handle()); + + size_t N = 1; + size_t C = 96; + size_t H = 64; + size_t W = 120; + TensorShape shape{N, C, H, W}; + std::vector stride{ + static_cast((C + 8) * (H + 8) * (W + 8)), + static_cast((H + 8) * (W + 8)), static_cast(W + 8), 1}; + TensorLayout src, dst; + + //! float32 -> float16 + src = TensorLayout{shape, stride, dtype::Float32()}; + dst = TensorLayout{shape, dtype::Float16()}; + checker.execl({src, dst}); + + //! float16 -> float32 + src = TensorLayout{shape, stride, dtype::Float16()}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + //! float -> s8 + src = TensorLayout{shape, stride, dtype::Float32()}; + dst = TensorLayout{shape, dtype::QuantizedS8(0.245121f)}; + checker.execl({src, dst}); + + //! float -> as8 + src = TensorLayout{shape, stride, dtype::Float32()}; + dst = TensorLayout{shape, dtype::Quantized8Asymm(0.1f, static_cast(3))}; + checker.execl({src, dst}); + + checker.set_rng(0, &rng32); + //! s32 -> as8 + src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0000113264f)}; + dst = TensorLayout{ + shape, dtype::Quantized8Asymm(0.018909f, static_cast(3))}; + checker.execl({src, dst}); + + src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0003f)}; + dst = TensorLayout{shape, dtype::Quantized8Asymm(0.1f, static_cast(3))}; + checker.execl({src, dst}); + + //! s32 -> s8 + src = TensorLayout{shape, stride, dtype::QuantizedS32(0.000815917f)}; + dst = TensorLayout{shape, dtype::QuantizedS8(0.245121f)}; + checker.execl({src, dst}); + + src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0003f)}; + dst = TensorLayout{shape, dtype::QuantizedS8(0.2f)}; + checker.execl({src, dst}); + + checker.set_rng(0, &rng8); + //! s32 -> s32 + src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0004f)}; + dst = TensorLayout{shape, dtype::QuantizedS32(0.0002f)}; + checker.execl({src, dst}); + + //! s8 -> s8 + src = TensorLayout{shape, stride, dtype::QuantizedS8(0.3f)}; + dst = TensorLayout{shape, dtype::QuantizedS8(0.2f)}; + checker.execl({src, dst}); + + //! as8 -> as8 + src = TensorLayout{ + shape, stride, dtype::Quantized8Asymm(0.3f, static_cast(8))}; + dst = TensorLayout{shape, dtype::Quantized8Asymm(0.1f, static_cast(3))}; + checker.execl({src, dst}); + + //! s8 -> s32 + src = TensorLayout{shape, stride, dtype::QuantizedS8(0.245121f)}; + dst = TensorLayout{shape, dtype::QuantizedS32(0.000815917f)}; + checker.execl({src, dst}); + + src = TensorLayout{shape, stride, dtype::QuantizedS8(0.2f)}; + dst = TensorLayout{shape, dtype::QuantizedS32(0.0003f)}; + checker.execl({src, dst}); + + //! s8 -> float + src = TensorLayout{shape, stride, dtype::QuantizedS8(0.3f)}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + //! as8 -> float + src = TensorLayout{ + shape, stride, dtype::Quantized8Asymm(0.3f, static_cast(8))}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + //! int8/uint8 -> float + src = TensorLayout{shape, stride, dtype::Int8()}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + src = TensorLayout{shape, stride, dtype::Uint8()}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + src = TensorLayout{shape, stride, dtype::Int16()}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + src = TensorLayout{shape, stride, dtype::Uint16()}; + dst = TensorLayout{shape, dtype::Float32()}; + checker.execl({src, dst}); + + UniformIntRNG narrow_rng{-40000, 40000}; + checker.set_rng(0, &narrow_rng); + //! s32 -> as8 + src = TensorLayout{shape, stride, dtype::QuantizedS32(0.000163794f)}; + dst = TensorLayout{ + shape, dtype::Quantized8Asymm(0.0479196f, static_cast(144))}; + checker.execl({src, dst}); } #if MEGDNN_WITH_BENCHMARK