@@ -62,7 +62,7 @@ namespace megdnn { | |||
#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb) \ | |||
cb(QuantizedS32) cb(QuantizedS8) cb(Quantized4Asymm) cb(QuantizedS4) \ | |||
cb(QuantizedS16) | |||
cb(QuantizedS16) cb(QuantizedS1) | |||
#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(cb_first, cb_others) \ | |||
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb_first) \ | |||
@@ -112,7 +112,7 @@ namespace megdnn { | |||
#define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \ | |||
cb(::megdnn::dtype::QuantizedS32) cb(::megdnn::dtype::QuantizedS8) \ | |||
cb(::megdnn::dtype::QuantizedS4) | |||
cb(::megdnn::dtype::QuantizedS4) cb(::megdnn::dtype::QuantizedS1) | |||
#define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \ | |||
cb(::megdnn::dtype::Quantized8Asymm) cb(::megdnn::dtype::Quantized4Asymm) | |||
@@ -292,10 +292,27 @@ public: | |||
}; | |||
using dt_qint4 = dt_qlowbit<4>; | |||
class dt_qint1 { | |||
int8_t _; | |||
public: | |||
MEGDNN_DEVICE int8_t as_int8() const { return _; } | |||
MEGDNN_HOST MEGDNN_DEVICE explicit dt_qint1(int8_t val) : _(val) {} | |||
#ifdef MEGDNN_CC_HOST | |||
explicit operator int8_t() { return _; } | |||
#endif | |||
bool operator<(const dt_qint1& b) const { return _ < b._; } | |||
bool operator>(const dt_qint1& b) const { return _ > b._; } | |||
bool operator==(const dt_qint1& b) const { return _ == b._; } | |||
bool operator!=(const dt_qint1& b) const { return _ != b._; } | |||
} MEGDNN_PACKED; | |||
#ifdef __clang__ | |||
#pragma clang diagnostic pop | |||
#endif | |||
MEGDNN_STATIC_ASSERT(sizeof(dt_byte) == 1, "bad dt_byte size"); | |||
MEGDNN_STATIC_ASSERT(sizeof(dt_qint1) == 1, "bad dt_qint1 size"); | |||
MEGDNN_STATIC_ASSERT(sizeof(dt_quint8) == 1, "bad dt_quint8 size"); | |||
MEGDNN_STATIC_ASSERT(sizeof(dt_qint16) == 2, "bad dt_qint16 size"); | |||
MEGDNN_STATIC_ASSERT(sizeof(dt_qint32) == 4, "bad dt_qint32 size"); | |||
@@ -677,7 +694,7 @@ MEGDNN_FOREACH_LOWBIT_DTYPE(MEGDNN_DEF_FRACTION_DT) | |||
return static_cast<_itype>(_maxval); \ | |||
} \ | |||
}; | |||
MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS1, dt_qint1, int8_t, QUANTIZED, SIGNED, 0, 1, 0); | |||
MEGDNN_DEF_PARAMETERIZED_DT( | |||
Quantized4Asymm, dt_quint4, uint8_t, QUANTIZED, SIGNED, 0, 15, 4); | |||
MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS4, dt_qint4, int8_t, QUANTIZED, SIGNED, -8, 7, 4); | |||
@@ -877,6 +894,26 @@ struct DTypeParamImpl<dt_quint4> { | |||
}; | |||
template <> | |||
struct DTypeParamImpl<dt_qint1> { | |||
float scale; | |||
DTypeParamImpl<dt_qint1>() = default; | |||
MGE_WIN_DECLSPEC_FUC DTypeParamImpl<dt_qint1>(float scale); | |||
#ifdef MEGDNN_CC_HOST | |||
std::size_t hash() const; | |||
#endif | |||
bool operator==(const DTypeParam<dt_qint1>& rhs) const; | |||
MEGDNN_DEVICE dt_qint1 quantize(float in) const { | |||
float v = in / scale; | |||
v = roundf(v); | |||
v = fmin(fmax(0.f, v), 1.f); | |||
return static_cast<dt_qint1>(v); | |||
} | |||
MEGDNN_DEVICE float dequantize(int8_t in) const { return in * scale; } | |||
MEGDNN_DEVICE float dequantize(dt_qint1 in) const { return in.as_int8() * scale; } | |||
}; | |||
template <> | |||
struct DTypeParamImpl<dt_qint4> { | |||
float scale; | |||
@@ -142,6 +142,19 @@ inline bool DTypeParam<dt_qint32>::operator==(const DTypeParam<dt_qint32>& rhs) | |||
return scale == rhs.scale; | |||
} | |||
DTypeParam<dt_qint1>::DTypeParamImpl(float scale) : scale{scale} { | |||
//! As the nan is not equal to any value | |||
megdnn_assert(!std::isnan(scale), "nan number compare is not support"); | |||
} | |||
inline std::size_t DTypeParam<dt_qint1>::hash() const { | |||
return std::hash<float>()(scale); | |||
} | |||
inline bool DTypeParam<dt_qint1>::operator==(const DTypeParam<dt_qint1>& rhs) const { | |||
return scale == rhs.scale; | |||
} | |||
DTypeParam<dt_quint4>::DTypeParamImpl(float scale, uint8_t zero_point) | |||
: scale{scale}, zero_point{zero_point} { | |||
//! As the nan is not equal to any value | |||
@@ -241,6 +241,7 @@ float megdnn::mul_scale(DType lhs, DType rhs) { | |||
return lhs.param<dt>().scale * rhs.param<dt>().scale; | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
cb(::megdnn::dtype::QuantizedS1) | |||
#undef cb | |||
megdnn_assert_internal(0); | |||
} | |||
@@ -253,8 +254,9 @@ float megdnn::get_scale(DType dt) { | |||
return dt.param<_dt>().scale; | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
cb(::megdnn::dtype::QuantizedS1) | |||
#undef cb | |||
megdnn_assert_internal(0); | |||
megdnn_assert_internal(0); | |||
} | |||
bool megdnn::dtype_almost_equal(DType lhs, DType rhs) { | |||
@@ -160,6 +160,9 @@ INST_FOR_CTYPE | |||
#define ct dt_bool | |||
INST_FOR_CTYPE | |||
#undef ct | |||
#define ct dt_qint1 | |||
INST_FOR_CTYPE | |||
#undef ct | |||
#undef INST_FOR_CTYPE | |||
#undef INST | |||
@@ -210,6 +213,9 @@ INST_FOR_CTYPE | |||
#define ct dt_bool | |||
INST_FOR_CTYPE | |||
#undef ct | |||
#define ct dt_qint1 | |||
INST_FOR_CTYPE | |||
#undef ct | |||
#undef ndim_cb | |||
@@ -221,6 +227,7 @@ INST(dt_int8); | |||
INST(dt_uint8); | |||
INST(dt_bool); | |||
INST(dt_qint8); | |||
INST(dt_qint1); | |||
INST(dt_quint8); | |||
#undef dt_ibyte | |||
@@ -96,6 +96,7 @@ INST(dt_bool, uchar4); | |||
#undef as_raw | |||
#define as_raw(x) x.as_int8() | |||
INST(dt_qint8, char4); | |||
INST(dt_qint1, char4); | |||
#undef as_raw | |||
#define as_raw(x) x.as_uint8() | |||
INST(dt_quint8, uchar4); | |||
@@ -466,6 +467,7 @@ INST_PARAM_VECT_VISITOR; | |||
INST_DT_IBYTE(dt_int8); | |||
INST_DT_IBYTE(dt_uint8); | |||
INST_DT_IBYTE(dt_qint8); | |||
INST_DT_IBYTE(dt_qint1); | |||
INST_DT_IBYTE(dt_quint8); | |||
INST_DT_IBYTE(dt_bool); | |||
#undef INST_DT_IBYTE | |||
@@ -1299,6 +1301,7 @@ private: | |||
INST_DT_IBYTE(dt_int8); | |||
INST_DT_IBYTE(dt_uint8); | |||
INST_DT_IBYTE(dt_qint8); | |||
INST_DT_IBYTE(dt_qint1); | |||
INST_DT_IBYTE(dt_quint8); | |||
INST_DT_IBYTE(dt_bool); | |||
#undef INST_DT_IBYTE | |||
@@ -1649,6 +1652,7 @@ public: | |||
INST_DT_IBYTE(dt_int8); | |||
INST_DT_IBYTE(dt_uint8); | |||
INST_DT_IBYTE(dt_qint8); | |||
INST_DT_IBYTE(dt_qint1); | |||
INST_DT_IBYTE(dt_quint8); | |||
INST_DT_IBYTE(dt_bool); | |||
#undef INST_DT_IBYTE | |||
@@ -88,6 +88,7 @@ struct TypeCvtOpToQuantized< | |||
typename std::enable_if< | |||
std::is_same<ctype_src, dt_int8>::value || | |||
std::is_same<ctype_src, dt_uint8>::value || | |||
std::is_same<ctype_src, dt_qint1>::value || | |||
std::is_same<ctype_src, dt_bool>::value>::type> { | |||
ctype_dest* dest; | |||
CudaDTypeParam<ctype_dest> param; | |||
@@ -111,6 +112,7 @@ struct TypeCvtOpFromQuantized< | |||
ctype_dest, ctype_src, | |||
typename std::enable_if< | |||
std::is_same<ctype_src, dt_qint8>::value || | |||
std::is_same<ctype_src, dt_qint1>::value || | |||
std::is_same<ctype_src, dt_quint8>::value>::type> { | |||
ctype_dest* dest; | |||
CudaDTypeParam<ctype_src> param; | |||
@@ -134,7 +136,8 @@ struct TypeCvtOpBetweenQuantized< | |||
ctype_dest, ctype_src, | |||
typename std::enable_if< | |||
(std::is_same<ctype_src, dt_qint8>::value || | |||
std::is_same<ctype_src, dt_quint8>::value) && | |||
std::is_same<ctype_src, dt_quint8>::value || | |||
std::is_same<ctype_src, dt_qint1>::value) && | |||
IsNotTypeQ4<ctype_dest>::value>::type> { | |||
ctype_dest* dest; | |||
CudaDTypeParam<ctype_src> src_param; | |||
@@ -306,6 +309,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st | |||
cb(dtype_src, dt_quint8) \ | |||
cb(dtype_src, dt_qint32) \ | |||
cb(dtype_src, dt_qint8) \ | |||
cb(dtype_src, dt_qint1) \ | |||
#define INST_SRC_QUANTIZED(dtype_src) \ | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \ | |||
@@ -330,7 +334,8 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st | |||
cb(dt_qint32) \ | |||
cb(dt_qint8) \ | |||
cb(dt_qint4) \ | |||
cb(dt_quint4) | |||
cb(dt_quint4) \ | |||
cb(dt_qint1) | |||
MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED) | |||
MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL) | |||
@@ -50,6 +50,7 @@ void exec_src_quantized( | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb); | |||
cb(::megdnn::dtype::QuantizedS1); | |||
default: | |||
megdnn_assert_internal(0); | |||
#undef cb | |||
@@ -101,6 +102,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, cudaStream_t stre | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb); | |||
cb(::megdnn::dtype::QuantizedS1); | |||
#undef cb | |||
default: | |||
megdnn_assert_internal(0); | |||
@@ -150,9 +152,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
} | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
cb(::megdnn::dtype::QuantizedS1) | |||
#undef cb | |||
default: | |||
megdnn_assert_internal(0); | |||
default : megdnn_assert_internal(0); | |||
} | |||
} | |||
} | |||
@@ -241,6 +241,23 @@ struct CudaDTypeParamImpl<dt_qint4> : DTypeParamImpl<dt_qint4> { | |||
} | |||
}; | |||
template <> | |||
struct CudaDTypeParamImpl<dt_qint1> : DTypeParamImpl<dt_qint1> { | |||
float inv_scale; | |||
CudaDTypeParamImpl() = default; | |||
CudaDTypeParamImpl(float scale) | |||
: DTypeParamImpl<dt_qint1>(scale), inv_scale(1.0f / scale) {} | |||
CudaDTypeParamImpl(const DTypeParamImpl<dt_qint1>& param) | |||
: CudaDTypeParamImpl(param.scale) {} | |||
__device__ dt_qint1 quantize(float in) const { | |||
float v = in * inv_scale; | |||
v = roundf(v); | |||
v = fmin(fmax(0.f, v), 1.f); | |||
return static_cast<dt_qint1>(v); | |||
} | |||
}; | |||
#if MEGDNN_CC_CUDA | |||
static inline MEGDNN_DEVICE void dot_prod(int a, int b, int c, int& d) { | |||
#if __CUDA_ARCH__ >= 610 | |||
@@ -510,7 +510,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
}; | |||
if (src.layout.is_contiguous() && dst.layout.is_contiguous() && | |||
!is_quantize_lowbit(src.layout.dtype) && | |||
!is_quantize_lowbit(dst.layout.dtype)) { | |||
!is_quantize_lowbit(dst.layout.dtype) && | |||
dst.layout.dtype.enumv() != DTypeEnum::QuantizedS1 && | |||
src.layout.dtype.enumv() != DTypeEnum::QuantizedS1) { | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); | |||
} else { | |||
naive::TypeCvtImpl::exec(src, dst); | |||
@@ -79,8 +79,9 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, const TensorND& src | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) | |||
cb(::megdnn::dtype::QuantizedS1) | |||
#undef cb | |||
default : megdnn_throw("bad dtype"); | |||
default : megdnn_throw("bad dtype"); | |||
} | |||
} | |||
@@ -100,8 +101,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) | |||
cb(::megdnn::dtype::QuantizedS1) | |||
#undef cb | |||
default : megdnn_throw("bad dtype"); | |||
default : megdnn_throw("bad dtype"); | |||
} | |||
} | |||
@@ -79,7 +79,8 @@ template <typename ctype> | |||
const char* expr0, const char* expr1, const TensorND& v0, const TensorND& v1, | |||
float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { | |||
if (!std::is_same<ctype, dt_qint4>::value && | |||
!std::is_same<ctype, dt_quint4>::value) { | |||
!std::is_same<ctype, dt_quint4>::value && | |||
!std::is_same<ctype, dt_qint1>::value) { | |||
if (v0.layout.is_physical_contiguous() && v1.layout.is_physical_contiguous()) { | |||
return assert_tensor_eq_with_iter<ctype>( | |||
expr0, expr1, v0.ptr<ctype>(), v1.ptr<ctype>(), v0.layout, maxerr, | |||
@@ -158,7 +159,7 @@ void copy_tensors( | |||
//! In order to avoid an unnecessary increase in binary size, we just | |||
//! use QuantizedS16 dtype in winograd_filter_preprocess now. | |||
cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
cb(::megdnn::dtype::Uint16) | |||
cb(::megdnn::dtype::Uint16) cb(::megdnn::dtype::QuantizedS1) | |||
#undef cb | |||
default : megdnn_trap(); | |||
} | |||
@@ -71,6 +71,32 @@ TEST(TestDType, TestQuantized8Asymm) { | |||
EXPECT_ANY_THROW(DType::from_enum(DTypeEnum::Quantized8Asymm)); | |||
} | |||
TEST(TestDType, QuantizedS1) { | |||
using namespace megdnn; | |||
dtype::QuantizedS1 qint1(0.1f); | |||
EXPECT_EQ(qint1.size(1), 1u); | |||
EXPECT_FLOAT_EQ(qint1.param().scale, 0.1f); | |||
dtype::QuantizedS1 qint1_copy = qint1; | |||
EXPECT_NO_THROW(qint1_copy.assert_is(qint1)); | |||
EXPECT_FLOAT_EQ(qint1_copy.param().scale, 0.1f); | |||
dtype::QuantizedS1 qint1_reconstruct_with_same_param(0.1f); | |||
EXPECT_NO_THROW(qint1_reconstruct_with_same_param.assert_is(qint1)); | |||
dtype::QuantizedS1 qint1_diff(0.2f); | |||
EXPECT_ANY_THROW(qint1_diff.assert_is(qint1)); | |||
DType parent = qint1; | |||
ASSERT_NO_THROW(dtype::QuantizedS1::downcast_from(parent)); | |||
auto param = dtype::QuantizedS1::downcast_from(parent).param(); | |||
EXPECT_FLOAT_EQ(param.scale, 0.1f); | |||
EXPECT_ANY_THROW(dtype::QuantizedS1::downcast_from(dtype::IntB1())); | |||
EXPECT_ANY_THROW(DType::from_enum(DTypeEnum::QuantizedS1)); | |||
} | |||
TEST(TestDType, TestQuantizedS4) { | |||
using namespace megdnn; | |||
@@ -149,7 +149,7 @@ void IIDRNG::gen(const TensorND& tensor) { | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
//! In order to avoid an unnecessary increase in binary size, we just | |||
//! use QuantizedS16 dtype in winograd_filter_preprocess now. | |||
cb(::megdnn::dtype::QuantizedS16) | |||
cb(::megdnn::dtype::QuantizedS16) cb(::megdnn::dtype::QuantizedS1) | |||
#undef cb | |||
if (tensor.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||
auto ptr = static_cast<uint8_t*>(tensor.raw_ptr()); | |||
@@ -226,6 +226,10 @@ static inline int diff(dt_qint4 x, dt_qint4 y) { | |||
return x.as_int8() - y.as_int8(); | |||
} | |||
static inline int diff(dt_qint1 x, dt_qint1 y) { | |||
return x.as_int8() - y.as_int8(); | |||
} | |||
static inline int diff(dt_quint4 x, dt_quint4 y) { | |||
return x.as_uint8() - y.as_uint8(); | |||
} | |||
@@ -339,6 +343,10 @@ static inline bool good_float(dt_qint4) { | |||
return true; | |||
} | |||
static inline bool good_float(dt_qint1) { | |||
return true; | |||
} | |||
static inline bool good_float(dt_quint4) { | |||
return true; | |||
} | |||
@@ -373,6 +381,11 @@ static inline int operator+(dt_qint4 lhs, int rhs) { | |||
megdnn_assert(rhs == 0, "unexpected rhs"); | |||
return lhs.as_int8(); | |||
} | |||
static inline int operator+(dt_qint1 lhs, int rhs) { | |||
megdnn_assert(rhs == 0, "unexpected rhs"); | |||
return lhs.as_int8(); | |||
} | |||
} // namespace test | |||
static inline bool operator==(const TensorLayout& a, const TensorLayout& b) { | |||
@@ -77,16 +77,19 @@ TEST_F(CUDA, QUANTIZED_TYPECVT) { | |||
}; | |||
run(dtype::Float32(), dtype::QuantizedS8(3.0f)); | |||
run(dtype::Float32(), dtype::QuantizedS1(3.0f)); | |||
run(dtype::Float16(), dtype::QuantizedS8(3.0f)); | |||
run(dtype::Int32(), dtype::QuantizedS32(5.0f)); | |||
run(dtype::Int8(), dtype::QuantizedS32(10.0f)); | |||
run(dtype::Float32(), dtype::QuantizedS8(2e-3f)); | |||
run(dtype::Float32(), dtype::QuantizedS1(2e-3f)); | |||
run(dtype::Float16(), dtype::QuantizedS8(1e-3f)); | |||
run(dtype::Int32(), dtype::QuantizedS32(1e-3f)); | |||
run(dtype::Int8(), dtype::QuantizedS32(7e-4f)); | |||
run(dtype::QuantizedS8(3.0f), dtype::QuantizedS8(10.0f)); | |||
run(dtype::QuantizedS1(3.0f), dtype::QuantizedS1(10.0f)); | |||
run(dtype::QuantizedS32(3.0f), dtype::QuantizedS8(10.0f)); | |||
run(dtype::QuantizedS8(3.0f), dtype::QuantizedS32(10.0f)); | |||
run(dtype::QuantizedS32(3.0f), dtype::QuantizedS32(10.0f)); | |||
@@ -95,6 +98,7 @@ TEST_F(CUDA, QUANTIZED_TYPECVT) { | |||
run(dtype::QuantizedS32(2e-3f), dtype::QuantizedS8(9e-4f)); | |||
run(dtype::QuantizedS8(9e-4f), dtype::QuantizedS32(7e-4f)); | |||
run(dtype::QuantizedS32(5e-3f), dtype::QuantizedS32(1e-3f)); | |||
run(dtype::QuantizedS1(1e-3f), dtype::Float32()); | |||
run(dtype::Quantized8Asymm(5.0f, (uint8_t)128), dtype::Float32()); | |||
run(dtype::Quantized8Asymm(5.0f, (uint8_t)124), dtype::Float16()); | |||
@@ -94,6 +94,7 @@ _builtin_quant_dtypes = { | |||
"qint8_narrow": QuantDtypeMeta("qint8_narrow", "QuantizedS8", "int8", -127, 127), | |||
"quint4": QuantDtypeMeta("quint4", "Quantized4Asymm", "uint8", 0, 15), | |||
"qint4": QuantDtypeMeta("qint4", "QuantizedS4", "int8", -8, 7), | |||
"qint1": QuantDtypeMeta("qint1", "QuantizedS1", "int8", 0, 1), | |||
"qint32": QuantDtypeMeta( | |||
"qint32", "QuantizedS32", "int32", -(2 ** 31), 2 ** 31 - 1, | |||
), | |||
@@ -192,6 +193,13 @@ def qint4(scale): | |||
return create_quantized_dtype(_builtin_quant_dtypes["qint4"], scale, None) | |||
def qint1(scale): | |||
r"""Construct a quantized int1 data type with ``scale`` (float). The real value | |||
represented by a qint1 data type is float_val = scale * int1_val | |||
""" | |||
return create_quantized_dtype(_builtin_quant_dtypes["qint1"], scale, None) | |||
def _convert_to_quantized_dtype( | |||
arr: np.ndarray, dtype: np.dtype, dtype_meta: QuantDtypeMeta | |||
): | |||
@@ -335,3 +343,22 @@ def convert_from_qint4(arr: np.ndarray): | |||
arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint4"]) | |||
def convert_to_qint1(arr: np.ndarray, q: np.dtype): | |||
r"""Quantize a float NumPy ndarray into a qint1 one with specified params. | |||
Args: | |||
arr: Input ndarray. | |||
q: Target data type, should be a qint1. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint1"]) | |||
def convert_from_qint1(arr: np.ndarray): | |||
r"""Dequantize a qint1 NumPy ndarray into a float one. | |||
Args: | |||
arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint1"]) |
@@ -214,6 +214,14 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(DType dty | |||
if (dtype.has_param()) { | |||
PyArray_Descr* type_descr; | |||
switch (dtype.enumv()) { | |||
case DTypeEnum::QuantizedS1: { | |||
auto& param = dtype.param<dtype::QuantizedS1>(); | |||
type_descr = PyArray_DescrNewFromType(NPY_INT8); | |||
type_descr->metadata = build_mgb_dtype_dict( | |||
DTypeTrait<dtype::QuantizedS1>::name, | |||
{{"scale", PyFloat_FromDouble(param.scale)}}); | |||
break; | |||
} | |||
case DTypeEnum::Quantized4Asymm: { | |||
auto& param = dtype.param<dtype::Quantized4Asymm>(); | |||
type_descr = PyArray_DescrNewFromType(NPY_UINT8); | |||
@@ -354,7 +362,7 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { | |||
static_cast<uint8_t>(zero_point)); | |||
} | |||
if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" || | |||
dtype_name == "QuantizedS4") { | |||
dtype_name == "QuantizedS4" || dtype_name == "QuantizedS1") { | |||
PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); | |||
mgb_assert(scale_py, "Invalid metadata: missing scale"); | |||
mgb_assert( | |||
@@ -364,8 +372,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { | |||
return dtype::QuantizedS32(scale); | |||
} else if (dtype_name == "QuantizedS8") { | |||
return dtype::QuantizedS8(scale); | |||
} else { | |||
} else if (dtype_name == "QuantizedS4") { | |||
return dtype::QuantizedS4(scale); | |||
} else if (dtype_name == "QuantizedS1") { | |||
return dtype::QuantizedS1(scale); | |||
} | |||
} | |||
throw ConversionError( | |||
@@ -15,10 +15,12 @@ import megengine.core.tensor.megbrain_graph as G | |||
from megengine.core.ops import builtin as ops | |||
from megengine.core.tensor.dtype import ( | |||
_builtin_quant_dtypes, | |||
convert_from_qint1, | |||
convert_from_qint4, | |||
convert_from_qint8, | |||
convert_from_quint4, | |||
convert_from_quint8, | |||
convert_to_qint1, | |||
convert_to_qint4, | |||
convert_to_qint8, | |||
convert_to_quint4, | |||
@@ -26,6 +28,7 @@ from megengine.core.tensor.dtype import ( | |||
get_scale, | |||
get_zero_point, | |||
is_quantize, | |||
qint1, | |||
qint4, | |||
qint8, | |||
quint4, | |||
@@ -113,9 +116,20 @@ def test_dtype_qint4(): | |||
np.testing.assert_allclose(get_scale(dt), 0.01) | |||
def test_dtype_qint1(): | |||
dt = qint1(0.01) | |||
assert isinstance(dt, np.dtype) | |||
assert "mgb_dtype" in dt.metadata | |||
np.testing.assert_allclose(dt.metadata["mgb_dtype"]["scale"], 0.01) | |||
assert is_quantize(dt) | |||
np.testing.assert_allclose(get_scale(dt), 0.01) | |||
@pytest.mark.parametrize( | |||
"dtype, dtype_name", | |||
[ | |||
(qint1(0.01), "qint1"), | |||
(quint4(0.01, 5), "quint4"), | |||
(qint4(0.01), "qint4"), | |||
(quint8(0.01, 135), "quint8"), | |||
@@ -141,6 +155,7 @@ def test_dtype_qint_mgb_ffi_handle(dtype, dtype_name): | |||
@pytest.mark.parametrize( | |||
"dtype, dtype_name", | |||
[ | |||
(qint1(0.01), "qint1"), | |||
(quint4(0.01, 5), "quint4"), | |||
(qint4(0.01), "qint4"), | |||
(quint8(0.01, 135), "quint8"), | |||
@@ -178,6 +193,7 @@ def test_qint_typecvt(dtype, dtype_name): | |||
@pytest.mark.parametrize( | |||
"dtype, dtype_name", | |||
[ | |||
(qint1(0.01), "qint1"), | |||
(quint4(0.01, 5), "quint4"), | |||
(qint4(0.01), "qint4"), | |||
(quint8(0.01, 135), "quint8"), | |||
@@ -207,6 +223,7 @@ def test_qint_astype(dtype, dtype_name): | |||
@pytest.mark.parametrize( | |||
"dtype, dtype_name", | |||
[ | |||
(qint1(0.01), "qint1"), | |||
(quint4(0.01, 5), "quint4"), | |||
(qint4(0.01), "qint4"), | |||
(quint8(0.01, 135), "quint8"), | |||
@@ -42,6 +42,10 @@ double as_double(megdnn::dt_qint4& a) { | |||
return static_cast<double>(a.as_int8()); | |||
} | |||
template <> | |||
double as_double(megdnn::dt_qint1& a) { | |||
return static_cast<double>(a.as_int8()); | |||
} | |||
template <> | |||
double as_double(megdnn::dt_qint32& a) { | |||
return static_cast<double>(a.as_int32()); | |||
} | |||
@@ -111,7 +115,7 @@ void print_host_val( | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
cb(dtype::Bool) | |||
cb(dtype::Bool) cb(::megdnn::dtype::QuantizedS1) | |||
#undef cb | |||
default : mgb_throw( | |||
MegBrainError, | |||
@@ -23,6 +23,7 @@ enum DTypeEnum : byte { | |||
BFloat16, | |||
Bool, | |||
Uint16, | |||
QuantizedS1, | |||
} | |||
table LinearQuantizationParam { | |||
@@ -55,6 +55,8 @@ megdnn::DType load_dtype(const fbs::DType* dtype) { | |||
return dtype::_dt{}; | |||
MEGDNN_FOREACH_DTYPE_NAME(cb) | |||
#undef cb | |||
case DTypeEnum_QuantizedS1: | |||
return dtype::QuantizedS1{param->scale()}; | |||
case DTypeEnum_QuantizedS4: | |||
return dtype::QuantizedS4{param->scale()}; | |||
case DTypeEnum_QuantizedS8: | |||
@@ -113,6 +115,7 @@ flatbuffers::Offset<fbs::DType> build_dtype( | |||
break; | |||
CASE_ASYMMETRIC(Quantized4Asymm) | |||
CASE_ASYMMETRIC(Quantized8Asymm) | |||
CASE_SYMMETRIC(QuantizedS1) | |||
CASE_SYMMETRIC(QuantizedS4) | |||
CASE_SYMMETRIC(QuantizedS8) | |||
CASE_SYMMETRIC(QuantizedS16) | |||